diff options
author | Devtools Arcadia <[email protected]> | 2022-02-07 18:08:42 +0300 |
---|---|---|
committer | Devtools Arcadia <[email protected]> | 2022-02-07 18:08:42 +0300 |
commit | 1110808a9d39d4b808aef724c861a2e1a38d2a69 (patch) | |
tree | e26c9fed0de5d9873cce7e00bc214573dc2195b7 /contrib/libs/grpc/src/python/grpcio_tests |
intermediate changes
ref:cde9a383711a11544ce7e107a78147fb96cc4029
Diffstat (limited to 'contrib/libs/grpc/src/python/grpcio_tests')
189 files changed, 28176 insertions, 0 deletions
diff --git a/contrib/libs/grpc/src/python/grpcio_tests/.yandex_meta/licenses.list.txt b/contrib/libs/grpc/src/python/grpcio_tests/.yandex_meta/licenses.list.txt new file mode 100644 index 00000000000..e0080a7b1fc --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/.yandex_meta/licenses.list.txt @@ -0,0 +1,72 @@ +====================Apache-2.0==================== +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +====================COPYRIGHT==================== + * Copyright 2015 gRPC authors. + + +====================COPYRIGHT==================== + * Copyright 2016 gRPC authors. + + +====================COPYRIGHT==================== + * Copyright 2017 gRPC authors. + + +====================COPYRIGHT==================== + * Copyright 2018 gRPC authors. + + +====================COPYRIGHT==================== + * Copyright 2020 gRPC authors. + + +====================COPYRIGHT==================== +# Copyright 2018 The gRPC Authors. + + +====================COPYRIGHT==================== +# Copyright 2019 The gRPC Authors. + + +====================COPYRIGHT==================== +# Copyright 2019 The gRPC authors. + + +====================COPYRIGHT==================== +# Copyright 2019 gRPC authors. + + +====================COPYRIGHT==================== +# Copyright 2019 the gRPC authors. + + +====================COPYRIGHT==================== +# Copyright 2020 The gRPC Authors. + + +====================COPYRIGHT==================== +# Copyright 2020 The gRPC authors. + + +====================COPYRIGHT==================== +// Copyright 2018 The gRPC Authors + + +====================COPYRIGHT==================== +// Copyright 2019 The gRPC Authors + + +====================COPYRIGHT==================== +// Copyright 2020 The gRPC Authors diff --git a/contrib/libs/grpc/src/python/grpcio_tests/commands.py b/contrib/libs/grpc/src/python/grpcio_tests/commands.py new file mode 100644 index 00000000000..889b0bd9dc3 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/commands.py @@ -0,0 +1,344 @@ +# Copyright 2015 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Provides distutils command classes for the gRPC Python setup process.""" + +from distutils import errors as _errors +import glob +import os +import os.path +import platform +import re +import shutil +import sys + +import setuptools +from setuptools.command import build_ext +from setuptools.command import build_py +from setuptools.command import easy_install +from setuptools.command import install +from setuptools.command import test + +PYTHON_STEM = os.path.dirname(os.path.abspath(__file__)) +GRPC_STEM = os.path.abspath(PYTHON_STEM + '../../../../') +GRPC_PROTO_STEM = os.path.join(GRPC_STEM, 'src', 'proto') +PROTO_STEM = os.path.join(PYTHON_STEM, 'src', 'proto') +PYTHON_PROTO_TOP_LEVEL = os.path.join(PYTHON_STEM, 'src') + + +class CommandError(object): + pass + + +class GatherProto(setuptools.Command): + + description = 'gather proto dependencies' + user_options = [] + + def initialize_options(self): + pass + + def finalize_options(self): + pass + + def run(self): + # TODO(atash) ensure that we're running from the repository directory when + # this command is used + try: + shutil.rmtree(PROTO_STEM) + except Exception as error: + # We don't care if this command fails + pass + shutil.copytree(GRPC_PROTO_STEM, PROTO_STEM) + for root, _, _ in os.walk(PYTHON_PROTO_TOP_LEVEL): + path = os.path.join(root, '__init__.py') + open(path, 'a').close() + + +class BuildPy(build_py.build_py): + """Custom project build command.""" + + def run(self): + try: + self.run_command('build_package_protos') + except CommandError as error: + sys.stderr.write('warning: %s\n' % error.message) + build_py.build_py.run(self) + + +class TestLite(setuptools.Command): + """Command to run tests without fetching or building anything.""" + + description = 'run tests without fetching or building anything.' + user_options = [] + + def initialize_options(self): + pass + + def finalize_options(self): + # distutils requires this override. + pass + + def run(self): + self._add_eggs_to_path() + + import tests + loader = tests.Loader() + loader.loadTestsFromNames(['tests']) + runner = tests.Runner(dedicated_threads=True) + result = runner.run(loader.suite) + if not result.wasSuccessful(): + sys.exit('Test failure') + + def _add_eggs_to_path(self): + """Fetch install and test requirements""" + self.distribution.fetch_build_eggs(self.distribution.install_requires) + self.distribution.fetch_build_eggs(self.distribution.tests_require) + + +class TestPy3Only(setuptools.Command): + """Command to run tests for Python 3+ features. + + This does not include asyncio tests, which are housed in a separate + directory. + """ + + description = 'run tests for py3+ features' + user_options = [] + + def initialize_options(self): + pass + + def finalize_options(self): + pass + + def run(self): + self._add_eggs_to_path() + import tests + loader = tests.Loader() + loader.loadTestsFromNames(['tests_py3_only']) + runner = tests.Runner() + result = runner.run(loader.suite) + if not result.wasSuccessful(): + sys.exit('Test failure') + + def _add_eggs_to_path(self): + self.distribution.fetch_build_eggs(self.distribution.install_requires) + self.distribution.fetch_build_eggs(self.distribution.tests_require) + + +class TestAio(setuptools.Command): + """Command to run aio tests without fetching or building anything.""" + + description = 'run aio tests without fetching or building anything.' + user_options = [] + + def initialize_options(self): + pass + + def finalize_options(self): + pass + + def run(self): + self._add_eggs_to_path() + + import tests + loader = tests.Loader() + loader.loadTestsFromNames(['tests_aio']) + # Even without dedicated threads, the framework will somehow spawn a + # new thread for tests to run upon. New thread doesn't have event loop + # attached by default, so initialization is needed. + runner = tests.Runner(dedicated_threads=False) + result = runner.run(loader.suite) + if not result.wasSuccessful(): + sys.exit('Test failure') + + def _add_eggs_to_path(self): + """Fetch install and test requirements""" + self.distribution.fetch_build_eggs(self.distribution.install_requires) + self.distribution.fetch_build_eggs(self.distribution.tests_require) + + +class TestGevent(setuptools.Command): + """Command to run tests w/gevent.""" + + BANNED_TESTS = ( + # Fork support is not compatible with gevent + 'fork._fork_interop_test.ForkInteropTest', + # These tests send a lot of RPCs and are really slow on gevent. They will + # eventually succeed, but need to dig into performance issues. + 'unit._cython._no_messages_server_completion_queue_per_call_test.Test.test_rpcs', + 'unit._cython._no_messages_single_server_completion_queue_test.Test.test_rpcs', + 'unit._compression_test', + # TODO(https://github.com/grpc/grpc/issues/16890) enable this test + 'unit._cython._channel_test.ChannelTest.test_multiple_channels_lonely_connectivity', + # I have no idea why this doesn't work in gevent, but it shouldn't even be + # using the c-core + 'testing._client_test.ClientTest.test_infinite_request_stream_real_time', + # TODO(https://github.com/grpc/grpc/issues/15743) enable this test + 'unit._session_cache_test.SSLSessionCacheTest.testSSLSessionCacheLRU', + # TODO(https://github.com/grpc/grpc/issues/14789) enable this test + 'unit._server_ssl_cert_config_test', + # TODO(https://github.com/grpc/grpc/issues/14901) enable this test + 'protoc_plugin._python_plugin_test.PythonPluginTest', + 'protoc_plugin._python_plugin_test.SimpleStubsPluginTest', + # Beta API is unsupported for gevent + 'protoc_plugin.beta_python_plugin_test', + 'unit.beta._beta_features_test', + # TODO(https://github.com/grpc/grpc/issues/15411) unpin gevent version + # This test will stuck while running higher version of gevent + 'unit._auth_context_test.AuthContextTest.testSessionResumption', + # TODO(https://github.com/grpc/grpc/issues/15411) enable these tests + 'unit._channel_ready_future_test.ChannelReadyFutureTest.test_immediately_connectable_channel_connectivity', + "unit._cython._channel_test.ChannelTest.test_single_channel_lonely_connectivity", + 'unit._exit_test.ExitTest.test_in_flight_unary_unary_call', + 'unit._exit_test.ExitTest.test_in_flight_unary_stream_call', + 'unit._exit_test.ExitTest.test_in_flight_stream_unary_call', + 'unit._exit_test.ExitTest.test_in_flight_stream_stream_call', + 'unit._exit_test.ExitTest.test_in_flight_partial_unary_stream_call', + 'unit._exit_test.ExitTest.test_in_flight_partial_stream_unary_call', + 'unit._exit_test.ExitTest.test_in_flight_partial_stream_stream_call', + # TODO(https://github.com/grpc/grpc/issues/18980): Reenable. + 'unit._signal_handling_test.SignalHandlingTest', + 'unit._metadata_flags_test', + 'health_check._health_servicer_test.HealthServicerTest.test_cancelled_watch_removed_from_watch_list', + # TODO(https://github.com/grpc/grpc/issues/17330) enable these three tests + 'channelz._channelz_servicer_test.ChannelzServicerTest.test_many_subchannels', + 'channelz._channelz_servicer_test.ChannelzServicerTest.test_many_subchannels_and_sockets', + 'channelz._channelz_servicer_test.ChannelzServicerTest.test_streaming_rpc', + # TODO(https://github.com/grpc/grpc/issues/15411) enable this test + 'unit._cython._channel_test.ChannelTest.test_negative_deadline_connectivity', + # TODO(https://github.com/grpc/grpc/issues/15411) enable this test + 'unit._local_credentials_test.LocalCredentialsTest', + # TODO(https://github.com/grpc/grpc/issues/22020) LocalCredentials + # aren't supported with custom io managers. + 'unit._contextvars_propagation_test', + 'testing._time_test.StrictRealTimeTest', + ) + BANNED_WINDOWS_TESTS = ( + # TODO(https://github.com/grpc/grpc/pull/15411) enable this test + 'unit._dns_resolver_test.DNSResolverTest.test_connect_loopback', + # TODO(https://github.com/grpc/grpc/pull/15411) enable this test + 'unit._server_test.ServerTest.test_failed_port_binding_exception', + ) + description = 'run tests with gevent. Assumes grpc/gevent are installed' + user_options = [] + + def initialize_options(self): + pass + + def finalize_options(self): + # distutils requires this override. + pass + + def run(self): + from gevent import monkey + monkey.patch_all() + + import tests + + import grpc.experimental.gevent + grpc.experimental.gevent.init_gevent() + + import gevent + + import tests + loader = tests.Loader() + loader.loadTestsFromNames(['tests']) + runner = tests.Runner() + if sys.platform == 'win32': + runner.skip_tests(self.BANNED_TESTS + self.BANNED_WINDOWS_TESTS) + else: + runner.skip_tests(self.BANNED_TESTS) + result = gevent.spawn(runner.run, loader.suite) + result.join() + if not result.value.wasSuccessful(): + sys.exit('Test failure') + + +class RunInterop(test.test): + + description = 'run interop test client/server' + user_options = [ + ('args=', None, 'pass-thru arguments for the client/server'), + ('client', None, 'flag indicating to run the client'), + ('server', None, 'flag indicating to run the server'), + ('use-asyncio', None, 'flag indicating to run the asyncio stack') + ] + + def initialize_options(self): + self.args = '' + self.client = False + self.server = False + self.use_asyncio = False + + def finalize_options(self): + if self.client and self.server: + raise _errors.DistutilsOptionError( + 'you may only specify one of client or server') + + def run(self): + if self.distribution.install_requires: + self.distribution.fetch_build_eggs( + self.distribution.install_requires) + if self.distribution.tests_require: + self.distribution.fetch_build_eggs(self.distribution.tests_require) + if self.client: + self.run_client() + elif self.server: + self.run_server() + + def run_server(self): + # We import here to ensure that our setuptools parent has had a chance to + # edit the Python system path. + if self.use_asyncio: + import asyncio + from tests_aio.interop import server + sys.argv[1:] = self.args.split() + asyncio.get_event_loop().run_until_complete(server.serve()) + else: + from tests.interop import server + sys.argv[1:] = self.args.split() + server.serve() + + def run_client(self): + # We import here to ensure that our setuptools parent has had a chance to + # edit the Python system path. + from tests.interop import client + sys.argv[1:] = self.args.split() + client.test_interoperability() + + +class RunFork(test.test): + + description = 'run fork test client' + user_options = [('args=', 'a', 'pass-thru arguments for the client')] + + def initialize_options(self): + self.args = '' + + def finalize_options(self): + # distutils requires this override. + pass + + def run(self): + if self.distribution.install_requires: + self.distribution.fetch_build_eggs( + self.distribution.install_requires) + if self.distribution.tests_require: + self.distribution.fetch_build_eggs(self.distribution.tests_require) + # We import here to ensure that our setuptools parent has had a chance to + # edit the Python system path. + from tests.fork import client + sys.argv[1:] = self.args.split() + client.test_fork() diff --git a/contrib/libs/grpc/src/python/grpcio_tests/grpc_version.py b/contrib/libs/grpc/src/python/grpcio_tests/grpc_version.py new file mode 100644 index 00000000000..219b336a429 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/grpc_version.py @@ -0,0 +1,17 @@ +# Copyright 2016 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# AUTO-GENERATED FROM `$REPO_ROOT/templates/src/python/grpcio_tests/grpc_version.py.template`!!! + +VERSION = '1.33.2' diff --git a/contrib/libs/grpc/src/python/grpcio_tests/setup.py b/contrib/libs/grpc/src/python/grpcio_tests/setup.py new file mode 100644 index 00000000000..87cccda425b --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/setup.py @@ -0,0 +1,112 @@ +# Copyright 2015 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A setup module for the gRPC Python package.""" + +import multiprocessing +import os +import os.path +import sys + +import setuptools + +import grpc_tools.command + +PY3 = sys.version_info.major == 3 + +# Ensure we're in the proper directory whether or not we're being used by pip. +os.chdir(os.path.dirname(os.path.abspath(__file__))) + +# Break import-style to ensure we can actually find our in-repo dependencies. +import commands +import grpc_version + +LICENSE = 'Apache License 2.0' + +PACKAGE_DIRECTORIES = { + '': '.', +} + +INSTALL_REQUIRES = ( + 'coverage>=4.0', 'grpcio>={version}'.format(version=grpc_version.VERSION), + 'grpcio-channelz>={version}'.format(version=grpc_version.VERSION), + 'grpcio-status>={version}'.format(version=grpc_version.VERSION), + 'grpcio-tools>={version}'.format(version=grpc_version.VERSION), + 'grpcio-health-checking>={version}'.format(version=grpc_version.VERSION), + 'oauth2client>=1.4.7', 'protobuf>=3.6.0', 'six>=1.10', + 'google-auth>=1.17.2', 'requests>=2.14.2') + +if not PY3: + INSTALL_REQUIRES += ('futures>=2.2.0', 'enum34>=1.0.4') + +COMMAND_CLASS = { + # Run `preprocess` *before* doing any packaging! + 'preprocess': commands.GatherProto, + 'build_package_protos': grpc_tools.command.BuildPackageProtos, + 'build_py': commands.BuildPy, + 'run_fork': commands.RunFork, + 'run_interop': commands.RunInterop, + 'test_lite': commands.TestLite, + 'test_gevent': commands.TestGevent, + 'test_aio': commands.TestAio, + 'test_py3_only': commands.TestPy3Only, +} + +PACKAGE_DATA = { + 'tests.interop': [ + 'credentials/ca.pem', + 'credentials/server1.key', + 'credentials/server1.pem', + ], + 'tests.protoc_plugin.protos.invocation_testing': ['same.proto',], + 'tests.protoc_plugin.protos.invocation_testing.split_messages': [ + 'messages.proto', + ], + 'tests.protoc_plugin.protos.invocation_testing.split_services': [ + 'services.proto', + ], + 'tests.testing.proto': [ + 'requests.proto', + 'services.proto', + ], + 'tests.unit': [ + 'credentials/ca.pem', + 'credentials/server1.key', + 'credentials/server1.pem', + ], + 'tests': ['tests.json'], +} + +TEST_SUITE = 'tests' +TEST_LOADER = 'tests:Loader' +TEST_RUNNER = 'tests:Runner' +TESTS_REQUIRE = INSTALL_REQUIRES + +PACKAGES = setuptools.find_packages('.') + +if __name__ == "__main__": + multiprocessing.freeze_support() + setuptools.setup( + name='grpcio-tests', + version=grpc_version.VERSION, + license=LICENSE, + packages=list(PACKAGES), + package_dir=PACKAGE_DIRECTORIES, + package_data=PACKAGE_DATA, + install_requires=INSTALL_REQUIRES, + cmdclass=COMMAND_CLASS, + tests_require=TESTS_REQUIRE, + test_suite=TEST_SUITE, + test_loader=TEST_LOADER, + test_runner=TEST_RUNNER, + ) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/__init__.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/__init__.py new file mode 100644 index 00000000000..d2466fd0228 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2015 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import + +from tests import _loader +from tests import _runner + +Loader = _loader.Loader +Runner = _runner.Runner diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/_loader.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/_loader.py new file mode 100644 index 00000000000..80c107aa8e4 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/_loader.py @@ -0,0 +1,106 @@ +# Copyright 2015 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import + +import importlib +import pkgutil +import re +import unittest + +import coverage + +TEST_MODULE_REGEX = r'^.*_test$' + + +class Loader(object): + """Test loader for setuptools test suite support. + + Attributes: + suite (unittest.TestSuite): All tests collected by the loader. + loader (unittest.TestLoader): Standard Python unittest loader to be ran per + module discovered. + module_matcher (re.RegexObject): A regular expression object to match + against module names and determine whether or not the discovered module + contributes to the test suite. + """ + + def __init__(self): + self.suite = unittest.TestSuite() + self.loader = unittest.TestLoader() + self.module_matcher = re.compile(TEST_MODULE_REGEX) + + def loadTestsFromNames(self, names, module=None): + """Function mirroring TestLoader::loadTestsFromNames, as expected by + setuptools.setup argument `test_loader`.""" + # ensure that we capture decorators and definitions (else our coverage + # measure unnecessarily suffers) + coverage_context = coverage.Coverage(data_suffix=True) + coverage_context.start() + imported_modules = tuple( + importlib.import_module(name) for name in names) + for imported_module in imported_modules: + self.visit_module(imported_module) + for imported_module in imported_modules: + try: + package_paths = imported_module.__path__ + except AttributeError: + continue + self.walk_packages(package_paths) + coverage_context.stop() + coverage_context.save() + return self.suite + + def walk_packages(self, package_paths): + """Walks over the packages, dispatching `visit_module` calls. + + Args: + package_paths (list): A list of paths over which to walk through modules + along. + """ + for importer, module_name, is_package in ( + pkgutil.walk_packages(package_paths)): + module = importer.find_module(module_name).load_module(module_name) + self.visit_module(module) + + def visit_module(self, module): + """Visits the module, adding discovered tests to the test suite. + + Args: + module (module): Module to match against self.module_matcher; if matched + it has its tests loaded via self.loader into self.suite. + """ + if self.module_matcher.match(module.__name__): + module_suite = self.loader.loadTestsFromModule(module) + self.suite.addTest(module_suite) + + +def iterate_suite_cases(suite): + """Generator over all unittest.TestCases in a unittest.TestSuite. + + Args: + suite (unittest.TestSuite): Suite to iterate over in the generator. + + Returns: + generator: A generator over all unittest.TestCases in `suite`. + """ + for item in suite: + if isinstance(item, unittest.TestSuite): + for child_item in iterate_suite_cases(item): + yield child_item + elif isinstance(item, unittest.TestCase): + yield item + else: + raise ValueError('unexpected suite item of type {}'.format( + type(item))) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/_result.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/_result.py new file mode 100644 index 00000000000..389d5f4f96a --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/_result.py @@ -0,0 +1,439 @@ +# Copyright 2015 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import + +import collections +import itertools +import traceback +import unittest +from xml.etree import ElementTree + +import coverage +from six import moves + +from tests import _loader + + +class CaseResult( + collections.namedtuple('CaseResult', [ + 'id', 'name', 'kind', 'stdout', 'stderr', 'skip_reason', 'traceback' + ])): + """A serializable result of a single test case. + + Attributes: + id (object): Any serializable object used to denote the identity of this + test case. + name (str or None): A human-readable name of the test case. + kind (CaseResult.Kind): The kind of test result. + stdout (object or None): Output on stdout, or None if nothing was captured. + stderr (object or None): Output on stderr, or None if nothing was captured. + skip_reason (object or None): The reason the test was skipped. Must be + something if self.kind is CaseResult.Kind.SKIP, else None. + traceback (object or None): The traceback of the test. Must be something if + self.kind is CaseResult.Kind.{ERROR, FAILURE, EXPECTED_FAILURE}, else + None. + """ + + class Kind(object): + UNTESTED = 'untested' + RUNNING = 'running' + ERROR = 'error' + FAILURE = 'failure' + SUCCESS = 'success' + SKIP = 'skip' + EXPECTED_FAILURE = 'expected failure' + UNEXPECTED_SUCCESS = 'unexpected success' + + def __new__(cls, + id=None, + name=None, + kind=None, + stdout=None, + stderr=None, + skip_reason=None, + traceback=None): + """Helper keyword constructor for the namedtuple. + + See this class' attributes for information on the arguments.""" + assert id is not None + assert name is None or isinstance(name, str) + if kind is CaseResult.Kind.UNTESTED: + pass + elif kind is CaseResult.Kind.RUNNING: + pass + elif kind is CaseResult.Kind.ERROR: + assert traceback is not None + elif kind is CaseResult.Kind.FAILURE: + assert traceback is not None + elif kind is CaseResult.Kind.SUCCESS: + pass + elif kind is CaseResult.Kind.SKIP: + assert skip_reason is not None + elif kind is CaseResult.Kind.EXPECTED_FAILURE: + assert traceback is not None + elif kind is CaseResult.Kind.UNEXPECTED_SUCCESS: + pass + else: + assert False + return super(cls, CaseResult).__new__(cls, id, name, kind, stdout, + stderr, skip_reason, traceback) + + def updated(self, + name=None, + kind=None, + stdout=None, + stderr=None, + skip_reason=None, + traceback=None): + """Get a new validated CaseResult with the fields updated. + + See this class' attributes for information on the arguments.""" + name = self.name if name is None else name + kind = self.kind if kind is None else kind + stdout = self.stdout if stdout is None else stdout + stderr = self.stderr if stderr is None else stderr + skip_reason = self.skip_reason if skip_reason is None else skip_reason + traceback = self.traceback if traceback is None else traceback + return CaseResult(id=self.id, + name=name, + kind=kind, + stdout=stdout, + stderr=stderr, + skip_reason=skip_reason, + traceback=traceback) + + +class AugmentedResult(unittest.TestResult): + """unittest.Result that keeps track of additional information. + + Uses CaseResult objects to store test-case results, providing additional + information beyond that of the standard Python unittest library, such as + standard output. + + Attributes: + id_map (callable): A unary callable mapping unittest.TestCase objects to + unique identifiers. + cases (dict): A dictionary mapping from the identifiers returned by id_map + to CaseResult objects corresponding to those IDs. + """ + + def __init__(self, id_map): + """Initialize the object with an identifier mapping. + + Arguments: + id_map (callable): Corresponds to the attribute `id_map`.""" + super(AugmentedResult, self).__init__() + self.id_map = id_map + self.cases = None + + def startTestRun(self): + """See unittest.TestResult.startTestRun.""" + super(AugmentedResult, self).startTestRun() + self.cases = dict() + + def startTest(self, test): + """See unittest.TestResult.startTest.""" + super(AugmentedResult, self).startTest(test) + case_id = self.id_map(test) + self.cases[case_id] = CaseResult(id=case_id, + name=test.id(), + kind=CaseResult.Kind.RUNNING) + + def addError(self, test, err): + """See unittest.TestResult.addError.""" + super(AugmentedResult, self).addError(test, err) + case_id = self.id_map(test) + self.cases[case_id] = self.cases[case_id].updated( + kind=CaseResult.Kind.ERROR, traceback=err) + + def addFailure(self, test, err): + """See unittest.TestResult.addFailure.""" + super(AugmentedResult, self).addFailure(test, err) + case_id = self.id_map(test) + self.cases[case_id] = self.cases[case_id].updated( + kind=CaseResult.Kind.FAILURE, traceback=err) + + def addSuccess(self, test): + """See unittest.TestResult.addSuccess.""" + super(AugmentedResult, self).addSuccess(test) + case_id = self.id_map(test) + self.cases[case_id] = self.cases[case_id].updated( + kind=CaseResult.Kind.SUCCESS) + + def addSkip(self, test, reason): + """See unittest.TestResult.addSkip.""" + super(AugmentedResult, self).addSkip(test, reason) + case_id = self.id_map(test) + self.cases[case_id] = self.cases[case_id].updated( + kind=CaseResult.Kind.SKIP, skip_reason=reason) + + def addExpectedFailure(self, test, err): + """See unittest.TestResult.addExpectedFailure.""" + super(AugmentedResult, self).addExpectedFailure(test, err) + case_id = self.id_map(test) + self.cases[case_id] = self.cases[case_id].updated( + kind=CaseResult.Kind.EXPECTED_FAILURE, traceback=err) + + def addUnexpectedSuccess(self, test): + """See unittest.TestResult.addUnexpectedSuccess.""" + super(AugmentedResult, self).addUnexpectedSuccess(test) + case_id = self.id_map(test) + self.cases[case_id] = self.cases[case_id].updated( + kind=CaseResult.Kind.UNEXPECTED_SUCCESS) + + def set_output(self, test, stdout, stderr): + """Set the output attributes for the CaseResult corresponding to a test. + + Args: + test (unittest.TestCase): The TestCase to set the outputs of. + stdout (str): Output from stdout to assign to self.id_map(test). + stderr (str): Output from stderr to assign to self.id_map(test). + """ + case_id = self.id_map(test) + self.cases[case_id] = self.cases[case_id].updated( + stdout=stdout.decode(), stderr=stderr.decode()) + + def augmented_results(self, filter): + """Convenience method to retrieve filtered case results. + + Args: + filter (callable): A unary predicate to filter over CaseResult objects. + """ + return (self.cases[case_id] + for case_id in self.cases + if filter(self.cases[case_id])) + + +class CoverageResult(AugmentedResult): + """Extension to AugmentedResult adding coverage.py support per test.\ + + Attributes: + coverage_context (coverage.Coverage): coverage.py management object. + """ + + def __init__(self, id_map): + """See AugmentedResult.__init__.""" + super(CoverageResult, self).__init__(id_map=id_map) + self.coverage_context = None + + def startTest(self, test): + """See unittest.TestResult.startTest. + + Additionally initializes and begins code coverage tracking.""" + super(CoverageResult, self).startTest(test) + self.coverage_context = coverage.Coverage(data_suffix=True) + self.coverage_context.start() + + def stopTest(self, test): + """See unittest.TestResult.stopTest. + + Additionally stops and deinitializes code coverage tracking.""" + super(CoverageResult, self).stopTest(test) + self.coverage_context.stop() + self.coverage_context.save() + self.coverage_context = None + + +class _Colors(object): + """Namespaced constants for terminal color magic numbers.""" + HEADER = '\033[95m' + INFO = '\033[94m' + OK = '\033[92m' + WARN = '\033[93m' + FAIL = '\033[91m' + BOLD = '\033[1m' + UNDERLINE = '\033[4m' + END = '\033[0m' + + +class TerminalResult(CoverageResult): + """Extension to CoverageResult adding basic terminal reporting.""" + + def __init__(self, out, id_map): + """Initialize the result object. + + Args: + out (file-like): Output file to which terminal-colored live results will + be written. + id_map (callable): See AugmentedResult.__init__. + """ + super(TerminalResult, self).__init__(id_map=id_map) + self.out = out + + def startTestRun(self): + """See unittest.TestResult.startTestRun.""" + super(TerminalResult, self).startTestRun() + self.out.write(_Colors.HEADER + 'Testing gRPC Python...\n' + + _Colors.END) + + def stopTestRun(self): + """See unittest.TestResult.stopTestRun.""" + super(TerminalResult, self).stopTestRun() + self.out.write(summary(self)) + self.out.flush() + + def addError(self, test, err): + """See unittest.TestResult.addError.""" + super(TerminalResult, self).addError(test, err) + self.out.write(_Colors.FAIL + 'ERROR {}\n'.format(test.id()) + + _Colors.END) + self.out.flush() + + def addFailure(self, test, err): + """See unittest.TestResult.addFailure.""" + super(TerminalResult, self).addFailure(test, err) + self.out.write(_Colors.FAIL + 'FAILURE {}\n'.format(test.id()) + + _Colors.END) + self.out.flush() + + def addSuccess(self, test): + """See unittest.TestResult.addSuccess.""" + super(TerminalResult, self).addSuccess(test) + self.out.write(_Colors.OK + 'SUCCESS {}\n'.format(test.id()) + + _Colors.END) + self.out.flush() + + def addSkip(self, test, reason): + """See unittest.TestResult.addSkip.""" + super(TerminalResult, self).addSkip(test, reason) + self.out.write(_Colors.INFO + 'SKIP {}\n'.format(test.id()) + + _Colors.END) + self.out.flush() + + def addExpectedFailure(self, test, err): + """See unittest.TestResult.addExpectedFailure.""" + super(TerminalResult, self).addExpectedFailure(test, err) + self.out.write(_Colors.INFO + 'FAILURE_OK {}\n'.format(test.id()) + + _Colors.END) + self.out.flush() + + def addUnexpectedSuccess(self, test): + """See unittest.TestResult.addUnexpectedSuccess.""" + super(TerminalResult, self).addUnexpectedSuccess(test) + self.out.write(_Colors.INFO + 'UNEXPECTED_OK {}\n'.format(test.id()) + + _Colors.END) + self.out.flush() + + +def _traceback_string(type, value, trace): + """Generate a descriptive string of a Python exception traceback. + + Args: + type (class): The type of the exception. + value (Exception): The value of the exception. + trace (traceback): Traceback of the exception. + + Returns: + str: Formatted exception descriptive string. + """ + buffer = moves.cStringIO() + traceback.print_exception(type, value, trace, file=buffer) + return buffer.getvalue() + + +def summary(result): + """A summary string of a result object. + + Args: + result (AugmentedResult): The result object to get the summary of. + + Returns: + str: The summary string. + """ + assert isinstance(result, AugmentedResult) + untested = list( + result.augmented_results( + lambda case_result: case_result.kind is CaseResult.Kind.UNTESTED)) + running = list( + result.augmented_results( + lambda case_result: case_result.kind is CaseResult.Kind.RUNNING)) + failures = list( + result.augmented_results( + lambda case_result: case_result.kind is CaseResult.Kind.FAILURE)) + errors = list( + result.augmented_results( + lambda case_result: case_result.kind is CaseResult.Kind.ERROR)) + successes = list( + result.augmented_results( + lambda case_result: case_result.kind is CaseResult.Kind.SUCCESS)) + skips = list( + result.augmented_results( + lambda case_result: case_result.kind is CaseResult.Kind.SKIP)) + expected_failures = list( + result.augmented_results(lambda case_result: case_result.kind is + CaseResult.Kind.EXPECTED_FAILURE)) + unexpected_successes = list( + result.augmented_results(lambda case_result: case_result.kind is + CaseResult.Kind.UNEXPECTED_SUCCESS)) + running_names = [case.name for case in running] + finished_count = (len(failures) + len(errors) + len(successes) + + len(expected_failures) + len(unexpected_successes)) + statistics = ('{finished} tests finished:\n' + '\t{successful} successful\n' + '\t{unsuccessful} unsuccessful\n' + '\t{skipped} skipped\n' + '\t{expected_fail} expected failures\n' + '\t{unexpected_successful} unexpected successes\n' + 'Interrupted Tests:\n' + '\t{interrupted}\n'.format( + finished=finished_count, + successful=len(successes), + unsuccessful=(len(failures) + len(errors)), + skipped=len(skips), + expected_fail=len(expected_failures), + unexpected_successful=len(unexpected_successes), + interrupted=str(running_names))) + tracebacks = '\n\n'.join([ + (_Colors.FAIL + '{test_name}' + _Colors.END + '\n' + _Colors.BOLD + + 'traceback:' + _Colors.END + '\n' + '{traceback}\n' + _Colors.BOLD + + 'stdout:' + _Colors.END + '\n' + '{stdout}\n' + _Colors.BOLD + + 'stderr:' + _Colors.END + '\n' + '{stderr}\n').format( + test_name=result.name, + traceback=_traceback_string(*result.traceback), + stdout=result.stdout, + stderr=result.stderr) + for result in itertools.chain(failures, errors) + ]) + notes = 'Unexpected successes: {}\n'.format( + [result.name for result in unexpected_successes]) + return statistics + '\nErrors/Failures: \n' + tracebacks + '\n' + notes + + +def jenkins_junit_xml(result): + """An XML tree object that when written is recognizable by Jenkins. + + Args: + result (AugmentedResult): The result object to get the junit xml output of. + + Returns: + ElementTree.ElementTree: The XML tree. + """ + assert isinstance(result, AugmentedResult) + root = ElementTree.Element('testsuites') + suite = ElementTree.SubElement(root, 'testsuite', { + 'name': 'Python gRPC tests', + }) + for case in result.cases.values(): + if case.kind is CaseResult.Kind.SUCCESS: + ElementTree.SubElement(suite, 'testcase', { + 'name': case.name, + }) + elif case.kind in (CaseResult.Kind.ERROR, CaseResult.Kind.FAILURE): + case_xml = ElementTree.SubElement(suite, 'testcase', { + 'name': case.name, + }) + error_xml = ElementTree.SubElement(case_xml, 'error', {}) + error_xml.text = ''.format(case.stderr, case.traceback) + return ElementTree.ElementTree(element=root) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/_runner.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/_runner.py new file mode 100644 index 00000000000..39da0399b02 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/_runner.py @@ -0,0 +1,239 @@ +# Copyright 2015 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import + +import collections +import os +import select +import signal +import sys +import tempfile +import threading +import time +import unittest +import uuid + +import six +from six import moves + +from tests import _loader +from tests import _result + + +class CaptureFile(object): + """A context-managed file to redirect output to a byte array. + + Use by invoking `start` (`__enter__`) and at some point invoking `stop` + (`__exit__`). At any point after the initial call to `start` call `output` to + get the current redirected output. Note that we don't currently use file + locking, so calling `output` between calls to `start` and `stop` may muddle + the result (you should only be doing this during a Python-handled interrupt as + a last ditch effort to provide output to the user). + + Attributes: + _redirected_fd (int): File descriptor of file to redirect writes from. + _saved_fd (int): A copy of the original value of the redirected file + descriptor. + _into_file (TemporaryFile or None): File to which writes are redirected. + Only non-None when self is started. + """ + + def __init__(self, fd): + self._redirected_fd = fd + self._saved_fd = os.dup(self._redirected_fd) + self._into_file = None + + def output(self): + """Get all output from the redirected-to file if it exists.""" + if self._into_file: + self._into_file.seek(0) + return bytes(self._into_file.read()) + else: + return bytes() + + def start(self): + """Start redirection of writes to the file descriptor.""" + self._into_file = tempfile.TemporaryFile() + os.dup2(self._into_file.fileno(), self._redirected_fd) + + def stop(self): + """Stop redirection of writes to the file descriptor.""" + # n.b. this dup2 call auto-closes self._redirected_fd + os.dup2(self._saved_fd, self._redirected_fd) + + def write_bypass(self, value): + """Bypass the redirection and write directly to the original file. + + Arguments: + value (str): What to write to the original file. + """ + if six.PY3 and not isinstance(value, six.binary_type): + value = bytes(value, 'ascii') + if self._saved_fd is None: + os.write(self._redirect_fd, value) + else: + os.write(self._saved_fd, value) + + def __enter__(self): + self.start() + return self + + def __exit__(self, type, value, traceback): + self.stop() + + def close(self): + """Close any resources used by self not closed by stop().""" + os.close(self._saved_fd) + + +class AugmentedCase(collections.namedtuple('AugmentedCase', ['case', 'id'])): + """A test case with a guaranteed unique externally specified identifier. + + Attributes: + case (unittest.TestCase): TestCase we're decorating with an additional + identifier. + id (object): Any identifier that may be considered 'unique' for testing + purposes. + """ + + def __new__(cls, case, id=None): + if id is None: + id = uuid.uuid4() + return super(cls, AugmentedCase).__new__(cls, case, id) + + +# NOTE(lidiz) This complex wrapper is not triggering setUpClass nor +# tearDownClass. Do not use those methods, or fix this wrapper! +class Runner(object): + + def __init__(self, dedicated_threads=False): + """Constructs the Runner object. + + Args: + dedicated_threads: A bool indicates whether to spawn each unit test + in separate thread or not. + """ + self._skipped_tests = [] + self._dedicated_threads = dedicated_threads + + def skip_tests(self, tests): + self._skipped_tests = tests + + def run(self, suite): + """See setuptools' test_runner setup argument for information.""" + # only run test cases with id starting with given prefix + testcase_filter = os.getenv('GRPC_PYTHON_TESTRUNNER_FILTER') + filtered_cases = [] + for case in _loader.iterate_suite_cases(suite): + if not testcase_filter or case.id().startswith(testcase_filter): + filtered_cases.append(case) + + # Ensure that every test case has no collision with any other test case in + # the augmented results. + augmented_cases = [ + AugmentedCase(case, uuid.uuid4()) for case in filtered_cases + ] + case_id_by_case = dict((augmented_case.case, augmented_case.id) + for augmented_case in augmented_cases) + result_out = moves.cStringIO() + result = _result.TerminalResult( + result_out, id_map=lambda case: case_id_by_case[case]) + stdout_pipe = CaptureFile(sys.stdout.fileno()) + stderr_pipe = CaptureFile(sys.stderr.fileno()) + kill_flag = [False] + + def sigint_handler(signal_number, frame): + if signal_number == signal.SIGINT: + kill_flag[0] = True # Python 2.7 not having 'local'... :-( + signal.signal(signal_number, signal.SIG_DFL) + + def fault_handler(signal_number, frame): + stdout_pipe.write_bypass( + 'Received fault signal {}\nstdout:\n{}\n\nstderr:{}\n'.format( + signal_number, stdout_pipe.output(), stderr_pipe.output())) + os._exit(1) + + def check_kill_self(): + if kill_flag[0]: + stdout_pipe.write_bypass('Stopping tests short...') + result.stopTestRun() + stdout_pipe.write_bypass(result_out.getvalue()) + stdout_pipe.write_bypass('\ninterrupted stdout:\n{}\n'.format( + stdout_pipe.output().decode())) + stderr_pipe.write_bypass('\ninterrupted stderr:\n{}\n'.format( + stderr_pipe.output().decode())) + os._exit(1) + + def try_set_handler(name, handler): + try: + signal.signal(getattr(signal, name), handler) + except AttributeError: + pass + + try_set_handler('SIGINT', sigint_handler) + try_set_handler('SIGBUS', fault_handler) + try_set_handler('SIGABRT', fault_handler) + try_set_handler('SIGFPE', fault_handler) + try_set_handler('SIGILL', fault_handler) + # Sometimes output will lag after a test has successfully finished; we + # ignore such writes to our pipes. + try_set_handler('SIGPIPE', signal.SIG_IGN) + + # Run the tests + result.startTestRun() + for augmented_case in augmented_cases: + for skipped_test in self._skipped_tests: + if skipped_test in augmented_case.case.id(): + break + else: + sys.stdout.write('Running {}\n'.format( + augmented_case.case.id())) + sys.stdout.flush() + if self._dedicated_threads: + # (Deprecated) Spawns dedicated thread for each test case. + case_thread = threading.Thread( + target=augmented_case.case.run, args=(result,)) + try: + with stdout_pipe, stderr_pipe: + case_thread.start() + # If the thread is exited unexpected, stop testing. + while case_thread.is_alive(): + check_kill_self() + time.sleep(0) + case_thread.join() + except: # pylint: disable=try-except-raise + # re-raise the exception after forcing the with-block to end + raise + # Records the result of the test case run. + result.set_output(augmented_case.case, stdout_pipe.output(), + stderr_pipe.output()) + sys.stdout.write(result_out.getvalue()) + sys.stdout.flush() + result_out.truncate(0) + check_kill_self() + else: + # Donates current thread to test case execution. + augmented_case.case.run(result) + result.stopTestRun() + stdout_pipe.close() + stderr_pipe.close() + + # Report results + sys.stdout.write(result_out.getvalue()) + sys.stdout.flush() + signal.signal(signal.SIGINT, signal.SIG_DFL) + with open('report.xml', 'wb') as report_xml_file: + _result.jenkins_junit_xml(result).write(report_xml_file) + return result diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/_sanity/__init__.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/_sanity/__init__.py new file mode 100644 index 00000000000..5772620b602 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/_sanity/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2016 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/_sanity/_sanity_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/_sanity/_sanity_test.py new file mode 100644 index 00000000000..3aa92f37fb1 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/_sanity/_sanity_test.py @@ -0,0 +1,48 @@ +# Copyright 2016 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import pkgutil +import unittest + +import six + +import tests + + +class SanityTest(unittest.TestCase): + + maxDiff = 32768 + + TEST_PKG_MODULE_NAME = 'tests' + TEST_PKG_PATH = 'tests' + + def testTestsJsonUpToDate(self): + """Autodiscovers all test suites and checks that tests.json is up to date""" + loader = tests.Loader() + loader.loadTestsFromNames([self.TEST_PKG_MODULE_NAME]) + test_suite_names = sorted({ + test_case_class.id().rsplit('.', 1)[0] for test_case_class in + tests._loader.iterate_suite_cases(loader.suite) + }) + + tests_json_string = pkgutil.get_data(self.TEST_PKG_PATH, 'tests.json') + tests_json = json.loads( + tests_json_string.decode() if six.PY3 else tests_json_string) + + self.assertSequenceEqual(tests_json, test_suite_names) + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/bazel_namespace_package_hack.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/bazel_namespace_package_hack.py new file mode 100644 index 00000000000..994a8e1e800 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/bazel_namespace_package_hack.py @@ -0,0 +1,40 @@ +# Copyright 2019 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import site +import sys + +_GRPC_BAZEL_RUNTIME_ENV = "GRPC_BAZEL_RUNTIME" + + +# TODO(https://github.com/bazelbuild/bazel/issues/6844) Bazel failed to +# interpret namespace packages correctly. This monkey patch will force the +# Python process to parse the .pth file in the sys.path to resolve namespace +# package in the right place. +# Analysis in depth: https://github.com/bazelbuild/rules_python/issues/55 +def sys_path_to_site_dir_hack(): + """Add valid sys.path item to site directory to parse the .pth files.""" + # Only run within our Bazel environment + if not os.environ.get(_GRPC_BAZEL_RUNTIME_ENV): + return + items = [] + for item in sys.path: + if os.path.exists(item): + # The only difference between sys.path and site-directory is + # whether the .pth file will be parsed or not. A site-directory + # will always exist in sys.path, but not another way around. + items.append(item) + for item in items: + site.addsitedir(item) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/channelz/__init__.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/channelz/__init__.py new file mode 100644 index 00000000000..38fdfc9c5cf --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/channelz/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2018 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/channelz/_channelz_servicer_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/channelz/_channelz_servicer_test.py new file mode 100644 index 00000000000..784307ae005 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/channelz/_channelz_servicer_test.py @@ -0,0 +1,469 @@ +# Copyright 2018 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests of grpc_channelz.v1.channelz.""" + +import unittest + +from concurrent import futures + +import grpc + +from grpc_channelz.v1 import channelz +from grpc_channelz.v1 import channelz_pb2 +from grpc_channelz.v1 import channelz_pb2_grpc + +from tests.unit import test_common +from tests.unit.framework.common import test_constants + +_SUCCESSFUL_UNARY_UNARY = '/test/SuccessfulUnaryUnary' +_FAILED_UNARY_UNARY = '/test/FailedUnaryUnary' +_SUCCESSFUL_STREAM_STREAM = '/test/SuccessfulStreamStream' + +_REQUEST = b'\x00\x00\x00' +_RESPONSE = b'\x01\x01\x01' + +_DISABLE_REUSE_PORT = (('grpc.so_reuseport', 0),) +_ENABLE_CHANNELZ = (('grpc.enable_channelz', 1),) +_DISABLE_CHANNELZ = (('grpc.enable_channelz', 0),) + + +def _successful_unary_unary(request, servicer_context): + return _RESPONSE + + +def _failed_unary_unary(request, servicer_context): + servicer_context.set_code(grpc.StatusCode.INTERNAL) + servicer_context.set_details("Channelz Test Intended Failure") + + +def _successful_stream_stream(request_iterator, servicer_context): + for _ in request_iterator: + yield _RESPONSE + + +class _GenericHandler(grpc.GenericRpcHandler): + + def service(self, handler_call_details): + if handler_call_details.method == _SUCCESSFUL_UNARY_UNARY: + return grpc.unary_unary_rpc_method_handler(_successful_unary_unary) + elif handler_call_details.method == _FAILED_UNARY_UNARY: + return grpc.unary_unary_rpc_method_handler(_failed_unary_unary) + elif handler_call_details.method == _SUCCESSFUL_STREAM_STREAM: + return grpc.stream_stream_rpc_method_handler( + _successful_stream_stream) + else: + return None + + +class _ChannelServerPair(object): + + def __init__(self): + # Server will enable channelz service + self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=3), + options=_DISABLE_REUSE_PORT + + _ENABLE_CHANNELZ) + port = self.server.add_insecure_port('[::]:0') + self.server.add_generic_rpc_handlers((_GenericHandler(),)) + self.server.start() + + # Channel will enable channelz service... + self.channel = grpc.insecure_channel('localhost:%d' % port, + _ENABLE_CHANNELZ) + + +def _generate_channel_server_pairs(n): + return [_ChannelServerPair() for i in range(n)] + + +def _close_channel_server_pairs(pairs): + for pair in pairs: + pair.server.stop(None) + pair.channel.close() + + +class ChannelzServicerTest(unittest.TestCase): + + def _send_successful_unary_unary(self, idx): + _, r = self._pairs[idx].channel.unary_unary( + _SUCCESSFUL_UNARY_UNARY).with_call(_REQUEST) + self.assertEqual(r.code(), grpc.StatusCode.OK) + + def _send_failed_unary_unary(self, idx): + try: + self._pairs[idx].channel.unary_unary(_FAILED_UNARY_UNARY).with_call( + _REQUEST) + except grpc.RpcError: + return + else: + self.fail("This call supposed to fail") + + def _send_successful_stream_stream(self, idx): + response_iterator = self._pairs[idx].channel.stream_stream( + _SUCCESSFUL_STREAM_STREAM).__call__( + iter([_REQUEST] * test_constants.STREAM_LENGTH)) + cnt = 0 + for _ in response_iterator: + cnt += 1 + self.assertEqual(cnt, test_constants.STREAM_LENGTH) + + def _get_channel_id(self, idx): + """Channel id may not be consecutive""" + resp = self._channelz_stub.GetTopChannels( + channelz_pb2.GetTopChannelsRequest(start_channel_id=0)) + self.assertGreater(len(resp.channel), idx) + return resp.channel[idx].ref.channel_id + + def setUp(self): + self._pairs = [] + # This server is for Channelz info fetching only + # It self should not enable Channelz + self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=3), + options=_DISABLE_REUSE_PORT + + _DISABLE_CHANNELZ) + port = self._server.add_insecure_port('[::]:0') + channelz.add_channelz_servicer(self._server) + self._server.start() + + # This channel is used to fetch Channelz info only + # Channelz should not be enabled + self._channel = grpc.insecure_channel('localhost:%d' % port, + _DISABLE_CHANNELZ) + self._channelz_stub = channelz_pb2_grpc.ChannelzStub(self._channel) + + def tearDown(self): + self._server.stop(None) + self._channel.close() + _close_channel_server_pairs(self._pairs) + + def test_get_top_channels_basic(self): + self._pairs = _generate_channel_server_pairs(1) + resp = self._channelz_stub.GetTopChannels( + channelz_pb2.GetTopChannelsRequest(start_channel_id=0)) + self.assertEqual(len(resp.channel), 1) + self.assertEqual(resp.end, True) + + def test_get_top_channels_high_start_id(self): + self._pairs = _generate_channel_server_pairs(1) + resp = self._channelz_stub.GetTopChannels( + channelz_pb2.GetTopChannelsRequest(start_channel_id=10000)) + self.assertEqual(len(resp.channel), 0) + self.assertEqual(resp.end, True) + + def test_successful_request(self): + self._pairs = _generate_channel_server_pairs(1) + self._send_successful_unary_unary(0) + resp = self._channelz_stub.GetChannel( + channelz_pb2.GetChannelRequest(channel_id=self._get_channel_id(0))) + self.assertEqual(resp.channel.data.calls_started, 1) + self.assertEqual(resp.channel.data.calls_succeeded, 1) + self.assertEqual(resp.channel.data.calls_failed, 0) + + def test_failed_request(self): + self._pairs = _generate_channel_server_pairs(1) + self._send_failed_unary_unary(0) + resp = self._channelz_stub.GetChannel( + channelz_pb2.GetChannelRequest(channel_id=self._get_channel_id(0))) + self.assertEqual(resp.channel.data.calls_started, 1) + self.assertEqual(resp.channel.data.calls_succeeded, 0) + self.assertEqual(resp.channel.data.calls_failed, 1) + + def test_many_requests(self): + self._pairs = _generate_channel_server_pairs(1) + k_success = 7 + k_failed = 9 + for i in range(k_success): + self._send_successful_unary_unary(0) + for i in range(k_failed): + self._send_failed_unary_unary(0) + resp = self._channelz_stub.GetChannel( + channelz_pb2.GetChannelRequest(channel_id=self._get_channel_id(0))) + self.assertEqual(resp.channel.data.calls_started, k_success + k_failed) + self.assertEqual(resp.channel.data.calls_succeeded, k_success) + self.assertEqual(resp.channel.data.calls_failed, k_failed) + + def test_many_channel(self): + k_channels = 4 + self._pairs = _generate_channel_server_pairs(k_channels) + resp = self._channelz_stub.GetTopChannels( + channelz_pb2.GetTopChannelsRequest(start_channel_id=0)) + self.assertEqual(len(resp.channel), k_channels) + + def test_many_requests_many_channel(self): + k_channels = 4 + self._pairs = _generate_channel_server_pairs(k_channels) + k_success = 11 + k_failed = 13 + for i in range(k_success): + self._send_successful_unary_unary(0) + self._send_successful_unary_unary(2) + for i in range(k_failed): + self._send_failed_unary_unary(1) + self._send_failed_unary_unary(2) + + # The first channel saw only successes + resp = self._channelz_stub.GetChannel( + channelz_pb2.GetChannelRequest(channel_id=self._get_channel_id(0))) + self.assertEqual(resp.channel.data.calls_started, k_success) + self.assertEqual(resp.channel.data.calls_succeeded, k_success) + self.assertEqual(resp.channel.data.calls_failed, 0) + + # The second channel saw only failures + resp = self._channelz_stub.GetChannel( + channelz_pb2.GetChannelRequest(channel_id=self._get_channel_id(1))) + self.assertEqual(resp.channel.data.calls_started, k_failed) + self.assertEqual(resp.channel.data.calls_succeeded, 0) + self.assertEqual(resp.channel.data.calls_failed, k_failed) + + # The third channel saw both successes and failures + resp = self._channelz_stub.GetChannel( + channelz_pb2.GetChannelRequest(channel_id=self._get_channel_id(2))) + self.assertEqual(resp.channel.data.calls_started, k_success + k_failed) + self.assertEqual(resp.channel.data.calls_succeeded, k_success) + self.assertEqual(resp.channel.data.calls_failed, k_failed) + + # The fourth channel saw nothing + resp = self._channelz_stub.GetChannel( + channelz_pb2.GetChannelRequest(channel_id=self._get_channel_id(3))) + self.assertEqual(resp.channel.data.calls_started, 0) + self.assertEqual(resp.channel.data.calls_succeeded, 0) + self.assertEqual(resp.channel.data.calls_failed, 0) + + def test_many_subchannels(self): + k_channels = 4 + self._pairs = _generate_channel_server_pairs(k_channels) + k_success = 17 + k_failed = 19 + for i in range(k_success): + self._send_successful_unary_unary(0) + self._send_successful_unary_unary(2) + for i in range(k_failed): + self._send_failed_unary_unary(1) + self._send_failed_unary_unary(2) + + gtc_resp = self._channelz_stub.GetTopChannels( + channelz_pb2.GetTopChannelsRequest(start_channel_id=0)) + self.assertEqual(len(gtc_resp.channel), k_channels) + for i in range(k_channels): + # If no call performed in the channel, there shouldn't be any subchannel + if gtc_resp.channel[i].data.calls_started == 0: + self.assertEqual(len(gtc_resp.channel[i].subchannel_ref), 0) + continue + + # Otherwise, the subchannel should exist + self.assertGreater(len(gtc_resp.channel[i].subchannel_ref), 0) + gsc_resp = self._channelz_stub.GetSubchannel( + channelz_pb2.GetSubchannelRequest( + subchannel_id=gtc_resp.channel[i].subchannel_ref[0]. + subchannel_id)) + self.assertEqual(gtc_resp.channel[i].data.calls_started, + gsc_resp.subchannel.data.calls_started) + self.assertEqual(gtc_resp.channel[i].data.calls_succeeded, + gsc_resp.subchannel.data.calls_succeeded) + self.assertEqual(gtc_resp.channel[i].data.calls_failed, + gsc_resp.subchannel.data.calls_failed) + + def test_server_basic(self): + self._pairs = _generate_channel_server_pairs(1) + resp = self._channelz_stub.GetServers( + channelz_pb2.GetServersRequest(start_server_id=0)) + self.assertEqual(len(resp.server), 1) + + def test_get_one_server(self): + self._pairs = _generate_channel_server_pairs(1) + gss_resp = self._channelz_stub.GetServers( + channelz_pb2.GetServersRequest(start_server_id=0)) + self.assertEqual(len(gss_resp.server), 1) + gs_resp = self._channelz_stub.GetServer( + channelz_pb2.GetServerRequest( + server_id=gss_resp.server[0].ref.server_id)) + self.assertEqual(gss_resp.server[0].ref.server_id, + gs_resp.server.ref.server_id) + + def test_server_call(self): + self._pairs = _generate_channel_server_pairs(1) + k_success = 23 + k_failed = 29 + for i in range(k_success): + self._send_successful_unary_unary(0) + for i in range(k_failed): + self._send_failed_unary_unary(0) + + resp = self._channelz_stub.GetServers( + channelz_pb2.GetServersRequest(start_server_id=0)) + self.assertEqual(len(resp.server), 1) + self.assertEqual(resp.server[0].data.calls_started, + k_success + k_failed) + self.assertEqual(resp.server[0].data.calls_succeeded, k_success) + self.assertEqual(resp.server[0].data.calls_failed, k_failed) + + def test_many_subchannels_and_sockets(self): + k_channels = 4 + self._pairs = _generate_channel_server_pairs(k_channels) + k_success = 3 + k_failed = 5 + for i in range(k_success): + self._send_successful_unary_unary(0) + self._send_successful_unary_unary(2) + for i in range(k_failed): + self._send_failed_unary_unary(1) + self._send_failed_unary_unary(2) + + gtc_resp = self._channelz_stub.GetTopChannels( + channelz_pb2.GetTopChannelsRequest(start_channel_id=0)) + self.assertEqual(len(gtc_resp.channel), k_channels) + for i in range(k_channels): + # If no call performed in the channel, there shouldn't be any subchannel + if gtc_resp.channel[i].data.calls_started == 0: + self.assertEqual(len(gtc_resp.channel[i].subchannel_ref), 0) + continue + + # Otherwise, the subchannel should exist + self.assertGreater(len(gtc_resp.channel[i].subchannel_ref), 0) + gsc_resp = self._channelz_stub.GetSubchannel( + channelz_pb2.GetSubchannelRequest( + subchannel_id=gtc_resp.channel[i].subchannel_ref[0]. + subchannel_id)) + self.assertEqual(len(gsc_resp.subchannel.socket_ref), 1) + + gs_resp = self._channelz_stub.GetSocket( + channelz_pb2.GetSocketRequest( + socket_id=gsc_resp.subchannel.socket_ref[0].socket_id)) + self.assertEqual(gsc_resp.subchannel.data.calls_started, + gs_resp.socket.data.streams_started) + self.assertEqual(gsc_resp.subchannel.data.calls_started, + gs_resp.socket.data.streams_succeeded) + # Calls started == messages sent, only valid for unary calls + self.assertEqual(gsc_resp.subchannel.data.calls_started, + gs_resp.socket.data.messages_sent) + # Only receive responses when the RPC was successful + self.assertEqual(gsc_resp.subchannel.data.calls_succeeded, + gs_resp.socket.data.messages_received) + + def test_streaming_rpc(self): + self._pairs = _generate_channel_server_pairs(1) + # In C++, the argument for _send_successful_stream_stream is message length. + # Here the argument is still channel idx, to be consistent with the other two. + self._send_successful_stream_stream(0) + + gc_resp = self._channelz_stub.GetChannel( + channelz_pb2.GetChannelRequest(channel_id=self._get_channel_id(0))) + self.assertEqual(gc_resp.channel.data.calls_started, 1) + self.assertEqual(gc_resp.channel.data.calls_succeeded, 1) + self.assertEqual(gc_resp.channel.data.calls_failed, 0) + # Subchannel exists + self.assertGreater(len(gc_resp.channel.subchannel_ref), 0) + + gsc_resp = self._channelz_stub.GetSubchannel( + channelz_pb2.GetSubchannelRequest( + subchannel_id=gc_resp.channel.subchannel_ref[0].subchannel_id)) + self.assertEqual(gsc_resp.subchannel.data.calls_started, 1) + self.assertEqual(gsc_resp.subchannel.data.calls_succeeded, 1) + self.assertEqual(gsc_resp.subchannel.data.calls_failed, 0) + # Socket exists + self.assertEqual(len(gsc_resp.subchannel.socket_ref), 1) + + gs_resp = self._channelz_stub.GetSocket( + channelz_pb2.GetSocketRequest( + socket_id=gsc_resp.subchannel.socket_ref[0].socket_id)) + self.assertEqual(gs_resp.socket.data.streams_started, 1) + self.assertEqual(gs_resp.socket.data.streams_succeeded, 1) + self.assertEqual(gs_resp.socket.data.streams_failed, 0) + self.assertEqual(gs_resp.socket.data.messages_sent, + test_constants.STREAM_LENGTH) + self.assertEqual(gs_resp.socket.data.messages_received, + test_constants.STREAM_LENGTH) + + def test_server_sockets(self): + self._pairs = _generate_channel_server_pairs(1) + self._send_successful_unary_unary(0) + self._send_failed_unary_unary(0) + + gs_resp = self._channelz_stub.GetServers( + channelz_pb2.GetServersRequest(start_server_id=0)) + self.assertEqual(len(gs_resp.server), 1) + self.assertEqual(gs_resp.server[0].data.calls_started, 2) + self.assertEqual(gs_resp.server[0].data.calls_succeeded, 1) + self.assertEqual(gs_resp.server[0].data.calls_failed, 1) + + gss_resp = self._channelz_stub.GetServerSockets( + channelz_pb2.GetServerSocketsRequest( + server_id=gs_resp.server[0].ref.server_id, start_socket_id=0)) + # If the RPC call failed, it will raise a grpc.RpcError + # So, if there is no exception raised, considered pass + + def test_server_listen_sockets(self): + self._pairs = _generate_channel_server_pairs(1) + + gss_resp = self._channelz_stub.GetServers( + channelz_pb2.GetServersRequest(start_server_id=0)) + self.assertEqual(len(gss_resp.server), 1) + self.assertEqual(len(gss_resp.server[0].listen_socket), 1) + + gs_resp = self._channelz_stub.GetSocket( + channelz_pb2.GetSocketRequest( + socket_id=gss_resp.server[0].listen_socket[0].socket_id)) + # If the RPC call failed, it will raise a grpc.RpcError + # So, if there is no exception raised, considered pass + + def test_invalid_query_get_server(self): + try: + self._channelz_stub.GetServer( + channelz_pb2.GetServerRequest(server_id=10000)) + except BaseException as e: + self.assertIn('StatusCode.NOT_FOUND', str(e)) + else: + self.fail('Invalid query not detected') + + def test_invalid_query_get_channel(self): + try: + self._channelz_stub.GetChannel( + channelz_pb2.GetChannelRequest(channel_id=10000)) + except BaseException as e: + self.assertIn('StatusCode.NOT_FOUND', str(e)) + else: + self.fail('Invalid query not detected') + + def test_invalid_query_get_subchannel(self): + try: + self._channelz_stub.GetSubchannel( + channelz_pb2.GetSubchannelRequest(subchannel_id=10000)) + except BaseException as e: + self.assertIn('StatusCode.NOT_FOUND', str(e)) + else: + self.fail('Invalid query not detected') + + def test_invalid_query_get_socket(self): + try: + self._channelz_stub.GetSocket( + channelz_pb2.GetSocketRequest(socket_id=10000)) + except BaseException as e: + self.assertIn('StatusCode.NOT_FOUND', str(e)) + else: + self.fail('Invalid query not detected') + + def test_invalid_query_get_server_sockets(self): + try: + self._channelz_stub.GetServerSockets( + channelz_pb2.GetServerSocketsRequest( + server_id=10000, + start_socket_id=0, + )) + except BaseException as e: + self.assertIn('StatusCode.NOT_FOUND', str(e)) + else: + self.fail('Invalid query not detected') + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/fork/__init__.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/fork/__init__.py new file mode 100644 index 00000000000..9a26bac0101 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/fork/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2018 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/fork/_fork_interop_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/fork/_fork_interop_test.py new file mode 100644 index 00000000000..e2eff257fa1 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/fork/_fork_interop_test.py @@ -0,0 +1,151 @@ +# Copyright 2019 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Client-side fork interop tests as a unit test.""" + +import six +import subprocess +import sys +import threading +import unittest +from grpc._cython import cygrpc +from tests.fork import methods + +# New instance of multiprocessing.Process using fork without exec can and will +# hang if the Python process has any other threads running. This includes the +# additional thread spawned by our _runner.py class. So in order to test our +# compatibility with multiprocessing, we first fork+exec a new process to ensure +# we don't have any conflicting background threads. +_CLIENT_FORK_SCRIPT_TEMPLATE = """if True: + import os + import sys + from grpc._cython import cygrpc + from tests.fork import methods + + cygrpc._GRPC_ENABLE_FORK_SUPPORT = True + os.environ['GRPC_POLL_STRATEGY'] = 'epoll1' + methods.TestCase.%s.run_test({ + 'server_host': 'localhost', + 'server_port': %d, + 'use_tls': False + }) +""" +_SUBPROCESS_TIMEOUT_S = 30 + + + sys.platform.startswith("linux"), + "not supported on windows, and fork+exec networking blocked on mac") [email protected](six.PY2, "https://github.com/grpc/grpc/issues/18075") +class ForkInteropTest(unittest.TestCase): + + def setUp(self): + start_server_script = """if True: + import sys + import time + + import grpc + from src.proto.grpc.testing import test_pb2_grpc + from tests.interop import service as interop_service + from tests.unit import test_common + + server = test_common.test_server() + test_pb2_grpc.add_TestServiceServicer_to_server( + interop_service.TestService(), server) + port = server.add_insecure_port('[::]:0') + server.start() + print(port) + sys.stdout.flush() + while True: + time.sleep(1) + """ + self._server_process = subprocess.Popen( + [sys.executable, '-c', start_server_script], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + timer = threading.Timer(_SUBPROCESS_TIMEOUT_S, + self._server_process.kill) + try: + timer.start() + self._port = int(self._server_process.stdout.readline()) + except ValueError: + raise Exception('Failed to get port from server') + finally: + timer.cancel() + + def testConnectivityWatch(self): + self._verifyTestCase(methods.TestCase.CONNECTIVITY_WATCH) + + def testCloseChannelBeforeFork(self): + self._verifyTestCase(methods.TestCase.CLOSE_CHANNEL_BEFORE_FORK) + + def testAsyncUnarySameChannel(self): + self._verifyTestCase(methods.TestCase.ASYNC_UNARY_SAME_CHANNEL) + + def testAsyncUnaryNewChannel(self): + self._verifyTestCase(methods.TestCase.ASYNC_UNARY_NEW_CHANNEL) + + def testBlockingUnarySameChannel(self): + self._verifyTestCase(methods.TestCase.BLOCKING_UNARY_SAME_CHANNEL) + + def testBlockingUnaryNewChannel(self): + self._verifyTestCase(methods.TestCase.BLOCKING_UNARY_NEW_CHANNEL) + + def testInProgressBidiContinueCall(self): + self._verifyTestCase(methods.TestCase.IN_PROGRESS_BIDI_CONTINUE_CALL) + + def testInProgressBidiSameChannelAsyncCall(self): + self._verifyTestCase( + methods.TestCase.IN_PROGRESS_BIDI_SAME_CHANNEL_ASYNC_CALL) + + def testInProgressBidiSameChannelBlockingCall(self): + self._verifyTestCase( + methods.TestCase.IN_PROGRESS_BIDI_SAME_CHANNEL_BLOCKING_CALL) + + def testInProgressBidiNewChannelAsyncCall(self): + self._verifyTestCase( + methods.TestCase.IN_PROGRESS_BIDI_NEW_CHANNEL_ASYNC_CALL) + + def testInProgressBidiNewChannelBlockingCall(self): + self._verifyTestCase( + methods.TestCase.IN_PROGRESS_BIDI_NEW_CHANNEL_BLOCKING_CALL) + + def tearDown(self): + self._server_process.kill() + + def _verifyTestCase(self, test_case): + script = _CLIENT_FORK_SCRIPT_TEMPLATE % (test_case.name, self._port) + process = subprocess.Popen([sys.executable, '-c', script], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + timer = threading.Timer(_SUBPROCESS_TIMEOUT_S, process.kill) + try: + timer.start() + try: + out, err = process.communicate(timeout=_SUBPROCESS_TIMEOUT_S) + except TypeError: + # The timeout parameter was added in Python 3.3. + out, err = process.communicate() + except subprocess.TimeoutExpired: + process.kill() + raise RuntimeError('Process failed to terminate') + finally: + timer.cancel() + self.assertEqual( + 0, process.returncode, + 'process failed with exit code %d (stdout: %s, stderr: %s)' % + (process.returncode, out, err)) + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/fork/client.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/fork/client.py new file mode 100644 index 00000000000..852e6da4d69 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/fork/client.py @@ -0,0 +1,72 @@ +# Copyright 2018 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The Python implementation of the GRPC interoperability test client.""" + +import argparse +import logging +import sys + +from tests.fork import methods + + +def _args(): + + def parse_bool(value): + if value == 'true': + return True + if value == 'false': + return False + raise argparse.ArgumentTypeError('Only true/false allowed') + + parser = argparse.ArgumentParser() + parser.add_argument('--server_host', + default="localhost", + type=str, + help='the host to which to connect') + parser.add_argument('--server_port', + type=int, + required=True, + help='the port to which to connect') + parser.add_argument('--test_case', + default='large_unary', + type=str, + help='the test case to execute') + parser.add_argument('--use_tls', + default=False, + type=parse_bool, + help='require a secure connection') + return parser.parse_args() + + +def _test_case_from_arg(test_case_arg): + for test_case in methods.TestCase: + if test_case_arg == test_case.value: + return test_case + else: + raise ValueError('No test case "%s"!' % test_case_arg) + + +def test_fork(): + logging.basicConfig(level=logging.INFO) + args = vars(_args()) + if args['test_case'] == "all": + for test_case in methods.TestCase: + test_case.run_test(args) + else: + test_case = _test_case_from_arg(args['test_case']) + test_case.run_test(args) + + +if __name__ == '__main__': + test_fork() diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/fork/methods.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/fork/methods.py new file mode 100644 index 00000000000..2123c699161 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/fork/methods.py @@ -0,0 +1,451 @@ +# Copyright 2018 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Implementations of fork support test methods.""" + +import enum +import json +import logging +import multiprocessing +import os +import threading +import time + +import grpc + +from six.moves import queue + +from src.proto.grpc.testing import empty_pb2 +from src.proto.grpc.testing import messages_pb2 +from src.proto.grpc.testing import test_pb2_grpc + +_LOGGER = logging.getLogger(__name__) +_RPC_TIMEOUT_S = 10 +_CHILD_FINISH_TIMEOUT_S = 60 + + +def _channel(args): + target = '{}:{}'.format(args['server_host'], args['server_port']) + if args['use_tls']: + channel_credentials = grpc.ssl_channel_credentials() + channel = grpc.secure_channel(target, channel_credentials) + else: + channel = grpc.insecure_channel(target) + return channel + + +def _validate_payload_type_and_length(response, expected_type, expected_length): + if response.payload.type is not expected_type: + raise ValueError('expected payload type %s, got %s' % + (expected_type, type(response.payload.type))) + elif len(response.payload.body) != expected_length: + raise ValueError('expected payload body size %d, got %d' % + (expected_length, len(response.payload.body))) + + +def _async_unary(stub): + size = 314159 + request = messages_pb2.SimpleRequest( + response_type=messages_pb2.COMPRESSABLE, + response_size=size, + payload=messages_pb2.Payload(body=b'\x00' * 271828)) + response_future = stub.UnaryCall.future(request, timeout=_RPC_TIMEOUT_S) + response = response_future.result() + _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, size) + + +def _blocking_unary(stub): + size = 314159 + request = messages_pb2.SimpleRequest( + response_type=messages_pb2.COMPRESSABLE, + response_size=size, + payload=messages_pb2.Payload(body=b'\x00' * 271828)) + response = stub.UnaryCall(request, timeout=_RPC_TIMEOUT_S) + _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, size) + + +class _Pipe(object): + + def __init__(self): + self._condition = threading.Condition() + self._values = [] + self._open = True + + def __iter__(self): + return self + + def __next__(self): + return self.next() + + def next(self): + with self._condition: + while not self._values and self._open: + self._condition.wait() + if self._values: + return self._values.pop(0) + else: + raise StopIteration() + + def add(self, value): + with self._condition: + self._values.append(value) + self._condition.notify() + + def close(self): + with self._condition: + self._open = False + self._condition.notify() + + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + self.close() + + +class _ChildProcess(object): + + def __init__(self, task, args=None): + if args is None: + args = () + self._exceptions = multiprocessing.Queue() + + def record_exceptions(): + try: + task(*args) + except grpc.RpcError as rpc_error: + self._exceptions.put('RpcError: %s' % rpc_error) + except Exception as e: # pylint: disable=broad-except + self._exceptions.put(e) + + self._process = multiprocessing.Process(target=record_exceptions) + + def start(self): + self._process.start() + + def finish(self): + self._process.join(timeout=_CHILD_FINISH_TIMEOUT_S) + if self._process.is_alive(): + raise RuntimeError('Child process did not terminate') + if self._process.exitcode != 0: + raise ValueError('Child process failed with exitcode %d' % + self._process.exitcode) + try: + exception = self._exceptions.get(block=False) + raise ValueError('Child process failed: %s' % exception) + except queue.Empty: + pass + + +def _async_unary_same_channel(channel): + + def child_target(): + try: + _async_unary(stub) + raise Exception( + 'Child should not be able to re-use channel after fork') + except ValueError as expected_value_error: + pass + + stub = test_pb2_grpc.TestServiceStub(channel) + _async_unary(stub) + child_process = _ChildProcess(child_target) + child_process.start() + _async_unary(stub) + child_process.finish() + + +def _async_unary_new_channel(channel, args): + + def child_target(): + with _channel(args) as child_channel: + child_stub = test_pb2_grpc.TestServiceStub(child_channel) + _async_unary(child_stub) + child_channel.close() + + stub = test_pb2_grpc.TestServiceStub(channel) + _async_unary(stub) + child_process = _ChildProcess(child_target) + child_process.start() + _async_unary(stub) + child_process.finish() + + +def _blocking_unary_same_channel(channel): + + def child_target(): + try: + _blocking_unary(stub) + raise Exception( + 'Child should not be able to re-use channel after fork') + except ValueError as expected_value_error: + pass + + stub = test_pb2_grpc.TestServiceStub(channel) + _blocking_unary(stub) + child_process = _ChildProcess(child_target) + child_process.start() + child_process.finish() + + +def _blocking_unary_new_channel(channel, args): + + def child_target(): + with _channel(args) as child_channel: + child_stub = test_pb2_grpc.TestServiceStub(child_channel) + _blocking_unary(child_stub) + + stub = test_pb2_grpc.TestServiceStub(channel) + _blocking_unary(stub) + child_process = _ChildProcess(child_target) + child_process.start() + _blocking_unary(stub) + child_process.finish() + + +# Verify that the fork channel registry can handle already closed channels +def _close_channel_before_fork(channel, args): + + def child_target(): + new_channel.close() + with _channel(args) as child_channel: + child_stub = test_pb2_grpc.TestServiceStub(child_channel) + _blocking_unary(child_stub) + + stub = test_pb2_grpc.TestServiceStub(channel) + _blocking_unary(stub) + channel.close() + + with _channel(args) as new_channel: + new_stub = test_pb2_grpc.TestServiceStub(new_channel) + child_process = _ChildProcess(child_target) + child_process.start() + _blocking_unary(new_stub) + child_process.finish() + + +def _connectivity_watch(channel, args): + + parent_states = [] + parent_channel_ready_event = threading.Event() + + def child_target(): + + child_channel_ready_event = threading.Event() + + def child_connectivity_callback(state): + if state is grpc.ChannelConnectivity.READY: + child_channel_ready_event.set() + + with _channel(args) as child_channel: + child_stub = test_pb2_grpc.TestServiceStub(child_channel) + child_channel.subscribe(child_connectivity_callback) + _async_unary(child_stub) + if not child_channel_ready_event.wait(timeout=_RPC_TIMEOUT_S): + raise ValueError('Channel did not move to READY') + if len(parent_states) > 1: + raise ValueError( + 'Received connectivity updates on parent callback', + parent_states) + child_channel.unsubscribe(child_connectivity_callback) + + def parent_connectivity_callback(state): + parent_states.append(state) + if state is grpc.ChannelConnectivity.READY: + parent_channel_ready_event.set() + + channel.subscribe(parent_connectivity_callback) + stub = test_pb2_grpc.TestServiceStub(channel) + child_process = _ChildProcess(child_target) + child_process.start() + _async_unary(stub) + if not parent_channel_ready_event.wait(timeout=_RPC_TIMEOUT_S): + raise ValueError('Channel did not move to READY') + channel.unsubscribe(parent_connectivity_callback) + child_process.finish() + + +def _ping_pong_with_child_processes_after_first_response( + channel, args, child_target, run_after_close=True): + request_response_sizes = ( + 31415, + 9, + 2653, + 58979, + ) + request_payload_sizes = ( + 27182, + 8, + 1828, + 45904, + ) + stub = test_pb2_grpc.TestServiceStub(channel) + pipe = _Pipe() + parent_bidi_call = stub.FullDuplexCall(pipe) + child_processes = [] + first_message_received = False + for response_size, payload_size in zip(request_response_sizes, + request_payload_sizes): + request = messages_pb2.StreamingOutputCallRequest( + response_type=messages_pb2.COMPRESSABLE, + response_parameters=(messages_pb2.ResponseParameters( + size=response_size),), + payload=messages_pb2.Payload(body=b'\x00' * payload_size)) + pipe.add(request) + if first_message_received: + child_process = _ChildProcess(child_target, + (parent_bidi_call, channel, args)) + child_process.start() + child_processes.append(child_process) + response = next(parent_bidi_call) + first_message_received = True + child_process = _ChildProcess(child_target, + (parent_bidi_call, channel, args)) + child_process.start() + child_processes.append(child_process) + _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, + response_size) + pipe.close() + if run_after_close: + child_process = _ChildProcess(child_target, + (parent_bidi_call, channel, args)) + child_process.start() + child_processes.append(child_process) + for child_process in child_processes: + child_process.finish() + + +def _in_progress_bidi_continue_call(channel): + + def child_target(parent_bidi_call, parent_channel, args): + stub = test_pb2_grpc.TestServiceStub(parent_channel) + try: + _async_unary(stub) + raise Exception( + 'Child should not be able to re-use channel after fork') + except ValueError as expected_value_error: + pass + inherited_code = parent_bidi_call.code() + inherited_details = parent_bidi_call.details() + if inherited_code != grpc.StatusCode.CANCELLED: + raise ValueError('Expected inherited code CANCELLED, got %s' % + inherited_code) + if inherited_details != 'Channel closed due to fork': + raise ValueError( + 'Expected inherited details Channel closed due to fork, got %s' + % inherited_details) + + # Don't run child_target after closing the parent call, as the call may have + # received a status from the server before fork occurs. + _ping_pong_with_child_processes_after_first_response(channel, + None, + child_target, + run_after_close=False) + + +def _in_progress_bidi_same_channel_async_call(channel): + + def child_target(parent_bidi_call, parent_channel, args): + stub = test_pb2_grpc.TestServiceStub(parent_channel) + try: + _async_unary(stub) + raise Exception( + 'Child should not be able to re-use channel after fork') + except ValueError as expected_value_error: + pass + + _ping_pong_with_child_processes_after_first_response( + channel, None, child_target) + + +def _in_progress_bidi_same_channel_blocking_call(channel): + + def child_target(parent_bidi_call, parent_channel, args): + stub = test_pb2_grpc.TestServiceStub(parent_channel) + try: + _blocking_unary(stub) + raise Exception( + 'Child should not be able to re-use channel after fork') + except ValueError as expected_value_error: + pass + + _ping_pong_with_child_processes_after_first_response( + channel, None, child_target) + + +def _in_progress_bidi_new_channel_async_call(channel, args): + + def child_target(parent_bidi_call, parent_channel, args): + with _channel(args) as channel: + stub = test_pb2_grpc.TestServiceStub(channel) + _async_unary(stub) + + _ping_pong_with_child_processes_after_first_response( + channel, args, child_target) + + +def _in_progress_bidi_new_channel_blocking_call(channel, args): + + def child_target(parent_bidi_call, parent_channel, args): + with _channel(args) as channel: + stub = test_pb2_grpc.TestServiceStub(channel) + _blocking_unary(stub) + + _ping_pong_with_child_processes_after_first_response( + channel, args, child_target) + + +class TestCase(enum.Enum): + + CONNECTIVITY_WATCH = 'connectivity_watch' + CLOSE_CHANNEL_BEFORE_FORK = 'close_channel_before_fork' + ASYNC_UNARY_SAME_CHANNEL = 'async_unary_same_channel' + ASYNC_UNARY_NEW_CHANNEL = 'async_unary_new_channel' + BLOCKING_UNARY_SAME_CHANNEL = 'blocking_unary_same_channel' + BLOCKING_UNARY_NEW_CHANNEL = 'blocking_unary_new_channel' + IN_PROGRESS_BIDI_CONTINUE_CALL = 'in_progress_bidi_continue_call' + IN_PROGRESS_BIDI_SAME_CHANNEL_ASYNC_CALL = 'in_progress_bidi_same_channel_async_call' + IN_PROGRESS_BIDI_SAME_CHANNEL_BLOCKING_CALL = 'in_progress_bidi_same_channel_blocking_call' + IN_PROGRESS_BIDI_NEW_CHANNEL_ASYNC_CALL = 'in_progress_bidi_new_channel_async_call' + IN_PROGRESS_BIDI_NEW_CHANNEL_BLOCKING_CALL = 'in_progress_bidi_new_channel_blocking_call' + + def run_test(self, args): + _LOGGER.info("Running %s", self) + channel = _channel(args) + if self is TestCase.ASYNC_UNARY_SAME_CHANNEL: + _async_unary_same_channel(channel) + elif self is TestCase.ASYNC_UNARY_NEW_CHANNEL: + _async_unary_new_channel(channel, args) + elif self is TestCase.BLOCKING_UNARY_SAME_CHANNEL: + _blocking_unary_same_channel(channel) + elif self is TestCase.BLOCKING_UNARY_NEW_CHANNEL: + _blocking_unary_new_channel(channel, args) + elif self is TestCase.CLOSE_CHANNEL_BEFORE_FORK: + _close_channel_before_fork(channel, args) + elif self is TestCase.CONNECTIVITY_WATCH: + _connectivity_watch(channel, args) + elif self is TestCase.IN_PROGRESS_BIDI_CONTINUE_CALL: + _in_progress_bidi_continue_call(channel) + elif self is TestCase.IN_PROGRESS_BIDI_SAME_CHANNEL_ASYNC_CALL: + _in_progress_bidi_same_channel_async_call(channel) + elif self is TestCase.IN_PROGRESS_BIDI_SAME_CHANNEL_BLOCKING_CALL: + _in_progress_bidi_same_channel_blocking_call(channel) + elif self is TestCase.IN_PROGRESS_BIDI_NEW_CHANNEL_ASYNC_CALL: + _in_progress_bidi_new_channel_async_call(channel, args) + elif self is TestCase.IN_PROGRESS_BIDI_NEW_CHANNEL_BLOCKING_CALL: + _in_progress_bidi_new_channel_blocking_call(channel, args) + else: + raise NotImplementedError('Test case "%s" not implemented!' % + self.name) + channel.close() diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/health_check/__init__.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/health_check/__init__.py new file mode 100644 index 00000000000..5772620b602 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/health_check/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2016 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/health_check/_health_servicer_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/health_check/_health_servicer_test.py new file mode 100644 index 00000000000..01345aaca08 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/health_check/_health_servicer_test.py @@ -0,0 +1,282 @@ +# Copyright 2016 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests of grpc_health.v1.health.""" + +import logging +import threading +import time +import unittest + +import grpc + +from grpc_health.v1 import health +from grpc_health.v1 import health_pb2 +from grpc_health.v1 import health_pb2_grpc + +from tests.unit import test_common +from tests.unit import thread_pool +from tests.unit.framework.common import test_constants + +from six.moves import queue + +_SERVING_SERVICE = 'grpc.test.TestServiceServing' +_UNKNOWN_SERVICE = 'grpc.test.TestServiceUnknown' +_NOT_SERVING_SERVICE = 'grpc.test.TestServiceNotServing' +_WATCH_SERVICE = 'grpc.test.WatchService' + + +def _consume_responses(response_iterator, response_queue): + for response in response_iterator: + response_queue.put(response) + + +class BaseWatchTests(object): + + class WatchTests(unittest.TestCase): + + def start_server(self, non_blocking=False, thread_pool=None): + self._thread_pool = thread_pool + self._servicer = health.HealthServicer( + experimental_non_blocking=non_blocking, + experimental_thread_pool=thread_pool) + self._servicer.set(_SERVING_SERVICE, + health_pb2.HealthCheckResponse.SERVING) + self._servicer.set(_UNKNOWN_SERVICE, + health_pb2.HealthCheckResponse.UNKNOWN) + self._servicer.set(_NOT_SERVING_SERVICE, + health_pb2.HealthCheckResponse.NOT_SERVING) + self._server = test_common.test_server() + port = self._server.add_insecure_port('[::]:0') + health_pb2_grpc.add_HealthServicer_to_server( + self._servicer, self._server) + self._server.start() + + self._channel = grpc.insecure_channel('localhost:%d' % port) + self._stub = health_pb2_grpc.HealthStub(self._channel) + + def tearDown(self): + self._server.stop(None) + self._channel.close() + + def test_watch_empty_service(self): + request = health_pb2.HealthCheckRequest(service='') + response_queue = queue.Queue() + rendezvous = self._stub.Watch(request) + thread = threading.Thread(target=_consume_responses, + args=(rendezvous, response_queue)) + thread.start() + + response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT) + self.assertEqual(health_pb2.HealthCheckResponse.SERVING, + response.status) + + rendezvous.cancel() + thread.join() + self.assertTrue(response_queue.empty()) + + if self._thread_pool is not None: + self.assertTrue(self._thread_pool.was_used()) + + def test_watch_new_service(self): + request = health_pb2.HealthCheckRequest(service=_WATCH_SERVICE) + response_queue = queue.Queue() + rendezvous = self._stub.Watch(request) + thread = threading.Thread(target=_consume_responses, + args=(rendezvous, response_queue)) + thread.start() + + response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT) + self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, + response.status) + + self._servicer.set(_WATCH_SERVICE, + health_pb2.HealthCheckResponse.SERVING) + response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT) + self.assertEqual(health_pb2.HealthCheckResponse.SERVING, + response.status) + + self._servicer.set(_WATCH_SERVICE, + health_pb2.HealthCheckResponse.NOT_SERVING) + response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT) + self.assertEqual(health_pb2.HealthCheckResponse.NOT_SERVING, + response.status) + + rendezvous.cancel() + thread.join() + self.assertTrue(response_queue.empty()) + + def test_watch_service_isolation(self): + request = health_pb2.HealthCheckRequest(service=_WATCH_SERVICE) + response_queue = queue.Queue() + rendezvous = self._stub.Watch(request) + thread = threading.Thread(target=_consume_responses, + args=(rendezvous, response_queue)) + thread.start() + + response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT) + self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, + response.status) + + self._servicer.set('some-other-service', + health_pb2.HealthCheckResponse.SERVING) + with self.assertRaises(queue.Empty): + response_queue.get(timeout=test_constants.SHORT_TIMEOUT) + + rendezvous.cancel() + thread.join() + self.assertTrue(response_queue.empty()) + + def test_two_watchers(self): + request = health_pb2.HealthCheckRequest(service=_WATCH_SERVICE) + response_queue1 = queue.Queue() + response_queue2 = queue.Queue() + rendezvous1 = self._stub.Watch(request) + rendezvous2 = self._stub.Watch(request) + thread1 = threading.Thread(target=_consume_responses, + args=(rendezvous1, response_queue1)) + thread2 = threading.Thread(target=_consume_responses, + args=(rendezvous2, response_queue2)) + thread1.start() + thread2.start() + + response1 = response_queue1.get( + timeout=test_constants.SHORT_TIMEOUT) + response2 = response_queue2.get( + timeout=test_constants.SHORT_TIMEOUT) + self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, + response1.status) + self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, + response2.status) + + self._servicer.set(_WATCH_SERVICE, + health_pb2.HealthCheckResponse.SERVING) + response1 = response_queue1.get( + timeout=test_constants.SHORT_TIMEOUT) + response2 = response_queue2.get( + timeout=test_constants.SHORT_TIMEOUT) + self.assertEqual(health_pb2.HealthCheckResponse.SERVING, + response1.status) + self.assertEqual(health_pb2.HealthCheckResponse.SERVING, + response2.status) + + rendezvous1.cancel() + rendezvous2.cancel() + thread1.join() + thread2.join() + self.assertTrue(response_queue1.empty()) + self.assertTrue(response_queue2.empty()) + + @unittest.skip("https://github.com/grpc/grpc/issues/18127") + def test_cancelled_watch_removed_from_watch_list(self): + request = health_pb2.HealthCheckRequest(service=_WATCH_SERVICE) + response_queue = queue.Queue() + rendezvous = self._stub.Watch(request) + thread = threading.Thread(target=_consume_responses, + args=(rendezvous, response_queue)) + thread.start() + + response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT) + self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, + response.status) + + rendezvous.cancel() + self._servicer.set(_WATCH_SERVICE, + health_pb2.HealthCheckResponse.SERVING) + thread.join() + + # Wait, if necessary, for serving thread to process client cancellation + timeout = time.time() + test_constants.TIME_ALLOWANCE + while (time.time() < timeout and + self._servicer._send_response_callbacks[_WATCH_SERVICE]): + time.sleep(1) + self.assertFalse( + self._servicer._send_response_callbacks[_WATCH_SERVICE], + 'watch set should be empty') + self.assertTrue(response_queue.empty()) + + def test_graceful_shutdown(self): + request = health_pb2.HealthCheckRequest(service='') + response_queue = queue.Queue() + rendezvous = self._stub.Watch(request) + thread = threading.Thread(target=_consume_responses, + args=(rendezvous, response_queue)) + thread.start() + + response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT) + self.assertEqual(health_pb2.HealthCheckResponse.SERVING, + response.status) + + self._servicer.enter_graceful_shutdown() + response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT) + self.assertEqual(health_pb2.HealthCheckResponse.NOT_SERVING, + response.status) + + # This should be a no-op. + self._servicer.set('', health_pb2.HealthCheckResponse.SERVING) + + rendezvous.cancel() + thread.join() + self.assertTrue(response_queue.empty()) + + +class HealthServicerTest(BaseWatchTests.WatchTests): + + def setUp(self): + self._thread_pool = thread_pool.RecordingThreadPool(max_workers=None) + super(HealthServicerTest, + self).start_server(non_blocking=True, + thread_pool=self._thread_pool) + + def test_check_empty_service(self): + request = health_pb2.HealthCheckRequest() + resp = self._stub.Check(request) + self.assertEqual(health_pb2.HealthCheckResponse.SERVING, resp.status) + + def test_check_serving_service(self): + request = health_pb2.HealthCheckRequest(service=_SERVING_SERVICE) + resp = self._stub.Check(request) + self.assertEqual(health_pb2.HealthCheckResponse.SERVING, resp.status) + + def test_check_unknown_service(self): + request = health_pb2.HealthCheckRequest(service=_UNKNOWN_SERVICE) + resp = self._stub.Check(request) + self.assertEqual(health_pb2.HealthCheckResponse.UNKNOWN, resp.status) + + def test_check_not_serving_service(self): + request = health_pb2.HealthCheckRequest(service=_NOT_SERVING_SERVICE) + resp = self._stub.Check(request) + self.assertEqual(health_pb2.HealthCheckResponse.NOT_SERVING, + resp.status) + + def test_check_not_found_service(self): + request = health_pb2.HealthCheckRequest(service='not-found') + with self.assertRaises(grpc.RpcError) as context: + resp = self._stub.Check(request) + + self.assertEqual(grpc.StatusCode.NOT_FOUND, context.exception.code()) + + def test_health_service_name(self): + self.assertEqual(health.SERVICE_NAME, 'grpc.health.v1.Health') + + +class HealthServicerBackwardsCompatibleWatchTest(BaseWatchTests.WatchTests): + + def setUp(self): + super(HealthServicerBackwardsCompatibleWatchTest, + self).start_server(non_blocking=False, thread_pool=None) + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/http2/negative_http2_client.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/http2/negative_http2_client.py new file mode 100644 index 00000000000..0753872b5e4 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/http2/negative_http2_client.py @@ -0,0 +1,158 @@ +# Copyright 2016 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The Python client used to test negative http2 conditions.""" + +import argparse + +import grpc +import time +from src.proto.grpc.testing import test_pb2_grpc +from src.proto.grpc.testing import messages_pb2 + + +def _validate_payload_type_and_length(response, expected_type, expected_length): + if response.payload.type is not expected_type: + raise ValueError('expected payload type %s, got %s' % + (expected_type, type(response.payload.type))) + elif len(response.payload.body) != expected_length: + raise ValueError('expected payload body size %d, got %d' % + (expected_length, len(response.payload.body))) + + +def _expect_status_code(call, expected_code): + if call.code() != expected_code: + raise ValueError('expected code %s, got %s' % + (expected_code, call.code())) + + +def _expect_status_details(call, expected_details): + if call.details() != expected_details: + raise ValueError('expected message %s, got %s' % + (expected_details, call.details())) + + +def _validate_status_code_and_details(call, expected_code, expected_details): + _expect_status_code(call, expected_code) + _expect_status_details(call, expected_details) + + +# common requests +_REQUEST_SIZE = 314159 +_RESPONSE_SIZE = 271828 + +_SIMPLE_REQUEST = messages_pb2.SimpleRequest( + response_type=messages_pb2.COMPRESSABLE, + response_size=_RESPONSE_SIZE, + payload=messages_pb2.Payload(body=b'\x00' * _REQUEST_SIZE)) + + +def _goaway(stub): + first_response = stub.UnaryCall(_SIMPLE_REQUEST) + _validate_payload_type_and_length(first_response, messages_pb2.COMPRESSABLE, + _RESPONSE_SIZE) + time.sleep(1) + second_response = stub.UnaryCall(_SIMPLE_REQUEST) + _validate_payload_type_and_length(second_response, + messages_pb2.COMPRESSABLE, _RESPONSE_SIZE) + + +def _rst_after_header(stub): + resp_future = stub.UnaryCall.future(_SIMPLE_REQUEST) + _validate_status_code_and_details(resp_future, grpc.StatusCode.INTERNAL, + "Received RST_STREAM with error code 0") + + +def _rst_during_data(stub): + resp_future = stub.UnaryCall.future(_SIMPLE_REQUEST) + _validate_status_code_and_details(resp_future, grpc.StatusCode.INTERNAL, + "Received RST_STREAM with error code 0") + + +def _rst_after_data(stub): + resp_future = stub.UnaryCall.future(_SIMPLE_REQUEST) + _validate_status_code_and_details(resp_future, grpc.StatusCode.INTERNAL, + "Received RST_STREAM with error code 0") + + +def _ping(stub): + response = stub.UnaryCall(_SIMPLE_REQUEST) + _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, + _RESPONSE_SIZE) + + +def _max_streams(stub): + # send one req to ensure server sets MAX_STREAMS + response = stub.UnaryCall(_SIMPLE_REQUEST) + _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, + _RESPONSE_SIZE) + + # give the streams a workout + futures = [] + for _ in range(15): + futures.append(stub.UnaryCall.future(_SIMPLE_REQUEST)) + for future in futures: + _validate_payload_type_and_length(future.result(), + messages_pb2.COMPRESSABLE, + _RESPONSE_SIZE) + + +def _run_test_case(test_case, stub): + if test_case == 'goaway': + _goaway(stub) + elif test_case == 'rst_after_header': + _rst_after_header(stub) + elif test_case == 'rst_during_data': + _rst_during_data(stub) + elif test_case == 'rst_after_data': + _rst_after_data(stub) + elif test_case == 'ping': + _ping(stub) + elif test_case == 'max_streams': + _max_streams(stub) + else: + raise ValueError("Invalid test case: %s" % test_case) + + +def _args(): + parser = argparse.ArgumentParser() + parser.add_argument('--server_host', + help='the host to which to connect', + type=str, + default="127.0.0.1") + parser.add_argument('--server_port', + help='the port to which to connect', + type=int, + default="8080") + parser.add_argument('--test_case', + help='the test case to execute', + type=str, + default="goaway") + return parser.parse_args() + + +def _stub(server_host, server_port): + target = '{}:{}'.format(server_host, server_port) + channel = grpc.insecure_channel(target) + grpc.channel_ready_future(channel).result() + return test_pb2_grpc.TestServiceStub(channel) + + +def main(): + args = _args() + stub = _stub(args.server_host, args.server_port) + _run_test_case(args.test_case, stub) + + +if __name__ == '__main__': + main() diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/interop/__init__.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/interop/__init__.py new file mode 100644 index 00000000000..5fb4f3c3cfd --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/interop/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2015 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/interop/_insecure_intraop_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/interop/_insecure_intraop_test.py new file mode 100644 index 00000000000..fecf31767a7 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/interop/_insecure_intraop_test.py @@ -0,0 +1,44 @@ +# Copyright 2015 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Insecure client-server interoperability as a unit test.""" + +import unittest + +import grpc +from src.proto.grpc.testing import test_pb2_grpc + +from tests.interop import _intraop_test_case +from tests.interop import service +from tests.interop import server +from tests.unit import test_common + + +class InsecureIntraopTest(_intraop_test_case.IntraopTestCase, + unittest.TestCase): + + def setUp(self): + self.server = test_common.test_server() + test_pb2_grpc.add_TestServiceServicer_to_server(service.TestService(), + self.server) + port = self.server.add_insecure_port('[::]:0') + self.server.start() + self.stub = test_pb2_grpc.TestServiceStub( + grpc.insecure_channel('localhost:{}'.format(port))) + + def tearDown(self): + self.server.stop(None) + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/interop/_intraop_test_case.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/interop/_intraop_test_case.py new file mode 100644 index 00000000000..007db7ab41b --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/interop/_intraop_test_case.py @@ -0,0 +1,51 @@ +# Copyright 2015 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Common code for unit tests of the interoperability test code.""" + +from tests.interop import methods + + +class IntraopTestCase(object): + """Unit test methods. + + This class must be mixed in with unittest.TestCase and a class that defines + setUp and tearDown methods that manage a stub attribute. + """ + + def testEmptyUnary(self): + methods.TestCase.EMPTY_UNARY.test_interoperability(self.stub, None) + + def testLargeUnary(self): + methods.TestCase.LARGE_UNARY.test_interoperability(self.stub, None) + + def testServerStreaming(self): + methods.TestCase.SERVER_STREAMING.test_interoperability(self.stub, None) + + def testClientStreaming(self): + methods.TestCase.CLIENT_STREAMING.test_interoperability(self.stub, None) + + def testPingPong(self): + methods.TestCase.PING_PONG.test_interoperability(self.stub, None) + + def testCancelAfterBegin(self): + methods.TestCase.CANCEL_AFTER_BEGIN.test_interoperability( + self.stub, None) + + def testCancelAfterFirstResponse(self): + methods.TestCase.CANCEL_AFTER_FIRST_RESPONSE.test_interoperability( + self.stub, None) + + def testTimeoutOnSleepingServer(self): + methods.TestCase.TIMEOUT_ON_SLEEPING_SERVER.test_interoperability( + self.stub, None) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/interop/_secure_intraop_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/interop/_secure_intraop_test.py new file mode 100644 index 00000000000..bf1f1b118b3 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/interop/_secure_intraop_test.py @@ -0,0 +1,54 @@ +# Copyright 2015 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Secure client-server interoperability as a unit test.""" + +import unittest + +import grpc +from src.proto.grpc.testing import test_pb2_grpc + +from tests.interop import _intraop_test_case +from tests.interop import service +from tests.interop import resources +from tests.unit import test_common + +_SERVER_HOST_OVERRIDE = 'foo.test.google.fr' + + +class SecureIntraopTest(_intraop_test_case.IntraopTestCase, unittest.TestCase): + + def setUp(self): + self.server = test_common.test_server() + test_pb2_grpc.add_TestServiceServicer_to_server(service.TestService(), + self.server) + port = self.server.add_secure_port( + '[::]:0', + grpc.ssl_server_credentials([(resources.private_key(), + resources.certificate_chain())])) + self.server.start() + self.stub = test_pb2_grpc.TestServiceStub( + grpc.secure_channel( + 'localhost:{}'.format(port), + grpc.ssl_channel_credentials( + resources.test_root_certificates()), (( + 'grpc.ssl_target_name_override', + _SERVER_HOST_OVERRIDE, + ),))) + + def tearDown(self): + self.server.stop(None) + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/interop/client.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/interop/client.py new file mode 100644 index 00000000000..4d35f7ca32a --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/interop/client.py @@ -0,0 +1,180 @@ +# Copyright 2015 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The Python implementation of the GRPC interoperability test client.""" + +import argparse +import os + +from google import auth as google_auth +from google.auth import jwt as google_auth_jwt +import grpc +from src.proto.grpc.testing import test_pb2_grpc + +from tests.interop import methods +from tests.interop import resources + + +def parse_interop_client_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--server_host', + default="localhost", + type=str, + help='the host to which to connect') + parser.add_argument('--server_port', + type=int, + required=True, + help='the port to which to connect') + parser.add_argument('--test_case', + default='large_unary', + type=str, + help='the test case to execute') + parser.add_argument('--use_tls', + default=False, + type=resources.parse_bool, + help='require a secure connection') + parser.add_argument('--use_alts', + default=False, + type=resources.parse_bool, + help='require an ALTS secure connection') + parser.add_argument('--use_test_ca', + default=False, + type=resources.parse_bool, + help='replace platform root CAs with ca.pem') + parser.add_argument('--custom_credentials_type', + choices=["compute_engine_channel_creds"], + default=None, + help='use google default credentials') + parser.add_argument('--server_host_override', + type=str, + help='the server host to which to claim to connect') + parser.add_argument('--oauth_scope', + type=str, + help='scope for OAuth tokens') + parser.add_argument('--default_service_account', + type=str, + help='email address of the default service account') + parser.add_argument( + "--grpc_test_use_grpclb_with_child_policy", + type=str, + help=( + "If non-empty, set a static service config on channels created by " + + "grpc::CreateTestChannel, that configures the grpclb LB policy " + + "with a child policy being the value of this flag (e.g. round_robin " + + "or pick_first).")) + return parser.parse_args() + + +def _create_call_credentials(args): + if args.test_case == 'oauth2_auth_token': + google_credentials, unused_project_id = google_auth.default( + scopes=[args.oauth_scope]) + google_credentials.refresh(google_auth.transport.requests.Request()) + return grpc.access_token_call_credentials(google_credentials.token) + elif args.test_case == 'compute_engine_creds': + google_credentials, unused_project_id = google_auth.default( + scopes=[args.oauth_scope]) + return grpc.metadata_call_credentials( + google_auth.transport.grpc.AuthMetadataPlugin( + credentials=google_credentials, + request=google_auth.transport.requests.Request())) + elif args.test_case == 'jwt_token_creds': + google_credentials = google_auth_jwt.OnDemandCredentials.from_service_account_file( + os.environ[google_auth.environment_vars.CREDENTIALS]) + return grpc.metadata_call_credentials( + google_auth.transport.grpc.AuthMetadataPlugin( + credentials=google_credentials, request=None)) + else: + return None + + +def get_secure_channel_parameters(args): + call_credentials = _create_call_credentials(args) + + channel_opts = () + if args.grpc_test_use_grpclb_with_child_policy: + channel_opts += (( + "grpc.service_config", + '{"loadBalancingConfig": [{"grpclb": {"childPolicy": [{"%s": {}}]}}]}' + % args.grpc_test_use_grpclb_with_child_policy),) + if args.custom_credentials_type is not None: + if args.custom_credentials_type == "compute_engine_channel_creds": + assert call_credentials is None + google_credentials, unused_project_id = google_auth.default( + scopes=[args.oauth_scope]) + call_creds = grpc.metadata_call_credentials( + google_auth.transport.grpc.AuthMetadataPlugin( + credentials=google_credentials, + request=google_auth.transport.requests.Request())) + channel_credentials = grpc.compute_engine_channel_credentials( + call_creds) + else: + raise ValueError("Unknown credentials type '{}'".format( + args.custom_credentials_type)) + elif args.use_tls: + if args.use_test_ca: + root_certificates = resources.test_root_certificates() + else: + root_certificates = None # will load default roots. + + channel_credentials = grpc.ssl_channel_credentials(root_certificates) + if call_credentials is not None: + channel_credentials = grpc.composite_channel_credentials( + channel_credentials, call_credentials) + + if args.server_host_override: + channel_opts += (( + 'grpc.ssl_target_name_override', + args.server_host_override, + ),) + elif args.use_alts: + channel_credentials = grpc.alts_channel_credentials() + + return channel_credentials, channel_opts + + +def _create_channel(args): + target = '{}:{}'.format(args.server_host, args.server_port) + + if args.use_tls or args.use_alts or args.custom_credentials_type is not None: + channel_credentials, options = get_secure_channel_parameters(args) + return grpc.secure_channel(target, channel_credentials, options) + else: + return grpc.insecure_channel(target) + + +def create_stub(channel, args): + if args.test_case == "unimplemented_service": + return test_pb2_grpc.UnimplementedServiceStub(channel) + else: + return test_pb2_grpc.TestServiceStub(channel) + + +def _test_case_from_arg(test_case_arg): + for test_case in methods.TestCase: + if test_case_arg == test_case.value: + return test_case + else: + raise ValueError('No test case "%s"!' % test_case_arg) + + +def test_interoperability(): + args = parse_interop_client_args() + channel = _create_channel(args) + stub = create_stub(channel, args) + test_case = _test_case_from_arg(args.test_case) + test_case.test_interoperability(stub, args) + + +if __name__ == '__main__': + test_interoperability() diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/interop/methods.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/interop/methods.py new file mode 100644 index 00000000000..44a1c38bb93 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/interop/methods.py @@ -0,0 +1,482 @@ +# Copyright 2015 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Implementations of interoperability test methods.""" + +# NOTE(lidiz) This module only exists in Bazel BUILD file, for more details +# please refer to comments in the "bazel_namespace_package_hack" module. +try: + from tests import bazel_namespace_package_hack + bazel_namespace_package_hack.sys_path_to_site_dir_hack() +except ImportError: + pass + +import enum +import json +import os +import threading +import time + +from google import auth as google_auth +from google.auth import environment_vars as google_auth_environment_vars +from google.auth.transport import grpc as google_auth_transport_grpc +from google.auth.transport import requests as google_auth_transport_requests +import grpc + +from src.proto.grpc.testing import empty_pb2 +from src.proto.grpc.testing import messages_pb2 + +_INITIAL_METADATA_KEY = "x-grpc-test-echo-initial" +_TRAILING_METADATA_KEY = "x-grpc-test-echo-trailing-bin" + + +def _expect_status_code(call, expected_code): + if call.code() != expected_code: + raise ValueError('expected code %s, got %s' % + (expected_code, call.code())) + + +def _expect_status_details(call, expected_details): + if call.details() != expected_details: + raise ValueError('expected message %s, got %s' % + (expected_details, call.details())) + + +def _validate_status_code_and_details(call, expected_code, expected_details): + _expect_status_code(call, expected_code) + _expect_status_details(call, expected_details) + + +def _validate_payload_type_and_length(response, expected_type, expected_length): + if response.payload.type is not expected_type: + raise ValueError('expected payload type %s, got %s' % + (expected_type, type(response.payload.type))) + elif len(response.payload.body) != expected_length: + raise ValueError('expected payload body size %d, got %d' % + (expected_length, len(response.payload.body))) + + +def _large_unary_common_behavior(stub, fill_username, fill_oauth_scope, + call_credentials): + size = 314159 + request = messages_pb2.SimpleRequest( + response_type=messages_pb2.COMPRESSABLE, + response_size=size, + payload=messages_pb2.Payload(body=b'\x00' * 271828), + fill_username=fill_username, + fill_oauth_scope=fill_oauth_scope) + response_future = stub.UnaryCall.future(request, + credentials=call_credentials) + response = response_future.result() + _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, size) + return response + + +def _empty_unary(stub): + response = stub.EmptyCall(empty_pb2.Empty()) + if not isinstance(response, empty_pb2.Empty): + raise TypeError('response is of type "%s", not empty_pb2.Empty!' % + type(response)) + + +def _large_unary(stub): + _large_unary_common_behavior(stub, False, False, None) + + +def _client_streaming(stub): + payload_body_sizes = ( + 27182, + 8, + 1828, + 45904, + ) + payloads = (messages_pb2.Payload(body=b'\x00' * size) + for size in payload_body_sizes) + requests = (messages_pb2.StreamingInputCallRequest(payload=payload) + for payload in payloads) + response = stub.StreamingInputCall(requests) + if response.aggregated_payload_size != 74922: + raise ValueError('incorrect size %d!' % + response.aggregated_payload_size) + + +def _server_streaming(stub): + sizes = ( + 31415, + 9, + 2653, + 58979, + ) + + request = messages_pb2.StreamingOutputCallRequest( + response_type=messages_pb2.COMPRESSABLE, + response_parameters=( + messages_pb2.ResponseParameters(size=sizes[0]), + messages_pb2.ResponseParameters(size=sizes[1]), + messages_pb2.ResponseParameters(size=sizes[2]), + messages_pb2.ResponseParameters(size=sizes[3]), + )) + response_iterator = stub.StreamingOutputCall(request) + for index, response in enumerate(response_iterator): + _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, + sizes[index]) + + +class _Pipe(object): + + def __init__(self): + self._condition = threading.Condition() + self._values = [] + self._open = True + + def __iter__(self): + return self + + def __next__(self): + return self.next() + + def next(self): + with self._condition: + while not self._values and self._open: + self._condition.wait() + if self._values: + return self._values.pop(0) + else: + raise StopIteration() + + def add(self, value): + with self._condition: + self._values.append(value) + self._condition.notify() + + def close(self): + with self._condition: + self._open = False + self._condition.notify() + + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + self.close() + + +def _ping_pong(stub): + request_response_sizes = ( + 31415, + 9, + 2653, + 58979, + ) + request_payload_sizes = ( + 27182, + 8, + 1828, + 45904, + ) + + with _Pipe() as pipe: + response_iterator = stub.FullDuplexCall(pipe) + for response_size, payload_size in zip(request_response_sizes, + request_payload_sizes): + request = messages_pb2.StreamingOutputCallRequest( + response_type=messages_pb2.COMPRESSABLE, + response_parameters=(messages_pb2.ResponseParameters( + size=response_size),), + payload=messages_pb2.Payload(body=b'\x00' * payload_size)) + pipe.add(request) + response = next(response_iterator) + _validate_payload_type_and_length(response, + messages_pb2.COMPRESSABLE, + response_size) + + +def _cancel_after_begin(stub): + with _Pipe() as pipe: + response_future = stub.StreamingInputCall.future(pipe) + response_future.cancel() + if not response_future.cancelled(): + raise ValueError('expected cancelled method to return True') + if response_future.code() is not grpc.StatusCode.CANCELLED: + raise ValueError('expected status code CANCELLED') + + +def _cancel_after_first_response(stub): + request_response_sizes = ( + 31415, + 9, + 2653, + 58979, + ) + request_payload_sizes = ( + 27182, + 8, + 1828, + 45904, + ) + with _Pipe() as pipe: + response_iterator = stub.FullDuplexCall(pipe) + + response_size = request_response_sizes[0] + payload_size = request_payload_sizes[0] + request = messages_pb2.StreamingOutputCallRequest( + response_type=messages_pb2.COMPRESSABLE, + response_parameters=(messages_pb2.ResponseParameters( + size=response_size),), + payload=messages_pb2.Payload(body=b'\x00' * payload_size)) + pipe.add(request) + response = next(response_iterator) + # We test the contents of `response` in the Ping Pong test - don't check + # them here. + response_iterator.cancel() + + try: + next(response_iterator) + except grpc.RpcError as rpc_error: + if rpc_error.code() is not grpc.StatusCode.CANCELLED: + raise + else: + raise ValueError('expected call to be cancelled') + + +def _timeout_on_sleeping_server(stub): + request_payload_size = 27182 + with _Pipe() as pipe: + response_iterator = stub.FullDuplexCall(pipe, timeout=0.001) + + request = messages_pb2.StreamingOutputCallRequest( + response_type=messages_pb2.COMPRESSABLE, + payload=messages_pb2.Payload(body=b'\x00' * request_payload_size)) + pipe.add(request) + try: + next(response_iterator) + except grpc.RpcError as rpc_error: + if rpc_error.code() is not grpc.StatusCode.DEADLINE_EXCEEDED: + raise + else: + raise ValueError('expected call to exceed deadline') + + +def _empty_stream(stub): + with _Pipe() as pipe: + response_iterator = stub.FullDuplexCall(pipe) + pipe.close() + try: + next(response_iterator) + raise ValueError('expected exactly 0 responses') + except StopIteration: + pass + + +def _status_code_and_message(stub): + details = 'test status message' + code = 2 + status = grpc.StatusCode.UNKNOWN # code = 2 + + # Test with a UnaryCall + request = messages_pb2.SimpleRequest( + response_type=messages_pb2.COMPRESSABLE, + response_size=1, + payload=messages_pb2.Payload(body=b'\x00'), + response_status=messages_pb2.EchoStatus(code=code, message=details)) + response_future = stub.UnaryCall.future(request) + _validate_status_code_and_details(response_future, status, details) + + # Test with a FullDuplexCall + with _Pipe() as pipe: + response_iterator = stub.FullDuplexCall(pipe) + request = messages_pb2.StreamingOutputCallRequest( + response_type=messages_pb2.COMPRESSABLE, + response_parameters=(messages_pb2.ResponseParameters(size=1),), + payload=messages_pb2.Payload(body=b'\x00'), + response_status=messages_pb2.EchoStatus(code=code, message=details)) + pipe.add(request) # sends the initial request. + try: + next(response_iterator) + except grpc.RpcError as rpc_error: + assert rpc_error.code() == status + # Dropping out of with block closes the pipe + _validate_status_code_and_details(response_iterator, status, details) + + +def _unimplemented_method(test_service_stub): + response_future = (test_service_stub.UnimplementedCall.future( + empty_pb2.Empty())) + _expect_status_code(response_future, grpc.StatusCode.UNIMPLEMENTED) + + +def _unimplemented_service(unimplemented_service_stub): + response_future = (unimplemented_service_stub.UnimplementedCall.future( + empty_pb2.Empty())) + _expect_status_code(response_future, grpc.StatusCode.UNIMPLEMENTED) + + +def _custom_metadata(stub): + initial_metadata_value = "test_initial_metadata_value" + trailing_metadata_value = b"\x0a\x0b\x0a\x0b\x0a\x0b" + metadata = ((_INITIAL_METADATA_KEY, initial_metadata_value), + (_TRAILING_METADATA_KEY, trailing_metadata_value)) + + def _validate_metadata(response): + initial_metadata = dict(response.initial_metadata()) + if initial_metadata[_INITIAL_METADATA_KEY] != initial_metadata_value: + raise ValueError('expected initial metadata %s, got %s' % + (initial_metadata_value, + initial_metadata[_INITIAL_METADATA_KEY])) + trailing_metadata = dict(response.trailing_metadata()) + if trailing_metadata[_TRAILING_METADATA_KEY] != trailing_metadata_value: + raise ValueError('expected trailing metadata %s, got %s' % + (trailing_metadata_value, + trailing_metadata[_TRAILING_METADATA_KEY])) + + # Testing with UnaryCall + request = messages_pb2.SimpleRequest( + response_type=messages_pb2.COMPRESSABLE, + response_size=1, + payload=messages_pb2.Payload(body=b'\x00')) + response_future = stub.UnaryCall.future(request, metadata=metadata) + _validate_metadata(response_future) + + # Testing with FullDuplexCall + with _Pipe() as pipe: + response_iterator = stub.FullDuplexCall(pipe, metadata=metadata) + request = messages_pb2.StreamingOutputCallRequest( + response_type=messages_pb2.COMPRESSABLE, + response_parameters=(messages_pb2.ResponseParameters(size=1),)) + pipe.add(request) # Sends the request + next(response_iterator) # Causes server to send trailing metadata + # Dropping out of the with block closes the pipe + _validate_metadata(response_iterator) + + +def _compute_engine_creds(stub, args): + response = _large_unary_common_behavior(stub, True, True, None) + if args.default_service_account != response.username: + raise ValueError('expected username %s, got %s' % + (args.default_service_account, response.username)) + + +def _oauth2_auth_token(stub, args): + json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS] + wanted_email = json.load(open(json_key_filename, 'r'))['client_email'] + response = _large_unary_common_behavior(stub, True, True, None) + if wanted_email != response.username: + raise ValueError('expected username %s, got %s' % + (wanted_email, response.username)) + if args.oauth_scope.find(response.oauth_scope) == -1: + raise ValueError( + 'expected to find oauth scope "{}" in received "{}"'.format( + response.oauth_scope, args.oauth_scope)) + + +def _jwt_token_creds(stub, args): + json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS] + wanted_email = json.load(open(json_key_filename, 'r'))['client_email'] + response = _large_unary_common_behavior(stub, True, False, None) + if wanted_email != response.username: + raise ValueError('expected username %s, got %s' % + (wanted_email, response.username)) + + +def _per_rpc_creds(stub, args): + json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS] + wanted_email = json.load(open(json_key_filename, 'r'))['client_email'] + google_credentials, unused_project_id = google_auth.default( + scopes=[args.oauth_scope]) + call_credentials = grpc.metadata_call_credentials( + google_auth_transport_grpc.AuthMetadataPlugin( + credentials=google_credentials, + request=google_auth_transport_requests.Request())) + response = _large_unary_common_behavior(stub, True, False, call_credentials) + if wanted_email != response.username: + raise ValueError('expected username %s, got %s' % + (wanted_email, response.username)) + + +def _special_status_message(stub, args): + details = b'\t\ntest with whitespace\r\nand Unicode BMP \xe2\x98\xba and non-BMP \xf0\x9f\x98\x88\t\n'.decode( + 'utf-8') + code = 2 + status = grpc.StatusCode.UNKNOWN # code = 2 + + # Test with a UnaryCall + request = messages_pb2.SimpleRequest( + response_type=messages_pb2.COMPRESSABLE, + response_size=1, + payload=messages_pb2.Payload(body=b'\x00'), + response_status=messages_pb2.EchoStatus(code=code, message=details)) + response_future = stub.UnaryCall.future(request) + _validate_status_code_and_details(response_future, status, details) + + +class TestCase(enum.Enum): + EMPTY_UNARY = 'empty_unary' + LARGE_UNARY = 'large_unary' + SERVER_STREAMING = 'server_streaming' + CLIENT_STREAMING = 'client_streaming' + PING_PONG = 'ping_pong' + CANCEL_AFTER_BEGIN = 'cancel_after_begin' + CANCEL_AFTER_FIRST_RESPONSE = 'cancel_after_first_response' + EMPTY_STREAM = 'empty_stream' + STATUS_CODE_AND_MESSAGE = 'status_code_and_message' + UNIMPLEMENTED_METHOD = 'unimplemented_method' + UNIMPLEMENTED_SERVICE = 'unimplemented_service' + CUSTOM_METADATA = "custom_metadata" + COMPUTE_ENGINE_CREDS = 'compute_engine_creds' + OAUTH2_AUTH_TOKEN = 'oauth2_auth_token' + JWT_TOKEN_CREDS = 'jwt_token_creds' + PER_RPC_CREDS = 'per_rpc_creds' + TIMEOUT_ON_SLEEPING_SERVER = 'timeout_on_sleeping_server' + SPECIAL_STATUS_MESSAGE = 'special_status_message' + + def test_interoperability(self, stub, args): + if self is TestCase.EMPTY_UNARY: + _empty_unary(stub) + elif self is TestCase.LARGE_UNARY: + _large_unary(stub) + elif self is TestCase.SERVER_STREAMING: + _server_streaming(stub) + elif self is TestCase.CLIENT_STREAMING: + _client_streaming(stub) + elif self is TestCase.PING_PONG: + _ping_pong(stub) + elif self is TestCase.CANCEL_AFTER_BEGIN: + _cancel_after_begin(stub) + elif self is TestCase.CANCEL_AFTER_FIRST_RESPONSE: + _cancel_after_first_response(stub) + elif self is TestCase.TIMEOUT_ON_SLEEPING_SERVER: + _timeout_on_sleeping_server(stub) + elif self is TestCase.EMPTY_STREAM: + _empty_stream(stub) + elif self is TestCase.STATUS_CODE_AND_MESSAGE: + _status_code_and_message(stub) + elif self is TestCase.UNIMPLEMENTED_METHOD: + _unimplemented_method(stub) + elif self is TestCase.UNIMPLEMENTED_SERVICE: + _unimplemented_service(stub) + elif self is TestCase.CUSTOM_METADATA: + _custom_metadata(stub) + elif self is TestCase.COMPUTE_ENGINE_CREDS: + _compute_engine_creds(stub, args) + elif self is TestCase.OAUTH2_AUTH_TOKEN: + _oauth2_auth_token(stub, args) + elif self is TestCase.JWT_TOKEN_CREDS: + _jwt_token_creds(stub, args) + elif self is TestCase.PER_RPC_CREDS: + _per_rpc_creds(stub, args) + elif self is TestCase.SPECIAL_STATUS_MESSAGE: + _special_status_message(stub, args) + else: + raise NotImplementedError('Test case "%s" not implemented!' % + self.name) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/interop/resources.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/interop/resources.py new file mode 100644 index 00000000000..a55919a60ae --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/interop/resources.py @@ -0,0 +1,42 @@ +# Copyright 2015 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Constants and functions for data used in interoperability testing.""" + +import argparse +import pkgutil +import os + +_ROOT_CERTIFICATES_RESOURCE_PATH = 'credentials/ca.pem' +_PRIVATE_KEY_RESOURCE_PATH = 'credentials/server1.key' +_CERTIFICATE_CHAIN_RESOURCE_PATH = 'credentials/server1.pem' + + +def test_root_certificates(): + return pkgutil.get_data(__name__, _ROOT_CERTIFICATES_RESOURCE_PATH) + + +def private_key(): + return pkgutil.get_data(__name__, _PRIVATE_KEY_RESOURCE_PATH) + + +def certificate_chain(): + return pkgutil.get_data(__name__, _CERTIFICATE_CHAIN_RESOURCE_PATH) + + +def parse_bool(value): + if value == 'true': + return True + if value == 'false': + return False + raise argparse.ArgumentTypeError('Only true/false allowed') diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/interop/server.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/interop/server.py new file mode 100644 index 00000000000..c85adb0b0bb --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/interop/server.py @@ -0,0 +1,76 @@ +# Copyright 2015 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The Python implementation of the GRPC interoperability test server.""" + +import argparse +from concurrent import futures +import logging + +import grpc +from src.proto.grpc.testing import test_pb2_grpc + +from tests.interop import service +from tests.interop import resources +from tests.unit import test_common + +logging.basicConfig() +_LOGGER = logging.getLogger(__name__) + + +def parse_interop_server_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument('--port', + type=int, + required=True, + help='the port on which to serve') + parser.add_argument('--use_tls', + default=False, + type=resources.parse_bool, + help='require a secure connection') + parser.add_argument('--use_alts', + default=False, + type=resources.parse_bool, + help='require an ALTS connection') + return parser.parse_args() + + +def get_server_credentials(use_tls): + if use_tls: + private_key = resources.private_key() + certificate_chain = resources.certificate_chain() + return grpc.ssl_server_credentials(((private_key, certificate_chain),)) + else: + return grpc.alts_server_credentials() + + +def serve(): + args = parse_interop_server_arguments() + + server = test_common.test_server() + test_pb2_grpc.add_TestServiceServicer_to_server(service.TestService(), + server) + if args.use_tls or args.use_alts: + credentials = get_server_credentials(args.use_tls) + server.add_secure_port('[::]:{}'.format(args.port), credentials) + else: + server.add_insecure_port('[::]:{}'.format(args.port)) + + server.start() + _LOGGER.info('Server serving.') + server.wait_for_termination() + _LOGGER.info('Server stopped; exiting.') + + +if __name__ == '__main__': + serve() diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/interop/service.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/interop/service.py new file mode 100644 index 00000000000..08bb0c45a24 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/interop/service.py @@ -0,0 +1,96 @@ +# Copyright 2019 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The Python implementation of the TestServicer.""" + +import time + +import grpc + +from src.proto.grpc.testing import empty_pb2 +from src.proto.grpc.testing import messages_pb2 +from src.proto.grpc.testing import test_pb2_grpc + +_INITIAL_METADATA_KEY = "x-grpc-test-echo-initial" +_TRAILING_METADATA_KEY = "x-grpc-test-echo-trailing-bin" +_US_IN_A_SECOND = 1000 * 1000 + + +def _maybe_echo_metadata(servicer_context): + """Copies metadata from request to response if it is present.""" + invocation_metadata = dict(servicer_context.invocation_metadata()) + if _INITIAL_METADATA_KEY in invocation_metadata: + initial_metadatum = (_INITIAL_METADATA_KEY, + invocation_metadata[_INITIAL_METADATA_KEY]) + servicer_context.send_initial_metadata((initial_metadatum,)) + if _TRAILING_METADATA_KEY in invocation_metadata: + trailing_metadatum = (_TRAILING_METADATA_KEY, + invocation_metadata[_TRAILING_METADATA_KEY]) + servicer_context.set_trailing_metadata((trailing_metadatum,)) + + +def _maybe_echo_status_and_message(request, servicer_context): + """Sets the response context code and details if the request asks for them""" + if request.HasField('response_status'): + servicer_context.set_code(request.response_status.code) + servicer_context.set_details(request.response_status.message) + + +class TestService(test_pb2_grpc.TestServiceServicer): + + def EmptyCall(self, request, context): + _maybe_echo_metadata(context) + return empty_pb2.Empty() + + def UnaryCall(self, request, context): + _maybe_echo_metadata(context) + _maybe_echo_status_and_message(request, context) + return messages_pb2.SimpleResponse( + payload=messages_pb2.Payload(type=messages_pb2.COMPRESSABLE, + body=b'\x00' * request.response_size)) + + def StreamingOutputCall(self, request, context): + _maybe_echo_status_and_message(request, context) + for response_parameters in request.response_parameters: + if response_parameters.interval_us != 0: + time.sleep(response_parameters.interval_us / _US_IN_A_SECOND) + yield messages_pb2.StreamingOutputCallResponse( + payload=messages_pb2.Payload(type=request.response_type, + body=b'\x00' * + response_parameters.size)) + + def StreamingInputCall(self, request_iterator, context): + aggregate_size = 0 + for request in request_iterator: + if request.payload is not None and request.payload.body: + aggregate_size += len(request.payload.body) + return messages_pb2.StreamingInputCallResponse( + aggregated_payload_size=aggregate_size) + + def FullDuplexCall(self, request_iterator, context): + _maybe_echo_metadata(context) + for request in request_iterator: + _maybe_echo_status_and_message(request, context) + for response_parameters in request.response_parameters: + if response_parameters.interval_us != 0: + time.sleep(response_parameters.interval_us / + _US_IN_A_SECOND) + yield messages_pb2.StreamingOutputCallResponse( + payload=messages_pb2.Payload(type=request.payload.type, + body=b'\x00' * + response_parameters.size)) + + # NOTE(nathaniel): Apparently this is the same as the full-duplex call? + # NOTE(atash): It isn't even called in the interop spec (Oct 22 2015)... + def HalfDuplexCall(self, request_iterator, context): + return self.FullDuplexCall(request_iterator, context) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/qps/__init__.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/qps/__init__.py new file mode 100644 index 00000000000..5772620b602 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/qps/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2016 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/qps/benchmark_client.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/qps/benchmark_client.py new file mode 100644 index 00000000000..17835e7c0db --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/qps/benchmark_client.py @@ -0,0 +1,202 @@ +# Copyright 2016 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Defines test client behaviors (UNARY/STREAMING) (SYNC/ASYNC).""" + +import abc +import threading +import time + +from concurrent import futures +from six.moves import queue + +import grpc +from src.proto.grpc.testing import messages_pb2 +from src.proto.grpc.testing import benchmark_service_pb2_grpc +from tests.unit import resources +from tests.unit import test_common + +_TIMEOUT = 60 * 60 * 24 + + +class GenericStub(object): + + def __init__(self, channel): + self.UnaryCall = channel.unary_unary( + '/grpc.testing.BenchmarkService/UnaryCall') + self.StreamingCall = channel.stream_stream( + '/grpc.testing.BenchmarkService/StreamingCall') + + +class BenchmarkClient: + """Benchmark client interface that exposes a non-blocking send_request().""" + + __metaclass__ = abc.ABCMeta + + def __init__(self, server, config, hist): + # Create the stub + if config.HasField('security_params'): + creds = grpc.ssl_channel_credentials( + resources.test_root_certificates()) + channel = test_common.test_secure_channel( + server, creds, config.security_params.server_host_override) + else: + channel = grpc.insecure_channel(server) + + # waits for the channel to be ready before we start sending messages + grpc.channel_ready_future(channel).result() + + if config.payload_config.WhichOneof('payload') == 'simple_params': + self._generic = False + self._stub = benchmark_service_pb2_grpc.BenchmarkServiceStub( + channel) + payload = messages_pb2.Payload( + body=bytes(b'\0' * + config.payload_config.simple_params.req_size)) + self._request = messages_pb2.SimpleRequest( + payload=payload, + response_size=config.payload_config.simple_params.resp_size) + else: + self._generic = True + self._stub = GenericStub(channel) + self._request = bytes(b'\0' * + config.payload_config.bytebuf_params.req_size) + + self._hist = hist + self._response_callbacks = [] + + def add_response_callback(self, callback): + """callback will be invoked as callback(client, query_time)""" + self._response_callbacks.append(callback) + + @abc.abstractmethod + def send_request(self): + """Non-blocking wrapper for a client's request operation.""" + raise NotImplementedError() + + def start(self): + pass + + def stop(self): + pass + + def _handle_response(self, client, query_time): + self._hist.add(query_time * 1e9) # Report times in nanoseconds + for callback in self._response_callbacks: + callback(client, query_time) + + +class UnarySyncBenchmarkClient(BenchmarkClient): + + def __init__(self, server, config, hist): + super(UnarySyncBenchmarkClient, self).__init__(server, config, hist) + self._pool = futures.ThreadPoolExecutor( + max_workers=config.outstanding_rpcs_per_channel) + + def send_request(self): + # Send requests in separate threads to support multiple outstanding rpcs + # (See src/proto/grpc/testing/control.proto) + self._pool.submit(self._dispatch_request) + + def stop(self): + self._pool.shutdown(wait=True) + self._stub = None + + def _dispatch_request(self): + start_time = time.time() + self._stub.UnaryCall(self._request, _TIMEOUT) + end_time = time.time() + self._handle_response(self, end_time - start_time) + + +class UnaryAsyncBenchmarkClient(BenchmarkClient): + + def send_request(self): + # Use the Future callback api to support multiple outstanding rpcs + start_time = time.time() + response_future = self._stub.UnaryCall.future(self._request, _TIMEOUT) + response_future.add_done_callback( + lambda resp: self._response_received(start_time, resp)) + + def _response_received(self, start_time, resp): + resp.result() + end_time = time.time() + self._handle_response(self, end_time - start_time) + + def stop(self): + self._stub = None + + +class _SyncStream(object): + + def __init__(self, stub, generic, request, handle_response): + self._stub = stub + self._generic = generic + self._request = request + self._handle_response = handle_response + self._is_streaming = False + self._request_queue = queue.Queue() + self._send_time_queue = queue.Queue() + + def send_request(self): + self._send_time_queue.put(time.time()) + self._request_queue.put(self._request) + + def start(self): + self._is_streaming = True + response_stream = self._stub.StreamingCall(self._request_generator(), + _TIMEOUT) + for _ in response_stream: + self._handle_response( + self, + time.time() - self._send_time_queue.get_nowait()) + + def stop(self): + self._is_streaming = False + + def _request_generator(self): + while self._is_streaming: + try: + request = self._request_queue.get(block=True, timeout=1.0) + yield request + except queue.Empty: + pass + + +class StreamingSyncBenchmarkClient(BenchmarkClient): + + def __init__(self, server, config, hist): + super(StreamingSyncBenchmarkClient, self).__init__(server, config, hist) + self._pool = futures.ThreadPoolExecutor( + max_workers=config.outstanding_rpcs_per_channel) + self._streams = [ + _SyncStream(self._stub, self._generic, self._request, + self._handle_response) + for _ in range(config.outstanding_rpcs_per_channel) + ] + self._curr_stream = 0 + + def send_request(self): + # Use a round_robin scheduler to determine what stream to send on + self._streams[self._curr_stream].send_request() + self._curr_stream = (self._curr_stream + 1) % len(self._streams) + + def start(self): + for stream in self._streams: + self._pool.submit(stream.start) + + def stop(self): + for stream in self._streams: + stream.stop() + self._pool.shutdown(wait=True) + self._stub = None diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/qps/benchmark_server.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/qps/benchmark_server.py new file mode 100644 index 00000000000..75280bd7719 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/qps/benchmark_server.py @@ -0,0 +1,44 @@ +# Copyright 2016 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from src.proto.grpc.testing import messages_pb2 +from src.proto.grpc.testing import benchmark_service_pb2_grpc + + +class BenchmarkServer(benchmark_service_pb2_grpc.BenchmarkServiceServicer): + """Synchronous Server implementation for the Benchmark service.""" + + def UnaryCall(self, request, context): + payload = messages_pb2.Payload(body=b'\0' * request.response_size) + return messages_pb2.SimpleResponse(payload=payload) + + def StreamingCall(self, request_iterator, context): + for request in request_iterator: + payload = messages_pb2.Payload(body=b'\0' * request.response_size) + yield messages_pb2.SimpleResponse(payload=payload) + + +class GenericBenchmarkServer(benchmark_service_pb2_grpc.BenchmarkServiceServicer + ): + """Generic Server implementation for the Benchmark service.""" + + def __init__(self, resp_size): + self._response = b'\0' * resp_size + + def UnaryCall(self, request, context): + return self._response + + def StreamingCall(self, request_iterator, context): + for request in request_iterator: + yield self._response diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/qps/client_runner.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/qps/client_runner.py new file mode 100644 index 00000000000..c5d299f6463 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/qps/client_runner.py @@ -0,0 +1,90 @@ +# Copyright 2016 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Defines behavior for WHEN clients send requests. + +Each client exposes a non-blocking send_request() method that the +ClientRunner invokes either periodically or in response to some event. +""" + +import abc +import threading +import time + + +class ClientRunner: + """Abstract interface for sending requests from clients.""" + + __metaclass__ = abc.ABCMeta + + def __init__(self, client): + self._client = client + + @abc.abstractmethod + def start(self): + raise NotImplementedError() + + @abc.abstractmethod + def stop(self): + raise NotImplementedError() + + +class OpenLoopClientRunner(ClientRunner): + + def __init__(self, client, interval_generator): + super(OpenLoopClientRunner, self).__init__(client) + self._is_running = False + self._interval_generator = interval_generator + self._dispatch_thread = threading.Thread(target=self._dispatch_requests, + args=()) + + def start(self): + self._is_running = True + self._client.start() + self._dispatch_thread.start() + + def stop(self): + self._is_running = False + self._client.stop() + self._dispatch_thread.join() + self._client = None + + def _dispatch_requests(self): + while self._is_running: + self._client.send_request() + time.sleep(next(self._interval_generator)) + + +class ClosedLoopClientRunner(ClientRunner): + + def __init__(self, client, request_count): + super(ClosedLoopClientRunner, self).__init__(client) + self._is_running = False + self._request_count = request_count + # Send a new request on each response for closed loop + self._client.add_response_callback(self._send_request) + + def start(self): + self._is_running = True + self._client.start() + for _ in range(self._request_count): + self._client.send_request() + + def stop(self): + self._is_running = False + self._client.stop() + self._client = None + + def _send_request(self, client, response_time): + if self._is_running: + client.send_request() diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/qps/histogram.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/qps/histogram.py new file mode 100644 index 00000000000..8139a6ee2fb --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/qps/histogram.py @@ -0,0 +1,80 @@ +# Copyright 2016 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import threading + +from src.proto.grpc.testing import stats_pb2 + + +class Histogram(object): + """Histogram class used for recording performance testing data. + + This class is thread safe. + """ + + def __init__(self, resolution, max_possible): + self._lock = threading.Lock() + self._resolution = resolution + self._max_possible = max_possible + self._sum = 0 + self._sum_of_squares = 0 + self.multiplier = 1.0 + self._resolution + self._count = 0 + self._min = self._max_possible + self._max = 0 + self._buckets = [0] * (self._bucket_for(self._max_possible) + 1) + + def reset(self): + with self._lock: + self._sum = 0 + self._sum_of_squares = 0 + self._count = 0 + self._min = self._max_possible + self._max = 0 + self._buckets = [0] * (self._bucket_for(self._max_possible) + 1) + + def add(self, val): + with self._lock: + self._sum += val + self._sum_of_squares += val * val + self._count += 1 + self._min = min(self._min, val) + self._max = max(self._max, val) + self._buckets[self._bucket_for(val)] += 1 + + def get_data(self): + with self._lock: + data = stats_pb2.HistogramData() + data.bucket.extend(self._buckets) + data.min_seen = self._min + data.max_seen = self._max + data.sum = self._sum + data.sum_of_squares = self._sum_of_squares + data.count = self._count + return data + + def merge(self, another_data): + with self._lock: + for i in range(len(self._buckets)): + self._buckets[i] += another_data.bucket[i] + self._min = min(self._min, another_data.min_seen) + self._max = max(self._max, another_data.max_seen) + self._sum += another_data.sum + self._sum_of_squares += another_data.sum_of_squares + self._count += another_data.count + + def _bucket_for(self, val): + val = min(val, self._max_possible) + return int(math.log(val, self.multiplier)) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/qps/qps_worker.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/qps/qps_worker.py new file mode 100644 index 00000000000..a7e692821ac --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/qps/qps_worker.py @@ -0,0 +1,46 @@ +# Copyright 2016 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The entry point for the qps worker.""" + +import argparse +import time + +import grpc +from src.proto.grpc.testing import worker_service_pb2_grpc + +from tests.qps import worker_server +from tests.unit import test_common + + +def run_worker_server(port): + server = test_common.test_server() + servicer = worker_server.WorkerServer() + worker_service_pb2_grpc.add_WorkerServiceServicer_to_server( + servicer, server) + server.add_insecure_port('[::]:{}'.format(port)) + server.start() + servicer.wait_for_quit() + server.stop(0) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='gRPC Python performance testing worker') + parser.add_argument('--driver_port', + type=int, + dest='port', + help='The port the worker should listen on') + args = parser.parse_args() + + run_worker_server(args.port) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/qps/worker_server.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/qps/worker_server.py new file mode 100644 index 00000000000..65b081e5d1c --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/qps/worker_server.py @@ -0,0 +1,186 @@ +# Copyright 2016 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import multiprocessing +import random +import threading +import time + +from concurrent import futures +import grpc +from src.proto.grpc.testing import control_pb2 +from src.proto.grpc.testing import benchmark_service_pb2_grpc +from src.proto.grpc.testing import worker_service_pb2_grpc +from src.proto.grpc.testing import stats_pb2 + +from tests.qps import benchmark_client +from tests.qps import benchmark_server +from tests.qps import client_runner +from tests.qps import histogram +from tests.unit import resources +from tests.unit import test_common + + +class WorkerServer(worker_service_pb2_grpc.WorkerServiceServicer): + """Python Worker Server implementation.""" + + def __init__(self): + self._quit_event = threading.Event() + + def RunServer(self, request_iterator, context): + config = next(request_iterator).setup #pylint: disable=stop-iteration-return + server, port = self._create_server(config) + cores = multiprocessing.cpu_count() + server.start() + start_time = time.time() + yield self._get_server_status(start_time, start_time, port, cores) + + for request in request_iterator: + end_time = time.time() + status = self._get_server_status(start_time, end_time, port, cores) + if request.mark.reset: + start_time = end_time + yield status + server.stop(None) + + def _get_server_status(self, start_time, end_time, port, cores): + end_time = time.time() + elapsed_time = end_time - start_time + stats = stats_pb2.ServerStats(time_elapsed=elapsed_time, + time_user=elapsed_time, + time_system=elapsed_time) + return control_pb2.ServerStatus(stats=stats, port=port, cores=cores) + + def _create_server(self, config): + if config.async_server_threads == 0: + # This is the default concurrent.futures thread pool size, but + # None doesn't seem to work + server_threads = multiprocessing.cpu_count() * 5 + else: + server_threads = config.async_server_threads + server = test_common.test_server(max_workers=server_threads) + if config.server_type == control_pb2.ASYNC_SERVER: + servicer = benchmark_server.BenchmarkServer() + benchmark_service_pb2_grpc.add_BenchmarkServiceServicer_to_server( + servicer, server) + elif config.server_type == control_pb2.ASYNC_GENERIC_SERVER: + resp_size = config.payload_config.bytebuf_params.resp_size + servicer = benchmark_server.GenericBenchmarkServer(resp_size) + method_implementations = { + 'StreamingCall': + grpc.stream_stream_rpc_method_handler(servicer.StreamingCall + ), + 'UnaryCall': + grpc.unary_unary_rpc_method_handler(servicer.UnaryCall), + } + handler = grpc.method_handlers_generic_handler( + 'grpc.testing.BenchmarkService', method_implementations) + server.add_generic_rpc_handlers((handler,)) + else: + raise Exception('Unsupported server type {}'.format( + config.server_type)) + + if config.HasField('security_params'): # Use SSL + server_creds = grpc.ssl_server_credentials( + ((resources.private_key(), resources.certificate_chain()),)) + port = server.add_secure_port('[::]:{}'.format(config.port), + server_creds) + else: + port = server.add_insecure_port('[::]:{}'.format(config.port)) + + return (server, port) + + def RunClient(self, request_iterator, context): + config = next(request_iterator).setup #pylint: disable=stop-iteration-return + client_runners = [] + qps_data = histogram.Histogram(config.histogram_params.resolution, + config.histogram_params.max_possible) + start_time = time.time() + + # Create a client for each channel + for i in range(config.client_channels): + server = config.server_targets[i % len(config.server_targets)] + runner = self._create_client_runner(server, config, qps_data) + client_runners.append(runner) + runner.start() + + end_time = time.time() + yield self._get_client_status(start_time, end_time, qps_data) + + # Respond to stat requests + for request in request_iterator: + end_time = time.time() + status = self._get_client_status(start_time, end_time, qps_data) + if request.mark.reset: + qps_data.reset() + start_time = time.time() + yield status + + # Cleanup the clients + for runner in client_runners: + runner.stop() + + def _get_client_status(self, start_time, end_time, qps_data): + latencies = qps_data.get_data() + end_time = time.time() + elapsed_time = end_time - start_time + stats = stats_pb2.ClientStats(latencies=latencies, + time_elapsed=elapsed_time, + time_user=elapsed_time, + time_system=elapsed_time) + return control_pb2.ClientStatus(stats=stats) + + def _create_client_runner(self, server, config, qps_data): + if config.client_type == control_pb2.SYNC_CLIENT: + if config.rpc_type == control_pb2.UNARY: + client = benchmark_client.UnarySyncBenchmarkClient( + server, config, qps_data) + elif config.rpc_type == control_pb2.STREAMING: + client = benchmark_client.StreamingSyncBenchmarkClient( + server, config, qps_data) + elif config.client_type == control_pb2.ASYNC_CLIENT: + if config.rpc_type == control_pb2.UNARY: + client = benchmark_client.UnaryAsyncBenchmarkClient( + server, config, qps_data) + else: + raise Exception('Async streaming client not supported') + else: + raise Exception('Unsupported client type {}'.format( + config.client_type)) + + # In multi-channel tests, we split the load across all channels + load_factor = float(config.client_channels) + if config.load_params.WhichOneof('load') == 'closed_loop': + runner = client_runner.ClosedLoopClientRunner( + client, config.outstanding_rpcs_per_channel) + else: # Open loop Poisson + alpha = config.load_params.poisson.offered_load / load_factor + + def poisson(): + while True: + yield random.expovariate(alpha) + + runner = client_runner.OpenLoopClientRunner(client, poisson()) + + return runner + + def CoreCount(self, request, context): + return control_pb2.CoreResponse(cores=multiprocessing.cpu_count()) + + def QuitWorker(self, request, context): + self._quit_event.set() + return control_pb2.Void() + + def wait_for_quit(self): + self._quit_event.wait() diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/reflection/__init__.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/reflection/__init__.py new file mode 100644 index 00000000000..5772620b602 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/reflection/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2016 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/reflection/_reflection_servicer_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/reflection/_reflection_servicer_test.py new file mode 100644 index 00000000000..169e55022da --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/reflection/_reflection_servicer_test.py @@ -0,0 +1,195 @@ +# Copyright 2016 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests of grpc_reflection.v1alpha.reflection.""" + +import unittest + +import grpc + +from grpc_reflection.v1alpha import reflection +from grpc_reflection.v1alpha import reflection_pb2 +from grpc_reflection.v1alpha import reflection_pb2_grpc + +from google.protobuf import descriptor_pool +from google.protobuf import descriptor_pb2 + +from src.proto.grpc.testing import empty_pb2 +from src.proto.grpc.testing.proto2 import empty2_extensions_pb2 + +from tests.unit import test_common + +_EMPTY_PROTO_FILE_NAME = 'src/proto/grpc/testing/empty.proto' +_EMPTY_PROTO_SYMBOL_NAME = 'grpc.testing.Empty' +_SERVICE_NAMES = ('Angstrom', 'Bohr', 'Curie', 'Dyson', 'Einstein', 'Feynman', + 'Galilei') +_EMPTY_EXTENSIONS_SYMBOL_NAME = 'grpc.testing.proto2.EmptyWithExtensions' +_EMPTY_EXTENSIONS_NUMBERS = ( + 124, + 125, + 126, + 127, + 128, +) + + +def _file_descriptor_to_proto(descriptor): + proto = descriptor_pb2.FileDescriptorProto() + descriptor.CopyToProto(proto) + return proto.SerializeToString() + + +class ReflectionServicerTest(unittest.TestCase): + + # TODO(https://github.com/grpc/grpc/issues/17844) + # Bazel + Python 3 will result in creating two different instance of + # DESCRIPTOR for each message. So, the equal comparison between protobuf + # returned by stub and manually crafted protobuf will always fail. + def _assert_sequence_of_proto_equal(self, x, y): + self.assertSequenceEqual( + tuple(proto.SerializeToString() for proto in x), + tuple(proto.SerializeToString() for proto in y), + ) + + def setUp(self): + self._server = test_common.test_server() + reflection.enable_server_reflection(_SERVICE_NAMES, self._server) + port = self._server.add_insecure_port('[::]:0') + self._server.start() + + self._channel = grpc.insecure_channel('localhost:%d' % port) + self._stub = reflection_pb2_grpc.ServerReflectionStub(self._channel) + + def tearDown(self): + self._server.stop(None) + self._channel.close() + + def testFileByName(self): + requests = ( + reflection_pb2.ServerReflectionRequest( + file_by_filename=_EMPTY_PROTO_FILE_NAME), + reflection_pb2.ServerReflectionRequest( + file_by_filename='i-donut-exist'), + ) + responses = tuple(self._stub.ServerReflectionInfo(iter(requests))) + expected_responses = ( + reflection_pb2.ServerReflectionResponse( + valid_host='', + file_descriptor_response=reflection_pb2.FileDescriptorResponse( + file_descriptor_proto=( + _file_descriptor_to_proto(empty_pb2.DESCRIPTOR),))), + reflection_pb2.ServerReflectionResponse( + valid_host='', + error_response=reflection_pb2.ErrorResponse( + error_code=grpc.StatusCode.NOT_FOUND.value[0], + error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(), + )), + ) + self._assert_sequence_of_proto_equal(expected_responses, responses) + + def testFileBySymbol(self): + requests = ( + reflection_pb2.ServerReflectionRequest( + file_containing_symbol=_EMPTY_PROTO_SYMBOL_NAME), + reflection_pb2.ServerReflectionRequest( + file_containing_symbol='i.donut.exist.co.uk.org.net.me.name.foo' + ), + ) + responses = tuple(self._stub.ServerReflectionInfo(iter(requests))) + expected_responses = ( + reflection_pb2.ServerReflectionResponse( + valid_host='', + file_descriptor_response=reflection_pb2.FileDescriptorResponse( + file_descriptor_proto=( + _file_descriptor_to_proto(empty_pb2.DESCRIPTOR),))), + reflection_pb2.ServerReflectionResponse( + valid_host='', + error_response=reflection_pb2.ErrorResponse( + error_code=grpc.StatusCode.NOT_FOUND.value[0], + error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(), + )), + ) + self._assert_sequence_of_proto_equal(expected_responses, responses) + + def testFileContainingExtension(self): + requests = ( + reflection_pb2.ServerReflectionRequest( + file_containing_extension=reflection_pb2.ExtensionRequest( + containing_type=_EMPTY_EXTENSIONS_SYMBOL_NAME, + extension_number=125, + ),), + reflection_pb2.ServerReflectionRequest( + file_containing_extension=reflection_pb2.ExtensionRequest( + containing_type='i.donut.exist.co.uk.org.net.me.name.foo', + extension_number=55, + ),), + ) + responses = tuple(self._stub.ServerReflectionInfo(iter(requests))) + expected_responses = ( + reflection_pb2.ServerReflectionResponse( + valid_host='', + file_descriptor_response=reflection_pb2.FileDescriptorResponse( + file_descriptor_proto=(_file_descriptor_to_proto( + empty2_extensions_pb2.DESCRIPTOR),))), + reflection_pb2.ServerReflectionResponse( + valid_host='', + error_response=reflection_pb2.ErrorResponse( + error_code=grpc.StatusCode.NOT_FOUND.value[0], + error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(), + )), + ) + self._assert_sequence_of_proto_equal(expected_responses, responses) + + def testExtensionNumbersOfType(self): + requests = ( + reflection_pb2.ServerReflectionRequest( + all_extension_numbers_of_type=_EMPTY_EXTENSIONS_SYMBOL_NAME), + reflection_pb2.ServerReflectionRequest( + all_extension_numbers_of_type='i.donut.exist.co.uk.net.name.foo' + ), + ) + responses = tuple(self._stub.ServerReflectionInfo(iter(requests))) + expected_responses = ( + reflection_pb2.ServerReflectionResponse( + valid_host='', + all_extension_numbers_response=reflection_pb2. + ExtensionNumberResponse( + base_type_name=_EMPTY_EXTENSIONS_SYMBOL_NAME, + extension_number=_EMPTY_EXTENSIONS_NUMBERS)), + reflection_pb2.ServerReflectionResponse( + valid_host='', + error_response=reflection_pb2.ErrorResponse( + error_code=grpc.StatusCode.NOT_FOUND.value[0], + error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(), + )), + ) + self._assert_sequence_of_proto_equal(expected_responses, responses) + + def testListServices(self): + requests = (reflection_pb2.ServerReflectionRequest(list_services='',),) + responses = tuple(self._stub.ServerReflectionInfo(iter(requests))) + expected_responses = (reflection_pb2.ServerReflectionResponse( + valid_host='', + list_services_response=reflection_pb2.ListServiceResponse( + service=tuple( + reflection_pb2.ServiceResponse(name=name) + for name in _SERVICE_NAMES))),) + self._assert_sequence_of_proto_equal(expected_responses, responses) + + def testReflectionServiceName(self): + self.assertEqual(reflection.SERVICE_NAME, + 'grpc.reflection.v1alpha.ServerReflection') + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/status/__init__.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/status/__init__.py new file mode 100644 index 00000000000..38fdfc9c5cf --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/status/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2018 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/status/_grpc_status_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/status/_grpc_status_test.py new file mode 100644 index 00000000000..54a3b624203 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/status/_grpc_status_test.py @@ -0,0 +1,180 @@ +# Copyright 2018 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests of grpc_status.""" + +# NOTE(lidiz) This module only exists in Bazel BUILD file, for more details +# please refer to comments in the "bazel_namespace_package_hack" module. +try: + from tests import bazel_namespace_package_hack + bazel_namespace_package_hack.sys_path_to_site_dir_hack() +except ImportError: + pass + +import unittest + +import logging +import traceback + +import grpc +from grpc_status import rpc_status + +from tests.unit import test_common + +from google.protobuf import any_pb2 +from google.rpc import code_pb2, status_pb2, error_details_pb2 + +_STATUS_OK = '/test/StatusOK' +_STATUS_NOT_OK = '/test/StatusNotOk' +_ERROR_DETAILS = '/test/ErrorDetails' +_INCONSISTENT = '/test/Inconsistent' +_INVALID_CODE = '/test/InvalidCode' + +_REQUEST = b'\x00\x00\x00' +_RESPONSE = b'\x01\x01\x01' + +_GRPC_DETAILS_METADATA_KEY = 'grpc-status-details-bin' + +_STATUS_DETAILS = 'This is an error detail' +_STATUS_DETAILS_ANOTHER = 'This is another error detail' + + +def _ok_unary_unary(request, servicer_context): + return _RESPONSE + + +def _not_ok_unary_unary(request, servicer_context): + servicer_context.abort(grpc.StatusCode.INTERNAL, _STATUS_DETAILS) + + +def _error_details_unary_unary(request, servicer_context): + details = any_pb2.Any() + details.Pack( + error_details_pb2.DebugInfo(stack_entries=traceback.format_stack(), + detail='Intentionally invoked')) + rich_status = status_pb2.Status( + code=code_pb2.INTERNAL, + message=_STATUS_DETAILS, + details=[details], + ) + servicer_context.abort_with_status(rpc_status.to_status(rich_status)) + + +def _inconsistent_unary_unary(request, servicer_context): + rich_status = status_pb2.Status( + code=code_pb2.INTERNAL, + message=_STATUS_DETAILS, + ) + servicer_context.set_code(grpc.StatusCode.NOT_FOUND) + servicer_context.set_details(_STATUS_DETAILS_ANOTHER) + # User put inconsistent status information in trailing metadata + servicer_context.set_trailing_metadata( + ((_GRPC_DETAILS_METADATA_KEY, rich_status.SerializeToString()),)) + + +def _invalid_code_unary_unary(request, servicer_context): + rich_status = status_pb2.Status( + code=42, + message='Invalid code', + ) + servicer_context.abort_with_status(rpc_status.to_status(rich_status)) + + +class _GenericHandler(grpc.GenericRpcHandler): + + def service(self, handler_call_details): + if handler_call_details.method == _STATUS_OK: + return grpc.unary_unary_rpc_method_handler(_ok_unary_unary) + elif handler_call_details.method == _STATUS_NOT_OK: + return grpc.unary_unary_rpc_method_handler(_not_ok_unary_unary) + elif handler_call_details.method == _ERROR_DETAILS: + return grpc.unary_unary_rpc_method_handler( + _error_details_unary_unary) + elif handler_call_details.method == _INCONSISTENT: + return grpc.unary_unary_rpc_method_handler( + _inconsistent_unary_unary) + elif handler_call_details.method == _INVALID_CODE: + return grpc.unary_unary_rpc_method_handler( + _invalid_code_unary_unary) + else: + return None + + +class StatusTest(unittest.TestCase): + + def setUp(self): + self._server = test_common.test_server() + self._server.add_generic_rpc_handlers((_GenericHandler(),)) + port = self._server.add_insecure_port('[::]:0') + self._server.start() + + self._channel = grpc.insecure_channel('localhost:%d' % port) + + def tearDown(self): + self._server.stop(None) + self._channel.close() + + def test_status_ok(self): + _, call = self._channel.unary_unary(_STATUS_OK).with_call(_REQUEST) + + # Succeed RPC doesn't have status + status = rpc_status.from_call(call) + self.assertIs(status, None) + + def test_status_not_ok(self): + with self.assertRaises(grpc.RpcError) as exception_context: + self._channel.unary_unary(_STATUS_NOT_OK).with_call(_REQUEST) + rpc_error = exception_context.exception + + self.assertEqual(rpc_error.code(), grpc.StatusCode.INTERNAL) + # Failed RPC doesn't automatically generate status + status = rpc_status.from_call(rpc_error) + self.assertIs(status, None) + + def test_error_details(self): + with self.assertRaises(grpc.RpcError) as exception_context: + self._channel.unary_unary(_ERROR_DETAILS).with_call(_REQUEST) + rpc_error = exception_context.exception + + status = rpc_status.from_call(rpc_error) + self.assertEqual(rpc_error.code(), grpc.StatusCode.INTERNAL) + self.assertEqual(status.code, code_pb2.Code.Value('INTERNAL')) + + # Check if the underlying proto message is intact + self.assertEqual( + status.details[0].Is(error_details_pb2.DebugInfo.DESCRIPTOR), True) + info = error_details_pb2.DebugInfo() + status.details[0].Unpack(info) + self.assertIn('_error_details_unary_unary', info.stack_entries[-1]) + + def test_code_message_validation(self): + with self.assertRaises(grpc.RpcError) as exception_context: + self._channel.unary_unary(_INCONSISTENT).with_call(_REQUEST) + rpc_error = exception_context.exception + self.assertEqual(rpc_error.code(), grpc.StatusCode.NOT_FOUND) + + # Code/Message validation failed + self.assertRaises(ValueError, rpc_status.from_call, rpc_error) + + def test_invalid_code(self): + with self.assertRaises(grpc.RpcError) as exception_context: + self._channel.unary_unary(_INVALID_CODE).with_call(_REQUEST) + rpc_error = exception_context.exception + self.assertEqual(rpc_error.code(), grpc.StatusCode.UNKNOWN) + # Invalid status code exception raised during coversion + self.assertIn('Invalid status code', rpc_error.details()) + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/stress/__init__.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/stress/__init__.py new file mode 100644 index 00000000000..5772620b602 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/stress/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2016 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/stress/client.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/stress/client.py new file mode 100644 index 00000000000..01c14ba3e20 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/stress/client.py @@ -0,0 +1,159 @@ +# Copyright 2016 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Entry point for running stress tests.""" + +import argparse +from concurrent import futures +import threading + +import grpc +from six.moves import queue +from src.proto.grpc.testing import metrics_pb2_grpc +from src.proto.grpc.testing import test_pb2_grpc + +from tests.interop import methods +from tests.interop import resources +from tests.qps import histogram +from tests.stress import metrics_server +from tests.stress import test_runner + + +def _args(): + parser = argparse.ArgumentParser( + description='gRPC Python stress test client') + parser.add_argument( + '--server_addresses', + help='comma separated list of hostname:port to run servers on', + default='localhost:8080', + type=str) + parser.add_argument( + '--test_cases', + help='comma separated list of testcase:weighting of tests to run', + default='large_unary:100', + type=str) + parser.add_argument('--test_duration_secs', + help='number of seconds to run the stress test', + default=-1, + type=int) + parser.add_argument('--num_channels_per_server', + help='number of channels per server', + default=1, + type=int) + parser.add_argument('--num_stubs_per_channel', + help='number of stubs to create per channel', + default=1, + type=int) + parser.add_argument('--metrics_port', + help='the port to listen for metrics requests on', + default=8081, + type=int) + parser.add_argument( + '--use_test_ca', + help='Whether to use our fake CA. Requires --use_tls=true', + default=False, + type=bool) + parser.add_argument('--use_tls', + help='Whether to use TLS', + default=False, + type=bool) + parser.add_argument('--server_host_override', + help='the server host to which to claim to connect', + type=str) + return parser.parse_args() + + +def _test_case_from_arg(test_case_arg): + for test_case in methods.TestCase: + if test_case_arg == test_case.value: + return test_case + else: + raise ValueError('No test case {}!'.format(test_case_arg)) + + +def _parse_weighted_test_cases(test_case_args): + weighted_test_cases = {} + for test_case_arg in test_case_args.split(','): + name, weight = test_case_arg.split(':', 1) + test_case = _test_case_from_arg(name) + weighted_test_cases[test_case] = int(weight) + return weighted_test_cases + + +def _get_channel(target, args): + if args.use_tls: + if args.use_test_ca: + root_certificates = resources.test_root_certificates() + else: + root_certificates = None # will load default roots. + channel_credentials = grpc.ssl_channel_credentials( + root_certificates=root_certificates) + options = (( + 'grpc.ssl_target_name_override', + args.server_host_override, + ),) + channel = grpc.secure_channel(target, + channel_credentials, + options=options) + else: + channel = grpc.insecure_channel(target) + + # waits for the channel to be ready before we start sending messages + grpc.channel_ready_future(channel).result() + return channel + + +def run_test(args): + test_cases = _parse_weighted_test_cases(args.test_cases) + test_server_targets = args.server_addresses.split(',') + # Propagate any client exceptions with a queue + exception_queue = queue.Queue() + stop_event = threading.Event() + hist = histogram.Histogram(1, 1) + runners = [] + + server = grpc.server(futures.ThreadPoolExecutor(max_workers=25)) + metrics_pb2_grpc.add_MetricsServiceServicer_to_server( + metrics_server.MetricsServer(hist), server) + server.add_insecure_port('[::]:{}'.format(args.metrics_port)) + server.start() + + for test_server_target in test_server_targets: + for _ in range(args.num_channels_per_server): + channel = _get_channel(test_server_target, args) + for _ in range(args.num_stubs_per_channel): + stub = test_pb2_grpc.TestServiceStub(channel) + runner = test_runner.TestRunner(stub, test_cases, hist, + exception_queue, stop_event) + runners.append(runner) + + for runner in runners: + runner.start() + try: + timeout_secs = args.test_duration_secs + if timeout_secs < 0: + timeout_secs = None + raise exception_queue.get(block=True, timeout=timeout_secs) + except queue.Empty: + # No exceptions thrown, success + pass + finally: + stop_event.set() + for runner in runners: + runner.join() + runner = None + server.stop(None) + + +if __name__ == '__main__': + run_test(_args()) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/stress/metrics_server.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/stress/metrics_server.py new file mode 100644 index 00000000000..33a74b4a388 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/stress/metrics_server.py @@ -0,0 +1,45 @@ +# Copyright 2016 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MetricsService for publishing stress test qps data.""" + +import time + +from src.proto.grpc.testing import metrics_pb2 +from src.proto.grpc.testing import metrics_pb2_grpc + +GAUGE_NAME = 'python_overall_qps' + + +class MetricsServer(metrics_pb2_grpc.MetricsServiceServicer): + + def __init__(self, histogram): + self._start_time = time.time() + self._histogram = histogram + + def _get_qps(self): + count = self._histogram.get_data().count + delta = time.time() - self._start_time + self._histogram.reset() + self._start_time = time.time() + return int(count / delta) + + def GetAllGauges(self, request, context): + qps = self._get_qps() + return [metrics_pb2.GaugeResponse(name=GAUGE_NAME, long_value=qps)] + + def GetGauge(self, request, context): + if request.name != GAUGE_NAME: + raise Exception('Gauge {} does not exist'.format(request.name)) + qps = self._get_qps() + return metrics_pb2.GaugeResponse(name=GAUGE_NAME, long_value=qps) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/stress/test_runner.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/stress/test_runner.py new file mode 100644 index 00000000000..1b6003fc698 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/stress/test_runner.py @@ -0,0 +1,58 @@ +# Copyright 2016 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Thread that sends random weighted requests on a TestService stub.""" + +import random +import threading +import time +import traceback + + +def _weighted_test_case_generator(weighted_cases): + weight_sum = sum(weighted_cases.itervalues()) + + while True: + val = random.uniform(0, weight_sum) + partial_sum = 0 + for case in weighted_cases: + partial_sum += weighted_cases[case] + if val <= partial_sum: + yield case + break + + +class TestRunner(threading.Thread): + + def __init__(self, stub, test_cases, hist, exception_queue, stop_event): + super(TestRunner, self).__init__() + self._exception_queue = exception_queue + self._stop_event = stop_event + self._stub = stub + self._test_cases = _weighted_test_case_generator(test_cases) + self._histogram = hist + + def run(self): + while not self._stop_event.is_set(): + try: + test_case = next(self._test_cases) + start_time = time.time() + test_case.test_interoperability(self._stub, None) + end_time = time.time() + self._histogram.add((end_time - start_time) * 1e9) + except Exception as e: # pylint: disable=broad-except + traceback.print_exc() + self._exception_queue.put( + Exception( + "An exception occurred during test {}".format( + test_case), e)) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/stress/unary_stream_benchmark.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/stress/unary_stream_benchmark.py new file mode 100644 index 00000000000..cd872ece29d --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/stress/unary_stream_benchmark.py @@ -0,0 +1,102 @@ +# Copyright 2019 The gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import threading +import grpc +import grpc.experimental +import subprocess +import sys +import time +import contextlib + +_PORT = 5741 +_MESSAGE_SIZE = 4 +_RESPONSE_COUNT = 32 * 1024 + +_SERVER_CODE = """ +import datetime +import threading +import grpc +from concurrent import futures +from src.python.grpcio_tests.tests.stress import unary_stream_benchmark_pb2 +from src.python.grpcio_tests.tests.stress import unary_stream_benchmark_pb2_grpc + +class Handler(unary_stream_benchmark_pb2_grpc.UnaryStreamBenchmarkServiceServicer): + + def Benchmark(self, request, context): + payload = b'\\x00\\x01' * int(request.message_size / 2) + for _ in range(request.response_count): + yield unary_stream_benchmark_pb2.BenchmarkResponse(response=payload) + + +server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) +server.add_insecure_port('[::]:%d') +unary_stream_benchmark_pb2_grpc.add_UnaryStreamBenchmarkServiceServicer_to_server(Handler(), server) +server.start() +server.wait_for_termination() +""" % _PORT + +try: + from src.python.grpcio_tests.tests.stress import unary_stream_benchmark_pb2 + from src.python.grpcio_tests.tests.stress import unary_stream_benchmark_pb2_grpc + + _GRPC_CHANNEL_OPTIONS = [ + ('grpc.max_metadata_size', 16 * 1024 * 1024), + ('grpc.max_receive_message_length', 64 * 1024 * 1024), + (grpc.experimental.ChannelOptions.SingleThreadedUnaryStream, 1), + ] + + @contextlib.contextmanager + def _running_server(): + server_process = subprocess.Popen([sys.executable, '-c', _SERVER_CODE], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + try: + yield + finally: + server_process.terminate() + server_process.wait() + sys.stdout.write("stdout: {}".format(server_process.stdout.read())) + sys.stdout.flush() + sys.stdout.write("stderr: {}".format(server_process.stderr.read())) + sys.stdout.flush() + + def profile(message_size, response_count): + request = unary_stream_benchmark_pb2.BenchmarkRequest( + message_size=message_size, response_count=response_count) + with grpc.insecure_channel('[::]:{}'.format(_PORT), + options=_GRPC_CHANNEL_OPTIONS) as channel: + stub = unary_stream_benchmark_pb2_grpc.UnaryStreamBenchmarkServiceStub( + channel) + start = datetime.datetime.now() + call = stub.Benchmark(request, wait_for_ready=True) + for message in call: + pass + end = datetime.datetime.now() + return end - start + + def main(): + with _running_server(): + for i in range(1000): + latency = profile(_MESSAGE_SIZE, 1024) + sys.stdout.write("{}\n".format(latency.total_seconds())) + sys.stdout.flush() + + if __name__ == '__main__': + main() + +except ImportError: + # NOTE(rbellevi): The test runner should not load this module. + pass diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/testing/__init__.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/testing/__init__.py new file mode 100644 index 00000000000..1e120359cf9 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/testing/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2017 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/testing/_application_common.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/testing/_application_common.py new file mode 100644 index 00000000000..3226d1fb020 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/testing/_application_common.py @@ -0,0 +1,43 @@ +# Copyright 2017 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""An example gRPC Python-using application's common code elements.""" + +from tests.testing.proto import requests_pb2 +from tests.testing.proto import services_pb2 + +SERVICE_NAME = 'tests_of_grpc_testing.FirstService' +UNARY_UNARY_METHOD_NAME = 'UnUn' +UNARY_STREAM_METHOD_NAME = 'UnStre' +STREAM_UNARY_METHOD_NAME = 'StreUn' +STREAM_STREAM_METHOD_NAME = 'StreStre' + +UNARY_UNARY_REQUEST = requests_pb2.Up(first_up_field=2) +ERRONEOUS_UNARY_UNARY_REQUEST = requests_pb2.Up(first_up_field=3) +UNARY_UNARY_RESPONSE = services_pb2.Down(first_down_field=5) +ERRONEOUS_UNARY_UNARY_RESPONSE = services_pb2.Down(first_down_field=7) +UNARY_STREAM_REQUEST = requests_pb2.Charm(first_charm_field=11) +STREAM_UNARY_REQUEST = requests_pb2.Charm(first_charm_field=13) +STREAM_UNARY_RESPONSE = services_pb2.Strange(first_strange_field=17) +STREAM_STREAM_REQUEST = requests_pb2.Top(first_top_field=19) +STREAM_STREAM_RESPONSE = services_pb2.Bottom(first_bottom_field=23) +TWO_STREAM_STREAM_RESPONSES = (STREAM_STREAM_RESPONSE,) * 2 +ABORT_REQUEST = requests_pb2.Up(first_up_field=42) +ABORT_SUCCESS_QUERY = requests_pb2.Up(first_up_field=43) +ABORT_NO_STATUS_RESPONSE = services_pb2.Down(first_down_field=50) +ABORT_SUCCESS_RESPONSE = services_pb2.Down(first_down_field=51) +ABORT_FAILURE_RESPONSE = services_pb2.Down(first_down_field=52) +STREAM_STREAM_MUTATING_REQUEST = requests_pb2.Top(first_top_field=24601) +STREAM_STREAM_MUTATING_COUNT = 2 + +INFINITE_REQUEST_STREAM_TIMEOUT = 0.2 diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/testing/_application_testing_common.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/testing/_application_testing_common.py new file mode 100644 index 00000000000..9c9e485a783 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/testing/_application_testing_common.py @@ -0,0 +1,33 @@ +# Copyright 2017 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import grpc_testing + +from tests.testing.proto import requests_pb2 +from tests.testing.proto import services_pb2 + +# TODO(https://github.com/grpc/grpc/issues/11657): Eliminate this entirely. +# TODO(https://github.com/google/protobuf/issues/3452): Eliminate this if/else. +if services_pb2.DESCRIPTOR.services_by_name.get('FirstService') is None: + FIRST_SERVICE = 'Fix protobuf issue 3452!' + FIRST_SERVICE_UNUN = 'Fix protobuf issue 3452!' + FIRST_SERVICE_UNSTRE = 'Fix protobuf issue 3452!' + FIRST_SERVICE_STREUN = 'Fix protobuf issue 3452!' + FIRST_SERVICE_STRESTRE = 'Fix protobuf issue 3452!' +else: + FIRST_SERVICE = services_pb2.DESCRIPTOR.services_by_name['FirstService'] + FIRST_SERVICE_UNUN = FIRST_SERVICE.methods_by_name['UnUn'] + FIRST_SERVICE_UNSTRE = FIRST_SERVICE.methods_by_name['UnStre'] + FIRST_SERVICE_STREUN = FIRST_SERVICE.methods_by_name['StreUn'] + FIRST_SERVICE_STRESTRE = FIRST_SERVICE.methods_by_name['StreStre'] diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/testing/_client_application.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/testing/_client_application.py new file mode 100644 index 00000000000..57fa5109139 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/testing/_client_application.py @@ -0,0 +1,236 @@ +# Copyright 2017 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""An example gRPC Python-using client-side application.""" + +import collections +import enum +import threading +import time + +import grpc +from tests.unit.framework.common import test_constants + +from tests.testing.proto import requests_pb2 +from tests.testing.proto import services_pb2 +from tests.testing.proto import services_pb2_grpc + +from tests.testing import _application_common + + +class Scenario(enum.Enum): + UNARY_UNARY = 'unary unary' + UNARY_STREAM = 'unary stream' + STREAM_UNARY = 'stream unary' + STREAM_STREAM = 'stream stream' + CONCURRENT_STREAM_UNARY = 'concurrent stream unary' + CONCURRENT_STREAM_STREAM = 'concurrent stream stream' + CANCEL_UNARY_UNARY = 'cancel unary unary' + CANCEL_UNARY_STREAM = 'cancel unary stream' + INFINITE_REQUEST_STREAM = 'infinite request stream' + + +class Outcome(collections.namedtuple('Outcome', ('kind', 'code', 'details'))): + """Outcome of a client application scenario. + + Attributes: + kind: A Kind value describing the overall kind of scenario execution. + code: A grpc.StatusCode value. Only valid if kind is Kind.RPC_ERROR. + details: A status details string. Only valid if kind is Kind.RPC_ERROR. + """ + + @enum.unique + class Kind(enum.Enum): + SATISFACTORY = 'satisfactory' + UNSATISFACTORY = 'unsatisfactory' + RPC_ERROR = 'rpc error' + + +_SATISFACTORY_OUTCOME = Outcome(Outcome.Kind.SATISFACTORY, None, None) +_UNSATISFACTORY_OUTCOME = Outcome(Outcome.Kind.UNSATISFACTORY, None, None) + + +class _Pipe(object): + + def __init__(self): + self._condition = threading.Condition() + self._values = [] + self._open = True + + def __iter__(self): + return self + + def _next(self): + with self._condition: + while True: + if self._values: + return self._values.pop(0) + elif not self._open: + raise StopIteration() + else: + self._condition.wait() + + def __next__(self): # (Python 3 Iterator Protocol) + return self._next() + + def next(self): # (Python 2 Iterator Protocol) + return self._next() + + def add(self, value): + with self._condition: + self._values.append(value) + self._condition.notify_all() + + def close(self): + with self._condition: + self._open = False + self._condition.notify_all() + + +def _run_unary_unary(stub): + response = stub.UnUn(_application_common.UNARY_UNARY_REQUEST) + if _application_common.UNARY_UNARY_RESPONSE == response: + return _SATISFACTORY_OUTCOME + else: + return _UNSATISFACTORY_OUTCOME + + +def _run_unary_stream(stub): + response_iterator = stub.UnStre(_application_common.UNARY_STREAM_REQUEST) + try: + next(response_iterator) + except StopIteration: + return _SATISFACTORY_OUTCOME + else: + return _UNSATISFACTORY_OUTCOME + + +def _run_stream_unary(stub): + response, call = stub.StreUn.with_call( + iter((_application_common.STREAM_UNARY_REQUEST,) * 3)) + if (_application_common.STREAM_UNARY_RESPONSE == response and + call.code() is grpc.StatusCode.OK): + return _SATISFACTORY_OUTCOME + else: + return _UNSATISFACTORY_OUTCOME + + +def _run_stream_stream(stub): + request_pipe = _Pipe() + response_iterator = stub.StreStre(iter(request_pipe)) + request_pipe.add(_application_common.STREAM_STREAM_REQUEST) + first_responses = next(response_iterator), next(response_iterator) + request_pipe.add(_application_common.STREAM_STREAM_REQUEST) + second_responses = next(response_iterator), next(response_iterator) + request_pipe.close() + try: + next(response_iterator) + except StopIteration: + unexpected_extra_response = False + else: + unexpected_extra_response = True + if (first_responses == _application_common.TWO_STREAM_STREAM_RESPONSES and + second_responses == _application_common.TWO_STREAM_STREAM_RESPONSES + and not unexpected_extra_response): + return _SATISFACTORY_OUTCOME + else: + return _UNSATISFACTORY_OUTCOME + + +def _run_concurrent_stream_unary(stub): + future_calls = tuple( + stub.StreUn.future(iter((_application_common.STREAM_UNARY_REQUEST,) * + 3)) + for _ in range(test_constants.THREAD_CONCURRENCY)) + for future_call in future_calls: + if future_call.code() is grpc.StatusCode.OK: + response = future_call.result() + if _application_common.STREAM_UNARY_RESPONSE != response: + return _UNSATISFACTORY_OUTCOME + else: + return _UNSATISFACTORY_OUTCOME + else: + return _SATISFACTORY_OUTCOME + + +def _run_concurrent_stream_stream(stub): + condition = threading.Condition() + outcomes = [None] * test_constants.RPC_CONCURRENCY + + def run_stream_stream(index): + outcome = _run_stream_stream(stub) + with condition: + outcomes[index] = outcome + condition.notify() + + for index in range(test_constants.RPC_CONCURRENCY): + thread = threading.Thread(target=run_stream_stream, args=(index,)) + thread.start() + with condition: + while True: + if all(outcomes): + for outcome in outcomes: + if outcome.kind is not Outcome.Kind.SATISFACTORY: + return _UNSATISFACTORY_OUTCOME + else: + return _SATISFACTORY_OUTCOME + else: + condition.wait() + + +def _run_cancel_unary_unary(stub): + response_future_call = stub.UnUn.future( + _application_common.UNARY_UNARY_REQUEST) + initial_metadata = response_future_call.initial_metadata() + cancelled = response_future_call.cancel() + if initial_metadata is not None and cancelled: + return _SATISFACTORY_OUTCOME + else: + return _UNSATISFACTORY_OUTCOME + + +def _run_infinite_request_stream(stub): + + def infinite_request_iterator(): + while True: + yield _application_common.STREAM_UNARY_REQUEST + + response_future_call = stub.StreUn.future( + infinite_request_iterator(), + timeout=_application_common.INFINITE_REQUEST_STREAM_TIMEOUT) + if response_future_call.code() is grpc.StatusCode.DEADLINE_EXCEEDED: + return _SATISFACTORY_OUTCOME + else: + return _UNSATISFACTORY_OUTCOME + + +_IMPLEMENTATIONS = { + Scenario.UNARY_UNARY: _run_unary_unary, + Scenario.UNARY_STREAM: _run_unary_stream, + Scenario.STREAM_UNARY: _run_stream_unary, + Scenario.STREAM_STREAM: _run_stream_stream, + Scenario.CONCURRENT_STREAM_UNARY: _run_concurrent_stream_unary, + Scenario.CONCURRENT_STREAM_STREAM: _run_concurrent_stream_stream, + Scenario.CANCEL_UNARY_UNARY: _run_cancel_unary_unary, + Scenario.INFINITE_REQUEST_STREAM: _run_infinite_request_stream, +} + + +def run(scenario, channel): + stub = services_pb2_grpc.FirstServiceStub(channel) + try: + return _IMPLEMENTATIONS[scenario](stub) + except grpc.RpcError as rpc_error: + return Outcome(Outcome.Kind.RPC_ERROR, rpc_error.code(), + rpc_error.details()) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/testing/_client_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/testing/_client_test.py new file mode 100644 index 00000000000..5b051c39390 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/testing/_client_test.py @@ -0,0 +1,308 @@ +# Copyright 2017 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from concurrent import futures +import time +import unittest + +import grpc +from grpc.framework.foundation import logging_pool +from tests.unit.framework.common import test_constants +import grpc_testing + +from tests.testing import _application_common +from tests.testing import _application_testing_common +from tests.testing import _client_application +from tests.testing.proto import requests_pb2 +from tests.testing.proto import services_pb2 + + +# TODO(https://github.com/google/protobuf/issues/3452): Drop this skip. + services_pb2.DESCRIPTOR.services_by_name.get('FirstService') is None, + 'Fix protobuf issue 3452!') +class ClientTest(unittest.TestCase): + + def setUp(self): + # In this test the client-side application under test executes in + # a separate thread while we retain use of the test thread to "play + # server". + self._client_execution_thread_pool = logging_pool.pool(1) + + self._fake_time = grpc_testing.strict_fake_time(time.time()) + self._real_time = grpc_testing.strict_real_time() + self._fake_time_channel = grpc_testing.channel( + services_pb2.DESCRIPTOR.services_by_name.values(), self._fake_time) + self._real_time_channel = grpc_testing.channel( + services_pb2.DESCRIPTOR.services_by_name.values(), self._real_time) + + def tearDown(self): + self._client_execution_thread_pool.shutdown(wait=True) + + def test_successful_unary_unary(self): + application_future = self._client_execution_thread_pool.submit( + _client_application.run, _client_application.Scenario.UNARY_UNARY, + self._real_time_channel) + invocation_metadata, request, rpc = ( + self._real_time_channel.take_unary_unary( + _application_testing_common.FIRST_SERVICE_UNUN)) + rpc.send_initial_metadata(()) + rpc.terminate(_application_common.UNARY_UNARY_RESPONSE, (), + grpc.StatusCode.OK, '') + application_return_value = application_future.result() + + self.assertEqual(_application_common.UNARY_UNARY_REQUEST, request) + self.assertIs(application_return_value.kind, + _client_application.Outcome.Kind.SATISFACTORY) + + def test_successful_unary_stream(self): + application_future = self._client_execution_thread_pool.submit( + _client_application.run, _client_application.Scenario.UNARY_STREAM, + self._fake_time_channel) + invocation_metadata, request, rpc = ( + self._fake_time_channel.take_unary_stream( + _application_testing_common.FIRST_SERVICE_UNSTRE)) + rpc.send_initial_metadata(()) + rpc.terminate((), grpc.StatusCode.OK, '') + application_return_value = application_future.result() + + self.assertEqual(_application_common.UNARY_STREAM_REQUEST, request) + self.assertIs(application_return_value.kind, + _client_application.Outcome.Kind.SATISFACTORY) + + def test_successful_stream_unary(self): + application_future = self._client_execution_thread_pool.submit( + _client_application.run, _client_application.Scenario.STREAM_UNARY, + self._real_time_channel) + invocation_metadata, rpc = self._real_time_channel.take_stream_unary( + _application_testing_common.FIRST_SERVICE_STREUN) + rpc.send_initial_metadata(()) + first_request = rpc.take_request() + second_request = rpc.take_request() + third_request = rpc.take_request() + rpc.requests_closed() + rpc.terminate(_application_common.STREAM_UNARY_RESPONSE, (), + grpc.StatusCode.OK, '') + application_return_value = application_future.result() + + self.assertEqual(_application_common.STREAM_UNARY_REQUEST, + first_request) + self.assertEqual(_application_common.STREAM_UNARY_REQUEST, + second_request) + self.assertEqual(_application_common.STREAM_UNARY_REQUEST, + third_request) + self.assertIs(application_return_value.kind, + _client_application.Outcome.Kind.SATISFACTORY) + + def test_successful_stream_stream(self): + application_future = self._client_execution_thread_pool.submit( + _client_application.run, _client_application.Scenario.STREAM_STREAM, + self._fake_time_channel) + invocation_metadata, rpc = self._fake_time_channel.take_stream_stream( + _application_testing_common.FIRST_SERVICE_STRESTRE) + first_request = rpc.take_request() + rpc.send_response(_application_common.STREAM_STREAM_RESPONSE) + rpc.send_response(_application_common.STREAM_STREAM_RESPONSE) + second_request = rpc.take_request() + rpc.send_response(_application_common.STREAM_STREAM_RESPONSE) + rpc.send_response(_application_common.STREAM_STREAM_RESPONSE) + rpc.requests_closed() + rpc.terminate((), grpc.StatusCode.OK, '') + application_return_value = application_future.result() + + self.assertEqual(_application_common.STREAM_STREAM_REQUEST, + first_request) + self.assertEqual(_application_common.STREAM_STREAM_REQUEST, + second_request) + self.assertIs(application_return_value.kind, + _client_application.Outcome.Kind.SATISFACTORY) + + def test_concurrent_stream_stream(self): + application_future = self._client_execution_thread_pool.submit( + _client_application.run, + _client_application.Scenario.CONCURRENT_STREAM_STREAM, + self._real_time_channel) + rpcs = [] + for _ in range(test_constants.RPC_CONCURRENCY): + invocation_metadata, rpc = ( + self._real_time_channel.take_stream_stream( + _application_testing_common.FIRST_SERVICE_STRESTRE)) + rpcs.append(rpc) + requests = {} + for rpc in rpcs: + requests[rpc] = [rpc.take_request()] + for rpc in rpcs: + rpc.send_response(_application_common.STREAM_STREAM_RESPONSE) + rpc.send_response(_application_common.STREAM_STREAM_RESPONSE) + for rpc in rpcs: + requests[rpc].append(rpc.take_request()) + for rpc in rpcs: + rpc.send_response(_application_common.STREAM_STREAM_RESPONSE) + rpc.send_response(_application_common.STREAM_STREAM_RESPONSE) + for rpc in rpcs: + rpc.requests_closed() + for rpc in rpcs: + rpc.terminate((), grpc.StatusCode.OK, '') + application_return_value = application_future.result() + + for requests_of_one_rpc in requests.values(): + for request in requests_of_one_rpc: + self.assertEqual(_application_common.STREAM_STREAM_REQUEST, + request) + self.assertIs(application_return_value.kind, + _client_application.Outcome.Kind.SATISFACTORY) + + def test_cancelled_unary_unary(self): + application_future = self._client_execution_thread_pool.submit( + _client_application.run, + _client_application.Scenario.CANCEL_UNARY_UNARY, + self._fake_time_channel) + invocation_metadata, request, rpc = ( + self._fake_time_channel.take_unary_unary( + _application_testing_common.FIRST_SERVICE_UNUN)) + rpc.send_initial_metadata(()) + rpc.cancelled() + application_return_value = application_future.result() + + self.assertEqual(_application_common.UNARY_UNARY_REQUEST, request) + self.assertIs(application_return_value.kind, + _client_application.Outcome.Kind.SATISFACTORY) + + def test_status_stream_unary(self): + application_future = self._client_execution_thread_pool.submit( + _client_application.run, + _client_application.Scenario.CONCURRENT_STREAM_UNARY, + self._fake_time_channel) + rpcs = tuple( + self._fake_time_channel.take_stream_unary( + _application_testing_common.FIRST_SERVICE_STREUN)[1] + for _ in range(test_constants.THREAD_CONCURRENCY)) + for rpc in rpcs: + rpc.take_request() + rpc.take_request() + rpc.take_request() + rpc.requests_closed() + rpc.send_initial_metadata((( + 'my_metadata_key', + 'My Metadata Value!', + ),)) + for rpc in rpcs[:-1]: + rpc.terminate(_application_common.STREAM_UNARY_RESPONSE, (), + grpc.StatusCode.OK, '') + rpcs[-1].terminate(_application_common.STREAM_UNARY_RESPONSE, (), + grpc.StatusCode.RESOURCE_EXHAUSTED, + 'nope; not able to handle all those RPCs!') + application_return_value = application_future.result() + + self.assertIs(application_return_value.kind, + _client_application.Outcome.Kind.UNSATISFACTORY) + + def test_status_stream_stream(self): + code = grpc.StatusCode.DEADLINE_EXCEEDED + details = 'test deadline exceeded!' + + application_future = self._client_execution_thread_pool.submit( + _client_application.run, _client_application.Scenario.STREAM_STREAM, + self._real_time_channel) + invocation_metadata, rpc = self._real_time_channel.take_stream_stream( + _application_testing_common.FIRST_SERVICE_STRESTRE) + first_request = rpc.take_request() + rpc.send_response(_application_common.STREAM_STREAM_RESPONSE) + rpc.send_response(_application_common.STREAM_STREAM_RESPONSE) + second_request = rpc.take_request() + rpc.send_response(_application_common.STREAM_STREAM_RESPONSE) + rpc.send_response(_application_common.STREAM_STREAM_RESPONSE) + rpc.requests_closed() + rpc.terminate((), code, details) + application_return_value = application_future.result() + + self.assertEqual(_application_common.STREAM_STREAM_REQUEST, + first_request) + self.assertEqual(_application_common.STREAM_STREAM_REQUEST, + second_request) + self.assertIs(application_return_value.kind, + _client_application.Outcome.Kind.RPC_ERROR) + self.assertIs(application_return_value.code, code) + self.assertEqual(application_return_value.details, details) + + def test_misbehaving_server_unary_unary(self): + application_future = self._client_execution_thread_pool.submit( + _client_application.run, _client_application.Scenario.UNARY_UNARY, + self._fake_time_channel) + invocation_metadata, request, rpc = ( + self._fake_time_channel.take_unary_unary( + _application_testing_common.FIRST_SERVICE_UNUN)) + rpc.send_initial_metadata(()) + rpc.terminate(_application_common.ERRONEOUS_UNARY_UNARY_RESPONSE, (), + grpc.StatusCode.OK, '') + application_return_value = application_future.result() + + self.assertEqual(_application_common.UNARY_UNARY_REQUEST, request) + self.assertIs(application_return_value.kind, + _client_application.Outcome.Kind.UNSATISFACTORY) + + def test_misbehaving_server_stream_stream(self): + application_future = self._client_execution_thread_pool.submit( + _client_application.run, _client_application.Scenario.STREAM_STREAM, + self._real_time_channel) + invocation_metadata, rpc = self._real_time_channel.take_stream_stream( + _application_testing_common.FIRST_SERVICE_STRESTRE) + first_request = rpc.take_request() + rpc.send_response(_application_common.STREAM_STREAM_RESPONSE) + rpc.send_response(_application_common.STREAM_STREAM_RESPONSE) + rpc.send_response(_application_common.STREAM_STREAM_RESPONSE) + second_request = rpc.take_request() + rpc.send_response(_application_common.STREAM_STREAM_RESPONSE) + rpc.send_response(_application_common.STREAM_STREAM_RESPONSE) + rpc.send_response(_application_common.STREAM_STREAM_RESPONSE) + rpc.requests_closed() + rpc.terminate((), grpc.StatusCode.OK, '') + application_return_value = application_future.result() + + self.assertEqual(_application_common.STREAM_STREAM_REQUEST, + first_request) + self.assertEqual(_application_common.STREAM_STREAM_REQUEST, + second_request) + self.assertIs(application_return_value.kind, + _client_application.Outcome.Kind.UNSATISFACTORY) + + def test_infinite_request_stream_real_time(self): + application_future = self._client_execution_thread_pool.submit( + _client_application.run, + _client_application.Scenario.INFINITE_REQUEST_STREAM, + self._real_time_channel) + invocation_metadata, rpc = self._real_time_channel.take_stream_unary( + _application_testing_common.FIRST_SERVICE_STREUN) + rpc.send_initial_metadata(()) + first_request = rpc.take_request() + second_request = rpc.take_request() + third_request = rpc.take_request() + self._real_time.sleep_for( + _application_common.INFINITE_REQUEST_STREAM_TIMEOUT) + rpc.terminate(_application_common.STREAM_UNARY_RESPONSE, (), + grpc.StatusCode.DEADLINE_EXCEEDED, '') + application_return_value = application_future.result() + + self.assertEqual(_application_common.STREAM_UNARY_REQUEST, + first_request) + self.assertEqual(_application_common.STREAM_UNARY_REQUEST, + second_request) + self.assertEqual(_application_common.STREAM_UNARY_REQUEST, + third_request) + self.assertIs(application_return_value.kind, + _client_application.Outcome.Kind.SATISFACTORY) + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/testing/_server_application.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/testing/_server_application.py new file mode 100644 index 00000000000..51ed977b8fe --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/testing/_server_application.py @@ -0,0 +1,95 @@ +# Copyright 2017 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""An example gRPC Python-using server-side application.""" + +import grpc + +import threading + +# requests_pb2 is a semantic dependency of this module. +from tests.testing import _application_common +from tests.testing.proto import requests_pb2 # pylint: disable=unused-import +from tests.testing.proto import services_pb2 +from tests.testing.proto import services_pb2_grpc + + +class FirstServiceServicer(services_pb2_grpc.FirstServiceServicer): + """Services RPCs.""" + + def __init__(self): + self._abort_lock = threading.RLock() + self._abort_response = _application_common.ABORT_NO_STATUS_RESPONSE + + def UnUn(self, request, context): + if request == _application_common.UNARY_UNARY_REQUEST: + return _application_common.UNARY_UNARY_RESPONSE + elif request == _application_common.ABORT_REQUEST: + with self._abort_lock: + try: + context.abort(grpc.StatusCode.PERMISSION_DENIED, + "Denying permission to test abort.") + except Exception as e: # pylint: disable=broad-except + self._abort_response = _application_common.ABORT_SUCCESS_RESPONSE + else: + self._abort_status = _application_common.ABORT_FAILURE_RESPONSE + return None # NOTE: For the linter. + elif request == _application_common.ABORT_SUCCESS_QUERY: + with self._abort_lock: + return self._abort_response + else: + context.set_code(grpc.StatusCode.INVALID_ARGUMENT) + context.set_details('Something is wrong with your request!') + return services_pb2.Down() + + def UnStre(self, request, context): + if _application_common.UNARY_STREAM_REQUEST != request: + context.set_code(grpc.StatusCode.INVALID_ARGUMENT) + context.set_details('Something is wrong with your request!') + return + yield services_pb2.Strange() # pylint: disable=unreachable + + def StreUn(self, request_iterator, context): + context.send_initial_metadata((( + 'server_application_metadata_key', + 'Hi there!', + ),)) + for request in request_iterator: + if request != _application_common.STREAM_UNARY_REQUEST: + context.set_code(grpc.StatusCode.INVALID_ARGUMENT) + context.set_details('Something is wrong with your request!') + return services_pb2.Strange() + elif not context.is_active(): + return services_pb2.Strange() + else: + return _application_common.STREAM_UNARY_RESPONSE + + def StreStre(self, request_iterator, context): + valid_requests = (_application_common.STREAM_STREAM_REQUEST, + _application_common.STREAM_STREAM_MUTATING_REQUEST) + for request in request_iterator: + if request not in valid_requests: + context.set_code(grpc.StatusCode.INVALID_ARGUMENT) + context.set_details('Something is wrong with your request!') + return + elif not context.is_active(): + return + elif request == _application_common.STREAM_STREAM_REQUEST: + yield _application_common.STREAM_STREAM_RESPONSE + yield _application_common.STREAM_STREAM_RESPONSE + elif request == _application_common.STREAM_STREAM_MUTATING_REQUEST: + response = services_pb2.Bottom() + for i in range( + _application_common.STREAM_STREAM_MUTATING_COUNT): + response.first_bottom_field = i + yield response diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/testing/_server_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/testing/_server_test.py new file mode 100644 index 00000000000..617a41b7e54 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/testing/_server_test.py @@ -0,0 +1,207 @@ +# Copyright 2017 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import unittest + +import grpc +import grpc_testing + +from tests.testing import _application_common +from tests.testing import _application_testing_common +from tests.testing import _server_application +from tests.testing.proto import services_pb2 + + +class FirstServiceServicerTest(unittest.TestCase): + + def setUp(self): + self._real_time = grpc_testing.strict_real_time() + self._fake_time = grpc_testing.strict_fake_time(time.time()) + servicer = _server_application.FirstServiceServicer() + descriptors_to_servicers = { + _application_testing_common.FIRST_SERVICE: servicer + } + self._real_time_server = grpc_testing.server_from_dictionary( + descriptors_to_servicers, self._real_time) + self._fake_time_server = grpc_testing.server_from_dictionary( + descriptors_to_servicers, self._fake_time) + + def test_successful_unary_unary(self): + rpc = self._real_time_server.invoke_unary_unary( + _application_testing_common.FIRST_SERVICE_UNUN, (), + _application_common.UNARY_UNARY_REQUEST, None) + initial_metadata = rpc.initial_metadata() + response, trailing_metadata, code, details = rpc.termination() + + self.assertEqual(_application_common.UNARY_UNARY_RESPONSE, response) + self.assertIs(code, grpc.StatusCode.OK) + + def test_successful_unary_stream(self): + rpc = self._real_time_server.invoke_unary_stream( + _application_testing_common.FIRST_SERVICE_UNSTRE, (), + _application_common.UNARY_STREAM_REQUEST, None) + initial_metadata = rpc.initial_metadata() + trailing_metadata, code, details = rpc.termination() + + self.assertIs(code, grpc.StatusCode.OK) + + def test_successful_stream_unary(self): + rpc = self._real_time_server.invoke_stream_unary( + _application_testing_common.FIRST_SERVICE_STREUN, (), None) + rpc.send_request(_application_common.STREAM_UNARY_REQUEST) + rpc.send_request(_application_common.STREAM_UNARY_REQUEST) + rpc.send_request(_application_common.STREAM_UNARY_REQUEST) + rpc.requests_closed() + initial_metadata = rpc.initial_metadata() + response, trailing_metadata, code, details = rpc.termination() + + self.assertEqual(_application_common.STREAM_UNARY_RESPONSE, response) + self.assertIs(code, grpc.StatusCode.OK) + + def test_successful_stream_stream(self): + rpc = self._real_time_server.invoke_stream_stream( + _application_testing_common.FIRST_SERVICE_STRESTRE, (), None) + rpc.send_request(_application_common.STREAM_STREAM_REQUEST) + initial_metadata = rpc.initial_metadata() + responses = [ + rpc.take_response(), + rpc.take_response(), + ] + rpc.send_request(_application_common.STREAM_STREAM_REQUEST) + rpc.send_request(_application_common.STREAM_STREAM_REQUEST) + responses.extend([ + rpc.take_response(), + rpc.take_response(), + rpc.take_response(), + rpc.take_response(), + ]) + rpc.requests_closed() + trailing_metadata, code, details = rpc.termination() + + for response in responses: + self.assertEqual(_application_common.STREAM_STREAM_RESPONSE, + response) + self.assertIs(code, grpc.StatusCode.OK) + + def test_mutating_stream_stream(self): + rpc = self._real_time_server.invoke_stream_stream( + _application_testing_common.FIRST_SERVICE_STRESTRE, (), None) + rpc.send_request(_application_common.STREAM_STREAM_MUTATING_REQUEST) + initial_metadata = rpc.initial_metadata() + responses = [ + rpc.take_response() + for _ in range(_application_common.STREAM_STREAM_MUTATING_COUNT) + ] + rpc.send_request(_application_common.STREAM_STREAM_MUTATING_REQUEST) + responses.extend([ + rpc.take_response() + for _ in range(_application_common.STREAM_STREAM_MUTATING_COUNT) + ]) + rpc.requests_closed() + _, _, _ = rpc.termination() + expected_responses = ( + services_pb2.Bottom(first_bottom_field=0), + services_pb2.Bottom(first_bottom_field=1), + services_pb2.Bottom(first_bottom_field=0), + services_pb2.Bottom(first_bottom_field=1), + ) + self.assertSequenceEqual(expected_responses, responses) + + def test_server_rpc_idempotence(self): + rpc = self._real_time_server.invoke_unary_unary( + _application_testing_common.FIRST_SERVICE_UNUN, (), + _application_common.UNARY_UNARY_REQUEST, None) + first_initial_metadata = rpc.initial_metadata() + second_initial_metadata = rpc.initial_metadata() + third_initial_metadata = rpc.initial_metadata() + first_termination = rpc.termination() + second_termination = rpc.termination() + third_termination = rpc.termination() + + for later_initial_metadata in ( + second_initial_metadata, + third_initial_metadata, + ): + self.assertEqual(first_initial_metadata, later_initial_metadata) + response = first_termination[0] + terminal_metadata = first_termination[1] + code = first_termination[2] + details = first_termination[3] + for later_termination in ( + second_termination, + third_termination, + ): + self.assertEqual(response, later_termination[0]) + self.assertEqual(terminal_metadata, later_termination[1]) + self.assertIs(code, later_termination[2]) + self.assertEqual(details, later_termination[3]) + self.assertEqual(_application_common.UNARY_UNARY_RESPONSE, response) + self.assertIs(code, grpc.StatusCode.OK) + + def test_misbehaving_client_unary_unary(self): + rpc = self._real_time_server.invoke_unary_unary( + _application_testing_common.FIRST_SERVICE_UNUN, (), + _application_common.ERRONEOUS_UNARY_UNARY_REQUEST, None) + initial_metadata = rpc.initial_metadata() + response, trailing_metadata, code, details = rpc.termination() + + self.assertIsNot(code, grpc.StatusCode.OK) + + def test_infinite_request_stream_real_time(self): + rpc = self._real_time_server.invoke_stream_unary( + _application_testing_common.FIRST_SERVICE_STREUN, (), + _application_common.INFINITE_REQUEST_STREAM_TIMEOUT) + rpc.send_request(_application_common.STREAM_UNARY_REQUEST) + rpc.send_request(_application_common.STREAM_UNARY_REQUEST) + rpc.send_request(_application_common.STREAM_UNARY_REQUEST) + initial_metadata = rpc.initial_metadata() + self._real_time.sleep_for( + _application_common.INFINITE_REQUEST_STREAM_TIMEOUT * 2) + rpc.send_request(_application_common.STREAM_UNARY_REQUEST) + response, trailing_metadata, code, details = rpc.termination() + + self.assertIs(code, grpc.StatusCode.DEADLINE_EXCEEDED) + + def test_infinite_request_stream_fake_time(self): + rpc = self._fake_time_server.invoke_stream_unary( + _application_testing_common.FIRST_SERVICE_STREUN, (), + _application_common.INFINITE_REQUEST_STREAM_TIMEOUT) + rpc.send_request(_application_common.STREAM_UNARY_REQUEST) + rpc.send_request(_application_common.STREAM_UNARY_REQUEST) + rpc.send_request(_application_common.STREAM_UNARY_REQUEST) + initial_metadata = rpc.initial_metadata() + self._fake_time.sleep_for( + _application_common.INFINITE_REQUEST_STREAM_TIMEOUT * 2) + rpc.send_request(_application_common.STREAM_UNARY_REQUEST) + response, trailing_metadata, code, details = rpc.termination() + + self.assertIs(code, grpc.StatusCode.DEADLINE_EXCEEDED) + + def test_servicer_context_abort(self): + rpc = self._real_time_server.invoke_unary_unary( + _application_testing_common.FIRST_SERVICE_UNUN, (), + _application_common.ABORT_REQUEST, None) + _, _, code, _ = rpc.termination() + self.assertIs(code, grpc.StatusCode.PERMISSION_DENIED) + rpc = self._real_time_server.invoke_unary_unary( + _application_testing_common.FIRST_SERVICE_UNUN, (), + _application_common.ABORT_SUCCESS_QUERY, None) + response, _, code, _ = rpc.termination() + self.assertEqual(_application_common.ABORT_SUCCESS_RESPONSE, response) + self.assertIs(code, grpc.StatusCode.OK) + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/testing/_time_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/testing/_time_test.py new file mode 100644 index 00000000000..cab665c045c --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/testing/_time_test.py @@ -0,0 +1,165 @@ +# Copyright 2017 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import threading +import time +import unittest + +import grpc_testing + +_QUANTUM = 0.3 +_MANY = 10000 +# Tests that run in real time can either wait for the scheduler to +# eventually run what needs to be run (and risk timing out) or declare +# that the scheduler didn't schedule work reasonably fast enough. We +# choose the latter for this test. +_PATHOLOGICAL_SCHEDULING = 'pathological thread scheduling!' + + +class _TimeNoter(object): + + def __init__(self, time): + self._condition = threading.Condition() + self._time = time + self._call_times = [] + + def __call__(self): + with self._condition: + self._call_times.append(self._time.time()) + + def call_times(self): + with self._condition: + return tuple(self._call_times) + + +class TimeTest(object): + + def test_sleep_for(self): + start_time = self._time.time() + self._time.sleep_for(_QUANTUM) + end_time = self._time.time() + + self.assertLessEqual(start_time + _QUANTUM, end_time) + + def test_sleep_until(self): + start_time = self._time.time() + self._time.sleep_until(start_time + _QUANTUM) + end_time = self._time.time() + + self.assertLessEqual(start_time + _QUANTUM, end_time) + + def test_call_in(self): + time_noter = _TimeNoter(self._time) + + start_time = self._time.time() + self._time.call_in(time_noter, _QUANTUM) + self._time.sleep_for(_QUANTUM * 2) + call_times = time_noter.call_times() + + self.assertTrue(call_times, msg=_PATHOLOGICAL_SCHEDULING) + self.assertLessEqual(start_time + _QUANTUM, call_times[0]) + + def test_call_at(self): + time_noter = _TimeNoter(self._time) + + start_time = self._time.time() + self._time.call_at(time_noter, self._time.time() + _QUANTUM) + self._time.sleep_for(_QUANTUM * 2) + call_times = time_noter.call_times() + + self.assertTrue(call_times, msg=_PATHOLOGICAL_SCHEDULING) + self.assertLessEqual(start_time + _QUANTUM, call_times[0]) + + def test_cancel(self): + time_noter = _TimeNoter(self._time) + + future = self._time.call_in(time_noter, _QUANTUM * 2) + self._time.sleep_for(_QUANTUM) + cancelled = future.cancel() + self._time.sleep_for(_QUANTUM * 2) + call_times = time_noter.call_times() + + self.assertFalse(call_times, msg=_PATHOLOGICAL_SCHEDULING) + self.assertTrue(cancelled) + self.assertTrue(future.cancelled()) + + def test_many(self): + test_events = tuple(threading.Event() for _ in range(_MANY)) + possibly_cancelled_futures = {} + background_noise_futures = [] + + for test_event in test_events: + possibly_cancelled_futures[test_event] = self._time.call_in( + test_event.set, _QUANTUM * (2 + random.random())) + for _ in range(_MANY): + background_noise_futures.append( + self._time.call_in(threading.Event().set, + _QUANTUM * 1000 * random.random())) + self._time.sleep_for(_QUANTUM) + cancelled = set() + for test_event, test_future in possibly_cancelled_futures.items(): + if bool(random.randint(0, 1)) and test_future.cancel(): + cancelled.add(test_event) + self._time.sleep_for(_QUANTUM * 3) + + for test_event in test_events: + (self.assertFalse if test_event in cancelled else self.assertTrue)( + test_event.is_set()) + for background_noise_future in background_noise_futures: + background_noise_future.cancel() + + def test_same_behavior_used_several_times(self): + time_noter = _TimeNoter(self._time) + + start_time = self._time.time() + first_future_at_one = self._time.call_in(time_noter, _QUANTUM) + second_future_at_one = self._time.call_in(time_noter, _QUANTUM) + first_future_at_three = self._time.call_in(time_noter, _QUANTUM * 3) + second_future_at_three = self._time.call_in(time_noter, _QUANTUM * 3) + self._time.sleep_for(_QUANTUM * 2) + first_future_at_one_cancelled = first_future_at_one.cancel() + second_future_at_one_cancelled = second_future_at_one.cancel() + first_future_at_three_cancelled = first_future_at_three.cancel() + self._time.sleep_for(_QUANTUM * 2) + second_future_at_three_cancelled = second_future_at_three.cancel() + first_future_at_three_cancelled_again = first_future_at_three.cancel() + call_times = time_noter.call_times() + + self.assertEqual(3, len(call_times), msg=_PATHOLOGICAL_SCHEDULING) + self.assertFalse(first_future_at_one_cancelled) + self.assertFalse(second_future_at_one_cancelled) + self.assertTrue(first_future_at_three_cancelled) + self.assertFalse(second_future_at_three_cancelled) + self.assertTrue(first_future_at_three_cancelled_again) + self.assertLessEqual(start_time + _QUANTUM, call_times[0]) + self.assertLessEqual(start_time + _QUANTUM, call_times[1]) + self.assertLessEqual(start_time + _QUANTUM * 3, call_times[2]) + + +class StrictRealTimeTest(TimeTest, unittest.TestCase): + + def setUp(self): + self._time = grpc_testing.strict_real_time() + + +class StrictFakeTimeTest(TimeTest, unittest.TestCase): + + def setUp(self): + self._time = grpc_testing.strict_fake_time( + random.randint(0, int(time.time()))) + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/testing/proto/__init__.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/testing/proto/__init__.py new file mode 100644 index 00000000000..1e120359cf9 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/testing/proto/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2017 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/__init__.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/__init__.py new file mode 100644 index 00000000000..5fb4f3c3cfd --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2015 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_abort_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_abort_test.py new file mode 100644 index 00000000000..d2eaf97d5f4 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_abort_test.py @@ -0,0 +1,154 @@ +# Copyright 2018 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests server context abort mechanism""" + +import unittest +import collections +import gc +import logging +import weakref + +import grpc + +from tests.unit import test_common +from tests.unit.framework.common import test_constants + +_ABORT = '/test/abort' +_ABORT_WITH_STATUS = '/test/AbortWithStatus' +_INVALID_CODE = '/test/InvalidCode' + +_REQUEST = b'\x00\x00\x00' +_RESPONSE = b'\x00\x00\x00' + +_ABORT_DETAILS = 'Abandon ship!' +_ABORT_METADATA = (('a-trailing-metadata', '42'),) + + +class _Status( + collections.namedtuple('_Status', + ('code', 'details', 'trailing_metadata')), + grpc.Status): + pass + + +class _Object(object): + pass + + +do_not_leak_me = _Object() + + +def abort_unary_unary(request, servicer_context): + this_should_not_be_leaked = do_not_leak_me + servicer_context.abort( + grpc.StatusCode.INTERNAL, + _ABORT_DETAILS, + ) + raise Exception('This line should not be executed!') + + +def abort_with_status_unary_unary(request, servicer_context): + servicer_context.abort_with_status( + _Status( + code=grpc.StatusCode.INTERNAL, + details=_ABORT_DETAILS, + trailing_metadata=_ABORT_METADATA, + )) + raise Exception('This line should not be executed!') + + +def invalid_code_unary_unary(request, servicer_context): + servicer_context.abort( + 42, + _ABORT_DETAILS, + ) + + +class _GenericHandler(grpc.GenericRpcHandler): + + def service(self, handler_call_details): + if handler_call_details.method == _ABORT: + return grpc.unary_unary_rpc_method_handler(abort_unary_unary) + elif handler_call_details.method == _ABORT_WITH_STATUS: + return grpc.unary_unary_rpc_method_handler( + abort_with_status_unary_unary) + elif handler_call_details.method == _INVALID_CODE: + return grpc.stream_stream_rpc_method_handler( + invalid_code_unary_unary) + else: + return None + + +class AbortTest(unittest.TestCase): + + def setUp(self): + self._server = test_common.test_server() + port = self._server.add_insecure_port('[::]:0') + self._server.add_generic_rpc_handlers((_GenericHandler(),)) + self._server.start() + + self._channel = grpc.insecure_channel('localhost:%d' % port) + + def tearDown(self): + self._channel.close() + self._server.stop(0) + + def test_abort(self): + with self.assertRaises(grpc.RpcError) as exception_context: + self._channel.unary_unary(_ABORT)(_REQUEST) + rpc_error = exception_context.exception + + self.assertEqual(rpc_error.code(), grpc.StatusCode.INTERNAL) + self.assertEqual(rpc_error.details(), _ABORT_DETAILS) + + # This test ensures that abort() does not store the raised exception, which + # on Python 3 (via the `__traceback__` attribute) holds a reference to + # all local vars. Storing the raised exception can prevent GC and stop the + # grpc_call from being unref'ed, even after server shutdown. + @unittest.skip("https://github.com/grpc/grpc/issues/17927") + def test_abort_does_not_leak_local_vars(self): + global do_not_leak_me # pylint: disable=global-statement + weak_ref = weakref.ref(do_not_leak_me) + + # Servicer will abort() after creating a local ref to do_not_leak_me. + with self.assertRaises(grpc.RpcError): + self._channel.unary_unary(_ABORT)(_REQUEST) + + # Server may still have a stack frame reference to the exception even + # after client sees error, so ensure server has shutdown. + self._server.stop(None) + do_not_leak_me = None + self.assertIsNone(weak_ref()) + + def test_abort_with_status(self): + with self.assertRaises(grpc.RpcError) as exception_context: + self._channel.unary_unary(_ABORT_WITH_STATUS)(_REQUEST) + rpc_error = exception_context.exception + + self.assertEqual(rpc_error.code(), grpc.StatusCode.INTERNAL) + self.assertEqual(rpc_error.details(), _ABORT_DETAILS) + self.assertEqual(rpc_error.trailing_metadata(), _ABORT_METADATA) + + def test_invalid_code(self): + with self.assertRaises(grpc.RpcError) as exception_context: + self._channel.unary_unary(_INVALID_CODE)(_REQUEST) + rpc_error = exception_context.exception + + self.assertEqual(rpc_error.code(), grpc.StatusCode.UNKNOWN) + self.assertEqual(rpc_error.details(), _ABORT_DETAILS) + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_api_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_api_test.py new file mode 100644 index 00000000000..a459ee6e192 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_api_test.py @@ -0,0 +1,118 @@ +# Copyright 2016 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test of gRPC Python's application-layer API.""" + +import unittest +import logging + +import six + +import grpc + +from tests.unit import _from_grpc_import_star + + +class AllTest(unittest.TestCase): + + def testAll(self): + expected_grpc_code_elements = ( + 'FutureTimeoutError', + 'FutureCancelledError', + 'Future', + 'ChannelConnectivity', + 'Compression', + 'StatusCode', + 'Status', + 'RpcError', + 'RpcContext', + 'Call', + 'ChannelCredentials', + 'CallCredentials', + 'AuthMetadataContext', + 'AuthMetadataPluginCallback', + 'AuthMetadataPlugin', + 'ServerCertificateConfiguration', + 'ServerCredentials', + 'UnaryUnaryMultiCallable', + 'UnaryStreamMultiCallable', + 'StreamUnaryMultiCallable', + 'StreamStreamMultiCallable', + 'UnaryUnaryClientInterceptor', + 'UnaryStreamClientInterceptor', + 'StreamUnaryClientInterceptor', + 'StreamStreamClientInterceptor', + 'Channel', + 'ServicerContext', + 'RpcMethodHandler', + 'HandlerCallDetails', + 'GenericRpcHandler', + 'ServiceRpcHandler', + 'Server', + 'ServerInterceptor', + 'LocalConnectionType', + 'local_channel_credentials', + 'local_server_credentials', + 'alts_channel_credentials', + 'alts_server_credentials', + 'unary_unary_rpc_method_handler', + 'unary_stream_rpc_method_handler', + 'stream_unary_rpc_method_handler', + 'ClientCallDetails', + 'stream_stream_rpc_method_handler', + 'method_handlers_generic_handler', + 'ssl_channel_credentials', + 'metadata_call_credentials', + 'access_token_call_credentials', + 'composite_call_credentials', + 'composite_channel_credentials', + 'ssl_server_credentials', + 'ssl_server_certificate_configuration', + 'dynamic_ssl_server_credentials', + 'channel_ready_future', + 'insecure_channel', + 'secure_channel', + 'intercept_channel', + 'server', + 'protos', + 'services', + 'protos_and_services', + ) + + six.assertCountEqual(self, expected_grpc_code_elements, + _from_grpc_import_star.GRPC_ELEMENTS) + + +class ChannelConnectivityTest(unittest.TestCase): + + def testChannelConnectivity(self): + self.assertSequenceEqual(( + grpc.ChannelConnectivity.IDLE, + grpc.ChannelConnectivity.CONNECTING, + grpc.ChannelConnectivity.READY, + grpc.ChannelConnectivity.TRANSIENT_FAILURE, + grpc.ChannelConnectivity.SHUTDOWN, + ), tuple(grpc.ChannelConnectivity)) + + +class ChannelTest(unittest.TestCase): + + def test_secure_channel(self): + channel_credentials = grpc.ssl_channel_credentials() + channel = grpc.secure_channel('google.com:443', channel_credentials) + channel.close() + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_auth_context_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_auth_context_test.py new file mode 100644 index 00000000000..817c528237b --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_auth_context_test.py @@ -0,0 +1,193 @@ +# Copyright 2017 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests exposure of SSL auth context""" + +import pickle +import unittest +import logging + +import grpc +from grpc import _channel +from grpc.experimental import session_cache +import six + +from tests.unit import test_common +from tests.unit import resources + +_REQUEST = b'\x00\x00\x00' +_RESPONSE = b'\x00\x00\x00' + +_UNARY_UNARY = '/test/UnaryUnary' + +_SERVER_HOST_OVERRIDE = 'foo.test.google.fr' +_CLIENT_IDS = ( + b'*.test.google.fr', + b'waterzooi.test.google.be', + b'*.test.youtube.com', + b'192.168.1.3', +) +_ID = 'id' +_ID_KEY = 'id_key' +_AUTH_CTX = 'auth_ctx' + +_PRIVATE_KEY = resources.private_key() +_CERTIFICATE_CHAIN = resources.certificate_chain() +_TEST_ROOT_CERTIFICATES = resources.test_root_certificates() +_SERVER_CERTS = ((_PRIVATE_KEY, _CERTIFICATE_CHAIN),) +_PROPERTY_OPTIONS = (( + 'grpc.ssl_target_name_override', + _SERVER_HOST_OVERRIDE, +),) + + +def handle_unary_unary(request, servicer_context): + return pickle.dumps({ + _ID: servicer_context.peer_identities(), + _ID_KEY: servicer_context.peer_identity_key(), + _AUTH_CTX: servicer_context.auth_context() + }) + + +class AuthContextTest(unittest.TestCase): + + def testInsecure(self): + handler = grpc.method_handlers_generic_handler('test', { + 'UnaryUnary': + grpc.unary_unary_rpc_method_handler(handle_unary_unary) + }) + server = test_common.test_server() + server.add_generic_rpc_handlers((handler,)) + port = server.add_insecure_port('[::]:0') + server.start() + + with grpc.insecure_channel('localhost:%d' % port) as channel: + response = channel.unary_unary(_UNARY_UNARY)(_REQUEST) + server.stop(None) + + auth_data = pickle.loads(response) + self.assertIsNone(auth_data[_ID]) + self.assertIsNone(auth_data[_ID_KEY]) + self.assertDictEqual({}, auth_data[_AUTH_CTX]) + + def testSecureNoCert(self): + handler = grpc.method_handlers_generic_handler('test', { + 'UnaryUnary': + grpc.unary_unary_rpc_method_handler(handle_unary_unary) + }) + server = test_common.test_server() + server.add_generic_rpc_handlers((handler,)) + server_cred = grpc.ssl_server_credentials(_SERVER_CERTS) + port = server.add_secure_port('[::]:0', server_cred) + server.start() + + channel_creds = grpc.ssl_channel_credentials( + root_certificates=_TEST_ROOT_CERTIFICATES) + channel = grpc.secure_channel('localhost:{}'.format(port), + channel_creds, + options=_PROPERTY_OPTIONS) + response = channel.unary_unary(_UNARY_UNARY)(_REQUEST) + channel.close() + server.stop(None) + + auth_data = pickle.loads(response) + self.assertIsNone(auth_data[_ID]) + self.assertIsNone(auth_data[_ID_KEY]) + self.assertDictEqual( + { + 'security_level': [b'TSI_PRIVACY_AND_INTEGRITY'], + 'transport_security_type': [b'ssl'], + 'ssl_session_reused': [b'false'], + }, auth_data[_AUTH_CTX]) + + def testSecureClientCert(self): + handler = grpc.method_handlers_generic_handler('test', { + 'UnaryUnary': + grpc.unary_unary_rpc_method_handler(handle_unary_unary) + }) + server = test_common.test_server() + server.add_generic_rpc_handlers((handler,)) + server_cred = grpc.ssl_server_credentials( + _SERVER_CERTS, + root_certificates=_TEST_ROOT_CERTIFICATES, + require_client_auth=True) + port = server.add_secure_port('[::]:0', server_cred) + server.start() + + channel_creds = grpc.ssl_channel_credentials( + root_certificates=_TEST_ROOT_CERTIFICATES, + private_key=_PRIVATE_KEY, + certificate_chain=_CERTIFICATE_CHAIN) + channel = grpc.secure_channel('localhost:{}'.format(port), + channel_creds, + options=_PROPERTY_OPTIONS) + + response = channel.unary_unary(_UNARY_UNARY)(_REQUEST) + channel.close() + server.stop(None) + + auth_data = pickle.loads(response) + auth_ctx = auth_data[_AUTH_CTX] + six.assertCountEqual(self, _CLIENT_IDS, auth_data[_ID]) + self.assertEqual('x509_subject_alternative_name', auth_data[_ID_KEY]) + self.assertSequenceEqual([b'ssl'], auth_ctx['transport_security_type']) + self.assertSequenceEqual([b'*.test.google.com'], + auth_ctx['x509_common_name']) + + def _do_one_shot_client_rpc(self, channel_creds, channel_options, port, + expect_ssl_session_reused): + channel = grpc.secure_channel('localhost:{}'.format(port), + channel_creds, + options=channel_options) + response = channel.unary_unary(_UNARY_UNARY)(_REQUEST) + auth_data = pickle.loads(response) + self.assertEqual(expect_ssl_session_reused, + auth_data[_AUTH_CTX]['ssl_session_reused']) + channel.close() + + def testSessionResumption(self): + # Set up a secure server + handler = grpc.method_handlers_generic_handler('test', { + 'UnaryUnary': + grpc.unary_unary_rpc_method_handler(handle_unary_unary) + }) + server = test_common.test_server() + server.add_generic_rpc_handlers((handler,)) + server_cred = grpc.ssl_server_credentials(_SERVER_CERTS) + port = server.add_secure_port('[::]:0', server_cred) + server.start() + + # Create a cache for TLS session tickets + cache = session_cache.ssl_session_cache_lru(1) + channel_creds = grpc.ssl_channel_credentials( + root_certificates=_TEST_ROOT_CERTIFICATES) + channel_options = _PROPERTY_OPTIONS + ( + ('grpc.ssl_session_cache', cache),) + + # Initial connection has no session to resume + self._do_one_shot_client_rpc(channel_creds, + channel_options, + port, + expect_ssl_session_reused=[b'false']) + + # Subsequent connections resume sessions + self._do_one_shot_client_rpc(channel_creds, + channel_options, + port, + expect_ssl_session_reused=[b'true']) + server.stop(None) + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_auth_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_auth_test.py new file mode 100644 index 00000000000..d9df2add4f2 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_auth_test.py @@ -0,0 +1,82 @@ +# Copyright 2016 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests of standard AuthMetadataPlugins.""" + +import collections +import threading +import unittest +import logging + +from grpc import _auth + + +class MockGoogleCreds(object): + + def get_access_token(self): + token = collections.namedtuple('MockAccessTokenInfo', + ('access_token', 'expires_in')) + token.access_token = 'token' + return token + + +class MockExceptionGoogleCreds(object): + + def get_access_token(self): + raise Exception() + + +class GoogleCallCredentialsTest(unittest.TestCase): + + def test_google_call_credentials_success(self): + callback_event = threading.Event() + + def mock_callback(metadata, error): + self.assertEqual(metadata, (('authorization', 'Bearer token'),)) + self.assertIsNone(error) + callback_event.set() + + call_creds = _auth.GoogleCallCredentials(MockGoogleCreds()) + call_creds(None, mock_callback) + self.assertTrue(callback_event.wait(1.0)) + + def test_google_call_credentials_error(self): + callback_event = threading.Event() + + def mock_callback(metadata, error): + self.assertIsNotNone(error) + callback_event.set() + + call_creds = _auth.GoogleCallCredentials(MockExceptionGoogleCreds()) + call_creds(None, mock_callback) + self.assertTrue(callback_event.wait(1.0)) + + +class AccessTokenAuthMetadataPluginTest(unittest.TestCase): + + def test_google_call_credentials_success(self): + callback_event = threading.Event() + + def mock_callback(metadata, error): + self.assertEqual(metadata, (('authorization', 'Bearer token'),)) + self.assertIsNone(error) + callback_event.set() + + metadata_plugin = _auth.AccessTokenAuthMetadataPlugin('token') + metadata_plugin(None, mock_callback) + self.assertTrue(callback_event.wait(1.0)) + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_channel_args_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_channel_args_test.py new file mode 100644 index 00000000000..2f2eea61dbd --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_channel_args_test.py @@ -0,0 +1,65 @@ +# Copyright 2016 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests of channel arguments on client/server side.""" + +from concurrent import futures +import unittest +import logging + +import grpc + + +class TestPointerWrapper(object): + + def __int__(self): + return 123456 + + +TEST_CHANNEL_ARGS = ( + ('arg1', b'bytes_val'), + ('arg2', 'str_val'), + ('arg3', 1), + (b'arg4', 'str_val'), + ('arg6', TestPointerWrapper()), +) + +INVALID_TEST_CHANNEL_ARGS = [ + { + 'foo': 'bar' + }, + (('key',),), + 'str', +] + + +class ChannelArgsTest(unittest.TestCase): + + def test_client(self): + grpc.insecure_channel('localhost:8080', options=TEST_CHANNEL_ARGS) + + def test_server(self): + grpc.server(futures.ThreadPoolExecutor(max_workers=1), + options=TEST_CHANNEL_ARGS) + + def test_invalid_client_args(self): + for invalid_arg in INVALID_TEST_CHANNEL_ARGS: + self.assertRaises(ValueError, + grpc.insecure_channel, + 'localhost:8080', + options=invalid_arg) + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_channel_close_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_channel_close_test.py new file mode 100644 index 00000000000..47f52b4890e --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_channel_close_test.py @@ -0,0 +1,220 @@ +# Copyright 2018 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests server and client side compression.""" + +import itertools +import logging +import threading +import time +import unittest + +import grpc + +from tests.unit import test_common +from tests.unit.framework.common import test_constants + +_BEAT = 0.5 +_SOME_TIME = 5 +_MORE_TIME = 10 + +_STREAM_URI = 'Meffod' +_UNARY_URI = 'MeffodMan' + + +class _StreamingMethodHandler(grpc.RpcMethodHandler): + + request_streaming = True + response_streaming = True + request_deserializer = None + response_serializer = None + + def stream_stream(self, request_iterator, servicer_context): + for request in request_iterator: + yield request * 2 + + +class _UnaryMethodHandler(grpc.RpcMethodHandler): + + request_streaming = False + response_streaming = False + request_deserializer = None + response_serializer = None + + def unary_unary(self, request, servicer_context): + return request * 2 + + +_STREAMING_METHOD_HANDLER = _StreamingMethodHandler() +_UNARY_METHOD_HANDLER = _UnaryMethodHandler() + + +class _GenericHandler(grpc.GenericRpcHandler): + + def service(self, handler_call_details): + if handler_call_details.method == _STREAM_URI: + return _STREAMING_METHOD_HANDLER + else: + return _UNARY_METHOD_HANDLER + + +_GENERIC_HANDLER = _GenericHandler() + + +class _Pipe(object): + + def __init__(self, values): + self._condition = threading.Condition() + self._values = list(values) + self._open = True + + def __iter__(self): + return self + + def _next(self): + with self._condition: + while not self._values and self._open: + self._condition.wait() + if self._values: + return self._values.pop(0) + else: + raise StopIteration() + + def next(self): + return self._next() + + def __next__(self): + return self._next() + + def add(self, value): + with self._condition: + self._values.append(value) + self._condition.notify() + + def close(self): + with self._condition: + self._open = False + self._condition.notify() + + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + self.close() + + +class ChannelCloseTest(unittest.TestCase): + + def setUp(self): + self._server = test_common.test_server( + max_workers=test_constants.THREAD_CONCURRENCY) + self._server.add_generic_rpc_handlers((_GENERIC_HANDLER,)) + self._port = self._server.add_insecure_port('[::]:0') + self._server.start() + + def tearDown(self): + self._server.stop(None) + + def test_close_immediately_after_call_invocation(self): + channel = grpc.insecure_channel('localhost:{}'.format(self._port)) + multi_callable = channel.stream_stream(_STREAM_URI) + request_iterator = _Pipe(()) + response_iterator = multi_callable(request_iterator) + channel.close() + request_iterator.close() + + self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED) + + def test_close_while_call_active(self): + channel = grpc.insecure_channel('localhost:{}'.format(self._port)) + multi_callable = channel.stream_stream(_STREAM_URI) + request_iterator = _Pipe((b'abc',)) + response_iterator = multi_callable(request_iterator) + next(response_iterator) + channel.close() + request_iterator.close() + + self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED) + + def test_context_manager_close_while_call_active(self): + with grpc.insecure_channel('localhost:{}'.format( + self._port)) as channel: # pylint: disable=bad-continuation + multi_callable = channel.stream_stream(_STREAM_URI) + request_iterator = _Pipe((b'abc',)) + response_iterator = multi_callable(request_iterator) + next(response_iterator) + request_iterator.close() + + self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED) + + def test_context_manager_close_while_many_calls_active(self): + with grpc.insecure_channel('localhost:{}'.format( + self._port)) as channel: # pylint: disable=bad-continuation + multi_callable = channel.stream_stream(_STREAM_URI) + request_iterators = tuple( + _Pipe((b'abc',)) + for _ in range(test_constants.THREAD_CONCURRENCY)) + response_iterators = [] + for request_iterator in request_iterators: + response_iterator = multi_callable(request_iterator) + next(response_iterator) + response_iterators.append(response_iterator) + for request_iterator in request_iterators: + request_iterator.close() + + for response_iterator in response_iterators: + self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED) + + def test_many_concurrent_closes(self): + channel = grpc.insecure_channel('localhost:{}'.format(self._port)) + multi_callable = channel.stream_stream(_STREAM_URI) + request_iterator = _Pipe((b'abc',)) + response_iterator = multi_callable(request_iterator) + next(response_iterator) + start = time.time() + end = start + _MORE_TIME + + def sleep_some_time_then_close(): + time.sleep(_SOME_TIME) + channel.close() + + for _ in range(test_constants.THREAD_CONCURRENCY): + close_thread = threading.Thread(target=sleep_some_time_then_close) + close_thread.start() + while True: + request_iterator.add(b'def') + time.sleep(_BEAT) + if end < time.time(): + break + request_iterator.close() + + self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED) + + def test_exception_in_callback(self): + with grpc.insecure_channel('localhost:{}'.format( + self._port)) as channel: + stream_multi_callable = channel.stream_stream(_STREAM_URI) + endless_iterator = itertools.repeat(b'abc') + stream_response_iterator = stream_multi_callable(endless_iterator) + future = channel.unary_unary(_UNARY_URI).future(b'abc') + + def on_done_callback(future): + raise Exception("This should not cause a deadlock.") + + future.add_done_callback(on_done_callback) + future.result() + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_channel_connectivity_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_channel_connectivity_test.py new file mode 100644 index 00000000000..d1b4c3c932f --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_channel_connectivity_test.py @@ -0,0 +1,155 @@ +# Copyright 2015 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests of grpc._channel.Channel connectivity.""" + +import logging +import threading +import time +import unittest + +import grpc +from tests.unit.framework.common import test_constants +from tests.unit import thread_pool + + +def _ready_in_connectivities(connectivities): + return grpc.ChannelConnectivity.READY in connectivities + + +def _last_connectivity_is_not_ready(connectivities): + return connectivities[-1] is not grpc.ChannelConnectivity.READY + + +class _Callback(object): + + def __init__(self): + self._condition = threading.Condition() + self._connectivities = [] + + def update(self, connectivity): + with self._condition: + self._connectivities.append(connectivity) + self._condition.notify() + + def connectivities(self): + with self._condition: + return tuple(self._connectivities) + + def block_until_connectivities_satisfy(self, predicate): + with self._condition: + while True: + connectivities = tuple(self._connectivities) + if predicate(connectivities): + return connectivities + else: + self._condition.wait() + + +class ChannelConnectivityTest(unittest.TestCase): + + def test_lonely_channel_connectivity(self): + callback = _Callback() + + channel = grpc.insecure_channel('localhost:12345') + channel.subscribe(callback.update, try_to_connect=False) + first_connectivities = callback.block_until_connectivities_satisfy(bool) + channel.subscribe(callback.update, try_to_connect=True) + second_connectivities = callback.block_until_connectivities_satisfy( + lambda connectivities: 2 <= len(connectivities)) + # Wait for a connection that will never happen. + time.sleep(test_constants.SHORT_TIMEOUT) + third_connectivities = callback.connectivities() + channel.unsubscribe(callback.update) + fourth_connectivities = callback.connectivities() + channel.unsubscribe(callback.update) + fifth_connectivities = callback.connectivities() + + channel.close() + + self.assertSequenceEqual((grpc.ChannelConnectivity.IDLE,), + first_connectivities) + self.assertNotIn(grpc.ChannelConnectivity.READY, second_connectivities) + self.assertNotIn(grpc.ChannelConnectivity.READY, third_connectivities) + self.assertNotIn(grpc.ChannelConnectivity.READY, fourth_connectivities) + self.assertNotIn(grpc.ChannelConnectivity.READY, fifth_connectivities) + + def test_immediately_connectable_channel_connectivity(self): + recording_thread_pool = thread_pool.RecordingThreadPool( + max_workers=None) + server = grpc.server(recording_thread_pool, + options=(('grpc.so_reuseport', 0),)) + port = server.add_insecure_port('[::]:0') + server.start() + first_callback = _Callback() + second_callback = _Callback() + + channel = grpc.insecure_channel('localhost:{}'.format(port)) + channel.subscribe(first_callback.update, try_to_connect=False) + first_connectivities = first_callback.block_until_connectivities_satisfy( + bool) + # Wait for a connection that will never happen because try_to_connect=True + # has not yet been passed. + time.sleep(test_constants.SHORT_TIMEOUT) + second_connectivities = first_callback.connectivities() + channel.subscribe(second_callback.update, try_to_connect=True) + third_connectivities = first_callback.block_until_connectivities_satisfy( + lambda connectivities: 2 <= len(connectivities)) + fourth_connectivities = second_callback.block_until_connectivities_satisfy( + bool) + # Wait for a connection that will happen (or may already have happened). + first_callback.block_until_connectivities_satisfy( + _ready_in_connectivities) + second_callback.block_until_connectivities_satisfy( + _ready_in_connectivities) + channel.close() + server.stop(None) + + self.assertSequenceEqual((grpc.ChannelConnectivity.IDLE,), + first_connectivities) + self.assertSequenceEqual((grpc.ChannelConnectivity.IDLE,), + second_connectivities) + self.assertNotIn(grpc.ChannelConnectivity.TRANSIENT_FAILURE, + third_connectivities) + self.assertNotIn(grpc.ChannelConnectivity.SHUTDOWN, + third_connectivities) + self.assertNotIn(grpc.ChannelConnectivity.TRANSIENT_FAILURE, + fourth_connectivities) + self.assertNotIn(grpc.ChannelConnectivity.SHUTDOWN, + fourth_connectivities) + self.assertFalse(recording_thread_pool.was_used()) + + def test_reachable_then_unreachable_channel_connectivity(self): + recording_thread_pool = thread_pool.RecordingThreadPool( + max_workers=None) + server = grpc.server(recording_thread_pool, + options=(('grpc.so_reuseport', 0),)) + port = server.add_insecure_port('[::]:0') + server.start() + callback = _Callback() + + channel = grpc.insecure_channel('localhost:{}'.format(port)) + channel.subscribe(callback.update, try_to_connect=True) + callback.block_until_connectivities_satisfy(_ready_in_connectivities) + # Now take down the server and confirm that channel readiness is repudiated. + server.stop(None) + callback.block_until_connectivities_satisfy( + _last_connectivity_is_not_ready) + channel.unsubscribe(callback.update) + channel.close() + self.assertFalse(recording_thread_pool.was_used()) + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_channel_ready_future_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_channel_ready_future_test.py new file mode 100644 index 00000000000..ca9ebc16fe9 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_channel_ready_future_test.py @@ -0,0 +1,97 @@ +# Copyright 2015 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests of grpc.channel_ready_future.""" + +import threading +import unittest +import logging + +import grpc +from tests.unit.framework.common import test_constants +from tests.unit import thread_pool + + +class _Callback(object): + + def __init__(self): + self._condition = threading.Condition() + self._value = None + + def accept_value(self, value): + with self._condition: + self._value = value + self._condition.notify_all() + + def block_until_called(self): + with self._condition: + while self._value is None: + self._condition.wait() + return self._value + + +class ChannelReadyFutureTest(unittest.TestCase): + + def test_lonely_channel_connectivity(self): + channel = grpc.insecure_channel('localhost:12345') + callback = _Callback() + + ready_future = grpc.channel_ready_future(channel) + ready_future.add_done_callback(callback.accept_value) + with self.assertRaises(grpc.FutureTimeoutError): + ready_future.result(timeout=test_constants.SHORT_TIMEOUT) + self.assertFalse(ready_future.cancelled()) + self.assertFalse(ready_future.done()) + self.assertTrue(ready_future.running()) + ready_future.cancel() + value_passed_to_callback = callback.block_until_called() + self.assertIs(ready_future, value_passed_to_callback) + self.assertTrue(ready_future.cancelled()) + self.assertTrue(ready_future.done()) + self.assertFalse(ready_future.running()) + + channel.close() + + def test_immediately_connectable_channel_connectivity(self): + recording_thread_pool = thread_pool.RecordingThreadPool( + max_workers=None) + server = grpc.server(recording_thread_pool, + options=(('grpc.so_reuseport', 0),)) + port = server.add_insecure_port('[::]:0') + server.start() + channel = grpc.insecure_channel('localhost:{}'.format(port)) + callback = _Callback() + + ready_future = grpc.channel_ready_future(channel) + ready_future.add_done_callback(callback.accept_value) + self.assertIsNone( + ready_future.result(timeout=test_constants.LONG_TIMEOUT)) + value_passed_to_callback = callback.block_until_called() + self.assertIs(ready_future, value_passed_to_callback) + self.assertFalse(ready_future.cancelled()) + self.assertTrue(ready_future.done()) + self.assertFalse(ready_future.running()) + # Cancellation after maturity has no effect. + ready_future.cancel() + self.assertFalse(ready_future.cancelled()) + self.assertTrue(ready_future.done()) + self.assertFalse(ready_future.running()) + self.assertFalse(recording_thread_pool.was_used()) + + channel.close() + server.stop(None) + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_compression_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_compression_test.py new file mode 100644 index 00000000000..bc58e1032ca --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_compression_test.py @@ -0,0 +1,382 @@ +# Copyright 2016 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests server and client side compression.""" + +import unittest + +import contextlib +from concurrent import futures +import functools +import itertools +import logging +import os + +import grpc +from grpc import _grpcio_metadata + +from tests.unit import test_common +from tests.unit.framework.common import test_constants +from tests.unit import _tcp_proxy + +_UNARY_UNARY = '/test/UnaryUnary' +_UNARY_STREAM = '/test/UnaryStream' +_STREAM_UNARY = '/test/StreamUnary' +_STREAM_STREAM = '/test/StreamStream' + +# Cut down on test time. +_STREAM_LENGTH = test_constants.STREAM_LENGTH // 16 + +_HOST = 'localhost' + +_REQUEST = b'\x00' * 100 +_COMPRESSION_RATIO_THRESHOLD = 0.05 +_COMPRESSION_METHODS = ( + None, + # Disabled for test tractability. + # grpc.Compression.NoCompression, + # grpc.Compression.Deflate, + grpc.Compression.Gzip, +) +_COMPRESSION_NAMES = { + None: 'Uncompressed', + grpc.Compression.NoCompression: 'NoCompression', + grpc.Compression.Deflate: 'DeflateCompression', + grpc.Compression.Gzip: 'GzipCompression', +} + +_TEST_OPTIONS = { + 'client_streaming': (True, False), + 'server_streaming': (True, False), + 'channel_compression': _COMPRESSION_METHODS, + 'multicallable_compression': _COMPRESSION_METHODS, + 'server_compression': _COMPRESSION_METHODS, + 'server_call_compression': _COMPRESSION_METHODS, +} + + +def _make_handle_unary_unary(pre_response_callback): + + def _handle_unary(request, servicer_context): + if pre_response_callback: + pre_response_callback(request, servicer_context) + return request + + return _handle_unary + + +def _make_handle_unary_stream(pre_response_callback): + + def _handle_unary_stream(request, servicer_context): + if pre_response_callback: + pre_response_callback(request, servicer_context) + for _ in range(_STREAM_LENGTH): + yield request + + return _handle_unary_stream + + +def _make_handle_stream_unary(pre_response_callback): + + def _handle_stream_unary(request_iterator, servicer_context): + if pre_response_callback: + pre_response_callback(request_iterator, servicer_context) + response = None + for request in request_iterator: + if not response: + response = request + return response + + return _handle_stream_unary + + +def _make_handle_stream_stream(pre_response_callback): + + def _handle_stream(request_iterator, servicer_context): + # TODO(issue:#6891) We should be able to remove this loop, + # and replace with return; yield + for request in request_iterator: + if pre_response_callback: + pre_response_callback(request, servicer_context) + yield request + + return _handle_stream + + +def set_call_compression(compression_method, request_or_iterator, + servicer_context): + del request_or_iterator + servicer_context.set_compression(compression_method) + + +def disable_next_compression(request, servicer_context): + del request + servicer_context.disable_next_message_compression() + + +def disable_first_compression(request, servicer_context): + if int(request.decode('ascii')) == 0: + servicer_context.disable_next_message_compression() + + +class _MethodHandler(grpc.RpcMethodHandler): + + def __init__(self, request_streaming, response_streaming, + pre_response_callback): + self.request_streaming = request_streaming + self.response_streaming = response_streaming + self.request_deserializer = None + self.response_serializer = None + self.unary_unary = None + self.unary_stream = None + self.stream_unary = None + self.stream_stream = None + + if self.request_streaming and self.response_streaming: + self.stream_stream = _make_handle_stream_stream( + pre_response_callback) + elif not self.request_streaming and not self.response_streaming: + self.unary_unary = _make_handle_unary_unary(pre_response_callback) + elif not self.request_streaming and self.response_streaming: + self.unary_stream = _make_handle_unary_stream(pre_response_callback) + else: + self.stream_unary = _make_handle_stream_unary(pre_response_callback) + + +class _GenericHandler(grpc.GenericRpcHandler): + + def __init__(self, pre_response_callback): + self._pre_response_callback = pre_response_callback + + def service(self, handler_call_details): + if handler_call_details.method == _UNARY_UNARY: + return _MethodHandler(False, False, self._pre_response_callback) + elif handler_call_details.method == _UNARY_STREAM: + return _MethodHandler(False, True, self._pre_response_callback) + elif handler_call_details.method == _STREAM_UNARY: + return _MethodHandler(True, False, self._pre_response_callback) + elif handler_call_details.method == _STREAM_STREAM: + return _MethodHandler(True, True, self._pre_response_callback) + else: + return None + + +def _instrumented_client_server_pair(channel_kwargs, server_kwargs, + server_handler): + server = grpc.server(futures.ThreadPoolExecutor(), **server_kwargs) + server.add_generic_rpc_handlers((server_handler,)) + server_port = server.add_insecure_port('{}:0'.format(_HOST)) + server.start() + with _tcp_proxy.TcpProxy(_HOST, _HOST, server_port) as proxy: + proxy_port = proxy.get_port() + with grpc.insecure_channel('{}:{}'.format(_HOST, proxy_port), + **channel_kwargs) as client_channel: + try: + yield client_channel, proxy, server + finally: + server.stop(None) + + +def _get_byte_counts(channel_kwargs, multicallable_kwargs, client_function, + server_kwargs, server_handler, message): + with _instrumented_client_server_pair(channel_kwargs, server_kwargs, + server_handler) as pipeline: + client_channel, proxy, server = pipeline + client_function(client_channel, multicallable_kwargs, message) + return proxy.get_byte_count() + + +def _get_compression_ratios(client_function, first_channel_kwargs, + first_multicallable_kwargs, first_server_kwargs, + first_server_handler, second_channel_kwargs, + second_multicallable_kwargs, second_server_kwargs, + second_server_handler, message): + try: + # This test requires the byte length of each connection to be deterministic. As + # it turns out, flow control puts bytes on the wire in a nondeterministic + # manner. We disable it here in order to measure compression ratios + # deterministically. + os.environ['GRPC_EXPERIMENTAL_DISABLE_FLOW_CONTROL'] = 'true' + first_bytes_sent, first_bytes_received = _get_byte_counts( + first_channel_kwargs, first_multicallable_kwargs, client_function, + first_server_kwargs, first_server_handler, message) + second_bytes_sent, second_bytes_received = _get_byte_counts( + second_channel_kwargs, second_multicallable_kwargs, client_function, + second_server_kwargs, second_server_handler, message) + return ((second_bytes_sent - first_bytes_sent) / + float(first_bytes_sent), + (second_bytes_received - first_bytes_received) / + float(first_bytes_received)) + finally: + del os.environ['GRPC_EXPERIMENTAL_DISABLE_FLOW_CONTROL'] + + +def _unary_unary_client(channel, multicallable_kwargs, message): + multi_callable = channel.unary_unary(_UNARY_UNARY) + response = multi_callable(message, **multicallable_kwargs) + if response != message: + raise RuntimeError("Request '{}' != Response '{}'".format( + message, response)) + + +def _unary_stream_client(channel, multicallable_kwargs, message): + multi_callable = channel.unary_stream(_UNARY_STREAM) + response_iterator = multi_callable(message, **multicallable_kwargs) + for response in response_iterator: + if response != message: + raise RuntimeError("Request '{}' != Response '{}'".format( + message, response)) + + +def _stream_unary_client(channel, multicallable_kwargs, message): + multi_callable = channel.stream_unary(_STREAM_UNARY) + requests = (_REQUEST for _ in range(_STREAM_LENGTH)) + response = multi_callable(requests, **multicallable_kwargs) + if response != message: + raise RuntimeError("Request '{}' != Response '{}'".format( + message, response)) + + +def _stream_stream_client(channel, multicallable_kwargs, message): + multi_callable = channel.stream_stream(_STREAM_STREAM) + request_prefix = str(0).encode('ascii') * 100 + requests = ( + request_prefix + str(i).encode('ascii') for i in range(_STREAM_LENGTH)) + response_iterator = multi_callable(requests, **multicallable_kwargs) + for i, response in enumerate(response_iterator): + if int(response.decode('ascii')) != i: + raise RuntimeError("Request '{}' != Response '{}'".format( + i, response)) + + +class CompressionTest(unittest.TestCase): + + def assertCompressed(self, compression_ratio): + self.assertLess( + compression_ratio, + -1.0 * _COMPRESSION_RATIO_THRESHOLD, + msg='Actual compression ratio: {}'.format(compression_ratio)) + + def assertNotCompressed(self, compression_ratio): + self.assertGreaterEqual( + compression_ratio, + -1.0 * _COMPRESSION_RATIO_THRESHOLD, + msg='Actual compession ratio: {}'.format(compression_ratio)) + + def assertConfigurationCompressed(self, client_streaming, server_streaming, + channel_compression, + multicallable_compression, + server_compression, + server_call_compression): + client_side_compressed = channel_compression or multicallable_compression + server_side_compressed = server_compression or server_call_compression + channel_kwargs = { + 'compression': channel_compression, + } if channel_compression else {} + multicallable_kwargs = { + 'compression': multicallable_compression, + } if multicallable_compression else {} + + client_function = None + if not client_streaming and not server_streaming: + client_function = _unary_unary_client + elif not client_streaming and server_streaming: + client_function = _unary_stream_client + elif client_streaming and not server_streaming: + client_function = _stream_unary_client + else: + client_function = _stream_stream_client + + server_kwargs = { + 'compression': server_compression, + } if server_compression else {} + server_handler = _GenericHandler( + functools.partial(set_call_compression, grpc.Compression.Gzip) + ) if server_call_compression else _GenericHandler(None) + sent_ratio, received_ratio = _get_compression_ratios( + client_function, {}, {}, {}, _GenericHandler(None), channel_kwargs, + multicallable_kwargs, server_kwargs, server_handler, _REQUEST) + + if client_side_compressed: + self.assertCompressed(sent_ratio) + else: + self.assertNotCompressed(sent_ratio) + + if server_side_compressed: + self.assertCompressed(received_ratio) + else: + self.assertNotCompressed(received_ratio) + + def testDisableNextCompressionStreaming(self): + server_kwargs = { + 'compression': grpc.Compression.Deflate, + } + _, received_ratio = _get_compression_ratios( + _stream_stream_client, {}, {}, {}, _GenericHandler(None), {}, {}, + server_kwargs, _GenericHandler(disable_next_compression), _REQUEST) + self.assertNotCompressed(received_ratio) + + def testDisableNextCompressionStreamingResets(self): + server_kwargs = { + 'compression': grpc.Compression.Deflate, + } + _, received_ratio = _get_compression_ratios( + _stream_stream_client, {}, {}, {}, _GenericHandler(None), {}, {}, + server_kwargs, _GenericHandler(disable_first_compression), _REQUEST) + self.assertCompressed(received_ratio) + + +def _get_compression_str(name, value): + return '{}{}'.format(name, _COMPRESSION_NAMES[value]) + + +def _get_compression_test_name(client_streaming, server_streaming, + channel_compression, multicallable_compression, + server_compression, server_call_compression): + client_arity = 'Stream' if client_streaming else 'Unary' + server_arity = 'Stream' if server_streaming else 'Unary' + arity = '{}{}'.format(client_arity, server_arity) + channel_compression_str = _get_compression_str('Channel', + channel_compression) + multicallable_compression_str = _get_compression_str( + 'Multicallable', multicallable_compression) + server_compression_str = _get_compression_str('Server', server_compression) + server_call_compression_str = _get_compression_str('ServerCall', + server_call_compression) + return 'test{}{}{}{}{}'.format(arity, channel_compression_str, + multicallable_compression_str, + server_compression_str, + server_call_compression_str) + + +def _test_options(): + for test_parameters in itertools.product(*_TEST_OPTIONS.values()): + yield dict(zip(_TEST_OPTIONS.keys(), test_parameters)) + + +for options in _test_options(): + + def test_compression(**kwargs): + + def _test_compression(self): + self.assertConfigurationCompressed(**kwargs) + + return _test_compression + + setattr(CompressionTest, _get_compression_test_name(**options), + test_compression(**options)) + +if __name__ == '__main__': + logging.basicConfig() + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_contextvars_propagation_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_contextvars_propagation_test.py new file mode 100644 index 00000000000..fec0fbd7df4 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_contextvars_propagation_test.py @@ -0,0 +1,118 @@ +# Copyright 2020 The gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test of propagation of contextvars to AuthMetadataPlugin threads..""" + +import contextlib +import logging +import os +import sys +import unittest + +import grpc + +from tests.unit import test_common + +_UNARY_UNARY = "/test/UnaryUnary" +_REQUEST = b"0000" + + +def _unary_unary_handler(request, context): + return request + + +def contextvars_supported(): + try: + import contextvars + return True + except ImportError: + return False + + +class _GenericHandler(grpc.GenericRpcHandler): + + def service(self, handler_call_details): + if handler_call_details.method == _UNARY_UNARY: + return grpc.unary_unary_rpc_method_handler(_unary_unary_handler) + else: + raise NotImplementedError() + + +def _server(): + try: + server = test_common.test_server() + target = 'localhost:0' + port = server.add_insecure_port(target) + server.add_generic_rpc_handlers((_GenericHandler(),)) + server.start() + yield port + finally: + server.stop(None) + + +if contextvars_supported(): + import contextvars + + _EXPECTED_VALUE = 24601 + test_var = contextvars.ContextVar("test_var", default=None) + + def set_up_expected_context(): + test_var.set(_EXPECTED_VALUE) + + class TestCallCredentials(grpc.AuthMetadataPlugin): + + def __call__(self, context, callback): + if test_var.get() != _EXPECTED_VALUE: + raise AssertionError("{} != {}".format(test_var.get(), + _EXPECTED_VALUE)) + callback((), None) + + def assert_called(self, test): + test.assertTrue(self._invoked) + test.assertEqual(_EXPECTED_VALUE, self._recorded_value) + +else: + + def set_up_expected_context(): + pass + + class TestCallCredentials(grpc.AuthMetadataPlugin): + + def __call__(self, context, callback): + callback((), None) + + +# TODO(https://github.com/grpc/grpc/issues/22257) [email protected](os.name == "nt", "LocalCredentials not supported on Windows.") +class ContextVarsPropagationTest(unittest.TestCase): + + def test_propagation_to_auth_plugin(self): + set_up_expected_context() + with _server() as port: + target = "localhost:{}".format(port) + local_credentials = grpc.local_channel_credentials() + test_call_credentials = TestCallCredentials() + call_credentials = grpc.metadata_call_credentials( + test_call_credentials, "test call credentials") + composite_credentials = grpc.composite_channel_credentials( + local_credentials, call_credentials) + with grpc.secure_channel(target, composite_credentials) as channel: + stub = channel.unary_unary(_UNARY_UNARY) + response = stub(_REQUEST, wait_for_ready=True) + self.assertEqual(_REQUEST, response) + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_credentials_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_credentials_test.py new file mode 100644 index 00000000000..187a6f03881 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_credentials_test.py @@ -0,0 +1,70 @@ +# Copyright 2016 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests of credentials.""" + +import unittest +import logging +import six + +import grpc + + +class CredentialsTest(unittest.TestCase): + + def test_call_credentials_composition(self): + first = grpc.access_token_call_credentials('abc') + second = grpc.access_token_call_credentials('def') + third = grpc.access_token_call_credentials('ghi') + + first_and_second = grpc.composite_call_credentials(first, second) + first_second_and_third = grpc.composite_call_credentials( + first, second, third) + + self.assertIsInstance(first_and_second, grpc.CallCredentials) + self.assertIsInstance(first_second_and_third, grpc.CallCredentials) + + def test_channel_credentials_composition(self): + first_call_credentials = grpc.access_token_call_credentials('abc') + second_call_credentials = grpc.access_token_call_credentials('def') + third_call_credentials = grpc.access_token_call_credentials('ghi') + channel_credentials = grpc.ssl_channel_credentials() + + channel_and_first = grpc.composite_channel_credentials( + channel_credentials, first_call_credentials) + channel_first_and_second = grpc.composite_channel_credentials( + channel_credentials, first_call_credentials, + second_call_credentials) + channel_first_second_and_third = grpc.composite_channel_credentials( + channel_credentials, first_call_credentials, + second_call_credentials, third_call_credentials) + + self.assertIsInstance(channel_and_first, grpc.ChannelCredentials) + self.assertIsInstance(channel_first_and_second, grpc.ChannelCredentials) + self.assertIsInstance(channel_first_second_and_third, + grpc.ChannelCredentials) + + @unittest.skipIf(six.PY2, 'only invalid in Python3') + def test_invalid_string_certificate(self): + self.assertRaises( + TypeError, + grpc.ssl_channel_credentials, + root_certificates='A Certificate', + private_key=None, + certificate_chain=None, + ) + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_cython/__init__.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_cython/__init__.py new file mode 100644 index 00000000000..5fb4f3c3cfd --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_cython/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2015 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py new file mode 100644 index 00000000000..b279f3d07c5 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py @@ -0,0 +1,223 @@ +# Copyright 2016 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test making many calls and immediately cancelling most of them.""" + +import threading +import unittest + +from grpc._cython import cygrpc +from grpc.framework.foundation import logging_pool +from tests.unit.framework.common import test_constants +from tests.unit._cython import test_utilities + +_EMPTY_FLAGS = 0 +_EMPTY_METADATA = () + +_SERVER_SHUTDOWN_TAG = 'server_shutdown' +_REQUEST_CALL_TAG = 'request_call' +_RECEIVE_CLOSE_ON_SERVER_TAG = 'receive_close_on_server' +_RECEIVE_MESSAGE_TAG = 'receive_message' +_SERVER_COMPLETE_CALL_TAG = 'server_complete_call' + +_SUCCESS_CALL_FRACTION = 1.0 / 8.0 +_SUCCESSFUL_CALLS = int(test_constants.RPC_CONCURRENCY * _SUCCESS_CALL_FRACTION) +_UNSUCCESSFUL_CALLS = test_constants.RPC_CONCURRENCY - _SUCCESSFUL_CALLS + + +class _State(object): + + def __init__(self): + self.condition = threading.Condition() + self.handlers_released = False + self.parked_handlers = 0 + self.handled_rpcs = 0 + + +def _is_cancellation_event(event): + return (event.tag is _RECEIVE_CLOSE_ON_SERVER_TAG and + event.batch_operations[0].cancelled()) + + +class _Handler(object): + + def __init__(self, state, completion_queue, rpc_event): + self._state = state + self._lock = threading.Lock() + self._completion_queue = completion_queue + self._call = rpc_event.call + + def __call__(self): + with self._state.condition: + self._state.parked_handlers += 1 + if self._state.parked_handlers == test_constants.THREAD_CONCURRENCY: + self._state.condition.notify_all() + while not self._state.handlers_released: + self._state.condition.wait() + + with self._lock: + self._call.start_server_batch( + (cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),), + _RECEIVE_CLOSE_ON_SERVER_TAG) + self._call.start_server_batch( + (cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),), + _RECEIVE_MESSAGE_TAG) + first_event = self._completion_queue.poll() + if _is_cancellation_event(first_event): + self._completion_queue.poll() + else: + with self._lock: + operations = ( + cygrpc.SendInitialMetadataOperation(_EMPTY_METADATA, + _EMPTY_FLAGS), + cygrpc.SendMessageOperation(b'\x79\x57', _EMPTY_FLAGS), + cygrpc.SendStatusFromServerOperation( + _EMPTY_METADATA, cygrpc.StatusCode.ok, b'test details!', + _EMPTY_FLAGS), + ) + self._call.start_server_batch(operations, + _SERVER_COMPLETE_CALL_TAG) + self._completion_queue.poll() + self._completion_queue.poll() + + +def _serve(state, server, server_completion_queue, thread_pool): + for _ in range(test_constants.RPC_CONCURRENCY): + call_completion_queue = cygrpc.CompletionQueue() + server.request_call(call_completion_queue, server_completion_queue, + _REQUEST_CALL_TAG) + rpc_event = server_completion_queue.poll() + thread_pool.submit(_Handler(state, call_completion_queue, rpc_event)) + with state.condition: + state.handled_rpcs += 1 + if test_constants.RPC_CONCURRENCY <= state.handled_rpcs: + state.condition.notify_all() + server_completion_queue.poll() + + +class _QueueDriver(object): + + def __init__(self, condition, completion_queue, due): + self._condition = condition + self._completion_queue = completion_queue + self._due = due + self._events = [] + self._returned = False + + def start(self): + + def in_thread(): + while True: + event = self._completion_queue.poll() + with self._condition: + self._events.append(event) + self._due.remove(event.tag) + self._condition.notify_all() + if not self._due: + self._returned = True + return + + thread = threading.Thread(target=in_thread) + thread.start() + + def events(self, at_least): + with self._condition: + while len(self._events) < at_least: + self._condition.wait() + return tuple(self._events) + + +class CancelManyCallsTest(unittest.TestCase): + + def testCancelManyCalls(self): + server_thread_pool = logging_pool.pool( + test_constants.THREAD_CONCURRENCY) + + server_completion_queue = cygrpc.CompletionQueue() + server = cygrpc.Server([ + ( + b'grpc.so_reuseport', + 0, + ), + ]) + server.register_completion_queue(server_completion_queue) + port = server.add_http2_port(b'[::]:0') + server.start() + channel = cygrpc.Channel('localhost:{}'.format(port).encode(), None, + None) + + state = _State() + + server_thread_args = ( + state, + server, + server_completion_queue, + server_thread_pool, + ) + server_thread = threading.Thread(target=_serve, args=server_thread_args) + server_thread.start() + + client_condition = threading.Condition() + client_due = set() + + with client_condition: + client_calls = [] + for index in range(test_constants.RPC_CONCURRENCY): + tag = 'client_complete_call_{0:04d}_tag'.format(index) + client_call = channel.integrated_call( + _EMPTY_FLAGS, b'/twinkies', None, None, _EMPTY_METADATA, + None, (( + ( + cygrpc.SendInitialMetadataOperation( + _EMPTY_METADATA, _EMPTY_FLAGS), + cygrpc.SendMessageOperation(b'\x45\x56', + _EMPTY_FLAGS), + cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), + cygrpc.ReceiveInitialMetadataOperation( + _EMPTY_FLAGS), + cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), + cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), + ), + tag, + ),)) + client_due.add(tag) + client_calls.append(client_call) + + client_events_future = test_utilities.SimpleFuture(lambda: tuple( + channel.next_call_event() for _ in range(_SUCCESSFUL_CALLS))) + + with state.condition: + while True: + if state.parked_handlers < test_constants.THREAD_CONCURRENCY: + state.condition.wait() + elif state.handled_rpcs < test_constants.RPC_CONCURRENCY: + state.condition.wait() + else: + state.handlers_released = True + state.condition.notify_all() + break + + client_events_future.result() + with client_condition: + for client_call in client_calls: + client_call.cancel(cygrpc.StatusCode.cancelled, 'Cancelled!') + for _ in range(_UNSUCCESSFUL_CALLS): + channel.next_call_event() + + channel.close(cygrpc.StatusCode.unknown, 'Cancelled on channel close!') + with state.condition: + server.shutdown(server_completion_queue, _SERVER_SHUTDOWN_TAG) + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_cython/_channel_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_cython/_channel_test.py new file mode 100644 index 00000000000..54f620523ea --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_cython/_channel_test.py @@ -0,0 +1,70 @@ +# Copyright 2016 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import threading +import unittest + +from grpc._cython import cygrpc + +from tests.unit.framework.common import test_constants + + +def _channel(): + return cygrpc.Channel(b'localhost:54321', (), None) + + +def _connectivity_loop(channel): + for _ in range(100): + connectivity = channel.check_connectivity_state(True) + channel.watch_connectivity_state(connectivity, time.time() + 0.2) + + +def _create_loop_destroy(): + channel = _channel() + _connectivity_loop(channel) + channel.close(cygrpc.StatusCode.ok, 'Channel close!') + + +def _in_parallel(behavior, arguments): + threads = tuple( + threading.Thread(target=behavior, args=arguments) + for _ in range(test_constants.THREAD_CONCURRENCY)) + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + +class ChannelTest(unittest.TestCase): + + def test_single_channel_lonely_connectivity(self): + channel = _channel() + _connectivity_loop(channel) + channel.close(cygrpc.StatusCode.ok, 'Channel close!') + + def test_multiple_channels_lonely_connectivity(self): + _in_parallel(_create_loop_destroy, ()) + + def test_negative_deadline_connectivity(self): + channel = _channel() + connectivity = channel.check_connectivity_state(True) + channel.watch_connectivity_state(connectivity, -3.14) + channel.close(cygrpc.StatusCode.ok, 'Channel close!') + # NOTE(lidiz) The negative timeout should not trigger SIGABRT. + # Bug report: https://github.com/grpc/grpc/issues/18244 + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_cython/_common.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_cython/_common.py new file mode 100644 index 00000000000..d8210f36f80 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_cython/_common.py @@ -0,0 +1,123 @@ +# Copyright 2017 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Common utilities for tests of the Cython layer of gRPC Python.""" + +import collections +import threading + +from grpc._cython import cygrpc + +RPC_COUNT = 4000 + +EMPTY_FLAGS = 0 + +INVOCATION_METADATA = ( + ('client-md-key', 'client-md-key'), + ('client-md-key-bin', b'\x00\x01' * 3000), +) + +INITIAL_METADATA = ( + ('server-initial-md-key', 'server-initial-md-value'), + ('server-initial-md-key-bin', b'\x00\x02' * 3000), +) + +TRAILING_METADATA = ( + ('server-trailing-md-key', 'server-trailing-md-value'), + ('server-trailing-md-key-bin', b'\x00\x03' * 3000), +) + + +class QueueDriver(object): + + def __init__(self, condition, completion_queue): + self._condition = condition + self._completion_queue = completion_queue + self._due = collections.defaultdict(int) + self._events = collections.defaultdict(list) + + def add_due(self, tags): + if not self._due: + + def in_thread(): + while True: + event = self._completion_queue.poll() + with self._condition: + self._events[event.tag].append(event) + self._due[event.tag] -= 1 + self._condition.notify_all() + if self._due[event.tag] <= 0: + self._due.pop(event.tag) + if not self._due: + return + + thread = threading.Thread(target=in_thread) + thread.start() + for tag in tags: + self._due[tag] += 1 + + def event_with_tag(self, tag): + with self._condition: + while True: + if self._events[tag]: + return self._events[tag].pop(0) + else: + self._condition.wait() + + +def execute_many_times(behavior): + return tuple(behavior() for _ in range(RPC_COUNT)) + + +class OperationResult( + collections.namedtuple('OperationResult', ( + 'start_batch_result', + 'completion_type', + 'success', + ))): + pass + + +SUCCESSFUL_OPERATION_RESULT = OperationResult( + cygrpc.CallError.ok, cygrpc.CompletionType.operation_complete, True) + + +class RpcTest(object): + + def setUp(self): + self.server_completion_queue = cygrpc.CompletionQueue() + self.server = cygrpc.Server([(b'grpc.so_reuseport', 0)]) + self.server.register_completion_queue(self.server_completion_queue) + port = self.server.add_http2_port(b'[::]:0') + self.server.start() + self.channel = cygrpc.Channel('localhost:{}'.format(port).encode(), [], + None) + + self._server_shutdown_tag = 'server_shutdown_tag' + self.server_condition = threading.Condition() + self.server_driver = QueueDriver(self.server_condition, + self.server_completion_queue) + with self.server_condition: + self.server_driver.add_due({ + self._server_shutdown_tag, + }) + + self.client_condition = threading.Condition() + self.client_completion_queue = cygrpc.CompletionQueue() + self.client_driver = QueueDriver(self.client_condition, + self.client_completion_queue) + + def tearDown(self): + self.server.shutdown(self.server_completion_queue, + self._server_shutdown_tag) + self.server.cancel_all_calls() diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_cython/_fork_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_cython/_fork_test.py new file mode 100644 index 00000000000..5a5dedd5f26 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_cython/_fork_test.py @@ -0,0 +1,72 @@ +# Copyright 2018 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import threading +import unittest + +from grpc._cython import cygrpc + + +def _get_number_active_threads(): + return cygrpc._fork_state.active_thread_count._num_active_threads + + [email protected](os.name == 'nt', 'Posix-specific tests') +class ForkPosixTester(unittest.TestCase): + + def setUp(self): + self._saved_fork_support_flag = cygrpc._GRPC_ENABLE_FORK_SUPPORT + cygrpc._GRPC_ENABLE_FORK_SUPPORT = True + + def testForkManagedThread(self): + + def cb(): + self.assertEqual(1, _get_number_active_threads()) + + thread = cygrpc.ForkManagedThread(cb) + thread.start() + thread.join() + self.assertEqual(0, _get_number_active_threads()) + + def testForkManagedThreadThrowsException(self): + + def cb(): + self.assertEqual(1, _get_number_active_threads()) + raise Exception("expected exception") + + thread = cygrpc.ForkManagedThread(cb) + thread.start() + thread.join() + self.assertEqual(0, _get_number_active_threads()) + + def tearDown(self): + cygrpc._GRPC_ENABLE_FORK_SUPPORT = self._saved_fork_support_flag + + [email protected](os.name == 'nt', 'Windows-specific tests') +class ForkWindowsTester(unittest.TestCase): + + def testForkManagedThreadIsNoOp(self): + + def cb(): + pass + + thread = cygrpc.ForkManagedThread(cb) + thread.start() + thread.join() + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_cython/_no_messages_server_completion_queue_per_call_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_cython/_no_messages_server_completion_queue_per_call_test.py new file mode 100644 index 00000000000..144a2fcae3f --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_cython/_no_messages_server_completion_queue_per_call_test.py @@ -0,0 +1,132 @@ +# Copyright 2017 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test a corner-case at the level of the Cython API.""" + +import threading +import unittest + +from grpc._cython import cygrpc + +from tests.unit._cython import _common +from tests.unit._cython import test_utilities + + +class Test(_common.RpcTest, unittest.TestCase): + + def _do_rpcs(self): + server_call_condition = threading.Condition() + server_call_completion_queue = cygrpc.CompletionQueue() + server_call_driver = _common.QueueDriver(server_call_condition, + server_call_completion_queue) + + server_request_call_tag = 'server_request_call_tag' + server_send_initial_metadata_tag = 'server_send_initial_metadata_tag' + server_complete_rpc_tag = 'server_complete_rpc_tag' + + with self.server_condition: + server_request_call_start_batch_result = self.server.request_call( + server_call_completion_queue, self.server_completion_queue, + server_request_call_tag) + self.server_driver.add_due({ + server_request_call_tag, + }) + + client_receive_initial_metadata_tag = 'client_receive_initial_metadata_tag' + client_complete_rpc_tag = 'client_complete_rpc_tag' + client_call = self.channel.integrated_call( + _common.EMPTY_FLAGS, b'/twinkies', None, None, + _common.INVOCATION_METADATA, None, [( + [ + cygrpc.ReceiveInitialMetadataOperation(_common.EMPTY_FLAGS), + ], + client_receive_initial_metadata_tag, + )]) + client_call.operate([ + cygrpc.SendInitialMetadataOperation(_common.INVOCATION_METADATA, + _common.EMPTY_FLAGS), + cygrpc.SendCloseFromClientOperation(_common.EMPTY_FLAGS), + cygrpc.ReceiveStatusOnClientOperation(_common.EMPTY_FLAGS), + ], client_complete_rpc_tag) + + client_events_future = test_utilities.SimpleFuture(lambda: [ + self.channel.next_call_event(), + self.channel.next_call_event(), + ]) + + server_request_call_event = self.server_driver.event_with_tag( + server_request_call_tag) + + with server_call_condition: + server_send_initial_metadata_start_batch_result = ( + server_request_call_event.call.start_server_batch([ + cygrpc.SendInitialMetadataOperation( + _common.INITIAL_METADATA, _common.EMPTY_FLAGS), + ], server_send_initial_metadata_tag)) + server_call_driver.add_due({ + server_send_initial_metadata_tag, + }) + server_send_initial_metadata_event = server_call_driver.event_with_tag( + server_send_initial_metadata_tag) + + with server_call_condition: + server_complete_rpc_start_batch_result = ( + server_request_call_event.call.start_server_batch([ + cygrpc.ReceiveCloseOnServerOperation(_common.EMPTY_FLAGS), + cygrpc.SendStatusFromServerOperation( + _common.TRAILING_METADATA, cygrpc.StatusCode.ok, + b'test details', _common.EMPTY_FLAGS), + ], server_complete_rpc_tag)) + server_call_driver.add_due({ + server_complete_rpc_tag, + }) + server_complete_rpc_event = server_call_driver.event_with_tag( + server_complete_rpc_tag) + + client_events = client_events_future.result() + if client_events[0].tag is client_receive_initial_metadata_tag: + client_receive_initial_metadata_event = client_events[0] + client_complete_rpc_event = client_events[1] + else: + client_complete_rpc_event = client_events[0] + client_receive_initial_metadata_event = client_events[1] + + return ( + _common.OperationResult(server_request_call_start_batch_result, + server_request_call_event.completion_type, + server_request_call_event.success), + _common.OperationResult( + cygrpc.CallError.ok, + client_receive_initial_metadata_event.completion_type, + client_receive_initial_metadata_event.success), + _common.OperationResult(cygrpc.CallError.ok, + client_complete_rpc_event.completion_type, + client_complete_rpc_event.success), + _common.OperationResult( + server_send_initial_metadata_start_batch_result, + server_send_initial_metadata_event.completion_type, + server_send_initial_metadata_event.success), + _common.OperationResult(server_complete_rpc_start_batch_result, + server_complete_rpc_event.completion_type, + server_complete_rpc_event.success), + ) + + def test_rpcs(self): + expecteds = [(_common.SUCCESSFUL_OPERATION_RESULT,) * 5 + ] * _common.RPC_COUNT + actuallys = _common.execute_many_times(self._do_rpcs) + self.assertSequenceEqual(expecteds, actuallys) + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_cython/_no_messages_single_server_completion_queue_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_cython/_no_messages_single_server_completion_queue_test.py new file mode 100644 index 00000000000..38964768db7 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_cython/_no_messages_single_server_completion_queue_test.py @@ -0,0 +1,126 @@ +# Copyright 2017 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test a corner-case at the level of the Cython API.""" + +import threading +import unittest + +from grpc._cython import cygrpc + +from tests.unit._cython import _common +from tests.unit._cython import test_utilities + + +class Test(_common.RpcTest, unittest.TestCase): + + def _do_rpcs(self): + server_request_call_tag = 'server_request_call_tag' + server_send_initial_metadata_tag = 'server_send_initial_metadata_tag' + server_complete_rpc_tag = 'server_complete_rpc_tag' + + with self.server_condition: + server_request_call_start_batch_result = self.server.request_call( + self.server_completion_queue, self.server_completion_queue, + server_request_call_tag) + self.server_driver.add_due({ + server_request_call_tag, + }) + + client_receive_initial_metadata_tag = 'client_receive_initial_metadata_tag' + client_complete_rpc_tag = 'client_complete_rpc_tag' + client_call = self.channel.integrated_call( + _common.EMPTY_FLAGS, b'/twinkies', None, None, + _common.INVOCATION_METADATA, None, [ + ( + [ + cygrpc.SendInitialMetadataOperation( + _common.INVOCATION_METADATA, _common.EMPTY_FLAGS), + cygrpc.SendCloseFromClientOperation( + _common.EMPTY_FLAGS), + cygrpc.ReceiveStatusOnClientOperation( + _common.EMPTY_FLAGS), + ], + client_complete_rpc_tag, + ), + ]) + client_call.operate([ + cygrpc.ReceiveInitialMetadataOperation(_common.EMPTY_FLAGS), + ], client_receive_initial_metadata_tag) + + client_events_future = test_utilities.SimpleFuture(lambda: [ + self.channel.next_call_event(), + self.channel.next_call_event(), + ]) + server_request_call_event = self.server_driver.event_with_tag( + server_request_call_tag) + + with self.server_condition: + server_send_initial_metadata_start_batch_result = ( + server_request_call_event.call.start_server_batch([ + cygrpc.SendInitialMetadataOperation( + _common.INITIAL_METADATA, _common.EMPTY_FLAGS), + ], server_send_initial_metadata_tag)) + self.server_driver.add_due({ + server_send_initial_metadata_tag, + }) + server_send_initial_metadata_event = self.server_driver.event_with_tag( + server_send_initial_metadata_tag) + + with self.server_condition: + server_complete_rpc_start_batch_result = ( + server_request_call_event.call.start_server_batch([ + cygrpc.ReceiveCloseOnServerOperation(_common.EMPTY_FLAGS), + cygrpc.SendStatusFromServerOperation( + _common.TRAILING_METADATA, cygrpc.StatusCode.ok, + 'test details', _common.EMPTY_FLAGS), + ], server_complete_rpc_tag)) + self.server_driver.add_due({ + server_complete_rpc_tag, + }) + server_complete_rpc_event = self.server_driver.event_with_tag( + server_complete_rpc_tag) + + client_events = client_events_future.result() + client_receive_initial_metadata_event = client_events[0] + client_complete_rpc_event = client_events[1] + + return ( + _common.OperationResult(server_request_call_start_batch_result, + server_request_call_event.completion_type, + server_request_call_event.success), + _common.OperationResult( + cygrpc.CallError.ok, + client_receive_initial_metadata_event.completion_type, + client_receive_initial_metadata_event.success), + _common.OperationResult(cygrpc.CallError.ok, + client_complete_rpc_event.completion_type, + client_complete_rpc_event.success), + _common.OperationResult( + server_send_initial_metadata_start_batch_result, + server_send_initial_metadata_event.completion_type, + server_send_initial_metadata_event.success), + _common.OperationResult(server_complete_rpc_start_batch_result, + server_complete_rpc_event.completion_type, + server_complete_rpc_event.success), + ) + + def test_rpcs(self): + expecteds = [(_common.SUCCESSFUL_OPERATION_RESULT,) * 5 + ] * _common.RPC_COUNT + actuallys = _common.execute_many_times(self._do_rpcs) + self.assertSequenceEqual(expecteds, actuallys) + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py new file mode 100644 index 00000000000..8a903bfaf91 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py @@ -0,0 +1,240 @@ +# Copyright 2016 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test a corner-case at the level of the Cython API.""" + +import threading +import unittest + +from grpc._cython import cygrpc +from tests.unit._cython import test_utilities + +_EMPTY_FLAGS = 0 +_EMPTY_METADATA = () + + +class _ServerDriver(object): + + def __init__(self, completion_queue, shutdown_tag): + self._condition = threading.Condition() + self._completion_queue = completion_queue + self._shutdown_tag = shutdown_tag + self._events = [] + self._saw_shutdown_tag = False + + def start(self): + + def in_thread(): + while True: + event = self._completion_queue.poll() + with self._condition: + self._events.append(event) + self._condition.notify() + if event.tag is self._shutdown_tag: + self._saw_shutdown_tag = True + break + + thread = threading.Thread(target=in_thread) + thread.start() + + def done(self): + with self._condition: + return self._saw_shutdown_tag + + def first_event(self): + with self._condition: + while not self._events: + self._condition.wait() + return self._events[0] + + def events(self): + with self._condition: + while not self._saw_shutdown_tag: + self._condition.wait() + return tuple(self._events) + + +class _QueueDriver(object): + + def __init__(self, condition, completion_queue, due): + self._condition = condition + self._completion_queue = completion_queue + self._due = due + self._events = [] + self._returned = False + + def start(self): + + def in_thread(): + while True: + event = self._completion_queue.poll() + with self._condition: + self._events.append(event) + self._due.remove(event.tag) + self._condition.notify_all() + if not self._due: + self._returned = True + return + + thread = threading.Thread(target=in_thread) + thread.start() + + def done(self): + with self._condition: + return self._returned + + def event_with_tag(self, tag): + with self._condition: + while True: + for event in self._events: + if event.tag is tag: + return event + self._condition.wait() + + def events(self): + with self._condition: + while not self._returned: + self._condition.wait() + return tuple(self._events) + + +class ReadSomeButNotAllResponsesTest(unittest.TestCase): + + def testReadSomeButNotAllResponses(self): + server_completion_queue = cygrpc.CompletionQueue() + server = cygrpc.Server([( + b'grpc.so_reuseport', + 0, + )]) + server.register_completion_queue(server_completion_queue) + port = server.add_http2_port(b'[::]:0') + server.start() + channel = cygrpc.Channel('localhost:{}'.format(port).encode(), set(), + None) + + server_shutdown_tag = 'server_shutdown_tag' + server_driver = _ServerDriver(server_completion_queue, + server_shutdown_tag) + server_driver.start() + + client_condition = threading.Condition() + client_due = set() + + server_call_condition = threading.Condition() + server_send_initial_metadata_tag = 'server_send_initial_metadata_tag' + server_send_first_message_tag = 'server_send_first_message_tag' + server_send_second_message_tag = 'server_send_second_message_tag' + server_complete_rpc_tag = 'server_complete_rpc_tag' + server_call_due = set(( + server_send_initial_metadata_tag, + server_send_first_message_tag, + server_send_second_message_tag, + server_complete_rpc_tag, + )) + server_call_completion_queue = cygrpc.CompletionQueue() + server_call_driver = _QueueDriver(server_call_condition, + server_call_completion_queue, + server_call_due) + server_call_driver.start() + + server_rpc_tag = 'server_rpc_tag' + request_call_result = server.request_call(server_call_completion_queue, + server_completion_queue, + server_rpc_tag) + + client_receive_initial_metadata_tag = 'client_receive_initial_metadata_tag' + client_complete_rpc_tag = 'client_complete_rpc_tag' + client_call = channel.segregated_call( + _EMPTY_FLAGS, b'/twinkies', None, None, _EMPTY_METADATA, None, ( + ( + [ + cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS), + ], + client_receive_initial_metadata_tag, + ), + ( + [ + cygrpc.SendInitialMetadataOperation( + _EMPTY_METADATA, _EMPTY_FLAGS), + cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), + cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), + ], + client_complete_rpc_tag, + ), + )) + client_receive_initial_metadata_event_future = test_utilities.SimpleFuture( + client_call.next_event) + + server_rpc_event = server_driver.first_event() + + with server_call_condition: + server_send_initial_metadata_start_batch_result = ( + server_rpc_event.call.start_server_batch([ + cygrpc.SendInitialMetadataOperation(_EMPTY_METADATA, + _EMPTY_FLAGS), + ], server_send_initial_metadata_tag)) + server_send_first_message_start_batch_result = ( + server_rpc_event.call.start_server_batch([ + cygrpc.SendMessageOperation(b'\x07', _EMPTY_FLAGS), + ], server_send_first_message_tag)) + server_send_initial_metadata_event = server_call_driver.event_with_tag( + server_send_initial_metadata_tag) + server_send_first_message_event = server_call_driver.event_with_tag( + server_send_first_message_tag) + with server_call_condition: + server_send_second_message_start_batch_result = ( + server_rpc_event.call.start_server_batch([ + cygrpc.SendMessageOperation(b'\x07', _EMPTY_FLAGS), + ], server_send_second_message_tag)) + server_complete_rpc_start_batch_result = ( + server_rpc_event.call.start_server_batch([ + cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS), + cygrpc.SendStatusFromServerOperation( + (), cygrpc.StatusCode.ok, b'test details', + _EMPTY_FLAGS), + ], server_complete_rpc_tag)) + server_send_second_message_event = server_call_driver.event_with_tag( + server_send_second_message_tag) + server_complete_rpc_event = server_call_driver.event_with_tag( + server_complete_rpc_tag) + server_call_driver.events() + + client_recieve_initial_metadata_event = client_receive_initial_metadata_event_future.result( + ) + + client_receive_first_message_tag = 'client_receive_first_message_tag' + client_call.operate([ + cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), + ], client_receive_first_message_tag) + client_receive_first_message_event = client_call.next_event() + + client_call_cancel_result = client_call.cancel( + cygrpc.StatusCode.cancelled, 'Cancelled during test!') + client_complete_rpc_event = client_call.next_event() + + channel.close(cygrpc.StatusCode.unknown, 'Channel closed!') + server.shutdown(server_completion_queue, server_shutdown_tag) + server.cancel_all_calls() + server_driver.events() + + self.assertEqual(cygrpc.CallError.ok, request_call_result) + self.assertEqual(cygrpc.CallError.ok, + server_send_initial_metadata_start_batch_result) + self.assertIs(server_rpc_tag, server_rpc_event.tag) + self.assertEqual(cygrpc.CompletionType.operation_complete, + server_rpc_event.completion_type) + self.assertIsInstance(server_rpc_event.call, cygrpc.Call) + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_cython/_server_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_cython/_server_test.py new file mode 100644 index 00000000000..bbd25457b3e --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_cython/_server_test.py @@ -0,0 +1,49 @@ +# Copyright 2017 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test servers at the level of the Cython API.""" + +import threading +import time +import unittest + +from grpc._cython import cygrpc + + +class Test(unittest.TestCase): + + def test_lonely_server(self): + server_call_completion_queue = cygrpc.CompletionQueue() + server_shutdown_completion_queue = cygrpc.CompletionQueue() + server = cygrpc.Server(None) + server.register_completion_queue(server_call_completion_queue) + server.register_completion_queue(server_shutdown_completion_queue) + port = server.add_http2_port(b'[::]:0') + server.start() + + server_request_call_tag = 'server_request_call_tag' + server_request_call_start_batch_result = server.request_call( + server_call_completion_queue, server_call_completion_queue, + server_request_call_tag) + + time.sleep(4) + + server_shutdown_tag = 'server_shutdown_tag' + server_shutdown_result = server.shutdown( + server_shutdown_completion_queue, server_shutdown_tag) + server_request_call_event = server_call_completion_queue.poll() + server_shutdown_event = server_shutdown_completion_queue.poll() + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py new file mode 100644 index 00000000000..1182f83a425 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py @@ -0,0 +1,422 @@ +# Copyright 2015 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import threading +import unittest +import platform + +from grpc._cython import cygrpc +from tests.unit._cython import test_utilities +from tests.unit import test_common +from tests.unit import resources + +_SSL_HOST_OVERRIDE = b'foo.test.google.fr' +_CALL_CREDENTIALS_METADATA_KEY = 'call-creds-key' +_CALL_CREDENTIALS_METADATA_VALUE = 'call-creds-value' +_EMPTY_FLAGS = 0 + + +def _metadata_plugin(context, callback): + callback((( + _CALL_CREDENTIALS_METADATA_KEY, + _CALL_CREDENTIALS_METADATA_VALUE, + ),), cygrpc.StatusCode.ok, b'') + + +class TypeSmokeTest(unittest.TestCase): + + def testCompletionQueueUpDown(self): + completion_queue = cygrpc.CompletionQueue() + del completion_queue + + def testServerUpDown(self): + server = cygrpc.Server(set([ + ( + b'grpc.so_reuseport', + 0, + ), + ])) + del server + + def testChannelUpDown(self): + channel = cygrpc.Channel(b'[::]:0', None, None) + channel.close(cygrpc.StatusCode.cancelled, 'Test method anyway!') + + def test_metadata_plugin_call_credentials_up_down(self): + cygrpc.MetadataPluginCallCredentials(_metadata_plugin, + b'test plugin name!') + + def testServerStartNoExplicitShutdown(self): + server = cygrpc.Server([ + ( + b'grpc.so_reuseport', + 0, + ), + ]) + completion_queue = cygrpc.CompletionQueue() + server.register_completion_queue(completion_queue) + port = server.add_http2_port(b'[::]:0') + self.assertIsInstance(port, int) + server.start() + del server + + def testServerStartShutdown(self): + completion_queue = cygrpc.CompletionQueue() + server = cygrpc.Server([ + ( + b'grpc.so_reuseport', + 0, + ), + ]) + server.add_http2_port(b'[::]:0') + server.register_completion_queue(completion_queue) + server.start() + shutdown_tag = object() + server.shutdown(completion_queue, shutdown_tag) + event = completion_queue.poll() + self.assertEqual(cygrpc.CompletionType.operation_complete, + event.completion_type) + self.assertIs(shutdown_tag, event.tag) + del server + del completion_queue + + +class ServerClientMixin(object): + + def setUpMixin(self, server_credentials, client_credentials, host_override): + self.server_completion_queue = cygrpc.CompletionQueue() + self.server = cygrpc.Server([ + ( + b'grpc.so_reuseport', + 0, + ), + ]) + self.server.register_completion_queue(self.server_completion_queue) + if server_credentials: + self.port = self.server.add_http2_port(b'[::]:0', + server_credentials) + else: + self.port = self.server.add_http2_port(b'[::]:0') + self.server.start() + self.client_completion_queue = cygrpc.CompletionQueue() + if client_credentials: + client_channel_arguments = (( + cygrpc.ChannelArgKey.ssl_target_name_override, + host_override, + ),) + self.client_channel = cygrpc.Channel( + 'localhost:{}'.format(self.port).encode(), + client_channel_arguments, client_credentials) + else: + self.client_channel = cygrpc.Channel( + 'localhost:{}'.format(self.port).encode(), set(), None) + if host_override: + self.host_argument = None # default host + self.expected_host = host_override + else: + # arbitrary host name necessitating no further identification + self.host_argument = b'hostess' + self.expected_host = self.host_argument + + def tearDownMixin(self): + self.client_channel.close(cygrpc.StatusCode.ok, 'test being torn down!') + del self.client_channel + del self.server + del self.client_completion_queue + del self.server_completion_queue + + def _perform_queue_operations(self, operations, call, queue, deadline, + description): + """Perform the operations with given call, queue, and deadline. + + Invocation errors are reported with as an exception with `description` + in the message. Performs the operations asynchronously, returning a + future. + """ + + def performer(): + tag = object() + try: + call_result = call.start_client_batch(operations, tag) + self.assertEqual(cygrpc.CallError.ok, call_result) + event = queue.poll(deadline=deadline) + self.assertEqual(cygrpc.CompletionType.operation_complete, + event.completion_type) + self.assertTrue(event.success) + self.assertIs(tag, event.tag) + except Exception as error: + raise Exception("Error in '{}': {}".format( + description, error.message)) + return event + + return test_utilities.SimpleFuture(performer) + + def test_echo(self): + DEADLINE = time.time() + 5 + DEADLINE_TOLERANCE = 0.25 + CLIENT_METADATA_ASCII_KEY = 'key' + CLIENT_METADATA_ASCII_VALUE = 'val' + CLIENT_METADATA_BIN_KEY = 'key-bin' + CLIENT_METADATA_BIN_VALUE = b'\0' * 1000 + SERVER_INITIAL_METADATA_KEY = 'init_me_me_me' + SERVER_INITIAL_METADATA_VALUE = 'whodawha?' + SERVER_TRAILING_METADATA_KEY = 'california_is_in_a_drought' + SERVER_TRAILING_METADATA_VALUE = 'zomg it is' + SERVER_STATUS_CODE = cygrpc.StatusCode.ok + SERVER_STATUS_DETAILS = 'our work is never over' + REQUEST = b'in death a member of project mayhem has a name' + RESPONSE = b'his name is robert paulson' + METHOD = b'twinkies' + + server_request_tag = object() + request_call_result = self.server.request_call( + self.server_completion_queue, self.server_completion_queue, + server_request_tag) + + self.assertEqual(cygrpc.CallError.ok, request_call_result) + + client_call_tag = object() + client_initial_metadata = ( + ( + CLIENT_METADATA_ASCII_KEY, + CLIENT_METADATA_ASCII_VALUE, + ), + ( + CLIENT_METADATA_BIN_KEY, + CLIENT_METADATA_BIN_VALUE, + ), + ) + client_call = self.client_channel.integrated_call( + 0, METHOD, self.host_argument, DEADLINE, client_initial_metadata, + None, [ + ( + [ + cygrpc.SendInitialMetadataOperation( + client_initial_metadata, _EMPTY_FLAGS), + cygrpc.SendMessageOperation(REQUEST, _EMPTY_FLAGS), + cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), + cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS), + cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), + cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), + ], + client_call_tag, + ), + ]) + client_event_future = test_utilities.SimpleFuture( + self.client_channel.next_call_event) + + request_event = self.server_completion_queue.poll(deadline=DEADLINE) + self.assertEqual(cygrpc.CompletionType.operation_complete, + request_event.completion_type) + self.assertIsInstance(request_event.call, cygrpc.Call) + self.assertIs(server_request_tag, request_event.tag) + self.assertTrue( + test_common.metadata_transmitted(client_initial_metadata, + request_event.invocation_metadata)) + self.assertEqual(METHOD, request_event.call_details.method) + self.assertEqual(self.expected_host, request_event.call_details.host) + self.assertLess(abs(DEADLINE - request_event.call_details.deadline), + DEADLINE_TOLERANCE) + + server_call_tag = object() + server_call = request_event.call + server_initial_metadata = (( + SERVER_INITIAL_METADATA_KEY, + SERVER_INITIAL_METADATA_VALUE, + ),) + server_trailing_metadata = (( + SERVER_TRAILING_METADATA_KEY, + SERVER_TRAILING_METADATA_VALUE, + ),) + server_start_batch_result = server_call.start_server_batch([ + cygrpc.SendInitialMetadataOperation(server_initial_metadata, + _EMPTY_FLAGS), + cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), + cygrpc.SendMessageOperation(RESPONSE, _EMPTY_FLAGS), + cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS), + cygrpc.SendStatusFromServerOperation( + server_trailing_metadata, SERVER_STATUS_CODE, + SERVER_STATUS_DETAILS, _EMPTY_FLAGS) + ], server_call_tag) + self.assertEqual(cygrpc.CallError.ok, server_start_batch_result) + + server_event = self.server_completion_queue.poll(deadline=DEADLINE) + client_event = client_event_future.result() + + self.assertEqual(6, len(client_event.batch_operations)) + found_client_op_types = set() + for client_result in client_event.batch_operations: + # we expect each op type to be unique + self.assertNotIn(client_result.type(), found_client_op_types) + found_client_op_types.add(client_result.type()) + if client_result.type( + ) == cygrpc.OperationType.receive_initial_metadata: + self.assertTrue( + test_common.metadata_transmitted( + server_initial_metadata, + client_result.initial_metadata())) + elif client_result.type() == cygrpc.OperationType.receive_message: + self.assertEqual(RESPONSE, client_result.message()) + elif client_result.type( + ) == cygrpc.OperationType.receive_status_on_client: + self.assertTrue( + test_common.metadata_transmitted( + server_trailing_metadata, + client_result.trailing_metadata())) + self.assertEqual(SERVER_STATUS_DETAILS, client_result.details()) + self.assertEqual(SERVER_STATUS_CODE, client_result.code()) + self.assertEqual( + set([ + cygrpc.OperationType.send_initial_metadata, + cygrpc.OperationType.send_message, + cygrpc.OperationType.send_close_from_client, + cygrpc.OperationType.receive_initial_metadata, + cygrpc.OperationType.receive_message, + cygrpc.OperationType.receive_status_on_client + ]), found_client_op_types) + + self.assertEqual(5, len(server_event.batch_operations)) + found_server_op_types = set() + for server_result in server_event.batch_operations: + self.assertNotIn(server_result.type(), found_server_op_types) + found_server_op_types.add(server_result.type()) + if server_result.type() == cygrpc.OperationType.receive_message: + self.assertEqual(REQUEST, server_result.message()) + elif server_result.type( + ) == cygrpc.OperationType.receive_close_on_server: + self.assertFalse(server_result.cancelled()) + self.assertEqual( + set([ + cygrpc.OperationType.send_initial_metadata, + cygrpc.OperationType.receive_message, + cygrpc.OperationType.send_message, + cygrpc.OperationType.receive_close_on_server, + cygrpc.OperationType.send_status_from_server + ]), found_server_op_types) + + del client_call + del server_call + + def test_6522(self): + DEADLINE = time.time() + 5 + DEADLINE_TOLERANCE = 0.25 + METHOD = b'twinkies' + + empty_metadata = () + + # Prologue + server_request_tag = object() + self.server.request_call(self.server_completion_queue, + self.server_completion_queue, + server_request_tag) + client_call = self.client_channel.segregated_call( + 0, METHOD, self.host_argument, DEADLINE, None, None, + ([( + [ + cygrpc.SendInitialMetadataOperation(empty_metadata, + _EMPTY_FLAGS), + cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS), + ], + object(), + ), + ( + [ + cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), + ], + object(), + )])) + + client_initial_metadata_event_future = test_utilities.SimpleFuture( + client_call.next_event) + + request_event = self.server_completion_queue.poll(deadline=DEADLINE) + server_call = request_event.call + + def perform_server_operations(operations, description): + return self._perform_queue_operations(operations, server_call, + self.server_completion_queue, + DEADLINE, description) + + server_event_future = perform_server_operations([ + cygrpc.SendInitialMetadataOperation(empty_metadata, _EMPTY_FLAGS), + ], "Server prologue") + + client_initial_metadata_event_future.result() # force completion + server_event_future.result() + + # Messaging + for _ in range(10): + client_call.operate([ + cygrpc.SendMessageOperation(b'', _EMPTY_FLAGS), + cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), + ], "Client message") + client_message_event_future = test_utilities.SimpleFuture( + client_call.next_event) + server_event_future = perform_server_operations([ + cygrpc.SendMessageOperation(b'', _EMPTY_FLAGS), + cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), + ], "Server receive") + + client_message_event_future.result() # force completion + server_event_future.result() + + # Epilogue + client_call.operate([ + cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), + ], "Client epilogue") + # One for ReceiveStatusOnClient, one for SendCloseFromClient. + client_events_future = test_utilities.SimpleFuture(lambda: { + client_call.next_event(), + client_call.next_event(), + }) + + server_event_future = perform_server_operations([ + cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS), + cygrpc.SendStatusFromServerOperation( + empty_metadata, cygrpc.StatusCode.ok, b'', _EMPTY_FLAGS) + ], "Server epilogue") + + client_events_future.result() # force completion + server_event_future.result() + + +class InsecureServerInsecureClient(unittest.TestCase, ServerClientMixin): + + def setUp(self): + self.setUpMixin(None, None, None) + + def tearDown(self): + self.tearDownMixin() + + +class SecureServerSecureClient(unittest.TestCase, ServerClientMixin): + + def setUp(self): + server_credentials = cygrpc.server_credentials_ssl( + None, [ + cygrpc.SslPemKeyCertPair(resources.private_key(), + resources.certificate_chain()) + ], False) + client_credentials = cygrpc.SSLChannelCredentials( + resources.test_root_certificates(), None, None) + self.setUpMixin(server_credentials, client_credentials, + _SSL_HOST_OVERRIDE) + + def tearDown(self): + self.tearDownMixin() + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_cython/test_utilities.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_cython/test_utilities.py new file mode 100644 index 00000000000..7d5eaaaa842 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_cython/test_utilities.py @@ -0,0 +1,52 @@ +# Copyright 2015 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading + +from grpc._cython import cygrpc + + +class SimpleFuture(object): + """A simple future mechanism.""" + + def __init__(self, function, *args, **kwargs): + + def wrapped_function(): + try: + self._result = function(*args, **kwargs) + except Exception as error: # pylint: disable=broad-except + self._error = error + + self._result = None + self._error = None + self._thread = threading.Thread(target=wrapped_function) + self._thread.start() + + def result(self): + """The resulting value of this future. + + Re-raises any exceptions. + """ + self._thread.join() + if self._error: + # TODO(atash): re-raise exceptions in a way that preserves tracebacks + raise self._error # pylint: disable=raising-bad-type + return self._result + + +class CompletionQueuePollFuture(SimpleFuture): + + def __init__(self, completion_queue, deadline): + super(CompletionQueuePollFuture, + self).__init__(lambda: completion_queue.poll(deadline=deadline)) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_dns_resolver_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_dns_resolver_test.py new file mode 100644 index 00000000000..43141255f1c --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_dns_resolver_test.py @@ -0,0 +1,63 @@ +# Copyright 2019 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for an actual dns resolution.""" + +import unittest +import logging +import six + +import grpc +from tests.unit import test_common +from tests.unit.framework.common import test_constants + +_METHOD = '/ANY/METHOD' +_REQUEST = b'\x00\x00\x00' +_RESPONSE = _REQUEST + + +class GenericHandler(grpc.GenericRpcHandler): + + def service(self, unused_handler_details): + return grpc.unary_unary_rpc_method_handler( + lambda request, unused_context: request, + ) + + +class DNSResolverTest(unittest.TestCase): + + def setUp(self): + self._server = test_common.test_server() + self._server.add_generic_rpc_handlers((GenericHandler(),)) + self._port = self._server.add_insecure_port('[::]:0') + self._server.start() + + def tearDown(self): + self._server.stop(None) + + def test_connect_loopback(self): + # NOTE(https://github.com/grpc/grpc/issues/18422) + # In short, Gevent + C-Ares = Segfault. The C-Ares driver is not + # supported by custom io manager like "gevent" or "libuv". + with grpc.insecure_channel('loopback4.unittest.grpc.io:%d' % + self._port) as channel: + self.assertEqual( + channel.unary_unary(_METHOD)( + _REQUEST, + timeout=test_constants.SHORT_TIMEOUT, + ), _RESPONSE) + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_dynamic_stubs_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_dynamic_stubs_test.py new file mode 100644 index 00000000000..d2d8ce9f60b --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_dynamic_stubs_test.py @@ -0,0 +1,119 @@ +# Copyright 2019 The gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test of dynamic stub import API.""" + +import contextlib +import functools +import logging +import multiprocessing +import os +import sys +import unittest + + +def _grpc_tools_unimportable(): + original_sys_path = sys.path + sys.path = [path for path in sys.path if "grpcio_tools" not in path] + try: + import grpc_tools + except ImportError: + pass + else: + del grpc_tools + sys.path = original_sys_path + raise unittest.SkipTest("Failed to make grpc_tools unimportable.") + try: + yield + finally: + sys.path = original_sys_path + + +def _collect_errors(fn): + + @functools.wraps(fn) + def _wrapped(error_queue): + try: + fn() + except Exception as e: + error_queue.put(e) + raise + + return _wrapped + + +def _run_in_subprocess(test_case): + sys.path.insert( + 0, os.path.join(os.path.realpath(os.path.dirname(__file__)), "..")) + error_queue = multiprocessing.Queue() + proc = multiprocessing.Process(target=test_case, args=(error_queue,)) + proc.start() + proc.join() + sys.path.pop(0) + if not error_queue.empty(): + raise error_queue.get() + assert proc.exitcode == 0, "Process exited with code {}".format( + proc.exitcode) + + +def _assert_unimplemented(msg_substr): + import grpc + try: + protos, services = grpc.protos_and_services( + "tests/unit/data/foo/bar.proto") + except NotImplementedError as e: + assert msg_substr in str(e), "{} was not in '{}'".format( + msg_substr, str(e)) + else: + assert False, "Did not raise NotImplementedError" + + +@_collect_errors +def _test_sunny_day(): + if sys.version_info[0] == 3: + import grpc + protos, services = grpc.protos_and_services( + os.path.join("tests", "unit", "data", "foo", "bar.proto")) + assert protos.BarMessage is not None + assert services.BarStub is not None + else: + _assert_unimplemented("Python 3") + + +@_collect_errors +def _test_grpc_tools_unimportable(): + with _grpc_tools_unimportable(): + if sys.version_info[0] == 3: + _assert_unimplemented("grpcio-tools") + else: + _assert_unimplemented("Python 3") + + +# NOTE(rbellevi): multiprocessing.Process fails to pickle function objects +# when they do not come from the "__main__" module, so this test passes +# if run directly on Windows, but not if started by the test runner. [email protected](os.name == "nt", "Windows multiprocessing unsupported") +class DynamicStubTest(unittest.TestCase): + + @unittest.skip('grpcio-tools package required') + def test_sunny_day(self): + _run_in_subprocess(_test_sunny_day) + + def test_grpc_tools_unimportable(self): + _run_in_subprocess(_test_grpc_tools_unimportable) + + +if __name__ == "__main__": + logging.basicConfig() + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_empty_message_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_empty_message_test.py new file mode 100644 index 00000000000..f27ea422d0c --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_empty_message_test.py @@ -0,0 +1,124 @@ +# Copyright 2016 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import logging + +import grpc + +from tests.unit import test_common +from tests.unit.framework.common import test_constants + +_REQUEST = b'' +_RESPONSE = b'' + +_UNARY_UNARY = '/test/UnaryUnary' +_UNARY_STREAM = '/test/UnaryStream' +_STREAM_UNARY = '/test/StreamUnary' +_STREAM_STREAM = '/test/StreamStream' + + +def handle_unary_unary(request, servicer_context): + return _RESPONSE + + +def handle_unary_stream(request, servicer_context): + for _ in range(test_constants.STREAM_LENGTH): + yield _RESPONSE + + +def handle_stream_unary(request_iterator, servicer_context): + for request in request_iterator: + pass + return _RESPONSE + + +def handle_stream_stream(request_iterator, servicer_context): + for request in request_iterator: + yield _RESPONSE + + +class _MethodHandler(grpc.RpcMethodHandler): + + def __init__(self, request_streaming, response_streaming): + self.request_streaming = request_streaming + self.response_streaming = response_streaming + self.request_deserializer = None + self.response_serializer = None + self.unary_unary = None + self.unary_stream = None + self.stream_unary = None + self.stream_stream = None + if self.request_streaming and self.response_streaming: + self.stream_stream = handle_stream_stream + elif self.request_streaming: + self.stream_unary = handle_stream_unary + elif self.response_streaming: + self.unary_stream = handle_unary_stream + else: + self.unary_unary = handle_unary_unary + + +class _GenericHandler(grpc.GenericRpcHandler): + + def service(self, handler_call_details): + if handler_call_details.method == _UNARY_UNARY: + return _MethodHandler(False, False) + elif handler_call_details.method == _UNARY_STREAM: + return _MethodHandler(False, True) + elif handler_call_details.method == _STREAM_UNARY: + return _MethodHandler(True, False) + elif handler_call_details.method == _STREAM_STREAM: + return _MethodHandler(True, True) + else: + return None + + +class EmptyMessageTest(unittest.TestCase): + + def setUp(self): + self._server = test_common.test_server() + self._server.add_generic_rpc_handlers((_GenericHandler(),)) + port = self._server.add_insecure_port('[::]:0') + self._server.start() + self._channel = grpc.insecure_channel('localhost:%d' % port) + + def tearDown(self): + self._server.stop(0) + self._channel.close() + + def testUnaryUnary(self): + response = self._channel.unary_unary(_UNARY_UNARY)(_REQUEST) + self.assertEqual(_RESPONSE, response) + + def testUnaryStream(self): + response_iterator = self._channel.unary_stream(_UNARY_STREAM)(_REQUEST) + self.assertSequenceEqual([_RESPONSE] * test_constants.STREAM_LENGTH, + list(response_iterator)) + + def testStreamUnary(self): + response = self._channel.stream_unary(_STREAM_UNARY)(iter( + [_REQUEST] * test_constants.STREAM_LENGTH)) + self.assertEqual(_RESPONSE, response) + + def testStreamStream(self): + response_iterator = self._channel.stream_stream(_STREAM_STREAM)(iter( + [_REQUEST] * test_constants.STREAM_LENGTH)) + self.assertSequenceEqual([_RESPONSE] * test_constants.STREAM_LENGTH, + list(response_iterator)) + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_error_message_encoding_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_error_message_encoding_test.py new file mode 100644 index 00000000000..e58007ad3ed --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_error_message_encoding_test.py @@ -0,0 +1,87 @@ +# Copyright 2018 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests 'utf-8' encoded error message.""" + +import unittest +import weakref + +import grpc + +from tests.unit import test_common +from tests.unit.framework.common import test_constants + +_UNICODE_ERROR_MESSAGES = [ + b'\xe2\x80\x9d'.decode('utf-8'), + b'abc\x80\xd0\xaf'.decode('latin-1'), + b'\xc3\xa9'.decode('utf-8'), +] + +_REQUEST = b'\x00\x00\x00' +_RESPONSE = b'\x00\x00\x00' + +_UNARY_UNARY = '/test/UnaryUnary' + + +class _MethodHandler(grpc.RpcMethodHandler): + + def __init__(self, request_streaming=None, response_streaming=None): + self.request_streaming = request_streaming + self.response_streaming = response_streaming + self.request_deserializer = None + self.response_serializer = None + self.unary_stream = None + self.stream_unary = None + self.stream_stream = None + + def unary_unary(self, request, servicer_context): + servicer_context.set_code(grpc.StatusCode.UNKNOWN) + servicer_context.set_details(request.decode('utf-8')) + return _RESPONSE + + +class _GenericHandler(grpc.GenericRpcHandler): + + def __init__(self, test): + self._test = test + + def service(self, handler_call_details): + return _MethodHandler() + + +class ErrorMessageEncodingTest(unittest.TestCase): + + def setUp(self): + self._server = test_common.test_server() + self._server.add_generic_rpc_handlers( + (_GenericHandler(weakref.proxy(self)),)) + port = self._server.add_insecure_port('[::]:0') + self._server.start() + self._channel = grpc.insecure_channel('localhost:%d' % port) + + def tearDown(self): + self._server.stop(0) + self._channel.close() + + def testMessageEncoding(self): + for message in _UNICODE_ERROR_MESSAGES: + multi_callable = self._channel.unary_unary(_UNARY_UNARY) + with self.assertRaises(grpc.RpcError) as cm: + multi_callable(message.encode('utf-8')) + + self.assertEqual(cm.exception.code(), grpc.StatusCode.UNKNOWN) + self.assertEqual(cm.exception.details(), message) + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_exit_scenarios.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_exit_scenarios.py new file mode 100644 index 00000000000..48ea054d2d7 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_exit_scenarios.py @@ -0,0 +1,236 @@ +# Copyright 2016 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Defines a number of module-scope gRPC scenarios to test clean exit.""" + +import argparse +import threading +import time +import logging + +import grpc + +from tests.unit.framework.common import test_constants + +WAIT_TIME = 1000 + +REQUEST = b'request' + +UNSTARTED_SERVER = 'unstarted_server' +RUNNING_SERVER = 'running_server' +POLL_CONNECTIVITY_NO_SERVER = 'poll_connectivity_no_server' +POLL_CONNECTIVITY = 'poll_connectivity' +IN_FLIGHT_UNARY_UNARY_CALL = 'in_flight_unary_unary_call' +IN_FLIGHT_UNARY_STREAM_CALL = 'in_flight_unary_stream_call' +IN_FLIGHT_STREAM_UNARY_CALL = 'in_flight_stream_unary_call' +IN_FLIGHT_STREAM_STREAM_CALL = 'in_flight_stream_stream_call' +IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL = 'in_flight_partial_unary_stream_call' +IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL = 'in_flight_partial_stream_unary_call' +IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL = 'in_flight_partial_stream_stream_call' + +UNARY_UNARY = b'/test/UnaryUnary' +UNARY_STREAM = b'/test/UnaryStream' +STREAM_UNARY = b'/test/StreamUnary' +STREAM_STREAM = b'/test/StreamStream' +PARTIAL_UNARY_STREAM = b'/test/PartialUnaryStream' +PARTIAL_STREAM_UNARY = b'/test/PartialStreamUnary' +PARTIAL_STREAM_STREAM = b'/test/PartialStreamStream' + +TEST_TO_METHOD = { + IN_FLIGHT_UNARY_UNARY_CALL: UNARY_UNARY, + IN_FLIGHT_UNARY_STREAM_CALL: UNARY_STREAM, + IN_FLIGHT_STREAM_UNARY_CALL: STREAM_UNARY, + IN_FLIGHT_STREAM_STREAM_CALL: STREAM_STREAM, + IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL: PARTIAL_UNARY_STREAM, + IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL: PARTIAL_STREAM_UNARY, + IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL: PARTIAL_STREAM_STREAM, +} + + +def hang_unary_unary(request, servicer_context): + time.sleep(WAIT_TIME) + + +def hang_unary_stream(request, servicer_context): + time.sleep(WAIT_TIME) + + +def hang_partial_unary_stream(request, servicer_context): + for _ in range(test_constants.STREAM_LENGTH // 2): + yield request + time.sleep(WAIT_TIME) + + +def hang_stream_unary(request_iterator, servicer_context): + time.sleep(WAIT_TIME) + + +def hang_partial_stream_unary(request_iterator, servicer_context): + for _ in range(test_constants.STREAM_LENGTH // 2): + next(request_iterator) + time.sleep(WAIT_TIME) + + +def hang_stream_stream(request_iterator, servicer_context): + time.sleep(WAIT_TIME) + + +def hang_partial_stream_stream(request_iterator, servicer_context): + for _ in range(test_constants.STREAM_LENGTH // 2): + yield next(request_iterator) #pylint: disable=stop-iteration-return + time.sleep(WAIT_TIME) + + +class MethodHandler(grpc.RpcMethodHandler): + + def __init__(self, request_streaming, response_streaming, partial_hang): + self.request_streaming = request_streaming + self.response_streaming = response_streaming + self.request_deserializer = None + self.response_serializer = None + self.unary_unary = None + self.unary_stream = None + self.stream_unary = None + self.stream_stream = None + if self.request_streaming and self.response_streaming: + if partial_hang: + self.stream_stream = hang_partial_stream_stream + else: + self.stream_stream = hang_stream_stream + elif self.request_streaming: + if partial_hang: + self.stream_unary = hang_partial_stream_unary + else: + self.stream_unary = hang_stream_unary + elif self.response_streaming: + if partial_hang: + self.unary_stream = hang_partial_unary_stream + else: + self.unary_stream = hang_unary_stream + else: + self.unary_unary = hang_unary_unary + + +class GenericHandler(grpc.GenericRpcHandler): + + def service(self, handler_call_details): + if handler_call_details.method == UNARY_UNARY: + return MethodHandler(False, False, False) + elif handler_call_details.method == UNARY_STREAM: + return MethodHandler(False, True, False) + elif handler_call_details.method == STREAM_UNARY: + return MethodHandler(True, False, False) + elif handler_call_details.method == STREAM_STREAM: + return MethodHandler(True, True, False) + elif handler_call_details.method == PARTIAL_UNARY_STREAM: + return MethodHandler(False, True, True) + elif handler_call_details.method == PARTIAL_STREAM_UNARY: + return MethodHandler(True, False, True) + elif handler_call_details.method == PARTIAL_STREAM_STREAM: + return MethodHandler(True, True, True) + else: + return None + + +# Traditional executors will not exit until all their +# current jobs complete. Because we submit jobs that will +# never finish, we don't want to block exit on these jobs. +class DaemonPool(object): + + def submit(self, fn, *args, **kwargs): + thread = threading.Thread(target=fn, args=args, kwargs=kwargs) + thread.daemon = True + thread.start() + + def shutdown(self, wait=True): + pass + + +def infinite_request_iterator(): + while True: + yield REQUEST + + +if __name__ == '__main__': + logging.basicConfig() + parser = argparse.ArgumentParser() + parser.add_argument('scenario', type=str) + parser.add_argument('--wait_for_interrupt', + dest='wait_for_interrupt', + action='store_true') + args = parser.parse_args() + + if args.scenario == UNSTARTED_SERVER: + server = grpc.server(DaemonPool(), options=(('grpc.so_reuseport', 0),)) + if args.wait_for_interrupt: + time.sleep(WAIT_TIME) + elif args.scenario == RUNNING_SERVER: + server = grpc.server(DaemonPool(), options=(('grpc.so_reuseport', 0),)) + port = server.add_insecure_port('[::]:0') + server.start() + if args.wait_for_interrupt: + time.sleep(WAIT_TIME) + elif args.scenario == POLL_CONNECTIVITY_NO_SERVER: + channel = grpc.insecure_channel('localhost:12345') + + def connectivity_callback(connectivity): + pass + + channel.subscribe(connectivity_callback, try_to_connect=True) + if args.wait_for_interrupt: + time.sleep(WAIT_TIME) + elif args.scenario == POLL_CONNECTIVITY: + server = grpc.server(DaemonPool(), options=(('grpc.so_reuseport', 0),)) + port = server.add_insecure_port('[::]:0') + server.start() + channel = grpc.insecure_channel('localhost:%d' % port) + + def connectivity_callback(connectivity): + pass + + channel.subscribe(connectivity_callback, try_to_connect=True) + if args.wait_for_interrupt: + time.sleep(WAIT_TIME) + + else: + handler = GenericHandler() + server = grpc.server(DaemonPool(), options=(('grpc.so_reuseport', 0),)) + port = server.add_insecure_port('[::]:0') + server.add_generic_rpc_handlers((handler,)) + server.start() + channel = grpc.insecure_channel('localhost:%d' % port) + + method = TEST_TO_METHOD[args.scenario] + + if args.scenario == IN_FLIGHT_UNARY_UNARY_CALL: + multi_callable = channel.unary_unary(method) + future = multi_callable.future(REQUEST) + result, call = multi_callable.with_call(REQUEST) + elif (args.scenario == IN_FLIGHT_UNARY_STREAM_CALL or + args.scenario == IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL): + multi_callable = channel.unary_stream(method) + response_iterator = multi_callable(REQUEST) + for response in response_iterator: + pass + elif (args.scenario == IN_FLIGHT_STREAM_UNARY_CALL or + args.scenario == IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL): + multi_callable = channel.stream_unary(method) + future = multi_callable.future(infinite_request_iterator()) + result, call = multi_callable.with_call( + iter([REQUEST] * test_constants.STREAM_LENGTH)) + elif (args.scenario == IN_FLIGHT_STREAM_STREAM_CALL or + args.scenario == IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL): + multi_callable = channel.stream_stream(method) + response_iterator = multi_callable(infinite_request_iterator()) + for response in response_iterator: + pass diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_exit_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_exit_test.py new file mode 100644 index 00000000000..4cf5ab63bdf --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_exit_test.py @@ -0,0 +1,261 @@ +# Copyright 2016 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests clean exit of server/client on Python Interpreter exit/sigint. + +The tests in this module spawn a subprocess for each test case, the +test is considered successful if it doesn't hang/timeout. +""" + +import atexit +import os +import signal +import six +import subprocess +import sys +import threading +import datetime +import time +import unittest +import logging + +from tests.unit import _exit_scenarios + +# SCENARIO_FILE = os.path.abspath( +# os.path.join(os.path.dirname(os.path.realpath(__file__)), +# '_exit_scenarios.py')) +INTERPRETER = sys.executable +BASE_COMMAND = [INTERPRETER, '-m', 'tests.unit._exit_scenarios'] +BASE_SIGTERM_COMMAND = BASE_COMMAND + ['--wait_for_interrupt'] + +INIT_TIME = datetime.timedelta(seconds=1) +WAIT_CHECK_INTERVAL = datetime.timedelta(milliseconds=100) +WAIT_CHECK_DEFAULT_TIMEOUT = datetime.timedelta(seconds=5) + +processes = [] +process_lock = threading.Lock() + + +# Make sure we attempt to clean up any +# processes we may have left running +def cleanup_processes(): + with process_lock: + for process in processes: + try: + process.kill() + except Exception: # pylint: disable=broad-except + pass + + +atexit.register(cleanup_processes) + + +def _process_wait_with_timeout(process, timeout=WAIT_CHECK_DEFAULT_TIMEOUT): + """A funciton to mimic 3.3+ only timeout argument in process.wait.""" + deadline = datetime.datetime.now() + timeout + while (process.poll() is None) and (datetime.datetime.now() < deadline): + time.sleep(WAIT_CHECK_INTERVAL.total_seconds()) + if process.returncode is None: + raise RuntimeError('Process failed to exit within %s' % timeout) + + +def interrupt_and_wait(process): + with process_lock: + processes.append(process) + time.sleep(INIT_TIME.total_seconds()) + os.kill(process.pid, signal.SIGINT) + _process_wait_with_timeout(process) + + +def wait(process): + with process_lock: + processes.append(process) + _process_wait_with_timeout(process) + + +# TODO(lidiz) enable exit tests once the root cause found. [email protected]('https://github.com/grpc/grpc/issues/23982') [email protected]('https://github.com/grpc/grpc/issues/23028') +class ExitTest(unittest.TestCase): + + def test_unstarted_server(self): + env = os.environ.copy() + env['Y_PYTHON_ENTRY_POINT'] = ':main' + process = subprocess.Popen(BASE_COMMAND + + [_exit_scenarios.UNSTARTED_SERVER], + stdout=sys.stdout, + stderr=sys.stderr, + env=env) + wait(process) + + def test_unstarted_server_terminate(self): + env = os.environ.copy() + env['Y_PYTHON_ENTRY_POINT'] = ':main' + process = subprocess.Popen(BASE_SIGTERM_COMMAND + + [_exit_scenarios.UNSTARTED_SERVER], + stdout=sys.stdout, + env=env) + interrupt_and_wait(process) + + def test_running_server(self): + env = os.environ.copy() + env['Y_PYTHON_ENTRY_POINT'] = ':main' + process = subprocess.Popen(BASE_COMMAND + + [_exit_scenarios.RUNNING_SERVER], + stdout=sys.stdout, + stderr=sys.stderr, + env=env) + wait(process) + + def test_running_server_terminate(self): + env = os.environ.copy() + env['Y_PYTHON_ENTRY_POINT'] = ':main' + process = subprocess.Popen(BASE_SIGTERM_COMMAND + + [_exit_scenarios.RUNNING_SERVER], + stdout=sys.stdout, + stderr=sys.stderr, + env=env) + interrupt_and_wait(process) + + def test_poll_connectivity_no_server(self): + env = os.environ.copy() + env['Y_PYTHON_ENTRY_POINT'] = ':main' + process = subprocess.Popen( + BASE_COMMAND + [_exit_scenarios.POLL_CONNECTIVITY_NO_SERVER], + stdout=sys.stdout, + stderr=sys.stderr, + env=env) + wait(process) + + def test_poll_connectivity_no_server_terminate(self): + env = os.environ.copy() + env['Y_PYTHON_ENTRY_POINT'] = ':main' + process = subprocess.Popen( + BASE_SIGTERM_COMMAND + + [_exit_scenarios.POLL_CONNECTIVITY_NO_SERVER], + stdout=sys.stdout, + stderr=sys.stderr, + env=env) + interrupt_and_wait(process) + + def test_poll_connectivity(self): + env = os.environ.copy() + env['Y_PYTHON_ENTRY_POINT'] = ':main' + process = subprocess.Popen(BASE_COMMAND + + [_exit_scenarios.POLL_CONNECTIVITY], + stdout=sys.stdout, + stderr=sys.stderr, + env=env) + wait(process) + + def test_poll_connectivity_terminate(self): + env = os.environ.copy() + env['Y_PYTHON_ENTRY_POINT'] = ':main' + process = subprocess.Popen(BASE_SIGTERM_COMMAND + + [_exit_scenarios.POLL_CONNECTIVITY], + stdout=sys.stdout, + stderr=sys.stderr, + env=env) + interrupt_and_wait(process) + + @unittest.skipIf(os.name == 'nt', + 'os.kill does not have required permission on Windows') + def test_in_flight_unary_unary_call(self): + env = os.environ.copy() + env['Y_PYTHON_ENTRY_POINT'] = ':main' + process = subprocess.Popen(BASE_COMMAND + + [_exit_scenarios.IN_FLIGHT_UNARY_UNARY_CALL], + stdout=sys.stdout, + stderr=sys.stderr, + env=env) + interrupt_and_wait(process) + + @unittest.skipIf(os.name == 'nt', + 'os.kill does not have required permission on Windows') + def test_in_flight_unary_stream_call(self): + env = os.environ.copy() + env['Y_PYTHON_ENTRY_POINT'] = ':main' + process = subprocess.Popen( + BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_UNARY_STREAM_CALL], + stdout=sys.stdout, + stderr=sys.stderr, + env=env) + interrupt_and_wait(process) + + @unittest.skipIf(os.name == 'nt', + 'os.kill does not have required permission on Windows') + def test_in_flight_stream_unary_call(self): + env = os.environ.copy() + env['Y_PYTHON_ENTRY_POINT'] = ':main' + process = subprocess.Popen( + BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_STREAM_UNARY_CALL], + stdout=sys.stdout, + stderr=sys.stderr, + env=env) + interrupt_and_wait(process) + + @unittest.skipIf(os.name == 'nt', + 'os.kill does not have required permission on Windows') + def test_in_flight_stream_stream_call(self): + env = os.environ.copy() + env['Y_PYTHON_ENTRY_POINT'] = ':main' + process = subprocess.Popen( + BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_STREAM_STREAM_CALL], + stdout=sys.stdout, + stderr=sys.stderr, + env=env) + interrupt_and_wait(process) + + @unittest.skipIf(os.name == 'nt', + 'os.kill does not have required permission on Windows') + def test_in_flight_partial_unary_stream_call(self): + env = os.environ.copy() + env['Y_PYTHON_ENTRY_POINT'] = ':main' + process = subprocess.Popen( + BASE_COMMAND + + [_exit_scenarios.IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL], + stdout=sys.stdout, + stderr=sys.stderr, + env=env) + interrupt_and_wait(process) + + @unittest.skipIf(os.name == 'nt', + 'os.kill does not have required permission on Windows') + def test_in_flight_partial_stream_unary_call(self): + env = os.environ.copy() + env['Y_PYTHON_ENTRY_POINT'] = ':main' + process = subprocess.Popen( + BASE_COMMAND + + [_exit_scenarios.IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL], + stdout=sys.stdout, + stderr=sys.stderr, + env=env) + interrupt_and_wait(process) + + @unittest.skipIf(os.name == 'nt', + 'os.kill does not have required permission on Windows') + def test_in_flight_partial_stream_stream_call(self): + env = os.environ.copy() + env['Y_PYTHON_ENTRY_POINT'] = ':main' + process = subprocess.Popen( + BASE_COMMAND + + [_exit_scenarios.IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL], + stdout=sys.stdout, + stderr=sys.stderr, + env=env) + interrupt_and_wait(process) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_from_grpc_import_star.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_from_grpc_import_star.py new file mode 100644 index 00000000000..1ada25382de --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_from_grpc_import_star.py @@ -0,0 +1,23 @@ +# Copyright 2016 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +_BEFORE_IMPORT = tuple(globals()) + +from grpc import * # pylint: disable=wildcard-import,unused-wildcard-import + +_AFTER_IMPORT = tuple(globals()) + +GRPC_ELEMENTS = tuple( + element for element in _AFTER_IMPORT + if element not in _BEFORE_IMPORT and element != '_BEFORE_IMPORT') diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_grpc_shutdown_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_grpc_shutdown_test.py new file mode 100644 index 00000000000..1c4890b97f1 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_grpc_shutdown_test.py @@ -0,0 +1,54 @@ +# Copyright 2019 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests the gRPC Core shutdown path.""" + +import time +import threading +import unittest +import datetime + +import grpc + +_TIMEOUT_FOR_SEGFAULT = datetime.timedelta(seconds=10) + + +class GrpcShutdownTest(unittest.TestCase): + + def test_channel_close_with_connectivity_watcher(self): + """Originated by https://github.com/grpc/grpc/issues/20299. + + The grpc_shutdown happens synchronously, but there might be Core object + references left in Cython which might lead to ABORT or SIGSEGV. + """ + connection_failed = threading.Event() + + def on_state_change(state): + if state in (grpc.ChannelConnectivity.TRANSIENT_FAILURE, + grpc.ChannelConnectivity.SHUTDOWN): + connection_failed.set() + + # Connects to an void address, and subscribes state changes + channel = grpc.insecure_channel("0.1.1.1:12345") + channel.subscribe(on_state_change, True) + + deadline = datetime.datetime.now() + _TIMEOUT_FOR_SEGFAULT + + while datetime.datetime.now() < deadline: + time.sleep(0.1) + if connection_failed.is_set(): + channel.close() + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_interceptor_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_interceptor_test.py new file mode 100644 index 00000000000..619db7b3ffd --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_interceptor_test.py @@ -0,0 +1,708 @@ +# Copyright 2017 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test of gRPC Python interceptors.""" + +import collections +import itertools +import threading +import unittest +import logging +import os +from concurrent import futures + +import grpc +from grpc.framework.foundation import logging_pool + +from tests.unit import test_common +from tests.unit.framework.common import test_constants +from tests.unit.framework.common import test_control + +_SERIALIZE_REQUEST = lambda bytestring: bytestring * 2 +_DESERIALIZE_REQUEST = lambda bytestring: bytestring[len(bytestring) // 2:] +_SERIALIZE_RESPONSE = lambda bytestring: bytestring * 3 +_DESERIALIZE_RESPONSE = lambda bytestring: bytestring[:len(bytestring) // 3] + +_EXCEPTION_REQUEST = b'\x09\x0a' + +_UNARY_UNARY = '/test/UnaryUnary' +_UNARY_STREAM = '/test/UnaryStream' +_STREAM_UNARY = '/test/StreamUnary' +_STREAM_STREAM = '/test/StreamStream' + + +class _ApplicationErrorStandin(Exception): + pass + + +class _Callback(object): + + def __init__(self): + self._condition = threading.Condition() + self._value = None + self._called = False + + def __call__(self, value): + with self._condition: + self._value = value + self._called = True + self._condition.notify_all() + + def value(self): + with self._condition: + while not self._called: + self._condition.wait() + return self._value + + +class _Handler(object): + + def __init__(self, control): + self._control = control + + def handle_unary_unary(self, request, servicer_context): + self._control.control() + if servicer_context is not None: + servicer_context.set_trailing_metadata((( + 'testkey', + 'testvalue', + ),)) + if request == _EXCEPTION_REQUEST: + raise _ApplicationErrorStandin() + return request + + def handle_unary_stream(self, request, servicer_context): + if request == _EXCEPTION_REQUEST: + raise _ApplicationErrorStandin() + for _ in range(test_constants.STREAM_LENGTH): + self._control.control() + yield request + self._control.control() + if servicer_context is not None: + servicer_context.set_trailing_metadata((( + 'testkey', + 'testvalue', + ),)) + + def handle_stream_unary(self, request_iterator, servicer_context): + if servicer_context is not None: + servicer_context.invocation_metadata() + self._control.control() + response_elements = [] + for request in request_iterator: + self._control.control() + response_elements.append(request) + self._control.control() + if servicer_context is not None: + servicer_context.set_trailing_metadata((( + 'testkey', + 'testvalue', + ),)) + if _EXCEPTION_REQUEST in response_elements: + raise _ApplicationErrorStandin() + return b''.join(response_elements) + + def handle_stream_stream(self, request_iterator, servicer_context): + self._control.control() + if servicer_context is not None: + servicer_context.set_trailing_metadata((( + 'testkey', + 'testvalue', + ),)) + for request in request_iterator: + if request == _EXCEPTION_REQUEST: + raise _ApplicationErrorStandin() + self._control.control() + yield request + self._control.control() + + +class _MethodHandler(grpc.RpcMethodHandler): + + def __init__(self, request_streaming, response_streaming, + request_deserializer, response_serializer, unary_unary, + unary_stream, stream_unary, stream_stream): + self.request_streaming = request_streaming + self.response_streaming = response_streaming + self.request_deserializer = request_deserializer + self.response_serializer = response_serializer + self.unary_unary = unary_unary + self.unary_stream = unary_stream + self.stream_unary = stream_unary + self.stream_stream = stream_stream + + +class _GenericHandler(grpc.GenericRpcHandler): + + def __init__(self, handler): + self._handler = handler + + def service(self, handler_call_details): + if handler_call_details.method == _UNARY_UNARY: + return _MethodHandler(False, False, None, None, + self._handler.handle_unary_unary, None, None, + None) + elif handler_call_details.method == _UNARY_STREAM: + return _MethodHandler(False, True, _DESERIALIZE_REQUEST, + _SERIALIZE_RESPONSE, None, + self._handler.handle_unary_stream, None, None) + elif handler_call_details.method == _STREAM_UNARY: + return _MethodHandler(True, False, _DESERIALIZE_REQUEST, + _SERIALIZE_RESPONSE, None, None, + self._handler.handle_stream_unary, None) + elif handler_call_details.method == _STREAM_STREAM: + return _MethodHandler(True, True, None, None, None, None, None, + self._handler.handle_stream_stream) + else: + return None + + +def _unary_unary_multi_callable(channel): + return channel.unary_unary(_UNARY_UNARY) + + +def _unary_stream_multi_callable(channel): + return channel.unary_stream(_UNARY_STREAM, + request_serializer=_SERIALIZE_REQUEST, + response_deserializer=_DESERIALIZE_RESPONSE) + + +def _stream_unary_multi_callable(channel): + return channel.stream_unary(_STREAM_UNARY, + request_serializer=_SERIALIZE_REQUEST, + response_deserializer=_DESERIALIZE_RESPONSE) + + +def _stream_stream_multi_callable(channel): + return channel.stream_stream(_STREAM_STREAM) + + +class _ClientCallDetails( + collections.namedtuple( + '_ClientCallDetails', + ('method', 'timeout', 'metadata', 'credentials')), + grpc.ClientCallDetails): + pass + + +class _GenericClientInterceptor(grpc.UnaryUnaryClientInterceptor, + grpc.UnaryStreamClientInterceptor, + grpc.StreamUnaryClientInterceptor, + grpc.StreamStreamClientInterceptor): + + def __init__(self, interceptor_function): + self._fn = interceptor_function + + def intercept_unary_unary(self, continuation, client_call_details, request): + new_details, new_request_iterator, postprocess = self._fn( + client_call_details, iter((request,)), False, False) + response = continuation(new_details, next(new_request_iterator)) + return postprocess(response) if postprocess else response + + def intercept_unary_stream(self, continuation, client_call_details, + request): + new_details, new_request_iterator, postprocess = self._fn( + client_call_details, iter((request,)), False, True) + response_it = continuation(new_details, new_request_iterator) + return postprocess(response_it) if postprocess else response_it + + def intercept_stream_unary(self, continuation, client_call_details, + request_iterator): + new_details, new_request_iterator, postprocess = self._fn( + client_call_details, request_iterator, True, False) + response = continuation(new_details, next(new_request_iterator)) + return postprocess(response) if postprocess else response + + def intercept_stream_stream(self, continuation, client_call_details, + request_iterator): + new_details, new_request_iterator, postprocess = self._fn( + client_call_details, request_iterator, True, True) + response_it = continuation(new_details, new_request_iterator) + return postprocess(response_it) if postprocess else response_it + + +class _LoggingInterceptor(grpc.ServerInterceptor, + grpc.UnaryUnaryClientInterceptor, + grpc.UnaryStreamClientInterceptor, + grpc.StreamUnaryClientInterceptor, + grpc.StreamStreamClientInterceptor): + + def __init__(self, tag, record): + self.tag = tag + self.record = record + + def intercept_service(self, continuation, handler_call_details): + self.record.append(self.tag + ':intercept_service') + return continuation(handler_call_details) + + def intercept_unary_unary(self, continuation, client_call_details, request): + self.record.append(self.tag + ':intercept_unary_unary') + result = continuation(client_call_details, request) + assert isinstance( + result, + grpc.Call), '{} ({}) is not an instance of grpc.Call'.format( + result, type(result)) + assert isinstance( + result, + grpc.Future), '{} ({}) is not an instance of grpc.Future'.format( + result, type(result)) + return result + + def intercept_unary_stream(self, continuation, client_call_details, + request): + self.record.append(self.tag + ':intercept_unary_stream') + return continuation(client_call_details, request) + + def intercept_stream_unary(self, continuation, client_call_details, + request_iterator): + self.record.append(self.tag + ':intercept_stream_unary') + result = continuation(client_call_details, request_iterator) + assert isinstance( + result, + grpc.Call), '{} is not an instance of grpc.Call'.format(result) + assert isinstance( + result, + grpc.Future), '{} is not an instance of grpc.Future'.format(result) + return result + + def intercept_stream_stream(self, continuation, client_call_details, + request_iterator): + self.record.append(self.tag + ':intercept_stream_stream') + return continuation(client_call_details, request_iterator) + + +class _DefectiveClientInterceptor(grpc.UnaryUnaryClientInterceptor): + + def intercept_unary_unary(self, ignored_continuation, + ignored_client_call_details, ignored_request): + raise test_control.Defect() + + +def _wrap_request_iterator_stream_interceptor(wrapper): + + def intercept_call(client_call_details, request_iterator, request_streaming, + ignored_response_streaming): + if request_streaming: + return client_call_details, wrapper(request_iterator), None + else: + return client_call_details, request_iterator, None + + return _GenericClientInterceptor(intercept_call) + + +def _append_request_header_interceptor(header, value): + + def intercept_call(client_call_details, request_iterator, + ignored_request_streaming, ignored_response_streaming): + metadata = [] + if client_call_details.metadata: + metadata = list(client_call_details.metadata) + metadata.append(( + header, + value, + )) + client_call_details = _ClientCallDetails( + client_call_details.method, client_call_details.timeout, metadata, + client_call_details.credentials) + return client_call_details, request_iterator, None + + return _GenericClientInterceptor(intercept_call) + + +class _GenericServerInterceptor(grpc.ServerInterceptor): + + def __init__(self, fn): + self._fn = fn + + def intercept_service(self, continuation, handler_call_details): + return self._fn(continuation, handler_call_details) + + +def _filter_server_interceptor(condition, interceptor): + + def intercept_service(continuation, handler_call_details): + if condition(handler_call_details): + return interceptor.intercept_service(continuation, + handler_call_details) + return continuation(handler_call_details) + + return _GenericServerInterceptor(intercept_service) + + +class InterceptorTest(unittest.TestCase): + + def setUp(self): + self._control = test_control.PauseFailControl() + self._handler = _Handler(self._control) + self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY) + + self._record = [] + conditional_interceptor = _filter_server_interceptor( + lambda x: ('secret', '42') in x.invocation_metadata, + _LoggingInterceptor('s3', self._record)) + + self._server = grpc.server(self._server_pool, + options=(('grpc.so_reuseport', 0),), + interceptors=( + _LoggingInterceptor('s1', self._record), + conditional_interceptor, + _LoggingInterceptor('s2', self._record), + )) + port = self._server.add_insecure_port('[::]:0') + self._server.add_generic_rpc_handlers((_GenericHandler(self._handler),)) + self._server.start() + + self._channel = grpc.insecure_channel('localhost:%d' % port) + + def tearDown(self): + self._server.stop(None) + self._server_pool.shutdown(wait=True) + self._channel.close() + + def testTripleRequestMessagesClientInterceptor(self): + + def triple(request_iterator): + while True: + try: + item = next(request_iterator) + yield item + yield item + yield item + except StopIteration: + break + + interceptor = _wrap_request_iterator_stream_interceptor(triple) + channel = grpc.intercept_channel(self._channel, interceptor) + requests = tuple( + b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) + + multi_callable = _stream_stream_multi_callable(channel) + response_iterator = multi_callable( + iter(requests), + metadata=( + ('test', + 'InterceptedStreamRequestBlockingUnaryResponseWithCall'),)) + + responses = tuple(response_iterator) + self.assertEqual(len(responses), 3 * test_constants.STREAM_LENGTH) + + multi_callable = _stream_stream_multi_callable(self._channel) + response_iterator = multi_callable( + iter(requests), + metadata=( + ('test', + 'InterceptedStreamRequestBlockingUnaryResponseWithCall'),)) + + responses = tuple(response_iterator) + self.assertEqual(len(responses), test_constants.STREAM_LENGTH) + + def testDefectiveClientInterceptor(self): + interceptor = _DefectiveClientInterceptor() + defective_channel = grpc.intercept_channel(self._channel, interceptor) + + request = b'\x07\x08' + + multi_callable = _unary_unary_multi_callable(defective_channel) + call_future = multi_callable.future( + request, + metadata=(('test', + 'InterceptedUnaryRequestBlockingUnaryResponse'),)) + + self.assertIsNotNone(call_future.exception()) + self.assertEqual(call_future.code(), grpc.StatusCode.INTERNAL) + + def testInterceptedHeaderManipulationWithServerSideVerification(self): + request = b'\x07\x08' + + channel = grpc.intercept_channel( + self._channel, _append_request_header_interceptor('secret', '42')) + channel = grpc.intercept_channel( + channel, _LoggingInterceptor('c1', self._record), + _LoggingInterceptor('c2', self._record)) + + self._record[:] = [] + + multi_callable = _unary_unary_multi_callable(channel) + multi_callable.with_call( + request, + metadata=( + ('test', + 'InterceptedUnaryRequestBlockingUnaryResponseWithCall'),)) + + self.assertSequenceEqual(self._record, [ + 'c1:intercept_unary_unary', 'c2:intercept_unary_unary', + 's1:intercept_service', 's3:intercept_service', + 's2:intercept_service' + ]) + + def testInterceptedUnaryRequestBlockingUnaryResponse(self): + request = b'\x07\x08' + + self._record[:] = [] + + channel = grpc.intercept_channel( + self._channel, _LoggingInterceptor('c1', self._record), + _LoggingInterceptor('c2', self._record)) + + multi_callable = _unary_unary_multi_callable(channel) + multi_callable( + request, + metadata=(('test', + 'InterceptedUnaryRequestBlockingUnaryResponse'),)) + + self.assertSequenceEqual(self._record, [ + 'c1:intercept_unary_unary', 'c2:intercept_unary_unary', + 's1:intercept_service', 's2:intercept_service' + ]) + + def testInterceptedUnaryRequestBlockingUnaryResponseWithError(self): + request = _EXCEPTION_REQUEST + + self._record[:] = [] + + channel = grpc.intercept_channel( + self._channel, _LoggingInterceptor('c1', self._record), + _LoggingInterceptor('c2', self._record)) + + multi_callable = _unary_unary_multi_callable(channel) + with self.assertRaises(grpc.RpcError) as exception_context: + multi_callable( + request, + metadata=(('test', + 'InterceptedUnaryRequestBlockingUnaryResponse'),)) + exception = exception_context.exception + self.assertFalse(exception.cancelled()) + self.assertFalse(exception.running()) + self.assertTrue(exception.done()) + with self.assertRaises(grpc.RpcError): + exception.result() + self.assertIsInstance(exception.exception(), grpc.RpcError) + + def testInterceptedUnaryRequestBlockingUnaryResponseWithCall(self): + request = b'\x07\x08' + + channel = grpc.intercept_channel( + self._channel, _LoggingInterceptor('c1', self._record), + _LoggingInterceptor('c2', self._record)) + + self._record[:] = [] + + multi_callable = _unary_unary_multi_callable(channel) + multi_callable.with_call( + request, + metadata=( + ('test', + 'InterceptedUnaryRequestBlockingUnaryResponseWithCall'),)) + + self.assertSequenceEqual(self._record, [ + 'c1:intercept_unary_unary', 'c2:intercept_unary_unary', + 's1:intercept_service', 's2:intercept_service' + ]) + + def testInterceptedUnaryRequestFutureUnaryResponse(self): + request = b'\x07\x08' + + self._record[:] = [] + channel = grpc.intercept_channel( + self._channel, _LoggingInterceptor('c1', self._record), + _LoggingInterceptor('c2', self._record)) + + multi_callable = _unary_unary_multi_callable(channel) + response_future = multi_callable.future( + request, + metadata=(('test', 'InterceptedUnaryRequestFutureUnaryResponse'),)) + response_future.result() + + self.assertSequenceEqual(self._record, [ + 'c1:intercept_unary_unary', 'c2:intercept_unary_unary', + 's1:intercept_service', 's2:intercept_service' + ]) + + def testInterceptedUnaryRequestStreamResponse(self): + request = b'\x37\x58' + + self._record[:] = [] + channel = grpc.intercept_channel( + self._channel, _LoggingInterceptor('c1', self._record), + _LoggingInterceptor('c2', self._record)) + + multi_callable = _unary_stream_multi_callable(channel) + response_iterator = multi_callable( + request, + metadata=(('test', 'InterceptedUnaryRequestStreamResponse'),)) + tuple(response_iterator) + + self.assertSequenceEqual(self._record, [ + 'c1:intercept_unary_stream', 'c2:intercept_unary_stream', + 's1:intercept_service', 's2:intercept_service' + ]) + + def testInterceptedUnaryRequestStreamResponseWithError(self): + request = _EXCEPTION_REQUEST + + self._record[:] = [] + channel = grpc.intercept_channel( + self._channel, _LoggingInterceptor('c1', self._record), + _LoggingInterceptor('c2', self._record)) + + multi_callable = _unary_stream_multi_callable(channel) + response_iterator = multi_callable( + request, + metadata=(('test', 'InterceptedUnaryRequestStreamResponse'),)) + with self.assertRaises(grpc.RpcError) as exception_context: + tuple(response_iterator) + exception = exception_context.exception + self.assertFalse(exception.cancelled()) + self.assertFalse(exception.running()) + self.assertTrue(exception.done()) + with self.assertRaises(grpc.RpcError): + exception.result() + self.assertIsInstance(exception.exception(), grpc.RpcError) + + def testInterceptedStreamRequestBlockingUnaryResponse(self): + requests = tuple( + b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + self._record[:] = [] + channel = grpc.intercept_channel( + self._channel, _LoggingInterceptor('c1', self._record), + _LoggingInterceptor('c2', self._record)) + + multi_callable = _stream_unary_multi_callable(channel) + multi_callable( + request_iterator, + metadata=(('test', + 'InterceptedStreamRequestBlockingUnaryResponse'),)) + + self.assertSequenceEqual(self._record, [ + 'c1:intercept_stream_unary', 'c2:intercept_stream_unary', + 's1:intercept_service', 's2:intercept_service' + ]) + + def testInterceptedStreamRequestBlockingUnaryResponseWithCall(self): + requests = tuple( + b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + self._record[:] = [] + channel = grpc.intercept_channel( + self._channel, _LoggingInterceptor('c1', self._record), + _LoggingInterceptor('c2', self._record)) + + multi_callable = _stream_unary_multi_callable(channel) + multi_callable.with_call( + request_iterator, + metadata=( + ('test', + 'InterceptedStreamRequestBlockingUnaryResponseWithCall'),)) + + self.assertSequenceEqual(self._record, [ + 'c1:intercept_stream_unary', 'c2:intercept_stream_unary', + 's1:intercept_service', 's2:intercept_service' + ]) + + def testInterceptedStreamRequestFutureUnaryResponse(self): + requests = tuple( + b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + self._record[:] = [] + channel = grpc.intercept_channel( + self._channel, _LoggingInterceptor('c1', self._record), + _LoggingInterceptor('c2', self._record)) + + multi_callable = _stream_unary_multi_callable(channel) + response_future = multi_callable.future( + request_iterator, + metadata=(('test', 'InterceptedStreamRequestFutureUnaryResponse'),)) + response_future.result() + + self.assertSequenceEqual(self._record, [ + 'c1:intercept_stream_unary', 'c2:intercept_stream_unary', + 's1:intercept_service', 's2:intercept_service' + ]) + + def testInterceptedStreamRequestFutureUnaryResponseWithError(self): + requests = tuple( + _EXCEPTION_REQUEST for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + self._record[:] = [] + channel = grpc.intercept_channel( + self._channel, _LoggingInterceptor('c1', self._record), + _LoggingInterceptor('c2', self._record)) + + multi_callable = _stream_unary_multi_callable(channel) + response_future = multi_callable.future( + request_iterator, + metadata=(('test', 'InterceptedStreamRequestFutureUnaryResponse'),)) + with self.assertRaises(grpc.RpcError) as exception_context: + response_future.result() + exception = exception_context.exception + self.assertFalse(exception.cancelled()) + self.assertFalse(exception.running()) + self.assertTrue(exception.done()) + with self.assertRaises(grpc.RpcError): + exception.result() + self.assertIsInstance(exception.exception(), grpc.RpcError) + + def testInterceptedStreamRequestStreamResponse(self): + requests = tuple( + b'\x77\x58' for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + self._record[:] = [] + channel = grpc.intercept_channel( + self._channel, _LoggingInterceptor('c1', self._record), + _LoggingInterceptor('c2', self._record)) + + multi_callable = _stream_stream_multi_callable(channel) + response_iterator = multi_callable( + request_iterator, + metadata=(('test', 'InterceptedStreamRequestStreamResponse'),)) + tuple(response_iterator) + + self.assertSequenceEqual(self._record, [ + 'c1:intercept_stream_stream', 'c2:intercept_stream_stream', + 's1:intercept_service', 's2:intercept_service' + ]) + + def testInterceptedStreamRequestStreamResponseWithError(self): + requests = tuple( + _EXCEPTION_REQUEST for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + self._record[:] = [] + channel = grpc.intercept_channel( + self._channel, _LoggingInterceptor('c1', self._record), + _LoggingInterceptor('c2', self._record)) + + multi_callable = _stream_stream_multi_callable(channel) + response_iterator = multi_callable( + request_iterator, + metadata=(('test', 'InterceptedStreamRequestStreamResponse'),)) + with self.assertRaises(grpc.RpcError) as exception_context: + tuple(response_iterator) + exception = exception_context.exception + self.assertFalse(exception.cancelled()) + self.assertFalse(exception.running()) + self.assertTrue(exception.done()) + with self.assertRaises(grpc.RpcError): + exception.result() + self.assertIsInstance(exception.exception(), grpc.RpcError) + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py new file mode 100644 index 00000000000..d1f1499d8cd --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py @@ -0,0 +1,140 @@ +# Copyright 2016 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test of RPCs made against gRPC Python's application-layer API.""" + +import unittest +import logging + +import grpc + +from tests.unit.framework.common import test_constants + +_SERIALIZE_REQUEST = lambda bytestring: bytestring * 2 +_DESERIALIZE_REQUEST = lambda bytestring: bytestring[len(bytestring) // 2:] +_SERIALIZE_RESPONSE = lambda bytestring: bytestring * 3 +_DESERIALIZE_RESPONSE = lambda bytestring: bytestring[:len(bytestring) // 3] + +_UNARY_UNARY = '/test/UnaryUnary' +_UNARY_STREAM = '/test/UnaryStream' +_STREAM_UNARY = '/test/StreamUnary' +_STREAM_STREAM = '/test/StreamStream' + + +def _unary_unary_multi_callable(channel): + return channel.unary_unary(_UNARY_UNARY) + + +def _unary_stream_multi_callable(channel): + return channel.unary_stream(_UNARY_STREAM, + request_serializer=_SERIALIZE_REQUEST, + response_deserializer=_DESERIALIZE_RESPONSE) + + +def _stream_unary_multi_callable(channel): + return channel.stream_unary(_STREAM_UNARY, + request_serializer=_SERIALIZE_REQUEST, + response_deserializer=_DESERIALIZE_RESPONSE) + + +def _stream_stream_multi_callable(channel): + return channel.stream_stream(_STREAM_STREAM) + + +class InvalidMetadataTest(unittest.TestCase): + + def setUp(self): + self._channel = grpc.insecure_channel('localhost:8080') + self._unary_unary = _unary_unary_multi_callable(self._channel) + self._unary_stream = _unary_stream_multi_callable(self._channel) + self._stream_unary = _stream_unary_multi_callable(self._channel) + self._stream_stream = _stream_stream_multi_callable(self._channel) + + def tearDown(self): + self._channel.close() + + def testUnaryRequestBlockingUnaryResponse(self): + request = b'\x07\x08' + metadata = (('InVaLiD', 'UnaryRequestBlockingUnaryResponse'),) + expected_error_details = "metadata was invalid: %s" % metadata + with self.assertRaises(ValueError) as exception_context: + self._unary_unary(request, metadata=metadata) + self.assertIn(expected_error_details, str(exception_context.exception)) + + def testUnaryRequestBlockingUnaryResponseWithCall(self): + request = b'\x07\x08' + metadata = (('InVaLiD', 'UnaryRequestBlockingUnaryResponseWithCall'),) + expected_error_details = "metadata was invalid: %s" % metadata + with self.assertRaises(ValueError) as exception_context: + self._unary_unary.with_call(request, metadata=metadata) + self.assertIn(expected_error_details, str(exception_context.exception)) + + def testUnaryRequestFutureUnaryResponse(self): + request = b'\x07\x08' + metadata = (('InVaLiD', 'UnaryRequestFutureUnaryResponse'),) + expected_error_details = "metadata was invalid: %s" % metadata + with self.assertRaises(ValueError) as exception_context: + self._unary_unary.future(request, metadata=metadata) + + def testUnaryRequestStreamResponse(self): + request = b'\x37\x58' + metadata = (('InVaLiD', 'UnaryRequestStreamResponse'),) + expected_error_details = "metadata was invalid: %s" % metadata + with self.assertRaises(ValueError) as exception_context: + self._unary_stream(request, metadata=metadata) + self.assertIn(expected_error_details, str(exception_context.exception)) + + def testStreamRequestBlockingUnaryResponse(self): + request_iterator = ( + b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) + metadata = (('InVaLiD', 'StreamRequestBlockingUnaryResponse'),) + expected_error_details = "metadata was invalid: %s" % metadata + with self.assertRaises(ValueError) as exception_context: + self._stream_unary(request_iterator, metadata=metadata) + self.assertIn(expected_error_details, str(exception_context.exception)) + + def testStreamRequestBlockingUnaryResponseWithCall(self): + request_iterator = ( + b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) + metadata = (('InVaLiD', 'StreamRequestBlockingUnaryResponseWithCall'),) + expected_error_details = "metadata was invalid: %s" % metadata + multi_callable = _stream_unary_multi_callable(self._channel) + with self.assertRaises(ValueError) as exception_context: + multi_callable.with_call(request_iterator, metadata=metadata) + self.assertIn(expected_error_details, str(exception_context.exception)) + + def testStreamRequestFutureUnaryResponse(self): + request_iterator = ( + b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) + metadata = (('InVaLiD', 'StreamRequestFutureUnaryResponse'),) + expected_error_details = "metadata was invalid: %s" % metadata + with self.assertRaises(ValueError) as exception_context: + self._stream_unary.future(request_iterator, metadata=metadata) + self.assertIn(expected_error_details, str(exception_context.exception)) + + def testStreamRequestStreamResponse(self): + request_iterator = ( + b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) + metadata = (('InVaLiD', 'StreamRequestStreamResponse'),) + expected_error_details = "metadata was invalid: %s" % metadata + with self.assertRaises(ValueError) as exception_context: + self._stream_stream(request_iterator, metadata=metadata) + self.assertIn(expected_error_details, str(exception_context.exception)) + + def testInvalidMetadata(self): + self.assertRaises(TypeError, self._unary_unary, b'', metadata=42) + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_invocation_defects_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_invocation_defects_test.py new file mode 100644 index 00000000000..a0208b51df4 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_invocation_defects_test.py @@ -0,0 +1,266 @@ +# Copyright 2016 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import logging + +import grpc + +from tests.unit import test_common +from tests.unit.framework.common import test_constants +from tests.unit.framework.common import test_control + +_SERIALIZE_REQUEST = lambda bytestring: bytestring * 2 +_DESERIALIZE_REQUEST = lambda bytestring: bytestring[len(bytestring) // 2:] +_SERIALIZE_RESPONSE = lambda bytestring: bytestring * 3 +_DESERIALIZE_RESPONSE = lambda bytestring: bytestring[:len(bytestring) // 3] + +_UNARY_UNARY = '/test/UnaryUnary' +_UNARY_STREAM = '/test/UnaryStream' +_STREAM_UNARY = '/test/StreamUnary' +_STREAM_STREAM = '/test/StreamStream' +_DEFECTIVE_GENERIC_RPC_HANDLER = '/test/DefectiveGenericRpcHandler' + + +class _Handler(object): + + def __init__(self, control): + self._control = control + + def handle_unary_unary(self, request, servicer_context): + self._control.control() + if servicer_context is not None: + servicer_context.set_trailing_metadata((( + 'testkey', + 'testvalue', + ),)) + return request + + def handle_unary_stream(self, request, servicer_context): + for _ in range(test_constants.STREAM_LENGTH): + self._control.control() + yield request + self._control.control() + if servicer_context is not None: + servicer_context.set_trailing_metadata((( + 'testkey', + 'testvalue', + ),)) + + def handle_stream_unary(self, request_iterator, servicer_context): + if servicer_context is not None: + servicer_context.invocation_metadata() + self._control.control() + response_elements = [] + for request in request_iterator: + self._control.control() + response_elements.append(request) + self._control.control() + if servicer_context is not None: + servicer_context.set_trailing_metadata((( + 'testkey', + 'testvalue', + ),)) + return b''.join(response_elements) + + def handle_stream_stream(self, request_iterator, servicer_context): + self._control.control() + if servicer_context is not None: + servicer_context.set_trailing_metadata((( + 'testkey', + 'testvalue', + ),)) + for request in request_iterator: + self._control.control() + yield request + self._control.control() + + def defective_generic_rpc_handler(self): + raise test_control.Defect() + + +class _MethodHandler(grpc.RpcMethodHandler): + + def __init__(self, request_streaming, response_streaming, + request_deserializer, response_serializer, unary_unary, + unary_stream, stream_unary, stream_stream): + self.request_streaming = request_streaming + self.response_streaming = response_streaming + self.request_deserializer = request_deserializer + self.response_serializer = response_serializer + self.unary_unary = unary_unary + self.unary_stream = unary_stream + self.stream_unary = stream_unary + self.stream_stream = stream_stream + + +class _GenericHandler(grpc.GenericRpcHandler): + + def __init__(self, handler): + self._handler = handler + + def service(self, handler_call_details): + if handler_call_details.method == _UNARY_UNARY: + return _MethodHandler(False, False, None, None, + self._handler.handle_unary_unary, None, None, + None) + elif handler_call_details.method == _UNARY_STREAM: + return _MethodHandler(False, True, _DESERIALIZE_REQUEST, + _SERIALIZE_RESPONSE, None, + self._handler.handle_unary_stream, None, None) + elif handler_call_details.method == _STREAM_UNARY: + return _MethodHandler(True, False, _DESERIALIZE_REQUEST, + _SERIALIZE_RESPONSE, None, None, + self._handler.handle_stream_unary, None) + elif handler_call_details.method == _STREAM_STREAM: + return _MethodHandler(True, True, None, None, None, None, None, + self._handler.handle_stream_stream) + elif handler_call_details.method == _DEFECTIVE_GENERIC_RPC_HANDLER: + return self._handler.defective_generic_rpc_handler() + else: + return None + + +class FailAfterFewIterationsCounter(object): + + def __init__(self, high, bytestring): + self._current = 0 + self._high = high + self._bytestring = bytestring + + def __iter__(self): + return self + + def __next__(self): + if self._current >= self._high: + raise test_control.Defect() + else: + self._current += 1 + return self._bytestring + + next = __next__ + + +def _unary_unary_multi_callable(channel): + return channel.unary_unary(_UNARY_UNARY) + + +def _unary_stream_multi_callable(channel): + return channel.unary_stream(_UNARY_STREAM, + request_serializer=_SERIALIZE_REQUEST, + response_deserializer=_DESERIALIZE_RESPONSE) + + +def _stream_unary_multi_callable(channel): + return channel.stream_unary(_STREAM_UNARY, + request_serializer=_SERIALIZE_REQUEST, + response_deserializer=_DESERIALIZE_RESPONSE) + + +def _stream_stream_multi_callable(channel): + return channel.stream_stream(_STREAM_STREAM) + + +def _defective_handler_multi_callable(channel): + return channel.unary_unary(_DEFECTIVE_GENERIC_RPC_HANDLER) + + +class InvocationDefectsTest(unittest.TestCase): + """Tests the handling of exception-raising user code on the client-side.""" + + def setUp(self): + self._control = test_control.PauseFailControl() + self._handler = _Handler(self._control) + + self._server = test_common.test_server() + port = self._server.add_insecure_port('[::]:0') + self._server.add_generic_rpc_handlers((_GenericHandler(self._handler),)) + self._server.start() + + self._channel = grpc.insecure_channel('localhost:%d' % port) + + def tearDown(self): + self._server.stop(0) + self._channel.close() + + def testIterableStreamRequestBlockingUnaryResponse(self): + requests = object() + multi_callable = _stream_unary_multi_callable(self._channel) + + with self.assertRaises(grpc.RpcError) as exception_context: + multi_callable( + requests, + metadata=(('test', + 'IterableStreamRequestBlockingUnaryResponse'),)) + + self.assertIs(grpc.StatusCode.UNKNOWN, + exception_context.exception.code()) + + def testIterableStreamRequestFutureUnaryResponse(self): + requests = object() + multi_callable = _stream_unary_multi_callable(self._channel) + response_future = multi_callable.future( + requests, + metadata=(('test', 'IterableStreamRequestFutureUnaryResponse'),)) + + with self.assertRaises(grpc.RpcError) as exception_context: + response_future.result() + + self.assertIs(grpc.StatusCode.UNKNOWN, + exception_context.exception.code()) + + def testIterableStreamRequestStreamResponse(self): + requests = object() + multi_callable = _stream_stream_multi_callable(self._channel) + response_iterator = multi_callable( + requests, + metadata=(('test', 'IterableStreamRequestStreamResponse'),)) + + with self.assertRaises(grpc.RpcError) as exception_context: + next(response_iterator) + + self.assertIs(grpc.StatusCode.UNKNOWN, + exception_context.exception.code()) + + def testIteratorStreamRequestStreamResponse(self): + requests_iterator = FailAfterFewIterationsCounter( + test_constants.STREAM_LENGTH // 2, b'\x07\x08') + multi_callable = _stream_stream_multi_callable(self._channel) + response_iterator = multi_callable( + requests_iterator, + metadata=(('test', 'IteratorStreamRequestStreamResponse'),)) + + with self.assertRaises(grpc.RpcError) as exception_context: + for _ in range(test_constants.STREAM_LENGTH // 2 + 1): + next(response_iterator) + + self.assertIs(grpc.StatusCode.UNKNOWN, + exception_context.exception.code()) + + def testDefectiveGenericRpcHandlerUnaryResponse(self): + request = b'\x07\x08' + multi_callable = _defective_handler_multi_callable(self._channel) + + with self.assertRaises(grpc.RpcError) as exception_context: + multi_callable(request, + metadata=(('test', + 'DefectiveGenericRpcHandlerUnary'),)) + + self.assertIs(grpc.StatusCode.UNKNOWN, + exception_context.exception.code()) + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_local_credentials_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_local_credentials_test.py new file mode 100644 index 00000000000..cd1f71dbeed --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_local_credentials_test.py @@ -0,0 +1,77 @@ +# Copyright 2019 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test of RPCs made using local credentials.""" + +import unittest +import os +from concurrent.futures import ThreadPoolExecutor +import grpc + + +class _GenericHandler(grpc.GenericRpcHandler): + + def service(self, handler_call_details): + return grpc.unary_unary_rpc_method_handler( + lambda request, unused_context: request) + + +class LocalCredentialsTest(unittest.TestCase): + + def _create_server(self): + server = grpc.server(ThreadPoolExecutor()) + server.add_generic_rpc_handlers((_GenericHandler(),)) + return server + + @unittest.skipIf(os.name == 'nt', + 'TODO(https://github.com/grpc/grpc/issues/20078)') + def test_local_tcp(self): + server_addr = 'localhost:{}' + channel_creds = grpc.local_channel_credentials( + grpc.LocalConnectionType.LOCAL_TCP) + server_creds = grpc.local_server_credentials( + grpc.LocalConnectionType.LOCAL_TCP) + + server = self._create_server() + port = server.add_secure_port(server_addr.format(0), server_creds) + server.start() + with grpc.secure_channel(server_addr.format(port), + channel_creds) as channel: + self.assertEqual( + b'abc', + channel.unary_unary('/test/method')(b'abc', + wait_for_ready=True)) + server.stop(None) + + @unittest.skipIf(os.name == 'nt', + 'Unix Domain Socket is not supported on Windows') + def test_uds(self): + server_addr = 'unix:/tmp/grpc_fullstack_test' + channel_creds = grpc.local_channel_credentials( + grpc.LocalConnectionType.UDS) + server_creds = grpc.local_server_credentials( + grpc.LocalConnectionType.UDS) + + server = self._create_server() + server.add_secure_port(server_addr, server_creds) + server.start() + with grpc.secure_channel(server_addr, channel_creds) as channel: + self.assertEqual( + b'abc', + channel.unary_unary('/test/method')(b'abc', + wait_for_ready=True)) + server.stop(None) + + +if __name__ == '__main__': + unittest.main() diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_logging_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_logging_test.py new file mode 100644 index 00000000000..1304bb55879 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_logging_test.py @@ -0,0 +1,103 @@ +# Copyright 2018 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test of gRPC Python's interaction with the python logging module""" + +import unittest +import logging +import grpc +import os +import subprocess +import sys + +INTERPRETER = sys.executable + + +class LoggingTest(unittest.TestCase): + + def test_logger_not_occupied(self): + script = """if True: + import logging + + import grpc + + if len(logging.getLogger().handlers) != 0: + raise Exception('expected 0 logging handlers') + + """ + self._verifyScriptSucceeds(script) + + def test_handler_found(self): + script = """if True: + import logging + + import grpc + """ + out, err = self._verifyScriptSucceeds(script) + self.assertEqual(0, len(err), 'unexpected output to stderr') + + def test_can_configure_logger(self): + script = """if True: + import logging + import six + + import grpc + + + intended_stream = six.StringIO() + logging.basicConfig(stream=intended_stream) + + if len(logging.getLogger().handlers) != 1: + raise Exception('expected 1 logging handler') + + if logging.getLogger().handlers[0].stream is not intended_stream: + raise Exception('wrong handler stream') + + """ + self._verifyScriptSucceeds(script) + + def test_grpc_logger(self): + script = """if True: + import logging + + import grpc + + if "grpc" not in logging.Logger.manager.loggerDict: + raise Exception('grpc logger not found') + + root_logger = logging.getLogger("grpc") + if len(root_logger.handlers) != 1: + raise Exception('expected 1 root logger handler') + if not isinstance(root_logger.handlers[0], logging.NullHandler): + raise Exception('expected logging.NullHandler') + + """ + self._verifyScriptSucceeds(script) + + def _verifyScriptSucceeds(self, script): + env = os.environ.copy() + env['Y_PYTHON_ENTRY_POINT'] = ':main' + process = subprocess.Popen([INTERPRETER, '-c', script], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=env) + out, err = process.communicate() + self.assertEqual( + 0, process.returncode, + 'process failed with exit code %d (stdout: %s, stderr: %s)' % + (process.returncode, out, err)) + return out, err + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py new file mode 100644 index 00000000000..5b06eb2bfe8 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py @@ -0,0 +1,663 @@ +# Copyright 2016 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests application-provided metadata, status code, and details.""" + +import threading +import unittest +import logging + +import grpc + +from tests.unit import test_common +from tests.unit.framework.common import test_constants +from tests.unit.framework.common import test_control + +_SERIALIZED_REQUEST = b'\x46\x47\x48' +_SERIALIZED_RESPONSE = b'\x49\x50\x51' + +_REQUEST_SERIALIZER = lambda unused_request: _SERIALIZED_REQUEST +_REQUEST_DESERIALIZER = lambda unused_serialized_request: object() +_RESPONSE_SERIALIZER = lambda unused_response: _SERIALIZED_RESPONSE +_RESPONSE_DESERIALIZER = lambda unused_serialized_response: object() + +_SERVICE = 'test.TestService' +_UNARY_UNARY = 'UnaryUnary' +_UNARY_STREAM = 'UnaryStream' +_STREAM_UNARY = 'StreamUnary' +_STREAM_STREAM = 'StreamStream' + +_CLIENT_METADATA = (('client-md-key', 'client-md-key'), ('client-md-key-bin', + b'\x00\x01')) + +_SERVER_INITIAL_METADATA = (('server-initial-md-key', + 'server-initial-md-value'), + ('server-initial-md-key-bin', b'\x00\x02')) + +_SERVER_TRAILING_METADATA = (('server-trailing-md-key', + 'server-trailing-md-value'), + ('server-trailing-md-key-bin', b'\x00\x03')) + +_NON_OK_CODE = grpc.StatusCode.NOT_FOUND +_DETAILS = 'Test details!' + +# calling abort should always fail an RPC, even for "invalid" codes +_ABORT_CODES = (_NON_OK_CODE, 3, grpc.StatusCode.OK) +_EXPECTED_CLIENT_CODES = (_NON_OK_CODE, grpc.StatusCode.UNKNOWN, + grpc.StatusCode.UNKNOWN) +_EXPECTED_DETAILS = (_DETAILS, _DETAILS, '') + + +class _Servicer(object): + + def __init__(self): + self._lock = threading.Lock() + self._abort_call = False + self._code = None + self._details = None + self._exception = False + self._return_none = False + self._received_client_metadata = None + + def unary_unary(self, request, context): + with self._lock: + self._received_client_metadata = context.invocation_metadata() + context.send_initial_metadata(_SERVER_INITIAL_METADATA) + context.set_trailing_metadata(_SERVER_TRAILING_METADATA) + if self._abort_call: + context.abort(self._code, self._details) + else: + if self._code is not None: + context.set_code(self._code) + if self._details is not None: + context.set_details(self._details) + if self._exception: + raise test_control.Defect() + else: + return None if self._return_none else object() + + def unary_stream(self, request, context): + with self._lock: + self._received_client_metadata = context.invocation_metadata() + context.send_initial_metadata(_SERVER_INITIAL_METADATA) + context.set_trailing_metadata(_SERVER_TRAILING_METADATA) + if self._abort_call: + context.abort(self._code, self._details) + else: + if self._code is not None: + context.set_code(self._code) + if self._details is not None: + context.set_details(self._details) + for _ in range(test_constants.STREAM_LENGTH // 2): + yield _SERIALIZED_RESPONSE + if self._exception: + raise test_control.Defect() + + def stream_unary(self, request_iterator, context): + with self._lock: + self._received_client_metadata = context.invocation_metadata() + context.send_initial_metadata(_SERVER_INITIAL_METADATA) + context.set_trailing_metadata(_SERVER_TRAILING_METADATA) + # TODO(https://github.com/grpc/grpc/issues/6891): just ignore the + # request iterator. + list(request_iterator) + if self._abort_call: + context.abort(self._code, self._details) + else: + if self._code is not None: + context.set_code(self._code) + if self._details is not None: + context.set_details(self._details) + if self._exception: + raise test_control.Defect() + else: + return None if self._return_none else _SERIALIZED_RESPONSE + + def stream_stream(self, request_iterator, context): + with self._lock: + self._received_client_metadata = context.invocation_metadata() + context.send_initial_metadata(_SERVER_INITIAL_METADATA) + context.set_trailing_metadata(_SERVER_TRAILING_METADATA) + # TODO(https://github.com/grpc/grpc/issues/6891): just ignore the + # request iterator. + list(request_iterator) + if self._abort_call: + context.abort(self._code, self._details) + else: + if self._code is not None: + context.set_code(self._code) + if self._details is not None: + context.set_details(self._details) + for _ in range(test_constants.STREAM_LENGTH // 3): + yield object() + if self._exception: + raise test_control.Defect() + + def set_abort_call(self): + with self._lock: + self._abort_call = True + + def set_code(self, code): + with self._lock: + self._code = code + + def set_details(self, details): + with self._lock: + self._details = details + + def set_exception(self): + with self._lock: + self._exception = True + + def set_return_none(self): + with self._lock: + self._return_none = True + + def received_client_metadata(self): + with self._lock: + return self._received_client_metadata + + +def _generic_handler(servicer): + method_handlers = { + _UNARY_UNARY: + grpc.unary_unary_rpc_method_handler( + servicer.unary_unary, + request_deserializer=_REQUEST_DESERIALIZER, + response_serializer=_RESPONSE_SERIALIZER), + _UNARY_STREAM: + grpc.unary_stream_rpc_method_handler(servicer.unary_stream), + _STREAM_UNARY: + grpc.stream_unary_rpc_method_handler(servicer.stream_unary), + _STREAM_STREAM: + grpc.stream_stream_rpc_method_handler( + servicer.stream_stream, + request_deserializer=_REQUEST_DESERIALIZER, + response_serializer=_RESPONSE_SERIALIZER), + } + return grpc.method_handlers_generic_handler(_SERVICE, method_handlers) + + +class MetadataCodeDetailsTest(unittest.TestCase): + + def setUp(self): + self._servicer = _Servicer() + self._server = test_common.test_server() + self._server.add_generic_rpc_handlers( + (_generic_handler(self._servicer),)) + port = self._server.add_insecure_port('[::]:0') + self._server.start() + + self._channel = grpc.insecure_channel('localhost:{}'.format(port)) + self._unary_unary = self._channel.unary_unary( + '/'.join(( + '', + _SERVICE, + _UNARY_UNARY, + )), + request_serializer=_REQUEST_SERIALIZER, + response_deserializer=_RESPONSE_DESERIALIZER, + ) + self._unary_stream = self._channel.unary_stream( + '/'.join(( + '', + _SERVICE, + _UNARY_STREAM, + )),) + self._stream_unary = self._channel.stream_unary( + '/'.join(( + '', + _SERVICE, + _STREAM_UNARY, + )),) + self._stream_stream = self._channel.stream_stream( + '/'.join(( + '', + _SERVICE, + _STREAM_STREAM, + )), + request_serializer=_REQUEST_SERIALIZER, + response_deserializer=_RESPONSE_DESERIALIZER, + ) + + def tearDown(self): + self._server.stop(None) + self._channel.close() + + def testSuccessfulUnaryUnary(self): + self._servicer.set_details(_DETAILS) + + unused_response, call = self._unary_unary.with_call( + object(), metadata=_CLIENT_METADATA) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, + call.initial_metadata())) + self.assertTrue( + test_common.metadata_transmitted(_SERVER_TRAILING_METADATA, + call.trailing_metadata())) + self.assertIs(grpc.StatusCode.OK, call.code()) + + def testSuccessfulUnaryStream(self): + self._servicer.set_details(_DETAILS) + + response_iterator_call = self._unary_stream(_SERIALIZED_REQUEST, + metadata=_CLIENT_METADATA) + received_initial_metadata = response_iterator_call.initial_metadata() + list(response_iterator_call) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, + received_initial_metadata)) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + response_iterator_call.trailing_metadata())) + self.assertIs(grpc.StatusCode.OK, response_iterator_call.code()) + + def testSuccessfulStreamUnary(self): + self._servicer.set_details(_DETAILS) + + unused_response, call = self._stream_unary.with_call( + iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH), + metadata=_CLIENT_METADATA) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, + call.initial_metadata())) + self.assertTrue( + test_common.metadata_transmitted(_SERVER_TRAILING_METADATA, + call.trailing_metadata())) + self.assertIs(grpc.StatusCode.OK, call.code()) + + def testSuccessfulStreamStream(self): + self._servicer.set_details(_DETAILS) + + response_iterator_call = self._stream_stream(iter( + [object()] * test_constants.STREAM_LENGTH), + metadata=_CLIENT_METADATA) + received_initial_metadata = response_iterator_call.initial_metadata() + list(response_iterator_call) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, + received_initial_metadata)) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + response_iterator_call.trailing_metadata())) + self.assertIs(grpc.StatusCode.OK, response_iterator_call.code()) + + def testAbortedUnaryUnary(self): + test_cases = zip(_ABORT_CODES, _EXPECTED_CLIENT_CODES, + _EXPECTED_DETAILS) + for abort_code, expected_code, expected_details in test_cases: + self._servicer.set_code(abort_code) + self._servicer.set_details(_DETAILS) + self._servicer.set_abort_call() + + with self.assertRaises(grpc.RpcError) as exception_context: + self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, + self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, + exception_context.exception.initial_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + exception_context.exception.trailing_metadata())) + self.assertIs(expected_code, exception_context.exception.code()) + self.assertEqual(expected_details, + exception_context.exception.details()) + + def testAbortedUnaryStream(self): + test_cases = zip(_ABORT_CODES, _EXPECTED_CLIENT_CODES, + _EXPECTED_DETAILS) + for abort_code, expected_code, expected_details in test_cases: + self._servicer.set_code(abort_code) + self._servicer.set_details(_DETAILS) + self._servicer.set_abort_call() + + response_iterator_call = self._unary_stream( + _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA) + received_initial_metadata = \ + response_iterator_call.initial_metadata() + with self.assertRaises(grpc.RpcError): + self.assertEqual(len(list(response_iterator_call)), 0) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, + self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, + received_initial_metadata)) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + response_iterator_call.trailing_metadata())) + self.assertIs(expected_code, response_iterator_call.code()) + self.assertEqual(expected_details, response_iterator_call.details()) + + def testAbortedStreamUnary(self): + test_cases = zip(_ABORT_CODES, _EXPECTED_CLIENT_CODES, + _EXPECTED_DETAILS) + for abort_code, expected_code, expected_details in test_cases: + self._servicer.set_code(abort_code) + self._servicer.set_details(_DETAILS) + self._servicer.set_abort_call() + + with self.assertRaises(grpc.RpcError) as exception_context: + self._stream_unary.with_call(iter([_SERIALIZED_REQUEST] * + test_constants.STREAM_LENGTH), + metadata=_CLIENT_METADATA) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, + self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, + exception_context.exception.initial_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + exception_context.exception.trailing_metadata())) + self.assertIs(expected_code, exception_context.exception.code()) + self.assertEqual(expected_details, + exception_context.exception.details()) + + def testAbortedStreamStream(self): + test_cases = zip(_ABORT_CODES, _EXPECTED_CLIENT_CODES, + _EXPECTED_DETAILS) + for abort_code, expected_code, expected_details in test_cases: + self._servicer.set_code(abort_code) + self._servicer.set_details(_DETAILS) + self._servicer.set_abort_call() + + response_iterator_call = self._stream_stream( + iter([object()] * test_constants.STREAM_LENGTH), + metadata=_CLIENT_METADATA) + received_initial_metadata = \ + response_iterator_call.initial_metadata() + with self.assertRaises(grpc.RpcError): + self.assertEqual(len(list(response_iterator_call)), 0) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, + self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, + received_initial_metadata)) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + response_iterator_call.trailing_metadata())) + self.assertIs(expected_code, response_iterator_call.code()) + self.assertEqual(expected_details, response_iterator_call.details()) + + def testCustomCodeUnaryUnary(self): + self._servicer.set_code(_NON_OK_CODE) + self._servicer.set_details(_DETAILS) + + with self.assertRaises(grpc.RpcError) as exception_context: + self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, + exception_context.exception.initial_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + exception_context.exception.trailing_metadata())) + self.assertIs(_NON_OK_CODE, exception_context.exception.code()) + self.assertEqual(_DETAILS, exception_context.exception.details()) + + def testCustomCodeUnaryStream(self): + self._servicer.set_code(_NON_OK_CODE) + self._servicer.set_details(_DETAILS) + + response_iterator_call = self._unary_stream(_SERIALIZED_REQUEST, + metadata=_CLIENT_METADATA) + received_initial_metadata = response_iterator_call.initial_metadata() + with self.assertRaises(grpc.RpcError): + list(response_iterator_call) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, + received_initial_metadata)) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + response_iterator_call.trailing_metadata())) + self.assertIs(_NON_OK_CODE, response_iterator_call.code()) + self.assertEqual(_DETAILS, response_iterator_call.details()) + + def testCustomCodeStreamUnary(self): + self._servicer.set_code(_NON_OK_CODE) + self._servicer.set_details(_DETAILS) + + with self.assertRaises(grpc.RpcError) as exception_context: + self._stream_unary.with_call(iter([_SERIALIZED_REQUEST] * + test_constants.STREAM_LENGTH), + metadata=_CLIENT_METADATA) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, + exception_context.exception.initial_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + exception_context.exception.trailing_metadata())) + self.assertIs(_NON_OK_CODE, exception_context.exception.code()) + self.assertEqual(_DETAILS, exception_context.exception.details()) + + def testCustomCodeStreamStream(self): + self._servicer.set_code(_NON_OK_CODE) + self._servicer.set_details(_DETAILS) + + response_iterator_call = self._stream_stream(iter( + [object()] * test_constants.STREAM_LENGTH), + metadata=_CLIENT_METADATA) + received_initial_metadata = response_iterator_call.initial_metadata() + with self.assertRaises(grpc.RpcError) as exception_context: + list(response_iterator_call) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, + received_initial_metadata)) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + exception_context.exception.trailing_metadata())) + self.assertIs(_NON_OK_CODE, exception_context.exception.code()) + self.assertEqual(_DETAILS, exception_context.exception.details()) + + def testCustomCodeExceptionUnaryUnary(self): + self._servicer.set_code(_NON_OK_CODE) + self._servicer.set_details(_DETAILS) + self._servicer.set_exception() + + with self.assertRaises(grpc.RpcError) as exception_context: + self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, + exception_context.exception.initial_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + exception_context.exception.trailing_metadata())) + self.assertIs(_NON_OK_CODE, exception_context.exception.code()) + self.assertEqual(_DETAILS, exception_context.exception.details()) + + def testCustomCodeExceptionUnaryStream(self): + self._servicer.set_code(_NON_OK_CODE) + self._servicer.set_details(_DETAILS) + self._servicer.set_exception() + + response_iterator_call = self._unary_stream(_SERIALIZED_REQUEST, + metadata=_CLIENT_METADATA) + received_initial_metadata = response_iterator_call.initial_metadata() + with self.assertRaises(grpc.RpcError): + list(response_iterator_call) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, + received_initial_metadata)) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + response_iterator_call.trailing_metadata())) + self.assertIs(_NON_OK_CODE, response_iterator_call.code()) + self.assertEqual(_DETAILS, response_iterator_call.details()) + + def testCustomCodeExceptionStreamUnary(self): + self._servicer.set_code(_NON_OK_CODE) + self._servicer.set_details(_DETAILS) + self._servicer.set_exception() + + with self.assertRaises(grpc.RpcError) as exception_context: + self._stream_unary.with_call(iter([_SERIALIZED_REQUEST] * + test_constants.STREAM_LENGTH), + metadata=_CLIENT_METADATA) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, + exception_context.exception.initial_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + exception_context.exception.trailing_metadata())) + self.assertIs(_NON_OK_CODE, exception_context.exception.code()) + self.assertEqual(_DETAILS, exception_context.exception.details()) + + def testCustomCodeExceptionStreamStream(self): + self._servicer.set_code(_NON_OK_CODE) + self._servicer.set_details(_DETAILS) + self._servicer.set_exception() + + response_iterator_call = self._stream_stream(iter( + [object()] * test_constants.STREAM_LENGTH), + metadata=_CLIENT_METADATA) + received_initial_metadata = response_iterator_call.initial_metadata() + with self.assertRaises(grpc.RpcError): + list(response_iterator_call) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, + received_initial_metadata)) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + response_iterator_call.trailing_metadata())) + self.assertIs(_NON_OK_CODE, response_iterator_call.code()) + self.assertEqual(_DETAILS, response_iterator_call.details()) + + def testCustomCodeReturnNoneUnaryUnary(self): + self._servicer.set_code(_NON_OK_CODE) + self._servicer.set_details(_DETAILS) + self._servicer.set_return_none() + + with self.assertRaises(grpc.RpcError) as exception_context: + self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, + exception_context.exception.initial_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + exception_context.exception.trailing_metadata())) + self.assertIs(_NON_OK_CODE, exception_context.exception.code()) + self.assertEqual(_DETAILS, exception_context.exception.details()) + + def testCustomCodeReturnNoneStreamUnary(self): + self._servicer.set_code(_NON_OK_CODE) + self._servicer.set_details(_DETAILS) + self._servicer.set_return_none() + + with self.assertRaises(grpc.RpcError) as exception_context: + self._stream_unary.with_call(iter([_SERIALIZED_REQUEST] * + test_constants.STREAM_LENGTH), + metadata=_CLIENT_METADATA) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, + exception_context.exception.initial_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + exception_context.exception.trailing_metadata())) + self.assertIs(_NON_OK_CODE, exception_context.exception.code()) + self.assertEqual(_DETAILS, exception_context.exception.details()) + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py new file mode 100644 index 00000000000..e2b36b1c70f --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py @@ -0,0 +1,260 @@ +# Copyright 2018 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests metadata flags feature by testing wait-for-ready semantics""" + +import time +import weakref +import unittest +import threading +import logging +import socket +from six.moves import queue + +import grpc + +from tests.unit import test_common +from tests.unit.framework.common import test_constants +import tests.unit.framework.common +from tests.unit.framework.common import get_socket + +_UNARY_UNARY = '/test/UnaryUnary' +_UNARY_STREAM = '/test/UnaryStream' +_STREAM_UNARY = '/test/StreamUnary' +_STREAM_STREAM = '/test/StreamStream' + +_REQUEST = b'\x00\x00\x00' +_RESPONSE = b'\x00\x00\x00' + + +def handle_unary_unary(test, request, servicer_context): + return _RESPONSE + + +def handle_unary_stream(test, request, servicer_context): + for _ in range(test_constants.STREAM_LENGTH): + yield _RESPONSE + + +def handle_stream_unary(test, request_iterator, servicer_context): + for _ in request_iterator: + pass + return _RESPONSE + + +def handle_stream_stream(test, request_iterator, servicer_context): + for _ in request_iterator: + yield _RESPONSE + + +class _MethodHandler(grpc.RpcMethodHandler): + + def __init__(self, test, request_streaming, response_streaming): + self.request_streaming = request_streaming + self.response_streaming = response_streaming + self.request_deserializer = None + self.response_serializer = None + self.unary_unary = None + self.unary_stream = None + self.stream_unary = None + self.stream_stream = None + if self.request_streaming and self.response_streaming: + self.stream_stream = lambda req, ctx: handle_stream_stream( + test, req, ctx) + elif self.request_streaming: + self.stream_unary = lambda req, ctx: handle_stream_unary( + test, req, ctx) + elif self.response_streaming: + self.unary_stream = lambda req, ctx: handle_unary_stream( + test, req, ctx) + else: + self.unary_unary = lambda req, ctx: handle_unary_unary( + test, req, ctx) + + +class _GenericHandler(grpc.GenericRpcHandler): + + def __init__(self, test): + self._test = test + + def service(self, handler_call_details): + if handler_call_details.method == _UNARY_UNARY: + return _MethodHandler(self._test, False, False) + elif handler_call_details.method == _UNARY_STREAM: + return _MethodHandler(self._test, False, True) + elif handler_call_details.method == _STREAM_UNARY: + return _MethodHandler(self._test, True, False) + elif handler_call_details.method == _STREAM_STREAM: + return _MethodHandler(self._test, True, True) + else: + return None + + +def create_dummy_channel(): + """Creating dummy channels is a workaround for retries""" + host, port, sock = get_socket(sock_options=(socket.SO_REUSEADDR,)) + sock.close() + return grpc.insecure_channel('{}:{}'.format(host, port)) + + +def perform_unary_unary_call(channel, wait_for_ready=None): + channel.unary_unary(_UNARY_UNARY).__call__( + _REQUEST, + timeout=test_constants.LONG_TIMEOUT, + wait_for_ready=wait_for_ready) + + +def perform_unary_unary_with_call(channel, wait_for_ready=None): + channel.unary_unary(_UNARY_UNARY).with_call( + _REQUEST, + timeout=test_constants.LONG_TIMEOUT, + wait_for_ready=wait_for_ready) + + +def perform_unary_unary_future(channel, wait_for_ready=None): + channel.unary_unary(_UNARY_UNARY).future( + _REQUEST, + timeout=test_constants.LONG_TIMEOUT, + wait_for_ready=wait_for_ready).result( + timeout=test_constants.LONG_TIMEOUT) + + +def perform_unary_stream_call(channel, wait_for_ready=None): + response_iterator = channel.unary_stream(_UNARY_STREAM).__call__( + _REQUEST, + timeout=test_constants.LONG_TIMEOUT, + wait_for_ready=wait_for_ready) + for _ in response_iterator: + pass + + +def perform_stream_unary_call(channel, wait_for_ready=None): + channel.stream_unary(_STREAM_UNARY).__call__( + iter([_REQUEST] * test_constants.STREAM_LENGTH), + timeout=test_constants.LONG_TIMEOUT, + wait_for_ready=wait_for_ready) + + +def perform_stream_unary_with_call(channel, wait_for_ready=None): + channel.stream_unary(_STREAM_UNARY).with_call( + iter([_REQUEST] * test_constants.STREAM_LENGTH), + timeout=test_constants.LONG_TIMEOUT, + wait_for_ready=wait_for_ready) + + +def perform_stream_unary_future(channel, wait_for_ready=None): + channel.stream_unary(_STREAM_UNARY).future( + iter([_REQUEST] * test_constants.STREAM_LENGTH), + timeout=test_constants.LONG_TIMEOUT, + wait_for_ready=wait_for_ready).result( + timeout=test_constants.LONG_TIMEOUT) + + +def perform_stream_stream_call(channel, wait_for_ready=None): + response_iterator = channel.stream_stream(_STREAM_STREAM).__call__( + iter([_REQUEST] * test_constants.STREAM_LENGTH), + timeout=test_constants.LONG_TIMEOUT, + wait_for_ready=wait_for_ready) + for _ in response_iterator: + pass + + +_ALL_CALL_CASES = [ + perform_unary_unary_call, perform_unary_unary_with_call, + perform_unary_unary_future, perform_unary_stream_call, + perform_stream_unary_call, perform_stream_unary_with_call, + perform_stream_unary_future, perform_stream_stream_call +] + + +class MetadataFlagsTest(unittest.TestCase): + + def check_connection_does_failfast(self, fn, channel, wait_for_ready=None): + try: + fn(channel, wait_for_ready) + self.fail("The Call should fail") + except BaseException as e: # pylint: disable=broad-except + self.assertIs(grpc.StatusCode.UNAVAILABLE, e.code()) + + def test_call_wait_for_ready_default(self): + for perform_call in _ALL_CALL_CASES: + with create_dummy_channel() as channel: + self.check_connection_does_failfast(perform_call, channel) + + def test_call_wait_for_ready_disabled(self): + for perform_call in _ALL_CALL_CASES: + with create_dummy_channel() as channel: + self.check_connection_does_failfast(perform_call, + channel, + wait_for_ready=False) + + def test_call_wait_for_ready_enabled(self): + # To test the wait mechanism, Python thread is required to make + # client set up first without handling them case by case. + # Also, Python thread don't pass the unhandled exceptions to + # main thread. So, it need another method to store the + # exceptions and raise them again in main thread. + unhandled_exceptions = queue.Queue() + + # We just need an unused TCP port + host, port, sock = get_socket(sock_options=(socket.SO_REUSEADDR,)) + sock.close() + + addr = '{}:{}'.format(host, port) + wg = test_common.WaitGroup(len(_ALL_CALL_CASES)) + + def wait_for_transient_failure(channel_connectivity): + if channel_connectivity == grpc.ChannelConnectivity.TRANSIENT_FAILURE: + wg.done() + + def test_call(perform_call): + with grpc.insecure_channel(addr) as channel: + try: + channel.subscribe(wait_for_transient_failure) + perform_call(channel, wait_for_ready=True) + except BaseException as e: # pylint: disable=broad-except + # If the call failed, the thread would be destroyed. The + # channel object can be collected before calling the + # callback, which will result in a deadlock. + wg.done() + unhandled_exceptions.put(e, True) + + test_threads = [] + for perform_call in _ALL_CALL_CASES: + test_thread = threading.Thread(target=test_call, + args=(perform_call,)) + test_thread.daemon = True + test_thread.exception = None + test_thread.start() + test_threads.append(test_thread) + + # Start the server after the connections are waiting + wg.wait() + server = test_common.test_server(reuse_port=True) + server.add_generic_rpc_handlers((_GenericHandler(weakref.proxy(self)),)) + server.add_insecure_port(addr) + server.start() + + for test_thread in test_threads: + test_thread.join() + + # Stop the server to make test end properly + server.stop(0) + + if not unhandled_exceptions.empty(): + raise unhandled_exceptions.get(True) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_metadata_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_metadata_test.py new file mode 100644 index 00000000000..3e7717b04c7 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_metadata_test.py @@ -0,0 +1,242 @@ +# Copyright 2016 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests server and client side metadata API.""" + +import unittest +import weakref +import logging + +import grpc +from grpc import _channel + +from tests.unit import test_common +from tests.unit.framework.common import test_constants + +_CHANNEL_ARGS = (('grpc.primary_user_agent', 'primary-agent'), + ('grpc.secondary_user_agent', 'secondary-agent')) + +_REQUEST = b'\x00\x00\x00' +_RESPONSE = b'\x00\x00\x00' + +_UNARY_UNARY = '/test/UnaryUnary' +_UNARY_STREAM = '/test/UnaryStream' +_STREAM_UNARY = '/test/StreamUnary' +_STREAM_STREAM = '/test/StreamStream' + +_INVOCATION_METADATA = ( + ( + b'invocation-md-key', + u'invocation-md-value', + ), + ( + u'invocation-md-key-bin', + b'\x00\x01', + ), +) +_EXPECTED_INVOCATION_METADATA = ( + ( + 'invocation-md-key', + 'invocation-md-value', + ), + ( + 'invocation-md-key-bin', + b'\x00\x01', + ), +) + +_INITIAL_METADATA = ((b'initial-md-key', u'initial-md-value'), + (u'initial-md-key-bin', b'\x00\x02')) +_EXPECTED_INITIAL_METADATA = ( + ( + 'initial-md-key', + 'initial-md-value', + ), + ( + 'initial-md-key-bin', + b'\x00\x02', + ), +) + +_TRAILING_METADATA = ( + ( + 'server-trailing-md-key', + 'server-trailing-md-value', + ), + ( + 'server-trailing-md-key-bin', + b'\x00\x03', + ), +) +_EXPECTED_TRAILING_METADATA = _TRAILING_METADATA + + +def _user_agent(metadata): + for key, val in metadata: + if key == 'user-agent': + return val + raise KeyError('No user agent!') + + +def validate_client_metadata(test, servicer_context): + invocation_metadata = servicer_context.invocation_metadata() + test.assertTrue( + test_common.metadata_transmitted(_EXPECTED_INVOCATION_METADATA, + invocation_metadata)) + user_agent = _user_agent(invocation_metadata) + test.assertTrue( + user_agent.startswith('primary-agent ' + _channel._USER_AGENT)) + test.assertTrue(user_agent.endswith('secondary-agent')) + + +def handle_unary_unary(test, request, servicer_context): + validate_client_metadata(test, servicer_context) + servicer_context.send_initial_metadata(_INITIAL_METADATA) + servicer_context.set_trailing_metadata(_TRAILING_METADATA) + return _RESPONSE + + +def handle_unary_stream(test, request, servicer_context): + validate_client_metadata(test, servicer_context) + servicer_context.send_initial_metadata(_INITIAL_METADATA) + servicer_context.set_trailing_metadata(_TRAILING_METADATA) + for _ in range(test_constants.STREAM_LENGTH): + yield _RESPONSE + + +def handle_stream_unary(test, request_iterator, servicer_context): + validate_client_metadata(test, servicer_context) + servicer_context.send_initial_metadata(_INITIAL_METADATA) + servicer_context.set_trailing_metadata(_TRAILING_METADATA) + # TODO(issue:#6891) We should be able to remove this loop + for request in request_iterator: + pass + return _RESPONSE + + +def handle_stream_stream(test, request_iterator, servicer_context): + validate_client_metadata(test, servicer_context) + servicer_context.send_initial_metadata(_INITIAL_METADATA) + servicer_context.set_trailing_metadata(_TRAILING_METADATA) + # TODO(issue:#6891) We should be able to remove this loop, + # and replace with return; yield + for request in request_iterator: + yield _RESPONSE + + +class _MethodHandler(grpc.RpcMethodHandler): + + def __init__(self, test, request_streaming, response_streaming): + self.request_streaming = request_streaming + self.response_streaming = response_streaming + self.request_deserializer = None + self.response_serializer = None + self.unary_unary = None + self.unary_stream = None + self.stream_unary = None + self.stream_stream = None + if self.request_streaming and self.response_streaming: + self.stream_stream = lambda x, y: handle_stream_stream(test, x, y) + elif self.request_streaming: + self.stream_unary = lambda x, y: handle_stream_unary(test, x, y) + elif self.response_streaming: + self.unary_stream = lambda x, y: handle_unary_stream(test, x, y) + else: + self.unary_unary = lambda x, y: handle_unary_unary(test, x, y) + + +class _GenericHandler(grpc.GenericRpcHandler): + + def __init__(self, test): + self._test = test + + def service(self, handler_call_details): + if handler_call_details.method == _UNARY_UNARY: + return _MethodHandler(self._test, False, False) + elif handler_call_details.method == _UNARY_STREAM: + return _MethodHandler(self._test, False, True) + elif handler_call_details.method == _STREAM_UNARY: + return _MethodHandler(self._test, True, False) + elif handler_call_details.method == _STREAM_STREAM: + return _MethodHandler(self._test, True, True) + else: + return None + + +class MetadataTest(unittest.TestCase): + + def setUp(self): + self._server = test_common.test_server() + self._server.add_generic_rpc_handlers( + (_GenericHandler(weakref.proxy(self)),)) + port = self._server.add_insecure_port('[::]:0') + self._server.start() + self._channel = grpc.insecure_channel('localhost:%d' % port, + options=_CHANNEL_ARGS) + + def tearDown(self): + self._server.stop(0) + self._channel.close() + + def testUnaryUnary(self): + multi_callable = self._channel.unary_unary(_UNARY_UNARY) + unused_response, call = multi_callable.with_call( + _REQUEST, metadata=_INVOCATION_METADATA) + self.assertTrue( + test_common.metadata_transmitted(_EXPECTED_INITIAL_METADATA, + call.initial_metadata())) + self.assertTrue( + test_common.metadata_transmitted(_EXPECTED_TRAILING_METADATA, + call.trailing_metadata())) + + def testUnaryStream(self): + multi_callable = self._channel.unary_stream(_UNARY_STREAM) + call = multi_callable(_REQUEST, metadata=_INVOCATION_METADATA) + self.assertTrue( + test_common.metadata_transmitted(_EXPECTED_INITIAL_METADATA, + call.initial_metadata())) + for _ in call: + pass + self.assertTrue( + test_common.metadata_transmitted(_EXPECTED_TRAILING_METADATA, + call.trailing_metadata())) + + def testStreamUnary(self): + multi_callable = self._channel.stream_unary(_STREAM_UNARY) + unused_response, call = multi_callable.with_call( + iter([_REQUEST] * test_constants.STREAM_LENGTH), + metadata=_INVOCATION_METADATA) + self.assertTrue( + test_common.metadata_transmitted(_EXPECTED_INITIAL_METADATA, + call.initial_metadata())) + self.assertTrue( + test_common.metadata_transmitted(_EXPECTED_TRAILING_METADATA, + call.trailing_metadata())) + + def testStreamStream(self): + multi_callable = self._channel.stream_stream(_STREAM_STREAM) + call = multi_callable(iter([_REQUEST] * test_constants.STREAM_LENGTH), + metadata=_INVOCATION_METADATA) + self.assertTrue( + test_common.metadata_transmitted(_EXPECTED_INITIAL_METADATA, + call.initial_metadata())) + for _ in call: + pass + self.assertTrue( + test_common.metadata_transmitted(_EXPECTED_TRAILING_METADATA, + call.trailing_metadata())) + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_reconnect_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_reconnect_test.py new file mode 100644 index 00000000000..16feb4b1ff4 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_reconnect_test.py @@ -0,0 +1,69 @@ +# Copyright 2017 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests that a channel will reconnect if a connection is dropped""" + +import socket +import time +import logging +import unittest + +import grpc +from grpc.framework.foundation import logging_pool + +from tests.unit.framework.common import test_constants +from tests.unit.framework.common import bound_socket + +_REQUEST = b'\x00\x00\x00' +_RESPONSE = b'\x00\x00\x01' + +_UNARY_UNARY = '/test/UnaryUnary' + + +def _handle_unary_unary(unused_request, unused_servicer_context): + return _RESPONSE + + +class ReconnectTest(unittest.TestCase): + + def test_reconnect(self): + server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY) + handler = grpc.method_handlers_generic_handler('test', { + 'UnaryUnary': + grpc.unary_unary_rpc_method_handler(_handle_unary_unary) + }) + options = (('grpc.so_reuseport', 1),) + with bound_socket() as (host, port): + addr = '{}:{}'.format(host, port) + server = grpc.server(server_pool, (handler,), options=options) + server.add_insecure_port(addr) + server.start() + channel = grpc.insecure_channel(addr) + multi_callable = channel.unary_unary(_UNARY_UNARY) + self.assertEqual(_RESPONSE, multi_callable(_REQUEST)) + server.stop(None) + # By default, the channel connectivity is checked every 5s + # GRPC_CLIENT_CHANNEL_BACKUP_POLL_INTERVAL_MS can be set to change + # this. + time.sleep(5.1) + server = grpc.server(server_pool, (handler,), options=options) + server.add_insecure_port(addr) + server.start() + self.assertEqual(_RESPONSE, multi_callable(_REQUEST)) + server.stop(None) + channel.close() + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_resource_exhausted_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_resource_exhausted_test.py new file mode 100644 index 00000000000..ecd2ccadbde --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_resource_exhausted_test.py @@ -0,0 +1,259 @@ +# Copyright 2017 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests server responding with RESOURCE_EXHAUSTED.""" + +import threading +import unittest +import logging + +import grpc +from grpc import _channel +from grpc.framework.foundation import logging_pool + +from tests.unit import test_common +from tests.unit.framework.common import test_constants + +_REQUEST = b'\x00\x00\x00' +_RESPONSE = b'\x00\x00\x00' + +_UNARY_UNARY = '/test/UnaryUnary' +_UNARY_STREAM = '/test/UnaryStream' +_STREAM_UNARY = '/test/StreamUnary' +_STREAM_STREAM = '/test/StreamStream' + + +class _TestTrigger(object): + + def __init__(self, total_call_count): + self._total_call_count = total_call_count + self._pending_calls = 0 + self._triggered = False + self._finish_condition = threading.Condition() + self._start_condition = threading.Condition() + + # Wait for all calls be blocked in their handler + def await_calls(self): + with self._start_condition: + while self._pending_calls < self._total_call_count: + self._start_condition.wait() + + # Block in a response handler and wait for a trigger + def await_trigger(self): + with self._start_condition: + self._pending_calls += 1 + self._start_condition.notify() + + with self._finish_condition: + if not self._triggered: + self._finish_condition.wait() + + # Finish all response handlers + def trigger(self): + with self._finish_condition: + self._triggered = True + self._finish_condition.notify_all() + + +def handle_unary_unary(trigger, request, servicer_context): + trigger.await_trigger() + return _RESPONSE + + +def handle_unary_stream(trigger, request, servicer_context): + trigger.await_trigger() + for _ in range(test_constants.STREAM_LENGTH): + yield _RESPONSE + + +def handle_stream_unary(trigger, request_iterator, servicer_context): + trigger.await_trigger() + # TODO(issue:#6891) We should be able to remove this loop + for request in request_iterator: + pass + return _RESPONSE + + +def handle_stream_stream(trigger, request_iterator, servicer_context): + trigger.await_trigger() + # TODO(issue:#6891) We should be able to remove this loop, + # and replace with return; yield + for request in request_iterator: + yield _RESPONSE + + +class _MethodHandler(grpc.RpcMethodHandler): + + def __init__(self, trigger, request_streaming, response_streaming): + self.request_streaming = request_streaming + self.response_streaming = response_streaming + self.request_deserializer = None + self.response_serializer = None + self.unary_unary = None + self.unary_stream = None + self.stream_unary = None + self.stream_stream = None + if self.request_streaming and self.response_streaming: + self.stream_stream = ( + lambda x, y: handle_stream_stream(trigger, x, y)) + elif self.request_streaming: + self.stream_unary = lambda x, y: handle_stream_unary(trigger, x, y) + elif self.response_streaming: + self.unary_stream = lambda x, y: handle_unary_stream(trigger, x, y) + else: + self.unary_unary = lambda x, y: handle_unary_unary(trigger, x, y) + + +class _GenericHandler(grpc.GenericRpcHandler): + + def __init__(self, trigger): + self._trigger = trigger + + def service(self, handler_call_details): + if handler_call_details.method == _UNARY_UNARY: + return _MethodHandler(self._trigger, False, False) + elif handler_call_details.method == _UNARY_STREAM: + return _MethodHandler(self._trigger, False, True) + elif handler_call_details.method == _STREAM_UNARY: + return _MethodHandler(self._trigger, True, False) + elif handler_call_details.method == _STREAM_STREAM: + return _MethodHandler(self._trigger, True, True) + else: + return None + + +class ResourceExhaustedTest(unittest.TestCase): + + def setUp(self): + self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY) + self._trigger = _TestTrigger(test_constants.THREAD_CONCURRENCY) + self._server = grpc.server( + self._server_pool, + handlers=(_GenericHandler(self._trigger),), + options=(('grpc.so_reuseport', 0),), + maximum_concurrent_rpcs=test_constants.THREAD_CONCURRENCY) + port = self._server.add_insecure_port('[::]:0') + self._server.start() + self._channel = grpc.insecure_channel('localhost:%d' % port) + + def tearDown(self): + self._server.stop(0) + self._channel.close() + + def testUnaryUnary(self): + multi_callable = self._channel.unary_unary(_UNARY_UNARY) + futures = [] + for _ in range(test_constants.THREAD_CONCURRENCY): + futures.append(multi_callable.future(_REQUEST)) + + self._trigger.await_calls() + + with self.assertRaises(grpc.RpcError) as exception_context: + multi_callable(_REQUEST) + + self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED, + exception_context.exception.code()) + + future_exception = multi_callable.future(_REQUEST) + self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED, + future_exception.exception().code()) + + self._trigger.trigger() + for future in futures: + self.assertEqual(_RESPONSE, future.result()) + + # Ensure a new request can be handled + self.assertEqual(_RESPONSE, multi_callable(_REQUEST)) + + def testUnaryStream(self): + multi_callable = self._channel.unary_stream(_UNARY_STREAM) + calls = [] + for _ in range(test_constants.THREAD_CONCURRENCY): + calls.append(multi_callable(_REQUEST)) + + self._trigger.await_calls() + + with self.assertRaises(grpc.RpcError) as exception_context: + next(multi_callable(_REQUEST)) + + self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED, + exception_context.exception.code()) + + self._trigger.trigger() + + for call in calls: + for response in call: + self.assertEqual(_RESPONSE, response) + + # Ensure a new request can be handled + new_call = multi_callable(_REQUEST) + for response in new_call: + self.assertEqual(_RESPONSE, response) + + def testStreamUnary(self): + multi_callable = self._channel.stream_unary(_STREAM_UNARY) + futures = [] + request = iter([_REQUEST] * test_constants.STREAM_LENGTH) + for _ in range(test_constants.THREAD_CONCURRENCY): + futures.append(multi_callable.future(request)) + + self._trigger.await_calls() + + with self.assertRaises(grpc.RpcError) as exception_context: + multi_callable(request) + + self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED, + exception_context.exception.code()) + + future_exception = multi_callable.future(request) + self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED, + future_exception.exception().code()) + + self._trigger.trigger() + + for future in futures: + self.assertEqual(_RESPONSE, future.result()) + + # Ensure a new request can be handled + self.assertEqual(_RESPONSE, multi_callable(request)) + + def testStreamStream(self): + multi_callable = self._channel.stream_stream(_STREAM_STREAM) + calls = [] + request = iter([_REQUEST] * test_constants.STREAM_LENGTH) + for _ in range(test_constants.THREAD_CONCURRENCY): + calls.append(multi_callable(request)) + + self._trigger.await_calls() + + with self.assertRaises(grpc.RpcError) as exception_context: + next(multi_callable(request)) + + self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED, + exception_context.exception.code()) + + self._trigger.trigger() + + for call in calls: + for response in call: + self.assertEqual(_RESPONSE, response) + + # Ensure a new request can be handled + new_call = multi_callable(request) + for response in new_call: + self.assertEqual(_RESPONSE, response) + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_rpc_part_1_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_rpc_part_1_test.py new file mode 100644 index 00000000000..9b0cb29a0d5 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_rpc_part_1_test.py @@ -0,0 +1,232 @@ +# Copyright 2016 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test of RPCs made against gRPC Python's application-layer API.""" + +import itertools +import threading +import unittest +import logging +from concurrent import futures + +import grpc +from grpc.framework.foundation import logging_pool + +from tests.unit._rpc_test_helpers import ( + TIMEOUT_SHORT, Callback, unary_unary_multi_callable, + unary_stream_multi_callable, unary_stream_non_blocking_multi_callable, + stream_unary_multi_callable, stream_stream_multi_callable, + stream_stream_non_blocking_multi_callable, BaseRPCTest) +from tests.unit.framework.common import test_constants + + +class RPCPart1Test(BaseRPCTest, unittest.TestCase): + + def testExpiredStreamRequestBlockingUnaryResponse(self): + requests = tuple( + b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + multi_callable = stream_unary_multi_callable(self._channel) + with self._control.pause(): + with self.assertRaises(grpc.RpcError) as exception_context: + multi_callable( + request_iterator, + timeout=TIMEOUT_SHORT, + metadata=(('test', + 'ExpiredStreamRequestBlockingUnaryResponse'),)) + + self.assertIsInstance(exception_context.exception, grpc.RpcError) + self.assertIsInstance(exception_context.exception, grpc.Call) + self.assertIsNotNone(exception_context.exception.initial_metadata()) + self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, + exception_context.exception.code()) + self.assertIsNotNone(exception_context.exception.details()) + self.assertIsNotNone(exception_context.exception.trailing_metadata()) + + def testExpiredStreamRequestFutureUnaryResponse(self): + requests = tuple( + b'\x07\x18' for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + callback = Callback() + + multi_callable = stream_unary_multi_callable(self._channel) + with self._control.pause(): + response_future = multi_callable.future( + request_iterator, + timeout=TIMEOUT_SHORT, + metadata=(('test', 'ExpiredStreamRequestFutureUnaryResponse'),)) + with self.assertRaises(grpc.FutureTimeoutError): + response_future.result(timeout=TIMEOUT_SHORT / 2.0) + response_future.add_done_callback(callback) + value_passed_to_callback = callback.value() + + with self.assertRaises(grpc.RpcError) as exception_context: + response_future.result() + self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, response_future.code()) + self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, + exception_context.exception.code()) + self.assertIsInstance(response_future.exception(), grpc.RpcError) + self.assertIsNotNone(response_future.traceback()) + self.assertIs(response_future, value_passed_to_callback) + self.assertIsNotNone(response_future.initial_metadata()) + self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, response_future.code()) + self.assertIsNotNone(response_future.details()) + self.assertIsNotNone(response_future.trailing_metadata()) + + def testExpiredStreamRequestStreamResponse(self): + self._expired_stream_request_stream_response( + stream_stream_multi_callable(self._channel)) + + def testExpiredStreamRequestStreamResponseNonBlocking(self): + self._expired_stream_request_stream_response( + stream_stream_non_blocking_multi_callable(self._channel)) + + def testFailedUnaryRequestBlockingUnaryResponse(self): + request = b'\x37\x17' + + multi_callable = unary_unary_multi_callable(self._channel) + with self._control.fail(): + with self.assertRaises(grpc.RpcError) as exception_context: + multi_callable.with_call( + request, + metadata=(('test', + 'FailedUnaryRequestBlockingUnaryResponse'),)) + + self.assertIs(grpc.StatusCode.UNKNOWN, + exception_context.exception.code()) + # sanity checks on to make sure returned string contains default members + # of the error + debug_error_string = exception_context.exception.debug_error_string() + self.assertIn('created', debug_error_string) + self.assertIn('description', debug_error_string) + self.assertIn('file', debug_error_string) + self.assertIn('file_line', debug_error_string) + + def testFailedUnaryRequestFutureUnaryResponse(self): + request = b'\x37\x17' + callback = Callback() + + multi_callable = unary_unary_multi_callable(self._channel) + with self._control.fail(): + response_future = multi_callable.future( + request, + metadata=(('test', 'FailedUnaryRequestFutureUnaryResponse'),)) + response_future.add_done_callback(callback) + value_passed_to_callback = callback.value() + + self.assertIsInstance(response_future, grpc.Future) + self.assertIsInstance(response_future, grpc.Call) + with self.assertRaises(grpc.RpcError) as exception_context: + response_future.result() + self.assertIs(grpc.StatusCode.UNKNOWN, + exception_context.exception.code()) + self.assertIsInstance(response_future.exception(), grpc.RpcError) + self.assertIsNotNone(response_future.traceback()) + self.assertIs(grpc.StatusCode.UNKNOWN, + response_future.exception().code()) + self.assertIs(response_future, value_passed_to_callback) + + def testFailedUnaryRequestStreamResponse(self): + self._failed_unary_request_stream_response( + unary_stream_multi_callable(self._channel)) + + def testFailedUnaryRequestStreamResponseNonBlocking(self): + self._failed_unary_request_stream_response( + unary_stream_non_blocking_multi_callable(self._channel)) + + def testFailedStreamRequestBlockingUnaryResponse(self): + requests = tuple( + b'\x47\x58' for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + multi_callable = stream_unary_multi_callable(self._channel) + with self._control.fail(): + with self.assertRaises(grpc.RpcError) as exception_context: + multi_callable( + request_iterator, + metadata=(('test', + 'FailedStreamRequestBlockingUnaryResponse'),)) + + self.assertIs(grpc.StatusCode.UNKNOWN, + exception_context.exception.code()) + + def testFailedStreamRequestFutureUnaryResponse(self): + requests = tuple( + b'\x07\x18' for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + callback = Callback() + + multi_callable = stream_unary_multi_callable(self._channel) + with self._control.fail(): + response_future = multi_callable.future( + request_iterator, + metadata=(('test', 'FailedStreamRequestFutureUnaryResponse'),)) + response_future.add_done_callback(callback) + value_passed_to_callback = callback.value() + + with self.assertRaises(grpc.RpcError) as exception_context: + response_future.result() + self.assertIs(grpc.StatusCode.UNKNOWN, response_future.code()) + self.assertIs(grpc.StatusCode.UNKNOWN, + exception_context.exception.code()) + self.assertIsInstance(response_future.exception(), grpc.RpcError) + self.assertIsNotNone(response_future.traceback()) + self.assertIs(response_future, value_passed_to_callback) + + def testFailedStreamRequestStreamResponse(self): + self._failed_stream_request_stream_response( + stream_stream_multi_callable(self._channel)) + + def testFailedStreamRequestStreamResponseNonBlocking(self): + self._failed_stream_request_stream_response( + stream_stream_non_blocking_multi_callable(self._channel)) + + def testIgnoredUnaryRequestFutureUnaryResponse(self): + request = b'\x37\x17' + + multi_callable = unary_unary_multi_callable(self._channel) + multi_callable.future( + request, + metadata=(('test', 'IgnoredUnaryRequestFutureUnaryResponse'),)) + + def testIgnoredUnaryRequestStreamResponse(self): + self._ignored_unary_stream_request_future_unary_response( + unary_stream_multi_callable(self._channel)) + + def testIgnoredUnaryRequestStreamResponseNonBlocking(self): + self._ignored_unary_stream_request_future_unary_response( + unary_stream_non_blocking_multi_callable(self._channel)) + + def testIgnoredStreamRequestFutureUnaryResponse(self): + requests = tuple( + b'\x07\x18' for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + multi_callable = stream_unary_multi_callable(self._channel) + multi_callable.future( + request_iterator, + metadata=(('test', 'IgnoredStreamRequestFutureUnaryResponse'),)) + + def testIgnoredStreamRequestStreamResponse(self): + self._ignored_stream_request_stream_response( + stream_stream_multi_callable(self._channel)) + + def testIgnoredStreamRequestStreamResponseNonBlocking(self): + self._ignored_stream_request_stream_response( + stream_stream_non_blocking_multi_callable(self._channel)) + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_rpc_part_2_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_rpc_part_2_test.py new file mode 100644 index 00000000000..0e559efec2a --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_rpc_part_2_test.py @@ -0,0 +1,426 @@ +# Copyright 2016 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test of RPCs made against gRPC Python's application-layer API.""" + +import itertools +import threading +import unittest +import logging +from concurrent import futures + +import grpc +from grpc.framework.foundation import logging_pool + +from tests.unit._rpc_test_helpers import ( + TIMEOUT_SHORT, Callback, unary_unary_multi_callable, + unary_stream_multi_callable, unary_stream_non_blocking_multi_callable, + stream_unary_multi_callable, stream_stream_multi_callable, + stream_stream_non_blocking_multi_callable, BaseRPCTest) +from tests.unit.framework.common import test_constants + + +class RPCPart2Test(BaseRPCTest, unittest.TestCase): + + def testDefaultThreadPoolIsUsed(self): + self._consume_one_stream_response_unary_request( + unary_stream_multi_callable(self._channel)) + self.assertFalse(self._thread_pool.was_used()) + + def testExperimentalThreadPoolIsUsed(self): + self._consume_one_stream_response_unary_request( + unary_stream_non_blocking_multi_callable(self._channel)) + self.assertTrue(self._thread_pool.was_used()) + + def testUnrecognizedMethod(self): + request = b'abc' + + with self.assertRaises(grpc.RpcError) as exception_context: + self._channel.unary_unary('NoSuchMethod')(request) + + self.assertEqual(grpc.StatusCode.UNIMPLEMENTED, + exception_context.exception.code()) + + def testSuccessfulUnaryRequestBlockingUnaryResponse(self): + request = b'\x07\x08' + expected_response = self._handler.handle_unary_unary(request, None) + + multi_callable = unary_unary_multi_callable(self._channel) + response = multi_callable( + request, + metadata=(('test', 'SuccessfulUnaryRequestBlockingUnaryResponse'),)) + + self.assertEqual(expected_response, response) + + def testSuccessfulUnaryRequestBlockingUnaryResponseWithCall(self): + request = b'\x07\x08' + expected_response = self._handler.handle_unary_unary(request, None) + + multi_callable = unary_unary_multi_callable(self._channel) + response, call = multi_callable.with_call( + request, + metadata=(('test', + 'SuccessfulUnaryRequestBlockingUnaryResponseWithCall'),)) + + self.assertEqual(expected_response, response) + self.assertIs(grpc.StatusCode.OK, call.code()) + self.assertEqual('', call.debug_error_string()) + + def testSuccessfulUnaryRequestFutureUnaryResponse(self): + request = b'\x07\x08' + expected_response = self._handler.handle_unary_unary(request, None) + + multi_callable = unary_unary_multi_callable(self._channel) + response_future = multi_callable.future( + request, + metadata=(('test', 'SuccessfulUnaryRequestFutureUnaryResponse'),)) + response = response_future.result() + + self.assertIsInstance(response_future, grpc.Future) + self.assertIsInstance(response_future, grpc.Call) + self.assertEqual(expected_response, response) + self.assertIsNone(response_future.exception()) + self.assertIsNone(response_future.traceback()) + + def testSuccessfulUnaryRequestStreamResponse(self): + request = b'\x37\x58' + expected_responses = tuple( + self._handler.handle_unary_stream(request, None)) + + multi_callable = unary_stream_multi_callable(self._channel) + response_iterator = multi_callable( + request, + metadata=(('test', 'SuccessfulUnaryRequestStreamResponse'),)) + responses = tuple(response_iterator) + + self.assertSequenceEqual(expected_responses, responses) + + def testSuccessfulStreamRequestBlockingUnaryResponse(self): + requests = tuple( + b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) + expected_response = self._handler.handle_stream_unary( + iter(requests), None) + request_iterator = iter(requests) + + multi_callable = stream_unary_multi_callable(self._channel) + response = multi_callable( + request_iterator, + metadata=(('test', + 'SuccessfulStreamRequestBlockingUnaryResponse'),)) + + self.assertEqual(expected_response, response) + + def testSuccessfulStreamRequestBlockingUnaryResponseWithCall(self): + requests = tuple( + b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) + expected_response = self._handler.handle_stream_unary( + iter(requests), None) + request_iterator = iter(requests) + + multi_callable = stream_unary_multi_callable(self._channel) + response, call = multi_callable.with_call( + request_iterator, + metadata=( + ('test', + 'SuccessfulStreamRequestBlockingUnaryResponseWithCall'),)) + + self.assertEqual(expected_response, response) + self.assertIs(grpc.StatusCode.OK, call.code()) + + def testSuccessfulStreamRequestFutureUnaryResponse(self): + requests = tuple( + b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) + expected_response = self._handler.handle_stream_unary( + iter(requests), None) + request_iterator = iter(requests) + + multi_callable = stream_unary_multi_callable(self._channel) + response_future = multi_callable.future( + request_iterator, + metadata=(('test', 'SuccessfulStreamRequestFutureUnaryResponse'),)) + response = response_future.result() + + self.assertEqual(expected_response, response) + self.assertIsNone(response_future.exception()) + self.assertIsNone(response_future.traceback()) + + def testSuccessfulStreamRequestStreamResponse(self): + requests = tuple( + b'\x77\x58' for _ in range(test_constants.STREAM_LENGTH)) + + expected_responses = tuple( + self._handler.handle_stream_stream(iter(requests), None)) + request_iterator = iter(requests) + + multi_callable = stream_stream_multi_callable(self._channel) + response_iterator = multi_callable( + request_iterator, + metadata=(('test', 'SuccessfulStreamRequestStreamResponse'),)) + responses = tuple(response_iterator) + + self.assertSequenceEqual(expected_responses, responses) + + def testSequentialInvocations(self): + first_request = b'\x07\x08' + second_request = b'\x0809' + expected_first_response = self._handler.handle_unary_unary( + first_request, None) + expected_second_response = self._handler.handle_unary_unary( + second_request, None) + + multi_callable = unary_unary_multi_callable(self._channel) + first_response = multi_callable(first_request, + metadata=(('test', + 'SequentialInvocations'),)) + second_response = multi_callable(second_request, + metadata=(('test', + 'SequentialInvocations'),)) + + self.assertEqual(expected_first_response, first_response) + self.assertEqual(expected_second_response, second_response) + + def testConcurrentBlockingInvocations(self): + pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY) + requests = tuple( + b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) + expected_response = self._handler.handle_stream_unary( + iter(requests), None) + expected_responses = [expected_response + ] * test_constants.THREAD_CONCURRENCY + response_futures = [None] * test_constants.THREAD_CONCURRENCY + + multi_callable = stream_unary_multi_callable(self._channel) + for index in range(test_constants.THREAD_CONCURRENCY): + request_iterator = iter(requests) + response_future = pool.submit( + multi_callable, + request_iterator, + metadata=(('test', 'ConcurrentBlockingInvocations'),)) + response_futures[index] = response_future + responses = tuple( + response_future.result() for response_future in response_futures) + + pool.shutdown(wait=True) + self.assertSequenceEqual(expected_responses, responses) + + def testConcurrentFutureInvocations(self): + requests = tuple( + b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) + expected_response = self._handler.handle_stream_unary( + iter(requests), None) + expected_responses = [expected_response + ] * test_constants.THREAD_CONCURRENCY + response_futures = [None] * test_constants.THREAD_CONCURRENCY + + multi_callable = stream_unary_multi_callable(self._channel) + for index in range(test_constants.THREAD_CONCURRENCY): + request_iterator = iter(requests) + response_future = multi_callable.future( + request_iterator, + metadata=(('test', 'ConcurrentFutureInvocations'),)) + response_futures[index] = response_future + responses = tuple( + response_future.result() for response_future in response_futures) + + self.assertSequenceEqual(expected_responses, responses) + + def testWaitingForSomeButNotAllConcurrentFutureInvocations(self): + pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY) + request = b'\x67\x68' + expected_response = self._handler.handle_unary_unary(request, None) + response_futures = [None] * test_constants.THREAD_CONCURRENCY + lock = threading.Lock() + test_is_running_cell = [True] + + def wrap_future(future): + + def wrap(): + try: + return future.result() + except grpc.RpcError: + with lock: + if test_is_running_cell[0]: + raise + return None + + return wrap + + multi_callable = unary_unary_multi_callable(self._channel) + for index in range(test_constants.THREAD_CONCURRENCY): + inner_response_future = multi_callable.future( + request, + metadata=( + ('test', + 'WaitingForSomeButNotAllConcurrentFutureInvocations'),)) + outer_response_future = pool.submit( + wrap_future(inner_response_future)) + response_futures[index] = outer_response_future + + some_completed_response_futures_iterator = itertools.islice( + futures.as_completed(response_futures), + test_constants.THREAD_CONCURRENCY // 2) + for response_future in some_completed_response_futures_iterator: + self.assertEqual(expected_response, response_future.result()) + with lock: + test_is_running_cell[0] = False + + def testConsumingOneStreamResponseUnaryRequest(self): + self._consume_one_stream_response_unary_request( + unary_stream_multi_callable(self._channel)) + + def testConsumingOneStreamResponseUnaryRequestNonBlocking(self): + self._consume_one_stream_response_unary_request( + unary_stream_non_blocking_multi_callable(self._channel)) + + def testConsumingSomeButNotAllStreamResponsesUnaryRequest(self): + self._consume_some_but_not_all_stream_responses_unary_request( + unary_stream_multi_callable(self._channel)) + + def testConsumingSomeButNotAllStreamResponsesUnaryRequestNonBlocking(self): + self._consume_some_but_not_all_stream_responses_unary_request( + unary_stream_non_blocking_multi_callable(self._channel)) + + def testConsumingSomeButNotAllStreamResponsesStreamRequest(self): + self._consume_some_but_not_all_stream_responses_stream_request( + stream_stream_multi_callable(self._channel)) + + def testConsumingSomeButNotAllStreamResponsesStreamRequestNonBlocking(self): + self._consume_some_but_not_all_stream_responses_stream_request( + stream_stream_non_blocking_multi_callable(self._channel)) + + def testConsumingTooManyStreamResponsesStreamRequest(self): + self._consume_too_many_stream_responses_stream_request( + stream_stream_multi_callable(self._channel)) + + def testConsumingTooManyStreamResponsesStreamRequestNonBlocking(self): + self._consume_too_many_stream_responses_stream_request( + stream_stream_non_blocking_multi_callable(self._channel)) + + def testCancelledUnaryRequestUnaryResponse(self): + request = b'\x07\x17' + + multi_callable = unary_unary_multi_callable(self._channel) + with self._control.pause(): + response_future = multi_callable.future( + request, + metadata=(('test', 'CancelledUnaryRequestUnaryResponse'),)) + response_future.cancel() + + self.assertIs(grpc.StatusCode.CANCELLED, response_future.code()) + self.assertTrue(response_future.cancelled()) + with self.assertRaises(grpc.FutureCancelledError): + response_future.result() + with self.assertRaises(grpc.FutureCancelledError): + response_future.exception() + with self.assertRaises(grpc.FutureCancelledError): + response_future.traceback() + + def testCancelledUnaryRequestStreamResponse(self): + self._cancelled_unary_request_stream_response( + unary_stream_multi_callable(self._channel)) + + def testCancelledUnaryRequestStreamResponseNonBlocking(self): + self._cancelled_unary_request_stream_response( + unary_stream_non_blocking_multi_callable(self._channel)) + + def testCancelledStreamRequestUnaryResponse(self): + requests = tuple( + b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + multi_callable = stream_unary_multi_callable(self._channel) + with self._control.pause(): + response_future = multi_callable.future( + request_iterator, + metadata=(('test', 'CancelledStreamRequestUnaryResponse'),)) + self._control.block_until_paused() + response_future.cancel() + + self.assertIs(grpc.StatusCode.CANCELLED, response_future.code()) + self.assertTrue(response_future.cancelled()) + with self.assertRaises(grpc.FutureCancelledError): + response_future.result() + with self.assertRaises(grpc.FutureCancelledError): + response_future.exception() + with self.assertRaises(grpc.FutureCancelledError): + response_future.traceback() + self.assertIsNotNone(response_future.initial_metadata()) + self.assertIsNotNone(response_future.details()) + self.assertIsNotNone(response_future.trailing_metadata()) + + def testCancelledStreamRequestStreamResponse(self): + self._cancelled_stream_request_stream_response( + stream_stream_multi_callable(self._channel)) + + def testCancelledStreamRequestStreamResponseNonBlocking(self): + self._cancelled_stream_request_stream_response( + stream_stream_non_blocking_multi_callable(self._channel)) + + def testExpiredUnaryRequestBlockingUnaryResponse(self): + request = b'\x07\x17' + + multi_callable = unary_unary_multi_callable(self._channel) + with self._control.pause(): + with self.assertRaises(grpc.RpcError) as exception_context: + multi_callable.with_call( + request, + timeout=TIMEOUT_SHORT, + metadata=(('test', + 'ExpiredUnaryRequestBlockingUnaryResponse'),)) + + self.assertIsInstance(exception_context.exception, grpc.Call) + self.assertIsNotNone(exception_context.exception.initial_metadata()) + self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, + exception_context.exception.code()) + self.assertIsNotNone(exception_context.exception.details()) + self.assertIsNotNone(exception_context.exception.trailing_metadata()) + + def testExpiredUnaryRequestFutureUnaryResponse(self): + request = b'\x07\x17' + callback = Callback() + + multi_callable = unary_unary_multi_callable(self._channel) + with self._control.pause(): + response_future = multi_callable.future( + request, + timeout=TIMEOUT_SHORT, + metadata=(('test', 'ExpiredUnaryRequestFutureUnaryResponse'),)) + response_future.add_done_callback(callback) + value_passed_to_callback = callback.value() + + self.assertIs(response_future, value_passed_to_callback) + self.assertIsNotNone(response_future.initial_metadata()) + self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, response_future.code()) + self.assertIsNotNone(response_future.details()) + self.assertIsNotNone(response_future.trailing_metadata()) + with self.assertRaises(grpc.RpcError) as exception_context: + response_future.result() + self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, + exception_context.exception.code()) + self.assertIsInstance(response_future.exception(), grpc.RpcError) + self.assertIsNotNone(response_future.traceback()) + self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, + response_future.exception().code()) + + def testExpiredUnaryRequestStreamResponse(self): + self._expired_unary_request_stream_response( + unary_stream_multi_callable(self._channel)) + + def testExpiredUnaryRequestStreamResponseNonBlocking(self): + self._expired_unary_request_stream_response( + unary_stream_non_blocking_multi_callable(self._channel)) + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_rpc_test_helpers.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_rpc_test_helpers.py new file mode 100644 index 00000000000..a3f18a9a490 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_rpc_test_helpers.py @@ -0,0 +1,417 @@ +# Copyright 2020 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test helpers for RPC invocation tests.""" + +import datetime +import threading + +import grpc +from grpc.framework.foundation import logging_pool + +from tests.unit import test_common +from tests.unit import thread_pool +from tests.unit.framework.common import test_constants +from tests.unit.framework.common import test_control + +_SERIALIZE_REQUEST = lambda bytestring: bytestring * 2 +_DESERIALIZE_REQUEST = lambda bytestring: bytestring[len(bytestring) // 2:] +_SERIALIZE_RESPONSE = lambda bytestring: bytestring * 3 +_DESERIALIZE_RESPONSE = lambda bytestring: bytestring[:len(bytestring) // 3] + +_UNARY_UNARY = '/test/UnaryUnary' +_UNARY_STREAM = '/test/UnaryStream' +_UNARY_STREAM_NON_BLOCKING = '/test/UnaryStreamNonBlocking' +_STREAM_UNARY = '/test/StreamUnary' +_STREAM_STREAM = '/test/StreamStream' +_STREAM_STREAM_NON_BLOCKING = '/test/StreamStreamNonBlocking' + +TIMEOUT_SHORT = datetime.timedelta(seconds=1).total_seconds() + + +class Callback(object): + + def __init__(self): + self._condition = threading.Condition() + self._value = None + self._called = False + + def __call__(self, value): + with self._condition: + self._value = value + self._called = True + self._condition.notify_all() + + def value(self): + with self._condition: + while not self._called: + self._condition.wait() + return self._value + + +class _Handler(object): + + def __init__(self, control, thread_pool): + self._control = control + self._thread_pool = thread_pool + non_blocking_functions = (self.handle_unary_stream_non_blocking, + self.handle_stream_stream_non_blocking) + for non_blocking_function in non_blocking_functions: + non_blocking_function.__func__.experimental_non_blocking = True + non_blocking_function.__func__.experimental_thread_pool = self._thread_pool + + def handle_unary_unary(self, request, servicer_context): + self._control.control() + if servicer_context is not None: + servicer_context.set_trailing_metadata((( + 'testkey', + 'testvalue', + ),)) + # TODO(https://github.com/grpc/grpc/issues/8483): test the values + # returned by these methods rather than only "smoke" testing that + # the return after having been called. + servicer_context.is_active() + servicer_context.time_remaining() + return request + + def handle_unary_stream(self, request, servicer_context): + for _ in range(test_constants.STREAM_LENGTH): + self._control.control() + yield request + self._control.control() + if servicer_context is not None: + servicer_context.set_trailing_metadata((( + 'testkey', + 'testvalue', + ),)) + + def handle_unary_stream_non_blocking(self, request, servicer_context, + on_next): + for _ in range(test_constants.STREAM_LENGTH): + self._control.control() + on_next(request) + self._control.control() + if servicer_context is not None: + servicer_context.set_trailing_metadata((( + 'testkey', + 'testvalue', + ),)) + on_next(None) + + def handle_stream_unary(self, request_iterator, servicer_context): + if servicer_context is not None: + servicer_context.invocation_metadata() + self._control.control() + response_elements = [] + for request in request_iterator: + self._control.control() + response_elements.append(request) + self._control.control() + if servicer_context is not None: + servicer_context.set_trailing_metadata((( + 'testkey', + 'testvalue', + ),)) + return b''.join(response_elements) + + def handle_stream_stream(self, request_iterator, servicer_context): + self._control.control() + if servicer_context is not None: + servicer_context.set_trailing_metadata((( + 'testkey', + 'testvalue', + ),)) + for request in request_iterator: + self._control.control() + yield request + self._control.control() + + def handle_stream_stream_non_blocking(self, request_iterator, + servicer_context, on_next): + self._control.control() + if servicer_context is not None: + servicer_context.set_trailing_metadata((( + 'testkey', + 'testvalue', + ),)) + for request in request_iterator: + self._control.control() + on_next(request) + self._control.control() + on_next(None) + + +class _MethodHandler(grpc.RpcMethodHandler): + + def __init__(self, request_streaming, response_streaming, + request_deserializer, response_serializer, unary_unary, + unary_stream, stream_unary, stream_stream): + self.request_streaming = request_streaming + self.response_streaming = response_streaming + self.request_deserializer = request_deserializer + self.response_serializer = response_serializer + self.unary_unary = unary_unary + self.unary_stream = unary_stream + self.stream_unary = stream_unary + self.stream_stream = stream_stream + + +class _GenericHandler(grpc.GenericRpcHandler): + + def __init__(self, handler): + self._handler = handler + + def service(self, handler_call_details): + if handler_call_details.method == _UNARY_UNARY: + return _MethodHandler(False, False, None, None, + self._handler.handle_unary_unary, None, None, + None) + elif handler_call_details.method == _UNARY_STREAM: + return _MethodHandler(False, True, _DESERIALIZE_REQUEST, + _SERIALIZE_RESPONSE, None, + self._handler.handle_unary_stream, None, None) + elif handler_call_details.method == _UNARY_STREAM_NON_BLOCKING: + return _MethodHandler( + False, True, _DESERIALIZE_REQUEST, _SERIALIZE_RESPONSE, None, + self._handler.handle_unary_stream_non_blocking, None, None) + elif handler_call_details.method == _STREAM_UNARY: + return _MethodHandler(True, False, _DESERIALIZE_REQUEST, + _SERIALIZE_RESPONSE, None, None, + self._handler.handle_stream_unary, None) + elif handler_call_details.method == _STREAM_STREAM: + return _MethodHandler(True, True, None, None, None, None, None, + self._handler.handle_stream_stream) + elif handler_call_details.method == _STREAM_STREAM_NON_BLOCKING: + return _MethodHandler( + True, True, None, None, None, None, None, + self._handler.handle_stream_stream_non_blocking) + else: + return None + + +def unary_unary_multi_callable(channel): + return channel.unary_unary(_UNARY_UNARY) + + +def unary_stream_multi_callable(channel): + return channel.unary_stream(_UNARY_STREAM, + request_serializer=_SERIALIZE_REQUEST, + response_deserializer=_DESERIALIZE_RESPONSE) + + +def unary_stream_non_blocking_multi_callable(channel): + return channel.unary_stream(_UNARY_STREAM_NON_BLOCKING, + request_serializer=_SERIALIZE_REQUEST, + response_deserializer=_DESERIALIZE_RESPONSE) + + +def stream_unary_multi_callable(channel): + return channel.stream_unary(_STREAM_UNARY, + request_serializer=_SERIALIZE_REQUEST, + response_deserializer=_DESERIALIZE_RESPONSE) + + +def stream_stream_multi_callable(channel): + return channel.stream_stream(_STREAM_STREAM) + + +def stream_stream_non_blocking_multi_callable(channel): + return channel.stream_stream(_STREAM_STREAM_NON_BLOCKING) + + +class BaseRPCTest(object): + + def setUp(self): + self._control = test_control.PauseFailControl() + self._thread_pool = thread_pool.RecordingThreadPool(max_workers=None) + self._handler = _Handler(self._control, self._thread_pool) + + self._server = test_common.test_server() + port = self._server.add_insecure_port('[::]:0') + self._server.add_generic_rpc_handlers((_GenericHandler(self._handler),)) + self._server.start() + + self._channel = grpc.insecure_channel('localhost:%d' % port) + + def tearDown(self): + self._server.stop(None) + self._channel.close() + + def _consume_one_stream_response_unary_request(self, multi_callable): + request = b'\x57\x38' + + response_iterator = multi_callable( + request, + metadata=(('test', 'ConsumingOneStreamResponseUnaryRequest'),)) + next(response_iterator) + + def _consume_some_but_not_all_stream_responses_unary_request( + self, multi_callable): + request = b'\x57\x38' + + response_iterator = multi_callable( + request, + metadata=(('test', + 'ConsumingSomeButNotAllStreamResponsesUnaryRequest'),)) + for _ in range(test_constants.STREAM_LENGTH // 2): + next(response_iterator) + + def _consume_some_but_not_all_stream_responses_stream_request( + self, multi_callable): + requests = tuple( + b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + response_iterator = multi_callable( + request_iterator, + metadata=(('test', + 'ConsumingSomeButNotAllStreamResponsesStreamRequest'),)) + for _ in range(test_constants.STREAM_LENGTH // 2): + next(response_iterator) + + def _consume_too_many_stream_responses_stream_request(self, multi_callable): + requests = tuple( + b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + response_iterator = multi_callable( + request_iterator, + metadata=(('test', + 'ConsumingTooManyStreamResponsesStreamRequest'),)) + for _ in range(test_constants.STREAM_LENGTH): + next(response_iterator) + for _ in range(test_constants.STREAM_LENGTH): + with self.assertRaises(StopIteration): + next(response_iterator) + + self.assertIsNotNone(response_iterator.initial_metadata()) + self.assertIs(grpc.StatusCode.OK, response_iterator.code()) + self.assertIsNotNone(response_iterator.details()) + self.assertIsNotNone(response_iterator.trailing_metadata()) + + def _cancelled_unary_request_stream_response(self, multi_callable): + request = b'\x07\x19' + + with self._control.pause(): + response_iterator = multi_callable( + request, + metadata=(('test', 'CancelledUnaryRequestStreamResponse'),)) + self._control.block_until_paused() + response_iterator.cancel() + + with self.assertRaises(grpc.RpcError) as exception_context: + next(response_iterator) + self.assertIs(grpc.StatusCode.CANCELLED, + exception_context.exception.code()) + self.assertIsNotNone(response_iterator.initial_metadata()) + self.assertIs(grpc.StatusCode.CANCELLED, response_iterator.code()) + self.assertIsNotNone(response_iterator.details()) + self.assertIsNotNone(response_iterator.trailing_metadata()) + + def _cancelled_stream_request_stream_response(self, multi_callable): + requests = tuple( + b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + with self._control.pause(): + response_iterator = multi_callable( + request_iterator, + metadata=(('test', 'CancelledStreamRequestStreamResponse'),)) + response_iterator.cancel() + + with self.assertRaises(grpc.RpcError): + next(response_iterator) + self.assertIsNotNone(response_iterator.initial_metadata()) + self.assertIs(grpc.StatusCode.CANCELLED, response_iterator.code()) + self.assertIsNotNone(response_iterator.details()) + self.assertIsNotNone(response_iterator.trailing_metadata()) + + def _expired_unary_request_stream_response(self, multi_callable): + request = b'\x07\x19' + + with self._control.pause(): + with self.assertRaises(grpc.RpcError) as exception_context: + response_iterator = multi_callable( + request, + timeout=test_constants.SHORT_TIMEOUT, + metadata=(('test', 'ExpiredUnaryRequestStreamResponse'),)) + next(response_iterator) + + self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, + exception_context.exception.code()) + self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, + response_iterator.code()) + + def _expired_stream_request_stream_response(self, multi_callable): + requests = tuple( + b'\x67\x18' for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + with self._control.pause(): + with self.assertRaises(grpc.RpcError) as exception_context: + response_iterator = multi_callable( + request_iterator, + timeout=test_constants.SHORT_TIMEOUT, + metadata=(('test', 'ExpiredStreamRequestStreamResponse'),)) + next(response_iterator) + + self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, + exception_context.exception.code()) + self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, + response_iterator.code()) + + def _failed_unary_request_stream_response(self, multi_callable): + request = b'\x37\x17' + + with self.assertRaises(grpc.RpcError) as exception_context: + with self._control.fail(): + response_iterator = multi_callable( + request, + metadata=(('test', 'FailedUnaryRequestStreamResponse'),)) + next(response_iterator) + + self.assertIs(grpc.StatusCode.UNKNOWN, + exception_context.exception.code()) + + def _failed_stream_request_stream_response(self, multi_callable): + requests = tuple( + b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + with self._control.fail(): + with self.assertRaises(grpc.RpcError) as exception_context: + response_iterator = multi_callable( + request_iterator, + metadata=(('test', 'FailedStreamRequestStreamResponse'),)) + tuple(response_iterator) + + self.assertIs(grpc.StatusCode.UNKNOWN, + exception_context.exception.code()) + self.assertIs(grpc.StatusCode.UNKNOWN, response_iterator.code()) + + def _ignored_unary_stream_request_future_unary_response( + self, multi_callable): + request = b'\x37\x17' + + multi_callable(request, + metadata=(('test', + 'IgnoredUnaryRequestStreamResponse'),)) + + def _ignored_stream_request_stream_response(self, multi_callable): + requests = tuple( + b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + multi_callable(request_iterator, + metadata=(('test', + 'IgnoredStreamRequestStreamResponse'),)) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_server_shutdown_scenarios.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_server_shutdown_scenarios.py new file mode 100644 index 00000000000..1d1fdba11ee --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_server_shutdown_scenarios.py @@ -0,0 +1,97 @@ +# Copyright 2018 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Defines a number of module-scope gRPC scenarios to test server shutdown.""" + +import argparse +import os +import threading +import time +import logging + +import grpc +from tests.unit import test_common + +from concurrent import futures +from six.moves import queue + +WAIT_TIME = 1000 + +REQUEST = b'request' +RESPONSE = b'response' + +SERVER_RAISES_EXCEPTION = 'server_raises_exception' +SERVER_DEALLOCATED = 'server_deallocated' +SERVER_FORK_CAN_EXIT = 'server_fork_can_exit' + +FORK_EXIT = '/test/ForkExit' + + +def fork_and_exit(request, servicer_context): + pid = os.fork() + if pid == 0: + os._exit(0) + return RESPONSE + + +class GenericHandler(grpc.GenericRpcHandler): + + def service(self, handler_call_details): + if handler_call_details.method == FORK_EXIT: + return grpc.unary_unary_rpc_method_handler(fork_and_exit) + else: + return None + + +def run_server(port_queue): + server = test_common.test_server() + port = server.add_insecure_port('[::]:0') + port_queue.put(port) + server.add_generic_rpc_handlers((GenericHandler(),)) + server.start() + # threading.Event.wait() does not exhibit the bug identified in + # https://github.com/grpc/grpc/issues/17093, sleep instead + time.sleep(WAIT_TIME) + + +def run_test(args): + if args.scenario == SERVER_RAISES_EXCEPTION: + server = test_common.test_server() + server.start() + raise Exception() + elif args.scenario == SERVER_DEALLOCATED: + server = test_common.test_server() + server.start() + server.__del__() + while server._state.stage != grpc._server._ServerStage.STOPPED: + pass + elif args.scenario == SERVER_FORK_CAN_EXIT: + port_queue = queue.Queue() + thread = threading.Thread(target=run_server, args=(port_queue,)) + thread.daemon = True + thread.start() + port = port_queue.get() + channel = grpc.insecure_channel('localhost:%d' % port) + multi_callable = channel.unary_unary(FORK_EXIT) + result, call = multi_callable.with_call(REQUEST, wait_for_ready=True) + os.wait() + else: + raise ValueError('unknown test scenario') + + +if __name__ == '__main__': + logging.basicConfig() + parser = argparse.ArgumentParser() + parser.add_argument('scenario', type=str) + args = parser.parse_args() + run_test(args) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_server_shutdown_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_server_shutdown_test.py new file mode 100644 index 00000000000..c1dc7585f80 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_server_shutdown_test.py @@ -0,0 +1,95 @@ +# Copyright 2018 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests clean shutdown of server on various interpreter exit conditions. + +The tests in this module spawn a subprocess for each test case, the +test is considered successful if it doesn't hang/timeout. +""" + +import atexit +import os +import subprocess +import sys +import threading +import unittest +import logging + +from tests.unit import _server_shutdown_scenarios + +INTERPRETER = sys.executable +BASE_COMMAND = [INTERPRETER, '-m', 'tests.unit._server_shutdown_scenarios'] + +processes = [] +process_lock = threading.Lock() + + +# Make sure we attempt to clean up any +# processes we may have left running +def cleanup_processes(): + with process_lock: + for process in processes: + try: + process.kill() + except Exception: # pylint: disable=broad-except + pass + + +atexit.register(cleanup_processes) + + +def wait(process): + with process_lock: + processes.append(process) + process.wait() + + +class ServerShutdown(unittest.TestCase): + + # Currently we shut down a server (if possible) after the Python server + # instance is garbage collected. This behavior may change in the future. + def test_deallocated_server_stops(self): + env = os.environ.copy() + env['Y_PYTHON_ENTRY_POINT'] = ':main' + process = subprocess.Popen( + BASE_COMMAND + [_server_shutdown_scenarios.SERVER_DEALLOCATED], + stdout=sys.stdout, + stderr=sys.stderr, + env=env) + wait(process) + + def test_server_exception_exits(self): + env = os.environ.copy() + env['Y_PYTHON_ENTRY_POINT'] = ':main' + process = subprocess.Popen( + BASE_COMMAND + [_server_shutdown_scenarios.SERVER_RAISES_EXCEPTION], + stdout=sys.stdout, + stderr=sys.stderr, + env=env) + wait(process) + + @unittest.skipIf(os.name == 'nt', 'fork not supported on windows') + def test_server_fork_can_exit(self): + env = os.environ.copy() + env['Y_PYTHON_ENTRY_POINT'] = ':main' + process = subprocess.Popen( + BASE_COMMAND + [_server_shutdown_scenarios.SERVER_FORK_CAN_EXIT], + stdout=sys.stdout, + stderr=sys.stderr, + env=env) + wait(process) + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_server_ssl_cert_config_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_server_ssl_cert_config_test.py new file mode 100644 index 00000000000..35d992a33d6 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_server_ssl_cert_config_test.py @@ -0,0 +1,511 @@ +# Copyright 2017 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests server certificate rotation. + +Here we test various aspects of gRPC Python, and in some cases gRPC +Core by extension, support for server certificate rotation. + +* ServerSSLCertReloadTestWithClientAuth: test ability to rotate + server's SSL cert for use in future channels with clients while not + affecting any existing channel. The server requires client + authentication. + +* ServerSSLCertReloadTestWithoutClientAuth: like + ServerSSLCertReloadTestWithClientAuth except that the server does + not authenticate the client. + +* ServerSSLCertReloadTestCertConfigReuse: tests gRPC Python's ability + to deal with user's reuse of ServerCertificateConfiguration instances. +""" + +import abc +import collections +import os +import six +import threading +import unittest +import logging + +from concurrent import futures + +import grpc +from tests.unit import resources +from tests.unit import test_common +from tests.testing import _application_common +from tests.testing import _server_application +from tests.testing.proto import services_pb2_grpc + +CA_1_PEM = resources.cert_hier_1_root_ca_cert() +CA_2_PEM = resources.cert_hier_2_root_ca_cert() + +CLIENT_KEY_1_PEM = resources.cert_hier_1_client_1_key() +CLIENT_CERT_CHAIN_1_PEM = (resources.cert_hier_1_client_1_cert() + + resources.cert_hier_1_intermediate_ca_cert()) + +CLIENT_KEY_2_PEM = resources.cert_hier_2_client_1_key() +CLIENT_CERT_CHAIN_2_PEM = (resources.cert_hier_2_client_1_cert() + + resources.cert_hier_2_intermediate_ca_cert()) + +SERVER_KEY_1_PEM = resources.cert_hier_1_server_1_key() +SERVER_CERT_CHAIN_1_PEM = (resources.cert_hier_1_server_1_cert() + + resources.cert_hier_1_intermediate_ca_cert()) + +SERVER_KEY_2_PEM = resources.cert_hier_2_server_1_key() +SERVER_CERT_CHAIN_2_PEM = (resources.cert_hier_2_server_1_cert() + + resources.cert_hier_2_intermediate_ca_cert()) + +# for use with the CertConfigFetcher. Roughly a simple custom mock +# implementation +Call = collections.namedtuple('Call', ['did_raise', 'returned_cert_config']) + + +def _create_channel(port, credentials): + return grpc.secure_channel('localhost:{}'.format(port), credentials) + + +def _create_client_stub(channel, expect_success): + if expect_success: + # per Nathaniel: there's some robustness issue if we start + # using a channel without waiting for it to be actually ready + grpc.channel_ready_future(channel).result(timeout=10) + return services_pb2_grpc.FirstServiceStub(channel) + + +class CertConfigFetcher(object): + + def __init__(self): + self._lock = threading.Lock() + self._calls = [] + self._should_raise = False + self._cert_config = None + + def reset(self): + with self._lock: + self._calls = [] + self._should_raise = False + self._cert_config = None + + def configure(self, should_raise, cert_config): + assert not (should_raise and cert_config), ( + "should not specify both should_raise and a cert_config at the same time" + ) + with self._lock: + self._should_raise = should_raise + self._cert_config = cert_config + + def getCalls(self): + with self._lock: + return self._calls + + def __call__(self): + with self._lock: + if self._should_raise: + self._calls.append(Call(True, None)) + raise ValueError('just for fun, should not affect the test') + else: + self._calls.append(Call(False, self._cert_config)) + return self._cert_config + + +class _ServerSSLCertReloadTest( + six.with_metaclass(abc.ABCMeta, unittest.TestCase)): + + def __init__(self, *args, **kwargs): + super(_ServerSSLCertReloadTest, self).__init__(*args, **kwargs) + self.server = None + self.port = None + + @abc.abstractmethod + def require_client_auth(self): + raise NotImplementedError() + + def setUp(self): + self.server = test_common.test_server() + services_pb2_grpc.add_FirstServiceServicer_to_server( + _server_application.FirstServiceServicer(), self.server) + switch_cert_on_client_num = 10 + initial_cert_config = grpc.ssl_server_certificate_configuration( + [(SERVER_KEY_1_PEM, SERVER_CERT_CHAIN_1_PEM)], + root_certificates=CA_2_PEM) + self.cert_config_fetcher = CertConfigFetcher() + server_credentials = grpc.dynamic_ssl_server_credentials( + initial_cert_config, + self.cert_config_fetcher, + require_client_authentication=self.require_client_auth()) + self.port = self.server.add_secure_port('[::]:0', server_credentials) + self.server.start() + + def tearDown(self): + if self.server: + self.server.stop(None) + + def _perform_rpc(self, client_stub, expect_success): + # we don't care about the actual response of the rpc; only + # whether we can perform it or not, and if not, the status + # code must be UNAVAILABLE + request = _application_common.UNARY_UNARY_REQUEST + if expect_success: + response = client_stub.UnUn(request) + self.assertEqual(response, _application_common.UNARY_UNARY_RESPONSE) + else: + with self.assertRaises(grpc.RpcError) as exception_context: + client_stub.UnUn(request) + # If TLS 1.2 is used, then the client receives an alert message + # before the handshake is complete, so the status is UNAVAILABLE. If + # TLS 1.3 is used, then the client receives the alert message after + # the handshake is complete, so the TSI handshaker returns the + # TSI_PROTOCOL_FAILURE result. This result does not have a + # corresponding status code, so this yields an UNKNOWN status. + self.assertTrue(exception_context.exception.code( + ) in [grpc.StatusCode.UNAVAILABLE, grpc.StatusCode.UNKNOWN]) + + def _do_one_shot_client_rpc(self, + expect_success, + root_certificates=None, + private_key=None, + certificate_chain=None): + credentials = grpc.ssl_channel_credentials( + root_certificates=root_certificates, + private_key=private_key, + certificate_chain=certificate_chain) + with _create_channel(self.port, credentials) as client_channel: + client_stub = _create_client_stub(client_channel, expect_success) + self._perform_rpc(client_stub, expect_success) + + def _test(self): + # things should work... + self.cert_config_fetcher.configure(False, None) + self._do_one_shot_client_rpc(True, + root_certificates=CA_1_PEM, + private_key=CLIENT_KEY_2_PEM, + certificate_chain=CLIENT_CERT_CHAIN_2_PEM) + actual_calls = self.cert_config_fetcher.getCalls() + self.assertEqual(len(actual_calls), 1) + self.assertFalse(actual_calls[0].did_raise) + self.assertIsNone(actual_calls[0].returned_cert_config) + + # client should reject server... + # fails because client trusts ca2 and so will reject server + self.cert_config_fetcher.reset() + self.cert_config_fetcher.configure(False, None) + self._do_one_shot_client_rpc(False, + root_certificates=CA_2_PEM, + private_key=CLIENT_KEY_2_PEM, + certificate_chain=CLIENT_CERT_CHAIN_2_PEM) + actual_calls = self.cert_config_fetcher.getCalls() + self.assertGreaterEqual(len(actual_calls), 1) + self.assertFalse(actual_calls[0].did_raise) + for i, call in enumerate(actual_calls): + self.assertFalse(call.did_raise, 'i= {}'.format(i)) + self.assertIsNone(call.returned_cert_config, 'i= {}'.format(i)) + + # should work again... + self.cert_config_fetcher.reset() + self.cert_config_fetcher.configure(True, None) + self._do_one_shot_client_rpc(True, + root_certificates=CA_1_PEM, + private_key=CLIENT_KEY_2_PEM, + certificate_chain=CLIENT_CERT_CHAIN_2_PEM) + actual_calls = self.cert_config_fetcher.getCalls() + self.assertEqual(len(actual_calls), 1) + self.assertTrue(actual_calls[0].did_raise) + self.assertIsNone(actual_calls[0].returned_cert_config) + + # if with_client_auth, then client should be rejected by + # server because client uses key/cert1, but server trusts ca2, + # so server will reject + self.cert_config_fetcher.reset() + self.cert_config_fetcher.configure(False, None) + self._do_one_shot_client_rpc(not self.require_client_auth(), + root_certificates=CA_1_PEM, + private_key=CLIENT_KEY_1_PEM, + certificate_chain=CLIENT_CERT_CHAIN_1_PEM) + actual_calls = self.cert_config_fetcher.getCalls() + self.assertGreaterEqual(len(actual_calls), 1) + for i, call in enumerate(actual_calls): + self.assertFalse(call.did_raise, 'i= {}'.format(i)) + self.assertIsNone(call.returned_cert_config, 'i= {}'.format(i)) + + # should work again... + self.cert_config_fetcher.reset() + self.cert_config_fetcher.configure(False, None) + self._do_one_shot_client_rpc(True, + root_certificates=CA_1_PEM, + private_key=CLIENT_KEY_2_PEM, + certificate_chain=CLIENT_CERT_CHAIN_2_PEM) + actual_calls = self.cert_config_fetcher.getCalls() + self.assertEqual(len(actual_calls), 1) + self.assertFalse(actual_calls[0].did_raise) + self.assertIsNone(actual_calls[0].returned_cert_config) + + # now create the "persistent" clients + self.cert_config_fetcher.reset() + self.cert_config_fetcher.configure(False, None) + channel_A = _create_channel( + self.port, + grpc.ssl_channel_credentials( + root_certificates=CA_1_PEM, + private_key=CLIENT_KEY_2_PEM, + certificate_chain=CLIENT_CERT_CHAIN_2_PEM)) + persistent_client_stub_A = _create_client_stub(channel_A, True) + self._perform_rpc(persistent_client_stub_A, True) + actual_calls = self.cert_config_fetcher.getCalls() + self.assertEqual(len(actual_calls), 1) + self.assertFalse(actual_calls[0].did_raise) + self.assertIsNone(actual_calls[0].returned_cert_config) + + self.cert_config_fetcher.reset() + self.cert_config_fetcher.configure(False, None) + channel_B = _create_channel( + self.port, + grpc.ssl_channel_credentials( + root_certificates=CA_1_PEM, + private_key=CLIENT_KEY_2_PEM, + certificate_chain=CLIENT_CERT_CHAIN_2_PEM)) + persistent_client_stub_B = _create_client_stub(channel_B, True) + self._perform_rpc(persistent_client_stub_B, True) + actual_calls = self.cert_config_fetcher.getCalls() + self.assertEqual(len(actual_calls), 1) + self.assertFalse(actual_calls[0].did_raise) + self.assertIsNone(actual_calls[0].returned_cert_config) + + # moment of truth!! client should reject server because the + # server switch cert... + cert_config = grpc.ssl_server_certificate_configuration( + [(SERVER_KEY_2_PEM, SERVER_CERT_CHAIN_2_PEM)], + root_certificates=CA_1_PEM) + self.cert_config_fetcher.reset() + self.cert_config_fetcher.configure(False, cert_config) + self._do_one_shot_client_rpc(False, + root_certificates=CA_1_PEM, + private_key=CLIENT_KEY_2_PEM, + certificate_chain=CLIENT_CERT_CHAIN_2_PEM) + actual_calls = self.cert_config_fetcher.getCalls() + self.assertGreaterEqual(len(actual_calls), 1) + self.assertFalse(actual_calls[0].did_raise) + for i, call in enumerate(actual_calls): + self.assertFalse(call.did_raise, 'i= {}'.format(i)) + self.assertEqual(call.returned_cert_config, cert_config, + 'i= {}'.format(i)) + + # now should work again... + self.cert_config_fetcher.reset() + self.cert_config_fetcher.configure(False, None) + self._do_one_shot_client_rpc(True, + root_certificates=CA_2_PEM, + private_key=CLIENT_KEY_1_PEM, + certificate_chain=CLIENT_CERT_CHAIN_1_PEM) + actual_calls = self.cert_config_fetcher.getCalls() + self.assertEqual(len(actual_calls), 1) + self.assertFalse(actual_calls[0].did_raise) + self.assertIsNone(actual_calls[0].returned_cert_config) + + # client should be rejected by server if with_client_auth + self.cert_config_fetcher.reset() + self.cert_config_fetcher.configure(False, None) + self._do_one_shot_client_rpc(not self.require_client_auth(), + root_certificates=CA_2_PEM, + private_key=CLIENT_KEY_2_PEM, + certificate_chain=CLIENT_CERT_CHAIN_2_PEM) + actual_calls = self.cert_config_fetcher.getCalls() + self.assertGreaterEqual(len(actual_calls), 1) + for i, call in enumerate(actual_calls): + self.assertFalse(call.did_raise, 'i= {}'.format(i)) + self.assertIsNone(call.returned_cert_config, 'i= {}'.format(i)) + + # here client should reject server... + self.cert_config_fetcher.reset() + self.cert_config_fetcher.configure(False, None) + self._do_one_shot_client_rpc(False, + root_certificates=CA_1_PEM, + private_key=CLIENT_KEY_2_PEM, + certificate_chain=CLIENT_CERT_CHAIN_2_PEM) + actual_calls = self.cert_config_fetcher.getCalls() + self.assertGreaterEqual(len(actual_calls), 1) + for i, call in enumerate(actual_calls): + self.assertFalse(call.did_raise, 'i= {}'.format(i)) + self.assertIsNone(call.returned_cert_config, 'i= {}'.format(i)) + + # persistent clients should continue to work + self.cert_config_fetcher.reset() + self.cert_config_fetcher.configure(False, None) + self._perform_rpc(persistent_client_stub_A, True) + actual_calls = self.cert_config_fetcher.getCalls() + self.assertEqual(len(actual_calls), 0) + + self.cert_config_fetcher.reset() + self.cert_config_fetcher.configure(False, None) + self._perform_rpc(persistent_client_stub_B, True) + actual_calls = self.cert_config_fetcher.getCalls() + self.assertEqual(len(actual_calls), 0) + + channel_A.close() + channel_B.close() + + +class ServerSSLCertConfigFetcherParamsChecks(unittest.TestCase): + + def test_check_on_initial_config(self): + with self.assertRaises(TypeError): + grpc.dynamic_ssl_server_credentials(None, str) + with self.assertRaises(TypeError): + grpc.dynamic_ssl_server_credentials(1, str) + + def test_check_on_config_fetcher(self): + cert_config = grpc.ssl_server_certificate_configuration( + [(SERVER_KEY_2_PEM, SERVER_CERT_CHAIN_2_PEM)], + root_certificates=CA_1_PEM) + with self.assertRaises(TypeError): + grpc.dynamic_ssl_server_credentials(cert_config, None) + with self.assertRaises(TypeError): + grpc.dynamic_ssl_server_credentials(cert_config, 1) + + +class ServerSSLCertReloadTestWithClientAuth(_ServerSSLCertReloadTest): + + def require_client_auth(self): + return True + + test = _ServerSSLCertReloadTest._test + + +class ServerSSLCertReloadTestWithoutClientAuth(_ServerSSLCertReloadTest): + + def require_client_auth(self): + return False + + test = _ServerSSLCertReloadTest._test + + +class ServerSSLCertReloadTestCertConfigReuse(_ServerSSLCertReloadTest): + """Ensures that `ServerCertificateConfiguration` instances can be reused. + + Because gRPC Core takes ownership of the + `grpc_ssl_server_certificate_config` encapsulated by + `ServerCertificateConfiguration`, this test reuses the same + `ServerCertificateConfiguration` instances multiple times to make sure + gRPC Python takes care of maintaining the validity of + `ServerCertificateConfiguration` instances, so that such instances can be + re-used by user application. + """ + + def require_client_auth(self): + return True + + def setUp(self): + self.server = test_common.test_server() + services_pb2_grpc.add_FirstServiceServicer_to_server( + _server_application.FirstServiceServicer(), self.server) + self.cert_config_A = grpc.ssl_server_certificate_configuration( + [(SERVER_KEY_1_PEM, SERVER_CERT_CHAIN_1_PEM)], + root_certificates=CA_2_PEM) + self.cert_config_B = grpc.ssl_server_certificate_configuration( + [(SERVER_KEY_2_PEM, SERVER_CERT_CHAIN_2_PEM)], + root_certificates=CA_1_PEM) + self.cert_config_fetcher = CertConfigFetcher() + server_credentials = grpc.dynamic_ssl_server_credentials( + self.cert_config_A, + self.cert_config_fetcher, + require_client_authentication=True) + self.port = self.server.add_secure_port('[::]:0', server_credentials) + self.server.start() + + def test_cert_config_reuse(self): + + # succeed with A + self.cert_config_fetcher.reset() + self.cert_config_fetcher.configure(False, self.cert_config_A) + self._do_one_shot_client_rpc(True, + root_certificates=CA_1_PEM, + private_key=CLIENT_KEY_2_PEM, + certificate_chain=CLIENT_CERT_CHAIN_2_PEM) + actual_calls = self.cert_config_fetcher.getCalls() + self.assertEqual(len(actual_calls), 1) + self.assertFalse(actual_calls[0].did_raise) + self.assertEqual(actual_calls[0].returned_cert_config, + self.cert_config_A) + + # fail with A + self.cert_config_fetcher.reset() + self.cert_config_fetcher.configure(False, self.cert_config_A) + self._do_one_shot_client_rpc(False, + root_certificates=CA_2_PEM, + private_key=CLIENT_KEY_1_PEM, + certificate_chain=CLIENT_CERT_CHAIN_1_PEM) + actual_calls = self.cert_config_fetcher.getCalls() + self.assertGreaterEqual(len(actual_calls), 1) + self.assertFalse(actual_calls[0].did_raise) + for i, call in enumerate(actual_calls): + self.assertFalse(call.did_raise, 'i= {}'.format(i)) + self.assertEqual(call.returned_cert_config, self.cert_config_A, + 'i= {}'.format(i)) + + # succeed again with A + self.cert_config_fetcher.reset() + self.cert_config_fetcher.configure(False, self.cert_config_A) + self._do_one_shot_client_rpc(True, + root_certificates=CA_1_PEM, + private_key=CLIENT_KEY_2_PEM, + certificate_chain=CLIENT_CERT_CHAIN_2_PEM) + actual_calls = self.cert_config_fetcher.getCalls() + self.assertEqual(len(actual_calls), 1) + self.assertFalse(actual_calls[0].did_raise) + self.assertEqual(actual_calls[0].returned_cert_config, + self.cert_config_A) + + # succeed with B + self.cert_config_fetcher.reset() + self.cert_config_fetcher.configure(False, self.cert_config_B) + self._do_one_shot_client_rpc(True, + root_certificates=CA_2_PEM, + private_key=CLIENT_KEY_1_PEM, + certificate_chain=CLIENT_CERT_CHAIN_1_PEM) + actual_calls = self.cert_config_fetcher.getCalls() + self.assertEqual(len(actual_calls), 1) + self.assertFalse(actual_calls[0].did_raise) + self.assertEqual(actual_calls[0].returned_cert_config, + self.cert_config_B) + + # fail with B + self.cert_config_fetcher.reset() + self.cert_config_fetcher.configure(False, self.cert_config_B) + self._do_one_shot_client_rpc(False, + root_certificates=CA_1_PEM, + private_key=CLIENT_KEY_2_PEM, + certificate_chain=CLIENT_CERT_CHAIN_2_PEM) + actual_calls = self.cert_config_fetcher.getCalls() + self.assertGreaterEqual(len(actual_calls), 1) + self.assertFalse(actual_calls[0].did_raise) + for i, call in enumerate(actual_calls): + self.assertFalse(call.did_raise, 'i= {}'.format(i)) + self.assertEqual(call.returned_cert_config, self.cert_config_B, + 'i= {}'.format(i)) + + # succeed again with B + self.cert_config_fetcher.reset() + self.cert_config_fetcher.configure(False, self.cert_config_B) + self._do_one_shot_client_rpc(True, + root_certificates=CA_2_PEM, + private_key=CLIENT_KEY_1_PEM, + certificate_chain=CLIENT_CERT_CHAIN_1_PEM) + actual_calls = self.cert_config_fetcher.getCalls() + self.assertEqual(len(actual_calls), 1) + self.assertFalse(actual_calls[0].did_raise) + self.assertEqual(actual_calls[0].returned_cert_config, + self.cert_config_B) + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_server_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_server_test.py new file mode 100644 index 00000000000..3c519219d59 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_server_test.py @@ -0,0 +1,69 @@ +# Copyright 2018 The gRPC Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from concurrent import futures +import unittest +import logging + +import grpc + +from tests.unit import resources + + +class _ActualGenericRpcHandler(grpc.GenericRpcHandler): + + def service(self, handler_call_details): + return None + + +class ServerTest(unittest.TestCase): + + def test_not_a_generic_rpc_handler_at_construction(self): + with self.assertRaises(AttributeError) as exception_context: + grpc.server(futures.ThreadPoolExecutor(max_workers=5), + handlers=[ + _ActualGenericRpcHandler(), + object(), + ]) + self.assertIn('grpc.GenericRpcHandler', + str(exception_context.exception)) + + def test_not_a_generic_rpc_handler_after_construction(self): + server = grpc.server(futures.ThreadPoolExecutor(max_workers=5)) + with self.assertRaises(AttributeError) as exception_context: + server.add_generic_rpc_handlers([ + _ActualGenericRpcHandler(), + object(), + ]) + self.assertIn('grpc.GenericRpcHandler', + str(exception_context.exception)) + + def test_failed_port_binding_exception(self): + server = grpc.server(None, options=(('grpc.so_reuseport', 0),)) + port = server.add_insecure_port('localhost:0') + bind_address = "localhost:%d" % port + + with self.assertRaises(RuntimeError): + server.add_insecure_port(bind_address) + + server_credentials = grpc.ssl_server_credentials([ + (resources.private_key(), resources.certificate_chain()) + ]) + with self.assertRaises(RuntimeError): + server.add_secure_port(bind_address, server_credentials) + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_server_wait_for_termination_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_server_wait_for_termination_test.py new file mode 100644 index 00000000000..3dd95ea8bf6 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_server_wait_for_termination_test.py @@ -0,0 +1,91 @@ +# Copyright 2019 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division + +import datetime +from concurrent import futures +import unittest +import time +import threading +import six + +import grpc +from tests.unit.framework.common import test_constants + +_WAIT_FOR_BLOCKING = datetime.timedelta(seconds=1) + + +def _block_on_waiting(server, termination_event, timeout=None): + server.start() + server.wait_for_termination(timeout=timeout) + termination_event.set() + + +class ServerWaitForTerminationTest(unittest.TestCase): + + def test_unblock_by_invoking_stop(self): + termination_event = threading.Event() + server = grpc.server(futures.ThreadPoolExecutor()) + + wait_thread = threading.Thread(target=_block_on_waiting, + args=( + server, + termination_event, + )) + wait_thread.daemon = True + wait_thread.start() + time.sleep(_WAIT_FOR_BLOCKING.total_seconds()) + + server.stop(None) + termination_event.wait(timeout=test_constants.SHORT_TIMEOUT) + self.assertTrue(termination_event.is_set()) + + def test_unblock_by_del(self): + termination_event = threading.Event() + server = grpc.server(futures.ThreadPoolExecutor()) + + wait_thread = threading.Thread(target=_block_on_waiting, + args=( + server, + termination_event, + )) + wait_thread.daemon = True + wait_thread.start() + time.sleep(_WAIT_FOR_BLOCKING.total_seconds()) + + # Invoke manually here, in Python 2 it will be invoked by GC sometime. + server.__del__() + termination_event.wait(timeout=test_constants.SHORT_TIMEOUT) + self.assertTrue(termination_event.is_set()) + + def test_unblock_by_timeout(self): + termination_event = threading.Event() + server = grpc.server(futures.ThreadPoolExecutor()) + + wait_thread = threading.Thread(target=_block_on_waiting, + args=( + server, + termination_event, + test_constants.SHORT_TIMEOUT / 2, + )) + wait_thread.daemon = True + wait_thread.start() + + termination_event.wait(timeout=test_constants.SHORT_TIMEOUT) + self.assertTrue(termination_event.is_set()) + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_session_cache_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_session_cache_test.py new file mode 100644 index 00000000000..9bff4d2af00 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_session_cache_test.py @@ -0,0 +1,140 @@ +# Copyright 2018 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests experimental TLS Session Resumption API""" + +import pickle +import unittest +import logging + +import grpc +from grpc import _channel +from grpc.experimental import session_cache + +from tests.unit import test_common +from tests.unit import resources + +_REQUEST = b'\x00\x00\x00' +_RESPONSE = b'\x00\x00\x00' + +_UNARY_UNARY = '/test/UnaryUnary' + +_SERVER_HOST_OVERRIDE = 'foo.test.google.fr' +_ID = 'id' +_ID_KEY = 'id_key' +_AUTH_CTX = 'auth_ctx' + +_PRIVATE_KEY = resources.private_key() +_CERTIFICATE_CHAIN = resources.certificate_chain() +_TEST_ROOT_CERTIFICATES = resources.test_root_certificates() +_SERVER_CERTS = ((_PRIVATE_KEY, _CERTIFICATE_CHAIN),) +_PROPERTY_OPTIONS = (( + 'grpc.ssl_target_name_override', + _SERVER_HOST_OVERRIDE, +),) + + +def handle_unary_unary(request, servicer_context): + return pickle.dumps({ + _ID: servicer_context.peer_identities(), + _ID_KEY: servicer_context.peer_identity_key(), + _AUTH_CTX: servicer_context.auth_context() + }) + + +def start_secure_server(): + handler = grpc.method_handlers_generic_handler( + 'test', + {'UnaryUnary': grpc.unary_unary_rpc_method_handler(handle_unary_unary)}) + server = test_common.test_server() + server.add_generic_rpc_handlers((handler,)) + server_cred = grpc.ssl_server_credentials(_SERVER_CERTS) + port = server.add_secure_port('[::]:0', server_cred) + server.start() + + return server, port + + +class SSLSessionCacheTest(unittest.TestCase): + + def _do_one_shot_client_rpc(self, channel_creds, channel_options, port, + expect_ssl_session_reused): + channel = grpc.secure_channel('localhost:{}'.format(port), + channel_creds, + options=channel_options) + response = channel.unary_unary(_UNARY_UNARY)(_REQUEST) + auth_data = pickle.loads(response) + self.assertEqual(expect_ssl_session_reused, + auth_data[_AUTH_CTX]['ssl_session_reused']) + channel.close() + + def testSSLSessionCacheLRU(self): + server_1, port_1 = start_secure_server() + + cache = session_cache.ssl_session_cache_lru(1) + channel_creds = grpc.ssl_channel_credentials( + root_certificates=_TEST_ROOT_CERTIFICATES) + channel_options = _PROPERTY_OPTIONS + ( + ('grpc.ssl_session_cache', cache),) + + # Initial connection has no session to resume + self._do_one_shot_client_rpc(channel_creds, + channel_options, + port_1, + expect_ssl_session_reused=[b'false']) + + # Connection to server_1 resumes from initial session + self._do_one_shot_client_rpc(channel_creds, + channel_options, + port_1, + expect_ssl_session_reused=[b'true']) + + # Connection to a different server with the same name overwrites the cache entry + server_2, port_2 = start_secure_server() + self._do_one_shot_client_rpc(channel_creds, + channel_options, + port_2, + expect_ssl_session_reused=[b'false']) + self._do_one_shot_client_rpc(channel_creds, + channel_options, + port_2, + expect_ssl_session_reused=[b'true']) + server_2.stop(None) + + # Connection to server_1 now falls back to full TLS handshake + self._do_one_shot_client_rpc(channel_creds, + channel_options, + port_1, + expect_ssl_session_reused=[b'false']) + + # Re-creating server_1 causes old sessions to become invalid + server_1.stop(None) + server_1, port_1 = start_secure_server() + + # Old sessions should no longer be valid + self._do_one_shot_client_rpc(channel_creds, + channel_options, + port_1, + expect_ssl_session_reused=[b'false']) + + # Resumption should work for subsequent connections + self._do_one_shot_client_rpc(channel_creds, + channel_options, + port_1, + expect_ssl_session_reused=[b'true']) + server_1.stop(None) + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_signal_client.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_signal_client.py new file mode 100644 index 00000000000..0be1270749c --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_signal_client.py @@ -0,0 +1,119 @@ +# Copyright 2019 the gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Client for testing responsiveness to signals.""" + +from __future__ import print_function + +import argparse +import functools +import logging +import signal +import sys + +import grpc + +SIGTERM_MESSAGE = "Handling sigterm!" + +UNARY_UNARY = "/test/Unary" +UNARY_STREAM = "/test/ServerStreaming" + +_MESSAGE = b'\x00\x00\x00' + +_ASSERTION_MESSAGE = "Control flow should never reach here." + +# NOTE(gnossen): We use a global variable here so that the signal handler can be +# installed before the RPC begins. If we do not do this, then we may receive the +# SIGINT before the signal handler is installed. I'm not happy with per-process +# global state, but the per-process global state that is signal handlers +# somewhat forces my hand. +per_process_rpc_future = None + + +def handle_sigint(unused_signum, unused_frame): + print(SIGTERM_MESSAGE) + if per_process_rpc_future is not None: + per_process_rpc_future.cancel() + sys.stderr.flush() + # This sys.exit(0) avoids an exception caused by the cancelled RPC. + sys.exit(0) + + +def main_unary(server_target): + """Initiate a unary RPC to be interrupted by a SIGINT.""" + global per_process_rpc_future # pylint: disable=global-statement + with grpc.insecure_channel(server_target) as channel: + multicallable = channel.unary_unary(UNARY_UNARY) + signal.signal(signal.SIGINT, handle_sigint) + per_process_rpc_future = multicallable.future(_MESSAGE, + wait_for_ready=True) + result = per_process_rpc_future.result() + assert False, _ASSERTION_MESSAGE + + +def main_streaming(server_target): + """Initiate a streaming RPC to be interrupted by a SIGINT.""" + global per_process_rpc_future # pylint: disable=global-statement + with grpc.insecure_channel(server_target) as channel: + signal.signal(signal.SIGINT, handle_sigint) + per_process_rpc_future = channel.unary_stream(UNARY_STREAM)( + _MESSAGE, wait_for_ready=True) + for result in per_process_rpc_future: + pass + assert False, _ASSERTION_MESSAGE + + +def main_unary_with_exception(server_target): + """Initiate a unary RPC with a signal handler that will raise.""" + channel = grpc.insecure_channel(server_target) + try: + channel.unary_unary(UNARY_UNARY)(_MESSAGE, wait_for_ready=True) + except KeyboardInterrupt: + sys.stderr.write("Running signal handler.\n") + sys.stderr.flush() + + # This call should not hang. + channel.close() + + +def main_streaming_with_exception(server_target): + """Initiate a streaming RPC with a signal handler that will raise.""" + channel = grpc.insecure_channel(server_target) + try: + for _ in channel.unary_stream(UNARY_STREAM)(_MESSAGE, + wait_for_ready=True): + pass + except KeyboardInterrupt: + sys.stderr.write("Running signal handler.\n") + sys.stderr.flush() + + # This call should not hang. + channel.close() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Signal test client.') + parser.add_argument('server', help='Server target') + parser.add_argument('arity', help='Arity', choices=('unary', 'streaming')) + parser.add_argument('--exception', + help='Whether the signal throws an exception', + action='store_true') + args = parser.parse_args() + if args.arity == 'unary' and not args.exception: + main_unary(args.server) + elif args.arity == 'streaming' and not args.exception: + main_streaming(args.server) + elif args.arity == 'unary' and args.exception: + main_unary_with_exception(args.server) + else: + main_streaming_with_exception(args.server) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_signal_handling_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_signal_handling_test.py new file mode 100644 index 00000000000..a05e42d5a31 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_signal_handling_test.py @@ -0,0 +1,200 @@ +# Copyright 2019 the gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test of responsiveness to signals.""" + +import logging +import os +import signal +import subprocess +import tempfile +import threading +import unittest +import sys + +import grpc + +from tests.unit import test_common +from tests.unit import _signal_client + +_CLIENT_PATH = None +if sys.executable is not None: + _CLIENT_PATH = 'tests.unit._signal_client' +else: + # NOTE(rbellevi): For compatibility with internal testing. + if len(sys.argv) != 2: + raise RuntimeError("Must supply path to executable client.") + client_name = sys.argv[1].split("/")[-1] + del sys.argv[1] # For compatibility with test runner. + _CLIENT_PATH = os.path.realpath( + os.path.join(os.path.dirname(os.path.abspath(__file__)), client_name)) + +_HOST = 'localhost' + + +class _GenericHandler(grpc.GenericRpcHandler): + + def __init__(self): + self._connected_clients_lock = threading.RLock() + self._connected_clients_event = threading.Event() + self._connected_clients = 0 + + self._unary_unary_handler = grpc.unary_unary_rpc_method_handler( + self._handle_unary_unary) + self._unary_stream_handler = grpc.unary_stream_rpc_method_handler( + self._handle_unary_stream) + + def _on_client_connect(self): + with self._connected_clients_lock: + self._connected_clients += 1 + self._connected_clients_event.set() + + def _on_client_disconnect(self): + with self._connected_clients_lock: + self._connected_clients -= 1 + if self._connected_clients == 0: + self._connected_clients_event.clear() + + def await_connected_client(self): + """Blocks until a client connects to the server.""" + self._connected_clients_event.wait() + + def _handle_unary_unary(self, request, servicer_context): + """Handles a unary RPC. + + Blocks until the client disconnects and then echoes. + """ + stop_event = threading.Event() + + def on_rpc_end(): + self._on_client_disconnect() + stop_event.set() + + servicer_context.add_callback(on_rpc_end) + self._on_client_connect() + stop_event.wait() + return request + + def _handle_unary_stream(self, request, servicer_context): + """Handles a server streaming RPC. + + Blocks until the client disconnects and then echoes. + """ + stop_event = threading.Event() + + def on_rpc_end(): + self._on_client_disconnect() + stop_event.set() + + servicer_context.add_callback(on_rpc_end) + self._on_client_connect() + stop_event.wait() + yield request + + def service(self, handler_call_details): + if handler_call_details.method == _signal_client.UNARY_UNARY: + return self._unary_unary_handler + elif handler_call_details.method == _signal_client.UNARY_STREAM: + return self._unary_stream_handler + else: + return None + + +def _read_stream(stream): + stream.seek(0) + return stream.read() + + +def _start_client(args, stdout, stderr): + invocation = None + if sys.executable is not None: + invocation = (sys.executable, '-m', _CLIENT_PATH) + tuple(args) + else: + invocation = (_CLIENT_PATH,) + tuple(args) + env = os.environ.copy() + env['Y_PYTHON_ENTRY_POINT'] = ':main' + return subprocess.Popen(invocation, stdout=stdout, stderr=stderr, env=env) + + +class SignalHandlingTest(unittest.TestCase): + + def setUp(self): + self._server = test_common.test_server() + self._port = self._server.add_insecure_port('{}:0'.format(_HOST)) + self._handler = _GenericHandler() + self._server.add_generic_rpc_handlers((self._handler,)) + self._server.start() + + def tearDown(self): + self._server.stop(None) + + @unittest.skipIf(os.name == 'nt', 'SIGINT not supported on windows') + def testUnary(self): + """Tests that the server unary code path does not stall signal handlers.""" + server_target = '{}:{}'.format(_HOST, self._port) + with tempfile.TemporaryFile(mode='r') as client_stdout: + with tempfile.TemporaryFile(mode='r') as client_stderr: + client = _start_client((server_target, 'unary'), client_stdout, + client_stderr) + self._handler.await_connected_client() + client.send_signal(signal.SIGINT) + self.assertFalse(client.wait(), msg=_read_stream(client_stderr)) + client_stdout.seek(0) + self.assertIn(_signal_client.SIGTERM_MESSAGE, + client_stdout.read()) + + @unittest.skipIf(os.name == 'nt', 'SIGINT not supported on windows') + def testStreaming(self): + """Tests that the server streaming code path does not stall signal handlers.""" + server_target = '{}:{}'.format(_HOST, self._port) + with tempfile.TemporaryFile(mode='r') as client_stdout: + with tempfile.TemporaryFile(mode='r') as client_stderr: + client = _start_client((server_target, 'streaming'), + client_stdout, client_stderr) + self._handler.await_connected_client() + client.send_signal(signal.SIGINT) + self.assertFalse(client.wait(), msg=_read_stream(client_stderr)) + client_stdout.seek(0) + self.assertIn(_signal_client.SIGTERM_MESSAGE, + client_stdout.read()) + + @unittest.skipIf(os.name == 'nt', 'SIGINT not supported on windows') + def testUnaryWithException(self): + server_target = '{}:{}'.format(_HOST, self._port) + with tempfile.TemporaryFile(mode='r') as client_stdout: + with tempfile.TemporaryFile(mode='r') as client_stderr: + client = _start_client(('--exception', server_target, 'unary'), + client_stdout, client_stderr) + self._handler.await_connected_client() + client.send_signal(signal.SIGINT) + client.wait() + self.assertEqual(0, client.returncode) + + @unittest.skipIf(os.name == 'nt', 'SIGINT not supported on windows') + def testStreamingHandlerWithException(self): + server_target = '{}:{}'.format(_HOST, self._port) + with tempfile.TemporaryFile(mode='r') as client_stdout: + with tempfile.TemporaryFile(mode='r') as client_stderr: + client = _start_client( + ('--exception', server_target, 'streaming'), client_stdout, + client_stderr) + self._handler.await_connected_client() + client.send_signal(signal.SIGINT) + client.wait() + print(_read_stream(client_stderr)) + self.assertEqual(0, client.returncode) + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_tcp_proxy.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_tcp_proxy.py new file mode 100644 index 00000000000..84dc0e2d6cf --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_tcp_proxy.py @@ -0,0 +1,141 @@ +# Copyright 2019 the gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Proxies a TCP connection between a single client-server pair. + +This proxy is not suitable for production, but should work well for cases in +which a test needs to spy on the bytes put on the wire between a server and +a client. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import datetime +import select +import socket +import threading + +from tests.unit.framework.common import get_socket + +_TCP_PROXY_BUFFER_SIZE = 1024 +_TCP_PROXY_TIMEOUT = datetime.timedelta(milliseconds=500) + + +def _init_proxy_socket(gateway_address, gateway_port): + proxy_socket = socket.create_connection((gateway_address, gateway_port)) + return proxy_socket + + +class TcpProxy(object): + """Proxies a TCP connection between one client and one server.""" + + def __init__(self, bind_address, gateway_address, gateway_port): + self._bind_address = bind_address + self._gateway_address = gateway_address + self._gateway_port = gateway_port + + self._byte_count_lock = threading.RLock() + self._sent_byte_count = 0 + self._received_byte_count = 0 + + self._stop_event = threading.Event() + + self._port = None + self._listen_socket = None + self._proxy_socket = None + + # The following three attributes are owned by the serving thread. + self._northbound_data = b"" + self._southbound_data = b"" + self._client_sockets = [] + + self._thread = threading.Thread(target=self._run_proxy) + + def start(self): + _, self._port, self._listen_socket = get_socket( + bind_address=self._bind_address) + self._proxy_socket = _init_proxy_socket(self._gateway_address, + self._gateway_port) + self._thread.start() + + def get_port(self): + return self._port + + def _handle_reads(self, sockets_to_read): + for socket_to_read in sockets_to_read: + if socket_to_read is self._listen_socket: + client_socket, client_address = socket_to_read.accept() + self._client_sockets.append(client_socket) + elif socket_to_read is self._proxy_socket: + data = socket_to_read.recv(_TCP_PROXY_BUFFER_SIZE) + with self._byte_count_lock: + self._received_byte_count += len(data) + self._northbound_data += data + elif socket_to_read in self._client_sockets: + data = socket_to_read.recv(_TCP_PROXY_BUFFER_SIZE) + if data: + with self._byte_count_lock: + self._sent_byte_count += len(data) + self._southbound_data += data + else: + self._client_sockets.remove(socket_to_read) + else: + raise RuntimeError('Unidentified socket appeared in read set.') + + def _handle_writes(self, sockets_to_write): + for socket_to_write in sockets_to_write: + if socket_to_write is self._proxy_socket: + if self._southbound_data: + self._proxy_socket.sendall(self._southbound_data) + self._southbound_data = b"" + elif socket_to_write in self._client_sockets: + if self._northbound_data: + socket_to_write.sendall(self._northbound_data) + self._northbound_data = b"" + + def _run_proxy(self): + while not self._stop_event.is_set(): + expected_reads = (self._listen_socket, self._proxy_socket) + tuple( + self._client_sockets) + expected_writes = expected_reads + sockets_to_read, sockets_to_write, _ = select.select( + expected_reads, expected_writes, (), + _TCP_PROXY_TIMEOUT.total_seconds()) + self._handle_reads(sockets_to_read) + self._handle_writes(sockets_to_write) + for client_socket in self._client_sockets: + client_socket.close() + + def stop(self): + self._stop_event.set() + self._thread.join() + self._listen_socket.close() + self._proxy_socket.close() + + def get_byte_count(self): + with self._byte_count_lock: + return self._sent_byte_count, self._received_byte_count + + def reset_byte_count(self): + with self._byte_count_lock: + self._byte_count = 0 + self._received_byte_count = 0 + + def __enter__(self): + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stop() diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_version_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_version_test.py new file mode 100644 index 00000000000..3d37b319e5a --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/_version_test.py @@ -0,0 +1,30 @@ +# Copyright 2018 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test for grpc.__version__""" + +import unittest +import grpc +import logging +from grpc import _grpcio_metadata + + +class VersionTest(unittest.TestCase): + + def test_get_version(self): + self.assertEqual(grpc.__version__, _grpcio_metadata.__version__) + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/beta/__init__.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/beta/__init__.py new file mode 100644 index 00000000000..5fb4f3c3cfd --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/beta/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2015 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/beta/_beta_features_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/beta/_beta_features_test.py new file mode 100644 index 00000000000..a111d687641 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/beta/_beta_features_test.py @@ -0,0 +1,354 @@ +# Copyright 2015 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests Face interface compliance of the gRPC Python Beta API.""" + +import threading +import unittest + +from grpc.beta import implementations +from grpc.beta import interfaces +from grpc.framework.common import cardinality +from grpc.framework.interfaces.face import utilities +from tests.unit import resources +from tests.unit.beta import test_utilities +from tests.unit.framework.common import test_constants + +_SERVER_HOST_OVERRIDE = 'foo.test.google.fr' + +_PER_RPC_CREDENTIALS_METADATA_KEY = b'my-call-credentials-metadata-key' +_PER_RPC_CREDENTIALS_METADATA_VALUE = b'my-call-credentials-metadata-value' + +_GROUP = 'group' +_UNARY_UNARY = 'unary-unary' +_UNARY_STREAM = 'unary-stream' +_STREAM_UNARY = 'stream-unary' +_STREAM_STREAM = 'stream-stream' + +_REQUEST = b'abc' +_RESPONSE = b'123' + + +class _Servicer(object): + + def __init__(self): + self._condition = threading.Condition() + self._peer = None + self._serviced = False + + def unary_unary(self, request, context): + with self._condition: + self._request = request + self._peer = context.protocol_context().peer() + self._invocation_metadata = context.invocation_metadata() + context.protocol_context().disable_next_response_compression() + self._serviced = True + self._condition.notify_all() + return _RESPONSE + + def unary_stream(self, request, context): + with self._condition: + self._request = request + self._peer = context.protocol_context().peer() + self._invocation_metadata = context.invocation_metadata() + context.protocol_context().disable_next_response_compression() + self._serviced = True + self._condition.notify_all() + return + yield # pylint: disable=unreachable + + def stream_unary(self, request_iterator, context): + for request in request_iterator: + self._request = request + with self._condition: + self._peer = context.protocol_context().peer() + self._invocation_metadata = context.invocation_metadata() + context.protocol_context().disable_next_response_compression() + self._serviced = True + self._condition.notify_all() + return _RESPONSE + + def stream_stream(self, request_iterator, context): + for request in request_iterator: + with self._condition: + self._peer = context.protocol_context().peer() + context.protocol_context().disable_next_response_compression() + yield _RESPONSE + with self._condition: + self._invocation_metadata = context.invocation_metadata() + self._serviced = True + self._condition.notify_all() + + def peer(self): + with self._condition: + return self._peer + + def block_until_serviced(self): + with self._condition: + while not self._serviced: + self._condition.wait() + + +class _BlockingIterator(object): + + def __init__(self, upstream): + self._condition = threading.Condition() + self._upstream = upstream + self._allowed = [] + + def __iter__(self): + return self + + def __next__(self): + return self.next() + + def next(self): + with self._condition: + while True: + if self._allowed is None: + raise StopIteration() + elif self._allowed: + return self._allowed.pop(0) + else: + self._condition.wait() + + def allow(self): + with self._condition: + try: + self._allowed.append(next(self._upstream)) + except StopIteration: + self._allowed = None + self._condition.notify_all() + + +def _metadata_plugin(context, callback): + callback([ + (_PER_RPC_CREDENTIALS_METADATA_KEY, _PER_RPC_CREDENTIALS_METADATA_VALUE) + ], None) + + +class BetaFeaturesTest(unittest.TestCase): + + def setUp(self): + self._servicer = _Servicer() + method_implementations = { + (_GROUP, _UNARY_UNARY): + utilities.unary_unary_inline(self._servicer.unary_unary), + (_GROUP, _UNARY_STREAM): + utilities.unary_stream_inline(self._servicer.unary_stream), + (_GROUP, _STREAM_UNARY): + utilities.stream_unary_inline(self._servicer.stream_unary), + (_GROUP, _STREAM_STREAM): + utilities.stream_stream_inline(self._servicer.stream_stream), + } + + cardinalities = { + _UNARY_UNARY: cardinality.Cardinality.UNARY_UNARY, + _UNARY_STREAM: cardinality.Cardinality.UNARY_STREAM, + _STREAM_UNARY: cardinality.Cardinality.STREAM_UNARY, + _STREAM_STREAM: cardinality.Cardinality.STREAM_STREAM, + } + + server_options = implementations.server_options( + thread_pool_size=test_constants.POOL_SIZE) + self._server = implementations.server(method_implementations, + options=server_options) + server_credentials = implementations.ssl_server_credentials([ + ( + resources.private_key(), + resources.certificate_chain(), + ), + ]) + port = self._server.add_secure_port('[::]:0', server_credentials) + self._server.start() + self._channel_credentials = implementations.ssl_channel_credentials( + resources.test_root_certificates()) + self._call_credentials = implementations.metadata_call_credentials( + _metadata_plugin) + channel = test_utilities.not_really_secure_channel( + 'localhost', port, self._channel_credentials, _SERVER_HOST_OVERRIDE) + stub_options = implementations.stub_options( + thread_pool_size=test_constants.POOL_SIZE) + self._dynamic_stub = implementations.dynamic_stub(channel, + _GROUP, + cardinalities, + options=stub_options) + + def tearDown(self): + self._dynamic_stub = None + self._server.stop(test_constants.SHORT_TIMEOUT).wait() + + def test_unary_unary(self): + call_options = interfaces.grpc_call_options( + disable_compression=True, credentials=self._call_credentials) + response = getattr(self._dynamic_stub, + _UNARY_UNARY)(_REQUEST, + test_constants.LONG_TIMEOUT, + protocol_options=call_options) + self.assertEqual(_RESPONSE, response) + self.assertIsNotNone(self._servicer.peer()) + invocation_metadata = [ + (metadatum.key, metadatum.value) + for metadatum in self._servicer._invocation_metadata + ] + self.assertIn((_PER_RPC_CREDENTIALS_METADATA_KEY, + _PER_RPC_CREDENTIALS_METADATA_VALUE), + invocation_metadata) + + def test_unary_stream(self): + call_options = interfaces.grpc_call_options( + disable_compression=True, credentials=self._call_credentials) + response_iterator = getattr(self._dynamic_stub, _UNARY_STREAM)( + _REQUEST, + test_constants.LONG_TIMEOUT, + protocol_options=call_options) + self._servicer.block_until_serviced() + self.assertIsNotNone(self._servicer.peer()) + invocation_metadata = [ + (metadatum.key, metadatum.value) + for metadatum in self._servicer._invocation_metadata + ] + self.assertIn((_PER_RPC_CREDENTIALS_METADATA_KEY, + _PER_RPC_CREDENTIALS_METADATA_VALUE), + invocation_metadata) + + def test_stream_unary(self): + call_options = interfaces.grpc_call_options( + credentials=self._call_credentials) + request_iterator = _BlockingIterator(iter((_REQUEST,))) + response_future = getattr(self._dynamic_stub, _STREAM_UNARY).future( + request_iterator, + test_constants.LONG_TIMEOUT, + protocol_options=call_options) + response_future.protocol_context().disable_next_request_compression() + request_iterator.allow() + response_future.protocol_context().disable_next_request_compression() + request_iterator.allow() + self._servicer.block_until_serviced() + self.assertIsNotNone(self._servicer.peer()) + self.assertEqual(_RESPONSE, response_future.result()) + invocation_metadata = [ + (metadatum.key, metadatum.value) + for metadatum in self._servicer._invocation_metadata + ] + self.assertIn((_PER_RPC_CREDENTIALS_METADATA_KEY, + _PER_RPC_CREDENTIALS_METADATA_VALUE), + invocation_metadata) + + def test_stream_stream(self): + call_options = interfaces.grpc_call_options( + credentials=self._call_credentials) + request_iterator = _BlockingIterator(iter((_REQUEST,))) + response_iterator = getattr(self._dynamic_stub, _STREAM_STREAM)( + request_iterator, + test_constants.SHORT_TIMEOUT, + protocol_options=call_options) + response_iterator.protocol_context().disable_next_request_compression() + request_iterator.allow() + response = next(response_iterator) + response_iterator.protocol_context().disable_next_request_compression() + request_iterator.allow() + self._servicer.block_until_serviced() + self.assertIsNotNone(self._servicer.peer()) + self.assertEqual(_RESPONSE, response) + invocation_metadata = [ + (metadatum.key, metadatum.value) + for metadatum in self._servicer._invocation_metadata + ] + self.assertIn((_PER_RPC_CREDENTIALS_METADATA_KEY, + _PER_RPC_CREDENTIALS_METADATA_VALUE), + invocation_metadata) + + +class ContextManagementAndLifecycleTest(unittest.TestCase): + + def setUp(self): + self._servicer = _Servicer() + self._method_implementations = { + (_GROUP, _UNARY_UNARY): + utilities.unary_unary_inline(self._servicer.unary_unary), + (_GROUP, _UNARY_STREAM): + utilities.unary_stream_inline(self._servicer.unary_stream), + (_GROUP, _STREAM_UNARY): + utilities.stream_unary_inline(self._servicer.stream_unary), + (_GROUP, _STREAM_STREAM): + utilities.stream_stream_inline(self._servicer.stream_stream), + } + + self._cardinalities = { + _UNARY_UNARY: cardinality.Cardinality.UNARY_UNARY, + _UNARY_STREAM: cardinality.Cardinality.UNARY_STREAM, + _STREAM_UNARY: cardinality.Cardinality.STREAM_UNARY, + _STREAM_STREAM: cardinality.Cardinality.STREAM_STREAM, + } + + self._server_options = implementations.server_options( + thread_pool_size=test_constants.POOL_SIZE) + self._server_credentials = implementations.ssl_server_credentials([ + ( + resources.private_key(), + resources.certificate_chain(), + ), + ]) + self._channel_credentials = implementations.ssl_channel_credentials( + resources.test_root_certificates()) + self._stub_options = implementations.stub_options( + thread_pool_size=test_constants.POOL_SIZE) + + def test_stub_context(self): + server = implementations.server(self._method_implementations, + options=self._server_options) + port = server.add_secure_port('[::]:0', self._server_credentials) + server.start() + + channel = test_utilities.not_really_secure_channel( + 'localhost', port, self._channel_credentials, _SERVER_HOST_OVERRIDE) + dynamic_stub = implementations.dynamic_stub(channel, + _GROUP, + self._cardinalities, + options=self._stub_options) + for _ in range(100): + with dynamic_stub: + pass + for _ in range(10): + with dynamic_stub: + call_options = interfaces.grpc_call_options( + disable_compression=True) + response = getattr(dynamic_stub, + _UNARY_UNARY)(_REQUEST, + test_constants.LONG_TIMEOUT, + protocol_options=call_options) + self.assertEqual(_RESPONSE, response) + self.assertIsNotNone(self._servicer.peer()) + + server.stop(test_constants.SHORT_TIMEOUT).wait() + + def test_server_lifecycle(self): + for _ in range(100): + server = implementations.server(self._method_implementations, + options=self._server_options) + port = server.add_secure_port('[::]:0', self._server_credentials) + server.start() + server.stop(test_constants.SHORT_TIMEOUT).wait() + for _ in range(100): + server = implementations.server(self._method_implementations, + options=self._server_options) + server.add_secure_port('[::]:0', self._server_credentials) + server.add_insecure_port('[::]:0') + with server: + server.stop(test_constants.SHORT_TIMEOUT) + server.stop(test_constants.SHORT_TIMEOUT) + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/beta/_connectivity_channel_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/beta/_connectivity_channel_test.py new file mode 100644 index 00000000000..1416902eab8 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/beta/_connectivity_channel_test.py @@ -0,0 +1,32 @@ +# Copyright 2015 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests of grpc.beta._connectivity_channel.""" + +import unittest + +from grpc.beta import interfaces + + +class ConnectivityStatesTest(unittest.TestCase): + + def testBetaConnectivityStates(self): + self.assertIsNotNone(interfaces.ChannelConnectivity.IDLE) + self.assertIsNotNone(interfaces.ChannelConnectivity.CONNECTING) + self.assertIsNotNone(interfaces.ChannelConnectivity.READY) + self.assertIsNotNone(interfaces.ChannelConnectivity.TRANSIENT_FAILURE) + self.assertIsNotNone(interfaces.ChannelConnectivity.FATAL_FAILURE) + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/beta/_implementations_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/beta/_implementations_test.py new file mode 100644 index 00000000000..75a615eeffb --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/beta/_implementations_test.py @@ -0,0 +1,55 @@ +# Copyright 2016 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests the implementations module of the gRPC Python Beta API.""" + +import datetime +import unittest + +from oauth2client import client as oauth2client_client + +from grpc.beta import implementations +from tests.unit import resources + + +class ChannelCredentialsTest(unittest.TestCase): + + def test_runtime_provided_root_certificates(self): + channel_credentials = implementations.ssl_channel_credentials() + self.assertIsInstance(channel_credentials, + implementations.ChannelCredentials) + + def test_application_provided_root_certificates(self): + channel_credentials = implementations.ssl_channel_credentials( + resources.test_root_certificates()) + self.assertIsInstance(channel_credentials, + implementations.ChannelCredentials) + + +class CallCredentialsTest(unittest.TestCase): + + def test_google_call_credentials(self): + creds = oauth2client_client.GoogleCredentials( + 'token', 'client_id', 'secret', 'refresh_token', + datetime.datetime(2008, 6, 24), 'https://refresh.uri.com/', + 'user_agent') + call_creds = implementations.google_call_credentials(creds) + self.assertIsInstance(call_creds, implementations.CallCredentials) + + def test_access_token_call_credentials(self): + call_creds = implementations.access_token_call_credentials('token') + self.assertIsInstance(call_creds, implementations.CallCredentials) + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/beta/_not_found_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/beta/_not_found_test.py new file mode 100644 index 00000000000..837d2bbebf2 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/beta/_not_found_test.py @@ -0,0 +1,59 @@ +# Copyright 2015 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests of RPC-method-not-found behavior.""" + +import unittest + +from grpc.beta import implementations +from grpc.beta import interfaces +from grpc.framework.interfaces.face import face +from tests.unit.framework.common import test_constants + + +class NotFoundTest(unittest.TestCase): + + def setUp(self): + self._server = implementations.server({}) + port = self._server.add_insecure_port('[::]:0') + channel = implementations.insecure_channel('localhost', port) + self._generic_stub = implementations.generic_stub(channel) + self._server.start() + + def tearDown(self): + self._server.stop(0).wait() + self._generic_stub = None + + def test_blocking_unary_unary_not_found(self): + with self.assertRaises(face.LocalError) as exception_assertion_context: + self._generic_stub.blocking_unary_unary('groop', + 'meffod', + b'abc', + test_constants.LONG_TIMEOUT, + with_call=True) + self.assertIs(exception_assertion_context.exception.code, + interfaces.StatusCode.UNIMPLEMENTED) + + def test_future_stream_unary_not_found(self): + rpc_future = self._generic_stub.future_stream_unary( + 'grupe', 'mevvod', iter([b'def']), test_constants.LONG_TIMEOUT) + with self.assertRaises(face.LocalError) as exception_assertion_context: + rpc_future.result() + self.assertIs(exception_assertion_context.exception.code, + interfaces.StatusCode.UNIMPLEMENTED) + self.assertIs(rpc_future.exception().code, + interfaces.StatusCode.UNIMPLEMENTED) + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/beta/_utilities_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/beta/_utilities_test.py new file mode 100644 index 00000000000..e0422627962 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/beta/_utilities_test.py @@ -0,0 +1,93 @@ +# Copyright 2015 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests of grpc.beta.utilities.""" + +import threading +import time +import unittest + +from grpc.beta import implementations +from grpc.beta import utilities +from grpc.framework.foundation import future +from tests.unit.framework.common import test_constants + + +class _Callback(object): + + def __init__(self): + self._condition = threading.Condition() + self._value = None + + def accept_value(self, value): + with self._condition: + self._value = value + self._condition.notify_all() + + def block_until_called(self): + with self._condition: + while self._value is None: + self._condition.wait() + return self._value + + [email protected]('https://github.com/grpc/grpc/issues/16134') +class ChannelConnectivityTest(unittest.TestCase): + + def test_lonely_channel_connectivity(self): + channel = implementations.insecure_channel('localhost', 12345) + callback = _Callback() + + ready_future = utilities.channel_ready_future(channel) + ready_future.add_done_callback(callback.accept_value) + with self.assertRaises(future.TimeoutError): + ready_future.result(timeout=test_constants.SHORT_TIMEOUT) + self.assertFalse(ready_future.cancelled()) + self.assertFalse(ready_future.done()) + self.assertTrue(ready_future.running()) + ready_future.cancel() + value_passed_to_callback = callback.block_until_called() + self.assertIs(ready_future, value_passed_to_callback) + self.assertTrue(ready_future.cancelled()) + self.assertTrue(ready_future.done()) + self.assertFalse(ready_future.running()) + + def test_immediately_connectable_channel_connectivity(self): + server = implementations.server({}) + port = server.add_insecure_port('[::]:0') + server.start() + channel = implementations.insecure_channel('localhost', port) + callback = _Callback() + + try: + ready_future = utilities.channel_ready_future(channel) + ready_future.add_done_callback(callback.accept_value) + self.assertIsNone( + ready_future.result(timeout=test_constants.LONG_TIMEOUT)) + value_passed_to_callback = callback.block_until_called() + self.assertIs(ready_future, value_passed_to_callback) + self.assertFalse(ready_future.cancelled()) + self.assertTrue(ready_future.done()) + self.assertFalse(ready_future.running()) + # Cancellation after maturity has no effect. + ready_future.cancel() + self.assertFalse(ready_future.cancelled()) + self.assertTrue(ready_future.done()) + self.assertFalse(ready_future.running()) + finally: + ready_future.cancel() + server.stop(0) + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/beta/test_utilities.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/beta/test_utilities.py new file mode 100644 index 00000000000..c8d920d35e9 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/beta/test_utilities.py @@ -0,0 +1,40 @@ +# Copyright 2015 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test-appropriate entry points into the gRPC Python Beta API.""" + +import grpc +from grpc.beta import implementations + + +def not_really_secure_channel(host, port, channel_credentials, + server_host_override): + """Creates an insecure Channel to a remote host. + + Args: + host: The name of the remote host to which to connect. + port: The port of the remote host to which to connect. + channel_credentials: The implementations.ChannelCredentials with which to + connect. + server_host_override: The target name used for SSL host name checking. + + Returns: + An implementations.Channel to the remote host through which RPCs may be + conducted. + """ + target = '%s:%d' % (host, port) + channel = grpc.secure_channel(target, channel_credentials, (( + 'grpc.ssl_target_name_override', + server_host_override, + ),)) + return implementations.Channel(channel) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/credentials/README.md b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/credentials/README.md new file mode 100644 index 00000000000..100b43c1aaf --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/credentials/README.md @@ -0,0 +1,15 @@ +These are test keys *NOT* to be used in production. + +The `certificate_hierarchy_1` and `certificate_hierarchy_2` contain +two disjoint but similarly organized certificate hierarchies. Each +contains: + +* The respective root CA cert in `certs/ca.cert.pem` + +* The intermediate CA cert in + `intermediate/certs/intermediate.cert.pem`, signed by the root CA + +* A client cert and a server cert--both signed by the intermediate + CA--in `intermediate/certs/client.cert.pem` and + `intermediate/certs/localhost-1.cert.pem`; the corresponding keys + are in `intermediate/private` diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/credentials/ca.pem b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/credentials/ca.pem new file mode 100755 index 00000000000..49d39cd8ed5 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/credentials/ca.pem @@ -0,0 +1,20 @@ +-----BEGIN CERTIFICATE----- +MIIDWjCCAkKgAwIBAgIUWrP0VvHcy+LP6UuYNtiL9gBhD5owDQYJKoZIhvcNAQEL +BQAwVjELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDEPMA0GA1UEAwwGdGVzdGNhMB4XDTIw +MDMxNzE4NTk1MVoXDTMwMDMxNTE4NTk1MVowVjELMAkGA1UEBhMCQVUxEzARBgNV +BAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0 +ZDEPMA0GA1UEAwwGdGVzdGNhMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKC +AQEAsGL0oXflF0LzoM+Bh+qUU9yhqzw2w8OOX5mu/iNCyUOBrqaHi7mGHx73GD01 +diNzCzvlcQqdNIH6NQSL7DTpBjca66jYT9u73vZe2MDrr1nVbuLvfu9850cdxiUO +Inv5xf8+sTHG0C+a+VAvMhsLiRjsq+lXKRJyk5zkbbsETybqpxoJ+K7CoSy3yc/k +QIY3TipwEtwkKP4hzyo6KiGd/DPexie4nBUInN3bS1BUeNZ5zeaIC2eg3bkeeW7c +qT55b+Yen6CxY0TEkzBK6AKt/WUialKMgT0wbTxRZO7kUCH3Sq6e/wXeFdJ+HvdV +LPlAg5TnMaNpRdQih/8nRFpsdwIDAQABoyAwHjAMBgNVHRMEBTADAQH/MA4GA1Ud +DwEB/wQEAwICBDANBgkqhkiG9w0BAQsFAAOCAQEAkTrKZjBrJXHps/HrjNCFPb5a +THuGPCSsepe1wkKdSp1h4HGRpLoCgcLysCJ5hZhRpHkRihhef+rFHEe60UePQO3S +CVTtdJB4CYWpcNyXOdqefrbJW5QNljxgi6Fhvs7JJkBqdXIkWXtFk2eRgOIP2Eo9 +/OHQHlYnwZFrk6sp4wPyR+A95S0toZBcyDVz7u+hOW0pGK3wviOe9lvRgj/H3Pwt +bewb0l+MhRig0/DVHamyVxrDRbqInU1/GTNCwcZkXKYFWSf92U+kIcTth24Q1gcw +eZiLl5FfrWokUNytFElXob0V0a5/kbhiLc3yWmvWqHTpqCALbVyF+rKJo2f5Kw== +-----END CERTIFICATE----- diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/credentials/server1.key b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/credentials/server1.key new file mode 100755 index 00000000000..086462992cf --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/credentials/server1.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDnE443EknxvxBq +6+hvn/t09hl8hx366EBYvZmVM/NC+7igXRAjiJiA/mIaCvL3MS0Iz5hBLxSGICU+ +WproA3GCIFITIwcf/ETyWj/5xpgZ4AKrLrjQmmX8mhwUajfF3UvwMJrCOVqPp67t +PtP+2kBXaqrXdvnvXR41FsIB8V7zIAuIZB6bHQhiGVlc1sgZYsE2EGG9WMmHtS86 +qkAOTjG2XyjmPTGAwhGDpYkYrpzp99IiDh4/Veai81hn0ssQkbry0XRD/Ig3jcHh +23WiriPNJ0JsbgXUSLKRPZObA9VgOLy2aXoN84IMaeK3yy+cwSYG/99w93fUZJte +MXwz4oYZAgMBAAECggEBAIVn2Ncai+4xbH0OLWckabwgyJ4IM9rDc0LIU368O1kU +koais8qP9dujAWgfoh3sGh/YGgKn96VnsZjKHlyMgF+r4TaDJn3k2rlAOWcurGlj +1qaVlsV4HiEzp7pxiDmHhWvp4672Bb6iBG+bsjCUOEk/n9o9KhZzIBluRhtxCmw5 +nw4Do7z00PTvN81260uPWSc04IrytvZUiAIx/5qxD72bij2xJ8t/I9GI8g4FtoVB +8pB6S/hJX1PZhh9VlU6Yk+TOfOVnbebG4W5138LkB835eqk3Zz0qsbc2euoi8Hxi +y1VGwQEmMQ63jXz4c6g+X55ifvUK9Jpn5E8pq+pMd7ECgYEA93lYq+Cr54K4ey5t +sWMa+ye5RqxjzgXj2Kqr55jb54VWG7wp2iGbg8FMlkQwzTJwebzDyCSatguEZLuB +gRGroRnsUOy9vBvhKPOch9bfKIl6qOgzMJB267fBVWx5ybnRbWN/I7RvMQf3k+9y +biCIVnxDLEEYyx7z85/5qxsXg/MCgYEA7wmWKtCTn032Hy9P8OL49T0X6Z8FlkDC +Rk42ygrc/MUbugq9RGUxcCxoImOG9JXUpEtUe31YDm2j+/nbvrjl6/bP2qWs0V7l +dTJl6dABP51pCw8+l4cWgBBX08Lkeen812AAFNrjmDCjX6rHjWHLJcpS18fnRRkP +V1d/AHWX7MMCgYEA6Gsw2guhp0Zf2GCcaNK5DlQab8OL4Hwrpttzo4kuTlwtqNKp +Q9H4al9qfF4Cr1TFya98+EVYf8yFRM3NLNjZpe3gwYf2EerlJj7VLcahw0KKzoN1 +QBENfwgPLRk5sDkx9VhSmcfl/diLroZdpAwtv3vo4nEoxeuGFbKTGx3Qkf0CgYEA +xyR+dcb05Ygm3w4klHQTowQ10s1H80iaUcZBgQuR1ghEtDbUPZHsoR5t1xCB02ys +DgAwLv1bChIvxvH/L6KM8ovZ2LekBX4AviWxoBxJnfz/EVau98B0b1auRN6eSC83 +FRuGldlSOW1z/nSh8ViizSYE5H5HX1qkXEippvFRE88CgYB3Bfu3YQY60ITWIShv +nNkdcbTT9eoP9suaRJjw92Ln+7ZpALYlQMKUZmJ/5uBmLs4RFwUTQruLOPL4yLTH +awADWUzs3IRr1fwn9E+zM8JVyKCnUEM3w4N5UZskGO2klashAd30hWO+knRv/y0r +uGIYs9Ek7YXlXIRVrzMwcsrt1w== +-----END PRIVATE KEY----- diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/credentials/server1.pem b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/credentials/server1.pem new file mode 100755 index 00000000000..88244f856c6 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/credentials/server1.pem @@ -0,0 +1,22 @@ +-----BEGIN CERTIFICATE----- +MIIDtDCCApygAwIBAgIUbJfTREJ6k6/+oInWhV1O1j3ZT0IwDQYJKoZIhvcNAQEL +BQAwVjELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDEPMA0GA1UEAwwGdGVzdGNhMB4XDTIw +MDMxODAzMTA0MloXDTMwMDMxNjAzMTA0MlowZTELMAkGA1UEBhMCVVMxETAPBgNV +BAgMCElsbGlub2lzMRAwDgYDVQQHDAdDaGljYWdvMRUwEwYDVQQKDAxFeGFtcGxl +LCBDby4xGjAYBgNVBAMMESoudGVzdC5nb29nbGUuY29tMIIBIjANBgkqhkiG9w0B +AQEFAAOCAQ8AMIIBCgKCAQEA5xOONxJJ8b8Qauvob5/7dPYZfIcd+uhAWL2ZlTPz +Qvu4oF0QI4iYgP5iGgry9zEtCM+YQS8UhiAlPlqa6ANxgiBSEyMHH/xE8lo/+caY +GeACqy640Jpl/JocFGo3xd1L8DCawjlaj6eu7T7T/tpAV2qq13b5710eNRbCAfFe +8yALiGQemx0IYhlZXNbIGWLBNhBhvVjJh7UvOqpADk4xtl8o5j0xgMIRg6WJGK6c +6ffSIg4eP1XmovNYZ9LLEJG68tF0Q/yIN43B4dt1oq4jzSdCbG4F1EiykT2TmwPV +YDi8tml6DfOCDGnit8svnMEmBv/fcPd31GSbXjF8M+KGGQIDAQABo2swaTAJBgNV +HRMEAjAAMAsGA1UdDwQEAwIF4DBPBgNVHREESDBGghAqLnRlc3QuZ29vZ2xlLmZy +ghh3YXRlcnpvb2kudGVzdC5nb29nbGUuYmWCEioudGVzdC55b3V0dWJlLmNvbYcE +wKgBAzANBgkqhkiG9w0BAQsFAAOCAQEAS8hDQA8PSgipgAml7Q3/djwQ644ghWQv +C2Kb+r30RCY1EyKNhnQnIIh/OUbBZvh0M0iYsy6xqXgfDhCB93AA6j0i5cS8fkhH +Jl4RK0tSkGQ3YNY4NzXwQP/vmUgfkw8VBAZ4Y4GKxppdATjffIW+srbAmdDruIRM +wPeikgOoRrXf0LA1fi4TqxARzeRwenQpayNfGHTvVF9aJkl8HoaMunTAdG5pIVcr +9GKi/gEMpXUJbbVv3U5frX1Wo4CFo+rZWJ/LyCMeb0jciNLxSdMwj/E/ZuExlyeZ +gc9ctPjSMvgSyXEKv6Vwobleeg88V2ZgzenziORoWj4KszG/lbQZvg== +-----END CERTIFICATE----- diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/framework/__init__.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/framework/__init__.py new file mode 100644 index 00000000000..5fb4f3c3cfd --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/framework/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2015 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/framework/common/__init__.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/framework/common/__init__.py new file mode 100644 index 00000000000..8b58a0c46af --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/framework/common/__init__.py @@ -0,0 +1,102 @@ +# Copyright 2019 The gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import os +import socket +import errno + +_DEFAULT_SOCK_OPTIONS = (socket.SO_REUSEADDR, + socket.SO_REUSEPORT) if os.name != 'nt' else ( + socket.SO_REUSEADDR,) +_UNRECOVERABLE_ERRNOS = (errno.EADDRINUSE, errno.ENOSR) + + +def get_socket(bind_address='localhost', + port=0, + listen=True, + sock_options=_DEFAULT_SOCK_OPTIONS): + """Opens a socket. + + Useful for reserving a port for a system-under-test. + + Args: + bind_address: The host to which to bind. + port: The port to which to bind. + listen: A boolean value indicating whether or not to listen on the socket. + sock_options: A sequence of socket options to apply to the socket. + + Returns: + A tuple containing: + - the address to which the socket is bound + - the port to which the socket is bound + - the socket object itself + """ + _sock_options = sock_options if sock_options else [] + if socket.has_ipv6: + address_families = (socket.AF_INET6, socket.AF_INET) + else: + address_families = (socket.AF_INET) + for address_family in address_families: + try: + sock = socket.socket(address_family, socket.SOCK_STREAM) + for sock_option in _sock_options: + sock.setsockopt(socket.SOL_SOCKET, sock_option, 1) + sock.bind((bind_address, port)) + if listen: + sock.listen(1) + return bind_address, sock.getsockname()[1], sock + except OSError as os_error: + sock.close() + if os_error.errno in _UNRECOVERABLE_ERRNOS: + raise + else: + continue + # For PY2, socket.error is a child class of IOError; for PY3, it is + # pointing to OSError. We need this catch to make it 2/3 agnostic. + except socket.error: # pylint: disable=duplicate-except + sock.close() + continue + raise RuntimeError("Failed to bind to {} with sock_options {}".format( + bind_address, sock_options)) + + +def bound_socket(bind_address='localhost', + port=0, + listen=True, + sock_options=_DEFAULT_SOCK_OPTIONS): + """Opens a socket bound to an arbitrary port. + + Useful for reserving a port for a system-under-test. + + Args: + bind_address: The host to which to bind. + port: The port to which to bind. + listen: A boolean value indicating whether or not to listen on the socket. + sock_options: A sequence of socket options to apply to the socket. + + Yields: + A tuple containing: + - the address to which the socket is bound + - the port to which the socket is bound + """ + host, port, sock = get_socket(bind_address=bind_address, + port=port, + listen=listen, + sock_options=sock_options) + try: + yield host, port + finally: + sock.close() diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/framework/common/test_constants.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/framework/common/test_constants.py new file mode 100644 index 00000000000..2b9eb2e35bd --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/framework/common/test_constants.py @@ -0,0 +1,45 @@ +# Copyright 2015 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Constants shared among tests throughout RPC Framework.""" + +# Value for maximum duration in seconds that a test is allowed for its actual +# behavioral logic, excluding all time spent deliberately waiting in the test. +TIME_ALLOWANCE = 10 +# Value for maximum duration in seconds of RPCs that may time out as part of a +# test. +SHORT_TIMEOUT = 4 +# Absurdly large value for maximum duration in seconds for should-not-time-out +# RPCs made during tests. +LONG_TIMEOUT = 3000 +# Values to supply on construction of an object that will service RPCs; these +# should not be used as the actual timeout values of any RPCs made during tests. +DEFAULT_TIMEOUT = 300 +MAXIMUM_TIMEOUT = 3600 + +# The number of payloads to transmit in streaming tests. +STREAM_LENGTH = 200 + +# The size of payloads to transmit in tests. +PAYLOAD_SIZE = 256 * 1024 + 17 + +# The concurrency to use in tests of concurrent RPCs that will not create as +# many threads as RPCs. +RPC_CONCURRENCY = 200 + +# The concurrency to use in tests of concurrent RPCs that will create as many +# threads as RPCs. +THREAD_CONCURRENCY = 25 + +# The size of thread pools to use in tests. +POOL_SIZE = 10 diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/framework/common/test_control.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/framework/common/test_control.py new file mode 100644 index 00000000000..6a422825cc7 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/framework/common/test_control.py @@ -0,0 +1,97 @@ +# Copyright 2015 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Code for instructing systems under test to block or fail.""" + +import abc +import contextlib +import threading + +import six + + +class Defect(Exception): + """Simulates a programming defect raised into in a system under test. + + Use of a standard exception type is too easily misconstrued as an actual + defect in either the test infrastructure or the system under test. + """ + + +class Control(six.with_metaclass(abc.ABCMeta)): + """An object that accepts program control from a system under test. + + Systems under test passed a Control should call its control() method + frequently during execution. The control() method may block, raise an + exception, or do nothing, all according to the enclosing test's desire for + the system under test to simulate hanging, failing, or functioning. + """ + + @abc.abstractmethod + def control(self): + """Potentially does anything.""" + raise NotImplementedError() + + +class PauseFailControl(Control): + """A Control that can be used to pause or fail code under control. + + This object is only safe for use from two threads: one of the system under + test calling control and the other from the test system calling pause, + block_until_paused, and fail. + """ + + def __init__(self): + self._condition = threading.Condition() + self._pause = False + self._paused = False + self._fail = False + + def control(self): + with self._condition: + if self._fail: + raise Defect() + + while self._pause: + self._paused = True + self._condition.notify_all() + self._condition.wait() + self._paused = False + + @contextlib.contextmanager + def pause(self): + """Pauses code under control while controlling code is in context.""" + with self._condition: + self._pause = True + yield + with self._condition: + self._pause = False + self._condition.notify_all() + + def block_until_paused(self): + """Blocks controlling code until code under control is paused. + + May only be called within the context of a pause call. + """ + with self._condition: + while not self._paused: + self._condition.wait() + + @contextlib.contextmanager + def fail(self): + """Fails code under control while controlling code is in context.""" + with self._condition: + self._fail = True + yield + with self._condition: + self._fail = False diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/framework/common/test_coverage.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/framework/common/test_coverage.py new file mode 100644 index 00000000000..f90a11963fb --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/framework/common/test_coverage.py @@ -0,0 +1,101 @@ +# Copyright 2015 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Governs coverage for tests of RPCs throughout RPC Framework.""" + +import abc + +import six + +# This code is designed for use with the unittest module. +# pylint: disable=invalid-name + + +class Coverage(six.with_metaclass(abc.ABCMeta)): + """Specification of test coverage.""" + + @abc.abstractmethod + def testSuccessfulUnaryRequestUnaryResponse(self): + raise NotImplementedError() + + @abc.abstractmethod + def testSuccessfulUnaryRequestStreamResponse(self): + raise NotImplementedError() + + @abc.abstractmethod + def testSuccessfulStreamRequestUnaryResponse(self): + raise NotImplementedError() + + @abc.abstractmethod + def testSuccessfulStreamRequestStreamResponse(self): + raise NotImplementedError() + + @abc.abstractmethod + def testSequentialInvocations(self): + raise NotImplementedError() + + @abc.abstractmethod + def testParallelInvocations(self): + raise NotImplementedError() + + @abc.abstractmethod + def testWaitingForSomeButNotAllParallelInvocations(self): + raise NotImplementedError() + + @abc.abstractmethod + def testCancelledUnaryRequestUnaryResponse(self): + raise NotImplementedError() + + @abc.abstractmethod + def testCancelledUnaryRequestStreamResponse(self): + raise NotImplementedError() + + @abc.abstractmethod + def testCancelledStreamRequestUnaryResponse(self): + raise NotImplementedError() + + @abc.abstractmethod + def testCancelledStreamRequestStreamResponse(self): + raise NotImplementedError() + + @abc.abstractmethod + def testExpiredUnaryRequestUnaryResponse(self): + raise NotImplementedError() + + @abc.abstractmethod + def testExpiredUnaryRequestStreamResponse(self): + raise NotImplementedError() + + @abc.abstractmethod + def testExpiredStreamRequestUnaryResponse(self): + raise NotImplementedError() + + @abc.abstractmethod + def testExpiredStreamRequestStreamResponse(self): + raise NotImplementedError() + + @abc.abstractmethod + def testFailedUnaryRequestUnaryResponse(self): + raise NotImplementedError() + + @abc.abstractmethod + def testFailedUnaryRequestStreamResponse(self): + raise NotImplementedError() + + @abc.abstractmethod + def testFailedStreamRequestUnaryResponse(self): + raise NotImplementedError() + + @abc.abstractmethod + def testFailedStreamRequestStreamResponse(self): + raise NotImplementedError() diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/framework/foundation/__init__.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/framework/foundation/__init__.py new file mode 100644 index 00000000000..5fb4f3c3cfd --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/framework/foundation/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2015 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/framework/foundation/_logging_pool_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/framework/foundation/_logging_pool_test.py new file mode 100644 index 00000000000..c4ea03177cc --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/framework/foundation/_logging_pool_test.py @@ -0,0 +1,73 @@ +# Copyright 2015 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for grpc.framework.foundation.logging_pool.""" + +import threading +import unittest + +from grpc.framework.foundation import logging_pool + +_POOL_SIZE = 16 + + +class _CallableObject(object): + + def __init__(self): + self._lock = threading.Lock() + self._passed_values = [] + + def __call__(self, value): + with self._lock: + self._passed_values.append(value) + + def passed_values(self): + with self._lock: + return tuple(self._passed_values) + + +class LoggingPoolTest(unittest.TestCase): + + def testUpAndDown(self): + pool = logging_pool.pool(_POOL_SIZE) + pool.shutdown(wait=True) + + with logging_pool.pool(_POOL_SIZE) as pool: + self.assertIsNotNone(pool) + + def testTaskExecuted(self): + test_list = [] + + with logging_pool.pool(_POOL_SIZE) as pool: + pool.submit(lambda: test_list.append(object())).result() + + self.assertTrue(test_list) + + def testException(self): + with logging_pool.pool(_POOL_SIZE) as pool: + raised_exception = pool.submit(lambda: 1 / 0).exception() + + self.assertIsNotNone(raised_exception) + + def testCallableObjectExecuted(self): + callable_object = _CallableObject() + passed_object = object() + with logging_pool.pool(_POOL_SIZE) as pool: + future = pool.submit(callable_object, passed_object) + self.assertIsNone(future.result()) + self.assertSequenceEqual((passed_object,), + callable_object.passed_values()) + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/framework/foundation/stream_testing.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/framework/foundation/stream_testing.py new file mode 100644 index 00000000000..dd5c5b3b031 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/framework/foundation/stream_testing.py @@ -0,0 +1,57 @@ +# Copyright 2015 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities for testing stream-related code.""" + +from grpc.framework.foundation import stream + + +class TestConsumer(stream.Consumer): + """A stream.Consumer instrumented for testing. + + Attributes: + calls: A sequence of value-termination pairs describing the history of calls + made on this object. + """ + + def __init__(self): + self.calls = [] + + def consume(self, value): + """See stream.Consumer.consume for specification.""" + self.calls.append((value, False)) + + def terminate(self): + """See stream.Consumer.terminate for specification.""" + self.calls.append((None, True)) + + def consume_and_terminate(self, value): + """See stream.Consumer.consume_and_terminate for specification.""" + self.calls.append((value, True)) + + def is_legal(self): + """Reports whether or not a legal sequence of calls has been made.""" + terminated = False + for value, terminal in self.calls: + if terminated: + return False + elif terminal: + terminated = True + elif value is None: + return False + else: # pylint: disable=useless-else-on-loop + return True + + def values(self): + """Returns the sequence of values that have been passed to this Consumer.""" + return [value for value, _ in self.calls if value] diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/resources.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/resources.py new file mode 100644 index 00000000000..6efd870fc86 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/resources.py @@ -0,0 +1,113 @@ +# Copyright 2015 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Constants and functions for data used in testing.""" + +import os +import pkgutil + +_ROOT_CERTIFICATES_RESOURCE_PATH = 'credentials/ca.pem' +_PRIVATE_KEY_RESOURCE_PATH = 'credentials/server1.key' +_CERTIFICATE_CHAIN_RESOURCE_PATH = 'credentials/server1.pem' + + +def test_root_certificates(): + return pkgutil.get_data(__name__, _ROOT_CERTIFICATES_RESOURCE_PATH) + + +def private_key(): + return pkgutil.get_data(__name__, _PRIVATE_KEY_RESOURCE_PATH) + + +def certificate_chain(): + return pkgutil.get_data(__name__, _CERTIFICATE_CHAIN_RESOURCE_PATH) + + +def cert_hier_1_root_ca_cert(): + return pkgutil.get_data( + __name__, 'credentials/certificate_hierarchy_1/certs/ca.cert.pem') + + +def cert_hier_1_intermediate_ca_cert(): + return pkgutil.get_data( + __name__, + 'credentials/certificate_hierarchy_1/intermediate/certs/intermediate.cert.pem' + ) + + +def cert_hier_1_client_1_key(): + return pkgutil.get_data( + __name__, + 'credentials/certificate_hierarchy_1/intermediate/private/client.key.pem' + ) + + +def cert_hier_1_client_1_cert(): + return pkgutil.get_data( + __name__, + 'credentials/certificate_hierarchy_1/intermediate/certs/client.cert.pem' + ) + + +def cert_hier_1_server_1_key(): + return pkgutil.get_data( + __name__, + 'credentials/certificate_hierarchy_1/intermediate/private/localhost-1.key.pem' + ) + + +def cert_hier_1_server_1_cert(): + return pkgutil.get_data( + __name__, + 'credentials/certificate_hierarchy_1/intermediate/certs/localhost-1.cert.pem' + ) + + +def cert_hier_2_root_ca_cert(): + return pkgutil.get_data( + __name__, 'credentials/certificate_hierarchy_2/certs/ca.cert.pem') + + +def cert_hier_2_intermediate_ca_cert(): + return pkgutil.get_data( + __name__, + 'credentials/certificate_hierarchy_2/intermediate/certs/intermediate.cert.pem' + ) + + +def cert_hier_2_client_1_key(): + return pkgutil.get_data( + __name__, + 'credentials/certificate_hierarchy_2/intermediate/private/client.key.pem' + ) + + +def cert_hier_2_client_1_cert(): + return pkgutil.get_data( + __name__, + 'credentials/certificate_hierarchy_2/intermediate/certs/client.cert.pem' + ) + + +def cert_hier_2_server_1_key(): + return pkgutil.get_data( + __name__, + 'credentials/certificate_hierarchy_2/intermediate/private/localhost-1.key.pem' + ) + + +def cert_hier_2_server_1_cert(): + return pkgutil.get_data( + __name__, + 'credentials/certificate_hierarchy_2/intermediate/certs/localhost-1.cert.pem' + ) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/test_common.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/test_common.py new file mode 100644 index 00000000000..59ded0752fd --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/test_common.py @@ -0,0 +1,134 @@ +# Copyright 2015 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Common code used throughout tests of gRPC.""" + +import collections +import threading + +from concurrent import futures +import grpc +import six + +INVOCATION_INITIAL_METADATA = ( + ('0', 'abc'), + ('1', 'def'), + ('2', 'ghi'), +) +SERVICE_INITIAL_METADATA = ( + ('3', 'jkl'), + ('4', 'mno'), + ('5', 'pqr'), +) +SERVICE_TERMINAL_METADATA = ( + ('6', 'stu'), + ('7', 'vwx'), + ('8', 'yza'), +) +DETAILS = 'test details' + + +def metadata_transmitted(original_metadata, transmitted_metadata): + """Judges whether or not metadata was acceptably transmitted. + + gRPC is allowed to insert key-value pairs into the metadata values given by + applications and to reorder key-value pairs with different keys but it is not + allowed to alter existing key-value pairs or to reorder key-value pairs with + the same key. + + Args: + original_metadata: A metadata value used in a test of gRPC. An iterable over + iterables of length 2. + transmitted_metadata: A metadata value corresponding to original_metadata + after having been transmitted via gRPC. An iterable over iterables of + length 2. + + Returns: + A boolean indicating whether transmitted_metadata accurately reflects + original_metadata after having been transmitted via gRPC. + """ + original = collections.defaultdict(list) + for key, value in original_metadata: + original[key].append(value) + transmitted = collections.defaultdict(list) + for key, value in transmitted_metadata: + transmitted[key].append(value) + + for key, values in six.iteritems(original): + transmitted_values = transmitted[key] + transmitted_iterator = iter(transmitted_values) + try: + for value in values: + while True: + transmitted_value = next(transmitted_iterator) + if value == transmitted_value: + break + except StopIteration: + return False + else: + return True + + +def test_secure_channel(target, channel_credentials, server_host_override): + """Creates an insecure Channel to a remote host. + + Args: + host: The name of the remote host to which to connect. + port: The port of the remote host to which to connect. + channel_credentials: The implementations.ChannelCredentials with which to + connect. + server_host_override: The target name used for SSL host name checking. + + Returns: + An implementations.Channel to the remote host through which RPCs may be + conducted. + """ + channel = grpc.secure_channel(target, channel_credentials, (( + 'grpc.ssl_target_name_override', + server_host_override, + ),)) + return channel + + +def test_server(max_workers=10, reuse_port=False): + """Creates an insecure grpc server. + + These servers have SO_REUSEPORT disabled to prevent cross-talk. + """ + return grpc.server(futures.ThreadPoolExecutor(max_workers=max_workers), + options=(('grpc.so_reuseport', int(reuse_port)),)) + + +class WaitGroup(object): + + def __init__(self, n=0): + self.count = n + self.cv = threading.Condition() + + def add(self, n): + self.cv.acquire() + self.count += n + self.cv.release() + + def done(self): + self.cv.acquire() + self.count -= 1 + if self.count == 0: + self.cv.notify_all() + self.cv.release() + + def wait(self): + self.cv.acquire() + while self.count > 0: + self.cv.wait() + self.cv.release() diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/thread_pool.py b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/thread_pool.py new file mode 100644 index 00000000000..094e203cd95 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests/unit/thread_pool.py @@ -0,0 +1,34 @@ +# Copyright 2016 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading +from concurrent import futures + + +class RecordingThreadPool(futures.ThreadPoolExecutor): + """A thread pool that records if used.""" + + def __init__(self, max_workers): + self._tp_executor = futures.ThreadPoolExecutor(max_workers=max_workers) + self._lock = threading.Lock() + self._was_used = False + + def submit(self, fn, *args, **kwargs): # pylint: disable=arguments-differ + with self._lock: + self._was_used = True + self._tp_executor.submit(fn, *args, **kwargs) + + def was_used(self): + with self._lock: + return self._was_used diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/__init__.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/__init__.py new file mode 100644 index 00000000000..8ddd3106965 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2019 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import + +from tests import _loader +from tests import _runner + +Loader = _loader.Loader +Runner = _runner.Runner diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/_sanity/__init__.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/_sanity/__init__.py new file mode 100644 index 00000000000..f4b321fc5b2 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/_sanity/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2019 The gRPC Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/_sanity/_sanity_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/_sanity/_sanity_test.py new file mode 100644 index 00000000000..e74dec0739b --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/_sanity/_sanity_test.py @@ -0,0 +1,27 @@ +# Copyright 2019 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from tests._sanity import _sanity_test + + +class AioSanityTest(_sanity_test.SanityTest): + + TEST_PKG_MODULE_NAME = 'tests_aio' + TEST_PKG_PATH = 'tests_aio' + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/benchmark/benchmark_client.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/benchmark/benchmark_client.py new file mode 100644 index 00000000000..51a046c20c7 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/benchmark/benchmark_client.py @@ -0,0 +1,155 @@ +# Copyright 2020 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The Python AsyncIO Benchmark Clients.""" + +import abc +import asyncio +import time +import logging +import random + +import grpc +from grpc.experimental import aio + +from src.proto.grpc.testing import (benchmark_service_pb2_grpc, control_pb2, + messages_pb2) +from tests.qps import histogram +from tests.unit import resources + + +class GenericStub(object): + + def __init__(self, channel: aio.Channel): + self.UnaryCall = channel.unary_unary( + '/grpc.testing.BenchmarkService/UnaryCall') + self.StreamingCall = channel.stream_stream( + '/grpc.testing.BenchmarkService/StreamingCall') + + +class BenchmarkClient(abc.ABC): + """Benchmark client interface that exposes a non-blocking send_request().""" + + def __init__(self, address: str, config: control_pb2.ClientConfig, + hist: histogram.Histogram): + # Disables underlying reuse of subchannels + unique_option = (('iv', random.random()),) + + # Parses the channel argument from config + channel_args = tuple( + (arg.name, arg.str_value) if arg.HasField('str_value') else ( + arg.name, int(arg.int_value)) for arg in config.channel_args) + + # Creates the channel + if config.HasField('security_params'): + channel_credentials = grpc.ssl_channel_credentials( + resources.test_root_certificates(),) + server_host_override_option = (( + 'grpc.ssl_target_name_override', + config.security_params.server_host_override, + ),) + self._channel = aio.secure_channel( + address, channel_credentials, + unique_option + channel_args + server_host_override_option) + else: + self._channel = aio.insecure_channel(address, + options=unique_option + + channel_args) + + # Creates the stub + if config.payload_config.WhichOneof('payload') == 'simple_params': + self._generic = False + self._stub = benchmark_service_pb2_grpc.BenchmarkServiceStub( + self._channel) + payload = messages_pb2.Payload( + body=b'\0' * config.payload_config.simple_params.req_size) + self._request = messages_pb2.SimpleRequest( + payload=payload, + response_size=config.payload_config.simple_params.resp_size) + else: + self._generic = True + self._stub = GenericStub(self._channel) + self._request = b'\0' * config.payload_config.bytebuf_params.req_size + + self._hist = hist + self._response_callbacks = [] + self._concurrency = config.outstanding_rpcs_per_channel + + async def run(self) -> None: + await self._channel.channel_ready() + + async def stop(self) -> None: + await self._channel.close() + + def _record_query_time(self, query_time: float) -> None: + self._hist.add(query_time * 1e9) + + +class UnaryAsyncBenchmarkClient(BenchmarkClient): + + def __init__(self, address: str, config: control_pb2.ClientConfig, + hist: histogram.Histogram): + super().__init__(address, config, hist) + self._running = None + self._stopped = asyncio.Event() + + async def _send_request(self): + start_time = time.monotonic() + await self._stub.UnaryCall(self._request) + self._record_query_time(time.monotonic() - start_time) + + async def _send_indefinitely(self) -> None: + while self._running: + await self._send_request() + + async def run(self) -> None: + await super().run() + self._running = True + senders = (self._send_indefinitely() for _ in range(self._concurrency)) + await asyncio.gather(*senders) + self._stopped.set() + + async def stop(self) -> None: + self._running = False + await self._stopped.wait() + await super().stop() + + +class StreamingAsyncBenchmarkClient(BenchmarkClient): + + def __init__(self, address: str, config: control_pb2.ClientConfig, + hist: histogram.Histogram): + super().__init__(address, config, hist) + self._running = None + self._stopped = asyncio.Event() + + async def _one_streaming_call(self): + call = self._stub.StreamingCall() + while self._running: + start_time = time.time() + await call.write(self._request) + await call.read() + self._record_query_time(time.time() - start_time) + await call.done_writing() + + async def run(self): + await super().run() + self._running = True + senders = (self._one_streaming_call() for _ in range(self._concurrency)) + await asyncio.gather(*senders) + self._stopped.set() + + async def stop(self): + self._running = False + await self._stopped.wait() + await super().stop() diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/benchmark/benchmark_servicer.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/benchmark/benchmark_servicer.py new file mode 100644 index 00000000000..50d3065cd19 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/benchmark/benchmark_servicer.py @@ -0,0 +1,55 @@ +# Copyright 2020 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The Python AsyncIO Benchmark Servicers.""" + +import asyncio +import logging +import unittest + +from grpc.experimental import aio + +from src.proto.grpc.testing import benchmark_service_pb2_grpc, messages_pb2 + + +class BenchmarkServicer(benchmark_service_pb2_grpc.BenchmarkServiceServicer): + + async def UnaryCall(self, request, unused_context): + payload = messages_pb2.Payload(body=b'\0' * request.response_size) + return messages_pb2.SimpleResponse(payload=payload) + + async def StreamingFromServer(self, request, unused_context): + payload = messages_pb2.Payload(body=b'\0' * request.response_size) + # Sends response at full capacity! + while True: + yield messages_pb2.SimpleResponse(payload=payload) + + async def StreamingCall(self, request_iterator, unused_context): + async for request in request_iterator: + payload = messages_pb2.Payload(body=b'\0' * request.response_size) + yield messages_pb2.SimpleResponse(payload=payload) + + +class GenericBenchmarkServicer( + benchmark_service_pb2_grpc.BenchmarkServiceServicer): + """Generic (no-codec) Server implementation for the Benchmark service.""" + + def __init__(self, resp_size): + self._response = '\0' * resp_size + + async def UnaryCall(self, unused_request, unused_context): + return self._response + + async def StreamingCall(self, request_iterator, unused_context): + async for _ in request_iterator: + yield self._response diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/benchmark/server.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/benchmark/server.py new file mode 100644 index 00000000000..561298a626b --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/benchmark/server.py @@ -0,0 +1,46 @@ +# Copyright 2019 The gRPC Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import logging +import unittest + +from grpc.experimental import aio + +from src.proto.grpc.testing import benchmark_service_pb2_grpc +from tests_aio.benchmark import benchmark_servicer + + +async def _start_async_server(): + server = aio.server() + + port = server.add_insecure_port('localhost:%s' % 50051) + servicer = benchmark_servicer.BenchmarkServicer() + benchmark_service_pb2_grpc.add_BenchmarkServiceServicer_to_server( + servicer, server) + + await server.start() + logging.info('Benchmark server started at :%d' % port) + await server.wait_for_termination() + + +def main(): + loop = asyncio.get_event_loop() + loop.create_task(_start_async_server()) + loop.run_forever() + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + main() diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/benchmark/worker.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/benchmark/worker.py new file mode 100644 index 00000000000..dc16f050872 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/benchmark/worker.py @@ -0,0 +1,59 @@ +# Copyright 2020 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import asyncio +import logging + +from grpc.experimental import aio + +from src.proto.grpc.testing import worker_service_pb2_grpc +from tests_aio.benchmark import worker_servicer + + +async def run_worker_server(port: int) -> None: + server = aio.server() + + servicer = worker_servicer.WorkerServicer() + worker_service_pb2_grpc.add_WorkerServiceServicer_to_server( + servicer, server) + + server.add_insecure_port('[::]:{}'.format(port)) + + await server.start() + + await servicer.wait_for_quit() + await server.stop(None) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + parser = argparse.ArgumentParser( + description='gRPC Python performance testing worker') + parser.add_argument('--driver_port', + type=int, + dest='port', + help='The port the worker should listen on') + parser.add_argument('--uvloop', + action='store_true', + help='Use uvloop or not') + args = parser.parse_args() + + if args.uvloop: + import uvloop + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + loop = uvloop.new_event_loop() + asyncio.set_event_loop(loop) + + asyncio.get_event_loop().run_until_complete(run_worker_server(args.port)) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/benchmark/worker_servicer.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/benchmark/worker_servicer.py new file mode 100644 index 00000000000..4f80095cd20 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/benchmark/worker_servicer.py @@ -0,0 +1,367 @@ +# Copyright 2020 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import collections +import logging +import multiprocessing +import os +import sys +import time +from typing import Tuple + +import grpc +from grpc.experimental import aio + +from src.proto.grpc.testing import (benchmark_service_pb2_grpc, control_pb2, + stats_pb2, worker_service_pb2_grpc) +from tests.qps import histogram +from tests.unit import resources +from tests.unit.framework.common import get_socket +from tests_aio.benchmark import benchmark_client, benchmark_servicer + +_NUM_CORES = multiprocessing.cpu_count() +_WORKER_ENTRY_FILE = os.path.join( + os.path.split(os.path.abspath(__file__))[0], 'worker.py') + +_LOGGER = logging.getLogger(__name__) + + +class _SubWorker( + collections.namedtuple('_SubWorker', + ['process', 'port', 'channel', 'stub'])): + """A data class that holds information about a child qps worker.""" + + def _repr(self): + return f'<_SubWorker pid={self.process.pid} port={self.port}>' + + def __repr__(self): + return self._repr() + + def __str__(self): + return self._repr() + + +def _get_server_status(start_time: float, end_time: float, + port: int) -> control_pb2.ServerStatus: + """Creates ServerStatus proto message.""" + end_time = time.monotonic() + elapsed_time = end_time - start_time + # TODO(lidiz) Collect accurate time system to compute QPS/core-second. + stats = stats_pb2.ServerStats(time_elapsed=elapsed_time, + time_user=elapsed_time, + time_system=elapsed_time) + return control_pb2.ServerStatus(stats=stats, port=port, cores=_NUM_CORES) + + +def _create_server(config: control_pb2.ServerConfig) -> Tuple[aio.Server, int]: + """Creates a server object according to the ServerConfig.""" + channel_args = tuple( + (arg.name, + arg.str_value) if arg.HasField('str_value') else (arg.name, + int(arg.int_value)) + for arg in config.channel_args) + + server = aio.server(options=channel_args + (('grpc.so_reuseport', 1),)) + if config.server_type == control_pb2.ASYNC_SERVER: + servicer = benchmark_servicer.BenchmarkServicer() + benchmark_service_pb2_grpc.add_BenchmarkServiceServicer_to_server( + servicer, server) + elif config.server_type == control_pb2.ASYNC_GENERIC_SERVER: + resp_size = config.payload_config.bytebuf_params.resp_size + servicer = benchmark_servicer.GenericBenchmarkServicer(resp_size) + method_implementations = { + 'StreamingCall': + grpc.stream_stream_rpc_method_handler(servicer.StreamingCall), + 'UnaryCall': + grpc.unary_unary_rpc_method_handler(servicer.UnaryCall), + } + handler = grpc.method_handlers_generic_handler( + 'grpc.testing.BenchmarkService', method_implementations) + server.add_generic_rpc_handlers((handler,)) + else: + raise NotImplementedError('Unsupported server type {}'.format( + config.server_type)) + + if config.HasField('security_params'): # Use SSL + server_creds = grpc.ssl_server_credentials( + ((resources.private_key(), resources.certificate_chain()),)) + port = server.add_secure_port('[::]:{}'.format(config.port), + server_creds) + else: + port = server.add_insecure_port('[::]:{}'.format(config.port)) + + return server, port + + +def _get_client_status(start_time: float, end_time: float, + qps_data: histogram.Histogram + ) -> control_pb2.ClientStatus: + """Creates ClientStatus proto message.""" + latencies = qps_data.get_data() + end_time = time.monotonic() + elapsed_time = end_time - start_time + # TODO(lidiz) Collect accurate time system to compute QPS/core-second. + stats = stats_pb2.ClientStats(latencies=latencies, + time_elapsed=elapsed_time, + time_user=elapsed_time, + time_system=elapsed_time) + return control_pb2.ClientStatus(stats=stats) + + +def _create_client(server: str, config: control_pb2.ClientConfig, + qps_data: histogram.Histogram + ) -> benchmark_client.BenchmarkClient: + """Creates a client object according to the ClientConfig.""" + if config.load_params.WhichOneof('load') != 'closed_loop': + raise NotImplementedError( + f'Unsupported load parameter {config.load_params}') + + if config.client_type == control_pb2.ASYNC_CLIENT: + if config.rpc_type == control_pb2.UNARY: + client_type = benchmark_client.UnaryAsyncBenchmarkClient + elif config.rpc_type == control_pb2.STREAMING: + client_type = benchmark_client.StreamingAsyncBenchmarkClient + else: + raise NotImplementedError( + f'Unsupported rpc_type [{config.rpc_type}]') + else: + raise NotImplementedError( + f'Unsupported client type {config.client_type}') + + return client_type(server, config, qps_data) + + +def _pick_an_unused_port() -> int: + """Picks an unused TCP port.""" + _, port, sock = get_socket() + sock.close() + return port + + +async def _create_sub_worker() -> _SubWorker: + """Creates a child qps worker as a subprocess.""" + port = _pick_an_unused_port() + + _LOGGER.info('Creating sub worker at port [%d]...', port) + process = await asyncio.create_subprocess_exec(sys.executable, + _WORKER_ENTRY_FILE, + '--driver_port', str(port)) + _LOGGER.info('Created sub worker process for port [%d] at pid [%d]', port, + process.pid) + channel = aio.insecure_channel(f'localhost:{port}') + _LOGGER.info('Waiting for sub worker at port [%d]', port) + await channel.channel_ready() + stub = worker_service_pb2_grpc.WorkerServiceStub(channel) + return _SubWorker( + process=process, + port=port, + channel=channel, + stub=stub, + ) + + +class WorkerServicer(worker_service_pb2_grpc.WorkerServiceServicer): + """Python Worker Server implementation.""" + + def __init__(self): + self._loop = asyncio.get_event_loop() + self._quit_event = asyncio.Event() + + async def _run_single_server(self, config, request_iterator, context): + server, port = _create_server(config) + await server.start() + _LOGGER.info('Server started at port [%d]', port) + + start_time = time.monotonic() + await context.write(_get_server_status(start_time, start_time, port)) + + async for request in request_iterator: + end_time = time.monotonic() + status = _get_server_status(start_time, end_time, port) + if request.mark.reset: + start_time = end_time + await context.write(status) + await server.stop(None) + + async def RunServer(self, request_iterator, context): + config_request = await context.read() + config = config_request.setup + _LOGGER.info('Received ServerConfig: %s', config) + + if config.server_processes <= 0: + _LOGGER.info('Using server_processes == [%d]', _NUM_CORES) + config.server_processes = _NUM_CORES + + if config.port == 0: + config.port = _pick_an_unused_port() + _LOGGER.info('Port picked [%d]', config.port) + + if config.server_processes == 1: + # If server_processes == 1, start the server in this process. + await self._run_single_server(config, request_iterator, context) + else: + # If server_processes > 1, offload to other processes. + sub_workers = await asyncio.gather(*( + _create_sub_worker() for _ in range(config.server_processes))) + + calls = [worker.stub.RunServer() for worker in sub_workers] + + config_request.setup.server_processes = 1 + + for call in calls: + await call.write(config_request) + # An empty status indicates the peer is ready + await call.read() + + start_time = time.monotonic() + await context.write( + _get_server_status( + start_time, + start_time, + config.port, + )) + + _LOGGER.info('Servers are ready to serve.') + + async for request in request_iterator: + end_time = time.monotonic() + + for call in calls: + await call.write(request) + # Reports from sub workers doesn't matter + await call.read() + + status = _get_server_status( + start_time, + end_time, + config.port, + ) + if request.mark.reset: + start_time = end_time + await context.write(status) + + for call in calls: + await call.done_writing() + + for worker in sub_workers: + await worker.stub.QuitWorker(control_pb2.Void()) + await worker.channel.close() + _LOGGER.info('Waiting for [%s] to quit...', worker) + await worker.process.wait() + + async def _run_single_client(self, config, request_iterator, context): + running_tasks = [] + qps_data = histogram.Histogram(config.histogram_params.resolution, + config.histogram_params.max_possible) + start_time = time.monotonic() + + # Create a client for each channel as asyncio.Task + for i in range(config.client_channels): + server = config.server_targets[i % len(config.server_targets)] + client = _create_client(server, config, qps_data) + _LOGGER.info('Client created against server [%s]', server) + running_tasks.append(self._loop.create_task(client.run())) + + end_time = time.monotonic() + await context.write(_get_client_status(start_time, end_time, qps_data)) + + # Respond to stat requests + async for request in request_iterator: + end_time = time.monotonic() + status = _get_client_status(start_time, end_time, qps_data) + if request.mark.reset: + qps_data.reset() + start_time = time.monotonic() + await context.write(status) + + # Cleanup the clients + for task in running_tasks: + task.cancel() + + async def RunClient(self, request_iterator, context): + config_request = await context.read() + config = config_request.setup + _LOGGER.info('Received ClientConfig: %s', config) + + if config.client_processes <= 0: + _LOGGER.info('client_processes can\'t be [%d]', + config.client_processes) + _LOGGER.info('Using client_processes == [%d]', _NUM_CORES) + config.client_processes = _NUM_CORES + + if config.client_processes == 1: + # If client_processes == 1, run the benchmark in this process. + await self._run_single_client(config, request_iterator, context) + else: + # If client_processes > 1, offload the work to other processes. + sub_workers = await asyncio.gather(*( + _create_sub_worker() for _ in range(config.client_processes))) + + calls = [worker.stub.RunClient() for worker in sub_workers] + + config_request.setup.client_processes = 1 + + for call in calls: + await call.write(config_request) + # An empty status indicates the peer is ready + await call.read() + + start_time = time.monotonic() + result = histogram.Histogram(config.histogram_params.resolution, + config.histogram_params.max_possible) + end_time = time.monotonic() + await context.write(_get_client_status(start_time, end_time, + result)) + + async for request in request_iterator: + end_time = time.monotonic() + + for call in calls: + _LOGGER.debug('Fetching status...') + await call.write(request) + sub_status = await call.read() + result.merge(sub_status.stats.latencies) + _LOGGER.debug('Update from sub worker count=[%d]', + sub_status.stats.latencies.count) + + status = _get_client_status(start_time, end_time, result) + if request.mark.reset: + result.reset() + start_time = time.monotonic() + _LOGGER.debug('Reporting count=[%d]', + status.stats.latencies.count) + await context.write(status) + + for call in calls: + await call.done_writing() + + for worker in sub_workers: + await worker.stub.QuitWorker(control_pb2.Void()) + await worker.channel.close() + _LOGGER.info('Waiting for sub worker [%s] to quit...', worker) + await worker.process.wait() + _LOGGER.info('Sub worker [%s] quit', worker) + + @staticmethod + async def CoreCount(unused_request, unused_context): + return control_pb2.CoreResponse(cores=_NUM_CORES) + + async def QuitWorker(self, unused_request, unused_context): + _LOGGER.info('QuitWorker command received.') + self._quit_event.set() + return control_pb2.Void() + + async def wait_for_quit(self): + await self._quit_event.wait() diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/channelz/__init__.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/channelz/__init__.py new file mode 100644 index 00000000000..1517f71d093 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/channelz/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2020 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/channelz/channelz_servicer_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/channelz/channelz_servicer_test.py new file mode 100644 index 00000000000..d6e9fd42791 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/channelz/channelz_servicer_test.py @@ -0,0 +1,474 @@ +# Copyright 2020 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests of grpc_channelz.v1.channelz.""" + +import unittest +import logging +import asyncio + +import grpc +from grpc.experimental import aio + +from grpc_channelz.v1 import channelz +from grpc_channelz.v1 import channelz_pb2 +from grpc_channelz.v1 import channelz_pb2_grpc + +from tests.unit.framework.common import test_constants +from tests_aio.unit._test_base import AioTestBase + +_SUCCESSFUL_UNARY_UNARY = '/test/SuccessfulUnaryUnary' +_FAILED_UNARY_UNARY = '/test/FailedUnaryUnary' +_SUCCESSFUL_STREAM_STREAM = '/test/SuccessfulStreamStream' + +_REQUEST = b'\x00\x00\x00' +_RESPONSE = b'\x01\x01\x01' + +_DISABLE_REUSE_PORT = (('grpc.so_reuseport', 0),) +_ENABLE_CHANNELZ = (('grpc.enable_channelz', 1),) +_DISABLE_CHANNELZ = (('grpc.enable_channelz', 0),) + +_LARGE_UNASSIGNED_ID = 10000 + + +async def _successful_unary_unary(request, servicer_context): + return _RESPONSE + + +async def _failed_unary_unary(request, servicer_context): + servicer_context.set_code(grpc.StatusCode.INTERNAL) + servicer_context.set_details("Channelz Test Intended Failure") + + +async def _successful_stream_stream(request_iterator, servicer_context): + async for _ in request_iterator: + yield _RESPONSE + + +class _GenericHandler(grpc.GenericRpcHandler): + + def service(self, handler_call_details): + if handler_call_details.method == _SUCCESSFUL_UNARY_UNARY: + return grpc.unary_unary_rpc_method_handler(_successful_unary_unary) + elif handler_call_details.method == _FAILED_UNARY_UNARY: + return grpc.unary_unary_rpc_method_handler(_failed_unary_unary) + elif handler_call_details.method == _SUCCESSFUL_STREAM_STREAM: + return grpc.stream_stream_rpc_method_handler( + _successful_stream_stream) + else: + return None + + +class _ChannelServerPair: + + def __init__(self): + self.address = '' + self.server = None + self.channel = None + self.server_ref_id = None + self.channel_ref_id = None + + async def start(self): + # Server will enable channelz service + self.server = aio.server(options=_DISABLE_REUSE_PORT + _ENABLE_CHANNELZ) + port = self.server.add_insecure_port('[::]:0') + self.address = 'localhost:%d' % port + self.server.add_generic_rpc_handlers((_GenericHandler(),)) + await self.server.start() + + # Channel will enable channelz service... + self.channel = aio.insecure_channel(self.address, + options=_ENABLE_CHANNELZ) + + async def bind_channelz(self, channelz_stub): + resp = await channelz_stub.GetTopChannels( + channelz_pb2.GetTopChannelsRequest(start_channel_id=0)) + for channel in resp.channel: + if channel.data.target == self.address: + self.channel_ref_id = channel.ref.channel_id + + resp = await channelz_stub.GetServers( + channelz_pb2.GetServersRequest(start_server_id=0)) + self.server_ref_id = resp.server[-1].ref.server_id + + async def stop(self): + await self.channel.close() + await self.server.stop(None) + + +async def _create_channel_server_pairs(n, channelz_stub=None): + """Create channel-server pairs.""" + pairs = [_ChannelServerPair() for i in range(n)] + for pair in pairs: + await pair.start() + if channelz_stub: + await pair.bind_channelz(channelz_stub) + return pairs + + +async def _destroy_channel_server_pairs(pairs): + for pair in pairs: + await pair.stop() + + +class ChannelzServicerTest(AioTestBase): + + async def setUp(self): + # This server is for Channelz info fetching only + # It self should not enable Channelz + self._server = aio.server(options=_DISABLE_REUSE_PORT + + _DISABLE_CHANNELZ) + port = self._server.add_insecure_port('[::]:0') + channelz.add_channelz_servicer(self._server) + await self._server.start() + + # This channel is used to fetch Channelz info only + # Channelz should not be enabled + self._channel = aio.insecure_channel('localhost:%d' % port, + options=_DISABLE_CHANNELZ) + self._channelz_stub = channelz_pb2_grpc.ChannelzStub(self._channel) + + async def tearDown(self): + await self._channel.close() + await self._server.stop(None) + + async def _get_server_by_ref_id(self, ref_id): + """Server id may not be consecutive""" + resp = await self._channelz_stub.GetServers( + channelz_pb2.GetServersRequest(start_server_id=ref_id)) + self.assertEqual(ref_id, resp.server[0].ref.server_id) + return resp.server[0] + + async def _send_successful_unary_unary(self, pair): + call = pair.channel.unary_unary(_SUCCESSFUL_UNARY_UNARY)(_REQUEST) + self.assertEqual(grpc.StatusCode.OK, await call.code()) + + async def _send_failed_unary_unary(self, pair): + try: + await pair.channel.unary_unary(_FAILED_UNARY_UNARY)(_REQUEST) + except grpc.RpcError: + return + else: + self.fail("This call supposed to fail") + + async def _send_successful_stream_stream(self, pair): + call = pair.channel.stream_stream(_SUCCESSFUL_STREAM_STREAM)(iter( + [_REQUEST] * test_constants.STREAM_LENGTH)) + cnt = 0 + async for _ in call: + cnt += 1 + self.assertEqual(cnt, test_constants.STREAM_LENGTH) + + async def test_get_top_channels_high_start_id(self): + pairs = await _create_channel_server_pairs(1) + + resp = await self._channelz_stub.GetTopChannels( + channelz_pb2.GetTopChannelsRequest( + start_channel_id=_LARGE_UNASSIGNED_ID)) + self.assertEqual(len(resp.channel), 0) + self.assertEqual(resp.end, True) + + await _destroy_channel_server_pairs(pairs) + + async def test_successful_request(self): + pairs = await _create_channel_server_pairs(1, self._channelz_stub) + + await self._send_successful_unary_unary(pairs[0]) + resp = await self._channelz_stub.GetChannel( + channelz_pb2.GetChannelRequest(channel_id=pairs[0].channel_ref_id)) + + self.assertEqual(resp.channel.data.calls_started, 1) + self.assertEqual(resp.channel.data.calls_succeeded, 1) + self.assertEqual(resp.channel.data.calls_failed, 0) + + await _destroy_channel_server_pairs(pairs) + + async def test_failed_request(self): + pairs = await _create_channel_server_pairs(1, self._channelz_stub) + + await self._send_failed_unary_unary(pairs[0]) + resp = await self._channelz_stub.GetChannel( + channelz_pb2.GetChannelRequest(channel_id=pairs[0].channel_ref_id)) + self.assertEqual(resp.channel.data.calls_started, 1) + self.assertEqual(resp.channel.data.calls_succeeded, 0) + self.assertEqual(resp.channel.data.calls_failed, 1) + + await _destroy_channel_server_pairs(pairs) + + async def test_many_requests(self): + pairs = await _create_channel_server_pairs(1, self._channelz_stub) + + k_success = 7 + k_failed = 9 + for i in range(k_success): + await self._send_successful_unary_unary(pairs[0]) + for i in range(k_failed): + await self._send_failed_unary_unary(pairs[0]) + resp = await self._channelz_stub.GetChannel( + channelz_pb2.GetChannelRequest(channel_id=pairs[0].channel_ref_id)) + self.assertEqual(resp.channel.data.calls_started, k_success + k_failed) + self.assertEqual(resp.channel.data.calls_succeeded, k_success) + self.assertEqual(resp.channel.data.calls_failed, k_failed) + + await _destroy_channel_server_pairs(pairs) + + async def test_many_requests_many_channel(self): + k_channels = 4 + pairs = await _create_channel_server_pairs(k_channels, + self._channelz_stub) + k_success = 11 + k_failed = 13 + for i in range(k_success): + await self._send_successful_unary_unary(pairs[0]) + await self._send_successful_unary_unary(pairs[2]) + for i in range(k_failed): + await self._send_failed_unary_unary(pairs[1]) + await self._send_failed_unary_unary(pairs[2]) + + # The first channel saw only successes + resp = await self._channelz_stub.GetChannel( + channelz_pb2.GetChannelRequest(channel_id=pairs[0].channel_ref_id)) + self.assertEqual(resp.channel.data.calls_started, k_success) + self.assertEqual(resp.channel.data.calls_succeeded, k_success) + self.assertEqual(resp.channel.data.calls_failed, 0) + + # The second channel saw only failures + resp = await self._channelz_stub.GetChannel( + channelz_pb2.GetChannelRequest(channel_id=pairs[1].channel_ref_id)) + self.assertEqual(resp.channel.data.calls_started, k_failed) + self.assertEqual(resp.channel.data.calls_succeeded, 0) + self.assertEqual(resp.channel.data.calls_failed, k_failed) + + # The third channel saw both successes and failures + resp = await self._channelz_stub.GetChannel( + channelz_pb2.GetChannelRequest(channel_id=pairs[2].channel_ref_id)) + self.assertEqual(resp.channel.data.calls_started, k_success + k_failed) + self.assertEqual(resp.channel.data.calls_succeeded, k_success) + self.assertEqual(resp.channel.data.calls_failed, k_failed) + + # The fourth channel saw nothing + resp = await self._channelz_stub.GetChannel( + channelz_pb2.GetChannelRequest(channel_id=pairs[3].channel_ref_id)) + self.assertEqual(resp.channel.data.calls_started, 0) + self.assertEqual(resp.channel.data.calls_succeeded, 0) + self.assertEqual(resp.channel.data.calls_failed, 0) + + await _destroy_channel_server_pairs(pairs) + + async def test_many_subchannels(self): + k_channels = 4 + pairs = await _create_channel_server_pairs(k_channels, + self._channelz_stub) + k_success = 17 + k_failed = 19 + for i in range(k_success): + await self._send_successful_unary_unary(pairs[0]) + await self._send_successful_unary_unary(pairs[2]) + for i in range(k_failed): + await self._send_failed_unary_unary(pairs[1]) + await self._send_failed_unary_unary(pairs[2]) + + for i in range(k_channels): + gc_resp = await self._channelz_stub.GetChannel( + channelz_pb2.GetChannelRequest( + channel_id=pairs[i].channel_ref_id)) + # If no call performed in the channel, there shouldn't be any subchannel + if gc_resp.channel.data.calls_started == 0: + self.assertEqual(len(gc_resp.channel.subchannel_ref), 0) + continue + + # Otherwise, the subchannel should exist + self.assertGreater(len(gc_resp.channel.subchannel_ref), 0) + gsc_resp = await self._channelz_stub.GetSubchannel( + channelz_pb2.GetSubchannelRequest( + subchannel_id=gc_resp.channel.subchannel_ref[0]. + subchannel_id)) + self.assertEqual(gc_resp.channel.data.calls_started, + gsc_resp.subchannel.data.calls_started) + self.assertEqual(gc_resp.channel.data.calls_succeeded, + gsc_resp.subchannel.data.calls_succeeded) + self.assertEqual(gc_resp.channel.data.calls_failed, + gsc_resp.subchannel.data.calls_failed) + + await _destroy_channel_server_pairs(pairs) + + async def test_server_call(self): + pairs = await _create_channel_server_pairs(1, self._channelz_stub) + + k_success = 23 + k_failed = 29 + for i in range(k_success): + await self._send_successful_unary_unary(pairs[0]) + for i in range(k_failed): + await self._send_failed_unary_unary(pairs[0]) + + resp = await self._get_server_by_ref_id(pairs[0].server_ref_id) + self.assertEqual(resp.data.calls_started, k_success + k_failed) + self.assertEqual(resp.data.calls_succeeded, k_success) + self.assertEqual(resp.data.calls_failed, k_failed) + + await _destroy_channel_server_pairs(pairs) + + async def test_many_subchannels_and_sockets(self): + k_channels = 4 + pairs = await _create_channel_server_pairs(k_channels, + self._channelz_stub) + k_success = 3 + k_failed = 5 + for i in range(k_success): + await self._send_successful_unary_unary(pairs[0]) + await self._send_successful_unary_unary(pairs[2]) + for i in range(k_failed): + await self._send_failed_unary_unary(pairs[1]) + await self._send_failed_unary_unary(pairs[2]) + + for i in range(k_channels): + gc_resp = await self._channelz_stub.GetChannel( + channelz_pb2.GetChannelRequest( + channel_id=pairs[i].channel_ref_id)) + + # If no call performed in the channel, there shouldn't be any subchannel + if gc_resp.channel.data.calls_started == 0: + self.assertEqual(len(gc_resp.channel.subchannel_ref), 0) + continue + + # Otherwise, the subchannel should exist + self.assertGreater(len(gc_resp.channel.subchannel_ref), 0) + gsc_resp = await self._channelz_stub.GetSubchannel( + channelz_pb2.GetSubchannelRequest( + subchannel_id=gc_resp.channel.subchannel_ref[0]. + subchannel_id)) + self.assertEqual(len(gsc_resp.subchannel.socket_ref), 1) + + gs_resp = await self._channelz_stub.GetSocket( + channelz_pb2.GetSocketRequest( + socket_id=gsc_resp.subchannel.socket_ref[0].socket_id)) + self.assertEqual(gsc_resp.subchannel.data.calls_started, + gs_resp.socket.data.streams_started) + self.assertEqual(0, gs_resp.socket.data.streams_failed) + # Calls started == messages sent, only valid for unary calls + self.assertEqual(gsc_resp.subchannel.data.calls_started, + gs_resp.socket.data.messages_sent) + + await _destroy_channel_server_pairs(pairs) + + async def test_streaming_rpc(self): + pairs = await _create_channel_server_pairs(1, self._channelz_stub) + # In C++, the argument for _send_successful_stream_stream is message length. + # Here the argument is still channel idx, to be consistent with the other two. + await self._send_successful_stream_stream(pairs[0]) + + gc_resp = await self._channelz_stub.GetChannel( + channelz_pb2.GetChannelRequest(channel_id=pairs[0].channel_ref_id)) + self.assertEqual(gc_resp.channel.data.calls_started, 1) + self.assertEqual(gc_resp.channel.data.calls_succeeded, 1) + self.assertEqual(gc_resp.channel.data.calls_failed, 0) + # Subchannel exists + self.assertGreater(len(gc_resp.channel.subchannel_ref), 0) + + gsc_resp = await self._channelz_stub.GetSubchannel( + channelz_pb2.GetSubchannelRequest( + subchannel_id=gc_resp.channel.subchannel_ref[0].subchannel_id)) + self.assertEqual(gsc_resp.subchannel.data.calls_started, 1) + self.assertEqual(gsc_resp.subchannel.data.calls_succeeded, 1) + self.assertEqual(gsc_resp.subchannel.data.calls_failed, 0) + # Socket exists + self.assertEqual(len(gsc_resp.subchannel.socket_ref), 1) + + gs_resp = await self._channelz_stub.GetSocket( + channelz_pb2.GetSocketRequest( + socket_id=gsc_resp.subchannel.socket_ref[0].socket_id)) + self.assertEqual(gs_resp.socket.data.streams_started, 1) + self.assertEqual(gs_resp.socket.data.streams_succeeded, 1) + self.assertEqual(gs_resp.socket.data.streams_failed, 0) + self.assertEqual(gs_resp.socket.data.messages_sent, + test_constants.STREAM_LENGTH) + self.assertEqual(gs_resp.socket.data.messages_received, + test_constants.STREAM_LENGTH) + + await _destroy_channel_server_pairs(pairs) + + async def test_server_sockets(self): + pairs = await _create_channel_server_pairs(1, self._channelz_stub) + + await self._send_successful_unary_unary(pairs[0]) + await self._send_failed_unary_unary(pairs[0]) + + resp = await self._get_server_by_ref_id(pairs[0].server_ref_id) + self.assertEqual(resp.data.calls_started, 2) + self.assertEqual(resp.data.calls_succeeded, 1) + self.assertEqual(resp.data.calls_failed, 1) + + gss_resp = await self._channelz_stub.GetServerSockets( + channelz_pb2.GetServerSocketsRequest(server_id=resp.ref.server_id, + start_socket_id=0)) + # If the RPC call failed, it will raise a grpc.RpcError + # So, if there is no exception raised, considered pass + await _destroy_channel_server_pairs(pairs) + + async def test_server_listen_sockets(self): + pairs = await _create_channel_server_pairs(1, self._channelz_stub) + + resp = await self._get_server_by_ref_id(pairs[0].server_ref_id) + self.assertEqual(len(resp.listen_socket), 1) + + gs_resp = await self._channelz_stub.GetSocket( + channelz_pb2.GetSocketRequest( + socket_id=resp.listen_socket[0].socket_id)) + # If the RPC call failed, it will raise a grpc.RpcError + # So, if there is no exception raised, considered pass + await _destroy_channel_server_pairs(pairs) + + async def test_invalid_query_get_server(self): + with self.assertRaises(aio.AioRpcError) as exception_context: + await self._channelz_stub.GetServer( + channelz_pb2.GetServerRequest(server_id=_LARGE_UNASSIGNED_ID)) + self.assertEqual(grpc.StatusCode.NOT_FOUND, + exception_context.exception.code()) + + async def test_invalid_query_get_channel(self): + with self.assertRaises(aio.AioRpcError) as exception_context: + await self._channelz_stub.GetChannel( + channelz_pb2.GetChannelRequest(channel_id=_LARGE_UNASSIGNED_ID)) + self.assertEqual(grpc.StatusCode.NOT_FOUND, + exception_context.exception.code()) + + async def test_invalid_query_get_subchannel(self): + with self.assertRaises(aio.AioRpcError) as exception_context: + await self._channelz_stub.GetSubchannel( + channelz_pb2.GetSubchannelRequest( + subchannel_id=_LARGE_UNASSIGNED_ID)) + self.assertEqual(grpc.StatusCode.NOT_FOUND, + exception_context.exception.code()) + + async def test_invalid_query_get_socket(self): + with self.assertRaises(aio.AioRpcError) as exception_context: + await self._channelz_stub.GetSocket( + channelz_pb2.GetSocketRequest(socket_id=_LARGE_UNASSIGNED_ID)) + self.assertEqual(grpc.StatusCode.NOT_FOUND, + exception_context.exception.code()) + + async def test_invalid_query_get_server_sockets(self): + with self.assertRaises(aio.AioRpcError) as exception_context: + await self._channelz_stub.GetServerSockets( + channelz_pb2.GetServerSocketsRequest( + server_id=_LARGE_UNASSIGNED_ID, + start_socket_id=0, + )) + self.assertEqual(grpc.StatusCode.NOT_FOUND, + exception_context.exception.code()) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/health_check/__init__.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/health_check/__init__.py new file mode 100644 index 00000000000..1517f71d093 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/health_check/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2020 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/health_check/health_servicer_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/health_check/health_servicer_test.py new file mode 100644 index 00000000000..a539dbf1409 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/health_check/health_servicer_test.py @@ -0,0 +1,282 @@ +# Copyright 2020 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests AsyncIO version of grpcio-health-checking.""" + +import asyncio +import logging +import time +import random +import unittest + +import grpc + +from grpc_health.v1 import health +from grpc_health.v1 import health_pb2 +from grpc_health.v1 import health_pb2_grpc +from grpc.experimental import aio + +from tests.unit.framework.common import test_constants + +from tests_aio.unit._test_base import AioTestBase + +_SERVING_SERVICE = 'grpc.test.TestServiceServing' +_UNKNOWN_SERVICE = 'grpc.test.TestServiceUnknown' +_NOT_SERVING_SERVICE = 'grpc.test.TestServiceNotServing' +_WATCH_SERVICE = 'grpc.test.WatchService' + +_LARGE_NUMBER_OF_STATUS_CHANGES = 1000 + + +async def _pipe_to_queue(call, queue): + async for response in call: + await queue.put(response) + + +class HealthServicerTest(AioTestBase): + + async def setUp(self): + self._servicer = health.aio.HealthServicer() + await self._servicer.set(_SERVING_SERVICE, + health_pb2.HealthCheckResponse.SERVING) + await self._servicer.set(_UNKNOWN_SERVICE, + health_pb2.HealthCheckResponse.UNKNOWN) + await self._servicer.set(_NOT_SERVING_SERVICE, + health_pb2.HealthCheckResponse.NOT_SERVING) + self._server = aio.server() + port = self._server.add_insecure_port('[::]:0') + health_pb2_grpc.add_HealthServicer_to_server(self._servicer, + self._server) + await self._server.start() + + self._channel = aio.insecure_channel('localhost:%d' % port) + self._stub = health_pb2_grpc.HealthStub(self._channel) + + async def tearDown(self): + await self._channel.close() + await self._server.stop(None) + + async def test_check_empty_service(self): + request = health_pb2.HealthCheckRequest() + resp = await self._stub.Check(request) + self.assertEqual(health_pb2.HealthCheckResponse.SERVING, resp.status) + + async def test_check_serving_service(self): + request = health_pb2.HealthCheckRequest(service=_SERVING_SERVICE) + resp = await self._stub.Check(request) + self.assertEqual(health_pb2.HealthCheckResponse.SERVING, resp.status) + + async def test_check_unknown_service(self): + request = health_pb2.HealthCheckRequest(service=_UNKNOWN_SERVICE) + resp = await self._stub.Check(request) + self.assertEqual(health_pb2.HealthCheckResponse.UNKNOWN, resp.status) + + async def test_check_not_serving_service(self): + request = health_pb2.HealthCheckRequest(service=_NOT_SERVING_SERVICE) + resp = await self._stub.Check(request) + self.assertEqual(health_pb2.HealthCheckResponse.NOT_SERVING, + resp.status) + + async def test_check_not_found_service(self): + request = health_pb2.HealthCheckRequest(service='not-found') + with self.assertRaises(aio.AioRpcError) as context: + await self._stub.Check(request) + + self.assertEqual(grpc.StatusCode.NOT_FOUND, context.exception.code()) + + async def test_health_service_name(self): + self.assertEqual(health.SERVICE_NAME, 'grpc.health.v1.Health') + + async def test_watch_empty_service(self): + request = health_pb2.HealthCheckRequest(service=health.OVERALL_HEALTH) + + call = self._stub.Watch(request) + queue = asyncio.Queue() + task = self.loop.create_task(_pipe_to_queue(call, queue)) + + self.assertEqual(health_pb2.HealthCheckResponse.SERVING, + (await queue.get()).status) + + call.cancel() + + with self.assertRaises(asyncio.CancelledError): + await task + + self.assertTrue(queue.empty()) + + async def test_watch_new_service(self): + request = health_pb2.HealthCheckRequest(service=_WATCH_SERVICE) + call = self._stub.Watch(request) + queue = asyncio.Queue() + task = self.loop.create_task(_pipe_to_queue(call, queue)) + + self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, + (await queue.get()).status) + + await self._servicer.set(_WATCH_SERVICE, + health_pb2.HealthCheckResponse.SERVING) + self.assertEqual(health_pb2.HealthCheckResponse.SERVING, + (await queue.get()).status) + + await self._servicer.set(_WATCH_SERVICE, + health_pb2.HealthCheckResponse.NOT_SERVING) + self.assertEqual(health_pb2.HealthCheckResponse.NOT_SERVING, + (await queue.get()).status) + + call.cancel() + + with self.assertRaises(asyncio.CancelledError): + await task + + self.assertTrue(queue.empty()) + + async def test_watch_service_isolation(self): + request = health_pb2.HealthCheckRequest(service=_WATCH_SERVICE) + call = self._stub.Watch(request) + queue = asyncio.Queue() + task = self.loop.create_task(_pipe_to_queue(call, queue)) + + self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, + (await queue.get()).status) + + await self._servicer.set('some-other-service', + health_pb2.HealthCheckResponse.SERVING) + # The change of health status in other service should be isolated. + # Hence, no additional notification should be observed. + with self.assertRaises(asyncio.TimeoutError): + await asyncio.wait_for(queue.get(), test_constants.SHORT_TIMEOUT) + + call.cancel() + + with self.assertRaises(asyncio.CancelledError): + await task + + self.assertTrue(queue.empty()) + + async def test_two_watchers(self): + request = health_pb2.HealthCheckRequest(service=_WATCH_SERVICE) + queue1 = asyncio.Queue() + queue2 = asyncio.Queue() + call1 = self._stub.Watch(request) + call2 = self._stub.Watch(request) + task1 = self.loop.create_task(_pipe_to_queue(call1, queue1)) + task2 = self.loop.create_task(_pipe_to_queue(call2, queue2)) + + self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, + (await queue1.get()).status) + self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, + (await queue2.get()).status) + + await self._servicer.set(_WATCH_SERVICE, + health_pb2.HealthCheckResponse.SERVING) + self.assertEqual(health_pb2.HealthCheckResponse.SERVING, + (await queue1.get()).status) + self.assertEqual(health_pb2.HealthCheckResponse.SERVING, + (await queue2.get()).status) + + call1.cancel() + call2.cancel() + + with self.assertRaises(asyncio.CancelledError): + await task1 + + with self.assertRaises(asyncio.CancelledError): + await task2 + + self.assertTrue(queue1.empty()) + self.assertTrue(queue2.empty()) + + async def test_cancelled_watch_removed_from_watch_list(self): + request = health_pb2.HealthCheckRequest(service=_WATCH_SERVICE) + call = self._stub.Watch(request) + queue = asyncio.Queue() + task = self.loop.create_task(_pipe_to_queue(call, queue)) + + self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, + (await queue.get()).status) + + call.cancel() + await self._servicer.set(_WATCH_SERVICE, + health_pb2.HealthCheckResponse.SERVING) + + with self.assertRaises(asyncio.CancelledError): + await task + + # Wait for the serving coroutine to process client cancellation. + timeout = time.monotonic() + test_constants.TIME_ALLOWANCE + while (time.monotonic() < timeout and self._servicer._server_watchers): + await asyncio.sleep(1) + self.assertFalse(self._servicer._server_watchers, + 'There should not be any watcher left') + self.assertTrue(queue.empty()) + + async def test_graceful_shutdown(self): + request = health_pb2.HealthCheckRequest(service=health.OVERALL_HEALTH) + call = self._stub.Watch(request) + queue = asyncio.Queue() + task = self.loop.create_task(_pipe_to_queue(call, queue)) + + self.assertEqual(health_pb2.HealthCheckResponse.SERVING, + (await queue.get()).status) + + await self._servicer.enter_graceful_shutdown() + self.assertEqual(health_pb2.HealthCheckResponse.NOT_SERVING, + (await queue.get()).status) + + # This should be a no-op. + await self._servicer.set(health.OVERALL_HEALTH, + health_pb2.HealthCheckResponse.SERVING) + + resp = await self._stub.Check(request) + self.assertEqual(health_pb2.HealthCheckResponse.NOT_SERVING, + resp.status) + + call.cancel() + + with self.assertRaises(asyncio.CancelledError): + await task + + self.assertTrue(queue.empty()) + + async def test_no_duplicate_status(self): + request = health_pb2.HealthCheckRequest(service=_WATCH_SERVICE) + call = self._stub.Watch(request) + queue = asyncio.Queue() + task = self.loop.create_task(_pipe_to_queue(call, queue)) + + self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, + (await queue.get()).status) + last_status = health_pb2.HealthCheckResponse.SERVICE_UNKNOWN + + for _ in range(_LARGE_NUMBER_OF_STATUS_CHANGES): + if random.randint(0, 1) == 0: + status = health_pb2.HealthCheckResponse.SERVING + else: + status = health_pb2.HealthCheckResponse.NOT_SERVING + + await self._servicer.set(_WATCH_SERVICE, status) + if status != last_status: + self.assertEqual(status, (await queue.get()).status) + last_status = status + + call.cancel() + + with self.assertRaises(asyncio.CancelledError): + await task + + self.assertTrue(queue.empty()) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/interop/__init__.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/interop/__init__.py new file mode 100644 index 00000000000..b71ddbd314c --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/interop/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2019 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/interop/client.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/interop/client.py new file mode 100644 index 00000000000..a4c5e12ceda --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/interop/client.py @@ -0,0 +1,61 @@ +# Copyright 2019 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import asyncio +import logging +import os + +import grpc +from grpc.experimental import aio + +from tests.interop import client as interop_client_lib +from tests_aio.interop import methods + +_LOGGER = logging.getLogger(__name__) +_LOGGER.setLevel(logging.DEBUG) + + +def _create_channel(args): + target = f'{args.server_host}:{args.server_port}' + + if args.use_tls or args.use_alts or args.custom_credentials_type is not None: + channel_credentials, options = interop_client_lib.get_secure_channel_parameters( + args) + return aio.secure_channel(target, channel_credentials, options) + else: + return aio.insecure_channel(target) + + +def _test_case_from_arg(test_case_arg): + for test_case in methods.TestCase: + if test_case_arg == test_case.value: + return test_case + else: + raise ValueError('No test case "%s"!' % test_case_arg) + + +async def test_interoperability(): + + args = interop_client_lib.parse_interop_client_args() + channel = _create_channel(args) + stub = interop_client_lib.create_stub(channel, args) + test_case = _test_case_from_arg(args.test_case) + await methods.test_interoperability(test_case, stub, args) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + asyncio.get_event_loop().set_debug(True) + asyncio.get_event_loop().run_until_complete(test_interoperability()) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/interop/local_interop_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/interop/local_interop_test.py new file mode 100644 index 00000000000..0db15be3a94 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/interop/local_interop_test.py @@ -0,0 +1,134 @@ +# Copyright 2020 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Conducts interop tests locally.""" + +import logging +import unittest + +import grpc +from grpc.experimental import aio + +from src.proto.grpc.testing import test_pb2_grpc +from tests.interop import resources +from tests_aio.interop import methods +from tests_aio.unit._test_base import AioTestBase +from tests_aio.unit._test_server import start_test_server + +_SERVER_HOST_OVERRIDE = 'foo.test.google.fr' + + +class InteropTestCaseMixin: + """Unit test methods. + + This class must be mixed in with unittest.TestCase and a class that defines + setUp and tearDown methods that manage a stub attribute. + """ + _stub: test_pb2_grpc.TestServiceStub + + async def test_empty_unary(self): + await methods.test_interoperability(methods.TestCase.EMPTY_UNARY, + self._stub, None) + + async def test_large_unary(self): + await methods.test_interoperability(methods.TestCase.LARGE_UNARY, + self._stub, None) + + async def test_server_streaming(self): + await methods.test_interoperability(methods.TestCase.SERVER_STREAMING, + self._stub, None) + + async def test_client_streaming(self): + await methods.test_interoperability(methods.TestCase.CLIENT_STREAMING, + self._stub, None) + + async def test_ping_pong(self): + await methods.test_interoperability(methods.TestCase.PING_PONG, + self._stub, None) + + async def test_cancel_after_begin(self): + await methods.test_interoperability(methods.TestCase.CANCEL_AFTER_BEGIN, + self._stub, None) + + async def test_cancel_after_first_response(self): + await methods.test_interoperability( + methods.TestCase.CANCEL_AFTER_FIRST_RESPONSE, self._stub, None) + + async def test_timeout_on_sleeping_server(self): + await methods.test_interoperability( + methods.TestCase.TIMEOUT_ON_SLEEPING_SERVER, self._stub, None) + + async def test_empty_stream(self): + await methods.test_interoperability(methods.TestCase.EMPTY_STREAM, + self._stub, None) + + async def test_status_code_and_message(self): + await methods.test_interoperability( + methods.TestCase.STATUS_CODE_AND_MESSAGE, self._stub, None) + + async def test_unimplemented_method(self): + await methods.test_interoperability( + methods.TestCase.UNIMPLEMENTED_METHOD, self._stub, None) + + async def test_unimplemented_service(self): + await methods.test_interoperability( + methods.TestCase.UNIMPLEMENTED_SERVICE, self._stub, None) + + async def test_custom_metadata(self): + await methods.test_interoperability(methods.TestCase.CUSTOM_METADATA, + self._stub, None) + + async def test_special_status_message(self): + await methods.test_interoperability( + methods.TestCase.SPECIAL_STATUS_MESSAGE, self._stub, None) + + +class InsecureLocalInteropTest(InteropTestCaseMixin, AioTestBase): + + async def setUp(self): + address, self._server = await start_test_server() + self._channel = aio.insecure_channel(address) + self._stub = test_pb2_grpc.TestServiceStub(self._channel) + + async def tearDown(self): + await self._channel.close() + await self._server.stop(None) + + +class SecureLocalInteropTest(InteropTestCaseMixin, AioTestBase): + + async def setUp(self): + server_credentials = grpc.ssl_server_credentials([ + (resources.private_key(), resources.certificate_chain()) + ]) + channel_credentials = grpc.ssl_channel_credentials( + resources.test_root_certificates()) + channel_options = (( + 'grpc.ssl_target_name_override', + _SERVER_HOST_OVERRIDE, + ),) + + address, self._server = await start_test_server( + secure=True, server_credentials=server_credentials) + self._channel = aio.secure_channel(address, channel_credentials, + channel_options) + self._stub = test_pb2_grpc.TestServiceStub(self._channel) + + async def tearDown(self): + await self._channel.close() + await self._server.stop(None) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.INFO) + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/interop/methods.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/interop/methods.py new file mode 100644 index 00000000000..aa39976981f --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/interop/methods.py @@ -0,0 +1,456 @@ +# Copyright 2019 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Implementations of interoperability test methods.""" + +import argparse +import asyncio +import collections +import datetime +import enum +import inspect +import json +import os +import threading +import time +from typing import Any, Optional, Union + +import grpc +from google import auth as google_auth +from google.auth import environment_vars as google_auth_environment_vars +from google.auth.transport import grpc as google_auth_transport_grpc +from google.auth.transport import requests as google_auth_transport_requests +from grpc.experimental import aio + +from src.proto.grpc.testing import empty_pb2, messages_pb2, test_pb2_grpc + +_INITIAL_METADATA_KEY = "x-grpc-test-echo-initial" +_TRAILING_METADATA_KEY = "x-grpc-test-echo-trailing-bin" + + +async def _expect_status_code(call: aio.Call, + expected_code: grpc.StatusCode) -> None: + code = await call.code() + if code != expected_code: + raise ValueError('expected code %s, got %s' % + (expected_code, await call.code())) + + +async def _expect_status_details(call: aio.Call, expected_details: str) -> None: + details = await call.details() + if details != expected_details: + raise ValueError('expected message %s, got %s' % + (expected_details, await call.details())) + + +async def _validate_status_code_and_details(call: aio.Call, + expected_code: grpc.StatusCode, + expected_details: str) -> None: + await _expect_status_code(call, expected_code) + await _expect_status_details(call, expected_details) + + +def _validate_payload_type_and_length( + response: Union[messages_pb2.SimpleResponse, messages_pb2. + StreamingOutputCallResponse], expected_type: Any, + expected_length: int) -> None: + if response.payload.type is not expected_type: + raise ValueError('expected payload type %s, got %s' % + (expected_type, type(response.payload.type))) + elif len(response.payload.body) != expected_length: + raise ValueError('expected payload body size %d, got %d' % + (expected_length, len(response.payload.body))) + + +async def _large_unary_common_behavior( + stub: test_pb2_grpc.TestServiceStub, fill_username: bool, + fill_oauth_scope: bool, call_credentials: Optional[grpc.CallCredentials] +) -> messages_pb2.SimpleResponse: + size = 314159 + request = messages_pb2.SimpleRequest( + response_type=messages_pb2.COMPRESSABLE, + response_size=size, + payload=messages_pb2.Payload(body=b'\x00' * 271828), + fill_username=fill_username, + fill_oauth_scope=fill_oauth_scope) + response = await stub.UnaryCall(request, credentials=call_credentials) + _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, size) + return response + + +async def _empty_unary(stub: test_pb2_grpc.TestServiceStub) -> None: + response = await stub.EmptyCall(empty_pb2.Empty()) + if not isinstance(response, empty_pb2.Empty): + raise TypeError('response is of type "%s", not empty_pb2.Empty!' % + type(response)) + + +async def _large_unary(stub: test_pb2_grpc.TestServiceStub) -> None: + await _large_unary_common_behavior(stub, False, False, None) + + +async def _client_streaming(stub: test_pb2_grpc.TestServiceStub) -> None: + payload_body_sizes = ( + 27182, + 8, + 1828, + 45904, + ) + + async def request_gen(): + for size in payload_body_sizes: + yield messages_pb2.StreamingInputCallRequest( + payload=messages_pb2.Payload(body=b'\x00' * size)) + + response = await stub.StreamingInputCall(request_gen()) + if response.aggregated_payload_size != sum(payload_body_sizes): + raise ValueError('incorrect size %d!' % + response.aggregated_payload_size) + + +async def _server_streaming(stub: test_pb2_grpc.TestServiceStub) -> None: + sizes = ( + 31415, + 9, + 2653, + 58979, + ) + + request = messages_pb2.StreamingOutputCallRequest( + response_type=messages_pb2.COMPRESSABLE, + response_parameters=( + messages_pb2.ResponseParameters(size=sizes[0]), + messages_pb2.ResponseParameters(size=sizes[1]), + messages_pb2.ResponseParameters(size=sizes[2]), + messages_pb2.ResponseParameters(size=sizes[3]), + )) + call = stub.StreamingOutputCall(request) + for size in sizes: + response = await call.read() + _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, + size) + + +async def _ping_pong(stub: test_pb2_grpc.TestServiceStub) -> None: + request_response_sizes = ( + 31415, + 9, + 2653, + 58979, + ) + request_payload_sizes = ( + 27182, + 8, + 1828, + 45904, + ) + + call = stub.FullDuplexCall() + for response_size, payload_size in zip(request_response_sizes, + request_payload_sizes): + request = messages_pb2.StreamingOutputCallRequest( + response_type=messages_pb2.COMPRESSABLE, + response_parameters=(messages_pb2.ResponseParameters( + size=response_size),), + payload=messages_pb2.Payload(body=b'\x00' * payload_size)) + + await call.write(request) + response = await call.read() + _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, + response_size) + await call.done_writing() + await _validate_status_code_and_details(call, grpc.StatusCode.OK, '') + + +async def _cancel_after_begin(stub: test_pb2_grpc.TestServiceStub): + call = stub.StreamingInputCall() + call.cancel() + if not call.cancelled(): + raise ValueError('expected cancelled method to return True') + code = await call.code() + if code is not grpc.StatusCode.CANCELLED: + raise ValueError('expected status code CANCELLED') + + +async def _cancel_after_first_response(stub: test_pb2_grpc.TestServiceStub): + request_response_sizes = ( + 31415, + 9, + 2653, + 58979, + ) + request_payload_sizes = ( + 27182, + 8, + 1828, + 45904, + ) + + call = stub.FullDuplexCall() + + response_size = request_response_sizes[0] + payload_size = request_payload_sizes[0] + request = messages_pb2.StreamingOutputCallRequest( + response_type=messages_pb2.COMPRESSABLE, + response_parameters=(messages_pb2.ResponseParameters( + size=response_size),), + payload=messages_pb2.Payload(body=b'\x00' * payload_size)) + + await call.write(request) + await call.read() + + call.cancel() + + try: + await call.read() + except asyncio.CancelledError: + assert await call.code() is grpc.StatusCode.CANCELLED + else: + raise ValueError('expected call to be cancelled') + + +async def _timeout_on_sleeping_server(stub: test_pb2_grpc.TestServiceStub): + request_payload_size = 27182 + time_limit = datetime.timedelta(seconds=1) + + call = stub.FullDuplexCall(timeout=time_limit.total_seconds()) + + request = messages_pb2.StreamingOutputCallRequest( + response_type=messages_pb2.COMPRESSABLE, + payload=messages_pb2.Payload(body=b'\x00' * request_payload_size), + response_parameters=(messages_pb2.ResponseParameters( + interval_us=int(time_limit.total_seconds() * 2 * 10**6)),)) + await call.write(request) + await call.done_writing() + try: + await call.read() + except aio.AioRpcError as rpc_error: + if rpc_error.code() is not grpc.StatusCode.DEADLINE_EXCEEDED: + raise + else: + raise ValueError('expected call to exceed deadline') + + +async def _empty_stream(stub: test_pb2_grpc.TestServiceStub): + call = stub.FullDuplexCall() + await call.done_writing() + assert await call.read() == aio.EOF + + +async def _status_code_and_message(stub: test_pb2_grpc.TestServiceStub): + details = 'test status message' + status = grpc.StatusCode.UNKNOWN # code = 2 + + # Test with a UnaryCall + request = messages_pb2.SimpleRequest( + response_type=messages_pb2.COMPRESSABLE, + response_size=1, + payload=messages_pb2.Payload(body=b'\x00'), + response_status=messages_pb2.EchoStatus(code=status.value[0], + message=details)) + call = stub.UnaryCall(request) + await _validate_status_code_and_details(call, status, details) + + # Test with a FullDuplexCall + call = stub.FullDuplexCall() + request = messages_pb2.StreamingOutputCallRequest( + response_type=messages_pb2.COMPRESSABLE, + response_parameters=(messages_pb2.ResponseParameters(size=1),), + payload=messages_pb2.Payload(body=b'\x00'), + response_status=messages_pb2.EchoStatus(code=status.value[0], + message=details)) + await call.write(request) # sends the initial request. + await call.done_writing() + try: + await call.read() + except aio.AioRpcError as rpc_error: + assert rpc_error.code() == status + await _validate_status_code_and_details(call, status, details) + + +async def _unimplemented_method(stub: test_pb2_grpc.TestServiceStub): + call = stub.UnimplementedCall(empty_pb2.Empty()) + await _expect_status_code(call, grpc.StatusCode.UNIMPLEMENTED) + + +async def _unimplemented_service(stub: test_pb2_grpc.UnimplementedServiceStub): + call = stub.UnimplementedCall(empty_pb2.Empty()) + await _expect_status_code(call, grpc.StatusCode.UNIMPLEMENTED) + + +async def _custom_metadata(stub: test_pb2_grpc.TestServiceStub): + initial_metadata_value = "test_initial_metadata_value" + trailing_metadata_value = b"\x0a\x0b\x0a\x0b\x0a\x0b" + metadata = aio.Metadata( + (_INITIAL_METADATA_KEY, initial_metadata_value), + (_TRAILING_METADATA_KEY, trailing_metadata_value), + ) + + async def _validate_metadata(call): + initial_metadata = await call.initial_metadata() + if initial_metadata[_INITIAL_METADATA_KEY] != initial_metadata_value: + raise ValueError('expected initial metadata %s, got %s' % + (initial_metadata_value, + initial_metadata[_INITIAL_METADATA_KEY])) + + trailing_metadata = await call.trailing_metadata() + if trailing_metadata[_TRAILING_METADATA_KEY] != trailing_metadata_value: + raise ValueError('expected trailing metadata %s, got %s' % + (trailing_metadata_value, + trailing_metadata[_TRAILING_METADATA_KEY])) + + # Testing with UnaryCall + request = messages_pb2.SimpleRequest( + response_type=messages_pb2.COMPRESSABLE, + response_size=1, + payload=messages_pb2.Payload(body=b'\x00')) + call = stub.UnaryCall(request, metadata=metadata) + await _validate_metadata(call) + + # Testing with FullDuplexCall + call = stub.FullDuplexCall(metadata=metadata) + request = messages_pb2.StreamingOutputCallRequest( + response_type=messages_pb2.COMPRESSABLE, + response_parameters=(messages_pb2.ResponseParameters(size=1),)) + await call.write(request) + await call.read() + await call.done_writing() + await _validate_metadata(call) + + +async def _compute_engine_creds(stub: test_pb2_grpc.TestServiceStub, + args: argparse.Namespace): + response = await _large_unary_common_behavior(stub, True, True, None) + if args.default_service_account != response.username: + raise ValueError('expected username %s, got %s' % + (args.default_service_account, response.username)) + + +async def _oauth2_auth_token(stub: test_pb2_grpc.TestServiceStub, + args: argparse.Namespace): + json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS] + wanted_email = json.load(open(json_key_filename, 'r'))['client_email'] + response = await _large_unary_common_behavior(stub, True, True, None) + if wanted_email != response.username: + raise ValueError('expected username %s, got %s' % + (wanted_email, response.username)) + if args.oauth_scope.find(response.oauth_scope) == -1: + raise ValueError( + 'expected to find oauth scope "{}" in received "{}"'.format( + response.oauth_scope, args.oauth_scope)) + + +async def _jwt_token_creds(stub: test_pb2_grpc.TestServiceStub): + json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS] + wanted_email = json.load(open(json_key_filename, 'r'))['client_email'] + response = await _large_unary_common_behavior(stub, True, False, None) + if wanted_email != response.username: + raise ValueError('expected username %s, got %s' % + (wanted_email, response.username)) + + +async def _per_rpc_creds(stub: test_pb2_grpc.TestServiceStub, + args: argparse.Namespace): + json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS] + wanted_email = json.load(open(json_key_filename, 'r'))['client_email'] + google_credentials, unused_project_id = google_auth.default( + scopes=[args.oauth_scope]) + call_credentials = grpc.metadata_call_credentials( + google_auth_transport_grpc.AuthMetadataPlugin( + credentials=google_credentials, + request=google_auth_transport_requests.Request())) + response = await _large_unary_common_behavior(stub, True, False, + call_credentials) + if wanted_email != response.username: + raise ValueError('expected username %s, got %s' % + (wanted_email, response.username)) + + +async def _special_status_message(stub: test_pb2_grpc.TestServiceStub): + details = b'\t\ntest with whitespace\r\nand Unicode BMP \xe2\x98\xba and non-BMP \xf0\x9f\x98\x88\t\n'.decode( + 'utf-8') + status = grpc.StatusCode.UNKNOWN # code = 2 + + # Test with a UnaryCall + request = messages_pb2.SimpleRequest( + response_type=messages_pb2.COMPRESSABLE, + response_size=1, + payload=messages_pb2.Payload(body=b'\x00'), + response_status=messages_pb2.EchoStatus(code=status.value[0], + message=details)) + call = stub.UnaryCall(request) + await _validate_status_code_and_details(call, status, details) + + +class TestCase(enum.Enum): + EMPTY_UNARY = 'empty_unary' + LARGE_UNARY = 'large_unary' + SERVER_STREAMING = 'server_streaming' + CLIENT_STREAMING = 'client_streaming' + PING_PONG = 'ping_pong' + CANCEL_AFTER_BEGIN = 'cancel_after_begin' + CANCEL_AFTER_FIRST_RESPONSE = 'cancel_after_first_response' + TIMEOUT_ON_SLEEPING_SERVER = 'timeout_on_sleeping_server' + EMPTY_STREAM = 'empty_stream' + STATUS_CODE_AND_MESSAGE = 'status_code_and_message' + UNIMPLEMENTED_METHOD = 'unimplemented_method' + UNIMPLEMENTED_SERVICE = 'unimplemented_service' + CUSTOM_METADATA = "custom_metadata" + COMPUTE_ENGINE_CREDS = 'compute_engine_creds' + OAUTH2_AUTH_TOKEN = 'oauth2_auth_token' + JWT_TOKEN_CREDS = 'jwt_token_creds' + PER_RPC_CREDS = 'per_rpc_creds' + SPECIAL_STATUS_MESSAGE = 'special_status_message' + + +_TEST_CASE_IMPLEMENTATION_MAPPING = { + TestCase.EMPTY_UNARY: _empty_unary, + TestCase.LARGE_UNARY: _large_unary, + TestCase.SERVER_STREAMING: _server_streaming, + TestCase.CLIENT_STREAMING: _client_streaming, + TestCase.PING_PONG: _ping_pong, + TestCase.CANCEL_AFTER_BEGIN: _cancel_after_begin, + TestCase.CANCEL_AFTER_FIRST_RESPONSE: _cancel_after_first_response, + TestCase.TIMEOUT_ON_SLEEPING_SERVER: _timeout_on_sleeping_server, + TestCase.EMPTY_STREAM: _empty_stream, + TestCase.STATUS_CODE_AND_MESSAGE: _status_code_and_message, + TestCase.UNIMPLEMENTED_METHOD: _unimplemented_method, + TestCase.UNIMPLEMENTED_SERVICE: _unimplemented_service, + TestCase.CUSTOM_METADATA: _custom_metadata, + TestCase.COMPUTE_ENGINE_CREDS: _compute_engine_creds, + TestCase.OAUTH2_AUTH_TOKEN: _oauth2_auth_token, + TestCase.JWT_TOKEN_CREDS: _jwt_token_creds, + TestCase.PER_RPC_CREDS: _per_rpc_creds, + TestCase.SPECIAL_STATUS_MESSAGE: _special_status_message, +} + + +async def test_interoperability(case: TestCase, + stub: test_pb2_grpc.TestServiceStub, + args: Optional[argparse.Namespace] = None + ) -> None: + method = _TEST_CASE_IMPLEMENTATION_MAPPING.get(case) + if method is None: + raise NotImplementedError(f'Test case "{case}" not implemented!') + else: + num_params = len(inspect.signature(method).parameters) + if num_params == 1: + await method(stub) + elif num_params == 2: + if args is not None: + await method(stub, args) + else: + raise ValueError(f'Failed to run case [{case}]: args is None') + else: + raise ValueError(f'Invalid number of parameters [{num_params}]') diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/interop/server.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/interop/server.py new file mode 100644 index 00000000000..509abdf0b2f --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/interop/server.py @@ -0,0 +1,49 @@ +# Copyright 2019 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The gRPC interoperability test server using AsyncIO stack.""" + +import asyncio +import argparse +import logging + +import grpc + +from tests.interop import server as interop_server_lib +from tests_aio.unit import _test_server + +logging.basicConfig(level=logging.DEBUG) +_LOGGER = logging.getLogger(__name__) +_LOGGER.setLevel(logging.DEBUG) + + +async def serve(): + args = interop_server_lib.parse_interop_server_arguments() + + if args.use_tls or args.use_alts: + credentials = interop_server_lib.get_server_credentials(args.use_tls) + address, server = await _test_server.start_test_server( + port=args.port, secure=True, server_credentials=credentials) + else: + address, server = await _test_server.start_test_server( + port=args.port, + secure=False, + ) + + _LOGGER.info('Server serving at %s', address) + await server.wait_for_termination() + _LOGGER.info('Server stopped; exiting.') + + +if __name__ == '__main__': + asyncio.get_event_loop().run_until_complete(serve()) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/reflection/__init__.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/reflection/__init__.py new file mode 100644 index 00000000000..5772620b602 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/reflection/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2016 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/reflection/reflection_servicer_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/reflection/reflection_servicer_test.py new file mode 100644 index 00000000000..edd2d79eabe --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/reflection/reflection_servicer_test.py @@ -0,0 +1,193 @@ +# Copyright 2016 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests of grpc_reflection.v1alpha.reflection.""" + +import logging +import unittest + +import grpc +from google.protobuf import descriptor_pb2 +from grpc.experimental import aio + +from grpc_reflection.v1alpha import (reflection, reflection_pb2, + reflection_pb2_grpc) +from src.proto.grpc.testing import empty_pb2 +from src.proto.grpc.testing.proto2 import empty2_extensions_pb2 +from tests_aio.unit._test_base import AioTestBase + +_EMPTY_PROTO_FILE_NAME = 'src/proto/grpc/testing/empty.proto' +_EMPTY_PROTO_SYMBOL_NAME = 'grpc.testing.Empty' +_SERVICE_NAMES = ('Angstrom', 'Bohr', 'Curie', 'Dyson', 'Einstein', 'Feynman', + 'Galilei') +_EMPTY_EXTENSIONS_SYMBOL_NAME = 'grpc.testing.proto2.EmptyWithExtensions' +_EMPTY_EXTENSIONS_NUMBERS = ( + 124, + 125, + 126, + 127, + 128, +) + + +def _file_descriptor_to_proto(descriptor): + proto = descriptor_pb2.FileDescriptorProto() + descriptor.CopyToProto(proto) + return proto.SerializeToString() + + +class ReflectionServicerTest(AioTestBase): + + async def setUp(self): + self._server = aio.server() + reflection.enable_server_reflection(_SERVICE_NAMES, self._server) + port = self._server.add_insecure_port('[::]:0') + await self._server.start() + + self._channel = aio.insecure_channel('localhost:%d' % port) + self._stub = reflection_pb2_grpc.ServerReflectionStub(self._channel) + + async def tearDown(self): + await self._server.stop(None) + await self._channel.close() + + async def test_file_by_name(self): + requests = ( + reflection_pb2.ServerReflectionRequest( + file_by_filename=_EMPTY_PROTO_FILE_NAME), + reflection_pb2.ServerReflectionRequest( + file_by_filename='i-donut-exist'), + ) + responses = [] + async for response in self._stub.ServerReflectionInfo(iter(requests)): + responses.append(response) + expected_responses = ( + reflection_pb2.ServerReflectionResponse( + valid_host='', + file_descriptor_response=reflection_pb2.FileDescriptorResponse( + file_descriptor_proto=( + _file_descriptor_to_proto(empty_pb2.DESCRIPTOR),))), + reflection_pb2.ServerReflectionResponse( + valid_host='', + error_response=reflection_pb2.ErrorResponse( + error_code=grpc.StatusCode.NOT_FOUND.value[0], + error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(), + )), + ) + self.assertSequenceEqual(expected_responses, responses) + + async def test_file_by_symbol(self): + requests = ( + reflection_pb2.ServerReflectionRequest( + file_containing_symbol=_EMPTY_PROTO_SYMBOL_NAME), + reflection_pb2.ServerReflectionRequest( + file_containing_symbol='i.donut.exist.co.uk.org.net.me.name.foo' + ), + ) + responses = [] + async for response in self._stub.ServerReflectionInfo(iter(requests)): + responses.append(response) + expected_responses = ( + reflection_pb2.ServerReflectionResponse( + valid_host='', + file_descriptor_response=reflection_pb2.FileDescriptorResponse( + file_descriptor_proto=( + _file_descriptor_to_proto(empty_pb2.DESCRIPTOR),))), + reflection_pb2.ServerReflectionResponse( + valid_host='', + error_response=reflection_pb2.ErrorResponse( + error_code=grpc.StatusCode.NOT_FOUND.value[0], + error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(), + )), + ) + self.assertSequenceEqual(expected_responses, responses) + + async def test_file_containing_extension(self): + requests = ( + reflection_pb2.ServerReflectionRequest( + file_containing_extension=reflection_pb2.ExtensionRequest( + containing_type=_EMPTY_EXTENSIONS_SYMBOL_NAME, + extension_number=125, + ),), + reflection_pb2.ServerReflectionRequest( + file_containing_extension=reflection_pb2.ExtensionRequest( + containing_type='i.donut.exist.co.uk.org.net.me.name.foo', + extension_number=55, + ),), + ) + responses = [] + async for response in self._stub.ServerReflectionInfo(iter(requests)): + responses.append(response) + expected_responses = ( + reflection_pb2.ServerReflectionResponse( + valid_host='', + file_descriptor_response=reflection_pb2.FileDescriptorResponse( + file_descriptor_proto=(_file_descriptor_to_proto( + empty2_extensions_pb2.DESCRIPTOR),))), + reflection_pb2.ServerReflectionResponse( + valid_host='', + error_response=reflection_pb2.ErrorResponse( + error_code=grpc.StatusCode.NOT_FOUND.value[0], + error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(), + )), + ) + self.assertSequenceEqual(expected_responses, responses) + + async def test_extension_numbers_of_type(self): + requests = ( + reflection_pb2.ServerReflectionRequest( + all_extension_numbers_of_type=_EMPTY_EXTENSIONS_SYMBOL_NAME), + reflection_pb2.ServerReflectionRequest( + all_extension_numbers_of_type='i.donut.exist.co.uk.net.name.foo' + ), + ) + responses = [] + async for response in self._stub.ServerReflectionInfo(iter(requests)): + responses.append(response) + expected_responses = ( + reflection_pb2.ServerReflectionResponse( + valid_host='', + all_extension_numbers_response=reflection_pb2. + ExtensionNumberResponse( + base_type_name=_EMPTY_EXTENSIONS_SYMBOL_NAME, + extension_number=_EMPTY_EXTENSIONS_NUMBERS)), + reflection_pb2.ServerReflectionResponse( + valid_host='', + error_response=reflection_pb2.ErrorResponse( + error_code=grpc.StatusCode.NOT_FOUND.value[0], + error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(), + )), + ) + self.assertSequenceEqual(expected_responses, responses) + + async def test_list_services(self): + requests = (reflection_pb2.ServerReflectionRequest(list_services='',),) + responses = [] + async for response in self._stub.ServerReflectionInfo(iter(requests)): + responses.append(response) + expected_responses = (reflection_pb2.ServerReflectionResponse( + valid_host='', + list_services_response=reflection_pb2.ListServiceResponse( + service=tuple( + reflection_pb2.ServiceResponse(name=name) + for name in _SERVICE_NAMES))),) + self.assertSequenceEqual(expected_responses, responses) + + def test_reflection_service_name(self): + self.assertEqual(reflection.SERVICE_NAME, + 'grpc.reflection.v1alpha.ServerReflection') + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/status/__init__.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/status/__init__.py new file mode 100644 index 00000000000..1517f71d093 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/status/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2020 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/status/grpc_status_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/status/grpc_status_test.py new file mode 100644 index 00000000000..980cf5a67e7 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/status/grpc_status_test.py @@ -0,0 +1,175 @@ +# Copyright 2020 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests of grpc_status with gRPC AsyncIO stack.""" + +import logging +import traceback +import unittest + +import grpc +from google.protobuf import any_pb2 +from google.rpc import code_pb2, error_details_pb2, status_pb2 +from grpc.experimental import aio + +from grpc_status import rpc_status +from tests_aio.unit._test_base import AioTestBase + +_STATUS_OK = '/test/StatusOK' +_STATUS_NOT_OK = '/test/StatusNotOk' +_ERROR_DETAILS = '/test/ErrorDetails' +_INCONSISTENT = '/test/Inconsistent' +_INVALID_CODE = '/test/InvalidCode' + +_REQUEST = b'\x00\x00\x00' +_RESPONSE = b'\x01\x01\x01' + +_GRPC_DETAILS_METADATA_KEY = 'grpc-status-details-bin' + +_STATUS_DETAILS = 'This is an error detail' +_STATUS_DETAILS_ANOTHER = 'This is another error detail' + + +async def _ok_unary_unary(request, servicer_context): + return _RESPONSE + + +async def _not_ok_unary_unary(request, servicer_context): + await servicer_context.abort(grpc.StatusCode.INTERNAL, _STATUS_DETAILS) + + +async def _error_details_unary_unary(request, servicer_context): + details = any_pb2.Any() + details.Pack( + error_details_pb2.DebugInfo(stack_entries=traceback.format_stack(), + detail='Intentionally invoked')) + rich_status = status_pb2.Status( + code=code_pb2.INTERNAL, + message=_STATUS_DETAILS, + details=[details], + ) + await servicer_context.abort_with_status(rpc_status.to_status(rich_status)) + + +async def _inconsistent_unary_unary(request, servicer_context): + rich_status = status_pb2.Status( + code=code_pb2.INTERNAL, + message=_STATUS_DETAILS, + ) + servicer_context.set_code(grpc.StatusCode.NOT_FOUND) + servicer_context.set_details(_STATUS_DETAILS_ANOTHER) + # User put inconsistent status information in trailing metadata + servicer_context.set_trailing_metadata( + ((_GRPC_DETAILS_METADATA_KEY, rich_status.SerializeToString()),)) + + +async def _invalid_code_unary_unary(request, servicer_context): + rich_status = status_pb2.Status( + code=42, + message='Invalid code', + ) + await servicer_context.abort_with_status(rpc_status.to_status(rich_status)) + + +class _GenericHandler(grpc.GenericRpcHandler): + + def service(self, handler_call_details): + if handler_call_details.method == _STATUS_OK: + return grpc.unary_unary_rpc_method_handler(_ok_unary_unary) + elif handler_call_details.method == _STATUS_NOT_OK: + return grpc.unary_unary_rpc_method_handler(_not_ok_unary_unary) + elif handler_call_details.method == _ERROR_DETAILS: + return grpc.unary_unary_rpc_method_handler( + _error_details_unary_unary) + elif handler_call_details.method == _INCONSISTENT: + return grpc.unary_unary_rpc_method_handler( + _inconsistent_unary_unary) + elif handler_call_details.method == _INVALID_CODE: + return grpc.unary_unary_rpc_method_handler( + _invalid_code_unary_unary) + else: + return None + + +class StatusTest(AioTestBase): + + async def setUp(self): + self._server = aio.server() + self._server.add_generic_rpc_handlers((_GenericHandler(),)) + port = self._server.add_insecure_port('[::]:0') + await self._server.start() + + self._channel = aio.insecure_channel('localhost:%d' % port) + + async def tearDown(self): + await self._server.stop(None) + await self._channel.close() + + async def test_status_ok(self): + call = self._channel.unary_unary(_STATUS_OK)(_REQUEST) + + # Succeed RPC doesn't have status + status = await rpc_status.aio.from_call(call) + self.assertIs(status, None) + + async def test_status_not_ok(self): + call = self._channel.unary_unary(_STATUS_NOT_OK)(_REQUEST) + with self.assertRaises(aio.AioRpcError) as exception_context: + await call + rpc_error = exception_context.exception + + self.assertEqual(rpc_error.code(), grpc.StatusCode.INTERNAL) + # Failed RPC doesn't automatically generate status + status = await rpc_status.aio.from_call(call) + self.assertIs(status, None) + + async def test_error_details(self): + call = self._channel.unary_unary(_ERROR_DETAILS)(_REQUEST) + with self.assertRaises(aio.AioRpcError) as exception_context: + await call + rpc_error = exception_context.exception + + status = await rpc_status.aio.from_call(call) + self.assertEqual(rpc_error.code(), grpc.StatusCode.INTERNAL) + self.assertEqual(status.code, code_pb2.Code.Value('INTERNAL')) + + # Check if the underlying proto message is intact + self.assertTrue(status.details[0].Is( + error_details_pb2.DebugInfo.DESCRIPTOR)) + info = error_details_pb2.DebugInfo() + status.details[0].Unpack(info) + self.assertIn('_error_details_unary_unary', info.stack_entries[-1]) + + async def test_code_message_validation(self): + call = self._channel.unary_unary(_INCONSISTENT)(_REQUEST) + with self.assertRaises(aio.AioRpcError) as exception_context: + await call + rpc_error = exception_context.exception + self.assertEqual(rpc_error.code(), grpc.StatusCode.NOT_FOUND) + + # Code/Message validation failed + with self.assertRaises(ValueError): + await rpc_status.aio.from_call(call) + + async def test_invalid_code(self): + with self.assertRaises(aio.AioRpcError) as exception_context: + await self._channel.unary_unary(_INVALID_CODE)(_REQUEST) + rpc_error = exception_context.exception + self.assertEqual(rpc_error.code(), grpc.StatusCode.UNKNOWN) + # Invalid status code exception raised during coversion + self.assertIn('Invalid status code', rpc_error.details()) + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/__init__.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/__init__.py new file mode 100644 index 00000000000..f4b321fc5b2 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2019 The gRPC Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/_common.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/_common.py new file mode 100644 index 00000000000..016280a1528 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/_common.py @@ -0,0 +1,99 @@ +# Copyright 2020 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import grpc +from typing import AsyncIterable +from grpc.experimental import aio +from grpc.aio._typing import MetadatumType, MetadataKey, MetadataValue +from grpc.aio._metadata import Metadata + +from tests.unit.framework.common import test_constants + + +def seen_metadata(expected: Metadata, actual: Metadata): + return not bool(set(tuple(expected)) - set(tuple(actual))) + + +def seen_metadatum(expected_key: MetadataKey, expected_value: MetadataValue, + actual: Metadata) -> bool: + obtained = actual[expected_key] + return obtained == expected_value + + +async def block_until_certain_state(channel: aio.Channel, + expected_state: grpc.ChannelConnectivity): + state = channel.get_state() + while state != expected_state: + await channel.wait_for_state_change(state) + state = channel.get_state() + + +def inject_callbacks(call: aio.Call): + first_callback_ran = asyncio.Event() + + def first_callback(call): + # Validate that all resopnses have been received + # and the call is an end state. + assert call.done() + first_callback_ran.set() + + second_callback_ran = asyncio.Event() + + def second_callback(call): + # Validate that all responses have been received + # and the call is an end state. + assert call.done() + second_callback_ran.set() + + call.add_done_callback(first_callback) + call.add_done_callback(second_callback) + + async def validation(): + await asyncio.wait_for( + asyncio.gather(first_callback_ran.wait(), + second_callback_ran.wait()), + test_constants.SHORT_TIMEOUT) + + return validation() + + +class CountingRequestIterator: + + def __init__(self, request_iterator): + self.request_cnt = 0 + self._request_iterator = request_iterator + + async def _forward_requests(self): + async for request in self._request_iterator: + self.request_cnt += 1 + yield request + + def __aiter__(self): + return self._forward_requests() + + +class CountingResponseIterator: + + def __init__(self, response_iterator): + self.response_cnt = 0 + self._response_iterator = response_iterator + + async def _forward_responses(self): + async for response in self._response_iterator: + self.response_cnt += 1 + yield response + + def __aiter__(self): + return self._forward_responses() diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/_constants.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/_constants.py new file mode 100644 index 00000000000..986a6f9d842 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/_constants.py @@ -0,0 +1,16 @@ +# Copyright 2020 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +UNREACHABLE_TARGET = '0.0.0.1:1111' +UNARY_CALL_WITH_SLEEP_VALUE = 0.2 diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/_metadata_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/_metadata_test.py new file mode 100644 index 00000000000..c0594cb06ab --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/_metadata_test.py @@ -0,0 +1,137 @@ +# Copyright 2020 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for the metadata abstraction that's used in the asynchronous driver.""" +import logging +import unittest + +from grpc.experimental.aio import Metadata + + +class TestTypeMetadata(unittest.TestCase): + """Tests for the metadata type""" + + _DEFAULT_DATA = (("key1", "value1"), ("key2", "value2")) + _MULTI_ENTRY_DATA = (("key1", "value1"), ("key1", "other value 1"), + ("key2", "value2")) + + def test_init_metadata(self): + test_cases = { + "emtpy": (), + "with-single-data": self._DEFAULT_DATA, + "with-multi-data": self._MULTI_ENTRY_DATA, + } + for case, args in test_cases.items(): + with self.subTest(case=case): + metadata = Metadata(*args) + self.assertEqual(len(metadata), len(args)) + + def test_get_item(self): + metadata = Metadata(("key", "value1"), ("key", "value2"), + ("key2", "other value")) + self.assertEqual(metadata["key"], "value1") + self.assertEqual(metadata["key2"], "other value") + self.assertEqual(metadata.get("key"), "value1") + self.assertEqual(metadata.get("key2"), "other value") + + with self.assertRaises(KeyError): + metadata["key not found"] + self.assertIsNone(metadata.get("key not found")) + + def test_add_value(self): + metadata = Metadata() + metadata.add("key", "value") + metadata.add("key", "second value") + metadata.add("key2", "value2") + + self.assertEqual(metadata["key"], "value") + self.assertEqual(metadata["key2"], "value2") + + def test_get_all_items(self): + metadata = Metadata(*self._MULTI_ENTRY_DATA) + self.assertEqual(metadata.get_all("key1"), ["value1", "other value 1"]) + self.assertEqual(metadata.get_all("key2"), ["value2"]) + self.assertEqual(metadata.get_all("non existing key"), []) + + def test_container(self): + metadata = Metadata(*self._MULTI_ENTRY_DATA) + self.assertIn("key1", metadata) + + def test_equals(self): + metadata = Metadata() + for key, value in self._DEFAULT_DATA: + metadata.add(key, value) + metadata2 = Metadata(*self._DEFAULT_DATA) + + self.assertEqual(metadata, metadata2) + self.assertNotEqual(metadata, "foo") + + def test_repr(self): + metadata = Metadata(*self._DEFAULT_DATA) + expected = "Metadata({0!r})".format(self._DEFAULT_DATA) + self.assertEqual(repr(metadata), expected) + + def test_set(self): + metadata = Metadata(*self._MULTI_ENTRY_DATA) + override_value = "override value" + for _ in range(3): + metadata["key1"] = override_value + + self.assertEqual(metadata["key1"], override_value) + self.assertEqual(metadata.get_all("key1"), + [override_value, "other value 1"]) + + empty_metadata = Metadata() + for _ in range(3): + empty_metadata["key"] = override_value + + self.assertEqual(empty_metadata["key"], override_value) + self.assertEqual(empty_metadata.get_all("key"), [override_value]) + + def test_set_all(self): + metadata = Metadata(*self._DEFAULT_DATA) + metadata.set_all("key", ["value1", b"new value 2"]) + + self.assertEqual(metadata["key"], "value1") + self.assertEqual(metadata.get_all("key"), ["value1", b"new value 2"]) + + def test_delete_values(self): + metadata = Metadata(*self._MULTI_ENTRY_DATA) + del metadata["key1"] + self.assertEqual(metadata.get("key1"), "other value 1") + + metadata.delete_all("key1") + self.assertNotIn("key1", metadata) + + metadata.delete_all("key2") + self.assertEqual(len(metadata), 0) + + with self.assertRaises(KeyError): + del metadata["other key"] + + def test_metadata_from_tuple(self): + scenarios = ( + (None, Metadata()), + (Metadata(), Metadata()), + (self._DEFAULT_DATA, Metadata(*self._DEFAULT_DATA)), + (self._MULTI_ENTRY_DATA, Metadata(*self._MULTI_ENTRY_DATA)), + (Metadata(*self._DEFAULT_DATA), Metadata(*self._DEFAULT_DATA)), + ) + for source, expected in scenarios: + with self.subTest(raw_metadata=source, expected=expected): + self.assertEqual(expected, Metadata.from_tuple(source)) + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/_test_base.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/_test_base.py new file mode 100644 index 00000000000..ec5f2112da0 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/_test_base.py @@ -0,0 +1,66 @@ +# Copyright 2019 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import functools +import asyncio +from typing import Callable +import unittest +from grpc.experimental import aio + +__all__ = 'AioTestBase' + +_COROUTINE_FUNCTION_ALLOWLIST = ['setUp', 'tearDown'] + + +def _async_to_sync_decorator(f: Callable, loop: asyncio.AbstractEventLoop): + + @functools.wraps(f) + def wrapper(*args, **kwargs): + return loop.run_until_complete(f(*args, **kwargs)) + + return wrapper + + +def _get_default_loop(debug=True): + try: + loop = asyncio.get_event_loop() + except: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + finally: + loop.set_debug(debug) + return loop + + +# NOTE(gnossen) this test class can also be implemented with metaclass. +class AioTestBase(unittest.TestCase): + # NOTE(lidi) We need to pick a loop for entire testing phase, otherwise it + # will trigger create new loops in new threads, leads to deadlock. + _TEST_LOOP = _get_default_loop() + + @property + def loop(self): + return self._TEST_LOOP + + def __getattribute__(self, name): + """Overrides the loading logic to support coroutine functions.""" + attr = super().__getattribute__(name) + + # If possible, converts the coroutine into a sync function. + if name.startswith('test_') or name in _COROUTINE_FUNCTION_ALLOWLIST: + if asyncio.iscoroutinefunction(attr): + return _async_to_sync_decorator(attr, self._TEST_LOOP) + # For other attributes, let them pass. + return attr diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/_test_server.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/_test_server.py new file mode 100644 index 00000000000..5e5081a38d0 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/_test_server.py @@ -0,0 +1,143 @@ +# Copyright 2019 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import datetime + +import grpc +from grpc.experimental import aio +from tests.unit import resources + +from src.proto.grpc.testing import empty_pb2, messages_pb2, test_pb2_grpc +from tests_aio.unit import _constants + +_INITIAL_METADATA_KEY = "x-grpc-test-echo-initial" +_TRAILING_METADATA_KEY = "x-grpc-test-echo-trailing-bin" + + +async def _maybe_echo_metadata(servicer_context): + """Copies metadata from request to response if it is present.""" + invocation_metadata = dict(servicer_context.invocation_metadata()) + if _INITIAL_METADATA_KEY in invocation_metadata: + initial_metadatum = (_INITIAL_METADATA_KEY, + invocation_metadata[_INITIAL_METADATA_KEY]) + await servicer_context.send_initial_metadata((initial_metadatum,)) + if _TRAILING_METADATA_KEY in invocation_metadata: + trailing_metadatum = (_TRAILING_METADATA_KEY, + invocation_metadata[_TRAILING_METADATA_KEY]) + servicer_context.set_trailing_metadata((trailing_metadatum,)) + + +async def _maybe_echo_status(request: messages_pb2.SimpleRequest, + servicer_context): + """Echos the RPC status if demanded by the request.""" + if request.HasField('response_status'): + await servicer_context.abort(request.response_status.code, + request.response_status.message) + + +class TestServiceServicer(test_pb2_grpc.TestServiceServicer): + + async def UnaryCall(self, request, context): + await _maybe_echo_metadata(context) + await _maybe_echo_status(request, context) + return messages_pb2.SimpleResponse( + payload=messages_pb2.Payload(type=messages_pb2.COMPRESSABLE, + body=b'\x00' * request.response_size)) + + async def EmptyCall(self, request, context): + return empty_pb2.Empty() + + async def StreamingOutputCall( + self, request: messages_pb2.StreamingOutputCallRequest, + unused_context): + for response_parameters in request.response_parameters: + if response_parameters.interval_us != 0: + await asyncio.sleep( + datetime.timedelta(microseconds=response_parameters. + interval_us).total_seconds()) + yield messages_pb2.StreamingOutputCallResponse( + payload=messages_pb2.Payload(type=request.response_type, + body=b'\x00' * + response_parameters.size)) + + # Next methods are extra ones that are registred programatically + # when the sever is instantiated. They are not being provided by + # the proto file. + async def UnaryCallWithSleep(self, unused_request, unused_context): + await asyncio.sleep(_constants.UNARY_CALL_WITH_SLEEP_VALUE) + return messages_pb2.SimpleResponse() + + async def StreamingInputCall(self, request_async_iterator, unused_context): + aggregate_size = 0 + async for request in request_async_iterator: + if request.payload is not None and request.payload.body: + aggregate_size += len(request.payload.body) + return messages_pb2.StreamingInputCallResponse( + aggregated_payload_size=aggregate_size) + + async def FullDuplexCall(self, request_async_iterator, context): + await _maybe_echo_metadata(context) + async for request in request_async_iterator: + await _maybe_echo_status(request, context) + for response_parameters in request.response_parameters: + if response_parameters.interval_us != 0: + await asyncio.sleep( + datetime.timedelta(microseconds=response_parameters. + interval_us).total_seconds()) + yield messages_pb2.StreamingOutputCallResponse( + payload=messages_pb2.Payload(type=request.payload.type, + body=b'\x00' * + response_parameters.size)) + + +def _create_extra_generic_handler(servicer: TestServiceServicer): + # Add programatically extra methods not provided by the proto file + # that are used during the tests + rpc_method_handlers = { + 'UnaryCallWithSleep': + grpc.unary_unary_rpc_method_handler( + servicer.UnaryCallWithSleep, + request_deserializer=messages_pb2.SimpleRequest.FromString, + response_serializer=messages_pb2.SimpleResponse. + SerializeToString) + } + return grpc.method_handlers_generic_handler('grpc.testing.TestService', + rpc_method_handlers) + + +async def start_test_server(port=0, + secure=False, + server_credentials=None, + interceptors=None): + server = aio.server(options=(('grpc.so_reuseport', 0),), + interceptors=interceptors) + servicer = TestServiceServicer() + test_pb2_grpc.add_TestServiceServicer_to_server(servicer, server) + + server.add_generic_rpc_handlers((_create_extra_generic_handler(servicer),)) + + if secure: + if server_credentials is None: + server_credentials = grpc.ssl_server_credentials([ + (resources.private_key(), resources.certificate_chain()) + ]) + port = server.add_secure_port('[::]:%d' % port, server_credentials) + else: + port = server.add_insecure_port('[::]:%d' % port) + + await server.start() + + # NOTE(lidizheng) returning the server to prevent it from deallocation + return 'localhost:%d' % port, server diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/abort_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/abort_test.py new file mode 100644 index 00000000000..828b6884dfa --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/abort_test.py @@ -0,0 +1,151 @@ +# Copyright 2019 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import logging +import unittest +import time +import gc + +import grpc +from grpc.experimental import aio +from tests_aio.unit._test_base import AioTestBase +from tests.unit.framework.common import test_constants + +_UNARY_UNARY_ABORT = '/test/UnaryUnaryAbort' +_SUPPRESS_ABORT = '/test/SuppressAbort' +_REPLACE_ABORT = '/test/ReplaceAbort' +_ABORT_AFTER_REPLY = '/test/AbortAfterReply' + +_REQUEST = b'\x00\x00\x00' +_RESPONSE = b'\x01\x01\x01' +_NUM_STREAM_RESPONSES = 5 + +_ABORT_CODE = grpc.StatusCode.RESOURCE_EXHAUSTED +_ABORT_DETAILS = 'Dummy error details' + + +class _GenericHandler(grpc.GenericRpcHandler): + + @staticmethod + async def _unary_unary_abort(unused_request, context): + await context.abort(_ABORT_CODE, _ABORT_DETAILS) + raise RuntimeError('This line should not be executed') + + @staticmethod + async def _suppress_abort(unused_request, context): + try: + await context.abort(_ABORT_CODE, _ABORT_DETAILS) + except aio.AbortError as e: + pass + return _RESPONSE + + @staticmethod + async def _replace_abort(unused_request, context): + try: + await context.abort(_ABORT_CODE, _ABORT_DETAILS) + except aio.AbortError as e: + await context.abort(grpc.StatusCode.INVALID_ARGUMENT, + 'Override abort!') + + @staticmethod + async def _abort_after_reply(unused_request, context): + yield _RESPONSE + await context.abort(_ABORT_CODE, _ABORT_DETAILS) + raise RuntimeError('This line should not be executed') + + def service(self, handler_details): + if handler_details.method == _UNARY_UNARY_ABORT: + return grpc.unary_unary_rpc_method_handler(self._unary_unary_abort) + if handler_details.method == _SUPPRESS_ABORT: + return grpc.unary_unary_rpc_method_handler(self._suppress_abort) + if handler_details.method == _REPLACE_ABORT: + return grpc.unary_unary_rpc_method_handler(self._replace_abort) + if handler_details.method == _ABORT_AFTER_REPLY: + return grpc.unary_stream_rpc_method_handler(self._abort_after_reply) + + +async def _start_test_server(): + server = aio.server() + port = server.add_insecure_port('[::]:0') + server.add_generic_rpc_handlers((_GenericHandler(),)) + await server.start() + return 'localhost:%d' % port, server + + +class TestAbort(AioTestBase): + + async def setUp(self): + address, self._server = await _start_test_server() + self._channel = aio.insecure_channel(address) + + async def tearDown(self): + await self._channel.close() + await self._server.stop(None) + + async def test_unary_unary_abort(self): + method = self._channel.unary_unary(_UNARY_UNARY_ABORT) + call = method(_REQUEST) + + self.assertEqual(_ABORT_CODE, await call.code()) + self.assertEqual(_ABORT_DETAILS, await call.details()) + + with self.assertRaises(aio.AioRpcError) as exception_context: + await call + + rpc_error = exception_context.exception + self.assertEqual(_ABORT_CODE, rpc_error.code()) + self.assertEqual(_ABORT_DETAILS, rpc_error.details()) + + async def test_suppress_abort(self): + method = self._channel.unary_unary(_SUPPRESS_ABORT) + call = method(_REQUEST) + + with self.assertRaises(aio.AioRpcError) as exception_context: + await call + + rpc_error = exception_context.exception + self.assertEqual(_ABORT_CODE, rpc_error.code()) + self.assertEqual(_ABORT_DETAILS, rpc_error.details()) + + async def test_replace_abort(self): + method = self._channel.unary_unary(_REPLACE_ABORT) + call = method(_REQUEST) + + with self.assertRaises(aio.AioRpcError) as exception_context: + await call + + rpc_error = exception_context.exception + self.assertEqual(_ABORT_CODE, rpc_error.code()) + self.assertEqual(_ABORT_DETAILS, rpc_error.details()) + + async def test_abort_after_reply(self): + method = self._channel.unary_stream(_ABORT_AFTER_REPLY) + call = method(_REQUEST) + + with self.assertRaises(aio.AioRpcError) as exception_context: + await call.read() + await call.read() + + rpc_error = exception_context.exception + self.assertEqual(_ABORT_CODE, rpc_error.code()) + self.assertEqual(_ABORT_DETAILS, rpc_error.details()) + + self.assertEqual(_ABORT_CODE, await call.code()) + self.assertEqual(_ABORT_DETAILS, await call.details()) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/aio_rpc_error_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/aio_rpc_error_test.py new file mode 100644 index 00000000000..b7b18e08f6e --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/aio_rpc_error_test.py @@ -0,0 +1,52 @@ +# Copyright 2019 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests AioRpcError class.""" + +import logging +import unittest + +import grpc + +from grpc.experimental import aio +from grpc.aio._call import AioRpcError +from tests_aio.unit._test_base import AioTestBase + +_TEST_INITIAL_METADATA = aio.Metadata( + ('initial metadata key', 'initial metadata value')) +_TEST_TRAILING_METADATA = aio.Metadata( + ('trailing metadata key', 'trailing metadata value')) +_TEST_DEBUG_ERROR_STRING = '{This is a debug string}' + + +class TestAioRpcError(unittest.TestCase): + + def test_attributes(self): + aio_rpc_error = AioRpcError(grpc.StatusCode.CANCELLED, + initial_metadata=_TEST_INITIAL_METADATA, + trailing_metadata=_TEST_TRAILING_METADATA, + details="details", + debug_error_string=_TEST_DEBUG_ERROR_STRING) + self.assertEqual(aio_rpc_error.code(), grpc.StatusCode.CANCELLED) + self.assertEqual(aio_rpc_error.details(), 'details') + self.assertEqual(aio_rpc_error.initial_metadata(), + _TEST_INITIAL_METADATA) + self.assertEqual(aio_rpc_error.trailing_metadata(), + _TEST_TRAILING_METADATA) + self.assertEqual(aio_rpc_error.debug_error_string(), + _TEST_DEBUG_ERROR_STRING) + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/auth_context_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/auth_context_test.py new file mode 100644 index 00000000000..fb303714682 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/auth_context_test.py @@ -0,0 +1,194 @@ +# Copyright 2020 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Porting auth context tests from sync stack.""" + +import pickle +import unittest +import logging + +import grpc +from grpc.experimental import aio +from grpc.experimental import session_cache +import six + +from tests.unit import resources +from tests_aio.unit._test_base import AioTestBase + +_REQUEST = b'\x00\x00\x00' +_RESPONSE = b'\x00\x00\x00' + +_UNARY_UNARY = '/test/UnaryUnary' + +_SERVER_HOST_OVERRIDE = 'foo.test.google.fr' +_CLIENT_IDS = ( + b'*.test.google.fr', + b'waterzooi.test.google.be', + b'*.test.youtube.com', + b'192.168.1.3', +) +_ID = 'id' +_ID_KEY = 'id_key' +_AUTH_CTX = 'auth_ctx' + +_PRIVATE_KEY = resources.private_key() +_CERTIFICATE_CHAIN = resources.certificate_chain() +_TEST_ROOT_CERTIFICATES = resources.test_root_certificates() +_SERVER_CERTS = ((_PRIVATE_KEY, _CERTIFICATE_CHAIN),) +_PROPERTY_OPTIONS = (( + 'grpc.ssl_target_name_override', + _SERVER_HOST_OVERRIDE, +),) + + +async def handle_unary_unary(unused_request: bytes, + servicer_context: aio.ServicerContext): + return pickle.dumps({ + _ID: servicer_context.peer_identities(), + _ID_KEY: servicer_context.peer_identity_key(), + _AUTH_CTX: servicer_context.auth_context() + }) + + +class TestAuthContext(AioTestBase): + + async def test_insecure(self): + handler = grpc.method_handlers_generic_handler('test', { + 'UnaryUnary': + grpc.unary_unary_rpc_method_handler(handle_unary_unary) + }) + server = aio.server() + server.add_generic_rpc_handlers((handler,)) + port = server.add_insecure_port('[::]:0') + await server.start() + + async with aio.insecure_channel('localhost:%d' % port) as channel: + response = await channel.unary_unary(_UNARY_UNARY)(_REQUEST) + await server.stop(None) + + auth_data = pickle.loads(response) + self.assertIsNone(auth_data[_ID]) + self.assertIsNone(auth_data[_ID_KEY]) + self.assertDictEqual({}, auth_data[_AUTH_CTX]) + + async def test_secure_no_cert(self): + handler = grpc.method_handlers_generic_handler('test', { + 'UnaryUnary': + grpc.unary_unary_rpc_method_handler(handle_unary_unary) + }) + server = aio.server() + server.add_generic_rpc_handlers((handler,)) + server_cred = grpc.ssl_server_credentials(_SERVER_CERTS) + port = server.add_secure_port('[::]:0', server_cred) + await server.start() + + channel_creds = grpc.ssl_channel_credentials( + root_certificates=_TEST_ROOT_CERTIFICATES) + channel = aio.secure_channel('localhost:{}'.format(port), + channel_creds, + options=_PROPERTY_OPTIONS) + response = await channel.unary_unary(_UNARY_UNARY)(_REQUEST) + await channel.close() + await server.stop(None) + + auth_data = pickle.loads(response) + self.assertIsNone(auth_data[_ID]) + self.assertIsNone(auth_data[_ID_KEY]) + self.assertDictEqual( + { + 'security_level': [b'TSI_PRIVACY_AND_INTEGRITY'], + 'transport_security_type': [b'ssl'], + 'ssl_session_reused': [b'false'], + }, auth_data[_AUTH_CTX]) + + async def test_secure_client_cert(self): + handler = grpc.method_handlers_generic_handler('test', { + 'UnaryUnary': + grpc.unary_unary_rpc_method_handler(handle_unary_unary) + }) + server = aio.server() + server.add_generic_rpc_handlers((handler,)) + server_cred = grpc.ssl_server_credentials( + _SERVER_CERTS, + root_certificates=_TEST_ROOT_CERTIFICATES, + require_client_auth=True) + port = server.add_secure_port('[::]:0', server_cred) + await server.start() + + channel_creds = grpc.ssl_channel_credentials( + root_certificates=_TEST_ROOT_CERTIFICATES, + private_key=_PRIVATE_KEY, + certificate_chain=_CERTIFICATE_CHAIN) + channel = aio.secure_channel('localhost:{}'.format(port), + channel_creds, + options=_PROPERTY_OPTIONS) + + response = await channel.unary_unary(_UNARY_UNARY)(_REQUEST) + await channel.close() + await server.stop(None) + + auth_data = pickle.loads(response) + auth_ctx = auth_data[_AUTH_CTX] + self.assertCountEqual(_CLIENT_IDS, auth_data[_ID]) + self.assertEqual('x509_subject_alternative_name', auth_data[_ID_KEY]) + self.assertSequenceEqual([b'ssl'], auth_ctx['transport_security_type']) + self.assertSequenceEqual([b'*.test.google.com'], + auth_ctx['x509_common_name']) + + async def _do_one_shot_client_rpc(self, channel_creds, channel_options, + port, expect_ssl_session_reused): + channel = aio.secure_channel('localhost:{}'.format(port), + channel_creds, + options=channel_options) + response = await channel.unary_unary(_UNARY_UNARY)(_REQUEST) + auth_data = pickle.loads(response) + self.assertEqual(expect_ssl_session_reused, + auth_data[_AUTH_CTX]['ssl_session_reused']) + await channel.close() + + async def test_session_resumption(self): + # Set up a secure server + handler = grpc.method_handlers_generic_handler('test', { + 'UnaryUnary': + grpc.unary_unary_rpc_method_handler(handle_unary_unary) + }) + server = aio.server() + server.add_generic_rpc_handlers((handler,)) + server_cred = grpc.ssl_server_credentials(_SERVER_CERTS) + port = server.add_secure_port('[::]:0', server_cred) + await server.start() + + # Create a cache for TLS session tickets + cache = session_cache.ssl_session_cache_lru(1) + channel_creds = grpc.ssl_channel_credentials( + root_certificates=_TEST_ROOT_CERTIFICATES) + channel_options = _PROPERTY_OPTIONS + ( + ('grpc.ssl_session_cache', cache),) + + # Initial connection has no session to resume + await self._do_one_shot_client_rpc(channel_creds, + channel_options, + port, + expect_ssl_session_reused=[b'false']) + + # Subsequent connections resume sessions + await self._do_one_shot_client_rpc(channel_creds, + channel_options, + port, + expect_ssl_session_reused=[b'true']) + await server.stop(None) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + unittest.main() diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/call_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/call_test.py new file mode 100644 index 00000000000..1961226fa6d --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/call_test.py @@ -0,0 +1,814 @@ +# Copyright 2019 The gRPC Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests behavior of the Call classes.""" + +import asyncio +import logging +import unittest +import datetime + +import grpc +from grpc.experimental import aio + +from src.proto.grpc.testing import messages_pb2, test_pb2_grpc +from tests_aio.unit._test_base import AioTestBase +from tests_aio.unit._test_server import start_test_server +from tests_aio.unit._constants import UNREACHABLE_TARGET + +_SHORT_TIMEOUT_S = datetime.timedelta(seconds=1).total_seconds() + +_NUM_STREAM_RESPONSES = 5 +_RESPONSE_PAYLOAD_SIZE = 42 +_REQUEST_PAYLOAD_SIZE = 7 +_LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!' +_RESPONSE_INTERVAL_US = int(_SHORT_TIMEOUT_S * 1000 * 1000) +_INFINITE_INTERVAL_US = 2**31 - 1 + + +class _MulticallableTestMixin(): + + async def setUp(self): + address, self._server = await start_test_server() + self._channel = aio.insecure_channel(address) + self._stub = test_pb2_grpc.TestServiceStub(self._channel) + + async def tearDown(self): + await self._channel.close() + await self._server.stop(None) + + +class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase): + + async def test_call_to_string(self): + call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) + + self.assertTrue(str(call) is not None) + self.assertTrue(repr(call) is not None) + + await call + + self.assertTrue(str(call) is not None) + self.assertTrue(repr(call) is not None) + + async def test_call_ok(self): + call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) + + self.assertFalse(call.done()) + + response = await call + + self.assertTrue(call.done()) + self.assertIsInstance(response, messages_pb2.SimpleResponse) + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + # Response is cached at call object level, reentrance + # returns again the same response + response_retry = await call + self.assertIs(response, response_retry) + + async def test_call_rpc_error(self): + async with aio.insecure_channel(UNREACHABLE_TARGET) as channel: + stub = test_pb2_grpc.TestServiceStub(channel) + + call = stub.UnaryCall(messages_pb2.SimpleRequest()) + + with self.assertRaises(aio.AioRpcError) as exception_context: + await call + + self.assertEqual(grpc.StatusCode.UNAVAILABLE, + exception_context.exception.code()) + + self.assertTrue(call.done()) + self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code()) + + async def test_call_code_awaitable(self): + call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + async def test_call_details_awaitable(self): + call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) + self.assertEqual('', await call.details()) + + async def test_call_initial_metadata_awaitable(self): + call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) + self.assertEqual(aio.Metadata(), await call.initial_metadata()) + + async def test_call_trailing_metadata_awaitable(self): + call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) + self.assertEqual(aio.Metadata(), await call.trailing_metadata()) + + async def test_call_initial_metadata_cancelable(self): + coro_started = asyncio.Event() + call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) + + async def coro(): + coro_started.set() + await call.initial_metadata() + + task = self.loop.create_task(coro()) + await coro_started.wait() + task.cancel() + + # Test that initial metadata can still be asked thought + # a cancellation happened with the previous task + self.assertEqual(aio.Metadata(), await call.initial_metadata()) + + async def test_call_initial_metadata_multiple_waiters(self): + call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) + + async def coro(): + return await call.initial_metadata() + + task1 = self.loop.create_task(coro()) + task2 = self.loop.create_task(coro()) + + await call + expected = [aio.Metadata() for _ in range(2)] + self.assertEqual(expected, await asyncio.gather(*[task1, task2])) + + async def test_call_code_cancelable(self): + coro_started = asyncio.Event() + call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) + + async def coro(): + coro_started.set() + await call.code() + + task = self.loop.create_task(coro()) + await coro_started.wait() + task.cancel() + + # Test that code can still be asked thought + # a cancellation happened with the previous task + self.assertEqual(grpc.StatusCode.OK, await call.code()) + + async def test_call_code_multiple_waiters(self): + call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) + + async def coro(): + return await call.code() + + task1 = self.loop.create_task(coro()) + task2 = self.loop.create_task(coro()) + + await call + + self.assertEqual([grpc.StatusCode.OK, grpc.StatusCode.OK], await + asyncio.gather(task1, task2)) + + async def test_cancel_unary_unary(self): + call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) + + self.assertFalse(call.cancelled()) + + self.assertTrue(call.cancel()) + self.assertFalse(call.cancel()) + + with self.assertRaises(asyncio.CancelledError): + await call + + # The info in the RpcError should match the info in Call object. + self.assertTrue(call.cancelled()) + self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) + self.assertEqual(await call.details(), + 'Locally cancelled by application!') + + async def test_cancel_unary_unary_in_task(self): + coro_started = asyncio.Event() + call = self._stub.EmptyCall(messages_pb2.SimpleRequest()) + + async def another_coro(): + coro_started.set() + await call + + task = self.loop.create_task(another_coro()) + await coro_started.wait() + + self.assertFalse(task.done()) + task.cancel() + + self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) + + with self.assertRaises(asyncio.CancelledError): + await task + + async def test_passing_credentials_fails_over_insecure_channel(self): + call_credentials = grpc.composite_call_credentials( + grpc.access_token_call_credentials("abc"), + grpc.access_token_call_credentials("def"), + ) + with self.assertRaisesRegex( + aio.UsageError, + "Call credentials are only valid on secure channels"): + self._stub.UnaryCall(messages_pb2.SimpleRequest(), + credentials=call_credentials) + + +class TestUnaryStreamCall(_MulticallableTestMixin, AioTestBase): + + async def test_call_rpc_error(self): + channel = aio.insecure_channel(UNREACHABLE_TARGET) + request = messages_pb2.StreamingOutputCallRequest() + stub = test_pb2_grpc.TestServiceStub(channel) + call = stub.StreamingOutputCall(request) + + with self.assertRaises(aio.AioRpcError) as exception_context: + async for response in call: + pass + + self.assertEqual(grpc.StatusCode.UNAVAILABLE, + exception_context.exception.code()) + + self.assertTrue(call.done()) + self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code()) + await channel.close() + + async def test_cancel_unary_stream(self): + # Prepares the request + request = messages_pb2.StreamingOutputCallRequest() + for _ in range(_NUM_STREAM_RESPONSES): + request.response_parameters.append( + messages_pb2.ResponseParameters( + size=_RESPONSE_PAYLOAD_SIZE, + interval_us=_RESPONSE_INTERVAL_US, + )) + + # Invokes the actual RPC + call = self._stub.StreamingOutputCall(request) + self.assertFalse(call.cancelled()) + + response = await call.read() + self.assertIs(type(response), messages_pb2.StreamingOutputCallResponse) + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) + + self.assertTrue(call.cancel()) + self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) + self.assertEqual(_LOCAL_CANCEL_DETAILS_EXPECTATION, await + call.details()) + self.assertFalse(call.cancel()) + + with self.assertRaises(asyncio.CancelledError): + await call.read() + self.assertTrue(call.cancelled()) + + async def test_multiple_cancel_unary_stream(self): + # Prepares the request + request = messages_pb2.StreamingOutputCallRequest() + for _ in range(_NUM_STREAM_RESPONSES): + request.response_parameters.append( + messages_pb2.ResponseParameters( + size=_RESPONSE_PAYLOAD_SIZE, + interval_us=_RESPONSE_INTERVAL_US, + )) + + # Invokes the actual RPC + call = self._stub.StreamingOutputCall(request) + self.assertFalse(call.cancelled()) + + response = await call.read() + self.assertIs(type(response), messages_pb2.StreamingOutputCallResponse) + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) + + self.assertTrue(call.cancel()) + self.assertFalse(call.cancel()) + self.assertFalse(call.cancel()) + self.assertFalse(call.cancel()) + + with self.assertRaises(asyncio.CancelledError): + await call.read() + + async def test_early_cancel_unary_stream(self): + """Test cancellation before receiving messages.""" + # Prepares the request + request = messages_pb2.StreamingOutputCallRequest() + for _ in range(_NUM_STREAM_RESPONSES): + request.response_parameters.append( + messages_pb2.ResponseParameters( + size=_RESPONSE_PAYLOAD_SIZE, + interval_us=_RESPONSE_INTERVAL_US, + )) + + # Invokes the actual RPC + call = self._stub.StreamingOutputCall(request) + + self.assertFalse(call.cancelled()) + self.assertTrue(call.cancel()) + self.assertFalse(call.cancel()) + + with self.assertRaises(asyncio.CancelledError): + await call.read() + + self.assertTrue(call.cancelled()) + + self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) + self.assertEqual(_LOCAL_CANCEL_DETAILS_EXPECTATION, await + call.details()) + + async def test_late_cancel_unary_stream(self): + """Test cancellation after received all messages.""" + # Prepares the request + request = messages_pb2.StreamingOutputCallRequest() + for _ in range(_NUM_STREAM_RESPONSES): + request.response_parameters.append( + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,)) + + # Invokes the actual RPC + call = self._stub.StreamingOutputCall(request) + + for _ in range(_NUM_STREAM_RESPONSES): + response = await call.read() + self.assertIs(type(response), + messages_pb2.StreamingOutputCallResponse) + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) + + # After all messages received, it is possible that the final state + # is received or on its way. It's basically a data race, so our + # expectation here is do not crash :) + call.cancel() + self.assertIn(await call.code(), + [grpc.StatusCode.OK, grpc.StatusCode.CANCELLED]) + + async def test_too_many_reads_unary_stream(self): + """Test calling read after received all messages fails.""" + # Prepares the request + request = messages_pb2.StreamingOutputCallRequest() + for _ in range(_NUM_STREAM_RESPONSES): + request.response_parameters.append( + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,)) + + # Invokes the actual RPC + call = self._stub.StreamingOutputCall(request) + + for _ in range(_NUM_STREAM_RESPONSES): + response = await call.read() + self.assertIs(type(response), + messages_pb2.StreamingOutputCallResponse) + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) + self.assertIs(await call.read(), aio.EOF) + + # After the RPC is finished, further reads will lead to exception. + self.assertEqual(await call.code(), grpc.StatusCode.OK) + self.assertIs(await call.read(), aio.EOF) + + async def test_unary_stream_async_generator(self): + """Sunny day test case for unary_stream.""" + # Prepares the request + request = messages_pb2.StreamingOutputCallRequest() + for _ in range(_NUM_STREAM_RESPONSES): + request.response_parameters.append( + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,)) + + # Invokes the actual RPC + call = self._stub.StreamingOutputCall(request) + self.assertFalse(call.cancelled()) + + async for response in call: + self.assertIs(type(response), + messages_pb2.StreamingOutputCallResponse) + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) + + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + async def test_cancel_unary_stream_in_task_using_read(self): + coro_started = asyncio.Event() + + # Configs the server method to block forever + request = messages_pb2.StreamingOutputCallRequest() + request.response_parameters.append( + messages_pb2.ResponseParameters( + size=_RESPONSE_PAYLOAD_SIZE, + interval_us=_INFINITE_INTERVAL_US, + )) + + # Invokes the actual RPC + call = self._stub.StreamingOutputCall(request) + + async def another_coro(): + coro_started.set() + await call.read() + + task = self.loop.create_task(another_coro()) + await coro_started.wait() + + self.assertFalse(task.done()) + task.cancel() + + self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) + + with self.assertRaises(asyncio.CancelledError): + await task + + async def test_cancel_unary_stream_in_task_using_async_for(self): + coro_started = asyncio.Event() + + # Configs the server method to block forever + request = messages_pb2.StreamingOutputCallRequest() + request.response_parameters.append( + messages_pb2.ResponseParameters( + size=_RESPONSE_PAYLOAD_SIZE, + interval_us=_INFINITE_INTERVAL_US, + )) + + # Invokes the actual RPC + call = self._stub.StreamingOutputCall(request) + + async def another_coro(): + coro_started.set() + async for _ in call: + pass + + task = self.loop.create_task(another_coro()) + await coro_started.wait() + + self.assertFalse(task.done()) + task.cancel() + + self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) + + with self.assertRaises(asyncio.CancelledError): + await task + + async def test_time_remaining(self): + request = messages_pb2.StreamingOutputCallRequest() + # First message comes back immediately + request.response_parameters.append( + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,)) + # Second message comes back after a unit of wait time + request.response_parameters.append( + messages_pb2.ResponseParameters( + size=_RESPONSE_PAYLOAD_SIZE, + interval_us=_RESPONSE_INTERVAL_US, + )) + + call = self._stub.StreamingOutputCall(request, + timeout=_SHORT_TIMEOUT_S * 2) + + response = await call.read() + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) + + # Should be around the same as the timeout + remained_time = call.time_remaining() + self.assertGreater(remained_time, _SHORT_TIMEOUT_S * 3 / 2) + self.assertLess(remained_time, _SHORT_TIMEOUT_S * 5 / 2) + + response = await call.read() + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) + + # Should be around the timeout minus a unit of wait time + remained_time = call.time_remaining() + self.assertGreater(remained_time, _SHORT_TIMEOUT_S / 2) + self.assertLess(remained_time, _SHORT_TIMEOUT_S * 3 / 2) + + self.assertEqual(grpc.StatusCode.OK, await call.code()) + + +class TestStreamUnaryCall(_MulticallableTestMixin, AioTestBase): + + async def test_cancel_stream_unary(self): + call = self._stub.StreamingInputCall() + + # Prepares the request + payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest(payload=payload) + + # Sends out requests + for _ in range(_NUM_STREAM_RESPONSES): + await call.write(request) + + # Cancels the RPC + self.assertFalse(call.done()) + self.assertFalse(call.cancelled()) + self.assertTrue(call.cancel()) + self.assertTrue(call.cancelled()) + + await call.done_writing() + + with self.assertRaises(asyncio.CancelledError): + await call + + async def test_early_cancel_stream_unary(self): + call = self._stub.StreamingInputCall() + + # Cancels the RPC + self.assertFalse(call.done()) + self.assertFalse(call.cancelled()) + self.assertTrue(call.cancel()) + self.assertTrue(call.cancelled()) + + with self.assertRaises(asyncio.InvalidStateError): + await call.write(messages_pb2.StreamingInputCallRequest()) + + # Should be no-op + await call.done_writing() + + with self.assertRaises(asyncio.CancelledError): + await call + + async def test_write_after_done_writing(self): + call = self._stub.StreamingInputCall() + + # Prepares the request + payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest(payload=payload) + + # Sends out requests + for _ in range(_NUM_STREAM_RESPONSES): + await call.write(request) + + # Should be no-op + await call.done_writing() + + with self.assertRaises(asyncio.InvalidStateError): + await call.write(messages_pb2.StreamingInputCallRequest()) + + response = await call + self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse) + self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, + response.aggregated_payload_size) + + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + async def test_error_in_async_generator(self): + # Server will pause between responses + request = messages_pb2.StreamingOutputCallRequest() + request.response_parameters.append( + messages_pb2.ResponseParameters( + size=_RESPONSE_PAYLOAD_SIZE, + interval_us=_RESPONSE_INTERVAL_US, + )) + + # We expect the request iterator to receive the exception + request_iterator_received_the_exception = asyncio.Event() + + async def request_iterator(): + with self.assertRaises(asyncio.CancelledError): + for _ in range(_NUM_STREAM_RESPONSES): + yield request + await asyncio.sleep(_SHORT_TIMEOUT_S) + request_iterator_received_the_exception.set() + + call = self._stub.StreamingInputCall(request_iterator()) + + # Cancel the RPC after at least one response + async def cancel_later(): + await asyncio.sleep(_SHORT_TIMEOUT_S * 2) + call.cancel() + + cancel_later_task = self.loop.create_task(cancel_later()) + + with self.assertRaises(asyncio.CancelledError): + await call + + await request_iterator_received_the_exception.wait() + + # No failures in the cancel later task! + await cancel_later_task + + async def test_normal_iterable_requests(self): + # Prepares the request + payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest(payload=payload) + requests = [request] * _NUM_STREAM_RESPONSES + + # Sends out requests + call = self._stub.StreamingInputCall(requests) + + # RPC should succeed + response = await call + self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse) + self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, + response.aggregated_payload_size) + + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + async def test_call_rpc_error(self): + async with aio.insecure_channel(UNREACHABLE_TARGET) as channel: + stub = test_pb2_grpc.TestServiceStub(channel) + + # The error should be raised automatically without any traffic. + call = stub.StreamingInputCall() + with self.assertRaises(aio.AioRpcError) as exception_context: + await call + + self.assertEqual(grpc.StatusCode.UNAVAILABLE, + exception_context.exception.code()) + + self.assertTrue(call.done()) + self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code()) + + async def test_timeout(self): + call = self._stub.StreamingInputCall(timeout=_SHORT_TIMEOUT_S) + + # The error should be raised automatically without any traffic. + with self.assertRaises(aio.AioRpcError) as exception_context: + await call + + rpc_error = exception_context.exception + self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, rpc_error.code()) + self.assertTrue(call.done()) + self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, await call.code()) + + +# Prepares the request that stream in a ping-pong manner. +_STREAM_OUTPUT_REQUEST_ONE_RESPONSE = messages_pb2.StreamingOutputCallRequest() +_STREAM_OUTPUT_REQUEST_ONE_RESPONSE.response_parameters.append( + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) + + +class TestStreamStreamCall(_MulticallableTestMixin, AioTestBase): + + async def test_cancel(self): + # Invokes the actual RPC + call = self._stub.FullDuplexCall() + + for _ in range(_NUM_STREAM_RESPONSES): + await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE) + response = await call.read() + self.assertIsInstance(response, + messages_pb2.StreamingOutputCallResponse) + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) + + # Cancels the RPC + self.assertFalse(call.done()) + self.assertFalse(call.cancelled()) + self.assertTrue(call.cancel()) + self.assertTrue(call.cancelled()) + self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) + + async def test_cancel_with_pending_read(self): + call = self._stub.FullDuplexCall() + + await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE) + + # Cancels the RPC + self.assertFalse(call.done()) + self.assertFalse(call.cancelled()) + self.assertTrue(call.cancel()) + self.assertTrue(call.cancelled()) + self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) + + async def test_cancel_with_ongoing_read(self): + call = self._stub.FullDuplexCall() + coro_started = asyncio.Event() + + async def read_coro(): + coro_started.set() + await call.read() + + read_task = self.loop.create_task(read_coro()) + await coro_started.wait() + self.assertFalse(read_task.done()) + + # Cancels the RPC + self.assertFalse(call.done()) + self.assertFalse(call.cancelled()) + self.assertTrue(call.cancel()) + self.assertTrue(call.cancelled()) + self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) + + async def test_early_cancel(self): + call = self._stub.FullDuplexCall() + + # Cancels the RPC + self.assertFalse(call.done()) + self.assertFalse(call.cancelled()) + self.assertTrue(call.cancel()) + self.assertTrue(call.cancelled()) + self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) + + async def test_cancel_after_done_writing(self): + call = self._stub.FullDuplexCall() + await call.done_writing() + + # Cancels the RPC + self.assertFalse(call.done()) + self.assertFalse(call.cancelled()) + self.assertTrue(call.cancel()) + self.assertTrue(call.cancelled()) + self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) + + async def test_late_cancel(self): + call = self._stub.FullDuplexCall() + await call.done_writing() + self.assertEqual(grpc.StatusCode.OK, await call.code()) + + # Cancels the RPC + self.assertTrue(call.done()) + self.assertFalse(call.cancelled()) + self.assertFalse(call.cancel()) + self.assertFalse(call.cancelled()) + + # Status is still OK + self.assertEqual(grpc.StatusCode.OK, await call.code()) + + async def test_async_generator(self): + + async def request_generator(): + yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE + yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE + + call = self._stub.FullDuplexCall(request_generator()) + async for response in call: + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) + + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + async def test_too_many_reads(self): + + async def request_generator(): + for _ in range(_NUM_STREAM_RESPONSES): + yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE + + call = self._stub.FullDuplexCall(request_generator()) + for _ in range(_NUM_STREAM_RESPONSES): + response = await call.read() + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) + self.assertIs(await call.read(), aio.EOF) + + self.assertEqual(await call.code(), grpc.StatusCode.OK) + # After the RPC finished, the read should also produce EOF + self.assertIs(await call.read(), aio.EOF) + + async def test_read_write_after_done_writing(self): + call = self._stub.FullDuplexCall() + + # Writes two requests, and pending two requests + await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE) + await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE) + await call.done_writing() + + # Further write should fail + with self.assertRaises(asyncio.InvalidStateError): + await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE) + + # But read should be unaffected + response = await call.read() + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) + response = await call.read() + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) + + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + async def test_error_in_async_generator(self): + # Server will pause between responses + request = messages_pb2.StreamingOutputCallRequest() + request.response_parameters.append( + messages_pb2.ResponseParameters( + size=_RESPONSE_PAYLOAD_SIZE, + interval_us=_RESPONSE_INTERVAL_US, + )) + + # We expect the request iterator to receive the exception + request_iterator_received_the_exception = asyncio.Event() + + async def request_iterator(): + with self.assertRaises(asyncio.CancelledError): + for _ in range(_NUM_STREAM_RESPONSES): + yield request + await asyncio.sleep(_SHORT_TIMEOUT_S) + request_iterator_received_the_exception.set() + + call = self._stub.FullDuplexCall(request_iterator()) + + # Cancel the RPC after at least one response + async def cancel_later(): + await asyncio.sleep(_SHORT_TIMEOUT_S * 2) + call.cancel() + + cancel_later_task = self.loop.create_task(cancel_later()) + + with self.assertRaises(asyncio.CancelledError): + async for response in call: + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, + len(response.payload.body)) + + await request_iterator_received_the_exception.wait() + + self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) + # No failures in the cancel later task! + await cancel_later_task + + async def test_normal_iterable_requests(self): + requests = [_STREAM_OUTPUT_REQUEST_ONE_RESPONSE] * _NUM_STREAM_RESPONSES + + call = self._stub.FullDuplexCall(iter(requests)) + async for response in call: + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) + + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/channel_argument_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/channel_argument_test.py new file mode 100644 index 00000000000..8bf2dc8b1f1 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/channel_argument_test.py @@ -0,0 +1,176 @@ +# Copyright 2019 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests behavior around the Core channel arguments.""" + +import asyncio +import logging +import platform +import random +import errno +import unittest + +import grpc +from grpc.experimental import aio + +from src.proto.grpc.testing import messages_pb2, test_pb2_grpc +from tests.unit.framework import common +from tests_aio.unit._test_base import AioTestBase +from tests_aio.unit._test_server import start_test_server + +_RANDOM_SEED = 42 + +_ENABLE_REUSE_PORT = 'SO_REUSEPORT enabled' +_DISABLE_REUSE_PORT = 'SO_REUSEPORT disabled' +_SOCKET_OPT_SO_REUSEPORT = 'grpc.so_reuseport' +_OPTIONS = ( + (_ENABLE_REUSE_PORT, ((_SOCKET_OPT_SO_REUSEPORT, 1),)), + (_DISABLE_REUSE_PORT, ((_SOCKET_OPT_SO_REUSEPORT, 0),)), +) + +_NUM_SERVER_CREATED = 5 + +_GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH = 'grpc.max_receive_message_length' +_MAX_MESSAGE_LENGTH = 1024 + +_ADDRESS_TOKEN_ERRNO = errno.EADDRINUSE, errno.ENOSR + + +class _TestPointerWrapper(object): + + def __int__(self): + return 123456 + + +_TEST_CHANNEL_ARGS = ( + ('arg1', b'bytes_val'), + ('arg2', 'str_val'), + ('arg3', 1), + (b'arg4', 'str_val'), + ('arg6', _TestPointerWrapper()), +) + +_INVALID_TEST_CHANNEL_ARGS = [ + { + 'foo': 'bar' + }, + (('key',),), + 'str', +] + + +async def test_if_reuse_port_enabled(server: aio.Server): + port = server.add_insecure_port('localhost:0') + await server.start() + + try: + with common.bound_socket( + bind_address='localhost', + port=port, + listen=False, + ) as (unused_host, bound_port): + assert bound_port == port + except OSError as e: + if e.errno in _ADDRESS_TOKEN_ERRNO: + return False + else: + logging.exception(e) + raise + else: + return True + + +class TestChannelArgument(AioTestBase): + + async def setUp(self): + random.seed(_RANDOM_SEED) + + @unittest.skipIf(platform.system() == 'Windows', + 'SO_REUSEPORT only available in Linux-like OS.') + async def test_server_so_reuse_port_is_set_properly(self): + + async def test_body(): + fact, options = random.choice(_OPTIONS) + server = aio.server(options=options) + try: + result = await test_if_reuse_port_enabled(server) + if fact == _ENABLE_REUSE_PORT and not result: + self.fail( + 'Enabled reuse port in options, but not observed in socket' + ) + elif fact == _DISABLE_REUSE_PORT and result: + self.fail( + 'Disabled reuse port in options, but observed in socket' + ) + finally: + await server.stop(None) + + # Creating a lot of servers concurrently + await asyncio.gather(*(test_body() for _ in range(_NUM_SERVER_CREATED))) + + async def test_client(self): + # Do not segfault, or raise exception! + channel = aio.insecure_channel('[::]:0', options=_TEST_CHANNEL_ARGS) + await channel.close() + + async def test_server(self): + # Do not segfault, or raise exception! + server = aio.server(options=_TEST_CHANNEL_ARGS) + await server.stop(None) + + async def test_invalid_client_args(self): + for invalid_arg in _INVALID_TEST_CHANNEL_ARGS: + self.assertRaises((ValueError, TypeError), + aio.insecure_channel, + '[::]:0', + options=invalid_arg) + + async def test_max_message_length_applied(self): + address, server = await start_test_server() + + async with aio.insecure_channel( + address, + options=((_GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH, + _MAX_MESSAGE_LENGTH),)) as channel: + stub = test_pb2_grpc.TestServiceStub(channel) + + request = messages_pb2.StreamingOutputCallRequest() + # First request will pass + request.response_parameters.append( + messages_pb2.ResponseParameters(size=_MAX_MESSAGE_LENGTH // 2,)) + # Second request should fail + request.response_parameters.append( + messages_pb2.ResponseParameters(size=_MAX_MESSAGE_LENGTH * 2,)) + + call = stub.StreamingOutputCall(request) + + response = await call.read() + self.assertEqual(_MAX_MESSAGE_LENGTH // 2, + len(response.payload.body)) + + with self.assertRaises(aio.AioRpcError) as exception_context: + await call.read() + rpc_error = exception_context.exception + self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED, + rpc_error.code()) + self.assertIn(str(_MAX_MESSAGE_LENGTH), rpc_error.details()) + + self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED, await + call.code()) + + await server.stop(None) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/channel_ready_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/channel_ready_test.py new file mode 100644 index 00000000000..75e4703d869 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/channel_ready_test.py @@ -0,0 +1,69 @@ +# Copyright 2020 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing the channel_ready function.""" + +import asyncio +import gc +import logging +import socket +import time +import unittest + +import grpc +from grpc.experimental import aio + +from tests.unit.framework.common import get_socket, test_constants +from tests_aio.unit import _common +from tests_aio.unit._test_base import AioTestBase +from tests_aio.unit._test_server import start_test_server + + +class TestChannelReady(AioTestBase): + + async def setUp(self): + address, self._port, self._socket = get_socket( + listen=False, sock_options=(socket.SO_REUSEADDR,)) + self._channel = aio.insecure_channel(f"{address}:{self._port}") + self._socket.close() + + async def tearDown(self): + await self._channel.close() + + async def test_channel_ready_success(self): + # Start `channel_ready` as another Task + channel_ready_task = self.loop.create_task( + self._channel.channel_ready()) + + # Wait for TRANSIENT_FAILURE + await _common.block_until_certain_state( + self._channel, grpc.ChannelConnectivity.TRANSIENT_FAILURE) + + try: + # Start the server + _, server = await start_test_server(port=self._port) + + # The RPC should recover itself + await channel_ready_task + finally: + await server.stop(None) + + async def test_channel_ready_blocked(self): + with self.assertRaises(asyncio.TimeoutError): + await asyncio.wait_for(self._channel.channel_ready(), + test_constants.SHORT_TIMEOUT) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/channel_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/channel_test.py new file mode 100644 index 00000000000..58cd555491d --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/channel_test.py @@ -0,0 +1,230 @@ +# Copyright 2019 The gRPC Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests behavior of the grpc.aio.Channel class.""" + +import logging +import os +import unittest + +import grpc +from grpc.experimental import aio + +from src.proto.grpc.testing import messages_pb2, test_pb2_grpc +from tests.unit.framework.common import test_constants +from tests_aio.unit._constants import (UNARY_CALL_WITH_SLEEP_VALUE, + UNREACHABLE_TARGET) +from tests_aio.unit._test_base import AioTestBase +from tests_aio.unit._test_server import start_test_server + +_UNARY_CALL_METHOD = '/grpc.testing.TestService/UnaryCall' +_UNARY_CALL_METHOD_WITH_SLEEP = '/grpc.testing.TestService/UnaryCallWithSleep' +_STREAMING_OUTPUT_CALL_METHOD = '/grpc.testing.TestService/StreamingOutputCall' + +_INVOCATION_METADATA = ( + ('x-grpc-test-echo-initial', 'initial-md-value'), + ('x-grpc-test-echo-trailing-bin', b'\x00\x02'), +) + +_NUM_STREAM_RESPONSES = 5 +_REQUEST_PAYLOAD_SIZE = 7 +_RESPONSE_PAYLOAD_SIZE = 42 + + +class TestChannel(AioTestBase): + + async def setUp(self): + self._server_target, self._server = await start_test_server() + + async def tearDown(self): + await self._server.stop(None) + + async def test_async_context(self): + async with aio.insecure_channel(self._server_target) as channel: + hi = channel.unary_unary( + _UNARY_CALL_METHOD, + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + await hi(messages_pb2.SimpleRequest()) + + async def test_unary_unary(self): + async with aio.insecure_channel(self._server_target) as channel: + hi = channel.unary_unary( + _UNARY_CALL_METHOD, + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + response = await hi(messages_pb2.SimpleRequest()) + + self.assertIsInstance(response, messages_pb2.SimpleResponse) + + async def test_unary_call_times_out(self): + async with aio.insecure_channel(self._server_target) as channel: + hi = channel.unary_unary( + _UNARY_CALL_METHOD_WITH_SLEEP, + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString, + ) + + with self.assertRaises(grpc.RpcError) as exception_context: + await hi(messages_pb2.SimpleRequest(), + timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2) + + _, details = grpc.StatusCode.DEADLINE_EXCEEDED.value # pylint: disable=unused-variable + self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, + exception_context.exception.code()) + self.assertEqual(details.title(), + exception_context.exception.details()) + self.assertIsNotNone(exception_context.exception.initial_metadata()) + self.assertIsNotNone( + exception_context.exception.trailing_metadata()) + + @unittest.skipIf(os.name == 'nt', + 'TODO: https://github.com/grpc/grpc/issues/21658') + async def test_unary_call_does_not_times_out(self): + async with aio.insecure_channel(self._server_target) as channel: + hi = channel.unary_unary( + _UNARY_CALL_METHOD_WITH_SLEEP, + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString, + ) + + call = hi(messages_pb2.SimpleRequest(), + timeout=UNARY_CALL_WITH_SLEEP_VALUE * 5) + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + async def test_unary_stream(self): + channel = aio.insecure_channel(self._server_target) + stub = test_pb2_grpc.TestServiceStub(channel) + + # Prepares the request + request = messages_pb2.StreamingOutputCallRequest() + for _ in range(_NUM_STREAM_RESPONSES): + request.response_parameters.append( + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) + + # Invokes the actual RPC + call = stub.StreamingOutputCall(request) + + # Validates the responses + response_cnt = 0 + async for response in call: + response_cnt += 1 + self.assertIs(type(response), + messages_pb2.StreamingOutputCallResponse) + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) + + self.assertEqual(_NUM_STREAM_RESPONSES, response_cnt) + self.assertEqual(await call.code(), grpc.StatusCode.OK) + await channel.close() + + async def test_stream_unary_using_write(self): + channel = aio.insecure_channel(self._server_target) + stub = test_pb2_grpc.TestServiceStub(channel) + + # Invokes the actual RPC + call = stub.StreamingInputCall() + + # Prepares the request + payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest(payload=payload) + + # Sends out requests + for _ in range(_NUM_STREAM_RESPONSES): + await call.write(request) + await call.done_writing() + + # Validates the responses + response = await call + self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse) + self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, + response.aggregated_payload_size) + + self.assertEqual(await call.code(), grpc.StatusCode.OK) + await channel.close() + + async def test_stream_unary_using_async_gen(self): + channel = aio.insecure_channel(self._server_target) + stub = test_pb2_grpc.TestServiceStub(channel) + + # Prepares the request + payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest(payload=payload) + + async def gen(): + for _ in range(_NUM_STREAM_RESPONSES): + yield request + + # Invokes the actual RPC + call = stub.StreamingInputCall(gen()) + + # Validates the responses + response = await call + self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse) + self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, + response.aggregated_payload_size) + + self.assertEqual(await call.code(), grpc.StatusCode.OK) + await channel.close() + + async def test_stream_stream_using_read_write(self): + channel = aio.insecure_channel(self._server_target) + stub = test_pb2_grpc.TestServiceStub(channel) + + # Invokes the actual RPC + call = stub.FullDuplexCall() + + # Prepares the request + request = messages_pb2.StreamingOutputCallRequest() + request.response_parameters.append( + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) + + for _ in range(_NUM_STREAM_RESPONSES): + await call.write(request) + response = await call.read() + self.assertIsInstance(response, + messages_pb2.StreamingOutputCallResponse) + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) + + await call.done_writing() + + self.assertEqual(grpc.StatusCode.OK, await call.code()) + await channel.close() + + async def test_stream_stream_using_async_gen(self): + channel = aio.insecure_channel(self._server_target) + stub = test_pb2_grpc.TestServiceStub(channel) + + # Prepares the request + request = messages_pb2.StreamingOutputCallRequest() + request.response_parameters.append( + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) + + async def gen(): + for _ in range(_NUM_STREAM_RESPONSES): + yield request + + # Invokes the actual RPC + call = stub.FullDuplexCall(gen()) + + async for response in call: + self.assertIsInstance(response, + messages_pb2.StreamingOutputCallResponse) + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) + + self.assertEqual(grpc.StatusCode.OK, await call.code()) + await channel.close() + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/client_stream_stream_interceptor_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/client_stream_stream_interceptor_test.py new file mode 100644 index 00000000000..ce6a7bc04d6 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/client_stream_stream_interceptor_test.py @@ -0,0 +1,202 @@ +# Copyright 2020 The gRPC Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import unittest + +import grpc + +from grpc.experimental import aio +from tests_aio.unit._common import CountingResponseIterator, CountingRequestIterator +from tests_aio.unit._test_server import start_test_server +from tests_aio.unit._test_base import AioTestBase +from src.proto.grpc.testing import messages_pb2, test_pb2_grpc + +_NUM_STREAM_RESPONSES = 5 +_NUM_STREAM_REQUESTS = 5 +_RESPONSE_PAYLOAD_SIZE = 7 + + +class _StreamStreamInterceptorEmpty(aio.StreamStreamClientInterceptor): + + async def intercept_stream_stream(self, continuation, client_call_details, + request_iterator): + return await continuation(client_call_details, request_iterator) + + def assert_in_final_state(self, test: unittest.TestCase): + pass + + +class _StreamStreamInterceptorWithRequestAndResponseIterator( + aio.StreamStreamClientInterceptor): + + async def intercept_stream_stream(self, continuation, client_call_details, + request_iterator): + self.request_iterator = CountingRequestIterator(request_iterator) + call = await continuation(client_call_details, self.request_iterator) + self.response_iterator = CountingResponseIterator(call) + return self.response_iterator + + def assert_in_final_state(self, test: unittest.TestCase): + test.assertEqual(_NUM_STREAM_REQUESTS, + self.request_iterator.request_cnt) + test.assertEqual(_NUM_STREAM_RESPONSES, + self.response_iterator.response_cnt) + + +class TestStreamStreamClientInterceptor(AioTestBase): + + async def setUp(self): + self._server_target, self._server = await start_test_server() + + async def tearDown(self): + await self._server.stop(None) + + async def test_intercepts(self): + + for interceptor_class in ( + _StreamStreamInterceptorEmpty, + _StreamStreamInterceptorWithRequestAndResponseIterator): + + with self.subTest(name=interceptor_class): + interceptor = interceptor_class() + channel = aio.insecure_channel(self._server_target, + interceptors=[interceptor]) + stub = test_pb2_grpc.TestServiceStub(channel) + + # Prepares the request + request = messages_pb2.StreamingOutputCallRequest() + request.response_parameters.append( + messages_pb2.ResponseParameters( + size=_RESPONSE_PAYLOAD_SIZE)) + + async def request_iterator(): + for _ in range(_NUM_STREAM_REQUESTS): + yield request + + call = stub.FullDuplexCall(request_iterator()) + + await call.wait_for_connection() + + response_cnt = 0 + async for response in call: + response_cnt += 1 + self.assertIsInstance( + response, messages_pb2.StreamingOutputCallResponse) + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, + len(response.payload.body)) + + self.assertEqual(response_cnt, _NUM_STREAM_RESPONSES) + self.assertEqual(await call.code(), grpc.StatusCode.OK) + self.assertEqual(await call.initial_metadata(), aio.Metadata()) + self.assertEqual(await call.trailing_metadata(), aio.Metadata()) + self.assertEqual(await call.details(), '') + self.assertEqual(await call.debug_error_string(), '') + self.assertEqual(call.cancel(), False) + self.assertEqual(call.cancelled(), False) + self.assertEqual(call.done(), True) + + interceptor.assert_in_final_state(self) + + await channel.close() + + async def test_intercepts_using_write_and_read(self): + for interceptor_class in ( + _StreamStreamInterceptorEmpty, + _StreamStreamInterceptorWithRequestAndResponseIterator): + + with self.subTest(name=interceptor_class): + interceptor = interceptor_class() + channel = aio.insecure_channel(self._server_target, + interceptors=[interceptor]) + stub = test_pb2_grpc.TestServiceStub(channel) + + # Prepares the request + request = messages_pb2.StreamingOutputCallRequest() + request.response_parameters.append( + messages_pb2.ResponseParameters( + size=_RESPONSE_PAYLOAD_SIZE)) + + call = stub.FullDuplexCall() + + for _ in range(_NUM_STREAM_RESPONSES): + await call.write(request) + response = await call.read() + self.assertIsInstance( + response, messages_pb2.StreamingOutputCallResponse) + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, + len(response.payload.body)) + + await call.done_writing() + + self.assertEqual(await call.code(), grpc.StatusCode.OK) + self.assertEqual(await call.initial_metadata(), aio.Metadata()) + self.assertEqual(await call.trailing_metadata(), aio.Metadata()) + self.assertEqual(await call.details(), '') + self.assertEqual(await call.debug_error_string(), '') + self.assertEqual(call.cancel(), False) + self.assertEqual(call.cancelled(), False) + self.assertEqual(call.done(), True) + + interceptor.assert_in_final_state(self) + + await channel.close() + + async def test_multiple_interceptors_request_iterator(self): + for interceptor_class in ( + _StreamStreamInterceptorEmpty, + _StreamStreamInterceptorWithRequestAndResponseIterator): + + with self.subTest(name=interceptor_class): + + interceptors = [interceptor_class(), interceptor_class()] + channel = aio.insecure_channel(self._server_target, + interceptors=interceptors) + stub = test_pb2_grpc.TestServiceStub(channel) + + # Prepares the request + request = messages_pb2.StreamingOutputCallRequest() + request.response_parameters.append( + messages_pb2.ResponseParameters( + size=_RESPONSE_PAYLOAD_SIZE)) + + call = stub.FullDuplexCall() + + for _ in range(_NUM_STREAM_RESPONSES): + await call.write(request) + response = await call.read() + self.assertIsInstance( + response, messages_pb2.StreamingOutputCallResponse) + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, + len(response.payload.body)) + + await call.done_writing() + + self.assertEqual(await call.code(), grpc.StatusCode.OK) + self.assertEqual(await call.initial_metadata(), aio.Metadata()) + self.assertEqual(await call.trailing_metadata(), aio.Metadata()) + self.assertEqual(await call.details(), '') + self.assertEqual(await call.debug_error_string(), '') + self.assertEqual(call.cancel(), False) + self.assertEqual(call.cancelled(), False) + self.assertEqual(call.done(), True) + + for interceptor in interceptors: + interceptor.assert_in_final_state(self) + + await channel.close() + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/client_stream_unary_interceptor_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/client_stream_unary_interceptor_test.py new file mode 100644 index 00000000000..b9a04af00dc --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/client_stream_unary_interceptor_test.py @@ -0,0 +1,517 @@ +# Copyright 2020 The gRPC Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import logging +import unittest +import datetime + +import grpc + +from grpc.experimental import aio +from tests_aio.unit._constants import UNREACHABLE_TARGET +from tests_aio.unit._common import inject_callbacks +from tests_aio.unit._common import CountingRequestIterator +from tests_aio.unit._test_server import start_test_server +from tests_aio.unit._test_base import AioTestBase +from tests.unit.framework.common import test_constants +from src.proto.grpc.testing import messages_pb2, test_pb2_grpc + +_SHORT_TIMEOUT_S = 1.0 + +_NUM_STREAM_REQUESTS = 5 +_REQUEST_PAYLOAD_SIZE = 7 +_RESPONSE_INTERVAL_US = int(_SHORT_TIMEOUT_S * 1000 * 1000) + + +class _StreamUnaryInterceptorEmpty(aio.StreamUnaryClientInterceptor): + + async def intercept_stream_unary(self, continuation, client_call_details, + request_iterator): + return await continuation(client_call_details, request_iterator) + + def assert_in_final_state(self, test: unittest.TestCase): + pass + + +class _StreamUnaryInterceptorWithRequestIterator( + aio.StreamUnaryClientInterceptor): + + async def intercept_stream_unary(self, continuation, client_call_details, + request_iterator): + self.request_iterator = CountingRequestIterator(request_iterator) + call = await continuation(client_call_details, self.request_iterator) + return call + + def assert_in_final_state(self, test: unittest.TestCase): + test.assertEqual(_NUM_STREAM_REQUESTS, + self.request_iterator.request_cnt) + + +class TestStreamUnaryClientInterceptor(AioTestBase): + + async def setUp(self): + self._server_target, self._server = await start_test_server() + + async def tearDown(self): + await self._server.stop(None) + + async def test_intercepts(self): + for interceptor_class in (_StreamUnaryInterceptorEmpty, + _StreamUnaryInterceptorWithRequestIterator): + + with self.subTest(name=interceptor_class): + interceptor = interceptor_class() + channel = aio.insecure_channel(self._server_target, + interceptors=[interceptor]) + stub = test_pb2_grpc.TestServiceStub(channel) + + payload = messages_pb2.Payload(body=b'\0' * + _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest( + payload=payload) + + async def request_iterator(): + for _ in range(_NUM_STREAM_REQUESTS): + yield request + + call = stub.StreamingInputCall(request_iterator()) + + response = await call + + self.assertEqual(_NUM_STREAM_REQUESTS * _REQUEST_PAYLOAD_SIZE, + response.aggregated_payload_size) + self.assertEqual(await call.code(), grpc.StatusCode.OK) + self.assertEqual(await call.initial_metadata(), aio.Metadata()) + self.assertEqual(await call.trailing_metadata(), aio.Metadata()) + self.assertEqual(await call.details(), '') + self.assertEqual(await call.debug_error_string(), '') + self.assertEqual(call.cancel(), False) + self.assertEqual(call.cancelled(), False) + self.assertEqual(call.done(), True) + + interceptor.assert_in_final_state(self) + + await channel.close() + + async def test_intercepts_using_write(self): + for interceptor_class in (_StreamUnaryInterceptorEmpty, + _StreamUnaryInterceptorWithRequestIterator): + + with self.subTest(name=interceptor_class): + interceptor = interceptor_class() + channel = aio.insecure_channel(self._server_target, + interceptors=[interceptor]) + stub = test_pb2_grpc.TestServiceStub(channel) + + payload = messages_pb2.Payload(body=b'\0' * + _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest( + payload=payload) + + call = stub.StreamingInputCall() + + for _ in range(_NUM_STREAM_REQUESTS): + await call.write(request) + + await call.done_writing() + + response = await call + + self.assertEqual(_NUM_STREAM_REQUESTS * _REQUEST_PAYLOAD_SIZE, + response.aggregated_payload_size) + self.assertEqual(await call.code(), grpc.StatusCode.OK) + self.assertEqual(await call.initial_metadata(), aio.Metadata()) + self.assertEqual(await call.trailing_metadata(), aio.Metadata()) + self.assertEqual(await call.details(), '') + self.assertEqual(await call.debug_error_string(), '') + self.assertEqual(call.cancel(), False) + self.assertEqual(call.cancelled(), False) + self.assertEqual(call.done(), True) + + interceptor.assert_in_final_state(self) + + await channel.close() + + async def test_add_done_callback_interceptor_task_not_finished(self): + for interceptor_class in (_StreamUnaryInterceptorEmpty, + _StreamUnaryInterceptorWithRequestIterator): + + with self.subTest(name=interceptor_class): + interceptor = interceptor_class() + + channel = aio.insecure_channel(self._server_target, + interceptors=[interceptor]) + stub = test_pb2_grpc.TestServiceStub(channel) + + payload = messages_pb2.Payload(body=b'\0' * + _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest( + payload=payload) + + async def request_iterator(): + for _ in range(_NUM_STREAM_REQUESTS): + yield request + + call = stub.StreamingInputCall(request_iterator()) + + validation = inject_callbacks(call) + + response = await call + + await validation + + await channel.close() + + async def test_add_done_callback_interceptor_task_finished(self): + for interceptor_class in (_StreamUnaryInterceptorEmpty, + _StreamUnaryInterceptorWithRequestIterator): + + with self.subTest(name=interceptor_class): + interceptor = interceptor_class() + + channel = aio.insecure_channel(self._server_target, + interceptors=[interceptor]) + stub = test_pb2_grpc.TestServiceStub(channel) + + payload = messages_pb2.Payload(body=b'\0' * + _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest( + payload=payload) + + async def request_iterator(): + for _ in range(_NUM_STREAM_REQUESTS): + yield request + + call = stub.StreamingInputCall(request_iterator()) + + response = await call + + validation = inject_callbacks(call) + + await validation + + await channel.close() + + async def test_multiple_interceptors_request_iterator(self): + for interceptor_class in (_StreamUnaryInterceptorEmpty, + _StreamUnaryInterceptorWithRequestIterator): + + with self.subTest(name=interceptor_class): + + interceptors = [interceptor_class(), interceptor_class()] + channel = aio.insecure_channel(self._server_target, + interceptors=interceptors) + stub = test_pb2_grpc.TestServiceStub(channel) + + payload = messages_pb2.Payload(body=b'\0' * + _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest( + payload=payload) + + async def request_iterator(): + for _ in range(_NUM_STREAM_REQUESTS): + yield request + + call = stub.StreamingInputCall(request_iterator()) + + response = await call + + self.assertEqual(_NUM_STREAM_REQUESTS * _REQUEST_PAYLOAD_SIZE, + response.aggregated_payload_size) + self.assertEqual(await call.code(), grpc.StatusCode.OK) + self.assertEqual(await call.initial_metadata(), aio.Metadata()) + self.assertEqual(await call.trailing_metadata(), aio.Metadata()) + self.assertEqual(await call.details(), '') + self.assertEqual(await call.debug_error_string(), '') + self.assertEqual(call.cancel(), False) + self.assertEqual(call.cancelled(), False) + self.assertEqual(call.done(), True) + + for interceptor in interceptors: + interceptor.assert_in_final_state(self) + + await channel.close() + + async def test_intercepts_request_iterator_rpc_error(self): + for interceptor_class in (_StreamUnaryInterceptorEmpty, + _StreamUnaryInterceptorWithRequestIterator): + + with self.subTest(name=interceptor_class): + channel = aio.insecure_channel( + UNREACHABLE_TARGET, interceptors=[interceptor_class()]) + stub = test_pb2_grpc.TestServiceStub(channel) + + payload = messages_pb2.Payload(body=b'\0' * + _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest( + payload=payload) + + # When there is an error the request iterator is no longer + # consumed. + async def request_iterator(): + for _ in range(_NUM_STREAM_REQUESTS): + yield request + + call = stub.StreamingInputCall(request_iterator()) + + with self.assertRaises(aio.AioRpcError) as exception_context: + await call + + self.assertEqual(grpc.StatusCode.UNAVAILABLE, + exception_context.exception.code()) + self.assertTrue(call.done()) + self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code()) + + await channel.close() + + async def test_intercepts_request_iterator_rpc_error_using_write(self): + for interceptor_class in (_StreamUnaryInterceptorEmpty, + _StreamUnaryInterceptorWithRequestIterator): + + with self.subTest(name=interceptor_class): + channel = aio.insecure_channel( + UNREACHABLE_TARGET, interceptors=[interceptor_class()]) + stub = test_pb2_grpc.TestServiceStub(channel) + + payload = messages_pb2.Payload(body=b'\0' * + _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest( + payload=payload) + + call = stub.StreamingInputCall() + + # When there is an error during the write, exception is raised. + with self.assertRaises(asyncio.InvalidStateError): + for _ in range(_NUM_STREAM_REQUESTS): + await call.write(request) + + with self.assertRaises(aio.AioRpcError) as exception_context: + await call + + self.assertEqual(grpc.StatusCode.UNAVAILABLE, + exception_context.exception.code()) + self.assertTrue(call.done()) + self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code()) + + await channel.close() + + async def test_cancel_before_rpc(self): + + interceptor_reached = asyncio.Event() + wait_for_ever = self.loop.create_future() + + class Interceptor(aio.StreamUnaryClientInterceptor): + + async def intercept_stream_unary(self, continuation, + client_call_details, + request_iterator): + interceptor_reached.set() + await wait_for_ever + + channel = aio.insecure_channel(self._server_target, + interceptors=[Interceptor()]) + stub = test_pb2_grpc.TestServiceStub(channel) + + payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest(payload=payload) + + call = stub.StreamingInputCall() + + self.assertFalse(call.cancelled()) + self.assertFalse(call.done()) + + await interceptor_reached.wait() + self.assertTrue(call.cancel()) + + # When there is an error during the write, exception is raised. + with self.assertRaises(asyncio.InvalidStateError): + for _ in range(_NUM_STREAM_REQUESTS): + await call.write(request) + + with self.assertRaises(asyncio.CancelledError): + await call + + self.assertTrue(call.cancelled()) + self.assertTrue(call.done()) + self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) + self.assertEqual(await call.initial_metadata(), None) + self.assertEqual(await call.trailing_metadata(), None) + await channel.close() + + async def test_cancel_after_rpc(self): + + interceptor_reached = asyncio.Event() + wait_for_ever = self.loop.create_future() + + class Interceptor(aio.StreamUnaryClientInterceptor): + + async def intercept_stream_unary(self, continuation, + client_call_details, + request_iterator): + call = await continuation(client_call_details, request_iterator) + interceptor_reached.set() + await wait_for_ever + + channel = aio.insecure_channel(self._server_target, + interceptors=[Interceptor()]) + stub = test_pb2_grpc.TestServiceStub(channel) + + payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest(payload=payload) + + call = stub.StreamingInputCall() + + self.assertFalse(call.cancelled()) + self.assertFalse(call.done()) + + await interceptor_reached.wait() + self.assertTrue(call.cancel()) + + # When there is an error during the write, exception is raised. + with self.assertRaises(asyncio.InvalidStateError): + for _ in range(_NUM_STREAM_REQUESTS): + await call.write(request) + + with self.assertRaises(asyncio.CancelledError): + await call + + self.assertTrue(call.cancelled()) + self.assertTrue(call.done()) + self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) + self.assertEqual(await call.initial_metadata(), None) + self.assertEqual(await call.trailing_metadata(), None) + await channel.close() + + async def test_cancel_while_writing(self): + # Test cancelation before making any write or after doing at least 1 + for num_writes_before_cancel in (0, 1): + with self.subTest(name="Num writes before cancel: {}".format( + num_writes_before_cancel)): + + channel = aio.insecure_channel( + UNREACHABLE_TARGET, + interceptors=[_StreamUnaryInterceptorWithRequestIterator()]) + stub = test_pb2_grpc.TestServiceStub(channel) + + payload = messages_pb2.Payload(body=b'\0' * + _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest( + payload=payload) + + call = stub.StreamingInputCall() + + with self.assertRaises(asyncio.InvalidStateError): + for i in range(_NUM_STREAM_REQUESTS): + if i == num_writes_before_cancel: + self.assertTrue(call.cancel()) + await call.write(request) + + with self.assertRaises(asyncio.CancelledError): + await call + + self.assertTrue(call.cancelled()) + self.assertTrue(call.done()) + self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) + + await channel.close() + + async def test_cancel_by_the_interceptor(self): + + class Interceptor(aio.StreamUnaryClientInterceptor): + + async def intercept_stream_unary(self, continuation, + client_call_details, + request_iterator): + call = await continuation(client_call_details, request_iterator) + call.cancel() + return call + + channel = aio.insecure_channel(UNREACHABLE_TARGET, + interceptors=[Interceptor()]) + stub = test_pb2_grpc.TestServiceStub(channel) + + payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest(payload=payload) + + call = stub.StreamingInputCall() + + with self.assertRaises(asyncio.InvalidStateError): + for i in range(_NUM_STREAM_REQUESTS): + await call.write(request) + + with self.assertRaises(asyncio.CancelledError): + await call + + self.assertTrue(call.cancelled()) + self.assertTrue(call.done()) + self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) + + await channel.close() + + async def test_exception_raised_by_interceptor(self): + + class InterceptorException(Exception): + pass + + class Interceptor(aio.StreamUnaryClientInterceptor): + + async def intercept_stream_unary(self, continuation, + client_call_details, + request_iterator): + raise InterceptorException + + channel = aio.insecure_channel(UNREACHABLE_TARGET, + interceptors=[Interceptor()]) + stub = test_pb2_grpc.TestServiceStub(channel) + + payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest(payload=payload) + + call = stub.StreamingInputCall() + + with self.assertRaises(InterceptorException): + for i in range(_NUM_STREAM_REQUESTS): + await call.write(request) + + with self.assertRaises(InterceptorException): + await call + + await channel.close() + + async def test_intercepts_prohibit_mixing_style(self): + channel = aio.insecure_channel( + self._server_target, interceptors=[_StreamUnaryInterceptorEmpty()]) + stub = test_pb2_grpc.TestServiceStub(channel) + + payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest(payload=payload) + + async def request_iterator(): + for _ in range(_NUM_STREAM_REQUESTS): + yield request + + call = stub.StreamingInputCall(request_iterator()) + + with self.assertRaises(grpc._cython.cygrpc.UsageError): + await call.write(request) + + with self.assertRaises(grpc._cython.cygrpc.UsageError): + await call.done_writing() + + await channel.close() + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/client_unary_stream_interceptor_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/client_unary_stream_interceptor_test.py new file mode 100644 index 00000000000..fd542fd16e9 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/client_unary_stream_interceptor_test.py @@ -0,0 +1,395 @@ +# Copyright 2020 The gRPC Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import logging +import unittest +import datetime + +import grpc + +from grpc.experimental import aio +from tests_aio.unit._constants import UNREACHABLE_TARGET +from tests_aio.unit._common import inject_callbacks +from tests_aio.unit._common import CountingResponseIterator +from tests_aio.unit._test_server import start_test_server +from tests_aio.unit._test_base import AioTestBase +from tests.unit.framework.common import test_constants +from src.proto.grpc.testing import messages_pb2, test_pb2_grpc + +_SHORT_TIMEOUT_S = 1.0 + +_NUM_STREAM_RESPONSES = 5 +_REQUEST_PAYLOAD_SIZE = 7 +_RESPONSE_PAYLOAD_SIZE = 7 +_RESPONSE_INTERVAL_US = int(_SHORT_TIMEOUT_S * 1000 * 1000) + + +class _UnaryStreamInterceptorEmpty(aio.UnaryStreamClientInterceptor): + + async def intercept_unary_stream(self, continuation, client_call_details, + request): + return await continuation(client_call_details, request) + + def assert_in_final_state(self, test: unittest.TestCase): + pass + + +class _UnaryStreamInterceptorWithResponseIterator( + aio.UnaryStreamClientInterceptor): + + async def intercept_unary_stream(self, continuation, client_call_details, + request): + call = await continuation(client_call_details, request) + self.response_iterator = CountingResponseIterator(call) + return self.response_iterator + + def assert_in_final_state(self, test: unittest.TestCase): + test.assertEqual(_NUM_STREAM_RESPONSES, + self.response_iterator.response_cnt) + + +class TestUnaryStreamClientInterceptor(AioTestBase): + + async def setUp(self): + self._server_target, self._server = await start_test_server() + + async def tearDown(self): + await self._server.stop(None) + + async def test_intercepts(self): + for interceptor_class in (_UnaryStreamInterceptorEmpty, + _UnaryStreamInterceptorWithResponseIterator): + + with self.subTest(name=interceptor_class): + interceptor = interceptor_class() + + request = messages_pb2.StreamingOutputCallRequest() + request.response_parameters.extend([ + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE) + ] * _NUM_STREAM_RESPONSES) + + channel = aio.insecure_channel(self._server_target, + interceptors=[interceptor]) + stub = test_pb2_grpc.TestServiceStub(channel) + call = stub.StreamingOutputCall(request) + + await call.wait_for_connection() + + response_cnt = 0 + async for response in call: + response_cnt += 1 + self.assertIs(type(response), + messages_pb2.StreamingOutputCallResponse) + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, + len(response.payload.body)) + + self.assertEqual(response_cnt, _NUM_STREAM_RESPONSES) + self.assertEqual(await call.code(), grpc.StatusCode.OK) + self.assertEqual(await call.initial_metadata(), aio.Metadata()) + self.assertEqual(await call.trailing_metadata(), aio.Metadata()) + self.assertEqual(await call.details(), '') + self.assertEqual(await call.debug_error_string(), '') + self.assertEqual(call.cancel(), False) + self.assertEqual(call.cancelled(), False) + self.assertEqual(call.done(), True) + + interceptor.assert_in_final_state(self) + + await channel.close() + + async def test_add_done_callback_interceptor_task_not_finished(self): + for interceptor_class in (_UnaryStreamInterceptorEmpty, + _UnaryStreamInterceptorWithResponseIterator): + + with self.subTest(name=interceptor_class): + interceptor = interceptor_class() + + request = messages_pb2.StreamingOutputCallRequest() + request.response_parameters.extend([ + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE) + ] * _NUM_STREAM_RESPONSES) + + channel = aio.insecure_channel(self._server_target, + interceptors=[interceptor]) + stub = test_pb2_grpc.TestServiceStub(channel) + call = stub.StreamingOutputCall(request) + + validation = inject_callbacks(call) + + async for response in call: + pass + + await validation + + await channel.close() + + async def test_add_done_callback_interceptor_task_finished(self): + for interceptor_class in (_UnaryStreamInterceptorEmpty, + _UnaryStreamInterceptorWithResponseIterator): + + with self.subTest(name=interceptor_class): + interceptor = interceptor_class() + + request = messages_pb2.StreamingOutputCallRequest() + request.response_parameters.extend([ + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE) + ] * _NUM_STREAM_RESPONSES) + + channel = aio.insecure_channel(self._server_target, + interceptors=[interceptor]) + stub = test_pb2_grpc.TestServiceStub(channel) + call = stub.StreamingOutputCall(request) + + # This ensures that the callbacks will be registered + # with the intercepted call rather than saving in the + # pending state list. + await call.wait_for_connection() + + validation = inject_callbacks(call) + + async for response in call: + pass + + await validation + + await channel.close() + + async def test_response_iterator_using_read(self): + interceptor = _UnaryStreamInterceptorWithResponseIterator() + + channel = aio.insecure_channel(self._server_target, + interceptors=[interceptor]) + stub = test_pb2_grpc.TestServiceStub(channel) + + request = messages_pb2.StreamingOutputCallRequest() + request.response_parameters.extend( + [messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)] * + _NUM_STREAM_RESPONSES) + + call = stub.StreamingOutputCall(request) + + response_cnt = 0 + for response in range(_NUM_STREAM_RESPONSES): + response = await call.read() + response_cnt += 1 + self.assertIs(type(response), + messages_pb2.StreamingOutputCallResponse) + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) + + self.assertEqual(response_cnt, _NUM_STREAM_RESPONSES) + self.assertEqual(interceptor.response_iterator.response_cnt, + _NUM_STREAM_RESPONSES) + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + await channel.close() + + async def test_multiple_interceptors_response_iterator(self): + for interceptor_class in (_UnaryStreamInterceptorEmpty, + _UnaryStreamInterceptorWithResponseIterator): + + with self.subTest(name=interceptor_class): + + interceptors = [interceptor_class(), interceptor_class()] + + channel = aio.insecure_channel(self._server_target, + interceptors=interceptors) + stub = test_pb2_grpc.TestServiceStub(channel) + + request = messages_pb2.StreamingOutputCallRequest() + request.response_parameters.extend([ + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE) + ] * _NUM_STREAM_RESPONSES) + + call = stub.StreamingOutputCall(request) + + response_cnt = 0 + async for response in call: + response_cnt += 1 + self.assertIs(type(response), + messages_pb2.StreamingOutputCallResponse) + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, + len(response.payload.body)) + + self.assertEqual(response_cnt, _NUM_STREAM_RESPONSES) + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + await channel.close() + + async def test_intercepts_response_iterator_rpc_error(self): + for interceptor_class in (_UnaryStreamInterceptorEmpty, + _UnaryStreamInterceptorWithResponseIterator): + + with self.subTest(name=interceptor_class): + + channel = aio.insecure_channel( + UNREACHABLE_TARGET, interceptors=[interceptor_class()]) + request = messages_pb2.StreamingOutputCallRequest() + stub = test_pb2_grpc.TestServiceStub(channel) + call = stub.StreamingOutputCall(request) + + with self.assertRaises(aio.AioRpcError) as exception_context: + async for response in call: + pass + + self.assertEqual(grpc.StatusCode.UNAVAILABLE, + exception_context.exception.code()) + + self.assertTrue(call.done()) + self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code()) + await channel.close() + + async def test_cancel_before_rpc(self): + + interceptor_reached = asyncio.Event() + wait_for_ever = self.loop.create_future() + + class Interceptor(aio.UnaryStreamClientInterceptor): + + async def intercept_unary_stream(self, continuation, + client_call_details, request): + interceptor_reached.set() + await wait_for_ever + + channel = aio.insecure_channel(UNREACHABLE_TARGET, + interceptors=[Interceptor()]) + request = messages_pb2.StreamingOutputCallRequest() + stub = test_pb2_grpc.TestServiceStub(channel) + call = stub.StreamingOutputCall(request) + + self.assertFalse(call.cancelled()) + self.assertFalse(call.done()) + + await interceptor_reached.wait() + self.assertTrue(call.cancel()) + + with self.assertRaises(asyncio.CancelledError): + async for response in call: + pass + + self.assertTrue(call.cancelled()) + self.assertTrue(call.done()) + self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) + self.assertEqual(await call.initial_metadata(), None) + self.assertEqual(await call.trailing_metadata(), None) + await channel.close() + + async def test_cancel_after_rpc(self): + + interceptor_reached = asyncio.Event() + wait_for_ever = self.loop.create_future() + + class Interceptor(aio.UnaryStreamClientInterceptor): + + async def intercept_unary_stream(self, continuation, + client_call_details, request): + call = await continuation(client_call_details, request) + interceptor_reached.set() + await wait_for_ever + + channel = aio.insecure_channel(UNREACHABLE_TARGET, + interceptors=[Interceptor()]) + request = messages_pb2.StreamingOutputCallRequest() + stub = test_pb2_grpc.TestServiceStub(channel) + call = stub.StreamingOutputCall(request) + + self.assertFalse(call.cancelled()) + self.assertFalse(call.done()) + + await interceptor_reached.wait() + self.assertTrue(call.cancel()) + + with self.assertRaises(asyncio.CancelledError): + async for response in call: + pass + + self.assertTrue(call.cancelled()) + self.assertTrue(call.done()) + self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) + self.assertEqual(await call.initial_metadata(), None) + self.assertEqual(await call.trailing_metadata(), None) + await channel.close() + + async def test_cancel_consuming_response_iterator(self): + request = messages_pb2.StreamingOutputCallRequest() + request.response_parameters.extend( + [messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)] * + _NUM_STREAM_RESPONSES) + + channel = aio.insecure_channel( + self._server_target, + interceptors=[_UnaryStreamInterceptorWithResponseIterator()]) + stub = test_pb2_grpc.TestServiceStub(channel) + call = stub.StreamingOutputCall(request) + + with self.assertRaises(asyncio.CancelledError): + async for response in call: + call.cancel() + + self.assertTrue(call.cancelled()) + self.assertTrue(call.done()) + self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) + await channel.close() + + async def test_cancel_by_the_interceptor(self): + + class Interceptor(aio.UnaryStreamClientInterceptor): + + async def intercept_unary_stream(self, continuation, + client_call_details, request): + call = await continuation(client_call_details, request) + call.cancel() + return call + + channel = aio.insecure_channel(UNREACHABLE_TARGET, + interceptors=[Interceptor()]) + request = messages_pb2.StreamingOutputCallRequest() + stub = test_pb2_grpc.TestServiceStub(channel) + call = stub.StreamingOutputCall(request) + + with self.assertRaises(asyncio.CancelledError): + async for response in call: + pass + + self.assertTrue(call.cancelled()) + self.assertTrue(call.done()) + self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) + await channel.close() + + async def test_exception_raised_by_interceptor(self): + + class InterceptorException(Exception): + pass + + class Interceptor(aio.UnaryStreamClientInterceptor): + + async def intercept_unary_stream(self, continuation, + client_call_details, request): + raise InterceptorException + + channel = aio.insecure_channel(UNREACHABLE_TARGET, + interceptors=[Interceptor()]) + request = messages_pb2.StreamingOutputCallRequest() + stub = test_pb2_grpc.TestServiceStub(channel) + call = stub.StreamingOutputCall(request) + + with self.assertRaises(InterceptorException): + async for response in call: + pass + + await channel.close() + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/client_unary_unary_interceptor_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/client_unary_unary_interceptor_test.py new file mode 100644 index 00000000000..e64daec7df4 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/client_unary_unary_interceptor_test.py @@ -0,0 +1,699 @@ +# Copyright 2019 The gRPC Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import logging +import unittest + +import grpc + +from grpc.experimental import aio +from tests_aio.unit._test_server import start_test_server, _INITIAL_METADATA_KEY, _TRAILING_METADATA_KEY +from tests_aio.unit import _constants +from tests_aio.unit import _common +from tests_aio.unit._test_base import AioTestBase +from src.proto.grpc.testing import messages_pb2, test_pb2_grpc + +_LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!' +_INITIAL_METADATA_TO_INJECT = aio.Metadata( + (_INITIAL_METADATA_KEY, 'extra info'), + (_TRAILING_METADATA_KEY, b'\x13\x37'), +) +_TIMEOUT_CHECK_IF_CALLBACK_WAS_CALLED = 1.0 + + +class TestUnaryUnaryClientInterceptor(AioTestBase): + + async def setUp(self): + self._server_target, self._server = await start_test_server() + + async def tearDown(self): + await self._server.stop(None) + + def test_invalid_interceptor(self): + + class InvalidInterceptor: + """Just an invalid Interceptor""" + + with self.assertRaises(ValueError): + aio.insecure_channel("", interceptors=[InvalidInterceptor()]) + + async def test_executed_right_order(self): + + interceptors_executed = [] + + class Interceptor(aio.UnaryUnaryClientInterceptor): + """Interceptor used for testing if the interceptor is being called""" + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + interceptors_executed.append(self) + call = await continuation(client_call_details, request) + return call + + interceptors = [Interceptor() for i in range(2)] + + async with aio.insecure_channel(self._server_target, + interceptors=interceptors) as channel: + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + call = multicallable(messages_pb2.SimpleRequest()) + response = await call + + # Check that all interceptors were executed, and were executed + # in the right order. + self.assertSequenceEqual(interceptors_executed, interceptors) + + self.assertIsInstance(response, messages_pb2.SimpleResponse) + + @unittest.expectedFailure + # TODO(https://github.com/grpc/grpc/issues/20144) Once metadata support is + # implemented in the client-side, this test must be implemented. + def test_modify_metadata(self): + raise NotImplementedError() + + @unittest.expectedFailure + # TODO(https://github.com/grpc/grpc/issues/20532) Once credentials support is + # implemented in the client-side, this test must be implemented. + def test_modify_credentials(self): + raise NotImplementedError() + + async def test_status_code_Ok(self): + + class StatusCodeOkInterceptor(aio.UnaryUnaryClientInterceptor): + """Interceptor used for observing status code Ok returned by the RPC""" + + def __init__(self): + self.status_code_Ok_observed = False + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + call = await continuation(client_call_details, request) + code = await call.code() + if code == grpc.StatusCode.OK: + self.status_code_Ok_observed = True + + return call + + interceptor = StatusCodeOkInterceptor() + + async with aio.insecure_channel(self._server_target, + interceptors=[interceptor]) as channel: + + # when no error StatusCode.OK must be observed + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + + await multicallable(messages_pb2.SimpleRequest()) + + self.assertTrue(interceptor.status_code_Ok_observed) + + async def test_add_timeout(self): + + class TimeoutInterceptor(aio.UnaryUnaryClientInterceptor): + """Interceptor used for adding a timeout to the RPC""" + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + new_client_call_details = aio.ClientCallDetails( + method=client_call_details.method, + timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2, + metadata=client_call_details.metadata, + credentials=client_call_details.credentials, + wait_for_ready=client_call_details.wait_for_ready) + return await continuation(new_client_call_details, request) + + interceptor = TimeoutInterceptor() + + async with aio.insecure_channel(self._server_target, + interceptors=[interceptor]) as channel: + + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCallWithSleep', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + + call = multicallable(messages_pb2.SimpleRequest()) + + with self.assertRaises(aio.AioRpcError) as exception_context: + await call + + self.assertEqual(exception_context.exception.code(), + grpc.StatusCode.DEADLINE_EXCEEDED) + + self.assertTrue(call.done()) + self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, await + call.code()) + + async def test_retry(self): + + class RetryInterceptor(aio.UnaryUnaryClientInterceptor): + """Simulates a Retry Interceptor which ends up by making + two RPC calls.""" + + def __init__(self): + self.calls = [] + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + + new_client_call_details = aio.ClientCallDetails( + method=client_call_details.method, + timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2, + metadata=client_call_details.metadata, + credentials=client_call_details.credentials, + wait_for_ready=client_call_details.wait_for_ready) + + try: + call = await continuation(new_client_call_details, request) + await call + except grpc.RpcError: + pass + + self.calls.append(call) + + new_client_call_details = aio.ClientCallDetails( + method=client_call_details.method, + timeout=None, + metadata=client_call_details.metadata, + credentials=client_call_details.credentials, + wait_for_ready=client_call_details.wait_for_ready) + + call = await continuation(new_client_call_details, request) + self.calls.append(call) + return call + + interceptor = RetryInterceptor() + + async with aio.insecure_channel(self._server_target, + interceptors=[interceptor]) as channel: + + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCallWithSleep', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + + call = multicallable(messages_pb2.SimpleRequest()) + + await call + + self.assertEqual(grpc.StatusCode.OK, await call.code()) + + # Check that two calls were made, first one finishing with + # a deadline and second one finishing ok.. + self.assertEqual(len(interceptor.calls), 2) + self.assertEqual(await interceptor.calls[0].code(), + grpc.StatusCode.DEADLINE_EXCEEDED) + self.assertEqual(await interceptor.calls[1].code(), + grpc.StatusCode.OK) + + async def test_rpcresponse(self): + + class Interceptor(aio.UnaryUnaryClientInterceptor): + """Raw responses are seen as reegular calls""" + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + call = await continuation(client_call_details, request) + response = await call + return call + + class ResponseInterceptor(aio.UnaryUnaryClientInterceptor): + """Return a raw response""" + response = messages_pb2.SimpleResponse() + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + return ResponseInterceptor.response + + interceptor, interceptor_response = Interceptor(), ResponseInterceptor() + + async with aio.insecure_channel( + self._server_target, + interceptors=[interceptor, interceptor_response]) as channel: + + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + + call = multicallable(messages_pb2.SimpleRequest()) + response = await call + + # Check that the response returned is the one returned by the + # interceptor + self.assertEqual(id(response), id(ResponseInterceptor.response)) + + # Check all of the UnaryUnaryCallResponse attributes + self.assertTrue(call.done()) + self.assertFalse(call.cancel()) + self.assertFalse(call.cancelled()) + self.assertEqual(await call.code(), grpc.StatusCode.OK) + self.assertEqual(await call.details(), '') + self.assertEqual(await call.initial_metadata(), None) + self.assertEqual(await call.trailing_metadata(), None) + self.assertEqual(await call.debug_error_string(), None) + + +class TestInterceptedUnaryUnaryCall(AioTestBase): + + async def setUp(self): + self._server_target, self._server = await start_test_server() + + async def tearDown(self): + await self._server.stop(None) + + async def test_call_ok(self): + + class Interceptor(aio.UnaryUnaryClientInterceptor): + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + call = await continuation(client_call_details, request) + return call + + async with aio.insecure_channel(self._server_target, + interceptors=[Interceptor() + ]) as channel: + + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + call = multicallable(messages_pb2.SimpleRequest()) + response = await call + + self.assertTrue(call.done()) + self.assertFalse(call.cancelled()) + self.assertEqual(type(response), messages_pb2.SimpleResponse) + self.assertEqual(await call.code(), grpc.StatusCode.OK) + self.assertEqual(await call.details(), '') + self.assertEqual(await call.initial_metadata(), aio.Metadata()) + self.assertEqual(await call.trailing_metadata(), aio.Metadata()) + + async def test_call_ok_awaited(self): + + class Interceptor(aio.UnaryUnaryClientInterceptor): + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + call = await continuation(client_call_details, request) + await call + return call + + async with aio.insecure_channel(self._server_target, + interceptors=[Interceptor() + ]) as channel: + + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + call = multicallable(messages_pb2.SimpleRequest()) + response = await call + + self.assertTrue(call.done()) + self.assertFalse(call.cancelled()) + self.assertEqual(type(response), messages_pb2.SimpleResponse) + self.assertEqual(await call.code(), grpc.StatusCode.OK) + self.assertEqual(await call.details(), '') + self.assertEqual(await call.initial_metadata(), aio.Metadata()) + self.assertEqual(await call.trailing_metadata(), aio.Metadata()) + + async def test_call_rpc_error(self): + + class Interceptor(aio.UnaryUnaryClientInterceptor): + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + call = await continuation(client_call_details, request) + return call + + async with aio.insecure_channel(self._server_target, + interceptors=[Interceptor() + ]) as channel: + + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCallWithSleep', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + + call = multicallable( + messages_pb2.SimpleRequest(), + timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2) + + with self.assertRaises(aio.AioRpcError) as exception_context: + await call + + self.assertTrue(call.done()) + self.assertFalse(call.cancelled()) + self.assertEqual(await call.code(), + grpc.StatusCode.DEADLINE_EXCEEDED) + self.assertEqual(await call.details(), 'Deadline Exceeded') + self.assertEqual(await call.initial_metadata(), aio.Metadata()) + self.assertEqual(await call.trailing_metadata(), aio.Metadata()) + + async def test_call_rpc_error_awaited(self): + + class Interceptor(aio.UnaryUnaryClientInterceptor): + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + call = await continuation(client_call_details, request) + await call + return call + + async with aio.insecure_channel(self._server_target, + interceptors=[Interceptor() + ]) as channel: + + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCallWithSleep', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + + call = multicallable( + messages_pb2.SimpleRequest(), + timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2) + + with self.assertRaises(aio.AioRpcError) as exception_context: + await call + + self.assertTrue(call.done()) + self.assertFalse(call.cancelled()) + self.assertEqual(await call.code(), + grpc.StatusCode.DEADLINE_EXCEEDED) + self.assertEqual(await call.details(), 'Deadline Exceeded') + self.assertEqual(await call.initial_metadata(), aio.Metadata()) + self.assertEqual(await call.trailing_metadata(), aio.Metadata()) + + async def test_cancel_before_rpc(self): + + interceptor_reached = asyncio.Event() + wait_for_ever = self.loop.create_future() + + class Interceptor(aio.UnaryUnaryClientInterceptor): + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + interceptor_reached.set() + await wait_for_ever + + async with aio.insecure_channel(self._server_target, + interceptors=[Interceptor() + ]) as channel: + + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + call = multicallable(messages_pb2.SimpleRequest()) + + self.assertFalse(call.cancelled()) + self.assertFalse(call.done()) + + await interceptor_reached.wait() + self.assertTrue(call.cancel()) + + with self.assertRaises(asyncio.CancelledError): + await call + + self.assertTrue(call.cancelled()) + self.assertTrue(call.done()) + self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) + self.assertEqual(await call.details(), + _LOCAL_CANCEL_DETAILS_EXPECTATION) + self.assertEqual(await call.initial_metadata(), None) + self.assertEqual(await call.trailing_metadata(), None) + + async def test_cancel_after_rpc(self): + + interceptor_reached = asyncio.Event() + wait_for_ever = self.loop.create_future() + + class Interceptor(aio.UnaryUnaryClientInterceptor): + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + call = await continuation(client_call_details, request) + await call + interceptor_reached.set() + await wait_for_ever + + async with aio.insecure_channel(self._server_target, + interceptors=[Interceptor() + ]) as channel: + + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + call = multicallable(messages_pb2.SimpleRequest()) + + self.assertFalse(call.cancelled()) + self.assertFalse(call.done()) + + await interceptor_reached.wait() + self.assertTrue(call.cancel()) + + with self.assertRaises(asyncio.CancelledError): + await call + + self.assertTrue(call.cancelled()) + self.assertTrue(call.done()) + self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) + self.assertEqual(await call.details(), + _LOCAL_CANCEL_DETAILS_EXPECTATION) + self.assertEqual(await call.initial_metadata(), None) + self.assertEqual(await call.trailing_metadata(), None) + + async def test_cancel_inside_interceptor_after_rpc_awaiting(self): + + class Interceptor(aio.UnaryUnaryClientInterceptor): + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + call = await continuation(client_call_details, request) + call.cancel() + await call + return call + + async with aio.insecure_channel(self._server_target, + interceptors=[Interceptor() + ]) as channel: + + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + call = multicallable(messages_pb2.SimpleRequest()) + + with self.assertRaises(asyncio.CancelledError): + await call + + self.assertTrue(call.cancelled()) + self.assertTrue(call.done()) + self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) + self.assertEqual(await call.details(), + _LOCAL_CANCEL_DETAILS_EXPECTATION) + self.assertEqual(await call.initial_metadata(), None) + self.assertEqual(await call.trailing_metadata(), None) + + async def test_cancel_inside_interceptor_after_rpc_not_awaiting(self): + + class Interceptor(aio.UnaryUnaryClientInterceptor): + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + call = await continuation(client_call_details, request) + call.cancel() + return call + + async with aio.insecure_channel(self._server_target, + interceptors=[Interceptor() + ]) as channel: + + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + call = multicallable(messages_pb2.SimpleRequest()) + + with self.assertRaises(asyncio.CancelledError): + await call + + self.assertTrue(call.cancelled()) + self.assertTrue(call.done()) + self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) + self.assertEqual(await call.details(), + _LOCAL_CANCEL_DETAILS_EXPECTATION) + self.assertEqual(await call.initial_metadata(), aio.Metadata()) + self.assertEqual( + await call.trailing_metadata(), aio.Metadata(), + "When the raw response is None, empty metadata is returned") + + async def test_initial_metadata_modification(self): + + class Interceptor(aio.UnaryUnaryClientInterceptor): + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + new_metadata = aio.Metadata(*client_call_details.metadata, + *_INITIAL_METADATA_TO_INJECT) + new_details = aio.ClientCallDetails( + method=client_call_details.method, + timeout=client_call_details.timeout, + metadata=new_metadata, + credentials=client_call_details.credentials, + wait_for_ready=client_call_details.wait_for_ready, + ) + return await continuation(new_details, request) + + async with aio.insecure_channel(self._server_target, + interceptors=[Interceptor() + ]) as channel: + stub = test_pb2_grpc.TestServiceStub(channel) + call = stub.UnaryCall(messages_pb2.SimpleRequest()) + + # Expected to see the echoed initial metadata + self.assertTrue( + _common.seen_metadatum( + expected_key=_INITIAL_METADATA_KEY, + expected_value=_INITIAL_METADATA_TO_INJECT[ + _INITIAL_METADATA_KEY], + actual=await call.initial_metadata(), + )) + # Expected to see the echoed trailing metadata + self.assertTrue( + _common.seen_metadatum( + expected_key=_TRAILING_METADATA_KEY, + expected_value=_INITIAL_METADATA_TO_INJECT[ + _TRAILING_METADATA_KEY], + actual=await call.trailing_metadata(), + )) + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + async def test_add_done_callback_before_finishes(self): + called = asyncio.Event() + interceptor_can_continue = asyncio.Event() + + def callback(call): + called.set() + + class Interceptor(aio.UnaryUnaryClientInterceptor): + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + + await interceptor_can_continue.wait() + call = await continuation(client_call_details, request) + return call + + async with aio.insecure_channel(self._server_target, + interceptors=[Interceptor() + ]) as channel: + + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + call = multicallable(messages_pb2.SimpleRequest()) + call.add_done_callback(callback) + interceptor_can_continue.set() + await call + + try: + await asyncio.wait_for( + called.wait(), + timeout=_TIMEOUT_CHECK_IF_CALLBACK_WAS_CALLED) + except: + self.fail("Callback was not called") + + async def test_add_done_callback_after_finishes(self): + called = asyncio.Event() + + def callback(call): + called.set() + + class Interceptor(aio.UnaryUnaryClientInterceptor): + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + + call = await continuation(client_call_details, request) + return call + + async with aio.insecure_channel(self._server_target, + interceptors=[Interceptor() + ]) as channel: + + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + call = multicallable(messages_pb2.SimpleRequest()) + + await call + + call.add_done_callback(callback) + + try: + await asyncio.wait_for( + called.wait(), + timeout=_TIMEOUT_CHECK_IF_CALLBACK_WAS_CALLED) + except: + self.fail("Callback was not called") + + async def test_add_done_callback_after_finishes_before_await(self): + called = asyncio.Event() + + def callback(call): + called.set() + + class Interceptor(aio.UnaryUnaryClientInterceptor): + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + + call = await continuation(client_call_details, request) + return call + + async with aio.insecure_channel(self._server_target, + interceptors=[Interceptor() + ]) as channel: + + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + call = multicallable(messages_pb2.SimpleRequest()) + + call.add_done_callback(callback) + + await call + + try: + await asyncio.wait_for( + called.wait(), + timeout=_TIMEOUT_CHECK_IF_CALLBACK_WAS_CALLED) + except: + self.fail("Callback was not called") + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py new file mode 100644 index 00000000000..20543e95bf7 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py @@ -0,0 +1,138 @@ +# Copyright 2020 The gRPC Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests behavior of closing a grpc.aio.Channel.""" + +import asyncio +import logging +import unittest + +import grpc +from grpc.experimental import aio +from grpc.aio import _base_call + +from src.proto.grpc.testing import messages_pb2, test_pb2_grpc +from tests_aio.unit._test_base import AioTestBase +from tests_aio.unit._test_server import start_test_server + +_UNARY_CALL_METHOD_WITH_SLEEP = '/grpc.testing.TestService/UnaryCallWithSleep' +_LONG_TIMEOUT_THAT_SHOULD_NOT_EXPIRE = 60 + + +class TestCloseChannel(AioTestBase): + + async def setUp(self): + self._server_target, self._server = await start_test_server() + + async def tearDown(self): + await self._server.stop(None) + + async def test_graceful_close(self): + channel = aio.insecure_channel(self._server_target) + UnaryCallWithSleep = channel.unary_unary( + _UNARY_CALL_METHOD_WITH_SLEEP, + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString, + ) + + call = UnaryCallWithSleep(messages_pb2.SimpleRequest()) + + await channel.close(grace=_LONG_TIMEOUT_THAT_SHOULD_NOT_EXPIRE) + + self.assertEqual(grpc.StatusCode.OK, await call.code()) + + async def test_none_graceful_close(self): + channel = aio.insecure_channel(self._server_target) + UnaryCallWithSleep = channel.unary_unary( + _UNARY_CALL_METHOD_WITH_SLEEP, + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString, + ) + + call = UnaryCallWithSleep(messages_pb2.SimpleRequest()) + + await channel.close(None) + + self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) + + async def test_close_unary_unary(self): + channel = aio.insecure_channel(self._server_target) + stub = test_pb2_grpc.TestServiceStub(channel) + + calls = [stub.UnaryCall(messages_pb2.SimpleRequest()) for _ in range(2)] + + await channel.close() + + for call in calls: + self.assertTrue(call.cancelled()) + + async def test_close_unary_stream(self): + channel = aio.insecure_channel(self._server_target) + stub = test_pb2_grpc.TestServiceStub(channel) + + request = messages_pb2.StreamingOutputCallRequest() + calls = [stub.StreamingOutputCall(request) for _ in range(2)] + + await channel.close() + + for call in calls: + self.assertTrue(call.cancelled()) + + async def test_close_stream_unary(self): + channel = aio.insecure_channel(self._server_target) + stub = test_pb2_grpc.TestServiceStub(channel) + + calls = [stub.StreamingInputCall() for _ in range(2)] + + await channel.close() + + for call in calls: + self.assertTrue(call.cancelled()) + + async def test_close_stream_stream(self): + channel = aio.insecure_channel(self._server_target) + stub = test_pb2_grpc.TestServiceStub(channel) + + calls = [stub.FullDuplexCall() for _ in range(2)] + + await channel.close() + + for call in calls: + self.assertTrue(call.cancelled()) + + async def test_close_async_context(self): + async with aio.insecure_channel(self._server_target) as channel: + stub = test_pb2_grpc.TestServiceStub(channel) + calls = [ + stub.UnaryCall(messages_pb2.SimpleRequest()) for _ in range(2) + ] + + for call in calls: + self.assertTrue(call.cancelled()) + + async def test_channel_isolation(self): + async with aio.insecure_channel(self._server_target) as channel1: + async with aio.insecure_channel(self._server_target) as channel2: + stub1 = test_pb2_grpc.TestServiceStub(channel1) + stub2 = test_pb2_grpc.TestServiceStub(channel2) + + call1 = stub1.UnaryCall(messages_pb2.SimpleRequest()) + call2 = stub2.UnaryCall(messages_pb2.SimpleRequest()) + + self.assertFalse(call1.cancelled()) + self.assertTrue(call2.cancelled()) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/compatibility_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/compatibility_test.py new file mode 100644 index 00000000000..0bb3a3acc89 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/compatibility_test.py @@ -0,0 +1,380 @@ +# Copyright 2020 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing the compatibility between AsyncIO stack and the old stack.""" + +import asyncio +import logging +import os +import random +import threading +import unittest +from concurrent.futures import ThreadPoolExecutor +from typing import Callable, Iterable, Sequence, Tuple + +import grpc +from grpc.experimental import aio + +from src.proto.grpc.testing import messages_pb2, test_pb2_grpc +from tests.unit.framework.common import test_constants +from tests_aio.unit import _common +from tests_aio.unit._test_base import AioTestBase +from tests_aio.unit._test_server import TestServiceServicer, start_test_server + +_NUM_STREAM_RESPONSES = 5 +_REQUEST_PAYLOAD_SIZE = 7 +_RESPONSE_PAYLOAD_SIZE = 42 +_REQUEST = b'\x03\x07' +_ADHOC_METHOD = '/test/AdHoc' + + +def _unique_options() -> Sequence[Tuple[str, float]]: + return (('iv', random.random()),) + + +class _AdhocGenericHandler(grpc.GenericRpcHandler): + _handler: grpc.RpcMethodHandler + + def __init__(self): + self._handler = None + + def set_adhoc_handler(self, handler: grpc.RpcMethodHandler): + self._handler = handler + + def service(self, handler_call_details): + if handler_call_details.method == _ADHOC_METHOD: + return self._handler + else: + return None + + + os.environ.get('GRPC_ASYNCIO_ENGINE', '').lower() == 'custom_io_manager', + 'Compatible mode needs POLLER completion queue.') +class TestCompatibility(AioTestBase): + + async def setUp(self): + self._async_server = aio.server( + options=(('grpc.so_reuseport', 0),), + migration_thread_pool=ThreadPoolExecutor()) + + test_pb2_grpc.add_TestServiceServicer_to_server(TestServiceServicer(), + self._async_server) + self._adhoc_handlers = _AdhocGenericHandler() + self._async_server.add_generic_rpc_handlers((self._adhoc_handlers,)) + + port = self._async_server.add_insecure_port('[::]:0') + address = 'localhost:%d' % port + await self._async_server.start() + + # Create async stub + self._async_channel = aio.insecure_channel(address, + options=_unique_options()) + self._async_stub = test_pb2_grpc.TestServiceStub(self._async_channel) + + # Create sync stub + self._sync_channel = grpc.insecure_channel(address, + options=_unique_options()) + self._sync_stub = test_pb2_grpc.TestServiceStub(self._sync_channel) + + async def tearDown(self): + self._sync_channel.close() + await self._async_channel.close() + await self._async_server.stop(None) + + async def _run_in_another_thread(self, func: Callable[[], None]): + work_done = asyncio.Event(loop=self.loop) + + def thread_work(): + func() + self.loop.call_soon_threadsafe(work_done.set) + + thread = threading.Thread(target=thread_work, daemon=True) + thread.start() + await work_done.wait() + thread.join() + + async def test_unary_unary(self): + # Calling async API in this thread + await self._async_stub.UnaryCall(messages_pb2.SimpleRequest(), + timeout=test_constants.LONG_TIMEOUT) + + # Calling sync API in a different thread + def sync_work() -> None: + response, call = self._sync_stub.UnaryCall.with_call( + messages_pb2.SimpleRequest(), + timeout=test_constants.LONG_TIMEOUT) + self.assertIsInstance(response, messages_pb2.SimpleResponse) + self.assertEqual(grpc.StatusCode.OK, call.code()) + + await self._run_in_another_thread(sync_work) + + async def test_unary_stream(self): + request = messages_pb2.StreamingOutputCallRequest() + for _ in range(_NUM_STREAM_RESPONSES): + request.response_parameters.append( + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) + + # Calling async API in this thread + call = self._async_stub.StreamingOutputCall(request) + + for _ in range(_NUM_STREAM_RESPONSES): + await call.read() + self.assertEqual(grpc.StatusCode.OK, await call.code()) + + # Calling sync API in a different thread + def sync_work() -> None: + response_iterator = self._sync_stub.StreamingOutputCall(request) + for response in response_iterator: + assert _RESPONSE_PAYLOAD_SIZE == len(response.payload.body) + self.assertEqual(grpc.StatusCode.OK, response_iterator.code()) + + await self._run_in_another_thread(sync_work) + + async def test_stream_unary(self): + payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest(payload=payload) + + # Calling async API in this thread + async def gen(): + for _ in range(_NUM_STREAM_RESPONSES): + yield request + + response = await self._async_stub.StreamingInputCall(gen()) + self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, + response.aggregated_payload_size) + + # Calling sync API in a different thread + def sync_work() -> None: + response = self._sync_stub.StreamingInputCall( + iter([request] * _NUM_STREAM_RESPONSES)) + self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, + response.aggregated_payload_size) + + await self._run_in_another_thread(sync_work) + + async def test_stream_stream(self): + request = messages_pb2.StreamingOutputCallRequest() + request.response_parameters.append( + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) + + # Calling async API in this thread + call = self._async_stub.FullDuplexCall() + + for _ in range(_NUM_STREAM_RESPONSES): + await call.write(request) + response = await call.read() + assert _RESPONSE_PAYLOAD_SIZE == len(response.payload.body) + + await call.done_writing() + assert await call.code() == grpc.StatusCode.OK + + # Calling sync API in a different thread + def sync_work() -> None: + response_iterator = self._sync_stub.FullDuplexCall(iter([request])) + for response in response_iterator: + assert _RESPONSE_PAYLOAD_SIZE == len(response.payload.body) + self.assertEqual(grpc.StatusCode.OK, response_iterator.code()) + + await self._run_in_another_thread(sync_work) + + async def test_server(self): + + class GenericHandlers(grpc.GenericRpcHandler): + + def service(self, handler_call_details): + return grpc.unary_unary_rpc_method_handler(lambda x, _: x) + + # It's fine to instantiate server object in the event loop thread. + # The server will spawn its own serving thread. + server = grpc.server(ThreadPoolExecutor(), + handlers=(GenericHandlers(),)) + port = server.add_insecure_port('localhost:0') + server.start() + + def sync_work() -> None: + for _ in range(100): + with grpc.insecure_channel('localhost:%d' % port) as channel: + response = channel.unary_unary('/test/test')(b'\x07\x08') + self.assertEqual(response, b'\x07\x08') + + await self._run_in_another_thread(sync_work) + + async def test_many_loop(self): + address, server = await start_test_server() + + # Run another loop in another thread + def sync_work(): + + async def async_work(): + # Create async stub + async_channel = aio.insecure_channel(address, + options=_unique_options()) + async_stub = test_pb2_grpc.TestServiceStub(async_channel) + + call = async_stub.UnaryCall(messages_pb2.SimpleRequest()) + response = await call + self.assertIsInstance(response, messages_pb2.SimpleResponse) + self.assertEqual(grpc.StatusCode.OK, await call.code()) + + loop = asyncio.new_event_loop() + loop.run_until_complete(async_work()) + + await self._run_in_another_thread(sync_work) + await server.stop(None) + + async def test_sync_unary_unary_success(self): + + @grpc.unary_unary_rpc_method_handler + def echo_unary_unary(request: bytes, unused_context): + return request + + self._adhoc_handlers.set_adhoc_handler(echo_unary_unary) + response = await self._async_channel.unary_unary(_ADHOC_METHOD)(_REQUEST + ) + self.assertEqual(_REQUEST, response) + + async def test_sync_unary_unary_metadata(self): + metadata = (('unique', 'key-42'),) + + @grpc.unary_unary_rpc_method_handler + def metadata_unary_unary(request: bytes, context: grpc.ServicerContext): + context.send_initial_metadata(metadata) + return request + + self._adhoc_handlers.set_adhoc_handler(metadata_unary_unary) + call = self._async_channel.unary_unary(_ADHOC_METHOD)(_REQUEST) + self.assertTrue( + _common.seen_metadata(aio.Metadata(*metadata), await + call.initial_metadata())) + + async def test_sync_unary_unary_abort(self): + + @grpc.unary_unary_rpc_method_handler + def abort_unary_unary(request: bytes, context: grpc.ServicerContext): + context.abort(grpc.StatusCode.INTERNAL, 'Test') + + self._adhoc_handlers.set_adhoc_handler(abort_unary_unary) + with self.assertRaises(aio.AioRpcError) as exception_context: + await self._async_channel.unary_unary(_ADHOC_METHOD)(_REQUEST) + self.assertEqual(grpc.StatusCode.INTERNAL, + exception_context.exception.code()) + + async def test_sync_unary_unary_set_code(self): + + @grpc.unary_unary_rpc_method_handler + def set_code_unary_unary(request: bytes, context: grpc.ServicerContext): + context.set_code(grpc.StatusCode.INTERNAL) + + self._adhoc_handlers.set_adhoc_handler(set_code_unary_unary) + with self.assertRaises(aio.AioRpcError) as exception_context: + await self._async_channel.unary_unary(_ADHOC_METHOD)(_REQUEST) + self.assertEqual(grpc.StatusCode.INTERNAL, + exception_context.exception.code()) + + async def test_sync_unary_stream_success(self): + + @grpc.unary_stream_rpc_method_handler + def echo_unary_stream(request: bytes, unused_context): + for _ in range(_NUM_STREAM_RESPONSES): + yield request + + self._adhoc_handlers.set_adhoc_handler(echo_unary_stream) + call = self._async_channel.unary_stream(_ADHOC_METHOD)(_REQUEST) + async for response in call: + self.assertEqual(_REQUEST, response) + + async def test_sync_unary_stream_error(self): + + @grpc.unary_stream_rpc_method_handler + def error_unary_stream(request: bytes, unused_context): + for _ in range(_NUM_STREAM_RESPONSES): + yield request + raise RuntimeError('Test') + + self._adhoc_handlers.set_adhoc_handler(error_unary_stream) + call = self._async_channel.unary_stream(_ADHOC_METHOD)(_REQUEST) + with self.assertRaises(aio.AioRpcError) as exception_context: + async for response in call: + self.assertEqual(_REQUEST, response) + self.assertEqual(grpc.StatusCode.UNKNOWN, + exception_context.exception.code()) + + async def test_sync_stream_unary_success(self): + + @grpc.stream_unary_rpc_method_handler + def echo_stream_unary(request_iterator: Iterable[bytes], + unused_context): + self.assertEqual(len(list(request_iterator)), _NUM_STREAM_RESPONSES) + return _REQUEST + + self._adhoc_handlers.set_adhoc_handler(echo_stream_unary) + request_iterator = iter([_REQUEST] * _NUM_STREAM_RESPONSES) + response = await self._async_channel.stream_unary(_ADHOC_METHOD)( + request_iterator) + self.assertEqual(_REQUEST, response) + + async def test_sync_stream_unary_error(self): + + @grpc.stream_unary_rpc_method_handler + def echo_stream_unary(request_iterator: Iterable[bytes], + unused_context): + self.assertEqual(len(list(request_iterator)), _NUM_STREAM_RESPONSES) + raise RuntimeError('Test') + + self._adhoc_handlers.set_adhoc_handler(echo_stream_unary) + request_iterator = iter([_REQUEST] * _NUM_STREAM_RESPONSES) + with self.assertRaises(aio.AioRpcError) as exception_context: + response = await self._async_channel.stream_unary(_ADHOC_METHOD)( + request_iterator) + self.assertEqual(grpc.StatusCode.UNKNOWN, + exception_context.exception.code()) + + async def test_sync_stream_stream_success(self): + + @grpc.stream_stream_rpc_method_handler + def echo_stream_stream(request_iterator: Iterable[bytes], + unused_context): + for request in request_iterator: + yield request + + self._adhoc_handlers.set_adhoc_handler(echo_stream_stream) + request_iterator = iter([_REQUEST] * _NUM_STREAM_RESPONSES) + call = self._async_channel.stream_stream(_ADHOC_METHOD)( + request_iterator) + async for response in call: + self.assertEqual(_REQUEST, response) + + async def test_sync_stream_stream_error(self): + + @grpc.stream_stream_rpc_method_handler + def echo_stream_stream(request_iterator: Iterable[bytes], + unused_context): + for request in request_iterator: + yield request + raise RuntimeError('test') + + self._adhoc_handlers.set_adhoc_handler(echo_stream_stream) + request_iterator = iter([_REQUEST] * _NUM_STREAM_RESPONSES) + call = self._async_channel.stream_stream(_ADHOC_METHOD)( + request_iterator) + with self.assertRaises(aio.AioRpcError) as exception_context: + async for response in call: + self.assertEqual(_REQUEST, response) + self.assertEqual(grpc.StatusCode.UNKNOWN, + exception_context.exception.code()) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/compression_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/compression_test.py new file mode 100644 index 00000000000..9d93885ea23 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/compression_test.py @@ -0,0 +1,196 @@ +# Copyright 2020 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests behavior around the compression mechanism.""" + +import asyncio +import logging +import platform +import random +import unittest + +import grpc +from grpc.experimental import aio + +from tests_aio.unit._test_base import AioTestBase +from tests_aio.unit import _common + +_GZIP_CHANNEL_ARGUMENT = ('grpc.default_compression_algorithm', 2) +_GZIP_DISABLED_CHANNEL_ARGUMENT = ('grpc.compression_enabled_algorithms_bitset', + 3) +_DEFLATE_DISABLED_CHANNEL_ARGUMENT = ( + 'grpc.compression_enabled_algorithms_bitset', 5) + +_TEST_UNARY_UNARY = '/test/TestUnaryUnary' +_TEST_SET_COMPRESSION = '/test/TestSetCompression' +_TEST_DISABLE_COMPRESSION_UNARY = '/test/TestDisableCompressionUnary' +_TEST_DISABLE_COMPRESSION_STREAM = '/test/TestDisableCompressionStream' + +_REQUEST = b'\x01' * 100 +_RESPONSE = b'\x02' * 100 + + +async def _test_unary_unary(unused_request, unused_context): + return _RESPONSE + + +async def _test_set_compression(unused_request_iterator, context): + assert _REQUEST == await context.read() + context.set_compression(grpc.Compression.Deflate) + await context.write(_RESPONSE) + try: + context.set_compression(grpc.Compression.Deflate) + except RuntimeError: + # NOTE(lidiz) Testing if the servicer context raises exception when + # the set_compression method is called after initial_metadata sent. + # After the initial_metadata sent, the server-side has no control over + # which compression algorithm it should use. + pass + else: + raise ValueError( + 'Expecting exceptions if set_compression is not effective') + + +async def _test_disable_compression_unary(request, context): + assert _REQUEST == request + context.set_compression(grpc.Compression.Deflate) + context.disable_next_message_compression() + return _RESPONSE + + +async def _test_disable_compression_stream(unused_request_iterator, context): + assert _REQUEST == await context.read() + context.set_compression(grpc.Compression.Deflate) + await context.write(_RESPONSE) + context.disable_next_message_compression() + await context.write(_RESPONSE) + await context.write(_RESPONSE) + + +_ROUTING_TABLE = { + _TEST_UNARY_UNARY: + grpc.unary_unary_rpc_method_handler(_test_unary_unary), + _TEST_SET_COMPRESSION: + grpc.stream_stream_rpc_method_handler(_test_set_compression), + _TEST_DISABLE_COMPRESSION_UNARY: + grpc.unary_unary_rpc_method_handler(_test_disable_compression_unary), + _TEST_DISABLE_COMPRESSION_STREAM: + grpc.stream_stream_rpc_method_handler(_test_disable_compression_stream), +} + + +class _GenericHandler(grpc.GenericRpcHandler): + + def service(self, handler_call_details): + return _ROUTING_TABLE.get(handler_call_details.method) + + +async def _start_test_server(options=None): + server = aio.server(options=options) + port = server.add_insecure_port('[::]:0') + server.add_generic_rpc_handlers((_GenericHandler(),)) + await server.start() + return f'localhost:{port}', server + + +class TestCompression(AioTestBase): + + async def setUp(self): + server_options = (_GZIP_DISABLED_CHANNEL_ARGUMENT,) + self._address, self._server = await _start_test_server(server_options) + self._channel = aio.insecure_channel(self._address) + + async def tearDown(self): + await self._channel.close() + await self._server.stop(None) + + async def test_channel_level_compression_baned_compression(self): + # GZIP is disabled, this call should fail + async with aio.insecure_channel( + self._address, compression=grpc.Compression.Gzip) as channel: + multicallable = channel.unary_unary(_TEST_UNARY_UNARY) + call = multicallable(_REQUEST) + with self.assertRaises(aio.AioRpcError) as exception_context: + await call + rpc_error = exception_context.exception + self.assertEqual(grpc.StatusCode.UNIMPLEMENTED, rpc_error.code()) + + async def test_channel_level_compression_allowed_compression(self): + # Deflate is allowed, this call should succeed + async with aio.insecure_channel( + self._address, compression=grpc.Compression.Deflate) as channel: + multicallable = channel.unary_unary(_TEST_UNARY_UNARY) + call = multicallable(_REQUEST) + self.assertEqual(grpc.StatusCode.OK, await call.code()) + + async def test_client_call_level_compression_baned_compression(self): + multicallable = self._channel.unary_unary(_TEST_UNARY_UNARY) + + # GZIP is disabled, this call should fail + call = multicallable(_REQUEST, compression=grpc.Compression.Gzip) + with self.assertRaises(aio.AioRpcError) as exception_context: + await call + rpc_error = exception_context.exception + self.assertEqual(grpc.StatusCode.UNIMPLEMENTED, rpc_error.code()) + + async def test_client_call_level_compression_allowed_compression(self): + multicallable = self._channel.unary_unary(_TEST_UNARY_UNARY) + + # Deflate is allowed, this call should succeed + call = multicallable(_REQUEST, compression=grpc.Compression.Deflate) + self.assertEqual(grpc.StatusCode.OK, await call.code()) + + async def test_server_call_level_compression(self): + multicallable = self._channel.stream_stream(_TEST_SET_COMPRESSION) + call = multicallable() + await call.write(_REQUEST) + await call.done_writing() + self.assertEqual(_RESPONSE, await call.read()) + self.assertEqual(grpc.StatusCode.OK, await call.code()) + + async def test_server_disable_compression_unary(self): + multicallable = self._channel.unary_unary( + _TEST_DISABLE_COMPRESSION_UNARY) + call = multicallable(_REQUEST) + self.assertEqual(_RESPONSE, await call) + self.assertEqual(grpc.StatusCode.OK, await call.code()) + + async def test_server_disable_compression_stream(self): + multicallable = self._channel.stream_stream( + _TEST_DISABLE_COMPRESSION_STREAM) + call = multicallable() + await call.write(_REQUEST) + await call.done_writing() + self.assertEqual(_RESPONSE, await call.read()) + self.assertEqual(_RESPONSE, await call.read()) + self.assertEqual(_RESPONSE, await call.read()) + self.assertEqual(grpc.StatusCode.OK, await call.code()) + + async def test_server_default_compression_algorithm(self): + server = aio.server(compression=grpc.Compression.Deflate) + port = server.add_insecure_port('[::]:0') + server.add_generic_rpc_handlers((_GenericHandler(),)) + await server.start() + + async with aio.insecure_channel(f'localhost:{port}') as channel: + multicallable = channel.unary_unary(_TEST_UNARY_UNARY) + call = multicallable(_REQUEST) + self.assertEqual(_RESPONSE, await call) + self.assertEqual(grpc.StatusCode.OK, await call.code()) + + await server.stop(None) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/connectivity_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/connectivity_test.py new file mode 100644 index 00000000000..7f98329070b --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/connectivity_test.py @@ -0,0 +1,112 @@ +# Copyright 2019 The gRPC Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests behavior of the connectivity state.""" + +import asyncio +import logging +import threading +import time +import unittest + +import grpc +from grpc.experimental import aio + +from tests.unit.framework.common import test_constants +from tests_aio.unit import _common +from tests_aio.unit._constants import UNREACHABLE_TARGET +from tests_aio.unit._test_base import AioTestBase +from tests_aio.unit._test_server import start_test_server + + +class TestConnectivityState(AioTestBase): + + async def setUp(self): + self._server_address, self._server = await start_test_server() + + async def tearDown(self): + await self._server.stop(None) + + async def test_unavailable_backend(self): + async with aio.insecure_channel(UNREACHABLE_TARGET) as channel: + self.assertEqual(grpc.ChannelConnectivity.IDLE, + channel.get_state(False)) + self.assertEqual(grpc.ChannelConnectivity.IDLE, + channel.get_state(True)) + + # Should not time out + await asyncio.wait_for( + _common.block_until_certain_state( + channel, grpc.ChannelConnectivity.TRANSIENT_FAILURE), + test_constants.SHORT_TIMEOUT) + + async def test_normal_backend(self): + async with aio.insecure_channel(self._server_address) as channel: + current_state = channel.get_state(True) + self.assertEqual(grpc.ChannelConnectivity.IDLE, current_state) + + # Should not time out + await asyncio.wait_for( + _common.block_until_certain_state( + channel, grpc.ChannelConnectivity.READY), + test_constants.SHORT_TIMEOUT) + + async def test_timeout(self): + async with aio.insecure_channel(self._server_address) as channel: + self.assertEqual(grpc.ChannelConnectivity.IDLE, + channel.get_state(False)) + + # If timed out, the function should return None. + with self.assertRaises(asyncio.TimeoutError): + await asyncio.wait_for( + _common.block_until_certain_state( + channel, grpc.ChannelConnectivity.READY), + test_constants.SHORT_TIMEOUT) + + async def test_shutdown(self): + channel = aio.insecure_channel(self._server_address) + + self.assertEqual(grpc.ChannelConnectivity.IDLE, + channel.get_state(False)) + + # Waiting for changes in a separate coroutine + wait_started = asyncio.Event() + + async def a_pending_wait(): + wait_started.set() + await channel.wait_for_state_change(grpc.ChannelConnectivity.IDLE) + + pending_task = self.loop.create_task(a_pending_wait()) + await wait_started.wait() + + await channel.close() + + self.assertEqual(grpc.ChannelConnectivity.SHUTDOWN, + channel.get_state(True)) + + self.assertEqual(grpc.ChannelConnectivity.SHUTDOWN, + channel.get_state(False)) + + # Make sure there isn't any exception in the task + await pending_task + + # It can raise exceptions since it is an usage error, but it should not + # segfault or abort. + with self.assertRaises(aio.UsageError): + await channel.wait_for_state_change( + grpc.ChannelConnectivity.SHUTDOWN) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/context_peer_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/context_peer_test.py new file mode 100644 index 00000000000..ea5f4621afb --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/context_peer_test.py @@ -0,0 +1,65 @@ +# Copyright 2020 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing the server context ability to access peer info.""" + +import asyncio +import logging +import os +import unittest +from typing import Callable, Iterable, Sequence, Tuple + +import grpc +from grpc.experimental import aio + +from src.proto.grpc.testing import messages_pb2, test_pb2_grpc +from tests.unit.framework.common import test_constants +from tests_aio.unit import _common +from tests_aio.unit._test_base import AioTestBase +from tests_aio.unit._test_server import TestServiceServicer, start_test_server + +_REQUEST = b'\x03\x07' +_TEST_METHOD = '/test/UnaryUnary' + + +class TestContextPeer(AioTestBase): + + async def test_peer(self): + + @grpc.unary_unary_rpc_method_handler + async def check_peer_unary_unary(request: bytes, + context: aio.ServicerContext): + self.assertEqual(_REQUEST, request) + # The peer address could be ipv4 or ipv6 + self.assertIn('ip', context.peer()) + return request + + # Creates a server + server = aio.server() + handlers = grpc.method_handlers_generic_handler( + 'test', {'UnaryUnary': check_peer_unary_unary}) + server.add_generic_rpc_handlers((handlers,)) + port = server.add_insecure_port('[::]:0') + await server.start() + + # Creates a channel + async with aio.insecure_channel('localhost:%d' % port) as channel: + response = await channel.unary_unary(_TEST_METHOD)(_REQUEST) + self.assertEqual(_REQUEST, response) + + await server.stop(None) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/done_callback_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/done_callback_test.py new file mode 100644 index 00000000000..481bafd5679 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/done_callback_test.py @@ -0,0 +1,124 @@ +# Copyright 2020 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing the done callbacks mechanism.""" + +import asyncio +import logging +import unittest +import time +import gc + +import grpc +from grpc.experimental import aio +from tests_aio.unit._common import inject_callbacks +from tests_aio.unit._test_base import AioTestBase +from tests.unit.framework.common import test_constants +from src.proto.grpc.testing import messages_pb2, test_pb2_grpc +from tests_aio.unit._test_server import start_test_server + +_NUM_STREAM_RESPONSES = 5 +_REQUEST_PAYLOAD_SIZE = 7 +_RESPONSE_PAYLOAD_SIZE = 42 + + +class TestDoneCallback(AioTestBase): + + async def setUp(self): + address, self._server = await start_test_server() + self._channel = aio.insecure_channel(address) + self._stub = test_pb2_grpc.TestServiceStub(self._channel) + + async def tearDown(self): + await self._channel.close() + await self._server.stop(None) + + async def test_add_after_done(self): + call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) + self.assertEqual(grpc.StatusCode.OK, await call.code()) + + validation = inject_callbacks(call) + await validation + + async def test_unary_unary(self): + call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) + validation = inject_callbacks(call) + + self.assertEqual(grpc.StatusCode.OK, await call.code()) + + await validation + + async def test_unary_stream(self): + request = messages_pb2.StreamingOutputCallRequest() + for _ in range(_NUM_STREAM_RESPONSES): + request.response_parameters.append( + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) + + call = self._stub.StreamingOutputCall(request) + validation = inject_callbacks(call) + + response_cnt = 0 + async for response in call: + response_cnt += 1 + self.assertIsInstance(response, + messages_pb2.StreamingOutputCallResponse) + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) + + self.assertEqual(_NUM_STREAM_RESPONSES, response_cnt) + self.assertEqual(grpc.StatusCode.OK, await call.code()) + + await validation + + async def test_stream_unary(self): + payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest(payload=payload) + + async def gen(): + for _ in range(_NUM_STREAM_RESPONSES): + yield request + + call = self._stub.StreamingInputCall(gen()) + validation = inject_callbacks(call) + + response = await call + self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse) + self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, + response.aggregated_payload_size) + self.assertEqual(grpc.StatusCode.OK, await call.code()) + + await validation + + async def test_stream_stream(self): + call = self._stub.FullDuplexCall() + validation = inject_callbacks(call) + + request = messages_pb2.StreamingOutputCallRequest() + request.response_parameters.append( + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) + + for _ in range(_NUM_STREAM_RESPONSES): + await call.write(request) + response = await call.read() + self.assertIsInstance(response, + messages_pb2.StreamingOutputCallResponse) + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) + + await call.done_writing() + + self.assertEqual(grpc.StatusCode.OK, await call.code()) + await validation + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/init_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/init_test.py new file mode 100644 index 00000000000..b9183a22c75 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/init_test.py @@ -0,0 +1,33 @@ +# Copyright 2019 The gRPC Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import unittest + + +class TestInit(unittest.TestCase): + + def test_grpc(self): + import grpc # pylint: disable=wrong-import-position + channel = grpc.aio.insecure_channel('dummy') + self.assertIsInstance(channel, grpc.aio.Channel) + + def test_grpc_dot_aio(self): + import grpc.aio # pylint: disable=wrong-import-position + channel = grpc.aio.insecure_channel('dummy') + self.assertIsInstance(channel, grpc.aio.Channel) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/metadata_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/metadata_test.py new file mode 100644 index 00000000000..2261446b3ea --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/metadata_test.py @@ -0,0 +1,297 @@ +# Copyright 2020 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests behavior around the metadata mechanism.""" + +import asyncio +import logging +import platform +import random +import unittest + +import grpc +from grpc.experimental import aio + +from tests_aio.unit._test_base import AioTestBase +from tests_aio.unit import _common + +_TEST_CLIENT_TO_SERVER = '/test/TestClientToServer' +_TEST_SERVER_TO_CLIENT = '/test/TestServerToClient' +_TEST_TRAILING_METADATA = '/test/TestTrailingMetadata' +_TEST_ECHO_INITIAL_METADATA = '/test/TestEchoInitialMetadata' +_TEST_GENERIC_HANDLER = '/test/TestGenericHandler' +_TEST_UNARY_STREAM = '/test/TestUnaryStream' +_TEST_STREAM_UNARY = '/test/TestStreamUnary' +_TEST_STREAM_STREAM = '/test/TestStreamStream' + +_REQUEST = b'\x00\x00\x00' +_RESPONSE = b'\x01\x01\x01' + +_INITIAL_METADATA_FROM_CLIENT_TO_SERVER = aio.Metadata( + ('client-to-server', 'question'), + ('client-to-server-bin', b'\x07\x07\x07'), +) +_INITIAL_METADATA_FROM_SERVER_TO_CLIENT = aio.Metadata( + ('server-to-client', 'answer'), + ('server-to-client-bin', b'\x06\x06\x06'), +) +_TRAILING_METADATA = aio.Metadata( + ('a-trailing-metadata', 'stack-trace'), + ('a-trailing-metadata-bin', b'\x05\x05\x05'), +) +_INITIAL_METADATA_FOR_GENERIC_HANDLER = aio.Metadata( + ('a-must-have-key', 'secret'),) + +_INVALID_METADATA_TEST_CASES = ( + ( + TypeError, + ((42, 42),), + ), + ( + TypeError, + (({}, {}),), + ), + ( + TypeError, + ((None, {}),), + ), + ( + TypeError, + (({}, {}),), + ), + ( + TypeError, + (('normal', object()),), + ), +) + + +class _TestGenericHandlerForMethods(grpc.GenericRpcHandler): + + def __init__(self): + self._routing_table = { + _TEST_CLIENT_TO_SERVER: + grpc.unary_unary_rpc_method_handler(self._test_client_to_server + ), + _TEST_SERVER_TO_CLIENT: + grpc.unary_unary_rpc_method_handler(self._test_server_to_client + ), + _TEST_TRAILING_METADATA: + grpc.unary_unary_rpc_method_handler(self._test_trailing_metadata + ), + _TEST_UNARY_STREAM: + grpc.unary_stream_rpc_method_handler(self._test_unary_stream), + _TEST_STREAM_UNARY: + grpc.stream_unary_rpc_method_handler(self._test_stream_unary), + _TEST_STREAM_STREAM: + grpc.stream_stream_rpc_method_handler(self._test_stream_stream), + } + + @staticmethod + async def _test_client_to_server(request, context): + assert _REQUEST == request + assert _common.seen_metadata(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER, + context.invocation_metadata()) + return _RESPONSE + + @staticmethod + async def _test_server_to_client(request, context): + assert _REQUEST == request + await context.send_initial_metadata( + _INITIAL_METADATA_FROM_SERVER_TO_CLIENT) + return _RESPONSE + + @staticmethod + async def _test_trailing_metadata(request, context): + assert _REQUEST == request + context.set_trailing_metadata(_TRAILING_METADATA) + return _RESPONSE + + @staticmethod + async def _test_unary_stream(request, context): + assert _REQUEST == request + assert _common.seen_metadata(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER, + context.invocation_metadata()) + await context.send_initial_metadata( + _INITIAL_METADATA_FROM_SERVER_TO_CLIENT) + yield _RESPONSE + context.set_trailing_metadata(_TRAILING_METADATA) + + @staticmethod + async def _test_stream_unary(request_iterator, context): + assert _common.seen_metadata(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER, + context.invocation_metadata()) + await context.send_initial_metadata( + _INITIAL_METADATA_FROM_SERVER_TO_CLIENT) + + async for request in request_iterator: + assert _REQUEST == request + + context.set_trailing_metadata(_TRAILING_METADATA) + return _RESPONSE + + @staticmethod + async def _test_stream_stream(request_iterator, context): + assert _common.seen_metadata(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER, + context.invocation_metadata()) + await context.send_initial_metadata( + _INITIAL_METADATA_FROM_SERVER_TO_CLIENT) + + async for request in request_iterator: + assert _REQUEST == request + + yield _RESPONSE + context.set_trailing_metadata(_TRAILING_METADATA) + + def service(self, handler_call_details): + return self._routing_table.get(handler_call_details.method) + + +class _TestGenericHandlerItself(grpc.GenericRpcHandler): + + @staticmethod + async def _method(request, unused_context): + assert _REQUEST == request + return _RESPONSE + + def service(self, handler_call_details): + assert _common.seen_metadata(_INITIAL_METADATA_FOR_GENERIC_HANDLER, + handler_call_details.invocation_metadata) + return grpc.unary_unary_rpc_method_handler(self._method) + + +async def _start_test_server(): + server = aio.server() + port = server.add_insecure_port('[::]:0') + server.add_generic_rpc_handlers(( + _TestGenericHandlerForMethods(), + _TestGenericHandlerItself(), + )) + await server.start() + return 'localhost:%d' % port, server + + +class TestMetadata(AioTestBase): + + async def setUp(self): + address, self._server = await _start_test_server() + self._client = aio.insecure_channel(address) + + async def tearDown(self): + await self._client.close() + await self._server.stop(None) + + async def test_from_client_to_server(self): + multicallable = self._client.unary_unary(_TEST_CLIENT_TO_SERVER) + call = multicallable(_REQUEST, + metadata=_INITIAL_METADATA_FROM_CLIENT_TO_SERVER) + self.assertEqual(_RESPONSE, await call) + self.assertEqual(grpc.StatusCode.OK, await call.code()) + + async def test_from_server_to_client(self): + multicallable = self._client.unary_unary(_TEST_SERVER_TO_CLIENT) + call = multicallable(_REQUEST) + + self.assertEqual(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT, await + call.initial_metadata()) + self.assertEqual(_RESPONSE, await call) + self.assertEqual(grpc.StatusCode.OK, await call.code()) + + async def test_trailing_metadata(self): + multicallable = self._client.unary_unary(_TEST_TRAILING_METADATA) + call = multicallable(_REQUEST) + self.assertEqual(_TRAILING_METADATA, await call.trailing_metadata()) + self.assertEqual(_RESPONSE, await call) + self.assertEqual(grpc.StatusCode.OK, await call.code()) + + async def test_from_client_to_server_with_list(self): + multicallable = self._client.unary_unary(_TEST_CLIENT_TO_SERVER) + call = multicallable( + _REQUEST, metadata=list(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER)) # pytype: disable=wrong-arg-types + self.assertEqual(_RESPONSE, await call) + self.assertEqual(grpc.StatusCode.OK, await call.code()) + + @unittest.skipIf(platform.system() == 'Windows', + 'https://github.com/grpc/grpc/issues/21943') + async def test_invalid_metadata(self): + multicallable = self._client.unary_unary(_TEST_CLIENT_TO_SERVER) + for exception_type, metadata in _INVALID_METADATA_TEST_CASES: + with self.subTest(metadata=metadata): + with self.assertRaises(exception_type): + call = multicallable(_REQUEST, metadata=metadata) + await call + + async def test_generic_handler(self): + multicallable = self._client.unary_unary(_TEST_GENERIC_HANDLER) + call = multicallable(_REQUEST, + metadata=_INITIAL_METADATA_FOR_GENERIC_HANDLER) + self.assertEqual(_RESPONSE, await call) + self.assertEqual(grpc.StatusCode.OK, await call.code()) + + async def test_unary_stream(self): + multicallable = self._client.unary_stream(_TEST_UNARY_STREAM) + call = multicallable(_REQUEST, + metadata=_INITIAL_METADATA_FROM_CLIENT_TO_SERVER) + + self.assertTrue( + _common.seen_metadata(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT, await + call.initial_metadata())) + + self.assertSequenceEqual([_RESPONSE], + [request async for request in call]) + + self.assertEqual(_TRAILING_METADATA, await call.trailing_metadata()) + self.assertEqual(grpc.StatusCode.OK, await call.code()) + + async def test_stream_unary(self): + multicallable = self._client.stream_unary(_TEST_STREAM_UNARY) + call = multicallable(metadata=_INITIAL_METADATA_FROM_CLIENT_TO_SERVER) + await call.write(_REQUEST) + await call.done_writing() + + self.assertTrue( + _common.seen_metadata(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT, await + call.initial_metadata())) + self.assertEqual(_RESPONSE, await call) + + self.assertEqual(_TRAILING_METADATA, await call.trailing_metadata()) + self.assertEqual(grpc.StatusCode.OK, await call.code()) + + async def test_stream_stream(self): + multicallable = self._client.stream_stream(_TEST_STREAM_STREAM) + call = multicallable(metadata=_INITIAL_METADATA_FROM_CLIENT_TO_SERVER) + await call.write(_REQUEST) + await call.done_writing() + + self.assertTrue( + _common.seen_metadata(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT, await + call.initial_metadata())) + self.assertSequenceEqual([_RESPONSE], + [request async for request in call]) + self.assertEqual(_TRAILING_METADATA, await call.trailing_metadata()) + self.assertEqual(grpc.StatusCode.OK, await call.code()) + + async def test_compatibility_with_tuple(self): + metadata_obj = aio.Metadata(('key', '42'), ('key-2', 'value')) + self.assertEqual(metadata_obj, tuple(metadata_obj)) + self.assertEqual(tuple(metadata_obj), metadata_obj) + + expected_sum = tuple(metadata_obj) + (('third', '3'),) + self.assertEqual(expected_sum, metadata_obj + (('third', '3'),)) + self.assertEqual(expected_sum, metadata_obj + aio.Metadata( + ('third', '3'))) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/outside_init_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/outside_init_test.py new file mode 100644 index 00000000000..879796cf0f5 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/outside_init_test.py @@ -0,0 +1,74 @@ +# Copyright 2020 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests behavior around the metadata mechanism.""" + +import asyncio +import logging +import unittest +from grpc.experimental import aio +import grpc + +from tests_aio.unit._test_server import start_test_server +from src.proto.grpc.testing import messages_pb2, test_pb2_grpc + +_NUM_OF_LOOPS = 50 + + +class TestOutsideInit(unittest.TestCase): + + def test_behavior_outside_asyncio(self): + # Ensures non-AsyncIO object can be initiated + channel_creds = grpc.ssl_channel_credentials() + + # Ensures AsyncIO API not raising outside of AsyncIO. + # NOTE(lidiz) This behavior is bound with GAPIC generator, and required + # by test frameworks like pytest. In test frameworks, objects shared + # across cases need to be created outside of AsyncIO coroutines. + aio.insecure_channel('') + aio.secure_channel('', channel_creds) + aio.server() + aio.init_grpc_aio() + aio.shutdown_grpc_aio() + + def test_multi_ephemeral_loops(self): + # Initializes AIO module outside. It's part of the test. We especially + # want to ensure the closing of the default loop won't cause deadlocks. + aio.init_grpc_aio() + + async def ping_pong(): + address, server = await start_test_server() + channel = aio.insecure_channel(address) + stub = test_pb2_grpc.TestServiceStub(channel) + + await stub.UnaryCall(messages_pb2.SimpleRequest()) + + await channel.close() + await server.stop(None) + + for i in range(_NUM_OF_LOOPS): + old_loop = asyncio.get_event_loop() + old_loop.close() + + loop = asyncio.new_event_loop() + loop.set_debug(True) + asyncio.set_event_loop(loop) + + loop.run_until_complete(ping_pong()) + + aio.shutdown_grpc_aio() + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/secure_call_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/secure_call_test.py new file mode 100644 index 00000000000..7efaddd607e --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/secure_call_test.py @@ -0,0 +1,130 @@ +# Copyright 2020 The gRPC Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests the behaviour of the Call classes under a secure channel.""" + +import unittest +import logging + +import grpc +from grpc.experimental import aio +from src.proto.grpc.testing import messages_pb2, test_pb2_grpc +from tests_aio.unit._test_base import AioTestBase +from tests_aio.unit._test_server import start_test_server +from tests.unit import resources + +_SERVER_HOST_OVERRIDE = 'foo.test.google.fr' +_NUM_STREAM_RESPONSES = 5 +_RESPONSE_PAYLOAD_SIZE = 42 + + +class _SecureCallMixin: + """A Mixin to run the call tests over a secure channel.""" + + async def setUp(self): + server_credentials = grpc.ssl_server_credentials([ + (resources.private_key(), resources.certificate_chain()) + ]) + channel_credentials = grpc.ssl_channel_credentials( + resources.test_root_certificates()) + + self._server_address, self._server = await start_test_server( + secure=True, server_credentials=server_credentials) + channel_options = (( + 'grpc.ssl_target_name_override', + _SERVER_HOST_OVERRIDE, + ),) + self._channel = aio.secure_channel(self._server_address, + channel_credentials, channel_options) + self._stub = test_pb2_grpc.TestServiceStub(self._channel) + + async def tearDown(self): + await self._channel.close() + await self._server.stop(None) + + +class TestUnaryUnarySecureCall(_SecureCallMixin, AioTestBase): + """unary_unary Calls made over a secure channel.""" + + async def test_call_ok_over_secure_channel(self): + call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) + response = await call + self.assertIsInstance(response, messages_pb2.SimpleResponse) + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + async def test_call_with_credentials(self): + call_credentials = grpc.composite_call_credentials( + grpc.access_token_call_credentials("abc"), + grpc.access_token_call_credentials("def"), + ) + call = self._stub.UnaryCall(messages_pb2.SimpleRequest(), + credentials=call_credentials) + response = await call + + self.assertIsInstance(response, messages_pb2.SimpleResponse) + + +class TestUnaryStreamSecureCall(_SecureCallMixin, AioTestBase): + """unary_stream calls over a secure channel""" + + async def test_unary_stream_async_generator_secure(self): + request = messages_pb2.StreamingOutputCallRequest() + request.response_parameters.extend( + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,) + for _ in range(_NUM_STREAM_RESPONSES)) + call_credentials = grpc.composite_call_credentials( + grpc.access_token_call_credentials("abc"), + grpc.access_token_call_credentials("def"), + ) + call = self._stub.StreamingOutputCall(request, + credentials=call_credentials) + + async for response in call: + self.assertIsInstance(response, + messages_pb2.StreamingOutputCallResponse) + self.assertEqual(len(response.payload.body), _RESPONSE_PAYLOAD_SIZE) + + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + +# Prepares the request that stream in a ping-pong manner. +_STREAM_OUTPUT_REQUEST_ONE_RESPONSE = messages_pb2.StreamingOutputCallRequest() +_STREAM_OUTPUT_REQUEST_ONE_RESPONSE.response_parameters.append( + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) + + +class TestStreamStreamSecureCall(_SecureCallMixin, AioTestBase): + _STREAM_ITERATIONS = 2 + + async def test_async_generator_secure_channel(self): + + async def request_generator(): + for _ in range(self._STREAM_ITERATIONS): + yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE + + call_credentials = grpc.composite_call_credentials( + grpc.access_token_call_credentials("abc"), + grpc.access_token_call_credentials("def"), + ) + + call = self._stub.FullDuplexCall(request_generator(), + credentials=call_credentials) + async for response in call: + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) + + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py new file mode 100644 index 00000000000..d891ecdb771 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py @@ -0,0 +1,330 @@ +# Copyright 2020 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test the functionality of server interceptors.""" + +import asyncio +import functools +import logging +import unittest +from typing import Any, Awaitable, Callable, Tuple + +import grpc +from grpc.experimental import aio, wrap_server_method_handler + +from src.proto.grpc.testing import messages_pb2, test_pb2_grpc +from tests_aio.unit._test_base import AioTestBase +from tests_aio.unit._test_server import start_test_server + +_NUM_STREAM_RESPONSES = 5 +_REQUEST_PAYLOAD_SIZE = 7 +_RESPONSE_PAYLOAD_SIZE = 42 + + +class _LoggingInterceptor(aio.ServerInterceptor): + + def __init__(self, tag: str, record: list) -> None: + self.tag = tag + self.record = record + + async def intercept_service( + self, continuation: Callable[[grpc.HandlerCallDetails], Awaitable[ + grpc.RpcMethodHandler]], + handler_call_details: grpc.HandlerCallDetails + ) -> grpc.RpcMethodHandler: + self.record.append(self.tag + ':intercept_service') + return await continuation(handler_call_details) + + +class _GenericInterceptor(aio.ServerInterceptor): + + def __init__(self, fn: Callable[[ + Callable[[grpc.HandlerCallDetails], Awaitable[grpc. + RpcMethodHandler]], + grpc.HandlerCallDetails + ], Any]) -> None: + self._fn = fn + + async def intercept_service( + self, continuation: Callable[[grpc.HandlerCallDetails], Awaitable[ + grpc.RpcMethodHandler]], + handler_call_details: grpc.HandlerCallDetails + ) -> grpc.RpcMethodHandler: + return await self._fn(continuation, handler_call_details) + + +def _filter_server_interceptor(condition: Callable, + interceptor: aio.ServerInterceptor + ) -> aio.ServerInterceptor: + + async def intercept_service( + continuation: Callable[[grpc.HandlerCallDetails], Awaitable[ + grpc.RpcMethodHandler]], + handler_call_details: grpc.HandlerCallDetails + ) -> grpc.RpcMethodHandler: + if condition(handler_call_details): + return await interceptor.intercept_service(continuation, + handler_call_details) + return await continuation(handler_call_details) + + return _GenericInterceptor(intercept_service) + + +class _CacheInterceptor(aio.ServerInterceptor): + """An interceptor that caches response based on request message.""" + + def __init__(self, cache_store=None): + self.cache_store = cache_store or {} + + async def intercept_service( + self, continuation: Callable[[grpc.HandlerCallDetails], Awaitable[ + grpc.RpcMethodHandler]], + handler_call_details: grpc.HandlerCallDetails + ) -> grpc.RpcMethodHandler: + # Get the actual handler + handler = await continuation(handler_call_details) + + # Only intercept unary call RPCs + if handler and (handler.request_streaming or # pytype: disable=attribute-error + handler.response_streaming): # pytype: disable=attribute-error + return handler + + def wrapper(behavior: Callable[ + [messages_pb2.SimpleRequest, aio. + ServicerContext], messages_pb2.SimpleResponse]): + + @functools.wraps(behavior) + async def wrapper(request: messages_pb2.SimpleRequest, + context: aio.ServicerContext + ) -> messages_pb2.SimpleResponse: + if request.response_size not in self.cache_store: + self.cache_store[request.response_size] = await behavior( + request, context) + return self.cache_store[request.response_size] + + return wrapper + + return wrap_server_method_handler(wrapper, handler) + + +async def _create_server_stub_pair( + *interceptors: aio.ServerInterceptor +) -> Tuple[aio.Server, test_pb2_grpc.TestServiceStub]: + """Creates a server-stub pair with given interceptors. + + Returning the server object to protect it from being garbage collected. + """ + server_target, server = await start_test_server(interceptors=interceptors) + channel = aio.insecure_channel(server_target) + return server, test_pb2_grpc.TestServiceStub(channel) + + +class TestServerInterceptor(AioTestBase): + + async def test_invalid_interceptor(self): + + class InvalidInterceptor: + """Just an invalid Interceptor""" + + with self.assertRaises(ValueError): + server_target, _ = await start_test_server( + interceptors=(InvalidInterceptor(),)) + + async def test_executed_right_order(self): + record = [] + server_target, _ = await start_test_server(interceptors=( + _LoggingInterceptor('log1', record), + _LoggingInterceptor('log2', record), + )) + + async with aio.insecure_channel(server_target) as channel: + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + call = multicallable(messages_pb2.SimpleRequest()) + response = await call + + # Check that all interceptors were executed, and were executed + # in the right order. + self.assertSequenceEqual([ + 'log1:intercept_service', + 'log2:intercept_service', + ], record) + self.assertIsInstance(response, messages_pb2.SimpleResponse) + + async def test_response_ok(self): + record = [] + server_target, _ = await start_test_server( + interceptors=(_LoggingInterceptor('log1', record),)) + + async with aio.insecure_channel(server_target) as channel: + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + call = multicallable(messages_pb2.SimpleRequest()) + response = await call + code = await call.code() + + self.assertSequenceEqual(['log1:intercept_service'], record) + self.assertIsInstance(response, messages_pb2.SimpleResponse) + self.assertEqual(code, grpc.StatusCode.OK) + + async def test_apply_different_interceptors_by_metadata(self): + record = [] + conditional_interceptor = _filter_server_interceptor( + lambda x: ('secret', '42') in x.invocation_metadata, + _LoggingInterceptor('log3', record)) + server_target, _ = await start_test_server(interceptors=( + _LoggingInterceptor('log1', record), + conditional_interceptor, + _LoggingInterceptor('log2', record), + )) + + async with aio.insecure_channel(server_target) as channel: + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + + metadata = aio.Metadata(('key', 'value'),) + call = multicallable(messages_pb2.SimpleRequest(), + metadata=metadata) + await call + self.assertSequenceEqual([ + 'log1:intercept_service', + 'log2:intercept_service', + ], record) + + record.clear() + metadata = aio.Metadata(('key', 'value'), ('secret', '42')) + call = multicallable(messages_pb2.SimpleRequest(), + metadata=metadata) + await call + self.assertSequenceEqual([ + 'log1:intercept_service', + 'log3:intercept_service', + 'log2:intercept_service', + ], record) + + async def test_response_caching(self): + # Prepares a preset value to help testing + interceptor = _CacheInterceptor({ + 42: + messages_pb2.SimpleResponse(payload=messages_pb2.Payload( + body=b'\x42')) + }) + + # Constructs a server with the cache interceptor + server, stub = await _create_server_stub_pair(interceptor) + + # Tests if the cache store is used + response = await stub.UnaryCall( + messages_pb2.SimpleRequest(response_size=42)) + self.assertEqual(1, len(interceptor.cache_store[42].payload.body)) + self.assertEqual(interceptor.cache_store[42], response) + + # Tests response can be cached + response = await stub.UnaryCall( + messages_pb2.SimpleRequest(response_size=1337)) + self.assertEqual(1337, len(interceptor.cache_store[1337].payload.body)) + self.assertEqual(interceptor.cache_store[1337], response) + response = await stub.UnaryCall( + messages_pb2.SimpleRequest(response_size=1337)) + self.assertEqual(interceptor.cache_store[1337], response) + + async def test_interceptor_unary_stream(self): + record = [] + server, stub = await _create_server_stub_pair( + _LoggingInterceptor('log_unary_stream', record)) + + # Prepares the request + request = messages_pb2.StreamingOutputCallRequest() + for _ in range(_NUM_STREAM_RESPONSES): + request.response_parameters.append( + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,)) + + # Tests if the cache store is used + call = stub.StreamingOutputCall(request) + + # Ensures the RPC goes fine + async for response in call: + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + self.assertSequenceEqual([ + 'log_unary_stream:intercept_service', + ], record) + + async def test_interceptor_stream_unary(self): + record = [] + server, stub = await _create_server_stub_pair( + _LoggingInterceptor('log_stream_unary', record)) + + # Invokes the actual RPC + call = stub.StreamingInputCall() + + # Prepares the request + payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest(payload=payload) + + # Sends out requests + for _ in range(_NUM_STREAM_RESPONSES): + await call.write(request) + await call.done_writing() + + # Validates the responses + response = await call + self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse) + self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, + response.aggregated_payload_size) + + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + self.assertSequenceEqual([ + 'log_stream_unary:intercept_service', + ], record) + + async def test_interceptor_stream_stream(self): + record = [] + server, stub = await _create_server_stub_pair( + _LoggingInterceptor('log_stream_stream', record)) + + # Prepares the request + payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest(payload=payload) + + async def gen(): + for _ in range(_NUM_STREAM_RESPONSES): + yield request + + # Invokes the actual RPC + call = stub.StreamingInputCall(gen()) + + # Validates the responses + response = await call + self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse) + self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, + response.aggregated_payload_size) + + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + self.assertSequenceEqual([ + 'log_stream_stream:intercept_service', + ], record) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/server_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/server_test.py new file mode 100644 index 00000000000..61d1edd5231 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/server_test.py @@ -0,0 +1,486 @@ +# Copyright 2019 The gRPC Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import gc +import logging +import socket +import time +import unittest + +import grpc +from grpc.experimental import aio + +from tests.unit import resources +from tests.unit.framework.common import test_constants +from tests_aio.unit._test_base import AioTestBase + +_SIMPLE_UNARY_UNARY = '/test/SimpleUnaryUnary' +_BLOCK_FOREVER = '/test/BlockForever' +_BLOCK_BRIEFLY = '/test/BlockBriefly' +_UNARY_STREAM_ASYNC_GEN = '/test/UnaryStreamAsyncGen' +_UNARY_STREAM_READER_WRITER = '/test/UnaryStreamReaderWriter' +_UNARY_STREAM_EVILLY_MIXED = '/test/UnaryStreamEvillyMixed' +_STREAM_UNARY_ASYNC_GEN = '/test/StreamUnaryAsyncGen' +_STREAM_UNARY_READER_WRITER = '/test/StreamUnaryReaderWriter' +_STREAM_UNARY_EVILLY_MIXED = '/test/StreamUnaryEvillyMixed' +_STREAM_STREAM_ASYNC_GEN = '/test/StreamStreamAsyncGen' +_STREAM_STREAM_READER_WRITER = '/test/StreamStreamReaderWriter' +_STREAM_STREAM_EVILLY_MIXED = '/test/StreamStreamEvillyMixed' +_UNIMPLEMENTED_METHOD = '/test/UnimplementedMethod' +_ERROR_IN_STREAM_STREAM = '/test/ErrorInStreamStream' +_ERROR_WITHOUT_RAISE_IN_UNARY_UNARY = '/test/ErrorWithoutRaiseInUnaryUnary' +_ERROR_WITHOUT_RAISE_IN_STREAM_STREAM = '/test/ErrorWithoutRaiseInStreamStream' + +_REQUEST = b'\x00\x00\x00' +_RESPONSE = b'\x01\x01\x01' +_NUM_STREAM_REQUESTS = 3 +_NUM_STREAM_RESPONSES = 5 + + +class _GenericHandler(grpc.GenericRpcHandler): + + def __init__(self): + self._called = asyncio.get_event_loop().create_future() + self._routing_table = { + _SIMPLE_UNARY_UNARY: + grpc.unary_unary_rpc_method_handler(self._unary_unary), + _BLOCK_FOREVER: + grpc.unary_unary_rpc_method_handler(self._block_forever), + _BLOCK_BRIEFLY: + grpc.unary_unary_rpc_method_handler(self._block_briefly), + _UNARY_STREAM_ASYNC_GEN: + grpc.unary_stream_rpc_method_handler( + self._unary_stream_async_gen), + _UNARY_STREAM_READER_WRITER: + grpc.unary_stream_rpc_method_handler( + self._unary_stream_reader_writer), + _UNARY_STREAM_EVILLY_MIXED: + grpc.unary_stream_rpc_method_handler( + self._unary_stream_evilly_mixed), + _STREAM_UNARY_ASYNC_GEN: + grpc.stream_unary_rpc_method_handler( + self._stream_unary_async_gen), + _STREAM_UNARY_READER_WRITER: + grpc.stream_unary_rpc_method_handler( + self._stream_unary_reader_writer), + _STREAM_UNARY_EVILLY_MIXED: + grpc.stream_unary_rpc_method_handler( + self._stream_unary_evilly_mixed), + _STREAM_STREAM_ASYNC_GEN: + grpc.stream_stream_rpc_method_handler( + self._stream_stream_async_gen), + _STREAM_STREAM_READER_WRITER: + grpc.stream_stream_rpc_method_handler( + self._stream_stream_reader_writer), + _STREAM_STREAM_EVILLY_MIXED: + grpc.stream_stream_rpc_method_handler( + self._stream_stream_evilly_mixed), + _ERROR_IN_STREAM_STREAM: + grpc.stream_stream_rpc_method_handler( + self._error_in_stream_stream), + _ERROR_WITHOUT_RAISE_IN_UNARY_UNARY: + grpc.unary_unary_rpc_method_handler( + self._error_without_raise_in_unary_unary), + _ERROR_WITHOUT_RAISE_IN_STREAM_STREAM: + grpc.stream_stream_rpc_method_handler( + self._error_without_raise_in_stream_stream), + } + + @staticmethod + async def _unary_unary(unused_request, unused_context): + return _RESPONSE + + async def _block_forever(self, unused_request, unused_context): + await asyncio.get_event_loop().create_future() + + async def _block_briefly(self, unused_request, unused_context): + await asyncio.sleep(test_constants.SHORT_TIMEOUT / 2) + return _RESPONSE + + async def _unary_stream_async_gen(self, unused_request, unused_context): + for _ in range(_NUM_STREAM_RESPONSES): + yield _RESPONSE + + async def _unary_stream_reader_writer(self, unused_request, context): + for _ in range(_NUM_STREAM_RESPONSES): + await context.write(_RESPONSE) + + async def _unary_stream_evilly_mixed(self, unused_request, context): + yield _RESPONSE + for _ in range(_NUM_STREAM_RESPONSES - 1): + await context.write(_RESPONSE) + + async def _stream_unary_async_gen(self, request_iterator, unused_context): + request_count = 0 + async for request in request_iterator: + assert _REQUEST == request + request_count += 1 + assert _NUM_STREAM_REQUESTS == request_count + return _RESPONSE + + async def _stream_unary_reader_writer(self, unused_request, context): + for _ in range(_NUM_STREAM_REQUESTS): + assert _REQUEST == await context.read() + return _RESPONSE + + async def _stream_unary_evilly_mixed(self, request_iterator, context): + assert _REQUEST == await context.read() + request_count = 0 + async for request in request_iterator: + assert _REQUEST == request + request_count += 1 + assert _NUM_STREAM_REQUESTS - 1 == request_count + return _RESPONSE + + async def _stream_stream_async_gen(self, request_iterator, unused_context): + request_count = 0 + async for request in request_iterator: + assert _REQUEST == request + request_count += 1 + assert _NUM_STREAM_REQUESTS == request_count + + for _ in range(_NUM_STREAM_RESPONSES): + yield _RESPONSE + + async def _stream_stream_reader_writer(self, unused_request, context): + for _ in range(_NUM_STREAM_REQUESTS): + assert _REQUEST == await context.read() + for _ in range(_NUM_STREAM_RESPONSES): + await context.write(_RESPONSE) + + async def _stream_stream_evilly_mixed(self, request_iterator, context): + assert _REQUEST == await context.read() + request_count = 0 + async for request in request_iterator: + assert _REQUEST == request + request_count += 1 + assert _NUM_STREAM_REQUESTS - 1 == request_count + + yield _RESPONSE + for _ in range(_NUM_STREAM_RESPONSES - 1): + await context.write(_RESPONSE) + + async def _error_in_stream_stream(self, request_iterator, unused_context): + async for request in request_iterator: + assert _REQUEST == request + raise RuntimeError('A testing RuntimeError!') + yield _RESPONSE + + async def _error_without_raise_in_unary_unary(self, request, context): + assert _REQUEST == request + context.set_code(grpc.StatusCode.INTERNAL) + + async def _error_without_raise_in_stream_stream(self, request_iterator, + context): + async for request in request_iterator: + assert _REQUEST == request + context.set_code(grpc.StatusCode.INTERNAL) + + def service(self, handler_details): + self._called.set_result(None) + return self._routing_table.get(handler_details.method) + + async def wait_for_call(self): + await self._called + + +async def _start_test_server(): + server = aio.server() + port = server.add_insecure_port('[::]:0') + generic_handler = _GenericHandler() + server.add_generic_rpc_handlers((generic_handler,)) + await server.start() + return 'localhost:%d' % port, server, generic_handler + + +class TestServer(AioTestBase): + + async def setUp(self): + addr, self._server, self._generic_handler = await _start_test_server() + self._channel = aio.insecure_channel(addr) + + async def tearDown(self): + await self._channel.close() + await self._server.stop(None) + + async def test_unary_unary(self): + unary_unary_call = self._channel.unary_unary(_SIMPLE_UNARY_UNARY) + response = await unary_unary_call(_REQUEST) + self.assertEqual(response, _RESPONSE) + + async def test_unary_stream_async_generator(self): + unary_stream_call = self._channel.unary_stream(_UNARY_STREAM_ASYNC_GEN) + call = unary_stream_call(_REQUEST) + + response_cnt = 0 + async for response in call: + response_cnt += 1 + self.assertEqual(_RESPONSE, response) + + self.assertEqual(_NUM_STREAM_RESPONSES, response_cnt) + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + async def test_unary_stream_reader_writer(self): + unary_stream_call = self._channel.unary_stream( + _UNARY_STREAM_READER_WRITER) + call = unary_stream_call(_REQUEST) + + for _ in range(_NUM_STREAM_RESPONSES): + response = await call.read() + self.assertEqual(_RESPONSE, response) + + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + async def test_unary_stream_evilly_mixed(self): + unary_stream_call = self._channel.unary_stream( + _UNARY_STREAM_EVILLY_MIXED) + call = unary_stream_call(_REQUEST) + + # Uses reader API + self.assertEqual(_RESPONSE, await call.read()) + + # Uses async generator API, mixed! + with self.assertRaises(aio.UsageError): + async for response in call: + self.assertEqual(_RESPONSE, response) + + async def test_stream_unary_async_generator(self): + stream_unary_call = self._channel.stream_unary(_STREAM_UNARY_ASYNC_GEN) + call = stream_unary_call() + + for _ in range(_NUM_STREAM_REQUESTS): + await call.write(_REQUEST) + await call.done_writing() + + response = await call + self.assertEqual(_RESPONSE, response) + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + async def test_stream_unary_reader_writer(self): + stream_unary_call = self._channel.stream_unary( + _STREAM_UNARY_READER_WRITER) + call = stream_unary_call() + + for _ in range(_NUM_STREAM_REQUESTS): + await call.write(_REQUEST) + await call.done_writing() + + response = await call + self.assertEqual(_RESPONSE, response) + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + async def test_stream_unary_evilly_mixed(self): + stream_unary_call = self._channel.stream_unary( + _STREAM_UNARY_EVILLY_MIXED) + call = stream_unary_call() + + for _ in range(_NUM_STREAM_REQUESTS): + await call.write(_REQUEST) + await call.done_writing() + + response = await call + self.assertEqual(_RESPONSE, response) + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + async def test_stream_stream_async_generator(self): + stream_stream_call = self._channel.stream_stream( + _STREAM_STREAM_ASYNC_GEN) + call = stream_stream_call() + + for _ in range(_NUM_STREAM_REQUESTS): + await call.write(_REQUEST) + await call.done_writing() + + for _ in range(_NUM_STREAM_RESPONSES): + response = await call.read() + self.assertEqual(_RESPONSE, response) + + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + async def test_stream_stream_reader_writer(self): + stream_stream_call = self._channel.stream_stream( + _STREAM_STREAM_READER_WRITER) + call = stream_stream_call() + + for _ in range(_NUM_STREAM_REQUESTS): + await call.write(_REQUEST) + await call.done_writing() + + for _ in range(_NUM_STREAM_RESPONSES): + response = await call.read() + self.assertEqual(_RESPONSE, response) + + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + async def test_stream_stream_evilly_mixed(self): + stream_stream_call = self._channel.stream_stream( + _STREAM_STREAM_EVILLY_MIXED) + call = stream_stream_call() + + for _ in range(_NUM_STREAM_REQUESTS): + await call.write(_REQUEST) + await call.done_writing() + + for _ in range(_NUM_STREAM_RESPONSES): + response = await call.read() + self.assertEqual(_RESPONSE, response) + + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + async def test_shutdown(self): + await self._server.stop(None) + # Ensures no SIGSEGV triggered, and ends within timeout. + + async def test_shutdown_after_call(self): + await self._channel.unary_unary(_SIMPLE_UNARY_UNARY)(_REQUEST) + + await self._server.stop(None) + + async def test_graceful_shutdown_success(self): + call = self._channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST) + await self._generic_handler.wait_for_call() + + shutdown_start_time = time.time() + await self._server.stop(test_constants.SHORT_TIMEOUT) + grace_period_length = time.time() - shutdown_start_time + self.assertGreater(grace_period_length, + test_constants.SHORT_TIMEOUT / 3) + + # Validates the states. + self.assertEqual(_RESPONSE, await call) + self.assertTrue(call.done()) + + async def test_graceful_shutdown_failed(self): + call = self._channel.unary_unary(_BLOCK_FOREVER)(_REQUEST) + await self._generic_handler.wait_for_call() + + await self._server.stop(test_constants.SHORT_TIMEOUT) + + with self.assertRaises(aio.AioRpcError) as exception_context: + await call + self.assertEqual(grpc.StatusCode.UNAVAILABLE, + exception_context.exception.code()) + + async def test_concurrent_graceful_shutdown(self): + call = self._channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST) + await self._generic_handler.wait_for_call() + + # Expects the shortest grace period to be effective. + shutdown_start_time = time.time() + await asyncio.gather( + self._server.stop(test_constants.LONG_TIMEOUT), + self._server.stop(test_constants.SHORT_TIMEOUT), + self._server.stop(test_constants.LONG_TIMEOUT), + ) + grace_period_length = time.time() - shutdown_start_time + self.assertGreater(grace_period_length, + test_constants.SHORT_TIMEOUT / 3) + + self.assertEqual(_RESPONSE, await call) + self.assertTrue(call.done()) + + async def test_concurrent_graceful_shutdown_immediate(self): + call = self._channel.unary_unary(_BLOCK_FOREVER)(_REQUEST) + await self._generic_handler.wait_for_call() + + # Expects no grace period, due to the "server.stop(None)". + await asyncio.gather( + self._server.stop(test_constants.LONG_TIMEOUT), + self._server.stop(None), + self._server.stop(test_constants.SHORT_TIMEOUT), + self._server.stop(test_constants.LONG_TIMEOUT), + ) + + with self.assertRaises(aio.AioRpcError) as exception_context: + await call + self.assertEqual(grpc.StatusCode.UNAVAILABLE, + exception_context.exception.code()) + + async def test_shutdown_before_call(self): + await self._server.stop(None) + + # Ensures the server is cleaned up at this point. + # Some proper exception should be raised. + with self.assertRaises(aio.AioRpcError): + await self._channel.unary_unary(_SIMPLE_UNARY_UNARY)(_REQUEST) + + async def test_unimplemented(self): + call = self._channel.unary_unary(_UNIMPLEMENTED_METHOD) + with self.assertRaises(aio.AioRpcError) as exception_context: + await call(_REQUEST) + rpc_error = exception_context.exception + self.assertEqual(grpc.StatusCode.UNIMPLEMENTED, rpc_error.code()) + + async def test_shutdown_during_stream_stream(self): + stream_stream_call = self._channel.stream_stream( + _STREAM_STREAM_ASYNC_GEN) + call = stream_stream_call() + + # Don't half close the RPC yet, keep it alive. + await call.write(_REQUEST) + await self._server.stop(None) + + self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code()) + # No segfault + + async def test_error_in_stream_stream(self): + stream_stream_call = self._channel.stream_stream( + _ERROR_IN_STREAM_STREAM) + call = stream_stream_call() + + # Don't half close the RPC yet, keep it alive. + await call.write(_REQUEST) + + # Don't segfault here + self.assertEqual(grpc.StatusCode.UNKNOWN, await call.code()) + + async def test_error_without_raise_in_unary_unary(self): + call = self._channel.unary_unary(_ERROR_WITHOUT_RAISE_IN_UNARY_UNARY)( + _REQUEST) + + with self.assertRaises(aio.AioRpcError) as exception_context: + await call + + rpc_error = exception_context.exception + self.assertEqual(grpc.StatusCode.INTERNAL, rpc_error.code()) + + async def test_error_without_raise_in_stream_stream(self): + call = self._channel.stream_stream( + _ERROR_WITHOUT_RAISE_IN_STREAM_STREAM)() + + for _ in range(_NUM_STREAM_REQUESTS): + await call.write(_REQUEST) + await call.done_writing() + + self.assertEqual(grpc.StatusCode.INTERNAL, await call.code()) + + async def test_port_binding_exception(self): + server = aio.server(options=(('grpc.so_reuseport', 0),)) + port = server.add_insecure_port('localhost:0') + bind_address = "localhost:%d" % port + + with self.assertRaises(RuntimeError): + server.add_insecure_port(bind_address) + + server_credentials = grpc.ssl_server_credentials([ + (resources.private_key(), resources.certificate_chain()) + ]) + with self.assertRaises(RuntimeError): + server.add_secure_port(bind_address, server_credentials) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/timeout_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/timeout_test.py new file mode 100644 index 00000000000..b5bcc027ec1 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/timeout_test.py @@ -0,0 +1,178 @@ +# Copyright 2020 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests behavior of the timeout mechanism on client side.""" + +import asyncio +import logging +import platform +import random +import unittest +import datetime + +import grpc +from grpc.experimental import aio + +from tests_aio.unit._test_base import AioTestBase +from tests_aio.unit import _common + +_SLEEP_TIME_UNIT_S = datetime.timedelta(seconds=1).total_seconds() + +_TEST_SLEEPY_UNARY_UNARY = '/test/Test/SleepyUnaryUnary' +_TEST_SLEEPY_UNARY_STREAM = '/test/Test/SleepyUnaryStream' +_TEST_SLEEPY_STREAM_UNARY = '/test/Test/SleepyStreamUnary' +_TEST_SLEEPY_STREAM_STREAM = '/test/Test/SleepyStreamStream' + +_REQUEST = b'\x00\x00\x00' +_RESPONSE = b'\x01\x01\x01' + + +async def _test_sleepy_unary_unary(unused_request, unused_context): + await asyncio.sleep(_SLEEP_TIME_UNIT_S) + return _RESPONSE + + +async def _test_sleepy_unary_stream(unused_request, unused_context): + yield _RESPONSE + await asyncio.sleep(_SLEEP_TIME_UNIT_S) + yield _RESPONSE + + +async def _test_sleepy_stream_unary(unused_request_iterator, context): + assert _REQUEST == await context.read() + await asyncio.sleep(_SLEEP_TIME_UNIT_S) + assert _REQUEST == await context.read() + return _RESPONSE + + +async def _test_sleepy_stream_stream(unused_request_iterator, context): + assert _REQUEST == await context.read() + await asyncio.sleep(_SLEEP_TIME_UNIT_S) + await context.write(_RESPONSE) + + +_ROUTING_TABLE = { + _TEST_SLEEPY_UNARY_UNARY: + grpc.unary_unary_rpc_method_handler(_test_sleepy_unary_unary), + _TEST_SLEEPY_UNARY_STREAM: + grpc.unary_stream_rpc_method_handler(_test_sleepy_unary_stream), + _TEST_SLEEPY_STREAM_UNARY: + grpc.stream_unary_rpc_method_handler(_test_sleepy_stream_unary), + _TEST_SLEEPY_STREAM_STREAM: + grpc.stream_stream_rpc_method_handler(_test_sleepy_stream_stream) +} + + +class _GenericHandler(grpc.GenericRpcHandler): + + def service(self, handler_call_details): + return _ROUTING_TABLE.get(handler_call_details.method) + + +async def _start_test_server(): + server = aio.server() + port = server.add_insecure_port('[::]:0') + server.add_generic_rpc_handlers((_GenericHandler(),)) + await server.start() + return f'localhost:{port}', server + + +class TestTimeout(AioTestBase): + + async def setUp(self): + address, self._server = await _start_test_server() + self._client = aio.insecure_channel(address) + self.assertEqual(grpc.ChannelConnectivity.IDLE, + self._client.get_state(True)) + await _common.block_until_certain_state(self._client, + grpc.ChannelConnectivity.READY) + + async def tearDown(self): + await self._client.close() + await self._server.stop(None) + + async def test_unary_unary_success_with_timeout(self): + multicallable = self._client.unary_unary(_TEST_SLEEPY_UNARY_UNARY) + call = multicallable(_REQUEST, timeout=2 * _SLEEP_TIME_UNIT_S) + self.assertEqual(_RESPONSE, await call) + self.assertEqual(grpc.StatusCode.OK, await call.code()) + + async def test_unary_unary_deadline_exceeded(self): + multicallable = self._client.unary_unary(_TEST_SLEEPY_UNARY_UNARY) + call = multicallable(_REQUEST, timeout=0.5 * _SLEEP_TIME_UNIT_S) + + with self.assertRaises(aio.AioRpcError) as exception_context: + await call + + rpc_error = exception_context.exception + self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, rpc_error.code()) + + async def test_unary_stream_success_with_timeout(self): + multicallable = self._client.unary_stream(_TEST_SLEEPY_UNARY_STREAM) + call = multicallable(_REQUEST, timeout=2 * _SLEEP_TIME_UNIT_S) + self.assertEqual(_RESPONSE, await call.read()) + self.assertEqual(_RESPONSE, await call.read()) + self.assertEqual(grpc.StatusCode.OK, await call.code()) + + async def test_unary_stream_deadline_exceeded(self): + multicallable = self._client.unary_stream(_TEST_SLEEPY_UNARY_STREAM) + call = multicallable(_REQUEST, timeout=0.5 * _SLEEP_TIME_UNIT_S) + self.assertEqual(_RESPONSE, await call.read()) + + with self.assertRaises(aio.AioRpcError) as exception_context: + await call.read() + + rpc_error = exception_context.exception + self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, rpc_error.code()) + + async def test_stream_unary_success_with_timeout(self): + multicallable = self._client.stream_unary(_TEST_SLEEPY_STREAM_UNARY) + call = multicallable(timeout=2 * _SLEEP_TIME_UNIT_S) + await call.write(_REQUEST) + await call.write(_REQUEST) + self.assertEqual(grpc.StatusCode.OK, await call.code()) + + async def test_stream_unary_deadline_exceeded(self): + multicallable = self._client.stream_unary(_TEST_SLEEPY_STREAM_UNARY) + call = multicallable(timeout=0.5 * _SLEEP_TIME_UNIT_S) + + with self.assertRaises(aio.AioRpcError) as exception_context: + await call.write(_REQUEST) + await call.write(_REQUEST) + await call + + rpc_error = exception_context.exception + self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, rpc_error.code()) + + async def test_stream_stream_success_with_timeout(self): + multicallable = self._client.stream_stream(_TEST_SLEEPY_STREAM_STREAM) + call = multicallable(timeout=2 * _SLEEP_TIME_UNIT_S) + await call.write(_REQUEST) + self.assertEqual(_RESPONSE, await call.read()) + self.assertEqual(grpc.StatusCode.OK, await call.code()) + + async def test_stream_stream_deadline_exceeded(self): + multicallable = self._client.stream_stream(_TEST_SLEEPY_STREAM_STREAM) + call = multicallable(timeout=0.5 * _SLEEP_TIME_UNIT_S) + + with self.assertRaises(aio.AioRpcError) as exception_context: + await call.write(_REQUEST) + await call.read() + + rpc_error = exception_context.exception + self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, rpc_error.code()) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/wait_for_connection_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/wait_for_connection_test.py new file mode 100644 index 00000000000..cb6f7985290 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/wait_for_connection_test.py @@ -0,0 +1,159 @@ +# Copyright 2020 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests behavior of the wait for connection API on client side.""" + +import asyncio +import logging +import unittest +import datetime +from typing import Callable, Tuple + +import grpc +from grpc.experimental import aio + +from tests_aio.unit._test_base import AioTestBase +from tests_aio.unit._test_server import start_test_server +from tests_aio.unit import _common +from src.proto.grpc.testing import messages_pb2, test_pb2_grpc +from tests_aio.unit._constants import UNREACHABLE_TARGET + +_REQUEST = b'\x01\x02\x03' +_TEST_METHOD = '/test/Test' + +_NUM_STREAM_RESPONSES = 5 +_REQUEST_PAYLOAD_SIZE = 7 +_RESPONSE_PAYLOAD_SIZE = 42 + + +class TestWaitForConnection(AioTestBase): + """Tests if wait_for_connection raises connectivity issue.""" + + async def setUp(self): + address, self._server = await start_test_server() + self._channel = aio.insecure_channel(address) + self._dummy_channel = aio.insecure_channel(UNREACHABLE_TARGET) + self._stub = test_pb2_grpc.TestServiceStub(self._channel) + + async def tearDown(self): + await self._dummy_channel.close() + await self._channel.close() + await self._server.stop(None) + + async def test_unary_unary_ok(self): + call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) + + # No exception raised and no message swallowed. + await call.wait_for_connection() + + response = await call + self.assertIsInstance(response, messages_pb2.SimpleResponse) + + async def test_unary_stream_ok(self): + request = messages_pb2.StreamingOutputCallRequest() + for _ in range(_NUM_STREAM_RESPONSES): + request.response_parameters.append( + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) + + call = self._stub.StreamingOutputCall(request) + + # No exception raised and no message swallowed. + await call.wait_for_connection() + + response_cnt = 0 + async for response in call: + response_cnt += 1 + self.assertIs(type(response), + messages_pb2.StreamingOutputCallResponse) + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) + + self.assertEqual(_NUM_STREAM_RESPONSES, response_cnt) + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + async def test_stream_unary_ok(self): + call = self._stub.StreamingInputCall() + + # No exception raised and no message swallowed. + await call.wait_for_connection() + + payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest(payload=payload) + + for _ in range(_NUM_STREAM_RESPONSES): + await call.write(request) + await call.done_writing() + + response = await call + self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse) + self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, + response.aggregated_payload_size) + + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + async def test_stream_stream_ok(self): + call = self._stub.FullDuplexCall() + + # No exception raised and no message swallowed. + await call.wait_for_connection() + + request = messages_pb2.StreamingOutputCallRequest() + request.response_parameters.append( + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) + + for _ in range(_NUM_STREAM_RESPONSES): + await call.write(request) + response = await call.read() + self.assertIsInstance(response, + messages_pb2.StreamingOutputCallResponse) + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) + + await call.done_writing() + + self.assertEqual(grpc.StatusCode.OK, await call.code()) + + async def test_unary_unary_error(self): + call = self._dummy_channel.unary_unary(_TEST_METHOD)(_REQUEST) + + with self.assertRaises(aio.AioRpcError) as exception_context: + await call.wait_for_connection() + rpc_error = exception_context.exception + self.assertEqual(grpc.StatusCode.UNAVAILABLE, rpc_error.code()) + + async def test_unary_stream_error(self): + call = self._dummy_channel.unary_stream(_TEST_METHOD)(_REQUEST) + + with self.assertRaises(aio.AioRpcError) as exception_context: + await call.wait_for_connection() + rpc_error = exception_context.exception + self.assertEqual(grpc.StatusCode.UNAVAILABLE, rpc_error.code()) + + async def test_stream_unary_error(self): + call = self._dummy_channel.stream_unary(_TEST_METHOD)() + + with self.assertRaises(aio.AioRpcError) as exception_context: + await call.wait_for_connection() + rpc_error = exception_context.exception + self.assertEqual(grpc.StatusCode.UNAVAILABLE, rpc_error.code()) + + async def test_stream_stream_error(self): + call = self._dummy_channel.stream_stream(_TEST_METHOD)() + + with self.assertRaises(aio.AioRpcError) as exception_context: + await call.wait_for_connection() + rpc_error = exception_context.exception + self.assertEqual(grpc.StatusCode.UNAVAILABLE, rpc_error.code()) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/wait_for_ready_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/wait_for_ready_test.py new file mode 100644 index 00000000000..5bcfd54856b --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_aio/unit/wait_for_ready_test.py @@ -0,0 +1,146 @@ +# Copyright 2020 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing the done callbacks mechanism.""" + +import asyncio +import logging +import unittest +import time +import gc + +import grpc +from grpc.experimental import aio +from tests_aio.unit._test_base import AioTestBase +from tests.unit.framework.common import test_constants +from tests.unit.framework.common import get_socket +from src.proto.grpc.testing import messages_pb2, test_pb2_grpc +from tests_aio.unit._test_server import start_test_server +from tests_aio.unit import _common + +_NUM_STREAM_RESPONSES = 5 +_REQUEST_PAYLOAD_SIZE = 7 +_RESPONSE_PAYLOAD_SIZE = 42 + + +async def _perform_unary_unary(stub, wait_for_ready): + await stub.UnaryCall(messages_pb2.SimpleRequest(), + timeout=test_constants.LONG_TIMEOUT, + wait_for_ready=wait_for_ready) + + +async def _perform_unary_stream(stub, wait_for_ready): + request = messages_pb2.StreamingOutputCallRequest() + for _ in range(_NUM_STREAM_RESPONSES): + request.response_parameters.append( + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) + + call = stub.StreamingOutputCall(request, + timeout=test_constants.LONG_TIMEOUT, + wait_for_ready=wait_for_ready) + + for _ in range(_NUM_STREAM_RESPONSES): + await call.read() + assert await call.code() == grpc.StatusCode.OK + + +async def _perform_stream_unary(stub, wait_for_ready): + payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest(payload=payload) + + async def gen(): + for _ in range(_NUM_STREAM_RESPONSES): + yield request + + await stub.StreamingInputCall(gen(), + timeout=test_constants.LONG_TIMEOUT, + wait_for_ready=wait_for_ready) + + +async def _perform_stream_stream(stub, wait_for_ready): + call = stub.FullDuplexCall(timeout=test_constants.LONG_TIMEOUT, + wait_for_ready=wait_for_ready) + + request = messages_pb2.StreamingOutputCallRequest() + request.response_parameters.append( + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) + + for _ in range(_NUM_STREAM_RESPONSES): + await call.write(request) + response = await call.read() + assert _RESPONSE_PAYLOAD_SIZE == len(response.payload.body) + + await call.done_writing() + assert await call.code() == grpc.StatusCode.OK + + +_RPC_ACTIONS = ( + _perform_unary_unary, + _perform_unary_stream, + _perform_stream_unary, + _perform_stream_stream, +) + + +class TestWaitForReady(AioTestBase): + + async def setUp(self): + address, self._port, self._socket = get_socket(listen=False) + self._channel = aio.insecure_channel(f"{address}:{self._port}") + self._stub = test_pb2_grpc.TestServiceStub(self._channel) + self._socket.close() + + async def tearDown(self): + await self._channel.close() + + async def _connection_fails_fast(self, wait_for_ready): + for action in _RPC_ACTIONS: + with self.subTest(name=action): + with self.assertRaises(aio.AioRpcError) as exception_context: + await action(self._stub, wait_for_ready) + rpc_error = exception_context.exception + self.assertEqual(grpc.StatusCode.UNAVAILABLE, rpc_error.code()) + + async def test_call_wait_for_ready_default(self): + """RPC should fail immediately after connection failed.""" + await self._connection_fails_fast(None) + + async def test_call_wait_for_ready_disabled(self): + """RPC should fail immediately after connection failed.""" + await self._connection_fails_fast(False) + + async def test_call_wait_for_ready_enabled(self): + """RPC will wait until the connection is ready.""" + for action in _RPC_ACTIONS: + with self.subTest(name=action.__name__): + # Starts the RPC + action_task = self.loop.create_task(action(self._stub, True)) + + # Wait for TRANSIENT_FAILURE, and RPC is not aborting + await _common.block_until_certain_state( + self._channel, grpc.ChannelConnectivity.TRANSIENT_FAILURE) + + try: + # Start the server + _, server = await start_test_server(port=self._port) + + # The RPC should recover itself + await action_task + finally: + if server is not None: + await server.stop(None) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_py3_only/__init__.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_py3_only/__init__.py new file mode 100644 index 00000000000..6732ae8cbb5 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_py3_only/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2020 The gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import + +from tests import _loader +from tests import _runner + +Loader = _loader.Loader +Runner = _runner.Runner diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_py3_only/interop/xds_interop_client.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_py3_only/interop/xds_interop_client.py new file mode 100644 index 00000000000..21277a98cf2 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_py3_only/interop/xds_interop_client.py @@ -0,0 +1,348 @@ +# Copyright 2020 The gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import signal +import threading +import time +import sys + +from typing import DefaultDict, Dict, List, Mapping, Set, Sequence, Tuple +import collections + +from concurrent import futures + +import grpc + +from src.proto.grpc.testing import test_pb2 +from src.proto.grpc.testing import test_pb2_grpc +from src.proto.grpc.testing import messages_pb2 +from src.proto.grpc.testing import empty_pb2 + +logger = logging.getLogger() +console_handler = logging.StreamHandler() +formatter = logging.Formatter(fmt='%(asctime)s: %(levelname)-8s %(message)s') +console_handler.setFormatter(formatter) +logger.addHandler(console_handler) + +_SUPPORTED_METHODS = ( + "UnaryCall", + "EmptyCall", +) + +PerMethodMetadataType = Mapping[str, Sequence[Tuple[str, str]]] + + +class _StatsWatcher: + _start: int + _end: int + _rpcs_needed: int + _rpcs_by_peer: DefaultDict[str, int] + _rpcs_by_method: DefaultDict[str, DefaultDict[str, int]] + _no_remote_peer: int + _lock: threading.Lock + _condition: threading.Condition + + def __init__(self, start: int, end: int): + self._start = start + self._end = end + self._rpcs_needed = end - start + self._rpcs_by_peer = collections.defaultdict(int) + self._rpcs_by_method = collections.defaultdict( + lambda: collections.defaultdict(int)) + self._condition = threading.Condition() + self._no_remote_peer = 0 + + def on_rpc_complete(self, request_id: int, peer: str, method: str) -> None: + """Records statistics for a single RPC.""" + if self._start <= request_id < self._end: + with self._condition: + if not peer: + self._no_remote_peer += 1 + else: + self._rpcs_by_peer[peer] += 1 + self._rpcs_by_method[method][peer] += 1 + self._rpcs_needed -= 1 + self._condition.notify() + + def await_rpc_stats_response(self, timeout_sec: int + ) -> messages_pb2.LoadBalancerStatsResponse: + """Blocks until a full response has been collected.""" + with self._condition: + self._condition.wait_for(lambda: not self._rpcs_needed, + timeout=float(timeout_sec)) + response = messages_pb2.LoadBalancerStatsResponse() + for peer, count in self._rpcs_by_peer.items(): + response.rpcs_by_peer[peer] = count + for method, count_by_peer in self._rpcs_by_method.items(): + for peer, count in count_by_peer.items(): + response.rpcs_by_method[method].rpcs_by_peer[peer] = count + response.num_failures = self._no_remote_peer + self._rpcs_needed + return response + + +_global_lock = threading.Lock() +_stop_event = threading.Event() +_global_rpc_id: int = 0 +_watchers: Set[_StatsWatcher] = set() +_global_server = None + + +def _handle_sigint(sig, frame): + _stop_event.set() + _global_server.stop(None) + + +class _LoadBalancerStatsServicer(test_pb2_grpc.LoadBalancerStatsServiceServicer + ): + + def __init__(self): + super(_LoadBalancerStatsServicer).__init__() + + def GetClientStats(self, request: messages_pb2.LoadBalancerStatsRequest, + context: grpc.ServicerContext + ) -> messages_pb2.LoadBalancerStatsResponse: + logger.info("Received stats request.") + start = None + end = None + watcher = None + with _global_lock: + start = _global_rpc_id + 1 + end = start + request.num_rpcs + watcher = _StatsWatcher(start, end) + _watchers.add(watcher) + response = watcher.await_rpc_stats_response(request.timeout_sec) + with _global_lock: + _watchers.remove(watcher) + logger.info("Returning stats response: {}".format(response)) + return response + + +def _start_rpc(method: str, metadata: Sequence[Tuple[str, str]], + request_id: int, stub: test_pb2_grpc.TestServiceStub, + timeout: float, + futures: Mapping[int, Tuple[grpc.Future, str]]) -> None: + logger.info(f"Sending {method} request to backend: {request_id}") + if method == "UnaryCall": + future = stub.UnaryCall.future(messages_pb2.SimpleRequest(), + metadata=metadata, + timeout=timeout) + elif method == "EmptyCall": + future = stub.EmptyCall.future(empty_pb2.Empty(), + metadata=metadata, + timeout=timeout) + else: + raise ValueError(f"Unrecognized method '{method}'.") + futures[request_id] = (future, method) + + +def _on_rpc_done(rpc_id: int, future: grpc.Future, method: str, + print_response: bool) -> None: + exception = future.exception() + hostname = "" + if exception is not None: + if exception.code() == grpc.StatusCode.DEADLINE_EXCEEDED: + logger.error(f"RPC {rpc_id} timed out") + else: + logger.error(exception) + else: + response = future.result() + hostname = None + for metadatum in future.initial_metadata(): + if metadatum[0] == "hostname": + hostname = metadatum[1] + break + else: + hostname = response.hostname + if print_response: + if future.code() == grpc.StatusCode.OK: + logger.info("Successful response.") + else: + logger.info(f"RPC failed: {call}") + with _global_lock: + for watcher in _watchers: + watcher.on_rpc_complete(rpc_id, hostname, method) + + +def _remove_completed_rpcs(futures: Mapping[int, grpc.Future], + print_response: bool) -> None: + logger.debug("Removing completed RPCs") + done = [] + for future_id, (future, method) in futures.items(): + if future.done(): + _on_rpc_done(future_id, future, method, args.print_response) + done.append(future_id) + for rpc_id in done: + del futures[rpc_id] + + +def _cancel_all_rpcs(futures: Mapping[int, Tuple[grpc.Future, str]]) -> None: + logger.info("Cancelling all remaining RPCs") + for future, _ in futures.values(): + future.cancel() + + +def _run_single_channel(method: str, metadata: Sequence[Tuple[str, str]], + qps: int, server: str, rpc_timeout_sec: int, + print_response: bool): + global _global_rpc_id # pylint: disable=global-statement + duration_per_query = 1.0 / float(qps) + with grpc.insecure_channel(server) as channel: + stub = test_pb2_grpc.TestServiceStub(channel) + futures: Dict[int, Tuple[grpc.Future, str]] = {} + while not _stop_event.is_set(): + request_id = None + with _global_lock: + request_id = _global_rpc_id + _global_rpc_id += 1 + start = time.time() + end = start + duration_per_query + _start_rpc(method, metadata, request_id, stub, + float(rpc_timeout_sec), futures) + _remove_completed_rpcs(futures, print_response) + logger.debug(f"Currently {len(futures)} in-flight RPCs") + now = time.time() + while now < end: + time.sleep(end - now) + now = time.time() + _cancel_all_rpcs(futures) + + +class _MethodHandle: + """An object grouping together threads driving RPCs for a method.""" + + _channel_threads: List[threading.Thread] + + def __init__(self, method: str, metadata: Sequence[Tuple[str, str]], + num_channels: int, qps: int, server: str, rpc_timeout_sec: int, + print_response: bool): + """Creates and starts a group of threads running the indicated method.""" + self._channel_threads = [] + for i in range(num_channels): + thread = threading.Thread(target=_run_single_channel, + args=( + method, + metadata, + qps, + server, + rpc_timeout_sec, + print_response, + )) + thread.start() + self._channel_threads.append(thread) + + def stop(self): + """Joins all threads referenced by the handle.""" + for channel_thread in self._channel_threads: + channel_thread.join() + + +def _run(args: argparse.Namespace, methods: Sequence[str], + per_method_metadata: PerMethodMetadataType) -> None: + logger.info("Starting python xDS Interop Client.") + global _global_server # pylint: disable=global-statement + method_handles = [] + for method in methods: + method_handles.append( + _MethodHandle(method, per_method_metadata.get(method, []), + args.num_channels, args.qps, args.server, + args.rpc_timeout_sec, args.print_response)) + _global_server = grpc.server(futures.ThreadPoolExecutor()) + _global_server.add_insecure_port(f"0.0.0.0:{args.stats_port}") + test_pb2_grpc.add_LoadBalancerStatsServiceServicer_to_server( + _LoadBalancerStatsServicer(), _global_server) + _global_server.start() + _global_server.wait_for_termination() + for method_handle in method_handles: + method_handle.stop() + + +def parse_metadata_arg(metadata_arg: str) -> PerMethodMetadataType: + metadata = metadata_arg.split(",") if args.metadata else [] + per_method_metadata = collections.defaultdict(list) + for metadatum in metadata: + elems = metadatum.split(":") + if len(elems) != 3: + raise ValueError( + f"'{metadatum}' was not in the form 'METHOD:KEY:VALUE'") + if elems[0] not in _SUPPORTED_METHODS: + raise ValueError(f"Unrecognized method '{elems[0]}'") + per_method_metadata[elems[0]].append((elems[1], elems[2])) + return per_method_metadata + + +def parse_rpc_arg(rpc_arg: str) -> Sequence[str]: + methods = rpc_arg.split(",") + if set(methods) - set(_SUPPORTED_METHODS): + raise ValueError("--rpc supported methods: {}".format( + ", ".join(_SUPPORTED_METHODS))) + return methods + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description='Run Python XDS interop client.') + parser.add_argument( + "--num_channels", + default=1, + type=int, + help="The number of channels from which to send requests.") + parser.add_argument("--print_response", + default=False, + action="store_true", + help="Write RPC response to STDOUT.") + parser.add_argument( + "--qps", + default=1, + type=int, + help="The number of queries to send from each channel per second.") + parser.add_argument("--rpc_timeout_sec", + default=30, + type=int, + help="The per-RPC timeout in seconds.") + parser.add_argument("--server", + default="localhost:50051", + help="The address of the server.") + parser.add_argument( + "--stats_port", + default=50052, + type=int, + help="The port on which to expose the peer distribution stats service.") + parser.add_argument('--verbose', + help='verbose log output', + default=False, + action='store_true') + parser.add_argument("--log_file", + default=None, + type=str, + help="A file to log to.") + rpc_help = "A comma-delimited list of RPC methods to run. Must be one of " + rpc_help += ", ".join(_SUPPORTED_METHODS) + rpc_help += "." + parser.add_argument("--rpc", default="UnaryCall", type=str, help=rpc_help) + metadata_help = ( + "A comma-delimited list of 3-tuples of the form " + + "METHOD:KEY:VALUE, e.g. " + + "EmptyCall:key1:value1,UnaryCall:key2:value2,EmptyCall:k3:v3") + parser.add_argument("--metadata", default="", type=str, help=metadata_help) + args = parser.parse_args() + signal.signal(signal.SIGINT, _handle_sigint) + if args.verbose: + logger.setLevel(logging.DEBUG) + if args.log_file: + file_handler = logging.FileHandler(args.log_file, mode='a') + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + _run(args, parse_rpc_arg(args.rpc), parse_metadata_arg(args.metadata)) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_py3_only/unit/__init__.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_py3_only/unit/__init__.py new file mode 100644 index 00000000000..f4b321fc5b2 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_py3_only/unit/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2019 The gRPC Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_py3_only/unit/_leak_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_py3_only/unit/_leak_test.py new file mode 100644 index 00000000000..3b3f12fa1f9 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_py3_only/unit/_leak_test.py @@ -0,0 +1,98 @@ +# Copyright 2020 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A smoke test for memory leaks on short-lived channels without close. + +This test doesn't guarantee all resources are cleaned if `Channel.close` is not +explicitly invoked. The recommended way of using Channel object is using `with` +clause, and let context manager automatically close the channel. +""" + +import logging +import os +import resource +import sys +import unittest +from concurrent.futures import ThreadPoolExecutor + +import grpc + +_TEST_METHOD = '/test/Test' +_REQUEST = b'\x23\x33' +_LARGE_NUM_OF_ITERATIONS = 5000 + +# If MAX_RSS inflated more than this size, the test is failed. +_FAIL_THRESHOLD = 25 * 1024 * 1024 # 25 MiB + + +def _get_max_rss(): + return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss + + +def _pretty_print_bytes(x): + if x > 1024 * 1024 * 1024: + return "%.2f GiB" % (x / 1024.0 / 1024 / 1024) + elif x > 1024 * 1024: + return "%.2f MiB" % (x / 1024.0 / 1024) + elif x > 1024: + return "%.2f KiB" % (x / 1024.0) + else: + return "%d B" % x + + +class _GenericHandler(grpc.GenericRpcHandler): + + def service(self, handler_call_details): + if handler_call_details.method == _TEST_METHOD: + return grpc.unary_unary_rpc_method_handler(lambda x, _: x) + + +def _start_a_test_server(): + server = grpc.server(ThreadPoolExecutor(max_workers=1), + options=(('grpc.so_reuseport', 0),)) + server.add_generic_rpc_handlers((_GenericHandler(),)) + port = server.add_insecure_port('localhost:0') + server.start() + return 'localhost:%d' % port, server + + +def _perform_an_rpc(address): + channel = grpc.insecure_channel(address) + multicallable = channel.unary_unary(_TEST_METHOD) + response = multicallable(_REQUEST) + assert _REQUEST == response + + +class TestLeak(unittest.TestCase): + + def test_leak_with_single_shot_rpcs(self): + address, server = _start_a_test_server() + + # Records memory before experiment. + before = _get_max_rss() + + # Amplifies the leak. + for n in range(_LARGE_NUM_OF_ITERATIONS): + _perform_an_rpc(address) + + # Fails the test if memory leak detected. + diff = _get_max_rss() - before + if diff > _FAIL_THRESHOLD: + self.fail("Max RSS inflated {} > {}".format( + _pretty_print_bytes(diff), + _pretty_print_bytes(_FAIL_THRESHOLD))) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/tests_py3_only/unit/_simple_stubs_test.py b/contrib/libs/grpc/src/python/grpcio_tests/tests_py3_only/unit/_simple_stubs_test.py new file mode 100644 index 00000000000..08d5a882eb9 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/tests_py3_only/unit/_simple_stubs_test.py @@ -0,0 +1,415 @@ +# Copyright 2020 The gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for Simple Stubs.""" + +# TODO(https://github.com/grpc/grpc/issues/21965): Run under setuptools. + +import os + +_MAXIMUM_CHANNELS = 10 + +_DEFAULT_TIMEOUT = 1.0 + +os.environ["GRPC_PYTHON_MANAGED_CHANNEL_EVICTION_SECONDS"] = "2" +os.environ["GRPC_PYTHON_MANAGED_CHANNEL_MAXIMUM"] = str(_MAXIMUM_CHANNELS) +os.environ["GRPC_PYTHON_DEFAULT_TIMEOUT_SECONDS"] = str(_DEFAULT_TIMEOUT) + +import contextlib +import datetime +import inspect +import logging +import threading +import unittest +import sys +import time +from typing import Callable, Optional + +from tests.unit import test_common +from tests.unit.framework.common import get_socket +from tests.unit import resources +import grpc +import grpc.experimental + +_REQUEST = b"0000" + +_CACHE_EPOCHS = 8 +_CACHE_TRIALS = 6 + +_SERVER_RESPONSE_COUNT = 10 +_CLIENT_REQUEST_COUNT = _SERVER_RESPONSE_COUNT + +_STRESS_EPOCHS = _MAXIMUM_CHANNELS * 10 + +_UNARY_UNARY = "/test/UnaryUnary" +_UNARY_STREAM = "/test/UnaryStream" +_STREAM_UNARY = "/test/StreamUnary" +_STREAM_STREAM = "/test/StreamStream" +_BLACK_HOLE = "/test/BlackHole" + + +def _env(key: str, value: str): + os.environ[key] = value + yield + del os.environ[key] + + +def _unary_unary_handler(request, context): + return request + + +def _unary_stream_handler(request, context): + for _ in range(_SERVER_RESPONSE_COUNT): + yield request + + +def _stream_unary_handler(request_iterator, context): + request = None + for single_request in request_iterator: + request = single_request + return request + + +def _stream_stream_handler(request_iterator, context): + for request in request_iterator: + yield request + + +def _black_hole_handler(request, context): + event = threading.Event() + + def _on_done(): + event.set() + + context.add_callback(_on_done) + while not event.is_set(): + time.sleep(0.1) + + +class _GenericHandler(grpc.GenericRpcHandler): + + def service(self, handler_call_details): + if handler_call_details.method == _UNARY_UNARY: + return grpc.unary_unary_rpc_method_handler(_unary_unary_handler) + elif handler_call_details.method == _UNARY_STREAM: + return grpc.unary_stream_rpc_method_handler(_unary_stream_handler) + elif handler_call_details.method == _STREAM_UNARY: + return grpc.stream_unary_rpc_method_handler(_stream_unary_handler) + elif handler_call_details.method == _STREAM_STREAM: + return grpc.stream_stream_rpc_method_handler(_stream_stream_handler) + elif handler_call_details.method == _BLACK_HOLE: + return grpc.unary_unary_rpc_method_handler(_black_hole_handler) + else: + raise NotImplementedError() + + +def _time_invocation(to_time: Callable[[], None]) -> datetime.timedelta: + start = datetime.datetime.now() + to_time() + return datetime.datetime.now() - start + + +def _server(credentials: Optional[grpc.ServerCredentials]): + try: + server = test_common.test_server() + target = '[::]:0' + if credentials is None: + port = server.add_insecure_port(target) + else: + port = server.add_secure_port(target, credentials) + server.add_generic_rpc_handlers((_GenericHandler(),)) + server.start() + yield port + finally: + server.stop(None) + + +class SimpleStubsTest(unittest.TestCase): + + def assert_cached(self, to_check: Callable[[str], None]) -> None: + """Asserts that a function caches intermediate data/state. + + To be specific, given a function whose caching behavior is + deterministic in the value of a supplied string, this function asserts + that, on average, subsequent invocations of the function for a specific + string are faster than first invocations with that same string. + + Args: + to_check: A function returning nothing, that caches values based on + an arbitrary supplied string. + """ + initial_runs = [] + cached_runs = [] + for epoch in range(_CACHE_EPOCHS): + runs = [] + text = str(epoch) + for trial in range(_CACHE_TRIALS): + runs.append(_time_invocation(lambda: to_check(text))) + initial_runs.append(runs[0]) + cached_runs.extend(runs[1:]) + average_cold = sum((run for run in initial_runs), + datetime.timedelta()) / len(initial_runs) + average_warm = sum((run for run in cached_runs), + datetime.timedelta()) / len(cached_runs) + self.assertLess(average_warm, average_cold) + + def assert_eventually(self, + predicate: Callable[[], bool], + *, + timeout: Optional[datetime.timedelta] = None, + message: Optional[Callable[[], str]] = None) -> None: + message = message or (lambda: "Proposition did not evaluate to true") + timeout = timeout or datetime.timedelta(seconds=10) + end = datetime.datetime.now() + timeout + while datetime.datetime.now() < end: + if predicate(): + break + time.sleep(0.5) + else: + self.fail(message() + " after " + str(timeout)) + + def test_unary_unary_insecure(self): + with _server(None) as port: + target = f'localhost:{port}' + response = grpc.experimental.unary_unary( + _REQUEST, + target, + _UNARY_UNARY, + channel_credentials=grpc.experimental. + insecure_channel_credentials(), + timeout=None) + self.assertEqual(_REQUEST, response) + + def test_unary_unary_secure(self): + with _server(grpc.local_server_credentials()) as port: + target = f'localhost:{port}' + response = grpc.experimental.unary_unary( + _REQUEST, + target, + _UNARY_UNARY, + channel_credentials=grpc.local_channel_credentials(), + timeout=None) + self.assertEqual(_REQUEST, response) + + def test_channels_cached(self): + with _server(grpc.local_server_credentials()) as port: + target = f'localhost:{port}' + test_name = inspect.stack()[0][3] + args = (_REQUEST, target, _UNARY_UNARY) + kwargs = {"channel_credentials": grpc.local_channel_credentials()} + + def _invoke(seed: str): + run_kwargs = dict(kwargs) + run_kwargs["options"] = ((test_name + seed, ""),) + grpc.experimental.unary_unary(*args, **run_kwargs) + + self.assert_cached(_invoke) + + def test_channels_evicted(self): + with _server(grpc.local_server_credentials()) as port: + target = f'localhost:{port}' + response = grpc.experimental.unary_unary( + _REQUEST, + target, + _UNARY_UNARY, + channel_credentials=grpc.local_channel_credentials()) + self.assert_eventually( + lambda: grpc._simple_stubs.ChannelCache.get( + )._test_only_channel_count() == 0, + message=lambda: + f"{grpc._simple_stubs.ChannelCache.get()._test_only_channel_count()} remain" + ) + + def test_total_channels_enforced(self): + with _server(grpc.local_server_credentials()) as port: + target = f'localhost:{port}' + for i in range(_STRESS_EPOCHS): + # Ensure we get a new channel each time. + options = (("foo", str(i)),) + # Send messages at full blast. + grpc.experimental.unary_unary( + _REQUEST, + target, + _UNARY_UNARY, + options=options, + channel_credentials=grpc.local_channel_credentials()) + self.assert_eventually( + lambda: grpc._simple_stubs.ChannelCache.get( + )._test_only_channel_count() <= _MAXIMUM_CHANNELS + 1, + message=lambda: + f"{grpc._simple_stubs.ChannelCache.get()._test_only_channel_count()} channels remain" + ) + + def test_unary_stream(self): + with _server(grpc.local_server_credentials()) as port: + target = f'localhost:{port}' + for response in grpc.experimental.unary_stream( + _REQUEST, + target, + _UNARY_STREAM, + channel_credentials=grpc.local_channel_credentials()): + self.assertEqual(_REQUEST, response) + + def test_stream_unary(self): + + def request_iter(): + for _ in range(_CLIENT_REQUEST_COUNT): + yield _REQUEST + + with _server(grpc.local_server_credentials()) as port: + target = f'localhost:{port}' + response = grpc.experimental.stream_unary( + request_iter(), + target, + _STREAM_UNARY, + channel_credentials=grpc.local_channel_credentials()) + self.assertEqual(_REQUEST, response) + + def test_stream_stream(self): + + def request_iter(): + for _ in range(_CLIENT_REQUEST_COUNT): + yield _REQUEST + + with _server(grpc.local_server_credentials()) as port: + target = f'localhost:{port}' + for response in grpc.experimental.stream_stream( + request_iter(), + target, + _STREAM_STREAM, + channel_credentials=grpc.local_channel_credentials()): + self.assertEqual(_REQUEST, response) + + def test_default_ssl(self): + _private_key = resources.private_key() + _certificate_chain = resources.certificate_chain() + _server_certs = ((_private_key, _certificate_chain),) + _server_host_override = 'foo.test.google.fr' + _test_root_certificates = resources.test_root_certificates() + _property_options = (( + 'grpc.ssl_target_name_override', + _server_host_override, + ),) + cert_dir = os.path.join(os.path.dirname(resources.__file__), + "credentials") + cert_file = os.path.join(cert_dir, "ca.pem") + with _env("GRPC_DEFAULT_SSL_ROOTS_FILE_PATH", cert_file): + server_creds = grpc.ssl_server_credentials(_server_certs) + with _server(server_creds) as port: + target = f'localhost:{port}' + response = grpc.experimental.unary_unary( + _REQUEST, target, _UNARY_UNARY, options=_property_options) + + def test_insecure_sugar(self): + with _server(None) as port: + target = f'localhost:{port}' + response = grpc.experimental.unary_unary(_REQUEST, + target, + _UNARY_UNARY, + insecure=True) + self.assertEqual(_REQUEST, response) + + def test_insecure_sugar_mutually_exclusive(self): + with _server(None) as port: + target = f'localhost:{port}' + with self.assertRaises(ValueError): + response = grpc.experimental.unary_unary( + _REQUEST, + target, + _UNARY_UNARY, + insecure=True, + channel_credentials=grpc.local_channel_credentials()) + + def test_default_wait_for_ready(self): + addr, port, sock = get_socket() + sock.close() + target = f'{addr}:{port}' + channel = grpc._simple_stubs.ChannelCache.get().get_channel( + target, (), None, True, None) + rpc_finished_event = threading.Event() + rpc_failed_event = threading.Event() + server = None + + def _on_connectivity_changed(connectivity): + nonlocal server + if connectivity is grpc.ChannelConnectivity.TRANSIENT_FAILURE: + self.assertFalse(rpc_finished_event.is_set()) + self.assertFalse(rpc_failed_event.is_set()) + server = test_common.test_server() + server.add_insecure_port(target) + server.add_generic_rpc_handlers((_GenericHandler(),)) + server.start() + channel.unsubscribe(_on_connectivity_changed) + elif connectivity in (grpc.ChannelConnectivity.IDLE, + grpc.ChannelConnectivity.CONNECTING): + pass + else: + self.fail("Encountered unknown state.") + + channel.subscribe(_on_connectivity_changed) + + def _send_rpc(): + try: + response = grpc.experimental.unary_unary(_REQUEST, + target, + _UNARY_UNARY, + timeout=None, + insecure=True) + rpc_finished_event.set() + except Exception as e: + rpc_failed_event.set() + + t = threading.Thread(target=_send_rpc) + t.start() + t.join() + self.assertFalse(rpc_failed_event.is_set()) + self.assertTrue(rpc_finished_event.is_set()) + if server is not None: + server.stop(None) + + def assert_times_out(self, invocation_args): + with _server(None) as port: + target = f'localhost:{port}' + with self.assertRaises(grpc.RpcError) as cm: + response = grpc.experimental.unary_unary(_REQUEST, + target, + _BLACK_HOLE, + insecure=True, + **invocation_args) + self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, + cm.exception.code()) + + def test_default_timeout(self): + not_present = object() + wait_for_ready_values = [True, not_present] + timeout_values = [0.5, not_present] + cases = [] + for wait_for_ready in wait_for_ready_values: + for timeout in timeout_values: + case = {} + if timeout is not not_present: + case["timeout"] = timeout + if wait_for_ready is not not_present: + case["wait_for_ready"] = wait_for_ready + cases.append(case) + + for case in cases: + with self.subTest(**case): + self.assert_times_out(case) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + unittest.main(verbosity=2) diff --git a/contrib/libs/grpc/src/python/grpcio_tests/ya.make b/contrib/libs/grpc/src/python/grpcio_tests/ya.make new file mode 100644 index 00000000000..b0642eae345 --- /dev/null +++ b/contrib/libs/grpc/src/python/grpcio_tests/ya.make @@ -0,0 +1,141 @@ +PY3TEST() + +LICENSE(Apache-2.0) + +LICENSE_TEXTS(.yandex_meta/licenses.list.txt) + +PEERDIR( + contrib/libs/grpc/python +) + +NO_LINT() + +PY_SRCS( + TOP_LEVEL + # tests/_sanity/__init__.py + # tests/testing/proto/__init__.py + # tests/testing/__init__.py + # tests/testing/_application_common.py + # tests/testing/_application_testing_common.py + # tests/testing/_client_application.py + # tests/testing/_client_test.py + # tests/testing/_server_application.py + # tests/testing/_server_test.py + # tests/testing/_time_test.py + tests/unit/__init__.py + tests/unit/_cython/__init__.py + tests/unit/_cython/_common.py + tests/unit/_cython/test_utilities.py + tests/unit/_exit_scenarios.py + tests/unit/_from_grpc_import_star.py + tests/unit/_rpc_test_helpers.py + tests/unit/_server_shutdown_scenarios.py + tests/unit/_signal_client.py + tests/unit/_tcp_proxy.py + tests/unit/beta/__init__.py + tests/unit/beta/test_utilities.py + tests/unit/framework/__init__.py + tests/unit/framework/common/__init__.py + tests/unit/framework/common/test_constants.py + tests/unit/framework/common/test_control.py + tests/unit/framework/common/test_coverage.py + tests/unit/framework/foundation/__init__.py + tests/unit/resources.py + tests/unit/test_common.py + tests/unit/thread_pool.py + # protofiles + # tests/interop/__init__.py + # tests/interop/_intraop_test_case.py + # tests/interop/client.py + # tests/interop/methods.py + # tests/interop/resources.py + # tests/interop/server.py + # tests/interop/service.py + # protofiles + # tests/fork/__init__.py + # tests/fork/client.py + # tests/fork/methods.py + # protofiles + # tests/__init__.py + # tests/_loader.py + # tests/_result.py + # tests/_runner.py +) + +TEST_SRCS( + # coverage + # tests/_sanity/_sanity_test.py + tests/unit/_api_test.py + tests/unit/_abort_test.py + # CRASH + # tests/unit/_auth_context_test.py + tests/unit/_auth_test.py + tests/unit/_channel_args_test.py + tests/unit/_channel_close_test.py + tests/unit/_channel_connectivity_test.py + tests/unit/_channel_ready_future_test.py + # FLAKY + # tests/unit/_compression_test.py + tests/unit/_contextvars_propagation_test.py + tests/unit/_credentials_test.py + tests/unit/_cython/_cancel_many_calls_test.py + tests/unit/_cython/_channel_test.py + tests/unit/_cython/_fork_test.py + tests/unit/_cython/_no_messages_server_completion_queue_per_call_test.py + tests/unit/_cython/_no_messages_single_server_completion_queue_test.py + tests/unit/_cython/_read_some_but_not_all_responses_test.py + tests/unit/_cython/_server_test.py + tests/unit/_cython/cygrpc_test.py + tests/unit/_dns_resolver_test.py + tests/unit/_dynamic_stubs_test.py + tests/unit/_empty_message_test.py + tests/unit/_error_message_encoding_test.py + tests/unit/_exit_test.py + tests/unit/_grpc_shutdown_test.py + tests/unit/_interceptor_test.py + tests/unit/_invalid_metadata_test.py + tests/unit/_invocation_defects_test.py + tests/unit/_local_credentials_test.py + tests/unit/_logging_test.py + tests/unit/_metadata_code_details_test.py + tests/unit/_metadata_flags_test.py + tests/unit/_metadata_test.py + tests/unit/_reconnect_test.py + tests/unit/_resource_exhausted_test.py + tests/unit/_rpc_part_1_test.py + tests/unit/_rpc_part_2_test.py + tests/unit/_server_shutdown_test.py + # tests.testing + # tests/unit/_server_ssl_cert_config_test.py + tests/unit/_server_test.py + tests/unit/_server_wait_for_termination_test.py + # CRASH + # tests/unit/_session_cache_test.py + tests/unit/_signal_handling_test.py + tests/unit/_version_test.py + tests/unit/beta/_beta_features_test.py + tests/unit/beta/_connectivity_channel_test.py + # oauth2client + # tests/unit/beta/_implementations_test.py + tests/unit/beta/_not_found_test.py + tests/unit/beta/_utilities_test.py + tests/unit/framework/foundation/_logging_pool_test.py + tests/unit/framework/foundation/stream_testing.py + # protofiles + # tests/interop/_insecure_intraop_test.py + # tests/interop/_secure_intraop_test.py + # tests/fork/_fork_interop_test.py +) + +SIZE(MEDIUM) + +RESOURCE_FILES( + PREFIX contrib/libs/grpc/src/python/grpcio_tests/ + tests/unit/credentials/ca.pem + tests/unit/credentials/server1.key + tests/unit/credentials/server1.pem +) + +REQUIREMENTS(network:full) + +END() |