aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/libs/grpc/test/cpp
diff options
context:
space:
mode:
authorDevtools Arcadia <arcadia-devtools@yandex-team.ru>2022-02-07 18:08:42 +0300
committerDevtools Arcadia <arcadia-devtools@mous.vla.yp-c.yandex.net>2022-02-07 18:08:42 +0300
commit1110808a9d39d4b808aef724c861a2e1a38d2a69 (patch)
treee26c9fed0de5d9873cce7e00bc214573dc2195b7 /contrib/libs/grpc/test/cpp
downloadydb-1110808a9d39d4b808aef724c861a2e1a38d2a69.tar.gz
intermediate changes
ref:cde9a383711a11544ce7e107a78147fb96cc4029
Diffstat (limited to 'contrib/libs/grpc/test/cpp')
-rw-r--r--contrib/libs/grpc/test/cpp/README-iOS.md52
-rw-r--r--contrib/libs/grpc/test/cpp/end2end/.yandex_meta/licenses.list.txt36
-rw-r--r--contrib/libs/grpc/test/cpp/end2end/async_end2end_test.cc1952
-rw-r--r--contrib/libs/grpc/test/cpp/end2end/cfstream_test.cc496
-rw-r--r--contrib/libs/grpc/test/cpp/end2end/channelz_service_test.cc767
-rw-r--r--contrib/libs/grpc/test/cpp/end2end/client_callback_end2end_test.cc1565
-rw-r--r--contrib/libs/grpc/test/cpp/end2end/client_crash_test.cc147
-rw-r--r--contrib/libs/grpc/test/cpp/end2end/client_crash_test_server.cc80
-rw-r--r--contrib/libs/grpc/test/cpp/end2end/client_interceptors_end2end_test.cc1194
-rw-r--r--contrib/libs/grpc/test/cpp/end2end/client_lb_end2end_test.cc1990
-rw-r--r--contrib/libs/grpc/test/cpp/end2end/delegating_channel_test.cc100
-rw-r--r--contrib/libs/grpc/test/cpp/end2end/end2end_test.cc2357
-rw-r--r--contrib/libs/grpc/test/cpp/end2end/exception_test.cc123
-rw-r--r--contrib/libs/grpc/test/cpp/end2end/filter_end2end_test.cc346
-rw-r--r--contrib/libs/grpc/test/cpp/end2end/flaky_network_test.cc558
-rw-r--r--contrib/libs/grpc/test/cpp/end2end/generic_end2end_test.cc430
-rw-r--r--contrib/libs/grpc/test/cpp/end2end/grpclb_end2end_test.cc2029
-rw-r--r--contrib/libs/grpc/test/cpp/end2end/health/ya.make33
-rw-r--r--contrib/libs/grpc/test/cpp/end2end/health_service_end2end_test.cc374
-rw-r--r--contrib/libs/grpc/test/cpp/end2end/hybrid_end2end_test.cc987
-rw-r--r--contrib/libs/grpc/test/cpp/end2end/interceptors_util.cc214
-rw-r--r--contrib/libs/grpc/test/cpp/end2end/interceptors_util.h317
-rw-r--r--contrib/libs/grpc/test/cpp/end2end/message_allocator_end2end_test.cc438
-rw-r--r--contrib/libs/grpc/test/cpp/end2end/mock_test.cc434
-rw-r--r--contrib/libs/grpc/test/cpp/end2end/nonblocking_test.cc214
-rw-r--r--contrib/libs/grpc/test/cpp/end2end/port_sharing_end2end_test.cc374
-rw-r--r--contrib/libs/grpc/test/cpp/end2end/proto_server_reflection_test.cc150
-rw-r--r--contrib/libs/grpc/test/cpp/end2end/raw_end2end_test.cc370
-rw-r--r--contrib/libs/grpc/test/cpp/end2end/server_builder_plugin_test.cc265
-rw-r--r--contrib/libs/grpc/test/cpp/end2end/server_crash_test.cc160
-rw-r--r--contrib/libs/grpc/test/cpp/end2end/server_crash_test_client.cc72
-rw-r--r--contrib/libs/grpc/test/cpp/end2end/server_early_return_test.cc232
-rw-r--r--contrib/libs/grpc/test/cpp/end2end/server_interceptors/ya.make32
-rw-r--r--contrib/libs/grpc/test/cpp/end2end/server_interceptors_end2end_test.cc708
-rw-r--r--contrib/libs/grpc/test/cpp/end2end/server_load_reporting_end2end_test.cc192
-rw-r--r--contrib/libs/grpc/test/cpp/end2end/service_config_end2end_test.cc613
-rw-r--r--contrib/libs/grpc/test/cpp/end2end/shutdown_test.cc170
-rw-r--r--contrib/libs/grpc/test/cpp/end2end/streaming_throughput_test.cc193
-rw-r--r--contrib/libs/grpc/test/cpp/end2end/test_health_check_service_impl.cc98
-rw-r--r--contrib/libs/grpc/test/cpp/end2end/test_health_check_service_impl.h58
-rw-r--r--contrib/libs/grpc/test/cpp/end2end/test_service_impl.cc638
-rw-r--r--contrib/libs/grpc/test/cpp/end2end/test_service_impl.h495
-rw-r--r--contrib/libs/grpc/test/cpp/end2end/thread/ya.make_31
-rw-r--r--contrib/libs/grpc/test/cpp/end2end/thread_stress_test.cc442
-rw-r--r--contrib/libs/grpc/test/cpp/end2end/time_change_test.cc367
-rw-r--r--contrib/libs/grpc/test/cpp/end2end/xds_end2end_test.cc5832
-rw-r--r--contrib/libs/grpc/test/cpp/end2end/ya.make67
-rw-r--r--contrib/libs/grpc/test/cpp/util/.yandex_meta/licenses.list.txt32
-rw-r--r--contrib/libs/grpc/test/cpp/util/byte_buffer_proto_helper.cc57
-rw-r--r--contrib/libs/grpc/test/cpp/util/byte_buffer_proto_helper.h42
-rw-r--r--contrib/libs/grpc/test/cpp/util/byte_buffer_test.cc134
-rw-r--r--contrib/libs/grpc/test/cpp/util/channel_trace_proto_helper.cc115
-rw-r--r--contrib/libs/grpc/test/cpp/util/channel_trace_proto_helper.h37
-rw-r--r--contrib/libs/grpc/test/cpp/util/channelz_sampler.cc588
-rw-r--r--contrib/libs/grpc/test/cpp/util/channelz_sampler_test.cc176
-rw-r--r--contrib/libs/grpc/test/cpp/util/cli_call.cc229
-rw-r--r--contrib/libs/grpc/test/cpp/util/cli_call.h109
-rw-r--r--contrib/libs/grpc/test/cpp/util/cli_call_test.cc128
-rw-r--r--contrib/libs/grpc/test/cpp/util/cli_credentials.cc245
-rw-r--r--contrib/libs/grpc/test/cpp/util/cli_credentials.h55
-rw-r--r--contrib/libs/grpc/test/cpp/util/config_grpc_cli.h70
-rw-r--r--contrib/libs/grpc/test/cpp/util/create_test_channel.cc252
-rw-r--r--contrib/libs/grpc/test/cpp/util/create_test_channel.h99
-rw-r--r--contrib/libs/grpc/test/cpp/util/error_details_test.cc125
-rw-r--r--contrib/libs/grpc/test/cpp/util/grpc_cli.cc90
-rw-r--r--contrib/libs/grpc/test/cpp/util/grpc_tool.cc985
-rw-r--r--contrib/libs/grpc/test/cpp/util/grpc_tool.h39
-rw-r--r--contrib/libs/grpc/test/cpp/util/grpc_tool_test.cc1344
-rw-r--r--contrib/libs/grpc/test/cpp/util/metrics_server.cc117
-rw-r--r--contrib/libs/grpc/test/cpp/util/metrics_server.h98
-rw-r--r--contrib/libs/grpc/test/cpp/util/proto_file_parser.cc323
-rw-r--r--contrib/libs/grpc/test/cpp/util/proto_file_parser.h129
-rw-r--r--contrib/libs/grpc/test/cpp/util/proto_reflection_descriptor_database.cc333
-rw-r--r--contrib/libs/grpc/test/cpp/util/proto_reflection_descriptor_database.h111
-rw-r--r--contrib/libs/grpc/test/cpp/util/service_describer.cc92
-rw-r--r--contrib/libs/grpc/test/cpp/util/service_describer.h42
-rw-r--r--contrib/libs/grpc/test/cpp/util/slice_test.cc144
-rw-r--r--contrib/libs/grpc/test/cpp/util/string_ref_helper.cc29
-rw-r--r--contrib/libs/grpc/test/cpp/util/string_ref_helper.h32
-rw-r--r--contrib/libs/grpc/test/cpp/util/string_ref_test.cc205
-rw-r--r--contrib/libs/grpc/test/cpp/util/subprocess.cc44
-rw-r--r--contrib/libs/grpc/test/cpp/util/subprocess.h47
-rw-r--r--contrib/libs/grpc/test/cpp/util/test_config.h30
-rw-r--r--contrib/libs/grpc/test/cpp/util/test_config_cc.cc37
-rw-r--r--contrib/libs/grpc/test/cpp/util/test_credentials_provider.cc181
-rw-r--r--contrib/libs/grpc/test/cpp/util/test_credentials_provider.h85
-rw-r--r--contrib/libs/grpc/test/cpp/util/time_test.cc72
-rw-r--r--contrib/libs/grpc/test/cpp/util/ya.make39
88 files changed, 35933 insertions, 0 deletions
diff --git a/contrib/libs/grpc/test/cpp/README-iOS.md b/contrib/libs/grpc/test/cpp/README-iOS.md
new file mode 100644
index 0000000000..898931085b
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/README-iOS.md
@@ -0,0 +1,52 @@
+## C++ tests on iOS
+
+[GTMGoogleTestRunner](https://github.com/google/google-toolbox-for-mac/blob/master/UnitTesting/GTMGoogleTestRunner.mm) is used to convert googletest cases to XCTest that can be run on iOS. GTMGoogleTestRunner doesn't execute the `main` function, so we can't have any test logic in `main`.
+However, it's ok to call `::testing::InitGoogleTest` in `main`, as `GTMGoogleTestRunner` [calls InitGoogleTest](https://github.com/google/google-toolbox-for-mac/blob/master/UnitTesting/GTMGoogleTestRunner.mm#L151).
+`grpc::testing::TestEnvironment` can also be called from `main`, as it does some test initialization (install crash handler, seed RNG) that's not strictly required to run testcases on iOS.
+
+
+## Porting exising C++ tests to run on iOS
+
+Please follow these guidelines when porting tests to run on iOS:
+
+- Tests need to use the googletest framework
+- Any setup/teardown code in `main` needs to be moved to `SetUpTestCase`/`TearDownTestCase`, and `TEST` needs to be changed to `TEST_F`.
+- [Death tests](https://github.com/google/googletest/blob/master/googletest/docs/advanced.md#death-tests) are not supported on iOS, so use the `*_IF_SUPPORTED()` macros to ensure that your code compiles on iOS.
+
+For example, the following test
+```c++
+TEST(MyTest, TestOne) {
+ ASSERT_DEATH(ThisShouldDie(), "");
+}
+
+int main(int argc, char** argv) {
+ grpc::testing::TestEnvironment env(argc, argv);
+ ::testing::InitGoogleTest(&argc, argv);
+ grpc_init();
+ return RUN_ALL_TESTS();
+ grpc_shutdown();
+}
+```
+
+should be changed to
+```c++
+class MyTest : public ::testing::Test {
+ protected:
+ static void SetUpTestCase() { grpc_init(); }
+ static void TearDownTestCase() { grpc_shutdown(); }
+};
+
+TEST_F(MyTest, TestOne) {
+ ASSERT_DEATH_IF_SUPPORTED(ThisShouldDie(), "");
+}
+
+int main(int argc, char** argv) {
+ grpc::testing::TestEnvironment env(argc, argv);
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
+```
+
+## Limitations
+
+Due to a [limitation](https://github.com/google/google-toolbox-for-mac/blob/master/UnitTesting/GTMGoogleTestRunner.mm#L48-L56) in GTMGoogleTestRunner, `SetUpTestCase`/`TeardownTestCase` will be called before/after *every* individual test case, similar to `SetUp`/`TearDown`.
diff --git a/contrib/libs/grpc/test/cpp/end2end/.yandex_meta/licenses.list.txt b/contrib/libs/grpc/test/cpp/end2end/.yandex_meta/licenses.list.txt
new file mode 100644
index 0000000000..a07ea0849d
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/end2end/.yandex_meta/licenses.list.txt
@@ -0,0 +1,36 @@
+====================Apache-2.0====================
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+
+
+====================COPYRIGHT====================
+ * Copyright 2015 gRPC authors.
+
+
+====================COPYRIGHT====================
+ * Copyright 2016 gRPC authors.
+
+
+====================COPYRIGHT====================
+ * Copyright 2017 gRPC authors.
+
+
+====================COPYRIGHT====================
+ * Copyright 2018 gRPC authors.
+
+
+====================COPYRIGHT====================
+# Copyright 2019 gRPC authors.
+
+
+====================COPYRIGHT====================
+// Copyright 2019 The gRPC Authors
diff --git a/contrib/libs/grpc/test/cpp/end2end/async_end2end_test.cc b/contrib/libs/grpc/test/cpp/end2end/async_end2end_test.cc
new file mode 100644
index 0000000000..45df8718f9
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/end2end/async_end2end_test.cc
@@ -0,0 +1,1952 @@
+/*
+ *
+ * Copyright 2015 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include <cinttypes>
+#include <memory>
+#include <thread>
+
+#include <grpc/grpc.h>
+#include <grpc/support/alloc.h>
+#include <grpc/support/log.h>
+#include <grpc/support/time.h>
+#include <grpcpp/channel.h>
+#include <grpcpp/client_context.h>
+#include <grpcpp/create_channel.h>
+#include <grpcpp/ext/health_check_service_server_builder_option.h>
+#include <grpcpp/server.h>
+#include <grpcpp/server_builder.h>
+#include <grpcpp/server_context.h>
+
+#include "src/core/ext/filters/client_channel/backup_poller.h"
+#include "src/core/lib/gpr/tls.h"
+#include "src/core/lib/iomgr/port.h"
+#include "src/proto/grpc/health/v1/health.grpc.pb.h"
+#include "src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.h"
+#include "src/proto/grpc/testing/echo.grpc.pb.h"
+#include "test/core/util/port.h"
+#include "test/core/util/test_config.h"
+#include "test/cpp/util/string_ref_helper.h"
+#include "test/cpp/util/test_credentials_provider.h"
+
+#ifdef GRPC_POSIX_SOCKET_EV
+#include "src/core/lib/iomgr/ev_posix.h"
+#endif // GRPC_POSIX_SOCKET_EV
+
+#include <gtest/gtest.h>
+
+using grpc::testing::EchoRequest;
+using grpc::testing::EchoResponse;
+using grpc::testing::kTlsCredentialsType;
+using std::chrono::system_clock;
+
+namespace grpc {
+namespace testing {
+
+namespace {
+
+void* tag(int i) { return (void*)static_cast<intptr_t>(i); }
+int detag(void* p) { return static_cast<int>(reinterpret_cast<intptr_t>(p)); }
+
+class Verifier {
+ public:
+ Verifier() : lambda_run_(false) {}
+ // Expect sets the expected ok value for a specific tag
+ Verifier& Expect(int i, bool expect_ok) {
+ return ExpectUnless(i, expect_ok, false);
+ }
+ // ExpectUnless sets the expected ok value for a specific tag
+ // unless the tag was already marked seen (as a result of ExpectMaybe)
+ Verifier& ExpectUnless(int i, bool expect_ok, bool seen) {
+ if (!seen) {
+ expectations_[tag(i)] = expect_ok;
+ }
+ return *this;
+ }
+ // ExpectMaybe sets the expected ok value for a specific tag, but does not
+ // require it to appear
+ // If it does, sets *seen to true
+ Verifier& ExpectMaybe(int i, bool expect_ok, bool* seen) {
+ if (!*seen) {
+ maybe_expectations_[tag(i)] = MaybeExpect{expect_ok, seen};
+ }
+ return *this;
+ }
+
+ // Next waits for 1 async tag to complete, checks its
+ // expectations, and returns the tag
+ int Next(CompletionQueue* cq, bool ignore_ok) {
+ bool ok;
+ void* got_tag;
+ EXPECT_TRUE(cq->Next(&got_tag, &ok));
+ GotTag(got_tag, ok, ignore_ok);
+ return detag(got_tag);
+ }
+
+ template <typename T>
+ CompletionQueue::NextStatus DoOnceThenAsyncNext(
+ CompletionQueue* cq, void** got_tag, bool* ok, T deadline,
+ std::function<void(void)> lambda) {
+ if (lambda_run_) {
+ return cq->AsyncNext(got_tag, ok, deadline);
+ } else {
+ lambda_run_ = true;
+ return cq->DoThenAsyncNext(lambda, got_tag, ok, deadline);
+ }
+ }
+
+ // Verify keeps calling Next until all currently set
+ // expected tags are complete
+ void Verify(CompletionQueue* cq) { Verify(cq, false); }
+
+ // This version of Verify allows optionally ignoring the
+ // outcome of the expectation
+ void Verify(CompletionQueue* cq, bool ignore_ok) {
+ GPR_ASSERT(!expectations_.empty() || !maybe_expectations_.empty());
+ while (!expectations_.empty()) {
+ Next(cq, ignore_ok);
+ }
+ maybe_expectations_.clear();
+ }
+
+ // This version of Verify stops after a certain deadline
+ void Verify(CompletionQueue* cq,
+ std::chrono::system_clock::time_point deadline) {
+ if (expectations_.empty()) {
+ bool ok;
+ void* got_tag;
+ EXPECT_EQ(cq->AsyncNext(&got_tag, &ok, deadline),
+ CompletionQueue::TIMEOUT);
+ } else {
+ while (!expectations_.empty()) {
+ bool ok;
+ void* got_tag;
+ EXPECT_EQ(cq->AsyncNext(&got_tag, &ok, deadline),
+ CompletionQueue::GOT_EVENT);
+ GotTag(got_tag, ok, false);
+ }
+ }
+ maybe_expectations_.clear();
+ }
+
+ // This version of Verify stops after a certain deadline, and uses the
+ // DoThenAsyncNext API
+ // to call the lambda
+ void Verify(CompletionQueue* cq,
+ std::chrono::system_clock::time_point deadline,
+ const std::function<void(void)>& lambda) {
+ if (expectations_.empty()) {
+ bool ok;
+ void* got_tag;
+ EXPECT_EQ(DoOnceThenAsyncNext(cq, &got_tag, &ok, deadline, lambda),
+ CompletionQueue::TIMEOUT);
+ } else {
+ while (!expectations_.empty()) {
+ bool ok;
+ void* got_tag;
+ EXPECT_EQ(DoOnceThenAsyncNext(cq, &got_tag, &ok, deadline, lambda),
+ CompletionQueue::GOT_EVENT);
+ GotTag(got_tag, ok, false);
+ }
+ }
+ maybe_expectations_.clear();
+ }
+
+ private:
+ void GotTag(void* got_tag, bool ok, bool ignore_ok) {
+ auto it = expectations_.find(got_tag);
+ if (it != expectations_.end()) {
+ if (!ignore_ok) {
+ EXPECT_EQ(it->second, ok);
+ }
+ expectations_.erase(it);
+ } else {
+ auto it2 = maybe_expectations_.find(got_tag);
+ if (it2 != maybe_expectations_.end()) {
+ if (it2->second.seen != nullptr) {
+ EXPECT_FALSE(*it2->second.seen);
+ *it2->second.seen = true;
+ }
+ if (!ignore_ok) {
+ EXPECT_EQ(it2->second.ok, ok);
+ }
+ maybe_expectations_.erase(it2);
+ } else {
+ gpr_log(GPR_ERROR, "Unexpected tag: %p", got_tag);
+ abort();
+ }
+ }
+ }
+
+ struct MaybeExpect {
+ bool ok;
+ bool* seen;
+ };
+
+ std::map<void*, bool> expectations_;
+ std::map<void*, MaybeExpect> maybe_expectations_;
+ bool lambda_run_;
+};
+
+bool plugin_has_sync_methods(std::unique_ptr<ServerBuilderPlugin>& plugin) {
+ return plugin->has_sync_methods();
+}
+
+// This class disables the server builder plugins that may add sync services to
+// the server. If there are sync services, UnimplementedRpc test will triger
+// the sync unknown rpc routine on the server side, rather than the async one
+// that needs to be tested here.
+class ServerBuilderSyncPluginDisabler : public ::grpc::ServerBuilderOption {
+ public:
+ void UpdateArguments(ChannelArguments* /*arg*/) override {}
+
+ void UpdatePlugins(
+ std::vector<std::unique_ptr<ServerBuilderPlugin>>* plugins) override {
+ plugins->erase(std::remove_if(plugins->begin(), plugins->end(),
+ plugin_has_sync_methods),
+ plugins->end());
+ }
+};
+
+class TestScenario {
+ public:
+ TestScenario(bool inproc_stub, const TString& creds_type, bool hcs,
+ const TString& content)
+ : inproc(inproc_stub),
+ health_check_service(hcs),
+ credentials_type(creds_type),
+ message_content(content) {}
+ void Log() const;
+ bool inproc;
+ bool health_check_service;
+ const TString credentials_type;
+ const TString message_content;
+};
+
+static std::ostream& operator<<(std::ostream& out,
+ const TestScenario& scenario) {
+ return out << "TestScenario{inproc=" << (scenario.inproc ? "true" : "false")
+ << ", credentials='" << scenario.credentials_type
+ << ", health_check_service="
+ << (scenario.health_check_service ? "true" : "false")
+ << "', message_size=" << scenario.message_content.size() << "}";
+}
+
+void TestScenario::Log() const {
+ std::ostringstream out;
+ out << *this;
+ gpr_log(GPR_DEBUG, "%s", out.str().c_str());
+}
+
+class HealthCheck : public health::v1::Health::Service {};
+
+class AsyncEnd2endTest : public ::testing::TestWithParam<TestScenario> {
+ protected:
+ AsyncEnd2endTest() { GetParam().Log(); }
+
+ void SetUp() override {
+ port_ = grpc_pick_unused_port_or_die();
+ server_address_ << "localhost:" << port_;
+
+ // Setup server
+ BuildAndStartServer();
+ }
+
+ void TearDown() override {
+ server_->Shutdown();
+ void* ignored_tag;
+ bool ignored_ok;
+ cq_->Shutdown();
+ while (cq_->Next(&ignored_tag, &ignored_ok))
+ ;
+ stub_.reset();
+ grpc_recycle_unused_port(port_);
+ }
+
+ void BuildAndStartServer() {
+ ServerBuilder builder;
+ auto server_creds = GetCredentialsProvider()->GetServerCredentials(
+ GetParam().credentials_type);
+ builder.AddListeningPort(server_address_.str(), server_creds);
+ service_.reset(new grpc::testing::EchoTestService::AsyncService());
+ builder.RegisterService(service_.get());
+ if (GetParam().health_check_service) {
+ builder.RegisterService(&health_check_);
+ }
+ cq_ = builder.AddCompletionQueue();
+
+ // TODO(zyc): make a test option to choose wheather sync plugins should be
+ // deleted
+ std::unique_ptr<ServerBuilderOption> sync_plugin_disabler(
+ new ServerBuilderSyncPluginDisabler());
+ builder.SetOption(move(sync_plugin_disabler));
+ server_ = builder.BuildAndStart();
+ }
+
+ void ResetStub() {
+ ChannelArguments args;
+ auto channel_creds = GetCredentialsProvider()->GetChannelCredentials(
+ GetParam().credentials_type, &args);
+ std::shared_ptr<Channel> channel =
+ !(GetParam().inproc) ? ::grpc::CreateCustomChannel(
+ server_address_.str(), channel_creds, args)
+ : server_->InProcessChannel(args);
+ stub_ = grpc::testing::EchoTestService::NewStub(channel);
+ }
+
+ void SendRpc(int num_rpcs) {
+ for (int i = 0; i < num_rpcs; i++) {
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+
+ ClientContext cli_ctx;
+ ServerContext srv_ctx;
+ grpc::ServerAsyncResponseWriter<EchoResponse> response_writer(&srv_ctx);
+
+ send_request.set_message(GetParam().message_content);
+ std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader(
+ stub_->AsyncEcho(&cli_ctx, send_request, cq_.get()));
+
+ service_->RequestEcho(&srv_ctx, &recv_request, &response_writer,
+ cq_.get(), cq_.get(), tag(2));
+
+ response_reader->Finish(&recv_response, &recv_status, tag(4));
+
+ Verifier().Expect(2, true).Verify(cq_.get());
+ EXPECT_EQ(send_request.message(), recv_request.message());
+
+ send_response.set_message(recv_request.message());
+ response_writer.Finish(send_response, Status::OK, tag(3));
+ Verifier().Expect(3, true).Expect(4, true).Verify(cq_.get());
+
+ EXPECT_EQ(send_response.message(), recv_response.message());
+ EXPECT_TRUE(recv_status.ok());
+ }
+ }
+
+ std::unique_ptr<ServerCompletionQueue> cq_;
+ std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_;
+ std::unique_ptr<Server> server_;
+ std::unique_ptr<grpc::testing::EchoTestService::AsyncService> service_;
+ HealthCheck health_check_;
+ std::ostringstream server_address_;
+ int port_;
+};
+
+TEST_P(AsyncEnd2endTest, SimpleRpc) {
+ ResetStub();
+ SendRpc(1);
+}
+
+TEST_P(AsyncEnd2endTest, SimpleRpcWithExpectedError) {
+ ResetStub();
+
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+
+ ClientContext cli_ctx;
+ ServerContext srv_ctx;
+ grpc::ServerAsyncResponseWriter<EchoResponse> response_writer(&srv_ctx);
+ ErrorStatus error_status;
+
+ send_request.set_message(GetParam().message_content);
+ error_status.set_code(1); // CANCELLED
+ error_status.set_error_message("cancel error message");
+ *send_request.mutable_param()->mutable_expected_error() = error_status;
+
+ std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader(
+ stub_->AsyncEcho(&cli_ctx, send_request, cq_.get()));
+
+ srv_ctx.AsyncNotifyWhenDone(tag(5));
+ service_->RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(),
+ cq_.get(), tag(2));
+
+ response_reader->Finish(&recv_response, &recv_status, tag(4));
+
+ Verifier().Expect(2, true).Verify(cq_.get());
+ EXPECT_EQ(send_request.message(), recv_request.message());
+
+ send_response.set_message(recv_request.message());
+ response_writer.Finish(
+ send_response,
+ Status(
+ static_cast<StatusCode>(recv_request.param().expected_error().code()),
+ recv_request.param().expected_error().error_message()),
+ tag(3));
+ Verifier().Expect(3, true).Expect(4, true).Expect(5, true).Verify(cq_.get());
+
+ EXPECT_EQ(recv_response.message(), "");
+ EXPECT_EQ(recv_status.error_code(), error_status.code());
+ EXPECT_EQ(recv_status.error_message(), error_status.error_message());
+ EXPECT_FALSE(srv_ctx.IsCancelled());
+}
+
+TEST_P(AsyncEnd2endTest, SequentialRpcs) {
+ ResetStub();
+ SendRpc(10);
+}
+
+TEST_P(AsyncEnd2endTest, ReconnectChannel) {
+ // GRPC_CLIENT_CHANNEL_BACKUP_POLL_INTERVAL_MS is set to 100ms in main()
+ if (GetParam().inproc) {
+ return;
+ }
+ int poller_slowdown_factor = 1;
+#ifdef GRPC_POSIX_SOCKET_EV
+ // It needs 2 pollset_works to reconnect the channel with polling engine
+ // "poll"
+ grpc_core::UniquePtr<char> poller = GPR_GLOBAL_CONFIG_GET(grpc_poll_strategy);
+ if (0 == strcmp(poller.get(), "poll")) {
+ poller_slowdown_factor = 2;
+ }
+#endif // GRPC_POSIX_SOCKET_EV
+ ResetStub();
+ SendRpc(1);
+ server_->Shutdown();
+ void* ignored_tag;
+ bool ignored_ok;
+ cq_->Shutdown();
+ while (cq_->Next(&ignored_tag, &ignored_ok))
+ ;
+ BuildAndStartServer();
+ // It needs more than GRPC_CLIENT_CHANNEL_BACKUP_POLL_INTERVAL_MS time to
+ // reconnect the channel.
+ gpr_sleep_until(gpr_time_add(
+ gpr_now(GPR_CLOCK_REALTIME),
+ gpr_time_from_millis(
+ 300 * poller_slowdown_factor * grpc_test_slowdown_factor(),
+ GPR_TIMESPAN)));
+ SendRpc(1);
+}
+
+// We do not need to protect notify because the use is synchronized.
+void ServerWait(Server* server, int* notify) {
+ server->Wait();
+ *notify = 1;
+}
+TEST_P(AsyncEnd2endTest, WaitAndShutdownTest) {
+ int notify = 0;
+ std::thread wait_thread(&ServerWait, server_.get(), &notify);
+ ResetStub();
+ SendRpc(1);
+ EXPECT_EQ(0, notify);
+ server_->Shutdown();
+ wait_thread.join();
+ EXPECT_EQ(1, notify);
+}
+
+TEST_P(AsyncEnd2endTest, ShutdownThenWait) {
+ ResetStub();
+ SendRpc(1);
+ std::thread t([this]() { server_->Shutdown(); });
+ server_->Wait();
+ t.join();
+}
+
+// Test a simple RPC using the async version of Next
+TEST_P(AsyncEnd2endTest, AsyncNextRpc) {
+ ResetStub();
+
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+
+ ClientContext cli_ctx;
+ ServerContext srv_ctx;
+ grpc::ServerAsyncResponseWriter<EchoResponse> response_writer(&srv_ctx);
+
+ send_request.set_message(GetParam().message_content);
+ std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader(
+ stub_->AsyncEcho(&cli_ctx, send_request, cq_.get()));
+
+ std::chrono::system_clock::time_point time_now(
+ std::chrono::system_clock::now());
+ std::chrono::system_clock::time_point time_limit(
+ std::chrono::system_clock::now() + std::chrono::seconds(10));
+ Verifier().Verify(cq_.get(), time_now);
+ Verifier().Verify(cq_.get(), time_now);
+
+ service_->RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(),
+ cq_.get(), tag(2));
+ response_reader->Finish(&recv_response, &recv_status, tag(4));
+
+ Verifier().Expect(2, true).Verify(cq_.get(), time_limit);
+ EXPECT_EQ(send_request.message(), recv_request.message());
+
+ send_response.set_message(recv_request.message());
+ response_writer.Finish(send_response, Status::OK, tag(3));
+ Verifier().Expect(3, true).Expect(4, true).Verify(
+ cq_.get(), std::chrono::system_clock::time_point::max());
+
+ EXPECT_EQ(send_response.message(), recv_response.message());
+ EXPECT_TRUE(recv_status.ok());
+}
+
+// Test a simple RPC using the async version of Next
+TEST_P(AsyncEnd2endTest, DoThenAsyncNextRpc) {
+ ResetStub();
+
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+
+ ClientContext cli_ctx;
+ ServerContext srv_ctx;
+ grpc::ServerAsyncResponseWriter<EchoResponse> response_writer(&srv_ctx);
+
+ send_request.set_message(GetParam().message_content);
+ std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader(
+ stub_->AsyncEcho(&cli_ctx, send_request, cq_.get()));
+
+ std::chrono::system_clock::time_point time_now(
+ std::chrono::system_clock::now());
+ std::chrono::system_clock::time_point time_limit(
+ std::chrono::system_clock::now() + std::chrono::seconds(10));
+ Verifier().Verify(cq_.get(), time_now);
+ Verifier().Verify(cq_.get(), time_now);
+
+ auto resp_writer_ptr = &response_writer;
+ auto lambda_2 = [&, this, resp_writer_ptr]() {
+ service_->RequestEcho(&srv_ctx, &recv_request, resp_writer_ptr, cq_.get(),
+ cq_.get(), tag(2));
+ };
+ response_reader->Finish(&recv_response, &recv_status, tag(4));
+
+ Verifier().Expect(2, true).Verify(cq_.get(), time_limit, lambda_2);
+ EXPECT_EQ(send_request.message(), recv_request.message());
+
+ send_response.set_message(recv_request.message());
+ auto lambda_3 = [resp_writer_ptr, send_response]() {
+ resp_writer_ptr->Finish(send_response, Status::OK, tag(3));
+ };
+ Verifier().Expect(3, true).Expect(4, true).Verify(
+ cq_.get(), std::chrono::system_clock::time_point::max(), lambda_3);
+
+ EXPECT_EQ(send_response.message(), recv_response.message());
+ EXPECT_TRUE(recv_status.ok());
+}
+
+// Two pings and a final pong.
+TEST_P(AsyncEnd2endTest, SimpleClientStreaming) {
+ ResetStub();
+
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+ ClientContext cli_ctx;
+ ServerContext srv_ctx;
+ ServerAsyncReader<EchoResponse, EchoRequest> srv_stream(&srv_ctx);
+
+ send_request.set_message(GetParam().message_content);
+ std::unique_ptr<ClientAsyncWriter<EchoRequest>> cli_stream(
+ stub_->AsyncRequestStream(&cli_ctx, &recv_response, cq_.get(), tag(1)));
+
+ service_->RequestRequestStream(&srv_ctx, &srv_stream, cq_.get(), cq_.get(),
+ tag(2));
+
+ Verifier().Expect(2, true).Expect(1, true).Verify(cq_.get());
+
+ cli_stream->Write(send_request, tag(3));
+ srv_stream.Read(&recv_request, tag(4));
+ Verifier().Expect(3, true).Expect(4, true).Verify(cq_.get());
+ EXPECT_EQ(send_request.message(), recv_request.message());
+
+ cli_stream->Write(send_request, tag(5));
+ srv_stream.Read(&recv_request, tag(6));
+ Verifier().Expect(5, true).Expect(6, true).Verify(cq_.get());
+
+ EXPECT_EQ(send_request.message(), recv_request.message());
+ cli_stream->WritesDone(tag(7));
+ srv_stream.Read(&recv_request, tag(8));
+ Verifier().Expect(7, true).Expect(8, false).Verify(cq_.get());
+
+ send_response.set_message(recv_request.message());
+ srv_stream.Finish(send_response, Status::OK, tag(9));
+ cli_stream->Finish(&recv_status, tag(10));
+ Verifier().Expect(9, true).Expect(10, true).Verify(cq_.get());
+
+ EXPECT_EQ(send_response.message(), recv_response.message());
+ EXPECT_TRUE(recv_status.ok());
+}
+
+// Two pings and a final pong.
+TEST_P(AsyncEnd2endTest, SimpleClientStreamingWithCoalescingApi) {
+ ResetStub();
+
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+ ClientContext cli_ctx;
+ ServerContext srv_ctx;
+ ServerAsyncReader<EchoResponse, EchoRequest> srv_stream(&srv_ctx);
+
+ send_request.set_message(GetParam().message_content);
+ cli_ctx.set_initial_metadata_corked(true);
+ // tag:1 never comes up since no op is performed
+ std::unique_ptr<ClientAsyncWriter<EchoRequest>> cli_stream(
+ stub_->AsyncRequestStream(&cli_ctx, &recv_response, cq_.get(), tag(1)));
+
+ service_->RequestRequestStream(&srv_ctx, &srv_stream, cq_.get(), cq_.get(),
+ tag(2));
+
+ cli_stream->Write(send_request, tag(3));
+
+ bool seen3 = false;
+
+ Verifier().Expect(2, true).ExpectMaybe(3, true, &seen3).Verify(cq_.get());
+
+ srv_stream.Read(&recv_request, tag(4));
+
+ Verifier().ExpectUnless(3, true, seen3).Expect(4, true).Verify(cq_.get());
+
+ EXPECT_EQ(send_request.message(), recv_request.message());
+
+ cli_stream->WriteLast(send_request, WriteOptions(), tag(5));
+ srv_stream.Read(&recv_request, tag(6));
+ Verifier().Expect(5, true).Expect(6, true).Verify(cq_.get());
+ EXPECT_EQ(send_request.message(), recv_request.message());
+
+ srv_stream.Read(&recv_request, tag(7));
+ Verifier().Expect(7, false).Verify(cq_.get());
+
+ send_response.set_message(recv_request.message());
+ srv_stream.Finish(send_response, Status::OK, tag(8));
+ cli_stream->Finish(&recv_status, tag(9));
+ Verifier().Expect(8, true).Expect(9, true).Verify(cq_.get());
+
+ EXPECT_EQ(send_response.message(), recv_response.message());
+ EXPECT_TRUE(recv_status.ok());
+}
+
+// One ping, two pongs.
+TEST_P(AsyncEnd2endTest, SimpleServerStreaming) {
+ ResetStub();
+
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+ ClientContext cli_ctx;
+ ServerContext srv_ctx;
+ ServerAsyncWriter<EchoResponse> srv_stream(&srv_ctx);
+
+ send_request.set_message(GetParam().message_content);
+ std::unique_ptr<ClientAsyncReader<EchoResponse>> cli_stream(
+ stub_->AsyncResponseStream(&cli_ctx, send_request, cq_.get(), tag(1)));
+
+ service_->RequestResponseStream(&srv_ctx, &recv_request, &srv_stream,
+ cq_.get(), cq_.get(), tag(2));
+
+ Verifier().Expect(1, true).Expect(2, true).Verify(cq_.get());
+ EXPECT_EQ(send_request.message(), recv_request.message());
+
+ send_response.set_message(recv_request.message());
+ srv_stream.Write(send_response, tag(3));
+ cli_stream->Read(&recv_response, tag(4));
+ Verifier().Expect(3, true).Expect(4, true).Verify(cq_.get());
+ EXPECT_EQ(send_response.message(), recv_response.message());
+
+ srv_stream.Write(send_response, tag(5));
+ cli_stream->Read(&recv_response, tag(6));
+ Verifier().Expect(5, true).Expect(6, true).Verify(cq_.get());
+ EXPECT_EQ(send_response.message(), recv_response.message());
+
+ srv_stream.Finish(Status::OK, tag(7));
+ cli_stream->Read(&recv_response, tag(8));
+ Verifier().Expect(7, true).Expect(8, false).Verify(cq_.get());
+
+ cli_stream->Finish(&recv_status, tag(9));
+ Verifier().Expect(9, true).Verify(cq_.get());
+
+ EXPECT_TRUE(recv_status.ok());
+}
+
+// One ping, two pongs. Using WriteAndFinish API
+TEST_P(AsyncEnd2endTest, SimpleServerStreamingWithCoalescingApiWAF) {
+ ResetStub();
+
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+ ClientContext cli_ctx;
+ ServerContext srv_ctx;
+ ServerAsyncWriter<EchoResponse> srv_stream(&srv_ctx);
+
+ send_request.set_message(GetParam().message_content);
+ std::unique_ptr<ClientAsyncReader<EchoResponse>> cli_stream(
+ stub_->AsyncResponseStream(&cli_ctx, send_request, cq_.get(), tag(1)));
+
+ service_->RequestResponseStream(&srv_ctx, &recv_request, &srv_stream,
+ cq_.get(), cq_.get(), tag(2));
+
+ Verifier().Expect(1, true).Expect(2, true).Verify(cq_.get());
+ EXPECT_EQ(send_request.message(), recv_request.message());
+
+ send_response.set_message(recv_request.message());
+ srv_stream.Write(send_response, tag(3));
+ cli_stream->Read(&recv_response, tag(4));
+ Verifier().Expect(3, true).Expect(4, true).Verify(cq_.get());
+ EXPECT_EQ(send_response.message(), recv_response.message());
+
+ srv_stream.WriteAndFinish(send_response, WriteOptions(), Status::OK, tag(5));
+ cli_stream->Read(&recv_response, tag(6));
+ Verifier().Expect(5, true).Expect(6, true).Verify(cq_.get());
+ EXPECT_EQ(send_response.message(), recv_response.message());
+
+ cli_stream->Read(&recv_response, tag(7));
+ Verifier().Expect(7, false).Verify(cq_.get());
+
+ cli_stream->Finish(&recv_status, tag(8));
+ Verifier().Expect(8, true).Verify(cq_.get());
+
+ EXPECT_TRUE(recv_status.ok());
+}
+
+// One ping, two pongs. Using WriteLast API
+TEST_P(AsyncEnd2endTest, SimpleServerStreamingWithCoalescingApiWL) {
+ ResetStub();
+
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+ ClientContext cli_ctx;
+ ServerContext srv_ctx;
+ ServerAsyncWriter<EchoResponse> srv_stream(&srv_ctx);
+
+ send_request.set_message(GetParam().message_content);
+ std::unique_ptr<ClientAsyncReader<EchoResponse>> cli_stream(
+ stub_->AsyncResponseStream(&cli_ctx, send_request, cq_.get(), tag(1)));
+
+ service_->RequestResponseStream(&srv_ctx, &recv_request, &srv_stream,
+ cq_.get(), cq_.get(), tag(2));
+
+ Verifier().Expect(1, true).Expect(2, true).Verify(cq_.get());
+ EXPECT_EQ(send_request.message(), recv_request.message());
+
+ send_response.set_message(recv_request.message());
+ srv_stream.Write(send_response, tag(3));
+ cli_stream->Read(&recv_response, tag(4));
+ Verifier().Expect(3, true).Expect(4, true).Verify(cq_.get());
+ EXPECT_EQ(send_response.message(), recv_response.message());
+
+ srv_stream.WriteLast(send_response, WriteOptions(), tag(5));
+ cli_stream->Read(&recv_response, tag(6));
+ srv_stream.Finish(Status::OK, tag(7));
+ Verifier().Expect(5, true).Expect(6, true).Expect(7, true).Verify(cq_.get());
+ EXPECT_EQ(send_response.message(), recv_response.message());
+
+ cli_stream->Read(&recv_response, tag(8));
+ Verifier().Expect(8, false).Verify(cq_.get());
+
+ cli_stream->Finish(&recv_status, tag(9));
+ Verifier().Expect(9, true).Verify(cq_.get());
+
+ EXPECT_TRUE(recv_status.ok());
+}
+
+// One ping, one pong.
+TEST_P(AsyncEnd2endTest, SimpleBidiStreaming) {
+ ResetStub();
+
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+ ClientContext cli_ctx;
+ ServerContext srv_ctx;
+ ServerAsyncReaderWriter<EchoResponse, EchoRequest> srv_stream(&srv_ctx);
+
+ send_request.set_message(GetParam().message_content);
+ std::unique_ptr<ClientAsyncReaderWriter<EchoRequest, EchoResponse>>
+ cli_stream(stub_->AsyncBidiStream(&cli_ctx, cq_.get(), tag(1)));
+
+ service_->RequestBidiStream(&srv_ctx, &srv_stream, cq_.get(), cq_.get(),
+ tag(2));
+
+ Verifier().Expect(1, true).Expect(2, true).Verify(cq_.get());
+
+ cli_stream->Write(send_request, tag(3));
+ srv_stream.Read(&recv_request, tag(4));
+ Verifier().Expect(3, true).Expect(4, true).Verify(cq_.get());
+ EXPECT_EQ(send_request.message(), recv_request.message());
+
+ send_response.set_message(recv_request.message());
+ srv_stream.Write(send_response, tag(5));
+ cli_stream->Read(&recv_response, tag(6));
+ Verifier().Expect(5, true).Expect(6, true).Verify(cq_.get());
+ EXPECT_EQ(send_response.message(), recv_response.message());
+
+ cli_stream->WritesDone(tag(7));
+ srv_stream.Read(&recv_request, tag(8));
+ Verifier().Expect(7, true).Expect(8, false).Verify(cq_.get());
+
+ srv_stream.Finish(Status::OK, tag(9));
+ cli_stream->Finish(&recv_status, tag(10));
+ Verifier().Expect(9, true).Expect(10, true).Verify(cq_.get());
+
+ EXPECT_TRUE(recv_status.ok());
+}
+
+// One ping, one pong. Using server:WriteAndFinish api
+TEST_P(AsyncEnd2endTest, SimpleBidiStreamingWithCoalescingApiWAF) {
+ ResetStub();
+
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+ ClientContext cli_ctx;
+ ServerContext srv_ctx;
+ ServerAsyncReaderWriter<EchoResponse, EchoRequest> srv_stream(&srv_ctx);
+
+ send_request.set_message(GetParam().message_content);
+ cli_ctx.set_initial_metadata_corked(true);
+ std::unique_ptr<ClientAsyncReaderWriter<EchoRequest, EchoResponse>>
+ cli_stream(stub_->AsyncBidiStream(&cli_ctx, cq_.get(), tag(1)));
+
+ service_->RequestBidiStream(&srv_ctx, &srv_stream, cq_.get(), cq_.get(),
+ tag(2));
+
+ cli_stream->WriteLast(send_request, WriteOptions(), tag(3));
+
+ bool seen3 = false;
+
+ Verifier().Expect(2, true).ExpectMaybe(3, true, &seen3).Verify(cq_.get());
+
+ srv_stream.Read(&recv_request, tag(4));
+
+ Verifier().ExpectUnless(3, true, seen3).Expect(4, true).Verify(cq_.get());
+ EXPECT_EQ(send_request.message(), recv_request.message());
+
+ srv_stream.Read(&recv_request, tag(5));
+ Verifier().Expect(5, false).Verify(cq_.get());
+
+ send_response.set_message(recv_request.message());
+ srv_stream.WriteAndFinish(send_response, WriteOptions(), Status::OK, tag(6));
+ cli_stream->Read(&recv_response, tag(7));
+ Verifier().Expect(6, true).Expect(7, true).Verify(cq_.get());
+ EXPECT_EQ(send_response.message(), recv_response.message());
+
+ cli_stream->Finish(&recv_status, tag(8));
+ Verifier().Expect(8, true).Verify(cq_.get());
+
+ EXPECT_TRUE(recv_status.ok());
+}
+
+// One ping, one pong. Using server:WriteLast api
+TEST_P(AsyncEnd2endTest, SimpleBidiStreamingWithCoalescingApiWL) {
+ ResetStub();
+
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+ ClientContext cli_ctx;
+ ServerContext srv_ctx;
+ ServerAsyncReaderWriter<EchoResponse, EchoRequest> srv_stream(&srv_ctx);
+
+ send_request.set_message(GetParam().message_content);
+ cli_ctx.set_initial_metadata_corked(true);
+ std::unique_ptr<ClientAsyncReaderWriter<EchoRequest, EchoResponse>>
+ cli_stream(stub_->AsyncBidiStream(&cli_ctx, cq_.get(), tag(1)));
+
+ service_->RequestBidiStream(&srv_ctx, &srv_stream, cq_.get(), cq_.get(),
+ tag(2));
+
+ cli_stream->WriteLast(send_request, WriteOptions(), tag(3));
+
+ bool seen3 = false;
+
+ Verifier().Expect(2, true).ExpectMaybe(3, true, &seen3).Verify(cq_.get());
+
+ srv_stream.Read(&recv_request, tag(4));
+
+ Verifier().ExpectUnless(3, true, seen3).Expect(4, true).Verify(cq_.get());
+ EXPECT_EQ(send_request.message(), recv_request.message());
+
+ srv_stream.Read(&recv_request, tag(5));
+ Verifier().Expect(5, false).Verify(cq_.get());
+
+ send_response.set_message(recv_request.message());
+ srv_stream.WriteLast(send_response, WriteOptions(), tag(6));
+ srv_stream.Finish(Status::OK, tag(7));
+ cli_stream->Read(&recv_response, tag(8));
+ Verifier().Expect(6, true).Expect(7, true).Expect(8, true).Verify(cq_.get());
+ EXPECT_EQ(send_response.message(), recv_response.message());
+
+ cli_stream->Finish(&recv_status, tag(9));
+ Verifier().Expect(9, true).Verify(cq_.get());
+
+ EXPECT_TRUE(recv_status.ok());
+}
+
+// Metadata tests
+TEST_P(AsyncEnd2endTest, ClientInitialMetadataRpc) {
+ ResetStub();
+
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+
+ ClientContext cli_ctx;
+ ServerContext srv_ctx;
+ grpc::ServerAsyncResponseWriter<EchoResponse> response_writer(&srv_ctx);
+
+ send_request.set_message(GetParam().message_content);
+ std::pair<TString, TString> meta1("key1", "val1");
+ std::pair<TString, TString> meta2("key2", "val2");
+ std::pair<TString, TString> meta3("g.r.d-bin", "xyz");
+ cli_ctx.AddMetadata(meta1.first, meta1.second);
+ cli_ctx.AddMetadata(meta2.first, meta2.second);
+ cli_ctx.AddMetadata(meta3.first, meta3.second);
+
+ std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader(
+ stub_->AsyncEcho(&cli_ctx, send_request, cq_.get()));
+ response_reader->Finish(&recv_response, &recv_status, tag(4));
+
+ service_->RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(),
+ cq_.get(), tag(2));
+ Verifier().Expect(2, true).Verify(cq_.get());
+ EXPECT_EQ(send_request.message(), recv_request.message());
+ const auto& client_initial_metadata = srv_ctx.client_metadata();
+ EXPECT_EQ(meta1.second,
+ ToString(client_initial_metadata.find(meta1.first)->second));
+ EXPECT_EQ(meta2.second,
+ ToString(client_initial_metadata.find(meta2.first)->second));
+ EXPECT_EQ(meta3.second,
+ ToString(client_initial_metadata.find(meta3.first)->second));
+ EXPECT_GE(client_initial_metadata.size(), static_cast<size_t>(2));
+
+ send_response.set_message(recv_request.message());
+ response_writer.Finish(send_response, Status::OK, tag(3));
+ Verifier().Expect(3, true).Expect(4, true).Verify(cq_.get());
+
+ EXPECT_EQ(send_response.message(), recv_response.message());
+ EXPECT_TRUE(recv_status.ok());
+}
+
+TEST_P(AsyncEnd2endTest, ServerInitialMetadataRpc) {
+ ResetStub();
+
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+
+ ClientContext cli_ctx;
+ ServerContext srv_ctx;
+ grpc::ServerAsyncResponseWriter<EchoResponse> response_writer(&srv_ctx);
+
+ send_request.set_message(GetParam().message_content);
+ std::pair<TString, TString> meta1("key1", "val1");
+ std::pair<TString, TString> meta2("key2", "val2");
+
+ std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader(
+ stub_->AsyncEcho(&cli_ctx, send_request, cq_.get()));
+ response_reader->ReadInitialMetadata(tag(4));
+
+ service_->RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(),
+ cq_.get(), tag(2));
+ Verifier().Expect(2, true).Verify(cq_.get());
+ EXPECT_EQ(send_request.message(), recv_request.message());
+ srv_ctx.AddInitialMetadata(meta1.first, meta1.second);
+ srv_ctx.AddInitialMetadata(meta2.first, meta2.second);
+ response_writer.SendInitialMetadata(tag(3));
+ Verifier().Expect(3, true).Expect(4, true).Verify(cq_.get());
+ const auto& server_initial_metadata = cli_ctx.GetServerInitialMetadata();
+ EXPECT_EQ(meta1.second,
+ ToString(server_initial_metadata.find(meta1.first)->second));
+ EXPECT_EQ(meta2.second,
+ ToString(server_initial_metadata.find(meta2.first)->second));
+ EXPECT_EQ(static_cast<size_t>(2), server_initial_metadata.size());
+
+ send_response.set_message(recv_request.message());
+ response_writer.Finish(send_response, Status::OK, tag(5));
+ response_reader->Finish(&recv_response, &recv_status, tag(6));
+ Verifier().Expect(5, true).Expect(6, true).Verify(cq_.get());
+
+ EXPECT_EQ(send_response.message(), recv_response.message());
+ EXPECT_TRUE(recv_status.ok());
+}
+
+// 1 ping, 2 pongs.
+TEST_P(AsyncEnd2endTest, ServerInitialMetadataServerStreaming) {
+ ResetStub();
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+ ClientContext cli_ctx;
+ ServerContext srv_ctx;
+ ServerAsyncWriter<EchoResponse> srv_stream(&srv_ctx);
+
+ std::pair<::TString, ::TString> meta1("key1", "val1");
+ std::pair<::TString, ::TString> meta2("key2", "val2");
+
+ std::unique_ptr<ClientAsyncReader<EchoResponse>> cli_stream(
+ stub_->AsyncResponseStream(&cli_ctx, send_request, cq_.get(), tag(1)));
+ cli_stream->ReadInitialMetadata(tag(11));
+ service_->RequestResponseStream(&srv_ctx, &recv_request, &srv_stream,
+ cq_.get(), cq_.get(), tag(2));
+
+ Verifier().Expect(1, true).Expect(2, true).Verify(cq_.get());
+
+ srv_ctx.AddInitialMetadata(meta1.first, meta1.second);
+ srv_ctx.AddInitialMetadata(meta2.first, meta2.second);
+ srv_stream.SendInitialMetadata(tag(10));
+ Verifier().Expect(10, true).Expect(11, true).Verify(cq_.get());
+ auto server_initial_metadata = cli_ctx.GetServerInitialMetadata();
+ EXPECT_EQ(meta1.second,
+ ToString(server_initial_metadata.find(meta1.first)->second));
+ EXPECT_EQ(meta2.second,
+ ToString(server_initial_metadata.find(meta2.first)->second));
+ EXPECT_EQ(static_cast<size_t>(2), server_initial_metadata.size());
+
+ srv_stream.Write(send_response, tag(3));
+
+ cli_stream->Read(&recv_response, tag(4));
+ Verifier().Expect(3, true).Expect(4, true).Verify(cq_.get());
+
+ srv_stream.Write(send_response, tag(5));
+ cli_stream->Read(&recv_response, tag(6));
+ Verifier().Expect(5, true).Expect(6, true).Verify(cq_.get());
+
+ srv_stream.Finish(Status::OK, tag(7));
+ cli_stream->Read(&recv_response, tag(8));
+ Verifier().Expect(7, true).Expect(8, false).Verify(cq_.get());
+
+ cli_stream->Finish(&recv_status, tag(9));
+ Verifier().Expect(9, true).Verify(cq_.get());
+
+ EXPECT_TRUE(recv_status.ok());
+}
+
+// 1 ping, 2 pongs.
+// Test for server initial metadata being sent implicitly
+TEST_P(AsyncEnd2endTest, ServerInitialMetadataServerStreamingImplicit) {
+ ResetStub();
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+ ClientContext cli_ctx;
+ ServerContext srv_ctx;
+ ServerAsyncWriter<EchoResponse> srv_stream(&srv_ctx);
+
+ send_request.set_message(GetParam().message_content);
+ std::pair<::TString, ::TString> meta1("key1", "val1");
+ std::pair<::TString, ::TString> meta2("key2", "val2");
+
+ std::unique_ptr<ClientAsyncReader<EchoResponse>> cli_stream(
+ stub_->AsyncResponseStream(&cli_ctx, send_request, cq_.get(), tag(1)));
+ service_->RequestResponseStream(&srv_ctx, &recv_request, &srv_stream,
+ cq_.get(), cq_.get(), tag(2));
+
+ Verifier().Expect(1, true).Expect(2, true).Verify(cq_.get());
+ EXPECT_EQ(send_request.message(), recv_request.message());
+
+ srv_ctx.AddInitialMetadata(meta1.first, meta1.second);
+ srv_ctx.AddInitialMetadata(meta2.first, meta2.second);
+ send_response.set_message(recv_request.message());
+ srv_stream.Write(send_response, tag(3));
+
+ cli_stream->Read(&recv_response, tag(4));
+ Verifier().Expect(3, true).Expect(4, true).Verify(cq_.get());
+ EXPECT_EQ(send_response.message(), recv_response.message());
+
+ auto server_initial_metadata = cli_ctx.GetServerInitialMetadata();
+ EXPECT_EQ(meta1.second,
+ ToString(server_initial_metadata.find(meta1.first)->second));
+ EXPECT_EQ(meta2.second,
+ ToString(server_initial_metadata.find(meta2.first)->second));
+ EXPECT_EQ(static_cast<size_t>(2), server_initial_metadata.size());
+
+ srv_stream.Write(send_response, tag(5));
+ cli_stream->Read(&recv_response, tag(6));
+ Verifier().Expect(5, true).Expect(6, true).Verify(cq_.get());
+
+ srv_stream.Finish(Status::OK, tag(7));
+ cli_stream->Read(&recv_response, tag(8));
+ Verifier().Expect(7, true).Expect(8, false).Verify(cq_.get());
+
+ cli_stream->Finish(&recv_status, tag(9));
+ Verifier().Expect(9, true).Verify(cq_.get());
+
+ EXPECT_TRUE(recv_status.ok());
+}
+
+TEST_P(AsyncEnd2endTest, ServerTrailingMetadataRpc) {
+ ResetStub();
+
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+
+ ClientContext cli_ctx;
+ ServerContext srv_ctx;
+ grpc::ServerAsyncResponseWriter<EchoResponse> response_writer(&srv_ctx);
+
+ send_request.set_message(GetParam().message_content);
+ std::pair<TString, TString> meta1("key1", "val1");
+ std::pair<TString, TString> meta2("key2", "val2");
+
+ std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader(
+ stub_->AsyncEcho(&cli_ctx, send_request, cq_.get()));
+ response_reader->Finish(&recv_response, &recv_status, tag(5));
+
+ service_->RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(),
+ cq_.get(), tag(2));
+ Verifier().Expect(2, true).Verify(cq_.get());
+ EXPECT_EQ(send_request.message(), recv_request.message());
+ response_writer.SendInitialMetadata(tag(3));
+ Verifier().Expect(3, true).Verify(cq_.get());
+
+ send_response.set_message(recv_request.message());
+ srv_ctx.AddTrailingMetadata(meta1.first, meta1.second);
+ srv_ctx.AddTrailingMetadata(meta2.first, meta2.second);
+ response_writer.Finish(send_response, Status::OK, tag(4));
+
+ Verifier().Expect(4, true).Expect(5, true).Verify(cq_.get());
+
+ EXPECT_EQ(send_response.message(), recv_response.message());
+ EXPECT_TRUE(recv_status.ok());
+ const auto& server_trailing_metadata = cli_ctx.GetServerTrailingMetadata();
+ EXPECT_EQ(meta1.second,
+ ToString(server_trailing_metadata.find(meta1.first)->second));
+ EXPECT_EQ(meta2.second,
+ ToString(server_trailing_metadata.find(meta2.first)->second));
+ EXPECT_EQ(static_cast<size_t>(2), server_trailing_metadata.size());
+}
+
+TEST_P(AsyncEnd2endTest, MetadataRpc) {
+ ResetStub();
+
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+
+ ClientContext cli_ctx;
+ ServerContext srv_ctx;
+ grpc::ServerAsyncResponseWriter<EchoResponse> response_writer(&srv_ctx);
+
+ send_request.set_message(GetParam().message_content);
+ std::pair<TString, TString> meta1("key1", "val1");
+ std::pair<TString, TString> meta2(
+ "key2-bin",
+ TString("\xc0\xc1\xc2\xc3\xc4\xc5\xc6\xc7\xc8\xc9\xca\xcb\xcc", 13));
+ std::pair<TString, TString> meta3("key3", "val3");
+ std::pair<TString, TString> meta6(
+ "key4-bin",
+ TString("\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d",
+ 14));
+ std::pair<TString, TString> meta5("key5", "val5");
+ std::pair<TString, TString> meta4(
+ "key6-bin",
+ TString(
+ "\xe0\xe1\xe2\xe3\xe4\xe5\xe6\xe7\xe8\xe9\xea\xeb\xec\xed\xee", 15));
+
+ cli_ctx.AddMetadata(meta1.first, meta1.second);
+ cli_ctx.AddMetadata(meta2.first, meta2.second);
+
+ std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader(
+ stub_->AsyncEcho(&cli_ctx, send_request, cq_.get()));
+ response_reader->ReadInitialMetadata(tag(4));
+
+ service_->RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(),
+ cq_.get(), tag(2));
+ Verifier().Expect(2, true).Verify(cq_.get());
+ EXPECT_EQ(send_request.message(), recv_request.message());
+ const auto& client_initial_metadata = srv_ctx.client_metadata();
+ EXPECT_EQ(meta1.second,
+ ToString(client_initial_metadata.find(meta1.first)->second));
+ EXPECT_EQ(meta2.second,
+ ToString(client_initial_metadata.find(meta2.first)->second));
+ EXPECT_GE(client_initial_metadata.size(), static_cast<size_t>(2));
+
+ srv_ctx.AddInitialMetadata(meta3.first, meta3.second);
+ srv_ctx.AddInitialMetadata(meta4.first, meta4.second);
+ response_writer.SendInitialMetadata(tag(3));
+ Verifier().Expect(3, true).Expect(4, true).Verify(cq_.get());
+ const auto& server_initial_metadata = cli_ctx.GetServerInitialMetadata();
+ EXPECT_EQ(meta3.second,
+ ToString(server_initial_metadata.find(meta3.first)->second));
+ EXPECT_EQ(meta4.second,
+ ToString(server_initial_metadata.find(meta4.first)->second));
+ EXPECT_GE(server_initial_metadata.size(), static_cast<size_t>(2));
+
+ send_response.set_message(recv_request.message());
+ srv_ctx.AddTrailingMetadata(meta5.first, meta5.second);
+ srv_ctx.AddTrailingMetadata(meta6.first, meta6.second);
+ response_writer.Finish(send_response, Status::OK, tag(5));
+ response_reader->Finish(&recv_response, &recv_status, tag(6));
+
+ Verifier().Expect(5, true).Expect(6, true).Verify(cq_.get());
+
+ EXPECT_EQ(send_response.message(), recv_response.message());
+ EXPECT_TRUE(recv_status.ok());
+ const auto& server_trailing_metadata = cli_ctx.GetServerTrailingMetadata();
+ EXPECT_EQ(meta5.second,
+ ToString(server_trailing_metadata.find(meta5.first)->second));
+ EXPECT_EQ(meta6.second,
+ ToString(server_trailing_metadata.find(meta6.first)->second));
+ EXPECT_GE(server_trailing_metadata.size(), static_cast<size_t>(2));
+}
+
+// Server uses AsyncNotifyWhenDone API to check for cancellation
+TEST_P(AsyncEnd2endTest, ServerCheckCancellation) {
+ ResetStub();
+
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+
+ ClientContext cli_ctx;
+ ServerContext srv_ctx;
+ grpc::ServerAsyncResponseWriter<EchoResponse> response_writer(&srv_ctx);
+
+ send_request.set_message(GetParam().message_content);
+ std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader(
+ stub_->AsyncEcho(&cli_ctx, send_request, cq_.get()));
+ response_reader->Finish(&recv_response, &recv_status, tag(4));
+
+ srv_ctx.AsyncNotifyWhenDone(tag(5));
+ service_->RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(),
+ cq_.get(), tag(2));
+
+ Verifier().Expect(2, true).Verify(cq_.get());
+ EXPECT_EQ(send_request.message(), recv_request.message());
+
+ cli_ctx.TryCancel();
+ Verifier().Expect(5, true).Expect(4, true).Verify(cq_.get());
+ EXPECT_TRUE(srv_ctx.IsCancelled());
+
+ EXPECT_EQ(StatusCode::CANCELLED, recv_status.error_code());
+}
+
+// Server uses AsyncNotifyWhenDone API to check for normal finish
+TEST_P(AsyncEnd2endTest, ServerCheckDone) {
+ ResetStub();
+
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+
+ ClientContext cli_ctx;
+ ServerContext srv_ctx;
+ grpc::ServerAsyncResponseWriter<EchoResponse> response_writer(&srv_ctx);
+
+ send_request.set_message(GetParam().message_content);
+ std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader(
+ stub_->AsyncEcho(&cli_ctx, send_request, cq_.get()));
+ response_reader->Finish(&recv_response, &recv_status, tag(4));
+
+ srv_ctx.AsyncNotifyWhenDone(tag(5));
+ service_->RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(),
+ cq_.get(), tag(2));
+
+ Verifier().Expect(2, true).Verify(cq_.get());
+ EXPECT_EQ(send_request.message(), recv_request.message());
+
+ send_response.set_message(recv_request.message());
+ response_writer.Finish(send_response, Status::OK, tag(3));
+ Verifier().Expect(3, true).Expect(4, true).Expect(5, true).Verify(cq_.get());
+ EXPECT_FALSE(srv_ctx.IsCancelled());
+
+ EXPECT_EQ(send_response.message(), recv_response.message());
+ EXPECT_TRUE(recv_status.ok());
+}
+
+TEST_P(AsyncEnd2endTest, UnimplementedRpc) {
+ ChannelArguments args;
+ const auto& channel_creds = GetCredentialsProvider()->GetChannelCredentials(
+ GetParam().credentials_type, &args);
+ std::shared_ptr<Channel> channel =
+ !(GetParam().inproc) ? ::grpc::CreateCustomChannel(server_address_.str(),
+ channel_creds, args)
+ : server_->InProcessChannel(args);
+ std::unique_ptr<grpc::testing::UnimplementedEchoService::Stub> stub;
+ stub = grpc::testing::UnimplementedEchoService::NewStub(channel);
+ EchoRequest send_request;
+ EchoResponse recv_response;
+ Status recv_status;
+
+ ClientContext cli_ctx;
+ send_request.set_message(GetParam().message_content);
+ std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader(
+ stub->AsyncUnimplemented(&cli_ctx, send_request, cq_.get()));
+
+ response_reader->Finish(&recv_response, &recv_status, tag(4));
+ Verifier().Expect(4, true).Verify(cq_.get());
+
+ EXPECT_EQ(StatusCode::UNIMPLEMENTED, recv_status.error_code());
+ EXPECT_EQ("", recv_status.error_message());
+}
+
+// This class is for testing scenarios where RPCs are cancelled on the server
+// by calling ServerContext::TryCancel(). Server uses AsyncNotifyWhenDone
+// API to check for cancellation
+class AsyncEnd2endServerTryCancelTest : public AsyncEnd2endTest {
+ protected:
+ typedef enum {
+ DO_NOT_CANCEL = 0,
+ CANCEL_BEFORE_PROCESSING,
+ CANCEL_DURING_PROCESSING,
+ CANCEL_AFTER_PROCESSING
+ } ServerTryCancelRequestPhase;
+
+ // Helper for testing client-streaming RPCs which are cancelled on the server.
+ // Depending on the value of server_try_cancel parameter, this will test one
+ // of the following three scenarios:
+ // CANCEL_BEFORE_PROCESSING: Rpc is cancelled by the server before reading
+ // any messages from the client
+ //
+ // CANCEL_DURING_PROCESSING: Rpc is cancelled by the server while reading
+ // messages from the client
+ //
+ // CANCEL_AFTER PROCESSING: Rpc is cancelled by server after reading all
+ // messages from the client (but before sending any status back to the
+ // client)
+ void TestClientStreamingServerCancel(
+ ServerTryCancelRequestPhase server_try_cancel) {
+ ResetStub();
+
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+
+ ClientContext cli_ctx;
+ ServerContext srv_ctx;
+ ServerAsyncReader<EchoResponse, EchoRequest> srv_stream(&srv_ctx);
+
+ // Initiate the 'RequestStream' call on client
+ CompletionQueue cli_cq;
+
+ std::unique_ptr<ClientAsyncWriter<EchoRequest>> cli_stream(
+ stub_->AsyncRequestStream(&cli_ctx, &recv_response, &cli_cq, tag(1)));
+
+ // On the server, request to be notified of 'RequestStream' calls
+ // and receive the 'RequestStream' call just made by the client
+ srv_ctx.AsyncNotifyWhenDone(tag(11));
+ service_->RequestRequestStream(&srv_ctx, &srv_stream, cq_.get(), cq_.get(),
+ tag(2));
+ std::thread t1([&cli_cq] { Verifier().Expect(1, true).Verify(&cli_cq); });
+ Verifier().Expect(2, true).Verify(cq_.get());
+ t1.join();
+
+ bool expected_server_cq_result = true;
+ bool expected_client_cq_result = true;
+
+ if (server_try_cancel == CANCEL_BEFORE_PROCESSING) {
+ srv_ctx.TryCancel();
+ Verifier().Expect(11, true).Verify(cq_.get());
+ EXPECT_TRUE(srv_ctx.IsCancelled());
+
+ // Since cancellation is done before server reads any results, we know
+ // for sure that all server cq results will return false from this
+ // point forward
+ expected_server_cq_result = false;
+ expected_client_cq_result = false;
+ }
+
+ bool ignore_client_cq_result =
+ (server_try_cancel == CANCEL_DURING_PROCESSING) ||
+ (server_try_cancel == CANCEL_BEFORE_PROCESSING);
+
+ std::thread cli_thread([&cli_cq, &cli_stream, &expected_client_cq_result,
+ &ignore_client_cq_result] {
+ EchoRequest send_request;
+ // Client sends 3 messages (tags 3, 4 and 5)
+ for (int tag_idx = 3; tag_idx <= 5; tag_idx++) {
+ send_request.set_message("Ping " + ToString(tag_idx));
+ cli_stream->Write(send_request, tag(tag_idx));
+ Verifier()
+ .Expect(tag_idx, expected_client_cq_result)
+ .Verify(&cli_cq, ignore_client_cq_result);
+ }
+ cli_stream->WritesDone(tag(6));
+ // Ignore ok on WritesDone since cancel can affect it
+ Verifier()
+ .Expect(6, expected_client_cq_result)
+ .Verify(&cli_cq, ignore_client_cq_result);
+ });
+
+ bool ignore_cq_result = false;
+ bool want_done_tag = false;
+ std::thread* server_try_cancel_thd = nullptr;
+
+ auto verif = Verifier();
+
+ if (server_try_cancel == CANCEL_DURING_PROCESSING) {
+ server_try_cancel_thd =
+ new std::thread([&srv_ctx] { srv_ctx.TryCancel(); });
+ // Server will cancel the RPC in a parallel thread while reading the
+ // requests from the client. Since the cancellation can happen at anytime,
+ // some of the cq results (i.e those until cancellation) might be true but
+ // its non deterministic. So better to ignore the cq results
+ ignore_cq_result = true;
+ // Expect that we might possibly see the done tag that
+ // indicates cancellation completion in this case
+ want_done_tag = true;
+ verif.Expect(11, true);
+ }
+
+ // Server reads 3 messages (tags 6, 7 and 8)
+ // But if want_done_tag is true, we might also see tag 11
+ for (int tag_idx = 6; tag_idx <= 8; tag_idx++) {
+ srv_stream.Read(&recv_request, tag(tag_idx));
+ // Note that we'll add something to the verifier and verify that
+ // something was seen, but it might be tag 11 and not what we
+ // just added
+ int got_tag = verif.Expect(tag_idx, expected_server_cq_result)
+ .Next(cq_.get(), ignore_cq_result);
+ GPR_ASSERT((got_tag == tag_idx) || (got_tag == 11 && want_done_tag));
+ if (got_tag == 11) {
+ EXPECT_TRUE(srv_ctx.IsCancelled());
+ want_done_tag = false;
+ // Now get the other entry that we were waiting on
+ EXPECT_EQ(verif.Next(cq_.get(), ignore_cq_result), tag_idx);
+ }
+ }
+
+ cli_thread.join();
+
+ if (server_try_cancel_thd != nullptr) {
+ server_try_cancel_thd->join();
+ delete server_try_cancel_thd;
+ }
+
+ if (server_try_cancel == CANCEL_AFTER_PROCESSING) {
+ srv_ctx.TryCancel();
+ want_done_tag = true;
+ verif.Expect(11, true);
+ }
+
+ if (want_done_tag) {
+ verif.Verify(cq_.get());
+ EXPECT_TRUE(srv_ctx.IsCancelled());
+ want_done_tag = false;
+ }
+
+ // The RPC has been cancelled at this point for sure (i.e irrespective of
+ // the value of `server_try_cancel` is). So, from this point forward, we
+ // know that cq results are supposed to return false on server.
+
+ // Server sends the final message and cancelled status (but the RPC is
+ // already cancelled at this point. So we expect the operation to fail)
+ srv_stream.Finish(send_response, Status::CANCELLED, tag(9));
+ Verifier().Expect(9, false).Verify(cq_.get());
+
+ // Client will see the cancellation
+ cli_stream->Finish(&recv_status, tag(10));
+ Verifier().Expect(10, true).Verify(&cli_cq);
+ EXPECT_FALSE(recv_status.ok());
+ EXPECT_EQ(::grpc::StatusCode::CANCELLED, recv_status.error_code());
+
+ cli_cq.Shutdown();
+ void* dummy_tag;
+ bool dummy_ok;
+ while (cli_cq.Next(&dummy_tag, &dummy_ok)) {
+ }
+ }
+
+ // Helper for testing server-streaming RPCs which are cancelled on the server.
+ // Depending on the value of server_try_cancel parameter, this will test one
+ // of the following three scenarios:
+ // CANCEL_BEFORE_PROCESSING: Rpc is cancelled by the server before sending
+ // any messages to the client
+ //
+ // CANCEL_DURING_PROCESSING: Rpc is cancelled by the server while sending
+ // messages to the client
+ //
+ // CANCEL_AFTER PROCESSING: Rpc is cancelled by server after sending all
+ // messages to the client (but before sending any status back to the
+ // client)
+ void TestServerStreamingServerCancel(
+ ServerTryCancelRequestPhase server_try_cancel) {
+ ResetStub();
+
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ Status recv_status;
+ ClientContext cli_ctx;
+ ServerContext srv_ctx;
+ ServerAsyncWriter<EchoResponse> srv_stream(&srv_ctx);
+
+ send_request.set_message("Ping");
+ // Initiate the 'ResponseStream' call on the client
+ CompletionQueue cli_cq;
+ std::unique_ptr<ClientAsyncReader<EchoResponse>> cli_stream(
+ stub_->AsyncResponseStream(&cli_ctx, send_request, &cli_cq, tag(1)));
+ // On the server, request to be notified of 'ResponseStream' calls and
+ // receive the call just made by the client
+ srv_ctx.AsyncNotifyWhenDone(tag(11));
+ service_->RequestResponseStream(&srv_ctx, &recv_request, &srv_stream,
+ cq_.get(), cq_.get(), tag(2));
+
+ std::thread t1([&cli_cq] { Verifier().Expect(1, true).Verify(&cli_cq); });
+ Verifier().Expect(2, true).Verify(cq_.get());
+ t1.join();
+
+ EXPECT_EQ(send_request.message(), recv_request.message());
+
+ bool expected_cq_result = true;
+ bool ignore_cq_result = false;
+ bool want_done_tag = false;
+ bool expected_client_cq_result = true;
+ bool ignore_client_cq_result =
+ (server_try_cancel != CANCEL_BEFORE_PROCESSING);
+
+ if (server_try_cancel == CANCEL_BEFORE_PROCESSING) {
+ srv_ctx.TryCancel();
+ Verifier().Expect(11, true).Verify(cq_.get());
+ EXPECT_TRUE(srv_ctx.IsCancelled());
+
+ // We know for sure that all cq results will be false from this point
+ // since the server cancelled the RPC
+ expected_cq_result = false;
+ expected_client_cq_result = false;
+ }
+
+ std::thread cli_thread([&cli_cq, &cli_stream, &expected_client_cq_result,
+ &ignore_client_cq_result] {
+ // Client attempts to read the three messages from the server
+ for (int tag_idx = 6; tag_idx <= 8; tag_idx++) {
+ EchoResponse recv_response;
+ cli_stream->Read(&recv_response, tag(tag_idx));
+ Verifier()
+ .Expect(tag_idx, expected_client_cq_result)
+ .Verify(&cli_cq, ignore_client_cq_result);
+ }
+ });
+
+ std::thread* server_try_cancel_thd = nullptr;
+
+ auto verif = Verifier();
+
+ if (server_try_cancel == CANCEL_DURING_PROCESSING) {
+ server_try_cancel_thd =
+ new std::thread([&srv_ctx] { srv_ctx.TryCancel(); });
+
+ // Server will cancel the RPC in a parallel thread while writing responses
+ // to the client. Since the cancellation can happen at anytime, some of
+ // the cq results (i.e those until cancellation) might be true but it is
+ // non deterministic. So better to ignore the cq results
+ ignore_cq_result = true;
+ // Expect that we might possibly see the done tag that
+ // indicates cancellation completion in this case
+ want_done_tag = true;
+ verif.Expect(11, true);
+ }
+
+ // Server sends three messages (tags 3, 4 and 5)
+ // But if want_done tag is true, we might also see tag 11
+ for (int tag_idx = 3; tag_idx <= 5; tag_idx++) {
+ send_response.set_message("Pong " + ToString(tag_idx));
+ srv_stream.Write(send_response, tag(tag_idx));
+ // Note that we'll add something to the verifier and verify that
+ // something was seen, but it might be tag 11 and not what we
+ // just added
+ int got_tag = verif.Expect(tag_idx, expected_cq_result)
+ .Next(cq_.get(), ignore_cq_result);
+ GPR_ASSERT((got_tag == tag_idx) || (got_tag == 11 && want_done_tag));
+ if (got_tag == 11) {
+ EXPECT_TRUE(srv_ctx.IsCancelled());
+ want_done_tag = false;
+ // Now get the other entry that we were waiting on
+ EXPECT_EQ(verif.Next(cq_.get(), ignore_cq_result), tag_idx);
+ }
+ }
+
+ if (server_try_cancel_thd != nullptr) {
+ server_try_cancel_thd->join();
+ delete server_try_cancel_thd;
+ }
+
+ if (server_try_cancel == CANCEL_AFTER_PROCESSING) {
+ srv_ctx.TryCancel();
+ want_done_tag = true;
+ verif.Expect(11, true);
+ }
+
+ if (want_done_tag) {
+ verif.Verify(cq_.get());
+ EXPECT_TRUE(srv_ctx.IsCancelled());
+ want_done_tag = false;
+ }
+
+ cli_thread.join();
+
+ // The RPC has been cancelled at this point for sure (i.e irrespective of
+ // the value of `server_try_cancel` is). So, from this point forward, we
+ // know that cq results are supposed to return false on server.
+
+ // Server finishes the stream (but the RPC is already cancelled)
+ srv_stream.Finish(Status::CANCELLED, tag(9));
+ Verifier().Expect(9, false).Verify(cq_.get());
+
+ // Client will see the cancellation
+ cli_stream->Finish(&recv_status, tag(10));
+ Verifier().Expect(10, true).Verify(&cli_cq);
+ EXPECT_FALSE(recv_status.ok());
+ EXPECT_EQ(::grpc::StatusCode::CANCELLED, recv_status.error_code());
+
+ cli_cq.Shutdown();
+ void* dummy_tag;
+ bool dummy_ok;
+ while (cli_cq.Next(&dummy_tag, &dummy_ok)) {
+ }
+ }
+
+ // Helper for testing bidirectinal-streaming RPCs which are cancelled on the
+ // server.
+ //
+ // Depending on the value of server_try_cancel parameter, this will
+ // test one of the following three scenarios:
+ // CANCEL_BEFORE_PROCESSING: Rpc is cancelled by the server before reading/
+ // writing any messages from/to the client
+ //
+ // CANCEL_DURING_PROCESSING: Rpc is cancelled by the server while reading
+ // messages from the client
+ //
+ // CANCEL_AFTER PROCESSING: Rpc is cancelled by server after reading all
+ // messages from the client (but before sending any status back to the
+ // client)
+ void TestBidiStreamingServerCancel(
+ ServerTryCancelRequestPhase server_try_cancel) {
+ ResetStub();
+
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+ ClientContext cli_ctx;
+ ServerContext srv_ctx;
+ ServerAsyncReaderWriter<EchoResponse, EchoRequest> srv_stream(&srv_ctx);
+
+ // Initiate the call from the client side
+ std::unique_ptr<ClientAsyncReaderWriter<EchoRequest, EchoResponse>>
+ cli_stream(stub_->AsyncBidiStream(&cli_ctx, cq_.get(), tag(1)));
+
+ // On the server, request to be notified of the 'BidiStream' call and
+ // receive the call just made by the client
+ srv_ctx.AsyncNotifyWhenDone(tag(11));
+ service_->RequestBidiStream(&srv_ctx, &srv_stream, cq_.get(), cq_.get(),
+ tag(2));
+ Verifier().Expect(1, true).Expect(2, true).Verify(cq_.get());
+
+ auto verif = Verifier();
+
+ // Client sends the first and the only message
+ send_request.set_message("Ping");
+ cli_stream->Write(send_request, tag(3));
+ verif.Expect(3, true);
+
+ bool expected_cq_result = true;
+ bool ignore_cq_result = false;
+ bool want_done_tag = false;
+
+ int got_tag, got_tag2;
+ bool tag_3_done = false;
+
+ if (server_try_cancel == CANCEL_BEFORE_PROCESSING) {
+ srv_ctx.TryCancel();
+ verif.Expect(11, true);
+ // We know for sure that all server cq results will be false from
+ // this point since the server cancelled the RPC. However, we can't
+ // say for sure about the client
+ expected_cq_result = false;
+ ignore_cq_result = true;
+
+ do {
+ got_tag = verif.Next(cq_.get(), ignore_cq_result);
+ GPR_ASSERT(((got_tag == 3) && !tag_3_done) || (got_tag == 11));
+ if (got_tag == 3) {
+ tag_3_done = true;
+ }
+ } while (got_tag != 11);
+ EXPECT_TRUE(srv_ctx.IsCancelled());
+ }
+
+ std::thread* server_try_cancel_thd = nullptr;
+
+ if (server_try_cancel == CANCEL_DURING_PROCESSING) {
+ server_try_cancel_thd =
+ new std::thread([&srv_ctx] { srv_ctx.TryCancel(); });
+
+ // Since server is going to cancel the RPC in a parallel thread, some of
+ // the cq results (i.e those until the cancellation) might be true. Since
+ // that number is non-deterministic, it is better to ignore the cq results
+ ignore_cq_result = true;
+ // Expect that we might possibly see the done tag that
+ // indicates cancellation completion in this case
+ want_done_tag = true;
+ verif.Expect(11, true);
+ }
+
+ srv_stream.Read(&recv_request, tag(4));
+ verif.Expect(4, expected_cq_result);
+ got_tag = tag_3_done ? 3 : verif.Next(cq_.get(), ignore_cq_result);
+ got_tag2 = verif.Next(cq_.get(), ignore_cq_result);
+ GPR_ASSERT((got_tag == 3) || (got_tag == 4) ||
+ (got_tag == 11 && want_done_tag));
+ GPR_ASSERT((got_tag2 == 3) || (got_tag2 == 4) ||
+ (got_tag2 == 11 && want_done_tag));
+ // If we get 3 and 4, we don't need to wait for 11, but if
+ // we get 11, we should also clear 3 and 4
+ if (got_tag + got_tag2 != 7) {
+ EXPECT_TRUE(srv_ctx.IsCancelled());
+ want_done_tag = false;
+ got_tag = verif.Next(cq_.get(), ignore_cq_result);
+ GPR_ASSERT((got_tag == 3) || (got_tag == 4));
+ }
+
+ send_response.set_message("Pong");
+ srv_stream.Write(send_response, tag(5));
+ verif.Expect(5, expected_cq_result);
+
+ cli_stream->Read(&recv_response, tag(6));
+ verif.Expect(6, expected_cq_result);
+ got_tag = verif.Next(cq_.get(), ignore_cq_result);
+ got_tag2 = verif.Next(cq_.get(), ignore_cq_result);
+ GPR_ASSERT((got_tag == 5) || (got_tag == 6) ||
+ (got_tag == 11 && want_done_tag));
+ GPR_ASSERT((got_tag2 == 5) || (got_tag2 == 6) ||
+ (got_tag2 == 11 && want_done_tag));
+ // If we get 5 and 6, we don't need to wait for 11, but if
+ // we get 11, we should also clear 5 and 6
+ if (got_tag + got_tag2 != 11) {
+ EXPECT_TRUE(srv_ctx.IsCancelled());
+ want_done_tag = false;
+ got_tag = verif.Next(cq_.get(), ignore_cq_result);
+ GPR_ASSERT((got_tag == 5) || (got_tag == 6));
+ }
+
+ // This is expected to succeed in all cases
+ cli_stream->WritesDone(tag(7));
+ verif.Expect(7, true);
+ // TODO(vjpai): Consider whether the following is too flexible
+ // or whether it should just be reset to ignore_cq_result
+ bool ignore_cq_wd_result =
+ ignore_cq_result || (server_try_cancel == CANCEL_BEFORE_PROCESSING);
+ got_tag = verif.Next(cq_.get(), ignore_cq_wd_result);
+ GPR_ASSERT((got_tag == 7) || (got_tag == 11 && want_done_tag));
+ if (got_tag == 11) {
+ EXPECT_TRUE(srv_ctx.IsCancelled());
+ want_done_tag = false;
+ // Now get the other entry that we were waiting on
+ EXPECT_EQ(verif.Next(cq_.get(), ignore_cq_wd_result), 7);
+ }
+
+ // This is expected to fail in all cases i.e for all values of
+ // server_try_cancel. This is because at this point, either there are no
+ // more msgs from the client (because client called WritesDone) or the RPC
+ // is cancelled on the server
+ srv_stream.Read(&recv_request, tag(8));
+ verif.Expect(8, false);
+ got_tag = verif.Next(cq_.get(), ignore_cq_result);
+ GPR_ASSERT((got_tag == 8) || (got_tag == 11 && want_done_tag));
+ if (got_tag == 11) {
+ EXPECT_TRUE(srv_ctx.IsCancelled());
+ want_done_tag = false;
+ // Now get the other entry that we were waiting on
+ EXPECT_EQ(verif.Next(cq_.get(), ignore_cq_result), 8);
+ }
+
+ if (server_try_cancel_thd != nullptr) {
+ server_try_cancel_thd->join();
+ delete server_try_cancel_thd;
+ }
+
+ if (server_try_cancel == CANCEL_AFTER_PROCESSING) {
+ srv_ctx.TryCancel();
+ want_done_tag = true;
+ verif.Expect(11, true);
+ }
+
+ if (want_done_tag) {
+ verif.Verify(cq_.get());
+ EXPECT_TRUE(srv_ctx.IsCancelled());
+ want_done_tag = false;
+ }
+
+ // The RPC has been cancelled at this point for sure (i.e irrespective of
+ // the value of `server_try_cancel` is). So, from this point forward, we
+ // know that cq results are supposed to return false on server.
+
+ srv_stream.Finish(Status::CANCELLED, tag(9));
+ Verifier().Expect(9, false).Verify(cq_.get());
+
+ cli_stream->Finish(&recv_status, tag(10));
+ Verifier().Expect(10, true).Verify(cq_.get());
+ EXPECT_FALSE(recv_status.ok());
+ EXPECT_EQ(grpc::StatusCode::CANCELLED, recv_status.error_code());
+ }
+};
+
+TEST_P(AsyncEnd2endServerTryCancelTest, ClientStreamingServerTryCancelBefore) {
+ TestClientStreamingServerCancel(CANCEL_BEFORE_PROCESSING);
+}
+
+TEST_P(AsyncEnd2endServerTryCancelTest, ClientStreamingServerTryCancelDuring) {
+ TestClientStreamingServerCancel(CANCEL_DURING_PROCESSING);
+}
+
+TEST_P(AsyncEnd2endServerTryCancelTest, ClientStreamingServerTryCancelAfter) {
+ TestClientStreamingServerCancel(CANCEL_AFTER_PROCESSING);
+}
+
+TEST_P(AsyncEnd2endServerTryCancelTest, ServerStreamingServerTryCancelBefore) {
+ TestServerStreamingServerCancel(CANCEL_BEFORE_PROCESSING);
+}
+
+TEST_P(AsyncEnd2endServerTryCancelTest, ServerStreamingServerTryCancelDuring) {
+ TestServerStreamingServerCancel(CANCEL_DURING_PROCESSING);
+}
+
+TEST_P(AsyncEnd2endServerTryCancelTest, ServerStreamingServerTryCancelAfter) {
+ TestServerStreamingServerCancel(CANCEL_AFTER_PROCESSING);
+}
+
+TEST_P(AsyncEnd2endServerTryCancelTest, ServerBidiStreamingTryCancelBefore) {
+ TestBidiStreamingServerCancel(CANCEL_BEFORE_PROCESSING);
+}
+
+TEST_P(AsyncEnd2endServerTryCancelTest, ServerBidiStreamingTryCancelDuring) {
+ TestBidiStreamingServerCancel(CANCEL_DURING_PROCESSING);
+}
+
+TEST_P(AsyncEnd2endServerTryCancelTest, ServerBidiStreamingTryCancelAfter) {
+ TestBidiStreamingServerCancel(CANCEL_AFTER_PROCESSING);
+}
+
+std::vector<TestScenario> CreateTestScenarios(bool /*test_secure*/,
+ bool test_message_size_limit) {
+ std::vector<TestScenario> scenarios;
+ std::vector<TString> credentials_types;
+ std::vector<TString> messages;
+
+ auto insec_ok = [] {
+ // Only allow insecure credentials type when it is registered with the
+ // provider. User may create providers that do not have insecure.
+ return GetCredentialsProvider()->GetChannelCredentials(
+ kInsecureCredentialsType, nullptr) != nullptr;
+ };
+
+ if (insec_ok()) {
+ credentials_types.push_back(kInsecureCredentialsType);
+ }
+ auto sec_list = GetCredentialsProvider()->GetSecureCredentialsTypeList();
+ for (auto sec = sec_list.begin(); sec != sec_list.end(); sec++) {
+ credentials_types.push_back(*sec);
+ }
+ GPR_ASSERT(!credentials_types.empty());
+
+ messages.push_back("Hello");
+ if (test_message_size_limit) {
+ for (size_t k = 1; k < GRPC_DEFAULT_MAX_RECV_MESSAGE_LENGTH / 1024;
+ k *= 32) {
+ TString big_msg;
+ for (size_t i = 0; i < k * 1024; ++i) {
+ char c = 'a' + (i % 26);
+ big_msg += c;
+ }
+ messages.push_back(big_msg);
+ }
+ if (!BuiltUnderMsan()) {
+ // 4MB message processing with SSL is very slow under msan
+ // (causes timeouts) and doesn't really increase the signal from tests.
+ // Reserve 100 bytes for other fields of the message proto.
+ messages.push_back(
+ TString(GRPC_DEFAULT_MAX_RECV_MESSAGE_LENGTH - 100, 'a'));
+ }
+ }
+
+ // TODO (sreek) Renable tests with health check service after the issue
+ // https://github.com/grpc/grpc/issues/11223 is resolved
+ for (auto health_check_service : {false}) {
+ for (auto msg = messages.begin(); msg != messages.end(); msg++) {
+ for (auto cred = credentials_types.begin();
+ cred != credentials_types.end(); ++cred) {
+ scenarios.emplace_back(false, *cred, health_check_service, *msg);
+ }
+ if (insec_ok()) {
+ scenarios.emplace_back(true, kInsecureCredentialsType,
+ health_check_service, *msg);
+ }
+ }
+ }
+ return scenarios;
+}
+
+INSTANTIATE_TEST_SUITE_P(AsyncEnd2end, AsyncEnd2endTest,
+ ::testing::ValuesIn(CreateTestScenarios(true, true)));
+INSTANTIATE_TEST_SUITE_P(AsyncEnd2endServerTryCancel,
+ AsyncEnd2endServerTryCancelTest,
+ ::testing::ValuesIn(CreateTestScenarios(false,
+ false)));
+
+} // namespace
+} // namespace testing
+} // namespace grpc
+
+int main(int argc, char** argv) {
+ // Change the backup poll interval from 5s to 100ms to speed up the
+ // ReconnectChannel test
+ GPR_GLOBAL_CONFIG_SET(grpc_client_channel_backup_poll_interval_ms, 100);
+ grpc::testing::TestEnvironment env(argc, argv);
+ ::testing::InitGoogleTest(&argc, argv);
+ int ret = RUN_ALL_TESTS();
+ return ret;
+}
diff --git a/contrib/libs/grpc/test/cpp/end2end/cfstream_test.cc b/contrib/libs/grpc/test/cpp/end2end/cfstream_test.cc
new file mode 100644
index 0000000000..e6695982bd
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/end2end/cfstream_test.cc
@@ -0,0 +1,496 @@
+/*
+ *
+ * Copyright 2019 The gRPC Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include "src/core/lib/iomgr/port.h"
+
+#include <algorithm>
+#include <memory>
+#include <mutex>
+#include <random>
+#include <thread>
+
+#include <grpc/grpc.h>
+#include <grpc/support/alloc.h>
+#include <grpc/support/atm.h>
+#include <grpc/support/log.h>
+#include <grpc/support/string_util.h>
+#include <grpc/support/time.h>
+#include <grpcpp/channel.h>
+#include <grpcpp/client_context.h>
+#include <grpcpp/create_channel.h>
+#include <grpcpp/health_check_service_interface.h>
+#include <grpcpp/server.h>
+#include <grpcpp/server_builder.h>
+#include <gtest/gtest.h>
+
+#include "src/core/lib/backoff/backoff.h"
+#include "src/core/lib/gpr/env.h"
+
+#include "src/proto/grpc/testing/echo.grpc.pb.h"
+#include "test/core/util/debugger_macros.h"
+#include "test/core/util/port.h"
+#include "test/core/util/test_config.h"
+#include "test/cpp/end2end/test_service_impl.h"
+#include "test/cpp/util/test_credentials_provider.h"
+
+#ifdef GRPC_CFSTREAM
+using grpc::ClientAsyncResponseReader;
+using grpc::testing::EchoRequest;
+using grpc::testing::EchoResponse;
+using grpc::testing::RequestParams;
+using std::chrono::system_clock;
+
+namespace grpc {
+namespace testing {
+namespace {
+
+struct TestScenario {
+ TestScenario(const TString& creds_type, const TString& content)
+ : credentials_type(creds_type), message_content(content) {}
+ const TString credentials_type;
+ const TString message_content;
+};
+
+class CFStreamTest : public ::testing::TestWithParam<TestScenario> {
+ protected:
+ CFStreamTest()
+ : server_host_("grpctest"),
+ interface_("lo0"),
+ ipv4_address_("10.0.0.1") {}
+
+ void DNSUp() {
+ std::ostringstream cmd;
+ // Add DNS entry for server_host_ in /etc/hosts
+ cmd << "echo '" << ipv4_address_ << " " << server_host_
+ << " ' | sudo tee -a /etc/hosts";
+ std::system(cmd.str().c_str());
+ }
+
+ void DNSDown() {
+ std::ostringstream cmd;
+ // Remove DNS entry for server_host_ in /etc/hosts
+ cmd << "sudo sed -i '.bak' '/" << server_host_ << "/d' /etc/hosts";
+ std::system(cmd.str().c_str());
+ }
+
+ void InterfaceUp() {
+ std::ostringstream cmd;
+ cmd << "sudo /sbin/ifconfig " << interface_ << " alias " << ipv4_address_;
+ std::system(cmd.str().c_str());
+ }
+
+ void InterfaceDown() {
+ std::ostringstream cmd;
+ cmd << "sudo /sbin/ifconfig " << interface_ << " -alias " << ipv4_address_;
+ std::system(cmd.str().c_str());
+ }
+
+ void NetworkUp() {
+ gpr_log(GPR_DEBUG, "Bringing network up");
+ InterfaceUp();
+ DNSUp();
+ }
+
+ void NetworkDown() {
+ gpr_log(GPR_DEBUG, "Bringing network down");
+ InterfaceDown();
+ DNSDown();
+ }
+
+ void SetUp() override {
+ NetworkUp();
+ grpc_init();
+ StartServer();
+ }
+
+ void TearDown() override {
+ NetworkDown();
+ StopServer();
+ grpc_shutdown();
+ }
+
+ void StartServer() {
+ port_ = grpc_pick_unused_port_or_die();
+ server_.reset(new ServerData(port_, GetParam().credentials_type));
+ server_->Start(server_host_);
+ }
+ void StopServer() { server_->Shutdown(); }
+
+ std::unique_ptr<grpc::testing::EchoTestService::Stub> BuildStub(
+ const std::shared_ptr<Channel>& channel) {
+ return grpc::testing::EchoTestService::NewStub(channel);
+ }
+
+ std::shared_ptr<Channel> BuildChannel() {
+ std::ostringstream server_address;
+ server_address << server_host_ << ":" << port_;
+ ChannelArguments args;
+ auto channel_creds = GetCredentialsProvider()->GetChannelCredentials(
+ GetParam().credentials_type, &args);
+ return CreateCustomChannel(server_address.str(), channel_creds, args);
+ }
+
+ int GetStreamID(ClientContext& context) {
+ int stream_id = 0;
+ grpc_call* call = context.c_call();
+ if (call) {
+ grpc_chttp2_stream* stream = grpc_chttp2_stream_from_call(call);
+ if (stream) {
+ stream_id = stream->id;
+ }
+ }
+ return stream_id;
+ }
+
+ void SendRpc(
+ const std::unique_ptr<grpc::testing::EchoTestService::Stub>& stub,
+ bool expect_success = false) {
+ auto response = std::unique_ptr<EchoResponse>(new EchoResponse());
+ EchoRequest request;
+ auto& msg = GetParam().message_content;
+ request.set_message(msg);
+ ClientContext context;
+ Status status = stub->Echo(&context, request, response.get());
+ int stream_id = GetStreamID(context);
+ if (status.ok()) {
+ gpr_log(GPR_DEBUG, "RPC with stream_id %d succeeded", stream_id);
+ EXPECT_EQ(msg, response->message());
+ } else {
+ gpr_log(GPR_DEBUG, "RPC with stream_id %d failed: %s", stream_id,
+ status.error_message().c_str());
+ }
+ if (expect_success) {
+ EXPECT_TRUE(status.ok());
+ }
+ }
+ void SendAsyncRpc(
+ const std::unique_ptr<grpc::testing::EchoTestService::Stub>& stub,
+ RequestParams param = RequestParams()) {
+ EchoRequest request;
+ request.set_message(GetParam().message_content);
+ *request.mutable_param() = std::move(param);
+ AsyncClientCall* call = new AsyncClientCall;
+
+ call->response_reader =
+ stub->PrepareAsyncEcho(&call->context, request, &cq_);
+
+ call->response_reader->StartCall();
+ call->response_reader->Finish(&call->reply, &call->status, (void*)call);
+ }
+
+ void ShutdownCQ() { cq_.Shutdown(); }
+
+ bool CQNext(void** tag, bool* ok) {
+ auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(10);
+ auto ret = cq_.AsyncNext(tag, ok, deadline);
+ if (ret == grpc::CompletionQueue::GOT_EVENT) {
+ return true;
+ } else if (ret == grpc::CompletionQueue::SHUTDOWN) {
+ return false;
+ } else {
+ GPR_ASSERT(ret == grpc::CompletionQueue::TIMEOUT);
+ // This can happen if we hit the Apple CFStream bug which results in the
+ // read stream hanging. We are ignoring hangs and timeouts, but these
+ // tests are still useful as they can catch memory memory corruptions,
+ // crashes and other bugs that don't result in test hang/timeout.
+ return false;
+ }
+ }
+
+ bool WaitForChannelNotReady(Channel* channel, int timeout_seconds = 5) {
+ const gpr_timespec deadline =
+ grpc_timeout_seconds_to_deadline(timeout_seconds);
+ grpc_connectivity_state state;
+ while ((state = channel->GetState(false /* try_to_connect */)) ==
+ GRPC_CHANNEL_READY) {
+ if (!channel->WaitForStateChange(state, deadline)) return false;
+ }
+ return true;
+ }
+
+ bool WaitForChannelReady(Channel* channel, int timeout_seconds = 10) {
+ const gpr_timespec deadline =
+ grpc_timeout_seconds_to_deadline(timeout_seconds);
+ grpc_connectivity_state state;
+ while ((state = channel->GetState(true /* try_to_connect */)) !=
+ GRPC_CHANNEL_READY) {
+ if (!channel->WaitForStateChange(state, deadline)) return false;
+ }
+ return true;
+ }
+
+ struct AsyncClientCall {
+ EchoResponse reply;
+ ClientContext context;
+ Status status;
+ std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader;
+ };
+
+ private:
+ struct ServerData {
+ int port_;
+ const TString creds_;
+ std::unique_ptr<Server> server_;
+ TestServiceImpl service_;
+ std::unique_ptr<std::thread> thread_;
+ bool server_ready_ = false;
+
+ ServerData(int port, const TString& creds)
+ : port_(port), creds_(creds) {}
+
+ void Start(const TString& server_host) {
+ gpr_log(GPR_INFO, "starting server on port %d", port_);
+ std::mutex mu;
+ std::unique_lock<std::mutex> lock(mu);
+ std::condition_variable cond;
+ thread_.reset(new std::thread(
+ std::bind(&ServerData::Serve, this, server_host, &mu, &cond)));
+ cond.wait(lock, [this] { return server_ready_; });
+ server_ready_ = false;
+ gpr_log(GPR_INFO, "server startup complete");
+ }
+
+ void Serve(const TString& server_host, std::mutex* mu,
+ std::condition_variable* cond) {
+ std::ostringstream server_address;
+ server_address << server_host << ":" << port_;
+ ServerBuilder builder;
+ auto server_creds =
+ GetCredentialsProvider()->GetServerCredentials(creds_);
+ builder.AddListeningPort(server_address.str(), server_creds);
+ builder.RegisterService(&service_);
+ server_ = builder.BuildAndStart();
+ std::lock_guard<std::mutex> lock(*mu);
+ server_ready_ = true;
+ cond->notify_one();
+ }
+
+ void Shutdown(bool join = true) {
+ server_->Shutdown(grpc_timeout_milliseconds_to_deadline(0));
+ if (join) thread_->join();
+ }
+ };
+
+ CompletionQueue cq_;
+ const TString server_host_;
+ const TString interface_;
+ const TString ipv4_address_;
+ std::unique_ptr<ServerData> server_;
+ int port_;
+};
+
+std::vector<TestScenario> CreateTestScenarios() {
+ std::vector<TestScenario> scenarios;
+ std::vector<TString> credentials_types;
+ std::vector<TString> messages;
+
+ credentials_types.push_back(kInsecureCredentialsType);
+ auto sec_list = GetCredentialsProvider()->GetSecureCredentialsTypeList();
+ for (auto sec = sec_list.begin(); sec != sec_list.end(); sec++) {
+ credentials_types.push_back(*sec);
+ }
+
+ messages.push_back("🖖");
+ for (size_t k = 1; k < GRPC_DEFAULT_MAX_RECV_MESSAGE_LENGTH / 1024; k *= 32) {
+ TString big_msg;
+ for (size_t i = 0; i < k * 1024; ++i) {
+ char c = 'a' + (i % 26);
+ big_msg += c;
+ }
+ messages.push_back(big_msg);
+ }
+ for (auto cred = credentials_types.begin(); cred != credentials_types.end();
+ ++cred) {
+ for (auto msg = messages.begin(); msg != messages.end(); msg++) {
+ scenarios.emplace_back(*cred, *msg);
+ }
+ }
+
+ return scenarios;
+}
+
+INSTANTIATE_TEST_SUITE_P(CFStreamTest, CFStreamTest,
+ ::testing::ValuesIn(CreateTestScenarios()));
+
+// gRPC should automatically detech network flaps (without enabling keepalives)
+// when CFStream is enabled
+TEST_P(CFStreamTest, NetworkTransition) {
+ auto channel = BuildChannel();
+ auto stub = BuildStub(channel);
+ // Channel should be in READY state after we send an RPC
+ SendRpc(stub, /*expect_success=*/true);
+ EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY);
+
+ std::atomic_bool shutdown{false};
+ std::thread sender = std::thread([this, &stub, &shutdown]() {
+ while (true) {
+ if (shutdown.load()) {
+ return;
+ }
+ SendRpc(stub);
+ std::this_thread::sleep_for(std::chrono::milliseconds(1000));
+ }
+ });
+
+ // bring down network
+ NetworkDown();
+
+ // network going down should be detected by cfstream
+ EXPECT_TRUE(WaitForChannelNotReady(channel.get()));
+
+ // bring network interface back up
+ std::this_thread::sleep_for(std::chrono::milliseconds(1000));
+ NetworkUp();
+
+ // channel should reconnect
+ EXPECT_TRUE(WaitForChannelReady(channel.get()));
+ EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY);
+ shutdown.store(true);
+ sender.join();
+}
+
+// Network flaps while RPCs are in flight
+TEST_P(CFStreamTest, NetworkFlapRpcsInFlight) {
+ auto channel = BuildChannel();
+ auto stub = BuildStub(channel);
+ std::atomic_int rpcs_sent{0};
+
+ // Channel should be in READY state after we send some RPCs
+ for (int i = 0; i < 10; ++i) {
+ RequestParams param;
+ param.set_skip_cancelled_check(true);
+ SendAsyncRpc(stub, param);
+ ++rpcs_sent;
+ }
+ EXPECT_TRUE(WaitForChannelReady(channel.get()));
+
+ // Bring down the network
+ NetworkDown();
+
+ std::thread thd = std::thread([this, &rpcs_sent]() {
+ void* got_tag;
+ bool ok = false;
+ bool network_down = true;
+ int total_completions = 0;
+
+ while (CQNext(&got_tag, &ok)) {
+ ++total_completions;
+ GPR_ASSERT(ok);
+ AsyncClientCall* call = static_cast<AsyncClientCall*>(got_tag);
+ int stream_id = GetStreamID(call->context);
+ if (!call->status.ok()) {
+ gpr_log(GPR_DEBUG, "RPC with stream_id %d failed with error: %s",
+ stream_id, call->status.error_message().c_str());
+ // Bring network up when RPCs start failing
+ if (network_down) {
+ NetworkUp();
+ network_down = false;
+ }
+ } else {
+ gpr_log(GPR_DEBUG, "RPC with stream_id %d succeeded", stream_id);
+ }
+ delete call;
+ }
+ // Remove line below and uncomment the following line after Apple CFStream
+ // bug has been fixed.
+ (void)rpcs_sent;
+ // EXPECT_EQ(total_completions, rpcs_sent);
+ });
+
+ for (int i = 0; i < 100; ++i) {
+ RequestParams param;
+ param.set_skip_cancelled_check(true);
+ SendAsyncRpc(stub, param);
+ std::this_thread::sleep_for(std::chrono::milliseconds(10));
+ ++rpcs_sent;
+ }
+
+ ShutdownCQ();
+
+ thd.join();
+}
+
+// Send a bunch of RPCs, some of which are expected to fail.
+// We should get back a response for all RPCs
+TEST_P(CFStreamTest, ConcurrentRpc) {
+ auto channel = BuildChannel();
+ auto stub = BuildStub(channel);
+ std::atomic_int rpcs_sent{0};
+ std::thread thd = std::thread([this, &rpcs_sent]() {
+ void* got_tag;
+ bool ok = false;
+ int total_completions = 0;
+
+ while (CQNext(&got_tag, &ok)) {
+ ++total_completions;
+ GPR_ASSERT(ok);
+ AsyncClientCall* call = static_cast<AsyncClientCall*>(got_tag);
+ int stream_id = GetStreamID(call->context);
+ if (!call->status.ok()) {
+ gpr_log(GPR_DEBUG, "RPC with stream_id %d failed with error: %s",
+ stream_id, call->status.error_message().c_str());
+ // Bring network up when RPCs start failing
+ } else {
+ gpr_log(GPR_DEBUG, "RPC with stream_id %d succeeded", stream_id);
+ }
+ delete call;
+ }
+ // Remove line below and uncomment the following line after Apple CFStream
+ // bug has been fixed.
+ (void)rpcs_sent;
+ // EXPECT_EQ(total_completions, rpcs_sent);
+ });
+
+ for (int i = 0; i < 10; ++i) {
+ if (i % 3 == 0) {
+ RequestParams param;
+ ErrorStatus* error = param.mutable_expected_error();
+ error->set_code(StatusCode::INTERNAL);
+ error->set_error_message("internal error");
+ SendAsyncRpc(stub, param);
+ } else if (i % 5 == 0) {
+ RequestParams param;
+ param.set_echo_metadata(true);
+ DebugInfo* info = param.mutable_debug_info();
+ info->add_stack_entries("stack_entry1");
+ info->add_stack_entries("stack_entry2");
+ info->set_detail("detailed debug info");
+ SendAsyncRpc(stub, param);
+ } else {
+ SendAsyncRpc(stub);
+ }
+ ++rpcs_sent;
+ }
+
+ ShutdownCQ();
+
+ thd.join();
+}
+
+} // namespace
+} // namespace testing
+} // namespace grpc
+#endif // GRPC_CFSTREAM
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ grpc::testing::TestEnvironment env(argc, argv);
+ gpr_setenv("grpc_cfstream", "1");
+ const auto result = RUN_ALL_TESTS();
+ return result;
+}
diff --git a/contrib/libs/grpc/test/cpp/end2end/channelz_service_test.cc b/contrib/libs/grpc/test/cpp/end2end/channelz_service_test.cc
new file mode 100644
index 0000000000..9c723bebb6
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/end2end/channelz_service_test.cc
@@ -0,0 +1,767 @@
+/*
+ *
+ * Copyright 2018 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include <grpc/support/port_platform.h>
+
+#include <grpc/grpc.h>
+#include <grpcpp/channel.h>
+#include <grpcpp/client_context.h>
+#include <grpcpp/create_channel.h>
+#include <grpcpp/security/credentials.h>
+#include <grpcpp/security/server_credentials.h>
+#include <grpcpp/server.h>
+#include <grpcpp/server_builder.h>
+#include <grpcpp/server_context.h>
+
+#include <grpcpp/ext/channelz_service_plugin.h>
+#include "src/core/lib/gpr/env.h"
+#include "src/proto/grpc/channelz/channelz.grpc.pb.h"
+#include "src/proto/grpc/testing/echo.grpc.pb.h"
+#include "test/core/util/port.h"
+#include "test/core/util/test_config.h"
+#include "test/cpp/end2end/test_service_impl.h"
+
+#include <gtest/gtest.h>
+
+using grpc::channelz::v1::GetChannelRequest;
+using grpc::channelz::v1::GetChannelResponse;
+using grpc::channelz::v1::GetServerRequest;
+using grpc::channelz::v1::GetServerResponse;
+using grpc::channelz::v1::GetServerSocketsRequest;
+using grpc::channelz::v1::GetServerSocketsResponse;
+using grpc::channelz::v1::GetServersRequest;
+using grpc::channelz::v1::GetServersResponse;
+using grpc::channelz::v1::GetSocketRequest;
+using grpc::channelz::v1::GetSocketResponse;
+using grpc::channelz::v1::GetSubchannelRequest;
+using grpc::channelz::v1::GetSubchannelResponse;
+using grpc::channelz::v1::GetTopChannelsRequest;
+using grpc::channelz::v1::GetTopChannelsResponse;
+
+namespace grpc {
+namespace testing {
+namespace {
+
+// Proxy service supports N backends. Sends RPC to backend dictated by
+// request->backend_channel_idx().
+class Proxy : public ::grpc::testing::EchoTestService::Service {
+ public:
+ Proxy() {}
+
+ void AddChannelToBackend(const std::shared_ptr<Channel>& channel) {
+ stubs_.push_back(grpc::testing::EchoTestService::NewStub(channel));
+ }
+
+ Status Echo(ServerContext* server_context, const EchoRequest* request,
+ EchoResponse* response) override {
+ std::unique_ptr<ClientContext> client_context =
+ ClientContext::FromServerContext(*server_context);
+ size_t idx = request->param().backend_channel_idx();
+ GPR_ASSERT(idx < stubs_.size());
+ return stubs_[idx]->Echo(client_context.get(), *request, response);
+ }
+
+ Status BidiStream(ServerContext* server_context,
+ ServerReaderWriter<EchoResponse, EchoRequest>*
+ stream_from_client) override {
+ EchoRequest request;
+ EchoResponse response;
+ std::unique_ptr<ClientContext> client_context =
+ ClientContext::FromServerContext(*server_context);
+
+ // always use the first proxy for streaming
+ auto stream_to_backend = stubs_[0]->BidiStream(client_context.get());
+ while (stream_from_client->Read(&request)) {
+ stream_to_backend->Write(request);
+ stream_to_backend->Read(&response);
+ stream_from_client->Write(response);
+ }
+
+ stream_to_backend->WritesDone();
+ return stream_to_backend->Finish();
+ }
+
+ private:
+ std::vector<std::unique_ptr<::grpc::testing::EchoTestService::Stub>> stubs_;
+};
+
+} // namespace
+
+class ChannelzServerTest : public ::testing::Test {
+ public:
+ ChannelzServerTest() {}
+ static void SetUpTestCase() {
+#if TARGET_OS_IPHONE
+ // Workaround Apple CFStream bug
+ gpr_setenv("grpc_cfstream", "0");
+#endif
+ }
+ void SetUp() override {
+ // ensure channel server is brought up on all severs we build.
+ ::grpc::channelz::experimental::InitChannelzService();
+
+ // We set up a proxy server with channelz enabled.
+ proxy_port_ = grpc_pick_unused_port_or_die();
+ ServerBuilder proxy_builder;
+ TString proxy_server_address = "localhost:" + to_string(proxy_port_);
+ proxy_builder.AddListeningPort(proxy_server_address,
+ InsecureServerCredentials());
+ // forces channelz and channel tracing to be enabled.
+ proxy_builder.AddChannelArgument(GRPC_ARG_ENABLE_CHANNELZ, 1);
+ proxy_builder.AddChannelArgument(
+ GRPC_ARG_MAX_CHANNEL_TRACE_EVENT_MEMORY_PER_NODE, 1024);
+ proxy_builder.RegisterService(&proxy_service_);
+ proxy_server_ = proxy_builder.BuildAndStart();
+ }
+
+ // Sets the proxy up to have an arbitrary number of backends.
+ void ConfigureProxy(size_t num_backends) {
+ backends_.resize(num_backends);
+ for (size_t i = 0; i < num_backends; ++i) {
+ // create a new backend.
+ backends_[i].port = grpc_pick_unused_port_or_die();
+ ServerBuilder backend_builder;
+ TString backend_server_address =
+ "localhost:" + to_string(backends_[i].port);
+ backend_builder.AddListeningPort(backend_server_address,
+ InsecureServerCredentials());
+ backends_[i].service.reset(new TestServiceImpl);
+ // ensure that the backend itself has channelz disabled.
+ backend_builder.AddChannelArgument(GRPC_ARG_ENABLE_CHANNELZ, 0);
+ backend_builder.RegisterService(backends_[i].service.get());
+ backends_[i].server = backend_builder.BuildAndStart();
+ // set up a channel to the backend. We ensure that this channel has
+ // channelz enabled since these channels (proxy outbound to backends)
+ // are the ones that our test will actually be validating.
+ ChannelArguments args;
+ args.SetInt(GRPC_ARG_ENABLE_CHANNELZ, 1);
+ args.SetInt(GRPC_ARG_MAX_CHANNEL_TRACE_EVENT_MEMORY_PER_NODE, 1024);
+ std::shared_ptr<Channel> channel_to_backend = ::grpc::CreateCustomChannel(
+ backend_server_address, InsecureChannelCredentials(), args);
+ proxy_service_.AddChannelToBackend(channel_to_backend);
+ }
+ }
+
+ void ResetStubs() {
+ string target = "dns:localhost:" + to_string(proxy_port_);
+ ChannelArguments args;
+ // disable channelz. We only want to focus on proxy to backend outbound.
+ args.SetInt(GRPC_ARG_ENABLE_CHANNELZ, 0);
+ std::shared_ptr<Channel> channel =
+ ::grpc::CreateCustomChannel(target, InsecureChannelCredentials(), args);
+ channelz_stub_ = grpc::channelz::v1::Channelz::NewStub(channel);
+ echo_stub_ = grpc::testing::EchoTestService::NewStub(channel);
+ }
+
+ std::unique_ptr<grpc::testing::EchoTestService::Stub> NewEchoStub() {
+ string target = "dns:localhost:" + to_string(proxy_port_);
+ ChannelArguments args;
+ // disable channelz. We only want to focus on proxy to backend outbound.
+ args.SetInt(GRPC_ARG_ENABLE_CHANNELZ, 0);
+ // This ensures that gRPC will not do connection sharing.
+ args.SetInt(GRPC_ARG_USE_LOCAL_SUBCHANNEL_POOL, true);
+ std::shared_ptr<Channel> channel =
+ ::grpc::CreateCustomChannel(target, InsecureChannelCredentials(), args);
+ return grpc::testing::EchoTestService::NewStub(channel);
+ }
+
+ void SendSuccessfulEcho(int channel_idx) {
+ EchoRequest request;
+ EchoResponse response;
+ request.set_message("Hello channelz");
+ request.mutable_param()->set_backend_channel_idx(channel_idx);
+ ClientContext context;
+ Status s = echo_stub_->Echo(&context, request, &response);
+ EXPECT_EQ(response.message(), request.message());
+ EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message();
+ }
+
+ void SendSuccessfulStream(int num_messages) {
+ EchoRequest request;
+ EchoResponse response;
+ request.set_message("Hello channelz");
+ ClientContext context;
+ auto stream_to_proxy = echo_stub_->BidiStream(&context);
+ for (int i = 0; i < num_messages; ++i) {
+ EXPECT_TRUE(stream_to_proxy->Write(request));
+ EXPECT_TRUE(stream_to_proxy->Read(&response));
+ }
+ stream_to_proxy->WritesDone();
+ Status s = stream_to_proxy->Finish();
+ EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message();
+ }
+
+ void SendFailedEcho(int channel_idx) {
+ EchoRequest request;
+ EchoResponse response;
+ request.set_message("Hello channelz");
+ request.mutable_param()->set_backend_channel_idx(channel_idx);
+ auto* error = request.mutable_param()->mutable_expected_error();
+ error->set_code(13); // INTERNAL
+ error->set_error_message("error");
+ ClientContext context;
+ Status s = echo_stub_->Echo(&context, request, &response);
+ EXPECT_FALSE(s.ok());
+ }
+
+ // Uses GetTopChannels to return the channel_id of a particular channel,
+ // so that the unit tests may test GetChannel call.
+ intptr_t GetChannelId(int channel_idx) {
+ GetTopChannelsRequest request;
+ GetTopChannelsResponse response;
+ request.set_start_channel_id(0);
+ ClientContext context;
+ Status s = channelz_stub_->GetTopChannels(&context, request, &response);
+ EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message();
+ EXPECT_GT(response.channel_size(), channel_idx);
+ return response.channel(channel_idx).ref().channel_id();
+ }
+
+ static string to_string(const int number) {
+ std::stringstream strs;
+ strs << number;
+ return strs.str();
+ }
+
+ protected:
+ // package of data needed for each backend server.
+ struct BackendData {
+ std::unique_ptr<Server> server;
+ int port;
+ std::unique_ptr<TestServiceImpl> service;
+ };
+
+ std::unique_ptr<grpc::channelz::v1::Channelz::Stub> channelz_stub_;
+ std::unique_ptr<grpc::testing::EchoTestService::Stub> echo_stub_;
+
+ // proxy server to ping with channelz requests.
+ std::unique_ptr<Server> proxy_server_;
+ int proxy_port_;
+ Proxy proxy_service_;
+
+ // backends. All implement the echo service.
+ std::vector<BackendData> backends_;
+};
+
+TEST_F(ChannelzServerTest, BasicTest) {
+ ResetStubs();
+ ConfigureProxy(1);
+ GetTopChannelsRequest request;
+ GetTopChannelsResponse response;
+ request.set_start_channel_id(0);
+ ClientContext context;
+ Status s = channelz_stub_->GetTopChannels(&context, request, &response);
+ EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message();
+ EXPECT_EQ(response.channel_size(), 1);
+}
+
+TEST_F(ChannelzServerTest, HighStartId) {
+ ResetStubs();
+ ConfigureProxy(1);
+ GetTopChannelsRequest request;
+ GetTopChannelsResponse response;
+ request.set_start_channel_id(10000);
+ ClientContext context;
+ Status s = channelz_stub_->GetTopChannels(&context, request, &response);
+ EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message();
+ EXPECT_EQ(response.channel_size(), 0);
+}
+
+TEST_F(ChannelzServerTest, SuccessfulRequestTest) {
+ ResetStubs();
+ ConfigureProxy(1);
+ SendSuccessfulEcho(0);
+ GetChannelRequest request;
+ GetChannelResponse response;
+ request.set_channel_id(GetChannelId(0));
+ ClientContext context;
+ Status s = channelz_stub_->GetChannel(&context, request, &response);
+ EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message();
+ EXPECT_EQ(response.channel().data().calls_started(), 1);
+ EXPECT_EQ(response.channel().data().calls_succeeded(), 1);
+ EXPECT_EQ(response.channel().data().calls_failed(), 0);
+}
+
+TEST_F(ChannelzServerTest, FailedRequestTest) {
+ ResetStubs();
+ ConfigureProxy(1);
+ SendFailedEcho(0);
+ GetChannelRequest request;
+ GetChannelResponse response;
+ request.set_channel_id(GetChannelId(0));
+ ClientContext context;
+ Status s = channelz_stub_->GetChannel(&context, request, &response);
+ EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message();
+ EXPECT_EQ(response.channel().data().calls_started(), 1);
+ EXPECT_EQ(response.channel().data().calls_succeeded(), 0);
+ EXPECT_EQ(response.channel().data().calls_failed(), 1);
+}
+
+TEST_F(ChannelzServerTest, ManyRequestsTest) {
+ ResetStubs();
+ ConfigureProxy(1);
+ // send some RPCs
+ const int kNumSuccess = 10;
+ const int kNumFailed = 11;
+ for (int i = 0; i < kNumSuccess; ++i) {
+ SendSuccessfulEcho(0);
+ }
+ for (int i = 0; i < kNumFailed; ++i) {
+ SendFailedEcho(0);
+ }
+ GetChannelRequest request;
+ GetChannelResponse response;
+ request.set_channel_id(GetChannelId(0));
+ ClientContext context;
+ Status s = channelz_stub_->GetChannel(&context, request, &response);
+ EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message();
+ EXPECT_EQ(response.channel().data().calls_started(),
+ kNumSuccess + kNumFailed);
+ EXPECT_EQ(response.channel().data().calls_succeeded(), kNumSuccess);
+ EXPECT_EQ(response.channel().data().calls_failed(), kNumFailed);
+}
+
+TEST_F(ChannelzServerTest, ManyChannels) {
+ ResetStubs();
+ const int kNumChannels = 4;
+ ConfigureProxy(kNumChannels);
+ GetTopChannelsRequest request;
+ GetTopChannelsResponse response;
+ request.set_start_channel_id(0);
+ ClientContext context;
+ Status s = channelz_stub_->GetTopChannels(&context, request, &response);
+ EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message();
+ EXPECT_EQ(response.channel_size(), kNumChannels);
+}
+
+TEST_F(ChannelzServerTest, ManyRequestsManyChannels) {
+ ResetStubs();
+ const int kNumChannels = 4;
+ ConfigureProxy(kNumChannels);
+ const int kNumSuccess = 10;
+ const int kNumFailed = 11;
+ for (int i = 0; i < kNumSuccess; ++i) {
+ SendSuccessfulEcho(0);
+ SendSuccessfulEcho(2);
+ }
+ for (int i = 0; i < kNumFailed; ++i) {
+ SendFailedEcho(1);
+ SendFailedEcho(2);
+ }
+
+ // the first channel saw only successes
+ {
+ GetChannelRequest request;
+ GetChannelResponse response;
+ request.set_channel_id(GetChannelId(0));
+ ClientContext context;
+ Status s = channelz_stub_->GetChannel(&context, request, &response);
+ EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message();
+ EXPECT_EQ(response.channel().data().calls_started(), kNumSuccess);
+ EXPECT_EQ(response.channel().data().calls_succeeded(), kNumSuccess);
+ EXPECT_EQ(response.channel().data().calls_failed(), 0);
+ }
+
+ // the second channel saw only failures
+ {
+ GetChannelRequest request;
+ GetChannelResponse response;
+ request.set_channel_id(GetChannelId(1));
+ ClientContext context;
+ Status s = channelz_stub_->GetChannel(&context, request, &response);
+ EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message();
+ EXPECT_EQ(response.channel().data().calls_started(), kNumFailed);
+ EXPECT_EQ(response.channel().data().calls_succeeded(), 0);
+ EXPECT_EQ(response.channel().data().calls_failed(), kNumFailed);
+ }
+
+ // the third channel saw both
+ {
+ GetChannelRequest request;
+ GetChannelResponse response;
+ request.set_channel_id(GetChannelId(2));
+ ClientContext context;
+ Status s = channelz_stub_->GetChannel(&context, request, &response);
+ EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message();
+ EXPECT_EQ(response.channel().data().calls_started(),
+ kNumSuccess + kNumFailed);
+ EXPECT_EQ(response.channel().data().calls_succeeded(), kNumSuccess);
+ EXPECT_EQ(response.channel().data().calls_failed(), kNumFailed);
+ }
+
+ // the fourth channel saw nothing
+ {
+ GetChannelRequest request;
+ GetChannelResponse response;
+ request.set_channel_id(GetChannelId(3));
+ ClientContext context;
+ Status s = channelz_stub_->GetChannel(&context, request, &response);
+ EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message();
+ EXPECT_EQ(response.channel().data().calls_started(), 0);
+ EXPECT_EQ(response.channel().data().calls_succeeded(), 0);
+ EXPECT_EQ(response.channel().data().calls_failed(), 0);
+ }
+}
+
+TEST_F(ChannelzServerTest, ManySubchannels) {
+ ResetStubs();
+ const int kNumChannels = 4;
+ ConfigureProxy(kNumChannels);
+ const int kNumSuccess = 10;
+ const int kNumFailed = 11;
+ for (int i = 0; i < kNumSuccess; ++i) {
+ SendSuccessfulEcho(0);
+ SendSuccessfulEcho(2);
+ }
+ for (int i = 0; i < kNumFailed; ++i) {
+ SendFailedEcho(1);
+ SendFailedEcho(2);
+ }
+ GetTopChannelsRequest gtc_request;
+ GetTopChannelsResponse gtc_response;
+ gtc_request.set_start_channel_id(0);
+ ClientContext context;
+ Status s =
+ channelz_stub_->GetTopChannels(&context, gtc_request, &gtc_response);
+ EXPECT_TRUE(s.ok()) << s.error_message();
+ EXPECT_EQ(gtc_response.channel_size(), kNumChannels);
+ for (int i = 0; i < gtc_response.channel_size(); ++i) {
+ // if the channel sent no RPCs, then expect no subchannels to have been
+ // created.
+ if (gtc_response.channel(i).data().calls_started() == 0) {
+ EXPECT_EQ(gtc_response.channel(i).subchannel_ref_size(), 0);
+ continue;
+ }
+ // The resolver must return at least one address.
+ ASSERT_GT(gtc_response.channel(i).subchannel_ref_size(), 0);
+ GetSubchannelRequest gsc_request;
+ GetSubchannelResponse gsc_response;
+ gsc_request.set_subchannel_id(
+ gtc_response.channel(i).subchannel_ref(0).subchannel_id());
+ ClientContext context;
+ Status s =
+ channelz_stub_->GetSubchannel(&context, gsc_request, &gsc_response);
+ EXPECT_TRUE(s.ok()) << s.error_message();
+ EXPECT_EQ(gtc_response.channel(i).data().calls_started(),
+ gsc_response.subchannel().data().calls_started());
+ EXPECT_EQ(gtc_response.channel(i).data().calls_succeeded(),
+ gsc_response.subchannel().data().calls_succeeded());
+ EXPECT_EQ(gtc_response.channel(i).data().calls_failed(),
+ gsc_response.subchannel().data().calls_failed());
+ }
+}
+
+TEST_F(ChannelzServerTest, BasicServerTest) {
+ ResetStubs();
+ ConfigureProxy(1);
+ GetServersRequest request;
+ GetServersResponse response;
+ request.set_start_server_id(0);
+ ClientContext context;
+ Status s = channelz_stub_->GetServers(&context, request, &response);
+ EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message();
+ EXPECT_EQ(response.server_size(), 1);
+}
+
+TEST_F(ChannelzServerTest, BasicGetServerTest) {
+ ResetStubs();
+ ConfigureProxy(1);
+ GetServersRequest get_servers_request;
+ GetServersResponse get_servers_response;
+ get_servers_request.set_start_server_id(0);
+ ClientContext get_servers_context;
+ Status s = channelz_stub_->GetServers(
+ &get_servers_context, get_servers_request, &get_servers_response);
+ EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message();
+ EXPECT_EQ(get_servers_response.server_size(), 1);
+ GetServerRequest get_server_request;
+ GetServerResponse get_server_response;
+ get_server_request.set_server_id(
+ get_servers_response.server(0).ref().server_id());
+ ClientContext get_server_context;
+ s = channelz_stub_->GetServer(&get_server_context, get_server_request,
+ &get_server_response);
+ EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message();
+ EXPECT_EQ(get_servers_response.server(0).ref().server_id(),
+ get_server_response.server().ref().server_id());
+}
+
+TEST_F(ChannelzServerTest, ServerCallTest) {
+ ResetStubs();
+ ConfigureProxy(1);
+ const int kNumSuccess = 10;
+ const int kNumFailed = 11;
+ for (int i = 0; i < kNumSuccess; ++i) {
+ SendSuccessfulEcho(0);
+ }
+ for (int i = 0; i < kNumFailed; ++i) {
+ SendFailedEcho(0);
+ }
+ GetServersRequest request;
+ GetServersResponse response;
+ request.set_start_server_id(0);
+ ClientContext context;
+ Status s = channelz_stub_->GetServers(&context, request, &response);
+ EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message();
+ EXPECT_EQ(response.server_size(), 1);
+ EXPECT_EQ(response.server(0).data().calls_succeeded(), kNumSuccess);
+ EXPECT_EQ(response.server(0).data().calls_failed(), kNumFailed);
+ // This is success+failure+1 because the call that retrieved this information
+ // will be counted as started. It will not track success/failure until after
+ // it has returned, so that is not included in the response.
+ EXPECT_EQ(response.server(0).data().calls_started(),
+ kNumSuccess + kNumFailed + 1);
+}
+
+TEST_F(ChannelzServerTest, ManySubchannelsAndSockets) {
+ ResetStubs();
+ const int kNumChannels = 4;
+ ConfigureProxy(kNumChannels);
+ const int kNumSuccess = 10;
+ const int kNumFailed = 11;
+ for (int i = 0; i < kNumSuccess; ++i) {
+ SendSuccessfulEcho(0);
+ SendSuccessfulEcho(2);
+ }
+ for (int i = 0; i < kNumFailed; ++i) {
+ SendFailedEcho(1);
+ SendFailedEcho(2);
+ }
+ GetTopChannelsRequest gtc_request;
+ GetTopChannelsResponse gtc_response;
+ gtc_request.set_start_channel_id(0);
+ ClientContext context;
+ Status s =
+ channelz_stub_->GetTopChannels(&context, gtc_request, &gtc_response);
+ EXPECT_TRUE(s.ok()) << s.error_message();
+ EXPECT_EQ(gtc_response.channel_size(), kNumChannels);
+ for (int i = 0; i < gtc_response.channel_size(); ++i) {
+ // if the channel sent no RPCs, then expect no subchannels to have been
+ // created.
+ if (gtc_response.channel(i).data().calls_started() == 0) {
+ EXPECT_EQ(gtc_response.channel(i).subchannel_ref_size(), 0);
+ continue;
+ }
+ // The resolver must return at least one address.
+ ASSERT_GT(gtc_response.channel(i).subchannel_ref_size(), 0);
+ // First grab the subchannel
+ GetSubchannelRequest get_subchannel_req;
+ GetSubchannelResponse get_subchannel_resp;
+ get_subchannel_req.set_subchannel_id(
+ gtc_response.channel(i).subchannel_ref(0).subchannel_id());
+ ClientContext get_subchannel_ctx;
+ Status s = channelz_stub_->GetSubchannel(
+ &get_subchannel_ctx, get_subchannel_req, &get_subchannel_resp);
+ EXPECT_TRUE(s.ok()) << s.error_message();
+ EXPECT_EQ(get_subchannel_resp.subchannel().socket_ref_size(), 1);
+ // Now grab the socket.
+ GetSocketRequest get_socket_req;
+ GetSocketResponse get_socket_resp;
+ ClientContext get_socket_ctx;
+ get_socket_req.set_socket_id(
+ get_subchannel_resp.subchannel().socket_ref(0).socket_id());
+ s = channelz_stub_->GetSocket(&get_socket_ctx, get_socket_req,
+ &get_socket_resp);
+ EXPECT_TRUE(
+ get_subchannel_resp.subchannel().socket_ref(0).name().find("http"));
+ EXPECT_TRUE(s.ok()) << s.error_message();
+ // calls started == streams started AND stream succeeded. Since none of
+ // these RPCs were canceled, all of the streams will succeeded even though
+ // the RPCs they represent might have failed.
+ EXPECT_EQ(get_subchannel_resp.subchannel().data().calls_started(),
+ get_socket_resp.socket().data().streams_started());
+ EXPECT_EQ(get_subchannel_resp.subchannel().data().calls_started(),
+ get_socket_resp.socket().data().streams_succeeded());
+ // All of the calls were unary, so calls started == messages sent.
+ EXPECT_EQ(get_subchannel_resp.subchannel().data().calls_started(),
+ get_socket_resp.socket().data().messages_sent());
+ // We only get responses when the RPC was successful, so
+ // calls succeeded == messages received.
+ EXPECT_EQ(get_subchannel_resp.subchannel().data().calls_succeeded(),
+ get_socket_resp.socket().data().messages_received());
+ }
+}
+
+TEST_F(ChannelzServerTest, StreamingRPC) {
+ ResetStubs();
+ ConfigureProxy(1);
+ const int kNumMessages = 5;
+ SendSuccessfulStream(kNumMessages);
+ // Get the channel
+ GetChannelRequest get_channel_request;
+ GetChannelResponse get_channel_response;
+ get_channel_request.set_channel_id(GetChannelId(0));
+ ClientContext get_channel_context;
+ Status s = channelz_stub_->GetChannel(
+ &get_channel_context, get_channel_request, &get_channel_response);
+ EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message();
+ EXPECT_EQ(get_channel_response.channel().data().calls_started(), 1);
+ EXPECT_EQ(get_channel_response.channel().data().calls_succeeded(), 1);
+ EXPECT_EQ(get_channel_response.channel().data().calls_failed(), 0);
+ // Get the subchannel
+ ASSERT_GT(get_channel_response.channel().subchannel_ref_size(), 0);
+ GetSubchannelRequest get_subchannel_request;
+ GetSubchannelResponse get_subchannel_response;
+ ClientContext get_subchannel_context;
+ get_subchannel_request.set_subchannel_id(
+ get_channel_response.channel().subchannel_ref(0).subchannel_id());
+ s = channelz_stub_->GetSubchannel(&get_subchannel_context,
+ get_subchannel_request,
+ &get_subchannel_response);
+ EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message();
+ EXPECT_EQ(get_subchannel_response.subchannel().data().calls_started(), 1);
+ EXPECT_EQ(get_subchannel_response.subchannel().data().calls_succeeded(), 1);
+ EXPECT_EQ(get_subchannel_response.subchannel().data().calls_failed(), 0);
+ // Get the socket
+ ASSERT_GT(get_subchannel_response.subchannel().socket_ref_size(), 0);
+ GetSocketRequest get_socket_request;
+ GetSocketResponse get_socket_response;
+ ClientContext get_socket_context;
+ get_socket_request.set_socket_id(
+ get_subchannel_response.subchannel().socket_ref(0).socket_id());
+ EXPECT_TRUE(
+ get_subchannel_response.subchannel().socket_ref(0).name().find("http"));
+ s = channelz_stub_->GetSocket(&get_socket_context, get_socket_request,
+ &get_socket_response);
+ EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message();
+ EXPECT_EQ(get_socket_response.socket().data().streams_started(), 1);
+ EXPECT_EQ(get_socket_response.socket().data().streams_succeeded(), 1);
+ EXPECT_EQ(get_socket_response.socket().data().streams_failed(), 0);
+ EXPECT_EQ(get_socket_response.socket().data().messages_sent(), kNumMessages);
+ EXPECT_EQ(get_socket_response.socket().data().messages_received(),
+ kNumMessages);
+}
+
+TEST_F(ChannelzServerTest, GetServerSocketsTest) {
+ ResetStubs();
+ ConfigureProxy(1);
+ GetServersRequest get_server_request;
+ GetServersResponse get_server_response;
+ get_server_request.set_start_server_id(0);
+ ClientContext get_server_context;
+ Status s = channelz_stub_->GetServers(&get_server_context, get_server_request,
+ &get_server_response);
+ EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message();
+ EXPECT_EQ(get_server_response.server_size(), 1);
+ GetServerSocketsRequest get_server_sockets_request;
+ GetServerSocketsResponse get_server_sockets_response;
+ get_server_sockets_request.set_server_id(
+ get_server_response.server(0).ref().server_id());
+ get_server_sockets_request.set_start_socket_id(0);
+ ClientContext get_server_sockets_context;
+ s = channelz_stub_->GetServerSockets(&get_server_sockets_context,
+ get_server_sockets_request,
+ &get_server_sockets_response);
+ EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message();
+ EXPECT_EQ(get_server_sockets_response.socket_ref_size(), 1);
+ EXPECT_TRUE(get_server_sockets_response.socket_ref(0).name().find("http"));
+}
+
+TEST_F(ChannelzServerTest, GetServerSocketsPaginationTest) {
+ ResetStubs();
+ ConfigureProxy(1);
+ std::vector<std::unique_ptr<grpc::testing::EchoTestService::Stub>> stubs;
+ const int kNumServerSocketsCreated = 20;
+ for (int i = 0; i < kNumServerSocketsCreated; ++i) {
+ stubs.push_back(NewEchoStub());
+ EchoRequest request;
+ EchoResponse response;
+ request.set_message("Hello channelz");
+ request.mutable_param()->set_backend_channel_idx(0);
+ ClientContext context;
+ Status s = stubs.back()->Echo(&context, request, &response);
+ EXPECT_EQ(response.message(), request.message());
+ EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message();
+ }
+ GetServersRequest get_server_request;
+ GetServersResponse get_server_response;
+ get_server_request.set_start_server_id(0);
+ ClientContext get_server_context;
+ Status s = channelz_stub_->GetServers(&get_server_context, get_server_request,
+ &get_server_response);
+ EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message();
+ EXPECT_EQ(get_server_response.server_size(), 1);
+ // Make a request that gets all of the serversockets
+ {
+ GetServerSocketsRequest get_server_sockets_request;
+ GetServerSocketsResponse get_server_sockets_response;
+ get_server_sockets_request.set_server_id(
+ get_server_response.server(0).ref().server_id());
+ get_server_sockets_request.set_start_socket_id(0);
+ ClientContext get_server_sockets_context;
+ s = channelz_stub_->GetServerSockets(&get_server_sockets_context,
+ get_server_sockets_request,
+ &get_server_sockets_response);
+ EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message();
+ // We add one to account the channelz stub that will end up creating
+ // a serversocket.
+ EXPECT_EQ(get_server_sockets_response.socket_ref_size(),
+ kNumServerSocketsCreated + 1);
+ EXPECT_TRUE(get_server_sockets_response.end());
+ }
+ // Now we make a request that exercises pagination.
+ {
+ GetServerSocketsRequest get_server_sockets_request;
+ GetServerSocketsResponse get_server_sockets_response;
+ get_server_sockets_request.set_server_id(
+ get_server_response.server(0).ref().server_id());
+ get_server_sockets_request.set_start_socket_id(0);
+ const int kMaxResults = 10;
+ get_server_sockets_request.set_max_results(kMaxResults);
+ ClientContext get_server_sockets_context;
+ s = channelz_stub_->GetServerSockets(&get_server_sockets_context,
+ get_server_sockets_request,
+ &get_server_sockets_response);
+ EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message();
+ EXPECT_EQ(get_server_sockets_response.socket_ref_size(), kMaxResults);
+ EXPECT_FALSE(get_server_sockets_response.end());
+ }
+}
+
+TEST_F(ChannelzServerTest, GetServerListenSocketsTest) {
+ ResetStubs();
+ ConfigureProxy(1);
+ GetServersRequest get_server_request;
+ GetServersResponse get_server_response;
+ get_server_request.set_start_server_id(0);
+ ClientContext get_server_context;
+ Status s = channelz_stub_->GetServers(&get_server_context, get_server_request,
+ &get_server_response);
+ EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message();
+ EXPECT_EQ(get_server_response.server_size(), 1);
+ EXPECT_EQ(get_server_response.server(0).listen_socket_size(), 1);
+ GetSocketRequest get_socket_request;
+ GetSocketResponse get_socket_response;
+ get_socket_request.set_socket_id(
+ get_server_response.server(0).listen_socket(0).socket_id());
+ EXPECT_TRUE(
+ get_server_response.server(0).listen_socket(0).name().find("http"));
+ ClientContext get_socket_context;
+ s = channelz_stub_->GetSocket(&get_socket_context, get_socket_request,
+ &get_socket_response);
+ EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message();
+}
+
+} // namespace testing
+} // namespace grpc
+
+int main(int argc, char** argv) {
+ grpc::testing::TestEnvironment env(argc, argv);
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/contrib/libs/grpc/test/cpp/end2end/client_callback_end2end_test.cc b/contrib/libs/grpc/test/cpp/end2end/client_callback_end2end_test.cc
new file mode 100644
index 0000000000..12cb40a953
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/end2end/client_callback_end2end_test.cc
@@ -0,0 +1,1565 @@
+/*
+ *
+ * Copyright 2018 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include <grpcpp/channel.h>
+#include <grpcpp/client_context.h>
+#include <grpcpp/create_channel.h>
+#include <grpcpp/generic/generic_stub.h>
+#include <grpcpp/impl/codegen/proto_utils.h>
+#include <grpcpp/server.h>
+#include <grpcpp/server_builder.h>
+#include <grpcpp/server_context.h>
+#include <grpcpp/support/client_callback.h>
+#include <gtest/gtest.h>
+
+#include <algorithm>
+#include <condition_variable>
+#include <functional>
+#include <mutex>
+#include <sstream>
+#include <thread>
+
+#include "src/core/lib/gpr/env.h"
+#include "src/core/lib/iomgr/iomgr.h"
+#include "src/proto/grpc/testing/echo.grpc.pb.h"
+#include "test/core/util/port.h"
+#include "test/core/util/test_config.h"
+#include "test/cpp/end2end/interceptors_util.h"
+#include "test/cpp/end2end/test_service_impl.h"
+#include "test/cpp/util/byte_buffer_proto_helper.h"
+#include "test/cpp/util/string_ref_helper.h"
+#include "test/cpp/util/test_credentials_provider.h"
+
+// MAYBE_SKIP_TEST is a macro to determine if this particular test configuration
+// should be skipped based on a decision made at SetUp time. In particular, any
+// callback tests can only be run if the iomgr can run in the background or if
+// the transport is in-process.
+#define MAYBE_SKIP_TEST \
+ do { \
+ if (do_not_test_) { \
+ return; \
+ } \
+ } while (0)
+
+namespace grpc {
+namespace testing {
+namespace {
+
+enum class Protocol { INPROC, TCP };
+
+class TestScenario {
+ public:
+ TestScenario(bool serve_callback, Protocol protocol, bool intercept,
+ const TString& creds_type)
+ : callback_server(serve_callback),
+ protocol(protocol),
+ use_interceptors(intercept),
+ credentials_type(creds_type) {}
+ void Log() const;
+ bool callback_server;
+ Protocol protocol;
+ bool use_interceptors;
+ const TString credentials_type;
+};
+
+static std::ostream& operator<<(std::ostream& out,
+ const TestScenario& scenario) {
+ return out << "TestScenario{callback_server="
+ << (scenario.callback_server ? "true" : "false") << ",protocol="
+ << (scenario.protocol == Protocol::INPROC ? "INPROC" : "TCP")
+ << ",intercept=" << (scenario.use_interceptors ? "true" : "false")
+ << ",creds=" << scenario.credentials_type << "}";
+}
+
+void TestScenario::Log() const {
+ std::ostringstream out;
+ out << *this;
+ gpr_log(GPR_DEBUG, "%s", out.str().c_str());
+}
+
+class ClientCallbackEnd2endTest
+ : public ::testing::TestWithParam<TestScenario> {
+ protected:
+ ClientCallbackEnd2endTest() { GetParam().Log(); }
+
+ void SetUp() override {
+ ServerBuilder builder;
+
+ auto server_creds = GetCredentialsProvider()->GetServerCredentials(
+ GetParam().credentials_type);
+ // TODO(vjpai): Support testing of AuthMetadataProcessor
+
+ if (GetParam().protocol == Protocol::TCP) {
+ picked_port_ = grpc_pick_unused_port_or_die();
+ server_address_ << "localhost:" << picked_port_;
+ builder.AddListeningPort(server_address_.str(), server_creds);
+ }
+ if (!GetParam().callback_server) {
+ builder.RegisterService(&service_);
+ } else {
+ builder.RegisterService(&callback_service_);
+ }
+
+ if (GetParam().use_interceptors) {
+ std::vector<
+ std::unique_ptr<experimental::ServerInterceptorFactoryInterface>>
+ creators;
+ // Add 20 dummy server interceptors
+ creators.reserve(20);
+ for (auto i = 0; i < 20; i++) {
+ creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
+ new DummyInterceptorFactory()));
+ }
+ builder.experimental().SetInterceptorCreators(std::move(creators));
+ }
+
+ server_ = builder.BuildAndStart();
+ is_server_started_ = true;
+ if (GetParam().protocol == Protocol::TCP &&
+ !grpc_iomgr_run_in_background()) {
+ do_not_test_ = true;
+ }
+ }
+
+ void ResetStub() {
+ ChannelArguments args;
+ auto channel_creds = GetCredentialsProvider()->GetChannelCredentials(
+ GetParam().credentials_type, &args);
+ switch (GetParam().protocol) {
+ case Protocol::TCP:
+ if (!GetParam().use_interceptors) {
+ channel_ = ::grpc::CreateCustomChannel(server_address_.str(),
+ channel_creds, args);
+ } else {
+ channel_ = CreateCustomChannelWithInterceptors(
+ server_address_.str(), channel_creds, args,
+ CreateDummyClientInterceptors());
+ }
+ break;
+ case Protocol::INPROC:
+ if (!GetParam().use_interceptors) {
+ channel_ = server_->InProcessChannel(args);
+ } else {
+ channel_ = server_->experimental().InProcessChannelWithInterceptors(
+ args, CreateDummyClientInterceptors());
+ }
+ break;
+ default:
+ assert(false);
+ }
+ stub_ = grpc::testing::EchoTestService::NewStub(channel_);
+ generic_stub_.reset(new GenericStub(channel_));
+ DummyInterceptor::Reset();
+ }
+
+ void TearDown() override {
+ if (is_server_started_) {
+ // Although we would normally do an explicit shutdown, the server
+ // should also work correctly with just a destructor call. The regular
+ // end2end test uses explicit shutdown, so let this one just do reset.
+ server_.reset();
+ }
+ if (picked_port_ > 0) {
+ grpc_recycle_unused_port(picked_port_);
+ }
+ }
+
+ void SendRpcs(int num_rpcs, bool with_binary_metadata) {
+ TString test_string("");
+ for (int i = 0; i < num_rpcs; i++) {
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext cli_ctx;
+
+ test_string += "Hello world. ";
+ request.set_message(test_string);
+ TString val;
+ if (with_binary_metadata) {
+ request.mutable_param()->set_echo_metadata(true);
+ char bytes[8] = {'\0', '\1', '\2', '\3',
+ '\4', '\5', '\6', static_cast<char>(i)};
+ val = TString(bytes, 8);
+ cli_ctx.AddMetadata("custom-bin", val);
+ }
+
+ cli_ctx.set_compression_algorithm(GRPC_COMPRESS_GZIP);
+
+ std::mutex mu;
+ std::condition_variable cv;
+ bool done = false;
+ stub_->experimental_async()->Echo(
+ &cli_ctx, &request, &response,
+ [&cli_ctx, &request, &response, &done, &mu, &cv, val,
+ with_binary_metadata](Status s) {
+ GPR_ASSERT(s.ok());
+
+ EXPECT_EQ(request.message(), response.message());
+ if (with_binary_metadata) {
+ EXPECT_EQ(
+ 1u, cli_ctx.GetServerTrailingMetadata().count("custom-bin"));
+ EXPECT_EQ(val, ToString(cli_ctx.GetServerTrailingMetadata()
+ .find("custom-bin")
+ ->second));
+ }
+ std::lock_guard<std::mutex> l(mu);
+ done = true;
+ cv.notify_one();
+ });
+ std::unique_lock<std::mutex> l(mu);
+ while (!done) {
+ cv.wait(l);
+ }
+ }
+ }
+
+ void SendRpcsGeneric(int num_rpcs, bool maybe_except) {
+ const TString kMethodName("/grpc.testing.EchoTestService/Echo");
+ TString test_string("");
+ for (int i = 0; i < num_rpcs; i++) {
+ EchoRequest request;
+ std::unique_ptr<ByteBuffer> send_buf;
+ ByteBuffer recv_buf;
+ ClientContext cli_ctx;
+
+ test_string += "Hello world. ";
+ request.set_message(test_string);
+ send_buf = SerializeToByteBuffer(&request);
+
+ std::mutex mu;
+ std::condition_variable cv;
+ bool done = false;
+ generic_stub_->experimental().UnaryCall(
+ &cli_ctx, kMethodName, send_buf.get(), &recv_buf,
+ [&request, &recv_buf, &done, &mu, &cv, maybe_except](Status s) {
+ GPR_ASSERT(s.ok());
+
+ EchoResponse response;
+ EXPECT_TRUE(ParseFromByteBuffer(&recv_buf, &response));
+ EXPECT_EQ(request.message(), response.message());
+ std::lock_guard<std::mutex> l(mu);
+ done = true;
+ cv.notify_one();
+#if GRPC_ALLOW_EXCEPTIONS
+ if (maybe_except) {
+ throw - 1;
+ }
+#else
+ GPR_ASSERT(!maybe_except);
+#endif
+ });
+ std::unique_lock<std::mutex> l(mu);
+ while (!done) {
+ cv.wait(l);
+ }
+ }
+ }
+
+ void SendGenericEchoAsBidi(int num_rpcs, int reuses, bool do_writes_done) {
+ const TString kMethodName("/grpc.testing.EchoTestService/Echo");
+ TString test_string("");
+ for (int i = 0; i < num_rpcs; i++) {
+ test_string += "Hello world. ";
+ class Client : public grpc::experimental::ClientBidiReactor<ByteBuffer,
+ ByteBuffer> {
+ public:
+ Client(ClientCallbackEnd2endTest* test, const TString& method_name,
+ const TString& test_str, int reuses, bool do_writes_done)
+ : reuses_remaining_(reuses), do_writes_done_(do_writes_done) {
+ activate_ = [this, test, method_name, test_str] {
+ if (reuses_remaining_ > 0) {
+ cli_ctx_.reset(new ClientContext);
+ reuses_remaining_--;
+ test->generic_stub_->experimental().PrepareBidiStreamingCall(
+ cli_ctx_.get(), method_name, this);
+ request_.set_message(test_str);
+ send_buf_ = SerializeToByteBuffer(&request_);
+ StartWrite(send_buf_.get());
+ StartRead(&recv_buf_);
+ StartCall();
+ } else {
+ std::unique_lock<std::mutex> l(mu_);
+ done_ = true;
+ cv_.notify_one();
+ }
+ };
+ activate_();
+ }
+ void OnWriteDone(bool /*ok*/) override {
+ if (do_writes_done_) {
+ StartWritesDone();
+ }
+ }
+ void OnReadDone(bool /*ok*/) override {
+ EchoResponse response;
+ EXPECT_TRUE(ParseFromByteBuffer(&recv_buf_, &response));
+ EXPECT_EQ(request_.message(), response.message());
+ };
+ void OnDone(const Status& s) override {
+ EXPECT_TRUE(s.ok());
+ activate_();
+ }
+ void Await() {
+ std::unique_lock<std::mutex> l(mu_);
+ while (!done_) {
+ cv_.wait(l);
+ }
+ }
+
+ EchoRequest request_;
+ std::unique_ptr<ByteBuffer> send_buf_;
+ ByteBuffer recv_buf_;
+ std::unique_ptr<ClientContext> cli_ctx_;
+ int reuses_remaining_;
+ std::function<void()> activate_;
+ std::mutex mu_;
+ std::condition_variable cv_;
+ bool done_ = false;
+ const bool do_writes_done_;
+ };
+
+ Client rpc(this, kMethodName, test_string, reuses, do_writes_done);
+
+ rpc.Await();
+ }
+ }
+ bool do_not_test_{false};
+ bool is_server_started_{false};
+ int picked_port_{0};
+ std::shared_ptr<Channel> channel_;
+ std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_;
+ std::unique_ptr<grpc::GenericStub> generic_stub_;
+ TestServiceImpl service_;
+ CallbackTestServiceImpl callback_service_;
+ std::unique_ptr<Server> server_;
+ std::ostringstream server_address_;
+};
+
+TEST_P(ClientCallbackEnd2endTest, SimpleRpc) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ SendRpcs(1, false);
+}
+
+TEST_P(ClientCallbackEnd2endTest, SimpleRpcExpectedError) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext cli_ctx;
+ ErrorStatus error_status;
+
+ request.set_message("Hello failure");
+ error_status.set_code(1); // CANCELLED
+ error_status.set_error_message("cancel error message");
+ *request.mutable_param()->mutable_expected_error() = error_status;
+
+ std::mutex mu;
+ std::condition_variable cv;
+ bool done = false;
+
+ stub_->experimental_async()->Echo(
+ &cli_ctx, &request, &response,
+ [&response, &done, &mu, &cv, &error_status](Status s) {
+ EXPECT_EQ("", response.message());
+ EXPECT_EQ(error_status.code(), s.error_code());
+ EXPECT_EQ(error_status.error_message(), s.error_message());
+ std::lock_guard<std::mutex> l(mu);
+ done = true;
+ cv.notify_one();
+ });
+
+ std::unique_lock<std::mutex> l(mu);
+ while (!done) {
+ cv.wait(l);
+ }
+}
+
+TEST_P(ClientCallbackEnd2endTest, SimpleRpcUnderLockNested) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+
+ // The request/response state associated with an RPC and the synchronization
+ // variables needed to notify its completion.
+ struct RpcState {
+ std::mutex mu;
+ std::condition_variable cv;
+ bool done = false;
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext cli_ctx;
+
+ RpcState() = default;
+ ~RpcState() {
+ // Grab the lock to prevent destruction while another is still holding
+ // lock
+ std::lock_guard<std::mutex> lock(mu);
+ }
+ };
+ std::vector<RpcState> rpc_state(3);
+ for (size_t i = 0; i < rpc_state.size(); i++) {
+ TString message = "Hello locked world";
+ message += ToString(i);
+ rpc_state[i].request.set_message(message);
+ }
+
+ // Grab a lock and then start an RPC whose callback grabs the same lock and
+ // then calls this function to start the next RPC under lock (up to a limit of
+ // the size of the rpc_state vector).
+ std::function<void(int)> nested_call = [this, &nested_call,
+ &rpc_state](int index) {
+ std::lock_guard<std::mutex> l(rpc_state[index].mu);
+ stub_->experimental_async()->Echo(
+ &rpc_state[index].cli_ctx, &rpc_state[index].request,
+ &rpc_state[index].response,
+ [index, &nested_call, &rpc_state](Status s) {
+ std::lock_guard<std::mutex> l1(rpc_state[index].mu);
+ EXPECT_TRUE(s.ok());
+ rpc_state[index].done = true;
+ rpc_state[index].cv.notify_all();
+ // Call the next level of nesting if possible
+ if (index + 1 < rpc_state.size()) {
+ nested_call(index + 1);
+ }
+ });
+ };
+
+ nested_call(0);
+
+ // Wait for completion notifications from all RPCs. Order doesn't matter.
+ for (RpcState& state : rpc_state) {
+ std::unique_lock<std::mutex> l(state.mu);
+ while (!state.done) {
+ state.cv.wait(l);
+ }
+ EXPECT_EQ(state.request.message(), state.response.message());
+ }
+}
+
+TEST_P(ClientCallbackEnd2endTest, SimpleRpcUnderLock) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ std::mutex mu;
+ std::condition_variable cv;
+ bool done = false;
+ EchoRequest request;
+ request.set_message("Hello locked world.");
+ EchoResponse response;
+ ClientContext cli_ctx;
+ {
+ std::lock_guard<std::mutex> l(mu);
+ stub_->experimental_async()->Echo(
+ &cli_ctx, &request, &response,
+ [&mu, &cv, &done, &request, &response](Status s) {
+ std::lock_guard<std::mutex> l(mu);
+ EXPECT_TRUE(s.ok());
+ EXPECT_EQ(request.message(), response.message());
+ done = true;
+ cv.notify_one();
+ });
+ }
+ std::unique_lock<std::mutex> l(mu);
+ while (!done) {
+ cv.wait(l);
+ }
+}
+
+TEST_P(ClientCallbackEnd2endTest, SequentialRpcs) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ SendRpcs(10, false);
+}
+
+TEST_P(ClientCallbackEnd2endTest, SendClientInitialMetadata) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ SimpleRequest request;
+ SimpleResponse response;
+ ClientContext cli_ctx;
+
+ cli_ctx.AddMetadata(kCheckClientInitialMetadataKey,
+ kCheckClientInitialMetadataVal);
+
+ std::mutex mu;
+ std::condition_variable cv;
+ bool done = false;
+ stub_->experimental_async()->CheckClientInitialMetadata(
+ &cli_ctx, &request, &response, [&done, &mu, &cv](Status s) {
+ GPR_ASSERT(s.ok());
+
+ std::lock_guard<std::mutex> l(mu);
+ done = true;
+ cv.notify_one();
+ });
+ std::unique_lock<std::mutex> l(mu);
+ while (!done) {
+ cv.wait(l);
+ }
+}
+
+TEST_P(ClientCallbackEnd2endTest, SimpleRpcWithBinaryMetadata) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ SendRpcs(1, true);
+}
+
+TEST_P(ClientCallbackEnd2endTest, SequentialRpcsWithVariedBinaryMetadataValue) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ SendRpcs(10, true);
+}
+
+TEST_P(ClientCallbackEnd2endTest, SequentialGenericRpcs) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ SendRpcsGeneric(10, false);
+}
+
+TEST_P(ClientCallbackEnd2endTest, SequentialGenericRpcsAsBidi) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ SendGenericEchoAsBidi(10, 1, /*do_writes_done=*/true);
+}
+
+TEST_P(ClientCallbackEnd2endTest, SequentialGenericRpcsAsBidiWithReactorReuse) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ SendGenericEchoAsBidi(10, 10, /*do_writes_done=*/true);
+}
+
+TEST_P(ClientCallbackEnd2endTest, GenericRpcNoWritesDone) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ SendGenericEchoAsBidi(1, 1, /*do_writes_done=*/false);
+}
+
+#if GRPC_ALLOW_EXCEPTIONS
+TEST_P(ClientCallbackEnd2endTest, ExceptingRpc) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ SendRpcsGeneric(10, true);
+}
+#endif
+
+TEST_P(ClientCallbackEnd2endTest, MultipleRpcsWithVariedBinaryMetadataValue) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ std::vector<std::thread> threads;
+ threads.reserve(10);
+ for (int i = 0; i < 10; ++i) {
+ threads.emplace_back([this] { SendRpcs(10, true); });
+ }
+ for (int i = 0; i < 10; ++i) {
+ threads[i].join();
+ }
+}
+
+TEST_P(ClientCallbackEnd2endTest, MultipleRpcs) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ std::vector<std::thread> threads;
+ threads.reserve(10);
+ for (int i = 0; i < 10; ++i) {
+ threads.emplace_back([this] { SendRpcs(10, false); });
+ }
+ for (int i = 0; i < 10; ++i) {
+ threads[i].join();
+ }
+}
+
+TEST_P(ClientCallbackEnd2endTest, CancelRpcBeforeStart) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+ request.set_message("hello");
+ context.TryCancel();
+
+ std::mutex mu;
+ std::condition_variable cv;
+ bool done = false;
+ stub_->experimental_async()->Echo(
+ &context, &request, &response, [&response, &done, &mu, &cv](Status s) {
+ EXPECT_EQ("", response.message());
+ EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code());
+ std::lock_guard<std::mutex> l(mu);
+ done = true;
+ cv.notify_one();
+ });
+ std::unique_lock<std::mutex> l(mu);
+ while (!done) {
+ cv.wait(l);
+ }
+ if (GetParam().use_interceptors) {
+ EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel());
+ }
+}
+
+TEST_P(ClientCallbackEnd2endTest, RequestEchoServerCancel) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+ request.set_message("hello");
+ context.AddMetadata(kServerTryCancelRequest,
+ ToString(CANCEL_BEFORE_PROCESSING));
+
+ std::mutex mu;
+ std::condition_variable cv;
+ bool done = false;
+ stub_->experimental_async()->Echo(
+ &context, &request, &response, [&done, &mu, &cv](Status s) {
+ EXPECT_FALSE(s.ok());
+ EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code());
+ std::lock_guard<std::mutex> l(mu);
+ done = true;
+ cv.notify_one();
+ });
+ std::unique_lock<std::mutex> l(mu);
+ while (!done) {
+ cv.wait(l);
+ }
+}
+
+struct ClientCancelInfo {
+ bool cancel{false};
+ int ops_before_cancel;
+
+ ClientCancelInfo() : cancel{false} {}
+ explicit ClientCancelInfo(int ops) : cancel{true}, ops_before_cancel{ops} {}
+};
+
+class WriteClient : public grpc::experimental::ClientWriteReactor<EchoRequest> {
+ public:
+ WriteClient(grpc::testing::EchoTestService::Stub* stub,
+ ServerTryCancelRequestPhase server_try_cancel,
+ int num_msgs_to_send, ClientCancelInfo client_cancel = {})
+ : server_try_cancel_(server_try_cancel),
+ num_msgs_to_send_(num_msgs_to_send),
+ client_cancel_{client_cancel} {
+ TString msg{"Hello server."};
+ for (int i = 0; i < num_msgs_to_send; i++) {
+ desired_ += msg;
+ }
+ if (server_try_cancel != DO_NOT_CANCEL) {
+ // Send server_try_cancel value in the client metadata
+ context_.AddMetadata(kServerTryCancelRequest,
+ ToString(server_try_cancel));
+ }
+ context_.set_initial_metadata_corked(true);
+ stub->experimental_async()->RequestStream(&context_, &response_, this);
+ StartCall();
+ request_.set_message(msg);
+ MaybeWrite();
+ }
+ void OnWriteDone(bool ok) override {
+ if (ok) {
+ num_msgs_sent_++;
+ MaybeWrite();
+ }
+ }
+ void OnDone(const Status& s) override {
+ gpr_log(GPR_INFO, "Sent %d messages", num_msgs_sent_);
+ int num_to_send =
+ (client_cancel_.cancel)
+ ? std::min(num_msgs_to_send_, client_cancel_.ops_before_cancel)
+ : num_msgs_to_send_;
+ switch (server_try_cancel_) {
+ case CANCEL_BEFORE_PROCESSING:
+ case CANCEL_DURING_PROCESSING:
+ // If the RPC is canceled by server before / during messages from the
+ // client, it means that the client most likely did not get a chance to
+ // send all the messages it wanted to send. i.e num_msgs_sent <=
+ // num_msgs_to_send
+ EXPECT_LE(num_msgs_sent_, num_to_send);
+ break;
+ case DO_NOT_CANCEL:
+ case CANCEL_AFTER_PROCESSING:
+ // If the RPC was not canceled or canceled after all messages were read
+ // by the server, the client did get a chance to send all its messages
+ EXPECT_EQ(num_msgs_sent_, num_to_send);
+ break;
+ default:
+ assert(false);
+ break;
+ }
+ if ((server_try_cancel_ == DO_NOT_CANCEL) && !client_cancel_.cancel) {
+ EXPECT_TRUE(s.ok());
+ EXPECT_EQ(response_.message(), desired_);
+ } else {
+ EXPECT_FALSE(s.ok());
+ EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code());
+ }
+ std::unique_lock<std::mutex> l(mu_);
+ done_ = true;
+ cv_.notify_one();
+ }
+ void Await() {
+ std::unique_lock<std::mutex> l(mu_);
+ while (!done_) {
+ cv_.wait(l);
+ }
+ }
+
+ private:
+ void MaybeWrite() {
+ if (client_cancel_.cancel &&
+ num_msgs_sent_ == client_cancel_.ops_before_cancel) {
+ context_.TryCancel();
+ } else if (num_msgs_to_send_ > num_msgs_sent_ + 1) {
+ StartWrite(&request_);
+ } else if (num_msgs_to_send_ == num_msgs_sent_ + 1) {
+ StartWriteLast(&request_, WriteOptions());
+ }
+ }
+ EchoRequest request_;
+ EchoResponse response_;
+ ClientContext context_;
+ const ServerTryCancelRequestPhase server_try_cancel_;
+ int num_msgs_sent_{0};
+ const int num_msgs_to_send_;
+ TString desired_;
+ const ClientCancelInfo client_cancel_;
+ std::mutex mu_;
+ std::condition_variable cv_;
+ bool done_ = false;
+};
+
+TEST_P(ClientCallbackEnd2endTest, RequestStream) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ WriteClient test{stub_.get(), DO_NOT_CANCEL, 3};
+ test.Await();
+ // Make sure that the server interceptors were not notified to cancel
+ if (GetParam().use_interceptors) {
+ EXPECT_EQ(0, DummyInterceptor::GetNumTimesCancel());
+ }
+}
+
+TEST_P(ClientCallbackEnd2endTest, ClientCancelsRequestStream) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ WriteClient test{stub_.get(), DO_NOT_CANCEL, 3, ClientCancelInfo{2}};
+ test.Await();
+ // Make sure that the server interceptors got the cancel
+ if (GetParam().use_interceptors) {
+ EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel());
+ }
+}
+
+// Server to cancel before doing reading the request
+TEST_P(ClientCallbackEnd2endTest, RequestStreamServerCancelBeforeReads) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ WriteClient test{stub_.get(), CANCEL_BEFORE_PROCESSING, 1};
+ test.Await();
+ // Make sure that the server interceptors were notified
+ if (GetParam().use_interceptors) {
+ EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel());
+ }
+}
+
+// Server to cancel while reading a request from the stream in parallel
+TEST_P(ClientCallbackEnd2endTest, RequestStreamServerCancelDuringRead) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ WriteClient test{stub_.get(), CANCEL_DURING_PROCESSING, 10};
+ test.Await();
+ // Make sure that the server interceptors were notified
+ if (GetParam().use_interceptors) {
+ EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel());
+ }
+}
+
+// Server to cancel after reading all the requests but before returning to the
+// client
+TEST_P(ClientCallbackEnd2endTest, RequestStreamServerCancelAfterReads) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ WriteClient test{stub_.get(), CANCEL_AFTER_PROCESSING, 4};
+ test.Await();
+ // Make sure that the server interceptors were notified
+ if (GetParam().use_interceptors) {
+ EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel());
+ }
+}
+
+TEST_P(ClientCallbackEnd2endTest, UnaryReactor) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ class UnaryClient : public grpc::experimental::ClientUnaryReactor {
+ public:
+ UnaryClient(grpc::testing::EchoTestService::Stub* stub) {
+ cli_ctx_.AddMetadata("key1", "val1");
+ cli_ctx_.AddMetadata("key2", "val2");
+ request_.mutable_param()->set_echo_metadata_initially(true);
+ request_.set_message("Hello metadata");
+ stub->experimental_async()->Echo(&cli_ctx_, &request_, &response_, this);
+ StartCall();
+ }
+ void OnReadInitialMetadataDone(bool ok) override {
+ EXPECT_TRUE(ok);
+ EXPECT_EQ(1u, cli_ctx_.GetServerInitialMetadata().count("key1"));
+ EXPECT_EQ(
+ "val1",
+ ToString(cli_ctx_.GetServerInitialMetadata().find("key1")->second));
+ EXPECT_EQ(1u, cli_ctx_.GetServerInitialMetadata().count("key2"));
+ EXPECT_EQ(
+ "val2",
+ ToString(cli_ctx_.GetServerInitialMetadata().find("key2")->second));
+ initial_metadata_done_ = true;
+ }
+ void OnDone(const Status& s) override {
+ EXPECT_TRUE(initial_metadata_done_);
+ EXPECT_EQ(0u, cli_ctx_.GetServerTrailingMetadata().size());
+ EXPECT_TRUE(s.ok());
+ EXPECT_EQ(request_.message(), response_.message());
+ std::unique_lock<std::mutex> l(mu_);
+ done_ = true;
+ cv_.notify_one();
+ }
+ void Await() {
+ std::unique_lock<std::mutex> l(mu_);
+ while (!done_) {
+ cv_.wait(l);
+ }
+ }
+
+ private:
+ EchoRequest request_;
+ EchoResponse response_;
+ ClientContext cli_ctx_;
+ std::mutex mu_;
+ std::condition_variable cv_;
+ bool done_{false};
+ bool initial_metadata_done_{false};
+ };
+
+ UnaryClient test{stub_.get()};
+ test.Await();
+ // Make sure that the server interceptors were not notified of a cancel
+ if (GetParam().use_interceptors) {
+ EXPECT_EQ(0, DummyInterceptor::GetNumTimesCancel());
+ }
+}
+
+TEST_P(ClientCallbackEnd2endTest, GenericUnaryReactor) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ const TString kMethodName("/grpc.testing.EchoTestService/Echo");
+ class UnaryClient : public grpc::experimental::ClientUnaryReactor {
+ public:
+ UnaryClient(grpc::GenericStub* stub, const TString& method_name) {
+ cli_ctx_.AddMetadata("key1", "val1");
+ cli_ctx_.AddMetadata("key2", "val2");
+ request_.mutable_param()->set_echo_metadata_initially(true);
+ request_.set_message("Hello metadata");
+ send_buf_ = SerializeToByteBuffer(&request_);
+
+ stub->experimental().PrepareUnaryCall(&cli_ctx_, method_name,
+ send_buf_.get(), &recv_buf_, this);
+ StartCall();
+ }
+ void OnReadInitialMetadataDone(bool ok) override {
+ EXPECT_TRUE(ok);
+ EXPECT_EQ(1u, cli_ctx_.GetServerInitialMetadata().count("key1"));
+ EXPECT_EQ(
+ "val1",
+ ToString(cli_ctx_.GetServerInitialMetadata().find("key1")->second));
+ EXPECT_EQ(1u, cli_ctx_.GetServerInitialMetadata().count("key2"));
+ EXPECT_EQ(
+ "val2",
+ ToString(cli_ctx_.GetServerInitialMetadata().find("key2")->second));
+ initial_metadata_done_ = true;
+ }
+ void OnDone(const Status& s) override {
+ EXPECT_TRUE(initial_metadata_done_);
+ EXPECT_EQ(0u, cli_ctx_.GetServerTrailingMetadata().size());
+ EXPECT_TRUE(s.ok());
+ EchoResponse response;
+ EXPECT_TRUE(ParseFromByteBuffer(&recv_buf_, &response));
+ EXPECT_EQ(request_.message(), response.message());
+ std::unique_lock<std::mutex> l(mu_);
+ done_ = true;
+ cv_.notify_one();
+ }
+ void Await() {
+ std::unique_lock<std::mutex> l(mu_);
+ while (!done_) {
+ cv_.wait(l);
+ }
+ }
+
+ private:
+ EchoRequest request_;
+ std::unique_ptr<ByteBuffer> send_buf_;
+ ByteBuffer recv_buf_;
+ ClientContext cli_ctx_;
+ std::mutex mu_;
+ std::condition_variable cv_;
+ bool done_{false};
+ bool initial_metadata_done_{false};
+ };
+
+ UnaryClient test{generic_stub_.get(), kMethodName};
+ test.Await();
+ // Make sure that the server interceptors were not notified of a cancel
+ if (GetParam().use_interceptors) {
+ EXPECT_EQ(0, DummyInterceptor::GetNumTimesCancel());
+ }
+}
+
+class ReadClient : public grpc::experimental::ClientReadReactor<EchoResponse> {
+ public:
+ ReadClient(grpc::testing::EchoTestService::Stub* stub,
+ ServerTryCancelRequestPhase server_try_cancel,
+ ClientCancelInfo client_cancel = {})
+ : server_try_cancel_(server_try_cancel), client_cancel_{client_cancel} {
+ if (server_try_cancel_ != DO_NOT_CANCEL) {
+ // Send server_try_cancel value in the client metadata
+ context_.AddMetadata(kServerTryCancelRequest,
+ ToString(server_try_cancel));
+ }
+ request_.set_message("Hello client ");
+ stub->experimental_async()->ResponseStream(&context_, &request_, this);
+ if (client_cancel_.cancel &&
+ reads_complete_ == client_cancel_.ops_before_cancel) {
+ context_.TryCancel();
+ }
+ // Even if we cancel, read until failure because there might be responses
+ // pending
+ StartRead(&response_);
+ StartCall();
+ }
+ void OnReadDone(bool ok) override {
+ if (!ok) {
+ if (server_try_cancel_ == DO_NOT_CANCEL && !client_cancel_.cancel) {
+ EXPECT_EQ(reads_complete_, kServerDefaultResponseStreamsToSend);
+ }
+ } else {
+ EXPECT_LE(reads_complete_, kServerDefaultResponseStreamsToSend);
+ EXPECT_EQ(response_.message(),
+ request_.message() + ToString(reads_complete_));
+ reads_complete_++;
+ if (client_cancel_.cancel &&
+ reads_complete_ == client_cancel_.ops_before_cancel) {
+ context_.TryCancel();
+ }
+ // Even if we cancel, read until failure because there might be responses
+ // pending
+ StartRead(&response_);
+ }
+ }
+ void OnDone(const Status& s) override {
+ gpr_log(GPR_INFO, "Read %d messages", reads_complete_);
+ switch (server_try_cancel_) {
+ case DO_NOT_CANCEL:
+ if (!client_cancel_.cancel || client_cancel_.ops_before_cancel >
+ kServerDefaultResponseStreamsToSend) {
+ EXPECT_TRUE(s.ok());
+ EXPECT_EQ(reads_complete_, kServerDefaultResponseStreamsToSend);
+ } else {
+ EXPECT_GE(reads_complete_, client_cancel_.ops_before_cancel);
+ EXPECT_LE(reads_complete_, kServerDefaultResponseStreamsToSend);
+ // Status might be ok or cancelled depending on whether server
+ // sent status before client cancel went through
+ if (!s.ok()) {
+ EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code());
+ }
+ }
+ break;
+ case CANCEL_BEFORE_PROCESSING:
+ EXPECT_FALSE(s.ok());
+ EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code());
+ EXPECT_EQ(reads_complete_, 0);
+ break;
+ case CANCEL_DURING_PROCESSING:
+ case CANCEL_AFTER_PROCESSING:
+ // If server canceled while writing messages, client must have read
+ // less than or equal to the expected number of messages. Even if the
+ // server canceled after writing all messages, the RPC may be canceled
+ // before the Client got a chance to read all the messages.
+ EXPECT_FALSE(s.ok());
+ EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code());
+ EXPECT_LE(reads_complete_, kServerDefaultResponseStreamsToSend);
+ break;
+ default:
+ assert(false);
+ }
+ std::unique_lock<std::mutex> l(mu_);
+ done_ = true;
+ cv_.notify_one();
+ }
+ void Await() {
+ std::unique_lock<std::mutex> l(mu_);
+ while (!done_) {
+ cv_.wait(l);
+ }
+ }
+
+ private:
+ EchoRequest request_;
+ EchoResponse response_;
+ ClientContext context_;
+ const ServerTryCancelRequestPhase server_try_cancel_;
+ int reads_complete_{0};
+ const ClientCancelInfo client_cancel_;
+ std::mutex mu_;
+ std::condition_variable cv_;
+ bool done_ = false;
+};
+
+TEST_P(ClientCallbackEnd2endTest, ResponseStream) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ ReadClient test{stub_.get(), DO_NOT_CANCEL};
+ test.Await();
+ // Make sure that the server interceptors were not notified of a cancel
+ if (GetParam().use_interceptors) {
+ EXPECT_EQ(0, DummyInterceptor::GetNumTimesCancel());
+ }
+}
+
+TEST_P(ClientCallbackEnd2endTest, ClientCancelsResponseStream) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ ReadClient test{stub_.get(), DO_NOT_CANCEL, ClientCancelInfo{2}};
+ test.Await();
+ // Because cancel in this case races with server finish, we can't be sure that
+ // server interceptors even see cancellation
+}
+
+// Server to cancel before sending any response messages
+TEST_P(ClientCallbackEnd2endTest, ResponseStreamServerCancelBefore) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ ReadClient test{stub_.get(), CANCEL_BEFORE_PROCESSING};
+ test.Await();
+ // Make sure that the server interceptors were notified
+ if (GetParam().use_interceptors) {
+ EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel());
+ }
+}
+
+// Server to cancel while writing a response to the stream in parallel
+TEST_P(ClientCallbackEnd2endTest, ResponseStreamServerCancelDuring) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ ReadClient test{stub_.get(), CANCEL_DURING_PROCESSING};
+ test.Await();
+ // Make sure that the server interceptors were notified
+ if (GetParam().use_interceptors) {
+ EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel());
+ }
+}
+
+// Server to cancel after writing all the respones to the stream but before
+// returning to the client
+TEST_P(ClientCallbackEnd2endTest, ResponseStreamServerCancelAfter) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ ReadClient test{stub_.get(), CANCEL_AFTER_PROCESSING};
+ test.Await();
+ // Make sure that the server interceptors were notified
+ if (GetParam().use_interceptors) {
+ EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel());
+ }
+}
+
+class BidiClient
+ : public grpc::experimental::ClientBidiReactor<EchoRequest, EchoResponse> {
+ public:
+ BidiClient(grpc::testing::EchoTestService::Stub* stub,
+ ServerTryCancelRequestPhase server_try_cancel,
+ int num_msgs_to_send, bool cork_metadata, bool first_write_async,
+ ClientCancelInfo client_cancel = {})
+ : server_try_cancel_(server_try_cancel),
+ msgs_to_send_{num_msgs_to_send},
+ client_cancel_{client_cancel} {
+ if (server_try_cancel_ != DO_NOT_CANCEL) {
+ // Send server_try_cancel value in the client metadata
+ context_.AddMetadata(kServerTryCancelRequest,
+ ToString(server_try_cancel));
+ }
+ request_.set_message("Hello fren ");
+ context_.set_initial_metadata_corked(cork_metadata);
+ stub->experimental_async()->BidiStream(&context_, this);
+ MaybeAsyncWrite(first_write_async);
+ StartRead(&response_);
+ StartCall();
+ }
+ void OnReadDone(bool ok) override {
+ if (!ok) {
+ if (server_try_cancel_ == DO_NOT_CANCEL) {
+ if (!client_cancel_.cancel) {
+ EXPECT_EQ(reads_complete_, msgs_to_send_);
+ } else {
+ EXPECT_LE(reads_complete_, writes_complete_);
+ }
+ }
+ } else {
+ EXPECT_LE(reads_complete_, msgs_to_send_);
+ EXPECT_EQ(response_.message(), request_.message());
+ reads_complete_++;
+ StartRead(&response_);
+ }
+ }
+ void OnWriteDone(bool ok) override {
+ if (async_write_thread_.joinable()) {
+ async_write_thread_.join();
+ RemoveHold();
+ }
+ if (server_try_cancel_ == DO_NOT_CANCEL) {
+ EXPECT_TRUE(ok);
+ } else if (!ok) {
+ return;
+ }
+ writes_complete_++;
+ MaybeWrite();
+ }
+ void OnDone(const Status& s) override {
+ gpr_log(GPR_INFO, "Sent %d messages", writes_complete_);
+ gpr_log(GPR_INFO, "Read %d messages", reads_complete_);
+ switch (server_try_cancel_) {
+ case DO_NOT_CANCEL:
+ if (!client_cancel_.cancel ||
+ client_cancel_.ops_before_cancel > msgs_to_send_) {
+ EXPECT_TRUE(s.ok());
+ EXPECT_EQ(writes_complete_, msgs_to_send_);
+ EXPECT_EQ(reads_complete_, writes_complete_);
+ } else {
+ EXPECT_FALSE(s.ok());
+ EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code());
+ EXPECT_EQ(writes_complete_, client_cancel_.ops_before_cancel);
+ EXPECT_LE(reads_complete_, writes_complete_);
+ }
+ break;
+ case CANCEL_BEFORE_PROCESSING:
+ EXPECT_FALSE(s.ok());
+ EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code());
+ // The RPC is canceled before the server did any work or returned any
+ // reads, but it's possible that some writes took place first from the
+ // client
+ EXPECT_LE(writes_complete_, msgs_to_send_);
+ EXPECT_EQ(reads_complete_, 0);
+ break;
+ case CANCEL_DURING_PROCESSING:
+ EXPECT_FALSE(s.ok());
+ EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code());
+ EXPECT_LE(writes_complete_, msgs_to_send_);
+ EXPECT_LE(reads_complete_, writes_complete_);
+ break;
+ case CANCEL_AFTER_PROCESSING:
+ EXPECT_FALSE(s.ok());
+ EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code());
+ EXPECT_EQ(writes_complete_, msgs_to_send_);
+ // The Server canceled after reading the last message and after writing
+ // the message to the client. However, the RPC cancellation might have
+ // taken effect before the client actually read the response.
+ EXPECT_LE(reads_complete_, writes_complete_);
+ break;
+ default:
+ assert(false);
+ }
+ std::unique_lock<std::mutex> l(mu_);
+ done_ = true;
+ cv_.notify_one();
+ }
+ void Await() {
+ std::unique_lock<std::mutex> l(mu_);
+ while (!done_) {
+ cv_.wait(l);
+ }
+ }
+
+ private:
+ void MaybeAsyncWrite(bool first_write_async) {
+ if (first_write_async) {
+ // Make sure that we have a write to issue.
+ // TODO(vjpai): Make this work with 0 writes case as well.
+ assert(msgs_to_send_ >= 1);
+
+ AddHold();
+ async_write_thread_ = std::thread([this] {
+ std::unique_lock<std::mutex> lock(async_write_thread_mu_);
+ async_write_thread_cv_.wait(
+ lock, [this] { return async_write_thread_start_; });
+ MaybeWrite();
+ });
+ std::lock_guard<std::mutex> lock(async_write_thread_mu_);
+ async_write_thread_start_ = true;
+ async_write_thread_cv_.notify_one();
+ return;
+ }
+ MaybeWrite();
+ }
+ void MaybeWrite() {
+ if (client_cancel_.cancel &&
+ writes_complete_ == client_cancel_.ops_before_cancel) {
+ context_.TryCancel();
+ } else if (writes_complete_ == msgs_to_send_) {
+ StartWritesDone();
+ } else {
+ StartWrite(&request_);
+ }
+ }
+ EchoRequest request_;
+ EchoResponse response_;
+ ClientContext context_;
+ const ServerTryCancelRequestPhase server_try_cancel_;
+ int reads_complete_{0};
+ int writes_complete_{0};
+ const int msgs_to_send_;
+ const ClientCancelInfo client_cancel_;
+ std::mutex mu_;
+ std::condition_variable cv_;
+ bool done_ = false;
+ std::thread async_write_thread_;
+ bool async_write_thread_start_ = false;
+ std::mutex async_write_thread_mu_;
+ std::condition_variable async_write_thread_cv_;
+};
+
+TEST_P(ClientCallbackEnd2endTest, BidiStream) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ BidiClient test(stub_.get(), DO_NOT_CANCEL,
+ kServerDefaultResponseStreamsToSend,
+ /*cork_metadata=*/false, /*first_write_async=*/false);
+ test.Await();
+ // Make sure that the server interceptors were not notified of a cancel
+ if (GetParam().use_interceptors) {
+ EXPECT_EQ(0, DummyInterceptor::GetNumTimesCancel());
+ }
+}
+
+TEST_P(ClientCallbackEnd2endTest, BidiStreamFirstWriteAsync) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ BidiClient test(stub_.get(), DO_NOT_CANCEL,
+ kServerDefaultResponseStreamsToSend,
+ /*cork_metadata=*/false, /*first_write_async=*/true);
+ test.Await();
+ // Make sure that the server interceptors were not notified of a cancel
+ if (GetParam().use_interceptors) {
+ EXPECT_EQ(0, DummyInterceptor::GetNumTimesCancel());
+ }
+}
+
+TEST_P(ClientCallbackEnd2endTest, BidiStreamCorked) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ BidiClient test(stub_.get(), DO_NOT_CANCEL,
+ kServerDefaultResponseStreamsToSend,
+ /*cork_metadata=*/true, /*first_write_async=*/false);
+ test.Await();
+ // Make sure that the server interceptors were not notified of a cancel
+ if (GetParam().use_interceptors) {
+ EXPECT_EQ(0, DummyInterceptor::GetNumTimesCancel());
+ }
+}
+
+TEST_P(ClientCallbackEnd2endTest, BidiStreamCorkedFirstWriteAsync) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ BidiClient test(stub_.get(), DO_NOT_CANCEL,
+ kServerDefaultResponseStreamsToSend,
+ /*cork_metadata=*/true, /*first_write_async=*/true);
+ test.Await();
+ // Make sure that the server interceptors were not notified of a cancel
+ if (GetParam().use_interceptors) {
+ EXPECT_EQ(0, DummyInterceptor::GetNumTimesCancel());
+ }
+}
+
+TEST_P(ClientCallbackEnd2endTest, ClientCancelsBidiStream) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ BidiClient test(stub_.get(), DO_NOT_CANCEL,
+ kServerDefaultResponseStreamsToSend,
+ /*cork_metadata=*/false, /*first_write_async=*/false,
+ ClientCancelInfo(2));
+ test.Await();
+ // Make sure that the server interceptors were notified of a cancel
+ if (GetParam().use_interceptors) {
+ EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel());
+ }
+}
+
+// Server to cancel before reading/writing any requests/responses on the stream
+TEST_P(ClientCallbackEnd2endTest, BidiStreamServerCancelBefore) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ BidiClient test(stub_.get(), CANCEL_BEFORE_PROCESSING, /*num_msgs_to_send=*/2,
+ /*cork_metadata=*/false, /*first_write_async=*/false);
+ test.Await();
+ // Make sure that the server interceptors were notified
+ if (GetParam().use_interceptors) {
+ EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel());
+ }
+}
+
+// Server to cancel while reading/writing requests/responses on the stream in
+// parallel
+TEST_P(ClientCallbackEnd2endTest, BidiStreamServerCancelDuring) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ BidiClient test(stub_.get(), CANCEL_DURING_PROCESSING,
+ /*num_msgs_to_send=*/10, /*cork_metadata=*/false,
+ /*first_write_async=*/false);
+ test.Await();
+ // Make sure that the server interceptors were notified
+ if (GetParam().use_interceptors) {
+ EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel());
+ }
+}
+
+// Server to cancel after reading/writing all requests/responses on the stream
+// but before returning to the client
+TEST_P(ClientCallbackEnd2endTest, BidiStreamServerCancelAfter) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ BidiClient test(stub_.get(), CANCEL_AFTER_PROCESSING, /*num_msgs_to_send=*/5,
+ /*cork_metadata=*/false, /*first_write_async=*/false);
+ test.Await();
+ // Make sure that the server interceptors were notified
+ if (GetParam().use_interceptors) {
+ EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel());
+ }
+}
+
+TEST_P(ClientCallbackEnd2endTest, SimultaneousReadAndWritesDone) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ class Client : public grpc::experimental::ClientBidiReactor<EchoRequest,
+ EchoResponse> {
+ public:
+ Client(grpc::testing::EchoTestService::Stub* stub) {
+ request_.set_message("Hello bidi ");
+ stub->experimental_async()->BidiStream(&context_, this);
+ StartWrite(&request_);
+ StartCall();
+ }
+ void OnReadDone(bool ok) override {
+ EXPECT_TRUE(ok);
+ EXPECT_EQ(response_.message(), request_.message());
+ }
+ void OnWriteDone(bool ok) override {
+ EXPECT_TRUE(ok);
+ // Now send out the simultaneous Read and WritesDone
+ StartWritesDone();
+ StartRead(&response_);
+ }
+ void OnDone(const Status& s) override {
+ EXPECT_TRUE(s.ok());
+ EXPECT_EQ(response_.message(), request_.message());
+ std::unique_lock<std::mutex> l(mu_);
+ done_ = true;
+ cv_.notify_one();
+ }
+ void Await() {
+ std::unique_lock<std::mutex> l(mu_);
+ while (!done_) {
+ cv_.wait(l);
+ }
+ }
+
+ private:
+ EchoRequest request_;
+ EchoResponse response_;
+ ClientContext context_;
+ std::mutex mu_;
+ std::condition_variable cv_;
+ bool done_ = false;
+ } test{stub_.get()};
+
+ test.Await();
+}
+
+TEST_P(ClientCallbackEnd2endTest, UnimplementedRpc) {
+ MAYBE_SKIP_TEST;
+ ChannelArguments args;
+ const auto& channel_creds = GetCredentialsProvider()->GetChannelCredentials(
+ GetParam().credentials_type, &args);
+ std::shared_ptr<Channel> channel =
+ (GetParam().protocol == Protocol::TCP)
+ ? ::grpc::CreateCustomChannel(server_address_.str(), channel_creds,
+ args)
+ : server_->InProcessChannel(args);
+ std::unique_ptr<grpc::testing::UnimplementedEchoService::Stub> stub;
+ stub = grpc::testing::UnimplementedEchoService::NewStub(channel);
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext cli_ctx;
+ request.set_message("Hello world.");
+ std::mutex mu;
+ std::condition_variable cv;
+ bool done = false;
+ stub->experimental_async()->Unimplemented(
+ &cli_ctx, &request, &response, [&done, &mu, &cv](Status s) {
+ EXPECT_EQ(StatusCode::UNIMPLEMENTED, s.error_code());
+ EXPECT_EQ("", s.error_message());
+
+ std::lock_guard<std::mutex> l(mu);
+ done = true;
+ cv.notify_one();
+ });
+ std::unique_lock<std::mutex> l(mu);
+ while (!done) {
+ cv.wait(l);
+ }
+}
+
+TEST_P(ClientCallbackEnd2endTest,
+ ResponseStreamExtraReactionFlowReadsUntilDone) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ class ReadAllIncomingDataClient
+ : public grpc::experimental::ClientReadReactor<EchoResponse> {
+ public:
+ ReadAllIncomingDataClient(grpc::testing::EchoTestService::Stub* stub) {
+ request_.set_message("Hello client ");
+ stub->experimental_async()->ResponseStream(&context_, &request_, this);
+ }
+ bool WaitForReadDone() {
+ std::unique_lock<std::mutex> l(mu_);
+ while (!read_done_) {
+ read_cv_.wait(l);
+ }
+ read_done_ = false;
+ return read_ok_;
+ }
+ void Await() {
+ std::unique_lock<std::mutex> l(mu_);
+ while (!done_) {
+ done_cv_.wait(l);
+ }
+ }
+ // RemoveHold under the same lock used for OnDone to make sure that we don't
+ // call OnDone directly or indirectly from the RemoveHold function.
+ void RemoveHoldUnderLock() {
+ std::unique_lock<std::mutex> l(mu_);
+ RemoveHold();
+ }
+ const Status& status() {
+ std::unique_lock<std::mutex> l(mu_);
+ return status_;
+ }
+
+ private:
+ void OnReadDone(bool ok) override {
+ std::unique_lock<std::mutex> l(mu_);
+ read_ok_ = ok;
+ read_done_ = true;
+ read_cv_.notify_one();
+ }
+ void OnDone(const Status& s) override {
+ std::unique_lock<std::mutex> l(mu_);
+ done_ = true;
+ status_ = s;
+ done_cv_.notify_one();
+ }
+
+ EchoRequest request_;
+ EchoResponse response_;
+ ClientContext context_;
+ bool read_ok_ = false;
+ bool read_done_ = false;
+ std::mutex mu_;
+ std::condition_variable read_cv_;
+ std::condition_variable done_cv_;
+ bool done_ = false;
+ Status status_;
+ } client{stub_.get()};
+
+ int reads_complete = 0;
+ client.AddHold();
+ client.StartCall();
+
+ EchoResponse response;
+ bool read_ok = true;
+ while (read_ok) {
+ client.StartRead(&response);
+ read_ok = client.WaitForReadDone();
+ if (read_ok) {
+ ++reads_complete;
+ }
+ }
+ client.RemoveHoldUnderLock();
+ client.Await();
+
+ EXPECT_EQ(kServerDefaultResponseStreamsToSend, reads_complete);
+ EXPECT_EQ(client.status().error_code(), grpc::StatusCode::OK);
+}
+
+std::vector<TestScenario> CreateTestScenarios(bool test_insecure) {
+#if TARGET_OS_IPHONE
+ // Workaround Apple CFStream bug
+ gpr_setenv("grpc_cfstream", "0");
+#endif
+
+ std::vector<TestScenario> scenarios;
+ std::vector<TString> credentials_types{
+ GetCredentialsProvider()->GetSecureCredentialsTypeList()};
+ auto insec_ok = [] {
+ // Only allow insecure credentials type when it is registered with the
+ // provider. User may create providers that do not have insecure.
+ return GetCredentialsProvider()->GetChannelCredentials(
+ kInsecureCredentialsType, nullptr) != nullptr;
+ };
+ if (test_insecure && insec_ok()) {
+ credentials_types.push_back(kInsecureCredentialsType);
+ }
+ GPR_ASSERT(!credentials_types.empty());
+
+ bool barr[]{false, true};
+ Protocol parr[]{Protocol::INPROC, Protocol::TCP};
+ for (Protocol p : parr) {
+ for (const auto& cred : credentials_types) {
+ // TODO(vjpai): Test inproc with secure credentials when feasible
+ if (p == Protocol::INPROC &&
+ (cred != kInsecureCredentialsType || !insec_ok())) {
+ continue;
+ }
+ for (bool callback_server : barr) {
+ for (bool use_interceptors : barr) {
+ scenarios.emplace_back(callback_server, p, use_interceptors, cred);
+ }
+ }
+ }
+ }
+ return scenarios;
+}
+
+INSTANTIATE_TEST_SUITE_P(ClientCallbackEnd2endTest, ClientCallbackEnd2endTest,
+ ::testing::ValuesIn(CreateTestScenarios(true)));
+
+} // namespace
+} // namespace testing
+} // namespace grpc
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ grpc::testing::TestEnvironment env(argc, argv);
+ grpc_init();
+ int ret = RUN_ALL_TESTS();
+ grpc_shutdown();
+ return ret;
+}
diff --git a/contrib/libs/grpc/test/cpp/end2end/client_crash_test.cc b/contrib/libs/grpc/test/cpp/end2end/client_crash_test.cc
new file mode 100644
index 0000000000..80e1869396
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/end2end/client_crash_test.cc
@@ -0,0 +1,147 @@
+/*
+ *
+ * Copyright 2015 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include <grpc/grpc.h>
+#include <grpc/support/log.h>
+#include <grpc/support/time.h>
+#include <grpcpp/channel.h>
+#include <grpcpp/client_context.h>
+#include <grpcpp/create_channel.h>
+#include <grpcpp/server.h>
+#include <grpcpp/server_builder.h>
+#include <grpcpp/server_context.h>
+
+#include "src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.h"
+#include "src/proto/grpc/testing/echo.grpc.pb.h"
+#include "test/core/util/port.h"
+#include "test/core/util/test_config.h"
+#include "test/cpp/util/subprocess.h"
+
+#include <gtest/gtest.h>
+
+using grpc::testing::EchoRequest;
+using grpc::testing::EchoResponse;
+using std::chrono::system_clock;
+
+static TString g_root;
+
+namespace grpc {
+namespace testing {
+
+namespace {
+
+class CrashTest : public ::testing::Test {
+ protected:
+ CrashTest() {}
+
+ std::unique_ptr<grpc::testing::EchoTestService::Stub> CreateServerAndStub() {
+ auto port = grpc_pick_unused_port_or_die();
+ std::ostringstream addr_stream;
+ addr_stream << "localhost:" << port;
+ auto addr = addr_stream.str();
+ server_.reset(new SubProcess({
+ g_root + "/client_crash_test_server",
+ "--address=" + addr,
+ }));
+ GPR_ASSERT(server_);
+ return grpc::testing::EchoTestService::NewStub(
+ grpc::CreateChannel(addr, InsecureChannelCredentials()));
+ }
+
+ void KillServer() { server_.reset(); }
+
+ private:
+ std::unique_ptr<SubProcess> server_;
+};
+
+TEST_F(CrashTest, KillBeforeWrite) {
+ auto stub = CreateServerAndStub();
+
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+ context.set_wait_for_ready(true);
+
+ auto stream = stub->BidiStream(&context);
+
+ request.set_message("Hello");
+ EXPECT_TRUE(stream->Write(request));
+ EXPECT_TRUE(stream->Read(&response));
+ EXPECT_EQ(response.message(), request.message());
+
+ KillServer();
+
+ request.set_message("You should be dead");
+ // This may succeed or fail depending on the state of the TCP connection
+ stream->Write(request);
+ // But the read will definitely fail
+ EXPECT_FALSE(stream->Read(&response));
+
+ EXPECT_FALSE(stream->Finish().ok());
+}
+
+TEST_F(CrashTest, KillAfterWrite) {
+ auto stub = CreateServerAndStub();
+
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+ context.set_wait_for_ready(true);
+
+ auto stream = stub->BidiStream(&context);
+
+ request.set_message("Hello");
+ EXPECT_TRUE(stream->Write(request));
+ EXPECT_TRUE(stream->Read(&response));
+ EXPECT_EQ(response.message(), request.message());
+
+ request.set_message("I'm going to kill you");
+ EXPECT_TRUE(stream->Write(request));
+
+ KillServer();
+
+ // This may succeed or fail depending on how quick the server was
+ stream->Read(&response);
+
+ EXPECT_FALSE(stream->Finish().ok());
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace grpc
+
+int main(int argc, char** argv) {
+ TString me = argv[0];
+ auto lslash = me.rfind('/');
+ if (lslash != TString::npos) {
+ g_root = me.substr(0, lslash);
+ } else {
+ g_root = ".";
+ }
+
+ grpc::testing::TestEnvironment env(argc, argv);
+ ::testing::InitGoogleTest(&argc, argv);
+ // Order seems to matter on these tests: run three times to eliminate that
+ for (int i = 0; i < 3; i++) {
+ if (RUN_ALL_TESTS() != 0) {
+ return 1;
+ }
+ }
+ return 0;
+}
diff --git a/contrib/libs/grpc/test/cpp/end2end/client_crash_test_server.cc b/contrib/libs/grpc/test/cpp/end2end/client_crash_test_server.cc
new file mode 100644
index 0000000000..2d5be420f2
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/end2end/client_crash_test_server.cc
@@ -0,0 +1,80 @@
+/*
+ *
+ * Copyright 2015 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include <gflags/gflags.h>
+#include <iostream>
+#include <memory>
+#include <util/generic/string.h>
+
+#include <grpc/support/log.h>
+#include <grpcpp/server.h>
+#include <grpcpp/server_builder.h>
+#include <grpcpp/server_context.h>
+
+#include "src/proto/grpc/testing/echo.grpc.pb.h"
+#include "test/cpp/util/test_config.h"
+
+DEFINE_string(address, "", "Address to bind to");
+
+using grpc::testing::EchoRequest;
+using grpc::testing::EchoResponse;
+
+// In some distros, gflags is in the namespace google, and in some others,
+// in gflags. This hack is enabling us to find both.
+namespace google {}
+namespace gflags {}
+using namespace google;
+using namespace gflags;
+
+namespace grpc {
+namespace testing {
+
+class ServiceImpl final : public ::grpc::testing::EchoTestService::Service {
+ Status BidiStream(
+ ServerContext* /*context*/,
+ ServerReaderWriter<EchoResponse, EchoRequest>* stream) override {
+ EchoRequest request;
+ EchoResponse response;
+ while (stream->Read(&request)) {
+ gpr_log(GPR_INFO, "recv msg %s", request.message().c_str());
+ response.set_message(request.message());
+ stream->Write(response);
+ }
+ return Status::OK;
+ }
+};
+
+void RunServer() {
+ ServiceImpl service;
+
+ ServerBuilder builder;
+ builder.AddListeningPort(FLAGS_address, grpc::InsecureServerCredentials());
+ builder.RegisterService(&service);
+ std::unique_ptr<Server> server(builder.BuildAndStart());
+ std::cout << "Server listening on " << FLAGS_address << std::endl;
+ server->Wait();
+}
+} // namespace testing
+} // namespace grpc
+
+int main(int argc, char** argv) {
+ grpc::testing::InitTest(&argc, &argv, true);
+ grpc::testing::RunServer();
+
+ return 0;
+}
diff --git a/contrib/libs/grpc/test/cpp/end2end/client_interceptors_end2end_test.cc b/contrib/libs/grpc/test/cpp/end2end/client_interceptors_end2end_test.cc
new file mode 100644
index 0000000000..956876d9f6
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/end2end/client_interceptors_end2end_test.cc
@@ -0,0 +1,1194 @@
+/*
+ *
+ * Copyright 2018 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include <memory>
+#include <vector>
+
+#include <grpcpp/channel.h>
+#include <grpcpp/client_context.h>
+#include <grpcpp/create_channel.h>
+#include <grpcpp/generic/generic_stub.h>
+#include <grpcpp/impl/codegen/proto_utils.h>
+#include <grpcpp/server.h>
+#include <grpcpp/server_builder.h>
+#include <grpcpp/server_context.h>
+#include <grpcpp/support/client_interceptor.h>
+
+#include "src/proto/grpc/testing/echo.grpc.pb.h"
+#include "test/core/util/port.h"
+#include "test/core/util/test_config.h"
+#include "test/cpp/end2end/interceptors_util.h"
+#include "test/cpp/end2end/test_service_impl.h"
+#include "test/cpp/util/byte_buffer_proto_helper.h"
+#include "test/cpp/util/string_ref_helper.h"
+
+#include <gtest/gtest.h>
+
+namespace grpc {
+namespace testing {
+namespace {
+
+enum class RPCType {
+ kSyncUnary,
+ kSyncClientStreaming,
+ kSyncServerStreaming,
+ kSyncBidiStreaming,
+ kAsyncCQUnary,
+ kAsyncCQClientStreaming,
+ kAsyncCQServerStreaming,
+ kAsyncCQBidiStreaming,
+};
+
+/* Hijacks Echo RPC and fills in the expected values */
+class HijackingInterceptor : public experimental::Interceptor {
+ public:
+ HijackingInterceptor(experimental::ClientRpcInfo* info) {
+ info_ = info;
+ // Make sure it is the right method
+ EXPECT_EQ(strcmp("/grpc.testing.EchoTestService/Echo", info->method()), 0);
+ EXPECT_EQ(info->type(), experimental::ClientRpcInfo::Type::UNARY);
+ }
+
+ virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
+ bool hijack = false;
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
+ auto* map = methods->GetSendInitialMetadata();
+ // Check that we can see the test metadata
+ ASSERT_EQ(map->size(), static_cast<unsigned>(1));
+ auto iterator = map->begin();
+ EXPECT_EQ("testkey", iterator->first);
+ EXPECT_EQ("testvalue", iterator->second);
+ hijack = true;
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
+ EchoRequest req;
+ auto* buffer = methods->GetSerializedSendMessage();
+ auto copied_buffer = *buffer;
+ EXPECT_TRUE(
+ SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
+ .ok());
+ EXPECT_EQ(req.message(), "Hello");
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
+ // Got nothing to do here for now
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) {
+ auto* map = methods->GetRecvInitialMetadata();
+ // Got nothing better to do here for now
+ EXPECT_EQ(map->size(), static_cast<unsigned>(0));
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
+ EchoResponse* resp =
+ static_cast<EchoResponse*>(methods->GetRecvMessage());
+ // Check that we got the hijacked message, and re-insert the expected
+ // message
+ EXPECT_EQ(resp->message(), "Hello1");
+ resp->set_message("Hello");
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
+ auto* map = methods->GetRecvTrailingMetadata();
+ bool found = false;
+ // Check that we received the metadata as an echo
+ for (const auto& pair : *map) {
+ found = pair.first.starts_with("testkey") &&
+ pair.second.starts_with("testvalue");
+ if (found) break;
+ }
+ EXPECT_EQ(found, true);
+ auto* status = methods->GetRecvStatus();
+ EXPECT_EQ(status->ok(), true);
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_RECV_INITIAL_METADATA)) {
+ auto* map = methods->GetRecvInitialMetadata();
+ // Got nothing better to do here at the moment
+ EXPECT_EQ(map->size(), static_cast<unsigned>(0));
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
+ // Insert a different message than expected
+ EchoResponse* resp =
+ static_cast<EchoResponse*>(methods->GetRecvMessage());
+ resp->set_message("Hello1");
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
+ auto* map = methods->GetRecvTrailingMetadata();
+ // insert the metadata that we want
+ EXPECT_EQ(map->size(), static_cast<unsigned>(0));
+ map->insert(std::make_pair("testkey", "testvalue"));
+ auto* status = methods->GetRecvStatus();
+ *status = Status(StatusCode::OK, "");
+ }
+ if (hijack) {
+ methods->Hijack();
+ } else {
+ methods->Proceed();
+ }
+ }
+
+ private:
+ experimental::ClientRpcInfo* info_;
+};
+
+class HijackingInterceptorFactory
+ : public experimental::ClientInterceptorFactoryInterface {
+ public:
+ virtual experimental::Interceptor* CreateClientInterceptor(
+ experimental::ClientRpcInfo* info) override {
+ return new HijackingInterceptor(info);
+ }
+};
+
+class HijackingInterceptorMakesAnotherCall : public experimental::Interceptor {
+ public:
+ HijackingInterceptorMakesAnotherCall(experimental::ClientRpcInfo* info) {
+ info_ = info;
+ // Make sure it is the right method
+ EXPECT_EQ(strcmp("/grpc.testing.EchoTestService/Echo", info->method()), 0);
+ }
+
+ virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
+ auto* map = methods->GetSendInitialMetadata();
+ // Check that we can see the test metadata
+ ASSERT_EQ(map->size(), static_cast<unsigned>(1));
+ auto iterator = map->begin();
+ EXPECT_EQ("testkey", iterator->first);
+ EXPECT_EQ("testvalue", iterator->second);
+ // Make a copy of the map
+ metadata_map_ = *map;
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
+ EchoRequest req;
+ auto* buffer = methods->GetSerializedSendMessage();
+ auto copied_buffer = *buffer;
+ EXPECT_TRUE(
+ SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
+ .ok());
+ EXPECT_EQ(req.message(), "Hello");
+ req_ = req;
+ stub_ = grpc::testing::EchoTestService::NewStub(
+ methods->GetInterceptedChannel());
+ ctx_.AddMetadata(metadata_map_.begin()->first,
+ metadata_map_.begin()->second);
+ stub_->experimental_async()->Echo(&ctx_, &req_, &resp_,
+ [this, methods](Status s) {
+ EXPECT_EQ(s.ok(), true);
+ EXPECT_EQ(resp_.message(), "Hello");
+ methods->Hijack();
+ });
+ // This is a Unary RPC and we have got nothing interesting to do in the
+ // PRE_SEND_CLOSE interception hook point for this interceptor, so let's
+ // return here. (We do not want to call methods->Proceed(). When the new
+ // RPC returns, we will call methods->Hijack() instead.)
+ return;
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
+ // Got nothing to do here for now
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) {
+ auto* map = methods->GetRecvInitialMetadata();
+ // Got nothing better to do here for now
+ EXPECT_EQ(map->size(), static_cast<unsigned>(0));
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
+ EchoResponse* resp =
+ static_cast<EchoResponse*>(methods->GetRecvMessage());
+ // Check that we got the hijacked message, and re-insert the expected
+ // message
+ EXPECT_EQ(resp->message(), "Hello");
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
+ auto* map = methods->GetRecvTrailingMetadata();
+ bool found = false;
+ // Check that we received the metadata as an echo
+ for (const auto& pair : *map) {
+ found = pair.first.starts_with("testkey") &&
+ pair.second.starts_with("testvalue");
+ if (found) break;
+ }
+ EXPECT_EQ(found, true);
+ auto* status = methods->GetRecvStatus();
+ EXPECT_EQ(status->ok(), true);
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_RECV_INITIAL_METADATA)) {
+ auto* map = methods->GetRecvInitialMetadata();
+ // Got nothing better to do here at the moment
+ EXPECT_EQ(map->size(), static_cast<unsigned>(0));
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
+ // Insert a different message than expected
+ EchoResponse* resp =
+ static_cast<EchoResponse*>(methods->GetRecvMessage());
+ resp->set_message(resp_.message());
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
+ auto* map = methods->GetRecvTrailingMetadata();
+ // insert the metadata that we want
+ EXPECT_EQ(map->size(), static_cast<unsigned>(0));
+ map->insert(std::make_pair("testkey", "testvalue"));
+ auto* status = methods->GetRecvStatus();
+ *status = Status(StatusCode::OK, "");
+ }
+
+ methods->Proceed();
+ }
+
+ private:
+ experimental::ClientRpcInfo* info_;
+ std::multimap<TString, TString> metadata_map_;
+ ClientContext ctx_;
+ EchoRequest req_;
+ EchoResponse resp_;
+ std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_;
+};
+
+class HijackingInterceptorMakesAnotherCallFactory
+ : public experimental::ClientInterceptorFactoryInterface {
+ public:
+ virtual experimental::Interceptor* CreateClientInterceptor(
+ experimental::ClientRpcInfo* info) override {
+ return new HijackingInterceptorMakesAnotherCall(info);
+ }
+};
+
+class BidiStreamingRpcHijackingInterceptor : public experimental::Interceptor {
+ public:
+ BidiStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) {
+ info_ = info;
+ }
+
+ virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
+ bool hijack = false;
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
+ CheckMetadata(*methods->GetSendInitialMetadata(), "testkey", "testvalue");
+ hijack = true;
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
+ EchoRequest req;
+ auto* buffer = methods->GetSerializedSendMessage();
+ auto copied_buffer = *buffer;
+ EXPECT_TRUE(
+ SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
+ .ok());
+ EXPECT_EQ(req.message().find("Hello"), 0u);
+ msg = req.message();
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
+ // Got nothing to do here for now
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
+ CheckMetadata(*methods->GetRecvTrailingMetadata(), "testkey",
+ "testvalue");
+ auto* status = methods->GetRecvStatus();
+ EXPECT_EQ(status->ok(), true);
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
+ EchoResponse* resp =
+ static_cast<EchoResponse*>(methods->GetRecvMessage());
+ resp->set_message(msg);
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
+ EXPECT_EQ(static_cast<EchoResponse*>(methods->GetRecvMessage())
+ ->message()
+ .find("Hello"),
+ 0u);
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
+ auto* map = methods->GetRecvTrailingMetadata();
+ // insert the metadata that we want
+ EXPECT_EQ(map->size(), static_cast<unsigned>(0));
+ map->insert(std::make_pair("testkey", "testvalue"));
+ auto* status = methods->GetRecvStatus();
+ *status = Status(StatusCode::OK, "");
+ }
+ if (hijack) {
+ methods->Hijack();
+ } else {
+ methods->Proceed();
+ }
+ }
+
+ private:
+ experimental::ClientRpcInfo* info_;
+ TString msg;
+};
+
+class ClientStreamingRpcHijackingInterceptor
+ : public experimental::Interceptor {
+ public:
+ ClientStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) {
+ info_ = info;
+ }
+ virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
+ bool hijack = false;
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
+ hijack = true;
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
+ if (++count_ > 10) {
+ methods->FailHijackedSendMessage();
+ }
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::POST_SEND_MESSAGE)) {
+ EXPECT_FALSE(got_failed_send_);
+ got_failed_send_ = !methods->GetSendMessageStatus();
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
+ auto* status = methods->GetRecvStatus();
+ *status = Status(StatusCode::UNAVAILABLE, "Done sending 10 messages");
+ }
+ if (hijack) {
+ methods->Hijack();
+ } else {
+ methods->Proceed();
+ }
+ }
+
+ static bool GotFailedSend() { return got_failed_send_; }
+
+ private:
+ experimental::ClientRpcInfo* info_;
+ int count_ = 0;
+ static bool got_failed_send_;
+};
+
+bool ClientStreamingRpcHijackingInterceptor::got_failed_send_ = false;
+
+class ClientStreamingRpcHijackingInterceptorFactory
+ : public experimental::ClientInterceptorFactoryInterface {
+ public:
+ virtual experimental::Interceptor* CreateClientInterceptor(
+ experimental::ClientRpcInfo* info) override {
+ return new ClientStreamingRpcHijackingInterceptor(info);
+ }
+};
+
+class ServerStreamingRpcHijackingInterceptor
+ : public experimental::Interceptor {
+ public:
+ ServerStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) {
+ info_ = info;
+ got_failed_message_ = false;
+ }
+
+ virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
+ bool hijack = false;
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
+ auto* map = methods->GetSendInitialMetadata();
+ // Check that we can see the test metadata
+ ASSERT_EQ(map->size(), static_cast<unsigned>(1));
+ auto iterator = map->begin();
+ EXPECT_EQ("testkey", iterator->first);
+ EXPECT_EQ("testvalue", iterator->second);
+ hijack = true;
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
+ EchoRequest req;
+ auto* buffer = methods->GetSerializedSendMessage();
+ auto copied_buffer = *buffer;
+ EXPECT_TRUE(
+ SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
+ .ok());
+ EXPECT_EQ(req.message(), "Hello");
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
+ // Got nothing to do here for now
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
+ auto* map = methods->GetRecvTrailingMetadata();
+ bool found = false;
+ // Check that we received the metadata as an echo
+ for (const auto& pair : *map) {
+ found = pair.first.starts_with("testkey") &&
+ pair.second.starts_with("testvalue");
+ if (found) break;
+ }
+ EXPECT_EQ(found, true);
+ auto* status = methods->GetRecvStatus();
+ EXPECT_EQ(status->ok(), true);
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
+ if (++count_ > 10) {
+ methods->FailHijackedRecvMessage();
+ }
+ EchoResponse* resp =
+ static_cast<EchoResponse*>(methods->GetRecvMessage());
+ resp->set_message("Hello");
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
+ // Only the last message will be a failure
+ EXPECT_FALSE(got_failed_message_);
+ got_failed_message_ = methods->GetRecvMessage() == nullptr;
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
+ auto* map = methods->GetRecvTrailingMetadata();
+ // insert the metadata that we want
+ EXPECT_EQ(map->size(), static_cast<unsigned>(0));
+ map->insert(std::make_pair("testkey", "testvalue"));
+ auto* status = methods->GetRecvStatus();
+ *status = Status(StatusCode::OK, "");
+ }
+ if (hijack) {
+ methods->Hijack();
+ } else {
+ methods->Proceed();
+ }
+ }
+
+ static bool GotFailedMessage() { return got_failed_message_; }
+
+ private:
+ experimental::ClientRpcInfo* info_;
+ static bool got_failed_message_;
+ int count_ = 0;
+};
+
+bool ServerStreamingRpcHijackingInterceptor::got_failed_message_ = false;
+
+class ServerStreamingRpcHijackingInterceptorFactory
+ : public experimental::ClientInterceptorFactoryInterface {
+ public:
+ virtual experimental::Interceptor* CreateClientInterceptor(
+ experimental::ClientRpcInfo* info) override {
+ return new ServerStreamingRpcHijackingInterceptor(info);
+ }
+};
+
+class BidiStreamingRpcHijackingInterceptorFactory
+ : public experimental::ClientInterceptorFactoryInterface {
+ public:
+ virtual experimental::Interceptor* CreateClientInterceptor(
+ experimental::ClientRpcInfo* info) override {
+ return new BidiStreamingRpcHijackingInterceptor(info);
+ }
+};
+
+// The logging interceptor is for testing purposes only. It is used to verify
+// that all the appropriate hook points are invoked for an RPC. The counts are
+// reset each time a new object of LoggingInterceptor is created, so only a
+// single RPC should be made on the channel before calling the Verify methods.
+class LoggingInterceptor : public experimental::Interceptor {
+ public:
+ LoggingInterceptor(experimental::ClientRpcInfo* /*info*/) {
+ pre_send_initial_metadata_ = false;
+ pre_send_message_count_ = 0;
+ pre_send_close_ = false;
+ post_recv_initial_metadata_ = false;
+ post_recv_message_count_ = 0;
+ post_recv_status_ = false;
+ }
+
+ virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
+ auto* map = methods->GetSendInitialMetadata();
+ // Check that we can see the test metadata
+ ASSERT_EQ(map->size(), static_cast<unsigned>(1));
+ auto iterator = map->begin();
+ EXPECT_EQ("testkey", iterator->first);
+ EXPECT_EQ("testvalue", iterator->second);
+ ASSERT_FALSE(pre_send_initial_metadata_);
+ pre_send_initial_metadata_ = true;
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
+ EchoRequest req;
+ auto* send_msg = methods->GetSendMessage();
+ if (send_msg == nullptr) {
+ // We did not get the non-serialized form of the message. Get the
+ // serialized form.
+ auto* buffer = methods->GetSerializedSendMessage();
+ auto copied_buffer = *buffer;
+ EchoRequest req;
+ EXPECT_TRUE(
+ SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
+ .ok());
+ EXPECT_EQ(req.message(), "Hello");
+ } else {
+ EXPECT_EQ(
+ static_cast<const EchoRequest*>(send_msg)->message().find("Hello"),
+ 0u);
+ }
+ auto* buffer = methods->GetSerializedSendMessage();
+ auto copied_buffer = *buffer;
+ EXPECT_TRUE(
+ SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
+ .ok());
+ EXPECT_TRUE(req.message().find("Hello") == 0u);
+ pre_send_message_count_++;
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
+ // Got nothing to do here for now
+ pre_send_close_ = true;
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) {
+ auto* map = methods->GetRecvInitialMetadata();
+ // Got nothing better to do here for now
+ EXPECT_EQ(map->size(), static_cast<unsigned>(0));
+ post_recv_initial_metadata_ = true;
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
+ EchoResponse* resp =
+ static_cast<EchoResponse*>(methods->GetRecvMessage());
+ if (resp != nullptr) {
+ EXPECT_TRUE(resp->message().find("Hello") == 0u);
+ post_recv_message_count_++;
+ }
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
+ auto* map = methods->GetRecvTrailingMetadata();
+ bool found = false;
+ // Check that we received the metadata as an echo
+ for (const auto& pair : *map) {
+ found = pair.first.starts_with("testkey") &&
+ pair.second.starts_with("testvalue");
+ if (found) break;
+ }
+ EXPECT_EQ(found, true);
+ auto* status = methods->GetRecvStatus();
+ EXPECT_EQ(status->ok(), true);
+ post_recv_status_ = true;
+ }
+ methods->Proceed();
+ }
+
+ static void VerifyCall(RPCType type) {
+ switch (type) {
+ case RPCType::kSyncUnary:
+ case RPCType::kAsyncCQUnary:
+ VerifyUnaryCall();
+ break;
+ case RPCType::kSyncClientStreaming:
+ case RPCType::kAsyncCQClientStreaming:
+ VerifyClientStreamingCall();
+ break;
+ case RPCType::kSyncServerStreaming:
+ case RPCType::kAsyncCQServerStreaming:
+ VerifyServerStreamingCall();
+ break;
+ case RPCType::kSyncBidiStreaming:
+ case RPCType::kAsyncCQBidiStreaming:
+ VerifyBidiStreamingCall();
+ break;
+ }
+ }
+
+ static void VerifyCallCommon() {
+ EXPECT_TRUE(pre_send_initial_metadata_);
+ EXPECT_TRUE(pre_send_close_);
+ EXPECT_TRUE(post_recv_initial_metadata_);
+ EXPECT_TRUE(post_recv_status_);
+ }
+
+ static void VerifyUnaryCall() {
+ VerifyCallCommon();
+ EXPECT_EQ(pre_send_message_count_, 1);
+ EXPECT_EQ(post_recv_message_count_, 1);
+ }
+
+ static void VerifyClientStreamingCall() {
+ VerifyCallCommon();
+ EXPECT_EQ(pre_send_message_count_, kNumStreamingMessages);
+ EXPECT_EQ(post_recv_message_count_, 1);
+ }
+
+ static void VerifyServerStreamingCall() {
+ VerifyCallCommon();
+ EXPECT_EQ(pre_send_message_count_, 1);
+ EXPECT_EQ(post_recv_message_count_, kNumStreamingMessages);
+ }
+
+ static void VerifyBidiStreamingCall() {
+ VerifyCallCommon();
+ EXPECT_EQ(pre_send_message_count_, kNumStreamingMessages);
+ EXPECT_EQ(post_recv_message_count_, kNumStreamingMessages);
+ }
+
+ private:
+ static bool pre_send_initial_metadata_;
+ static int pre_send_message_count_;
+ static bool pre_send_close_;
+ static bool post_recv_initial_metadata_;
+ static int post_recv_message_count_;
+ static bool post_recv_status_;
+};
+
+bool LoggingInterceptor::pre_send_initial_metadata_;
+int LoggingInterceptor::pre_send_message_count_;
+bool LoggingInterceptor::pre_send_close_;
+bool LoggingInterceptor::post_recv_initial_metadata_;
+int LoggingInterceptor::post_recv_message_count_;
+bool LoggingInterceptor::post_recv_status_;
+
+class LoggingInterceptorFactory
+ : public experimental::ClientInterceptorFactoryInterface {
+ public:
+ virtual experimental::Interceptor* CreateClientInterceptor(
+ experimental::ClientRpcInfo* info) override {
+ return new LoggingInterceptor(info);
+ }
+};
+
+class TestScenario {
+ public:
+ explicit TestScenario(const RPCType& type) : type_(type) {}
+
+ RPCType type() const { return type_; }
+
+ private:
+ RPCType type_;
+};
+
+std::vector<TestScenario> CreateTestScenarios() {
+ std::vector<TestScenario> scenarios;
+ scenarios.emplace_back(RPCType::kSyncUnary);
+ scenarios.emplace_back(RPCType::kSyncClientStreaming);
+ scenarios.emplace_back(RPCType::kSyncServerStreaming);
+ scenarios.emplace_back(RPCType::kSyncBidiStreaming);
+ scenarios.emplace_back(RPCType::kAsyncCQUnary);
+ scenarios.emplace_back(RPCType::kAsyncCQServerStreaming);
+ return scenarios;
+}
+
+class ParameterizedClientInterceptorsEnd2endTest
+ : public ::testing::TestWithParam<TestScenario> {
+ protected:
+ ParameterizedClientInterceptorsEnd2endTest() {
+ int port = grpc_pick_unused_port_or_die();
+
+ ServerBuilder builder;
+ server_address_ = "localhost:" + ToString(port);
+ builder.AddListeningPort(server_address_, InsecureServerCredentials());
+ builder.RegisterService(&service_);
+ server_ = builder.BuildAndStart();
+ }
+
+ ~ParameterizedClientInterceptorsEnd2endTest() { server_->Shutdown(); }
+
+ void SendRPC(const std::shared_ptr<Channel>& channel) {
+ switch (GetParam().type()) {
+ case RPCType::kSyncUnary:
+ MakeCall(channel);
+ break;
+ case RPCType::kSyncClientStreaming:
+ MakeClientStreamingCall(channel);
+ break;
+ case RPCType::kSyncServerStreaming:
+ MakeServerStreamingCall(channel);
+ break;
+ case RPCType::kSyncBidiStreaming:
+ MakeBidiStreamingCall(channel);
+ break;
+ case RPCType::kAsyncCQUnary:
+ MakeAsyncCQCall(channel);
+ break;
+ case RPCType::kAsyncCQClientStreaming:
+ // TODO(yashykt) : Fill this out
+ break;
+ case RPCType::kAsyncCQServerStreaming:
+ MakeAsyncCQServerStreamingCall(channel);
+ break;
+ case RPCType::kAsyncCQBidiStreaming:
+ // TODO(yashykt) : Fill this out
+ break;
+ }
+ }
+
+ TString server_address_;
+ EchoTestServiceStreamingImpl service_;
+ std::unique_ptr<Server> server_;
+};
+
+TEST_P(ParameterizedClientInterceptorsEnd2endTest,
+ ClientInterceptorLoggingTest) {
+ ChannelArguments args;
+ DummyInterceptor::Reset();
+ std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
+ creators;
+ creators.push_back(std::unique_ptr<LoggingInterceptorFactory>(
+ new LoggingInterceptorFactory()));
+ // Add 20 dummy interceptors
+ for (auto i = 0; i < 20; i++) {
+ creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
+ new DummyInterceptorFactory()));
+ }
+ auto channel = experimental::CreateCustomChannelWithInterceptors(
+ server_address_, InsecureChannelCredentials(), args, std::move(creators));
+ SendRPC(channel);
+ LoggingInterceptor::VerifyCall(GetParam().type());
+ // Make sure all 20 dummy interceptors were run
+ EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
+}
+
+INSTANTIATE_TEST_SUITE_P(ParameterizedClientInterceptorsEnd2end,
+ ParameterizedClientInterceptorsEnd2endTest,
+ ::testing::ValuesIn(CreateTestScenarios()));
+
+class ClientInterceptorsEnd2endTest
+ : public ::testing::TestWithParam<TestScenario> {
+ protected:
+ ClientInterceptorsEnd2endTest() {
+ int port = grpc_pick_unused_port_or_die();
+
+ ServerBuilder builder;
+ server_address_ = "localhost:" + ToString(port);
+ builder.AddListeningPort(server_address_, InsecureServerCredentials());
+ builder.RegisterService(&service_);
+ server_ = builder.BuildAndStart();
+ }
+
+ ~ClientInterceptorsEnd2endTest() { server_->Shutdown(); }
+
+ TString server_address_;
+ TestServiceImpl service_;
+ std::unique_ptr<Server> server_;
+};
+
+TEST_F(ClientInterceptorsEnd2endTest,
+ LameChannelClientInterceptorHijackingTest) {
+ ChannelArguments args;
+ std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
+ creators;
+ creators.push_back(std::unique_ptr<HijackingInterceptorFactory>(
+ new HijackingInterceptorFactory()));
+ auto channel = experimental::CreateCustomChannelWithInterceptors(
+ server_address_, nullptr, args, std::move(creators));
+ MakeCall(channel);
+}
+
+TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorHijackingTest) {
+ ChannelArguments args;
+ DummyInterceptor::Reset();
+ std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
+ creators;
+ // Add 20 dummy interceptors before hijacking interceptor
+ creators.reserve(20);
+ for (auto i = 0; i < 20; i++) {
+ creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
+ new DummyInterceptorFactory()));
+ }
+ creators.push_back(std::unique_ptr<HijackingInterceptorFactory>(
+ new HijackingInterceptorFactory()));
+ // Add 20 dummy interceptors after hijacking interceptor
+ for (auto i = 0; i < 20; i++) {
+ creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
+ new DummyInterceptorFactory()));
+ }
+ auto channel = experimental::CreateCustomChannelWithInterceptors(
+ server_address_, InsecureChannelCredentials(), args, std::move(creators));
+ MakeCall(channel);
+ // Make sure only 20 dummy interceptors were run
+ EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
+}
+
+TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorLogThenHijackTest) {
+ ChannelArguments args;
+ std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
+ creators;
+ creators.push_back(std::unique_ptr<LoggingInterceptorFactory>(
+ new LoggingInterceptorFactory()));
+ creators.push_back(std::unique_ptr<HijackingInterceptorFactory>(
+ new HijackingInterceptorFactory()));
+ auto channel = experimental::CreateCustomChannelWithInterceptors(
+ server_address_, InsecureChannelCredentials(), args, std::move(creators));
+ MakeCall(channel);
+ LoggingInterceptor::VerifyUnaryCall();
+}
+
+TEST_F(ClientInterceptorsEnd2endTest,
+ ClientInterceptorHijackingMakesAnotherCallTest) {
+ ChannelArguments args;
+ DummyInterceptor::Reset();
+ std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
+ creators;
+ // Add 5 dummy interceptors before hijacking interceptor
+ creators.reserve(5);
+ for (auto i = 0; i < 5; i++) {
+ creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
+ new DummyInterceptorFactory()));
+ }
+ creators.push_back(
+ std::unique_ptr<experimental::ClientInterceptorFactoryInterface>(
+ new HijackingInterceptorMakesAnotherCallFactory()));
+ // Add 7 dummy interceptors after hijacking interceptor
+ for (auto i = 0; i < 7; i++) {
+ creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
+ new DummyInterceptorFactory()));
+ }
+ auto channel = server_->experimental().InProcessChannelWithInterceptors(
+ args, std::move(creators));
+
+ MakeCall(channel);
+ // Make sure all interceptors were run once, since the hijacking interceptor
+ // makes an RPC on the intercepted channel
+ EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 12);
+}
+
+class ClientInterceptorsCallbackEnd2endTest : public ::testing::Test {
+ protected:
+ ClientInterceptorsCallbackEnd2endTest() {
+ int port = grpc_pick_unused_port_or_die();
+
+ ServerBuilder builder;
+ server_address_ = "localhost:" + ToString(port);
+ builder.AddListeningPort(server_address_, InsecureServerCredentials());
+ builder.RegisterService(&service_);
+ server_ = builder.BuildAndStart();
+ }
+
+ ~ClientInterceptorsCallbackEnd2endTest() { server_->Shutdown(); }
+
+ TString server_address_;
+ TestServiceImpl service_;
+ std::unique_ptr<Server> server_;
+};
+
+TEST_F(ClientInterceptorsCallbackEnd2endTest,
+ ClientInterceptorLoggingTestWithCallback) {
+ ChannelArguments args;
+ DummyInterceptor::Reset();
+ std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
+ creators;
+ creators.push_back(std::unique_ptr<LoggingInterceptorFactory>(
+ new LoggingInterceptorFactory()));
+ // Add 20 dummy interceptors
+ for (auto i = 0; i < 20; i++) {
+ creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
+ new DummyInterceptorFactory()));
+ }
+ auto channel = server_->experimental().InProcessChannelWithInterceptors(
+ args, std::move(creators));
+ MakeCallbackCall(channel);
+ LoggingInterceptor::VerifyUnaryCall();
+ // Make sure all 20 dummy interceptors were run
+ EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
+}
+
+TEST_F(ClientInterceptorsCallbackEnd2endTest,
+ ClientInterceptorFactoryAllowsNullptrReturn) {
+ ChannelArguments args;
+ DummyInterceptor::Reset();
+ std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
+ creators;
+ creators.push_back(std::unique_ptr<LoggingInterceptorFactory>(
+ new LoggingInterceptorFactory()));
+ // Add 20 dummy interceptors and 20 null interceptors
+ for (auto i = 0; i < 20; i++) {
+ creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
+ new DummyInterceptorFactory()));
+ creators.push_back(
+ std::unique_ptr<NullInterceptorFactory>(new NullInterceptorFactory()));
+ }
+ auto channel = server_->experimental().InProcessChannelWithInterceptors(
+ args, std::move(creators));
+ MakeCallbackCall(channel);
+ LoggingInterceptor::VerifyUnaryCall();
+ // Make sure all 20 dummy interceptors were run
+ EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
+}
+
+class ClientInterceptorsStreamingEnd2endTest : public ::testing::Test {
+ protected:
+ ClientInterceptorsStreamingEnd2endTest() {
+ int port = grpc_pick_unused_port_or_die();
+
+ ServerBuilder builder;
+ server_address_ = "localhost:" + ToString(port);
+ builder.AddListeningPort(server_address_, InsecureServerCredentials());
+ builder.RegisterService(&service_);
+ server_ = builder.BuildAndStart();
+ }
+
+ ~ClientInterceptorsStreamingEnd2endTest() { server_->Shutdown(); }
+
+ TString server_address_;
+ EchoTestServiceStreamingImpl service_;
+ std::unique_ptr<Server> server_;
+};
+
+TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingTest) {
+ ChannelArguments args;
+ DummyInterceptor::Reset();
+ std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
+ creators;
+ creators.push_back(std::unique_ptr<LoggingInterceptorFactory>(
+ new LoggingInterceptorFactory()));
+ // Add 20 dummy interceptors
+ for (auto i = 0; i < 20; i++) {
+ creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
+ new DummyInterceptorFactory()));
+ }
+ auto channel = experimental::CreateCustomChannelWithInterceptors(
+ server_address_, InsecureChannelCredentials(), args, std::move(creators));
+ MakeClientStreamingCall(channel);
+ LoggingInterceptor::VerifyClientStreamingCall();
+ // Make sure all 20 dummy interceptors were run
+ EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
+}
+
+TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingTest) {
+ ChannelArguments args;
+ DummyInterceptor::Reset();
+ std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
+ creators;
+ creators.push_back(std::unique_ptr<LoggingInterceptorFactory>(
+ new LoggingInterceptorFactory()));
+ // Add 20 dummy interceptors
+ for (auto i = 0; i < 20; i++) {
+ creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
+ new DummyInterceptorFactory()));
+ }
+ auto channel = experimental::CreateCustomChannelWithInterceptors(
+ server_address_, InsecureChannelCredentials(), args, std::move(creators));
+ MakeServerStreamingCall(channel);
+ LoggingInterceptor::VerifyServerStreamingCall();
+ // Make sure all 20 dummy interceptors were run
+ EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
+}
+
+TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingHijackingTest) {
+ ChannelArguments args;
+ std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
+ creators;
+ creators.push_back(
+ std::unique_ptr<ClientStreamingRpcHijackingInterceptorFactory>(
+ new ClientStreamingRpcHijackingInterceptorFactory()));
+ auto channel = experimental::CreateCustomChannelWithInterceptors(
+ server_address_, InsecureChannelCredentials(), args, std::move(creators));
+
+ auto stub = grpc::testing::EchoTestService::NewStub(channel);
+ ClientContext ctx;
+ EchoRequest req;
+ EchoResponse resp;
+ req.mutable_param()->set_echo_metadata(true);
+ req.set_message("Hello");
+ string expected_resp = "";
+ auto writer = stub->RequestStream(&ctx, &resp);
+ for (int i = 0; i < 10; i++) {
+ EXPECT_TRUE(writer->Write(req));
+ expected_resp += "Hello";
+ }
+ // The interceptor will reject the 11th message
+ writer->Write(req);
+ Status s = writer->Finish();
+ EXPECT_EQ(s.ok(), false);
+ EXPECT_TRUE(ClientStreamingRpcHijackingInterceptor::GotFailedSend());
+}
+
+TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingHijackingTest) {
+ ChannelArguments args;
+ DummyInterceptor::Reset();
+ std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
+ creators;
+ creators.push_back(
+ std::unique_ptr<ServerStreamingRpcHijackingInterceptorFactory>(
+ new ServerStreamingRpcHijackingInterceptorFactory()));
+ auto channel = experimental::CreateCustomChannelWithInterceptors(
+ server_address_, InsecureChannelCredentials(), args, std::move(creators));
+ MakeServerStreamingCall(channel);
+ EXPECT_TRUE(ServerStreamingRpcHijackingInterceptor::GotFailedMessage());
+}
+
+TEST_F(ClientInterceptorsStreamingEnd2endTest,
+ AsyncCQServerStreamingHijackingTest) {
+ ChannelArguments args;
+ DummyInterceptor::Reset();
+ std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
+ creators;
+ creators.push_back(
+ std::unique_ptr<ServerStreamingRpcHijackingInterceptorFactory>(
+ new ServerStreamingRpcHijackingInterceptorFactory()));
+ auto channel = experimental::CreateCustomChannelWithInterceptors(
+ server_address_, InsecureChannelCredentials(), args, std::move(creators));
+ MakeAsyncCQServerStreamingCall(channel);
+ EXPECT_TRUE(ServerStreamingRpcHijackingInterceptor::GotFailedMessage());
+}
+
+TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingHijackingTest) {
+ ChannelArguments args;
+ DummyInterceptor::Reset();
+ std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
+ creators;
+ creators.push_back(
+ std::unique_ptr<BidiStreamingRpcHijackingInterceptorFactory>(
+ new BidiStreamingRpcHijackingInterceptorFactory()));
+ auto channel = experimental::CreateCustomChannelWithInterceptors(
+ server_address_, InsecureChannelCredentials(), args, std::move(creators));
+ MakeBidiStreamingCall(channel);
+}
+
+TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingTest) {
+ ChannelArguments args;
+ DummyInterceptor::Reset();
+ std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
+ creators;
+ creators.push_back(std::unique_ptr<LoggingInterceptorFactory>(
+ new LoggingInterceptorFactory()));
+ // Add 20 dummy interceptors
+ for (auto i = 0; i < 20; i++) {
+ creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
+ new DummyInterceptorFactory()));
+ }
+ auto channel = experimental::CreateCustomChannelWithInterceptors(
+ server_address_, InsecureChannelCredentials(), args, std::move(creators));
+ MakeBidiStreamingCall(channel);
+ LoggingInterceptor::VerifyBidiStreamingCall();
+ // Make sure all 20 dummy interceptors were run
+ EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
+}
+
+class ClientGlobalInterceptorEnd2endTest : public ::testing::Test {
+ protected:
+ ClientGlobalInterceptorEnd2endTest() {
+ int port = grpc_pick_unused_port_or_die();
+
+ ServerBuilder builder;
+ server_address_ = "localhost:" + ToString(port);
+ builder.AddListeningPort(server_address_, InsecureServerCredentials());
+ builder.RegisterService(&service_);
+ server_ = builder.BuildAndStart();
+ }
+
+ ~ClientGlobalInterceptorEnd2endTest() { server_->Shutdown(); }
+
+ TString server_address_;
+ TestServiceImpl service_;
+ std::unique_ptr<Server> server_;
+};
+
+TEST_F(ClientGlobalInterceptorEnd2endTest, DummyGlobalInterceptor) {
+ // We should ideally be registering a global interceptor only once per
+ // process, but for the purposes of testing, it should be fine to modify the
+ // registered global interceptor when there are no ongoing gRPC operations
+ DummyInterceptorFactory global_factory;
+ experimental::RegisterGlobalClientInterceptorFactory(&global_factory);
+ ChannelArguments args;
+ DummyInterceptor::Reset();
+ std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
+ creators;
+ // Add 20 dummy interceptors
+ creators.reserve(20);
+ for (auto i = 0; i < 20; i++) {
+ creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
+ new DummyInterceptorFactory()));
+ }
+ auto channel = experimental::CreateCustomChannelWithInterceptors(
+ server_address_, InsecureChannelCredentials(), args, std::move(creators));
+ MakeCall(channel);
+ // Make sure all 20 dummy interceptors were run with the global interceptor
+ EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 21);
+ experimental::TestOnlyResetGlobalClientInterceptorFactory();
+}
+
+TEST_F(ClientGlobalInterceptorEnd2endTest, LoggingGlobalInterceptor) {
+ // We should ideally be registering a global interceptor only once per
+ // process, but for the purposes of testing, it should be fine to modify the
+ // registered global interceptor when there are no ongoing gRPC operations
+ LoggingInterceptorFactory global_factory;
+ experimental::RegisterGlobalClientInterceptorFactory(&global_factory);
+ ChannelArguments args;
+ DummyInterceptor::Reset();
+ std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
+ creators;
+ // Add 20 dummy interceptors
+ creators.reserve(20);
+ for (auto i = 0; i < 20; i++) {
+ creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
+ new DummyInterceptorFactory()));
+ }
+ auto channel = experimental::CreateCustomChannelWithInterceptors(
+ server_address_, InsecureChannelCredentials(), args, std::move(creators));
+ MakeCall(channel);
+ LoggingInterceptor::VerifyUnaryCall();
+ // Make sure all 20 dummy interceptors were run
+ EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
+ experimental::TestOnlyResetGlobalClientInterceptorFactory();
+}
+
+TEST_F(ClientGlobalInterceptorEnd2endTest, HijackingGlobalInterceptor) {
+ // We should ideally be registering a global interceptor only once per
+ // process, but for the purposes of testing, it should be fine to modify the
+ // registered global interceptor when there are no ongoing gRPC operations
+ HijackingInterceptorFactory global_factory;
+ experimental::RegisterGlobalClientInterceptorFactory(&global_factory);
+ ChannelArguments args;
+ DummyInterceptor::Reset();
+ std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
+ creators;
+ // Add 20 dummy interceptors
+ creators.reserve(20);
+ for (auto i = 0; i < 20; i++) {
+ creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
+ new DummyInterceptorFactory()));
+ }
+ auto channel = experimental::CreateCustomChannelWithInterceptors(
+ server_address_, InsecureChannelCredentials(), args, std::move(creators));
+ MakeCall(channel);
+ // Make sure all 20 dummy interceptors were run
+ EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
+ experimental::TestOnlyResetGlobalClientInterceptorFactory();
+}
+
+} // namespace
+} // namespace testing
+} // namespace grpc
+
+int main(int argc, char** argv) {
+ grpc::testing::TestEnvironment env(argc, argv);
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/contrib/libs/grpc/test/cpp/end2end/client_lb_end2end_test.cc b/contrib/libs/grpc/test/cpp/end2end/client_lb_end2end_test.cc
new file mode 100644
index 0000000000..fd08dd163d
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/end2end/client_lb_end2end_test.cc
@@ -0,0 +1,1990 @@
+/*
+ *
+ * Copyright 2016 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include <algorithm>
+#include <memory>
+#include <mutex>
+#include <random>
+#include <set>
+#include <util/generic/string.h>
+#include <thread>
+
+#include "y_absl/strings/str_cat.h"
+
+#include <grpc/grpc.h>
+#include <grpc/support/alloc.h>
+#include <grpc/support/atm.h>
+#include <grpc/support/log.h>
+#include <grpc/support/time.h>
+#include <grpcpp/channel.h>
+#include <grpcpp/client_context.h>
+#include <grpcpp/create_channel.h>
+#include <grpcpp/health_check_service_interface.h>
+#include <grpcpp/impl/codegen/sync.h>
+#include <grpcpp/server.h>
+#include <grpcpp/server_builder.h>
+
+#include "src/core/ext/filters/client_channel/backup_poller.h"
+#include "src/core/ext/filters/client_channel/global_subchannel_pool.h"
+#include "src/core/ext/filters/client_channel/resolver/fake/fake_resolver.h"
+#include "src/core/ext/filters/client_channel/server_address.h"
+#include "src/core/ext/filters/client_channel/service_config.h"
+#include "src/core/lib/backoff/backoff.h"
+#include "src/core/lib/channel/channel_args.h"
+#include "src/core/lib/gpr/env.h"
+#include "src/core/lib/gprpp/debug_location.h"
+#include "src/core/lib/gprpp/ref_counted_ptr.h"
+#include "src/core/lib/iomgr/parse_address.h"
+#include "src/core/lib/iomgr/tcp_client.h"
+#include "src/core/lib/security/credentials/fake/fake_credentials.h"
+#include "src/cpp/client/secure_credentials.h"
+#include "src/cpp/server/secure_server_credentials.h"
+
+#include "src/proto/grpc/testing/echo.grpc.pb.h"
+#include "src/proto/grpc/testing/xds/orca_load_report_for_test.pb.h"
+#include "test/core/util/port.h"
+#include "test/core/util/test_config.h"
+#include "test/core/util/test_lb_policies.h"
+#include "test/cpp/end2end/test_service_impl.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+using grpc::testing::EchoRequest;
+using grpc::testing::EchoResponse;
+using std::chrono::system_clock;
+
+// defined in tcp_client.cc
+extern grpc_tcp_client_vtable* grpc_tcp_client_impl;
+
+static grpc_tcp_client_vtable* default_client_impl;
+
+namespace grpc {
+namespace testing {
+namespace {
+
+gpr_atm g_connection_delay_ms;
+
+void tcp_client_connect_with_delay(grpc_closure* closure, grpc_endpoint** ep,
+ grpc_pollset_set* interested_parties,
+ const grpc_channel_args* channel_args,
+ const grpc_resolved_address* addr,
+ grpc_millis deadline) {
+ const int delay_ms = gpr_atm_acq_load(&g_connection_delay_ms);
+ if (delay_ms > 0) {
+ gpr_sleep_until(grpc_timeout_milliseconds_to_deadline(delay_ms));
+ }
+ default_client_impl->connect(closure, ep, interested_parties, channel_args,
+ addr, deadline + delay_ms);
+}
+
+grpc_tcp_client_vtable delayed_connect = {tcp_client_connect_with_delay};
+
+// Subclass of TestServiceImpl that increments a request counter for
+// every call to the Echo RPC.
+class MyTestServiceImpl : public TestServiceImpl {
+ public:
+ Status Echo(ServerContext* context, const EchoRequest* request,
+ EchoResponse* response) override {
+ const udpa::data::orca::v1::OrcaLoadReport* load_report = nullptr;
+ {
+ grpc::internal::MutexLock lock(&mu_);
+ ++request_count_;
+ load_report = load_report_;
+ }
+ AddClient(context->peer().c_str());
+ if (load_report != nullptr) {
+ // TODO(roth): Once we provide a more standard server-side API for
+ // populating this data, use that API here.
+ context->AddTrailingMetadata("x-endpoint-load-metrics-bin",
+ load_report->SerializeAsString());
+ }
+ return TestServiceImpl::Echo(context, request, response);
+ }
+
+ int request_count() {
+ grpc::internal::MutexLock lock(&mu_);
+ return request_count_;
+ }
+
+ void ResetCounters() {
+ grpc::internal::MutexLock lock(&mu_);
+ request_count_ = 0;
+ }
+
+ std::set<TString> clients() {
+ grpc::internal::MutexLock lock(&clients_mu_);
+ return clients_;
+ }
+
+ void set_load_report(udpa::data::orca::v1::OrcaLoadReport* load_report) {
+ grpc::internal::MutexLock lock(&mu_);
+ load_report_ = load_report;
+ }
+
+ private:
+ void AddClient(const TString& client) {
+ grpc::internal::MutexLock lock(&clients_mu_);
+ clients_.insert(client);
+ }
+
+ grpc::internal::Mutex mu_;
+ int request_count_ = 0;
+ const udpa::data::orca::v1::OrcaLoadReport* load_report_ = nullptr;
+ grpc::internal::Mutex clients_mu_;
+ std::set<TString> clients_;
+};
+
+class FakeResolverResponseGeneratorWrapper {
+ public:
+ FakeResolverResponseGeneratorWrapper()
+ : response_generator_(grpc_core::MakeRefCounted<
+ grpc_core::FakeResolverResponseGenerator>()) {}
+
+ FakeResolverResponseGeneratorWrapper(
+ FakeResolverResponseGeneratorWrapper&& other) noexcept {
+ response_generator_ = std::move(other.response_generator_);
+ }
+
+ void SetNextResolution(
+ const std::vector<int>& ports, const char* service_config_json = nullptr,
+ const char* attribute_key = nullptr,
+ std::unique_ptr<grpc_core::ServerAddress::AttributeInterface> attribute =
+ nullptr) {
+ grpc_core::ExecCtx exec_ctx;
+ response_generator_->SetResponse(BuildFakeResults(
+ ports, service_config_json, attribute_key, std::move(attribute)));
+ }
+
+ void SetNextResolutionUponError(const std::vector<int>& ports) {
+ grpc_core::ExecCtx exec_ctx;
+ response_generator_->SetReresolutionResponse(BuildFakeResults(ports));
+ }
+
+ void SetFailureOnReresolution() {
+ grpc_core::ExecCtx exec_ctx;
+ response_generator_->SetFailureOnReresolution();
+ }
+
+ grpc_core::FakeResolverResponseGenerator* Get() const {
+ return response_generator_.get();
+ }
+
+ private:
+ static grpc_core::Resolver::Result BuildFakeResults(
+ const std::vector<int>& ports, const char* service_config_json = nullptr,
+ const char* attribute_key = nullptr,
+ std::unique_ptr<grpc_core::ServerAddress::AttributeInterface> attribute =
+ nullptr) {
+ grpc_core::Resolver::Result result;
+ for (const int& port : ports) {
+ TString lb_uri_str = y_absl::StrCat("ipv4:127.0.0.1:", port);
+ grpc_uri* lb_uri = grpc_uri_parse(lb_uri_str.c_str(), true);
+ GPR_ASSERT(lb_uri != nullptr);
+ grpc_resolved_address address;
+ GPR_ASSERT(grpc_parse_uri(lb_uri, &address));
+ std::map<const char*,
+ std::unique_ptr<grpc_core::ServerAddress::AttributeInterface>>
+ attributes;
+ if (attribute != nullptr) {
+ attributes[attribute_key] = attribute->Copy();
+ }
+ result.addresses.emplace_back(address.addr, address.len,
+ nullptr /* args */, std::move(attributes));
+ grpc_uri_destroy(lb_uri);
+ }
+ if (service_config_json != nullptr) {
+ result.service_config = grpc_core::ServiceConfig::Create(
+ nullptr, service_config_json, &result.service_config_error);
+ GPR_ASSERT(result.service_config != nullptr);
+ }
+ return result;
+ }
+
+ grpc_core::RefCountedPtr<grpc_core::FakeResolverResponseGenerator>
+ response_generator_;
+};
+
+class ClientLbEnd2endTest : public ::testing::Test {
+ protected:
+ ClientLbEnd2endTest()
+ : server_host_("localhost"),
+ kRequestMessage_("Live long and prosper."),
+ creds_(new SecureChannelCredentials(
+ grpc_fake_transport_security_credentials_create())) {}
+
+ static void SetUpTestCase() {
+ // Make the backup poller poll very frequently in order to pick up
+ // updates from all the subchannels's FDs.
+ GPR_GLOBAL_CONFIG_SET(grpc_client_channel_backup_poll_interval_ms, 1);
+#if TARGET_OS_IPHONE
+ // Workaround Apple CFStream bug
+ gpr_setenv("grpc_cfstream", "0");
+#endif
+ }
+
+ void SetUp() override { grpc_init(); }
+
+ void TearDown() override {
+ for (size_t i = 0; i < servers_.size(); ++i) {
+ servers_[i]->Shutdown();
+ }
+ servers_.clear();
+ creds_.reset();
+ grpc_shutdown_blocking();
+ }
+
+ void CreateServers(size_t num_servers,
+ std::vector<int> ports = std::vector<int>()) {
+ servers_.clear();
+ for (size_t i = 0; i < num_servers; ++i) {
+ int port = 0;
+ if (ports.size() == num_servers) port = ports[i];
+ servers_.emplace_back(new ServerData(port));
+ }
+ }
+
+ void StartServer(size_t index) { servers_[index]->Start(server_host_); }
+
+ void StartServers(size_t num_servers,
+ std::vector<int> ports = std::vector<int>()) {
+ CreateServers(num_servers, std::move(ports));
+ for (size_t i = 0; i < num_servers; ++i) {
+ StartServer(i);
+ }
+ }
+
+ std::vector<int> GetServersPorts(size_t start_index = 0) {
+ std::vector<int> ports;
+ for (size_t i = start_index; i < servers_.size(); ++i) {
+ ports.push_back(servers_[i]->port_);
+ }
+ return ports;
+ }
+
+ FakeResolverResponseGeneratorWrapper BuildResolverResponseGenerator() {
+ return FakeResolverResponseGeneratorWrapper();
+ }
+
+ std::unique_ptr<grpc::testing::EchoTestService::Stub> BuildStub(
+ const std::shared_ptr<Channel>& channel) {
+ return grpc::testing::EchoTestService::NewStub(channel);
+ }
+
+ std::shared_ptr<Channel> BuildChannel(
+ const TString& lb_policy_name,
+ const FakeResolverResponseGeneratorWrapper& response_generator,
+ ChannelArguments args = ChannelArguments()) {
+ if (lb_policy_name.size() > 0) {
+ args.SetLoadBalancingPolicyName(lb_policy_name);
+ } // else, default to pick first
+ args.SetPointer(GRPC_ARG_FAKE_RESOLVER_RESPONSE_GENERATOR,
+ response_generator.Get());
+ return ::grpc::CreateCustomChannel("fake:///", creds_, args);
+ }
+
+ bool SendRpc(
+ const std::unique_ptr<grpc::testing::EchoTestService::Stub>& stub,
+ EchoResponse* response = nullptr, int timeout_ms = 1000,
+ Status* result = nullptr, bool wait_for_ready = false) {
+ const bool local_response = (response == nullptr);
+ if (local_response) response = new EchoResponse;
+ EchoRequest request;
+ request.set_message(kRequestMessage_);
+ request.mutable_param()->set_echo_metadata(true);
+ ClientContext context;
+ context.set_deadline(grpc_timeout_milliseconds_to_deadline(timeout_ms));
+ if (wait_for_ready) context.set_wait_for_ready(true);
+ context.AddMetadata("foo", "1");
+ context.AddMetadata("bar", "2");
+ context.AddMetadata("baz", "3");
+ Status status = stub->Echo(&context, request, response);
+ if (result != nullptr) *result = status;
+ if (local_response) delete response;
+ return status.ok();
+ }
+
+ void CheckRpcSendOk(
+ const std::unique_ptr<grpc::testing::EchoTestService::Stub>& stub,
+ const grpc_core::DebugLocation& location, bool wait_for_ready = false) {
+ EchoResponse response;
+ Status status;
+ const bool success =
+ SendRpc(stub, &response, 2000, &status, wait_for_ready);
+ ASSERT_TRUE(success) << "From " << location.file() << ":" << location.line()
+ << "\n"
+ << "Error: " << status.error_message() << " "
+ << status.error_details();
+ ASSERT_EQ(response.message(), kRequestMessage_)
+ << "From " << location.file() << ":" << location.line();
+ if (!success) abort();
+ }
+
+ void CheckRpcSendFailure(
+ const std::unique_ptr<grpc::testing::EchoTestService::Stub>& stub) {
+ const bool success = SendRpc(stub);
+ EXPECT_FALSE(success);
+ }
+
+ struct ServerData {
+ int port_;
+ std::unique_ptr<Server> server_;
+ MyTestServiceImpl service_;
+ std::unique_ptr<std::thread> thread_;
+ bool server_ready_ = false;
+ bool started_ = false;
+
+ explicit ServerData(int port = 0) {
+ port_ = port > 0 ? port : 5100; // grpc_pick_unused_port_or_die();
+ }
+
+ void Start(const TString& server_host) {
+ gpr_log(GPR_INFO, "starting server on port %d", port_);
+ started_ = true;
+ grpc::internal::Mutex mu;
+ grpc::internal::MutexLock lock(&mu);
+ grpc::internal::CondVar cond;
+ thread_.reset(new std::thread(
+ std::bind(&ServerData::Serve, this, server_host, &mu, &cond)));
+ cond.WaitUntil(&mu, [this] { return server_ready_; });
+ server_ready_ = false;
+ gpr_log(GPR_INFO, "server startup complete");
+ }
+
+ void Serve(const TString& server_host, grpc::internal::Mutex* mu,
+ grpc::internal::CondVar* cond) {
+ std::ostringstream server_address;
+ server_address << server_host << ":" << port_;
+ ServerBuilder builder;
+ std::shared_ptr<ServerCredentials> creds(new SecureServerCredentials(
+ grpc_fake_transport_security_server_credentials_create()));
+ builder.AddListeningPort(server_address.str(), std::move(creds));
+ builder.RegisterService(&service_);
+ server_ = builder.BuildAndStart();
+ grpc::internal::MutexLock lock(mu);
+ server_ready_ = true;
+ cond->Signal();
+ }
+
+ void Shutdown() {
+ if (!started_) return;
+ server_->Shutdown(grpc_timeout_milliseconds_to_deadline(0));
+ thread_->join();
+ started_ = false;
+ }
+
+ void SetServingStatus(const TString& service, bool serving) {
+ server_->GetHealthCheckService()->SetServingStatus(service, serving);
+ }
+ };
+
+ void ResetCounters() {
+ for (const auto& server : servers_) server->service_.ResetCounters();
+ }
+
+ void WaitForServer(
+ const std::unique_ptr<grpc::testing::EchoTestService::Stub>& stub,
+ size_t server_idx, const grpc_core::DebugLocation& location,
+ bool ignore_failure = false) {
+ do {
+ if (ignore_failure) {
+ SendRpc(stub);
+ } else {
+ CheckRpcSendOk(stub, location, true);
+ }
+ } while (servers_[server_idx]->service_.request_count() == 0);
+ ResetCounters();
+ }
+
+ bool WaitForChannelState(
+ Channel* channel, std::function<bool(grpc_connectivity_state)> predicate,
+ bool try_to_connect = false, int timeout_seconds = 5) {
+ const gpr_timespec deadline =
+ grpc_timeout_seconds_to_deadline(timeout_seconds);
+ while (true) {
+ grpc_connectivity_state state = channel->GetState(try_to_connect);
+ if (predicate(state)) break;
+ if (!channel->WaitForStateChange(state, deadline)) return false;
+ }
+ return true;
+ }
+
+ bool WaitForChannelNotReady(Channel* channel, int timeout_seconds = 5) {
+ auto predicate = [](grpc_connectivity_state state) {
+ return state != GRPC_CHANNEL_READY;
+ };
+ return WaitForChannelState(channel, predicate, false, timeout_seconds);
+ }
+
+ bool WaitForChannelReady(Channel* channel, int timeout_seconds = 5) {
+ auto predicate = [](grpc_connectivity_state state) {
+ return state == GRPC_CHANNEL_READY;
+ };
+ return WaitForChannelState(channel, predicate, true, timeout_seconds);
+ }
+
+ bool SeenAllServers() {
+ for (const auto& server : servers_) {
+ if (server->service_.request_count() == 0) return false;
+ }
+ return true;
+ }
+
+ // Updates \a connection_order by appending to it the index of the newly
+ // connected server. Must be called after every single RPC.
+ void UpdateConnectionOrder(
+ const std::vector<std::unique_ptr<ServerData>>& servers,
+ std::vector<int>* connection_order) {
+ for (size_t i = 0; i < servers.size(); ++i) {
+ if (servers[i]->service_.request_count() == 1) {
+ // Was the server index known? If not, update connection_order.
+ const auto it =
+ std::find(connection_order->begin(), connection_order->end(), i);
+ if (it == connection_order->end()) {
+ connection_order->push_back(i);
+ return;
+ }
+ }
+ }
+ }
+
+ const TString server_host_;
+ std::vector<std::unique_ptr<ServerData>> servers_;
+ const TString kRequestMessage_;
+ std::shared_ptr<ChannelCredentials> creds_;
+};
+
+TEST_F(ClientLbEnd2endTest, ChannelStateConnectingWhenResolving) {
+ const int kNumServers = 3;
+ StartServers(kNumServers);
+ auto response_generator = BuildResolverResponseGenerator();
+ auto channel = BuildChannel("", response_generator);
+ auto stub = BuildStub(channel);
+ // Initial state should be IDLE.
+ EXPECT_EQ(channel->GetState(false /* try_to_connect */), GRPC_CHANNEL_IDLE);
+ // Tell the channel to try to connect.
+ // Note that this call also returns IDLE, since the state change has
+ // not yet occurred; it just gets triggered by this call.
+ EXPECT_EQ(channel->GetState(true /* try_to_connect */), GRPC_CHANNEL_IDLE);
+ // Now that the channel is trying to connect, we should be in state
+ // CONNECTING.
+ EXPECT_EQ(channel->GetState(false /* try_to_connect */),
+ GRPC_CHANNEL_CONNECTING);
+ // Return a resolver result, which allows the connection attempt to proceed.
+ response_generator.SetNextResolution(GetServersPorts());
+ // We should eventually transition into state READY.
+ EXPECT_TRUE(WaitForChannelReady(channel.get()));
+}
+
+TEST_F(ClientLbEnd2endTest, PickFirst) {
+ // Start servers and send one RPC per server.
+ const int kNumServers = 3;
+ StartServers(kNumServers);
+ auto response_generator = BuildResolverResponseGenerator();
+ auto channel = BuildChannel(
+ "", response_generator); // test that pick first is the default.
+ auto stub = BuildStub(channel);
+ response_generator.SetNextResolution(GetServersPorts());
+ for (size_t i = 0; i < servers_.size(); ++i) {
+ CheckRpcSendOk(stub, DEBUG_LOCATION);
+ }
+ // All requests should have gone to a single server.
+ bool found = false;
+ for (size_t i = 0; i < servers_.size(); ++i) {
+ const int request_count = servers_[i]->service_.request_count();
+ if (request_count == kNumServers) {
+ found = true;
+ } else {
+ EXPECT_EQ(0, request_count);
+ }
+ }
+ EXPECT_TRUE(found);
+ // Check LB policy name for the channel.
+ EXPECT_EQ("pick_first", channel->GetLoadBalancingPolicyName());
+}
+
+TEST_F(ClientLbEnd2endTest, PickFirstProcessPending) {
+ StartServers(1); // Single server
+ auto response_generator = BuildResolverResponseGenerator();
+ auto channel = BuildChannel(
+ "", response_generator); // test that pick first is the default.
+ auto stub = BuildStub(channel);
+ response_generator.SetNextResolution({servers_[0]->port_});
+ WaitForServer(stub, 0, DEBUG_LOCATION);
+ // Create a new channel and its corresponding PF LB policy, which will pick
+ // the subchannels in READY state from the previous RPC against the same
+ // target (even if it happened over a different channel, because subchannels
+ // are globally reused). Progress should happen without any transition from
+ // this READY state.
+ auto second_response_generator = BuildResolverResponseGenerator();
+ auto second_channel = BuildChannel("", second_response_generator);
+ auto second_stub = BuildStub(second_channel);
+ second_response_generator.SetNextResolution({servers_[0]->port_});
+ CheckRpcSendOk(second_stub, DEBUG_LOCATION);
+}
+
+TEST_F(ClientLbEnd2endTest, PickFirstSelectsReadyAtStartup) {
+ ChannelArguments args;
+ constexpr int kInitialBackOffMs = 5000;
+ args.SetInt(GRPC_ARG_INITIAL_RECONNECT_BACKOFF_MS, kInitialBackOffMs);
+ // Create 2 servers, but start only the second one.
+ std::vector<int> ports = { 5101, // grpc_pick_unused_port_or_die(),
+ 5102}; // grpc_pick_unused_port_or_die()};
+ CreateServers(2, ports);
+ StartServer(1);
+ auto response_generator1 = BuildResolverResponseGenerator();
+ auto channel1 = BuildChannel("pick_first", response_generator1, args);
+ auto stub1 = BuildStub(channel1);
+ response_generator1.SetNextResolution(ports);
+ // Wait for second server to be ready.
+ WaitForServer(stub1, 1, DEBUG_LOCATION);
+ // Create a second channel with the same addresses. Its PF instance
+ // should immediately pick the second subchannel, since it's already
+ // in READY state.
+ auto response_generator2 = BuildResolverResponseGenerator();
+ auto channel2 = BuildChannel("pick_first", response_generator2, args);
+ response_generator2.SetNextResolution(ports);
+ // Check that the channel reports READY without waiting for the
+ // initial backoff.
+ EXPECT_TRUE(WaitForChannelReady(channel2.get(), 1 /* timeout_seconds */));
+}
+
+TEST_F(ClientLbEnd2endTest, PickFirstBackOffInitialReconnect) {
+ ChannelArguments args;
+ constexpr int kInitialBackOffMs = 100;
+ args.SetInt(GRPC_ARG_INITIAL_RECONNECT_BACKOFF_MS, kInitialBackOffMs);
+ const std::vector<int> ports = {5103}; // {grpc_pick_unused_port_or_die()};
+ const gpr_timespec t0 = gpr_now(GPR_CLOCK_MONOTONIC);
+ auto response_generator = BuildResolverResponseGenerator();
+ auto channel = BuildChannel("pick_first", response_generator, args);
+ auto stub = BuildStub(channel);
+ response_generator.SetNextResolution(ports);
+ // The channel won't become connected (there's no server).
+ ASSERT_FALSE(channel->WaitForConnected(
+ grpc_timeout_milliseconds_to_deadline(kInitialBackOffMs * 2)));
+ // Bring up a server on the chosen port.
+ StartServers(1, ports);
+ // Now it will.
+ ASSERT_TRUE(channel->WaitForConnected(
+ grpc_timeout_milliseconds_to_deadline(kInitialBackOffMs * 2)));
+ const gpr_timespec t1 = gpr_now(GPR_CLOCK_MONOTONIC);
+ const grpc_millis waited_ms = gpr_time_to_millis(gpr_time_sub(t1, t0));
+ gpr_log(GPR_DEBUG, "Waited %" PRId64 " milliseconds", waited_ms);
+ // We should have waited at least kInitialBackOffMs. We substract one to
+ // account for test and precision accuracy drift.
+ EXPECT_GE(waited_ms, kInitialBackOffMs - 1);
+ // But not much more.
+ EXPECT_GT(
+ gpr_time_cmp(
+ grpc_timeout_milliseconds_to_deadline(kInitialBackOffMs * 1.10), t1),
+ 0);
+}
+
+TEST_F(ClientLbEnd2endTest, PickFirstBackOffMinReconnect) {
+ ChannelArguments args;
+ constexpr int kMinReconnectBackOffMs = 1000;
+ args.SetInt(GRPC_ARG_MIN_RECONNECT_BACKOFF_MS, kMinReconnectBackOffMs);
+ const std::vector<int> ports = {5104}; // {grpc_pick_unused_port_or_die()};
+ auto response_generator = BuildResolverResponseGenerator();
+ auto channel = BuildChannel("pick_first", response_generator, args);
+ auto stub = BuildStub(channel);
+ response_generator.SetNextResolution(ports);
+ // Make connection delay a 10% longer than it's willing to in order to make
+ // sure we are hitting the codepath that waits for the min reconnect backoff.
+ gpr_atm_rel_store(&g_connection_delay_ms, kMinReconnectBackOffMs * 1.10);
+ default_client_impl = grpc_tcp_client_impl;
+ grpc_set_tcp_client_impl(&delayed_connect);
+ const gpr_timespec t0 = gpr_now(GPR_CLOCK_MONOTONIC);
+ channel->WaitForConnected(
+ grpc_timeout_milliseconds_to_deadline(kMinReconnectBackOffMs * 2));
+ const gpr_timespec t1 = gpr_now(GPR_CLOCK_MONOTONIC);
+ const grpc_millis waited_ms = gpr_time_to_millis(gpr_time_sub(t1, t0));
+ gpr_log(GPR_DEBUG, "Waited %" PRId64 " ms", waited_ms);
+ // We should have waited at least kMinReconnectBackOffMs. We substract one to
+ // account for test and precision accuracy drift.
+ EXPECT_GE(waited_ms, kMinReconnectBackOffMs - 1);
+ gpr_atm_rel_store(&g_connection_delay_ms, 0);
+}
+
+TEST_F(ClientLbEnd2endTest, PickFirstResetConnectionBackoff) {
+ ChannelArguments args;
+ constexpr int kInitialBackOffMs = 1000;
+ args.SetInt(GRPC_ARG_INITIAL_RECONNECT_BACKOFF_MS, kInitialBackOffMs);
+ const std::vector<int> ports = {5105}; // {grpc_pick_unused_port_or_die()};
+ auto response_generator = BuildResolverResponseGenerator();
+ auto channel = BuildChannel("pick_first", response_generator, args);
+ auto stub = BuildStub(channel);
+ response_generator.SetNextResolution(ports);
+ // The channel won't become connected (there's no server).
+ EXPECT_FALSE(
+ channel->WaitForConnected(grpc_timeout_milliseconds_to_deadline(10)));
+ // Bring up a server on the chosen port.
+ StartServers(1, ports);
+ const gpr_timespec t0 = gpr_now(GPR_CLOCK_MONOTONIC);
+ // Wait for connect, but not long enough. This proves that we're
+ // being throttled by initial backoff.
+ EXPECT_FALSE(
+ channel->WaitForConnected(grpc_timeout_milliseconds_to_deadline(10)));
+ // Reset connection backoff.
+ experimental::ChannelResetConnectionBackoff(channel.get());
+ // Wait for connect. Should happen as soon as the client connects to
+ // the newly started server, which should be before the initial
+ // backoff timeout elapses.
+ EXPECT_TRUE(
+ channel->WaitForConnected(grpc_timeout_milliseconds_to_deadline(20)));
+ const gpr_timespec t1 = gpr_now(GPR_CLOCK_MONOTONIC);
+ const grpc_millis waited_ms = gpr_time_to_millis(gpr_time_sub(t1, t0));
+ gpr_log(GPR_DEBUG, "Waited %" PRId64 " milliseconds", waited_ms);
+ // We should have waited less than kInitialBackOffMs.
+ EXPECT_LT(waited_ms, kInitialBackOffMs);
+}
+
+TEST_F(ClientLbEnd2endTest,
+ PickFirstResetConnectionBackoffNextAttemptStartsImmediately) {
+ ChannelArguments args;
+ constexpr int kInitialBackOffMs = 1000;
+ args.SetInt(GRPC_ARG_INITIAL_RECONNECT_BACKOFF_MS, kInitialBackOffMs);
+ const std::vector<int> ports = {5106}; // {grpc_pick_unused_port_or_die()};
+ auto response_generator = BuildResolverResponseGenerator();
+ auto channel = BuildChannel("pick_first", response_generator, args);
+ auto stub = BuildStub(channel);
+ response_generator.SetNextResolution(ports);
+ // Wait for connect, which should fail ~immediately, because the server
+ // is not up.
+ gpr_log(GPR_INFO, "=== INITIAL CONNECTION ATTEMPT");
+ EXPECT_FALSE(
+ channel->WaitForConnected(grpc_timeout_milliseconds_to_deadline(10)));
+ // Reset connection backoff.
+ // Note that the time at which the third attempt will be started is
+ // actually computed at this point, so we record the start time here.
+ gpr_log(GPR_INFO, "=== RESETTING BACKOFF");
+ const gpr_timespec t0 = gpr_now(GPR_CLOCK_MONOTONIC);
+ experimental::ChannelResetConnectionBackoff(channel.get());
+ // Trigger a second connection attempt. This should also fail
+ // ~immediately, but the retry should be scheduled for
+ // kInitialBackOffMs instead of applying the multiplier.
+ gpr_log(GPR_INFO, "=== POLLING FOR SECOND CONNECTION ATTEMPT");
+ EXPECT_FALSE(
+ channel->WaitForConnected(grpc_timeout_milliseconds_to_deadline(10)));
+ // Bring up a server on the chosen port.
+ gpr_log(GPR_INFO, "=== STARTING BACKEND");
+ StartServers(1, ports);
+ // Wait for connect. Should happen within kInitialBackOffMs.
+ // Give an extra 100ms to account for the time spent in the second and
+ // third connection attempts themselves (since what we really want to
+ // measure is the time between the two). As long as this is less than
+ // the 1.6x increase we would see if the backoff state was not reset
+ // properly, the test is still proving that the backoff was reset.
+ constexpr int kWaitMs = kInitialBackOffMs + 100;
+ gpr_log(GPR_INFO, "=== POLLING FOR THIRD CONNECTION ATTEMPT");
+ EXPECT_TRUE(channel->WaitForConnected(
+ grpc_timeout_milliseconds_to_deadline(kWaitMs)));
+ const gpr_timespec t1 = gpr_now(GPR_CLOCK_MONOTONIC);
+ const grpc_millis waited_ms = gpr_time_to_millis(gpr_time_sub(t1, t0));
+ gpr_log(GPR_DEBUG, "Waited %" PRId64 " milliseconds", waited_ms);
+ EXPECT_LT(waited_ms, kWaitMs);
+}
+
+TEST_F(ClientLbEnd2endTest, PickFirstUpdates) {
+ // Start servers and send one RPC per server.
+ const int kNumServers = 3;
+ StartServers(kNumServers);
+ auto response_generator = BuildResolverResponseGenerator();
+ auto channel = BuildChannel("pick_first", response_generator);
+ auto stub = BuildStub(channel);
+
+ std::vector<int> ports;
+
+ // Perform one RPC against the first server.
+ ports.emplace_back(servers_[0]->port_);
+ response_generator.SetNextResolution(ports);
+ gpr_log(GPR_INFO, "****** SET [0] *******");
+ CheckRpcSendOk(stub, DEBUG_LOCATION);
+ EXPECT_EQ(servers_[0]->service_.request_count(), 1);
+
+ // An empty update will result in the channel going into TRANSIENT_FAILURE.
+ ports.clear();
+ response_generator.SetNextResolution(ports);
+ gpr_log(GPR_INFO, "****** SET none *******");
+ grpc_connectivity_state channel_state;
+ do {
+ channel_state = channel->GetState(true /* try to connect */);
+ } while (channel_state == GRPC_CHANNEL_READY);
+ ASSERT_NE(channel_state, GRPC_CHANNEL_READY);
+ servers_[0]->service_.ResetCounters();
+
+ // Next update introduces servers_[1], making the channel recover.
+ ports.clear();
+ ports.emplace_back(servers_[1]->port_);
+ response_generator.SetNextResolution(ports);
+ gpr_log(GPR_INFO, "****** SET [1] *******");
+ WaitForServer(stub, 1, DEBUG_LOCATION);
+ EXPECT_EQ(servers_[0]->service_.request_count(), 0);
+
+ // And again for servers_[2]
+ ports.clear();
+ ports.emplace_back(servers_[2]->port_);
+ response_generator.SetNextResolution(ports);
+ gpr_log(GPR_INFO, "****** SET [2] *******");
+ WaitForServer(stub, 2, DEBUG_LOCATION);
+ EXPECT_EQ(servers_[0]->service_.request_count(), 0);
+ EXPECT_EQ(servers_[1]->service_.request_count(), 0);
+
+ // Check LB policy name for the channel.
+ EXPECT_EQ("pick_first", channel->GetLoadBalancingPolicyName());
+}
+
+TEST_F(ClientLbEnd2endTest, PickFirstUpdateSuperset) {
+ // Start servers and send one RPC per server.
+ const int kNumServers = 3;
+ StartServers(kNumServers);
+ auto response_generator = BuildResolverResponseGenerator();
+ auto channel = BuildChannel("pick_first", response_generator);
+ auto stub = BuildStub(channel);
+
+ std::vector<int> ports;
+
+ // Perform one RPC against the first server.
+ ports.emplace_back(servers_[0]->port_);
+ response_generator.SetNextResolution(ports);
+ gpr_log(GPR_INFO, "****** SET [0] *******");
+ CheckRpcSendOk(stub, DEBUG_LOCATION);
+ EXPECT_EQ(servers_[0]->service_.request_count(), 1);
+ servers_[0]->service_.ResetCounters();
+
+ // Send and superset update
+ ports.clear();
+ ports.emplace_back(servers_[1]->port_);
+ ports.emplace_back(servers_[0]->port_);
+ response_generator.SetNextResolution(ports);
+ gpr_log(GPR_INFO, "****** SET superset *******");
+ CheckRpcSendOk(stub, DEBUG_LOCATION);
+ // We stick to the previously connected server.
+ WaitForServer(stub, 0, DEBUG_LOCATION);
+ EXPECT_EQ(0, servers_[1]->service_.request_count());
+
+ // Check LB policy name for the channel.
+ EXPECT_EQ("pick_first", channel->GetLoadBalancingPolicyName());
+}
+
+TEST_F(ClientLbEnd2endTest, PickFirstGlobalSubchannelPool) {
+ // Start one server.
+ const int kNumServers = 1;
+ StartServers(kNumServers);
+ std::vector<int> ports = GetServersPorts();
+ // Create two channels that (by default) use the global subchannel pool.
+ auto response_generator1 = BuildResolverResponseGenerator();
+ auto channel1 = BuildChannel("pick_first", response_generator1);
+ auto stub1 = BuildStub(channel1);
+ response_generator1.SetNextResolution(ports);
+ auto response_generator2 = BuildResolverResponseGenerator();
+ auto channel2 = BuildChannel("pick_first", response_generator2);
+ auto stub2 = BuildStub(channel2);
+ response_generator2.SetNextResolution(ports);
+ WaitForServer(stub1, 0, DEBUG_LOCATION);
+ // Send one RPC on each channel.
+ CheckRpcSendOk(stub1, DEBUG_LOCATION);
+ CheckRpcSendOk(stub2, DEBUG_LOCATION);
+ // The server receives two requests.
+ EXPECT_EQ(2, servers_[0]->service_.request_count());
+ // The two requests are from the same client port, because the two channels
+ // share subchannels via the global subchannel pool.
+ EXPECT_EQ(1UL, servers_[0]->service_.clients().size());
+}
+
+TEST_F(ClientLbEnd2endTest, PickFirstLocalSubchannelPool) {
+ // Start one server.
+ const int kNumServers = 1;
+ StartServers(kNumServers);
+ std::vector<int> ports = GetServersPorts();
+ // Create two channels that use local subchannel pool.
+ ChannelArguments args;
+ args.SetInt(GRPC_ARG_USE_LOCAL_SUBCHANNEL_POOL, 1);
+ auto response_generator1 = BuildResolverResponseGenerator();
+ auto channel1 = BuildChannel("pick_first", response_generator1, args);
+ auto stub1 = BuildStub(channel1);
+ response_generator1.SetNextResolution(ports);
+ auto response_generator2 = BuildResolverResponseGenerator();
+ auto channel2 = BuildChannel("pick_first", response_generator2, args);
+ auto stub2 = BuildStub(channel2);
+ response_generator2.SetNextResolution(ports);
+ WaitForServer(stub1, 0, DEBUG_LOCATION);
+ // Send one RPC on each channel.
+ CheckRpcSendOk(stub1, DEBUG_LOCATION);
+ CheckRpcSendOk(stub2, DEBUG_LOCATION);
+ // The server receives two requests.
+ EXPECT_EQ(2, servers_[0]->service_.request_count());
+ // The two requests are from two client ports, because the two channels didn't
+ // share subchannels with each other.
+ EXPECT_EQ(2UL, servers_[0]->service_.clients().size());
+}
+
+TEST_F(ClientLbEnd2endTest, PickFirstManyUpdates) {
+ const int kNumUpdates = 1000;
+ const int kNumServers = 3;
+ StartServers(kNumServers);
+ auto response_generator = BuildResolverResponseGenerator();
+ auto channel = BuildChannel("pick_first", response_generator);
+ auto stub = BuildStub(channel);
+ std::vector<int> ports = GetServersPorts();
+ for (size_t i = 0; i < kNumUpdates; ++i) {
+ std::shuffle(ports.begin(), ports.end(),
+ std::mt19937(std::random_device()()));
+ response_generator.SetNextResolution(ports);
+ // We should re-enter core at the end of the loop to give the resolution
+ // setting closure a chance to run.
+ if ((i + 1) % 10 == 0) CheckRpcSendOk(stub, DEBUG_LOCATION);
+ }
+ // Check LB policy name for the channel.
+ EXPECT_EQ("pick_first", channel->GetLoadBalancingPolicyName());
+}
+
+TEST_F(ClientLbEnd2endTest, PickFirstReresolutionNoSelected) {
+ // Prepare the ports for up servers and down servers.
+ const int kNumServers = 3;
+ const int kNumAliveServers = 1;
+ StartServers(kNumAliveServers);
+ std::vector<int> alive_ports, dead_ports;
+ for (size_t i = 0; i < kNumServers; ++i) {
+ if (i < kNumAliveServers) {
+ alive_ports.emplace_back(servers_[i]->port_);
+ } else {
+ dead_ports.emplace_back(5107 + i);
+ // dead_ports.emplace_back(grpc_pick_unused_port_or_die());
+ }
+ }
+ auto response_generator = BuildResolverResponseGenerator();
+ auto channel = BuildChannel("pick_first", response_generator);
+ auto stub = BuildStub(channel);
+ // The initial resolution only contains dead ports. There won't be any
+ // selected subchannel. Re-resolution will return the same result.
+ response_generator.SetNextResolution(dead_ports);
+ gpr_log(GPR_INFO, "****** INITIAL RESOLUTION SET *******");
+ for (size_t i = 0; i < 10; ++i) CheckRpcSendFailure(stub);
+ // Set a re-resolution result that contains reachable ports, so that the
+ // pick_first LB policy can recover soon.
+ response_generator.SetNextResolutionUponError(alive_ports);
+ gpr_log(GPR_INFO, "****** RE-RESOLUTION SET *******");
+ WaitForServer(stub, 0, DEBUG_LOCATION, true /* ignore_failure */);
+ CheckRpcSendOk(stub, DEBUG_LOCATION);
+ EXPECT_EQ(servers_[0]->service_.request_count(), 1);
+ // Check LB policy name for the channel.
+ EXPECT_EQ("pick_first", channel->GetLoadBalancingPolicyName());
+}
+
+TEST_F(ClientLbEnd2endTest, PickFirstReconnectWithoutNewResolverResult) {
+ std::vector<int> ports = {5110}; // {grpc_pick_unused_port_or_die()};
+ StartServers(1, ports);
+ auto response_generator = BuildResolverResponseGenerator();
+ auto channel = BuildChannel("pick_first", response_generator);
+ auto stub = BuildStub(channel);
+ response_generator.SetNextResolution(ports);
+ gpr_log(GPR_INFO, "****** INITIAL CONNECTION *******");
+ WaitForServer(stub, 0, DEBUG_LOCATION);
+ gpr_log(GPR_INFO, "****** STOPPING SERVER ******");
+ servers_[0]->Shutdown();
+ EXPECT_TRUE(WaitForChannelNotReady(channel.get()));
+ gpr_log(GPR_INFO, "****** RESTARTING SERVER ******");
+ StartServers(1, ports);
+ WaitForServer(stub, 0, DEBUG_LOCATION);
+}
+
+TEST_F(ClientLbEnd2endTest,
+ PickFirstReconnectWithoutNewResolverResultStartsFromTopOfList) {
+ std::vector<int> ports = {5111, // grpc_pick_unused_port_or_die(),
+ 5112}; // grpc_pick_unused_port_or_die()};
+ CreateServers(2, ports);
+ StartServer(1);
+ auto response_generator = BuildResolverResponseGenerator();
+ auto channel = BuildChannel("pick_first", response_generator);
+ auto stub = BuildStub(channel);
+ response_generator.SetNextResolution(ports);
+ gpr_log(GPR_INFO, "****** INITIAL CONNECTION *******");
+ WaitForServer(stub, 1, DEBUG_LOCATION);
+ gpr_log(GPR_INFO, "****** STOPPING SERVER ******");
+ servers_[1]->Shutdown();
+ EXPECT_TRUE(WaitForChannelNotReady(channel.get()));
+ gpr_log(GPR_INFO, "****** STARTING BOTH SERVERS ******");
+ StartServers(2, ports);
+ WaitForServer(stub, 0, DEBUG_LOCATION);
+}
+
+TEST_F(ClientLbEnd2endTest, PickFirstCheckStateBeforeStartWatch) {
+ std::vector<int> ports = {5113}; // {grpc_pick_unused_port_or_die()};
+ StartServers(1, ports);
+ auto response_generator = BuildResolverResponseGenerator();
+ auto channel_1 = BuildChannel("pick_first", response_generator);
+ auto stub_1 = BuildStub(channel_1);
+ response_generator.SetNextResolution(ports);
+ gpr_log(GPR_INFO, "****** RESOLUTION SET FOR CHANNEL 1 *******");
+ WaitForServer(stub_1, 0, DEBUG_LOCATION);
+ gpr_log(GPR_INFO, "****** CHANNEL 1 CONNECTED *******");
+ servers_[0]->Shutdown();
+ // Channel 1 will receive a re-resolution containing the same server. It will
+ // create a new subchannel and hold a ref to it.
+ StartServers(1, ports);
+ gpr_log(GPR_INFO, "****** SERVER RESTARTED *******");
+ auto response_generator_2 = BuildResolverResponseGenerator();
+ auto channel_2 = BuildChannel("pick_first", response_generator_2);
+ auto stub_2 = BuildStub(channel_2);
+ response_generator_2.SetNextResolution(ports);
+ gpr_log(GPR_INFO, "****** RESOLUTION SET FOR CHANNEL 2 *******");
+ WaitForServer(stub_2, 0, DEBUG_LOCATION, true);
+ gpr_log(GPR_INFO, "****** CHANNEL 2 CONNECTED *******");
+ servers_[0]->Shutdown();
+ // Wait until the disconnection has triggered the connectivity notification.
+ // Otherwise, the subchannel may be picked for next call but will fail soon.
+ EXPECT_TRUE(WaitForChannelNotReady(channel_2.get()));
+ // Channel 2 will also receive a re-resolution containing the same server.
+ // Both channels will ref the same subchannel that failed.
+ StartServers(1, ports);
+ gpr_log(GPR_INFO, "****** SERVER RESTARTED AGAIN *******");
+ gpr_log(GPR_INFO, "****** CHANNEL 2 STARTING A CALL *******");
+ // The first call after the server restart will succeed.
+ CheckRpcSendOk(stub_2, DEBUG_LOCATION);
+ gpr_log(GPR_INFO, "****** CHANNEL 2 FINISHED A CALL *******");
+ // Check LB policy name for the channel.
+ EXPECT_EQ("pick_first", channel_1->GetLoadBalancingPolicyName());
+ // Check LB policy name for the channel.
+ EXPECT_EQ("pick_first", channel_2->GetLoadBalancingPolicyName());
+}
+
+TEST_F(ClientLbEnd2endTest, PickFirstIdleOnDisconnect) {
+ // Start server, send RPC, and make sure channel is READY.
+ const int kNumServers = 1;
+ StartServers(kNumServers);
+ auto response_generator = BuildResolverResponseGenerator();
+ auto channel =
+ BuildChannel("", response_generator); // pick_first is the default.
+ auto stub = BuildStub(channel);
+ response_generator.SetNextResolution(GetServersPorts());
+ CheckRpcSendOk(stub, DEBUG_LOCATION);
+ EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY);
+ // Stop server. Channel should go into state IDLE.
+ response_generator.SetFailureOnReresolution();
+ servers_[0]->Shutdown();
+ EXPECT_TRUE(WaitForChannelNotReady(channel.get()));
+ EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_IDLE);
+ servers_.clear();
+}
+
+TEST_F(ClientLbEnd2endTest, PickFirstPendingUpdateAndSelectedSubchannelFails) {
+ auto response_generator = BuildResolverResponseGenerator();
+ auto channel =
+ BuildChannel("", response_generator); // pick_first is the default.
+ auto stub = BuildStub(channel);
+ // Create a number of servers, but only start 1 of them.
+ CreateServers(10);
+ StartServer(0);
+ // Initially resolve to first server and make sure it connects.
+ gpr_log(GPR_INFO, "Phase 1: Connect to first server.");
+ response_generator.SetNextResolution({servers_[0]->port_});
+ CheckRpcSendOk(stub, DEBUG_LOCATION, true /* wait_for_ready */);
+ EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY);
+ // Send a resolution update with the remaining servers, none of which are
+ // running yet, so the update will stay pending. Note that it's important
+ // to have multiple servers here, or else the test will be flaky; with only
+ // one server, the pending subchannel list has already gone into
+ // TRANSIENT_FAILURE due to hitting the end of the list by the time we
+ // check the state.
+ gpr_log(GPR_INFO,
+ "Phase 2: Resolver update pointing to remaining "
+ "(not started) servers.");
+ response_generator.SetNextResolution(GetServersPorts(1 /* start_index */));
+ // RPCs will continue to be sent to the first server.
+ CheckRpcSendOk(stub, DEBUG_LOCATION);
+ // Now stop the first server, so that the current subchannel list
+ // fails. This should cause us to immediately swap over to the
+ // pending list, even though it's not yet connected. The state should
+ // be set to CONNECTING, since that's what the pending subchannel list
+ // was doing when we swapped over.
+ gpr_log(GPR_INFO, "Phase 3: Stopping first server.");
+ servers_[0]->Shutdown();
+ WaitForChannelNotReady(channel.get());
+ // TODO(roth): This should always return CONNECTING, but it's flaky
+ // between that and TRANSIENT_FAILURE. I suspect that this problem
+ // will go away once we move the backoff code out of the subchannel
+ // and into the LB policies.
+ EXPECT_THAT(channel->GetState(false),
+ ::testing::AnyOf(GRPC_CHANNEL_CONNECTING,
+ GRPC_CHANNEL_TRANSIENT_FAILURE));
+ // Now start the second server.
+ gpr_log(GPR_INFO, "Phase 4: Starting second server.");
+ StartServer(1);
+ // The channel should go to READY state and RPCs should go to the
+ // second server.
+ WaitForChannelReady(channel.get());
+ WaitForServer(stub, 1, DEBUG_LOCATION, true /* ignore_failure */);
+}
+
+TEST_F(ClientLbEnd2endTest, PickFirstStaysIdleUponEmptyUpdate) {
+ // Start server, send RPC, and make sure channel is READY.
+ const int kNumServers = 1;
+ StartServers(kNumServers);
+ auto response_generator = BuildResolverResponseGenerator();
+ auto channel =
+ BuildChannel("", response_generator); // pick_first is the default.
+ auto stub = BuildStub(channel);
+ response_generator.SetNextResolution(GetServersPorts());
+ CheckRpcSendOk(stub, DEBUG_LOCATION);
+ EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY);
+ // Stop server. Channel should go into state IDLE.
+ servers_[0]->Shutdown();
+ EXPECT_TRUE(WaitForChannelNotReady(channel.get()));
+ EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_IDLE);
+ // Now send resolver update that includes no addresses. Channel
+ // should stay in state IDLE.
+ response_generator.SetNextResolution({});
+ EXPECT_FALSE(channel->WaitForStateChange(
+ GRPC_CHANNEL_IDLE, grpc_timeout_seconds_to_deadline(3)));
+ // Now bring the backend back up and send a non-empty resolver update,
+ // and then try to send an RPC. Channel should go back into state READY.
+ StartServer(0);
+ response_generator.SetNextResolution(GetServersPorts());
+ CheckRpcSendOk(stub, DEBUG_LOCATION);
+ EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY);
+}
+
+TEST_F(ClientLbEnd2endTest, RoundRobin) {
+ // Start servers and send one RPC per server.
+ const int kNumServers = 3;
+ StartServers(kNumServers);
+ auto response_generator = BuildResolverResponseGenerator();
+ auto channel = BuildChannel("round_robin", response_generator);
+ auto stub = BuildStub(channel);
+ response_generator.SetNextResolution(GetServersPorts());
+ // Wait until all backends are ready.
+ do {
+ CheckRpcSendOk(stub, DEBUG_LOCATION);
+ } while (!SeenAllServers());
+ ResetCounters();
+ // "Sync" to the end of the list. Next sequence of picks will start at the
+ // first server (index 0).
+ WaitForServer(stub, servers_.size() - 1, DEBUG_LOCATION);
+ std::vector<int> connection_order;
+ for (size_t i = 0; i < servers_.size(); ++i) {
+ CheckRpcSendOk(stub, DEBUG_LOCATION);
+ UpdateConnectionOrder(servers_, &connection_order);
+ }
+ // Backends should be iterated over in the order in which the addresses were
+ // given.
+ const auto expected = std::vector<int>{0, 1, 2};
+ EXPECT_EQ(expected, connection_order);
+ // Check LB policy name for the channel.
+ EXPECT_EQ("round_robin", channel->GetLoadBalancingPolicyName());
+}
+
+TEST_F(ClientLbEnd2endTest, RoundRobinProcessPending) {
+ StartServers(1); // Single server
+ auto response_generator = BuildResolverResponseGenerator();
+ auto channel = BuildChannel("round_robin", response_generator);
+ auto stub = BuildStub(channel);
+ response_generator.SetNextResolution({servers_[0]->port_});
+ WaitForServer(stub, 0, DEBUG_LOCATION);
+ // Create a new channel and its corresponding RR LB policy, which will pick
+ // the subchannels in READY state from the previous RPC against the same
+ // target (even if it happened over a different channel, because subchannels
+ // are globally reused). Progress should happen without any transition from
+ // this READY state.
+ auto second_response_generator = BuildResolverResponseGenerator();
+ auto second_channel = BuildChannel("round_robin", second_response_generator);
+ auto second_stub = BuildStub(second_channel);
+ second_response_generator.SetNextResolution({servers_[0]->port_});
+ CheckRpcSendOk(second_stub, DEBUG_LOCATION);
+}
+
+TEST_F(ClientLbEnd2endTest, RoundRobinUpdates) {
+ // Start servers and send one RPC per server.
+ const int kNumServers = 3;
+ StartServers(kNumServers);
+ auto response_generator = BuildResolverResponseGenerator();
+ auto channel = BuildChannel("round_robin", response_generator);
+ auto stub = BuildStub(channel);
+ std::vector<int> ports;
+ // Start with a single server.
+ gpr_log(GPR_INFO, "*** FIRST BACKEND ***");
+ ports.emplace_back(servers_[0]->port_);
+ response_generator.SetNextResolution(ports);
+ WaitForServer(stub, 0, DEBUG_LOCATION);
+ // Send RPCs. They should all go servers_[0]
+ for (size_t i = 0; i < 10; ++i) CheckRpcSendOk(stub, DEBUG_LOCATION);
+ EXPECT_EQ(10, servers_[0]->service_.request_count());
+ EXPECT_EQ(0, servers_[1]->service_.request_count());
+ EXPECT_EQ(0, servers_[2]->service_.request_count());
+ servers_[0]->service_.ResetCounters();
+ // And now for the second server.
+ gpr_log(GPR_INFO, "*** SECOND BACKEND ***");
+ ports.clear();
+ ports.emplace_back(servers_[1]->port_);
+ response_generator.SetNextResolution(ports);
+ // Wait until update has been processed, as signaled by the second backend
+ // receiving a request.
+ EXPECT_EQ(0, servers_[1]->service_.request_count());
+ WaitForServer(stub, 1, DEBUG_LOCATION);
+ for (size_t i = 0; i < 10; ++i) CheckRpcSendOk(stub, DEBUG_LOCATION);
+ EXPECT_EQ(0, servers_[0]->service_.request_count());
+ EXPECT_EQ(10, servers_[1]->service_.request_count());
+ EXPECT_EQ(0, servers_[2]->service_.request_count());
+ servers_[1]->service_.ResetCounters();
+ // ... and for the last server.
+ gpr_log(GPR_INFO, "*** THIRD BACKEND ***");
+ ports.clear();
+ ports.emplace_back(servers_[2]->port_);
+ response_generator.SetNextResolution(ports);
+ WaitForServer(stub, 2, DEBUG_LOCATION);
+ for (size_t i = 0; i < 10; ++i) CheckRpcSendOk(stub, DEBUG_LOCATION);
+ EXPECT_EQ(0, servers_[0]->service_.request_count());
+ EXPECT_EQ(0, servers_[1]->service_.request_count());
+ EXPECT_EQ(10, servers_[2]->service_.request_count());
+ servers_[2]->service_.ResetCounters();
+ // Back to all servers.
+ gpr_log(GPR_INFO, "*** ALL BACKENDS ***");
+ ports.clear();
+ ports.emplace_back(servers_[0]->port_);
+ ports.emplace_back(servers_[1]->port_);
+ ports.emplace_back(servers_[2]->port_);
+ response_generator.SetNextResolution(ports);
+ WaitForServer(stub, 0, DEBUG_LOCATION);
+ WaitForServer(stub, 1, DEBUG_LOCATION);
+ WaitForServer(stub, 2, DEBUG_LOCATION);
+ // Send three RPCs, one per server.
+ for (size_t i = 0; i < 3; ++i) CheckRpcSendOk(stub, DEBUG_LOCATION);
+ EXPECT_EQ(1, servers_[0]->service_.request_count());
+ EXPECT_EQ(1, servers_[1]->service_.request_count());
+ EXPECT_EQ(1, servers_[2]->service_.request_count());
+ // An empty update will result in the channel going into TRANSIENT_FAILURE.
+ gpr_log(GPR_INFO, "*** NO BACKENDS ***");
+ ports.clear();
+ response_generator.SetNextResolution(ports);
+ grpc_connectivity_state channel_state;
+ do {
+ channel_state = channel->GetState(true /* try to connect */);
+ } while (channel_state == GRPC_CHANNEL_READY);
+ ASSERT_NE(channel_state, GRPC_CHANNEL_READY);
+ servers_[0]->service_.ResetCounters();
+ // Next update introduces servers_[1], making the channel recover.
+ gpr_log(GPR_INFO, "*** BACK TO SECOND BACKEND ***");
+ ports.clear();
+ ports.emplace_back(servers_[1]->port_);
+ response_generator.SetNextResolution(ports);
+ WaitForServer(stub, 1, DEBUG_LOCATION);
+ channel_state = channel->GetState(false /* try to connect */);
+ ASSERT_EQ(channel_state, GRPC_CHANNEL_READY);
+ // Check LB policy name for the channel.
+ EXPECT_EQ("round_robin", channel->GetLoadBalancingPolicyName());
+}
+
+TEST_F(ClientLbEnd2endTest, RoundRobinUpdateInError) {
+ const int kNumServers = 3;
+ StartServers(kNumServers);
+ auto response_generator = BuildResolverResponseGenerator();
+ auto channel = BuildChannel("round_robin", response_generator);
+ auto stub = BuildStub(channel);
+ std::vector<int> ports;
+ // Start with a single server.
+ ports.emplace_back(servers_[0]->port_);
+ response_generator.SetNextResolution(ports);
+ WaitForServer(stub, 0, DEBUG_LOCATION);
+ // Send RPCs. They should all go to servers_[0]
+ for (size_t i = 0; i < 10; ++i) SendRpc(stub);
+ EXPECT_EQ(10, servers_[0]->service_.request_count());
+ EXPECT_EQ(0, servers_[1]->service_.request_count());
+ EXPECT_EQ(0, servers_[2]->service_.request_count());
+ servers_[0]->service_.ResetCounters();
+ // Shutdown one of the servers to be sent in the update.
+ servers_[1]->Shutdown();
+ ports.emplace_back(servers_[1]->port_);
+ ports.emplace_back(servers_[2]->port_);
+ response_generator.SetNextResolution(ports);
+ WaitForServer(stub, 0, DEBUG_LOCATION);
+ WaitForServer(stub, 2, DEBUG_LOCATION);
+ // Send three RPCs, one per server.
+ for (size_t i = 0; i < kNumServers; ++i) SendRpc(stub);
+ // The server in shutdown shouldn't receive any.
+ EXPECT_EQ(0, servers_[1]->service_.request_count());
+}
+
+TEST_F(ClientLbEnd2endTest, RoundRobinManyUpdates) {
+ // Start servers and send one RPC per server.
+ const int kNumServers = 3;
+ StartServers(kNumServers);
+ auto response_generator = BuildResolverResponseGenerator();
+ auto channel = BuildChannel("round_robin", response_generator);
+ auto stub = BuildStub(channel);
+ std::vector<int> ports = GetServersPorts();
+ for (size_t i = 0; i < 1000; ++i) {
+ std::shuffle(ports.begin(), ports.end(),
+ std::mt19937(std::random_device()()));
+ response_generator.SetNextResolution(ports);
+ if (i % 10 == 0) CheckRpcSendOk(stub, DEBUG_LOCATION);
+ }
+ // Check LB policy name for the channel.
+ EXPECT_EQ("round_robin", channel->GetLoadBalancingPolicyName());
+}
+
+TEST_F(ClientLbEnd2endTest, RoundRobinConcurrentUpdates) {
+ // TODO(dgq): replicate the way internal testing exercises the concurrent
+ // update provisions of RR.
+}
+
+TEST_F(ClientLbEnd2endTest, RoundRobinReresolve) {
+ // Start servers and send one RPC per server.
+ const int kNumServers = 3;
+ std::vector<int> first_ports;
+ std::vector<int> second_ports;
+ first_ports.reserve(kNumServers);
+ for (int i = 0; i < kNumServers; ++i) {
+ // first_ports.push_back(grpc_pick_unused_port_or_die());
+ first_ports.push_back(5114 + i);
+ }
+ second_ports.reserve(kNumServers);
+ for (int i = 0; i < kNumServers; ++i) {
+ // second_ports.push_back(grpc_pick_unused_port_or_die());
+ second_ports.push_back(5117 + i);
+ }
+ StartServers(kNumServers, first_ports);
+ auto response_generator = BuildResolverResponseGenerator();
+ auto channel = BuildChannel("round_robin", response_generator);
+ auto stub = BuildStub(channel);
+ response_generator.SetNextResolution(first_ports);
+ // Send a number of RPCs, which succeed.
+ for (size_t i = 0; i < 100; ++i) {
+ CheckRpcSendOk(stub, DEBUG_LOCATION);
+ }
+ // Kill all servers
+ gpr_log(GPR_INFO, "****** ABOUT TO KILL SERVERS *******");
+ for (size_t i = 0; i < servers_.size(); ++i) {
+ servers_[i]->Shutdown();
+ }
+ gpr_log(GPR_INFO, "****** SERVERS KILLED *******");
+ gpr_log(GPR_INFO, "****** SENDING DOOMED REQUESTS *******");
+ // Client requests should fail. Send enough to tickle all subchannels.
+ for (size_t i = 0; i < servers_.size(); ++i) CheckRpcSendFailure(stub);
+ gpr_log(GPR_INFO, "****** DOOMED REQUESTS SENT *******");
+ // Bring servers back up on a different set of ports. We need to do this to be
+ // sure that the eventual success is *not* due to subchannel reconnection
+ // attempts and that an actual re-resolution has happened as a result of the
+ // RR policy going into transient failure when all its subchannels become
+ // unavailable (in transient failure as well).
+ gpr_log(GPR_INFO, "****** RESTARTING SERVERS *******");
+ StartServers(kNumServers, second_ports);
+ // Don't notify of the update. Wait for the LB policy's re-resolution to
+ // "pull" the new ports.
+ response_generator.SetNextResolutionUponError(second_ports);
+ gpr_log(GPR_INFO, "****** SERVERS RESTARTED *******");
+ gpr_log(GPR_INFO, "****** SENDING REQUEST TO SUCCEED *******");
+ // Client request should eventually (but still fairly soon) succeed.
+ const gpr_timespec deadline = grpc_timeout_seconds_to_deadline(5);
+ gpr_timespec now = gpr_now(GPR_CLOCK_MONOTONIC);
+ while (gpr_time_cmp(deadline, now) > 0) {
+ if (SendRpc(stub)) break;
+ now = gpr_now(GPR_CLOCK_MONOTONIC);
+ }
+ ASSERT_GT(gpr_time_cmp(deadline, now), 0);
+}
+
+TEST_F(ClientLbEnd2endTest, RoundRobinTransientFailure) {
+ // Start servers and create channel. Channel should go to READY state.
+ const int kNumServers = 3;
+ StartServers(kNumServers);
+ auto response_generator = BuildResolverResponseGenerator();
+ auto channel = BuildChannel("round_robin", response_generator);
+ auto stub = BuildStub(channel);
+ response_generator.SetNextResolution(GetServersPorts());
+ EXPECT_TRUE(WaitForChannelReady(channel.get()));
+ // Now kill the servers. The channel should transition to TRANSIENT_FAILURE.
+ // TODO(roth): This test should ideally check that even when the
+ // subchannels are in state CONNECTING for an extended period of time,
+ // we will still report TRANSIENT_FAILURE. Unfortunately, we don't
+ // currently have a good way to get a subchannel to report CONNECTING
+ // for a long period of time, since the servers in this test framework
+ // are on the loopback interface, which will immediately return a
+ // "Connection refused" error, so the subchannels will only be in
+ // CONNECTING state very briefly. When we have time, see if we can
+ // find a way to fix this.
+ for (size_t i = 0; i < servers_.size(); ++i) {
+ servers_[i]->Shutdown();
+ }
+ auto predicate = [](grpc_connectivity_state state) {
+ return state == GRPC_CHANNEL_TRANSIENT_FAILURE;
+ };
+ EXPECT_TRUE(WaitForChannelState(channel.get(), predicate));
+}
+
+TEST_F(ClientLbEnd2endTest, RoundRobinTransientFailureAtStartup) {
+ // Create channel and return servers that don't exist. Channel should
+ // quickly transition into TRANSIENT_FAILURE.
+ // TODO(roth): This test should ideally check that even when the
+ // subchannels are in state CONNECTING for an extended period of time,
+ // we will still report TRANSIENT_FAILURE. Unfortunately, we don't
+ // currently have a good way to get a subchannel to report CONNECTING
+ // for a long period of time, since the servers in this test framework
+ // are on the loopback interface, which will immediately return a
+ // "Connection refused" error, so the subchannels will only be in
+ // CONNECTING state very briefly. When we have time, see if we can
+ // find a way to fix this.
+ auto response_generator = BuildResolverResponseGenerator();
+ auto channel = BuildChannel("round_robin", response_generator);
+ auto stub = BuildStub(channel);
+ response_generator.SetNextResolution({
+ grpc_pick_unused_port_or_die(),
+ grpc_pick_unused_port_or_die(),
+ grpc_pick_unused_port_or_die(),
+ });
+ for (size_t i = 0; i < servers_.size(); ++i) {
+ servers_[i]->Shutdown();
+ }
+ auto predicate = [](grpc_connectivity_state state) {
+ return state == GRPC_CHANNEL_TRANSIENT_FAILURE;
+ };
+ EXPECT_TRUE(WaitForChannelState(channel.get(), predicate, true));
+}
+
+TEST_F(ClientLbEnd2endTest, RoundRobinSingleReconnect) {
+ const int kNumServers = 3;
+ StartServers(kNumServers);
+ const auto ports = GetServersPorts();
+ auto response_generator = BuildResolverResponseGenerator();
+ auto channel = BuildChannel("round_robin", response_generator);
+ auto stub = BuildStub(channel);
+ response_generator.SetNextResolution(ports);
+ for (size_t i = 0; i < kNumServers; ++i) {
+ WaitForServer(stub, i, DEBUG_LOCATION);
+ }
+ for (size_t i = 0; i < servers_.size(); ++i) {
+ CheckRpcSendOk(stub, DEBUG_LOCATION);
+ EXPECT_EQ(1, servers_[i]->service_.request_count()) << "for backend #" << i;
+ }
+ // One request should have gone to each server.
+ for (size_t i = 0; i < servers_.size(); ++i) {
+ EXPECT_EQ(1, servers_[i]->service_.request_count());
+ }
+ const auto pre_death = servers_[0]->service_.request_count();
+ // Kill the first server.
+ servers_[0]->Shutdown();
+ // Client request still succeed. May need retrying if RR had returned a pick
+ // before noticing the change in the server's connectivity.
+ while (!SendRpc(stub)) {
+ } // Retry until success.
+ // Send a bunch of RPCs that should succeed.
+ for (int i = 0; i < 10 * kNumServers; ++i) {
+ CheckRpcSendOk(stub, DEBUG_LOCATION);
+ }
+ const auto post_death = servers_[0]->service_.request_count();
+ // No requests have gone to the deceased server.
+ EXPECT_EQ(pre_death, post_death);
+ // Bring the first server back up.
+ StartServer(0);
+ // Requests should start arriving at the first server either right away (if
+ // the server managed to start before the RR policy retried the subchannel) or
+ // after the subchannel retry delay otherwise (RR's subchannel retried before
+ // the server was fully back up).
+ WaitForServer(stub, 0, DEBUG_LOCATION);
+}
+
+// If health checking is required by client but health checking service
+// is not running on the server, the channel should be treated as healthy.
+TEST_F(ClientLbEnd2endTest,
+ RoundRobinServersHealthCheckingUnimplementedTreatedAsHealthy) {
+ StartServers(1); // Single server
+ ChannelArguments args;
+ args.SetServiceConfigJSON(
+ "{\"healthCheckConfig\": "
+ "{\"serviceName\": \"health_check_service_name\"}}");
+ auto response_generator = BuildResolverResponseGenerator();
+ auto channel = BuildChannel("round_robin", response_generator, args);
+ auto stub = BuildStub(channel);
+ response_generator.SetNextResolution({servers_[0]->port_});
+ EXPECT_TRUE(WaitForChannelReady(channel.get()));
+ CheckRpcSendOk(stub, DEBUG_LOCATION);
+}
+
+TEST_F(ClientLbEnd2endTest, RoundRobinWithHealthChecking) {
+ EnableDefaultHealthCheckService(true);
+ // Start servers.
+ const int kNumServers = 3;
+ StartServers(kNumServers);
+ ChannelArguments args;
+ args.SetServiceConfigJSON(
+ "{\"healthCheckConfig\": "
+ "{\"serviceName\": \"health_check_service_name\"}}");
+ auto response_generator = BuildResolverResponseGenerator();
+ auto channel = BuildChannel("round_robin", response_generator, args);
+ auto stub = BuildStub(channel);
+ response_generator.SetNextResolution(GetServersPorts());
+ // Channel should not become READY, because health checks should be failing.
+ gpr_log(GPR_INFO,
+ "*** initial state: unknown health check service name for "
+ "all servers");
+ EXPECT_FALSE(WaitForChannelReady(channel.get(), 1));
+ // Now set one of the servers to be healthy.
+ // The channel should become healthy and all requests should go to
+ // the healthy server.
+ gpr_log(GPR_INFO, "*** server 0 healthy");
+ servers_[0]->SetServingStatus("health_check_service_name", true);
+ EXPECT_TRUE(WaitForChannelReady(channel.get()));
+ for (int i = 0; i < 10; ++i) {
+ CheckRpcSendOk(stub, DEBUG_LOCATION);
+ }
+ EXPECT_EQ(10, servers_[0]->service_.request_count());
+ EXPECT_EQ(0, servers_[1]->service_.request_count());
+ EXPECT_EQ(0, servers_[2]->service_.request_count());
+ // Now set a second server to be healthy.
+ gpr_log(GPR_INFO, "*** server 2 healthy");
+ servers_[2]->SetServingStatus("health_check_service_name", true);
+ WaitForServer(stub, 2, DEBUG_LOCATION);
+ for (int i = 0; i < 10; ++i) {
+ CheckRpcSendOk(stub, DEBUG_LOCATION);
+ }
+ EXPECT_EQ(5, servers_[0]->service_.request_count());
+ EXPECT_EQ(0, servers_[1]->service_.request_count());
+ EXPECT_EQ(5, servers_[2]->service_.request_count());
+ // Now set the remaining server to be healthy.
+ gpr_log(GPR_INFO, "*** server 1 healthy");
+ servers_[1]->SetServingStatus("health_check_service_name", true);
+ WaitForServer(stub, 1, DEBUG_LOCATION);
+ for (int i = 0; i < 9; ++i) {
+ CheckRpcSendOk(stub, DEBUG_LOCATION);
+ }
+ EXPECT_EQ(3, servers_[0]->service_.request_count());
+ EXPECT_EQ(3, servers_[1]->service_.request_count());
+ EXPECT_EQ(3, servers_[2]->service_.request_count());
+ // Now set one server to be unhealthy again. Then wait until the
+ // unhealthiness has hit the client. We know that the client will see
+ // this when we send kNumServers requests and one of the remaining servers
+ // sees two of the requests.
+ gpr_log(GPR_INFO, "*** server 0 unhealthy");
+ servers_[0]->SetServingStatus("health_check_service_name", false);
+ do {
+ ResetCounters();
+ for (int i = 0; i < kNumServers; ++i) {
+ CheckRpcSendOk(stub, DEBUG_LOCATION);
+ }
+ } while (servers_[1]->service_.request_count() != 2 &&
+ servers_[2]->service_.request_count() != 2);
+ // Now set the remaining two servers to be unhealthy. Make sure the
+ // channel leaves READY state and that RPCs fail.
+ gpr_log(GPR_INFO, "*** all servers unhealthy");
+ servers_[1]->SetServingStatus("health_check_service_name", false);
+ servers_[2]->SetServingStatus("health_check_service_name", false);
+ EXPECT_TRUE(WaitForChannelNotReady(channel.get()));
+ CheckRpcSendFailure(stub);
+ // Clean up.
+ EnableDefaultHealthCheckService(false);
+}
+
+TEST_F(ClientLbEnd2endTest,
+ RoundRobinWithHealthCheckingHandlesSubchannelFailure) {
+ EnableDefaultHealthCheckService(true);
+ // Start servers.
+ const int kNumServers = 3;
+ StartServers(kNumServers);
+ servers_[0]->SetServingStatus("health_check_service_name", true);
+ servers_[1]->SetServingStatus("health_check_service_name", true);
+ servers_[2]->SetServingStatus("health_check_service_name", true);
+ ChannelArguments args;
+ args.SetServiceConfigJSON(
+ "{\"healthCheckConfig\": "
+ "{\"serviceName\": \"health_check_service_name\"}}");
+ auto response_generator = BuildResolverResponseGenerator();
+ auto channel = BuildChannel("round_robin", response_generator, args);
+ auto stub = BuildStub(channel);
+ response_generator.SetNextResolution(GetServersPorts());
+ WaitForServer(stub, 0, DEBUG_LOCATION);
+ // Stop server 0 and send a new resolver result to ensure that RR
+ // checks each subchannel's state.
+ servers_[0]->Shutdown();
+ response_generator.SetNextResolution(GetServersPorts());
+ // Send a bunch more RPCs.
+ for (size_t i = 0; i < 100; i++) {
+ SendRpc(stub);
+ }
+}
+
+TEST_F(ClientLbEnd2endTest, RoundRobinWithHealthCheckingInhibitPerChannel) {
+ EnableDefaultHealthCheckService(true);
+ // Start server.
+ const int kNumServers = 1;
+ StartServers(kNumServers);
+ // Create a channel with health-checking enabled.
+ ChannelArguments args;
+ args.SetServiceConfigJSON(
+ "{\"healthCheckConfig\": "
+ "{\"serviceName\": \"health_check_service_name\"}}");
+ auto response_generator1 = BuildResolverResponseGenerator();
+ auto channel1 = BuildChannel("round_robin", response_generator1, args);
+ auto stub1 = BuildStub(channel1);
+ std::vector<int> ports = GetServersPorts();
+ response_generator1.SetNextResolution(ports);
+ // Create a channel with health checking enabled but inhibited.
+ args.SetInt(GRPC_ARG_INHIBIT_HEALTH_CHECKING, 1);
+ auto response_generator2 = BuildResolverResponseGenerator();
+ auto channel2 = BuildChannel("round_robin", response_generator2, args);
+ auto stub2 = BuildStub(channel2);
+ response_generator2.SetNextResolution(ports);
+ // First channel should not become READY, because health checks should be
+ // failing.
+ EXPECT_FALSE(WaitForChannelReady(channel1.get(), 1));
+ CheckRpcSendFailure(stub1);
+ // Second channel should be READY.
+ EXPECT_TRUE(WaitForChannelReady(channel2.get(), 1));
+ CheckRpcSendOk(stub2, DEBUG_LOCATION);
+ // Enable health checks on the backend and wait for channel 1 to succeed.
+ servers_[0]->SetServingStatus("health_check_service_name", true);
+ CheckRpcSendOk(stub1, DEBUG_LOCATION, true /* wait_for_ready */);
+ // Check that we created only one subchannel to the backend.
+ EXPECT_EQ(1UL, servers_[0]->service_.clients().size());
+ // Clean up.
+ EnableDefaultHealthCheckService(false);
+}
+
+TEST_F(ClientLbEnd2endTest, RoundRobinWithHealthCheckingServiceNamePerChannel) {
+ EnableDefaultHealthCheckService(true);
+ // Start server.
+ const int kNumServers = 1;
+ StartServers(kNumServers);
+ // Create a channel with health-checking enabled.
+ ChannelArguments args;
+ args.SetServiceConfigJSON(
+ "{\"healthCheckConfig\": "
+ "{\"serviceName\": \"health_check_service_name\"}}");
+ auto response_generator1 = BuildResolverResponseGenerator();
+ auto channel1 = BuildChannel("round_robin", response_generator1, args);
+ auto stub1 = BuildStub(channel1);
+ std::vector<int> ports = GetServersPorts();
+ response_generator1.SetNextResolution(ports);
+ // Create a channel with health-checking enabled with a different
+ // service name.
+ ChannelArguments args2;
+ args2.SetServiceConfigJSON(
+ "{\"healthCheckConfig\": "
+ "{\"serviceName\": \"health_check_service_name2\"}}");
+ auto response_generator2 = BuildResolverResponseGenerator();
+ auto channel2 = BuildChannel("round_robin", response_generator2, args2);
+ auto stub2 = BuildStub(channel2);
+ response_generator2.SetNextResolution(ports);
+ // Allow health checks from channel 2 to succeed.
+ servers_[0]->SetServingStatus("health_check_service_name2", true);
+ // First channel should not become READY, because health checks should be
+ // failing.
+ EXPECT_FALSE(WaitForChannelReady(channel1.get(), 1));
+ CheckRpcSendFailure(stub1);
+ // Second channel should be READY.
+ EXPECT_TRUE(WaitForChannelReady(channel2.get(), 1));
+ CheckRpcSendOk(stub2, DEBUG_LOCATION);
+ // Enable health checks for channel 1 and wait for it to succeed.
+ servers_[0]->SetServingStatus("health_check_service_name", true);
+ CheckRpcSendOk(stub1, DEBUG_LOCATION, true /* wait_for_ready */);
+ // Check that we created only one subchannel to the backend.
+ EXPECT_EQ(1UL, servers_[0]->service_.clients().size());
+ // Clean up.
+ EnableDefaultHealthCheckService(false);
+}
+
+TEST_F(ClientLbEnd2endTest,
+ RoundRobinWithHealthCheckingServiceNameChangesAfterSubchannelsCreated) {
+ EnableDefaultHealthCheckService(true);
+ // Start server.
+ const int kNumServers = 1;
+ StartServers(kNumServers);
+ // Create a channel with health-checking enabled.
+ const char* kServiceConfigJson =
+ "{\"healthCheckConfig\": "
+ "{\"serviceName\": \"health_check_service_name\"}}";
+ auto response_generator = BuildResolverResponseGenerator();
+ auto channel = BuildChannel("round_robin", response_generator);
+ auto stub = BuildStub(channel);
+ std::vector<int> ports = GetServersPorts();
+ response_generator.SetNextResolution(ports, kServiceConfigJson);
+ servers_[0]->SetServingStatus("health_check_service_name", true);
+ EXPECT_TRUE(WaitForChannelReady(channel.get(), 1 /* timeout_seconds */));
+ // Send an update on the channel to change it to use a health checking
+ // service name that is not being reported as healthy.
+ const char* kServiceConfigJson2 =
+ "{\"healthCheckConfig\": "
+ "{\"serviceName\": \"health_check_service_name2\"}}";
+ response_generator.SetNextResolution(ports, kServiceConfigJson2);
+ EXPECT_TRUE(WaitForChannelNotReady(channel.get()));
+ // Clean up.
+ EnableDefaultHealthCheckService(false);
+}
+
+TEST_F(ClientLbEnd2endTest, ChannelIdleness) {
+ // Start server.
+ const int kNumServers = 1;
+ StartServers(kNumServers);
+ // Set max idle time and build the channel.
+ ChannelArguments args;
+ args.SetInt(GRPC_ARG_CLIENT_IDLE_TIMEOUT_MS, 1000);
+ auto response_generator = BuildResolverResponseGenerator();
+ auto channel = BuildChannel("", response_generator, args);
+ auto stub = BuildStub(channel);
+ // The initial channel state should be IDLE.
+ EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_IDLE);
+ // After sending RPC, channel state should be READY.
+ response_generator.SetNextResolution(GetServersPorts());
+ CheckRpcSendOk(stub, DEBUG_LOCATION);
+ EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY);
+ // After a period time not using the channel, the channel state should switch
+ // to IDLE.
+ gpr_sleep_until(grpc_timeout_milliseconds_to_deadline(1200));
+ EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_IDLE);
+ // Sending a new RPC should awake the IDLE channel.
+ response_generator.SetNextResolution(GetServersPorts());
+ CheckRpcSendOk(stub, DEBUG_LOCATION);
+ EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY);
+}
+
+class ClientLbPickArgsTest : public ClientLbEnd2endTest {
+ protected:
+ void SetUp() override {
+ ClientLbEnd2endTest::SetUp();
+ current_test_instance_ = this;
+ }
+
+ static void SetUpTestCase() {
+ grpc_init();
+ grpc_core::RegisterTestPickArgsLoadBalancingPolicy(SavePickArgs);
+ }
+
+ static void TearDownTestCase() { grpc_shutdown_blocking(); }
+
+ const std::vector<grpc_core::PickArgsSeen>& args_seen_list() {
+ grpc::internal::MutexLock lock(&mu_);
+ return args_seen_list_;
+ }
+
+ private:
+ static void SavePickArgs(const grpc_core::PickArgsSeen& args_seen) {
+ ClientLbPickArgsTest* self = current_test_instance_;
+ grpc::internal::MutexLock lock(&self->mu_);
+ self->args_seen_list_.emplace_back(args_seen);
+ }
+
+ static ClientLbPickArgsTest* current_test_instance_;
+ grpc::internal::Mutex mu_;
+ std::vector<grpc_core::PickArgsSeen> args_seen_list_;
+};
+
+ClientLbPickArgsTest* ClientLbPickArgsTest::current_test_instance_ = nullptr;
+
+TEST_F(ClientLbPickArgsTest, Basic) {
+ const int kNumServers = 1;
+ StartServers(kNumServers);
+ auto response_generator = BuildResolverResponseGenerator();
+ auto channel = BuildChannel("test_pick_args_lb", response_generator);
+ auto stub = BuildStub(channel);
+ response_generator.SetNextResolution(GetServersPorts());
+ CheckRpcSendOk(stub, DEBUG_LOCATION, /*wait_for_ready=*/true);
+ // Check LB policy name for the channel.
+ EXPECT_EQ("test_pick_args_lb", channel->GetLoadBalancingPolicyName());
+ // There will be two entries, one for the pick tried in state
+ // CONNECTING and another for the pick tried in state READY.
+ EXPECT_THAT(args_seen_list(),
+ ::testing::ElementsAre(
+ ::testing::AllOf(
+ ::testing::Field(&grpc_core::PickArgsSeen::path,
+ "/grpc.testing.EchoTestService/Echo"),
+ ::testing::Field(&grpc_core::PickArgsSeen::metadata,
+ ::testing::UnorderedElementsAre(
+ ::testing::Pair("foo", "1"),
+ ::testing::Pair("bar", "2"),
+ ::testing::Pair("baz", "3")))),
+ ::testing::AllOf(
+ ::testing::Field(&grpc_core::PickArgsSeen::path,
+ "/grpc.testing.EchoTestService/Echo"),
+ ::testing::Field(&grpc_core::PickArgsSeen::metadata,
+ ::testing::UnorderedElementsAre(
+ ::testing::Pair("foo", "1"),
+ ::testing::Pair("bar", "2"),
+ ::testing::Pair("baz", "3"))))));
+}
+
+class ClientLbInterceptTrailingMetadataTest : public ClientLbEnd2endTest {
+ protected:
+ void SetUp() override {
+ ClientLbEnd2endTest::SetUp();
+ current_test_instance_ = this;
+ }
+
+ static void SetUpTestCase() {
+ grpc_init();
+ grpc_core::RegisterInterceptRecvTrailingMetadataLoadBalancingPolicy(
+ ReportTrailerIntercepted);
+ }
+
+ static void TearDownTestCase() { grpc_shutdown_blocking(); }
+
+ int trailers_intercepted() {
+ grpc::internal::MutexLock lock(&mu_);
+ return trailers_intercepted_;
+ }
+
+ const grpc_core::MetadataVector& trailing_metadata() {
+ grpc::internal::MutexLock lock(&mu_);
+ return trailing_metadata_;
+ }
+
+ const udpa::data::orca::v1::OrcaLoadReport* backend_load_report() {
+ grpc::internal::MutexLock lock(&mu_);
+ return load_report_.get();
+ }
+
+ private:
+ static void ReportTrailerIntercepted(
+ const grpc_core::TrailingMetadataArgsSeen& args_seen) {
+ const auto* backend_metric_data = args_seen.backend_metric_data;
+ ClientLbInterceptTrailingMetadataTest* self = current_test_instance_;
+ grpc::internal::MutexLock lock(&self->mu_);
+ self->trailers_intercepted_++;
+ self->trailing_metadata_ = args_seen.metadata;
+ if (backend_metric_data != nullptr) {
+ self->load_report_.reset(new udpa::data::orca::v1::OrcaLoadReport);
+ self->load_report_->set_cpu_utilization(
+ backend_metric_data->cpu_utilization);
+ self->load_report_->set_mem_utilization(
+ backend_metric_data->mem_utilization);
+ self->load_report_->set_rps(backend_metric_data->requests_per_second);
+ for (const auto& p : backend_metric_data->request_cost) {
+ TString name = TString(p.first);
+ (*self->load_report_->mutable_request_cost())[std::move(name)] =
+ p.second;
+ }
+ for (const auto& p : backend_metric_data->utilization) {
+ TString name = TString(p.first);
+ (*self->load_report_->mutable_utilization())[std::move(name)] =
+ p.second;
+ }
+ }
+ }
+
+ static ClientLbInterceptTrailingMetadataTest* current_test_instance_;
+ grpc::internal::Mutex mu_;
+ int trailers_intercepted_ = 0;
+ grpc_core::MetadataVector trailing_metadata_;
+ std::unique_ptr<udpa::data::orca::v1::OrcaLoadReport> load_report_;
+};
+
+ClientLbInterceptTrailingMetadataTest*
+ ClientLbInterceptTrailingMetadataTest::current_test_instance_ = nullptr;
+
+TEST_F(ClientLbInterceptTrailingMetadataTest, InterceptsRetriesDisabled) {
+ const int kNumServers = 1;
+ const int kNumRpcs = 10;
+ StartServers(kNumServers);
+ auto response_generator = BuildResolverResponseGenerator();
+ auto channel =
+ BuildChannel("intercept_trailing_metadata_lb", response_generator);
+ auto stub = BuildStub(channel);
+ response_generator.SetNextResolution(GetServersPorts());
+ for (size_t i = 0; i < kNumRpcs; ++i) {
+ CheckRpcSendOk(stub, DEBUG_LOCATION);
+ }
+ // Check LB policy name for the channel.
+ EXPECT_EQ("intercept_trailing_metadata_lb",
+ channel->GetLoadBalancingPolicyName());
+ EXPECT_EQ(kNumRpcs, trailers_intercepted());
+ EXPECT_THAT(trailing_metadata(),
+ ::testing::UnorderedElementsAre(
+ // TODO(roth): Should grpc-status be visible here?
+ ::testing::Pair("grpc-status", "0"),
+ ::testing::Pair("user-agent", ::testing::_),
+ ::testing::Pair("foo", "1"), ::testing::Pair("bar", "2"),
+ ::testing::Pair("baz", "3")));
+ EXPECT_EQ(nullptr, backend_load_report());
+}
+
+TEST_F(ClientLbInterceptTrailingMetadataTest, InterceptsRetriesEnabled) {
+ const int kNumServers = 1;
+ const int kNumRpcs = 10;
+ StartServers(kNumServers);
+ ChannelArguments args;
+ args.SetServiceConfigJSON(
+ "{\n"
+ " \"methodConfig\": [ {\n"
+ " \"name\": [\n"
+ " { \"service\": \"grpc.testing.EchoTestService\" }\n"
+ " ],\n"
+ " \"retryPolicy\": {\n"
+ " \"maxAttempts\": 3,\n"
+ " \"initialBackoff\": \"1s\",\n"
+ " \"maxBackoff\": \"120s\",\n"
+ " \"backoffMultiplier\": 1.6,\n"
+ " \"retryableStatusCodes\": [ \"ABORTED\" ]\n"
+ " }\n"
+ " } ]\n"
+ "}");
+ auto response_generator = BuildResolverResponseGenerator();
+ auto channel =
+ BuildChannel("intercept_trailing_metadata_lb", response_generator, args);
+ auto stub = BuildStub(channel);
+ response_generator.SetNextResolution(GetServersPorts());
+ for (size_t i = 0; i < kNumRpcs; ++i) {
+ CheckRpcSendOk(stub, DEBUG_LOCATION);
+ }
+ // Check LB policy name for the channel.
+ EXPECT_EQ("intercept_trailing_metadata_lb",
+ channel->GetLoadBalancingPolicyName());
+ EXPECT_EQ(kNumRpcs, trailers_intercepted());
+ EXPECT_THAT(trailing_metadata(),
+ ::testing::UnorderedElementsAre(
+ // TODO(roth): Should grpc-status be visible here?
+ ::testing::Pair("grpc-status", "0"),
+ ::testing::Pair("user-agent", ::testing::_),
+ ::testing::Pair("foo", "1"), ::testing::Pair("bar", "2"),
+ ::testing::Pair("baz", "3")));
+ EXPECT_EQ(nullptr, backend_load_report());
+}
+
+TEST_F(ClientLbInterceptTrailingMetadataTest, BackendMetricData) {
+ const int kNumServers = 1;
+ const int kNumRpcs = 10;
+ StartServers(kNumServers);
+ udpa::data::orca::v1::OrcaLoadReport load_report;
+ load_report.set_cpu_utilization(0.5);
+ load_report.set_mem_utilization(0.75);
+ load_report.set_rps(25);
+ auto* request_cost = load_report.mutable_request_cost();
+ (*request_cost)["foo"] = 0.8;
+ (*request_cost)["bar"] = 1.4;
+ auto* utilization = load_report.mutable_utilization();
+ (*utilization)["baz"] = 1.1;
+ (*utilization)["quux"] = 0.9;
+ for (const auto& server : servers_) {
+ server->service_.set_load_report(&load_report);
+ }
+ auto response_generator = BuildResolverResponseGenerator();
+ auto channel =
+ BuildChannel("intercept_trailing_metadata_lb", response_generator);
+ auto stub = BuildStub(channel);
+ response_generator.SetNextResolution(GetServersPorts());
+ for (size_t i = 0; i < kNumRpcs; ++i) {
+ CheckRpcSendOk(stub, DEBUG_LOCATION);
+ auto* actual = backend_load_report();
+ ASSERT_NE(actual, nullptr);
+ // TODO(roth): Change this to use EqualsProto() once that becomes
+ // available in OSS.
+ EXPECT_EQ(actual->cpu_utilization(), load_report.cpu_utilization());
+ EXPECT_EQ(actual->mem_utilization(), load_report.mem_utilization());
+ EXPECT_EQ(actual->rps(), load_report.rps());
+ EXPECT_EQ(actual->request_cost().size(), load_report.request_cost().size());
+ for (const auto& p : actual->request_cost()) {
+ auto it = load_report.request_cost().find(p.first);
+ ASSERT_NE(it, load_report.request_cost().end());
+ EXPECT_EQ(it->second, p.second);
+ }
+ EXPECT_EQ(actual->utilization().size(), load_report.utilization().size());
+ for (const auto& p : actual->utilization()) {
+ auto it = load_report.utilization().find(p.first);
+ ASSERT_NE(it, load_report.utilization().end());
+ EXPECT_EQ(it->second, p.second);
+ }
+ }
+ // Check LB policy name for the channel.
+ EXPECT_EQ("intercept_trailing_metadata_lb",
+ channel->GetLoadBalancingPolicyName());
+ EXPECT_EQ(kNumRpcs, trailers_intercepted());
+}
+
+class ClientLbAddressTest : public ClientLbEnd2endTest {
+ protected:
+ static const char* kAttributeKey;
+
+ class Attribute : public grpc_core::ServerAddress::AttributeInterface {
+ public:
+ explicit Attribute(const TString& str) : str_(str) {}
+
+ std::unique_ptr<AttributeInterface> Copy() const override {
+ return y_absl::make_unique<Attribute>(str_);
+ }
+
+ int Cmp(const AttributeInterface* other) const override {
+ return str_.compare(static_cast<const Attribute*>(other)->str_);
+ }
+
+ TString ToString() const override { return str_; }
+
+ private:
+ TString str_;
+ };
+
+ void SetUp() override {
+ ClientLbEnd2endTest::SetUp();
+ current_test_instance_ = this;
+ }
+
+ static void SetUpTestCase() {
+ grpc_init();
+ grpc_core::RegisterAddressTestLoadBalancingPolicy(SaveAddress);
+ }
+
+ static void TearDownTestCase() { grpc_shutdown_blocking(); }
+
+ const std::vector<TString>& addresses_seen() {
+ grpc::internal::MutexLock lock(&mu_);
+ return addresses_seen_;
+ }
+
+ private:
+ static void SaveAddress(const grpc_core::ServerAddress& address) {
+ ClientLbAddressTest* self = current_test_instance_;
+ grpc::internal::MutexLock lock(&self->mu_);
+ self->addresses_seen_.emplace_back(address.ToString());
+ }
+
+ static ClientLbAddressTest* current_test_instance_;
+ grpc::internal::Mutex mu_;
+ std::vector<TString> addresses_seen_;
+};
+
+const char* ClientLbAddressTest::kAttributeKey = "attribute_key";
+
+ClientLbAddressTest* ClientLbAddressTest::current_test_instance_ = nullptr;
+
+TEST_F(ClientLbAddressTest, Basic) {
+ const int kNumServers = 1;
+ StartServers(kNumServers);
+ auto response_generator = BuildResolverResponseGenerator();
+ auto channel = BuildChannel("address_test_lb", response_generator);
+ auto stub = BuildStub(channel);
+ // Addresses returned by the resolver will have attached attributes.
+ response_generator.SetNextResolution(GetServersPorts(), nullptr,
+ kAttributeKey,
+ y_absl::make_unique<Attribute>("foo"));
+ CheckRpcSendOk(stub, DEBUG_LOCATION);
+ // Check LB policy name for the channel.
+ EXPECT_EQ("address_test_lb", channel->GetLoadBalancingPolicyName());
+ // Make sure that the attributes wind up on the subchannels.
+ std::vector<TString> expected;
+ for (const int port : GetServersPorts()) {
+ expected.emplace_back(y_absl::StrCat(
+ "127.0.0.1:", port, " args={} attributes={", kAttributeKey, "=foo}"));
+ }
+ EXPECT_EQ(addresses_seen(), expected);
+}
+
+} // namespace
+} // namespace testing
+} // namespace grpc
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ grpc::testing::TestEnvironment env(argc, argv);
+ const auto result = RUN_ALL_TESTS();
+ return result;
+}
diff --git a/contrib/libs/grpc/test/cpp/end2end/delegating_channel_test.cc b/contrib/libs/grpc/test/cpp/end2end/delegating_channel_test.cc
new file mode 100644
index 0000000000..5d025ecb94
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/end2end/delegating_channel_test.cc
@@ -0,0 +1,100 @@
+/*
+ *
+ * Copyright 2018 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include <memory>
+#include <vector>
+
+#include <grpcpp/channel.h>
+#include <grpcpp/client_context.h>
+#include <grpcpp/create_channel.h>
+#include <grpcpp/generic/generic_stub.h>
+#include <grpcpp/impl/codegen/delegating_channel.h>
+#include <grpcpp/impl/codegen/proto_utils.h>
+#include <grpcpp/server.h>
+#include <grpcpp/server_builder.h>
+#include <grpcpp/server_context.h>
+#include <grpcpp/support/client_interceptor.h>
+
+#include "src/proto/grpc/testing/echo.grpc.pb.h"
+#include "test/core/util/port.h"
+#include "test/core/util/test_config.h"
+#include "test/cpp/end2end/test_service_impl.h"
+#include "test/cpp/util/byte_buffer_proto_helper.h"
+#include "test/cpp/util/string_ref_helper.h"
+
+#include <gtest/gtest.h>
+
+namespace grpc {
+namespace testing {
+namespace {
+
+class TestChannel : public experimental::DelegatingChannel {
+ public:
+ TestChannel(const std::shared_ptr<ChannelInterface>& delegate_channel)
+ : experimental::DelegatingChannel(delegate_channel) {}
+ // Always returns GRPC_CHANNEL_READY
+ grpc_connectivity_state GetState(bool /*try_to_connect*/) override {
+ return GRPC_CHANNEL_READY;
+ }
+};
+
+class DelegatingChannelTest : public ::testing::Test {
+ protected:
+ DelegatingChannelTest() {
+ int port = grpc_pick_unused_port_or_die();
+ ServerBuilder builder;
+ server_address_ = "localhost:" + ToString(port);
+ builder.AddListeningPort(server_address_, InsecureServerCredentials());
+ builder.RegisterService(&service_);
+ server_ = builder.BuildAndStart();
+ }
+
+ ~DelegatingChannelTest() { server_->Shutdown(); }
+
+ TString server_address_;
+ TestServiceImpl service_;
+ std::unique_ptr<Server> server_;
+};
+
+TEST_F(DelegatingChannelTest, SimpleTest) {
+ auto channel = CreateChannel(server_address_, InsecureChannelCredentials());
+ std::shared_ptr<TestChannel> test_channel =
+ std::make_shared<TestChannel>(channel);
+ // gRPC channel should be in idle state at this point but our test channel
+ // will return ready.
+ EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_IDLE);
+ EXPECT_EQ(test_channel->GetState(false), GRPC_CHANNEL_READY);
+ auto stub = grpc::testing::EchoTestService::NewStub(test_channel);
+ ClientContext ctx;
+ EchoRequest req;
+ req.set_message("Hello");
+ EchoResponse resp;
+ Status s = stub->Echo(&ctx, req, &resp);
+ EXPECT_EQ(s.ok(), true);
+ EXPECT_EQ(resp.message(), "Hello");
+}
+
+} // namespace
+} // namespace testing
+} // namespace grpc
+
+int main(int argc, char** argv) {
+ grpc::testing::TestEnvironment env(argc, argv);
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/contrib/libs/grpc/test/cpp/end2end/end2end_test.cc b/contrib/libs/grpc/test/cpp/end2end/end2end_test.cc
new file mode 100644
index 0000000000..ad2ddb7e84
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/end2end/end2end_test.cc
@@ -0,0 +1,2357 @@
+/*
+ *
+ * Copyright 2015 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include <grpc/grpc.h>
+#include <grpc/support/alloc.h>
+#include <grpc/support/log.h>
+#include <grpc/support/time.h>
+#include <grpcpp/channel.h>
+#include <grpcpp/client_context.h>
+#include <grpcpp/create_channel.h>
+#include <grpcpp/impl/codegen/status_code_enum.h>
+#include <grpcpp/resource_quota.h>
+#include <grpcpp/security/auth_metadata_processor.h>
+#include <grpcpp/security/credentials.h>
+#include <grpcpp/security/server_credentials.h>
+#include <grpcpp/server.h>
+#include <grpcpp/server_builder.h>
+#include <grpcpp/server_context.h>
+#include <grpcpp/support/string_ref.h>
+#include <grpcpp/test/channel_test_peer.h>
+
+#include <mutex>
+#include <thread>
+
+#include "y_absl/strings/str_format.h"
+#include "src/core/ext/filters/client_channel/backup_poller.h"
+#include "src/core/lib/gpr/env.h"
+#include "src/core/lib/iomgr/iomgr.h"
+#include "src/core/lib/security/credentials/credentials.h"
+#include "src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.h"
+#include "src/proto/grpc/testing/echo.grpc.pb.h"
+#include "test/core/util/port.h"
+#include "test/core/util/test_config.h"
+#include "test/cpp/end2end/interceptors_util.h"
+#include "test/cpp/end2end/test_service_impl.h"
+#include "test/cpp/util/string_ref_helper.h"
+#include "test/cpp/util/test_credentials_provider.h"
+
+#ifdef GRPC_POSIX_SOCKET_EV
+#include "src/core/lib/iomgr/ev_posix.h"
+#endif // GRPC_POSIX_SOCKET_EV
+
+#include <gtest/gtest.h>
+
+using grpc::testing::EchoRequest;
+using grpc::testing::EchoResponse;
+using grpc::testing::kTlsCredentialsType;
+using std::chrono::system_clock;
+
+// MAYBE_SKIP_TEST is a macro to determine if this particular test configuration
+// should be skipped based on a decision made at SetUp time. In particular,
+// tests that use the callback server can only be run if the iomgr can run in
+// the background or if the transport is in-process.
+#define MAYBE_SKIP_TEST \
+ do { \
+ if (do_not_test_) { \
+ return; \
+ } \
+ } while (0)
+
+namespace grpc {
+namespace testing {
+namespace {
+
+bool CheckIsLocalhost(const TString& addr) {
+ const TString kIpv6("ipv6:[::1]:");
+ const TString kIpv4MappedIpv6("ipv6:[::ffff:127.0.0.1]:");
+ const TString kIpv4("ipv4:127.0.0.1:");
+ return addr.substr(0, kIpv4.size()) == kIpv4 ||
+ addr.substr(0, kIpv4MappedIpv6.size()) == kIpv4MappedIpv6 ||
+ addr.substr(0, kIpv6.size()) == kIpv6;
+}
+
+const int kClientChannelBackupPollIntervalMs = 200;
+
+const char kTestCredsPluginErrorMsg[] = "Could not find plugin metadata.";
+
+const char kFakeToken[] = "fake_token";
+const char kFakeSelector[] = "fake_selector";
+const char kExpectedFakeCredsDebugString[] =
+ "SecureCallCredentials{GoogleIAMCredentials{Token:present,"
+ "AuthoritySelector:fake_selector}}";
+
+const char kWrongToken[] = "wrong_token";
+const char kWrongSelector[] = "wrong_selector";
+const char kExpectedWrongCredsDebugString[] =
+ "SecureCallCredentials{GoogleIAMCredentials{Token:present,"
+ "AuthoritySelector:wrong_selector}}";
+
+const char kFakeToken1[] = "fake_token1";
+const char kFakeSelector1[] = "fake_selector1";
+const char kExpectedFakeCreds1DebugString[] =
+ "SecureCallCredentials{GoogleIAMCredentials{Token:present,"
+ "AuthoritySelector:fake_selector1}}";
+
+const char kFakeToken2[] = "fake_token2";
+const char kFakeSelector2[] = "fake_selector2";
+const char kExpectedFakeCreds2DebugString[] =
+ "SecureCallCredentials{GoogleIAMCredentials{Token:present,"
+ "AuthoritySelector:fake_selector2}}";
+
+const char kExpectedAuthMetadataPluginKeyFailureCredsDebugString[] =
+ "SecureCallCredentials{TestMetadataCredentials{key:TestPluginMetadata,"
+ "value:Does not matter, will fail the key is invalid.}}";
+const char kExpectedAuthMetadataPluginValueFailureCredsDebugString[] =
+ "SecureCallCredentials{TestMetadataCredentials{key:test-plugin-metadata,"
+ "value:With illegal \n value.}}";
+const char kExpectedAuthMetadataPluginWithDeadlineCredsDebugString[] =
+ "SecureCallCredentials{TestMetadataCredentials{key:meta_key,value:Does not "
+ "matter}}";
+const char kExpectedNonBlockingAuthMetadataPluginFailureCredsDebugString[] =
+ "SecureCallCredentials{TestMetadataCredentials{key:test-plugin-metadata,"
+ "value:Does not matter, will fail anyway (see 3rd param)}}";
+const char
+ kExpectedNonBlockingAuthMetadataPluginAndProcessorSuccessCredsDebugString
+ [] = "SecureCallCredentials{TestMetadataCredentials{key:test-plugin-"
+ "metadata,value:Dr Jekyll}}";
+const char
+ kExpectedNonBlockingAuthMetadataPluginAndProcessorFailureCredsDebugString
+ [] = "SecureCallCredentials{TestMetadataCredentials{key:test-plugin-"
+ "metadata,value:Mr Hyde}}";
+const char kExpectedBlockingAuthMetadataPluginFailureCredsDebugString[] =
+ "SecureCallCredentials{TestMetadataCredentials{key:test-plugin-metadata,"
+ "value:Does not matter, will fail anyway (see 3rd param)}}";
+const char kExpectedCompositeCallCredsDebugString[] =
+ "SecureCallCredentials{CompositeCallCredentials{TestMetadataCredentials{"
+ "key:call-creds-key1,value:call-creds-val1},TestMetadataCredentials{key:"
+ "call-creds-key2,value:call-creds-val2}}}";
+
+class TestMetadataCredentialsPlugin : public MetadataCredentialsPlugin {
+ public:
+ static const char kGoodMetadataKey[];
+ static const char kBadMetadataKey[];
+
+ TestMetadataCredentialsPlugin(const grpc::string_ref& metadata_key,
+ const grpc::string_ref& metadata_value,
+ bool is_blocking, bool is_successful,
+ int delay_ms)
+ : metadata_key_(metadata_key.data(), metadata_key.length()),
+ metadata_value_(metadata_value.data(), metadata_value.length()),
+ is_blocking_(is_blocking),
+ is_successful_(is_successful),
+ delay_ms_(delay_ms) {}
+
+ bool IsBlocking() const override { return is_blocking_; }
+
+ Status GetMetadata(
+ grpc::string_ref service_url, grpc::string_ref method_name,
+ const grpc::AuthContext& channel_auth_context,
+ std::multimap<TString, TString>* metadata) override {
+ if (delay_ms_ != 0) {
+ gpr_sleep_until(
+ gpr_time_add(gpr_now(GPR_CLOCK_REALTIME),
+ gpr_time_from_millis(delay_ms_, GPR_TIMESPAN)));
+ }
+ EXPECT_GT(service_url.length(), 0UL);
+ EXPECT_GT(method_name.length(), 0UL);
+ EXPECT_TRUE(channel_auth_context.IsPeerAuthenticated());
+ EXPECT_TRUE(metadata != nullptr);
+ if (is_successful_) {
+ metadata->insert(std::make_pair(metadata_key_, metadata_value_));
+ return Status::OK;
+ } else {
+ return Status(StatusCode::NOT_FOUND, kTestCredsPluginErrorMsg);
+ }
+ }
+
+ TString DebugString() override {
+ return y_absl::StrFormat("TestMetadataCredentials{key:%s,value:%s}",
+ metadata_key_.c_str(), metadata_value_.c_str());
+ }
+
+ private:
+ TString metadata_key_;
+ TString metadata_value_;
+ bool is_blocking_;
+ bool is_successful_;
+ int delay_ms_;
+};
+
+const char TestMetadataCredentialsPlugin::kBadMetadataKey[] =
+ "TestPluginMetadata";
+const char TestMetadataCredentialsPlugin::kGoodMetadataKey[] =
+ "test-plugin-metadata";
+
+class TestAuthMetadataProcessor : public AuthMetadataProcessor {
+ public:
+ static const char kGoodGuy[];
+
+ TestAuthMetadataProcessor(bool is_blocking) : is_blocking_(is_blocking) {}
+
+ std::shared_ptr<CallCredentials> GetCompatibleClientCreds() {
+ return grpc::MetadataCredentialsFromPlugin(
+ std::unique_ptr<MetadataCredentialsPlugin>(
+ new TestMetadataCredentialsPlugin(
+ TestMetadataCredentialsPlugin::kGoodMetadataKey, kGoodGuy,
+ is_blocking_, true, 0)));
+ }
+
+ std::shared_ptr<CallCredentials> GetIncompatibleClientCreds() {
+ return grpc::MetadataCredentialsFromPlugin(
+ std::unique_ptr<MetadataCredentialsPlugin>(
+ new TestMetadataCredentialsPlugin(
+ TestMetadataCredentialsPlugin::kGoodMetadataKey, "Mr Hyde",
+ is_blocking_, true, 0)));
+ }
+
+ // Interface implementation
+ bool IsBlocking() const override { return is_blocking_; }
+
+ Status Process(const InputMetadata& auth_metadata, AuthContext* context,
+ OutputMetadata* consumed_auth_metadata,
+ OutputMetadata* response_metadata) override {
+ EXPECT_TRUE(consumed_auth_metadata != nullptr);
+ EXPECT_TRUE(context != nullptr);
+ EXPECT_TRUE(response_metadata != nullptr);
+ auto auth_md =
+ auth_metadata.find(TestMetadataCredentialsPlugin::kGoodMetadataKey);
+ EXPECT_NE(auth_md, auth_metadata.end());
+ string_ref auth_md_value = auth_md->second;
+ if (auth_md_value == kGoodGuy) {
+ context->AddProperty(kIdentityPropName, kGoodGuy);
+ context->SetPeerIdentityPropertyName(kIdentityPropName);
+ consumed_auth_metadata->insert(std::make_pair(
+ string(auth_md->first.data(), auth_md->first.length()),
+ string(auth_md->second.data(), auth_md->second.length())));
+ return Status::OK;
+ } else {
+ return Status(StatusCode::UNAUTHENTICATED,
+ string("Invalid principal: ") +
+ string(auth_md_value.data(), auth_md_value.length()));
+ }
+ }
+
+ private:
+ static const char kIdentityPropName[];
+ bool is_blocking_;
+};
+
+const char TestAuthMetadataProcessor::kGoodGuy[] = "Dr Jekyll";
+const char TestAuthMetadataProcessor::kIdentityPropName[] = "novel identity";
+
+class Proxy : public ::grpc::testing::EchoTestService::Service {
+ public:
+ Proxy(const std::shared_ptr<Channel>& channel)
+ : stub_(grpc::testing::EchoTestService::NewStub(channel)) {}
+
+ Status Echo(ServerContext* server_context, const EchoRequest* request,
+ EchoResponse* response) override {
+ std::unique_ptr<ClientContext> client_context =
+ ClientContext::FromServerContext(*server_context);
+ return stub_->Echo(client_context.get(), *request, response);
+ }
+
+ private:
+ std::unique_ptr<::grpc::testing::EchoTestService::Stub> stub_;
+};
+
+class TestServiceImplDupPkg
+ : public ::grpc::testing::duplicate::EchoTestService::Service {
+ public:
+ Status Echo(ServerContext* /*context*/, const EchoRequest* /*request*/,
+ EchoResponse* response) override {
+ response->set_message("no package");
+ return Status::OK;
+ }
+};
+
+class TestScenario {
+ public:
+ TestScenario(bool interceptors, bool proxy, bool inproc_stub,
+ const TString& creds_type, bool use_callback_server)
+ : use_interceptors(interceptors),
+ use_proxy(proxy),
+ inproc(inproc_stub),
+ credentials_type(creds_type),
+ callback_server(use_callback_server) {}
+ void Log() const;
+ bool use_interceptors;
+ bool use_proxy;
+ bool inproc;
+ const TString credentials_type;
+ bool callback_server;
+};
+
+static std::ostream& operator<<(std::ostream& out,
+ const TestScenario& scenario) {
+ return out << "TestScenario{use_interceptors="
+ << (scenario.use_interceptors ? "true" : "false")
+ << ", use_proxy=" << (scenario.use_proxy ? "true" : "false")
+ << ", inproc=" << (scenario.inproc ? "true" : "false")
+ << ", server_type="
+ << (scenario.callback_server ? "callback" : "sync")
+ << ", credentials='" << scenario.credentials_type << "'}";
+}
+
+void TestScenario::Log() const {
+ std::ostringstream out;
+ out << *this;
+ gpr_log(GPR_DEBUG, "%s", out.str().c_str());
+}
+
+class End2endTest : public ::testing::TestWithParam<TestScenario> {
+ protected:
+ static void SetUpTestCase() { grpc_init(); }
+ static void TearDownTestCase() { grpc_shutdown(); }
+ End2endTest()
+ : is_server_started_(false),
+ kMaxMessageSize_(8192),
+ special_service_("special"),
+ first_picked_port_(0) {
+ GetParam().Log();
+ }
+
+ void SetUp() override {
+ if (GetParam().callback_server && !GetParam().inproc &&
+ !grpc_iomgr_run_in_background()) {
+ do_not_test_ = true;
+ return;
+ }
+ }
+
+ void TearDown() override {
+ if (is_server_started_) {
+ server_->Shutdown();
+ if (proxy_server_) proxy_server_->Shutdown();
+ }
+ if (first_picked_port_ > 0) {
+ grpc_recycle_unused_port(first_picked_port_);
+ }
+ }
+
+ void StartServer(const std::shared_ptr<AuthMetadataProcessor>& processor) {
+ int port = grpc_pick_unused_port_or_die();
+ first_picked_port_ = port;
+ server_address_ << "127.0.0.1:" << port;
+ // Setup server
+ BuildAndStartServer(processor);
+ }
+
+ void RestartServer(const std::shared_ptr<AuthMetadataProcessor>& processor) {
+ if (is_server_started_) {
+ server_->Shutdown();
+ BuildAndStartServer(processor);
+ }
+ }
+
+ void BuildAndStartServer(
+ const std::shared_ptr<AuthMetadataProcessor>& processor) {
+ ServerBuilder builder;
+ ConfigureServerBuilder(&builder);
+ auto server_creds = GetCredentialsProvider()->GetServerCredentials(
+ GetParam().credentials_type);
+ if (GetParam().credentials_type != kInsecureCredentialsType) {
+ server_creds->SetAuthMetadataProcessor(processor);
+ }
+ if (GetParam().use_interceptors) {
+ std::vector<
+ std::unique_ptr<experimental::ServerInterceptorFactoryInterface>>
+ creators;
+ // Add 20 dummy server interceptors
+ creators.reserve(20);
+ for (auto i = 0; i < 20; i++) {
+ creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
+ new DummyInterceptorFactory()));
+ }
+ builder.experimental().SetInterceptorCreators(std::move(creators));
+ }
+ builder.AddListeningPort(server_address_.str(), server_creds);
+ if (!GetParam().callback_server) {
+ builder.RegisterService(&service_);
+ } else {
+ builder.RegisterService(&callback_service_);
+ }
+ builder.RegisterService("foo.test.youtube.com", &special_service_);
+ builder.RegisterService(&dup_pkg_service_);
+
+ builder.SetSyncServerOption(ServerBuilder::SyncServerOption::NUM_CQS, 4);
+ builder.SetSyncServerOption(
+ ServerBuilder::SyncServerOption::CQ_TIMEOUT_MSEC, 10);
+
+ server_ = builder.BuildAndStart();
+ is_server_started_ = true;
+ }
+
+ virtual void ConfigureServerBuilder(ServerBuilder* builder) {
+ builder->SetMaxMessageSize(
+ kMaxMessageSize_); // For testing max message size.
+ }
+
+ void ResetChannel(
+ std::vector<
+ std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
+ interceptor_creators = {}) {
+ if (!is_server_started_) {
+ StartServer(std::shared_ptr<AuthMetadataProcessor>());
+ }
+ EXPECT_TRUE(is_server_started_);
+ ChannelArguments args;
+ auto channel_creds = GetCredentialsProvider()->GetChannelCredentials(
+ GetParam().credentials_type, &args);
+ if (!user_agent_prefix_.empty()) {
+ args.SetUserAgentPrefix(user_agent_prefix_);
+ }
+ args.SetString(GRPC_ARG_SECONDARY_USER_AGENT_STRING, "end2end_test");
+
+ if (!GetParam().inproc) {
+ if (!GetParam().use_interceptors) {
+ channel_ = ::grpc::CreateCustomChannel(server_address_.str(),
+ channel_creds, args);
+ } else {
+ channel_ = CreateCustomChannelWithInterceptors(
+ server_address_.str(), channel_creds, args,
+ interceptor_creators.empty() ? CreateDummyClientInterceptors()
+ : std::move(interceptor_creators));
+ }
+ } else {
+ if (!GetParam().use_interceptors) {
+ channel_ = server_->InProcessChannel(args);
+ } else {
+ channel_ = server_->experimental().InProcessChannelWithInterceptors(
+ args, interceptor_creators.empty()
+ ? CreateDummyClientInterceptors()
+ : std::move(interceptor_creators));
+ }
+ }
+ }
+
+ void ResetStub(
+ std::vector<
+ std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
+ interceptor_creators = {}) {
+ ResetChannel(std::move(interceptor_creators));
+ if (GetParam().use_proxy) {
+ proxy_service_.reset(new Proxy(channel_));
+ int port = grpc_pick_unused_port_or_die();
+ std::ostringstream proxyaddr;
+ proxyaddr << "localhost:" << port;
+ ServerBuilder builder;
+ builder.AddListeningPort(proxyaddr.str(), InsecureServerCredentials());
+ builder.RegisterService(proxy_service_.get());
+
+ builder.SetSyncServerOption(ServerBuilder::SyncServerOption::NUM_CQS, 4);
+ builder.SetSyncServerOption(
+ ServerBuilder::SyncServerOption::CQ_TIMEOUT_MSEC, 10);
+
+ proxy_server_ = builder.BuildAndStart();
+
+ channel_ =
+ grpc::CreateChannel(proxyaddr.str(), InsecureChannelCredentials());
+ }
+
+ stub_ = grpc::testing::EchoTestService::NewStub(channel_);
+ DummyInterceptor::Reset();
+ }
+
+ bool do_not_test_{false};
+ bool is_server_started_;
+ std::shared_ptr<Channel> channel_;
+ std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_;
+ std::unique_ptr<Server> server_;
+ std::unique_ptr<Server> proxy_server_;
+ std::unique_ptr<Proxy> proxy_service_;
+ std::ostringstream server_address_;
+ const int kMaxMessageSize_;
+ TestServiceImpl service_;
+ CallbackTestServiceImpl callback_service_;
+ TestServiceImpl special_service_;
+ TestServiceImplDupPkg dup_pkg_service_;
+ TString user_agent_prefix_;
+ int first_picked_port_;
+};
+
+static void SendRpc(grpc::testing::EchoTestService::Stub* stub, int num_rpcs,
+ bool with_binary_metadata) {
+ EchoRequest request;
+ EchoResponse response;
+ request.set_message("Hello hello hello hello");
+
+ for (int i = 0; i < num_rpcs; ++i) {
+ ClientContext context;
+ if (with_binary_metadata) {
+ char bytes[8] = {'\0', '\1', '\2', '\3',
+ '\4', '\5', '\6', static_cast<char>(i)};
+ context.AddMetadata("custom-bin", TString(bytes, 8));
+ }
+ context.set_compression_algorithm(GRPC_COMPRESS_GZIP);
+ Status s = stub->Echo(&context, request, &response);
+ EXPECT_EQ(response.message(), request.message());
+ EXPECT_TRUE(s.ok());
+ }
+}
+
+// This class is for testing scenarios where RPCs are cancelled on the server
+// by calling ServerContext::TryCancel()
+class End2endServerTryCancelTest : public End2endTest {
+ protected:
+ // Helper for testing client-streaming RPCs which are cancelled on the server.
+ // Depending on the value of server_try_cancel parameter, this will test one
+ // of the following three scenarios:
+ // CANCEL_BEFORE_PROCESSING: Rpc is cancelled by the server before reading
+ // any messages from the client
+ //
+ // CANCEL_DURING_PROCESSING: Rpc is cancelled by the server while reading
+ // messages from the client
+ //
+ // CANCEL_AFTER PROCESSING: Rpc is cancelled by server after reading all
+ // the messages from the client
+ //
+ // NOTE: Do not call this function with server_try_cancel == DO_NOT_CANCEL.
+ void TestRequestStreamServerCancel(
+ ServerTryCancelRequestPhase server_try_cancel, int num_msgs_to_send) {
+ MAYBE_SKIP_TEST;
+ RestartServer(std::shared_ptr<AuthMetadataProcessor>());
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+
+ // Send server_try_cancel value in the client metadata
+ context.AddMetadata(kServerTryCancelRequest,
+ ToString(server_try_cancel));
+
+ auto stream = stub_->RequestStream(&context, &response);
+
+ int num_msgs_sent = 0;
+ while (num_msgs_sent < num_msgs_to_send) {
+ request.set_message("hello");
+ if (!stream->Write(request)) {
+ break;
+ }
+ num_msgs_sent++;
+ }
+ gpr_log(GPR_INFO, "Sent %d messages", num_msgs_sent);
+
+ stream->WritesDone();
+ Status s = stream->Finish();
+
+ // At this point, we know for sure that RPC was cancelled by the server
+ // since we passed server_try_cancel value in the metadata. Depending on the
+ // value of server_try_cancel, the RPC might have been cancelled by the
+ // server at different stages. The following validates our expectations of
+ // number of messages sent in various cancellation scenarios:
+
+ switch (server_try_cancel) {
+ case CANCEL_BEFORE_PROCESSING:
+ case CANCEL_DURING_PROCESSING:
+ // If the RPC is cancelled by server before / during messages from the
+ // client, it means that the client most likely did not get a chance to
+ // send all the messages it wanted to send. i.e num_msgs_sent <=
+ // num_msgs_to_send
+ EXPECT_LE(num_msgs_sent, num_msgs_to_send);
+ break;
+
+ case CANCEL_AFTER_PROCESSING:
+ // If the RPC was cancelled after all messages were read by the server,
+ // the client did get a chance to send all its messages
+ EXPECT_EQ(num_msgs_sent, num_msgs_to_send);
+ break;
+
+ default:
+ gpr_log(GPR_ERROR, "Invalid server_try_cancel value: %d",
+ server_try_cancel);
+ EXPECT_TRUE(server_try_cancel > DO_NOT_CANCEL &&
+ server_try_cancel <= CANCEL_AFTER_PROCESSING);
+ break;
+ }
+
+ EXPECT_FALSE(s.ok());
+ EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code());
+ // Make sure that the server interceptors were notified
+ if (GetParam().use_interceptors) {
+ EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel());
+ }
+ }
+
+ // Helper for testing server-streaming RPCs which are cancelled on the server.
+ // Depending on the value of server_try_cancel parameter, this will test one
+ // of the following three scenarios:
+ // CANCEL_BEFORE_PROCESSING: Rpc is cancelled by the server before writing
+ // any messages to the client
+ //
+ // CANCEL_DURING_PROCESSING: Rpc is cancelled by the server while writing
+ // messages to the client
+ //
+ // CANCEL_AFTER PROCESSING: Rpc is cancelled by server after writing all
+ // the messages to the client
+ //
+ // NOTE: Do not call this function with server_try_cancel == DO_NOT_CANCEL.
+ void TestResponseStreamServerCancel(
+ ServerTryCancelRequestPhase server_try_cancel) {
+ MAYBE_SKIP_TEST;
+ RestartServer(std::shared_ptr<AuthMetadataProcessor>());
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+
+ // Send server_try_cancel in the client metadata
+ context.AddMetadata(kServerTryCancelRequest,
+ ToString(server_try_cancel));
+
+ request.set_message("hello");
+ auto stream = stub_->ResponseStream(&context, request);
+
+ int num_msgs_read = 0;
+ while (num_msgs_read < kServerDefaultResponseStreamsToSend) {
+ if (!stream->Read(&response)) {
+ break;
+ }
+ EXPECT_EQ(response.message(),
+ request.message() + ToString(num_msgs_read));
+ num_msgs_read++;
+ }
+ gpr_log(GPR_INFO, "Read %d messages", num_msgs_read);
+
+ Status s = stream->Finish();
+
+ // Depending on the value of server_try_cancel, the RPC might have been
+ // cancelled by the server at different stages. The following validates our
+ // expectations of number of messages read in various cancellation
+ // scenarios:
+ switch (server_try_cancel) {
+ case CANCEL_BEFORE_PROCESSING:
+ // Server cancelled before sending any messages. Which means the client
+ // wouldn't have read any
+ EXPECT_EQ(num_msgs_read, 0);
+ break;
+
+ case CANCEL_DURING_PROCESSING:
+ // Server cancelled while writing messages. Client must have read less
+ // than or equal to the expected number of messages
+ EXPECT_LE(num_msgs_read, kServerDefaultResponseStreamsToSend);
+ break;
+
+ case CANCEL_AFTER_PROCESSING:
+ // Even though the Server cancelled after writing all messages, the RPC
+ // may be cancelled before the Client got a chance to read all the
+ // messages.
+ EXPECT_LE(num_msgs_read, kServerDefaultResponseStreamsToSend);
+ break;
+
+ default: {
+ gpr_log(GPR_ERROR, "Invalid server_try_cancel value: %d",
+ server_try_cancel);
+ EXPECT_TRUE(server_try_cancel > DO_NOT_CANCEL &&
+ server_try_cancel <= CANCEL_AFTER_PROCESSING);
+ break;
+ }
+ }
+
+ EXPECT_FALSE(s.ok());
+ // Make sure that the server interceptors were notified
+ if (GetParam().use_interceptors) {
+ EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel());
+ }
+ }
+
+ // Helper for testing bidirectional-streaming RPCs which are cancelled on the
+ // server. Depending on the value of server_try_cancel parameter, this will
+ // test one of the following three scenarios:
+ // CANCEL_BEFORE_PROCESSING: Rpc is cancelled by the server before reading/
+ // writing any messages from/to the client
+ //
+ // CANCEL_DURING_PROCESSING: Rpc is cancelled by the server while reading/
+ // writing messages from/to the client
+ //
+ // CANCEL_AFTER PROCESSING: Rpc is cancelled by server after reading/writing
+ // all the messages from/to the client
+ //
+ // NOTE: Do not call this function with server_try_cancel == DO_NOT_CANCEL.
+ void TestBidiStreamServerCancel(ServerTryCancelRequestPhase server_try_cancel,
+ int num_messages) {
+ MAYBE_SKIP_TEST;
+ RestartServer(std::shared_ptr<AuthMetadataProcessor>());
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+
+ // Send server_try_cancel in the client metadata
+ context.AddMetadata(kServerTryCancelRequest,
+ ToString(server_try_cancel));
+
+ auto stream = stub_->BidiStream(&context);
+
+ int num_msgs_read = 0;
+ int num_msgs_sent = 0;
+ while (num_msgs_sent < num_messages) {
+ request.set_message("hello " + ToString(num_msgs_sent));
+ if (!stream->Write(request)) {
+ break;
+ }
+ num_msgs_sent++;
+
+ if (!stream->Read(&response)) {
+ break;
+ }
+ num_msgs_read++;
+
+ EXPECT_EQ(response.message(), request.message());
+ }
+ gpr_log(GPR_INFO, "Sent %d messages", num_msgs_sent);
+ gpr_log(GPR_INFO, "Read %d messages", num_msgs_read);
+
+ stream->WritesDone();
+ Status s = stream->Finish();
+
+ // Depending on the value of server_try_cancel, the RPC might have been
+ // cancelled by the server at different stages. The following validates our
+ // expectations of number of messages read in various cancellation
+ // scenarios:
+ switch (server_try_cancel) {
+ case CANCEL_BEFORE_PROCESSING:
+ EXPECT_EQ(num_msgs_read, 0);
+ break;
+
+ case CANCEL_DURING_PROCESSING:
+ EXPECT_LE(num_msgs_sent, num_messages);
+ EXPECT_LE(num_msgs_read, num_msgs_sent);
+ break;
+
+ case CANCEL_AFTER_PROCESSING:
+ EXPECT_EQ(num_msgs_sent, num_messages);
+
+ // The Server cancelled after reading the last message and after writing
+ // the message to the client. However, the RPC cancellation might have
+ // taken effect before the client actually read the response.
+ EXPECT_LE(num_msgs_read, num_msgs_sent);
+ break;
+
+ default:
+ gpr_log(GPR_ERROR, "Invalid server_try_cancel value: %d",
+ server_try_cancel);
+ EXPECT_TRUE(server_try_cancel > DO_NOT_CANCEL &&
+ server_try_cancel <= CANCEL_AFTER_PROCESSING);
+ break;
+ }
+
+ EXPECT_FALSE(s.ok());
+ EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code());
+ // Make sure that the server interceptors were notified
+ if (GetParam().use_interceptors) {
+ EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel());
+ }
+ }
+};
+
+TEST_P(End2endServerTryCancelTest, RequestEchoServerCancel) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+
+ context.AddMetadata(kServerTryCancelRequest,
+ ToString(CANCEL_BEFORE_PROCESSING));
+ Status s = stub_->Echo(&context, request, &response);
+ EXPECT_FALSE(s.ok());
+ EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code());
+}
+
+// Server to cancel before doing reading the request
+TEST_P(End2endServerTryCancelTest, RequestStreamServerCancelBeforeReads) {
+ TestRequestStreamServerCancel(CANCEL_BEFORE_PROCESSING, 1);
+}
+
+// Server to cancel while reading a request from the stream in parallel
+TEST_P(End2endServerTryCancelTest, RequestStreamServerCancelDuringRead) {
+ TestRequestStreamServerCancel(CANCEL_DURING_PROCESSING, 10);
+}
+
+// Server to cancel after reading all the requests but before returning to the
+// client
+TEST_P(End2endServerTryCancelTest, RequestStreamServerCancelAfterReads) {
+ TestRequestStreamServerCancel(CANCEL_AFTER_PROCESSING, 4);
+}
+
+// Server to cancel before sending any response messages
+TEST_P(End2endServerTryCancelTest, ResponseStreamServerCancelBefore) {
+ TestResponseStreamServerCancel(CANCEL_BEFORE_PROCESSING);
+}
+
+// Server to cancel while writing a response to the stream in parallel
+TEST_P(End2endServerTryCancelTest, ResponseStreamServerCancelDuring) {
+ TestResponseStreamServerCancel(CANCEL_DURING_PROCESSING);
+}
+
+// Server to cancel after writing all the respones to the stream but before
+// returning to the client
+TEST_P(End2endServerTryCancelTest, ResponseStreamServerCancelAfter) {
+ TestResponseStreamServerCancel(CANCEL_AFTER_PROCESSING);
+}
+
+// Server to cancel before reading/writing any requests/responses on the stream
+TEST_P(End2endServerTryCancelTest, BidiStreamServerCancelBefore) {
+ TestBidiStreamServerCancel(CANCEL_BEFORE_PROCESSING, 2);
+}
+
+// Server to cancel while reading/writing requests/responses on the stream in
+// parallel
+TEST_P(End2endServerTryCancelTest, BidiStreamServerCancelDuring) {
+ TestBidiStreamServerCancel(CANCEL_DURING_PROCESSING, 10);
+}
+
+// Server to cancel after reading/writing all requests/responses on the stream
+// but before returning to the client
+TEST_P(End2endServerTryCancelTest, BidiStreamServerCancelAfter) {
+ TestBidiStreamServerCancel(CANCEL_AFTER_PROCESSING, 5);
+}
+
+TEST_P(End2endTest, SimpleRpcWithCustomUserAgentPrefix) {
+ MAYBE_SKIP_TEST;
+ // User-Agent is an HTTP header for HTTP transports only
+ if (GetParam().inproc) {
+ return;
+ }
+ user_agent_prefix_ = "custom_prefix";
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ request.set_message("Hello hello hello hello");
+ request.mutable_param()->set_echo_metadata(true);
+
+ ClientContext context;
+ Status s = stub_->Echo(&context, request, &response);
+ EXPECT_EQ(response.message(), request.message());
+ EXPECT_TRUE(s.ok());
+ const auto& trailing_metadata = context.GetServerTrailingMetadata();
+ auto iter = trailing_metadata.find("user-agent");
+ EXPECT_TRUE(iter != trailing_metadata.end());
+ TString expected_prefix = user_agent_prefix_ + " grpc-c++/";
+ EXPECT_TRUE(iter->second.starts_with(expected_prefix)) << iter->second;
+}
+
+TEST_P(End2endTest, MultipleRpcsWithVariedBinaryMetadataValue) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ std::vector<std::thread> threads;
+ threads.reserve(10);
+ for (int i = 0; i < 10; ++i) {
+ threads.emplace_back(SendRpc, stub_.get(), 10, true);
+ }
+ for (int i = 0; i < 10; ++i) {
+ threads[i].join();
+ }
+}
+
+TEST_P(End2endTest, MultipleRpcs) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ std::vector<std::thread> threads;
+ threads.reserve(10);
+ for (int i = 0; i < 10; ++i) {
+ threads.emplace_back(SendRpc, stub_.get(), 10, false);
+ }
+ for (int i = 0; i < 10; ++i) {
+ threads[i].join();
+ }
+}
+
+TEST_P(End2endTest, ManyStubs) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ ChannelTestPeer peer(channel_.get());
+ int registered_calls_pre = peer.registered_calls();
+ int registration_attempts_pre = peer.registration_attempts();
+ for (int i = 0; i < 1000; ++i) {
+ grpc::testing::EchoTestService::NewStub(channel_);
+ }
+ EXPECT_EQ(peer.registered_calls(), registered_calls_pre);
+ EXPECT_GT(peer.registration_attempts(), registration_attempts_pre);
+}
+
+TEST_P(End2endTest, EmptyBinaryMetadata) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ request.set_message("Hello hello hello hello");
+ ClientContext context;
+ context.AddMetadata("custom-bin", "");
+ Status s = stub_->Echo(&context, request, &response);
+ EXPECT_EQ(response.message(), request.message());
+ EXPECT_TRUE(s.ok());
+}
+
+TEST_P(End2endTest, ReconnectChannel) {
+ MAYBE_SKIP_TEST;
+ if (GetParam().inproc) {
+ return;
+ }
+ int poller_slowdown_factor = 1;
+ // It needs 2 pollset_works to reconnect the channel with polling engine
+ // "poll"
+#ifdef GRPC_POSIX_SOCKET_EV
+ grpc_core::UniquePtr<char> poller = GPR_GLOBAL_CONFIG_GET(grpc_poll_strategy);
+ if (0 == strcmp(poller.get(), "poll")) {
+ poller_slowdown_factor = 2;
+ }
+#endif // GRPC_POSIX_SOCKET_EV
+ ResetStub();
+ SendRpc(stub_.get(), 1, false);
+ RestartServer(std::shared_ptr<AuthMetadataProcessor>());
+ // It needs more than GRPC_CLIENT_CHANNEL_BACKUP_POLL_INTERVAL_MS time to
+ // reconnect the channel. Make it a factor of 5x
+ gpr_sleep_until(
+ gpr_time_add(gpr_now(GPR_CLOCK_REALTIME),
+ gpr_time_from_millis(kClientChannelBackupPollIntervalMs * 5 *
+ poller_slowdown_factor *
+ grpc_test_slowdown_factor(),
+ GPR_TIMESPAN)));
+ SendRpc(stub_.get(), 1, false);
+}
+
+TEST_P(End2endTest, RequestStreamOneRequest) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+
+ auto stream = stub_->RequestStream(&context, &response);
+ request.set_message("hello");
+ EXPECT_TRUE(stream->Write(request));
+ stream->WritesDone();
+ Status s = stream->Finish();
+ EXPECT_EQ(response.message(), request.message());
+ EXPECT_TRUE(s.ok());
+ EXPECT_TRUE(context.debug_error_string().empty());
+}
+
+TEST_P(End2endTest, RequestStreamOneRequestWithCoalescingApi) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+
+ context.set_initial_metadata_corked(true);
+ auto stream = stub_->RequestStream(&context, &response);
+ request.set_message("hello");
+ stream->WriteLast(request, WriteOptions());
+ Status s = stream->Finish();
+ EXPECT_EQ(response.message(), request.message());
+ EXPECT_TRUE(s.ok());
+}
+
+TEST_P(End2endTest, RequestStreamTwoRequests) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+
+ auto stream = stub_->RequestStream(&context, &response);
+ request.set_message("hello");
+ EXPECT_TRUE(stream->Write(request));
+ EXPECT_TRUE(stream->Write(request));
+ stream->WritesDone();
+ Status s = stream->Finish();
+ EXPECT_EQ(response.message(), "hellohello");
+ EXPECT_TRUE(s.ok());
+}
+
+TEST_P(End2endTest, RequestStreamTwoRequestsWithWriteThrough) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+
+ auto stream = stub_->RequestStream(&context, &response);
+ request.set_message("hello");
+ EXPECT_TRUE(stream->Write(request, WriteOptions().set_write_through()));
+ EXPECT_TRUE(stream->Write(request, WriteOptions().set_write_through()));
+ stream->WritesDone();
+ Status s = stream->Finish();
+ EXPECT_EQ(response.message(), "hellohello");
+ EXPECT_TRUE(s.ok());
+}
+
+TEST_P(End2endTest, RequestStreamTwoRequestsWithCoalescingApi) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+
+ context.set_initial_metadata_corked(true);
+ auto stream = stub_->RequestStream(&context, &response);
+ request.set_message("hello");
+ EXPECT_TRUE(stream->Write(request));
+ stream->WriteLast(request, WriteOptions());
+ Status s = stream->Finish();
+ EXPECT_EQ(response.message(), "hellohello");
+ EXPECT_TRUE(s.ok());
+}
+
+TEST_P(End2endTest, ResponseStream) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+ request.set_message("hello");
+
+ auto stream = stub_->ResponseStream(&context, request);
+ for (int i = 0; i < kServerDefaultResponseStreamsToSend; ++i) {
+ EXPECT_TRUE(stream->Read(&response));
+ EXPECT_EQ(response.message(), request.message() + ToString(i));
+ }
+ EXPECT_FALSE(stream->Read(&response));
+
+ Status s = stream->Finish();
+ EXPECT_TRUE(s.ok());
+}
+
+TEST_P(End2endTest, ResponseStreamWithCoalescingApi) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+ request.set_message("hello");
+ context.AddMetadata(kServerUseCoalescingApi, "1");
+
+ auto stream = stub_->ResponseStream(&context, request);
+ for (int i = 0; i < kServerDefaultResponseStreamsToSend; ++i) {
+ EXPECT_TRUE(stream->Read(&response));
+ EXPECT_EQ(response.message(), request.message() + ToString(i));
+ }
+ EXPECT_FALSE(stream->Read(&response));
+
+ Status s = stream->Finish();
+ EXPECT_TRUE(s.ok());
+}
+
+// This was added to prevent regression from issue:
+// https://github.com/grpc/grpc/issues/11546
+TEST_P(End2endTest, ResponseStreamWithEverythingCoalesced) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+ request.set_message("hello");
+ context.AddMetadata(kServerUseCoalescingApi, "1");
+ // We will only send one message, forcing everything (init metadata, message,
+ // trailing) to be coalesced together.
+ context.AddMetadata(kServerResponseStreamsToSend, "1");
+
+ auto stream = stub_->ResponseStream(&context, request);
+ EXPECT_TRUE(stream->Read(&response));
+ EXPECT_EQ(response.message(), request.message() + "0");
+
+ EXPECT_FALSE(stream->Read(&response));
+
+ Status s = stream->Finish();
+ EXPECT_TRUE(s.ok());
+}
+
+TEST_P(End2endTest, BidiStream) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+ TString msg("hello");
+
+ auto stream = stub_->BidiStream(&context);
+
+ for (int i = 0; i < kServerDefaultResponseStreamsToSend; ++i) {
+ request.set_message(msg + ToString(i));
+ EXPECT_TRUE(stream->Write(request));
+ EXPECT_TRUE(stream->Read(&response));
+ EXPECT_EQ(response.message(), request.message());
+ }
+
+ stream->WritesDone();
+ EXPECT_FALSE(stream->Read(&response));
+ EXPECT_FALSE(stream->Read(&response));
+
+ Status s = stream->Finish();
+ EXPECT_TRUE(s.ok());
+}
+
+TEST_P(End2endTest, BidiStreamWithCoalescingApi) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+ context.AddMetadata(kServerFinishAfterNReads, "3");
+ context.set_initial_metadata_corked(true);
+ TString msg("hello");
+
+ auto stream = stub_->BidiStream(&context);
+
+ request.set_message(msg + "0");
+ EXPECT_TRUE(stream->Write(request));
+ EXPECT_TRUE(stream->Read(&response));
+ EXPECT_EQ(response.message(), request.message());
+
+ request.set_message(msg + "1");
+ EXPECT_TRUE(stream->Write(request));
+ EXPECT_TRUE(stream->Read(&response));
+ EXPECT_EQ(response.message(), request.message());
+
+ request.set_message(msg + "2");
+ stream->WriteLast(request, WriteOptions());
+ EXPECT_TRUE(stream->Read(&response));
+ EXPECT_EQ(response.message(), request.message());
+
+ EXPECT_FALSE(stream->Read(&response));
+ EXPECT_FALSE(stream->Read(&response));
+
+ Status s = stream->Finish();
+ EXPECT_TRUE(s.ok());
+}
+
+// This was added to prevent regression from issue:
+// https://github.com/grpc/grpc/issues/11546
+TEST_P(End2endTest, BidiStreamWithEverythingCoalesced) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+ context.AddMetadata(kServerFinishAfterNReads, "1");
+ context.set_initial_metadata_corked(true);
+ TString msg("hello");
+
+ auto stream = stub_->BidiStream(&context);
+
+ request.set_message(msg + "0");
+ stream->WriteLast(request, WriteOptions());
+ EXPECT_TRUE(stream->Read(&response));
+ EXPECT_EQ(response.message(), request.message());
+
+ EXPECT_FALSE(stream->Read(&response));
+ EXPECT_FALSE(stream->Read(&response));
+
+ Status s = stream->Finish();
+ EXPECT_TRUE(s.ok());
+}
+
+// Talk to the two services with the same name but different package names.
+// The two stubs are created on the same channel.
+TEST_P(End2endTest, DiffPackageServices) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ request.set_message("Hello");
+
+ ClientContext context;
+ Status s = stub_->Echo(&context, request, &response);
+ EXPECT_EQ(response.message(), request.message());
+ EXPECT_TRUE(s.ok());
+
+ std::unique_ptr<grpc::testing::duplicate::EchoTestService::Stub> dup_pkg_stub(
+ grpc::testing::duplicate::EchoTestService::NewStub(channel_));
+ ClientContext context2;
+ s = dup_pkg_stub->Echo(&context2, request, &response);
+ EXPECT_EQ("no package", response.message());
+ EXPECT_TRUE(s.ok());
+}
+
+template <class ServiceType>
+void CancelRpc(ClientContext* context, int delay_us, ServiceType* service) {
+ gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME),
+ gpr_time_from_micros(delay_us, GPR_TIMESPAN)));
+ while (!service->signal_client()) {
+ }
+ context->TryCancel();
+}
+
+TEST_P(End2endTest, CancelRpcBeforeStart) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+ request.set_message("hello");
+ context.TryCancel();
+ Status s = stub_->Echo(&context, request, &response);
+ EXPECT_EQ("", response.message());
+ EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code());
+ if (GetParam().use_interceptors) {
+ EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel());
+ }
+}
+
+TEST_P(End2endTest, CancelRpcAfterStart) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+ request.set_message("hello");
+ request.mutable_param()->set_server_notify_client_when_started(true);
+ request.mutable_param()->set_skip_cancelled_check(true);
+ Status s;
+ std::thread echo_thread([this, &s, &context, &request, &response] {
+ s = stub_->Echo(&context, request, &response);
+ EXPECT_EQ(StatusCode::CANCELLED, s.error_code());
+ });
+ if (!GetParam().callback_server) {
+ service_.ClientWaitUntilRpcStarted();
+ } else {
+ callback_service_.ClientWaitUntilRpcStarted();
+ }
+
+ context.TryCancel();
+
+ if (!GetParam().callback_server) {
+ service_.SignalServerToContinue();
+ } else {
+ callback_service_.SignalServerToContinue();
+ }
+
+ echo_thread.join();
+ EXPECT_EQ("", response.message());
+ EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code());
+ if (GetParam().use_interceptors) {
+ EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel());
+ }
+}
+
+// Client cancels request stream after sending two messages
+TEST_P(End2endTest, ClientCancelsRequestStream) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+ request.set_message("hello");
+
+ auto stream = stub_->RequestStream(&context, &response);
+ EXPECT_TRUE(stream->Write(request));
+ EXPECT_TRUE(stream->Write(request));
+
+ context.TryCancel();
+
+ Status s = stream->Finish();
+ EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code());
+
+ EXPECT_EQ(response.message(), "");
+ if (GetParam().use_interceptors) {
+ EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel());
+ }
+}
+
+// Client cancels server stream after sending some messages
+TEST_P(End2endTest, ClientCancelsResponseStream) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+ request.set_message("hello");
+
+ auto stream = stub_->ResponseStream(&context, request);
+
+ EXPECT_TRUE(stream->Read(&response));
+ EXPECT_EQ(response.message(), request.message() + "0");
+ EXPECT_TRUE(stream->Read(&response));
+ EXPECT_EQ(response.message(), request.message() + "1");
+
+ context.TryCancel();
+
+ // The cancellation races with responses, so there might be zero or
+ // one responses pending, read till failure
+
+ if (stream->Read(&response)) {
+ EXPECT_EQ(response.message(), request.message() + "2");
+ // Since we have cancelled, we expect the next attempt to read to fail
+ EXPECT_FALSE(stream->Read(&response));
+ }
+
+ Status s = stream->Finish();
+ // The final status could be either of CANCELLED or OK depending on
+ // who won the race.
+ EXPECT_GE(grpc::StatusCode::CANCELLED, s.error_code());
+ if (GetParam().use_interceptors) {
+ EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel());
+ }
+}
+
+// Client cancels bidi stream after sending some messages
+TEST_P(End2endTest, ClientCancelsBidi) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+ TString msg("hello");
+
+ auto stream = stub_->BidiStream(&context);
+
+ request.set_message(msg + "0");
+ EXPECT_TRUE(stream->Write(request));
+ EXPECT_TRUE(stream->Read(&response));
+ EXPECT_EQ(response.message(), request.message());
+
+ request.set_message(msg + "1");
+ EXPECT_TRUE(stream->Write(request));
+
+ context.TryCancel();
+
+ // The cancellation races with responses, so there might be zero or
+ // one responses pending, read till failure
+
+ if (stream->Read(&response)) {
+ EXPECT_EQ(response.message(), request.message());
+ // Since we have cancelled, we expect the next attempt to read to fail
+ EXPECT_FALSE(stream->Read(&response));
+ }
+
+ Status s = stream->Finish();
+ EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code());
+ if (GetParam().use_interceptors) {
+ EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel());
+ }
+}
+
+TEST_P(End2endTest, RpcMaxMessageSize) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ request.set_message(string(kMaxMessageSize_ * 2, 'a'));
+ request.mutable_param()->set_server_die(true);
+
+ ClientContext context;
+ Status s = stub_->Echo(&context, request, &response);
+ EXPECT_FALSE(s.ok());
+}
+
+void ReaderThreadFunc(ClientReaderWriter<EchoRequest, EchoResponse>* stream,
+ gpr_event* ev) {
+ EchoResponse resp;
+ gpr_event_set(ev, (void*)1);
+ while (stream->Read(&resp)) {
+ gpr_log(GPR_INFO, "Read message");
+ }
+}
+
+// Run a Read and a WritesDone simultaneously.
+TEST_P(End2endTest, SimultaneousReadWritesDone) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ ClientContext context;
+ gpr_event ev;
+ gpr_event_init(&ev);
+ auto stream = stub_->BidiStream(&context);
+ std::thread reader_thread(ReaderThreadFunc, stream.get(), &ev);
+ gpr_event_wait(&ev, gpr_inf_future(GPR_CLOCK_REALTIME));
+ stream->WritesDone();
+ reader_thread.join();
+ Status s = stream->Finish();
+ EXPECT_TRUE(s.ok());
+}
+
+TEST_P(End2endTest, ChannelState) {
+ MAYBE_SKIP_TEST;
+ if (GetParam().inproc) {
+ return;
+ }
+
+ ResetStub();
+ // Start IDLE
+ EXPECT_EQ(GRPC_CHANNEL_IDLE, channel_->GetState(false));
+
+ // Did not ask to connect, no state change.
+ CompletionQueue cq;
+ std::chrono::system_clock::time_point deadline =
+ std::chrono::system_clock::now() + std::chrono::milliseconds(10);
+ channel_->NotifyOnStateChange(GRPC_CHANNEL_IDLE, deadline, &cq, nullptr);
+ void* tag;
+ bool ok = true;
+ cq.Next(&tag, &ok);
+ EXPECT_FALSE(ok);
+
+ EXPECT_EQ(GRPC_CHANNEL_IDLE, channel_->GetState(true));
+ EXPECT_TRUE(channel_->WaitForStateChange(GRPC_CHANNEL_IDLE,
+ gpr_inf_future(GPR_CLOCK_REALTIME)));
+ auto state = channel_->GetState(false);
+ EXPECT_TRUE(state == GRPC_CHANNEL_CONNECTING || state == GRPC_CHANNEL_READY);
+}
+
+// Takes 10s.
+TEST_P(End2endTest, ChannelStateTimeout) {
+ if ((GetParam().credentials_type != kInsecureCredentialsType) ||
+ GetParam().inproc) {
+ return;
+ }
+ int port = grpc_pick_unused_port_or_die();
+ std::ostringstream server_address;
+ server_address << "127.0.0.1:" << port;
+ // Channel to non-existing server
+ auto channel =
+ grpc::CreateChannel(server_address.str(), InsecureChannelCredentials());
+ // Start IDLE
+ EXPECT_EQ(GRPC_CHANNEL_IDLE, channel->GetState(true));
+
+ auto state = GRPC_CHANNEL_IDLE;
+ for (int i = 0; i < 10; i++) {
+ channel->WaitForStateChange(
+ state, std::chrono::system_clock::now() + std::chrono::seconds(1));
+ state = channel->GetState(false);
+ }
+}
+
+// Talking to a non-existing service.
+TEST_P(End2endTest, NonExistingService) {
+ MAYBE_SKIP_TEST;
+ ResetChannel();
+ std::unique_ptr<grpc::testing::UnimplementedEchoService::Stub> stub;
+ stub = grpc::testing::UnimplementedEchoService::NewStub(channel_);
+
+ EchoRequest request;
+ EchoResponse response;
+ request.set_message("Hello");
+
+ ClientContext context;
+ Status s = stub->Unimplemented(&context, request, &response);
+ EXPECT_EQ(StatusCode::UNIMPLEMENTED, s.error_code());
+ EXPECT_EQ("", s.error_message());
+}
+
+// Ask the server to send back a serialized proto in trailer.
+// This is an example of setting error details.
+TEST_P(End2endTest, BinaryTrailerTest) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+
+ request.mutable_param()->set_echo_metadata(true);
+ DebugInfo* info = request.mutable_param()->mutable_debug_info();
+ info->add_stack_entries("stack_entry_1");
+ info->add_stack_entries("stack_entry_2");
+ info->add_stack_entries("stack_entry_3");
+ info->set_detail("detailed debug info");
+ TString expected_string = info->SerializeAsString();
+ request.set_message("Hello");
+
+ Status s = stub_->Echo(&context, request, &response);
+ EXPECT_FALSE(s.ok());
+ auto trailers = context.GetServerTrailingMetadata();
+ EXPECT_EQ(1u, trailers.count(kDebugInfoTrailerKey));
+ auto iter = trailers.find(kDebugInfoTrailerKey);
+ EXPECT_EQ(expected_string, iter->second);
+ // Parse the returned trailer into a DebugInfo proto.
+ DebugInfo returned_info;
+ EXPECT_TRUE(returned_info.ParseFromString(ToString(iter->second)));
+}
+
+TEST_P(End2endTest, ExpectErrorTest) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+
+ std::vector<ErrorStatus> expected_status;
+ expected_status.emplace_back();
+ expected_status.back().set_code(13); // INTERNAL
+ // No Error message or details
+
+ expected_status.emplace_back();
+ expected_status.back().set_code(13); // INTERNAL
+ expected_status.back().set_error_message("text error message");
+ expected_status.back().set_binary_error_details("text error details");
+
+ expected_status.emplace_back();
+ expected_status.back().set_code(13); // INTERNAL
+ expected_status.back().set_error_message("text error message");
+ expected_status.back().set_binary_error_details(
+ "\x0\x1\x2\x3\x4\x5\x6\x8\x9\xA\xB");
+
+ for (auto iter = expected_status.begin(); iter != expected_status.end();
+ ++iter) {
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+ request.set_message("Hello");
+ auto* error = request.mutable_param()->mutable_expected_error();
+ error->set_code(iter->code());
+ error->set_error_message(iter->error_message());
+ error->set_binary_error_details(iter->binary_error_details());
+
+ Status s = stub_->Echo(&context, request, &response);
+ EXPECT_FALSE(s.ok());
+ EXPECT_EQ(iter->code(), s.error_code());
+ EXPECT_EQ(iter->error_message(), s.error_message());
+ EXPECT_EQ(iter->binary_error_details(), s.error_details());
+ EXPECT_TRUE(context.debug_error_string().find("created") !=
+ TString::npos);
+ EXPECT_TRUE(context.debug_error_string().find("file") != TString::npos);
+ EXPECT_TRUE(context.debug_error_string().find("line") != TString::npos);
+ EXPECT_TRUE(context.debug_error_string().find("status") !=
+ TString::npos);
+ EXPECT_TRUE(context.debug_error_string().find("13") != TString::npos);
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+// Test with and without a proxy.
+class ProxyEnd2endTest : public End2endTest {
+ protected:
+};
+
+TEST_P(ProxyEnd2endTest, SimpleRpc) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ SendRpc(stub_.get(), 1, false);
+}
+
+TEST_P(ProxyEnd2endTest, SimpleRpcWithEmptyMessages) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+
+ ClientContext context;
+ Status s = stub_->Echo(&context, request, &response);
+ EXPECT_TRUE(s.ok());
+}
+
+TEST_P(ProxyEnd2endTest, MultipleRpcs) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ std::vector<std::thread> threads;
+ threads.reserve(10);
+ for (int i = 0; i < 10; ++i) {
+ threads.emplace_back(SendRpc, stub_.get(), 10, false);
+ }
+ for (int i = 0; i < 10; ++i) {
+ threads[i].join();
+ }
+}
+
+// Set a 10us deadline and make sure proper error is returned.
+TEST_P(ProxyEnd2endTest, RpcDeadlineExpires) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ request.set_message("Hello");
+ request.mutable_param()->set_skip_cancelled_check(true);
+ // Let server sleep for 40 ms first to guarantee expiry.
+ // 40 ms might seem a bit extreme but the timer manager would have been just
+ // initialized (when ResetStub() was called) and there are some warmup costs
+ // i.e the timer thread many not have even started. There might also be other
+ // delays in the timer manager thread (in acquiring locks, timer data
+ // structure manipulations, starting backup timer threads) that add to the
+ // delays. 40ms is still not enough in some cases but this significantly
+ // reduces the test flakes
+ request.mutable_param()->set_server_sleep_us(40 * 1000);
+
+ ClientContext context;
+ std::chrono::system_clock::time_point deadline =
+ std::chrono::system_clock::now() + std::chrono::milliseconds(1);
+ context.set_deadline(deadline);
+ Status s = stub_->Echo(&context, request, &response);
+ EXPECT_EQ(StatusCode::DEADLINE_EXCEEDED, s.error_code());
+}
+
+// Set a long but finite deadline.
+TEST_P(ProxyEnd2endTest, RpcLongDeadline) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ request.set_message("Hello");
+
+ ClientContext context;
+ std::chrono::system_clock::time_point deadline =
+ std::chrono::system_clock::now() + std::chrono::hours(1);
+ context.set_deadline(deadline);
+ Status s = stub_->Echo(&context, request, &response);
+ EXPECT_EQ(response.message(), request.message());
+ EXPECT_TRUE(s.ok());
+}
+
+// Ask server to echo back the deadline it sees.
+TEST_P(ProxyEnd2endTest, EchoDeadline) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ request.set_message("Hello");
+ request.mutable_param()->set_echo_deadline(true);
+
+ ClientContext context;
+ std::chrono::system_clock::time_point deadline =
+ std::chrono::system_clock::now() + std::chrono::seconds(100);
+ context.set_deadline(deadline);
+ Status s = stub_->Echo(&context, request, &response);
+ EXPECT_EQ(response.message(), request.message());
+ EXPECT_TRUE(s.ok());
+ gpr_timespec sent_deadline;
+ Timepoint2Timespec(deadline, &sent_deadline);
+ // We want to allow some reasonable error given:
+ // - request_deadline() only has 1sec resolution so the best we can do is +-1
+ // - if sent_deadline.tv_nsec is very close to the next second's boundary we
+ // can end up being off by 2 in one direction.
+ EXPECT_LE(response.param().request_deadline() - sent_deadline.tv_sec, 2);
+ EXPECT_GE(response.param().request_deadline() - sent_deadline.tv_sec, -1);
+}
+
+// Ask server to echo back the deadline it sees. The rpc has no deadline.
+TEST_P(ProxyEnd2endTest, EchoDeadlineForNoDeadlineRpc) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ request.set_message("Hello");
+ request.mutable_param()->set_echo_deadline(true);
+
+ ClientContext context;
+ Status s = stub_->Echo(&context, request, &response);
+ EXPECT_EQ(response.message(), request.message());
+ EXPECT_TRUE(s.ok());
+ EXPECT_EQ(response.param().request_deadline(),
+ gpr_inf_future(GPR_CLOCK_REALTIME).tv_sec);
+}
+
+TEST_P(ProxyEnd2endTest, UnimplementedRpc) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ request.set_message("Hello");
+
+ ClientContext context;
+ Status s = stub_->Unimplemented(&context, request, &response);
+ EXPECT_FALSE(s.ok());
+ EXPECT_EQ(s.error_code(), grpc::StatusCode::UNIMPLEMENTED);
+ EXPECT_EQ(s.error_message(), "");
+ EXPECT_EQ(response.message(), "");
+}
+
+// Client cancels rpc after 10ms
+TEST_P(ProxyEnd2endTest, ClientCancelsRpc) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ request.set_message("Hello");
+ const int kCancelDelayUs = 10 * 1000;
+ request.mutable_param()->set_client_cancel_after_us(kCancelDelayUs);
+
+ ClientContext context;
+ std::thread cancel_thread;
+ if (!GetParam().callback_server) {
+ cancel_thread = std::thread(
+ [&context, this](int delay) { CancelRpc(&context, delay, &service_); },
+ kCancelDelayUs);
+ // Note: the unusual pattern above (and below) is caused by a conflict
+ // between two sets of compiler expectations. clang allows const to be
+ // captured without mention, so there is no need to capture kCancelDelayUs
+ // (and indeed clang-tidy complains if you do so). OTOH, a Windows compiler
+ // in our tests requires an explicit capture even for const. We square this
+ // circle by passing the const value in as an argument to the lambda.
+ } else {
+ cancel_thread = std::thread(
+ [&context, this](int delay) {
+ CancelRpc(&context, delay, &callback_service_);
+ },
+ kCancelDelayUs);
+ }
+ Status s = stub_->Echo(&context, request, &response);
+ cancel_thread.join();
+ EXPECT_EQ(StatusCode::CANCELLED, s.error_code());
+ EXPECT_EQ(s.error_message(), "Cancelled");
+}
+
+// Server cancels rpc after 1ms
+TEST_P(ProxyEnd2endTest, ServerCancelsRpc) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ request.set_message("Hello");
+ request.mutable_param()->set_server_cancel_after_us(1000);
+
+ ClientContext context;
+ Status s = stub_->Echo(&context, request, &response);
+ EXPECT_EQ(StatusCode::CANCELLED, s.error_code());
+ EXPECT_TRUE(s.error_message().empty());
+}
+
+// Make the response larger than the flow control window.
+TEST_P(ProxyEnd2endTest, HugeResponse) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ request.set_message("huge response");
+ const size_t kResponseSize = 1024 * (1024 + 10);
+ request.mutable_param()->set_response_message_length(kResponseSize);
+
+ ClientContext context;
+ std::chrono::system_clock::time_point deadline =
+ std::chrono::system_clock::now() + std::chrono::seconds(20);
+ context.set_deadline(deadline);
+ Status s = stub_->Echo(&context, request, &response);
+ EXPECT_EQ(kResponseSize, response.message().size());
+ EXPECT_TRUE(s.ok());
+}
+
+TEST_P(ProxyEnd2endTest, Peer) {
+ MAYBE_SKIP_TEST;
+ // Peer is not meaningful for inproc
+ if (GetParam().inproc) {
+ return;
+ }
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ request.set_message("hello");
+ request.mutable_param()->set_echo_peer(true);
+
+ ClientContext context;
+ Status s = stub_->Echo(&context, request, &response);
+ EXPECT_EQ(response.message(), request.message());
+ EXPECT_TRUE(s.ok());
+ EXPECT_TRUE(CheckIsLocalhost(response.param().peer()));
+ EXPECT_TRUE(CheckIsLocalhost(context.peer()));
+}
+
+//////////////////////////////////////////////////////////////////////////
+class SecureEnd2endTest : public End2endTest {
+ protected:
+ SecureEnd2endTest() {
+ GPR_ASSERT(!GetParam().use_proxy);
+ GPR_ASSERT(GetParam().credentials_type != kInsecureCredentialsType);
+ }
+};
+
+TEST_P(SecureEnd2endTest, SimpleRpcWithHost) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+
+ EchoRequest request;
+ EchoResponse response;
+ request.set_message("Hello");
+
+ ClientContext context;
+ context.set_authority("foo.test.youtube.com");
+ Status s = stub_->Echo(&context, request, &response);
+ EXPECT_EQ(response.message(), request.message());
+ EXPECT_TRUE(response.has_param());
+ EXPECT_EQ("special", response.param().host());
+ EXPECT_TRUE(s.ok());
+}
+
+bool MetadataContains(
+ const std::multimap<grpc::string_ref, grpc::string_ref>& metadata,
+ const TString& key, const TString& value) {
+ int count = 0;
+
+ for (std::multimap<grpc::string_ref, grpc::string_ref>::const_iterator iter =
+ metadata.begin();
+ iter != metadata.end(); ++iter) {
+ if (ToString(iter->first) == key && ToString(iter->second) == value) {
+ count++;
+ }
+ }
+ return count == 1;
+}
+
+TEST_P(SecureEnd2endTest, BlockingAuthMetadataPluginAndProcessorSuccess) {
+ MAYBE_SKIP_TEST;
+ auto* processor = new TestAuthMetadataProcessor(true);
+ StartServer(std::shared_ptr<AuthMetadataProcessor>(processor));
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+ context.set_credentials(processor->GetCompatibleClientCreds());
+ request.set_message("Hello");
+ request.mutable_param()->set_echo_metadata(true);
+ request.mutable_param()->set_expected_client_identity(
+ TestAuthMetadataProcessor::kGoodGuy);
+ request.mutable_param()->set_expected_transport_security_type(
+ GetParam().credentials_type);
+
+ Status s = stub_->Echo(&context, request, &response);
+ EXPECT_EQ(request.message(), response.message());
+ EXPECT_TRUE(s.ok());
+
+ // Metadata should have been consumed by the processor.
+ EXPECT_FALSE(MetadataContains(
+ context.GetServerTrailingMetadata(), GRPC_AUTHORIZATION_METADATA_KEY,
+ TString("Bearer ") + TestAuthMetadataProcessor::kGoodGuy));
+}
+
+TEST_P(SecureEnd2endTest, BlockingAuthMetadataPluginAndProcessorFailure) {
+ MAYBE_SKIP_TEST;
+ auto* processor = new TestAuthMetadataProcessor(true);
+ StartServer(std::shared_ptr<AuthMetadataProcessor>(processor));
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+ context.set_credentials(processor->GetIncompatibleClientCreds());
+ request.set_message("Hello");
+
+ Status s = stub_->Echo(&context, request, &response);
+ EXPECT_FALSE(s.ok());
+ EXPECT_EQ(s.error_code(), StatusCode::UNAUTHENTICATED);
+}
+
+TEST_P(SecureEnd2endTest, SetPerCallCredentials) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+ std::shared_ptr<CallCredentials> creds =
+ GoogleIAMCredentials(kFakeToken, kFakeSelector);
+ context.set_credentials(creds);
+ request.set_message("Hello");
+ request.mutable_param()->set_echo_metadata(true);
+
+ Status s = stub_->Echo(&context, request, &response);
+ EXPECT_EQ(request.message(), response.message());
+ EXPECT_TRUE(s.ok());
+ EXPECT_TRUE(MetadataContains(context.GetServerTrailingMetadata(),
+ GRPC_IAM_AUTHORIZATION_TOKEN_METADATA_KEY,
+ kFakeToken));
+ EXPECT_TRUE(MetadataContains(context.GetServerTrailingMetadata(),
+ GRPC_IAM_AUTHORITY_SELECTOR_METADATA_KEY,
+ kFakeSelector));
+ EXPECT_EQ(context.credentials()->DebugString(),
+ kExpectedFakeCredsDebugString);
+}
+
+class CredentialsInterceptor : public experimental::Interceptor {
+ public:
+ CredentialsInterceptor(experimental::ClientRpcInfo* info) : info_(info) {}
+
+ void Intercept(experimental::InterceptorBatchMethods* methods) {
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
+ std::shared_ptr<CallCredentials> creds =
+ GoogleIAMCredentials(kFakeToken, kFakeSelector);
+ info_->client_context()->set_credentials(creds);
+ }
+ methods->Proceed();
+ }
+
+ private:
+ experimental::ClientRpcInfo* info_ = nullptr;
+};
+
+class CredentialsInterceptorFactory
+ : public experimental::ClientInterceptorFactoryInterface {
+ CredentialsInterceptor* CreateClientInterceptor(
+ experimental::ClientRpcInfo* info) {
+ return new CredentialsInterceptor(info);
+ }
+};
+
+TEST_P(SecureEnd2endTest, CallCredentialsInterception) {
+ MAYBE_SKIP_TEST;
+ if (!GetParam().use_interceptors) {
+ return;
+ }
+ std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
+ interceptor_creators;
+ interceptor_creators.push_back(std::unique_ptr<CredentialsInterceptorFactory>(
+ new CredentialsInterceptorFactory()));
+ ResetStub(std::move(interceptor_creators));
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+
+ request.set_message("Hello");
+ request.mutable_param()->set_echo_metadata(true);
+
+ Status s = stub_->Echo(&context, request, &response);
+ EXPECT_EQ(request.message(), response.message());
+ EXPECT_TRUE(s.ok());
+ EXPECT_TRUE(MetadataContains(context.GetServerTrailingMetadata(),
+ GRPC_IAM_AUTHORIZATION_TOKEN_METADATA_KEY,
+ kFakeToken));
+ EXPECT_TRUE(MetadataContains(context.GetServerTrailingMetadata(),
+ GRPC_IAM_AUTHORITY_SELECTOR_METADATA_KEY,
+ kFakeSelector));
+ EXPECT_EQ(context.credentials()->DebugString(),
+ kExpectedFakeCredsDebugString);
+}
+
+TEST_P(SecureEnd2endTest, CallCredentialsInterceptionWithSetCredentials) {
+ MAYBE_SKIP_TEST;
+ if (!GetParam().use_interceptors) {
+ return;
+ }
+ std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
+ interceptor_creators;
+ interceptor_creators.push_back(std::unique_ptr<CredentialsInterceptorFactory>(
+ new CredentialsInterceptorFactory()));
+ ResetStub(std::move(interceptor_creators));
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+ std::shared_ptr<CallCredentials> creds1 =
+ GoogleIAMCredentials(kWrongToken, kWrongSelector);
+ context.set_credentials(creds1);
+ EXPECT_EQ(context.credentials(), creds1);
+ EXPECT_EQ(context.credentials()->DebugString(),
+ kExpectedWrongCredsDebugString);
+ request.set_message("Hello");
+ request.mutable_param()->set_echo_metadata(true);
+
+ Status s = stub_->Echo(&context, request, &response);
+ EXPECT_EQ(request.message(), response.message());
+ EXPECT_TRUE(s.ok());
+ EXPECT_TRUE(MetadataContains(context.GetServerTrailingMetadata(),
+ GRPC_IAM_AUTHORIZATION_TOKEN_METADATA_KEY,
+ kFakeToken));
+ EXPECT_TRUE(MetadataContains(context.GetServerTrailingMetadata(),
+ GRPC_IAM_AUTHORITY_SELECTOR_METADATA_KEY,
+ kFakeSelector));
+ EXPECT_EQ(context.credentials()->DebugString(),
+ kExpectedFakeCredsDebugString);
+}
+
+TEST_P(SecureEnd2endTest, OverridePerCallCredentials) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+ std::shared_ptr<CallCredentials> creds1 =
+ GoogleIAMCredentials(kFakeToken1, kFakeSelector1);
+ context.set_credentials(creds1);
+ EXPECT_EQ(context.credentials(), creds1);
+ EXPECT_EQ(context.credentials()->DebugString(),
+ kExpectedFakeCreds1DebugString);
+ std::shared_ptr<CallCredentials> creds2 =
+ GoogleIAMCredentials(kFakeToken2, kFakeSelector2);
+ context.set_credentials(creds2);
+ EXPECT_EQ(context.credentials(), creds2);
+ request.set_message("Hello");
+ request.mutable_param()->set_echo_metadata(true);
+
+ Status s = stub_->Echo(&context, request, &response);
+ EXPECT_TRUE(MetadataContains(context.GetServerTrailingMetadata(),
+ GRPC_IAM_AUTHORIZATION_TOKEN_METADATA_KEY,
+ kFakeToken2));
+ EXPECT_TRUE(MetadataContains(context.GetServerTrailingMetadata(),
+ GRPC_IAM_AUTHORITY_SELECTOR_METADATA_KEY,
+ kFakeSelector2));
+ EXPECT_FALSE(MetadataContains(context.GetServerTrailingMetadata(),
+ GRPC_IAM_AUTHORIZATION_TOKEN_METADATA_KEY,
+ kFakeToken1));
+ EXPECT_FALSE(MetadataContains(context.GetServerTrailingMetadata(),
+ GRPC_IAM_AUTHORITY_SELECTOR_METADATA_KEY,
+ kFakeSelector1));
+ EXPECT_EQ(context.credentials()->DebugString(),
+ kExpectedFakeCreds2DebugString);
+ EXPECT_EQ(request.message(), response.message());
+ EXPECT_TRUE(s.ok());
+}
+
+TEST_P(SecureEnd2endTest, AuthMetadataPluginKeyFailure) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+ context.set_credentials(grpc::MetadataCredentialsFromPlugin(
+ std::unique_ptr<MetadataCredentialsPlugin>(
+ new TestMetadataCredentialsPlugin(
+ TestMetadataCredentialsPlugin::kBadMetadataKey,
+ "Does not matter, will fail the key is invalid.", false, true,
+ 0))));
+ request.set_message("Hello");
+
+ Status s = stub_->Echo(&context, request, &response);
+ EXPECT_FALSE(s.ok());
+ EXPECT_EQ(s.error_code(), StatusCode::UNAVAILABLE);
+ EXPECT_EQ(context.credentials()->DebugString(),
+ kExpectedAuthMetadataPluginKeyFailureCredsDebugString);
+}
+
+TEST_P(SecureEnd2endTest, AuthMetadataPluginValueFailure) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+ context.set_credentials(grpc::MetadataCredentialsFromPlugin(
+ std::unique_ptr<MetadataCredentialsPlugin>(
+ new TestMetadataCredentialsPlugin(
+ TestMetadataCredentialsPlugin::kGoodMetadataKey,
+ "With illegal \n value.", false, true, 0))));
+ request.set_message("Hello");
+
+ Status s = stub_->Echo(&context, request, &response);
+ EXPECT_FALSE(s.ok());
+ EXPECT_EQ(s.error_code(), StatusCode::UNAVAILABLE);
+ EXPECT_EQ(context.credentials()->DebugString(),
+ kExpectedAuthMetadataPluginValueFailureCredsDebugString);
+}
+
+TEST_P(SecureEnd2endTest, AuthMetadataPluginWithDeadline) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ EchoRequest request;
+ request.mutable_param()->set_skip_cancelled_check(true);
+ EchoResponse response;
+ ClientContext context;
+ const int delay = 100;
+ std::chrono::system_clock::time_point deadline =
+ std::chrono::system_clock::now() + std::chrono::milliseconds(delay);
+ context.set_deadline(deadline);
+ context.set_credentials(grpc::MetadataCredentialsFromPlugin(
+ std::unique_ptr<MetadataCredentialsPlugin>(
+ new TestMetadataCredentialsPlugin("meta_key", "Does not matter", true,
+ true, delay))));
+ request.set_message("Hello");
+
+ Status s = stub_->Echo(&context, request, &response);
+ if (!s.ok()) {
+ EXPECT_TRUE(s.error_code() == StatusCode::DEADLINE_EXCEEDED ||
+ s.error_code() == StatusCode::UNAVAILABLE);
+ }
+ EXPECT_EQ(context.credentials()->DebugString(),
+ kExpectedAuthMetadataPluginWithDeadlineCredsDebugString);
+}
+
+TEST_P(SecureEnd2endTest, AuthMetadataPluginWithCancel) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ EchoRequest request;
+ request.mutable_param()->set_skip_cancelled_check(true);
+ EchoResponse response;
+ ClientContext context;
+ const int delay = 100;
+ context.set_credentials(grpc::MetadataCredentialsFromPlugin(
+ std::unique_ptr<MetadataCredentialsPlugin>(
+ new TestMetadataCredentialsPlugin("meta_key", "Does not matter", true,
+ true, delay))));
+ request.set_message("Hello");
+
+ std::thread cancel_thread([&] {
+ gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME),
+ gpr_time_from_millis(delay, GPR_TIMESPAN)));
+ context.TryCancel();
+ });
+ Status s = stub_->Echo(&context, request, &response);
+ if (!s.ok()) {
+ EXPECT_TRUE(s.error_code() == StatusCode::CANCELLED ||
+ s.error_code() == StatusCode::UNAVAILABLE);
+ }
+ cancel_thread.join();
+ EXPECT_EQ(context.credentials()->DebugString(),
+ kExpectedAuthMetadataPluginWithDeadlineCredsDebugString);
+}
+
+TEST_P(SecureEnd2endTest, NonBlockingAuthMetadataPluginFailure) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+ context.set_credentials(grpc::MetadataCredentialsFromPlugin(
+ std::unique_ptr<MetadataCredentialsPlugin>(
+ new TestMetadataCredentialsPlugin(
+ TestMetadataCredentialsPlugin::kGoodMetadataKey,
+ "Does not matter, will fail anyway (see 3rd param)", false, false,
+ 0))));
+ request.set_message("Hello");
+
+ Status s = stub_->Echo(&context, request, &response);
+ EXPECT_FALSE(s.ok());
+ EXPECT_EQ(s.error_code(), StatusCode::UNAVAILABLE);
+ EXPECT_EQ(s.error_message(),
+ TString("Getting metadata from plugin failed with error: ") +
+ kTestCredsPluginErrorMsg);
+ EXPECT_EQ(context.credentials()->DebugString(),
+ kExpectedNonBlockingAuthMetadataPluginFailureCredsDebugString);
+}
+
+TEST_P(SecureEnd2endTest, NonBlockingAuthMetadataPluginAndProcessorSuccess) {
+ MAYBE_SKIP_TEST;
+ auto* processor = new TestAuthMetadataProcessor(false);
+ StartServer(std::shared_ptr<AuthMetadataProcessor>(processor));
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+ context.set_credentials(processor->GetCompatibleClientCreds());
+ request.set_message("Hello");
+ request.mutable_param()->set_echo_metadata(true);
+ request.mutable_param()->set_expected_client_identity(
+ TestAuthMetadataProcessor::kGoodGuy);
+ request.mutable_param()->set_expected_transport_security_type(
+ GetParam().credentials_type);
+
+ Status s = stub_->Echo(&context, request, &response);
+ EXPECT_EQ(request.message(), response.message());
+ EXPECT_TRUE(s.ok());
+
+ // Metadata should have been consumed by the processor.
+ EXPECT_FALSE(MetadataContains(
+ context.GetServerTrailingMetadata(), GRPC_AUTHORIZATION_METADATA_KEY,
+ TString("Bearer ") + TestAuthMetadataProcessor::kGoodGuy));
+ EXPECT_EQ(
+ context.credentials()->DebugString(),
+ kExpectedNonBlockingAuthMetadataPluginAndProcessorSuccessCredsDebugString);
+}
+
+TEST_P(SecureEnd2endTest, NonBlockingAuthMetadataPluginAndProcessorFailure) {
+ MAYBE_SKIP_TEST;
+ auto* processor = new TestAuthMetadataProcessor(false);
+ StartServer(std::shared_ptr<AuthMetadataProcessor>(processor));
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+ context.set_credentials(processor->GetIncompatibleClientCreds());
+ request.set_message("Hello");
+
+ Status s = stub_->Echo(&context, request, &response);
+ EXPECT_FALSE(s.ok());
+ EXPECT_EQ(s.error_code(), StatusCode::UNAUTHENTICATED);
+ EXPECT_EQ(
+ context.credentials()->DebugString(),
+ kExpectedNonBlockingAuthMetadataPluginAndProcessorFailureCredsDebugString);
+}
+
+TEST_P(SecureEnd2endTest, BlockingAuthMetadataPluginFailure) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+ context.set_credentials(grpc::MetadataCredentialsFromPlugin(
+ std::unique_ptr<MetadataCredentialsPlugin>(
+ new TestMetadataCredentialsPlugin(
+ TestMetadataCredentialsPlugin::kGoodMetadataKey,
+ "Does not matter, will fail anyway (see 3rd param)", true, false,
+ 0))));
+ request.set_message("Hello");
+
+ Status s = stub_->Echo(&context, request, &response);
+ EXPECT_FALSE(s.ok());
+ EXPECT_EQ(s.error_code(), StatusCode::UNAVAILABLE);
+ EXPECT_EQ(s.error_message(),
+ TString("Getting metadata from plugin failed with error: ") +
+ kTestCredsPluginErrorMsg);
+ EXPECT_EQ(context.credentials()->DebugString(),
+ kExpectedBlockingAuthMetadataPluginFailureCredsDebugString);
+}
+
+TEST_P(SecureEnd2endTest, CompositeCallCreds) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+ const char kMetadataKey1[] = "call-creds-key1";
+ const char kMetadataKey2[] = "call-creds-key2";
+ const char kMetadataVal1[] = "call-creds-val1";
+ const char kMetadataVal2[] = "call-creds-val2";
+
+ context.set_credentials(grpc::CompositeCallCredentials(
+ grpc::MetadataCredentialsFromPlugin(
+ std::unique_ptr<MetadataCredentialsPlugin>(
+ new TestMetadataCredentialsPlugin(kMetadataKey1, kMetadataVal1,
+ true, true, 0))),
+ grpc::MetadataCredentialsFromPlugin(
+ std::unique_ptr<MetadataCredentialsPlugin>(
+ new TestMetadataCredentialsPlugin(kMetadataKey2, kMetadataVal2,
+ true, true, 0)))));
+ request.set_message("Hello");
+ request.mutable_param()->set_echo_metadata(true);
+
+ Status s = stub_->Echo(&context, request, &response);
+ EXPECT_TRUE(s.ok());
+ EXPECT_TRUE(MetadataContains(context.GetServerTrailingMetadata(),
+ kMetadataKey1, kMetadataVal1));
+ EXPECT_TRUE(MetadataContains(context.GetServerTrailingMetadata(),
+ kMetadataKey2, kMetadataVal2));
+ EXPECT_EQ(context.credentials()->DebugString(),
+ kExpectedCompositeCallCredsDebugString);
+}
+
+TEST_P(SecureEnd2endTest, ClientAuthContext) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ request.set_message("Hello");
+ request.mutable_param()->set_check_auth_context(GetParam().credentials_type ==
+ kTlsCredentialsType);
+ request.mutable_param()->set_expected_transport_security_type(
+ GetParam().credentials_type);
+ ClientContext context;
+ Status s = stub_->Echo(&context, request, &response);
+ EXPECT_EQ(response.message(), request.message());
+ EXPECT_TRUE(s.ok());
+
+ std::shared_ptr<const AuthContext> auth_ctx = context.auth_context();
+ std::vector<grpc::string_ref> tst =
+ auth_ctx->FindPropertyValues("transport_security_type");
+ ASSERT_EQ(1u, tst.size());
+ EXPECT_EQ(GetParam().credentials_type, ToString(tst[0]));
+ if (GetParam().credentials_type == kTlsCredentialsType) {
+ EXPECT_EQ("x509_subject_alternative_name",
+ auth_ctx->GetPeerIdentityPropertyName());
+ EXPECT_EQ(4u, auth_ctx->GetPeerIdentity().size());
+ EXPECT_EQ("*.test.google.fr", ToString(auth_ctx->GetPeerIdentity()[0]));
+ EXPECT_EQ("waterzooi.test.google.be",
+ ToString(auth_ctx->GetPeerIdentity()[1]));
+ EXPECT_EQ("*.test.youtube.com", ToString(auth_ctx->GetPeerIdentity()[2]));
+ EXPECT_EQ("192.168.1.3", ToString(auth_ctx->GetPeerIdentity()[3]));
+ }
+}
+
+class ResourceQuotaEnd2endTest : public End2endTest {
+ public:
+ ResourceQuotaEnd2endTest()
+ : server_resource_quota_("server_resource_quota") {}
+
+ virtual void ConfigureServerBuilder(ServerBuilder* builder) override {
+ builder->SetResourceQuota(server_resource_quota_);
+ }
+
+ private:
+ ResourceQuota server_resource_quota_;
+};
+
+TEST_P(ResourceQuotaEnd2endTest, SimpleRequest) {
+ MAYBE_SKIP_TEST;
+ ResetStub();
+
+ EchoRequest request;
+ EchoResponse response;
+ request.set_message("Hello");
+
+ ClientContext context;
+ Status s = stub_->Echo(&context, request, &response);
+ EXPECT_EQ(response.message(), request.message());
+ EXPECT_TRUE(s.ok());
+}
+
+// TODO(vjpai): refactor arguments into a struct if it makes sense
+std::vector<TestScenario> CreateTestScenarios(bool use_proxy,
+ bool test_insecure,
+ bool test_secure,
+ bool test_inproc,
+ bool test_callback_server) {
+ std::vector<TestScenario> scenarios;
+ std::vector<TString> credentials_types;
+
+ GPR_GLOBAL_CONFIG_SET(grpc_client_channel_backup_poll_interval_ms,
+ kClientChannelBackupPollIntervalMs);
+#if TARGET_OS_IPHONE
+ // Workaround Apple CFStream bug
+ gpr_setenv("grpc_cfstream", "0");
+#endif
+
+ if (test_secure) {
+ credentials_types =
+ GetCredentialsProvider()->GetSecureCredentialsTypeList();
+ }
+ auto insec_ok = [] {
+ // Only allow insecure credentials type when it is registered with the
+ // provider. User may create providers that do not have insecure.
+ return GetCredentialsProvider()->GetChannelCredentials(
+ kInsecureCredentialsType, nullptr) != nullptr;
+ };
+ if (test_insecure && insec_ok()) {
+ credentials_types.push_back(kInsecureCredentialsType);
+ }
+
+ // Test callback with inproc or if the event-engine allows it
+ GPR_ASSERT(!credentials_types.empty());
+ for (const auto& cred : credentials_types) {
+ scenarios.emplace_back(false, false, false, cred, false);
+ scenarios.emplace_back(true, false, false, cred, false);
+ if (test_callback_server) {
+ // Note that these scenarios will be dynamically disabled if the event
+ // engine doesn't run in the background
+ scenarios.emplace_back(false, false, false, cred, true);
+ scenarios.emplace_back(true, false, false, cred, true);
+ }
+ if (use_proxy) {
+ scenarios.emplace_back(false, true, false, cred, false);
+ scenarios.emplace_back(true, true, false, cred, false);
+ }
+ }
+ if (test_inproc && insec_ok()) {
+ scenarios.emplace_back(false, false, true, kInsecureCredentialsType, false);
+ scenarios.emplace_back(true, false, true, kInsecureCredentialsType, false);
+ if (test_callback_server) {
+ scenarios.emplace_back(false, false, true, kInsecureCredentialsType,
+ true);
+ scenarios.emplace_back(true, false, true, kInsecureCredentialsType, true);
+ }
+ }
+ return scenarios;
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ End2end, End2endTest,
+ ::testing::ValuesIn(CreateTestScenarios(false, true, true, true, true)));
+
+INSTANTIATE_TEST_SUITE_P(
+ End2endServerTryCancel, End2endServerTryCancelTest,
+ ::testing::ValuesIn(CreateTestScenarios(false, true, true, true, true)));
+
+INSTANTIATE_TEST_SUITE_P(
+ ProxyEnd2end, ProxyEnd2endTest,
+ ::testing::ValuesIn(CreateTestScenarios(true, true, true, true, true)));
+
+INSTANTIATE_TEST_SUITE_P(
+ SecureEnd2end, SecureEnd2endTest,
+ ::testing::ValuesIn(CreateTestScenarios(false, false, true, false, true)));
+
+INSTANTIATE_TEST_SUITE_P(
+ ResourceQuotaEnd2end, ResourceQuotaEnd2endTest,
+ ::testing::ValuesIn(CreateTestScenarios(false, true, true, true, true)));
+
+} // namespace
+} // namespace testing
+} // namespace grpc
+
+int main(int argc, char** argv) {
+ grpc::testing::TestEnvironment env(argc, argv);
+ ::testing::InitGoogleTest(&argc, argv);
+ int ret = RUN_ALL_TESTS();
+ return ret;
+}
diff --git a/contrib/libs/grpc/test/cpp/end2end/exception_test.cc b/contrib/libs/grpc/test/cpp/end2end/exception_test.cc
new file mode 100644
index 0000000000..cd29eb8a10
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/end2end/exception_test.cc
@@ -0,0 +1,123 @@
+/*
+ *
+ * Copyright 2017 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include <exception>
+#include <memory>
+
+#include <grpc/impl/codegen/port_platform.h>
+#include <grpcpp/channel.h>
+#include <grpcpp/client_context.h>
+#include <grpcpp/server.h>
+#include <grpcpp/server_builder.h>
+#include <grpcpp/server_context.h>
+
+#include "src/proto/grpc/testing/echo.grpc.pb.h"
+#include "test/core/util/test_config.h"
+
+#include <gtest/gtest.h>
+
+namespace grpc {
+namespace testing {
+
+const char* kErrorMessage = "This service caused an exception";
+
+#if GRPC_ALLOW_EXCEPTIONS
+class ExceptingServiceImpl : public ::grpc::testing::EchoTestService::Service {
+ public:
+ Status Echo(ServerContext* /*server_context*/, const EchoRequest* /*request*/,
+ EchoResponse* /*response*/) override {
+ throw - 1;
+ }
+ Status RequestStream(ServerContext* /*context*/,
+ ServerReader<EchoRequest>* /*reader*/,
+ EchoResponse* /*response*/) override {
+ throw ServiceException();
+ }
+
+ private:
+ class ServiceException final : public std::exception {
+ public:
+ ServiceException() {}
+
+ private:
+ const char* what() const noexcept override { return kErrorMessage; }
+ };
+};
+
+class ExceptionTest : public ::testing::Test {
+ protected:
+ ExceptionTest() {}
+
+ void SetUp() override {
+ ServerBuilder builder;
+ builder.RegisterService(&service_);
+ server_ = builder.BuildAndStart();
+ }
+
+ void TearDown() override { server_->Shutdown(); }
+
+ void ResetStub() {
+ channel_ = server_->InProcessChannel(ChannelArguments());
+ stub_ = grpc::testing::EchoTestService::NewStub(channel_);
+ }
+
+ std::shared_ptr<Channel> channel_;
+ std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_;
+ std::unique_ptr<Server> server_;
+ ExceptingServiceImpl service_;
+};
+
+TEST_F(ExceptionTest, Unary) {
+ ResetStub();
+ EchoRequest request;
+ EchoResponse response;
+ request.set_message("test");
+
+ for (int i = 0; i < 10; i++) {
+ ClientContext context;
+ Status s = stub_->Echo(&context, request, &response);
+ EXPECT_FALSE(s.ok());
+ EXPECT_EQ(s.error_code(), StatusCode::UNKNOWN);
+ }
+}
+
+TEST_F(ExceptionTest, RequestStream) {
+ ResetStub();
+ EchoResponse response;
+
+ for (int i = 0; i < 10; i++) {
+ ClientContext context;
+ auto stream = stub_->RequestStream(&context, &response);
+ stream->WritesDone();
+ Status s = stream->Finish();
+
+ EXPECT_FALSE(s.ok());
+ EXPECT_EQ(s.error_code(), StatusCode::UNKNOWN);
+ }
+}
+
+#endif // GRPC_ALLOW_EXCEPTIONS
+
+} // namespace testing
+} // namespace grpc
+
+int main(int argc, char** argv) {
+ grpc::testing::TestEnvironment env(argc, argv);
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/contrib/libs/grpc/test/cpp/end2end/filter_end2end_test.cc b/contrib/libs/grpc/test/cpp/end2end/filter_end2end_test.cc
new file mode 100644
index 0000000000..2f26d0716c
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/end2end/filter_end2end_test.cc
@@ -0,0 +1,346 @@
+/*
+ *
+ * Copyright 2016 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include <memory>
+#include <mutex>
+#include <thread>
+
+#include <grpc/grpc.h>
+#include <grpc/support/time.h>
+#include <grpcpp/channel.h>
+#include <grpcpp/client_context.h>
+#include <grpcpp/create_channel.h>
+#include <grpcpp/generic/async_generic_service.h>
+#include <grpcpp/generic/generic_stub.h>
+#include <grpcpp/impl/codegen/proto_utils.h>
+#include <grpcpp/server.h>
+#include <grpcpp/server_builder.h>
+#include <grpcpp/server_context.h>
+#include <grpcpp/support/config.h>
+#include <grpcpp/support/slice.h>
+
+#include "src/cpp/common/channel_filter.h"
+#include "src/proto/grpc/testing/echo.grpc.pb.h"
+#include "test/core/util/port.h"
+#include "test/core/util/test_config.h"
+#include "test/cpp/util/byte_buffer_proto_helper.h"
+
+#include <gtest/gtest.h>
+
+using grpc::testing::EchoRequest;
+using grpc::testing::EchoResponse;
+using std::chrono::system_clock;
+
+namespace grpc {
+namespace testing {
+namespace {
+
+void* tag(int i) { return (void*)static_cast<intptr_t>(i); }
+
+void verify_ok(CompletionQueue* cq, int i, bool expect_ok) {
+ bool ok;
+ void* got_tag;
+ EXPECT_TRUE(cq->Next(&got_tag, &ok));
+ EXPECT_EQ(expect_ok, ok);
+ EXPECT_EQ(tag(i), got_tag);
+}
+
+namespace {
+
+int global_num_connections = 0;
+int global_num_calls = 0;
+std::mutex global_mu;
+
+void IncrementConnectionCounter() {
+ std::unique_lock<std::mutex> lock(global_mu);
+ ++global_num_connections;
+}
+
+void ResetConnectionCounter() {
+ std::unique_lock<std::mutex> lock(global_mu);
+ global_num_connections = 0;
+}
+
+int GetConnectionCounterValue() {
+ std::unique_lock<std::mutex> lock(global_mu);
+ return global_num_connections;
+}
+
+void IncrementCallCounter() {
+ std::unique_lock<std::mutex> lock(global_mu);
+ ++global_num_calls;
+}
+
+void ResetCallCounter() {
+ std::unique_lock<std::mutex> lock(global_mu);
+ global_num_calls = 0;
+}
+
+int GetCallCounterValue() {
+ std::unique_lock<std::mutex> lock(global_mu);
+ return global_num_calls;
+}
+
+} // namespace
+
+class ChannelDataImpl : public ChannelData {
+ public:
+ grpc_error* Init(grpc_channel_element* /*elem*/,
+ grpc_channel_element_args* /*args*/) {
+ IncrementConnectionCounter();
+ return GRPC_ERROR_NONE;
+ }
+};
+
+class CallDataImpl : public CallData {
+ public:
+ void StartTransportStreamOpBatch(grpc_call_element* elem,
+ TransportStreamOpBatch* op) override {
+ // Incrementing the counter could be done from Init(), but we want
+ // to test that the individual methods are actually called correctly.
+ if (op->recv_initial_metadata() != nullptr) IncrementCallCounter();
+ grpc_call_next_op(elem, op->op());
+ }
+};
+
+class FilterEnd2endTest : public ::testing::Test {
+ protected:
+ FilterEnd2endTest() : server_host_("localhost") {}
+
+ static void SetUpTestCase() {
+ // Workaround for
+ // https://github.com/google/google-toolbox-for-mac/issues/242
+ static bool setup_done = false;
+ if (!setup_done) {
+ setup_done = true;
+ grpc::RegisterChannelFilter<ChannelDataImpl, CallDataImpl>(
+ "test-filter", GRPC_SERVER_CHANNEL, INT_MAX, nullptr);
+ }
+ }
+
+ void SetUp() override {
+ int port = grpc_pick_unused_port_or_die();
+ server_address_ << server_host_ << ":" << port;
+ // Setup server
+ ServerBuilder builder;
+ builder.AddListeningPort(server_address_.str(),
+ InsecureServerCredentials());
+ builder.RegisterAsyncGenericService(&generic_service_);
+ srv_cq_ = builder.AddCompletionQueue();
+ server_ = builder.BuildAndStart();
+ }
+
+ void TearDown() override {
+ server_->Shutdown();
+ void* ignored_tag;
+ bool ignored_ok;
+ cli_cq_.Shutdown();
+ srv_cq_->Shutdown();
+ while (cli_cq_.Next(&ignored_tag, &ignored_ok))
+ ;
+ while (srv_cq_->Next(&ignored_tag, &ignored_ok))
+ ;
+ }
+
+ void ResetStub() {
+ std::shared_ptr<Channel> channel = grpc::CreateChannel(
+ server_address_.str(), InsecureChannelCredentials());
+ generic_stub_.reset(new GenericStub(channel));
+ ResetConnectionCounter();
+ ResetCallCounter();
+ }
+
+ void server_ok(int i) { verify_ok(srv_cq_.get(), i, true); }
+ void client_ok(int i) { verify_ok(&cli_cq_, i, true); }
+ void server_fail(int i) { verify_ok(srv_cq_.get(), i, false); }
+ void client_fail(int i) { verify_ok(&cli_cq_, i, false); }
+
+ void SendRpc(int num_rpcs) {
+ const TString kMethodName("/grpc.cpp.test.util.EchoTestService/Echo");
+ for (int i = 0; i < num_rpcs; i++) {
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+
+ ClientContext cli_ctx;
+ GenericServerContext srv_ctx;
+ GenericServerAsyncReaderWriter stream(&srv_ctx);
+
+ // The string needs to be long enough to test heap-based slice.
+ send_request.set_message("Hello world. Hello world. Hello world.");
+ std::thread request_call([this]() { server_ok(4); });
+ std::unique_ptr<GenericClientAsyncReaderWriter> call =
+ generic_stub_->PrepareCall(&cli_ctx, kMethodName, &cli_cq_);
+ call->StartCall(tag(1));
+ client_ok(1);
+ std::unique_ptr<ByteBuffer> send_buffer =
+ SerializeToByteBuffer(&send_request);
+ call->Write(*send_buffer, tag(2));
+ // Send ByteBuffer can be destroyed after calling Write.
+ send_buffer.reset();
+ client_ok(2);
+ call->WritesDone(tag(3));
+ client_ok(3);
+
+ generic_service_.RequestCall(&srv_ctx, &stream, srv_cq_.get(),
+ srv_cq_.get(), tag(4));
+
+ request_call.join();
+ EXPECT_EQ(server_host_, srv_ctx.host().substr(0, server_host_.length()));
+ EXPECT_EQ(kMethodName, srv_ctx.method());
+ ByteBuffer recv_buffer;
+ stream.Read(&recv_buffer, tag(5));
+ server_ok(5);
+ EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_request));
+ EXPECT_EQ(send_request.message(), recv_request.message());
+
+ send_response.set_message(recv_request.message());
+ send_buffer = SerializeToByteBuffer(&send_response);
+ stream.Write(*send_buffer, tag(6));
+ send_buffer.reset();
+ server_ok(6);
+
+ stream.Finish(Status::OK, tag(7));
+ server_ok(7);
+
+ recv_buffer.Clear();
+ call->Read(&recv_buffer, tag(8));
+ client_ok(8);
+ EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_response));
+
+ call->Finish(&recv_status, tag(9));
+ client_ok(9);
+
+ EXPECT_EQ(send_response.message(), recv_response.message());
+ EXPECT_TRUE(recv_status.ok());
+ }
+ }
+
+ CompletionQueue cli_cq_;
+ std::unique_ptr<ServerCompletionQueue> srv_cq_;
+ std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_;
+ std::unique_ptr<grpc::GenericStub> generic_stub_;
+ std::unique_ptr<Server> server_;
+ AsyncGenericService generic_service_;
+ const TString server_host_;
+ std::ostringstream server_address_;
+};
+
+TEST_F(FilterEnd2endTest, SimpleRpc) {
+ ResetStub();
+ EXPECT_EQ(0, GetConnectionCounterValue());
+ EXPECT_EQ(0, GetCallCounterValue());
+ SendRpc(1);
+ EXPECT_EQ(1, GetConnectionCounterValue());
+ EXPECT_EQ(1, GetCallCounterValue());
+}
+
+TEST_F(FilterEnd2endTest, SequentialRpcs) {
+ ResetStub();
+ EXPECT_EQ(0, GetConnectionCounterValue());
+ EXPECT_EQ(0, GetCallCounterValue());
+ SendRpc(10);
+ EXPECT_EQ(1, GetConnectionCounterValue());
+ EXPECT_EQ(10, GetCallCounterValue());
+}
+
+// One ping, one pong.
+TEST_F(FilterEnd2endTest, SimpleBidiStreaming) {
+ ResetStub();
+ EXPECT_EQ(0, GetConnectionCounterValue());
+ EXPECT_EQ(0, GetCallCounterValue());
+
+ const TString kMethodName(
+ "/grpc.cpp.test.util.EchoTestService/BidiStream");
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+ ClientContext cli_ctx;
+ GenericServerContext srv_ctx;
+ GenericServerAsyncReaderWriter srv_stream(&srv_ctx);
+
+ cli_ctx.set_compression_algorithm(GRPC_COMPRESS_GZIP);
+ send_request.set_message("Hello");
+ std::thread request_call([this]() { server_ok(2); });
+ std::unique_ptr<GenericClientAsyncReaderWriter> cli_stream =
+ generic_stub_->PrepareCall(&cli_ctx, kMethodName, &cli_cq_);
+ cli_stream->StartCall(tag(1));
+ client_ok(1);
+
+ generic_service_.RequestCall(&srv_ctx, &srv_stream, srv_cq_.get(),
+ srv_cq_.get(), tag(2));
+
+ request_call.join();
+ EXPECT_EQ(server_host_, srv_ctx.host().substr(0, server_host_.length()));
+ EXPECT_EQ(kMethodName, srv_ctx.method());
+
+ std::unique_ptr<ByteBuffer> send_buffer =
+ SerializeToByteBuffer(&send_request);
+ cli_stream->Write(*send_buffer, tag(3));
+ send_buffer.reset();
+ client_ok(3);
+
+ ByteBuffer recv_buffer;
+ srv_stream.Read(&recv_buffer, tag(4));
+ server_ok(4);
+ EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_request));
+ EXPECT_EQ(send_request.message(), recv_request.message());
+
+ send_response.set_message(recv_request.message());
+ send_buffer = SerializeToByteBuffer(&send_response);
+ srv_stream.Write(*send_buffer, tag(5));
+ send_buffer.reset();
+ server_ok(5);
+
+ cli_stream->Read(&recv_buffer, tag(6));
+ client_ok(6);
+ EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_response));
+ EXPECT_EQ(send_response.message(), recv_response.message());
+
+ cli_stream->WritesDone(tag(7));
+ client_ok(7);
+
+ srv_stream.Read(&recv_buffer, tag(8));
+ server_fail(8);
+
+ srv_stream.Finish(Status::OK, tag(9));
+ server_ok(9);
+
+ cli_stream->Finish(&recv_status, tag(10));
+ client_ok(10);
+
+ EXPECT_EQ(send_response.message(), recv_response.message());
+ EXPECT_TRUE(recv_status.ok());
+
+ EXPECT_EQ(1, GetCallCounterValue());
+ EXPECT_EQ(1, GetConnectionCounterValue());
+}
+
+} // namespace
+} // namespace testing
+} // namespace grpc
+
+int main(int argc, char** argv) {
+ grpc::testing::TestEnvironment env(argc, argv);
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/contrib/libs/grpc/test/cpp/end2end/flaky_network_test.cc b/contrib/libs/grpc/test/cpp/end2end/flaky_network_test.cc
new file mode 100644
index 0000000000..3ee75952c0
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/end2end/flaky_network_test.cc
@@ -0,0 +1,558 @@
+/*
+ *
+ * Copyright 2019 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include <grpc/grpc.h>
+#include <grpc/support/alloc.h>
+#include <grpc/support/atm.h>
+#include <grpc/support/log.h>
+#include <grpc/support/port_platform.h>
+#include <grpc/support/string_util.h>
+#include <grpc/support/time.h>
+#include <grpcpp/channel.h>
+#include <grpcpp/client_context.h>
+#include <grpcpp/create_channel.h>
+#include <grpcpp/health_check_service_interface.h>
+#include <grpcpp/server.h>
+#include <grpcpp/server_builder.h>
+#include <gtest/gtest.h>
+
+#include <algorithm>
+#include <condition_variable>
+#include <memory>
+#include <mutex>
+#include <random>
+#include <thread>
+
+#include "src/core/lib/backoff/backoff.h"
+#include "src/core/lib/gpr/env.h"
+#include "src/proto/grpc/testing/echo.grpc.pb.h"
+#include "test/core/util/debugger_macros.h"
+#include "test/core/util/port.h"
+#include "test/core/util/test_config.h"
+#include "test/cpp/end2end/test_service_impl.h"
+#include "test/cpp/util/test_credentials_provider.h"
+
+#ifdef GPR_LINUX
+using grpc::testing::EchoRequest;
+using grpc::testing::EchoResponse;
+
+namespace grpc {
+namespace testing {
+namespace {
+
+struct TestScenario {
+ TestScenario(const TString& creds_type, const TString& content)
+ : credentials_type(creds_type), message_content(content) {}
+ const TString credentials_type;
+ const TString message_content;
+};
+
+class FlakyNetworkTest : public ::testing::TestWithParam<TestScenario> {
+ protected:
+ FlakyNetworkTest()
+ : server_host_("grpctest"),
+ interface_("lo:1"),
+ ipv4_address_("10.0.0.1"),
+ netmask_("/32") {}
+
+ void InterfaceUp() {
+ std::ostringstream cmd;
+ // create interface_ with address ipv4_address_
+ cmd << "ip addr add " << ipv4_address_ << netmask_ << " dev " << interface_;
+ std::system(cmd.str().c_str());
+ }
+
+ void InterfaceDown() {
+ std::ostringstream cmd;
+ // remove interface_
+ cmd << "ip addr del " << ipv4_address_ << netmask_ << " dev " << interface_;
+ std::system(cmd.str().c_str());
+ }
+
+ void DNSUp() {
+ std::ostringstream cmd;
+ // Add DNS entry for server_host_ in /etc/hosts
+ cmd << "echo '" << ipv4_address_ << " " << server_host_
+ << "' >> /etc/hosts";
+ std::system(cmd.str().c_str());
+ }
+
+ void DNSDown() {
+ std::ostringstream cmd;
+ // Remove DNS entry for server_host_ from /etc/hosts
+ // NOTE: we can't do this in one step with sed -i because when we are
+ // running under docker, the file is mounted by docker so we can't change
+ // its inode from within the container (sed -i creates a new file and
+ // replaces the old file, which changes the inode)
+ cmd << "sed '/" << server_host_ << "/d' /etc/hosts > /etc/hosts.orig";
+ std::system(cmd.str().c_str());
+
+ // clear the stream
+ cmd.str("");
+
+ cmd << "cat /etc/hosts.orig > /etc/hosts";
+ std::system(cmd.str().c_str());
+ }
+
+ void DropPackets() {
+ std::ostringstream cmd;
+ // drop packets with src IP = ipv4_address_
+ cmd << "iptables -A INPUT -s " << ipv4_address_ << " -j DROP";
+
+ std::system(cmd.str().c_str());
+ // clear the stream
+ cmd.str("");
+
+ // drop packets with dst IP = ipv4_address_
+ cmd << "iptables -A INPUT -d " << ipv4_address_ << " -j DROP";
+ }
+
+ void RestoreNetwork() {
+ std::ostringstream cmd;
+ // remove iptables rule to drop packets with src IP = ipv4_address_
+ cmd << "iptables -D INPUT -s " << ipv4_address_ << " -j DROP";
+ std::system(cmd.str().c_str());
+ // clear the stream
+ cmd.str("");
+ // remove iptables rule to drop packets with dest IP = ipv4_address_
+ cmd << "iptables -D INPUT -d " << ipv4_address_ << " -j DROP";
+ }
+
+ void FlakeNetwork() {
+ std::ostringstream cmd;
+ // Emulate a flaky network connection over interface_. Add a delay of 100ms
+ // +/- 20ms, 0.1% packet loss, 1% duplicates and 0.01% corrupt packets.
+ cmd << "tc qdisc replace dev " << interface_
+ << " root netem delay 100ms 20ms distribution normal loss 0.1% "
+ "duplicate "
+ "0.1% corrupt 0.01% ";
+ std::system(cmd.str().c_str());
+ }
+
+ void UnflakeNetwork() {
+ // Remove simulated network flake on interface_
+ std::ostringstream cmd;
+ cmd << "tc qdisc del dev " << interface_ << " root netem";
+ std::system(cmd.str().c_str());
+ }
+
+ void NetworkUp() {
+ InterfaceUp();
+ DNSUp();
+ }
+
+ void NetworkDown() {
+ InterfaceDown();
+ DNSDown();
+ }
+
+ void SetUp() override {
+ NetworkUp();
+ grpc_init();
+ StartServer();
+ }
+
+ void TearDown() override {
+ NetworkDown();
+ StopServer();
+ grpc_shutdown();
+ }
+
+ void StartServer() {
+ // TODO (pjaikumar): Ideally, we should allocate the port dynamically using
+ // grpc_pick_unused_port_or_die(). That doesn't work inside some docker
+ // containers because port_server listens on localhost which maps to
+ // ip6-looopback, but ipv6 support is not enabled by default in docker.
+ port_ = SERVER_PORT;
+
+ server_.reset(new ServerData(port_, GetParam().credentials_type));
+ server_->Start(server_host_);
+ }
+ void StopServer() { server_->Shutdown(); }
+
+ std::unique_ptr<grpc::testing::EchoTestService::Stub> BuildStub(
+ const std::shared_ptr<Channel>& channel) {
+ return grpc::testing::EchoTestService::NewStub(channel);
+ }
+
+ std::shared_ptr<Channel> BuildChannel(
+ const TString& lb_policy_name,
+ ChannelArguments args = ChannelArguments()) {
+ if (lb_policy_name.size() > 0) {
+ args.SetLoadBalancingPolicyName(lb_policy_name);
+ } // else, default to pick first
+ auto channel_creds = GetCredentialsProvider()->GetChannelCredentials(
+ GetParam().credentials_type, &args);
+ std::ostringstream server_address;
+ server_address << server_host_ << ":" << port_;
+ return CreateCustomChannel(server_address.str(), channel_creds, args);
+ }
+
+ bool SendRpc(
+ const std::unique_ptr<grpc::testing::EchoTestService::Stub>& stub,
+ int timeout_ms = 0, bool wait_for_ready = false) {
+ auto response = std::unique_ptr<EchoResponse>(new EchoResponse());
+ EchoRequest request;
+ auto& msg = GetParam().message_content;
+ request.set_message(msg);
+ ClientContext context;
+ if (timeout_ms > 0) {
+ context.set_deadline(grpc_timeout_milliseconds_to_deadline(timeout_ms));
+ // Allow an RPC to be canceled (for deadline exceeded) after it has
+ // reached the server.
+ request.mutable_param()->set_skip_cancelled_check(true);
+ }
+ // See https://github.com/grpc/grpc/blob/master/doc/wait-for-ready.md for
+ // details of wait-for-ready semantics
+ if (wait_for_ready) {
+ context.set_wait_for_ready(true);
+ }
+ Status status = stub->Echo(&context, request, response.get());
+ auto ok = status.ok();
+ int stream_id = 0;
+ grpc_call* call = context.c_call();
+ if (call) {
+ grpc_chttp2_stream* stream = grpc_chttp2_stream_from_call(call);
+ if (stream) {
+ stream_id = stream->id;
+ }
+ }
+ if (ok) {
+ gpr_log(GPR_DEBUG, "RPC with stream_id %d succeeded", stream_id);
+ } else {
+ gpr_log(GPR_DEBUG, "RPC with stream_id %d failed: %s", stream_id,
+ status.error_message().c_str());
+ }
+ return ok;
+ }
+
+ struct ServerData {
+ int port_;
+ const TString creds_;
+ std::unique_ptr<Server> server_;
+ TestServiceImpl service_;
+ std::unique_ptr<std::thread> thread_;
+ bool server_ready_ = false;
+
+ ServerData(int port, const TString& creds)
+ : port_(port), creds_(creds) {}
+
+ void Start(const TString& server_host) {
+ gpr_log(GPR_INFO, "starting server on port %d", port_);
+ std::mutex mu;
+ std::unique_lock<std::mutex> lock(mu);
+ std::condition_variable cond;
+ thread_.reset(new std::thread(
+ std::bind(&ServerData::Serve, this, server_host, &mu, &cond)));
+ cond.wait(lock, [this] { return server_ready_; });
+ server_ready_ = false;
+ gpr_log(GPR_INFO, "server startup complete");
+ }
+
+ void Serve(const TString& server_host, std::mutex* mu,
+ std::condition_variable* cond) {
+ std::ostringstream server_address;
+ server_address << server_host << ":" << port_;
+ ServerBuilder builder;
+ auto server_creds =
+ GetCredentialsProvider()->GetServerCredentials(creds_);
+ builder.AddListeningPort(server_address.str(), server_creds);
+ builder.RegisterService(&service_);
+ server_ = builder.BuildAndStart();
+ std::lock_guard<std::mutex> lock(*mu);
+ server_ready_ = true;
+ cond->notify_one();
+ }
+
+ void Shutdown() {
+ server_->Shutdown(grpc_timeout_milliseconds_to_deadline(0));
+ thread_->join();
+ }
+ };
+
+ bool WaitForChannelNotReady(Channel* channel, int timeout_seconds = 5) {
+ const gpr_timespec deadline =
+ grpc_timeout_seconds_to_deadline(timeout_seconds);
+ grpc_connectivity_state state;
+ while ((state = channel->GetState(false /* try_to_connect */)) ==
+ GRPC_CHANNEL_READY) {
+ if (!channel->WaitForStateChange(state, deadline)) return false;
+ }
+ return true;
+ }
+
+ bool WaitForChannelReady(Channel* channel, int timeout_seconds = 5) {
+ const gpr_timespec deadline =
+ grpc_timeout_seconds_to_deadline(timeout_seconds);
+ grpc_connectivity_state state;
+ while ((state = channel->GetState(true /* try_to_connect */)) !=
+ GRPC_CHANNEL_READY) {
+ if (!channel->WaitForStateChange(state, deadline)) return false;
+ }
+ return true;
+ }
+
+ private:
+ const TString server_host_;
+ const TString interface_;
+ const TString ipv4_address_;
+ const TString netmask_;
+ std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_;
+ std::unique_ptr<ServerData> server_;
+ const int SERVER_PORT = 32750;
+ int port_;
+};
+
+std::vector<TestScenario> CreateTestScenarios() {
+ std::vector<TestScenario> scenarios;
+ std::vector<TString> credentials_types;
+ std::vector<TString> messages;
+
+ credentials_types.push_back(kInsecureCredentialsType);
+ auto sec_list = GetCredentialsProvider()->GetSecureCredentialsTypeList();
+ for (auto sec = sec_list.begin(); sec != sec_list.end(); sec++) {
+ credentials_types.push_back(*sec);
+ }
+
+ messages.push_back("🖖");
+ for (size_t k = 1; k < GRPC_DEFAULT_MAX_RECV_MESSAGE_LENGTH / 1024; k *= 32) {
+ TString big_msg;
+ for (size_t i = 0; i < k * 1024; ++i) {
+ char c = 'a' + (i % 26);
+ big_msg += c;
+ }
+ messages.push_back(big_msg);
+ }
+ for (auto cred = credentials_types.begin(); cred != credentials_types.end();
+ ++cred) {
+ for (auto msg = messages.begin(); msg != messages.end(); msg++) {
+ scenarios.emplace_back(*cred, *msg);
+ }
+ }
+
+ return scenarios;
+}
+
+INSTANTIATE_TEST_SUITE_P(FlakyNetworkTest, FlakyNetworkTest,
+ ::testing::ValuesIn(CreateTestScenarios()));
+
+// Network interface connected to server flaps
+TEST_P(FlakyNetworkTest, NetworkTransition) {
+ const int kKeepAliveTimeMs = 1000;
+ const int kKeepAliveTimeoutMs = 1000;
+ ChannelArguments args;
+ args.SetInt(GRPC_ARG_KEEPALIVE_TIME_MS, kKeepAliveTimeMs);
+ args.SetInt(GRPC_ARG_KEEPALIVE_TIMEOUT_MS, kKeepAliveTimeoutMs);
+ args.SetInt(GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS, 1);
+ args.SetInt(GRPC_ARG_HTTP2_MAX_PINGS_WITHOUT_DATA, 0);
+
+ auto channel = BuildChannel("pick_first", args);
+ auto stub = BuildStub(channel);
+ // Channel should be in READY state after we send an RPC
+ EXPECT_TRUE(SendRpc(stub));
+ EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY);
+
+ std::atomic_bool shutdown{false};
+ std::thread sender = std::thread([this, &stub, &shutdown]() {
+ while (true) {
+ if (shutdown.load()) {
+ return;
+ }
+ SendRpc(stub);
+ std::this_thread::sleep_for(std::chrono::milliseconds(1000));
+ }
+ });
+
+ // bring down network
+ NetworkDown();
+ EXPECT_TRUE(WaitForChannelNotReady(channel.get()));
+ // bring network interface back up
+ InterfaceUp();
+ std::this_thread::sleep_for(std::chrono::milliseconds(1000));
+ // Restore DNS entry for server
+ DNSUp();
+ EXPECT_TRUE(WaitForChannelReady(channel.get()));
+ EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY);
+ shutdown.store(true);
+ sender.join();
+}
+
+// Traffic to server server is blackholed temporarily with keepalives enabled
+TEST_P(FlakyNetworkTest, ServerUnreachableWithKeepalive) {
+ const int kKeepAliveTimeMs = 1000;
+ const int kKeepAliveTimeoutMs = 1000;
+ const int kReconnectBackoffMs = 1000;
+ ChannelArguments args;
+ args.SetInt(GRPC_ARG_KEEPALIVE_TIME_MS, kKeepAliveTimeMs);
+ args.SetInt(GRPC_ARG_KEEPALIVE_TIMEOUT_MS, kKeepAliveTimeoutMs);
+ args.SetInt(GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS, 1);
+ args.SetInt(GRPC_ARG_HTTP2_MAX_PINGS_WITHOUT_DATA, 0);
+ // max time for a connection attempt
+ args.SetInt(GRPC_ARG_MIN_RECONNECT_BACKOFF_MS, kReconnectBackoffMs);
+ // max time between reconnect attempts
+ args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, kReconnectBackoffMs);
+
+ gpr_log(GPR_DEBUG, "FlakyNetworkTest.ServerUnreachableWithKeepalive start");
+ auto channel = BuildChannel("pick_first", args);
+ auto stub = BuildStub(channel);
+ // Channel should be in READY state after we send an RPC
+ EXPECT_TRUE(SendRpc(stub));
+ EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY);
+
+ std::atomic_bool shutdown{false};
+ std::thread sender = std::thread([this, &stub, &shutdown]() {
+ while (true) {
+ if (shutdown.load()) {
+ return;
+ }
+ SendRpc(stub);
+ std::this_thread::sleep_for(std::chrono::milliseconds(1000));
+ }
+ });
+
+ // break network connectivity
+ gpr_log(GPR_DEBUG, "Adding iptables rule to drop packets");
+ DropPackets();
+ std::this_thread::sleep_for(std::chrono::milliseconds(10000));
+ EXPECT_TRUE(WaitForChannelNotReady(channel.get()));
+ // bring network interface back up
+ RestoreNetwork();
+ gpr_log(GPR_DEBUG, "Removed iptables rule to drop packets");
+ EXPECT_TRUE(WaitForChannelReady(channel.get()));
+ EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY);
+ shutdown.store(true);
+ sender.join();
+ gpr_log(GPR_DEBUG, "FlakyNetworkTest.ServerUnreachableWithKeepalive end");
+}
+
+//
+// Traffic to server server is blackholed temporarily with keepalives disabled
+TEST_P(FlakyNetworkTest, ServerUnreachableNoKeepalive) {
+ auto channel = BuildChannel("pick_first", ChannelArguments());
+ auto stub = BuildStub(channel);
+ // Channel should be in READY state after we send an RPC
+ EXPECT_TRUE(SendRpc(stub));
+ EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY);
+
+ // break network connectivity
+ DropPackets();
+
+ std::thread sender = std::thread([this, &stub]() {
+ // RPC with deadline should timeout
+ EXPECT_FALSE(SendRpc(stub, /*timeout_ms=*/500, /*wait_for_ready=*/true));
+ // RPC without deadline forever until call finishes
+ EXPECT_TRUE(SendRpc(stub, /*timeout_ms=*/0, /*wait_for_ready=*/true));
+ });
+
+ std::this_thread::sleep_for(std::chrono::milliseconds(2000));
+ // bring network interface back up
+ RestoreNetwork();
+
+ // wait for RPC to finish
+ sender.join();
+}
+
+// Send RPCs over a flaky network connection
+TEST_P(FlakyNetworkTest, FlakyNetwork) {
+ const int kKeepAliveTimeMs = 1000;
+ const int kKeepAliveTimeoutMs = 1000;
+ const int kMessageCount = 100;
+ ChannelArguments args;
+ args.SetInt(GRPC_ARG_KEEPALIVE_TIME_MS, kKeepAliveTimeMs);
+ args.SetInt(GRPC_ARG_KEEPALIVE_TIMEOUT_MS, kKeepAliveTimeoutMs);
+ args.SetInt(GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS, 1);
+ args.SetInt(GRPC_ARG_HTTP2_MAX_PINGS_WITHOUT_DATA, 0);
+
+ auto channel = BuildChannel("pick_first", args);
+ auto stub = BuildStub(channel);
+ // Channel should be in READY state after we send an RPC
+ EXPECT_TRUE(SendRpc(stub));
+ EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY);
+
+ // simulate flaky network (packet loss, corruption and delays)
+ FlakeNetwork();
+ for (int i = 0; i < kMessageCount; ++i) {
+ SendRpc(stub);
+ }
+ // remove network flakiness
+ UnflakeNetwork();
+ EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY);
+}
+
+// Server is shutdown gracefully and restarted. Client keepalives are enabled
+TEST_P(FlakyNetworkTest, ServerRestartKeepaliveEnabled) {
+ const int kKeepAliveTimeMs = 1000;
+ const int kKeepAliveTimeoutMs = 1000;
+ ChannelArguments args;
+ args.SetInt(GRPC_ARG_KEEPALIVE_TIME_MS, kKeepAliveTimeMs);
+ args.SetInt(GRPC_ARG_KEEPALIVE_TIMEOUT_MS, kKeepAliveTimeoutMs);
+ args.SetInt(GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS, 1);
+ args.SetInt(GRPC_ARG_HTTP2_MAX_PINGS_WITHOUT_DATA, 0);
+
+ auto channel = BuildChannel("pick_first", args);
+ auto stub = BuildStub(channel);
+ // Channel should be in READY state after we send an RPC
+ EXPECT_TRUE(SendRpc(stub));
+ EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY);
+
+ // server goes down, client should detect server going down and calls should
+ // fail
+ StopServer();
+ EXPECT_TRUE(WaitForChannelNotReady(channel.get()));
+ EXPECT_FALSE(SendRpc(stub));
+
+ std::this_thread::sleep_for(std::chrono::milliseconds(1000));
+
+ // server restarts, calls succeed
+ StartServer();
+ EXPECT_TRUE(WaitForChannelReady(channel.get()));
+ // EXPECT_TRUE(SendRpc(stub));
+}
+
+// Server is shutdown gracefully and restarted. Client keepalives are enabled
+TEST_P(FlakyNetworkTest, ServerRestartKeepaliveDisabled) {
+ auto channel = BuildChannel("pick_first", ChannelArguments());
+ auto stub = BuildStub(channel);
+ // Channel should be in READY state after we send an RPC
+ EXPECT_TRUE(SendRpc(stub));
+ EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY);
+
+ // server sends GOAWAY when it's shutdown, so client attempts to reconnect
+ StopServer();
+ std::this_thread::sleep_for(std::chrono::milliseconds(1000));
+
+ EXPECT_TRUE(WaitForChannelNotReady(channel.get()));
+
+ std::this_thread::sleep_for(std::chrono::milliseconds(1000));
+
+ // server restarts, calls succeed
+ StartServer();
+ EXPECT_TRUE(WaitForChannelReady(channel.get()));
+}
+
+} // namespace
+} // namespace testing
+} // namespace grpc
+#endif // GPR_LINUX
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ grpc::testing::TestEnvironment env(argc, argv);
+ auto result = RUN_ALL_TESTS();
+ return result;
+}
diff --git a/contrib/libs/grpc/test/cpp/end2end/generic_end2end_test.cc b/contrib/libs/grpc/test/cpp/end2end/generic_end2end_test.cc
new file mode 100644
index 0000000000..59eec49fb2
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/end2end/generic_end2end_test.cc
@@ -0,0 +1,430 @@
+/*
+ *
+ * Copyright 2015 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include <memory>
+#include <thread>
+
+#include <grpc/grpc.h>
+#include <grpc/support/time.h>
+#include <grpcpp/channel.h>
+#include <grpcpp/client_context.h>
+#include <grpcpp/create_channel.h>
+#include <grpcpp/generic/async_generic_service.h>
+#include <grpcpp/generic/generic_stub.h>
+#include <grpcpp/impl/codegen/proto_utils.h>
+#include <grpcpp/server.h>
+#include <grpcpp/server_builder.h>
+#include <grpcpp/server_context.h>
+#include <grpcpp/support/slice.h>
+
+#include "src/proto/grpc/testing/echo.grpc.pb.h"
+#include "test/core/util/port.h"
+#include "test/core/util/test_config.h"
+#include "test/cpp/util/byte_buffer_proto_helper.h"
+
+#include <gtest/gtest.h>
+
+using grpc::testing::EchoRequest;
+using grpc::testing::EchoResponse;
+using std::chrono::system_clock;
+
+namespace grpc {
+namespace testing {
+namespace {
+
+void* tag(int i) { return (void*)static_cast<intptr_t>(i); }
+
+void verify_ok(CompletionQueue* cq, int i, bool expect_ok) {
+ bool ok;
+ void* got_tag;
+ EXPECT_TRUE(cq->Next(&got_tag, &ok));
+ EXPECT_EQ(expect_ok, ok);
+ EXPECT_EQ(tag(i), got_tag);
+}
+
+class GenericEnd2endTest : public ::testing::Test {
+ protected:
+ GenericEnd2endTest() : server_host_("localhost") {}
+
+ void SetUp() override {
+ shut_down_ = false;
+ int port = grpc_pick_unused_port_or_die();
+ server_address_ << server_host_ << ":" << port;
+ // Setup server
+ ServerBuilder builder;
+ builder.AddListeningPort(server_address_.str(),
+ InsecureServerCredentials());
+ builder.RegisterAsyncGenericService(&generic_service_);
+ // Include a second call to RegisterAsyncGenericService to make sure that
+ // we get an error in the log, since it is not allowed to have 2 async
+ // generic services
+ builder.RegisterAsyncGenericService(&generic_service_);
+ srv_cq_ = builder.AddCompletionQueue();
+ server_ = builder.BuildAndStart();
+ }
+
+ void ShutDownServerAndCQs() {
+ if (!shut_down_) {
+ server_->Shutdown();
+ void* ignored_tag;
+ bool ignored_ok;
+ cli_cq_.Shutdown();
+ srv_cq_->Shutdown();
+ while (cli_cq_.Next(&ignored_tag, &ignored_ok))
+ ;
+ while (srv_cq_->Next(&ignored_tag, &ignored_ok))
+ ;
+ shut_down_ = true;
+ }
+ }
+ void TearDown() override { ShutDownServerAndCQs(); }
+
+ void ResetStub() {
+ std::shared_ptr<Channel> channel = grpc::CreateChannel(
+ server_address_.str(), InsecureChannelCredentials());
+ stub_ = grpc::testing::EchoTestService::NewStub(channel);
+ generic_stub_.reset(new GenericStub(channel));
+ }
+
+ void server_ok(int i) { verify_ok(srv_cq_.get(), i, true); }
+ void client_ok(int i) { verify_ok(&cli_cq_, i, true); }
+ void server_fail(int i) { verify_ok(srv_cq_.get(), i, false); }
+ void client_fail(int i) { verify_ok(&cli_cq_, i, false); }
+
+ void SendRpc(int num_rpcs) {
+ SendRpc(num_rpcs, false, gpr_inf_future(GPR_CLOCK_MONOTONIC));
+ }
+
+ void SendRpc(int num_rpcs, bool check_deadline, gpr_timespec deadline) {
+ const TString kMethodName("/grpc.cpp.test.util.EchoTestService/Echo");
+ for (int i = 0; i < num_rpcs; i++) {
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+
+ ClientContext cli_ctx;
+ GenericServerContext srv_ctx;
+ GenericServerAsyncReaderWriter stream(&srv_ctx);
+
+ // The string needs to be long enough to test heap-based slice.
+ send_request.set_message("Hello world. Hello world. Hello world.");
+
+ if (check_deadline) {
+ cli_ctx.set_deadline(deadline);
+ }
+
+ // Rather than using the original kMethodName, make a short-lived
+ // copy to also confirm that we don't refer to this object beyond
+ // the initial call preparation
+ const TString* method_name = new TString(kMethodName);
+
+ std::unique_ptr<GenericClientAsyncReaderWriter> call =
+ generic_stub_->PrepareCall(&cli_ctx, *method_name, &cli_cq_);
+
+ delete method_name; // Make sure that this is not needed after invocation
+
+ std::thread request_call([this]() { server_ok(4); });
+ call->StartCall(tag(1));
+ client_ok(1);
+ std::unique_ptr<ByteBuffer> send_buffer =
+ SerializeToByteBuffer(&send_request);
+ call->Write(*send_buffer, tag(2));
+ // Send ByteBuffer can be destroyed after calling Write.
+ send_buffer.reset();
+ client_ok(2);
+ call->WritesDone(tag(3));
+ client_ok(3);
+
+ generic_service_.RequestCall(&srv_ctx, &stream, srv_cq_.get(),
+ srv_cq_.get(), tag(4));
+
+ request_call.join();
+ EXPECT_EQ(server_host_, srv_ctx.host().substr(0, server_host_.length()));
+ EXPECT_EQ(kMethodName, srv_ctx.method());
+
+ if (check_deadline) {
+ EXPECT_TRUE(gpr_time_similar(deadline, srv_ctx.raw_deadline(),
+ gpr_time_from_millis(1000, GPR_TIMESPAN)));
+ }
+
+ ByteBuffer recv_buffer;
+ stream.Read(&recv_buffer, tag(5));
+ server_ok(5);
+ EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_request));
+ EXPECT_EQ(send_request.message(), recv_request.message());
+
+ send_response.set_message(recv_request.message());
+ send_buffer = SerializeToByteBuffer(&send_response);
+ stream.Write(*send_buffer, tag(6));
+ send_buffer.reset();
+ server_ok(6);
+
+ stream.Finish(Status::OK, tag(7));
+ server_ok(7);
+
+ recv_buffer.Clear();
+ call->Read(&recv_buffer, tag(8));
+ client_ok(8);
+ EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_response));
+
+ call->Finish(&recv_status, tag(9));
+ client_ok(9);
+
+ EXPECT_EQ(send_response.message(), recv_response.message());
+ EXPECT_TRUE(recv_status.ok());
+ }
+ }
+
+ // Return errors to up to one call that comes in on the supplied completion
+ // queue, until the CQ is being shut down (and therefore we can no longer
+ // enqueue further events).
+ void DriveCompletionQueue() {
+ enum class Event : uintptr_t {
+ kCallReceived,
+ kResponseSent,
+ };
+ // Request the call, but only if the main thread hasn't beaten us to
+ // shutting down the CQ.
+ grpc::GenericServerContext server_context;
+ grpc::GenericServerAsyncReaderWriter reader_writer(&server_context);
+
+ {
+ std::lock_guard<std::mutex> lock(shutting_down_mu_);
+ if (!shutting_down_) {
+ generic_service_.RequestCall(
+ &server_context, &reader_writer, srv_cq_.get(), srv_cq_.get(),
+ reinterpret_cast<void*>(Event::kCallReceived));
+ }
+ }
+ // Process events.
+ {
+ Event event;
+ bool ok;
+ while (srv_cq_->Next(reinterpret_cast<void**>(&event), &ok)) {
+ std::lock_guard<std::mutex> lock(shutting_down_mu_);
+ if (shutting_down_) {
+ // The main thread has started shutting down. Simply continue to drain
+ // events.
+ continue;
+ }
+
+ switch (event) {
+ case Event::kCallReceived:
+ reader_writer.Finish(
+ ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, "go away"),
+ reinterpret_cast<void*>(Event::kResponseSent));
+ break;
+
+ case Event::kResponseSent:
+ // We are done.
+ break;
+ }
+ }
+ }
+ }
+
+ CompletionQueue cli_cq_;
+ std::unique_ptr<ServerCompletionQueue> srv_cq_;
+ std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_;
+ std::unique_ptr<grpc::GenericStub> generic_stub_;
+ std::unique_ptr<Server> server_;
+ AsyncGenericService generic_service_;
+ const TString server_host_;
+ std::ostringstream server_address_;
+ bool shutting_down_;
+ bool shut_down_;
+ std::mutex shutting_down_mu_;
+};
+
+TEST_F(GenericEnd2endTest, SimpleRpc) {
+ ResetStub();
+ SendRpc(1);
+}
+
+TEST_F(GenericEnd2endTest, SequentialRpcs) {
+ ResetStub();
+ SendRpc(10);
+}
+
+TEST_F(GenericEnd2endTest, SequentialUnaryRpcs) {
+ ResetStub();
+ const int num_rpcs = 10;
+ const TString kMethodName("/grpc.cpp.test.util.EchoTestService/Echo");
+ for (int i = 0; i < num_rpcs; i++) {
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+
+ ClientContext cli_ctx;
+ GenericServerContext srv_ctx;
+ GenericServerAsyncReaderWriter stream(&srv_ctx);
+
+ // The string needs to be long enough to test heap-based slice.
+ send_request.set_message("Hello world. Hello world. Hello world.");
+
+ std::unique_ptr<ByteBuffer> cli_send_buffer =
+ SerializeToByteBuffer(&send_request);
+ std::thread request_call([this]() { server_ok(4); });
+ std::unique_ptr<GenericClientAsyncResponseReader> call =
+ generic_stub_->PrepareUnaryCall(&cli_ctx, kMethodName,
+ *cli_send_buffer.get(), &cli_cq_);
+ call->StartCall();
+ ByteBuffer cli_recv_buffer;
+ call->Finish(&cli_recv_buffer, &recv_status, tag(1));
+ std::thread client_check([this] { client_ok(1); });
+
+ generic_service_.RequestCall(&srv_ctx, &stream, srv_cq_.get(),
+ srv_cq_.get(), tag(4));
+ request_call.join();
+ EXPECT_EQ(server_host_, srv_ctx.host().substr(0, server_host_.length()));
+ EXPECT_EQ(kMethodName, srv_ctx.method());
+
+ ByteBuffer srv_recv_buffer;
+ stream.Read(&srv_recv_buffer, tag(5));
+ server_ok(5);
+ EXPECT_TRUE(ParseFromByteBuffer(&srv_recv_buffer, &recv_request));
+ EXPECT_EQ(send_request.message(), recv_request.message());
+
+ send_response.set_message(recv_request.message());
+ std::unique_ptr<ByteBuffer> srv_send_buffer =
+ SerializeToByteBuffer(&send_response);
+ stream.Write(*srv_send_buffer, tag(6));
+ server_ok(6);
+
+ stream.Finish(Status::OK, tag(7));
+ server_ok(7);
+
+ client_check.join();
+ EXPECT_TRUE(ParseFromByteBuffer(&cli_recv_buffer, &recv_response));
+ EXPECT_EQ(send_response.message(), recv_response.message());
+ EXPECT_TRUE(recv_status.ok());
+ }
+}
+
+// One ping, one pong.
+TEST_F(GenericEnd2endTest, SimpleBidiStreaming) {
+ ResetStub();
+
+ const TString kMethodName(
+ "/grpc.cpp.test.util.EchoTestService/BidiStream");
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+ ClientContext cli_ctx;
+ GenericServerContext srv_ctx;
+ GenericServerAsyncReaderWriter srv_stream(&srv_ctx);
+
+ cli_ctx.set_compression_algorithm(GRPC_COMPRESS_GZIP);
+ send_request.set_message("Hello");
+ std::thread request_call([this]() { server_ok(2); });
+ std::unique_ptr<GenericClientAsyncReaderWriter> cli_stream =
+ generic_stub_->PrepareCall(&cli_ctx, kMethodName, &cli_cq_);
+ cli_stream->StartCall(tag(1));
+ client_ok(1);
+
+ generic_service_.RequestCall(&srv_ctx, &srv_stream, srv_cq_.get(),
+ srv_cq_.get(), tag(2));
+ request_call.join();
+
+ EXPECT_EQ(server_host_, srv_ctx.host().substr(0, server_host_.length()));
+ EXPECT_EQ(kMethodName, srv_ctx.method());
+
+ std::unique_ptr<ByteBuffer> send_buffer =
+ SerializeToByteBuffer(&send_request);
+ cli_stream->Write(*send_buffer, tag(3));
+ send_buffer.reset();
+ client_ok(3);
+
+ ByteBuffer recv_buffer;
+ srv_stream.Read(&recv_buffer, tag(4));
+ server_ok(4);
+ EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_request));
+ EXPECT_EQ(send_request.message(), recv_request.message());
+
+ send_response.set_message(recv_request.message());
+ send_buffer = SerializeToByteBuffer(&send_response);
+ srv_stream.Write(*send_buffer, tag(5));
+ send_buffer.reset();
+ server_ok(5);
+
+ cli_stream->Read(&recv_buffer, tag(6));
+ client_ok(6);
+ EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_response));
+ EXPECT_EQ(send_response.message(), recv_response.message());
+
+ cli_stream->WritesDone(tag(7));
+ client_ok(7);
+
+ srv_stream.Read(&recv_buffer, tag(8));
+ server_fail(8);
+
+ srv_stream.Finish(Status::OK, tag(9));
+ server_ok(9);
+
+ cli_stream->Finish(&recv_status, tag(10));
+ client_ok(10);
+
+ EXPECT_EQ(send_response.message(), recv_response.message());
+ EXPECT_TRUE(recv_status.ok());
+}
+
+TEST_F(GenericEnd2endTest, Deadline) {
+ ResetStub();
+ SendRpc(1, true,
+ gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC),
+ gpr_time_from_seconds(10, GPR_TIMESPAN)));
+}
+
+TEST_F(GenericEnd2endTest, ShortDeadline) {
+ ResetStub();
+
+ ClientContext cli_ctx;
+ EchoRequest request;
+ EchoResponse response;
+
+ shutting_down_ = false;
+ std::thread driver([this] { DriveCompletionQueue(); });
+
+ request.set_message("");
+ cli_ctx.set_deadline(gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC),
+ gpr_time_from_micros(500, GPR_TIMESPAN)));
+ Status s = stub_->Echo(&cli_ctx, request, &response);
+ EXPECT_FALSE(s.ok());
+ {
+ std::lock_guard<std::mutex> lock(shutting_down_mu_);
+ shutting_down_ = true;
+ }
+ ShutDownServerAndCQs();
+ driver.join();
+}
+
+} // namespace
+} // namespace testing
+} // namespace grpc
+
+int main(int argc, char** argv) {
+ grpc::testing::TestEnvironment env(argc, argv);
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/contrib/libs/grpc/test/cpp/end2end/grpclb_end2end_test.cc b/contrib/libs/grpc/test/cpp/end2end/grpclb_end2end_test.cc
new file mode 100644
index 0000000000..6208dc2535
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/end2end/grpclb_end2end_test.cc
@@ -0,0 +1,2029 @@
+/*
+ *
+ * Copyright 2017 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include <deque>
+#include <memory>
+#include <mutex>
+#include <set>
+#include <sstream>
+#include <util/generic/string.h>
+#include <thread>
+
+#include "y_absl/strings/str_cat.h"
+#include "y_absl/strings/str_format.h"
+
+#include <grpc/grpc.h>
+#include <grpc/support/alloc.h>
+#include <grpc/support/log.h>
+#include <grpc/support/time.h>
+#include <grpcpp/channel.h>
+#include <grpcpp/client_context.h>
+#include <grpcpp/create_channel.h>
+#include <grpcpp/impl/codegen/sync.h>
+#include <grpcpp/server.h>
+#include <grpcpp/server_builder.h>
+
+#include "src/core/ext/filters/client_channel/backup_poller.h"
+#include "src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_balancer_addresses.h"
+#include "src/core/ext/filters/client_channel/resolver/fake/fake_resolver.h"
+#include "src/core/ext/filters/client_channel/server_address.h"
+#include "src/core/ext/filters/client_channel/service_config.h"
+#include "src/core/lib/gpr/env.h"
+#include "src/core/lib/gprpp/ref_counted_ptr.h"
+#include "src/core/lib/iomgr/parse_address.h"
+#include "src/core/lib/iomgr/sockaddr.h"
+#include "src/core/lib/security/credentials/fake/fake_credentials.h"
+#include "src/core/lib/transport/authority_override.h"
+#include "src/cpp/client/secure_credentials.h"
+#include "src/cpp/server/secure_server_credentials.h"
+
+#include "test/core/util/port.h"
+#include "test/core/util/test_config.h"
+#include "test/cpp/end2end/test_service_impl.h"
+
+#include "src/proto/grpc/lb/v1/load_balancer.grpc.pb.h"
+#include "src/proto/grpc/testing/echo.grpc.pb.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+// TODO(dgq): Other scenarios in need of testing:
+// - Send a serverlist with faulty ip:port addresses (port > 2^16, etc).
+// - Test reception of invalid serverlist
+// - Test against a non-LB server.
+// - Random LB server closing the stream unexpectedly.
+//
+// Findings from end to end testing to be covered here:
+// - Handling of LB servers restart, including reconnection after backing-off
+// retries.
+// - Destruction of load balanced channel (and therefore of grpclb instance)
+// while:
+// 1) the internal LB call is still active. This should work by virtue
+// of the weak reference the LB call holds. The call should be terminated as
+// part of the grpclb shutdown process.
+// 2) the retry timer is active. Again, the weak reference it holds should
+// prevent a premature call to \a glb_destroy.
+
+using std::chrono::system_clock;
+
+using grpc::lb::v1::LoadBalancer;
+using grpc::lb::v1::LoadBalanceRequest;
+using grpc::lb::v1::LoadBalanceResponse;
+
+namespace grpc {
+namespace testing {
+namespace {
+
+constexpr char kDefaultServiceConfig[] =
+ "{\n"
+ " \"loadBalancingConfig\":[\n"
+ " { \"grpclb\":{} }\n"
+ " ]\n"
+ "}";
+
+template <typename ServiceType>
+class CountedService : public ServiceType {
+ public:
+ size_t request_count() {
+ grpc::internal::MutexLock lock(&mu_);
+ return request_count_;
+ }
+
+ size_t response_count() {
+ grpc::internal::MutexLock lock(&mu_);
+ return response_count_;
+ }
+
+ void IncreaseResponseCount() {
+ grpc::internal::MutexLock lock(&mu_);
+ ++response_count_;
+ }
+ void IncreaseRequestCount() {
+ grpc::internal::MutexLock lock(&mu_);
+ ++request_count_;
+ }
+
+ void ResetCounters() {
+ grpc::internal::MutexLock lock(&mu_);
+ request_count_ = 0;
+ response_count_ = 0;
+ }
+
+ protected:
+ grpc::internal::Mutex mu_;
+
+ private:
+ size_t request_count_ = 0;
+ size_t response_count_ = 0;
+};
+
+using BackendService = CountedService<TestServiceImpl>;
+using BalancerService = CountedService<LoadBalancer::Service>;
+
+const char g_kCallCredsMdKey[] = "Balancer should not ...";
+const char g_kCallCredsMdValue[] = "... receive me";
+
+class BackendServiceImpl : public BackendService {
+ public:
+ BackendServiceImpl() {}
+
+ Status Echo(ServerContext* context, const EchoRequest* request,
+ EchoResponse* response) override {
+ // Backend should receive the call credentials metadata.
+ auto call_credentials_entry =
+ context->client_metadata().find(g_kCallCredsMdKey);
+ EXPECT_NE(call_credentials_entry, context->client_metadata().end());
+ if (call_credentials_entry != context->client_metadata().end()) {
+ EXPECT_EQ(call_credentials_entry->second, g_kCallCredsMdValue);
+ }
+ IncreaseRequestCount();
+ const auto status = TestServiceImpl::Echo(context, request, response);
+ IncreaseResponseCount();
+ AddClient(context->peer().c_str());
+ return status;
+ }
+
+ void Start() {}
+
+ void Shutdown() {}
+
+ std::set<TString> clients() {
+ grpc::internal::MutexLock lock(&clients_mu_);
+ return clients_;
+ }
+
+ private:
+ void AddClient(const TString& client) {
+ grpc::internal::MutexLock lock(&clients_mu_);
+ clients_.insert(client);
+ }
+
+ grpc::internal::Mutex mu_;
+ grpc::internal::Mutex clients_mu_;
+ std::set<TString> clients_;
+};
+
+TString Ip4ToPackedString(const char* ip_str) {
+ struct in_addr ip4;
+ GPR_ASSERT(inet_pton(AF_INET, ip_str, &ip4) == 1);
+ return TString(reinterpret_cast<const char*>(&ip4), sizeof(ip4));
+}
+
+struct ClientStats {
+ size_t num_calls_started = 0;
+ size_t num_calls_finished = 0;
+ size_t num_calls_finished_with_client_failed_to_send = 0;
+ size_t num_calls_finished_known_received = 0;
+ std::map<TString, size_t> drop_token_counts;
+
+ ClientStats& operator+=(const ClientStats& other) {
+ num_calls_started += other.num_calls_started;
+ num_calls_finished += other.num_calls_finished;
+ num_calls_finished_with_client_failed_to_send +=
+ other.num_calls_finished_with_client_failed_to_send;
+ num_calls_finished_known_received +=
+ other.num_calls_finished_known_received;
+ for (const auto& p : other.drop_token_counts) {
+ drop_token_counts[p.first] += p.second;
+ }
+ return *this;
+ }
+
+ void Reset() {
+ num_calls_started = 0;
+ num_calls_finished = 0;
+ num_calls_finished_with_client_failed_to_send = 0;
+ num_calls_finished_known_received = 0;
+ drop_token_counts.clear();
+ }
+};
+
+class BalancerServiceImpl : public BalancerService {
+ public:
+ using Stream = ServerReaderWriter<LoadBalanceResponse, LoadBalanceRequest>;
+ using ResponseDelayPair = std::pair<LoadBalanceResponse, int>;
+
+ explicit BalancerServiceImpl(int client_load_reporting_interval_seconds)
+ : client_load_reporting_interval_seconds_(
+ client_load_reporting_interval_seconds) {}
+
+ Status BalanceLoad(ServerContext* context, Stream* stream) override {
+ gpr_log(GPR_INFO, "LB[%p]: BalanceLoad", this);
+ {
+ grpc::internal::MutexLock lock(&mu_);
+ if (serverlist_done_) goto done;
+ }
+ {
+ // Balancer shouldn't receive the call credentials metadata.
+ EXPECT_EQ(context->client_metadata().find(g_kCallCredsMdKey),
+ context->client_metadata().end());
+ LoadBalanceRequest request;
+ std::vector<ResponseDelayPair> responses_and_delays;
+
+ if (!stream->Read(&request)) {
+ goto done;
+ } else {
+ if (request.has_initial_request()) {
+ grpc::internal::MutexLock lock(&mu_);
+ service_names_.push_back(request.initial_request().name());
+ }
+ }
+ IncreaseRequestCount();
+ gpr_log(GPR_INFO, "LB[%p]: received initial message '%s'", this,
+ request.DebugString().c_str());
+
+ // TODO(juanlishen): Initial response should always be the first response.
+ if (client_load_reporting_interval_seconds_ > 0) {
+ LoadBalanceResponse initial_response;
+ initial_response.mutable_initial_response()
+ ->mutable_client_stats_report_interval()
+ ->set_seconds(client_load_reporting_interval_seconds_);
+ stream->Write(initial_response);
+ }
+
+ {
+ grpc::internal::MutexLock lock(&mu_);
+ responses_and_delays = responses_and_delays_;
+ }
+ for (const auto& response_and_delay : responses_and_delays) {
+ SendResponse(stream, response_and_delay.first,
+ response_and_delay.second);
+ }
+ {
+ grpc::internal::MutexLock lock(&mu_);
+ serverlist_cond_.WaitUntil(&mu_, [this] { return serverlist_done_; });
+ }
+
+ if (client_load_reporting_interval_seconds_ > 0) {
+ request.Clear();
+ while (stream->Read(&request)) {
+ gpr_log(GPR_INFO, "LB[%p]: received client load report message '%s'",
+ this, request.DebugString().c_str());
+ GPR_ASSERT(request.has_client_stats());
+ ClientStats load_report;
+ load_report.num_calls_started =
+ request.client_stats().num_calls_started();
+ load_report.num_calls_finished =
+ request.client_stats().num_calls_finished();
+ load_report.num_calls_finished_with_client_failed_to_send =
+ request.client_stats()
+ .num_calls_finished_with_client_failed_to_send();
+ load_report.num_calls_finished_known_received =
+ request.client_stats().num_calls_finished_known_received();
+ for (const auto& drop_token_count :
+ request.client_stats().calls_finished_with_drop()) {
+ load_report
+ .drop_token_counts[drop_token_count.load_balance_token()] =
+ drop_token_count.num_calls();
+ }
+ // We need to acquire the lock here in order to prevent the notify_one
+ // below from firing before its corresponding wait is executed.
+ grpc::internal::MutexLock lock(&mu_);
+ load_report_queue_.emplace_back(std::move(load_report));
+ if (load_report_cond_ != nullptr) load_report_cond_->Signal();
+ }
+ }
+ }
+ done:
+ gpr_log(GPR_INFO, "LB[%p]: done", this);
+ return Status::OK;
+ }
+
+ void add_response(const LoadBalanceResponse& response, int send_after_ms) {
+ grpc::internal::MutexLock lock(&mu_);
+ responses_and_delays_.push_back(std::make_pair(response, send_after_ms));
+ }
+
+ void Start() {
+ grpc::internal::MutexLock lock(&mu_);
+ serverlist_done_ = false;
+ responses_and_delays_.clear();
+ load_report_queue_.clear();
+ }
+
+ void Shutdown() {
+ NotifyDoneWithServerlists();
+ gpr_log(GPR_INFO, "LB[%p]: shut down", this);
+ }
+
+ static LoadBalanceResponse BuildResponseForBackends(
+ const std::vector<int>& backend_ports,
+ const std::map<TString, size_t>& drop_token_counts) {
+ LoadBalanceResponse response;
+ for (const auto& drop_token_count : drop_token_counts) {
+ for (size_t i = 0; i < drop_token_count.second; ++i) {
+ auto* server = response.mutable_server_list()->add_servers();
+ server->set_drop(true);
+ server->set_load_balance_token(drop_token_count.first);
+ }
+ }
+ for (const int& backend_port : backend_ports) {
+ auto* server = response.mutable_server_list()->add_servers();
+ server->set_ip_address(Ip4ToPackedString("127.0.0.1"));
+ server->set_port(backend_port);
+ static int token_count = 0;
+ server->set_load_balance_token(
+ y_absl::StrFormat("token%03d", ++token_count));
+ }
+ return response;
+ }
+
+ ClientStats WaitForLoadReport() {
+ grpc::internal::MutexLock lock(&mu_);
+ grpc::internal::CondVar cv;
+ if (load_report_queue_.empty()) {
+ load_report_cond_ = &cv;
+ load_report_cond_->WaitUntil(
+ &mu_, [this] { return !load_report_queue_.empty(); });
+ load_report_cond_ = nullptr;
+ }
+ ClientStats load_report = std::move(load_report_queue_.front());
+ load_report_queue_.pop_front();
+ return load_report;
+ }
+
+ void NotifyDoneWithServerlists() {
+ grpc::internal::MutexLock lock(&mu_);
+ if (!serverlist_done_) {
+ serverlist_done_ = true;
+ serverlist_cond_.Broadcast();
+ }
+ }
+
+ std::vector<TString> service_names() {
+ grpc::internal::MutexLock lock(&mu_);
+ return service_names_;
+ }
+
+ private:
+ void SendResponse(Stream* stream, const LoadBalanceResponse& response,
+ int delay_ms) {
+ gpr_log(GPR_INFO, "LB[%p]: sleeping for %d ms...", this, delay_ms);
+ if (delay_ms > 0) {
+ gpr_sleep_until(grpc_timeout_milliseconds_to_deadline(delay_ms));
+ }
+ gpr_log(GPR_INFO, "LB[%p]: Woke up! Sending response '%s'", this,
+ response.DebugString().c_str());
+ IncreaseResponseCount();
+ stream->Write(response);
+ }
+
+ const int client_load_reporting_interval_seconds_;
+ std::vector<ResponseDelayPair> responses_and_delays_;
+ std::vector<TString> service_names_;
+
+ grpc::internal::Mutex mu_;
+ grpc::internal::CondVar serverlist_cond_;
+ bool serverlist_done_ = false;
+ grpc::internal::CondVar* load_report_cond_ = nullptr;
+ std::deque<ClientStats> load_report_queue_;
+};
+
+class GrpclbEnd2endTest : public ::testing::Test {
+ protected:
+ GrpclbEnd2endTest(size_t num_backends, size_t num_balancers,
+ int client_load_reporting_interval_seconds)
+ : server_host_("localhost"),
+ num_backends_(num_backends),
+ num_balancers_(num_balancers),
+ client_load_reporting_interval_seconds_(
+ client_load_reporting_interval_seconds) {}
+
+ static void SetUpTestCase() {
+ // Make the backup poller poll very frequently in order to pick up
+ // updates from all the subchannels's FDs.
+ GPR_GLOBAL_CONFIG_SET(grpc_client_channel_backup_poll_interval_ms, 1);
+#if TARGET_OS_IPHONE
+ // Workaround Apple CFStream bug
+ gpr_setenv("grpc_cfstream", "0");
+#endif
+ grpc_init();
+ }
+
+ static void TearDownTestCase() { grpc_shutdown(); }
+
+ void SetUp() override {
+ response_generator_ =
+ grpc_core::MakeRefCounted<grpc_core::FakeResolverResponseGenerator>();
+ // Start the backends.
+ for (size_t i = 0; i < num_backends_; ++i) {
+ backends_.emplace_back(new ServerThread<BackendServiceImpl>("backend"));
+ backends_.back()->Start(server_host_);
+ }
+ // Start the load balancers.
+ for (size_t i = 0; i < num_balancers_; ++i) {
+ balancers_.emplace_back(new ServerThread<BalancerServiceImpl>(
+ "balancer", client_load_reporting_interval_seconds_));
+ balancers_.back()->Start(server_host_);
+ }
+ ResetStub();
+ }
+
+ void TearDown() override {
+ ShutdownAllBackends();
+ for (auto& balancer : balancers_) balancer->Shutdown();
+ }
+
+ void StartAllBackends() {
+ for (auto& backend : backends_) backend->Start(server_host_);
+ }
+
+ void StartBackend(size_t index) { backends_[index]->Start(server_host_); }
+
+ void ShutdownAllBackends() {
+ for (auto& backend : backends_) backend->Shutdown();
+ }
+
+ void ShutdownBackend(size_t index) { backends_[index]->Shutdown(); }
+
+ void ResetStub(int fallback_timeout = 0,
+ const TString& expected_targets = "") {
+ ChannelArguments args;
+ if (fallback_timeout > 0) args.SetGrpclbFallbackTimeout(fallback_timeout);
+ args.SetPointer(GRPC_ARG_FAKE_RESOLVER_RESPONSE_GENERATOR,
+ response_generator_.get());
+ if (!expected_targets.empty()) {
+ args.SetString(GRPC_ARG_FAKE_SECURITY_EXPECTED_TARGETS, expected_targets);
+ }
+ std::ostringstream uri;
+ uri << "fake:///" << kApplicationTargetName_;
+ // TODO(dgq): templatize tests to run everything using both secure and
+ // insecure channel credentials.
+ grpc_channel_credentials* channel_creds =
+ grpc_fake_transport_security_credentials_create();
+ grpc_call_credentials* call_creds = grpc_md_only_test_credentials_create(
+ g_kCallCredsMdKey, g_kCallCredsMdValue, false);
+ std::shared_ptr<ChannelCredentials> creds(
+ new SecureChannelCredentials(grpc_composite_channel_credentials_create(
+ channel_creds, call_creds, nullptr)));
+ call_creds->Unref();
+ channel_creds->Unref();
+ channel_ = ::grpc::CreateCustomChannel(uri.str(), creds, args);
+ stub_ = grpc::testing::EchoTestService::NewStub(channel_);
+ }
+
+ void ResetBackendCounters() {
+ for (auto& backend : backends_) backend->service_.ResetCounters();
+ }
+
+ ClientStats WaitForLoadReports() {
+ ClientStats client_stats;
+ for (auto& balancer : balancers_) {
+ client_stats += balancer->service_.WaitForLoadReport();
+ }
+ return client_stats;
+ }
+
+ bool SeenAllBackends(size_t start_index = 0, size_t stop_index = 0) {
+ if (stop_index == 0) stop_index = backends_.size();
+ for (size_t i = start_index; i < stop_index; ++i) {
+ if (backends_[i]->service_.request_count() == 0) return false;
+ }
+ return true;
+ }
+
+ void SendRpcAndCount(int* num_total, int* num_ok, int* num_failure,
+ int* num_drops) {
+ const Status status = SendRpc();
+ if (status.ok()) {
+ ++*num_ok;
+ } else {
+ if (status.error_message() == "Call dropped by load balancing policy") {
+ ++*num_drops;
+ } else {
+ ++*num_failure;
+ }
+ }
+ ++*num_total;
+ }
+
+ std::tuple<int, int, int> WaitForAllBackends(int num_requests_multiple_of = 1,
+ size_t start_index = 0,
+ size_t stop_index = 0) {
+ int num_ok = 0;
+ int num_failure = 0;
+ int num_drops = 0;
+ int num_total = 0;
+ while (!SeenAllBackends(start_index, stop_index)) {
+ SendRpcAndCount(&num_total, &num_ok, &num_failure, &num_drops);
+ }
+ while (num_total % num_requests_multiple_of != 0) {
+ SendRpcAndCount(&num_total, &num_ok, &num_failure, &num_drops);
+ }
+ ResetBackendCounters();
+ gpr_log(GPR_INFO,
+ "Performed %d warm up requests (a multiple of %d) against the "
+ "backends. %d succeeded, %d failed, %d dropped.",
+ num_total, num_requests_multiple_of, num_ok, num_failure,
+ num_drops);
+ return std::make_tuple(num_ok, num_failure, num_drops);
+ }
+
+ void WaitForBackend(size_t backend_idx) {
+ do {
+ (void)SendRpc();
+ } while (backends_[backend_idx]->service_.request_count() == 0);
+ ResetBackendCounters();
+ }
+
+ struct AddressData {
+ int port;
+ TString balancer_name;
+ };
+
+ static grpc_core::ServerAddressList CreateLbAddressesFromAddressDataList(
+ const std::vector<AddressData>& address_data) {
+ grpc_core::ServerAddressList addresses;
+ for (const auto& addr : address_data) {
+ TString lb_uri_str = y_absl::StrCat("ipv4:127.0.0.1:", addr.port);
+ grpc_uri* lb_uri = grpc_uri_parse(lb_uri_str.c_str(), true);
+ GPR_ASSERT(lb_uri != nullptr);
+ grpc_resolved_address address;
+ GPR_ASSERT(grpc_parse_uri(lb_uri, &address));
+ grpc_arg arg = grpc_core::CreateAuthorityOverrideChannelArg(
+ addr.balancer_name.c_str());
+ grpc_channel_args* args =
+ grpc_channel_args_copy_and_add(nullptr, &arg, 1);
+ addresses.emplace_back(address.addr, address.len, args);
+ grpc_uri_destroy(lb_uri);
+ }
+ return addresses;
+ }
+
+ static grpc_core::Resolver::Result MakeResolverResult(
+ const std::vector<AddressData>& balancer_address_data,
+ const std::vector<AddressData>& backend_address_data = {},
+ const char* service_config_json = kDefaultServiceConfig) {
+ grpc_core::Resolver::Result result;
+ result.addresses =
+ CreateLbAddressesFromAddressDataList(backend_address_data);
+ grpc_error* error = GRPC_ERROR_NONE;
+ result.service_config =
+ grpc_core::ServiceConfig::Create(nullptr, service_config_json, &error);
+ GPR_ASSERT(error == GRPC_ERROR_NONE);
+ grpc_core::ServerAddressList balancer_addresses =
+ CreateLbAddressesFromAddressDataList(balancer_address_data);
+ grpc_arg arg = CreateGrpclbBalancerAddressesArg(&balancer_addresses);
+ result.args = grpc_channel_args_copy_and_add(nullptr, &arg, 1);
+ return result;
+ }
+
+ void SetNextResolutionAllBalancers(
+ const char* service_config_json = kDefaultServiceConfig) {
+ std::vector<AddressData> addresses;
+ for (size_t i = 0; i < balancers_.size(); ++i) {
+ addresses.emplace_back(AddressData{balancers_[i]->port_, ""});
+ }
+ SetNextResolution(addresses, {}, service_config_json);
+ }
+
+ void SetNextResolution(
+ const std::vector<AddressData>& balancer_address_data,
+ const std::vector<AddressData>& backend_address_data = {},
+ const char* service_config_json = kDefaultServiceConfig) {
+ grpc_core::ExecCtx exec_ctx;
+ grpc_core::Resolver::Result result = MakeResolverResult(
+ balancer_address_data, backend_address_data, service_config_json);
+ response_generator_->SetResponse(std::move(result));
+ }
+
+ void SetNextReresolutionResponse(
+ const std::vector<AddressData>& balancer_address_data,
+ const std::vector<AddressData>& backend_address_data = {},
+ const char* service_config_json = kDefaultServiceConfig) {
+ grpc_core::ExecCtx exec_ctx;
+ grpc_core::Resolver::Result result = MakeResolverResult(
+ balancer_address_data, backend_address_data, service_config_json);
+ response_generator_->SetReresolutionResponse(std::move(result));
+ }
+
+ const std::vector<int> GetBackendPorts(size_t start_index = 0,
+ size_t stop_index = 0) const {
+ if (stop_index == 0) stop_index = backends_.size();
+ std::vector<int> backend_ports;
+ for (size_t i = start_index; i < stop_index; ++i) {
+ backend_ports.push_back(backends_[i]->port_);
+ }
+ return backend_ports;
+ }
+
+ void ScheduleResponseForBalancer(size_t i,
+ const LoadBalanceResponse& response,
+ int delay_ms) {
+ balancers_[i]->service_.add_response(response, delay_ms);
+ }
+
+ Status SendRpc(EchoResponse* response = nullptr, int timeout_ms = 1000,
+ bool wait_for_ready = false,
+ const Status& expected_status = Status::OK) {
+ const bool local_response = (response == nullptr);
+ if (local_response) response = new EchoResponse;
+ EchoRequest request;
+ request.set_message(kRequestMessage_);
+ if (!expected_status.ok()) {
+ auto* error = request.mutable_param()->mutable_expected_error();
+ error->set_code(expected_status.error_code());
+ error->set_error_message(expected_status.error_message());
+ }
+ ClientContext context;
+ context.set_deadline(grpc_timeout_milliseconds_to_deadline(timeout_ms));
+ if (wait_for_ready) context.set_wait_for_ready(true);
+ Status status = stub_->Echo(&context, request, response);
+ if (local_response) delete response;
+ return status;
+ }
+
+ void CheckRpcSendOk(const size_t times = 1, const int timeout_ms = 1000,
+ bool wait_for_ready = false) {
+ for (size_t i = 0; i < times; ++i) {
+ EchoResponse response;
+ const Status status = SendRpc(&response, timeout_ms, wait_for_ready);
+ EXPECT_TRUE(status.ok()) << "code=" << status.error_code()
+ << " message=" << status.error_message();
+ EXPECT_EQ(response.message(), kRequestMessage_);
+ }
+ }
+
+ void CheckRpcSendFailure() {
+ const Status status = SendRpc();
+ EXPECT_FALSE(status.ok());
+ }
+
+ template <typename T>
+ struct ServerThread {
+ template <typename... Args>
+ explicit ServerThread(const TString& type, Args&&... args)
+ : port_(grpc_pick_unused_port_or_die()),
+ type_(type),
+ service_(std::forward<Args>(args)...) {}
+
+ void Start(const TString& server_host) {
+ gpr_log(GPR_INFO, "starting %s server on port %d", type_.c_str(), port_);
+ GPR_ASSERT(!running_);
+ running_ = true;
+ service_.Start();
+ grpc::internal::Mutex mu;
+ // We need to acquire the lock here in order to prevent the notify_one
+ // by ServerThread::Serve from firing before the wait below is hit.
+ grpc::internal::MutexLock lock(&mu);
+ grpc::internal::CondVar cond;
+ thread_.reset(new std::thread(
+ std::bind(&ServerThread::Serve, this, server_host, &mu, &cond)));
+ cond.Wait(&mu);
+ gpr_log(GPR_INFO, "%s server startup complete", type_.c_str());
+ }
+
+ void Serve(const TString& server_host, grpc::internal::Mutex* mu,
+ grpc::internal::CondVar* cond) {
+ // We need to acquire the lock here in order to prevent the notify_one
+ // below from firing before its corresponding wait is executed.
+ grpc::internal::MutexLock lock(mu);
+ std::ostringstream server_address;
+ server_address << server_host << ":" << port_;
+ ServerBuilder builder;
+ std::shared_ptr<ServerCredentials> creds(new SecureServerCredentials(
+ grpc_fake_transport_security_server_credentials_create()));
+ builder.AddListeningPort(server_address.str(), creds);
+ builder.RegisterService(&service_);
+ server_ = builder.BuildAndStart();
+ cond->Signal();
+ }
+
+ void Shutdown() {
+ if (!running_) return;
+ gpr_log(GPR_INFO, "%s about to shutdown", type_.c_str());
+ service_.Shutdown();
+ server_->Shutdown(grpc_timeout_milliseconds_to_deadline(0));
+ thread_->join();
+ gpr_log(GPR_INFO, "%s shutdown completed", type_.c_str());
+ running_ = false;
+ }
+
+ const int port_;
+ TString type_;
+ T service_;
+ std::unique_ptr<Server> server_;
+ std::unique_ptr<std::thread> thread_;
+ bool running_ = false;
+ };
+
+ const TString server_host_;
+ const size_t num_backends_;
+ const size_t num_balancers_;
+ const int client_load_reporting_interval_seconds_;
+ std::shared_ptr<Channel> channel_;
+ std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_;
+ std::vector<std::unique_ptr<ServerThread<BackendServiceImpl>>> backends_;
+ std::vector<std::unique_ptr<ServerThread<BalancerServiceImpl>>> balancers_;
+ grpc_core::RefCountedPtr<grpc_core::FakeResolverResponseGenerator>
+ response_generator_;
+ const TString kRequestMessage_ = "Live long and prosper.";
+ const TString kApplicationTargetName_ = "application_target_name";
+};
+
+class SingleBalancerTest : public GrpclbEnd2endTest {
+ public:
+ SingleBalancerTest() : GrpclbEnd2endTest(4, 1, 0) {}
+};
+
+TEST_F(SingleBalancerTest, Vanilla) {
+ SetNextResolutionAllBalancers();
+ const size_t kNumRpcsPerAddress = 100;
+ ScheduleResponseForBalancer(
+ 0, BalancerServiceImpl::BuildResponseForBackends(GetBackendPorts(), {}),
+ 0);
+ // Make sure that trying to connect works without a call.
+ channel_->GetState(true /* try_to_connect */);
+ // We need to wait for all backends to come online.
+ WaitForAllBackends();
+ // Send kNumRpcsPerAddress RPCs per server.
+ CheckRpcSendOk(kNumRpcsPerAddress * num_backends_);
+
+ // Each backend should have gotten 100 requests.
+ for (size_t i = 0; i < backends_.size(); ++i) {
+ EXPECT_EQ(kNumRpcsPerAddress, backends_[i]->service_.request_count());
+ }
+ balancers_[0]->service_.NotifyDoneWithServerlists();
+ // The balancer got a single request.
+ EXPECT_EQ(1U, balancers_[0]->service_.request_count());
+ // and sent a single response.
+ EXPECT_EQ(1U, balancers_[0]->service_.response_count());
+
+ // Check LB policy name for the channel.
+ EXPECT_EQ("grpclb", channel_->GetLoadBalancingPolicyName());
+}
+
+TEST_F(SingleBalancerTest, ReturnServerStatus) {
+ SetNextResolutionAllBalancers();
+ ScheduleResponseForBalancer(
+ 0, BalancerServiceImpl::BuildResponseForBackends(GetBackendPorts(), {}),
+ 0);
+ // We need to wait for all backends to come online.
+ WaitForAllBackends();
+ // Send a request that the backend will fail, and make sure we get
+ // back the right status.
+ Status expected(StatusCode::INVALID_ARGUMENT, "He's dead, Jim!");
+ Status actual = SendRpc(/*response=*/nullptr, /*timeout_ms=*/1000,
+ /*wait_for_ready=*/false, expected);
+ EXPECT_EQ(actual.error_code(), expected.error_code());
+ EXPECT_EQ(actual.error_message(), expected.error_message());
+}
+
+TEST_F(SingleBalancerTest, SelectGrpclbWithMigrationServiceConfig) {
+ SetNextResolutionAllBalancers(
+ "{\n"
+ " \"loadBalancingConfig\":[\n"
+ " { \"does_not_exist\":{} },\n"
+ " { \"grpclb\":{} }\n"
+ " ]\n"
+ "}");
+ ScheduleResponseForBalancer(
+ 0, BalancerServiceImpl::BuildResponseForBackends(GetBackendPorts(), {}),
+ 0);
+ CheckRpcSendOk(1, 1000 /* timeout_ms */, true /* wait_for_ready */);
+ balancers_[0]->service_.NotifyDoneWithServerlists();
+ // The balancer got a single request.
+ EXPECT_EQ(1U, balancers_[0]->service_.request_count());
+ // and sent a single response.
+ EXPECT_EQ(1U, balancers_[0]->service_.response_count());
+ // Check LB policy name for the channel.
+ EXPECT_EQ("grpclb", channel_->GetLoadBalancingPolicyName());
+}
+
+TEST_F(SingleBalancerTest,
+ SelectGrpclbWithMigrationServiceConfigAndNoAddresses) {
+ const int kFallbackTimeoutMs = 200 * grpc_test_slowdown_factor();
+ ResetStub(kFallbackTimeoutMs);
+ SetNextResolution({}, {},
+ "{\n"
+ " \"loadBalancingConfig\":[\n"
+ " { \"does_not_exist\":{} },\n"
+ " { \"grpclb\":{} }\n"
+ " ]\n"
+ "}");
+ // Try to connect.
+ EXPECT_EQ(GRPC_CHANNEL_IDLE, channel_->GetState(true));
+ // Should go into state TRANSIENT_FAILURE when we enter fallback mode.
+ const gpr_timespec deadline = grpc_timeout_seconds_to_deadline(1);
+ grpc_connectivity_state state;
+ while ((state = channel_->GetState(false)) !=
+ GRPC_CHANNEL_TRANSIENT_FAILURE) {
+ ASSERT_TRUE(channel_->WaitForStateChange(state, deadline));
+ }
+ // Check LB policy name for the channel.
+ EXPECT_EQ("grpclb", channel_->GetLoadBalancingPolicyName());
+}
+
+TEST_F(SingleBalancerTest, UsePickFirstChildPolicy) {
+ SetNextResolutionAllBalancers(
+ "{\n"
+ " \"loadBalancingConfig\":[\n"
+ " { \"grpclb\":{\n"
+ " \"childPolicy\":[\n"
+ " { \"pick_first\":{} }\n"
+ " ]\n"
+ " } }\n"
+ " ]\n"
+ "}");
+ ScheduleResponseForBalancer(
+ 0, BalancerServiceImpl::BuildResponseForBackends(GetBackendPorts(), {}),
+ 0);
+ const size_t kNumRpcs = num_backends_ * 2;
+ CheckRpcSendOk(kNumRpcs, 1000 /* timeout_ms */, true /* wait_for_ready */);
+ balancers_[0]->service_.NotifyDoneWithServerlists();
+ // Check that all requests went to the first backend. This verifies
+ // that we used pick_first instead of round_robin as the child policy.
+ EXPECT_EQ(backends_[0]->service_.request_count(), kNumRpcs);
+ for (size_t i = 1; i < backends_.size(); ++i) {
+ EXPECT_EQ(backends_[i]->service_.request_count(), 0UL);
+ }
+ // The balancer got a single request.
+ EXPECT_EQ(1U, balancers_[0]->service_.request_count());
+ // and sent a single response.
+ EXPECT_EQ(1U, balancers_[0]->service_.response_count());
+ // Check LB policy name for the channel.
+ EXPECT_EQ("grpclb", channel_->GetLoadBalancingPolicyName());
+}
+
+TEST_F(SingleBalancerTest, SwapChildPolicy) {
+ SetNextResolutionAllBalancers(
+ "{\n"
+ " \"loadBalancingConfig\":[\n"
+ " { \"grpclb\":{\n"
+ " \"childPolicy\":[\n"
+ " { \"pick_first\":{} }\n"
+ " ]\n"
+ " } }\n"
+ " ]\n"
+ "}");
+ ScheduleResponseForBalancer(
+ 0, BalancerServiceImpl::BuildResponseForBackends(GetBackendPorts(), {}),
+ 0);
+ const size_t kNumRpcs = num_backends_ * 2;
+ CheckRpcSendOk(kNumRpcs, 1000 /* timeout_ms */, true /* wait_for_ready */);
+ // Check that all requests went to the first backend. This verifies
+ // that we used pick_first instead of round_robin as the child policy.
+ EXPECT_EQ(backends_[0]->service_.request_count(), kNumRpcs);
+ for (size_t i = 1; i < backends_.size(); ++i) {
+ EXPECT_EQ(backends_[i]->service_.request_count(), 0UL);
+ }
+ // Send new resolution that removes child policy from service config.
+ SetNextResolutionAllBalancers();
+ WaitForAllBackends();
+ CheckRpcSendOk(kNumRpcs, 1000 /* timeout_ms */, true /* wait_for_ready */);
+ // Check that every backend saw the same number of requests. This verifies
+ // that we used round_robin.
+ for (size_t i = 0; i < backends_.size(); ++i) {
+ EXPECT_EQ(backends_[i]->service_.request_count(), 2UL);
+ }
+ // Done.
+ balancers_[0]->service_.NotifyDoneWithServerlists();
+ // The balancer got a single request.
+ EXPECT_EQ(1U, balancers_[0]->service_.request_count());
+ // and sent a single response.
+ EXPECT_EQ(1U, balancers_[0]->service_.response_count());
+ // Check LB policy name for the channel.
+ EXPECT_EQ("grpclb", channel_->GetLoadBalancingPolicyName());
+}
+
+TEST_F(SingleBalancerTest, SameBackendListedMultipleTimes) {
+ SetNextResolutionAllBalancers();
+ // Same backend listed twice.
+ std::vector<int> ports;
+ ports.push_back(backends_[0]->port_);
+ ports.push_back(backends_[0]->port_);
+ const size_t kNumRpcsPerAddress = 10;
+ ScheduleResponseForBalancer(
+ 0, BalancerServiceImpl::BuildResponseForBackends(ports, {}), 0);
+ // We need to wait for the backend to come online.
+ WaitForBackend(0);
+ // Send kNumRpcsPerAddress RPCs per server.
+ CheckRpcSendOk(kNumRpcsPerAddress * ports.size());
+ // Backend should have gotten 20 requests.
+ EXPECT_EQ(kNumRpcsPerAddress * 2, backends_[0]->service_.request_count());
+ // And they should have come from a single client port, because of
+ // subchannel sharing.
+ EXPECT_EQ(1UL, backends_[0]->service_.clients().size());
+ balancers_[0]->service_.NotifyDoneWithServerlists();
+}
+
+TEST_F(SingleBalancerTest, SecureNaming) {
+ ResetStub(0, kApplicationTargetName_ + ";lb");
+ SetNextResolution({AddressData{balancers_[0]->port_, "lb"}});
+ const size_t kNumRpcsPerAddress = 100;
+ ScheduleResponseForBalancer(
+ 0, BalancerServiceImpl::BuildResponseForBackends(GetBackendPorts(), {}),
+ 0);
+ // Make sure that trying to connect works without a call.
+ channel_->GetState(true /* try_to_connect */);
+ // We need to wait for all backends to come online.
+ WaitForAllBackends();
+ // Send kNumRpcsPerAddress RPCs per server.
+ CheckRpcSendOk(kNumRpcsPerAddress * num_backends_);
+
+ // Each backend should have gotten 100 requests.
+ for (size_t i = 0; i < backends_.size(); ++i) {
+ EXPECT_EQ(kNumRpcsPerAddress, backends_[i]->service_.request_count());
+ }
+ balancers_[0]->service_.NotifyDoneWithServerlists();
+ // The balancer got a single request.
+ EXPECT_EQ(1U, balancers_[0]->service_.request_count());
+ // and sent a single response.
+ EXPECT_EQ(1U, balancers_[0]->service_.response_count());
+ // Check LB policy name for the channel.
+ EXPECT_EQ("grpclb", channel_->GetLoadBalancingPolicyName());
+}
+
+TEST_F(SingleBalancerTest, SecureNamingDeathTest) {
+ ::testing::FLAGS_gtest_death_test_style = "threadsafe";
+ // Make sure that we blow up (via abort() from the security connector) when
+ // the name from the balancer doesn't match expectations.
+ ASSERT_DEATH_IF_SUPPORTED(
+ {
+ ResetStub(0, kApplicationTargetName_ + ";lb");
+ SetNextResolution({AddressData{balancers_[0]->port_, "woops"}});
+ channel_->WaitForConnected(grpc_timeout_seconds_to_deadline(1));
+ },
+ "");
+}
+
+TEST_F(SingleBalancerTest, InitiallyEmptyServerlist) {
+ SetNextResolutionAllBalancers();
+ const int kServerlistDelayMs = 500 * grpc_test_slowdown_factor();
+ const int kCallDeadlineMs = kServerlistDelayMs * 2;
+ // First response is an empty serverlist, sent right away.
+ ScheduleResponseForBalancer(0, LoadBalanceResponse(), 0);
+ // Send non-empty serverlist only after kServerlistDelayMs
+ ScheduleResponseForBalancer(
+ 0, BalancerServiceImpl::BuildResponseForBackends(GetBackendPorts(), {}),
+ kServerlistDelayMs);
+ const auto t0 = system_clock::now();
+ // Client will block: LB will initially send empty serverlist.
+ CheckRpcSendOk(1, kCallDeadlineMs, true /* wait_for_ready */);
+ const auto ellapsed_ms =
+ std::chrono::duration_cast<std::chrono::milliseconds>(
+ system_clock::now() - t0);
+ // but eventually, the LB sends a serverlist update that allows the call to
+ // proceed. The call delay must be larger than the delay in sending the
+ // populated serverlist but under the call's deadline (which is enforced by
+ // the call's deadline).
+ EXPECT_GT(ellapsed_ms.count(), kServerlistDelayMs);
+ balancers_[0]->service_.NotifyDoneWithServerlists();
+ // The balancer got a single request.
+ EXPECT_EQ(1U, balancers_[0]->service_.request_count());
+ // and sent two responses.
+ EXPECT_EQ(2U, balancers_[0]->service_.response_count());
+}
+
+TEST_F(SingleBalancerTest, AllServersUnreachableFailFast) {
+ SetNextResolutionAllBalancers();
+ const size_t kNumUnreachableServers = 5;
+ std::vector<int> ports;
+ for (size_t i = 0; i < kNumUnreachableServers; ++i) {
+ ports.push_back(grpc_pick_unused_port_or_die());
+ }
+ ScheduleResponseForBalancer(
+ 0, BalancerServiceImpl::BuildResponseForBackends(ports, {}), 0);
+ const Status status = SendRpc();
+ // The error shouldn't be DEADLINE_EXCEEDED.
+ EXPECT_EQ(StatusCode::UNAVAILABLE, status.error_code());
+ balancers_[0]->service_.NotifyDoneWithServerlists();
+ // The balancer got a single request.
+ EXPECT_EQ(1U, balancers_[0]->service_.request_count());
+ // and sent a single response.
+ EXPECT_EQ(1U, balancers_[0]->service_.response_count());
+}
+
+TEST_F(SingleBalancerTest, Fallback) {
+ SetNextResolutionAllBalancers();
+ const int kFallbackTimeoutMs = 200 * grpc_test_slowdown_factor();
+ const int kServerlistDelayMs = 500 * grpc_test_slowdown_factor();
+ const size_t kNumBackendsInResolution = backends_.size() / 2;
+
+ ResetStub(kFallbackTimeoutMs);
+ std::vector<AddressData> balancer_addresses;
+ balancer_addresses.emplace_back(AddressData{balancers_[0]->port_, ""});
+ std::vector<AddressData> backend_addresses;
+ for (size_t i = 0; i < kNumBackendsInResolution; ++i) {
+ backend_addresses.emplace_back(AddressData{backends_[i]->port_, ""});
+ }
+ SetNextResolution(balancer_addresses, backend_addresses);
+
+ // Send non-empty serverlist only after kServerlistDelayMs.
+ ScheduleResponseForBalancer(
+ 0,
+ BalancerServiceImpl::BuildResponseForBackends(
+ GetBackendPorts(kNumBackendsInResolution /* start_index */), {}),
+ kServerlistDelayMs);
+
+ // Wait until all the fallback backends are reachable.
+ for (size_t i = 0; i < kNumBackendsInResolution; ++i) {
+ WaitForBackend(i);
+ }
+
+ // The first request.
+ gpr_log(GPR_INFO, "========= BEFORE FIRST BATCH ==========");
+ CheckRpcSendOk(kNumBackendsInResolution);
+ gpr_log(GPR_INFO, "========= DONE WITH FIRST BATCH ==========");
+
+ // Fallback is used: each backend returned by the resolver should have
+ // gotten one request.
+ for (size_t i = 0; i < kNumBackendsInResolution; ++i) {
+ EXPECT_EQ(1U, backends_[i]->service_.request_count());
+ }
+ for (size_t i = kNumBackendsInResolution; i < backends_.size(); ++i) {
+ EXPECT_EQ(0U, backends_[i]->service_.request_count());
+ }
+
+ // Wait until the serverlist reception has been processed and all backends
+ // in the serverlist are reachable.
+ for (size_t i = kNumBackendsInResolution; i < backends_.size(); ++i) {
+ WaitForBackend(i);
+ }
+
+ // Send out the second request.
+ gpr_log(GPR_INFO, "========= BEFORE SECOND BATCH ==========");
+ CheckRpcSendOk(backends_.size() - kNumBackendsInResolution);
+ gpr_log(GPR_INFO, "========= DONE WITH SECOND BATCH ==========");
+
+ // Serverlist is used: each backend returned by the balancer should
+ // have gotten one request.
+ for (size_t i = 0; i < kNumBackendsInResolution; ++i) {
+ EXPECT_EQ(0U, backends_[i]->service_.request_count());
+ }
+ for (size_t i = kNumBackendsInResolution; i < backends_.size(); ++i) {
+ EXPECT_EQ(1U, backends_[i]->service_.request_count());
+ }
+
+ balancers_[0]->service_.NotifyDoneWithServerlists();
+ // The balancer got a single request.
+ EXPECT_EQ(1U, balancers_[0]->service_.request_count());
+ // and sent a single response.
+ EXPECT_EQ(1U, balancers_[0]->service_.response_count());
+}
+
+TEST_F(SingleBalancerTest, FallbackUpdate) {
+ SetNextResolutionAllBalancers();
+ const int kFallbackTimeoutMs = 200 * grpc_test_slowdown_factor();
+ const int kServerlistDelayMs = 500 * grpc_test_slowdown_factor();
+ const size_t kNumBackendsInResolution = backends_.size() / 3;
+ const size_t kNumBackendsInResolutionUpdate = backends_.size() / 3;
+
+ ResetStub(kFallbackTimeoutMs);
+ std::vector<AddressData> balancer_addresses;
+ balancer_addresses.emplace_back(AddressData{balancers_[0]->port_, ""});
+ std::vector<AddressData> backend_addresses;
+ for (size_t i = 0; i < kNumBackendsInResolution; ++i) {
+ backend_addresses.emplace_back(AddressData{backends_[i]->port_, ""});
+ }
+ SetNextResolution(balancer_addresses, backend_addresses);
+
+ // Send non-empty serverlist only after kServerlistDelayMs.
+ ScheduleResponseForBalancer(
+ 0,
+ BalancerServiceImpl::BuildResponseForBackends(
+ GetBackendPorts(kNumBackendsInResolution +
+ kNumBackendsInResolutionUpdate /* start_index */),
+ {}),
+ kServerlistDelayMs);
+
+ // Wait until all the fallback backends are reachable.
+ for (size_t i = 0; i < kNumBackendsInResolution; ++i) {
+ WaitForBackend(i);
+ }
+
+ // The first request.
+ gpr_log(GPR_INFO, "========= BEFORE FIRST BATCH ==========");
+ CheckRpcSendOk(kNumBackendsInResolution);
+ gpr_log(GPR_INFO, "========= DONE WITH FIRST BATCH ==========");
+
+ // Fallback is used: each backend returned by the resolver should have
+ // gotten one request.
+ for (size_t i = 0; i < kNumBackendsInResolution; ++i) {
+ EXPECT_EQ(1U, backends_[i]->service_.request_count());
+ }
+ for (size_t i = kNumBackendsInResolution; i < backends_.size(); ++i) {
+ EXPECT_EQ(0U, backends_[i]->service_.request_count());
+ }
+
+ balancer_addresses.clear();
+ balancer_addresses.emplace_back(AddressData{balancers_[0]->port_, ""});
+ backend_addresses.clear();
+ for (size_t i = kNumBackendsInResolution;
+ i < kNumBackendsInResolution + kNumBackendsInResolutionUpdate; ++i) {
+ backend_addresses.emplace_back(AddressData{backends_[i]->port_, ""});
+ }
+ SetNextResolution(balancer_addresses, backend_addresses);
+
+ // Wait until the resolution update has been processed and all the new
+ // fallback backends are reachable.
+ for (size_t i = kNumBackendsInResolution;
+ i < kNumBackendsInResolution + kNumBackendsInResolutionUpdate; ++i) {
+ WaitForBackend(i);
+ }
+
+ // Send out the second request.
+ gpr_log(GPR_INFO, "========= BEFORE SECOND BATCH ==========");
+ CheckRpcSendOk(kNumBackendsInResolutionUpdate);
+ gpr_log(GPR_INFO, "========= DONE WITH SECOND BATCH ==========");
+
+ // The resolution update is used: each backend in the resolution update should
+ // have gotten one request.
+ for (size_t i = 0; i < kNumBackendsInResolution; ++i) {
+ EXPECT_EQ(0U, backends_[i]->service_.request_count());
+ }
+ for (size_t i = kNumBackendsInResolution;
+ i < kNumBackendsInResolution + kNumBackendsInResolutionUpdate; ++i) {
+ EXPECT_EQ(1U, backends_[i]->service_.request_count());
+ }
+ for (size_t i = kNumBackendsInResolution + kNumBackendsInResolutionUpdate;
+ i < backends_.size(); ++i) {
+ EXPECT_EQ(0U, backends_[i]->service_.request_count());
+ }
+
+ // Wait until the serverlist reception has been processed and all backends
+ // in the serverlist are reachable.
+ for (size_t i = kNumBackendsInResolution + kNumBackendsInResolutionUpdate;
+ i < backends_.size(); ++i) {
+ WaitForBackend(i);
+ }
+
+ // Send out the third request.
+ gpr_log(GPR_INFO, "========= BEFORE THIRD BATCH ==========");
+ CheckRpcSendOk(backends_.size() - kNumBackendsInResolution -
+ kNumBackendsInResolutionUpdate);
+ gpr_log(GPR_INFO, "========= DONE WITH THIRD BATCH ==========");
+
+ // Serverlist is used: each backend returned by the balancer should
+ // have gotten one request.
+ for (size_t i = 0;
+ i < kNumBackendsInResolution + kNumBackendsInResolutionUpdate; ++i) {
+ EXPECT_EQ(0U, backends_[i]->service_.request_count());
+ }
+ for (size_t i = kNumBackendsInResolution + kNumBackendsInResolutionUpdate;
+ i < backends_.size(); ++i) {
+ EXPECT_EQ(1U, backends_[i]->service_.request_count());
+ }
+
+ balancers_[0]->service_.NotifyDoneWithServerlists();
+ // The balancer got a single request.
+ EXPECT_EQ(1U, balancers_[0]->service_.request_count());
+ // and sent a single response.
+ EXPECT_EQ(1U, balancers_[0]->service_.response_count());
+}
+
+TEST_F(SingleBalancerTest,
+ FallbackAfterStartup_LoseContactWithBalancerThenBackends) {
+ // First two backends are fallback, last two are pointed to by balancer.
+ const size_t kNumFallbackBackends = 2;
+ const size_t kNumBalancerBackends = backends_.size() - kNumFallbackBackends;
+ std::vector<AddressData> backend_addresses;
+ for (size_t i = 0; i < kNumFallbackBackends; ++i) {
+ backend_addresses.emplace_back(AddressData{backends_[i]->port_, ""});
+ }
+ std::vector<AddressData> balancer_addresses;
+ for (size_t i = 0; i < balancers_.size(); ++i) {
+ balancer_addresses.emplace_back(AddressData{balancers_[i]->port_, ""});
+ }
+ SetNextResolution(balancer_addresses, backend_addresses);
+ ScheduleResponseForBalancer(0,
+ BalancerServiceImpl::BuildResponseForBackends(
+ GetBackendPorts(kNumFallbackBackends), {}),
+ 0);
+ // Try to connect.
+ channel_->GetState(true /* try_to_connect */);
+ WaitForAllBackends(1 /* num_requests_multiple_of */,
+ kNumFallbackBackends /* start_index */);
+ // Stop balancer. RPCs should continue going to backends from balancer.
+ balancers_[0]->Shutdown();
+ CheckRpcSendOk(100 * kNumBalancerBackends);
+ for (size_t i = kNumFallbackBackends; i < backends_.size(); ++i) {
+ EXPECT_EQ(100UL, backends_[i]->service_.request_count());
+ }
+ // Stop backends from balancer. This should put us in fallback mode.
+ for (size_t i = kNumFallbackBackends; i < backends_.size(); ++i) {
+ ShutdownBackend(i);
+ }
+ WaitForAllBackends(1 /* num_requests_multiple_of */, 0 /* start_index */,
+ kNumFallbackBackends /* stop_index */);
+ // Restart the backends from the balancer. We should *not* start
+ // sending traffic back to them at this point (although the behavior
+ // in xds may be different).
+ for (size_t i = kNumFallbackBackends; i < backends_.size(); ++i) {
+ StartBackend(i);
+ }
+ CheckRpcSendOk(100 * kNumBalancerBackends);
+ for (size_t i = 0; i < kNumFallbackBackends; ++i) {
+ EXPECT_EQ(100UL, backends_[i]->service_.request_count());
+ }
+ // Now start the balancer again. This should cause us to exit
+ // fallback mode.
+ balancers_[0]->Start(server_host_);
+ ScheduleResponseForBalancer(0,
+ BalancerServiceImpl::BuildResponseForBackends(
+ GetBackendPorts(kNumFallbackBackends), {}),
+ 0);
+ WaitForAllBackends(1 /* num_requests_multiple_of */,
+ kNumFallbackBackends /* start_index */);
+}
+
+TEST_F(SingleBalancerTest,
+ FallbackAfterStartup_LoseContactWithBackendsThenBalancer) {
+ // First two backends are fallback, last two are pointed to by balancer.
+ const size_t kNumFallbackBackends = 2;
+ const size_t kNumBalancerBackends = backends_.size() - kNumFallbackBackends;
+ std::vector<AddressData> backend_addresses;
+ for (size_t i = 0; i < kNumFallbackBackends; ++i) {
+ backend_addresses.emplace_back(AddressData{backends_[i]->port_, ""});
+ }
+ std::vector<AddressData> balancer_addresses;
+ for (size_t i = 0; i < balancers_.size(); ++i) {
+ balancer_addresses.emplace_back(AddressData{balancers_[i]->port_, ""});
+ }
+ SetNextResolution(balancer_addresses, backend_addresses);
+ ScheduleResponseForBalancer(0,
+ BalancerServiceImpl::BuildResponseForBackends(
+ GetBackendPorts(kNumFallbackBackends), {}),
+ 0);
+ // Try to connect.
+ channel_->GetState(true /* try_to_connect */);
+ WaitForAllBackends(1 /* num_requests_multiple_of */,
+ kNumFallbackBackends /* start_index */);
+ // Stop backends from balancer. Since we are still in contact with
+ // the balancer at this point, RPCs should be failing.
+ for (size_t i = kNumFallbackBackends; i < backends_.size(); ++i) {
+ ShutdownBackend(i);
+ }
+ CheckRpcSendFailure();
+ // Stop balancer. This should put us in fallback mode.
+ balancers_[0]->Shutdown();
+ WaitForAllBackends(1 /* num_requests_multiple_of */, 0 /* start_index */,
+ kNumFallbackBackends /* stop_index */);
+ // Restart the backends from the balancer. We should *not* start
+ // sending traffic back to them at this point (although the behavior
+ // in xds may be different).
+ for (size_t i = kNumFallbackBackends; i < backends_.size(); ++i) {
+ StartBackend(i);
+ }
+ CheckRpcSendOk(100 * kNumBalancerBackends);
+ for (size_t i = 0; i < kNumFallbackBackends; ++i) {
+ EXPECT_EQ(100UL, backends_[i]->service_.request_count());
+ }
+ // Now start the balancer again. This should cause us to exit
+ // fallback mode.
+ balancers_[0]->Start(server_host_);
+ ScheduleResponseForBalancer(0,
+ BalancerServiceImpl::BuildResponseForBackends(
+ GetBackendPorts(kNumFallbackBackends), {}),
+ 0);
+ WaitForAllBackends(1 /* num_requests_multiple_of */,
+ kNumFallbackBackends /* start_index */);
+}
+
+TEST_F(SingleBalancerTest, FallbackEarlyWhenBalancerChannelFails) {
+ const int kFallbackTimeoutMs = 10000 * grpc_test_slowdown_factor();
+ ResetStub(kFallbackTimeoutMs);
+ // Return an unreachable balancer and one fallback backend.
+ std::vector<AddressData> balancer_addresses;
+ balancer_addresses.emplace_back(
+ AddressData{grpc_pick_unused_port_or_die(), ""});
+ std::vector<AddressData> backend_addresses;
+ backend_addresses.emplace_back(AddressData{backends_[0]->port_, ""});
+ SetNextResolution(balancer_addresses, backend_addresses);
+ // Send RPC with deadline less than the fallback timeout and make sure it
+ // succeeds.
+ CheckRpcSendOk(/* times */ 1, /* timeout_ms */ 1000,
+ /* wait_for_ready */ false);
+}
+
+TEST_F(SingleBalancerTest, FallbackEarlyWhenBalancerCallFails) {
+ const int kFallbackTimeoutMs = 10000 * grpc_test_slowdown_factor();
+ ResetStub(kFallbackTimeoutMs);
+ // Return one balancer and one fallback backend.
+ std::vector<AddressData> balancer_addresses;
+ balancer_addresses.emplace_back(AddressData{balancers_[0]->port_, ""});
+ std::vector<AddressData> backend_addresses;
+ backend_addresses.emplace_back(AddressData{backends_[0]->port_, ""});
+ SetNextResolution(balancer_addresses, backend_addresses);
+ // Balancer drops call without sending a serverlist.
+ balancers_[0]->service_.NotifyDoneWithServerlists();
+ // Send RPC with deadline less than the fallback timeout and make sure it
+ // succeeds.
+ CheckRpcSendOk(/* times */ 1, /* timeout_ms */ 1000,
+ /* wait_for_ready */ false);
+}
+
+TEST_F(SingleBalancerTest, FallbackControlledByBalancer_BeforeFirstServerlist) {
+ const int kFallbackTimeoutMs = 10000 * grpc_test_slowdown_factor();
+ ResetStub(kFallbackTimeoutMs);
+ // Return one balancer and one fallback backend.
+ std::vector<AddressData> balancer_addresses;
+ balancer_addresses.emplace_back(AddressData{balancers_[0]->port_, ""});
+ std::vector<AddressData> backend_addresses;
+ backend_addresses.emplace_back(AddressData{backends_[0]->port_, ""});
+ SetNextResolution(balancer_addresses, backend_addresses);
+ // Balancer explicitly tells client to fallback.
+ LoadBalanceResponse resp;
+ resp.mutable_fallback_response();
+ ScheduleResponseForBalancer(0, resp, 0);
+ // Send RPC with deadline less than the fallback timeout and make sure it
+ // succeeds.
+ CheckRpcSendOk(/* times */ 1, /* timeout_ms */ 1000,
+ /* wait_for_ready */ false);
+}
+
+TEST_F(SingleBalancerTest, FallbackControlledByBalancer_AfterFirstServerlist) {
+ // Return one balancer and one fallback backend (backend 0).
+ std::vector<AddressData> balancer_addresses;
+ balancer_addresses.emplace_back(AddressData{balancers_[0]->port_, ""});
+ std::vector<AddressData> backend_addresses;
+ backend_addresses.emplace_back(AddressData{backends_[0]->port_, ""});
+ SetNextResolution(balancer_addresses, backend_addresses);
+ // Balancer initially sends serverlist, then tells client to fall back,
+ // then sends the serverlist again.
+ // The serverlist points to backend 1.
+ LoadBalanceResponse serverlist_resp =
+ BalancerServiceImpl::BuildResponseForBackends({backends_[1]->port_}, {});
+ LoadBalanceResponse fallback_resp;
+ fallback_resp.mutable_fallback_response();
+ ScheduleResponseForBalancer(0, serverlist_resp, 0);
+ ScheduleResponseForBalancer(0, fallback_resp, 100);
+ ScheduleResponseForBalancer(0, serverlist_resp, 100);
+ // Requests initially go to backend 1, then go to backend 0 in
+ // fallback mode, then go back to backend 1 when we exit fallback.
+ WaitForBackend(1);
+ WaitForBackend(0);
+ WaitForBackend(1);
+}
+
+TEST_F(SingleBalancerTest, BackendsRestart) {
+ SetNextResolutionAllBalancers();
+ const size_t kNumRpcsPerAddress = 100;
+ ScheduleResponseForBalancer(
+ 0, BalancerServiceImpl::BuildResponseForBackends(GetBackendPorts(), {}),
+ 0);
+ // Make sure that trying to connect works without a call.
+ channel_->GetState(true /* try_to_connect */);
+ // Send kNumRpcsPerAddress RPCs per server.
+ CheckRpcSendOk(kNumRpcsPerAddress * num_backends_);
+ // Stop backends. RPCs should fail.
+ ShutdownAllBackends();
+ CheckRpcSendFailure();
+ // Restart backends. RPCs should start succeeding again.
+ StartAllBackends();
+ CheckRpcSendOk(1 /* times */, 2000 /* timeout_ms */,
+ true /* wait_for_ready */);
+ // The balancer got a single request.
+ EXPECT_EQ(1U, balancers_[0]->service_.request_count());
+ // and sent a single response.
+ EXPECT_EQ(1U, balancers_[0]->service_.response_count());
+}
+
+TEST_F(SingleBalancerTest, ServiceNameFromLbPolicyConfig) {
+ constexpr char kServiceConfigWithTarget[] =
+ "{\n"
+ " \"loadBalancingConfig\":[\n"
+ " { \"grpclb\":{\n"
+ " \"serviceName\":\"test_service\"\n"
+ " }}\n"
+ " ]\n"
+ "}";
+
+ SetNextResolutionAllBalancers(kServiceConfigWithTarget);
+ ScheduleResponseForBalancer(
+ 0, BalancerServiceImpl::BuildResponseForBackends(GetBackendPorts(), {}),
+ 0);
+ // Make sure that trying to connect works without a call.
+ channel_->GetState(true /* try_to_connect */);
+ // We need to wait for all backends to come online.
+ WaitForAllBackends();
+ EXPECT_EQ(balancers_[0]->service_.service_names().back(), "test_service");
+}
+
+class UpdatesTest : public GrpclbEnd2endTest {
+ public:
+ UpdatesTest() : GrpclbEnd2endTest(4, 3, 0) {}
+};
+
+TEST_F(UpdatesTest, UpdateBalancersButKeepUsingOriginalBalancer) {
+ SetNextResolutionAllBalancers();
+ const std::vector<int> first_backend{GetBackendPorts()[0]};
+ const std::vector<int> second_backend{GetBackendPorts()[1]};
+ ScheduleResponseForBalancer(
+ 0, BalancerServiceImpl::BuildResponseForBackends(first_backend, {}), 0);
+ ScheduleResponseForBalancer(
+ 1, BalancerServiceImpl::BuildResponseForBackends(second_backend, {}), 0);
+
+ // Wait until the first backend is ready.
+ WaitForBackend(0);
+
+ // Send 10 requests.
+ gpr_log(GPR_INFO, "========= BEFORE FIRST BATCH ==========");
+ CheckRpcSendOk(10);
+ gpr_log(GPR_INFO, "========= DONE WITH FIRST BATCH ==========");
+
+ // All 10 requests should have gone to the first backend.
+ EXPECT_EQ(10U, backends_[0]->service_.request_count());
+
+ // Balancer 0 got a single request.
+ EXPECT_EQ(1U, balancers_[0]->service_.request_count());
+ // and sent a single response.
+ EXPECT_EQ(1U, balancers_[0]->service_.response_count());
+ EXPECT_EQ(0U, balancers_[1]->service_.request_count());
+ EXPECT_EQ(0U, balancers_[1]->service_.response_count());
+ EXPECT_EQ(0U, balancers_[2]->service_.request_count());
+ EXPECT_EQ(0U, balancers_[2]->service_.response_count());
+
+ std::vector<AddressData> addresses;
+ addresses.emplace_back(AddressData{balancers_[1]->port_, ""});
+ gpr_log(GPR_INFO, "========= ABOUT TO UPDATE 1 ==========");
+ SetNextResolution(addresses);
+ gpr_log(GPR_INFO, "========= UPDATE 1 DONE ==========");
+
+ EXPECT_EQ(0U, backends_[1]->service_.request_count());
+ gpr_timespec deadline = gpr_time_add(
+ gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_millis(10000, GPR_TIMESPAN));
+ // Send 10 seconds worth of RPCs
+ do {
+ CheckRpcSendOk();
+ } while (gpr_time_cmp(gpr_now(GPR_CLOCK_REALTIME), deadline) < 0);
+ // The current LB call is still working, so grpclb continued using it to the
+ // first balancer, which doesn't assign the second backend.
+ EXPECT_EQ(0U, backends_[1]->service_.request_count());
+
+ EXPECT_EQ(1U, balancers_[0]->service_.request_count());
+ EXPECT_EQ(1U, balancers_[0]->service_.response_count());
+ EXPECT_EQ(0U, balancers_[1]->service_.request_count());
+ EXPECT_EQ(0U, balancers_[1]->service_.response_count());
+ EXPECT_EQ(0U, balancers_[2]->service_.request_count());
+ EXPECT_EQ(0U, balancers_[2]->service_.response_count());
+}
+
+// Send an update with the same set of LBs as the one in SetUp() in order to
+// verify that the LB channel inside grpclb keeps the initial connection (which
+// by definition is also present in the update).
+TEST_F(UpdatesTest, UpdateBalancersRepeated) {
+ SetNextResolutionAllBalancers();
+ const std::vector<int> first_backend{GetBackendPorts()[0]};
+ const std::vector<int> second_backend{GetBackendPorts()[0]};
+
+ ScheduleResponseForBalancer(
+ 0, BalancerServiceImpl::BuildResponseForBackends(first_backend, {}), 0);
+ ScheduleResponseForBalancer(
+ 1, BalancerServiceImpl::BuildResponseForBackends(second_backend, {}), 0);
+
+ // Wait until the first backend is ready.
+ WaitForBackend(0);
+
+ // Send 10 requests.
+ gpr_log(GPR_INFO, "========= BEFORE FIRST BATCH ==========");
+ CheckRpcSendOk(10);
+ gpr_log(GPR_INFO, "========= DONE WITH FIRST BATCH ==========");
+
+ // All 10 requests should have gone to the first backend.
+ EXPECT_EQ(10U, backends_[0]->service_.request_count());
+
+ balancers_[0]->service_.NotifyDoneWithServerlists();
+ // Balancer 0 got a single request.
+ EXPECT_EQ(1U, balancers_[0]->service_.request_count());
+ // and sent a single response.
+ EXPECT_EQ(1U, balancers_[0]->service_.response_count());
+ EXPECT_EQ(0U, balancers_[1]->service_.request_count());
+ EXPECT_EQ(0U, balancers_[1]->service_.response_count());
+ EXPECT_EQ(0U, balancers_[2]->service_.request_count());
+ EXPECT_EQ(0U, balancers_[2]->service_.response_count());
+
+ std::vector<AddressData> addresses;
+ addresses.emplace_back(AddressData{balancers_[0]->port_, ""});
+ addresses.emplace_back(AddressData{balancers_[1]->port_, ""});
+ addresses.emplace_back(AddressData{balancers_[2]->port_, ""});
+ gpr_log(GPR_INFO, "========= ABOUT TO UPDATE 1 ==========");
+ SetNextResolution(addresses);
+ gpr_log(GPR_INFO, "========= UPDATE 1 DONE ==========");
+
+ EXPECT_EQ(0U, backends_[1]->service_.request_count());
+ gpr_timespec deadline = gpr_time_add(
+ gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_millis(10000, GPR_TIMESPAN));
+ // Send 10 seconds worth of RPCs
+ do {
+ CheckRpcSendOk();
+ } while (gpr_time_cmp(gpr_now(GPR_CLOCK_REALTIME), deadline) < 0);
+ // grpclb continued using the original LB call to the first balancer, which
+ // doesn't assign the second backend.
+ EXPECT_EQ(0U, backends_[1]->service_.request_count());
+ balancers_[0]->service_.NotifyDoneWithServerlists();
+
+ addresses.clear();
+ addresses.emplace_back(AddressData{balancers_[0]->port_, ""});
+ addresses.emplace_back(AddressData{balancers_[1]->port_, ""});
+ gpr_log(GPR_INFO, "========= ABOUT TO UPDATE 2 ==========");
+ SetNextResolution(addresses);
+ gpr_log(GPR_INFO, "========= UPDATE 2 DONE ==========");
+
+ EXPECT_EQ(0U, backends_[1]->service_.request_count());
+ deadline = gpr_time_add(gpr_now(GPR_CLOCK_REALTIME),
+ gpr_time_from_millis(10000, GPR_TIMESPAN));
+ // Send 10 seconds worth of RPCs
+ do {
+ CheckRpcSendOk();
+ } while (gpr_time_cmp(gpr_now(GPR_CLOCK_REALTIME), deadline) < 0);
+ // grpclb continued using the original LB call to the first balancer, which
+ // doesn't assign the second backend.
+ EXPECT_EQ(0U, backends_[1]->service_.request_count());
+ balancers_[0]->service_.NotifyDoneWithServerlists();
+}
+
+TEST_F(UpdatesTest, UpdateBalancersDeadUpdate) {
+ std::vector<AddressData> addresses;
+ addresses.emplace_back(AddressData{balancers_[0]->port_, ""});
+ SetNextResolution(addresses);
+ const std::vector<int> first_backend{GetBackendPorts()[0]};
+ const std::vector<int> second_backend{GetBackendPorts()[1]};
+
+ ScheduleResponseForBalancer(
+ 0, BalancerServiceImpl::BuildResponseForBackends(first_backend, {}), 0);
+ ScheduleResponseForBalancer(
+ 1, BalancerServiceImpl::BuildResponseForBackends(second_backend, {}), 0);
+
+ // Start servers and send 10 RPCs per server.
+ gpr_log(GPR_INFO, "========= BEFORE FIRST BATCH ==========");
+ CheckRpcSendOk(10);
+ gpr_log(GPR_INFO, "========= DONE WITH FIRST BATCH ==========");
+ // All 10 requests should have gone to the first backend.
+ EXPECT_EQ(10U, backends_[0]->service_.request_count());
+
+ // Kill balancer 0
+ gpr_log(GPR_INFO, "********** ABOUT TO KILL BALANCER 0 *************");
+ balancers_[0]->Shutdown();
+ gpr_log(GPR_INFO, "********** KILLED BALANCER 0 *************");
+
+ // This is serviced by the existing RR policy
+ gpr_log(GPR_INFO, "========= BEFORE SECOND BATCH ==========");
+ CheckRpcSendOk(10);
+ gpr_log(GPR_INFO, "========= DONE WITH SECOND BATCH ==========");
+ // All 10 requests should again have gone to the first backend.
+ EXPECT_EQ(20U, backends_[0]->service_.request_count());
+ EXPECT_EQ(0U, backends_[1]->service_.request_count());
+
+ // Balancer 0 got a single request.
+ EXPECT_EQ(1U, balancers_[0]->service_.request_count());
+ // and sent a single response.
+ EXPECT_EQ(1U, balancers_[0]->service_.response_count());
+ EXPECT_EQ(0U, balancers_[1]->service_.request_count());
+ EXPECT_EQ(0U, balancers_[1]->service_.response_count());
+ EXPECT_EQ(0U, balancers_[2]->service_.request_count());
+ EXPECT_EQ(0U, balancers_[2]->service_.response_count());
+
+ addresses.clear();
+ addresses.emplace_back(AddressData{balancers_[1]->port_, ""});
+ gpr_log(GPR_INFO, "========= ABOUT TO UPDATE 1 ==========");
+ SetNextResolution(addresses);
+ gpr_log(GPR_INFO, "========= UPDATE 1 DONE ==========");
+
+ // Wait until update has been processed, as signaled by the second backend
+ // receiving a request. In the meantime, the client continues to be serviced
+ // (by the first backend) without interruption.
+ EXPECT_EQ(0U, backends_[1]->service_.request_count());
+ WaitForBackend(1);
+
+ // This is serviced by the updated RR policy
+ backends_[1]->service_.ResetCounters();
+ gpr_log(GPR_INFO, "========= BEFORE THIRD BATCH ==========");
+ CheckRpcSendOk(10);
+ gpr_log(GPR_INFO, "========= DONE WITH THIRD BATCH ==========");
+ // All 10 requests should have gone to the second backend.
+ EXPECT_EQ(10U, backends_[1]->service_.request_count());
+
+ EXPECT_EQ(1U, balancers_[0]->service_.request_count());
+ EXPECT_EQ(1U, balancers_[0]->service_.response_count());
+ // The second balancer, published as part of the first update, may end up
+ // getting two requests (that is, 1 <= #req <= 2) if the LB call retry timer
+ // firing races with the arrival of the update containing the second
+ // balancer.
+ EXPECT_GE(balancers_[1]->service_.request_count(), 1U);
+ EXPECT_GE(balancers_[1]->service_.response_count(), 1U);
+ EXPECT_LE(balancers_[1]->service_.request_count(), 2U);
+ EXPECT_LE(balancers_[1]->service_.response_count(), 2U);
+ EXPECT_EQ(0U, balancers_[2]->service_.request_count());
+ EXPECT_EQ(0U, balancers_[2]->service_.response_count());
+}
+
+TEST_F(UpdatesTest, ReresolveDeadBackend) {
+ ResetStub(500);
+ // The first resolution contains the addresses of a balancer that never
+ // responds, and a fallback backend.
+ std::vector<AddressData> balancer_addresses;
+ balancer_addresses.emplace_back(AddressData{balancers_[0]->port_, ""});
+ std::vector<AddressData> backend_addresses;
+ backend_addresses.emplace_back(AddressData{backends_[0]->port_, ""});
+ SetNextResolution(balancer_addresses, backend_addresses);
+ // Ask channel to connect to trigger resolver creation.
+ channel_->GetState(true);
+ // The re-resolution result will contain the addresses of the same balancer
+ // and a new fallback backend.
+ balancer_addresses.clear();
+ balancer_addresses.emplace_back(AddressData{balancers_[0]->port_, ""});
+ backend_addresses.clear();
+ backend_addresses.emplace_back(AddressData{backends_[1]->port_, ""});
+ SetNextReresolutionResponse(balancer_addresses, backend_addresses);
+
+ // Start servers and send 10 RPCs per server.
+ gpr_log(GPR_INFO, "========= BEFORE FIRST BATCH ==========");
+ CheckRpcSendOk(10);
+ gpr_log(GPR_INFO, "========= DONE WITH FIRST BATCH ==========");
+ // All 10 requests should have gone to the fallback backend.
+ EXPECT_EQ(10U, backends_[0]->service_.request_count());
+
+ // Kill backend 0.
+ gpr_log(GPR_INFO, "********** ABOUT TO KILL BACKEND 0 *************");
+ backends_[0]->Shutdown();
+ gpr_log(GPR_INFO, "********** KILLED BACKEND 0 *************");
+
+ // Wait until re-resolution has finished, as signaled by the second backend
+ // receiving a request.
+ WaitForBackend(1);
+
+ gpr_log(GPR_INFO, "========= BEFORE SECOND BATCH ==========");
+ CheckRpcSendOk(10);
+ gpr_log(GPR_INFO, "========= DONE WITH SECOND BATCH ==========");
+ // All 10 requests should have gone to the second backend.
+ EXPECT_EQ(10U, backends_[1]->service_.request_count());
+
+ balancers_[0]->service_.NotifyDoneWithServerlists();
+ balancers_[1]->service_.NotifyDoneWithServerlists();
+ balancers_[2]->service_.NotifyDoneWithServerlists();
+ EXPECT_EQ(1U, balancers_[0]->service_.request_count());
+ EXPECT_EQ(0U, balancers_[0]->service_.response_count());
+ EXPECT_EQ(0U, balancers_[1]->service_.request_count());
+ EXPECT_EQ(0U, balancers_[1]->service_.response_count());
+ EXPECT_EQ(0U, balancers_[2]->service_.request_count());
+ EXPECT_EQ(0U, balancers_[2]->service_.response_count());
+}
+
+// TODO(juanlishen): Should be removed when the first response is always the
+// initial response. Currently, if client load reporting is not enabled, the
+// balancer doesn't send initial response. When the backend shuts down, an
+// unexpected re-resolution will happen. This test configuration is a workaround
+// for test ReresolveDeadBalancer.
+class UpdatesWithClientLoadReportingTest : public GrpclbEnd2endTest {
+ public:
+ UpdatesWithClientLoadReportingTest() : GrpclbEnd2endTest(4, 3, 2) {}
+};
+
+TEST_F(UpdatesWithClientLoadReportingTest, ReresolveDeadBalancer) {
+ const std::vector<int> first_backend{GetBackendPorts()[0]};
+ const std::vector<int> second_backend{GetBackendPorts()[1]};
+ ScheduleResponseForBalancer(
+ 0, BalancerServiceImpl::BuildResponseForBackends(first_backend, {}), 0);
+ ScheduleResponseForBalancer(
+ 1, BalancerServiceImpl::BuildResponseForBackends(second_backend, {}), 0);
+
+ // Ask channel to connect to trigger resolver creation.
+ channel_->GetState(true);
+ std::vector<AddressData> addresses;
+ addresses.emplace_back(AddressData{balancers_[0]->port_, ""});
+ SetNextResolution(addresses);
+ addresses.clear();
+ addresses.emplace_back(AddressData{balancers_[1]->port_, ""});
+ SetNextReresolutionResponse(addresses);
+
+ // Start servers and send 10 RPCs per server.
+ gpr_log(GPR_INFO, "========= BEFORE FIRST BATCH ==========");
+ CheckRpcSendOk(10);
+ gpr_log(GPR_INFO, "========= DONE WITH FIRST BATCH ==========");
+ // All 10 requests should have gone to the first backend.
+ EXPECT_EQ(10U, backends_[0]->service_.request_count());
+
+ // Kill backend 0.
+ gpr_log(GPR_INFO, "********** ABOUT TO KILL BACKEND 0 *************");
+ backends_[0]->Shutdown();
+ gpr_log(GPR_INFO, "********** KILLED BACKEND 0 *************");
+
+ CheckRpcSendFailure();
+
+ // Balancer 0 got a single request.
+ EXPECT_EQ(1U, balancers_[0]->service_.request_count());
+ // and sent a single response.
+ EXPECT_EQ(1U, balancers_[0]->service_.response_count());
+ EXPECT_EQ(0U, balancers_[1]->service_.request_count());
+ EXPECT_EQ(0U, balancers_[1]->service_.response_count());
+ EXPECT_EQ(0U, balancers_[2]->service_.request_count());
+ EXPECT_EQ(0U, balancers_[2]->service_.response_count());
+
+ // Kill balancer 0.
+ gpr_log(GPR_INFO, "********** ABOUT TO KILL BALANCER 0 *************");
+ balancers_[0]->Shutdown();
+ gpr_log(GPR_INFO, "********** KILLED BALANCER 0 *************");
+
+ // Wait until re-resolution has finished, as signaled by the second backend
+ // receiving a request.
+ WaitForBackend(1);
+
+ // This is serviced by the new serverlist.
+ gpr_log(GPR_INFO, "========= BEFORE SECOND BATCH ==========");
+ CheckRpcSendOk(10);
+ gpr_log(GPR_INFO, "========= DONE WITH SECOND BATCH ==========");
+ // All 10 requests should have gone to the second backend.
+ EXPECT_EQ(10U, backends_[1]->service_.request_count());
+
+ EXPECT_EQ(1U, balancers_[0]->service_.request_count());
+ EXPECT_EQ(1U, balancers_[0]->service_.response_count());
+ // After balancer 0 is killed, we restart an LB call immediately (because we
+ // disconnect to a previously connected balancer). Although we will cancel
+ // this call when the re-resolution update is done and another LB call restart
+ // is needed, this old call may still succeed reaching the LB server if
+ // re-resolution is slow. So balancer 1 may have received 2 requests and sent
+ // 2 responses.
+ EXPECT_GE(balancers_[1]->service_.request_count(), 1U);
+ EXPECT_GE(balancers_[1]->service_.response_count(), 1U);
+ EXPECT_LE(balancers_[1]->service_.request_count(), 2U);
+ EXPECT_LE(balancers_[1]->service_.response_count(), 2U);
+ EXPECT_EQ(0U, balancers_[2]->service_.request_count());
+ EXPECT_EQ(0U, balancers_[2]->service_.response_count());
+}
+
+TEST_F(SingleBalancerTest, Drop) {
+ SetNextResolutionAllBalancers();
+ const size_t kNumRpcsPerAddress = 100;
+ const int num_of_drop_by_rate_limiting_addresses = 1;
+ const int num_of_drop_by_load_balancing_addresses = 2;
+ const int num_of_drop_addresses = num_of_drop_by_rate_limiting_addresses +
+ num_of_drop_by_load_balancing_addresses;
+ const int num_total_addresses = num_backends_ + num_of_drop_addresses;
+ ScheduleResponseForBalancer(
+ 0,
+ BalancerServiceImpl::BuildResponseForBackends(
+ GetBackendPorts(),
+ {{"rate_limiting", num_of_drop_by_rate_limiting_addresses},
+ {"load_balancing", num_of_drop_by_load_balancing_addresses}}),
+ 0);
+ // Wait until all backends are ready.
+ WaitForAllBackends();
+ // Send kNumRpcsPerAddress RPCs for each server and drop address.
+ size_t num_drops = 0;
+ for (size_t i = 0; i < kNumRpcsPerAddress * num_total_addresses; ++i) {
+ EchoResponse response;
+ const Status status = SendRpc(&response);
+ if (!status.ok() &&
+ status.error_message() == "Call dropped by load balancing policy") {
+ ++num_drops;
+ } else {
+ EXPECT_TRUE(status.ok()) << "code=" << status.error_code()
+ << " message=" << status.error_message();
+ EXPECT_EQ(response.message(), kRequestMessage_);
+ }
+ }
+ EXPECT_EQ(kNumRpcsPerAddress * num_of_drop_addresses, num_drops);
+ // Each backend should have gotten 100 requests.
+ for (size_t i = 0; i < backends_.size(); ++i) {
+ EXPECT_EQ(kNumRpcsPerAddress, backends_[i]->service_.request_count());
+ }
+ // The balancer got a single request.
+ EXPECT_EQ(1U, balancers_[0]->service_.request_count());
+ // and sent a single response.
+ EXPECT_EQ(1U, balancers_[0]->service_.response_count());
+}
+
+TEST_F(SingleBalancerTest, DropAllFirst) {
+ SetNextResolutionAllBalancers();
+ // All registered addresses are marked as "drop".
+ const int num_of_drop_by_rate_limiting_addresses = 1;
+ const int num_of_drop_by_load_balancing_addresses = 1;
+ ScheduleResponseForBalancer(
+ 0,
+ BalancerServiceImpl::BuildResponseForBackends(
+ {}, {{"rate_limiting", num_of_drop_by_rate_limiting_addresses},
+ {"load_balancing", num_of_drop_by_load_balancing_addresses}}),
+ 0);
+ const Status status = SendRpc(nullptr, 1000, true);
+ EXPECT_FALSE(status.ok());
+ EXPECT_EQ(status.error_message(), "Call dropped by load balancing policy");
+}
+
+TEST_F(SingleBalancerTest, DropAll) {
+ SetNextResolutionAllBalancers();
+ ScheduleResponseForBalancer(
+ 0, BalancerServiceImpl::BuildResponseForBackends(GetBackendPorts(), {}),
+ 0);
+ const int num_of_drop_by_rate_limiting_addresses = 1;
+ const int num_of_drop_by_load_balancing_addresses = 1;
+ ScheduleResponseForBalancer(
+ 0,
+ BalancerServiceImpl::BuildResponseForBackends(
+ {}, {{"rate_limiting", num_of_drop_by_rate_limiting_addresses},
+ {"load_balancing", num_of_drop_by_load_balancing_addresses}}),
+ 1000);
+
+ // First call succeeds.
+ CheckRpcSendOk();
+ // But eventually, the update with only dropped servers is processed and calls
+ // fail.
+ Status status;
+ do {
+ status = SendRpc(nullptr, 1000, true);
+ } while (status.ok());
+ EXPECT_FALSE(status.ok());
+ EXPECT_EQ(status.error_message(), "Call dropped by load balancing policy");
+}
+
+class SingleBalancerWithClientLoadReportingTest : public GrpclbEnd2endTest {
+ public:
+ SingleBalancerWithClientLoadReportingTest() : GrpclbEnd2endTest(4, 1, 3) {}
+};
+
+TEST_F(SingleBalancerWithClientLoadReportingTest, Vanilla) {
+ SetNextResolutionAllBalancers();
+ const size_t kNumRpcsPerAddress = 100;
+ ScheduleResponseForBalancer(
+ 0, BalancerServiceImpl::BuildResponseForBackends(GetBackendPorts(), {}),
+ 0);
+ // Wait until all backends are ready.
+ int num_ok = 0;
+ int num_failure = 0;
+ int num_drops = 0;
+ std::tie(num_ok, num_failure, num_drops) = WaitForAllBackends();
+ // Send kNumRpcsPerAddress RPCs per server.
+ CheckRpcSendOk(kNumRpcsPerAddress * num_backends_);
+ // Each backend should have gotten 100 requests.
+ for (size_t i = 0; i < backends_.size(); ++i) {
+ EXPECT_EQ(kNumRpcsPerAddress, backends_[i]->service_.request_count());
+ }
+ balancers_[0]->service_.NotifyDoneWithServerlists();
+ // The balancer got a single request.
+ EXPECT_EQ(1U, balancers_[0]->service_.request_count());
+ // and sent a single response.
+ EXPECT_EQ(1U, balancers_[0]->service_.response_count());
+
+ ClientStats client_stats;
+ do {
+ client_stats += WaitForLoadReports();
+ } while (client_stats.num_calls_finished !=
+ kNumRpcsPerAddress * num_backends_ + num_ok);
+ EXPECT_EQ(kNumRpcsPerAddress * num_backends_ + num_ok,
+ client_stats.num_calls_started);
+ EXPECT_EQ(kNumRpcsPerAddress * num_backends_ + num_ok,
+ client_stats.num_calls_finished);
+ EXPECT_EQ(0U, client_stats.num_calls_finished_with_client_failed_to_send);
+ EXPECT_EQ(kNumRpcsPerAddress * num_backends_ + (num_ok + num_drops),
+ client_stats.num_calls_finished_known_received);
+ EXPECT_THAT(client_stats.drop_token_counts, ::testing::ElementsAre());
+}
+
+TEST_F(SingleBalancerWithClientLoadReportingTest, BalancerRestart) {
+ SetNextResolutionAllBalancers();
+ const size_t kNumBackendsFirstPass = 2;
+ const size_t kNumBackendsSecondPass =
+ backends_.size() - kNumBackendsFirstPass;
+ // Balancer returns backends starting at index 1.
+ ScheduleResponseForBalancer(
+ 0,
+ BalancerServiceImpl::BuildResponseForBackends(
+ GetBackendPorts(0, kNumBackendsFirstPass), {}),
+ 0);
+ // Wait until all backends returned by the balancer are ready.
+ int num_ok = 0;
+ int num_failure = 0;
+ int num_drops = 0;
+ std::tie(num_ok, num_failure, num_drops) =
+ WaitForAllBackends(/* num_requests_multiple_of */ 1, /* start_index */ 0,
+ /* stop_index */ kNumBackendsFirstPass);
+ balancers_[0]->service_.NotifyDoneWithServerlists();
+ ClientStats client_stats = WaitForLoadReports();
+ EXPECT_EQ(static_cast<size_t>(num_ok), client_stats.num_calls_started);
+ EXPECT_EQ(static_cast<size_t>(num_ok), client_stats.num_calls_finished);
+ EXPECT_EQ(0U, client_stats.num_calls_finished_with_client_failed_to_send);
+ EXPECT_EQ(static_cast<size_t>(num_ok),
+ client_stats.num_calls_finished_known_received);
+ EXPECT_THAT(client_stats.drop_token_counts, ::testing::ElementsAre());
+ // Shut down the balancer.
+ balancers_[0]->Shutdown();
+ // Send 10 more requests per backend. This will continue using the
+ // last serverlist we received from the balancer before it was shut down.
+ ResetBackendCounters();
+ CheckRpcSendOk(kNumBackendsFirstPass);
+ // Each backend should have gotten 1 request.
+ for (size_t i = 0; i < kNumBackendsFirstPass; ++i) {
+ EXPECT_EQ(1UL, backends_[i]->service_.request_count());
+ }
+ // Now restart the balancer, this time pointing to all backends.
+ balancers_[0]->Start(server_host_);
+ ScheduleResponseForBalancer(0,
+ BalancerServiceImpl::BuildResponseForBackends(
+ GetBackendPorts(kNumBackendsFirstPass), {}),
+ 0);
+ // Wait for queries to start going to one of the new backends.
+ // This tells us that we're now using the new serverlist.
+ do {
+ CheckRpcSendOk();
+ } while (backends_[2]->service_.request_count() == 0 &&
+ backends_[3]->service_.request_count() == 0);
+ // Send one RPC per backend.
+ CheckRpcSendOk(kNumBackendsSecondPass);
+ balancers_[0]->service_.NotifyDoneWithServerlists();
+ // Check client stats.
+ client_stats = WaitForLoadReports();
+ EXPECT_EQ(kNumBackendsSecondPass + 1, client_stats.num_calls_started);
+ EXPECT_EQ(kNumBackendsSecondPass + 1, client_stats.num_calls_finished);
+ EXPECT_EQ(0U, client_stats.num_calls_finished_with_client_failed_to_send);
+ EXPECT_EQ(kNumBackendsSecondPass + 1,
+ client_stats.num_calls_finished_known_received);
+ EXPECT_THAT(client_stats.drop_token_counts, ::testing::ElementsAre());
+}
+
+TEST_F(SingleBalancerWithClientLoadReportingTest, Drop) {
+ SetNextResolutionAllBalancers();
+ const size_t kNumRpcsPerAddress = 3;
+ const int num_of_drop_by_rate_limiting_addresses = 2;
+ const int num_of_drop_by_load_balancing_addresses = 1;
+ const int num_of_drop_addresses = num_of_drop_by_rate_limiting_addresses +
+ num_of_drop_by_load_balancing_addresses;
+ const int num_total_addresses = num_backends_ + num_of_drop_addresses;
+ ScheduleResponseForBalancer(
+ 0,
+ BalancerServiceImpl::BuildResponseForBackends(
+ GetBackendPorts(),
+ {{"rate_limiting", num_of_drop_by_rate_limiting_addresses},
+ {"load_balancing", num_of_drop_by_load_balancing_addresses}}),
+ 0);
+ // Wait until all backends are ready.
+ int num_warmup_ok = 0;
+ int num_warmup_failure = 0;
+ int num_warmup_drops = 0;
+ std::tie(num_warmup_ok, num_warmup_failure, num_warmup_drops) =
+ WaitForAllBackends(num_total_addresses /* num_requests_multiple_of */);
+ const int num_total_warmup_requests =
+ num_warmup_ok + num_warmup_failure + num_warmup_drops;
+ size_t num_drops = 0;
+ for (size_t i = 0; i < kNumRpcsPerAddress * num_total_addresses; ++i) {
+ EchoResponse response;
+ const Status status = SendRpc(&response);
+ if (!status.ok() &&
+ status.error_message() == "Call dropped by load balancing policy") {
+ ++num_drops;
+ } else {
+ EXPECT_TRUE(status.ok()) << "code=" << status.error_code()
+ << " message=" << status.error_message();
+ EXPECT_EQ(response.message(), kRequestMessage_);
+ }
+ }
+ EXPECT_EQ(kNumRpcsPerAddress * num_of_drop_addresses, num_drops);
+ // Each backend should have gotten 100 requests.
+ for (size_t i = 0; i < backends_.size(); ++i) {
+ EXPECT_EQ(kNumRpcsPerAddress, backends_[i]->service_.request_count());
+ }
+ balancers_[0]->service_.NotifyDoneWithServerlists();
+ // The balancer got a single request.
+ EXPECT_EQ(1U, balancers_[0]->service_.request_count());
+ // and sent a single response.
+ EXPECT_EQ(1U, balancers_[0]->service_.response_count());
+
+ const ClientStats client_stats = WaitForLoadReports();
+ EXPECT_EQ(
+ kNumRpcsPerAddress * num_total_addresses + num_total_warmup_requests,
+ client_stats.num_calls_started);
+ EXPECT_EQ(
+ kNumRpcsPerAddress * num_total_addresses + num_total_warmup_requests,
+ client_stats.num_calls_finished);
+ EXPECT_EQ(0U, client_stats.num_calls_finished_with_client_failed_to_send);
+ EXPECT_EQ(kNumRpcsPerAddress * num_backends_ + num_warmup_ok,
+ client_stats.num_calls_finished_known_received);
+ // The number of warmup request is a multiple of the number of addresses.
+ // Therefore, all addresses in the scheduled balancer response are hit the
+ // same number of times.
+ const int num_times_drop_addresses_hit =
+ num_warmup_drops / num_of_drop_addresses;
+ EXPECT_THAT(
+ client_stats.drop_token_counts,
+ ::testing::ElementsAre(
+ ::testing::Pair("load_balancing",
+ (kNumRpcsPerAddress + num_times_drop_addresses_hit)),
+ ::testing::Pair(
+ "rate_limiting",
+ (kNumRpcsPerAddress + num_times_drop_addresses_hit) * 2)));
+}
+
+} // namespace
+} // namespace testing
+} // namespace grpc
+
+int main(int argc, char** argv) {
+ grpc::testing::TestEnvironment env(argc, argv);
+ ::testing::InitGoogleTest(&argc, argv);
+ const auto result = RUN_ALL_TESTS();
+ return result;
+}
diff --git a/contrib/libs/grpc/test/cpp/end2end/health/ya.make b/contrib/libs/grpc/test/cpp/end2end/health/ya.make
new file mode 100644
index 0000000000..7330129b73
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/end2end/health/ya.make
@@ -0,0 +1,33 @@
+GTEST_UGLY()
+
+OWNER(
+ dvshkurko
+ g:ymake
+)
+
+ADDINCL(
+ ${ARCADIA_BUILD_ROOT}/contrib/libs/grpc
+ ${ARCADIA_ROOT}/contrib/libs/grpc
+)
+
+PEERDIR(
+ contrib/libs/grpc/src/proto/grpc/health/v1
+ contrib/libs/grpc/src/proto/grpc/core
+ contrib/libs/grpc/src/proto/grpc/testing
+ contrib/libs/grpc/src/proto/grpc/testing/duplicate
+ contrib/libs/grpc/test/core/util
+ contrib/libs/grpc/test/cpp/end2end
+ contrib/libs/grpc/test/cpp/util
+)
+
+NO_COMPILER_WARNINGS()
+
+SRCDIR(
+ contrib/libs/grpc/test/cpp/end2end
+)
+
+SRCS(
+ health_service_end2end_test.cc
+)
+
+END()
diff --git a/contrib/libs/grpc/test/cpp/end2end/health_service_end2end_test.cc b/contrib/libs/grpc/test/cpp/end2end/health_service_end2end_test.cc
new file mode 100644
index 0000000000..516b3a4c81
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/end2end/health_service_end2end_test.cc
@@ -0,0 +1,374 @@
+/*
+ *
+ * Copyright 2016 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include <memory>
+#include <mutex>
+#include <thread>
+#include <vector>
+
+#include <grpc/grpc.h>
+#include <grpc/support/log.h>
+#include <grpcpp/channel.h>
+#include <grpcpp/client_context.h>
+#include <grpcpp/create_channel.h>
+#include <grpcpp/ext/health_check_service_server_builder_option.h>
+#include <grpcpp/health_check_service_interface.h>
+#include <grpcpp/server.h>
+#include <grpcpp/server_builder.h>
+#include <grpcpp/server_context.h>
+
+#include "src/proto/grpc/health/v1/health.grpc.pb.h"
+#include "src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.h"
+#include "src/proto/grpc/testing/echo.grpc.pb.h"
+#include "test/core/util/port.h"
+#include "test/core/util/test_config.h"
+#include "test/cpp/end2end/test_health_check_service_impl.h"
+#include "test/cpp/end2end/test_service_impl.h"
+
+#include <gtest/gtest.h>
+
+using grpc::health::v1::Health;
+using grpc::health::v1::HealthCheckRequest;
+using grpc::health::v1::HealthCheckResponse;
+
+namespace grpc {
+namespace testing {
+namespace {
+
+// A custom implementation of the health checking service interface. This is
+// used to test that it prevents the server from creating a default service and
+// also serves as an example of how to override the default service.
+class CustomHealthCheckService : public HealthCheckServiceInterface {
+ public:
+ explicit CustomHealthCheckService(HealthCheckServiceImpl* impl)
+ : impl_(impl) {
+ impl_->SetStatus("", HealthCheckResponse::SERVING);
+ }
+ void SetServingStatus(const TString& service_name,
+ bool serving) override {
+ impl_->SetStatus(service_name, serving ? HealthCheckResponse::SERVING
+ : HealthCheckResponse::NOT_SERVING);
+ }
+
+ void SetServingStatus(bool serving) override {
+ impl_->SetAll(serving ? HealthCheckResponse::SERVING
+ : HealthCheckResponse::NOT_SERVING);
+ }
+
+ void Shutdown() override { impl_->Shutdown(); }
+
+ private:
+ HealthCheckServiceImpl* impl_; // not owned
+};
+
+class HealthServiceEnd2endTest : public ::testing::Test {
+ protected:
+ HealthServiceEnd2endTest() {}
+
+ void SetUpServer(bool register_sync_test_service, bool add_async_cq,
+ bool explicit_health_service,
+ std::unique_ptr<HealthCheckServiceInterface> service) {
+ int port = 5001; // grpc_pick_unused_port_or_die();
+ server_address_ << "localhost:" << port;
+
+ bool register_sync_health_service_impl =
+ explicit_health_service && service != nullptr;
+
+ // Setup server
+ ServerBuilder builder;
+ if (explicit_health_service) {
+ std::unique_ptr<ServerBuilderOption> option(
+ new HealthCheckServiceServerBuilderOption(std::move(service)));
+ builder.SetOption(std::move(option));
+ }
+ builder.AddListeningPort(server_address_.str(),
+ grpc::InsecureServerCredentials());
+ if (register_sync_test_service) {
+ // Register a sync service.
+ builder.RegisterService(&echo_test_service_);
+ }
+ if (register_sync_health_service_impl) {
+ builder.RegisterService(&health_check_service_impl_);
+ }
+ if (add_async_cq) {
+ cq_ = builder.AddCompletionQueue();
+ }
+ server_ = builder.BuildAndStart();
+ }
+
+ void TearDown() override {
+ if (server_) {
+ server_->Shutdown();
+ if (cq_ != nullptr) {
+ cq_->Shutdown();
+ }
+ if (cq_thread_.joinable()) {
+ cq_thread_.join();
+ }
+ }
+ }
+
+ void ResetStubs() {
+ std::shared_ptr<Channel> channel = grpc::CreateChannel(
+ server_address_.str(), InsecureChannelCredentials());
+ hc_stub_ = grpc::health::v1::Health::NewStub(channel);
+ }
+
+ // When the expected_status is NOT OK, we do not care about the response.
+ void SendHealthCheckRpc(const TString& service_name,
+ const Status& expected_status) {
+ EXPECT_FALSE(expected_status.ok());
+ SendHealthCheckRpc(service_name, expected_status,
+ HealthCheckResponse::UNKNOWN);
+ }
+
+ void SendHealthCheckRpc(
+ const TString& service_name, const Status& expected_status,
+ HealthCheckResponse::ServingStatus expected_serving_status) {
+ HealthCheckRequest request;
+ request.set_service(service_name);
+ HealthCheckResponse response;
+ ClientContext context;
+ Status s = hc_stub_->Check(&context, request, &response);
+ EXPECT_EQ(expected_status.error_code(), s.error_code());
+ if (s.ok()) {
+ EXPECT_EQ(expected_serving_status, response.status());
+ }
+ }
+
+ void VerifyHealthCheckService() {
+ HealthCheckServiceInterface* service = server_->GetHealthCheckService();
+ EXPECT_TRUE(service != nullptr);
+ const TString kHealthyService("healthy_service");
+ const TString kUnhealthyService("unhealthy_service");
+ const TString kNotRegisteredService("not_registered");
+ service->SetServingStatus(kHealthyService, true);
+ service->SetServingStatus(kUnhealthyService, false);
+
+ ResetStubs();
+
+ SendHealthCheckRpc("", Status::OK, HealthCheckResponse::SERVING);
+ SendHealthCheckRpc(kHealthyService, Status::OK,
+ HealthCheckResponse::SERVING);
+ SendHealthCheckRpc(kUnhealthyService, Status::OK,
+ HealthCheckResponse::NOT_SERVING);
+ SendHealthCheckRpc(kNotRegisteredService,
+ Status(StatusCode::NOT_FOUND, ""));
+
+ service->SetServingStatus(false);
+ SendHealthCheckRpc("", Status::OK, HealthCheckResponse::NOT_SERVING);
+ SendHealthCheckRpc(kHealthyService, Status::OK,
+ HealthCheckResponse::NOT_SERVING);
+ SendHealthCheckRpc(kUnhealthyService, Status::OK,
+ HealthCheckResponse::NOT_SERVING);
+ SendHealthCheckRpc(kNotRegisteredService,
+ Status(StatusCode::NOT_FOUND, ""));
+ }
+
+ void VerifyHealthCheckServiceStreaming() {
+ const TString kServiceName("service_name");
+ HealthCheckServiceInterface* service = server_->GetHealthCheckService();
+ // Start Watch for service.
+ ClientContext context;
+ HealthCheckRequest request;
+ request.set_service(kServiceName);
+ std::unique_ptr<::grpc::ClientReaderInterface<HealthCheckResponse>> reader =
+ hc_stub_->Watch(&context, request);
+ // Initial response will be SERVICE_UNKNOWN.
+ HealthCheckResponse response;
+ EXPECT_TRUE(reader->Read(&response));
+ EXPECT_EQ(response.SERVICE_UNKNOWN, response.status());
+ response.Clear();
+ // Now set service to NOT_SERVING and make sure we get an update.
+ service->SetServingStatus(kServiceName, false);
+ EXPECT_TRUE(reader->Read(&response));
+ EXPECT_EQ(response.NOT_SERVING, response.status());
+ response.Clear();
+ // Now set service to SERVING and make sure we get another update.
+ service->SetServingStatus(kServiceName, true);
+ EXPECT_TRUE(reader->Read(&response));
+ EXPECT_EQ(response.SERVING, response.status());
+ // Finish call.
+ context.TryCancel();
+ }
+
+ // Verify that after HealthCheckServiceInterface::Shutdown is called
+ // 1. unary client will see NOT_SERVING.
+ // 2. unary client still sees NOT_SERVING after a SetServing(true) is called.
+ // 3. streaming (Watch) client will see an update.
+ // 4. setting a new service to serving after shutdown will add the service
+ // name but return NOT_SERVING to client.
+ // This has to be called last.
+ void VerifyHealthCheckServiceShutdown() {
+ HealthCheckServiceInterface* service = server_->GetHealthCheckService();
+ EXPECT_TRUE(service != nullptr);
+ const TString kHealthyService("healthy_service");
+ const TString kUnhealthyService("unhealthy_service");
+ const TString kNotRegisteredService("not_registered");
+ const TString kNewService("add_after_shutdown");
+ service->SetServingStatus(kHealthyService, true);
+ service->SetServingStatus(kUnhealthyService, false);
+
+ ResetStubs();
+
+ // Start Watch for service.
+ ClientContext context;
+ HealthCheckRequest request;
+ request.set_service(kHealthyService);
+ std::unique_ptr<::grpc::ClientReaderInterface<HealthCheckResponse>> reader =
+ hc_stub_->Watch(&context, request);
+
+ HealthCheckResponse response;
+ EXPECT_TRUE(reader->Read(&response));
+ EXPECT_EQ(response.SERVING, response.status());
+
+ SendHealthCheckRpc("", Status::OK, HealthCheckResponse::SERVING);
+ SendHealthCheckRpc(kHealthyService, Status::OK,
+ HealthCheckResponse::SERVING);
+ SendHealthCheckRpc(kUnhealthyService, Status::OK,
+ HealthCheckResponse::NOT_SERVING);
+ SendHealthCheckRpc(kNotRegisteredService,
+ Status(StatusCode::NOT_FOUND, ""));
+ SendHealthCheckRpc(kNewService, Status(StatusCode::NOT_FOUND, ""));
+
+ // Shutdown health check service.
+ service->Shutdown();
+
+ // Watch client gets another update.
+ EXPECT_TRUE(reader->Read(&response));
+ EXPECT_EQ(response.NOT_SERVING, response.status());
+ // Finish Watch call.
+ context.TryCancel();
+
+ SendHealthCheckRpc("", Status::OK, HealthCheckResponse::NOT_SERVING);
+ SendHealthCheckRpc(kHealthyService, Status::OK,
+ HealthCheckResponse::NOT_SERVING);
+ SendHealthCheckRpc(kUnhealthyService, Status::OK,
+ HealthCheckResponse::NOT_SERVING);
+ SendHealthCheckRpc(kNotRegisteredService,
+ Status(StatusCode::NOT_FOUND, ""));
+
+ // Setting status after Shutdown has no effect.
+ service->SetServingStatus(kHealthyService, true);
+ SendHealthCheckRpc(kHealthyService, Status::OK,
+ HealthCheckResponse::NOT_SERVING);
+
+ // Adding serving status for a new service after shutdown will return
+ // NOT_SERVING.
+ service->SetServingStatus(kNewService, true);
+ SendHealthCheckRpc(kNewService, Status::OK,
+ HealthCheckResponse::NOT_SERVING);
+ }
+
+ TestServiceImpl echo_test_service_;
+ HealthCheckServiceImpl health_check_service_impl_;
+ std::unique_ptr<Health::Stub> hc_stub_;
+ std::unique_ptr<ServerCompletionQueue> cq_;
+ std::unique_ptr<Server> server_;
+ std::ostringstream server_address_;
+ std::thread cq_thread_;
+};
+
+TEST_F(HealthServiceEnd2endTest, DefaultHealthServiceDisabled) {
+ EnableDefaultHealthCheckService(false);
+ EXPECT_FALSE(DefaultHealthCheckServiceEnabled());
+ SetUpServer(true, false, false, nullptr);
+ HealthCheckServiceInterface* default_service =
+ server_->GetHealthCheckService();
+ EXPECT_TRUE(default_service == nullptr);
+
+ ResetStubs();
+
+ SendHealthCheckRpc("", Status(StatusCode::UNIMPLEMENTED, ""));
+}
+
+TEST_F(HealthServiceEnd2endTest, DefaultHealthService) {
+ EnableDefaultHealthCheckService(true);
+ EXPECT_TRUE(DefaultHealthCheckServiceEnabled());
+ SetUpServer(true, false, false, nullptr);
+ VerifyHealthCheckService();
+ VerifyHealthCheckServiceStreaming();
+
+ // The default service has a size limit of the service name.
+ const TString kTooLongServiceName(201, 'x');
+ SendHealthCheckRpc(kTooLongServiceName,
+ Status(StatusCode::INVALID_ARGUMENT, ""));
+}
+
+TEST_F(HealthServiceEnd2endTest, DefaultHealthServiceShutdown) {
+ EnableDefaultHealthCheckService(true);
+ EXPECT_TRUE(DefaultHealthCheckServiceEnabled());
+ SetUpServer(true, false, false, nullptr);
+ VerifyHealthCheckServiceShutdown();
+}
+
+// Provide an empty service to disable the default service.
+TEST_F(HealthServiceEnd2endTest, ExplicitlyDisableViaOverride) {
+ EnableDefaultHealthCheckService(true);
+ EXPECT_TRUE(DefaultHealthCheckServiceEnabled());
+ std::unique_ptr<HealthCheckServiceInterface> empty_service;
+ SetUpServer(true, false, true, std::move(empty_service));
+ HealthCheckServiceInterface* service = server_->GetHealthCheckService();
+ EXPECT_TRUE(service == nullptr);
+
+ ResetStubs();
+
+ SendHealthCheckRpc("", Status(StatusCode::UNIMPLEMENTED, ""));
+}
+
+// Provide an explicit override of health checking service interface.
+TEST_F(HealthServiceEnd2endTest, ExplicitlyOverride) {
+ EnableDefaultHealthCheckService(true);
+ EXPECT_TRUE(DefaultHealthCheckServiceEnabled());
+ std::unique_ptr<HealthCheckServiceInterface> override_service(
+ new CustomHealthCheckService(&health_check_service_impl_));
+ HealthCheckServiceInterface* underlying_service = override_service.get();
+ SetUpServer(false, false, true, std::move(override_service));
+ HealthCheckServiceInterface* service = server_->GetHealthCheckService();
+ EXPECT_TRUE(service == underlying_service);
+
+ ResetStubs();
+
+ VerifyHealthCheckService();
+ VerifyHealthCheckServiceStreaming();
+}
+
+TEST_F(HealthServiceEnd2endTest, ExplicitlyHealthServiceShutdown) {
+ EnableDefaultHealthCheckService(true);
+ EXPECT_TRUE(DefaultHealthCheckServiceEnabled());
+ std::unique_ptr<HealthCheckServiceInterface> override_service(
+ new CustomHealthCheckService(&health_check_service_impl_));
+ HealthCheckServiceInterface* underlying_service = override_service.get();
+ SetUpServer(false, false, true, std::move(override_service));
+ HealthCheckServiceInterface* service = server_->GetHealthCheckService();
+ EXPECT_TRUE(service == underlying_service);
+
+ ResetStubs();
+
+ VerifyHealthCheckServiceShutdown();
+}
+
+} // namespace
+} // namespace testing
+} // namespace grpc
+
+int main(int argc, char** argv) {
+ grpc::testing::TestEnvironment env(argc, argv);
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/contrib/libs/grpc/test/cpp/end2end/hybrid_end2end_test.cc b/contrib/libs/grpc/test/cpp/end2end/hybrid_end2end_test.cc
new file mode 100644
index 0000000000..e4ebee8e93
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/end2end/hybrid_end2end_test.cc
@@ -0,0 +1,987 @@
+/*
+ *
+ * Copyright 2016 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include <memory>
+#include <thread>
+
+#include <grpc/grpc.h>
+#include <grpcpp/channel.h>
+#include <grpcpp/client_context.h>
+#include <grpcpp/create_channel.h>
+#include <grpcpp/generic/async_generic_service.h>
+#include <grpcpp/server.h>
+#include <grpcpp/server_builder.h>
+#include <grpcpp/server_context.h>
+
+#include "src/core/lib/gpr/env.h"
+#include "src/core/lib/iomgr/iomgr.h"
+#include "src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.h"
+#include "src/proto/grpc/testing/echo.grpc.pb.h"
+#include "test/core/util/port.h"
+#include "test/core/util/test_config.h"
+#include "test/cpp/end2end/test_service_impl.h"
+#include "test/cpp/util/byte_buffer_proto_helper.h"
+
+#include <gtest/gtest.h>
+
+namespace grpc {
+namespace testing {
+namespace {
+
+#ifndef GRPC_CALLBACK_API_NONEXPERIMENTAL
+using ::grpc::experimental::CallbackGenericService;
+using ::grpc::experimental::GenericCallbackServerContext;
+using ::grpc::experimental::ServerGenericBidiReactor;
+#endif
+
+void* tag(int i) { return (void*)static_cast<intptr_t>(i); }
+
+bool VerifyReturnSuccess(CompletionQueue* cq, int i) {
+ void* got_tag;
+ bool ok;
+ EXPECT_TRUE(cq->Next(&got_tag, &ok));
+ EXPECT_EQ(tag(i), got_tag);
+ return ok;
+}
+
+void Verify(CompletionQueue* cq, int i, bool expect_ok) {
+ EXPECT_EQ(expect_ok, VerifyReturnSuccess(cq, i));
+}
+
+// Handlers to handle async request at a server. To be run in a separate thread.
+template <class Service>
+void HandleEcho(Service* service, ServerCompletionQueue* cq, bool dup_service) {
+ ServerContext srv_ctx;
+ grpc::ServerAsyncResponseWriter<EchoResponse> response_writer(&srv_ctx);
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ service->RequestEcho(&srv_ctx, &recv_request, &response_writer, cq, cq,
+ tag(1));
+ Verify(cq, 1, true);
+ send_response.set_message(recv_request.message());
+ if (dup_service) {
+ send_response.mutable_message()->append("_dup");
+ }
+ response_writer.Finish(send_response, Status::OK, tag(2));
+ Verify(cq, 2, true);
+}
+
+// Handlers to handle raw request at a server. To be run in a
+// separate thread. Note that this is the same as the async version, except
+// that the req/resp are ByteBuffers
+template <class Service>
+void HandleRawEcho(Service* service, ServerCompletionQueue* cq,
+ bool /*dup_service*/) {
+ ServerContext srv_ctx;
+ GenericServerAsyncResponseWriter response_writer(&srv_ctx);
+ ByteBuffer recv_buffer;
+ service->RequestEcho(&srv_ctx, &recv_buffer, &response_writer, cq, cq,
+ tag(1));
+ Verify(cq, 1, true);
+ EchoRequest recv_request;
+ EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_request));
+ EchoResponse send_response;
+ send_response.set_message(recv_request.message());
+ auto send_buffer = SerializeToByteBuffer(&send_response);
+ response_writer.Finish(*send_buffer, Status::OK, tag(2));
+ Verify(cq, 2, true);
+}
+
+template <class Service>
+void HandleClientStreaming(Service* service, ServerCompletionQueue* cq) {
+ ServerContext srv_ctx;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ ServerAsyncReader<EchoResponse, EchoRequest> srv_stream(&srv_ctx);
+ service->RequestRequestStream(&srv_ctx, &srv_stream, cq, cq, tag(1));
+ Verify(cq, 1, true);
+ int i = 1;
+ do {
+ i++;
+ send_response.mutable_message()->append(recv_request.message());
+ srv_stream.Read(&recv_request, tag(i));
+ } while (VerifyReturnSuccess(cq, i));
+ srv_stream.Finish(send_response, Status::OK, tag(100));
+ Verify(cq, 100, true);
+}
+
+template <class Service>
+void HandleRawClientStreaming(Service* service, ServerCompletionQueue* cq) {
+ ServerContext srv_ctx;
+ ByteBuffer recv_buffer;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ GenericServerAsyncReader srv_stream(&srv_ctx);
+ service->RequestRequestStream(&srv_ctx, &srv_stream, cq, cq, tag(1));
+ Verify(cq, 1, true);
+ int i = 1;
+ while (true) {
+ i++;
+ srv_stream.Read(&recv_buffer, tag(i));
+ if (!VerifyReturnSuccess(cq, i)) {
+ break;
+ }
+ EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_request));
+ send_response.mutable_message()->append(recv_request.message());
+ }
+ auto send_buffer = SerializeToByteBuffer(&send_response);
+ srv_stream.Finish(*send_buffer, Status::OK, tag(100));
+ Verify(cq, 100, true);
+}
+
+template <class Service>
+void HandleServerStreaming(Service* service, ServerCompletionQueue* cq) {
+ ServerContext srv_ctx;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ ServerAsyncWriter<EchoResponse> srv_stream(&srv_ctx);
+ service->RequestResponseStream(&srv_ctx, &recv_request, &srv_stream, cq, cq,
+ tag(1));
+ Verify(cq, 1, true);
+ send_response.set_message(recv_request.message() + "0");
+ srv_stream.Write(send_response, tag(2));
+ Verify(cq, 2, true);
+ send_response.set_message(recv_request.message() + "1");
+ srv_stream.Write(send_response, tag(3));
+ Verify(cq, 3, true);
+ send_response.set_message(recv_request.message() + "2");
+ srv_stream.Write(send_response, tag(4));
+ Verify(cq, 4, true);
+ srv_stream.Finish(Status::OK, tag(5));
+ Verify(cq, 5, true);
+}
+
+void HandleGenericEcho(GenericServerAsyncReaderWriter* stream,
+ CompletionQueue* cq) {
+ ByteBuffer recv_buffer;
+ stream->Read(&recv_buffer, tag(2));
+ Verify(cq, 2, true);
+ EchoRequest recv_request;
+ EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_request));
+ EchoResponse send_response;
+ send_response.set_message(recv_request.message());
+ auto send_buffer = SerializeToByteBuffer(&send_response);
+ stream->Write(*send_buffer, tag(3));
+ Verify(cq, 3, true);
+ stream->Finish(Status::OK, tag(4));
+ Verify(cq, 4, true);
+}
+
+void HandleGenericRequestStream(GenericServerAsyncReaderWriter* stream,
+ CompletionQueue* cq) {
+ ByteBuffer recv_buffer;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ int i = 1;
+ while (true) {
+ i++;
+ stream->Read(&recv_buffer, tag(i));
+ if (!VerifyReturnSuccess(cq, i)) {
+ break;
+ }
+ EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_request));
+ send_response.mutable_message()->append(recv_request.message());
+ }
+ auto send_buffer = SerializeToByteBuffer(&send_response);
+ stream->Write(*send_buffer, tag(99));
+ Verify(cq, 99, true);
+ stream->Finish(Status::OK, tag(100));
+ Verify(cq, 100, true);
+}
+
+// Request and handle one generic call.
+void HandleGenericCall(AsyncGenericService* service,
+ ServerCompletionQueue* cq) {
+ GenericServerContext srv_ctx;
+ GenericServerAsyncReaderWriter stream(&srv_ctx);
+ service->RequestCall(&srv_ctx, &stream, cq, cq, tag(1));
+ Verify(cq, 1, true);
+ if (srv_ctx.method() == "/grpc.testing.EchoTestService/Echo") {
+ HandleGenericEcho(&stream, cq);
+ } else if (srv_ctx.method() ==
+ "/grpc.testing.EchoTestService/RequestStream") {
+ HandleGenericRequestStream(&stream, cq);
+ } else { // other methods not handled yet.
+ gpr_log(GPR_ERROR, "method: %s", srv_ctx.method().c_str());
+ GPR_ASSERT(0);
+ }
+}
+
+class TestServiceImplDupPkg
+ : public ::grpc::testing::duplicate::EchoTestService::Service {
+ public:
+ Status Echo(ServerContext* /*context*/, const EchoRequest* request,
+ EchoResponse* response) override {
+ response->set_message(request->message() + "_dup");
+ return Status::OK;
+ }
+};
+
+class HybridEnd2endTest : public ::testing::TestWithParam<bool> {
+ protected:
+ HybridEnd2endTest() {}
+
+ static void SetUpTestCase() {
+#if TARGET_OS_IPHONE
+ // Workaround Apple CFStream bug
+ gpr_setenv("grpc_cfstream", "0");
+#endif
+ }
+
+ void SetUp() override {
+ inproc_ = (::testing::UnitTest::GetInstance()
+ ->current_test_info()
+ ->value_param() != nullptr)
+ ? GetParam()
+ : false;
+ }
+
+ bool SetUpServer(::grpc::Service* service1, ::grpc::Service* service2,
+ AsyncGenericService* generic_service,
+ CallbackGenericService* callback_generic_service,
+ int max_message_size = 0) {
+ int port = grpc_pick_unused_port_or_die();
+ server_address_ << "localhost:" << port;
+
+ // Setup server
+ ServerBuilder builder;
+ builder.AddListeningPort(server_address_.str(),
+ grpc::InsecureServerCredentials());
+ // Always add a sync unimplemented service: we rely on having at least one
+ // synchronous method to get a listening cq
+ builder.RegisterService(&unimplemented_service_);
+ builder.RegisterService(service1);
+ if (service2) {
+ builder.RegisterService(service2);
+ }
+ if (generic_service) {
+ builder.RegisterAsyncGenericService(generic_service);
+ }
+ if (callback_generic_service) {
+#ifdef GRPC_CALLBACK_API_NONEXPERIMENTAL
+ builder.RegisterCallbackGenericService(callback_generic_service);
+#else
+ builder.experimental().RegisterCallbackGenericService(
+ callback_generic_service);
+#endif
+ }
+
+ if (max_message_size != 0) {
+ builder.SetMaxMessageSize(max_message_size);
+ }
+
+ // Create a separate cq for each potential handler.
+ for (int i = 0; i < 5; i++) {
+ cqs_.push_back(builder.AddCompletionQueue(false));
+ }
+ server_ = builder.BuildAndStart();
+
+ // If there is a generic callback service, this setup is only successful if
+ // we have an iomgr that can run in the background or are inprocess
+ return !callback_generic_service || grpc_iomgr_run_in_background() ||
+ inproc_;
+ }
+
+ void TearDown() override {
+ if (server_) {
+ server_->Shutdown();
+ }
+ void* ignored_tag;
+ bool ignored_ok;
+ for (auto it = cqs_.begin(); it != cqs_.end(); ++it) {
+ (*it)->Shutdown();
+ while ((*it)->Next(&ignored_tag, &ignored_ok))
+ ;
+ }
+ }
+
+ void ResetStub() {
+ std::shared_ptr<Channel> channel =
+ inproc_ ? server_->InProcessChannel(ChannelArguments())
+ : grpc::CreateChannel(server_address_.str(),
+ InsecureChannelCredentials());
+ stub_ = grpc::testing::EchoTestService::NewStub(channel);
+ }
+
+ // Test all rpc methods.
+ void TestAllMethods() {
+ SendEcho();
+ SendSimpleClientStreaming();
+ SendSimpleServerStreaming();
+ SendBidiStreaming();
+ }
+
+ void SendEcho() {
+ EchoRequest send_request;
+ EchoResponse recv_response;
+ ClientContext cli_ctx;
+ cli_ctx.set_wait_for_ready(true);
+ send_request.set_message("Hello");
+ Status recv_status = stub_->Echo(&cli_ctx, send_request, &recv_response);
+ EXPECT_EQ(send_request.message(), recv_response.message());
+ EXPECT_TRUE(recv_status.ok());
+ }
+
+ void SendEchoToDupService() {
+ std::shared_ptr<Channel> channel = grpc::CreateChannel(
+ server_address_.str(), InsecureChannelCredentials());
+ auto stub = grpc::testing::duplicate::EchoTestService::NewStub(channel);
+ EchoRequest send_request;
+ EchoResponse recv_response;
+ ClientContext cli_ctx;
+ cli_ctx.set_wait_for_ready(true);
+ send_request.set_message("Hello");
+ Status recv_status = stub->Echo(&cli_ctx, send_request, &recv_response);
+ EXPECT_EQ(send_request.message() + "_dup", recv_response.message());
+ EXPECT_TRUE(recv_status.ok());
+ }
+
+ void SendSimpleClientStreaming() {
+ EchoRequest send_request;
+ EchoResponse recv_response;
+ TString expected_message;
+ ClientContext cli_ctx;
+ cli_ctx.set_wait_for_ready(true);
+ send_request.set_message("Hello");
+ auto stream = stub_->RequestStream(&cli_ctx, &recv_response);
+ for (int i = 0; i < 5; i++) {
+ EXPECT_TRUE(stream->Write(send_request));
+ expected_message.append(send_request.message());
+ }
+ stream->WritesDone();
+ Status recv_status = stream->Finish();
+ EXPECT_EQ(expected_message, recv_response.message());
+ EXPECT_TRUE(recv_status.ok());
+ }
+
+ void SendSimpleServerStreaming() {
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+ context.set_wait_for_ready(true);
+ request.set_message("hello");
+
+ auto stream = stub_->ResponseStream(&context, request);
+ EXPECT_TRUE(stream->Read(&response));
+ EXPECT_EQ(response.message(), request.message() + "0");
+ EXPECT_TRUE(stream->Read(&response));
+ EXPECT_EQ(response.message(), request.message() + "1");
+ EXPECT_TRUE(stream->Read(&response));
+ EXPECT_EQ(response.message(), request.message() + "2");
+ EXPECT_FALSE(stream->Read(&response));
+
+ Status s = stream->Finish();
+ EXPECT_TRUE(s.ok());
+ }
+
+ void SendSimpleServerStreamingToDupService() {
+ std::shared_ptr<Channel> channel = grpc::CreateChannel(
+ server_address_.str(), InsecureChannelCredentials());
+ auto stub = grpc::testing::duplicate::EchoTestService::NewStub(channel);
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+ context.set_wait_for_ready(true);
+ request.set_message("hello");
+
+ auto stream = stub->ResponseStream(&context, request);
+ EXPECT_TRUE(stream->Read(&response));
+ EXPECT_EQ(response.message(), request.message() + "0_dup");
+ EXPECT_TRUE(stream->Read(&response));
+ EXPECT_EQ(response.message(), request.message() + "1_dup");
+ EXPECT_TRUE(stream->Read(&response));
+ EXPECT_EQ(response.message(), request.message() + "2_dup");
+ EXPECT_FALSE(stream->Read(&response));
+
+ Status s = stream->Finish();
+ EXPECT_TRUE(s.ok());
+ }
+
+ void SendBidiStreaming() {
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+ context.set_wait_for_ready(true);
+ TString msg("hello");
+
+ auto stream = stub_->BidiStream(&context);
+
+ request.set_message(msg + "0");
+ EXPECT_TRUE(stream->Write(request));
+ EXPECT_TRUE(stream->Read(&response));
+ EXPECT_EQ(response.message(), request.message());
+
+ request.set_message(msg + "1");
+ EXPECT_TRUE(stream->Write(request));
+ EXPECT_TRUE(stream->Read(&response));
+ EXPECT_EQ(response.message(), request.message());
+
+ request.set_message(msg + "2");
+ EXPECT_TRUE(stream->Write(request));
+ EXPECT_TRUE(stream->Read(&response));
+ EXPECT_EQ(response.message(), request.message());
+
+ stream->WritesDone();
+ EXPECT_FALSE(stream->Read(&response));
+ EXPECT_FALSE(stream->Read(&response));
+
+ Status s = stream->Finish();
+ EXPECT_TRUE(s.ok());
+ }
+
+ grpc::testing::UnimplementedEchoService::Service unimplemented_service_;
+ std::vector<std::unique_ptr<ServerCompletionQueue>> cqs_;
+ std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_;
+ std::unique_ptr<Server> server_;
+ std::ostringstream server_address_;
+ bool inproc_;
+};
+
+TEST_F(HybridEnd2endTest, AsyncEcho) {
+ typedef EchoTestService::WithAsyncMethod_Echo<TestServiceImpl> SType;
+ SType service;
+ SetUpServer(&service, nullptr, nullptr, nullptr);
+ ResetStub();
+ std::thread echo_handler_thread(HandleEcho<SType>, &service, cqs_[0].get(),
+ false);
+ TestAllMethods();
+ echo_handler_thread.join();
+}
+
+TEST_F(HybridEnd2endTest, RawEcho) {
+ typedef EchoTestService::WithRawMethod_Echo<TestServiceImpl> SType;
+ SType service;
+ SetUpServer(&service, nullptr, nullptr, nullptr);
+ ResetStub();
+ std::thread echo_handler_thread(HandleRawEcho<SType>, &service, cqs_[0].get(),
+ false);
+ TestAllMethods();
+ echo_handler_thread.join();
+}
+
+TEST_F(HybridEnd2endTest, RawRequestStream) {
+ typedef EchoTestService::WithRawMethod_RequestStream<TestServiceImpl> SType;
+ SType service;
+ SetUpServer(&service, nullptr, nullptr, nullptr);
+ ResetStub();
+ std::thread request_stream_handler_thread(HandleRawClientStreaming<SType>,
+ &service, cqs_[0].get());
+ TestAllMethods();
+ request_stream_handler_thread.join();
+}
+
+TEST_F(HybridEnd2endTest, AsyncEchoRawRequestStream) {
+ typedef EchoTestService::WithRawMethod_RequestStream<
+ EchoTestService::WithAsyncMethod_Echo<TestServiceImpl>>
+ SType;
+ SType service;
+ SetUpServer(&service, nullptr, nullptr, nullptr);
+ ResetStub();
+ std::thread echo_handler_thread(HandleEcho<SType>, &service, cqs_[0].get(),
+ false);
+ std::thread request_stream_handler_thread(HandleRawClientStreaming<SType>,
+ &service, cqs_[1].get());
+ TestAllMethods();
+ request_stream_handler_thread.join();
+ echo_handler_thread.join();
+}
+
+TEST_F(HybridEnd2endTest, GenericEchoRawRequestStream) {
+ typedef EchoTestService::WithRawMethod_RequestStream<
+ EchoTestService::WithGenericMethod_Echo<TestServiceImpl>>
+ SType;
+ SType service;
+ AsyncGenericService generic_service;
+ SetUpServer(&service, nullptr, &generic_service, nullptr);
+ ResetStub();
+ std::thread generic_handler_thread(HandleGenericCall, &generic_service,
+ cqs_[0].get());
+ std::thread request_stream_handler_thread(HandleRawClientStreaming<SType>,
+ &service, cqs_[1].get());
+ TestAllMethods();
+ generic_handler_thread.join();
+ request_stream_handler_thread.join();
+}
+
+TEST_F(HybridEnd2endTest, AsyncEchoRequestStream) {
+ typedef EchoTestService::WithAsyncMethod_RequestStream<
+ EchoTestService::WithAsyncMethod_Echo<TestServiceImpl>>
+ SType;
+ SType service;
+ SetUpServer(&service, nullptr, nullptr, nullptr);
+ ResetStub();
+ std::thread echo_handler_thread(HandleEcho<SType>, &service, cqs_[0].get(),
+ false);
+ std::thread request_stream_handler_thread(HandleClientStreaming<SType>,
+ &service, cqs_[1].get());
+ TestAllMethods();
+ echo_handler_thread.join();
+ request_stream_handler_thread.join();
+}
+
+TEST_F(HybridEnd2endTest, AsyncRequestStreamResponseStream) {
+ typedef EchoTestService::WithAsyncMethod_RequestStream<
+ EchoTestService::WithAsyncMethod_ResponseStream<TestServiceImpl>>
+ SType;
+ SType service;
+ SetUpServer(&service, nullptr, nullptr, nullptr);
+ ResetStub();
+ std::thread response_stream_handler_thread(HandleServerStreaming<SType>,
+ &service, cqs_[0].get());
+ std::thread request_stream_handler_thread(HandleClientStreaming<SType>,
+ &service, cqs_[1].get());
+ TestAllMethods();
+ response_stream_handler_thread.join();
+ request_stream_handler_thread.join();
+}
+
+// Add a second service with one sync method.
+TEST_F(HybridEnd2endTest, AsyncRequestStreamResponseStream_SyncDupService) {
+ typedef EchoTestService::WithAsyncMethod_RequestStream<
+ EchoTestService::WithAsyncMethod_ResponseStream<TestServiceImpl>>
+ SType;
+ SType service;
+ TestServiceImplDupPkg dup_service;
+ SetUpServer(&service, &dup_service, nullptr, nullptr);
+ ResetStub();
+ std::thread response_stream_handler_thread(HandleServerStreaming<SType>,
+ &service, cqs_[0].get());
+ std::thread request_stream_handler_thread(HandleClientStreaming<SType>,
+ &service, cqs_[1].get());
+ TestAllMethods();
+ SendEchoToDupService();
+ response_stream_handler_thread.join();
+ request_stream_handler_thread.join();
+}
+
+// Add a second service with one sync streamed unary method.
+class StreamedUnaryDupPkg
+ : public duplicate::EchoTestService::WithStreamedUnaryMethod_Echo<
+ TestServiceImplDupPkg> {
+ public:
+ Status StreamedEcho(
+ ServerContext* /*context*/,
+ ServerUnaryStreamer<EchoRequest, EchoResponse>* stream) override {
+ EchoRequest req;
+ EchoResponse resp;
+ uint32_t next_msg_sz;
+ stream->NextMessageSize(&next_msg_sz);
+ gpr_log(GPR_INFO, "Streamed Unary Next Message Size is %u", next_msg_sz);
+ GPR_ASSERT(stream->Read(&req));
+ resp.set_message(req.message() + "_dup");
+ GPR_ASSERT(stream->Write(resp));
+ return Status::OK;
+ }
+};
+
+TEST_F(HybridEnd2endTest,
+ AsyncRequestStreamResponseStream_SyncStreamedUnaryDupService) {
+ typedef EchoTestService::WithAsyncMethod_RequestStream<
+ EchoTestService::WithAsyncMethod_ResponseStream<TestServiceImpl>>
+ SType;
+ SType service;
+ StreamedUnaryDupPkg dup_service;
+ SetUpServer(&service, &dup_service, nullptr, nullptr, 8192);
+ ResetStub();
+ std::thread response_stream_handler_thread(HandleServerStreaming<SType>,
+ &service, cqs_[0].get());
+ std::thread request_stream_handler_thread(HandleClientStreaming<SType>,
+ &service, cqs_[1].get());
+ TestAllMethods();
+ SendEchoToDupService();
+ response_stream_handler_thread.join();
+ request_stream_handler_thread.join();
+}
+
+// Add a second service that is fully Streamed Unary
+class FullyStreamedUnaryDupPkg
+ : public duplicate::EchoTestService::StreamedUnaryService {
+ public:
+ Status StreamedEcho(
+ ServerContext* /*context*/,
+ ServerUnaryStreamer<EchoRequest, EchoResponse>* stream) override {
+ EchoRequest req;
+ EchoResponse resp;
+ uint32_t next_msg_sz;
+ stream->NextMessageSize(&next_msg_sz);
+ gpr_log(GPR_INFO, "Streamed Unary Next Message Size is %u", next_msg_sz);
+ GPR_ASSERT(stream->Read(&req));
+ resp.set_message(req.message() + "_dup");
+ GPR_ASSERT(stream->Write(resp));
+ return Status::OK;
+ }
+};
+
+TEST_F(HybridEnd2endTest,
+ AsyncRequestStreamResponseStream_SyncFullyStreamedUnaryDupService) {
+ typedef EchoTestService::WithAsyncMethod_RequestStream<
+ EchoTestService::WithAsyncMethod_ResponseStream<TestServiceImpl>>
+ SType;
+ SType service;
+ FullyStreamedUnaryDupPkg dup_service;
+ SetUpServer(&service, &dup_service, nullptr, nullptr, 8192);
+ ResetStub();
+ std::thread response_stream_handler_thread(HandleServerStreaming<SType>,
+ &service, cqs_[0].get());
+ std::thread request_stream_handler_thread(HandleClientStreaming<SType>,
+ &service, cqs_[1].get());
+ TestAllMethods();
+ SendEchoToDupService();
+ response_stream_handler_thread.join();
+ request_stream_handler_thread.join();
+}
+
+// Add a second service with one sync split server streaming method.
+class SplitResponseStreamDupPkg
+ : public duplicate::EchoTestService::
+ WithSplitStreamingMethod_ResponseStream<TestServiceImplDupPkg> {
+ public:
+ Status StreamedResponseStream(
+ ServerContext* /*context*/,
+ ServerSplitStreamer<EchoRequest, EchoResponse>* stream) override {
+ EchoRequest req;
+ EchoResponse resp;
+ uint32_t next_msg_sz;
+ stream->NextMessageSize(&next_msg_sz);
+ gpr_log(GPR_INFO, "Split Streamed Next Message Size is %u", next_msg_sz);
+ GPR_ASSERT(stream->Read(&req));
+ for (int i = 0; i < kServerDefaultResponseStreamsToSend; i++) {
+ resp.set_message(req.message() + ToString(i) + "_dup");
+ GPR_ASSERT(stream->Write(resp));
+ }
+ return Status::OK;
+ }
+};
+
+TEST_F(HybridEnd2endTest,
+ AsyncRequestStreamResponseStream_SyncSplitStreamedDupService) {
+ typedef EchoTestService::WithAsyncMethod_RequestStream<
+ EchoTestService::WithAsyncMethod_ResponseStream<TestServiceImpl>>
+ SType;
+ SType service;
+ SplitResponseStreamDupPkg dup_service;
+ SetUpServer(&service, &dup_service, nullptr, nullptr, 8192);
+ ResetStub();
+ std::thread response_stream_handler_thread(HandleServerStreaming<SType>,
+ &service, cqs_[0].get());
+ std::thread request_stream_handler_thread(HandleClientStreaming<SType>,
+ &service, cqs_[1].get());
+ TestAllMethods();
+ SendSimpleServerStreamingToDupService();
+ response_stream_handler_thread.join();
+ request_stream_handler_thread.join();
+}
+
+// Add a second service that is fully split server streamed
+class FullySplitStreamedDupPkg
+ : public duplicate::EchoTestService::SplitStreamedService {
+ public:
+ Status StreamedResponseStream(
+ ServerContext* /*context*/,
+ ServerSplitStreamer<EchoRequest, EchoResponse>* stream) override {
+ EchoRequest req;
+ EchoResponse resp;
+ uint32_t next_msg_sz;
+ stream->NextMessageSize(&next_msg_sz);
+ gpr_log(GPR_INFO, "Split Streamed Next Message Size is %u", next_msg_sz);
+ GPR_ASSERT(stream->Read(&req));
+ for (int i = 0; i < kServerDefaultResponseStreamsToSend; i++) {
+ resp.set_message(req.message() + ToString(i) + "_dup");
+ GPR_ASSERT(stream->Write(resp));
+ }
+ return Status::OK;
+ }
+};
+
+TEST_F(HybridEnd2endTest,
+ AsyncRequestStreamResponseStream_FullySplitStreamedDupService) {
+ typedef EchoTestService::WithAsyncMethod_RequestStream<
+ EchoTestService::WithAsyncMethod_ResponseStream<TestServiceImpl>>
+ SType;
+ SType service;
+ FullySplitStreamedDupPkg dup_service;
+ SetUpServer(&service, &dup_service, nullptr, nullptr, 8192);
+ ResetStub();
+ std::thread response_stream_handler_thread(HandleServerStreaming<SType>,
+ &service, cqs_[0].get());
+ std::thread request_stream_handler_thread(HandleClientStreaming<SType>,
+ &service, cqs_[1].get());
+ TestAllMethods();
+ SendSimpleServerStreamingToDupService();
+ response_stream_handler_thread.join();
+ request_stream_handler_thread.join();
+}
+
+// Add a second service that is fully server streamed
+class FullyStreamedDupPkg : public duplicate::EchoTestService::StreamedService {
+ public:
+ Status StreamedEcho(
+ ServerContext* /*context*/,
+ ServerUnaryStreamer<EchoRequest, EchoResponse>* stream) override {
+ EchoRequest req;
+ EchoResponse resp;
+ uint32_t next_msg_sz;
+ stream->NextMessageSize(&next_msg_sz);
+ gpr_log(GPR_INFO, "Streamed Unary Next Message Size is %u", next_msg_sz);
+ GPR_ASSERT(stream->Read(&req));
+ resp.set_message(req.message() + "_dup");
+ GPR_ASSERT(stream->Write(resp));
+ return Status::OK;
+ }
+ Status StreamedResponseStream(
+ ServerContext* /*context*/,
+ ServerSplitStreamer<EchoRequest, EchoResponse>* stream) override {
+ EchoRequest req;
+ EchoResponse resp;
+ uint32_t next_msg_sz;
+ stream->NextMessageSize(&next_msg_sz);
+ gpr_log(GPR_INFO, "Split Streamed Next Message Size is %u", next_msg_sz);
+ GPR_ASSERT(stream->Read(&req));
+ for (int i = 0; i < kServerDefaultResponseStreamsToSend; i++) {
+ resp.set_message(req.message() + ToString(i) + "_dup");
+ GPR_ASSERT(stream->Write(resp));
+ }
+ return Status::OK;
+ }
+};
+
+TEST_F(HybridEnd2endTest,
+ AsyncRequestStreamResponseStream_FullyStreamedDupService) {
+ typedef EchoTestService::WithAsyncMethod_RequestStream<
+ EchoTestService::WithAsyncMethod_ResponseStream<TestServiceImpl>>
+ SType;
+ SType service;
+ FullyStreamedDupPkg dup_service;
+ SetUpServer(&service, &dup_service, nullptr, nullptr, 8192);
+ ResetStub();
+ std::thread response_stream_handler_thread(HandleServerStreaming<SType>,
+ &service, cqs_[0].get());
+ std::thread request_stream_handler_thread(HandleClientStreaming<SType>,
+ &service, cqs_[1].get());
+ TestAllMethods();
+ SendEchoToDupService();
+ SendSimpleServerStreamingToDupService();
+ response_stream_handler_thread.join();
+ request_stream_handler_thread.join();
+}
+
+// Add a second service with one async method.
+TEST_F(HybridEnd2endTest, AsyncRequestStreamResponseStream_AsyncDupService) {
+ typedef EchoTestService::WithAsyncMethod_RequestStream<
+ EchoTestService::WithAsyncMethod_ResponseStream<TestServiceImpl>>
+ SType;
+ SType service;
+ duplicate::EchoTestService::AsyncService dup_service;
+ SetUpServer(&service, &dup_service, nullptr, nullptr);
+ ResetStub();
+ std::thread response_stream_handler_thread(HandleServerStreaming<SType>,
+ &service, cqs_[0].get());
+ std::thread request_stream_handler_thread(HandleClientStreaming<SType>,
+ &service, cqs_[1].get());
+ std::thread echo_handler_thread(
+ HandleEcho<duplicate::EchoTestService::AsyncService>, &dup_service,
+ cqs_[2].get(), true);
+ TestAllMethods();
+ SendEchoToDupService();
+ response_stream_handler_thread.join();
+ request_stream_handler_thread.join();
+ echo_handler_thread.join();
+}
+
+TEST_F(HybridEnd2endTest, GenericEcho) {
+ EchoTestService::WithGenericMethod_Echo<TestServiceImpl> service;
+ AsyncGenericService generic_service;
+ SetUpServer(&service, nullptr, &generic_service, nullptr);
+ ResetStub();
+ std::thread generic_handler_thread(HandleGenericCall, &generic_service,
+ cqs_[0].get());
+ TestAllMethods();
+ generic_handler_thread.join();
+}
+
+TEST_P(HybridEnd2endTest, CallbackGenericEcho) {
+ EchoTestService::WithGenericMethod_Echo<TestServiceImpl> service;
+ class GenericEchoService : public CallbackGenericService {
+ private:
+ ServerGenericBidiReactor* CreateReactor(
+ GenericCallbackServerContext* context) override {
+ EXPECT_EQ(context->method(), "/grpc.testing.EchoTestService/Echo");
+ gpr_log(GPR_DEBUG, "Constructor of generic service %d",
+ static_cast<int>(context->deadline().time_since_epoch().count()));
+
+ class Reactor : public ServerGenericBidiReactor {
+ public:
+ Reactor() { StartRead(&request_); }
+
+ private:
+ void OnDone() override { delete this; }
+ void OnReadDone(bool ok) override {
+ if (!ok) {
+ EXPECT_EQ(reads_complete_, 1);
+ } else {
+ EXPECT_EQ(reads_complete_++, 0);
+ response_ = request_;
+ StartWrite(&response_);
+ StartRead(&request_);
+ }
+ }
+ void OnWriteDone(bool ok) override {
+ Finish(ok ? Status::OK
+ : Status(StatusCode::UNKNOWN, "Unexpected failure"));
+ }
+ ByteBuffer request_;
+ ByteBuffer response_;
+ std::atomic_int reads_complete_{0};
+ };
+ return new Reactor;
+ }
+ } generic_service;
+
+ if (!SetUpServer(&service, nullptr, nullptr, &generic_service)) {
+ return;
+ }
+ ResetStub();
+ TestAllMethods();
+}
+
+TEST_F(HybridEnd2endTest, GenericEchoAsyncRequestStream) {
+ typedef EchoTestService::WithAsyncMethod_RequestStream<
+ EchoTestService::WithGenericMethod_Echo<TestServiceImpl>>
+ SType;
+ SType service;
+ AsyncGenericService generic_service;
+ SetUpServer(&service, nullptr, &generic_service, nullptr);
+ ResetStub();
+ std::thread generic_handler_thread(HandleGenericCall, &generic_service,
+ cqs_[0].get());
+ std::thread request_stream_handler_thread(HandleClientStreaming<SType>,
+ &service, cqs_[1].get());
+ TestAllMethods();
+ generic_handler_thread.join();
+ request_stream_handler_thread.join();
+}
+
+// Add a second service with one sync method.
+TEST_F(HybridEnd2endTest, GenericEchoAsyncRequestStream_SyncDupService) {
+ typedef EchoTestService::WithAsyncMethod_RequestStream<
+ EchoTestService::WithGenericMethod_Echo<TestServiceImpl>>
+ SType;
+ SType service;
+ AsyncGenericService generic_service;
+ TestServiceImplDupPkg dup_service;
+ SetUpServer(&service, &dup_service, &generic_service, nullptr);
+ ResetStub();
+ std::thread generic_handler_thread(HandleGenericCall, &generic_service,
+ cqs_[0].get());
+ std::thread request_stream_handler_thread(HandleClientStreaming<SType>,
+ &service, cqs_[1].get());
+ TestAllMethods();
+ SendEchoToDupService();
+ generic_handler_thread.join();
+ request_stream_handler_thread.join();
+}
+
+// Add a second service with one async method.
+TEST_F(HybridEnd2endTest, GenericEchoAsyncRequestStream_AsyncDupService) {
+ typedef EchoTestService::WithAsyncMethod_RequestStream<
+ EchoTestService::WithGenericMethod_Echo<TestServiceImpl>>
+ SType;
+ SType service;
+ AsyncGenericService generic_service;
+ duplicate::EchoTestService::AsyncService dup_service;
+ SetUpServer(&service, &dup_service, &generic_service, nullptr);
+ ResetStub();
+ std::thread generic_handler_thread(HandleGenericCall, &generic_service,
+ cqs_[0].get());
+ std::thread request_stream_handler_thread(HandleClientStreaming<SType>,
+ &service, cqs_[1].get());
+ std::thread echo_handler_thread(
+ HandleEcho<duplicate::EchoTestService::AsyncService>, &dup_service,
+ cqs_[2].get(), true);
+ TestAllMethods();
+ SendEchoToDupService();
+ generic_handler_thread.join();
+ request_stream_handler_thread.join();
+ echo_handler_thread.join();
+}
+
+TEST_F(HybridEnd2endTest, GenericEchoAsyncRequestStreamResponseStream) {
+ typedef EchoTestService::WithAsyncMethod_RequestStream<
+ EchoTestService::WithGenericMethod_Echo<
+ EchoTestService::WithAsyncMethod_ResponseStream<TestServiceImpl>>>
+ SType;
+ SType service;
+ AsyncGenericService generic_service;
+ SetUpServer(&service, nullptr, &generic_service, nullptr);
+ ResetStub();
+ std::thread generic_handler_thread(HandleGenericCall, &generic_service,
+ cqs_[0].get());
+ std::thread request_stream_handler_thread(HandleClientStreaming<SType>,
+ &service, cqs_[1].get());
+ std::thread response_stream_handler_thread(HandleServerStreaming<SType>,
+ &service, cqs_[2].get());
+ TestAllMethods();
+ generic_handler_thread.join();
+ request_stream_handler_thread.join();
+ response_stream_handler_thread.join();
+}
+
+TEST_F(HybridEnd2endTest, GenericEchoRequestStreamAsyncResponseStream) {
+ typedef EchoTestService::WithGenericMethod_RequestStream<
+ EchoTestService::WithGenericMethod_Echo<
+ EchoTestService::WithAsyncMethod_ResponseStream<TestServiceImpl>>>
+ SType;
+ SType service;
+ AsyncGenericService generic_service;
+ SetUpServer(&service, nullptr, &generic_service, nullptr);
+ ResetStub();
+ std::thread generic_handler_thread(HandleGenericCall, &generic_service,
+ cqs_[0].get());
+ std::thread generic_handler_thread2(HandleGenericCall, &generic_service,
+ cqs_[1].get());
+ std::thread response_stream_handler_thread(HandleServerStreaming<SType>,
+ &service, cqs_[2].get());
+ TestAllMethods();
+ generic_handler_thread.join();
+ generic_handler_thread2.join();
+ response_stream_handler_thread.join();
+}
+
+// If WithGenericMethod is called and no generic service is registered, the
+// server will fail to build.
+TEST_F(HybridEnd2endTest, GenericMethodWithoutGenericService) {
+ EchoTestService::WithGenericMethod_RequestStream<
+ EchoTestService::WithGenericMethod_Echo<
+ EchoTestService::WithAsyncMethod_ResponseStream<TestServiceImpl>>>
+ service;
+ SetUpServer(&service, nullptr, nullptr, nullptr);
+ EXPECT_EQ(nullptr, server_.get());
+}
+
+INSTANTIATE_TEST_SUITE_P(HybridEnd2endTest, HybridEnd2endTest,
+ ::testing::Bool());
+
+} // namespace
+} // namespace testing
+} // namespace grpc
+
+int main(int argc, char** argv) {
+ grpc::testing::TestEnvironment env(argc, argv);
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/contrib/libs/grpc/test/cpp/end2end/interceptors_util.cc b/contrib/libs/grpc/test/cpp/end2end/interceptors_util.cc
new file mode 100644
index 0000000000..ff88953651
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/end2end/interceptors_util.cc
@@ -0,0 +1,214 @@
+/*
+ *
+ * Copyright 2018 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include "test/cpp/end2end/interceptors_util.h"
+#include <util/string/cast.h>
+
+namespace grpc {
+namespace testing {
+
+std::atomic<int> DummyInterceptor::num_times_run_;
+std::atomic<int> DummyInterceptor::num_times_run_reverse_;
+std::atomic<int> DummyInterceptor::num_times_cancel_;
+
+void MakeCall(const std::shared_ptr<Channel>& channel) {
+ auto stub = grpc::testing::EchoTestService::NewStub(channel);
+ ClientContext ctx;
+ EchoRequest req;
+ req.mutable_param()->set_echo_metadata(true);
+ ctx.AddMetadata("testkey", "testvalue");
+ req.set_message("Hello");
+ EchoResponse resp;
+ Status s = stub->Echo(&ctx, req, &resp);
+ EXPECT_EQ(s.ok(), true);
+ EXPECT_EQ(resp.message(), "Hello");
+}
+
+void MakeClientStreamingCall(const std::shared_ptr<Channel>& channel) {
+ auto stub = grpc::testing::EchoTestService::NewStub(channel);
+ ClientContext ctx;
+ EchoRequest req;
+ req.mutable_param()->set_echo_metadata(true);
+ ctx.AddMetadata("testkey", "testvalue");
+ req.set_message("Hello");
+ EchoResponse resp;
+ string expected_resp = "";
+ auto writer = stub->RequestStream(&ctx, &resp);
+ for (int i = 0; i < kNumStreamingMessages; i++) {
+ writer->Write(req);
+ expected_resp += "Hello";
+ }
+ writer->WritesDone();
+ Status s = writer->Finish();
+ EXPECT_EQ(s.ok(), true);
+ EXPECT_EQ(resp.message(), expected_resp);
+}
+
+void MakeServerStreamingCall(const std::shared_ptr<Channel>& channel) {
+ auto stub = grpc::testing::EchoTestService::NewStub(channel);
+ ClientContext ctx;
+ EchoRequest req;
+ req.mutable_param()->set_echo_metadata(true);
+ ctx.AddMetadata("testkey", "testvalue");
+ req.set_message("Hello");
+ EchoResponse resp;
+ auto reader = stub->ResponseStream(&ctx, req);
+ int count = 0;
+ while (reader->Read(&resp)) {
+ EXPECT_EQ(resp.message(), "Hello");
+ count++;
+ }
+ ASSERT_EQ(count, kNumStreamingMessages);
+ Status s = reader->Finish();
+ EXPECT_EQ(s.ok(), true);
+}
+
+void MakeBidiStreamingCall(const std::shared_ptr<Channel>& channel) {
+ auto stub = grpc::testing::EchoTestService::NewStub(channel);
+ ClientContext ctx;
+ EchoRequest req;
+ EchoResponse resp;
+ ctx.AddMetadata("testkey", "testvalue");
+ req.mutable_param()->set_echo_metadata(true);
+ auto stream = stub->BidiStream(&ctx);
+ for (auto i = 0; i < kNumStreamingMessages; i++) {
+ req.set_message(TString("Hello") + ::ToString(i));
+ stream->Write(req);
+ stream->Read(&resp);
+ EXPECT_EQ(req.message(), resp.message());
+ }
+ ASSERT_TRUE(stream->WritesDone());
+ Status s = stream->Finish();
+ EXPECT_EQ(s.ok(), true);
+}
+
+void MakeAsyncCQCall(const std::shared_ptr<Channel>& channel) {
+ auto stub = grpc::testing::EchoTestService::NewStub(channel);
+ CompletionQueue cq;
+ EchoRequest send_request;
+ EchoResponse recv_response;
+ Status recv_status;
+ ClientContext cli_ctx;
+
+ send_request.set_message("Hello");
+ cli_ctx.AddMetadata("testkey", "testvalue");
+ std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader(
+ stub->AsyncEcho(&cli_ctx, send_request, &cq));
+ response_reader->Finish(&recv_response, &recv_status, tag(1));
+ Verifier().Expect(1, true).Verify(&cq);
+ EXPECT_EQ(send_request.message(), recv_response.message());
+ EXPECT_TRUE(recv_status.ok());
+}
+
+void MakeAsyncCQClientStreamingCall(
+ const std::shared_ptr<Channel>& /*channel*/) {
+ // TODO(yashykt) : Fill this out
+}
+
+void MakeAsyncCQServerStreamingCall(const std::shared_ptr<Channel>& channel) {
+ auto stub = grpc::testing::EchoTestService::NewStub(channel);
+ CompletionQueue cq;
+ EchoRequest send_request;
+ EchoResponse recv_response;
+ Status recv_status;
+ ClientContext cli_ctx;
+
+ cli_ctx.AddMetadata("testkey", "testvalue");
+ send_request.set_message("Hello");
+ std::unique_ptr<ClientAsyncReader<EchoResponse>> cli_stream(
+ stub->AsyncResponseStream(&cli_ctx, send_request, &cq, tag(1)));
+ Verifier().Expect(1, true).Verify(&cq);
+ // Read the expected number of messages
+ for (int i = 0; i < kNumStreamingMessages; i++) {
+ cli_stream->Read(&recv_response, tag(2));
+ Verifier().Expect(2, true).Verify(&cq);
+ ASSERT_EQ(recv_response.message(), send_request.message());
+ }
+ // The next read should fail
+ cli_stream->Read(&recv_response, tag(3));
+ Verifier().Expect(3, false).Verify(&cq);
+ // Get the status
+ cli_stream->Finish(&recv_status, tag(4));
+ Verifier().Expect(4, true).Verify(&cq);
+ EXPECT_TRUE(recv_status.ok());
+}
+
+void MakeAsyncCQBidiStreamingCall(const std::shared_ptr<Channel>& /*channel*/) {
+ // TODO(yashykt) : Fill this out
+}
+
+void MakeCallbackCall(const std::shared_ptr<Channel>& channel) {
+ auto stub = grpc::testing::EchoTestService::NewStub(channel);
+ ClientContext ctx;
+ EchoRequest req;
+ std::mutex mu;
+ std::condition_variable cv;
+ bool done = false;
+ req.mutable_param()->set_echo_metadata(true);
+ ctx.AddMetadata("testkey", "testvalue");
+ req.set_message("Hello");
+ EchoResponse resp;
+ stub->experimental_async()->Echo(&ctx, &req, &resp,
+ [&resp, &mu, &done, &cv](Status s) {
+ EXPECT_EQ(s.ok(), true);
+ EXPECT_EQ(resp.message(), "Hello");
+ std::lock_guard<std::mutex> l(mu);
+ done = true;
+ cv.notify_one();
+ });
+ std::unique_lock<std::mutex> l(mu);
+ while (!done) {
+ cv.wait(l);
+ }
+}
+
+bool CheckMetadata(const std::multimap<grpc::string_ref, grpc::string_ref>& map,
+ const string& key, const string& value) {
+ for (const auto& pair : map) {
+ if (pair.first.starts_with(key) && pair.second.starts_with(value)) {
+ return true;
+ }
+ }
+ return false;
+}
+
+bool CheckMetadata(const std::multimap<TString, TString>& map,
+ const string& key, const string& value) {
+ for (const auto& pair : map) {
+ if (pair.first == key.c_str() && pair.second == value.c_str()) {
+ return true;
+ }
+ }
+ return false;
+}
+
+std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
+CreateDummyClientInterceptors() {
+ std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
+ creators;
+ // Add 20 dummy interceptors before hijacking interceptor
+ creators.reserve(20);
+ for (auto i = 0; i < 20; i++) {
+ creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
+ new DummyInterceptorFactory()));
+ }
+ return creators;
+}
+
+} // namespace testing
+} // namespace grpc
diff --git a/contrib/libs/grpc/test/cpp/end2end/interceptors_util.h b/contrib/libs/grpc/test/cpp/end2end/interceptors_util.h
new file mode 100644
index 0000000000..c95170bbbc
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/end2end/interceptors_util.h
@@ -0,0 +1,317 @@
+/*
+ *
+ * Copyright 2018 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include <condition_variable>
+
+#include <grpcpp/channel.h>
+
+#include "src/proto/grpc/testing/echo.grpc.pb.h"
+#include "test/cpp/util/string_ref_helper.h"
+
+#include <gtest/gtest.h>
+
+namespace grpc {
+namespace testing {
+/* This interceptor does nothing. Just keeps a global count on the number of
+ * times it was invoked. */
+class DummyInterceptor : public experimental::Interceptor {
+ public:
+ DummyInterceptor() {}
+
+ virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
+ num_times_run_++;
+ } else if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::
+ POST_RECV_INITIAL_METADATA)) {
+ num_times_run_reverse_++;
+ } else if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_CANCEL)) {
+ num_times_cancel_++;
+ }
+ methods->Proceed();
+ }
+
+ static void Reset() {
+ num_times_run_.store(0);
+ num_times_run_reverse_.store(0);
+ num_times_cancel_.store(0);
+ }
+
+ static int GetNumTimesRun() {
+ EXPECT_EQ(num_times_run_.load(), num_times_run_reverse_.load());
+ return num_times_run_.load();
+ }
+
+ static int GetNumTimesCancel() { return num_times_cancel_.load(); }
+
+ private:
+ static std::atomic<int> num_times_run_;
+ static std::atomic<int> num_times_run_reverse_;
+ static std::atomic<int> num_times_cancel_;
+};
+
+class DummyInterceptorFactory
+ : public experimental::ClientInterceptorFactoryInterface,
+ public experimental::ServerInterceptorFactoryInterface {
+ public:
+ virtual experimental::Interceptor* CreateClientInterceptor(
+ experimental::ClientRpcInfo* /*info*/) override {
+ return new DummyInterceptor();
+ }
+
+ virtual experimental::Interceptor* CreateServerInterceptor(
+ experimental::ServerRpcInfo* /*info*/) override {
+ return new DummyInterceptor();
+ }
+};
+
+/* This interceptor factory returns nullptr on interceptor creation */
+class NullInterceptorFactory
+ : public experimental::ClientInterceptorFactoryInterface,
+ public experimental::ServerInterceptorFactoryInterface {
+ public:
+ virtual experimental::Interceptor* CreateClientInterceptor(
+ experimental::ClientRpcInfo* /*info*/) override {
+ return nullptr;
+ }
+
+ virtual experimental::Interceptor* CreateServerInterceptor(
+ experimental::ServerRpcInfo* /*info*/) override {
+ return nullptr;
+ }
+};
+
+class EchoTestServiceStreamingImpl : public EchoTestService::Service {
+ public:
+ ~EchoTestServiceStreamingImpl() override {}
+
+ Status Echo(ServerContext* context, const EchoRequest* request,
+ EchoResponse* response) override {
+ auto client_metadata = context->client_metadata();
+ for (const auto& pair : client_metadata) {
+ context->AddTrailingMetadata(ToString(pair.first), ToString(pair.second));
+ }
+ response->set_message(request->message());
+ return Status::OK;
+ }
+
+ Status BidiStream(
+ ServerContext* context,
+ grpc::ServerReaderWriter<EchoResponse, EchoRequest>* stream) override {
+ EchoRequest req;
+ EchoResponse resp;
+ auto client_metadata = context->client_metadata();
+ for (const auto& pair : client_metadata) {
+ context->AddTrailingMetadata(ToString(pair.first), ToString(pair.second));
+ }
+
+ while (stream->Read(&req)) {
+ resp.set_message(req.message());
+ EXPECT_TRUE(stream->Write(resp, grpc::WriteOptions()));
+ }
+ return Status::OK;
+ }
+
+ Status RequestStream(ServerContext* context,
+ ServerReader<EchoRequest>* reader,
+ EchoResponse* resp) override {
+ auto client_metadata = context->client_metadata();
+ for (const auto& pair : client_metadata) {
+ context->AddTrailingMetadata(ToString(pair.first), ToString(pair.second));
+ }
+
+ EchoRequest req;
+ string response_str = "";
+ while (reader->Read(&req)) {
+ response_str += req.message();
+ }
+ resp->set_message(response_str);
+ return Status::OK;
+ }
+
+ Status ResponseStream(ServerContext* context, const EchoRequest* req,
+ ServerWriter<EchoResponse>* writer) override {
+ auto client_metadata = context->client_metadata();
+ for (const auto& pair : client_metadata) {
+ context->AddTrailingMetadata(ToString(pair.first), ToString(pair.second));
+ }
+
+ EchoResponse resp;
+ resp.set_message(req->message());
+ for (int i = 0; i < 10; i++) {
+ EXPECT_TRUE(writer->Write(resp));
+ }
+ return Status::OK;
+ }
+};
+
+constexpr int kNumStreamingMessages = 10;
+
+void MakeCall(const std::shared_ptr<Channel>& channel);
+
+void MakeClientStreamingCall(const std::shared_ptr<Channel>& channel);
+
+void MakeServerStreamingCall(const std::shared_ptr<Channel>& channel);
+
+void MakeBidiStreamingCall(const std::shared_ptr<Channel>& channel);
+
+void MakeAsyncCQCall(const std::shared_ptr<Channel>& channel);
+
+void MakeAsyncCQClientStreamingCall(const std::shared_ptr<Channel>& channel);
+
+void MakeAsyncCQServerStreamingCall(const std::shared_ptr<Channel>& channel);
+
+void MakeAsyncCQBidiStreamingCall(const std::shared_ptr<Channel>& channel);
+
+void MakeCallbackCall(const std::shared_ptr<Channel>& channel);
+
+bool CheckMetadata(const std::multimap<grpc::string_ref, grpc::string_ref>& map,
+ const string& key, const string& value);
+
+bool CheckMetadata(const std::multimap<TString, TString>& map,
+ const string& key, const string& value);
+
+std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
+CreateDummyClientInterceptors();
+
+inline void* tag(int i) { return (void*)static_cast<intptr_t>(i); }
+inline int detag(void* p) {
+ return static_cast<int>(reinterpret_cast<intptr_t>(p));
+}
+
+class Verifier {
+ public:
+ Verifier() : lambda_run_(false) {}
+ // Expect sets the expected ok value for a specific tag
+ Verifier& Expect(int i, bool expect_ok) {
+ return ExpectUnless(i, expect_ok, false);
+ }
+ // ExpectUnless sets the expected ok value for a specific tag
+ // unless the tag was already marked seen (as a result of ExpectMaybe)
+ Verifier& ExpectUnless(int i, bool expect_ok, bool seen) {
+ if (!seen) {
+ expectations_[tag(i)] = expect_ok;
+ }
+ return *this;
+ }
+ // ExpectMaybe sets the expected ok value for a specific tag, but does not
+ // require it to appear
+ // If it does, sets *seen to true
+ Verifier& ExpectMaybe(int i, bool expect_ok, bool* seen) {
+ if (!*seen) {
+ maybe_expectations_[tag(i)] = MaybeExpect{expect_ok, seen};
+ }
+ return *this;
+ }
+
+ // Next waits for 1 async tag to complete, checks its
+ // expectations, and returns the tag
+ int Next(CompletionQueue* cq, bool ignore_ok) {
+ bool ok;
+ void* got_tag;
+ EXPECT_TRUE(cq->Next(&got_tag, &ok));
+ GotTag(got_tag, ok, ignore_ok);
+ return detag(got_tag);
+ }
+
+ template <typename T>
+ CompletionQueue::NextStatus DoOnceThenAsyncNext(
+ CompletionQueue* cq, void** got_tag, bool* ok, T deadline,
+ std::function<void(void)> lambda) {
+ if (lambda_run_) {
+ return cq->AsyncNext(got_tag, ok, deadline);
+ } else {
+ lambda_run_ = true;
+ return cq->DoThenAsyncNext(lambda, got_tag, ok, deadline);
+ }
+ }
+
+ // Verify keeps calling Next until all currently set
+ // expected tags are complete
+ void Verify(CompletionQueue* cq) { Verify(cq, false); }
+
+ // This version of Verify allows optionally ignoring the
+ // outcome of the expectation
+ void Verify(CompletionQueue* cq, bool ignore_ok) {
+ GPR_ASSERT(!expectations_.empty() || !maybe_expectations_.empty());
+ while (!expectations_.empty()) {
+ Next(cq, ignore_ok);
+ }
+ }
+
+ // This version of Verify stops after a certain deadline, and uses the
+ // DoThenAsyncNext API
+ // to call the lambda
+ void Verify(CompletionQueue* cq,
+ std::chrono::system_clock::time_point deadline,
+ const std::function<void(void)>& lambda) {
+ if (expectations_.empty()) {
+ bool ok;
+ void* got_tag;
+ EXPECT_EQ(DoOnceThenAsyncNext(cq, &got_tag, &ok, deadline, lambda),
+ CompletionQueue::TIMEOUT);
+ } else {
+ while (!expectations_.empty()) {
+ bool ok;
+ void* got_tag;
+ EXPECT_EQ(DoOnceThenAsyncNext(cq, &got_tag, &ok, deadline, lambda),
+ CompletionQueue::GOT_EVENT);
+ GotTag(got_tag, ok, false);
+ }
+ }
+ }
+
+ private:
+ void GotTag(void* got_tag, bool ok, bool ignore_ok) {
+ auto it = expectations_.find(got_tag);
+ if (it != expectations_.end()) {
+ if (!ignore_ok) {
+ EXPECT_EQ(it->second, ok);
+ }
+ expectations_.erase(it);
+ } else {
+ auto it2 = maybe_expectations_.find(got_tag);
+ if (it2 != maybe_expectations_.end()) {
+ if (it2->second.seen != nullptr) {
+ EXPECT_FALSE(*it2->second.seen);
+ *it2->second.seen = true;
+ }
+ if (!ignore_ok) {
+ EXPECT_EQ(it2->second.ok, ok);
+ }
+ } else {
+ gpr_log(GPR_ERROR, "Unexpected tag: %p", got_tag);
+ abort();
+ }
+ }
+ }
+
+ struct MaybeExpect {
+ bool ok;
+ bool* seen;
+ };
+
+ std::map<void*, bool> expectations_;
+ std::map<void*, MaybeExpect> maybe_expectations_;
+ bool lambda_run_;
+};
+
+} // namespace testing
+} // namespace grpc
diff --git a/contrib/libs/grpc/test/cpp/end2end/message_allocator_end2end_test.cc b/contrib/libs/grpc/test/cpp/end2end/message_allocator_end2end_test.cc
new file mode 100644
index 0000000000..4bf755206e
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/end2end/message_allocator_end2end_test.cc
@@ -0,0 +1,438 @@
+/*
+ *
+ * Copyright 2019 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include <algorithm>
+#include <atomic>
+#include <condition_variable>
+#include <functional>
+#include <memory>
+#include <mutex>
+#include <sstream>
+#include <thread>
+
+#include <google/protobuf/arena.h>
+
+#include <grpc/impl/codegen/log.h>
+#include <gtest/gtest.h>
+
+#include <grpcpp/channel.h>
+#include <grpcpp/client_context.h>
+#include <grpcpp/create_channel.h>
+#include <grpcpp/server.h>
+#include <grpcpp/server_builder.h>
+#include <grpcpp/server_context.h>
+#include <grpcpp/support/client_callback.h>
+#include <grpcpp/support/message_allocator.h>
+
+#include "src/core/lib/iomgr/iomgr.h"
+#include "src/proto/grpc/testing/echo.grpc.pb.h"
+#include "test/core/util/port.h"
+#include "test/core/util/test_config.h"
+#include "test/cpp/util/test_credentials_provider.h"
+
+// MAYBE_SKIP_TEST is a macro to determine if this particular test configuration
+// should be skipped based on a decision made at SetUp time. In particular, any
+// callback tests can only be run if the iomgr can run in the background or if
+// the transport is in-process.
+#define MAYBE_SKIP_TEST \
+ do { \
+ if (do_not_test_) { \
+ return; \
+ } \
+ } while (0)
+
+namespace grpc {
+namespace testing {
+namespace {
+
+class CallbackTestServiceImpl
+ : public EchoTestService::ExperimentalCallbackService {
+ public:
+ explicit CallbackTestServiceImpl() {}
+
+ void SetAllocatorMutator(
+ std::function<void(experimental::RpcAllocatorState* allocator_state,
+ const EchoRequest* req, EchoResponse* resp)>
+ mutator) {
+ allocator_mutator_ = mutator;
+ }
+
+ experimental::ServerUnaryReactor* Echo(
+ experimental::CallbackServerContext* context, const EchoRequest* request,
+ EchoResponse* response) override {
+ response->set_message(request->message());
+ if (allocator_mutator_) {
+ allocator_mutator_(context->GetRpcAllocatorState(), request, response);
+ }
+ auto* reactor = context->DefaultReactor();
+ reactor->Finish(Status::OK);
+ return reactor;
+ }
+
+ private:
+ std::function<void(experimental::RpcAllocatorState* allocator_state,
+ const EchoRequest* req, EchoResponse* resp)>
+ allocator_mutator_;
+};
+
+enum class Protocol { INPROC, TCP };
+
+class TestScenario {
+ public:
+ TestScenario(Protocol protocol, const TString& creds_type)
+ : protocol(protocol), credentials_type(creds_type) {}
+ void Log() const;
+ Protocol protocol;
+ const TString credentials_type;
+};
+
+static std::ostream& operator<<(std::ostream& out,
+ const TestScenario& scenario) {
+ return out << "TestScenario{protocol="
+ << (scenario.protocol == Protocol::INPROC ? "INPROC" : "TCP")
+ << "," << scenario.credentials_type << "}";
+}
+
+void TestScenario::Log() const {
+ std::ostringstream out;
+ out << *this;
+ gpr_log(GPR_INFO, "%s", out.str().c_str());
+}
+
+class MessageAllocatorEnd2endTestBase
+ : public ::testing::TestWithParam<TestScenario> {
+ protected:
+ MessageAllocatorEnd2endTestBase() {
+ GetParam().Log();
+ if (GetParam().protocol == Protocol::TCP) {
+ if (!grpc_iomgr_run_in_background()) {
+ do_not_test_ = true;
+ return;
+ }
+ }
+ }
+
+ ~MessageAllocatorEnd2endTestBase() = default;
+
+ void CreateServer(
+ experimental::MessageAllocator<EchoRequest, EchoResponse>* allocator) {
+ ServerBuilder builder;
+
+ auto server_creds = GetCredentialsProvider()->GetServerCredentials(
+ GetParam().credentials_type);
+ if (GetParam().protocol == Protocol::TCP) {
+ picked_port_ = grpc_pick_unused_port_or_die();
+ server_address_ << "localhost:" << picked_port_;
+ builder.AddListeningPort(server_address_.str(), server_creds);
+ }
+ callback_service_.SetMessageAllocatorFor_Echo(allocator);
+ builder.RegisterService(&callback_service_);
+
+ server_ = builder.BuildAndStart();
+ }
+
+ void DestroyServer() {
+ if (server_) {
+ server_->Shutdown();
+ server_.reset();
+ }
+ }
+
+ void ResetStub() {
+ ChannelArguments args;
+ auto channel_creds = GetCredentialsProvider()->GetChannelCredentials(
+ GetParam().credentials_type, &args);
+ switch (GetParam().protocol) {
+ case Protocol::TCP:
+ channel_ = ::grpc::CreateCustomChannel(server_address_.str(),
+ channel_creds, args);
+ break;
+ case Protocol::INPROC:
+ channel_ = server_->InProcessChannel(args);
+ break;
+ default:
+ assert(false);
+ }
+ stub_ = EchoTestService::NewStub(channel_);
+ }
+
+ void TearDown() override {
+ DestroyServer();
+ if (picked_port_ > 0) {
+ grpc_recycle_unused_port(picked_port_);
+ }
+ }
+
+ void SendRpcs(int num_rpcs) {
+ TString test_string("");
+ for (int i = 0; i < num_rpcs; i++) {
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext cli_ctx;
+
+ test_string += TString(1024, 'x');
+ request.set_message(test_string);
+ TString val;
+ cli_ctx.set_compression_algorithm(GRPC_COMPRESS_GZIP);
+
+ std::mutex mu;
+ std::condition_variable cv;
+ bool done = false;
+ stub_->experimental_async()->Echo(
+ &cli_ctx, &request, &response,
+ [&request, &response, &done, &mu, &cv, val](Status s) {
+ GPR_ASSERT(s.ok());
+
+ EXPECT_EQ(request.message(), response.message());
+ std::lock_guard<std::mutex> l(mu);
+ done = true;
+ cv.notify_one();
+ });
+ std::unique_lock<std::mutex> l(mu);
+ while (!done) {
+ cv.wait(l);
+ }
+ }
+ }
+
+ bool do_not_test_{false};
+ int picked_port_{0};
+ std::shared_ptr<Channel> channel_;
+ std::unique_ptr<EchoTestService::Stub> stub_;
+ CallbackTestServiceImpl callback_service_;
+ std::unique_ptr<Server> server_;
+ std::ostringstream server_address_;
+};
+
+class NullAllocatorTest : public MessageAllocatorEnd2endTestBase {};
+
+TEST_P(NullAllocatorTest, SimpleRpc) {
+ MAYBE_SKIP_TEST;
+ CreateServer(nullptr);
+ ResetStub();
+ SendRpcs(1);
+}
+
+class SimpleAllocatorTest : public MessageAllocatorEnd2endTestBase {
+ public:
+ class SimpleAllocator
+ : public experimental::MessageAllocator<EchoRequest, EchoResponse> {
+ public:
+ class MessageHolderImpl
+ : public experimental::MessageHolder<EchoRequest, EchoResponse> {
+ public:
+ MessageHolderImpl(std::atomic_int* request_deallocation_count,
+ std::atomic_int* messages_deallocation_count)
+ : request_deallocation_count_(request_deallocation_count),
+ messages_deallocation_count_(messages_deallocation_count) {
+ set_request(new EchoRequest);
+ set_response(new EchoResponse);
+ }
+ void Release() override {
+ (*messages_deallocation_count_)++;
+ delete request();
+ delete response();
+ delete this;
+ }
+ void FreeRequest() override {
+ (*request_deallocation_count_)++;
+ delete request();
+ set_request(nullptr);
+ }
+
+ EchoRequest* ReleaseRequest() {
+ auto* ret = request();
+ set_request(nullptr);
+ return ret;
+ }
+
+ private:
+ std::atomic_int* const request_deallocation_count_;
+ std::atomic_int* const messages_deallocation_count_;
+ };
+ experimental::MessageHolder<EchoRequest, EchoResponse>* AllocateMessages()
+ override {
+ allocation_count++;
+ return new MessageHolderImpl(&request_deallocation_count,
+ &messages_deallocation_count);
+ }
+ int allocation_count = 0;
+ std::atomic_int request_deallocation_count{0};
+ std::atomic_int messages_deallocation_count{0};
+ };
+};
+
+TEST_P(SimpleAllocatorTest, SimpleRpc) {
+ MAYBE_SKIP_TEST;
+ const int kRpcCount = 10;
+ std::unique_ptr<SimpleAllocator> allocator(new SimpleAllocator);
+ CreateServer(allocator.get());
+ ResetStub();
+ SendRpcs(kRpcCount);
+ // messages_deallocaton_count is updated in Release after server side OnDone.
+ // Destroy server to make sure it has been updated.
+ DestroyServer();
+ EXPECT_EQ(kRpcCount, allocator->allocation_count);
+ EXPECT_EQ(kRpcCount, allocator->messages_deallocation_count);
+ EXPECT_EQ(0, allocator->request_deallocation_count);
+}
+
+TEST_P(SimpleAllocatorTest, RpcWithEarlyFreeRequest) {
+ MAYBE_SKIP_TEST;
+ const int kRpcCount = 10;
+ std::unique_ptr<SimpleAllocator> allocator(new SimpleAllocator);
+ auto mutator = [](experimental::RpcAllocatorState* allocator_state,
+ const EchoRequest* req, EchoResponse* resp) {
+ auto* info =
+ static_cast<SimpleAllocator::MessageHolderImpl*>(allocator_state);
+ EXPECT_EQ(req, info->request());
+ EXPECT_EQ(resp, info->response());
+ allocator_state->FreeRequest();
+ EXPECT_EQ(nullptr, info->request());
+ };
+ callback_service_.SetAllocatorMutator(mutator);
+ CreateServer(allocator.get());
+ ResetStub();
+ SendRpcs(kRpcCount);
+ // messages_deallocaton_count is updated in Release after server side OnDone.
+ // Destroy server to make sure it has been updated.
+ DestroyServer();
+ EXPECT_EQ(kRpcCount, allocator->allocation_count);
+ EXPECT_EQ(kRpcCount, allocator->messages_deallocation_count);
+ EXPECT_EQ(kRpcCount, allocator->request_deallocation_count);
+}
+
+TEST_P(SimpleAllocatorTest, RpcWithReleaseRequest) {
+ MAYBE_SKIP_TEST;
+ const int kRpcCount = 10;
+ std::unique_ptr<SimpleAllocator> allocator(new SimpleAllocator);
+ std::vector<EchoRequest*> released_requests;
+ auto mutator = [&released_requests](
+ experimental::RpcAllocatorState* allocator_state,
+ const EchoRequest* req, EchoResponse* resp) {
+ auto* info =
+ static_cast<SimpleAllocator::MessageHolderImpl*>(allocator_state);
+ EXPECT_EQ(req, info->request());
+ EXPECT_EQ(resp, info->response());
+ released_requests.push_back(info->ReleaseRequest());
+ EXPECT_EQ(nullptr, info->request());
+ };
+ callback_service_.SetAllocatorMutator(mutator);
+ CreateServer(allocator.get());
+ ResetStub();
+ SendRpcs(kRpcCount);
+ // messages_deallocaton_count is updated in Release after server side OnDone.
+ // Destroy server to make sure it has been updated.
+ DestroyServer();
+ EXPECT_EQ(kRpcCount, allocator->allocation_count);
+ EXPECT_EQ(kRpcCount, allocator->messages_deallocation_count);
+ EXPECT_EQ(0, allocator->request_deallocation_count);
+ EXPECT_EQ(static_cast<unsigned>(kRpcCount), released_requests.size());
+ for (auto* req : released_requests) {
+ delete req;
+ }
+}
+
+class ArenaAllocatorTest : public MessageAllocatorEnd2endTestBase {
+ public:
+ class ArenaAllocator
+ : public experimental::MessageAllocator<EchoRequest, EchoResponse> {
+ public:
+ class MessageHolderImpl
+ : public experimental::MessageHolder<EchoRequest, EchoResponse> {
+ public:
+ MessageHolderImpl() {
+ set_request(
+ google::protobuf::Arena::CreateMessage<EchoRequest>(&arena_));
+ set_response(
+ google::protobuf::Arena::CreateMessage<EchoResponse>(&arena_));
+ }
+ void Release() override { delete this; }
+ void FreeRequest() override { GPR_ASSERT(0); }
+
+ private:
+ google::protobuf::Arena arena_;
+ };
+ experimental::MessageHolder<EchoRequest, EchoResponse>* AllocateMessages()
+ override {
+ allocation_count++;
+ return new MessageHolderImpl;
+ }
+ int allocation_count = 0;
+ };
+};
+
+TEST_P(ArenaAllocatorTest, SimpleRpc) {
+ MAYBE_SKIP_TEST;
+ const int kRpcCount = 10;
+ std::unique_ptr<ArenaAllocator> allocator(new ArenaAllocator);
+ CreateServer(allocator.get());
+ ResetStub();
+ SendRpcs(kRpcCount);
+ EXPECT_EQ(kRpcCount, allocator->allocation_count);
+}
+
+std::vector<TestScenario> CreateTestScenarios(bool test_insecure) {
+ std::vector<TestScenario> scenarios;
+ std::vector<TString> credentials_types{
+ GetCredentialsProvider()->GetSecureCredentialsTypeList()};
+ auto insec_ok = [] {
+ // Only allow insecure credentials type when it is registered with the
+ // provider. User may create providers that do not have insecure.
+ return GetCredentialsProvider()->GetChannelCredentials(
+ kInsecureCredentialsType, nullptr) != nullptr;
+ };
+ if (test_insecure && insec_ok()) {
+ credentials_types.push_back(kInsecureCredentialsType);
+ }
+ GPR_ASSERT(!credentials_types.empty());
+
+ Protocol parr[]{Protocol::INPROC, Protocol::TCP};
+ for (Protocol p : parr) {
+ for (const auto& cred : credentials_types) {
+ // TODO(vjpai): Test inproc with secure credentials when feasible
+ if (p == Protocol::INPROC &&
+ (cred != kInsecureCredentialsType || !insec_ok())) {
+ continue;
+ }
+ scenarios.emplace_back(p, cred);
+ }
+ }
+ return scenarios;
+}
+
+INSTANTIATE_TEST_SUITE_P(NullAllocatorTest, NullAllocatorTest,
+ ::testing::ValuesIn(CreateTestScenarios(true)));
+INSTANTIATE_TEST_SUITE_P(SimpleAllocatorTest, SimpleAllocatorTest,
+ ::testing::ValuesIn(CreateTestScenarios(true)));
+INSTANTIATE_TEST_SUITE_P(ArenaAllocatorTest, ArenaAllocatorTest,
+ ::testing::ValuesIn(CreateTestScenarios(true)));
+
+} // namespace
+} // namespace testing
+} // namespace grpc
+
+int main(int argc, char** argv) {
+ grpc::testing::TestEnvironment env(argc, argv);
+ // The grpc_init is to cover the MAYBE_SKIP_TEST.
+ grpc_init();
+ ::testing::InitGoogleTest(&argc, argv);
+ int ret = RUN_ALL_TESTS();
+ grpc_shutdown();
+ return ret;
+}
diff --git a/contrib/libs/grpc/test/cpp/end2end/mock_test.cc b/contrib/libs/grpc/test/cpp/end2end/mock_test.cc
new file mode 100644
index 0000000000..a3d61c4e98
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/end2end/mock_test.cc
@@ -0,0 +1,434 @@
+/*
+ *
+ * Copyright 2015 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include <climits>
+
+#include <grpc/grpc.h>
+#include <grpc/support/log.h>
+#include <grpc/support/time.h>
+#include <grpcpp/channel.h>
+#include <grpcpp/client_context.h>
+#include <grpcpp/create_channel.h>
+#include <grpcpp/server.h>
+#include <grpcpp/server_builder.h>
+#include <grpcpp/server_context.h>
+#include <grpcpp/test/default_reactor_test_peer.h>
+
+#include "src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.h"
+#include "src/proto/grpc/testing/echo.grpc.pb.h"
+#include "src/proto/grpc/testing/echo_mock.grpc.pb.h"
+#include "test/core/util/port.h"
+#include "test/core/util/test_config.h"
+
+#include <grpcpp/test/mock_stream.h>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include <iostream>
+
+using grpc::testing::DefaultReactorTestPeer;
+using grpc::testing::EchoRequest;
+using grpc::testing::EchoResponse;
+using grpc::testing::EchoTestService;
+using grpc::testing::MockClientReaderWriter;
+using std::vector;
+using std::chrono::system_clock;
+using ::testing::_;
+using ::testing::AtLeast;
+using ::testing::DoAll;
+using ::testing::Invoke;
+using ::testing::Return;
+using ::testing::SaveArg;
+using ::testing::SetArgPointee;
+using ::testing::WithArg;
+
+namespace grpc {
+namespace testing {
+
+namespace {
+class FakeClient {
+ public:
+ explicit FakeClient(EchoTestService::StubInterface* stub) : stub_(stub) {}
+
+ void DoEcho() {
+ ClientContext context;
+ EchoRequest request;
+ EchoResponse response;
+ request.set_message("hello world");
+ Status s = stub_->Echo(&context, request, &response);
+ EXPECT_EQ(request.message(), response.message());
+ EXPECT_TRUE(s.ok());
+ }
+
+ void DoRequestStream() {
+ EchoRequest request;
+ EchoResponse response;
+
+ ClientContext context;
+ TString msg("hello");
+ TString exp(msg);
+
+ std::unique_ptr<ClientWriterInterface<EchoRequest>> cstream =
+ stub_->RequestStream(&context, &response);
+
+ request.set_message(msg);
+ EXPECT_TRUE(cstream->Write(request));
+
+ msg = ", world";
+ request.set_message(msg);
+ exp.append(msg);
+ EXPECT_TRUE(cstream->Write(request));
+
+ cstream->WritesDone();
+ Status s = cstream->Finish();
+
+ EXPECT_EQ(exp, response.message());
+ EXPECT_TRUE(s.ok());
+ }
+
+ void DoResponseStream() {
+ EchoRequest request;
+ EchoResponse response;
+ request.set_message("hello world");
+
+ ClientContext context;
+ std::unique_ptr<ClientReaderInterface<EchoResponse>> cstream =
+ stub_->ResponseStream(&context, request);
+
+ TString exp = "";
+ EXPECT_TRUE(cstream->Read(&response));
+ exp.append(response.message() + " ");
+
+ EXPECT_TRUE(cstream->Read(&response));
+ exp.append(response.message());
+
+ EXPECT_FALSE(cstream->Read(&response));
+ EXPECT_EQ(request.message(), exp);
+
+ Status s = cstream->Finish();
+ EXPECT_TRUE(s.ok());
+ }
+
+ void DoBidiStream() {
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+ TString msg("hello");
+
+ std::unique_ptr<ClientReaderWriterInterface<EchoRequest, EchoResponse>>
+ stream = stub_->BidiStream(&context);
+
+ request.set_message(msg + "0");
+ EXPECT_TRUE(stream->Write(request));
+ EXPECT_TRUE(stream->Read(&response));
+ EXPECT_EQ(response.message(), request.message());
+
+ request.set_message(msg + "1");
+ EXPECT_TRUE(stream->Write(request));
+ EXPECT_TRUE(stream->Read(&response));
+ EXPECT_EQ(response.message(), request.message());
+
+ request.set_message(msg + "2");
+ EXPECT_TRUE(stream->Write(request));
+ EXPECT_TRUE(stream->Read(&response));
+ EXPECT_EQ(response.message(), request.message());
+
+ stream->WritesDone();
+ EXPECT_FALSE(stream->Read(&response));
+
+ Status s = stream->Finish();
+ EXPECT_TRUE(s.ok());
+ }
+
+ void ResetStub(EchoTestService::StubInterface* stub) { stub_ = stub; }
+
+ private:
+ EchoTestService::StubInterface* stub_;
+};
+
+class CallbackTestServiceImpl
+ : public EchoTestService::ExperimentalCallbackService {
+ public:
+ experimental::ServerUnaryReactor* Echo(
+ experimental::CallbackServerContext* context, const EchoRequest* request,
+ EchoResponse* response) override {
+ // Make the mock service explicitly treat empty input messages as invalid
+ // arguments so that we can test various results of status. In general, a
+ // mocked service should just use the original service methods, but we are
+ // adding this variance in Status return value just to improve coverage in
+ // this test.
+ auto* reactor = context->DefaultReactor();
+ if (request->message().length() > 0) {
+ response->set_message(request->message());
+ reactor->Finish(Status::OK);
+ } else {
+ reactor->Finish(Status(StatusCode::INVALID_ARGUMENT, "Invalid request"));
+ }
+ return reactor;
+ }
+};
+
+class MockCallbackTest : public ::testing::Test {
+ protected:
+ CallbackTestServiceImpl service_;
+ ServerContext context_;
+};
+
+TEST_F(MockCallbackTest, MockedCallSucceedsWithWait) {
+ experimental::CallbackServerContext ctx;
+ EchoRequest req;
+ EchoResponse resp;
+ grpc::internal::Mutex mu;
+ grpc::internal::CondVar cv;
+ grpc::Status status;
+ bool status_set = false;
+ DefaultReactorTestPeer peer(&ctx, [&](::grpc::Status s) {
+ grpc::internal::MutexLock l(&mu);
+ status_set = true;
+ status = std::move(s);
+ cv.Signal();
+ });
+
+ req.set_message("mock 1");
+ auto* reactor = service_.Echo(&ctx, &req, &resp);
+ cv.WaitUntil(&mu, [&] {
+ grpc::internal::MutexLock l(&mu);
+ return status_set;
+ });
+ EXPECT_EQ(reactor, peer.reactor());
+ EXPECT_TRUE(peer.test_status_set());
+ EXPECT_TRUE(peer.test_status().ok());
+ EXPECT_TRUE(status_set);
+ EXPECT_TRUE(status.ok());
+ EXPECT_EQ(req.message(), resp.message());
+}
+
+TEST_F(MockCallbackTest, MockedCallSucceeds) {
+ experimental::CallbackServerContext ctx;
+ EchoRequest req;
+ EchoResponse resp;
+ DefaultReactorTestPeer peer(&ctx);
+
+ req.set_message("ha ha, consider yourself mocked.");
+ auto* reactor = service_.Echo(&ctx, &req, &resp);
+ EXPECT_EQ(reactor, peer.reactor());
+ EXPECT_TRUE(peer.test_status_set());
+ EXPECT_TRUE(peer.test_status().ok());
+}
+
+TEST_F(MockCallbackTest, MockedCallFails) {
+ experimental::CallbackServerContext ctx;
+ EchoRequest req;
+ EchoResponse resp;
+ DefaultReactorTestPeer peer(&ctx);
+
+ auto* reactor = service_.Echo(&ctx, &req, &resp);
+ EXPECT_EQ(reactor, peer.reactor());
+ EXPECT_TRUE(peer.test_status_set());
+ EXPECT_EQ(peer.test_status().error_code(), StatusCode::INVALID_ARGUMENT);
+}
+
+class TestServiceImpl : public EchoTestService::Service {
+ public:
+ Status Echo(ServerContext* /*context*/, const EchoRequest* request,
+ EchoResponse* response) override {
+ response->set_message(request->message());
+ return Status::OK;
+ }
+
+ Status RequestStream(ServerContext* /*context*/,
+ ServerReader<EchoRequest>* reader,
+ EchoResponse* response) override {
+ EchoRequest request;
+ TString resp("");
+ while (reader->Read(&request)) {
+ gpr_log(GPR_INFO, "recv msg %s", request.message().c_str());
+ resp.append(request.message());
+ }
+ response->set_message(resp);
+ return Status::OK;
+ }
+
+ Status ResponseStream(ServerContext* /*context*/, const EchoRequest* request,
+ ServerWriter<EchoResponse>* writer) override {
+ EchoResponse response;
+ vector<TString> tokens = split(request->message());
+ for (const TString& token : tokens) {
+ response.set_message(token);
+ writer->Write(response);
+ }
+ return Status::OK;
+ }
+
+ Status BidiStream(
+ ServerContext* /*context*/,
+ ServerReaderWriter<EchoResponse, EchoRequest>* stream) override {
+ EchoRequest request;
+ EchoResponse response;
+ while (stream->Read(&request)) {
+ gpr_log(GPR_INFO, "recv msg %s", request.message().c_str());
+ response.set_message(request.message());
+ stream->Write(response);
+ }
+ return Status::OK;
+ }
+
+ private:
+ const vector<TString> split(const TString& input) {
+ TString buff("");
+ vector<TString> result;
+
+ for (auto n : input) {
+ if (n != ' ') {
+ buff += n;
+ continue;
+ }
+ if (buff == "") continue;
+ result.push_back(buff);
+ buff = "";
+ }
+ if (buff != "") result.push_back(buff);
+
+ return result;
+ }
+};
+
+class MockTest : public ::testing::Test {
+ protected:
+ MockTest() {}
+
+ void SetUp() override {
+ int port = grpc_pick_unused_port_or_die();
+ server_address_ << "localhost:" << port;
+ // Setup server
+ ServerBuilder builder;
+ builder.AddListeningPort(server_address_.str(),
+ InsecureServerCredentials());
+ builder.RegisterService(&service_);
+ server_ = builder.BuildAndStart();
+ }
+
+ void TearDown() override { server_->Shutdown(); }
+
+ void ResetStub() {
+ std::shared_ptr<Channel> channel = grpc::CreateChannel(
+ server_address_.str(), InsecureChannelCredentials());
+ stub_ = grpc::testing::EchoTestService::NewStub(channel);
+ }
+
+ std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_;
+ std::unique_ptr<Server> server_;
+ std::ostringstream server_address_;
+ TestServiceImpl service_;
+};
+
+// Do one real rpc and one mocked one
+TEST_F(MockTest, SimpleRpc) {
+ ResetStub();
+ FakeClient client(stub_.get());
+ client.DoEcho();
+ MockEchoTestServiceStub stub;
+ EchoResponse resp;
+ resp.set_message("hello world");
+ EXPECT_CALL(stub, Echo(_, _, _))
+ .Times(AtLeast(1))
+ .WillOnce(DoAll(SetArgPointee<2>(resp), Return(Status::OK)));
+ client.ResetStub(&stub);
+ client.DoEcho();
+}
+
+TEST_F(MockTest, ClientStream) {
+ ResetStub();
+ FakeClient client(stub_.get());
+ client.DoRequestStream();
+
+ MockEchoTestServiceStub stub;
+ auto w = new MockClientWriter<EchoRequest>();
+ EchoResponse resp;
+ resp.set_message("hello, world");
+
+ EXPECT_CALL(*w, Write(_, _)).Times(2).WillRepeatedly(Return(true));
+ EXPECT_CALL(*w, WritesDone());
+ EXPECT_CALL(*w, Finish()).WillOnce(Return(Status::OK));
+
+ EXPECT_CALL(stub, RequestStreamRaw(_, _))
+ .WillOnce(DoAll(SetArgPointee<1>(resp), Return(w)));
+ client.ResetStub(&stub);
+ client.DoRequestStream();
+}
+
+TEST_F(MockTest, ServerStream) {
+ ResetStub();
+ FakeClient client(stub_.get());
+ client.DoResponseStream();
+
+ MockEchoTestServiceStub stub;
+ auto r = new MockClientReader<EchoResponse>();
+ EchoResponse resp1;
+ resp1.set_message("hello");
+ EchoResponse resp2;
+ resp2.set_message("world");
+
+ EXPECT_CALL(*r, Read(_))
+ .WillOnce(DoAll(SetArgPointee<0>(resp1), Return(true)))
+ .WillOnce(DoAll(SetArgPointee<0>(resp2), Return(true)))
+ .WillOnce(Return(false));
+ EXPECT_CALL(*r, Finish()).WillOnce(Return(Status::OK));
+
+ EXPECT_CALL(stub, ResponseStreamRaw(_, _)).WillOnce(Return(r));
+
+ client.ResetStub(&stub);
+ client.DoResponseStream();
+}
+
+ACTION_P(copy, msg) { arg0->set_message(msg->message()); }
+
+TEST_F(MockTest, BidiStream) {
+ ResetStub();
+ FakeClient client(stub_.get());
+ client.DoBidiStream();
+ MockEchoTestServiceStub stub;
+ auto rw = new MockClientReaderWriter<EchoRequest, EchoResponse>();
+ EchoRequest msg;
+
+ EXPECT_CALL(*rw, Write(_, _))
+ .Times(3)
+ .WillRepeatedly(DoAll(SaveArg<0>(&msg), Return(true)));
+ EXPECT_CALL(*rw, Read(_))
+ .WillOnce(DoAll(WithArg<0>(copy(&msg)), Return(true)))
+ .WillOnce(DoAll(WithArg<0>(copy(&msg)), Return(true)))
+ .WillOnce(DoAll(WithArg<0>(copy(&msg)), Return(true)))
+ .WillOnce(Return(false));
+ EXPECT_CALL(*rw, WritesDone());
+ EXPECT_CALL(*rw, Finish()).WillOnce(Return(Status::OK));
+
+ EXPECT_CALL(stub, BidiStreamRaw(_)).WillOnce(Return(rw));
+ client.ResetStub(&stub);
+ client.DoBidiStream();
+}
+
+} // namespace
+} // namespace testing
+} // namespace grpc
+
+int main(int argc, char** argv) {
+ grpc::testing::TestEnvironment env(argc, argv);
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/contrib/libs/grpc/test/cpp/end2end/nonblocking_test.cc b/contrib/libs/grpc/test/cpp/end2end/nonblocking_test.cc
new file mode 100644
index 0000000000..4be070ec71
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/end2end/nonblocking_test.cc
@@ -0,0 +1,214 @@
+/*
+ *
+ * Copyright 2018 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include <memory>
+
+#include <grpcpp/channel.h>
+#include <grpcpp/client_context.h>
+#include <grpcpp/create_channel.h>
+#include <grpcpp/server.h>
+#include <grpcpp/server_builder.h>
+#include <grpcpp/server_context.h>
+
+#include "src/core/lib/gpr/tls.h"
+#include "src/core/lib/iomgr/port.h"
+#include "src/proto/grpc/testing/echo.grpc.pb.h"
+#include "test/core/util/port.h"
+#include "test/core/util/test_config.h"
+
+#ifdef GRPC_POSIX_SOCKET
+#include "src/core/lib/iomgr/ev_posix.h"
+#endif // GRPC_POSIX_SOCKET
+
+#include <gtest/gtest.h>
+
+#ifdef GRPC_POSIX_SOCKET
+// Thread-local variable to so that only polls from this test assert
+// non-blocking (not polls from resolver, timer thread, etc), and only when the
+// thread is waiting on polls caused by CompletionQueue::AsyncNext (not for
+// picking a port or other reasons).
+GPR_TLS_DECL(g_is_nonblocking_poll);
+
+namespace {
+
+int maybe_assert_non_blocking_poll(struct pollfd* pfds, nfds_t nfds,
+ int timeout) {
+ // Only assert that this poll should have zero timeout if we're in the
+ // middle of a zero-timeout CQ Next.
+ if (gpr_tls_get(&g_is_nonblocking_poll)) {
+ GPR_ASSERT(timeout == 0);
+ }
+ return poll(pfds, nfds, timeout);
+}
+
+} // namespace
+
+namespace grpc {
+namespace testing {
+namespace {
+
+void* tag(int i) { return reinterpret_cast<void*>(static_cast<intptr_t>(i)); }
+int detag(void* p) { return static_cast<int>(reinterpret_cast<intptr_t>(p)); }
+
+class NonblockingTest : public ::testing::Test {
+ protected:
+ NonblockingTest() {}
+
+ void SetUp() override {
+ port_ = grpc_pick_unused_port_or_die();
+ server_address_ << "localhost:" << port_;
+
+ // Setup server
+ BuildAndStartServer();
+ }
+
+ bool LoopForTag(void** tag, bool* ok) {
+ // Temporarily set the thread-local nonblocking poll flag so that the polls
+ // caused by this loop are indeed sent by the library with zero timeout.
+ intptr_t orig_val = gpr_tls_get(&g_is_nonblocking_poll);
+ gpr_tls_set(&g_is_nonblocking_poll, static_cast<intptr_t>(true));
+ for (;;) {
+ auto r = cq_->AsyncNext(tag, ok, gpr_time_0(GPR_CLOCK_REALTIME));
+ if (r == CompletionQueue::SHUTDOWN) {
+ gpr_tls_set(&g_is_nonblocking_poll, orig_val);
+ return false;
+ } else if (r == CompletionQueue::GOT_EVENT) {
+ gpr_tls_set(&g_is_nonblocking_poll, orig_val);
+ return true;
+ }
+ }
+ }
+
+ void TearDown() override {
+ server_->Shutdown();
+ void* ignored_tag;
+ bool ignored_ok;
+ cq_->Shutdown();
+ while (LoopForTag(&ignored_tag, &ignored_ok))
+ ;
+ stub_.reset();
+ grpc_recycle_unused_port(port_);
+ }
+
+ void BuildAndStartServer() {
+ ServerBuilder builder;
+ builder.AddListeningPort(server_address_.str(),
+ grpc::InsecureServerCredentials());
+ service_.reset(new grpc::testing::EchoTestService::AsyncService());
+ builder.RegisterService(service_.get());
+ cq_ = builder.AddCompletionQueue();
+ server_ = builder.BuildAndStart();
+ }
+
+ void ResetStub() {
+ std::shared_ptr<Channel> channel = grpc::CreateChannel(
+ server_address_.str(), grpc::InsecureChannelCredentials());
+ stub_ = grpc::testing::EchoTestService::NewStub(channel);
+ }
+
+ void SendRpc(int num_rpcs) {
+ for (int i = 0; i < num_rpcs; i++) {
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+
+ ClientContext cli_ctx;
+ ServerContext srv_ctx;
+ grpc::ServerAsyncResponseWriter<EchoResponse> response_writer(&srv_ctx);
+
+ send_request.set_message("hello non-blocking world");
+ std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader(
+ stub_->PrepareAsyncEcho(&cli_ctx, send_request, cq_.get()));
+
+ response_reader->StartCall();
+ response_reader->Finish(&recv_response, &recv_status, tag(4));
+
+ service_->RequestEcho(&srv_ctx, &recv_request, &response_writer,
+ cq_.get(), cq_.get(), tag(2));
+
+ void* got_tag;
+ bool ok;
+ EXPECT_TRUE(LoopForTag(&got_tag, &ok));
+ EXPECT_TRUE(ok);
+ EXPECT_EQ(detag(got_tag), 2);
+ EXPECT_EQ(send_request.message(), recv_request.message());
+
+ send_response.set_message(recv_request.message());
+ response_writer.Finish(send_response, Status::OK, tag(3));
+
+ int tagsum = 0;
+ int tagprod = 1;
+ EXPECT_TRUE(LoopForTag(&got_tag, &ok));
+ EXPECT_TRUE(ok);
+ tagsum += detag(got_tag);
+ tagprod *= detag(got_tag);
+
+ EXPECT_TRUE(LoopForTag(&got_tag, &ok));
+ EXPECT_TRUE(ok);
+ tagsum += detag(got_tag);
+ tagprod *= detag(got_tag);
+
+ EXPECT_EQ(tagsum, 7);
+ EXPECT_EQ(tagprod, 12);
+ EXPECT_EQ(send_response.message(), recv_response.message());
+ EXPECT_TRUE(recv_status.ok());
+ }
+ }
+
+ std::unique_ptr<ServerCompletionQueue> cq_;
+ std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_;
+ std::unique_ptr<Server> server_;
+ std::unique_ptr<grpc::testing::EchoTestService::AsyncService> service_;
+ std::ostringstream server_address_;
+ int port_;
+};
+
+TEST_F(NonblockingTest, SimpleRpc) {
+ ResetStub();
+ SendRpc(10);
+}
+
+} // namespace
+} // namespace testing
+} // namespace grpc
+
+#endif // GRPC_POSIX_SOCKET
+
+int main(int argc, char** argv) {
+#ifdef GRPC_POSIX_SOCKET
+ // Override the poll function before anything else can happen
+ grpc_poll_function = maybe_assert_non_blocking_poll;
+
+ grpc::testing::TestEnvironment env(argc, argv);
+ ::testing::InitGoogleTest(&argc, argv);
+ gpr_tls_init(&g_is_nonblocking_poll);
+
+ // Start the nonblocking poll thread-local variable as false because the
+ // thread that issues RPCs starts by picking a port (which has non-zero
+ // timeout).
+ gpr_tls_set(&g_is_nonblocking_poll, static_cast<intptr_t>(false));
+
+ int ret = RUN_ALL_TESTS();
+ gpr_tls_destroy(&g_is_nonblocking_poll);
+ return ret;
+#else // GRPC_POSIX_SOCKET
+ return 0;
+#endif // GRPC_POSIX_SOCKET
+}
diff --git a/contrib/libs/grpc/test/cpp/end2end/port_sharing_end2end_test.cc b/contrib/libs/grpc/test/cpp/end2end/port_sharing_end2end_test.cc
new file mode 100644
index 0000000000..b69d1dd2be
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/end2end/port_sharing_end2end_test.cc
@@ -0,0 +1,374 @@
+/*
+ *
+ * Copyright 2019 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include <grpc/grpc.h>
+#include <grpc/support/alloc.h>
+#include <grpc/support/log.h>
+#include <grpc/support/time.h>
+#include <grpcpp/channel.h>
+#include <grpcpp/client_context.h>
+#include <grpcpp/create_channel.h>
+#include <grpcpp/security/credentials.h>
+#include <grpcpp/security/server_credentials.h>
+#include <grpcpp/server.h>
+#include <grpcpp/server_builder.h>
+#include <grpcpp/server_context.h>
+#include <gtest/gtest.h>
+
+#include <mutex>
+#include <thread>
+
+#include "src/core/lib/gpr/env.h"
+#include "src/core/lib/iomgr/endpoint.h"
+#include "src/core/lib/iomgr/exec_ctx.h"
+#include "src/core/lib/iomgr/pollset.h"
+#include "src/core/lib/iomgr/port.h"
+#include "src/core/lib/iomgr/tcp_server.h"
+#include "src/core/lib/security/credentials/credentials.h"
+#include "src/proto/grpc/testing/echo.grpc.pb.h"
+#include "test/core/util/port.h"
+#include "test/core/util/test_config.h"
+#include "test/core/util/test_tcp_server.h"
+#include "test/cpp/end2end/test_service_impl.h"
+#include "test/cpp/util/test_credentials_provider.h"
+
+#ifdef GRPC_POSIX_SOCKET_TCP_SERVER
+
+#include "src/core/lib/iomgr/tcp_posix.h"
+
+namespace grpc {
+namespace testing {
+namespace {
+
+class TestScenario {
+ public:
+ TestScenario(bool server_port, bool pending_data,
+ const TString& creds_type)
+ : server_has_port(server_port),
+ queue_pending_data(pending_data),
+ credentials_type(creds_type) {}
+ void Log() const;
+ // server has its own port or not
+ bool server_has_port;
+ // whether tcp server should read some data before handoff
+ bool queue_pending_data;
+ const TString credentials_type;
+};
+
+static std::ostream& operator<<(std::ostream& out,
+ const TestScenario& scenario) {
+ return out << "TestScenario{server_has_port="
+ << (scenario.server_has_port ? "true" : "false")
+ << ", queue_pending_data="
+ << (scenario.queue_pending_data ? "true" : "false")
+ << ", credentials='" << scenario.credentials_type << "'}";
+}
+
+void TestScenario::Log() const {
+ std::ostringstream out;
+ out << *this;
+ gpr_log(GPR_ERROR, "%s", out.str().c_str());
+}
+
+// Set up a test tcp server which is in charge of accepting connections and
+// handing off the connections as fds.
+class TestTcpServer {
+ public:
+ TestTcpServer()
+ : shutdown_(false),
+ queue_data_(false),
+ port_(grpc_pick_unused_port_or_die()) {
+ std::ostringstream server_address;
+ server_address << "localhost:" << port_;
+ address_ = server_address.str();
+ test_tcp_server_init(&tcp_server_, &TestTcpServer::OnConnect, this);
+ GRPC_CLOSURE_INIT(&on_fd_released_, &TestTcpServer::OnFdReleased, this,
+ grpc_schedule_on_exec_ctx);
+ }
+
+ ~TestTcpServer() {
+ running_thread_.join();
+ test_tcp_server_destroy(&tcp_server_);
+ grpc_recycle_unused_port(port_);
+ }
+
+ // Read some data before handing off the connection.
+ void SetQueueData() { queue_data_ = true; }
+
+ void Start() {
+ test_tcp_server_start(&tcp_server_, port_);
+ gpr_log(GPR_INFO, "Test TCP server started at %s", address_.c_str());
+ }
+
+ const TString& address() { return address_; }
+
+ void SetAcceptor(
+ std::unique_ptr<experimental::ExternalConnectionAcceptor> acceptor) {
+ connection_acceptor_ = std::move(acceptor);
+ }
+
+ void Run() {
+ running_thread_ = std::thread([this]() {
+ while (true) {
+ {
+ std::lock_guard<std::mutex> lock(mu_);
+ if (shutdown_) {
+ return;
+ }
+ }
+ test_tcp_server_poll(&tcp_server_, 1);
+ }
+ });
+ }
+
+ void Shutdown() {
+ std::lock_guard<std::mutex> lock(mu_);
+ shutdown_ = true;
+ }
+
+ static void OnConnect(void* arg, grpc_endpoint* tcp,
+ grpc_pollset* accepting_pollset,
+ grpc_tcp_server_acceptor* acceptor) {
+ auto* self = static_cast<TestTcpServer*>(arg);
+ self->OnConnect(tcp, accepting_pollset, acceptor);
+ }
+
+ static void OnFdReleased(void* arg, grpc_error* err) {
+ auto* self = static_cast<TestTcpServer*>(arg);
+ self->OnFdReleased(err);
+ }
+
+ private:
+ void OnConnect(grpc_endpoint* tcp, grpc_pollset* /*accepting_pollset*/,
+ grpc_tcp_server_acceptor* acceptor) {
+ TString peer(grpc_endpoint_get_peer(tcp));
+ gpr_log(GPR_INFO, "Got incoming connection! from %s", peer.c_str());
+ EXPECT_FALSE(acceptor->external_connection);
+ listener_fd_ = grpc_tcp_server_port_fd(
+ acceptor->from_server, acceptor->port_index, acceptor->fd_index);
+ gpr_free(acceptor);
+ grpc_tcp_destroy_and_release_fd(tcp, &fd_, &on_fd_released_);
+ }
+
+ void OnFdReleased(grpc_error* err) {
+ EXPECT_EQ(GRPC_ERROR_NONE, err);
+ experimental::ExternalConnectionAcceptor::NewConnectionParameters p;
+ p.listener_fd = listener_fd_;
+ p.fd = fd_;
+ if (queue_data_) {
+ char buf[1024];
+ ssize_t read_bytes = 0;
+ while (read_bytes <= 0) {
+ read_bytes = read(fd_, buf, 1024);
+ }
+ Slice data(buf, read_bytes);
+ p.read_buffer = ByteBuffer(&data, 1);
+ }
+ gpr_log(GPR_INFO, "Handing off fd %d with data size %d from listener fd %d",
+ fd_, static_cast<int>(p.read_buffer.Length()), listener_fd_);
+ connection_acceptor_->HandleNewConnection(&p);
+ }
+
+ std::mutex mu_;
+ bool shutdown_;
+
+ int listener_fd_ = -1;
+ int fd_ = -1;
+ bool queue_data_ = false;
+
+ grpc_closure on_fd_released_;
+ std::thread running_thread_;
+ int port_ = -1;
+ TString address_;
+ std::unique_ptr<experimental::ExternalConnectionAcceptor>
+ connection_acceptor_;
+ test_tcp_server tcp_server_;
+};
+
+class PortSharingEnd2endTest : public ::testing::TestWithParam<TestScenario> {
+ protected:
+ PortSharingEnd2endTest() : is_server_started_(false), first_picked_port_(0) {
+ GetParam().Log();
+ }
+
+ void SetUp() override {
+ if (GetParam().queue_pending_data) {
+ tcp_server1_.SetQueueData();
+ tcp_server2_.SetQueueData();
+ }
+ tcp_server1_.Start();
+ tcp_server2_.Start();
+ ServerBuilder builder;
+ if (GetParam().server_has_port) {
+ int port = grpc_pick_unused_port_or_die();
+ first_picked_port_ = port;
+ server_address_ << "localhost:" << port;
+ auto creds = GetCredentialsProvider()->GetServerCredentials(
+ GetParam().credentials_type);
+ builder.AddListeningPort(server_address_.str(), creds);
+ gpr_log(GPR_INFO, "gRPC server listening on %s",
+ server_address_.str().c_str());
+ }
+ auto server_creds = GetCredentialsProvider()->GetServerCredentials(
+ GetParam().credentials_type);
+ auto acceptor1 = builder.experimental().AddExternalConnectionAcceptor(
+ ServerBuilder::experimental_type::ExternalConnectionType::FROM_FD,
+ server_creds);
+ tcp_server1_.SetAcceptor(std::move(acceptor1));
+ auto acceptor2 = builder.experimental().AddExternalConnectionAcceptor(
+ ServerBuilder::experimental_type::ExternalConnectionType::FROM_FD,
+ server_creds);
+ tcp_server2_.SetAcceptor(std::move(acceptor2));
+ builder.RegisterService(&service_);
+ server_ = builder.BuildAndStart();
+ is_server_started_ = true;
+
+ tcp_server1_.Run();
+ tcp_server2_.Run();
+ }
+
+ void TearDown() override {
+ tcp_server1_.Shutdown();
+ tcp_server2_.Shutdown();
+ if (is_server_started_) {
+ server_->Shutdown();
+ }
+ if (first_picked_port_ > 0) {
+ grpc_recycle_unused_port(first_picked_port_);
+ }
+ }
+
+ void ResetStubs() {
+ EXPECT_TRUE(is_server_started_);
+ ChannelArguments args;
+ args.SetInt(GRPC_ARG_USE_LOCAL_SUBCHANNEL_POOL, 1);
+ auto channel_creds = GetCredentialsProvider()->GetChannelCredentials(
+ GetParam().credentials_type, &args);
+ channel_handoff1_ =
+ CreateCustomChannel(tcp_server1_.address(), channel_creds, args);
+ stub_handoff1_ = EchoTestService::NewStub(channel_handoff1_);
+ channel_handoff2_ =
+ CreateCustomChannel(tcp_server2_.address(), channel_creds, args);
+ stub_handoff2_ = EchoTestService::NewStub(channel_handoff2_);
+ if (GetParam().server_has_port) {
+ ChannelArguments direct_args;
+ direct_args.SetInt(GRPC_ARG_USE_LOCAL_SUBCHANNEL_POOL, 1);
+ auto direct_creds = GetCredentialsProvider()->GetChannelCredentials(
+ GetParam().credentials_type, &direct_args);
+ channel_direct_ =
+ CreateCustomChannel(server_address_.str(), direct_creds, direct_args);
+ stub_direct_ = EchoTestService::NewStub(channel_direct_);
+ }
+ }
+
+ bool is_server_started_;
+ // channel/stub to the test tcp server, the connection will be handed to the
+ // grpc server.
+ std::shared_ptr<Channel> channel_handoff1_;
+ std::unique_ptr<EchoTestService::Stub> stub_handoff1_;
+ std::shared_ptr<Channel> channel_handoff2_;
+ std::unique_ptr<EchoTestService::Stub> stub_handoff2_;
+ // channel/stub to talk to the grpc server directly, if applicable.
+ std::shared_ptr<Channel> channel_direct_;
+ std::unique_ptr<EchoTestService::Stub> stub_direct_;
+ std::unique_ptr<Server> server_;
+ std::ostringstream server_address_;
+ TestServiceImpl service_;
+ TestTcpServer tcp_server1_;
+ TestTcpServer tcp_server2_;
+ int first_picked_port_;
+};
+
+static void SendRpc(EchoTestService::Stub* stub, int num_rpcs) {
+ EchoRequest request;
+ EchoResponse response;
+ request.set_message("Hello hello hello hello");
+
+ for (int i = 0; i < num_rpcs; ++i) {
+ ClientContext context;
+ Status s = stub->Echo(&context, request, &response);
+ EXPECT_EQ(response.message(), request.message());
+ EXPECT_TRUE(s.ok());
+ }
+}
+
+std::vector<TestScenario> CreateTestScenarios() {
+ std::vector<TestScenario> scenarios;
+ std::vector<TString> credentials_types;
+
+#if TARGET_OS_IPHONE
+ // Workaround Apple CFStream bug
+ gpr_setenv("grpc_cfstream", "0");
+#endif
+
+ credentials_types = GetCredentialsProvider()->GetSecureCredentialsTypeList();
+ // Only allow insecure credentials type when it is registered with the
+ // provider. User may create providers that do not have insecure.
+ if (GetCredentialsProvider()->GetChannelCredentials(kInsecureCredentialsType,
+ nullptr) != nullptr) {
+ credentials_types.push_back(kInsecureCredentialsType);
+ }
+
+ GPR_ASSERT(!credentials_types.empty());
+ for (const auto& cred : credentials_types) {
+ for (auto server_has_port : {true, false}) {
+ for (auto queue_pending_data : {true, false}) {
+ scenarios.emplace_back(server_has_port, queue_pending_data, cred);
+ }
+ }
+ }
+ return scenarios;
+}
+
+TEST_P(PortSharingEnd2endTest, HandoffAndDirectCalls) {
+ ResetStubs();
+ SendRpc(stub_handoff1_.get(), 5);
+ if (GetParam().server_has_port) {
+ SendRpc(stub_direct_.get(), 5);
+ }
+}
+
+TEST_P(PortSharingEnd2endTest, MultipleHandoff) {
+ for (int i = 0; i < 3; i++) {
+ ResetStubs();
+ SendRpc(stub_handoff2_.get(), 1);
+ }
+}
+
+TEST_P(PortSharingEnd2endTest, TwoHandoffPorts) {
+ for (int i = 0; i < 3; i++) {
+ ResetStubs();
+ SendRpc(stub_handoff1_.get(), 5);
+ SendRpc(stub_handoff2_.get(), 5);
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(PortSharingEnd2end, PortSharingEnd2endTest,
+ ::testing::ValuesIn(CreateTestScenarios()));
+
+} // namespace
+} // namespace testing
+} // namespace grpc
+
+#endif // GRPC_POSIX_SOCKET_TCP_SERVER
+
+int main(int argc, char** argv) {
+ grpc::testing::TestEnvironment env(argc, argv);
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/contrib/libs/grpc/test/cpp/end2end/proto_server_reflection_test.cc b/contrib/libs/grpc/test/cpp/end2end/proto_server_reflection_test.cc
new file mode 100644
index 0000000000..d79b33da70
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/end2end/proto_server_reflection_test.cc
@@ -0,0 +1,150 @@
+/*
+ *
+ * Copyright 2016 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include <grpc/grpc.h>
+#include <grpcpp/channel.h>
+#include <grpcpp/client_context.h>
+#include <grpcpp/create_channel.h>
+#include <grpcpp/ext/proto_server_reflection_plugin.h>
+#include <grpcpp/security/credentials.h>
+#include <grpcpp/security/server_credentials.h>
+#include <grpcpp/server.h>
+#include <grpcpp/server_builder.h>
+#include <grpcpp/server_context.h>
+
+#include "src/proto/grpc/testing/echo.grpc.pb.h"
+#include "test/core/util/port.h"
+#include "test/core/util/test_config.h"
+#include "test/cpp/end2end/test_service_impl.h"
+#include "test/cpp/util/proto_reflection_descriptor_database.h"
+
+#include <gtest/gtest.h>
+
+namespace grpc {
+namespace testing {
+
+class ProtoServerReflectionTest : public ::testing::Test {
+ public:
+ ProtoServerReflectionTest() {}
+
+ void SetUp() override {
+ port_ = grpc_pick_unused_port_or_die();
+ ref_desc_pool_ = protobuf::DescriptorPool::generated_pool();
+
+ ServerBuilder builder;
+ TString server_address = "localhost:" + to_string(port_);
+ builder.AddListeningPort(server_address, InsecureServerCredentials());
+ server_ = builder.BuildAndStart();
+ }
+
+ void ResetStub() {
+ string target = "dns:localhost:" + to_string(port_);
+ std::shared_ptr<Channel> channel =
+ grpc::CreateChannel(target, InsecureChannelCredentials());
+ stub_ = grpc::testing::EchoTestService::NewStub(channel);
+ desc_db_.reset(new ProtoReflectionDescriptorDatabase(channel));
+ desc_pool_.reset(new protobuf::DescriptorPool(desc_db_.get()));
+ }
+
+ string to_string(const int number) {
+ std::stringstream strs;
+ strs << number;
+ return strs.str();
+ }
+
+ void CompareService(const TString& service) {
+ const protobuf::ServiceDescriptor* service_desc =
+ desc_pool_->FindServiceByName(service);
+ const protobuf::ServiceDescriptor* ref_service_desc =
+ ref_desc_pool_->FindServiceByName(service);
+ EXPECT_TRUE(service_desc != nullptr);
+ EXPECT_TRUE(ref_service_desc != nullptr);
+ EXPECT_EQ(service_desc->DebugString(), ref_service_desc->DebugString());
+
+ const protobuf::FileDescriptor* file_desc = service_desc->file();
+ if (known_files_.find(file_desc->package() + "/" + file_desc->name()) !=
+ known_files_.end()) {
+ EXPECT_EQ(file_desc->DebugString(),
+ ref_service_desc->file()->DebugString());
+ known_files_.insert(file_desc->package() + "/" + file_desc->name());
+ }
+
+ for (int i = 0; i < service_desc->method_count(); ++i) {
+ CompareMethod(service_desc->method(i)->full_name());
+ }
+ }
+
+ void CompareMethod(const TString& method) {
+ const protobuf::MethodDescriptor* method_desc =
+ desc_pool_->FindMethodByName(method);
+ const protobuf::MethodDescriptor* ref_method_desc =
+ ref_desc_pool_->FindMethodByName(method);
+ EXPECT_TRUE(method_desc != nullptr);
+ EXPECT_TRUE(ref_method_desc != nullptr);
+ EXPECT_EQ(method_desc->DebugString(), ref_method_desc->DebugString());
+
+ CompareType(method_desc->input_type()->full_name());
+ CompareType(method_desc->output_type()->full_name());
+ }
+
+ void CompareType(const TString& type) {
+ if (known_types_.find(type) != known_types_.end()) {
+ return;
+ }
+
+ const protobuf::Descriptor* desc = desc_pool_->FindMessageTypeByName(type);
+ const protobuf::Descriptor* ref_desc =
+ ref_desc_pool_->FindMessageTypeByName(type);
+ EXPECT_TRUE(desc != nullptr);
+ EXPECT_TRUE(ref_desc != nullptr);
+ EXPECT_EQ(desc->DebugString(), ref_desc->DebugString());
+ }
+
+ protected:
+ std::unique_ptr<Server> server_;
+ std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_;
+ std::unique_ptr<ProtoReflectionDescriptorDatabase> desc_db_;
+ std::unique_ptr<protobuf::DescriptorPool> desc_pool_;
+ std::unordered_set<string> known_files_;
+ std::unordered_set<string> known_types_;
+ const protobuf::DescriptorPool* ref_desc_pool_;
+ int port_;
+ reflection::ProtoServerReflectionPlugin plugin_;
+};
+
+TEST_F(ProtoServerReflectionTest, CheckResponseWithLocalDescriptorPool) {
+ ResetStub();
+
+ std::vector<TString> services;
+ desc_db_->GetServices(&services);
+ // The service list has at least one service (reflection servcie).
+ EXPECT_TRUE(services.size() > 0);
+
+ for (auto it = services.begin(); it != services.end(); ++it) {
+ CompareService(*it);
+ }
+}
+
+} // namespace testing
+} // namespace grpc
+
+int main(int argc, char** argv) {
+ grpc::testing::TestEnvironment env(argc, argv);
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/contrib/libs/grpc/test/cpp/end2end/raw_end2end_test.cc b/contrib/libs/grpc/test/cpp/end2end/raw_end2end_test.cc
new file mode 100644
index 0000000000..184dc1e5f5
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/end2end/raw_end2end_test.cc
@@ -0,0 +1,370 @@
+/*
+ *
+ * Copyright 2015 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include <cinttypes>
+#include <memory>
+#include <thread>
+
+#include <grpc/grpc.h>
+#include <grpc/support/alloc.h>
+#include <grpc/support/log.h>
+#include <grpc/support/time.h>
+#include <grpcpp/channel.h>
+#include <grpcpp/client_context.h>
+#include <grpcpp/create_channel.h>
+#include <grpcpp/server.h>
+#include <grpcpp/server_builder.h>
+#include <grpcpp/server_context.h>
+
+#include "src/core/lib/gpr/env.h"
+#include "src/core/lib/iomgr/port.h"
+#include "src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.h"
+#include "src/proto/grpc/testing/echo.grpc.pb.h"
+#include "test/core/util/port.h"
+#include "test/core/util/test_config.h"
+#include "test/cpp/util/byte_buffer_proto_helper.h"
+#include "test/cpp/util/string_ref_helper.h"
+
+#include <gtest/gtest.h>
+
+using grpc::testing::EchoRequest;
+using grpc::testing::EchoResponse;
+
+namespace grpc {
+namespace testing {
+
+namespace {
+
+void* tag(int i) { return (void*)static_cast<intptr_t>(i); }
+int detag(void* p) { return static_cast<int>(reinterpret_cast<intptr_t>(p)); }
+
+class Verifier {
+ public:
+ Verifier() {}
+
+ // Expect sets the expected ok value for a specific tag
+ Verifier& Expect(int i, bool expect_ok) {
+ expectations_[tag(i)] = expect_ok;
+ return *this;
+ }
+
+ // Next waits for 1 async tag to complete, checks its
+ // expectations, and returns the tag
+ int Next(CompletionQueue* cq, bool ignore_ok) {
+ bool ok;
+ void* got_tag;
+ EXPECT_TRUE(cq->Next(&got_tag, &ok));
+ GotTag(got_tag, ok, ignore_ok);
+ return detag(got_tag);
+ }
+
+ // Verify keeps calling Next until all currently set
+ // expected tags are complete
+ void Verify(CompletionQueue* cq) {
+ GPR_ASSERT(!expectations_.empty());
+ while (!expectations_.empty()) {
+ Next(cq, false);
+ }
+ }
+
+ private:
+ void GotTag(void* got_tag, bool ok, bool ignore_ok) {
+ auto it = expectations_.find(got_tag);
+ if (it != expectations_.end()) {
+ if (!ignore_ok) {
+ EXPECT_EQ(it->second, ok);
+ }
+ expectations_.erase(it);
+ }
+ }
+
+ std::map<void*, bool> expectations_;
+};
+
+class RawEnd2EndTest : public ::testing::Test {
+ protected:
+ RawEnd2EndTest() {}
+
+ void SetUp() override {
+ port_ = grpc_pick_unused_port_or_die();
+ server_address_ << "localhost:" << port_;
+ }
+
+ void TearDown() override {
+ server_->Shutdown();
+ void* ignored_tag;
+ bool ignored_ok;
+ cq_->Shutdown();
+ while (cq_->Next(&ignored_tag, &ignored_ok))
+ ;
+ stub_.reset();
+ grpc_recycle_unused_port(port_);
+ }
+
+ template <typename ServerType>
+ std::unique_ptr<ServerType> BuildAndStartServer() {
+ ServerBuilder builder;
+ builder.AddListeningPort(server_address_.str(),
+ grpc::InsecureServerCredentials());
+ std::unique_ptr<ServerType> service(new ServerType());
+ builder.RegisterService(service.get());
+ cq_ = builder.AddCompletionQueue();
+ server_ = builder.BuildAndStart();
+ return service;
+ }
+
+ void ResetStub() {
+ ChannelArguments args;
+ std::shared_ptr<Channel> channel = grpc::CreateChannel(
+ server_address_.str(), grpc::InsecureChannelCredentials());
+ stub_ = grpc::testing::EchoTestService::NewStub(channel);
+ }
+
+ std::unique_ptr<ServerCompletionQueue> cq_;
+ std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_;
+ std::unique_ptr<Server> server_;
+ std::ostringstream server_address_;
+ int port_;
+
+ // For the client application to populate and send to server.
+ EchoRequest send_request_;
+ ::grpc::ByteBuffer send_request_buffer_;
+
+ // For the server to give to gRPC to be populated by incoming request
+ // from client.
+ EchoRequest recv_request_;
+ ::grpc::ByteBuffer recv_request_buffer_;
+
+ // For the server application to populate and send back to client.
+ EchoResponse send_response_;
+ ::grpc::ByteBuffer send_response_buffer_;
+
+ // For the client to give to gRPC to be populated by incoming response
+ // from server.
+ EchoResponse recv_response_;
+ ::grpc::ByteBuffer recv_response_buffer_;
+ Status recv_status_;
+
+ // Both sides need contexts
+ ClientContext cli_ctx_;
+ ServerContext srv_ctx_;
+};
+
+// Regular Async, both peers use proto
+TEST_F(RawEnd2EndTest, PureAsyncService) {
+ typedef grpc::testing::EchoTestService::AsyncService SType;
+ ResetStub();
+ auto service = BuildAndStartServer<SType>();
+ grpc::ServerAsyncResponseWriter<EchoResponse> response_writer(&srv_ctx_);
+
+ send_request_.set_message("hello");
+ std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader(
+ stub_->AsyncEcho(&cli_ctx_, send_request_, cq_.get()));
+ service->RequestEcho(&srv_ctx_, &recv_request_, &response_writer, cq_.get(),
+ cq_.get(), tag(2));
+ response_reader->Finish(&recv_response_, &recv_status_, tag(4));
+ Verifier().Expect(2, true).Verify(cq_.get());
+ EXPECT_EQ(send_request_.message(), recv_request_.message());
+ send_response_.set_message(recv_request_.message());
+ response_writer.Finish(send_response_, Status::OK, tag(3));
+ Verifier().Expect(3, true).Expect(4, true).Verify(cq_.get());
+
+ EXPECT_EQ(send_response_.message(), recv_response_.message());
+ EXPECT_TRUE(recv_status_.ok());
+}
+
+// Client uses proto, server uses generic codegen, unary
+TEST_F(RawEnd2EndTest, RawServerUnary) {
+ typedef grpc::testing::EchoTestService::WithRawMethod_Echo<
+ grpc::testing::EchoTestService::Service>
+ SType;
+ ResetStub();
+ auto service = BuildAndStartServer<SType>();
+ grpc::GenericServerAsyncResponseWriter response_writer(&srv_ctx_);
+
+ send_request_.set_message("hello unary");
+ std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader(
+ stub_->AsyncEcho(&cli_ctx_, send_request_, cq_.get()));
+ service->RequestEcho(&srv_ctx_, &recv_request_buffer_, &response_writer,
+ cq_.get(), cq_.get(), tag(2));
+ response_reader->Finish(&recv_response_, &recv_status_, tag(4));
+ Verifier().Expect(2, true).Verify(cq_.get());
+ EXPECT_TRUE(ParseFromByteBuffer(&recv_request_buffer_, &recv_request_));
+ EXPECT_EQ(send_request_.message(), recv_request_.message());
+ send_response_.set_message(recv_request_.message());
+ EXPECT_TRUE(
+ SerializeToByteBufferInPlace(&send_response_, &send_response_buffer_));
+ response_writer.Finish(send_response_buffer_, Status::OK, tag(3));
+ Verifier().Expect(3, true).Expect(4, true).Verify(cq_.get());
+
+ EXPECT_EQ(send_response_.message(), recv_response_.message());
+ EXPECT_TRUE(recv_status_.ok());
+}
+
+// Client uses proto, server uses generic codegen, client streaming
+TEST_F(RawEnd2EndTest, RawServerClientStreaming) {
+ typedef grpc::testing::EchoTestService::WithRawMethod_RequestStream<
+ grpc::testing::EchoTestService::Service>
+ SType;
+ ResetStub();
+ auto service = BuildAndStartServer<SType>();
+
+ grpc::GenericServerAsyncReader srv_stream(&srv_ctx_);
+
+ send_request_.set_message("hello client streaming");
+ std::unique_ptr<ClientAsyncWriter<EchoRequest>> cli_stream(
+ stub_->AsyncRequestStream(&cli_ctx_, &recv_response_, cq_.get(), tag(1)));
+
+ service->RequestRequestStream(&srv_ctx_, &srv_stream, cq_.get(), cq_.get(),
+ tag(2));
+
+ Verifier().Expect(2, true).Expect(1, true).Verify(cq_.get());
+
+ cli_stream->Write(send_request_, tag(3));
+ srv_stream.Read(&recv_request_buffer_, tag(4));
+ Verifier().Expect(3, true).Expect(4, true).Verify(cq_.get());
+ ParseFromByteBuffer(&recv_request_buffer_, &recv_request_);
+ EXPECT_EQ(send_request_.message(), recv_request_.message());
+
+ cli_stream->Write(send_request_, tag(5));
+ srv_stream.Read(&recv_request_buffer_, tag(6));
+ Verifier().Expect(5, true).Expect(6, true).Verify(cq_.get());
+
+ ParseFromByteBuffer(&recv_request_buffer_, &recv_request_);
+ EXPECT_EQ(send_request_.message(), recv_request_.message());
+ cli_stream->WritesDone(tag(7));
+ srv_stream.Read(&recv_request_buffer_, tag(8));
+ Verifier().Expect(7, true).Expect(8, false).Verify(cq_.get());
+
+ ParseFromByteBuffer(&recv_request_buffer_, &recv_request_);
+ send_response_.set_message(recv_request_.message());
+ SerializeToByteBufferInPlace(&send_response_, &send_response_buffer_);
+ srv_stream.Finish(send_response_buffer_, Status::OK, tag(9));
+ cli_stream->Finish(&recv_status_, tag(10));
+ Verifier().Expect(9, true).Expect(10, true).Verify(cq_.get());
+
+ EXPECT_EQ(send_response_.message(), recv_response_.message());
+ EXPECT_TRUE(recv_status_.ok());
+}
+
+// Client uses proto, server uses generic codegen, server streaming
+TEST_F(RawEnd2EndTest, RawServerServerStreaming) {
+ typedef grpc::testing::EchoTestService::WithRawMethod_ResponseStream<
+ grpc::testing::EchoTestService::Service>
+ SType;
+ ResetStub();
+ auto service = BuildAndStartServer<SType>();
+ grpc::GenericServerAsyncWriter srv_stream(&srv_ctx_);
+
+ send_request_.set_message("hello server streaming");
+ std::unique_ptr<ClientAsyncReader<EchoResponse>> cli_stream(
+ stub_->AsyncResponseStream(&cli_ctx_, send_request_, cq_.get(), tag(1)));
+
+ service->RequestResponseStream(&srv_ctx_, &recv_request_buffer_, &srv_stream,
+ cq_.get(), cq_.get(), tag(2));
+
+ Verifier().Expect(1, true).Expect(2, true).Verify(cq_.get());
+ ParseFromByteBuffer(&recv_request_buffer_, &recv_request_);
+ EXPECT_EQ(send_request_.message(), recv_request_.message());
+
+ send_response_.set_message(recv_request_.message());
+ SerializeToByteBufferInPlace(&send_response_, &send_response_buffer_);
+ srv_stream.Write(send_response_buffer_, tag(3));
+ cli_stream->Read(&recv_response_, tag(4));
+ Verifier().Expect(3, true).Expect(4, true).Verify(cq_.get());
+ EXPECT_EQ(send_response_.message(), recv_response_.message());
+
+ srv_stream.Write(send_response_buffer_, tag(5));
+ cli_stream->Read(&recv_response_, tag(6));
+ Verifier().Expect(5, true).Expect(6, true).Verify(cq_.get());
+ EXPECT_EQ(send_response_.message(), recv_response_.message());
+
+ srv_stream.Finish(Status::OK, tag(7));
+ cli_stream->Read(&recv_response_, tag(8));
+ Verifier().Expect(7, true).Expect(8, false).Verify(cq_.get());
+
+ cli_stream->Finish(&recv_status_, tag(9));
+ Verifier().Expect(9, true).Verify(cq_.get());
+
+ EXPECT_TRUE(recv_status_.ok());
+}
+
+// Client uses proto, server uses generic codegen, bidi streaming
+TEST_F(RawEnd2EndTest, RawServerBidiStreaming) {
+ typedef grpc::testing::EchoTestService::WithRawMethod_BidiStream<
+ grpc::testing::EchoTestService::Service>
+ SType;
+ ResetStub();
+ auto service = BuildAndStartServer<SType>();
+
+ grpc::GenericServerAsyncReaderWriter srv_stream(&srv_ctx_);
+
+ send_request_.set_message("hello bidi streaming");
+ std::unique_ptr<ClientAsyncReaderWriter<EchoRequest, EchoResponse>>
+ cli_stream(stub_->AsyncBidiStream(&cli_ctx_, cq_.get(), tag(1)));
+
+ service->RequestBidiStream(&srv_ctx_, &srv_stream, cq_.get(), cq_.get(),
+ tag(2));
+
+ Verifier().Expect(1, true).Expect(2, true).Verify(cq_.get());
+
+ cli_stream->Write(send_request_, tag(3));
+ srv_stream.Read(&recv_request_buffer_, tag(4));
+ Verifier().Expect(3, true).Expect(4, true).Verify(cq_.get());
+ ParseFromByteBuffer(&recv_request_buffer_, &recv_request_);
+ EXPECT_EQ(send_request_.message(), recv_request_.message());
+
+ send_response_.set_message(recv_request_.message());
+ SerializeToByteBufferInPlace(&send_response_, &send_response_buffer_);
+ srv_stream.Write(send_response_buffer_, tag(5));
+ cli_stream->Read(&recv_response_, tag(6));
+ Verifier().Expect(5, true).Expect(6, true).Verify(cq_.get());
+ EXPECT_EQ(send_response_.message(), recv_response_.message());
+
+ cli_stream->WritesDone(tag(7));
+ srv_stream.Read(&recv_request_buffer_, tag(8));
+ Verifier().Expect(7, true).Expect(8, false).Verify(cq_.get());
+
+ srv_stream.Finish(Status::OK, tag(9));
+ cli_stream->Finish(&recv_status_, tag(10));
+ Verifier().Expect(9, true).Expect(10, true).Verify(cq_.get());
+
+ EXPECT_TRUE(recv_status_.ok());
+}
+
+// Testing that this pattern compiles
+TEST_F(RawEnd2EndTest, CompileTest) {
+ typedef grpc::testing::EchoTestService::WithRawMethod_Echo<
+ grpc::testing::EchoTestService::AsyncService>
+ SType;
+ ResetStub();
+ auto service = BuildAndStartServer<SType>();
+}
+
+} // namespace
+} // namespace testing
+} // namespace grpc
+
+int main(int argc, char** argv) {
+ // Change the backup poll interval from 5s to 100ms to speed up the
+ // ReconnectChannel test
+ grpc::testing::TestEnvironment env(argc, argv);
+ ::testing::InitGoogleTest(&argc, argv);
+ int ret = RUN_ALL_TESTS();
+ return ret;
+}
diff --git a/contrib/libs/grpc/test/cpp/end2end/server_builder_plugin_test.cc b/contrib/libs/grpc/test/cpp/end2end/server_builder_plugin_test.cc
new file mode 100644
index 0000000000..004902cad3
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/end2end/server_builder_plugin_test.cc
@@ -0,0 +1,265 @@
+/*
+ *
+ * Copyright 2016 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include <thread>
+
+#include <grpc/grpc.h>
+#include <grpcpp/channel.h>
+#include <grpcpp/client_context.h>
+#include <grpcpp/create_channel.h>
+#include <grpcpp/impl/server_builder_option.h>
+#include <grpcpp/impl/server_builder_plugin.h>
+#include <grpcpp/impl/server_initializer.h>
+#include <grpcpp/security/credentials.h>
+#include <grpcpp/security/server_credentials.h>
+#include <grpcpp/server.h>
+#include <grpcpp/server_builder.h>
+#include <grpcpp/server_context.h>
+
+#include "src/proto/grpc/testing/echo.grpc.pb.h"
+#include "test/core/util/port.h"
+#include "test/core/util/test_config.h"
+#include "test/cpp/end2end/test_service_impl.h"
+
+#include <gtest/gtest.h>
+
+#define PLUGIN_NAME "TestServerBuilderPlugin"
+
+namespace grpc {
+namespace testing {
+
+class TestServerBuilderPlugin : public ServerBuilderPlugin {
+ public:
+ TestServerBuilderPlugin() : service_(new TestServiceImpl()) {
+ init_server_is_called_ = false;
+ finish_is_called_ = false;
+ change_arguments_is_called_ = false;
+ register_service_ = false;
+ }
+
+ TString name() override { return PLUGIN_NAME; }
+
+ void InitServer(ServerInitializer* si) override {
+ init_server_is_called_ = true;
+ if (register_service_) {
+ si->RegisterService(service_);
+ }
+ }
+
+ void Finish(ServerInitializer* /*si*/) override { finish_is_called_ = true; }
+
+ void ChangeArguments(const TString& /*name*/, void* /*value*/) override {
+ change_arguments_is_called_ = true;
+ }
+
+ bool has_async_methods() const override {
+ if (register_service_) {
+ return service_->has_async_methods();
+ }
+ return false;
+ }
+
+ bool has_sync_methods() const override {
+ if (register_service_) {
+ return service_->has_synchronous_methods();
+ }
+ return false;
+ }
+
+ void SetRegisterService() { register_service_ = true; }
+
+ bool init_server_is_called() { return init_server_is_called_; }
+ bool finish_is_called() { return finish_is_called_; }
+ bool change_arguments_is_called() { return change_arguments_is_called_; }
+
+ private:
+ bool init_server_is_called_;
+ bool finish_is_called_;
+ bool change_arguments_is_called_;
+ bool register_service_;
+ std::shared_ptr<TestServiceImpl> service_;
+};
+
+class InsertPluginServerBuilderOption : public ServerBuilderOption {
+ public:
+ InsertPluginServerBuilderOption() { register_service_ = false; }
+
+ void UpdateArguments(ChannelArguments* /*arg*/) override {}
+
+ void UpdatePlugins(
+ std::vector<std::unique_ptr<ServerBuilderPlugin>>* plugins) override {
+ plugins->clear();
+
+ std::unique_ptr<TestServerBuilderPlugin> plugin(
+ new TestServerBuilderPlugin());
+ if (register_service_) plugin->SetRegisterService();
+ plugins->emplace_back(std::move(plugin));
+ }
+
+ void SetRegisterService() { register_service_ = true; }
+
+ private:
+ bool register_service_;
+};
+
+std::unique_ptr<ServerBuilderPlugin> CreateTestServerBuilderPlugin() {
+ return std::unique_ptr<ServerBuilderPlugin>(new TestServerBuilderPlugin());
+}
+
+// Force AddServerBuilderPlugin() to be called at static initialization time.
+struct StaticTestPluginInitializer {
+ StaticTestPluginInitializer() {
+ ::grpc::ServerBuilder::InternalAddPluginFactory(
+ &CreateTestServerBuilderPlugin);
+ }
+} static_plugin_initializer_test_;
+
+// When the param boolean is true, the ServerBuilder plugin will be added at the
+// time of static initialization. When it's false, the ServerBuilder plugin will
+// be added using ServerBuilder::SetOption().
+class ServerBuilderPluginTest : public ::testing::TestWithParam<bool> {
+ public:
+ ServerBuilderPluginTest() {}
+
+ void SetUp() override {
+ port_ = grpc_pick_unused_port_or_die();
+ builder_.reset(new ServerBuilder());
+ }
+
+ void InsertPlugin() {
+ if (GetParam()) {
+ // Add ServerBuilder plugin in static initialization
+ CheckPresent();
+ } else {
+ // Add ServerBuilder plugin using ServerBuilder::SetOption()
+ builder_->SetOption(std::unique_ptr<ServerBuilderOption>(
+ new InsertPluginServerBuilderOption()));
+ }
+ }
+
+ void InsertPluginWithTestService() {
+ if (GetParam()) {
+ // Add ServerBuilder plugin in static initialization
+ auto plugin = CheckPresent();
+ EXPECT_TRUE(plugin);
+ plugin->SetRegisterService();
+ } else {
+ // Add ServerBuilder plugin using ServerBuilder::SetOption()
+ std::unique_ptr<InsertPluginServerBuilderOption> option(
+ new InsertPluginServerBuilderOption());
+ option->SetRegisterService();
+ builder_->SetOption(std::move(option));
+ }
+ }
+
+ void StartServer() {
+ TString server_address = "localhost:" + to_string(port_);
+ builder_->AddListeningPort(server_address, InsecureServerCredentials());
+ // we run some tests without a service, and for those we need to supply a
+ // frequently polled completion queue
+ cq_ = builder_->AddCompletionQueue();
+ cq_thread_ = new std::thread(&ServerBuilderPluginTest::RunCQ, this);
+ server_ = builder_->BuildAndStart();
+ EXPECT_TRUE(CheckPresent());
+ }
+
+ void ResetStub() {
+ string target = "dns:localhost:" + to_string(port_);
+ channel_ = grpc::CreateChannel(target, InsecureChannelCredentials());
+ stub_ = grpc::testing::EchoTestService::NewStub(channel_);
+ }
+
+ void TearDown() override {
+ auto plugin = CheckPresent();
+ EXPECT_TRUE(plugin);
+ EXPECT_TRUE(plugin->init_server_is_called());
+ EXPECT_TRUE(plugin->finish_is_called());
+ server_->Shutdown();
+ cq_->Shutdown();
+ cq_thread_->join();
+ delete cq_thread_;
+ }
+
+ string to_string(const int number) {
+ std::stringstream strs;
+ strs << number;
+ return strs.str();
+ }
+
+ protected:
+ std::shared_ptr<Channel> channel_;
+ std::unique_ptr<ServerBuilder> builder_;
+ std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_;
+ std::unique_ptr<ServerCompletionQueue> cq_;
+ std::unique_ptr<Server> server_;
+ std::thread* cq_thread_;
+ TestServiceImpl service_;
+ int port_;
+
+ private:
+ TestServerBuilderPlugin* CheckPresent() {
+ auto it = builder_->plugins_.begin();
+ for (; it != builder_->plugins_.end(); it++) {
+ if ((*it)->name() == PLUGIN_NAME) break;
+ }
+ if (it != builder_->plugins_.end()) {
+ return static_cast<TestServerBuilderPlugin*>(it->get());
+ } else {
+ return nullptr;
+ }
+ }
+
+ void RunCQ() {
+ void* tag;
+ bool ok;
+ while (cq_->Next(&tag, &ok))
+ ;
+ }
+};
+
+TEST_P(ServerBuilderPluginTest, PluginWithoutServiceTest) {
+ InsertPlugin();
+ StartServer();
+}
+
+TEST_P(ServerBuilderPluginTest, PluginWithServiceTest) {
+ InsertPluginWithTestService();
+ StartServer();
+ ResetStub();
+
+ EchoRequest request;
+ EchoResponse response;
+ request.set_message("Hello hello hello hello");
+ ClientContext context;
+ context.set_compression_algorithm(GRPC_COMPRESS_GZIP);
+ Status s = stub_->Echo(&context, request, &response);
+ EXPECT_EQ(response.message(), request.message());
+ EXPECT_TRUE(s.ok());
+}
+
+INSTANTIATE_TEST_SUITE_P(ServerBuilderPluginTest, ServerBuilderPluginTest,
+ ::testing::Values(false, true));
+
+} // namespace testing
+} // namespace grpc
+
+int main(int argc, char** argv) {
+ grpc::testing::TestEnvironment env(argc, argv);
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/contrib/libs/grpc/test/cpp/end2end/server_crash_test.cc b/contrib/libs/grpc/test/cpp/end2end/server_crash_test.cc
new file mode 100644
index 0000000000..3616d680f9
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/end2end/server_crash_test.cc
@@ -0,0 +1,160 @@
+/*
+ *
+ * Copyright 2015 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include <grpc/grpc.h>
+#include <grpc/support/log.h>
+#include <grpc/support/time.h>
+#include <grpcpp/channel.h>
+#include <grpcpp/client_context.h>
+#include <grpcpp/create_channel.h>
+#include <grpcpp/server.h>
+#include <grpcpp/server_builder.h>
+#include <grpcpp/server_context.h>
+
+#include "src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.h"
+#include "src/proto/grpc/testing/echo.grpc.pb.h"
+#include "test/core/util/port.h"
+#include "test/core/util/test_config.h"
+#include "test/cpp/util/subprocess.h"
+
+#include <gtest/gtest.h>
+
+using grpc::testing::EchoRequest;
+using grpc::testing::EchoResponse;
+using std::chrono::system_clock;
+
+static TString g_root;
+
+namespace grpc {
+namespace testing {
+
+namespace {
+
+class ServiceImpl final : public ::grpc::testing::EchoTestService::Service {
+ public:
+ ServiceImpl() : bidi_stream_count_(0), response_stream_count_(0) {}
+
+ Status BidiStream(
+ ServerContext* /*context*/,
+ ServerReaderWriter<EchoResponse, EchoRequest>* stream) override {
+ bidi_stream_count_++;
+ EchoRequest request;
+ EchoResponse response;
+ while (stream->Read(&request)) {
+ gpr_log(GPR_INFO, "recv msg %s", request.message().c_str());
+ response.set_message(request.message());
+ stream->Write(response);
+ gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME),
+ gpr_time_from_seconds(1, GPR_TIMESPAN)));
+ }
+ return Status::OK;
+ }
+
+ Status ResponseStream(ServerContext* /*context*/,
+ const EchoRequest* /*request*/,
+ ServerWriter<EchoResponse>* writer) override {
+ EchoResponse response;
+ response_stream_count_++;
+ for (int i = 0;; i++) {
+ std::ostringstream msg;
+ msg << "Hello " << i;
+ response.set_message(msg.str());
+ if (!writer->Write(response)) break;
+ gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME),
+ gpr_time_from_seconds(1, GPR_TIMESPAN)));
+ }
+ return Status::OK;
+ }
+
+ int bidi_stream_count() { return bidi_stream_count_; }
+
+ int response_stream_count() { return response_stream_count_; }
+
+ private:
+ int bidi_stream_count_;
+ int response_stream_count_;
+};
+
+class CrashTest : public ::testing::Test {
+ protected:
+ CrashTest() {}
+
+ std::unique_ptr<Server> CreateServerAndClient(const TString& mode) {
+ auto port = grpc_pick_unused_port_or_die();
+ std::ostringstream addr_stream;
+ addr_stream << "localhost:" << port;
+ auto addr = addr_stream.str();
+ client_.reset(new SubProcess({g_root + "/server_crash_test_client",
+ "--address=" + addr, "--mode=" + mode}));
+ GPR_ASSERT(client_);
+
+ ServerBuilder builder;
+ builder.AddListeningPort(addr, grpc::InsecureServerCredentials());
+ builder.RegisterService(&service_);
+ return builder.BuildAndStart();
+ }
+
+ void KillClient() { client_.reset(); }
+
+ bool HadOneBidiStream() { return service_.bidi_stream_count() == 1; }
+
+ bool HadOneResponseStream() { return service_.response_stream_count() == 1; }
+
+ private:
+ std::unique_ptr<SubProcess> client_;
+ ServiceImpl service_;
+};
+
+TEST_F(CrashTest, ResponseStream) {
+ auto server = CreateServerAndClient("response");
+
+ gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME),
+ gpr_time_from_seconds(60, GPR_TIMESPAN)));
+ KillClient();
+ server->Shutdown();
+ GPR_ASSERT(HadOneResponseStream());
+}
+
+TEST_F(CrashTest, BidiStream) {
+ auto server = CreateServerAndClient("bidi");
+
+ gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME),
+ gpr_time_from_seconds(60, GPR_TIMESPAN)));
+ KillClient();
+ server->Shutdown();
+ GPR_ASSERT(HadOneBidiStream());
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace grpc
+
+int main(int argc, char** argv) {
+ TString me = argv[0];
+ auto lslash = me.rfind('/');
+ if (lslash != TString::npos) {
+ g_root = me.substr(0, lslash);
+ } else {
+ g_root = ".";
+ }
+
+ grpc::testing::TestEnvironment env(argc, argv);
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/contrib/libs/grpc/test/cpp/end2end/server_crash_test_client.cc b/contrib/libs/grpc/test/cpp/end2end/server_crash_test_client.cc
new file mode 100644
index 0000000000..202fb2836c
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/end2end/server_crash_test_client.cc
@@ -0,0 +1,72 @@
+/*
+ *
+ * Copyright 2015 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include <gflags/gflags.h>
+#include <iostream>
+#include <memory>
+#include <sstream>
+#include <util/generic/string.h>
+
+#include <grpc/support/log.h>
+#include <grpcpp/channel.h>
+#include <grpcpp/client_context.h>
+#include <grpcpp/create_channel.h>
+
+#include "src/proto/grpc/testing/echo.grpc.pb.h"
+#include "test/cpp/util/test_config.h"
+
+DEFINE_string(address, "", "Address to connect to");
+DEFINE_string(mode, "", "Test mode to use");
+
+using grpc::testing::EchoRequest;
+using grpc::testing::EchoResponse;
+
+int main(int argc, char** argv) {
+ grpc::testing::InitTest(&argc, &argv, true);
+ auto stub = grpc::testing::EchoTestService::NewStub(
+ grpc::CreateChannel(FLAGS_address, grpc::InsecureChannelCredentials()));
+
+ EchoRequest request;
+ EchoResponse response;
+ grpc::ClientContext context;
+ context.set_wait_for_ready(true);
+
+ if (FLAGS_mode == "bidi") {
+ auto stream = stub->BidiStream(&context);
+ for (int i = 0;; i++) {
+ std::ostringstream msg;
+ msg << "Hello " << i;
+ request.set_message(msg.str());
+ GPR_ASSERT(stream->Write(request));
+ GPR_ASSERT(stream->Read(&response));
+ GPR_ASSERT(response.message() == request.message());
+ }
+ } else if (FLAGS_mode == "response") {
+ EchoRequest request;
+ request.set_message("Hello");
+ auto stream = stub->ResponseStream(&context, request);
+ for (;;) {
+ GPR_ASSERT(stream->Read(&response));
+ }
+ } else {
+ gpr_log(GPR_ERROR, "invalid test mode '%s'", FLAGS_mode.c_str());
+ return 1;
+ }
+
+ return 0;
+}
diff --git a/contrib/libs/grpc/test/cpp/end2end/server_early_return_test.cc b/contrib/libs/grpc/test/cpp/end2end/server_early_return_test.cc
new file mode 100644
index 0000000000..0f340516b0
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/end2end/server_early_return_test.cc
@@ -0,0 +1,232 @@
+/*
+ *
+ * Copyright 2018 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include <grpc/grpc.h>
+#include <grpc/support/alloc.h>
+#include <grpc/support/log.h>
+#include <grpc/support/time.h>
+#include <grpcpp/channel.h>
+#include <grpcpp/client_context.h>
+#include <grpcpp/create_channel.h>
+#include <grpcpp/security/credentials.h>
+#include <grpcpp/security/server_credentials.h>
+#include <grpcpp/server.h>
+#include <grpcpp/server_builder.h>
+#include <grpcpp/server_context.h>
+
+#include "src/proto/grpc/testing/echo.grpc.pb.h"
+#include "test/core/util/port.h"
+#include "test/core/util/test_config.h"
+#include "test/cpp/util/string_ref_helper.h"
+
+#include <gtest/gtest.h>
+
+namespace grpc {
+namespace testing {
+namespace {
+
+const char kServerReturnStatusCode[] = "server_return_status_code";
+const char kServerDelayBeforeReturnUs[] = "server_delay_before_return_us";
+const char kServerReturnAfterNReads[] = "server_return_after_n_reads";
+
+class TestServiceImpl : public ::grpc::testing::EchoTestService::Service {
+ public:
+ // Unused methods are not implemented.
+
+ Status RequestStream(ServerContext* context,
+ ServerReader<EchoRequest>* reader,
+ EchoResponse* response) override {
+ int server_return_status_code =
+ GetIntValueFromMetadata(context, kServerReturnStatusCode, 0);
+ int server_delay_before_return_us =
+ GetIntValueFromMetadata(context, kServerDelayBeforeReturnUs, 0);
+ int server_return_after_n_reads =
+ GetIntValueFromMetadata(context, kServerReturnAfterNReads, 0);
+
+ EchoRequest request;
+ while (server_return_after_n_reads--) {
+ EXPECT_TRUE(reader->Read(&request));
+ }
+
+ response->set_message("response msg");
+
+ gpr_sleep_until(gpr_time_add(
+ gpr_now(GPR_CLOCK_MONOTONIC),
+ gpr_time_from_micros(server_delay_before_return_us, GPR_TIMESPAN)));
+
+ return Status(static_cast<StatusCode>(server_return_status_code), "");
+ }
+
+ Status BidiStream(
+ ServerContext* context,
+ ServerReaderWriter<EchoResponse, EchoRequest>* stream) override {
+ int server_return_status_code =
+ GetIntValueFromMetadata(context, kServerReturnStatusCode, 0);
+ int server_delay_before_return_us =
+ GetIntValueFromMetadata(context, kServerDelayBeforeReturnUs, 0);
+ int server_return_after_n_reads =
+ GetIntValueFromMetadata(context, kServerReturnAfterNReads, 0);
+
+ EchoRequest request;
+ EchoResponse response;
+ while (server_return_after_n_reads--) {
+ EXPECT_TRUE(stream->Read(&request));
+ response.set_message(request.message());
+ EXPECT_TRUE(stream->Write(response));
+ }
+
+ gpr_sleep_until(gpr_time_add(
+ gpr_now(GPR_CLOCK_MONOTONIC),
+ gpr_time_from_micros(server_delay_before_return_us, GPR_TIMESPAN)));
+
+ return Status(static_cast<StatusCode>(server_return_status_code), "");
+ }
+
+ int GetIntValueFromMetadata(ServerContext* context, const char* key,
+ int default_value) {
+ auto metadata = context->client_metadata();
+ if (metadata.find(key) != metadata.end()) {
+ std::istringstream iss(ToString(metadata.find(key)->second));
+ iss >> default_value;
+ }
+ return default_value;
+ }
+};
+
+class ServerEarlyReturnTest : public ::testing::Test {
+ protected:
+ ServerEarlyReturnTest() : picked_port_(0) {}
+
+ void SetUp() override {
+ int port = grpc_pick_unused_port_or_die();
+ picked_port_ = port;
+ server_address_ << "127.0.0.1:" << port;
+ ServerBuilder builder;
+ builder.AddListeningPort(server_address_.str(),
+ InsecureServerCredentials());
+ builder.RegisterService(&service_);
+ server_ = builder.BuildAndStart();
+
+ channel_ = grpc::CreateChannel(server_address_.str(),
+ InsecureChannelCredentials());
+ stub_ = grpc::testing::EchoTestService::NewStub(channel_);
+ }
+
+ void TearDown() override {
+ server_->Shutdown();
+ if (picked_port_ > 0) {
+ grpc_recycle_unused_port(picked_port_);
+ }
+ }
+
+ // Client sends 20 requests and the server returns after reading 10 requests.
+ // If return_cancel is true, server returns CANCELLED status. Otherwise it
+ // returns OK.
+ void DoBidiStream(bool return_cancelled) {
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+
+ context.AddMetadata(kServerReturnAfterNReads, "10");
+ if (return_cancelled) {
+ // "1" means CANCELLED
+ context.AddMetadata(kServerReturnStatusCode, "1");
+ }
+ context.AddMetadata(kServerDelayBeforeReturnUs, "10000");
+
+ auto stream = stub_->BidiStream(&context);
+
+ for (int i = 0; i < 20; i++) {
+ request.set_message(TString("hello") + ToString(i));
+ bool write_ok = stream->Write(request);
+ bool read_ok = stream->Read(&response);
+ if (i < 10) {
+ EXPECT_TRUE(write_ok);
+ EXPECT_TRUE(read_ok);
+ EXPECT_EQ(response.message(), request.message());
+ } else {
+ EXPECT_FALSE(read_ok);
+ }
+ }
+
+ stream->WritesDone();
+ EXPECT_FALSE(stream->Read(&response));
+
+ Status s = stream->Finish();
+ if (return_cancelled) {
+ EXPECT_EQ(s.error_code(), StatusCode::CANCELLED);
+ } else {
+ EXPECT_TRUE(s.ok());
+ }
+ }
+
+ void DoRequestStream(bool return_cancelled) {
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+
+ context.AddMetadata(kServerReturnAfterNReads, "10");
+ if (return_cancelled) {
+ // "1" means CANCELLED
+ context.AddMetadata(kServerReturnStatusCode, "1");
+ }
+ context.AddMetadata(kServerDelayBeforeReturnUs, "10000");
+
+ auto stream = stub_->RequestStream(&context, &response);
+ for (int i = 0; i < 20; i++) {
+ request.set_message(TString("hello") + ToString(i));
+ bool written = stream->Write(request);
+ if (i < 10) {
+ EXPECT_TRUE(written);
+ }
+ }
+ stream->WritesDone();
+ Status s = stream->Finish();
+ if (return_cancelled) {
+ EXPECT_EQ(s.error_code(), StatusCode::CANCELLED);
+ } else {
+ EXPECT_TRUE(s.ok());
+ }
+ }
+
+ std::shared_ptr<Channel> channel_;
+ std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_;
+ std::unique_ptr<Server> server_;
+ std::ostringstream server_address_;
+ TestServiceImpl service_;
+ int picked_port_;
+};
+
+TEST_F(ServerEarlyReturnTest, BidiStreamEarlyOk) { DoBidiStream(false); }
+
+TEST_F(ServerEarlyReturnTest, BidiStreamEarlyCancel) { DoBidiStream(true); }
+
+TEST_F(ServerEarlyReturnTest, RequestStreamEarlyOK) { DoRequestStream(false); }
+TEST_F(ServerEarlyReturnTest, RequestStreamEarlyCancel) {
+ DoRequestStream(true);
+}
+
+} // namespace
+} // namespace testing
+} // namespace grpc
+
+int main(int argc, char** argv) {
+ grpc::testing::TestEnvironment env(argc, argv);
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/contrib/libs/grpc/test/cpp/end2end/server_interceptors/ya.make b/contrib/libs/grpc/test/cpp/end2end/server_interceptors/ya.make
new file mode 100644
index 0000000000..161176f141
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/end2end/server_interceptors/ya.make
@@ -0,0 +1,32 @@
+GTEST_UGLY()
+
+OWNER(
+ dvshkurko
+ g:ymake
+)
+
+ADDINCL(
+ ${ARCADIA_BUILD_ROOT}/contrib/libs/grpc
+ ${ARCADIA_ROOT}/contrib/libs/grpc
+)
+
+PEERDIR(
+ contrib/libs/grpc/src/proto/grpc/core
+ contrib/libs/grpc/src/proto/grpc/testing
+ contrib/libs/grpc/src/proto/grpc/testing/duplicate
+ contrib/libs/grpc/test/core/util
+ contrib/libs/grpc/test/cpp/end2end
+ contrib/libs/grpc/test/cpp/util
+)
+
+NO_COMPILER_WARNINGS()
+
+SRCDIR(
+ contrib/libs/grpc/test/cpp/end2end
+)
+
+SRCS(
+ server_interceptors_end2end_test.cc
+)
+
+END()
diff --git a/contrib/libs/grpc/test/cpp/end2end/server_interceptors_end2end_test.cc b/contrib/libs/grpc/test/cpp/end2end/server_interceptors_end2end_test.cc
new file mode 100644
index 0000000000..6d2dc772ef
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/end2end/server_interceptors_end2end_test.cc
@@ -0,0 +1,708 @@
+/*
+ *
+ * Copyright 2018 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include <memory>
+#include <vector>
+
+#include <grpcpp/channel.h>
+#include <grpcpp/client_context.h>
+#include <grpcpp/create_channel.h>
+#include <grpcpp/generic/generic_stub.h>
+#include <grpcpp/impl/codegen/proto_utils.h>
+#include <grpcpp/server.h>
+#include <grpcpp/server_builder.h>
+#include <grpcpp/server_context.h>
+#include <grpcpp/support/server_interceptor.h>
+
+#include "src/proto/grpc/testing/echo.grpc.pb.h"
+#include "test/core/util/port.h"
+#include "test/core/util/test_config.h"
+#include "test/cpp/end2end/interceptors_util.h"
+#include "test/cpp/end2end/test_service_impl.h"
+#include "test/cpp/util/byte_buffer_proto_helper.h"
+
+#include <gtest/gtest.h>
+
+namespace grpc {
+namespace testing {
+namespace {
+
+class LoggingInterceptor : public experimental::Interceptor {
+ public:
+ LoggingInterceptor(experimental::ServerRpcInfo* info) {
+ info_ = info;
+
+ // Check the method name and compare to the type
+ const char* method = info->method();
+ experimental::ServerRpcInfo::Type type = info->type();
+
+ // Check that we use one of our standard methods with expected type.
+ // Also allow the health checking service.
+ // We accept BIDI_STREAMING for Echo in case it's an AsyncGenericService
+ // being tested (the GenericRpc test).
+ // The empty method is for the Unimplemented requests that arise
+ // when draining the CQ.
+ EXPECT_TRUE(
+ strstr(method, "/grpc.health") == method ||
+ (strcmp(method, "/grpc.testing.EchoTestService/Echo") == 0 &&
+ (type == experimental::ServerRpcInfo::Type::UNARY ||
+ type == experimental::ServerRpcInfo::Type::BIDI_STREAMING)) ||
+ (strcmp(method, "/grpc.testing.EchoTestService/RequestStream") == 0 &&
+ type == experimental::ServerRpcInfo::Type::CLIENT_STREAMING) ||
+ (strcmp(method, "/grpc.testing.EchoTestService/ResponseStream") == 0 &&
+ type == experimental::ServerRpcInfo::Type::SERVER_STREAMING) ||
+ (strcmp(method, "/grpc.testing.EchoTestService/BidiStream") == 0 &&
+ type == experimental::ServerRpcInfo::Type::BIDI_STREAMING) ||
+ strcmp(method, "/grpc.testing.EchoTestService/Unimplemented") == 0 ||
+ (strcmp(method, "") == 0 &&
+ type == experimental::ServerRpcInfo::Type::BIDI_STREAMING));
+ }
+
+ void Intercept(experimental::InterceptorBatchMethods* methods) override {
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
+ auto* map = methods->GetSendInitialMetadata();
+ // Got nothing better to do here for now
+ EXPECT_EQ(map->size(), static_cast<unsigned>(0));
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
+ EchoRequest req;
+ auto* buffer = methods->GetSerializedSendMessage();
+ auto copied_buffer = *buffer;
+ EXPECT_TRUE(
+ SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
+ .ok());
+ EXPECT_TRUE(req.message().find("Hello") == 0);
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_STATUS)) {
+ auto* map = methods->GetSendTrailingMetadata();
+ bool found = false;
+ // Check that we received the metadata as an echo
+ for (const auto& pair : *map) {
+ found = pair.first.find("testkey") == 0 &&
+ pair.second.find("testvalue") == 0;
+ if (found) break;
+ }
+ EXPECT_EQ(found, true);
+ auto status = methods->GetSendStatus();
+ EXPECT_EQ(status.ok(), true);
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) {
+ auto* map = methods->GetRecvInitialMetadata();
+ bool found = false;
+ // Check that we received the metadata as an echo
+ for (const auto& pair : *map) {
+ found = pair.first.find("testkey") == 0 &&
+ pair.second.find("testvalue") == 0;
+ if (found) break;
+ }
+ EXPECT_EQ(found, true);
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
+ EchoResponse* resp =
+ static_cast<EchoResponse*>(methods->GetRecvMessage());
+ if (resp != nullptr) {
+ EXPECT_TRUE(resp->message().find("Hello") == 0);
+ }
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::POST_RECV_CLOSE)) {
+ // Got nothing interesting to do here
+ }
+ methods->Proceed();
+ }
+
+ private:
+ experimental::ServerRpcInfo* info_;
+};
+
+class LoggingInterceptorFactory
+ : public experimental::ServerInterceptorFactoryInterface {
+ public:
+ virtual experimental::Interceptor* CreateServerInterceptor(
+ experimental::ServerRpcInfo* info) override {
+ return new LoggingInterceptor(info);
+ }
+};
+
+// Test if SendMessage function family works as expected for sync/callback apis
+class SyncSendMessageTester : public experimental::Interceptor {
+ public:
+ SyncSendMessageTester(experimental::ServerRpcInfo* /*info*/) {}
+
+ void Intercept(experimental::InterceptorBatchMethods* methods) override {
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
+ string old_msg =
+ static_cast<const EchoRequest*>(methods->GetSendMessage())->message();
+ EXPECT_EQ(old_msg.find("Hello"), 0u);
+ new_msg_.set_message(TString("World" + old_msg).c_str());
+ methods->ModifySendMessage(&new_msg_);
+ }
+ methods->Proceed();
+ }
+
+ private:
+ EchoRequest new_msg_;
+};
+
+class SyncSendMessageTesterFactory
+ : public experimental::ServerInterceptorFactoryInterface {
+ public:
+ virtual experimental::Interceptor* CreateServerInterceptor(
+ experimental::ServerRpcInfo* info) override {
+ return new SyncSendMessageTester(info);
+ }
+};
+
+// Test if SendMessage function family works as expected for sync/callback apis
+class SyncSendMessageVerifier : public experimental::Interceptor {
+ public:
+ SyncSendMessageVerifier(experimental::ServerRpcInfo* /*info*/) {}
+
+ void Intercept(experimental::InterceptorBatchMethods* methods) override {
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
+ // Make sure that the changes made in SyncSendMessageTester persisted
+ string old_msg =
+ static_cast<const EchoRequest*>(methods->GetSendMessage())->message();
+ EXPECT_EQ(old_msg.find("World"), 0u);
+
+ // Remove the "World" part of the string that we added earlier
+ new_msg_.set_message(old_msg.erase(0, 5));
+ methods->ModifySendMessage(&new_msg_);
+
+ // LoggingInterceptor verifies that changes got reverted
+ }
+ methods->Proceed();
+ }
+
+ private:
+ EchoRequest new_msg_;
+};
+
+class SyncSendMessageVerifierFactory
+ : public experimental::ServerInterceptorFactoryInterface {
+ public:
+ virtual experimental::Interceptor* CreateServerInterceptor(
+ experimental::ServerRpcInfo* info) override {
+ return new SyncSendMessageVerifier(info);
+ }
+};
+
+void MakeBidiStreamingCall(const std::shared_ptr<Channel>& channel) {
+ auto stub = grpc::testing::EchoTestService::NewStub(channel);
+ ClientContext ctx;
+ EchoRequest req;
+ EchoResponse resp;
+ ctx.AddMetadata("testkey", "testvalue");
+ auto stream = stub->BidiStream(&ctx);
+ for (auto i = 0; i < 10; i++) {
+ req.set_message("Hello" + ::ToString(i));
+ stream->Write(req);
+ stream->Read(&resp);
+ EXPECT_EQ(req.message(), resp.message());
+ }
+ ASSERT_TRUE(stream->WritesDone());
+ Status s = stream->Finish();
+ EXPECT_EQ(s.ok(), true);
+}
+
+class ServerInterceptorsEnd2endSyncUnaryTest : public ::testing::Test {
+ protected:
+ ServerInterceptorsEnd2endSyncUnaryTest() {
+ int port = 5004; // grpc_pick_unused_port_or_die();
+
+ ServerBuilder builder;
+ server_address_ = "localhost:" + ::ToString(port);
+ builder.AddListeningPort(server_address_, InsecureServerCredentials());
+ builder.RegisterService(&service_);
+
+ std::vector<
+ std::unique_ptr<experimental::ServerInterceptorFactoryInterface>>
+ creators;
+ creators.push_back(
+ std::unique_ptr<experimental::ServerInterceptorFactoryInterface>(
+ new SyncSendMessageTesterFactory()));
+ creators.push_back(
+ std::unique_ptr<experimental::ServerInterceptorFactoryInterface>(
+ new SyncSendMessageVerifierFactory()));
+ creators.push_back(
+ std::unique_ptr<experimental::ServerInterceptorFactoryInterface>(
+ new LoggingInterceptorFactory()));
+ // Add 20 dummy interceptor factories and null interceptor factories
+ for (auto i = 0; i < 20; i++) {
+ creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
+ new DummyInterceptorFactory()));
+ creators.push_back(std::unique_ptr<NullInterceptorFactory>(
+ new NullInterceptorFactory()));
+ }
+ builder.experimental().SetInterceptorCreators(std::move(creators));
+ server_ = builder.BuildAndStart();
+ }
+ TString server_address_;
+ TestServiceImpl service_;
+ std::unique_ptr<Server> server_;
+};
+
+TEST_F(ServerInterceptorsEnd2endSyncUnaryTest, UnaryTest) {
+ ChannelArguments args;
+ DummyInterceptor::Reset();
+ auto channel =
+ grpc::CreateChannel(server_address_, InsecureChannelCredentials());
+ MakeCall(channel);
+ // Make sure all 20 dummy interceptors were run
+ EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
+}
+
+class ServerInterceptorsEnd2endSyncStreamingTest : public ::testing::Test {
+ protected:
+ ServerInterceptorsEnd2endSyncStreamingTest() {
+ int port = 5005; // grpc_pick_unused_port_or_die();
+
+ ServerBuilder builder;
+ server_address_ = "localhost:" + ::ToString(port);
+ builder.AddListeningPort(server_address_, InsecureServerCredentials());
+ builder.RegisterService(&service_);
+
+ std::vector<
+ std::unique_ptr<experimental::ServerInterceptorFactoryInterface>>
+ creators;
+ creators.push_back(
+ std::unique_ptr<experimental::ServerInterceptorFactoryInterface>(
+ new SyncSendMessageTesterFactory()));
+ creators.push_back(
+ std::unique_ptr<experimental::ServerInterceptorFactoryInterface>(
+ new SyncSendMessageVerifierFactory()));
+ creators.push_back(
+ std::unique_ptr<experimental::ServerInterceptorFactoryInterface>(
+ new LoggingInterceptorFactory()));
+ for (auto i = 0; i < 20; i++) {
+ creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
+ new DummyInterceptorFactory()));
+ }
+ builder.experimental().SetInterceptorCreators(std::move(creators));
+ server_ = builder.BuildAndStart();
+ }
+ TString server_address_;
+ EchoTestServiceStreamingImpl service_;
+ std::unique_ptr<Server> server_;
+};
+
+TEST_F(ServerInterceptorsEnd2endSyncStreamingTest, ClientStreamingTest) {
+ ChannelArguments args;
+ DummyInterceptor::Reset();
+ auto channel =
+ grpc::CreateChannel(server_address_, InsecureChannelCredentials());
+ MakeClientStreamingCall(channel);
+ // Make sure all 20 dummy interceptors were run
+ EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
+}
+
+TEST_F(ServerInterceptorsEnd2endSyncStreamingTest, ServerStreamingTest) {
+ ChannelArguments args;
+ DummyInterceptor::Reset();
+ auto channel =
+ grpc::CreateChannel(server_address_, InsecureChannelCredentials());
+ MakeServerStreamingCall(channel);
+ // Make sure all 20 dummy interceptors were run
+ EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
+}
+
+TEST_F(ServerInterceptorsEnd2endSyncStreamingTest, BidiStreamingTest) {
+ ChannelArguments args;
+ DummyInterceptor::Reset();
+ auto channel =
+ grpc::CreateChannel(server_address_, InsecureChannelCredentials());
+ MakeBidiStreamingCall(channel);
+ // Make sure all 20 dummy interceptors were run
+ EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
+}
+
+class ServerInterceptorsAsyncEnd2endTest : public ::testing::Test {};
+
+TEST_F(ServerInterceptorsAsyncEnd2endTest, UnaryTest) {
+ DummyInterceptor::Reset();
+ int port = 5006; // grpc_pick_unused_port_or_die();
+ string server_address = "localhost:" + ::ToString(port);
+ ServerBuilder builder;
+ EchoTestService::AsyncService service;
+ builder.AddListeningPort(server_address, InsecureServerCredentials());
+ builder.RegisterService(&service);
+ std::vector<std::unique_ptr<experimental::ServerInterceptorFactoryInterface>>
+ creators;
+ creators.push_back(
+ std::unique_ptr<experimental::ServerInterceptorFactoryInterface>(
+ new LoggingInterceptorFactory()));
+ for (auto i = 0; i < 20; i++) {
+ creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
+ new DummyInterceptorFactory()));
+ }
+ builder.experimental().SetInterceptorCreators(std::move(creators));
+ auto cq = builder.AddCompletionQueue();
+ auto server = builder.BuildAndStart();
+
+ ChannelArguments args;
+ auto channel =
+ grpc::CreateChannel(server_address, InsecureChannelCredentials());
+ auto stub = grpc::testing::EchoTestService::NewStub(channel);
+
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+
+ ClientContext cli_ctx;
+ ServerContext srv_ctx;
+ grpc::ServerAsyncResponseWriter<EchoResponse> response_writer(&srv_ctx);
+
+ send_request.set_message("Hello");
+ cli_ctx.AddMetadata("testkey", "testvalue");
+ std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader(
+ stub->AsyncEcho(&cli_ctx, send_request, cq.get()));
+
+ service.RequestEcho(&srv_ctx, &recv_request, &response_writer, cq.get(),
+ cq.get(), tag(2));
+
+ response_reader->Finish(&recv_response, &recv_status, tag(4));
+
+ Verifier().Expect(2, true).Verify(cq.get());
+ EXPECT_EQ(send_request.message(), recv_request.message());
+
+ EXPECT_TRUE(CheckMetadata(srv_ctx.client_metadata(), "testkey", "testvalue"));
+ srv_ctx.AddTrailingMetadata("testkey", "testvalue");
+
+ send_response.set_message(recv_request.message());
+ response_writer.Finish(send_response, Status::OK, tag(3));
+ Verifier().Expect(3, true).Expect(4, true).Verify(cq.get());
+
+ EXPECT_EQ(send_response.message(), recv_response.message());
+ EXPECT_TRUE(recv_status.ok());
+ EXPECT_TRUE(CheckMetadata(cli_ctx.GetServerTrailingMetadata(), "testkey",
+ "testvalue"));
+
+ // Make sure all 20 dummy interceptors were run
+ EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
+
+ server->Shutdown();
+ cq->Shutdown();
+ void* ignored_tag;
+ bool ignored_ok;
+ while (cq->Next(&ignored_tag, &ignored_ok))
+ ;
+ // grpc_recycle_unused_port(port);
+}
+
+TEST_F(ServerInterceptorsAsyncEnd2endTest, BidiStreamingTest) {
+ DummyInterceptor::Reset();
+ int port = 5007; // grpc_pick_unused_port_or_die();
+ string server_address = "localhost:" + ::ToString(port);
+ ServerBuilder builder;
+ EchoTestService::AsyncService service;
+ builder.AddListeningPort(server_address, InsecureServerCredentials());
+ builder.RegisterService(&service);
+ std::vector<std::unique_ptr<experimental::ServerInterceptorFactoryInterface>>
+ creators;
+ creators.push_back(
+ std::unique_ptr<experimental::ServerInterceptorFactoryInterface>(
+ new LoggingInterceptorFactory()));
+ for (auto i = 0; i < 20; i++) {
+ creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
+ new DummyInterceptorFactory()));
+ }
+ builder.experimental().SetInterceptorCreators(std::move(creators));
+ auto cq = builder.AddCompletionQueue();
+ auto server = builder.BuildAndStart();
+
+ ChannelArguments args;
+ auto channel =
+ grpc::CreateChannel(server_address, InsecureChannelCredentials());
+ auto stub = grpc::testing::EchoTestService::NewStub(channel);
+
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+
+ ClientContext cli_ctx;
+ ServerContext srv_ctx;
+ grpc::ServerAsyncReaderWriter<EchoResponse, EchoRequest> srv_stream(&srv_ctx);
+
+ send_request.set_message("Hello");
+ cli_ctx.AddMetadata("testkey", "testvalue");
+ std::unique_ptr<ClientAsyncReaderWriter<EchoRequest, EchoResponse>>
+ cli_stream(stub->AsyncBidiStream(&cli_ctx, cq.get(), tag(1)));
+
+ service.RequestBidiStream(&srv_ctx, &srv_stream, cq.get(), cq.get(), tag(2));
+
+ Verifier().Expect(1, true).Expect(2, true).Verify(cq.get());
+
+ EXPECT_TRUE(CheckMetadata(srv_ctx.client_metadata(), "testkey", "testvalue"));
+ srv_ctx.AddTrailingMetadata("testkey", "testvalue");
+
+ cli_stream->Write(send_request, tag(3));
+ srv_stream.Read(&recv_request, tag(4));
+ Verifier().Expect(3, true).Expect(4, true).Verify(cq.get());
+ EXPECT_EQ(send_request.message(), recv_request.message());
+
+ send_response.set_message(recv_request.message());
+ srv_stream.Write(send_response, tag(5));
+ cli_stream->Read(&recv_response, tag(6));
+ Verifier().Expect(5, true).Expect(6, true).Verify(cq.get());
+ EXPECT_EQ(send_response.message(), recv_response.message());
+
+ cli_stream->WritesDone(tag(7));
+ srv_stream.Read(&recv_request, tag(8));
+ Verifier().Expect(7, true).Expect(8, false).Verify(cq.get());
+
+ srv_stream.Finish(Status::OK, tag(9));
+ cli_stream->Finish(&recv_status, tag(10));
+ Verifier().Expect(9, true).Expect(10, true).Verify(cq.get());
+
+ EXPECT_TRUE(recv_status.ok());
+ EXPECT_TRUE(CheckMetadata(cli_ctx.GetServerTrailingMetadata(), "testkey",
+ "testvalue"));
+
+ // Make sure all 20 dummy interceptors were run
+ EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
+
+ server->Shutdown();
+ cq->Shutdown();
+ void* ignored_tag;
+ bool ignored_ok;
+ while (cq->Next(&ignored_tag, &ignored_ok))
+ ;
+ // grpc_recycle_unused_port(port);
+}
+
+TEST_F(ServerInterceptorsAsyncEnd2endTest, GenericRPCTest) {
+ DummyInterceptor::Reset();
+ int port = 5008; // grpc_pick_unused_port_or_die();
+ string server_address = "localhost:" + ::ToString(port);
+ ServerBuilder builder;
+ AsyncGenericService service;
+ builder.AddListeningPort(server_address, InsecureServerCredentials());
+ builder.RegisterAsyncGenericService(&service);
+ std::vector<std::unique_ptr<experimental::ServerInterceptorFactoryInterface>>
+ creators;
+ creators.reserve(20);
+ for (auto i = 0; i < 20; i++) {
+ creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
+ new DummyInterceptorFactory()));
+ }
+ builder.experimental().SetInterceptorCreators(std::move(creators));
+ auto srv_cq = builder.AddCompletionQueue();
+ CompletionQueue cli_cq;
+ auto server = builder.BuildAndStart();
+
+ ChannelArguments args;
+ auto channel =
+ grpc::CreateChannel(server_address, InsecureChannelCredentials());
+ GenericStub generic_stub(channel);
+
+ const TString kMethodName("/grpc.cpp.test.util.EchoTestService/Echo");
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+
+ ClientContext cli_ctx;
+ GenericServerContext srv_ctx;
+ GenericServerAsyncReaderWriter stream(&srv_ctx);
+
+ // The string needs to be long enough to test heap-based slice.
+ send_request.set_message("Hello");
+ cli_ctx.AddMetadata("testkey", "testvalue");
+
+ CompletionQueue* cq = srv_cq.get();
+ std::thread request_call([cq]() { Verifier().Expect(4, true).Verify(cq); });
+ std::unique_ptr<GenericClientAsyncReaderWriter> call =
+ generic_stub.PrepareCall(&cli_ctx, kMethodName, &cli_cq);
+ call->StartCall(tag(1));
+ Verifier().Expect(1, true).Verify(&cli_cq);
+ std::unique_ptr<ByteBuffer> send_buffer =
+ SerializeToByteBuffer(&send_request);
+ call->Write(*send_buffer, tag(2));
+ // Send ByteBuffer can be destroyed after calling Write.
+ send_buffer.reset();
+ Verifier().Expect(2, true).Verify(&cli_cq);
+ call->WritesDone(tag(3));
+ Verifier().Expect(3, true).Verify(&cli_cq);
+
+ service.RequestCall(&srv_ctx, &stream, srv_cq.get(), srv_cq.get(), tag(4));
+
+ request_call.join();
+ EXPECT_EQ(kMethodName, srv_ctx.method());
+ EXPECT_TRUE(CheckMetadata(srv_ctx.client_metadata(), "testkey", "testvalue"));
+ srv_ctx.AddTrailingMetadata("testkey", "testvalue");
+
+ ByteBuffer recv_buffer;
+ stream.Read(&recv_buffer, tag(5));
+ Verifier().Expect(5, true).Verify(srv_cq.get());
+ EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_request));
+ EXPECT_EQ(send_request.message(), recv_request.message());
+
+ send_response.set_message(recv_request.message());
+ send_buffer = SerializeToByteBuffer(&send_response);
+ stream.Write(*send_buffer, tag(6));
+ send_buffer.reset();
+ Verifier().Expect(6, true).Verify(srv_cq.get());
+
+ stream.Finish(Status::OK, tag(7));
+ // Shutdown srv_cq before we try to get the tag back, to verify that the
+ // interception API handles completion queue shutdowns that take place before
+ // all the tags are returned
+ srv_cq->Shutdown();
+ Verifier().Expect(7, true).Verify(srv_cq.get());
+
+ recv_buffer.Clear();
+ call->Read(&recv_buffer, tag(8));
+ Verifier().Expect(8, true).Verify(&cli_cq);
+ EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_response));
+
+ call->Finish(&recv_status, tag(9));
+ cli_cq.Shutdown();
+ Verifier().Expect(9, true).Verify(&cli_cq);
+
+ EXPECT_EQ(send_response.message(), recv_response.message());
+ EXPECT_TRUE(recv_status.ok());
+ EXPECT_TRUE(CheckMetadata(cli_ctx.GetServerTrailingMetadata(), "testkey",
+ "testvalue"));
+
+ // Make sure all 20 dummy interceptors were run
+ EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
+
+ server->Shutdown();
+ void* ignored_tag;
+ bool ignored_ok;
+ while (cli_cq.Next(&ignored_tag, &ignored_ok))
+ ;
+ while (srv_cq->Next(&ignored_tag, &ignored_ok))
+ ;
+ // grpc_recycle_unused_port(port);
+}
+
+TEST_F(ServerInterceptorsAsyncEnd2endTest, UnimplementedRpcTest) {
+ DummyInterceptor::Reset();
+ int port = 5009; // grpc_pick_unused_port_or_die();
+ string server_address = "localhost:" + ::ToString(port);
+ ServerBuilder builder;
+ builder.AddListeningPort(server_address, InsecureServerCredentials());
+ std::vector<std::unique_ptr<experimental::ServerInterceptorFactoryInterface>>
+ creators;
+ creators.reserve(20);
+ for (auto i = 0; i < 20; i++) {
+ creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
+ new DummyInterceptorFactory()));
+ }
+ builder.experimental().SetInterceptorCreators(std::move(creators));
+ auto cq = builder.AddCompletionQueue();
+ auto server = builder.BuildAndStart();
+
+ ChannelArguments args;
+ std::shared_ptr<Channel> channel =
+ grpc::CreateChannel(server_address, InsecureChannelCredentials());
+ std::unique_ptr<grpc::testing::UnimplementedEchoService::Stub> stub;
+ stub = grpc::testing::UnimplementedEchoService::NewStub(channel);
+ EchoRequest send_request;
+ EchoResponse recv_response;
+ Status recv_status;
+
+ ClientContext cli_ctx;
+ send_request.set_message("Hello");
+ std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader(
+ stub->AsyncUnimplemented(&cli_ctx, send_request, cq.get()));
+
+ response_reader->Finish(&recv_response, &recv_status, tag(4));
+ Verifier().Expect(4, true).Verify(cq.get());
+
+ EXPECT_EQ(StatusCode::UNIMPLEMENTED, recv_status.error_code());
+ EXPECT_EQ("", recv_status.error_message());
+
+ // Make sure all 20 dummy interceptors were run
+ EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
+
+ server->Shutdown();
+ cq->Shutdown();
+ void* ignored_tag;
+ bool ignored_ok;
+ while (cq->Next(&ignored_tag, &ignored_ok))
+ ;
+ // grpc_recycle_unused_port(port);
+}
+
+class ServerInterceptorsSyncUnimplementedEnd2endTest : public ::testing::Test {
+};
+
+TEST_F(ServerInterceptorsSyncUnimplementedEnd2endTest, UnimplementedRpcTest) {
+ DummyInterceptor::Reset();
+ int port = 5010; // grpc_pick_unused_port_or_die();
+ string server_address = "localhost:" + ::ToString(port);
+ ServerBuilder builder;
+ TestServiceImpl service;
+ builder.RegisterService(&service);
+ builder.AddListeningPort(server_address, InsecureServerCredentials());
+ std::vector<std::unique_ptr<experimental::ServerInterceptorFactoryInterface>>
+ creators;
+ creators.reserve(20);
+ for (auto i = 0; i < 20; i++) {
+ creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
+ new DummyInterceptorFactory()));
+ }
+ builder.experimental().SetInterceptorCreators(std::move(creators));
+ auto server = builder.BuildAndStart();
+
+ ChannelArguments args;
+ std::shared_ptr<Channel> channel =
+ grpc::CreateChannel(server_address, InsecureChannelCredentials());
+ std::unique_ptr<grpc::testing::UnimplementedEchoService::Stub> stub;
+ stub = grpc::testing::UnimplementedEchoService::NewStub(channel);
+ EchoRequest send_request;
+ EchoResponse recv_response;
+
+ ClientContext cli_ctx;
+ send_request.set_message("Hello");
+ Status recv_status =
+ stub->Unimplemented(&cli_ctx, send_request, &recv_response);
+
+ EXPECT_EQ(StatusCode::UNIMPLEMENTED, recv_status.error_code());
+ EXPECT_EQ("", recv_status.error_message());
+
+ // Make sure all 20 dummy interceptors were run
+ EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
+
+ server->Shutdown();
+ // grpc_recycle_unused_port(port);
+}
+
+} // namespace
+} // namespace testing
+} // namespace grpc
+
+int main(int argc, char** argv) {
+ grpc::testing::TestEnvironment env(argc, argv);
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/contrib/libs/grpc/test/cpp/end2end/server_load_reporting_end2end_test.cc b/contrib/libs/grpc/test/cpp/end2end/server_load_reporting_end2end_test.cc
new file mode 100644
index 0000000000..13833cf66c
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/end2end/server_load_reporting_end2end_test.cc
@@ -0,0 +1,192 @@
+/*
+ *
+ * Copyright 2018 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include <grpc/support/port_platform.h>
+
+#include <thread>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include <grpc++/grpc++.h>
+#include <grpc/grpc.h>
+#include <grpc/support/log.h>
+#include <grpc/support/string_util.h>
+#include <grpcpp/ext/server_load_reporting.h>
+#include <grpcpp/server_builder.h>
+
+#include "src/proto/grpc/lb/v1/load_reporter.grpc.pb.h"
+#include "src/proto/grpc/testing/echo.grpc.pb.h"
+#include "test/core/util/port.h"
+#include "test/core/util/test_config.h"
+
+namespace grpc {
+namespace testing {
+namespace {
+
+constexpr double kMetricValue = 3.1415;
+constexpr char kMetricName[] = "METRIC_PI";
+
+// Different messages result in different response statuses. For simplicity in
+// computing request bytes, the message sizes should be the same.
+const char kOkMessage[] = "hello";
+const char kServerErrorMessage[] = "sverr";
+const char kClientErrorMessage[] = "clerr";
+
+class EchoTestServiceImpl : public EchoTestService::Service {
+ public:
+ ~EchoTestServiceImpl() override {}
+
+ Status Echo(ServerContext* context, const EchoRequest* request,
+ EchoResponse* response) override {
+ if (request->message() == kServerErrorMessage) {
+ return Status(StatusCode::UNKNOWN, "Server error requested");
+ }
+ if (request->message() == kClientErrorMessage) {
+ return Status(StatusCode::FAILED_PRECONDITION, "Client error requested");
+ }
+ response->set_message(request->message());
+ ::grpc::load_reporter::experimental::AddLoadReportingCost(
+ context, kMetricName, kMetricValue);
+ return Status::OK;
+ }
+};
+
+class ServerLoadReportingEnd2endTest : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ server_address_ =
+ "localhost:" + ToString(grpc_pick_unused_port_or_die());
+ server_ =
+ ServerBuilder()
+ .AddListeningPort(server_address_, InsecureServerCredentials())
+ .RegisterService(&echo_service_)
+ .SetOption(std::unique_ptr<::grpc::ServerBuilderOption>(
+ new ::grpc::load_reporter::experimental::
+ LoadReportingServiceServerBuilderOption()))
+ .BuildAndStart();
+ server_thread_ =
+ std::thread(&ServerLoadReportingEnd2endTest::RunServerLoop, this);
+ }
+
+ void RunServerLoop() { server_->Wait(); }
+
+ void TearDown() override {
+ server_->Shutdown();
+ server_thread_.join();
+ }
+
+ void ClientMakeEchoCalls(const TString& lb_id, const TString& lb_tag,
+ const TString& message, size_t num_requests) {
+ auto stub = EchoTestService::NewStub(
+ grpc::CreateChannel(server_address_, InsecureChannelCredentials()));
+ TString lb_token = lb_id + lb_tag;
+ for (int i = 0; i < num_requests; ++i) {
+ ClientContext ctx;
+ if (!lb_token.empty()) ctx.AddMetadata(GRPC_LB_TOKEN_MD_KEY, lb_token);
+ EchoRequest request;
+ EchoResponse response;
+ request.set_message(message);
+ Status status = stub->Echo(&ctx, request, &response);
+ if (message == kOkMessage) {
+ ASSERT_EQ(status.error_code(), StatusCode::OK);
+ ASSERT_EQ(request.message(), response.message());
+ } else if (message == kServerErrorMessage) {
+ ASSERT_EQ(status.error_code(), StatusCode::UNKNOWN);
+ } else if (message == kClientErrorMessage) {
+ ASSERT_EQ(status.error_code(), StatusCode::FAILED_PRECONDITION);
+ }
+ }
+ }
+
+ TString server_address_;
+ std::unique_ptr<Server> server_;
+ std::thread server_thread_;
+ EchoTestServiceImpl echo_service_;
+};
+
+TEST_F(ServerLoadReportingEnd2endTest, NoCall) {}
+
+TEST_F(ServerLoadReportingEnd2endTest, BasicReport) {
+ auto channel =
+ grpc::CreateChannel(server_address_, InsecureChannelCredentials());
+ auto stub = ::grpc::lb::v1::LoadReporter::NewStub(channel);
+ ClientContext ctx;
+ auto stream = stub->ReportLoad(&ctx);
+ ::grpc::lb::v1::LoadReportRequest request;
+ request.mutable_initial_request()->set_load_balanced_hostname(
+ server_address_);
+ request.mutable_initial_request()->set_load_key("LOAD_KEY");
+ request.mutable_initial_request()
+ ->mutable_load_report_interval()
+ ->set_seconds(5);
+ stream->Write(request);
+ gpr_log(GPR_INFO, "Initial request sent.");
+ ::grpc::lb::v1::LoadReportResponse response;
+ stream->Read(&response);
+ const TString& lb_id = response.initial_response().load_balancer_id();
+ gpr_log(GPR_INFO, "Initial response received (lb_id: %s).", lb_id.c_str());
+ ClientMakeEchoCalls(lb_id, "LB_TAG", kOkMessage, 1);
+ while (true) {
+ stream->Read(&response);
+ if (!response.load().empty()) {
+ ASSERT_EQ(response.load().size(), 3);
+ for (const auto& load : response.load()) {
+ if (load.in_progress_report_case()) {
+ // The special load record that reports the number of in-progress
+ // calls.
+ ASSERT_EQ(load.num_calls_in_progress(), 1);
+ } else if (load.orphaned_load_case()) {
+ // The call from the balancer doesn't have any valid LB token.
+ ASSERT_EQ(load.orphaned_load_case(), load.kLoadKeyUnknown);
+ ASSERT_EQ(load.num_calls_started(), 1);
+ ASSERT_EQ(load.num_calls_finished_without_error(), 0);
+ ASSERT_EQ(load.num_calls_finished_with_error(), 0);
+ } else {
+ // This corresponds to the calls from the client.
+ ASSERT_EQ(load.num_calls_started(), 1);
+ ASSERT_EQ(load.num_calls_finished_without_error(), 1);
+ ASSERT_EQ(load.num_calls_finished_with_error(), 0);
+ ASSERT_GE(load.total_bytes_received(), sizeof(kOkMessage));
+ ASSERT_GE(load.total_bytes_sent(), sizeof(kOkMessage));
+ ASSERT_EQ(load.metric_data().size(), 1);
+ ASSERT_EQ(load.metric_data().Get(0).metric_name(), kMetricName);
+ ASSERT_EQ(load.metric_data().Get(0).num_calls_finished_with_metric(),
+ 1);
+ ASSERT_EQ(load.metric_data().Get(0).total_metric_value(),
+ kMetricValue);
+ }
+ }
+ break;
+ }
+ }
+ stream->WritesDone();
+ ASSERT_EQ(stream->Finish().error_code(), StatusCode::CANCELLED);
+}
+
+// TODO(juanlishen): Add more tests.
+
+} // namespace
+} // namespace testing
+} // namespace grpc
+
+int main(int argc, char** argv) {
+ grpc::testing::TestEnvironment env(argc, argv);
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/contrib/libs/grpc/test/cpp/end2end/service_config_end2end_test.cc b/contrib/libs/grpc/test/cpp/end2end/service_config_end2end_test.cc
new file mode 100644
index 0000000000..cee33343c1
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/end2end/service_config_end2end_test.cc
@@ -0,0 +1,613 @@
+/*
+ *
+ * Copyright 2016 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include <algorithm>
+#include <memory>
+#include <mutex>
+#include <random>
+#include <set>
+#include <util/generic/string.h>
+#include <thread>
+
+#include "y_absl/strings/str_cat.h"
+
+#include <grpc/grpc.h>
+#include <grpc/support/alloc.h>
+#include <grpc/support/atm.h>
+#include <grpc/support/log.h>
+#include <grpc/support/time.h>
+#include <grpcpp/channel.h>
+#include <grpcpp/client_context.h>
+#include <grpcpp/create_channel.h>
+#include <grpcpp/health_check_service_interface.h>
+#include <grpcpp/impl/codegen/sync.h>
+#include <grpcpp/server.h>
+#include <grpcpp/server_builder.h>
+#include <grpcpp/support/validate_service_config.h>
+
+#include "src/core/ext/filters/client_channel/backup_poller.h"
+#include "src/core/ext/filters/client_channel/global_subchannel_pool.h"
+#include "src/core/ext/filters/client_channel/resolver/fake/fake_resolver.h"
+#include "src/core/ext/filters/client_channel/server_address.h"
+#include "src/core/lib/backoff/backoff.h"
+#include "src/core/lib/channel/channel_args.h"
+#include "src/core/lib/gprpp/debug_location.h"
+#include "src/core/lib/gprpp/ref_counted_ptr.h"
+#include "src/core/lib/iomgr/parse_address.h"
+#include "src/core/lib/iomgr/tcp_client.h"
+#include "src/core/lib/security/credentials/fake/fake_credentials.h"
+#include "src/cpp/client/secure_credentials.h"
+#include "src/cpp/server/secure_server_credentials.h"
+
+#include "src/proto/grpc/testing/echo.grpc.pb.h"
+#include "test/core/util/port.h"
+#include "test/core/util/test_config.h"
+#include "test/cpp/end2end/test_service_impl.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+using grpc::testing::EchoRequest;
+using grpc::testing::EchoResponse;
+using std::chrono::system_clock;
+
+namespace grpc {
+namespace testing {
+namespace {
+
+// Subclass of TestServiceImpl that increments a request counter for
+// every call to the Echo RPC.
+class MyTestServiceImpl : public TestServiceImpl {
+ public:
+ MyTestServiceImpl() : request_count_(0) {}
+
+ Status Echo(ServerContext* context, const EchoRequest* request,
+ EchoResponse* response) override {
+ {
+ grpc::internal::MutexLock lock(&mu_);
+ ++request_count_;
+ }
+ AddClient(context->peer());
+ return TestServiceImpl::Echo(context, request, response);
+ }
+
+ int request_count() {
+ grpc::internal::MutexLock lock(&mu_);
+ return request_count_;
+ }
+
+ void ResetCounters() {
+ grpc::internal::MutexLock lock(&mu_);
+ request_count_ = 0;
+ }
+
+ std::set<TString> clients() {
+ grpc::internal::MutexLock lock(&clients_mu_);
+ return clients_;
+ }
+
+ private:
+ void AddClient(const TString& client) {
+ grpc::internal::MutexLock lock(&clients_mu_);
+ clients_.insert(client);
+ }
+
+ grpc::internal::Mutex mu_;
+ int request_count_;
+ grpc::internal::Mutex clients_mu_;
+ std::set<TString> clients_;
+};
+
+class ServiceConfigEnd2endTest : public ::testing::Test {
+ protected:
+ ServiceConfigEnd2endTest()
+ : server_host_("localhost"),
+ kRequestMessage_("Live long and prosper."),
+ creds_(new SecureChannelCredentials(
+ grpc_fake_transport_security_credentials_create())) {}
+
+ static void SetUpTestCase() {
+ // Make the backup poller poll very frequently in order to pick up
+ // updates from all the subchannels's FDs.
+ GPR_GLOBAL_CONFIG_SET(grpc_client_channel_backup_poll_interval_ms, 1);
+ }
+
+ void SetUp() override {
+ grpc_init();
+ response_generator_ =
+ grpc_core::MakeRefCounted<grpc_core::FakeResolverResponseGenerator>();
+ }
+
+ void TearDown() override {
+ for (size_t i = 0; i < servers_.size(); ++i) {
+ servers_[i]->Shutdown();
+ }
+ // Explicitly destroy all the members so that we can make sure grpc_shutdown
+ // has finished by the end of this function, and thus all the registered
+ // LB policy factories are removed.
+ stub_.reset();
+ servers_.clear();
+ creds_.reset();
+ grpc_shutdown_blocking();
+ }
+
+ void CreateServers(size_t num_servers,
+ std::vector<int> ports = std::vector<int>()) {
+ servers_.clear();
+ for (size_t i = 0; i < num_servers; ++i) {
+ int port = 0;
+ if (ports.size() == num_servers) port = ports[i];
+ servers_.emplace_back(new ServerData(port));
+ }
+ }
+
+ void StartServer(size_t index) { servers_[index]->Start(server_host_); }
+
+ void StartServers(size_t num_servers,
+ std::vector<int> ports = std::vector<int>()) {
+ CreateServers(num_servers, std::move(ports));
+ for (size_t i = 0; i < num_servers; ++i) {
+ StartServer(i);
+ }
+ }
+
+ grpc_core::Resolver::Result BuildFakeResults(const std::vector<int>& ports) {
+ grpc_core::Resolver::Result result;
+ for (const int& port : ports) {
+ TString lb_uri_str = y_absl::StrCat("ipv4:127.0.0.1:", port);
+ grpc_uri* lb_uri = grpc_uri_parse(lb_uri_str.c_str(), true);
+ GPR_ASSERT(lb_uri != nullptr);
+ grpc_resolved_address address;
+ GPR_ASSERT(grpc_parse_uri(lb_uri, &address));
+ result.addresses.emplace_back(address.addr, address.len,
+ nullptr /* args */);
+ grpc_uri_destroy(lb_uri);
+ }
+ return result;
+ }
+
+ void SetNextResolutionNoServiceConfig(const std::vector<int>& ports) {
+ grpc_core::ExecCtx exec_ctx;
+ grpc_core::Resolver::Result result = BuildFakeResults(ports);
+ response_generator_->SetResponse(result);
+ }
+
+ void SetNextResolutionValidServiceConfig(const std::vector<int>& ports) {
+ grpc_core::ExecCtx exec_ctx;
+ grpc_core::Resolver::Result result = BuildFakeResults(ports);
+ result.service_config = grpc_core::ServiceConfig::Create(
+ nullptr, "{}", &result.service_config_error);
+ response_generator_->SetResponse(result);
+ }
+
+ void SetNextResolutionInvalidServiceConfig(const std::vector<int>& ports) {
+ grpc_core::ExecCtx exec_ctx;
+ grpc_core::Resolver::Result result = BuildFakeResults(ports);
+ result.service_config = grpc_core::ServiceConfig::Create(
+ nullptr, "{", &result.service_config_error);
+ response_generator_->SetResponse(result);
+ }
+
+ void SetNextResolutionWithServiceConfig(const std::vector<int>& ports,
+ const char* svc_cfg) {
+ grpc_core::ExecCtx exec_ctx;
+ grpc_core::Resolver::Result result = BuildFakeResults(ports);
+ result.service_config = grpc_core::ServiceConfig::Create(
+ nullptr, svc_cfg, &result.service_config_error);
+ response_generator_->SetResponse(result);
+ }
+
+ std::vector<int> GetServersPorts(size_t start_index = 0) {
+ std::vector<int> ports;
+ for (size_t i = start_index; i < servers_.size(); ++i) {
+ ports.push_back(servers_[i]->port_);
+ }
+ return ports;
+ }
+
+ std::unique_ptr<grpc::testing::EchoTestService::Stub> BuildStub(
+ const std::shared_ptr<Channel>& channel) {
+ return grpc::testing::EchoTestService::NewStub(channel);
+ }
+
+ std::shared_ptr<Channel> BuildChannel() {
+ ChannelArguments args;
+ args.SetPointer(GRPC_ARG_FAKE_RESOLVER_RESPONSE_GENERATOR,
+ response_generator_.get());
+ return ::grpc::CreateCustomChannel("fake:///", creds_, args);
+ }
+
+ std::shared_ptr<Channel> BuildChannelWithDefaultServiceConfig() {
+ ChannelArguments args;
+ EXPECT_THAT(grpc::experimental::ValidateServiceConfigJSON(
+ ValidDefaultServiceConfig()),
+ ::testing::StrEq(""));
+ args.SetServiceConfigJSON(ValidDefaultServiceConfig());
+ args.SetPointer(GRPC_ARG_FAKE_RESOLVER_RESPONSE_GENERATOR,
+ response_generator_.get());
+ return ::grpc::CreateCustomChannel("fake:///", creds_, args);
+ }
+
+ std::shared_ptr<Channel> BuildChannelWithInvalidDefaultServiceConfig() {
+ ChannelArguments args;
+ EXPECT_THAT(grpc::experimental::ValidateServiceConfigJSON(
+ InvalidDefaultServiceConfig()),
+ ::testing::HasSubstr("JSON parse error"));
+ args.SetServiceConfigJSON(InvalidDefaultServiceConfig());
+ args.SetPointer(GRPC_ARG_FAKE_RESOLVER_RESPONSE_GENERATOR,
+ response_generator_.get());
+ return ::grpc::CreateCustomChannel("fake:///", creds_, args);
+ }
+
+ bool SendRpc(
+ const std::unique_ptr<grpc::testing::EchoTestService::Stub>& stub,
+ EchoResponse* response = nullptr, int timeout_ms = 1000,
+ Status* result = nullptr, bool wait_for_ready = false) {
+ const bool local_response = (response == nullptr);
+ if (local_response) response = new EchoResponse;
+ EchoRequest request;
+ request.set_message(kRequestMessage_);
+ ClientContext context;
+ context.set_deadline(grpc_timeout_milliseconds_to_deadline(timeout_ms));
+ if (wait_for_ready) context.set_wait_for_ready(true);
+ Status status = stub->Echo(&context, request, response);
+ if (result != nullptr) *result = status;
+ if (local_response) delete response;
+ return status.ok();
+ }
+
+ void CheckRpcSendOk(
+ const std::unique_ptr<grpc::testing::EchoTestService::Stub>& stub,
+ const grpc_core::DebugLocation& location, bool wait_for_ready = false) {
+ EchoResponse response;
+ Status status;
+ const bool success =
+ SendRpc(stub, &response, 2000, &status, wait_for_ready);
+ ASSERT_TRUE(success) << "From " << location.file() << ":" << location.line()
+ << "\n"
+ << "Error: " << status.error_message() << " "
+ << status.error_details();
+ ASSERT_EQ(response.message(), kRequestMessage_)
+ << "From " << location.file() << ":" << location.line();
+ if (!success) abort();
+ }
+
+ void CheckRpcSendFailure(
+ const std::unique_ptr<grpc::testing::EchoTestService::Stub>& stub) {
+ const bool success = SendRpc(stub);
+ EXPECT_FALSE(success);
+ }
+
+ struct ServerData {
+ int port_;
+ std::unique_ptr<Server> server_;
+ MyTestServiceImpl service_;
+ std::unique_ptr<std::thread> thread_;
+ bool server_ready_ = false;
+ bool started_ = false;
+
+ explicit ServerData(int port = 0) {
+ port_ = port > 0 ? port : grpc_pick_unused_port_or_die();
+ }
+
+ void Start(const TString& server_host) {
+ gpr_log(GPR_INFO, "starting server on port %d", port_);
+ started_ = true;
+ grpc::internal::Mutex mu;
+ grpc::internal::MutexLock lock(&mu);
+ grpc::internal::CondVar cond;
+ thread_.reset(new std::thread(
+ std::bind(&ServerData::Serve, this, server_host, &mu, &cond)));
+ cond.WaitUntil(&mu, [this] { return server_ready_; });
+ server_ready_ = false;
+ gpr_log(GPR_INFO, "server startup complete");
+ }
+
+ void Serve(const TString& server_host, grpc::internal::Mutex* mu,
+ grpc::internal::CondVar* cond) {
+ std::ostringstream server_address;
+ server_address << server_host << ":" << port_;
+ ServerBuilder builder;
+ std::shared_ptr<ServerCredentials> creds(new SecureServerCredentials(
+ grpc_fake_transport_security_server_credentials_create()));
+ builder.AddListeningPort(server_address.str(), std::move(creds));
+ builder.RegisterService(&service_);
+ server_ = builder.BuildAndStart();
+ grpc::internal::MutexLock lock(mu);
+ server_ready_ = true;
+ cond->Signal();
+ }
+
+ void Shutdown() {
+ if (!started_) return;
+ server_->Shutdown(grpc_timeout_milliseconds_to_deadline(0));
+ thread_->join();
+ started_ = false;
+ }
+
+ void SetServingStatus(const TString& service, bool serving) {
+ server_->GetHealthCheckService()->SetServingStatus(service, serving);
+ }
+ };
+
+ void ResetCounters() {
+ for (const auto& server : servers_) server->service_.ResetCounters();
+ }
+
+ void WaitForServer(
+ const std::unique_ptr<grpc::testing::EchoTestService::Stub>& stub,
+ size_t server_idx, const grpc_core::DebugLocation& location,
+ bool ignore_failure = false) {
+ do {
+ if (ignore_failure) {
+ SendRpc(stub);
+ } else {
+ CheckRpcSendOk(stub, location, true);
+ }
+ } while (servers_[server_idx]->service_.request_count() == 0);
+ ResetCounters();
+ }
+
+ bool WaitForChannelNotReady(Channel* channel, int timeout_seconds = 5) {
+ const gpr_timespec deadline =
+ grpc_timeout_seconds_to_deadline(timeout_seconds);
+ grpc_connectivity_state state;
+ while ((state = channel->GetState(false /* try_to_connect */)) ==
+ GRPC_CHANNEL_READY) {
+ if (!channel->WaitForStateChange(state, deadline)) return false;
+ }
+ return true;
+ }
+
+ bool WaitForChannelReady(Channel* channel, int timeout_seconds = 5) {
+ const gpr_timespec deadline =
+ grpc_timeout_seconds_to_deadline(timeout_seconds);
+ grpc_connectivity_state state;
+ while ((state = channel->GetState(true /* try_to_connect */)) !=
+ GRPC_CHANNEL_READY) {
+ if (!channel->WaitForStateChange(state, deadline)) return false;
+ }
+ return true;
+ }
+
+ bool SeenAllServers() {
+ for (const auto& server : servers_) {
+ if (server->service_.request_count() == 0) return false;
+ }
+ return true;
+ }
+
+ // Updates \a connection_order by appending to it the index of the newly
+ // connected server. Must be called after every single RPC.
+ void UpdateConnectionOrder(
+ const std::vector<std::unique_ptr<ServerData>>& servers,
+ std::vector<int>* connection_order) {
+ for (size_t i = 0; i < servers.size(); ++i) {
+ if (servers[i]->service_.request_count() == 1) {
+ // Was the server index known? If not, update connection_order.
+ const auto it =
+ std::find(connection_order->begin(), connection_order->end(), i);
+ if (it == connection_order->end()) {
+ connection_order->push_back(i);
+ return;
+ }
+ }
+ }
+ }
+
+ const char* ValidServiceConfigV1() { return "{\"version\": \"1\"}"; }
+
+ const char* ValidServiceConfigV2() { return "{\"version\": \"2\"}"; }
+
+ const char* ValidDefaultServiceConfig() {
+ return "{\"version\": \"valid_default\"}";
+ }
+
+ const char* InvalidDefaultServiceConfig() {
+ return "{\"version\": \"invalid_default\"";
+ }
+
+ const TString server_host_;
+ std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_;
+ std::vector<std::unique_ptr<ServerData>> servers_;
+ grpc_core::RefCountedPtr<grpc_core::FakeResolverResponseGenerator>
+ response_generator_;
+ const TString kRequestMessage_;
+ std::shared_ptr<ChannelCredentials> creds_;
+};
+
+TEST_F(ServiceConfigEnd2endTest, NoServiceConfigTest) {
+ StartServers(1);
+ auto channel = BuildChannel();
+ auto stub = BuildStub(channel);
+ SetNextResolutionNoServiceConfig(GetServersPorts());
+ CheckRpcSendOk(stub, DEBUG_LOCATION);
+ EXPECT_STREQ("{}", channel->GetServiceConfigJSON().c_str());
+}
+
+TEST_F(ServiceConfigEnd2endTest, NoServiceConfigWithDefaultConfigTest) {
+ StartServers(1);
+ auto channel = BuildChannelWithDefaultServiceConfig();
+ auto stub = BuildStub(channel);
+ SetNextResolutionNoServiceConfig(GetServersPorts());
+ CheckRpcSendOk(stub, DEBUG_LOCATION);
+ EXPECT_STREQ(ValidDefaultServiceConfig(),
+ channel->GetServiceConfigJSON().c_str());
+}
+
+TEST_F(ServiceConfigEnd2endTest, InvalidServiceConfigTest) {
+ StartServers(1);
+ auto channel = BuildChannel();
+ auto stub = BuildStub(channel);
+ SetNextResolutionInvalidServiceConfig(GetServersPorts());
+ CheckRpcSendFailure(stub);
+}
+
+TEST_F(ServiceConfigEnd2endTest, ValidServiceConfigUpdatesTest) {
+ StartServers(1);
+ auto channel = BuildChannel();
+ auto stub = BuildStub(channel);
+ SetNextResolutionWithServiceConfig(GetServersPorts(), ValidServiceConfigV1());
+ CheckRpcSendOk(stub, DEBUG_LOCATION);
+ EXPECT_STREQ(ValidServiceConfigV1(), channel->GetServiceConfigJSON().c_str());
+ SetNextResolutionWithServiceConfig(GetServersPorts(), ValidServiceConfigV2());
+ CheckRpcSendOk(stub, DEBUG_LOCATION);
+ EXPECT_STREQ(ValidServiceConfigV2(), channel->GetServiceConfigJSON().c_str());
+}
+
+TEST_F(ServiceConfigEnd2endTest,
+ NoServiceConfigUpdateAfterValidServiceConfigTest) {
+ StartServers(1);
+ auto channel = BuildChannel();
+ auto stub = BuildStub(channel);
+ SetNextResolutionWithServiceConfig(GetServersPorts(), ValidServiceConfigV1());
+ CheckRpcSendOk(stub, DEBUG_LOCATION);
+ EXPECT_STREQ(ValidServiceConfigV1(), channel->GetServiceConfigJSON().c_str());
+ SetNextResolutionNoServiceConfig(GetServersPorts());
+ CheckRpcSendOk(stub, DEBUG_LOCATION);
+ EXPECT_STREQ("{}", channel->GetServiceConfigJSON().c_str());
+}
+
+TEST_F(ServiceConfigEnd2endTest,
+ NoServiceConfigUpdateAfterValidServiceConfigWithDefaultConfigTest) {
+ StartServers(1);
+ auto channel = BuildChannelWithDefaultServiceConfig();
+ auto stub = BuildStub(channel);
+ SetNextResolutionWithServiceConfig(GetServersPorts(), ValidServiceConfigV1());
+ CheckRpcSendOk(stub, DEBUG_LOCATION);
+ EXPECT_STREQ(ValidServiceConfigV1(), channel->GetServiceConfigJSON().c_str());
+ SetNextResolutionNoServiceConfig(GetServersPorts());
+ CheckRpcSendOk(stub, DEBUG_LOCATION);
+ EXPECT_STREQ(ValidDefaultServiceConfig(),
+ channel->GetServiceConfigJSON().c_str());
+}
+
+TEST_F(ServiceConfigEnd2endTest,
+ InvalidServiceConfigUpdateAfterValidServiceConfigTest) {
+ StartServers(1);
+ auto channel = BuildChannel();
+ auto stub = BuildStub(channel);
+ SetNextResolutionWithServiceConfig(GetServersPorts(), ValidServiceConfigV1());
+ CheckRpcSendOk(stub, DEBUG_LOCATION);
+ EXPECT_STREQ(ValidServiceConfigV1(), channel->GetServiceConfigJSON().c_str());
+ SetNextResolutionInvalidServiceConfig(GetServersPorts());
+ CheckRpcSendOk(stub, DEBUG_LOCATION);
+ EXPECT_STREQ(ValidServiceConfigV1(), channel->GetServiceConfigJSON().c_str());
+}
+
+TEST_F(ServiceConfigEnd2endTest,
+ InvalidServiceConfigUpdateAfterValidServiceConfigWithDefaultConfigTest) {
+ StartServers(1);
+ auto channel = BuildChannelWithDefaultServiceConfig();
+ auto stub = BuildStub(channel);
+ SetNextResolutionWithServiceConfig(GetServersPorts(), ValidServiceConfigV1());
+ CheckRpcSendOk(stub, DEBUG_LOCATION);
+ EXPECT_STREQ(ValidServiceConfigV1(), channel->GetServiceConfigJSON().c_str());
+ SetNextResolutionInvalidServiceConfig(GetServersPorts());
+ CheckRpcSendOk(stub, DEBUG_LOCATION);
+ EXPECT_STREQ(ValidServiceConfigV1(), channel->GetServiceConfigJSON().c_str());
+}
+
+TEST_F(ServiceConfigEnd2endTest,
+ ValidServiceConfigAfterInvalidServiceConfigTest) {
+ StartServers(1);
+ auto channel = BuildChannel();
+ auto stub = BuildStub(channel);
+ SetNextResolutionInvalidServiceConfig(GetServersPorts());
+ CheckRpcSendFailure(stub);
+ SetNextResolutionValidServiceConfig(GetServersPorts());
+ CheckRpcSendOk(stub, DEBUG_LOCATION);
+}
+
+TEST_F(ServiceConfigEnd2endTest, NoServiceConfigAfterInvalidServiceConfigTest) {
+ StartServers(1);
+ auto channel = BuildChannel();
+ auto stub = BuildStub(channel);
+ SetNextResolutionInvalidServiceConfig(GetServersPorts());
+ CheckRpcSendFailure(stub);
+ SetNextResolutionNoServiceConfig(GetServersPorts());
+ CheckRpcSendOk(stub, DEBUG_LOCATION);
+ EXPECT_STREQ("{}", channel->GetServiceConfigJSON().c_str());
+}
+
+TEST_F(ServiceConfigEnd2endTest,
+ AnotherInvalidServiceConfigAfterInvalidServiceConfigTest) {
+ StartServers(1);
+ auto channel = BuildChannel();
+ auto stub = BuildStub(channel);
+ SetNextResolutionInvalidServiceConfig(GetServersPorts());
+ CheckRpcSendFailure(stub);
+ SetNextResolutionInvalidServiceConfig(GetServersPorts());
+ CheckRpcSendFailure(stub);
+}
+
+TEST_F(ServiceConfigEnd2endTest, InvalidDefaultServiceConfigTest) {
+ StartServers(1);
+ auto channel = BuildChannelWithInvalidDefaultServiceConfig();
+ auto stub = BuildStub(channel);
+ // An invalid default service config results in a lame channel which fails all
+ // RPCs
+ CheckRpcSendFailure(stub);
+}
+
+TEST_F(ServiceConfigEnd2endTest,
+ InvalidDefaultServiceConfigTestWithValidServiceConfig) {
+ StartServers(1);
+ auto channel = BuildChannelWithInvalidDefaultServiceConfig();
+ auto stub = BuildStub(channel);
+ CheckRpcSendFailure(stub);
+ // An invalid default service config results in a lame channel which fails all
+ // RPCs
+ SetNextResolutionValidServiceConfig(GetServersPorts());
+ CheckRpcSendFailure(stub);
+}
+
+TEST_F(ServiceConfigEnd2endTest,
+ InvalidDefaultServiceConfigTestWithInvalidServiceConfig) {
+ StartServers(1);
+ auto channel = BuildChannelWithInvalidDefaultServiceConfig();
+ auto stub = BuildStub(channel);
+ CheckRpcSendFailure(stub);
+ // An invalid default service config results in a lame channel which fails all
+ // RPCs
+ SetNextResolutionInvalidServiceConfig(GetServersPorts());
+ CheckRpcSendFailure(stub);
+}
+
+TEST_F(ServiceConfigEnd2endTest,
+ InvalidDefaultServiceConfigTestWithNoServiceConfig) {
+ StartServers(1);
+ auto channel = BuildChannelWithInvalidDefaultServiceConfig();
+ auto stub = BuildStub(channel);
+ CheckRpcSendFailure(stub);
+ // An invalid default service config results in a lame channel which fails all
+ // RPCs
+ SetNextResolutionNoServiceConfig(GetServersPorts());
+ CheckRpcSendFailure(stub);
+}
+
+} // namespace
+} // namespace testing
+} // namespace grpc
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ grpc::testing::TestEnvironment env(argc, argv);
+ const auto result = RUN_ALL_TESTS();
+ return result;
+}
diff --git a/contrib/libs/grpc/test/cpp/end2end/shutdown_test.cc b/contrib/libs/grpc/test/cpp/end2end/shutdown_test.cc
new file mode 100644
index 0000000000..3aa7a766c4
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/end2end/shutdown_test.cc
@@ -0,0 +1,170 @@
+/*
+ *
+ * Copyright 2015 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include <thread>
+
+#include <grpc/grpc.h>
+#include <grpc/support/log.h>
+#include <grpc/support/sync.h>
+#include <grpcpp/channel.h>
+#include <grpcpp/client_context.h>
+#include <grpcpp/create_channel.h>
+#include <grpcpp/server.h>
+#include <grpcpp/server_builder.h>
+#include <grpcpp/server_context.h>
+
+#include "src/core/lib/gpr/env.h"
+#include "src/proto/grpc/testing/echo.grpc.pb.h"
+#include "test/core/util/port.h"
+#include "test/core/util/test_config.h"
+#include "test/cpp/util/test_credentials_provider.h"
+
+#include <gtest/gtest.h>
+
+using grpc::testing::EchoRequest;
+using grpc::testing::EchoResponse;
+
+namespace grpc {
+namespace testing {
+
+class TestServiceImpl : public ::grpc::testing::EchoTestService::Service {
+ public:
+ explicit TestServiceImpl(gpr_event* ev) : ev_(ev) {}
+
+ Status Echo(ServerContext* context, const EchoRequest* /*request*/,
+ EchoResponse* /*response*/) override {
+ gpr_event_set(ev_, (void*)1);
+ while (!context->IsCancelled()) {
+ }
+ return Status::OK;
+ }
+
+ private:
+ gpr_event* ev_;
+};
+
+class ShutdownTest : public ::testing::TestWithParam<string> {
+ public:
+ ShutdownTest() : shutdown_(false), service_(&ev_) { gpr_event_init(&ev_); }
+
+ void SetUp() override {
+ port_ = grpc_pick_unused_port_or_die();
+ server_ = SetUpServer(port_);
+ }
+
+ std::unique_ptr<Server> SetUpServer(const int port) {
+ TString server_address = "localhost:" + to_string(port);
+
+ ServerBuilder builder;
+ auto server_creds =
+ GetCredentialsProvider()->GetServerCredentials(GetParam());
+ builder.AddListeningPort(server_address, server_creds);
+ builder.RegisterService(&service_);
+ std::unique_ptr<Server> server = builder.BuildAndStart();
+ return server;
+ }
+
+ void TearDown() override { GPR_ASSERT(shutdown_); }
+
+ void ResetStub() {
+ string target = "dns:localhost:" + to_string(port_);
+ ChannelArguments args;
+ auto channel_creds =
+ GetCredentialsProvider()->GetChannelCredentials(GetParam(), &args);
+ channel_ = ::grpc::CreateCustomChannel(target, channel_creds, args);
+ stub_ = grpc::testing::EchoTestService::NewStub(channel_);
+ }
+
+ string to_string(const int number) {
+ std::stringstream strs;
+ strs << number;
+ return strs.str();
+ }
+
+ void SendRequest() {
+ EchoRequest request;
+ EchoResponse response;
+ request.set_message("Hello");
+ ClientContext context;
+ GPR_ASSERT(!shutdown_);
+ Status s = stub_->Echo(&context, request, &response);
+ GPR_ASSERT(shutdown_);
+ }
+
+ protected:
+ std::shared_ptr<Channel> channel_;
+ std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_;
+ std::unique_ptr<Server> server_;
+ bool shutdown_;
+ int port_;
+ gpr_event ev_;
+ TestServiceImpl service_;
+};
+
+std::vector<string> GetAllCredentialsTypeList() {
+ std::vector<TString> credentials_types;
+ if (GetCredentialsProvider()->GetChannelCredentials(kInsecureCredentialsType,
+ nullptr) != nullptr) {
+ credentials_types.push_back(kInsecureCredentialsType);
+ }
+ auto sec_list = GetCredentialsProvider()->GetSecureCredentialsTypeList();
+ for (auto sec = sec_list.begin(); sec != sec_list.end(); sec++) {
+ credentials_types.push_back(*sec);
+ }
+ GPR_ASSERT(!credentials_types.empty());
+
+ TString credentials_type_list("credentials types:");
+ for (const string& type : credentials_types) {
+ credentials_type_list.append(" " + type);
+ }
+ gpr_log(GPR_INFO, "%s", credentials_type_list.c_str());
+ return credentials_types;
+}
+
+INSTANTIATE_TEST_SUITE_P(End2EndShutdown, ShutdownTest,
+ ::testing::ValuesIn(GetAllCredentialsTypeList()));
+
+// TODO(ctiller): leaked objects in this test
+TEST_P(ShutdownTest, ShutdownTest) {
+ ResetStub();
+
+ // send the request in a background thread
+ std::thread thr(std::bind(&ShutdownTest::SendRequest, this));
+
+ // wait for the server to get the event
+ gpr_event_wait(&ev_, gpr_inf_future(GPR_CLOCK_MONOTONIC));
+
+ shutdown_ = true;
+
+ // shutdown should trigger cancellation causing everything to shutdown
+ auto deadline =
+ std::chrono::system_clock::now() + std::chrono::microseconds(100);
+ server_->Shutdown(deadline);
+ EXPECT_GE(std::chrono::system_clock::now(), deadline);
+
+ thr.join();
+}
+
+} // namespace testing
+} // namespace grpc
+
+int main(int argc, char** argv) {
+ grpc::testing::TestEnvironment env(argc, argv);
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/contrib/libs/grpc/test/cpp/end2end/streaming_throughput_test.cc b/contrib/libs/grpc/test/cpp/end2end/streaming_throughput_test.cc
new file mode 100644
index 0000000000..f2252063fb
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/end2end/streaming_throughput_test.cc
@@ -0,0 +1,193 @@
+/*
+ *
+ * Copyright 2015 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include <time.h>
+#include <mutex>
+#include <thread>
+
+#include <grpc/grpc.h>
+#include <grpc/support/atm.h>
+#include <grpc/support/log.h>
+#include <grpc/support/time.h>
+#include <grpcpp/channel.h>
+#include <grpcpp/client_context.h>
+#include <grpcpp/create_channel.h>
+#include <grpcpp/security/credentials.h>
+#include <grpcpp/security/server_credentials.h>
+#include <grpcpp/server.h>
+#include <grpcpp/server_builder.h>
+#include <grpcpp/server_context.h>
+
+#include "src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.h"
+#include "src/proto/grpc/testing/echo.grpc.pb.h"
+#include "test/core/util/port.h"
+#include "test/core/util/test_config.h"
+
+#include <gtest/gtest.h>
+
+using grpc::testing::EchoRequest;
+using grpc::testing::EchoResponse;
+using std::chrono::system_clock;
+
+const char* kLargeString =
+ "("
+ "To be, or not to be- that is the question:"
+ "Whether 'tis nobler in the mind to suffer"
+ "The slings and arrows of outrageous fortune"
+ "Or to take arms against a sea of troubles,"
+ "And by opposing end them. To die- to sleep-"
+ "No more; and by a sleep to say we end"
+ "The heartache, and the thousand natural shock"
+ "That flesh is heir to. 'Tis a consummation"
+ "Devoutly to be wish'd. To die- to sleep."
+ "To sleep- perchance to dream: ay, there's the rub!"
+ "For in that sleep of death what dreams may come"
+ "When we have shuffled off this mortal coil,"
+ "Must give us pause. There's the respect"
+ "That makes calamity of so long life."
+ "For who would bear the whips and scorns of time,"
+ "Th' oppressor's wrong, the proud man's contumely,"
+ "The pangs of despis'd love, the law's delay,"
+ "The insolence of office, and the spurns"
+ "That patient merit of th' unworthy takes,"
+ "When he himself might his quietus make"
+ "With a bare bodkin? Who would these fardels bear,"
+ "To grunt and sweat under a weary life,"
+ "But that the dread of something after death-"
+ "The undiscover'd country, from whose bourn"
+ "No traveller returns- puzzles the will,"
+ "And makes us rather bear those ills we have"
+ "Than fly to others that we know not of?"
+ "Thus conscience does make cowards of us all,"
+ "And thus the native hue of resolution"
+ "Is sicklied o'er with the pale cast of thought,"
+ "And enterprises of great pith and moment"
+ "With this regard their currents turn awry"
+ "And lose the name of action.- Soft you now!"
+ "The fair Ophelia!- Nymph, in thy orisons"
+ "Be all my sins rememb'red.";
+
+namespace grpc {
+namespace testing {
+
+class TestServiceImpl : public ::grpc::testing::EchoTestService::Service {
+ public:
+ static void BidiStream_Sender(
+ ServerReaderWriter<EchoResponse, EchoRequest>* stream,
+ gpr_atm* should_exit) {
+ EchoResponse response;
+ response.set_message(kLargeString);
+ while (gpr_atm_acq_load(should_exit) == static_cast<gpr_atm>(0)) {
+ struct timespec tv = {0, 1000000}; // 1 ms
+ struct timespec rem;
+ // TODO (vpai): Mark this blocking
+ while (nanosleep(&tv, &rem) != 0) {
+ tv = rem;
+ };
+
+ stream->Write(response);
+ }
+ }
+
+ // Only implement the one method we will be calling for brevity.
+ Status BidiStream(
+ ServerContext* /*context*/,
+ ServerReaderWriter<EchoResponse, EchoRequest>* stream) override {
+ EchoRequest request;
+ gpr_atm should_exit;
+ gpr_atm_rel_store(&should_exit, static_cast<gpr_atm>(0));
+
+ std::thread sender(
+ std::bind(&TestServiceImpl::BidiStream_Sender, stream, &should_exit));
+
+ while (stream->Read(&request)) {
+ struct timespec tv = {0, 3000000}; // 3 ms
+ struct timespec rem;
+ // TODO (vpai): Mark this blocking
+ while (nanosleep(&tv, &rem) != 0) {
+ tv = rem;
+ };
+ }
+ gpr_atm_rel_store(&should_exit, static_cast<gpr_atm>(1));
+ sender.join();
+ return Status::OK;
+ }
+};
+
+class End2endTest : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ int port = grpc_pick_unused_port_or_die();
+ server_address_ << "localhost:" << port;
+ // Setup server
+ ServerBuilder builder;
+ builder.AddListeningPort(server_address_.str(),
+ InsecureServerCredentials());
+ builder.RegisterService(&service_);
+ server_ = builder.BuildAndStart();
+ }
+
+ void TearDown() override { server_->Shutdown(); }
+
+ void ResetStub() {
+ std::shared_ptr<Channel> channel = grpc::CreateChannel(
+ server_address_.str(), InsecureChannelCredentials());
+ stub_ = grpc::testing::EchoTestService::NewStub(channel);
+ }
+
+ std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_;
+ std::unique_ptr<Server> server_;
+ std::ostringstream server_address_;
+ TestServiceImpl service_;
+};
+
+static void Drainer(ClientReaderWriter<EchoRequest, EchoResponse>* reader) {
+ EchoResponse response;
+ while (reader->Read(&response)) {
+ // Just drain out the responses as fast as possible.
+ }
+}
+
+TEST_F(End2endTest, StreamingThroughput) {
+ ResetStub();
+ grpc::ClientContext context;
+ auto stream = stub_->BidiStream(&context);
+
+ auto reader = stream.get();
+ std::thread receiver(std::bind(Drainer, reader));
+
+ for (int i = 0; i < 10000; i++) {
+ EchoRequest request;
+ request.set_message(kLargeString);
+ ASSERT_TRUE(stream->Write(request));
+ if (i % 1000 == 0) {
+ gpr_log(GPR_INFO, "Send count = %d", i);
+ }
+ }
+ stream->WritesDone();
+ receiver.join();
+}
+
+} // namespace testing
+} // namespace grpc
+
+int main(int argc, char** argv) {
+ grpc::testing::TestEnvironment env(argc, argv);
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/contrib/libs/grpc/test/cpp/end2end/test_health_check_service_impl.cc b/contrib/libs/grpc/test/cpp/end2end/test_health_check_service_impl.cc
new file mode 100644
index 0000000000..5b212cba31
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/end2end/test_health_check_service_impl.cc
@@ -0,0 +1,98 @@
+/*
+ *
+ * Copyright 2018 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include "test/cpp/end2end/test_health_check_service_impl.h"
+
+#include <grpc/grpc.h>
+
+using grpc::health::v1::HealthCheckRequest;
+using grpc::health::v1::HealthCheckResponse;
+
+namespace grpc {
+namespace testing {
+
+Status HealthCheckServiceImpl::Check(ServerContext* /*context*/,
+ const HealthCheckRequest* request,
+ HealthCheckResponse* response) {
+ std::lock_guard<std::mutex> lock(mu_);
+ auto iter = status_map_.find(request->service());
+ if (iter == status_map_.end()) {
+ return Status(StatusCode::NOT_FOUND, "");
+ }
+ response->set_status(iter->second);
+ return Status::OK;
+}
+
+Status HealthCheckServiceImpl::Watch(
+ ServerContext* context, const HealthCheckRequest* request,
+ ::grpc::ServerWriter<HealthCheckResponse>* writer) {
+ auto last_state = HealthCheckResponse::UNKNOWN;
+ while (!context->IsCancelled()) {
+ {
+ std::lock_guard<std::mutex> lock(mu_);
+ HealthCheckResponse response;
+ auto iter = status_map_.find(request->service());
+ if (iter == status_map_.end()) {
+ response.set_status(response.SERVICE_UNKNOWN);
+ } else {
+ response.set_status(iter->second);
+ }
+ if (response.status() != last_state) {
+ writer->Write(response, ::grpc::WriteOptions());
+ last_state = response.status();
+ }
+ }
+ gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC),
+ gpr_time_from_millis(1000, GPR_TIMESPAN)));
+ }
+ return Status::OK;
+}
+
+void HealthCheckServiceImpl::SetStatus(
+ const TString& service_name,
+ HealthCheckResponse::ServingStatus status) {
+ std::lock_guard<std::mutex> lock(mu_);
+ if (shutdown_) {
+ status = HealthCheckResponse::NOT_SERVING;
+ }
+ status_map_[service_name] = status;
+}
+
+void HealthCheckServiceImpl::SetAll(HealthCheckResponse::ServingStatus status) {
+ std::lock_guard<std::mutex> lock(mu_);
+ if (shutdown_) {
+ return;
+ }
+ for (auto iter = status_map_.begin(); iter != status_map_.end(); ++iter) {
+ iter->second = status;
+ }
+}
+
+void HealthCheckServiceImpl::Shutdown() {
+ std::lock_guard<std::mutex> lock(mu_);
+ if (shutdown_) {
+ return;
+ }
+ shutdown_ = true;
+ for (auto iter = status_map_.begin(); iter != status_map_.end(); ++iter) {
+ iter->second = HealthCheckResponse::NOT_SERVING;
+ }
+}
+
+} // namespace testing
+} // namespace grpc
diff --git a/contrib/libs/grpc/test/cpp/end2end/test_health_check_service_impl.h b/contrib/libs/grpc/test/cpp/end2end/test_health_check_service_impl.h
new file mode 100644
index 0000000000..d370e4693a
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/end2end/test_health_check_service_impl.h
@@ -0,0 +1,58 @@
+/*
+ *
+ * Copyright 2018 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+#ifndef GRPC_TEST_CPP_END2END_TEST_HEALTH_CHECK_SERVICE_IMPL_H
+#define GRPC_TEST_CPP_END2END_TEST_HEALTH_CHECK_SERVICE_IMPL_H
+
+#include <map>
+#include <mutex>
+
+#include <grpcpp/server_context.h>
+#include <grpcpp/support/status.h>
+
+#include "src/proto/grpc/health/v1/health.grpc.pb.h"
+
+namespace grpc {
+namespace testing {
+
+// A sample sync implementation of the health checking service. This does the
+// same thing as the default one.
+class HealthCheckServiceImpl : public health::v1::Health::Service {
+ public:
+ Status Check(ServerContext* context,
+ const health::v1::HealthCheckRequest* request,
+ health::v1::HealthCheckResponse* response) override;
+ Status Watch(ServerContext* context,
+ const health::v1::HealthCheckRequest* request,
+ ServerWriter<health::v1::HealthCheckResponse>* writer) override;
+ void SetStatus(const TString& service_name,
+ health::v1::HealthCheckResponse::ServingStatus status);
+ void SetAll(health::v1::HealthCheckResponse::ServingStatus status);
+
+ void Shutdown();
+
+ private:
+ std::mutex mu_;
+ bool shutdown_ = false;
+ std::map<const TString, health::v1::HealthCheckResponse::ServingStatus>
+ status_map_;
+};
+
+} // namespace testing
+} // namespace grpc
+
+#endif // GRPC_TEST_CPP_END2END_TEST_HEALTH_CHECK_SERVICE_IMPL_H
diff --git a/contrib/libs/grpc/test/cpp/end2end/test_service_impl.cc b/contrib/libs/grpc/test/cpp/end2end/test_service_impl.cc
new file mode 100644
index 0000000000..078977e824
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/end2end/test_service_impl.cc
@@ -0,0 +1,638 @@
+/*
+ *
+ * Copyright 2016 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include "test/cpp/end2end/test_service_impl.h"
+
+#include <grpc/support/log.h>
+#include <grpcpp/alarm.h>
+#include <grpcpp/security/credentials.h>
+#include <grpcpp/server_context.h>
+#include <gtest/gtest.h>
+
+#include <util/generic/string.h>
+#include <thread>
+
+#include "src/proto/grpc/testing/echo.grpc.pb.h"
+#include "test/cpp/util/string_ref_helper.h"
+
+using std::chrono::system_clock;
+
+namespace grpc {
+namespace testing {
+namespace internal {
+
+// When echo_deadline is requested, deadline seen in the ServerContext is set in
+// the response in seconds.
+void MaybeEchoDeadline(experimental::ServerContextBase* context,
+ const EchoRequest* request, EchoResponse* response) {
+ if (request->has_param() && request->param().echo_deadline()) {
+ gpr_timespec deadline = gpr_inf_future(GPR_CLOCK_REALTIME);
+ if (context->deadline() != system_clock::time_point::max()) {
+ Timepoint2Timespec(context->deadline(), &deadline);
+ }
+ response->mutable_param()->set_request_deadline(deadline.tv_sec);
+ }
+}
+
+void CheckServerAuthContext(const experimental::ServerContextBase* context,
+ const TString& expected_transport_security_type,
+ const TString& expected_client_identity) {
+ std::shared_ptr<const AuthContext> auth_ctx = context->auth_context();
+ std::vector<grpc::string_ref> tst =
+ auth_ctx->FindPropertyValues("transport_security_type");
+ EXPECT_EQ(1u, tst.size());
+ EXPECT_EQ(expected_transport_security_type.c_str(), ToString(tst[0]));
+ if (expected_client_identity.empty()) {
+ EXPECT_TRUE(auth_ctx->GetPeerIdentityPropertyName().empty());
+ EXPECT_TRUE(auth_ctx->GetPeerIdentity().empty());
+ EXPECT_FALSE(auth_ctx->IsPeerAuthenticated());
+ } else {
+ auto identity = auth_ctx->GetPeerIdentity();
+ EXPECT_TRUE(auth_ctx->IsPeerAuthenticated());
+ EXPECT_EQ(1u, identity.size());
+ EXPECT_EQ(expected_client_identity.c_str(), ToString(identity[0]));
+ }
+}
+
+// Returns the number of pairs in metadata that exactly match the given
+// key-value pair. Returns -1 if the pair wasn't found.
+int MetadataMatchCount(
+ const std::multimap<grpc::string_ref, grpc::string_ref>& metadata,
+ const TString& key, const TString& value) {
+ int count = 0;
+ for (const auto& metadatum : metadata) {
+ if (ToString(metadatum.first) == key &&
+ ToString(metadatum.second) == value) {
+ count++;
+ }
+ }
+ return count;
+}
+
+int GetIntValueFromMetadataHelper(
+ const char* key,
+ const std::multimap<grpc::string_ref, grpc::string_ref>& metadata,
+ int default_value) {
+ if (metadata.find(key) != metadata.end()) {
+ std::istringstream iss(ToString(metadata.find(key)->second));
+ iss >> default_value;
+ gpr_log(GPR_INFO, "%s : %d", key, default_value);
+ }
+
+ return default_value;
+}
+
+int GetIntValueFromMetadata(
+ const char* key,
+ const std::multimap<grpc::string_ref, grpc::string_ref>& metadata,
+ int default_value) {
+ return GetIntValueFromMetadataHelper(key, metadata, default_value);
+}
+
+void ServerTryCancel(ServerContext* context) {
+ EXPECT_FALSE(context->IsCancelled());
+ context->TryCancel();
+ gpr_log(GPR_INFO, "Server called TryCancel() to cancel the request");
+ // Now wait until it's really canceled
+ while (!context->IsCancelled()) {
+ gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME),
+ gpr_time_from_micros(1000, GPR_TIMESPAN)));
+ }
+}
+
+void ServerTryCancelNonblocking(experimental::CallbackServerContext* context) {
+ EXPECT_FALSE(context->IsCancelled());
+ context->TryCancel();
+ gpr_log(GPR_INFO,
+ "Server called TryCancelNonblocking() to cancel the request");
+}
+
+} // namespace internal
+
+experimental::ServerUnaryReactor* CallbackTestServiceImpl::Echo(
+ experimental::CallbackServerContext* context, const EchoRequest* request,
+ EchoResponse* response) {
+ class Reactor : public ::grpc::experimental::ServerUnaryReactor {
+ public:
+ Reactor(CallbackTestServiceImpl* service,
+ experimental::CallbackServerContext* ctx,
+ const EchoRequest* request, EchoResponse* response)
+ : service_(service), ctx_(ctx), req_(request), resp_(response) {
+ // It should be safe to call IsCancelled here, even though we don't know
+ // the result. Call it asynchronously to see if we trigger any data races.
+ // Join it in OnDone (technically that could be blocking but shouldn't be
+ // for very long).
+ async_cancel_check_ = std::thread([this] { (void)ctx_->IsCancelled(); });
+
+ started_ = true;
+
+ if (request->has_param() &&
+ request->param().server_notify_client_when_started()) {
+ service->signaller_.SignalClientThatRpcStarted();
+ // Block on the "wait to continue" decision in a different thread since
+ // we can't tie up an EM thread with blocking events. We can join it in
+ // OnDone since it would definitely be done by then.
+ rpc_wait_thread_ = std::thread([this] {
+ service_->signaller_.ServerWaitToContinue();
+ StartRpc();
+ });
+ } else {
+ StartRpc();
+ }
+ }
+
+ void StartRpc() {
+ if (req_->has_param() && req_->param().server_sleep_us() > 0) {
+ // Set an alarm for that much time
+ alarm_.experimental().Set(
+ gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC),
+ gpr_time_from_micros(req_->param().server_sleep_us(),
+ GPR_TIMESPAN)),
+ [this](bool ok) { NonDelayed(ok); });
+ return;
+ }
+ NonDelayed(true);
+ }
+ void OnSendInitialMetadataDone(bool ok) override {
+ EXPECT_TRUE(ok);
+ initial_metadata_sent_ = true;
+ }
+ void OnCancel() override {
+ EXPECT_TRUE(started_);
+ EXPECT_TRUE(ctx_->IsCancelled());
+ on_cancel_invoked_ = true;
+ std::lock_guard<std::mutex> l(cancel_mu_);
+ cancel_cv_.notify_one();
+ }
+ void OnDone() override {
+ if (req_->has_param() && req_->param().echo_metadata_initially()) {
+ EXPECT_TRUE(initial_metadata_sent_);
+ }
+ EXPECT_EQ(ctx_->IsCancelled(), on_cancel_invoked_);
+ // Validate that finishing with a non-OK status doesn't cause cancellation
+ if (req_->has_param() && req_->param().has_expected_error()) {
+ EXPECT_FALSE(on_cancel_invoked_);
+ }
+ async_cancel_check_.join();
+ if (rpc_wait_thread_.joinable()) {
+ rpc_wait_thread_.join();
+ }
+ if (finish_when_cancelled_.joinable()) {
+ finish_when_cancelled_.join();
+ }
+ delete this;
+ }
+
+ private:
+ void NonDelayed(bool ok) {
+ if (!ok) {
+ EXPECT_TRUE(ctx_->IsCancelled());
+ Finish(Status::CANCELLED);
+ return;
+ }
+ if (req_->has_param() && req_->param().server_die()) {
+ gpr_log(GPR_ERROR, "The request should not reach application handler.");
+ GPR_ASSERT(0);
+ }
+ if (req_->has_param() && req_->param().has_expected_error()) {
+ const auto& error = req_->param().expected_error();
+ Finish(Status(static_cast<StatusCode>(error.code()),
+ error.error_message(), error.binary_error_details()));
+ return;
+ }
+ int server_try_cancel = internal::GetIntValueFromMetadata(
+ kServerTryCancelRequest, ctx_->client_metadata(), DO_NOT_CANCEL);
+ if (server_try_cancel != DO_NOT_CANCEL) {
+ // Since this is a unary RPC, by the time this server handler is called,
+ // the 'request' message is already read from the client. So the
+ // scenarios in server_try_cancel don't make much sense. Just cancel the
+ // RPC as long as server_try_cancel is not DO_NOT_CANCEL
+ EXPECT_FALSE(ctx_->IsCancelled());
+ ctx_->TryCancel();
+ gpr_log(GPR_INFO, "Server called TryCancel() to cancel the request");
+ FinishWhenCancelledAsync();
+ return;
+ }
+ resp_->set_message(req_->message());
+ internal::MaybeEchoDeadline(ctx_, req_, resp_);
+ if (service_->host_) {
+ resp_->mutable_param()->set_host(*service_->host_);
+ }
+ if (req_->has_param() && req_->param().client_cancel_after_us()) {
+ {
+ std::unique_lock<std::mutex> lock(service_->mu_);
+ service_->signal_client_ = true;
+ }
+ FinishWhenCancelledAsync();
+ return;
+ } else if (req_->has_param() && req_->param().server_cancel_after_us()) {
+ alarm_.experimental().Set(
+ gpr_time_add(
+ gpr_now(GPR_CLOCK_REALTIME),
+ gpr_time_from_micros(req_->param().server_cancel_after_us(),
+ GPR_TIMESPAN)),
+ [this](bool) { Finish(Status::CANCELLED); });
+ return;
+ } else if (!req_->has_param() || !req_->param().skip_cancelled_check()) {
+ EXPECT_FALSE(ctx_->IsCancelled());
+ }
+
+ if (req_->has_param() && req_->param().echo_metadata_initially()) {
+ const std::multimap<grpc::string_ref, grpc::string_ref>&
+ client_metadata = ctx_->client_metadata();
+ for (const auto& metadatum : client_metadata) {
+ ctx_->AddInitialMetadata(ToString(metadatum.first),
+ ToString(metadatum.second));
+ }
+ StartSendInitialMetadata();
+ }
+
+ if (req_->has_param() && req_->param().echo_metadata()) {
+ const std::multimap<grpc::string_ref, grpc::string_ref>&
+ client_metadata = ctx_->client_metadata();
+ for (const auto& metadatum : client_metadata) {
+ ctx_->AddTrailingMetadata(ToString(metadatum.first),
+ ToString(metadatum.second));
+ }
+ // Terminate rpc with error and debug info in trailer.
+ if (req_->param().debug_info().stack_entries_size() ||
+ !req_->param().debug_info().detail().empty()) {
+ TString serialized_debug_info =
+ req_->param().debug_info().SerializeAsString();
+ ctx_->AddTrailingMetadata(kDebugInfoTrailerKey,
+ serialized_debug_info);
+ Finish(Status::CANCELLED);
+ return;
+ }
+ }
+ if (req_->has_param() &&
+ (req_->param().expected_client_identity().length() > 0 ||
+ req_->param().check_auth_context())) {
+ internal::CheckServerAuthContext(
+ ctx_, req_->param().expected_transport_security_type(),
+ req_->param().expected_client_identity());
+ }
+ if (req_->has_param() && req_->param().response_message_length() > 0) {
+ resp_->set_message(
+ TString(req_->param().response_message_length(), '\0'));
+ }
+ if (req_->has_param() && req_->param().echo_peer()) {
+ resp_->mutable_param()->set_peer(ctx_->peer().c_str());
+ }
+ Finish(Status::OK);
+ }
+ void FinishWhenCancelledAsync() {
+ finish_when_cancelled_ = std::thread([this] {
+ std::unique_lock<std::mutex> l(cancel_mu_);
+ cancel_cv_.wait(l, [this] { return ctx_->IsCancelled(); });
+ Finish(Status::CANCELLED);
+ });
+ }
+
+ CallbackTestServiceImpl* const service_;
+ experimental::CallbackServerContext* const ctx_;
+ const EchoRequest* const req_;
+ EchoResponse* const resp_;
+ Alarm alarm_;
+ std::mutex cancel_mu_;
+ std::condition_variable cancel_cv_;
+ bool initial_metadata_sent_ = false;
+ bool started_ = false;
+ bool on_cancel_invoked_ = false;
+ std::thread async_cancel_check_;
+ std::thread rpc_wait_thread_;
+ std::thread finish_when_cancelled_;
+ };
+
+ return new Reactor(this, context, request, response);
+}
+
+experimental::ServerUnaryReactor*
+CallbackTestServiceImpl::CheckClientInitialMetadata(
+ experimental::CallbackServerContext* context, const SimpleRequest42*,
+ SimpleResponse42*) {
+ class Reactor : public ::grpc::experimental::ServerUnaryReactor {
+ public:
+ explicit Reactor(experimental::CallbackServerContext* ctx) {
+ EXPECT_EQ(internal::MetadataMatchCount(ctx->client_metadata(),
+ kCheckClientInitialMetadataKey,
+ kCheckClientInitialMetadataVal),
+ 1);
+ EXPECT_EQ(ctx->client_metadata().count(kCheckClientInitialMetadataKey),
+ 1u);
+ Finish(Status::OK);
+ }
+ void OnDone() override { delete this; }
+ };
+
+ return new Reactor(context);
+}
+
+experimental::ServerReadReactor<EchoRequest>*
+CallbackTestServiceImpl::RequestStream(
+ experimental::CallbackServerContext* context, EchoResponse* response) {
+ // If 'server_try_cancel' is set in the metadata, the RPC is cancelled by
+ // the server by calling ServerContext::TryCancel() depending on the
+ // value:
+ // CANCEL_BEFORE_PROCESSING: The RPC is cancelled before the server
+ // reads any message from the client CANCEL_DURING_PROCESSING: The RPC
+ // is cancelled while the server is reading messages from the client
+ // CANCEL_AFTER_PROCESSING: The RPC is cancelled after the server reads
+ // all the messages from the client
+ int server_try_cancel = internal::GetIntValueFromMetadata(
+ kServerTryCancelRequest, context->client_metadata(), DO_NOT_CANCEL);
+ if (server_try_cancel == CANCEL_BEFORE_PROCESSING) {
+ internal::ServerTryCancelNonblocking(context);
+ // Don't need to provide a reactor since the RPC is canceled
+ return nullptr;
+ }
+
+ class Reactor : public ::grpc::experimental::ServerReadReactor<EchoRequest> {
+ public:
+ Reactor(experimental::CallbackServerContext* ctx, EchoResponse* response,
+ int server_try_cancel)
+ : ctx_(ctx),
+ response_(response),
+ server_try_cancel_(server_try_cancel) {
+ EXPECT_NE(server_try_cancel, CANCEL_BEFORE_PROCESSING);
+ response->set_message("");
+
+ if (server_try_cancel_ == CANCEL_DURING_PROCESSING) {
+ ctx->TryCancel();
+ // Don't wait for it here
+ }
+ StartRead(&request_);
+ setup_done_ = true;
+ }
+ void OnDone() override { delete this; }
+ void OnCancel() override {
+ EXPECT_TRUE(setup_done_);
+ EXPECT_TRUE(ctx_->IsCancelled());
+ FinishOnce(Status::CANCELLED);
+ }
+ void OnReadDone(bool ok) override {
+ if (ok) {
+ response_->mutable_message()->append(request_.message());
+ num_msgs_read_++;
+ StartRead(&request_);
+ } else {
+ gpr_log(GPR_INFO, "Read: %d messages", num_msgs_read_);
+
+ if (server_try_cancel_ == CANCEL_DURING_PROCESSING) {
+ // Let OnCancel recover this
+ return;
+ }
+ if (server_try_cancel_ == CANCEL_AFTER_PROCESSING) {
+ internal::ServerTryCancelNonblocking(ctx_);
+ return;
+ }
+ FinishOnce(Status::OK);
+ }
+ }
+
+ private:
+ void FinishOnce(const Status& s) {
+ std::lock_guard<std::mutex> l(finish_mu_);
+ if (!finished_) {
+ Finish(s);
+ finished_ = true;
+ }
+ }
+
+ experimental::CallbackServerContext* const ctx_;
+ EchoResponse* const response_;
+ EchoRequest request_;
+ int num_msgs_read_{0};
+ int server_try_cancel_;
+ std::mutex finish_mu_;
+ bool finished_{false};
+ bool setup_done_{false};
+ };
+
+ return new Reactor(context, response, server_try_cancel);
+}
+
+// Return 'kNumResponseStreamMsgs' messages.
+// TODO(yangg) make it generic by adding a parameter into EchoRequest
+experimental::ServerWriteReactor<EchoResponse>*
+CallbackTestServiceImpl::ResponseStream(
+ experimental::CallbackServerContext* context, const EchoRequest* request) {
+ // If 'server_try_cancel' is set in the metadata, the RPC is cancelled by
+ // the server by calling ServerContext::TryCancel() depending on the
+ // value:
+ // CANCEL_BEFORE_PROCESSING: The RPC is cancelled before the server
+ // reads any message from the client CANCEL_DURING_PROCESSING: The RPC
+ // is cancelled while the server is reading messages from the client
+ // CANCEL_AFTER_PROCESSING: The RPC is cancelled after the server reads
+ // all the messages from the client
+ int server_try_cancel = internal::GetIntValueFromMetadata(
+ kServerTryCancelRequest, context->client_metadata(), DO_NOT_CANCEL);
+ if (server_try_cancel == CANCEL_BEFORE_PROCESSING) {
+ internal::ServerTryCancelNonblocking(context);
+ }
+
+ class Reactor
+ : public ::grpc::experimental::ServerWriteReactor<EchoResponse> {
+ public:
+ Reactor(experimental::CallbackServerContext* ctx,
+ const EchoRequest* request, int server_try_cancel)
+ : ctx_(ctx), request_(request), server_try_cancel_(server_try_cancel) {
+ server_coalescing_api_ = internal::GetIntValueFromMetadata(
+ kServerUseCoalescingApi, ctx->client_metadata(), 0);
+ server_responses_to_send_ = internal::GetIntValueFromMetadata(
+ kServerResponseStreamsToSend, ctx->client_metadata(),
+ kServerDefaultResponseStreamsToSend);
+ if (server_try_cancel_ == CANCEL_DURING_PROCESSING) {
+ ctx->TryCancel();
+ }
+ if (server_try_cancel_ != CANCEL_BEFORE_PROCESSING) {
+ if (num_msgs_sent_ < server_responses_to_send_) {
+ NextWrite();
+ }
+ }
+ setup_done_ = true;
+ }
+ void OnDone() override { delete this; }
+ void OnCancel() override {
+ EXPECT_TRUE(setup_done_);
+ EXPECT_TRUE(ctx_->IsCancelled());
+ FinishOnce(Status::CANCELLED);
+ }
+ void OnWriteDone(bool /*ok*/) override {
+ if (num_msgs_sent_ < server_responses_to_send_) {
+ NextWrite();
+ } else if (server_coalescing_api_ != 0) {
+ // We would have already done Finish just after the WriteLast
+ } else if (server_try_cancel_ == CANCEL_DURING_PROCESSING) {
+ // Let OnCancel recover this
+ } else if (server_try_cancel_ == CANCEL_AFTER_PROCESSING) {
+ internal::ServerTryCancelNonblocking(ctx_);
+ } else {
+ FinishOnce(Status::OK);
+ }
+ }
+
+ private:
+ void FinishOnce(const Status& s) {
+ std::lock_guard<std::mutex> l(finish_mu_);
+ if (!finished_) {
+ Finish(s);
+ finished_ = true;
+ }
+ }
+
+ void NextWrite() {
+ response_.set_message(request_->message() +
+ ::ToString(num_msgs_sent_));
+ if (num_msgs_sent_ == server_responses_to_send_ - 1 &&
+ server_coalescing_api_ != 0) {
+ {
+ std::lock_guard<std::mutex> l(finish_mu_);
+ if (!finished_) {
+ num_msgs_sent_++;
+ StartWriteLast(&response_, WriteOptions());
+ }
+ }
+ // If we use WriteLast, we shouldn't wait before attempting Finish
+ FinishOnce(Status::OK);
+ } else {
+ std::lock_guard<std::mutex> l(finish_mu_);
+ if (!finished_) {
+ num_msgs_sent_++;
+ StartWrite(&response_);
+ }
+ }
+ }
+ experimental::CallbackServerContext* const ctx_;
+ const EchoRequest* const request_;
+ EchoResponse response_;
+ int num_msgs_sent_{0};
+ int server_try_cancel_;
+ int server_coalescing_api_;
+ int server_responses_to_send_;
+ std::mutex finish_mu_;
+ bool finished_{false};
+ bool setup_done_{false};
+ };
+ return new Reactor(context, request, server_try_cancel);
+}
+
+experimental::ServerBidiReactor<EchoRequest, EchoResponse>*
+CallbackTestServiceImpl::BidiStream(
+ experimental::CallbackServerContext* context) {
+ class Reactor : public ::grpc::experimental::ServerBidiReactor<EchoRequest,
+ EchoResponse> {
+ public:
+ explicit Reactor(experimental::CallbackServerContext* ctx) : ctx_(ctx) {
+ // If 'server_try_cancel' is set in the metadata, the RPC is cancelled by
+ // the server by calling ServerContext::TryCancel() depending on the
+ // value:
+ // CANCEL_BEFORE_PROCESSING: The RPC is cancelled before the server
+ // reads any message from the client CANCEL_DURING_PROCESSING: The RPC
+ // is cancelled while the server is reading messages from the client
+ // CANCEL_AFTER_PROCESSING: The RPC is cancelled after the server reads
+ // all the messages from the client
+ server_try_cancel_ = internal::GetIntValueFromMetadata(
+ kServerTryCancelRequest, ctx->client_metadata(), DO_NOT_CANCEL);
+ server_write_last_ = internal::GetIntValueFromMetadata(
+ kServerFinishAfterNReads, ctx->client_metadata(), 0);
+ if (server_try_cancel_ == CANCEL_BEFORE_PROCESSING) {
+ internal::ServerTryCancelNonblocking(ctx);
+ } else {
+ if (server_try_cancel_ == CANCEL_DURING_PROCESSING) {
+ ctx->TryCancel();
+ }
+ StartRead(&request_);
+ }
+ setup_done_ = true;
+ }
+ void OnDone() override {
+ {
+ // Use the same lock as finish to make sure that OnDone isn't inlined.
+ std::lock_guard<std::mutex> l(finish_mu_);
+ EXPECT_TRUE(finished_);
+ finish_thread_.join();
+ }
+ delete this;
+ }
+ void OnCancel() override {
+ EXPECT_TRUE(setup_done_);
+ EXPECT_TRUE(ctx_->IsCancelled());
+ FinishOnce(Status::CANCELLED);
+ }
+ void OnReadDone(bool ok) override {
+ if (ok) {
+ num_msgs_read_++;
+ response_.set_message(request_.message());
+ std::lock_guard<std::mutex> l(finish_mu_);
+ if (!finished_) {
+ if (num_msgs_read_ == server_write_last_) {
+ StartWriteLast(&response_, WriteOptions());
+ // If we use WriteLast, we shouldn't wait before attempting Finish
+ } else {
+ StartWrite(&response_);
+ return;
+ }
+ }
+ }
+
+ if (server_try_cancel_ == CANCEL_DURING_PROCESSING) {
+ // Let OnCancel handle this
+ } else if (server_try_cancel_ == CANCEL_AFTER_PROCESSING) {
+ internal::ServerTryCancelNonblocking(ctx_);
+ } else {
+ FinishOnce(Status::OK);
+ }
+ }
+ void OnWriteDone(bool /*ok*/) override {
+ std::lock_guard<std::mutex> l(finish_mu_);
+ if (!finished_) {
+ StartRead(&request_);
+ }
+ }
+
+ private:
+ void FinishOnce(const Status& s) {
+ std::lock_guard<std::mutex> l(finish_mu_);
+ if (!finished_) {
+ finished_ = true;
+ // Finish asynchronously to make sure that there are no deadlocks.
+ finish_thread_ = std::thread([this, s] {
+ std::lock_guard<std::mutex> l(finish_mu_);
+ Finish(s);
+ });
+ }
+ }
+
+ experimental::CallbackServerContext* const ctx_;
+ EchoRequest request_;
+ EchoResponse response_;
+ int num_msgs_read_{0};
+ int server_try_cancel_;
+ int server_write_last_;
+ std::mutex finish_mu_;
+ bool finished_{false};
+ bool setup_done_{false};
+ std::thread finish_thread_;
+ };
+
+ return new Reactor(context);
+}
+
+} // namespace testing
+} // namespace grpc
diff --git a/contrib/libs/grpc/test/cpp/end2end/test_service_impl.h b/contrib/libs/grpc/test/cpp/end2end/test_service_impl.h
new file mode 100644
index 0000000000..5f207f1979
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/end2end/test_service_impl.h
@@ -0,0 +1,495 @@
+/*
+ *
+ * Copyright 2016 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#ifndef GRPC_TEST_CPP_END2END_TEST_SERVICE_IMPL_H
+#define GRPC_TEST_CPP_END2END_TEST_SERVICE_IMPL_H
+
+#include <condition_variable>
+#include <memory>
+#include <mutex>
+
+#include <grpc/grpc.h>
+#include <grpc/support/log.h>
+#include <grpcpp/alarm.h>
+#include <grpcpp/security/credentials.h>
+#include <grpcpp/server_context.h>
+#include <gtest/gtest.h>
+
+#include <util/generic/string.h>
+#include <thread>
+
+#include "src/proto/grpc/testing/echo.grpc.pb.h"
+#include "test/cpp/util/string_ref_helper.h"
+
+#include <util/string/cast.h>
+
+using std::chrono::system_clock;
+
+namespace grpc {
+namespace testing {
+
+const int kServerDefaultResponseStreamsToSend = 3;
+const char* const kServerResponseStreamsToSend = "server_responses_to_send";
+const char* const kServerTryCancelRequest = "server_try_cancel";
+const char* const kDebugInfoTrailerKey = "debug-info-bin";
+const char* const kServerFinishAfterNReads = "server_finish_after_n_reads";
+const char* const kServerUseCoalescingApi = "server_use_coalescing_api";
+const char* const kCheckClientInitialMetadataKey = "custom_client_metadata";
+const char* const kCheckClientInitialMetadataVal = "Value for client metadata";
+
+typedef enum {
+ DO_NOT_CANCEL = 0,
+ CANCEL_BEFORE_PROCESSING,
+ CANCEL_DURING_PROCESSING,
+ CANCEL_AFTER_PROCESSING
+} ServerTryCancelRequestPhase;
+
+namespace internal {
+// When echo_deadline is requested, deadline seen in the ServerContext is set in
+// the response in seconds.
+void MaybeEchoDeadline(experimental::ServerContextBase* context,
+ const EchoRequest* request, EchoResponse* response);
+
+void CheckServerAuthContext(const experimental::ServerContextBase* context,
+ const TString& expected_transport_security_type,
+ const TString& expected_client_identity);
+
+// Returns the number of pairs in metadata that exactly match the given
+// key-value pair. Returns -1 if the pair wasn't found.
+int MetadataMatchCount(
+ const std::multimap<grpc::string_ref, grpc::string_ref>& metadata,
+ const TString& key, const TString& value);
+
+int GetIntValueFromMetadataHelper(
+ const char* key,
+ const std::multimap<grpc::string_ref, grpc::string_ref>& metadata,
+ int default_value);
+
+int GetIntValueFromMetadata(
+ const char* key,
+ const std::multimap<grpc::string_ref, grpc::string_ref>& metadata,
+ int default_value);
+
+void ServerTryCancel(ServerContext* context);
+} // namespace internal
+
+class TestServiceSignaller {
+ public:
+ void ClientWaitUntilRpcStarted() {
+ std::unique_lock<std::mutex> lock(mu_);
+ cv_rpc_started_.wait(lock, [this] { return rpc_started_; });
+ }
+ void ServerWaitToContinue() {
+ std::unique_lock<std::mutex> lock(mu_);
+ cv_server_continue_.wait(lock, [this] { return server_should_continue_; });
+ }
+ void SignalClientThatRpcStarted() {
+ std::unique_lock<std::mutex> lock(mu_);
+ rpc_started_ = true;
+ cv_rpc_started_.notify_one();
+ }
+ void SignalServerToContinue() {
+ std::unique_lock<std::mutex> lock(mu_);
+ server_should_continue_ = true;
+ cv_server_continue_.notify_one();
+ }
+
+ private:
+ std::mutex mu_;
+ std::condition_variable cv_rpc_started_;
+ bool rpc_started_ /* GUARDED_BY(mu_) */ = false;
+ std::condition_variable cv_server_continue_;
+ bool server_should_continue_ /* GUARDED_BY(mu_) */ = false;
+};
+
+template <typename RpcService>
+class TestMultipleServiceImpl : public RpcService {
+ public:
+ TestMultipleServiceImpl() : signal_client_(false), host_() {}
+ explicit TestMultipleServiceImpl(const TString& host)
+ : signal_client_(false), host_(new TString(host)) {}
+
+ Status Echo(ServerContext* context, const EchoRequest* request,
+ EchoResponse* response) {
+ if (request->has_param() &&
+ request->param().server_notify_client_when_started()) {
+ signaller_.SignalClientThatRpcStarted();
+ signaller_.ServerWaitToContinue();
+ }
+
+ // A bit of sleep to make sure that short deadline tests fail
+ if (request->has_param() && request->param().server_sleep_us() > 0) {
+ gpr_sleep_until(
+ gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC),
+ gpr_time_from_micros(request->param().server_sleep_us(),
+ GPR_TIMESPAN)));
+ }
+
+ if (request->has_param() && request->param().server_die()) {
+ gpr_log(GPR_ERROR, "The request should not reach application handler.");
+ GPR_ASSERT(0);
+ }
+ if (request->has_param() && request->param().has_expected_error()) {
+ const auto& error = request->param().expected_error();
+ return Status(static_cast<StatusCode>(error.code()),
+ error.error_message(), error.binary_error_details());
+ }
+ int server_try_cancel = internal::GetIntValueFromMetadata(
+ kServerTryCancelRequest, context->client_metadata(), DO_NOT_CANCEL);
+ if (server_try_cancel > DO_NOT_CANCEL) {
+ // Since this is a unary RPC, by the time this server handler is called,
+ // the 'request' message is already read from the client. So the scenarios
+ // in server_try_cancel don't make much sense. Just cancel the RPC as long
+ // as server_try_cancel is not DO_NOT_CANCEL
+ internal::ServerTryCancel(context);
+ return Status::CANCELLED;
+ }
+
+ response->set_message(request->message());
+ internal::MaybeEchoDeadline(context, request, response);
+ if (host_) {
+ response->mutable_param()->set_host(*host_);
+ }
+ if (request->has_param() && request->param().client_cancel_after_us()) {
+ {
+ std::unique_lock<std::mutex> lock(mu_);
+ signal_client_ = true;
+ ++rpcs_waiting_for_client_cancel_;
+ }
+ while (!context->IsCancelled()) {
+ gpr_sleep_until(gpr_time_add(
+ gpr_now(GPR_CLOCK_REALTIME),
+ gpr_time_from_micros(request->param().client_cancel_after_us(),
+ GPR_TIMESPAN)));
+ }
+ {
+ std::unique_lock<std::mutex> lock(mu_);
+ --rpcs_waiting_for_client_cancel_;
+ }
+ return Status::CANCELLED;
+ } else if (request->has_param() &&
+ request->param().server_cancel_after_us()) {
+ gpr_sleep_until(gpr_time_add(
+ gpr_now(GPR_CLOCK_REALTIME),
+ gpr_time_from_micros(request->param().server_cancel_after_us(),
+ GPR_TIMESPAN)));
+ return Status::CANCELLED;
+ } else if (!request->has_param() ||
+ !request->param().skip_cancelled_check()) {
+ EXPECT_FALSE(context->IsCancelled());
+ }
+
+ if (request->has_param() && request->param().echo_metadata_initially()) {
+ const std::multimap<grpc::string_ref, grpc::string_ref>& client_metadata =
+ context->client_metadata();
+ for (const auto& metadatum : client_metadata) {
+ context->AddInitialMetadata(::ToString(metadatum.first),
+ ::ToString(metadatum.second));
+ }
+ }
+
+ if (request->has_param() && request->param().echo_metadata()) {
+ const std::multimap<grpc::string_ref, grpc::string_ref>& client_metadata =
+ context->client_metadata();
+ for (const auto& metadatum : client_metadata) {
+ context->AddTrailingMetadata(::ToString(metadatum.first),
+ ::ToString(metadatum.second));
+ }
+ // Terminate rpc with error and debug info in trailer.
+ if (request->param().debug_info().stack_entries_size() ||
+ !request->param().debug_info().detail().empty()) {
+ TString serialized_debug_info =
+ request->param().debug_info().SerializeAsString();
+ context->AddTrailingMetadata(kDebugInfoTrailerKey,
+ serialized_debug_info);
+ return Status::CANCELLED;
+ }
+ }
+ if (request->has_param() &&
+ (request->param().expected_client_identity().length() > 0 ||
+ request->param().check_auth_context())) {
+ internal::CheckServerAuthContext(
+ context, request->param().expected_transport_security_type(),
+ request->param().expected_client_identity());
+ }
+ if (request->has_param() &&
+ request->param().response_message_length() > 0) {
+ response->set_message(
+ TString(request->param().response_message_length(), '\0'));
+ }
+ if (request->has_param() && request->param().echo_peer()) {
+ response->mutable_param()->set_peer(context->peer());
+ }
+ return Status::OK;
+ }
+
+ Status Echo1(ServerContext* context, const EchoRequest* request,
+ EchoResponse* response) {
+ return Echo(context, request, response);
+ }
+
+ Status Echo2(ServerContext* context, const EchoRequest* request,
+ EchoResponse* response) {
+ return Echo(context, request, response);
+ }
+
+ Status CheckClientInitialMetadata(ServerContext* context,
+ const SimpleRequest42* /*request*/,
+ SimpleResponse42* /*response*/) {
+ EXPECT_EQ(internal::MetadataMatchCount(context->client_metadata(),
+ kCheckClientInitialMetadataKey,
+ kCheckClientInitialMetadataVal),
+ 1);
+ EXPECT_EQ(1u,
+ context->client_metadata().count(kCheckClientInitialMetadataKey));
+ return Status::OK;
+ }
+
+ // Unimplemented is left unimplemented to test the returned error.
+
+ Status RequestStream(ServerContext* context,
+ ServerReader<EchoRequest>* reader,
+ EchoResponse* response) {
+ // If 'server_try_cancel' is set in the metadata, the RPC is cancelled by
+ // the server by calling ServerContext::TryCancel() depending on the value:
+ // CANCEL_BEFORE_PROCESSING: The RPC is cancelled before the server reads
+ // any message from the client
+ // CANCEL_DURING_PROCESSING: The RPC is cancelled while the server is
+ // reading messages from the client
+ // CANCEL_AFTER_PROCESSING: The RPC is cancelled after the server reads
+ // all the messages from the client
+ int server_try_cancel = internal::GetIntValueFromMetadata(
+ kServerTryCancelRequest, context->client_metadata(), DO_NOT_CANCEL);
+
+ EchoRequest request;
+ response->set_message("");
+
+ if (server_try_cancel == CANCEL_BEFORE_PROCESSING) {
+ internal::ServerTryCancel(context);
+ return Status::CANCELLED;
+ }
+
+ std::thread* server_try_cancel_thd = nullptr;
+ if (server_try_cancel == CANCEL_DURING_PROCESSING) {
+ server_try_cancel_thd =
+ new std::thread([context] { internal::ServerTryCancel(context); });
+ }
+
+ int num_msgs_read = 0;
+ while (reader->Read(&request)) {
+ response->mutable_message()->append(request.message());
+ }
+ gpr_log(GPR_INFO, "Read: %d messages", num_msgs_read);
+
+ if (server_try_cancel_thd != nullptr) {
+ server_try_cancel_thd->join();
+ delete server_try_cancel_thd;
+ return Status::CANCELLED;
+ }
+
+ if (server_try_cancel == CANCEL_AFTER_PROCESSING) {
+ internal::ServerTryCancel(context);
+ return Status::CANCELLED;
+ }
+
+ return Status::OK;
+ }
+
+ // Return 'kNumResponseStreamMsgs' messages.
+ // TODO(yangg) make it generic by adding a parameter into EchoRequest
+ Status ResponseStream(ServerContext* context, const EchoRequest* request,
+ ServerWriter<EchoResponse>* writer) {
+ // If server_try_cancel is set in the metadata, the RPC is cancelled by the
+ // server by calling ServerContext::TryCancel() depending on the value:
+ // CANCEL_BEFORE_PROCESSING: The RPC is cancelled before the server writes
+ // any messages to the client
+ // CANCEL_DURING_PROCESSING: The RPC is cancelled while the server is
+ // writing messages to the client
+ // CANCEL_AFTER_PROCESSING: The RPC is cancelled after the server writes
+ // all the messages to the client
+ int server_try_cancel = internal::GetIntValueFromMetadata(
+ kServerTryCancelRequest, context->client_metadata(), DO_NOT_CANCEL);
+
+ int server_coalescing_api = internal::GetIntValueFromMetadata(
+ kServerUseCoalescingApi, context->client_metadata(), 0);
+
+ int server_responses_to_send = internal::GetIntValueFromMetadata(
+ kServerResponseStreamsToSend, context->client_metadata(),
+ kServerDefaultResponseStreamsToSend);
+
+ if (server_try_cancel == CANCEL_BEFORE_PROCESSING) {
+ internal::ServerTryCancel(context);
+ return Status::CANCELLED;
+ }
+
+ EchoResponse response;
+ std::thread* server_try_cancel_thd = nullptr;
+ if (server_try_cancel == CANCEL_DURING_PROCESSING) {
+ server_try_cancel_thd =
+ new std::thread([context] { internal::ServerTryCancel(context); });
+ }
+
+ for (int i = 0; i < server_responses_to_send; i++) {
+ response.set_message(request->message() + ::ToString(i));
+ if (i == server_responses_to_send - 1 && server_coalescing_api != 0) {
+ writer->WriteLast(response, WriteOptions());
+ } else {
+ writer->Write(response);
+ }
+ }
+
+ if (server_try_cancel_thd != nullptr) {
+ server_try_cancel_thd->join();
+ delete server_try_cancel_thd;
+ return Status::CANCELLED;
+ }
+
+ if (server_try_cancel == CANCEL_AFTER_PROCESSING) {
+ internal::ServerTryCancel(context);
+ return Status::CANCELLED;
+ }
+
+ return Status::OK;
+ }
+
+ Status BidiStream(ServerContext* context,
+ ServerReaderWriter<EchoResponse, EchoRequest>* stream) {
+ // If server_try_cancel is set in the metadata, the RPC is cancelled by the
+ // server by calling ServerContext::TryCancel() depending on the value:
+ // CANCEL_BEFORE_PROCESSING: The RPC is cancelled before the server reads/
+ // writes any messages from/to the client
+ // CANCEL_DURING_PROCESSING: The RPC is cancelled while the server is
+ // reading/writing messages from/to the client
+ // CANCEL_AFTER_PROCESSING: The RPC is cancelled after the server
+ // reads/writes all messages from/to the client
+ int server_try_cancel = internal::GetIntValueFromMetadata(
+ kServerTryCancelRequest, context->client_metadata(), DO_NOT_CANCEL);
+
+ EchoRequest request;
+ EchoResponse response;
+
+ if (server_try_cancel == CANCEL_BEFORE_PROCESSING) {
+ internal::ServerTryCancel(context);
+ return Status::CANCELLED;
+ }
+
+ std::thread* server_try_cancel_thd = nullptr;
+ if (server_try_cancel == CANCEL_DURING_PROCESSING) {
+ server_try_cancel_thd =
+ new std::thread([context] { internal::ServerTryCancel(context); });
+ }
+
+ // kServerFinishAfterNReads suggests after how many reads, the server should
+ // write the last message and send status (coalesced using WriteLast)
+ int server_write_last = internal::GetIntValueFromMetadata(
+ kServerFinishAfterNReads, context->client_metadata(), 0);
+
+ int read_counts = 0;
+ while (stream->Read(&request)) {
+ read_counts++;
+ gpr_log(GPR_INFO, "recv msg %s", request.message().c_str());
+ response.set_message(request.message());
+ if (read_counts == server_write_last) {
+ stream->WriteLast(response, WriteOptions());
+ } else {
+ stream->Write(response);
+ }
+ }
+
+ if (server_try_cancel_thd != nullptr) {
+ server_try_cancel_thd->join();
+ delete server_try_cancel_thd;
+ return Status::CANCELLED;
+ }
+
+ if (server_try_cancel == CANCEL_AFTER_PROCESSING) {
+ internal::ServerTryCancel(context);
+ return Status::CANCELLED;
+ }
+
+ return Status::OK;
+ }
+
+ // Unimplemented is left unimplemented to test the returned error.
+ bool signal_client() {
+ std::unique_lock<std::mutex> lock(mu_);
+ return signal_client_;
+ }
+ void ClientWaitUntilRpcStarted() { signaller_.ClientWaitUntilRpcStarted(); }
+ void SignalServerToContinue() { signaller_.SignalServerToContinue(); }
+ uint64_t RpcsWaitingForClientCancel() {
+ std::unique_lock<std::mutex> lock(mu_);
+ return rpcs_waiting_for_client_cancel_;
+ }
+
+ private:
+ bool signal_client_;
+ std::mutex mu_;
+ TestServiceSignaller signaller_;
+ std::unique_ptr<TString> host_;
+ uint64_t rpcs_waiting_for_client_cancel_ = 0;
+};
+
+class CallbackTestServiceImpl
+ : public ::grpc::testing::EchoTestService::ExperimentalCallbackService {
+ public:
+ CallbackTestServiceImpl() : signal_client_(false), host_() {}
+ explicit CallbackTestServiceImpl(const TString& host)
+ : signal_client_(false), host_(new TString(host)) {}
+
+ experimental::ServerUnaryReactor* Echo(
+ experimental::CallbackServerContext* context, const EchoRequest* request,
+ EchoResponse* response) override;
+
+ experimental::ServerUnaryReactor* CheckClientInitialMetadata(
+ experimental::CallbackServerContext* context, const SimpleRequest42*,
+ SimpleResponse42*) override;
+
+ experimental::ServerReadReactor<EchoRequest>* RequestStream(
+ experimental::CallbackServerContext* context,
+ EchoResponse* response) override;
+
+ experimental::ServerWriteReactor<EchoResponse>* ResponseStream(
+ experimental::CallbackServerContext* context,
+ const EchoRequest* request) override;
+
+ experimental::ServerBidiReactor<EchoRequest, EchoResponse>* BidiStream(
+ experimental::CallbackServerContext* context) override;
+
+ // Unimplemented is left unimplemented to test the returned error.
+ bool signal_client() {
+ std::unique_lock<std::mutex> lock(mu_);
+ return signal_client_;
+ }
+ void ClientWaitUntilRpcStarted() { signaller_.ClientWaitUntilRpcStarted(); }
+ void SignalServerToContinue() { signaller_.SignalServerToContinue(); }
+
+ private:
+ bool signal_client_;
+ std::mutex mu_;
+ TestServiceSignaller signaller_;
+ std::unique_ptr<TString> host_;
+};
+
+using TestServiceImpl =
+ TestMultipleServiceImpl<::grpc::testing::EchoTestService::Service>;
+
+} // namespace testing
+} // namespace grpc
+
+#endif // GRPC_TEST_CPP_END2END_TEST_SERVICE_IMPL_H
diff --git a/contrib/libs/grpc/test/cpp/end2end/thread/ya.make_ b/contrib/libs/grpc/test/cpp/end2end/thread/ya.make_
new file mode 100644
index 0000000000..afabda1c8f
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/end2end/thread/ya.make_
@@ -0,0 +1,31 @@
+GTEST_UGLY()
+
+OWNER(
+ dvshkurko
+ g:ymake
+)
+
+ADDINCL(
+ ${ARCADIA_ROOT}/contrib/libs/grpc
+)
+
+PEERDIR(
+ contrib/libs/grpc/src/proto/grpc/core
+ contrib/libs/grpc/src/proto/grpc/testing
+ contrib/libs/grpc/src/proto/grpc/testing/duplicate
+ contrib/libs/grpc/test/core/util
+ contrib/libs/grpc/test/cpp/end2end
+ contrib/libs/grpc/test/cpp/util
+)
+
+NO_COMPILER_WARNINGS()
+
+SRCDIR(
+ contrib/libs/grpc/test/cpp/end2end
+)
+
+SRCS(
+ thread_stress_test.cc
+)
+
+END()
diff --git a/contrib/libs/grpc/test/cpp/end2end/thread_stress_test.cc b/contrib/libs/grpc/test/cpp/end2end/thread_stress_test.cc
new file mode 100644
index 0000000000..8acb953729
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/end2end/thread_stress_test.cc
@@ -0,0 +1,442 @@
+/*
+ *
+ * Copyright 2015 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include <cinttypes>
+#include <mutex>
+#include <thread>
+
+#include <grpc/grpc.h>
+#include <grpc/support/time.h>
+#include <grpcpp/channel.h>
+#include <grpcpp/client_context.h>
+#include <grpcpp/create_channel.h>
+#include <grpcpp/impl/codegen/sync.h>
+#include <grpcpp/resource_quota.h>
+#include <grpcpp/server.h>
+#include <grpcpp/server_builder.h>
+#include <grpcpp/server_context.h>
+
+#include "src/core/lib/gpr/env.h"
+#include "src/core/lib/surface/api_trace.h"
+#include "src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.h"
+#include "src/proto/grpc/testing/echo.grpc.pb.h"
+#include "test/core/util/port.h"
+#include "test/core/util/test_config.h"
+
+#include <gtest/gtest.h>
+
+using grpc::testing::EchoRequest;
+using grpc::testing::EchoResponse;
+using std::chrono::system_clock;
+
+const int kNumThreads = 100; // Number of threads
+const int kNumAsyncSendThreads = 2;
+const int kNumAsyncReceiveThreads = 50;
+const int kNumAsyncServerThreads = 50;
+const int kNumRpcs = 1000; // Number of RPCs per thread
+
+namespace grpc {
+namespace testing {
+
+class TestServiceImpl : public ::grpc::testing::EchoTestService::Service {
+ public:
+ TestServiceImpl() {}
+
+ Status Echo(ServerContext* /*context*/, const EchoRequest* request,
+ EchoResponse* response) override {
+ response->set_message(request->message());
+ return Status::OK;
+ }
+};
+
+template <class Service>
+class CommonStressTest {
+ public:
+ CommonStressTest() : kMaxMessageSize_(8192) {
+#if TARGET_OS_IPHONE
+ // Workaround Apple CFStream bug
+ gpr_setenv("grpc_cfstream", "0");
+#endif
+ }
+ virtual ~CommonStressTest() {}
+ virtual void SetUp() = 0;
+ virtual void TearDown() = 0;
+ virtual void ResetStub() = 0;
+ virtual bool AllowExhaustion() = 0;
+ grpc::testing::EchoTestService::Stub* GetStub() { return stub_.get(); }
+
+ protected:
+ std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_;
+ std::unique_ptr<Server> server_;
+
+ virtual void SetUpStart(ServerBuilder* builder, Service* service) = 0;
+ void SetUpStartCommon(ServerBuilder* builder, Service* service) {
+ builder->RegisterService(service);
+ builder->SetMaxMessageSize(
+ kMaxMessageSize_); // For testing max message size.
+ }
+ void SetUpEnd(ServerBuilder* builder) { server_ = builder->BuildAndStart(); }
+ void TearDownStart() { server_->Shutdown(); }
+ void TearDownEnd() {}
+
+ private:
+ const int kMaxMessageSize_;
+};
+
+template <class Service>
+class CommonStressTestInsecure : public CommonStressTest<Service> {
+ public:
+ void ResetStub() override {
+ std::shared_ptr<Channel> channel = grpc::CreateChannel(
+ server_address_.str(), InsecureChannelCredentials());
+ this->stub_ = grpc::testing::EchoTestService::NewStub(channel);
+ }
+ bool AllowExhaustion() override { return false; }
+
+ protected:
+ void SetUpStart(ServerBuilder* builder, Service* service) override {
+ int port = 5003; // grpc_pick_unused_port_or_die();
+ this->server_address_ << "localhost:" << port;
+ // Setup server
+ builder->AddListeningPort(server_address_.str(),
+ InsecureServerCredentials());
+ this->SetUpStartCommon(builder, service);
+ }
+
+ private:
+ std::ostringstream server_address_;
+};
+
+template <class Service, bool allow_resource_exhaustion>
+class CommonStressTestInproc : public CommonStressTest<Service> {
+ public:
+ void ResetStub() override {
+ ChannelArguments args;
+ std::shared_ptr<Channel> channel = this->server_->InProcessChannel(args);
+ this->stub_ = grpc::testing::EchoTestService::NewStub(channel);
+ }
+ bool AllowExhaustion() override { return allow_resource_exhaustion; }
+
+ protected:
+ void SetUpStart(ServerBuilder* builder, Service* service) override {
+ this->SetUpStartCommon(builder, service);
+ }
+};
+
+template <class BaseClass>
+class CommonStressTestSyncServer : public BaseClass {
+ public:
+ void SetUp() override {
+ ServerBuilder builder;
+ this->SetUpStart(&builder, &service_);
+ this->SetUpEnd(&builder);
+ }
+ void TearDown() override {
+ this->TearDownStart();
+ this->TearDownEnd();
+ }
+
+ private:
+ TestServiceImpl service_;
+};
+
+template <class BaseClass>
+class CommonStressTestSyncServerLowThreadCount : public BaseClass {
+ public:
+ void SetUp() override {
+ ServerBuilder builder;
+ ResourceQuota quota;
+ this->SetUpStart(&builder, &service_);
+ quota.SetMaxThreads(4);
+ builder.SetResourceQuota(quota);
+ this->SetUpEnd(&builder);
+ }
+ void TearDown() override {
+ this->TearDownStart();
+ this->TearDownEnd();
+ }
+
+ private:
+ TestServiceImpl service_;
+};
+
+template <class BaseClass>
+class CommonStressTestAsyncServer : public BaseClass {
+ public:
+ CommonStressTestAsyncServer() : contexts_(kNumAsyncServerThreads * 100) {}
+ void SetUp() override {
+ shutting_down_ = false;
+ ServerBuilder builder;
+ this->SetUpStart(&builder, &service_);
+ cq_ = builder.AddCompletionQueue();
+ this->SetUpEnd(&builder);
+ for (int i = 0; i < kNumAsyncServerThreads * 100; i++) {
+ RefreshContext(i);
+ }
+ for (int i = 0; i < kNumAsyncServerThreads; i++) {
+ server_threads_.emplace_back(&CommonStressTestAsyncServer::ProcessRpcs,
+ this);
+ }
+ }
+ void TearDown() override {
+ {
+ grpc::internal::MutexLock l(&mu_);
+ this->TearDownStart();
+ shutting_down_ = true;
+ cq_->Shutdown();
+ }
+
+ for (int i = 0; i < kNumAsyncServerThreads; i++) {
+ server_threads_[i].join();
+ }
+
+ void* ignored_tag;
+ bool ignored_ok;
+ while (cq_->Next(&ignored_tag, &ignored_ok))
+ ;
+ this->TearDownEnd();
+ }
+
+ private:
+ void ProcessRpcs() {
+ void* tag;
+ bool ok;
+ while (cq_->Next(&tag, &ok)) {
+ if (ok) {
+ int i = static_cast<int>(reinterpret_cast<intptr_t>(tag));
+ switch (contexts_[i].state) {
+ case Context::READY: {
+ contexts_[i].state = Context::DONE;
+ EchoResponse send_response;
+ send_response.set_message(contexts_[i].recv_request.message());
+ contexts_[i].response_writer->Finish(send_response, Status::OK,
+ tag);
+ break;
+ }
+ case Context::DONE:
+ RefreshContext(i);
+ break;
+ }
+ }
+ }
+ }
+ void RefreshContext(int i) {
+ grpc::internal::MutexLock l(&mu_);
+ if (!shutting_down_) {
+ contexts_[i].state = Context::READY;
+ contexts_[i].srv_ctx.reset(new ServerContext);
+ contexts_[i].response_writer.reset(
+ new grpc::ServerAsyncResponseWriter<EchoResponse>(
+ contexts_[i].srv_ctx.get()));
+ service_.RequestEcho(contexts_[i].srv_ctx.get(),
+ &contexts_[i].recv_request,
+ contexts_[i].response_writer.get(), cq_.get(),
+ cq_.get(), (void*)static_cast<intptr_t>(i));
+ }
+ }
+ struct Context {
+ std::unique_ptr<ServerContext> srv_ctx;
+ std::unique_ptr<grpc::ServerAsyncResponseWriter<EchoResponse>>
+ response_writer;
+ EchoRequest recv_request;
+ enum { READY, DONE } state;
+ };
+ std::vector<Context> contexts_;
+ ::grpc::testing::EchoTestService::AsyncService service_;
+ std::unique_ptr<ServerCompletionQueue> cq_;
+ bool shutting_down_;
+ grpc::internal::Mutex mu_;
+ std::vector<std::thread> server_threads_;
+};
+
+template <class Common>
+class End2endTest : public ::testing::Test {
+ protected:
+ End2endTest() {}
+ void SetUp() override { common_.SetUp(); }
+ void TearDown() override { common_.TearDown(); }
+ void ResetStub() { common_.ResetStub(); }
+
+ Common common_;
+};
+
+static void SendRpc(grpc::testing::EchoTestService::Stub* stub, int num_rpcs,
+ bool allow_exhaustion, gpr_atm* errors) {
+ EchoRequest request;
+ EchoResponse response;
+ request.set_message("Hello");
+
+ for (int i = 0; i < num_rpcs; ++i) {
+ ClientContext context;
+ Status s = stub->Echo(&context, request, &response);
+ EXPECT_TRUE(s.ok() || (allow_exhaustion &&
+ s.error_code() == StatusCode::RESOURCE_EXHAUSTED));
+ if (!s.ok()) {
+ if (!(allow_exhaustion &&
+ s.error_code() == StatusCode::RESOURCE_EXHAUSTED)) {
+ gpr_log(GPR_ERROR, "RPC error: %d: %s", s.error_code(),
+ s.error_message().c_str());
+ }
+ gpr_atm_no_barrier_fetch_add(errors, static_cast<gpr_atm>(1));
+ } else {
+ EXPECT_EQ(response.message(), request.message());
+ }
+ }
+}
+
+typedef ::testing::Types<
+ CommonStressTestSyncServer<CommonStressTestInsecure<TestServiceImpl>>,
+ CommonStressTestSyncServer<CommonStressTestInproc<TestServiceImpl, false>>,
+ CommonStressTestSyncServerLowThreadCount<
+ CommonStressTestInproc<TestServiceImpl, true>>,
+ CommonStressTestAsyncServer<
+ CommonStressTestInsecure<grpc::testing::EchoTestService::AsyncService>>,
+ CommonStressTestAsyncServer<CommonStressTestInproc<
+ grpc::testing::EchoTestService::AsyncService, false>>>
+ CommonTypes;
+TYPED_TEST_SUITE(End2endTest, CommonTypes);
+TYPED_TEST(End2endTest, ThreadStress) {
+ this->common_.ResetStub();
+ std::vector<std::thread> threads;
+ gpr_atm errors;
+ gpr_atm_rel_store(&errors, static_cast<gpr_atm>(0));
+ threads.reserve(kNumThreads);
+ for (int i = 0; i < kNumThreads; ++i) {
+ threads.emplace_back(SendRpc, this->common_.GetStub(), kNumRpcs,
+ this->common_.AllowExhaustion(), &errors);
+ }
+ for (int i = 0; i < kNumThreads; ++i) {
+ threads[i].join();
+ }
+ uint64_t error_cnt = static_cast<uint64_t>(gpr_atm_no_barrier_load(&errors));
+ if (error_cnt != 0) {
+ gpr_log(GPR_INFO, "RPC error count: %" PRIu64, error_cnt);
+ }
+ // If this test allows resource exhaustion, expect that it actually sees some
+ if (this->common_.AllowExhaustion()) {
+ EXPECT_GT(error_cnt, static_cast<uint64_t>(0));
+ }
+}
+
+template <class Common>
+class AsyncClientEnd2endTest : public ::testing::Test {
+ protected:
+ AsyncClientEnd2endTest() : rpcs_outstanding_(0) {}
+
+ void SetUp() override { common_.SetUp(); }
+ void TearDown() override {
+ void* ignored_tag;
+ bool ignored_ok;
+ while (cq_.Next(&ignored_tag, &ignored_ok))
+ ;
+ common_.TearDown();
+ }
+
+ void Wait() {
+ grpc::internal::MutexLock l(&mu_);
+ while (rpcs_outstanding_ != 0) {
+ cv_.Wait(&mu_);
+ }
+
+ cq_.Shutdown();
+ }
+
+ struct AsyncClientCall {
+ EchoResponse response;
+ ClientContext context;
+ Status status;
+ std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader;
+ };
+
+ void AsyncSendRpc(int num_rpcs) {
+ for (int i = 0; i < num_rpcs; ++i) {
+ AsyncClientCall* call = new AsyncClientCall;
+ EchoRequest request;
+ request.set_message(TString("Hello: " + grpc::to_string(i)).c_str());
+ call->response_reader =
+ common_.GetStub()->AsyncEcho(&call->context, request, &cq_);
+ call->response_reader->Finish(&call->response, &call->status,
+ (void*)call);
+
+ grpc::internal::MutexLock l(&mu_);
+ rpcs_outstanding_++;
+ }
+ }
+
+ void AsyncCompleteRpc() {
+ while (true) {
+ void* got_tag;
+ bool ok = false;
+ if (!cq_.Next(&got_tag, &ok)) break;
+ AsyncClientCall* call = static_cast<AsyncClientCall*>(got_tag);
+ if (!ok) {
+ gpr_log(GPR_DEBUG, "Error: %d", call->status.error_code());
+ }
+ delete call;
+
+ bool notify;
+ {
+ grpc::internal::MutexLock l(&mu_);
+ rpcs_outstanding_--;
+ notify = (rpcs_outstanding_ == 0);
+ }
+ if (notify) {
+ cv_.Signal();
+ }
+ }
+ }
+
+ Common common_;
+ CompletionQueue cq_;
+ grpc::internal::Mutex mu_;
+ grpc::internal::CondVar cv_;
+ int rpcs_outstanding_;
+};
+
+TYPED_TEST_SUITE(AsyncClientEnd2endTest, CommonTypes);
+TYPED_TEST(AsyncClientEnd2endTest, ThreadStress) {
+ this->common_.ResetStub();
+ std::vector<std::thread> send_threads, completion_threads;
+ for (int i = 0; i < kNumAsyncReceiveThreads; ++i) {
+ completion_threads.emplace_back(
+ &AsyncClientEnd2endTest_ThreadStress_Test<TypeParam>::AsyncCompleteRpc,
+ this);
+ }
+ for (int i = 0; i < kNumAsyncSendThreads; ++i) {
+ send_threads.emplace_back(
+ &AsyncClientEnd2endTest_ThreadStress_Test<TypeParam>::AsyncSendRpc,
+ this, kNumRpcs);
+ }
+ for (int i = 0; i < kNumAsyncSendThreads; ++i) {
+ send_threads[i].join();
+ }
+
+ this->Wait();
+ for (int i = 0; i < kNumAsyncReceiveThreads; ++i) {
+ completion_threads[i].join();
+ }
+}
+
+} // namespace testing
+} // namespace grpc
+
+int main(int argc, char** argv) {
+ grpc::testing::TestEnvironment env(argc, argv);
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/contrib/libs/grpc/test/cpp/end2end/time_change_test.cc b/contrib/libs/grpc/test/cpp/end2end/time_change_test.cc
new file mode 100644
index 0000000000..48b9eace12
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/end2end/time_change_test.cc
@@ -0,0 +1,367 @@
+/*
+ *
+ * Copyright 2019 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include <grpc/grpc.h>
+#include <grpc/support/log.h>
+#include <grpc/support/time.h>
+#include <grpcpp/channel.h>
+#include <grpcpp/client_context.h>
+#include <grpcpp/create_channel.h>
+#include <grpcpp/server.h>
+#include <grpcpp/server_builder.h>
+#include <grpcpp/server_context.h>
+
+#include "src/core/lib/iomgr/timer.h"
+#include "src/proto/grpc/testing/echo.grpc.pb.h"
+#include "test/core/util/port.h"
+#include "test/core/util/test_config.h"
+#include "test/cpp/end2end/test_service_impl.h"
+#include "test/cpp/util/subprocess.h"
+
+#include <gtest/gtest.h>
+#include <sys/time.h>
+#include <thread>
+
+using grpc::testing::EchoRequest;
+using grpc::testing::EchoResponse;
+
+static TString g_root;
+
+static gpr_mu g_mu;
+extern gpr_timespec (*gpr_now_impl)(gpr_clock_type clock_type);
+gpr_timespec (*gpr_now_impl_orig)(gpr_clock_type clock_type) = gpr_now_impl;
+static int g_time_shift_sec = 0;
+static int g_time_shift_nsec = 0;
+static gpr_timespec now_impl(gpr_clock_type clock) {
+ auto ts = gpr_now_impl_orig(clock);
+ // We only manipulate the realtime clock to simulate changes in wall-clock
+ // time
+ if (clock != GPR_CLOCK_REALTIME) {
+ return ts;
+ }
+ GPR_ASSERT(ts.tv_nsec >= 0);
+ GPR_ASSERT(ts.tv_nsec < GPR_NS_PER_SEC);
+ gpr_mu_lock(&g_mu);
+ ts.tv_sec += g_time_shift_sec;
+ ts.tv_nsec += g_time_shift_nsec;
+ gpr_mu_unlock(&g_mu);
+ if (ts.tv_nsec >= GPR_NS_PER_SEC) {
+ ts.tv_nsec -= GPR_NS_PER_SEC;
+ ++ts.tv_sec;
+ } else if (ts.tv_nsec < 0) {
+ --ts.tv_sec;
+ ts.tv_nsec = GPR_NS_PER_SEC + ts.tv_nsec;
+ }
+ return ts;
+}
+
+// offset the value returned by gpr_now(GPR_CLOCK_REALTIME) by msecs
+// milliseconds
+static void set_now_offset(int msecs) {
+ gpr_mu_lock(&g_mu);
+ g_time_shift_sec = msecs / 1000;
+ g_time_shift_nsec = (msecs % 1000) * 1e6;
+ gpr_mu_unlock(&g_mu);
+}
+
+// restore the original implementation of gpr_now()
+static void reset_now_offset() {
+ gpr_mu_lock(&g_mu);
+ g_time_shift_sec = 0;
+ g_time_shift_nsec = 0;
+ gpr_mu_unlock(&g_mu);
+}
+
+namespace grpc {
+namespace testing {
+
+namespace {
+
+// gpr_now() is called with invalid clock_type
+TEST(TimespecTest, GprNowInvalidClockType) {
+ // initialize to some junk value
+ gpr_clock_type invalid_clock_type = (gpr_clock_type)32641;
+ EXPECT_DEATH(gpr_now(invalid_clock_type), ".*");
+}
+
+// Add timespan with negative nanoseconds
+TEST(TimespecTest, GprTimeAddNegativeNs) {
+ gpr_timespec now = gpr_now(GPR_CLOCK_MONOTONIC);
+ gpr_timespec bad_ts = {1, -1000, GPR_TIMESPAN};
+ EXPECT_DEATH(gpr_time_add(now, bad_ts), ".*");
+}
+
+// Subtract timespan with negative nanoseconds
+TEST(TimespecTest, GprTimeSubNegativeNs) {
+ // Nanoseconds must always be positive. Negative timestamps are represented by
+ // (negative seconds, positive nanoseconds)
+ gpr_timespec now = gpr_now(GPR_CLOCK_MONOTONIC);
+ gpr_timespec bad_ts = {1, -1000, GPR_TIMESPAN};
+ EXPECT_DEATH(gpr_time_sub(now, bad_ts), ".*");
+}
+
+// Add negative milliseconds to gpr_timespec
+TEST(TimespecTest, GrpcNegativeMillisToTimespec) {
+ // -1500 milliseconds converts to timespec (-2 secs, 5 * 10^8 nsec)
+ gpr_timespec ts = grpc_millis_to_timespec(-1500, GPR_CLOCK_MONOTONIC);
+ GPR_ASSERT(ts.tv_sec = -2);
+ GPR_ASSERT(ts.tv_nsec = 5e8);
+ GPR_ASSERT(ts.clock_type == GPR_CLOCK_MONOTONIC);
+}
+
+class TimeChangeTest : public ::testing::Test {
+ protected:
+ TimeChangeTest() {}
+
+ static void SetUpTestCase() {
+ auto port = grpc_pick_unused_port_or_die();
+ std::ostringstream addr_stream;
+ addr_stream << "localhost:" << port;
+ server_address_ = addr_stream.str();
+ server_.reset(new SubProcess({
+ g_root + "/client_crash_test_server",
+ "--address=" + server_address_,
+ }));
+ GPR_ASSERT(server_);
+ // connect to server and make sure it's reachable.
+ auto channel =
+ grpc::CreateChannel(server_address_, InsecureChannelCredentials());
+ GPR_ASSERT(channel);
+ EXPECT_TRUE(channel->WaitForConnected(
+ grpc_timeout_milliseconds_to_deadline(30000)));
+ }
+
+ static void TearDownTestCase() { server_.reset(); }
+
+ void SetUp() {
+ channel_ =
+ grpc::CreateChannel(server_address_, InsecureChannelCredentials());
+ GPR_ASSERT(channel_);
+ stub_ = grpc::testing::EchoTestService::NewStub(channel_);
+ }
+
+ void TearDown() { reset_now_offset(); }
+
+ std::unique_ptr<grpc::testing::EchoTestService::Stub> CreateStub() {
+ return grpc::testing::EchoTestService::NewStub(channel_);
+ }
+
+ std::shared_ptr<Channel> GetChannel() { return channel_; }
+ // time jump offsets in milliseconds
+ const int TIME_OFFSET1 = 20123;
+ const int TIME_OFFSET2 = 5678;
+
+ private:
+ static TString server_address_;
+ static std::unique_ptr<SubProcess> server_;
+ std::shared_ptr<Channel> channel_;
+ std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_;
+};
+TString TimeChangeTest::server_address_;
+std::unique_ptr<SubProcess> TimeChangeTest::server_;
+
+// Wall-clock time jumps forward on client before bidi stream is created
+TEST_F(TimeChangeTest, TimeJumpForwardBeforeStreamCreated) {
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+ context.set_deadline(grpc_timeout_milliseconds_to_deadline(5000));
+ context.AddMetadata(kServerResponseStreamsToSend, "1");
+
+ auto channel = GetChannel();
+ GPR_ASSERT(channel);
+ EXPECT_TRUE(
+ channel->WaitForConnected(grpc_timeout_milliseconds_to_deadline(5000)));
+ auto stub = CreateStub();
+
+ // time jumps forward by TIME_OFFSET1 milliseconds
+ set_now_offset(TIME_OFFSET1);
+ auto stream = stub->BidiStream(&context);
+ request.set_message("Hello");
+ EXPECT_TRUE(stream->Write(request));
+
+ EXPECT_TRUE(stream->WritesDone());
+ EXPECT_TRUE(stream->Read(&response));
+
+ auto status = stream->Finish();
+ EXPECT_TRUE(status.ok());
+}
+
+// Wall-clock time jumps back on client before bidi stream is created
+TEST_F(TimeChangeTest, TimeJumpBackBeforeStreamCreated) {
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+ context.set_deadline(grpc_timeout_milliseconds_to_deadline(5000));
+ context.AddMetadata(kServerResponseStreamsToSend, "1");
+
+ auto channel = GetChannel();
+ GPR_ASSERT(channel);
+ EXPECT_TRUE(
+ channel->WaitForConnected(grpc_timeout_milliseconds_to_deadline(5000)));
+ auto stub = CreateStub();
+
+ // time jumps back by TIME_OFFSET1 milliseconds
+ set_now_offset(-TIME_OFFSET1);
+ auto stream = stub->BidiStream(&context);
+ request.set_message("Hello");
+ EXPECT_TRUE(stream->Write(request));
+
+ EXPECT_TRUE(stream->WritesDone());
+ EXPECT_TRUE(stream->Read(&response));
+ EXPECT_EQ(request.message(), response.message());
+
+ auto status = stream->Finish();
+ EXPECT_TRUE(status.ok());
+}
+
+// Wall-clock time jumps forward on client while call is in progress
+TEST_F(TimeChangeTest, TimeJumpForwardAfterStreamCreated) {
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+ context.set_deadline(grpc_timeout_milliseconds_to_deadline(5000));
+ context.AddMetadata(kServerResponseStreamsToSend, "2");
+
+ auto channel = GetChannel();
+ GPR_ASSERT(channel);
+ EXPECT_TRUE(
+ channel->WaitForConnected(grpc_timeout_milliseconds_to_deadline(5000)));
+ auto stub = CreateStub();
+
+ auto stream = stub->BidiStream(&context);
+
+ request.set_message("Hello");
+ EXPECT_TRUE(stream->Write(request));
+ EXPECT_TRUE(stream->Read(&response));
+
+ // time jumps forward by TIME_OFFSET1 milliseconds.
+ set_now_offset(TIME_OFFSET1);
+
+ request.set_message("World");
+ EXPECT_TRUE(stream->Write(request));
+ EXPECT_TRUE(stream->WritesDone());
+ EXPECT_TRUE(stream->Read(&response));
+
+ auto status = stream->Finish();
+ EXPECT_TRUE(status.ok());
+}
+
+// Wall-clock time jumps back on client while call is in progress
+TEST_F(TimeChangeTest, TimeJumpBackAfterStreamCreated) {
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+ context.set_deadline(grpc_timeout_milliseconds_to_deadline(5000));
+ context.AddMetadata(kServerResponseStreamsToSend, "2");
+
+ auto channel = GetChannel();
+ GPR_ASSERT(channel);
+ EXPECT_TRUE(
+ channel->WaitForConnected(grpc_timeout_milliseconds_to_deadline(5000)));
+ auto stub = CreateStub();
+
+ auto stream = stub->BidiStream(&context);
+
+ request.set_message("Hello");
+ EXPECT_TRUE(stream->Write(request));
+ EXPECT_TRUE(stream->Read(&response));
+
+ // time jumps back TIME_OFFSET1 milliseconds.
+ set_now_offset(-TIME_OFFSET1);
+
+ request.set_message("World");
+ EXPECT_TRUE(stream->Write(request));
+ EXPECT_TRUE(stream->WritesDone());
+ EXPECT_TRUE(stream->Read(&response));
+
+ auto status = stream->Finish();
+ EXPECT_TRUE(status.ok());
+}
+
+// Wall-clock time jumps forward and backwards during call
+TEST_F(TimeChangeTest, TimeJumpForwardAndBackDuringCall) {
+ EchoRequest request;
+ EchoResponse response;
+ ClientContext context;
+ context.set_deadline(grpc_timeout_milliseconds_to_deadline(5000));
+ context.AddMetadata(kServerResponseStreamsToSend, "2");
+
+ auto channel = GetChannel();
+ GPR_ASSERT(channel);
+
+ EXPECT_TRUE(
+ channel->WaitForConnected(grpc_timeout_milliseconds_to_deadline(5000)));
+ auto stub = CreateStub();
+ auto stream = stub->BidiStream(&context);
+
+ request.set_message("Hello");
+ EXPECT_TRUE(stream->Write(request));
+
+ // time jumps back by TIME_OFFSET2 milliseconds
+ set_now_offset(-TIME_OFFSET2);
+
+ EXPECT_TRUE(stream->Read(&response));
+ request.set_message("World");
+
+ // time jumps forward by TIME_OFFSET milliseconds
+ set_now_offset(TIME_OFFSET1);
+
+ EXPECT_TRUE(stream->Write(request));
+
+ // time jumps back by TIME_OFFSET2 milliseconds
+ set_now_offset(-TIME_OFFSET2);
+
+ EXPECT_TRUE(stream->WritesDone());
+
+ // time jumps back by TIME_OFFSET2 milliseconds
+ set_now_offset(-TIME_OFFSET2);
+
+ EXPECT_TRUE(stream->Read(&response));
+
+ // time jumps back by TIME_OFFSET2 milliseconds
+ set_now_offset(-TIME_OFFSET2);
+
+ auto status = stream->Finish();
+ EXPECT_TRUE(status.ok());
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace grpc
+
+int main(int argc, char** argv) {
+ TString me = argv[0];
+ // get index of last slash in path to test binary
+ auto lslash = me.rfind('/');
+ // set g_root = path to directory containing test binary
+ if (lslash != TString::npos) {
+ g_root = me.substr(0, lslash);
+ } else {
+ g_root = ".";
+ }
+
+ gpr_mu_init(&g_mu);
+ gpr_now_impl = now_impl;
+
+ grpc::testing::TestEnvironment env(argc, argv);
+ ::testing::InitGoogleTest(&argc, argv);
+ auto ret = RUN_ALL_TESTS();
+ return ret;
+}
diff --git a/contrib/libs/grpc/test/cpp/end2end/xds_end2end_test.cc b/contrib/libs/grpc/test/cpp/end2end/xds_end2end_test.cc
new file mode 100644
index 0000000000..603e6186bf
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/end2end/xds_end2end_test.cc
@@ -0,0 +1,5832 @@
+/*
+ *
+ * Copyright 2017 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include <deque>
+#include <memory>
+#include <mutex>
+#include <numeric>
+#include <set>
+#include <sstream>
+#include <util/generic/string.h>
+#include <thread>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include "y_absl/strings/str_cat.h"
+#include "y_absl/types/optional.h"
+
+#include <grpc/grpc.h>
+#include <grpc/support/alloc.h>
+#include <grpc/support/log.h>
+#include <grpc/support/time.h>
+#include <grpcpp/channel.h>
+#include <grpcpp/client_context.h>
+#include <grpcpp/create_channel.h>
+#include <grpcpp/server.h>
+#include <grpcpp/server_builder.h>
+
+#include "src/core/ext/filters/client_channel/backup_poller.h"
+#include "src/core/ext/filters/client_channel/resolver/fake/fake_resolver.h"
+#include "src/core/ext/filters/client_channel/server_address.h"
+#include "src/core/ext/xds/xds_api.h"
+#include "src/core/ext/xds/xds_channel_args.h"
+#include "src/core/ext/xds/xds_client.h"
+#include "src/core/lib/channel/channel_args.h"
+#include "src/core/lib/gpr/env.h"
+#include "src/core/lib/gpr/tmpfile.h"
+#include "src/core/lib/gprpp/map.h"
+#include "src/core/lib/gprpp/ref_counted_ptr.h"
+#include "src/core/lib/gprpp/sync.h"
+#include "src/core/lib/iomgr/parse_address.h"
+#include "src/core/lib/iomgr/sockaddr.h"
+#include "src/core/lib/security/credentials/fake/fake_credentials.h"
+#include "src/cpp/client/secure_credentials.h"
+#include "src/cpp/server/secure_server_credentials.h"
+
+#include "test/core/util/port.h"
+#include "test/core/util/test_config.h"
+#include "test/cpp/end2end/test_service_impl.h"
+
+#include "src/proto/grpc/testing/echo.grpc.pb.h"
+#include "src/proto/grpc/testing/xds/ads_for_test.grpc.pb.h"
+#include "src/proto/grpc/testing/xds/cds_for_test.grpc.pb.h"
+#include "src/proto/grpc/testing/xds/eds_for_test.grpc.pb.h"
+#include "src/proto/grpc/testing/xds/lds_rds_for_test.grpc.pb.h"
+#include "src/proto/grpc/testing/xds/lrs_for_test.grpc.pb.h"
+
+#include "src/proto/grpc/testing/xds/v3/ads.grpc.pb.h"
+#include "src/proto/grpc/testing/xds/v3/cluster.grpc.pb.h"
+#include "src/proto/grpc/testing/xds/v3/discovery.grpc.pb.h"
+#include "src/proto/grpc/testing/xds/v3/endpoint.grpc.pb.h"
+#include "src/proto/grpc/testing/xds/v3/http_connection_manager.grpc.pb.h"
+#include "src/proto/grpc/testing/xds/v3/listener.grpc.pb.h"
+#include "src/proto/grpc/testing/xds/v3/lrs.grpc.pb.h"
+#include "src/proto/grpc/testing/xds/v3/route.grpc.pb.h"
+
+namespace grpc {
+namespace testing {
+namespace {
+
+using std::chrono::system_clock;
+
+using ::envoy::config::cluster::v3::CircuitBreakers;
+using ::envoy::config::cluster::v3::Cluster;
+using ::envoy::config::cluster::v3::RoutingPriority;
+using ::envoy::config::endpoint::v3::ClusterLoadAssignment;
+using ::envoy::config::endpoint::v3::HealthStatus;
+using ::envoy::config::listener::v3::Listener;
+using ::envoy::config::route::v3::RouteConfiguration;
+using ::envoy::extensions::filters::network::http_connection_manager::v3::
+ HttpConnectionManager;
+using ::envoy::type::v3::FractionalPercent;
+
+constexpr char kLdsTypeUrl[] =
+ "type.googleapis.com/envoy.config.listener.v3.Listener";
+constexpr char kRdsTypeUrl[] =
+ "type.googleapis.com/envoy.config.route.v3.RouteConfiguration";
+constexpr char kCdsTypeUrl[] =
+ "type.googleapis.com/envoy.config.cluster.v3.Cluster";
+constexpr char kEdsTypeUrl[] =
+ "type.googleapis.com/envoy.config.endpoint.v3.ClusterLoadAssignment";
+
+constexpr char kLdsV2TypeUrl[] = "type.googleapis.com/envoy.api.v2.Listener";
+constexpr char kRdsV2TypeUrl[] =
+ "type.googleapis.com/envoy.api.v2.RouteConfiguration";
+constexpr char kCdsV2TypeUrl[] = "type.googleapis.com/envoy.api.v2.Cluster";
+constexpr char kEdsV2TypeUrl[] =
+ "type.googleapis.com/envoy.api.v2.ClusterLoadAssignment";
+
+constexpr char kDefaultLocalityRegion[] = "xds_default_locality_region";
+constexpr char kDefaultLocalityZone[] = "xds_default_locality_zone";
+constexpr char kLbDropType[] = "lb";
+constexpr char kThrottleDropType[] = "throttle";
+constexpr char kServerName[] = "server.example.com";
+constexpr char kDefaultRouteConfigurationName[] = "route_config_name";
+constexpr char kDefaultClusterName[] = "cluster_name";
+constexpr char kDefaultEdsServiceName[] = "eds_service_name";
+constexpr int kDefaultLocalityWeight = 3;
+constexpr int kDefaultLocalityPriority = 0;
+
+constexpr char kRequestMessage[] = "Live long and prosper.";
+constexpr char kDefaultServiceConfig[] =
+ "{\n"
+ " \"loadBalancingConfig\":[\n"
+ " { \"does_not_exist\":{} },\n"
+ " { \"eds_experimental\":{\n"
+ " \"clusterName\": \"server.example.com\",\n"
+ " \"lrsLoadReportingServerName\": \"\"\n"
+ " } }\n"
+ " ]\n"
+ "}";
+constexpr char kDefaultServiceConfigWithoutLoadReporting[] =
+ "{\n"
+ " \"loadBalancingConfig\":[\n"
+ " { \"does_not_exist\":{} },\n"
+ " { \"eds_experimental\":{\n"
+ " \"clusterName\": \"server.example.com\"\n"
+ " } }\n"
+ " ]\n"
+ "}";
+
+constexpr char kBootstrapFileV3[] =
+ "{\n"
+ " \"xds_servers\": [\n"
+ " {\n"
+ " \"server_uri\": \"fake:///xds_server\",\n"
+ " \"channel_creds\": [\n"
+ " {\n"
+ " \"type\": \"fake\"\n"
+ " }\n"
+ " ],\n"
+ " \"server_features\": [\"xds_v3\"]\n"
+ " }\n"
+ " ],\n"
+ " \"node\": {\n"
+ " \"id\": \"xds_end2end_test\",\n"
+ " \"cluster\": \"test\",\n"
+ " \"metadata\": {\n"
+ " \"foo\": \"bar\"\n"
+ " },\n"
+ " \"locality\": {\n"
+ " \"region\": \"corp\",\n"
+ " \"zone\": \"svl\",\n"
+ " \"subzone\": \"mp3\"\n"
+ " }\n"
+ " }\n"
+ "}\n";
+
+constexpr char kBootstrapFileV2[] =
+ "{\n"
+ " \"xds_servers\": [\n"
+ " {\n"
+ " \"server_uri\": \"fake:///xds_server\",\n"
+ " \"channel_creds\": [\n"
+ " {\n"
+ " \"type\": \"fake\"\n"
+ " }\n"
+ " ]\n"
+ " }\n"
+ " ],\n"
+ " \"node\": {\n"
+ " \"id\": \"xds_end2end_test\",\n"
+ " \"cluster\": \"test\",\n"
+ " \"metadata\": {\n"
+ " \"foo\": \"bar\"\n"
+ " },\n"
+ " \"locality\": {\n"
+ " \"region\": \"corp\",\n"
+ " \"zone\": \"svl\",\n"
+ " \"subzone\": \"mp3\"\n"
+ " }\n"
+ " }\n"
+ "}\n";
+
+char* g_bootstrap_file_v3;
+char* g_bootstrap_file_v2;
+
+void WriteBootstrapFiles() {
+ char* bootstrap_file;
+ FILE* out = gpr_tmpfile("xds_bootstrap_v3", &bootstrap_file);
+ fputs(kBootstrapFileV3, out);
+ fclose(out);
+ g_bootstrap_file_v3 = bootstrap_file;
+ out = gpr_tmpfile("xds_bootstrap_v2", &bootstrap_file);
+ fputs(kBootstrapFileV2, out);
+ fclose(out);
+ g_bootstrap_file_v2 = bootstrap_file;
+}
+
+// Helper class to minimize the number of unique ports we use for this test.
+class PortSaver {
+ public:
+ int GetPort() {
+ if (idx_ >= ports_.size()) {
+ ports_.push_back(grpc_pick_unused_port_or_die());
+ }
+ return ports_[idx_++];
+ }
+
+ void Reset() { idx_ = 0; }
+
+ private:
+ std::vector<int> ports_;
+ size_t idx_ = 0;
+};
+
+PortSaver* g_port_saver = nullptr;
+
+template <typename ServiceType>
+class CountedService : public ServiceType {
+ public:
+ size_t request_count() {
+ grpc_core::MutexLock lock(&mu_);
+ return request_count_;
+ }
+
+ size_t response_count() {
+ grpc_core::MutexLock lock(&mu_);
+ return response_count_;
+ }
+
+ void IncreaseResponseCount() {
+ grpc_core::MutexLock lock(&mu_);
+ ++response_count_;
+ }
+ void IncreaseRequestCount() {
+ grpc_core::MutexLock lock(&mu_);
+ ++request_count_;
+ }
+
+ void ResetCounters() {
+ grpc_core::MutexLock lock(&mu_);
+ request_count_ = 0;
+ response_count_ = 0;
+ }
+
+ private:
+ grpc_core::Mutex mu_;
+ size_t request_count_ = 0;
+ size_t response_count_ = 0;
+};
+
+const char g_kCallCredsMdKey[] = "Balancer should not ...";
+const char g_kCallCredsMdValue[] = "... receive me";
+
+template <typename RpcService>
+class BackendServiceImpl
+ : public CountedService<TestMultipleServiceImpl<RpcService>> {
+ public:
+ BackendServiceImpl() {}
+
+ Status Echo(ServerContext* context, const EchoRequest* request,
+ EchoResponse* response) override {
+ // Backend should receive the call credentials metadata.
+ auto call_credentials_entry =
+ context->client_metadata().find(g_kCallCredsMdKey);
+ EXPECT_NE(call_credentials_entry, context->client_metadata().end());
+ if (call_credentials_entry != context->client_metadata().end()) {
+ EXPECT_EQ(call_credentials_entry->second, g_kCallCredsMdValue);
+ }
+ CountedService<TestMultipleServiceImpl<RpcService>>::IncreaseRequestCount();
+ const auto status =
+ TestMultipleServiceImpl<RpcService>::Echo(context, request, response);
+ CountedService<
+ TestMultipleServiceImpl<RpcService>>::IncreaseResponseCount();
+ AddClient(context->peer());
+ return status;
+ }
+
+ Status Echo1(ServerContext* context, const EchoRequest* request,
+ EchoResponse* response) override {
+ return Echo(context, request, response);
+ }
+
+ Status Echo2(ServerContext* context, const EchoRequest* request,
+ EchoResponse* response) override {
+ return Echo(context, request, response);
+ }
+
+ void Start() {}
+ void Shutdown() {}
+
+ std::set<TString> clients() {
+ grpc_core::MutexLock lock(&clients_mu_);
+ return clients_;
+ }
+
+ private:
+ void AddClient(const TString& client) {
+ grpc_core::MutexLock lock(&clients_mu_);
+ clients_.insert(client);
+ }
+
+ grpc_core::Mutex clients_mu_;
+ std::set<TString> clients_;
+};
+
+class ClientStats {
+ public:
+ struct LocalityStats {
+ LocalityStats() {}
+
+ // Converts from proto message class.
+ template <class UpstreamLocalityStats>
+ LocalityStats(const UpstreamLocalityStats& upstream_locality_stats)
+ : total_successful_requests(
+ upstream_locality_stats.total_successful_requests()),
+ total_requests_in_progress(
+ upstream_locality_stats.total_requests_in_progress()),
+ total_error_requests(upstream_locality_stats.total_error_requests()),
+ total_issued_requests(
+ upstream_locality_stats.total_issued_requests()) {}
+
+ LocalityStats& operator+=(const LocalityStats& other) {
+ total_successful_requests += other.total_successful_requests;
+ total_requests_in_progress += other.total_requests_in_progress;
+ total_error_requests += other.total_error_requests;
+ total_issued_requests += other.total_issued_requests;
+ return *this;
+ }
+
+ uint64_t total_successful_requests = 0;
+ uint64_t total_requests_in_progress = 0;
+ uint64_t total_error_requests = 0;
+ uint64_t total_issued_requests = 0;
+ };
+
+ ClientStats() {}
+
+ // Converts from proto message class.
+ template <class ClusterStats>
+ explicit ClientStats(const ClusterStats& cluster_stats)
+ : cluster_name_(cluster_stats.cluster_name()),
+ total_dropped_requests_(cluster_stats.total_dropped_requests()) {
+ for (const auto& input_locality_stats :
+ cluster_stats.upstream_locality_stats()) {
+ locality_stats_.emplace(input_locality_stats.locality().sub_zone(),
+ LocalityStats(input_locality_stats));
+ }
+ for (const auto& input_dropped_requests :
+ cluster_stats.dropped_requests()) {
+ dropped_requests_.emplace(input_dropped_requests.category(),
+ input_dropped_requests.dropped_count());
+ }
+ }
+
+ const TString& cluster_name() const { return cluster_name_; }
+
+ const std::map<TString, LocalityStats>& locality_stats() const {
+ return locality_stats_;
+ }
+ uint64_t total_successful_requests() const {
+ uint64_t sum = 0;
+ for (auto& p : locality_stats_) {
+ sum += p.second.total_successful_requests;
+ }
+ return sum;
+ }
+ uint64_t total_requests_in_progress() const {
+ uint64_t sum = 0;
+ for (auto& p : locality_stats_) {
+ sum += p.second.total_requests_in_progress;
+ }
+ return sum;
+ }
+ uint64_t total_error_requests() const {
+ uint64_t sum = 0;
+ for (auto& p : locality_stats_) {
+ sum += p.second.total_error_requests;
+ }
+ return sum;
+ }
+ uint64_t total_issued_requests() const {
+ uint64_t sum = 0;
+ for (auto& p : locality_stats_) {
+ sum += p.second.total_issued_requests;
+ }
+ return sum;
+ }
+
+ uint64_t total_dropped_requests() const { return total_dropped_requests_; }
+
+ uint64_t dropped_requests(const TString& category) const {
+ auto iter = dropped_requests_.find(category);
+ GPR_ASSERT(iter != dropped_requests_.end());
+ return iter->second;
+ }
+
+ ClientStats& operator+=(const ClientStats& other) {
+ for (const auto& p : other.locality_stats_) {
+ locality_stats_[p.first] += p.second;
+ }
+ total_dropped_requests_ += other.total_dropped_requests_;
+ for (const auto& p : other.dropped_requests_) {
+ dropped_requests_[p.first] += p.second;
+ }
+ return *this;
+ }
+
+ private:
+ TString cluster_name_;
+ std::map<TString, LocalityStats> locality_stats_;
+ uint64_t total_dropped_requests_ = 0;
+ std::map<TString, uint64_t> dropped_requests_;
+};
+
+class AdsServiceImpl : public std::enable_shared_from_this<AdsServiceImpl> {
+ public:
+ struct ResponseState {
+ enum State { NOT_SENT, SENT, ACKED, NACKED };
+ State state = NOT_SENT;
+ TString error_message;
+ };
+
+ struct EdsResourceArgs {
+ struct Locality {
+ Locality(const TString& sub_zone, std::vector<int> ports,
+ int lb_weight = kDefaultLocalityWeight,
+ int priority = kDefaultLocalityPriority,
+ std::vector<HealthStatus> health_statuses = {})
+ : sub_zone(std::move(sub_zone)),
+ ports(std::move(ports)),
+ lb_weight(lb_weight),
+ priority(priority),
+ health_statuses(std::move(health_statuses)) {}
+
+ const TString sub_zone;
+ std::vector<int> ports;
+ int lb_weight;
+ int priority;
+ std::vector<HealthStatus> health_statuses;
+ };
+
+ EdsResourceArgs() = default;
+ explicit EdsResourceArgs(std::vector<Locality> locality_list)
+ : locality_list(std::move(locality_list)) {}
+
+ std::vector<Locality> locality_list;
+ std::map<TString, uint32_t> drop_categories;
+ FractionalPercent::DenominatorType drop_denominator =
+ FractionalPercent::MILLION;
+ };
+
+ explicit AdsServiceImpl(bool enable_load_reporting)
+ : v2_rpc_service_(this, /*is_v2=*/true),
+ v3_rpc_service_(this, /*is_v2=*/false) {
+ // Construct RDS response data.
+ default_route_config_.set_name(kDefaultRouteConfigurationName);
+ auto* virtual_host = default_route_config_.add_virtual_hosts();
+ virtual_host->add_domains("*");
+ auto* route = virtual_host->add_routes();
+ route->mutable_match()->set_prefix("");
+ route->mutable_route()->set_cluster(kDefaultClusterName);
+ SetRdsResource(default_route_config_);
+ // Construct LDS response data (with inlined RDS result).
+ default_listener_ = BuildListener(default_route_config_);
+ SetLdsResource(default_listener_);
+ // Construct CDS response data.
+ default_cluster_.set_name(kDefaultClusterName);
+ default_cluster_.set_type(Cluster::EDS);
+ auto* eds_config = default_cluster_.mutable_eds_cluster_config();
+ eds_config->mutable_eds_config()->mutable_ads();
+ eds_config->set_service_name(kDefaultEdsServiceName);
+ default_cluster_.set_lb_policy(Cluster::ROUND_ROBIN);
+ if (enable_load_reporting) {
+ default_cluster_.mutable_lrs_server()->mutable_self();
+ }
+ SetCdsResource(default_cluster_);
+ }
+
+ bool seen_v2_client() const { return seen_v2_client_; }
+ bool seen_v3_client() const { return seen_v3_client_; }
+
+ ::envoy::service::discovery::v2::AggregatedDiscoveryService::Service*
+ v2_rpc_service() {
+ return &v2_rpc_service_;
+ }
+
+ ::envoy::service::discovery::v3::AggregatedDiscoveryService::Service*
+ v3_rpc_service() {
+ return &v3_rpc_service_;
+ }
+
+ Listener default_listener() const { return default_listener_; }
+ RouteConfiguration default_route_config() const {
+ return default_route_config_;
+ }
+ Cluster default_cluster() const { return default_cluster_; }
+
+ ResponseState lds_response_state() {
+ grpc_core::MutexLock lock(&ads_mu_);
+ return resource_type_response_state_[kLdsTypeUrl];
+ }
+
+ ResponseState rds_response_state() {
+ grpc_core::MutexLock lock(&ads_mu_);
+ return resource_type_response_state_[kRdsTypeUrl];
+ }
+
+ ResponseState cds_response_state() {
+ grpc_core::MutexLock lock(&ads_mu_);
+ return resource_type_response_state_[kCdsTypeUrl];
+ }
+
+ ResponseState eds_response_state() {
+ grpc_core::MutexLock lock(&ads_mu_);
+ return resource_type_response_state_[kEdsTypeUrl];
+ }
+
+ void SetResourceIgnore(const TString& type_url) {
+ grpc_core::MutexLock lock(&ads_mu_);
+ resource_types_to_ignore_.emplace(type_url);
+ }
+
+ void UnsetResource(const TString& type_url, const TString& name) {
+ grpc_core::MutexLock lock(&ads_mu_);
+ ResourceState& state = resource_map_[type_url][name];
+ ++state.version;
+ state.resource.reset();
+ gpr_log(GPR_INFO, "ADS[%p]: Unsetting %s resource %s to version %u", this,
+ type_url.c_str(), name.c_str(), state.version);
+ for (SubscriptionState* subscription : state.subscriptions) {
+ subscription->update_queue->emplace_back(type_url, name);
+ }
+ }
+
+ void SetResource(google::protobuf::Any resource, const TString& type_url,
+ const TString& name) {
+ grpc_core::MutexLock lock(&ads_mu_);
+ ResourceState& state = resource_map_[type_url][name];
+ ++state.version;
+ state.resource = std::move(resource);
+ gpr_log(GPR_INFO, "ADS[%p]: Updating %s resource %s to version %u", this,
+ type_url.c_str(), name.c_str(), state.version);
+ for (SubscriptionState* subscription : state.subscriptions) {
+ subscription->update_queue->emplace_back(type_url, name);
+ }
+ }
+
+ void SetLdsResource(const Listener& listener) {
+ google::protobuf::Any resource;
+ resource.PackFrom(listener);
+ SetResource(std::move(resource), kLdsTypeUrl, listener.name());
+ }
+
+ void SetRdsResource(const RouteConfiguration& route) {
+ google::protobuf::Any resource;
+ resource.PackFrom(route);
+ SetResource(std::move(resource), kRdsTypeUrl, route.name());
+ }
+
+ void SetCdsResource(const Cluster& cluster) {
+ google::protobuf::Any resource;
+ resource.PackFrom(cluster);
+ SetResource(std::move(resource), kCdsTypeUrl, cluster.name());
+ }
+
+ void SetEdsResource(const ClusterLoadAssignment& assignment) {
+ google::protobuf::Any resource;
+ resource.PackFrom(assignment);
+ SetResource(std::move(resource), kEdsTypeUrl, assignment.cluster_name());
+ }
+
+ void SetLdsToUseDynamicRds() {
+ auto listener = default_listener_;
+ HttpConnectionManager http_connection_manager;
+ auto* rds = http_connection_manager.mutable_rds();
+ rds->set_route_config_name(kDefaultRouteConfigurationName);
+ rds->mutable_config_source()->mutable_ads();
+ listener.mutable_api_listener()->mutable_api_listener()->PackFrom(
+ http_connection_manager);
+ SetLdsResource(listener);
+ }
+
+ static Listener BuildListener(const RouteConfiguration& route_config) {
+ HttpConnectionManager http_connection_manager;
+ *(http_connection_manager.mutable_route_config()) = route_config;
+ Listener listener;
+ listener.set_name(kServerName);
+ listener.mutable_api_listener()->mutable_api_listener()->PackFrom(
+ http_connection_manager);
+ return listener;
+ }
+
+ static ClusterLoadAssignment BuildEdsResource(
+ const EdsResourceArgs& args,
+ const char* eds_service_name = kDefaultEdsServiceName) {
+ ClusterLoadAssignment assignment;
+ assignment.set_cluster_name(eds_service_name);
+ for (const auto& locality : args.locality_list) {
+ auto* endpoints = assignment.add_endpoints();
+ endpoints->mutable_load_balancing_weight()->set_value(locality.lb_weight);
+ endpoints->set_priority(locality.priority);
+ endpoints->mutable_locality()->set_region(kDefaultLocalityRegion);
+ endpoints->mutable_locality()->set_zone(kDefaultLocalityZone);
+ endpoints->mutable_locality()->set_sub_zone(locality.sub_zone);
+ for (size_t i = 0; i < locality.ports.size(); ++i) {
+ const int& port = locality.ports[i];
+ auto* lb_endpoints = endpoints->add_lb_endpoints();
+ if (locality.health_statuses.size() > i &&
+ locality.health_statuses[i] != HealthStatus::UNKNOWN) {
+ lb_endpoints->set_health_status(locality.health_statuses[i]);
+ }
+ auto* endpoint = lb_endpoints->mutable_endpoint();
+ auto* address = endpoint->mutable_address();
+ auto* socket_address = address->mutable_socket_address();
+ socket_address->set_address("127.0.0.1");
+ socket_address->set_port_value(port);
+ }
+ }
+ if (!args.drop_categories.empty()) {
+ auto* policy = assignment.mutable_policy();
+ for (const auto& p : args.drop_categories) {
+ const TString& name = p.first;
+ const uint32_t parts_per_million = p.second;
+ auto* drop_overload = policy->add_drop_overloads();
+ drop_overload->set_category(name);
+ auto* drop_percentage = drop_overload->mutable_drop_percentage();
+ drop_percentage->set_numerator(parts_per_million);
+ drop_percentage->set_denominator(args.drop_denominator);
+ }
+ }
+ return assignment;
+ }
+
+ void Start() {
+ grpc_core::MutexLock lock(&ads_mu_);
+ ads_done_ = false;
+ }
+
+ void Shutdown() {
+ {
+ grpc_core::MutexLock lock(&ads_mu_);
+ NotifyDoneWithAdsCallLocked();
+ resource_type_response_state_.clear();
+ }
+ gpr_log(GPR_INFO, "ADS[%p]: shut down", this);
+ }
+
+ void NotifyDoneWithAdsCall() {
+ grpc_core::MutexLock lock(&ads_mu_);
+ NotifyDoneWithAdsCallLocked();
+ }
+
+ void NotifyDoneWithAdsCallLocked() {
+ if (!ads_done_) {
+ ads_done_ = true;
+ ads_cond_.Broadcast();
+ }
+ }
+
+ std::set<TString> clients() {
+ grpc_core::MutexLock lock(&clients_mu_);
+ return clients_;
+ }
+
+ private:
+ // A queue of resource type/name pairs that have changed since the client
+ // subscribed to them.
+ using UpdateQueue = std::deque<
+ std::pair<TString /* type url */, TString /* resource name */>>;
+
+ // A struct representing a client's subscription to a particular resource.
+ struct SubscriptionState {
+ // Version that the client currently knows about.
+ int current_version = 0;
+ // The queue upon which to place updates when the resource is updated.
+ UpdateQueue* update_queue;
+ };
+
+ // A struct representing the a client's subscription to all the resources.
+ using SubscriptionNameMap =
+ std::map<TString /* resource_name */, SubscriptionState>;
+ using SubscriptionMap =
+ std::map<TString /* type_url */, SubscriptionNameMap>;
+
+ // A struct representing the current state for a resource:
+ // - the version of the resource that is set by the SetResource() methods.
+ // - a list of subscriptions interested in this resource.
+ struct ResourceState {
+ int version = 0;
+ y_absl::optional<google::protobuf::Any> resource;
+ std::set<SubscriptionState*> subscriptions;
+ };
+
+ // A struct representing the current state for all resources:
+ // LDS, CDS, EDS, and RDS for the class as a whole.
+ using ResourceNameMap =
+ std::map<TString /* resource_name */, ResourceState>;
+ using ResourceMap = std::map<TString /* type_url */, ResourceNameMap>;
+
+ template <class RpcApi, class DiscoveryRequest, class DiscoveryResponse>
+ class RpcService : public RpcApi::Service {
+ public:
+ using Stream = ServerReaderWriter<DiscoveryResponse, DiscoveryRequest>;
+
+ RpcService(AdsServiceImpl* parent, bool is_v2)
+ : parent_(parent), is_v2_(is_v2) {}
+
+ Status StreamAggregatedResources(ServerContext* context,
+ Stream* stream) override {
+ gpr_log(GPR_INFO, "ADS[%p]: StreamAggregatedResources starts", this);
+ parent_->AddClient(context->peer());
+ if (is_v2_) {
+ parent_->seen_v2_client_ = true;
+ } else {
+ parent_->seen_v3_client_ = true;
+ }
+ // Resources (type/name pairs) that have changed since the client
+ // subscribed to them.
+ UpdateQueue update_queue;
+ // Resources that the client will be subscribed to keyed by resource type
+ // url.
+ SubscriptionMap subscription_map;
+ [&]() {
+ {
+ grpc_core::MutexLock lock(&parent_->ads_mu_);
+ if (parent_->ads_done_) return;
+ }
+ // Balancer shouldn't receive the call credentials metadata.
+ EXPECT_EQ(context->client_metadata().find(g_kCallCredsMdKey),
+ context->client_metadata().end());
+ // Current Version map keyed by resource type url.
+ std::map<TString, int> resource_type_version;
+ // Creating blocking thread to read from stream.
+ std::deque<DiscoveryRequest> requests;
+ bool stream_closed = false;
+ // Take a reference of the AdsServiceImpl object, reference will go
+ // out of scope after the reader thread is joined.
+ std::shared_ptr<AdsServiceImpl> ads_service_impl =
+ parent_->shared_from_this();
+ std::thread reader(std::bind(&RpcService::BlockingRead, this, stream,
+ &requests, &stream_closed));
+ // Main loop to look for requests and updates.
+ while (true) {
+ // Look for new requests and and decide what to handle.
+ y_absl::optional<DiscoveryResponse> response;
+ // Boolean to keep track if the loop received any work to do: a
+ // request or an update; regardless whether a response was actually
+ // sent out.
+ bool did_work = false;
+ {
+ grpc_core::MutexLock lock(&parent_->ads_mu_);
+ if (stream_closed) break;
+ if (!requests.empty()) {
+ DiscoveryRequest request = std::move(requests.front());
+ requests.pop_front();
+ did_work = true;
+ gpr_log(GPR_INFO,
+ "ADS[%p]: Received request for type %s with content %s",
+ this, request.type_url().c_str(),
+ request.DebugString().c_str());
+ const TString v3_resource_type =
+ TypeUrlToV3(request.type_url());
+ // As long as we are not in shutdown, identify ACK and NACK by
+ // looking for version information and comparing it to nonce (this
+ // server ensures they are always set to the same in a response.)
+ auto it =
+ parent_->resource_type_response_state_.find(v3_resource_type);
+ if (it != parent_->resource_type_response_state_.end()) {
+ if (!request.response_nonce().empty()) {
+ it->second.state =
+ (!request.version_info().empty() &&
+ request.version_info() == request.response_nonce())
+ ? ResponseState::ACKED
+ : ResponseState::NACKED;
+ }
+ if (request.has_error_detail()) {
+ it->second.error_message = request.error_detail().message();
+ }
+ }
+ // As long as the test did not tell us to ignore this type of
+ // request, look at all the resource names.
+ if (parent_->resource_types_to_ignore_.find(v3_resource_type) ==
+ parent_->resource_types_to_ignore_.end()) {
+ auto& subscription_name_map =
+ subscription_map[v3_resource_type];
+ auto& resource_name_map =
+ parent_->resource_map_[v3_resource_type];
+ std::set<TString> resources_in_current_request;
+ std::set<TString> resources_added_to_response;
+ for (const TString& resource_name :
+ request.resource_names()) {
+ resources_in_current_request.emplace(resource_name);
+ auto& subscription_state =
+ subscription_name_map[resource_name];
+ auto& resource_state = resource_name_map[resource_name];
+ // Subscribe if needed.
+ parent_->MaybeSubscribe(v3_resource_type, resource_name,
+ &subscription_state, &resource_state,
+ &update_queue);
+ // Send update if needed.
+ if (ClientNeedsResourceUpdate(resource_state,
+ &subscription_state)) {
+ gpr_log(GPR_INFO,
+ "ADS[%p]: Sending update for type=%s name=%s "
+ "version=%d",
+ this, request.type_url().c_str(),
+ resource_name.c_str(), resource_state.version);
+ resources_added_to_response.emplace(resource_name);
+ if (!response.has_value()) response.emplace();
+ if (resource_state.resource.has_value()) {
+ auto* resource = response->add_resources();
+ resource->CopyFrom(resource_state.resource.value());
+ if (is_v2_) {
+ resource->set_type_url(request.type_url());
+ }
+ }
+ } else {
+ gpr_log(GPR_INFO,
+ "ADS[%p]: client does not need update for "
+ "type=%s name=%s version=%d",
+ this, request.type_url().c_str(),
+ resource_name.c_str(), resource_state.version);
+ }
+ }
+ // Process unsubscriptions for any resource no longer
+ // present in the request's resource list.
+ parent_->ProcessUnsubscriptions(
+ v3_resource_type, resources_in_current_request,
+ &subscription_name_map, &resource_name_map);
+ // Send response if needed.
+ if (!resources_added_to_response.empty()) {
+ CompleteBuildingDiscoveryResponse(
+ v3_resource_type, request.type_url(),
+ ++resource_type_version[v3_resource_type],
+ subscription_name_map, resources_added_to_response,
+ &response.value());
+ }
+ }
+ }
+ }
+ if (response.has_value()) {
+ gpr_log(GPR_INFO, "ADS[%p]: Sending response: %s", this,
+ response->DebugString().c_str());
+ stream->Write(response.value());
+ }
+ response.reset();
+ // Look for updates and decide what to handle.
+ {
+ grpc_core::MutexLock lock(&parent_->ads_mu_);
+ if (!update_queue.empty()) {
+ const TString resource_type =
+ std::move(update_queue.front().first);
+ const TString resource_name =
+ std::move(update_queue.front().second);
+ update_queue.pop_front();
+ const TString v2_resource_type = TypeUrlToV2(resource_type);
+ did_work = true;
+ gpr_log(GPR_INFO, "ADS[%p]: Received update for type=%s name=%s",
+ this, resource_type.c_str(), resource_name.c_str());
+ auto& subscription_name_map = subscription_map[resource_type];
+ auto& resource_name_map = parent_->resource_map_[resource_type];
+ auto it = subscription_name_map.find(resource_name);
+ if (it != subscription_name_map.end()) {
+ SubscriptionState& subscription_state = it->second;
+ ResourceState& resource_state =
+ resource_name_map[resource_name];
+ if (ClientNeedsResourceUpdate(resource_state,
+ &subscription_state)) {
+ gpr_log(
+ GPR_INFO,
+ "ADS[%p]: Sending update for type=%s name=%s version=%d",
+ this, resource_type.c_str(), resource_name.c_str(),
+ resource_state.version);
+ response.emplace();
+ if (resource_state.resource.has_value()) {
+ auto* resource = response->add_resources();
+ resource->CopyFrom(resource_state.resource.value());
+ if (is_v2_) {
+ resource->set_type_url(v2_resource_type);
+ }
+ }
+ CompleteBuildingDiscoveryResponse(
+ resource_type, v2_resource_type,
+ ++resource_type_version[resource_type],
+ subscription_name_map, {resource_name},
+ &response.value());
+ }
+ }
+ }
+ }
+ if (response.has_value()) {
+ gpr_log(GPR_INFO, "ADS[%p]: Sending update response: %s", this,
+ response->DebugString().c_str());
+ stream->Write(response.value());
+ }
+ // If we didn't find anything to do, delay before the next loop
+ // iteration; otherwise, check whether we should exit and then
+ // immediately continue.
+ gpr_timespec deadline =
+ grpc_timeout_milliseconds_to_deadline(did_work ? 0 : 10);
+ {
+ grpc_core::MutexLock lock(&parent_->ads_mu_);
+ if (!parent_->ads_cond_.WaitUntil(
+ &parent_->ads_mu_, [this] { return parent_->ads_done_; },
+ deadline)) {
+ break;
+ }
+ }
+ }
+ reader.join();
+ }();
+ // Clean up any subscriptions that were still active when the call
+ // finished.
+ {
+ grpc_core::MutexLock lock(&parent_->ads_mu_);
+ for (auto& p : subscription_map) {
+ const TString& type_url = p.first;
+ SubscriptionNameMap& subscription_name_map = p.second;
+ for (auto& q : subscription_name_map) {
+ const TString& resource_name = q.first;
+ SubscriptionState& subscription_state = q.second;
+ ResourceState& resource_state =
+ parent_->resource_map_[type_url][resource_name];
+ resource_state.subscriptions.erase(&subscription_state);
+ }
+ }
+ }
+ gpr_log(GPR_INFO, "ADS[%p]: StreamAggregatedResources done", this);
+ parent_->RemoveClient(context->peer());
+ return Status::OK;
+ }
+
+ private:
+ static TString TypeUrlToV2(const TString& resource_type) {
+ if (resource_type == kLdsTypeUrl) return kLdsV2TypeUrl;
+ if (resource_type == kRdsTypeUrl) return kRdsV2TypeUrl;
+ if (resource_type == kCdsTypeUrl) return kCdsV2TypeUrl;
+ if (resource_type == kEdsTypeUrl) return kEdsV2TypeUrl;
+ return resource_type;
+ }
+
+ static TString TypeUrlToV3(const TString& resource_type) {
+ if (resource_type == kLdsV2TypeUrl) return kLdsTypeUrl;
+ if (resource_type == kRdsV2TypeUrl) return kRdsTypeUrl;
+ if (resource_type == kCdsV2TypeUrl) return kCdsTypeUrl;
+ if (resource_type == kEdsV2TypeUrl) return kEdsTypeUrl;
+ return resource_type;
+ }
+
+ // Starting a thread to do blocking read on the stream until cancel.
+ void BlockingRead(Stream* stream, std::deque<DiscoveryRequest>* requests,
+ bool* stream_closed) {
+ DiscoveryRequest request;
+ bool seen_first_request = false;
+ while (stream->Read(&request)) {
+ if (!seen_first_request) {
+ EXPECT_TRUE(request.has_node());
+ ASSERT_FALSE(request.node().client_features().empty());
+ EXPECT_EQ(request.node().client_features(0),
+ "envoy.lb.does_not_support_overprovisioning");
+ CheckBuildVersion(request);
+ seen_first_request = true;
+ }
+ {
+ grpc_core::MutexLock lock(&parent_->ads_mu_);
+ requests->emplace_back(std::move(request));
+ }
+ }
+ gpr_log(GPR_INFO, "ADS[%p]: Null read, stream closed", this);
+ grpc_core::MutexLock lock(&parent_->ads_mu_);
+ *stream_closed = true;
+ }
+
+ static void CheckBuildVersion(
+ const ::envoy::api::v2::DiscoveryRequest& request) {
+ EXPECT_FALSE(request.node().build_version().empty());
+ }
+
+ static void CheckBuildVersion(
+ const ::envoy::service::discovery::v3::DiscoveryRequest& request) {}
+
+ // Completing the building a DiscoveryResponse by adding common information
+ // for all resources and by adding all subscribed resources for LDS and CDS.
+ void CompleteBuildingDiscoveryResponse(
+ const TString& resource_type, const TString& v2_resource_type,
+ const int version, const SubscriptionNameMap& subscription_name_map,
+ const std::set<TString>& resources_added_to_response,
+ DiscoveryResponse* response) {
+ auto& response_state =
+ parent_->resource_type_response_state_[resource_type];
+ if (response_state.state == ResponseState::NOT_SENT) {
+ response_state.state = ResponseState::SENT;
+ }
+ response->set_type_url(is_v2_ ? v2_resource_type : resource_type);
+ response->set_version_info(y_absl::StrCat(version));
+ response->set_nonce(y_absl::StrCat(version));
+ if (resource_type == kLdsTypeUrl || resource_type == kCdsTypeUrl) {
+ // For LDS and CDS we must send back all subscribed resources
+ // (even the unchanged ones)
+ for (const auto& p : subscription_name_map) {
+ const TString& resource_name = p.first;
+ if (resources_added_to_response.find(resource_name) ==
+ resources_added_to_response.end()) {
+ const ResourceState& resource_state =
+ parent_->resource_map_[resource_type][resource_name];
+ if (resource_state.resource.has_value()) {
+ auto* resource = response->add_resources();
+ resource->CopyFrom(resource_state.resource.value());
+ if (is_v2_) {
+ resource->set_type_url(v2_resource_type);
+ }
+ }
+ }
+ }
+ }
+ }
+
+ AdsServiceImpl* parent_;
+ const bool is_v2_;
+ };
+
+ // Checks whether the client needs to receive a newer version of
+ // the resource. If so, updates subscription_state->current_version and
+ // returns true.
+ static bool ClientNeedsResourceUpdate(const ResourceState& resource_state,
+ SubscriptionState* subscription_state) {
+ if (subscription_state->current_version < resource_state.version) {
+ subscription_state->current_version = resource_state.version;
+ return true;
+ }
+ return false;
+ }
+
+ // Subscribes to a resource if not already subscribed:
+ // 1. Sets the update_queue field in subscription_state.
+ // 2. Adds subscription_state to resource_state->subscriptions.
+ void MaybeSubscribe(const TString& resource_type,
+ const TString& resource_name,
+ SubscriptionState* subscription_state,
+ ResourceState* resource_state,
+ UpdateQueue* update_queue) {
+ // The update_queue will be null if we were not previously subscribed.
+ if (subscription_state->update_queue != nullptr) return;
+ subscription_state->update_queue = update_queue;
+ resource_state->subscriptions.emplace(subscription_state);
+ gpr_log(GPR_INFO, "ADS[%p]: subscribe to resource type %s name %s state %p",
+ this, resource_type.c_str(), resource_name.c_str(),
+ &subscription_state);
+ }
+
+ // Removes subscriptions for resources no longer present in the
+ // current request.
+ void ProcessUnsubscriptions(
+ const TString& resource_type,
+ const std::set<TString>& resources_in_current_request,
+ SubscriptionNameMap* subscription_name_map,
+ ResourceNameMap* resource_name_map) {
+ for (auto it = subscription_name_map->begin();
+ it != subscription_name_map->end();) {
+ const TString& resource_name = it->first;
+ SubscriptionState& subscription_state = it->second;
+ if (resources_in_current_request.find(resource_name) !=
+ resources_in_current_request.end()) {
+ ++it;
+ continue;
+ }
+ gpr_log(GPR_INFO, "ADS[%p]: Unsubscribe to type=%s name=%s state=%p",
+ this, resource_type.c_str(), resource_name.c_str(),
+ &subscription_state);
+ auto resource_it = resource_name_map->find(resource_name);
+ GPR_ASSERT(resource_it != resource_name_map->end());
+ auto& resource_state = resource_it->second;
+ resource_state.subscriptions.erase(&subscription_state);
+ if (resource_state.subscriptions.empty() &&
+ !resource_state.resource.has_value()) {
+ resource_name_map->erase(resource_it);
+ }
+ it = subscription_name_map->erase(it);
+ }
+ }
+
+ void AddClient(const TString& client) {
+ grpc_core::MutexLock lock(&clients_mu_);
+ clients_.insert(client);
+ }
+
+ void RemoveClient(const TString& client) {
+ grpc_core::MutexLock lock(&clients_mu_);
+ clients_.erase(client);
+ }
+
+ RpcService<::envoy::service::discovery::v2::AggregatedDiscoveryService,
+ ::envoy::api::v2::DiscoveryRequest,
+ ::envoy::api::v2::DiscoveryResponse>
+ v2_rpc_service_;
+ RpcService<::envoy::service::discovery::v3::AggregatedDiscoveryService,
+ ::envoy::service::discovery::v3::DiscoveryRequest,
+ ::envoy::service::discovery::v3::DiscoveryResponse>
+ v3_rpc_service_;
+
+ std::atomic_bool seen_v2_client_{false};
+ std::atomic_bool seen_v3_client_{false};
+
+ grpc_core::CondVar ads_cond_;
+ // Protect the members below.
+ grpc_core::Mutex ads_mu_;
+ bool ads_done_ = false;
+ Listener default_listener_;
+ RouteConfiguration default_route_config_;
+ Cluster default_cluster_;
+ std::map<TString /* type_url */, ResponseState>
+ resource_type_response_state_;
+ std::set<TString /*resource_type*/> resource_types_to_ignore_;
+ // An instance data member containing the current state of all resources.
+ // Note that an entry will exist whenever either of the following is true:
+ // - The resource exists (i.e., has been created by SetResource() and has not
+ // yet been destroyed by UnsetResource()).
+ // - There is at least one subscription for the resource.
+ ResourceMap resource_map_;
+
+ grpc_core::Mutex clients_mu_;
+ std::set<TString> clients_;
+};
+
+class LrsServiceImpl : public std::enable_shared_from_this<LrsServiceImpl> {
+ public:
+ explicit LrsServiceImpl(int client_load_reporting_interval_seconds)
+ : v2_rpc_service_(this),
+ v3_rpc_service_(this),
+ client_load_reporting_interval_seconds_(
+ client_load_reporting_interval_seconds),
+ cluster_names_({kDefaultClusterName}) {}
+
+ ::envoy::service::load_stats::v2::LoadReportingService::Service*
+ v2_rpc_service() {
+ return &v2_rpc_service_;
+ }
+
+ ::envoy::service::load_stats::v3::LoadReportingService::Service*
+ v3_rpc_service() {
+ return &v3_rpc_service_;
+ }
+
+ size_t request_count() {
+ return v2_rpc_service_.request_count() + v3_rpc_service_.request_count();
+ }
+
+ size_t response_count() {
+ return v2_rpc_service_.response_count() + v3_rpc_service_.response_count();
+ }
+
+ // Must be called before the LRS call is started.
+ void set_send_all_clusters(bool send_all_clusters) {
+ send_all_clusters_ = send_all_clusters;
+ }
+ void set_cluster_names(const std::set<TString>& cluster_names) {
+ cluster_names_ = cluster_names;
+ }
+
+ void Start() {
+ lrs_done_ = false;
+ result_queue_.clear();
+ }
+
+ void Shutdown() {
+ {
+ grpc_core::MutexLock lock(&lrs_mu_);
+ NotifyDoneWithLrsCallLocked();
+ }
+ gpr_log(GPR_INFO, "LRS[%p]: shut down", this);
+ }
+
+ std::vector<ClientStats> WaitForLoadReport() {
+ grpc_core::MutexLock lock(&load_report_mu_);
+ grpc_core::CondVar cv;
+ if (result_queue_.empty()) {
+ load_report_cond_ = &cv;
+ load_report_cond_->WaitUntil(&load_report_mu_,
+ [this] { return !result_queue_.empty(); });
+ load_report_cond_ = nullptr;
+ }
+ std::vector<ClientStats> result = std::move(result_queue_.front());
+ result_queue_.pop_front();
+ return result;
+ }
+
+ void NotifyDoneWithLrsCall() {
+ grpc_core::MutexLock lock(&lrs_mu_);
+ NotifyDoneWithLrsCallLocked();
+ }
+
+ private:
+ template <class RpcApi, class LoadStatsRequest, class LoadStatsResponse>
+ class RpcService : public CountedService<typename RpcApi::Service> {
+ public:
+ using Stream = ServerReaderWriter<LoadStatsResponse, LoadStatsRequest>;
+
+ explicit RpcService(LrsServiceImpl* parent) : parent_(parent) {}
+
+ Status StreamLoadStats(ServerContext* /*context*/,
+ Stream* stream) override {
+ gpr_log(GPR_INFO, "LRS[%p]: StreamLoadStats starts", this);
+ EXPECT_GT(parent_->client_load_reporting_interval_seconds_, 0);
+ // Take a reference of the LrsServiceImpl object, reference will go
+ // out of scope after this method exits.
+ std::shared_ptr<LrsServiceImpl> lrs_service_impl =
+ parent_->shared_from_this();
+ // Read initial request.
+ LoadStatsRequest request;
+ if (stream->Read(&request)) {
+ CountedService<typename RpcApi::Service>::IncreaseRequestCount();
+ // Verify client features.
+ EXPECT_THAT(
+ request.node().client_features(),
+ ::testing::Contains("envoy.lrs.supports_send_all_clusters"));
+ // Send initial response.
+ LoadStatsResponse response;
+ if (parent_->send_all_clusters_) {
+ response.set_send_all_clusters(true);
+ } else {
+ for (const TString& cluster_name : parent_->cluster_names_) {
+ response.add_clusters(cluster_name);
+ }
+ }
+ response.mutable_load_reporting_interval()->set_seconds(
+ parent_->client_load_reporting_interval_seconds_);
+ stream->Write(response);
+ CountedService<typename RpcApi::Service>::IncreaseResponseCount();
+ // Wait for report.
+ request.Clear();
+ while (stream->Read(&request)) {
+ gpr_log(GPR_INFO, "LRS[%p]: received client load report message: %s",
+ this, request.DebugString().c_str());
+ std::vector<ClientStats> stats;
+ for (const auto& cluster_stats : request.cluster_stats()) {
+ stats.emplace_back(cluster_stats);
+ }
+ grpc_core::MutexLock lock(&parent_->load_report_mu_);
+ parent_->result_queue_.emplace_back(std::move(stats));
+ if (parent_->load_report_cond_ != nullptr) {
+ parent_->load_report_cond_->Signal();
+ }
+ }
+ // Wait until notified done.
+ grpc_core::MutexLock lock(&parent_->lrs_mu_);
+ parent_->lrs_cv_.WaitUntil(&parent_->lrs_mu_,
+ [this] { return parent_->lrs_done_; });
+ }
+ gpr_log(GPR_INFO, "LRS[%p]: StreamLoadStats done", this);
+ return Status::OK;
+ }
+
+ private:
+ LrsServiceImpl* parent_;
+ };
+
+ void NotifyDoneWithLrsCallLocked() {
+ if (!lrs_done_) {
+ lrs_done_ = true;
+ lrs_cv_.Broadcast();
+ }
+ }
+
+ RpcService<::envoy::service::load_stats::v2::LoadReportingService,
+ ::envoy::service::load_stats::v2::LoadStatsRequest,
+ ::envoy::service::load_stats::v2::LoadStatsResponse>
+ v2_rpc_service_;
+ RpcService<::envoy::service::load_stats::v3::LoadReportingService,
+ ::envoy::service::load_stats::v3::LoadStatsRequest,
+ ::envoy::service::load_stats::v3::LoadStatsResponse>
+ v3_rpc_service_;
+
+ const int client_load_reporting_interval_seconds_;
+ bool send_all_clusters_ = false;
+ std::set<TString> cluster_names_;
+
+ grpc_core::CondVar lrs_cv_;
+ grpc_core::Mutex lrs_mu_; // Protects lrs_done_.
+ bool lrs_done_ = false;
+
+ grpc_core::Mutex load_report_mu_; // Protects the members below.
+ grpc_core::CondVar* load_report_cond_ = nullptr;
+ std::deque<std::vector<ClientStats>> result_queue_;
+};
+
+class TestType {
+ public:
+ TestType(bool use_xds_resolver, bool enable_load_reporting,
+ bool enable_rds_testing = false, bool use_v2 = false)
+ : use_xds_resolver_(use_xds_resolver),
+ enable_load_reporting_(enable_load_reporting),
+ enable_rds_testing_(enable_rds_testing),
+ use_v2_(use_v2) {}
+
+ bool use_xds_resolver() const { return use_xds_resolver_; }
+ bool enable_load_reporting() const { return enable_load_reporting_; }
+ bool enable_rds_testing() const { return enable_rds_testing_; }
+ bool use_v2() const { return use_v2_; }
+
+ TString AsString() const {
+ TString retval = (use_xds_resolver_ ? "XdsResolver" : "FakeResolver");
+ retval += (use_v2_ ? "V2" : "V3");
+ if (enable_load_reporting_) retval += "WithLoadReporting";
+ if (enable_rds_testing_) retval += "Rds";
+ return retval;
+ }
+
+ private:
+ const bool use_xds_resolver_;
+ const bool enable_load_reporting_;
+ const bool enable_rds_testing_;
+ const bool use_v2_;
+};
+
+class XdsEnd2endTest : public ::testing::TestWithParam<TestType> {
+ protected:
+ XdsEnd2endTest(size_t num_backends, size_t num_balancers,
+ int client_load_reporting_interval_seconds = 100)
+ : num_backends_(num_backends),
+ num_balancers_(num_balancers),
+ client_load_reporting_interval_seconds_(
+ client_load_reporting_interval_seconds) {}
+
+ static void SetUpTestCase() {
+ // Make the backup poller poll very frequently in order to pick up
+ // updates from all the subchannels's FDs.
+ GPR_GLOBAL_CONFIG_SET(grpc_client_channel_backup_poll_interval_ms, 1);
+#if TARGET_OS_IPHONE
+ // Workaround Apple CFStream bug
+ gpr_setenv("grpc_cfstream", "0");
+#endif
+ grpc_init();
+ }
+
+ static void TearDownTestCase() { grpc_shutdown(); }
+
+ void SetUp() override {
+ gpr_setenv("GRPC_XDS_EXPERIMENTAL_V3_SUPPORT", "true");
+ gpr_setenv("GRPC_XDS_BOOTSTRAP",
+ GetParam().use_v2() ? g_bootstrap_file_v2 : g_bootstrap_file_v3);
+ g_port_saver->Reset();
+ response_generator_ =
+ grpc_core::MakeRefCounted<grpc_core::FakeResolverResponseGenerator>();
+ // Inject xDS channel response generator.
+ lb_channel_response_generator_ =
+ grpc_core::MakeRefCounted<grpc_core::FakeResolverResponseGenerator>();
+ xds_channel_args_to_add_.emplace_back(
+ grpc_core::FakeResolverResponseGenerator::MakeChannelArg(
+ lb_channel_response_generator_.get()));
+ if (xds_resource_does_not_exist_timeout_ms_ > 0) {
+ xds_channel_args_to_add_.emplace_back(grpc_channel_arg_integer_create(
+ const_cast<char*>(GRPC_ARG_XDS_RESOURCE_DOES_NOT_EXIST_TIMEOUT_MS),
+ xds_resource_does_not_exist_timeout_ms_));
+ }
+ xds_channel_args_.num_args = xds_channel_args_to_add_.size();
+ xds_channel_args_.args = xds_channel_args_to_add_.data();
+ grpc_core::internal::SetXdsChannelArgsForTest(&xds_channel_args_);
+ // Make sure each test creates a new XdsClient instance rather than
+ // reusing the one from the previous test. This avoids spurious failures
+ // caused when a load reporting test runs after a non-load reporting test
+ // and the XdsClient is still talking to the old LRS server, which fails
+ // because it's not expecting the client to connect. It also
+ // ensures that each test can independently set the global channel
+ // args for the xDS channel.
+ grpc_core::internal::UnsetGlobalXdsClientForTest();
+ // Start the backends.
+ for (size_t i = 0; i < num_backends_; ++i) {
+ backends_.emplace_back(new BackendServerThread);
+ backends_.back()->Start();
+ }
+ // Start the load balancers.
+ for (size_t i = 0; i < num_balancers_; ++i) {
+ balancers_.emplace_back(
+ new BalancerServerThread(GetParam().enable_load_reporting()
+ ? client_load_reporting_interval_seconds_
+ : 0));
+ balancers_.back()->Start();
+ if (GetParam().enable_rds_testing()) {
+ balancers_[i]->ads_service()->SetLdsToUseDynamicRds();
+ }
+ }
+ ResetStub();
+ }
+
+ const char* DefaultEdsServiceName() const {
+ return GetParam().use_xds_resolver() ? kDefaultEdsServiceName : kServerName;
+ }
+
+ void TearDown() override {
+ ShutdownAllBackends();
+ for (auto& balancer : balancers_) balancer->Shutdown();
+ // Clear global xDS channel args, since they will go out of scope
+ // when this test object is destroyed.
+ grpc_core::internal::SetXdsChannelArgsForTest(nullptr);
+ }
+
+ void StartAllBackends() {
+ for (auto& backend : backends_) backend->Start();
+ }
+
+ void StartBackend(size_t index) { backends_[index]->Start(); }
+
+ void ShutdownAllBackends() {
+ for (auto& backend : backends_) backend->Shutdown();
+ }
+
+ void ShutdownBackend(size_t index) { backends_[index]->Shutdown(); }
+
+ void ResetStub(int failover_timeout = 0) {
+ channel_ = CreateChannel(failover_timeout);
+ stub_ = grpc::testing::EchoTestService::NewStub(channel_);
+ stub1_ = grpc::testing::EchoTest1Service::NewStub(channel_);
+ stub2_ = grpc::testing::EchoTest2Service::NewStub(channel_);
+ }
+
+ std::shared_ptr<Channel> CreateChannel(
+ int failover_timeout = 0, const char* server_name = kServerName) {
+ ChannelArguments args;
+ if (failover_timeout > 0) {
+ args.SetInt(GRPC_ARG_PRIORITY_FAILOVER_TIMEOUT_MS, failover_timeout);
+ }
+ // If the parent channel is using the fake resolver, we inject the
+ // response generator here.
+ if (!GetParam().use_xds_resolver()) {
+ args.SetPointer(GRPC_ARG_FAKE_RESOLVER_RESPONSE_GENERATOR,
+ response_generator_.get());
+ }
+ TString uri = y_absl::StrCat(
+ GetParam().use_xds_resolver() ? "xds" : "fake", ":///", server_name);
+ // TODO(dgq): templatize tests to run everything using both secure and
+ // insecure channel credentials.
+ grpc_channel_credentials* channel_creds =
+ grpc_fake_transport_security_credentials_create();
+ grpc_call_credentials* call_creds = grpc_md_only_test_credentials_create(
+ g_kCallCredsMdKey, g_kCallCredsMdValue, false);
+ std::shared_ptr<ChannelCredentials> creds(
+ new SecureChannelCredentials(grpc_composite_channel_credentials_create(
+ channel_creds, call_creds, nullptr)));
+ call_creds->Unref();
+ channel_creds->Unref();
+ return ::grpc::CreateCustomChannel(uri, creds, args);
+ }
+
+ enum RpcService {
+ SERVICE_ECHO,
+ SERVICE_ECHO1,
+ SERVICE_ECHO2,
+ };
+
+ enum RpcMethod {
+ METHOD_ECHO,
+ METHOD_ECHO1,
+ METHOD_ECHO2,
+ };
+
+ struct RpcOptions {
+ RpcService service = SERVICE_ECHO;
+ RpcMethod method = METHOD_ECHO;
+ int timeout_ms = 1000;
+ bool wait_for_ready = false;
+ bool server_fail = false;
+ std::vector<std::pair<TString, TString>> metadata;
+
+ RpcOptions() {}
+
+ RpcOptions& set_rpc_service(RpcService rpc_service) {
+ service = rpc_service;
+ return *this;
+ }
+
+ RpcOptions& set_rpc_method(RpcMethod rpc_method) {
+ method = rpc_method;
+ return *this;
+ }
+
+ RpcOptions& set_timeout_ms(int rpc_timeout_ms) {
+ timeout_ms = rpc_timeout_ms;
+ return *this;
+ }
+
+ RpcOptions& set_wait_for_ready(bool rpc_wait_for_ready) {
+ wait_for_ready = rpc_wait_for_ready;
+ return *this;
+ }
+
+ RpcOptions& set_server_fail(bool rpc_server_fail) {
+ server_fail = rpc_server_fail;
+ return *this;
+ }
+
+ RpcOptions& set_metadata(
+ std::vector<std::pair<TString, TString>> rpc_metadata) {
+ metadata = rpc_metadata;
+ return *this;
+ }
+ };
+
+ template <typename Stub>
+ Status SendRpcMethod(Stub* stub, const RpcOptions& rpc_options,
+ ClientContext* context, EchoRequest& request,
+ EchoResponse* response) {
+ switch (rpc_options.method) {
+ case METHOD_ECHO:
+ return (*stub)->Echo(context, request, response);
+ case METHOD_ECHO1:
+ return (*stub)->Echo1(context, request, response);
+ case METHOD_ECHO2:
+ return (*stub)->Echo2(context, request, response);
+ }
+ }
+
+ void ResetBackendCounters(size_t start_index = 0, size_t stop_index = 0) {
+ if (stop_index == 0) stop_index = backends_.size();
+ for (size_t i = start_index; i < stop_index; ++i) {
+ backends_[i]->backend_service()->ResetCounters();
+ backends_[i]->backend_service1()->ResetCounters();
+ backends_[i]->backend_service2()->ResetCounters();
+ }
+ }
+
+ bool SeenAllBackends(size_t start_index = 0, size_t stop_index = 0,
+ const RpcOptions& rpc_options = RpcOptions()) {
+ if (stop_index == 0) stop_index = backends_.size();
+ for (size_t i = start_index; i < stop_index; ++i) {
+ switch (rpc_options.service) {
+ case SERVICE_ECHO:
+ if (backends_[i]->backend_service()->request_count() == 0)
+ return false;
+ break;
+ case SERVICE_ECHO1:
+ if (backends_[i]->backend_service1()->request_count() == 0)
+ return false;
+ break;
+ case SERVICE_ECHO2:
+ if (backends_[i]->backend_service2()->request_count() == 0)
+ return false;
+ break;
+ }
+ }
+ return true;
+ }
+
+ void SendRpcAndCount(int* num_total, int* num_ok, int* num_failure,
+ int* num_drops,
+ const RpcOptions& rpc_options = RpcOptions()) {
+ const Status status = SendRpc(rpc_options);
+ if (status.ok()) {
+ ++*num_ok;
+ } else {
+ if (status.error_message() == "Call dropped by load balancing policy") {
+ ++*num_drops;
+ } else {
+ ++*num_failure;
+ }
+ }
+ ++*num_total;
+ }
+
+ std::tuple<int, int, int> WaitForAllBackends(
+ size_t start_index = 0, size_t stop_index = 0, bool reset_counters = true,
+ const RpcOptions& rpc_options = RpcOptions(),
+ bool allow_failures = false) {
+ int num_ok = 0;
+ int num_failure = 0;
+ int num_drops = 0;
+ int num_total = 0;
+ while (!SeenAllBackends(start_index, stop_index, rpc_options)) {
+ SendRpcAndCount(&num_total, &num_ok, &num_failure, &num_drops,
+ rpc_options);
+ }
+ if (reset_counters) ResetBackendCounters();
+ gpr_log(GPR_INFO,
+ "Performed %d warm up requests against the backends. "
+ "%d succeeded, %d failed, %d dropped.",
+ num_total, num_ok, num_failure, num_drops);
+ if (!allow_failures) EXPECT_EQ(num_failure, 0);
+ return std::make_tuple(num_ok, num_failure, num_drops);
+ }
+
+ void WaitForBackend(size_t backend_idx, bool reset_counters = true,
+ bool require_success = false) {
+ gpr_log(GPR_INFO, "========= WAITING FOR BACKEND %lu ==========",
+ static_cast<unsigned long>(backend_idx));
+ do {
+ Status status = SendRpc();
+ if (require_success) {
+ EXPECT_TRUE(status.ok()) << "code=" << status.error_code()
+ << " message=" << status.error_message();
+ }
+ } while (backends_[backend_idx]->backend_service()->request_count() == 0);
+ if (reset_counters) ResetBackendCounters();
+ gpr_log(GPR_INFO, "========= BACKEND %lu READY ==========",
+ static_cast<unsigned long>(backend_idx));
+ }
+
+ grpc_core::ServerAddressList CreateAddressListFromPortList(
+ const std::vector<int>& ports) {
+ grpc_core::ServerAddressList addresses;
+ for (int port : ports) {
+ TString lb_uri_str = y_absl::StrCat("ipv4:127.0.0.1:", port);
+ grpc_uri* lb_uri = grpc_uri_parse(lb_uri_str.c_str(), true);
+ GPR_ASSERT(lb_uri != nullptr);
+ grpc_resolved_address address;
+ GPR_ASSERT(grpc_parse_uri(lb_uri, &address));
+ addresses.emplace_back(address.addr, address.len, nullptr);
+ grpc_uri_destroy(lb_uri);
+ }
+ return addresses;
+ }
+
+ void SetNextResolution(const std::vector<int>& ports) {
+ if (GetParam().use_xds_resolver()) return; // Not used with xds resolver.
+ grpc_core::ExecCtx exec_ctx;
+ grpc_core::Resolver::Result result;
+ result.addresses = CreateAddressListFromPortList(ports);
+ grpc_error* error = GRPC_ERROR_NONE;
+ const char* service_config_json =
+ GetParam().enable_load_reporting()
+ ? kDefaultServiceConfig
+ : kDefaultServiceConfigWithoutLoadReporting;
+ result.service_config =
+ grpc_core::ServiceConfig::Create(nullptr, service_config_json, &error);
+ ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_string(error);
+ ASSERT_NE(result.service_config.get(), nullptr);
+ response_generator_->SetResponse(std::move(result));
+ }
+
+ void SetNextResolutionForLbChannelAllBalancers(
+ const char* service_config_json = nullptr,
+ const char* expected_targets = nullptr) {
+ std::vector<int> ports;
+ for (size_t i = 0; i < balancers_.size(); ++i) {
+ ports.emplace_back(balancers_[i]->port());
+ }
+ SetNextResolutionForLbChannel(ports, service_config_json, expected_targets);
+ }
+
+ void SetNextResolutionForLbChannel(const std::vector<int>& ports,
+ const char* service_config_json = nullptr,
+ const char* expected_targets = nullptr) {
+ grpc_core::ExecCtx exec_ctx;
+ grpc_core::Resolver::Result result;
+ result.addresses = CreateAddressListFromPortList(ports);
+ if (service_config_json != nullptr) {
+ grpc_error* error = GRPC_ERROR_NONE;
+ result.service_config = grpc_core::ServiceConfig::Create(
+ nullptr, service_config_json, &error);
+ ASSERT_NE(result.service_config.get(), nullptr);
+ ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_string(error);
+ }
+ if (expected_targets != nullptr) {
+ grpc_arg expected_targets_arg = grpc_channel_arg_string_create(
+ const_cast<char*>(GRPC_ARG_FAKE_SECURITY_EXPECTED_TARGETS),
+ const_cast<char*>(expected_targets));
+ result.args =
+ grpc_channel_args_copy_and_add(nullptr, &expected_targets_arg, 1);
+ }
+ lb_channel_response_generator_->SetResponse(std::move(result));
+ }
+
+ void SetNextReresolutionResponse(const std::vector<int>& ports) {
+ grpc_core::ExecCtx exec_ctx;
+ grpc_core::Resolver::Result result;
+ result.addresses = CreateAddressListFromPortList(ports);
+ response_generator_->SetReresolutionResponse(std::move(result));
+ }
+
+ const std::vector<int> GetBackendPorts(size_t start_index = 0,
+ size_t stop_index = 0) const {
+ if (stop_index == 0) stop_index = backends_.size();
+ std::vector<int> backend_ports;
+ for (size_t i = start_index; i < stop_index; ++i) {
+ backend_ports.push_back(backends_[i]->port());
+ }
+ return backend_ports;
+ }
+
+ Status SendRpc(const RpcOptions& rpc_options = RpcOptions(),
+ EchoResponse* response = nullptr) {
+ const bool local_response = (response == nullptr);
+ if (local_response) response = new EchoResponse;
+ EchoRequest request;
+ ClientContext context;
+ for (const auto& metadata : rpc_options.metadata) {
+ context.AddMetadata(metadata.first, metadata.second);
+ }
+ if (rpc_options.timeout_ms != 0) {
+ context.set_deadline(
+ grpc_timeout_milliseconds_to_deadline(rpc_options.timeout_ms));
+ }
+ if (rpc_options.wait_for_ready) context.set_wait_for_ready(true);
+ request.set_message(kRequestMessage);
+ if (rpc_options.server_fail) {
+ request.mutable_param()->mutable_expected_error()->set_code(
+ GRPC_STATUS_FAILED_PRECONDITION);
+ }
+ Status status;
+ switch (rpc_options.service) {
+ case SERVICE_ECHO:
+ status =
+ SendRpcMethod(&stub_, rpc_options, &context, request, response);
+ break;
+ case SERVICE_ECHO1:
+ status =
+ SendRpcMethod(&stub1_, rpc_options, &context, request, response);
+ break;
+ case SERVICE_ECHO2:
+ status =
+ SendRpcMethod(&stub2_, rpc_options, &context, request, response);
+ break;
+ }
+ if (local_response) delete response;
+ return status;
+ }
+
+ void CheckRpcSendOk(const size_t times = 1,
+ const RpcOptions& rpc_options = RpcOptions()) {
+ for (size_t i = 0; i < times; ++i) {
+ EchoResponse response;
+ const Status status = SendRpc(rpc_options, &response);
+ EXPECT_TRUE(status.ok()) << "code=" << status.error_code()
+ << " message=" << status.error_message();
+ EXPECT_EQ(response.message(), kRequestMessage);
+ }
+ }
+
+ void CheckRpcSendFailure(const size_t times = 1,
+ const RpcOptions& rpc_options = RpcOptions()) {
+ for (size_t i = 0; i < times; ++i) {
+ const Status status = SendRpc(rpc_options);
+ EXPECT_FALSE(status.ok());
+ }
+ }
+
+ void SetRouteConfiguration(int idx, const RouteConfiguration& route_config) {
+ if (GetParam().enable_rds_testing()) {
+ balancers_[idx]->ads_service()->SetRdsResource(route_config);
+ } else {
+ balancers_[idx]->ads_service()->SetLdsResource(
+ AdsServiceImpl::BuildListener(route_config));
+ }
+ }
+
+ AdsServiceImpl::ResponseState RouteConfigurationResponseState(int idx) const {
+ AdsServiceImpl* ads_service = balancers_[idx]->ads_service();
+ if (GetParam().enable_rds_testing()) {
+ return ads_service->rds_response_state();
+ }
+ return ads_service->lds_response_state();
+ }
+
+ public:
+ // This method could benefit test subclasses; to make it accessible
+ // via bind with a qualified name, it needs to be public.
+ void SetEdsResourceWithDelay(size_t i,
+ const ClusterLoadAssignment& assignment,
+ int delay_ms) {
+ GPR_ASSERT(delay_ms > 0);
+ gpr_sleep_until(grpc_timeout_milliseconds_to_deadline(delay_ms));
+ balancers_[i]->ads_service()->SetEdsResource(assignment);
+ }
+
+ protected:
+ class ServerThread {
+ public:
+ ServerThread() : port_(g_port_saver->GetPort()) {}
+ virtual ~ServerThread(){};
+
+ void Start() {
+ gpr_log(GPR_INFO, "starting %s server on port %d", Type(), port_);
+ GPR_ASSERT(!running_);
+ running_ = true;
+ StartAllServices();
+ grpc_core::Mutex mu;
+ // We need to acquire the lock here in order to prevent the notify_one
+ // by ServerThread::Serve from firing before the wait below is hit.
+ grpc_core::MutexLock lock(&mu);
+ grpc_core::CondVar cond;
+ thread_.reset(
+ new std::thread(std::bind(&ServerThread::Serve, this, &mu, &cond)));
+ cond.Wait(&mu);
+ gpr_log(GPR_INFO, "%s server startup complete", Type());
+ }
+
+ void Serve(grpc_core::Mutex* mu, grpc_core::CondVar* cond) {
+ // We need to acquire the lock here in order to prevent the notify_one
+ // below from firing before its corresponding wait is executed.
+ grpc_core::MutexLock lock(mu);
+ std::ostringstream server_address;
+ server_address << "localhost:" << port_;
+ ServerBuilder builder;
+ std::shared_ptr<ServerCredentials> creds(new SecureServerCredentials(
+ grpc_fake_transport_security_server_credentials_create()));
+ builder.AddListeningPort(server_address.str(), creds);
+ RegisterAllServices(&builder);
+ server_ = builder.BuildAndStart();
+ cond->Signal();
+ }
+
+ void Shutdown() {
+ if (!running_) return;
+ gpr_log(GPR_INFO, "%s about to shutdown", Type());
+ ShutdownAllServices();
+ server_->Shutdown(grpc_timeout_milliseconds_to_deadline(0));
+ thread_->join();
+ gpr_log(GPR_INFO, "%s shutdown completed", Type());
+ running_ = false;
+ }
+
+ int port() const { return port_; }
+
+ private:
+ virtual void RegisterAllServices(ServerBuilder* builder) = 0;
+ virtual void StartAllServices() = 0;
+ virtual void ShutdownAllServices() = 0;
+
+ virtual const char* Type() = 0;
+
+ const int port_;
+ std::unique_ptr<Server> server_;
+ std::unique_ptr<std::thread> thread_;
+ bool running_ = false;
+ };
+
+ class BackendServerThread : public ServerThread {
+ public:
+ BackendServiceImpl<::grpc::testing::EchoTestService::Service>*
+ backend_service() {
+ return &backend_service_;
+ }
+ BackendServiceImpl<::grpc::testing::EchoTest1Service::Service>*
+ backend_service1() {
+ return &backend_service1_;
+ }
+ BackendServiceImpl<::grpc::testing::EchoTest2Service::Service>*
+ backend_service2() {
+ return &backend_service2_;
+ }
+
+ private:
+ void RegisterAllServices(ServerBuilder* builder) override {
+ builder->RegisterService(&backend_service_);
+ builder->RegisterService(&backend_service1_);
+ builder->RegisterService(&backend_service2_);
+ }
+
+ void StartAllServices() override {
+ backend_service_.Start();
+ backend_service1_.Start();
+ backend_service2_.Start();
+ }
+
+ void ShutdownAllServices() override {
+ backend_service_.Shutdown();
+ backend_service1_.Shutdown();
+ backend_service2_.Shutdown();
+ }
+
+ const char* Type() override { return "Backend"; }
+
+ BackendServiceImpl<::grpc::testing::EchoTestService::Service>
+ backend_service_;
+ BackendServiceImpl<::grpc::testing::EchoTest1Service::Service>
+ backend_service1_;
+ BackendServiceImpl<::grpc::testing::EchoTest2Service::Service>
+ backend_service2_;
+ };
+
+ class BalancerServerThread : public ServerThread {
+ public:
+ explicit BalancerServerThread(int client_load_reporting_interval = 0)
+ : ads_service_(new AdsServiceImpl(client_load_reporting_interval > 0)),
+ lrs_service_(new LrsServiceImpl(client_load_reporting_interval)) {}
+
+ AdsServiceImpl* ads_service() { return ads_service_.get(); }
+ LrsServiceImpl* lrs_service() { return lrs_service_.get(); }
+
+ private:
+ void RegisterAllServices(ServerBuilder* builder) override {
+ builder->RegisterService(ads_service_->v2_rpc_service());
+ builder->RegisterService(ads_service_->v3_rpc_service());
+ builder->RegisterService(lrs_service_->v2_rpc_service());
+ builder->RegisterService(lrs_service_->v3_rpc_service());
+ }
+
+ void StartAllServices() override {
+ ads_service_->Start();
+ lrs_service_->Start();
+ }
+
+ void ShutdownAllServices() override {
+ ads_service_->Shutdown();
+ lrs_service_->Shutdown();
+ }
+
+ const char* Type() override { return "Balancer"; }
+
+ std::shared_ptr<AdsServiceImpl> ads_service_;
+ std::shared_ptr<LrsServiceImpl> lrs_service_;
+ };
+
+ const size_t num_backends_;
+ const size_t num_balancers_;
+ const int client_load_reporting_interval_seconds_;
+ std::shared_ptr<Channel> channel_;
+ std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_;
+ std::unique_ptr<grpc::testing::EchoTest1Service::Stub> stub1_;
+ std::unique_ptr<grpc::testing::EchoTest2Service::Stub> stub2_;
+ std::vector<std::unique_ptr<BackendServerThread>> backends_;
+ std::vector<std::unique_ptr<BalancerServerThread>> balancers_;
+ grpc_core::RefCountedPtr<grpc_core::FakeResolverResponseGenerator>
+ response_generator_;
+ grpc_core::RefCountedPtr<grpc_core::FakeResolverResponseGenerator>
+ lb_channel_response_generator_;
+ int xds_resource_does_not_exist_timeout_ms_ = 0;
+ y_absl::InlinedVector<grpc_arg, 2> xds_channel_args_to_add_;
+ grpc_channel_args xds_channel_args_;
+};
+
+class BasicTest : public XdsEnd2endTest {
+ public:
+ BasicTest() : XdsEnd2endTest(4, 1) {}
+};
+
+// Tests that the balancer sends the correct response to the client, and the
+// client sends RPCs to the backends using the default child policy.
+TEST_P(BasicTest, Vanilla) {
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ const size_t kNumRpcsPerAddress = 100;
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts()},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()));
+ // Make sure that trying to connect works without a call.
+ channel_->GetState(true /* try_to_connect */);
+ // We need to wait for all backends to come online.
+ WaitForAllBackends();
+ // Send kNumRpcsPerAddress RPCs per server.
+ CheckRpcSendOk(kNumRpcsPerAddress * num_backends_);
+ // Each backend should have gotten 100 requests.
+ for (size_t i = 0; i < backends_.size(); ++i) {
+ EXPECT_EQ(kNumRpcsPerAddress,
+ backends_[i]->backend_service()->request_count());
+ }
+ // Check LB policy name for the channel.
+ EXPECT_EQ((GetParam().use_xds_resolver() ? "xds_cluster_manager_experimental"
+ : "eds_experimental"),
+ channel_->GetLoadBalancingPolicyName());
+}
+
+TEST_P(BasicTest, IgnoresUnhealthyEndpoints) {
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ const size_t kNumRpcsPerAddress = 100;
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0",
+ GetBackendPorts(),
+ kDefaultLocalityWeight,
+ kDefaultLocalityPriority,
+ {HealthStatus::DRAINING}},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()));
+ // Make sure that trying to connect works without a call.
+ channel_->GetState(true /* try_to_connect */);
+ // We need to wait for all backends to come online.
+ WaitForAllBackends(/*start_index=*/1);
+ // Send kNumRpcsPerAddress RPCs per server.
+ CheckRpcSendOk(kNumRpcsPerAddress * (num_backends_ - 1));
+ // Each backend should have gotten 100 requests.
+ for (size_t i = 1; i < backends_.size(); ++i) {
+ EXPECT_EQ(kNumRpcsPerAddress,
+ backends_[i]->backend_service()->request_count());
+ }
+}
+
+// Tests that subchannel sharing works when the same backend is listed multiple
+// times.
+TEST_P(BasicTest, SameBackendListedMultipleTimes) {
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ // Same backend listed twice.
+ std::vector<int> ports(2, backends_[0]->port());
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", ports},
+ });
+ const size_t kNumRpcsPerAddress = 10;
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()));
+ // We need to wait for the backend to come online.
+ WaitForBackend(0);
+ // Send kNumRpcsPerAddress RPCs per server.
+ CheckRpcSendOk(kNumRpcsPerAddress * ports.size());
+ // Backend should have gotten 20 requests.
+ EXPECT_EQ(kNumRpcsPerAddress * ports.size(),
+ backends_[0]->backend_service()->request_count());
+ // And they should have come from a single client port, because of
+ // subchannel sharing.
+ EXPECT_EQ(1UL, backends_[0]->backend_service()->clients().size());
+}
+
+// Tests that RPCs will be blocked until a non-empty serverlist is received.
+TEST_P(BasicTest, InitiallyEmptyServerlist) {
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ const int kServerlistDelayMs = 500 * grpc_test_slowdown_factor();
+ const int kCallDeadlineMs = kServerlistDelayMs * 2;
+ // First response is an empty serverlist, sent right away.
+ AdsServiceImpl::EdsResourceArgs::Locality empty_locality("locality0", {});
+ AdsServiceImpl::EdsResourceArgs args({
+ empty_locality,
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()));
+ // Send non-empty serverlist only after kServerlistDelayMs.
+ args = AdsServiceImpl::EdsResourceArgs({
+ {"locality0", GetBackendPorts()},
+ });
+ std::thread delayed_resource_setter(
+ std::bind(&BasicTest::SetEdsResourceWithDelay, this, 0,
+ AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()),
+ kServerlistDelayMs));
+ const auto t0 = system_clock::now();
+ // Client will block: LB will initially send empty serverlist.
+ CheckRpcSendOk(
+ 1, RpcOptions().set_timeout_ms(kCallDeadlineMs).set_wait_for_ready(true));
+ const auto ellapsed_ms =
+ std::chrono::duration_cast<std::chrono::milliseconds>(
+ system_clock::now() - t0);
+ // but eventually, the LB sends a serverlist update that allows the call to
+ // proceed. The call delay must be larger than the delay in sending the
+ // populated serverlist but under the call's deadline (which is enforced by
+ // the call's deadline).
+ EXPECT_GT(ellapsed_ms.count(), kServerlistDelayMs);
+ delayed_resource_setter.join();
+}
+
+// Tests that RPCs will fail with UNAVAILABLE instead of DEADLINE_EXCEEDED if
+// all the servers are unreachable.
+TEST_P(BasicTest, AllServersUnreachableFailFast) {
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ const size_t kNumUnreachableServers = 5;
+ std::vector<int> ports;
+ for (size_t i = 0; i < kNumUnreachableServers; ++i) {
+ ports.push_back(g_port_saver->GetPort());
+ }
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", ports},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()));
+ const Status status = SendRpc();
+ // The error shouldn't be DEADLINE_EXCEEDED.
+ EXPECT_EQ(StatusCode::UNAVAILABLE, status.error_code());
+}
+
+// Tests that RPCs fail when the backends are down, and will succeed again after
+// the backends are restarted.
+TEST_P(BasicTest, BackendsRestart) {
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts()},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()));
+ WaitForAllBackends();
+ // Stop backends. RPCs should fail.
+ ShutdownAllBackends();
+ // Sending multiple failed requests instead of just one to ensure that the
+ // client notices that all backends are down before we restart them. If we
+ // didn't do this, then a single RPC could fail here due to the race condition
+ // between the LB pick and the GOAWAY from the chosen backend being shut down,
+ // which would not actually prove that the client noticed that all of the
+ // backends are down. Then, when we send another request below (which we
+ // expect to succeed), if the callbacks happen in the wrong order, the same
+ // race condition could happen again due to the client not yet having noticed
+ // that the backends were all down.
+ CheckRpcSendFailure(num_backends_);
+ // Restart all backends. RPCs should start succeeding again.
+ StartAllBackends();
+ CheckRpcSendOk(1, RpcOptions().set_timeout_ms(2000).set_wait_for_ready(true));
+}
+
+TEST_P(BasicTest, IgnoresDuplicateUpdates) {
+ const size_t kNumRpcsPerAddress = 100;
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts()},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()));
+ // Wait for all backends to come online.
+ WaitForAllBackends();
+ // Send kNumRpcsPerAddress RPCs per server, but send an EDS update in
+ // between. If the update is not ignored, this will cause the
+ // round_robin policy to see an update, which will randomly reset its
+ // position in the address list.
+ for (size_t i = 0; i < kNumRpcsPerAddress; ++i) {
+ CheckRpcSendOk(2);
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()));
+ CheckRpcSendOk(2);
+ }
+ // Each backend should have gotten the right number of requests.
+ for (size_t i = 1; i < backends_.size(); ++i) {
+ EXPECT_EQ(kNumRpcsPerAddress,
+ backends_[i]->backend_service()->request_count());
+ }
+}
+
+using XdsResolverOnlyTest = BasicTest;
+
+// Tests switching over from one cluster to another.
+TEST_P(XdsResolverOnlyTest, ChangeClusters) {
+ const char* kNewClusterName = "new_cluster_name";
+ const char* kNewEdsServiceName = "new_eds_service_name";
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts(0, 2)},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args));
+ // We need to wait for all backends to come online.
+ WaitForAllBackends(0, 2);
+ // Populate new EDS resource.
+ AdsServiceImpl::EdsResourceArgs args2({
+ {"locality0", GetBackendPorts(2, 4)},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args2, kNewEdsServiceName));
+ // Populate new CDS resource.
+ Cluster new_cluster = balancers_[0]->ads_service()->default_cluster();
+ new_cluster.set_name(kNewClusterName);
+ new_cluster.mutable_eds_cluster_config()->set_service_name(
+ kNewEdsServiceName);
+ balancers_[0]->ads_service()->SetCdsResource(new_cluster);
+ // Change RDS resource to point to new cluster.
+ RouteConfiguration new_route_config =
+ balancers_[0]->ads_service()->default_route_config();
+ new_route_config.mutable_virtual_hosts(0)
+ ->mutable_routes(0)
+ ->mutable_route()
+ ->set_cluster(kNewClusterName);
+ Listener listener =
+ balancers_[0]->ads_service()->BuildListener(new_route_config);
+ balancers_[0]->ads_service()->SetLdsResource(listener);
+ // Wait for all new backends to be used.
+ std::tuple<int, int, int> counts = WaitForAllBackends(2, 4);
+ // Make sure no RPCs failed in the transition.
+ EXPECT_EQ(0, std::get<1>(counts));
+}
+
+// Tests that we go into TRANSIENT_FAILURE if the Cluster disappears.
+TEST_P(XdsResolverOnlyTest, ClusterRemoved) {
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts()},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args));
+ // We need to wait for all backends to come online.
+ WaitForAllBackends();
+ // Unset CDS resource.
+ balancers_[0]->ads_service()->UnsetResource(kCdsTypeUrl, kDefaultClusterName);
+ // Wait for RPCs to start failing.
+ do {
+ } while (SendRpc(RpcOptions(), nullptr).ok());
+ // Make sure RPCs are still failing.
+ CheckRpcSendFailure(1000);
+ // Make sure we ACK'ed the update.
+ EXPECT_EQ(balancers_[0]->ads_service()->cds_response_state().state,
+ AdsServiceImpl::ResponseState::ACKED);
+}
+
+// Tests that we restart all xDS requests when we reestablish the ADS call.
+TEST_P(XdsResolverOnlyTest, RestartsRequestsUponReconnection) {
+ balancers_[0]->ads_service()->SetLdsToUseDynamicRds();
+ const char* kNewClusterName = "new_cluster_name";
+ const char* kNewEdsServiceName = "new_eds_service_name";
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts(0, 2)},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args));
+ // We need to wait for all backends to come online.
+ WaitForAllBackends(0, 2);
+ // Now shut down and restart the balancer. When the client
+ // reconnects, it should automatically restart the requests for all
+ // resource types.
+ balancers_[0]->Shutdown();
+ balancers_[0]->Start();
+ // Make sure things are still working.
+ CheckRpcSendOk(100);
+ // Populate new EDS resource.
+ AdsServiceImpl::EdsResourceArgs args2({
+ {"locality0", GetBackendPorts(2, 4)},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args2, kNewEdsServiceName));
+ // Populate new CDS resource.
+ Cluster new_cluster = balancers_[0]->ads_service()->default_cluster();
+ new_cluster.set_name(kNewClusterName);
+ new_cluster.mutable_eds_cluster_config()->set_service_name(
+ kNewEdsServiceName);
+ balancers_[0]->ads_service()->SetCdsResource(new_cluster);
+ // Change RDS resource to point to new cluster.
+ RouteConfiguration new_route_config =
+ balancers_[0]->ads_service()->default_route_config();
+ new_route_config.mutable_virtual_hosts(0)
+ ->mutable_routes(0)
+ ->mutable_route()
+ ->set_cluster(kNewClusterName);
+ balancers_[0]->ads_service()->SetRdsResource(new_route_config);
+ // Wait for all new backends to be used.
+ std::tuple<int, int, int> counts = WaitForAllBackends(2, 4);
+ // Make sure no RPCs failed in the transition.
+ EXPECT_EQ(0, std::get<1>(counts));
+}
+
+TEST_P(XdsResolverOnlyTest, DefaultRouteSpecifiesSlashPrefix) {
+ RouteConfiguration route_config =
+ balancers_[0]->ads_service()->default_route_config();
+ route_config.mutable_virtual_hosts(0)
+ ->mutable_routes(0)
+ ->mutable_match()
+ ->set_prefix("/");
+ balancers_[0]->ads_service()->SetLdsResource(
+ AdsServiceImpl::BuildListener(route_config));
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts()},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args));
+ // We need to wait for all backends to come online.
+ WaitForAllBackends();
+}
+
+TEST_P(XdsResolverOnlyTest, CircuitBreaking) {
+ class TestRpc {
+ public:
+ TestRpc() {}
+
+ void StartRpc(grpc::testing::EchoTestService::Stub* stub) {
+ sender_thread_ = std::thread([this, stub]() {
+ EchoResponse response;
+ EchoRequest request;
+ request.mutable_param()->set_client_cancel_after_us(1 * 1000 * 1000);
+ request.set_message(kRequestMessage);
+ status_ = stub->Echo(&context_, request, &response);
+ });
+ }
+
+ void CancelRpc() {
+ context_.TryCancel();
+ sender_thread_.join();
+ }
+
+ private:
+ std::thread sender_thread_;
+ ClientContext context_;
+ Status status_;
+ };
+
+ gpr_setenv("GRPC_XDS_EXPERIMENTAL_CIRCUIT_BREAKING", "true");
+ constexpr size_t kMaxConcurrentRequests = 10;
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ // Populate new EDS resources.
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts(0, 1)},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args));
+ // Update CDS resource to set max concurrent request.
+ CircuitBreakers circuit_breaks;
+ Cluster cluster = balancers_[0]->ads_service()->default_cluster();
+ auto* threshold = cluster.mutable_circuit_breakers()->add_thresholds();
+ threshold->set_priority(RoutingPriority::DEFAULT);
+ threshold->mutable_max_requests()->set_value(kMaxConcurrentRequests);
+ balancers_[0]->ads_service()->SetCdsResource(cluster);
+ // Send exactly max_concurrent_requests long RPCs.
+ TestRpc rpcs[kMaxConcurrentRequests];
+ for (size_t i = 0; i < kMaxConcurrentRequests; ++i) {
+ rpcs[i].StartRpc(stub_.get());
+ }
+ // Wait for all RPCs to be in flight.
+ while (backends_[0]->backend_service()->RpcsWaitingForClientCancel() <
+ kMaxConcurrentRequests) {
+ gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME),
+ gpr_time_from_micros(1 * 1000, GPR_TIMESPAN)));
+ }
+ // Sending a RPC now should fail, the error message should tell us
+ // we hit the max concurrent requests limit and got dropped.
+ Status status = SendRpc();
+ EXPECT_FALSE(status.ok());
+ EXPECT_EQ(status.error_message(), "Call dropped by load balancing policy");
+ // Cancel one RPC to allow another one through
+ rpcs[0].CancelRpc();
+ status = SendRpc();
+ EXPECT_TRUE(status.ok());
+ for (size_t i = 1; i < kMaxConcurrentRequests; ++i) {
+ rpcs[i].CancelRpc();
+ }
+ // Make sure RPCs go to the correct backend:
+ EXPECT_EQ(kMaxConcurrentRequests + 1,
+ backends_[0]->backend_service()->request_count());
+ gpr_unsetenv("GRPC_XDS_EXPERIMENTAL_CIRCUIT_BREAKING");
+}
+
+TEST_P(XdsResolverOnlyTest, CircuitBreakingDisabled) {
+ class TestRpc {
+ public:
+ TestRpc() {}
+
+ void StartRpc(grpc::testing::EchoTestService::Stub* stub) {
+ sender_thread_ = std::thread([this, stub]() {
+ EchoResponse response;
+ EchoRequest request;
+ request.mutable_param()->set_client_cancel_after_us(1 * 1000 * 1000);
+ request.set_message(kRequestMessage);
+ status_ = stub->Echo(&context_, request, &response);
+ });
+ }
+
+ void CancelRpc() {
+ context_.TryCancel();
+ sender_thread_.join();
+ }
+
+ private:
+ std::thread sender_thread_;
+ ClientContext context_;
+ Status status_;
+ };
+
+ constexpr size_t kMaxConcurrentRequests = 10;
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ // Populate new EDS resources.
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts(0, 1)},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args));
+ // Update CDS resource to set max concurrent request.
+ CircuitBreakers circuit_breaks;
+ Cluster cluster = balancers_[0]->ads_service()->default_cluster();
+ auto* threshold = cluster.mutable_circuit_breakers()->add_thresholds();
+ threshold->set_priority(RoutingPriority::DEFAULT);
+ threshold->mutable_max_requests()->set_value(kMaxConcurrentRequests);
+ balancers_[0]->ads_service()->SetCdsResource(cluster);
+ // Send exactly max_concurrent_requests long RPCs.
+ TestRpc rpcs[kMaxConcurrentRequests];
+ for (size_t i = 0; i < kMaxConcurrentRequests; ++i) {
+ rpcs[i].StartRpc(stub_.get());
+ }
+ // Wait for all RPCs to be in flight.
+ while (backends_[0]->backend_service()->RpcsWaitingForClientCancel() <
+ kMaxConcurrentRequests) {
+ gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME),
+ gpr_time_from_micros(1 * 1000, GPR_TIMESPAN)));
+ }
+ // Sending a RPC now should not fail as circuit breaking is disabled.
+ Status status = SendRpc();
+ EXPECT_TRUE(status.ok());
+ for (size_t i = 0; i < kMaxConcurrentRequests; ++i) {
+ rpcs[i].CancelRpc();
+ }
+ // Make sure RPCs go to the correct backend:
+ EXPECT_EQ(kMaxConcurrentRequests + 1,
+ backends_[0]->backend_service()->request_count());
+}
+
+TEST_P(XdsResolverOnlyTest, MultipleChannelsShareXdsClient) {
+ const char* kNewServerName = "new-server.example.com";
+ Listener listener = balancers_[0]->ads_service()->default_listener();
+ listener.set_name(kNewServerName);
+ balancers_[0]->ads_service()->SetLdsResource(listener);
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts()},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args));
+ WaitForAllBackends();
+ // Create second channel and tell it to connect to kNewServerName.
+ auto channel2 = CreateChannel(/*failover_timeout=*/0, kNewServerName);
+ channel2->GetState(/*try_to_connect=*/true);
+ ASSERT_TRUE(
+ channel2->WaitForConnected(grpc_timeout_milliseconds_to_deadline(100)));
+ // Make sure there's only one client connected.
+ EXPECT_EQ(1UL, balancers_[0]->ads_service()->clients().size());
+}
+
+class XdsResolverLoadReportingOnlyTest : public XdsEnd2endTest {
+ public:
+ XdsResolverLoadReportingOnlyTest() : XdsEnd2endTest(4, 1, 3) {}
+};
+
+// Tests load reporting when switching over from one cluster to another.
+TEST_P(XdsResolverLoadReportingOnlyTest, ChangeClusters) {
+ const char* kNewClusterName = "new_cluster_name";
+ const char* kNewEdsServiceName = "new_eds_service_name";
+ balancers_[0]->lrs_service()->set_cluster_names(
+ {kDefaultClusterName, kNewClusterName});
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ // cluster kDefaultClusterName -> locality0 -> backends 0 and 1
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts(0, 2)},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args));
+ // cluster kNewClusterName -> locality1 -> backends 2 and 3
+ AdsServiceImpl::EdsResourceArgs args2({
+ {"locality1", GetBackendPorts(2, 4)},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args2, kNewEdsServiceName));
+ // CDS resource for kNewClusterName.
+ Cluster new_cluster = balancers_[0]->ads_service()->default_cluster();
+ new_cluster.set_name(kNewClusterName);
+ new_cluster.mutable_eds_cluster_config()->set_service_name(
+ kNewEdsServiceName);
+ balancers_[0]->ads_service()->SetCdsResource(new_cluster);
+ // Wait for all backends to come online.
+ int num_ok = 0;
+ int num_failure = 0;
+ int num_drops = 0;
+ std::tie(num_ok, num_failure, num_drops) = WaitForAllBackends(0, 2);
+ // The load report received at the balancer should be correct.
+ std::vector<ClientStats> load_report =
+ balancers_[0]->lrs_service()->WaitForLoadReport();
+ EXPECT_THAT(
+ load_report,
+ ::testing::ElementsAre(::testing::AllOf(
+ ::testing::Property(&ClientStats::cluster_name, kDefaultClusterName),
+ ::testing::Property(
+ &ClientStats::locality_stats,
+ ::testing::ElementsAre(::testing::Pair(
+ "locality0",
+ ::testing::AllOf(
+ ::testing::Field(&ClientStats::LocalityStats::
+ total_successful_requests,
+ num_ok),
+ ::testing::Field(&ClientStats::LocalityStats::
+ total_requests_in_progress,
+ 0UL),
+ ::testing::Field(
+ &ClientStats::LocalityStats::total_error_requests,
+ num_failure),
+ ::testing::Field(
+ &ClientStats::LocalityStats::total_issued_requests,
+ num_failure + num_ok))))),
+ ::testing::Property(&ClientStats::total_dropped_requests,
+ num_drops))));
+ // Change RDS resource to point to new cluster.
+ RouteConfiguration new_route_config =
+ balancers_[0]->ads_service()->default_route_config();
+ new_route_config.mutable_virtual_hosts(0)
+ ->mutable_routes(0)
+ ->mutable_route()
+ ->set_cluster(kNewClusterName);
+ Listener listener =
+ balancers_[0]->ads_service()->BuildListener(new_route_config);
+ balancers_[0]->ads_service()->SetLdsResource(listener);
+ // Wait for all new backends to be used.
+ std::tie(num_ok, num_failure, num_drops) = WaitForAllBackends(2, 4);
+ // The load report received at the balancer should be correct.
+ load_report = balancers_[0]->lrs_service()->WaitForLoadReport();
+ EXPECT_THAT(
+ load_report,
+ ::testing::ElementsAre(
+ ::testing::AllOf(
+ ::testing::Property(&ClientStats::cluster_name,
+ kDefaultClusterName),
+ ::testing::Property(
+ &ClientStats::locality_stats,
+ ::testing::ElementsAre(::testing::Pair(
+ "locality0",
+ ::testing::AllOf(
+ ::testing::Field(&ClientStats::LocalityStats::
+ total_successful_requests,
+ ::testing::Lt(num_ok)),
+ ::testing::Field(&ClientStats::LocalityStats::
+ total_requests_in_progress,
+ 0UL),
+ ::testing::Field(
+ &ClientStats::LocalityStats::total_error_requests,
+ ::testing::Le(num_failure)),
+ ::testing::Field(
+ &ClientStats::LocalityStats::
+ total_issued_requests,
+ ::testing::Le(num_failure + num_ok)))))),
+ ::testing::Property(&ClientStats::total_dropped_requests,
+ num_drops)),
+ ::testing::AllOf(
+ ::testing::Property(&ClientStats::cluster_name, kNewClusterName),
+ ::testing::Property(
+ &ClientStats::locality_stats,
+ ::testing::ElementsAre(::testing::Pair(
+ "locality1",
+ ::testing::AllOf(
+ ::testing::Field(&ClientStats::LocalityStats::
+ total_successful_requests,
+ ::testing::Le(num_ok)),
+ ::testing::Field(&ClientStats::LocalityStats::
+ total_requests_in_progress,
+ 0UL),
+ ::testing::Field(
+ &ClientStats::LocalityStats::total_error_requests,
+ ::testing::Le(num_failure)),
+ ::testing::Field(
+ &ClientStats::LocalityStats::
+ total_issued_requests,
+ ::testing::Le(num_failure + num_ok)))))),
+ ::testing::Property(&ClientStats::total_dropped_requests,
+ num_drops))));
+ int total_ok = 0;
+ int total_failure = 0;
+ for (const ClientStats& client_stats : load_report) {
+ total_ok += client_stats.total_successful_requests();
+ total_failure += client_stats.total_error_requests();
+ }
+ EXPECT_EQ(total_ok, num_ok);
+ EXPECT_EQ(total_failure, num_failure);
+ // The LRS service got a single request, and sent a single response.
+ EXPECT_EQ(1U, balancers_[0]->lrs_service()->request_count());
+ EXPECT_EQ(1U, balancers_[0]->lrs_service()->response_count());
+}
+
+using SecureNamingTest = BasicTest;
+
+// Tests that secure naming check passes if target name is expected.
+TEST_P(SecureNamingTest, TargetNameIsExpected) {
+ SetNextResolution({});
+ SetNextResolutionForLbChannel({balancers_[0]->port()}, nullptr, "xds_server");
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts()},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()));
+ CheckRpcSendOk();
+}
+
+// Tests that secure naming check fails if target name is unexpected.
+TEST_P(SecureNamingTest, TargetNameIsUnexpected) {
+ ::testing::FLAGS_gtest_death_test_style = "threadsafe";
+ SetNextResolution({});
+ SetNextResolutionForLbChannel({balancers_[0]->port()}, nullptr,
+ "incorrect_server_name");
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts()},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()));
+ // Make sure that we blow up (via abort() from the security connector) when
+ // the name from the balancer doesn't match expectations.
+ ASSERT_DEATH_IF_SUPPORTED({ CheckRpcSendOk(); }, "");
+}
+
+using LdsTest = BasicTest;
+
+// Tests that LDS client should send a NACK if there is no API listener in the
+// Listener in the LDS response.
+TEST_P(LdsTest, NoApiListener) {
+ auto listener = balancers_[0]->ads_service()->default_listener();
+ listener.clear_api_listener();
+ balancers_[0]->ads_service()->SetLdsResource(listener);
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ CheckRpcSendFailure();
+ const auto& response_state =
+ balancers_[0]->ads_service()->lds_response_state();
+ EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED);
+ EXPECT_EQ(response_state.error_message, "Listener has no ApiListener.");
+}
+
+// Tests that LDS client should send a NACK if the route_specifier in the
+// http_connection_manager is neither inlined route_config nor RDS.
+TEST_P(LdsTest, WrongRouteSpecifier) {
+ auto listener = balancers_[0]->ads_service()->default_listener();
+ HttpConnectionManager http_connection_manager;
+ http_connection_manager.mutable_scoped_routes();
+ listener.mutable_api_listener()->mutable_api_listener()->PackFrom(
+ http_connection_manager);
+ balancers_[0]->ads_service()->SetLdsResource(listener);
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ CheckRpcSendFailure();
+ const auto& response_state =
+ balancers_[0]->ads_service()->lds_response_state();
+ EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED);
+ EXPECT_EQ(response_state.error_message,
+ "HttpConnectionManager neither has inlined route_config nor RDS.");
+}
+
+// Tests that LDS client should send a NACK if the rds message in the
+// http_connection_manager is missing the config_source field.
+TEST_P(LdsTest, RdsMissingConfigSource) {
+ auto listener = balancers_[0]->ads_service()->default_listener();
+ HttpConnectionManager http_connection_manager;
+ http_connection_manager.mutable_rds()->set_route_config_name(
+ kDefaultRouteConfigurationName);
+ listener.mutable_api_listener()->mutable_api_listener()->PackFrom(
+ http_connection_manager);
+ balancers_[0]->ads_service()->SetLdsResource(listener);
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ CheckRpcSendFailure();
+ const auto& response_state =
+ balancers_[0]->ads_service()->lds_response_state();
+ EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED);
+ EXPECT_EQ(response_state.error_message,
+ "HttpConnectionManager missing config_source for RDS.");
+}
+
+// Tests that LDS client should send a NACK if the rds message in the
+// http_connection_manager has a config_source field that does not specify ADS.
+TEST_P(LdsTest, RdsConfigSourceDoesNotSpecifyAds) {
+ auto listener = balancers_[0]->ads_service()->default_listener();
+ HttpConnectionManager http_connection_manager;
+ auto* rds = http_connection_manager.mutable_rds();
+ rds->set_route_config_name(kDefaultRouteConfigurationName);
+ rds->mutable_config_source()->mutable_self();
+ listener.mutable_api_listener()->mutable_api_listener()->PackFrom(
+ http_connection_manager);
+ balancers_[0]->ads_service()->SetLdsResource(listener);
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ CheckRpcSendFailure();
+ const auto& response_state =
+ balancers_[0]->ads_service()->lds_response_state();
+ EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED);
+ EXPECT_EQ(response_state.error_message,
+ "HttpConnectionManager ConfigSource for RDS does not specify ADS.");
+}
+
+using LdsRdsTest = BasicTest;
+
+// Tests that LDS client should send an ACK upon correct LDS response (with
+// inlined RDS result).
+TEST_P(LdsRdsTest, Vanilla) {
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ (void)SendRpc();
+ EXPECT_EQ(RouteConfigurationResponseState(0).state,
+ AdsServiceImpl::ResponseState::ACKED);
+ // Make sure we actually used the RPC service for the right version of xDS.
+ EXPECT_EQ(balancers_[0]->ads_service()->seen_v2_client(),
+ GetParam().use_v2());
+ EXPECT_NE(balancers_[0]->ads_service()->seen_v3_client(),
+ GetParam().use_v2());
+}
+
+// Tests that we go into TRANSIENT_FAILURE if the Listener is removed.
+TEST_P(LdsRdsTest, ListenerRemoved) {
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts()},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args));
+ // We need to wait for all backends to come online.
+ WaitForAllBackends();
+ // Unset LDS resource.
+ balancers_[0]->ads_service()->UnsetResource(kLdsTypeUrl, kServerName);
+ // Wait for RPCs to start failing.
+ do {
+ } while (SendRpc(RpcOptions(), nullptr).ok());
+ // Make sure RPCs are still failing.
+ CheckRpcSendFailure(1000);
+ // Make sure we ACK'ed the update.
+ EXPECT_EQ(balancers_[0]->ads_service()->lds_response_state().state,
+ AdsServiceImpl::ResponseState::ACKED);
+}
+
+// Tests that LDS client ACKs but fails if matching domain can't be found in
+// the LDS response.
+TEST_P(LdsRdsTest, NoMatchedDomain) {
+ RouteConfiguration route_config =
+ balancers_[0]->ads_service()->default_route_config();
+ route_config.mutable_virtual_hosts(0)->clear_domains();
+ route_config.mutable_virtual_hosts(0)->add_domains("unmatched_domain");
+ SetRouteConfiguration(0, route_config);
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ CheckRpcSendFailure();
+ // Do a bit of polling, to allow the ACK to get to the ADS server.
+ channel_->WaitForConnected(grpc_timeout_milliseconds_to_deadline(100));
+ const auto& response_state = RouteConfigurationResponseState(0);
+ EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::ACKED);
+}
+
+// Tests that LDS client should choose the virtual host with matching domain if
+// multiple virtual hosts exist in the LDS response.
+TEST_P(LdsRdsTest, ChooseMatchedDomain) {
+ RouteConfiguration route_config =
+ balancers_[0]->ads_service()->default_route_config();
+ *(route_config.add_virtual_hosts()) = route_config.virtual_hosts(0);
+ route_config.mutable_virtual_hosts(0)->clear_domains();
+ route_config.mutable_virtual_hosts(0)->add_domains("unmatched_domain");
+ SetRouteConfiguration(0, route_config);
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ (void)SendRpc();
+ EXPECT_EQ(RouteConfigurationResponseState(0).state,
+ AdsServiceImpl::ResponseState::ACKED);
+}
+
+// Tests that LDS client should choose the last route in the virtual host if
+// multiple routes exist in the LDS response.
+TEST_P(LdsRdsTest, ChooseLastRoute) {
+ RouteConfiguration route_config =
+ balancers_[0]->ads_service()->default_route_config();
+ *(route_config.mutable_virtual_hosts(0)->add_routes()) =
+ route_config.virtual_hosts(0).routes(0);
+ route_config.mutable_virtual_hosts(0)
+ ->mutable_routes(0)
+ ->mutable_route()
+ ->mutable_cluster_header();
+ SetRouteConfiguration(0, route_config);
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ (void)SendRpc();
+ EXPECT_EQ(RouteConfigurationResponseState(0).state,
+ AdsServiceImpl::ResponseState::ACKED);
+}
+
+// Tests that LDS client should send a NACK if route match has a case_sensitive
+// set to false.
+TEST_P(LdsRdsTest, RouteMatchHasCaseSensitiveFalse) {
+ RouteConfiguration route_config =
+ balancers_[0]->ads_service()->default_route_config();
+ auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0);
+ route1->mutable_match()->mutable_case_sensitive()->set_value(false);
+ SetRouteConfiguration(0, route_config);
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ CheckRpcSendFailure();
+ const auto& response_state = RouteConfigurationResponseState(0);
+ EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED);
+ EXPECT_EQ(response_state.error_message,
+ "case_sensitive if set must be set to true.");
+}
+
+// Tests that LDS client should ignore route which has query_parameters.
+TEST_P(LdsRdsTest, RouteMatchHasQueryParameters) {
+ RouteConfiguration route_config =
+ balancers_[0]->ads_service()->default_route_config();
+ auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0);
+ route1->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/");
+ route1->mutable_match()->add_query_parameters();
+ SetRouteConfiguration(0, route_config);
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ CheckRpcSendFailure();
+ const auto& response_state = RouteConfigurationResponseState(0);
+ EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED);
+ EXPECT_EQ(response_state.error_message, "No valid routes specified.");
+}
+
+// Tests that LDS client should send a ACK if route match has a prefix
+// that is either empty or a single slash
+TEST_P(LdsRdsTest, RouteMatchHasValidPrefixEmptyOrSingleSlash) {
+ RouteConfiguration route_config =
+ balancers_[0]->ads_service()->default_route_config();
+ auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0);
+ route1->mutable_match()->set_prefix("");
+ auto* default_route = route_config.mutable_virtual_hosts(0)->add_routes();
+ default_route->mutable_match()->set_prefix("/");
+ default_route->mutable_route()->set_cluster(kDefaultClusterName);
+ SetRouteConfiguration(0, route_config);
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ (void)SendRpc();
+ const auto& response_state = RouteConfigurationResponseState(0);
+ EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::ACKED);
+}
+
+// Tests that LDS client should ignore route which has a path
+// prefix string does not start with "/".
+TEST_P(LdsRdsTest, RouteMatchHasInvalidPrefixNoLeadingSlash) {
+ RouteConfiguration route_config =
+ balancers_[0]->ads_service()->default_route_config();
+ auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0);
+ route1->mutable_match()->set_prefix("grpc.testing.EchoTest1Service/");
+ SetRouteConfiguration(0, route_config);
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ CheckRpcSendFailure();
+ const auto& response_state = RouteConfigurationResponseState(0);
+ EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED);
+ EXPECT_EQ(response_state.error_message, "No valid routes specified.");
+}
+
+// Tests that LDS client should ignore route which has a prefix
+// string with more than 2 slashes.
+TEST_P(LdsRdsTest, RouteMatchHasInvalidPrefixExtraContent) {
+ RouteConfiguration route_config =
+ balancers_[0]->ads_service()->default_route_config();
+ auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0);
+ route1->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/Echo1/");
+ SetRouteConfiguration(0, route_config);
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ CheckRpcSendFailure();
+ const auto& response_state = RouteConfigurationResponseState(0);
+ EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED);
+ EXPECT_EQ(response_state.error_message, "No valid routes specified.");
+}
+
+// Tests that LDS client should ignore route which has a prefix
+// string "//".
+TEST_P(LdsRdsTest, RouteMatchHasInvalidPrefixDoubleSlash) {
+ RouteConfiguration route_config =
+ balancers_[0]->ads_service()->default_route_config();
+ auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0);
+ route1->mutable_match()->set_prefix("//");
+ SetRouteConfiguration(0, route_config);
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ CheckRpcSendFailure();
+ const auto& response_state = RouteConfigurationResponseState(0);
+ EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED);
+ EXPECT_EQ(response_state.error_message, "No valid routes specified.");
+}
+
+// Tests that LDS client should ignore route which has path
+// but it's empty.
+TEST_P(LdsRdsTest, RouteMatchHasInvalidPathEmptyPath) {
+ RouteConfiguration route_config =
+ balancers_[0]->ads_service()->default_route_config();
+ auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0);
+ route1->mutable_match()->set_path("");
+ SetRouteConfiguration(0, route_config);
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ CheckRpcSendFailure();
+ const auto& response_state = RouteConfigurationResponseState(0);
+ EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED);
+ EXPECT_EQ(response_state.error_message, "No valid routes specified.");
+}
+
+// Tests that LDS client should ignore route which has path
+// string does not start with "/".
+TEST_P(LdsRdsTest, RouteMatchHasInvalidPathNoLeadingSlash) {
+ RouteConfiguration route_config =
+ balancers_[0]->ads_service()->default_route_config();
+ auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0);
+ route1->mutable_match()->set_path("grpc.testing.EchoTest1Service/Echo1");
+ SetRouteConfiguration(0, route_config);
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ CheckRpcSendFailure();
+ const auto& response_state = RouteConfigurationResponseState(0);
+ EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED);
+ EXPECT_EQ(response_state.error_message, "No valid routes specified.");
+}
+
+// Tests that LDS client should ignore route which has path
+// string that has too many slashes; for example, ends with "/".
+TEST_P(LdsRdsTest, RouteMatchHasInvalidPathTooManySlashes) {
+ RouteConfiguration route_config =
+ balancers_[0]->ads_service()->default_route_config();
+ auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0);
+ route1->mutable_match()->set_path("/grpc.testing.EchoTest1Service/Echo1/");
+ SetRouteConfiguration(0, route_config);
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ CheckRpcSendFailure();
+ const auto& response_state = RouteConfigurationResponseState(0);
+ EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED);
+ EXPECT_EQ(response_state.error_message, "No valid routes specified.");
+}
+
+// Tests that LDS client should ignore route which has path
+// string that has only 1 slash: missing "/" between service and method.
+TEST_P(LdsRdsTest, RouteMatchHasInvalidPathOnlyOneSlash) {
+ RouteConfiguration route_config =
+ balancers_[0]->ads_service()->default_route_config();
+ auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0);
+ route1->mutable_match()->set_path("/grpc.testing.EchoTest1Service.Echo1");
+ SetRouteConfiguration(0, route_config);
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ CheckRpcSendFailure();
+ const auto& response_state = RouteConfigurationResponseState(0);
+ EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED);
+ EXPECT_EQ(response_state.error_message, "No valid routes specified.");
+}
+
+// Tests that LDS client should ignore route which has path
+// string that is missing service.
+TEST_P(LdsRdsTest, RouteMatchHasInvalidPathMissingService) {
+ RouteConfiguration route_config =
+ balancers_[0]->ads_service()->default_route_config();
+ auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0);
+ route1->mutable_match()->set_path("//Echo1");
+ SetRouteConfiguration(0, route_config);
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ CheckRpcSendFailure();
+ const auto& response_state = RouteConfigurationResponseState(0);
+ EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED);
+ EXPECT_EQ(response_state.error_message, "No valid routes specified.");
+}
+
+// Tests that LDS client should ignore route which has path
+// string that is missing method.
+TEST_P(LdsRdsTest, RouteMatchHasInvalidPathMissingMethod) {
+ RouteConfiguration route_config =
+ balancers_[0]->ads_service()->default_route_config();
+ auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0);
+ route1->mutable_match()->set_path("/grpc.testing.EchoTest1Service/");
+ SetRouteConfiguration(0, route_config);
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ CheckRpcSendFailure();
+ const auto& response_state = RouteConfigurationResponseState(0);
+ EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED);
+ EXPECT_EQ(response_state.error_message, "No valid routes specified.");
+}
+
+// Test that LDS client should reject route which has invalid path regex.
+TEST_P(LdsRdsTest, RouteMatchHasInvalidPathRegex) {
+ const char* kNewCluster1Name = "new_cluster_1";
+ RouteConfiguration route_config =
+ balancers_[0]->ads_service()->default_route_config();
+ auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0);
+ route1->mutable_match()->mutable_safe_regex()->set_regex("a[z-a]");
+ route1->mutable_route()->set_cluster(kNewCluster1Name);
+ SetRouteConfiguration(0, route_config);
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ CheckRpcSendFailure();
+ const auto& response_state = RouteConfigurationResponseState(0);
+ EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED);
+ EXPECT_EQ(response_state.error_message,
+ "Invalid regex string specified in path matcher.");
+}
+
+// Tests that LDS client should send a NACK if route has an action other than
+// RouteAction in the LDS response.
+TEST_P(LdsRdsTest, RouteHasNoRouteAction) {
+ RouteConfiguration route_config =
+ balancers_[0]->ads_service()->default_route_config();
+ route_config.mutable_virtual_hosts(0)->mutable_routes(0)->mutable_redirect();
+ SetRouteConfiguration(0, route_config);
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ CheckRpcSendFailure();
+ const auto& response_state = RouteConfigurationResponseState(0);
+ EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED);
+ EXPECT_EQ(response_state.error_message, "No RouteAction found in route.");
+}
+
+TEST_P(LdsRdsTest, RouteActionClusterHasEmptyClusterName) {
+ RouteConfiguration route_config =
+ balancers_[0]->ads_service()->default_route_config();
+ auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0);
+ route1->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/");
+ route1->mutable_route()->set_cluster("");
+ auto* default_route = route_config.mutable_virtual_hosts(0)->add_routes();
+ default_route->mutable_match()->set_prefix("");
+ default_route->mutable_route()->set_cluster(kDefaultClusterName);
+ SetRouteConfiguration(0, route_config);
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ CheckRpcSendFailure();
+ const auto& response_state = RouteConfigurationResponseState(0);
+ EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED);
+ EXPECT_EQ(response_state.error_message,
+ "RouteAction cluster contains empty cluster name.");
+}
+
+TEST_P(LdsRdsTest, RouteActionWeightedTargetHasIncorrectTotalWeightSet) {
+ const size_t kWeight75 = 75;
+ const char* kNewCluster1Name = "new_cluster_1";
+ RouteConfiguration route_config =
+ balancers_[0]->ads_service()->default_route_config();
+ auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0);
+ route1->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/");
+ auto* weighted_cluster1 =
+ route1->mutable_route()->mutable_weighted_clusters()->add_clusters();
+ weighted_cluster1->set_name(kNewCluster1Name);
+ weighted_cluster1->mutable_weight()->set_value(kWeight75);
+ route1->mutable_route()
+ ->mutable_weighted_clusters()
+ ->mutable_total_weight()
+ ->set_value(kWeight75 + 1);
+ auto* default_route = route_config.mutable_virtual_hosts(0)->add_routes();
+ default_route->mutable_match()->set_prefix("");
+ default_route->mutable_route()->set_cluster(kDefaultClusterName);
+ SetRouteConfiguration(0, route_config);
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ CheckRpcSendFailure();
+ const auto& response_state = RouteConfigurationResponseState(0);
+ EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED);
+ EXPECT_EQ(response_state.error_message,
+ "RouteAction weighted_cluster has incorrect total weight");
+}
+
+TEST_P(LdsRdsTest, RouteActionWeightedTargetClusterHasEmptyClusterName) {
+ const size_t kWeight75 = 75;
+ RouteConfiguration route_config =
+ balancers_[0]->ads_service()->default_route_config();
+ auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0);
+ route1->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/");
+ auto* weighted_cluster1 =
+ route1->mutable_route()->mutable_weighted_clusters()->add_clusters();
+ weighted_cluster1->set_name("");
+ weighted_cluster1->mutable_weight()->set_value(kWeight75);
+ route1->mutable_route()
+ ->mutable_weighted_clusters()
+ ->mutable_total_weight()
+ ->set_value(kWeight75);
+ auto* default_route = route_config.mutable_virtual_hosts(0)->add_routes();
+ default_route->mutable_match()->set_prefix("");
+ default_route->mutable_route()->set_cluster(kDefaultClusterName);
+ SetRouteConfiguration(0, route_config);
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ CheckRpcSendFailure();
+ const auto& response_state = RouteConfigurationResponseState(0);
+ EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED);
+ EXPECT_EQ(
+ response_state.error_message,
+ "RouteAction weighted_cluster cluster contains empty cluster name.");
+}
+
+TEST_P(LdsRdsTest, RouteActionWeightedTargetClusterHasNoWeight) {
+ const size_t kWeight75 = 75;
+ const char* kNewCluster1Name = "new_cluster_1";
+ RouteConfiguration route_config =
+ balancers_[0]->ads_service()->default_route_config();
+ auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0);
+ route1->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/");
+ auto* weighted_cluster1 =
+ route1->mutable_route()->mutable_weighted_clusters()->add_clusters();
+ weighted_cluster1->set_name(kNewCluster1Name);
+ route1->mutable_route()
+ ->mutable_weighted_clusters()
+ ->mutable_total_weight()
+ ->set_value(kWeight75);
+ auto* default_route = route_config.mutable_virtual_hosts(0)->add_routes();
+ default_route->mutable_match()->set_prefix("");
+ default_route->mutable_route()->set_cluster(kDefaultClusterName);
+ SetRouteConfiguration(0, route_config);
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ CheckRpcSendFailure();
+ const auto& response_state = RouteConfigurationResponseState(0);
+ EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED);
+ EXPECT_EQ(response_state.error_message,
+ "RouteAction weighted_cluster cluster missing weight");
+}
+
+TEST_P(LdsRdsTest, RouteHeaderMatchInvalidRegex) {
+ const char* kNewCluster1Name = "new_cluster_1";
+ RouteConfiguration route_config =
+ balancers_[0]->ads_service()->default_route_config();
+ auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0);
+ route1->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/");
+ auto* header_matcher1 = route1->mutable_match()->add_headers();
+ header_matcher1->set_name("header1");
+ header_matcher1->mutable_safe_regex_match()->set_regex("a[z-a]");
+ route1->mutable_route()->set_cluster(kNewCluster1Name);
+ SetRouteConfiguration(0, route_config);
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ CheckRpcSendFailure();
+ const auto& response_state = RouteConfigurationResponseState(0);
+ EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED);
+ EXPECT_EQ(response_state.error_message,
+ "Invalid regex string specified in header matcher.");
+}
+
+TEST_P(LdsRdsTest, RouteHeaderMatchInvalidRange) {
+ const char* kNewCluster1Name = "new_cluster_1";
+ RouteConfiguration route_config =
+ balancers_[0]->ads_service()->default_route_config();
+ auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0);
+ route1->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/");
+ auto* header_matcher1 = route1->mutable_match()->add_headers();
+ header_matcher1->set_name("header1");
+ header_matcher1->mutable_range_match()->set_start(1001);
+ header_matcher1->mutable_range_match()->set_end(1000);
+ route1->mutable_route()->set_cluster(kNewCluster1Name);
+ SetRouteConfiguration(0, route_config);
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ CheckRpcSendFailure();
+ const auto& response_state = RouteConfigurationResponseState(0);
+ EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED);
+ EXPECT_EQ(response_state.error_message,
+ "Invalid range header matcher specifier specified: end "
+ "cannot be smaller than start.");
+}
+
+// Tests that LDS client should choose the default route (with no matching
+// specified) after unable to find a match with previous routes.
+TEST_P(LdsRdsTest, XdsRoutingPathMatching) {
+ const char* kNewCluster1Name = "new_cluster_1";
+ const char* kNewEdsService1Name = "new_eds_service_name_1";
+ const char* kNewCluster2Name = "new_cluster_2";
+ const char* kNewEdsService2Name = "new_eds_service_name_2";
+ const size_t kNumEcho1Rpcs = 10;
+ const size_t kNumEcho2Rpcs = 20;
+ const size_t kNumEchoRpcs = 30;
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ // Populate new EDS resources.
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts(0, 2)},
+ });
+ AdsServiceImpl::EdsResourceArgs args1({
+ {"locality0", GetBackendPorts(2, 3)},
+ });
+ AdsServiceImpl::EdsResourceArgs args2({
+ {"locality0", GetBackendPorts(3, 4)},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args));
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args1, kNewEdsService1Name));
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args2, kNewEdsService2Name));
+ // Populate new CDS resources.
+ Cluster new_cluster1 = balancers_[0]->ads_service()->default_cluster();
+ new_cluster1.set_name(kNewCluster1Name);
+ new_cluster1.mutable_eds_cluster_config()->set_service_name(
+ kNewEdsService1Name);
+ balancers_[0]->ads_service()->SetCdsResource(new_cluster1);
+ Cluster new_cluster2 = balancers_[0]->ads_service()->default_cluster();
+ new_cluster2.set_name(kNewCluster2Name);
+ new_cluster2.mutable_eds_cluster_config()->set_service_name(
+ kNewEdsService2Name);
+ balancers_[0]->ads_service()->SetCdsResource(new_cluster2);
+ // Populating Route Configurations for LDS.
+ RouteConfiguration new_route_config =
+ balancers_[0]->ads_service()->default_route_config();
+ auto* route1 = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0);
+ route1->mutable_match()->set_path("/grpc.testing.EchoTest1Service/Echo1");
+ route1->mutable_route()->set_cluster(kNewCluster1Name);
+ auto* route2 = new_route_config.mutable_virtual_hosts(0)->add_routes();
+ route2->mutable_match()->set_path("/grpc.testing.EchoTest2Service/Echo2");
+ route2->mutable_route()->set_cluster(kNewCluster2Name);
+ auto* route3 = new_route_config.mutable_virtual_hosts(0)->add_routes();
+ route3->mutable_match()->set_path("/grpc.testing.EchoTest3Service/Echo3");
+ route3->mutable_route()->set_cluster(kDefaultClusterName);
+ auto* default_route = new_route_config.mutable_virtual_hosts(0)->add_routes();
+ default_route->mutable_match()->set_prefix("");
+ default_route->mutable_route()->set_cluster(kDefaultClusterName);
+ SetRouteConfiguration(0, new_route_config);
+ WaitForAllBackends(0, 2);
+ CheckRpcSendOk(kNumEchoRpcs, RpcOptions().set_wait_for_ready(true));
+ CheckRpcSendOk(kNumEcho1Rpcs, RpcOptions()
+ .set_rpc_service(SERVICE_ECHO1)
+ .set_rpc_method(METHOD_ECHO1)
+ .set_wait_for_ready(true));
+ CheckRpcSendOk(kNumEcho2Rpcs, RpcOptions()
+ .set_rpc_service(SERVICE_ECHO2)
+ .set_rpc_method(METHOD_ECHO2)
+ .set_wait_for_ready(true));
+ // Make sure RPCs all go to the correct backend.
+ for (size_t i = 0; i < 2; ++i) {
+ EXPECT_EQ(kNumEchoRpcs / 2,
+ backends_[i]->backend_service()->request_count());
+ EXPECT_EQ(0, backends_[i]->backend_service1()->request_count());
+ EXPECT_EQ(0, backends_[i]->backend_service2()->request_count());
+ }
+ EXPECT_EQ(0, backends_[2]->backend_service()->request_count());
+ EXPECT_EQ(kNumEcho1Rpcs, backends_[2]->backend_service1()->request_count());
+ EXPECT_EQ(0, backends_[2]->backend_service2()->request_count());
+ EXPECT_EQ(0, backends_[3]->backend_service()->request_count());
+ EXPECT_EQ(0, backends_[3]->backend_service1()->request_count());
+ EXPECT_EQ(kNumEcho2Rpcs, backends_[3]->backend_service2()->request_count());
+}
+
+TEST_P(LdsRdsTest, XdsRoutingPrefixMatching) {
+ const char* kNewCluster1Name = "new_cluster_1";
+ const char* kNewEdsService1Name = "new_eds_service_name_1";
+ const char* kNewCluster2Name = "new_cluster_2";
+ const char* kNewEdsService2Name = "new_eds_service_name_2";
+ const size_t kNumEcho1Rpcs = 10;
+ const size_t kNumEcho2Rpcs = 20;
+ const size_t kNumEchoRpcs = 30;
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ // Populate new EDS resources.
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts(0, 2)},
+ });
+ AdsServiceImpl::EdsResourceArgs args1({
+ {"locality0", GetBackendPorts(2, 3)},
+ });
+ AdsServiceImpl::EdsResourceArgs args2({
+ {"locality0", GetBackendPorts(3, 4)},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args));
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args1, kNewEdsService1Name));
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args2, kNewEdsService2Name));
+ // Populate new CDS resources.
+ Cluster new_cluster1 = balancers_[0]->ads_service()->default_cluster();
+ new_cluster1.set_name(kNewCluster1Name);
+ new_cluster1.mutable_eds_cluster_config()->set_service_name(
+ kNewEdsService1Name);
+ balancers_[0]->ads_service()->SetCdsResource(new_cluster1);
+ Cluster new_cluster2 = balancers_[0]->ads_service()->default_cluster();
+ new_cluster2.set_name(kNewCluster2Name);
+ new_cluster2.mutable_eds_cluster_config()->set_service_name(
+ kNewEdsService2Name);
+ balancers_[0]->ads_service()->SetCdsResource(new_cluster2);
+ // Populating Route Configurations for LDS.
+ RouteConfiguration new_route_config =
+ balancers_[0]->ads_service()->default_route_config();
+ auto* route1 = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0);
+ route1->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/");
+ route1->mutable_route()->set_cluster(kNewCluster1Name);
+ auto* route2 = new_route_config.mutable_virtual_hosts(0)->add_routes();
+ route2->mutable_match()->set_prefix("/grpc.testing.EchoTest2Service/");
+ route2->mutable_route()->set_cluster(kNewCluster2Name);
+ auto* default_route = new_route_config.mutable_virtual_hosts(0)->add_routes();
+ default_route->mutable_match()->set_prefix("");
+ default_route->mutable_route()->set_cluster(kDefaultClusterName);
+ SetRouteConfiguration(0, new_route_config);
+ WaitForAllBackends(0, 2);
+ CheckRpcSendOk(kNumEchoRpcs, RpcOptions().set_wait_for_ready(true));
+ CheckRpcSendOk(
+ kNumEcho1Rpcs,
+ RpcOptions().set_rpc_service(SERVICE_ECHO1).set_wait_for_ready(true));
+ CheckRpcSendOk(
+ kNumEcho2Rpcs,
+ RpcOptions().set_rpc_service(SERVICE_ECHO2).set_wait_for_ready(true));
+ // Make sure RPCs all go to the correct backend.
+ for (size_t i = 0; i < 2; ++i) {
+ EXPECT_EQ(kNumEchoRpcs / 2,
+ backends_[i]->backend_service()->request_count());
+ EXPECT_EQ(0, backends_[i]->backend_service1()->request_count());
+ EXPECT_EQ(0, backends_[i]->backend_service2()->request_count());
+ }
+ EXPECT_EQ(0, backends_[2]->backend_service()->request_count());
+ EXPECT_EQ(kNumEcho1Rpcs, backends_[2]->backend_service1()->request_count());
+ EXPECT_EQ(0, backends_[2]->backend_service2()->request_count());
+ EXPECT_EQ(0, backends_[3]->backend_service()->request_count());
+ EXPECT_EQ(0, backends_[3]->backend_service1()->request_count());
+ EXPECT_EQ(kNumEcho2Rpcs, backends_[3]->backend_service2()->request_count());
+}
+
+TEST_P(LdsRdsTest, XdsRoutingPathRegexMatching) {
+ const char* kNewCluster1Name = "new_cluster_1";
+ const char* kNewEdsService1Name = "new_eds_service_name_1";
+ const char* kNewCluster2Name = "new_cluster_2";
+ const char* kNewEdsService2Name = "new_eds_service_name_2";
+ const size_t kNumEcho1Rpcs = 10;
+ const size_t kNumEcho2Rpcs = 20;
+ const size_t kNumEchoRpcs = 30;
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ // Populate new EDS resources.
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts(0, 2)},
+ });
+ AdsServiceImpl::EdsResourceArgs args1({
+ {"locality0", GetBackendPorts(2, 3)},
+ });
+ AdsServiceImpl::EdsResourceArgs args2({
+ {"locality0", GetBackendPorts(3, 4)},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args));
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args1, kNewEdsService1Name));
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args2, kNewEdsService2Name));
+ // Populate new CDS resources.
+ Cluster new_cluster1 = balancers_[0]->ads_service()->default_cluster();
+ new_cluster1.set_name(kNewCluster1Name);
+ new_cluster1.mutable_eds_cluster_config()->set_service_name(
+ kNewEdsService1Name);
+ balancers_[0]->ads_service()->SetCdsResource(new_cluster1);
+ Cluster new_cluster2 = balancers_[0]->ads_service()->default_cluster();
+ new_cluster2.set_name(kNewCluster2Name);
+ new_cluster2.mutable_eds_cluster_config()->set_service_name(
+ kNewEdsService2Name);
+ balancers_[0]->ads_service()->SetCdsResource(new_cluster2);
+ // Populating Route Configurations for LDS.
+ RouteConfiguration new_route_config =
+ balancers_[0]->ads_service()->default_route_config();
+ auto* route1 = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0);
+ // Will match "/grpc.testing.EchoTest1Service/"
+ route1->mutable_match()->mutable_safe_regex()->set_regex(".*1.*");
+ route1->mutable_route()->set_cluster(kNewCluster1Name);
+ auto* route2 = new_route_config.mutable_virtual_hosts(0)->add_routes();
+ // Will match "/grpc.testing.EchoTest2Service/"
+ route2->mutable_match()->mutable_safe_regex()->set_regex(".*2.*");
+ route2->mutable_route()->set_cluster(kNewCluster2Name);
+ auto* default_route = new_route_config.mutable_virtual_hosts(0)->add_routes();
+ default_route->mutable_match()->set_prefix("");
+ default_route->mutable_route()->set_cluster(kDefaultClusterName);
+ SetRouteConfiguration(0, new_route_config);
+ WaitForAllBackends(0, 2);
+ CheckRpcSendOk(kNumEchoRpcs, RpcOptions().set_wait_for_ready(true));
+ CheckRpcSendOk(
+ kNumEcho1Rpcs,
+ RpcOptions().set_rpc_service(SERVICE_ECHO1).set_wait_for_ready(true));
+ CheckRpcSendOk(
+ kNumEcho2Rpcs,
+ RpcOptions().set_rpc_service(SERVICE_ECHO2).set_wait_for_ready(true));
+ // Make sure RPCs all go to the correct backend.
+ for (size_t i = 0; i < 2; ++i) {
+ EXPECT_EQ(kNumEchoRpcs / 2,
+ backends_[i]->backend_service()->request_count());
+ EXPECT_EQ(0, backends_[i]->backend_service1()->request_count());
+ EXPECT_EQ(0, backends_[i]->backend_service2()->request_count());
+ }
+ EXPECT_EQ(0, backends_[2]->backend_service()->request_count());
+ EXPECT_EQ(kNumEcho1Rpcs, backends_[2]->backend_service1()->request_count());
+ EXPECT_EQ(0, backends_[2]->backend_service2()->request_count());
+ EXPECT_EQ(0, backends_[3]->backend_service()->request_count());
+ EXPECT_EQ(0, backends_[3]->backend_service1()->request_count());
+ EXPECT_EQ(kNumEcho2Rpcs, backends_[3]->backend_service2()->request_count());
+}
+
+TEST_P(LdsRdsTest, XdsRoutingWeightedCluster) {
+ const char* kNewCluster1Name = "new_cluster_1";
+ const char* kNewEdsService1Name = "new_eds_service_name_1";
+ const char* kNewCluster2Name = "new_cluster_2";
+ const char* kNewEdsService2Name = "new_eds_service_name_2";
+ const size_t kNumEcho1Rpcs = 1000;
+ const size_t kNumEchoRpcs = 10;
+ const size_t kWeight75 = 75;
+ const size_t kWeight25 = 25;
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ // Populate new EDS resources.
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts(0, 1)},
+ });
+ AdsServiceImpl::EdsResourceArgs args1({
+ {"locality0", GetBackendPorts(1, 2)},
+ });
+ AdsServiceImpl::EdsResourceArgs args2({
+ {"locality0", GetBackendPorts(2, 3)},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args));
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args1, kNewEdsService1Name));
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args2, kNewEdsService2Name));
+ // Populate new CDS resources.
+ Cluster new_cluster1 = balancers_[0]->ads_service()->default_cluster();
+ new_cluster1.set_name(kNewCluster1Name);
+ new_cluster1.mutable_eds_cluster_config()->set_service_name(
+ kNewEdsService1Name);
+ balancers_[0]->ads_service()->SetCdsResource(new_cluster1);
+ Cluster new_cluster2 = balancers_[0]->ads_service()->default_cluster();
+ new_cluster2.set_name(kNewCluster2Name);
+ new_cluster2.mutable_eds_cluster_config()->set_service_name(
+ kNewEdsService2Name);
+ balancers_[0]->ads_service()->SetCdsResource(new_cluster2);
+ // Populating Route Configurations for LDS.
+ RouteConfiguration new_route_config =
+ balancers_[0]->ads_service()->default_route_config();
+ auto* route1 = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0);
+ route1->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/");
+ auto* weighted_cluster1 =
+ route1->mutable_route()->mutable_weighted_clusters()->add_clusters();
+ weighted_cluster1->set_name(kNewCluster1Name);
+ weighted_cluster1->mutable_weight()->set_value(kWeight75);
+ auto* weighted_cluster2 =
+ route1->mutable_route()->mutable_weighted_clusters()->add_clusters();
+ weighted_cluster2->set_name(kNewCluster2Name);
+ weighted_cluster2->mutable_weight()->set_value(kWeight25);
+ route1->mutable_route()
+ ->mutable_weighted_clusters()
+ ->mutable_total_weight()
+ ->set_value(kWeight75 + kWeight25);
+ auto* default_route = new_route_config.mutable_virtual_hosts(0)->add_routes();
+ default_route->mutable_match()->set_prefix("");
+ default_route->mutable_route()->set_cluster(kDefaultClusterName);
+ SetRouteConfiguration(0, new_route_config);
+ WaitForAllBackends(0, 1);
+ WaitForAllBackends(1, 3, true, RpcOptions().set_rpc_service(SERVICE_ECHO1));
+ CheckRpcSendOk(kNumEchoRpcs);
+ CheckRpcSendOk(kNumEcho1Rpcs, RpcOptions().set_rpc_service(SERVICE_ECHO1));
+ // Make sure RPCs all go to the correct backend.
+ EXPECT_EQ(kNumEchoRpcs, backends_[0]->backend_service()->request_count());
+ EXPECT_EQ(0, backends_[0]->backend_service1()->request_count());
+ EXPECT_EQ(0, backends_[1]->backend_service()->request_count());
+ const int weight_75_request_count =
+ backends_[1]->backend_service1()->request_count();
+ EXPECT_EQ(0, backends_[2]->backend_service()->request_count());
+ const int weight_25_request_count =
+ backends_[2]->backend_service1()->request_count();
+ const double kErrorTolerance = 0.2;
+ EXPECT_THAT(weight_75_request_count,
+ ::testing::AllOf(::testing::Ge(kNumEcho1Rpcs * kWeight75 / 100 *
+ (1 - kErrorTolerance)),
+ ::testing::Le(kNumEcho1Rpcs * kWeight75 / 100 *
+ (1 + kErrorTolerance))));
+ // TODO: (@donnadionne) Reduce tolerance: increased the tolerance to keep the
+ // test from flaking while debugging potential root cause.
+ const double kErrorToleranceSmallLoad = 0.3;
+ gpr_log(GPR_INFO, "target_75 received %d rpcs and target_25 received %d rpcs",
+ weight_75_request_count, weight_25_request_count);
+ EXPECT_THAT(weight_25_request_count,
+ ::testing::AllOf(::testing::Ge(kNumEcho1Rpcs * kWeight25 / 100 *
+ (1 - kErrorToleranceSmallLoad)),
+ ::testing::Le(kNumEcho1Rpcs * kWeight25 / 100 *
+ (1 + kErrorToleranceSmallLoad))));
+}
+
+TEST_P(LdsRdsTest, RouteActionWeightedTargetDefaultRoute) {
+ const char* kNewCluster1Name = "new_cluster_1";
+ const char* kNewEdsService1Name = "new_eds_service_name_1";
+ const char* kNewCluster2Name = "new_cluster_2";
+ const char* kNewEdsService2Name = "new_eds_service_name_2";
+ const size_t kNumEchoRpcs = 1000;
+ const size_t kWeight75 = 75;
+ const size_t kWeight25 = 25;
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ // Populate new EDS resources.
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts(0, 1)},
+ });
+ AdsServiceImpl::EdsResourceArgs args1({
+ {"locality0", GetBackendPorts(1, 2)},
+ });
+ AdsServiceImpl::EdsResourceArgs args2({
+ {"locality0", GetBackendPorts(2, 3)},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args));
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args1, kNewEdsService1Name));
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args2, kNewEdsService2Name));
+ // Populate new CDS resources.
+ Cluster new_cluster1 = balancers_[0]->ads_service()->default_cluster();
+ new_cluster1.set_name(kNewCluster1Name);
+ new_cluster1.mutable_eds_cluster_config()->set_service_name(
+ kNewEdsService1Name);
+ balancers_[0]->ads_service()->SetCdsResource(new_cluster1);
+ Cluster new_cluster2 = balancers_[0]->ads_service()->default_cluster();
+ new_cluster2.set_name(kNewCluster2Name);
+ new_cluster2.mutable_eds_cluster_config()->set_service_name(
+ kNewEdsService2Name);
+ balancers_[0]->ads_service()->SetCdsResource(new_cluster2);
+ // Populating Route Configurations for LDS.
+ RouteConfiguration new_route_config =
+ balancers_[0]->ads_service()->default_route_config();
+ auto* route1 = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0);
+ route1->mutable_match()->set_prefix("");
+ auto* weighted_cluster1 =
+ route1->mutable_route()->mutable_weighted_clusters()->add_clusters();
+ weighted_cluster1->set_name(kNewCluster1Name);
+ weighted_cluster1->mutable_weight()->set_value(kWeight75);
+ auto* weighted_cluster2 =
+ route1->mutable_route()->mutable_weighted_clusters()->add_clusters();
+ weighted_cluster2->set_name(kNewCluster2Name);
+ weighted_cluster2->mutable_weight()->set_value(kWeight25);
+ route1->mutable_route()
+ ->mutable_weighted_clusters()
+ ->mutable_total_weight()
+ ->set_value(kWeight75 + kWeight25);
+ SetRouteConfiguration(0, new_route_config);
+ WaitForAllBackends(1, 3);
+ CheckRpcSendOk(kNumEchoRpcs);
+ // Make sure RPCs all go to the correct backend.
+ EXPECT_EQ(0, backends_[0]->backend_service()->request_count());
+ const int weight_75_request_count =
+ backends_[1]->backend_service()->request_count();
+ const int weight_25_request_count =
+ backends_[2]->backend_service()->request_count();
+ const double kErrorTolerance = 0.2;
+ EXPECT_THAT(weight_75_request_count,
+ ::testing::AllOf(::testing::Ge(kNumEchoRpcs * kWeight75 / 100 *
+ (1 - kErrorTolerance)),
+ ::testing::Le(kNumEchoRpcs * kWeight75 / 100 *
+ (1 + kErrorTolerance))));
+ // TODO: (@donnadionne) Reduce tolerance: increased the tolerance to keep the
+ // test from flaking while debugging potential root cause.
+ const double kErrorToleranceSmallLoad = 0.3;
+ gpr_log(GPR_INFO, "target_75 received %d rpcs and target_25 received %d rpcs",
+ weight_75_request_count, weight_25_request_count);
+ EXPECT_THAT(weight_25_request_count,
+ ::testing::AllOf(::testing::Ge(kNumEchoRpcs * kWeight25 / 100 *
+ (1 - kErrorToleranceSmallLoad)),
+ ::testing::Le(kNumEchoRpcs * kWeight25 / 100 *
+ (1 + kErrorToleranceSmallLoad))));
+}
+
+TEST_P(LdsRdsTest, XdsRoutingWeightedClusterUpdateWeights) {
+ const char* kNewCluster1Name = "new_cluster_1";
+ const char* kNewEdsService1Name = "new_eds_service_name_1";
+ const char* kNewCluster2Name = "new_cluster_2";
+ const char* kNewEdsService2Name = "new_eds_service_name_2";
+ const char* kNewCluster3Name = "new_cluster_3";
+ const char* kNewEdsService3Name = "new_eds_service_name_3";
+ const size_t kNumEcho1Rpcs = 1000;
+ const size_t kNumEchoRpcs = 10;
+ const size_t kWeight75 = 75;
+ const size_t kWeight25 = 25;
+ const size_t kWeight50 = 50;
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ // Populate new EDS resources.
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts(0, 1)},
+ });
+ AdsServiceImpl::EdsResourceArgs args1({
+ {"locality0", GetBackendPorts(1, 2)},
+ });
+ AdsServiceImpl::EdsResourceArgs args2({
+ {"locality0", GetBackendPorts(2, 3)},
+ });
+ AdsServiceImpl::EdsResourceArgs args3({
+ {"locality0", GetBackendPorts(3, 4)},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args));
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args1, kNewEdsService1Name));
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args2, kNewEdsService2Name));
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args3, kNewEdsService3Name));
+ // Populate new CDS resources.
+ Cluster new_cluster1 = balancers_[0]->ads_service()->default_cluster();
+ new_cluster1.set_name(kNewCluster1Name);
+ new_cluster1.mutable_eds_cluster_config()->set_service_name(
+ kNewEdsService1Name);
+ balancers_[0]->ads_service()->SetCdsResource(new_cluster1);
+ Cluster new_cluster2 = balancers_[0]->ads_service()->default_cluster();
+ new_cluster2.set_name(kNewCluster2Name);
+ new_cluster2.mutable_eds_cluster_config()->set_service_name(
+ kNewEdsService2Name);
+ balancers_[0]->ads_service()->SetCdsResource(new_cluster2);
+ Cluster new_cluster3 = balancers_[0]->ads_service()->default_cluster();
+ new_cluster3.set_name(kNewCluster3Name);
+ new_cluster3.mutable_eds_cluster_config()->set_service_name(
+ kNewEdsService3Name);
+ balancers_[0]->ads_service()->SetCdsResource(new_cluster3);
+ // Populating Route Configurations.
+ RouteConfiguration new_route_config =
+ balancers_[0]->ads_service()->default_route_config();
+ auto* route1 = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0);
+ route1->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/");
+ auto* weighted_cluster1 =
+ route1->mutable_route()->mutable_weighted_clusters()->add_clusters();
+ weighted_cluster1->set_name(kNewCluster1Name);
+ weighted_cluster1->mutable_weight()->set_value(kWeight75);
+ auto* weighted_cluster2 =
+ route1->mutable_route()->mutable_weighted_clusters()->add_clusters();
+ weighted_cluster2->set_name(kNewCluster2Name);
+ weighted_cluster2->mutable_weight()->set_value(kWeight25);
+ route1->mutable_route()
+ ->mutable_weighted_clusters()
+ ->mutable_total_weight()
+ ->set_value(kWeight75 + kWeight25);
+ auto* default_route = new_route_config.mutable_virtual_hosts(0)->add_routes();
+ default_route->mutable_match()->set_prefix("");
+ default_route->mutable_route()->set_cluster(kDefaultClusterName);
+ SetRouteConfiguration(0, new_route_config);
+ WaitForAllBackends(0, 1);
+ WaitForAllBackends(1, 3, true, RpcOptions().set_rpc_service(SERVICE_ECHO1));
+ CheckRpcSendOk(kNumEchoRpcs);
+ CheckRpcSendOk(kNumEcho1Rpcs, RpcOptions().set_rpc_service(SERVICE_ECHO1));
+ // Make sure RPCs all go to the correct backend.
+ EXPECT_EQ(kNumEchoRpcs, backends_[0]->backend_service()->request_count());
+ EXPECT_EQ(0, backends_[0]->backend_service1()->request_count());
+ EXPECT_EQ(0, backends_[1]->backend_service()->request_count());
+ const int weight_75_request_count =
+ backends_[1]->backend_service1()->request_count();
+ EXPECT_EQ(0, backends_[1]->backend_service2()->request_count());
+ EXPECT_EQ(0, backends_[2]->backend_service()->request_count());
+ const int weight_25_request_count =
+ backends_[2]->backend_service1()->request_count();
+ EXPECT_EQ(0, backends_[3]->backend_service()->request_count());
+ EXPECT_EQ(0, backends_[3]->backend_service1()->request_count());
+ const double kErrorTolerance = 0.2;
+ EXPECT_THAT(weight_75_request_count,
+ ::testing::AllOf(::testing::Ge(kNumEcho1Rpcs * kWeight75 / 100 *
+ (1 - kErrorTolerance)),
+ ::testing::Le(kNumEcho1Rpcs * kWeight75 / 100 *
+ (1 + kErrorTolerance))));
+ // TODO: (@donnadionne) Reduce tolerance: increased the tolerance to keep the
+ // test from flaking while debugging potential root cause.
+ const double kErrorToleranceSmallLoad = 0.3;
+ gpr_log(GPR_INFO, "target_75 received %d rpcs and target_25 received %d rpcs",
+ weight_75_request_count, weight_25_request_count);
+ EXPECT_THAT(weight_25_request_count,
+ ::testing::AllOf(::testing::Ge(kNumEcho1Rpcs * kWeight25 / 100 *
+ (1 - kErrorToleranceSmallLoad)),
+ ::testing::Le(kNumEcho1Rpcs * kWeight25 / 100 *
+ (1 + kErrorToleranceSmallLoad))));
+ // Change Route Configurations: same clusters different weights.
+ weighted_cluster1->mutable_weight()->set_value(kWeight50);
+ weighted_cluster2->mutable_weight()->set_value(kWeight50);
+ // Change default route to a new cluster to help to identify when new polices
+ // are seen by the client.
+ default_route->mutable_route()->set_cluster(kNewCluster3Name);
+ SetRouteConfiguration(0, new_route_config);
+ ResetBackendCounters();
+ WaitForAllBackends(3, 4);
+ CheckRpcSendOk(kNumEchoRpcs);
+ CheckRpcSendOk(kNumEcho1Rpcs, RpcOptions().set_rpc_service(SERVICE_ECHO1));
+ // Make sure RPCs all go to the correct backend.
+ EXPECT_EQ(0, backends_[0]->backend_service()->request_count());
+ EXPECT_EQ(0, backends_[0]->backend_service1()->request_count());
+ EXPECT_EQ(0, backends_[1]->backend_service()->request_count());
+ const int weight_50_request_count_1 =
+ backends_[1]->backend_service1()->request_count();
+ EXPECT_EQ(0, backends_[2]->backend_service()->request_count());
+ const int weight_50_request_count_2 =
+ backends_[2]->backend_service1()->request_count();
+ EXPECT_EQ(kNumEchoRpcs, backends_[3]->backend_service()->request_count());
+ EXPECT_EQ(0, backends_[3]->backend_service1()->request_count());
+ EXPECT_THAT(weight_50_request_count_1,
+ ::testing::AllOf(::testing::Ge(kNumEcho1Rpcs * kWeight50 / 100 *
+ (1 - kErrorTolerance)),
+ ::testing::Le(kNumEcho1Rpcs * kWeight50 / 100 *
+ (1 + kErrorTolerance))));
+ EXPECT_THAT(weight_50_request_count_2,
+ ::testing::AllOf(::testing::Ge(kNumEcho1Rpcs * kWeight50 / 100 *
+ (1 - kErrorTolerance)),
+ ::testing::Le(kNumEcho1Rpcs * kWeight50 / 100 *
+ (1 + kErrorTolerance))));
+}
+
+TEST_P(LdsRdsTest, XdsRoutingWeightedClusterUpdateClusters) {
+ const char* kNewCluster1Name = "new_cluster_1";
+ const char* kNewEdsService1Name = "new_eds_service_name_1";
+ const char* kNewCluster2Name = "new_cluster_2";
+ const char* kNewEdsService2Name = "new_eds_service_name_2";
+ const char* kNewCluster3Name = "new_cluster_3";
+ const char* kNewEdsService3Name = "new_eds_service_name_3";
+ const size_t kNumEcho1Rpcs = 1000;
+ const size_t kNumEchoRpcs = 10;
+ const size_t kWeight75 = 75;
+ const size_t kWeight25 = 25;
+ const size_t kWeight50 = 50;
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ // Populate new EDS resources.
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts(0, 1)},
+ });
+ AdsServiceImpl::EdsResourceArgs args1({
+ {"locality0", GetBackendPorts(1, 2)},
+ });
+ AdsServiceImpl::EdsResourceArgs args2({
+ {"locality0", GetBackendPorts(2, 3)},
+ });
+ AdsServiceImpl::EdsResourceArgs args3({
+ {"locality0", GetBackendPorts(3, 4)},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args));
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args1, kNewEdsService1Name));
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args2, kNewEdsService2Name));
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args3, kNewEdsService3Name));
+ // Populate new CDS resources.
+ Cluster new_cluster1 = balancers_[0]->ads_service()->default_cluster();
+ new_cluster1.set_name(kNewCluster1Name);
+ new_cluster1.mutable_eds_cluster_config()->set_service_name(
+ kNewEdsService1Name);
+ balancers_[0]->ads_service()->SetCdsResource(new_cluster1);
+ Cluster new_cluster2 = balancers_[0]->ads_service()->default_cluster();
+ new_cluster2.set_name(kNewCluster2Name);
+ new_cluster2.mutable_eds_cluster_config()->set_service_name(
+ kNewEdsService2Name);
+ balancers_[0]->ads_service()->SetCdsResource(new_cluster2);
+ Cluster new_cluster3 = balancers_[0]->ads_service()->default_cluster();
+ new_cluster3.set_name(kNewCluster3Name);
+ new_cluster3.mutable_eds_cluster_config()->set_service_name(
+ kNewEdsService3Name);
+ balancers_[0]->ads_service()->SetCdsResource(new_cluster3);
+ // Populating Route Configurations.
+ RouteConfiguration new_route_config =
+ balancers_[0]->ads_service()->default_route_config();
+ auto* route1 = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0);
+ route1->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/");
+ auto* weighted_cluster1 =
+ route1->mutable_route()->mutable_weighted_clusters()->add_clusters();
+ weighted_cluster1->set_name(kNewCluster1Name);
+ weighted_cluster1->mutable_weight()->set_value(kWeight75);
+ auto* weighted_cluster2 =
+ route1->mutable_route()->mutable_weighted_clusters()->add_clusters();
+ weighted_cluster2->set_name(kDefaultClusterName);
+ weighted_cluster2->mutable_weight()->set_value(kWeight25);
+ route1->mutable_route()
+ ->mutable_weighted_clusters()
+ ->mutable_total_weight()
+ ->set_value(kWeight75 + kWeight25);
+ auto* default_route = new_route_config.mutable_virtual_hosts(0)->add_routes();
+ default_route->mutable_match()->set_prefix("");
+ default_route->mutable_route()->set_cluster(kDefaultClusterName);
+ SetRouteConfiguration(0, new_route_config);
+ WaitForAllBackends(0, 1);
+ WaitForAllBackends(1, 2, true, RpcOptions().set_rpc_service(SERVICE_ECHO1));
+ CheckRpcSendOk(kNumEchoRpcs);
+ CheckRpcSendOk(kNumEcho1Rpcs, RpcOptions().set_rpc_service(SERVICE_ECHO1));
+ // Make sure RPCs all go to the correct backend.
+ EXPECT_EQ(kNumEchoRpcs, backends_[0]->backend_service()->request_count());
+ int weight_25_request_count =
+ backends_[0]->backend_service1()->request_count();
+ EXPECT_EQ(0, backends_[1]->backend_service()->request_count());
+ int weight_75_request_count =
+ backends_[1]->backend_service1()->request_count();
+ EXPECT_EQ(0, backends_[2]->backend_service()->request_count());
+ EXPECT_EQ(0, backends_[2]->backend_service1()->request_count());
+ EXPECT_EQ(0, backends_[3]->backend_service()->request_count());
+ EXPECT_EQ(0, backends_[3]->backend_service1()->request_count());
+ const double kErrorTolerance = 0.2;
+ EXPECT_THAT(weight_75_request_count,
+ ::testing::AllOf(::testing::Ge(kNumEcho1Rpcs * kWeight75 / 100 *
+ (1 - kErrorTolerance)),
+ ::testing::Le(kNumEcho1Rpcs * kWeight75 / 100 *
+ (1 + kErrorTolerance))));
+ // TODO: (@donnadionne) Reduce tolerance: increased the tolerance to keep the
+ // test from flaking while debugging potential root cause.
+ const double kErrorToleranceSmallLoad = 0.3;
+ gpr_log(GPR_INFO, "target_75 received %d rpcs and target_25 received %d rpcs",
+ weight_75_request_count, weight_25_request_count);
+ EXPECT_THAT(weight_25_request_count,
+ ::testing::AllOf(::testing::Ge(kNumEcho1Rpcs * kWeight25 / 100 *
+ (1 - kErrorToleranceSmallLoad)),
+ ::testing::Le(kNumEcho1Rpcs * kWeight25 / 100 *
+ (1 + kErrorToleranceSmallLoad))));
+ // Change Route Configurations: new set of clusters with different weights.
+ weighted_cluster1->mutable_weight()->set_value(kWeight50);
+ weighted_cluster2->set_name(kNewCluster2Name);
+ weighted_cluster2->mutable_weight()->set_value(kWeight50);
+ SetRouteConfiguration(0, new_route_config);
+ ResetBackendCounters();
+ WaitForAllBackends(2, 3, true, RpcOptions().set_rpc_service(SERVICE_ECHO1));
+ CheckRpcSendOk(kNumEchoRpcs);
+ CheckRpcSendOk(kNumEcho1Rpcs, RpcOptions().set_rpc_service(SERVICE_ECHO1));
+ // Make sure RPCs all go to the correct backend.
+ EXPECT_EQ(kNumEchoRpcs, backends_[0]->backend_service()->request_count());
+ EXPECT_EQ(0, backends_[0]->backend_service1()->request_count());
+ EXPECT_EQ(0, backends_[1]->backend_service()->request_count());
+ const int weight_50_request_count_1 =
+ backends_[1]->backend_service1()->request_count();
+ EXPECT_EQ(0, backends_[2]->backend_service()->request_count());
+ const int weight_50_request_count_2 =
+ backends_[2]->backend_service1()->request_count();
+ EXPECT_EQ(0, backends_[3]->backend_service()->request_count());
+ EXPECT_EQ(0, backends_[3]->backend_service1()->request_count());
+ EXPECT_THAT(weight_50_request_count_1,
+ ::testing::AllOf(::testing::Ge(kNumEcho1Rpcs * kWeight50 / 100 *
+ (1 - kErrorTolerance)),
+ ::testing::Le(kNumEcho1Rpcs * kWeight50 / 100 *
+ (1 + kErrorTolerance))));
+ EXPECT_THAT(weight_50_request_count_2,
+ ::testing::AllOf(::testing::Ge(kNumEcho1Rpcs * kWeight50 / 100 *
+ (1 - kErrorTolerance)),
+ ::testing::Le(kNumEcho1Rpcs * kWeight50 / 100 *
+ (1 + kErrorTolerance))));
+ // Change Route Configurations.
+ weighted_cluster1->mutable_weight()->set_value(kWeight75);
+ weighted_cluster2->set_name(kNewCluster3Name);
+ weighted_cluster2->mutable_weight()->set_value(kWeight25);
+ SetRouteConfiguration(0, new_route_config);
+ ResetBackendCounters();
+ WaitForAllBackends(3, 4, true, RpcOptions().set_rpc_service(SERVICE_ECHO1));
+ CheckRpcSendOk(kNumEchoRpcs);
+ CheckRpcSendOk(kNumEcho1Rpcs, RpcOptions().set_rpc_service(SERVICE_ECHO1));
+ // Make sure RPCs all go to the correct backend.
+ EXPECT_EQ(kNumEchoRpcs, backends_[0]->backend_service()->request_count());
+ EXPECT_EQ(0, backends_[0]->backend_service1()->request_count());
+ EXPECT_EQ(0, backends_[1]->backend_service()->request_count());
+ weight_75_request_count = backends_[1]->backend_service1()->request_count();
+ EXPECT_EQ(0, backends_[2]->backend_service()->request_count());
+ EXPECT_EQ(0, backends_[2]->backend_service1()->request_count());
+ EXPECT_EQ(0, backends_[3]->backend_service()->request_count());
+ weight_25_request_count = backends_[3]->backend_service1()->request_count();
+ EXPECT_THAT(weight_75_request_count,
+ ::testing::AllOf(::testing::Ge(kNumEcho1Rpcs * kWeight75 / 100 *
+ (1 - kErrorTolerance)),
+ ::testing::Le(kNumEcho1Rpcs * kWeight75 / 100 *
+ (1 + kErrorTolerance))));
+ // TODO: (@donnadionne) Reduce tolerance: increased the tolerance to keep the
+ // test from flaking while debugging potential root cause.
+ gpr_log(GPR_INFO, "target_75 received %d rpcs and target_25 received %d rpcs",
+ weight_75_request_count, weight_25_request_count);
+ EXPECT_THAT(weight_25_request_count,
+ ::testing::AllOf(::testing::Ge(kNumEcho1Rpcs * kWeight25 / 100 *
+ (1 - kErrorToleranceSmallLoad)),
+ ::testing::Le(kNumEcho1Rpcs * kWeight25 / 100 *
+ (1 + kErrorToleranceSmallLoad))));
+}
+
+TEST_P(LdsRdsTest, XdsRoutingClusterUpdateClusters) {
+ const char* kNewClusterName = "new_cluster";
+ const char* kNewEdsServiceName = "new_eds_service_name";
+ const size_t kNumEchoRpcs = 5;
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ // Populate new EDS resources.
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts(0, 1)},
+ });
+ AdsServiceImpl::EdsResourceArgs args1({
+ {"locality0", GetBackendPorts(1, 2)},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args));
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args1, kNewEdsServiceName));
+ // Populate new CDS resources.
+ Cluster new_cluster = balancers_[0]->ads_service()->default_cluster();
+ new_cluster.set_name(kNewClusterName);
+ new_cluster.mutable_eds_cluster_config()->set_service_name(
+ kNewEdsServiceName);
+ balancers_[0]->ads_service()->SetCdsResource(new_cluster);
+ // Send Route Configuration.
+ RouteConfiguration new_route_config =
+ balancers_[0]->ads_service()->default_route_config();
+ SetRouteConfiguration(0, new_route_config);
+ WaitForAllBackends(0, 1);
+ CheckRpcSendOk(kNumEchoRpcs);
+ // Make sure RPCs all go to the correct backend.
+ EXPECT_EQ(kNumEchoRpcs, backends_[0]->backend_service()->request_count());
+ // Change Route Configurations: new default cluster.
+ auto* default_route =
+ new_route_config.mutable_virtual_hosts(0)->mutable_routes(0);
+ default_route->mutable_route()->set_cluster(kNewClusterName);
+ SetRouteConfiguration(0, new_route_config);
+ WaitForAllBackends(1, 2);
+ CheckRpcSendOk(kNumEchoRpcs);
+ // Make sure RPCs all go to the correct backend.
+ EXPECT_EQ(kNumEchoRpcs, backends_[1]->backend_service()->request_count());
+}
+
+TEST_P(LdsRdsTest, XdsRoutingClusterUpdateClustersWithPickingDelays) {
+ const char* kNewClusterName = "new_cluster";
+ const char* kNewEdsServiceName = "new_eds_service_name";
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ // Populate new EDS resources.
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts(0, 1)},
+ });
+ AdsServiceImpl::EdsResourceArgs args1({
+ {"locality0", GetBackendPorts(1, 2)},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args));
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args1, kNewEdsServiceName));
+ // Populate new CDS resources.
+ Cluster new_cluster = balancers_[0]->ads_service()->default_cluster();
+ new_cluster.set_name(kNewClusterName);
+ new_cluster.mutable_eds_cluster_config()->set_service_name(
+ kNewEdsServiceName);
+ balancers_[0]->ads_service()->SetCdsResource(new_cluster);
+ // Bring down the current backend: 0, this will delay route picking time,
+ // resulting in un-committed RPCs.
+ ShutdownBackend(0);
+ // Send a RouteConfiguration with a default route that points to
+ // backend 0.
+ RouteConfiguration new_route_config =
+ balancers_[0]->ads_service()->default_route_config();
+ SetRouteConfiguration(0, new_route_config);
+ // Send exactly one RPC with no deadline and with wait_for_ready=true.
+ // This RPC will not complete until after backend 0 is started.
+ std::thread sending_rpc([this]() {
+ CheckRpcSendOk(1, RpcOptions().set_wait_for_ready(true).set_timeout_ms(0));
+ });
+ // Send a non-wait_for_ready RPC which should fail, this will tell us
+ // that the client has received the update and attempted to connect.
+ const Status status = SendRpc(RpcOptions().set_timeout_ms(0));
+ EXPECT_FALSE(status.ok());
+ // Send a update RouteConfiguration to use backend 1.
+ auto* default_route =
+ new_route_config.mutable_virtual_hosts(0)->mutable_routes(0);
+ default_route->mutable_route()->set_cluster(kNewClusterName);
+ SetRouteConfiguration(0, new_route_config);
+ // Wait for RPCs to go to the new backend: 1, this ensures that the client has
+ // processed the update.
+ WaitForAllBackends(1, 2, false, RpcOptions(), true);
+ // Bring up the previous backend: 0, this will allow the delayed RPC to
+ // finally call on_call_committed upon completion.
+ StartBackend(0);
+ sending_rpc.join();
+ // Make sure RPCs go to the correct backend:
+ EXPECT_EQ(1, backends_[0]->backend_service()->request_count());
+ EXPECT_EQ(1, backends_[1]->backend_service()->request_count());
+}
+
+TEST_P(LdsRdsTest, XdsRoutingHeadersMatching) {
+ const char* kNewClusterName = "new_cluster";
+ const char* kNewEdsServiceName = "new_eds_service_name";
+ const size_t kNumEcho1Rpcs = 100;
+ const size_t kNumEchoRpcs = 5;
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ // Populate new EDS resources.
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts(0, 1)},
+ });
+ AdsServiceImpl::EdsResourceArgs args1({
+ {"locality0", GetBackendPorts(1, 2)},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args));
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args1, kNewEdsServiceName));
+ // Populate new CDS resources.
+ Cluster new_cluster = balancers_[0]->ads_service()->default_cluster();
+ new_cluster.set_name(kNewClusterName);
+ new_cluster.mutable_eds_cluster_config()->set_service_name(
+ kNewEdsServiceName);
+ balancers_[0]->ads_service()->SetCdsResource(new_cluster);
+ // Populating Route Configurations for LDS.
+ RouteConfiguration route_config =
+ balancers_[0]->ads_service()->default_route_config();
+ auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0);
+ route1->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/");
+ auto* header_matcher1 = route1->mutable_match()->add_headers();
+ header_matcher1->set_name("header1");
+ header_matcher1->set_exact_match("POST,PUT,GET");
+ auto* header_matcher2 = route1->mutable_match()->add_headers();
+ header_matcher2->set_name("header2");
+ header_matcher2->mutable_safe_regex_match()->set_regex("[a-z]*");
+ auto* header_matcher3 = route1->mutable_match()->add_headers();
+ header_matcher3->set_name("header3");
+ header_matcher3->mutable_range_match()->set_start(1);
+ header_matcher3->mutable_range_match()->set_end(1000);
+ auto* header_matcher4 = route1->mutable_match()->add_headers();
+ header_matcher4->set_name("header4");
+ header_matcher4->set_present_match(false);
+ auto* header_matcher5 = route1->mutable_match()->add_headers();
+ header_matcher5->set_name("header5");
+ header_matcher5->set_prefix_match("/grpc");
+ auto* header_matcher6 = route1->mutable_match()->add_headers();
+ header_matcher6->set_name("header6");
+ header_matcher6->set_suffix_match(".cc");
+ header_matcher6->set_invert_match(true);
+ route1->mutable_route()->set_cluster(kNewClusterName);
+ auto* default_route = route_config.mutable_virtual_hosts(0)->add_routes();
+ default_route->mutable_match()->set_prefix("");
+ default_route->mutable_route()->set_cluster(kDefaultClusterName);
+ SetRouteConfiguration(0, route_config);
+ std::vector<std::pair<TString, TString>> metadata = {
+ {"header1", "POST"}, {"header2", "blah"},
+ {"header3", "1"}, {"header5", "/grpc.testing.EchoTest1Service/"},
+ {"header1", "PUT"}, {"header6", "grpc.java"},
+ {"header1", "GET"},
+ };
+ const auto header_match_rpc_options = RpcOptions()
+ .set_rpc_service(SERVICE_ECHO1)
+ .set_rpc_method(METHOD_ECHO1)
+ .set_metadata(std::move(metadata));
+ // Make sure all backends are up.
+ WaitForAllBackends(0, 1);
+ WaitForAllBackends(1, 2, true, header_match_rpc_options);
+ // Send RPCs.
+ CheckRpcSendOk(kNumEchoRpcs);
+ CheckRpcSendOk(kNumEcho1Rpcs, header_match_rpc_options);
+ EXPECT_EQ(kNumEchoRpcs, backends_[0]->backend_service()->request_count());
+ EXPECT_EQ(0, backends_[0]->backend_service1()->request_count());
+ EXPECT_EQ(0, backends_[0]->backend_service2()->request_count());
+ EXPECT_EQ(0, backends_[1]->backend_service()->request_count());
+ EXPECT_EQ(kNumEcho1Rpcs, backends_[1]->backend_service1()->request_count());
+ EXPECT_EQ(0, backends_[1]->backend_service2()->request_count());
+ const auto& response_state = RouteConfigurationResponseState(0);
+ EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::ACKED);
+}
+
+TEST_P(LdsRdsTest, XdsRoutingHeadersMatchingSpecialHeaderContentType) {
+ const char* kNewClusterName = "new_cluster";
+ const char* kNewEdsServiceName = "new_eds_service_name";
+ const size_t kNumEchoRpcs = 100;
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ // Populate new EDS resources.
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts(0, 1)},
+ });
+ AdsServiceImpl::EdsResourceArgs args1({
+ {"locality0", GetBackendPorts(1, 2)},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args));
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args1, kNewEdsServiceName));
+ // Populate new CDS resources.
+ Cluster new_cluster = balancers_[0]->ads_service()->default_cluster();
+ new_cluster.set_name(kNewClusterName);
+ new_cluster.mutable_eds_cluster_config()->set_service_name(
+ kNewEdsServiceName);
+ balancers_[0]->ads_service()->SetCdsResource(new_cluster);
+ // Populating Route Configurations for LDS.
+ RouteConfiguration route_config =
+ balancers_[0]->ads_service()->default_route_config();
+ auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0);
+ route1->mutable_match()->set_prefix("");
+ auto* header_matcher1 = route1->mutable_match()->add_headers();
+ header_matcher1->set_name("content-type");
+ header_matcher1->set_exact_match("notapplication/grpc");
+ route1->mutable_route()->set_cluster(kNewClusterName);
+ auto* default_route = route_config.mutable_virtual_hosts(0)->add_routes();
+ default_route->mutable_match()->set_prefix("");
+ auto* header_matcher2 = default_route->mutable_match()->add_headers();
+ header_matcher2->set_name("content-type");
+ header_matcher2->set_exact_match("application/grpc");
+ default_route->mutable_route()->set_cluster(kDefaultClusterName);
+ SetRouteConfiguration(0, route_config);
+ // Make sure the backend is up.
+ WaitForAllBackends(0, 1);
+ // Send RPCs.
+ CheckRpcSendOk(kNumEchoRpcs);
+ EXPECT_EQ(kNumEchoRpcs, backends_[0]->backend_service()->request_count());
+ EXPECT_EQ(0, backends_[1]->backend_service()->request_count());
+ const auto& response_state = RouteConfigurationResponseState(0);
+ EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::ACKED);
+}
+
+TEST_P(LdsRdsTest, XdsRoutingHeadersMatchingSpecialCasesToIgnore) {
+ const char* kNewCluster1Name = "new_cluster_1";
+ const char* kNewEdsService1Name = "new_eds_service_name_1";
+ const char* kNewCluster2Name = "new_cluster_2";
+ const char* kNewEdsService2Name = "new_eds_service_name_2";
+ const size_t kNumEchoRpcs = 100;
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ // Populate new EDS resources.
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts(0, 1)},
+ });
+ AdsServiceImpl::EdsResourceArgs args1({
+ {"locality0", GetBackendPorts(1, 2)},
+ });
+ AdsServiceImpl::EdsResourceArgs args2({
+ {"locality0", GetBackendPorts(2, 3)},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args));
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args1, kNewEdsService1Name));
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args2, kNewEdsService2Name));
+ // Populate new CDS resources.
+ Cluster new_cluster1 = balancers_[0]->ads_service()->default_cluster();
+ new_cluster1.set_name(kNewCluster1Name);
+ new_cluster1.mutable_eds_cluster_config()->set_service_name(
+ kNewEdsService1Name);
+ balancers_[0]->ads_service()->SetCdsResource(new_cluster1);
+ Cluster new_cluster2 = balancers_[0]->ads_service()->default_cluster();
+ new_cluster2.set_name(kNewCluster2Name);
+ new_cluster2.mutable_eds_cluster_config()->set_service_name(
+ kNewEdsService2Name);
+ balancers_[0]->ads_service()->SetCdsResource(new_cluster2);
+ // Populating Route Configurations for LDS.
+ RouteConfiguration route_config =
+ balancers_[0]->ads_service()->default_route_config();
+ auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0);
+ route1->mutable_match()->set_prefix("");
+ auto* header_matcher1 = route1->mutable_match()->add_headers();
+ header_matcher1->set_name("grpc-foo-bin");
+ header_matcher1->set_present_match(true);
+ route1->mutable_route()->set_cluster(kNewCluster1Name);
+ auto route2 = route_config.mutable_virtual_hosts(0)->add_routes();
+ route2->mutable_match()->set_prefix("");
+ auto* header_matcher2 = route2->mutable_match()->add_headers();
+ header_matcher2->set_name("grpc-previous-rpc-attempts");
+ header_matcher2->set_present_match(true);
+ route2->mutable_route()->set_cluster(kNewCluster2Name);
+ auto* default_route = route_config.mutable_virtual_hosts(0)->add_routes();
+ default_route->mutable_match()->set_prefix("");
+ default_route->mutable_route()->set_cluster(kDefaultClusterName);
+ SetRouteConfiguration(0, route_config);
+ // Send headers which will mismatch each route
+ std::vector<std::pair<TString, TString>> metadata = {
+ {"grpc-foo-bin", "grpc-foo-bin"},
+ {"grpc-previous-rpc-attempts", "grpc-previous-rpc-attempts"},
+ };
+ WaitForAllBackends(0, 1);
+ CheckRpcSendOk(kNumEchoRpcs, RpcOptions().set_metadata(metadata));
+ // Verify that only the default backend got RPCs since all previous routes
+ // were mismatched.
+ EXPECT_EQ(kNumEchoRpcs, backends_[0]->backend_service()->request_count());
+ EXPECT_EQ(0, backends_[1]->backend_service()->request_count());
+ EXPECT_EQ(0, backends_[2]->backend_service()->request_count());
+ const auto& response_state = RouteConfigurationResponseState(0);
+ EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::ACKED);
+}
+
+TEST_P(LdsRdsTest, XdsRoutingRuntimeFractionMatching) {
+ const char* kNewClusterName = "new_cluster";
+ const char* kNewEdsServiceName = "new_eds_service_name";
+ const size_t kNumRpcs = 1000;
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ // Populate new EDS resources.
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts(0, 1)},
+ });
+ AdsServiceImpl::EdsResourceArgs args1({
+ {"locality0", GetBackendPorts(1, 2)},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args));
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args1, kNewEdsServiceName));
+ // Populate new CDS resources.
+ Cluster new_cluster = balancers_[0]->ads_service()->default_cluster();
+ new_cluster.set_name(kNewClusterName);
+ new_cluster.mutable_eds_cluster_config()->set_service_name(
+ kNewEdsServiceName);
+ balancers_[0]->ads_service()->SetCdsResource(new_cluster);
+ // Populating Route Configurations for LDS.
+ RouteConfiguration route_config =
+ balancers_[0]->ads_service()->default_route_config();
+ auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0);
+ route1->mutable_match()
+ ->mutable_runtime_fraction()
+ ->mutable_default_value()
+ ->set_numerator(25);
+ route1->mutable_route()->set_cluster(kNewClusterName);
+ auto* default_route = route_config.mutable_virtual_hosts(0)->add_routes();
+ default_route->mutable_match()->set_prefix("");
+ default_route->mutable_route()->set_cluster(kDefaultClusterName);
+ SetRouteConfiguration(0, route_config);
+ WaitForAllBackends(0, 2);
+ CheckRpcSendOk(kNumRpcs);
+ const int default_backend_count =
+ backends_[0]->backend_service()->request_count();
+ const int matched_backend_count =
+ backends_[1]->backend_service()->request_count();
+ const double kErrorTolerance = 0.2;
+ EXPECT_THAT(default_backend_count,
+ ::testing::AllOf(
+ ::testing::Ge(kNumRpcs * 75 / 100 * (1 - kErrorTolerance)),
+ ::testing::Le(kNumRpcs * 75 / 100 * (1 + kErrorTolerance))));
+ EXPECT_THAT(matched_backend_count,
+ ::testing::AllOf(
+ ::testing::Ge(kNumRpcs * 25 / 100 * (1 - kErrorTolerance)),
+ ::testing::Le(kNumRpcs * 25 / 100 * (1 + kErrorTolerance))));
+ const auto& response_state = RouteConfigurationResponseState(0);
+ EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::ACKED);
+}
+
+TEST_P(LdsRdsTest, XdsRoutingHeadersMatchingUnmatchCases) {
+ const char* kNewCluster1Name = "new_cluster_1";
+ const char* kNewEdsService1Name = "new_eds_service_name_1";
+ const char* kNewCluster2Name = "new_cluster_2";
+ const char* kNewEdsService2Name = "new_eds_service_name_2";
+ const char* kNewCluster3Name = "new_cluster_3";
+ const char* kNewEdsService3Name = "new_eds_service_name_3";
+ const size_t kNumEcho1Rpcs = 100;
+ const size_t kNumEchoRpcs = 5;
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ // Populate new EDS resources.
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts(0, 1)},
+ });
+ AdsServiceImpl::EdsResourceArgs args1({
+ {"locality0", GetBackendPorts(1, 2)},
+ });
+ AdsServiceImpl::EdsResourceArgs args2({
+ {"locality0", GetBackendPorts(2, 3)},
+ });
+ AdsServiceImpl::EdsResourceArgs args3({
+ {"locality0", GetBackendPorts(3, 4)},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args));
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args1, kNewEdsService1Name));
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args2, kNewEdsService2Name));
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args3, kNewEdsService3Name));
+ // Populate new CDS resources.
+ Cluster new_cluster1 = balancers_[0]->ads_service()->default_cluster();
+ new_cluster1.set_name(kNewCluster1Name);
+ new_cluster1.mutable_eds_cluster_config()->set_service_name(
+ kNewEdsService1Name);
+ balancers_[0]->ads_service()->SetCdsResource(new_cluster1);
+ Cluster new_cluster2 = balancers_[0]->ads_service()->default_cluster();
+ new_cluster2.set_name(kNewCluster2Name);
+ new_cluster2.mutable_eds_cluster_config()->set_service_name(
+ kNewEdsService2Name);
+ balancers_[0]->ads_service()->SetCdsResource(new_cluster2);
+ Cluster new_cluster3 = balancers_[0]->ads_service()->default_cluster();
+ new_cluster3.set_name(kNewCluster3Name);
+ new_cluster3.mutable_eds_cluster_config()->set_service_name(
+ kNewEdsService3Name);
+ balancers_[0]->ads_service()->SetCdsResource(new_cluster3);
+ // Populating Route Configurations for LDS.
+ RouteConfiguration route_config =
+ balancers_[0]->ads_service()->default_route_config();
+ auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0);
+ route1->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/");
+ auto* header_matcher1 = route1->mutable_match()->add_headers();
+ header_matcher1->set_name("header1");
+ header_matcher1->set_exact_match("POST");
+ route1->mutable_route()->set_cluster(kNewCluster1Name);
+ auto route2 = route_config.mutable_virtual_hosts(0)->add_routes();
+ route2->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/");
+ auto* header_matcher2 = route2->mutable_match()->add_headers();
+ header_matcher2->set_name("header2");
+ header_matcher2->mutable_range_match()->set_start(1);
+ header_matcher2->mutable_range_match()->set_end(1000);
+ route2->mutable_route()->set_cluster(kNewCluster2Name);
+ auto route3 = route_config.mutable_virtual_hosts(0)->add_routes();
+ route3->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/");
+ auto* header_matcher3 = route3->mutable_match()->add_headers();
+ header_matcher3->set_name("header3");
+ header_matcher3->mutable_safe_regex_match()->set_regex("[a-z]*");
+ route3->mutable_route()->set_cluster(kNewCluster3Name);
+ auto* default_route = route_config.mutable_virtual_hosts(0)->add_routes();
+ default_route->mutable_match()->set_prefix("");
+ default_route->mutable_route()->set_cluster(kDefaultClusterName);
+ SetRouteConfiguration(0, route_config);
+ // Send headers which will mismatch each route
+ std::vector<std::pair<TString, TString>> metadata = {
+ {"header1", "POST"},
+ {"header2", "1000"},
+ {"header3", "123"},
+ {"header1", "GET"},
+ };
+ WaitForAllBackends(0, 1);
+ CheckRpcSendOk(kNumEchoRpcs, RpcOptions().set_metadata(metadata));
+ CheckRpcSendOk(kNumEcho1Rpcs, RpcOptions()
+ .set_rpc_service(SERVICE_ECHO1)
+ .set_rpc_method(METHOD_ECHO1)
+ .set_metadata(metadata));
+ // Verify that only the default backend got RPCs since all previous routes
+ // were mismatched.
+ for (size_t i = 1; i < 4; ++i) {
+ EXPECT_EQ(0, backends_[i]->backend_service()->request_count());
+ EXPECT_EQ(0, backends_[i]->backend_service1()->request_count());
+ EXPECT_EQ(0, backends_[i]->backend_service2()->request_count());
+ }
+ EXPECT_EQ(kNumEchoRpcs, backends_[0]->backend_service()->request_count());
+ EXPECT_EQ(kNumEcho1Rpcs, backends_[0]->backend_service1()->request_count());
+ EXPECT_EQ(0, backends_[0]->backend_service2()->request_count());
+ const auto& response_state = RouteConfigurationResponseState(0);
+ EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::ACKED);
+}
+
+TEST_P(LdsRdsTest, XdsRoutingChangeRoutesWithoutChangingClusters) {
+ const char* kNewClusterName = "new_cluster";
+ const char* kNewEdsServiceName = "new_eds_service_name";
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ // Populate new EDS resources.
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts(0, 1)},
+ });
+ AdsServiceImpl::EdsResourceArgs args1({
+ {"locality0", GetBackendPorts(1, 2)},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args));
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args1, kNewEdsServiceName));
+ // Populate new CDS resources.
+ Cluster new_cluster = balancers_[0]->ads_service()->default_cluster();
+ new_cluster.set_name(kNewClusterName);
+ new_cluster.mutable_eds_cluster_config()->set_service_name(
+ kNewEdsServiceName);
+ balancers_[0]->ads_service()->SetCdsResource(new_cluster);
+ // Populating Route Configurations for LDS.
+ RouteConfiguration route_config =
+ balancers_[0]->ads_service()->default_route_config();
+ auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0);
+ route1->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/");
+ route1->mutable_route()->set_cluster(kNewClusterName);
+ auto* default_route = route_config.mutable_virtual_hosts(0)->add_routes();
+ default_route->mutable_match()->set_prefix("");
+ default_route->mutable_route()->set_cluster(kDefaultClusterName);
+ SetRouteConfiguration(0, route_config);
+ // Make sure all backends are up and that requests for each RPC
+ // service go to the right backends.
+ WaitForAllBackends(0, 1, false);
+ WaitForAllBackends(1, 2, false, RpcOptions().set_rpc_service(SERVICE_ECHO1));
+ WaitForAllBackends(0, 1, false, RpcOptions().set_rpc_service(SERVICE_ECHO2));
+ // Requests for services Echo and Echo2 should have gone to backend 0.
+ EXPECT_EQ(1, backends_[0]->backend_service()->request_count());
+ EXPECT_EQ(0, backends_[0]->backend_service1()->request_count());
+ EXPECT_EQ(1, backends_[0]->backend_service2()->request_count());
+ // Requests for service Echo1 should have gone to backend 1.
+ EXPECT_EQ(0, backends_[1]->backend_service()->request_count());
+ EXPECT_EQ(1, backends_[1]->backend_service1()->request_count());
+ EXPECT_EQ(0, backends_[1]->backend_service2()->request_count());
+ // Now send an update that changes the first route to match a
+ // different RPC service, and wait for the client to make the change.
+ route1->mutable_match()->set_prefix("/grpc.testing.EchoTest2Service/");
+ SetRouteConfiguration(0, route_config);
+ WaitForAllBackends(1, 2, true, RpcOptions().set_rpc_service(SERVICE_ECHO2));
+ // Now repeat the earlier test, making sure all traffic goes to the
+ // right place.
+ WaitForAllBackends(0, 1, false);
+ WaitForAllBackends(0, 1, false, RpcOptions().set_rpc_service(SERVICE_ECHO1));
+ WaitForAllBackends(1, 2, false, RpcOptions().set_rpc_service(SERVICE_ECHO2));
+ // Requests for services Echo and Echo1 should have gone to backend 0.
+ EXPECT_EQ(1, backends_[0]->backend_service()->request_count());
+ EXPECT_EQ(1, backends_[0]->backend_service1()->request_count());
+ EXPECT_EQ(0, backends_[0]->backend_service2()->request_count());
+ // Requests for service Echo2 should have gone to backend 1.
+ EXPECT_EQ(0, backends_[1]->backend_service()->request_count());
+ EXPECT_EQ(0, backends_[1]->backend_service1()->request_count());
+ EXPECT_EQ(1, backends_[1]->backend_service2()->request_count());
+}
+
+using CdsTest = BasicTest;
+
+// Tests that CDS client should send an ACK upon correct CDS response.
+TEST_P(CdsTest, Vanilla) {
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ (void)SendRpc();
+ EXPECT_EQ(balancers_[0]->ads_service()->cds_response_state().state,
+ AdsServiceImpl::ResponseState::ACKED);
+}
+
+// Tests that CDS client should send a NACK if the cluster type in CDS response
+// is other than EDS.
+TEST_P(CdsTest, WrongClusterType) {
+ auto cluster = balancers_[0]->ads_service()->default_cluster();
+ cluster.set_type(Cluster::STATIC);
+ balancers_[0]->ads_service()->SetCdsResource(cluster);
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ CheckRpcSendFailure();
+ const auto& response_state =
+ balancers_[0]->ads_service()->cds_response_state();
+ EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED);
+ EXPECT_EQ(response_state.error_message, "DiscoveryType is not EDS.");
+}
+
+// Tests that CDS client should send a NACK if the eds_config in CDS response is
+// other than ADS.
+TEST_P(CdsTest, WrongEdsConfig) {
+ auto cluster = balancers_[0]->ads_service()->default_cluster();
+ cluster.mutable_eds_cluster_config()->mutable_eds_config()->mutable_self();
+ balancers_[0]->ads_service()->SetCdsResource(cluster);
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ CheckRpcSendFailure();
+ const auto& response_state =
+ balancers_[0]->ads_service()->cds_response_state();
+ EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED);
+ EXPECT_EQ(response_state.error_message, "EDS ConfigSource is not ADS.");
+}
+
+// Tests that CDS client should send a NACK if the lb_policy in CDS response is
+// other than ROUND_ROBIN.
+TEST_P(CdsTest, WrongLbPolicy) {
+ auto cluster = balancers_[0]->ads_service()->default_cluster();
+ cluster.set_lb_policy(Cluster::LEAST_REQUEST);
+ balancers_[0]->ads_service()->SetCdsResource(cluster);
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ CheckRpcSendFailure();
+ const auto& response_state =
+ balancers_[0]->ads_service()->cds_response_state();
+ EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED);
+ EXPECT_EQ(response_state.error_message, "LB policy is not ROUND_ROBIN.");
+}
+
+// Tests that CDS client should send a NACK if the lrs_server in CDS response is
+// other than SELF.
+TEST_P(CdsTest, WrongLrsServer) {
+ auto cluster = balancers_[0]->ads_service()->default_cluster();
+ cluster.mutable_lrs_server()->mutable_ads();
+ balancers_[0]->ads_service()->SetCdsResource(cluster);
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ CheckRpcSendFailure();
+ const auto& response_state =
+ balancers_[0]->ads_service()->cds_response_state();
+ EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED);
+ EXPECT_EQ(response_state.error_message, "LRS ConfigSource is not self.");
+}
+
+using EdsTest = BasicTest;
+
+// Tests that EDS client should send a NACK if the EDS update contains
+// sparse priorities.
+TEST_P(EdsTest, NacksSparsePriorityList) {
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts(), kDefaultLocalityWeight, 1},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args));
+ CheckRpcSendFailure();
+ const auto& response_state =
+ balancers_[0]->ads_service()->eds_response_state();
+ EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED);
+ EXPECT_EQ(response_state.error_message,
+ "EDS update includes sparse priority list");
+}
+
+// In most of our tests, we use different names for different resource
+// types, to make sure that there are no cut-and-paste errors in the code
+// that cause us to look at data for the wrong resource type. So we add
+// this test to make sure that the EDS resource name defaults to the
+// cluster name if not specified in the CDS resource.
+TEST_P(EdsTest, EdsServiceNameDefaultsToClusterName) {
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts()},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args, kDefaultClusterName));
+ Cluster cluster = balancers_[0]->ads_service()->default_cluster();
+ cluster.mutable_eds_cluster_config()->clear_service_name();
+ balancers_[0]->ads_service()->SetCdsResource(cluster);
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ CheckRpcSendOk();
+}
+
+class TimeoutTest : public BasicTest {
+ protected:
+ void SetUp() override {
+ xds_resource_does_not_exist_timeout_ms_ = 500;
+ BasicTest::SetUp();
+ }
+};
+
+// Tests that LDS client times out when no response received.
+TEST_P(TimeoutTest, Lds) {
+ balancers_[0]->ads_service()->SetResourceIgnore(kLdsTypeUrl);
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ CheckRpcSendFailure();
+}
+
+TEST_P(TimeoutTest, Rds) {
+ balancers_[0]->ads_service()->SetResourceIgnore(kRdsTypeUrl);
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ CheckRpcSendFailure();
+}
+
+// Tests that CDS client times out when no response received.
+TEST_P(TimeoutTest, Cds) {
+ balancers_[0]->ads_service()->SetResourceIgnore(kCdsTypeUrl);
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ CheckRpcSendFailure();
+}
+
+TEST_P(TimeoutTest, Eds) {
+ balancers_[0]->ads_service()->SetResourceIgnore(kEdsTypeUrl);
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ CheckRpcSendFailure();
+}
+
+using LocalityMapTest = BasicTest;
+
+// Tests that the localities in a locality map are picked according to their
+// weights.
+TEST_P(LocalityMapTest, WeightedRoundRobin) {
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ const size_t kNumRpcs = 5000;
+ const int kLocalityWeight0 = 2;
+ const int kLocalityWeight1 = 8;
+ const int kTotalLocalityWeight = kLocalityWeight0 + kLocalityWeight1;
+ const double kLocalityWeightRate0 =
+ static_cast<double>(kLocalityWeight0) / kTotalLocalityWeight;
+ const double kLocalityWeightRate1 =
+ static_cast<double>(kLocalityWeight1) / kTotalLocalityWeight;
+ // ADS response contains 2 localities, each of which contains 1 backend.
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts(0, 1), kLocalityWeight0},
+ {"locality1", GetBackendPorts(1, 2), kLocalityWeight1},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()));
+ // Wait for both backends to be ready.
+ WaitForAllBackends(0, 2);
+ // Send kNumRpcs RPCs.
+ CheckRpcSendOk(kNumRpcs);
+ // The locality picking rates should be roughly equal to the expectation.
+ const double locality_picked_rate_0 =
+ static_cast<double>(backends_[0]->backend_service()->request_count()) /
+ kNumRpcs;
+ const double locality_picked_rate_1 =
+ static_cast<double>(backends_[1]->backend_service()->request_count()) /
+ kNumRpcs;
+ const double kErrorTolerance = 0.2;
+ EXPECT_THAT(locality_picked_rate_0,
+ ::testing::AllOf(
+ ::testing::Ge(kLocalityWeightRate0 * (1 - kErrorTolerance)),
+ ::testing::Le(kLocalityWeightRate0 * (1 + kErrorTolerance))));
+ EXPECT_THAT(locality_picked_rate_1,
+ ::testing::AllOf(
+ ::testing::Ge(kLocalityWeightRate1 * (1 - kErrorTolerance)),
+ ::testing::Le(kLocalityWeightRate1 * (1 + kErrorTolerance))));
+}
+
+// Tests that we correctly handle a locality containing no endpoints.
+TEST_P(LocalityMapTest, LocalityContainingNoEndpoints) {
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ const size_t kNumRpcs = 5000;
+ // EDS response contains 2 localities, one with no endpoints.
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts()},
+ {"locality1", {}},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()));
+ // Wait for both backends to be ready.
+ WaitForAllBackends();
+ // Send kNumRpcs RPCs.
+ CheckRpcSendOk(kNumRpcs);
+ // All traffic should go to the reachable locality.
+ EXPECT_EQ(backends_[0]->backend_service()->request_count(),
+ kNumRpcs / backends_.size());
+ EXPECT_EQ(backends_[1]->backend_service()->request_count(),
+ kNumRpcs / backends_.size());
+ EXPECT_EQ(backends_[2]->backend_service()->request_count(),
+ kNumRpcs / backends_.size());
+ EXPECT_EQ(backends_[3]->backend_service()->request_count(),
+ kNumRpcs / backends_.size());
+}
+
+// EDS update with no localities.
+TEST_P(LocalityMapTest, NoLocalities) {
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource({}, DefaultEdsServiceName()));
+ Status status = SendRpc();
+ EXPECT_FALSE(status.ok());
+ EXPECT_EQ(status.error_code(), StatusCode::UNAVAILABLE);
+}
+
+// Tests that the locality map can work properly even when it contains a large
+// number of localities.
+TEST_P(LocalityMapTest, StressTest) {
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ const size_t kNumLocalities = 100;
+ // The first ADS response contains kNumLocalities localities, each of which
+ // contains backend 0.
+ AdsServiceImpl::EdsResourceArgs args;
+ for (size_t i = 0; i < kNumLocalities; ++i) {
+ TString name = y_absl::StrCat("locality", i);
+ AdsServiceImpl::EdsResourceArgs::Locality locality(name,
+ {backends_[0]->port()});
+ args.locality_list.emplace_back(std::move(locality));
+ }
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()));
+ // The second ADS response contains 1 locality, which contains backend 1.
+ args = AdsServiceImpl::EdsResourceArgs({
+ {"locality0", GetBackendPorts(1, 2)},
+ });
+ std::thread delayed_resource_setter(
+ std::bind(&BasicTest::SetEdsResourceWithDelay, this, 0,
+ AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()),
+ 60 * 1000));
+ // Wait until backend 0 is ready, before which kNumLocalities localities are
+ // received and handled by the xds policy.
+ WaitForBackend(0, /*reset_counters=*/false);
+ EXPECT_EQ(0U, backends_[1]->backend_service()->request_count());
+ // Wait until backend 1 is ready, before which kNumLocalities localities are
+ // removed by the xds policy.
+ WaitForBackend(1);
+ delayed_resource_setter.join();
+}
+
+// Tests that the localities in a locality map are picked correctly after update
+// (addition, modification, deletion).
+TEST_P(LocalityMapTest, UpdateMap) {
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ const size_t kNumRpcs = 3000;
+ // The locality weight for the first 3 localities.
+ const std::vector<int> kLocalityWeights0 = {2, 3, 4};
+ const double kTotalLocalityWeight0 =
+ std::accumulate(kLocalityWeights0.begin(), kLocalityWeights0.end(), 0);
+ std::vector<double> locality_weight_rate_0;
+ for (int weight : kLocalityWeights0) {
+ locality_weight_rate_0.push_back(weight / kTotalLocalityWeight0);
+ }
+ // Delete the first locality, keep the second locality, change the third
+ // locality's weight from 4 to 2, and add a new locality with weight 6.
+ const std::vector<int> kLocalityWeights1 = {3, 2, 6};
+ const double kTotalLocalityWeight1 =
+ std::accumulate(kLocalityWeights1.begin(), kLocalityWeights1.end(), 0);
+ std::vector<double> locality_weight_rate_1 = {
+ 0 /* placeholder for locality 0 */};
+ for (int weight : kLocalityWeights1) {
+ locality_weight_rate_1.push_back(weight / kTotalLocalityWeight1);
+ }
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts(0, 1), 2},
+ {"locality1", GetBackendPorts(1, 2), 3},
+ {"locality2", GetBackendPorts(2, 3), 4},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()));
+ // Wait for the first 3 backends to be ready.
+ WaitForAllBackends(0, 3);
+ gpr_log(GPR_INFO, "========= BEFORE FIRST BATCH ==========");
+ // Send kNumRpcs RPCs.
+ CheckRpcSendOk(kNumRpcs);
+ gpr_log(GPR_INFO, "========= DONE WITH FIRST BATCH ==========");
+ // The picking rates of the first 3 backends should be roughly equal to the
+ // expectation.
+ std::vector<double> locality_picked_rates;
+ for (size_t i = 0; i < 3; ++i) {
+ locality_picked_rates.push_back(
+ static_cast<double>(backends_[i]->backend_service()->request_count()) /
+ kNumRpcs);
+ }
+ const double kErrorTolerance = 0.2;
+ for (size_t i = 0; i < 3; ++i) {
+ gpr_log(GPR_INFO, "Locality %" PRIuPTR " rate %f", i,
+ locality_picked_rates[i]);
+ EXPECT_THAT(
+ locality_picked_rates[i],
+ ::testing::AllOf(
+ ::testing::Ge(locality_weight_rate_0[i] * (1 - kErrorTolerance)),
+ ::testing::Le(locality_weight_rate_0[i] * (1 + kErrorTolerance))));
+ }
+ args = AdsServiceImpl::EdsResourceArgs({
+ {"locality1", GetBackendPorts(1, 2), 3},
+ {"locality2", GetBackendPorts(2, 3), 2},
+ {"locality3", GetBackendPorts(3, 4), 6},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()));
+ // Backend 3 hasn't received any request.
+ EXPECT_EQ(0U, backends_[3]->backend_service()->request_count());
+ // Wait until the locality update has been processed, as signaled by backend 3
+ // receiving a request.
+ WaitForAllBackends(3, 4);
+ gpr_log(GPR_INFO, "========= BEFORE SECOND BATCH ==========");
+ // Send kNumRpcs RPCs.
+ CheckRpcSendOk(kNumRpcs);
+ gpr_log(GPR_INFO, "========= DONE WITH SECOND BATCH ==========");
+ // Backend 0 no longer receives any request.
+ EXPECT_EQ(0U, backends_[0]->backend_service()->request_count());
+ // The picking rates of the last 3 backends should be roughly equal to the
+ // expectation.
+ locality_picked_rates = {0 /* placeholder for backend 0 */};
+ for (size_t i = 1; i < 4; ++i) {
+ locality_picked_rates.push_back(
+ static_cast<double>(backends_[i]->backend_service()->request_count()) /
+ kNumRpcs);
+ }
+ for (size_t i = 1; i < 4; ++i) {
+ gpr_log(GPR_INFO, "Locality %" PRIuPTR " rate %f", i,
+ locality_picked_rates[i]);
+ EXPECT_THAT(
+ locality_picked_rates[i],
+ ::testing::AllOf(
+ ::testing::Ge(locality_weight_rate_1[i] * (1 - kErrorTolerance)),
+ ::testing::Le(locality_weight_rate_1[i] * (1 + kErrorTolerance))));
+ }
+}
+
+// Tests that we don't fail RPCs when replacing all of the localities in
+// a given priority.
+TEST_P(LocalityMapTest, ReplaceAllLocalitiesInPriority) {
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts(0, 1)},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()));
+ args = AdsServiceImpl::EdsResourceArgs({
+ {"locality1", GetBackendPorts(1, 2)},
+ });
+ std::thread delayed_resource_setter(std::bind(
+ &BasicTest::SetEdsResourceWithDelay, this, 0,
+ AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()), 5000));
+ // Wait for the first backend to be ready.
+ WaitForBackend(0);
+ // Keep sending RPCs until we switch over to backend 1, which tells us
+ // that we received the update. No RPCs should fail during this
+ // transition.
+ WaitForBackend(1, /*reset_counters=*/true, /*require_success=*/true);
+ delayed_resource_setter.join();
+}
+
+class FailoverTest : public BasicTest {
+ public:
+ void SetUp() override {
+ BasicTest::SetUp();
+ ResetStub(500);
+ }
+};
+
+// Localities with the highest priority are used when multiple priority exist.
+TEST_P(FailoverTest, ChooseHighestPriority) {
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts(0, 1), kDefaultLocalityWeight, 1},
+ {"locality1", GetBackendPorts(1, 2), kDefaultLocalityWeight, 2},
+ {"locality2", GetBackendPorts(2, 3), kDefaultLocalityWeight, 3},
+ {"locality3", GetBackendPorts(3, 4), kDefaultLocalityWeight, 0},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()));
+ WaitForBackend(3, false);
+ for (size_t i = 0; i < 3; ++i) {
+ EXPECT_EQ(0U, backends_[i]->backend_service()->request_count());
+ }
+}
+
+// Does not choose priority with no endpoints.
+TEST_P(FailoverTest, DoesNotUsePriorityWithNoEndpoints) {
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts(0, 1), kDefaultLocalityWeight, 1},
+ {"locality1", GetBackendPorts(1, 2), kDefaultLocalityWeight, 2},
+ {"locality2", GetBackendPorts(2, 3), kDefaultLocalityWeight, 3},
+ {"locality3", {}, kDefaultLocalityWeight, 0},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()));
+ WaitForBackend(0, false);
+ for (size_t i = 1; i < 3; ++i) {
+ EXPECT_EQ(0U, backends_[i]->backend_service()->request_count());
+ }
+}
+
+// Does not choose locality with no endpoints.
+TEST_P(FailoverTest, DoesNotUseLocalityWithNoEndpoints) {
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", {}, kDefaultLocalityWeight, 0},
+ {"locality1", GetBackendPorts(), kDefaultLocalityWeight, 0},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()));
+ // Wait for all backends to be used.
+ std::tuple<int, int, int> counts = WaitForAllBackends();
+ // Make sure no RPCs failed in the transition.
+ EXPECT_EQ(0, std::get<1>(counts));
+}
+
+// If the higher priority localities are not reachable, failover to the highest
+// priority among the rest.
+TEST_P(FailoverTest, Failover) {
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts(0, 1), kDefaultLocalityWeight, 1},
+ {"locality1", GetBackendPorts(1, 2), kDefaultLocalityWeight, 2},
+ {"locality2", GetBackendPorts(2, 3), kDefaultLocalityWeight, 3},
+ {"locality3", GetBackendPorts(3, 4), kDefaultLocalityWeight, 0},
+ });
+ ShutdownBackend(3);
+ ShutdownBackend(0);
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()));
+ WaitForBackend(1, false);
+ for (size_t i = 0; i < 4; ++i) {
+ if (i == 1) continue;
+ EXPECT_EQ(0U, backends_[i]->backend_service()->request_count());
+ }
+}
+
+// If a locality with higher priority than the current one becomes ready,
+// switch to it.
+TEST_P(FailoverTest, SwitchBackToHigherPriority) {
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ const size_t kNumRpcs = 100;
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts(0, 1), kDefaultLocalityWeight, 1},
+ {"locality1", GetBackendPorts(1, 2), kDefaultLocalityWeight, 2},
+ {"locality2", GetBackendPorts(2, 3), kDefaultLocalityWeight, 3},
+ {"locality3", GetBackendPorts(3, 4), kDefaultLocalityWeight, 0},
+ });
+ ShutdownBackend(3);
+ ShutdownBackend(0);
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()));
+ WaitForBackend(1, false);
+ for (size_t i = 0; i < 4; ++i) {
+ if (i == 1) continue;
+ EXPECT_EQ(0U, backends_[i]->backend_service()->request_count());
+ }
+ StartBackend(0);
+ WaitForBackend(0);
+ CheckRpcSendOk(kNumRpcs);
+ EXPECT_EQ(kNumRpcs, backends_[0]->backend_service()->request_count());
+}
+
+// The first update only contains unavailable priorities. The second update
+// contains available priorities.
+TEST_P(FailoverTest, UpdateInitialUnavailable) {
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts(0, 1), kDefaultLocalityWeight, 0},
+ {"locality1", GetBackendPorts(1, 2), kDefaultLocalityWeight, 1},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()));
+ args = AdsServiceImpl::EdsResourceArgs({
+ {"locality0", GetBackendPorts(0, 1), kDefaultLocalityWeight, 0},
+ {"locality1", GetBackendPorts(1, 2), kDefaultLocalityWeight, 1},
+ {"locality2", GetBackendPorts(2, 3), kDefaultLocalityWeight, 2},
+ {"locality3", GetBackendPorts(3, 4), kDefaultLocalityWeight, 3},
+ });
+ ShutdownBackend(0);
+ ShutdownBackend(1);
+ std::thread delayed_resource_setter(std::bind(
+ &BasicTest::SetEdsResourceWithDelay, this, 0,
+ AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()), 1000));
+ gpr_timespec deadline = gpr_time_add(gpr_now(GPR_CLOCK_REALTIME),
+ gpr_time_from_millis(500, GPR_TIMESPAN));
+ // Send 0.5 second worth of RPCs.
+ do {
+ CheckRpcSendFailure();
+ } while (gpr_time_cmp(gpr_now(GPR_CLOCK_REALTIME), deadline) < 0);
+ WaitForBackend(2, false);
+ for (size_t i = 0; i < 4; ++i) {
+ if (i == 2) continue;
+ EXPECT_EQ(0U, backends_[i]->backend_service()->request_count());
+ }
+ delayed_resource_setter.join();
+}
+
+// Tests that after the localities' priorities are updated, we still choose the
+// highest READY priority with the updated localities.
+TEST_P(FailoverTest, UpdatePriority) {
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ const size_t kNumRpcs = 100;
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts(0, 1), kDefaultLocalityWeight, 1},
+ {"locality1", GetBackendPorts(1, 2), kDefaultLocalityWeight, 2},
+ {"locality2", GetBackendPorts(2, 3), kDefaultLocalityWeight, 3},
+ {"locality3", GetBackendPorts(3, 4), kDefaultLocalityWeight, 0},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()));
+ args = AdsServiceImpl::EdsResourceArgs({
+ {"locality0", GetBackendPorts(0, 1), kDefaultLocalityWeight, 2},
+ {"locality1", GetBackendPorts(1, 2), kDefaultLocalityWeight, 0},
+ {"locality2", GetBackendPorts(2, 3), kDefaultLocalityWeight, 1},
+ {"locality3", GetBackendPorts(3, 4), kDefaultLocalityWeight, 3},
+ });
+ std::thread delayed_resource_setter(std::bind(
+ &BasicTest::SetEdsResourceWithDelay, this, 0,
+ AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()), 1000));
+ WaitForBackend(3, false);
+ for (size_t i = 0; i < 3; ++i) {
+ EXPECT_EQ(0U, backends_[i]->backend_service()->request_count());
+ }
+ WaitForBackend(1);
+ CheckRpcSendOk(kNumRpcs);
+ EXPECT_EQ(kNumRpcs, backends_[1]->backend_service()->request_count());
+ delayed_resource_setter.join();
+}
+
+// Moves all localities in the current priority to a higher priority.
+TEST_P(FailoverTest, MoveAllLocalitiesInCurrentPriorityToHigherPriority) {
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ // First update:
+ // - Priority 0 is locality 0, containing backend 0, which is down.
+ // - Priority 1 is locality 1, containing backends 1 and 2, which are up.
+ ShutdownBackend(0);
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts(0, 1), kDefaultLocalityWeight, 0},
+ {"locality1", GetBackendPorts(1, 3), kDefaultLocalityWeight, 1},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()));
+ // Second update:
+ // - Priority 0 contains both localities 0 and 1.
+ // - Priority 1 is not present.
+ // - We add backend 3 to locality 1, just so we have a way to know
+ // when the update has been seen by the client.
+ args = AdsServiceImpl::EdsResourceArgs({
+ {"locality0", GetBackendPorts(0, 1), kDefaultLocalityWeight, 0},
+ {"locality1", GetBackendPorts(1, 4), kDefaultLocalityWeight, 0},
+ });
+ std::thread delayed_resource_setter(std::bind(
+ &BasicTest::SetEdsResourceWithDelay, this, 0,
+ AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()), 1000));
+ // When we get the first update, all backends in priority 0 are down,
+ // so we will create priority 1. Backends 1 and 2 should have traffic,
+ // but backend 3 should not.
+ WaitForAllBackends(1, 3, false);
+ EXPECT_EQ(0UL, backends_[3]->backend_service()->request_count());
+ // When backend 3 gets traffic, we know the second update has been seen.
+ WaitForBackend(3);
+ // The ADS service of balancer 0 got at least 1 response.
+ EXPECT_GT(balancers_[0]->ads_service()->eds_response_state().state,
+ AdsServiceImpl::ResponseState::NOT_SENT);
+ delayed_resource_setter.join();
+}
+
+using DropTest = BasicTest;
+
+// Tests that RPCs are dropped according to the drop config.
+TEST_P(DropTest, Vanilla) {
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ const size_t kNumRpcs = 5000;
+ const uint32_t kDropPerMillionForLb = 100000;
+ const uint32_t kDropPerMillionForThrottle = 200000;
+ const double kDropRateForLb = kDropPerMillionForLb / 1000000.0;
+ const double kDropRateForThrottle = kDropPerMillionForThrottle / 1000000.0;
+ const double KDropRateForLbAndThrottle =
+ kDropRateForLb + (1 - kDropRateForLb) * kDropRateForThrottle;
+ // The ADS response contains two drop categories.
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts()},
+ });
+ args.drop_categories = {{kLbDropType, kDropPerMillionForLb},
+ {kThrottleDropType, kDropPerMillionForThrottle}};
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()));
+ WaitForAllBackends();
+ // Send kNumRpcs RPCs and count the drops.
+ size_t num_drops = 0;
+ for (size_t i = 0; i < kNumRpcs; ++i) {
+ EchoResponse response;
+ const Status status = SendRpc(RpcOptions(), &response);
+ if (!status.ok() &&
+ status.error_message() == "Call dropped by load balancing policy") {
+ ++num_drops;
+ } else {
+ EXPECT_TRUE(status.ok()) << "code=" << status.error_code()
+ << " message=" << status.error_message();
+ EXPECT_EQ(response.message(), kRequestMessage);
+ }
+ }
+ // The drop rate should be roughly equal to the expectation.
+ const double seen_drop_rate = static_cast<double>(num_drops) / kNumRpcs;
+ const double kErrorTolerance = 0.2;
+ EXPECT_THAT(
+ seen_drop_rate,
+ ::testing::AllOf(
+ ::testing::Ge(KDropRateForLbAndThrottle * (1 - kErrorTolerance)),
+ ::testing::Le(KDropRateForLbAndThrottle * (1 + kErrorTolerance))));
+}
+
+// Tests that drop config is converted correctly from per hundred.
+TEST_P(DropTest, DropPerHundred) {
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ const size_t kNumRpcs = 5000;
+ const uint32_t kDropPerHundredForLb = 10;
+ const double kDropRateForLb = kDropPerHundredForLb / 100.0;
+ // The ADS response contains one drop category.
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts()},
+ });
+ args.drop_categories = {{kLbDropType, kDropPerHundredForLb}};
+ args.drop_denominator = FractionalPercent::HUNDRED;
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()));
+ WaitForAllBackends();
+ // Send kNumRpcs RPCs and count the drops.
+ size_t num_drops = 0;
+ for (size_t i = 0; i < kNumRpcs; ++i) {
+ EchoResponse response;
+ const Status status = SendRpc(RpcOptions(), &response);
+ if (!status.ok() &&
+ status.error_message() == "Call dropped by load balancing policy") {
+ ++num_drops;
+ } else {
+ EXPECT_TRUE(status.ok()) << "code=" << status.error_code()
+ << " message=" << status.error_message();
+ EXPECT_EQ(response.message(), kRequestMessage);
+ }
+ }
+ // The drop rate should be roughly equal to the expectation.
+ const double seen_drop_rate = static_cast<double>(num_drops) / kNumRpcs;
+ const double kErrorTolerance = 0.2;
+ EXPECT_THAT(
+ seen_drop_rate,
+ ::testing::AllOf(::testing::Ge(kDropRateForLb * (1 - kErrorTolerance)),
+ ::testing::Le(kDropRateForLb * (1 + kErrorTolerance))));
+}
+
+// Tests that drop config is converted correctly from per ten thousand.
+TEST_P(DropTest, DropPerTenThousand) {
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ const size_t kNumRpcs = 5000;
+ const uint32_t kDropPerTenThousandForLb = 1000;
+ const double kDropRateForLb = kDropPerTenThousandForLb / 10000.0;
+ // The ADS response contains one drop category.
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts()},
+ });
+ args.drop_categories = {{kLbDropType, kDropPerTenThousandForLb}};
+ args.drop_denominator = FractionalPercent::TEN_THOUSAND;
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()));
+ WaitForAllBackends();
+ // Send kNumRpcs RPCs and count the drops.
+ size_t num_drops = 0;
+ for (size_t i = 0; i < kNumRpcs; ++i) {
+ EchoResponse response;
+ const Status status = SendRpc(RpcOptions(), &response);
+ if (!status.ok() &&
+ status.error_message() == "Call dropped by load balancing policy") {
+ ++num_drops;
+ } else {
+ EXPECT_TRUE(status.ok()) << "code=" << status.error_code()
+ << " message=" << status.error_message();
+ EXPECT_EQ(response.message(), kRequestMessage);
+ }
+ }
+ // The drop rate should be roughly equal to the expectation.
+ const double seen_drop_rate = static_cast<double>(num_drops) / kNumRpcs;
+ const double kErrorTolerance = 0.2;
+ EXPECT_THAT(
+ seen_drop_rate,
+ ::testing::AllOf(::testing::Ge(kDropRateForLb * (1 - kErrorTolerance)),
+ ::testing::Le(kDropRateForLb * (1 + kErrorTolerance))));
+}
+
+// Tests that drop is working correctly after update.
+TEST_P(DropTest, Update) {
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ const size_t kNumRpcs = 3000;
+ const uint32_t kDropPerMillionForLb = 100000;
+ const uint32_t kDropPerMillionForThrottle = 200000;
+ const double kDropRateForLb = kDropPerMillionForLb / 1000000.0;
+ const double kDropRateForThrottle = kDropPerMillionForThrottle / 1000000.0;
+ const double KDropRateForLbAndThrottle =
+ kDropRateForLb + (1 - kDropRateForLb) * kDropRateForThrottle;
+ // The first ADS response contains one drop category.
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts()},
+ });
+ args.drop_categories = {{kLbDropType, kDropPerMillionForLb}};
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()));
+ WaitForAllBackends();
+ // Send kNumRpcs RPCs and count the drops.
+ size_t num_drops = 0;
+ gpr_log(GPR_INFO, "========= BEFORE FIRST BATCH ==========");
+ for (size_t i = 0; i < kNumRpcs; ++i) {
+ EchoResponse response;
+ const Status status = SendRpc(RpcOptions(), &response);
+ if (!status.ok() &&
+ status.error_message() == "Call dropped by load balancing policy") {
+ ++num_drops;
+ } else {
+ EXPECT_TRUE(status.ok()) << "code=" << status.error_code()
+ << " message=" << status.error_message();
+ EXPECT_EQ(response.message(), kRequestMessage);
+ }
+ }
+ gpr_log(GPR_INFO, "========= DONE WITH FIRST BATCH ==========");
+ // The drop rate should be roughly equal to the expectation.
+ double seen_drop_rate = static_cast<double>(num_drops) / kNumRpcs;
+ gpr_log(GPR_INFO, "First batch drop rate %f", seen_drop_rate);
+ const double kErrorTolerance = 0.3;
+ EXPECT_THAT(
+ seen_drop_rate,
+ ::testing::AllOf(::testing::Ge(kDropRateForLb * (1 - kErrorTolerance)),
+ ::testing::Le(kDropRateForLb * (1 + kErrorTolerance))));
+ // The second ADS response contains two drop categories, send an update EDS
+ // response.
+ args.drop_categories = {{kLbDropType, kDropPerMillionForLb},
+ {kThrottleDropType, kDropPerMillionForThrottle}};
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()));
+ // Wait until the drop rate increases to the middle of the two configs, which
+ // implies that the update has been in effect.
+ const double kDropRateThreshold =
+ (kDropRateForLb + KDropRateForLbAndThrottle) / 2;
+ size_t num_rpcs = kNumRpcs;
+ while (seen_drop_rate < kDropRateThreshold) {
+ EchoResponse response;
+ const Status status = SendRpc(RpcOptions(), &response);
+ ++num_rpcs;
+ if (!status.ok() &&
+ status.error_message() == "Call dropped by load balancing policy") {
+ ++num_drops;
+ } else {
+ EXPECT_TRUE(status.ok()) << "code=" << status.error_code()
+ << " message=" << status.error_message();
+ EXPECT_EQ(response.message(), kRequestMessage);
+ }
+ seen_drop_rate = static_cast<double>(num_drops) / num_rpcs;
+ }
+ // Send kNumRpcs RPCs and count the drops.
+ num_drops = 0;
+ gpr_log(GPR_INFO, "========= BEFORE SECOND BATCH ==========");
+ for (size_t i = 0; i < kNumRpcs; ++i) {
+ EchoResponse response;
+ const Status status = SendRpc(RpcOptions(), &response);
+ if (!status.ok() &&
+ status.error_message() == "Call dropped by load balancing policy") {
+ ++num_drops;
+ } else {
+ EXPECT_TRUE(status.ok()) << "code=" << status.error_code()
+ << " message=" << status.error_message();
+ EXPECT_EQ(response.message(), kRequestMessage);
+ }
+ }
+ gpr_log(GPR_INFO, "========= DONE WITH SECOND BATCH ==========");
+ // The new drop rate should be roughly equal to the expectation.
+ seen_drop_rate = static_cast<double>(num_drops) / kNumRpcs;
+ gpr_log(GPR_INFO, "Second batch drop rate %f", seen_drop_rate);
+ EXPECT_THAT(
+ seen_drop_rate,
+ ::testing::AllOf(
+ ::testing::Ge(KDropRateForLbAndThrottle * (1 - kErrorTolerance)),
+ ::testing::Le(KDropRateForLbAndThrottle * (1 + kErrorTolerance))));
+}
+
+// Tests that all the RPCs are dropped if any drop category drops 100%.
+TEST_P(DropTest, DropAll) {
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ const size_t kNumRpcs = 1000;
+ const uint32_t kDropPerMillionForLb = 100000;
+ const uint32_t kDropPerMillionForThrottle = 1000000;
+ // The ADS response contains two drop categories.
+ AdsServiceImpl::EdsResourceArgs args;
+ args.drop_categories = {{kLbDropType, kDropPerMillionForLb},
+ {kThrottleDropType, kDropPerMillionForThrottle}};
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()));
+ // Send kNumRpcs RPCs and all of them are dropped.
+ for (size_t i = 0; i < kNumRpcs; ++i) {
+ EchoResponse response;
+ const Status status = SendRpc(RpcOptions(), &response);
+ EXPECT_EQ(status.error_code(), StatusCode::UNAVAILABLE);
+ EXPECT_EQ(status.error_message(), "Call dropped by load balancing policy");
+ }
+}
+
+class BalancerUpdateTest : public XdsEnd2endTest {
+ public:
+ BalancerUpdateTest() : XdsEnd2endTest(4, 3) {}
+};
+
+// Tests that the old LB call is still used after the balancer address update as
+// long as that call is still alive.
+TEST_P(BalancerUpdateTest, UpdateBalancersButKeepUsingOriginalBalancer) {
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", {backends_[0]->port()}},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()));
+ args = AdsServiceImpl::EdsResourceArgs({
+ {"locality0", {backends_[1]->port()}},
+ });
+ balancers_[1]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()));
+ // Wait until the first backend is ready.
+ WaitForBackend(0);
+ // Send 10 requests.
+ gpr_log(GPR_INFO, "========= BEFORE FIRST BATCH ==========");
+ CheckRpcSendOk(10);
+ gpr_log(GPR_INFO, "========= DONE WITH FIRST BATCH ==========");
+ // All 10 requests should have gone to the first backend.
+ EXPECT_EQ(10U, backends_[0]->backend_service()->request_count());
+ // The ADS service of balancer 0 sent at least 1 response.
+ EXPECT_GT(balancers_[0]->ads_service()->eds_response_state().state,
+ AdsServiceImpl::ResponseState::NOT_SENT);
+ EXPECT_EQ(balancers_[1]->ads_service()->eds_response_state().state,
+ AdsServiceImpl::ResponseState::NOT_SENT)
+ << "Error Message:"
+ << balancers_[1]->ads_service()->eds_response_state().error_message;
+ EXPECT_EQ(balancers_[2]->ads_service()->eds_response_state().state,
+ AdsServiceImpl::ResponseState::NOT_SENT)
+ << "Error Message:"
+ << balancers_[2]->ads_service()->eds_response_state().error_message;
+ gpr_log(GPR_INFO, "========= ABOUT TO UPDATE 1 ==========");
+ SetNextResolutionForLbChannel({balancers_[1]->port()});
+ gpr_log(GPR_INFO, "========= UPDATE 1 DONE ==========");
+ EXPECT_EQ(0U, backends_[1]->backend_service()->request_count());
+ gpr_timespec deadline = gpr_time_add(
+ gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_millis(10000, GPR_TIMESPAN));
+ // Send 10 seconds worth of RPCs
+ do {
+ CheckRpcSendOk();
+ } while (gpr_time_cmp(gpr_now(GPR_CLOCK_REALTIME), deadline) < 0);
+ // The current LB call is still working, so xds continued using it to the
+ // first balancer, which doesn't assign the second backend.
+ EXPECT_EQ(0U, backends_[1]->backend_service()->request_count());
+ // The ADS service of balancer 0 sent at least 1 response.
+ EXPECT_GT(balancers_[0]->ads_service()->eds_response_state().state,
+ AdsServiceImpl::ResponseState::NOT_SENT);
+ EXPECT_EQ(balancers_[1]->ads_service()->eds_response_state().state,
+ AdsServiceImpl::ResponseState::NOT_SENT)
+ << "Error Message:"
+ << balancers_[1]->ads_service()->eds_response_state().error_message;
+ EXPECT_EQ(balancers_[2]->ads_service()->eds_response_state().state,
+ AdsServiceImpl::ResponseState::NOT_SENT)
+ << "Error Message:"
+ << balancers_[2]->ads_service()->eds_response_state().error_message;
+}
+
+// Tests that the old LB call is still used after multiple balancer address
+// updates as long as that call is still alive. Send an update with the same set
+// of LBs as the one in SetUp() in order to verify that the LB channel inside
+// xds keeps the initial connection (which by definition is also present in the
+// update).
+TEST_P(BalancerUpdateTest, Repeated) {
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", {backends_[0]->port()}},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()));
+ args = AdsServiceImpl::EdsResourceArgs({
+ {"locality0", {backends_[1]->port()}},
+ });
+ balancers_[1]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()));
+ // Wait until the first backend is ready.
+ WaitForBackend(0);
+ // Send 10 requests.
+ gpr_log(GPR_INFO, "========= BEFORE FIRST BATCH ==========");
+ CheckRpcSendOk(10);
+ gpr_log(GPR_INFO, "========= DONE WITH FIRST BATCH ==========");
+ // All 10 requests should have gone to the first backend.
+ EXPECT_EQ(10U, backends_[0]->backend_service()->request_count());
+ // The ADS service of balancer 0 sent at least 1 response.
+ EXPECT_GT(balancers_[0]->ads_service()->eds_response_state().state,
+ AdsServiceImpl::ResponseState::NOT_SENT);
+ EXPECT_EQ(balancers_[1]->ads_service()->eds_response_state().state,
+ AdsServiceImpl::ResponseState::NOT_SENT)
+ << "Error Message:"
+ << balancers_[1]->ads_service()->eds_response_state().error_message;
+ EXPECT_EQ(balancers_[2]->ads_service()->eds_response_state().state,
+ AdsServiceImpl::ResponseState::NOT_SENT)
+ << "Error Message:"
+ << balancers_[2]->ads_service()->eds_response_state().error_message;
+ std::vector<int> ports;
+ ports.emplace_back(balancers_[0]->port());
+ ports.emplace_back(balancers_[1]->port());
+ ports.emplace_back(balancers_[2]->port());
+ gpr_log(GPR_INFO, "========= ABOUT TO UPDATE 1 ==========");
+ SetNextResolutionForLbChannel(ports);
+ gpr_log(GPR_INFO, "========= UPDATE 1 DONE ==========");
+ EXPECT_EQ(0U, backends_[1]->backend_service()->request_count());
+ gpr_timespec deadline = gpr_time_add(
+ gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_millis(10000, GPR_TIMESPAN));
+ // Send 10 seconds worth of RPCs
+ do {
+ CheckRpcSendOk();
+ } while (gpr_time_cmp(gpr_now(GPR_CLOCK_REALTIME), deadline) < 0);
+ // xds continued using the original LB call to the first balancer, which
+ // doesn't assign the second backend.
+ EXPECT_EQ(0U, backends_[1]->backend_service()->request_count());
+ ports.clear();
+ ports.emplace_back(balancers_[0]->port());
+ ports.emplace_back(balancers_[1]->port());
+ gpr_log(GPR_INFO, "========= ABOUT TO UPDATE 2 ==========");
+ SetNextResolutionForLbChannel(ports);
+ gpr_log(GPR_INFO, "========= UPDATE 2 DONE ==========");
+ EXPECT_EQ(0U, backends_[1]->backend_service()->request_count());
+ deadline = gpr_time_add(gpr_now(GPR_CLOCK_REALTIME),
+ gpr_time_from_millis(10000, GPR_TIMESPAN));
+ // Send 10 seconds worth of RPCs
+ do {
+ CheckRpcSendOk();
+ } while (gpr_time_cmp(gpr_now(GPR_CLOCK_REALTIME), deadline) < 0);
+ // xds continued using the original LB call to the first balancer, which
+ // doesn't assign the second backend.
+ EXPECT_EQ(0U, backends_[1]->backend_service()->request_count());
+}
+
+// Tests that if the balancer is down, the RPCs will still be sent to the
+// backends according to the last balancer response, until a new balancer is
+// reachable.
+TEST_P(BalancerUpdateTest, DeadUpdate) {
+ SetNextResolution({});
+ SetNextResolutionForLbChannel({balancers_[0]->port()});
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", {backends_[0]->port()}},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()));
+ args = AdsServiceImpl::EdsResourceArgs({
+ {"locality0", {backends_[1]->port()}},
+ });
+ balancers_[1]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()));
+ // Start servers and send 10 RPCs per server.
+ gpr_log(GPR_INFO, "========= BEFORE FIRST BATCH ==========");
+ CheckRpcSendOk(10);
+ gpr_log(GPR_INFO, "========= DONE WITH FIRST BATCH ==========");
+ // All 10 requests should have gone to the first backend.
+ EXPECT_EQ(10U, backends_[0]->backend_service()->request_count());
+ // The ADS service of balancer 0 sent at least 1 response.
+ EXPECT_GT(balancers_[0]->ads_service()->eds_response_state().state,
+ AdsServiceImpl::ResponseState::NOT_SENT);
+ EXPECT_EQ(balancers_[1]->ads_service()->eds_response_state().state,
+ AdsServiceImpl::ResponseState::NOT_SENT)
+ << "Error Message:"
+ << balancers_[1]->ads_service()->eds_response_state().error_message;
+ EXPECT_EQ(balancers_[2]->ads_service()->eds_response_state().state,
+ AdsServiceImpl::ResponseState::NOT_SENT)
+ << "Error Message:"
+ << balancers_[2]->ads_service()->eds_response_state().error_message;
+ // Kill balancer 0
+ gpr_log(GPR_INFO, "********** ABOUT TO KILL BALANCER 0 *************");
+ balancers_[0]->Shutdown();
+ gpr_log(GPR_INFO, "********** KILLED BALANCER 0 *************");
+ // This is serviced by the existing child policy.
+ gpr_log(GPR_INFO, "========= BEFORE SECOND BATCH ==========");
+ CheckRpcSendOk(10);
+ gpr_log(GPR_INFO, "========= DONE WITH SECOND BATCH ==========");
+ // All 10 requests should again have gone to the first backend.
+ EXPECT_EQ(20U, backends_[0]->backend_service()->request_count());
+ EXPECT_EQ(0U, backends_[1]->backend_service()->request_count());
+ // The ADS service of no balancers sent anything
+ EXPECT_EQ(balancers_[0]->ads_service()->eds_response_state().state,
+ AdsServiceImpl::ResponseState::NOT_SENT)
+ << "Error Message:"
+ << balancers_[0]->ads_service()->eds_response_state().error_message;
+ EXPECT_EQ(balancers_[1]->ads_service()->eds_response_state().state,
+ AdsServiceImpl::ResponseState::NOT_SENT)
+ << "Error Message:"
+ << balancers_[1]->ads_service()->eds_response_state().error_message;
+ EXPECT_EQ(balancers_[2]->ads_service()->eds_response_state().state,
+ AdsServiceImpl::ResponseState::NOT_SENT)
+ << "Error Message:"
+ << balancers_[2]->ads_service()->eds_response_state().error_message;
+ gpr_log(GPR_INFO, "========= ABOUT TO UPDATE 1 ==========");
+ SetNextResolutionForLbChannel({balancers_[1]->port()});
+ gpr_log(GPR_INFO, "========= UPDATE 1 DONE ==========");
+ // Wait until update has been processed, as signaled by the second backend
+ // receiving a request. In the meantime, the client continues to be serviced
+ // (by the first backend) without interruption.
+ EXPECT_EQ(0U, backends_[1]->backend_service()->request_count());
+ WaitForBackend(1);
+ // This is serviced by the updated RR policy
+ backends_[1]->backend_service()->ResetCounters();
+ gpr_log(GPR_INFO, "========= BEFORE THIRD BATCH ==========");
+ CheckRpcSendOk(10);
+ gpr_log(GPR_INFO, "========= DONE WITH THIRD BATCH ==========");
+ // All 10 requests should have gone to the second backend.
+ EXPECT_EQ(10U, backends_[1]->backend_service()->request_count());
+ // The ADS service of balancer 1 sent at least 1 response.
+ EXPECT_EQ(balancers_[0]->ads_service()->eds_response_state().state,
+ AdsServiceImpl::ResponseState::NOT_SENT)
+ << "Error Message:"
+ << balancers_[0]->ads_service()->eds_response_state().error_message;
+ EXPECT_GT(balancers_[1]->ads_service()->eds_response_state().state,
+ AdsServiceImpl::ResponseState::NOT_SENT);
+ EXPECT_EQ(balancers_[2]->ads_service()->eds_response_state().state,
+ AdsServiceImpl::ResponseState::NOT_SENT)
+ << "Error Message:"
+ << balancers_[2]->ads_service()->eds_response_state().error_message;
+}
+
+// The re-resolution tests are deferred because they rely on the fallback mode,
+// which hasn't been supported.
+
+// TODO(juanlishen): Add TEST_P(BalancerUpdateTest, ReresolveDeadBackend).
+
+// TODO(juanlishen): Add TEST_P(UpdatesWithClientLoadReportingTest,
+// ReresolveDeadBalancer)
+
+class ClientLoadReportingTest : public XdsEnd2endTest {
+ public:
+ ClientLoadReportingTest() : XdsEnd2endTest(4, 1, 3) {}
+};
+
+// Tests that the load report received at the balancer is correct.
+TEST_P(ClientLoadReportingTest, Vanilla) {
+ if (!GetParam().use_xds_resolver()) {
+ balancers_[0]->lrs_service()->set_cluster_names({kServerName});
+ }
+ SetNextResolution({});
+ SetNextResolutionForLbChannel({balancers_[0]->port()});
+ const size_t kNumRpcsPerAddress = 10;
+ const size_t kNumFailuresPerAddress = 3;
+ // TODO(juanlishen): Partition the backends after multiple localities is
+ // tested.
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts()},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()));
+ // Wait until all backends are ready.
+ int num_ok = 0;
+ int num_failure = 0;
+ int num_drops = 0;
+ std::tie(num_ok, num_failure, num_drops) = WaitForAllBackends();
+ // Send kNumRpcsPerAddress RPCs per server.
+ CheckRpcSendOk(kNumRpcsPerAddress * num_backends_);
+ CheckRpcSendFailure(kNumFailuresPerAddress * num_backends_,
+ RpcOptions().set_server_fail(true));
+ // Check that each backend got the right number of requests.
+ for (size_t i = 0; i < backends_.size(); ++i) {
+ EXPECT_EQ(kNumRpcsPerAddress + kNumFailuresPerAddress,
+ backends_[i]->backend_service()->request_count());
+ }
+ // The load report received at the balancer should be correct.
+ std::vector<ClientStats> load_report =
+ balancers_[0]->lrs_service()->WaitForLoadReport();
+ ASSERT_EQ(load_report.size(), 1UL);
+ ClientStats& client_stats = load_report.front();
+ EXPECT_EQ(kNumRpcsPerAddress * num_backends_ + num_ok,
+ client_stats.total_successful_requests());
+ EXPECT_EQ(0U, client_stats.total_requests_in_progress());
+ EXPECT_EQ((kNumRpcsPerAddress + kNumFailuresPerAddress) * num_backends_ +
+ num_ok + num_failure,
+ client_stats.total_issued_requests());
+ EXPECT_EQ(kNumFailuresPerAddress * num_backends_ + num_failure,
+ client_stats.total_error_requests());
+ EXPECT_EQ(0U, client_stats.total_dropped_requests());
+ // The LRS service got a single request, and sent a single response.
+ EXPECT_EQ(1U, balancers_[0]->lrs_service()->request_count());
+ EXPECT_EQ(1U, balancers_[0]->lrs_service()->response_count());
+}
+
+// Tests send_all_clusters.
+TEST_P(ClientLoadReportingTest, SendAllClusters) {
+ balancers_[0]->lrs_service()->set_send_all_clusters(true);
+ SetNextResolution({});
+ SetNextResolutionForLbChannel({balancers_[0]->port()});
+ const size_t kNumRpcsPerAddress = 10;
+ const size_t kNumFailuresPerAddress = 3;
+ // TODO(juanlishen): Partition the backends after multiple localities is
+ // tested.
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts()},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()));
+ // Wait until all backends are ready.
+ int num_ok = 0;
+ int num_failure = 0;
+ int num_drops = 0;
+ std::tie(num_ok, num_failure, num_drops) = WaitForAllBackends();
+ // Send kNumRpcsPerAddress RPCs per server.
+ CheckRpcSendOk(kNumRpcsPerAddress * num_backends_);
+ CheckRpcSendFailure(kNumFailuresPerAddress * num_backends_,
+ RpcOptions().set_server_fail(true));
+ // Check that each backend got the right number of requests.
+ for (size_t i = 0; i < backends_.size(); ++i) {
+ EXPECT_EQ(kNumRpcsPerAddress + kNumFailuresPerAddress,
+ backends_[i]->backend_service()->request_count());
+ }
+ // The load report received at the balancer should be correct.
+ std::vector<ClientStats> load_report =
+ balancers_[0]->lrs_service()->WaitForLoadReport();
+ ASSERT_EQ(load_report.size(), 1UL);
+ ClientStats& client_stats = load_report.front();
+ EXPECT_EQ(kNumRpcsPerAddress * num_backends_ + num_ok,
+ client_stats.total_successful_requests());
+ EXPECT_EQ(0U, client_stats.total_requests_in_progress());
+ EXPECT_EQ((kNumRpcsPerAddress + kNumFailuresPerAddress) * num_backends_ +
+ num_ok + num_failure,
+ client_stats.total_issued_requests());
+ EXPECT_EQ(kNumFailuresPerAddress * num_backends_ + num_failure,
+ client_stats.total_error_requests());
+ EXPECT_EQ(0U, client_stats.total_dropped_requests());
+ // The LRS service got a single request, and sent a single response.
+ EXPECT_EQ(1U, balancers_[0]->lrs_service()->request_count());
+ EXPECT_EQ(1U, balancers_[0]->lrs_service()->response_count());
+}
+
+// Tests that we don't include stats for clusters that are not requested
+// by the LRS server.
+TEST_P(ClientLoadReportingTest, HonorsClustersRequestedByLrsServer) {
+ balancers_[0]->lrs_service()->set_cluster_names({"bogus"});
+ SetNextResolution({});
+ SetNextResolutionForLbChannel({balancers_[0]->port()});
+ const size_t kNumRpcsPerAddress = 100;
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts()},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()));
+ // Wait until all backends are ready.
+ int num_ok = 0;
+ int num_failure = 0;
+ int num_drops = 0;
+ std::tie(num_ok, num_failure, num_drops) = WaitForAllBackends();
+ // Send kNumRpcsPerAddress RPCs per server.
+ CheckRpcSendOk(kNumRpcsPerAddress * num_backends_);
+ // Each backend should have gotten 100 requests.
+ for (size_t i = 0; i < backends_.size(); ++i) {
+ EXPECT_EQ(kNumRpcsPerAddress,
+ backends_[i]->backend_service()->request_count());
+ }
+ // The LRS service got a single request, and sent a single response.
+ EXPECT_EQ(1U, balancers_[0]->lrs_service()->request_count());
+ EXPECT_EQ(1U, balancers_[0]->lrs_service()->response_count());
+ // The load report received at the balancer should be correct.
+ std::vector<ClientStats> load_report =
+ balancers_[0]->lrs_service()->WaitForLoadReport();
+ ASSERT_EQ(load_report.size(), 0UL);
+}
+
+// Tests that if the balancer restarts, the client load report contains the
+// stats before and after the restart correctly.
+TEST_P(ClientLoadReportingTest, BalancerRestart) {
+ if (!GetParam().use_xds_resolver()) {
+ balancers_[0]->lrs_service()->set_cluster_names({kServerName});
+ }
+ SetNextResolution({});
+ SetNextResolutionForLbChannel({balancers_[0]->port()});
+ const size_t kNumBackendsFirstPass = backends_.size() / 2;
+ const size_t kNumBackendsSecondPass =
+ backends_.size() - kNumBackendsFirstPass;
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts(0, kNumBackendsFirstPass)},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()));
+ // Wait until all backends returned by the balancer are ready.
+ int num_ok = 0;
+ int num_failure = 0;
+ int num_drops = 0;
+ std::tie(num_ok, num_failure, num_drops) =
+ WaitForAllBackends(/* start_index */ 0,
+ /* stop_index */ kNumBackendsFirstPass);
+ std::vector<ClientStats> load_report =
+ balancers_[0]->lrs_service()->WaitForLoadReport();
+ ASSERT_EQ(load_report.size(), 1UL);
+ ClientStats client_stats = std::move(load_report.front());
+ EXPECT_EQ(static_cast<size_t>(num_ok),
+ client_stats.total_successful_requests());
+ EXPECT_EQ(0U, client_stats.total_requests_in_progress());
+ EXPECT_EQ(0U, client_stats.total_error_requests());
+ EXPECT_EQ(0U, client_stats.total_dropped_requests());
+ // Shut down the balancer.
+ balancers_[0]->Shutdown();
+ // We should continue using the last EDS response we received from the
+ // balancer before it was shut down.
+ // Note: We need to use WaitForAllBackends() here instead of just
+ // CheckRpcSendOk(kNumBackendsFirstPass), because when the balancer
+ // shuts down, the XdsClient will generate an error to the
+ // ServiceConfigWatcher, which will cause the xds resolver to send a
+ // no-op update to the LB policy. When this update gets down to the
+ // round_robin child policy for the locality, it will generate a new
+ // subchannel list, which resets the start index randomly. So we need
+ // to be a little more permissive here to avoid spurious failures.
+ ResetBackendCounters();
+ int num_started = std::get<0>(WaitForAllBackends(
+ /* start_index */ 0, /* stop_index */ kNumBackendsFirstPass));
+ // Now restart the balancer, this time pointing to the new backends.
+ balancers_[0]->Start();
+ args = AdsServiceImpl::EdsResourceArgs({
+ {"locality0", GetBackendPorts(kNumBackendsFirstPass)},
+ });
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()));
+ // Wait for queries to start going to one of the new backends.
+ // This tells us that we're now using the new serverlist.
+ std::tie(num_ok, num_failure, num_drops) =
+ WaitForAllBackends(/* start_index */ kNumBackendsFirstPass);
+ num_started += num_ok + num_failure + num_drops;
+ // Send one RPC per backend.
+ CheckRpcSendOk(kNumBackendsSecondPass);
+ num_started += kNumBackendsSecondPass;
+ // Check client stats.
+ load_report = balancers_[0]->lrs_service()->WaitForLoadReport();
+ ASSERT_EQ(load_report.size(), 1UL);
+ client_stats = std::move(load_report.front());
+ EXPECT_EQ(num_started, client_stats.total_successful_requests());
+ EXPECT_EQ(0U, client_stats.total_requests_in_progress());
+ EXPECT_EQ(0U, client_stats.total_error_requests());
+ EXPECT_EQ(0U, client_stats.total_dropped_requests());
+}
+
+class ClientLoadReportingWithDropTest : public XdsEnd2endTest {
+ public:
+ ClientLoadReportingWithDropTest() : XdsEnd2endTest(4, 1, 20) {}
+};
+
+// Tests that the drop stats are correctly reported by client load reporting.
+TEST_P(ClientLoadReportingWithDropTest, Vanilla) {
+ if (!GetParam().use_xds_resolver()) {
+ balancers_[0]->lrs_service()->set_cluster_names({kServerName});
+ }
+ SetNextResolution({});
+ SetNextResolutionForLbChannelAllBalancers();
+ const size_t kNumRpcs = 3000;
+ const uint32_t kDropPerMillionForLb = 100000;
+ const uint32_t kDropPerMillionForThrottle = 200000;
+ const double kDropRateForLb = kDropPerMillionForLb / 1000000.0;
+ const double kDropRateForThrottle = kDropPerMillionForThrottle / 1000000.0;
+ const double KDropRateForLbAndThrottle =
+ kDropRateForLb + (1 - kDropRateForLb) * kDropRateForThrottle;
+ // The ADS response contains two drop categories.
+ AdsServiceImpl::EdsResourceArgs args({
+ {"locality0", GetBackendPorts()},
+ });
+ args.drop_categories = {{kLbDropType, kDropPerMillionForLb},
+ {kThrottleDropType, kDropPerMillionForThrottle}};
+ balancers_[0]->ads_service()->SetEdsResource(
+ AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()));
+ int num_ok = 0;
+ int num_failure = 0;
+ int num_drops = 0;
+ std::tie(num_ok, num_failure, num_drops) = WaitForAllBackends();
+ const size_t num_warmup = num_ok + num_failure + num_drops;
+ // Send kNumRpcs RPCs and count the drops.
+ for (size_t i = 0; i < kNumRpcs; ++i) {
+ EchoResponse response;
+ const Status status = SendRpc(RpcOptions(), &response);
+ if (!status.ok() &&
+ status.error_message() == "Call dropped by load balancing policy") {
+ ++num_drops;
+ } else {
+ EXPECT_TRUE(status.ok()) << "code=" << status.error_code()
+ << " message=" << status.error_message();
+ EXPECT_EQ(response.message(), kRequestMessage);
+ }
+ }
+ // The drop rate should be roughly equal to the expectation.
+ const double seen_drop_rate = static_cast<double>(num_drops) / kNumRpcs;
+ const double kErrorTolerance = 0.2;
+ EXPECT_THAT(
+ seen_drop_rate,
+ ::testing::AllOf(
+ ::testing::Ge(KDropRateForLbAndThrottle * (1 - kErrorTolerance)),
+ ::testing::Le(KDropRateForLbAndThrottle * (1 + kErrorTolerance))));
+ // Check client stats.
+ const size_t total_rpc = num_warmup + kNumRpcs;
+ ClientStats client_stats;
+ do {
+ std::vector<ClientStats> load_reports =
+ balancers_[0]->lrs_service()->WaitForLoadReport();
+ for (const auto& load_report : load_reports) {
+ client_stats += load_report;
+ }
+ } while (client_stats.total_issued_requests() +
+ client_stats.total_dropped_requests() <
+ total_rpc);
+ EXPECT_EQ(num_drops, client_stats.total_dropped_requests());
+ EXPECT_THAT(
+ client_stats.dropped_requests(kLbDropType),
+ ::testing::AllOf(
+ ::testing::Ge(total_rpc * kDropRateForLb * (1 - kErrorTolerance)),
+ ::testing::Le(total_rpc * kDropRateForLb * (1 + kErrorTolerance))));
+ EXPECT_THAT(client_stats.dropped_requests(kThrottleDropType),
+ ::testing::AllOf(
+ ::testing::Ge(total_rpc * (1 - kDropRateForLb) *
+ kDropRateForThrottle * (1 - kErrorTolerance)),
+ ::testing::Le(total_rpc * (1 - kDropRateForLb) *
+ kDropRateForThrottle * (1 + kErrorTolerance))));
+}
+
+TString TestTypeName(const ::testing::TestParamInfo<TestType>& info) {
+ return info.param.AsString();
+}
+
+// TestType params:
+// - use_xds_resolver
+// - enable_load_reporting
+// - enable_rds_testing = false
+// - use_v2 = false
+
+INSTANTIATE_TEST_SUITE_P(XdsTest, BasicTest,
+ ::testing::Values(TestType(false, true),
+ TestType(false, false),
+ TestType(true, false),
+ TestType(true, true)),
+ &TestTypeName);
+
+// Run with both fake resolver and xds resolver.
+// Don't run with load reporting or v2 or RDS, since they are irrelevant to
+// the tests.
+INSTANTIATE_TEST_SUITE_P(XdsTest, SecureNamingTest,
+ ::testing::Values(TestType(false, false),
+ TestType(true, false)),
+ &TestTypeName);
+
+// LDS depends on XdsResolver.
+INSTANTIATE_TEST_SUITE_P(XdsTest, LdsTest,
+ ::testing::Values(TestType(true, false),
+ TestType(true, true)),
+ &TestTypeName);
+
+// LDS/RDS commmon tests depend on XdsResolver.
+INSTANTIATE_TEST_SUITE_P(XdsTest, LdsRdsTest,
+ ::testing::Values(TestType(true, false),
+ TestType(true, true),
+ TestType(true, false, true),
+ TestType(true, true, true),
+ // Also test with xDS v2.
+ TestType(true, true, true, true)),
+ &TestTypeName);
+
+// CDS depends on XdsResolver.
+INSTANTIATE_TEST_SUITE_P(XdsTest, CdsTest,
+ ::testing::Values(TestType(true, false),
+ TestType(true, true)),
+ &TestTypeName);
+
+// EDS could be tested with or without XdsResolver, but the tests would
+// be the same either way, so we test it only with XdsResolver.
+INSTANTIATE_TEST_SUITE_P(XdsTest, EdsTest,
+ ::testing::Values(TestType(true, false),
+ TestType(true, true)),
+ &TestTypeName);
+
+// Test initial resource timeouts for each resource type.
+// Do this only for XdsResolver with RDS enabled, so that we can test
+// all resource types.
+// Run with V3 only, since the functionality is no different in V2.
+INSTANTIATE_TEST_SUITE_P(XdsTest, TimeoutTest,
+ ::testing::Values(TestType(true, false, true)),
+ &TestTypeName);
+
+// XdsResolverOnlyTest depends on XdsResolver.
+INSTANTIATE_TEST_SUITE_P(XdsTest, XdsResolverOnlyTest,
+ ::testing::Values(TestType(true, false),
+ TestType(true, true)),
+ &TestTypeName);
+
+// XdsResolverLoadReprtingOnlyTest depends on XdsResolver and load reporting.
+INSTANTIATE_TEST_SUITE_P(XdsTest, XdsResolverLoadReportingOnlyTest,
+ ::testing::Values(TestType(true, true)),
+ &TestTypeName);
+
+INSTANTIATE_TEST_SUITE_P(XdsTest, LocalityMapTest,
+ ::testing::Values(TestType(false, true),
+ TestType(false, false),
+ TestType(true, false),
+ TestType(true, true)),
+ &TestTypeName);
+
+INSTANTIATE_TEST_SUITE_P(XdsTest, FailoverTest,
+ ::testing::Values(TestType(false, true),
+ TestType(false, false),
+ TestType(true, false),
+ TestType(true, true)),
+ &TestTypeName);
+
+INSTANTIATE_TEST_SUITE_P(XdsTest, DropTest,
+ ::testing::Values(TestType(false, true),
+ TestType(false, false),
+ TestType(true, false),
+ TestType(true, true)),
+ &TestTypeName);
+
+INSTANTIATE_TEST_SUITE_P(XdsTest, BalancerUpdateTest,
+ ::testing::Values(TestType(false, true),
+ TestType(false, false),
+ TestType(true, true)),
+ &TestTypeName);
+
+// Load reporting tests are not run with load reporting disabled.
+INSTANTIATE_TEST_SUITE_P(XdsTest, ClientLoadReportingTest,
+ ::testing::Values(TestType(false, true),
+ TestType(true, true)),
+ &TestTypeName);
+
+// Load reporting tests are not run with load reporting disabled.
+INSTANTIATE_TEST_SUITE_P(XdsTest, ClientLoadReportingWithDropTest,
+ ::testing::Values(TestType(false, true),
+ TestType(true, true)),
+ &TestTypeName);
+
+} // namespace
+} // namespace testing
+} // namespace grpc
+
+int main(int argc, char** argv) {
+ grpc::testing::TestEnvironment env(argc, argv);
+ ::testing::InitGoogleTest(&argc, argv);
+ grpc::testing::WriteBootstrapFiles();
+ grpc::testing::g_port_saver = new grpc::testing::PortSaver();
+ const auto result = RUN_ALL_TESTS();
+ return result;
+}
diff --git a/contrib/libs/grpc/test/cpp/end2end/ya.make b/contrib/libs/grpc/test/cpp/end2end/ya.make
new file mode 100644
index 0000000000..b9c1dc7fe0
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/end2end/ya.make
@@ -0,0 +1,67 @@
+LIBRARY()
+
+LICENSE(Apache-2.0)
+
+LICENSE_TEXTS(.yandex_meta/licenses.list.txt)
+
+OWNER(dvshkurko)
+
+PEERDIR(
+ contrib/libs/grpc/src/proto/grpc/health/v1
+ contrib/libs/grpc/src/proto/grpc/testing
+ contrib/libs/grpc/src/proto/grpc/testing/duplicate
+ contrib/libs/grpc/test/cpp/util
+ contrib/libs/grpc
+ contrib/restricted/googletest/googlemock
+ contrib/restricted/googletest/googletest
+)
+
+ADDINCL(
+ ${ARCADIA_BUILD_ROOT}/contrib/libs/grpc
+ contrib/libs/grpc
+)
+
+NO_COMPILER_WARNINGS()
+
+SRCS(
+ # async_end2end_test.cc
+ # channelz_service_test.cc
+ # client_callback_end2end_test.cc
+ # client_crash_test.cc
+ # client_crash_test_server.cc
+ # client_interceptors_end2end_test.cc
+ # client_lb_end2end_test.cc lb needs opencensus, not enabled.
+ # end2end_test.cc
+ # exception_test.cc
+ # filter_end2end_test.cc
+ # generic_end2end_test.cc
+ # grpclb_end2end_test.cc lb needs opencensus, not enabled.
+ # health_service_end2end_test.cc
+ # hybrid_end2end_test.cc
+ interceptors_util.cc
+ # mock_test.cc
+ # nonblocking_test.cc
+ # proto_server_reflection_test.cc
+ # raw_end2end_test.cc
+ # server_builder_plugin_test.cc
+ # server_crash_test.cc
+ # server_crash_test_client.cc
+ # server_early_return_test.cc
+ # server_interceptors_end2end_test.cc
+ # server_load_reporting_end2end_test.cc
+ # shutdown_test.cc
+ # streaming_throughput_test.cc
+ test_health_check_service_impl.cc
+ test_service_impl.cc
+ # thread_stress_test.cc
+ # time_change_test.cc
+)
+
+END()
+
+RECURSE_FOR_TESTS(
+ health
+ server_interceptors
+ # Needs new gtest
+ # thread
+)
diff --git a/contrib/libs/grpc/test/cpp/util/.yandex_meta/licenses.list.txt b/contrib/libs/grpc/test/cpp/util/.yandex_meta/licenses.list.txt
new file mode 100644
index 0000000000..d2dadabed9
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/util/.yandex_meta/licenses.list.txt
@@ -0,0 +1,32 @@
+====================Apache-2.0====================
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+
+
+====================COPYRIGHT====================
+ * Copyright 2015 gRPC authors.
+
+
+====================COPYRIGHT====================
+ * Copyright 2015-2016 gRPC authors.
+
+
+====================COPYRIGHT====================
+ * Copyright 2016 gRPC authors.
+
+
+====================COPYRIGHT====================
+ * Copyright 2017 gRPC authors.
+
+
+====================COPYRIGHT====================
+ * Copyright 2018 gRPC authors.
diff --git a/contrib/libs/grpc/test/cpp/util/byte_buffer_proto_helper.cc b/contrib/libs/grpc/test/cpp/util/byte_buffer_proto_helper.cc
new file mode 100644
index 0000000000..5971b53075
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/util/byte_buffer_proto_helper.cc
@@ -0,0 +1,57 @@
+/*
+ *
+ * Copyright 2016 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include "test/cpp/util/byte_buffer_proto_helper.h"
+
+namespace grpc {
+namespace testing {
+
+bool ParseFromByteBuffer(ByteBuffer* buffer, grpc::protobuf::Message* message) {
+ std::vector<Slice> slices;
+ (void)buffer->Dump(&slices);
+ TString buf;
+ buf.reserve(buffer->Length());
+ for (auto s = slices.begin(); s != slices.end(); s++) {
+ buf.append(reinterpret_cast<const char*>(s->begin()), s->size());
+ }
+ return message->ParseFromString(buf);
+}
+
+std::unique_ptr<ByteBuffer> SerializeToByteBuffer(
+ grpc::protobuf::Message* message) {
+ TString buf;
+ message->SerializeToString(&buf);
+ Slice slice(buf);
+ return std::unique_ptr<ByteBuffer>(new ByteBuffer(&slice, 1));
+}
+
+bool SerializeToByteBufferInPlace(grpc::protobuf::Message* message,
+ ByteBuffer* buffer) {
+ TString buf;
+ if (!message->SerializeToString(&buf)) {
+ return false;
+ }
+ buffer->Clear();
+ Slice slice(buf);
+ ByteBuffer tmp(&slice, 1);
+ buffer->Swap(&tmp);
+ return true;
+}
+
+} // namespace testing
+} // namespace grpc
diff --git a/contrib/libs/grpc/test/cpp/util/byte_buffer_proto_helper.h b/contrib/libs/grpc/test/cpp/util/byte_buffer_proto_helper.h
new file mode 100644
index 0000000000..3d01fb2468
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/util/byte_buffer_proto_helper.h
@@ -0,0 +1,42 @@
+/*
+ *
+ * Copyright 2016 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#ifndef GRPC_TEST_CPP_UTIL_BYTE_BUFFER_PROTO_HELPER_H
+#define GRPC_TEST_CPP_UTIL_BYTE_BUFFER_PROTO_HELPER_H
+
+#include <memory>
+
+#include <grpcpp/impl/codegen/config_protobuf.h>
+#include <grpcpp/support/byte_buffer.h>
+
+namespace grpc {
+namespace testing {
+
+bool ParseFromByteBuffer(ByteBuffer* buffer,
+ ::grpc::protobuf::Message* message);
+
+std::unique_ptr<ByteBuffer> SerializeToByteBuffer(
+ ::grpc::protobuf::Message* message);
+
+bool SerializeToByteBufferInPlace(::grpc::protobuf::Message* message,
+ ByteBuffer* buffer);
+
+} // namespace testing
+} // namespace grpc
+
+#endif // GRPC_TEST_CPP_UTIL_BYTE_BUFFER_PROTO_HELPER_H
diff --git a/contrib/libs/grpc/test/cpp/util/byte_buffer_test.cc b/contrib/libs/grpc/test/cpp/util/byte_buffer_test.cc
new file mode 100644
index 0000000000..c63f351a8f
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/util/byte_buffer_test.cc
@@ -0,0 +1,134 @@
+/*
+ *
+ * Copyright 2015 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include <grpc++/support/byte_buffer.h>
+#include <grpcpp/impl/grpc_library.h>
+
+#include <cstring>
+#include <vector>
+
+#include <grpc/grpc.h>
+#include <grpc/slice.h>
+#include <grpcpp/support/slice.h>
+#include <gtest/gtest.h>
+
+#include "test/core/util/test_config.h"
+
+namespace grpc {
+
+static internal::GrpcLibraryInitializer g_gli_initializer;
+
+namespace {
+
+const char* kContent1 = "hello xxxxxxxxxxxxxxxxxxxxxxxxxxxxxx";
+const char* kContent2 = "yyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyy world";
+
+class ByteBufferTest : public ::testing::Test {
+ protected:
+ static void SetUpTestCase() { grpc_init(); }
+
+ static void TearDownTestCase() { grpc_shutdown(); }
+};
+
+TEST_F(ByteBufferTest, CopyCtor) {
+ ByteBuffer buffer1;
+ EXPECT_FALSE(buffer1.Valid());
+ const ByteBuffer& buffer2 = buffer1;
+ EXPECT_FALSE(buffer2.Valid());
+}
+
+TEST_F(ByteBufferTest, CreateFromSingleSlice) {
+ Slice s(kContent1);
+ ByteBuffer buffer(&s, 1);
+ EXPECT_EQ(strlen(kContent1), buffer.Length());
+}
+
+TEST_F(ByteBufferTest, CreateFromVector) {
+ std::vector<Slice> slices;
+ slices.emplace_back(kContent1);
+ slices.emplace_back(kContent2);
+ ByteBuffer buffer(&slices[0], 2);
+ EXPECT_EQ(strlen(kContent1) + strlen(kContent2), buffer.Length());
+}
+
+TEST_F(ByteBufferTest, Clear) {
+ Slice s(kContent1);
+ ByteBuffer buffer(&s, 1);
+ buffer.Clear();
+ EXPECT_EQ(static_cast<size_t>(0), buffer.Length());
+}
+
+TEST_F(ByteBufferTest, Length) {
+ std::vector<Slice> slices;
+ slices.emplace_back(kContent1);
+ slices.emplace_back(kContent2);
+ ByteBuffer buffer(&slices[0], 2);
+ EXPECT_EQ(strlen(kContent1) + strlen(kContent2), buffer.Length());
+}
+
+bool SliceEqual(const Slice& a, grpc_slice b) {
+ if (a.size() != GRPC_SLICE_LENGTH(b)) {
+ return false;
+ }
+ for (size_t i = 0; i < a.size(); i++) {
+ if (a.begin()[i] != GRPC_SLICE_START_PTR(b)[i]) {
+ return false;
+ }
+ }
+ return true;
+}
+
+TEST_F(ByteBufferTest, Dump) {
+ grpc_slice hello = grpc_slice_from_copied_string(kContent1);
+ grpc_slice world = grpc_slice_from_copied_string(kContent2);
+ std::vector<Slice> slices;
+ slices.push_back(Slice(hello, Slice::STEAL_REF));
+ slices.push_back(Slice(world, Slice::STEAL_REF));
+ ByteBuffer buffer(&slices[0], 2);
+ slices.clear();
+ (void)buffer.Dump(&slices);
+ EXPECT_TRUE(SliceEqual(slices[0], hello));
+ EXPECT_TRUE(SliceEqual(slices[1], world));
+}
+
+TEST_F(ByteBufferTest, SerializationMakesCopy) {
+ grpc_slice hello = grpc_slice_from_copied_string(kContent1);
+ grpc_slice world = grpc_slice_from_copied_string(kContent2);
+ std::vector<Slice> slices;
+ slices.push_back(Slice(hello, Slice::STEAL_REF));
+ slices.push_back(Slice(world, Slice::STEAL_REF));
+ ByteBuffer send_buffer;
+ bool owned = false;
+ ByteBuffer buffer(&slices[0], 2);
+ slices.clear();
+ auto status = SerializationTraits<ByteBuffer, void>::Serialize(
+ buffer, &send_buffer, &owned);
+ EXPECT_TRUE(status.ok());
+ EXPECT_TRUE(owned);
+ EXPECT_TRUE(send_buffer.Valid());
+}
+
+} // namespace
+} // namespace grpc
+
+int main(int argc, char** argv) {
+ grpc::testing::TestEnvironment env(argc, argv);
+ ::testing::InitGoogleTest(&argc, argv);
+ int ret = RUN_ALL_TESTS();
+ return ret;
+}
diff --git a/contrib/libs/grpc/test/cpp/util/channel_trace_proto_helper.cc b/contrib/libs/grpc/test/cpp/util/channel_trace_proto_helper.cc
new file mode 100644
index 0000000000..d4b4026774
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/util/channel_trace_proto_helper.cc
@@ -0,0 +1,115 @@
+/*
+ *
+ * Copyright 2018 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include <grpc/support/port_platform.h>
+
+#include "test/cpp/util/channel_trace_proto_helper.h"
+
+#include <grpc/grpc.h>
+#include <grpc/support/log.h>
+#include <grpcpp/impl/codegen/config.h>
+#include <grpcpp/impl/codegen/config_protobuf.h>
+#include <gtest/gtest.h>
+
+#include "src/core/lib/iomgr/error.h"
+#include "src/core/lib/json/json.h"
+#include "src/proto/grpc/channelz/channelz.pb.h"
+
+namespace grpc {
+
+namespace {
+
+// Generic helper that takes in a json string, converts it to a proto, and
+// then back to json. This ensures that the json string was correctly formatted
+// according to https://developers.google.com/protocol-buffers/docs/proto3#json
+template <typename Message>
+void VaidateProtoJsonTranslation(const TString& json_str) {
+ Message msg;
+ grpc::protobuf::json::JsonParseOptions parse_options;
+ // If the following line is failing, then uncomment the last line of the
+ // comment, and uncomment the lines that print the two strings. You can
+ // then compare the output, and determine what fields are missing.
+ //
+ // parse_options.ignore_unknown_fields = true;
+ grpc::protobuf::util::Status s =
+ grpc::protobuf::json::JsonStringToMessage(json_str, &msg, parse_options);
+ EXPECT_TRUE(s.ok());
+ TString proto_json_str;
+ grpc::protobuf::json::JsonPrintOptions print_options;
+ // We usually do not want this to be true, however it can be helpful to
+ // uncomment and see the output produced then all fields are printed.
+ // print_options.always_print_primitive_fields = true;
+ s = grpc::protobuf::json::MessageToJsonString(msg, &proto_json_str);
+ EXPECT_TRUE(s.ok());
+ // Parse JSON and re-dump to string, to make sure formatting is the
+ // same as what would be generated by our JSON library.
+ grpc_error* error = GRPC_ERROR_NONE;
+ grpc_core::Json parsed_json =
+ grpc_core::Json::Parse(proto_json_str.c_str(), &error);
+ ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_string(error);
+ ASSERT_EQ(parsed_json.type(), grpc_core::Json::Type::OBJECT);
+ proto_json_str = parsed_json.Dump();
+ // uncomment these to compare the json strings.
+ // gpr_log(GPR_ERROR, "tracer json: %s", json_str.c_str());
+ // gpr_log(GPR_ERROR, "proto json: %s", proto_json_str.c_str());
+ EXPECT_EQ(json_str, proto_json_str);
+}
+
+} // namespace
+
+namespace testing {
+
+void ValidateChannelTraceProtoJsonTranslation(const char* json_c_str) {
+ VaidateProtoJsonTranslation<grpc::channelz::v1::ChannelTrace>(json_c_str);
+}
+
+void ValidateChannelProtoJsonTranslation(const char* json_c_str) {
+ VaidateProtoJsonTranslation<grpc::channelz::v1::Channel>(json_c_str);
+}
+
+void ValidateGetTopChannelsResponseProtoJsonTranslation(
+ const char* json_c_str) {
+ VaidateProtoJsonTranslation<grpc::channelz::v1::GetTopChannelsResponse>(
+ json_c_str);
+}
+
+void ValidateGetChannelResponseProtoJsonTranslation(const char* json_c_str) {
+ VaidateProtoJsonTranslation<grpc::channelz::v1::GetChannelResponse>(
+ json_c_str);
+}
+
+void ValidateGetServerResponseProtoJsonTranslation(const char* json_c_str) {
+ VaidateProtoJsonTranslation<grpc::channelz::v1::GetServerResponse>(
+ json_c_str);
+}
+
+void ValidateSubchannelProtoJsonTranslation(const char* json_c_str) {
+ VaidateProtoJsonTranslation<grpc::channelz::v1::Subchannel>(json_c_str);
+}
+
+void ValidateServerProtoJsonTranslation(const char* json_c_str) {
+ VaidateProtoJsonTranslation<grpc::channelz::v1::Server>(json_c_str);
+}
+
+void ValidateGetServersResponseProtoJsonTranslation(const char* json_c_str) {
+ VaidateProtoJsonTranslation<grpc::channelz::v1::GetServersResponse>(
+ json_c_str);
+}
+
+} // namespace testing
+} // namespace grpc
diff --git a/contrib/libs/grpc/test/cpp/util/channel_trace_proto_helper.h b/contrib/libs/grpc/test/cpp/util/channel_trace_proto_helper.h
new file mode 100644
index 0000000000..664e899deb
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/util/channel_trace_proto_helper.h
@@ -0,0 +1,37 @@
+/*
+ *
+ * Copyright 2018 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#ifndef GRPC_TEST_CPP_UTIL_CHANNEL_TRACE_PROTO_HELPER_H
+#define GRPC_TEST_CPP_UTIL_CHANNEL_TRACE_PROTO_HELPER_H
+
+namespace grpc {
+namespace testing {
+
+void ValidateChannelTraceProtoJsonTranslation(const char* json_c_str);
+void ValidateChannelProtoJsonTranslation(const char* json_c_str);
+void ValidateGetTopChannelsResponseProtoJsonTranslation(const char* json_c_str);
+void ValidateGetChannelResponseProtoJsonTranslation(const char* json_c_str);
+void ValidateGetServerResponseProtoJsonTranslation(const char* json_c_str);
+void ValidateSubchannelProtoJsonTranslation(const char* json_c_str);
+void ValidateServerProtoJsonTranslation(const char* json_c_str);
+void ValidateGetServersResponseProtoJsonTranslation(const char* json_c_str);
+
+} // namespace testing
+} // namespace grpc
+
+#endif // GRPC_TEST_CPP_UTIL_CHANNEL_TRACE_PROTO_HELPER_H
diff --git a/contrib/libs/grpc/test/cpp/util/channelz_sampler.cc b/contrib/libs/grpc/test/cpp/util/channelz_sampler.cc
new file mode 100644
index 0000000000..e6bde68556
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/util/channelz_sampler.cc
@@ -0,0 +1,588 @@
+/*
+ *
+ * Copyright 2015 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+#include <unistd.h>
+
+#include <cstdlib>
+#include <fstream>
+#include <iostream>
+#include <memory>
+#include <ostream>
+#include <queue>
+#include <util/generic/string.h>
+
+#include "y_absl/strings/str_format.h"
+#include "y_absl/strings/str_join.h"
+#include "gflags/gflags.h"
+#include "google/protobuf/text_format.h"
+#include "grpc/grpc.h"
+#include "grpc/support/port_platform.h"
+#include "grpcpp/channel.h"
+#include "grpcpp/client_context.h"
+#include "grpcpp/create_channel.h"
+#include "grpcpp/ext/channelz_service_plugin.h"
+#include "grpcpp/grpcpp.h"
+#include "grpcpp/security/credentials.h"
+#include "grpcpp/security/server_credentials.h"
+#include "grpcpp/server.h"
+#include "grpcpp/server_builder.h"
+#include "grpcpp/server_context.h"
+#include "src/core/lib/json/json.h"
+#include "src/cpp/server/channelz/channelz_service.h"
+#include "src/proto/grpc/channelz/channelz.pb.h"
+#include "test/core/util/test_config.h"
+#include "test/cpp/util/test_config.h"
+#include "test/cpp/util/test_credentials_provider.h"
+
+DEFINE_string(server_address, "", "channelz server address");
+DEFINE_string(custom_credentials_type, "", "custom credentials type");
+DEFINE_int64(sampling_times, 1, "number of sampling");
+DEFINE_int64(sampling_interval_seconds, 0, "sampling interval in seconds");
+DEFINE_string(output_json, "", "output filename in json format");
+
+namespace {
+using grpc::ClientContext;
+using grpc::Status;
+using grpc::StatusCode;
+using grpc::channelz::v1::GetChannelRequest;
+using grpc::channelz::v1::GetChannelResponse;
+using grpc::channelz::v1::GetServerRequest;
+using grpc::channelz::v1::GetServerResponse;
+using grpc::channelz::v1::GetServerSocketsRequest;
+using grpc::channelz::v1::GetServerSocketsResponse;
+using grpc::channelz::v1::GetServersRequest;
+using grpc::channelz::v1::GetServersResponse;
+using grpc::channelz::v1::GetSocketRequest;
+using grpc::channelz::v1::GetSocketResponse;
+using grpc::channelz::v1::GetSubchannelRequest;
+using grpc::channelz::v1::GetSubchannelResponse;
+using grpc::channelz::v1::GetTopChannelsRequest;
+using grpc::channelz::v1::GetTopChannelsResponse;
+} // namespace
+
+class ChannelzSampler final {
+ public:
+ // Get server_id of a server
+ int64_t GetServerID(const grpc::channelz::v1::Server& server) {
+ return server.ref().server_id();
+ }
+
+ // Get channel_id of a channel
+ inline int64_t GetChannelID(const grpc::channelz::v1::Channel& channel) {
+ return channel.ref().channel_id();
+ }
+
+ // Get subchannel_id of a subchannel
+ inline int64_t GetSubchannelID(
+ const grpc::channelz::v1::Subchannel& subchannel) {
+ return subchannel.ref().subchannel_id();
+ }
+
+ // Get socket_id of a socket
+ inline int64_t GetSocketID(const grpc::channelz::v1::Socket& socket) {
+ return socket.ref().socket_id();
+ }
+
+ // Get name of a server
+ inline TString GetServerName(const grpc::channelz::v1::Server& server) {
+ return server.ref().name();
+ }
+
+ // Get name of a channel
+ inline TString GetChannelName(
+ const grpc::channelz::v1::Channel& channel) {
+ return channel.ref().name();
+ }
+
+ // Get name of a subchannel
+ inline TString GetSubchannelName(
+ const grpc::channelz::v1::Subchannel& subchannel) {
+ return subchannel.ref().name();
+ }
+
+ // Get name of a socket
+ inline TString GetSocketName(const grpc::channelz::v1::Socket& socket) {
+ return socket.ref().name();
+ }
+
+ // Get a channel based on channel_id
+ grpc::channelz::v1::Channel GetChannelRPC(int64_t channel_id) {
+ GetChannelRequest get_channel_request;
+ get_channel_request.set_channel_id(channel_id);
+ GetChannelResponse get_channel_response;
+ ClientContext get_channel_context;
+ get_channel_context.set_deadline(
+ grpc_timeout_seconds_to_deadline(rpc_timeout_seconds_));
+ Status status = channelz_stub_->GetChannel(
+ &get_channel_context, get_channel_request, &get_channel_response);
+ if (!status.ok()) {
+ gpr_log(GPR_ERROR, "GetChannelRPC failed: %s",
+ get_channel_context.debug_error_string().c_str());
+ GPR_ASSERT(0);
+ }
+ return get_channel_response.channel();
+ }
+
+ // Get a subchannel based on subchannel_id
+ grpc::channelz::v1::Subchannel GetSubchannelRPC(int64_t subchannel_id) {
+ GetSubchannelRequest get_subchannel_request;
+ get_subchannel_request.set_subchannel_id(subchannel_id);
+ GetSubchannelResponse get_subchannel_response;
+ ClientContext get_subchannel_context;
+ get_subchannel_context.set_deadline(
+ grpc_timeout_seconds_to_deadline(rpc_timeout_seconds_));
+ Status status = channelz_stub_->GetSubchannel(&get_subchannel_context,
+ get_subchannel_request,
+ &get_subchannel_response);
+ if (!status.ok()) {
+ gpr_log(GPR_ERROR, "GetSubchannelRPC failed: %s",
+ get_subchannel_context.debug_error_string().c_str());
+ GPR_ASSERT(0);
+ }
+ return get_subchannel_response.subchannel();
+ }
+
+ // get a socket based on socket_id
+ grpc::channelz::v1::Socket GetSocketRPC(int64_t socket_id) {
+ GetSocketRequest get_socket_request;
+ get_socket_request.set_socket_id(socket_id);
+ GetSocketResponse get_socket_response;
+ ClientContext get_socket_context;
+ get_socket_context.set_deadline(
+ grpc_timeout_seconds_to_deadline(rpc_timeout_seconds_));
+ Status status = channelz_stub_->GetSocket(
+ &get_socket_context, get_socket_request, &get_socket_response);
+ if (!status.ok()) {
+ gpr_log(GPR_ERROR, "GetSocketRPC failed: %s",
+ get_socket_context.debug_error_string().c_str());
+ GPR_ASSERT(0);
+ }
+ return get_socket_response.socket();
+ }
+
+ // get the descedent channels/subchannels/sockets of a channel
+ // push descedent channels/subchannels to queue for layer traverse
+ // store descedent channels/subchannels/sockets for dumping data
+ void GetChannelDescedence(
+ const grpc::channelz::v1::Channel& channel,
+ std::queue<grpc::channelz::v1::Channel>& channel_queue,
+ std::queue<grpc::channelz::v1::Subchannel>& subchannel_queue) {
+ std::cout << " Channel ID" << GetChannelID(channel) << "_"
+ << GetChannelName(channel) << " descendence - ";
+ if (channel.channel_ref_size() > 0 || channel.subchannel_ref_size() > 0) {
+ if (channel.channel_ref_size() > 0) {
+ std::cout << "channel: ";
+ for (const auto& _channelref : channel.channel_ref()) {
+ int64_t ch_id = _channelref.channel_id();
+ std::cout << "ID" << ch_id << "_" << _channelref.name() << " ";
+ grpc::channelz::v1::Channel ch = GetChannelRPC(ch_id);
+ channel_queue.push(ch);
+ if (CheckID(ch_id)) {
+ all_channels_.push_back(ch);
+ StoreChannelInJson(ch);
+ }
+ }
+ if (channel.subchannel_ref_size() > 0) {
+ std::cout << ", ";
+ }
+ }
+ if (channel.subchannel_ref_size() > 0) {
+ std::cout << "subchannel: ";
+ for (const auto& _subchannelref : channel.subchannel_ref()) {
+ int64_t subch_id = _subchannelref.subchannel_id();
+ std::cout << "ID" << subch_id << "_" << _subchannelref.name() << " ";
+ grpc::channelz::v1::Subchannel subch = GetSubchannelRPC(subch_id);
+ subchannel_queue.push(subch);
+ if (CheckID(subch_id)) {
+ all_subchannels_.push_back(subch);
+ StoreSubchannelInJson(subch);
+ }
+ }
+ }
+ } else if (channel.socket_ref_size() > 0) {
+ std::cout << "socket: ";
+ for (const auto& _socketref : channel.socket_ref()) {
+ int64_t so_id = _socketref.socket_id();
+ std::cout << "ID" << so_id << "_" << _socketref.name() << " ";
+ grpc::channelz::v1::Socket so = GetSocketRPC(so_id);
+ if (CheckID(so_id)) {
+ all_sockets_.push_back(so);
+ StoreSocketInJson(so);
+ }
+ }
+ }
+ std::cout << std::endl;
+ }
+
+ // get the descedent channels/subchannels/sockets of a subchannel
+ // push descedent channels/subchannels to queue for layer traverse
+ // store descedent channels/subchannels/sockets for dumping data
+ void GetSubchannelDescedence(
+ grpc::channelz::v1::Subchannel& subchannel,
+ std::queue<grpc::channelz::v1::Channel>& channel_queue,
+ std::queue<grpc::channelz::v1::Subchannel>& subchannel_queue) {
+ std::cout << " Subchannel ID" << GetSubchannelID(subchannel) << "_"
+ << GetSubchannelName(subchannel) << " descendence - ";
+ if (subchannel.channel_ref_size() > 0 ||
+ subchannel.subchannel_ref_size() > 0) {
+ if (subchannel.channel_ref_size() > 0) {
+ std::cout << "channel: ";
+ for (const auto& _channelref : subchannel.channel_ref()) {
+ int64_t ch_id = _channelref.channel_id();
+ std::cout << "ID" << ch_id << "_" << _channelref.name() << " ";
+ grpc::channelz::v1::Channel ch = GetChannelRPC(ch_id);
+ channel_queue.push(ch);
+ if (CheckID(ch_id)) {
+ all_channels_.push_back(ch);
+ StoreChannelInJson(ch);
+ }
+ }
+ if (subchannel.subchannel_ref_size() > 0) {
+ std::cout << ", ";
+ }
+ }
+ if (subchannel.subchannel_ref_size() > 0) {
+ std::cout << "subchannel: ";
+ for (const auto& _subchannelref : subchannel.subchannel_ref()) {
+ int64_t subch_id = _subchannelref.subchannel_id();
+ std::cout << "ID" << subch_id << "_" << _subchannelref.name() << " ";
+ grpc::channelz::v1::Subchannel subch = GetSubchannelRPC(subch_id);
+ subchannel_queue.push(subch);
+ if (CheckID(subch_id)) {
+ all_subchannels_.push_back(subch);
+ StoreSubchannelInJson(subch);
+ }
+ }
+ }
+ } else if (subchannel.socket_ref_size() > 0) {
+ std::cout << "socket: ";
+ for (const auto& _socketref : subchannel.socket_ref()) {
+ int64_t so_id = _socketref.socket_id();
+ std::cout << "ID" << so_id << "_" << _socketref.name() << " ";
+ grpc::channelz::v1::Socket so = GetSocketRPC(so_id);
+ if (CheckID(so_id)) {
+ all_sockets_.push_back(so);
+ StoreSocketInJson(so);
+ }
+ }
+ }
+ std::cout << std::endl;
+ }
+
+ // Set up the channelz sampler client
+ // Initialize json as an array
+ void Setup(const TString& custom_credentials_type,
+ const TString& server_address) {
+ json_ = grpc_core::Json::Array();
+ rpc_timeout_seconds_ = 20;
+ grpc::ChannelArguments channel_args;
+ std::shared_ptr<grpc::ChannelCredentials> channel_creds =
+ grpc::testing::GetCredentialsProvider()->GetChannelCredentials(
+ custom_credentials_type, &channel_args);
+ if (!channel_creds) {
+ gpr_log(GPR_ERROR,
+ "Wrong user credential type: %s. Allowed credential types: "
+ "INSECURE_CREDENTIALS, ssl, alts, google_default_credentials.",
+ custom_credentials_type.c_str());
+ GPR_ASSERT(0);
+ }
+ std::shared_ptr<grpc::Channel> channel =
+ CreateChannel(server_address, channel_creds);
+ channelz_stub_ = grpc::channelz::v1::Channelz::NewStub(channel);
+ }
+
+ // Get all servers, keep querying until getting all
+ // Store servers for dumping data
+ // Need to check id repeating for servers
+ void GetServersRPC() {
+ int64_t server_start_id = 0;
+ while (true) {
+ GetServersRequest get_servers_request;
+ GetServersResponse get_servers_response;
+ ClientContext get_servers_context;
+ get_servers_context.set_deadline(
+ grpc_timeout_seconds_to_deadline(rpc_timeout_seconds_));
+ get_servers_request.set_start_server_id(server_start_id);
+ Status status = channelz_stub_->GetServers(
+ &get_servers_context, get_servers_request, &get_servers_response);
+ if (!status.ok()) {
+ if (status.error_code() == StatusCode::UNIMPLEMENTED) {
+ gpr_log(GPR_ERROR,
+ "Error status UNIMPLEMENTED. Please check and make sure "
+ "channelz has been registered on the server being queried.");
+ } else {
+ gpr_log(GPR_ERROR,
+ "GetServers RPC with GetServersRequest.server_start_id=%d, "
+ "failed: %s",
+ int(server_start_id),
+ get_servers_context.debug_error_string().c_str());
+ }
+ GPR_ASSERT(0);
+ }
+ for (const auto& _server : get_servers_response.server()) {
+ all_servers_.push_back(_server);
+ StoreServerInJson(_server);
+ }
+ if (!get_servers_response.end()) {
+ server_start_id = GetServerID(all_servers_.back()) + 1;
+ } else {
+ break;
+ }
+ }
+ std::cout << "Number of servers = " << all_servers_.size() << std::endl;
+ }
+
+ // Get sockets that belongs to servers
+ // Store sockets for dumping data
+ void GetSocketsOfServers() {
+ for (const auto& _server : all_servers_) {
+ std::cout << "Server ID" << GetServerID(_server) << "_"
+ << GetServerName(_server) << " listen_socket - ";
+ for (const auto& _socket : _server.listen_socket()) {
+ int64_t so_id = _socket.socket_id();
+ std::cout << "ID" << so_id << "_" << _socket.name() << " ";
+ if (CheckID(so_id)) {
+ grpc::channelz::v1::Socket so = GetSocketRPC(so_id);
+ all_sockets_.push_back(so);
+ StoreSocketInJson(so);
+ }
+ }
+ std::cout << std::endl;
+ }
+ }
+
+ // Get all top channels, keep querying until getting all
+ // Store channels for dumping data
+ // No need to check id repeating for top channels
+ void GetTopChannelsRPC() {
+ int64_t channel_start_id = 0;
+ while (true) {
+ GetTopChannelsRequest get_top_channels_request;
+ GetTopChannelsResponse get_top_channels_response;
+ ClientContext get_top_channels_context;
+ get_top_channels_context.set_deadline(
+ grpc_timeout_seconds_to_deadline(rpc_timeout_seconds_));
+ get_top_channels_request.set_start_channel_id(channel_start_id);
+ Status status = channelz_stub_->GetTopChannels(
+ &get_top_channels_context, get_top_channels_request,
+ &get_top_channels_response);
+ if (!status.ok()) {
+ gpr_log(GPR_ERROR,
+ "GetTopChannels RPC with "
+ "GetTopChannelsRequest.channel_start_id=%d failed: %s",
+ int(channel_start_id),
+ get_top_channels_context.debug_error_string().c_str());
+ GPR_ASSERT(0);
+ }
+ for (const auto& _topchannel : get_top_channels_response.channel()) {
+ top_channels_.push_back(_topchannel);
+ all_channels_.push_back(_topchannel);
+ StoreChannelInJson(_topchannel);
+ }
+ if (!get_top_channels_response.end()) {
+ channel_start_id = GetChannelID(top_channels_.back()) + 1;
+ } else {
+ break;
+ }
+ }
+ std::cout << std::endl
+ << "Number of top channels = " << top_channels_.size()
+ << std::endl;
+ }
+
+ // layer traverse for each top channel
+ void TraverseTopChannels() {
+ for (const auto& _topchannel : top_channels_) {
+ int tree_depth = 0;
+ std::queue<grpc::channelz::v1::Channel> channel_queue;
+ std::queue<grpc::channelz::v1::Subchannel> subchannel_queue;
+ std::cout << "Tree depth = " << tree_depth << std::endl;
+ GetChannelDescedence(_topchannel, channel_queue, subchannel_queue);
+ while (!channel_queue.empty() || !subchannel_queue.empty()) {
+ ++tree_depth;
+ std::cout << "Tree depth = " << tree_depth << std::endl;
+ int ch_q_size = channel_queue.size();
+ int subch_q_size = subchannel_queue.size();
+ for (int i = 0; i < ch_q_size; ++i) {
+ grpc::channelz::v1::Channel ch = channel_queue.front();
+ channel_queue.pop();
+ GetChannelDescedence(ch, channel_queue, subchannel_queue);
+ }
+ for (int i = 0; i < subch_q_size; ++i) {
+ grpc::channelz::v1::Subchannel subch = subchannel_queue.front();
+ subchannel_queue.pop();
+ GetSubchannelDescedence(subch, channel_queue, subchannel_queue);
+ }
+ }
+ std::cout << std::endl;
+ }
+ }
+
+ // dump data of all entities to stdout
+ void DumpStdout() {
+ TString data_str;
+ for (const auto& _channel : all_channels_) {
+ std::cout << "channel ID" << GetChannelID(_channel) << "_"
+ << GetChannelName(_channel) << " data:" << std::endl;
+ // TODO(mohanli): TextFormat::PrintToString records time as seconds and
+ // nanos. Need a more human readable way.
+ ::google::protobuf::TextFormat::PrintToString(_channel.data(), &data_str);
+ printf("%s\n", data_str.c_str());
+ }
+ for (const auto& _subchannel : all_subchannels_) {
+ std::cout << "subchannel ID" << GetSubchannelID(_subchannel) << "_"
+ << GetSubchannelName(_subchannel) << " data:" << std::endl;
+ ::google::protobuf::TextFormat::PrintToString(_subchannel.data(),
+ &data_str);
+ printf("%s\n", data_str.c_str());
+ }
+ for (const auto& _server : all_servers_) {
+ std::cout << "server ID" << GetServerID(_server) << "_"
+ << GetServerName(_server) << " data:" << std::endl;
+ ::google::protobuf::TextFormat::PrintToString(_server.data(), &data_str);
+ printf("%s\n", data_str.c_str());
+ }
+ for (const auto& _socket : all_sockets_) {
+ std::cout << "socket ID" << GetSocketID(_socket) << "_"
+ << GetSocketName(_socket) << " data:" << std::endl;
+ ::google::protobuf::TextFormat::PrintToString(_socket.data(), &data_str);
+ printf("%s\n", data_str.c_str());
+ }
+ }
+
+ // Store a channel in Json
+ void StoreChannelInJson(const grpc::channelz::v1::Channel& channel) {
+ TString id = grpc::to_string(GetChannelID(channel));
+ TString type = "Channel";
+ TString description;
+ ::google::protobuf::TextFormat::PrintToString(channel.data(), &description);
+ grpc_core::Json description_json = grpc_core::Json(description);
+ StoreEntityInJson(id, type, description_json);
+ }
+
+ // Store a subchannel in Json
+ void StoreSubchannelInJson(const grpc::channelz::v1::Subchannel& subchannel) {
+ TString id = grpc::to_string(GetSubchannelID(subchannel));
+ TString type = "Subchannel";
+ TString description;
+ ::google::protobuf::TextFormat::PrintToString(subchannel.data(),
+ &description);
+ grpc_core::Json description_json = grpc_core::Json(description);
+ StoreEntityInJson(id, type, description_json);
+ }
+
+ // Store a server in Json
+ void StoreServerInJson(const grpc::channelz::v1::Server& server) {
+ TString id = grpc::to_string(GetServerID(server));
+ TString type = "Server";
+ TString description;
+ ::google::protobuf::TextFormat::PrintToString(server.data(), &description);
+ grpc_core::Json description_json = grpc_core::Json(description);
+ StoreEntityInJson(id, type, description_json);
+ }
+
+ // Store a socket in Json
+ void StoreSocketInJson(const grpc::channelz::v1::Socket& socket) {
+ TString id = grpc::to_string(GetSocketID(socket));
+ TString type = "Socket";
+ TString description;
+ ::google::protobuf::TextFormat::PrintToString(socket.data(), &description);
+ grpc_core::Json description_json = grpc_core::Json(description);
+ StoreEntityInJson(id, type, description_json);
+ }
+
+ // Store an entity in Json
+ void StoreEntityInJson(TString& id, TString& type,
+ const grpc_core::Json& description) {
+ TString start, finish;
+ gpr_timespec ago = gpr_time_sub(
+ now_,
+ gpr_time_from_seconds(FLAGS_sampling_interval_seconds, GPR_TIMESPAN));
+ std::stringstream ss;
+ const time_t time_now = now_.tv_sec;
+ ss << std::put_time(std::localtime(&time_now), "%F %T");
+ finish = ss.str(); // example: "2019-02-01 12:12:18"
+ ss.str("");
+ const time_t time_ago = ago.tv_sec;
+ ss << std::put_time(std::localtime(&time_ago), "%F %T");
+ start = ss.str();
+ grpc_core::Json obj =
+ grpc_core::Json::Object{{"Task", y_absl::StrFormat("%s_ID%s", type, id)},
+ {"Start", start},
+ {"Finish", finish},
+ {"ID", id},
+ {"Type", type},
+ {"Description", description}};
+ json_.mutable_array()->push_back(obj);
+ }
+
+ // Dump data in json
+ TString DumpJson() { return json_.Dump(); }
+
+ // Check if one entity has been recorded
+ bool CheckID(int64_t id) {
+ if (id_set_.count(id) == 0) {
+ id_set_.insert(id);
+ return true;
+ } else {
+ return false;
+ }
+ }
+
+ // Record current time
+ void RecordNow() { now_ = gpr_now(GPR_CLOCK_REALTIME); }
+
+ private:
+ std::unique_ptr<grpc::channelz::v1::Channelz::Stub> channelz_stub_;
+ std::vector<grpc::channelz::v1::Channel> top_channels_;
+ std::vector<grpc::channelz::v1::Server> all_servers_;
+ std::vector<grpc::channelz::v1::Channel> all_channels_;
+ std::vector<grpc::channelz::v1::Subchannel> all_subchannels_;
+ std::vector<grpc::channelz::v1::Socket> all_sockets_;
+ std::unordered_set<int64_t> id_set_;
+ grpc_core::Json json_;
+ int64_t rpc_timeout_seconds_;
+ gpr_timespec now_;
+};
+
+int main(int argc, char** argv) {
+ grpc::testing::TestEnvironment env(argc, argv);
+ grpc::testing::InitTest(&argc, &argv, true);
+ std::ofstream output_file(FLAGS_output_json);
+ for (int i = 0; i < FLAGS_sampling_times; ++i) {
+ ChannelzSampler channelz_sampler;
+ channelz_sampler.Setup(FLAGS_custom_credentials_type, FLAGS_server_address);
+ std::cout << "Wait for sampling interval "
+ << FLAGS_sampling_interval_seconds << "s..." << std::endl;
+ const gpr_timespec kDelay = gpr_time_add(
+ gpr_now(GPR_CLOCK_MONOTONIC),
+ gpr_time_from_seconds(FLAGS_sampling_interval_seconds, GPR_TIMESPAN));
+ gpr_sleep_until(kDelay);
+ std::cout << "##### " << i << "th sampling #####" << std::endl;
+ channelz_sampler.RecordNow();
+ channelz_sampler.GetServersRPC();
+ channelz_sampler.GetSocketsOfServers();
+ channelz_sampler.GetTopChannelsRPC();
+ channelz_sampler.TraverseTopChannels();
+ channelz_sampler.DumpStdout();
+ if (!FLAGS_output_json.empty()) {
+ output_file << channelz_sampler.DumpJson() << "\n" << std::flush;
+ }
+ }
+ output_file.close();
+ return 0;
+}
diff --git a/contrib/libs/grpc/test/cpp/util/channelz_sampler_test.cc b/contrib/libs/grpc/test/cpp/util/channelz_sampler_test.cc
new file mode 100644
index 0000000000..d81dbb0d05
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/util/channelz_sampler_test.cc
@@ -0,0 +1,176 @@
+/*
+ *
+ * Copyright 2016 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+#include <stdlib.h>
+#include <unistd.h>
+
+#include <cstdlib>
+#include <iostream>
+#include <memory>
+#include <util/generic/string.h>
+#include <thread>
+
+#include "grpc/grpc.h"
+#include "grpc/support/alloc.h"
+#include "grpc/support/port_platform.h"
+#include "grpcpp/channel.h"
+#include "grpcpp/client_context.h"
+#include "grpcpp/create_channel.h"
+#include "grpcpp/ext/channelz_service_plugin.h"
+#include "grpcpp/grpcpp.h"
+#include "grpcpp/security/credentials.h"
+#include "grpcpp/security/server_credentials.h"
+#include "grpcpp/server.h"
+#include "grpcpp/server_builder.h"
+#include "grpcpp/server_context.h"
+#include "gtest/gtest.h"
+#include "src/core/lib/gpr/env.h"
+#include "src/cpp/server/channelz/channelz_service.h"
+#include "src/proto/grpc/testing/test.grpc.pb.h"
+#include "test/core/util/test_config.h"
+#include "test/cpp/util/subprocess.h"
+#include "test/cpp/util/test_credentials_provider.h"
+
+static TString g_root;
+
+namespace {
+using grpc::ClientContext;
+using grpc::Server;
+using grpc::ServerBuilder;
+using grpc::ServerContext;
+using grpc::Status;
+} // namespace
+
+// Test variables
+TString server_address("0.0.0.0:10000");
+TString custom_credentials_type("INSECURE_CREDENTIALS");
+TString sampling_times = "2";
+TString sampling_interval_seconds = "3";
+TString output_json("output.json");
+
+// Creata an echo server
+class EchoServerImpl final : public grpc::testing::TestService::Service {
+ Status EmptyCall(::grpc::ServerContext* context,
+ const grpc::testing::Empty* request,
+ grpc::testing::Empty* response) {
+ return Status::OK;
+ }
+};
+
+// Run client in a thread
+void RunClient(const TString& client_id, gpr_event* done_ev) {
+ grpc::ChannelArguments channel_args;
+ std::shared_ptr<grpc::ChannelCredentials> channel_creds =
+ grpc::testing::GetCredentialsProvider()->GetChannelCredentials(
+ custom_credentials_type, &channel_args);
+ std::unique_ptr<grpc::testing::TestService::Stub> stub =
+ grpc::testing::TestService::NewStub(
+ grpc::CreateChannel(server_address, channel_creds));
+ gpr_log(GPR_INFO, "Client %s is echoing!", client_id.c_str());
+ while (true) {
+ if (gpr_event_wait(done_ev, grpc_timeout_seconds_to_deadline(1)) !=
+ nullptr) {
+ return;
+ }
+ grpc::testing::Empty request;
+ grpc::testing::Empty response;
+ ClientContext context;
+ Status status = stub->EmptyCall(&context, request, &response);
+ if (!status.ok()) {
+ gpr_log(GPR_ERROR, "Client echo failed.");
+ GPR_ASSERT(0);
+ }
+ }
+}
+
+// Create the channelz to test the connection to the server
+bool WaitForConnection(int wait_server_seconds) {
+ grpc::ChannelArguments channel_args;
+ std::shared_ptr<grpc::ChannelCredentials> channel_creds =
+ grpc::testing::GetCredentialsProvider()->GetChannelCredentials(
+ custom_credentials_type, &channel_args);
+ auto channel = grpc::CreateChannel(server_address, channel_creds);
+ return channel->WaitForConnected(
+ grpc_timeout_seconds_to_deadline(wait_server_seconds));
+}
+
+// Test the channelz sampler
+TEST(ChannelzSamplerTest, SimpleTest) {
+ // start server
+ ::grpc::channelz::experimental::InitChannelzService();
+ EchoServerImpl service;
+ grpc::ServerBuilder builder;
+ auto server_creds =
+ grpc::testing::GetCredentialsProvider()->GetServerCredentials(
+ custom_credentials_type);
+ builder.AddListeningPort(server_address, server_creds);
+ builder.RegisterService(&service);
+ std::unique_ptr<Server> server(builder.BuildAndStart());
+ gpr_log(GPR_INFO, "Server listening on %s", server_address.c_str());
+ const int kWaitForServerSeconds = 10;
+ ASSERT_TRUE(WaitForConnection(kWaitForServerSeconds));
+ // client threads
+ gpr_event done_ev1, done_ev2;
+ gpr_event_init(&done_ev1);
+ gpr_event_init(&done_ev2);
+ std::thread client_thread_1(RunClient, "1", &done_ev1);
+ std::thread client_thread_2(RunClient, "2", &done_ev2);
+ // Run the channelz sampler
+ grpc::SubProcess* test_driver = new grpc::SubProcess(
+ {g_root + "/channelz_sampler", "--server_address=" + server_address,
+ "--custom_credentials_type=" + custom_credentials_type,
+ "--sampling_times=" + sampling_times,
+ "--sampling_interval_seconds=" + sampling_interval_seconds,
+ "--output_json=" + output_json});
+ int status = test_driver->Join();
+ if (WIFEXITED(status)) {
+ if (WEXITSTATUS(status)) {
+ gpr_log(GPR_ERROR,
+ "Channelz sampler test test-runner exited with code %d",
+ WEXITSTATUS(status));
+ GPR_ASSERT(0); // log the line number of the assertion failure
+ }
+ } else if (WIFSIGNALED(status)) {
+ gpr_log(GPR_ERROR, "Channelz sampler test test-runner ended from signal %d",
+ WTERMSIG(status));
+ GPR_ASSERT(0);
+ } else {
+ gpr_log(GPR_ERROR,
+ "Channelz sampler test test-runner ended with unknown status %d",
+ status);
+ GPR_ASSERT(0);
+ }
+ delete test_driver;
+ gpr_event_set(&done_ev1, (void*)1);
+ gpr_event_set(&done_ev2, (void*)1);
+ client_thread_1.join();
+ client_thread_2.join();
+}
+
+int main(int argc, char** argv) {
+ grpc::testing::TestEnvironment env(argc, argv);
+ ::testing::InitGoogleTest(&argc, argv);
+ TString me = argv[0];
+ auto lslash = me.rfind('/');
+ if (lslash != TString::npos) {
+ g_root = me.substr(0, lslash);
+ } else {
+ g_root = ".";
+ }
+ int ret = RUN_ALL_TESTS();
+ return ret;
+}
diff --git a/contrib/libs/grpc/test/cpp/util/cli_call.cc b/contrib/libs/grpc/test/cpp/util/cli_call.cc
new file mode 100644
index 0000000000..5b3631667f
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/util/cli_call.cc
@@ -0,0 +1,229 @@
+/*
+ *
+ * Copyright 2015 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include "test/cpp/util/cli_call.h"
+
+#include <grpc/grpc.h>
+#include <grpc/slice.h>
+#include <grpc/support/log.h>
+#include <grpcpp/channel.h>
+#include <grpcpp/client_context.h>
+#include <grpcpp/support/byte_buffer.h>
+
+#include <cmath>
+#include <iostream>
+#include <utility>
+
+namespace grpc {
+namespace testing {
+namespace {
+void* tag(int i) { return (void*)static_cast<intptr_t>(i); }
+} // namespace
+
+Status CliCall::Call(const std::shared_ptr<grpc::Channel>& channel,
+ const TString& method, const TString& request,
+ TString* response,
+ const OutgoingMetadataContainer& metadata,
+ IncomingMetadataContainer* server_initial_metadata,
+ IncomingMetadataContainer* server_trailing_metadata) {
+ CliCall call(channel, method, metadata);
+ call.Write(request);
+ call.WritesDone();
+ if (!call.Read(response, server_initial_metadata)) {
+ fprintf(stderr, "Failed to read response.\n");
+ }
+ return call.Finish(server_trailing_metadata);
+}
+
+CliCall::CliCall(const std::shared_ptr<grpc::Channel>& channel,
+ const TString& method,
+ const OutgoingMetadataContainer& metadata, CliArgs args)
+ : stub_(new grpc::GenericStub(channel)) {
+ gpr_mu_init(&write_mu_);
+ gpr_cv_init(&write_cv_);
+ if (!metadata.empty()) {
+ for (OutgoingMetadataContainer::const_iterator iter = metadata.begin();
+ iter != metadata.end(); ++iter) {
+ ctx_.AddMetadata(iter->first, iter->second);
+ }
+ }
+
+ // Set deadline if timeout > 0 (default value -1 if no timeout specified)
+ if (args.timeout > 0) {
+ int64_t timeout_in_ns = ceil(args.timeout * 1e9);
+
+ // Convert timeout (in nanoseconds) to a deadline
+ auto deadline =
+ gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC),
+ gpr_time_from_nanos(timeout_in_ns, GPR_TIMESPAN));
+ ctx_.set_deadline(deadline);
+ } else if (args.timeout != -1) {
+ fprintf(
+ stderr,
+ "WARNING: Non-positive timeout value, skipping setting deadline.\n");
+ }
+
+ call_ = stub_->PrepareCall(&ctx_, method, &cq_);
+ call_->StartCall(tag(1));
+ void* got_tag;
+ bool ok;
+ cq_.Next(&got_tag, &ok);
+ GPR_ASSERT(ok);
+}
+
+CliCall::~CliCall() {
+ gpr_cv_destroy(&write_cv_);
+ gpr_mu_destroy(&write_mu_);
+}
+
+void CliCall::Write(const TString& request) {
+ void* got_tag;
+ bool ok;
+
+ gpr_slice s = gpr_slice_from_copied_buffer(request.data(), request.size());
+ grpc::Slice req_slice(s, grpc::Slice::STEAL_REF);
+ grpc::ByteBuffer send_buffer(&req_slice, 1);
+ call_->Write(send_buffer, tag(2));
+ cq_.Next(&got_tag, &ok);
+ GPR_ASSERT(ok);
+}
+
+bool CliCall::Read(TString* response,
+ IncomingMetadataContainer* server_initial_metadata) {
+ void* got_tag;
+ bool ok;
+
+ grpc::ByteBuffer recv_buffer;
+ call_->Read(&recv_buffer, tag(3));
+
+ if (!cq_.Next(&got_tag, &ok) || !ok) {
+ return false;
+ }
+ std::vector<grpc::Slice> slices;
+ GPR_ASSERT(recv_buffer.Dump(&slices).ok());
+
+ response->clear();
+ for (size_t i = 0; i < slices.size(); i++) {
+ response->append(reinterpret_cast<const char*>(slices[i].begin()),
+ slices[i].size());
+ }
+ if (server_initial_metadata) {
+ *server_initial_metadata = ctx_.GetServerInitialMetadata();
+ }
+ return true;
+}
+
+void CliCall::WritesDone() {
+ void* got_tag;
+ bool ok;
+
+ call_->WritesDone(tag(4));
+ cq_.Next(&got_tag, &ok);
+ GPR_ASSERT(ok);
+}
+
+void CliCall::WriteAndWait(const TString& request) {
+ grpc::Slice req_slice(request);
+ grpc::ByteBuffer send_buffer(&req_slice, 1);
+
+ gpr_mu_lock(&write_mu_);
+ call_->Write(send_buffer, tag(2));
+ write_done_ = false;
+ while (!write_done_) {
+ gpr_cv_wait(&write_cv_, &write_mu_, gpr_inf_future(GPR_CLOCK_MONOTONIC));
+ }
+ gpr_mu_unlock(&write_mu_);
+}
+
+void CliCall::WritesDoneAndWait() {
+ gpr_mu_lock(&write_mu_);
+ call_->WritesDone(tag(4));
+ write_done_ = false;
+ while (!write_done_) {
+ gpr_cv_wait(&write_cv_, &write_mu_, gpr_inf_future(GPR_CLOCK_MONOTONIC));
+ }
+ gpr_mu_unlock(&write_mu_);
+}
+
+bool CliCall::ReadAndMaybeNotifyWrite(
+ TString* response, IncomingMetadataContainer* server_initial_metadata) {
+ void* got_tag;
+ bool ok;
+ grpc::ByteBuffer recv_buffer;
+
+ call_->Read(&recv_buffer, tag(3));
+ bool cq_result = cq_.Next(&got_tag, &ok);
+
+ while (got_tag != tag(3)) {
+ gpr_mu_lock(&write_mu_);
+ write_done_ = true;
+ gpr_cv_signal(&write_cv_);
+ gpr_mu_unlock(&write_mu_);
+
+ cq_result = cq_.Next(&got_tag, &ok);
+ if (got_tag == tag(2)) {
+ GPR_ASSERT(ok);
+ }
+ }
+
+ if (!cq_result || !ok) {
+ // If the RPC is ended on the server side, we should still wait for the
+ // pending write on the client side to be done.
+ if (!ok) {
+ gpr_mu_lock(&write_mu_);
+ if (!write_done_) {
+ cq_.Next(&got_tag, &ok);
+ GPR_ASSERT(got_tag != tag(2));
+ write_done_ = true;
+ gpr_cv_signal(&write_cv_);
+ }
+ gpr_mu_unlock(&write_mu_);
+ }
+ return false;
+ }
+
+ std::vector<grpc::Slice> slices;
+ GPR_ASSERT(recv_buffer.Dump(&slices).ok());
+ response->clear();
+ for (size_t i = 0; i < slices.size(); i++) {
+ response->append(reinterpret_cast<const char*>(slices[i].begin()),
+ slices[i].size());
+ }
+ if (server_initial_metadata) {
+ *server_initial_metadata = ctx_.GetServerInitialMetadata();
+ }
+ return true;
+}
+
+Status CliCall::Finish(IncomingMetadataContainer* server_trailing_metadata) {
+ void* got_tag;
+ bool ok;
+ grpc::Status status;
+
+ call_->Finish(&status, tag(5));
+ cq_.Next(&got_tag, &ok);
+ GPR_ASSERT(ok);
+ if (server_trailing_metadata) {
+ *server_trailing_metadata = ctx_.GetServerTrailingMetadata();
+ }
+
+ return status;
+}
+
+} // namespace testing
+} // namespace grpc
diff --git a/contrib/libs/grpc/test/cpp/util/cli_call.h b/contrib/libs/grpc/test/cpp/util/cli_call.h
new file mode 100644
index 0000000000..79d00d99f4
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/util/cli_call.h
@@ -0,0 +1,109 @@
+/*
+ *
+ * Copyright 2015 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#ifndef GRPC_TEST_CPP_UTIL_CLI_CALL_H
+#define GRPC_TEST_CPP_UTIL_CLI_CALL_H
+
+#include <grpcpp/channel.h>
+#include <grpcpp/completion_queue.h>
+#include <grpcpp/generic/generic_stub.h>
+#include <grpcpp/support/status.h>
+#include <grpcpp/support/string_ref.h>
+
+#include <map>
+
+namespace grpc {
+
+class ClientContext;
+
+struct CliArgs {
+ double timeout = -1;
+};
+
+namespace testing {
+
+// CliCall handles the sending and receiving of generic messages given the name
+// of the remote method. This class is only used by GrpcTool. Its thread-safe
+// and thread-unsafe methods should not be used together.
+class CliCall final {
+ public:
+ typedef std::multimap<TString, TString> OutgoingMetadataContainer;
+ typedef std::multimap<grpc::string_ref, grpc::string_ref>
+ IncomingMetadataContainer;
+
+ CliCall(const std::shared_ptr<grpc::Channel>& channel,
+ const TString& method, const OutgoingMetadataContainer& metadata,
+ CliArgs args);
+ CliCall(const std::shared_ptr<grpc::Channel>& channel,
+ const TString& method, const OutgoingMetadataContainer& metadata)
+ : CliCall(channel, method, metadata, CliArgs{}) {}
+
+ ~CliCall();
+
+ // Perform an unary generic RPC.
+ static Status Call(const std::shared_ptr<grpc::Channel>& channel,
+ const TString& method, const TString& request,
+ TString* response,
+ const OutgoingMetadataContainer& metadata,
+ IncomingMetadataContainer* server_initial_metadata,
+ IncomingMetadataContainer* server_trailing_metadata);
+
+ // Send a generic request message in a synchronous manner. NOT thread-safe.
+ void Write(const TString& request);
+
+ // Send a generic request message in a synchronous manner. NOT thread-safe.
+ void WritesDone();
+
+ // Receive a generic response message in a synchronous manner.NOT thread-safe.
+ bool Read(TString* response,
+ IncomingMetadataContainer* server_initial_metadata);
+
+ // Thread-safe write. Must be used with ReadAndMaybeNotifyWrite. Send out a
+ // generic request message and wait for ReadAndMaybeNotifyWrite to finish it.
+ void WriteAndWait(const TString& request);
+
+ // Thread-safe WritesDone. Must be used with ReadAndMaybeNotifyWrite. Send out
+ // WritesDone for gereneric request messages and wait for
+ // ReadAndMaybeNotifyWrite to finish it.
+ void WritesDoneAndWait();
+
+ // Thread-safe Read. Blockingly receive a generic response message. Notify
+ // writes if they are finished when this read is waiting for a resposne.
+ bool ReadAndMaybeNotifyWrite(
+ TString* response,
+ IncomingMetadataContainer* server_initial_metadata);
+
+ // Finish the RPC.
+ Status Finish(IncomingMetadataContainer* server_trailing_metadata);
+
+ TString peer() const { return ctx_.peer(); }
+
+ private:
+ std::unique_ptr<grpc::GenericStub> stub_;
+ grpc::ClientContext ctx_;
+ std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call_;
+ grpc::CompletionQueue cq_;
+ gpr_mu write_mu_;
+ gpr_cv write_cv_; // Protected by write_mu_;
+ bool write_done_; // Portected by write_mu_;
+};
+
+} // namespace testing
+} // namespace grpc
+
+#endif // GRPC_TEST_CPP_UTIL_CLI_CALL_H
diff --git a/contrib/libs/grpc/test/cpp/util/cli_call_test.cc b/contrib/libs/grpc/test/cpp/util/cli_call_test.cc
new file mode 100644
index 0000000000..4f0544b2e5
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/util/cli_call_test.cc
@@ -0,0 +1,128 @@
+/*
+ *
+ * Copyright 2015 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include "test/cpp/util/cli_call.h"
+
+#include <grpc/grpc.h>
+#include <grpcpp/channel.h>
+#include <grpcpp/client_context.h>
+#include <grpcpp/create_channel.h>
+#include <grpcpp/server.h>
+#include <grpcpp/server_builder.h>
+#include <grpcpp/server_context.h>
+#include <gtest/gtest.h>
+
+#include "src/proto/grpc/testing/echo.grpc.pb.h"
+#include "test/core/util/port.h"
+#include "test/core/util/test_config.h"
+#include "test/cpp/util/string_ref_helper.h"
+
+using grpc::testing::EchoRequest;
+using grpc::testing::EchoResponse;
+
+namespace grpc {
+namespace testing {
+
+class TestServiceImpl : public ::grpc::testing::EchoTestService::Service {
+ public:
+ Status Echo(ServerContext* context, const EchoRequest* request,
+ EchoResponse* response) override {
+ if (!context->client_metadata().empty()) {
+ for (std::multimap<grpc::string_ref, grpc::string_ref>::const_iterator
+ iter = context->client_metadata().begin();
+ iter != context->client_metadata().end(); ++iter) {
+ context->AddInitialMetadata(ToString(iter->first),
+ ToString(iter->second));
+ }
+ }
+ context->AddTrailingMetadata("trailing_key", "trailing_value");
+ response->set_message(request->message());
+ return Status::OK;
+ }
+};
+
+class CliCallTest : public ::testing::Test {
+ protected:
+ CliCallTest() {}
+
+ void SetUp() override {
+ int port = grpc_pick_unused_port_or_die();
+ server_address_ << "localhost:" << port;
+ // Setup server
+ ServerBuilder builder;
+ builder.AddListeningPort(server_address_.str(),
+ InsecureServerCredentials());
+ builder.RegisterService(&service_);
+ server_ = builder.BuildAndStart();
+ }
+
+ void TearDown() override { server_->Shutdown(); }
+
+ void ResetStub() {
+ channel_ = grpc::CreateChannel(server_address_.str(),
+ InsecureChannelCredentials());
+ stub_ = grpc::testing::EchoTestService::NewStub(channel_);
+ }
+
+ std::shared_ptr<Channel> channel_;
+ std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_;
+ std::unique_ptr<Server> server_;
+ std::ostringstream server_address_;
+ TestServiceImpl service_;
+};
+
+// Send a rpc with a normal stub and then a CliCall. Verify they match.
+TEST_F(CliCallTest, SimpleRpc) {
+ ResetStub();
+ // Normal stub.
+ EchoRequest request;
+ EchoResponse response;
+ request.set_message("Hello");
+
+ ClientContext context;
+ context.AddMetadata("key1", "val1");
+ Status s = stub_->Echo(&context, request, &response);
+ EXPECT_EQ(response.message(), request.message());
+ EXPECT_TRUE(s.ok());
+
+ const TString kMethod("/grpc.testing.EchoTestService/Echo");
+ TString request_bin, response_bin, expected_response_bin;
+ EXPECT_TRUE(request.SerializeToString(&request_bin));
+ EXPECT_TRUE(response.SerializeToString(&expected_response_bin));
+ std::multimap<TString, TString> client_metadata;
+ std::multimap<grpc::string_ref, grpc::string_ref> server_initial_metadata,
+ server_trailing_metadata;
+ client_metadata.insert(std::pair<TString, TString>("key1", "val1"));
+ Status s2 = CliCall::Call(channel_, kMethod, request_bin, &response_bin,
+ client_metadata, &server_initial_metadata,
+ &server_trailing_metadata);
+ EXPECT_TRUE(s2.ok());
+
+ EXPECT_EQ(expected_response_bin, response_bin);
+ EXPECT_EQ(context.GetServerInitialMetadata(), server_initial_metadata);
+ EXPECT_EQ(context.GetServerTrailingMetadata(), server_trailing_metadata);
+}
+
+} // namespace testing
+} // namespace grpc
+
+int main(int argc, char** argv) {
+ grpc::testing::TestEnvironment env(argc, argv);
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/contrib/libs/grpc/test/cpp/util/cli_credentials.cc b/contrib/libs/grpc/test/cpp/util/cli_credentials.cc
new file mode 100644
index 0000000000..efd548eb9b
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/util/cli_credentials.cc
@@ -0,0 +1,245 @@
+/*
+ *
+ * Copyright 2016 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include "test/cpp/util/cli_credentials.h"
+
+#include <gflags/gflags.h>
+#include <grpc/slice.h>
+#include <grpc/support/log.h>
+#include <grpcpp/impl/codegen/slice.h>
+
+#include "src/core/lib/iomgr/load_file.h"
+
+DEFINE_bool(
+ enable_ssl, false,
+ "Whether to use ssl/tls. Deprecated. Use --channel_creds_type=ssl.");
+DEFINE_bool(use_auth, false,
+ "Whether to create default google credentials. Deprecated. Use "
+ "--channel_creds_type=gdc.");
+DEFINE_string(
+ access_token, "",
+ "The access token that will be sent to the server to authenticate RPCs. "
+ "Deprecated. Use --call_creds=access_token=<token>.");
+DEFINE_string(
+ ssl_target, "",
+ "If not empty, treat the server host name as this for ssl/tls certificate "
+ "validation.");
+DEFINE_string(
+ ssl_client_cert, "",
+ "If not empty, load this PEM formatted client certificate file. Requires "
+ "use of --ssl_client_key.");
+DEFINE_string(
+ ssl_client_key, "",
+ "If not empty, load this PEM formatted private key. Requires use of "
+ "--ssl_client_cert");
+DEFINE_string(
+ local_connect_type, "local_tcp",
+ "The type of local connections for which local channel credentials will "
+ "be applied. Should be local_tcp or uds.");
+DEFINE_string(
+ channel_creds_type, "",
+ "The channel creds type: insecure, ssl, gdc (Google Default Credentials), "
+ "alts, or local.");
+DEFINE_string(
+ call_creds, "",
+ "Call credentials to use: none (default), or access_token=<token>. If "
+ "provided, the call creds are composited on top of channel creds.");
+
+namespace grpc {
+namespace testing {
+
+namespace {
+
+const char ACCESS_TOKEN_PREFIX[] = "access_token=";
+constexpr int ACCESS_TOKEN_PREFIX_LEN =
+ sizeof(ACCESS_TOKEN_PREFIX) / sizeof(*ACCESS_TOKEN_PREFIX) - 1;
+
+bool IsAccessToken(const TString& auth) {
+ return auth.length() > ACCESS_TOKEN_PREFIX_LEN &&
+ auth.compare(0, ACCESS_TOKEN_PREFIX_LEN, ACCESS_TOKEN_PREFIX) == 0;
+}
+
+TString AccessToken(const TString& auth) {
+ if (!IsAccessToken(auth)) {
+ return "";
+ }
+ return TString(auth.c_str(), ACCESS_TOKEN_PREFIX_LEN);
+}
+
+} // namespace
+
+TString CliCredentials::GetDefaultChannelCredsType() const {
+ // Compatibility logic for --enable_ssl.
+ if (FLAGS_enable_ssl) {
+ fprintf(stderr,
+ "warning: --enable_ssl is deprecated. Use "
+ "--channel_creds_type=ssl.\n");
+ return "ssl";
+ }
+ // Compatibility logic for --use_auth.
+ if (FLAGS_access_token.empty() && FLAGS_use_auth) {
+ fprintf(stderr,
+ "warning: --use_auth is deprecated. Use "
+ "--channel_creds_type=gdc.\n");
+ return "gdc";
+ }
+ return "insecure";
+}
+
+TString CliCredentials::GetDefaultCallCreds() const {
+ if (!FLAGS_access_token.empty()) {
+ fprintf(stderr,
+ "warning: --access_token is deprecated. Use "
+ "--call_creds=access_token=<token>.\n");
+ return TString("access_token=") + FLAGS_access_token;
+ }
+ return "none";
+}
+
+std::shared_ptr<grpc::ChannelCredentials>
+CliCredentials::GetChannelCredentials() const {
+ if (FLAGS_channel_creds_type.compare("insecure") == 0) {
+ return grpc::InsecureChannelCredentials();
+ } else if (FLAGS_channel_creds_type.compare("ssl") == 0) {
+ grpc::SslCredentialsOptions ssl_creds_options;
+ // TODO(@Capstan): This won't affect Google Default Credentials using SSL.
+ if (!FLAGS_ssl_client_cert.empty()) {
+ grpc_slice cert_slice = grpc_empty_slice();
+ GRPC_LOG_IF_ERROR(
+ "load_file",
+ grpc_load_file(FLAGS_ssl_client_cert.c_str(), 1, &cert_slice));
+ ssl_creds_options.pem_cert_chain =
+ grpc::StringFromCopiedSlice(cert_slice);
+ grpc_slice_unref(cert_slice);
+ }
+ if (!FLAGS_ssl_client_key.empty()) {
+ grpc_slice key_slice = grpc_empty_slice();
+ GRPC_LOG_IF_ERROR(
+ "load_file",
+ grpc_load_file(FLAGS_ssl_client_key.c_str(), 1, &key_slice));
+ ssl_creds_options.pem_private_key =
+ grpc::StringFromCopiedSlice(key_slice);
+ grpc_slice_unref(key_slice);
+ }
+ return grpc::SslCredentials(ssl_creds_options);
+ } else if (FLAGS_channel_creds_type.compare("gdc") == 0) {
+ return grpc::GoogleDefaultCredentials();
+ } else if (FLAGS_channel_creds_type.compare("alts") == 0) {
+ return grpc::experimental::AltsCredentials(
+ grpc::experimental::AltsCredentialsOptions());
+ } else if (FLAGS_channel_creds_type.compare("local") == 0) {
+ if (FLAGS_local_connect_type.compare("local_tcp") == 0) {
+ return grpc::experimental::LocalCredentials(LOCAL_TCP);
+ } else if (FLAGS_local_connect_type.compare("uds") == 0) {
+ return grpc::experimental::LocalCredentials(UDS);
+ } else {
+ fprintf(stderr,
+ "--local_connect_type=%s invalid; must be local_tcp or uds.\n",
+ FLAGS_local_connect_type.c_str());
+ }
+ }
+ fprintf(stderr,
+ "--channel_creds_type=%s invalid; must be insecure, ssl, gdc, "
+ "alts, or local.\n",
+ FLAGS_channel_creds_type.c_str());
+ return std::shared_ptr<grpc::ChannelCredentials>();
+}
+
+std::shared_ptr<grpc::CallCredentials> CliCredentials::GetCallCredentials()
+ const {
+ if (IsAccessToken(FLAGS_call_creds.c_str())) {
+ return grpc::AccessTokenCredentials(AccessToken(FLAGS_call_creds.c_str()));
+ }
+ if (FLAGS_call_creds.compare("none") == 0) {
+ // Nothing to do; creds, if any, are baked into the channel.
+ return std::shared_ptr<grpc::CallCredentials>();
+ }
+ fprintf(stderr,
+ "--call_creds=%s invalid; must be none "
+ "or access_token=<token>.\n",
+ FLAGS_call_creds.c_str());
+ return std::shared_ptr<grpc::CallCredentials>();
+}
+
+std::shared_ptr<grpc::ChannelCredentials> CliCredentials::GetCredentials()
+ const {
+ if (FLAGS_call_creds.empty()) {
+ FLAGS_call_creds = GetDefaultCallCreds();
+ } else if (!FLAGS_access_token.empty() && !IsAccessToken(FLAGS_call_creds.c_str())) {
+ fprintf(stderr,
+ "warning: ignoring --access_token because --call_creds "
+ "already set to %s.\n",
+ FLAGS_call_creds.c_str());
+ }
+ if (FLAGS_channel_creds_type.empty()) {
+ FLAGS_channel_creds_type = GetDefaultChannelCredsType();
+ } else if (FLAGS_enable_ssl && FLAGS_channel_creds_type.compare("ssl") != 0) {
+ fprintf(stderr,
+ "warning: ignoring --enable_ssl because "
+ "--channel_creds_type already set to %s.\n",
+ FLAGS_channel_creds_type.c_str());
+ } else if (FLAGS_use_auth && FLAGS_channel_creds_type.compare("gdc") != 0) {
+ fprintf(stderr,
+ "warning: ignoring --use_auth because "
+ "--channel_creds_type already set to %s.\n",
+ FLAGS_channel_creds_type.c_str());
+ }
+ // Legacy transport upgrade logic for insecure requests.
+ if (IsAccessToken(FLAGS_call_creds.c_str()) &&
+ FLAGS_channel_creds_type.compare("insecure") == 0) {
+ fprintf(stderr,
+ "warning: --channel_creds_type=insecure upgraded to ssl because "
+ "an access token was provided.\n");
+ FLAGS_channel_creds_type = "ssl";
+ }
+ std::shared_ptr<grpc::ChannelCredentials> channel_creds =
+ GetChannelCredentials();
+ // Composite any call-type credentials on top of the base channel.
+ std::shared_ptr<grpc::CallCredentials> call_creds = GetCallCredentials();
+ return (channel_creds == nullptr || call_creds == nullptr)
+ ? channel_creds
+ : grpc::CompositeChannelCredentials(channel_creds, call_creds);
+}
+
+const TString CliCredentials::GetCredentialUsage() const {
+ return " --enable_ssl ; Set whether to use ssl "
+ "(deprecated)\n"
+ " --use_auth ; Set whether to create default google"
+ " credentials\n"
+ " ; (deprecated)\n"
+ " --access_token ; Set the access token in metadata,"
+ " overrides --use_auth\n"
+ " ; (deprecated)\n"
+ " --ssl_target ; Set server host for ssl validation\n"
+ " --ssl_client_cert ; Client cert for ssl\n"
+ " --ssl_client_key ; Client private key for ssl\n"
+ " --local_connect_type ; Set to local_tcp or uds\n"
+ " --channel_creds_type ; Set to insecure, ssl, gdc, alts, or "
+ "local\n"
+ " --call_creds ; Set to none, or"
+ " access_token=<token>\n";
+}
+
+const TString CliCredentials::GetSslTargetNameOverride() const {
+ bool use_ssl = FLAGS_channel_creds_type.compare("ssl") == 0 ||
+ FLAGS_channel_creds_type.compare("gdc") == 0;
+ return use_ssl ? FLAGS_ssl_target : "";
+}
+
+} // namespace testing
+} // namespace grpc
diff --git a/contrib/libs/grpc/test/cpp/util/cli_credentials.h b/contrib/libs/grpc/test/cpp/util/cli_credentials.h
new file mode 100644
index 0000000000..3e695692fa
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/util/cli_credentials.h
@@ -0,0 +1,55 @@
+/*
+ *
+ * Copyright 2016 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#ifndef GRPC_TEST_CPP_UTIL_CLI_CREDENTIALS_H
+#define GRPC_TEST_CPP_UTIL_CLI_CREDENTIALS_H
+
+#include <grpcpp/security/credentials.h>
+#include <grpcpp/support/config.h>
+
+namespace grpc {
+namespace testing {
+
+class CliCredentials {
+ public:
+ virtual ~CliCredentials() {}
+ std::shared_ptr<grpc::ChannelCredentials> GetCredentials() const;
+ virtual const TString GetCredentialUsage() const;
+ virtual const TString GetSslTargetNameOverride() const;
+
+ protected:
+ // Returns the appropriate channel_creds_type value for the set of legacy
+ // flag arguments.
+ virtual TString GetDefaultChannelCredsType() const;
+ // Returns the appropriate call_creds value for the set of legacy flag
+ // arguments.
+ virtual TString GetDefaultCallCreds() const;
+ // Returns the base transport channel credentials. Child classes can override
+ // to support additional channel_creds_types unknown to this base class.
+ virtual std::shared_ptr<grpc::ChannelCredentials> GetChannelCredentials()
+ const;
+ // Returns call credentials to composite onto the base transport channel
+ // credentials. Child classes can override to support additional
+ // authentication flags unknown to this base class.
+ virtual std::shared_ptr<grpc::CallCredentials> GetCallCredentials() const;
+};
+
+} // namespace testing
+} // namespace grpc
+
+#endif // GRPC_TEST_CPP_UTIL_CLI_CREDENTIALS_H
diff --git a/contrib/libs/grpc/test/cpp/util/config_grpc_cli.h b/contrib/libs/grpc/test/cpp/util/config_grpc_cli.h
new file mode 100644
index 0000000000..358884196d
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/util/config_grpc_cli.h
@@ -0,0 +1,70 @@
+/*
+ *
+ * Copyright 2016 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#ifndef GRPC_TEST_CPP_UTIL_CONFIG_GRPC_CLI_H
+#define GRPC_TEST_CPP_UTIL_CONFIG_GRPC_CLI_H
+
+#include <grpcpp/impl/codegen/config_protobuf.h>
+
+#ifndef GRPC_CUSTOM_DYNAMICMESSAGEFACTORY
+#include <google/protobuf/dynamic_message.h>
+#define GRPC_CUSTOM_DYNAMICMESSAGEFACTORY \
+ ::google::protobuf::DynamicMessageFactory
+#endif
+
+#ifndef GRPC_CUSTOM_DESCRIPTORPOOLDATABASE
+#include <google/protobuf/descriptor.h>
+#define GRPC_CUSTOM_DESCRIPTORPOOLDATABASE \
+ ::google::protobuf::DescriptorPoolDatabase
+#define GRPC_CUSTOM_MERGEDDESCRIPTORDATABASE \
+ ::google::protobuf::MergedDescriptorDatabase
+#endif
+
+#ifndef GRPC_CUSTOM_TEXTFORMAT
+#include <google/protobuf/text_format.h>
+#define GRPC_CUSTOM_TEXTFORMAT ::google::protobuf::TextFormat
+#endif
+
+#ifndef GRPC_CUSTOM_DISKSOURCETREE
+#include <google/protobuf/compiler/importer.h>
+#define GRPC_CUSTOM_DISKSOURCETREE ::google::protobuf::compiler::DiskSourceTree
+#define GRPC_CUSTOM_IMPORTER ::google::protobuf::compiler::Importer
+#define GRPC_CUSTOM_MULTIFILEERRORCOLLECTOR \
+ ::google::protobuf::compiler::MultiFileErrorCollector
+#endif
+
+namespace grpc {
+namespace protobuf {
+
+typedef GRPC_CUSTOM_DYNAMICMESSAGEFACTORY DynamicMessageFactory;
+
+typedef GRPC_CUSTOM_DESCRIPTORPOOLDATABASE DescriptorPoolDatabase;
+typedef GRPC_CUSTOM_MERGEDDESCRIPTORDATABASE MergedDescriptorDatabase;
+
+typedef GRPC_CUSTOM_TEXTFORMAT TextFormat;
+
+namespace compiler {
+typedef GRPC_CUSTOM_DISKSOURCETREE DiskSourceTree;
+typedef GRPC_CUSTOM_IMPORTER Importer;
+typedef GRPC_CUSTOM_MULTIFILEERRORCOLLECTOR MultiFileErrorCollector;
+} // namespace compiler
+
+} // namespace protobuf
+} // namespace grpc
+
+#endif // GRPC_TEST_CPP_UTIL_CONFIG_GRPC_CLI_H
diff --git a/contrib/libs/grpc/test/cpp/util/create_test_channel.cc b/contrib/libs/grpc/test/cpp/util/create_test_channel.cc
new file mode 100644
index 0000000000..86d8e22af1
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/util/create_test_channel.cc
@@ -0,0 +1,252 @@
+/*
+ *
+ * Copyright 2015-2016 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include "test/cpp/util/create_test_channel.h"
+
+#include <gflags/gflags.h>
+
+#include <grpc/support/log.h>
+#include <grpcpp/create_channel.h>
+#include <grpcpp/security/credentials.h>
+
+#include "test/cpp/util/test_credentials_provider.h"
+
+DEFINE_string(
+ grpc_test_use_grpclb_with_child_policy, "",
+ "If non-empty, set a static service config on channels created by "
+ "grpc::CreateTestChannel, that configures the grpclb LB policy "
+ "with a child policy being the value of this flag (e.g. round_robin "
+ "or pick_first).");
+
+namespace grpc {
+
+namespace {
+
+const char kProdTlsCredentialsType[] = "prod_ssl";
+
+class SslCredentialProvider : public testing::CredentialTypeProvider {
+ public:
+ std::shared_ptr<ChannelCredentials> GetChannelCredentials(
+ grpc::ChannelArguments* /*args*/) override {
+ return grpc::SslCredentials(SslCredentialsOptions());
+ }
+ std::shared_ptr<ServerCredentials> GetServerCredentials() override {
+ return nullptr;
+ }
+};
+
+gpr_once g_once_init_add_prod_ssl_provider = GPR_ONCE_INIT;
+// Register ssl with non-test roots type to the credentials provider.
+void AddProdSslType() {
+ testing::GetCredentialsProvider()->AddSecureType(
+ kProdTlsCredentialsType, std::unique_ptr<testing::CredentialTypeProvider>(
+ new SslCredentialProvider));
+}
+
+void MaybeSetCustomChannelArgs(grpc::ChannelArguments* args) {
+ if (FLAGS_grpc_test_use_grpclb_with_child_policy.size() > 0) {
+ args->SetString("grpc.service_config",
+ "{\"loadBalancingConfig\":[{\"grpclb\":{\"childPolicy\":[{"
+ "\"" +
+ FLAGS_grpc_test_use_grpclb_with_child_policy +
+ "\":{}}]}}]}");
+ }
+}
+
+} // namespace
+
+// When cred_type is 'ssl', if server is empty, override_hostname is used to
+// create channel. Otherwise, connect to server and override hostname if
+// override_hostname is provided.
+// When cred_type is not 'ssl', override_hostname is ignored.
+// Set use_prod_root to true to use the SSL root for connecting to google.
+// In this case, path to the roots pem file must be set via environment variable
+// GRPC_DEFAULT_SSL_ROOTS_FILE_PATH.
+// Otherwise, root for test SSL cert will be used.
+// creds will be used to create a channel when cred_type is 'ssl'.
+// Use examples:
+// CreateTestChannel(
+// "1.1.1.1:12345", "ssl", "override.hostname.com", false, creds);
+// CreateTestChannel("test.google.com:443", "ssl", "", true, creds);
+// same as above
+// CreateTestChannel("", "ssl", "test.google.com:443", true, creds);
+std::shared_ptr<Channel> CreateTestChannel(
+ const TString& server, const TString& cred_type,
+ const TString& override_hostname, bool use_prod_roots,
+ const std::shared_ptr<CallCredentials>& creds,
+ const ChannelArguments& args) {
+ return CreateTestChannel(server, cred_type, override_hostname, use_prod_roots,
+ creds, args,
+ /*interceptor_creators=*/{});
+}
+
+std::shared_ptr<Channel> CreateTestChannel(
+ const TString& server, const TString& override_hostname,
+ testing::transport_security security_type, bool use_prod_roots,
+ const std::shared_ptr<CallCredentials>& creds,
+ const ChannelArguments& args) {
+ return CreateTestChannel(server, override_hostname, security_type,
+ use_prod_roots, creds, args,
+ /*interceptor_creators=*/{});
+}
+
+std::shared_ptr<Channel> CreateTestChannel(
+ const TString& server, const TString& override_hostname,
+ testing::transport_security security_type, bool use_prod_roots,
+ const std::shared_ptr<CallCredentials>& creds) {
+ return CreateTestChannel(server, override_hostname, security_type,
+ use_prod_roots, creds, ChannelArguments());
+}
+
+std::shared_ptr<Channel> CreateTestChannel(
+ const TString& server, const TString& override_hostname,
+ testing::transport_security security_type, bool use_prod_roots) {
+ return CreateTestChannel(server, override_hostname, security_type,
+ use_prod_roots, std::shared_ptr<CallCredentials>());
+}
+
+// Shortcut for end2end and interop tests.
+std::shared_ptr<Channel> CreateTestChannel(
+ const TString& server, testing::transport_security security_type) {
+ return CreateTestChannel(server, "foo.test.google.fr", security_type, false);
+}
+
+std::shared_ptr<Channel> CreateTestChannel(
+ const TString& server, const TString& credential_type,
+ const std::shared_ptr<CallCredentials>& creds) {
+ ChannelArguments channel_args;
+ MaybeSetCustomChannelArgs(&channel_args);
+ std::shared_ptr<ChannelCredentials> channel_creds =
+ testing::GetCredentialsProvider()->GetChannelCredentials(credential_type,
+ &channel_args);
+ GPR_ASSERT(channel_creds != nullptr);
+ if (creds.get()) {
+ channel_creds = grpc::CompositeChannelCredentials(channel_creds, creds);
+ }
+ return ::grpc::CreateCustomChannel(server, channel_creds, channel_args);
+}
+
+std::shared_ptr<Channel> CreateTestChannel(
+ const TString& server, const TString& cred_type,
+ const TString& override_hostname, bool use_prod_roots,
+ const std::shared_ptr<CallCredentials>& creds, const ChannelArguments& args,
+ std::vector<
+ std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
+ interceptor_creators) {
+ ChannelArguments channel_args(args);
+ MaybeSetCustomChannelArgs(&channel_args);
+ std::shared_ptr<ChannelCredentials> channel_creds;
+ if (cred_type.empty()) {
+ if (interceptor_creators.empty()) {
+ return ::grpc::CreateCustomChannel(server, InsecureChannelCredentials(),
+ channel_args);
+ } else {
+ return experimental::CreateCustomChannelWithInterceptors(
+ server, InsecureChannelCredentials(), channel_args,
+ std::move(interceptor_creators));
+ }
+ } else if (cred_type == testing::kTlsCredentialsType) { // cred_type == "ssl"
+ if (use_prod_roots) {
+ gpr_once_init(&g_once_init_add_prod_ssl_provider, &AddProdSslType);
+ channel_creds = testing::GetCredentialsProvider()->GetChannelCredentials(
+ kProdTlsCredentialsType, &channel_args);
+ if (!server.empty() && !override_hostname.empty()) {
+ channel_args.SetSslTargetNameOverride(override_hostname);
+ }
+ } else {
+ // override_hostname is discarded as the provider handles it.
+ channel_creds = testing::GetCredentialsProvider()->GetChannelCredentials(
+ testing::kTlsCredentialsType, &channel_args);
+ }
+ GPR_ASSERT(channel_creds != nullptr);
+
+ const TString& connect_to = server.empty() ? override_hostname : server;
+ if (creds.get()) {
+ channel_creds = grpc::CompositeChannelCredentials(channel_creds, creds);
+ }
+ if (interceptor_creators.empty()) {
+ return ::grpc::CreateCustomChannel(connect_to, channel_creds,
+ channel_args);
+ } else {
+ return experimental::CreateCustomChannelWithInterceptors(
+ connect_to, channel_creds, channel_args,
+ std::move(interceptor_creators));
+ }
+ } else {
+ channel_creds = testing::GetCredentialsProvider()->GetChannelCredentials(
+ cred_type, &channel_args);
+ GPR_ASSERT(channel_creds != nullptr);
+
+ if (interceptor_creators.empty()) {
+ return ::grpc::CreateCustomChannel(server, channel_creds, channel_args);
+ } else {
+ return experimental::CreateCustomChannelWithInterceptors(
+ server, channel_creds, channel_args, std::move(interceptor_creators));
+ }
+ }
+}
+
+std::shared_ptr<Channel> CreateTestChannel(
+ const TString& server, const TString& override_hostname,
+ testing::transport_security security_type, bool use_prod_roots,
+ const std::shared_ptr<CallCredentials>& creds, const ChannelArguments& args,
+ std::vector<
+ std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
+ interceptor_creators) {
+ TString credential_type =
+ security_type == testing::ALTS
+ ? testing::kAltsCredentialsType
+ : (security_type == testing::TLS ? testing::kTlsCredentialsType
+ : testing::kInsecureCredentialsType);
+ return CreateTestChannel(server, credential_type, override_hostname,
+ use_prod_roots, creds, args,
+ std::move(interceptor_creators));
+}
+
+std::shared_ptr<Channel> CreateTestChannel(
+ const TString& server, const TString& override_hostname,
+ testing::transport_security security_type, bool use_prod_roots,
+ const std::shared_ptr<CallCredentials>& creds,
+ std::vector<
+ std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
+ interceptor_creators) {
+ return CreateTestChannel(server, override_hostname, security_type,
+ use_prod_roots, creds, ChannelArguments(),
+ std::move(interceptor_creators));
+}
+
+std::shared_ptr<Channel> CreateTestChannel(
+ const TString& server, const TString& credential_type,
+ const std::shared_ptr<CallCredentials>& creds,
+ std::vector<
+ std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
+ interceptor_creators) {
+ ChannelArguments channel_args;
+ MaybeSetCustomChannelArgs(&channel_args);
+ std::shared_ptr<ChannelCredentials> channel_creds =
+ testing::GetCredentialsProvider()->GetChannelCredentials(credential_type,
+ &channel_args);
+ GPR_ASSERT(channel_creds != nullptr);
+ if (creds.get()) {
+ channel_creds = grpc::CompositeChannelCredentials(channel_creds, creds);
+ }
+ return experimental::CreateCustomChannelWithInterceptors(
+ server, channel_creds, channel_args, std::move(interceptor_creators));
+}
+
+} // namespace grpc
diff --git a/contrib/libs/grpc/test/cpp/util/create_test_channel.h b/contrib/libs/grpc/test/cpp/util/create_test_channel.h
new file mode 100644
index 0000000000..ed4ce6c11b
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/util/create_test_channel.h
@@ -0,0 +1,99 @@
+/*
+ *
+ * Copyright 2015-2016 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#ifndef GRPC_TEST_CPP_UTIL_CREATE_TEST_CHANNEL_H
+#define GRPC_TEST_CPP_UTIL_CREATE_TEST_CHANNEL_H
+
+#include <memory>
+
+#include <grpcpp/channel.h>
+#include <grpcpp/impl/codegen/client_interceptor.h>
+#include <grpcpp/security/credentials.h>
+#include <grpcpp/support/channel_arguments.h>
+
+namespace grpc {
+class Channel;
+
+namespace testing {
+
+typedef enum { INSECURE = 0, TLS, ALTS } transport_security;
+
+} // namespace testing
+
+std::shared_ptr<Channel> CreateTestChannel(
+ const TString& server, testing::transport_security security_type);
+
+std::shared_ptr<Channel> CreateTestChannel(
+ const TString& server, const TString& override_hostname,
+ testing::transport_security security_type, bool use_prod_roots);
+
+std::shared_ptr<Channel> CreateTestChannel(
+ const TString& server, const TString& override_hostname,
+ testing::transport_security security_type, bool use_prod_roots,
+ const std::shared_ptr<CallCredentials>& creds);
+
+std::shared_ptr<Channel> CreateTestChannel(
+ const TString& server, const TString& override_hostname,
+ testing::transport_security security_type, bool use_prod_roots,
+ const std::shared_ptr<CallCredentials>& creds,
+ const ChannelArguments& args);
+
+std::shared_ptr<Channel> CreateTestChannel(
+ const TString& server, const TString& cred_type,
+ const TString& override_hostname, bool use_prod_roots,
+ const std::shared_ptr<CallCredentials>& creds,
+ const ChannelArguments& args);
+
+std::shared_ptr<Channel> CreateTestChannel(
+ const TString& server, const TString& credential_type,
+ const std::shared_ptr<CallCredentials>& creds);
+
+std::shared_ptr<Channel> CreateTestChannel(
+ const TString& server, const TString& override_hostname,
+ testing::transport_security security_type, bool use_prod_roots,
+ const std::shared_ptr<CallCredentials>& creds,
+ std::vector<
+ std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
+ interceptor_creators);
+
+std::shared_ptr<Channel> CreateTestChannel(
+ const TString& server, const TString& override_hostname,
+ testing::transport_security security_type, bool use_prod_roots,
+ const std::shared_ptr<CallCredentials>& creds, const ChannelArguments& args,
+ std::vector<
+ std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
+ interceptor_creators);
+
+std::shared_ptr<Channel> CreateTestChannel(
+ const TString& server, const TString& cred_type,
+ const TString& override_hostname, bool use_prod_roots,
+ const std::shared_ptr<CallCredentials>& creds, const ChannelArguments& args,
+ std::vector<
+ std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
+ interceptor_creators);
+
+std::shared_ptr<Channel> CreateTestChannel(
+ const TString& server, const TString& credential_type,
+ const std::shared_ptr<CallCredentials>& creds,
+ std::vector<
+ std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
+ interceptor_creators);
+
+} // namespace grpc
+
+#endif // GRPC_TEST_CPP_UTIL_CREATE_TEST_CHANNEL_H
diff --git a/contrib/libs/grpc/test/cpp/util/error_details_test.cc b/contrib/libs/grpc/test/cpp/util/error_details_test.cc
new file mode 100644
index 0000000000..630ab1d98f
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/util/error_details_test.cc
@@ -0,0 +1,125 @@
+/*
+ *
+ * Copyright 2017 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include <grpcpp/support/error_details.h>
+#include <gtest/gtest.h>
+
+#include "src/proto/grpc/status/status.pb.h"
+#include "src/proto/grpc/testing/echo_messages.pb.h"
+#include "test/core/util/test_config.h"
+
+namespace grpc {
+namespace {
+
+TEST(ExtractTest, Success) {
+ google::rpc::Status expected;
+ expected.set_code(13); // INTERNAL
+ expected.set_message("I am an error message");
+ testing::EchoRequest expected_details;
+ expected_details.set_message(TString(100, '\0'));
+ expected.add_details()->PackFrom(expected_details);
+
+ google::rpc::Status to;
+ TString error_details = expected.SerializeAsString();
+ Status from(static_cast<StatusCode>(expected.code()), expected.message(),
+ error_details);
+ EXPECT_TRUE(ExtractErrorDetails(from, &to).ok());
+ EXPECT_EQ(expected.code(), to.code());
+ EXPECT_EQ(expected.message(), to.message());
+ EXPECT_EQ(1, to.details_size());
+ testing::EchoRequest details;
+ to.details(0).UnpackTo(&details);
+ EXPECT_EQ(expected_details.message(), details.message());
+}
+
+TEST(ExtractTest, NullInput) {
+ EXPECT_EQ(StatusCode::FAILED_PRECONDITION,
+ ExtractErrorDetails(Status(), nullptr).error_code());
+}
+
+TEST(ExtractTest, Unparsable) {
+ TString error_details("I am not a status object");
+ Status from(StatusCode::INTERNAL, "", error_details);
+ google::rpc::Status to;
+ EXPECT_EQ(StatusCode::INVALID_ARGUMENT,
+ ExtractErrorDetails(from, &to).error_code());
+}
+
+TEST(SetTest, Success) {
+ google::rpc::Status expected;
+ expected.set_code(13); // INTERNAL
+ expected.set_message("I am an error message");
+ testing::EchoRequest expected_details;
+ expected_details.set_message(TString(100, '\0'));
+ expected.add_details()->PackFrom(expected_details);
+
+ Status to;
+ Status s = SetErrorDetails(expected, &to);
+ EXPECT_TRUE(s.ok());
+ EXPECT_EQ(expected.code(), to.error_code());
+ EXPECT_EQ(expected.message(), to.error_message());
+ EXPECT_EQ(expected.SerializeAsString(), to.error_details());
+}
+
+TEST(SetTest, NullInput) {
+ EXPECT_EQ(StatusCode::FAILED_PRECONDITION,
+ SetErrorDetails(google::rpc::Status(), nullptr).error_code());
+}
+
+TEST(SetTest, OutOfScopeErrorCode) {
+ google::rpc::Status expected;
+ expected.set_code(17); // Out of scope (UNAUTHENTICATED is 16).
+ expected.set_message("I am an error message");
+ testing::EchoRequest expected_details;
+ expected_details.set_message(TString(100, '\0'));
+ expected.add_details()->PackFrom(expected_details);
+
+ Status to;
+ Status s = SetErrorDetails(expected, &to);
+ EXPECT_TRUE(s.ok());
+ EXPECT_EQ(StatusCode::UNKNOWN, to.error_code());
+ EXPECT_EQ(expected.message(), to.error_message());
+ EXPECT_EQ(expected.SerializeAsString(), to.error_details());
+}
+
+TEST(SetTest, ValidScopeErrorCode) {
+ for (int c = StatusCode::OK; c <= StatusCode::UNAUTHENTICATED; c++) {
+ google::rpc::Status expected;
+ expected.set_code(c);
+ expected.set_message("I am an error message");
+ testing::EchoRequest expected_details;
+ expected_details.set_message(TString(100, '\0'));
+ expected.add_details()->PackFrom(expected_details);
+
+ Status to;
+ Status s = SetErrorDetails(expected, &to);
+ EXPECT_TRUE(s.ok());
+ EXPECT_EQ(c, to.error_code());
+ EXPECT_EQ(expected.message(), to.error_message());
+ EXPECT_EQ(expected.SerializeAsString(), to.error_details());
+ }
+}
+
+} // namespace
+} // namespace grpc
+
+int main(int argc, char** argv) {
+ grpc::testing::TestEnvironment env(argc, argv);
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/contrib/libs/grpc/test/cpp/util/grpc_cli.cc b/contrib/libs/grpc/test/cpp/util/grpc_cli.cc
new file mode 100644
index 0000000000..45c6b94f84
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/util/grpc_cli.cc
@@ -0,0 +1,90 @@
+/*
+
+ * Copyright 2015 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+/*
+ A command line tool to talk to a grpc server.
+ Run `grpc_cli help` command to see its usage information.
+
+ Example of talking to grpc interop server:
+ grpc_cli call localhost:50051 UnaryCall "response_size:10" \
+ --protofiles=src/proto/grpc/testing/test.proto --enable_ssl=false
+
+ Options:
+ 1. --protofiles, use this flag to provide proto files if the server does
+ does not have the reflection service.
+ 2. --proto_path, if your proto file is not under current working directory,
+ use this flag to provide a search root. It should work similar to the
+ counterpart in protoc. This option is valid only when protofiles is
+ provided.
+ 3. --metadata specifies metadata to be sent to the server, such as:
+ --metadata="MyHeaderKey1:Value1:MyHeaderKey2:Value2"
+ 4. --enable_ssl, whether to use tls.
+ 5. --use_auth, if set to true, attach a GoogleDefaultCredentials to the call
+ 6. --infile, input filename (defaults to stdin)
+ 7. --outfile, output filename (defaults to stdout)
+ 8. --binary_input, use the serialized request as input. The serialized
+ request can be generated by calling something like:
+ protoc --proto_path=src/proto/grpc/testing/ \
+ --encode=grpc.testing.SimpleRequest \
+ src/proto/grpc/testing/messages.proto \
+ < input.txt > input.bin
+ If this is used and no proto file is provided in the argument list, the
+ method string has to be exact in the form of /package.service/method.
+ 9. --binary_output, use binary format response as output, it can
+ be later decoded using protoc:
+ protoc --proto_path=src/proto/grpc/testing/ \
+ --decode=grpc.testing.SimpleResponse \
+ src/proto/grpc/testing/messages.proto \
+ < output.bin > output.txt
+ 10. --default_service_config, optional default service config to use
+ on the channel. Note that this may be ignored if the name resolver
+ returns a service config.
+ 11. --display_peer_address, on CallMethod commands, log the peer socket
+ address of the connection that each RPC is made on to stderr.
+*/
+
+#include <fstream>
+#include <functional>
+#include <iostream>
+
+#include <gflags/gflags.h>
+#include <grpcpp/support/config.h>
+#include "test/cpp/util/cli_credentials.h"
+#include "test/cpp/util/grpc_tool.h"
+#include "test/cpp/util/test_config.h"
+
+DEFINE_string(outfile, "", "Output file (default is stdout)");
+
+static bool SimplePrint(const TString& outfile, const TString& output) {
+ if (outfile.empty()) {
+ std::cout << output << std::flush;
+ } else {
+ std::ofstream output_file(outfile, std::ios::app | std::ios::binary);
+ output_file << output << std::flush;
+ output_file.close();
+ }
+ return true;
+}
+
+int main(int argc, char** argv) {
+ grpc::testing::InitTest(&argc, &argv, true);
+
+ return grpc::testing::GrpcToolMainLib(
+ argc, (const char**)argv, grpc::testing::CliCredentials(),
+ std::bind(SimplePrint, TString(FLAGS_outfile.c_str()), std::placeholders::_1));
+}
diff --git a/contrib/libs/grpc/test/cpp/util/grpc_tool.cc b/contrib/libs/grpc/test/cpp/util/grpc_tool.cc
new file mode 100644
index 0000000000..30f3024e25
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/util/grpc_tool.cc
@@ -0,0 +1,985 @@
+/*
+ *
+ * Copyright 2016 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include "test/cpp/util/grpc_tool.h"
+
+#include <gflags/gflags.h>
+#include <grpc/grpc.h>
+#include <grpc/support/port_platform.h>
+#include <grpcpp/channel.h>
+#include <grpcpp/create_channel.h>
+#include <grpcpp/grpcpp.h>
+#include <grpcpp/security/credentials.h>
+#include <grpcpp/support/string_ref.h>
+
+#include <cstdio>
+#include <fstream>
+#include <iostream>
+#include <memory>
+#include <sstream>
+#include <util/generic/string.h>
+#include <thread>
+
+#include "test/cpp/util/cli_call.h"
+#include "test/cpp/util/proto_file_parser.h"
+#include "test/cpp/util/proto_reflection_descriptor_database.h"
+#include "test/cpp/util/service_describer.h"
+
+#if GPR_WINDOWS
+#include <io.h>
+#else
+#include <unistd.h>
+#endif
+
+namespace grpc {
+namespace testing {
+
+DEFINE_bool(l, false, "Use a long listing format");
+DEFINE_bool(remotedb, true, "Use server types to parse and format messages");
+DEFINE_string(metadata, "",
+ "Metadata to send to server, in the form of key1:val1:key2:val2");
+DEFINE_string(proto_path, ".", "Path to look for the proto file.");
+DEFINE_string(protofiles, "", "Name of the proto file.");
+DEFINE_bool(binary_input, false, "Input in binary format");
+DEFINE_bool(binary_output, false, "Output in binary format");
+DEFINE_string(
+ default_service_config, "",
+ "Default service config to use on the channel, if non-empty. Note "
+ "that this will be ignored if the name resolver returns a service "
+ "config.");
+DEFINE_bool(
+ display_peer_address, false,
+ "Log the peer socket address of the connection that each RPC is made "
+ "on to stderr.");
+DEFINE_bool(json_input, false, "Input in json format");
+DEFINE_bool(json_output, false, "Output in json format");
+DEFINE_string(infile, "", "Input file (default is stdin)");
+DEFINE_bool(batch, false,
+ "Input contains multiple requests. Please do not use this to send "
+ "more than a few RPCs. gRPC CLI has very different performance "
+ "characteristics compared with normal RPC calls which make it "
+ "unsuitable for loadtesting or significant production traffic.");
+DEFINE_double(timeout, -1,
+ "Specify timeout in seconds, used to set the deadline for all "
+ "RPCs. The default value of -1 means no deadline has been set.");
+
+namespace {
+
+class GrpcTool {
+ public:
+ explicit GrpcTool();
+ virtual ~GrpcTool() {}
+
+ bool Help(int argc, const char** argv, const CliCredentials& cred,
+ GrpcToolOutputCallback callback);
+ bool CallMethod(int argc, const char** argv, const CliCredentials& cred,
+ GrpcToolOutputCallback callback);
+ bool ListServices(int argc, const char** argv, const CliCredentials& cred,
+ GrpcToolOutputCallback callback);
+ bool PrintType(int argc, const char** argv, const CliCredentials& cred,
+ GrpcToolOutputCallback callback);
+ // TODO(zyc): implement the following methods
+ // bool ListServices(int argc, const char** argv, GrpcToolOutputCallback
+ // callback);
+ // bool PrintTypeId(int argc, const char** argv, GrpcToolOutputCallback
+ // callback);
+ bool ParseMessage(int argc, const char** argv, const CliCredentials& cred,
+ GrpcToolOutputCallback callback);
+ bool ToText(int argc, const char** argv, const CliCredentials& cred,
+ GrpcToolOutputCallback callback);
+ bool ToJson(int argc, const char** argv, const CliCredentials& cred,
+ GrpcToolOutputCallback callback);
+ bool ToBinary(int argc, const char** argv, const CliCredentials& cred,
+ GrpcToolOutputCallback callback);
+
+ void SetPrintCommandMode(int exit_status) {
+ print_command_usage_ = true;
+ usage_exit_status_ = exit_status;
+ }
+
+ private:
+ void CommandUsage(const TString& usage) const;
+ bool print_command_usage_;
+ int usage_exit_status_;
+ const TString cred_usage_;
+};
+
+template <typename T>
+std::function<bool(GrpcTool*, int, const char**, const CliCredentials&,
+ GrpcToolOutputCallback)>
+BindWith5Args(T&& func) {
+ return std::bind(std::forward<T>(func), std::placeholders::_1,
+ std::placeholders::_2, std::placeholders::_3,
+ std::placeholders::_4, std::placeholders::_5);
+}
+
+template <typename T>
+size_t ArraySize(T& a) {
+ return ((sizeof(a) / sizeof(*(a))) /
+ static_cast<size_t>(!(sizeof(a) % sizeof(*(a)))));
+}
+
+void ParseMetadataFlag(
+ std::multimap<TString, TString>* client_metadata) {
+ if (FLAGS_metadata.empty()) {
+ return;
+ }
+ std::vector<TString> fields;
+ const char delim = ':';
+ const char escape = '\\';
+ size_t cur = -1;
+ std::stringstream ss;
+ while (++cur < FLAGS_metadata.length()) {
+ switch (FLAGS_metadata.at(cur)) {
+ case escape:
+ if (cur < FLAGS_metadata.length() - 1) {
+ char c = FLAGS_metadata.at(++cur);
+ if (c == delim || c == escape) {
+ ss << c;
+ continue;
+ }
+ }
+ fprintf(stderr, "Failed to parse metadata flag.\n");
+ exit(1);
+ case delim:
+ fields.push_back(ss.str());
+ ss.str("");
+ ss.clear();
+ break;
+ default:
+ ss << FLAGS_metadata.at(cur);
+ }
+ }
+ fields.push_back(ss.str());
+ if (fields.size() % 2) {
+ fprintf(stderr, "Failed to parse metadata flag.\n");
+ exit(1);
+ }
+ for (size_t i = 0; i < fields.size(); i += 2) {
+ client_metadata->insert(
+ std::pair<TString, TString>(fields[i], fields[i + 1]));
+ }
+}
+
+template <typename T>
+void PrintMetadata(const T& m, const TString& message) {
+ if (m.empty()) {
+ return;
+ }
+ fprintf(stderr, "%s\n", message.c_str());
+ TString pair;
+ for (typename T::const_iterator iter = m.begin(); iter != m.end(); ++iter) {
+ pair.clear();
+ pair.append(iter->first.data(), iter->first.size());
+ pair.append(" : ");
+ pair.append(iter->second.data(), iter->second.size());
+ fprintf(stderr, "%s\n", pair.c_str());
+ }
+}
+
+void ReadResponse(CliCall* call, const TString& method_name,
+ GrpcToolOutputCallback callback, ProtoFileParser* parser,
+ gpr_mu* parser_mu, bool print_mode) {
+ TString serialized_response_proto;
+ std::multimap<grpc::string_ref, grpc::string_ref> server_initial_metadata;
+
+ for (bool receive_initial_metadata = true; call->ReadAndMaybeNotifyWrite(
+ &serialized_response_proto,
+ receive_initial_metadata ? &server_initial_metadata : nullptr);
+ receive_initial_metadata = false) {
+ fprintf(stderr, "got response.\n");
+ if (!FLAGS_binary_output) {
+ gpr_mu_lock(parser_mu);
+ serialized_response_proto = parser->GetFormattedStringFromMethod(
+ method_name, serialized_response_proto, false /* is_request */,
+ FLAGS_json_output);
+ if (parser->HasError() && print_mode) {
+ fprintf(stderr, "Failed to parse response.\n");
+ }
+ gpr_mu_unlock(parser_mu);
+ }
+ if (receive_initial_metadata) {
+ PrintMetadata(server_initial_metadata,
+ "Received initial metadata from server:");
+ }
+ if (!callback(serialized_response_proto) && print_mode) {
+ fprintf(stderr, "Failed to output response.\n");
+ }
+ }
+}
+
+std::shared_ptr<grpc::Channel> CreateCliChannel(
+ const TString& server_address, const CliCredentials& cred) {
+ grpc::ChannelArguments args;
+ if (!cred.GetSslTargetNameOverride().empty()) {
+ args.SetSslTargetNameOverride(cred.GetSslTargetNameOverride());
+ }
+ if (!FLAGS_default_service_config.empty()) {
+ args.SetString(GRPC_ARG_SERVICE_CONFIG,
+ FLAGS_default_service_config.c_str());
+ }
+ return ::grpc::CreateCustomChannel(server_address, cred.GetCredentials(),
+ args);
+}
+
+struct Command {
+ const char* command;
+ std::function<bool(GrpcTool*, int, const char**, const CliCredentials&,
+ GrpcToolOutputCallback)>
+ function;
+ int min_args;
+ int max_args;
+};
+
+const Command ops[] = {
+ {"help", BindWith5Args(&GrpcTool::Help), 0, INT_MAX},
+ {"ls", BindWith5Args(&GrpcTool::ListServices), 1, 3},
+ {"list", BindWith5Args(&GrpcTool::ListServices), 1, 3},
+ {"call", BindWith5Args(&GrpcTool::CallMethod), 2, 3},
+ {"type", BindWith5Args(&GrpcTool::PrintType), 2, 2},
+ {"parse", BindWith5Args(&GrpcTool::ParseMessage), 2, 3},
+ {"totext", BindWith5Args(&GrpcTool::ToText), 2, 3},
+ {"tobinary", BindWith5Args(&GrpcTool::ToBinary), 2, 3},
+ {"tojson", BindWith5Args(&GrpcTool::ToJson), 2, 3},
+};
+
+void Usage(const TString& msg) {
+ fprintf(
+ stderr,
+ "%s\n"
+ " grpc_cli ls ... ; List services\n"
+ " grpc_cli call ... ; Call method\n"
+ " grpc_cli type ... ; Print type\n"
+ " grpc_cli parse ... ; Parse message\n"
+ " grpc_cli totext ... ; Convert binary message to text\n"
+ " grpc_cli tojson ... ; Convert binary message to json\n"
+ " grpc_cli tobinary ... ; Convert text message to binary\n"
+ " grpc_cli help ... ; Print this message, or per-command usage\n"
+ "\n",
+ msg.c_str());
+
+ exit(1);
+}
+
+const Command* FindCommand(const TString& name) {
+ for (int i = 0; i < (int)ArraySize(ops); i++) {
+ if (name == ops[i].command) {
+ return &ops[i];
+ }
+ }
+ return nullptr;
+}
+} // namespace
+
+int GrpcToolMainLib(int argc, const char** argv, const CliCredentials& cred,
+ GrpcToolOutputCallback callback) {
+ if (argc < 2) {
+ Usage("No command specified");
+ }
+
+ TString command = argv[1];
+ argc -= 2;
+ argv += 2;
+
+ const Command* cmd = FindCommand(command);
+ if (cmd != nullptr) {
+ GrpcTool grpc_tool;
+ if (argc < cmd->min_args || argc > cmd->max_args) {
+ // Force the command to print its usage message
+ fprintf(stderr, "\nWrong number of arguments for %s\n", command.c_str());
+ grpc_tool.SetPrintCommandMode(1);
+ return cmd->function(&grpc_tool, -1, nullptr, cred, callback);
+ }
+ const bool ok = cmd->function(&grpc_tool, argc, argv, cred, callback);
+ return ok ? 0 : 1;
+ } else {
+ Usage("Invalid command '" + TString(command.c_str()) + "'");
+ }
+ return 1;
+}
+
+GrpcTool::GrpcTool() : print_command_usage_(false), usage_exit_status_(0) {}
+
+void GrpcTool::CommandUsage(const TString& usage) const {
+ if (print_command_usage_) {
+ fprintf(stderr, "\n%s%s\n", usage.c_str(),
+ (usage.empty() || usage[usage.size() - 1] != '\n') ? "\n" : "");
+ exit(usage_exit_status_);
+ }
+}
+
+bool GrpcTool::Help(int argc, const char** argv, const CliCredentials& cred,
+ GrpcToolOutputCallback callback) {
+ CommandUsage(
+ "Print help\n"
+ " grpc_cli help [subcommand]\n");
+
+ if (argc == 0) {
+ Usage("");
+ } else {
+ const Command* cmd = FindCommand(argv[0]);
+ if (cmd == nullptr) {
+ Usage("Unknown command '" + TString(argv[0]) + "'");
+ }
+ SetPrintCommandMode(0);
+ cmd->function(this, -1, nullptr, cred, callback);
+ }
+ return true;
+}
+
+bool GrpcTool::ListServices(int argc, const char** argv,
+ const CliCredentials& cred,
+ GrpcToolOutputCallback callback) {
+ CommandUsage(
+ "List services\n"
+ " grpc_cli ls <address> [<service>[/<method>]]\n"
+ " <address> ; host:port\n"
+ " <service> ; Exported service name\n"
+ " <method> ; Method name\n"
+ " --l ; Use a long listing format\n"
+ " --outfile ; Output filename (defaults to stdout)\n" +
+ cred.GetCredentialUsage());
+
+ TString server_address(argv[0]);
+ std::shared_ptr<grpc::Channel> channel =
+ CreateCliChannel(server_address, cred);
+ grpc::ProtoReflectionDescriptorDatabase desc_db(channel);
+ grpc::protobuf::DescriptorPool desc_pool(&desc_db);
+
+ std::vector<TString> service_list;
+ if (!desc_db.GetServices(&service_list)) {
+ fprintf(stderr, "Received an error when querying services endpoint.\n");
+ return false;
+ }
+
+ // If no service is specified, dump the list of services.
+ TString output;
+ if (argc < 2) {
+ // List all services, if --l is passed, then include full description,
+ // otherwise include a summarized list only.
+ if (FLAGS_l) {
+ output = DescribeServiceList(service_list, desc_pool);
+ } else {
+ for (auto it = service_list.begin(); it != service_list.end(); it++) {
+ auto const& service = *it;
+ output.append(service);
+ output.append("\n");
+ }
+ }
+ } else {
+ std::string service_name;
+ std::string method_name;
+ std::stringstream ss(argv[1]);
+
+ // Remove leading slashes.
+ while (ss.peek() == '/') {
+ ss.get();
+ }
+
+ // Parse service and method names. Support the following patterns:
+ // Service
+ // Service Method
+ // Service.Method
+ // Service/Method
+ if (argc == 3) {
+ std::getline(ss, service_name, '/');
+ method_name = argv[2];
+ } else {
+ if (std::getline(ss, service_name, '/')) {
+ std::getline(ss, method_name);
+ }
+ }
+
+ const grpc::protobuf::ServiceDescriptor* service =
+ desc_pool.FindServiceByName(google::protobuf::string(service_name));
+ if (service != nullptr) {
+ if (method_name.empty()) {
+ output = FLAGS_l ? DescribeService(service) : SummarizeService(service);
+ } else {
+ method_name.insert(0, ".");
+ method_name.insert(0, service_name);
+ const grpc::protobuf::MethodDescriptor* method =
+ desc_pool.FindMethodByName(google::protobuf::string(method_name));
+ if (method != nullptr) {
+ output = FLAGS_l ? DescribeMethod(method) : SummarizeMethod(method);
+ } else {
+ fprintf(stderr, "Method %s not found in service %s.\n",
+ method_name.c_str(), service_name.c_str());
+ return false;
+ }
+ }
+ } else {
+ if (!method_name.empty()) {
+ fprintf(stderr, "Service %s not found.\n", service_name.c_str());
+ return false;
+ } else {
+ const grpc::protobuf::MethodDescriptor* method =
+ desc_pool.FindMethodByName(google::protobuf::string(service_name));
+ if (method != nullptr) {
+ output = FLAGS_l ? DescribeMethod(method) : SummarizeMethod(method);
+ } else {
+ fprintf(stderr, "Service or method %s not found.\n",
+ service_name.c_str());
+ return false;
+ }
+ }
+ }
+ }
+ return callback(output);
+}
+
+bool GrpcTool::PrintType(int /*argc*/, const char** argv,
+ const CliCredentials& cred,
+ GrpcToolOutputCallback callback) {
+ CommandUsage(
+ "Print type\n"
+ " grpc_cli type <address> <type>\n"
+ " <address> ; host:port\n"
+ " <type> ; Protocol buffer type name\n" +
+ cred.GetCredentialUsage());
+
+ TString server_address(argv[0]);
+ std::shared_ptr<grpc::Channel> channel =
+ CreateCliChannel(server_address, cred);
+ grpc::ProtoReflectionDescriptorDatabase desc_db(channel);
+ grpc::protobuf::DescriptorPool desc_pool(&desc_db);
+
+ TString output;
+ const grpc::protobuf::Descriptor* descriptor =
+ desc_pool.FindMessageTypeByName(argv[1]);
+ if (descriptor != nullptr) {
+ output = descriptor->DebugString();
+ } else {
+ fprintf(stderr, "Type %s not found.\n", argv[1]);
+ return false;
+ }
+ return callback(output);
+}
+
+bool GrpcTool::CallMethod(int argc, const char** argv,
+ const CliCredentials& cred,
+ GrpcToolOutputCallback callback) {
+ CommandUsage(
+ "Call method\n"
+ " grpc_cli call <address> <service>[.<method>] <request>\n"
+ " <address> ; host:port\n"
+ " <service> ; Exported service name\n"
+ " <method> ; Method name\n"
+ " <request> ; Text protobuffer (overrides infile)\n"
+ " --protofiles ; Comma separated proto files used as a"
+ " fallback when parsing request/response\n"
+ " --proto_path ; The search path of proto files, valid"
+ " only when --protofiles is given\n"
+ " --noremotedb ; Don't attempt to use reflection service"
+ " at all\n"
+ " --metadata ; The metadata to be sent to the server\n"
+ " --infile ; Input filename (defaults to stdin)\n"
+ " --outfile ; Output filename (defaults to stdout)\n"
+ " --binary_input ; Input in binary format\n"
+ " --binary_output ; Output in binary format\n"
+ " --json_input ; Input in json format\n"
+ " --json_output ; Output in json format\n"
+ " --timeout ; Specify timeout (in seconds), used to "
+ "set the deadline for RPCs. The default value of -1 means no "
+ "deadline has been set.\n" +
+ cred.GetCredentialUsage());
+
+ std::stringstream output_ss;
+ TString request_text;
+ TString server_address(argv[0]);
+ TString method_name(argv[1]);
+ TString formatted_method_name;
+ std::unique_ptr<ProtoFileParser> parser;
+ TString serialized_request_proto;
+ CliArgs cli_args;
+ cli_args.timeout = FLAGS_timeout;
+ bool print_mode = false;
+
+ std::shared_ptr<grpc::Channel> channel =
+ CreateCliChannel(server_address, cred);
+
+ if (!FLAGS_binary_input || !FLAGS_binary_output) {
+ parser.reset(
+ new grpc::testing::ProtoFileParser(FLAGS_remotedb ? channel : nullptr,
+ FLAGS_proto_path.c_str(), FLAGS_protofiles.c_str()));
+ if (parser->HasError()) {
+ fprintf(
+ stderr,
+ "Failed to find remote reflection service and local proto files.\n");
+ return false;
+ }
+ }
+
+ if (FLAGS_binary_input) {
+ formatted_method_name = method_name;
+ } else {
+ formatted_method_name = parser->GetFormattedMethodName(method_name);
+ if (parser->HasError()) {
+ fprintf(stderr, "Failed to find method %s in proto files.\n",
+ method_name.c_str());
+ }
+ }
+
+ if (argc == 3) {
+ request_text = argv[2];
+ }
+
+ if (parser->IsStreaming(method_name, true /* is_request */)) {
+ std::istream* input_stream;
+ std::ifstream input_file;
+
+ if (FLAGS_batch) {
+ fprintf(stderr, "Batch mode for streaming RPC is not supported.\n");
+ return false;
+ }
+
+ std::multimap<TString, TString> client_metadata;
+ ParseMetadataFlag(&client_metadata);
+ PrintMetadata(client_metadata, "Sending client initial metadata:");
+
+ CliCall call(channel, formatted_method_name, client_metadata, cli_args);
+ if (FLAGS_display_peer_address) {
+ fprintf(stderr, "New call for method_name:%s has peer address:|%s|\n",
+ formatted_method_name.c_str(), call.peer().c_str());
+ }
+
+ if (FLAGS_infile.empty()) {
+ if (isatty(fileno(stdin))) {
+ print_mode = true;
+ fprintf(stderr, "reading streaming request message from stdin...\n");
+ }
+ input_stream = &std::cin;
+ } else {
+ input_file.open(FLAGS_infile, std::ios::in | std::ios::binary);
+ input_stream = &input_file;
+ }
+
+ gpr_mu parser_mu;
+ gpr_mu_init(&parser_mu);
+ std::thread read_thread(ReadResponse, &call, method_name, callback,
+ parser.get(), &parser_mu, print_mode);
+
+ std::stringstream request_ss;
+ std::string line;
+ while (!request_text.empty() ||
+ (!input_stream->eof() && getline(*input_stream, line))) {
+ if (!request_text.empty()) {
+ if (FLAGS_binary_input) {
+ serialized_request_proto = request_text;
+ request_text.clear();
+ } else {
+ gpr_mu_lock(&parser_mu);
+ serialized_request_proto = parser->GetSerializedProtoFromMethod(
+ method_name, request_text, true /* is_request */,
+ FLAGS_json_input);
+ request_text.clear();
+ if (parser->HasError()) {
+ if (print_mode) {
+ fprintf(stderr, "Failed to parse request.\n");
+ }
+ gpr_mu_unlock(&parser_mu);
+ continue;
+ }
+ gpr_mu_unlock(&parser_mu);
+ }
+
+ call.WriteAndWait(serialized_request_proto);
+ if (print_mode) {
+ fprintf(stderr, "Request sent.\n");
+ }
+ } else {
+ if (line.length() == 0) {
+ request_text = request_ss.str();
+ request_ss.str(TString());
+ request_ss.clear();
+ } else {
+ request_ss << line << ' ';
+ }
+ }
+ }
+ if (input_file.is_open()) {
+ input_file.close();
+ }
+
+ call.WritesDoneAndWait();
+ read_thread.join();
+ gpr_mu_destroy(&parser_mu);
+
+ std::multimap<grpc::string_ref, grpc::string_ref> server_trailing_metadata;
+ Status status = call.Finish(&server_trailing_metadata);
+ PrintMetadata(server_trailing_metadata,
+ "Received trailing metadata from server:");
+
+ if (status.ok()) {
+ fprintf(stderr, "Stream RPC succeeded with OK status\n");
+ return true;
+ } else {
+ fprintf(stderr, "Rpc failed with status code %d, error message: %s\n",
+ status.error_code(), status.error_message().c_str());
+ return false;
+ }
+
+ } else { // parser->IsStreaming(method_name, true /* is_request */)
+ if (FLAGS_batch) {
+ if (parser->IsStreaming(method_name, false /* is_request */)) {
+ fprintf(stderr, "Batch mode for streaming RPC is not supported.\n");
+ return false;
+ }
+
+ std::istream* input_stream;
+ std::ifstream input_file;
+
+ if (FLAGS_infile.empty()) {
+ if (isatty(fileno(stdin))) {
+ print_mode = true;
+ fprintf(stderr, "reading request messages from stdin...\n");
+ }
+ input_stream = &std::cin;
+ } else {
+ input_file.open(FLAGS_infile, std::ios::in | std::ios::binary);
+ input_stream = &input_file;
+ }
+
+ std::multimap<TString, TString> client_metadata;
+ ParseMetadataFlag(&client_metadata);
+ if (print_mode) {
+ PrintMetadata(client_metadata, "Sending client initial metadata:");
+ }
+
+ std::stringstream request_ss;
+ std::string line;
+ while (!request_text.empty() ||
+ (!input_stream->eof() && getline(*input_stream, line))) {
+ if (!request_text.empty()) {
+ if (FLAGS_binary_input) {
+ serialized_request_proto = request_text;
+ request_text.clear();
+ } else {
+ serialized_request_proto = parser->GetSerializedProtoFromMethod(
+ method_name, request_text, true /* is_request */,
+ FLAGS_json_input);
+ request_text.clear();
+ if (parser->HasError()) {
+ if (print_mode) {
+ fprintf(stderr, "Failed to parse request.\n");
+ }
+ continue;
+ }
+ }
+
+ TString serialized_response_proto;
+ std::multimap<grpc::string_ref, grpc::string_ref>
+ server_initial_metadata, server_trailing_metadata;
+ CliCall call(channel, formatted_method_name, client_metadata,
+ cli_args);
+ if (FLAGS_display_peer_address) {
+ fprintf(stderr,
+ "New call for method_name:%s has peer address:|%s|\n",
+ formatted_method_name.c_str(), call.peer().c_str());
+ }
+ call.Write(serialized_request_proto);
+ call.WritesDone();
+ if (!call.Read(&serialized_response_proto,
+ &server_initial_metadata)) {
+ fprintf(stderr, "Failed to read response.\n");
+ }
+ Status status = call.Finish(&server_trailing_metadata);
+
+ if (status.ok()) {
+ if (print_mode) {
+ fprintf(stderr, "Rpc succeeded with OK status.\n");
+ PrintMetadata(server_initial_metadata,
+ "Received initial metadata from server:");
+ PrintMetadata(server_trailing_metadata,
+ "Received trailing metadata from server:");
+ }
+
+ if (FLAGS_binary_output) {
+ if (!callback(serialized_response_proto)) {
+ break;
+ }
+ } else {
+ TString response_text = parser->GetFormattedStringFromMethod(
+ method_name, serialized_response_proto,
+ false /* is_request */, FLAGS_json_output);
+
+ if (parser->HasError() && print_mode) {
+ fprintf(stderr, "Failed to parse response.\n");
+ } else {
+ if (!callback(response_text)) {
+ break;
+ }
+ }
+ }
+ } else {
+ if (print_mode) {
+ fprintf(stderr,
+ "Rpc failed with status code %d, error message: %s\n",
+ status.error_code(), status.error_message().c_str());
+ }
+ }
+ } else {
+ if (line.length() == 0) {
+ request_text = request_ss.str();
+ request_ss.str(TString());
+ request_ss.clear();
+ } else {
+ request_ss << line << ' ';
+ }
+ }
+ }
+
+ if (input_file.is_open()) {
+ input_file.close();
+ }
+
+ return true;
+ }
+
+ if (argc == 3) {
+ if (!FLAGS_infile.empty()) {
+ fprintf(stderr, "warning: request given in argv, ignoring --infile\n");
+ }
+ } else {
+ std::stringstream input_stream;
+ if (FLAGS_infile.empty()) {
+ if (isatty(fileno(stdin))) {
+ fprintf(stderr, "reading request message from stdin...\n");
+ }
+ input_stream << std::cin.rdbuf();
+ } else {
+ std::ifstream input_file(FLAGS_infile, std::ios::in | std::ios::binary);
+ input_stream << input_file.rdbuf();
+ input_file.close();
+ }
+ request_text = input_stream.str();
+ }
+
+ if (FLAGS_binary_input) {
+ serialized_request_proto = request_text;
+ } else {
+ serialized_request_proto = parser->GetSerializedProtoFromMethod(
+ method_name, request_text, true /* is_request */, FLAGS_json_input);
+ if (parser->HasError()) {
+ fprintf(stderr, "Failed to parse request.\n");
+ return false;
+ }
+ }
+ fprintf(stderr, "connecting to %s\n", server_address.c_str());
+
+ TString serialized_response_proto;
+ std::multimap<TString, TString> client_metadata;
+ std::multimap<grpc::string_ref, grpc::string_ref> server_initial_metadata,
+ server_trailing_metadata;
+ ParseMetadataFlag(&client_metadata);
+ PrintMetadata(client_metadata, "Sending client initial metadata:");
+
+ CliCall call(channel, formatted_method_name, client_metadata, cli_args);
+ if (FLAGS_display_peer_address) {
+ fprintf(stderr, "New call for method_name:%s has peer address:|%s|\n",
+ formatted_method_name.c_str(), call.peer().c_str());
+ }
+ call.Write(serialized_request_proto);
+ call.WritesDone();
+
+ for (bool receive_initial_metadata = true; call.Read(
+ &serialized_response_proto,
+ receive_initial_metadata ? &server_initial_metadata : nullptr);
+ receive_initial_metadata = false) {
+ if (!FLAGS_binary_output) {
+ serialized_response_proto = parser->GetFormattedStringFromMethod(
+ method_name, serialized_response_proto, false /* is_request */,
+ FLAGS_json_output);
+ if (parser->HasError()) {
+ fprintf(stderr, "Failed to parse response.\n");
+ return false;
+ }
+ }
+
+ if (receive_initial_metadata) {
+ PrintMetadata(server_initial_metadata,
+ "Received initial metadata from server:");
+ }
+ if (!callback(serialized_response_proto)) {
+ return false;
+ }
+ }
+ Status status = call.Finish(&server_trailing_metadata);
+ PrintMetadata(server_trailing_metadata,
+ "Received trailing metadata from server:");
+ if (status.ok()) {
+ fprintf(stderr, "Rpc succeeded with OK status\n");
+ return true;
+ } else {
+ fprintf(stderr, "Rpc failed with status code %d, error message: %s\n",
+ status.error_code(), status.error_message().c_str());
+ return false;
+ }
+ }
+ GPR_UNREACHABLE_CODE(return false);
+}
+
+bool GrpcTool::ParseMessage(int argc, const char** argv,
+ const CliCredentials& cred,
+ GrpcToolOutputCallback callback) {
+ CommandUsage(
+ "Parse message\n"
+ " grpc_cli parse <address> <type> [<message>]\n"
+ " <address> ; host:port\n"
+ " <type> ; Protocol buffer type name\n"
+ " <message> ; Text protobuffer (overrides --infile)\n"
+ " --protofiles ; Comma separated proto files used as a"
+ " fallback when parsing request/response\n"
+ " --proto_path ; The search path of proto files, valid"
+ " only when --protofiles is given\n"
+ " --noremotedb ; Don't attempt to use reflection service"
+ " at all\n"
+ " --infile ; Input filename (defaults to stdin)\n"
+ " --outfile ; Output filename (defaults to stdout)\n"
+ " --binary_input ; Input in binary format\n"
+ " --binary_output ; Output in binary format\n"
+ " --json_input ; Input in json format\n"
+ " --json_output ; Output in json format\n" +
+ cred.GetCredentialUsage());
+
+ std::stringstream output_ss;
+ TString message_text;
+ TString server_address(argv[0]);
+ TString type_name(argv[1]);
+ std::unique_ptr<grpc::testing::ProtoFileParser> parser;
+ TString serialized_request_proto;
+
+ if (argc == 3) {
+ message_text = argv[2];
+ if (!FLAGS_infile.empty()) {
+ fprintf(stderr, "warning: message given in argv, ignoring --infile.\n");
+ }
+ } else {
+ std::stringstream input_stream;
+ if (FLAGS_infile.empty()) {
+ if (isatty(fileno(stdin))) {
+ fprintf(stderr, "reading request message from stdin...\n");
+ }
+ input_stream << std::cin.rdbuf();
+ } else {
+ std::ifstream input_file(FLAGS_infile, std::ios::in | std::ios::binary);
+ input_stream << input_file.rdbuf();
+ input_file.close();
+ }
+ message_text = input_stream.str();
+ }
+
+ if (!FLAGS_binary_input || !FLAGS_binary_output) {
+ std::shared_ptr<grpc::Channel> channel =
+ CreateCliChannel(server_address, cred);
+ parser.reset(
+ new grpc::testing::ProtoFileParser(FLAGS_remotedb ? channel : nullptr,
+ FLAGS_proto_path.c_str(), FLAGS_protofiles.c_str()));
+ if (parser->HasError()) {
+ fprintf(
+ stderr,
+ "Failed to find remote reflection service and local proto files.\n");
+ return false;
+ }
+ }
+
+ if (FLAGS_binary_input) {
+ serialized_request_proto = message_text;
+ } else {
+ serialized_request_proto = parser->GetSerializedProtoFromMessageType(
+ type_name, message_text, FLAGS_json_input);
+ if (parser->HasError()) {
+ fprintf(stderr, "Failed to serialize the message.\n");
+ return false;
+ }
+ }
+
+ if (FLAGS_binary_output) {
+ output_ss << serialized_request_proto;
+ } else {
+ TString output_text;
+ output_text = parser->GetFormattedStringFromMessageType(
+ type_name, serialized_request_proto, FLAGS_json_output);
+ if (parser->HasError()) {
+ fprintf(stderr, "Failed to deserialize the message.\n");
+ return false;
+ }
+
+ output_ss << output_text << std::endl;
+ }
+
+ return callback(output_ss.str());
+}
+
+bool GrpcTool::ToText(int argc, const char** argv, const CliCredentials& cred,
+ GrpcToolOutputCallback callback) {
+ CommandUsage(
+ "Convert binary message to text\n"
+ " grpc_cli totext <protofiles> <type>\n"
+ " <protofiles> ; Comma separated list of proto files\n"
+ " <type> ; Protocol buffer type name\n"
+ " --proto_path ; The search path of proto files\n"
+ " --infile ; Input filename (defaults to stdin)\n"
+ " --outfile ; Output filename (defaults to stdout)\n");
+
+ FLAGS_protofiles = argv[0];
+ FLAGS_remotedb = false;
+ FLAGS_binary_input = true;
+ FLAGS_binary_output = false;
+ return ParseMessage(argc, argv, cred, callback);
+}
+
+bool GrpcTool::ToJson(int argc, const char** argv, const CliCredentials& cred,
+ GrpcToolOutputCallback callback) {
+ CommandUsage(
+ "Convert binary message to json\n"
+ " grpc_cli tojson <protofiles> <type>\n"
+ " <protofiles> ; Comma separated list of proto files\n"
+ " <type> ; Protocol buffer type name\n"
+ " --proto_path ; The search path of proto files\n"
+ " --infile ; Input filename (defaults to stdin)\n"
+ " --outfile ; Output filename (defaults to stdout)\n");
+
+ FLAGS_protofiles = argv[0];
+ FLAGS_remotedb = false;
+ FLAGS_binary_input = true;
+ FLAGS_binary_output = false;
+ FLAGS_json_output = true;
+ return ParseMessage(argc, argv, cred, callback);
+}
+
+bool GrpcTool::ToBinary(int argc, const char** argv, const CliCredentials& cred,
+ GrpcToolOutputCallback callback) {
+ CommandUsage(
+ "Convert text message to binary\n"
+ " grpc_cli tobinary <protofiles> <type> [<message>]\n"
+ " <protofiles> ; Comma separated list of proto files\n"
+ " <type> ; Protocol buffer type name\n"
+ " --proto_path ; The search path of proto files\n"
+ " --infile ; Input filename (defaults to stdin)\n"
+ " --outfile ; Output filename (defaults to stdout)\n");
+
+ FLAGS_protofiles = argv[0];
+ FLAGS_remotedb = false;
+ FLAGS_binary_input = false;
+ FLAGS_binary_output = true;
+ return ParseMessage(argc, argv, cred, callback);
+}
+
+} // namespace testing
+} // namespace grpc
diff --git a/contrib/libs/grpc/test/cpp/util/grpc_tool.h b/contrib/libs/grpc/test/cpp/util/grpc_tool.h
new file mode 100644
index 0000000000..5bb43430d3
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/util/grpc_tool.h
@@ -0,0 +1,39 @@
+/*
+ *
+ * Copyright 2016 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#ifndef GRPC_TEST_CPP_UTIL_GRPC_TOOL_H
+#define GRPC_TEST_CPP_UTIL_GRPC_TOOL_H
+
+#include <functional>
+
+#include <grpcpp/support/config.h>
+
+#include "test/cpp/util/cli_credentials.h"
+
+namespace grpc {
+namespace testing {
+
+typedef std::function<bool(const TString&)> GrpcToolOutputCallback;
+
+int GrpcToolMainLib(int argc, const char** argv, const CliCredentials& cred,
+ GrpcToolOutputCallback callback);
+
+} // namespace testing
+} // namespace grpc
+
+#endif // GRPC_TEST_CPP_UTIL_GRPC_TOOL_H
diff --git a/contrib/libs/grpc/test/cpp/util/grpc_tool_test.cc b/contrib/libs/grpc/test/cpp/util/grpc_tool_test.cc
new file mode 100644
index 0000000000..ff610daadd
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/util/grpc_tool_test.cc
@@ -0,0 +1,1344 @@
+/*
+ *
+ * Copyright 2016 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include "test/cpp/util/grpc_tool.h"
+
+#include <gflags/gflags.h>
+#include <grpc/grpc.h>
+#include <grpc/support/alloc.h>
+#include <grpcpp/channel.h>
+#include <grpcpp/client_context.h>
+#include <grpcpp/create_channel.h>
+#include <grpcpp/ext/proto_server_reflection_plugin.h>
+#include <grpcpp/server.h>
+#include <grpcpp/server_builder.h>
+#include <grpcpp/server_context.h>
+#include <gtest/gtest.h>
+
+#include <chrono>
+#include <sstream>
+
+#include "src/core/lib/gpr/env.h"
+#include "src/core/lib/iomgr/load_file.h"
+#include "src/proto/grpc/testing/echo.grpc.pb.h"
+#include "src/proto/grpc/testing/echo.pb.h"
+#include "test/core/util/port.h"
+#include "test/core/util/test_config.h"
+#include "test/cpp/util/cli_credentials.h"
+#include "test/cpp/util/string_ref_helper.h"
+
+#define CA_CERT_PATH "src/core/tsi/test_creds/ca.pem"
+#define SERVER_CERT_PATH "src/core/tsi/test_creds/server1.pem"
+#define SERVER_KEY_PATH "src/core/tsi/test_creds/server1.key"
+
+using grpc::testing::EchoRequest;
+using grpc::testing::EchoResponse;
+
+#define USAGE_REGEX "( grpc_cli .+\n){2,10}"
+
+#define ECHO_TEST_SERVICE_SUMMARY \
+ "Echo\n" \
+ "Echo1\n" \
+ "Echo2\n" \
+ "CheckDeadlineUpperBound\n" \
+ "CheckDeadlineSet\n" \
+ "CheckClientInitialMetadata\n" \
+ "RequestStream\n" \
+ "ResponseStream\n" \
+ "BidiStream\n" \
+ "Unimplemented\n"
+
+#define ECHO_TEST_SERVICE_DESCRIPTION \
+ "filename: src/proto/grpc/testing/echo.proto\n" \
+ "package: grpc.testing;\n" \
+ "service EchoTestService {\n" \
+ " rpc Echo(grpc.testing.EchoRequest) returns (grpc.testing.EchoResponse) " \
+ "{}\n" \
+ " rpc Echo1(grpc.testing.EchoRequest) returns (grpc.testing.EchoResponse) " \
+ "{}\n" \
+ " rpc Echo2(grpc.testing.EchoRequest) returns (grpc.testing.EchoResponse) " \
+ "{}\n" \
+ " rpc CheckDeadlineUpperBound(grpc.testing.SimpleRequest) returns " \
+ "(grpc.testing.StringValue) {}\n" \
+ " rpc CheckDeadlineSet(grpc.testing.SimpleRequest) returns " \
+ "(grpc.testing.StringValue) {}\n" \
+ " rpc CheckClientInitialMetadata(grpc.testing.SimpleRequest) returns " \
+ "(grpc.testing.SimpleResponse) {}\n" \
+ " rpc RequestStream(stream grpc.testing.EchoRequest) returns " \
+ "(grpc.testing.EchoResponse) {}\n" \
+ " rpc ResponseStream(grpc.testing.EchoRequest) returns (stream " \
+ "grpc.testing.EchoResponse) {}\n" \
+ " rpc BidiStream(stream grpc.testing.EchoRequest) returns (stream " \
+ "grpc.testing.EchoResponse) {}\n" \
+ " rpc Unimplemented(grpc.testing.EchoRequest) returns " \
+ "(grpc.testing.EchoResponse) {}\n" \
+ "}\n" \
+ "\n"
+
+#define ECHO_METHOD_DESCRIPTION \
+ " rpc Echo(grpc.testing.EchoRequest) returns (grpc.testing.EchoResponse) " \
+ "{}\n"
+
+#define ECHO_RESPONSE_MESSAGE_TEXT_FORMAT \
+ "message: \"echo\"\n" \
+ "param {\n" \
+ " host: \"localhost\"\n" \
+ " peer: \"peer\"\n" \
+ "}\n\n"
+
+#define ECHO_RESPONSE_MESSAGE_JSON_FORMAT \
+ "{\n" \
+ " \"message\": \"echo\",\n" \
+ " \"param\": {\n" \
+ " \"host\": \"localhost\",\n" \
+ " \"peer\": \"peer\"\n" \
+ " }\n" \
+ "}\n\n"
+
+DECLARE_string(channel_creds_type);
+DECLARE_string(ssl_target);
+
+namespace grpc {
+namespace testing {
+
+DECLARE_bool(binary_input);
+DECLARE_bool(binary_output);
+DECLARE_bool(json_input);
+DECLARE_bool(json_output);
+DECLARE_bool(l);
+DECLARE_bool(batch);
+DECLARE_string(metadata);
+DECLARE_string(protofiles);
+DECLARE_string(proto_path);
+DECLARE_string(default_service_config);
+DECLARE_double(timeout);
+
+namespace {
+
+const int kServerDefaultResponseStreamsToSend = 3;
+
+class TestCliCredentials final : public grpc::testing::CliCredentials {
+ public:
+ TestCliCredentials(bool secure = false) : secure_(secure) {}
+ std::shared_ptr<grpc::ChannelCredentials> GetChannelCredentials()
+ const override {
+ if (!secure_) {
+ return InsecureChannelCredentials();
+ }
+ grpc_slice ca_slice;
+ GPR_ASSERT(GRPC_LOG_IF_ERROR("load_file",
+ grpc_load_file(CA_CERT_PATH, 1, &ca_slice)));
+ const char* test_root_cert =
+ reinterpret_cast<const char*> GRPC_SLICE_START_PTR(ca_slice);
+ SslCredentialsOptions ssl_opts = {test_root_cert, "", ""};
+ std::shared_ptr<grpc::ChannelCredentials> credential_ptr =
+ grpc::SslCredentials(grpc::SslCredentialsOptions(ssl_opts));
+ grpc_slice_unref(ca_slice);
+ return credential_ptr;
+ }
+ const TString GetCredentialUsage() const override { return ""; }
+
+ private:
+ const bool secure_;
+};
+
+bool PrintStream(std::stringstream* ss, const TString& output) {
+ (*ss) << output;
+ return true;
+}
+
+template <typename T>
+size_t ArraySize(T& a) {
+ return ((sizeof(a) / sizeof(*(a))) /
+ static_cast<size_t>(!(sizeof(a) % sizeof(*(a)))));
+}
+
+class TestServiceImpl : public ::grpc::testing::EchoTestService::Service {
+ public:
+ Status Echo(ServerContext* context, const EchoRequest* request,
+ EchoResponse* response) override {
+ if (!context->client_metadata().empty()) {
+ for (std::multimap<grpc::string_ref, grpc::string_ref>::const_iterator
+ iter = context->client_metadata().begin();
+ iter != context->client_metadata().end(); ++iter) {
+ context->AddInitialMetadata(ToString(iter->first),
+ ToString(iter->second));
+ }
+ }
+ context->AddTrailingMetadata("trailing_key", "trailing_value");
+ response->set_message(request->message());
+ return Status::OK;
+ }
+
+ Status CheckDeadlineSet(ServerContext* context, const SimpleRequest* request,
+ StringValue* response) override {
+ response->set_message(context->deadline() !=
+ std::chrono::system_clock::time_point::max()
+ ? "true"
+ : "false");
+ return Status::OK;
+ }
+
+ // Check if deadline - current time <= timeout
+ // If deadline set, timeout + current time should be an upper bound for it
+ Status CheckDeadlineUpperBound(ServerContext* context,
+ const SimpleRequest* request,
+ StringValue* response) override {
+ auto seconds = std::chrono::duration_cast<std::chrono::seconds>(
+ context->deadline() - std::chrono::system_clock::now());
+
+ // Returning string instead of bool to avoid using embedded messages in
+ // proto3
+ response->set_message(seconds.count() <= FLAGS_timeout ? "true" : "false");
+ return Status::OK;
+ }
+
+ Status RequestStream(ServerContext* context,
+ ServerReader<EchoRequest>* reader,
+ EchoResponse* response) override {
+ EchoRequest request;
+ response->set_message("");
+ if (!context->client_metadata().empty()) {
+ for (std::multimap<grpc::string_ref, grpc::string_ref>::const_iterator
+ iter = context->client_metadata().begin();
+ iter != context->client_metadata().end(); ++iter) {
+ context->AddInitialMetadata(ToString(iter->first),
+ ToString(iter->second));
+ }
+ }
+ context->AddTrailingMetadata("trailing_key", "trailing_value");
+ while (reader->Read(&request)) {
+ response->mutable_message()->append(request.message());
+ }
+
+ return Status::OK;
+ }
+
+ Status ResponseStream(ServerContext* context, const EchoRequest* request,
+ ServerWriter<EchoResponse>* writer) override {
+ if (!context->client_metadata().empty()) {
+ for (std::multimap<grpc::string_ref, grpc::string_ref>::const_iterator
+ iter = context->client_metadata().begin();
+ iter != context->client_metadata().end(); ++iter) {
+ context->AddInitialMetadata(ToString(iter->first),
+ ToString(iter->second));
+ }
+ }
+ context->AddTrailingMetadata("trailing_key", "trailing_value");
+
+ EchoResponse response;
+ for (int i = 0; i < kServerDefaultResponseStreamsToSend; i++) {
+ response.set_message(request->message() + ToString(i));
+ writer->Write(response);
+ }
+
+ return Status::OK;
+ }
+
+ Status BidiStream(
+ ServerContext* context,
+ ServerReaderWriter<EchoResponse, EchoRequest>* stream) override {
+ EchoRequest request;
+ EchoResponse response;
+ if (!context->client_metadata().empty()) {
+ for (std::multimap<grpc::string_ref, grpc::string_ref>::const_iterator
+ iter = context->client_metadata().begin();
+ iter != context->client_metadata().end(); ++iter) {
+ context->AddInitialMetadata(ToString(iter->first),
+ ToString(iter->second));
+ }
+ }
+ context->AddTrailingMetadata("trailing_key", "trailing_value");
+
+ while (stream->Read(&request)) {
+ response.set_message(request.message());
+ stream->Write(response);
+ }
+
+ return Status::OK;
+ }
+};
+
+} // namespace
+
+class GrpcToolTest : public ::testing::Test {
+ protected:
+ GrpcToolTest() {}
+
+ // SetUpServer cannot be used with EXPECT_EXIT. grpc_pick_unused_port_or_die()
+ // uses atexit() to free chosen ports, and it will spawn a new thread in
+ // resolve_address_posix.c:192 at exit time.
+ const TString SetUpServer(bool secure = false) {
+ std::ostringstream server_address;
+ int port = grpc_pick_unused_port_or_die();
+ server_address << "localhost:" << port;
+ // Setup server
+ ServerBuilder builder;
+ std::shared_ptr<grpc::ServerCredentials> creds;
+ grpc_slice cert_slice, key_slice;
+ GPR_ASSERT(GRPC_LOG_IF_ERROR(
+ "load_file", grpc_load_file(SERVER_CERT_PATH, 1, &cert_slice)));
+ GPR_ASSERT(GRPC_LOG_IF_ERROR(
+ "load_file", grpc_load_file(SERVER_KEY_PATH, 1, &key_slice)));
+ const char* server_cert =
+ reinterpret_cast<const char*> GRPC_SLICE_START_PTR(cert_slice);
+ const char* server_key =
+ reinterpret_cast<const char*> GRPC_SLICE_START_PTR(key_slice);
+ SslServerCredentialsOptions::PemKeyCertPair pkcp = {server_key,
+ server_cert};
+ if (secure) {
+ SslServerCredentialsOptions ssl_opts;
+ ssl_opts.pem_root_certs = "";
+ ssl_opts.pem_key_cert_pairs.push_back(pkcp);
+ creds = SslServerCredentials(ssl_opts);
+ } else {
+ creds = InsecureServerCredentials();
+ }
+ builder.AddListeningPort(server_address.str(), creds);
+ builder.RegisterService(&service_);
+ server_ = builder.BuildAndStart();
+ grpc_slice_unref(cert_slice);
+ grpc_slice_unref(key_slice);
+ return server_address.str();
+ }
+
+ void ShutdownServer() { server_->Shutdown(); }
+
+ std::unique_ptr<Server> server_;
+ TestServiceImpl service_;
+ reflection::ProtoServerReflectionPlugin plugin_;
+};
+
+TEST_F(GrpcToolTest, NoCommand) {
+ // Test input "grpc_cli"
+ std::stringstream output_stream;
+ const char* argv[] = {"grpc_cli"};
+ // Exit with 1, print usage instruction in stderr
+ EXPECT_EXIT(
+ GrpcToolMainLib(
+ ArraySize(argv), argv, TestCliCredentials(),
+ std::bind(PrintStream, &output_stream, std::placeholders::_1)),
+ ::testing::ExitedWithCode(1), "No command specified\n" USAGE_REGEX);
+ // No output
+ EXPECT_TRUE(0 == output_stream.tellp());
+}
+
+TEST_F(GrpcToolTest, InvalidCommand) {
+ // Test input "grpc_cli"
+ std::stringstream output_stream;
+ const char* argv[] = {"grpc_cli", "abc"};
+ // Exit with 1, print usage instruction in stderr
+ EXPECT_EXIT(
+ GrpcToolMainLib(
+ ArraySize(argv), argv, TestCliCredentials(),
+ std::bind(PrintStream, &output_stream, std::placeholders::_1)),
+ ::testing::ExitedWithCode(1), "Invalid command 'abc'\n" USAGE_REGEX);
+ // No output
+ EXPECT_TRUE(0 == output_stream.tellp());
+}
+
+TEST_F(GrpcToolTest, HelpCommand) {
+ // Test input "grpc_cli help"
+ std::stringstream output_stream;
+ const char* argv[] = {"grpc_cli", "help"};
+ // Exit with 1, print usage instruction in stderr
+ EXPECT_EXIT(GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(),
+ std::bind(PrintStream, &output_stream,
+ std::placeholders::_1)),
+ ::testing::ExitedWithCode(1), USAGE_REGEX);
+ // No output
+ EXPECT_TRUE(0 == output_stream.tellp());
+}
+
+TEST_F(GrpcToolTest, ListCommand) {
+ // Test input "grpc_cli list localhost:<port>"
+ std::stringstream output_stream;
+
+ const TString server_address = SetUpServer();
+ const char* argv[] = {"grpc_cli", "ls", server_address.c_str()};
+
+ FLAGS_l = false;
+ EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(),
+ std::bind(PrintStream, &output_stream,
+ std::placeholders::_1)));
+ EXPECT_TRUE(0 == strcmp(output_stream.str().c_str(),
+ "grpc.testing.EchoTestService\n"
+ "grpc.reflection.v1alpha.ServerReflection\n"));
+
+ ShutdownServer();
+}
+
+TEST_F(GrpcToolTest, ListOneService) {
+ // Test input "grpc_cli list localhost:<port> grpc.testing.EchoTestService"
+ std::stringstream output_stream;
+
+ const TString server_address = SetUpServer();
+ const char* argv[] = {"grpc_cli", "ls", server_address.c_str(),
+ "grpc.testing.EchoTestService"};
+ // without -l flag
+ FLAGS_l = false;
+ EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(),
+ std::bind(PrintStream, &output_stream,
+ std::placeholders::_1)));
+ // Expected output: ECHO_TEST_SERVICE_SUMMARY
+ EXPECT_TRUE(0 ==
+ strcmp(output_stream.str().c_str(), ECHO_TEST_SERVICE_SUMMARY));
+
+ // with -l flag
+ output_stream.str(TString());
+ output_stream.clear();
+ FLAGS_l = true;
+ EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(),
+ std::bind(PrintStream, &output_stream,
+ std::placeholders::_1)));
+ // Expected output: ECHO_TEST_SERVICE_DESCRIPTION
+ EXPECT_TRUE(
+ 0 == strcmp(output_stream.str().c_str(), ECHO_TEST_SERVICE_DESCRIPTION));
+
+ ShutdownServer();
+}
+
+TEST_F(GrpcToolTest, TypeCommand) {
+ // Test input "grpc_cli type localhost:<port> grpc.testing.EchoRequest"
+ std::stringstream output_stream;
+
+ const TString server_address = SetUpServer();
+ const char* argv[] = {"grpc_cli", "type", server_address.c_str(),
+ "grpc.testing.EchoRequest"};
+
+ EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(),
+ std::bind(PrintStream, &output_stream,
+ std::placeholders::_1)));
+ const grpc::protobuf::Descriptor* desc =
+ grpc::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName(
+ "grpc.testing.EchoRequest");
+ // Expected output: the DebugString of grpc.testing.EchoRequest
+ EXPECT_TRUE(0 ==
+ strcmp(output_stream.str().c_str(), desc->DebugString().c_str()));
+
+ ShutdownServer();
+}
+
+TEST_F(GrpcToolTest, ListOneMethod) {
+ // Test input "grpc_cli list localhost:<port> grpc.testing.EchoTestService"
+ std::stringstream output_stream;
+
+ const TString server_address = SetUpServer();
+ const char* argv[] = {"grpc_cli", "ls", server_address.c_str(),
+ "grpc.testing.EchoTestService.Echo"};
+ // without -l flag
+ FLAGS_l = false;
+ EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(),
+ std::bind(PrintStream, &output_stream,
+ std::placeholders::_1)));
+ // Expected output: "Echo"
+ EXPECT_TRUE(0 == strcmp(output_stream.str().c_str(), "Echo\n"));
+
+ // with -l flag
+ output_stream.str(TString());
+ output_stream.clear();
+ FLAGS_l = true;
+ EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(),
+ std::bind(PrintStream, &output_stream,
+ std::placeholders::_1)));
+ // Expected output: ECHO_METHOD_DESCRIPTION
+ EXPECT_TRUE(0 ==
+ strcmp(output_stream.str().c_str(), ECHO_METHOD_DESCRIPTION));
+
+ ShutdownServer();
+}
+
+TEST_F(GrpcToolTest, TypeNotFound) {
+ // Test input "grpc_cli type localhost:<port> grpc.testing.DummyRequest"
+ std::stringstream output_stream;
+
+ const TString server_address = SetUpServer();
+ const char* argv[] = {"grpc_cli", "type", server_address.c_str(),
+ "grpc.testing.DummyRequest"};
+
+ EXPECT_TRUE(1 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(),
+ std::bind(PrintStream, &output_stream,
+ std::placeholders::_1)));
+ ShutdownServer();
+}
+
+TEST_F(GrpcToolTest, CallCommand) {
+ // Test input "grpc_cli call localhost:<port> Echo "message: 'Hello'"
+ std::stringstream output_stream;
+
+ const TString server_address = SetUpServer();
+ const char* argv[] = {"grpc_cli", "call", server_address.c_str(), "Echo",
+ "message: 'Hello'"};
+
+ EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(),
+ std::bind(PrintStream, &output_stream,
+ std::placeholders::_1)));
+ // Expected output: "message: \"Hello\""
+ EXPECT_TRUE(nullptr !=
+ strstr(output_stream.str().c_str(), "message: \"Hello\""));
+
+ // with json_output
+ output_stream.str(TString());
+ output_stream.clear();
+
+ FLAGS_json_output = true;
+ EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(),
+ std::bind(PrintStream, &output_stream,
+ std::placeholders::_1)));
+ FLAGS_json_output = false;
+
+ // Expected output:
+ // {
+ // "message": "Hello"
+ // }
+ EXPECT_TRUE(nullptr != strstr(output_stream.str().c_str(),
+ "{\n \"message\": \"Hello\"\n}"));
+
+ ShutdownServer();
+}
+
+TEST_F(GrpcToolTest, CallCommandJsonInput) {
+ // Test input "grpc_cli call localhost:<port> Echo "{ \"message\": \"Hello\"}"
+ std::stringstream output_stream;
+
+ const TString server_address = SetUpServer();
+ const char* argv[] = {"grpc_cli", "call", server_address.c_str(), "Echo",
+ "{ \"message\": \"Hello\"}"};
+
+ FLAGS_json_input = true;
+ EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(),
+ std::bind(PrintStream, &output_stream,
+ std::placeholders::_1)));
+ // Expected output: "message: \"Hello\""
+ EXPECT_TRUE(nullptr !=
+ strstr(output_stream.str().c_str(), "message: \"Hello\""));
+
+ // with json_output
+ output_stream.str(TString());
+ output_stream.clear();
+
+ FLAGS_json_output = true;
+ EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(),
+ std::bind(PrintStream, &output_stream,
+ std::placeholders::_1)));
+ FLAGS_json_output = false;
+ FLAGS_json_input = false;
+
+ // Expected output:
+ // {
+ // "message": "Hello"
+ // }
+ EXPECT_TRUE(nullptr != strstr(output_stream.str().c_str(),
+ "{\n \"message\": \"Hello\"\n}"));
+
+ ShutdownServer();
+}
+
+TEST_F(GrpcToolTest, CallCommandBatch) {
+ // Test input "grpc_cli call Echo"
+ std::stringstream output_stream;
+
+ const TString server_address = SetUpServer();
+ const char* argv[] = {"grpc_cli", "call", server_address.c_str(), "Echo",
+ "message: 'Hello0'"};
+
+ // Mock std::cin input "message: 'Hello1'\n\n message: 'Hello2'\n\n"
+ std::streambuf* orig = std::cin.rdbuf();
+ std::istringstream ss("message: 'Hello1'\n\n message: 'Hello2'\n\n");
+ std::cin.rdbuf(ss.rdbuf());
+
+ FLAGS_batch = true;
+ EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(),
+ std::bind(PrintStream, &output_stream,
+ std::placeholders::_1)));
+ FLAGS_batch = false;
+
+ // Expected output: "message: "Hello0"\nmessage: "Hello1"\nmessage:
+ // "Hello2"\n"
+ EXPECT_TRUE(nullptr != strstr(output_stream.str().c_str(),
+ "message: \"Hello0\"\nmessage: "
+ "\"Hello1\"\nmessage: \"Hello2\"\n"));
+ // with json_output
+ output_stream.str(TString());
+ output_stream.clear();
+ ss.clear();
+ ss.seekg(0);
+ std::cin.rdbuf(ss.rdbuf());
+
+ FLAGS_batch = true;
+ FLAGS_json_output = true;
+ EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(),
+ std::bind(PrintStream, &output_stream,
+ std::placeholders::_1)));
+ FLAGS_json_output = false;
+ FLAGS_batch = false;
+
+ // Expected output:
+ // {
+ // "message": "Hello0"
+ // }
+ // {
+ // "message": "Hello1"
+ // }
+ // {
+ // "message": "Hello2"
+ // }
+ // Expected output: "message: "Hello0"\nmessage: "Hello1"\nmessage:
+ // "Hello2"\n"
+ EXPECT_TRUE(nullptr != strstr(output_stream.str().c_str(),
+ "{\n \"message\": \"Hello0\"\n}\n"
+ "{\n \"message\": \"Hello1\"\n}\n"
+ "{\n \"message\": \"Hello2\"\n}\n"));
+
+ std::cin.rdbuf(orig);
+ ShutdownServer();
+}
+
+TEST_F(GrpcToolTest, CallCommandBatchJsonInput) {
+ // Test input "grpc_cli call Echo"
+ std::stringstream output_stream;
+
+ const TString server_address = SetUpServer();
+ const char* argv[] = {"grpc_cli", "call", server_address.c_str(), "Echo",
+ "{\"message\": \"Hello0\"}"};
+
+ // Mock std::cin input "message: 'Hello1'\n\n message: 'Hello2'\n\n"
+ std::streambuf* orig = std::cin.rdbuf();
+ std::istringstream ss(
+ "{\"message\": \"Hello1\"}\n\n{\"message\": \"Hello2\" }\n\n");
+ std::cin.rdbuf(ss.rdbuf());
+
+ FLAGS_json_input = true;
+ FLAGS_batch = true;
+ EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(),
+ std::bind(PrintStream, &output_stream,
+ std::placeholders::_1)));
+ FLAGS_batch = false;
+
+ // Expected output: "message: "Hello0"\nmessage: "Hello1"\nmessage:
+ // "Hello2"\n"
+ EXPECT_TRUE(nullptr != strstr(output_stream.str().c_str(),
+ "message: \"Hello0\"\nmessage: "
+ "\"Hello1\"\nmessage: \"Hello2\"\n"));
+ // with json_output
+ output_stream.str(TString());
+ output_stream.clear();
+ ss.clear();
+ ss.seekg(0);
+ std::cin.rdbuf(ss.rdbuf());
+
+ FLAGS_batch = true;
+ FLAGS_json_output = true;
+ EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(),
+ std::bind(PrintStream, &output_stream,
+ std::placeholders::_1)));
+ FLAGS_json_output = false;
+ FLAGS_batch = false;
+ FLAGS_json_input = false;
+
+ // Expected output:
+ // {
+ // "message": "Hello0"
+ // }
+ // {
+ // "message": "Hello1"
+ // }
+ // {
+ // "message": "Hello2"
+ // }
+ // Expected output: "message: "Hello0"\nmessage: "Hello1"\nmessage:
+ // "Hello2"\n"
+ EXPECT_TRUE(nullptr != strstr(output_stream.str().c_str(),
+ "{\n \"message\": \"Hello0\"\n}\n"
+ "{\n \"message\": \"Hello1\"\n}\n"
+ "{\n \"message\": \"Hello2\"\n}\n"));
+
+ std::cin.rdbuf(orig);
+ ShutdownServer();
+}
+
+TEST_F(GrpcToolTest, CallCommandBatchWithBadRequest) {
+ // Test input "grpc_cli call Echo"
+ std::stringstream output_stream;
+
+ const TString server_address = SetUpServer();
+ const char* argv[] = {"grpc_cli", "call", server_address.c_str(), "Echo",
+ "message: 'Hello0'"};
+
+ // Mock std::cin input "message: 1\n\n message: 'Hello2'\n\n"
+ std::streambuf* orig = std::cin.rdbuf();
+ std::istringstream ss("message: 1\n\n message: 'Hello2'\n\n");
+ std::cin.rdbuf(ss.rdbuf());
+
+ FLAGS_batch = true;
+ EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(),
+ std::bind(PrintStream, &output_stream,
+ std::placeholders::_1)));
+ FLAGS_batch = false;
+
+ // Expected output: "message: "Hello0"\nmessage: "Hello2"\n"
+ EXPECT_TRUE(nullptr != strstr(output_stream.str().c_str(),
+ "message: \"Hello0\"\nmessage: \"Hello2\"\n"));
+
+ // with json_output
+ output_stream.str(TString());
+ output_stream.clear();
+ ss.clear();
+ ss.seekg(0);
+ std::cin.rdbuf(ss.rdbuf());
+
+ FLAGS_batch = true;
+ FLAGS_json_output = true;
+ EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(),
+ std::bind(PrintStream, &output_stream,
+ std::placeholders::_1)));
+ FLAGS_json_output = false;
+ FLAGS_batch = false;
+
+ // Expected output:
+ // {
+ // "message": "Hello0"
+ // }
+ // {
+ // "message": "Hello2"
+ // }
+ // Expected output: "message: "Hello0"\nmessage: "Hello1"\nmessage:
+ // "Hello2"\n"
+ EXPECT_TRUE(nullptr != strstr(output_stream.str().c_str(),
+ "{\n \"message\": \"Hello0\"\n}\n"
+ "{\n \"message\": \"Hello2\"\n}\n"));
+
+ std::cin.rdbuf(orig);
+ ShutdownServer();
+}
+
+TEST_F(GrpcToolTest, CallCommandBatchJsonInputWithBadRequest) {
+ // Test input "grpc_cli call Echo"
+ std::stringstream output_stream;
+
+ const TString server_address = SetUpServer();
+ const char* argv[] = {"grpc_cli", "call", server_address.c_str(), "Echo",
+ "{ \"message\": \"Hello0\"}"};
+
+ // Mock std::cin input "message: 1\n\n message: 'Hello2'\n\n"
+ std::streambuf* orig = std::cin.rdbuf();
+ std::istringstream ss(
+ "{ \"message\": 1 }\n\n { \"message\": \"Hello2\" }\n\n");
+ std::cin.rdbuf(ss.rdbuf());
+
+ FLAGS_batch = true;
+ FLAGS_json_input = true;
+ EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(),
+ std::bind(PrintStream, &output_stream,
+ std::placeholders::_1)));
+ FLAGS_json_input = false;
+ FLAGS_batch = false;
+
+ // Expected output: "message: "Hello0"\nmessage: "Hello2"\n"
+ EXPECT_TRUE(nullptr != strstr(output_stream.str().c_str(),
+ "message: \"Hello0\"\nmessage: \"Hello2\"\n"));
+
+ // with json_output
+ output_stream.str(TString());
+ output_stream.clear();
+ ss.clear();
+ ss.seekg(0);
+ std::cin.rdbuf(ss.rdbuf());
+
+ FLAGS_batch = true;
+ FLAGS_json_input = true;
+ FLAGS_json_output = true;
+ EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(),
+ std::bind(PrintStream, &output_stream,
+ std::placeholders::_1)));
+ FLAGS_json_output = false;
+ FLAGS_json_input = false;
+ FLAGS_batch = false;
+
+ // Expected output:
+ // {
+ // "message": "Hello0"
+ // }
+ // {
+ // "message": "Hello2"
+ // }
+ // Expected output: "message: "Hello0"\nmessage: "Hello1"\nmessage:
+ // "Hello2"\n"
+ EXPECT_TRUE(nullptr != strstr(output_stream.str().c_str(),
+ "{\n \"message\": \"Hello0\"\n}\n"
+ "{\n \"message\": \"Hello2\"\n}\n"));
+
+ std::cin.rdbuf(orig);
+ ShutdownServer();
+}
+
+TEST_F(GrpcToolTest, CallCommandRequestStream) {
+ // Test input: grpc_cli call localhost:<port> RequestStream "message:
+ // 'Hello0'"
+ std::stringstream output_stream;
+
+ const TString server_address = SetUpServer();
+ const char* argv[] = {"grpc_cli", "call", server_address.c_str(),
+ "RequestStream", "message: 'Hello0'"};
+
+ // Mock std::cin input "message: 'Hello1'\n\n message: 'Hello2'\n\n"
+ std::streambuf* orig = std::cin.rdbuf();
+ std::istringstream ss("message: 'Hello1'\n\n message: 'Hello2'\n\n");
+ std::cin.rdbuf(ss.rdbuf());
+
+ EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(),
+ std::bind(PrintStream, &output_stream,
+ std::placeholders::_1)));
+
+ // Expected output: "message: \"Hello0Hello1Hello2\""
+ EXPECT_TRUE(nullptr != strstr(output_stream.str().c_str(),
+ "message: \"Hello0Hello1Hello2\""));
+ std::cin.rdbuf(orig);
+ ShutdownServer();
+}
+
+TEST_F(GrpcToolTest, CallCommandRequestStreamJsonInput) {
+ // Test input: grpc_cli call localhost:<port> RequestStream "{ \"message\":
+ // \"Hello0\"}"
+ std::stringstream output_stream;
+
+ const TString server_address = SetUpServer();
+ const char* argv[] = {"grpc_cli", "call", server_address.c_str(),
+ "RequestStream", "{ \"message\": \"Hello0\" }"};
+
+ // Mock std::cin input "message: 'Hello1'\n\n message: 'Hello2'\n\n"
+ std::streambuf* orig = std::cin.rdbuf();
+ std::istringstream ss(
+ "{ \"message\": \"Hello1\" }\n\n{ \"message\": \"Hello2\" }\n\n");
+ std::cin.rdbuf(ss.rdbuf());
+
+ FLAGS_json_input = true;
+ EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(),
+ std::bind(PrintStream, &output_stream,
+ std::placeholders::_1)));
+ FLAGS_json_input = false;
+
+ // Expected output: "message: \"Hello0Hello1Hello2\""
+ EXPECT_TRUE(nullptr != strstr(output_stream.str().c_str(),
+ "message: \"Hello0Hello1Hello2\""));
+ std::cin.rdbuf(orig);
+ ShutdownServer();
+}
+
+TEST_F(GrpcToolTest, CallCommandRequestStreamWithBadRequest) {
+ // Test input: grpc_cli call localhost:<port> RequestStream "message:
+ // 'Hello0'"
+ std::stringstream output_stream;
+
+ const TString server_address = SetUpServer();
+ const char* argv[] = {"grpc_cli", "call", server_address.c_str(),
+ "RequestStream", "message: 'Hello0'"};
+
+ // Mock std::cin input "bad_field: 'Hello1'\n\n message: 'Hello2'\n\n"
+ std::streambuf* orig = std::cin.rdbuf();
+ std::istringstream ss("bad_field: 'Hello1'\n\n message: 'Hello2'\n\n");
+ std::cin.rdbuf(ss.rdbuf());
+
+ EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(),
+ std::bind(PrintStream, &output_stream,
+ std::placeholders::_1)));
+
+ // Expected output: "message: \"Hello0Hello2\""
+ EXPECT_TRUE(nullptr !=
+ strstr(output_stream.str().c_str(), "message: \"Hello0Hello2\""));
+ std::cin.rdbuf(orig);
+ ShutdownServer();
+}
+
+TEST_F(GrpcToolTest, CallCommandRequestStreamWithBadRequestJsonInput) {
+ // Test input: grpc_cli call localhost:<port> RequestStream "message:
+ // 'Hello0'"
+ std::stringstream output_stream;
+
+ const TString server_address = SetUpServer();
+ const char* argv[] = {"grpc_cli", "call", server_address.c_str(),
+ "RequestStream", "{ \"message\": \"Hello0\" }"};
+
+ // Mock std::cin input "bad_field: 'Hello1'\n\n message: 'Hello2'\n\n"
+ std::streambuf* orig = std::cin.rdbuf();
+ std::istringstream ss(
+ "{ \"bad_field\": \"Hello1\" }\n\n{ \"message\": \"Hello2\" }\n\n");
+ std::cin.rdbuf(ss.rdbuf());
+
+ FLAGS_json_input = true;
+ EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(),
+ std::bind(PrintStream, &output_stream,
+ std::placeholders::_1)));
+ FLAGS_json_input = false;
+
+ // Expected output: "message: \"Hello0Hello2\""
+ EXPECT_TRUE(nullptr !=
+ strstr(output_stream.str().c_str(), "message: \"Hello0Hello2\""));
+ std::cin.rdbuf(orig);
+ ShutdownServer();
+}
+
+TEST_F(GrpcToolTest, CallCommandWithTimeoutDeadlineSet) {
+ // Test input "grpc_cli call CheckDeadlineSet --timeout=5000.25"
+ std::stringstream output_stream;
+
+ const TString server_address = SetUpServer();
+ const char* argv[] = {"grpc_cli", "call", server_address.c_str(),
+ "CheckDeadlineSet"};
+
+ // Set timeout to 5000.25 seconds
+ FLAGS_timeout = 5000.25;
+
+ EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(),
+ std::bind(PrintStream, &output_stream,
+ std::placeholders::_1)));
+
+ // Expected output: "message: "true"", deadline set
+ EXPECT_TRUE(nullptr !=
+ strstr(output_stream.str().c_str(), "message: \"true\""));
+ ShutdownServer();
+}
+
+TEST_F(GrpcToolTest, CallCommandWithTimeoutDeadlineUpperBound) {
+ // Test input "grpc_cli call CheckDeadlineUpperBound --timeout=900"
+ std::stringstream output_stream;
+
+ const TString server_address = SetUpServer();
+ const char* argv[] = {"grpc_cli", "call", server_address.c_str(),
+ "CheckDeadlineUpperBound"};
+
+ // Set timeout to 900 seconds
+ FLAGS_timeout = 900;
+
+ EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(),
+ std::bind(PrintStream, &output_stream,
+ std::placeholders::_1)));
+
+ // Expected output: "message: "true""
+ // deadline not greater than timeout + current time
+ EXPECT_TRUE(nullptr !=
+ strstr(output_stream.str().c_str(), "message: \"true\""));
+ ShutdownServer();
+}
+
+TEST_F(GrpcToolTest, CallCommandWithNegativeTimeoutValue) {
+ // Test input "grpc_cli call CheckDeadlineSet --timeout=-5"
+ std::stringstream output_stream;
+
+ const TString server_address = SetUpServer();
+ const char* argv[] = {"grpc_cli", "call", server_address.c_str(),
+ "CheckDeadlineSet"};
+
+ // Set timeout to -5 (deadline not set)
+ FLAGS_timeout = -5;
+
+ EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(),
+ std::bind(PrintStream, &output_stream,
+ std::placeholders::_1)));
+
+ // Expected output: "message: "false"", deadline not set
+ EXPECT_TRUE(nullptr !=
+ strstr(output_stream.str().c_str(), "message: \"false\""));
+
+ ShutdownServer();
+}
+
+TEST_F(GrpcToolTest, CallCommandWithDefaultTimeoutValue) {
+ // Test input "grpc_cli call CheckDeadlineSet --timeout=-1"
+ std::stringstream output_stream;
+
+ const TString server_address = SetUpServer();
+ const char* argv[] = {"grpc_cli", "call", server_address.c_str(),
+ "CheckDeadlineSet"};
+
+ // Set timeout to -1 (default value, deadline not set)
+ FLAGS_timeout = -1;
+
+ EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(),
+ std::bind(PrintStream, &output_stream,
+ std::placeholders::_1)));
+
+ // Expected output: "message: "false"", deadline not set
+ EXPECT_TRUE(nullptr !=
+ strstr(output_stream.str().c_str(), "message: \"false\""));
+
+ ShutdownServer();
+}
+
+TEST_F(GrpcToolTest, CallCommandResponseStream) {
+ // Test input: grpc_cli call localhost:<port> ResponseStream "message:
+ // 'Hello'"
+ std::stringstream output_stream;
+
+ const TString server_address = SetUpServer();
+ const char* argv[] = {"grpc_cli", "call", server_address.c_str(),
+ "ResponseStream", "message: 'Hello'"};
+
+ EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(),
+ std::bind(PrintStream, &output_stream,
+ std::placeholders::_1)));
+
+ // Expected output: "message: \"Hello{n}\""
+ for (int i = 0; i < kServerDefaultResponseStreamsToSend; i++) {
+ TString expected_response_text =
+ "message: \"Hello" + ToString(i) + "\"\n";
+ EXPECT_TRUE(nullptr != strstr(output_stream.str().c_str(),
+ expected_response_text.c_str()));
+ }
+
+ // with json_output
+ output_stream.str(TString());
+ output_stream.clear();
+
+ FLAGS_json_output = true;
+ EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(),
+ std::bind(PrintStream, &output_stream,
+ std::placeholders::_1)));
+ FLAGS_json_output = false;
+
+ // Expected output: "{\n \"message\": \"Hello{n}\"\n}\n"
+ for (int i = 0; i < kServerDefaultResponseStreamsToSend; i++) {
+ TString expected_response_text =
+ "{\n \"message\": \"Hello" + ToString(i) + "\"\n}\n";
+ EXPECT_TRUE(nullptr != strstr(output_stream.str().c_str(),
+ expected_response_text.c_str()));
+ }
+
+ ShutdownServer();
+}
+
+TEST_F(GrpcToolTest, CallCommandBidiStream) {
+ // Test input: grpc_cli call localhost:<port> BidiStream "message: 'Hello0'"
+ std::stringstream output_stream;
+
+ const TString server_address = SetUpServer();
+ const char* argv[] = {"grpc_cli", "call", server_address.c_str(),
+ "BidiStream", "message: 'Hello0'"};
+
+ // Mock std::cin input "message: 'Hello1'\n\n message: 'Hello2'\n\n"
+ std::streambuf* orig = std::cin.rdbuf();
+ std::istringstream ss("message: 'Hello1'\n\n message: 'Hello2'\n\n");
+ std::cin.rdbuf(ss.rdbuf());
+
+ EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(),
+ std::bind(PrintStream, &output_stream,
+ std::placeholders::_1)));
+
+ // Expected output: "message: \"Hello0\"\nmessage: \"Hello1\"\nmessage:
+ // \"Hello2\"\n\n"
+ EXPECT_TRUE(nullptr != strstr(output_stream.str().c_str(),
+ "message: \"Hello0\"\nmessage: "
+ "\"Hello1\"\nmessage: \"Hello2\"\n"));
+ std::cin.rdbuf(orig);
+ ShutdownServer();
+}
+
+TEST_F(GrpcToolTest, CallCommandBidiStreamWithBadRequest) {
+ // Test input: grpc_cli call localhost:<port> BidiStream "message: 'Hello0'"
+ std::stringstream output_stream;
+
+ const TString server_address = SetUpServer();
+ const char* argv[] = {"grpc_cli", "call", server_address.c_str(),
+ "BidiStream", "message: 'Hello0'"};
+
+ // Mock std::cin input "message: 'Hello1'\n\n message: 'Hello2'\n\n"
+ std::streambuf* orig = std::cin.rdbuf();
+ std::istringstream ss("message: 1.0\n\n message: 'Hello2'\n\n");
+ std::cin.rdbuf(ss.rdbuf());
+
+ EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(),
+ std::bind(PrintStream, &output_stream,
+ std::placeholders::_1)));
+
+ // Expected output: "message: \"Hello0\"\nmessage: \"Hello1\"\nmessage:
+ // \"Hello2\"\n\n"
+ EXPECT_TRUE(nullptr != strstr(output_stream.str().c_str(),
+ "message: \"Hello0\"\nmessage: \"Hello2\"\n"));
+ std::cin.rdbuf(orig);
+
+ ShutdownServer();
+}
+
+TEST_F(GrpcToolTest, ParseCommand) {
+ // Test input "grpc_cli parse localhost:<port> grpc.testing.EchoResponse
+ // ECHO_RESPONSE_MESSAGE"
+ std::stringstream output_stream;
+ std::stringstream binary_output_stream;
+
+ const TString server_address = SetUpServer();
+ const char* argv[] = {"grpc_cli", "parse", server_address.c_str(),
+ "grpc.testing.EchoResponse",
+ ECHO_RESPONSE_MESSAGE_TEXT_FORMAT};
+
+ FLAGS_binary_input = false;
+ FLAGS_binary_output = false;
+ EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(),
+ std::bind(PrintStream, &output_stream,
+ std::placeholders::_1)));
+ // Expected output: ECHO_RESPONSE_MESSAGE_TEXT_FORMAT
+ EXPECT_TRUE(0 == strcmp(output_stream.str().c_str(),
+ ECHO_RESPONSE_MESSAGE_TEXT_FORMAT));
+
+ // with json_output
+ output_stream.str(TString());
+ output_stream.clear();
+
+ FLAGS_json_output = true;
+ EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(),
+ std::bind(PrintStream, &output_stream,
+ std::placeholders::_1)));
+ FLAGS_json_output = false;
+
+ // Expected output: ECHO_RESPONSE_MESSAGE_JSON_FORMAT
+ EXPECT_TRUE(0 == strcmp(output_stream.str().c_str(),
+ ECHO_RESPONSE_MESSAGE_JSON_FORMAT));
+
+ // Parse text message to binary message and then parse it back to text message
+ output_stream.str(TString());
+ output_stream.clear();
+ FLAGS_binary_output = true;
+ EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(),
+ std::bind(PrintStream, &output_stream,
+ std::placeholders::_1)));
+ TString binary_data = output_stream.str();
+ output_stream.str(TString());
+ output_stream.clear();
+ argv[4] = binary_data.c_str();
+ FLAGS_binary_input = true;
+ FLAGS_binary_output = false;
+ EXPECT_TRUE(0 == GrpcToolMainLib(5, argv, TestCliCredentials(),
+ std::bind(PrintStream, &output_stream,
+ std::placeholders::_1)));
+
+ // Expected output: ECHO_RESPONSE_MESSAGE
+ EXPECT_TRUE(0 == strcmp(output_stream.str().c_str(),
+ ECHO_RESPONSE_MESSAGE_TEXT_FORMAT));
+
+ FLAGS_binary_input = false;
+ FLAGS_binary_output = false;
+ ShutdownServer();
+}
+
+TEST_F(GrpcToolTest, ParseCommandJsonFormat) {
+ // Test input "grpc_cli parse localhost:<port> grpc.testing.EchoResponse
+ // ECHO_RESPONSE_MESSAGE_JSON_FORMAT"
+ std::stringstream output_stream;
+ std::stringstream binary_output_stream;
+
+ const TString server_address = SetUpServer();
+ const char* argv[] = {"grpc_cli", "parse", server_address.c_str(),
+ "grpc.testing.EchoResponse",
+ ECHO_RESPONSE_MESSAGE_JSON_FORMAT};
+
+ FLAGS_json_input = true;
+ EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(),
+ std::bind(PrintStream, &output_stream,
+ std::placeholders::_1)));
+
+ // Expected output: ECHO_RESPONSE_MESSAGE_TEXT_FORMAT
+ EXPECT_TRUE(0 == strcmp(output_stream.str().c_str(),
+ ECHO_RESPONSE_MESSAGE_TEXT_FORMAT));
+
+ // with json_output
+ output_stream.str(TString());
+ output_stream.clear();
+
+ FLAGS_json_output = true;
+ EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(),
+ std::bind(PrintStream, &output_stream,
+ std::placeholders::_1)));
+ FLAGS_json_output = false;
+ FLAGS_json_input = false;
+
+ // Expected output: ECHO_RESPONSE_MESSAGE_JSON_FORMAT
+ EXPECT_TRUE(0 == strcmp(output_stream.str().c_str(),
+ ECHO_RESPONSE_MESSAGE_JSON_FORMAT));
+
+ ShutdownServer();
+}
+
+TEST_F(GrpcToolTest, TooFewArguments) {
+ // Test input "grpc_cli call Echo"
+ std::stringstream output_stream;
+ const char* argv[] = {"grpc_cli", "call", "Echo"};
+
+ // Exit with 1
+ EXPECT_EXIT(
+ GrpcToolMainLib(
+ ArraySize(argv), argv, TestCliCredentials(),
+ std::bind(PrintStream, &output_stream, std::placeholders::_1)),
+ ::testing::ExitedWithCode(1), ".*Wrong number of arguments for call.*");
+ // No output
+ EXPECT_TRUE(0 == output_stream.tellp());
+}
+
+TEST_F(GrpcToolTest, TooManyArguments) {
+ // Test input "grpc_cli call localhost:<port> Echo Echo "message: 'Hello'"
+ std::stringstream output_stream;
+ const char* argv[] = {"grpc_cli", "call", "localhost:10000",
+ "Echo", "Echo", "message: 'Hello'"};
+
+ // Exit with 1
+ EXPECT_EXIT(
+ GrpcToolMainLib(
+ ArraySize(argv), argv, TestCliCredentials(),
+ std::bind(PrintStream, &output_stream, std::placeholders::_1)),
+ ::testing::ExitedWithCode(1), ".*Wrong number of arguments for call.*");
+ // No output
+ EXPECT_TRUE(0 == output_stream.tellp());
+}
+
+TEST_F(GrpcToolTest, CallCommandWithMetadata) {
+ // Test input "grpc_cli call localhost:<port> Echo "message: 'Hello'"
+ const TString server_address = SetUpServer();
+ const char* argv[] = {"grpc_cli", "call", server_address.c_str(), "Echo",
+ "message: 'Hello'"};
+
+ {
+ std::stringstream output_stream;
+ FLAGS_metadata = "key0:val0:key1:valq:key2:val2";
+ EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv,
+ TestCliCredentials(),
+ std::bind(PrintStream, &output_stream,
+ std::placeholders::_1)));
+ // Expected output: "message: \"Hello\""
+ EXPECT_TRUE(nullptr !=
+ strstr(output_stream.str().c_str(), "message: \"Hello\""));
+ }
+
+ {
+ std::stringstream output_stream;
+ FLAGS_metadata = "key:val\\:val";
+ EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv,
+ TestCliCredentials(),
+ std::bind(PrintStream, &output_stream,
+ std::placeholders::_1)));
+ // Expected output: "message: \"Hello\""
+ EXPECT_TRUE(nullptr !=
+ strstr(output_stream.str().c_str(), "message: \"Hello\""));
+ }
+
+ {
+ std::stringstream output_stream;
+ FLAGS_metadata = "key:val\\\\val";
+ EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv,
+ TestCliCredentials(),
+ std::bind(PrintStream, &output_stream,
+ std::placeholders::_1)));
+ // Expected output: "message: \"Hello\""
+ EXPECT_TRUE(nullptr !=
+ strstr(output_stream.str().c_str(), "message: \"Hello\""));
+ }
+
+ FLAGS_metadata = "";
+ ShutdownServer();
+}
+
+TEST_F(GrpcToolTest, CallCommandWithBadMetadata) {
+ // Test input "grpc_cli call localhost:10000 Echo "message: 'Hello'"
+ const char* argv[] = {"grpc_cli", "call", "localhost:10000",
+ "grpc.testing.EchoTestService.Echo",
+ "message: 'Hello'"};
+ FLAGS_protofiles = "src/proto/grpc/testing/echo.proto";
+ char* test_srcdir = gpr_getenv("TEST_SRCDIR");
+ if (test_srcdir != nullptr) {
+ FLAGS_proto_path = test_srcdir + TString("/com_github_grpc_grpc");
+ }
+
+ {
+ std::stringstream output_stream;
+ FLAGS_metadata = "key0:val0:key1";
+ // Exit with 1
+ EXPECT_EXIT(
+ GrpcToolMainLib(
+ ArraySize(argv), argv, TestCliCredentials(),
+ std::bind(PrintStream, &output_stream, std::placeholders::_1)),
+ ::testing::ExitedWithCode(1), ".*Failed to parse metadata flag.*");
+ }
+
+ {
+ std::stringstream output_stream;
+ FLAGS_metadata = "key:val\\val";
+ // Exit with 1
+ EXPECT_EXIT(
+ GrpcToolMainLib(
+ ArraySize(argv), argv, TestCliCredentials(),
+ std::bind(PrintStream, &output_stream, std::placeholders::_1)),
+ ::testing::ExitedWithCode(1), ".*Failed to parse metadata flag.*");
+ }
+
+ FLAGS_metadata = "";
+ FLAGS_protofiles = "";
+
+ gpr_free(test_srcdir);
+}
+
+TEST_F(GrpcToolTest, ListCommand_OverrideSslHostName) {
+ const TString server_address = SetUpServer(true);
+
+ // Test input "grpc_cli ls localhost:<port> --channel_creds_type=ssl
+ // --ssl_target=z.test.google.fr"
+ std::stringstream output_stream;
+ const char* argv[] = {"grpc_cli", "ls", server_address.c_str()};
+ FLAGS_l = false;
+ FLAGS_channel_creds_type = "ssl";
+ FLAGS_ssl_target = "z.test.google.fr";
+ EXPECT_TRUE(
+ 0 == GrpcToolMainLib(
+ ArraySize(argv), argv, TestCliCredentials(true),
+ std::bind(PrintStream, &output_stream, std::placeholders::_1)));
+ EXPECT_TRUE(0 == strcmp(output_stream.str().c_str(),
+ "grpc.testing.EchoTestService\n"
+ "grpc.reflection.v1alpha.ServerReflection\n"));
+
+ FLAGS_channel_creds_type = "";
+ FLAGS_ssl_target = "";
+ ShutdownServer();
+}
+
+TEST_F(GrpcToolTest, ConfiguringDefaultServiceConfig) {
+ // Test input "grpc_cli list localhost:<port>
+ // --default_service_config={\"loadBalancingConfig\":[{\"pick_first\":{}}]}"
+ std::stringstream output_stream;
+ const TString server_address = SetUpServer();
+ const char* argv[] = {"grpc_cli", "ls", server_address.c_str()};
+ // Just check that the tool is still operational when --default_service_config
+ // is configured. This particular service config is in reality redundant with
+ // the channel's default configuration.
+ FLAGS_l = false;
+ FLAGS_default_service_config =
+ "{\"loadBalancingConfig\":[{\"pick_first\":{}}]}";
+ EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(),
+ std::bind(PrintStream, &output_stream,
+ std::placeholders::_1)));
+ FLAGS_default_service_config = "";
+ EXPECT_TRUE(0 == strcmp(output_stream.str().c_str(),
+ "grpc.testing.EchoTestService\n"
+ "grpc.reflection.v1alpha.ServerReflection\n"));
+ ShutdownServer();
+}
+
+} // namespace testing
+} // namespace grpc
+
+int main(int argc, char** argv) {
+ grpc::testing::TestEnvironment env(argc, argv);
+ ::testing::InitGoogleTest(&argc, argv);
+ ::testing::FLAGS_gtest_death_test_style = "threadsafe";
+ return RUN_ALL_TESTS();
+}
diff --git a/contrib/libs/grpc/test/cpp/util/metrics_server.cc b/contrib/libs/grpc/test/cpp/util/metrics_server.cc
new file mode 100644
index 0000000000..0493da053e
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/util/metrics_server.cc
@@ -0,0 +1,117 @@
+/*
+ *
+ * Copyright 2015 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *is % allowed in string
+ */
+
+#include "test/cpp/util/metrics_server.h"
+
+#include <grpc/support/log.h>
+#include <grpcpp/server.h>
+#include <grpcpp/server_builder.h>
+
+#include "src/proto/grpc/testing/metrics.grpc.pb.h"
+#include "src/proto/grpc/testing/metrics.pb.h"
+
+namespace grpc {
+namespace testing {
+
+QpsGauge::QpsGauge()
+ : start_time_(gpr_now(GPR_CLOCK_REALTIME)), num_queries_(0) {}
+
+void QpsGauge::Reset() {
+ std::lock_guard<std::mutex> lock(num_queries_mu_);
+ num_queries_ = 0;
+ start_time_ = gpr_now(GPR_CLOCK_REALTIME);
+}
+
+void QpsGauge::Incr() {
+ std::lock_guard<std::mutex> lock(num_queries_mu_);
+ num_queries_++;
+}
+
+long QpsGauge::Get() {
+ std::lock_guard<std::mutex> lock(num_queries_mu_);
+ gpr_timespec time_diff =
+ gpr_time_sub(gpr_now(GPR_CLOCK_REALTIME), start_time_);
+ long duration_secs = time_diff.tv_sec > 0 ? time_diff.tv_sec : 1;
+ return num_queries_ / duration_secs;
+}
+
+grpc::Status MetricsServiceImpl::GetAllGauges(
+ ServerContext* /*context*/, const EmptyMessage* /*request*/,
+ ServerWriter<GaugeResponse>* writer) {
+ gpr_log(GPR_DEBUG, "GetAllGauges called");
+
+ std::lock_guard<std::mutex> lock(mu_);
+ for (auto it = qps_gauges_.begin(); it != qps_gauges_.end(); it++) {
+ GaugeResponse resp;
+ resp.set_name(it->first); // Gauge name
+ resp.set_long_value(it->second->Get()); // Gauge value
+ writer->Write(resp);
+ }
+
+ return Status::OK;
+}
+
+grpc::Status MetricsServiceImpl::GetGauge(ServerContext* /*context*/,
+ const GaugeRequest* request,
+ GaugeResponse* response) {
+ std::lock_guard<std::mutex> lock(mu_);
+
+ const auto it = qps_gauges_.find(request->name());
+ if (it != qps_gauges_.end()) {
+ response->set_name(it->first);
+ response->set_long_value(it->second->Get());
+ }
+
+ return Status::OK;
+}
+
+std::shared_ptr<QpsGauge> MetricsServiceImpl::CreateQpsGauge(
+ const TString& name, bool* already_present) {
+ std::lock_guard<std::mutex> lock(mu_);
+
+ std::shared_ptr<QpsGauge> qps_gauge(new QpsGauge());
+ const auto p = qps_gauges_.insert(std::make_pair(name, qps_gauge));
+
+ // p.first is an iterator pointing to <name, shared_ptr<QpsGauge>> pair.
+ // p.second is a boolean which is set to 'true' if the QpsGauge is
+ // successfully inserted in the guages_ map and 'false' if it is already
+ // present in the map
+ *already_present = !p.second;
+ return p.first->second;
+}
+
+// Starts the metrics server and returns the grpc::Server instance. Call Wait()
+// on the returned server instance.
+std::unique_ptr<grpc::Server> MetricsServiceImpl::StartServer(int port) {
+ gpr_log(GPR_INFO, "Building metrics server..");
+
+ const TString address = "0.0.0.0:" + ToString(port);
+
+ ServerBuilder builder;
+ builder.AddListeningPort(address, grpc::InsecureServerCredentials());
+ builder.RegisterService(this);
+
+ std::unique_ptr<grpc::Server> server(builder.BuildAndStart());
+ gpr_log(GPR_INFO, "Metrics server %s started. Ready to receive requests..",
+ address.c_str());
+
+ return server;
+}
+
+} // namespace testing
+} // namespace grpc
diff --git a/contrib/libs/grpc/test/cpp/util/metrics_server.h b/contrib/libs/grpc/test/cpp/util/metrics_server.h
new file mode 100644
index 0000000000..10ffa7b4dd
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/util/metrics_server.h
@@ -0,0 +1,98 @@
+/*
+ *
+ * Copyright 2015 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *is % allowed in string
+ */
+#ifndef GRPC_TEST_CPP_METRICS_SERVER_H
+#define GRPC_TEST_CPP_METRICS_SERVER_H
+
+#include <map>
+#include <mutex>
+
+#include <grpcpp/server.h>
+
+#include "src/proto/grpc/testing/metrics.grpc.pb.h"
+#include "src/proto/grpc/testing/metrics.pb.h"
+
+/*
+ * This implements a Metrics server defined in
+ * src/proto/grpc/testing/metrics.proto. Any
+ * test service can use this to export Metrics (TODO (sreek): Only Gauges for
+ * now).
+ *
+ * Example:
+ * MetricsServiceImpl metricsImpl;
+ * ..
+ * // Create QpsGauge(s). Note: QpsGauges can be created even after calling
+ * // 'StartServer'.
+ * QpsGauge qps_gauge1 = metricsImpl.CreateQpsGauge("foo", is_present);
+ * // qps_gauge1 can now be used anywhere in the program by first making a
+ * // one-time call qps_gauge1.Reset() and then calling qps_gauge1.Incr()
+ * // every time to increment a query counter
+ *
+ * ...
+ * // Create the metrics server
+ * std::unique_ptr<grpc::Server> server = metricsImpl.StartServer(port);
+ * server->Wait(); // Note: This is blocking.
+ */
+namespace grpc {
+namespace testing {
+
+class QpsGauge {
+ public:
+ QpsGauge();
+
+ // Initialize the internal timer and reset the query count to 0
+ void Reset();
+
+ // Increment the query count by 1
+ void Incr();
+
+ // Return the current qps (i.e query count divided by the time since this
+ // QpsGauge object created (or Reset() was called))
+ long Get();
+
+ private:
+ gpr_timespec start_time_;
+ long num_queries_;
+ std::mutex num_queries_mu_;
+};
+
+class MetricsServiceImpl final : public MetricsService::Service {
+ public:
+ grpc::Status GetAllGauges(ServerContext* context, const EmptyMessage* request,
+ ServerWriter<GaugeResponse>* writer) override;
+
+ grpc::Status GetGauge(ServerContext* context, const GaugeRequest* request,
+ GaugeResponse* response) override;
+
+ // Create a QpsGauge with name 'name'. is_present is set to true if the Gauge
+ // is already present in the map.
+ // NOTE: CreateQpsGauge can be called anytime (i.e before or after calling
+ // StartServer).
+ std::shared_ptr<QpsGauge> CreateQpsGauge(const TString& name,
+ bool* already_present);
+
+ std::unique_ptr<grpc::Server> StartServer(int port);
+
+ private:
+ std::map<string, std::shared_ptr<QpsGauge>> qps_gauges_;
+ std::mutex mu_;
+};
+
+} // namespace testing
+} // namespace grpc
+
+#endif // GRPC_TEST_CPP_METRICS_SERVER_H
diff --git a/contrib/libs/grpc/test/cpp/util/proto_file_parser.cc b/contrib/libs/grpc/test/cpp/util/proto_file_parser.cc
new file mode 100644
index 0000000000..b0912a712c
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/util/proto_file_parser.cc
@@ -0,0 +1,323 @@
+/*
+ *
+ * Copyright 2016 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include "test/cpp/util/proto_file_parser.h"
+
+#include <algorithm>
+#include <iostream>
+#include <sstream>
+#include <unordered_set>
+
+#include <grpcpp/support/config.h>
+
+namespace grpc {
+namespace testing {
+namespace {
+
+// Match the user input method string to the full_name from method descriptor.
+bool MethodNameMatch(const TString& full_name, const TString& input) {
+ TString clean_input = input;
+ std::replace(clean_input.begin(), clean_input.vend(), '/', '.');
+ if (clean_input.size() > full_name.size()) {
+ return false;
+ }
+ return full_name.compare(full_name.size() - clean_input.size(),
+ clean_input.size(), clean_input) == 0;
+}
+} // namespace
+
+class ErrorPrinter : public protobuf::compiler::MultiFileErrorCollector {
+ public:
+ explicit ErrorPrinter(ProtoFileParser* parser) : parser_(parser) {}
+
+ void AddError(const google::protobuf::string& filename, int line, int column,
+ const google::protobuf::string& message) override {
+ std::ostringstream oss;
+ oss << "error " << filename << " " << line << " " << column << " "
+ << message << "\n";
+ parser_->LogError(oss.str());
+ }
+
+ void AddWarning(const google::protobuf::string& filename, int line, int column,
+ const google::protobuf::string& message) override {
+ std::cerr << "warning " << filename << " " << line << " " << column << " "
+ << message << std::endl;
+ }
+
+ private:
+ ProtoFileParser* parser_; // not owned
+};
+
+ProtoFileParser::ProtoFileParser(const std::shared_ptr<grpc::Channel>& channel,
+ const TString& proto_path,
+ const TString& protofiles)
+ : has_error_(false),
+ dynamic_factory_(new protobuf::DynamicMessageFactory()) {
+ std::vector<TString> service_list;
+ if (channel) {
+ reflection_db_.reset(new grpc::ProtoReflectionDescriptorDatabase(channel));
+ reflection_db_->GetServices(&service_list);
+ }
+
+ std::unordered_set<TString> known_services;
+ if (!protofiles.empty()) {
+ source_tree_.MapPath("", google::protobuf::string(proto_path));
+ error_printer_.reset(new ErrorPrinter(this));
+ importer_.reset(
+ new protobuf::compiler::Importer(&source_tree_, error_printer_.get()));
+
+ std::string file_name;
+ std::stringstream ss(protofiles);
+ while (std::getline(ss, file_name, ',')) {
+ const auto* file_desc = importer_->Import(google::protobuf::string(file_name.c_str()));
+ if (file_desc) {
+ for (int i = 0; i < file_desc->service_count(); i++) {
+ service_desc_list_.push_back(file_desc->service(i));
+ known_services.insert(file_desc->service(i)->full_name());
+ }
+ } else {
+ std::cerr << file_name << " not found" << std::endl;
+ }
+ }
+
+ file_db_.reset(new protobuf::DescriptorPoolDatabase(*importer_->pool()));
+ }
+
+ if (!reflection_db_ && !file_db_) {
+ LogError("No available proto database");
+ return;
+ }
+
+ if (!reflection_db_) {
+ desc_db_ = std::move(file_db_);
+ } else if (!file_db_) {
+ desc_db_ = std::move(reflection_db_);
+ } else {
+ desc_db_.reset(new protobuf::MergedDescriptorDatabase(reflection_db_.get(),
+ file_db_.get()));
+ }
+
+ desc_pool_.reset(new protobuf::DescriptorPool(desc_db_.get()));
+
+ for (auto it = service_list.begin(); it != service_list.end(); it++) {
+ if (known_services.find(*it) == known_services.end()) {
+ if (const protobuf::ServiceDescriptor* service_desc =
+ desc_pool_->FindServiceByName(google::protobuf::string(*it))) {
+ service_desc_list_.push_back(service_desc);
+ known_services.insert(*it);
+ }
+ }
+ }
+}
+
+ProtoFileParser::~ProtoFileParser() {}
+
+TString ProtoFileParser::GetFullMethodName(const TString& method) {
+ has_error_ = false;
+
+ if (known_methods_.find(method) != known_methods_.end()) {
+ return known_methods_[method];
+ }
+
+ const protobuf::MethodDescriptor* method_descriptor = nullptr;
+ for (auto it = service_desc_list_.begin(); it != service_desc_list_.end();
+ it++) {
+ const auto* service_desc = *it;
+ for (int j = 0; j < service_desc->method_count(); j++) {
+ const auto* method_desc = service_desc->method(j);
+ if (MethodNameMatch(method_desc->full_name(), method)) {
+ if (method_descriptor) {
+ std::ostringstream error_stream;
+ error_stream << "Ambiguous method names: ";
+ error_stream << method_descriptor->full_name() << " ";
+ error_stream << method_desc->full_name();
+ LogError(error_stream.str());
+ }
+ method_descriptor = method_desc;
+ }
+ }
+ }
+ if (!method_descriptor) {
+ LogError("Method name not found");
+ }
+ if (has_error_) {
+ return "";
+ }
+
+ known_methods_[method] = method_descriptor->full_name();
+
+ return method_descriptor->full_name();
+}
+
+TString ProtoFileParser::GetFormattedMethodName(const TString& method) {
+ has_error_ = false;
+ TString formatted_method_name = GetFullMethodName(method);
+ if (has_error_) {
+ return "";
+ }
+ size_t last_dot = formatted_method_name.find_last_of('.');
+ if (last_dot != TString::npos) {
+ formatted_method_name[last_dot] = '/';
+ }
+ formatted_method_name.insert(formatted_method_name.begin(), '/');
+ return formatted_method_name;
+}
+
+TString ProtoFileParser::GetMessageTypeFromMethod(const TString& method,
+ bool is_request) {
+ has_error_ = false;
+ TString full_method_name = GetFullMethodName(method);
+ if (has_error_) {
+ return "";
+ }
+ const protobuf::MethodDescriptor* method_desc =
+ desc_pool_->FindMethodByName(google::protobuf::string(full_method_name));
+ if (!method_desc) {
+ LogError("Method not found");
+ return "";
+ }
+
+ return is_request ? method_desc->input_type()->full_name()
+ : method_desc->output_type()->full_name();
+}
+
+bool ProtoFileParser::IsStreaming(const TString& method, bool is_request) {
+ has_error_ = false;
+
+ TString full_method_name = GetFullMethodName(method);
+ if (has_error_) {
+ return false;
+ }
+
+ const protobuf::MethodDescriptor* method_desc =
+ desc_pool_->FindMethodByName(google::protobuf::string(full_method_name));
+ if (!method_desc) {
+ LogError("Method not found");
+ return false;
+ }
+
+ return is_request ? method_desc->client_streaming()
+ : method_desc->server_streaming();
+}
+
+TString ProtoFileParser::GetSerializedProtoFromMethod(
+ const TString& method, const TString& formatted_proto,
+ bool is_request, bool is_json_format) {
+ has_error_ = false;
+ TString message_type_name = GetMessageTypeFromMethod(method, is_request);
+ if (has_error_) {
+ return "";
+ }
+ return GetSerializedProtoFromMessageType(message_type_name, formatted_proto,
+ is_json_format);
+}
+
+TString ProtoFileParser::GetFormattedStringFromMethod(
+ const TString& method, const TString& serialized_proto,
+ bool is_request, bool is_json_format) {
+ has_error_ = false;
+ TString message_type_name = GetMessageTypeFromMethod(method, is_request);
+ if (has_error_) {
+ return "";
+ }
+ return GetFormattedStringFromMessageType(message_type_name, serialized_proto,
+ is_json_format);
+}
+
+TString ProtoFileParser::GetSerializedProtoFromMessageType(
+ const TString& message_type_name, const TString& formatted_proto,
+ bool is_json_format) {
+ has_error_ = false;
+ google::protobuf::string serialized;
+ const protobuf::Descriptor* desc =
+ desc_pool_->FindMessageTypeByName(google::protobuf::string(message_type_name));
+ if (!desc) {
+ LogError("Message type not found");
+ return "";
+ }
+ std::unique_ptr<grpc::protobuf::Message> msg(
+ dynamic_factory_->GetPrototype(desc)->New());
+ bool ok;
+ if (is_json_format) {
+ ok = grpc::protobuf::json::JsonStringToMessage(google::protobuf::string(formatted_proto), msg.get())
+ .ok();
+ if (!ok) {
+ LogError("Failed to convert json format to proto.");
+ return "";
+ }
+ } else {
+ ok = protobuf::TextFormat::ParseFromString(google::protobuf::string(formatted_proto), msg.get());
+ if (!ok) {
+ LogError("Failed to convert text format to proto.");
+ return "";
+ }
+ }
+
+ ok = msg->SerializeToString(&serialized);
+ if (!ok) {
+ LogError("Failed to serialize proto.");
+ return "";
+ }
+ return serialized;
+}
+
+TString ProtoFileParser::GetFormattedStringFromMessageType(
+ const TString& message_type_name, const TString& serialized_proto,
+ bool is_json_format) {
+ has_error_ = false;
+ const protobuf::Descriptor* desc =
+ desc_pool_->FindMessageTypeByName(google::protobuf::string(message_type_name));
+ if (!desc) {
+ LogError("Message type not found");
+ return "";
+ }
+ std::unique_ptr<grpc::protobuf::Message> msg(
+ dynamic_factory_->GetPrototype(desc)->New());
+ if (!msg->ParseFromString(google::protobuf::string(serialized_proto))) {
+ LogError("Failed to deserialize proto.");
+ return "";
+ }
+ google::protobuf::string formatted_string;
+
+ if (is_json_format) {
+ grpc::protobuf::json::JsonPrintOptions jsonPrintOptions;
+ jsonPrintOptions.add_whitespace = true;
+ if (!grpc::protobuf::json::MessageToJsonString(
+ *msg.get(), &formatted_string, jsonPrintOptions)
+ .ok()) {
+ LogError("Failed to print proto message to json format");
+ return "";
+ }
+ } else {
+ if (!protobuf::TextFormat::PrintToString(*msg.get(), &formatted_string)) {
+ LogError("Failed to print proto message to text format");
+ return "";
+ }
+ }
+ return formatted_string;
+}
+
+void ProtoFileParser::LogError(const TString& error_msg) {
+ if (!error_msg.empty()) {
+ std::cerr << error_msg << std::endl;
+ }
+ has_error_ = true;
+}
+
+} // namespace testing
+} // namespace grpc
diff --git a/contrib/libs/grpc/test/cpp/util/proto_file_parser.h b/contrib/libs/grpc/test/cpp/util/proto_file_parser.h
new file mode 100644
index 0000000000..c0445641c7
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/util/proto_file_parser.h
@@ -0,0 +1,129 @@
+/*
+ *
+ * Copyright 2016 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#ifndef GRPC_TEST_CPP_UTIL_PROTO_FILE_PARSER_H
+#define GRPC_TEST_CPP_UTIL_PROTO_FILE_PARSER_H
+
+#include <memory>
+
+#include <grpcpp/channel.h>
+
+#include "test/cpp/util/config_grpc_cli.h"
+#include "test/cpp/util/proto_reflection_descriptor_database.h"
+
+namespace grpc {
+namespace testing {
+class ErrorPrinter;
+
+// Find method and associated request/response types.
+class ProtoFileParser {
+ public:
+ // The parser will search proto files using the server reflection service
+ // provided on the given channel. The given protofiles in a source tree rooted
+ // from proto_path will also be searched.
+ ProtoFileParser(const std::shared_ptr<grpc::Channel>& channel,
+ const TString& proto_path, const TString& protofiles);
+
+ ~ProtoFileParser();
+
+ // The input method name in the following four functions could be a partial
+ // string such as Service.Method or even just Method. It will log an error if
+ // there is ambiguity.
+ // Full method name is in the form of Service.Method, it's good to be used in
+ // descriptor database queries.
+ TString GetFullMethodName(const TString& method);
+
+ // Formatted method name is in the form of /Service/Method, it's good to be
+ // used as the argument of Stub::Call()
+ TString GetFormattedMethodName(const TString& method);
+
+ /// Converts a text or json string to its binary proto representation for the
+ /// given method's input or return type.
+ /// \param method the name of the method (does not need to be fully qualified
+ /// name)
+ /// \param formatted_proto the text- or json-formatted proto string
+ /// \param is_request if \c true the resolved type is that of the input
+ /// parameter of the method, otherwise it is the output type
+ /// \param is_json_format if \c true the \c formatted_proto is treated as a
+ /// json-formatted proto, otherwise it is treated as a text-formatted
+ /// proto
+ /// \return the serialised binary proto representation of \c formatted_proto
+ TString GetSerializedProtoFromMethod(const TString& method,
+ const TString& formatted_proto,
+ bool is_request,
+ bool is_json_format);
+
+ /// Converts a text or json string to its proto representation for the given
+ /// message type.
+ /// \param formatted_proto the text- or json-formatted proto string
+ /// \return the serialised binary proto representation of \c formatted_proto
+ TString GetSerializedProtoFromMessageType(
+ const TString& message_type_name, const TString& formatted_proto,
+ bool is_json_format);
+
+ /// Converts a binary proto string to its text or json string representation
+ /// for the given method's input or return type.
+ /// \param method the name of the method (does not need to be a fully
+ /// qualified name)
+ /// \param the serialised binary proto representation of type
+ /// \c message_type_name
+ /// \return the text- or json-formatted proto string of \c serialized_proto
+ TString GetFormattedStringFromMethod(const TString& method,
+ const TString& serialized_proto,
+ bool is_request,
+ bool is_json_format);
+
+ /// Converts a binary proto string to its text or json string representation
+ /// for the given message type.
+ /// \param the serialised binary proto representation of type
+ /// \c message_type_name
+ /// \return the text- or json-formatted proto string of \c serialized_proto
+ TString GetFormattedStringFromMessageType(
+ const TString& message_type_name, const TString& serialized_proto,
+ bool is_json_format);
+
+ bool IsStreaming(const TString& method, bool is_request);
+
+ bool HasError() const { return has_error_; }
+
+ void LogError(const TString& error_msg);
+
+ private:
+ TString GetMessageTypeFromMethod(const TString& method,
+ bool is_request);
+
+ bool has_error_;
+ TString request_text_;
+ protobuf::compiler::DiskSourceTree source_tree_;
+ std::unique_ptr<ErrorPrinter> error_printer_;
+ std::unique_ptr<protobuf::compiler::Importer> importer_;
+ std::unique_ptr<grpc::ProtoReflectionDescriptorDatabase> reflection_db_;
+ std::unique_ptr<protobuf::DescriptorPoolDatabase> file_db_;
+ std::unique_ptr<protobuf::DescriptorDatabase> desc_db_;
+ std::unique_ptr<protobuf::DescriptorPool> desc_pool_;
+ std::unique_ptr<protobuf::DynamicMessageFactory> dynamic_factory_;
+ std::unique_ptr<grpc::protobuf::Message> request_prototype_;
+ std::unique_ptr<grpc::protobuf::Message> response_prototype_;
+ std::unordered_map<TString, TString> known_methods_;
+ std::vector<const protobuf::ServiceDescriptor*> service_desc_list_;
+};
+
+} // namespace testing
+} // namespace grpc
+
+#endif // GRPC_TEST_CPP_UTIL_PROTO_FILE_PARSER_H
diff --git a/contrib/libs/grpc/test/cpp/util/proto_reflection_descriptor_database.cc b/contrib/libs/grpc/test/cpp/util/proto_reflection_descriptor_database.cc
new file mode 100644
index 0000000000..27a4c1e4cf
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/util/proto_reflection_descriptor_database.cc
@@ -0,0 +1,333 @@
+/*
+ *
+ * Copyright 2016 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include "test/cpp/util/proto_reflection_descriptor_database.h"
+
+#include <vector>
+
+#include <grpc/support/log.h>
+
+using grpc::reflection::v1alpha::ErrorResponse;
+using grpc::reflection::v1alpha::ListServiceResponse;
+using grpc::reflection::v1alpha::ServerReflection;
+using grpc::reflection::v1alpha::ServerReflectionRequest;
+using grpc::reflection::v1alpha::ServerReflectionResponse;
+
+namespace grpc {
+
+ProtoReflectionDescriptorDatabase::ProtoReflectionDescriptorDatabase(
+ std::unique_ptr<ServerReflection::Stub> stub)
+ : stub_(std::move(stub)) {}
+
+ProtoReflectionDescriptorDatabase::ProtoReflectionDescriptorDatabase(
+ const std::shared_ptr<grpc::Channel>& channel)
+ : stub_(ServerReflection::NewStub(channel)) {}
+
+ProtoReflectionDescriptorDatabase::~ProtoReflectionDescriptorDatabase() {
+ if (stream_) {
+ stream_->WritesDone();
+ Status status = stream_->Finish();
+ if (!status.ok()) {
+ if (status.error_code() == StatusCode::UNIMPLEMENTED) {
+ fprintf(stderr,
+ "Reflection request not implemented; "
+ "is the ServerReflection service enabled?\n");
+ } else {
+ fprintf(stderr,
+ "ServerReflectionInfo rpc failed. Error code: %d, message: %s, "
+ "debug info: %s\n",
+ static_cast<int>(status.error_code()),
+ status.error_message().c_str(),
+ ctx_.debug_error_string().c_str());
+ }
+ }
+ }
+}
+
+bool ProtoReflectionDescriptorDatabase::FindFileByName(
+ const google::protobuf::string& filename, protobuf::FileDescriptorProto* output) {
+ if (cached_db_.FindFileByName(filename, output)) {
+ return true;
+ }
+
+ if (known_files_.find(filename) != known_files_.end()) {
+ return false;
+ }
+
+ ServerReflectionRequest request;
+ request.set_file_by_filename(filename);
+ ServerReflectionResponse response;
+
+ if (!DoOneRequest(request, response)) {
+ return false;
+ }
+
+ if (response.message_response_case() ==
+ ServerReflectionResponse::MessageResponseCase::kFileDescriptorResponse) {
+ AddFileFromResponse(response.file_descriptor_response());
+ } else if (response.message_response_case() ==
+ ServerReflectionResponse::MessageResponseCase::kErrorResponse) {
+ const ErrorResponse& error = response.error_response();
+ if (error.error_code() == StatusCode::NOT_FOUND) {
+ gpr_log(GPR_INFO, "NOT_FOUND from server for FindFileByName(%s)",
+ filename.c_str());
+ } else {
+ gpr_log(GPR_INFO,
+ "Error on FindFileByName(%s)\n\tError code: %d\n"
+ "\tError Message: %s",
+ filename.c_str(), error.error_code(),
+ error.error_message().c_str());
+ }
+ } else {
+ gpr_log(
+ GPR_INFO,
+ "Error on FindFileByName(%s) response type\n"
+ "\tExpecting: %d\n\tReceived: %d",
+ filename.c_str(),
+ ServerReflectionResponse::MessageResponseCase::kFileDescriptorResponse,
+ response.message_response_case());
+ }
+
+ return cached_db_.FindFileByName(filename, output);
+}
+
+bool ProtoReflectionDescriptorDatabase::FindFileContainingSymbol(
+ const google::protobuf::string& symbol_name, protobuf::FileDescriptorProto* output) {
+ if (cached_db_.FindFileContainingSymbol(symbol_name, output)) {
+ return true;
+ }
+
+ if (missing_symbols_.find(symbol_name) != missing_symbols_.end()) {
+ return false;
+ }
+
+ ServerReflectionRequest request;
+ request.set_file_containing_symbol(symbol_name);
+ ServerReflectionResponse response;
+
+ if (!DoOneRequest(request, response)) {
+ return false;
+ }
+
+ if (response.message_response_case() ==
+ ServerReflectionResponse::MessageResponseCase::kFileDescriptorResponse) {
+ AddFileFromResponse(response.file_descriptor_response());
+ } else if (response.message_response_case() ==
+ ServerReflectionResponse::MessageResponseCase::kErrorResponse) {
+ const ErrorResponse& error = response.error_response();
+ if (error.error_code() == StatusCode::NOT_FOUND) {
+ missing_symbols_.insert(symbol_name);
+ gpr_log(GPR_INFO,
+ "NOT_FOUND from server for FindFileContainingSymbol(%s)",
+ symbol_name.c_str());
+ } else {
+ gpr_log(GPR_INFO,
+ "Error on FindFileContainingSymbol(%s)\n"
+ "\tError code: %d\n\tError Message: %s",
+ symbol_name.c_str(), error.error_code(),
+ error.error_message().c_str());
+ }
+ } else {
+ gpr_log(
+ GPR_INFO,
+ "Error on FindFileContainingSymbol(%s) response type\n"
+ "\tExpecting: %d\n\tReceived: %d",
+ symbol_name.c_str(),
+ ServerReflectionResponse::MessageResponseCase::kFileDescriptorResponse,
+ response.message_response_case());
+ }
+ return cached_db_.FindFileContainingSymbol(symbol_name, output);
+}
+
+bool ProtoReflectionDescriptorDatabase::FindFileContainingExtension(
+ const google::protobuf::string& containing_type, int field_number,
+ protobuf::FileDescriptorProto* output) {
+ if (cached_db_.FindFileContainingExtension(containing_type, field_number,
+ output)) {
+ return true;
+ }
+
+ if (missing_extensions_.find(containing_type) != missing_extensions_.end() &&
+ missing_extensions_[containing_type].find(field_number) !=
+ missing_extensions_[containing_type].end()) {
+ gpr_log(GPR_INFO, "nested map.");
+ return false;
+ }
+
+ ServerReflectionRequest request;
+ request.mutable_file_containing_extension()->set_containing_type(
+ containing_type);
+ request.mutable_file_containing_extension()->set_extension_number(
+ field_number);
+ ServerReflectionResponse response;
+
+ if (!DoOneRequest(request, response)) {
+ return false;
+ }
+
+ if (response.message_response_case() ==
+ ServerReflectionResponse::MessageResponseCase::kFileDescriptorResponse) {
+ AddFileFromResponse(response.file_descriptor_response());
+ } else if (response.message_response_case() ==
+ ServerReflectionResponse::MessageResponseCase::kErrorResponse) {
+ const ErrorResponse& error = response.error_response();
+ if (error.error_code() == StatusCode::NOT_FOUND) {
+ if (missing_extensions_.find(containing_type) ==
+ missing_extensions_.end()) {
+ missing_extensions_[containing_type] = {};
+ }
+ missing_extensions_[containing_type].insert(field_number);
+ gpr_log(GPR_INFO,
+ "NOT_FOUND from server for FindFileContainingExtension(%s, %d)",
+ containing_type.c_str(), field_number);
+ } else {
+ gpr_log(GPR_INFO,
+ "Error on FindFileContainingExtension(%s, %d)\n"
+ "\tError code: %d\n\tError Message: %s",
+ containing_type.c_str(), field_number, error.error_code(),
+ error.error_message().c_str());
+ }
+ } else {
+ gpr_log(
+ GPR_INFO,
+ "Error on FindFileContainingExtension(%s, %d) response type\n"
+ "\tExpecting: %d\n\tReceived: %d",
+ containing_type.c_str(), field_number,
+ ServerReflectionResponse::MessageResponseCase::kFileDescriptorResponse,
+ response.message_response_case());
+ }
+
+ return cached_db_.FindFileContainingExtension(containing_type, field_number,
+ output);
+}
+
+bool ProtoReflectionDescriptorDatabase::FindAllExtensionNumbers(
+ const google::protobuf::string& extendee_type, std::vector<int>* output) {
+ if (cached_extension_numbers_.find(extendee_type) !=
+ cached_extension_numbers_.end()) {
+ *output = cached_extension_numbers_[extendee_type];
+ return true;
+ }
+
+ ServerReflectionRequest request;
+ request.set_all_extension_numbers_of_type(extendee_type);
+ ServerReflectionResponse response;
+
+ if (!DoOneRequest(request, response)) {
+ return false;
+ }
+
+ if (response.message_response_case() ==
+ ServerReflectionResponse::MessageResponseCase::
+ kAllExtensionNumbersResponse) {
+ auto number = response.all_extension_numbers_response().extension_number();
+ *output = std::vector<int>(number.begin(), number.end());
+ cached_extension_numbers_[extendee_type] = *output;
+ return true;
+ } else if (response.message_response_case() ==
+ ServerReflectionResponse::MessageResponseCase::kErrorResponse) {
+ const ErrorResponse& error = response.error_response();
+ if (error.error_code() == StatusCode::NOT_FOUND) {
+ gpr_log(GPR_INFO, "NOT_FOUND from server for FindAllExtensionNumbers(%s)",
+ extendee_type.c_str());
+ } else {
+ gpr_log(GPR_INFO,
+ "Error on FindAllExtensionNumbersExtension(%s)\n"
+ "\tError code: %d\n\tError Message: %s",
+ extendee_type.c_str(), error.error_code(),
+ error.error_message().c_str());
+ }
+ }
+ return false;
+}
+
+bool ProtoReflectionDescriptorDatabase::GetServices(
+ std::vector<TString>* output) {
+ ServerReflectionRequest request;
+ request.set_list_services("");
+ ServerReflectionResponse response;
+
+ if (!DoOneRequest(request, response)) {
+ return false;
+ }
+
+ if (response.message_response_case() ==
+ ServerReflectionResponse::MessageResponseCase::kListServicesResponse) {
+ const ListServiceResponse& ls_response = response.list_services_response();
+ for (int i = 0; i < ls_response.service_size(); ++i) {
+ (*output).push_back(ls_response.service(i).name());
+ }
+ return true;
+ } else if (response.message_response_case() ==
+ ServerReflectionResponse::MessageResponseCase::kErrorResponse) {
+ const ErrorResponse& error = response.error_response();
+ gpr_log(GPR_INFO,
+ "Error on GetServices()\n\tError code: %d\n"
+ "\tError Message: %s",
+ error.error_code(), error.error_message().c_str());
+ } else {
+ gpr_log(
+ GPR_INFO,
+ "Error on GetServices() response type\n\tExpecting: %d\n\tReceived: %d",
+ ServerReflectionResponse::MessageResponseCase::kListServicesResponse,
+ response.message_response_case());
+ }
+ return false;
+}
+
+const protobuf::FileDescriptorProto
+ProtoReflectionDescriptorDatabase::ParseFileDescriptorProtoResponse(
+ const TString& byte_fd_proto) {
+ protobuf::FileDescriptorProto file_desc_proto;
+ file_desc_proto.ParseFromString(google::protobuf::string(byte_fd_proto));
+ return file_desc_proto;
+}
+
+void ProtoReflectionDescriptorDatabase::AddFileFromResponse(
+ const grpc::reflection::v1alpha::FileDescriptorResponse& response) {
+ for (int i = 0; i < response.file_descriptor_proto_size(); ++i) {
+ const protobuf::FileDescriptorProto file_proto =
+ ParseFileDescriptorProtoResponse(response.file_descriptor_proto(i));
+ if (known_files_.find(file_proto.name()) == known_files_.end()) {
+ known_files_.insert(file_proto.name());
+ cached_db_.Add(file_proto);
+ }
+ }
+}
+
+const std::shared_ptr<ProtoReflectionDescriptorDatabase::ClientStream>
+ProtoReflectionDescriptorDatabase::GetStream() {
+ if (!stream_) {
+ stream_ = stub_->ServerReflectionInfo(&ctx_);
+ }
+ return stream_;
+}
+
+bool ProtoReflectionDescriptorDatabase::DoOneRequest(
+ const ServerReflectionRequest& request,
+ ServerReflectionResponse& response) {
+ bool success = false;
+ stream_mutex_.lock();
+ if (GetStream()->Write(request) && GetStream()->Read(&response)) {
+ success = true;
+ }
+ stream_mutex_.unlock();
+ return success;
+}
+
+} // namespace grpc
diff --git a/contrib/libs/grpc/test/cpp/util/proto_reflection_descriptor_database.h b/contrib/libs/grpc/test/cpp/util/proto_reflection_descriptor_database.h
new file mode 100644
index 0000000000..cdd6f0cccd
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/util/proto_reflection_descriptor_database.h
@@ -0,0 +1,111 @@
+/*
+ *
+ * Copyright 2016 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+#ifndef GRPC_TEST_CPP_PROTO_SERVER_REFLECTION_DATABSE_H
+#define GRPC_TEST_CPP_PROTO_SERVER_REFLECTION_DATABSE_H
+
+#include <mutex>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include <grpcpp/grpcpp.h>
+#include <grpcpp/impl/codegen/config_protobuf.h>
+#include "src/proto/grpc/reflection/v1alpha/reflection.grpc.pb.h"
+
+namespace grpc {
+
+// ProtoReflectionDescriptorDatabase takes a stub of ServerReflection and
+// provides the methods defined by DescriptorDatabase interfaces. It can be used
+// to feed a DescriptorPool instance.
+class ProtoReflectionDescriptorDatabase : public protobuf::DescriptorDatabase {
+ public:
+ explicit ProtoReflectionDescriptorDatabase(
+ std::unique_ptr<reflection::v1alpha::ServerReflection::Stub> stub);
+
+ explicit ProtoReflectionDescriptorDatabase(
+ const std::shared_ptr<grpc::Channel>& channel);
+
+ virtual ~ProtoReflectionDescriptorDatabase();
+
+ // The following four methods implement DescriptorDatabase interfaces.
+ //
+ // Find a file by file name. Fills in *output and returns true if found.
+ // Otherwise, returns false, leaving the contents of *output undefined.
+ bool FindFileByName(const google::protobuf::string& filename,
+ protobuf::FileDescriptorProto* output) override;
+
+ // Find the file that declares the given fully-qualified symbol name.
+ // If found, fills in *output and returns true, otherwise returns false
+ // and leaves *output undefined.
+ bool FindFileContainingSymbol(const google::protobuf::string& symbol_name,
+ protobuf::FileDescriptorProto* output) override;
+
+ // Find the file which defines an extension extending the given message type
+ // with the given field number. If found, fills in *output and returns true,
+ // otherwise returns false and leaves *output undefined. containing_type
+ // must be a fully-qualified type name.
+ bool FindFileContainingExtension(
+ const google::protobuf::string& containing_type, int field_number,
+ protobuf::FileDescriptorProto* output) override;
+
+ // Finds the tag numbers used by all known extensions of
+ // extendee_type, and appends them to output in an undefined
+ // order. This method is best-effort: it's not guaranteed that the
+ // database will find all extensions, and it's not guaranteed that
+ // FindFileContainingExtension will return true on all of the found
+ // numbers. Returns true if the search was successful, otherwise
+ // returns false and leaves output unchanged.
+ bool FindAllExtensionNumbers(const google::protobuf::string& extendee_type,
+ std::vector<int>* output) override;
+
+ // Provide a list of full names of registered services
+ bool GetServices(std::vector<TString>* output);
+
+ private:
+ typedef ClientReaderWriter<
+ grpc::reflection::v1alpha::ServerReflectionRequest,
+ grpc::reflection::v1alpha::ServerReflectionResponse>
+ ClientStream;
+
+ const protobuf::FileDescriptorProto ParseFileDescriptorProtoResponse(
+ const TString& byte_fd_proto);
+
+ void AddFileFromResponse(
+ const grpc::reflection::v1alpha::FileDescriptorResponse& response);
+
+ const std::shared_ptr<ClientStream> GetStream();
+
+ bool DoOneRequest(
+ const grpc::reflection::v1alpha::ServerReflectionRequest& request,
+ grpc::reflection::v1alpha::ServerReflectionResponse& response);
+
+ std::shared_ptr<ClientStream> stream_;
+ grpc::ClientContext ctx_;
+ std::unique_ptr<grpc::reflection::v1alpha::ServerReflection::Stub> stub_;
+ std::unordered_set<string> known_files_;
+ std::unordered_set<string> missing_symbols_;
+ std::unordered_map<string, std::unordered_set<int>> missing_extensions_;
+ std::unordered_map<string, std::vector<int>> cached_extension_numbers_;
+ std::mutex stream_mutex_;
+
+ protobuf::SimpleDescriptorDatabase cached_db_;
+};
+
+} // namespace grpc
+
+#endif // GRPC_TEST_CPP_METRICS_SERVER_H
diff --git a/contrib/libs/grpc/test/cpp/util/service_describer.cc b/contrib/libs/grpc/test/cpp/util/service_describer.cc
new file mode 100644
index 0000000000..2af1104b97
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/util/service_describer.cc
@@ -0,0 +1,92 @@
+/*
+ *
+ * Copyright 2016 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include "test/cpp/util/service_describer.h"
+
+#include <iostream>
+#include <sstream>
+#include <util/generic/string.h>
+#include <vector>
+
+namespace grpc {
+namespace testing {
+
+TString DescribeServiceList(std::vector<TString> service_list,
+ grpc::protobuf::DescriptorPool& desc_pool) {
+ std::stringstream result;
+ for (auto it = service_list.begin(); it != service_list.end(); it++) {
+ auto const& service = *it;
+ const grpc::protobuf::ServiceDescriptor* service_desc =
+ desc_pool.FindServiceByName(google::protobuf::string(service));
+ if (service_desc != nullptr) {
+ result << DescribeService(service_desc);
+ }
+ }
+ return result.str();
+}
+
+TString DescribeService(const grpc::protobuf::ServiceDescriptor* service) {
+ TString result;
+ if (service->options().deprecated()) {
+ result.append("DEPRECATED\n");
+ }
+ result.append("filename: " + service->file()->name() + "\n");
+
+ TString package = service->full_name();
+ size_t pos = package.rfind("." + service->name());
+ if (pos != TString::npos) {
+ package.erase(pos);
+ result.append("package: " + package + ";\n");
+ }
+ result.append("service " + service->name() + " {\n");
+ for (int i = 0; i < service->method_count(); ++i) {
+ result.append(DescribeMethod(service->method(i)));
+ }
+ result.append("}\n\n");
+ return result;
+}
+
+TString DescribeMethod(const grpc::protobuf::MethodDescriptor* method) {
+ std::stringstream result;
+ result << " rpc " << method->name()
+ << (method->client_streaming() ? "(stream " : "(")
+ << method->input_type()->full_name() << ") returns "
+ << (method->server_streaming() ? "(stream " : "(")
+ << method->output_type()->full_name() << ") {}\n";
+ if (method->options().deprecated()) {
+ result << " DEPRECATED";
+ }
+ return result.str();
+}
+
+TString SummarizeService(const grpc::protobuf::ServiceDescriptor* service) {
+ TString result;
+ for (int i = 0; i < service->method_count(); ++i) {
+ result.append(SummarizeMethod(service->method(i)));
+ }
+ return result;
+}
+
+TString SummarizeMethod(const grpc::protobuf::MethodDescriptor* method) {
+ TString result = method->name();
+ result.append("\n");
+ return result;
+}
+
+} // namespace testing
+} // namespace grpc
diff --git a/contrib/libs/grpc/test/cpp/util/service_describer.h b/contrib/libs/grpc/test/cpp/util/service_describer.h
new file mode 100644
index 0000000000..a473f03744
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/util/service_describer.h
@@ -0,0 +1,42 @@
+/*
+ *
+ * Copyright 2016 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#ifndef GRPC_TEST_CPP_UTIL_SERVICE_DESCRIBER_H
+#define GRPC_TEST_CPP_UTIL_SERVICE_DESCRIBER_H
+
+#include <grpcpp/support/config.h>
+#include "test/cpp/util/config_grpc_cli.h"
+
+namespace grpc {
+namespace testing {
+
+TString DescribeServiceList(std::vector<TString> service_list,
+ grpc::protobuf::DescriptorPool& desc_pool);
+
+TString DescribeService(const grpc::protobuf::ServiceDescriptor* service);
+
+TString DescribeMethod(const grpc::protobuf::MethodDescriptor* method);
+
+TString SummarizeService(const grpc::protobuf::ServiceDescriptor* service);
+
+TString SummarizeMethod(const grpc::protobuf::MethodDescriptor* method);
+
+} // namespace testing
+} // namespace grpc
+
+#endif // GRPC_TEST_CPP_UTIL_SERVICE_DESCRIBER_H
diff --git a/contrib/libs/grpc/test/cpp/util/slice_test.cc b/contrib/libs/grpc/test/cpp/util/slice_test.cc
new file mode 100644
index 0000000000..d7e945ae38
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/util/slice_test.cc
@@ -0,0 +1,144 @@
+/*
+ *
+ * Copyright 2015 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include <grpc++/support/slice.h>
+#include <grpcpp/impl/grpc_library.h>
+
+#include <grpc/grpc.h>
+#include <grpc/slice.h>
+#include <gtest/gtest.h>
+
+#include "test/core/util/test_config.h"
+
+namespace grpc {
+
+static internal::GrpcLibraryInitializer g_gli_initializer;
+
+namespace {
+
+const char* kContent = "hello xxxxxxxxxxxxxxxxxxxx world";
+
+class SliceTest : public ::testing::Test {
+ protected:
+ static void SetUpTestCase() { grpc_init(); }
+
+ static void TearDownTestCase() { grpc_shutdown(); }
+
+ void CheckSliceSize(const Slice& s, const TString& content) {
+ EXPECT_EQ(content.size(), s.size());
+ }
+ void CheckSlice(const Slice& s, const TString& content) {
+ EXPECT_EQ(content.size(), s.size());
+ EXPECT_EQ(content,
+ TString(reinterpret_cast<const char*>(s.begin()), s.size()));
+ }
+};
+
+TEST_F(SliceTest, Empty) {
+ Slice empty_slice;
+ CheckSlice(empty_slice, "");
+}
+
+TEST_F(SliceTest, Sized) {
+ Slice sized_slice(strlen(kContent));
+ CheckSliceSize(sized_slice, kContent);
+}
+
+TEST_F(SliceTest, String) {
+ Slice spp(kContent);
+ CheckSlice(spp, kContent);
+}
+
+TEST_F(SliceTest, Buf) {
+ Slice spp(kContent, strlen(kContent));
+ CheckSlice(spp, kContent);
+}
+
+TEST_F(SliceTest, StaticBuf) {
+ Slice spp(kContent, strlen(kContent), Slice::STATIC_SLICE);
+ CheckSlice(spp, kContent);
+}
+
+TEST_F(SliceTest, SliceNew) {
+ char* x = new char[strlen(kContent) + 1];
+ strcpy(x, kContent);
+ Slice spp(x, strlen(x), [](void* p) { delete[] static_cast<char*>(p); });
+ CheckSlice(spp, kContent);
+}
+
+TEST_F(SliceTest, SliceNewDoNothing) {
+ Slice spp(const_cast<char*>(kContent), strlen(kContent), [](void* /*p*/) {});
+ CheckSlice(spp, kContent);
+}
+
+TEST_F(SliceTest, SliceNewWithUserData) {
+ struct stest {
+ char* x;
+ int y;
+ };
+ auto* t = new stest;
+ t->x = new char[strlen(kContent) + 1];
+ strcpy(t->x, kContent);
+ Slice spp(t->x, strlen(t->x),
+ [](void* p) {
+ auto* t = static_cast<stest*>(p);
+ delete[] t->x;
+ delete t;
+ },
+ t);
+ CheckSlice(spp, kContent);
+}
+
+TEST_F(SliceTest, SliceNewLen) {
+ Slice spp(const_cast<char*>(kContent), strlen(kContent),
+ [](void* /*p*/, size_t l) { EXPECT_EQ(l, strlen(kContent)); });
+ CheckSlice(spp, kContent);
+}
+
+TEST_F(SliceTest, Steal) {
+ grpc_slice s = grpc_slice_from_copied_string(kContent);
+ Slice spp(s, Slice::STEAL_REF);
+ CheckSlice(spp, kContent);
+}
+
+TEST_F(SliceTest, Add) {
+ grpc_slice s = grpc_slice_from_copied_string(kContent);
+ Slice spp(s, Slice::ADD_REF);
+ grpc_slice_unref(s);
+ CheckSlice(spp, kContent);
+}
+
+TEST_F(SliceTest, Cslice) {
+ grpc_slice s = grpc_slice_from_copied_string(kContent);
+ Slice spp(s, Slice::STEAL_REF);
+ CheckSlice(spp, kContent);
+ grpc_slice c_slice = spp.c_slice();
+ EXPECT_EQ(GRPC_SLICE_START_PTR(s), GRPC_SLICE_START_PTR(c_slice));
+ EXPECT_EQ(GRPC_SLICE_END_PTR(s), GRPC_SLICE_END_PTR(c_slice));
+ grpc_slice_unref(c_slice);
+}
+
+} // namespace
+} // namespace grpc
+
+int main(int argc, char** argv) {
+ grpc::testing::TestEnvironment env(argc, argv);
+ ::testing::InitGoogleTest(&argc, argv);
+ int ret = RUN_ALL_TESTS();
+ return ret;
+}
diff --git a/contrib/libs/grpc/test/cpp/util/string_ref_helper.cc b/contrib/libs/grpc/test/cpp/util/string_ref_helper.cc
new file mode 100644
index 0000000000..e573f5d33a
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/util/string_ref_helper.cc
@@ -0,0 +1,29 @@
+/*
+ *
+ * Copyright 2015 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include "test/cpp/util/string_ref_helper.h"
+
+namespace grpc {
+namespace testing {
+
+TString ToString(const grpc::string_ref& r) {
+ return TString(r.data(), r.size());
+}
+
+} // namespace testing
+} // namespace grpc
diff --git a/contrib/libs/grpc/test/cpp/util/string_ref_helper.h b/contrib/libs/grpc/test/cpp/util/string_ref_helper.h
new file mode 100644
index 0000000000..e9e941f319
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/util/string_ref_helper.h
@@ -0,0 +1,32 @@
+/*
+ *
+ * Copyright 2015 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#ifndef GRPC_TEST_CPP_UTIL_STRING_REF_HELPER_H
+#define GRPC_TEST_CPP_UTIL_STRING_REF_HELPER_H
+
+#include <grpcpp/support/string_ref.h>
+
+namespace grpc {
+namespace testing {
+
+TString ToString(const grpc::string_ref& r);
+
+} // namespace testing
+} // namespace grpc
+
+#endif // GRPC_TEST_CPP_UTIL_STRING_REF_HELPER_H
diff --git a/contrib/libs/grpc/test/cpp/util/string_ref_test.cc b/contrib/libs/grpc/test/cpp/util/string_ref_test.cc
new file mode 100644
index 0000000000..8e3259b764
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/util/string_ref_test.cc
@@ -0,0 +1,205 @@
+/*
+ *
+ * Copyright 2015 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include <grpcpp/support/string_ref.h>
+
+#include <string.h>
+
+#include <gtest/gtest.h>
+
+#include "test/core/util/test_config.h"
+
+namespace grpc {
+namespace {
+
+const char kTestString[] = "blah";
+const char kTestStringWithEmbeddedNull[] = "blah\0foo";
+const size_t kTestStringWithEmbeddedNullLength = 8;
+const char kTestUnrelatedString[] = "foo";
+
+class StringRefTest : public ::testing::Test {};
+
+TEST_F(StringRefTest, Empty) {
+ string_ref s;
+ EXPECT_EQ(0U, s.length());
+ EXPECT_EQ(nullptr, s.data());
+}
+
+TEST_F(StringRefTest, FromCString) {
+ string_ref s(kTestString);
+ EXPECT_EQ(strlen(kTestString), s.length());
+ EXPECT_EQ(kTestString, s.data());
+}
+
+TEST_F(StringRefTest, FromCStringWithLength) {
+ string_ref s(kTestString, 2);
+ EXPECT_EQ(2U, s.length());
+ EXPECT_EQ(kTestString, s.data());
+}
+
+TEST_F(StringRefTest, FromString) {
+ string copy(kTestString);
+ string_ref s(copy);
+ EXPECT_EQ(copy.data(), s.data());
+ EXPECT_EQ(copy.length(), s.length());
+}
+
+TEST_F(StringRefTest, CopyConstructor) {
+ string_ref s1(kTestString);
+ ;
+ const string_ref& s2(s1);
+ EXPECT_EQ(s1.length(), s2.length());
+ EXPECT_EQ(s1.data(), s2.data());
+}
+
+TEST_F(StringRefTest, FromStringWithEmbeddedNull) {
+ string copy(kTestStringWithEmbeddedNull, kTestStringWithEmbeddedNullLength);
+ string_ref s(copy);
+ EXPECT_EQ(copy.data(), s.data());
+ EXPECT_EQ(copy.length(), s.length());
+ EXPECT_EQ(kTestStringWithEmbeddedNullLength, s.length());
+}
+
+TEST_F(StringRefTest, Assignment) {
+ string_ref s1(kTestString);
+ ;
+ string_ref s2;
+ EXPECT_EQ(nullptr, s2.data());
+ s2 = s1;
+ EXPECT_EQ(s1.length(), s2.length());
+ EXPECT_EQ(s1.data(), s2.data());
+}
+
+TEST_F(StringRefTest, Iterator) {
+ string_ref s(kTestString);
+ size_t i = 0;
+ for (auto it = s.cbegin(); it != s.cend(); ++it) {
+ auto val = kTestString[i++];
+ EXPECT_EQ(val, *it);
+ }
+ EXPECT_EQ(strlen(kTestString), i);
+}
+
+TEST_F(StringRefTest, ReverseIterator) {
+ string_ref s(kTestString);
+ size_t i = strlen(kTestString);
+ for (auto rit = s.crbegin(); rit != s.crend(); ++rit) {
+ auto val = kTestString[--i];
+ EXPECT_EQ(val, *rit);
+ }
+ EXPECT_EQ(0U, i);
+}
+
+TEST_F(StringRefTest, Capacity) {
+ string_ref empty;
+ EXPECT_EQ(0U, empty.length());
+ EXPECT_EQ(0U, empty.size());
+ EXPECT_EQ(0U, empty.max_size());
+ EXPECT_TRUE(empty.empty());
+
+ string_ref s(kTestString);
+ EXPECT_EQ(strlen(kTestString), s.length());
+ EXPECT_EQ(s.length(), s.size());
+ EXPECT_EQ(s.max_size(), s.length());
+ EXPECT_FALSE(s.empty());
+}
+
+TEST_F(StringRefTest, Compare) {
+ string_ref s1(kTestString);
+ string s1_copy(kTestString);
+ string_ref s2(kTestUnrelatedString);
+ string_ref s3(kTestStringWithEmbeddedNull, kTestStringWithEmbeddedNullLength);
+ EXPECT_EQ(0, s1.compare(s1_copy));
+ EXPECT_NE(0, s1.compare(s2));
+ EXPECT_NE(0, s1.compare(s3));
+}
+
+TEST_F(StringRefTest, StartsWith) {
+ string_ref s1(kTestString);
+ string_ref s2(kTestUnrelatedString);
+ string_ref s3(kTestStringWithEmbeddedNull, kTestStringWithEmbeddedNullLength);
+ EXPECT_TRUE(s1.starts_with(s1));
+ EXPECT_FALSE(s1.starts_with(s2));
+ EXPECT_FALSE(s2.starts_with(s1));
+ EXPECT_FALSE(s1.starts_with(s3));
+ EXPECT_TRUE(s3.starts_with(s1));
+}
+
+TEST_F(StringRefTest, Endswith) {
+ string_ref s1(kTestString);
+ string_ref s2(kTestUnrelatedString);
+ string_ref s3(kTestStringWithEmbeddedNull, kTestStringWithEmbeddedNullLength);
+ EXPECT_TRUE(s1.ends_with(s1));
+ EXPECT_FALSE(s1.ends_with(s2));
+ EXPECT_FALSE(s2.ends_with(s1));
+ EXPECT_FALSE(s2.ends_with(s3));
+ EXPECT_TRUE(s3.ends_with(s2));
+}
+
+TEST_F(StringRefTest, Find) {
+ string_ref s1(kTestString);
+ string_ref s2(kTestUnrelatedString);
+ string_ref s3(kTestStringWithEmbeddedNull, kTestStringWithEmbeddedNullLength);
+ EXPECT_EQ(0U, s1.find(s1));
+ EXPECT_EQ(0U, s2.find(s2));
+ EXPECT_EQ(0U, s3.find(s3));
+ EXPECT_EQ(string_ref::npos, s1.find(s2));
+ EXPECT_EQ(string_ref::npos, s2.find(s1));
+ EXPECT_EQ(string_ref::npos, s1.find(s3));
+ EXPECT_EQ(0U, s3.find(s1));
+ EXPECT_EQ(5U, s3.find(s2));
+ EXPECT_EQ(string_ref::npos, s1.find('z'));
+ EXPECT_EQ(1U, s2.find('o'));
+}
+
+TEST_F(StringRefTest, SubString) {
+ string_ref s(kTestStringWithEmbeddedNull, kTestStringWithEmbeddedNullLength);
+ string_ref sub1 = s.substr(0, 4);
+ EXPECT_EQ(string_ref(kTestString), sub1);
+ string_ref sub2 = s.substr(5);
+ EXPECT_EQ(string_ref(kTestUnrelatedString), sub2);
+}
+
+TEST_F(StringRefTest, ComparisonOperators) {
+ string_ref s1(kTestString);
+ string_ref s2(kTestUnrelatedString);
+ string_ref s3(kTestStringWithEmbeddedNull, kTestStringWithEmbeddedNullLength);
+ EXPECT_EQ(s1, s1);
+ EXPECT_EQ(s2, s2);
+ EXPECT_EQ(s3, s3);
+ EXPECT_GE(s1, s1);
+ EXPECT_GE(s2, s2);
+ EXPECT_GE(s3, s3);
+ EXPECT_LE(s1, s1);
+ EXPECT_LE(s2, s2);
+ EXPECT_LE(s3, s3);
+ EXPECT_NE(s1, s2);
+ EXPECT_NE(s1, s3);
+ EXPECT_NE(s2, s3);
+ EXPECT_GT(s3, s1);
+ EXPECT_LT(s1, s3);
+}
+
+} // namespace
+} // namespace grpc
+
+int main(int argc, char** argv) {
+ grpc::testing::TestEnvironment env(argc, argv);
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/contrib/libs/grpc/test/cpp/util/subprocess.cc b/contrib/libs/grpc/test/cpp/util/subprocess.cc
new file mode 100644
index 0000000000..648bd50274
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/util/subprocess.cc
@@ -0,0 +1,44 @@
+/*
+ *
+ * Copyright 2015 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include "test/cpp/util/subprocess.h"
+
+#include <vector>
+
+#include "test/core/util/subprocess.h"
+
+namespace grpc {
+
+static gpr_subprocess* MakeProcess(const std::vector<TString>& args) {
+ std::vector<const char*> vargs;
+ for (auto it = args.begin(); it != args.end(); ++it) {
+ vargs.push_back(it->c_str());
+ }
+ return gpr_subprocess_create(vargs.size(), &vargs[0]);
+}
+
+SubProcess::SubProcess(const std::vector<TString>& args)
+ : subprocess_(MakeProcess(args)) {}
+
+SubProcess::~SubProcess() { gpr_subprocess_destroy(subprocess_); }
+
+int SubProcess::Join() { return gpr_subprocess_join(subprocess_); }
+
+void SubProcess::Interrupt() { gpr_subprocess_interrupt(subprocess_); }
+
+} // namespace grpc
diff --git a/contrib/libs/grpc/test/cpp/util/subprocess.h b/contrib/libs/grpc/test/cpp/util/subprocess.h
new file mode 100644
index 0000000000..84dda31dd1
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/util/subprocess.h
@@ -0,0 +1,47 @@
+/*
+ *
+ * Copyright 2015 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#ifndef GRPC_TEST_CPP_UTIL_SUBPROCESS_H
+#define GRPC_TEST_CPP_UTIL_SUBPROCESS_H
+
+#include <initializer_list>
+#include <util/generic/string.h>
+#include <vector>
+
+struct gpr_subprocess;
+
+namespace grpc {
+
+class SubProcess {
+ public:
+ SubProcess(const std::vector<TString>& args);
+ ~SubProcess();
+
+ int Join();
+ void Interrupt();
+
+ private:
+ SubProcess(const SubProcess& other);
+ SubProcess& operator=(const SubProcess& other);
+
+ gpr_subprocess* const subprocess_;
+};
+
+} // namespace grpc
+
+#endif // GRPC_TEST_CPP_UTIL_SUBPROCESS_H
diff --git a/contrib/libs/grpc/test/cpp/util/test_config.h b/contrib/libs/grpc/test/cpp/util/test_config.h
new file mode 100644
index 0000000000..094ed44f63
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/util/test_config.h
@@ -0,0 +1,30 @@
+/*
+ *
+ * Copyright 2015 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#ifndef GRPC_TEST_CPP_UTIL_TEST_CONFIG_H
+#define GRPC_TEST_CPP_UTIL_TEST_CONFIG_H
+
+namespace grpc {
+namespace testing {
+
+void InitTest(int* argc, char*** argv, bool remove_flags);
+
+} // namespace testing
+} // namespace grpc
+
+#endif // GRPC_TEST_CPP_UTIL_TEST_CONFIG_H
diff --git a/contrib/libs/grpc/test/cpp/util/test_config_cc.cc b/contrib/libs/grpc/test/cpp/util/test_config_cc.cc
new file mode 100644
index 0000000000..e4b6886335
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/util/test_config_cc.cc
@@ -0,0 +1,37 @@
+/*
+ *
+ * Copyright 2015 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include <gflags/gflags.h>
+#include "test/cpp/util/test_config.h"
+
+// In some distros, gflags is in the namespace google, and in some others,
+// in gflags. This hack is enabling us to find both.
+namespace google {}
+namespace gflags {}
+using namespace google;
+using namespace gflags;
+
+namespace grpc {
+namespace testing {
+
+void InitTest(int* argc, char*** argv, bool remove_flags) {
+ ParseCommandLineFlags(argc, argv, remove_flags);
+}
+
+} // namespace testing
+} // namespace grpc
diff --git a/contrib/libs/grpc/test/cpp/util/test_credentials_provider.cc b/contrib/libs/grpc/test/cpp/util/test_credentials_provider.cc
new file mode 100644
index 0000000000..f7134b773f
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/util/test_credentials_provider.cc
@@ -0,0 +1,181 @@
+
+/*
+ *
+ * Copyright 2016 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include "test/cpp/util/test_credentials_provider.h"
+
+#include <cstdio>
+#include <fstream>
+#include <iostream>
+
+#include <mutex>
+#include <unordered_map>
+
+#include <gflags/gflags.h>
+#include <grpc/support/log.h>
+#include <grpc/support/sync.h>
+#include <grpcpp/security/server_credentials.h>
+
+#include "test/core/end2end/data/ssl_test_data.h"
+
+DEFINE_string(tls_cert_file, "", "The TLS cert file used when --use_tls=true");
+DEFINE_string(tls_key_file, "", "The TLS key file used when --use_tls=true");
+
+namespace grpc {
+namespace testing {
+namespace {
+
+TString ReadFile(const TString& src_path) {
+ std::ifstream src;
+ src.open(src_path, std::ifstream::in | std::ifstream::binary);
+
+ TString contents;
+ src.seekg(0, std::ios::end);
+ contents.reserve(src.tellg());
+ src.seekg(0, std::ios::beg);
+ contents.assign((std::istreambuf_iterator<char>(src)),
+ (std::istreambuf_iterator<char>()));
+ return contents;
+}
+
+class DefaultCredentialsProvider : public CredentialsProvider {
+ public:
+ DefaultCredentialsProvider() {
+ if (!FLAGS_tls_key_file.empty()) {
+ custom_server_key_ = ReadFile(FLAGS_tls_key_file);
+ }
+ if (!FLAGS_tls_cert_file.empty()) {
+ custom_server_cert_ = ReadFile(FLAGS_tls_cert_file);
+ }
+ }
+ ~DefaultCredentialsProvider() override {}
+
+ void AddSecureType(
+ const TString& type,
+ std::unique_ptr<CredentialTypeProvider> type_provider) override {
+ // This clobbers any existing entry for type, except the defaults, which
+ // can't be clobbered.
+ std::unique_lock<std::mutex> lock(mu_);
+ auto it = std::find(added_secure_type_names_.begin(),
+ added_secure_type_names_.end(), type);
+ if (it == added_secure_type_names_.end()) {
+ added_secure_type_names_.push_back(type);
+ added_secure_type_providers_.push_back(std::move(type_provider));
+ } else {
+ added_secure_type_providers_[it - added_secure_type_names_.begin()] =
+ std::move(type_provider);
+ }
+ }
+
+ std::shared_ptr<ChannelCredentials> GetChannelCredentials(
+ const TString& type, ChannelArguments* args) override {
+ if (type == grpc::testing::kInsecureCredentialsType) {
+ return InsecureChannelCredentials();
+ } else if (type == grpc::testing::kAltsCredentialsType) {
+ grpc::experimental::AltsCredentialsOptions alts_opts;
+ return grpc::experimental::AltsCredentials(alts_opts);
+ } else if (type == grpc::testing::kTlsCredentialsType) {
+ SslCredentialsOptions ssl_opts = {test_root_cert, "", ""};
+ args->SetSslTargetNameOverride("foo.test.google.fr");
+ return grpc::SslCredentials(ssl_opts);
+ } else if (type == grpc::testing::kGoogleDefaultCredentialsType) {
+ return grpc::GoogleDefaultCredentials();
+ } else {
+ std::unique_lock<std::mutex> lock(mu_);
+ auto it(std::find(added_secure_type_names_.begin(),
+ added_secure_type_names_.end(), type));
+ if (it == added_secure_type_names_.end()) {
+ gpr_log(GPR_ERROR, "Unsupported credentials type %s.", type.c_str());
+ return nullptr;
+ }
+ return added_secure_type_providers_[it - added_secure_type_names_.begin()]
+ ->GetChannelCredentials(args);
+ }
+ }
+
+ std::shared_ptr<ServerCredentials> GetServerCredentials(
+ const TString& type) override {
+ if (type == grpc::testing::kInsecureCredentialsType) {
+ return InsecureServerCredentials();
+ } else if (type == grpc::testing::kAltsCredentialsType) {
+ grpc::experimental::AltsServerCredentialsOptions alts_opts;
+ return grpc::experimental::AltsServerCredentials(alts_opts);
+ } else if (type == grpc::testing::kTlsCredentialsType) {
+ SslServerCredentialsOptions ssl_opts;
+ ssl_opts.pem_root_certs = "";
+ if (!custom_server_key_.empty() && !custom_server_cert_.empty()) {
+ SslServerCredentialsOptions::PemKeyCertPair pkcp = {
+ custom_server_key_, custom_server_cert_};
+ ssl_opts.pem_key_cert_pairs.push_back(pkcp);
+ } else {
+ SslServerCredentialsOptions::PemKeyCertPair pkcp = {test_server1_key,
+ test_server1_cert};
+ ssl_opts.pem_key_cert_pairs.push_back(pkcp);
+ }
+ return SslServerCredentials(ssl_opts);
+ } else {
+ std::unique_lock<std::mutex> lock(mu_);
+ auto it(std::find(added_secure_type_names_.begin(),
+ added_secure_type_names_.end(), type));
+ if (it == added_secure_type_names_.end()) {
+ gpr_log(GPR_ERROR, "Unsupported credentials type %s.", type.c_str());
+ return nullptr;
+ }
+ return added_secure_type_providers_[it - added_secure_type_names_.begin()]
+ ->GetServerCredentials();
+ }
+ }
+ std::vector<TString> GetSecureCredentialsTypeList() override {
+ std::vector<TString> types;
+ types.push_back(grpc::testing::kTlsCredentialsType);
+ std::unique_lock<std::mutex> lock(mu_);
+ for (auto it = added_secure_type_names_.begin();
+ it != added_secure_type_names_.end(); it++) {
+ types.push_back(*it);
+ }
+ return types;
+ }
+
+ private:
+ std::mutex mu_;
+ std::vector<TString> added_secure_type_names_;
+ std::vector<std::unique_ptr<CredentialTypeProvider>>
+ added_secure_type_providers_;
+ TString custom_server_key_;
+ TString custom_server_cert_;
+};
+
+CredentialsProvider* g_provider = nullptr;
+
+} // namespace
+
+CredentialsProvider* GetCredentialsProvider() {
+ if (g_provider == nullptr) {
+ g_provider = new DefaultCredentialsProvider;
+ }
+ return g_provider;
+}
+
+void SetCredentialsProvider(CredentialsProvider* provider) {
+ // For now, forbids overriding provider.
+ GPR_ASSERT(g_provider == nullptr);
+ g_provider = provider;
+}
+
+} // namespace testing
+} // namespace grpc
diff --git a/contrib/libs/grpc/test/cpp/util/test_credentials_provider.h b/contrib/libs/grpc/test/cpp/util/test_credentials_provider.h
new file mode 100644
index 0000000000..acba277ada
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/util/test_credentials_provider.h
@@ -0,0 +1,85 @@
+/*
+ *
+ * Copyright 2016 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#ifndef GRPC_TEST_CPP_UTIL_TEST_CREDENTIALS_PROVIDER_H
+#define GRPC_TEST_CPP_UTIL_TEST_CREDENTIALS_PROVIDER_H
+
+#include <memory>
+
+#include <grpcpp/security/credentials.h>
+#include <grpcpp/security/server_credentials.h>
+#include <grpcpp/support/channel_arguments.h>
+
+namespace grpc {
+namespace testing {
+
+const char kInsecureCredentialsType[] = "INSECURE_CREDENTIALS";
+// For real credentials, like tls/ssl, this name should match the AuthContext
+// property "transport_security_type".
+const char kTlsCredentialsType[] = "ssl";
+const char kAltsCredentialsType[] = "alts";
+const char kGoogleDefaultCredentialsType[] = "google_default_credentials";
+
+// Provide test credentials of a particular type.
+class CredentialTypeProvider {
+ public:
+ virtual ~CredentialTypeProvider() {}
+
+ virtual std::shared_ptr<ChannelCredentials> GetChannelCredentials(
+ ChannelArguments* args) = 0;
+ virtual std::shared_ptr<ServerCredentials> GetServerCredentials() = 0;
+};
+
+// Provide test credentials. Thread-safe.
+class CredentialsProvider {
+ public:
+ virtual ~CredentialsProvider() {}
+
+ // Add a secure type in addition to the defaults. The default provider has
+ // (kInsecureCredentialsType, kTlsCredentialsType).
+ virtual void AddSecureType(
+ const TString& type,
+ std::unique_ptr<CredentialTypeProvider> type_provider) = 0;
+
+ // Provide channel credentials according to the given type. Alter the channel
+ // arguments if needed. Return nullptr if type is not registered.
+ virtual std::shared_ptr<ChannelCredentials> GetChannelCredentials(
+ const TString& type, ChannelArguments* args) = 0;
+
+ // Provide server credentials according to the given type.
+ // Return nullptr if type is not registered.
+ virtual std::shared_ptr<ServerCredentials> GetServerCredentials(
+ const TString& type) = 0;
+
+ // Provide a list of secure credentials type.
+ virtual std::vector<TString> GetSecureCredentialsTypeList() = 0;
+};
+
+// Get the current provider. Create a default one if not set.
+// Not thread-safe.
+CredentialsProvider* GetCredentialsProvider();
+
+// Set the global provider. Takes ownership. The previous set provider will be
+// destroyed.
+// Not thread-safe.
+void SetCredentialsProvider(CredentialsProvider* provider);
+
+} // namespace testing
+} // namespace grpc
+
+#endif // GRPC_TEST_CPP_UTIL_TEST_CREDENTIALS_PROVIDER_H
diff --git a/contrib/libs/grpc/test/cpp/util/time_test.cc b/contrib/libs/grpc/test/cpp/util/time_test.cc
new file mode 100644
index 0000000000..bcbfa14f94
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/util/time_test.cc
@@ -0,0 +1,72 @@
+/*
+ *
+ * Copyright 2015 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include <grpc/support/time.h>
+#include <grpcpp/support/time.h>
+#include <gtest/gtest.h>
+
+#include "test/core/util/test_config.h"
+
+using std::chrono::duration_cast;
+using std::chrono::microseconds;
+using std::chrono::system_clock;
+
+namespace grpc {
+namespace {
+
+class TimeTest : public ::testing::Test {};
+
+TEST_F(TimeTest, AbsolutePointTest) {
+ int64_t us = 10000000L;
+ gpr_timespec ts = gpr_time_from_micros(us, GPR_TIMESPAN);
+ ts.clock_type = GPR_CLOCK_REALTIME;
+ system_clock::time_point tp{microseconds(us)};
+ system_clock::time_point tp_converted = Timespec2Timepoint(ts);
+ gpr_timespec ts_converted;
+ Timepoint2Timespec(tp_converted, &ts_converted);
+ EXPECT_TRUE(ts.tv_sec == ts_converted.tv_sec);
+ EXPECT_TRUE(ts.tv_nsec == ts_converted.tv_nsec);
+ system_clock::time_point tp_converted_2 = Timespec2Timepoint(ts_converted);
+ EXPECT_TRUE(tp == tp_converted);
+ EXPECT_TRUE(tp == tp_converted_2);
+}
+
+// gpr_inf_future is treated specially and mapped to/from time_point::max()
+TEST_F(TimeTest, InfFuture) {
+ EXPECT_EQ(system_clock::time_point::max(),
+ Timespec2Timepoint(gpr_inf_future(GPR_CLOCK_REALTIME)));
+ gpr_timespec from_time_point_max;
+ Timepoint2Timespec(system_clock::time_point::max(), &from_time_point_max);
+ EXPECT_EQ(
+ 0, gpr_time_cmp(gpr_inf_future(GPR_CLOCK_REALTIME), from_time_point_max));
+ // This will cause an overflow
+ Timepoint2Timespec(
+ std::chrono::time_point<system_clock, std::chrono::seconds>::max(),
+ &from_time_point_max);
+ EXPECT_EQ(
+ 0, gpr_time_cmp(gpr_inf_future(GPR_CLOCK_REALTIME), from_time_point_max));
+}
+
+} // namespace
+} // namespace grpc
+
+int main(int argc, char** argv) {
+ grpc::testing::TestEnvironment env(argc, argv);
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/contrib/libs/grpc/test/cpp/util/ya.make b/contrib/libs/grpc/test/cpp/util/ya.make
new file mode 100644
index 0000000000..f043cc5b14
--- /dev/null
+++ b/contrib/libs/grpc/test/cpp/util/ya.make
@@ -0,0 +1,39 @@
+LIBRARY()
+
+LICENSE(Apache-2.0)
+
+LICENSE_TEXTS(.yandex_meta/licenses.list.txt)
+
+OWNER(orivej)
+
+PEERDIR(
+ contrib/libs/gflags
+ contrib/libs/protoc
+ contrib/libs/grpc/src/proto/grpc/reflection/v1alpha
+ contrib/restricted/googletest/googlemock
+ contrib/restricted/googletest/googletest
+)
+
+ADDINCL(
+ ${ARCADIA_BUILD_ROOT}/contrib/libs/grpc
+ contrib/libs/grpc
+)
+
+NO_COMPILER_WARNINGS()
+
+SRCS(
+ byte_buffer_proto_helper.cc
+ # grpc_cli_libs:
+ cli_call.cc
+ cli_credentials.cc
+ grpc_tool.cc
+ proto_file_parser.cc
+ service_describer.cc
+ string_ref_helper.cc
+ # grpc++_proto_reflection_desc_db:
+ proto_reflection_descriptor_database.cc
+ # grpc++_test_config:
+ test_config_cc.cc
+)
+
+END()