diff options
author | Devtools Arcadia <arcadia-devtools@yandex-team.ru> | 2022-02-07 18:08:42 +0300 |
---|---|---|
committer | Devtools Arcadia <arcadia-devtools@mous.vla.yp-c.yandex.net> | 2022-02-07 18:08:42 +0300 |
commit | 1110808a9d39d4b808aef724c861a2e1a38d2a69 (patch) | |
tree | e26c9fed0de5d9873cce7e00bc214573dc2195b7 /contrib/libs/grpc/test/cpp | |
download | ydb-1110808a9d39d4b808aef724c861a2e1a38d2a69.tar.gz |
intermediate changes
ref:cde9a383711a11544ce7e107a78147fb96cc4029
Diffstat (limited to 'contrib/libs/grpc/test/cpp')
88 files changed, 35933 insertions, 0 deletions
diff --git a/contrib/libs/grpc/test/cpp/README-iOS.md b/contrib/libs/grpc/test/cpp/README-iOS.md new file mode 100644 index 0000000000..898931085b --- /dev/null +++ b/contrib/libs/grpc/test/cpp/README-iOS.md @@ -0,0 +1,52 @@ +## C++ tests on iOS + +[GTMGoogleTestRunner](https://github.com/google/google-toolbox-for-mac/blob/master/UnitTesting/GTMGoogleTestRunner.mm) is used to convert googletest cases to XCTest that can be run on iOS. GTMGoogleTestRunner doesn't execute the `main` function, so we can't have any test logic in `main`. +However, it's ok to call `::testing::InitGoogleTest` in `main`, as `GTMGoogleTestRunner` [calls InitGoogleTest](https://github.com/google/google-toolbox-for-mac/blob/master/UnitTesting/GTMGoogleTestRunner.mm#L151). +`grpc::testing::TestEnvironment` can also be called from `main`, as it does some test initialization (install crash handler, seed RNG) that's not strictly required to run testcases on iOS. + + +## Porting exising C++ tests to run on iOS + +Please follow these guidelines when porting tests to run on iOS: + +- Tests need to use the googletest framework +- Any setup/teardown code in `main` needs to be moved to `SetUpTestCase`/`TearDownTestCase`, and `TEST` needs to be changed to `TEST_F`. +- [Death tests](https://github.com/google/googletest/blob/master/googletest/docs/advanced.md#death-tests) are not supported on iOS, so use the `*_IF_SUPPORTED()` macros to ensure that your code compiles on iOS. + +For example, the following test +```c++ +TEST(MyTest, TestOne) { + ASSERT_DEATH(ThisShouldDie(), ""); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + grpc_init(); + return RUN_ALL_TESTS(); + grpc_shutdown(); +} +``` + +should be changed to +```c++ +class MyTest : public ::testing::Test { + protected: + static void SetUpTestCase() { grpc_init(); } + static void TearDownTestCase() { grpc_shutdown(); } +}; + +TEST_F(MyTest, TestOne) { + ASSERT_DEATH_IF_SUPPORTED(ThisShouldDie(), ""); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} +``` + +## Limitations + +Due to a [limitation](https://github.com/google/google-toolbox-for-mac/blob/master/UnitTesting/GTMGoogleTestRunner.mm#L48-L56) in GTMGoogleTestRunner, `SetUpTestCase`/`TeardownTestCase` will be called before/after *every* individual test case, similar to `SetUp`/`TearDown`. diff --git a/contrib/libs/grpc/test/cpp/end2end/.yandex_meta/licenses.list.txt b/contrib/libs/grpc/test/cpp/end2end/.yandex_meta/licenses.list.txt new file mode 100644 index 0000000000..a07ea0849d --- /dev/null +++ b/contrib/libs/grpc/test/cpp/end2end/.yandex_meta/licenses.list.txt @@ -0,0 +1,36 @@ +====================Apache-2.0==================== + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + + +====================COPYRIGHT==================== + * Copyright 2015 gRPC authors. + + +====================COPYRIGHT==================== + * Copyright 2016 gRPC authors. + + +====================COPYRIGHT==================== + * Copyright 2017 gRPC authors. + + +====================COPYRIGHT==================== + * Copyright 2018 gRPC authors. + + +====================COPYRIGHT==================== +# Copyright 2019 gRPC authors. + + +====================COPYRIGHT==================== +// Copyright 2019 The gRPC Authors diff --git a/contrib/libs/grpc/test/cpp/end2end/async_end2end_test.cc b/contrib/libs/grpc/test/cpp/end2end/async_end2end_test.cc new file mode 100644 index 0000000000..45df8718f9 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/end2end/async_end2end_test.cc @@ -0,0 +1,1952 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <cinttypes> +#include <memory> +#include <thread> + +#include <grpc/grpc.h> +#include <grpc/support/alloc.h> +#include <grpc/support/log.h> +#include <grpc/support/time.h> +#include <grpcpp/channel.h> +#include <grpcpp/client_context.h> +#include <grpcpp/create_channel.h> +#include <grpcpp/ext/health_check_service_server_builder_option.h> +#include <grpcpp/server.h> +#include <grpcpp/server_builder.h> +#include <grpcpp/server_context.h> + +#include "src/core/ext/filters/client_channel/backup_poller.h" +#include "src/core/lib/gpr/tls.h" +#include "src/core/lib/iomgr/port.h" +#include "src/proto/grpc/health/v1/health.grpc.pb.h" +#include "src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/util/string_ref_helper.h" +#include "test/cpp/util/test_credentials_provider.h" + +#ifdef GRPC_POSIX_SOCKET_EV +#include "src/core/lib/iomgr/ev_posix.h" +#endif // GRPC_POSIX_SOCKET_EV + +#include <gtest/gtest.h> + +using grpc::testing::EchoRequest; +using grpc::testing::EchoResponse; +using grpc::testing::kTlsCredentialsType; +using std::chrono::system_clock; + +namespace grpc { +namespace testing { + +namespace { + +void* tag(int i) { return (void*)static_cast<intptr_t>(i); } +int detag(void* p) { return static_cast<int>(reinterpret_cast<intptr_t>(p)); } + +class Verifier { + public: + Verifier() : lambda_run_(false) {} + // Expect sets the expected ok value for a specific tag + Verifier& Expect(int i, bool expect_ok) { + return ExpectUnless(i, expect_ok, false); + } + // ExpectUnless sets the expected ok value for a specific tag + // unless the tag was already marked seen (as a result of ExpectMaybe) + Verifier& ExpectUnless(int i, bool expect_ok, bool seen) { + if (!seen) { + expectations_[tag(i)] = expect_ok; + } + return *this; + } + // ExpectMaybe sets the expected ok value for a specific tag, but does not + // require it to appear + // If it does, sets *seen to true + Verifier& ExpectMaybe(int i, bool expect_ok, bool* seen) { + if (!*seen) { + maybe_expectations_[tag(i)] = MaybeExpect{expect_ok, seen}; + } + return *this; + } + + // Next waits for 1 async tag to complete, checks its + // expectations, and returns the tag + int Next(CompletionQueue* cq, bool ignore_ok) { + bool ok; + void* got_tag; + EXPECT_TRUE(cq->Next(&got_tag, &ok)); + GotTag(got_tag, ok, ignore_ok); + return detag(got_tag); + } + + template <typename T> + CompletionQueue::NextStatus DoOnceThenAsyncNext( + CompletionQueue* cq, void** got_tag, bool* ok, T deadline, + std::function<void(void)> lambda) { + if (lambda_run_) { + return cq->AsyncNext(got_tag, ok, deadline); + } else { + lambda_run_ = true; + return cq->DoThenAsyncNext(lambda, got_tag, ok, deadline); + } + } + + // Verify keeps calling Next until all currently set + // expected tags are complete + void Verify(CompletionQueue* cq) { Verify(cq, false); } + + // This version of Verify allows optionally ignoring the + // outcome of the expectation + void Verify(CompletionQueue* cq, bool ignore_ok) { + GPR_ASSERT(!expectations_.empty() || !maybe_expectations_.empty()); + while (!expectations_.empty()) { + Next(cq, ignore_ok); + } + maybe_expectations_.clear(); + } + + // This version of Verify stops after a certain deadline + void Verify(CompletionQueue* cq, + std::chrono::system_clock::time_point deadline) { + if (expectations_.empty()) { + bool ok; + void* got_tag; + EXPECT_EQ(cq->AsyncNext(&got_tag, &ok, deadline), + CompletionQueue::TIMEOUT); + } else { + while (!expectations_.empty()) { + bool ok; + void* got_tag; + EXPECT_EQ(cq->AsyncNext(&got_tag, &ok, deadline), + CompletionQueue::GOT_EVENT); + GotTag(got_tag, ok, false); + } + } + maybe_expectations_.clear(); + } + + // This version of Verify stops after a certain deadline, and uses the + // DoThenAsyncNext API + // to call the lambda + void Verify(CompletionQueue* cq, + std::chrono::system_clock::time_point deadline, + const std::function<void(void)>& lambda) { + if (expectations_.empty()) { + bool ok; + void* got_tag; + EXPECT_EQ(DoOnceThenAsyncNext(cq, &got_tag, &ok, deadline, lambda), + CompletionQueue::TIMEOUT); + } else { + while (!expectations_.empty()) { + bool ok; + void* got_tag; + EXPECT_EQ(DoOnceThenAsyncNext(cq, &got_tag, &ok, deadline, lambda), + CompletionQueue::GOT_EVENT); + GotTag(got_tag, ok, false); + } + } + maybe_expectations_.clear(); + } + + private: + void GotTag(void* got_tag, bool ok, bool ignore_ok) { + auto it = expectations_.find(got_tag); + if (it != expectations_.end()) { + if (!ignore_ok) { + EXPECT_EQ(it->second, ok); + } + expectations_.erase(it); + } else { + auto it2 = maybe_expectations_.find(got_tag); + if (it2 != maybe_expectations_.end()) { + if (it2->second.seen != nullptr) { + EXPECT_FALSE(*it2->second.seen); + *it2->second.seen = true; + } + if (!ignore_ok) { + EXPECT_EQ(it2->second.ok, ok); + } + maybe_expectations_.erase(it2); + } else { + gpr_log(GPR_ERROR, "Unexpected tag: %p", got_tag); + abort(); + } + } + } + + struct MaybeExpect { + bool ok; + bool* seen; + }; + + std::map<void*, bool> expectations_; + std::map<void*, MaybeExpect> maybe_expectations_; + bool lambda_run_; +}; + +bool plugin_has_sync_methods(std::unique_ptr<ServerBuilderPlugin>& plugin) { + return plugin->has_sync_methods(); +} + +// This class disables the server builder plugins that may add sync services to +// the server. If there are sync services, UnimplementedRpc test will triger +// the sync unknown rpc routine on the server side, rather than the async one +// that needs to be tested here. +class ServerBuilderSyncPluginDisabler : public ::grpc::ServerBuilderOption { + public: + void UpdateArguments(ChannelArguments* /*arg*/) override {} + + void UpdatePlugins( + std::vector<std::unique_ptr<ServerBuilderPlugin>>* plugins) override { + plugins->erase(std::remove_if(plugins->begin(), plugins->end(), + plugin_has_sync_methods), + plugins->end()); + } +}; + +class TestScenario { + public: + TestScenario(bool inproc_stub, const TString& creds_type, bool hcs, + const TString& content) + : inproc(inproc_stub), + health_check_service(hcs), + credentials_type(creds_type), + message_content(content) {} + void Log() const; + bool inproc; + bool health_check_service; + const TString credentials_type; + const TString message_content; +}; + +static std::ostream& operator<<(std::ostream& out, + const TestScenario& scenario) { + return out << "TestScenario{inproc=" << (scenario.inproc ? "true" : "false") + << ", credentials='" << scenario.credentials_type + << ", health_check_service=" + << (scenario.health_check_service ? "true" : "false") + << "', message_size=" << scenario.message_content.size() << "}"; +} + +void TestScenario::Log() const { + std::ostringstream out; + out << *this; + gpr_log(GPR_DEBUG, "%s", out.str().c_str()); +} + +class HealthCheck : public health::v1::Health::Service {}; + +class AsyncEnd2endTest : public ::testing::TestWithParam<TestScenario> { + protected: + AsyncEnd2endTest() { GetParam().Log(); } + + void SetUp() override { + port_ = grpc_pick_unused_port_or_die(); + server_address_ << "localhost:" << port_; + + // Setup server + BuildAndStartServer(); + } + + void TearDown() override { + server_->Shutdown(); + void* ignored_tag; + bool ignored_ok; + cq_->Shutdown(); + while (cq_->Next(&ignored_tag, &ignored_ok)) + ; + stub_.reset(); + grpc_recycle_unused_port(port_); + } + + void BuildAndStartServer() { + ServerBuilder builder; + auto server_creds = GetCredentialsProvider()->GetServerCredentials( + GetParam().credentials_type); + builder.AddListeningPort(server_address_.str(), server_creds); + service_.reset(new grpc::testing::EchoTestService::AsyncService()); + builder.RegisterService(service_.get()); + if (GetParam().health_check_service) { + builder.RegisterService(&health_check_); + } + cq_ = builder.AddCompletionQueue(); + + // TODO(zyc): make a test option to choose wheather sync plugins should be + // deleted + std::unique_ptr<ServerBuilderOption> sync_plugin_disabler( + new ServerBuilderSyncPluginDisabler()); + builder.SetOption(move(sync_plugin_disabler)); + server_ = builder.BuildAndStart(); + } + + void ResetStub() { + ChannelArguments args; + auto channel_creds = GetCredentialsProvider()->GetChannelCredentials( + GetParam().credentials_type, &args); + std::shared_ptr<Channel> channel = + !(GetParam().inproc) ? ::grpc::CreateCustomChannel( + server_address_.str(), channel_creds, args) + : server_->InProcessChannel(args); + stub_ = grpc::testing::EchoTestService::NewStub(channel); + } + + void SendRpc(int num_rpcs) { + for (int i = 0; i < num_rpcs; i++) { + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + ServerContext srv_ctx; + grpc::ServerAsyncResponseWriter<EchoResponse> response_writer(&srv_ctx); + + send_request.set_message(GetParam().message_content); + std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader( + stub_->AsyncEcho(&cli_ctx, send_request, cq_.get())); + + service_->RequestEcho(&srv_ctx, &recv_request, &response_writer, + cq_.get(), cq_.get(), tag(2)); + + response_reader->Finish(&recv_response, &recv_status, tag(4)); + + Verifier().Expect(2, true).Verify(cq_.get()); + EXPECT_EQ(send_request.message(), recv_request.message()); + + send_response.set_message(recv_request.message()); + response_writer.Finish(send_response, Status::OK, tag(3)); + Verifier().Expect(3, true).Expect(4, true).Verify(cq_.get()); + + EXPECT_EQ(send_response.message(), recv_response.message()); + EXPECT_TRUE(recv_status.ok()); + } + } + + std::unique_ptr<ServerCompletionQueue> cq_; + std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_; + std::unique_ptr<Server> server_; + std::unique_ptr<grpc::testing::EchoTestService::AsyncService> service_; + HealthCheck health_check_; + std::ostringstream server_address_; + int port_; +}; + +TEST_P(AsyncEnd2endTest, SimpleRpc) { + ResetStub(); + SendRpc(1); +} + +TEST_P(AsyncEnd2endTest, SimpleRpcWithExpectedError) { + ResetStub(); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + ServerContext srv_ctx; + grpc::ServerAsyncResponseWriter<EchoResponse> response_writer(&srv_ctx); + ErrorStatus error_status; + + send_request.set_message(GetParam().message_content); + error_status.set_code(1); // CANCELLED + error_status.set_error_message("cancel error message"); + *send_request.mutable_param()->mutable_expected_error() = error_status; + + std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader( + stub_->AsyncEcho(&cli_ctx, send_request, cq_.get())); + + srv_ctx.AsyncNotifyWhenDone(tag(5)); + service_->RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(), + cq_.get(), tag(2)); + + response_reader->Finish(&recv_response, &recv_status, tag(4)); + + Verifier().Expect(2, true).Verify(cq_.get()); + EXPECT_EQ(send_request.message(), recv_request.message()); + + send_response.set_message(recv_request.message()); + response_writer.Finish( + send_response, + Status( + static_cast<StatusCode>(recv_request.param().expected_error().code()), + recv_request.param().expected_error().error_message()), + tag(3)); + Verifier().Expect(3, true).Expect(4, true).Expect(5, true).Verify(cq_.get()); + + EXPECT_EQ(recv_response.message(), ""); + EXPECT_EQ(recv_status.error_code(), error_status.code()); + EXPECT_EQ(recv_status.error_message(), error_status.error_message()); + EXPECT_FALSE(srv_ctx.IsCancelled()); +} + +TEST_P(AsyncEnd2endTest, SequentialRpcs) { + ResetStub(); + SendRpc(10); +} + +TEST_P(AsyncEnd2endTest, ReconnectChannel) { + // GRPC_CLIENT_CHANNEL_BACKUP_POLL_INTERVAL_MS is set to 100ms in main() + if (GetParam().inproc) { + return; + } + int poller_slowdown_factor = 1; +#ifdef GRPC_POSIX_SOCKET_EV + // It needs 2 pollset_works to reconnect the channel with polling engine + // "poll" + grpc_core::UniquePtr<char> poller = GPR_GLOBAL_CONFIG_GET(grpc_poll_strategy); + if (0 == strcmp(poller.get(), "poll")) { + poller_slowdown_factor = 2; + } +#endif // GRPC_POSIX_SOCKET_EV + ResetStub(); + SendRpc(1); + server_->Shutdown(); + void* ignored_tag; + bool ignored_ok; + cq_->Shutdown(); + while (cq_->Next(&ignored_tag, &ignored_ok)) + ; + BuildAndStartServer(); + // It needs more than GRPC_CLIENT_CHANNEL_BACKUP_POLL_INTERVAL_MS time to + // reconnect the channel. + gpr_sleep_until(gpr_time_add( + gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_millis( + 300 * poller_slowdown_factor * grpc_test_slowdown_factor(), + GPR_TIMESPAN))); + SendRpc(1); +} + +// We do not need to protect notify because the use is synchronized. +void ServerWait(Server* server, int* notify) { + server->Wait(); + *notify = 1; +} +TEST_P(AsyncEnd2endTest, WaitAndShutdownTest) { + int notify = 0; + std::thread wait_thread(&ServerWait, server_.get(), ¬ify); + ResetStub(); + SendRpc(1); + EXPECT_EQ(0, notify); + server_->Shutdown(); + wait_thread.join(); + EXPECT_EQ(1, notify); +} + +TEST_P(AsyncEnd2endTest, ShutdownThenWait) { + ResetStub(); + SendRpc(1); + std::thread t([this]() { server_->Shutdown(); }); + server_->Wait(); + t.join(); +} + +// Test a simple RPC using the async version of Next +TEST_P(AsyncEnd2endTest, AsyncNextRpc) { + ResetStub(); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + ServerContext srv_ctx; + grpc::ServerAsyncResponseWriter<EchoResponse> response_writer(&srv_ctx); + + send_request.set_message(GetParam().message_content); + std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader( + stub_->AsyncEcho(&cli_ctx, send_request, cq_.get())); + + std::chrono::system_clock::time_point time_now( + std::chrono::system_clock::now()); + std::chrono::system_clock::time_point time_limit( + std::chrono::system_clock::now() + std::chrono::seconds(10)); + Verifier().Verify(cq_.get(), time_now); + Verifier().Verify(cq_.get(), time_now); + + service_->RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(), + cq_.get(), tag(2)); + response_reader->Finish(&recv_response, &recv_status, tag(4)); + + Verifier().Expect(2, true).Verify(cq_.get(), time_limit); + EXPECT_EQ(send_request.message(), recv_request.message()); + + send_response.set_message(recv_request.message()); + response_writer.Finish(send_response, Status::OK, tag(3)); + Verifier().Expect(3, true).Expect(4, true).Verify( + cq_.get(), std::chrono::system_clock::time_point::max()); + + EXPECT_EQ(send_response.message(), recv_response.message()); + EXPECT_TRUE(recv_status.ok()); +} + +// Test a simple RPC using the async version of Next +TEST_P(AsyncEnd2endTest, DoThenAsyncNextRpc) { + ResetStub(); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + ServerContext srv_ctx; + grpc::ServerAsyncResponseWriter<EchoResponse> response_writer(&srv_ctx); + + send_request.set_message(GetParam().message_content); + std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader( + stub_->AsyncEcho(&cli_ctx, send_request, cq_.get())); + + std::chrono::system_clock::time_point time_now( + std::chrono::system_clock::now()); + std::chrono::system_clock::time_point time_limit( + std::chrono::system_clock::now() + std::chrono::seconds(10)); + Verifier().Verify(cq_.get(), time_now); + Verifier().Verify(cq_.get(), time_now); + + auto resp_writer_ptr = &response_writer; + auto lambda_2 = [&, this, resp_writer_ptr]() { + service_->RequestEcho(&srv_ctx, &recv_request, resp_writer_ptr, cq_.get(), + cq_.get(), tag(2)); + }; + response_reader->Finish(&recv_response, &recv_status, tag(4)); + + Verifier().Expect(2, true).Verify(cq_.get(), time_limit, lambda_2); + EXPECT_EQ(send_request.message(), recv_request.message()); + + send_response.set_message(recv_request.message()); + auto lambda_3 = [resp_writer_ptr, send_response]() { + resp_writer_ptr->Finish(send_response, Status::OK, tag(3)); + }; + Verifier().Expect(3, true).Expect(4, true).Verify( + cq_.get(), std::chrono::system_clock::time_point::max(), lambda_3); + + EXPECT_EQ(send_response.message(), recv_response.message()); + EXPECT_TRUE(recv_status.ok()); +} + +// Two pings and a final pong. +TEST_P(AsyncEnd2endTest, SimpleClientStreaming) { + ResetStub(); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + ClientContext cli_ctx; + ServerContext srv_ctx; + ServerAsyncReader<EchoResponse, EchoRequest> srv_stream(&srv_ctx); + + send_request.set_message(GetParam().message_content); + std::unique_ptr<ClientAsyncWriter<EchoRequest>> cli_stream( + stub_->AsyncRequestStream(&cli_ctx, &recv_response, cq_.get(), tag(1))); + + service_->RequestRequestStream(&srv_ctx, &srv_stream, cq_.get(), cq_.get(), + tag(2)); + + Verifier().Expect(2, true).Expect(1, true).Verify(cq_.get()); + + cli_stream->Write(send_request, tag(3)); + srv_stream.Read(&recv_request, tag(4)); + Verifier().Expect(3, true).Expect(4, true).Verify(cq_.get()); + EXPECT_EQ(send_request.message(), recv_request.message()); + + cli_stream->Write(send_request, tag(5)); + srv_stream.Read(&recv_request, tag(6)); + Verifier().Expect(5, true).Expect(6, true).Verify(cq_.get()); + + EXPECT_EQ(send_request.message(), recv_request.message()); + cli_stream->WritesDone(tag(7)); + srv_stream.Read(&recv_request, tag(8)); + Verifier().Expect(7, true).Expect(8, false).Verify(cq_.get()); + + send_response.set_message(recv_request.message()); + srv_stream.Finish(send_response, Status::OK, tag(9)); + cli_stream->Finish(&recv_status, tag(10)); + Verifier().Expect(9, true).Expect(10, true).Verify(cq_.get()); + + EXPECT_EQ(send_response.message(), recv_response.message()); + EXPECT_TRUE(recv_status.ok()); +} + +// Two pings and a final pong. +TEST_P(AsyncEnd2endTest, SimpleClientStreamingWithCoalescingApi) { + ResetStub(); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + ClientContext cli_ctx; + ServerContext srv_ctx; + ServerAsyncReader<EchoResponse, EchoRequest> srv_stream(&srv_ctx); + + send_request.set_message(GetParam().message_content); + cli_ctx.set_initial_metadata_corked(true); + // tag:1 never comes up since no op is performed + std::unique_ptr<ClientAsyncWriter<EchoRequest>> cli_stream( + stub_->AsyncRequestStream(&cli_ctx, &recv_response, cq_.get(), tag(1))); + + service_->RequestRequestStream(&srv_ctx, &srv_stream, cq_.get(), cq_.get(), + tag(2)); + + cli_stream->Write(send_request, tag(3)); + + bool seen3 = false; + + Verifier().Expect(2, true).ExpectMaybe(3, true, &seen3).Verify(cq_.get()); + + srv_stream.Read(&recv_request, tag(4)); + + Verifier().ExpectUnless(3, true, seen3).Expect(4, true).Verify(cq_.get()); + + EXPECT_EQ(send_request.message(), recv_request.message()); + + cli_stream->WriteLast(send_request, WriteOptions(), tag(5)); + srv_stream.Read(&recv_request, tag(6)); + Verifier().Expect(5, true).Expect(6, true).Verify(cq_.get()); + EXPECT_EQ(send_request.message(), recv_request.message()); + + srv_stream.Read(&recv_request, tag(7)); + Verifier().Expect(7, false).Verify(cq_.get()); + + send_response.set_message(recv_request.message()); + srv_stream.Finish(send_response, Status::OK, tag(8)); + cli_stream->Finish(&recv_status, tag(9)); + Verifier().Expect(8, true).Expect(9, true).Verify(cq_.get()); + + EXPECT_EQ(send_response.message(), recv_response.message()); + EXPECT_TRUE(recv_status.ok()); +} + +// One ping, two pongs. +TEST_P(AsyncEnd2endTest, SimpleServerStreaming) { + ResetStub(); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + ClientContext cli_ctx; + ServerContext srv_ctx; + ServerAsyncWriter<EchoResponse> srv_stream(&srv_ctx); + + send_request.set_message(GetParam().message_content); + std::unique_ptr<ClientAsyncReader<EchoResponse>> cli_stream( + stub_->AsyncResponseStream(&cli_ctx, send_request, cq_.get(), tag(1))); + + service_->RequestResponseStream(&srv_ctx, &recv_request, &srv_stream, + cq_.get(), cq_.get(), tag(2)); + + Verifier().Expect(1, true).Expect(2, true).Verify(cq_.get()); + EXPECT_EQ(send_request.message(), recv_request.message()); + + send_response.set_message(recv_request.message()); + srv_stream.Write(send_response, tag(3)); + cli_stream->Read(&recv_response, tag(4)); + Verifier().Expect(3, true).Expect(4, true).Verify(cq_.get()); + EXPECT_EQ(send_response.message(), recv_response.message()); + + srv_stream.Write(send_response, tag(5)); + cli_stream->Read(&recv_response, tag(6)); + Verifier().Expect(5, true).Expect(6, true).Verify(cq_.get()); + EXPECT_EQ(send_response.message(), recv_response.message()); + + srv_stream.Finish(Status::OK, tag(7)); + cli_stream->Read(&recv_response, tag(8)); + Verifier().Expect(7, true).Expect(8, false).Verify(cq_.get()); + + cli_stream->Finish(&recv_status, tag(9)); + Verifier().Expect(9, true).Verify(cq_.get()); + + EXPECT_TRUE(recv_status.ok()); +} + +// One ping, two pongs. Using WriteAndFinish API +TEST_P(AsyncEnd2endTest, SimpleServerStreamingWithCoalescingApiWAF) { + ResetStub(); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + ClientContext cli_ctx; + ServerContext srv_ctx; + ServerAsyncWriter<EchoResponse> srv_stream(&srv_ctx); + + send_request.set_message(GetParam().message_content); + std::unique_ptr<ClientAsyncReader<EchoResponse>> cli_stream( + stub_->AsyncResponseStream(&cli_ctx, send_request, cq_.get(), tag(1))); + + service_->RequestResponseStream(&srv_ctx, &recv_request, &srv_stream, + cq_.get(), cq_.get(), tag(2)); + + Verifier().Expect(1, true).Expect(2, true).Verify(cq_.get()); + EXPECT_EQ(send_request.message(), recv_request.message()); + + send_response.set_message(recv_request.message()); + srv_stream.Write(send_response, tag(3)); + cli_stream->Read(&recv_response, tag(4)); + Verifier().Expect(3, true).Expect(4, true).Verify(cq_.get()); + EXPECT_EQ(send_response.message(), recv_response.message()); + + srv_stream.WriteAndFinish(send_response, WriteOptions(), Status::OK, tag(5)); + cli_stream->Read(&recv_response, tag(6)); + Verifier().Expect(5, true).Expect(6, true).Verify(cq_.get()); + EXPECT_EQ(send_response.message(), recv_response.message()); + + cli_stream->Read(&recv_response, tag(7)); + Verifier().Expect(7, false).Verify(cq_.get()); + + cli_stream->Finish(&recv_status, tag(8)); + Verifier().Expect(8, true).Verify(cq_.get()); + + EXPECT_TRUE(recv_status.ok()); +} + +// One ping, two pongs. Using WriteLast API +TEST_P(AsyncEnd2endTest, SimpleServerStreamingWithCoalescingApiWL) { + ResetStub(); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + ClientContext cli_ctx; + ServerContext srv_ctx; + ServerAsyncWriter<EchoResponse> srv_stream(&srv_ctx); + + send_request.set_message(GetParam().message_content); + std::unique_ptr<ClientAsyncReader<EchoResponse>> cli_stream( + stub_->AsyncResponseStream(&cli_ctx, send_request, cq_.get(), tag(1))); + + service_->RequestResponseStream(&srv_ctx, &recv_request, &srv_stream, + cq_.get(), cq_.get(), tag(2)); + + Verifier().Expect(1, true).Expect(2, true).Verify(cq_.get()); + EXPECT_EQ(send_request.message(), recv_request.message()); + + send_response.set_message(recv_request.message()); + srv_stream.Write(send_response, tag(3)); + cli_stream->Read(&recv_response, tag(4)); + Verifier().Expect(3, true).Expect(4, true).Verify(cq_.get()); + EXPECT_EQ(send_response.message(), recv_response.message()); + + srv_stream.WriteLast(send_response, WriteOptions(), tag(5)); + cli_stream->Read(&recv_response, tag(6)); + srv_stream.Finish(Status::OK, tag(7)); + Verifier().Expect(5, true).Expect(6, true).Expect(7, true).Verify(cq_.get()); + EXPECT_EQ(send_response.message(), recv_response.message()); + + cli_stream->Read(&recv_response, tag(8)); + Verifier().Expect(8, false).Verify(cq_.get()); + + cli_stream->Finish(&recv_status, tag(9)); + Verifier().Expect(9, true).Verify(cq_.get()); + + EXPECT_TRUE(recv_status.ok()); +} + +// One ping, one pong. +TEST_P(AsyncEnd2endTest, SimpleBidiStreaming) { + ResetStub(); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + ClientContext cli_ctx; + ServerContext srv_ctx; + ServerAsyncReaderWriter<EchoResponse, EchoRequest> srv_stream(&srv_ctx); + + send_request.set_message(GetParam().message_content); + std::unique_ptr<ClientAsyncReaderWriter<EchoRequest, EchoResponse>> + cli_stream(stub_->AsyncBidiStream(&cli_ctx, cq_.get(), tag(1))); + + service_->RequestBidiStream(&srv_ctx, &srv_stream, cq_.get(), cq_.get(), + tag(2)); + + Verifier().Expect(1, true).Expect(2, true).Verify(cq_.get()); + + cli_stream->Write(send_request, tag(3)); + srv_stream.Read(&recv_request, tag(4)); + Verifier().Expect(3, true).Expect(4, true).Verify(cq_.get()); + EXPECT_EQ(send_request.message(), recv_request.message()); + + send_response.set_message(recv_request.message()); + srv_stream.Write(send_response, tag(5)); + cli_stream->Read(&recv_response, tag(6)); + Verifier().Expect(5, true).Expect(6, true).Verify(cq_.get()); + EXPECT_EQ(send_response.message(), recv_response.message()); + + cli_stream->WritesDone(tag(7)); + srv_stream.Read(&recv_request, tag(8)); + Verifier().Expect(7, true).Expect(8, false).Verify(cq_.get()); + + srv_stream.Finish(Status::OK, tag(9)); + cli_stream->Finish(&recv_status, tag(10)); + Verifier().Expect(9, true).Expect(10, true).Verify(cq_.get()); + + EXPECT_TRUE(recv_status.ok()); +} + +// One ping, one pong. Using server:WriteAndFinish api +TEST_P(AsyncEnd2endTest, SimpleBidiStreamingWithCoalescingApiWAF) { + ResetStub(); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + ClientContext cli_ctx; + ServerContext srv_ctx; + ServerAsyncReaderWriter<EchoResponse, EchoRequest> srv_stream(&srv_ctx); + + send_request.set_message(GetParam().message_content); + cli_ctx.set_initial_metadata_corked(true); + std::unique_ptr<ClientAsyncReaderWriter<EchoRequest, EchoResponse>> + cli_stream(stub_->AsyncBidiStream(&cli_ctx, cq_.get(), tag(1))); + + service_->RequestBidiStream(&srv_ctx, &srv_stream, cq_.get(), cq_.get(), + tag(2)); + + cli_stream->WriteLast(send_request, WriteOptions(), tag(3)); + + bool seen3 = false; + + Verifier().Expect(2, true).ExpectMaybe(3, true, &seen3).Verify(cq_.get()); + + srv_stream.Read(&recv_request, tag(4)); + + Verifier().ExpectUnless(3, true, seen3).Expect(4, true).Verify(cq_.get()); + EXPECT_EQ(send_request.message(), recv_request.message()); + + srv_stream.Read(&recv_request, tag(5)); + Verifier().Expect(5, false).Verify(cq_.get()); + + send_response.set_message(recv_request.message()); + srv_stream.WriteAndFinish(send_response, WriteOptions(), Status::OK, tag(6)); + cli_stream->Read(&recv_response, tag(7)); + Verifier().Expect(6, true).Expect(7, true).Verify(cq_.get()); + EXPECT_EQ(send_response.message(), recv_response.message()); + + cli_stream->Finish(&recv_status, tag(8)); + Verifier().Expect(8, true).Verify(cq_.get()); + + EXPECT_TRUE(recv_status.ok()); +} + +// One ping, one pong. Using server:WriteLast api +TEST_P(AsyncEnd2endTest, SimpleBidiStreamingWithCoalescingApiWL) { + ResetStub(); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + ClientContext cli_ctx; + ServerContext srv_ctx; + ServerAsyncReaderWriter<EchoResponse, EchoRequest> srv_stream(&srv_ctx); + + send_request.set_message(GetParam().message_content); + cli_ctx.set_initial_metadata_corked(true); + std::unique_ptr<ClientAsyncReaderWriter<EchoRequest, EchoResponse>> + cli_stream(stub_->AsyncBidiStream(&cli_ctx, cq_.get(), tag(1))); + + service_->RequestBidiStream(&srv_ctx, &srv_stream, cq_.get(), cq_.get(), + tag(2)); + + cli_stream->WriteLast(send_request, WriteOptions(), tag(3)); + + bool seen3 = false; + + Verifier().Expect(2, true).ExpectMaybe(3, true, &seen3).Verify(cq_.get()); + + srv_stream.Read(&recv_request, tag(4)); + + Verifier().ExpectUnless(3, true, seen3).Expect(4, true).Verify(cq_.get()); + EXPECT_EQ(send_request.message(), recv_request.message()); + + srv_stream.Read(&recv_request, tag(5)); + Verifier().Expect(5, false).Verify(cq_.get()); + + send_response.set_message(recv_request.message()); + srv_stream.WriteLast(send_response, WriteOptions(), tag(6)); + srv_stream.Finish(Status::OK, tag(7)); + cli_stream->Read(&recv_response, tag(8)); + Verifier().Expect(6, true).Expect(7, true).Expect(8, true).Verify(cq_.get()); + EXPECT_EQ(send_response.message(), recv_response.message()); + + cli_stream->Finish(&recv_status, tag(9)); + Verifier().Expect(9, true).Verify(cq_.get()); + + EXPECT_TRUE(recv_status.ok()); +} + +// Metadata tests +TEST_P(AsyncEnd2endTest, ClientInitialMetadataRpc) { + ResetStub(); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + ServerContext srv_ctx; + grpc::ServerAsyncResponseWriter<EchoResponse> response_writer(&srv_ctx); + + send_request.set_message(GetParam().message_content); + std::pair<TString, TString> meta1("key1", "val1"); + std::pair<TString, TString> meta2("key2", "val2"); + std::pair<TString, TString> meta3("g.r.d-bin", "xyz"); + cli_ctx.AddMetadata(meta1.first, meta1.second); + cli_ctx.AddMetadata(meta2.first, meta2.second); + cli_ctx.AddMetadata(meta3.first, meta3.second); + + std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader( + stub_->AsyncEcho(&cli_ctx, send_request, cq_.get())); + response_reader->Finish(&recv_response, &recv_status, tag(4)); + + service_->RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(), + cq_.get(), tag(2)); + Verifier().Expect(2, true).Verify(cq_.get()); + EXPECT_EQ(send_request.message(), recv_request.message()); + const auto& client_initial_metadata = srv_ctx.client_metadata(); + EXPECT_EQ(meta1.second, + ToString(client_initial_metadata.find(meta1.first)->second)); + EXPECT_EQ(meta2.second, + ToString(client_initial_metadata.find(meta2.first)->second)); + EXPECT_EQ(meta3.second, + ToString(client_initial_metadata.find(meta3.first)->second)); + EXPECT_GE(client_initial_metadata.size(), static_cast<size_t>(2)); + + send_response.set_message(recv_request.message()); + response_writer.Finish(send_response, Status::OK, tag(3)); + Verifier().Expect(3, true).Expect(4, true).Verify(cq_.get()); + + EXPECT_EQ(send_response.message(), recv_response.message()); + EXPECT_TRUE(recv_status.ok()); +} + +TEST_P(AsyncEnd2endTest, ServerInitialMetadataRpc) { + ResetStub(); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + ServerContext srv_ctx; + grpc::ServerAsyncResponseWriter<EchoResponse> response_writer(&srv_ctx); + + send_request.set_message(GetParam().message_content); + std::pair<TString, TString> meta1("key1", "val1"); + std::pair<TString, TString> meta2("key2", "val2"); + + std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader( + stub_->AsyncEcho(&cli_ctx, send_request, cq_.get())); + response_reader->ReadInitialMetadata(tag(4)); + + service_->RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(), + cq_.get(), tag(2)); + Verifier().Expect(2, true).Verify(cq_.get()); + EXPECT_EQ(send_request.message(), recv_request.message()); + srv_ctx.AddInitialMetadata(meta1.first, meta1.second); + srv_ctx.AddInitialMetadata(meta2.first, meta2.second); + response_writer.SendInitialMetadata(tag(3)); + Verifier().Expect(3, true).Expect(4, true).Verify(cq_.get()); + const auto& server_initial_metadata = cli_ctx.GetServerInitialMetadata(); + EXPECT_EQ(meta1.second, + ToString(server_initial_metadata.find(meta1.first)->second)); + EXPECT_EQ(meta2.second, + ToString(server_initial_metadata.find(meta2.first)->second)); + EXPECT_EQ(static_cast<size_t>(2), server_initial_metadata.size()); + + send_response.set_message(recv_request.message()); + response_writer.Finish(send_response, Status::OK, tag(5)); + response_reader->Finish(&recv_response, &recv_status, tag(6)); + Verifier().Expect(5, true).Expect(6, true).Verify(cq_.get()); + + EXPECT_EQ(send_response.message(), recv_response.message()); + EXPECT_TRUE(recv_status.ok()); +} + +// 1 ping, 2 pongs. +TEST_P(AsyncEnd2endTest, ServerInitialMetadataServerStreaming) { + ResetStub(); + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + ClientContext cli_ctx; + ServerContext srv_ctx; + ServerAsyncWriter<EchoResponse> srv_stream(&srv_ctx); + + std::pair<::TString, ::TString> meta1("key1", "val1"); + std::pair<::TString, ::TString> meta2("key2", "val2"); + + std::unique_ptr<ClientAsyncReader<EchoResponse>> cli_stream( + stub_->AsyncResponseStream(&cli_ctx, send_request, cq_.get(), tag(1))); + cli_stream->ReadInitialMetadata(tag(11)); + service_->RequestResponseStream(&srv_ctx, &recv_request, &srv_stream, + cq_.get(), cq_.get(), tag(2)); + + Verifier().Expect(1, true).Expect(2, true).Verify(cq_.get()); + + srv_ctx.AddInitialMetadata(meta1.first, meta1.second); + srv_ctx.AddInitialMetadata(meta2.first, meta2.second); + srv_stream.SendInitialMetadata(tag(10)); + Verifier().Expect(10, true).Expect(11, true).Verify(cq_.get()); + auto server_initial_metadata = cli_ctx.GetServerInitialMetadata(); + EXPECT_EQ(meta1.second, + ToString(server_initial_metadata.find(meta1.first)->second)); + EXPECT_EQ(meta2.second, + ToString(server_initial_metadata.find(meta2.first)->second)); + EXPECT_EQ(static_cast<size_t>(2), server_initial_metadata.size()); + + srv_stream.Write(send_response, tag(3)); + + cli_stream->Read(&recv_response, tag(4)); + Verifier().Expect(3, true).Expect(4, true).Verify(cq_.get()); + + srv_stream.Write(send_response, tag(5)); + cli_stream->Read(&recv_response, tag(6)); + Verifier().Expect(5, true).Expect(6, true).Verify(cq_.get()); + + srv_stream.Finish(Status::OK, tag(7)); + cli_stream->Read(&recv_response, tag(8)); + Verifier().Expect(7, true).Expect(8, false).Verify(cq_.get()); + + cli_stream->Finish(&recv_status, tag(9)); + Verifier().Expect(9, true).Verify(cq_.get()); + + EXPECT_TRUE(recv_status.ok()); +} + +// 1 ping, 2 pongs. +// Test for server initial metadata being sent implicitly +TEST_P(AsyncEnd2endTest, ServerInitialMetadataServerStreamingImplicit) { + ResetStub(); + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + ClientContext cli_ctx; + ServerContext srv_ctx; + ServerAsyncWriter<EchoResponse> srv_stream(&srv_ctx); + + send_request.set_message(GetParam().message_content); + std::pair<::TString, ::TString> meta1("key1", "val1"); + std::pair<::TString, ::TString> meta2("key2", "val2"); + + std::unique_ptr<ClientAsyncReader<EchoResponse>> cli_stream( + stub_->AsyncResponseStream(&cli_ctx, send_request, cq_.get(), tag(1))); + service_->RequestResponseStream(&srv_ctx, &recv_request, &srv_stream, + cq_.get(), cq_.get(), tag(2)); + + Verifier().Expect(1, true).Expect(2, true).Verify(cq_.get()); + EXPECT_EQ(send_request.message(), recv_request.message()); + + srv_ctx.AddInitialMetadata(meta1.first, meta1.second); + srv_ctx.AddInitialMetadata(meta2.first, meta2.second); + send_response.set_message(recv_request.message()); + srv_stream.Write(send_response, tag(3)); + + cli_stream->Read(&recv_response, tag(4)); + Verifier().Expect(3, true).Expect(4, true).Verify(cq_.get()); + EXPECT_EQ(send_response.message(), recv_response.message()); + + auto server_initial_metadata = cli_ctx.GetServerInitialMetadata(); + EXPECT_EQ(meta1.second, + ToString(server_initial_metadata.find(meta1.first)->second)); + EXPECT_EQ(meta2.second, + ToString(server_initial_metadata.find(meta2.first)->second)); + EXPECT_EQ(static_cast<size_t>(2), server_initial_metadata.size()); + + srv_stream.Write(send_response, tag(5)); + cli_stream->Read(&recv_response, tag(6)); + Verifier().Expect(5, true).Expect(6, true).Verify(cq_.get()); + + srv_stream.Finish(Status::OK, tag(7)); + cli_stream->Read(&recv_response, tag(8)); + Verifier().Expect(7, true).Expect(8, false).Verify(cq_.get()); + + cli_stream->Finish(&recv_status, tag(9)); + Verifier().Expect(9, true).Verify(cq_.get()); + + EXPECT_TRUE(recv_status.ok()); +} + +TEST_P(AsyncEnd2endTest, ServerTrailingMetadataRpc) { + ResetStub(); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + ServerContext srv_ctx; + grpc::ServerAsyncResponseWriter<EchoResponse> response_writer(&srv_ctx); + + send_request.set_message(GetParam().message_content); + std::pair<TString, TString> meta1("key1", "val1"); + std::pair<TString, TString> meta2("key2", "val2"); + + std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader( + stub_->AsyncEcho(&cli_ctx, send_request, cq_.get())); + response_reader->Finish(&recv_response, &recv_status, tag(5)); + + service_->RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(), + cq_.get(), tag(2)); + Verifier().Expect(2, true).Verify(cq_.get()); + EXPECT_EQ(send_request.message(), recv_request.message()); + response_writer.SendInitialMetadata(tag(3)); + Verifier().Expect(3, true).Verify(cq_.get()); + + send_response.set_message(recv_request.message()); + srv_ctx.AddTrailingMetadata(meta1.first, meta1.second); + srv_ctx.AddTrailingMetadata(meta2.first, meta2.second); + response_writer.Finish(send_response, Status::OK, tag(4)); + + Verifier().Expect(4, true).Expect(5, true).Verify(cq_.get()); + + EXPECT_EQ(send_response.message(), recv_response.message()); + EXPECT_TRUE(recv_status.ok()); + const auto& server_trailing_metadata = cli_ctx.GetServerTrailingMetadata(); + EXPECT_EQ(meta1.second, + ToString(server_trailing_metadata.find(meta1.first)->second)); + EXPECT_EQ(meta2.second, + ToString(server_trailing_metadata.find(meta2.first)->second)); + EXPECT_EQ(static_cast<size_t>(2), server_trailing_metadata.size()); +} + +TEST_P(AsyncEnd2endTest, MetadataRpc) { + ResetStub(); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + ServerContext srv_ctx; + grpc::ServerAsyncResponseWriter<EchoResponse> response_writer(&srv_ctx); + + send_request.set_message(GetParam().message_content); + std::pair<TString, TString> meta1("key1", "val1"); + std::pair<TString, TString> meta2( + "key2-bin", + TString("\xc0\xc1\xc2\xc3\xc4\xc5\xc6\xc7\xc8\xc9\xca\xcb\xcc", 13)); + std::pair<TString, TString> meta3("key3", "val3"); + std::pair<TString, TString> meta6( + "key4-bin", + TString("\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d", + 14)); + std::pair<TString, TString> meta5("key5", "val5"); + std::pair<TString, TString> meta4( + "key6-bin", + TString( + "\xe0\xe1\xe2\xe3\xe4\xe5\xe6\xe7\xe8\xe9\xea\xeb\xec\xed\xee", 15)); + + cli_ctx.AddMetadata(meta1.first, meta1.second); + cli_ctx.AddMetadata(meta2.first, meta2.second); + + std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader( + stub_->AsyncEcho(&cli_ctx, send_request, cq_.get())); + response_reader->ReadInitialMetadata(tag(4)); + + service_->RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(), + cq_.get(), tag(2)); + Verifier().Expect(2, true).Verify(cq_.get()); + EXPECT_EQ(send_request.message(), recv_request.message()); + const auto& client_initial_metadata = srv_ctx.client_metadata(); + EXPECT_EQ(meta1.second, + ToString(client_initial_metadata.find(meta1.first)->second)); + EXPECT_EQ(meta2.second, + ToString(client_initial_metadata.find(meta2.first)->second)); + EXPECT_GE(client_initial_metadata.size(), static_cast<size_t>(2)); + + srv_ctx.AddInitialMetadata(meta3.first, meta3.second); + srv_ctx.AddInitialMetadata(meta4.first, meta4.second); + response_writer.SendInitialMetadata(tag(3)); + Verifier().Expect(3, true).Expect(4, true).Verify(cq_.get()); + const auto& server_initial_metadata = cli_ctx.GetServerInitialMetadata(); + EXPECT_EQ(meta3.second, + ToString(server_initial_metadata.find(meta3.first)->second)); + EXPECT_EQ(meta4.second, + ToString(server_initial_metadata.find(meta4.first)->second)); + EXPECT_GE(server_initial_metadata.size(), static_cast<size_t>(2)); + + send_response.set_message(recv_request.message()); + srv_ctx.AddTrailingMetadata(meta5.first, meta5.second); + srv_ctx.AddTrailingMetadata(meta6.first, meta6.second); + response_writer.Finish(send_response, Status::OK, tag(5)); + response_reader->Finish(&recv_response, &recv_status, tag(6)); + + Verifier().Expect(5, true).Expect(6, true).Verify(cq_.get()); + + EXPECT_EQ(send_response.message(), recv_response.message()); + EXPECT_TRUE(recv_status.ok()); + const auto& server_trailing_metadata = cli_ctx.GetServerTrailingMetadata(); + EXPECT_EQ(meta5.second, + ToString(server_trailing_metadata.find(meta5.first)->second)); + EXPECT_EQ(meta6.second, + ToString(server_trailing_metadata.find(meta6.first)->second)); + EXPECT_GE(server_trailing_metadata.size(), static_cast<size_t>(2)); +} + +// Server uses AsyncNotifyWhenDone API to check for cancellation +TEST_P(AsyncEnd2endTest, ServerCheckCancellation) { + ResetStub(); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + ServerContext srv_ctx; + grpc::ServerAsyncResponseWriter<EchoResponse> response_writer(&srv_ctx); + + send_request.set_message(GetParam().message_content); + std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader( + stub_->AsyncEcho(&cli_ctx, send_request, cq_.get())); + response_reader->Finish(&recv_response, &recv_status, tag(4)); + + srv_ctx.AsyncNotifyWhenDone(tag(5)); + service_->RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(), + cq_.get(), tag(2)); + + Verifier().Expect(2, true).Verify(cq_.get()); + EXPECT_EQ(send_request.message(), recv_request.message()); + + cli_ctx.TryCancel(); + Verifier().Expect(5, true).Expect(4, true).Verify(cq_.get()); + EXPECT_TRUE(srv_ctx.IsCancelled()); + + EXPECT_EQ(StatusCode::CANCELLED, recv_status.error_code()); +} + +// Server uses AsyncNotifyWhenDone API to check for normal finish +TEST_P(AsyncEnd2endTest, ServerCheckDone) { + ResetStub(); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + ServerContext srv_ctx; + grpc::ServerAsyncResponseWriter<EchoResponse> response_writer(&srv_ctx); + + send_request.set_message(GetParam().message_content); + std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader( + stub_->AsyncEcho(&cli_ctx, send_request, cq_.get())); + response_reader->Finish(&recv_response, &recv_status, tag(4)); + + srv_ctx.AsyncNotifyWhenDone(tag(5)); + service_->RequestEcho(&srv_ctx, &recv_request, &response_writer, cq_.get(), + cq_.get(), tag(2)); + + Verifier().Expect(2, true).Verify(cq_.get()); + EXPECT_EQ(send_request.message(), recv_request.message()); + + send_response.set_message(recv_request.message()); + response_writer.Finish(send_response, Status::OK, tag(3)); + Verifier().Expect(3, true).Expect(4, true).Expect(5, true).Verify(cq_.get()); + EXPECT_FALSE(srv_ctx.IsCancelled()); + + EXPECT_EQ(send_response.message(), recv_response.message()); + EXPECT_TRUE(recv_status.ok()); +} + +TEST_P(AsyncEnd2endTest, UnimplementedRpc) { + ChannelArguments args; + const auto& channel_creds = GetCredentialsProvider()->GetChannelCredentials( + GetParam().credentials_type, &args); + std::shared_ptr<Channel> channel = + !(GetParam().inproc) ? ::grpc::CreateCustomChannel(server_address_.str(), + channel_creds, args) + : server_->InProcessChannel(args); + std::unique_ptr<grpc::testing::UnimplementedEchoService::Stub> stub; + stub = grpc::testing::UnimplementedEchoService::NewStub(channel); + EchoRequest send_request; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + send_request.set_message(GetParam().message_content); + std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader( + stub->AsyncUnimplemented(&cli_ctx, send_request, cq_.get())); + + response_reader->Finish(&recv_response, &recv_status, tag(4)); + Verifier().Expect(4, true).Verify(cq_.get()); + + EXPECT_EQ(StatusCode::UNIMPLEMENTED, recv_status.error_code()); + EXPECT_EQ("", recv_status.error_message()); +} + +// This class is for testing scenarios where RPCs are cancelled on the server +// by calling ServerContext::TryCancel(). Server uses AsyncNotifyWhenDone +// API to check for cancellation +class AsyncEnd2endServerTryCancelTest : public AsyncEnd2endTest { + protected: + typedef enum { + DO_NOT_CANCEL = 0, + CANCEL_BEFORE_PROCESSING, + CANCEL_DURING_PROCESSING, + CANCEL_AFTER_PROCESSING + } ServerTryCancelRequestPhase; + + // Helper for testing client-streaming RPCs which are cancelled on the server. + // Depending on the value of server_try_cancel parameter, this will test one + // of the following three scenarios: + // CANCEL_BEFORE_PROCESSING: Rpc is cancelled by the server before reading + // any messages from the client + // + // CANCEL_DURING_PROCESSING: Rpc is cancelled by the server while reading + // messages from the client + // + // CANCEL_AFTER PROCESSING: Rpc is cancelled by server after reading all + // messages from the client (but before sending any status back to the + // client) + void TestClientStreamingServerCancel( + ServerTryCancelRequestPhase server_try_cancel) { + ResetStub(); + + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + ServerContext srv_ctx; + ServerAsyncReader<EchoResponse, EchoRequest> srv_stream(&srv_ctx); + + // Initiate the 'RequestStream' call on client + CompletionQueue cli_cq; + + std::unique_ptr<ClientAsyncWriter<EchoRequest>> cli_stream( + stub_->AsyncRequestStream(&cli_ctx, &recv_response, &cli_cq, tag(1))); + + // On the server, request to be notified of 'RequestStream' calls + // and receive the 'RequestStream' call just made by the client + srv_ctx.AsyncNotifyWhenDone(tag(11)); + service_->RequestRequestStream(&srv_ctx, &srv_stream, cq_.get(), cq_.get(), + tag(2)); + std::thread t1([&cli_cq] { Verifier().Expect(1, true).Verify(&cli_cq); }); + Verifier().Expect(2, true).Verify(cq_.get()); + t1.join(); + + bool expected_server_cq_result = true; + bool expected_client_cq_result = true; + + if (server_try_cancel == CANCEL_BEFORE_PROCESSING) { + srv_ctx.TryCancel(); + Verifier().Expect(11, true).Verify(cq_.get()); + EXPECT_TRUE(srv_ctx.IsCancelled()); + + // Since cancellation is done before server reads any results, we know + // for sure that all server cq results will return false from this + // point forward + expected_server_cq_result = false; + expected_client_cq_result = false; + } + + bool ignore_client_cq_result = + (server_try_cancel == CANCEL_DURING_PROCESSING) || + (server_try_cancel == CANCEL_BEFORE_PROCESSING); + + std::thread cli_thread([&cli_cq, &cli_stream, &expected_client_cq_result, + &ignore_client_cq_result] { + EchoRequest send_request; + // Client sends 3 messages (tags 3, 4 and 5) + for (int tag_idx = 3; tag_idx <= 5; tag_idx++) { + send_request.set_message("Ping " + ToString(tag_idx)); + cli_stream->Write(send_request, tag(tag_idx)); + Verifier() + .Expect(tag_idx, expected_client_cq_result) + .Verify(&cli_cq, ignore_client_cq_result); + } + cli_stream->WritesDone(tag(6)); + // Ignore ok on WritesDone since cancel can affect it + Verifier() + .Expect(6, expected_client_cq_result) + .Verify(&cli_cq, ignore_client_cq_result); + }); + + bool ignore_cq_result = false; + bool want_done_tag = false; + std::thread* server_try_cancel_thd = nullptr; + + auto verif = Verifier(); + + if (server_try_cancel == CANCEL_DURING_PROCESSING) { + server_try_cancel_thd = + new std::thread([&srv_ctx] { srv_ctx.TryCancel(); }); + // Server will cancel the RPC in a parallel thread while reading the + // requests from the client. Since the cancellation can happen at anytime, + // some of the cq results (i.e those until cancellation) might be true but + // its non deterministic. So better to ignore the cq results + ignore_cq_result = true; + // Expect that we might possibly see the done tag that + // indicates cancellation completion in this case + want_done_tag = true; + verif.Expect(11, true); + } + + // Server reads 3 messages (tags 6, 7 and 8) + // But if want_done_tag is true, we might also see tag 11 + for (int tag_idx = 6; tag_idx <= 8; tag_idx++) { + srv_stream.Read(&recv_request, tag(tag_idx)); + // Note that we'll add something to the verifier and verify that + // something was seen, but it might be tag 11 and not what we + // just added + int got_tag = verif.Expect(tag_idx, expected_server_cq_result) + .Next(cq_.get(), ignore_cq_result); + GPR_ASSERT((got_tag == tag_idx) || (got_tag == 11 && want_done_tag)); + if (got_tag == 11) { + EXPECT_TRUE(srv_ctx.IsCancelled()); + want_done_tag = false; + // Now get the other entry that we were waiting on + EXPECT_EQ(verif.Next(cq_.get(), ignore_cq_result), tag_idx); + } + } + + cli_thread.join(); + + if (server_try_cancel_thd != nullptr) { + server_try_cancel_thd->join(); + delete server_try_cancel_thd; + } + + if (server_try_cancel == CANCEL_AFTER_PROCESSING) { + srv_ctx.TryCancel(); + want_done_tag = true; + verif.Expect(11, true); + } + + if (want_done_tag) { + verif.Verify(cq_.get()); + EXPECT_TRUE(srv_ctx.IsCancelled()); + want_done_tag = false; + } + + // The RPC has been cancelled at this point for sure (i.e irrespective of + // the value of `server_try_cancel` is). So, from this point forward, we + // know that cq results are supposed to return false on server. + + // Server sends the final message and cancelled status (but the RPC is + // already cancelled at this point. So we expect the operation to fail) + srv_stream.Finish(send_response, Status::CANCELLED, tag(9)); + Verifier().Expect(9, false).Verify(cq_.get()); + + // Client will see the cancellation + cli_stream->Finish(&recv_status, tag(10)); + Verifier().Expect(10, true).Verify(&cli_cq); + EXPECT_FALSE(recv_status.ok()); + EXPECT_EQ(::grpc::StatusCode::CANCELLED, recv_status.error_code()); + + cli_cq.Shutdown(); + void* dummy_tag; + bool dummy_ok; + while (cli_cq.Next(&dummy_tag, &dummy_ok)) { + } + } + + // Helper for testing server-streaming RPCs which are cancelled on the server. + // Depending on the value of server_try_cancel parameter, this will test one + // of the following three scenarios: + // CANCEL_BEFORE_PROCESSING: Rpc is cancelled by the server before sending + // any messages to the client + // + // CANCEL_DURING_PROCESSING: Rpc is cancelled by the server while sending + // messages to the client + // + // CANCEL_AFTER PROCESSING: Rpc is cancelled by server after sending all + // messages to the client (but before sending any status back to the + // client) + void TestServerStreamingServerCancel( + ServerTryCancelRequestPhase server_try_cancel) { + ResetStub(); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + Status recv_status; + ClientContext cli_ctx; + ServerContext srv_ctx; + ServerAsyncWriter<EchoResponse> srv_stream(&srv_ctx); + + send_request.set_message("Ping"); + // Initiate the 'ResponseStream' call on the client + CompletionQueue cli_cq; + std::unique_ptr<ClientAsyncReader<EchoResponse>> cli_stream( + stub_->AsyncResponseStream(&cli_ctx, send_request, &cli_cq, tag(1))); + // On the server, request to be notified of 'ResponseStream' calls and + // receive the call just made by the client + srv_ctx.AsyncNotifyWhenDone(tag(11)); + service_->RequestResponseStream(&srv_ctx, &recv_request, &srv_stream, + cq_.get(), cq_.get(), tag(2)); + + std::thread t1([&cli_cq] { Verifier().Expect(1, true).Verify(&cli_cq); }); + Verifier().Expect(2, true).Verify(cq_.get()); + t1.join(); + + EXPECT_EQ(send_request.message(), recv_request.message()); + + bool expected_cq_result = true; + bool ignore_cq_result = false; + bool want_done_tag = false; + bool expected_client_cq_result = true; + bool ignore_client_cq_result = + (server_try_cancel != CANCEL_BEFORE_PROCESSING); + + if (server_try_cancel == CANCEL_BEFORE_PROCESSING) { + srv_ctx.TryCancel(); + Verifier().Expect(11, true).Verify(cq_.get()); + EXPECT_TRUE(srv_ctx.IsCancelled()); + + // We know for sure that all cq results will be false from this point + // since the server cancelled the RPC + expected_cq_result = false; + expected_client_cq_result = false; + } + + std::thread cli_thread([&cli_cq, &cli_stream, &expected_client_cq_result, + &ignore_client_cq_result] { + // Client attempts to read the three messages from the server + for (int tag_idx = 6; tag_idx <= 8; tag_idx++) { + EchoResponse recv_response; + cli_stream->Read(&recv_response, tag(tag_idx)); + Verifier() + .Expect(tag_idx, expected_client_cq_result) + .Verify(&cli_cq, ignore_client_cq_result); + } + }); + + std::thread* server_try_cancel_thd = nullptr; + + auto verif = Verifier(); + + if (server_try_cancel == CANCEL_DURING_PROCESSING) { + server_try_cancel_thd = + new std::thread([&srv_ctx] { srv_ctx.TryCancel(); }); + + // Server will cancel the RPC in a parallel thread while writing responses + // to the client. Since the cancellation can happen at anytime, some of + // the cq results (i.e those until cancellation) might be true but it is + // non deterministic. So better to ignore the cq results + ignore_cq_result = true; + // Expect that we might possibly see the done tag that + // indicates cancellation completion in this case + want_done_tag = true; + verif.Expect(11, true); + } + + // Server sends three messages (tags 3, 4 and 5) + // But if want_done tag is true, we might also see tag 11 + for (int tag_idx = 3; tag_idx <= 5; tag_idx++) { + send_response.set_message("Pong " + ToString(tag_idx)); + srv_stream.Write(send_response, tag(tag_idx)); + // Note that we'll add something to the verifier and verify that + // something was seen, but it might be tag 11 and not what we + // just added + int got_tag = verif.Expect(tag_idx, expected_cq_result) + .Next(cq_.get(), ignore_cq_result); + GPR_ASSERT((got_tag == tag_idx) || (got_tag == 11 && want_done_tag)); + if (got_tag == 11) { + EXPECT_TRUE(srv_ctx.IsCancelled()); + want_done_tag = false; + // Now get the other entry that we were waiting on + EXPECT_EQ(verif.Next(cq_.get(), ignore_cq_result), tag_idx); + } + } + + if (server_try_cancel_thd != nullptr) { + server_try_cancel_thd->join(); + delete server_try_cancel_thd; + } + + if (server_try_cancel == CANCEL_AFTER_PROCESSING) { + srv_ctx.TryCancel(); + want_done_tag = true; + verif.Expect(11, true); + } + + if (want_done_tag) { + verif.Verify(cq_.get()); + EXPECT_TRUE(srv_ctx.IsCancelled()); + want_done_tag = false; + } + + cli_thread.join(); + + // The RPC has been cancelled at this point for sure (i.e irrespective of + // the value of `server_try_cancel` is). So, from this point forward, we + // know that cq results are supposed to return false on server. + + // Server finishes the stream (but the RPC is already cancelled) + srv_stream.Finish(Status::CANCELLED, tag(9)); + Verifier().Expect(9, false).Verify(cq_.get()); + + // Client will see the cancellation + cli_stream->Finish(&recv_status, tag(10)); + Verifier().Expect(10, true).Verify(&cli_cq); + EXPECT_FALSE(recv_status.ok()); + EXPECT_EQ(::grpc::StatusCode::CANCELLED, recv_status.error_code()); + + cli_cq.Shutdown(); + void* dummy_tag; + bool dummy_ok; + while (cli_cq.Next(&dummy_tag, &dummy_ok)) { + } + } + + // Helper for testing bidirectinal-streaming RPCs which are cancelled on the + // server. + // + // Depending on the value of server_try_cancel parameter, this will + // test one of the following three scenarios: + // CANCEL_BEFORE_PROCESSING: Rpc is cancelled by the server before reading/ + // writing any messages from/to the client + // + // CANCEL_DURING_PROCESSING: Rpc is cancelled by the server while reading + // messages from the client + // + // CANCEL_AFTER PROCESSING: Rpc is cancelled by server after reading all + // messages from the client (but before sending any status back to the + // client) + void TestBidiStreamingServerCancel( + ServerTryCancelRequestPhase server_try_cancel) { + ResetStub(); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + ClientContext cli_ctx; + ServerContext srv_ctx; + ServerAsyncReaderWriter<EchoResponse, EchoRequest> srv_stream(&srv_ctx); + + // Initiate the call from the client side + std::unique_ptr<ClientAsyncReaderWriter<EchoRequest, EchoResponse>> + cli_stream(stub_->AsyncBidiStream(&cli_ctx, cq_.get(), tag(1))); + + // On the server, request to be notified of the 'BidiStream' call and + // receive the call just made by the client + srv_ctx.AsyncNotifyWhenDone(tag(11)); + service_->RequestBidiStream(&srv_ctx, &srv_stream, cq_.get(), cq_.get(), + tag(2)); + Verifier().Expect(1, true).Expect(2, true).Verify(cq_.get()); + + auto verif = Verifier(); + + // Client sends the first and the only message + send_request.set_message("Ping"); + cli_stream->Write(send_request, tag(3)); + verif.Expect(3, true); + + bool expected_cq_result = true; + bool ignore_cq_result = false; + bool want_done_tag = false; + + int got_tag, got_tag2; + bool tag_3_done = false; + + if (server_try_cancel == CANCEL_BEFORE_PROCESSING) { + srv_ctx.TryCancel(); + verif.Expect(11, true); + // We know for sure that all server cq results will be false from + // this point since the server cancelled the RPC. However, we can't + // say for sure about the client + expected_cq_result = false; + ignore_cq_result = true; + + do { + got_tag = verif.Next(cq_.get(), ignore_cq_result); + GPR_ASSERT(((got_tag == 3) && !tag_3_done) || (got_tag == 11)); + if (got_tag == 3) { + tag_3_done = true; + } + } while (got_tag != 11); + EXPECT_TRUE(srv_ctx.IsCancelled()); + } + + std::thread* server_try_cancel_thd = nullptr; + + if (server_try_cancel == CANCEL_DURING_PROCESSING) { + server_try_cancel_thd = + new std::thread([&srv_ctx] { srv_ctx.TryCancel(); }); + + // Since server is going to cancel the RPC in a parallel thread, some of + // the cq results (i.e those until the cancellation) might be true. Since + // that number is non-deterministic, it is better to ignore the cq results + ignore_cq_result = true; + // Expect that we might possibly see the done tag that + // indicates cancellation completion in this case + want_done_tag = true; + verif.Expect(11, true); + } + + srv_stream.Read(&recv_request, tag(4)); + verif.Expect(4, expected_cq_result); + got_tag = tag_3_done ? 3 : verif.Next(cq_.get(), ignore_cq_result); + got_tag2 = verif.Next(cq_.get(), ignore_cq_result); + GPR_ASSERT((got_tag == 3) || (got_tag == 4) || + (got_tag == 11 && want_done_tag)); + GPR_ASSERT((got_tag2 == 3) || (got_tag2 == 4) || + (got_tag2 == 11 && want_done_tag)); + // If we get 3 and 4, we don't need to wait for 11, but if + // we get 11, we should also clear 3 and 4 + if (got_tag + got_tag2 != 7) { + EXPECT_TRUE(srv_ctx.IsCancelled()); + want_done_tag = false; + got_tag = verif.Next(cq_.get(), ignore_cq_result); + GPR_ASSERT((got_tag == 3) || (got_tag == 4)); + } + + send_response.set_message("Pong"); + srv_stream.Write(send_response, tag(5)); + verif.Expect(5, expected_cq_result); + + cli_stream->Read(&recv_response, tag(6)); + verif.Expect(6, expected_cq_result); + got_tag = verif.Next(cq_.get(), ignore_cq_result); + got_tag2 = verif.Next(cq_.get(), ignore_cq_result); + GPR_ASSERT((got_tag == 5) || (got_tag == 6) || + (got_tag == 11 && want_done_tag)); + GPR_ASSERT((got_tag2 == 5) || (got_tag2 == 6) || + (got_tag2 == 11 && want_done_tag)); + // If we get 5 and 6, we don't need to wait for 11, but if + // we get 11, we should also clear 5 and 6 + if (got_tag + got_tag2 != 11) { + EXPECT_TRUE(srv_ctx.IsCancelled()); + want_done_tag = false; + got_tag = verif.Next(cq_.get(), ignore_cq_result); + GPR_ASSERT((got_tag == 5) || (got_tag == 6)); + } + + // This is expected to succeed in all cases + cli_stream->WritesDone(tag(7)); + verif.Expect(7, true); + // TODO(vjpai): Consider whether the following is too flexible + // or whether it should just be reset to ignore_cq_result + bool ignore_cq_wd_result = + ignore_cq_result || (server_try_cancel == CANCEL_BEFORE_PROCESSING); + got_tag = verif.Next(cq_.get(), ignore_cq_wd_result); + GPR_ASSERT((got_tag == 7) || (got_tag == 11 && want_done_tag)); + if (got_tag == 11) { + EXPECT_TRUE(srv_ctx.IsCancelled()); + want_done_tag = false; + // Now get the other entry that we were waiting on + EXPECT_EQ(verif.Next(cq_.get(), ignore_cq_wd_result), 7); + } + + // This is expected to fail in all cases i.e for all values of + // server_try_cancel. This is because at this point, either there are no + // more msgs from the client (because client called WritesDone) or the RPC + // is cancelled on the server + srv_stream.Read(&recv_request, tag(8)); + verif.Expect(8, false); + got_tag = verif.Next(cq_.get(), ignore_cq_result); + GPR_ASSERT((got_tag == 8) || (got_tag == 11 && want_done_tag)); + if (got_tag == 11) { + EXPECT_TRUE(srv_ctx.IsCancelled()); + want_done_tag = false; + // Now get the other entry that we were waiting on + EXPECT_EQ(verif.Next(cq_.get(), ignore_cq_result), 8); + } + + if (server_try_cancel_thd != nullptr) { + server_try_cancel_thd->join(); + delete server_try_cancel_thd; + } + + if (server_try_cancel == CANCEL_AFTER_PROCESSING) { + srv_ctx.TryCancel(); + want_done_tag = true; + verif.Expect(11, true); + } + + if (want_done_tag) { + verif.Verify(cq_.get()); + EXPECT_TRUE(srv_ctx.IsCancelled()); + want_done_tag = false; + } + + // The RPC has been cancelled at this point for sure (i.e irrespective of + // the value of `server_try_cancel` is). So, from this point forward, we + // know that cq results are supposed to return false on server. + + srv_stream.Finish(Status::CANCELLED, tag(9)); + Verifier().Expect(9, false).Verify(cq_.get()); + + cli_stream->Finish(&recv_status, tag(10)); + Verifier().Expect(10, true).Verify(cq_.get()); + EXPECT_FALSE(recv_status.ok()); + EXPECT_EQ(grpc::StatusCode::CANCELLED, recv_status.error_code()); + } +}; + +TEST_P(AsyncEnd2endServerTryCancelTest, ClientStreamingServerTryCancelBefore) { + TestClientStreamingServerCancel(CANCEL_BEFORE_PROCESSING); +} + +TEST_P(AsyncEnd2endServerTryCancelTest, ClientStreamingServerTryCancelDuring) { + TestClientStreamingServerCancel(CANCEL_DURING_PROCESSING); +} + +TEST_P(AsyncEnd2endServerTryCancelTest, ClientStreamingServerTryCancelAfter) { + TestClientStreamingServerCancel(CANCEL_AFTER_PROCESSING); +} + +TEST_P(AsyncEnd2endServerTryCancelTest, ServerStreamingServerTryCancelBefore) { + TestServerStreamingServerCancel(CANCEL_BEFORE_PROCESSING); +} + +TEST_P(AsyncEnd2endServerTryCancelTest, ServerStreamingServerTryCancelDuring) { + TestServerStreamingServerCancel(CANCEL_DURING_PROCESSING); +} + +TEST_P(AsyncEnd2endServerTryCancelTest, ServerStreamingServerTryCancelAfter) { + TestServerStreamingServerCancel(CANCEL_AFTER_PROCESSING); +} + +TEST_P(AsyncEnd2endServerTryCancelTest, ServerBidiStreamingTryCancelBefore) { + TestBidiStreamingServerCancel(CANCEL_BEFORE_PROCESSING); +} + +TEST_P(AsyncEnd2endServerTryCancelTest, ServerBidiStreamingTryCancelDuring) { + TestBidiStreamingServerCancel(CANCEL_DURING_PROCESSING); +} + +TEST_P(AsyncEnd2endServerTryCancelTest, ServerBidiStreamingTryCancelAfter) { + TestBidiStreamingServerCancel(CANCEL_AFTER_PROCESSING); +} + +std::vector<TestScenario> CreateTestScenarios(bool /*test_secure*/, + bool test_message_size_limit) { + std::vector<TestScenario> scenarios; + std::vector<TString> credentials_types; + std::vector<TString> messages; + + auto insec_ok = [] { + // Only allow insecure credentials type when it is registered with the + // provider. User may create providers that do not have insecure. + return GetCredentialsProvider()->GetChannelCredentials( + kInsecureCredentialsType, nullptr) != nullptr; + }; + + if (insec_ok()) { + credentials_types.push_back(kInsecureCredentialsType); + } + auto sec_list = GetCredentialsProvider()->GetSecureCredentialsTypeList(); + for (auto sec = sec_list.begin(); sec != sec_list.end(); sec++) { + credentials_types.push_back(*sec); + } + GPR_ASSERT(!credentials_types.empty()); + + messages.push_back("Hello"); + if (test_message_size_limit) { + for (size_t k = 1; k < GRPC_DEFAULT_MAX_RECV_MESSAGE_LENGTH / 1024; + k *= 32) { + TString big_msg; + for (size_t i = 0; i < k * 1024; ++i) { + char c = 'a' + (i % 26); + big_msg += c; + } + messages.push_back(big_msg); + } + if (!BuiltUnderMsan()) { + // 4MB message processing with SSL is very slow under msan + // (causes timeouts) and doesn't really increase the signal from tests. + // Reserve 100 bytes for other fields of the message proto. + messages.push_back( + TString(GRPC_DEFAULT_MAX_RECV_MESSAGE_LENGTH - 100, 'a')); + } + } + + // TODO (sreek) Renable tests with health check service after the issue + // https://github.com/grpc/grpc/issues/11223 is resolved + for (auto health_check_service : {false}) { + for (auto msg = messages.begin(); msg != messages.end(); msg++) { + for (auto cred = credentials_types.begin(); + cred != credentials_types.end(); ++cred) { + scenarios.emplace_back(false, *cred, health_check_service, *msg); + } + if (insec_ok()) { + scenarios.emplace_back(true, kInsecureCredentialsType, + health_check_service, *msg); + } + } + } + return scenarios; +} + +INSTANTIATE_TEST_SUITE_P(AsyncEnd2end, AsyncEnd2endTest, + ::testing::ValuesIn(CreateTestScenarios(true, true))); +INSTANTIATE_TEST_SUITE_P(AsyncEnd2endServerTryCancel, + AsyncEnd2endServerTryCancelTest, + ::testing::ValuesIn(CreateTestScenarios(false, + false))); + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + // Change the backup poll interval from 5s to 100ms to speed up the + // ReconnectChannel test + GPR_GLOBAL_CONFIG_SET(grpc_client_channel_backup_poll_interval_ms, 100); + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + int ret = RUN_ALL_TESTS(); + return ret; +} diff --git a/contrib/libs/grpc/test/cpp/end2end/cfstream_test.cc b/contrib/libs/grpc/test/cpp/end2end/cfstream_test.cc new file mode 100644 index 0000000000..e6695982bd --- /dev/null +++ b/contrib/libs/grpc/test/cpp/end2end/cfstream_test.cc @@ -0,0 +1,496 @@ +/* + * + * Copyright 2019 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/lib/iomgr/port.h" + +#include <algorithm> +#include <memory> +#include <mutex> +#include <random> +#include <thread> + +#include <grpc/grpc.h> +#include <grpc/support/alloc.h> +#include <grpc/support/atm.h> +#include <grpc/support/log.h> +#include <grpc/support/string_util.h> +#include <grpc/support/time.h> +#include <grpcpp/channel.h> +#include <grpcpp/client_context.h> +#include <grpcpp/create_channel.h> +#include <grpcpp/health_check_service_interface.h> +#include <grpcpp/server.h> +#include <grpcpp/server_builder.h> +#include <gtest/gtest.h> + +#include "src/core/lib/backoff/backoff.h" +#include "src/core/lib/gpr/env.h" + +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/debugger_macros.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/end2end/test_service_impl.h" +#include "test/cpp/util/test_credentials_provider.h" + +#ifdef GRPC_CFSTREAM +using grpc::ClientAsyncResponseReader; +using grpc::testing::EchoRequest; +using grpc::testing::EchoResponse; +using grpc::testing::RequestParams; +using std::chrono::system_clock; + +namespace grpc { +namespace testing { +namespace { + +struct TestScenario { + TestScenario(const TString& creds_type, const TString& content) + : credentials_type(creds_type), message_content(content) {} + const TString credentials_type; + const TString message_content; +}; + +class CFStreamTest : public ::testing::TestWithParam<TestScenario> { + protected: + CFStreamTest() + : server_host_("grpctest"), + interface_("lo0"), + ipv4_address_("10.0.0.1") {} + + void DNSUp() { + std::ostringstream cmd; + // Add DNS entry for server_host_ in /etc/hosts + cmd << "echo '" << ipv4_address_ << " " << server_host_ + << " ' | sudo tee -a /etc/hosts"; + std::system(cmd.str().c_str()); + } + + void DNSDown() { + std::ostringstream cmd; + // Remove DNS entry for server_host_ in /etc/hosts + cmd << "sudo sed -i '.bak' '/" << server_host_ << "/d' /etc/hosts"; + std::system(cmd.str().c_str()); + } + + void InterfaceUp() { + std::ostringstream cmd; + cmd << "sudo /sbin/ifconfig " << interface_ << " alias " << ipv4_address_; + std::system(cmd.str().c_str()); + } + + void InterfaceDown() { + std::ostringstream cmd; + cmd << "sudo /sbin/ifconfig " << interface_ << " -alias " << ipv4_address_; + std::system(cmd.str().c_str()); + } + + void NetworkUp() { + gpr_log(GPR_DEBUG, "Bringing network up"); + InterfaceUp(); + DNSUp(); + } + + void NetworkDown() { + gpr_log(GPR_DEBUG, "Bringing network down"); + InterfaceDown(); + DNSDown(); + } + + void SetUp() override { + NetworkUp(); + grpc_init(); + StartServer(); + } + + void TearDown() override { + NetworkDown(); + StopServer(); + grpc_shutdown(); + } + + void StartServer() { + port_ = grpc_pick_unused_port_or_die(); + server_.reset(new ServerData(port_, GetParam().credentials_type)); + server_->Start(server_host_); + } + void StopServer() { server_->Shutdown(); } + + std::unique_ptr<grpc::testing::EchoTestService::Stub> BuildStub( + const std::shared_ptr<Channel>& channel) { + return grpc::testing::EchoTestService::NewStub(channel); + } + + std::shared_ptr<Channel> BuildChannel() { + std::ostringstream server_address; + server_address << server_host_ << ":" << port_; + ChannelArguments args; + auto channel_creds = GetCredentialsProvider()->GetChannelCredentials( + GetParam().credentials_type, &args); + return CreateCustomChannel(server_address.str(), channel_creds, args); + } + + int GetStreamID(ClientContext& context) { + int stream_id = 0; + grpc_call* call = context.c_call(); + if (call) { + grpc_chttp2_stream* stream = grpc_chttp2_stream_from_call(call); + if (stream) { + stream_id = stream->id; + } + } + return stream_id; + } + + void SendRpc( + const std::unique_ptr<grpc::testing::EchoTestService::Stub>& stub, + bool expect_success = false) { + auto response = std::unique_ptr<EchoResponse>(new EchoResponse()); + EchoRequest request; + auto& msg = GetParam().message_content; + request.set_message(msg); + ClientContext context; + Status status = stub->Echo(&context, request, response.get()); + int stream_id = GetStreamID(context); + if (status.ok()) { + gpr_log(GPR_DEBUG, "RPC with stream_id %d succeeded", stream_id); + EXPECT_EQ(msg, response->message()); + } else { + gpr_log(GPR_DEBUG, "RPC with stream_id %d failed: %s", stream_id, + status.error_message().c_str()); + } + if (expect_success) { + EXPECT_TRUE(status.ok()); + } + } + void SendAsyncRpc( + const std::unique_ptr<grpc::testing::EchoTestService::Stub>& stub, + RequestParams param = RequestParams()) { + EchoRequest request; + request.set_message(GetParam().message_content); + *request.mutable_param() = std::move(param); + AsyncClientCall* call = new AsyncClientCall; + + call->response_reader = + stub->PrepareAsyncEcho(&call->context, request, &cq_); + + call->response_reader->StartCall(); + call->response_reader->Finish(&call->reply, &call->status, (void*)call); + } + + void ShutdownCQ() { cq_.Shutdown(); } + + bool CQNext(void** tag, bool* ok) { + auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(10); + auto ret = cq_.AsyncNext(tag, ok, deadline); + if (ret == grpc::CompletionQueue::GOT_EVENT) { + return true; + } else if (ret == grpc::CompletionQueue::SHUTDOWN) { + return false; + } else { + GPR_ASSERT(ret == grpc::CompletionQueue::TIMEOUT); + // This can happen if we hit the Apple CFStream bug which results in the + // read stream hanging. We are ignoring hangs and timeouts, but these + // tests are still useful as they can catch memory memory corruptions, + // crashes and other bugs that don't result in test hang/timeout. + return false; + } + } + + bool WaitForChannelNotReady(Channel* channel, int timeout_seconds = 5) { + const gpr_timespec deadline = + grpc_timeout_seconds_to_deadline(timeout_seconds); + grpc_connectivity_state state; + while ((state = channel->GetState(false /* try_to_connect */)) == + GRPC_CHANNEL_READY) { + if (!channel->WaitForStateChange(state, deadline)) return false; + } + return true; + } + + bool WaitForChannelReady(Channel* channel, int timeout_seconds = 10) { + const gpr_timespec deadline = + grpc_timeout_seconds_to_deadline(timeout_seconds); + grpc_connectivity_state state; + while ((state = channel->GetState(true /* try_to_connect */)) != + GRPC_CHANNEL_READY) { + if (!channel->WaitForStateChange(state, deadline)) return false; + } + return true; + } + + struct AsyncClientCall { + EchoResponse reply; + ClientContext context; + Status status; + std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader; + }; + + private: + struct ServerData { + int port_; + const TString creds_; + std::unique_ptr<Server> server_; + TestServiceImpl service_; + std::unique_ptr<std::thread> thread_; + bool server_ready_ = false; + + ServerData(int port, const TString& creds) + : port_(port), creds_(creds) {} + + void Start(const TString& server_host) { + gpr_log(GPR_INFO, "starting server on port %d", port_); + std::mutex mu; + std::unique_lock<std::mutex> lock(mu); + std::condition_variable cond; + thread_.reset(new std::thread( + std::bind(&ServerData::Serve, this, server_host, &mu, &cond))); + cond.wait(lock, [this] { return server_ready_; }); + server_ready_ = false; + gpr_log(GPR_INFO, "server startup complete"); + } + + void Serve(const TString& server_host, std::mutex* mu, + std::condition_variable* cond) { + std::ostringstream server_address; + server_address << server_host << ":" << port_; + ServerBuilder builder; + auto server_creds = + GetCredentialsProvider()->GetServerCredentials(creds_); + builder.AddListeningPort(server_address.str(), server_creds); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + std::lock_guard<std::mutex> lock(*mu); + server_ready_ = true; + cond->notify_one(); + } + + void Shutdown(bool join = true) { + server_->Shutdown(grpc_timeout_milliseconds_to_deadline(0)); + if (join) thread_->join(); + } + }; + + CompletionQueue cq_; + const TString server_host_; + const TString interface_; + const TString ipv4_address_; + std::unique_ptr<ServerData> server_; + int port_; +}; + +std::vector<TestScenario> CreateTestScenarios() { + std::vector<TestScenario> scenarios; + std::vector<TString> credentials_types; + std::vector<TString> messages; + + credentials_types.push_back(kInsecureCredentialsType); + auto sec_list = GetCredentialsProvider()->GetSecureCredentialsTypeList(); + for (auto sec = sec_list.begin(); sec != sec_list.end(); sec++) { + credentials_types.push_back(*sec); + } + + messages.push_back("🖖"); + for (size_t k = 1; k < GRPC_DEFAULT_MAX_RECV_MESSAGE_LENGTH / 1024; k *= 32) { + TString big_msg; + for (size_t i = 0; i < k * 1024; ++i) { + char c = 'a' + (i % 26); + big_msg += c; + } + messages.push_back(big_msg); + } + for (auto cred = credentials_types.begin(); cred != credentials_types.end(); + ++cred) { + for (auto msg = messages.begin(); msg != messages.end(); msg++) { + scenarios.emplace_back(*cred, *msg); + } + } + + return scenarios; +} + +INSTANTIATE_TEST_SUITE_P(CFStreamTest, CFStreamTest, + ::testing::ValuesIn(CreateTestScenarios())); + +// gRPC should automatically detech network flaps (without enabling keepalives) +// when CFStream is enabled +TEST_P(CFStreamTest, NetworkTransition) { + auto channel = BuildChannel(); + auto stub = BuildStub(channel); + // Channel should be in READY state after we send an RPC + SendRpc(stub, /*expect_success=*/true); + EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY); + + std::atomic_bool shutdown{false}; + std::thread sender = std::thread([this, &stub, &shutdown]() { + while (true) { + if (shutdown.load()) { + return; + } + SendRpc(stub); + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + } + }); + + // bring down network + NetworkDown(); + + // network going down should be detected by cfstream + EXPECT_TRUE(WaitForChannelNotReady(channel.get())); + + // bring network interface back up + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + NetworkUp(); + + // channel should reconnect + EXPECT_TRUE(WaitForChannelReady(channel.get())); + EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY); + shutdown.store(true); + sender.join(); +} + +// Network flaps while RPCs are in flight +TEST_P(CFStreamTest, NetworkFlapRpcsInFlight) { + auto channel = BuildChannel(); + auto stub = BuildStub(channel); + std::atomic_int rpcs_sent{0}; + + // Channel should be in READY state after we send some RPCs + for (int i = 0; i < 10; ++i) { + RequestParams param; + param.set_skip_cancelled_check(true); + SendAsyncRpc(stub, param); + ++rpcs_sent; + } + EXPECT_TRUE(WaitForChannelReady(channel.get())); + + // Bring down the network + NetworkDown(); + + std::thread thd = std::thread([this, &rpcs_sent]() { + void* got_tag; + bool ok = false; + bool network_down = true; + int total_completions = 0; + + while (CQNext(&got_tag, &ok)) { + ++total_completions; + GPR_ASSERT(ok); + AsyncClientCall* call = static_cast<AsyncClientCall*>(got_tag); + int stream_id = GetStreamID(call->context); + if (!call->status.ok()) { + gpr_log(GPR_DEBUG, "RPC with stream_id %d failed with error: %s", + stream_id, call->status.error_message().c_str()); + // Bring network up when RPCs start failing + if (network_down) { + NetworkUp(); + network_down = false; + } + } else { + gpr_log(GPR_DEBUG, "RPC with stream_id %d succeeded", stream_id); + } + delete call; + } + // Remove line below and uncomment the following line after Apple CFStream + // bug has been fixed. + (void)rpcs_sent; + // EXPECT_EQ(total_completions, rpcs_sent); + }); + + for (int i = 0; i < 100; ++i) { + RequestParams param; + param.set_skip_cancelled_check(true); + SendAsyncRpc(stub, param); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + ++rpcs_sent; + } + + ShutdownCQ(); + + thd.join(); +} + +// Send a bunch of RPCs, some of which are expected to fail. +// We should get back a response for all RPCs +TEST_P(CFStreamTest, ConcurrentRpc) { + auto channel = BuildChannel(); + auto stub = BuildStub(channel); + std::atomic_int rpcs_sent{0}; + std::thread thd = std::thread([this, &rpcs_sent]() { + void* got_tag; + bool ok = false; + int total_completions = 0; + + while (CQNext(&got_tag, &ok)) { + ++total_completions; + GPR_ASSERT(ok); + AsyncClientCall* call = static_cast<AsyncClientCall*>(got_tag); + int stream_id = GetStreamID(call->context); + if (!call->status.ok()) { + gpr_log(GPR_DEBUG, "RPC with stream_id %d failed with error: %s", + stream_id, call->status.error_message().c_str()); + // Bring network up when RPCs start failing + } else { + gpr_log(GPR_DEBUG, "RPC with stream_id %d succeeded", stream_id); + } + delete call; + } + // Remove line below and uncomment the following line after Apple CFStream + // bug has been fixed. + (void)rpcs_sent; + // EXPECT_EQ(total_completions, rpcs_sent); + }); + + for (int i = 0; i < 10; ++i) { + if (i % 3 == 0) { + RequestParams param; + ErrorStatus* error = param.mutable_expected_error(); + error->set_code(StatusCode::INTERNAL); + error->set_error_message("internal error"); + SendAsyncRpc(stub, param); + } else if (i % 5 == 0) { + RequestParams param; + param.set_echo_metadata(true); + DebugInfo* info = param.mutable_debug_info(); + info->add_stack_entries("stack_entry1"); + info->add_stack_entries("stack_entry2"); + info->set_detail("detailed debug info"); + SendAsyncRpc(stub, param); + } else { + SendAsyncRpc(stub); + } + ++rpcs_sent; + } + + ShutdownCQ(); + + thd.join(); +} + +} // namespace +} // namespace testing +} // namespace grpc +#endif // GRPC_CFSTREAM + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(argc, argv); + gpr_setenv("grpc_cfstream", "1"); + const auto result = RUN_ALL_TESTS(); + return result; +} diff --git a/contrib/libs/grpc/test/cpp/end2end/channelz_service_test.cc b/contrib/libs/grpc/test/cpp/end2end/channelz_service_test.cc new file mode 100644 index 0000000000..9c723bebb6 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/end2end/channelz_service_test.cc @@ -0,0 +1,767 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <grpc/support/port_platform.h> + +#include <grpc/grpc.h> +#include <grpcpp/channel.h> +#include <grpcpp/client_context.h> +#include <grpcpp/create_channel.h> +#include <grpcpp/security/credentials.h> +#include <grpcpp/security/server_credentials.h> +#include <grpcpp/server.h> +#include <grpcpp/server_builder.h> +#include <grpcpp/server_context.h> + +#include <grpcpp/ext/channelz_service_plugin.h> +#include "src/core/lib/gpr/env.h" +#include "src/proto/grpc/channelz/channelz.grpc.pb.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/end2end/test_service_impl.h" + +#include <gtest/gtest.h> + +using grpc::channelz::v1::GetChannelRequest; +using grpc::channelz::v1::GetChannelResponse; +using grpc::channelz::v1::GetServerRequest; +using grpc::channelz::v1::GetServerResponse; +using grpc::channelz::v1::GetServerSocketsRequest; +using grpc::channelz::v1::GetServerSocketsResponse; +using grpc::channelz::v1::GetServersRequest; +using grpc::channelz::v1::GetServersResponse; +using grpc::channelz::v1::GetSocketRequest; +using grpc::channelz::v1::GetSocketResponse; +using grpc::channelz::v1::GetSubchannelRequest; +using grpc::channelz::v1::GetSubchannelResponse; +using grpc::channelz::v1::GetTopChannelsRequest; +using grpc::channelz::v1::GetTopChannelsResponse; + +namespace grpc { +namespace testing { +namespace { + +// Proxy service supports N backends. Sends RPC to backend dictated by +// request->backend_channel_idx(). +class Proxy : public ::grpc::testing::EchoTestService::Service { + public: + Proxy() {} + + void AddChannelToBackend(const std::shared_ptr<Channel>& channel) { + stubs_.push_back(grpc::testing::EchoTestService::NewStub(channel)); + } + + Status Echo(ServerContext* server_context, const EchoRequest* request, + EchoResponse* response) override { + std::unique_ptr<ClientContext> client_context = + ClientContext::FromServerContext(*server_context); + size_t idx = request->param().backend_channel_idx(); + GPR_ASSERT(idx < stubs_.size()); + return stubs_[idx]->Echo(client_context.get(), *request, response); + } + + Status BidiStream(ServerContext* server_context, + ServerReaderWriter<EchoResponse, EchoRequest>* + stream_from_client) override { + EchoRequest request; + EchoResponse response; + std::unique_ptr<ClientContext> client_context = + ClientContext::FromServerContext(*server_context); + + // always use the first proxy for streaming + auto stream_to_backend = stubs_[0]->BidiStream(client_context.get()); + while (stream_from_client->Read(&request)) { + stream_to_backend->Write(request); + stream_to_backend->Read(&response); + stream_from_client->Write(response); + } + + stream_to_backend->WritesDone(); + return stream_to_backend->Finish(); + } + + private: + std::vector<std::unique_ptr<::grpc::testing::EchoTestService::Stub>> stubs_; +}; + +} // namespace + +class ChannelzServerTest : public ::testing::Test { + public: + ChannelzServerTest() {} + static void SetUpTestCase() { +#if TARGET_OS_IPHONE + // Workaround Apple CFStream bug + gpr_setenv("grpc_cfstream", "0"); +#endif + } + void SetUp() override { + // ensure channel server is brought up on all severs we build. + ::grpc::channelz::experimental::InitChannelzService(); + + // We set up a proxy server with channelz enabled. + proxy_port_ = grpc_pick_unused_port_or_die(); + ServerBuilder proxy_builder; + TString proxy_server_address = "localhost:" + to_string(proxy_port_); + proxy_builder.AddListeningPort(proxy_server_address, + InsecureServerCredentials()); + // forces channelz and channel tracing to be enabled. + proxy_builder.AddChannelArgument(GRPC_ARG_ENABLE_CHANNELZ, 1); + proxy_builder.AddChannelArgument( + GRPC_ARG_MAX_CHANNEL_TRACE_EVENT_MEMORY_PER_NODE, 1024); + proxy_builder.RegisterService(&proxy_service_); + proxy_server_ = proxy_builder.BuildAndStart(); + } + + // Sets the proxy up to have an arbitrary number of backends. + void ConfigureProxy(size_t num_backends) { + backends_.resize(num_backends); + for (size_t i = 0; i < num_backends; ++i) { + // create a new backend. + backends_[i].port = grpc_pick_unused_port_or_die(); + ServerBuilder backend_builder; + TString backend_server_address = + "localhost:" + to_string(backends_[i].port); + backend_builder.AddListeningPort(backend_server_address, + InsecureServerCredentials()); + backends_[i].service.reset(new TestServiceImpl); + // ensure that the backend itself has channelz disabled. + backend_builder.AddChannelArgument(GRPC_ARG_ENABLE_CHANNELZ, 0); + backend_builder.RegisterService(backends_[i].service.get()); + backends_[i].server = backend_builder.BuildAndStart(); + // set up a channel to the backend. We ensure that this channel has + // channelz enabled since these channels (proxy outbound to backends) + // are the ones that our test will actually be validating. + ChannelArguments args; + args.SetInt(GRPC_ARG_ENABLE_CHANNELZ, 1); + args.SetInt(GRPC_ARG_MAX_CHANNEL_TRACE_EVENT_MEMORY_PER_NODE, 1024); + std::shared_ptr<Channel> channel_to_backend = ::grpc::CreateCustomChannel( + backend_server_address, InsecureChannelCredentials(), args); + proxy_service_.AddChannelToBackend(channel_to_backend); + } + } + + void ResetStubs() { + string target = "dns:localhost:" + to_string(proxy_port_); + ChannelArguments args; + // disable channelz. We only want to focus on proxy to backend outbound. + args.SetInt(GRPC_ARG_ENABLE_CHANNELZ, 0); + std::shared_ptr<Channel> channel = + ::grpc::CreateCustomChannel(target, InsecureChannelCredentials(), args); + channelz_stub_ = grpc::channelz::v1::Channelz::NewStub(channel); + echo_stub_ = grpc::testing::EchoTestService::NewStub(channel); + } + + std::unique_ptr<grpc::testing::EchoTestService::Stub> NewEchoStub() { + string target = "dns:localhost:" + to_string(proxy_port_); + ChannelArguments args; + // disable channelz. We only want to focus on proxy to backend outbound. + args.SetInt(GRPC_ARG_ENABLE_CHANNELZ, 0); + // This ensures that gRPC will not do connection sharing. + args.SetInt(GRPC_ARG_USE_LOCAL_SUBCHANNEL_POOL, true); + std::shared_ptr<Channel> channel = + ::grpc::CreateCustomChannel(target, InsecureChannelCredentials(), args); + return grpc::testing::EchoTestService::NewStub(channel); + } + + void SendSuccessfulEcho(int channel_idx) { + EchoRequest request; + EchoResponse response; + request.set_message("Hello channelz"); + request.mutable_param()->set_backend_channel_idx(channel_idx); + ClientContext context; + Status s = echo_stub_->Echo(&context, request, &response); + EXPECT_EQ(response.message(), request.message()); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + } + + void SendSuccessfulStream(int num_messages) { + EchoRequest request; + EchoResponse response; + request.set_message("Hello channelz"); + ClientContext context; + auto stream_to_proxy = echo_stub_->BidiStream(&context); + for (int i = 0; i < num_messages; ++i) { + EXPECT_TRUE(stream_to_proxy->Write(request)); + EXPECT_TRUE(stream_to_proxy->Read(&response)); + } + stream_to_proxy->WritesDone(); + Status s = stream_to_proxy->Finish(); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + } + + void SendFailedEcho(int channel_idx) { + EchoRequest request; + EchoResponse response; + request.set_message("Hello channelz"); + request.mutable_param()->set_backend_channel_idx(channel_idx); + auto* error = request.mutable_param()->mutable_expected_error(); + error->set_code(13); // INTERNAL + error->set_error_message("error"); + ClientContext context; + Status s = echo_stub_->Echo(&context, request, &response); + EXPECT_FALSE(s.ok()); + } + + // Uses GetTopChannels to return the channel_id of a particular channel, + // so that the unit tests may test GetChannel call. + intptr_t GetChannelId(int channel_idx) { + GetTopChannelsRequest request; + GetTopChannelsResponse response; + request.set_start_channel_id(0); + ClientContext context; + Status s = channelz_stub_->GetTopChannels(&context, request, &response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + EXPECT_GT(response.channel_size(), channel_idx); + return response.channel(channel_idx).ref().channel_id(); + } + + static string to_string(const int number) { + std::stringstream strs; + strs << number; + return strs.str(); + } + + protected: + // package of data needed for each backend server. + struct BackendData { + std::unique_ptr<Server> server; + int port; + std::unique_ptr<TestServiceImpl> service; + }; + + std::unique_ptr<grpc::channelz::v1::Channelz::Stub> channelz_stub_; + std::unique_ptr<grpc::testing::EchoTestService::Stub> echo_stub_; + + // proxy server to ping with channelz requests. + std::unique_ptr<Server> proxy_server_; + int proxy_port_; + Proxy proxy_service_; + + // backends. All implement the echo service. + std::vector<BackendData> backends_; +}; + +TEST_F(ChannelzServerTest, BasicTest) { + ResetStubs(); + ConfigureProxy(1); + GetTopChannelsRequest request; + GetTopChannelsResponse response; + request.set_start_channel_id(0); + ClientContext context; + Status s = channelz_stub_->GetTopChannels(&context, request, &response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + EXPECT_EQ(response.channel_size(), 1); +} + +TEST_F(ChannelzServerTest, HighStartId) { + ResetStubs(); + ConfigureProxy(1); + GetTopChannelsRequest request; + GetTopChannelsResponse response; + request.set_start_channel_id(10000); + ClientContext context; + Status s = channelz_stub_->GetTopChannels(&context, request, &response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + EXPECT_EQ(response.channel_size(), 0); +} + +TEST_F(ChannelzServerTest, SuccessfulRequestTest) { + ResetStubs(); + ConfigureProxy(1); + SendSuccessfulEcho(0); + GetChannelRequest request; + GetChannelResponse response; + request.set_channel_id(GetChannelId(0)); + ClientContext context; + Status s = channelz_stub_->GetChannel(&context, request, &response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + EXPECT_EQ(response.channel().data().calls_started(), 1); + EXPECT_EQ(response.channel().data().calls_succeeded(), 1); + EXPECT_EQ(response.channel().data().calls_failed(), 0); +} + +TEST_F(ChannelzServerTest, FailedRequestTest) { + ResetStubs(); + ConfigureProxy(1); + SendFailedEcho(0); + GetChannelRequest request; + GetChannelResponse response; + request.set_channel_id(GetChannelId(0)); + ClientContext context; + Status s = channelz_stub_->GetChannel(&context, request, &response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + EXPECT_EQ(response.channel().data().calls_started(), 1); + EXPECT_EQ(response.channel().data().calls_succeeded(), 0); + EXPECT_EQ(response.channel().data().calls_failed(), 1); +} + +TEST_F(ChannelzServerTest, ManyRequestsTest) { + ResetStubs(); + ConfigureProxy(1); + // send some RPCs + const int kNumSuccess = 10; + const int kNumFailed = 11; + for (int i = 0; i < kNumSuccess; ++i) { + SendSuccessfulEcho(0); + } + for (int i = 0; i < kNumFailed; ++i) { + SendFailedEcho(0); + } + GetChannelRequest request; + GetChannelResponse response; + request.set_channel_id(GetChannelId(0)); + ClientContext context; + Status s = channelz_stub_->GetChannel(&context, request, &response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + EXPECT_EQ(response.channel().data().calls_started(), + kNumSuccess + kNumFailed); + EXPECT_EQ(response.channel().data().calls_succeeded(), kNumSuccess); + EXPECT_EQ(response.channel().data().calls_failed(), kNumFailed); +} + +TEST_F(ChannelzServerTest, ManyChannels) { + ResetStubs(); + const int kNumChannels = 4; + ConfigureProxy(kNumChannels); + GetTopChannelsRequest request; + GetTopChannelsResponse response; + request.set_start_channel_id(0); + ClientContext context; + Status s = channelz_stub_->GetTopChannels(&context, request, &response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + EXPECT_EQ(response.channel_size(), kNumChannels); +} + +TEST_F(ChannelzServerTest, ManyRequestsManyChannels) { + ResetStubs(); + const int kNumChannels = 4; + ConfigureProxy(kNumChannels); + const int kNumSuccess = 10; + const int kNumFailed = 11; + for (int i = 0; i < kNumSuccess; ++i) { + SendSuccessfulEcho(0); + SendSuccessfulEcho(2); + } + for (int i = 0; i < kNumFailed; ++i) { + SendFailedEcho(1); + SendFailedEcho(2); + } + + // the first channel saw only successes + { + GetChannelRequest request; + GetChannelResponse response; + request.set_channel_id(GetChannelId(0)); + ClientContext context; + Status s = channelz_stub_->GetChannel(&context, request, &response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + EXPECT_EQ(response.channel().data().calls_started(), kNumSuccess); + EXPECT_EQ(response.channel().data().calls_succeeded(), kNumSuccess); + EXPECT_EQ(response.channel().data().calls_failed(), 0); + } + + // the second channel saw only failures + { + GetChannelRequest request; + GetChannelResponse response; + request.set_channel_id(GetChannelId(1)); + ClientContext context; + Status s = channelz_stub_->GetChannel(&context, request, &response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + EXPECT_EQ(response.channel().data().calls_started(), kNumFailed); + EXPECT_EQ(response.channel().data().calls_succeeded(), 0); + EXPECT_EQ(response.channel().data().calls_failed(), kNumFailed); + } + + // the third channel saw both + { + GetChannelRequest request; + GetChannelResponse response; + request.set_channel_id(GetChannelId(2)); + ClientContext context; + Status s = channelz_stub_->GetChannel(&context, request, &response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + EXPECT_EQ(response.channel().data().calls_started(), + kNumSuccess + kNumFailed); + EXPECT_EQ(response.channel().data().calls_succeeded(), kNumSuccess); + EXPECT_EQ(response.channel().data().calls_failed(), kNumFailed); + } + + // the fourth channel saw nothing + { + GetChannelRequest request; + GetChannelResponse response; + request.set_channel_id(GetChannelId(3)); + ClientContext context; + Status s = channelz_stub_->GetChannel(&context, request, &response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + EXPECT_EQ(response.channel().data().calls_started(), 0); + EXPECT_EQ(response.channel().data().calls_succeeded(), 0); + EXPECT_EQ(response.channel().data().calls_failed(), 0); + } +} + +TEST_F(ChannelzServerTest, ManySubchannels) { + ResetStubs(); + const int kNumChannels = 4; + ConfigureProxy(kNumChannels); + const int kNumSuccess = 10; + const int kNumFailed = 11; + for (int i = 0; i < kNumSuccess; ++i) { + SendSuccessfulEcho(0); + SendSuccessfulEcho(2); + } + for (int i = 0; i < kNumFailed; ++i) { + SendFailedEcho(1); + SendFailedEcho(2); + } + GetTopChannelsRequest gtc_request; + GetTopChannelsResponse gtc_response; + gtc_request.set_start_channel_id(0); + ClientContext context; + Status s = + channelz_stub_->GetTopChannels(&context, gtc_request, >c_response); + EXPECT_TRUE(s.ok()) << s.error_message(); + EXPECT_EQ(gtc_response.channel_size(), kNumChannels); + for (int i = 0; i < gtc_response.channel_size(); ++i) { + // if the channel sent no RPCs, then expect no subchannels to have been + // created. + if (gtc_response.channel(i).data().calls_started() == 0) { + EXPECT_EQ(gtc_response.channel(i).subchannel_ref_size(), 0); + continue; + } + // The resolver must return at least one address. + ASSERT_GT(gtc_response.channel(i).subchannel_ref_size(), 0); + GetSubchannelRequest gsc_request; + GetSubchannelResponse gsc_response; + gsc_request.set_subchannel_id( + gtc_response.channel(i).subchannel_ref(0).subchannel_id()); + ClientContext context; + Status s = + channelz_stub_->GetSubchannel(&context, gsc_request, &gsc_response); + EXPECT_TRUE(s.ok()) << s.error_message(); + EXPECT_EQ(gtc_response.channel(i).data().calls_started(), + gsc_response.subchannel().data().calls_started()); + EXPECT_EQ(gtc_response.channel(i).data().calls_succeeded(), + gsc_response.subchannel().data().calls_succeeded()); + EXPECT_EQ(gtc_response.channel(i).data().calls_failed(), + gsc_response.subchannel().data().calls_failed()); + } +} + +TEST_F(ChannelzServerTest, BasicServerTest) { + ResetStubs(); + ConfigureProxy(1); + GetServersRequest request; + GetServersResponse response; + request.set_start_server_id(0); + ClientContext context; + Status s = channelz_stub_->GetServers(&context, request, &response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + EXPECT_EQ(response.server_size(), 1); +} + +TEST_F(ChannelzServerTest, BasicGetServerTest) { + ResetStubs(); + ConfigureProxy(1); + GetServersRequest get_servers_request; + GetServersResponse get_servers_response; + get_servers_request.set_start_server_id(0); + ClientContext get_servers_context; + Status s = channelz_stub_->GetServers( + &get_servers_context, get_servers_request, &get_servers_response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + EXPECT_EQ(get_servers_response.server_size(), 1); + GetServerRequest get_server_request; + GetServerResponse get_server_response; + get_server_request.set_server_id( + get_servers_response.server(0).ref().server_id()); + ClientContext get_server_context; + s = channelz_stub_->GetServer(&get_server_context, get_server_request, + &get_server_response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + EXPECT_EQ(get_servers_response.server(0).ref().server_id(), + get_server_response.server().ref().server_id()); +} + +TEST_F(ChannelzServerTest, ServerCallTest) { + ResetStubs(); + ConfigureProxy(1); + const int kNumSuccess = 10; + const int kNumFailed = 11; + for (int i = 0; i < kNumSuccess; ++i) { + SendSuccessfulEcho(0); + } + for (int i = 0; i < kNumFailed; ++i) { + SendFailedEcho(0); + } + GetServersRequest request; + GetServersResponse response; + request.set_start_server_id(0); + ClientContext context; + Status s = channelz_stub_->GetServers(&context, request, &response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + EXPECT_EQ(response.server_size(), 1); + EXPECT_EQ(response.server(0).data().calls_succeeded(), kNumSuccess); + EXPECT_EQ(response.server(0).data().calls_failed(), kNumFailed); + // This is success+failure+1 because the call that retrieved this information + // will be counted as started. It will not track success/failure until after + // it has returned, so that is not included in the response. + EXPECT_EQ(response.server(0).data().calls_started(), + kNumSuccess + kNumFailed + 1); +} + +TEST_F(ChannelzServerTest, ManySubchannelsAndSockets) { + ResetStubs(); + const int kNumChannels = 4; + ConfigureProxy(kNumChannels); + const int kNumSuccess = 10; + const int kNumFailed = 11; + for (int i = 0; i < kNumSuccess; ++i) { + SendSuccessfulEcho(0); + SendSuccessfulEcho(2); + } + for (int i = 0; i < kNumFailed; ++i) { + SendFailedEcho(1); + SendFailedEcho(2); + } + GetTopChannelsRequest gtc_request; + GetTopChannelsResponse gtc_response; + gtc_request.set_start_channel_id(0); + ClientContext context; + Status s = + channelz_stub_->GetTopChannels(&context, gtc_request, >c_response); + EXPECT_TRUE(s.ok()) << s.error_message(); + EXPECT_EQ(gtc_response.channel_size(), kNumChannels); + for (int i = 0; i < gtc_response.channel_size(); ++i) { + // if the channel sent no RPCs, then expect no subchannels to have been + // created. + if (gtc_response.channel(i).data().calls_started() == 0) { + EXPECT_EQ(gtc_response.channel(i).subchannel_ref_size(), 0); + continue; + } + // The resolver must return at least one address. + ASSERT_GT(gtc_response.channel(i).subchannel_ref_size(), 0); + // First grab the subchannel + GetSubchannelRequest get_subchannel_req; + GetSubchannelResponse get_subchannel_resp; + get_subchannel_req.set_subchannel_id( + gtc_response.channel(i).subchannel_ref(0).subchannel_id()); + ClientContext get_subchannel_ctx; + Status s = channelz_stub_->GetSubchannel( + &get_subchannel_ctx, get_subchannel_req, &get_subchannel_resp); + EXPECT_TRUE(s.ok()) << s.error_message(); + EXPECT_EQ(get_subchannel_resp.subchannel().socket_ref_size(), 1); + // Now grab the socket. + GetSocketRequest get_socket_req; + GetSocketResponse get_socket_resp; + ClientContext get_socket_ctx; + get_socket_req.set_socket_id( + get_subchannel_resp.subchannel().socket_ref(0).socket_id()); + s = channelz_stub_->GetSocket(&get_socket_ctx, get_socket_req, + &get_socket_resp); + EXPECT_TRUE( + get_subchannel_resp.subchannel().socket_ref(0).name().find("http")); + EXPECT_TRUE(s.ok()) << s.error_message(); + // calls started == streams started AND stream succeeded. Since none of + // these RPCs were canceled, all of the streams will succeeded even though + // the RPCs they represent might have failed. + EXPECT_EQ(get_subchannel_resp.subchannel().data().calls_started(), + get_socket_resp.socket().data().streams_started()); + EXPECT_EQ(get_subchannel_resp.subchannel().data().calls_started(), + get_socket_resp.socket().data().streams_succeeded()); + // All of the calls were unary, so calls started == messages sent. + EXPECT_EQ(get_subchannel_resp.subchannel().data().calls_started(), + get_socket_resp.socket().data().messages_sent()); + // We only get responses when the RPC was successful, so + // calls succeeded == messages received. + EXPECT_EQ(get_subchannel_resp.subchannel().data().calls_succeeded(), + get_socket_resp.socket().data().messages_received()); + } +} + +TEST_F(ChannelzServerTest, StreamingRPC) { + ResetStubs(); + ConfigureProxy(1); + const int kNumMessages = 5; + SendSuccessfulStream(kNumMessages); + // Get the channel + GetChannelRequest get_channel_request; + GetChannelResponse get_channel_response; + get_channel_request.set_channel_id(GetChannelId(0)); + ClientContext get_channel_context; + Status s = channelz_stub_->GetChannel( + &get_channel_context, get_channel_request, &get_channel_response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + EXPECT_EQ(get_channel_response.channel().data().calls_started(), 1); + EXPECT_EQ(get_channel_response.channel().data().calls_succeeded(), 1); + EXPECT_EQ(get_channel_response.channel().data().calls_failed(), 0); + // Get the subchannel + ASSERT_GT(get_channel_response.channel().subchannel_ref_size(), 0); + GetSubchannelRequest get_subchannel_request; + GetSubchannelResponse get_subchannel_response; + ClientContext get_subchannel_context; + get_subchannel_request.set_subchannel_id( + get_channel_response.channel().subchannel_ref(0).subchannel_id()); + s = channelz_stub_->GetSubchannel(&get_subchannel_context, + get_subchannel_request, + &get_subchannel_response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + EXPECT_EQ(get_subchannel_response.subchannel().data().calls_started(), 1); + EXPECT_EQ(get_subchannel_response.subchannel().data().calls_succeeded(), 1); + EXPECT_EQ(get_subchannel_response.subchannel().data().calls_failed(), 0); + // Get the socket + ASSERT_GT(get_subchannel_response.subchannel().socket_ref_size(), 0); + GetSocketRequest get_socket_request; + GetSocketResponse get_socket_response; + ClientContext get_socket_context; + get_socket_request.set_socket_id( + get_subchannel_response.subchannel().socket_ref(0).socket_id()); + EXPECT_TRUE( + get_subchannel_response.subchannel().socket_ref(0).name().find("http")); + s = channelz_stub_->GetSocket(&get_socket_context, get_socket_request, + &get_socket_response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + EXPECT_EQ(get_socket_response.socket().data().streams_started(), 1); + EXPECT_EQ(get_socket_response.socket().data().streams_succeeded(), 1); + EXPECT_EQ(get_socket_response.socket().data().streams_failed(), 0); + EXPECT_EQ(get_socket_response.socket().data().messages_sent(), kNumMessages); + EXPECT_EQ(get_socket_response.socket().data().messages_received(), + kNumMessages); +} + +TEST_F(ChannelzServerTest, GetServerSocketsTest) { + ResetStubs(); + ConfigureProxy(1); + GetServersRequest get_server_request; + GetServersResponse get_server_response; + get_server_request.set_start_server_id(0); + ClientContext get_server_context; + Status s = channelz_stub_->GetServers(&get_server_context, get_server_request, + &get_server_response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + EXPECT_EQ(get_server_response.server_size(), 1); + GetServerSocketsRequest get_server_sockets_request; + GetServerSocketsResponse get_server_sockets_response; + get_server_sockets_request.set_server_id( + get_server_response.server(0).ref().server_id()); + get_server_sockets_request.set_start_socket_id(0); + ClientContext get_server_sockets_context; + s = channelz_stub_->GetServerSockets(&get_server_sockets_context, + get_server_sockets_request, + &get_server_sockets_response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + EXPECT_EQ(get_server_sockets_response.socket_ref_size(), 1); + EXPECT_TRUE(get_server_sockets_response.socket_ref(0).name().find("http")); +} + +TEST_F(ChannelzServerTest, GetServerSocketsPaginationTest) { + ResetStubs(); + ConfigureProxy(1); + std::vector<std::unique_ptr<grpc::testing::EchoTestService::Stub>> stubs; + const int kNumServerSocketsCreated = 20; + for (int i = 0; i < kNumServerSocketsCreated; ++i) { + stubs.push_back(NewEchoStub()); + EchoRequest request; + EchoResponse response; + request.set_message("Hello channelz"); + request.mutable_param()->set_backend_channel_idx(0); + ClientContext context; + Status s = stubs.back()->Echo(&context, request, &response); + EXPECT_EQ(response.message(), request.message()); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + } + GetServersRequest get_server_request; + GetServersResponse get_server_response; + get_server_request.set_start_server_id(0); + ClientContext get_server_context; + Status s = channelz_stub_->GetServers(&get_server_context, get_server_request, + &get_server_response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + EXPECT_EQ(get_server_response.server_size(), 1); + // Make a request that gets all of the serversockets + { + GetServerSocketsRequest get_server_sockets_request; + GetServerSocketsResponse get_server_sockets_response; + get_server_sockets_request.set_server_id( + get_server_response.server(0).ref().server_id()); + get_server_sockets_request.set_start_socket_id(0); + ClientContext get_server_sockets_context; + s = channelz_stub_->GetServerSockets(&get_server_sockets_context, + get_server_sockets_request, + &get_server_sockets_response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + // We add one to account the channelz stub that will end up creating + // a serversocket. + EXPECT_EQ(get_server_sockets_response.socket_ref_size(), + kNumServerSocketsCreated + 1); + EXPECT_TRUE(get_server_sockets_response.end()); + } + // Now we make a request that exercises pagination. + { + GetServerSocketsRequest get_server_sockets_request; + GetServerSocketsResponse get_server_sockets_response; + get_server_sockets_request.set_server_id( + get_server_response.server(0).ref().server_id()); + get_server_sockets_request.set_start_socket_id(0); + const int kMaxResults = 10; + get_server_sockets_request.set_max_results(kMaxResults); + ClientContext get_server_sockets_context; + s = channelz_stub_->GetServerSockets(&get_server_sockets_context, + get_server_sockets_request, + &get_server_sockets_response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + EXPECT_EQ(get_server_sockets_response.socket_ref_size(), kMaxResults); + EXPECT_FALSE(get_server_sockets_response.end()); + } +} + +TEST_F(ChannelzServerTest, GetServerListenSocketsTest) { + ResetStubs(); + ConfigureProxy(1); + GetServersRequest get_server_request; + GetServersResponse get_server_response; + get_server_request.set_start_server_id(0); + ClientContext get_server_context; + Status s = channelz_stub_->GetServers(&get_server_context, get_server_request, + &get_server_response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + EXPECT_EQ(get_server_response.server_size(), 1); + EXPECT_EQ(get_server_response.server(0).listen_socket_size(), 1); + GetSocketRequest get_socket_request; + GetSocketResponse get_socket_response; + get_socket_request.set_socket_id( + get_server_response.server(0).listen_socket(0).socket_id()); + EXPECT_TRUE( + get_server_response.server(0).listen_socket(0).name().find("http")); + ClientContext get_socket_context; + s = channelz_stub_->GetSocket(&get_socket_context, get_socket_request, + &get_socket_response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); +} + +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/contrib/libs/grpc/test/cpp/end2end/client_callback_end2end_test.cc b/contrib/libs/grpc/test/cpp/end2end/client_callback_end2end_test.cc new file mode 100644 index 0000000000..12cb40a953 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/end2end/client_callback_end2end_test.cc @@ -0,0 +1,1565 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <grpcpp/channel.h> +#include <grpcpp/client_context.h> +#include <grpcpp/create_channel.h> +#include <grpcpp/generic/generic_stub.h> +#include <grpcpp/impl/codegen/proto_utils.h> +#include <grpcpp/server.h> +#include <grpcpp/server_builder.h> +#include <grpcpp/server_context.h> +#include <grpcpp/support/client_callback.h> +#include <gtest/gtest.h> + +#include <algorithm> +#include <condition_variable> +#include <functional> +#include <mutex> +#include <sstream> +#include <thread> + +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/iomgr/iomgr.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/end2end/interceptors_util.h" +#include "test/cpp/end2end/test_service_impl.h" +#include "test/cpp/util/byte_buffer_proto_helper.h" +#include "test/cpp/util/string_ref_helper.h" +#include "test/cpp/util/test_credentials_provider.h" + +// MAYBE_SKIP_TEST is a macro to determine if this particular test configuration +// should be skipped based on a decision made at SetUp time. In particular, any +// callback tests can only be run if the iomgr can run in the background or if +// the transport is in-process. +#define MAYBE_SKIP_TEST \ + do { \ + if (do_not_test_) { \ + return; \ + } \ + } while (0) + +namespace grpc { +namespace testing { +namespace { + +enum class Protocol { INPROC, TCP }; + +class TestScenario { + public: + TestScenario(bool serve_callback, Protocol protocol, bool intercept, + const TString& creds_type) + : callback_server(serve_callback), + protocol(protocol), + use_interceptors(intercept), + credentials_type(creds_type) {} + void Log() const; + bool callback_server; + Protocol protocol; + bool use_interceptors; + const TString credentials_type; +}; + +static std::ostream& operator<<(std::ostream& out, + const TestScenario& scenario) { + return out << "TestScenario{callback_server=" + << (scenario.callback_server ? "true" : "false") << ",protocol=" + << (scenario.protocol == Protocol::INPROC ? "INPROC" : "TCP") + << ",intercept=" << (scenario.use_interceptors ? "true" : "false") + << ",creds=" << scenario.credentials_type << "}"; +} + +void TestScenario::Log() const { + std::ostringstream out; + out << *this; + gpr_log(GPR_DEBUG, "%s", out.str().c_str()); +} + +class ClientCallbackEnd2endTest + : public ::testing::TestWithParam<TestScenario> { + protected: + ClientCallbackEnd2endTest() { GetParam().Log(); } + + void SetUp() override { + ServerBuilder builder; + + auto server_creds = GetCredentialsProvider()->GetServerCredentials( + GetParam().credentials_type); + // TODO(vjpai): Support testing of AuthMetadataProcessor + + if (GetParam().protocol == Protocol::TCP) { + picked_port_ = grpc_pick_unused_port_or_die(); + server_address_ << "localhost:" << picked_port_; + builder.AddListeningPort(server_address_.str(), server_creds); + } + if (!GetParam().callback_server) { + builder.RegisterService(&service_); + } else { + builder.RegisterService(&callback_service_); + } + + if (GetParam().use_interceptors) { + std::vector< + std::unique_ptr<experimental::ServerInterceptorFactoryInterface>> + creators; + // Add 20 dummy server interceptors + creators.reserve(20); + for (auto i = 0; i < 20; i++) { + creators.push_back(std::unique_ptr<DummyInterceptorFactory>( + new DummyInterceptorFactory())); + } + builder.experimental().SetInterceptorCreators(std::move(creators)); + } + + server_ = builder.BuildAndStart(); + is_server_started_ = true; + if (GetParam().protocol == Protocol::TCP && + !grpc_iomgr_run_in_background()) { + do_not_test_ = true; + } + } + + void ResetStub() { + ChannelArguments args; + auto channel_creds = GetCredentialsProvider()->GetChannelCredentials( + GetParam().credentials_type, &args); + switch (GetParam().protocol) { + case Protocol::TCP: + if (!GetParam().use_interceptors) { + channel_ = ::grpc::CreateCustomChannel(server_address_.str(), + channel_creds, args); + } else { + channel_ = CreateCustomChannelWithInterceptors( + server_address_.str(), channel_creds, args, + CreateDummyClientInterceptors()); + } + break; + case Protocol::INPROC: + if (!GetParam().use_interceptors) { + channel_ = server_->InProcessChannel(args); + } else { + channel_ = server_->experimental().InProcessChannelWithInterceptors( + args, CreateDummyClientInterceptors()); + } + break; + default: + assert(false); + } + stub_ = grpc::testing::EchoTestService::NewStub(channel_); + generic_stub_.reset(new GenericStub(channel_)); + DummyInterceptor::Reset(); + } + + void TearDown() override { + if (is_server_started_) { + // Although we would normally do an explicit shutdown, the server + // should also work correctly with just a destructor call. The regular + // end2end test uses explicit shutdown, so let this one just do reset. + server_.reset(); + } + if (picked_port_ > 0) { + grpc_recycle_unused_port(picked_port_); + } + } + + void SendRpcs(int num_rpcs, bool with_binary_metadata) { + TString test_string(""); + for (int i = 0; i < num_rpcs; i++) { + EchoRequest request; + EchoResponse response; + ClientContext cli_ctx; + + test_string += "Hello world. "; + request.set_message(test_string); + TString val; + if (with_binary_metadata) { + request.mutable_param()->set_echo_metadata(true); + char bytes[8] = {'\0', '\1', '\2', '\3', + '\4', '\5', '\6', static_cast<char>(i)}; + val = TString(bytes, 8); + cli_ctx.AddMetadata("custom-bin", val); + } + + cli_ctx.set_compression_algorithm(GRPC_COMPRESS_GZIP); + + std::mutex mu; + std::condition_variable cv; + bool done = false; + stub_->experimental_async()->Echo( + &cli_ctx, &request, &response, + [&cli_ctx, &request, &response, &done, &mu, &cv, val, + with_binary_metadata](Status s) { + GPR_ASSERT(s.ok()); + + EXPECT_EQ(request.message(), response.message()); + if (with_binary_metadata) { + EXPECT_EQ( + 1u, cli_ctx.GetServerTrailingMetadata().count("custom-bin")); + EXPECT_EQ(val, ToString(cli_ctx.GetServerTrailingMetadata() + .find("custom-bin") + ->second)); + } + std::lock_guard<std::mutex> l(mu); + done = true; + cv.notify_one(); + }); + std::unique_lock<std::mutex> l(mu); + while (!done) { + cv.wait(l); + } + } + } + + void SendRpcsGeneric(int num_rpcs, bool maybe_except) { + const TString kMethodName("/grpc.testing.EchoTestService/Echo"); + TString test_string(""); + for (int i = 0; i < num_rpcs; i++) { + EchoRequest request; + std::unique_ptr<ByteBuffer> send_buf; + ByteBuffer recv_buf; + ClientContext cli_ctx; + + test_string += "Hello world. "; + request.set_message(test_string); + send_buf = SerializeToByteBuffer(&request); + + std::mutex mu; + std::condition_variable cv; + bool done = false; + generic_stub_->experimental().UnaryCall( + &cli_ctx, kMethodName, send_buf.get(), &recv_buf, + [&request, &recv_buf, &done, &mu, &cv, maybe_except](Status s) { + GPR_ASSERT(s.ok()); + + EchoResponse response; + EXPECT_TRUE(ParseFromByteBuffer(&recv_buf, &response)); + EXPECT_EQ(request.message(), response.message()); + std::lock_guard<std::mutex> l(mu); + done = true; + cv.notify_one(); +#if GRPC_ALLOW_EXCEPTIONS + if (maybe_except) { + throw - 1; + } +#else + GPR_ASSERT(!maybe_except); +#endif + }); + std::unique_lock<std::mutex> l(mu); + while (!done) { + cv.wait(l); + } + } + } + + void SendGenericEchoAsBidi(int num_rpcs, int reuses, bool do_writes_done) { + const TString kMethodName("/grpc.testing.EchoTestService/Echo"); + TString test_string(""); + for (int i = 0; i < num_rpcs; i++) { + test_string += "Hello world. "; + class Client : public grpc::experimental::ClientBidiReactor<ByteBuffer, + ByteBuffer> { + public: + Client(ClientCallbackEnd2endTest* test, const TString& method_name, + const TString& test_str, int reuses, bool do_writes_done) + : reuses_remaining_(reuses), do_writes_done_(do_writes_done) { + activate_ = [this, test, method_name, test_str] { + if (reuses_remaining_ > 0) { + cli_ctx_.reset(new ClientContext); + reuses_remaining_--; + test->generic_stub_->experimental().PrepareBidiStreamingCall( + cli_ctx_.get(), method_name, this); + request_.set_message(test_str); + send_buf_ = SerializeToByteBuffer(&request_); + StartWrite(send_buf_.get()); + StartRead(&recv_buf_); + StartCall(); + } else { + std::unique_lock<std::mutex> l(mu_); + done_ = true; + cv_.notify_one(); + } + }; + activate_(); + } + void OnWriteDone(bool /*ok*/) override { + if (do_writes_done_) { + StartWritesDone(); + } + } + void OnReadDone(bool /*ok*/) override { + EchoResponse response; + EXPECT_TRUE(ParseFromByteBuffer(&recv_buf_, &response)); + EXPECT_EQ(request_.message(), response.message()); + }; + void OnDone(const Status& s) override { + EXPECT_TRUE(s.ok()); + activate_(); + } + void Await() { + std::unique_lock<std::mutex> l(mu_); + while (!done_) { + cv_.wait(l); + } + } + + EchoRequest request_; + std::unique_ptr<ByteBuffer> send_buf_; + ByteBuffer recv_buf_; + std::unique_ptr<ClientContext> cli_ctx_; + int reuses_remaining_; + std::function<void()> activate_; + std::mutex mu_; + std::condition_variable cv_; + bool done_ = false; + const bool do_writes_done_; + }; + + Client rpc(this, kMethodName, test_string, reuses, do_writes_done); + + rpc.Await(); + } + } + bool do_not_test_{false}; + bool is_server_started_{false}; + int picked_port_{0}; + std::shared_ptr<Channel> channel_; + std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_; + std::unique_ptr<grpc::GenericStub> generic_stub_; + TestServiceImpl service_; + CallbackTestServiceImpl callback_service_; + std::unique_ptr<Server> server_; + std::ostringstream server_address_; +}; + +TEST_P(ClientCallbackEnd2endTest, SimpleRpc) { + MAYBE_SKIP_TEST; + ResetStub(); + SendRpcs(1, false); +} + +TEST_P(ClientCallbackEnd2endTest, SimpleRpcExpectedError) { + MAYBE_SKIP_TEST; + ResetStub(); + + EchoRequest request; + EchoResponse response; + ClientContext cli_ctx; + ErrorStatus error_status; + + request.set_message("Hello failure"); + error_status.set_code(1); // CANCELLED + error_status.set_error_message("cancel error message"); + *request.mutable_param()->mutable_expected_error() = error_status; + + std::mutex mu; + std::condition_variable cv; + bool done = false; + + stub_->experimental_async()->Echo( + &cli_ctx, &request, &response, + [&response, &done, &mu, &cv, &error_status](Status s) { + EXPECT_EQ("", response.message()); + EXPECT_EQ(error_status.code(), s.error_code()); + EXPECT_EQ(error_status.error_message(), s.error_message()); + std::lock_guard<std::mutex> l(mu); + done = true; + cv.notify_one(); + }); + + std::unique_lock<std::mutex> l(mu); + while (!done) { + cv.wait(l); + } +} + +TEST_P(ClientCallbackEnd2endTest, SimpleRpcUnderLockNested) { + MAYBE_SKIP_TEST; + ResetStub(); + + // The request/response state associated with an RPC and the synchronization + // variables needed to notify its completion. + struct RpcState { + std::mutex mu; + std::condition_variable cv; + bool done = false; + EchoRequest request; + EchoResponse response; + ClientContext cli_ctx; + + RpcState() = default; + ~RpcState() { + // Grab the lock to prevent destruction while another is still holding + // lock + std::lock_guard<std::mutex> lock(mu); + } + }; + std::vector<RpcState> rpc_state(3); + for (size_t i = 0; i < rpc_state.size(); i++) { + TString message = "Hello locked world"; + message += ToString(i); + rpc_state[i].request.set_message(message); + } + + // Grab a lock and then start an RPC whose callback grabs the same lock and + // then calls this function to start the next RPC under lock (up to a limit of + // the size of the rpc_state vector). + std::function<void(int)> nested_call = [this, &nested_call, + &rpc_state](int index) { + std::lock_guard<std::mutex> l(rpc_state[index].mu); + stub_->experimental_async()->Echo( + &rpc_state[index].cli_ctx, &rpc_state[index].request, + &rpc_state[index].response, + [index, &nested_call, &rpc_state](Status s) { + std::lock_guard<std::mutex> l1(rpc_state[index].mu); + EXPECT_TRUE(s.ok()); + rpc_state[index].done = true; + rpc_state[index].cv.notify_all(); + // Call the next level of nesting if possible + if (index + 1 < rpc_state.size()) { + nested_call(index + 1); + } + }); + }; + + nested_call(0); + + // Wait for completion notifications from all RPCs. Order doesn't matter. + for (RpcState& state : rpc_state) { + std::unique_lock<std::mutex> l(state.mu); + while (!state.done) { + state.cv.wait(l); + } + EXPECT_EQ(state.request.message(), state.response.message()); + } +} + +TEST_P(ClientCallbackEnd2endTest, SimpleRpcUnderLock) { + MAYBE_SKIP_TEST; + ResetStub(); + std::mutex mu; + std::condition_variable cv; + bool done = false; + EchoRequest request; + request.set_message("Hello locked world."); + EchoResponse response; + ClientContext cli_ctx; + { + std::lock_guard<std::mutex> l(mu); + stub_->experimental_async()->Echo( + &cli_ctx, &request, &response, + [&mu, &cv, &done, &request, &response](Status s) { + std::lock_guard<std::mutex> l(mu); + EXPECT_TRUE(s.ok()); + EXPECT_EQ(request.message(), response.message()); + done = true; + cv.notify_one(); + }); + } + std::unique_lock<std::mutex> l(mu); + while (!done) { + cv.wait(l); + } +} + +TEST_P(ClientCallbackEnd2endTest, SequentialRpcs) { + MAYBE_SKIP_TEST; + ResetStub(); + SendRpcs(10, false); +} + +TEST_P(ClientCallbackEnd2endTest, SendClientInitialMetadata) { + MAYBE_SKIP_TEST; + ResetStub(); + SimpleRequest request; + SimpleResponse response; + ClientContext cli_ctx; + + cli_ctx.AddMetadata(kCheckClientInitialMetadataKey, + kCheckClientInitialMetadataVal); + + std::mutex mu; + std::condition_variable cv; + bool done = false; + stub_->experimental_async()->CheckClientInitialMetadata( + &cli_ctx, &request, &response, [&done, &mu, &cv](Status s) { + GPR_ASSERT(s.ok()); + + std::lock_guard<std::mutex> l(mu); + done = true; + cv.notify_one(); + }); + std::unique_lock<std::mutex> l(mu); + while (!done) { + cv.wait(l); + } +} + +TEST_P(ClientCallbackEnd2endTest, SimpleRpcWithBinaryMetadata) { + MAYBE_SKIP_TEST; + ResetStub(); + SendRpcs(1, true); +} + +TEST_P(ClientCallbackEnd2endTest, SequentialRpcsWithVariedBinaryMetadataValue) { + MAYBE_SKIP_TEST; + ResetStub(); + SendRpcs(10, true); +} + +TEST_P(ClientCallbackEnd2endTest, SequentialGenericRpcs) { + MAYBE_SKIP_TEST; + ResetStub(); + SendRpcsGeneric(10, false); +} + +TEST_P(ClientCallbackEnd2endTest, SequentialGenericRpcsAsBidi) { + MAYBE_SKIP_TEST; + ResetStub(); + SendGenericEchoAsBidi(10, 1, /*do_writes_done=*/true); +} + +TEST_P(ClientCallbackEnd2endTest, SequentialGenericRpcsAsBidiWithReactorReuse) { + MAYBE_SKIP_TEST; + ResetStub(); + SendGenericEchoAsBidi(10, 10, /*do_writes_done=*/true); +} + +TEST_P(ClientCallbackEnd2endTest, GenericRpcNoWritesDone) { + MAYBE_SKIP_TEST; + ResetStub(); + SendGenericEchoAsBidi(1, 1, /*do_writes_done=*/false); +} + +#if GRPC_ALLOW_EXCEPTIONS +TEST_P(ClientCallbackEnd2endTest, ExceptingRpc) { + MAYBE_SKIP_TEST; + ResetStub(); + SendRpcsGeneric(10, true); +} +#endif + +TEST_P(ClientCallbackEnd2endTest, MultipleRpcsWithVariedBinaryMetadataValue) { + MAYBE_SKIP_TEST; + ResetStub(); + std::vector<std::thread> threads; + threads.reserve(10); + for (int i = 0; i < 10; ++i) { + threads.emplace_back([this] { SendRpcs(10, true); }); + } + for (int i = 0; i < 10; ++i) { + threads[i].join(); + } +} + +TEST_P(ClientCallbackEnd2endTest, MultipleRpcs) { + MAYBE_SKIP_TEST; + ResetStub(); + std::vector<std::thread> threads; + threads.reserve(10); + for (int i = 0; i < 10; ++i) { + threads.emplace_back([this] { SendRpcs(10, false); }); + } + for (int i = 0; i < 10; ++i) { + threads[i].join(); + } +} + +TEST_P(ClientCallbackEnd2endTest, CancelRpcBeforeStart) { + MAYBE_SKIP_TEST; + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + request.set_message("hello"); + context.TryCancel(); + + std::mutex mu; + std::condition_variable cv; + bool done = false; + stub_->experimental_async()->Echo( + &context, &request, &response, [&response, &done, &mu, &cv](Status s) { + EXPECT_EQ("", response.message()); + EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code()); + std::lock_guard<std::mutex> l(mu); + done = true; + cv.notify_one(); + }); + std::unique_lock<std::mutex> l(mu); + while (!done) { + cv.wait(l); + } + if (GetParam().use_interceptors) { + EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel()); + } +} + +TEST_P(ClientCallbackEnd2endTest, RequestEchoServerCancel) { + MAYBE_SKIP_TEST; + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + request.set_message("hello"); + context.AddMetadata(kServerTryCancelRequest, + ToString(CANCEL_BEFORE_PROCESSING)); + + std::mutex mu; + std::condition_variable cv; + bool done = false; + stub_->experimental_async()->Echo( + &context, &request, &response, [&done, &mu, &cv](Status s) { + EXPECT_FALSE(s.ok()); + EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code()); + std::lock_guard<std::mutex> l(mu); + done = true; + cv.notify_one(); + }); + std::unique_lock<std::mutex> l(mu); + while (!done) { + cv.wait(l); + } +} + +struct ClientCancelInfo { + bool cancel{false}; + int ops_before_cancel; + + ClientCancelInfo() : cancel{false} {} + explicit ClientCancelInfo(int ops) : cancel{true}, ops_before_cancel{ops} {} +}; + +class WriteClient : public grpc::experimental::ClientWriteReactor<EchoRequest> { + public: + WriteClient(grpc::testing::EchoTestService::Stub* stub, + ServerTryCancelRequestPhase server_try_cancel, + int num_msgs_to_send, ClientCancelInfo client_cancel = {}) + : server_try_cancel_(server_try_cancel), + num_msgs_to_send_(num_msgs_to_send), + client_cancel_{client_cancel} { + TString msg{"Hello server."}; + for (int i = 0; i < num_msgs_to_send; i++) { + desired_ += msg; + } + if (server_try_cancel != DO_NOT_CANCEL) { + // Send server_try_cancel value in the client metadata + context_.AddMetadata(kServerTryCancelRequest, + ToString(server_try_cancel)); + } + context_.set_initial_metadata_corked(true); + stub->experimental_async()->RequestStream(&context_, &response_, this); + StartCall(); + request_.set_message(msg); + MaybeWrite(); + } + void OnWriteDone(bool ok) override { + if (ok) { + num_msgs_sent_++; + MaybeWrite(); + } + } + void OnDone(const Status& s) override { + gpr_log(GPR_INFO, "Sent %d messages", num_msgs_sent_); + int num_to_send = + (client_cancel_.cancel) + ? std::min(num_msgs_to_send_, client_cancel_.ops_before_cancel) + : num_msgs_to_send_; + switch (server_try_cancel_) { + case CANCEL_BEFORE_PROCESSING: + case CANCEL_DURING_PROCESSING: + // If the RPC is canceled by server before / during messages from the + // client, it means that the client most likely did not get a chance to + // send all the messages it wanted to send. i.e num_msgs_sent <= + // num_msgs_to_send + EXPECT_LE(num_msgs_sent_, num_to_send); + break; + case DO_NOT_CANCEL: + case CANCEL_AFTER_PROCESSING: + // If the RPC was not canceled or canceled after all messages were read + // by the server, the client did get a chance to send all its messages + EXPECT_EQ(num_msgs_sent_, num_to_send); + break; + default: + assert(false); + break; + } + if ((server_try_cancel_ == DO_NOT_CANCEL) && !client_cancel_.cancel) { + EXPECT_TRUE(s.ok()); + EXPECT_EQ(response_.message(), desired_); + } else { + EXPECT_FALSE(s.ok()); + EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code()); + } + std::unique_lock<std::mutex> l(mu_); + done_ = true; + cv_.notify_one(); + } + void Await() { + std::unique_lock<std::mutex> l(mu_); + while (!done_) { + cv_.wait(l); + } + } + + private: + void MaybeWrite() { + if (client_cancel_.cancel && + num_msgs_sent_ == client_cancel_.ops_before_cancel) { + context_.TryCancel(); + } else if (num_msgs_to_send_ > num_msgs_sent_ + 1) { + StartWrite(&request_); + } else if (num_msgs_to_send_ == num_msgs_sent_ + 1) { + StartWriteLast(&request_, WriteOptions()); + } + } + EchoRequest request_; + EchoResponse response_; + ClientContext context_; + const ServerTryCancelRequestPhase server_try_cancel_; + int num_msgs_sent_{0}; + const int num_msgs_to_send_; + TString desired_; + const ClientCancelInfo client_cancel_; + std::mutex mu_; + std::condition_variable cv_; + bool done_ = false; +}; + +TEST_P(ClientCallbackEnd2endTest, RequestStream) { + MAYBE_SKIP_TEST; + ResetStub(); + WriteClient test{stub_.get(), DO_NOT_CANCEL, 3}; + test.Await(); + // Make sure that the server interceptors were not notified to cancel + if (GetParam().use_interceptors) { + EXPECT_EQ(0, DummyInterceptor::GetNumTimesCancel()); + } +} + +TEST_P(ClientCallbackEnd2endTest, ClientCancelsRequestStream) { + MAYBE_SKIP_TEST; + ResetStub(); + WriteClient test{stub_.get(), DO_NOT_CANCEL, 3, ClientCancelInfo{2}}; + test.Await(); + // Make sure that the server interceptors got the cancel + if (GetParam().use_interceptors) { + EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel()); + } +} + +// Server to cancel before doing reading the request +TEST_P(ClientCallbackEnd2endTest, RequestStreamServerCancelBeforeReads) { + MAYBE_SKIP_TEST; + ResetStub(); + WriteClient test{stub_.get(), CANCEL_BEFORE_PROCESSING, 1}; + test.Await(); + // Make sure that the server interceptors were notified + if (GetParam().use_interceptors) { + EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel()); + } +} + +// Server to cancel while reading a request from the stream in parallel +TEST_P(ClientCallbackEnd2endTest, RequestStreamServerCancelDuringRead) { + MAYBE_SKIP_TEST; + ResetStub(); + WriteClient test{stub_.get(), CANCEL_DURING_PROCESSING, 10}; + test.Await(); + // Make sure that the server interceptors were notified + if (GetParam().use_interceptors) { + EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel()); + } +} + +// Server to cancel after reading all the requests but before returning to the +// client +TEST_P(ClientCallbackEnd2endTest, RequestStreamServerCancelAfterReads) { + MAYBE_SKIP_TEST; + ResetStub(); + WriteClient test{stub_.get(), CANCEL_AFTER_PROCESSING, 4}; + test.Await(); + // Make sure that the server interceptors were notified + if (GetParam().use_interceptors) { + EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel()); + } +} + +TEST_P(ClientCallbackEnd2endTest, UnaryReactor) { + MAYBE_SKIP_TEST; + ResetStub(); + class UnaryClient : public grpc::experimental::ClientUnaryReactor { + public: + UnaryClient(grpc::testing::EchoTestService::Stub* stub) { + cli_ctx_.AddMetadata("key1", "val1"); + cli_ctx_.AddMetadata("key2", "val2"); + request_.mutable_param()->set_echo_metadata_initially(true); + request_.set_message("Hello metadata"); + stub->experimental_async()->Echo(&cli_ctx_, &request_, &response_, this); + StartCall(); + } + void OnReadInitialMetadataDone(bool ok) override { + EXPECT_TRUE(ok); + EXPECT_EQ(1u, cli_ctx_.GetServerInitialMetadata().count("key1")); + EXPECT_EQ( + "val1", + ToString(cli_ctx_.GetServerInitialMetadata().find("key1")->second)); + EXPECT_EQ(1u, cli_ctx_.GetServerInitialMetadata().count("key2")); + EXPECT_EQ( + "val2", + ToString(cli_ctx_.GetServerInitialMetadata().find("key2")->second)); + initial_metadata_done_ = true; + } + void OnDone(const Status& s) override { + EXPECT_TRUE(initial_metadata_done_); + EXPECT_EQ(0u, cli_ctx_.GetServerTrailingMetadata().size()); + EXPECT_TRUE(s.ok()); + EXPECT_EQ(request_.message(), response_.message()); + std::unique_lock<std::mutex> l(mu_); + done_ = true; + cv_.notify_one(); + } + void Await() { + std::unique_lock<std::mutex> l(mu_); + while (!done_) { + cv_.wait(l); + } + } + + private: + EchoRequest request_; + EchoResponse response_; + ClientContext cli_ctx_; + std::mutex mu_; + std::condition_variable cv_; + bool done_{false}; + bool initial_metadata_done_{false}; + }; + + UnaryClient test{stub_.get()}; + test.Await(); + // Make sure that the server interceptors were not notified of a cancel + if (GetParam().use_interceptors) { + EXPECT_EQ(0, DummyInterceptor::GetNumTimesCancel()); + } +} + +TEST_P(ClientCallbackEnd2endTest, GenericUnaryReactor) { + MAYBE_SKIP_TEST; + ResetStub(); + const TString kMethodName("/grpc.testing.EchoTestService/Echo"); + class UnaryClient : public grpc::experimental::ClientUnaryReactor { + public: + UnaryClient(grpc::GenericStub* stub, const TString& method_name) { + cli_ctx_.AddMetadata("key1", "val1"); + cli_ctx_.AddMetadata("key2", "val2"); + request_.mutable_param()->set_echo_metadata_initially(true); + request_.set_message("Hello metadata"); + send_buf_ = SerializeToByteBuffer(&request_); + + stub->experimental().PrepareUnaryCall(&cli_ctx_, method_name, + send_buf_.get(), &recv_buf_, this); + StartCall(); + } + void OnReadInitialMetadataDone(bool ok) override { + EXPECT_TRUE(ok); + EXPECT_EQ(1u, cli_ctx_.GetServerInitialMetadata().count("key1")); + EXPECT_EQ( + "val1", + ToString(cli_ctx_.GetServerInitialMetadata().find("key1")->second)); + EXPECT_EQ(1u, cli_ctx_.GetServerInitialMetadata().count("key2")); + EXPECT_EQ( + "val2", + ToString(cli_ctx_.GetServerInitialMetadata().find("key2")->second)); + initial_metadata_done_ = true; + } + void OnDone(const Status& s) override { + EXPECT_TRUE(initial_metadata_done_); + EXPECT_EQ(0u, cli_ctx_.GetServerTrailingMetadata().size()); + EXPECT_TRUE(s.ok()); + EchoResponse response; + EXPECT_TRUE(ParseFromByteBuffer(&recv_buf_, &response)); + EXPECT_EQ(request_.message(), response.message()); + std::unique_lock<std::mutex> l(mu_); + done_ = true; + cv_.notify_one(); + } + void Await() { + std::unique_lock<std::mutex> l(mu_); + while (!done_) { + cv_.wait(l); + } + } + + private: + EchoRequest request_; + std::unique_ptr<ByteBuffer> send_buf_; + ByteBuffer recv_buf_; + ClientContext cli_ctx_; + std::mutex mu_; + std::condition_variable cv_; + bool done_{false}; + bool initial_metadata_done_{false}; + }; + + UnaryClient test{generic_stub_.get(), kMethodName}; + test.Await(); + // Make sure that the server interceptors were not notified of a cancel + if (GetParam().use_interceptors) { + EXPECT_EQ(0, DummyInterceptor::GetNumTimesCancel()); + } +} + +class ReadClient : public grpc::experimental::ClientReadReactor<EchoResponse> { + public: + ReadClient(grpc::testing::EchoTestService::Stub* stub, + ServerTryCancelRequestPhase server_try_cancel, + ClientCancelInfo client_cancel = {}) + : server_try_cancel_(server_try_cancel), client_cancel_{client_cancel} { + if (server_try_cancel_ != DO_NOT_CANCEL) { + // Send server_try_cancel value in the client metadata + context_.AddMetadata(kServerTryCancelRequest, + ToString(server_try_cancel)); + } + request_.set_message("Hello client "); + stub->experimental_async()->ResponseStream(&context_, &request_, this); + if (client_cancel_.cancel && + reads_complete_ == client_cancel_.ops_before_cancel) { + context_.TryCancel(); + } + // Even if we cancel, read until failure because there might be responses + // pending + StartRead(&response_); + StartCall(); + } + void OnReadDone(bool ok) override { + if (!ok) { + if (server_try_cancel_ == DO_NOT_CANCEL && !client_cancel_.cancel) { + EXPECT_EQ(reads_complete_, kServerDefaultResponseStreamsToSend); + } + } else { + EXPECT_LE(reads_complete_, kServerDefaultResponseStreamsToSend); + EXPECT_EQ(response_.message(), + request_.message() + ToString(reads_complete_)); + reads_complete_++; + if (client_cancel_.cancel && + reads_complete_ == client_cancel_.ops_before_cancel) { + context_.TryCancel(); + } + // Even if we cancel, read until failure because there might be responses + // pending + StartRead(&response_); + } + } + void OnDone(const Status& s) override { + gpr_log(GPR_INFO, "Read %d messages", reads_complete_); + switch (server_try_cancel_) { + case DO_NOT_CANCEL: + if (!client_cancel_.cancel || client_cancel_.ops_before_cancel > + kServerDefaultResponseStreamsToSend) { + EXPECT_TRUE(s.ok()); + EXPECT_EQ(reads_complete_, kServerDefaultResponseStreamsToSend); + } else { + EXPECT_GE(reads_complete_, client_cancel_.ops_before_cancel); + EXPECT_LE(reads_complete_, kServerDefaultResponseStreamsToSend); + // Status might be ok or cancelled depending on whether server + // sent status before client cancel went through + if (!s.ok()) { + EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code()); + } + } + break; + case CANCEL_BEFORE_PROCESSING: + EXPECT_FALSE(s.ok()); + EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code()); + EXPECT_EQ(reads_complete_, 0); + break; + case CANCEL_DURING_PROCESSING: + case CANCEL_AFTER_PROCESSING: + // If server canceled while writing messages, client must have read + // less than or equal to the expected number of messages. Even if the + // server canceled after writing all messages, the RPC may be canceled + // before the Client got a chance to read all the messages. + EXPECT_FALSE(s.ok()); + EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code()); + EXPECT_LE(reads_complete_, kServerDefaultResponseStreamsToSend); + break; + default: + assert(false); + } + std::unique_lock<std::mutex> l(mu_); + done_ = true; + cv_.notify_one(); + } + void Await() { + std::unique_lock<std::mutex> l(mu_); + while (!done_) { + cv_.wait(l); + } + } + + private: + EchoRequest request_; + EchoResponse response_; + ClientContext context_; + const ServerTryCancelRequestPhase server_try_cancel_; + int reads_complete_{0}; + const ClientCancelInfo client_cancel_; + std::mutex mu_; + std::condition_variable cv_; + bool done_ = false; +}; + +TEST_P(ClientCallbackEnd2endTest, ResponseStream) { + MAYBE_SKIP_TEST; + ResetStub(); + ReadClient test{stub_.get(), DO_NOT_CANCEL}; + test.Await(); + // Make sure that the server interceptors were not notified of a cancel + if (GetParam().use_interceptors) { + EXPECT_EQ(0, DummyInterceptor::GetNumTimesCancel()); + } +} + +TEST_P(ClientCallbackEnd2endTest, ClientCancelsResponseStream) { + MAYBE_SKIP_TEST; + ResetStub(); + ReadClient test{stub_.get(), DO_NOT_CANCEL, ClientCancelInfo{2}}; + test.Await(); + // Because cancel in this case races with server finish, we can't be sure that + // server interceptors even see cancellation +} + +// Server to cancel before sending any response messages +TEST_P(ClientCallbackEnd2endTest, ResponseStreamServerCancelBefore) { + MAYBE_SKIP_TEST; + ResetStub(); + ReadClient test{stub_.get(), CANCEL_BEFORE_PROCESSING}; + test.Await(); + // Make sure that the server interceptors were notified + if (GetParam().use_interceptors) { + EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel()); + } +} + +// Server to cancel while writing a response to the stream in parallel +TEST_P(ClientCallbackEnd2endTest, ResponseStreamServerCancelDuring) { + MAYBE_SKIP_TEST; + ResetStub(); + ReadClient test{stub_.get(), CANCEL_DURING_PROCESSING}; + test.Await(); + // Make sure that the server interceptors were notified + if (GetParam().use_interceptors) { + EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel()); + } +} + +// Server to cancel after writing all the respones to the stream but before +// returning to the client +TEST_P(ClientCallbackEnd2endTest, ResponseStreamServerCancelAfter) { + MAYBE_SKIP_TEST; + ResetStub(); + ReadClient test{stub_.get(), CANCEL_AFTER_PROCESSING}; + test.Await(); + // Make sure that the server interceptors were notified + if (GetParam().use_interceptors) { + EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel()); + } +} + +class BidiClient + : public grpc::experimental::ClientBidiReactor<EchoRequest, EchoResponse> { + public: + BidiClient(grpc::testing::EchoTestService::Stub* stub, + ServerTryCancelRequestPhase server_try_cancel, + int num_msgs_to_send, bool cork_metadata, bool first_write_async, + ClientCancelInfo client_cancel = {}) + : server_try_cancel_(server_try_cancel), + msgs_to_send_{num_msgs_to_send}, + client_cancel_{client_cancel} { + if (server_try_cancel_ != DO_NOT_CANCEL) { + // Send server_try_cancel value in the client metadata + context_.AddMetadata(kServerTryCancelRequest, + ToString(server_try_cancel)); + } + request_.set_message("Hello fren "); + context_.set_initial_metadata_corked(cork_metadata); + stub->experimental_async()->BidiStream(&context_, this); + MaybeAsyncWrite(first_write_async); + StartRead(&response_); + StartCall(); + } + void OnReadDone(bool ok) override { + if (!ok) { + if (server_try_cancel_ == DO_NOT_CANCEL) { + if (!client_cancel_.cancel) { + EXPECT_EQ(reads_complete_, msgs_to_send_); + } else { + EXPECT_LE(reads_complete_, writes_complete_); + } + } + } else { + EXPECT_LE(reads_complete_, msgs_to_send_); + EXPECT_EQ(response_.message(), request_.message()); + reads_complete_++; + StartRead(&response_); + } + } + void OnWriteDone(bool ok) override { + if (async_write_thread_.joinable()) { + async_write_thread_.join(); + RemoveHold(); + } + if (server_try_cancel_ == DO_NOT_CANCEL) { + EXPECT_TRUE(ok); + } else if (!ok) { + return; + } + writes_complete_++; + MaybeWrite(); + } + void OnDone(const Status& s) override { + gpr_log(GPR_INFO, "Sent %d messages", writes_complete_); + gpr_log(GPR_INFO, "Read %d messages", reads_complete_); + switch (server_try_cancel_) { + case DO_NOT_CANCEL: + if (!client_cancel_.cancel || + client_cancel_.ops_before_cancel > msgs_to_send_) { + EXPECT_TRUE(s.ok()); + EXPECT_EQ(writes_complete_, msgs_to_send_); + EXPECT_EQ(reads_complete_, writes_complete_); + } else { + EXPECT_FALSE(s.ok()); + EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code()); + EXPECT_EQ(writes_complete_, client_cancel_.ops_before_cancel); + EXPECT_LE(reads_complete_, writes_complete_); + } + break; + case CANCEL_BEFORE_PROCESSING: + EXPECT_FALSE(s.ok()); + EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code()); + // The RPC is canceled before the server did any work or returned any + // reads, but it's possible that some writes took place first from the + // client + EXPECT_LE(writes_complete_, msgs_to_send_); + EXPECT_EQ(reads_complete_, 0); + break; + case CANCEL_DURING_PROCESSING: + EXPECT_FALSE(s.ok()); + EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code()); + EXPECT_LE(writes_complete_, msgs_to_send_); + EXPECT_LE(reads_complete_, writes_complete_); + break; + case CANCEL_AFTER_PROCESSING: + EXPECT_FALSE(s.ok()); + EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code()); + EXPECT_EQ(writes_complete_, msgs_to_send_); + // The Server canceled after reading the last message and after writing + // the message to the client. However, the RPC cancellation might have + // taken effect before the client actually read the response. + EXPECT_LE(reads_complete_, writes_complete_); + break; + default: + assert(false); + } + std::unique_lock<std::mutex> l(mu_); + done_ = true; + cv_.notify_one(); + } + void Await() { + std::unique_lock<std::mutex> l(mu_); + while (!done_) { + cv_.wait(l); + } + } + + private: + void MaybeAsyncWrite(bool first_write_async) { + if (first_write_async) { + // Make sure that we have a write to issue. + // TODO(vjpai): Make this work with 0 writes case as well. + assert(msgs_to_send_ >= 1); + + AddHold(); + async_write_thread_ = std::thread([this] { + std::unique_lock<std::mutex> lock(async_write_thread_mu_); + async_write_thread_cv_.wait( + lock, [this] { return async_write_thread_start_; }); + MaybeWrite(); + }); + std::lock_guard<std::mutex> lock(async_write_thread_mu_); + async_write_thread_start_ = true; + async_write_thread_cv_.notify_one(); + return; + } + MaybeWrite(); + } + void MaybeWrite() { + if (client_cancel_.cancel && + writes_complete_ == client_cancel_.ops_before_cancel) { + context_.TryCancel(); + } else if (writes_complete_ == msgs_to_send_) { + StartWritesDone(); + } else { + StartWrite(&request_); + } + } + EchoRequest request_; + EchoResponse response_; + ClientContext context_; + const ServerTryCancelRequestPhase server_try_cancel_; + int reads_complete_{0}; + int writes_complete_{0}; + const int msgs_to_send_; + const ClientCancelInfo client_cancel_; + std::mutex mu_; + std::condition_variable cv_; + bool done_ = false; + std::thread async_write_thread_; + bool async_write_thread_start_ = false; + std::mutex async_write_thread_mu_; + std::condition_variable async_write_thread_cv_; +}; + +TEST_P(ClientCallbackEnd2endTest, BidiStream) { + MAYBE_SKIP_TEST; + ResetStub(); + BidiClient test(stub_.get(), DO_NOT_CANCEL, + kServerDefaultResponseStreamsToSend, + /*cork_metadata=*/false, /*first_write_async=*/false); + test.Await(); + // Make sure that the server interceptors were not notified of a cancel + if (GetParam().use_interceptors) { + EXPECT_EQ(0, DummyInterceptor::GetNumTimesCancel()); + } +} + +TEST_P(ClientCallbackEnd2endTest, BidiStreamFirstWriteAsync) { + MAYBE_SKIP_TEST; + ResetStub(); + BidiClient test(stub_.get(), DO_NOT_CANCEL, + kServerDefaultResponseStreamsToSend, + /*cork_metadata=*/false, /*first_write_async=*/true); + test.Await(); + // Make sure that the server interceptors were not notified of a cancel + if (GetParam().use_interceptors) { + EXPECT_EQ(0, DummyInterceptor::GetNumTimesCancel()); + } +} + +TEST_P(ClientCallbackEnd2endTest, BidiStreamCorked) { + MAYBE_SKIP_TEST; + ResetStub(); + BidiClient test(stub_.get(), DO_NOT_CANCEL, + kServerDefaultResponseStreamsToSend, + /*cork_metadata=*/true, /*first_write_async=*/false); + test.Await(); + // Make sure that the server interceptors were not notified of a cancel + if (GetParam().use_interceptors) { + EXPECT_EQ(0, DummyInterceptor::GetNumTimesCancel()); + } +} + +TEST_P(ClientCallbackEnd2endTest, BidiStreamCorkedFirstWriteAsync) { + MAYBE_SKIP_TEST; + ResetStub(); + BidiClient test(stub_.get(), DO_NOT_CANCEL, + kServerDefaultResponseStreamsToSend, + /*cork_metadata=*/true, /*first_write_async=*/true); + test.Await(); + // Make sure that the server interceptors were not notified of a cancel + if (GetParam().use_interceptors) { + EXPECT_EQ(0, DummyInterceptor::GetNumTimesCancel()); + } +} + +TEST_P(ClientCallbackEnd2endTest, ClientCancelsBidiStream) { + MAYBE_SKIP_TEST; + ResetStub(); + BidiClient test(stub_.get(), DO_NOT_CANCEL, + kServerDefaultResponseStreamsToSend, + /*cork_metadata=*/false, /*first_write_async=*/false, + ClientCancelInfo(2)); + test.Await(); + // Make sure that the server interceptors were notified of a cancel + if (GetParam().use_interceptors) { + EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel()); + } +} + +// Server to cancel before reading/writing any requests/responses on the stream +TEST_P(ClientCallbackEnd2endTest, BidiStreamServerCancelBefore) { + MAYBE_SKIP_TEST; + ResetStub(); + BidiClient test(stub_.get(), CANCEL_BEFORE_PROCESSING, /*num_msgs_to_send=*/2, + /*cork_metadata=*/false, /*first_write_async=*/false); + test.Await(); + // Make sure that the server interceptors were notified + if (GetParam().use_interceptors) { + EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel()); + } +} + +// Server to cancel while reading/writing requests/responses on the stream in +// parallel +TEST_P(ClientCallbackEnd2endTest, BidiStreamServerCancelDuring) { + MAYBE_SKIP_TEST; + ResetStub(); + BidiClient test(stub_.get(), CANCEL_DURING_PROCESSING, + /*num_msgs_to_send=*/10, /*cork_metadata=*/false, + /*first_write_async=*/false); + test.Await(); + // Make sure that the server interceptors were notified + if (GetParam().use_interceptors) { + EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel()); + } +} + +// Server to cancel after reading/writing all requests/responses on the stream +// but before returning to the client +TEST_P(ClientCallbackEnd2endTest, BidiStreamServerCancelAfter) { + MAYBE_SKIP_TEST; + ResetStub(); + BidiClient test(stub_.get(), CANCEL_AFTER_PROCESSING, /*num_msgs_to_send=*/5, + /*cork_metadata=*/false, /*first_write_async=*/false); + test.Await(); + // Make sure that the server interceptors were notified + if (GetParam().use_interceptors) { + EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel()); + } +} + +TEST_P(ClientCallbackEnd2endTest, SimultaneousReadAndWritesDone) { + MAYBE_SKIP_TEST; + ResetStub(); + class Client : public grpc::experimental::ClientBidiReactor<EchoRequest, + EchoResponse> { + public: + Client(grpc::testing::EchoTestService::Stub* stub) { + request_.set_message("Hello bidi "); + stub->experimental_async()->BidiStream(&context_, this); + StartWrite(&request_); + StartCall(); + } + void OnReadDone(bool ok) override { + EXPECT_TRUE(ok); + EXPECT_EQ(response_.message(), request_.message()); + } + void OnWriteDone(bool ok) override { + EXPECT_TRUE(ok); + // Now send out the simultaneous Read and WritesDone + StartWritesDone(); + StartRead(&response_); + } + void OnDone(const Status& s) override { + EXPECT_TRUE(s.ok()); + EXPECT_EQ(response_.message(), request_.message()); + std::unique_lock<std::mutex> l(mu_); + done_ = true; + cv_.notify_one(); + } + void Await() { + std::unique_lock<std::mutex> l(mu_); + while (!done_) { + cv_.wait(l); + } + } + + private: + EchoRequest request_; + EchoResponse response_; + ClientContext context_; + std::mutex mu_; + std::condition_variable cv_; + bool done_ = false; + } test{stub_.get()}; + + test.Await(); +} + +TEST_P(ClientCallbackEnd2endTest, UnimplementedRpc) { + MAYBE_SKIP_TEST; + ChannelArguments args; + const auto& channel_creds = GetCredentialsProvider()->GetChannelCredentials( + GetParam().credentials_type, &args); + std::shared_ptr<Channel> channel = + (GetParam().protocol == Protocol::TCP) + ? ::grpc::CreateCustomChannel(server_address_.str(), channel_creds, + args) + : server_->InProcessChannel(args); + std::unique_ptr<grpc::testing::UnimplementedEchoService::Stub> stub; + stub = grpc::testing::UnimplementedEchoService::NewStub(channel); + EchoRequest request; + EchoResponse response; + ClientContext cli_ctx; + request.set_message("Hello world."); + std::mutex mu; + std::condition_variable cv; + bool done = false; + stub->experimental_async()->Unimplemented( + &cli_ctx, &request, &response, [&done, &mu, &cv](Status s) { + EXPECT_EQ(StatusCode::UNIMPLEMENTED, s.error_code()); + EXPECT_EQ("", s.error_message()); + + std::lock_guard<std::mutex> l(mu); + done = true; + cv.notify_one(); + }); + std::unique_lock<std::mutex> l(mu); + while (!done) { + cv.wait(l); + } +} + +TEST_P(ClientCallbackEnd2endTest, + ResponseStreamExtraReactionFlowReadsUntilDone) { + MAYBE_SKIP_TEST; + ResetStub(); + class ReadAllIncomingDataClient + : public grpc::experimental::ClientReadReactor<EchoResponse> { + public: + ReadAllIncomingDataClient(grpc::testing::EchoTestService::Stub* stub) { + request_.set_message("Hello client "); + stub->experimental_async()->ResponseStream(&context_, &request_, this); + } + bool WaitForReadDone() { + std::unique_lock<std::mutex> l(mu_); + while (!read_done_) { + read_cv_.wait(l); + } + read_done_ = false; + return read_ok_; + } + void Await() { + std::unique_lock<std::mutex> l(mu_); + while (!done_) { + done_cv_.wait(l); + } + } + // RemoveHold under the same lock used for OnDone to make sure that we don't + // call OnDone directly or indirectly from the RemoveHold function. + void RemoveHoldUnderLock() { + std::unique_lock<std::mutex> l(mu_); + RemoveHold(); + } + const Status& status() { + std::unique_lock<std::mutex> l(mu_); + return status_; + } + + private: + void OnReadDone(bool ok) override { + std::unique_lock<std::mutex> l(mu_); + read_ok_ = ok; + read_done_ = true; + read_cv_.notify_one(); + } + void OnDone(const Status& s) override { + std::unique_lock<std::mutex> l(mu_); + done_ = true; + status_ = s; + done_cv_.notify_one(); + } + + EchoRequest request_; + EchoResponse response_; + ClientContext context_; + bool read_ok_ = false; + bool read_done_ = false; + std::mutex mu_; + std::condition_variable read_cv_; + std::condition_variable done_cv_; + bool done_ = false; + Status status_; + } client{stub_.get()}; + + int reads_complete = 0; + client.AddHold(); + client.StartCall(); + + EchoResponse response; + bool read_ok = true; + while (read_ok) { + client.StartRead(&response); + read_ok = client.WaitForReadDone(); + if (read_ok) { + ++reads_complete; + } + } + client.RemoveHoldUnderLock(); + client.Await(); + + EXPECT_EQ(kServerDefaultResponseStreamsToSend, reads_complete); + EXPECT_EQ(client.status().error_code(), grpc::StatusCode::OK); +} + +std::vector<TestScenario> CreateTestScenarios(bool test_insecure) { +#if TARGET_OS_IPHONE + // Workaround Apple CFStream bug + gpr_setenv("grpc_cfstream", "0"); +#endif + + std::vector<TestScenario> scenarios; + std::vector<TString> credentials_types{ + GetCredentialsProvider()->GetSecureCredentialsTypeList()}; + auto insec_ok = [] { + // Only allow insecure credentials type when it is registered with the + // provider. User may create providers that do not have insecure. + return GetCredentialsProvider()->GetChannelCredentials( + kInsecureCredentialsType, nullptr) != nullptr; + }; + if (test_insecure && insec_ok()) { + credentials_types.push_back(kInsecureCredentialsType); + } + GPR_ASSERT(!credentials_types.empty()); + + bool barr[]{false, true}; + Protocol parr[]{Protocol::INPROC, Protocol::TCP}; + for (Protocol p : parr) { + for (const auto& cred : credentials_types) { + // TODO(vjpai): Test inproc with secure credentials when feasible + if (p == Protocol::INPROC && + (cred != kInsecureCredentialsType || !insec_ok())) { + continue; + } + for (bool callback_server : barr) { + for (bool use_interceptors : barr) { + scenarios.emplace_back(callback_server, p, use_interceptors, cred); + } + } + } + } + return scenarios; +} + +INSTANTIATE_TEST_SUITE_P(ClientCallbackEnd2endTest, ClientCallbackEnd2endTest, + ::testing::ValuesIn(CreateTestScenarios(true))); + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(argc, argv); + grpc_init(); + int ret = RUN_ALL_TESTS(); + grpc_shutdown(); + return ret; +} diff --git a/contrib/libs/grpc/test/cpp/end2end/client_crash_test.cc b/contrib/libs/grpc/test/cpp/end2end/client_crash_test.cc new file mode 100644 index 0000000000..80e1869396 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/end2end/client_crash_test.cc @@ -0,0 +1,147 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <grpc/grpc.h> +#include <grpc/support/log.h> +#include <grpc/support/time.h> +#include <grpcpp/channel.h> +#include <grpcpp/client_context.h> +#include <grpcpp/create_channel.h> +#include <grpcpp/server.h> +#include <grpcpp/server_builder.h> +#include <grpcpp/server_context.h> + +#include "src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/util/subprocess.h" + +#include <gtest/gtest.h> + +using grpc::testing::EchoRequest; +using grpc::testing::EchoResponse; +using std::chrono::system_clock; + +static TString g_root; + +namespace grpc { +namespace testing { + +namespace { + +class CrashTest : public ::testing::Test { + protected: + CrashTest() {} + + std::unique_ptr<grpc::testing::EchoTestService::Stub> CreateServerAndStub() { + auto port = grpc_pick_unused_port_or_die(); + std::ostringstream addr_stream; + addr_stream << "localhost:" << port; + auto addr = addr_stream.str(); + server_.reset(new SubProcess({ + g_root + "/client_crash_test_server", + "--address=" + addr, + })); + GPR_ASSERT(server_); + return grpc::testing::EchoTestService::NewStub( + grpc::CreateChannel(addr, InsecureChannelCredentials())); + } + + void KillServer() { server_.reset(); } + + private: + std::unique_ptr<SubProcess> server_; +}; + +TEST_F(CrashTest, KillBeforeWrite) { + auto stub = CreateServerAndStub(); + + EchoRequest request; + EchoResponse response; + ClientContext context; + context.set_wait_for_ready(true); + + auto stream = stub->BidiStream(&context); + + request.set_message("Hello"); + EXPECT_TRUE(stream->Write(request)); + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message()); + + KillServer(); + + request.set_message("You should be dead"); + // This may succeed or fail depending on the state of the TCP connection + stream->Write(request); + // But the read will definitely fail + EXPECT_FALSE(stream->Read(&response)); + + EXPECT_FALSE(stream->Finish().ok()); +} + +TEST_F(CrashTest, KillAfterWrite) { + auto stub = CreateServerAndStub(); + + EchoRequest request; + EchoResponse response; + ClientContext context; + context.set_wait_for_ready(true); + + auto stream = stub->BidiStream(&context); + + request.set_message("Hello"); + EXPECT_TRUE(stream->Write(request)); + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message()); + + request.set_message("I'm going to kill you"); + EXPECT_TRUE(stream->Write(request)); + + KillServer(); + + // This may succeed or fail depending on how quick the server was + stream->Read(&response); + + EXPECT_FALSE(stream->Finish().ok()); +} + +} // namespace + +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + TString me = argv[0]; + auto lslash = me.rfind('/'); + if (lslash != TString::npos) { + g_root = me.substr(0, lslash); + } else { + g_root = "."; + } + + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + // Order seems to matter on these tests: run three times to eliminate that + for (int i = 0; i < 3; i++) { + if (RUN_ALL_TESTS() != 0) { + return 1; + } + } + return 0; +} diff --git a/contrib/libs/grpc/test/cpp/end2end/client_crash_test_server.cc b/contrib/libs/grpc/test/cpp/end2end/client_crash_test_server.cc new file mode 100644 index 0000000000..2d5be420f2 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/end2end/client_crash_test_server.cc @@ -0,0 +1,80 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <gflags/gflags.h> +#include <iostream> +#include <memory> +#include <util/generic/string.h> + +#include <grpc/support/log.h> +#include <grpcpp/server.h> +#include <grpcpp/server_builder.h> +#include <grpcpp/server_context.h> + +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/cpp/util/test_config.h" + +DEFINE_string(address, "", "Address to bind to"); + +using grpc::testing::EchoRequest; +using grpc::testing::EchoResponse; + +// In some distros, gflags is in the namespace google, and in some others, +// in gflags. This hack is enabling us to find both. +namespace google {} +namespace gflags {} +using namespace google; +using namespace gflags; + +namespace grpc { +namespace testing { + +class ServiceImpl final : public ::grpc::testing::EchoTestService::Service { + Status BidiStream( + ServerContext* /*context*/, + ServerReaderWriter<EchoResponse, EchoRequest>* stream) override { + EchoRequest request; + EchoResponse response; + while (stream->Read(&request)) { + gpr_log(GPR_INFO, "recv msg %s", request.message().c_str()); + response.set_message(request.message()); + stream->Write(response); + } + return Status::OK; + } +}; + +void RunServer() { + ServiceImpl service; + + ServerBuilder builder; + builder.AddListeningPort(FLAGS_address, grpc::InsecureServerCredentials()); + builder.RegisterService(&service); + std::unique_ptr<Server> server(builder.BuildAndStart()); + std::cout << "Server listening on " << FLAGS_address << std::endl; + server->Wait(); +} +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::InitTest(&argc, &argv, true); + grpc::testing::RunServer(); + + return 0; +} diff --git a/contrib/libs/grpc/test/cpp/end2end/client_interceptors_end2end_test.cc b/contrib/libs/grpc/test/cpp/end2end/client_interceptors_end2end_test.cc new file mode 100644 index 0000000000..956876d9f6 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/end2end/client_interceptors_end2end_test.cc @@ -0,0 +1,1194 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <memory> +#include <vector> + +#include <grpcpp/channel.h> +#include <grpcpp/client_context.h> +#include <grpcpp/create_channel.h> +#include <grpcpp/generic/generic_stub.h> +#include <grpcpp/impl/codegen/proto_utils.h> +#include <grpcpp/server.h> +#include <grpcpp/server_builder.h> +#include <grpcpp/server_context.h> +#include <grpcpp/support/client_interceptor.h> + +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/end2end/interceptors_util.h" +#include "test/cpp/end2end/test_service_impl.h" +#include "test/cpp/util/byte_buffer_proto_helper.h" +#include "test/cpp/util/string_ref_helper.h" + +#include <gtest/gtest.h> + +namespace grpc { +namespace testing { +namespace { + +enum class RPCType { + kSyncUnary, + kSyncClientStreaming, + kSyncServerStreaming, + kSyncBidiStreaming, + kAsyncCQUnary, + kAsyncCQClientStreaming, + kAsyncCQServerStreaming, + kAsyncCQBidiStreaming, +}; + +/* Hijacks Echo RPC and fills in the expected values */ +class HijackingInterceptor : public experimental::Interceptor { + public: + HijackingInterceptor(experimental::ClientRpcInfo* info) { + info_ = info; + // Make sure it is the right method + EXPECT_EQ(strcmp("/grpc.testing.EchoTestService/Echo", info->method()), 0); + EXPECT_EQ(info->type(), experimental::ClientRpcInfo::Type::UNARY); + } + + virtual void Intercept(experimental::InterceptorBatchMethods* methods) { + bool hijack = false; + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + auto* map = methods->GetSendInitialMetadata(); + // Check that we can see the test metadata + ASSERT_EQ(map->size(), static_cast<unsigned>(1)); + auto iterator = map->begin(); + EXPECT_EQ("testkey", iterator->first); + EXPECT_EQ("testvalue", iterator->second); + hijack = true; + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { + EchoRequest req; + auto* buffer = methods->GetSerializedSendMessage(); + auto copied_buffer = *buffer; + EXPECT_TRUE( + SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req) + .ok()); + EXPECT_EQ(req.message(), "Hello"); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) { + // Got nothing to do here for now + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) { + auto* map = methods->GetRecvInitialMetadata(); + // Got nothing better to do here for now + EXPECT_EQ(map->size(), static_cast<unsigned>(0)); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) { + EchoResponse* resp = + static_cast<EchoResponse*>(methods->GetRecvMessage()); + // Check that we got the hijacked message, and re-insert the expected + // message + EXPECT_EQ(resp->message(), "Hello1"); + resp->set_message("Hello"); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_STATUS)) { + auto* map = methods->GetRecvTrailingMetadata(); + bool found = false; + // Check that we received the metadata as an echo + for (const auto& pair : *map) { + found = pair.first.starts_with("testkey") && + pair.second.starts_with("testvalue"); + if (found) break; + } + EXPECT_EQ(found, true); + auto* status = methods->GetRecvStatus(); + EXPECT_EQ(status->ok(), true); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_INITIAL_METADATA)) { + auto* map = methods->GetRecvInitialMetadata(); + // Got nothing better to do here at the moment + EXPECT_EQ(map->size(), static_cast<unsigned>(0)); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) { + // Insert a different message than expected + EchoResponse* resp = + static_cast<EchoResponse*>(methods->GetRecvMessage()); + resp->set_message("Hello1"); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_STATUS)) { + auto* map = methods->GetRecvTrailingMetadata(); + // insert the metadata that we want + EXPECT_EQ(map->size(), static_cast<unsigned>(0)); + map->insert(std::make_pair("testkey", "testvalue")); + auto* status = methods->GetRecvStatus(); + *status = Status(StatusCode::OK, ""); + } + if (hijack) { + methods->Hijack(); + } else { + methods->Proceed(); + } + } + + private: + experimental::ClientRpcInfo* info_; +}; + +class HijackingInterceptorFactory + : public experimental::ClientInterceptorFactoryInterface { + public: + virtual experimental::Interceptor* CreateClientInterceptor( + experimental::ClientRpcInfo* info) override { + return new HijackingInterceptor(info); + } +}; + +class HijackingInterceptorMakesAnotherCall : public experimental::Interceptor { + public: + HijackingInterceptorMakesAnotherCall(experimental::ClientRpcInfo* info) { + info_ = info; + // Make sure it is the right method + EXPECT_EQ(strcmp("/grpc.testing.EchoTestService/Echo", info->method()), 0); + } + + virtual void Intercept(experimental::InterceptorBatchMethods* methods) { + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + auto* map = methods->GetSendInitialMetadata(); + // Check that we can see the test metadata + ASSERT_EQ(map->size(), static_cast<unsigned>(1)); + auto iterator = map->begin(); + EXPECT_EQ("testkey", iterator->first); + EXPECT_EQ("testvalue", iterator->second); + // Make a copy of the map + metadata_map_ = *map; + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { + EchoRequest req; + auto* buffer = methods->GetSerializedSendMessage(); + auto copied_buffer = *buffer; + EXPECT_TRUE( + SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req) + .ok()); + EXPECT_EQ(req.message(), "Hello"); + req_ = req; + stub_ = grpc::testing::EchoTestService::NewStub( + methods->GetInterceptedChannel()); + ctx_.AddMetadata(metadata_map_.begin()->first, + metadata_map_.begin()->second); + stub_->experimental_async()->Echo(&ctx_, &req_, &resp_, + [this, methods](Status s) { + EXPECT_EQ(s.ok(), true); + EXPECT_EQ(resp_.message(), "Hello"); + methods->Hijack(); + }); + // This is a Unary RPC and we have got nothing interesting to do in the + // PRE_SEND_CLOSE interception hook point for this interceptor, so let's + // return here. (We do not want to call methods->Proceed(). When the new + // RPC returns, we will call methods->Hijack() instead.) + return; + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) { + // Got nothing to do here for now + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) { + auto* map = methods->GetRecvInitialMetadata(); + // Got nothing better to do here for now + EXPECT_EQ(map->size(), static_cast<unsigned>(0)); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) { + EchoResponse* resp = + static_cast<EchoResponse*>(methods->GetRecvMessage()); + // Check that we got the hijacked message, and re-insert the expected + // message + EXPECT_EQ(resp->message(), "Hello"); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_STATUS)) { + auto* map = methods->GetRecvTrailingMetadata(); + bool found = false; + // Check that we received the metadata as an echo + for (const auto& pair : *map) { + found = pair.first.starts_with("testkey") && + pair.second.starts_with("testvalue"); + if (found) break; + } + EXPECT_EQ(found, true); + auto* status = methods->GetRecvStatus(); + EXPECT_EQ(status->ok(), true); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_INITIAL_METADATA)) { + auto* map = methods->GetRecvInitialMetadata(); + // Got nothing better to do here at the moment + EXPECT_EQ(map->size(), static_cast<unsigned>(0)); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) { + // Insert a different message than expected + EchoResponse* resp = + static_cast<EchoResponse*>(methods->GetRecvMessage()); + resp->set_message(resp_.message()); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_STATUS)) { + auto* map = methods->GetRecvTrailingMetadata(); + // insert the metadata that we want + EXPECT_EQ(map->size(), static_cast<unsigned>(0)); + map->insert(std::make_pair("testkey", "testvalue")); + auto* status = methods->GetRecvStatus(); + *status = Status(StatusCode::OK, ""); + } + + methods->Proceed(); + } + + private: + experimental::ClientRpcInfo* info_; + std::multimap<TString, TString> metadata_map_; + ClientContext ctx_; + EchoRequest req_; + EchoResponse resp_; + std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_; +}; + +class HijackingInterceptorMakesAnotherCallFactory + : public experimental::ClientInterceptorFactoryInterface { + public: + virtual experimental::Interceptor* CreateClientInterceptor( + experimental::ClientRpcInfo* info) override { + return new HijackingInterceptorMakesAnotherCall(info); + } +}; + +class BidiStreamingRpcHijackingInterceptor : public experimental::Interceptor { + public: + BidiStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) { + info_ = info; + } + + virtual void Intercept(experimental::InterceptorBatchMethods* methods) { + bool hijack = false; + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + CheckMetadata(*methods->GetSendInitialMetadata(), "testkey", "testvalue"); + hijack = true; + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { + EchoRequest req; + auto* buffer = methods->GetSerializedSendMessage(); + auto copied_buffer = *buffer; + EXPECT_TRUE( + SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req) + .ok()); + EXPECT_EQ(req.message().find("Hello"), 0u); + msg = req.message(); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) { + // Got nothing to do here for now + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_STATUS)) { + CheckMetadata(*methods->GetRecvTrailingMetadata(), "testkey", + "testvalue"); + auto* status = methods->GetRecvStatus(); + EXPECT_EQ(status->ok(), true); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) { + EchoResponse* resp = + static_cast<EchoResponse*>(methods->GetRecvMessage()); + resp->set_message(msg); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) { + EXPECT_EQ(static_cast<EchoResponse*>(methods->GetRecvMessage()) + ->message() + .find("Hello"), + 0u); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_STATUS)) { + auto* map = methods->GetRecvTrailingMetadata(); + // insert the metadata that we want + EXPECT_EQ(map->size(), static_cast<unsigned>(0)); + map->insert(std::make_pair("testkey", "testvalue")); + auto* status = methods->GetRecvStatus(); + *status = Status(StatusCode::OK, ""); + } + if (hijack) { + methods->Hijack(); + } else { + methods->Proceed(); + } + } + + private: + experimental::ClientRpcInfo* info_; + TString msg; +}; + +class ClientStreamingRpcHijackingInterceptor + : public experimental::Interceptor { + public: + ClientStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) { + info_ = info; + } + virtual void Intercept(experimental::InterceptorBatchMethods* methods) { + bool hijack = false; + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + hijack = true; + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { + if (++count_ > 10) { + methods->FailHijackedSendMessage(); + } + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_SEND_MESSAGE)) { + EXPECT_FALSE(got_failed_send_); + got_failed_send_ = !methods->GetSendMessageStatus(); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_STATUS)) { + auto* status = methods->GetRecvStatus(); + *status = Status(StatusCode::UNAVAILABLE, "Done sending 10 messages"); + } + if (hijack) { + methods->Hijack(); + } else { + methods->Proceed(); + } + } + + static bool GotFailedSend() { return got_failed_send_; } + + private: + experimental::ClientRpcInfo* info_; + int count_ = 0; + static bool got_failed_send_; +}; + +bool ClientStreamingRpcHijackingInterceptor::got_failed_send_ = false; + +class ClientStreamingRpcHijackingInterceptorFactory + : public experimental::ClientInterceptorFactoryInterface { + public: + virtual experimental::Interceptor* CreateClientInterceptor( + experimental::ClientRpcInfo* info) override { + return new ClientStreamingRpcHijackingInterceptor(info); + } +}; + +class ServerStreamingRpcHijackingInterceptor + : public experimental::Interceptor { + public: + ServerStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) { + info_ = info; + got_failed_message_ = false; + } + + virtual void Intercept(experimental::InterceptorBatchMethods* methods) { + bool hijack = false; + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + auto* map = methods->GetSendInitialMetadata(); + // Check that we can see the test metadata + ASSERT_EQ(map->size(), static_cast<unsigned>(1)); + auto iterator = map->begin(); + EXPECT_EQ("testkey", iterator->first); + EXPECT_EQ("testvalue", iterator->second); + hijack = true; + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { + EchoRequest req; + auto* buffer = methods->GetSerializedSendMessage(); + auto copied_buffer = *buffer; + EXPECT_TRUE( + SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req) + .ok()); + EXPECT_EQ(req.message(), "Hello"); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) { + // Got nothing to do here for now + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_STATUS)) { + auto* map = methods->GetRecvTrailingMetadata(); + bool found = false; + // Check that we received the metadata as an echo + for (const auto& pair : *map) { + found = pair.first.starts_with("testkey") && + pair.second.starts_with("testvalue"); + if (found) break; + } + EXPECT_EQ(found, true); + auto* status = methods->GetRecvStatus(); + EXPECT_EQ(status->ok(), true); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) { + if (++count_ > 10) { + methods->FailHijackedRecvMessage(); + } + EchoResponse* resp = + static_cast<EchoResponse*>(methods->GetRecvMessage()); + resp->set_message("Hello"); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) { + // Only the last message will be a failure + EXPECT_FALSE(got_failed_message_); + got_failed_message_ = methods->GetRecvMessage() == nullptr; + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_STATUS)) { + auto* map = methods->GetRecvTrailingMetadata(); + // insert the metadata that we want + EXPECT_EQ(map->size(), static_cast<unsigned>(0)); + map->insert(std::make_pair("testkey", "testvalue")); + auto* status = methods->GetRecvStatus(); + *status = Status(StatusCode::OK, ""); + } + if (hijack) { + methods->Hijack(); + } else { + methods->Proceed(); + } + } + + static bool GotFailedMessage() { return got_failed_message_; } + + private: + experimental::ClientRpcInfo* info_; + static bool got_failed_message_; + int count_ = 0; +}; + +bool ServerStreamingRpcHijackingInterceptor::got_failed_message_ = false; + +class ServerStreamingRpcHijackingInterceptorFactory + : public experimental::ClientInterceptorFactoryInterface { + public: + virtual experimental::Interceptor* CreateClientInterceptor( + experimental::ClientRpcInfo* info) override { + return new ServerStreamingRpcHijackingInterceptor(info); + } +}; + +class BidiStreamingRpcHijackingInterceptorFactory + : public experimental::ClientInterceptorFactoryInterface { + public: + virtual experimental::Interceptor* CreateClientInterceptor( + experimental::ClientRpcInfo* info) override { + return new BidiStreamingRpcHijackingInterceptor(info); + } +}; + +// The logging interceptor is for testing purposes only. It is used to verify +// that all the appropriate hook points are invoked for an RPC. The counts are +// reset each time a new object of LoggingInterceptor is created, so only a +// single RPC should be made on the channel before calling the Verify methods. +class LoggingInterceptor : public experimental::Interceptor { + public: + LoggingInterceptor(experimental::ClientRpcInfo* /*info*/) { + pre_send_initial_metadata_ = false; + pre_send_message_count_ = 0; + pre_send_close_ = false; + post_recv_initial_metadata_ = false; + post_recv_message_count_ = 0; + post_recv_status_ = false; + } + + virtual void Intercept(experimental::InterceptorBatchMethods* methods) { + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + auto* map = methods->GetSendInitialMetadata(); + // Check that we can see the test metadata + ASSERT_EQ(map->size(), static_cast<unsigned>(1)); + auto iterator = map->begin(); + EXPECT_EQ("testkey", iterator->first); + EXPECT_EQ("testvalue", iterator->second); + ASSERT_FALSE(pre_send_initial_metadata_); + pre_send_initial_metadata_ = true; + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { + EchoRequest req; + auto* send_msg = methods->GetSendMessage(); + if (send_msg == nullptr) { + // We did not get the non-serialized form of the message. Get the + // serialized form. + auto* buffer = methods->GetSerializedSendMessage(); + auto copied_buffer = *buffer; + EchoRequest req; + EXPECT_TRUE( + SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req) + .ok()); + EXPECT_EQ(req.message(), "Hello"); + } else { + EXPECT_EQ( + static_cast<const EchoRequest*>(send_msg)->message().find("Hello"), + 0u); + } + auto* buffer = methods->GetSerializedSendMessage(); + auto copied_buffer = *buffer; + EXPECT_TRUE( + SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req) + .ok()); + EXPECT_TRUE(req.message().find("Hello") == 0u); + pre_send_message_count_++; + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) { + // Got nothing to do here for now + pre_send_close_ = true; + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) { + auto* map = methods->GetRecvInitialMetadata(); + // Got nothing better to do here for now + EXPECT_EQ(map->size(), static_cast<unsigned>(0)); + post_recv_initial_metadata_ = true; + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) { + EchoResponse* resp = + static_cast<EchoResponse*>(methods->GetRecvMessage()); + if (resp != nullptr) { + EXPECT_TRUE(resp->message().find("Hello") == 0u); + post_recv_message_count_++; + } + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_STATUS)) { + auto* map = methods->GetRecvTrailingMetadata(); + bool found = false; + // Check that we received the metadata as an echo + for (const auto& pair : *map) { + found = pair.first.starts_with("testkey") && + pair.second.starts_with("testvalue"); + if (found) break; + } + EXPECT_EQ(found, true); + auto* status = methods->GetRecvStatus(); + EXPECT_EQ(status->ok(), true); + post_recv_status_ = true; + } + methods->Proceed(); + } + + static void VerifyCall(RPCType type) { + switch (type) { + case RPCType::kSyncUnary: + case RPCType::kAsyncCQUnary: + VerifyUnaryCall(); + break; + case RPCType::kSyncClientStreaming: + case RPCType::kAsyncCQClientStreaming: + VerifyClientStreamingCall(); + break; + case RPCType::kSyncServerStreaming: + case RPCType::kAsyncCQServerStreaming: + VerifyServerStreamingCall(); + break; + case RPCType::kSyncBidiStreaming: + case RPCType::kAsyncCQBidiStreaming: + VerifyBidiStreamingCall(); + break; + } + } + + static void VerifyCallCommon() { + EXPECT_TRUE(pre_send_initial_metadata_); + EXPECT_TRUE(pre_send_close_); + EXPECT_TRUE(post_recv_initial_metadata_); + EXPECT_TRUE(post_recv_status_); + } + + static void VerifyUnaryCall() { + VerifyCallCommon(); + EXPECT_EQ(pre_send_message_count_, 1); + EXPECT_EQ(post_recv_message_count_, 1); + } + + static void VerifyClientStreamingCall() { + VerifyCallCommon(); + EXPECT_EQ(pre_send_message_count_, kNumStreamingMessages); + EXPECT_EQ(post_recv_message_count_, 1); + } + + static void VerifyServerStreamingCall() { + VerifyCallCommon(); + EXPECT_EQ(pre_send_message_count_, 1); + EXPECT_EQ(post_recv_message_count_, kNumStreamingMessages); + } + + static void VerifyBidiStreamingCall() { + VerifyCallCommon(); + EXPECT_EQ(pre_send_message_count_, kNumStreamingMessages); + EXPECT_EQ(post_recv_message_count_, kNumStreamingMessages); + } + + private: + static bool pre_send_initial_metadata_; + static int pre_send_message_count_; + static bool pre_send_close_; + static bool post_recv_initial_metadata_; + static int post_recv_message_count_; + static bool post_recv_status_; +}; + +bool LoggingInterceptor::pre_send_initial_metadata_; +int LoggingInterceptor::pre_send_message_count_; +bool LoggingInterceptor::pre_send_close_; +bool LoggingInterceptor::post_recv_initial_metadata_; +int LoggingInterceptor::post_recv_message_count_; +bool LoggingInterceptor::post_recv_status_; + +class LoggingInterceptorFactory + : public experimental::ClientInterceptorFactoryInterface { + public: + virtual experimental::Interceptor* CreateClientInterceptor( + experimental::ClientRpcInfo* info) override { + return new LoggingInterceptor(info); + } +}; + +class TestScenario { + public: + explicit TestScenario(const RPCType& type) : type_(type) {} + + RPCType type() const { return type_; } + + private: + RPCType type_; +}; + +std::vector<TestScenario> CreateTestScenarios() { + std::vector<TestScenario> scenarios; + scenarios.emplace_back(RPCType::kSyncUnary); + scenarios.emplace_back(RPCType::kSyncClientStreaming); + scenarios.emplace_back(RPCType::kSyncServerStreaming); + scenarios.emplace_back(RPCType::kSyncBidiStreaming); + scenarios.emplace_back(RPCType::kAsyncCQUnary); + scenarios.emplace_back(RPCType::kAsyncCQServerStreaming); + return scenarios; +} + +class ParameterizedClientInterceptorsEnd2endTest + : public ::testing::TestWithParam<TestScenario> { + protected: + ParameterizedClientInterceptorsEnd2endTest() { + int port = grpc_pick_unused_port_or_die(); + + ServerBuilder builder; + server_address_ = "localhost:" + ToString(port); + builder.AddListeningPort(server_address_, InsecureServerCredentials()); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + } + + ~ParameterizedClientInterceptorsEnd2endTest() { server_->Shutdown(); } + + void SendRPC(const std::shared_ptr<Channel>& channel) { + switch (GetParam().type()) { + case RPCType::kSyncUnary: + MakeCall(channel); + break; + case RPCType::kSyncClientStreaming: + MakeClientStreamingCall(channel); + break; + case RPCType::kSyncServerStreaming: + MakeServerStreamingCall(channel); + break; + case RPCType::kSyncBidiStreaming: + MakeBidiStreamingCall(channel); + break; + case RPCType::kAsyncCQUnary: + MakeAsyncCQCall(channel); + break; + case RPCType::kAsyncCQClientStreaming: + // TODO(yashykt) : Fill this out + break; + case RPCType::kAsyncCQServerStreaming: + MakeAsyncCQServerStreamingCall(channel); + break; + case RPCType::kAsyncCQBidiStreaming: + // TODO(yashykt) : Fill this out + break; + } + } + + TString server_address_; + EchoTestServiceStreamingImpl service_; + std::unique_ptr<Server> server_; +}; + +TEST_P(ParameterizedClientInterceptorsEnd2endTest, + ClientInterceptorLoggingTest) { + ChannelArguments args; + DummyInterceptor::Reset(); + std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> + creators; + creators.push_back(std::unique_ptr<LoggingInterceptorFactory>( + new LoggingInterceptorFactory())); + // Add 20 dummy interceptors + for (auto i = 0; i < 20; i++) { + creators.push_back(std::unique_ptr<DummyInterceptorFactory>( + new DummyInterceptorFactory())); + } + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + SendRPC(channel); + LoggingInterceptor::VerifyCall(GetParam().type()); + // Make sure all 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); +} + +INSTANTIATE_TEST_SUITE_P(ParameterizedClientInterceptorsEnd2end, + ParameterizedClientInterceptorsEnd2endTest, + ::testing::ValuesIn(CreateTestScenarios())); + +class ClientInterceptorsEnd2endTest + : public ::testing::TestWithParam<TestScenario> { + protected: + ClientInterceptorsEnd2endTest() { + int port = grpc_pick_unused_port_or_die(); + + ServerBuilder builder; + server_address_ = "localhost:" + ToString(port); + builder.AddListeningPort(server_address_, InsecureServerCredentials()); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + } + + ~ClientInterceptorsEnd2endTest() { server_->Shutdown(); } + + TString server_address_; + TestServiceImpl service_; + std::unique_ptr<Server> server_; +}; + +TEST_F(ClientInterceptorsEnd2endTest, + LameChannelClientInterceptorHijackingTest) { + ChannelArguments args; + std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> + creators; + creators.push_back(std::unique_ptr<HijackingInterceptorFactory>( + new HijackingInterceptorFactory())); + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, nullptr, args, std::move(creators)); + MakeCall(channel); +} + +TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorHijackingTest) { + ChannelArguments args; + DummyInterceptor::Reset(); + std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> + creators; + // Add 20 dummy interceptors before hijacking interceptor + creators.reserve(20); + for (auto i = 0; i < 20; i++) { + creators.push_back(std::unique_ptr<DummyInterceptorFactory>( + new DummyInterceptorFactory())); + } + creators.push_back(std::unique_ptr<HijackingInterceptorFactory>( + new HijackingInterceptorFactory())); + // Add 20 dummy interceptors after hijacking interceptor + for (auto i = 0; i < 20; i++) { + creators.push_back(std::unique_ptr<DummyInterceptorFactory>( + new DummyInterceptorFactory())); + } + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + MakeCall(channel); + // Make sure only 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); +} + +TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorLogThenHijackTest) { + ChannelArguments args; + std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> + creators; + creators.push_back(std::unique_ptr<LoggingInterceptorFactory>( + new LoggingInterceptorFactory())); + creators.push_back(std::unique_ptr<HijackingInterceptorFactory>( + new HijackingInterceptorFactory())); + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + MakeCall(channel); + LoggingInterceptor::VerifyUnaryCall(); +} + +TEST_F(ClientInterceptorsEnd2endTest, + ClientInterceptorHijackingMakesAnotherCallTest) { + ChannelArguments args; + DummyInterceptor::Reset(); + std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> + creators; + // Add 5 dummy interceptors before hijacking interceptor + creators.reserve(5); + for (auto i = 0; i < 5; i++) { + creators.push_back(std::unique_ptr<DummyInterceptorFactory>( + new DummyInterceptorFactory())); + } + creators.push_back( + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>( + new HijackingInterceptorMakesAnotherCallFactory())); + // Add 7 dummy interceptors after hijacking interceptor + for (auto i = 0; i < 7; i++) { + creators.push_back(std::unique_ptr<DummyInterceptorFactory>( + new DummyInterceptorFactory())); + } + auto channel = server_->experimental().InProcessChannelWithInterceptors( + args, std::move(creators)); + + MakeCall(channel); + // Make sure all interceptors were run once, since the hijacking interceptor + // makes an RPC on the intercepted channel + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 12); +} + +class ClientInterceptorsCallbackEnd2endTest : public ::testing::Test { + protected: + ClientInterceptorsCallbackEnd2endTest() { + int port = grpc_pick_unused_port_or_die(); + + ServerBuilder builder; + server_address_ = "localhost:" + ToString(port); + builder.AddListeningPort(server_address_, InsecureServerCredentials()); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + } + + ~ClientInterceptorsCallbackEnd2endTest() { server_->Shutdown(); } + + TString server_address_; + TestServiceImpl service_; + std::unique_ptr<Server> server_; +}; + +TEST_F(ClientInterceptorsCallbackEnd2endTest, + ClientInterceptorLoggingTestWithCallback) { + ChannelArguments args; + DummyInterceptor::Reset(); + std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> + creators; + creators.push_back(std::unique_ptr<LoggingInterceptorFactory>( + new LoggingInterceptorFactory())); + // Add 20 dummy interceptors + for (auto i = 0; i < 20; i++) { + creators.push_back(std::unique_ptr<DummyInterceptorFactory>( + new DummyInterceptorFactory())); + } + auto channel = server_->experimental().InProcessChannelWithInterceptors( + args, std::move(creators)); + MakeCallbackCall(channel); + LoggingInterceptor::VerifyUnaryCall(); + // Make sure all 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); +} + +TEST_F(ClientInterceptorsCallbackEnd2endTest, + ClientInterceptorFactoryAllowsNullptrReturn) { + ChannelArguments args; + DummyInterceptor::Reset(); + std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> + creators; + creators.push_back(std::unique_ptr<LoggingInterceptorFactory>( + new LoggingInterceptorFactory())); + // Add 20 dummy interceptors and 20 null interceptors + for (auto i = 0; i < 20; i++) { + creators.push_back(std::unique_ptr<DummyInterceptorFactory>( + new DummyInterceptorFactory())); + creators.push_back( + std::unique_ptr<NullInterceptorFactory>(new NullInterceptorFactory())); + } + auto channel = server_->experimental().InProcessChannelWithInterceptors( + args, std::move(creators)); + MakeCallbackCall(channel); + LoggingInterceptor::VerifyUnaryCall(); + // Make sure all 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); +} + +class ClientInterceptorsStreamingEnd2endTest : public ::testing::Test { + protected: + ClientInterceptorsStreamingEnd2endTest() { + int port = grpc_pick_unused_port_or_die(); + + ServerBuilder builder; + server_address_ = "localhost:" + ToString(port); + builder.AddListeningPort(server_address_, InsecureServerCredentials()); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + } + + ~ClientInterceptorsStreamingEnd2endTest() { server_->Shutdown(); } + + TString server_address_; + EchoTestServiceStreamingImpl service_; + std::unique_ptr<Server> server_; +}; + +TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingTest) { + ChannelArguments args; + DummyInterceptor::Reset(); + std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> + creators; + creators.push_back(std::unique_ptr<LoggingInterceptorFactory>( + new LoggingInterceptorFactory())); + // Add 20 dummy interceptors + for (auto i = 0; i < 20; i++) { + creators.push_back(std::unique_ptr<DummyInterceptorFactory>( + new DummyInterceptorFactory())); + } + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + MakeClientStreamingCall(channel); + LoggingInterceptor::VerifyClientStreamingCall(); + // Make sure all 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); +} + +TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingTest) { + ChannelArguments args; + DummyInterceptor::Reset(); + std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> + creators; + creators.push_back(std::unique_ptr<LoggingInterceptorFactory>( + new LoggingInterceptorFactory())); + // Add 20 dummy interceptors + for (auto i = 0; i < 20; i++) { + creators.push_back(std::unique_ptr<DummyInterceptorFactory>( + new DummyInterceptorFactory())); + } + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + MakeServerStreamingCall(channel); + LoggingInterceptor::VerifyServerStreamingCall(); + // Make sure all 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); +} + +TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingHijackingTest) { + ChannelArguments args; + std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> + creators; + creators.push_back( + std::unique_ptr<ClientStreamingRpcHijackingInterceptorFactory>( + new ClientStreamingRpcHijackingInterceptorFactory())); + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + + auto stub = grpc::testing::EchoTestService::NewStub(channel); + ClientContext ctx; + EchoRequest req; + EchoResponse resp; + req.mutable_param()->set_echo_metadata(true); + req.set_message("Hello"); + string expected_resp = ""; + auto writer = stub->RequestStream(&ctx, &resp); + for (int i = 0; i < 10; i++) { + EXPECT_TRUE(writer->Write(req)); + expected_resp += "Hello"; + } + // The interceptor will reject the 11th message + writer->Write(req); + Status s = writer->Finish(); + EXPECT_EQ(s.ok(), false); + EXPECT_TRUE(ClientStreamingRpcHijackingInterceptor::GotFailedSend()); +} + +TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingHijackingTest) { + ChannelArguments args; + DummyInterceptor::Reset(); + std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> + creators; + creators.push_back( + std::unique_ptr<ServerStreamingRpcHijackingInterceptorFactory>( + new ServerStreamingRpcHijackingInterceptorFactory())); + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + MakeServerStreamingCall(channel); + EXPECT_TRUE(ServerStreamingRpcHijackingInterceptor::GotFailedMessage()); +} + +TEST_F(ClientInterceptorsStreamingEnd2endTest, + AsyncCQServerStreamingHijackingTest) { + ChannelArguments args; + DummyInterceptor::Reset(); + std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> + creators; + creators.push_back( + std::unique_ptr<ServerStreamingRpcHijackingInterceptorFactory>( + new ServerStreamingRpcHijackingInterceptorFactory())); + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + MakeAsyncCQServerStreamingCall(channel); + EXPECT_TRUE(ServerStreamingRpcHijackingInterceptor::GotFailedMessage()); +} + +TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingHijackingTest) { + ChannelArguments args; + DummyInterceptor::Reset(); + std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> + creators; + creators.push_back( + std::unique_ptr<BidiStreamingRpcHijackingInterceptorFactory>( + new BidiStreamingRpcHijackingInterceptorFactory())); + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + MakeBidiStreamingCall(channel); +} + +TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingTest) { + ChannelArguments args; + DummyInterceptor::Reset(); + std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> + creators; + creators.push_back(std::unique_ptr<LoggingInterceptorFactory>( + new LoggingInterceptorFactory())); + // Add 20 dummy interceptors + for (auto i = 0; i < 20; i++) { + creators.push_back(std::unique_ptr<DummyInterceptorFactory>( + new DummyInterceptorFactory())); + } + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + MakeBidiStreamingCall(channel); + LoggingInterceptor::VerifyBidiStreamingCall(); + // Make sure all 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); +} + +class ClientGlobalInterceptorEnd2endTest : public ::testing::Test { + protected: + ClientGlobalInterceptorEnd2endTest() { + int port = grpc_pick_unused_port_or_die(); + + ServerBuilder builder; + server_address_ = "localhost:" + ToString(port); + builder.AddListeningPort(server_address_, InsecureServerCredentials()); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + } + + ~ClientGlobalInterceptorEnd2endTest() { server_->Shutdown(); } + + TString server_address_; + TestServiceImpl service_; + std::unique_ptr<Server> server_; +}; + +TEST_F(ClientGlobalInterceptorEnd2endTest, DummyGlobalInterceptor) { + // We should ideally be registering a global interceptor only once per + // process, but for the purposes of testing, it should be fine to modify the + // registered global interceptor when there are no ongoing gRPC operations + DummyInterceptorFactory global_factory; + experimental::RegisterGlobalClientInterceptorFactory(&global_factory); + ChannelArguments args; + DummyInterceptor::Reset(); + std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> + creators; + // Add 20 dummy interceptors + creators.reserve(20); + for (auto i = 0; i < 20; i++) { + creators.push_back(std::unique_ptr<DummyInterceptorFactory>( + new DummyInterceptorFactory())); + } + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + MakeCall(channel); + // Make sure all 20 dummy interceptors were run with the global interceptor + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 21); + experimental::TestOnlyResetGlobalClientInterceptorFactory(); +} + +TEST_F(ClientGlobalInterceptorEnd2endTest, LoggingGlobalInterceptor) { + // We should ideally be registering a global interceptor only once per + // process, but for the purposes of testing, it should be fine to modify the + // registered global interceptor when there are no ongoing gRPC operations + LoggingInterceptorFactory global_factory; + experimental::RegisterGlobalClientInterceptorFactory(&global_factory); + ChannelArguments args; + DummyInterceptor::Reset(); + std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> + creators; + // Add 20 dummy interceptors + creators.reserve(20); + for (auto i = 0; i < 20; i++) { + creators.push_back(std::unique_ptr<DummyInterceptorFactory>( + new DummyInterceptorFactory())); + } + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + MakeCall(channel); + LoggingInterceptor::VerifyUnaryCall(); + // Make sure all 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); + experimental::TestOnlyResetGlobalClientInterceptorFactory(); +} + +TEST_F(ClientGlobalInterceptorEnd2endTest, HijackingGlobalInterceptor) { + // We should ideally be registering a global interceptor only once per + // process, but for the purposes of testing, it should be fine to modify the + // registered global interceptor when there are no ongoing gRPC operations + HijackingInterceptorFactory global_factory; + experimental::RegisterGlobalClientInterceptorFactory(&global_factory); + ChannelArguments args; + DummyInterceptor::Reset(); + std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> + creators; + // Add 20 dummy interceptors + creators.reserve(20); + for (auto i = 0; i < 20; i++) { + creators.push_back(std::unique_ptr<DummyInterceptorFactory>( + new DummyInterceptorFactory())); + } + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + MakeCall(channel); + // Make sure all 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); + experimental::TestOnlyResetGlobalClientInterceptorFactory(); +} + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/contrib/libs/grpc/test/cpp/end2end/client_lb_end2end_test.cc b/contrib/libs/grpc/test/cpp/end2end/client_lb_end2end_test.cc new file mode 100644 index 0000000000..fd08dd163d --- /dev/null +++ b/contrib/libs/grpc/test/cpp/end2end/client_lb_end2end_test.cc @@ -0,0 +1,1990 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <algorithm> +#include <memory> +#include <mutex> +#include <random> +#include <set> +#include <util/generic/string.h> +#include <thread> + +#include "y_absl/strings/str_cat.h" + +#include <grpc/grpc.h> +#include <grpc/support/alloc.h> +#include <grpc/support/atm.h> +#include <grpc/support/log.h> +#include <grpc/support/time.h> +#include <grpcpp/channel.h> +#include <grpcpp/client_context.h> +#include <grpcpp/create_channel.h> +#include <grpcpp/health_check_service_interface.h> +#include <grpcpp/impl/codegen/sync.h> +#include <grpcpp/server.h> +#include <grpcpp/server_builder.h> + +#include "src/core/ext/filters/client_channel/backup_poller.h" +#include "src/core/ext/filters/client_channel/global_subchannel_pool.h" +#include "src/core/ext/filters/client_channel/resolver/fake/fake_resolver.h" +#include "src/core/ext/filters/client_channel/server_address.h" +#include "src/core/ext/filters/client_channel/service_config.h" +#include "src/core/lib/backoff/backoff.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/gprpp/debug_location.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/iomgr/parse_address.h" +#include "src/core/lib/iomgr/tcp_client.h" +#include "src/core/lib/security/credentials/fake/fake_credentials.h" +#include "src/cpp/client/secure_credentials.h" +#include "src/cpp/server/secure_server_credentials.h" + +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "src/proto/grpc/testing/xds/orca_load_report_for_test.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/core/util/test_lb_policies.h" +#include "test/cpp/end2end/test_service_impl.h" + +#include <gmock/gmock.h> +#include <gtest/gtest.h> + +using grpc::testing::EchoRequest; +using grpc::testing::EchoResponse; +using std::chrono::system_clock; + +// defined in tcp_client.cc +extern grpc_tcp_client_vtable* grpc_tcp_client_impl; + +static grpc_tcp_client_vtable* default_client_impl; + +namespace grpc { +namespace testing { +namespace { + +gpr_atm g_connection_delay_ms; + +void tcp_client_connect_with_delay(grpc_closure* closure, grpc_endpoint** ep, + grpc_pollset_set* interested_parties, + const grpc_channel_args* channel_args, + const grpc_resolved_address* addr, + grpc_millis deadline) { + const int delay_ms = gpr_atm_acq_load(&g_connection_delay_ms); + if (delay_ms > 0) { + gpr_sleep_until(grpc_timeout_milliseconds_to_deadline(delay_ms)); + } + default_client_impl->connect(closure, ep, interested_parties, channel_args, + addr, deadline + delay_ms); +} + +grpc_tcp_client_vtable delayed_connect = {tcp_client_connect_with_delay}; + +// Subclass of TestServiceImpl that increments a request counter for +// every call to the Echo RPC. +class MyTestServiceImpl : public TestServiceImpl { + public: + Status Echo(ServerContext* context, const EchoRequest* request, + EchoResponse* response) override { + const udpa::data::orca::v1::OrcaLoadReport* load_report = nullptr; + { + grpc::internal::MutexLock lock(&mu_); + ++request_count_; + load_report = load_report_; + } + AddClient(context->peer().c_str()); + if (load_report != nullptr) { + // TODO(roth): Once we provide a more standard server-side API for + // populating this data, use that API here. + context->AddTrailingMetadata("x-endpoint-load-metrics-bin", + load_report->SerializeAsString()); + } + return TestServiceImpl::Echo(context, request, response); + } + + int request_count() { + grpc::internal::MutexLock lock(&mu_); + return request_count_; + } + + void ResetCounters() { + grpc::internal::MutexLock lock(&mu_); + request_count_ = 0; + } + + std::set<TString> clients() { + grpc::internal::MutexLock lock(&clients_mu_); + return clients_; + } + + void set_load_report(udpa::data::orca::v1::OrcaLoadReport* load_report) { + grpc::internal::MutexLock lock(&mu_); + load_report_ = load_report; + } + + private: + void AddClient(const TString& client) { + grpc::internal::MutexLock lock(&clients_mu_); + clients_.insert(client); + } + + grpc::internal::Mutex mu_; + int request_count_ = 0; + const udpa::data::orca::v1::OrcaLoadReport* load_report_ = nullptr; + grpc::internal::Mutex clients_mu_; + std::set<TString> clients_; +}; + +class FakeResolverResponseGeneratorWrapper { + public: + FakeResolverResponseGeneratorWrapper() + : response_generator_(grpc_core::MakeRefCounted< + grpc_core::FakeResolverResponseGenerator>()) {} + + FakeResolverResponseGeneratorWrapper( + FakeResolverResponseGeneratorWrapper&& other) noexcept { + response_generator_ = std::move(other.response_generator_); + } + + void SetNextResolution( + const std::vector<int>& ports, const char* service_config_json = nullptr, + const char* attribute_key = nullptr, + std::unique_ptr<grpc_core::ServerAddress::AttributeInterface> attribute = + nullptr) { + grpc_core::ExecCtx exec_ctx; + response_generator_->SetResponse(BuildFakeResults( + ports, service_config_json, attribute_key, std::move(attribute))); + } + + void SetNextResolutionUponError(const std::vector<int>& ports) { + grpc_core::ExecCtx exec_ctx; + response_generator_->SetReresolutionResponse(BuildFakeResults(ports)); + } + + void SetFailureOnReresolution() { + grpc_core::ExecCtx exec_ctx; + response_generator_->SetFailureOnReresolution(); + } + + grpc_core::FakeResolverResponseGenerator* Get() const { + return response_generator_.get(); + } + + private: + static grpc_core::Resolver::Result BuildFakeResults( + const std::vector<int>& ports, const char* service_config_json = nullptr, + const char* attribute_key = nullptr, + std::unique_ptr<grpc_core::ServerAddress::AttributeInterface> attribute = + nullptr) { + grpc_core::Resolver::Result result; + for (const int& port : ports) { + TString lb_uri_str = y_absl::StrCat("ipv4:127.0.0.1:", port); + grpc_uri* lb_uri = grpc_uri_parse(lb_uri_str.c_str(), true); + GPR_ASSERT(lb_uri != nullptr); + grpc_resolved_address address; + GPR_ASSERT(grpc_parse_uri(lb_uri, &address)); + std::map<const char*, + std::unique_ptr<grpc_core::ServerAddress::AttributeInterface>> + attributes; + if (attribute != nullptr) { + attributes[attribute_key] = attribute->Copy(); + } + result.addresses.emplace_back(address.addr, address.len, + nullptr /* args */, std::move(attributes)); + grpc_uri_destroy(lb_uri); + } + if (service_config_json != nullptr) { + result.service_config = grpc_core::ServiceConfig::Create( + nullptr, service_config_json, &result.service_config_error); + GPR_ASSERT(result.service_config != nullptr); + } + return result; + } + + grpc_core::RefCountedPtr<grpc_core::FakeResolverResponseGenerator> + response_generator_; +}; + +class ClientLbEnd2endTest : public ::testing::Test { + protected: + ClientLbEnd2endTest() + : server_host_("localhost"), + kRequestMessage_("Live long and prosper."), + creds_(new SecureChannelCredentials( + grpc_fake_transport_security_credentials_create())) {} + + static void SetUpTestCase() { + // Make the backup poller poll very frequently in order to pick up + // updates from all the subchannels's FDs. + GPR_GLOBAL_CONFIG_SET(grpc_client_channel_backup_poll_interval_ms, 1); +#if TARGET_OS_IPHONE + // Workaround Apple CFStream bug + gpr_setenv("grpc_cfstream", "0"); +#endif + } + + void SetUp() override { grpc_init(); } + + void TearDown() override { + for (size_t i = 0; i < servers_.size(); ++i) { + servers_[i]->Shutdown(); + } + servers_.clear(); + creds_.reset(); + grpc_shutdown_blocking(); + } + + void CreateServers(size_t num_servers, + std::vector<int> ports = std::vector<int>()) { + servers_.clear(); + for (size_t i = 0; i < num_servers; ++i) { + int port = 0; + if (ports.size() == num_servers) port = ports[i]; + servers_.emplace_back(new ServerData(port)); + } + } + + void StartServer(size_t index) { servers_[index]->Start(server_host_); } + + void StartServers(size_t num_servers, + std::vector<int> ports = std::vector<int>()) { + CreateServers(num_servers, std::move(ports)); + for (size_t i = 0; i < num_servers; ++i) { + StartServer(i); + } + } + + std::vector<int> GetServersPorts(size_t start_index = 0) { + std::vector<int> ports; + for (size_t i = start_index; i < servers_.size(); ++i) { + ports.push_back(servers_[i]->port_); + } + return ports; + } + + FakeResolverResponseGeneratorWrapper BuildResolverResponseGenerator() { + return FakeResolverResponseGeneratorWrapper(); + } + + std::unique_ptr<grpc::testing::EchoTestService::Stub> BuildStub( + const std::shared_ptr<Channel>& channel) { + return grpc::testing::EchoTestService::NewStub(channel); + } + + std::shared_ptr<Channel> BuildChannel( + const TString& lb_policy_name, + const FakeResolverResponseGeneratorWrapper& response_generator, + ChannelArguments args = ChannelArguments()) { + if (lb_policy_name.size() > 0) { + args.SetLoadBalancingPolicyName(lb_policy_name); + } // else, default to pick first + args.SetPointer(GRPC_ARG_FAKE_RESOLVER_RESPONSE_GENERATOR, + response_generator.Get()); + return ::grpc::CreateCustomChannel("fake:///", creds_, args); + } + + bool SendRpc( + const std::unique_ptr<grpc::testing::EchoTestService::Stub>& stub, + EchoResponse* response = nullptr, int timeout_ms = 1000, + Status* result = nullptr, bool wait_for_ready = false) { + const bool local_response = (response == nullptr); + if (local_response) response = new EchoResponse; + EchoRequest request; + request.set_message(kRequestMessage_); + request.mutable_param()->set_echo_metadata(true); + ClientContext context; + context.set_deadline(grpc_timeout_milliseconds_to_deadline(timeout_ms)); + if (wait_for_ready) context.set_wait_for_ready(true); + context.AddMetadata("foo", "1"); + context.AddMetadata("bar", "2"); + context.AddMetadata("baz", "3"); + Status status = stub->Echo(&context, request, response); + if (result != nullptr) *result = status; + if (local_response) delete response; + return status.ok(); + } + + void CheckRpcSendOk( + const std::unique_ptr<grpc::testing::EchoTestService::Stub>& stub, + const grpc_core::DebugLocation& location, bool wait_for_ready = false) { + EchoResponse response; + Status status; + const bool success = + SendRpc(stub, &response, 2000, &status, wait_for_ready); + ASSERT_TRUE(success) << "From " << location.file() << ":" << location.line() + << "\n" + << "Error: " << status.error_message() << " " + << status.error_details(); + ASSERT_EQ(response.message(), kRequestMessage_) + << "From " << location.file() << ":" << location.line(); + if (!success) abort(); + } + + void CheckRpcSendFailure( + const std::unique_ptr<grpc::testing::EchoTestService::Stub>& stub) { + const bool success = SendRpc(stub); + EXPECT_FALSE(success); + } + + struct ServerData { + int port_; + std::unique_ptr<Server> server_; + MyTestServiceImpl service_; + std::unique_ptr<std::thread> thread_; + bool server_ready_ = false; + bool started_ = false; + + explicit ServerData(int port = 0) { + port_ = port > 0 ? port : 5100; // grpc_pick_unused_port_or_die(); + } + + void Start(const TString& server_host) { + gpr_log(GPR_INFO, "starting server on port %d", port_); + started_ = true; + grpc::internal::Mutex mu; + grpc::internal::MutexLock lock(&mu); + grpc::internal::CondVar cond; + thread_.reset(new std::thread( + std::bind(&ServerData::Serve, this, server_host, &mu, &cond))); + cond.WaitUntil(&mu, [this] { return server_ready_; }); + server_ready_ = false; + gpr_log(GPR_INFO, "server startup complete"); + } + + void Serve(const TString& server_host, grpc::internal::Mutex* mu, + grpc::internal::CondVar* cond) { + std::ostringstream server_address; + server_address << server_host << ":" << port_; + ServerBuilder builder; + std::shared_ptr<ServerCredentials> creds(new SecureServerCredentials( + grpc_fake_transport_security_server_credentials_create())); + builder.AddListeningPort(server_address.str(), std::move(creds)); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + grpc::internal::MutexLock lock(mu); + server_ready_ = true; + cond->Signal(); + } + + void Shutdown() { + if (!started_) return; + server_->Shutdown(grpc_timeout_milliseconds_to_deadline(0)); + thread_->join(); + started_ = false; + } + + void SetServingStatus(const TString& service, bool serving) { + server_->GetHealthCheckService()->SetServingStatus(service, serving); + } + }; + + void ResetCounters() { + for (const auto& server : servers_) server->service_.ResetCounters(); + } + + void WaitForServer( + const std::unique_ptr<grpc::testing::EchoTestService::Stub>& stub, + size_t server_idx, const grpc_core::DebugLocation& location, + bool ignore_failure = false) { + do { + if (ignore_failure) { + SendRpc(stub); + } else { + CheckRpcSendOk(stub, location, true); + } + } while (servers_[server_idx]->service_.request_count() == 0); + ResetCounters(); + } + + bool WaitForChannelState( + Channel* channel, std::function<bool(grpc_connectivity_state)> predicate, + bool try_to_connect = false, int timeout_seconds = 5) { + const gpr_timespec deadline = + grpc_timeout_seconds_to_deadline(timeout_seconds); + while (true) { + grpc_connectivity_state state = channel->GetState(try_to_connect); + if (predicate(state)) break; + if (!channel->WaitForStateChange(state, deadline)) return false; + } + return true; + } + + bool WaitForChannelNotReady(Channel* channel, int timeout_seconds = 5) { + auto predicate = [](grpc_connectivity_state state) { + return state != GRPC_CHANNEL_READY; + }; + return WaitForChannelState(channel, predicate, false, timeout_seconds); + } + + bool WaitForChannelReady(Channel* channel, int timeout_seconds = 5) { + auto predicate = [](grpc_connectivity_state state) { + return state == GRPC_CHANNEL_READY; + }; + return WaitForChannelState(channel, predicate, true, timeout_seconds); + } + + bool SeenAllServers() { + for (const auto& server : servers_) { + if (server->service_.request_count() == 0) return false; + } + return true; + } + + // Updates \a connection_order by appending to it the index of the newly + // connected server. Must be called after every single RPC. + void UpdateConnectionOrder( + const std::vector<std::unique_ptr<ServerData>>& servers, + std::vector<int>* connection_order) { + for (size_t i = 0; i < servers.size(); ++i) { + if (servers[i]->service_.request_count() == 1) { + // Was the server index known? If not, update connection_order. + const auto it = + std::find(connection_order->begin(), connection_order->end(), i); + if (it == connection_order->end()) { + connection_order->push_back(i); + return; + } + } + } + } + + const TString server_host_; + std::vector<std::unique_ptr<ServerData>> servers_; + const TString kRequestMessage_; + std::shared_ptr<ChannelCredentials> creds_; +}; + +TEST_F(ClientLbEnd2endTest, ChannelStateConnectingWhenResolving) { + const int kNumServers = 3; + StartServers(kNumServers); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("", response_generator); + auto stub = BuildStub(channel); + // Initial state should be IDLE. + EXPECT_EQ(channel->GetState(false /* try_to_connect */), GRPC_CHANNEL_IDLE); + // Tell the channel to try to connect. + // Note that this call also returns IDLE, since the state change has + // not yet occurred; it just gets triggered by this call. + EXPECT_EQ(channel->GetState(true /* try_to_connect */), GRPC_CHANNEL_IDLE); + // Now that the channel is trying to connect, we should be in state + // CONNECTING. + EXPECT_EQ(channel->GetState(false /* try_to_connect */), + GRPC_CHANNEL_CONNECTING); + // Return a resolver result, which allows the connection attempt to proceed. + response_generator.SetNextResolution(GetServersPorts()); + // We should eventually transition into state READY. + EXPECT_TRUE(WaitForChannelReady(channel.get())); +} + +TEST_F(ClientLbEnd2endTest, PickFirst) { + // Start servers and send one RPC per server. + const int kNumServers = 3; + StartServers(kNumServers); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel( + "", response_generator); // test that pick first is the default. + auto stub = BuildStub(channel); + response_generator.SetNextResolution(GetServersPorts()); + for (size_t i = 0; i < servers_.size(); ++i) { + CheckRpcSendOk(stub, DEBUG_LOCATION); + } + // All requests should have gone to a single server. + bool found = false; + for (size_t i = 0; i < servers_.size(); ++i) { + const int request_count = servers_[i]->service_.request_count(); + if (request_count == kNumServers) { + found = true; + } else { + EXPECT_EQ(0, request_count); + } + } + EXPECT_TRUE(found); + // Check LB policy name for the channel. + EXPECT_EQ("pick_first", channel->GetLoadBalancingPolicyName()); +} + +TEST_F(ClientLbEnd2endTest, PickFirstProcessPending) { + StartServers(1); // Single server + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel( + "", response_generator); // test that pick first is the default. + auto stub = BuildStub(channel); + response_generator.SetNextResolution({servers_[0]->port_}); + WaitForServer(stub, 0, DEBUG_LOCATION); + // Create a new channel and its corresponding PF LB policy, which will pick + // the subchannels in READY state from the previous RPC against the same + // target (even if it happened over a different channel, because subchannels + // are globally reused). Progress should happen without any transition from + // this READY state. + auto second_response_generator = BuildResolverResponseGenerator(); + auto second_channel = BuildChannel("", second_response_generator); + auto second_stub = BuildStub(second_channel); + second_response_generator.SetNextResolution({servers_[0]->port_}); + CheckRpcSendOk(second_stub, DEBUG_LOCATION); +} + +TEST_F(ClientLbEnd2endTest, PickFirstSelectsReadyAtStartup) { + ChannelArguments args; + constexpr int kInitialBackOffMs = 5000; + args.SetInt(GRPC_ARG_INITIAL_RECONNECT_BACKOFF_MS, kInitialBackOffMs); + // Create 2 servers, but start only the second one. + std::vector<int> ports = { 5101, // grpc_pick_unused_port_or_die(), + 5102}; // grpc_pick_unused_port_or_die()}; + CreateServers(2, ports); + StartServer(1); + auto response_generator1 = BuildResolverResponseGenerator(); + auto channel1 = BuildChannel("pick_first", response_generator1, args); + auto stub1 = BuildStub(channel1); + response_generator1.SetNextResolution(ports); + // Wait for second server to be ready. + WaitForServer(stub1, 1, DEBUG_LOCATION); + // Create a second channel with the same addresses. Its PF instance + // should immediately pick the second subchannel, since it's already + // in READY state. + auto response_generator2 = BuildResolverResponseGenerator(); + auto channel2 = BuildChannel("pick_first", response_generator2, args); + response_generator2.SetNextResolution(ports); + // Check that the channel reports READY without waiting for the + // initial backoff. + EXPECT_TRUE(WaitForChannelReady(channel2.get(), 1 /* timeout_seconds */)); +} + +TEST_F(ClientLbEnd2endTest, PickFirstBackOffInitialReconnect) { + ChannelArguments args; + constexpr int kInitialBackOffMs = 100; + args.SetInt(GRPC_ARG_INITIAL_RECONNECT_BACKOFF_MS, kInitialBackOffMs); + const std::vector<int> ports = {5103}; // {grpc_pick_unused_port_or_die()}; + const gpr_timespec t0 = gpr_now(GPR_CLOCK_MONOTONIC); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("pick_first", response_generator, args); + auto stub = BuildStub(channel); + response_generator.SetNextResolution(ports); + // The channel won't become connected (there's no server). + ASSERT_FALSE(channel->WaitForConnected( + grpc_timeout_milliseconds_to_deadline(kInitialBackOffMs * 2))); + // Bring up a server on the chosen port. + StartServers(1, ports); + // Now it will. + ASSERT_TRUE(channel->WaitForConnected( + grpc_timeout_milliseconds_to_deadline(kInitialBackOffMs * 2))); + const gpr_timespec t1 = gpr_now(GPR_CLOCK_MONOTONIC); + const grpc_millis waited_ms = gpr_time_to_millis(gpr_time_sub(t1, t0)); + gpr_log(GPR_DEBUG, "Waited %" PRId64 " milliseconds", waited_ms); + // We should have waited at least kInitialBackOffMs. We substract one to + // account for test and precision accuracy drift. + EXPECT_GE(waited_ms, kInitialBackOffMs - 1); + // But not much more. + EXPECT_GT( + gpr_time_cmp( + grpc_timeout_milliseconds_to_deadline(kInitialBackOffMs * 1.10), t1), + 0); +} + +TEST_F(ClientLbEnd2endTest, PickFirstBackOffMinReconnect) { + ChannelArguments args; + constexpr int kMinReconnectBackOffMs = 1000; + args.SetInt(GRPC_ARG_MIN_RECONNECT_BACKOFF_MS, kMinReconnectBackOffMs); + const std::vector<int> ports = {5104}; // {grpc_pick_unused_port_or_die()}; + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("pick_first", response_generator, args); + auto stub = BuildStub(channel); + response_generator.SetNextResolution(ports); + // Make connection delay a 10% longer than it's willing to in order to make + // sure we are hitting the codepath that waits for the min reconnect backoff. + gpr_atm_rel_store(&g_connection_delay_ms, kMinReconnectBackOffMs * 1.10); + default_client_impl = grpc_tcp_client_impl; + grpc_set_tcp_client_impl(&delayed_connect); + const gpr_timespec t0 = gpr_now(GPR_CLOCK_MONOTONIC); + channel->WaitForConnected( + grpc_timeout_milliseconds_to_deadline(kMinReconnectBackOffMs * 2)); + const gpr_timespec t1 = gpr_now(GPR_CLOCK_MONOTONIC); + const grpc_millis waited_ms = gpr_time_to_millis(gpr_time_sub(t1, t0)); + gpr_log(GPR_DEBUG, "Waited %" PRId64 " ms", waited_ms); + // We should have waited at least kMinReconnectBackOffMs. We substract one to + // account for test and precision accuracy drift. + EXPECT_GE(waited_ms, kMinReconnectBackOffMs - 1); + gpr_atm_rel_store(&g_connection_delay_ms, 0); +} + +TEST_F(ClientLbEnd2endTest, PickFirstResetConnectionBackoff) { + ChannelArguments args; + constexpr int kInitialBackOffMs = 1000; + args.SetInt(GRPC_ARG_INITIAL_RECONNECT_BACKOFF_MS, kInitialBackOffMs); + const std::vector<int> ports = {5105}; // {grpc_pick_unused_port_or_die()}; + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("pick_first", response_generator, args); + auto stub = BuildStub(channel); + response_generator.SetNextResolution(ports); + // The channel won't become connected (there's no server). + EXPECT_FALSE( + channel->WaitForConnected(grpc_timeout_milliseconds_to_deadline(10))); + // Bring up a server on the chosen port. + StartServers(1, ports); + const gpr_timespec t0 = gpr_now(GPR_CLOCK_MONOTONIC); + // Wait for connect, but not long enough. This proves that we're + // being throttled by initial backoff. + EXPECT_FALSE( + channel->WaitForConnected(grpc_timeout_milliseconds_to_deadline(10))); + // Reset connection backoff. + experimental::ChannelResetConnectionBackoff(channel.get()); + // Wait for connect. Should happen as soon as the client connects to + // the newly started server, which should be before the initial + // backoff timeout elapses. + EXPECT_TRUE( + channel->WaitForConnected(grpc_timeout_milliseconds_to_deadline(20))); + const gpr_timespec t1 = gpr_now(GPR_CLOCK_MONOTONIC); + const grpc_millis waited_ms = gpr_time_to_millis(gpr_time_sub(t1, t0)); + gpr_log(GPR_DEBUG, "Waited %" PRId64 " milliseconds", waited_ms); + // We should have waited less than kInitialBackOffMs. + EXPECT_LT(waited_ms, kInitialBackOffMs); +} + +TEST_F(ClientLbEnd2endTest, + PickFirstResetConnectionBackoffNextAttemptStartsImmediately) { + ChannelArguments args; + constexpr int kInitialBackOffMs = 1000; + args.SetInt(GRPC_ARG_INITIAL_RECONNECT_BACKOFF_MS, kInitialBackOffMs); + const std::vector<int> ports = {5106}; // {grpc_pick_unused_port_or_die()}; + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("pick_first", response_generator, args); + auto stub = BuildStub(channel); + response_generator.SetNextResolution(ports); + // Wait for connect, which should fail ~immediately, because the server + // is not up. + gpr_log(GPR_INFO, "=== INITIAL CONNECTION ATTEMPT"); + EXPECT_FALSE( + channel->WaitForConnected(grpc_timeout_milliseconds_to_deadline(10))); + // Reset connection backoff. + // Note that the time at which the third attempt will be started is + // actually computed at this point, so we record the start time here. + gpr_log(GPR_INFO, "=== RESETTING BACKOFF"); + const gpr_timespec t0 = gpr_now(GPR_CLOCK_MONOTONIC); + experimental::ChannelResetConnectionBackoff(channel.get()); + // Trigger a second connection attempt. This should also fail + // ~immediately, but the retry should be scheduled for + // kInitialBackOffMs instead of applying the multiplier. + gpr_log(GPR_INFO, "=== POLLING FOR SECOND CONNECTION ATTEMPT"); + EXPECT_FALSE( + channel->WaitForConnected(grpc_timeout_milliseconds_to_deadline(10))); + // Bring up a server on the chosen port. + gpr_log(GPR_INFO, "=== STARTING BACKEND"); + StartServers(1, ports); + // Wait for connect. Should happen within kInitialBackOffMs. + // Give an extra 100ms to account for the time spent in the second and + // third connection attempts themselves (since what we really want to + // measure is the time between the two). As long as this is less than + // the 1.6x increase we would see if the backoff state was not reset + // properly, the test is still proving that the backoff was reset. + constexpr int kWaitMs = kInitialBackOffMs + 100; + gpr_log(GPR_INFO, "=== POLLING FOR THIRD CONNECTION ATTEMPT"); + EXPECT_TRUE(channel->WaitForConnected( + grpc_timeout_milliseconds_to_deadline(kWaitMs))); + const gpr_timespec t1 = gpr_now(GPR_CLOCK_MONOTONIC); + const grpc_millis waited_ms = gpr_time_to_millis(gpr_time_sub(t1, t0)); + gpr_log(GPR_DEBUG, "Waited %" PRId64 " milliseconds", waited_ms); + EXPECT_LT(waited_ms, kWaitMs); +} + +TEST_F(ClientLbEnd2endTest, PickFirstUpdates) { + // Start servers and send one RPC per server. + const int kNumServers = 3; + StartServers(kNumServers); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("pick_first", response_generator); + auto stub = BuildStub(channel); + + std::vector<int> ports; + + // Perform one RPC against the first server. + ports.emplace_back(servers_[0]->port_); + response_generator.SetNextResolution(ports); + gpr_log(GPR_INFO, "****** SET [0] *******"); + CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_EQ(servers_[0]->service_.request_count(), 1); + + // An empty update will result in the channel going into TRANSIENT_FAILURE. + ports.clear(); + response_generator.SetNextResolution(ports); + gpr_log(GPR_INFO, "****** SET none *******"); + grpc_connectivity_state channel_state; + do { + channel_state = channel->GetState(true /* try to connect */); + } while (channel_state == GRPC_CHANNEL_READY); + ASSERT_NE(channel_state, GRPC_CHANNEL_READY); + servers_[0]->service_.ResetCounters(); + + // Next update introduces servers_[1], making the channel recover. + ports.clear(); + ports.emplace_back(servers_[1]->port_); + response_generator.SetNextResolution(ports); + gpr_log(GPR_INFO, "****** SET [1] *******"); + WaitForServer(stub, 1, DEBUG_LOCATION); + EXPECT_EQ(servers_[0]->service_.request_count(), 0); + + // And again for servers_[2] + ports.clear(); + ports.emplace_back(servers_[2]->port_); + response_generator.SetNextResolution(ports); + gpr_log(GPR_INFO, "****** SET [2] *******"); + WaitForServer(stub, 2, DEBUG_LOCATION); + EXPECT_EQ(servers_[0]->service_.request_count(), 0); + EXPECT_EQ(servers_[1]->service_.request_count(), 0); + + // Check LB policy name for the channel. + EXPECT_EQ("pick_first", channel->GetLoadBalancingPolicyName()); +} + +TEST_F(ClientLbEnd2endTest, PickFirstUpdateSuperset) { + // Start servers and send one RPC per server. + const int kNumServers = 3; + StartServers(kNumServers); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("pick_first", response_generator); + auto stub = BuildStub(channel); + + std::vector<int> ports; + + // Perform one RPC against the first server. + ports.emplace_back(servers_[0]->port_); + response_generator.SetNextResolution(ports); + gpr_log(GPR_INFO, "****** SET [0] *******"); + CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_EQ(servers_[0]->service_.request_count(), 1); + servers_[0]->service_.ResetCounters(); + + // Send and superset update + ports.clear(); + ports.emplace_back(servers_[1]->port_); + ports.emplace_back(servers_[0]->port_); + response_generator.SetNextResolution(ports); + gpr_log(GPR_INFO, "****** SET superset *******"); + CheckRpcSendOk(stub, DEBUG_LOCATION); + // We stick to the previously connected server. + WaitForServer(stub, 0, DEBUG_LOCATION); + EXPECT_EQ(0, servers_[1]->service_.request_count()); + + // Check LB policy name for the channel. + EXPECT_EQ("pick_first", channel->GetLoadBalancingPolicyName()); +} + +TEST_F(ClientLbEnd2endTest, PickFirstGlobalSubchannelPool) { + // Start one server. + const int kNumServers = 1; + StartServers(kNumServers); + std::vector<int> ports = GetServersPorts(); + // Create two channels that (by default) use the global subchannel pool. + auto response_generator1 = BuildResolverResponseGenerator(); + auto channel1 = BuildChannel("pick_first", response_generator1); + auto stub1 = BuildStub(channel1); + response_generator1.SetNextResolution(ports); + auto response_generator2 = BuildResolverResponseGenerator(); + auto channel2 = BuildChannel("pick_first", response_generator2); + auto stub2 = BuildStub(channel2); + response_generator2.SetNextResolution(ports); + WaitForServer(stub1, 0, DEBUG_LOCATION); + // Send one RPC on each channel. + CheckRpcSendOk(stub1, DEBUG_LOCATION); + CheckRpcSendOk(stub2, DEBUG_LOCATION); + // The server receives two requests. + EXPECT_EQ(2, servers_[0]->service_.request_count()); + // The two requests are from the same client port, because the two channels + // share subchannels via the global subchannel pool. + EXPECT_EQ(1UL, servers_[0]->service_.clients().size()); +} + +TEST_F(ClientLbEnd2endTest, PickFirstLocalSubchannelPool) { + // Start one server. + const int kNumServers = 1; + StartServers(kNumServers); + std::vector<int> ports = GetServersPorts(); + // Create two channels that use local subchannel pool. + ChannelArguments args; + args.SetInt(GRPC_ARG_USE_LOCAL_SUBCHANNEL_POOL, 1); + auto response_generator1 = BuildResolverResponseGenerator(); + auto channel1 = BuildChannel("pick_first", response_generator1, args); + auto stub1 = BuildStub(channel1); + response_generator1.SetNextResolution(ports); + auto response_generator2 = BuildResolverResponseGenerator(); + auto channel2 = BuildChannel("pick_first", response_generator2, args); + auto stub2 = BuildStub(channel2); + response_generator2.SetNextResolution(ports); + WaitForServer(stub1, 0, DEBUG_LOCATION); + // Send one RPC on each channel. + CheckRpcSendOk(stub1, DEBUG_LOCATION); + CheckRpcSendOk(stub2, DEBUG_LOCATION); + // The server receives two requests. + EXPECT_EQ(2, servers_[0]->service_.request_count()); + // The two requests are from two client ports, because the two channels didn't + // share subchannels with each other. + EXPECT_EQ(2UL, servers_[0]->service_.clients().size()); +} + +TEST_F(ClientLbEnd2endTest, PickFirstManyUpdates) { + const int kNumUpdates = 1000; + const int kNumServers = 3; + StartServers(kNumServers); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("pick_first", response_generator); + auto stub = BuildStub(channel); + std::vector<int> ports = GetServersPorts(); + for (size_t i = 0; i < kNumUpdates; ++i) { + std::shuffle(ports.begin(), ports.end(), + std::mt19937(std::random_device()())); + response_generator.SetNextResolution(ports); + // We should re-enter core at the end of the loop to give the resolution + // setting closure a chance to run. + if ((i + 1) % 10 == 0) CheckRpcSendOk(stub, DEBUG_LOCATION); + } + // Check LB policy name for the channel. + EXPECT_EQ("pick_first", channel->GetLoadBalancingPolicyName()); +} + +TEST_F(ClientLbEnd2endTest, PickFirstReresolutionNoSelected) { + // Prepare the ports for up servers and down servers. + const int kNumServers = 3; + const int kNumAliveServers = 1; + StartServers(kNumAliveServers); + std::vector<int> alive_ports, dead_ports; + for (size_t i = 0; i < kNumServers; ++i) { + if (i < kNumAliveServers) { + alive_ports.emplace_back(servers_[i]->port_); + } else { + dead_ports.emplace_back(5107 + i); + // dead_ports.emplace_back(grpc_pick_unused_port_or_die()); + } + } + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("pick_first", response_generator); + auto stub = BuildStub(channel); + // The initial resolution only contains dead ports. There won't be any + // selected subchannel. Re-resolution will return the same result. + response_generator.SetNextResolution(dead_ports); + gpr_log(GPR_INFO, "****** INITIAL RESOLUTION SET *******"); + for (size_t i = 0; i < 10; ++i) CheckRpcSendFailure(stub); + // Set a re-resolution result that contains reachable ports, so that the + // pick_first LB policy can recover soon. + response_generator.SetNextResolutionUponError(alive_ports); + gpr_log(GPR_INFO, "****** RE-RESOLUTION SET *******"); + WaitForServer(stub, 0, DEBUG_LOCATION, true /* ignore_failure */); + CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_EQ(servers_[0]->service_.request_count(), 1); + // Check LB policy name for the channel. + EXPECT_EQ("pick_first", channel->GetLoadBalancingPolicyName()); +} + +TEST_F(ClientLbEnd2endTest, PickFirstReconnectWithoutNewResolverResult) { + std::vector<int> ports = {5110}; // {grpc_pick_unused_port_or_die()}; + StartServers(1, ports); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("pick_first", response_generator); + auto stub = BuildStub(channel); + response_generator.SetNextResolution(ports); + gpr_log(GPR_INFO, "****** INITIAL CONNECTION *******"); + WaitForServer(stub, 0, DEBUG_LOCATION); + gpr_log(GPR_INFO, "****** STOPPING SERVER ******"); + servers_[0]->Shutdown(); + EXPECT_TRUE(WaitForChannelNotReady(channel.get())); + gpr_log(GPR_INFO, "****** RESTARTING SERVER ******"); + StartServers(1, ports); + WaitForServer(stub, 0, DEBUG_LOCATION); +} + +TEST_F(ClientLbEnd2endTest, + PickFirstReconnectWithoutNewResolverResultStartsFromTopOfList) { + std::vector<int> ports = {5111, // grpc_pick_unused_port_or_die(), + 5112}; // grpc_pick_unused_port_or_die()}; + CreateServers(2, ports); + StartServer(1); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("pick_first", response_generator); + auto stub = BuildStub(channel); + response_generator.SetNextResolution(ports); + gpr_log(GPR_INFO, "****** INITIAL CONNECTION *******"); + WaitForServer(stub, 1, DEBUG_LOCATION); + gpr_log(GPR_INFO, "****** STOPPING SERVER ******"); + servers_[1]->Shutdown(); + EXPECT_TRUE(WaitForChannelNotReady(channel.get())); + gpr_log(GPR_INFO, "****** STARTING BOTH SERVERS ******"); + StartServers(2, ports); + WaitForServer(stub, 0, DEBUG_LOCATION); +} + +TEST_F(ClientLbEnd2endTest, PickFirstCheckStateBeforeStartWatch) { + std::vector<int> ports = {5113}; // {grpc_pick_unused_port_or_die()}; + StartServers(1, ports); + auto response_generator = BuildResolverResponseGenerator(); + auto channel_1 = BuildChannel("pick_first", response_generator); + auto stub_1 = BuildStub(channel_1); + response_generator.SetNextResolution(ports); + gpr_log(GPR_INFO, "****** RESOLUTION SET FOR CHANNEL 1 *******"); + WaitForServer(stub_1, 0, DEBUG_LOCATION); + gpr_log(GPR_INFO, "****** CHANNEL 1 CONNECTED *******"); + servers_[0]->Shutdown(); + // Channel 1 will receive a re-resolution containing the same server. It will + // create a new subchannel and hold a ref to it. + StartServers(1, ports); + gpr_log(GPR_INFO, "****** SERVER RESTARTED *******"); + auto response_generator_2 = BuildResolverResponseGenerator(); + auto channel_2 = BuildChannel("pick_first", response_generator_2); + auto stub_2 = BuildStub(channel_2); + response_generator_2.SetNextResolution(ports); + gpr_log(GPR_INFO, "****** RESOLUTION SET FOR CHANNEL 2 *******"); + WaitForServer(stub_2, 0, DEBUG_LOCATION, true); + gpr_log(GPR_INFO, "****** CHANNEL 2 CONNECTED *******"); + servers_[0]->Shutdown(); + // Wait until the disconnection has triggered the connectivity notification. + // Otherwise, the subchannel may be picked for next call but will fail soon. + EXPECT_TRUE(WaitForChannelNotReady(channel_2.get())); + // Channel 2 will also receive a re-resolution containing the same server. + // Both channels will ref the same subchannel that failed. + StartServers(1, ports); + gpr_log(GPR_INFO, "****** SERVER RESTARTED AGAIN *******"); + gpr_log(GPR_INFO, "****** CHANNEL 2 STARTING A CALL *******"); + // The first call after the server restart will succeed. + CheckRpcSendOk(stub_2, DEBUG_LOCATION); + gpr_log(GPR_INFO, "****** CHANNEL 2 FINISHED A CALL *******"); + // Check LB policy name for the channel. + EXPECT_EQ("pick_first", channel_1->GetLoadBalancingPolicyName()); + // Check LB policy name for the channel. + EXPECT_EQ("pick_first", channel_2->GetLoadBalancingPolicyName()); +} + +TEST_F(ClientLbEnd2endTest, PickFirstIdleOnDisconnect) { + // Start server, send RPC, and make sure channel is READY. + const int kNumServers = 1; + StartServers(kNumServers); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = + BuildChannel("", response_generator); // pick_first is the default. + auto stub = BuildStub(channel); + response_generator.SetNextResolution(GetServersPorts()); + CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY); + // Stop server. Channel should go into state IDLE. + response_generator.SetFailureOnReresolution(); + servers_[0]->Shutdown(); + EXPECT_TRUE(WaitForChannelNotReady(channel.get())); + EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_IDLE); + servers_.clear(); +} + +TEST_F(ClientLbEnd2endTest, PickFirstPendingUpdateAndSelectedSubchannelFails) { + auto response_generator = BuildResolverResponseGenerator(); + auto channel = + BuildChannel("", response_generator); // pick_first is the default. + auto stub = BuildStub(channel); + // Create a number of servers, but only start 1 of them. + CreateServers(10); + StartServer(0); + // Initially resolve to first server and make sure it connects. + gpr_log(GPR_INFO, "Phase 1: Connect to first server."); + response_generator.SetNextResolution({servers_[0]->port_}); + CheckRpcSendOk(stub, DEBUG_LOCATION, true /* wait_for_ready */); + EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY); + // Send a resolution update with the remaining servers, none of which are + // running yet, so the update will stay pending. Note that it's important + // to have multiple servers here, or else the test will be flaky; with only + // one server, the pending subchannel list has already gone into + // TRANSIENT_FAILURE due to hitting the end of the list by the time we + // check the state. + gpr_log(GPR_INFO, + "Phase 2: Resolver update pointing to remaining " + "(not started) servers."); + response_generator.SetNextResolution(GetServersPorts(1 /* start_index */)); + // RPCs will continue to be sent to the first server. + CheckRpcSendOk(stub, DEBUG_LOCATION); + // Now stop the first server, so that the current subchannel list + // fails. This should cause us to immediately swap over to the + // pending list, even though it's not yet connected. The state should + // be set to CONNECTING, since that's what the pending subchannel list + // was doing when we swapped over. + gpr_log(GPR_INFO, "Phase 3: Stopping first server."); + servers_[0]->Shutdown(); + WaitForChannelNotReady(channel.get()); + // TODO(roth): This should always return CONNECTING, but it's flaky + // between that and TRANSIENT_FAILURE. I suspect that this problem + // will go away once we move the backoff code out of the subchannel + // and into the LB policies. + EXPECT_THAT(channel->GetState(false), + ::testing::AnyOf(GRPC_CHANNEL_CONNECTING, + GRPC_CHANNEL_TRANSIENT_FAILURE)); + // Now start the second server. + gpr_log(GPR_INFO, "Phase 4: Starting second server."); + StartServer(1); + // The channel should go to READY state and RPCs should go to the + // second server. + WaitForChannelReady(channel.get()); + WaitForServer(stub, 1, DEBUG_LOCATION, true /* ignore_failure */); +} + +TEST_F(ClientLbEnd2endTest, PickFirstStaysIdleUponEmptyUpdate) { + // Start server, send RPC, and make sure channel is READY. + const int kNumServers = 1; + StartServers(kNumServers); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = + BuildChannel("", response_generator); // pick_first is the default. + auto stub = BuildStub(channel); + response_generator.SetNextResolution(GetServersPorts()); + CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY); + // Stop server. Channel should go into state IDLE. + servers_[0]->Shutdown(); + EXPECT_TRUE(WaitForChannelNotReady(channel.get())); + EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_IDLE); + // Now send resolver update that includes no addresses. Channel + // should stay in state IDLE. + response_generator.SetNextResolution({}); + EXPECT_FALSE(channel->WaitForStateChange( + GRPC_CHANNEL_IDLE, grpc_timeout_seconds_to_deadline(3))); + // Now bring the backend back up and send a non-empty resolver update, + // and then try to send an RPC. Channel should go back into state READY. + StartServer(0); + response_generator.SetNextResolution(GetServersPorts()); + CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY); +} + +TEST_F(ClientLbEnd2endTest, RoundRobin) { + // Start servers and send one RPC per server. + const int kNumServers = 3; + StartServers(kNumServers); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("round_robin", response_generator); + auto stub = BuildStub(channel); + response_generator.SetNextResolution(GetServersPorts()); + // Wait until all backends are ready. + do { + CheckRpcSendOk(stub, DEBUG_LOCATION); + } while (!SeenAllServers()); + ResetCounters(); + // "Sync" to the end of the list. Next sequence of picks will start at the + // first server (index 0). + WaitForServer(stub, servers_.size() - 1, DEBUG_LOCATION); + std::vector<int> connection_order; + for (size_t i = 0; i < servers_.size(); ++i) { + CheckRpcSendOk(stub, DEBUG_LOCATION); + UpdateConnectionOrder(servers_, &connection_order); + } + // Backends should be iterated over in the order in which the addresses were + // given. + const auto expected = std::vector<int>{0, 1, 2}; + EXPECT_EQ(expected, connection_order); + // Check LB policy name for the channel. + EXPECT_EQ("round_robin", channel->GetLoadBalancingPolicyName()); +} + +TEST_F(ClientLbEnd2endTest, RoundRobinProcessPending) { + StartServers(1); // Single server + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("round_robin", response_generator); + auto stub = BuildStub(channel); + response_generator.SetNextResolution({servers_[0]->port_}); + WaitForServer(stub, 0, DEBUG_LOCATION); + // Create a new channel and its corresponding RR LB policy, which will pick + // the subchannels in READY state from the previous RPC against the same + // target (even if it happened over a different channel, because subchannels + // are globally reused). Progress should happen without any transition from + // this READY state. + auto second_response_generator = BuildResolverResponseGenerator(); + auto second_channel = BuildChannel("round_robin", second_response_generator); + auto second_stub = BuildStub(second_channel); + second_response_generator.SetNextResolution({servers_[0]->port_}); + CheckRpcSendOk(second_stub, DEBUG_LOCATION); +} + +TEST_F(ClientLbEnd2endTest, RoundRobinUpdates) { + // Start servers and send one RPC per server. + const int kNumServers = 3; + StartServers(kNumServers); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("round_robin", response_generator); + auto stub = BuildStub(channel); + std::vector<int> ports; + // Start with a single server. + gpr_log(GPR_INFO, "*** FIRST BACKEND ***"); + ports.emplace_back(servers_[0]->port_); + response_generator.SetNextResolution(ports); + WaitForServer(stub, 0, DEBUG_LOCATION); + // Send RPCs. They should all go servers_[0] + for (size_t i = 0; i < 10; ++i) CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_EQ(10, servers_[0]->service_.request_count()); + EXPECT_EQ(0, servers_[1]->service_.request_count()); + EXPECT_EQ(0, servers_[2]->service_.request_count()); + servers_[0]->service_.ResetCounters(); + // And now for the second server. + gpr_log(GPR_INFO, "*** SECOND BACKEND ***"); + ports.clear(); + ports.emplace_back(servers_[1]->port_); + response_generator.SetNextResolution(ports); + // Wait until update has been processed, as signaled by the second backend + // receiving a request. + EXPECT_EQ(0, servers_[1]->service_.request_count()); + WaitForServer(stub, 1, DEBUG_LOCATION); + for (size_t i = 0; i < 10; ++i) CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_EQ(0, servers_[0]->service_.request_count()); + EXPECT_EQ(10, servers_[1]->service_.request_count()); + EXPECT_EQ(0, servers_[2]->service_.request_count()); + servers_[1]->service_.ResetCounters(); + // ... and for the last server. + gpr_log(GPR_INFO, "*** THIRD BACKEND ***"); + ports.clear(); + ports.emplace_back(servers_[2]->port_); + response_generator.SetNextResolution(ports); + WaitForServer(stub, 2, DEBUG_LOCATION); + for (size_t i = 0; i < 10; ++i) CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_EQ(0, servers_[0]->service_.request_count()); + EXPECT_EQ(0, servers_[1]->service_.request_count()); + EXPECT_EQ(10, servers_[2]->service_.request_count()); + servers_[2]->service_.ResetCounters(); + // Back to all servers. + gpr_log(GPR_INFO, "*** ALL BACKENDS ***"); + ports.clear(); + ports.emplace_back(servers_[0]->port_); + ports.emplace_back(servers_[1]->port_); + ports.emplace_back(servers_[2]->port_); + response_generator.SetNextResolution(ports); + WaitForServer(stub, 0, DEBUG_LOCATION); + WaitForServer(stub, 1, DEBUG_LOCATION); + WaitForServer(stub, 2, DEBUG_LOCATION); + // Send three RPCs, one per server. + for (size_t i = 0; i < 3; ++i) CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_EQ(1, servers_[0]->service_.request_count()); + EXPECT_EQ(1, servers_[1]->service_.request_count()); + EXPECT_EQ(1, servers_[2]->service_.request_count()); + // An empty update will result in the channel going into TRANSIENT_FAILURE. + gpr_log(GPR_INFO, "*** NO BACKENDS ***"); + ports.clear(); + response_generator.SetNextResolution(ports); + grpc_connectivity_state channel_state; + do { + channel_state = channel->GetState(true /* try to connect */); + } while (channel_state == GRPC_CHANNEL_READY); + ASSERT_NE(channel_state, GRPC_CHANNEL_READY); + servers_[0]->service_.ResetCounters(); + // Next update introduces servers_[1], making the channel recover. + gpr_log(GPR_INFO, "*** BACK TO SECOND BACKEND ***"); + ports.clear(); + ports.emplace_back(servers_[1]->port_); + response_generator.SetNextResolution(ports); + WaitForServer(stub, 1, DEBUG_LOCATION); + channel_state = channel->GetState(false /* try to connect */); + ASSERT_EQ(channel_state, GRPC_CHANNEL_READY); + // Check LB policy name for the channel. + EXPECT_EQ("round_robin", channel->GetLoadBalancingPolicyName()); +} + +TEST_F(ClientLbEnd2endTest, RoundRobinUpdateInError) { + const int kNumServers = 3; + StartServers(kNumServers); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("round_robin", response_generator); + auto stub = BuildStub(channel); + std::vector<int> ports; + // Start with a single server. + ports.emplace_back(servers_[0]->port_); + response_generator.SetNextResolution(ports); + WaitForServer(stub, 0, DEBUG_LOCATION); + // Send RPCs. They should all go to servers_[0] + for (size_t i = 0; i < 10; ++i) SendRpc(stub); + EXPECT_EQ(10, servers_[0]->service_.request_count()); + EXPECT_EQ(0, servers_[1]->service_.request_count()); + EXPECT_EQ(0, servers_[2]->service_.request_count()); + servers_[0]->service_.ResetCounters(); + // Shutdown one of the servers to be sent in the update. + servers_[1]->Shutdown(); + ports.emplace_back(servers_[1]->port_); + ports.emplace_back(servers_[2]->port_); + response_generator.SetNextResolution(ports); + WaitForServer(stub, 0, DEBUG_LOCATION); + WaitForServer(stub, 2, DEBUG_LOCATION); + // Send three RPCs, one per server. + for (size_t i = 0; i < kNumServers; ++i) SendRpc(stub); + // The server in shutdown shouldn't receive any. + EXPECT_EQ(0, servers_[1]->service_.request_count()); +} + +TEST_F(ClientLbEnd2endTest, RoundRobinManyUpdates) { + // Start servers and send one RPC per server. + const int kNumServers = 3; + StartServers(kNumServers); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("round_robin", response_generator); + auto stub = BuildStub(channel); + std::vector<int> ports = GetServersPorts(); + for (size_t i = 0; i < 1000; ++i) { + std::shuffle(ports.begin(), ports.end(), + std::mt19937(std::random_device()())); + response_generator.SetNextResolution(ports); + if (i % 10 == 0) CheckRpcSendOk(stub, DEBUG_LOCATION); + } + // Check LB policy name for the channel. + EXPECT_EQ("round_robin", channel->GetLoadBalancingPolicyName()); +} + +TEST_F(ClientLbEnd2endTest, RoundRobinConcurrentUpdates) { + // TODO(dgq): replicate the way internal testing exercises the concurrent + // update provisions of RR. +} + +TEST_F(ClientLbEnd2endTest, RoundRobinReresolve) { + // Start servers and send one RPC per server. + const int kNumServers = 3; + std::vector<int> first_ports; + std::vector<int> second_ports; + first_ports.reserve(kNumServers); + for (int i = 0; i < kNumServers; ++i) { + // first_ports.push_back(grpc_pick_unused_port_or_die()); + first_ports.push_back(5114 + i); + } + second_ports.reserve(kNumServers); + for (int i = 0; i < kNumServers; ++i) { + // second_ports.push_back(grpc_pick_unused_port_or_die()); + second_ports.push_back(5117 + i); + } + StartServers(kNumServers, first_ports); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("round_robin", response_generator); + auto stub = BuildStub(channel); + response_generator.SetNextResolution(first_ports); + // Send a number of RPCs, which succeed. + for (size_t i = 0; i < 100; ++i) { + CheckRpcSendOk(stub, DEBUG_LOCATION); + } + // Kill all servers + gpr_log(GPR_INFO, "****** ABOUT TO KILL SERVERS *******"); + for (size_t i = 0; i < servers_.size(); ++i) { + servers_[i]->Shutdown(); + } + gpr_log(GPR_INFO, "****** SERVERS KILLED *******"); + gpr_log(GPR_INFO, "****** SENDING DOOMED REQUESTS *******"); + // Client requests should fail. Send enough to tickle all subchannels. + for (size_t i = 0; i < servers_.size(); ++i) CheckRpcSendFailure(stub); + gpr_log(GPR_INFO, "****** DOOMED REQUESTS SENT *******"); + // Bring servers back up on a different set of ports. We need to do this to be + // sure that the eventual success is *not* due to subchannel reconnection + // attempts and that an actual re-resolution has happened as a result of the + // RR policy going into transient failure when all its subchannels become + // unavailable (in transient failure as well). + gpr_log(GPR_INFO, "****** RESTARTING SERVERS *******"); + StartServers(kNumServers, second_ports); + // Don't notify of the update. Wait for the LB policy's re-resolution to + // "pull" the new ports. + response_generator.SetNextResolutionUponError(second_ports); + gpr_log(GPR_INFO, "****** SERVERS RESTARTED *******"); + gpr_log(GPR_INFO, "****** SENDING REQUEST TO SUCCEED *******"); + // Client request should eventually (but still fairly soon) succeed. + const gpr_timespec deadline = grpc_timeout_seconds_to_deadline(5); + gpr_timespec now = gpr_now(GPR_CLOCK_MONOTONIC); + while (gpr_time_cmp(deadline, now) > 0) { + if (SendRpc(stub)) break; + now = gpr_now(GPR_CLOCK_MONOTONIC); + } + ASSERT_GT(gpr_time_cmp(deadline, now), 0); +} + +TEST_F(ClientLbEnd2endTest, RoundRobinTransientFailure) { + // Start servers and create channel. Channel should go to READY state. + const int kNumServers = 3; + StartServers(kNumServers); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("round_robin", response_generator); + auto stub = BuildStub(channel); + response_generator.SetNextResolution(GetServersPorts()); + EXPECT_TRUE(WaitForChannelReady(channel.get())); + // Now kill the servers. The channel should transition to TRANSIENT_FAILURE. + // TODO(roth): This test should ideally check that even when the + // subchannels are in state CONNECTING for an extended period of time, + // we will still report TRANSIENT_FAILURE. Unfortunately, we don't + // currently have a good way to get a subchannel to report CONNECTING + // for a long period of time, since the servers in this test framework + // are on the loopback interface, which will immediately return a + // "Connection refused" error, so the subchannels will only be in + // CONNECTING state very briefly. When we have time, see if we can + // find a way to fix this. + for (size_t i = 0; i < servers_.size(); ++i) { + servers_[i]->Shutdown(); + } + auto predicate = [](grpc_connectivity_state state) { + return state == GRPC_CHANNEL_TRANSIENT_FAILURE; + }; + EXPECT_TRUE(WaitForChannelState(channel.get(), predicate)); +} + +TEST_F(ClientLbEnd2endTest, RoundRobinTransientFailureAtStartup) { + // Create channel and return servers that don't exist. Channel should + // quickly transition into TRANSIENT_FAILURE. + // TODO(roth): This test should ideally check that even when the + // subchannels are in state CONNECTING for an extended period of time, + // we will still report TRANSIENT_FAILURE. Unfortunately, we don't + // currently have a good way to get a subchannel to report CONNECTING + // for a long period of time, since the servers in this test framework + // are on the loopback interface, which will immediately return a + // "Connection refused" error, so the subchannels will only be in + // CONNECTING state very briefly. When we have time, see if we can + // find a way to fix this. + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("round_robin", response_generator); + auto stub = BuildStub(channel); + response_generator.SetNextResolution({ + grpc_pick_unused_port_or_die(), + grpc_pick_unused_port_or_die(), + grpc_pick_unused_port_or_die(), + }); + for (size_t i = 0; i < servers_.size(); ++i) { + servers_[i]->Shutdown(); + } + auto predicate = [](grpc_connectivity_state state) { + return state == GRPC_CHANNEL_TRANSIENT_FAILURE; + }; + EXPECT_TRUE(WaitForChannelState(channel.get(), predicate, true)); +} + +TEST_F(ClientLbEnd2endTest, RoundRobinSingleReconnect) { + const int kNumServers = 3; + StartServers(kNumServers); + const auto ports = GetServersPorts(); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("round_robin", response_generator); + auto stub = BuildStub(channel); + response_generator.SetNextResolution(ports); + for (size_t i = 0; i < kNumServers; ++i) { + WaitForServer(stub, i, DEBUG_LOCATION); + } + for (size_t i = 0; i < servers_.size(); ++i) { + CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_EQ(1, servers_[i]->service_.request_count()) << "for backend #" << i; + } + // One request should have gone to each server. + for (size_t i = 0; i < servers_.size(); ++i) { + EXPECT_EQ(1, servers_[i]->service_.request_count()); + } + const auto pre_death = servers_[0]->service_.request_count(); + // Kill the first server. + servers_[0]->Shutdown(); + // Client request still succeed. May need retrying if RR had returned a pick + // before noticing the change in the server's connectivity. + while (!SendRpc(stub)) { + } // Retry until success. + // Send a bunch of RPCs that should succeed. + for (int i = 0; i < 10 * kNumServers; ++i) { + CheckRpcSendOk(stub, DEBUG_LOCATION); + } + const auto post_death = servers_[0]->service_.request_count(); + // No requests have gone to the deceased server. + EXPECT_EQ(pre_death, post_death); + // Bring the first server back up. + StartServer(0); + // Requests should start arriving at the first server either right away (if + // the server managed to start before the RR policy retried the subchannel) or + // after the subchannel retry delay otherwise (RR's subchannel retried before + // the server was fully back up). + WaitForServer(stub, 0, DEBUG_LOCATION); +} + +// If health checking is required by client but health checking service +// is not running on the server, the channel should be treated as healthy. +TEST_F(ClientLbEnd2endTest, + RoundRobinServersHealthCheckingUnimplementedTreatedAsHealthy) { + StartServers(1); // Single server + ChannelArguments args; + args.SetServiceConfigJSON( + "{\"healthCheckConfig\": " + "{\"serviceName\": \"health_check_service_name\"}}"); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("round_robin", response_generator, args); + auto stub = BuildStub(channel); + response_generator.SetNextResolution({servers_[0]->port_}); + EXPECT_TRUE(WaitForChannelReady(channel.get())); + CheckRpcSendOk(stub, DEBUG_LOCATION); +} + +TEST_F(ClientLbEnd2endTest, RoundRobinWithHealthChecking) { + EnableDefaultHealthCheckService(true); + // Start servers. + const int kNumServers = 3; + StartServers(kNumServers); + ChannelArguments args; + args.SetServiceConfigJSON( + "{\"healthCheckConfig\": " + "{\"serviceName\": \"health_check_service_name\"}}"); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("round_robin", response_generator, args); + auto stub = BuildStub(channel); + response_generator.SetNextResolution(GetServersPorts()); + // Channel should not become READY, because health checks should be failing. + gpr_log(GPR_INFO, + "*** initial state: unknown health check service name for " + "all servers"); + EXPECT_FALSE(WaitForChannelReady(channel.get(), 1)); + // Now set one of the servers to be healthy. + // The channel should become healthy and all requests should go to + // the healthy server. + gpr_log(GPR_INFO, "*** server 0 healthy"); + servers_[0]->SetServingStatus("health_check_service_name", true); + EXPECT_TRUE(WaitForChannelReady(channel.get())); + for (int i = 0; i < 10; ++i) { + CheckRpcSendOk(stub, DEBUG_LOCATION); + } + EXPECT_EQ(10, servers_[0]->service_.request_count()); + EXPECT_EQ(0, servers_[1]->service_.request_count()); + EXPECT_EQ(0, servers_[2]->service_.request_count()); + // Now set a second server to be healthy. + gpr_log(GPR_INFO, "*** server 2 healthy"); + servers_[2]->SetServingStatus("health_check_service_name", true); + WaitForServer(stub, 2, DEBUG_LOCATION); + for (int i = 0; i < 10; ++i) { + CheckRpcSendOk(stub, DEBUG_LOCATION); + } + EXPECT_EQ(5, servers_[0]->service_.request_count()); + EXPECT_EQ(0, servers_[1]->service_.request_count()); + EXPECT_EQ(5, servers_[2]->service_.request_count()); + // Now set the remaining server to be healthy. + gpr_log(GPR_INFO, "*** server 1 healthy"); + servers_[1]->SetServingStatus("health_check_service_name", true); + WaitForServer(stub, 1, DEBUG_LOCATION); + for (int i = 0; i < 9; ++i) { + CheckRpcSendOk(stub, DEBUG_LOCATION); + } + EXPECT_EQ(3, servers_[0]->service_.request_count()); + EXPECT_EQ(3, servers_[1]->service_.request_count()); + EXPECT_EQ(3, servers_[2]->service_.request_count()); + // Now set one server to be unhealthy again. Then wait until the + // unhealthiness has hit the client. We know that the client will see + // this when we send kNumServers requests and one of the remaining servers + // sees two of the requests. + gpr_log(GPR_INFO, "*** server 0 unhealthy"); + servers_[0]->SetServingStatus("health_check_service_name", false); + do { + ResetCounters(); + for (int i = 0; i < kNumServers; ++i) { + CheckRpcSendOk(stub, DEBUG_LOCATION); + } + } while (servers_[1]->service_.request_count() != 2 && + servers_[2]->service_.request_count() != 2); + // Now set the remaining two servers to be unhealthy. Make sure the + // channel leaves READY state and that RPCs fail. + gpr_log(GPR_INFO, "*** all servers unhealthy"); + servers_[1]->SetServingStatus("health_check_service_name", false); + servers_[2]->SetServingStatus("health_check_service_name", false); + EXPECT_TRUE(WaitForChannelNotReady(channel.get())); + CheckRpcSendFailure(stub); + // Clean up. + EnableDefaultHealthCheckService(false); +} + +TEST_F(ClientLbEnd2endTest, + RoundRobinWithHealthCheckingHandlesSubchannelFailure) { + EnableDefaultHealthCheckService(true); + // Start servers. + const int kNumServers = 3; + StartServers(kNumServers); + servers_[0]->SetServingStatus("health_check_service_name", true); + servers_[1]->SetServingStatus("health_check_service_name", true); + servers_[2]->SetServingStatus("health_check_service_name", true); + ChannelArguments args; + args.SetServiceConfigJSON( + "{\"healthCheckConfig\": " + "{\"serviceName\": \"health_check_service_name\"}}"); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("round_robin", response_generator, args); + auto stub = BuildStub(channel); + response_generator.SetNextResolution(GetServersPorts()); + WaitForServer(stub, 0, DEBUG_LOCATION); + // Stop server 0 and send a new resolver result to ensure that RR + // checks each subchannel's state. + servers_[0]->Shutdown(); + response_generator.SetNextResolution(GetServersPorts()); + // Send a bunch more RPCs. + for (size_t i = 0; i < 100; i++) { + SendRpc(stub); + } +} + +TEST_F(ClientLbEnd2endTest, RoundRobinWithHealthCheckingInhibitPerChannel) { + EnableDefaultHealthCheckService(true); + // Start server. + const int kNumServers = 1; + StartServers(kNumServers); + // Create a channel with health-checking enabled. + ChannelArguments args; + args.SetServiceConfigJSON( + "{\"healthCheckConfig\": " + "{\"serviceName\": \"health_check_service_name\"}}"); + auto response_generator1 = BuildResolverResponseGenerator(); + auto channel1 = BuildChannel("round_robin", response_generator1, args); + auto stub1 = BuildStub(channel1); + std::vector<int> ports = GetServersPorts(); + response_generator1.SetNextResolution(ports); + // Create a channel with health checking enabled but inhibited. + args.SetInt(GRPC_ARG_INHIBIT_HEALTH_CHECKING, 1); + auto response_generator2 = BuildResolverResponseGenerator(); + auto channel2 = BuildChannel("round_robin", response_generator2, args); + auto stub2 = BuildStub(channel2); + response_generator2.SetNextResolution(ports); + // First channel should not become READY, because health checks should be + // failing. + EXPECT_FALSE(WaitForChannelReady(channel1.get(), 1)); + CheckRpcSendFailure(stub1); + // Second channel should be READY. + EXPECT_TRUE(WaitForChannelReady(channel2.get(), 1)); + CheckRpcSendOk(stub2, DEBUG_LOCATION); + // Enable health checks on the backend and wait for channel 1 to succeed. + servers_[0]->SetServingStatus("health_check_service_name", true); + CheckRpcSendOk(stub1, DEBUG_LOCATION, true /* wait_for_ready */); + // Check that we created only one subchannel to the backend. + EXPECT_EQ(1UL, servers_[0]->service_.clients().size()); + // Clean up. + EnableDefaultHealthCheckService(false); +} + +TEST_F(ClientLbEnd2endTest, RoundRobinWithHealthCheckingServiceNamePerChannel) { + EnableDefaultHealthCheckService(true); + // Start server. + const int kNumServers = 1; + StartServers(kNumServers); + // Create a channel with health-checking enabled. + ChannelArguments args; + args.SetServiceConfigJSON( + "{\"healthCheckConfig\": " + "{\"serviceName\": \"health_check_service_name\"}}"); + auto response_generator1 = BuildResolverResponseGenerator(); + auto channel1 = BuildChannel("round_robin", response_generator1, args); + auto stub1 = BuildStub(channel1); + std::vector<int> ports = GetServersPorts(); + response_generator1.SetNextResolution(ports); + // Create a channel with health-checking enabled with a different + // service name. + ChannelArguments args2; + args2.SetServiceConfigJSON( + "{\"healthCheckConfig\": " + "{\"serviceName\": \"health_check_service_name2\"}}"); + auto response_generator2 = BuildResolverResponseGenerator(); + auto channel2 = BuildChannel("round_robin", response_generator2, args2); + auto stub2 = BuildStub(channel2); + response_generator2.SetNextResolution(ports); + // Allow health checks from channel 2 to succeed. + servers_[0]->SetServingStatus("health_check_service_name2", true); + // First channel should not become READY, because health checks should be + // failing. + EXPECT_FALSE(WaitForChannelReady(channel1.get(), 1)); + CheckRpcSendFailure(stub1); + // Second channel should be READY. + EXPECT_TRUE(WaitForChannelReady(channel2.get(), 1)); + CheckRpcSendOk(stub2, DEBUG_LOCATION); + // Enable health checks for channel 1 and wait for it to succeed. + servers_[0]->SetServingStatus("health_check_service_name", true); + CheckRpcSendOk(stub1, DEBUG_LOCATION, true /* wait_for_ready */); + // Check that we created only one subchannel to the backend. + EXPECT_EQ(1UL, servers_[0]->service_.clients().size()); + // Clean up. + EnableDefaultHealthCheckService(false); +} + +TEST_F(ClientLbEnd2endTest, + RoundRobinWithHealthCheckingServiceNameChangesAfterSubchannelsCreated) { + EnableDefaultHealthCheckService(true); + // Start server. + const int kNumServers = 1; + StartServers(kNumServers); + // Create a channel with health-checking enabled. + const char* kServiceConfigJson = + "{\"healthCheckConfig\": " + "{\"serviceName\": \"health_check_service_name\"}}"; + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("round_robin", response_generator); + auto stub = BuildStub(channel); + std::vector<int> ports = GetServersPorts(); + response_generator.SetNextResolution(ports, kServiceConfigJson); + servers_[0]->SetServingStatus("health_check_service_name", true); + EXPECT_TRUE(WaitForChannelReady(channel.get(), 1 /* timeout_seconds */)); + // Send an update on the channel to change it to use a health checking + // service name that is not being reported as healthy. + const char* kServiceConfigJson2 = + "{\"healthCheckConfig\": " + "{\"serviceName\": \"health_check_service_name2\"}}"; + response_generator.SetNextResolution(ports, kServiceConfigJson2); + EXPECT_TRUE(WaitForChannelNotReady(channel.get())); + // Clean up. + EnableDefaultHealthCheckService(false); +} + +TEST_F(ClientLbEnd2endTest, ChannelIdleness) { + // Start server. + const int kNumServers = 1; + StartServers(kNumServers); + // Set max idle time and build the channel. + ChannelArguments args; + args.SetInt(GRPC_ARG_CLIENT_IDLE_TIMEOUT_MS, 1000); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("", response_generator, args); + auto stub = BuildStub(channel); + // The initial channel state should be IDLE. + EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_IDLE); + // After sending RPC, channel state should be READY. + response_generator.SetNextResolution(GetServersPorts()); + CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY); + // After a period time not using the channel, the channel state should switch + // to IDLE. + gpr_sleep_until(grpc_timeout_milliseconds_to_deadline(1200)); + EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_IDLE); + // Sending a new RPC should awake the IDLE channel. + response_generator.SetNextResolution(GetServersPorts()); + CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY); +} + +class ClientLbPickArgsTest : public ClientLbEnd2endTest { + protected: + void SetUp() override { + ClientLbEnd2endTest::SetUp(); + current_test_instance_ = this; + } + + static void SetUpTestCase() { + grpc_init(); + grpc_core::RegisterTestPickArgsLoadBalancingPolicy(SavePickArgs); + } + + static void TearDownTestCase() { grpc_shutdown_blocking(); } + + const std::vector<grpc_core::PickArgsSeen>& args_seen_list() { + grpc::internal::MutexLock lock(&mu_); + return args_seen_list_; + } + + private: + static void SavePickArgs(const grpc_core::PickArgsSeen& args_seen) { + ClientLbPickArgsTest* self = current_test_instance_; + grpc::internal::MutexLock lock(&self->mu_); + self->args_seen_list_.emplace_back(args_seen); + } + + static ClientLbPickArgsTest* current_test_instance_; + grpc::internal::Mutex mu_; + std::vector<grpc_core::PickArgsSeen> args_seen_list_; +}; + +ClientLbPickArgsTest* ClientLbPickArgsTest::current_test_instance_ = nullptr; + +TEST_F(ClientLbPickArgsTest, Basic) { + const int kNumServers = 1; + StartServers(kNumServers); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("test_pick_args_lb", response_generator); + auto stub = BuildStub(channel); + response_generator.SetNextResolution(GetServersPorts()); + CheckRpcSendOk(stub, DEBUG_LOCATION, /*wait_for_ready=*/true); + // Check LB policy name for the channel. + EXPECT_EQ("test_pick_args_lb", channel->GetLoadBalancingPolicyName()); + // There will be two entries, one for the pick tried in state + // CONNECTING and another for the pick tried in state READY. + EXPECT_THAT(args_seen_list(), + ::testing::ElementsAre( + ::testing::AllOf( + ::testing::Field(&grpc_core::PickArgsSeen::path, + "/grpc.testing.EchoTestService/Echo"), + ::testing::Field(&grpc_core::PickArgsSeen::metadata, + ::testing::UnorderedElementsAre( + ::testing::Pair("foo", "1"), + ::testing::Pair("bar", "2"), + ::testing::Pair("baz", "3")))), + ::testing::AllOf( + ::testing::Field(&grpc_core::PickArgsSeen::path, + "/grpc.testing.EchoTestService/Echo"), + ::testing::Field(&grpc_core::PickArgsSeen::metadata, + ::testing::UnorderedElementsAre( + ::testing::Pair("foo", "1"), + ::testing::Pair("bar", "2"), + ::testing::Pair("baz", "3")))))); +} + +class ClientLbInterceptTrailingMetadataTest : public ClientLbEnd2endTest { + protected: + void SetUp() override { + ClientLbEnd2endTest::SetUp(); + current_test_instance_ = this; + } + + static void SetUpTestCase() { + grpc_init(); + grpc_core::RegisterInterceptRecvTrailingMetadataLoadBalancingPolicy( + ReportTrailerIntercepted); + } + + static void TearDownTestCase() { grpc_shutdown_blocking(); } + + int trailers_intercepted() { + grpc::internal::MutexLock lock(&mu_); + return trailers_intercepted_; + } + + const grpc_core::MetadataVector& trailing_metadata() { + grpc::internal::MutexLock lock(&mu_); + return trailing_metadata_; + } + + const udpa::data::orca::v1::OrcaLoadReport* backend_load_report() { + grpc::internal::MutexLock lock(&mu_); + return load_report_.get(); + } + + private: + static void ReportTrailerIntercepted( + const grpc_core::TrailingMetadataArgsSeen& args_seen) { + const auto* backend_metric_data = args_seen.backend_metric_data; + ClientLbInterceptTrailingMetadataTest* self = current_test_instance_; + grpc::internal::MutexLock lock(&self->mu_); + self->trailers_intercepted_++; + self->trailing_metadata_ = args_seen.metadata; + if (backend_metric_data != nullptr) { + self->load_report_.reset(new udpa::data::orca::v1::OrcaLoadReport); + self->load_report_->set_cpu_utilization( + backend_metric_data->cpu_utilization); + self->load_report_->set_mem_utilization( + backend_metric_data->mem_utilization); + self->load_report_->set_rps(backend_metric_data->requests_per_second); + for (const auto& p : backend_metric_data->request_cost) { + TString name = TString(p.first); + (*self->load_report_->mutable_request_cost())[std::move(name)] = + p.second; + } + for (const auto& p : backend_metric_data->utilization) { + TString name = TString(p.first); + (*self->load_report_->mutable_utilization())[std::move(name)] = + p.second; + } + } + } + + static ClientLbInterceptTrailingMetadataTest* current_test_instance_; + grpc::internal::Mutex mu_; + int trailers_intercepted_ = 0; + grpc_core::MetadataVector trailing_metadata_; + std::unique_ptr<udpa::data::orca::v1::OrcaLoadReport> load_report_; +}; + +ClientLbInterceptTrailingMetadataTest* + ClientLbInterceptTrailingMetadataTest::current_test_instance_ = nullptr; + +TEST_F(ClientLbInterceptTrailingMetadataTest, InterceptsRetriesDisabled) { + const int kNumServers = 1; + const int kNumRpcs = 10; + StartServers(kNumServers); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = + BuildChannel("intercept_trailing_metadata_lb", response_generator); + auto stub = BuildStub(channel); + response_generator.SetNextResolution(GetServersPorts()); + for (size_t i = 0; i < kNumRpcs; ++i) { + CheckRpcSendOk(stub, DEBUG_LOCATION); + } + // Check LB policy name for the channel. + EXPECT_EQ("intercept_trailing_metadata_lb", + channel->GetLoadBalancingPolicyName()); + EXPECT_EQ(kNumRpcs, trailers_intercepted()); + EXPECT_THAT(trailing_metadata(), + ::testing::UnorderedElementsAre( + // TODO(roth): Should grpc-status be visible here? + ::testing::Pair("grpc-status", "0"), + ::testing::Pair("user-agent", ::testing::_), + ::testing::Pair("foo", "1"), ::testing::Pair("bar", "2"), + ::testing::Pair("baz", "3"))); + EXPECT_EQ(nullptr, backend_load_report()); +} + +TEST_F(ClientLbInterceptTrailingMetadataTest, InterceptsRetriesEnabled) { + const int kNumServers = 1; + const int kNumRpcs = 10; + StartServers(kNumServers); + ChannelArguments args; + args.SetServiceConfigJSON( + "{\n" + " \"methodConfig\": [ {\n" + " \"name\": [\n" + " { \"service\": \"grpc.testing.EchoTestService\" }\n" + " ],\n" + " \"retryPolicy\": {\n" + " \"maxAttempts\": 3,\n" + " \"initialBackoff\": \"1s\",\n" + " \"maxBackoff\": \"120s\",\n" + " \"backoffMultiplier\": 1.6,\n" + " \"retryableStatusCodes\": [ \"ABORTED\" ]\n" + " }\n" + " } ]\n" + "}"); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = + BuildChannel("intercept_trailing_metadata_lb", response_generator, args); + auto stub = BuildStub(channel); + response_generator.SetNextResolution(GetServersPorts()); + for (size_t i = 0; i < kNumRpcs; ++i) { + CheckRpcSendOk(stub, DEBUG_LOCATION); + } + // Check LB policy name for the channel. + EXPECT_EQ("intercept_trailing_metadata_lb", + channel->GetLoadBalancingPolicyName()); + EXPECT_EQ(kNumRpcs, trailers_intercepted()); + EXPECT_THAT(trailing_metadata(), + ::testing::UnorderedElementsAre( + // TODO(roth): Should grpc-status be visible here? + ::testing::Pair("grpc-status", "0"), + ::testing::Pair("user-agent", ::testing::_), + ::testing::Pair("foo", "1"), ::testing::Pair("bar", "2"), + ::testing::Pair("baz", "3"))); + EXPECT_EQ(nullptr, backend_load_report()); +} + +TEST_F(ClientLbInterceptTrailingMetadataTest, BackendMetricData) { + const int kNumServers = 1; + const int kNumRpcs = 10; + StartServers(kNumServers); + udpa::data::orca::v1::OrcaLoadReport load_report; + load_report.set_cpu_utilization(0.5); + load_report.set_mem_utilization(0.75); + load_report.set_rps(25); + auto* request_cost = load_report.mutable_request_cost(); + (*request_cost)["foo"] = 0.8; + (*request_cost)["bar"] = 1.4; + auto* utilization = load_report.mutable_utilization(); + (*utilization)["baz"] = 1.1; + (*utilization)["quux"] = 0.9; + for (const auto& server : servers_) { + server->service_.set_load_report(&load_report); + } + auto response_generator = BuildResolverResponseGenerator(); + auto channel = + BuildChannel("intercept_trailing_metadata_lb", response_generator); + auto stub = BuildStub(channel); + response_generator.SetNextResolution(GetServersPorts()); + for (size_t i = 0; i < kNumRpcs; ++i) { + CheckRpcSendOk(stub, DEBUG_LOCATION); + auto* actual = backend_load_report(); + ASSERT_NE(actual, nullptr); + // TODO(roth): Change this to use EqualsProto() once that becomes + // available in OSS. + EXPECT_EQ(actual->cpu_utilization(), load_report.cpu_utilization()); + EXPECT_EQ(actual->mem_utilization(), load_report.mem_utilization()); + EXPECT_EQ(actual->rps(), load_report.rps()); + EXPECT_EQ(actual->request_cost().size(), load_report.request_cost().size()); + for (const auto& p : actual->request_cost()) { + auto it = load_report.request_cost().find(p.first); + ASSERT_NE(it, load_report.request_cost().end()); + EXPECT_EQ(it->second, p.second); + } + EXPECT_EQ(actual->utilization().size(), load_report.utilization().size()); + for (const auto& p : actual->utilization()) { + auto it = load_report.utilization().find(p.first); + ASSERT_NE(it, load_report.utilization().end()); + EXPECT_EQ(it->second, p.second); + } + } + // Check LB policy name for the channel. + EXPECT_EQ("intercept_trailing_metadata_lb", + channel->GetLoadBalancingPolicyName()); + EXPECT_EQ(kNumRpcs, trailers_intercepted()); +} + +class ClientLbAddressTest : public ClientLbEnd2endTest { + protected: + static const char* kAttributeKey; + + class Attribute : public grpc_core::ServerAddress::AttributeInterface { + public: + explicit Attribute(const TString& str) : str_(str) {} + + std::unique_ptr<AttributeInterface> Copy() const override { + return y_absl::make_unique<Attribute>(str_); + } + + int Cmp(const AttributeInterface* other) const override { + return str_.compare(static_cast<const Attribute*>(other)->str_); + } + + TString ToString() const override { return str_; } + + private: + TString str_; + }; + + void SetUp() override { + ClientLbEnd2endTest::SetUp(); + current_test_instance_ = this; + } + + static void SetUpTestCase() { + grpc_init(); + grpc_core::RegisterAddressTestLoadBalancingPolicy(SaveAddress); + } + + static void TearDownTestCase() { grpc_shutdown_blocking(); } + + const std::vector<TString>& addresses_seen() { + grpc::internal::MutexLock lock(&mu_); + return addresses_seen_; + } + + private: + static void SaveAddress(const grpc_core::ServerAddress& address) { + ClientLbAddressTest* self = current_test_instance_; + grpc::internal::MutexLock lock(&self->mu_); + self->addresses_seen_.emplace_back(address.ToString()); + } + + static ClientLbAddressTest* current_test_instance_; + grpc::internal::Mutex mu_; + std::vector<TString> addresses_seen_; +}; + +const char* ClientLbAddressTest::kAttributeKey = "attribute_key"; + +ClientLbAddressTest* ClientLbAddressTest::current_test_instance_ = nullptr; + +TEST_F(ClientLbAddressTest, Basic) { + const int kNumServers = 1; + StartServers(kNumServers); + auto response_generator = BuildResolverResponseGenerator(); + auto channel = BuildChannel("address_test_lb", response_generator); + auto stub = BuildStub(channel); + // Addresses returned by the resolver will have attached attributes. + response_generator.SetNextResolution(GetServersPorts(), nullptr, + kAttributeKey, + y_absl::make_unique<Attribute>("foo")); + CheckRpcSendOk(stub, DEBUG_LOCATION); + // Check LB policy name for the channel. + EXPECT_EQ("address_test_lb", channel->GetLoadBalancingPolicyName()); + // Make sure that the attributes wind up on the subchannels. + std::vector<TString> expected; + for (const int port : GetServersPorts()) { + expected.emplace_back(y_absl::StrCat( + "127.0.0.1:", port, " args={} attributes={", kAttributeKey, "=foo}")); + } + EXPECT_EQ(addresses_seen(), expected); +} + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(argc, argv); + const auto result = RUN_ALL_TESTS(); + return result; +} diff --git a/contrib/libs/grpc/test/cpp/end2end/delegating_channel_test.cc b/contrib/libs/grpc/test/cpp/end2end/delegating_channel_test.cc new file mode 100644 index 0000000000..5d025ecb94 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/end2end/delegating_channel_test.cc @@ -0,0 +1,100 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <memory> +#include <vector> + +#include <grpcpp/channel.h> +#include <grpcpp/client_context.h> +#include <grpcpp/create_channel.h> +#include <grpcpp/generic/generic_stub.h> +#include <grpcpp/impl/codegen/delegating_channel.h> +#include <grpcpp/impl/codegen/proto_utils.h> +#include <grpcpp/server.h> +#include <grpcpp/server_builder.h> +#include <grpcpp/server_context.h> +#include <grpcpp/support/client_interceptor.h> + +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/end2end/test_service_impl.h" +#include "test/cpp/util/byte_buffer_proto_helper.h" +#include "test/cpp/util/string_ref_helper.h" + +#include <gtest/gtest.h> + +namespace grpc { +namespace testing { +namespace { + +class TestChannel : public experimental::DelegatingChannel { + public: + TestChannel(const std::shared_ptr<ChannelInterface>& delegate_channel) + : experimental::DelegatingChannel(delegate_channel) {} + // Always returns GRPC_CHANNEL_READY + grpc_connectivity_state GetState(bool /*try_to_connect*/) override { + return GRPC_CHANNEL_READY; + } +}; + +class DelegatingChannelTest : public ::testing::Test { + protected: + DelegatingChannelTest() { + int port = grpc_pick_unused_port_or_die(); + ServerBuilder builder; + server_address_ = "localhost:" + ToString(port); + builder.AddListeningPort(server_address_, InsecureServerCredentials()); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + } + + ~DelegatingChannelTest() { server_->Shutdown(); } + + TString server_address_; + TestServiceImpl service_; + std::unique_ptr<Server> server_; +}; + +TEST_F(DelegatingChannelTest, SimpleTest) { + auto channel = CreateChannel(server_address_, InsecureChannelCredentials()); + std::shared_ptr<TestChannel> test_channel = + std::make_shared<TestChannel>(channel); + // gRPC channel should be in idle state at this point but our test channel + // will return ready. + EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_IDLE); + EXPECT_EQ(test_channel->GetState(false), GRPC_CHANNEL_READY); + auto stub = grpc::testing::EchoTestService::NewStub(test_channel); + ClientContext ctx; + EchoRequest req; + req.set_message("Hello"); + EchoResponse resp; + Status s = stub->Echo(&ctx, req, &resp); + EXPECT_EQ(s.ok(), true); + EXPECT_EQ(resp.message(), "Hello"); +} + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/contrib/libs/grpc/test/cpp/end2end/end2end_test.cc b/contrib/libs/grpc/test/cpp/end2end/end2end_test.cc new file mode 100644 index 0000000000..ad2ddb7e84 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/end2end/end2end_test.cc @@ -0,0 +1,2357 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <grpc/grpc.h> +#include <grpc/support/alloc.h> +#include <grpc/support/log.h> +#include <grpc/support/time.h> +#include <grpcpp/channel.h> +#include <grpcpp/client_context.h> +#include <grpcpp/create_channel.h> +#include <grpcpp/impl/codegen/status_code_enum.h> +#include <grpcpp/resource_quota.h> +#include <grpcpp/security/auth_metadata_processor.h> +#include <grpcpp/security/credentials.h> +#include <grpcpp/security/server_credentials.h> +#include <grpcpp/server.h> +#include <grpcpp/server_builder.h> +#include <grpcpp/server_context.h> +#include <grpcpp/support/string_ref.h> +#include <grpcpp/test/channel_test_peer.h> + +#include <mutex> +#include <thread> + +#include "y_absl/strings/str_format.h" +#include "src/core/ext/filters/client_channel/backup_poller.h" +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/iomgr/iomgr.h" +#include "src/core/lib/security/credentials/credentials.h" +#include "src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/end2end/interceptors_util.h" +#include "test/cpp/end2end/test_service_impl.h" +#include "test/cpp/util/string_ref_helper.h" +#include "test/cpp/util/test_credentials_provider.h" + +#ifdef GRPC_POSIX_SOCKET_EV +#include "src/core/lib/iomgr/ev_posix.h" +#endif // GRPC_POSIX_SOCKET_EV + +#include <gtest/gtest.h> + +using grpc::testing::EchoRequest; +using grpc::testing::EchoResponse; +using grpc::testing::kTlsCredentialsType; +using std::chrono::system_clock; + +// MAYBE_SKIP_TEST is a macro to determine if this particular test configuration +// should be skipped based on a decision made at SetUp time. In particular, +// tests that use the callback server can only be run if the iomgr can run in +// the background or if the transport is in-process. +#define MAYBE_SKIP_TEST \ + do { \ + if (do_not_test_) { \ + return; \ + } \ + } while (0) + +namespace grpc { +namespace testing { +namespace { + +bool CheckIsLocalhost(const TString& addr) { + const TString kIpv6("ipv6:[::1]:"); + const TString kIpv4MappedIpv6("ipv6:[::ffff:127.0.0.1]:"); + const TString kIpv4("ipv4:127.0.0.1:"); + return addr.substr(0, kIpv4.size()) == kIpv4 || + addr.substr(0, kIpv4MappedIpv6.size()) == kIpv4MappedIpv6 || + addr.substr(0, kIpv6.size()) == kIpv6; +} + +const int kClientChannelBackupPollIntervalMs = 200; + +const char kTestCredsPluginErrorMsg[] = "Could not find plugin metadata."; + +const char kFakeToken[] = "fake_token"; +const char kFakeSelector[] = "fake_selector"; +const char kExpectedFakeCredsDebugString[] = + "SecureCallCredentials{GoogleIAMCredentials{Token:present," + "AuthoritySelector:fake_selector}}"; + +const char kWrongToken[] = "wrong_token"; +const char kWrongSelector[] = "wrong_selector"; +const char kExpectedWrongCredsDebugString[] = + "SecureCallCredentials{GoogleIAMCredentials{Token:present," + "AuthoritySelector:wrong_selector}}"; + +const char kFakeToken1[] = "fake_token1"; +const char kFakeSelector1[] = "fake_selector1"; +const char kExpectedFakeCreds1DebugString[] = + "SecureCallCredentials{GoogleIAMCredentials{Token:present," + "AuthoritySelector:fake_selector1}}"; + +const char kFakeToken2[] = "fake_token2"; +const char kFakeSelector2[] = "fake_selector2"; +const char kExpectedFakeCreds2DebugString[] = + "SecureCallCredentials{GoogleIAMCredentials{Token:present," + "AuthoritySelector:fake_selector2}}"; + +const char kExpectedAuthMetadataPluginKeyFailureCredsDebugString[] = + "SecureCallCredentials{TestMetadataCredentials{key:TestPluginMetadata," + "value:Does not matter, will fail the key is invalid.}}"; +const char kExpectedAuthMetadataPluginValueFailureCredsDebugString[] = + "SecureCallCredentials{TestMetadataCredentials{key:test-plugin-metadata," + "value:With illegal \n value.}}"; +const char kExpectedAuthMetadataPluginWithDeadlineCredsDebugString[] = + "SecureCallCredentials{TestMetadataCredentials{key:meta_key,value:Does not " + "matter}}"; +const char kExpectedNonBlockingAuthMetadataPluginFailureCredsDebugString[] = + "SecureCallCredentials{TestMetadataCredentials{key:test-plugin-metadata," + "value:Does not matter, will fail anyway (see 3rd param)}}"; +const char + kExpectedNonBlockingAuthMetadataPluginAndProcessorSuccessCredsDebugString + [] = "SecureCallCredentials{TestMetadataCredentials{key:test-plugin-" + "metadata,value:Dr Jekyll}}"; +const char + kExpectedNonBlockingAuthMetadataPluginAndProcessorFailureCredsDebugString + [] = "SecureCallCredentials{TestMetadataCredentials{key:test-plugin-" + "metadata,value:Mr Hyde}}"; +const char kExpectedBlockingAuthMetadataPluginFailureCredsDebugString[] = + "SecureCallCredentials{TestMetadataCredentials{key:test-plugin-metadata," + "value:Does not matter, will fail anyway (see 3rd param)}}"; +const char kExpectedCompositeCallCredsDebugString[] = + "SecureCallCredentials{CompositeCallCredentials{TestMetadataCredentials{" + "key:call-creds-key1,value:call-creds-val1},TestMetadataCredentials{key:" + "call-creds-key2,value:call-creds-val2}}}"; + +class TestMetadataCredentialsPlugin : public MetadataCredentialsPlugin { + public: + static const char kGoodMetadataKey[]; + static const char kBadMetadataKey[]; + + TestMetadataCredentialsPlugin(const grpc::string_ref& metadata_key, + const grpc::string_ref& metadata_value, + bool is_blocking, bool is_successful, + int delay_ms) + : metadata_key_(metadata_key.data(), metadata_key.length()), + metadata_value_(metadata_value.data(), metadata_value.length()), + is_blocking_(is_blocking), + is_successful_(is_successful), + delay_ms_(delay_ms) {} + + bool IsBlocking() const override { return is_blocking_; } + + Status GetMetadata( + grpc::string_ref service_url, grpc::string_ref method_name, + const grpc::AuthContext& channel_auth_context, + std::multimap<TString, TString>* metadata) override { + if (delay_ms_ != 0) { + gpr_sleep_until( + gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_millis(delay_ms_, GPR_TIMESPAN))); + } + EXPECT_GT(service_url.length(), 0UL); + EXPECT_GT(method_name.length(), 0UL); + EXPECT_TRUE(channel_auth_context.IsPeerAuthenticated()); + EXPECT_TRUE(metadata != nullptr); + if (is_successful_) { + metadata->insert(std::make_pair(metadata_key_, metadata_value_)); + return Status::OK; + } else { + return Status(StatusCode::NOT_FOUND, kTestCredsPluginErrorMsg); + } + } + + TString DebugString() override { + return y_absl::StrFormat("TestMetadataCredentials{key:%s,value:%s}", + metadata_key_.c_str(), metadata_value_.c_str()); + } + + private: + TString metadata_key_; + TString metadata_value_; + bool is_blocking_; + bool is_successful_; + int delay_ms_; +}; + +const char TestMetadataCredentialsPlugin::kBadMetadataKey[] = + "TestPluginMetadata"; +const char TestMetadataCredentialsPlugin::kGoodMetadataKey[] = + "test-plugin-metadata"; + +class TestAuthMetadataProcessor : public AuthMetadataProcessor { + public: + static const char kGoodGuy[]; + + TestAuthMetadataProcessor(bool is_blocking) : is_blocking_(is_blocking) {} + + std::shared_ptr<CallCredentials> GetCompatibleClientCreds() { + return grpc::MetadataCredentialsFromPlugin( + std::unique_ptr<MetadataCredentialsPlugin>( + new TestMetadataCredentialsPlugin( + TestMetadataCredentialsPlugin::kGoodMetadataKey, kGoodGuy, + is_blocking_, true, 0))); + } + + std::shared_ptr<CallCredentials> GetIncompatibleClientCreds() { + return grpc::MetadataCredentialsFromPlugin( + std::unique_ptr<MetadataCredentialsPlugin>( + new TestMetadataCredentialsPlugin( + TestMetadataCredentialsPlugin::kGoodMetadataKey, "Mr Hyde", + is_blocking_, true, 0))); + } + + // Interface implementation + bool IsBlocking() const override { return is_blocking_; } + + Status Process(const InputMetadata& auth_metadata, AuthContext* context, + OutputMetadata* consumed_auth_metadata, + OutputMetadata* response_metadata) override { + EXPECT_TRUE(consumed_auth_metadata != nullptr); + EXPECT_TRUE(context != nullptr); + EXPECT_TRUE(response_metadata != nullptr); + auto auth_md = + auth_metadata.find(TestMetadataCredentialsPlugin::kGoodMetadataKey); + EXPECT_NE(auth_md, auth_metadata.end()); + string_ref auth_md_value = auth_md->second; + if (auth_md_value == kGoodGuy) { + context->AddProperty(kIdentityPropName, kGoodGuy); + context->SetPeerIdentityPropertyName(kIdentityPropName); + consumed_auth_metadata->insert(std::make_pair( + string(auth_md->first.data(), auth_md->first.length()), + string(auth_md->second.data(), auth_md->second.length()))); + return Status::OK; + } else { + return Status(StatusCode::UNAUTHENTICATED, + string("Invalid principal: ") + + string(auth_md_value.data(), auth_md_value.length())); + } + } + + private: + static const char kIdentityPropName[]; + bool is_blocking_; +}; + +const char TestAuthMetadataProcessor::kGoodGuy[] = "Dr Jekyll"; +const char TestAuthMetadataProcessor::kIdentityPropName[] = "novel identity"; + +class Proxy : public ::grpc::testing::EchoTestService::Service { + public: + Proxy(const std::shared_ptr<Channel>& channel) + : stub_(grpc::testing::EchoTestService::NewStub(channel)) {} + + Status Echo(ServerContext* server_context, const EchoRequest* request, + EchoResponse* response) override { + std::unique_ptr<ClientContext> client_context = + ClientContext::FromServerContext(*server_context); + return stub_->Echo(client_context.get(), *request, response); + } + + private: + std::unique_ptr<::grpc::testing::EchoTestService::Stub> stub_; +}; + +class TestServiceImplDupPkg + : public ::grpc::testing::duplicate::EchoTestService::Service { + public: + Status Echo(ServerContext* /*context*/, const EchoRequest* /*request*/, + EchoResponse* response) override { + response->set_message("no package"); + return Status::OK; + } +}; + +class TestScenario { + public: + TestScenario(bool interceptors, bool proxy, bool inproc_stub, + const TString& creds_type, bool use_callback_server) + : use_interceptors(interceptors), + use_proxy(proxy), + inproc(inproc_stub), + credentials_type(creds_type), + callback_server(use_callback_server) {} + void Log() const; + bool use_interceptors; + bool use_proxy; + bool inproc; + const TString credentials_type; + bool callback_server; +}; + +static std::ostream& operator<<(std::ostream& out, + const TestScenario& scenario) { + return out << "TestScenario{use_interceptors=" + << (scenario.use_interceptors ? "true" : "false") + << ", use_proxy=" << (scenario.use_proxy ? "true" : "false") + << ", inproc=" << (scenario.inproc ? "true" : "false") + << ", server_type=" + << (scenario.callback_server ? "callback" : "sync") + << ", credentials='" << scenario.credentials_type << "'}"; +} + +void TestScenario::Log() const { + std::ostringstream out; + out << *this; + gpr_log(GPR_DEBUG, "%s", out.str().c_str()); +} + +class End2endTest : public ::testing::TestWithParam<TestScenario> { + protected: + static void SetUpTestCase() { grpc_init(); } + static void TearDownTestCase() { grpc_shutdown(); } + End2endTest() + : is_server_started_(false), + kMaxMessageSize_(8192), + special_service_("special"), + first_picked_port_(0) { + GetParam().Log(); + } + + void SetUp() override { + if (GetParam().callback_server && !GetParam().inproc && + !grpc_iomgr_run_in_background()) { + do_not_test_ = true; + return; + } + } + + void TearDown() override { + if (is_server_started_) { + server_->Shutdown(); + if (proxy_server_) proxy_server_->Shutdown(); + } + if (first_picked_port_ > 0) { + grpc_recycle_unused_port(first_picked_port_); + } + } + + void StartServer(const std::shared_ptr<AuthMetadataProcessor>& processor) { + int port = grpc_pick_unused_port_or_die(); + first_picked_port_ = port; + server_address_ << "127.0.0.1:" << port; + // Setup server + BuildAndStartServer(processor); + } + + void RestartServer(const std::shared_ptr<AuthMetadataProcessor>& processor) { + if (is_server_started_) { + server_->Shutdown(); + BuildAndStartServer(processor); + } + } + + void BuildAndStartServer( + const std::shared_ptr<AuthMetadataProcessor>& processor) { + ServerBuilder builder; + ConfigureServerBuilder(&builder); + auto server_creds = GetCredentialsProvider()->GetServerCredentials( + GetParam().credentials_type); + if (GetParam().credentials_type != kInsecureCredentialsType) { + server_creds->SetAuthMetadataProcessor(processor); + } + if (GetParam().use_interceptors) { + std::vector< + std::unique_ptr<experimental::ServerInterceptorFactoryInterface>> + creators; + // Add 20 dummy server interceptors + creators.reserve(20); + for (auto i = 0; i < 20; i++) { + creators.push_back(std::unique_ptr<DummyInterceptorFactory>( + new DummyInterceptorFactory())); + } + builder.experimental().SetInterceptorCreators(std::move(creators)); + } + builder.AddListeningPort(server_address_.str(), server_creds); + if (!GetParam().callback_server) { + builder.RegisterService(&service_); + } else { + builder.RegisterService(&callback_service_); + } + builder.RegisterService("foo.test.youtube.com", &special_service_); + builder.RegisterService(&dup_pkg_service_); + + builder.SetSyncServerOption(ServerBuilder::SyncServerOption::NUM_CQS, 4); + builder.SetSyncServerOption( + ServerBuilder::SyncServerOption::CQ_TIMEOUT_MSEC, 10); + + server_ = builder.BuildAndStart(); + is_server_started_ = true; + } + + virtual void ConfigureServerBuilder(ServerBuilder* builder) { + builder->SetMaxMessageSize( + kMaxMessageSize_); // For testing max message size. + } + + void ResetChannel( + std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> + interceptor_creators = {}) { + if (!is_server_started_) { + StartServer(std::shared_ptr<AuthMetadataProcessor>()); + } + EXPECT_TRUE(is_server_started_); + ChannelArguments args; + auto channel_creds = GetCredentialsProvider()->GetChannelCredentials( + GetParam().credentials_type, &args); + if (!user_agent_prefix_.empty()) { + args.SetUserAgentPrefix(user_agent_prefix_); + } + args.SetString(GRPC_ARG_SECONDARY_USER_AGENT_STRING, "end2end_test"); + + if (!GetParam().inproc) { + if (!GetParam().use_interceptors) { + channel_ = ::grpc::CreateCustomChannel(server_address_.str(), + channel_creds, args); + } else { + channel_ = CreateCustomChannelWithInterceptors( + server_address_.str(), channel_creds, args, + interceptor_creators.empty() ? CreateDummyClientInterceptors() + : std::move(interceptor_creators)); + } + } else { + if (!GetParam().use_interceptors) { + channel_ = server_->InProcessChannel(args); + } else { + channel_ = server_->experimental().InProcessChannelWithInterceptors( + args, interceptor_creators.empty() + ? CreateDummyClientInterceptors() + : std::move(interceptor_creators)); + } + } + } + + void ResetStub( + std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> + interceptor_creators = {}) { + ResetChannel(std::move(interceptor_creators)); + if (GetParam().use_proxy) { + proxy_service_.reset(new Proxy(channel_)); + int port = grpc_pick_unused_port_or_die(); + std::ostringstream proxyaddr; + proxyaddr << "localhost:" << port; + ServerBuilder builder; + builder.AddListeningPort(proxyaddr.str(), InsecureServerCredentials()); + builder.RegisterService(proxy_service_.get()); + + builder.SetSyncServerOption(ServerBuilder::SyncServerOption::NUM_CQS, 4); + builder.SetSyncServerOption( + ServerBuilder::SyncServerOption::CQ_TIMEOUT_MSEC, 10); + + proxy_server_ = builder.BuildAndStart(); + + channel_ = + grpc::CreateChannel(proxyaddr.str(), InsecureChannelCredentials()); + } + + stub_ = grpc::testing::EchoTestService::NewStub(channel_); + DummyInterceptor::Reset(); + } + + bool do_not_test_{false}; + bool is_server_started_; + std::shared_ptr<Channel> channel_; + std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_; + std::unique_ptr<Server> server_; + std::unique_ptr<Server> proxy_server_; + std::unique_ptr<Proxy> proxy_service_; + std::ostringstream server_address_; + const int kMaxMessageSize_; + TestServiceImpl service_; + CallbackTestServiceImpl callback_service_; + TestServiceImpl special_service_; + TestServiceImplDupPkg dup_pkg_service_; + TString user_agent_prefix_; + int first_picked_port_; +}; + +static void SendRpc(grpc::testing::EchoTestService::Stub* stub, int num_rpcs, + bool with_binary_metadata) { + EchoRequest request; + EchoResponse response; + request.set_message("Hello hello hello hello"); + + for (int i = 0; i < num_rpcs; ++i) { + ClientContext context; + if (with_binary_metadata) { + char bytes[8] = {'\0', '\1', '\2', '\3', + '\4', '\5', '\6', static_cast<char>(i)}; + context.AddMetadata("custom-bin", TString(bytes, 8)); + } + context.set_compression_algorithm(GRPC_COMPRESS_GZIP); + Status s = stub->Echo(&context, request, &response); + EXPECT_EQ(response.message(), request.message()); + EXPECT_TRUE(s.ok()); + } +} + +// This class is for testing scenarios where RPCs are cancelled on the server +// by calling ServerContext::TryCancel() +class End2endServerTryCancelTest : public End2endTest { + protected: + // Helper for testing client-streaming RPCs which are cancelled on the server. + // Depending on the value of server_try_cancel parameter, this will test one + // of the following three scenarios: + // CANCEL_BEFORE_PROCESSING: Rpc is cancelled by the server before reading + // any messages from the client + // + // CANCEL_DURING_PROCESSING: Rpc is cancelled by the server while reading + // messages from the client + // + // CANCEL_AFTER PROCESSING: Rpc is cancelled by server after reading all + // the messages from the client + // + // NOTE: Do not call this function with server_try_cancel == DO_NOT_CANCEL. + void TestRequestStreamServerCancel( + ServerTryCancelRequestPhase server_try_cancel, int num_msgs_to_send) { + MAYBE_SKIP_TEST; + RestartServer(std::shared_ptr<AuthMetadataProcessor>()); + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + + // Send server_try_cancel value in the client metadata + context.AddMetadata(kServerTryCancelRequest, + ToString(server_try_cancel)); + + auto stream = stub_->RequestStream(&context, &response); + + int num_msgs_sent = 0; + while (num_msgs_sent < num_msgs_to_send) { + request.set_message("hello"); + if (!stream->Write(request)) { + break; + } + num_msgs_sent++; + } + gpr_log(GPR_INFO, "Sent %d messages", num_msgs_sent); + + stream->WritesDone(); + Status s = stream->Finish(); + + // At this point, we know for sure that RPC was cancelled by the server + // since we passed server_try_cancel value in the metadata. Depending on the + // value of server_try_cancel, the RPC might have been cancelled by the + // server at different stages. The following validates our expectations of + // number of messages sent in various cancellation scenarios: + + switch (server_try_cancel) { + case CANCEL_BEFORE_PROCESSING: + case CANCEL_DURING_PROCESSING: + // If the RPC is cancelled by server before / during messages from the + // client, it means that the client most likely did not get a chance to + // send all the messages it wanted to send. i.e num_msgs_sent <= + // num_msgs_to_send + EXPECT_LE(num_msgs_sent, num_msgs_to_send); + break; + + case CANCEL_AFTER_PROCESSING: + // If the RPC was cancelled after all messages were read by the server, + // the client did get a chance to send all its messages + EXPECT_EQ(num_msgs_sent, num_msgs_to_send); + break; + + default: + gpr_log(GPR_ERROR, "Invalid server_try_cancel value: %d", + server_try_cancel); + EXPECT_TRUE(server_try_cancel > DO_NOT_CANCEL && + server_try_cancel <= CANCEL_AFTER_PROCESSING); + break; + } + + EXPECT_FALSE(s.ok()); + EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code()); + // Make sure that the server interceptors were notified + if (GetParam().use_interceptors) { + EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel()); + } + } + + // Helper for testing server-streaming RPCs which are cancelled on the server. + // Depending on the value of server_try_cancel parameter, this will test one + // of the following three scenarios: + // CANCEL_BEFORE_PROCESSING: Rpc is cancelled by the server before writing + // any messages to the client + // + // CANCEL_DURING_PROCESSING: Rpc is cancelled by the server while writing + // messages to the client + // + // CANCEL_AFTER PROCESSING: Rpc is cancelled by server after writing all + // the messages to the client + // + // NOTE: Do not call this function with server_try_cancel == DO_NOT_CANCEL. + void TestResponseStreamServerCancel( + ServerTryCancelRequestPhase server_try_cancel) { + MAYBE_SKIP_TEST; + RestartServer(std::shared_ptr<AuthMetadataProcessor>()); + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + + // Send server_try_cancel in the client metadata + context.AddMetadata(kServerTryCancelRequest, + ToString(server_try_cancel)); + + request.set_message("hello"); + auto stream = stub_->ResponseStream(&context, request); + + int num_msgs_read = 0; + while (num_msgs_read < kServerDefaultResponseStreamsToSend) { + if (!stream->Read(&response)) { + break; + } + EXPECT_EQ(response.message(), + request.message() + ToString(num_msgs_read)); + num_msgs_read++; + } + gpr_log(GPR_INFO, "Read %d messages", num_msgs_read); + + Status s = stream->Finish(); + + // Depending on the value of server_try_cancel, the RPC might have been + // cancelled by the server at different stages. The following validates our + // expectations of number of messages read in various cancellation + // scenarios: + switch (server_try_cancel) { + case CANCEL_BEFORE_PROCESSING: + // Server cancelled before sending any messages. Which means the client + // wouldn't have read any + EXPECT_EQ(num_msgs_read, 0); + break; + + case CANCEL_DURING_PROCESSING: + // Server cancelled while writing messages. Client must have read less + // than or equal to the expected number of messages + EXPECT_LE(num_msgs_read, kServerDefaultResponseStreamsToSend); + break; + + case CANCEL_AFTER_PROCESSING: + // Even though the Server cancelled after writing all messages, the RPC + // may be cancelled before the Client got a chance to read all the + // messages. + EXPECT_LE(num_msgs_read, kServerDefaultResponseStreamsToSend); + break; + + default: { + gpr_log(GPR_ERROR, "Invalid server_try_cancel value: %d", + server_try_cancel); + EXPECT_TRUE(server_try_cancel > DO_NOT_CANCEL && + server_try_cancel <= CANCEL_AFTER_PROCESSING); + break; + } + } + + EXPECT_FALSE(s.ok()); + // Make sure that the server interceptors were notified + if (GetParam().use_interceptors) { + EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel()); + } + } + + // Helper for testing bidirectional-streaming RPCs which are cancelled on the + // server. Depending on the value of server_try_cancel parameter, this will + // test one of the following three scenarios: + // CANCEL_BEFORE_PROCESSING: Rpc is cancelled by the server before reading/ + // writing any messages from/to the client + // + // CANCEL_DURING_PROCESSING: Rpc is cancelled by the server while reading/ + // writing messages from/to the client + // + // CANCEL_AFTER PROCESSING: Rpc is cancelled by server after reading/writing + // all the messages from/to the client + // + // NOTE: Do not call this function with server_try_cancel == DO_NOT_CANCEL. + void TestBidiStreamServerCancel(ServerTryCancelRequestPhase server_try_cancel, + int num_messages) { + MAYBE_SKIP_TEST; + RestartServer(std::shared_ptr<AuthMetadataProcessor>()); + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + + // Send server_try_cancel in the client metadata + context.AddMetadata(kServerTryCancelRequest, + ToString(server_try_cancel)); + + auto stream = stub_->BidiStream(&context); + + int num_msgs_read = 0; + int num_msgs_sent = 0; + while (num_msgs_sent < num_messages) { + request.set_message("hello " + ToString(num_msgs_sent)); + if (!stream->Write(request)) { + break; + } + num_msgs_sent++; + + if (!stream->Read(&response)) { + break; + } + num_msgs_read++; + + EXPECT_EQ(response.message(), request.message()); + } + gpr_log(GPR_INFO, "Sent %d messages", num_msgs_sent); + gpr_log(GPR_INFO, "Read %d messages", num_msgs_read); + + stream->WritesDone(); + Status s = stream->Finish(); + + // Depending on the value of server_try_cancel, the RPC might have been + // cancelled by the server at different stages. The following validates our + // expectations of number of messages read in various cancellation + // scenarios: + switch (server_try_cancel) { + case CANCEL_BEFORE_PROCESSING: + EXPECT_EQ(num_msgs_read, 0); + break; + + case CANCEL_DURING_PROCESSING: + EXPECT_LE(num_msgs_sent, num_messages); + EXPECT_LE(num_msgs_read, num_msgs_sent); + break; + + case CANCEL_AFTER_PROCESSING: + EXPECT_EQ(num_msgs_sent, num_messages); + + // The Server cancelled after reading the last message and after writing + // the message to the client. However, the RPC cancellation might have + // taken effect before the client actually read the response. + EXPECT_LE(num_msgs_read, num_msgs_sent); + break; + + default: + gpr_log(GPR_ERROR, "Invalid server_try_cancel value: %d", + server_try_cancel); + EXPECT_TRUE(server_try_cancel > DO_NOT_CANCEL && + server_try_cancel <= CANCEL_AFTER_PROCESSING); + break; + } + + EXPECT_FALSE(s.ok()); + EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code()); + // Make sure that the server interceptors were notified + if (GetParam().use_interceptors) { + EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel()); + } + } +}; + +TEST_P(End2endServerTryCancelTest, RequestEchoServerCancel) { + MAYBE_SKIP_TEST; + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + + context.AddMetadata(kServerTryCancelRequest, + ToString(CANCEL_BEFORE_PROCESSING)); + Status s = stub_->Echo(&context, request, &response); + EXPECT_FALSE(s.ok()); + EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code()); +} + +// Server to cancel before doing reading the request +TEST_P(End2endServerTryCancelTest, RequestStreamServerCancelBeforeReads) { + TestRequestStreamServerCancel(CANCEL_BEFORE_PROCESSING, 1); +} + +// Server to cancel while reading a request from the stream in parallel +TEST_P(End2endServerTryCancelTest, RequestStreamServerCancelDuringRead) { + TestRequestStreamServerCancel(CANCEL_DURING_PROCESSING, 10); +} + +// Server to cancel after reading all the requests but before returning to the +// client +TEST_P(End2endServerTryCancelTest, RequestStreamServerCancelAfterReads) { + TestRequestStreamServerCancel(CANCEL_AFTER_PROCESSING, 4); +} + +// Server to cancel before sending any response messages +TEST_P(End2endServerTryCancelTest, ResponseStreamServerCancelBefore) { + TestResponseStreamServerCancel(CANCEL_BEFORE_PROCESSING); +} + +// Server to cancel while writing a response to the stream in parallel +TEST_P(End2endServerTryCancelTest, ResponseStreamServerCancelDuring) { + TestResponseStreamServerCancel(CANCEL_DURING_PROCESSING); +} + +// Server to cancel after writing all the respones to the stream but before +// returning to the client +TEST_P(End2endServerTryCancelTest, ResponseStreamServerCancelAfter) { + TestResponseStreamServerCancel(CANCEL_AFTER_PROCESSING); +} + +// Server to cancel before reading/writing any requests/responses on the stream +TEST_P(End2endServerTryCancelTest, BidiStreamServerCancelBefore) { + TestBidiStreamServerCancel(CANCEL_BEFORE_PROCESSING, 2); +} + +// Server to cancel while reading/writing requests/responses on the stream in +// parallel +TEST_P(End2endServerTryCancelTest, BidiStreamServerCancelDuring) { + TestBidiStreamServerCancel(CANCEL_DURING_PROCESSING, 10); +} + +// Server to cancel after reading/writing all requests/responses on the stream +// but before returning to the client +TEST_P(End2endServerTryCancelTest, BidiStreamServerCancelAfter) { + TestBidiStreamServerCancel(CANCEL_AFTER_PROCESSING, 5); +} + +TEST_P(End2endTest, SimpleRpcWithCustomUserAgentPrefix) { + MAYBE_SKIP_TEST; + // User-Agent is an HTTP header for HTTP transports only + if (GetParam().inproc) { + return; + } + user_agent_prefix_ = "custom_prefix"; + ResetStub(); + EchoRequest request; + EchoResponse response; + request.set_message("Hello hello hello hello"); + request.mutable_param()->set_echo_metadata(true); + + ClientContext context; + Status s = stub_->Echo(&context, request, &response); + EXPECT_EQ(response.message(), request.message()); + EXPECT_TRUE(s.ok()); + const auto& trailing_metadata = context.GetServerTrailingMetadata(); + auto iter = trailing_metadata.find("user-agent"); + EXPECT_TRUE(iter != trailing_metadata.end()); + TString expected_prefix = user_agent_prefix_ + " grpc-c++/"; + EXPECT_TRUE(iter->second.starts_with(expected_prefix)) << iter->second; +} + +TEST_P(End2endTest, MultipleRpcsWithVariedBinaryMetadataValue) { + MAYBE_SKIP_TEST; + ResetStub(); + std::vector<std::thread> threads; + threads.reserve(10); + for (int i = 0; i < 10; ++i) { + threads.emplace_back(SendRpc, stub_.get(), 10, true); + } + for (int i = 0; i < 10; ++i) { + threads[i].join(); + } +} + +TEST_P(End2endTest, MultipleRpcs) { + MAYBE_SKIP_TEST; + ResetStub(); + std::vector<std::thread> threads; + threads.reserve(10); + for (int i = 0; i < 10; ++i) { + threads.emplace_back(SendRpc, stub_.get(), 10, false); + } + for (int i = 0; i < 10; ++i) { + threads[i].join(); + } +} + +TEST_P(End2endTest, ManyStubs) { + MAYBE_SKIP_TEST; + ResetStub(); + ChannelTestPeer peer(channel_.get()); + int registered_calls_pre = peer.registered_calls(); + int registration_attempts_pre = peer.registration_attempts(); + for (int i = 0; i < 1000; ++i) { + grpc::testing::EchoTestService::NewStub(channel_); + } + EXPECT_EQ(peer.registered_calls(), registered_calls_pre); + EXPECT_GT(peer.registration_attempts(), registration_attempts_pre); +} + +TEST_P(End2endTest, EmptyBinaryMetadata) { + MAYBE_SKIP_TEST; + ResetStub(); + EchoRequest request; + EchoResponse response; + request.set_message("Hello hello hello hello"); + ClientContext context; + context.AddMetadata("custom-bin", ""); + Status s = stub_->Echo(&context, request, &response); + EXPECT_EQ(response.message(), request.message()); + EXPECT_TRUE(s.ok()); +} + +TEST_P(End2endTest, ReconnectChannel) { + MAYBE_SKIP_TEST; + if (GetParam().inproc) { + return; + } + int poller_slowdown_factor = 1; + // It needs 2 pollset_works to reconnect the channel with polling engine + // "poll" +#ifdef GRPC_POSIX_SOCKET_EV + grpc_core::UniquePtr<char> poller = GPR_GLOBAL_CONFIG_GET(grpc_poll_strategy); + if (0 == strcmp(poller.get(), "poll")) { + poller_slowdown_factor = 2; + } +#endif // GRPC_POSIX_SOCKET_EV + ResetStub(); + SendRpc(stub_.get(), 1, false); + RestartServer(std::shared_ptr<AuthMetadataProcessor>()); + // It needs more than GRPC_CLIENT_CHANNEL_BACKUP_POLL_INTERVAL_MS time to + // reconnect the channel. Make it a factor of 5x + gpr_sleep_until( + gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_millis(kClientChannelBackupPollIntervalMs * 5 * + poller_slowdown_factor * + grpc_test_slowdown_factor(), + GPR_TIMESPAN))); + SendRpc(stub_.get(), 1, false); +} + +TEST_P(End2endTest, RequestStreamOneRequest) { + MAYBE_SKIP_TEST; + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + + auto stream = stub_->RequestStream(&context, &response); + request.set_message("hello"); + EXPECT_TRUE(stream->Write(request)); + stream->WritesDone(); + Status s = stream->Finish(); + EXPECT_EQ(response.message(), request.message()); + EXPECT_TRUE(s.ok()); + EXPECT_TRUE(context.debug_error_string().empty()); +} + +TEST_P(End2endTest, RequestStreamOneRequestWithCoalescingApi) { + MAYBE_SKIP_TEST; + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + + context.set_initial_metadata_corked(true); + auto stream = stub_->RequestStream(&context, &response); + request.set_message("hello"); + stream->WriteLast(request, WriteOptions()); + Status s = stream->Finish(); + EXPECT_EQ(response.message(), request.message()); + EXPECT_TRUE(s.ok()); +} + +TEST_P(End2endTest, RequestStreamTwoRequests) { + MAYBE_SKIP_TEST; + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + + auto stream = stub_->RequestStream(&context, &response); + request.set_message("hello"); + EXPECT_TRUE(stream->Write(request)); + EXPECT_TRUE(stream->Write(request)); + stream->WritesDone(); + Status s = stream->Finish(); + EXPECT_EQ(response.message(), "hellohello"); + EXPECT_TRUE(s.ok()); +} + +TEST_P(End2endTest, RequestStreamTwoRequestsWithWriteThrough) { + MAYBE_SKIP_TEST; + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + + auto stream = stub_->RequestStream(&context, &response); + request.set_message("hello"); + EXPECT_TRUE(stream->Write(request, WriteOptions().set_write_through())); + EXPECT_TRUE(stream->Write(request, WriteOptions().set_write_through())); + stream->WritesDone(); + Status s = stream->Finish(); + EXPECT_EQ(response.message(), "hellohello"); + EXPECT_TRUE(s.ok()); +} + +TEST_P(End2endTest, RequestStreamTwoRequestsWithCoalescingApi) { + MAYBE_SKIP_TEST; + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + + context.set_initial_metadata_corked(true); + auto stream = stub_->RequestStream(&context, &response); + request.set_message("hello"); + EXPECT_TRUE(stream->Write(request)); + stream->WriteLast(request, WriteOptions()); + Status s = stream->Finish(); + EXPECT_EQ(response.message(), "hellohello"); + EXPECT_TRUE(s.ok()); +} + +TEST_P(End2endTest, ResponseStream) { + MAYBE_SKIP_TEST; + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + request.set_message("hello"); + + auto stream = stub_->ResponseStream(&context, request); + for (int i = 0; i < kServerDefaultResponseStreamsToSend; ++i) { + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message() + ToString(i)); + } + EXPECT_FALSE(stream->Read(&response)); + + Status s = stream->Finish(); + EXPECT_TRUE(s.ok()); +} + +TEST_P(End2endTest, ResponseStreamWithCoalescingApi) { + MAYBE_SKIP_TEST; + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + request.set_message("hello"); + context.AddMetadata(kServerUseCoalescingApi, "1"); + + auto stream = stub_->ResponseStream(&context, request); + for (int i = 0; i < kServerDefaultResponseStreamsToSend; ++i) { + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message() + ToString(i)); + } + EXPECT_FALSE(stream->Read(&response)); + + Status s = stream->Finish(); + EXPECT_TRUE(s.ok()); +} + +// This was added to prevent regression from issue: +// https://github.com/grpc/grpc/issues/11546 +TEST_P(End2endTest, ResponseStreamWithEverythingCoalesced) { + MAYBE_SKIP_TEST; + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + request.set_message("hello"); + context.AddMetadata(kServerUseCoalescingApi, "1"); + // We will only send one message, forcing everything (init metadata, message, + // trailing) to be coalesced together. + context.AddMetadata(kServerResponseStreamsToSend, "1"); + + auto stream = stub_->ResponseStream(&context, request); + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message() + "0"); + + EXPECT_FALSE(stream->Read(&response)); + + Status s = stream->Finish(); + EXPECT_TRUE(s.ok()); +} + +TEST_P(End2endTest, BidiStream) { + MAYBE_SKIP_TEST; + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + TString msg("hello"); + + auto stream = stub_->BidiStream(&context); + + for (int i = 0; i < kServerDefaultResponseStreamsToSend; ++i) { + request.set_message(msg + ToString(i)); + EXPECT_TRUE(stream->Write(request)); + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message()); + } + + stream->WritesDone(); + EXPECT_FALSE(stream->Read(&response)); + EXPECT_FALSE(stream->Read(&response)); + + Status s = stream->Finish(); + EXPECT_TRUE(s.ok()); +} + +TEST_P(End2endTest, BidiStreamWithCoalescingApi) { + MAYBE_SKIP_TEST; + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + context.AddMetadata(kServerFinishAfterNReads, "3"); + context.set_initial_metadata_corked(true); + TString msg("hello"); + + auto stream = stub_->BidiStream(&context); + + request.set_message(msg + "0"); + EXPECT_TRUE(stream->Write(request)); + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message()); + + request.set_message(msg + "1"); + EXPECT_TRUE(stream->Write(request)); + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message()); + + request.set_message(msg + "2"); + stream->WriteLast(request, WriteOptions()); + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message()); + + EXPECT_FALSE(stream->Read(&response)); + EXPECT_FALSE(stream->Read(&response)); + + Status s = stream->Finish(); + EXPECT_TRUE(s.ok()); +} + +// This was added to prevent regression from issue: +// https://github.com/grpc/grpc/issues/11546 +TEST_P(End2endTest, BidiStreamWithEverythingCoalesced) { + MAYBE_SKIP_TEST; + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + context.AddMetadata(kServerFinishAfterNReads, "1"); + context.set_initial_metadata_corked(true); + TString msg("hello"); + + auto stream = stub_->BidiStream(&context); + + request.set_message(msg + "0"); + stream->WriteLast(request, WriteOptions()); + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message()); + + EXPECT_FALSE(stream->Read(&response)); + EXPECT_FALSE(stream->Read(&response)); + + Status s = stream->Finish(); + EXPECT_TRUE(s.ok()); +} + +// Talk to the two services with the same name but different package names. +// The two stubs are created on the same channel. +TEST_P(End2endTest, DiffPackageServices) { + MAYBE_SKIP_TEST; + ResetStub(); + EchoRequest request; + EchoResponse response; + request.set_message("Hello"); + + ClientContext context; + Status s = stub_->Echo(&context, request, &response); + EXPECT_EQ(response.message(), request.message()); + EXPECT_TRUE(s.ok()); + + std::unique_ptr<grpc::testing::duplicate::EchoTestService::Stub> dup_pkg_stub( + grpc::testing::duplicate::EchoTestService::NewStub(channel_)); + ClientContext context2; + s = dup_pkg_stub->Echo(&context2, request, &response); + EXPECT_EQ("no package", response.message()); + EXPECT_TRUE(s.ok()); +} + +template <class ServiceType> +void CancelRpc(ClientContext* context, int delay_us, ServiceType* service) { + gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_micros(delay_us, GPR_TIMESPAN))); + while (!service->signal_client()) { + } + context->TryCancel(); +} + +TEST_P(End2endTest, CancelRpcBeforeStart) { + MAYBE_SKIP_TEST; + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + request.set_message("hello"); + context.TryCancel(); + Status s = stub_->Echo(&context, request, &response); + EXPECT_EQ("", response.message()); + EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code()); + if (GetParam().use_interceptors) { + EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel()); + } +} + +TEST_P(End2endTest, CancelRpcAfterStart) { + MAYBE_SKIP_TEST; + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + request.set_message("hello"); + request.mutable_param()->set_server_notify_client_when_started(true); + request.mutable_param()->set_skip_cancelled_check(true); + Status s; + std::thread echo_thread([this, &s, &context, &request, &response] { + s = stub_->Echo(&context, request, &response); + EXPECT_EQ(StatusCode::CANCELLED, s.error_code()); + }); + if (!GetParam().callback_server) { + service_.ClientWaitUntilRpcStarted(); + } else { + callback_service_.ClientWaitUntilRpcStarted(); + } + + context.TryCancel(); + + if (!GetParam().callback_server) { + service_.SignalServerToContinue(); + } else { + callback_service_.SignalServerToContinue(); + } + + echo_thread.join(); + EXPECT_EQ("", response.message()); + EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code()); + if (GetParam().use_interceptors) { + EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel()); + } +} + +// Client cancels request stream after sending two messages +TEST_P(End2endTest, ClientCancelsRequestStream) { + MAYBE_SKIP_TEST; + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + request.set_message("hello"); + + auto stream = stub_->RequestStream(&context, &response); + EXPECT_TRUE(stream->Write(request)); + EXPECT_TRUE(stream->Write(request)); + + context.TryCancel(); + + Status s = stream->Finish(); + EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code()); + + EXPECT_EQ(response.message(), ""); + if (GetParam().use_interceptors) { + EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel()); + } +} + +// Client cancels server stream after sending some messages +TEST_P(End2endTest, ClientCancelsResponseStream) { + MAYBE_SKIP_TEST; + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + request.set_message("hello"); + + auto stream = stub_->ResponseStream(&context, request); + + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message() + "0"); + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message() + "1"); + + context.TryCancel(); + + // The cancellation races with responses, so there might be zero or + // one responses pending, read till failure + + if (stream->Read(&response)) { + EXPECT_EQ(response.message(), request.message() + "2"); + // Since we have cancelled, we expect the next attempt to read to fail + EXPECT_FALSE(stream->Read(&response)); + } + + Status s = stream->Finish(); + // The final status could be either of CANCELLED or OK depending on + // who won the race. + EXPECT_GE(grpc::StatusCode::CANCELLED, s.error_code()); + if (GetParam().use_interceptors) { + EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel()); + } +} + +// Client cancels bidi stream after sending some messages +TEST_P(End2endTest, ClientCancelsBidi) { + MAYBE_SKIP_TEST; + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + TString msg("hello"); + + auto stream = stub_->BidiStream(&context); + + request.set_message(msg + "0"); + EXPECT_TRUE(stream->Write(request)); + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message()); + + request.set_message(msg + "1"); + EXPECT_TRUE(stream->Write(request)); + + context.TryCancel(); + + // The cancellation races with responses, so there might be zero or + // one responses pending, read till failure + + if (stream->Read(&response)) { + EXPECT_EQ(response.message(), request.message()); + // Since we have cancelled, we expect the next attempt to read to fail + EXPECT_FALSE(stream->Read(&response)); + } + + Status s = stream->Finish(); + EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code()); + if (GetParam().use_interceptors) { + EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel()); + } +} + +TEST_P(End2endTest, RpcMaxMessageSize) { + MAYBE_SKIP_TEST; + ResetStub(); + EchoRequest request; + EchoResponse response; + request.set_message(string(kMaxMessageSize_ * 2, 'a')); + request.mutable_param()->set_server_die(true); + + ClientContext context; + Status s = stub_->Echo(&context, request, &response); + EXPECT_FALSE(s.ok()); +} + +void ReaderThreadFunc(ClientReaderWriter<EchoRequest, EchoResponse>* stream, + gpr_event* ev) { + EchoResponse resp; + gpr_event_set(ev, (void*)1); + while (stream->Read(&resp)) { + gpr_log(GPR_INFO, "Read message"); + } +} + +// Run a Read and a WritesDone simultaneously. +TEST_P(End2endTest, SimultaneousReadWritesDone) { + MAYBE_SKIP_TEST; + ResetStub(); + ClientContext context; + gpr_event ev; + gpr_event_init(&ev); + auto stream = stub_->BidiStream(&context); + std::thread reader_thread(ReaderThreadFunc, stream.get(), &ev); + gpr_event_wait(&ev, gpr_inf_future(GPR_CLOCK_REALTIME)); + stream->WritesDone(); + reader_thread.join(); + Status s = stream->Finish(); + EXPECT_TRUE(s.ok()); +} + +TEST_P(End2endTest, ChannelState) { + MAYBE_SKIP_TEST; + if (GetParam().inproc) { + return; + } + + ResetStub(); + // Start IDLE + EXPECT_EQ(GRPC_CHANNEL_IDLE, channel_->GetState(false)); + + // Did not ask to connect, no state change. + CompletionQueue cq; + std::chrono::system_clock::time_point deadline = + std::chrono::system_clock::now() + std::chrono::milliseconds(10); + channel_->NotifyOnStateChange(GRPC_CHANNEL_IDLE, deadline, &cq, nullptr); + void* tag; + bool ok = true; + cq.Next(&tag, &ok); + EXPECT_FALSE(ok); + + EXPECT_EQ(GRPC_CHANNEL_IDLE, channel_->GetState(true)); + EXPECT_TRUE(channel_->WaitForStateChange(GRPC_CHANNEL_IDLE, + gpr_inf_future(GPR_CLOCK_REALTIME))); + auto state = channel_->GetState(false); + EXPECT_TRUE(state == GRPC_CHANNEL_CONNECTING || state == GRPC_CHANNEL_READY); +} + +// Takes 10s. +TEST_P(End2endTest, ChannelStateTimeout) { + if ((GetParam().credentials_type != kInsecureCredentialsType) || + GetParam().inproc) { + return; + } + int port = grpc_pick_unused_port_or_die(); + std::ostringstream server_address; + server_address << "127.0.0.1:" << port; + // Channel to non-existing server + auto channel = + grpc::CreateChannel(server_address.str(), InsecureChannelCredentials()); + // Start IDLE + EXPECT_EQ(GRPC_CHANNEL_IDLE, channel->GetState(true)); + + auto state = GRPC_CHANNEL_IDLE; + for (int i = 0; i < 10; i++) { + channel->WaitForStateChange( + state, std::chrono::system_clock::now() + std::chrono::seconds(1)); + state = channel->GetState(false); + } +} + +// Talking to a non-existing service. +TEST_P(End2endTest, NonExistingService) { + MAYBE_SKIP_TEST; + ResetChannel(); + std::unique_ptr<grpc::testing::UnimplementedEchoService::Stub> stub; + stub = grpc::testing::UnimplementedEchoService::NewStub(channel_); + + EchoRequest request; + EchoResponse response; + request.set_message("Hello"); + + ClientContext context; + Status s = stub->Unimplemented(&context, request, &response); + EXPECT_EQ(StatusCode::UNIMPLEMENTED, s.error_code()); + EXPECT_EQ("", s.error_message()); +} + +// Ask the server to send back a serialized proto in trailer. +// This is an example of setting error details. +TEST_P(End2endTest, BinaryTrailerTest) { + MAYBE_SKIP_TEST; + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + + request.mutable_param()->set_echo_metadata(true); + DebugInfo* info = request.mutable_param()->mutable_debug_info(); + info->add_stack_entries("stack_entry_1"); + info->add_stack_entries("stack_entry_2"); + info->add_stack_entries("stack_entry_3"); + info->set_detail("detailed debug info"); + TString expected_string = info->SerializeAsString(); + request.set_message("Hello"); + + Status s = stub_->Echo(&context, request, &response); + EXPECT_FALSE(s.ok()); + auto trailers = context.GetServerTrailingMetadata(); + EXPECT_EQ(1u, trailers.count(kDebugInfoTrailerKey)); + auto iter = trailers.find(kDebugInfoTrailerKey); + EXPECT_EQ(expected_string, iter->second); + // Parse the returned trailer into a DebugInfo proto. + DebugInfo returned_info; + EXPECT_TRUE(returned_info.ParseFromString(ToString(iter->second))); +} + +TEST_P(End2endTest, ExpectErrorTest) { + MAYBE_SKIP_TEST; + ResetStub(); + + std::vector<ErrorStatus> expected_status; + expected_status.emplace_back(); + expected_status.back().set_code(13); // INTERNAL + // No Error message or details + + expected_status.emplace_back(); + expected_status.back().set_code(13); // INTERNAL + expected_status.back().set_error_message("text error message"); + expected_status.back().set_binary_error_details("text error details"); + + expected_status.emplace_back(); + expected_status.back().set_code(13); // INTERNAL + expected_status.back().set_error_message("text error message"); + expected_status.back().set_binary_error_details( + "\x0\x1\x2\x3\x4\x5\x6\x8\x9\xA\xB"); + + for (auto iter = expected_status.begin(); iter != expected_status.end(); + ++iter) { + EchoRequest request; + EchoResponse response; + ClientContext context; + request.set_message("Hello"); + auto* error = request.mutable_param()->mutable_expected_error(); + error->set_code(iter->code()); + error->set_error_message(iter->error_message()); + error->set_binary_error_details(iter->binary_error_details()); + + Status s = stub_->Echo(&context, request, &response); + EXPECT_FALSE(s.ok()); + EXPECT_EQ(iter->code(), s.error_code()); + EXPECT_EQ(iter->error_message(), s.error_message()); + EXPECT_EQ(iter->binary_error_details(), s.error_details()); + EXPECT_TRUE(context.debug_error_string().find("created") != + TString::npos); + EXPECT_TRUE(context.debug_error_string().find("file") != TString::npos); + EXPECT_TRUE(context.debug_error_string().find("line") != TString::npos); + EXPECT_TRUE(context.debug_error_string().find("status") != + TString::npos); + EXPECT_TRUE(context.debug_error_string().find("13") != TString::npos); + } +} + +////////////////////////////////////////////////////////////////////////// +// Test with and without a proxy. +class ProxyEnd2endTest : public End2endTest { + protected: +}; + +TEST_P(ProxyEnd2endTest, SimpleRpc) { + MAYBE_SKIP_TEST; + ResetStub(); + SendRpc(stub_.get(), 1, false); +} + +TEST_P(ProxyEnd2endTest, SimpleRpcWithEmptyMessages) { + MAYBE_SKIP_TEST; + ResetStub(); + EchoRequest request; + EchoResponse response; + + ClientContext context; + Status s = stub_->Echo(&context, request, &response); + EXPECT_TRUE(s.ok()); +} + +TEST_P(ProxyEnd2endTest, MultipleRpcs) { + MAYBE_SKIP_TEST; + ResetStub(); + std::vector<std::thread> threads; + threads.reserve(10); + for (int i = 0; i < 10; ++i) { + threads.emplace_back(SendRpc, stub_.get(), 10, false); + } + for (int i = 0; i < 10; ++i) { + threads[i].join(); + } +} + +// Set a 10us deadline and make sure proper error is returned. +TEST_P(ProxyEnd2endTest, RpcDeadlineExpires) { + MAYBE_SKIP_TEST; + ResetStub(); + EchoRequest request; + EchoResponse response; + request.set_message("Hello"); + request.mutable_param()->set_skip_cancelled_check(true); + // Let server sleep for 40 ms first to guarantee expiry. + // 40 ms might seem a bit extreme but the timer manager would have been just + // initialized (when ResetStub() was called) and there are some warmup costs + // i.e the timer thread many not have even started. There might also be other + // delays in the timer manager thread (in acquiring locks, timer data + // structure manipulations, starting backup timer threads) that add to the + // delays. 40ms is still not enough in some cases but this significantly + // reduces the test flakes + request.mutable_param()->set_server_sleep_us(40 * 1000); + + ClientContext context; + std::chrono::system_clock::time_point deadline = + std::chrono::system_clock::now() + std::chrono::milliseconds(1); + context.set_deadline(deadline); + Status s = stub_->Echo(&context, request, &response); + EXPECT_EQ(StatusCode::DEADLINE_EXCEEDED, s.error_code()); +} + +// Set a long but finite deadline. +TEST_P(ProxyEnd2endTest, RpcLongDeadline) { + MAYBE_SKIP_TEST; + ResetStub(); + EchoRequest request; + EchoResponse response; + request.set_message("Hello"); + + ClientContext context; + std::chrono::system_clock::time_point deadline = + std::chrono::system_clock::now() + std::chrono::hours(1); + context.set_deadline(deadline); + Status s = stub_->Echo(&context, request, &response); + EXPECT_EQ(response.message(), request.message()); + EXPECT_TRUE(s.ok()); +} + +// Ask server to echo back the deadline it sees. +TEST_P(ProxyEnd2endTest, EchoDeadline) { + MAYBE_SKIP_TEST; + ResetStub(); + EchoRequest request; + EchoResponse response; + request.set_message("Hello"); + request.mutable_param()->set_echo_deadline(true); + + ClientContext context; + std::chrono::system_clock::time_point deadline = + std::chrono::system_clock::now() + std::chrono::seconds(100); + context.set_deadline(deadline); + Status s = stub_->Echo(&context, request, &response); + EXPECT_EQ(response.message(), request.message()); + EXPECT_TRUE(s.ok()); + gpr_timespec sent_deadline; + Timepoint2Timespec(deadline, &sent_deadline); + // We want to allow some reasonable error given: + // - request_deadline() only has 1sec resolution so the best we can do is +-1 + // - if sent_deadline.tv_nsec is very close to the next second's boundary we + // can end up being off by 2 in one direction. + EXPECT_LE(response.param().request_deadline() - sent_deadline.tv_sec, 2); + EXPECT_GE(response.param().request_deadline() - sent_deadline.tv_sec, -1); +} + +// Ask server to echo back the deadline it sees. The rpc has no deadline. +TEST_P(ProxyEnd2endTest, EchoDeadlineForNoDeadlineRpc) { + MAYBE_SKIP_TEST; + ResetStub(); + EchoRequest request; + EchoResponse response; + request.set_message("Hello"); + request.mutable_param()->set_echo_deadline(true); + + ClientContext context; + Status s = stub_->Echo(&context, request, &response); + EXPECT_EQ(response.message(), request.message()); + EXPECT_TRUE(s.ok()); + EXPECT_EQ(response.param().request_deadline(), + gpr_inf_future(GPR_CLOCK_REALTIME).tv_sec); +} + +TEST_P(ProxyEnd2endTest, UnimplementedRpc) { + MAYBE_SKIP_TEST; + ResetStub(); + EchoRequest request; + EchoResponse response; + request.set_message("Hello"); + + ClientContext context; + Status s = stub_->Unimplemented(&context, request, &response); + EXPECT_FALSE(s.ok()); + EXPECT_EQ(s.error_code(), grpc::StatusCode::UNIMPLEMENTED); + EXPECT_EQ(s.error_message(), ""); + EXPECT_EQ(response.message(), ""); +} + +// Client cancels rpc after 10ms +TEST_P(ProxyEnd2endTest, ClientCancelsRpc) { + MAYBE_SKIP_TEST; + ResetStub(); + EchoRequest request; + EchoResponse response; + request.set_message("Hello"); + const int kCancelDelayUs = 10 * 1000; + request.mutable_param()->set_client_cancel_after_us(kCancelDelayUs); + + ClientContext context; + std::thread cancel_thread; + if (!GetParam().callback_server) { + cancel_thread = std::thread( + [&context, this](int delay) { CancelRpc(&context, delay, &service_); }, + kCancelDelayUs); + // Note: the unusual pattern above (and below) is caused by a conflict + // between two sets of compiler expectations. clang allows const to be + // captured without mention, so there is no need to capture kCancelDelayUs + // (and indeed clang-tidy complains if you do so). OTOH, a Windows compiler + // in our tests requires an explicit capture even for const. We square this + // circle by passing the const value in as an argument to the lambda. + } else { + cancel_thread = std::thread( + [&context, this](int delay) { + CancelRpc(&context, delay, &callback_service_); + }, + kCancelDelayUs); + } + Status s = stub_->Echo(&context, request, &response); + cancel_thread.join(); + EXPECT_EQ(StatusCode::CANCELLED, s.error_code()); + EXPECT_EQ(s.error_message(), "Cancelled"); +} + +// Server cancels rpc after 1ms +TEST_P(ProxyEnd2endTest, ServerCancelsRpc) { + MAYBE_SKIP_TEST; + ResetStub(); + EchoRequest request; + EchoResponse response; + request.set_message("Hello"); + request.mutable_param()->set_server_cancel_after_us(1000); + + ClientContext context; + Status s = stub_->Echo(&context, request, &response); + EXPECT_EQ(StatusCode::CANCELLED, s.error_code()); + EXPECT_TRUE(s.error_message().empty()); +} + +// Make the response larger than the flow control window. +TEST_P(ProxyEnd2endTest, HugeResponse) { + MAYBE_SKIP_TEST; + ResetStub(); + EchoRequest request; + EchoResponse response; + request.set_message("huge response"); + const size_t kResponseSize = 1024 * (1024 + 10); + request.mutable_param()->set_response_message_length(kResponseSize); + + ClientContext context; + std::chrono::system_clock::time_point deadline = + std::chrono::system_clock::now() + std::chrono::seconds(20); + context.set_deadline(deadline); + Status s = stub_->Echo(&context, request, &response); + EXPECT_EQ(kResponseSize, response.message().size()); + EXPECT_TRUE(s.ok()); +} + +TEST_P(ProxyEnd2endTest, Peer) { + MAYBE_SKIP_TEST; + // Peer is not meaningful for inproc + if (GetParam().inproc) { + return; + } + ResetStub(); + EchoRequest request; + EchoResponse response; + request.set_message("hello"); + request.mutable_param()->set_echo_peer(true); + + ClientContext context; + Status s = stub_->Echo(&context, request, &response); + EXPECT_EQ(response.message(), request.message()); + EXPECT_TRUE(s.ok()); + EXPECT_TRUE(CheckIsLocalhost(response.param().peer())); + EXPECT_TRUE(CheckIsLocalhost(context.peer())); +} + +////////////////////////////////////////////////////////////////////////// +class SecureEnd2endTest : public End2endTest { + protected: + SecureEnd2endTest() { + GPR_ASSERT(!GetParam().use_proxy); + GPR_ASSERT(GetParam().credentials_type != kInsecureCredentialsType); + } +}; + +TEST_P(SecureEnd2endTest, SimpleRpcWithHost) { + MAYBE_SKIP_TEST; + ResetStub(); + + EchoRequest request; + EchoResponse response; + request.set_message("Hello"); + + ClientContext context; + context.set_authority("foo.test.youtube.com"); + Status s = stub_->Echo(&context, request, &response); + EXPECT_EQ(response.message(), request.message()); + EXPECT_TRUE(response.has_param()); + EXPECT_EQ("special", response.param().host()); + EXPECT_TRUE(s.ok()); +} + +bool MetadataContains( + const std::multimap<grpc::string_ref, grpc::string_ref>& metadata, + const TString& key, const TString& value) { + int count = 0; + + for (std::multimap<grpc::string_ref, grpc::string_ref>::const_iterator iter = + metadata.begin(); + iter != metadata.end(); ++iter) { + if (ToString(iter->first) == key && ToString(iter->second) == value) { + count++; + } + } + return count == 1; +} + +TEST_P(SecureEnd2endTest, BlockingAuthMetadataPluginAndProcessorSuccess) { + MAYBE_SKIP_TEST; + auto* processor = new TestAuthMetadataProcessor(true); + StartServer(std::shared_ptr<AuthMetadataProcessor>(processor)); + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + context.set_credentials(processor->GetCompatibleClientCreds()); + request.set_message("Hello"); + request.mutable_param()->set_echo_metadata(true); + request.mutable_param()->set_expected_client_identity( + TestAuthMetadataProcessor::kGoodGuy); + request.mutable_param()->set_expected_transport_security_type( + GetParam().credentials_type); + + Status s = stub_->Echo(&context, request, &response); + EXPECT_EQ(request.message(), response.message()); + EXPECT_TRUE(s.ok()); + + // Metadata should have been consumed by the processor. + EXPECT_FALSE(MetadataContains( + context.GetServerTrailingMetadata(), GRPC_AUTHORIZATION_METADATA_KEY, + TString("Bearer ") + TestAuthMetadataProcessor::kGoodGuy)); +} + +TEST_P(SecureEnd2endTest, BlockingAuthMetadataPluginAndProcessorFailure) { + MAYBE_SKIP_TEST; + auto* processor = new TestAuthMetadataProcessor(true); + StartServer(std::shared_ptr<AuthMetadataProcessor>(processor)); + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + context.set_credentials(processor->GetIncompatibleClientCreds()); + request.set_message("Hello"); + + Status s = stub_->Echo(&context, request, &response); + EXPECT_FALSE(s.ok()); + EXPECT_EQ(s.error_code(), StatusCode::UNAUTHENTICATED); +} + +TEST_P(SecureEnd2endTest, SetPerCallCredentials) { + MAYBE_SKIP_TEST; + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + std::shared_ptr<CallCredentials> creds = + GoogleIAMCredentials(kFakeToken, kFakeSelector); + context.set_credentials(creds); + request.set_message("Hello"); + request.mutable_param()->set_echo_metadata(true); + + Status s = stub_->Echo(&context, request, &response); + EXPECT_EQ(request.message(), response.message()); + EXPECT_TRUE(s.ok()); + EXPECT_TRUE(MetadataContains(context.GetServerTrailingMetadata(), + GRPC_IAM_AUTHORIZATION_TOKEN_METADATA_KEY, + kFakeToken)); + EXPECT_TRUE(MetadataContains(context.GetServerTrailingMetadata(), + GRPC_IAM_AUTHORITY_SELECTOR_METADATA_KEY, + kFakeSelector)); + EXPECT_EQ(context.credentials()->DebugString(), + kExpectedFakeCredsDebugString); +} + +class CredentialsInterceptor : public experimental::Interceptor { + public: + CredentialsInterceptor(experimental::ClientRpcInfo* info) : info_(info) {} + + void Intercept(experimental::InterceptorBatchMethods* methods) { + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + std::shared_ptr<CallCredentials> creds = + GoogleIAMCredentials(kFakeToken, kFakeSelector); + info_->client_context()->set_credentials(creds); + } + methods->Proceed(); + } + + private: + experimental::ClientRpcInfo* info_ = nullptr; +}; + +class CredentialsInterceptorFactory + : public experimental::ClientInterceptorFactoryInterface { + CredentialsInterceptor* CreateClientInterceptor( + experimental::ClientRpcInfo* info) { + return new CredentialsInterceptor(info); + } +}; + +TEST_P(SecureEnd2endTest, CallCredentialsInterception) { + MAYBE_SKIP_TEST; + if (!GetParam().use_interceptors) { + return; + } + std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> + interceptor_creators; + interceptor_creators.push_back(std::unique_ptr<CredentialsInterceptorFactory>( + new CredentialsInterceptorFactory())); + ResetStub(std::move(interceptor_creators)); + EchoRequest request; + EchoResponse response; + ClientContext context; + + request.set_message("Hello"); + request.mutable_param()->set_echo_metadata(true); + + Status s = stub_->Echo(&context, request, &response); + EXPECT_EQ(request.message(), response.message()); + EXPECT_TRUE(s.ok()); + EXPECT_TRUE(MetadataContains(context.GetServerTrailingMetadata(), + GRPC_IAM_AUTHORIZATION_TOKEN_METADATA_KEY, + kFakeToken)); + EXPECT_TRUE(MetadataContains(context.GetServerTrailingMetadata(), + GRPC_IAM_AUTHORITY_SELECTOR_METADATA_KEY, + kFakeSelector)); + EXPECT_EQ(context.credentials()->DebugString(), + kExpectedFakeCredsDebugString); +} + +TEST_P(SecureEnd2endTest, CallCredentialsInterceptionWithSetCredentials) { + MAYBE_SKIP_TEST; + if (!GetParam().use_interceptors) { + return; + } + std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> + interceptor_creators; + interceptor_creators.push_back(std::unique_ptr<CredentialsInterceptorFactory>( + new CredentialsInterceptorFactory())); + ResetStub(std::move(interceptor_creators)); + EchoRequest request; + EchoResponse response; + ClientContext context; + std::shared_ptr<CallCredentials> creds1 = + GoogleIAMCredentials(kWrongToken, kWrongSelector); + context.set_credentials(creds1); + EXPECT_EQ(context.credentials(), creds1); + EXPECT_EQ(context.credentials()->DebugString(), + kExpectedWrongCredsDebugString); + request.set_message("Hello"); + request.mutable_param()->set_echo_metadata(true); + + Status s = stub_->Echo(&context, request, &response); + EXPECT_EQ(request.message(), response.message()); + EXPECT_TRUE(s.ok()); + EXPECT_TRUE(MetadataContains(context.GetServerTrailingMetadata(), + GRPC_IAM_AUTHORIZATION_TOKEN_METADATA_KEY, + kFakeToken)); + EXPECT_TRUE(MetadataContains(context.GetServerTrailingMetadata(), + GRPC_IAM_AUTHORITY_SELECTOR_METADATA_KEY, + kFakeSelector)); + EXPECT_EQ(context.credentials()->DebugString(), + kExpectedFakeCredsDebugString); +} + +TEST_P(SecureEnd2endTest, OverridePerCallCredentials) { + MAYBE_SKIP_TEST; + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + std::shared_ptr<CallCredentials> creds1 = + GoogleIAMCredentials(kFakeToken1, kFakeSelector1); + context.set_credentials(creds1); + EXPECT_EQ(context.credentials(), creds1); + EXPECT_EQ(context.credentials()->DebugString(), + kExpectedFakeCreds1DebugString); + std::shared_ptr<CallCredentials> creds2 = + GoogleIAMCredentials(kFakeToken2, kFakeSelector2); + context.set_credentials(creds2); + EXPECT_EQ(context.credentials(), creds2); + request.set_message("Hello"); + request.mutable_param()->set_echo_metadata(true); + + Status s = stub_->Echo(&context, request, &response); + EXPECT_TRUE(MetadataContains(context.GetServerTrailingMetadata(), + GRPC_IAM_AUTHORIZATION_TOKEN_METADATA_KEY, + kFakeToken2)); + EXPECT_TRUE(MetadataContains(context.GetServerTrailingMetadata(), + GRPC_IAM_AUTHORITY_SELECTOR_METADATA_KEY, + kFakeSelector2)); + EXPECT_FALSE(MetadataContains(context.GetServerTrailingMetadata(), + GRPC_IAM_AUTHORIZATION_TOKEN_METADATA_KEY, + kFakeToken1)); + EXPECT_FALSE(MetadataContains(context.GetServerTrailingMetadata(), + GRPC_IAM_AUTHORITY_SELECTOR_METADATA_KEY, + kFakeSelector1)); + EXPECT_EQ(context.credentials()->DebugString(), + kExpectedFakeCreds2DebugString); + EXPECT_EQ(request.message(), response.message()); + EXPECT_TRUE(s.ok()); +} + +TEST_P(SecureEnd2endTest, AuthMetadataPluginKeyFailure) { + MAYBE_SKIP_TEST; + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + context.set_credentials(grpc::MetadataCredentialsFromPlugin( + std::unique_ptr<MetadataCredentialsPlugin>( + new TestMetadataCredentialsPlugin( + TestMetadataCredentialsPlugin::kBadMetadataKey, + "Does not matter, will fail the key is invalid.", false, true, + 0)))); + request.set_message("Hello"); + + Status s = stub_->Echo(&context, request, &response); + EXPECT_FALSE(s.ok()); + EXPECT_EQ(s.error_code(), StatusCode::UNAVAILABLE); + EXPECT_EQ(context.credentials()->DebugString(), + kExpectedAuthMetadataPluginKeyFailureCredsDebugString); +} + +TEST_P(SecureEnd2endTest, AuthMetadataPluginValueFailure) { + MAYBE_SKIP_TEST; + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + context.set_credentials(grpc::MetadataCredentialsFromPlugin( + std::unique_ptr<MetadataCredentialsPlugin>( + new TestMetadataCredentialsPlugin( + TestMetadataCredentialsPlugin::kGoodMetadataKey, + "With illegal \n value.", false, true, 0)))); + request.set_message("Hello"); + + Status s = stub_->Echo(&context, request, &response); + EXPECT_FALSE(s.ok()); + EXPECT_EQ(s.error_code(), StatusCode::UNAVAILABLE); + EXPECT_EQ(context.credentials()->DebugString(), + kExpectedAuthMetadataPluginValueFailureCredsDebugString); +} + +TEST_P(SecureEnd2endTest, AuthMetadataPluginWithDeadline) { + MAYBE_SKIP_TEST; + ResetStub(); + EchoRequest request; + request.mutable_param()->set_skip_cancelled_check(true); + EchoResponse response; + ClientContext context; + const int delay = 100; + std::chrono::system_clock::time_point deadline = + std::chrono::system_clock::now() + std::chrono::milliseconds(delay); + context.set_deadline(deadline); + context.set_credentials(grpc::MetadataCredentialsFromPlugin( + std::unique_ptr<MetadataCredentialsPlugin>( + new TestMetadataCredentialsPlugin("meta_key", "Does not matter", true, + true, delay)))); + request.set_message("Hello"); + + Status s = stub_->Echo(&context, request, &response); + if (!s.ok()) { + EXPECT_TRUE(s.error_code() == StatusCode::DEADLINE_EXCEEDED || + s.error_code() == StatusCode::UNAVAILABLE); + } + EXPECT_EQ(context.credentials()->DebugString(), + kExpectedAuthMetadataPluginWithDeadlineCredsDebugString); +} + +TEST_P(SecureEnd2endTest, AuthMetadataPluginWithCancel) { + MAYBE_SKIP_TEST; + ResetStub(); + EchoRequest request; + request.mutable_param()->set_skip_cancelled_check(true); + EchoResponse response; + ClientContext context; + const int delay = 100; + context.set_credentials(grpc::MetadataCredentialsFromPlugin( + std::unique_ptr<MetadataCredentialsPlugin>( + new TestMetadataCredentialsPlugin("meta_key", "Does not matter", true, + true, delay)))); + request.set_message("Hello"); + + std::thread cancel_thread([&] { + gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_millis(delay, GPR_TIMESPAN))); + context.TryCancel(); + }); + Status s = stub_->Echo(&context, request, &response); + if (!s.ok()) { + EXPECT_TRUE(s.error_code() == StatusCode::CANCELLED || + s.error_code() == StatusCode::UNAVAILABLE); + } + cancel_thread.join(); + EXPECT_EQ(context.credentials()->DebugString(), + kExpectedAuthMetadataPluginWithDeadlineCredsDebugString); +} + +TEST_P(SecureEnd2endTest, NonBlockingAuthMetadataPluginFailure) { + MAYBE_SKIP_TEST; + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + context.set_credentials(grpc::MetadataCredentialsFromPlugin( + std::unique_ptr<MetadataCredentialsPlugin>( + new TestMetadataCredentialsPlugin( + TestMetadataCredentialsPlugin::kGoodMetadataKey, + "Does not matter, will fail anyway (see 3rd param)", false, false, + 0)))); + request.set_message("Hello"); + + Status s = stub_->Echo(&context, request, &response); + EXPECT_FALSE(s.ok()); + EXPECT_EQ(s.error_code(), StatusCode::UNAVAILABLE); + EXPECT_EQ(s.error_message(), + TString("Getting metadata from plugin failed with error: ") + + kTestCredsPluginErrorMsg); + EXPECT_EQ(context.credentials()->DebugString(), + kExpectedNonBlockingAuthMetadataPluginFailureCredsDebugString); +} + +TEST_P(SecureEnd2endTest, NonBlockingAuthMetadataPluginAndProcessorSuccess) { + MAYBE_SKIP_TEST; + auto* processor = new TestAuthMetadataProcessor(false); + StartServer(std::shared_ptr<AuthMetadataProcessor>(processor)); + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + context.set_credentials(processor->GetCompatibleClientCreds()); + request.set_message("Hello"); + request.mutable_param()->set_echo_metadata(true); + request.mutable_param()->set_expected_client_identity( + TestAuthMetadataProcessor::kGoodGuy); + request.mutable_param()->set_expected_transport_security_type( + GetParam().credentials_type); + + Status s = stub_->Echo(&context, request, &response); + EXPECT_EQ(request.message(), response.message()); + EXPECT_TRUE(s.ok()); + + // Metadata should have been consumed by the processor. + EXPECT_FALSE(MetadataContains( + context.GetServerTrailingMetadata(), GRPC_AUTHORIZATION_METADATA_KEY, + TString("Bearer ") + TestAuthMetadataProcessor::kGoodGuy)); + EXPECT_EQ( + context.credentials()->DebugString(), + kExpectedNonBlockingAuthMetadataPluginAndProcessorSuccessCredsDebugString); +} + +TEST_P(SecureEnd2endTest, NonBlockingAuthMetadataPluginAndProcessorFailure) { + MAYBE_SKIP_TEST; + auto* processor = new TestAuthMetadataProcessor(false); + StartServer(std::shared_ptr<AuthMetadataProcessor>(processor)); + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + context.set_credentials(processor->GetIncompatibleClientCreds()); + request.set_message("Hello"); + + Status s = stub_->Echo(&context, request, &response); + EXPECT_FALSE(s.ok()); + EXPECT_EQ(s.error_code(), StatusCode::UNAUTHENTICATED); + EXPECT_EQ( + context.credentials()->DebugString(), + kExpectedNonBlockingAuthMetadataPluginAndProcessorFailureCredsDebugString); +} + +TEST_P(SecureEnd2endTest, BlockingAuthMetadataPluginFailure) { + MAYBE_SKIP_TEST; + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + context.set_credentials(grpc::MetadataCredentialsFromPlugin( + std::unique_ptr<MetadataCredentialsPlugin>( + new TestMetadataCredentialsPlugin( + TestMetadataCredentialsPlugin::kGoodMetadataKey, + "Does not matter, will fail anyway (see 3rd param)", true, false, + 0)))); + request.set_message("Hello"); + + Status s = stub_->Echo(&context, request, &response); + EXPECT_FALSE(s.ok()); + EXPECT_EQ(s.error_code(), StatusCode::UNAVAILABLE); + EXPECT_EQ(s.error_message(), + TString("Getting metadata from plugin failed with error: ") + + kTestCredsPluginErrorMsg); + EXPECT_EQ(context.credentials()->DebugString(), + kExpectedBlockingAuthMetadataPluginFailureCredsDebugString); +} + +TEST_P(SecureEnd2endTest, CompositeCallCreds) { + MAYBE_SKIP_TEST; + ResetStub(); + EchoRequest request; + EchoResponse response; + ClientContext context; + const char kMetadataKey1[] = "call-creds-key1"; + const char kMetadataKey2[] = "call-creds-key2"; + const char kMetadataVal1[] = "call-creds-val1"; + const char kMetadataVal2[] = "call-creds-val2"; + + context.set_credentials(grpc::CompositeCallCredentials( + grpc::MetadataCredentialsFromPlugin( + std::unique_ptr<MetadataCredentialsPlugin>( + new TestMetadataCredentialsPlugin(kMetadataKey1, kMetadataVal1, + true, true, 0))), + grpc::MetadataCredentialsFromPlugin( + std::unique_ptr<MetadataCredentialsPlugin>( + new TestMetadataCredentialsPlugin(kMetadataKey2, kMetadataVal2, + true, true, 0))))); + request.set_message("Hello"); + request.mutable_param()->set_echo_metadata(true); + + Status s = stub_->Echo(&context, request, &response); + EXPECT_TRUE(s.ok()); + EXPECT_TRUE(MetadataContains(context.GetServerTrailingMetadata(), + kMetadataKey1, kMetadataVal1)); + EXPECT_TRUE(MetadataContains(context.GetServerTrailingMetadata(), + kMetadataKey2, kMetadataVal2)); + EXPECT_EQ(context.credentials()->DebugString(), + kExpectedCompositeCallCredsDebugString); +} + +TEST_P(SecureEnd2endTest, ClientAuthContext) { + MAYBE_SKIP_TEST; + ResetStub(); + EchoRequest request; + EchoResponse response; + request.set_message("Hello"); + request.mutable_param()->set_check_auth_context(GetParam().credentials_type == + kTlsCredentialsType); + request.mutable_param()->set_expected_transport_security_type( + GetParam().credentials_type); + ClientContext context; + Status s = stub_->Echo(&context, request, &response); + EXPECT_EQ(response.message(), request.message()); + EXPECT_TRUE(s.ok()); + + std::shared_ptr<const AuthContext> auth_ctx = context.auth_context(); + std::vector<grpc::string_ref> tst = + auth_ctx->FindPropertyValues("transport_security_type"); + ASSERT_EQ(1u, tst.size()); + EXPECT_EQ(GetParam().credentials_type, ToString(tst[0])); + if (GetParam().credentials_type == kTlsCredentialsType) { + EXPECT_EQ("x509_subject_alternative_name", + auth_ctx->GetPeerIdentityPropertyName()); + EXPECT_EQ(4u, auth_ctx->GetPeerIdentity().size()); + EXPECT_EQ("*.test.google.fr", ToString(auth_ctx->GetPeerIdentity()[0])); + EXPECT_EQ("waterzooi.test.google.be", + ToString(auth_ctx->GetPeerIdentity()[1])); + EXPECT_EQ("*.test.youtube.com", ToString(auth_ctx->GetPeerIdentity()[2])); + EXPECT_EQ("192.168.1.3", ToString(auth_ctx->GetPeerIdentity()[3])); + } +} + +class ResourceQuotaEnd2endTest : public End2endTest { + public: + ResourceQuotaEnd2endTest() + : server_resource_quota_("server_resource_quota") {} + + virtual void ConfigureServerBuilder(ServerBuilder* builder) override { + builder->SetResourceQuota(server_resource_quota_); + } + + private: + ResourceQuota server_resource_quota_; +}; + +TEST_P(ResourceQuotaEnd2endTest, SimpleRequest) { + MAYBE_SKIP_TEST; + ResetStub(); + + EchoRequest request; + EchoResponse response; + request.set_message("Hello"); + + ClientContext context; + Status s = stub_->Echo(&context, request, &response); + EXPECT_EQ(response.message(), request.message()); + EXPECT_TRUE(s.ok()); +} + +// TODO(vjpai): refactor arguments into a struct if it makes sense +std::vector<TestScenario> CreateTestScenarios(bool use_proxy, + bool test_insecure, + bool test_secure, + bool test_inproc, + bool test_callback_server) { + std::vector<TestScenario> scenarios; + std::vector<TString> credentials_types; + + GPR_GLOBAL_CONFIG_SET(grpc_client_channel_backup_poll_interval_ms, + kClientChannelBackupPollIntervalMs); +#if TARGET_OS_IPHONE + // Workaround Apple CFStream bug + gpr_setenv("grpc_cfstream", "0"); +#endif + + if (test_secure) { + credentials_types = + GetCredentialsProvider()->GetSecureCredentialsTypeList(); + } + auto insec_ok = [] { + // Only allow insecure credentials type when it is registered with the + // provider. User may create providers that do not have insecure. + return GetCredentialsProvider()->GetChannelCredentials( + kInsecureCredentialsType, nullptr) != nullptr; + }; + if (test_insecure && insec_ok()) { + credentials_types.push_back(kInsecureCredentialsType); + } + + // Test callback with inproc or if the event-engine allows it + GPR_ASSERT(!credentials_types.empty()); + for (const auto& cred : credentials_types) { + scenarios.emplace_back(false, false, false, cred, false); + scenarios.emplace_back(true, false, false, cred, false); + if (test_callback_server) { + // Note that these scenarios will be dynamically disabled if the event + // engine doesn't run in the background + scenarios.emplace_back(false, false, false, cred, true); + scenarios.emplace_back(true, false, false, cred, true); + } + if (use_proxy) { + scenarios.emplace_back(false, true, false, cred, false); + scenarios.emplace_back(true, true, false, cred, false); + } + } + if (test_inproc && insec_ok()) { + scenarios.emplace_back(false, false, true, kInsecureCredentialsType, false); + scenarios.emplace_back(true, false, true, kInsecureCredentialsType, false); + if (test_callback_server) { + scenarios.emplace_back(false, false, true, kInsecureCredentialsType, + true); + scenarios.emplace_back(true, false, true, kInsecureCredentialsType, true); + } + } + return scenarios; +} + +INSTANTIATE_TEST_SUITE_P( + End2end, End2endTest, + ::testing::ValuesIn(CreateTestScenarios(false, true, true, true, true))); + +INSTANTIATE_TEST_SUITE_P( + End2endServerTryCancel, End2endServerTryCancelTest, + ::testing::ValuesIn(CreateTestScenarios(false, true, true, true, true))); + +INSTANTIATE_TEST_SUITE_P( + ProxyEnd2end, ProxyEnd2endTest, + ::testing::ValuesIn(CreateTestScenarios(true, true, true, true, true))); + +INSTANTIATE_TEST_SUITE_P( + SecureEnd2end, SecureEnd2endTest, + ::testing::ValuesIn(CreateTestScenarios(false, false, true, false, true))); + +INSTANTIATE_TEST_SUITE_P( + ResourceQuotaEnd2end, ResourceQuotaEnd2endTest, + ::testing::ValuesIn(CreateTestScenarios(false, true, true, true, true))); + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + int ret = RUN_ALL_TESTS(); + return ret; +} diff --git a/contrib/libs/grpc/test/cpp/end2end/exception_test.cc b/contrib/libs/grpc/test/cpp/end2end/exception_test.cc new file mode 100644 index 0000000000..cd29eb8a10 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/end2end/exception_test.cc @@ -0,0 +1,123 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <exception> +#include <memory> + +#include <grpc/impl/codegen/port_platform.h> +#include <grpcpp/channel.h> +#include <grpcpp/client_context.h> +#include <grpcpp/server.h> +#include <grpcpp/server_builder.h> +#include <grpcpp/server_context.h> + +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/test_config.h" + +#include <gtest/gtest.h> + +namespace grpc { +namespace testing { + +const char* kErrorMessage = "This service caused an exception"; + +#if GRPC_ALLOW_EXCEPTIONS +class ExceptingServiceImpl : public ::grpc::testing::EchoTestService::Service { + public: + Status Echo(ServerContext* /*server_context*/, const EchoRequest* /*request*/, + EchoResponse* /*response*/) override { + throw - 1; + } + Status RequestStream(ServerContext* /*context*/, + ServerReader<EchoRequest>* /*reader*/, + EchoResponse* /*response*/) override { + throw ServiceException(); + } + + private: + class ServiceException final : public std::exception { + public: + ServiceException() {} + + private: + const char* what() const noexcept override { return kErrorMessage; } + }; +}; + +class ExceptionTest : public ::testing::Test { + protected: + ExceptionTest() {} + + void SetUp() override { + ServerBuilder builder; + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + } + + void TearDown() override { server_->Shutdown(); } + + void ResetStub() { + channel_ = server_->InProcessChannel(ChannelArguments()); + stub_ = grpc::testing::EchoTestService::NewStub(channel_); + } + + std::shared_ptr<Channel> channel_; + std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_; + std::unique_ptr<Server> server_; + ExceptingServiceImpl service_; +}; + +TEST_F(ExceptionTest, Unary) { + ResetStub(); + EchoRequest request; + EchoResponse response; + request.set_message("test"); + + for (int i = 0; i < 10; i++) { + ClientContext context; + Status s = stub_->Echo(&context, request, &response); + EXPECT_FALSE(s.ok()); + EXPECT_EQ(s.error_code(), StatusCode::UNKNOWN); + } +} + +TEST_F(ExceptionTest, RequestStream) { + ResetStub(); + EchoResponse response; + + for (int i = 0; i < 10; i++) { + ClientContext context; + auto stream = stub_->RequestStream(&context, &response); + stream->WritesDone(); + Status s = stream->Finish(); + + EXPECT_FALSE(s.ok()); + EXPECT_EQ(s.error_code(), StatusCode::UNKNOWN); + } +} + +#endif // GRPC_ALLOW_EXCEPTIONS + +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/contrib/libs/grpc/test/cpp/end2end/filter_end2end_test.cc b/contrib/libs/grpc/test/cpp/end2end/filter_end2end_test.cc new file mode 100644 index 0000000000..2f26d0716c --- /dev/null +++ b/contrib/libs/grpc/test/cpp/end2end/filter_end2end_test.cc @@ -0,0 +1,346 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <memory> +#include <mutex> +#include <thread> + +#include <grpc/grpc.h> +#include <grpc/support/time.h> +#include <grpcpp/channel.h> +#include <grpcpp/client_context.h> +#include <grpcpp/create_channel.h> +#include <grpcpp/generic/async_generic_service.h> +#include <grpcpp/generic/generic_stub.h> +#include <grpcpp/impl/codegen/proto_utils.h> +#include <grpcpp/server.h> +#include <grpcpp/server_builder.h> +#include <grpcpp/server_context.h> +#include <grpcpp/support/config.h> +#include <grpcpp/support/slice.h> + +#include "src/cpp/common/channel_filter.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/util/byte_buffer_proto_helper.h" + +#include <gtest/gtest.h> + +using grpc::testing::EchoRequest; +using grpc::testing::EchoResponse; +using std::chrono::system_clock; + +namespace grpc { +namespace testing { +namespace { + +void* tag(int i) { return (void*)static_cast<intptr_t>(i); } + +void verify_ok(CompletionQueue* cq, int i, bool expect_ok) { + bool ok; + void* got_tag; + EXPECT_TRUE(cq->Next(&got_tag, &ok)); + EXPECT_EQ(expect_ok, ok); + EXPECT_EQ(tag(i), got_tag); +} + +namespace { + +int global_num_connections = 0; +int global_num_calls = 0; +std::mutex global_mu; + +void IncrementConnectionCounter() { + std::unique_lock<std::mutex> lock(global_mu); + ++global_num_connections; +} + +void ResetConnectionCounter() { + std::unique_lock<std::mutex> lock(global_mu); + global_num_connections = 0; +} + +int GetConnectionCounterValue() { + std::unique_lock<std::mutex> lock(global_mu); + return global_num_connections; +} + +void IncrementCallCounter() { + std::unique_lock<std::mutex> lock(global_mu); + ++global_num_calls; +} + +void ResetCallCounter() { + std::unique_lock<std::mutex> lock(global_mu); + global_num_calls = 0; +} + +int GetCallCounterValue() { + std::unique_lock<std::mutex> lock(global_mu); + return global_num_calls; +} + +} // namespace + +class ChannelDataImpl : public ChannelData { + public: + grpc_error* Init(grpc_channel_element* /*elem*/, + grpc_channel_element_args* /*args*/) { + IncrementConnectionCounter(); + return GRPC_ERROR_NONE; + } +}; + +class CallDataImpl : public CallData { + public: + void StartTransportStreamOpBatch(grpc_call_element* elem, + TransportStreamOpBatch* op) override { + // Incrementing the counter could be done from Init(), but we want + // to test that the individual methods are actually called correctly. + if (op->recv_initial_metadata() != nullptr) IncrementCallCounter(); + grpc_call_next_op(elem, op->op()); + } +}; + +class FilterEnd2endTest : public ::testing::Test { + protected: + FilterEnd2endTest() : server_host_("localhost") {} + + static void SetUpTestCase() { + // Workaround for + // https://github.com/google/google-toolbox-for-mac/issues/242 + static bool setup_done = false; + if (!setup_done) { + setup_done = true; + grpc::RegisterChannelFilter<ChannelDataImpl, CallDataImpl>( + "test-filter", GRPC_SERVER_CHANNEL, INT_MAX, nullptr); + } + } + + void SetUp() override { + int port = grpc_pick_unused_port_or_die(); + server_address_ << server_host_ << ":" << port; + // Setup server + ServerBuilder builder; + builder.AddListeningPort(server_address_.str(), + InsecureServerCredentials()); + builder.RegisterAsyncGenericService(&generic_service_); + srv_cq_ = builder.AddCompletionQueue(); + server_ = builder.BuildAndStart(); + } + + void TearDown() override { + server_->Shutdown(); + void* ignored_tag; + bool ignored_ok; + cli_cq_.Shutdown(); + srv_cq_->Shutdown(); + while (cli_cq_.Next(&ignored_tag, &ignored_ok)) + ; + while (srv_cq_->Next(&ignored_tag, &ignored_ok)) + ; + } + + void ResetStub() { + std::shared_ptr<Channel> channel = grpc::CreateChannel( + server_address_.str(), InsecureChannelCredentials()); + generic_stub_.reset(new GenericStub(channel)); + ResetConnectionCounter(); + ResetCallCounter(); + } + + void server_ok(int i) { verify_ok(srv_cq_.get(), i, true); } + void client_ok(int i) { verify_ok(&cli_cq_, i, true); } + void server_fail(int i) { verify_ok(srv_cq_.get(), i, false); } + void client_fail(int i) { verify_ok(&cli_cq_, i, false); } + + void SendRpc(int num_rpcs) { + const TString kMethodName("/grpc.cpp.test.util.EchoTestService/Echo"); + for (int i = 0; i < num_rpcs; i++) { + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + GenericServerContext srv_ctx; + GenericServerAsyncReaderWriter stream(&srv_ctx); + + // The string needs to be long enough to test heap-based slice. + send_request.set_message("Hello world. Hello world. Hello world."); + std::thread request_call([this]() { server_ok(4); }); + std::unique_ptr<GenericClientAsyncReaderWriter> call = + generic_stub_->PrepareCall(&cli_ctx, kMethodName, &cli_cq_); + call->StartCall(tag(1)); + client_ok(1); + std::unique_ptr<ByteBuffer> send_buffer = + SerializeToByteBuffer(&send_request); + call->Write(*send_buffer, tag(2)); + // Send ByteBuffer can be destroyed after calling Write. + send_buffer.reset(); + client_ok(2); + call->WritesDone(tag(3)); + client_ok(3); + + generic_service_.RequestCall(&srv_ctx, &stream, srv_cq_.get(), + srv_cq_.get(), tag(4)); + + request_call.join(); + EXPECT_EQ(server_host_, srv_ctx.host().substr(0, server_host_.length())); + EXPECT_EQ(kMethodName, srv_ctx.method()); + ByteBuffer recv_buffer; + stream.Read(&recv_buffer, tag(5)); + server_ok(5); + EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_request)); + EXPECT_EQ(send_request.message(), recv_request.message()); + + send_response.set_message(recv_request.message()); + send_buffer = SerializeToByteBuffer(&send_response); + stream.Write(*send_buffer, tag(6)); + send_buffer.reset(); + server_ok(6); + + stream.Finish(Status::OK, tag(7)); + server_ok(7); + + recv_buffer.Clear(); + call->Read(&recv_buffer, tag(8)); + client_ok(8); + EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_response)); + + call->Finish(&recv_status, tag(9)); + client_ok(9); + + EXPECT_EQ(send_response.message(), recv_response.message()); + EXPECT_TRUE(recv_status.ok()); + } + } + + CompletionQueue cli_cq_; + std::unique_ptr<ServerCompletionQueue> srv_cq_; + std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_; + std::unique_ptr<grpc::GenericStub> generic_stub_; + std::unique_ptr<Server> server_; + AsyncGenericService generic_service_; + const TString server_host_; + std::ostringstream server_address_; +}; + +TEST_F(FilterEnd2endTest, SimpleRpc) { + ResetStub(); + EXPECT_EQ(0, GetConnectionCounterValue()); + EXPECT_EQ(0, GetCallCounterValue()); + SendRpc(1); + EXPECT_EQ(1, GetConnectionCounterValue()); + EXPECT_EQ(1, GetCallCounterValue()); +} + +TEST_F(FilterEnd2endTest, SequentialRpcs) { + ResetStub(); + EXPECT_EQ(0, GetConnectionCounterValue()); + EXPECT_EQ(0, GetCallCounterValue()); + SendRpc(10); + EXPECT_EQ(1, GetConnectionCounterValue()); + EXPECT_EQ(10, GetCallCounterValue()); +} + +// One ping, one pong. +TEST_F(FilterEnd2endTest, SimpleBidiStreaming) { + ResetStub(); + EXPECT_EQ(0, GetConnectionCounterValue()); + EXPECT_EQ(0, GetCallCounterValue()); + + const TString kMethodName( + "/grpc.cpp.test.util.EchoTestService/BidiStream"); + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + ClientContext cli_ctx; + GenericServerContext srv_ctx; + GenericServerAsyncReaderWriter srv_stream(&srv_ctx); + + cli_ctx.set_compression_algorithm(GRPC_COMPRESS_GZIP); + send_request.set_message("Hello"); + std::thread request_call([this]() { server_ok(2); }); + std::unique_ptr<GenericClientAsyncReaderWriter> cli_stream = + generic_stub_->PrepareCall(&cli_ctx, kMethodName, &cli_cq_); + cli_stream->StartCall(tag(1)); + client_ok(1); + + generic_service_.RequestCall(&srv_ctx, &srv_stream, srv_cq_.get(), + srv_cq_.get(), tag(2)); + + request_call.join(); + EXPECT_EQ(server_host_, srv_ctx.host().substr(0, server_host_.length())); + EXPECT_EQ(kMethodName, srv_ctx.method()); + + std::unique_ptr<ByteBuffer> send_buffer = + SerializeToByteBuffer(&send_request); + cli_stream->Write(*send_buffer, tag(3)); + send_buffer.reset(); + client_ok(3); + + ByteBuffer recv_buffer; + srv_stream.Read(&recv_buffer, tag(4)); + server_ok(4); + EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_request)); + EXPECT_EQ(send_request.message(), recv_request.message()); + + send_response.set_message(recv_request.message()); + send_buffer = SerializeToByteBuffer(&send_response); + srv_stream.Write(*send_buffer, tag(5)); + send_buffer.reset(); + server_ok(5); + + cli_stream->Read(&recv_buffer, tag(6)); + client_ok(6); + EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_response)); + EXPECT_EQ(send_response.message(), recv_response.message()); + + cli_stream->WritesDone(tag(7)); + client_ok(7); + + srv_stream.Read(&recv_buffer, tag(8)); + server_fail(8); + + srv_stream.Finish(Status::OK, tag(9)); + server_ok(9); + + cli_stream->Finish(&recv_status, tag(10)); + client_ok(10); + + EXPECT_EQ(send_response.message(), recv_response.message()); + EXPECT_TRUE(recv_status.ok()); + + EXPECT_EQ(1, GetCallCounterValue()); + EXPECT_EQ(1, GetConnectionCounterValue()); +} + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/contrib/libs/grpc/test/cpp/end2end/flaky_network_test.cc b/contrib/libs/grpc/test/cpp/end2end/flaky_network_test.cc new file mode 100644 index 0000000000..3ee75952c0 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/end2end/flaky_network_test.cc @@ -0,0 +1,558 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <grpc/grpc.h> +#include <grpc/support/alloc.h> +#include <grpc/support/atm.h> +#include <grpc/support/log.h> +#include <grpc/support/port_platform.h> +#include <grpc/support/string_util.h> +#include <grpc/support/time.h> +#include <grpcpp/channel.h> +#include <grpcpp/client_context.h> +#include <grpcpp/create_channel.h> +#include <grpcpp/health_check_service_interface.h> +#include <grpcpp/server.h> +#include <grpcpp/server_builder.h> +#include <gtest/gtest.h> + +#include <algorithm> +#include <condition_variable> +#include <memory> +#include <mutex> +#include <random> +#include <thread> + +#include "src/core/lib/backoff/backoff.h" +#include "src/core/lib/gpr/env.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/debugger_macros.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/end2end/test_service_impl.h" +#include "test/cpp/util/test_credentials_provider.h" + +#ifdef GPR_LINUX +using grpc::testing::EchoRequest; +using grpc::testing::EchoResponse; + +namespace grpc { +namespace testing { +namespace { + +struct TestScenario { + TestScenario(const TString& creds_type, const TString& content) + : credentials_type(creds_type), message_content(content) {} + const TString credentials_type; + const TString message_content; +}; + +class FlakyNetworkTest : public ::testing::TestWithParam<TestScenario> { + protected: + FlakyNetworkTest() + : server_host_("grpctest"), + interface_("lo:1"), + ipv4_address_("10.0.0.1"), + netmask_("/32") {} + + void InterfaceUp() { + std::ostringstream cmd; + // create interface_ with address ipv4_address_ + cmd << "ip addr add " << ipv4_address_ << netmask_ << " dev " << interface_; + std::system(cmd.str().c_str()); + } + + void InterfaceDown() { + std::ostringstream cmd; + // remove interface_ + cmd << "ip addr del " << ipv4_address_ << netmask_ << " dev " << interface_; + std::system(cmd.str().c_str()); + } + + void DNSUp() { + std::ostringstream cmd; + // Add DNS entry for server_host_ in /etc/hosts + cmd << "echo '" << ipv4_address_ << " " << server_host_ + << "' >> /etc/hosts"; + std::system(cmd.str().c_str()); + } + + void DNSDown() { + std::ostringstream cmd; + // Remove DNS entry for server_host_ from /etc/hosts + // NOTE: we can't do this in one step with sed -i because when we are + // running under docker, the file is mounted by docker so we can't change + // its inode from within the container (sed -i creates a new file and + // replaces the old file, which changes the inode) + cmd << "sed '/" << server_host_ << "/d' /etc/hosts > /etc/hosts.orig"; + std::system(cmd.str().c_str()); + + // clear the stream + cmd.str(""); + + cmd << "cat /etc/hosts.orig > /etc/hosts"; + std::system(cmd.str().c_str()); + } + + void DropPackets() { + std::ostringstream cmd; + // drop packets with src IP = ipv4_address_ + cmd << "iptables -A INPUT -s " << ipv4_address_ << " -j DROP"; + + std::system(cmd.str().c_str()); + // clear the stream + cmd.str(""); + + // drop packets with dst IP = ipv4_address_ + cmd << "iptables -A INPUT -d " << ipv4_address_ << " -j DROP"; + } + + void RestoreNetwork() { + std::ostringstream cmd; + // remove iptables rule to drop packets with src IP = ipv4_address_ + cmd << "iptables -D INPUT -s " << ipv4_address_ << " -j DROP"; + std::system(cmd.str().c_str()); + // clear the stream + cmd.str(""); + // remove iptables rule to drop packets with dest IP = ipv4_address_ + cmd << "iptables -D INPUT -d " << ipv4_address_ << " -j DROP"; + } + + void FlakeNetwork() { + std::ostringstream cmd; + // Emulate a flaky network connection over interface_. Add a delay of 100ms + // +/- 20ms, 0.1% packet loss, 1% duplicates and 0.01% corrupt packets. + cmd << "tc qdisc replace dev " << interface_ + << " root netem delay 100ms 20ms distribution normal loss 0.1% " + "duplicate " + "0.1% corrupt 0.01% "; + std::system(cmd.str().c_str()); + } + + void UnflakeNetwork() { + // Remove simulated network flake on interface_ + std::ostringstream cmd; + cmd << "tc qdisc del dev " << interface_ << " root netem"; + std::system(cmd.str().c_str()); + } + + void NetworkUp() { + InterfaceUp(); + DNSUp(); + } + + void NetworkDown() { + InterfaceDown(); + DNSDown(); + } + + void SetUp() override { + NetworkUp(); + grpc_init(); + StartServer(); + } + + void TearDown() override { + NetworkDown(); + StopServer(); + grpc_shutdown(); + } + + void StartServer() { + // TODO (pjaikumar): Ideally, we should allocate the port dynamically using + // grpc_pick_unused_port_or_die(). That doesn't work inside some docker + // containers because port_server listens on localhost which maps to + // ip6-looopback, but ipv6 support is not enabled by default in docker. + port_ = SERVER_PORT; + + server_.reset(new ServerData(port_, GetParam().credentials_type)); + server_->Start(server_host_); + } + void StopServer() { server_->Shutdown(); } + + std::unique_ptr<grpc::testing::EchoTestService::Stub> BuildStub( + const std::shared_ptr<Channel>& channel) { + return grpc::testing::EchoTestService::NewStub(channel); + } + + std::shared_ptr<Channel> BuildChannel( + const TString& lb_policy_name, + ChannelArguments args = ChannelArguments()) { + if (lb_policy_name.size() > 0) { + args.SetLoadBalancingPolicyName(lb_policy_name); + } // else, default to pick first + auto channel_creds = GetCredentialsProvider()->GetChannelCredentials( + GetParam().credentials_type, &args); + std::ostringstream server_address; + server_address << server_host_ << ":" << port_; + return CreateCustomChannel(server_address.str(), channel_creds, args); + } + + bool SendRpc( + const std::unique_ptr<grpc::testing::EchoTestService::Stub>& stub, + int timeout_ms = 0, bool wait_for_ready = false) { + auto response = std::unique_ptr<EchoResponse>(new EchoResponse()); + EchoRequest request; + auto& msg = GetParam().message_content; + request.set_message(msg); + ClientContext context; + if (timeout_ms > 0) { + context.set_deadline(grpc_timeout_milliseconds_to_deadline(timeout_ms)); + // Allow an RPC to be canceled (for deadline exceeded) after it has + // reached the server. + request.mutable_param()->set_skip_cancelled_check(true); + } + // See https://github.com/grpc/grpc/blob/master/doc/wait-for-ready.md for + // details of wait-for-ready semantics + if (wait_for_ready) { + context.set_wait_for_ready(true); + } + Status status = stub->Echo(&context, request, response.get()); + auto ok = status.ok(); + int stream_id = 0; + grpc_call* call = context.c_call(); + if (call) { + grpc_chttp2_stream* stream = grpc_chttp2_stream_from_call(call); + if (stream) { + stream_id = stream->id; + } + } + if (ok) { + gpr_log(GPR_DEBUG, "RPC with stream_id %d succeeded", stream_id); + } else { + gpr_log(GPR_DEBUG, "RPC with stream_id %d failed: %s", stream_id, + status.error_message().c_str()); + } + return ok; + } + + struct ServerData { + int port_; + const TString creds_; + std::unique_ptr<Server> server_; + TestServiceImpl service_; + std::unique_ptr<std::thread> thread_; + bool server_ready_ = false; + + ServerData(int port, const TString& creds) + : port_(port), creds_(creds) {} + + void Start(const TString& server_host) { + gpr_log(GPR_INFO, "starting server on port %d", port_); + std::mutex mu; + std::unique_lock<std::mutex> lock(mu); + std::condition_variable cond; + thread_.reset(new std::thread( + std::bind(&ServerData::Serve, this, server_host, &mu, &cond))); + cond.wait(lock, [this] { return server_ready_; }); + server_ready_ = false; + gpr_log(GPR_INFO, "server startup complete"); + } + + void Serve(const TString& server_host, std::mutex* mu, + std::condition_variable* cond) { + std::ostringstream server_address; + server_address << server_host << ":" << port_; + ServerBuilder builder; + auto server_creds = + GetCredentialsProvider()->GetServerCredentials(creds_); + builder.AddListeningPort(server_address.str(), server_creds); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + std::lock_guard<std::mutex> lock(*mu); + server_ready_ = true; + cond->notify_one(); + } + + void Shutdown() { + server_->Shutdown(grpc_timeout_milliseconds_to_deadline(0)); + thread_->join(); + } + }; + + bool WaitForChannelNotReady(Channel* channel, int timeout_seconds = 5) { + const gpr_timespec deadline = + grpc_timeout_seconds_to_deadline(timeout_seconds); + grpc_connectivity_state state; + while ((state = channel->GetState(false /* try_to_connect */)) == + GRPC_CHANNEL_READY) { + if (!channel->WaitForStateChange(state, deadline)) return false; + } + return true; + } + + bool WaitForChannelReady(Channel* channel, int timeout_seconds = 5) { + const gpr_timespec deadline = + grpc_timeout_seconds_to_deadline(timeout_seconds); + grpc_connectivity_state state; + while ((state = channel->GetState(true /* try_to_connect */)) != + GRPC_CHANNEL_READY) { + if (!channel->WaitForStateChange(state, deadline)) return false; + } + return true; + } + + private: + const TString server_host_; + const TString interface_; + const TString ipv4_address_; + const TString netmask_; + std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_; + std::unique_ptr<ServerData> server_; + const int SERVER_PORT = 32750; + int port_; +}; + +std::vector<TestScenario> CreateTestScenarios() { + std::vector<TestScenario> scenarios; + std::vector<TString> credentials_types; + std::vector<TString> messages; + + credentials_types.push_back(kInsecureCredentialsType); + auto sec_list = GetCredentialsProvider()->GetSecureCredentialsTypeList(); + for (auto sec = sec_list.begin(); sec != sec_list.end(); sec++) { + credentials_types.push_back(*sec); + } + + messages.push_back("🖖"); + for (size_t k = 1; k < GRPC_DEFAULT_MAX_RECV_MESSAGE_LENGTH / 1024; k *= 32) { + TString big_msg; + for (size_t i = 0; i < k * 1024; ++i) { + char c = 'a' + (i % 26); + big_msg += c; + } + messages.push_back(big_msg); + } + for (auto cred = credentials_types.begin(); cred != credentials_types.end(); + ++cred) { + for (auto msg = messages.begin(); msg != messages.end(); msg++) { + scenarios.emplace_back(*cred, *msg); + } + } + + return scenarios; +} + +INSTANTIATE_TEST_SUITE_P(FlakyNetworkTest, FlakyNetworkTest, + ::testing::ValuesIn(CreateTestScenarios())); + +// Network interface connected to server flaps +TEST_P(FlakyNetworkTest, NetworkTransition) { + const int kKeepAliveTimeMs = 1000; + const int kKeepAliveTimeoutMs = 1000; + ChannelArguments args; + args.SetInt(GRPC_ARG_KEEPALIVE_TIME_MS, kKeepAliveTimeMs); + args.SetInt(GRPC_ARG_KEEPALIVE_TIMEOUT_MS, kKeepAliveTimeoutMs); + args.SetInt(GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS, 1); + args.SetInt(GRPC_ARG_HTTP2_MAX_PINGS_WITHOUT_DATA, 0); + + auto channel = BuildChannel("pick_first", args); + auto stub = BuildStub(channel); + // Channel should be in READY state after we send an RPC + EXPECT_TRUE(SendRpc(stub)); + EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY); + + std::atomic_bool shutdown{false}; + std::thread sender = std::thread([this, &stub, &shutdown]() { + while (true) { + if (shutdown.load()) { + return; + } + SendRpc(stub); + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + } + }); + + // bring down network + NetworkDown(); + EXPECT_TRUE(WaitForChannelNotReady(channel.get())); + // bring network interface back up + InterfaceUp(); + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + // Restore DNS entry for server + DNSUp(); + EXPECT_TRUE(WaitForChannelReady(channel.get())); + EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY); + shutdown.store(true); + sender.join(); +} + +// Traffic to server server is blackholed temporarily with keepalives enabled +TEST_P(FlakyNetworkTest, ServerUnreachableWithKeepalive) { + const int kKeepAliveTimeMs = 1000; + const int kKeepAliveTimeoutMs = 1000; + const int kReconnectBackoffMs = 1000; + ChannelArguments args; + args.SetInt(GRPC_ARG_KEEPALIVE_TIME_MS, kKeepAliveTimeMs); + args.SetInt(GRPC_ARG_KEEPALIVE_TIMEOUT_MS, kKeepAliveTimeoutMs); + args.SetInt(GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS, 1); + args.SetInt(GRPC_ARG_HTTP2_MAX_PINGS_WITHOUT_DATA, 0); + // max time for a connection attempt + args.SetInt(GRPC_ARG_MIN_RECONNECT_BACKOFF_MS, kReconnectBackoffMs); + // max time between reconnect attempts + args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, kReconnectBackoffMs); + + gpr_log(GPR_DEBUG, "FlakyNetworkTest.ServerUnreachableWithKeepalive start"); + auto channel = BuildChannel("pick_first", args); + auto stub = BuildStub(channel); + // Channel should be in READY state after we send an RPC + EXPECT_TRUE(SendRpc(stub)); + EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY); + + std::atomic_bool shutdown{false}; + std::thread sender = std::thread([this, &stub, &shutdown]() { + while (true) { + if (shutdown.load()) { + return; + } + SendRpc(stub); + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + } + }); + + // break network connectivity + gpr_log(GPR_DEBUG, "Adding iptables rule to drop packets"); + DropPackets(); + std::this_thread::sleep_for(std::chrono::milliseconds(10000)); + EXPECT_TRUE(WaitForChannelNotReady(channel.get())); + // bring network interface back up + RestoreNetwork(); + gpr_log(GPR_DEBUG, "Removed iptables rule to drop packets"); + EXPECT_TRUE(WaitForChannelReady(channel.get())); + EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY); + shutdown.store(true); + sender.join(); + gpr_log(GPR_DEBUG, "FlakyNetworkTest.ServerUnreachableWithKeepalive end"); +} + +// +// Traffic to server server is blackholed temporarily with keepalives disabled +TEST_P(FlakyNetworkTest, ServerUnreachableNoKeepalive) { + auto channel = BuildChannel("pick_first", ChannelArguments()); + auto stub = BuildStub(channel); + // Channel should be in READY state after we send an RPC + EXPECT_TRUE(SendRpc(stub)); + EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY); + + // break network connectivity + DropPackets(); + + std::thread sender = std::thread([this, &stub]() { + // RPC with deadline should timeout + EXPECT_FALSE(SendRpc(stub, /*timeout_ms=*/500, /*wait_for_ready=*/true)); + // RPC without deadline forever until call finishes + EXPECT_TRUE(SendRpc(stub, /*timeout_ms=*/0, /*wait_for_ready=*/true)); + }); + + std::this_thread::sleep_for(std::chrono::milliseconds(2000)); + // bring network interface back up + RestoreNetwork(); + + // wait for RPC to finish + sender.join(); +} + +// Send RPCs over a flaky network connection +TEST_P(FlakyNetworkTest, FlakyNetwork) { + const int kKeepAliveTimeMs = 1000; + const int kKeepAliveTimeoutMs = 1000; + const int kMessageCount = 100; + ChannelArguments args; + args.SetInt(GRPC_ARG_KEEPALIVE_TIME_MS, kKeepAliveTimeMs); + args.SetInt(GRPC_ARG_KEEPALIVE_TIMEOUT_MS, kKeepAliveTimeoutMs); + args.SetInt(GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS, 1); + args.SetInt(GRPC_ARG_HTTP2_MAX_PINGS_WITHOUT_DATA, 0); + + auto channel = BuildChannel("pick_first", args); + auto stub = BuildStub(channel); + // Channel should be in READY state after we send an RPC + EXPECT_TRUE(SendRpc(stub)); + EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY); + + // simulate flaky network (packet loss, corruption and delays) + FlakeNetwork(); + for (int i = 0; i < kMessageCount; ++i) { + SendRpc(stub); + } + // remove network flakiness + UnflakeNetwork(); + EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY); +} + +// Server is shutdown gracefully and restarted. Client keepalives are enabled +TEST_P(FlakyNetworkTest, ServerRestartKeepaliveEnabled) { + const int kKeepAliveTimeMs = 1000; + const int kKeepAliveTimeoutMs = 1000; + ChannelArguments args; + args.SetInt(GRPC_ARG_KEEPALIVE_TIME_MS, kKeepAliveTimeMs); + args.SetInt(GRPC_ARG_KEEPALIVE_TIMEOUT_MS, kKeepAliveTimeoutMs); + args.SetInt(GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS, 1); + args.SetInt(GRPC_ARG_HTTP2_MAX_PINGS_WITHOUT_DATA, 0); + + auto channel = BuildChannel("pick_first", args); + auto stub = BuildStub(channel); + // Channel should be in READY state after we send an RPC + EXPECT_TRUE(SendRpc(stub)); + EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY); + + // server goes down, client should detect server going down and calls should + // fail + StopServer(); + EXPECT_TRUE(WaitForChannelNotReady(channel.get())); + EXPECT_FALSE(SendRpc(stub)); + + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + + // server restarts, calls succeed + StartServer(); + EXPECT_TRUE(WaitForChannelReady(channel.get())); + // EXPECT_TRUE(SendRpc(stub)); +} + +// Server is shutdown gracefully and restarted. Client keepalives are enabled +TEST_P(FlakyNetworkTest, ServerRestartKeepaliveDisabled) { + auto channel = BuildChannel("pick_first", ChannelArguments()); + auto stub = BuildStub(channel); + // Channel should be in READY state after we send an RPC + EXPECT_TRUE(SendRpc(stub)); + EXPECT_EQ(channel->GetState(false), GRPC_CHANNEL_READY); + + // server sends GOAWAY when it's shutdown, so client attempts to reconnect + StopServer(); + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + + EXPECT_TRUE(WaitForChannelNotReady(channel.get())); + + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + + // server restarts, calls succeed + StartServer(); + EXPECT_TRUE(WaitForChannelReady(channel.get())); +} + +} // namespace +} // namespace testing +} // namespace grpc +#endif // GPR_LINUX + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(argc, argv); + auto result = RUN_ALL_TESTS(); + return result; +} diff --git a/contrib/libs/grpc/test/cpp/end2end/generic_end2end_test.cc b/contrib/libs/grpc/test/cpp/end2end/generic_end2end_test.cc new file mode 100644 index 0000000000..59eec49fb2 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/end2end/generic_end2end_test.cc @@ -0,0 +1,430 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <memory> +#include <thread> + +#include <grpc/grpc.h> +#include <grpc/support/time.h> +#include <grpcpp/channel.h> +#include <grpcpp/client_context.h> +#include <grpcpp/create_channel.h> +#include <grpcpp/generic/async_generic_service.h> +#include <grpcpp/generic/generic_stub.h> +#include <grpcpp/impl/codegen/proto_utils.h> +#include <grpcpp/server.h> +#include <grpcpp/server_builder.h> +#include <grpcpp/server_context.h> +#include <grpcpp/support/slice.h> + +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/util/byte_buffer_proto_helper.h" + +#include <gtest/gtest.h> + +using grpc::testing::EchoRequest; +using grpc::testing::EchoResponse; +using std::chrono::system_clock; + +namespace grpc { +namespace testing { +namespace { + +void* tag(int i) { return (void*)static_cast<intptr_t>(i); } + +void verify_ok(CompletionQueue* cq, int i, bool expect_ok) { + bool ok; + void* got_tag; + EXPECT_TRUE(cq->Next(&got_tag, &ok)); + EXPECT_EQ(expect_ok, ok); + EXPECT_EQ(tag(i), got_tag); +} + +class GenericEnd2endTest : public ::testing::Test { + protected: + GenericEnd2endTest() : server_host_("localhost") {} + + void SetUp() override { + shut_down_ = false; + int port = grpc_pick_unused_port_or_die(); + server_address_ << server_host_ << ":" << port; + // Setup server + ServerBuilder builder; + builder.AddListeningPort(server_address_.str(), + InsecureServerCredentials()); + builder.RegisterAsyncGenericService(&generic_service_); + // Include a second call to RegisterAsyncGenericService to make sure that + // we get an error in the log, since it is not allowed to have 2 async + // generic services + builder.RegisterAsyncGenericService(&generic_service_); + srv_cq_ = builder.AddCompletionQueue(); + server_ = builder.BuildAndStart(); + } + + void ShutDownServerAndCQs() { + if (!shut_down_) { + server_->Shutdown(); + void* ignored_tag; + bool ignored_ok; + cli_cq_.Shutdown(); + srv_cq_->Shutdown(); + while (cli_cq_.Next(&ignored_tag, &ignored_ok)) + ; + while (srv_cq_->Next(&ignored_tag, &ignored_ok)) + ; + shut_down_ = true; + } + } + void TearDown() override { ShutDownServerAndCQs(); } + + void ResetStub() { + std::shared_ptr<Channel> channel = grpc::CreateChannel( + server_address_.str(), InsecureChannelCredentials()); + stub_ = grpc::testing::EchoTestService::NewStub(channel); + generic_stub_.reset(new GenericStub(channel)); + } + + void server_ok(int i) { verify_ok(srv_cq_.get(), i, true); } + void client_ok(int i) { verify_ok(&cli_cq_, i, true); } + void server_fail(int i) { verify_ok(srv_cq_.get(), i, false); } + void client_fail(int i) { verify_ok(&cli_cq_, i, false); } + + void SendRpc(int num_rpcs) { + SendRpc(num_rpcs, false, gpr_inf_future(GPR_CLOCK_MONOTONIC)); + } + + void SendRpc(int num_rpcs, bool check_deadline, gpr_timespec deadline) { + const TString kMethodName("/grpc.cpp.test.util.EchoTestService/Echo"); + for (int i = 0; i < num_rpcs; i++) { + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + GenericServerContext srv_ctx; + GenericServerAsyncReaderWriter stream(&srv_ctx); + + // The string needs to be long enough to test heap-based slice. + send_request.set_message("Hello world. Hello world. Hello world."); + + if (check_deadline) { + cli_ctx.set_deadline(deadline); + } + + // Rather than using the original kMethodName, make a short-lived + // copy to also confirm that we don't refer to this object beyond + // the initial call preparation + const TString* method_name = new TString(kMethodName); + + std::unique_ptr<GenericClientAsyncReaderWriter> call = + generic_stub_->PrepareCall(&cli_ctx, *method_name, &cli_cq_); + + delete method_name; // Make sure that this is not needed after invocation + + std::thread request_call([this]() { server_ok(4); }); + call->StartCall(tag(1)); + client_ok(1); + std::unique_ptr<ByteBuffer> send_buffer = + SerializeToByteBuffer(&send_request); + call->Write(*send_buffer, tag(2)); + // Send ByteBuffer can be destroyed after calling Write. + send_buffer.reset(); + client_ok(2); + call->WritesDone(tag(3)); + client_ok(3); + + generic_service_.RequestCall(&srv_ctx, &stream, srv_cq_.get(), + srv_cq_.get(), tag(4)); + + request_call.join(); + EXPECT_EQ(server_host_, srv_ctx.host().substr(0, server_host_.length())); + EXPECT_EQ(kMethodName, srv_ctx.method()); + + if (check_deadline) { + EXPECT_TRUE(gpr_time_similar(deadline, srv_ctx.raw_deadline(), + gpr_time_from_millis(1000, GPR_TIMESPAN))); + } + + ByteBuffer recv_buffer; + stream.Read(&recv_buffer, tag(5)); + server_ok(5); + EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_request)); + EXPECT_EQ(send_request.message(), recv_request.message()); + + send_response.set_message(recv_request.message()); + send_buffer = SerializeToByteBuffer(&send_response); + stream.Write(*send_buffer, tag(6)); + send_buffer.reset(); + server_ok(6); + + stream.Finish(Status::OK, tag(7)); + server_ok(7); + + recv_buffer.Clear(); + call->Read(&recv_buffer, tag(8)); + client_ok(8); + EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_response)); + + call->Finish(&recv_status, tag(9)); + client_ok(9); + + EXPECT_EQ(send_response.message(), recv_response.message()); + EXPECT_TRUE(recv_status.ok()); + } + } + + // Return errors to up to one call that comes in on the supplied completion + // queue, until the CQ is being shut down (and therefore we can no longer + // enqueue further events). + void DriveCompletionQueue() { + enum class Event : uintptr_t { + kCallReceived, + kResponseSent, + }; + // Request the call, but only if the main thread hasn't beaten us to + // shutting down the CQ. + grpc::GenericServerContext server_context; + grpc::GenericServerAsyncReaderWriter reader_writer(&server_context); + + { + std::lock_guard<std::mutex> lock(shutting_down_mu_); + if (!shutting_down_) { + generic_service_.RequestCall( + &server_context, &reader_writer, srv_cq_.get(), srv_cq_.get(), + reinterpret_cast<void*>(Event::kCallReceived)); + } + } + // Process events. + { + Event event; + bool ok; + while (srv_cq_->Next(reinterpret_cast<void**>(&event), &ok)) { + std::lock_guard<std::mutex> lock(shutting_down_mu_); + if (shutting_down_) { + // The main thread has started shutting down. Simply continue to drain + // events. + continue; + } + + switch (event) { + case Event::kCallReceived: + reader_writer.Finish( + ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, "go away"), + reinterpret_cast<void*>(Event::kResponseSent)); + break; + + case Event::kResponseSent: + // We are done. + break; + } + } + } + } + + CompletionQueue cli_cq_; + std::unique_ptr<ServerCompletionQueue> srv_cq_; + std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_; + std::unique_ptr<grpc::GenericStub> generic_stub_; + std::unique_ptr<Server> server_; + AsyncGenericService generic_service_; + const TString server_host_; + std::ostringstream server_address_; + bool shutting_down_; + bool shut_down_; + std::mutex shutting_down_mu_; +}; + +TEST_F(GenericEnd2endTest, SimpleRpc) { + ResetStub(); + SendRpc(1); +} + +TEST_F(GenericEnd2endTest, SequentialRpcs) { + ResetStub(); + SendRpc(10); +} + +TEST_F(GenericEnd2endTest, SequentialUnaryRpcs) { + ResetStub(); + const int num_rpcs = 10; + const TString kMethodName("/grpc.cpp.test.util.EchoTestService/Echo"); + for (int i = 0; i < num_rpcs; i++) { + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + GenericServerContext srv_ctx; + GenericServerAsyncReaderWriter stream(&srv_ctx); + + // The string needs to be long enough to test heap-based slice. + send_request.set_message("Hello world. Hello world. Hello world."); + + std::unique_ptr<ByteBuffer> cli_send_buffer = + SerializeToByteBuffer(&send_request); + std::thread request_call([this]() { server_ok(4); }); + std::unique_ptr<GenericClientAsyncResponseReader> call = + generic_stub_->PrepareUnaryCall(&cli_ctx, kMethodName, + *cli_send_buffer.get(), &cli_cq_); + call->StartCall(); + ByteBuffer cli_recv_buffer; + call->Finish(&cli_recv_buffer, &recv_status, tag(1)); + std::thread client_check([this] { client_ok(1); }); + + generic_service_.RequestCall(&srv_ctx, &stream, srv_cq_.get(), + srv_cq_.get(), tag(4)); + request_call.join(); + EXPECT_EQ(server_host_, srv_ctx.host().substr(0, server_host_.length())); + EXPECT_EQ(kMethodName, srv_ctx.method()); + + ByteBuffer srv_recv_buffer; + stream.Read(&srv_recv_buffer, tag(5)); + server_ok(5); + EXPECT_TRUE(ParseFromByteBuffer(&srv_recv_buffer, &recv_request)); + EXPECT_EQ(send_request.message(), recv_request.message()); + + send_response.set_message(recv_request.message()); + std::unique_ptr<ByteBuffer> srv_send_buffer = + SerializeToByteBuffer(&send_response); + stream.Write(*srv_send_buffer, tag(6)); + server_ok(6); + + stream.Finish(Status::OK, tag(7)); + server_ok(7); + + client_check.join(); + EXPECT_TRUE(ParseFromByteBuffer(&cli_recv_buffer, &recv_response)); + EXPECT_EQ(send_response.message(), recv_response.message()); + EXPECT_TRUE(recv_status.ok()); + } +} + +// One ping, one pong. +TEST_F(GenericEnd2endTest, SimpleBidiStreaming) { + ResetStub(); + + const TString kMethodName( + "/grpc.cpp.test.util.EchoTestService/BidiStream"); + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + ClientContext cli_ctx; + GenericServerContext srv_ctx; + GenericServerAsyncReaderWriter srv_stream(&srv_ctx); + + cli_ctx.set_compression_algorithm(GRPC_COMPRESS_GZIP); + send_request.set_message("Hello"); + std::thread request_call([this]() { server_ok(2); }); + std::unique_ptr<GenericClientAsyncReaderWriter> cli_stream = + generic_stub_->PrepareCall(&cli_ctx, kMethodName, &cli_cq_); + cli_stream->StartCall(tag(1)); + client_ok(1); + + generic_service_.RequestCall(&srv_ctx, &srv_stream, srv_cq_.get(), + srv_cq_.get(), tag(2)); + request_call.join(); + + EXPECT_EQ(server_host_, srv_ctx.host().substr(0, server_host_.length())); + EXPECT_EQ(kMethodName, srv_ctx.method()); + + std::unique_ptr<ByteBuffer> send_buffer = + SerializeToByteBuffer(&send_request); + cli_stream->Write(*send_buffer, tag(3)); + send_buffer.reset(); + client_ok(3); + + ByteBuffer recv_buffer; + srv_stream.Read(&recv_buffer, tag(4)); + server_ok(4); + EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_request)); + EXPECT_EQ(send_request.message(), recv_request.message()); + + send_response.set_message(recv_request.message()); + send_buffer = SerializeToByteBuffer(&send_response); + srv_stream.Write(*send_buffer, tag(5)); + send_buffer.reset(); + server_ok(5); + + cli_stream->Read(&recv_buffer, tag(6)); + client_ok(6); + EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_response)); + EXPECT_EQ(send_response.message(), recv_response.message()); + + cli_stream->WritesDone(tag(7)); + client_ok(7); + + srv_stream.Read(&recv_buffer, tag(8)); + server_fail(8); + + srv_stream.Finish(Status::OK, tag(9)); + server_ok(9); + + cli_stream->Finish(&recv_status, tag(10)); + client_ok(10); + + EXPECT_EQ(send_response.message(), recv_response.message()); + EXPECT_TRUE(recv_status.ok()); +} + +TEST_F(GenericEnd2endTest, Deadline) { + ResetStub(); + SendRpc(1, true, + gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC), + gpr_time_from_seconds(10, GPR_TIMESPAN))); +} + +TEST_F(GenericEnd2endTest, ShortDeadline) { + ResetStub(); + + ClientContext cli_ctx; + EchoRequest request; + EchoResponse response; + + shutting_down_ = false; + std::thread driver([this] { DriveCompletionQueue(); }); + + request.set_message(""); + cli_ctx.set_deadline(gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC), + gpr_time_from_micros(500, GPR_TIMESPAN))); + Status s = stub_->Echo(&cli_ctx, request, &response); + EXPECT_FALSE(s.ok()); + { + std::lock_guard<std::mutex> lock(shutting_down_mu_); + shutting_down_ = true; + } + ShutDownServerAndCQs(); + driver.join(); +} + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/contrib/libs/grpc/test/cpp/end2end/grpclb_end2end_test.cc b/contrib/libs/grpc/test/cpp/end2end/grpclb_end2end_test.cc new file mode 100644 index 0000000000..6208dc2535 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/end2end/grpclb_end2end_test.cc @@ -0,0 +1,2029 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <deque> +#include <memory> +#include <mutex> +#include <set> +#include <sstream> +#include <util/generic/string.h> +#include <thread> + +#include "y_absl/strings/str_cat.h" +#include "y_absl/strings/str_format.h" + +#include <grpc/grpc.h> +#include <grpc/support/alloc.h> +#include <grpc/support/log.h> +#include <grpc/support/time.h> +#include <grpcpp/channel.h> +#include <grpcpp/client_context.h> +#include <grpcpp/create_channel.h> +#include <grpcpp/impl/codegen/sync.h> +#include <grpcpp/server.h> +#include <grpcpp/server_builder.h> + +#include "src/core/ext/filters/client_channel/backup_poller.h" +#include "src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_balancer_addresses.h" +#include "src/core/ext/filters/client_channel/resolver/fake/fake_resolver.h" +#include "src/core/ext/filters/client_channel/server_address.h" +#include "src/core/ext/filters/client_channel/service_config.h" +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/iomgr/parse_address.h" +#include "src/core/lib/iomgr/sockaddr.h" +#include "src/core/lib/security/credentials/fake/fake_credentials.h" +#include "src/core/lib/transport/authority_override.h" +#include "src/cpp/client/secure_credentials.h" +#include "src/cpp/server/secure_server_credentials.h" + +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/end2end/test_service_impl.h" + +#include "src/proto/grpc/lb/v1/load_balancer.grpc.pb.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" + +#include <gmock/gmock.h> +#include <gtest/gtest.h> + +// TODO(dgq): Other scenarios in need of testing: +// - Send a serverlist with faulty ip:port addresses (port > 2^16, etc). +// - Test reception of invalid serverlist +// - Test against a non-LB server. +// - Random LB server closing the stream unexpectedly. +// +// Findings from end to end testing to be covered here: +// - Handling of LB servers restart, including reconnection after backing-off +// retries. +// - Destruction of load balanced channel (and therefore of grpclb instance) +// while: +// 1) the internal LB call is still active. This should work by virtue +// of the weak reference the LB call holds. The call should be terminated as +// part of the grpclb shutdown process. +// 2) the retry timer is active. Again, the weak reference it holds should +// prevent a premature call to \a glb_destroy. + +using std::chrono::system_clock; + +using grpc::lb::v1::LoadBalancer; +using grpc::lb::v1::LoadBalanceRequest; +using grpc::lb::v1::LoadBalanceResponse; + +namespace grpc { +namespace testing { +namespace { + +constexpr char kDefaultServiceConfig[] = + "{\n" + " \"loadBalancingConfig\":[\n" + " { \"grpclb\":{} }\n" + " ]\n" + "}"; + +template <typename ServiceType> +class CountedService : public ServiceType { + public: + size_t request_count() { + grpc::internal::MutexLock lock(&mu_); + return request_count_; + } + + size_t response_count() { + grpc::internal::MutexLock lock(&mu_); + return response_count_; + } + + void IncreaseResponseCount() { + grpc::internal::MutexLock lock(&mu_); + ++response_count_; + } + void IncreaseRequestCount() { + grpc::internal::MutexLock lock(&mu_); + ++request_count_; + } + + void ResetCounters() { + grpc::internal::MutexLock lock(&mu_); + request_count_ = 0; + response_count_ = 0; + } + + protected: + grpc::internal::Mutex mu_; + + private: + size_t request_count_ = 0; + size_t response_count_ = 0; +}; + +using BackendService = CountedService<TestServiceImpl>; +using BalancerService = CountedService<LoadBalancer::Service>; + +const char g_kCallCredsMdKey[] = "Balancer should not ..."; +const char g_kCallCredsMdValue[] = "... receive me"; + +class BackendServiceImpl : public BackendService { + public: + BackendServiceImpl() {} + + Status Echo(ServerContext* context, const EchoRequest* request, + EchoResponse* response) override { + // Backend should receive the call credentials metadata. + auto call_credentials_entry = + context->client_metadata().find(g_kCallCredsMdKey); + EXPECT_NE(call_credentials_entry, context->client_metadata().end()); + if (call_credentials_entry != context->client_metadata().end()) { + EXPECT_EQ(call_credentials_entry->second, g_kCallCredsMdValue); + } + IncreaseRequestCount(); + const auto status = TestServiceImpl::Echo(context, request, response); + IncreaseResponseCount(); + AddClient(context->peer().c_str()); + return status; + } + + void Start() {} + + void Shutdown() {} + + std::set<TString> clients() { + grpc::internal::MutexLock lock(&clients_mu_); + return clients_; + } + + private: + void AddClient(const TString& client) { + grpc::internal::MutexLock lock(&clients_mu_); + clients_.insert(client); + } + + grpc::internal::Mutex mu_; + grpc::internal::Mutex clients_mu_; + std::set<TString> clients_; +}; + +TString Ip4ToPackedString(const char* ip_str) { + struct in_addr ip4; + GPR_ASSERT(inet_pton(AF_INET, ip_str, &ip4) == 1); + return TString(reinterpret_cast<const char*>(&ip4), sizeof(ip4)); +} + +struct ClientStats { + size_t num_calls_started = 0; + size_t num_calls_finished = 0; + size_t num_calls_finished_with_client_failed_to_send = 0; + size_t num_calls_finished_known_received = 0; + std::map<TString, size_t> drop_token_counts; + + ClientStats& operator+=(const ClientStats& other) { + num_calls_started += other.num_calls_started; + num_calls_finished += other.num_calls_finished; + num_calls_finished_with_client_failed_to_send += + other.num_calls_finished_with_client_failed_to_send; + num_calls_finished_known_received += + other.num_calls_finished_known_received; + for (const auto& p : other.drop_token_counts) { + drop_token_counts[p.first] += p.second; + } + return *this; + } + + void Reset() { + num_calls_started = 0; + num_calls_finished = 0; + num_calls_finished_with_client_failed_to_send = 0; + num_calls_finished_known_received = 0; + drop_token_counts.clear(); + } +}; + +class BalancerServiceImpl : public BalancerService { + public: + using Stream = ServerReaderWriter<LoadBalanceResponse, LoadBalanceRequest>; + using ResponseDelayPair = std::pair<LoadBalanceResponse, int>; + + explicit BalancerServiceImpl(int client_load_reporting_interval_seconds) + : client_load_reporting_interval_seconds_( + client_load_reporting_interval_seconds) {} + + Status BalanceLoad(ServerContext* context, Stream* stream) override { + gpr_log(GPR_INFO, "LB[%p]: BalanceLoad", this); + { + grpc::internal::MutexLock lock(&mu_); + if (serverlist_done_) goto done; + } + { + // Balancer shouldn't receive the call credentials metadata. + EXPECT_EQ(context->client_metadata().find(g_kCallCredsMdKey), + context->client_metadata().end()); + LoadBalanceRequest request; + std::vector<ResponseDelayPair> responses_and_delays; + + if (!stream->Read(&request)) { + goto done; + } else { + if (request.has_initial_request()) { + grpc::internal::MutexLock lock(&mu_); + service_names_.push_back(request.initial_request().name()); + } + } + IncreaseRequestCount(); + gpr_log(GPR_INFO, "LB[%p]: received initial message '%s'", this, + request.DebugString().c_str()); + + // TODO(juanlishen): Initial response should always be the first response. + if (client_load_reporting_interval_seconds_ > 0) { + LoadBalanceResponse initial_response; + initial_response.mutable_initial_response() + ->mutable_client_stats_report_interval() + ->set_seconds(client_load_reporting_interval_seconds_); + stream->Write(initial_response); + } + + { + grpc::internal::MutexLock lock(&mu_); + responses_and_delays = responses_and_delays_; + } + for (const auto& response_and_delay : responses_and_delays) { + SendResponse(stream, response_and_delay.first, + response_and_delay.second); + } + { + grpc::internal::MutexLock lock(&mu_); + serverlist_cond_.WaitUntil(&mu_, [this] { return serverlist_done_; }); + } + + if (client_load_reporting_interval_seconds_ > 0) { + request.Clear(); + while (stream->Read(&request)) { + gpr_log(GPR_INFO, "LB[%p]: received client load report message '%s'", + this, request.DebugString().c_str()); + GPR_ASSERT(request.has_client_stats()); + ClientStats load_report; + load_report.num_calls_started = + request.client_stats().num_calls_started(); + load_report.num_calls_finished = + request.client_stats().num_calls_finished(); + load_report.num_calls_finished_with_client_failed_to_send = + request.client_stats() + .num_calls_finished_with_client_failed_to_send(); + load_report.num_calls_finished_known_received = + request.client_stats().num_calls_finished_known_received(); + for (const auto& drop_token_count : + request.client_stats().calls_finished_with_drop()) { + load_report + .drop_token_counts[drop_token_count.load_balance_token()] = + drop_token_count.num_calls(); + } + // We need to acquire the lock here in order to prevent the notify_one + // below from firing before its corresponding wait is executed. + grpc::internal::MutexLock lock(&mu_); + load_report_queue_.emplace_back(std::move(load_report)); + if (load_report_cond_ != nullptr) load_report_cond_->Signal(); + } + } + } + done: + gpr_log(GPR_INFO, "LB[%p]: done", this); + return Status::OK; + } + + void add_response(const LoadBalanceResponse& response, int send_after_ms) { + grpc::internal::MutexLock lock(&mu_); + responses_and_delays_.push_back(std::make_pair(response, send_after_ms)); + } + + void Start() { + grpc::internal::MutexLock lock(&mu_); + serverlist_done_ = false; + responses_and_delays_.clear(); + load_report_queue_.clear(); + } + + void Shutdown() { + NotifyDoneWithServerlists(); + gpr_log(GPR_INFO, "LB[%p]: shut down", this); + } + + static LoadBalanceResponse BuildResponseForBackends( + const std::vector<int>& backend_ports, + const std::map<TString, size_t>& drop_token_counts) { + LoadBalanceResponse response; + for (const auto& drop_token_count : drop_token_counts) { + for (size_t i = 0; i < drop_token_count.second; ++i) { + auto* server = response.mutable_server_list()->add_servers(); + server->set_drop(true); + server->set_load_balance_token(drop_token_count.first); + } + } + for (const int& backend_port : backend_ports) { + auto* server = response.mutable_server_list()->add_servers(); + server->set_ip_address(Ip4ToPackedString("127.0.0.1")); + server->set_port(backend_port); + static int token_count = 0; + server->set_load_balance_token( + y_absl::StrFormat("token%03d", ++token_count)); + } + return response; + } + + ClientStats WaitForLoadReport() { + grpc::internal::MutexLock lock(&mu_); + grpc::internal::CondVar cv; + if (load_report_queue_.empty()) { + load_report_cond_ = &cv; + load_report_cond_->WaitUntil( + &mu_, [this] { return !load_report_queue_.empty(); }); + load_report_cond_ = nullptr; + } + ClientStats load_report = std::move(load_report_queue_.front()); + load_report_queue_.pop_front(); + return load_report; + } + + void NotifyDoneWithServerlists() { + grpc::internal::MutexLock lock(&mu_); + if (!serverlist_done_) { + serverlist_done_ = true; + serverlist_cond_.Broadcast(); + } + } + + std::vector<TString> service_names() { + grpc::internal::MutexLock lock(&mu_); + return service_names_; + } + + private: + void SendResponse(Stream* stream, const LoadBalanceResponse& response, + int delay_ms) { + gpr_log(GPR_INFO, "LB[%p]: sleeping for %d ms...", this, delay_ms); + if (delay_ms > 0) { + gpr_sleep_until(grpc_timeout_milliseconds_to_deadline(delay_ms)); + } + gpr_log(GPR_INFO, "LB[%p]: Woke up! Sending response '%s'", this, + response.DebugString().c_str()); + IncreaseResponseCount(); + stream->Write(response); + } + + const int client_load_reporting_interval_seconds_; + std::vector<ResponseDelayPair> responses_and_delays_; + std::vector<TString> service_names_; + + grpc::internal::Mutex mu_; + grpc::internal::CondVar serverlist_cond_; + bool serverlist_done_ = false; + grpc::internal::CondVar* load_report_cond_ = nullptr; + std::deque<ClientStats> load_report_queue_; +}; + +class GrpclbEnd2endTest : public ::testing::Test { + protected: + GrpclbEnd2endTest(size_t num_backends, size_t num_balancers, + int client_load_reporting_interval_seconds) + : server_host_("localhost"), + num_backends_(num_backends), + num_balancers_(num_balancers), + client_load_reporting_interval_seconds_( + client_load_reporting_interval_seconds) {} + + static void SetUpTestCase() { + // Make the backup poller poll very frequently in order to pick up + // updates from all the subchannels's FDs. + GPR_GLOBAL_CONFIG_SET(grpc_client_channel_backup_poll_interval_ms, 1); +#if TARGET_OS_IPHONE + // Workaround Apple CFStream bug + gpr_setenv("grpc_cfstream", "0"); +#endif + grpc_init(); + } + + static void TearDownTestCase() { grpc_shutdown(); } + + void SetUp() override { + response_generator_ = + grpc_core::MakeRefCounted<grpc_core::FakeResolverResponseGenerator>(); + // Start the backends. + for (size_t i = 0; i < num_backends_; ++i) { + backends_.emplace_back(new ServerThread<BackendServiceImpl>("backend")); + backends_.back()->Start(server_host_); + } + // Start the load balancers. + for (size_t i = 0; i < num_balancers_; ++i) { + balancers_.emplace_back(new ServerThread<BalancerServiceImpl>( + "balancer", client_load_reporting_interval_seconds_)); + balancers_.back()->Start(server_host_); + } + ResetStub(); + } + + void TearDown() override { + ShutdownAllBackends(); + for (auto& balancer : balancers_) balancer->Shutdown(); + } + + void StartAllBackends() { + for (auto& backend : backends_) backend->Start(server_host_); + } + + void StartBackend(size_t index) { backends_[index]->Start(server_host_); } + + void ShutdownAllBackends() { + for (auto& backend : backends_) backend->Shutdown(); + } + + void ShutdownBackend(size_t index) { backends_[index]->Shutdown(); } + + void ResetStub(int fallback_timeout = 0, + const TString& expected_targets = "") { + ChannelArguments args; + if (fallback_timeout > 0) args.SetGrpclbFallbackTimeout(fallback_timeout); + args.SetPointer(GRPC_ARG_FAKE_RESOLVER_RESPONSE_GENERATOR, + response_generator_.get()); + if (!expected_targets.empty()) { + args.SetString(GRPC_ARG_FAKE_SECURITY_EXPECTED_TARGETS, expected_targets); + } + std::ostringstream uri; + uri << "fake:///" << kApplicationTargetName_; + // TODO(dgq): templatize tests to run everything using both secure and + // insecure channel credentials. + grpc_channel_credentials* channel_creds = + grpc_fake_transport_security_credentials_create(); + grpc_call_credentials* call_creds = grpc_md_only_test_credentials_create( + g_kCallCredsMdKey, g_kCallCredsMdValue, false); + std::shared_ptr<ChannelCredentials> creds( + new SecureChannelCredentials(grpc_composite_channel_credentials_create( + channel_creds, call_creds, nullptr))); + call_creds->Unref(); + channel_creds->Unref(); + channel_ = ::grpc::CreateCustomChannel(uri.str(), creds, args); + stub_ = grpc::testing::EchoTestService::NewStub(channel_); + } + + void ResetBackendCounters() { + for (auto& backend : backends_) backend->service_.ResetCounters(); + } + + ClientStats WaitForLoadReports() { + ClientStats client_stats; + for (auto& balancer : balancers_) { + client_stats += balancer->service_.WaitForLoadReport(); + } + return client_stats; + } + + bool SeenAllBackends(size_t start_index = 0, size_t stop_index = 0) { + if (stop_index == 0) stop_index = backends_.size(); + for (size_t i = start_index; i < stop_index; ++i) { + if (backends_[i]->service_.request_count() == 0) return false; + } + return true; + } + + void SendRpcAndCount(int* num_total, int* num_ok, int* num_failure, + int* num_drops) { + const Status status = SendRpc(); + if (status.ok()) { + ++*num_ok; + } else { + if (status.error_message() == "Call dropped by load balancing policy") { + ++*num_drops; + } else { + ++*num_failure; + } + } + ++*num_total; + } + + std::tuple<int, int, int> WaitForAllBackends(int num_requests_multiple_of = 1, + size_t start_index = 0, + size_t stop_index = 0) { + int num_ok = 0; + int num_failure = 0; + int num_drops = 0; + int num_total = 0; + while (!SeenAllBackends(start_index, stop_index)) { + SendRpcAndCount(&num_total, &num_ok, &num_failure, &num_drops); + } + while (num_total % num_requests_multiple_of != 0) { + SendRpcAndCount(&num_total, &num_ok, &num_failure, &num_drops); + } + ResetBackendCounters(); + gpr_log(GPR_INFO, + "Performed %d warm up requests (a multiple of %d) against the " + "backends. %d succeeded, %d failed, %d dropped.", + num_total, num_requests_multiple_of, num_ok, num_failure, + num_drops); + return std::make_tuple(num_ok, num_failure, num_drops); + } + + void WaitForBackend(size_t backend_idx) { + do { + (void)SendRpc(); + } while (backends_[backend_idx]->service_.request_count() == 0); + ResetBackendCounters(); + } + + struct AddressData { + int port; + TString balancer_name; + }; + + static grpc_core::ServerAddressList CreateLbAddressesFromAddressDataList( + const std::vector<AddressData>& address_data) { + grpc_core::ServerAddressList addresses; + for (const auto& addr : address_data) { + TString lb_uri_str = y_absl::StrCat("ipv4:127.0.0.1:", addr.port); + grpc_uri* lb_uri = grpc_uri_parse(lb_uri_str.c_str(), true); + GPR_ASSERT(lb_uri != nullptr); + grpc_resolved_address address; + GPR_ASSERT(grpc_parse_uri(lb_uri, &address)); + grpc_arg arg = grpc_core::CreateAuthorityOverrideChannelArg( + addr.balancer_name.c_str()); + grpc_channel_args* args = + grpc_channel_args_copy_and_add(nullptr, &arg, 1); + addresses.emplace_back(address.addr, address.len, args); + grpc_uri_destroy(lb_uri); + } + return addresses; + } + + static grpc_core::Resolver::Result MakeResolverResult( + const std::vector<AddressData>& balancer_address_data, + const std::vector<AddressData>& backend_address_data = {}, + const char* service_config_json = kDefaultServiceConfig) { + grpc_core::Resolver::Result result; + result.addresses = + CreateLbAddressesFromAddressDataList(backend_address_data); + grpc_error* error = GRPC_ERROR_NONE; + result.service_config = + grpc_core::ServiceConfig::Create(nullptr, service_config_json, &error); + GPR_ASSERT(error == GRPC_ERROR_NONE); + grpc_core::ServerAddressList balancer_addresses = + CreateLbAddressesFromAddressDataList(balancer_address_data); + grpc_arg arg = CreateGrpclbBalancerAddressesArg(&balancer_addresses); + result.args = grpc_channel_args_copy_and_add(nullptr, &arg, 1); + return result; + } + + void SetNextResolutionAllBalancers( + const char* service_config_json = kDefaultServiceConfig) { + std::vector<AddressData> addresses; + for (size_t i = 0; i < balancers_.size(); ++i) { + addresses.emplace_back(AddressData{balancers_[i]->port_, ""}); + } + SetNextResolution(addresses, {}, service_config_json); + } + + void SetNextResolution( + const std::vector<AddressData>& balancer_address_data, + const std::vector<AddressData>& backend_address_data = {}, + const char* service_config_json = kDefaultServiceConfig) { + grpc_core::ExecCtx exec_ctx; + grpc_core::Resolver::Result result = MakeResolverResult( + balancer_address_data, backend_address_data, service_config_json); + response_generator_->SetResponse(std::move(result)); + } + + void SetNextReresolutionResponse( + const std::vector<AddressData>& balancer_address_data, + const std::vector<AddressData>& backend_address_data = {}, + const char* service_config_json = kDefaultServiceConfig) { + grpc_core::ExecCtx exec_ctx; + grpc_core::Resolver::Result result = MakeResolverResult( + balancer_address_data, backend_address_data, service_config_json); + response_generator_->SetReresolutionResponse(std::move(result)); + } + + const std::vector<int> GetBackendPorts(size_t start_index = 0, + size_t stop_index = 0) const { + if (stop_index == 0) stop_index = backends_.size(); + std::vector<int> backend_ports; + for (size_t i = start_index; i < stop_index; ++i) { + backend_ports.push_back(backends_[i]->port_); + } + return backend_ports; + } + + void ScheduleResponseForBalancer(size_t i, + const LoadBalanceResponse& response, + int delay_ms) { + balancers_[i]->service_.add_response(response, delay_ms); + } + + Status SendRpc(EchoResponse* response = nullptr, int timeout_ms = 1000, + bool wait_for_ready = false, + const Status& expected_status = Status::OK) { + const bool local_response = (response == nullptr); + if (local_response) response = new EchoResponse; + EchoRequest request; + request.set_message(kRequestMessage_); + if (!expected_status.ok()) { + auto* error = request.mutable_param()->mutable_expected_error(); + error->set_code(expected_status.error_code()); + error->set_error_message(expected_status.error_message()); + } + ClientContext context; + context.set_deadline(grpc_timeout_milliseconds_to_deadline(timeout_ms)); + if (wait_for_ready) context.set_wait_for_ready(true); + Status status = stub_->Echo(&context, request, response); + if (local_response) delete response; + return status; + } + + void CheckRpcSendOk(const size_t times = 1, const int timeout_ms = 1000, + bool wait_for_ready = false) { + for (size_t i = 0; i < times; ++i) { + EchoResponse response; + const Status status = SendRpc(&response, timeout_ms, wait_for_ready); + EXPECT_TRUE(status.ok()) << "code=" << status.error_code() + << " message=" << status.error_message(); + EXPECT_EQ(response.message(), kRequestMessage_); + } + } + + void CheckRpcSendFailure() { + const Status status = SendRpc(); + EXPECT_FALSE(status.ok()); + } + + template <typename T> + struct ServerThread { + template <typename... Args> + explicit ServerThread(const TString& type, Args&&... args) + : port_(grpc_pick_unused_port_or_die()), + type_(type), + service_(std::forward<Args>(args)...) {} + + void Start(const TString& server_host) { + gpr_log(GPR_INFO, "starting %s server on port %d", type_.c_str(), port_); + GPR_ASSERT(!running_); + running_ = true; + service_.Start(); + grpc::internal::Mutex mu; + // We need to acquire the lock here in order to prevent the notify_one + // by ServerThread::Serve from firing before the wait below is hit. + grpc::internal::MutexLock lock(&mu); + grpc::internal::CondVar cond; + thread_.reset(new std::thread( + std::bind(&ServerThread::Serve, this, server_host, &mu, &cond))); + cond.Wait(&mu); + gpr_log(GPR_INFO, "%s server startup complete", type_.c_str()); + } + + void Serve(const TString& server_host, grpc::internal::Mutex* mu, + grpc::internal::CondVar* cond) { + // We need to acquire the lock here in order to prevent the notify_one + // below from firing before its corresponding wait is executed. + grpc::internal::MutexLock lock(mu); + std::ostringstream server_address; + server_address << server_host << ":" << port_; + ServerBuilder builder; + std::shared_ptr<ServerCredentials> creds(new SecureServerCredentials( + grpc_fake_transport_security_server_credentials_create())); + builder.AddListeningPort(server_address.str(), creds); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + cond->Signal(); + } + + void Shutdown() { + if (!running_) return; + gpr_log(GPR_INFO, "%s about to shutdown", type_.c_str()); + service_.Shutdown(); + server_->Shutdown(grpc_timeout_milliseconds_to_deadline(0)); + thread_->join(); + gpr_log(GPR_INFO, "%s shutdown completed", type_.c_str()); + running_ = false; + } + + const int port_; + TString type_; + T service_; + std::unique_ptr<Server> server_; + std::unique_ptr<std::thread> thread_; + bool running_ = false; + }; + + const TString server_host_; + const size_t num_backends_; + const size_t num_balancers_; + const int client_load_reporting_interval_seconds_; + std::shared_ptr<Channel> channel_; + std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_; + std::vector<std::unique_ptr<ServerThread<BackendServiceImpl>>> backends_; + std::vector<std::unique_ptr<ServerThread<BalancerServiceImpl>>> balancers_; + grpc_core::RefCountedPtr<grpc_core::FakeResolverResponseGenerator> + response_generator_; + const TString kRequestMessage_ = "Live long and prosper."; + const TString kApplicationTargetName_ = "application_target_name"; +}; + +class SingleBalancerTest : public GrpclbEnd2endTest { + public: + SingleBalancerTest() : GrpclbEnd2endTest(4, 1, 0) {} +}; + +TEST_F(SingleBalancerTest, Vanilla) { + SetNextResolutionAllBalancers(); + const size_t kNumRpcsPerAddress = 100; + ScheduleResponseForBalancer( + 0, BalancerServiceImpl::BuildResponseForBackends(GetBackendPorts(), {}), + 0); + // Make sure that trying to connect works without a call. + channel_->GetState(true /* try_to_connect */); + // We need to wait for all backends to come online. + WaitForAllBackends(); + // Send kNumRpcsPerAddress RPCs per server. + CheckRpcSendOk(kNumRpcsPerAddress * num_backends_); + + // Each backend should have gotten 100 requests. + for (size_t i = 0; i < backends_.size(); ++i) { + EXPECT_EQ(kNumRpcsPerAddress, backends_[i]->service_.request_count()); + } + balancers_[0]->service_.NotifyDoneWithServerlists(); + // The balancer got a single request. + EXPECT_EQ(1U, balancers_[0]->service_.request_count()); + // and sent a single response. + EXPECT_EQ(1U, balancers_[0]->service_.response_count()); + + // Check LB policy name for the channel. + EXPECT_EQ("grpclb", channel_->GetLoadBalancingPolicyName()); +} + +TEST_F(SingleBalancerTest, ReturnServerStatus) { + SetNextResolutionAllBalancers(); + ScheduleResponseForBalancer( + 0, BalancerServiceImpl::BuildResponseForBackends(GetBackendPorts(), {}), + 0); + // We need to wait for all backends to come online. + WaitForAllBackends(); + // Send a request that the backend will fail, and make sure we get + // back the right status. + Status expected(StatusCode::INVALID_ARGUMENT, "He's dead, Jim!"); + Status actual = SendRpc(/*response=*/nullptr, /*timeout_ms=*/1000, + /*wait_for_ready=*/false, expected); + EXPECT_EQ(actual.error_code(), expected.error_code()); + EXPECT_EQ(actual.error_message(), expected.error_message()); +} + +TEST_F(SingleBalancerTest, SelectGrpclbWithMigrationServiceConfig) { + SetNextResolutionAllBalancers( + "{\n" + " \"loadBalancingConfig\":[\n" + " { \"does_not_exist\":{} },\n" + " { \"grpclb\":{} }\n" + " ]\n" + "}"); + ScheduleResponseForBalancer( + 0, BalancerServiceImpl::BuildResponseForBackends(GetBackendPorts(), {}), + 0); + CheckRpcSendOk(1, 1000 /* timeout_ms */, true /* wait_for_ready */); + balancers_[0]->service_.NotifyDoneWithServerlists(); + // The balancer got a single request. + EXPECT_EQ(1U, balancers_[0]->service_.request_count()); + // and sent a single response. + EXPECT_EQ(1U, balancers_[0]->service_.response_count()); + // Check LB policy name for the channel. + EXPECT_EQ("grpclb", channel_->GetLoadBalancingPolicyName()); +} + +TEST_F(SingleBalancerTest, + SelectGrpclbWithMigrationServiceConfigAndNoAddresses) { + const int kFallbackTimeoutMs = 200 * grpc_test_slowdown_factor(); + ResetStub(kFallbackTimeoutMs); + SetNextResolution({}, {}, + "{\n" + " \"loadBalancingConfig\":[\n" + " { \"does_not_exist\":{} },\n" + " { \"grpclb\":{} }\n" + " ]\n" + "}"); + // Try to connect. + EXPECT_EQ(GRPC_CHANNEL_IDLE, channel_->GetState(true)); + // Should go into state TRANSIENT_FAILURE when we enter fallback mode. + const gpr_timespec deadline = grpc_timeout_seconds_to_deadline(1); + grpc_connectivity_state state; + while ((state = channel_->GetState(false)) != + GRPC_CHANNEL_TRANSIENT_FAILURE) { + ASSERT_TRUE(channel_->WaitForStateChange(state, deadline)); + } + // Check LB policy name for the channel. + EXPECT_EQ("grpclb", channel_->GetLoadBalancingPolicyName()); +} + +TEST_F(SingleBalancerTest, UsePickFirstChildPolicy) { + SetNextResolutionAllBalancers( + "{\n" + " \"loadBalancingConfig\":[\n" + " { \"grpclb\":{\n" + " \"childPolicy\":[\n" + " { \"pick_first\":{} }\n" + " ]\n" + " } }\n" + " ]\n" + "}"); + ScheduleResponseForBalancer( + 0, BalancerServiceImpl::BuildResponseForBackends(GetBackendPorts(), {}), + 0); + const size_t kNumRpcs = num_backends_ * 2; + CheckRpcSendOk(kNumRpcs, 1000 /* timeout_ms */, true /* wait_for_ready */); + balancers_[0]->service_.NotifyDoneWithServerlists(); + // Check that all requests went to the first backend. This verifies + // that we used pick_first instead of round_robin as the child policy. + EXPECT_EQ(backends_[0]->service_.request_count(), kNumRpcs); + for (size_t i = 1; i < backends_.size(); ++i) { + EXPECT_EQ(backends_[i]->service_.request_count(), 0UL); + } + // The balancer got a single request. + EXPECT_EQ(1U, balancers_[0]->service_.request_count()); + // and sent a single response. + EXPECT_EQ(1U, balancers_[0]->service_.response_count()); + // Check LB policy name for the channel. + EXPECT_EQ("grpclb", channel_->GetLoadBalancingPolicyName()); +} + +TEST_F(SingleBalancerTest, SwapChildPolicy) { + SetNextResolutionAllBalancers( + "{\n" + " \"loadBalancingConfig\":[\n" + " { \"grpclb\":{\n" + " \"childPolicy\":[\n" + " { \"pick_first\":{} }\n" + " ]\n" + " } }\n" + " ]\n" + "}"); + ScheduleResponseForBalancer( + 0, BalancerServiceImpl::BuildResponseForBackends(GetBackendPorts(), {}), + 0); + const size_t kNumRpcs = num_backends_ * 2; + CheckRpcSendOk(kNumRpcs, 1000 /* timeout_ms */, true /* wait_for_ready */); + // Check that all requests went to the first backend. This verifies + // that we used pick_first instead of round_robin as the child policy. + EXPECT_EQ(backends_[0]->service_.request_count(), kNumRpcs); + for (size_t i = 1; i < backends_.size(); ++i) { + EXPECT_EQ(backends_[i]->service_.request_count(), 0UL); + } + // Send new resolution that removes child policy from service config. + SetNextResolutionAllBalancers(); + WaitForAllBackends(); + CheckRpcSendOk(kNumRpcs, 1000 /* timeout_ms */, true /* wait_for_ready */); + // Check that every backend saw the same number of requests. This verifies + // that we used round_robin. + for (size_t i = 0; i < backends_.size(); ++i) { + EXPECT_EQ(backends_[i]->service_.request_count(), 2UL); + } + // Done. + balancers_[0]->service_.NotifyDoneWithServerlists(); + // The balancer got a single request. + EXPECT_EQ(1U, balancers_[0]->service_.request_count()); + // and sent a single response. + EXPECT_EQ(1U, balancers_[0]->service_.response_count()); + // Check LB policy name for the channel. + EXPECT_EQ("grpclb", channel_->GetLoadBalancingPolicyName()); +} + +TEST_F(SingleBalancerTest, SameBackendListedMultipleTimes) { + SetNextResolutionAllBalancers(); + // Same backend listed twice. + std::vector<int> ports; + ports.push_back(backends_[0]->port_); + ports.push_back(backends_[0]->port_); + const size_t kNumRpcsPerAddress = 10; + ScheduleResponseForBalancer( + 0, BalancerServiceImpl::BuildResponseForBackends(ports, {}), 0); + // We need to wait for the backend to come online. + WaitForBackend(0); + // Send kNumRpcsPerAddress RPCs per server. + CheckRpcSendOk(kNumRpcsPerAddress * ports.size()); + // Backend should have gotten 20 requests. + EXPECT_EQ(kNumRpcsPerAddress * 2, backends_[0]->service_.request_count()); + // And they should have come from a single client port, because of + // subchannel sharing. + EXPECT_EQ(1UL, backends_[0]->service_.clients().size()); + balancers_[0]->service_.NotifyDoneWithServerlists(); +} + +TEST_F(SingleBalancerTest, SecureNaming) { + ResetStub(0, kApplicationTargetName_ + ";lb"); + SetNextResolution({AddressData{balancers_[0]->port_, "lb"}}); + const size_t kNumRpcsPerAddress = 100; + ScheduleResponseForBalancer( + 0, BalancerServiceImpl::BuildResponseForBackends(GetBackendPorts(), {}), + 0); + // Make sure that trying to connect works without a call. + channel_->GetState(true /* try_to_connect */); + // We need to wait for all backends to come online. + WaitForAllBackends(); + // Send kNumRpcsPerAddress RPCs per server. + CheckRpcSendOk(kNumRpcsPerAddress * num_backends_); + + // Each backend should have gotten 100 requests. + for (size_t i = 0; i < backends_.size(); ++i) { + EXPECT_EQ(kNumRpcsPerAddress, backends_[i]->service_.request_count()); + } + balancers_[0]->service_.NotifyDoneWithServerlists(); + // The balancer got a single request. + EXPECT_EQ(1U, balancers_[0]->service_.request_count()); + // and sent a single response. + EXPECT_EQ(1U, balancers_[0]->service_.response_count()); + // Check LB policy name for the channel. + EXPECT_EQ("grpclb", channel_->GetLoadBalancingPolicyName()); +} + +TEST_F(SingleBalancerTest, SecureNamingDeathTest) { + ::testing::FLAGS_gtest_death_test_style = "threadsafe"; + // Make sure that we blow up (via abort() from the security connector) when + // the name from the balancer doesn't match expectations. + ASSERT_DEATH_IF_SUPPORTED( + { + ResetStub(0, kApplicationTargetName_ + ";lb"); + SetNextResolution({AddressData{balancers_[0]->port_, "woops"}}); + channel_->WaitForConnected(grpc_timeout_seconds_to_deadline(1)); + }, + ""); +} + +TEST_F(SingleBalancerTest, InitiallyEmptyServerlist) { + SetNextResolutionAllBalancers(); + const int kServerlistDelayMs = 500 * grpc_test_slowdown_factor(); + const int kCallDeadlineMs = kServerlistDelayMs * 2; + // First response is an empty serverlist, sent right away. + ScheduleResponseForBalancer(0, LoadBalanceResponse(), 0); + // Send non-empty serverlist only after kServerlistDelayMs + ScheduleResponseForBalancer( + 0, BalancerServiceImpl::BuildResponseForBackends(GetBackendPorts(), {}), + kServerlistDelayMs); + const auto t0 = system_clock::now(); + // Client will block: LB will initially send empty serverlist. + CheckRpcSendOk(1, kCallDeadlineMs, true /* wait_for_ready */); + const auto ellapsed_ms = + std::chrono::duration_cast<std::chrono::milliseconds>( + system_clock::now() - t0); + // but eventually, the LB sends a serverlist update that allows the call to + // proceed. The call delay must be larger than the delay in sending the + // populated serverlist but under the call's deadline (which is enforced by + // the call's deadline). + EXPECT_GT(ellapsed_ms.count(), kServerlistDelayMs); + balancers_[0]->service_.NotifyDoneWithServerlists(); + // The balancer got a single request. + EXPECT_EQ(1U, balancers_[0]->service_.request_count()); + // and sent two responses. + EXPECT_EQ(2U, balancers_[0]->service_.response_count()); +} + +TEST_F(SingleBalancerTest, AllServersUnreachableFailFast) { + SetNextResolutionAllBalancers(); + const size_t kNumUnreachableServers = 5; + std::vector<int> ports; + for (size_t i = 0; i < kNumUnreachableServers; ++i) { + ports.push_back(grpc_pick_unused_port_or_die()); + } + ScheduleResponseForBalancer( + 0, BalancerServiceImpl::BuildResponseForBackends(ports, {}), 0); + const Status status = SendRpc(); + // The error shouldn't be DEADLINE_EXCEEDED. + EXPECT_EQ(StatusCode::UNAVAILABLE, status.error_code()); + balancers_[0]->service_.NotifyDoneWithServerlists(); + // The balancer got a single request. + EXPECT_EQ(1U, balancers_[0]->service_.request_count()); + // and sent a single response. + EXPECT_EQ(1U, balancers_[0]->service_.response_count()); +} + +TEST_F(SingleBalancerTest, Fallback) { + SetNextResolutionAllBalancers(); + const int kFallbackTimeoutMs = 200 * grpc_test_slowdown_factor(); + const int kServerlistDelayMs = 500 * grpc_test_slowdown_factor(); + const size_t kNumBackendsInResolution = backends_.size() / 2; + + ResetStub(kFallbackTimeoutMs); + std::vector<AddressData> balancer_addresses; + balancer_addresses.emplace_back(AddressData{balancers_[0]->port_, ""}); + std::vector<AddressData> backend_addresses; + for (size_t i = 0; i < kNumBackendsInResolution; ++i) { + backend_addresses.emplace_back(AddressData{backends_[i]->port_, ""}); + } + SetNextResolution(balancer_addresses, backend_addresses); + + // Send non-empty serverlist only after kServerlistDelayMs. + ScheduleResponseForBalancer( + 0, + BalancerServiceImpl::BuildResponseForBackends( + GetBackendPorts(kNumBackendsInResolution /* start_index */), {}), + kServerlistDelayMs); + + // Wait until all the fallback backends are reachable. + for (size_t i = 0; i < kNumBackendsInResolution; ++i) { + WaitForBackend(i); + } + + // The first request. + gpr_log(GPR_INFO, "========= BEFORE FIRST BATCH =========="); + CheckRpcSendOk(kNumBackendsInResolution); + gpr_log(GPR_INFO, "========= DONE WITH FIRST BATCH =========="); + + // Fallback is used: each backend returned by the resolver should have + // gotten one request. + for (size_t i = 0; i < kNumBackendsInResolution; ++i) { + EXPECT_EQ(1U, backends_[i]->service_.request_count()); + } + for (size_t i = kNumBackendsInResolution; i < backends_.size(); ++i) { + EXPECT_EQ(0U, backends_[i]->service_.request_count()); + } + + // Wait until the serverlist reception has been processed and all backends + // in the serverlist are reachable. + for (size_t i = kNumBackendsInResolution; i < backends_.size(); ++i) { + WaitForBackend(i); + } + + // Send out the second request. + gpr_log(GPR_INFO, "========= BEFORE SECOND BATCH =========="); + CheckRpcSendOk(backends_.size() - kNumBackendsInResolution); + gpr_log(GPR_INFO, "========= DONE WITH SECOND BATCH =========="); + + // Serverlist is used: each backend returned by the balancer should + // have gotten one request. + for (size_t i = 0; i < kNumBackendsInResolution; ++i) { + EXPECT_EQ(0U, backends_[i]->service_.request_count()); + } + for (size_t i = kNumBackendsInResolution; i < backends_.size(); ++i) { + EXPECT_EQ(1U, backends_[i]->service_.request_count()); + } + + balancers_[0]->service_.NotifyDoneWithServerlists(); + // The balancer got a single request. + EXPECT_EQ(1U, balancers_[0]->service_.request_count()); + // and sent a single response. + EXPECT_EQ(1U, balancers_[0]->service_.response_count()); +} + +TEST_F(SingleBalancerTest, FallbackUpdate) { + SetNextResolutionAllBalancers(); + const int kFallbackTimeoutMs = 200 * grpc_test_slowdown_factor(); + const int kServerlistDelayMs = 500 * grpc_test_slowdown_factor(); + const size_t kNumBackendsInResolution = backends_.size() / 3; + const size_t kNumBackendsInResolutionUpdate = backends_.size() / 3; + + ResetStub(kFallbackTimeoutMs); + std::vector<AddressData> balancer_addresses; + balancer_addresses.emplace_back(AddressData{balancers_[0]->port_, ""}); + std::vector<AddressData> backend_addresses; + for (size_t i = 0; i < kNumBackendsInResolution; ++i) { + backend_addresses.emplace_back(AddressData{backends_[i]->port_, ""}); + } + SetNextResolution(balancer_addresses, backend_addresses); + + // Send non-empty serverlist only after kServerlistDelayMs. + ScheduleResponseForBalancer( + 0, + BalancerServiceImpl::BuildResponseForBackends( + GetBackendPorts(kNumBackendsInResolution + + kNumBackendsInResolutionUpdate /* start_index */), + {}), + kServerlistDelayMs); + + // Wait until all the fallback backends are reachable. + for (size_t i = 0; i < kNumBackendsInResolution; ++i) { + WaitForBackend(i); + } + + // The first request. + gpr_log(GPR_INFO, "========= BEFORE FIRST BATCH =========="); + CheckRpcSendOk(kNumBackendsInResolution); + gpr_log(GPR_INFO, "========= DONE WITH FIRST BATCH =========="); + + // Fallback is used: each backend returned by the resolver should have + // gotten one request. + for (size_t i = 0; i < kNumBackendsInResolution; ++i) { + EXPECT_EQ(1U, backends_[i]->service_.request_count()); + } + for (size_t i = kNumBackendsInResolution; i < backends_.size(); ++i) { + EXPECT_EQ(0U, backends_[i]->service_.request_count()); + } + + balancer_addresses.clear(); + balancer_addresses.emplace_back(AddressData{balancers_[0]->port_, ""}); + backend_addresses.clear(); + for (size_t i = kNumBackendsInResolution; + i < kNumBackendsInResolution + kNumBackendsInResolutionUpdate; ++i) { + backend_addresses.emplace_back(AddressData{backends_[i]->port_, ""}); + } + SetNextResolution(balancer_addresses, backend_addresses); + + // Wait until the resolution update has been processed and all the new + // fallback backends are reachable. + for (size_t i = kNumBackendsInResolution; + i < kNumBackendsInResolution + kNumBackendsInResolutionUpdate; ++i) { + WaitForBackend(i); + } + + // Send out the second request. + gpr_log(GPR_INFO, "========= BEFORE SECOND BATCH =========="); + CheckRpcSendOk(kNumBackendsInResolutionUpdate); + gpr_log(GPR_INFO, "========= DONE WITH SECOND BATCH =========="); + + // The resolution update is used: each backend in the resolution update should + // have gotten one request. + for (size_t i = 0; i < kNumBackendsInResolution; ++i) { + EXPECT_EQ(0U, backends_[i]->service_.request_count()); + } + for (size_t i = kNumBackendsInResolution; + i < kNumBackendsInResolution + kNumBackendsInResolutionUpdate; ++i) { + EXPECT_EQ(1U, backends_[i]->service_.request_count()); + } + for (size_t i = kNumBackendsInResolution + kNumBackendsInResolutionUpdate; + i < backends_.size(); ++i) { + EXPECT_EQ(0U, backends_[i]->service_.request_count()); + } + + // Wait until the serverlist reception has been processed and all backends + // in the serverlist are reachable. + for (size_t i = kNumBackendsInResolution + kNumBackendsInResolutionUpdate; + i < backends_.size(); ++i) { + WaitForBackend(i); + } + + // Send out the third request. + gpr_log(GPR_INFO, "========= BEFORE THIRD BATCH =========="); + CheckRpcSendOk(backends_.size() - kNumBackendsInResolution - + kNumBackendsInResolutionUpdate); + gpr_log(GPR_INFO, "========= DONE WITH THIRD BATCH =========="); + + // Serverlist is used: each backend returned by the balancer should + // have gotten one request. + for (size_t i = 0; + i < kNumBackendsInResolution + kNumBackendsInResolutionUpdate; ++i) { + EXPECT_EQ(0U, backends_[i]->service_.request_count()); + } + for (size_t i = kNumBackendsInResolution + kNumBackendsInResolutionUpdate; + i < backends_.size(); ++i) { + EXPECT_EQ(1U, backends_[i]->service_.request_count()); + } + + balancers_[0]->service_.NotifyDoneWithServerlists(); + // The balancer got a single request. + EXPECT_EQ(1U, balancers_[0]->service_.request_count()); + // and sent a single response. + EXPECT_EQ(1U, balancers_[0]->service_.response_count()); +} + +TEST_F(SingleBalancerTest, + FallbackAfterStartup_LoseContactWithBalancerThenBackends) { + // First two backends are fallback, last two are pointed to by balancer. + const size_t kNumFallbackBackends = 2; + const size_t kNumBalancerBackends = backends_.size() - kNumFallbackBackends; + std::vector<AddressData> backend_addresses; + for (size_t i = 0; i < kNumFallbackBackends; ++i) { + backend_addresses.emplace_back(AddressData{backends_[i]->port_, ""}); + } + std::vector<AddressData> balancer_addresses; + for (size_t i = 0; i < balancers_.size(); ++i) { + balancer_addresses.emplace_back(AddressData{balancers_[i]->port_, ""}); + } + SetNextResolution(balancer_addresses, backend_addresses); + ScheduleResponseForBalancer(0, + BalancerServiceImpl::BuildResponseForBackends( + GetBackendPorts(kNumFallbackBackends), {}), + 0); + // Try to connect. + channel_->GetState(true /* try_to_connect */); + WaitForAllBackends(1 /* num_requests_multiple_of */, + kNumFallbackBackends /* start_index */); + // Stop balancer. RPCs should continue going to backends from balancer. + balancers_[0]->Shutdown(); + CheckRpcSendOk(100 * kNumBalancerBackends); + for (size_t i = kNumFallbackBackends; i < backends_.size(); ++i) { + EXPECT_EQ(100UL, backends_[i]->service_.request_count()); + } + // Stop backends from balancer. This should put us in fallback mode. + for (size_t i = kNumFallbackBackends; i < backends_.size(); ++i) { + ShutdownBackend(i); + } + WaitForAllBackends(1 /* num_requests_multiple_of */, 0 /* start_index */, + kNumFallbackBackends /* stop_index */); + // Restart the backends from the balancer. We should *not* start + // sending traffic back to them at this point (although the behavior + // in xds may be different). + for (size_t i = kNumFallbackBackends; i < backends_.size(); ++i) { + StartBackend(i); + } + CheckRpcSendOk(100 * kNumBalancerBackends); + for (size_t i = 0; i < kNumFallbackBackends; ++i) { + EXPECT_EQ(100UL, backends_[i]->service_.request_count()); + } + // Now start the balancer again. This should cause us to exit + // fallback mode. + balancers_[0]->Start(server_host_); + ScheduleResponseForBalancer(0, + BalancerServiceImpl::BuildResponseForBackends( + GetBackendPorts(kNumFallbackBackends), {}), + 0); + WaitForAllBackends(1 /* num_requests_multiple_of */, + kNumFallbackBackends /* start_index */); +} + +TEST_F(SingleBalancerTest, + FallbackAfterStartup_LoseContactWithBackendsThenBalancer) { + // First two backends are fallback, last two are pointed to by balancer. + const size_t kNumFallbackBackends = 2; + const size_t kNumBalancerBackends = backends_.size() - kNumFallbackBackends; + std::vector<AddressData> backend_addresses; + for (size_t i = 0; i < kNumFallbackBackends; ++i) { + backend_addresses.emplace_back(AddressData{backends_[i]->port_, ""}); + } + std::vector<AddressData> balancer_addresses; + for (size_t i = 0; i < balancers_.size(); ++i) { + balancer_addresses.emplace_back(AddressData{balancers_[i]->port_, ""}); + } + SetNextResolution(balancer_addresses, backend_addresses); + ScheduleResponseForBalancer(0, + BalancerServiceImpl::BuildResponseForBackends( + GetBackendPorts(kNumFallbackBackends), {}), + 0); + // Try to connect. + channel_->GetState(true /* try_to_connect */); + WaitForAllBackends(1 /* num_requests_multiple_of */, + kNumFallbackBackends /* start_index */); + // Stop backends from balancer. Since we are still in contact with + // the balancer at this point, RPCs should be failing. + for (size_t i = kNumFallbackBackends; i < backends_.size(); ++i) { + ShutdownBackend(i); + } + CheckRpcSendFailure(); + // Stop balancer. This should put us in fallback mode. + balancers_[0]->Shutdown(); + WaitForAllBackends(1 /* num_requests_multiple_of */, 0 /* start_index */, + kNumFallbackBackends /* stop_index */); + // Restart the backends from the balancer. We should *not* start + // sending traffic back to them at this point (although the behavior + // in xds may be different). + for (size_t i = kNumFallbackBackends; i < backends_.size(); ++i) { + StartBackend(i); + } + CheckRpcSendOk(100 * kNumBalancerBackends); + for (size_t i = 0; i < kNumFallbackBackends; ++i) { + EXPECT_EQ(100UL, backends_[i]->service_.request_count()); + } + // Now start the balancer again. This should cause us to exit + // fallback mode. + balancers_[0]->Start(server_host_); + ScheduleResponseForBalancer(0, + BalancerServiceImpl::BuildResponseForBackends( + GetBackendPorts(kNumFallbackBackends), {}), + 0); + WaitForAllBackends(1 /* num_requests_multiple_of */, + kNumFallbackBackends /* start_index */); +} + +TEST_F(SingleBalancerTest, FallbackEarlyWhenBalancerChannelFails) { + const int kFallbackTimeoutMs = 10000 * grpc_test_slowdown_factor(); + ResetStub(kFallbackTimeoutMs); + // Return an unreachable balancer and one fallback backend. + std::vector<AddressData> balancer_addresses; + balancer_addresses.emplace_back( + AddressData{grpc_pick_unused_port_or_die(), ""}); + std::vector<AddressData> backend_addresses; + backend_addresses.emplace_back(AddressData{backends_[0]->port_, ""}); + SetNextResolution(balancer_addresses, backend_addresses); + // Send RPC with deadline less than the fallback timeout and make sure it + // succeeds. + CheckRpcSendOk(/* times */ 1, /* timeout_ms */ 1000, + /* wait_for_ready */ false); +} + +TEST_F(SingleBalancerTest, FallbackEarlyWhenBalancerCallFails) { + const int kFallbackTimeoutMs = 10000 * grpc_test_slowdown_factor(); + ResetStub(kFallbackTimeoutMs); + // Return one balancer and one fallback backend. + std::vector<AddressData> balancer_addresses; + balancer_addresses.emplace_back(AddressData{balancers_[0]->port_, ""}); + std::vector<AddressData> backend_addresses; + backend_addresses.emplace_back(AddressData{backends_[0]->port_, ""}); + SetNextResolution(balancer_addresses, backend_addresses); + // Balancer drops call without sending a serverlist. + balancers_[0]->service_.NotifyDoneWithServerlists(); + // Send RPC with deadline less than the fallback timeout and make sure it + // succeeds. + CheckRpcSendOk(/* times */ 1, /* timeout_ms */ 1000, + /* wait_for_ready */ false); +} + +TEST_F(SingleBalancerTest, FallbackControlledByBalancer_BeforeFirstServerlist) { + const int kFallbackTimeoutMs = 10000 * grpc_test_slowdown_factor(); + ResetStub(kFallbackTimeoutMs); + // Return one balancer and one fallback backend. + std::vector<AddressData> balancer_addresses; + balancer_addresses.emplace_back(AddressData{balancers_[0]->port_, ""}); + std::vector<AddressData> backend_addresses; + backend_addresses.emplace_back(AddressData{backends_[0]->port_, ""}); + SetNextResolution(balancer_addresses, backend_addresses); + // Balancer explicitly tells client to fallback. + LoadBalanceResponse resp; + resp.mutable_fallback_response(); + ScheduleResponseForBalancer(0, resp, 0); + // Send RPC with deadline less than the fallback timeout and make sure it + // succeeds. + CheckRpcSendOk(/* times */ 1, /* timeout_ms */ 1000, + /* wait_for_ready */ false); +} + +TEST_F(SingleBalancerTest, FallbackControlledByBalancer_AfterFirstServerlist) { + // Return one balancer and one fallback backend (backend 0). + std::vector<AddressData> balancer_addresses; + balancer_addresses.emplace_back(AddressData{balancers_[0]->port_, ""}); + std::vector<AddressData> backend_addresses; + backend_addresses.emplace_back(AddressData{backends_[0]->port_, ""}); + SetNextResolution(balancer_addresses, backend_addresses); + // Balancer initially sends serverlist, then tells client to fall back, + // then sends the serverlist again. + // The serverlist points to backend 1. + LoadBalanceResponse serverlist_resp = + BalancerServiceImpl::BuildResponseForBackends({backends_[1]->port_}, {}); + LoadBalanceResponse fallback_resp; + fallback_resp.mutable_fallback_response(); + ScheduleResponseForBalancer(0, serverlist_resp, 0); + ScheduleResponseForBalancer(0, fallback_resp, 100); + ScheduleResponseForBalancer(0, serverlist_resp, 100); + // Requests initially go to backend 1, then go to backend 0 in + // fallback mode, then go back to backend 1 when we exit fallback. + WaitForBackend(1); + WaitForBackend(0); + WaitForBackend(1); +} + +TEST_F(SingleBalancerTest, BackendsRestart) { + SetNextResolutionAllBalancers(); + const size_t kNumRpcsPerAddress = 100; + ScheduleResponseForBalancer( + 0, BalancerServiceImpl::BuildResponseForBackends(GetBackendPorts(), {}), + 0); + // Make sure that trying to connect works without a call. + channel_->GetState(true /* try_to_connect */); + // Send kNumRpcsPerAddress RPCs per server. + CheckRpcSendOk(kNumRpcsPerAddress * num_backends_); + // Stop backends. RPCs should fail. + ShutdownAllBackends(); + CheckRpcSendFailure(); + // Restart backends. RPCs should start succeeding again. + StartAllBackends(); + CheckRpcSendOk(1 /* times */, 2000 /* timeout_ms */, + true /* wait_for_ready */); + // The balancer got a single request. + EXPECT_EQ(1U, balancers_[0]->service_.request_count()); + // and sent a single response. + EXPECT_EQ(1U, balancers_[0]->service_.response_count()); +} + +TEST_F(SingleBalancerTest, ServiceNameFromLbPolicyConfig) { + constexpr char kServiceConfigWithTarget[] = + "{\n" + " \"loadBalancingConfig\":[\n" + " { \"grpclb\":{\n" + " \"serviceName\":\"test_service\"\n" + " }}\n" + " ]\n" + "}"; + + SetNextResolutionAllBalancers(kServiceConfigWithTarget); + ScheduleResponseForBalancer( + 0, BalancerServiceImpl::BuildResponseForBackends(GetBackendPorts(), {}), + 0); + // Make sure that trying to connect works without a call. + channel_->GetState(true /* try_to_connect */); + // We need to wait for all backends to come online. + WaitForAllBackends(); + EXPECT_EQ(balancers_[0]->service_.service_names().back(), "test_service"); +} + +class UpdatesTest : public GrpclbEnd2endTest { + public: + UpdatesTest() : GrpclbEnd2endTest(4, 3, 0) {} +}; + +TEST_F(UpdatesTest, UpdateBalancersButKeepUsingOriginalBalancer) { + SetNextResolutionAllBalancers(); + const std::vector<int> first_backend{GetBackendPorts()[0]}; + const std::vector<int> second_backend{GetBackendPorts()[1]}; + ScheduleResponseForBalancer( + 0, BalancerServiceImpl::BuildResponseForBackends(first_backend, {}), 0); + ScheduleResponseForBalancer( + 1, BalancerServiceImpl::BuildResponseForBackends(second_backend, {}), 0); + + // Wait until the first backend is ready. + WaitForBackend(0); + + // Send 10 requests. + gpr_log(GPR_INFO, "========= BEFORE FIRST BATCH =========="); + CheckRpcSendOk(10); + gpr_log(GPR_INFO, "========= DONE WITH FIRST BATCH =========="); + + // All 10 requests should have gone to the first backend. + EXPECT_EQ(10U, backends_[0]->service_.request_count()); + + // Balancer 0 got a single request. + EXPECT_EQ(1U, balancers_[0]->service_.request_count()); + // and sent a single response. + EXPECT_EQ(1U, balancers_[0]->service_.response_count()); + EXPECT_EQ(0U, balancers_[1]->service_.request_count()); + EXPECT_EQ(0U, balancers_[1]->service_.response_count()); + EXPECT_EQ(0U, balancers_[2]->service_.request_count()); + EXPECT_EQ(0U, balancers_[2]->service_.response_count()); + + std::vector<AddressData> addresses; + addresses.emplace_back(AddressData{balancers_[1]->port_, ""}); + gpr_log(GPR_INFO, "========= ABOUT TO UPDATE 1 =========="); + SetNextResolution(addresses); + gpr_log(GPR_INFO, "========= UPDATE 1 DONE =========="); + + EXPECT_EQ(0U, backends_[1]->service_.request_count()); + gpr_timespec deadline = gpr_time_add( + gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_millis(10000, GPR_TIMESPAN)); + // Send 10 seconds worth of RPCs + do { + CheckRpcSendOk(); + } while (gpr_time_cmp(gpr_now(GPR_CLOCK_REALTIME), deadline) < 0); + // The current LB call is still working, so grpclb continued using it to the + // first balancer, which doesn't assign the second backend. + EXPECT_EQ(0U, backends_[1]->service_.request_count()); + + EXPECT_EQ(1U, balancers_[0]->service_.request_count()); + EXPECT_EQ(1U, balancers_[0]->service_.response_count()); + EXPECT_EQ(0U, balancers_[1]->service_.request_count()); + EXPECT_EQ(0U, balancers_[1]->service_.response_count()); + EXPECT_EQ(0U, balancers_[2]->service_.request_count()); + EXPECT_EQ(0U, balancers_[2]->service_.response_count()); +} + +// Send an update with the same set of LBs as the one in SetUp() in order to +// verify that the LB channel inside grpclb keeps the initial connection (which +// by definition is also present in the update). +TEST_F(UpdatesTest, UpdateBalancersRepeated) { + SetNextResolutionAllBalancers(); + const std::vector<int> first_backend{GetBackendPorts()[0]}; + const std::vector<int> second_backend{GetBackendPorts()[0]}; + + ScheduleResponseForBalancer( + 0, BalancerServiceImpl::BuildResponseForBackends(first_backend, {}), 0); + ScheduleResponseForBalancer( + 1, BalancerServiceImpl::BuildResponseForBackends(second_backend, {}), 0); + + // Wait until the first backend is ready. + WaitForBackend(0); + + // Send 10 requests. + gpr_log(GPR_INFO, "========= BEFORE FIRST BATCH =========="); + CheckRpcSendOk(10); + gpr_log(GPR_INFO, "========= DONE WITH FIRST BATCH =========="); + + // All 10 requests should have gone to the first backend. + EXPECT_EQ(10U, backends_[0]->service_.request_count()); + + balancers_[0]->service_.NotifyDoneWithServerlists(); + // Balancer 0 got a single request. + EXPECT_EQ(1U, balancers_[0]->service_.request_count()); + // and sent a single response. + EXPECT_EQ(1U, balancers_[0]->service_.response_count()); + EXPECT_EQ(0U, balancers_[1]->service_.request_count()); + EXPECT_EQ(0U, balancers_[1]->service_.response_count()); + EXPECT_EQ(0U, balancers_[2]->service_.request_count()); + EXPECT_EQ(0U, balancers_[2]->service_.response_count()); + + std::vector<AddressData> addresses; + addresses.emplace_back(AddressData{balancers_[0]->port_, ""}); + addresses.emplace_back(AddressData{balancers_[1]->port_, ""}); + addresses.emplace_back(AddressData{balancers_[2]->port_, ""}); + gpr_log(GPR_INFO, "========= ABOUT TO UPDATE 1 =========="); + SetNextResolution(addresses); + gpr_log(GPR_INFO, "========= UPDATE 1 DONE =========="); + + EXPECT_EQ(0U, backends_[1]->service_.request_count()); + gpr_timespec deadline = gpr_time_add( + gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_millis(10000, GPR_TIMESPAN)); + // Send 10 seconds worth of RPCs + do { + CheckRpcSendOk(); + } while (gpr_time_cmp(gpr_now(GPR_CLOCK_REALTIME), deadline) < 0); + // grpclb continued using the original LB call to the first balancer, which + // doesn't assign the second backend. + EXPECT_EQ(0U, backends_[1]->service_.request_count()); + balancers_[0]->service_.NotifyDoneWithServerlists(); + + addresses.clear(); + addresses.emplace_back(AddressData{balancers_[0]->port_, ""}); + addresses.emplace_back(AddressData{balancers_[1]->port_, ""}); + gpr_log(GPR_INFO, "========= ABOUT TO UPDATE 2 =========="); + SetNextResolution(addresses); + gpr_log(GPR_INFO, "========= UPDATE 2 DONE =========="); + + EXPECT_EQ(0U, backends_[1]->service_.request_count()); + deadline = gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_millis(10000, GPR_TIMESPAN)); + // Send 10 seconds worth of RPCs + do { + CheckRpcSendOk(); + } while (gpr_time_cmp(gpr_now(GPR_CLOCK_REALTIME), deadline) < 0); + // grpclb continued using the original LB call to the first balancer, which + // doesn't assign the second backend. + EXPECT_EQ(0U, backends_[1]->service_.request_count()); + balancers_[0]->service_.NotifyDoneWithServerlists(); +} + +TEST_F(UpdatesTest, UpdateBalancersDeadUpdate) { + std::vector<AddressData> addresses; + addresses.emplace_back(AddressData{balancers_[0]->port_, ""}); + SetNextResolution(addresses); + const std::vector<int> first_backend{GetBackendPorts()[0]}; + const std::vector<int> second_backend{GetBackendPorts()[1]}; + + ScheduleResponseForBalancer( + 0, BalancerServiceImpl::BuildResponseForBackends(first_backend, {}), 0); + ScheduleResponseForBalancer( + 1, BalancerServiceImpl::BuildResponseForBackends(second_backend, {}), 0); + + // Start servers and send 10 RPCs per server. + gpr_log(GPR_INFO, "========= BEFORE FIRST BATCH =========="); + CheckRpcSendOk(10); + gpr_log(GPR_INFO, "========= DONE WITH FIRST BATCH =========="); + // All 10 requests should have gone to the first backend. + EXPECT_EQ(10U, backends_[0]->service_.request_count()); + + // Kill balancer 0 + gpr_log(GPR_INFO, "********** ABOUT TO KILL BALANCER 0 *************"); + balancers_[0]->Shutdown(); + gpr_log(GPR_INFO, "********** KILLED BALANCER 0 *************"); + + // This is serviced by the existing RR policy + gpr_log(GPR_INFO, "========= BEFORE SECOND BATCH =========="); + CheckRpcSendOk(10); + gpr_log(GPR_INFO, "========= DONE WITH SECOND BATCH =========="); + // All 10 requests should again have gone to the first backend. + EXPECT_EQ(20U, backends_[0]->service_.request_count()); + EXPECT_EQ(0U, backends_[1]->service_.request_count()); + + // Balancer 0 got a single request. + EXPECT_EQ(1U, balancers_[0]->service_.request_count()); + // and sent a single response. + EXPECT_EQ(1U, balancers_[0]->service_.response_count()); + EXPECT_EQ(0U, balancers_[1]->service_.request_count()); + EXPECT_EQ(0U, balancers_[1]->service_.response_count()); + EXPECT_EQ(0U, balancers_[2]->service_.request_count()); + EXPECT_EQ(0U, balancers_[2]->service_.response_count()); + + addresses.clear(); + addresses.emplace_back(AddressData{balancers_[1]->port_, ""}); + gpr_log(GPR_INFO, "========= ABOUT TO UPDATE 1 =========="); + SetNextResolution(addresses); + gpr_log(GPR_INFO, "========= UPDATE 1 DONE =========="); + + // Wait until update has been processed, as signaled by the second backend + // receiving a request. In the meantime, the client continues to be serviced + // (by the first backend) without interruption. + EXPECT_EQ(0U, backends_[1]->service_.request_count()); + WaitForBackend(1); + + // This is serviced by the updated RR policy + backends_[1]->service_.ResetCounters(); + gpr_log(GPR_INFO, "========= BEFORE THIRD BATCH =========="); + CheckRpcSendOk(10); + gpr_log(GPR_INFO, "========= DONE WITH THIRD BATCH =========="); + // All 10 requests should have gone to the second backend. + EXPECT_EQ(10U, backends_[1]->service_.request_count()); + + EXPECT_EQ(1U, balancers_[0]->service_.request_count()); + EXPECT_EQ(1U, balancers_[0]->service_.response_count()); + // The second balancer, published as part of the first update, may end up + // getting two requests (that is, 1 <= #req <= 2) if the LB call retry timer + // firing races with the arrival of the update containing the second + // balancer. + EXPECT_GE(balancers_[1]->service_.request_count(), 1U); + EXPECT_GE(balancers_[1]->service_.response_count(), 1U); + EXPECT_LE(balancers_[1]->service_.request_count(), 2U); + EXPECT_LE(balancers_[1]->service_.response_count(), 2U); + EXPECT_EQ(0U, balancers_[2]->service_.request_count()); + EXPECT_EQ(0U, balancers_[2]->service_.response_count()); +} + +TEST_F(UpdatesTest, ReresolveDeadBackend) { + ResetStub(500); + // The first resolution contains the addresses of a balancer that never + // responds, and a fallback backend. + std::vector<AddressData> balancer_addresses; + balancer_addresses.emplace_back(AddressData{balancers_[0]->port_, ""}); + std::vector<AddressData> backend_addresses; + backend_addresses.emplace_back(AddressData{backends_[0]->port_, ""}); + SetNextResolution(balancer_addresses, backend_addresses); + // Ask channel to connect to trigger resolver creation. + channel_->GetState(true); + // The re-resolution result will contain the addresses of the same balancer + // and a new fallback backend. + balancer_addresses.clear(); + balancer_addresses.emplace_back(AddressData{balancers_[0]->port_, ""}); + backend_addresses.clear(); + backend_addresses.emplace_back(AddressData{backends_[1]->port_, ""}); + SetNextReresolutionResponse(balancer_addresses, backend_addresses); + + // Start servers and send 10 RPCs per server. + gpr_log(GPR_INFO, "========= BEFORE FIRST BATCH =========="); + CheckRpcSendOk(10); + gpr_log(GPR_INFO, "========= DONE WITH FIRST BATCH =========="); + // All 10 requests should have gone to the fallback backend. + EXPECT_EQ(10U, backends_[0]->service_.request_count()); + + // Kill backend 0. + gpr_log(GPR_INFO, "********** ABOUT TO KILL BACKEND 0 *************"); + backends_[0]->Shutdown(); + gpr_log(GPR_INFO, "********** KILLED BACKEND 0 *************"); + + // Wait until re-resolution has finished, as signaled by the second backend + // receiving a request. + WaitForBackend(1); + + gpr_log(GPR_INFO, "========= BEFORE SECOND BATCH =========="); + CheckRpcSendOk(10); + gpr_log(GPR_INFO, "========= DONE WITH SECOND BATCH =========="); + // All 10 requests should have gone to the second backend. + EXPECT_EQ(10U, backends_[1]->service_.request_count()); + + balancers_[0]->service_.NotifyDoneWithServerlists(); + balancers_[1]->service_.NotifyDoneWithServerlists(); + balancers_[2]->service_.NotifyDoneWithServerlists(); + EXPECT_EQ(1U, balancers_[0]->service_.request_count()); + EXPECT_EQ(0U, balancers_[0]->service_.response_count()); + EXPECT_EQ(0U, balancers_[1]->service_.request_count()); + EXPECT_EQ(0U, balancers_[1]->service_.response_count()); + EXPECT_EQ(0U, balancers_[2]->service_.request_count()); + EXPECT_EQ(0U, balancers_[2]->service_.response_count()); +} + +// TODO(juanlishen): Should be removed when the first response is always the +// initial response. Currently, if client load reporting is not enabled, the +// balancer doesn't send initial response. When the backend shuts down, an +// unexpected re-resolution will happen. This test configuration is a workaround +// for test ReresolveDeadBalancer. +class UpdatesWithClientLoadReportingTest : public GrpclbEnd2endTest { + public: + UpdatesWithClientLoadReportingTest() : GrpclbEnd2endTest(4, 3, 2) {} +}; + +TEST_F(UpdatesWithClientLoadReportingTest, ReresolveDeadBalancer) { + const std::vector<int> first_backend{GetBackendPorts()[0]}; + const std::vector<int> second_backend{GetBackendPorts()[1]}; + ScheduleResponseForBalancer( + 0, BalancerServiceImpl::BuildResponseForBackends(first_backend, {}), 0); + ScheduleResponseForBalancer( + 1, BalancerServiceImpl::BuildResponseForBackends(second_backend, {}), 0); + + // Ask channel to connect to trigger resolver creation. + channel_->GetState(true); + std::vector<AddressData> addresses; + addresses.emplace_back(AddressData{balancers_[0]->port_, ""}); + SetNextResolution(addresses); + addresses.clear(); + addresses.emplace_back(AddressData{balancers_[1]->port_, ""}); + SetNextReresolutionResponse(addresses); + + // Start servers and send 10 RPCs per server. + gpr_log(GPR_INFO, "========= BEFORE FIRST BATCH =========="); + CheckRpcSendOk(10); + gpr_log(GPR_INFO, "========= DONE WITH FIRST BATCH =========="); + // All 10 requests should have gone to the first backend. + EXPECT_EQ(10U, backends_[0]->service_.request_count()); + + // Kill backend 0. + gpr_log(GPR_INFO, "********** ABOUT TO KILL BACKEND 0 *************"); + backends_[0]->Shutdown(); + gpr_log(GPR_INFO, "********** KILLED BACKEND 0 *************"); + + CheckRpcSendFailure(); + + // Balancer 0 got a single request. + EXPECT_EQ(1U, balancers_[0]->service_.request_count()); + // and sent a single response. + EXPECT_EQ(1U, balancers_[0]->service_.response_count()); + EXPECT_EQ(0U, balancers_[1]->service_.request_count()); + EXPECT_EQ(0U, balancers_[1]->service_.response_count()); + EXPECT_EQ(0U, balancers_[2]->service_.request_count()); + EXPECT_EQ(0U, balancers_[2]->service_.response_count()); + + // Kill balancer 0. + gpr_log(GPR_INFO, "********** ABOUT TO KILL BALANCER 0 *************"); + balancers_[0]->Shutdown(); + gpr_log(GPR_INFO, "********** KILLED BALANCER 0 *************"); + + // Wait until re-resolution has finished, as signaled by the second backend + // receiving a request. + WaitForBackend(1); + + // This is serviced by the new serverlist. + gpr_log(GPR_INFO, "========= BEFORE SECOND BATCH =========="); + CheckRpcSendOk(10); + gpr_log(GPR_INFO, "========= DONE WITH SECOND BATCH =========="); + // All 10 requests should have gone to the second backend. + EXPECT_EQ(10U, backends_[1]->service_.request_count()); + + EXPECT_EQ(1U, balancers_[0]->service_.request_count()); + EXPECT_EQ(1U, balancers_[0]->service_.response_count()); + // After balancer 0 is killed, we restart an LB call immediately (because we + // disconnect to a previously connected balancer). Although we will cancel + // this call when the re-resolution update is done and another LB call restart + // is needed, this old call may still succeed reaching the LB server if + // re-resolution is slow. So balancer 1 may have received 2 requests and sent + // 2 responses. + EXPECT_GE(balancers_[1]->service_.request_count(), 1U); + EXPECT_GE(balancers_[1]->service_.response_count(), 1U); + EXPECT_LE(balancers_[1]->service_.request_count(), 2U); + EXPECT_LE(balancers_[1]->service_.response_count(), 2U); + EXPECT_EQ(0U, balancers_[2]->service_.request_count()); + EXPECT_EQ(0U, balancers_[2]->service_.response_count()); +} + +TEST_F(SingleBalancerTest, Drop) { + SetNextResolutionAllBalancers(); + const size_t kNumRpcsPerAddress = 100; + const int num_of_drop_by_rate_limiting_addresses = 1; + const int num_of_drop_by_load_balancing_addresses = 2; + const int num_of_drop_addresses = num_of_drop_by_rate_limiting_addresses + + num_of_drop_by_load_balancing_addresses; + const int num_total_addresses = num_backends_ + num_of_drop_addresses; + ScheduleResponseForBalancer( + 0, + BalancerServiceImpl::BuildResponseForBackends( + GetBackendPorts(), + {{"rate_limiting", num_of_drop_by_rate_limiting_addresses}, + {"load_balancing", num_of_drop_by_load_balancing_addresses}}), + 0); + // Wait until all backends are ready. + WaitForAllBackends(); + // Send kNumRpcsPerAddress RPCs for each server and drop address. + size_t num_drops = 0; + for (size_t i = 0; i < kNumRpcsPerAddress * num_total_addresses; ++i) { + EchoResponse response; + const Status status = SendRpc(&response); + if (!status.ok() && + status.error_message() == "Call dropped by load balancing policy") { + ++num_drops; + } else { + EXPECT_TRUE(status.ok()) << "code=" << status.error_code() + << " message=" << status.error_message(); + EXPECT_EQ(response.message(), kRequestMessage_); + } + } + EXPECT_EQ(kNumRpcsPerAddress * num_of_drop_addresses, num_drops); + // Each backend should have gotten 100 requests. + for (size_t i = 0; i < backends_.size(); ++i) { + EXPECT_EQ(kNumRpcsPerAddress, backends_[i]->service_.request_count()); + } + // The balancer got a single request. + EXPECT_EQ(1U, balancers_[0]->service_.request_count()); + // and sent a single response. + EXPECT_EQ(1U, balancers_[0]->service_.response_count()); +} + +TEST_F(SingleBalancerTest, DropAllFirst) { + SetNextResolutionAllBalancers(); + // All registered addresses are marked as "drop". + const int num_of_drop_by_rate_limiting_addresses = 1; + const int num_of_drop_by_load_balancing_addresses = 1; + ScheduleResponseForBalancer( + 0, + BalancerServiceImpl::BuildResponseForBackends( + {}, {{"rate_limiting", num_of_drop_by_rate_limiting_addresses}, + {"load_balancing", num_of_drop_by_load_balancing_addresses}}), + 0); + const Status status = SendRpc(nullptr, 1000, true); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.error_message(), "Call dropped by load balancing policy"); +} + +TEST_F(SingleBalancerTest, DropAll) { + SetNextResolutionAllBalancers(); + ScheduleResponseForBalancer( + 0, BalancerServiceImpl::BuildResponseForBackends(GetBackendPorts(), {}), + 0); + const int num_of_drop_by_rate_limiting_addresses = 1; + const int num_of_drop_by_load_balancing_addresses = 1; + ScheduleResponseForBalancer( + 0, + BalancerServiceImpl::BuildResponseForBackends( + {}, {{"rate_limiting", num_of_drop_by_rate_limiting_addresses}, + {"load_balancing", num_of_drop_by_load_balancing_addresses}}), + 1000); + + // First call succeeds. + CheckRpcSendOk(); + // But eventually, the update with only dropped servers is processed and calls + // fail. + Status status; + do { + status = SendRpc(nullptr, 1000, true); + } while (status.ok()); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.error_message(), "Call dropped by load balancing policy"); +} + +class SingleBalancerWithClientLoadReportingTest : public GrpclbEnd2endTest { + public: + SingleBalancerWithClientLoadReportingTest() : GrpclbEnd2endTest(4, 1, 3) {} +}; + +TEST_F(SingleBalancerWithClientLoadReportingTest, Vanilla) { + SetNextResolutionAllBalancers(); + const size_t kNumRpcsPerAddress = 100; + ScheduleResponseForBalancer( + 0, BalancerServiceImpl::BuildResponseForBackends(GetBackendPorts(), {}), + 0); + // Wait until all backends are ready. + int num_ok = 0; + int num_failure = 0; + int num_drops = 0; + std::tie(num_ok, num_failure, num_drops) = WaitForAllBackends(); + // Send kNumRpcsPerAddress RPCs per server. + CheckRpcSendOk(kNumRpcsPerAddress * num_backends_); + // Each backend should have gotten 100 requests. + for (size_t i = 0; i < backends_.size(); ++i) { + EXPECT_EQ(kNumRpcsPerAddress, backends_[i]->service_.request_count()); + } + balancers_[0]->service_.NotifyDoneWithServerlists(); + // The balancer got a single request. + EXPECT_EQ(1U, balancers_[0]->service_.request_count()); + // and sent a single response. + EXPECT_EQ(1U, balancers_[0]->service_.response_count()); + + ClientStats client_stats; + do { + client_stats += WaitForLoadReports(); + } while (client_stats.num_calls_finished != + kNumRpcsPerAddress * num_backends_ + num_ok); + EXPECT_EQ(kNumRpcsPerAddress * num_backends_ + num_ok, + client_stats.num_calls_started); + EXPECT_EQ(kNumRpcsPerAddress * num_backends_ + num_ok, + client_stats.num_calls_finished); + EXPECT_EQ(0U, client_stats.num_calls_finished_with_client_failed_to_send); + EXPECT_EQ(kNumRpcsPerAddress * num_backends_ + (num_ok + num_drops), + client_stats.num_calls_finished_known_received); + EXPECT_THAT(client_stats.drop_token_counts, ::testing::ElementsAre()); +} + +TEST_F(SingleBalancerWithClientLoadReportingTest, BalancerRestart) { + SetNextResolutionAllBalancers(); + const size_t kNumBackendsFirstPass = 2; + const size_t kNumBackendsSecondPass = + backends_.size() - kNumBackendsFirstPass; + // Balancer returns backends starting at index 1. + ScheduleResponseForBalancer( + 0, + BalancerServiceImpl::BuildResponseForBackends( + GetBackendPorts(0, kNumBackendsFirstPass), {}), + 0); + // Wait until all backends returned by the balancer are ready. + int num_ok = 0; + int num_failure = 0; + int num_drops = 0; + std::tie(num_ok, num_failure, num_drops) = + WaitForAllBackends(/* num_requests_multiple_of */ 1, /* start_index */ 0, + /* stop_index */ kNumBackendsFirstPass); + balancers_[0]->service_.NotifyDoneWithServerlists(); + ClientStats client_stats = WaitForLoadReports(); + EXPECT_EQ(static_cast<size_t>(num_ok), client_stats.num_calls_started); + EXPECT_EQ(static_cast<size_t>(num_ok), client_stats.num_calls_finished); + EXPECT_EQ(0U, client_stats.num_calls_finished_with_client_failed_to_send); + EXPECT_EQ(static_cast<size_t>(num_ok), + client_stats.num_calls_finished_known_received); + EXPECT_THAT(client_stats.drop_token_counts, ::testing::ElementsAre()); + // Shut down the balancer. + balancers_[0]->Shutdown(); + // Send 10 more requests per backend. This will continue using the + // last serverlist we received from the balancer before it was shut down. + ResetBackendCounters(); + CheckRpcSendOk(kNumBackendsFirstPass); + // Each backend should have gotten 1 request. + for (size_t i = 0; i < kNumBackendsFirstPass; ++i) { + EXPECT_EQ(1UL, backends_[i]->service_.request_count()); + } + // Now restart the balancer, this time pointing to all backends. + balancers_[0]->Start(server_host_); + ScheduleResponseForBalancer(0, + BalancerServiceImpl::BuildResponseForBackends( + GetBackendPorts(kNumBackendsFirstPass), {}), + 0); + // Wait for queries to start going to one of the new backends. + // This tells us that we're now using the new serverlist. + do { + CheckRpcSendOk(); + } while (backends_[2]->service_.request_count() == 0 && + backends_[3]->service_.request_count() == 0); + // Send one RPC per backend. + CheckRpcSendOk(kNumBackendsSecondPass); + balancers_[0]->service_.NotifyDoneWithServerlists(); + // Check client stats. + client_stats = WaitForLoadReports(); + EXPECT_EQ(kNumBackendsSecondPass + 1, client_stats.num_calls_started); + EXPECT_EQ(kNumBackendsSecondPass + 1, client_stats.num_calls_finished); + EXPECT_EQ(0U, client_stats.num_calls_finished_with_client_failed_to_send); + EXPECT_EQ(kNumBackendsSecondPass + 1, + client_stats.num_calls_finished_known_received); + EXPECT_THAT(client_stats.drop_token_counts, ::testing::ElementsAre()); +} + +TEST_F(SingleBalancerWithClientLoadReportingTest, Drop) { + SetNextResolutionAllBalancers(); + const size_t kNumRpcsPerAddress = 3; + const int num_of_drop_by_rate_limiting_addresses = 2; + const int num_of_drop_by_load_balancing_addresses = 1; + const int num_of_drop_addresses = num_of_drop_by_rate_limiting_addresses + + num_of_drop_by_load_balancing_addresses; + const int num_total_addresses = num_backends_ + num_of_drop_addresses; + ScheduleResponseForBalancer( + 0, + BalancerServiceImpl::BuildResponseForBackends( + GetBackendPorts(), + {{"rate_limiting", num_of_drop_by_rate_limiting_addresses}, + {"load_balancing", num_of_drop_by_load_balancing_addresses}}), + 0); + // Wait until all backends are ready. + int num_warmup_ok = 0; + int num_warmup_failure = 0; + int num_warmup_drops = 0; + std::tie(num_warmup_ok, num_warmup_failure, num_warmup_drops) = + WaitForAllBackends(num_total_addresses /* num_requests_multiple_of */); + const int num_total_warmup_requests = + num_warmup_ok + num_warmup_failure + num_warmup_drops; + size_t num_drops = 0; + for (size_t i = 0; i < kNumRpcsPerAddress * num_total_addresses; ++i) { + EchoResponse response; + const Status status = SendRpc(&response); + if (!status.ok() && + status.error_message() == "Call dropped by load balancing policy") { + ++num_drops; + } else { + EXPECT_TRUE(status.ok()) << "code=" << status.error_code() + << " message=" << status.error_message(); + EXPECT_EQ(response.message(), kRequestMessage_); + } + } + EXPECT_EQ(kNumRpcsPerAddress * num_of_drop_addresses, num_drops); + // Each backend should have gotten 100 requests. + for (size_t i = 0; i < backends_.size(); ++i) { + EXPECT_EQ(kNumRpcsPerAddress, backends_[i]->service_.request_count()); + } + balancers_[0]->service_.NotifyDoneWithServerlists(); + // The balancer got a single request. + EXPECT_EQ(1U, balancers_[0]->service_.request_count()); + // and sent a single response. + EXPECT_EQ(1U, balancers_[0]->service_.response_count()); + + const ClientStats client_stats = WaitForLoadReports(); + EXPECT_EQ( + kNumRpcsPerAddress * num_total_addresses + num_total_warmup_requests, + client_stats.num_calls_started); + EXPECT_EQ( + kNumRpcsPerAddress * num_total_addresses + num_total_warmup_requests, + client_stats.num_calls_finished); + EXPECT_EQ(0U, client_stats.num_calls_finished_with_client_failed_to_send); + EXPECT_EQ(kNumRpcsPerAddress * num_backends_ + num_warmup_ok, + client_stats.num_calls_finished_known_received); + // The number of warmup request is a multiple of the number of addresses. + // Therefore, all addresses in the scheduled balancer response are hit the + // same number of times. + const int num_times_drop_addresses_hit = + num_warmup_drops / num_of_drop_addresses; + EXPECT_THAT( + client_stats.drop_token_counts, + ::testing::ElementsAre( + ::testing::Pair("load_balancing", + (kNumRpcsPerAddress + num_times_drop_addresses_hit)), + ::testing::Pair( + "rate_limiting", + (kNumRpcsPerAddress + num_times_drop_addresses_hit) * 2))); +} + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + const auto result = RUN_ALL_TESTS(); + return result; +} diff --git a/contrib/libs/grpc/test/cpp/end2end/health/ya.make b/contrib/libs/grpc/test/cpp/end2end/health/ya.make new file mode 100644 index 0000000000..7330129b73 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/end2end/health/ya.make @@ -0,0 +1,33 @@ +GTEST_UGLY() + +OWNER( + dvshkurko + g:ymake +) + +ADDINCL( + ${ARCADIA_BUILD_ROOT}/contrib/libs/grpc + ${ARCADIA_ROOT}/contrib/libs/grpc +) + +PEERDIR( + contrib/libs/grpc/src/proto/grpc/health/v1 + contrib/libs/grpc/src/proto/grpc/core + contrib/libs/grpc/src/proto/grpc/testing + contrib/libs/grpc/src/proto/grpc/testing/duplicate + contrib/libs/grpc/test/core/util + contrib/libs/grpc/test/cpp/end2end + contrib/libs/grpc/test/cpp/util +) + +NO_COMPILER_WARNINGS() + +SRCDIR( + contrib/libs/grpc/test/cpp/end2end +) + +SRCS( + health_service_end2end_test.cc +) + +END() diff --git a/contrib/libs/grpc/test/cpp/end2end/health_service_end2end_test.cc b/contrib/libs/grpc/test/cpp/end2end/health_service_end2end_test.cc new file mode 100644 index 0000000000..516b3a4c81 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/end2end/health_service_end2end_test.cc @@ -0,0 +1,374 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <memory> +#include <mutex> +#include <thread> +#include <vector> + +#include <grpc/grpc.h> +#include <grpc/support/log.h> +#include <grpcpp/channel.h> +#include <grpcpp/client_context.h> +#include <grpcpp/create_channel.h> +#include <grpcpp/ext/health_check_service_server_builder_option.h> +#include <grpcpp/health_check_service_interface.h> +#include <grpcpp/server.h> +#include <grpcpp/server_builder.h> +#include <grpcpp/server_context.h> + +#include "src/proto/grpc/health/v1/health.grpc.pb.h" +#include "src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/end2end/test_health_check_service_impl.h" +#include "test/cpp/end2end/test_service_impl.h" + +#include <gtest/gtest.h> + +using grpc::health::v1::Health; +using grpc::health::v1::HealthCheckRequest; +using grpc::health::v1::HealthCheckResponse; + +namespace grpc { +namespace testing { +namespace { + +// A custom implementation of the health checking service interface. This is +// used to test that it prevents the server from creating a default service and +// also serves as an example of how to override the default service. +class CustomHealthCheckService : public HealthCheckServiceInterface { + public: + explicit CustomHealthCheckService(HealthCheckServiceImpl* impl) + : impl_(impl) { + impl_->SetStatus("", HealthCheckResponse::SERVING); + } + void SetServingStatus(const TString& service_name, + bool serving) override { + impl_->SetStatus(service_name, serving ? HealthCheckResponse::SERVING + : HealthCheckResponse::NOT_SERVING); + } + + void SetServingStatus(bool serving) override { + impl_->SetAll(serving ? HealthCheckResponse::SERVING + : HealthCheckResponse::NOT_SERVING); + } + + void Shutdown() override { impl_->Shutdown(); } + + private: + HealthCheckServiceImpl* impl_; // not owned +}; + +class HealthServiceEnd2endTest : public ::testing::Test { + protected: + HealthServiceEnd2endTest() {} + + void SetUpServer(bool register_sync_test_service, bool add_async_cq, + bool explicit_health_service, + std::unique_ptr<HealthCheckServiceInterface> service) { + int port = 5001; // grpc_pick_unused_port_or_die(); + server_address_ << "localhost:" << port; + + bool register_sync_health_service_impl = + explicit_health_service && service != nullptr; + + // Setup server + ServerBuilder builder; + if (explicit_health_service) { + std::unique_ptr<ServerBuilderOption> option( + new HealthCheckServiceServerBuilderOption(std::move(service))); + builder.SetOption(std::move(option)); + } + builder.AddListeningPort(server_address_.str(), + grpc::InsecureServerCredentials()); + if (register_sync_test_service) { + // Register a sync service. + builder.RegisterService(&echo_test_service_); + } + if (register_sync_health_service_impl) { + builder.RegisterService(&health_check_service_impl_); + } + if (add_async_cq) { + cq_ = builder.AddCompletionQueue(); + } + server_ = builder.BuildAndStart(); + } + + void TearDown() override { + if (server_) { + server_->Shutdown(); + if (cq_ != nullptr) { + cq_->Shutdown(); + } + if (cq_thread_.joinable()) { + cq_thread_.join(); + } + } + } + + void ResetStubs() { + std::shared_ptr<Channel> channel = grpc::CreateChannel( + server_address_.str(), InsecureChannelCredentials()); + hc_stub_ = grpc::health::v1::Health::NewStub(channel); + } + + // When the expected_status is NOT OK, we do not care about the response. + void SendHealthCheckRpc(const TString& service_name, + const Status& expected_status) { + EXPECT_FALSE(expected_status.ok()); + SendHealthCheckRpc(service_name, expected_status, + HealthCheckResponse::UNKNOWN); + } + + void SendHealthCheckRpc( + const TString& service_name, const Status& expected_status, + HealthCheckResponse::ServingStatus expected_serving_status) { + HealthCheckRequest request; + request.set_service(service_name); + HealthCheckResponse response; + ClientContext context; + Status s = hc_stub_->Check(&context, request, &response); + EXPECT_EQ(expected_status.error_code(), s.error_code()); + if (s.ok()) { + EXPECT_EQ(expected_serving_status, response.status()); + } + } + + void VerifyHealthCheckService() { + HealthCheckServiceInterface* service = server_->GetHealthCheckService(); + EXPECT_TRUE(service != nullptr); + const TString kHealthyService("healthy_service"); + const TString kUnhealthyService("unhealthy_service"); + const TString kNotRegisteredService("not_registered"); + service->SetServingStatus(kHealthyService, true); + service->SetServingStatus(kUnhealthyService, false); + + ResetStubs(); + + SendHealthCheckRpc("", Status::OK, HealthCheckResponse::SERVING); + SendHealthCheckRpc(kHealthyService, Status::OK, + HealthCheckResponse::SERVING); + SendHealthCheckRpc(kUnhealthyService, Status::OK, + HealthCheckResponse::NOT_SERVING); + SendHealthCheckRpc(kNotRegisteredService, + Status(StatusCode::NOT_FOUND, "")); + + service->SetServingStatus(false); + SendHealthCheckRpc("", Status::OK, HealthCheckResponse::NOT_SERVING); + SendHealthCheckRpc(kHealthyService, Status::OK, + HealthCheckResponse::NOT_SERVING); + SendHealthCheckRpc(kUnhealthyService, Status::OK, + HealthCheckResponse::NOT_SERVING); + SendHealthCheckRpc(kNotRegisteredService, + Status(StatusCode::NOT_FOUND, "")); + } + + void VerifyHealthCheckServiceStreaming() { + const TString kServiceName("service_name"); + HealthCheckServiceInterface* service = server_->GetHealthCheckService(); + // Start Watch for service. + ClientContext context; + HealthCheckRequest request; + request.set_service(kServiceName); + std::unique_ptr<::grpc::ClientReaderInterface<HealthCheckResponse>> reader = + hc_stub_->Watch(&context, request); + // Initial response will be SERVICE_UNKNOWN. + HealthCheckResponse response; + EXPECT_TRUE(reader->Read(&response)); + EXPECT_EQ(response.SERVICE_UNKNOWN, response.status()); + response.Clear(); + // Now set service to NOT_SERVING and make sure we get an update. + service->SetServingStatus(kServiceName, false); + EXPECT_TRUE(reader->Read(&response)); + EXPECT_EQ(response.NOT_SERVING, response.status()); + response.Clear(); + // Now set service to SERVING and make sure we get another update. + service->SetServingStatus(kServiceName, true); + EXPECT_TRUE(reader->Read(&response)); + EXPECT_EQ(response.SERVING, response.status()); + // Finish call. + context.TryCancel(); + } + + // Verify that after HealthCheckServiceInterface::Shutdown is called + // 1. unary client will see NOT_SERVING. + // 2. unary client still sees NOT_SERVING after a SetServing(true) is called. + // 3. streaming (Watch) client will see an update. + // 4. setting a new service to serving after shutdown will add the service + // name but return NOT_SERVING to client. + // This has to be called last. + void VerifyHealthCheckServiceShutdown() { + HealthCheckServiceInterface* service = server_->GetHealthCheckService(); + EXPECT_TRUE(service != nullptr); + const TString kHealthyService("healthy_service"); + const TString kUnhealthyService("unhealthy_service"); + const TString kNotRegisteredService("not_registered"); + const TString kNewService("add_after_shutdown"); + service->SetServingStatus(kHealthyService, true); + service->SetServingStatus(kUnhealthyService, false); + + ResetStubs(); + + // Start Watch for service. + ClientContext context; + HealthCheckRequest request; + request.set_service(kHealthyService); + std::unique_ptr<::grpc::ClientReaderInterface<HealthCheckResponse>> reader = + hc_stub_->Watch(&context, request); + + HealthCheckResponse response; + EXPECT_TRUE(reader->Read(&response)); + EXPECT_EQ(response.SERVING, response.status()); + + SendHealthCheckRpc("", Status::OK, HealthCheckResponse::SERVING); + SendHealthCheckRpc(kHealthyService, Status::OK, + HealthCheckResponse::SERVING); + SendHealthCheckRpc(kUnhealthyService, Status::OK, + HealthCheckResponse::NOT_SERVING); + SendHealthCheckRpc(kNotRegisteredService, + Status(StatusCode::NOT_FOUND, "")); + SendHealthCheckRpc(kNewService, Status(StatusCode::NOT_FOUND, "")); + + // Shutdown health check service. + service->Shutdown(); + + // Watch client gets another update. + EXPECT_TRUE(reader->Read(&response)); + EXPECT_EQ(response.NOT_SERVING, response.status()); + // Finish Watch call. + context.TryCancel(); + + SendHealthCheckRpc("", Status::OK, HealthCheckResponse::NOT_SERVING); + SendHealthCheckRpc(kHealthyService, Status::OK, + HealthCheckResponse::NOT_SERVING); + SendHealthCheckRpc(kUnhealthyService, Status::OK, + HealthCheckResponse::NOT_SERVING); + SendHealthCheckRpc(kNotRegisteredService, + Status(StatusCode::NOT_FOUND, "")); + + // Setting status after Shutdown has no effect. + service->SetServingStatus(kHealthyService, true); + SendHealthCheckRpc(kHealthyService, Status::OK, + HealthCheckResponse::NOT_SERVING); + + // Adding serving status for a new service after shutdown will return + // NOT_SERVING. + service->SetServingStatus(kNewService, true); + SendHealthCheckRpc(kNewService, Status::OK, + HealthCheckResponse::NOT_SERVING); + } + + TestServiceImpl echo_test_service_; + HealthCheckServiceImpl health_check_service_impl_; + std::unique_ptr<Health::Stub> hc_stub_; + std::unique_ptr<ServerCompletionQueue> cq_; + std::unique_ptr<Server> server_; + std::ostringstream server_address_; + std::thread cq_thread_; +}; + +TEST_F(HealthServiceEnd2endTest, DefaultHealthServiceDisabled) { + EnableDefaultHealthCheckService(false); + EXPECT_FALSE(DefaultHealthCheckServiceEnabled()); + SetUpServer(true, false, false, nullptr); + HealthCheckServiceInterface* default_service = + server_->GetHealthCheckService(); + EXPECT_TRUE(default_service == nullptr); + + ResetStubs(); + + SendHealthCheckRpc("", Status(StatusCode::UNIMPLEMENTED, "")); +} + +TEST_F(HealthServiceEnd2endTest, DefaultHealthService) { + EnableDefaultHealthCheckService(true); + EXPECT_TRUE(DefaultHealthCheckServiceEnabled()); + SetUpServer(true, false, false, nullptr); + VerifyHealthCheckService(); + VerifyHealthCheckServiceStreaming(); + + // The default service has a size limit of the service name. + const TString kTooLongServiceName(201, 'x'); + SendHealthCheckRpc(kTooLongServiceName, + Status(StatusCode::INVALID_ARGUMENT, "")); +} + +TEST_F(HealthServiceEnd2endTest, DefaultHealthServiceShutdown) { + EnableDefaultHealthCheckService(true); + EXPECT_TRUE(DefaultHealthCheckServiceEnabled()); + SetUpServer(true, false, false, nullptr); + VerifyHealthCheckServiceShutdown(); +} + +// Provide an empty service to disable the default service. +TEST_F(HealthServiceEnd2endTest, ExplicitlyDisableViaOverride) { + EnableDefaultHealthCheckService(true); + EXPECT_TRUE(DefaultHealthCheckServiceEnabled()); + std::unique_ptr<HealthCheckServiceInterface> empty_service; + SetUpServer(true, false, true, std::move(empty_service)); + HealthCheckServiceInterface* service = server_->GetHealthCheckService(); + EXPECT_TRUE(service == nullptr); + + ResetStubs(); + + SendHealthCheckRpc("", Status(StatusCode::UNIMPLEMENTED, "")); +} + +// Provide an explicit override of health checking service interface. +TEST_F(HealthServiceEnd2endTest, ExplicitlyOverride) { + EnableDefaultHealthCheckService(true); + EXPECT_TRUE(DefaultHealthCheckServiceEnabled()); + std::unique_ptr<HealthCheckServiceInterface> override_service( + new CustomHealthCheckService(&health_check_service_impl_)); + HealthCheckServiceInterface* underlying_service = override_service.get(); + SetUpServer(false, false, true, std::move(override_service)); + HealthCheckServiceInterface* service = server_->GetHealthCheckService(); + EXPECT_TRUE(service == underlying_service); + + ResetStubs(); + + VerifyHealthCheckService(); + VerifyHealthCheckServiceStreaming(); +} + +TEST_F(HealthServiceEnd2endTest, ExplicitlyHealthServiceShutdown) { + EnableDefaultHealthCheckService(true); + EXPECT_TRUE(DefaultHealthCheckServiceEnabled()); + std::unique_ptr<HealthCheckServiceInterface> override_service( + new CustomHealthCheckService(&health_check_service_impl_)); + HealthCheckServiceInterface* underlying_service = override_service.get(); + SetUpServer(false, false, true, std::move(override_service)); + HealthCheckServiceInterface* service = server_->GetHealthCheckService(); + EXPECT_TRUE(service == underlying_service); + + ResetStubs(); + + VerifyHealthCheckServiceShutdown(); +} + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/contrib/libs/grpc/test/cpp/end2end/hybrid_end2end_test.cc b/contrib/libs/grpc/test/cpp/end2end/hybrid_end2end_test.cc new file mode 100644 index 0000000000..e4ebee8e93 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/end2end/hybrid_end2end_test.cc @@ -0,0 +1,987 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <memory> +#include <thread> + +#include <grpc/grpc.h> +#include <grpcpp/channel.h> +#include <grpcpp/client_context.h> +#include <grpcpp/create_channel.h> +#include <grpcpp/generic/async_generic_service.h> +#include <grpcpp/server.h> +#include <grpcpp/server_builder.h> +#include <grpcpp/server_context.h> + +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/iomgr/iomgr.h" +#include "src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/end2end/test_service_impl.h" +#include "test/cpp/util/byte_buffer_proto_helper.h" + +#include <gtest/gtest.h> + +namespace grpc { +namespace testing { +namespace { + +#ifndef GRPC_CALLBACK_API_NONEXPERIMENTAL +using ::grpc::experimental::CallbackGenericService; +using ::grpc::experimental::GenericCallbackServerContext; +using ::grpc::experimental::ServerGenericBidiReactor; +#endif + +void* tag(int i) { return (void*)static_cast<intptr_t>(i); } + +bool VerifyReturnSuccess(CompletionQueue* cq, int i) { + void* got_tag; + bool ok; + EXPECT_TRUE(cq->Next(&got_tag, &ok)); + EXPECT_EQ(tag(i), got_tag); + return ok; +} + +void Verify(CompletionQueue* cq, int i, bool expect_ok) { + EXPECT_EQ(expect_ok, VerifyReturnSuccess(cq, i)); +} + +// Handlers to handle async request at a server. To be run in a separate thread. +template <class Service> +void HandleEcho(Service* service, ServerCompletionQueue* cq, bool dup_service) { + ServerContext srv_ctx; + grpc::ServerAsyncResponseWriter<EchoResponse> response_writer(&srv_ctx); + EchoRequest recv_request; + EchoResponse send_response; + service->RequestEcho(&srv_ctx, &recv_request, &response_writer, cq, cq, + tag(1)); + Verify(cq, 1, true); + send_response.set_message(recv_request.message()); + if (dup_service) { + send_response.mutable_message()->append("_dup"); + } + response_writer.Finish(send_response, Status::OK, tag(2)); + Verify(cq, 2, true); +} + +// Handlers to handle raw request at a server. To be run in a +// separate thread. Note that this is the same as the async version, except +// that the req/resp are ByteBuffers +template <class Service> +void HandleRawEcho(Service* service, ServerCompletionQueue* cq, + bool /*dup_service*/) { + ServerContext srv_ctx; + GenericServerAsyncResponseWriter response_writer(&srv_ctx); + ByteBuffer recv_buffer; + service->RequestEcho(&srv_ctx, &recv_buffer, &response_writer, cq, cq, + tag(1)); + Verify(cq, 1, true); + EchoRequest recv_request; + EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_request)); + EchoResponse send_response; + send_response.set_message(recv_request.message()); + auto send_buffer = SerializeToByteBuffer(&send_response); + response_writer.Finish(*send_buffer, Status::OK, tag(2)); + Verify(cq, 2, true); +} + +template <class Service> +void HandleClientStreaming(Service* service, ServerCompletionQueue* cq) { + ServerContext srv_ctx; + EchoRequest recv_request; + EchoResponse send_response; + ServerAsyncReader<EchoResponse, EchoRequest> srv_stream(&srv_ctx); + service->RequestRequestStream(&srv_ctx, &srv_stream, cq, cq, tag(1)); + Verify(cq, 1, true); + int i = 1; + do { + i++; + send_response.mutable_message()->append(recv_request.message()); + srv_stream.Read(&recv_request, tag(i)); + } while (VerifyReturnSuccess(cq, i)); + srv_stream.Finish(send_response, Status::OK, tag(100)); + Verify(cq, 100, true); +} + +template <class Service> +void HandleRawClientStreaming(Service* service, ServerCompletionQueue* cq) { + ServerContext srv_ctx; + ByteBuffer recv_buffer; + EchoRequest recv_request; + EchoResponse send_response; + GenericServerAsyncReader srv_stream(&srv_ctx); + service->RequestRequestStream(&srv_ctx, &srv_stream, cq, cq, tag(1)); + Verify(cq, 1, true); + int i = 1; + while (true) { + i++; + srv_stream.Read(&recv_buffer, tag(i)); + if (!VerifyReturnSuccess(cq, i)) { + break; + } + EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_request)); + send_response.mutable_message()->append(recv_request.message()); + } + auto send_buffer = SerializeToByteBuffer(&send_response); + srv_stream.Finish(*send_buffer, Status::OK, tag(100)); + Verify(cq, 100, true); +} + +template <class Service> +void HandleServerStreaming(Service* service, ServerCompletionQueue* cq) { + ServerContext srv_ctx; + EchoRequest recv_request; + EchoResponse send_response; + ServerAsyncWriter<EchoResponse> srv_stream(&srv_ctx); + service->RequestResponseStream(&srv_ctx, &recv_request, &srv_stream, cq, cq, + tag(1)); + Verify(cq, 1, true); + send_response.set_message(recv_request.message() + "0"); + srv_stream.Write(send_response, tag(2)); + Verify(cq, 2, true); + send_response.set_message(recv_request.message() + "1"); + srv_stream.Write(send_response, tag(3)); + Verify(cq, 3, true); + send_response.set_message(recv_request.message() + "2"); + srv_stream.Write(send_response, tag(4)); + Verify(cq, 4, true); + srv_stream.Finish(Status::OK, tag(5)); + Verify(cq, 5, true); +} + +void HandleGenericEcho(GenericServerAsyncReaderWriter* stream, + CompletionQueue* cq) { + ByteBuffer recv_buffer; + stream->Read(&recv_buffer, tag(2)); + Verify(cq, 2, true); + EchoRequest recv_request; + EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_request)); + EchoResponse send_response; + send_response.set_message(recv_request.message()); + auto send_buffer = SerializeToByteBuffer(&send_response); + stream->Write(*send_buffer, tag(3)); + Verify(cq, 3, true); + stream->Finish(Status::OK, tag(4)); + Verify(cq, 4, true); +} + +void HandleGenericRequestStream(GenericServerAsyncReaderWriter* stream, + CompletionQueue* cq) { + ByteBuffer recv_buffer; + EchoRequest recv_request; + EchoResponse send_response; + int i = 1; + while (true) { + i++; + stream->Read(&recv_buffer, tag(i)); + if (!VerifyReturnSuccess(cq, i)) { + break; + } + EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_request)); + send_response.mutable_message()->append(recv_request.message()); + } + auto send_buffer = SerializeToByteBuffer(&send_response); + stream->Write(*send_buffer, tag(99)); + Verify(cq, 99, true); + stream->Finish(Status::OK, tag(100)); + Verify(cq, 100, true); +} + +// Request and handle one generic call. +void HandleGenericCall(AsyncGenericService* service, + ServerCompletionQueue* cq) { + GenericServerContext srv_ctx; + GenericServerAsyncReaderWriter stream(&srv_ctx); + service->RequestCall(&srv_ctx, &stream, cq, cq, tag(1)); + Verify(cq, 1, true); + if (srv_ctx.method() == "/grpc.testing.EchoTestService/Echo") { + HandleGenericEcho(&stream, cq); + } else if (srv_ctx.method() == + "/grpc.testing.EchoTestService/RequestStream") { + HandleGenericRequestStream(&stream, cq); + } else { // other methods not handled yet. + gpr_log(GPR_ERROR, "method: %s", srv_ctx.method().c_str()); + GPR_ASSERT(0); + } +} + +class TestServiceImplDupPkg + : public ::grpc::testing::duplicate::EchoTestService::Service { + public: + Status Echo(ServerContext* /*context*/, const EchoRequest* request, + EchoResponse* response) override { + response->set_message(request->message() + "_dup"); + return Status::OK; + } +}; + +class HybridEnd2endTest : public ::testing::TestWithParam<bool> { + protected: + HybridEnd2endTest() {} + + static void SetUpTestCase() { +#if TARGET_OS_IPHONE + // Workaround Apple CFStream bug + gpr_setenv("grpc_cfstream", "0"); +#endif + } + + void SetUp() override { + inproc_ = (::testing::UnitTest::GetInstance() + ->current_test_info() + ->value_param() != nullptr) + ? GetParam() + : false; + } + + bool SetUpServer(::grpc::Service* service1, ::grpc::Service* service2, + AsyncGenericService* generic_service, + CallbackGenericService* callback_generic_service, + int max_message_size = 0) { + int port = grpc_pick_unused_port_or_die(); + server_address_ << "localhost:" << port; + + // Setup server + ServerBuilder builder; + builder.AddListeningPort(server_address_.str(), + grpc::InsecureServerCredentials()); + // Always add a sync unimplemented service: we rely on having at least one + // synchronous method to get a listening cq + builder.RegisterService(&unimplemented_service_); + builder.RegisterService(service1); + if (service2) { + builder.RegisterService(service2); + } + if (generic_service) { + builder.RegisterAsyncGenericService(generic_service); + } + if (callback_generic_service) { +#ifdef GRPC_CALLBACK_API_NONEXPERIMENTAL + builder.RegisterCallbackGenericService(callback_generic_service); +#else + builder.experimental().RegisterCallbackGenericService( + callback_generic_service); +#endif + } + + if (max_message_size != 0) { + builder.SetMaxMessageSize(max_message_size); + } + + // Create a separate cq for each potential handler. + for (int i = 0; i < 5; i++) { + cqs_.push_back(builder.AddCompletionQueue(false)); + } + server_ = builder.BuildAndStart(); + + // If there is a generic callback service, this setup is only successful if + // we have an iomgr that can run in the background or are inprocess + return !callback_generic_service || grpc_iomgr_run_in_background() || + inproc_; + } + + void TearDown() override { + if (server_) { + server_->Shutdown(); + } + void* ignored_tag; + bool ignored_ok; + for (auto it = cqs_.begin(); it != cqs_.end(); ++it) { + (*it)->Shutdown(); + while ((*it)->Next(&ignored_tag, &ignored_ok)) + ; + } + } + + void ResetStub() { + std::shared_ptr<Channel> channel = + inproc_ ? server_->InProcessChannel(ChannelArguments()) + : grpc::CreateChannel(server_address_.str(), + InsecureChannelCredentials()); + stub_ = grpc::testing::EchoTestService::NewStub(channel); + } + + // Test all rpc methods. + void TestAllMethods() { + SendEcho(); + SendSimpleClientStreaming(); + SendSimpleServerStreaming(); + SendBidiStreaming(); + } + + void SendEcho() { + EchoRequest send_request; + EchoResponse recv_response; + ClientContext cli_ctx; + cli_ctx.set_wait_for_ready(true); + send_request.set_message("Hello"); + Status recv_status = stub_->Echo(&cli_ctx, send_request, &recv_response); + EXPECT_EQ(send_request.message(), recv_response.message()); + EXPECT_TRUE(recv_status.ok()); + } + + void SendEchoToDupService() { + std::shared_ptr<Channel> channel = grpc::CreateChannel( + server_address_.str(), InsecureChannelCredentials()); + auto stub = grpc::testing::duplicate::EchoTestService::NewStub(channel); + EchoRequest send_request; + EchoResponse recv_response; + ClientContext cli_ctx; + cli_ctx.set_wait_for_ready(true); + send_request.set_message("Hello"); + Status recv_status = stub->Echo(&cli_ctx, send_request, &recv_response); + EXPECT_EQ(send_request.message() + "_dup", recv_response.message()); + EXPECT_TRUE(recv_status.ok()); + } + + void SendSimpleClientStreaming() { + EchoRequest send_request; + EchoResponse recv_response; + TString expected_message; + ClientContext cli_ctx; + cli_ctx.set_wait_for_ready(true); + send_request.set_message("Hello"); + auto stream = stub_->RequestStream(&cli_ctx, &recv_response); + for (int i = 0; i < 5; i++) { + EXPECT_TRUE(stream->Write(send_request)); + expected_message.append(send_request.message()); + } + stream->WritesDone(); + Status recv_status = stream->Finish(); + EXPECT_EQ(expected_message, recv_response.message()); + EXPECT_TRUE(recv_status.ok()); + } + + void SendSimpleServerStreaming() { + EchoRequest request; + EchoResponse response; + ClientContext context; + context.set_wait_for_ready(true); + request.set_message("hello"); + + auto stream = stub_->ResponseStream(&context, request); + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message() + "0"); + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message() + "1"); + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message() + "2"); + EXPECT_FALSE(stream->Read(&response)); + + Status s = stream->Finish(); + EXPECT_TRUE(s.ok()); + } + + void SendSimpleServerStreamingToDupService() { + std::shared_ptr<Channel> channel = grpc::CreateChannel( + server_address_.str(), InsecureChannelCredentials()); + auto stub = grpc::testing::duplicate::EchoTestService::NewStub(channel); + EchoRequest request; + EchoResponse response; + ClientContext context; + context.set_wait_for_ready(true); + request.set_message("hello"); + + auto stream = stub->ResponseStream(&context, request); + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message() + "0_dup"); + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message() + "1_dup"); + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message() + "2_dup"); + EXPECT_FALSE(stream->Read(&response)); + + Status s = stream->Finish(); + EXPECT_TRUE(s.ok()); + } + + void SendBidiStreaming() { + EchoRequest request; + EchoResponse response; + ClientContext context; + context.set_wait_for_ready(true); + TString msg("hello"); + + auto stream = stub_->BidiStream(&context); + + request.set_message(msg + "0"); + EXPECT_TRUE(stream->Write(request)); + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message()); + + request.set_message(msg + "1"); + EXPECT_TRUE(stream->Write(request)); + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message()); + + request.set_message(msg + "2"); + EXPECT_TRUE(stream->Write(request)); + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message()); + + stream->WritesDone(); + EXPECT_FALSE(stream->Read(&response)); + EXPECT_FALSE(stream->Read(&response)); + + Status s = stream->Finish(); + EXPECT_TRUE(s.ok()); + } + + grpc::testing::UnimplementedEchoService::Service unimplemented_service_; + std::vector<std::unique_ptr<ServerCompletionQueue>> cqs_; + std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_; + std::unique_ptr<Server> server_; + std::ostringstream server_address_; + bool inproc_; +}; + +TEST_F(HybridEnd2endTest, AsyncEcho) { + typedef EchoTestService::WithAsyncMethod_Echo<TestServiceImpl> SType; + SType service; + SetUpServer(&service, nullptr, nullptr, nullptr); + ResetStub(); + std::thread echo_handler_thread(HandleEcho<SType>, &service, cqs_[0].get(), + false); + TestAllMethods(); + echo_handler_thread.join(); +} + +TEST_F(HybridEnd2endTest, RawEcho) { + typedef EchoTestService::WithRawMethod_Echo<TestServiceImpl> SType; + SType service; + SetUpServer(&service, nullptr, nullptr, nullptr); + ResetStub(); + std::thread echo_handler_thread(HandleRawEcho<SType>, &service, cqs_[0].get(), + false); + TestAllMethods(); + echo_handler_thread.join(); +} + +TEST_F(HybridEnd2endTest, RawRequestStream) { + typedef EchoTestService::WithRawMethod_RequestStream<TestServiceImpl> SType; + SType service; + SetUpServer(&service, nullptr, nullptr, nullptr); + ResetStub(); + std::thread request_stream_handler_thread(HandleRawClientStreaming<SType>, + &service, cqs_[0].get()); + TestAllMethods(); + request_stream_handler_thread.join(); +} + +TEST_F(HybridEnd2endTest, AsyncEchoRawRequestStream) { + typedef EchoTestService::WithRawMethod_RequestStream< + EchoTestService::WithAsyncMethod_Echo<TestServiceImpl>> + SType; + SType service; + SetUpServer(&service, nullptr, nullptr, nullptr); + ResetStub(); + std::thread echo_handler_thread(HandleEcho<SType>, &service, cqs_[0].get(), + false); + std::thread request_stream_handler_thread(HandleRawClientStreaming<SType>, + &service, cqs_[1].get()); + TestAllMethods(); + request_stream_handler_thread.join(); + echo_handler_thread.join(); +} + +TEST_F(HybridEnd2endTest, GenericEchoRawRequestStream) { + typedef EchoTestService::WithRawMethod_RequestStream< + EchoTestService::WithGenericMethod_Echo<TestServiceImpl>> + SType; + SType service; + AsyncGenericService generic_service; + SetUpServer(&service, nullptr, &generic_service, nullptr); + ResetStub(); + std::thread generic_handler_thread(HandleGenericCall, &generic_service, + cqs_[0].get()); + std::thread request_stream_handler_thread(HandleRawClientStreaming<SType>, + &service, cqs_[1].get()); + TestAllMethods(); + generic_handler_thread.join(); + request_stream_handler_thread.join(); +} + +TEST_F(HybridEnd2endTest, AsyncEchoRequestStream) { + typedef EchoTestService::WithAsyncMethod_RequestStream< + EchoTestService::WithAsyncMethod_Echo<TestServiceImpl>> + SType; + SType service; + SetUpServer(&service, nullptr, nullptr, nullptr); + ResetStub(); + std::thread echo_handler_thread(HandleEcho<SType>, &service, cqs_[0].get(), + false); + std::thread request_stream_handler_thread(HandleClientStreaming<SType>, + &service, cqs_[1].get()); + TestAllMethods(); + echo_handler_thread.join(); + request_stream_handler_thread.join(); +} + +TEST_F(HybridEnd2endTest, AsyncRequestStreamResponseStream) { + typedef EchoTestService::WithAsyncMethod_RequestStream< + EchoTestService::WithAsyncMethod_ResponseStream<TestServiceImpl>> + SType; + SType service; + SetUpServer(&service, nullptr, nullptr, nullptr); + ResetStub(); + std::thread response_stream_handler_thread(HandleServerStreaming<SType>, + &service, cqs_[0].get()); + std::thread request_stream_handler_thread(HandleClientStreaming<SType>, + &service, cqs_[1].get()); + TestAllMethods(); + response_stream_handler_thread.join(); + request_stream_handler_thread.join(); +} + +// Add a second service with one sync method. +TEST_F(HybridEnd2endTest, AsyncRequestStreamResponseStream_SyncDupService) { + typedef EchoTestService::WithAsyncMethod_RequestStream< + EchoTestService::WithAsyncMethod_ResponseStream<TestServiceImpl>> + SType; + SType service; + TestServiceImplDupPkg dup_service; + SetUpServer(&service, &dup_service, nullptr, nullptr); + ResetStub(); + std::thread response_stream_handler_thread(HandleServerStreaming<SType>, + &service, cqs_[0].get()); + std::thread request_stream_handler_thread(HandleClientStreaming<SType>, + &service, cqs_[1].get()); + TestAllMethods(); + SendEchoToDupService(); + response_stream_handler_thread.join(); + request_stream_handler_thread.join(); +} + +// Add a second service with one sync streamed unary method. +class StreamedUnaryDupPkg + : public duplicate::EchoTestService::WithStreamedUnaryMethod_Echo< + TestServiceImplDupPkg> { + public: + Status StreamedEcho( + ServerContext* /*context*/, + ServerUnaryStreamer<EchoRequest, EchoResponse>* stream) override { + EchoRequest req; + EchoResponse resp; + uint32_t next_msg_sz; + stream->NextMessageSize(&next_msg_sz); + gpr_log(GPR_INFO, "Streamed Unary Next Message Size is %u", next_msg_sz); + GPR_ASSERT(stream->Read(&req)); + resp.set_message(req.message() + "_dup"); + GPR_ASSERT(stream->Write(resp)); + return Status::OK; + } +}; + +TEST_F(HybridEnd2endTest, + AsyncRequestStreamResponseStream_SyncStreamedUnaryDupService) { + typedef EchoTestService::WithAsyncMethod_RequestStream< + EchoTestService::WithAsyncMethod_ResponseStream<TestServiceImpl>> + SType; + SType service; + StreamedUnaryDupPkg dup_service; + SetUpServer(&service, &dup_service, nullptr, nullptr, 8192); + ResetStub(); + std::thread response_stream_handler_thread(HandleServerStreaming<SType>, + &service, cqs_[0].get()); + std::thread request_stream_handler_thread(HandleClientStreaming<SType>, + &service, cqs_[1].get()); + TestAllMethods(); + SendEchoToDupService(); + response_stream_handler_thread.join(); + request_stream_handler_thread.join(); +} + +// Add a second service that is fully Streamed Unary +class FullyStreamedUnaryDupPkg + : public duplicate::EchoTestService::StreamedUnaryService { + public: + Status StreamedEcho( + ServerContext* /*context*/, + ServerUnaryStreamer<EchoRequest, EchoResponse>* stream) override { + EchoRequest req; + EchoResponse resp; + uint32_t next_msg_sz; + stream->NextMessageSize(&next_msg_sz); + gpr_log(GPR_INFO, "Streamed Unary Next Message Size is %u", next_msg_sz); + GPR_ASSERT(stream->Read(&req)); + resp.set_message(req.message() + "_dup"); + GPR_ASSERT(stream->Write(resp)); + return Status::OK; + } +}; + +TEST_F(HybridEnd2endTest, + AsyncRequestStreamResponseStream_SyncFullyStreamedUnaryDupService) { + typedef EchoTestService::WithAsyncMethod_RequestStream< + EchoTestService::WithAsyncMethod_ResponseStream<TestServiceImpl>> + SType; + SType service; + FullyStreamedUnaryDupPkg dup_service; + SetUpServer(&service, &dup_service, nullptr, nullptr, 8192); + ResetStub(); + std::thread response_stream_handler_thread(HandleServerStreaming<SType>, + &service, cqs_[0].get()); + std::thread request_stream_handler_thread(HandleClientStreaming<SType>, + &service, cqs_[1].get()); + TestAllMethods(); + SendEchoToDupService(); + response_stream_handler_thread.join(); + request_stream_handler_thread.join(); +} + +// Add a second service with one sync split server streaming method. +class SplitResponseStreamDupPkg + : public duplicate::EchoTestService:: + WithSplitStreamingMethod_ResponseStream<TestServiceImplDupPkg> { + public: + Status StreamedResponseStream( + ServerContext* /*context*/, + ServerSplitStreamer<EchoRequest, EchoResponse>* stream) override { + EchoRequest req; + EchoResponse resp; + uint32_t next_msg_sz; + stream->NextMessageSize(&next_msg_sz); + gpr_log(GPR_INFO, "Split Streamed Next Message Size is %u", next_msg_sz); + GPR_ASSERT(stream->Read(&req)); + for (int i = 0; i < kServerDefaultResponseStreamsToSend; i++) { + resp.set_message(req.message() + ToString(i) + "_dup"); + GPR_ASSERT(stream->Write(resp)); + } + return Status::OK; + } +}; + +TEST_F(HybridEnd2endTest, + AsyncRequestStreamResponseStream_SyncSplitStreamedDupService) { + typedef EchoTestService::WithAsyncMethod_RequestStream< + EchoTestService::WithAsyncMethod_ResponseStream<TestServiceImpl>> + SType; + SType service; + SplitResponseStreamDupPkg dup_service; + SetUpServer(&service, &dup_service, nullptr, nullptr, 8192); + ResetStub(); + std::thread response_stream_handler_thread(HandleServerStreaming<SType>, + &service, cqs_[0].get()); + std::thread request_stream_handler_thread(HandleClientStreaming<SType>, + &service, cqs_[1].get()); + TestAllMethods(); + SendSimpleServerStreamingToDupService(); + response_stream_handler_thread.join(); + request_stream_handler_thread.join(); +} + +// Add a second service that is fully split server streamed +class FullySplitStreamedDupPkg + : public duplicate::EchoTestService::SplitStreamedService { + public: + Status StreamedResponseStream( + ServerContext* /*context*/, + ServerSplitStreamer<EchoRequest, EchoResponse>* stream) override { + EchoRequest req; + EchoResponse resp; + uint32_t next_msg_sz; + stream->NextMessageSize(&next_msg_sz); + gpr_log(GPR_INFO, "Split Streamed Next Message Size is %u", next_msg_sz); + GPR_ASSERT(stream->Read(&req)); + for (int i = 0; i < kServerDefaultResponseStreamsToSend; i++) { + resp.set_message(req.message() + ToString(i) + "_dup"); + GPR_ASSERT(stream->Write(resp)); + } + return Status::OK; + } +}; + +TEST_F(HybridEnd2endTest, + AsyncRequestStreamResponseStream_FullySplitStreamedDupService) { + typedef EchoTestService::WithAsyncMethod_RequestStream< + EchoTestService::WithAsyncMethod_ResponseStream<TestServiceImpl>> + SType; + SType service; + FullySplitStreamedDupPkg dup_service; + SetUpServer(&service, &dup_service, nullptr, nullptr, 8192); + ResetStub(); + std::thread response_stream_handler_thread(HandleServerStreaming<SType>, + &service, cqs_[0].get()); + std::thread request_stream_handler_thread(HandleClientStreaming<SType>, + &service, cqs_[1].get()); + TestAllMethods(); + SendSimpleServerStreamingToDupService(); + response_stream_handler_thread.join(); + request_stream_handler_thread.join(); +} + +// Add a second service that is fully server streamed +class FullyStreamedDupPkg : public duplicate::EchoTestService::StreamedService { + public: + Status StreamedEcho( + ServerContext* /*context*/, + ServerUnaryStreamer<EchoRequest, EchoResponse>* stream) override { + EchoRequest req; + EchoResponse resp; + uint32_t next_msg_sz; + stream->NextMessageSize(&next_msg_sz); + gpr_log(GPR_INFO, "Streamed Unary Next Message Size is %u", next_msg_sz); + GPR_ASSERT(stream->Read(&req)); + resp.set_message(req.message() + "_dup"); + GPR_ASSERT(stream->Write(resp)); + return Status::OK; + } + Status StreamedResponseStream( + ServerContext* /*context*/, + ServerSplitStreamer<EchoRequest, EchoResponse>* stream) override { + EchoRequest req; + EchoResponse resp; + uint32_t next_msg_sz; + stream->NextMessageSize(&next_msg_sz); + gpr_log(GPR_INFO, "Split Streamed Next Message Size is %u", next_msg_sz); + GPR_ASSERT(stream->Read(&req)); + for (int i = 0; i < kServerDefaultResponseStreamsToSend; i++) { + resp.set_message(req.message() + ToString(i) + "_dup"); + GPR_ASSERT(stream->Write(resp)); + } + return Status::OK; + } +}; + +TEST_F(HybridEnd2endTest, + AsyncRequestStreamResponseStream_FullyStreamedDupService) { + typedef EchoTestService::WithAsyncMethod_RequestStream< + EchoTestService::WithAsyncMethod_ResponseStream<TestServiceImpl>> + SType; + SType service; + FullyStreamedDupPkg dup_service; + SetUpServer(&service, &dup_service, nullptr, nullptr, 8192); + ResetStub(); + std::thread response_stream_handler_thread(HandleServerStreaming<SType>, + &service, cqs_[0].get()); + std::thread request_stream_handler_thread(HandleClientStreaming<SType>, + &service, cqs_[1].get()); + TestAllMethods(); + SendEchoToDupService(); + SendSimpleServerStreamingToDupService(); + response_stream_handler_thread.join(); + request_stream_handler_thread.join(); +} + +// Add a second service with one async method. +TEST_F(HybridEnd2endTest, AsyncRequestStreamResponseStream_AsyncDupService) { + typedef EchoTestService::WithAsyncMethod_RequestStream< + EchoTestService::WithAsyncMethod_ResponseStream<TestServiceImpl>> + SType; + SType service; + duplicate::EchoTestService::AsyncService dup_service; + SetUpServer(&service, &dup_service, nullptr, nullptr); + ResetStub(); + std::thread response_stream_handler_thread(HandleServerStreaming<SType>, + &service, cqs_[0].get()); + std::thread request_stream_handler_thread(HandleClientStreaming<SType>, + &service, cqs_[1].get()); + std::thread echo_handler_thread( + HandleEcho<duplicate::EchoTestService::AsyncService>, &dup_service, + cqs_[2].get(), true); + TestAllMethods(); + SendEchoToDupService(); + response_stream_handler_thread.join(); + request_stream_handler_thread.join(); + echo_handler_thread.join(); +} + +TEST_F(HybridEnd2endTest, GenericEcho) { + EchoTestService::WithGenericMethod_Echo<TestServiceImpl> service; + AsyncGenericService generic_service; + SetUpServer(&service, nullptr, &generic_service, nullptr); + ResetStub(); + std::thread generic_handler_thread(HandleGenericCall, &generic_service, + cqs_[0].get()); + TestAllMethods(); + generic_handler_thread.join(); +} + +TEST_P(HybridEnd2endTest, CallbackGenericEcho) { + EchoTestService::WithGenericMethod_Echo<TestServiceImpl> service; + class GenericEchoService : public CallbackGenericService { + private: + ServerGenericBidiReactor* CreateReactor( + GenericCallbackServerContext* context) override { + EXPECT_EQ(context->method(), "/grpc.testing.EchoTestService/Echo"); + gpr_log(GPR_DEBUG, "Constructor of generic service %d", + static_cast<int>(context->deadline().time_since_epoch().count())); + + class Reactor : public ServerGenericBidiReactor { + public: + Reactor() { StartRead(&request_); } + + private: + void OnDone() override { delete this; } + void OnReadDone(bool ok) override { + if (!ok) { + EXPECT_EQ(reads_complete_, 1); + } else { + EXPECT_EQ(reads_complete_++, 0); + response_ = request_; + StartWrite(&response_); + StartRead(&request_); + } + } + void OnWriteDone(bool ok) override { + Finish(ok ? Status::OK + : Status(StatusCode::UNKNOWN, "Unexpected failure")); + } + ByteBuffer request_; + ByteBuffer response_; + std::atomic_int reads_complete_{0}; + }; + return new Reactor; + } + } generic_service; + + if (!SetUpServer(&service, nullptr, nullptr, &generic_service)) { + return; + } + ResetStub(); + TestAllMethods(); +} + +TEST_F(HybridEnd2endTest, GenericEchoAsyncRequestStream) { + typedef EchoTestService::WithAsyncMethod_RequestStream< + EchoTestService::WithGenericMethod_Echo<TestServiceImpl>> + SType; + SType service; + AsyncGenericService generic_service; + SetUpServer(&service, nullptr, &generic_service, nullptr); + ResetStub(); + std::thread generic_handler_thread(HandleGenericCall, &generic_service, + cqs_[0].get()); + std::thread request_stream_handler_thread(HandleClientStreaming<SType>, + &service, cqs_[1].get()); + TestAllMethods(); + generic_handler_thread.join(); + request_stream_handler_thread.join(); +} + +// Add a second service with one sync method. +TEST_F(HybridEnd2endTest, GenericEchoAsyncRequestStream_SyncDupService) { + typedef EchoTestService::WithAsyncMethod_RequestStream< + EchoTestService::WithGenericMethod_Echo<TestServiceImpl>> + SType; + SType service; + AsyncGenericService generic_service; + TestServiceImplDupPkg dup_service; + SetUpServer(&service, &dup_service, &generic_service, nullptr); + ResetStub(); + std::thread generic_handler_thread(HandleGenericCall, &generic_service, + cqs_[0].get()); + std::thread request_stream_handler_thread(HandleClientStreaming<SType>, + &service, cqs_[1].get()); + TestAllMethods(); + SendEchoToDupService(); + generic_handler_thread.join(); + request_stream_handler_thread.join(); +} + +// Add a second service with one async method. +TEST_F(HybridEnd2endTest, GenericEchoAsyncRequestStream_AsyncDupService) { + typedef EchoTestService::WithAsyncMethod_RequestStream< + EchoTestService::WithGenericMethod_Echo<TestServiceImpl>> + SType; + SType service; + AsyncGenericService generic_service; + duplicate::EchoTestService::AsyncService dup_service; + SetUpServer(&service, &dup_service, &generic_service, nullptr); + ResetStub(); + std::thread generic_handler_thread(HandleGenericCall, &generic_service, + cqs_[0].get()); + std::thread request_stream_handler_thread(HandleClientStreaming<SType>, + &service, cqs_[1].get()); + std::thread echo_handler_thread( + HandleEcho<duplicate::EchoTestService::AsyncService>, &dup_service, + cqs_[2].get(), true); + TestAllMethods(); + SendEchoToDupService(); + generic_handler_thread.join(); + request_stream_handler_thread.join(); + echo_handler_thread.join(); +} + +TEST_F(HybridEnd2endTest, GenericEchoAsyncRequestStreamResponseStream) { + typedef EchoTestService::WithAsyncMethod_RequestStream< + EchoTestService::WithGenericMethod_Echo< + EchoTestService::WithAsyncMethod_ResponseStream<TestServiceImpl>>> + SType; + SType service; + AsyncGenericService generic_service; + SetUpServer(&service, nullptr, &generic_service, nullptr); + ResetStub(); + std::thread generic_handler_thread(HandleGenericCall, &generic_service, + cqs_[0].get()); + std::thread request_stream_handler_thread(HandleClientStreaming<SType>, + &service, cqs_[1].get()); + std::thread response_stream_handler_thread(HandleServerStreaming<SType>, + &service, cqs_[2].get()); + TestAllMethods(); + generic_handler_thread.join(); + request_stream_handler_thread.join(); + response_stream_handler_thread.join(); +} + +TEST_F(HybridEnd2endTest, GenericEchoRequestStreamAsyncResponseStream) { + typedef EchoTestService::WithGenericMethod_RequestStream< + EchoTestService::WithGenericMethod_Echo< + EchoTestService::WithAsyncMethod_ResponseStream<TestServiceImpl>>> + SType; + SType service; + AsyncGenericService generic_service; + SetUpServer(&service, nullptr, &generic_service, nullptr); + ResetStub(); + std::thread generic_handler_thread(HandleGenericCall, &generic_service, + cqs_[0].get()); + std::thread generic_handler_thread2(HandleGenericCall, &generic_service, + cqs_[1].get()); + std::thread response_stream_handler_thread(HandleServerStreaming<SType>, + &service, cqs_[2].get()); + TestAllMethods(); + generic_handler_thread.join(); + generic_handler_thread2.join(); + response_stream_handler_thread.join(); +} + +// If WithGenericMethod is called and no generic service is registered, the +// server will fail to build. +TEST_F(HybridEnd2endTest, GenericMethodWithoutGenericService) { + EchoTestService::WithGenericMethod_RequestStream< + EchoTestService::WithGenericMethod_Echo< + EchoTestService::WithAsyncMethod_ResponseStream<TestServiceImpl>>> + service; + SetUpServer(&service, nullptr, nullptr, nullptr); + EXPECT_EQ(nullptr, server_.get()); +} + +INSTANTIATE_TEST_SUITE_P(HybridEnd2endTest, HybridEnd2endTest, + ::testing::Bool()); + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/contrib/libs/grpc/test/cpp/end2end/interceptors_util.cc b/contrib/libs/grpc/test/cpp/end2end/interceptors_util.cc new file mode 100644 index 0000000000..ff88953651 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/end2end/interceptors_util.cc @@ -0,0 +1,214 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/cpp/end2end/interceptors_util.h" +#include <util/string/cast.h> + +namespace grpc { +namespace testing { + +std::atomic<int> DummyInterceptor::num_times_run_; +std::atomic<int> DummyInterceptor::num_times_run_reverse_; +std::atomic<int> DummyInterceptor::num_times_cancel_; + +void MakeCall(const std::shared_ptr<Channel>& channel) { + auto stub = grpc::testing::EchoTestService::NewStub(channel); + ClientContext ctx; + EchoRequest req; + req.mutable_param()->set_echo_metadata(true); + ctx.AddMetadata("testkey", "testvalue"); + req.set_message("Hello"); + EchoResponse resp; + Status s = stub->Echo(&ctx, req, &resp); + EXPECT_EQ(s.ok(), true); + EXPECT_EQ(resp.message(), "Hello"); +} + +void MakeClientStreamingCall(const std::shared_ptr<Channel>& channel) { + auto stub = grpc::testing::EchoTestService::NewStub(channel); + ClientContext ctx; + EchoRequest req; + req.mutable_param()->set_echo_metadata(true); + ctx.AddMetadata("testkey", "testvalue"); + req.set_message("Hello"); + EchoResponse resp; + string expected_resp = ""; + auto writer = stub->RequestStream(&ctx, &resp); + for (int i = 0; i < kNumStreamingMessages; i++) { + writer->Write(req); + expected_resp += "Hello"; + } + writer->WritesDone(); + Status s = writer->Finish(); + EXPECT_EQ(s.ok(), true); + EXPECT_EQ(resp.message(), expected_resp); +} + +void MakeServerStreamingCall(const std::shared_ptr<Channel>& channel) { + auto stub = grpc::testing::EchoTestService::NewStub(channel); + ClientContext ctx; + EchoRequest req; + req.mutable_param()->set_echo_metadata(true); + ctx.AddMetadata("testkey", "testvalue"); + req.set_message("Hello"); + EchoResponse resp; + auto reader = stub->ResponseStream(&ctx, req); + int count = 0; + while (reader->Read(&resp)) { + EXPECT_EQ(resp.message(), "Hello"); + count++; + } + ASSERT_EQ(count, kNumStreamingMessages); + Status s = reader->Finish(); + EXPECT_EQ(s.ok(), true); +} + +void MakeBidiStreamingCall(const std::shared_ptr<Channel>& channel) { + auto stub = grpc::testing::EchoTestService::NewStub(channel); + ClientContext ctx; + EchoRequest req; + EchoResponse resp; + ctx.AddMetadata("testkey", "testvalue"); + req.mutable_param()->set_echo_metadata(true); + auto stream = stub->BidiStream(&ctx); + for (auto i = 0; i < kNumStreamingMessages; i++) { + req.set_message(TString("Hello") + ::ToString(i)); + stream->Write(req); + stream->Read(&resp); + EXPECT_EQ(req.message(), resp.message()); + } + ASSERT_TRUE(stream->WritesDone()); + Status s = stream->Finish(); + EXPECT_EQ(s.ok(), true); +} + +void MakeAsyncCQCall(const std::shared_ptr<Channel>& channel) { + auto stub = grpc::testing::EchoTestService::NewStub(channel); + CompletionQueue cq; + EchoRequest send_request; + EchoResponse recv_response; + Status recv_status; + ClientContext cli_ctx; + + send_request.set_message("Hello"); + cli_ctx.AddMetadata("testkey", "testvalue"); + std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader( + stub->AsyncEcho(&cli_ctx, send_request, &cq)); + response_reader->Finish(&recv_response, &recv_status, tag(1)); + Verifier().Expect(1, true).Verify(&cq); + EXPECT_EQ(send_request.message(), recv_response.message()); + EXPECT_TRUE(recv_status.ok()); +} + +void MakeAsyncCQClientStreamingCall( + const std::shared_ptr<Channel>& /*channel*/) { + // TODO(yashykt) : Fill this out +} + +void MakeAsyncCQServerStreamingCall(const std::shared_ptr<Channel>& channel) { + auto stub = grpc::testing::EchoTestService::NewStub(channel); + CompletionQueue cq; + EchoRequest send_request; + EchoResponse recv_response; + Status recv_status; + ClientContext cli_ctx; + + cli_ctx.AddMetadata("testkey", "testvalue"); + send_request.set_message("Hello"); + std::unique_ptr<ClientAsyncReader<EchoResponse>> cli_stream( + stub->AsyncResponseStream(&cli_ctx, send_request, &cq, tag(1))); + Verifier().Expect(1, true).Verify(&cq); + // Read the expected number of messages + for (int i = 0; i < kNumStreamingMessages; i++) { + cli_stream->Read(&recv_response, tag(2)); + Verifier().Expect(2, true).Verify(&cq); + ASSERT_EQ(recv_response.message(), send_request.message()); + } + // The next read should fail + cli_stream->Read(&recv_response, tag(3)); + Verifier().Expect(3, false).Verify(&cq); + // Get the status + cli_stream->Finish(&recv_status, tag(4)); + Verifier().Expect(4, true).Verify(&cq); + EXPECT_TRUE(recv_status.ok()); +} + +void MakeAsyncCQBidiStreamingCall(const std::shared_ptr<Channel>& /*channel*/) { + // TODO(yashykt) : Fill this out +} + +void MakeCallbackCall(const std::shared_ptr<Channel>& channel) { + auto stub = grpc::testing::EchoTestService::NewStub(channel); + ClientContext ctx; + EchoRequest req; + std::mutex mu; + std::condition_variable cv; + bool done = false; + req.mutable_param()->set_echo_metadata(true); + ctx.AddMetadata("testkey", "testvalue"); + req.set_message("Hello"); + EchoResponse resp; + stub->experimental_async()->Echo(&ctx, &req, &resp, + [&resp, &mu, &done, &cv](Status s) { + EXPECT_EQ(s.ok(), true); + EXPECT_EQ(resp.message(), "Hello"); + std::lock_guard<std::mutex> l(mu); + done = true; + cv.notify_one(); + }); + std::unique_lock<std::mutex> l(mu); + while (!done) { + cv.wait(l); + } +} + +bool CheckMetadata(const std::multimap<grpc::string_ref, grpc::string_ref>& map, + const string& key, const string& value) { + for (const auto& pair : map) { + if (pair.first.starts_with(key) && pair.second.starts_with(value)) { + return true; + } + } + return false; +} + +bool CheckMetadata(const std::multimap<TString, TString>& map, + const string& key, const string& value) { + for (const auto& pair : map) { + if (pair.first == key.c_str() && pair.second == value.c_str()) { + return true; + } + } + return false; +} + +std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> +CreateDummyClientInterceptors() { + std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> + creators; + // Add 20 dummy interceptors before hijacking interceptor + creators.reserve(20); + for (auto i = 0; i < 20; i++) { + creators.push_back(std::unique_ptr<DummyInterceptorFactory>( + new DummyInterceptorFactory())); + } + return creators; +} + +} // namespace testing +} // namespace grpc diff --git a/contrib/libs/grpc/test/cpp/end2end/interceptors_util.h b/contrib/libs/grpc/test/cpp/end2end/interceptors_util.h new file mode 100644 index 0000000000..c95170bbbc --- /dev/null +++ b/contrib/libs/grpc/test/cpp/end2end/interceptors_util.h @@ -0,0 +1,317 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <condition_variable> + +#include <grpcpp/channel.h> + +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/cpp/util/string_ref_helper.h" + +#include <gtest/gtest.h> + +namespace grpc { +namespace testing { +/* This interceptor does nothing. Just keeps a global count on the number of + * times it was invoked. */ +class DummyInterceptor : public experimental::Interceptor { + public: + DummyInterceptor() {} + + virtual void Intercept(experimental::InterceptorBatchMethods* methods) { + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + num_times_run_++; + } else if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints:: + POST_RECV_INITIAL_METADATA)) { + num_times_run_reverse_++; + } else if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_CANCEL)) { + num_times_cancel_++; + } + methods->Proceed(); + } + + static void Reset() { + num_times_run_.store(0); + num_times_run_reverse_.store(0); + num_times_cancel_.store(0); + } + + static int GetNumTimesRun() { + EXPECT_EQ(num_times_run_.load(), num_times_run_reverse_.load()); + return num_times_run_.load(); + } + + static int GetNumTimesCancel() { return num_times_cancel_.load(); } + + private: + static std::atomic<int> num_times_run_; + static std::atomic<int> num_times_run_reverse_; + static std::atomic<int> num_times_cancel_; +}; + +class DummyInterceptorFactory + : public experimental::ClientInterceptorFactoryInterface, + public experimental::ServerInterceptorFactoryInterface { + public: + virtual experimental::Interceptor* CreateClientInterceptor( + experimental::ClientRpcInfo* /*info*/) override { + return new DummyInterceptor(); + } + + virtual experimental::Interceptor* CreateServerInterceptor( + experimental::ServerRpcInfo* /*info*/) override { + return new DummyInterceptor(); + } +}; + +/* This interceptor factory returns nullptr on interceptor creation */ +class NullInterceptorFactory + : public experimental::ClientInterceptorFactoryInterface, + public experimental::ServerInterceptorFactoryInterface { + public: + virtual experimental::Interceptor* CreateClientInterceptor( + experimental::ClientRpcInfo* /*info*/) override { + return nullptr; + } + + virtual experimental::Interceptor* CreateServerInterceptor( + experimental::ServerRpcInfo* /*info*/) override { + return nullptr; + } +}; + +class EchoTestServiceStreamingImpl : public EchoTestService::Service { + public: + ~EchoTestServiceStreamingImpl() override {} + + Status Echo(ServerContext* context, const EchoRequest* request, + EchoResponse* response) override { + auto client_metadata = context->client_metadata(); + for (const auto& pair : client_metadata) { + context->AddTrailingMetadata(ToString(pair.first), ToString(pair.second)); + } + response->set_message(request->message()); + return Status::OK; + } + + Status BidiStream( + ServerContext* context, + grpc::ServerReaderWriter<EchoResponse, EchoRequest>* stream) override { + EchoRequest req; + EchoResponse resp; + auto client_metadata = context->client_metadata(); + for (const auto& pair : client_metadata) { + context->AddTrailingMetadata(ToString(pair.first), ToString(pair.second)); + } + + while (stream->Read(&req)) { + resp.set_message(req.message()); + EXPECT_TRUE(stream->Write(resp, grpc::WriteOptions())); + } + return Status::OK; + } + + Status RequestStream(ServerContext* context, + ServerReader<EchoRequest>* reader, + EchoResponse* resp) override { + auto client_metadata = context->client_metadata(); + for (const auto& pair : client_metadata) { + context->AddTrailingMetadata(ToString(pair.first), ToString(pair.second)); + } + + EchoRequest req; + string response_str = ""; + while (reader->Read(&req)) { + response_str += req.message(); + } + resp->set_message(response_str); + return Status::OK; + } + + Status ResponseStream(ServerContext* context, const EchoRequest* req, + ServerWriter<EchoResponse>* writer) override { + auto client_metadata = context->client_metadata(); + for (const auto& pair : client_metadata) { + context->AddTrailingMetadata(ToString(pair.first), ToString(pair.second)); + } + + EchoResponse resp; + resp.set_message(req->message()); + for (int i = 0; i < 10; i++) { + EXPECT_TRUE(writer->Write(resp)); + } + return Status::OK; + } +}; + +constexpr int kNumStreamingMessages = 10; + +void MakeCall(const std::shared_ptr<Channel>& channel); + +void MakeClientStreamingCall(const std::shared_ptr<Channel>& channel); + +void MakeServerStreamingCall(const std::shared_ptr<Channel>& channel); + +void MakeBidiStreamingCall(const std::shared_ptr<Channel>& channel); + +void MakeAsyncCQCall(const std::shared_ptr<Channel>& channel); + +void MakeAsyncCQClientStreamingCall(const std::shared_ptr<Channel>& channel); + +void MakeAsyncCQServerStreamingCall(const std::shared_ptr<Channel>& channel); + +void MakeAsyncCQBidiStreamingCall(const std::shared_ptr<Channel>& channel); + +void MakeCallbackCall(const std::shared_ptr<Channel>& channel); + +bool CheckMetadata(const std::multimap<grpc::string_ref, grpc::string_ref>& map, + const string& key, const string& value); + +bool CheckMetadata(const std::multimap<TString, TString>& map, + const string& key, const string& value); + +std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> +CreateDummyClientInterceptors(); + +inline void* tag(int i) { return (void*)static_cast<intptr_t>(i); } +inline int detag(void* p) { + return static_cast<int>(reinterpret_cast<intptr_t>(p)); +} + +class Verifier { + public: + Verifier() : lambda_run_(false) {} + // Expect sets the expected ok value for a specific tag + Verifier& Expect(int i, bool expect_ok) { + return ExpectUnless(i, expect_ok, false); + } + // ExpectUnless sets the expected ok value for a specific tag + // unless the tag was already marked seen (as a result of ExpectMaybe) + Verifier& ExpectUnless(int i, bool expect_ok, bool seen) { + if (!seen) { + expectations_[tag(i)] = expect_ok; + } + return *this; + } + // ExpectMaybe sets the expected ok value for a specific tag, but does not + // require it to appear + // If it does, sets *seen to true + Verifier& ExpectMaybe(int i, bool expect_ok, bool* seen) { + if (!*seen) { + maybe_expectations_[tag(i)] = MaybeExpect{expect_ok, seen}; + } + return *this; + } + + // Next waits for 1 async tag to complete, checks its + // expectations, and returns the tag + int Next(CompletionQueue* cq, bool ignore_ok) { + bool ok; + void* got_tag; + EXPECT_TRUE(cq->Next(&got_tag, &ok)); + GotTag(got_tag, ok, ignore_ok); + return detag(got_tag); + } + + template <typename T> + CompletionQueue::NextStatus DoOnceThenAsyncNext( + CompletionQueue* cq, void** got_tag, bool* ok, T deadline, + std::function<void(void)> lambda) { + if (lambda_run_) { + return cq->AsyncNext(got_tag, ok, deadline); + } else { + lambda_run_ = true; + return cq->DoThenAsyncNext(lambda, got_tag, ok, deadline); + } + } + + // Verify keeps calling Next until all currently set + // expected tags are complete + void Verify(CompletionQueue* cq) { Verify(cq, false); } + + // This version of Verify allows optionally ignoring the + // outcome of the expectation + void Verify(CompletionQueue* cq, bool ignore_ok) { + GPR_ASSERT(!expectations_.empty() || !maybe_expectations_.empty()); + while (!expectations_.empty()) { + Next(cq, ignore_ok); + } + } + + // This version of Verify stops after a certain deadline, and uses the + // DoThenAsyncNext API + // to call the lambda + void Verify(CompletionQueue* cq, + std::chrono::system_clock::time_point deadline, + const std::function<void(void)>& lambda) { + if (expectations_.empty()) { + bool ok; + void* got_tag; + EXPECT_EQ(DoOnceThenAsyncNext(cq, &got_tag, &ok, deadline, lambda), + CompletionQueue::TIMEOUT); + } else { + while (!expectations_.empty()) { + bool ok; + void* got_tag; + EXPECT_EQ(DoOnceThenAsyncNext(cq, &got_tag, &ok, deadline, lambda), + CompletionQueue::GOT_EVENT); + GotTag(got_tag, ok, false); + } + } + } + + private: + void GotTag(void* got_tag, bool ok, bool ignore_ok) { + auto it = expectations_.find(got_tag); + if (it != expectations_.end()) { + if (!ignore_ok) { + EXPECT_EQ(it->second, ok); + } + expectations_.erase(it); + } else { + auto it2 = maybe_expectations_.find(got_tag); + if (it2 != maybe_expectations_.end()) { + if (it2->second.seen != nullptr) { + EXPECT_FALSE(*it2->second.seen); + *it2->second.seen = true; + } + if (!ignore_ok) { + EXPECT_EQ(it2->second.ok, ok); + } + } else { + gpr_log(GPR_ERROR, "Unexpected tag: %p", got_tag); + abort(); + } + } + } + + struct MaybeExpect { + bool ok; + bool* seen; + }; + + std::map<void*, bool> expectations_; + std::map<void*, MaybeExpect> maybe_expectations_; + bool lambda_run_; +}; + +} // namespace testing +} // namespace grpc diff --git a/contrib/libs/grpc/test/cpp/end2end/message_allocator_end2end_test.cc b/contrib/libs/grpc/test/cpp/end2end/message_allocator_end2end_test.cc new file mode 100644 index 0000000000..4bf755206e --- /dev/null +++ b/contrib/libs/grpc/test/cpp/end2end/message_allocator_end2end_test.cc @@ -0,0 +1,438 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <algorithm> +#include <atomic> +#include <condition_variable> +#include <functional> +#include <memory> +#include <mutex> +#include <sstream> +#include <thread> + +#include <google/protobuf/arena.h> + +#include <grpc/impl/codegen/log.h> +#include <gtest/gtest.h> + +#include <grpcpp/channel.h> +#include <grpcpp/client_context.h> +#include <grpcpp/create_channel.h> +#include <grpcpp/server.h> +#include <grpcpp/server_builder.h> +#include <grpcpp/server_context.h> +#include <grpcpp/support/client_callback.h> +#include <grpcpp/support/message_allocator.h> + +#include "src/core/lib/iomgr/iomgr.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/util/test_credentials_provider.h" + +// MAYBE_SKIP_TEST is a macro to determine if this particular test configuration +// should be skipped based on a decision made at SetUp time. In particular, any +// callback tests can only be run if the iomgr can run in the background or if +// the transport is in-process. +#define MAYBE_SKIP_TEST \ + do { \ + if (do_not_test_) { \ + return; \ + } \ + } while (0) + +namespace grpc { +namespace testing { +namespace { + +class CallbackTestServiceImpl + : public EchoTestService::ExperimentalCallbackService { + public: + explicit CallbackTestServiceImpl() {} + + void SetAllocatorMutator( + std::function<void(experimental::RpcAllocatorState* allocator_state, + const EchoRequest* req, EchoResponse* resp)> + mutator) { + allocator_mutator_ = mutator; + } + + experimental::ServerUnaryReactor* Echo( + experimental::CallbackServerContext* context, const EchoRequest* request, + EchoResponse* response) override { + response->set_message(request->message()); + if (allocator_mutator_) { + allocator_mutator_(context->GetRpcAllocatorState(), request, response); + } + auto* reactor = context->DefaultReactor(); + reactor->Finish(Status::OK); + return reactor; + } + + private: + std::function<void(experimental::RpcAllocatorState* allocator_state, + const EchoRequest* req, EchoResponse* resp)> + allocator_mutator_; +}; + +enum class Protocol { INPROC, TCP }; + +class TestScenario { + public: + TestScenario(Protocol protocol, const TString& creds_type) + : protocol(protocol), credentials_type(creds_type) {} + void Log() const; + Protocol protocol; + const TString credentials_type; +}; + +static std::ostream& operator<<(std::ostream& out, + const TestScenario& scenario) { + return out << "TestScenario{protocol=" + << (scenario.protocol == Protocol::INPROC ? "INPROC" : "TCP") + << "," << scenario.credentials_type << "}"; +} + +void TestScenario::Log() const { + std::ostringstream out; + out << *this; + gpr_log(GPR_INFO, "%s", out.str().c_str()); +} + +class MessageAllocatorEnd2endTestBase + : public ::testing::TestWithParam<TestScenario> { + protected: + MessageAllocatorEnd2endTestBase() { + GetParam().Log(); + if (GetParam().protocol == Protocol::TCP) { + if (!grpc_iomgr_run_in_background()) { + do_not_test_ = true; + return; + } + } + } + + ~MessageAllocatorEnd2endTestBase() = default; + + void CreateServer( + experimental::MessageAllocator<EchoRequest, EchoResponse>* allocator) { + ServerBuilder builder; + + auto server_creds = GetCredentialsProvider()->GetServerCredentials( + GetParam().credentials_type); + if (GetParam().protocol == Protocol::TCP) { + picked_port_ = grpc_pick_unused_port_or_die(); + server_address_ << "localhost:" << picked_port_; + builder.AddListeningPort(server_address_.str(), server_creds); + } + callback_service_.SetMessageAllocatorFor_Echo(allocator); + builder.RegisterService(&callback_service_); + + server_ = builder.BuildAndStart(); + } + + void DestroyServer() { + if (server_) { + server_->Shutdown(); + server_.reset(); + } + } + + void ResetStub() { + ChannelArguments args; + auto channel_creds = GetCredentialsProvider()->GetChannelCredentials( + GetParam().credentials_type, &args); + switch (GetParam().protocol) { + case Protocol::TCP: + channel_ = ::grpc::CreateCustomChannel(server_address_.str(), + channel_creds, args); + break; + case Protocol::INPROC: + channel_ = server_->InProcessChannel(args); + break; + default: + assert(false); + } + stub_ = EchoTestService::NewStub(channel_); + } + + void TearDown() override { + DestroyServer(); + if (picked_port_ > 0) { + grpc_recycle_unused_port(picked_port_); + } + } + + void SendRpcs(int num_rpcs) { + TString test_string(""); + for (int i = 0; i < num_rpcs; i++) { + EchoRequest request; + EchoResponse response; + ClientContext cli_ctx; + + test_string += TString(1024, 'x'); + request.set_message(test_string); + TString val; + cli_ctx.set_compression_algorithm(GRPC_COMPRESS_GZIP); + + std::mutex mu; + std::condition_variable cv; + bool done = false; + stub_->experimental_async()->Echo( + &cli_ctx, &request, &response, + [&request, &response, &done, &mu, &cv, val](Status s) { + GPR_ASSERT(s.ok()); + + EXPECT_EQ(request.message(), response.message()); + std::lock_guard<std::mutex> l(mu); + done = true; + cv.notify_one(); + }); + std::unique_lock<std::mutex> l(mu); + while (!done) { + cv.wait(l); + } + } + } + + bool do_not_test_{false}; + int picked_port_{0}; + std::shared_ptr<Channel> channel_; + std::unique_ptr<EchoTestService::Stub> stub_; + CallbackTestServiceImpl callback_service_; + std::unique_ptr<Server> server_; + std::ostringstream server_address_; +}; + +class NullAllocatorTest : public MessageAllocatorEnd2endTestBase {}; + +TEST_P(NullAllocatorTest, SimpleRpc) { + MAYBE_SKIP_TEST; + CreateServer(nullptr); + ResetStub(); + SendRpcs(1); +} + +class SimpleAllocatorTest : public MessageAllocatorEnd2endTestBase { + public: + class SimpleAllocator + : public experimental::MessageAllocator<EchoRequest, EchoResponse> { + public: + class MessageHolderImpl + : public experimental::MessageHolder<EchoRequest, EchoResponse> { + public: + MessageHolderImpl(std::atomic_int* request_deallocation_count, + std::atomic_int* messages_deallocation_count) + : request_deallocation_count_(request_deallocation_count), + messages_deallocation_count_(messages_deallocation_count) { + set_request(new EchoRequest); + set_response(new EchoResponse); + } + void Release() override { + (*messages_deallocation_count_)++; + delete request(); + delete response(); + delete this; + } + void FreeRequest() override { + (*request_deallocation_count_)++; + delete request(); + set_request(nullptr); + } + + EchoRequest* ReleaseRequest() { + auto* ret = request(); + set_request(nullptr); + return ret; + } + + private: + std::atomic_int* const request_deallocation_count_; + std::atomic_int* const messages_deallocation_count_; + }; + experimental::MessageHolder<EchoRequest, EchoResponse>* AllocateMessages() + override { + allocation_count++; + return new MessageHolderImpl(&request_deallocation_count, + &messages_deallocation_count); + } + int allocation_count = 0; + std::atomic_int request_deallocation_count{0}; + std::atomic_int messages_deallocation_count{0}; + }; +}; + +TEST_P(SimpleAllocatorTest, SimpleRpc) { + MAYBE_SKIP_TEST; + const int kRpcCount = 10; + std::unique_ptr<SimpleAllocator> allocator(new SimpleAllocator); + CreateServer(allocator.get()); + ResetStub(); + SendRpcs(kRpcCount); + // messages_deallocaton_count is updated in Release after server side OnDone. + // Destroy server to make sure it has been updated. + DestroyServer(); + EXPECT_EQ(kRpcCount, allocator->allocation_count); + EXPECT_EQ(kRpcCount, allocator->messages_deallocation_count); + EXPECT_EQ(0, allocator->request_deallocation_count); +} + +TEST_P(SimpleAllocatorTest, RpcWithEarlyFreeRequest) { + MAYBE_SKIP_TEST; + const int kRpcCount = 10; + std::unique_ptr<SimpleAllocator> allocator(new SimpleAllocator); + auto mutator = [](experimental::RpcAllocatorState* allocator_state, + const EchoRequest* req, EchoResponse* resp) { + auto* info = + static_cast<SimpleAllocator::MessageHolderImpl*>(allocator_state); + EXPECT_EQ(req, info->request()); + EXPECT_EQ(resp, info->response()); + allocator_state->FreeRequest(); + EXPECT_EQ(nullptr, info->request()); + }; + callback_service_.SetAllocatorMutator(mutator); + CreateServer(allocator.get()); + ResetStub(); + SendRpcs(kRpcCount); + // messages_deallocaton_count is updated in Release after server side OnDone. + // Destroy server to make sure it has been updated. + DestroyServer(); + EXPECT_EQ(kRpcCount, allocator->allocation_count); + EXPECT_EQ(kRpcCount, allocator->messages_deallocation_count); + EXPECT_EQ(kRpcCount, allocator->request_deallocation_count); +} + +TEST_P(SimpleAllocatorTest, RpcWithReleaseRequest) { + MAYBE_SKIP_TEST; + const int kRpcCount = 10; + std::unique_ptr<SimpleAllocator> allocator(new SimpleAllocator); + std::vector<EchoRequest*> released_requests; + auto mutator = [&released_requests]( + experimental::RpcAllocatorState* allocator_state, + const EchoRequest* req, EchoResponse* resp) { + auto* info = + static_cast<SimpleAllocator::MessageHolderImpl*>(allocator_state); + EXPECT_EQ(req, info->request()); + EXPECT_EQ(resp, info->response()); + released_requests.push_back(info->ReleaseRequest()); + EXPECT_EQ(nullptr, info->request()); + }; + callback_service_.SetAllocatorMutator(mutator); + CreateServer(allocator.get()); + ResetStub(); + SendRpcs(kRpcCount); + // messages_deallocaton_count is updated in Release after server side OnDone. + // Destroy server to make sure it has been updated. + DestroyServer(); + EXPECT_EQ(kRpcCount, allocator->allocation_count); + EXPECT_EQ(kRpcCount, allocator->messages_deallocation_count); + EXPECT_EQ(0, allocator->request_deallocation_count); + EXPECT_EQ(static_cast<unsigned>(kRpcCount), released_requests.size()); + for (auto* req : released_requests) { + delete req; + } +} + +class ArenaAllocatorTest : public MessageAllocatorEnd2endTestBase { + public: + class ArenaAllocator + : public experimental::MessageAllocator<EchoRequest, EchoResponse> { + public: + class MessageHolderImpl + : public experimental::MessageHolder<EchoRequest, EchoResponse> { + public: + MessageHolderImpl() { + set_request( + google::protobuf::Arena::CreateMessage<EchoRequest>(&arena_)); + set_response( + google::protobuf::Arena::CreateMessage<EchoResponse>(&arena_)); + } + void Release() override { delete this; } + void FreeRequest() override { GPR_ASSERT(0); } + + private: + google::protobuf::Arena arena_; + }; + experimental::MessageHolder<EchoRequest, EchoResponse>* AllocateMessages() + override { + allocation_count++; + return new MessageHolderImpl; + } + int allocation_count = 0; + }; +}; + +TEST_P(ArenaAllocatorTest, SimpleRpc) { + MAYBE_SKIP_TEST; + const int kRpcCount = 10; + std::unique_ptr<ArenaAllocator> allocator(new ArenaAllocator); + CreateServer(allocator.get()); + ResetStub(); + SendRpcs(kRpcCount); + EXPECT_EQ(kRpcCount, allocator->allocation_count); +} + +std::vector<TestScenario> CreateTestScenarios(bool test_insecure) { + std::vector<TestScenario> scenarios; + std::vector<TString> credentials_types{ + GetCredentialsProvider()->GetSecureCredentialsTypeList()}; + auto insec_ok = [] { + // Only allow insecure credentials type when it is registered with the + // provider. User may create providers that do not have insecure. + return GetCredentialsProvider()->GetChannelCredentials( + kInsecureCredentialsType, nullptr) != nullptr; + }; + if (test_insecure && insec_ok()) { + credentials_types.push_back(kInsecureCredentialsType); + } + GPR_ASSERT(!credentials_types.empty()); + + Protocol parr[]{Protocol::INPROC, Protocol::TCP}; + for (Protocol p : parr) { + for (const auto& cred : credentials_types) { + // TODO(vjpai): Test inproc with secure credentials when feasible + if (p == Protocol::INPROC && + (cred != kInsecureCredentialsType || !insec_ok())) { + continue; + } + scenarios.emplace_back(p, cred); + } + } + return scenarios; +} + +INSTANTIATE_TEST_SUITE_P(NullAllocatorTest, NullAllocatorTest, + ::testing::ValuesIn(CreateTestScenarios(true))); +INSTANTIATE_TEST_SUITE_P(SimpleAllocatorTest, SimpleAllocatorTest, + ::testing::ValuesIn(CreateTestScenarios(true))); +INSTANTIATE_TEST_SUITE_P(ArenaAllocatorTest, ArenaAllocatorTest, + ::testing::ValuesIn(CreateTestScenarios(true))); + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + // The grpc_init is to cover the MAYBE_SKIP_TEST. + grpc_init(); + ::testing::InitGoogleTest(&argc, argv); + int ret = RUN_ALL_TESTS(); + grpc_shutdown(); + return ret; +} diff --git a/contrib/libs/grpc/test/cpp/end2end/mock_test.cc b/contrib/libs/grpc/test/cpp/end2end/mock_test.cc new file mode 100644 index 0000000000..a3d61c4e98 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/end2end/mock_test.cc @@ -0,0 +1,434 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <climits> + +#include <grpc/grpc.h> +#include <grpc/support/log.h> +#include <grpc/support/time.h> +#include <grpcpp/channel.h> +#include <grpcpp/client_context.h> +#include <grpcpp/create_channel.h> +#include <grpcpp/server.h> +#include <grpcpp/server_builder.h> +#include <grpcpp/server_context.h> +#include <grpcpp/test/default_reactor_test_peer.h> + +#include "src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "src/proto/grpc/testing/echo_mock.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +#include <grpcpp/test/mock_stream.h> + +#include <gmock/gmock.h> +#include <gtest/gtest.h> + +#include <iostream> + +using grpc::testing::DefaultReactorTestPeer; +using grpc::testing::EchoRequest; +using grpc::testing::EchoResponse; +using grpc::testing::EchoTestService; +using grpc::testing::MockClientReaderWriter; +using std::vector; +using std::chrono::system_clock; +using ::testing::_; +using ::testing::AtLeast; +using ::testing::DoAll; +using ::testing::Invoke; +using ::testing::Return; +using ::testing::SaveArg; +using ::testing::SetArgPointee; +using ::testing::WithArg; + +namespace grpc { +namespace testing { + +namespace { +class FakeClient { + public: + explicit FakeClient(EchoTestService::StubInterface* stub) : stub_(stub) {} + + void DoEcho() { + ClientContext context; + EchoRequest request; + EchoResponse response; + request.set_message("hello world"); + Status s = stub_->Echo(&context, request, &response); + EXPECT_EQ(request.message(), response.message()); + EXPECT_TRUE(s.ok()); + } + + void DoRequestStream() { + EchoRequest request; + EchoResponse response; + + ClientContext context; + TString msg("hello"); + TString exp(msg); + + std::unique_ptr<ClientWriterInterface<EchoRequest>> cstream = + stub_->RequestStream(&context, &response); + + request.set_message(msg); + EXPECT_TRUE(cstream->Write(request)); + + msg = ", world"; + request.set_message(msg); + exp.append(msg); + EXPECT_TRUE(cstream->Write(request)); + + cstream->WritesDone(); + Status s = cstream->Finish(); + + EXPECT_EQ(exp, response.message()); + EXPECT_TRUE(s.ok()); + } + + void DoResponseStream() { + EchoRequest request; + EchoResponse response; + request.set_message("hello world"); + + ClientContext context; + std::unique_ptr<ClientReaderInterface<EchoResponse>> cstream = + stub_->ResponseStream(&context, request); + + TString exp = ""; + EXPECT_TRUE(cstream->Read(&response)); + exp.append(response.message() + " "); + + EXPECT_TRUE(cstream->Read(&response)); + exp.append(response.message()); + + EXPECT_FALSE(cstream->Read(&response)); + EXPECT_EQ(request.message(), exp); + + Status s = cstream->Finish(); + EXPECT_TRUE(s.ok()); + } + + void DoBidiStream() { + EchoRequest request; + EchoResponse response; + ClientContext context; + TString msg("hello"); + + std::unique_ptr<ClientReaderWriterInterface<EchoRequest, EchoResponse>> + stream = stub_->BidiStream(&context); + + request.set_message(msg + "0"); + EXPECT_TRUE(stream->Write(request)); + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message()); + + request.set_message(msg + "1"); + EXPECT_TRUE(stream->Write(request)); + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message()); + + request.set_message(msg + "2"); + EXPECT_TRUE(stream->Write(request)); + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(response.message(), request.message()); + + stream->WritesDone(); + EXPECT_FALSE(stream->Read(&response)); + + Status s = stream->Finish(); + EXPECT_TRUE(s.ok()); + } + + void ResetStub(EchoTestService::StubInterface* stub) { stub_ = stub; } + + private: + EchoTestService::StubInterface* stub_; +}; + +class CallbackTestServiceImpl + : public EchoTestService::ExperimentalCallbackService { + public: + experimental::ServerUnaryReactor* Echo( + experimental::CallbackServerContext* context, const EchoRequest* request, + EchoResponse* response) override { + // Make the mock service explicitly treat empty input messages as invalid + // arguments so that we can test various results of status. In general, a + // mocked service should just use the original service methods, but we are + // adding this variance in Status return value just to improve coverage in + // this test. + auto* reactor = context->DefaultReactor(); + if (request->message().length() > 0) { + response->set_message(request->message()); + reactor->Finish(Status::OK); + } else { + reactor->Finish(Status(StatusCode::INVALID_ARGUMENT, "Invalid request")); + } + return reactor; + } +}; + +class MockCallbackTest : public ::testing::Test { + protected: + CallbackTestServiceImpl service_; + ServerContext context_; +}; + +TEST_F(MockCallbackTest, MockedCallSucceedsWithWait) { + experimental::CallbackServerContext ctx; + EchoRequest req; + EchoResponse resp; + grpc::internal::Mutex mu; + grpc::internal::CondVar cv; + grpc::Status status; + bool status_set = false; + DefaultReactorTestPeer peer(&ctx, [&](::grpc::Status s) { + grpc::internal::MutexLock l(&mu); + status_set = true; + status = std::move(s); + cv.Signal(); + }); + + req.set_message("mock 1"); + auto* reactor = service_.Echo(&ctx, &req, &resp); + cv.WaitUntil(&mu, [&] { + grpc::internal::MutexLock l(&mu); + return status_set; + }); + EXPECT_EQ(reactor, peer.reactor()); + EXPECT_TRUE(peer.test_status_set()); + EXPECT_TRUE(peer.test_status().ok()); + EXPECT_TRUE(status_set); + EXPECT_TRUE(status.ok()); + EXPECT_EQ(req.message(), resp.message()); +} + +TEST_F(MockCallbackTest, MockedCallSucceeds) { + experimental::CallbackServerContext ctx; + EchoRequest req; + EchoResponse resp; + DefaultReactorTestPeer peer(&ctx); + + req.set_message("ha ha, consider yourself mocked."); + auto* reactor = service_.Echo(&ctx, &req, &resp); + EXPECT_EQ(reactor, peer.reactor()); + EXPECT_TRUE(peer.test_status_set()); + EXPECT_TRUE(peer.test_status().ok()); +} + +TEST_F(MockCallbackTest, MockedCallFails) { + experimental::CallbackServerContext ctx; + EchoRequest req; + EchoResponse resp; + DefaultReactorTestPeer peer(&ctx); + + auto* reactor = service_.Echo(&ctx, &req, &resp); + EXPECT_EQ(reactor, peer.reactor()); + EXPECT_TRUE(peer.test_status_set()); + EXPECT_EQ(peer.test_status().error_code(), StatusCode::INVALID_ARGUMENT); +} + +class TestServiceImpl : public EchoTestService::Service { + public: + Status Echo(ServerContext* /*context*/, const EchoRequest* request, + EchoResponse* response) override { + response->set_message(request->message()); + return Status::OK; + } + + Status RequestStream(ServerContext* /*context*/, + ServerReader<EchoRequest>* reader, + EchoResponse* response) override { + EchoRequest request; + TString resp(""); + while (reader->Read(&request)) { + gpr_log(GPR_INFO, "recv msg %s", request.message().c_str()); + resp.append(request.message()); + } + response->set_message(resp); + return Status::OK; + } + + Status ResponseStream(ServerContext* /*context*/, const EchoRequest* request, + ServerWriter<EchoResponse>* writer) override { + EchoResponse response; + vector<TString> tokens = split(request->message()); + for (const TString& token : tokens) { + response.set_message(token); + writer->Write(response); + } + return Status::OK; + } + + Status BidiStream( + ServerContext* /*context*/, + ServerReaderWriter<EchoResponse, EchoRequest>* stream) override { + EchoRequest request; + EchoResponse response; + while (stream->Read(&request)) { + gpr_log(GPR_INFO, "recv msg %s", request.message().c_str()); + response.set_message(request.message()); + stream->Write(response); + } + return Status::OK; + } + + private: + const vector<TString> split(const TString& input) { + TString buff(""); + vector<TString> result; + + for (auto n : input) { + if (n != ' ') { + buff += n; + continue; + } + if (buff == "") continue; + result.push_back(buff); + buff = ""; + } + if (buff != "") result.push_back(buff); + + return result; + } +}; + +class MockTest : public ::testing::Test { + protected: + MockTest() {} + + void SetUp() override { + int port = grpc_pick_unused_port_or_die(); + server_address_ << "localhost:" << port; + // Setup server + ServerBuilder builder; + builder.AddListeningPort(server_address_.str(), + InsecureServerCredentials()); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + } + + void TearDown() override { server_->Shutdown(); } + + void ResetStub() { + std::shared_ptr<Channel> channel = grpc::CreateChannel( + server_address_.str(), InsecureChannelCredentials()); + stub_ = grpc::testing::EchoTestService::NewStub(channel); + } + + std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_; + std::unique_ptr<Server> server_; + std::ostringstream server_address_; + TestServiceImpl service_; +}; + +// Do one real rpc and one mocked one +TEST_F(MockTest, SimpleRpc) { + ResetStub(); + FakeClient client(stub_.get()); + client.DoEcho(); + MockEchoTestServiceStub stub; + EchoResponse resp; + resp.set_message("hello world"); + EXPECT_CALL(stub, Echo(_, _, _)) + .Times(AtLeast(1)) + .WillOnce(DoAll(SetArgPointee<2>(resp), Return(Status::OK))); + client.ResetStub(&stub); + client.DoEcho(); +} + +TEST_F(MockTest, ClientStream) { + ResetStub(); + FakeClient client(stub_.get()); + client.DoRequestStream(); + + MockEchoTestServiceStub stub; + auto w = new MockClientWriter<EchoRequest>(); + EchoResponse resp; + resp.set_message("hello, world"); + + EXPECT_CALL(*w, Write(_, _)).Times(2).WillRepeatedly(Return(true)); + EXPECT_CALL(*w, WritesDone()); + EXPECT_CALL(*w, Finish()).WillOnce(Return(Status::OK)); + + EXPECT_CALL(stub, RequestStreamRaw(_, _)) + .WillOnce(DoAll(SetArgPointee<1>(resp), Return(w))); + client.ResetStub(&stub); + client.DoRequestStream(); +} + +TEST_F(MockTest, ServerStream) { + ResetStub(); + FakeClient client(stub_.get()); + client.DoResponseStream(); + + MockEchoTestServiceStub stub; + auto r = new MockClientReader<EchoResponse>(); + EchoResponse resp1; + resp1.set_message("hello"); + EchoResponse resp2; + resp2.set_message("world"); + + EXPECT_CALL(*r, Read(_)) + .WillOnce(DoAll(SetArgPointee<0>(resp1), Return(true))) + .WillOnce(DoAll(SetArgPointee<0>(resp2), Return(true))) + .WillOnce(Return(false)); + EXPECT_CALL(*r, Finish()).WillOnce(Return(Status::OK)); + + EXPECT_CALL(stub, ResponseStreamRaw(_, _)).WillOnce(Return(r)); + + client.ResetStub(&stub); + client.DoResponseStream(); +} + +ACTION_P(copy, msg) { arg0->set_message(msg->message()); } + +TEST_F(MockTest, BidiStream) { + ResetStub(); + FakeClient client(stub_.get()); + client.DoBidiStream(); + MockEchoTestServiceStub stub; + auto rw = new MockClientReaderWriter<EchoRequest, EchoResponse>(); + EchoRequest msg; + + EXPECT_CALL(*rw, Write(_, _)) + .Times(3) + .WillRepeatedly(DoAll(SaveArg<0>(&msg), Return(true))); + EXPECT_CALL(*rw, Read(_)) + .WillOnce(DoAll(WithArg<0>(copy(&msg)), Return(true))) + .WillOnce(DoAll(WithArg<0>(copy(&msg)), Return(true))) + .WillOnce(DoAll(WithArg<0>(copy(&msg)), Return(true))) + .WillOnce(Return(false)); + EXPECT_CALL(*rw, WritesDone()); + EXPECT_CALL(*rw, Finish()).WillOnce(Return(Status::OK)); + + EXPECT_CALL(stub, BidiStreamRaw(_)).WillOnce(Return(rw)); + client.ResetStub(&stub); + client.DoBidiStream(); +} + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/contrib/libs/grpc/test/cpp/end2end/nonblocking_test.cc b/contrib/libs/grpc/test/cpp/end2end/nonblocking_test.cc new file mode 100644 index 0000000000..4be070ec71 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/end2end/nonblocking_test.cc @@ -0,0 +1,214 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <memory> + +#include <grpcpp/channel.h> +#include <grpcpp/client_context.h> +#include <grpcpp/create_channel.h> +#include <grpcpp/server.h> +#include <grpcpp/server_builder.h> +#include <grpcpp/server_context.h> + +#include "src/core/lib/gpr/tls.h" +#include "src/core/lib/iomgr/port.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +#ifdef GRPC_POSIX_SOCKET +#include "src/core/lib/iomgr/ev_posix.h" +#endif // GRPC_POSIX_SOCKET + +#include <gtest/gtest.h> + +#ifdef GRPC_POSIX_SOCKET +// Thread-local variable to so that only polls from this test assert +// non-blocking (not polls from resolver, timer thread, etc), and only when the +// thread is waiting on polls caused by CompletionQueue::AsyncNext (not for +// picking a port or other reasons). +GPR_TLS_DECL(g_is_nonblocking_poll); + +namespace { + +int maybe_assert_non_blocking_poll(struct pollfd* pfds, nfds_t nfds, + int timeout) { + // Only assert that this poll should have zero timeout if we're in the + // middle of a zero-timeout CQ Next. + if (gpr_tls_get(&g_is_nonblocking_poll)) { + GPR_ASSERT(timeout == 0); + } + return poll(pfds, nfds, timeout); +} + +} // namespace + +namespace grpc { +namespace testing { +namespace { + +void* tag(int i) { return reinterpret_cast<void*>(static_cast<intptr_t>(i)); } +int detag(void* p) { return static_cast<int>(reinterpret_cast<intptr_t>(p)); } + +class NonblockingTest : public ::testing::Test { + protected: + NonblockingTest() {} + + void SetUp() override { + port_ = grpc_pick_unused_port_or_die(); + server_address_ << "localhost:" << port_; + + // Setup server + BuildAndStartServer(); + } + + bool LoopForTag(void** tag, bool* ok) { + // Temporarily set the thread-local nonblocking poll flag so that the polls + // caused by this loop are indeed sent by the library with zero timeout. + intptr_t orig_val = gpr_tls_get(&g_is_nonblocking_poll); + gpr_tls_set(&g_is_nonblocking_poll, static_cast<intptr_t>(true)); + for (;;) { + auto r = cq_->AsyncNext(tag, ok, gpr_time_0(GPR_CLOCK_REALTIME)); + if (r == CompletionQueue::SHUTDOWN) { + gpr_tls_set(&g_is_nonblocking_poll, orig_val); + return false; + } else if (r == CompletionQueue::GOT_EVENT) { + gpr_tls_set(&g_is_nonblocking_poll, orig_val); + return true; + } + } + } + + void TearDown() override { + server_->Shutdown(); + void* ignored_tag; + bool ignored_ok; + cq_->Shutdown(); + while (LoopForTag(&ignored_tag, &ignored_ok)) + ; + stub_.reset(); + grpc_recycle_unused_port(port_); + } + + void BuildAndStartServer() { + ServerBuilder builder; + builder.AddListeningPort(server_address_.str(), + grpc::InsecureServerCredentials()); + service_.reset(new grpc::testing::EchoTestService::AsyncService()); + builder.RegisterService(service_.get()); + cq_ = builder.AddCompletionQueue(); + server_ = builder.BuildAndStart(); + } + + void ResetStub() { + std::shared_ptr<Channel> channel = grpc::CreateChannel( + server_address_.str(), grpc::InsecureChannelCredentials()); + stub_ = grpc::testing::EchoTestService::NewStub(channel); + } + + void SendRpc(int num_rpcs) { + for (int i = 0; i < num_rpcs; i++) { + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + ServerContext srv_ctx; + grpc::ServerAsyncResponseWriter<EchoResponse> response_writer(&srv_ctx); + + send_request.set_message("hello non-blocking world"); + std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader( + stub_->PrepareAsyncEcho(&cli_ctx, send_request, cq_.get())); + + response_reader->StartCall(); + response_reader->Finish(&recv_response, &recv_status, tag(4)); + + service_->RequestEcho(&srv_ctx, &recv_request, &response_writer, + cq_.get(), cq_.get(), tag(2)); + + void* got_tag; + bool ok; + EXPECT_TRUE(LoopForTag(&got_tag, &ok)); + EXPECT_TRUE(ok); + EXPECT_EQ(detag(got_tag), 2); + EXPECT_EQ(send_request.message(), recv_request.message()); + + send_response.set_message(recv_request.message()); + response_writer.Finish(send_response, Status::OK, tag(3)); + + int tagsum = 0; + int tagprod = 1; + EXPECT_TRUE(LoopForTag(&got_tag, &ok)); + EXPECT_TRUE(ok); + tagsum += detag(got_tag); + tagprod *= detag(got_tag); + + EXPECT_TRUE(LoopForTag(&got_tag, &ok)); + EXPECT_TRUE(ok); + tagsum += detag(got_tag); + tagprod *= detag(got_tag); + + EXPECT_EQ(tagsum, 7); + EXPECT_EQ(tagprod, 12); + EXPECT_EQ(send_response.message(), recv_response.message()); + EXPECT_TRUE(recv_status.ok()); + } + } + + std::unique_ptr<ServerCompletionQueue> cq_; + std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_; + std::unique_ptr<Server> server_; + std::unique_ptr<grpc::testing::EchoTestService::AsyncService> service_; + std::ostringstream server_address_; + int port_; +}; + +TEST_F(NonblockingTest, SimpleRpc) { + ResetStub(); + SendRpc(10); +} + +} // namespace +} // namespace testing +} // namespace grpc + +#endif // GRPC_POSIX_SOCKET + +int main(int argc, char** argv) { +#ifdef GRPC_POSIX_SOCKET + // Override the poll function before anything else can happen + grpc_poll_function = maybe_assert_non_blocking_poll; + + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + gpr_tls_init(&g_is_nonblocking_poll); + + // Start the nonblocking poll thread-local variable as false because the + // thread that issues RPCs starts by picking a port (which has non-zero + // timeout). + gpr_tls_set(&g_is_nonblocking_poll, static_cast<intptr_t>(false)); + + int ret = RUN_ALL_TESTS(); + gpr_tls_destroy(&g_is_nonblocking_poll); + return ret; +#else // GRPC_POSIX_SOCKET + return 0; +#endif // GRPC_POSIX_SOCKET +} diff --git a/contrib/libs/grpc/test/cpp/end2end/port_sharing_end2end_test.cc b/contrib/libs/grpc/test/cpp/end2end/port_sharing_end2end_test.cc new file mode 100644 index 0000000000..b69d1dd2be --- /dev/null +++ b/contrib/libs/grpc/test/cpp/end2end/port_sharing_end2end_test.cc @@ -0,0 +1,374 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <grpc/grpc.h> +#include <grpc/support/alloc.h> +#include <grpc/support/log.h> +#include <grpc/support/time.h> +#include <grpcpp/channel.h> +#include <grpcpp/client_context.h> +#include <grpcpp/create_channel.h> +#include <grpcpp/security/credentials.h> +#include <grpcpp/security/server_credentials.h> +#include <grpcpp/server.h> +#include <grpcpp/server_builder.h> +#include <grpcpp/server_context.h> +#include <gtest/gtest.h> + +#include <mutex> +#include <thread> + +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/iomgr/endpoint.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/iomgr/pollset.h" +#include "src/core/lib/iomgr/port.h" +#include "src/core/lib/iomgr/tcp_server.h" +#include "src/core/lib/security/credentials/credentials.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/core/util/test_tcp_server.h" +#include "test/cpp/end2end/test_service_impl.h" +#include "test/cpp/util/test_credentials_provider.h" + +#ifdef GRPC_POSIX_SOCKET_TCP_SERVER + +#include "src/core/lib/iomgr/tcp_posix.h" + +namespace grpc { +namespace testing { +namespace { + +class TestScenario { + public: + TestScenario(bool server_port, bool pending_data, + const TString& creds_type) + : server_has_port(server_port), + queue_pending_data(pending_data), + credentials_type(creds_type) {} + void Log() const; + // server has its own port or not + bool server_has_port; + // whether tcp server should read some data before handoff + bool queue_pending_data; + const TString credentials_type; +}; + +static std::ostream& operator<<(std::ostream& out, + const TestScenario& scenario) { + return out << "TestScenario{server_has_port=" + << (scenario.server_has_port ? "true" : "false") + << ", queue_pending_data=" + << (scenario.queue_pending_data ? "true" : "false") + << ", credentials='" << scenario.credentials_type << "'}"; +} + +void TestScenario::Log() const { + std::ostringstream out; + out << *this; + gpr_log(GPR_ERROR, "%s", out.str().c_str()); +} + +// Set up a test tcp server which is in charge of accepting connections and +// handing off the connections as fds. +class TestTcpServer { + public: + TestTcpServer() + : shutdown_(false), + queue_data_(false), + port_(grpc_pick_unused_port_or_die()) { + std::ostringstream server_address; + server_address << "localhost:" << port_; + address_ = server_address.str(); + test_tcp_server_init(&tcp_server_, &TestTcpServer::OnConnect, this); + GRPC_CLOSURE_INIT(&on_fd_released_, &TestTcpServer::OnFdReleased, this, + grpc_schedule_on_exec_ctx); + } + + ~TestTcpServer() { + running_thread_.join(); + test_tcp_server_destroy(&tcp_server_); + grpc_recycle_unused_port(port_); + } + + // Read some data before handing off the connection. + void SetQueueData() { queue_data_ = true; } + + void Start() { + test_tcp_server_start(&tcp_server_, port_); + gpr_log(GPR_INFO, "Test TCP server started at %s", address_.c_str()); + } + + const TString& address() { return address_; } + + void SetAcceptor( + std::unique_ptr<experimental::ExternalConnectionAcceptor> acceptor) { + connection_acceptor_ = std::move(acceptor); + } + + void Run() { + running_thread_ = std::thread([this]() { + while (true) { + { + std::lock_guard<std::mutex> lock(mu_); + if (shutdown_) { + return; + } + } + test_tcp_server_poll(&tcp_server_, 1); + } + }); + } + + void Shutdown() { + std::lock_guard<std::mutex> lock(mu_); + shutdown_ = true; + } + + static void OnConnect(void* arg, grpc_endpoint* tcp, + grpc_pollset* accepting_pollset, + grpc_tcp_server_acceptor* acceptor) { + auto* self = static_cast<TestTcpServer*>(arg); + self->OnConnect(tcp, accepting_pollset, acceptor); + } + + static void OnFdReleased(void* arg, grpc_error* err) { + auto* self = static_cast<TestTcpServer*>(arg); + self->OnFdReleased(err); + } + + private: + void OnConnect(grpc_endpoint* tcp, grpc_pollset* /*accepting_pollset*/, + grpc_tcp_server_acceptor* acceptor) { + TString peer(grpc_endpoint_get_peer(tcp)); + gpr_log(GPR_INFO, "Got incoming connection! from %s", peer.c_str()); + EXPECT_FALSE(acceptor->external_connection); + listener_fd_ = grpc_tcp_server_port_fd( + acceptor->from_server, acceptor->port_index, acceptor->fd_index); + gpr_free(acceptor); + grpc_tcp_destroy_and_release_fd(tcp, &fd_, &on_fd_released_); + } + + void OnFdReleased(grpc_error* err) { + EXPECT_EQ(GRPC_ERROR_NONE, err); + experimental::ExternalConnectionAcceptor::NewConnectionParameters p; + p.listener_fd = listener_fd_; + p.fd = fd_; + if (queue_data_) { + char buf[1024]; + ssize_t read_bytes = 0; + while (read_bytes <= 0) { + read_bytes = read(fd_, buf, 1024); + } + Slice data(buf, read_bytes); + p.read_buffer = ByteBuffer(&data, 1); + } + gpr_log(GPR_INFO, "Handing off fd %d with data size %d from listener fd %d", + fd_, static_cast<int>(p.read_buffer.Length()), listener_fd_); + connection_acceptor_->HandleNewConnection(&p); + } + + std::mutex mu_; + bool shutdown_; + + int listener_fd_ = -1; + int fd_ = -1; + bool queue_data_ = false; + + grpc_closure on_fd_released_; + std::thread running_thread_; + int port_ = -1; + TString address_; + std::unique_ptr<experimental::ExternalConnectionAcceptor> + connection_acceptor_; + test_tcp_server tcp_server_; +}; + +class PortSharingEnd2endTest : public ::testing::TestWithParam<TestScenario> { + protected: + PortSharingEnd2endTest() : is_server_started_(false), first_picked_port_(0) { + GetParam().Log(); + } + + void SetUp() override { + if (GetParam().queue_pending_data) { + tcp_server1_.SetQueueData(); + tcp_server2_.SetQueueData(); + } + tcp_server1_.Start(); + tcp_server2_.Start(); + ServerBuilder builder; + if (GetParam().server_has_port) { + int port = grpc_pick_unused_port_or_die(); + first_picked_port_ = port; + server_address_ << "localhost:" << port; + auto creds = GetCredentialsProvider()->GetServerCredentials( + GetParam().credentials_type); + builder.AddListeningPort(server_address_.str(), creds); + gpr_log(GPR_INFO, "gRPC server listening on %s", + server_address_.str().c_str()); + } + auto server_creds = GetCredentialsProvider()->GetServerCredentials( + GetParam().credentials_type); + auto acceptor1 = builder.experimental().AddExternalConnectionAcceptor( + ServerBuilder::experimental_type::ExternalConnectionType::FROM_FD, + server_creds); + tcp_server1_.SetAcceptor(std::move(acceptor1)); + auto acceptor2 = builder.experimental().AddExternalConnectionAcceptor( + ServerBuilder::experimental_type::ExternalConnectionType::FROM_FD, + server_creds); + tcp_server2_.SetAcceptor(std::move(acceptor2)); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + is_server_started_ = true; + + tcp_server1_.Run(); + tcp_server2_.Run(); + } + + void TearDown() override { + tcp_server1_.Shutdown(); + tcp_server2_.Shutdown(); + if (is_server_started_) { + server_->Shutdown(); + } + if (first_picked_port_ > 0) { + grpc_recycle_unused_port(first_picked_port_); + } + } + + void ResetStubs() { + EXPECT_TRUE(is_server_started_); + ChannelArguments args; + args.SetInt(GRPC_ARG_USE_LOCAL_SUBCHANNEL_POOL, 1); + auto channel_creds = GetCredentialsProvider()->GetChannelCredentials( + GetParam().credentials_type, &args); + channel_handoff1_ = + CreateCustomChannel(tcp_server1_.address(), channel_creds, args); + stub_handoff1_ = EchoTestService::NewStub(channel_handoff1_); + channel_handoff2_ = + CreateCustomChannel(tcp_server2_.address(), channel_creds, args); + stub_handoff2_ = EchoTestService::NewStub(channel_handoff2_); + if (GetParam().server_has_port) { + ChannelArguments direct_args; + direct_args.SetInt(GRPC_ARG_USE_LOCAL_SUBCHANNEL_POOL, 1); + auto direct_creds = GetCredentialsProvider()->GetChannelCredentials( + GetParam().credentials_type, &direct_args); + channel_direct_ = + CreateCustomChannel(server_address_.str(), direct_creds, direct_args); + stub_direct_ = EchoTestService::NewStub(channel_direct_); + } + } + + bool is_server_started_; + // channel/stub to the test tcp server, the connection will be handed to the + // grpc server. + std::shared_ptr<Channel> channel_handoff1_; + std::unique_ptr<EchoTestService::Stub> stub_handoff1_; + std::shared_ptr<Channel> channel_handoff2_; + std::unique_ptr<EchoTestService::Stub> stub_handoff2_; + // channel/stub to talk to the grpc server directly, if applicable. + std::shared_ptr<Channel> channel_direct_; + std::unique_ptr<EchoTestService::Stub> stub_direct_; + std::unique_ptr<Server> server_; + std::ostringstream server_address_; + TestServiceImpl service_; + TestTcpServer tcp_server1_; + TestTcpServer tcp_server2_; + int first_picked_port_; +}; + +static void SendRpc(EchoTestService::Stub* stub, int num_rpcs) { + EchoRequest request; + EchoResponse response; + request.set_message("Hello hello hello hello"); + + for (int i = 0; i < num_rpcs; ++i) { + ClientContext context; + Status s = stub->Echo(&context, request, &response); + EXPECT_EQ(response.message(), request.message()); + EXPECT_TRUE(s.ok()); + } +} + +std::vector<TestScenario> CreateTestScenarios() { + std::vector<TestScenario> scenarios; + std::vector<TString> credentials_types; + +#if TARGET_OS_IPHONE + // Workaround Apple CFStream bug + gpr_setenv("grpc_cfstream", "0"); +#endif + + credentials_types = GetCredentialsProvider()->GetSecureCredentialsTypeList(); + // Only allow insecure credentials type when it is registered with the + // provider. User may create providers that do not have insecure. + if (GetCredentialsProvider()->GetChannelCredentials(kInsecureCredentialsType, + nullptr) != nullptr) { + credentials_types.push_back(kInsecureCredentialsType); + } + + GPR_ASSERT(!credentials_types.empty()); + for (const auto& cred : credentials_types) { + for (auto server_has_port : {true, false}) { + for (auto queue_pending_data : {true, false}) { + scenarios.emplace_back(server_has_port, queue_pending_data, cred); + } + } + } + return scenarios; +} + +TEST_P(PortSharingEnd2endTest, HandoffAndDirectCalls) { + ResetStubs(); + SendRpc(stub_handoff1_.get(), 5); + if (GetParam().server_has_port) { + SendRpc(stub_direct_.get(), 5); + } +} + +TEST_P(PortSharingEnd2endTest, MultipleHandoff) { + for (int i = 0; i < 3; i++) { + ResetStubs(); + SendRpc(stub_handoff2_.get(), 1); + } +} + +TEST_P(PortSharingEnd2endTest, TwoHandoffPorts) { + for (int i = 0; i < 3; i++) { + ResetStubs(); + SendRpc(stub_handoff1_.get(), 5); + SendRpc(stub_handoff2_.get(), 5); + } +} + +INSTANTIATE_TEST_SUITE_P(PortSharingEnd2end, PortSharingEnd2endTest, + ::testing::ValuesIn(CreateTestScenarios())); + +} // namespace +} // namespace testing +} // namespace grpc + +#endif // GRPC_POSIX_SOCKET_TCP_SERVER + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/contrib/libs/grpc/test/cpp/end2end/proto_server_reflection_test.cc b/contrib/libs/grpc/test/cpp/end2end/proto_server_reflection_test.cc new file mode 100644 index 0000000000..d79b33da70 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/end2end/proto_server_reflection_test.cc @@ -0,0 +1,150 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <grpc/grpc.h> +#include <grpcpp/channel.h> +#include <grpcpp/client_context.h> +#include <grpcpp/create_channel.h> +#include <grpcpp/ext/proto_server_reflection_plugin.h> +#include <grpcpp/security/credentials.h> +#include <grpcpp/security/server_credentials.h> +#include <grpcpp/server.h> +#include <grpcpp/server_builder.h> +#include <grpcpp/server_context.h> + +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/end2end/test_service_impl.h" +#include "test/cpp/util/proto_reflection_descriptor_database.h" + +#include <gtest/gtest.h> + +namespace grpc { +namespace testing { + +class ProtoServerReflectionTest : public ::testing::Test { + public: + ProtoServerReflectionTest() {} + + void SetUp() override { + port_ = grpc_pick_unused_port_or_die(); + ref_desc_pool_ = protobuf::DescriptorPool::generated_pool(); + + ServerBuilder builder; + TString server_address = "localhost:" + to_string(port_); + builder.AddListeningPort(server_address, InsecureServerCredentials()); + server_ = builder.BuildAndStart(); + } + + void ResetStub() { + string target = "dns:localhost:" + to_string(port_); + std::shared_ptr<Channel> channel = + grpc::CreateChannel(target, InsecureChannelCredentials()); + stub_ = grpc::testing::EchoTestService::NewStub(channel); + desc_db_.reset(new ProtoReflectionDescriptorDatabase(channel)); + desc_pool_.reset(new protobuf::DescriptorPool(desc_db_.get())); + } + + string to_string(const int number) { + std::stringstream strs; + strs << number; + return strs.str(); + } + + void CompareService(const TString& service) { + const protobuf::ServiceDescriptor* service_desc = + desc_pool_->FindServiceByName(service); + const protobuf::ServiceDescriptor* ref_service_desc = + ref_desc_pool_->FindServiceByName(service); + EXPECT_TRUE(service_desc != nullptr); + EXPECT_TRUE(ref_service_desc != nullptr); + EXPECT_EQ(service_desc->DebugString(), ref_service_desc->DebugString()); + + const protobuf::FileDescriptor* file_desc = service_desc->file(); + if (known_files_.find(file_desc->package() + "/" + file_desc->name()) != + known_files_.end()) { + EXPECT_EQ(file_desc->DebugString(), + ref_service_desc->file()->DebugString()); + known_files_.insert(file_desc->package() + "/" + file_desc->name()); + } + + for (int i = 0; i < service_desc->method_count(); ++i) { + CompareMethod(service_desc->method(i)->full_name()); + } + } + + void CompareMethod(const TString& method) { + const protobuf::MethodDescriptor* method_desc = + desc_pool_->FindMethodByName(method); + const protobuf::MethodDescriptor* ref_method_desc = + ref_desc_pool_->FindMethodByName(method); + EXPECT_TRUE(method_desc != nullptr); + EXPECT_TRUE(ref_method_desc != nullptr); + EXPECT_EQ(method_desc->DebugString(), ref_method_desc->DebugString()); + + CompareType(method_desc->input_type()->full_name()); + CompareType(method_desc->output_type()->full_name()); + } + + void CompareType(const TString& type) { + if (known_types_.find(type) != known_types_.end()) { + return; + } + + const protobuf::Descriptor* desc = desc_pool_->FindMessageTypeByName(type); + const protobuf::Descriptor* ref_desc = + ref_desc_pool_->FindMessageTypeByName(type); + EXPECT_TRUE(desc != nullptr); + EXPECT_TRUE(ref_desc != nullptr); + EXPECT_EQ(desc->DebugString(), ref_desc->DebugString()); + } + + protected: + std::unique_ptr<Server> server_; + std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_; + std::unique_ptr<ProtoReflectionDescriptorDatabase> desc_db_; + std::unique_ptr<protobuf::DescriptorPool> desc_pool_; + std::unordered_set<string> known_files_; + std::unordered_set<string> known_types_; + const protobuf::DescriptorPool* ref_desc_pool_; + int port_; + reflection::ProtoServerReflectionPlugin plugin_; +}; + +TEST_F(ProtoServerReflectionTest, CheckResponseWithLocalDescriptorPool) { + ResetStub(); + + std::vector<TString> services; + desc_db_->GetServices(&services); + // The service list has at least one service (reflection servcie). + EXPECT_TRUE(services.size() > 0); + + for (auto it = services.begin(); it != services.end(); ++it) { + CompareService(*it); + } +} + +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/contrib/libs/grpc/test/cpp/end2end/raw_end2end_test.cc b/contrib/libs/grpc/test/cpp/end2end/raw_end2end_test.cc new file mode 100644 index 0000000000..184dc1e5f5 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/end2end/raw_end2end_test.cc @@ -0,0 +1,370 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <cinttypes> +#include <memory> +#include <thread> + +#include <grpc/grpc.h> +#include <grpc/support/alloc.h> +#include <grpc/support/log.h> +#include <grpc/support/time.h> +#include <grpcpp/channel.h> +#include <grpcpp/client_context.h> +#include <grpcpp/create_channel.h> +#include <grpcpp/server.h> +#include <grpcpp/server_builder.h> +#include <grpcpp/server_context.h> + +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/iomgr/port.h" +#include "src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/util/byte_buffer_proto_helper.h" +#include "test/cpp/util/string_ref_helper.h" + +#include <gtest/gtest.h> + +using grpc::testing::EchoRequest; +using grpc::testing::EchoResponse; + +namespace grpc { +namespace testing { + +namespace { + +void* tag(int i) { return (void*)static_cast<intptr_t>(i); } +int detag(void* p) { return static_cast<int>(reinterpret_cast<intptr_t>(p)); } + +class Verifier { + public: + Verifier() {} + + // Expect sets the expected ok value for a specific tag + Verifier& Expect(int i, bool expect_ok) { + expectations_[tag(i)] = expect_ok; + return *this; + } + + // Next waits for 1 async tag to complete, checks its + // expectations, and returns the tag + int Next(CompletionQueue* cq, bool ignore_ok) { + bool ok; + void* got_tag; + EXPECT_TRUE(cq->Next(&got_tag, &ok)); + GotTag(got_tag, ok, ignore_ok); + return detag(got_tag); + } + + // Verify keeps calling Next until all currently set + // expected tags are complete + void Verify(CompletionQueue* cq) { + GPR_ASSERT(!expectations_.empty()); + while (!expectations_.empty()) { + Next(cq, false); + } + } + + private: + void GotTag(void* got_tag, bool ok, bool ignore_ok) { + auto it = expectations_.find(got_tag); + if (it != expectations_.end()) { + if (!ignore_ok) { + EXPECT_EQ(it->second, ok); + } + expectations_.erase(it); + } + } + + std::map<void*, bool> expectations_; +}; + +class RawEnd2EndTest : public ::testing::Test { + protected: + RawEnd2EndTest() {} + + void SetUp() override { + port_ = grpc_pick_unused_port_or_die(); + server_address_ << "localhost:" << port_; + } + + void TearDown() override { + server_->Shutdown(); + void* ignored_tag; + bool ignored_ok; + cq_->Shutdown(); + while (cq_->Next(&ignored_tag, &ignored_ok)) + ; + stub_.reset(); + grpc_recycle_unused_port(port_); + } + + template <typename ServerType> + std::unique_ptr<ServerType> BuildAndStartServer() { + ServerBuilder builder; + builder.AddListeningPort(server_address_.str(), + grpc::InsecureServerCredentials()); + std::unique_ptr<ServerType> service(new ServerType()); + builder.RegisterService(service.get()); + cq_ = builder.AddCompletionQueue(); + server_ = builder.BuildAndStart(); + return service; + } + + void ResetStub() { + ChannelArguments args; + std::shared_ptr<Channel> channel = grpc::CreateChannel( + server_address_.str(), grpc::InsecureChannelCredentials()); + stub_ = grpc::testing::EchoTestService::NewStub(channel); + } + + std::unique_ptr<ServerCompletionQueue> cq_; + std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_; + std::unique_ptr<Server> server_; + std::ostringstream server_address_; + int port_; + + // For the client application to populate and send to server. + EchoRequest send_request_; + ::grpc::ByteBuffer send_request_buffer_; + + // For the server to give to gRPC to be populated by incoming request + // from client. + EchoRequest recv_request_; + ::grpc::ByteBuffer recv_request_buffer_; + + // For the server application to populate and send back to client. + EchoResponse send_response_; + ::grpc::ByteBuffer send_response_buffer_; + + // For the client to give to gRPC to be populated by incoming response + // from server. + EchoResponse recv_response_; + ::grpc::ByteBuffer recv_response_buffer_; + Status recv_status_; + + // Both sides need contexts + ClientContext cli_ctx_; + ServerContext srv_ctx_; +}; + +// Regular Async, both peers use proto +TEST_F(RawEnd2EndTest, PureAsyncService) { + typedef grpc::testing::EchoTestService::AsyncService SType; + ResetStub(); + auto service = BuildAndStartServer<SType>(); + grpc::ServerAsyncResponseWriter<EchoResponse> response_writer(&srv_ctx_); + + send_request_.set_message("hello"); + std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader( + stub_->AsyncEcho(&cli_ctx_, send_request_, cq_.get())); + service->RequestEcho(&srv_ctx_, &recv_request_, &response_writer, cq_.get(), + cq_.get(), tag(2)); + response_reader->Finish(&recv_response_, &recv_status_, tag(4)); + Verifier().Expect(2, true).Verify(cq_.get()); + EXPECT_EQ(send_request_.message(), recv_request_.message()); + send_response_.set_message(recv_request_.message()); + response_writer.Finish(send_response_, Status::OK, tag(3)); + Verifier().Expect(3, true).Expect(4, true).Verify(cq_.get()); + + EXPECT_EQ(send_response_.message(), recv_response_.message()); + EXPECT_TRUE(recv_status_.ok()); +} + +// Client uses proto, server uses generic codegen, unary +TEST_F(RawEnd2EndTest, RawServerUnary) { + typedef grpc::testing::EchoTestService::WithRawMethod_Echo< + grpc::testing::EchoTestService::Service> + SType; + ResetStub(); + auto service = BuildAndStartServer<SType>(); + grpc::GenericServerAsyncResponseWriter response_writer(&srv_ctx_); + + send_request_.set_message("hello unary"); + std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader( + stub_->AsyncEcho(&cli_ctx_, send_request_, cq_.get())); + service->RequestEcho(&srv_ctx_, &recv_request_buffer_, &response_writer, + cq_.get(), cq_.get(), tag(2)); + response_reader->Finish(&recv_response_, &recv_status_, tag(4)); + Verifier().Expect(2, true).Verify(cq_.get()); + EXPECT_TRUE(ParseFromByteBuffer(&recv_request_buffer_, &recv_request_)); + EXPECT_EQ(send_request_.message(), recv_request_.message()); + send_response_.set_message(recv_request_.message()); + EXPECT_TRUE( + SerializeToByteBufferInPlace(&send_response_, &send_response_buffer_)); + response_writer.Finish(send_response_buffer_, Status::OK, tag(3)); + Verifier().Expect(3, true).Expect(4, true).Verify(cq_.get()); + + EXPECT_EQ(send_response_.message(), recv_response_.message()); + EXPECT_TRUE(recv_status_.ok()); +} + +// Client uses proto, server uses generic codegen, client streaming +TEST_F(RawEnd2EndTest, RawServerClientStreaming) { + typedef grpc::testing::EchoTestService::WithRawMethod_RequestStream< + grpc::testing::EchoTestService::Service> + SType; + ResetStub(); + auto service = BuildAndStartServer<SType>(); + + grpc::GenericServerAsyncReader srv_stream(&srv_ctx_); + + send_request_.set_message("hello client streaming"); + std::unique_ptr<ClientAsyncWriter<EchoRequest>> cli_stream( + stub_->AsyncRequestStream(&cli_ctx_, &recv_response_, cq_.get(), tag(1))); + + service->RequestRequestStream(&srv_ctx_, &srv_stream, cq_.get(), cq_.get(), + tag(2)); + + Verifier().Expect(2, true).Expect(1, true).Verify(cq_.get()); + + cli_stream->Write(send_request_, tag(3)); + srv_stream.Read(&recv_request_buffer_, tag(4)); + Verifier().Expect(3, true).Expect(4, true).Verify(cq_.get()); + ParseFromByteBuffer(&recv_request_buffer_, &recv_request_); + EXPECT_EQ(send_request_.message(), recv_request_.message()); + + cli_stream->Write(send_request_, tag(5)); + srv_stream.Read(&recv_request_buffer_, tag(6)); + Verifier().Expect(5, true).Expect(6, true).Verify(cq_.get()); + + ParseFromByteBuffer(&recv_request_buffer_, &recv_request_); + EXPECT_EQ(send_request_.message(), recv_request_.message()); + cli_stream->WritesDone(tag(7)); + srv_stream.Read(&recv_request_buffer_, tag(8)); + Verifier().Expect(7, true).Expect(8, false).Verify(cq_.get()); + + ParseFromByteBuffer(&recv_request_buffer_, &recv_request_); + send_response_.set_message(recv_request_.message()); + SerializeToByteBufferInPlace(&send_response_, &send_response_buffer_); + srv_stream.Finish(send_response_buffer_, Status::OK, tag(9)); + cli_stream->Finish(&recv_status_, tag(10)); + Verifier().Expect(9, true).Expect(10, true).Verify(cq_.get()); + + EXPECT_EQ(send_response_.message(), recv_response_.message()); + EXPECT_TRUE(recv_status_.ok()); +} + +// Client uses proto, server uses generic codegen, server streaming +TEST_F(RawEnd2EndTest, RawServerServerStreaming) { + typedef grpc::testing::EchoTestService::WithRawMethod_ResponseStream< + grpc::testing::EchoTestService::Service> + SType; + ResetStub(); + auto service = BuildAndStartServer<SType>(); + grpc::GenericServerAsyncWriter srv_stream(&srv_ctx_); + + send_request_.set_message("hello server streaming"); + std::unique_ptr<ClientAsyncReader<EchoResponse>> cli_stream( + stub_->AsyncResponseStream(&cli_ctx_, send_request_, cq_.get(), tag(1))); + + service->RequestResponseStream(&srv_ctx_, &recv_request_buffer_, &srv_stream, + cq_.get(), cq_.get(), tag(2)); + + Verifier().Expect(1, true).Expect(2, true).Verify(cq_.get()); + ParseFromByteBuffer(&recv_request_buffer_, &recv_request_); + EXPECT_EQ(send_request_.message(), recv_request_.message()); + + send_response_.set_message(recv_request_.message()); + SerializeToByteBufferInPlace(&send_response_, &send_response_buffer_); + srv_stream.Write(send_response_buffer_, tag(3)); + cli_stream->Read(&recv_response_, tag(4)); + Verifier().Expect(3, true).Expect(4, true).Verify(cq_.get()); + EXPECT_EQ(send_response_.message(), recv_response_.message()); + + srv_stream.Write(send_response_buffer_, tag(5)); + cli_stream->Read(&recv_response_, tag(6)); + Verifier().Expect(5, true).Expect(6, true).Verify(cq_.get()); + EXPECT_EQ(send_response_.message(), recv_response_.message()); + + srv_stream.Finish(Status::OK, tag(7)); + cli_stream->Read(&recv_response_, tag(8)); + Verifier().Expect(7, true).Expect(8, false).Verify(cq_.get()); + + cli_stream->Finish(&recv_status_, tag(9)); + Verifier().Expect(9, true).Verify(cq_.get()); + + EXPECT_TRUE(recv_status_.ok()); +} + +// Client uses proto, server uses generic codegen, bidi streaming +TEST_F(RawEnd2EndTest, RawServerBidiStreaming) { + typedef grpc::testing::EchoTestService::WithRawMethod_BidiStream< + grpc::testing::EchoTestService::Service> + SType; + ResetStub(); + auto service = BuildAndStartServer<SType>(); + + grpc::GenericServerAsyncReaderWriter srv_stream(&srv_ctx_); + + send_request_.set_message("hello bidi streaming"); + std::unique_ptr<ClientAsyncReaderWriter<EchoRequest, EchoResponse>> + cli_stream(stub_->AsyncBidiStream(&cli_ctx_, cq_.get(), tag(1))); + + service->RequestBidiStream(&srv_ctx_, &srv_stream, cq_.get(), cq_.get(), + tag(2)); + + Verifier().Expect(1, true).Expect(2, true).Verify(cq_.get()); + + cli_stream->Write(send_request_, tag(3)); + srv_stream.Read(&recv_request_buffer_, tag(4)); + Verifier().Expect(3, true).Expect(4, true).Verify(cq_.get()); + ParseFromByteBuffer(&recv_request_buffer_, &recv_request_); + EXPECT_EQ(send_request_.message(), recv_request_.message()); + + send_response_.set_message(recv_request_.message()); + SerializeToByteBufferInPlace(&send_response_, &send_response_buffer_); + srv_stream.Write(send_response_buffer_, tag(5)); + cli_stream->Read(&recv_response_, tag(6)); + Verifier().Expect(5, true).Expect(6, true).Verify(cq_.get()); + EXPECT_EQ(send_response_.message(), recv_response_.message()); + + cli_stream->WritesDone(tag(7)); + srv_stream.Read(&recv_request_buffer_, tag(8)); + Verifier().Expect(7, true).Expect(8, false).Verify(cq_.get()); + + srv_stream.Finish(Status::OK, tag(9)); + cli_stream->Finish(&recv_status_, tag(10)); + Verifier().Expect(9, true).Expect(10, true).Verify(cq_.get()); + + EXPECT_TRUE(recv_status_.ok()); +} + +// Testing that this pattern compiles +TEST_F(RawEnd2EndTest, CompileTest) { + typedef grpc::testing::EchoTestService::WithRawMethod_Echo< + grpc::testing::EchoTestService::AsyncService> + SType; + ResetStub(); + auto service = BuildAndStartServer<SType>(); +} + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + // Change the backup poll interval from 5s to 100ms to speed up the + // ReconnectChannel test + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + int ret = RUN_ALL_TESTS(); + return ret; +} diff --git a/contrib/libs/grpc/test/cpp/end2end/server_builder_plugin_test.cc b/contrib/libs/grpc/test/cpp/end2end/server_builder_plugin_test.cc new file mode 100644 index 0000000000..004902cad3 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/end2end/server_builder_plugin_test.cc @@ -0,0 +1,265 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <thread> + +#include <grpc/grpc.h> +#include <grpcpp/channel.h> +#include <grpcpp/client_context.h> +#include <grpcpp/create_channel.h> +#include <grpcpp/impl/server_builder_option.h> +#include <grpcpp/impl/server_builder_plugin.h> +#include <grpcpp/impl/server_initializer.h> +#include <grpcpp/security/credentials.h> +#include <grpcpp/security/server_credentials.h> +#include <grpcpp/server.h> +#include <grpcpp/server_builder.h> +#include <grpcpp/server_context.h> + +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/end2end/test_service_impl.h" + +#include <gtest/gtest.h> + +#define PLUGIN_NAME "TestServerBuilderPlugin" + +namespace grpc { +namespace testing { + +class TestServerBuilderPlugin : public ServerBuilderPlugin { + public: + TestServerBuilderPlugin() : service_(new TestServiceImpl()) { + init_server_is_called_ = false; + finish_is_called_ = false; + change_arguments_is_called_ = false; + register_service_ = false; + } + + TString name() override { return PLUGIN_NAME; } + + void InitServer(ServerInitializer* si) override { + init_server_is_called_ = true; + if (register_service_) { + si->RegisterService(service_); + } + } + + void Finish(ServerInitializer* /*si*/) override { finish_is_called_ = true; } + + void ChangeArguments(const TString& /*name*/, void* /*value*/) override { + change_arguments_is_called_ = true; + } + + bool has_async_methods() const override { + if (register_service_) { + return service_->has_async_methods(); + } + return false; + } + + bool has_sync_methods() const override { + if (register_service_) { + return service_->has_synchronous_methods(); + } + return false; + } + + void SetRegisterService() { register_service_ = true; } + + bool init_server_is_called() { return init_server_is_called_; } + bool finish_is_called() { return finish_is_called_; } + bool change_arguments_is_called() { return change_arguments_is_called_; } + + private: + bool init_server_is_called_; + bool finish_is_called_; + bool change_arguments_is_called_; + bool register_service_; + std::shared_ptr<TestServiceImpl> service_; +}; + +class InsertPluginServerBuilderOption : public ServerBuilderOption { + public: + InsertPluginServerBuilderOption() { register_service_ = false; } + + void UpdateArguments(ChannelArguments* /*arg*/) override {} + + void UpdatePlugins( + std::vector<std::unique_ptr<ServerBuilderPlugin>>* plugins) override { + plugins->clear(); + + std::unique_ptr<TestServerBuilderPlugin> plugin( + new TestServerBuilderPlugin()); + if (register_service_) plugin->SetRegisterService(); + plugins->emplace_back(std::move(plugin)); + } + + void SetRegisterService() { register_service_ = true; } + + private: + bool register_service_; +}; + +std::unique_ptr<ServerBuilderPlugin> CreateTestServerBuilderPlugin() { + return std::unique_ptr<ServerBuilderPlugin>(new TestServerBuilderPlugin()); +} + +// Force AddServerBuilderPlugin() to be called at static initialization time. +struct StaticTestPluginInitializer { + StaticTestPluginInitializer() { + ::grpc::ServerBuilder::InternalAddPluginFactory( + &CreateTestServerBuilderPlugin); + } +} static_plugin_initializer_test_; + +// When the param boolean is true, the ServerBuilder plugin will be added at the +// time of static initialization. When it's false, the ServerBuilder plugin will +// be added using ServerBuilder::SetOption(). +class ServerBuilderPluginTest : public ::testing::TestWithParam<bool> { + public: + ServerBuilderPluginTest() {} + + void SetUp() override { + port_ = grpc_pick_unused_port_or_die(); + builder_.reset(new ServerBuilder()); + } + + void InsertPlugin() { + if (GetParam()) { + // Add ServerBuilder plugin in static initialization + CheckPresent(); + } else { + // Add ServerBuilder plugin using ServerBuilder::SetOption() + builder_->SetOption(std::unique_ptr<ServerBuilderOption>( + new InsertPluginServerBuilderOption())); + } + } + + void InsertPluginWithTestService() { + if (GetParam()) { + // Add ServerBuilder plugin in static initialization + auto plugin = CheckPresent(); + EXPECT_TRUE(plugin); + plugin->SetRegisterService(); + } else { + // Add ServerBuilder plugin using ServerBuilder::SetOption() + std::unique_ptr<InsertPluginServerBuilderOption> option( + new InsertPluginServerBuilderOption()); + option->SetRegisterService(); + builder_->SetOption(std::move(option)); + } + } + + void StartServer() { + TString server_address = "localhost:" + to_string(port_); + builder_->AddListeningPort(server_address, InsecureServerCredentials()); + // we run some tests without a service, and for those we need to supply a + // frequently polled completion queue + cq_ = builder_->AddCompletionQueue(); + cq_thread_ = new std::thread(&ServerBuilderPluginTest::RunCQ, this); + server_ = builder_->BuildAndStart(); + EXPECT_TRUE(CheckPresent()); + } + + void ResetStub() { + string target = "dns:localhost:" + to_string(port_); + channel_ = grpc::CreateChannel(target, InsecureChannelCredentials()); + stub_ = grpc::testing::EchoTestService::NewStub(channel_); + } + + void TearDown() override { + auto plugin = CheckPresent(); + EXPECT_TRUE(plugin); + EXPECT_TRUE(plugin->init_server_is_called()); + EXPECT_TRUE(plugin->finish_is_called()); + server_->Shutdown(); + cq_->Shutdown(); + cq_thread_->join(); + delete cq_thread_; + } + + string to_string(const int number) { + std::stringstream strs; + strs << number; + return strs.str(); + } + + protected: + std::shared_ptr<Channel> channel_; + std::unique_ptr<ServerBuilder> builder_; + std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_; + std::unique_ptr<ServerCompletionQueue> cq_; + std::unique_ptr<Server> server_; + std::thread* cq_thread_; + TestServiceImpl service_; + int port_; + + private: + TestServerBuilderPlugin* CheckPresent() { + auto it = builder_->plugins_.begin(); + for (; it != builder_->plugins_.end(); it++) { + if ((*it)->name() == PLUGIN_NAME) break; + } + if (it != builder_->plugins_.end()) { + return static_cast<TestServerBuilderPlugin*>(it->get()); + } else { + return nullptr; + } + } + + void RunCQ() { + void* tag; + bool ok; + while (cq_->Next(&tag, &ok)) + ; + } +}; + +TEST_P(ServerBuilderPluginTest, PluginWithoutServiceTest) { + InsertPlugin(); + StartServer(); +} + +TEST_P(ServerBuilderPluginTest, PluginWithServiceTest) { + InsertPluginWithTestService(); + StartServer(); + ResetStub(); + + EchoRequest request; + EchoResponse response; + request.set_message("Hello hello hello hello"); + ClientContext context; + context.set_compression_algorithm(GRPC_COMPRESS_GZIP); + Status s = stub_->Echo(&context, request, &response); + EXPECT_EQ(response.message(), request.message()); + EXPECT_TRUE(s.ok()); +} + +INSTANTIATE_TEST_SUITE_P(ServerBuilderPluginTest, ServerBuilderPluginTest, + ::testing::Values(false, true)); + +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/contrib/libs/grpc/test/cpp/end2end/server_crash_test.cc b/contrib/libs/grpc/test/cpp/end2end/server_crash_test.cc new file mode 100644 index 0000000000..3616d680f9 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/end2end/server_crash_test.cc @@ -0,0 +1,160 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <grpc/grpc.h> +#include <grpc/support/log.h> +#include <grpc/support/time.h> +#include <grpcpp/channel.h> +#include <grpcpp/client_context.h> +#include <grpcpp/create_channel.h> +#include <grpcpp/server.h> +#include <grpcpp/server_builder.h> +#include <grpcpp/server_context.h> + +#include "src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/util/subprocess.h" + +#include <gtest/gtest.h> + +using grpc::testing::EchoRequest; +using grpc::testing::EchoResponse; +using std::chrono::system_clock; + +static TString g_root; + +namespace grpc { +namespace testing { + +namespace { + +class ServiceImpl final : public ::grpc::testing::EchoTestService::Service { + public: + ServiceImpl() : bidi_stream_count_(0), response_stream_count_(0) {} + + Status BidiStream( + ServerContext* /*context*/, + ServerReaderWriter<EchoResponse, EchoRequest>* stream) override { + bidi_stream_count_++; + EchoRequest request; + EchoResponse response; + while (stream->Read(&request)) { + gpr_log(GPR_INFO, "recv msg %s", request.message().c_str()); + response.set_message(request.message()); + stream->Write(response); + gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_seconds(1, GPR_TIMESPAN))); + } + return Status::OK; + } + + Status ResponseStream(ServerContext* /*context*/, + const EchoRequest* /*request*/, + ServerWriter<EchoResponse>* writer) override { + EchoResponse response; + response_stream_count_++; + for (int i = 0;; i++) { + std::ostringstream msg; + msg << "Hello " << i; + response.set_message(msg.str()); + if (!writer->Write(response)) break; + gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_seconds(1, GPR_TIMESPAN))); + } + return Status::OK; + } + + int bidi_stream_count() { return bidi_stream_count_; } + + int response_stream_count() { return response_stream_count_; } + + private: + int bidi_stream_count_; + int response_stream_count_; +}; + +class CrashTest : public ::testing::Test { + protected: + CrashTest() {} + + std::unique_ptr<Server> CreateServerAndClient(const TString& mode) { + auto port = grpc_pick_unused_port_or_die(); + std::ostringstream addr_stream; + addr_stream << "localhost:" << port; + auto addr = addr_stream.str(); + client_.reset(new SubProcess({g_root + "/server_crash_test_client", + "--address=" + addr, "--mode=" + mode})); + GPR_ASSERT(client_); + + ServerBuilder builder; + builder.AddListeningPort(addr, grpc::InsecureServerCredentials()); + builder.RegisterService(&service_); + return builder.BuildAndStart(); + } + + void KillClient() { client_.reset(); } + + bool HadOneBidiStream() { return service_.bidi_stream_count() == 1; } + + bool HadOneResponseStream() { return service_.response_stream_count() == 1; } + + private: + std::unique_ptr<SubProcess> client_; + ServiceImpl service_; +}; + +TEST_F(CrashTest, ResponseStream) { + auto server = CreateServerAndClient("response"); + + gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_seconds(60, GPR_TIMESPAN))); + KillClient(); + server->Shutdown(); + GPR_ASSERT(HadOneResponseStream()); +} + +TEST_F(CrashTest, BidiStream) { + auto server = CreateServerAndClient("bidi"); + + gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_seconds(60, GPR_TIMESPAN))); + KillClient(); + server->Shutdown(); + GPR_ASSERT(HadOneBidiStream()); +} + +} // namespace + +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + TString me = argv[0]; + auto lslash = me.rfind('/'); + if (lslash != TString::npos) { + g_root = me.substr(0, lslash); + } else { + g_root = "."; + } + + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/contrib/libs/grpc/test/cpp/end2end/server_crash_test_client.cc b/contrib/libs/grpc/test/cpp/end2end/server_crash_test_client.cc new file mode 100644 index 0000000000..202fb2836c --- /dev/null +++ b/contrib/libs/grpc/test/cpp/end2end/server_crash_test_client.cc @@ -0,0 +1,72 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <gflags/gflags.h> +#include <iostream> +#include <memory> +#include <sstream> +#include <util/generic/string.h> + +#include <grpc/support/log.h> +#include <grpcpp/channel.h> +#include <grpcpp/client_context.h> +#include <grpcpp/create_channel.h> + +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/cpp/util/test_config.h" + +DEFINE_string(address, "", "Address to connect to"); +DEFINE_string(mode, "", "Test mode to use"); + +using grpc::testing::EchoRequest; +using grpc::testing::EchoResponse; + +int main(int argc, char** argv) { + grpc::testing::InitTest(&argc, &argv, true); + auto stub = grpc::testing::EchoTestService::NewStub( + grpc::CreateChannel(FLAGS_address, grpc::InsecureChannelCredentials())); + + EchoRequest request; + EchoResponse response; + grpc::ClientContext context; + context.set_wait_for_ready(true); + + if (FLAGS_mode == "bidi") { + auto stream = stub->BidiStream(&context); + for (int i = 0;; i++) { + std::ostringstream msg; + msg << "Hello " << i; + request.set_message(msg.str()); + GPR_ASSERT(stream->Write(request)); + GPR_ASSERT(stream->Read(&response)); + GPR_ASSERT(response.message() == request.message()); + } + } else if (FLAGS_mode == "response") { + EchoRequest request; + request.set_message("Hello"); + auto stream = stub->ResponseStream(&context, request); + for (;;) { + GPR_ASSERT(stream->Read(&response)); + } + } else { + gpr_log(GPR_ERROR, "invalid test mode '%s'", FLAGS_mode.c_str()); + return 1; + } + + return 0; +} diff --git a/contrib/libs/grpc/test/cpp/end2end/server_early_return_test.cc b/contrib/libs/grpc/test/cpp/end2end/server_early_return_test.cc new file mode 100644 index 0000000000..0f340516b0 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/end2end/server_early_return_test.cc @@ -0,0 +1,232 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <grpc/grpc.h> +#include <grpc/support/alloc.h> +#include <grpc/support/log.h> +#include <grpc/support/time.h> +#include <grpcpp/channel.h> +#include <grpcpp/client_context.h> +#include <grpcpp/create_channel.h> +#include <grpcpp/security/credentials.h> +#include <grpcpp/security/server_credentials.h> +#include <grpcpp/server.h> +#include <grpcpp/server_builder.h> +#include <grpcpp/server_context.h> + +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/util/string_ref_helper.h" + +#include <gtest/gtest.h> + +namespace grpc { +namespace testing { +namespace { + +const char kServerReturnStatusCode[] = "server_return_status_code"; +const char kServerDelayBeforeReturnUs[] = "server_delay_before_return_us"; +const char kServerReturnAfterNReads[] = "server_return_after_n_reads"; + +class TestServiceImpl : public ::grpc::testing::EchoTestService::Service { + public: + // Unused methods are not implemented. + + Status RequestStream(ServerContext* context, + ServerReader<EchoRequest>* reader, + EchoResponse* response) override { + int server_return_status_code = + GetIntValueFromMetadata(context, kServerReturnStatusCode, 0); + int server_delay_before_return_us = + GetIntValueFromMetadata(context, kServerDelayBeforeReturnUs, 0); + int server_return_after_n_reads = + GetIntValueFromMetadata(context, kServerReturnAfterNReads, 0); + + EchoRequest request; + while (server_return_after_n_reads--) { + EXPECT_TRUE(reader->Read(&request)); + } + + response->set_message("response msg"); + + gpr_sleep_until(gpr_time_add( + gpr_now(GPR_CLOCK_MONOTONIC), + gpr_time_from_micros(server_delay_before_return_us, GPR_TIMESPAN))); + + return Status(static_cast<StatusCode>(server_return_status_code), ""); + } + + Status BidiStream( + ServerContext* context, + ServerReaderWriter<EchoResponse, EchoRequest>* stream) override { + int server_return_status_code = + GetIntValueFromMetadata(context, kServerReturnStatusCode, 0); + int server_delay_before_return_us = + GetIntValueFromMetadata(context, kServerDelayBeforeReturnUs, 0); + int server_return_after_n_reads = + GetIntValueFromMetadata(context, kServerReturnAfterNReads, 0); + + EchoRequest request; + EchoResponse response; + while (server_return_after_n_reads--) { + EXPECT_TRUE(stream->Read(&request)); + response.set_message(request.message()); + EXPECT_TRUE(stream->Write(response)); + } + + gpr_sleep_until(gpr_time_add( + gpr_now(GPR_CLOCK_MONOTONIC), + gpr_time_from_micros(server_delay_before_return_us, GPR_TIMESPAN))); + + return Status(static_cast<StatusCode>(server_return_status_code), ""); + } + + int GetIntValueFromMetadata(ServerContext* context, const char* key, + int default_value) { + auto metadata = context->client_metadata(); + if (metadata.find(key) != metadata.end()) { + std::istringstream iss(ToString(metadata.find(key)->second)); + iss >> default_value; + } + return default_value; + } +}; + +class ServerEarlyReturnTest : public ::testing::Test { + protected: + ServerEarlyReturnTest() : picked_port_(0) {} + + void SetUp() override { + int port = grpc_pick_unused_port_or_die(); + picked_port_ = port; + server_address_ << "127.0.0.1:" << port; + ServerBuilder builder; + builder.AddListeningPort(server_address_.str(), + InsecureServerCredentials()); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + + channel_ = grpc::CreateChannel(server_address_.str(), + InsecureChannelCredentials()); + stub_ = grpc::testing::EchoTestService::NewStub(channel_); + } + + void TearDown() override { + server_->Shutdown(); + if (picked_port_ > 0) { + grpc_recycle_unused_port(picked_port_); + } + } + + // Client sends 20 requests and the server returns after reading 10 requests. + // If return_cancel is true, server returns CANCELLED status. Otherwise it + // returns OK. + void DoBidiStream(bool return_cancelled) { + EchoRequest request; + EchoResponse response; + ClientContext context; + + context.AddMetadata(kServerReturnAfterNReads, "10"); + if (return_cancelled) { + // "1" means CANCELLED + context.AddMetadata(kServerReturnStatusCode, "1"); + } + context.AddMetadata(kServerDelayBeforeReturnUs, "10000"); + + auto stream = stub_->BidiStream(&context); + + for (int i = 0; i < 20; i++) { + request.set_message(TString("hello") + ToString(i)); + bool write_ok = stream->Write(request); + bool read_ok = stream->Read(&response); + if (i < 10) { + EXPECT_TRUE(write_ok); + EXPECT_TRUE(read_ok); + EXPECT_EQ(response.message(), request.message()); + } else { + EXPECT_FALSE(read_ok); + } + } + + stream->WritesDone(); + EXPECT_FALSE(stream->Read(&response)); + + Status s = stream->Finish(); + if (return_cancelled) { + EXPECT_EQ(s.error_code(), StatusCode::CANCELLED); + } else { + EXPECT_TRUE(s.ok()); + } + } + + void DoRequestStream(bool return_cancelled) { + EchoRequest request; + EchoResponse response; + ClientContext context; + + context.AddMetadata(kServerReturnAfterNReads, "10"); + if (return_cancelled) { + // "1" means CANCELLED + context.AddMetadata(kServerReturnStatusCode, "1"); + } + context.AddMetadata(kServerDelayBeforeReturnUs, "10000"); + + auto stream = stub_->RequestStream(&context, &response); + for (int i = 0; i < 20; i++) { + request.set_message(TString("hello") + ToString(i)); + bool written = stream->Write(request); + if (i < 10) { + EXPECT_TRUE(written); + } + } + stream->WritesDone(); + Status s = stream->Finish(); + if (return_cancelled) { + EXPECT_EQ(s.error_code(), StatusCode::CANCELLED); + } else { + EXPECT_TRUE(s.ok()); + } + } + + std::shared_ptr<Channel> channel_; + std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_; + std::unique_ptr<Server> server_; + std::ostringstream server_address_; + TestServiceImpl service_; + int picked_port_; +}; + +TEST_F(ServerEarlyReturnTest, BidiStreamEarlyOk) { DoBidiStream(false); } + +TEST_F(ServerEarlyReturnTest, BidiStreamEarlyCancel) { DoBidiStream(true); } + +TEST_F(ServerEarlyReturnTest, RequestStreamEarlyOK) { DoRequestStream(false); } +TEST_F(ServerEarlyReturnTest, RequestStreamEarlyCancel) { + DoRequestStream(true); +} + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/contrib/libs/grpc/test/cpp/end2end/server_interceptors/ya.make b/contrib/libs/grpc/test/cpp/end2end/server_interceptors/ya.make new file mode 100644 index 0000000000..161176f141 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/end2end/server_interceptors/ya.make @@ -0,0 +1,32 @@ +GTEST_UGLY() + +OWNER( + dvshkurko + g:ymake +) + +ADDINCL( + ${ARCADIA_BUILD_ROOT}/contrib/libs/grpc + ${ARCADIA_ROOT}/contrib/libs/grpc +) + +PEERDIR( + contrib/libs/grpc/src/proto/grpc/core + contrib/libs/grpc/src/proto/grpc/testing + contrib/libs/grpc/src/proto/grpc/testing/duplicate + contrib/libs/grpc/test/core/util + contrib/libs/grpc/test/cpp/end2end + contrib/libs/grpc/test/cpp/util +) + +NO_COMPILER_WARNINGS() + +SRCDIR( + contrib/libs/grpc/test/cpp/end2end +) + +SRCS( + server_interceptors_end2end_test.cc +) + +END() diff --git a/contrib/libs/grpc/test/cpp/end2end/server_interceptors_end2end_test.cc b/contrib/libs/grpc/test/cpp/end2end/server_interceptors_end2end_test.cc new file mode 100644 index 0000000000..6d2dc772ef --- /dev/null +++ b/contrib/libs/grpc/test/cpp/end2end/server_interceptors_end2end_test.cc @@ -0,0 +1,708 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <memory> +#include <vector> + +#include <grpcpp/channel.h> +#include <grpcpp/client_context.h> +#include <grpcpp/create_channel.h> +#include <grpcpp/generic/generic_stub.h> +#include <grpcpp/impl/codegen/proto_utils.h> +#include <grpcpp/server.h> +#include <grpcpp/server_builder.h> +#include <grpcpp/server_context.h> +#include <grpcpp/support/server_interceptor.h> + +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/end2end/interceptors_util.h" +#include "test/cpp/end2end/test_service_impl.h" +#include "test/cpp/util/byte_buffer_proto_helper.h" + +#include <gtest/gtest.h> + +namespace grpc { +namespace testing { +namespace { + +class LoggingInterceptor : public experimental::Interceptor { + public: + LoggingInterceptor(experimental::ServerRpcInfo* info) { + info_ = info; + + // Check the method name and compare to the type + const char* method = info->method(); + experimental::ServerRpcInfo::Type type = info->type(); + + // Check that we use one of our standard methods with expected type. + // Also allow the health checking service. + // We accept BIDI_STREAMING for Echo in case it's an AsyncGenericService + // being tested (the GenericRpc test). + // The empty method is for the Unimplemented requests that arise + // when draining the CQ. + EXPECT_TRUE( + strstr(method, "/grpc.health") == method || + (strcmp(method, "/grpc.testing.EchoTestService/Echo") == 0 && + (type == experimental::ServerRpcInfo::Type::UNARY || + type == experimental::ServerRpcInfo::Type::BIDI_STREAMING)) || + (strcmp(method, "/grpc.testing.EchoTestService/RequestStream") == 0 && + type == experimental::ServerRpcInfo::Type::CLIENT_STREAMING) || + (strcmp(method, "/grpc.testing.EchoTestService/ResponseStream") == 0 && + type == experimental::ServerRpcInfo::Type::SERVER_STREAMING) || + (strcmp(method, "/grpc.testing.EchoTestService/BidiStream") == 0 && + type == experimental::ServerRpcInfo::Type::BIDI_STREAMING) || + strcmp(method, "/grpc.testing.EchoTestService/Unimplemented") == 0 || + (strcmp(method, "") == 0 && + type == experimental::ServerRpcInfo::Type::BIDI_STREAMING)); + } + + void Intercept(experimental::InterceptorBatchMethods* methods) override { + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + auto* map = methods->GetSendInitialMetadata(); + // Got nothing better to do here for now + EXPECT_EQ(map->size(), static_cast<unsigned>(0)); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { + EchoRequest req; + auto* buffer = methods->GetSerializedSendMessage(); + auto copied_buffer = *buffer; + EXPECT_TRUE( + SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req) + .ok()); + EXPECT_TRUE(req.message().find("Hello") == 0); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_STATUS)) { + auto* map = methods->GetSendTrailingMetadata(); + bool found = false; + // Check that we received the metadata as an echo + for (const auto& pair : *map) { + found = pair.first.find("testkey") == 0 && + pair.second.find("testvalue") == 0; + if (found) break; + } + EXPECT_EQ(found, true); + auto status = methods->GetSendStatus(); + EXPECT_EQ(status.ok(), true); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) { + auto* map = methods->GetRecvInitialMetadata(); + bool found = false; + // Check that we received the metadata as an echo + for (const auto& pair : *map) { + found = pair.first.find("testkey") == 0 && + pair.second.find("testvalue") == 0; + if (found) break; + } + EXPECT_EQ(found, true); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) { + EchoResponse* resp = + static_cast<EchoResponse*>(methods->GetRecvMessage()); + if (resp != nullptr) { + EXPECT_TRUE(resp->message().find("Hello") == 0); + } + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_CLOSE)) { + // Got nothing interesting to do here + } + methods->Proceed(); + } + + private: + experimental::ServerRpcInfo* info_; +}; + +class LoggingInterceptorFactory + : public experimental::ServerInterceptorFactoryInterface { + public: + virtual experimental::Interceptor* CreateServerInterceptor( + experimental::ServerRpcInfo* info) override { + return new LoggingInterceptor(info); + } +}; + +// Test if SendMessage function family works as expected for sync/callback apis +class SyncSendMessageTester : public experimental::Interceptor { + public: + SyncSendMessageTester(experimental::ServerRpcInfo* /*info*/) {} + + void Intercept(experimental::InterceptorBatchMethods* methods) override { + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { + string old_msg = + static_cast<const EchoRequest*>(methods->GetSendMessage())->message(); + EXPECT_EQ(old_msg.find("Hello"), 0u); + new_msg_.set_message(TString("World" + old_msg).c_str()); + methods->ModifySendMessage(&new_msg_); + } + methods->Proceed(); + } + + private: + EchoRequest new_msg_; +}; + +class SyncSendMessageTesterFactory + : public experimental::ServerInterceptorFactoryInterface { + public: + virtual experimental::Interceptor* CreateServerInterceptor( + experimental::ServerRpcInfo* info) override { + return new SyncSendMessageTester(info); + } +}; + +// Test if SendMessage function family works as expected for sync/callback apis +class SyncSendMessageVerifier : public experimental::Interceptor { + public: + SyncSendMessageVerifier(experimental::ServerRpcInfo* /*info*/) {} + + void Intercept(experimental::InterceptorBatchMethods* methods) override { + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { + // Make sure that the changes made in SyncSendMessageTester persisted + string old_msg = + static_cast<const EchoRequest*>(methods->GetSendMessage())->message(); + EXPECT_EQ(old_msg.find("World"), 0u); + + // Remove the "World" part of the string that we added earlier + new_msg_.set_message(old_msg.erase(0, 5)); + methods->ModifySendMessage(&new_msg_); + + // LoggingInterceptor verifies that changes got reverted + } + methods->Proceed(); + } + + private: + EchoRequest new_msg_; +}; + +class SyncSendMessageVerifierFactory + : public experimental::ServerInterceptorFactoryInterface { + public: + virtual experimental::Interceptor* CreateServerInterceptor( + experimental::ServerRpcInfo* info) override { + return new SyncSendMessageVerifier(info); + } +}; + +void MakeBidiStreamingCall(const std::shared_ptr<Channel>& channel) { + auto stub = grpc::testing::EchoTestService::NewStub(channel); + ClientContext ctx; + EchoRequest req; + EchoResponse resp; + ctx.AddMetadata("testkey", "testvalue"); + auto stream = stub->BidiStream(&ctx); + for (auto i = 0; i < 10; i++) { + req.set_message("Hello" + ::ToString(i)); + stream->Write(req); + stream->Read(&resp); + EXPECT_EQ(req.message(), resp.message()); + } + ASSERT_TRUE(stream->WritesDone()); + Status s = stream->Finish(); + EXPECT_EQ(s.ok(), true); +} + +class ServerInterceptorsEnd2endSyncUnaryTest : public ::testing::Test { + protected: + ServerInterceptorsEnd2endSyncUnaryTest() { + int port = 5004; // grpc_pick_unused_port_or_die(); + + ServerBuilder builder; + server_address_ = "localhost:" + ::ToString(port); + builder.AddListeningPort(server_address_, InsecureServerCredentials()); + builder.RegisterService(&service_); + + std::vector< + std::unique_ptr<experimental::ServerInterceptorFactoryInterface>> + creators; + creators.push_back( + std::unique_ptr<experimental::ServerInterceptorFactoryInterface>( + new SyncSendMessageTesterFactory())); + creators.push_back( + std::unique_ptr<experimental::ServerInterceptorFactoryInterface>( + new SyncSendMessageVerifierFactory())); + creators.push_back( + std::unique_ptr<experimental::ServerInterceptorFactoryInterface>( + new LoggingInterceptorFactory())); + // Add 20 dummy interceptor factories and null interceptor factories + for (auto i = 0; i < 20; i++) { + creators.push_back(std::unique_ptr<DummyInterceptorFactory>( + new DummyInterceptorFactory())); + creators.push_back(std::unique_ptr<NullInterceptorFactory>( + new NullInterceptorFactory())); + } + builder.experimental().SetInterceptorCreators(std::move(creators)); + server_ = builder.BuildAndStart(); + } + TString server_address_; + TestServiceImpl service_; + std::unique_ptr<Server> server_; +}; + +TEST_F(ServerInterceptorsEnd2endSyncUnaryTest, UnaryTest) { + ChannelArguments args; + DummyInterceptor::Reset(); + auto channel = + grpc::CreateChannel(server_address_, InsecureChannelCredentials()); + MakeCall(channel); + // Make sure all 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); +} + +class ServerInterceptorsEnd2endSyncStreamingTest : public ::testing::Test { + protected: + ServerInterceptorsEnd2endSyncStreamingTest() { + int port = 5005; // grpc_pick_unused_port_or_die(); + + ServerBuilder builder; + server_address_ = "localhost:" + ::ToString(port); + builder.AddListeningPort(server_address_, InsecureServerCredentials()); + builder.RegisterService(&service_); + + std::vector< + std::unique_ptr<experimental::ServerInterceptorFactoryInterface>> + creators; + creators.push_back( + std::unique_ptr<experimental::ServerInterceptorFactoryInterface>( + new SyncSendMessageTesterFactory())); + creators.push_back( + std::unique_ptr<experimental::ServerInterceptorFactoryInterface>( + new SyncSendMessageVerifierFactory())); + creators.push_back( + std::unique_ptr<experimental::ServerInterceptorFactoryInterface>( + new LoggingInterceptorFactory())); + for (auto i = 0; i < 20; i++) { + creators.push_back(std::unique_ptr<DummyInterceptorFactory>( + new DummyInterceptorFactory())); + } + builder.experimental().SetInterceptorCreators(std::move(creators)); + server_ = builder.BuildAndStart(); + } + TString server_address_; + EchoTestServiceStreamingImpl service_; + std::unique_ptr<Server> server_; +}; + +TEST_F(ServerInterceptorsEnd2endSyncStreamingTest, ClientStreamingTest) { + ChannelArguments args; + DummyInterceptor::Reset(); + auto channel = + grpc::CreateChannel(server_address_, InsecureChannelCredentials()); + MakeClientStreamingCall(channel); + // Make sure all 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); +} + +TEST_F(ServerInterceptorsEnd2endSyncStreamingTest, ServerStreamingTest) { + ChannelArguments args; + DummyInterceptor::Reset(); + auto channel = + grpc::CreateChannel(server_address_, InsecureChannelCredentials()); + MakeServerStreamingCall(channel); + // Make sure all 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); +} + +TEST_F(ServerInterceptorsEnd2endSyncStreamingTest, BidiStreamingTest) { + ChannelArguments args; + DummyInterceptor::Reset(); + auto channel = + grpc::CreateChannel(server_address_, InsecureChannelCredentials()); + MakeBidiStreamingCall(channel); + // Make sure all 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); +} + +class ServerInterceptorsAsyncEnd2endTest : public ::testing::Test {}; + +TEST_F(ServerInterceptorsAsyncEnd2endTest, UnaryTest) { + DummyInterceptor::Reset(); + int port = 5006; // grpc_pick_unused_port_or_die(); + string server_address = "localhost:" + ::ToString(port); + ServerBuilder builder; + EchoTestService::AsyncService service; + builder.AddListeningPort(server_address, InsecureServerCredentials()); + builder.RegisterService(&service); + std::vector<std::unique_ptr<experimental::ServerInterceptorFactoryInterface>> + creators; + creators.push_back( + std::unique_ptr<experimental::ServerInterceptorFactoryInterface>( + new LoggingInterceptorFactory())); + for (auto i = 0; i < 20; i++) { + creators.push_back(std::unique_ptr<DummyInterceptorFactory>( + new DummyInterceptorFactory())); + } + builder.experimental().SetInterceptorCreators(std::move(creators)); + auto cq = builder.AddCompletionQueue(); + auto server = builder.BuildAndStart(); + + ChannelArguments args; + auto channel = + grpc::CreateChannel(server_address, InsecureChannelCredentials()); + auto stub = grpc::testing::EchoTestService::NewStub(channel); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + ServerContext srv_ctx; + grpc::ServerAsyncResponseWriter<EchoResponse> response_writer(&srv_ctx); + + send_request.set_message("Hello"); + cli_ctx.AddMetadata("testkey", "testvalue"); + std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader( + stub->AsyncEcho(&cli_ctx, send_request, cq.get())); + + service.RequestEcho(&srv_ctx, &recv_request, &response_writer, cq.get(), + cq.get(), tag(2)); + + response_reader->Finish(&recv_response, &recv_status, tag(4)); + + Verifier().Expect(2, true).Verify(cq.get()); + EXPECT_EQ(send_request.message(), recv_request.message()); + + EXPECT_TRUE(CheckMetadata(srv_ctx.client_metadata(), "testkey", "testvalue")); + srv_ctx.AddTrailingMetadata("testkey", "testvalue"); + + send_response.set_message(recv_request.message()); + response_writer.Finish(send_response, Status::OK, tag(3)); + Verifier().Expect(3, true).Expect(4, true).Verify(cq.get()); + + EXPECT_EQ(send_response.message(), recv_response.message()); + EXPECT_TRUE(recv_status.ok()); + EXPECT_TRUE(CheckMetadata(cli_ctx.GetServerTrailingMetadata(), "testkey", + "testvalue")); + + // Make sure all 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); + + server->Shutdown(); + cq->Shutdown(); + void* ignored_tag; + bool ignored_ok; + while (cq->Next(&ignored_tag, &ignored_ok)) + ; + // grpc_recycle_unused_port(port); +} + +TEST_F(ServerInterceptorsAsyncEnd2endTest, BidiStreamingTest) { + DummyInterceptor::Reset(); + int port = 5007; // grpc_pick_unused_port_or_die(); + string server_address = "localhost:" + ::ToString(port); + ServerBuilder builder; + EchoTestService::AsyncService service; + builder.AddListeningPort(server_address, InsecureServerCredentials()); + builder.RegisterService(&service); + std::vector<std::unique_ptr<experimental::ServerInterceptorFactoryInterface>> + creators; + creators.push_back( + std::unique_ptr<experimental::ServerInterceptorFactoryInterface>( + new LoggingInterceptorFactory())); + for (auto i = 0; i < 20; i++) { + creators.push_back(std::unique_ptr<DummyInterceptorFactory>( + new DummyInterceptorFactory())); + } + builder.experimental().SetInterceptorCreators(std::move(creators)); + auto cq = builder.AddCompletionQueue(); + auto server = builder.BuildAndStart(); + + ChannelArguments args; + auto channel = + grpc::CreateChannel(server_address, InsecureChannelCredentials()); + auto stub = grpc::testing::EchoTestService::NewStub(channel); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + ServerContext srv_ctx; + grpc::ServerAsyncReaderWriter<EchoResponse, EchoRequest> srv_stream(&srv_ctx); + + send_request.set_message("Hello"); + cli_ctx.AddMetadata("testkey", "testvalue"); + std::unique_ptr<ClientAsyncReaderWriter<EchoRequest, EchoResponse>> + cli_stream(stub->AsyncBidiStream(&cli_ctx, cq.get(), tag(1))); + + service.RequestBidiStream(&srv_ctx, &srv_stream, cq.get(), cq.get(), tag(2)); + + Verifier().Expect(1, true).Expect(2, true).Verify(cq.get()); + + EXPECT_TRUE(CheckMetadata(srv_ctx.client_metadata(), "testkey", "testvalue")); + srv_ctx.AddTrailingMetadata("testkey", "testvalue"); + + cli_stream->Write(send_request, tag(3)); + srv_stream.Read(&recv_request, tag(4)); + Verifier().Expect(3, true).Expect(4, true).Verify(cq.get()); + EXPECT_EQ(send_request.message(), recv_request.message()); + + send_response.set_message(recv_request.message()); + srv_stream.Write(send_response, tag(5)); + cli_stream->Read(&recv_response, tag(6)); + Verifier().Expect(5, true).Expect(6, true).Verify(cq.get()); + EXPECT_EQ(send_response.message(), recv_response.message()); + + cli_stream->WritesDone(tag(7)); + srv_stream.Read(&recv_request, tag(8)); + Verifier().Expect(7, true).Expect(8, false).Verify(cq.get()); + + srv_stream.Finish(Status::OK, tag(9)); + cli_stream->Finish(&recv_status, tag(10)); + Verifier().Expect(9, true).Expect(10, true).Verify(cq.get()); + + EXPECT_TRUE(recv_status.ok()); + EXPECT_TRUE(CheckMetadata(cli_ctx.GetServerTrailingMetadata(), "testkey", + "testvalue")); + + // Make sure all 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); + + server->Shutdown(); + cq->Shutdown(); + void* ignored_tag; + bool ignored_ok; + while (cq->Next(&ignored_tag, &ignored_ok)) + ; + // grpc_recycle_unused_port(port); +} + +TEST_F(ServerInterceptorsAsyncEnd2endTest, GenericRPCTest) { + DummyInterceptor::Reset(); + int port = 5008; // grpc_pick_unused_port_or_die(); + string server_address = "localhost:" + ::ToString(port); + ServerBuilder builder; + AsyncGenericService service; + builder.AddListeningPort(server_address, InsecureServerCredentials()); + builder.RegisterAsyncGenericService(&service); + std::vector<std::unique_ptr<experimental::ServerInterceptorFactoryInterface>> + creators; + creators.reserve(20); + for (auto i = 0; i < 20; i++) { + creators.push_back(std::unique_ptr<DummyInterceptorFactory>( + new DummyInterceptorFactory())); + } + builder.experimental().SetInterceptorCreators(std::move(creators)); + auto srv_cq = builder.AddCompletionQueue(); + CompletionQueue cli_cq; + auto server = builder.BuildAndStart(); + + ChannelArguments args; + auto channel = + grpc::CreateChannel(server_address, InsecureChannelCredentials()); + GenericStub generic_stub(channel); + + const TString kMethodName("/grpc.cpp.test.util.EchoTestService/Echo"); + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + GenericServerContext srv_ctx; + GenericServerAsyncReaderWriter stream(&srv_ctx); + + // The string needs to be long enough to test heap-based slice. + send_request.set_message("Hello"); + cli_ctx.AddMetadata("testkey", "testvalue"); + + CompletionQueue* cq = srv_cq.get(); + std::thread request_call([cq]() { Verifier().Expect(4, true).Verify(cq); }); + std::unique_ptr<GenericClientAsyncReaderWriter> call = + generic_stub.PrepareCall(&cli_ctx, kMethodName, &cli_cq); + call->StartCall(tag(1)); + Verifier().Expect(1, true).Verify(&cli_cq); + std::unique_ptr<ByteBuffer> send_buffer = + SerializeToByteBuffer(&send_request); + call->Write(*send_buffer, tag(2)); + // Send ByteBuffer can be destroyed after calling Write. + send_buffer.reset(); + Verifier().Expect(2, true).Verify(&cli_cq); + call->WritesDone(tag(3)); + Verifier().Expect(3, true).Verify(&cli_cq); + + service.RequestCall(&srv_ctx, &stream, srv_cq.get(), srv_cq.get(), tag(4)); + + request_call.join(); + EXPECT_EQ(kMethodName, srv_ctx.method()); + EXPECT_TRUE(CheckMetadata(srv_ctx.client_metadata(), "testkey", "testvalue")); + srv_ctx.AddTrailingMetadata("testkey", "testvalue"); + + ByteBuffer recv_buffer; + stream.Read(&recv_buffer, tag(5)); + Verifier().Expect(5, true).Verify(srv_cq.get()); + EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_request)); + EXPECT_EQ(send_request.message(), recv_request.message()); + + send_response.set_message(recv_request.message()); + send_buffer = SerializeToByteBuffer(&send_response); + stream.Write(*send_buffer, tag(6)); + send_buffer.reset(); + Verifier().Expect(6, true).Verify(srv_cq.get()); + + stream.Finish(Status::OK, tag(7)); + // Shutdown srv_cq before we try to get the tag back, to verify that the + // interception API handles completion queue shutdowns that take place before + // all the tags are returned + srv_cq->Shutdown(); + Verifier().Expect(7, true).Verify(srv_cq.get()); + + recv_buffer.Clear(); + call->Read(&recv_buffer, tag(8)); + Verifier().Expect(8, true).Verify(&cli_cq); + EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_response)); + + call->Finish(&recv_status, tag(9)); + cli_cq.Shutdown(); + Verifier().Expect(9, true).Verify(&cli_cq); + + EXPECT_EQ(send_response.message(), recv_response.message()); + EXPECT_TRUE(recv_status.ok()); + EXPECT_TRUE(CheckMetadata(cli_ctx.GetServerTrailingMetadata(), "testkey", + "testvalue")); + + // Make sure all 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); + + server->Shutdown(); + void* ignored_tag; + bool ignored_ok; + while (cli_cq.Next(&ignored_tag, &ignored_ok)) + ; + while (srv_cq->Next(&ignored_tag, &ignored_ok)) + ; + // grpc_recycle_unused_port(port); +} + +TEST_F(ServerInterceptorsAsyncEnd2endTest, UnimplementedRpcTest) { + DummyInterceptor::Reset(); + int port = 5009; // grpc_pick_unused_port_or_die(); + string server_address = "localhost:" + ::ToString(port); + ServerBuilder builder; + builder.AddListeningPort(server_address, InsecureServerCredentials()); + std::vector<std::unique_ptr<experimental::ServerInterceptorFactoryInterface>> + creators; + creators.reserve(20); + for (auto i = 0; i < 20; i++) { + creators.push_back(std::unique_ptr<DummyInterceptorFactory>( + new DummyInterceptorFactory())); + } + builder.experimental().SetInterceptorCreators(std::move(creators)); + auto cq = builder.AddCompletionQueue(); + auto server = builder.BuildAndStart(); + + ChannelArguments args; + std::shared_ptr<Channel> channel = + grpc::CreateChannel(server_address, InsecureChannelCredentials()); + std::unique_ptr<grpc::testing::UnimplementedEchoService::Stub> stub; + stub = grpc::testing::UnimplementedEchoService::NewStub(channel); + EchoRequest send_request; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + send_request.set_message("Hello"); + std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader( + stub->AsyncUnimplemented(&cli_ctx, send_request, cq.get())); + + response_reader->Finish(&recv_response, &recv_status, tag(4)); + Verifier().Expect(4, true).Verify(cq.get()); + + EXPECT_EQ(StatusCode::UNIMPLEMENTED, recv_status.error_code()); + EXPECT_EQ("", recv_status.error_message()); + + // Make sure all 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); + + server->Shutdown(); + cq->Shutdown(); + void* ignored_tag; + bool ignored_ok; + while (cq->Next(&ignored_tag, &ignored_ok)) + ; + // grpc_recycle_unused_port(port); +} + +class ServerInterceptorsSyncUnimplementedEnd2endTest : public ::testing::Test { +}; + +TEST_F(ServerInterceptorsSyncUnimplementedEnd2endTest, UnimplementedRpcTest) { + DummyInterceptor::Reset(); + int port = 5010; // grpc_pick_unused_port_or_die(); + string server_address = "localhost:" + ::ToString(port); + ServerBuilder builder; + TestServiceImpl service; + builder.RegisterService(&service); + builder.AddListeningPort(server_address, InsecureServerCredentials()); + std::vector<std::unique_ptr<experimental::ServerInterceptorFactoryInterface>> + creators; + creators.reserve(20); + for (auto i = 0; i < 20; i++) { + creators.push_back(std::unique_ptr<DummyInterceptorFactory>( + new DummyInterceptorFactory())); + } + builder.experimental().SetInterceptorCreators(std::move(creators)); + auto server = builder.BuildAndStart(); + + ChannelArguments args; + std::shared_ptr<Channel> channel = + grpc::CreateChannel(server_address, InsecureChannelCredentials()); + std::unique_ptr<grpc::testing::UnimplementedEchoService::Stub> stub; + stub = grpc::testing::UnimplementedEchoService::NewStub(channel); + EchoRequest send_request; + EchoResponse recv_response; + + ClientContext cli_ctx; + send_request.set_message("Hello"); + Status recv_status = + stub->Unimplemented(&cli_ctx, send_request, &recv_response); + + EXPECT_EQ(StatusCode::UNIMPLEMENTED, recv_status.error_code()); + EXPECT_EQ("", recv_status.error_message()); + + // Make sure all 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); + + server->Shutdown(); + // grpc_recycle_unused_port(port); +} + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/contrib/libs/grpc/test/cpp/end2end/server_load_reporting_end2end_test.cc b/contrib/libs/grpc/test/cpp/end2end/server_load_reporting_end2end_test.cc new file mode 100644 index 0000000000..13833cf66c --- /dev/null +++ b/contrib/libs/grpc/test/cpp/end2end/server_load_reporting_end2end_test.cc @@ -0,0 +1,192 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <grpc/support/port_platform.h> + +#include <thread> + +#include <gmock/gmock.h> +#include <gtest/gtest.h> + +#include <grpc++/grpc++.h> +#include <grpc/grpc.h> +#include <grpc/support/log.h> +#include <grpc/support/string_util.h> +#include <grpcpp/ext/server_load_reporting.h> +#include <grpcpp/server_builder.h> + +#include "src/proto/grpc/lb/v1/load_reporter.grpc.pb.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +namespace grpc { +namespace testing { +namespace { + +constexpr double kMetricValue = 3.1415; +constexpr char kMetricName[] = "METRIC_PI"; + +// Different messages result in different response statuses. For simplicity in +// computing request bytes, the message sizes should be the same. +const char kOkMessage[] = "hello"; +const char kServerErrorMessage[] = "sverr"; +const char kClientErrorMessage[] = "clerr"; + +class EchoTestServiceImpl : public EchoTestService::Service { + public: + ~EchoTestServiceImpl() override {} + + Status Echo(ServerContext* context, const EchoRequest* request, + EchoResponse* response) override { + if (request->message() == kServerErrorMessage) { + return Status(StatusCode::UNKNOWN, "Server error requested"); + } + if (request->message() == kClientErrorMessage) { + return Status(StatusCode::FAILED_PRECONDITION, "Client error requested"); + } + response->set_message(request->message()); + ::grpc::load_reporter::experimental::AddLoadReportingCost( + context, kMetricName, kMetricValue); + return Status::OK; + } +}; + +class ServerLoadReportingEnd2endTest : public ::testing::Test { + protected: + void SetUp() override { + server_address_ = + "localhost:" + ToString(grpc_pick_unused_port_or_die()); + server_ = + ServerBuilder() + .AddListeningPort(server_address_, InsecureServerCredentials()) + .RegisterService(&echo_service_) + .SetOption(std::unique_ptr<::grpc::ServerBuilderOption>( + new ::grpc::load_reporter::experimental:: + LoadReportingServiceServerBuilderOption())) + .BuildAndStart(); + server_thread_ = + std::thread(&ServerLoadReportingEnd2endTest::RunServerLoop, this); + } + + void RunServerLoop() { server_->Wait(); } + + void TearDown() override { + server_->Shutdown(); + server_thread_.join(); + } + + void ClientMakeEchoCalls(const TString& lb_id, const TString& lb_tag, + const TString& message, size_t num_requests) { + auto stub = EchoTestService::NewStub( + grpc::CreateChannel(server_address_, InsecureChannelCredentials())); + TString lb_token = lb_id + lb_tag; + for (int i = 0; i < num_requests; ++i) { + ClientContext ctx; + if (!lb_token.empty()) ctx.AddMetadata(GRPC_LB_TOKEN_MD_KEY, lb_token); + EchoRequest request; + EchoResponse response; + request.set_message(message); + Status status = stub->Echo(&ctx, request, &response); + if (message == kOkMessage) { + ASSERT_EQ(status.error_code(), StatusCode::OK); + ASSERT_EQ(request.message(), response.message()); + } else if (message == kServerErrorMessage) { + ASSERT_EQ(status.error_code(), StatusCode::UNKNOWN); + } else if (message == kClientErrorMessage) { + ASSERT_EQ(status.error_code(), StatusCode::FAILED_PRECONDITION); + } + } + } + + TString server_address_; + std::unique_ptr<Server> server_; + std::thread server_thread_; + EchoTestServiceImpl echo_service_; +}; + +TEST_F(ServerLoadReportingEnd2endTest, NoCall) {} + +TEST_F(ServerLoadReportingEnd2endTest, BasicReport) { + auto channel = + grpc::CreateChannel(server_address_, InsecureChannelCredentials()); + auto stub = ::grpc::lb::v1::LoadReporter::NewStub(channel); + ClientContext ctx; + auto stream = stub->ReportLoad(&ctx); + ::grpc::lb::v1::LoadReportRequest request; + request.mutable_initial_request()->set_load_balanced_hostname( + server_address_); + request.mutable_initial_request()->set_load_key("LOAD_KEY"); + request.mutable_initial_request() + ->mutable_load_report_interval() + ->set_seconds(5); + stream->Write(request); + gpr_log(GPR_INFO, "Initial request sent."); + ::grpc::lb::v1::LoadReportResponse response; + stream->Read(&response); + const TString& lb_id = response.initial_response().load_balancer_id(); + gpr_log(GPR_INFO, "Initial response received (lb_id: %s).", lb_id.c_str()); + ClientMakeEchoCalls(lb_id, "LB_TAG", kOkMessage, 1); + while (true) { + stream->Read(&response); + if (!response.load().empty()) { + ASSERT_EQ(response.load().size(), 3); + for (const auto& load : response.load()) { + if (load.in_progress_report_case()) { + // The special load record that reports the number of in-progress + // calls. + ASSERT_EQ(load.num_calls_in_progress(), 1); + } else if (load.orphaned_load_case()) { + // The call from the balancer doesn't have any valid LB token. + ASSERT_EQ(load.orphaned_load_case(), load.kLoadKeyUnknown); + ASSERT_EQ(load.num_calls_started(), 1); + ASSERT_EQ(load.num_calls_finished_without_error(), 0); + ASSERT_EQ(load.num_calls_finished_with_error(), 0); + } else { + // This corresponds to the calls from the client. + ASSERT_EQ(load.num_calls_started(), 1); + ASSERT_EQ(load.num_calls_finished_without_error(), 1); + ASSERT_EQ(load.num_calls_finished_with_error(), 0); + ASSERT_GE(load.total_bytes_received(), sizeof(kOkMessage)); + ASSERT_GE(load.total_bytes_sent(), sizeof(kOkMessage)); + ASSERT_EQ(load.metric_data().size(), 1); + ASSERT_EQ(load.metric_data().Get(0).metric_name(), kMetricName); + ASSERT_EQ(load.metric_data().Get(0).num_calls_finished_with_metric(), + 1); + ASSERT_EQ(load.metric_data().Get(0).total_metric_value(), + kMetricValue); + } + } + break; + } + } + stream->WritesDone(); + ASSERT_EQ(stream->Finish().error_code(), StatusCode::CANCELLED); +} + +// TODO(juanlishen): Add more tests. + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/contrib/libs/grpc/test/cpp/end2end/service_config_end2end_test.cc b/contrib/libs/grpc/test/cpp/end2end/service_config_end2end_test.cc new file mode 100644 index 0000000000..cee33343c1 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/end2end/service_config_end2end_test.cc @@ -0,0 +1,613 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <algorithm> +#include <memory> +#include <mutex> +#include <random> +#include <set> +#include <util/generic/string.h> +#include <thread> + +#include "y_absl/strings/str_cat.h" + +#include <grpc/grpc.h> +#include <grpc/support/alloc.h> +#include <grpc/support/atm.h> +#include <grpc/support/log.h> +#include <grpc/support/time.h> +#include <grpcpp/channel.h> +#include <grpcpp/client_context.h> +#include <grpcpp/create_channel.h> +#include <grpcpp/health_check_service_interface.h> +#include <grpcpp/impl/codegen/sync.h> +#include <grpcpp/server.h> +#include <grpcpp/server_builder.h> +#include <grpcpp/support/validate_service_config.h> + +#include "src/core/ext/filters/client_channel/backup_poller.h" +#include "src/core/ext/filters/client_channel/global_subchannel_pool.h" +#include "src/core/ext/filters/client_channel/resolver/fake/fake_resolver.h" +#include "src/core/ext/filters/client_channel/server_address.h" +#include "src/core/lib/backoff/backoff.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gprpp/debug_location.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/iomgr/parse_address.h" +#include "src/core/lib/iomgr/tcp_client.h" +#include "src/core/lib/security/credentials/fake/fake_credentials.h" +#include "src/cpp/client/secure_credentials.h" +#include "src/cpp/server/secure_server_credentials.h" + +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/end2end/test_service_impl.h" + +#include <gmock/gmock.h> +#include <gtest/gtest.h> + +using grpc::testing::EchoRequest; +using grpc::testing::EchoResponse; +using std::chrono::system_clock; + +namespace grpc { +namespace testing { +namespace { + +// Subclass of TestServiceImpl that increments a request counter for +// every call to the Echo RPC. +class MyTestServiceImpl : public TestServiceImpl { + public: + MyTestServiceImpl() : request_count_(0) {} + + Status Echo(ServerContext* context, const EchoRequest* request, + EchoResponse* response) override { + { + grpc::internal::MutexLock lock(&mu_); + ++request_count_; + } + AddClient(context->peer()); + return TestServiceImpl::Echo(context, request, response); + } + + int request_count() { + grpc::internal::MutexLock lock(&mu_); + return request_count_; + } + + void ResetCounters() { + grpc::internal::MutexLock lock(&mu_); + request_count_ = 0; + } + + std::set<TString> clients() { + grpc::internal::MutexLock lock(&clients_mu_); + return clients_; + } + + private: + void AddClient(const TString& client) { + grpc::internal::MutexLock lock(&clients_mu_); + clients_.insert(client); + } + + grpc::internal::Mutex mu_; + int request_count_; + grpc::internal::Mutex clients_mu_; + std::set<TString> clients_; +}; + +class ServiceConfigEnd2endTest : public ::testing::Test { + protected: + ServiceConfigEnd2endTest() + : server_host_("localhost"), + kRequestMessage_("Live long and prosper."), + creds_(new SecureChannelCredentials( + grpc_fake_transport_security_credentials_create())) {} + + static void SetUpTestCase() { + // Make the backup poller poll very frequently in order to pick up + // updates from all the subchannels's FDs. + GPR_GLOBAL_CONFIG_SET(grpc_client_channel_backup_poll_interval_ms, 1); + } + + void SetUp() override { + grpc_init(); + response_generator_ = + grpc_core::MakeRefCounted<grpc_core::FakeResolverResponseGenerator>(); + } + + void TearDown() override { + for (size_t i = 0; i < servers_.size(); ++i) { + servers_[i]->Shutdown(); + } + // Explicitly destroy all the members so that we can make sure grpc_shutdown + // has finished by the end of this function, and thus all the registered + // LB policy factories are removed. + stub_.reset(); + servers_.clear(); + creds_.reset(); + grpc_shutdown_blocking(); + } + + void CreateServers(size_t num_servers, + std::vector<int> ports = std::vector<int>()) { + servers_.clear(); + for (size_t i = 0; i < num_servers; ++i) { + int port = 0; + if (ports.size() == num_servers) port = ports[i]; + servers_.emplace_back(new ServerData(port)); + } + } + + void StartServer(size_t index) { servers_[index]->Start(server_host_); } + + void StartServers(size_t num_servers, + std::vector<int> ports = std::vector<int>()) { + CreateServers(num_servers, std::move(ports)); + for (size_t i = 0; i < num_servers; ++i) { + StartServer(i); + } + } + + grpc_core::Resolver::Result BuildFakeResults(const std::vector<int>& ports) { + grpc_core::Resolver::Result result; + for (const int& port : ports) { + TString lb_uri_str = y_absl::StrCat("ipv4:127.0.0.1:", port); + grpc_uri* lb_uri = grpc_uri_parse(lb_uri_str.c_str(), true); + GPR_ASSERT(lb_uri != nullptr); + grpc_resolved_address address; + GPR_ASSERT(grpc_parse_uri(lb_uri, &address)); + result.addresses.emplace_back(address.addr, address.len, + nullptr /* args */); + grpc_uri_destroy(lb_uri); + } + return result; + } + + void SetNextResolutionNoServiceConfig(const std::vector<int>& ports) { + grpc_core::ExecCtx exec_ctx; + grpc_core::Resolver::Result result = BuildFakeResults(ports); + response_generator_->SetResponse(result); + } + + void SetNextResolutionValidServiceConfig(const std::vector<int>& ports) { + grpc_core::ExecCtx exec_ctx; + grpc_core::Resolver::Result result = BuildFakeResults(ports); + result.service_config = grpc_core::ServiceConfig::Create( + nullptr, "{}", &result.service_config_error); + response_generator_->SetResponse(result); + } + + void SetNextResolutionInvalidServiceConfig(const std::vector<int>& ports) { + grpc_core::ExecCtx exec_ctx; + grpc_core::Resolver::Result result = BuildFakeResults(ports); + result.service_config = grpc_core::ServiceConfig::Create( + nullptr, "{", &result.service_config_error); + response_generator_->SetResponse(result); + } + + void SetNextResolutionWithServiceConfig(const std::vector<int>& ports, + const char* svc_cfg) { + grpc_core::ExecCtx exec_ctx; + grpc_core::Resolver::Result result = BuildFakeResults(ports); + result.service_config = grpc_core::ServiceConfig::Create( + nullptr, svc_cfg, &result.service_config_error); + response_generator_->SetResponse(result); + } + + std::vector<int> GetServersPorts(size_t start_index = 0) { + std::vector<int> ports; + for (size_t i = start_index; i < servers_.size(); ++i) { + ports.push_back(servers_[i]->port_); + } + return ports; + } + + std::unique_ptr<grpc::testing::EchoTestService::Stub> BuildStub( + const std::shared_ptr<Channel>& channel) { + return grpc::testing::EchoTestService::NewStub(channel); + } + + std::shared_ptr<Channel> BuildChannel() { + ChannelArguments args; + args.SetPointer(GRPC_ARG_FAKE_RESOLVER_RESPONSE_GENERATOR, + response_generator_.get()); + return ::grpc::CreateCustomChannel("fake:///", creds_, args); + } + + std::shared_ptr<Channel> BuildChannelWithDefaultServiceConfig() { + ChannelArguments args; + EXPECT_THAT(grpc::experimental::ValidateServiceConfigJSON( + ValidDefaultServiceConfig()), + ::testing::StrEq("")); + args.SetServiceConfigJSON(ValidDefaultServiceConfig()); + args.SetPointer(GRPC_ARG_FAKE_RESOLVER_RESPONSE_GENERATOR, + response_generator_.get()); + return ::grpc::CreateCustomChannel("fake:///", creds_, args); + } + + std::shared_ptr<Channel> BuildChannelWithInvalidDefaultServiceConfig() { + ChannelArguments args; + EXPECT_THAT(grpc::experimental::ValidateServiceConfigJSON( + InvalidDefaultServiceConfig()), + ::testing::HasSubstr("JSON parse error")); + args.SetServiceConfigJSON(InvalidDefaultServiceConfig()); + args.SetPointer(GRPC_ARG_FAKE_RESOLVER_RESPONSE_GENERATOR, + response_generator_.get()); + return ::grpc::CreateCustomChannel("fake:///", creds_, args); + } + + bool SendRpc( + const std::unique_ptr<grpc::testing::EchoTestService::Stub>& stub, + EchoResponse* response = nullptr, int timeout_ms = 1000, + Status* result = nullptr, bool wait_for_ready = false) { + const bool local_response = (response == nullptr); + if (local_response) response = new EchoResponse; + EchoRequest request; + request.set_message(kRequestMessage_); + ClientContext context; + context.set_deadline(grpc_timeout_milliseconds_to_deadline(timeout_ms)); + if (wait_for_ready) context.set_wait_for_ready(true); + Status status = stub->Echo(&context, request, response); + if (result != nullptr) *result = status; + if (local_response) delete response; + return status.ok(); + } + + void CheckRpcSendOk( + const std::unique_ptr<grpc::testing::EchoTestService::Stub>& stub, + const grpc_core::DebugLocation& location, bool wait_for_ready = false) { + EchoResponse response; + Status status; + const bool success = + SendRpc(stub, &response, 2000, &status, wait_for_ready); + ASSERT_TRUE(success) << "From " << location.file() << ":" << location.line() + << "\n" + << "Error: " << status.error_message() << " " + << status.error_details(); + ASSERT_EQ(response.message(), kRequestMessage_) + << "From " << location.file() << ":" << location.line(); + if (!success) abort(); + } + + void CheckRpcSendFailure( + const std::unique_ptr<grpc::testing::EchoTestService::Stub>& stub) { + const bool success = SendRpc(stub); + EXPECT_FALSE(success); + } + + struct ServerData { + int port_; + std::unique_ptr<Server> server_; + MyTestServiceImpl service_; + std::unique_ptr<std::thread> thread_; + bool server_ready_ = false; + bool started_ = false; + + explicit ServerData(int port = 0) { + port_ = port > 0 ? port : grpc_pick_unused_port_or_die(); + } + + void Start(const TString& server_host) { + gpr_log(GPR_INFO, "starting server on port %d", port_); + started_ = true; + grpc::internal::Mutex mu; + grpc::internal::MutexLock lock(&mu); + grpc::internal::CondVar cond; + thread_.reset(new std::thread( + std::bind(&ServerData::Serve, this, server_host, &mu, &cond))); + cond.WaitUntil(&mu, [this] { return server_ready_; }); + server_ready_ = false; + gpr_log(GPR_INFO, "server startup complete"); + } + + void Serve(const TString& server_host, grpc::internal::Mutex* mu, + grpc::internal::CondVar* cond) { + std::ostringstream server_address; + server_address << server_host << ":" << port_; + ServerBuilder builder; + std::shared_ptr<ServerCredentials> creds(new SecureServerCredentials( + grpc_fake_transport_security_server_credentials_create())); + builder.AddListeningPort(server_address.str(), std::move(creds)); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + grpc::internal::MutexLock lock(mu); + server_ready_ = true; + cond->Signal(); + } + + void Shutdown() { + if (!started_) return; + server_->Shutdown(grpc_timeout_milliseconds_to_deadline(0)); + thread_->join(); + started_ = false; + } + + void SetServingStatus(const TString& service, bool serving) { + server_->GetHealthCheckService()->SetServingStatus(service, serving); + } + }; + + void ResetCounters() { + for (const auto& server : servers_) server->service_.ResetCounters(); + } + + void WaitForServer( + const std::unique_ptr<grpc::testing::EchoTestService::Stub>& stub, + size_t server_idx, const grpc_core::DebugLocation& location, + bool ignore_failure = false) { + do { + if (ignore_failure) { + SendRpc(stub); + } else { + CheckRpcSendOk(stub, location, true); + } + } while (servers_[server_idx]->service_.request_count() == 0); + ResetCounters(); + } + + bool WaitForChannelNotReady(Channel* channel, int timeout_seconds = 5) { + const gpr_timespec deadline = + grpc_timeout_seconds_to_deadline(timeout_seconds); + grpc_connectivity_state state; + while ((state = channel->GetState(false /* try_to_connect */)) == + GRPC_CHANNEL_READY) { + if (!channel->WaitForStateChange(state, deadline)) return false; + } + return true; + } + + bool WaitForChannelReady(Channel* channel, int timeout_seconds = 5) { + const gpr_timespec deadline = + grpc_timeout_seconds_to_deadline(timeout_seconds); + grpc_connectivity_state state; + while ((state = channel->GetState(true /* try_to_connect */)) != + GRPC_CHANNEL_READY) { + if (!channel->WaitForStateChange(state, deadline)) return false; + } + return true; + } + + bool SeenAllServers() { + for (const auto& server : servers_) { + if (server->service_.request_count() == 0) return false; + } + return true; + } + + // Updates \a connection_order by appending to it the index of the newly + // connected server. Must be called after every single RPC. + void UpdateConnectionOrder( + const std::vector<std::unique_ptr<ServerData>>& servers, + std::vector<int>* connection_order) { + for (size_t i = 0; i < servers.size(); ++i) { + if (servers[i]->service_.request_count() == 1) { + // Was the server index known? If not, update connection_order. + const auto it = + std::find(connection_order->begin(), connection_order->end(), i); + if (it == connection_order->end()) { + connection_order->push_back(i); + return; + } + } + } + } + + const char* ValidServiceConfigV1() { return "{\"version\": \"1\"}"; } + + const char* ValidServiceConfigV2() { return "{\"version\": \"2\"}"; } + + const char* ValidDefaultServiceConfig() { + return "{\"version\": \"valid_default\"}"; + } + + const char* InvalidDefaultServiceConfig() { + return "{\"version\": \"invalid_default\""; + } + + const TString server_host_; + std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_; + std::vector<std::unique_ptr<ServerData>> servers_; + grpc_core::RefCountedPtr<grpc_core::FakeResolverResponseGenerator> + response_generator_; + const TString kRequestMessage_; + std::shared_ptr<ChannelCredentials> creds_; +}; + +TEST_F(ServiceConfigEnd2endTest, NoServiceConfigTest) { + StartServers(1); + auto channel = BuildChannel(); + auto stub = BuildStub(channel); + SetNextResolutionNoServiceConfig(GetServersPorts()); + CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_STREQ("{}", channel->GetServiceConfigJSON().c_str()); +} + +TEST_F(ServiceConfigEnd2endTest, NoServiceConfigWithDefaultConfigTest) { + StartServers(1); + auto channel = BuildChannelWithDefaultServiceConfig(); + auto stub = BuildStub(channel); + SetNextResolutionNoServiceConfig(GetServersPorts()); + CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_STREQ(ValidDefaultServiceConfig(), + channel->GetServiceConfigJSON().c_str()); +} + +TEST_F(ServiceConfigEnd2endTest, InvalidServiceConfigTest) { + StartServers(1); + auto channel = BuildChannel(); + auto stub = BuildStub(channel); + SetNextResolutionInvalidServiceConfig(GetServersPorts()); + CheckRpcSendFailure(stub); +} + +TEST_F(ServiceConfigEnd2endTest, ValidServiceConfigUpdatesTest) { + StartServers(1); + auto channel = BuildChannel(); + auto stub = BuildStub(channel); + SetNextResolutionWithServiceConfig(GetServersPorts(), ValidServiceConfigV1()); + CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_STREQ(ValidServiceConfigV1(), channel->GetServiceConfigJSON().c_str()); + SetNextResolutionWithServiceConfig(GetServersPorts(), ValidServiceConfigV2()); + CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_STREQ(ValidServiceConfigV2(), channel->GetServiceConfigJSON().c_str()); +} + +TEST_F(ServiceConfigEnd2endTest, + NoServiceConfigUpdateAfterValidServiceConfigTest) { + StartServers(1); + auto channel = BuildChannel(); + auto stub = BuildStub(channel); + SetNextResolutionWithServiceConfig(GetServersPorts(), ValidServiceConfigV1()); + CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_STREQ(ValidServiceConfigV1(), channel->GetServiceConfigJSON().c_str()); + SetNextResolutionNoServiceConfig(GetServersPorts()); + CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_STREQ("{}", channel->GetServiceConfigJSON().c_str()); +} + +TEST_F(ServiceConfigEnd2endTest, + NoServiceConfigUpdateAfterValidServiceConfigWithDefaultConfigTest) { + StartServers(1); + auto channel = BuildChannelWithDefaultServiceConfig(); + auto stub = BuildStub(channel); + SetNextResolutionWithServiceConfig(GetServersPorts(), ValidServiceConfigV1()); + CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_STREQ(ValidServiceConfigV1(), channel->GetServiceConfigJSON().c_str()); + SetNextResolutionNoServiceConfig(GetServersPorts()); + CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_STREQ(ValidDefaultServiceConfig(), + channel->GetServiceConfigJSON().c_str()); +} + +TEST_F(ServiceConfigEnd2endTest, + InvalidServiceConfigUpdateAfterValidServiceConfigTest) { + StartServers(1); + auto channel = BuildChannel(); + auto stub = BuildStub(channel); + SetNextResolutionWithServiceConfig(GetServersPorts(), ValidServiceConfigV1()); + CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_STREQ(ValidServiceConfigV1(), channel->GetServiceConfigJSON().c_str()); + SetNextResolutionInvalidServiceConfig(GetServersPorts()); + CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_STREQ(ValidServiceConfigV1(), channel->GetServiceConfigJSON().c_str()); +} + +TEST_F(ServiceConfigEnd2endTest, + InvalidServiceConfigUpdateAfterValidServiceConfigWithDefaultConfigTest) { + StartServers(1); + auto channel = BuildChannelWithDefaultServiceConfig(); + auto stub = BuildStub(channel); + SetNextResolutionWithServiceConfig(GetServersPorts(), ValidServiceConfigV1()); + CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_STREQ(ValidServiceConfigV1(), channel->GetServiceConfigJSON().c_str()); + SetNextResolutionInvalidServiceConfig(GetServersPorts()); + CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_STREQ(ValidServiceConfigV1(), channel->GetServiceConfigJSON().c_str()); +} + +TEST_F(ServiceConfigEnd2endTest, + ValidServiceConfigAfterInvalidServiceConfigTest) { + StartServers(1); + auto channel = BuildChannel(); + auto stub = BuildStub(channel); + SetNextResolutionInvalidServiceConfig(GetServersPorts()); + CheckRpcSendFailure(stub); + SetNextResolutionValidServiceConfig(GetServersPorts()); + CheckRpcSendOk(stub, DEBUG_LOCATION); +} + +TEST_F(ServiceConfigEnd2endTest, NoServiceConfigAfterInvalidServiceConfigTest) { + StartServers(1); + auto channel = BuildChannel(); + auto stub = BuildStub(channel); + SetNextResolutionInvalidServiceConfig(GetServersPorts()); + CheckRpcSendFailure(stub); + SetNextResolutionNoServiceConfig(GetServersPorts()); + CheckRpcSendOk(stub, DEBUG_LOCATION); + EXPECT_STREQ("{}", channel->GetServiceConfigJSON().c_str()); +} + +TEST_F(ServiceConfigEnd2endTest, + AnotherInvalidServiceConfigAfterInvalidServiceConfigTest) { + StartServers(1); + auto channel = BuildChannel(); + auto stub = BuildStub(channel); + SetNextResolutionInvalidServiceConfig(GetServersPorts()); + CheckRpcSendFailure(stub); + SetNextResolutionInvalidServiceConfig(GetServersPorts()); + CheckRpcSendFailure(stub); +} + +TEST_F(ServiceConfigEnd2endTest, InvalidDefaultServiceConfigTest) { + StartServers(1); + auto channel = BuildChannelWithInvalidDefaultServiceConfig(); + auto stub = BuildStub(channel); + // An invalid default service config results in a lame channel which fails all + // RPCs + CheckRpcSendFailure(stub); +} + +TEST_F(ServiceConfigEnd2endTest, + InvalidDefaultServiceConfigTestWithValidServiceConfig) { + StartServers(1); + auto channel = BuildChannelWithInvalidDefaultServiceConfig(); + auto stub = BuildStub(channel); + CheckRpcSendFailure(stub); + // An invalid default service config results in a lame channel which fails all + // RPCs + SetNextResolutionValidServiceConfig(GetServersPorts()); + CheckRpcSendFailure(stub); +} + +TEST_F(ServiceConfigEnd2endTest, + InvalidDefaultServiceConfigTestWithInvalidServiceConfig) { + StartServers(1); + auto channel = BuildChannelWithInvalidDefaultServiceConfig(); + auto stub = BuildStub(channel); + CheckRpcSendFailure(stub); + // An invalid default service config results in a lame channel which fails all + // RPCs + SetNextResolutionInvalidServiceConfig(GetServersPorts()); + CheckRpcSendFailure(stub); +} + +TEST_F(ServiceConfigEnd2endTest, + InvalidDefaultServiceConfigTestWithNoServiceConfig) { + StartServers(1); + auto channel = BuildChannelWithInvalidDefaultServiceConfig(); + auto stub = BuildStub(channel); + CheckRpcSendFailure(stub); + // An invalid default service config results in a lame channel which fails all + // RPCs + SetNextResolutionNoServiceConfig(GetServersPorts()); + CheckRpcSendFailure(stub); +} + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::TestEnvironment env(argc, argv); + const auto result = RUN_ALL_TESTS(); + return result; +} diff --git a/contrib/libs/grpc/test/cpp/end2end/shutdown_test.cc b/contrib/libs/grpc/test/cpp/end2end/shutdown_test.cc new file mode 100644 index 0000000000..3aa7a766c4 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/end2end/shutdown_test.cc @@ -0,0 +1,170 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <thread> + +#include <grpc/grpc.h> +#include <grpc/support/log.h> +#include <grpc/support/sync.h> +#include <grpcpp/channel.h> +#include <grpcpp/client_context.h> +#include <grpcpp/create_channel.h> +#include <grpcpp/server.h> +#include <grpcpp/server_builder.h> +#include <grpcpp/server_context.h> + +#include "src/core/lib/gpr/env.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/util/test_credentials_provider.h" + +#include <gtest/gtest.h> + +using grpc::testing::EchoRequest; +using grpc::testing::EchoResponse; + +namespace grpc { +namespace testing { + +class TestServiceImpl : public ::grpc::testing::EchoTestService::Service { + public: + explicit TestServiceImpl(gpr_event* ev) : ev_(ev) {} + + Status Echo(ServerContext* context, const EchoRequest* /*request*/, + EchoResponse* /*response*/) override { + gpr_event_set(ev_, (void*)1); + while (!context->IsCancelled()) { + } + return Status::OK; + } + + private: + gpr_event* ev_; +}; + +class ShutdownTest : public ::testing::TestWithParam<string> { + public: + ShutdownTest() : shutdown_(false), service_(&ev_) { gpr_event_init(&ev_); } + + void SetUp() override { + port_ = grpc_pick_unused_port_or_die(); + server_ = SetUpServer(port_); + } + + std::unique_ptr<Server> SetUpServer(const int port) { + TString server_address = "localhost:" + to_string(port); + + ServerBuilder builder; + auto server_creds = + GetCredentialsProvider()->GetServerCredentials(GetParam()); + builder.AddListeningPort(server_address, server_creds); + builder.RegisterService(&service_); + std::unique_ptr<Server> server = builder.BuildAndStart(); + return server; + } + + void TearDown() override { GPR_ASSERT(shutdown_); } + + void ResetStub() { + string target = "dns:localhost:" + to_string(port_); + ChannelArguments args; + auto channel_creds = + GetCredentialsProvider()->GetChannelCredentials(GetParam(), &args); + channel_ = ::grpc::CreateCustomChannel(target, channel_creds, args); + stub_ = grpc::testing::EchoTestService::NewStub(channel_); + } + + string to_string(const int number) { + std::stringstream strs; + strs << number; + return strs.str(); + } + + void SendRequest() { + EchoRequest request; + EchoResponse response; + request.set_message("Hello"); + ClientContext context; + GPR_ASSERT(!shutdown_); + Status s = stub_->Echo(&context, request, &response); + GPR_ASSERT(shutdown_); + } + + protected: + std::shared_ptr<Channel> channel_; + std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_; + std::unique_ptr<Server> server_; + bool shutdown_; + int port_; + gpr_event ev_; + TestServiceImpl service_; +}; + +std::vector<string> GetAllCredentialsTypeList() { + std::vector<TString> credentials_types; + if (GetCredentialsProvider()->GetChannelCredentials(kInsecureCredentialsType, + nullptr) != nullptr) { + credentials_types.push_back(kInsecureCredentialsType); + } + auto sec_list = GetCredentialsProvider()->GetSecureCredentialsTypeList(); + for (auto sec = sec_list.begin(); sec != sec_list.end(); sec++) { + credentials_types.push_back(*sec); + } + GPR_ASSERT(!credentials_types.empty()); + + TString credentials_type_list("credentials types:"); + for (const string& type : credentials_types) { + credentials_type_list.append(" " + type); + } + gpr_log(GPR_INFO, "%s", credentials_type_list.c_str()); + return credentials_types; +} + +INSTANTIATE_TEST_SUITE_P(End2EndShutdown, ShutdownTest, + ::testing::ValuesIn(GetAllCredentialsTypeList())); + +// TODO(ctiller): leaked objects in this test +TEST_P(ShutdownTest, ShutdownTest) { + ResetStub(); + + // send the request in a background thread + std::thread thr(std::bind(&ShutdownTest::SendRequest, this)); + + // wait for the server to get the event + gpr_event_wait(&ev_, gpr_inf_future(GPR_CLOCK_MONOTONIC)); + + shutdown_ = true; + + // shutdown should trigger cancellation causing everything to shutdown + auto deadline = + std::chrono::system_clock::now() + std::chrono::microseconds(100); + server_->Shutdown(deadline); + EXPECT_GE(std::chrono::system_clock::now(), deadline); + + thr.join(); +} + +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/contrib/libs/grpc/test/cpp/end2end/streaming_throughput_test.cc b/contrib/libs/grpc/test/cpp/end2end/streaming_throughput_test.cc new file mode 100644 index 0000000000..f2252063fb --- /dev/null +++ b/contrib/libs/grpc/test/cpp/end2end/streaming_throughput_test.cc @@ -0,0 +1,193 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <time.h> +#include <mutex> +#include <thread> + +#include <grpc/grpc.h> +#include <grpc/support/atm.h> +#include <grpc/support/log.h> +#include <grpc/support/time.h> +#include <grpcpp/channel.h> +#include <grpcpp/client_context.h> +#include <grpcpp/create_channel.h> +#include <grpcpp/security/credentials.h> +#include <grpcpp/security/server_credentials.h> +#include <grpcpp/server.h> +#include <grpcpp/server_builder.h> +#include <grpcpp/server_context.h> + +#include "src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +#include <gtest/gtest.h> + +using grpc::testing::EchoRequest; +using grpc::testing::EchoResponse; +using std::chrono::system_clock; + +const char* kLargeString = + "(" + "To be, or not to be- that is the question:" + "Whether 'tis nobler in the mind to suffer" + "The slings and arrows of outrageous fortune" + "Or to take arms against a sea of troubles," + "And by opposing end them. To die- to sleep-" + "No more; and by a sleep to say we end" + "The heartache, and the thousand natural shock" + "That flesh is heir to. 'Tis a consummation" + "Devoutly to be wish'd. To die- to sleep." + "To sleep- perchance to dream: ay, there's the rub!" + "For in that sleep of death what dreams may come" + "When we have shuffled off this mortal coil," + "Must give us pause. There's the respect" + "That makes calamity of so long life." + "For who would bear the whips and scorns of time," + "Th' oppressor's wrong, the proud man's contumely," + "The pangs of despis'd love, the law's delay," + "The insolence of office, and the spurns" + "That patient merit of th' unworthy takes," + "When he himself might his quietus make" + "With a bare bodkin? Who would these fardels bear," + "To grunt and sweat under a weary life," + "But that the dread of something after death-" + "The undiscover'd country, from whose bourn" + "No traveller returns- puzzles the will," + "And makes us rather bear those ills we have" + "Than fly to others that we know not of?" + "Thus conscience does make cowards of us all," + "And thus the native hue of resolution" + "Is sicklied o'er with the pale cast of thought," + "And enterprises of great pith and moment" + "With this regard their currents turn awry" + "And lose the name of action.- Soft you now!" + "The fair Ophelia!- Nymph, in thy orisons" + "Be all my sins rememb'red."; + +namespace grpc { +namespace testing { + +class TestServiceImpl : public ::grpc::testing::EchoTestService::Service { + public: + static void BidiStream_Sender( + ServerReaderWriter<EchoResponse, EchoRequest>* stream, + gpr_atm* should_exit) { + EchoResponse response; + response.set_message(kLargeString); + while (gpr_atm_acq_load(should_exit) == static_cast<gpr_atm>(0)) { + struct timespec tv = {0, 1000000}; // 1 ms + struct timespec rem; + // TODO (vpai): Mark this blocking + while (nanosleep(&tv, &rem) != 0) { + tv = rem; + }; + + stream->Write(response); + } + } + + // Only implement the one method we will be calling for brevity. + Status BidiStream( + ServerContext* /*context*/, + ServerReaderWriter<EchoResponse, EchoRequest>* stream) override { + EchoRequest request; + gpr_atm should_exit; + gpr_atm_rel_store(&should_exit, static_cast<gpr_atm>(0)); + + std::thread sender( + std::bind(&TestServiceImpl::BidiStream_Sender, stream, &should_exit)); + + while (stream->Read(&request)) { + struct timespec tv = {0, 3000000}; // 3 ms + struct timespec rem; + // TODO (vpai): Mark this blocking + while (nanosleep(&tv, &rem) != 0) { + tv = rem; + }; + } + gpr_atm_rel_store(&should_exit, static_cast<gpr_atm>(1)); + sender.join(); + return Status::OK; + } +}; + +class End2endTest : public ::testing::Test { + protected: + void SetUp() override { + int port = grpc_pick_unused_port_or_die(); + server_address_ << "localhost:" << port; + // Setup server + ServerBuilder builder; + builder.AddListeningPort(server_address_.str(), + InsecureServerCredentials()); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + } + + void TearDown() override { server_->Shutdown(); } + + void ResetStub() { + std::shared_ptr<Channel> channel = grpc::CreateChannel( + server_address_.str(), InsecureChannelCredentials()); + stub_ = grpc::testing::EchoTestService::NewStub(channel); + } + + std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_; + std::unique_ptr<Server> server_; + std::ostringstream server_address_; + TestServiceImpl service_; +}; + +static void Drainer(ClientReaderWriter<EchoRequest, EchoResponse>* reader) { + EchoResponse response; + while (reader->Read(&response)) { + // Just drain out the responses as fast as possible. + } +} + +TEST_F(End2endTest, StreamingThroughput) { + ResetStub(); + grpc::ClientContext context; + auto stream = stub_->BidiStream(&context); + + auto reader = stream.get(); + std::thread receiver(std::bind(Drainer, reader)); + + for (int i = 0; i < 10000; i++) { + EchoRequest request; + request.set_message(kLargeString); + ASSERT_TRUE(stream->Write(request)); + if (i % 1000 == 0) { + gpr_log(GPR_INFO, "Send count = %d", i); + } + } + stream->WritesDone(); + receiver.join(); +} + +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/contrib/libs/grpc/test/cpp/end2end/test_health_check_service_impl.cc b/contrib/libs/grpc/test/cpp/end2end/test_health_check_service_impl.cc new file mode 100644 index 0000000000..5b212cba31 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/end2end/test_health_check_service_impl.cc @@ -0,0 +1,98 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/cpp/end2end/test_health_check_service_impl.h" + +#include <grpc/grpc.h> + +using grpc::health::v1::HealthCheckRequest; +using grpc::health::v1::HealthCheckResponse; + +namespace grpc { +namespace testing { + +Status HealthCheckServiceImpl::Check(ServerContext* /*context*/, + const HealthCheckRequest* request, + HealthCheckResponse* response) { + std::lock_guard<std::mutex> lock(mu_); + auto iter = status_map_.find(request->service()); + if (iter == status_map_.end()) { + return Status(StatusCode::NOT_FOUND, ""); + } + response->set_status(iter->second); + return Status::OK; +} + +Status HealthCheckServiceImpl::Watch( + ServerContext* context, const HealthCheckRequest* request, + ::grpc::ServerWriter<HealthCheckResponse>* writer) { + auto last_state = HealthCheckResponse::UNKNOWN; + while (!context->IsCancelled()) { + { + std::lock_guard<std::mutex> lock(mu_); + HealthCheckResponse response; + auto iter = status_map_.find(request->service()); + if (iter == status_map_.end()) { + response.set_status(response.SERVICE_UNKNOWN); + } else { + response.set_status(iter->second); + } + if (response.status() != last_state) { + writer->Write(response, ::grpc::WriteOptions()); + last_state = response.status(); + } + } + gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC), + gpr_time_from_millis(1000, GPR_TIMESPAN))); + } + return Status::OK; +} + +void HealthCheckServiceImpl::SetStatus( + const TString& service_name, + HealthCheckResponse::ServingStatus status) { + std::lock_guard<std::mutex> lock(mu_); + if (shutdown_) { + status = HealthCheckResponse::NOT_SERVING; + } + status_map_[service_name] = status; +} + +void HealthCheckServiceImpl::SetAll(HealthCheckResponse::ServingStatus status) { + std::lock_guard<std::mutex> lock(mu_); + if (shutdown_) { + return; + } + for (auto iter = status_map_.begin(); iter != status_map_.end(); ++iter) { + iter->second = status; + } +} + +void HealthCheckServiceImpl::Shutdown() { + std::lock_guard<std::mutex> lock(mu_); + if (shutdown_) { + return; + } + shutdown_ = true; + for (auto iter = status_map_.begin(); iter != status_map_.end(); ++iter) { + iter->second = HealthCheckResponse::NOT_SERVING; + } +} + +} // namespace testing +} // namespace grpc diff --git a/contrib/libs/grpc/test/cpp/end2end/test_health_check_service_impl.h b/contrib/libs/grpc/test/cpp/end2end/test_health_check_service_impl.h new file mode 100644 index 0000000000..d370e4693a --- /dev/null +++ b/contrib/libs/grpc/test/cpp/end2end/test_health_check_service_impl.h @@ -0,0 +1,58 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ +#ifndef GRPC_TEST_CPP_END2END_TEST_HEALTH_CHECK_SERVICE_IMPL_H +#define GRPC_TEST_CPP_END2END_TEST_HEALTH_CHECK_SERVICE_IMPL_H + +#include <map> +#include <mutex> + +#include <grpcpp/server_context.h> +#include <grpcpp/support/status.h> + +#include "src/proto/grpc/health/v1/health.grpc.pb.h" + +namespace grpc { +namespace testing { + +// A sample sync implementation of the health checking service. This does the +// same thing as the default one. +class HealthCheckServiceImpl : public health::v1::Health::Service { + public: + Status Check(ServerContext* context, + const health::v1::HealthCheckRequest* request, + health::v1::HealthCheckResponse* response) override; + Status Watch(ServerContext* context, + const health::v1::HealthCheckRequest* request, + ServerWriter<health::v1::HealthCheckResponse>* writer) override; + void SetStatus(const TString& service_name, + health::v1::HealthCheckResponse::ServingStatus status); + void SetAll(health::v1::HealthCheckResponse::ServingStatus status); + + void Shutdown(); + + private: + std::mutex mu_; + bool shutdown_ = false; + std::map<const TString, health::v1::HealthCheckResponse::ServingStatus> + status_map_; +}; + +} // namespace testing +} // namespace grpc + +#endif // GRPC_TEST_CPP_END2END_TEST_HEALTH_CHECK_SERVICE_IMPL_H diff --git a/contrib/libs/grpc/test/cpp/end2end/test_service_impl.cc b/contrib/libs/grpc/test/cpp/end2end/test_service_impl.cc new file mode 100644 index 0000000000..078977e824 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/end2end/test_service_impl.cc @@ -0,0 +1,638 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/cpp/end2end/test_service_impl.h" + +#include <grpc/support/log.h> +#include <grpcpp/alarm.h> +#include <grpcpp/security/credentials.h> +#include <grpcpp/server_context.h> +#include <gtest/gtest.h> + +#include <util/generic/string.h> +#include <thread> + +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/cpp/util/string_ref_helper.h" + +using std::chrono::system_clock; + +namespace grpc { +namespace testing { +namespace internal { + +// When echo_deadline is requested, deadline seen in the ServerContext is set in +// the response in seconds. +void MaybeEchoDeadline(experimental::ServerContextBase* context, + const EchoRequest* request, EchoResponse* response) { + if (request->has_param() && request->param().echo_deadline()) { + gpr_timespec deadline = gpr_inf_future(GPR_CLOCK_REALTIME); + if (context->deadline() != system_clock::time_point::max()) { + Timepoint2Timespec(context->deadline(), &deadline); + } + response->mutable_param()->set_request_deadline(deadline.tv_sec); + } +} + +void CheckServerAuthContext(const experimental::ServerContextBase* context, + const TString& expected_transport_security_type, + const TString& expected_client_identity) { + std::shared_ptr<const AuthContext> auth_ctx = context->auth_context(); + std::vector<grpc::string_ref> tst = + auth_ctx->FindPropertyValues("transport_security_type"); + EXPECT_EQ(1u, tst.size()); + EXPECT_EQ(expected_transport_security_type.c_str(), ToString(tst[0])); + if (expected_client_identity.empty()) { + EXPECT_TRUE(auth_ctx->GetPeerIdentityPropertyName().empty()); + EXPECT_TRUE(auth_ctx->GetPeerIdentity().empty()); + EXPECT_FALSE(auth_ctx->IsPeerAuthenticated()); + } else { + auto identity = auth_ctx->GetPeerIdentity(); + EXPECT_TRUE(auth_ctx->IsPeerAuthenticated()); + EXPECT_EQ(1u, identity.size()); + EXPECT_EQ(expected_client_identity.c_str(), ToString(identity[0])); + } +} + +// Returns the number of pairs in metadata that exactly match the given +// key-value pair. Returns -1 if the pair wasn't found. +int MetadataMatchCount( + const std::multimap<grpc::string_ref, grpc::string_ref>& metadata, + const TString& key, const TString& value) { + int count = 0; + for (const auto& metadatum : metadata) { + if (ToString(metadatum.first) == key && + ToString(metadatum.second) == value) { + count++; + } + } + return count; +} + +int GetIntValueFromMetadataHelper( + const char* key, + const std::multimap<grpc::string_ref, grpc::string_ref>& metadata, + int default_value) { + if (metadata.find(key) != metadata.end()) { + std::istringstream iss(ToString(metadata.find(key)->second)); + iss >> default_value; + gpr_log(GPR_INFO, "%s : %d", key, default_value); + } + + return default_value; +} + +int GetIntValueFromMetadata( + const char* key, + const std::multimap<grpc::string_ref, grpc::string_ref>& metadata, + int default_value) { + return GetIntValueFromMetadataHelper(key, metadata, default_value); +} + +void ServerTryCancel(ServerContext* context) { + EXPECT_FALSE(context->IsCancelled()); + context->TryCancel(); + gpr_log(GPR_INFO, "Server called TryCancel() to cancel the request"); + // Now wait until it's really canceled + while (!context->IsCancelled()) { + gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_micros(1000, GPR_TIMESPAN))); + } +} + +void ServerTryCancelNonblocking(experimental::CallbackServerContext* context) { + EXPECT_FALSE(context->IsCancelled()); + context->TryCancel(); + gpr_log(GPR_INFO, + "Server called TryCancelNonblocking() to cancel the request"); +} + +} // namespace internal + +experimental::ServerUnaryReactor* CallbackTestServiceImpl::Echo( + experimental::CallbackServerContext* context, const EchoRequest* request, + EchoResponse* response) { + class Reactor : public ::grpc::experimental::ServerUnaryReactor { + public: + Reactor(CallbackTestServiceImpl* service, + experimental::CallbackServerContext* ctx, + const EchoRequest* request, EchoResponse* response) + : service_(service), ctx_(ctx), req_(request), resp_(response) { + // It should be safe to call IsCancelled here, even though we don't know + // the result. Call it asynchronously to see if we trigger any data races. + // Join it in OnDone (technically that could be blocking but shouldn't be + // for very long). + async_cancel_check_ = std::thread([this] { (void)ctx_->IsCancelled(); }); + + started_ = true; + + if (request->has_param() && + request->param().server_notify_client_when_started()) { + service->signaller_.SignalClientThatRpcStarted(); + // Block on the "wait to continue" decision in a different thread since + // we can't tie up an EM thread with blocking events. We can join it in + // OnDone since it would definitely be done by then. + rpc_wait_thread_ = std::thread([this] { + service_->signaller_.ServerWaitToContinue(); + StartRpc(); + }); + } else { + StartRpc(); + } + } + + void StartRpc() { + if (req_->has_param() && req_->param().server_sleep_us() > 0) { + // Set an alarm for that much time + alarm_.experimental().Set( + gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC), + gpr_time_from_micros(req_->param().server_sleep_us(), + GPR_TIMESPAN)), + [this](bool ok) { NonDelayed(ok); }); + return; + } + NonDelayed(true); + } + void OnSendInitialMetadataDone(bool ok) override { + EXPECT_TRUE(ok); + initial_metadata_sent_ = true; + } + void OnCancel() override { + EXPECT_TRUE(started_); + EXPECT_TRUE(ctx_->IsCancelled()); + on_cancel_invoked_ = true; + std::lock_guard<std::mutex> l(cancel_mu_); + cancel_cv_.notify_one(); + } + void OnDone() override { + if (req_->has_param() && req_->param().echo_metadata_initially()) { + EXPECT_TRUE(initial_metadata_sent_); + } + EXPECT_EQ(ctx_->IsCancelled(), on_cancel_invoked_); + // Validate that finishing with a non-OK status doesn't cause cancellation + if (req_->has_param() && req_->param().has_expected_error()) { + EXPECT_FALSE(on_cancel_invoked_); + } + async_cancel_check_.join(); + if (rpc_wait_thread_.joinable()) { + rpc_wait_thread_.join(); + } + if (finish_when_cancelled_.joinable()) { + finish_when_cancelled_.join(); + } + delete this; + } + + private: + void NonDelayed(bool ok) { + if (!ok) { + EXPECT_TRUE(ctx_->IsCancelled()); + Finish(Status::CANCELLED); + return; + } + if (req_->has_param() && req_->param().server_die()) { + gpr_log(GPR_ERROR, "The request should not reach application handler."); + GPR_ASSERT(0); + } + if (req_->has_param() && req_->param().has_expected_error()) { + const auto& error = req_->param().expected_error(); + Finish(Status(static_cast<StatusCode>(error.code()), + error.error_message(), error.binary_error_details())); + return; + } + int server_try_cancel = internal::GetIntValueFromMetadata( + kServerTryCancelRequest, ctx_->client_metadata(), DO_NOT_CANCEL); + if (server_try_cancel != DO_NOT_CANCEL) { + // Since this is a unary RPC, by the time this server handler is called, + // the 'request' message is already read from the client. So the + // scenarios in server_try_cancel don't make much sense. Just cancel the + // RPC as long as server_try_cancel is not DO_NOT_CANCEL + EXPECT_FALSE(ctx_->IsCancelled()); + ctx_->TryCancel(); + gpr_log(GPR_INFO, "Server called TryCancel() to cancel the request"); + FinishWhenCancelledAsync(); + return; + } + resp_->set_message(req_->message()); + internal::MaybeEchoDeadline(ctx_, req_, resp_); + if (service_->host_) { + resp_->mutable_param()->set_host(*service_->host_); + } + if (req_->has_param() && req_->param().client_cancel_after_us()) { + { + std::unique_lock<std::mutex> lock(service_->mu_); + service_->signal_client_ = true; + } + FinishWhenCancelledAsync(); + return; + } else if (req_->has_param() && req_->param().server_cancel_after_us()) { + alarm_.experimental().Set( + gpr_time_add( + gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_micros(req_->param().server_cancel_after_us(), + GPR_TIMESPAN)), + [this](bool) { Finish(Status::CANCELLED); }); + return; + } else if (!req_->has_param() || !req_->param().skip_cancelled_check()) { + EXPECT_FALSE(ctx_->IsCancelled()); + } + + if (req_->has_param() && req_->param().echo_metadata_initially()) { + const std::multimap<grpc::string_ref, grpc::string_ref>& + client_metadata = ctx_->client_metadata(); + for (const auto& metadatum : client_metadata) { + ctx_->AddInitialMetadata(ToString(metadatum.first), + ToString(metadatum.second)); + } + StartSendInitialMetadata(); + } + + if (req_->has_param() && req_->param().echo_metadata()) { + const std::multimap<grpc::string_ref, grpc::string_ref>& + client_metadata = ctx_->client_metadata(); + for (const auto& metadatum : client_metadata) { + ctx_->AddTrailingMetadata(ToString(metadatum.first), + ToString(metadatum.second)); + } + // Terminate rpc with error and debug info in trailer. + if (req_->param().debug_info().stack_entries_size() || + !req_->param().debug_info().detail().empty()) { + TString serialized_debug_info = + req_->param().debug_info().SerializeAsString(); + ctx_->AddTrailingMetadata(kDebugInfoTrailerKey, + serialized_debug_info); + Finish(Status::CANCELLED); + return; + } + } + if (req_->has_param() && + (req_->param().expected_client_identity().length() > 0 || + req_->param().check_auth_context())) { + internal::CheckServerAuthContext( + ctx_, req_->param().expected_transport_security_type(), + req_->param().expected_client_identity()); + } + if (req_->has_param() && req_->param().response_message_length() > 0) { + resp_->set_message( + TString(req_->param().response_message_length(), '\0')); + } + if (req_->has_param() && req_->param().echo_peer()) { + resp_->mutable_param()->set_peer(ctx_->peer().c_str()); + } + Finish(Status::OK); + } + void FinishWhenCancelledAsync() { + finish_when_cancelled_ = std::thread([this] { + std::unique_lock<std::mutex> l(cancel_mu_); + cancel_cv_.wait(l, [this] { return ctx_->IsCancelled(); }); + Finish(Status::CANCELLED); + }); + } + + CallbackTestServiceImpl* const service_; + experimental::CallbackServerContext* const ctx_; + const EchoRequest* const req_; + EchoResponse* const resp_; + Alarm alarm_; + std::mutex cancel_mu_; + std::condition_variable cancel_cv_; + bool initial_metadata_sent_ = false; + bool started_ = false; + bool on_cancel_invoked_ = false; + std::thread async_cancel_check_; + std::thread rpc_wait_thread_; + std::thread finish_when_cancelled_; + }; + + return new Reactor(this, context, request, response); +} + +experimental::ServerUnaryReactor* +CallbackTestServiceImpl::CheckClientInitialMetadata( + experimental::CallbackServerContext* context, const SimpleRequest42*, + SimpleResponse42*) { + class Reactor : public ::grpc::experimental::ServerUnaryReactor { + public: + explicit Reactor(experimental::CallbackServerContext* ctx) { + EXPECT_EQ(internal::MetadataMatchCount(ctx->client_metadata(), + kCheckClientInitialMetadataKey, + kCheckClientInitialMetadataVal), + 1); + EXPECT_EQ(ctx->client_metadata().count(kCheckClientInitialMetadataKey), + 1u); + Finish(Status::OK); + } + void OnDone() override { delete this; } + }; + + return new Reactor(context); +} + +experimental::ServerReadReactor<EchoRequest>* +CallbackTestServiceImpl::RequestStream( + experimental::CallbackServerContext* context, EchoResponse* response) { + // If 'server_try_cancel' is set in the metadata, the RPC is cancelled by + // the server by calling ServerContext::TryCancel() depending on the + // value: + // CANCEL_BEFORE_PROCESSING: The RPC is cancelled before the server + // reads any message from the client CANCEL_DURING_PROCESSING: The RPC + // is cancelled while the server is reading messages from the client + // CANCEL_AFTER_PROCESSING: The RPC is cancelled after the server reads + // all the messages from the client + int server_try_cancel = internal::GetIntValueFromMetadata( + kServerTryCancelRequest, context->client_metadata(), DO_NOT_CANCEL); + if (server_try_cancel == CANCEL_BEFORE_PROCESSING) { + internal::ServerTryCancelNonblocking(context); + // Don't need to provide a reactor since the RPC is canceled + return nullptr; + } + + class Reactor : public ::grpc::experimental::ServerReadReactor<EchoRequest> { + public: + Reactor(experimental::CallbackServerContext* ctx, EchoResponse* response, + int server_try_cancel) + : ctx_(ctx), + response_(response), + server_try_cancel_(server_try_cancel) { + EXPECT_NE(server_try_cancel, CANCEL_BEFORE_PROCESSING); + response->set_message(""); + + if (server_try_cancel_ == CANCEL_DURING_PROCESSING) { + ctx->TryCancel(); + // Don't wait for it here + } + StartRead(&request_); + setup_done_ = true; + } + void OnDone() override { delete this; } + void OnCancel() override { + EXPECT_TRUE(setup_done_); + EXPECT_TRUE(ctx_->IsCancelled()); + FinishOnce(Status::CANCELLED); + } + void OnReadDone(bool ok) override { + if (ok) { + response_->mutable_message()->append(request_.message()); + num_msgs_read_++; + StartRead(&request_); + } else { + gpr_log(GPR_INFO, "Read: %d messages", num_msgs_read_); + + if (server_try_cancel_ == CANCEL_DURING_PROCESSING) { + // Let OnCancel recover this + return; + } + if (server_try_cancel_ == CANCEL_AFTER_PROCESSING) { + internal::ServerTryCancelNonblocking(ctx_); + return; + } + FinishOnce(Status::OK); + } + } + + private: + void FinishOnce(const Status& s) { + std::lock_guard<std::mutex> l(finish_mu_); + if (!finished_) { + Finish(s); + finished_ = true; + } + } + + experimental::CallbackServerContext* const ctx_; + EchoResponse* const response_; + EchoRequest request_; + int num_msgs_read_{0}; + int server_try_cancel_; + std::mutex finish_mu_; + bool finished_{false}; + bool setup_done_{false}; + }; + + return new Reactor(context, response, server_try_cancel); +} + +// Return 'kNumResponseStreamMsgs' messages. +// TODO(yangg) make it generic by adding a parameter into EchoRequest +experimental::ServerWriteReactor<EchoResponse>* +CallbackTestServiceImpl::ResponseStream( + experimental::CallbackServerContext* context, const EchoRequest* request) { + // If 'server_try_cancel' is set in the metadata, the RPC is cancelled by + // the server by calling ServerContext::TryCancel() depending on the + // value: + // CANCEL_BEFORE_PROCESSING: The RPC is cancelled before the server + // reads any message from the client CANCEL_DURING_PROCESSING: The RPC + // is cancelled while the server is reading messages from the client + // CANCEL_AFTER_PROCESSING: The RPC is cancelled after the server reads + // all the messages from the client + int server_try_cancel = internal::GetIntValueFromMetadata( + kServerTryCancelRequest, context->client_metadata(), DO_NOT_CANCEL); + if (server_try_cancel == CANCEL_BEFORE_PROCESSING) { + internal::ServerTryCancelNonblocking(context); + } + + class Reactor + : public ::grpc::experimental::ServerWriteReactor<EchoResponse> { + public: + Reactor(experimental::CallbackServerContext* ctx, + const EchoRequest* request, int server_try_cancel) + : ctx_(ctx), request_(request), server_try_cancel_(server_try_cancel) { + server_coalescing_api_ = internal::GetIntValueFromMetadata( + kServerUseCoalescingApi, ctx->client_metadata(), 0); + server_responses_to_send_ = internal::GetIntValueFromMetadata( + kServerResponseStreamsToSend, ctx->client_metadata(), + kServerDefaultResponseStreamsToSend); + if (server_try_cancel_ == CANCEL_DURING_PROCESSING) { + ctx->TryCancel(); + } + if (server_try_cancel_ != CANCEL_BEFORE_PROCESSING) { + if (num_msgs_sent_ < server_responses_to_send_) { + NextWrite(); + } + } + setup_done_ = true; + } + void OnDone() override { delete this; } + void OnCancel() override { + EXPECT_TRUE(setup_done_); + EXPECT_TRUE(ctx_->IsCancelled()); + FinishOnce(Status::CANCELLED); + } + void OnWriteDone(bool /*ok*/) override { + if (num_msgs_sent_ < server_responses_to_send_) { + NextWrite(); + } else if (server_coalescing_api_ != 0) { + // We would have already done Finish just after the WriteLast + } else if (server_try_cancel_ == CANCEL_DURING_PROCESSING) { + // Let OnCancel recover this + } else if (server_try_cancel_ == CANCEL_AFTER_PROCESSING) { + internal::ServerTryCancelNonblocking(ctx_); + } else { + FinishOnce(Status::OK); + } + } + + private: + void FinishOnce(const Status& s) { + std::lock_guard<std::mutex> l(finish_mu_); + if (!finished_) { + Finish(s); + finished_ = true; + } + } + + void NextWrite() { + response_.set_message(request_->message() + + ::ToString(num_msgs_sent_)); + if (num_msgs_sent_ == server_responses_to_send_ - 1 && + server_coalescing_api_ != 0) { + { + std::lock_guard<std::mutex> l(finish_mu_); + if (!finished_) { + num_msgs_sent_++; + StartWriteLast(&response_, WriteOptions()); + } + } + // If we use WriteLast, we shouldn't wait before attempting Finish + FinishOnce(Status::OK); + } else { + std::lock_guard<std::mutex> l(finish_mu_); + if (!finished_) { + num_msgs_sent_++; + StartWrite(&response_); + } + } + } + experimental::CallbackServerContext* const ctx_; + const EchoRequest* const request_; + EchoResponse response_; + int num_msgs_sent_{0}; + int server_try_cancel_; + int server_coalescing_api_; + int server_responses_to_send_; + std::mutex finish_mu_; + bool finished_{false}; + bool setup_done_{false}; + }; + return new Reactor(context, request, server_try_cancel); +} + +experimental::ServerBidiReactor<EchoRequest, EchoResponse>* +CallbackTestServiceImpl::BidiStream( + experimental::CallbackServerContext* context) { + class Reactor : public ::grpc::experimental::ServerBidiReactor<EchoRequest, + EchoResponse> { + public: + explicit Reactor(experimental::CallbackServerContext* ctx) : ctx_(ctx) { + // If 'server_try_cancel' is set in the metadata, the RPC is cancelled by + // the server by calling ServerContext::TryCancel() depending on the + // value: + // CANCEL_BEFORE_PROCESSING: The RPC is cancelled before the server + // reads any message from the client CANCEL_DURING_PROCESSING: The RPC + // is cancelled while the server is reading messages from the client + // CANCEL_AFTER_PROCESSING: The RPC is cancelled after the server reads + // all the messages from the client + server_try_cancel_ = internal::GetIntValueFromMetadata( + kServerTryCancelRequest, ctx->client_metadata(), DO_NOT_CANCEL); + server_write_last_ = internal::GetIntValueFromMetadata( + kServerFinishAfterNReads, ctx->client_metadata(), 0); + if (server_try_cancel_ == CANCEL_BEFORE_PROCESSING) { + internal::ServerTryCancelNonblocking(ctx); + } else { + if (server_try_cancel_ == CANCEL_DURING_PROCESSING) { + ctx->TryCancel(); + } + StartRead(&request_); + } + setup_done_ = true; + } + void OnDone() override { + { + // Use the same lock as finish to make sure that OnDone isn't inlined. + std::lock_guard<std::mutex> l(finish_mu_); + EXPECT_TRUE(finished_); + finish_thread_.join(); + } + delete this; + } + void OnCancel() override { + EXPECT_TRUE(setup_done_); + EXPECT_TRUE(ctx_->IsCancelled()); + FinishOnce(Status::CANCELLED); + } + void OnReadDone(bool ok) override { + if (ok) { + num_msgs_read_++; + response_.set_message(request_.message()); + std::lock_guard<std::mutex> l(finish_mu_); + if (!finished_) { + if (num_msgs_read_ == server_write_last_) { + StartWriteLast(&response_, WriteOptions()); + // If we use WriteLast, we shouldn't wait before attempting Finish + } else { + StartWrite(&response_); + return; + } + } + } + + if (server_try_cancel_ == CANCEL_DURING_PROCESSING) { + // Let OnCancel handle this + } else if (server_try_cancel_ == CANCEL_AFTER_PROCESSING) { + internal::ServerTryCancelNonblocking(ctx_); + } else { + FinishOnce(Status::OK); + } + } + void OnWriteDone(bool /*ok*/) override { + std::lock_guard<std::mutex> l(finish_mu_); + if (!finished_) { + StartRead(&request_); + } + } + + private: + void FinishOnce(const Status& s) { + std::lock_guard<std::mutex> l(finish_mu_); + if (!finished_) { + finished_ = true; + // Finish asynchronously to make sure that there are no deadlocks. + finish_thread_ = std::thread([this, s] { + std::lock_guard<std::mutex> l(finish_mu_); + Finish(s); + }); + } + } + + experimental::CallbackServerContext* const ctx_; + EchoRequest request_; + EchoResponse response_; + int num_msgs_read_{0}; + int server_try_cancel_; + int server_write_last_; + std::mutex finish_mu_; + bool finished_{false}; + bool setup_done_{false}; + std::thread finish_thread_; + }; + + return new Reactor(context); +} + +} // namespace testing +} // namespace grpc diff --git a/contrib/libs/grpc/test/cpp/end2end/test_service_impl.h b/contrib/libs/grpc/test/cpp/end2end/test_service_impl.h new file mode 100644 index 0000000000..5f207f1979 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/end2end/test_service_impl.h @@ -0,0 +1,495 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef GRPC_TEST_CPP_END2END_TEST_SERVICE_IMPL_H +#define GRPC_TEST_CPP_END2END_TEST_SERVICE_IMPL_H + +#include <condition_variable> +#include <memory> +#include <mutex> + +#include <grpc/grpc.h> +#include <grpc/support/log.h> +#include <grpcpp/alarm.h> +#include <grpcpp/security/credentials.h> +#include <grpcpp/server_context.h> +#include <gtest/gtest.h> + +#include <util/generic/string.h> +#include <thread> + +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/cpp/util/string_ref_helper.h" + +#include <util/string/cast.h> + +using std::chrono::system_clock; + +namespace grpc { +namespace testing { + +const int kServerDefaultResponseStreamsToSend = 3; +const char* const kServerResponseStreamsToSend = "server_responses_to_send"; +const char* const kServerTryCancelRequest = "server_try_cancel"; +const char* const kDebugInfoTrailerKey = "debug-info-bin"; +const char* const kServerFinishAfterNReads = "server_finish_after_n_reads"; +const char* const kServerUseCoalescingApi = "server_use_coalescing_api"; +const char* const kCheckClientInitialMetadataKey = "custom_client_metadata"; +const char* const kCheckClientInitialMetadataVal = "Value for client metadata"; + +typedef enum { + DO_NOT_CANCEL = 0, + CANCEL_BEFORE_PROCESSING, + CANCEL_DURING_PROCESSING, + CANCEL_AFTER_PROCESSING +} ServerTryCancelRequestPhase; + +namespace internal { +// When echo_deadline is requested, deadline seen in the ServerContext is set in +// the response in seconds. +void MaybeEchoDeadline(experimental::ServerContextBase* context, + const EchoRequest* request, EchoResponse* response); + +void CheckServerAuthContext(const experimental::ServerContextBase* context, + const TString& expected_transport_security_type, + const TString& expected_client_identity); + +// Returns the number of pairs in metadata that exactly match the given +// key-value pair. Returns -1 if the pair wasn't found. +int MetadataMatchCount( + const std::multimap<grpc::string_ref, grpc::string_ref>& metadata, + const TString& key, const TString& value); + +int GetIntValueFromMetadataHelper( + const char* key, + const std::multimap<grpc::string_ref, grpc::string_ref>& metadata, + int default_value); + +int GetIntValueFromMetadata( + const char* key, + const std::multimap<grpc::string_ref, grpc::string_ref>& metadata, + int default_value); + +void ServerTryCancel(ServerContext* context); +} // namespace internal + +class TestServiceSignaller { + public: + void ClientWaitUntilRpcStarted() { + std::unique_lock<std::mutex> lock(mu_); + cv_rpc_started_.wait(lock, [this] { return rpc_started_; }); + } + void ServerWaitToContinue() { + std::unique_lock<std::mutex> lock(mu_); + cv_server_continue_.wait(lock, [this] { return server_should_continue_; }); + } + void SignalClientThatRpcStarted() { + std::unique_lock<std::mutex> lock(mu_); + rpc_started_ = true; + cv_rpc_started_.notify_one(); + } + void SignalServerToContinue() { + std::unique_lock<std::mutex> lock(mu_); + server_should_continue_ = true; + cv_server_continue_.notify_one(); + } + + private: + std::mutex mu_; + std::condition_variable cv_rpc_started_; + bool rpc_started_ /* GUARDED_BY(mu_) */ = false; + std::condition_variable cv_server_continue_; + bool server_should_continue_ /* GUARDED_BY(mu_) */ = false; +}; + +template <typename RpcService> +class TestMultipleServiceImpl : public RpcService { + public: + TestMultipleServiceImpl() : signal_client_(false), host_() {} + explicit TestMultipleServiceImpl(const TString& host) + : signal_client_(false), host_(new TString(host)) {} + + Status Echo(ServerContext* context, const EchoRequest* request, + EchoResponse* response) { + if (request->has_param() && + request->param().server_notify_client_when_started()) { + signaller_.SignalClientThatRpcStarted(); + signaller_.ServerWaitToContinue(); + } + + // A bit of sleep to make sure that short deadline tests fail + if (request->has_param() && request->param().server_sleep_us() > 0) { + gpr_sleep_until( + gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC), + gpr_time_from_micros(request->param().server_sleep_us(), + GPR_TIMESPAN))); + } + + if (request->has_param() && request->param().server_die()) { + gpr_log(GPR_ERROR, "The request should not reach application handler."); + GPR_ASSERT(0); + } + if (request->has_param() && request->param().has_expected_error()) { + const auto& error = request->param().expected_error(); + return Status(static_cast<StatusCode>(error.code()), + error.error_message(), error.binary_error_details()); + } + int server_try_cancel = internal::GetIntValueFromMetadata( + kServerTryCancelRequest, context->client_metadata(), DO_NOT_CANCEL); + if (server_try_cancel > DO_NOT_CANCEL) { + // Since this is a unary RPC, by the time this server handler is called, + // the 'request' message is already read from the client. So the scenarios + // in server_try_cancel don't make much sense. Just cancel the RPC as long + // as server_try_cancel is not DO_NOT_CANCEL + internal::ServerTryCancel(context); + return Status::CANCELLED; + } + + response->set_message(request->message()); + internal::MaybeEchoDeadline(context, request, response); + if (host_) { + response->mutable_param()->set_host(*host_); + } + if (request->has_param() && request->param().client_cancel_after_us()) { + { + std::unique_lock<std::mutex> lock(mu_); + signal_client_ = true; + ++rpcs_waiting_for_client_cancel_; + } + while (!context->IsCancelled()) { + gpr_sleep_until(gpr_time_add( + gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_micros(request->param().client_cancel_after_us(), + GPR_TIMESPAN))); + } + { + std::unique_lock<std::mutex> lock(mu_); + --rpcs_waiting_for_client_cancel_; + } + return Status::CANCELLED; + } else if (request->has_param() && + request->param().server_cancel_after_us()) { + gpr_sleep_until(gpr_time_add( + gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_micros(request->param().server_cancel_after_us(), + GPR_TIMESPAN))); + return Status::CANCELLED; + } else if (!request->has_param() || + !request->param().skip_cancelled_check()) { + EXPECT_FALSE(context->IsCancelled()); + } + + if (request->has_param() && request->param().echo_metadata_initially()) { + const std::multimap<grpc::string_ref, grpc::string_ref>& client_metadata = + context->client_metadata(); + for (const auto& metadatum : client_metadata) { + context->AddInitialMetadata(::ToString(metadatum.first), + ::ToString(metadatum.second)); + } + } + + if (request->has_param() && request->param().echo_metadata()) { + const std::multimap<grpc::string_ref, grpc::string_ref>& client_metadata = + context->client_metadata(); + for (const auto& metadatum : client_metadata) { + context->AddTrailingMetadata(::ToString(metadatum.first), + ::ToString(metadatum.second)); + } + // Terminate rpc with error and debug info in trailer. + if (request->param().debug_info().stack_entries_size() || + !request->param().debug_info().detail().empty()) { + TString serialized_debug_info = + request->param().debug_info().SerializeAsString(); + context->AddTrailingMetadata(kDebugInfoTrailerKey, + serialized_debug_info); + return Status::CANCELLED; + } + } + if (request->has_param() && + (request->param().expected_client_identity().length() > 0 || + request->param().check_auth_context())) { + internal::CheckServerAuthContext( + context, request->param().expected_transport_security_type(), + request->param().expected_client_identity()); + } + if (request->has_param() && + request->param().response_message_length() > 0) { + response->set_message( + TString(request->param().response_message_length(), '\0')); + } + if (request->has_param() && request->param().echo_peer()) { + response->mutable_param()->set_peer(context->peer()); + } + return Status::OK; + } + + Status Echo1(ServerContext* context, const EchoRequest* request, + EchoResponse* response) { + return Echo(context, request, response); + } + + Status Echo2(ServerContext* context, const EchoRequest* request, + EchoResponse* response) { + return Echo(context, request, response); + } + + Status CheckClientInitialMetadata(ServerContext* context, + const SimpleRequest42* /*request*/, + SimpleResponse42* /*response*/) { + EXPECT_EQ(internal::MetadataMatchCount(context->client_metadata(), + kCheckClientInitialMetadataKey, + kCheckClientInitialMetadataVal), + 1); + EXPECT_EQ(1u, + context->client_metadata().count(kCheckClientInitialMetadataKey)); + return Status::OK; + } + + // Unimplemented is left unimplemented to test the returned error. + + Status RequestStream(ServerContext* context, + ServerReader<EchoRequest>* reader, + EchoResponse* response) { + // If 'server_try_cancel' is set in the metadata, the RPC is cancelled by + // the server by calling ServerContext::TryCancel() depending on the value: + // CANCEL_BEFORE_PROCESSING: The RPC is cancelled before the server reads + // any message from the client + // CANCEL_DURING_PROCESSING: The RPC is cancelled while the server is + // reading messages from the client + // CANCEL_AFTER_PROCESSING: The RPC is cancelled after the server reads + // all the messages from the client + int server_try_cancel = internal::GetIntValueFromMetadata( + kServerTryCancelRequest, context->client_metadata(), DO_NOT_CANCEL); + + EchoRequest request; + response->set_message(""); + + if (server_try_cancel == CANCEL_BEFORE_PROCESSING) { + internal::ServerTryCancel(context); + return Status::CANCELLED; + } + + std::thread* server_try_cancel_thd = nullptr; + if (server_try_cancel == CANCEL_DURING_PROCESSING) { + server_try_cancel_thd = + new std::thread([context] { internal::ServerTryCancel(context); }); + } + + int num_msgs_read = 0; + while (reader->Read(&request)) { + response->mutable_message()->append(request.message()); + } + gpr_log(GPR_INFO, "Read: %d messages", num_msgs_read); + + if (server_try_cancel_thd != nullptr) { + server_try_cancel_thd->join(); + delete server_try_cancel_thd; + return Status::CANCELLED; + } + + if (server_try_cancel == CANCEL_AFTER_PROCESSING) { + internal::ServerTryCancel(context); + return Status::CANCELLED; + } + + return Status::OK; + } + + // Return 'kNumResponseStreamMsgs' messages. + // TODO(yangg) make it generic by adding a parameter into EchoRequest + Status ResponseStream(ServerContext* context, const EchoRequest* request, + ServerWriter<EchoResponse>* writer) { + // If server_try_cancel is set in the metadata, the RPC is cancelled by the + // server by calling ServerContext::TryCancel() depending on the value: + // CANCEL_BEFORE_PROCESSING: The RPC is cancelled before the server writes + // any messages to the client + // CANCEL_DURING_PROCESSING: The RPC is cancelled while the server is + // writing messages to the client + // CANCEL_AFTER_PROCESSING: The RPC is cancelled after the server writes + // all the messages to the client + int server_try_cancel = internal::GetIntValueFromMetadata( + kServerTryCancelRequest, context->client_metadata(), DO_NOT_CANCEL); + + int server_coalescing_api = internal::GetIntValueFromMetadata( + kServerUseCoalescingApi, context->client_metadata(), 0); + + int server_responses_to_send = internal::GetIntValueFromMetadata( + kServerResponseStreamsToSend, context->client_metadata(), + kServerDefaultResponseStreamsToSend); + + if (server_try_cancel == CANCEL_BEFORE_PROCESSING) { + internal::ServerTryCancel(context); + return Status::CANCELLED; + } + + EchoResponse response; + std::thread* server_try_cancel_thd = nullptr; + if (server_try_cancel == CANCEL_DURING_PROCESSING) { + server_try_cancel_thd = + new std::thread([context] { internal::ServerTryCancel(context); }); + } + + for (int i = 0; i < server_responses_to_send; i++) { + response.set_message(request->message() + ::ToString(i)); + if (i == server_responses_to_send - 1 && server_coalescing_api != 0) { + writer->WriteLast(response, WriteOptions()); + } else { + writer->Write(response); + } + } + + if (server_try_cancel_thd != nullptr) { + server_try_cancel_thd->join(); + delete server_try_cancel_thd; + return Status::CANCELLED; + } + + if (server_try_cancel == CANCEL_AFTER_PROCESSING) { + internal::ServerTryCancel(context); + return Status::CANCELLED; + } + + return Status::OK; + } + + Status BidiStream(ServerContext* context, + ServerReaderWriter<EchoResponse, EchoRequest>* stream) { + // If server_try_cancel is set in the metadata, the RPC is cancelled by the + // server by calling ServerContext::TryCancel() depending on the value: + // CANCEL_BEFORE_PROCESSING: The RPC is cancelled before the server reads/ + // writes any messages from/to the client + // CANCEL_DURING_PROCESSING: The RPC is cancelled while the server is + // reading/writing messages from/to the client + // CANCEL_AFTER_PROCESSING: The RPC is cancelled after the server + // reads/writes all messages from/to the client + int server_try_cancel = internal::GetIntValueFromMetadata( + kServerTryCancelRequest, context->client_metadata(), DO_NOT_CANCEL); + + EchoRequest request; + EchoResponse response; + + if (server_try_cancel == CANCEL_BEFORE_PROCESSING) { + internal::ServerTryCancel(context); + return Status::CANCELLED; + } + + std::thread* server_try_cancel_thd = nullptr; + if (server_try_cancel == CANCEL_DURING_PROCESSING) { + server_try_cancel_thd = + new std::thread([context] { internal::ServerTryCancel(context); }); + } + + // kServerFinishAfterNReads suggests after how many reads, the server should + // write the last message and send status (coalesced using WriteLast) + int server_write_last = internal::GetIntValueFromMetadata( + kServerFinishAfterNReads, context->client_metadata(), 0); + + int read_counts = 0; + while (stream->Read(&request)) { + read_counts++; + gpr_log(GPR_INFO, "recv msg %s", request.message().c_str()); + response.set_message(request.message()); + if (read_counts == server_write_last) { + stream->WriteLast(response, WriteOptions()); + } else { + stream->Write(response); + } + } + + if (server_try_cancel_thd != nullptr) { + server_try_cancel_thd->join(); + delete server_try_cancel_thd; + return Status::CANCELLED; + } + + if (server_try_cancel == CANCEL_AFTER_PROCESSING) { + internal::ServerTryCancel(context); + return Status::CANCELLED; + } + + return Status::OK; + } + + // Unimplemented is left unimplemented to test the returned error. + bool signal_client() { + std::unique_lock<std::mutex> lock(mu_); + return signal_client_; + } + void ClientWaitUntilRpcStarted() { signaller_.ClientWaitUntilRpcStarted(); } + void SignalServerToContinue() { signaller_.SignalServerToContinue(); } + uint64_t RpcsWaitingForClientCancel() { + std::unique_lock<std::mutex> lock(mu_); + return rpcs_waiting_for_client_cancel_; + } + + private: + bool signal_client_; + std::mutex mu_; + TestServiceSignaller signaller_; + std::unique_ptr<TString> host_; + uint64_t rpcs_waiting_for_client_cancel_ = 0; +}; + +class CallbackTestServiceImpl + : public ::grpc::testing::EchoTestService::ExperimentalCallbackService { + public: + CallbackTestServiceImpl() : signal_client_(false), host_() {} + explicit CallbackTestServiceImpl(const TString& host) + : signal_client_(false), host_(new TString(host)) {} + + experimental::ServerUnaryReactor* Echo( + experimental::CallbackServerContext* context, const EchoRequest* request, + EchoResponse* response) override; + + experimental::ServerUnaryReactor* CheckClientInitialMetadata( + experimental::CallbackServerContext* context, const SimpleRequest42*, + SimpleResponse42*) override; + + experimental::ServerReadReactor<EchoRequest>* RequestStream( + experimental::CallbackServerContext* context, + EchoResponse* response) override; + + experimental::ServerWriteReactor<EchoResponse>* ResponseStream( + experimental::CallbackServerContext* context, + const EchoRequest* request) override; + + experimental::ServerBidiReactor<EchoRequest, EchoResponse>* BidiStream( + experimental::CallbackServerContext* context) override; + + // Unimplemented is left unimplemented to test the returned error. + bool signal_client() { + std::unique_lock<std::mutex> lock(mu_); + return signal_client_; + } + void ClientWaitUntilRpcStarted() { signaller_.ClientWaitUntilRpcStarted(); } + void SignalServerToContinue() { signaller_.SignalServerToContinue(); } + + private: + bool signal_client_; + std::mutex mu_; + TestServiceSignaller signaller_; + std::unique_ptr<TString> host_; +}; + +using TestServiceImpl = + TestMultipleServiceImpl<::grpc::testing::EchoTestService::Service>; + +} // namespace testing +} // namespace grpc + +#endif // GRPC_TEST_CPP_END2END_TEST_SERVICE_IMPL_H diff --git a/contrib/libs/grpc/test/cpp/end2end/thread/ya.make_ b/contrib/libs/grpc/test/cpp/end2end/thread/ya.make_ new file mode 100644 index 0000000000..afabda1c8f --- /dev/null +++ b/contrib/libs/grpc/test/cpp/end2end/thread/ya.make_ @@ -0,0 +1,31 @@ +GTEST_UGLY() + +OWNER( + dvshkurko + g:ymake +) + +ADDINCL( + ${ARCADIA_ROOT}/contrib/libs/grpc +) + +PEERDIR( + contrib/libs/grpc/src/proto/grpc/core + contrib/libs/grpc/src/proto/grpc/testing + contrib/libs/grpc/src/proto/grpc/testing/duplicate + contrib/libs/grpc/test/core/util + contrib/libs/grpc/test/cpp/end2end + contrib/libs/grpc/test/cpp/util +) + +NO_COMPILER_WARNINGS() + +SRCDIR( + contrib/libs/grpc/test/cpp/end2end +) + +SRCS( + thread_stress_test.cc +) + +END() diff --git a/contrib/libs/grpc/test/cpp/end2end/thread_stress_test.cc b/contrib/libs/grpc/test/cpp/end2end/thread_stress_test.cc new file mode 100644 index 0000000000..8acb953729 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/end2end/thread_stress_test.cc @@ -0,0 +1,442 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <cinttypes> +#include <mutex> +#include <thread> + +#include <grpc/grpc.h> +#include <grpc/support/time.h> +#include <grpcpp/channel.h> +#include <grpcpp/client_context.h> +#include <grpcpp/create_channel.h> +#include <grpcpp/impl/codegen/sync.h> +#include <grpcpp/resource_quota.h> +#include <grpcpp/server.h> +#include <grpcpp/server_builder.h> +#include <grpcpp/server_context.h> + +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/surface/api_trace.h" +#include "src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" + +#include <gtest/gtest.h> + +using grpc::testing::EchoRequest; +using grpc::testing::EchoResponse; +using std::chrono::system_clock; + +const int kNumThreads = 100; // Number of threads +const int kNumAsyncSendThreads = 2; +const int kNumAsyncReceiveThreads = 50; +const int kNumAsyncServerThreads = 50; +const int kNumRpcs = 1000; // Number of RPCs per thread + +namespace grpc { +namespace testing { + +class TestServiceImpl : public ::grpc::testing::EchoTestService::Service { + public: + TestServiceImpl() {} + + Status Echo(ServerContext* /*context*/, const EchoRequest* request, + EchoResponse* response) override { + response->set_message(request->message()); + return Status::OK; + } +}; + +template <class Service> +class CommonStressTest { + public: + CommonStressTest() : kMaxMessageSize_(8192) { +#if TARGET_OS_IPHONE + // Workaround Apple CFStream bug + gpr_setenv("grpc_cfstream", "0"); +#endif + } + virtual ~CommonStressTest() {} + virtual void SetUp() = 0; + virtual void TearDown() = 0; + virtual void ResetStub() = 0; + virtual bool AllowExhaustion() = 0; + grpc::testing::EchoTestService::Stub* GetStub() { return stub_.get(); } + + protected: + std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_; + std::unique_ptr<Server> server_; + + virtual void SetUpStart(ServerBuilder* builder, Service* service) = 0; + void SetUpStartCommon(ServerBuilder* builder, Service* service) { + builder->RegisterService(service); + builder->SetMaxMessageSize( + kMaxMessageSize_); // For testing max message size. + } + void SetUpEnd(ServerBuilder* builder) { server_ = builder->BuildAndStart(); } + void TearDownStart() { server_->Shutdown(); } + void TearDownEnd() {} + + private: + const int kMaxMessageSize_; +}; + +template <class Service> +class CommonStressTestInsecure : public CommonStressTest<Service> { + public: + void ResetStub() override { + std::shared_ptr<Channel> channel = grpc::CreateChannel( + server_address_.str(), InsecureChannelCredentials()); + this->stub_ = grpc::testing::EchoTestService::NewStub(channel); + } + bool AllowExhaustion() override { return false; } + + protected: + void SetUpStart(ServerBuilder* builder, Service* service) override { + int port = 5003; // grpc_pick_unused_port_or_die(); + this->server_address_ << "localhost:" << port; + // Setup server + builder->AddListeningPort(server_address_.str(), + InsecureServerCredentials()); + this->SetUpStartCommon(builder, service); + } + + private: + std::ostringstream server_address_; +}; + +template <class Service, bool allow_resource_exhaustion> +class CommonStressTestInproc : public CommonStressTest<Service> { + public: + void ResetStub() override { + ChannelArguments args; + std::shared_ptr<Channel> channel = this->server_->InProcessChannel(args); + this->stub_ = grpc::testing::EchoTestService::NewStub(channel); + } + bool AllowExhaustion() override { return allow_resource_exhaustion; } + + protected: + void SetUpStart(ServerBuilder* builder, Service* service) override { + this->SetUpStartCommon(builder, service); + } +}; + +template <class BaseClass> +class CommonStressTestSyncServer : public BaseClass { + public: + void SetUp() override { + ServerBuilder builder; + this->SetUpStart(&builder, &service_); + this->SetUpEnd(&builder); + } + void TearDown() override { + this->TearDownStart(); + this->TearDownEnd(); + } + + private: + TestServiceImpl service_; +}; + +template <class BaseClass> +class CommonStressTestSyncServerLowThreadCount : public BaseClass { + public: + void SetUp() override { + ServerBuilder builder; + ResourceQuota quota; + this->SetUpStart(&builder, &service_); + quota.SetMaxThreads(4); + builder.SetResourceQuota(quota); + this->SetUpEnd(&builder); + } + void TearDown() override { + this->TearDownStart(); + this->TearDownEnd(); + } + + private: + TestServiceImpl service_; +}; + +template <class BaseClass> +class CommonStressTestAsyncServer : public BaseClass { + public: + CommonStressTestAsyncServer() : contexts_(kNumAsyncServerThreads * 100) {} + void SetUp() override { + shutting_down_ = false; + ServerBuilder builder; + this->SetUpStart(&builder, &service_); + cq_ = builder.AddCompletionQueue(); + this->SetUpEnd(&builder); + for (int i = 0; i < kNumAsyncServerThreads * 100; i++) { + RefreshContext(i); + } + for (int i = 0; i < kNumAsyncServerThreads; i++) { + server_threads_.emplace_back(&CommonStressTestAsyncServer::ProcessRpcs, + this); + } + } + void TearDown() override { + { + grpc::internal::MutexLock l(&mu_); + this->TearDownStart(); + shutting_down_ = true; + cq_->Shutdown(); + } + + for (int i = 0; i < kNumAsyncServerThreads; i++) { + server_threads_[i].join(); + } + + void* ignored_tag; + bool ignored_ok; + while (cq_->Next(&ignored_tag, &ignored_ok)) + ; + this->TearDownEnd(); + } + + private: + void ProcessRpcs() { + void* tag; + bool ok; + while (cq_->Next(&tag, &ok)) { + if (ok) { + int i = static_cast<int>(reinterpret_cast<intptr_t>(tag)); + switch (contexts_[i].state) { + case Context::READY: { + contexts_[i].state = Context::DONE; + EchoResponse send_response; + send_response.set_message(contexts_[i].recv_request.message()); + contexts_[i].response_writer->Finish(send_response, Status::OK, + tag); + break; + } + case Context::DONE: + RefreshContext(i); + break; + } + } + } + } + void RefreshContext(int i) { + grpc::internal::MutexLock l(&mu_); + if (!shutting_down_) { + contexts_[i].state = Context::READY; + contexts_[i].srv_ctx.reset(new ServerContext); + contexts_[i].response_writer.reset( + new grpc::ServerAsyncResponseWriter<EchoResponse>( + contexts_[i].srv_ctx.get())); + service_.RequestEcho(contexts_[i].srv_ctx.get(), + &contexts_[i].recv_request, + contexts_[i].response_writer.get(), cq_.get(), + cq_.get(), (void*)static_cast<intptr_t>(i)); + } + } + struct Context { + std::unique_ptr<ServerContext> srv_ctx; + std::unique_ptr<grpc::ServerAsyncResponseWriter<EchoResponse>> + response_writer; + EchoRequest recv_request; + enum { READY, DONE } state; + }; + std::vector<Context> contexts_; + ::grpc::testing::EchoTestService::AsyncService service_; + std::unique_ptr<ServerCompletionQueue> cq_; + bool shutting_down_; + grpc::internal::Mutex mu_; + std::vector<std::thread> server_threads_; +}; + +template <class Common> +class End2endTest : public ::testing::Test { + protected: + End2endTest() {} + void SetUp() override { common_.SetUp(); } + void TearDown() override { common_.TearDown(); } + void ResetStub() { common_.ResetStub(); } + + Common common_; +}; + +static void SendRpc(grpc::testing::EchoTestService::Stub* stub, int num_rpcs, + bool allow_exhaustion, gpr_atm* errors) { + EchoRequest request; + EchoResponse response; + request.set_message("Hello"); + + for (int i = 0; i < num_rpcs; ++i) { + ClientContext context; + Status s = stub->Echo(&context, request, &response); + EXPECT_TRUE(s.ok() || (allow_exhaustion && + s.error_code() == StatusCode::RESOURCE_EXHAUSTED)); + if (!s.ok()) { + if (!(allow_exhaustion && + s.error_code() == StatusCode::RESOURCE_EXHAUSTED)) { + gpr_log(GPR_ERROR, "RPC error: %d: %s", s.error_code(), + s.error_message().c_str()); + } + gpr_atm_no_barrier_fetch_add(errors, static_cast<gpr_atm>(1)); + } else { + EXPECT_EQ(response.message(), request.message()); + } + } +} + +typedef ::testing::Types< + CommonStressTestSyncServer<CommonStressTestInsecure<TestServiceImpl>>, + CommonStressTestSyncServer<CommonStressTestInproc<TestServiceImpl, false>>, + CommonStressTestSyncServerLowThreadCount< + CommonStressTestInproc<TestServiceImpl, true>>, + CommonStressTestAsyncServer< + CommonStressTestInsecure<grpc::testing::EchoTestService::AsyncService>>, + CommonStressTestAsyncServer<CommonStressTestInproc< + grpc::testing::EchoTestService::AsyncService, false>>> + CommonTypes; +TYPED_TEST_SUITE(End2endTest, CommonTypes); +TYPED_TEST(End2endTest, ThreadStress) { + this->common_.ResetStub(); + std::vector<std::thread> threads; + gpr_atm errors; + gpr_atm_rel_store(&errors, static_cast<gpr_atm>(0)); + threads.reserve(kNumThreads); + for (int i = 0; i < kNumThreads; ++i) { + threads.emplace_back(SendRpc, this->common_.GetStub(), kNumRpcs, + this->common_.AllowExhaustion(), &errors); + } + for (int i = 0; i < kNumThreads; ++i) { + threads[i].join(); + } + uint64_t error_cnt = static_cast<uint64_t>(gpr_atm_no_barrier_load(&errors)); + if (error_cnt != 0) { + gpr_log(GPR_INFO, "RPC error count: %" PRIu64, error_cnt); + } + // If this test allows resource exhaustion, expect that it actually sees some + if (this->common_.AllowExhaustion()) { + EXPECT_GT(error_cnt, static_cast<uint64_t>(0)); + } +} + +template <class Common> +class AsyncClientEnd2endTest : public ::testing::Test { + protected: + AsyncClientEnd2endTest() : rpcs_outstanding_(0) {} + + void SetUp() override { common_.SetUp(); } + void TearDown() override { + void* ignored_tag; + bool ignored_ok; + while (cq_.Next(&ignored_tag, &ignored_ok)) + ; + common_.TearDown(); + } + + void Wait() { + grpc::internal::MutexLock l(&mu_); + while (rpcs_outstanding_ != 0) { + cv_.Wait(&mu_); + } + + cq_.Shutdown(); + } + + struct AsyncClientCall { + EchoResponse response; + ClientContext context; + Status status; + std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader; + }; + + void AsyncSendRpc(int num_rpcs) { + for (int i = 0; i < num_rpcs; ++i) { + AsyncClientCall* call = new AsyncClientCall; + EchoRequest request; + request.set_message(TString("Hello: " + grpc::to_string(i)).c_str()); + call->response_reader = + common_.GetStub()->AsyncEcho(&call->context, request, &cq_); + call->response_reader->Finish(&call->response, &call->status, + (void*)call); + + grpc::internal::MutexLock l(&mu_); + rpcs_outstanding_++; + } + } + + void AsyncCompleteRpc() { + while (true) { + void* got_tag; + bool ok = false; + if (!cq_.Next(&got_tag, &ok)) break; + AsyncClientCall* call = static_cast<AsyncClientCall*>(got_tag); + if (!ok) { + gpr_log(GPR_DEBUG, "Error: %d", call->status.error_code()); + } + delete call; + + bool notify; + { + grpc::internal::MutexLock l(&mu_); + rpcs_outstanding_--; + notify = (rpcs_outstanding_ == 0); + } + if (notify) { + cv_.Signal(); + } + } + } + + Common common_; + CompletionQueue cq_; + grpc::internal::Mutex mu_; + grpc::internal::CondVar cv_; + int rpcs_outstanding_; +}; + +TYPED_TEST_SUITE(AsyncClientEnd2endTest, CommonTypes); +TYPED_TEST(AsyncClientEnd2endTest, ThreadStress) { + this->common_.ResetStub(); + std::vector<std::thread> send_threads, completion_threads; + for (int i = 0; i < kNumAsyncReceiveThreads; ++i) { + completion_threads.emplace_back( + &AsyncClientEnd2endTest_ThreadStress_Test<TypeParam>::AsyncCompleteRpc, + this); + } + for (int i = 0; i < kNumAsyncSendThreads; ++i) { + send_threads.emplace_back( + &AsyncClientEnd2endTest_ThreadStress_Test<TypeParam>::AsyncSendRpc, + this, kNumRpcs); + } + for (int i = 0; i < kNumAsyncSendThreads; ++i) { + send_threads[i].join(); + } + + this->Wait(); + for (int i = 0; i < kNumAsyncReceiveThreads; ++i) { + completion_threads[i].join(); + } +} + +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/contrib/libs/grpc/test/cpp/end2end/time_change_test.cc b/contrib/libs/grpc/test/cpp/end2end/time_change_test.cc new file mode 100644 index 0000000000..48b9eace12 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/end2end/time_change_test.cc @@ -0,0 +1,367 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <grpc/grpc.h> +#include <grpc/support/log.h> +#include <grpc/support/time.h> +#include <grpcpp/channel.h> +#include <grpcpp/client_context.h> +#include <grpcpp/create_channel.h> +#include <grpcpp/server.h> +#include <grpcpp/server_builder.h> +#include <grpcpp/server_context.h> + +#include "src/core/lib/iomgr/timer.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/end2end/test_service_impl.h" +#include "test/cpp/util/subprocess.h" + +#include <gtest/gtest.h> +#include <sys/time.h> +#include <thread> + +using grpc::testing::EchoRequest; +using grpc::testing::EchoResponse; + +static TString g_root; + +static gpr_mu g_mu; +extern gpr_timespec (*gpr_now_impl)(gpr_clock_type clock_type); +gpr_timespec (*gpr_now_impl_orig)(gpr_clock_type clock_type) = gpr_now_impl; +static int g_time_shift_sec = 0; +static int g_time_shift_nsec = 0; +static gpr_timespec now_impl(gpr_clock_type clock) { + auto ts = gpr_now_impl_orig(clock); + // We only manipulate the realtime clock to simulate changes in wall-clock + // time + if (clock != GPR_CLOCK_REALTIME) { + return ts; + } + GPR_ASSERT(ts.tv_nsec >= 0); + GPR_ASSERT(ts.tv_nsec < GPR_NS_PER_SEC); + gpr_mu_lock(&g_mu); + ts.tv_sec += g_time_shift_sec; + ts.tv_nsec += g_time_shift_nsec; + gpr_mu_unlock(&g_mu); + if (ts.tv_nsec >= GPR_NS_PER_SEC) { + ts.tv_nsec -= GPR_NS_PER_SEC; + ++ts.tv_sec; + } else if (ts.tv_nsec < 0) { + --ts.tv_sec; + ts.tv_nsec = GPR_NS_PER_SEC + ts.tv_nsec; + } + return ts; +} + +// offset the value returned by gpr_now(GPR_CLOCK_REALTIME) by msecs +// milliseconds +static void set_now_offset(int msecs) { + gpr_mu_lock(&g_mu); + g_time_shift_sec = msecs / 1000; + g_time_shift_nsec = (msecs % 1000) * 1e6; + gpr_mu_unlock(&g_mu); +} + +// restore the original implementation of gpr_now() +static void reset_now_offset() { + gpr_mu_lock(&g_mu); + g_time_shift_sec = 0; + g_time_shift_nsec = 0; + gpr_mu_unlock(&g_mu); +} + +namespace grpc { +namespace testing { + +namespace { + +// gpr_now() is called with invalid clock_type +TEST(TimespecTest, GprNowInvalidClockType) { + // initialize to some junk value + gpr_clock_type invalid_clock_type = (gpr_clock_type)32641; + EXPECT_DEATH(gpr_now(invalid_clock_type), ".*"); +} + +// Add timespan with negative nanoseconds +TEST(TimespecTest, GprTimeAddNegativeNs) { + gpr_timespec now = gpr_now(GPR_CLOCK_MONOTONIC); + gpr_timespec bad_ts = {1, -1000, GPR_TIMESPAN}; + EXPECT_DEATH(gpr_time_add(now, bad_ts), ".*"); +} + +// Subtract timespan with negative nanoseconds +TEST(TimespecTest, GprTimeSubNegativeNs) { + // Nanoseconds must always be positive. Negative timestamps are represented by + // (negative seconds, positive nanoseconds) + gpr_timespec now = gpr_now(GPR_CLOCK_MONOTONIC); + gpr_timespec bad_ts = {1, -1000, GPR_TIMESPAN}; + EXPECT_DEATH(gpr_time_sub(now, bad_ts), ".*"); +} + +// Add negative milliseconds to gpr_timespec +TEST(TimespecTest, GrpcNegativeMillisToTimespec) { + // -1500 milliseconds converts to timespec (-2 secs, 5 * 10^8 nsec) + gpr_timespec ts = grpc_millis_to_timespec(-1500, GPR_CLOCK_MONOTONIC); + GPR_ASSERT(ts.tv_sec = -2); + GPR_ASSERT(ts.tv_nsec = 5e8); + GPR_ASSERT(ts.clock_type == GPR_CLOCK_MONOTONIC); +} + +class TimeChangeTest : public ::testing::Test { + protected: + TimeChangeTest() {} + + static void SetUpTestCase() { + auto port = grpc_pick_unused_port_or_die(); + std::ostringstream addr_stream; + addr_stream << "localhost:" << port; + server_address_ = addr_stream.str(); + server_.reset(new SubProcess({ + g_root + "/client_crash_test_server", + "--address=" + server_address_, + })); + GPR_ASSERT(server_); + // connect to server and make sure it's reachable. + auto channel = + grpc::CreateChannel(server_address_, InsecureChannelCredentials()); + GPR_ASSERT(channel); + EXPECT_TRUE(channel->WaitForConnected( + grpc_timeout_milliseconds_to_deadline(30000))); + } + + static void TearDownTestCase() { server_.reset(); } + + void SetUp() { + channel_ = + grpc::CreateChannel(server_address_, InsecureChannelCredentials()); + GPR_ASSERT(channel_); + stub_ = grpc::testing::EchoTestService::NewStub(channel_); + } + + void TearDown() { reset_now_offset(); } + + std::unique_ptr<grpc::testing::EchoTestService::Stub> CreateStub() { + return grpc::testing::EchoTestService::NewStub(channel_); + } + + std::shared_ptr<Channel> GetChannel() { return channel_; } + // time jump offsets in milliseconds + const int TIME_OFFSET1 = 20123; + const int TIME_OFFSET2 = 5678; + + private: + static TString server_address_; + static std::unique_ptr<SubProcess> server_; + std::shared_ptr<Channel> channel_; + std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_; +}; +TString TimeChangeTest::server_address_; +std::unique_ptr<SubProcess> TimeChangeTest::server_; + +// Wall-clock time jumps forward on client before bidi stream is created +TEST_F(TimeChangeTest, TimeJumpForwardBeforeStreamCreated) { + EchoRequest request; + EchoResponse response; + ClientContext context; + context.set_deadline(grpc_timeout_milliseconds_to_deadline(5000)); + context.AddMetadata(kServerResponseStreamsToSend, "1"); + + auto channel = GetChannel(); + GPR_ASSERT(channel); + EXPECT_TRUE( + channel->WaitForConnected(grpc_timeout_milliseconds_to_deadline(5000))); + auto stub = CreateStub(); + + // time jumps forward by TIME_OFFSET1 milliseconds + set_now_offset(TIME_OFFSET1); + auto stream = stub->BidiStream(&context); + request.set_message("Hello"); + EXPECT_TRUE(stream->Write(request)); + + EXPECT_TRUE(stream->WritesDone()); + EXPECT_TRUE(stream->Read(&response)); + + auto status = stream->Finish(); + EXPECT_TRUE(status.ok()); +} + +// Wall-clock time jumps back on client before bidi stream is created +TEST_F(TimeChangeTest, TimeJumpBackBeforeStreamCreated) { + EchoRequest request; + EchoResponse response; + ClientContext context; + context.set_deadline(grpc_timeout_milliseconds_to_deadline(5000)); + context.AddMetadata(kServerResponseStreamsToSend, "1"); + + auto channel = GetChannel(); + GPR_ASSERT(channel); + EXPECT_TRUE( + channel->WaitForConnected(grpc_timeout_milliseconds_to_deadline(5000))); + auto stub = CreateStub(); + + // time jumps back by TIME_OFFSET1 milliseconds + set_now_offset(-TIME_OFFSET1); + auto stream = stub->BidiStream(&context); + request.set_message("Hello"); + EXPECT_TRUE(stream->Write(request)); + + EXPECT_TRUE(stream->WritesDone()); + EXPECT_TRUE(stream->Read(&response)); + EXPECT_EQ(request.message(), response.message()); + + auto status = stream->Finish(); + EXPECT_TRUE(status.ok()); +} + +// Wall-clock time jumps forward on client while call is in progress +TEST_F(TimeChangeTest, TimeJumpForwardAfterStreamCreated) { + EchoRequest request; + EchoResponse response; + ClientContext context; + context.set_deadline(grpc_timeout_milliseconds_to_deadline(5000)); + context.AddMetadata(kServerResponseStreamsToSend, "2"); + + auto channel = GetChannel(); + GPR_ASSERT(channel); + EXPECT_TRUE( + channel->WaitForConnected(grpc_timeout_milliseconds_to_deadline(5000))); + auto stub = CreateStub(); + + auto stream = stub->BidiStream(&context); + + request.set_message("Hello"); + EXPECT_TRUE(stream->Write(request)); + EXPECT_TRUE(stream->Read(&response)); + + // time jumps forward by TIME_OFFSET1 milliseconds. + set_now_offset(TIME_OFFSET1); + + request.set_message("World"); + EXPECT_TRUE(stream->Write(request)); + EXPECT_TRUE(stream->WritesDone()); + EXPECT_TRUE(stream->Read(&response)); + + auto status = stream->Finish(); + EXPECT_TRUE(status.ok()); +} + +// Wall-clock time jumps back on client while call is in progress +TEST_F(TimeChangeTest, TimeJumpBackAfterStreamCreated) { + EchoRequest request; + EchoResponse response; + ClientContext context; + context.set_deadline(grpc_timeout_milliseconds_to_deadline(5000)); + context.AddMetadata(kServerResponseStreamsToSend, "2"); + + auto channel = GetChannel(); + GPR_ASSERT(channel); + EXPECT_TRUE( + channel->WaitForConnected(grpc_timeout_milliseconds_to_deadline(5000))); + auto stub = CreateStub(); + + auto stream = stub->BidiStream(&context); + + request.set_message("Hello"); + EXPECT_TRUE(stream->Write(request)); + EXPECT_TRUE(stream->Read(&response)); + + // time jumps back TIME_OFFSET1 milliseconds. + set_now_offset(-TIME_OFFSET1); + + request.set_message("World"); + EXPECT_TRUE(stream->Write(request)); + EXPECT_TRUE(stream->WritesDone()); + EXPECT_TRUE(stream->Read(&response)); + + auto status = stream->Finish(); + EXPECT_TRUE(status.ok()); +} + +// Wall-clock time jumps forward and backwards during call +TEST_F(TimeChangeTest, TimeJumpForwardAndBackDuringCall) { + EchoRequest request; + EchoResponse response; + ClientContext context; + context.set_deadline(grpc_timeout_milliseconds_to_deadline(5000)); + context.AddMetadata(kServerResponseStreamsToSend, "2"); + + auto channel = GetChannel(); + GPR_ASSERT(channel); + + EXPECT_TRUE( + channel->WaitForConnected(grpc_timeout_milliseconds_to_deadline(5000))); + auto stub = CreateStub(); + auto stream = stub->BidiStream(&context); + + request.set_message("Hello"); + EXPECT_TRUE(stream->Write(request)); + + // time jumps back by TIME_OFFSET2 milliseconds + set_now_offset(-TIME_OFFSET2); + + EXPECT_TRUE(stream->Read(&response)); + request.set_message("World"); + + // time jumps forward by TIME_OFFSET milliseconds + set_now_offset(TIME_OFFSET1); + + EXPECT_TRUE(stream->Write(request)); + + // time jumps back by TIME_OFFSET2 milliseconds + set_now_offset(-TIME_OFFSET2); + + EXPECT_TRUE(stream->WritesDone()); + + // time jumps back by TIME_OFFSET2 milliseconds + set_now_offset(-TIME_OFFSET2); + + EXPECT_TRUE(stream->Read(&response)); + + // time jumps back by TIME_OFFSET2 milliseconds + set_now_offset(-TIME_OFFSET2); + + auto status = stream->Finish(); + EXPECT_TRUE(status.ok()); +} + +} // namespace + +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + TString me = argv[0]; + // get index of last slash in path to test binary + auto lslash = me.rfind('/'); + // set g_root = path to directory containing test binary + if (lslash != TString::npos) { + g_root = me.substr(0, lslash); + } else { + g_root = "."; + } + + gpr_mu_init(&g_mu); + gpr_now_impl = now_impl; + + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + auto ret = RUN_ALL_TESTS(); + return ret; +} diff --git a/contrib/libs/grpc/test/cpp/end2end/xds_end2end_test.cc b/contrib/libs/grpc/test/cpp/end2end/xds_end2end_test.cc new file mode 100644 index 0000000000..603e6186bf --- /dev/null +++ b/contrib/libs/grpc/test/cpp/end2end/xds_end2end_test.cc @@ -0,0 +1,5832 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <deque> +#include <memory> +#include <mutex> +#include <numeric> +#include <set> +#include <sstream> +#include <util/generic/string.h> +#include <thread> +#include <vector> + +#include <gmock/gmock.h> +#include <gtest/gtest.h> + +#include "y_absl/strings/str_cat.h" +#include "y_absl/types/optional.h" + +#include <grpc/grpc.h> +#include <grpc/support/alloc.h> +#include <grpc/support/log.h> +#include <grpc/support/time.h> +#include <grpcpp/channel.h> +#include <grpcpp/client_context.h> +#include <grpcpp/create_channel.h> +#include <grpcpp/server.h> +#include <grpcpp/server_builder.h> + +#include "src/core/ext/filters/client_channel/backup_poller.h" +#include "src/core/ext/filters/client_channel/resolver/fake/fake_resolver.h" +#include "src/core/ext/filters/client_channel/server_address.h" +#include "src/core/ext/xds/xds_api.h" +#include "src/core/ext/xds/xds_channel_args.h" +#include "src/core/ext/xds/xds_client.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/gpr/tmpfile.h" +#include "src/core/lib/gprpp/map.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/gprpp/sync.h" +#include "src/core/lib/iomgr/parse_address.h" +#include "src/core/lib/iomgr/sockaddr.h" +#include "src/core/lib/security/credentials/fake/fake_credentials.h" +#include "src/cpp/client/secure_credentials.h" +#include "src/cpp/server/secure_server_credentials.h" + +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/end2end/test_service_impl.h" + +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "src/proto/grpc/testing/xds/ads_for_test.grpc.pb.h" +#include "src/proto/grpc/testing/xds/cds_for_test.grpc.pb.h" +#include "src/proto/grpc/testing/xds/eds_for_test.grpc.pb.h" +#include "src/proto/grpc/testing/xds/lds_rds_for_test.grpc.pb.h" +#include "src/proto/grpc/testing/xds/lrs_for_test.grpc.pb.h" + +#include "src/proto/grpc/testing/xds/v3/ads.grpc.pb.h" +#include "src/proto/grpc/testing/xds/v3/cluster.grpc.pb.h" +#include "src/proto/grpc/testing/xds/v3/discovery.grpc.pb.h" +#include "src/proto/grpc/testing/xds/v3/endpoint.grpc.pb.h" +#include "src/proto/grpc/testing/xds/v3/http_connection_manager.grpc.pb.h" +#include "src/proto/grpc/testing/xds/v3/listener.grpc.pb.h" +#include "src/proto/grpc/testing/xds/v3/lrs.grpc.pb.h" +#include "src/proto/grpc/testing/xds/v3/route.grpc.pb.h" + +namespace grpc { +namespace testing { +namespace { + +using std::chrono::system_clock; + +using ::envoy::config::cluster::v3::CircuitBreakers; +using ::envoy::config::cluster::v3::Cluster; +using ::envoy::config::cluster::v3::RoutingPriority; +using ::envoy::config::endpoint::v3::ClusterLoadAssignment; +using ::envoy::config::endpoint::v3::HealthStatus; +using ::envoy::config::listener::v3::Listener; +using ::envoy::config::route::v3::RouteConfiguration; +using ::envoy::extensions::filters::network::http_connection_manager::v3:: + HttpConnectionManager; +using ::envoy::type::v3::FractionalPercent; + +constexpr char kLdsTypeUrl[] = + "type.googleapis.com/envoy.config.listener.v3.Listener"; +constexpr char kRdsTypeUrl[] = + "type.googleapis.com/envoy.config.route.v3.RouteConfiguration"; +constexpr char kCdsTypeUrl[] = + "type.googleapis.com/envoy.config.cluster.v3.Cluster"; +constexpr char kEdsTypeUrl[] = + "type.googleapis.com/envoy.config.endpoint.v3.ClusterLoadAssignment"; + +constexpr char kLdsV2TypeUrl[] = "type.googleapis.com/envoy.api.v2.Listener"; +constexpr char kRdsV2TypeUrl[] = + "type.googleapis.com/envoy.api.v2.RouteConfiguration"; +constexpr char kCdsV2TypeUrl[] = "type.googleapis.com/envoy.api.v2.Cluster"; +constexpr char kEdsV2TypeUrl[] = + "type.googleapis.com/envoy.api.v2.ClusterLoadAssignment"; + +constexpr char kDefaultLocalityRegion[] = "xds_default_locality_region"; +constexpr char kDefaultLocalityZone[] = "xds_default_locality_zone"; +constexpr char kLbDropType[] = "lb"; +constexpr char kThrottleDropType[] = "throttle"; +constexpr char kServerName[] = "server.example.com"; +constexpr char kDefaultRouteConfigurationName[] = "route_config_name"; +constexpr char kDefaultClusterName[] = "cluster_name"; +constexpr char kDefaultEdsServiceName[] = "eds_service_name"; +constexpr int kDefaultLocalityWeight = 3; +constexpr int kDefaultLocalityPriority = 0; + +constexpr char kRequestMessage[] = "Live long and prosper."; +constexpr char kDefaultServiceConfig[] = + "{\n" + " \"loadBalancingConfig\":[\n" + " { \"does_not_exist\":{} },\n" + " { \"eds_experimental\":{\n" + " \"clusterName\": \"server.example.com\",\n" + " \"lrsLoadReportingServerName\": \"\"\n" + " } }\n" + " ]\n" + "}"; +constexpr char kDefaultServiceConfigWithoutLoadReporting[] = + "{\n" + " \"loadBalancingConfig\":[\n" + " { \"does_not_exist\":{} },\n" + " { \"eds_experimental\":{\n" + " \"clusterName\": \"server.example.com\"\n" + " } }\n" + " ]\n" + "}"; + +constexpr char kBootstrapFileV3[] = + "{\n" + " \"xds_servers\": [\n" + " {\n" + " \"server_uri\": \"fake:///xds_server\",\n" + " \"channel_creds\": [\n" + " {\n" + " \"type\": \"fake\"\n" + " }\n" + " ],\n" + " \"server_features\": [\"xds_v3\"]\n" + " }\n" + " ],\n" + " \"node\": {\n" + " \"id\": \"xds_end2end_test\",\n" + " \"cluster\": \"test\",\n" + " \"metadata\": {\n" + " \"foo\": \"bar\"\n" + " },\n" + " \"locality\": {\n" + " \"region\": \"corp\",\n" + " \"zone\": \"svl\",\n" + " \"subzone\": \"mp3\"\n" + " }\n" + " }\n" + "}\n"; + +constexpr char kBootstrapFileV2[] = + "{\n" + " \"xds_servers\": [\n" + " {\n" + " \"server_uri\": \"fake:///xds_server\",\n" + " \"channel_creds\": [\n" + " {\n" + " \"type\": \"fake\"\n" + " }\n" + " ]\n" + " }\n" + " ],\n" + " \"node\": {\n" + " \"id\": \"xds_end2end_test\",\n" + " \"cluster\": \"test\",\n" + " \"metadata\": {\n" + " \"foo\": \"bar\"\n" + " },\n" + " \"locality\": {\n" + " \"region\": \"corp\",\n" + " \"zone\": \"svl\",\n" + " \"subzone\": \"mp3\"\n" + " }\n" + " }\n" + "}\n"; + +char* g_bootstrap_file_v3; +char* g_bootstrap_file_v2; + +void WriteBootstrapFiles() { + char* bootstrap_file; + FILE* out = gpr_tmpfile("xds_bootstrap_v3", &bootstrap_file); + fputs(kBootstrapFileV3, out); + fclose(out); + g_bootstrap_file_v3 = bootstrap_file; + out = gpr_tmpfile("xds_bootstrap_v2", &bootstrap_file); + fputs(kBootstrapFileV2, out); + fclose(out); + g_bootstrap_file_v2 = bootstrap_file; +} + +// Helper class to minimize the number of unique ports we use for this test. +class PortSaver { + public: + int GetPort() { + if (idx_ >= ports_.size()) { + ports_.push_back(grpc_pick_unused_port_or_die()); + } + return ports_[idx_++]; + } + + void Reset() { idx_ = 0; } + + private: + std::vector<int> ports_; + size_t idx_ = 0; +}; + +PortSaver* g_port_saver = nullptr; + +template <typename ServiceType> +class CountedService : public ServiceType { + public: + size_t request_count() { + grpc_core::MutexLock lock(&mu_); + return request_count_; + } + + size_t response_count() { + grpc_core::MutexLock lock(&mu_); + return response_count_; + } + + void IncreaseResponseCount() { + grpc_core::MutexLock lock(&mu_); + ++response_count_; + } + void IncreaseRequestCount() { + grpc_core::MutexLock lock(&mu_); + ++request_count_; + } + + void ResetCounters() { + grpc_core::MutexLock lock(&mu_); + request_count_ = 0; + response_count_ = 0; + } + + private: + grpc_core::Mutex mu_; + size_t request_count_ = 0; + size_t response_count_ = 0; +}; + +const char g_kCallCredsMdKey[] = "Balancer should not ..."; +const char g_kCallCredsMdValue[] = "... receive me"; + +template <typename RpcService> +class BackendServiceImpl + : public CountedService<TestMultipleServiceImpl<RpcService>> { + public: + BackendServiceImpl() {} + + Status Echo(ServerContext* context, const EchoRequest* request, + EchoResponse* response) override { + // Backend should receive the call credentials metadata. + auto call_credentials_entry = + context->client_metadata().find(g_kCallCredsMdKey); + EXPECT_NE(call_credentials_entry, context->client_metadata().end()); + if (call_credentials_entry != context->client_metadata().end()) { + EXPECT_EQ(call_credentials_entry->second, g_kCallCredsMdValue); + } + CountedService<TestMultipleServiceImpl<RpcService>>::IncreaseRequestCount(); + const auto status = + TestMultipleServiceImpl<RpcService>::Echo(context, request, response); + CountedService< + TestMultipleServiceImpl<RpcService>>::IncreaseResponseCount(); + AddClient(context->peer()); + return status; + } + + Status Echo1(ServerContext* context, const EchoRequest* request, + EchoResponse* response) override { + return Echo(context, request, response); + } + + Status Echo2(ServerContext* context, const EchoRequest* request, + EchoResponse* response) override { + return Echo(context, request, response); + } + + void Start() {} + void Shutdown() {} + + std::set<TString> clients() { + grpc_core::MutexLock lock(&clients_mu_); + return clients_; + } + + private: + void AddClient(const TString& client) { + grpc_core::MutexLock lock(&clients_mu_); + clients_.insert(client); + } + + grpc_core::Mutex clients_mu_; + std::set<TString> clients_; +}; + +class ClientStats { + public: + struct LocalityStats { + LocalityStats() {} + + // Converts from proto message class. + template <class UpstreamLocalityStats> + LocalityStats(const UpstreamLocalityStats& upstream_locality_stats) + : total_successful_requests( + upstream_locality_stats.total_successful_requests()), + total_requests_in_progress( + upstream_locality_stats.total_requests_in_progress()), + total_error_requests(upstream_locality_stats.total_error_requests()), + total_issued_requests( + upstream_locality_stats.total_issued_requests()) {} + + LocalityStats& operator+=(const LocalityStats& other) { + total_successful_requests += other.total_successful_requests; + total_requests_in_progress += other.total_requests_in_progress; + total_error_requests += other.total_error_requests; + total_issued_requests += other.total_issued_requests; + return *this; + } + + uint64_t total_successful_requests = 0; + uint64_t total_requests_in_progress = 0; + uint64_t total_error_requests = 0; + uint64_t total_issued_requests = 0; + }; + + ClientStats() {} + + // Converts from proto message class. + template <class ClusterStats> + explicit ClientStats(const ClusterStats& cluster_stats) + : cluster_name_(cluster_stats.cluster_name()), + total_dropped_requests_(cluster_stats.total_dropped_requests()) { + for (const auto& input_locality_stats : + cluster_stats.upstream_locality_stats()) { + locality_stats_.emplace(input_locality_stats.locality().sub_zone(), + LocalityStats(input_locality_stats)); + } + for (const auto& input_dropped_requests : + cluster_stats.dropped_requests()) { + dropped_requests_.emplace(input_dropped_requests.category(), + input_dropped_requests.dropped_count()); + } + } + + const TString& cluster_name() const { return cluster_name_; } + + const std::map<TString, LocalityStats>& locality_stats() const { + return locality_stats_; + } + uint64_t total_successful_requests() const { + uint64_t sum = 0; + for (auto& p : locality_stats_) { + sum += p.second.total_successful_requests; + } + return sum; + } + uint64_t total_requests_in_progress() const { + uint64_t sum = 0; + for (auto& p : locality_stats_) { + sum += p.second.total_requests_in_progress; + } + return sum; + } + uint64_t total_error_requests() const { + uint64_t sum = 0; + for (auto& p : locality_stats_) { + sum += p.second.total_error_requests; + } + return sum; + } + uint64_t total_issued_requests() const { + uint64_t sum = 0; + for (auto& p : locality_stats_) { + sum += p.second.total_issued_requests; + } + return sum; + } + + uint64_t total_dropped_requests() const { return total_dropped_requests_; } + + uint64_t dropped_requests(const TString& category) const { + auto iter = dropped_requests_.find(category); + GPR_ASSERT(iter != dropped_requests_.end()); + return iter->second; + } + + ClientStats& operator+=(const ClientStats& other) { + for (const auto& p : other.locality_stats_) { + locality_stats_[p.first] += p.second; + } + total_dropped_requests_ += other.total_dropped_requests_; + for (const auto& p : other.dropped_requests_) { + dropped_requests_[p.first] += p.second; + } + return *this; + } + + private: + TString cluster_name_; + std::map<TString, LocalityStats> locality_stats_; + uint64_t total_dropped_requests_ = 0; + std::map<TString, uint64_t> dropped_requests_; +}; + +class AdsServiceImpl : public std::enable_shared_from_this<AdsServiceImpl> { + public: + struct ResponseState { + enum State { NOT_SENT, SENT, ACKED, NACKED }; + State state = NOT_SENT; + TString error_message; + }; + + struct EdsResourceArgs { + struct Locality { + Locality(const TString& sub_zone, std::vector<int> ports, + int lb_weight = kDefaultLocalityWeight, + int priority = kDefaultLocalityPriority, + std::vector<HealthStatus> health_statuses = {}) + : sub_zone(std::move(sub_zone)), + ports(std::move(ports)), + lb_weight(lb_weight), + priority(priority), + health_statuses(std::move(health_statuses)) {} + + const TString sub_zone; + std::vector<int> ports; + int lb_weight; + int priority; + std::vector<HealthStatus> health_statuses; + }; + + EdsResourceArgs() = default; + explicit EdsResourceArgs(std::vector<Locality> locality_list) + : locality_list(std::move(locality_list)) {} + + std::vector<Locality> locality_list; + std::map<TString, uint32_t> drop_categories; + FractionalPercent::DenominatorType drop_denominator = + FractionalPercent::MILLION; + }; + + explicit AdsServiceImpl(bool enable_load_reporting) + : v2_rpc_service_(this, /*is_v2=*/true), + v3_rpc_service_(this, /*is_v2=*/false) { + // Construct RDS response data. + default_route_config_.set_name(kDefaultRouteConfigurationName); + auto* virtual_host = default_route_config_.add_virtual_hosts(); + virtual_host->add_domains("*"); + auto* route = virtual_host->add_routes(); + route->mutable_match()->set_prefix(""); + route->mutable_route()->set_cluster(kDefaultClusterName); + SetRdsResource(default_route_config_); + // Construct LDS response data (with inlined RDS result). + default_listener_ = BuildListener(default_route_config_); + SetLdsResource(default_listener_); + // Construct CDS response data. + default_cluster_.set_name(kDefaultClusterName); + default_cluster_.set_type(Cluster::EDS); + auto* eds_config = default_cluster_.mutable_eds_cluster_config(); + eds_config->mutable_eds_config()->mutable_ads(); + eds_config->set_service_name(kDefaultEdsServiceName); + default_cluster_.set_lb_policy(Cluster::ROUND_ROBIN); + if (enable_load_reporting) { + default_cluster_.mutable_lrs_server()->mutable_self(); + } + SetCdsResource(default_cluster_); + } + + bool seen_v2_client() const { return seen_v2_client_; } + bool seen_v3_client() const { return seen_v3_client_; } + + ::envoy::service::discovery::v2::AggregatedDiscoveryService::Service* + v2_rpc_service() { + return &v2_rpc_service_; + } + + ::envoy::service::discovery::v3::AggregatedDiscoveryService::Service* + v3_rpc_service() { + return &v3_rpc_service_; + } + + Listener default_listener() const { return default_listener_; } + RouteConfiguration default_route_config() const { + return default_route_config_; + } + Cluster default_cluster() const { return default_cluster_; } + + ResponseState lds_response_state() { + grpc_core::MutexLock lock(&ads_mu_); + return resource_type_response_state_[kLdsTypeUrl]; + } + + ResponseState rds_response_state() { + grpc_core::MutexLock lock(&ads_mu_); + return resource_type_response_state_[kRdsTypeUrl]; + } + + ResponseState cds_response_state() { + grpc_core::MutexLock lock(&ads_mu_); + return resource_type_response_state_[kCdsTypeUrl]; + } + + ResponseState eds_response_state() { + grpc_core::MutexLock lock(&ads_mu_); + return resource_type_response_state_[kEdsTypeUrl]; + } + + void SetResourceIgnore(const TString& type_url) { + grpc_core::MutexLock lock(&ads_mu_); + resource_types_to_ignore_.emplace(type_url); + } + + void UnsetResource(const TString& type_url, const TString& name) { + grpc_core::MutexLock lock(&ads_mu_); + ResourceState& state = resource_map_[type_url][name]; + ++state.version; + state.resource.reset(); + gpr_log(GPR_INFO, "ADS[%p]: Unsetting %s resource %s to version %u", this, + type_url.c_str(), name.c_str(), state.version); + for (SubscriptionState* subscription : state.subscriptions) { + subscription->update_queue->emplace_back(type_url, name); + } + } + + void SetResource(google::protobuf::Any resource, const TString& type_url, + const TString& name) { + grpc_core::MutexLock lock(&ads_mu_); + ResourceState& state = resource_map_[type_url][name]; + ++state.version; + state.resource = std::move(resource); + gpr_log(GPR_INFO, "ADS[%p]: Updating %s resource %s to version %u", this, + type_url.c_str(), name.c_str(), state.version); + for (SubscriptionState* subscription : state.subscriptions) { + subscription->update_queue->emplace_back(type_url, name); + } + } + + void SetLdsResource(const Listener& listener) { + google::protobuf::Any resource; + resource.PackFrom(listener); + SetResource(std::move(resource), kLdsTypeUrl, listener.name()); + } + + void SetRdsResource(const RouteConfiguration& route) { + google::protobuf::Any resource; + resource.PackFrom(route); + SetResource(std::move(resource), kRdsTypeUrl, route.name()); + } + + void SetCdsResource(const Cluster& cluster) { + google::protobuf::Any resource; + resource.PackFrom(cluster); + SetResource(std::move(resource), kCdsTypeUrl, cluster.name()); + } + + void SetEdsResource(const ClusterLoadAssignment& assignment) { + google::protobuf::Any resource; + resource.PackFrom(assignment); + SetResource(std::move(resource), kEdsTypeUrl, assignment.cluster_name()); + } + + void SetLdsToUseDynamicRds() { + auto listener = default_listener_; + HttpConnectionManager http_connection_manager; + auto* rds = http_connection_manager.mutable_rds(); + rds->set_route_config_name(kDefaultRouteConfigurationName); + rds->mutable_config_source()->mutable_ads(); + listener.mutable_api_listener()->mutable_api_listener()->PackFrom( + http_connection_manager); + SetLdsResource(listener); + } + + static Listener BuildListener(const RouteConfiguration& route_config) { + HttpConnectionManager http_connection_manager; + *(http_connection_manager.mutable_route_config()) = route_config; + Listener listener; + listener.set_name(kServerName); + listener.mutable_api_listener()->mutable_api_listener()->PackFrom( + http_connection_manager); + return listener; + } + + static ClusterLoadAssignment BuildEdsResource( + const EdsResourceArgs& args, + const char* eds_service_name = kDefaultEdsServiceName) { + ClusterLoadAssignment assignment; + assignment.set_cluster_name(eds_service_name); + for (const auto& locality : args.locality_list) { + auto* endpoints = assignment.add_endpoints(); + endpoints->mutable_load_balancing_weight()->set_value(locality.lb_weight); + endpoints->set_priority(locality.priority); + endpoints->mutable_locality()->set_region(kDefaultLocalityRegion); + endpoints->mutable_locality()->set_zone(kDefaultLocalityZone); + endpoints->mutable_locality()->set_sub_zone(locality.sub_zone); + for (size_t i = 0; i < locality.ports.size(); ++i) { + const int& port = locality.ports[i]; + auto* lb_endpoints = endpoints->add_lb_endpoints(); + if (locality.health_statuses.size() > i && + locality.health_statuses[i] != HealthStatus::UNKNOWN) { + lb_endpoints->set_health_status(locality.health_statuses[i]); + } + auto* endpoint = lb_endpoints->mutable_endpoint(); + auto* address = endpoint->mutable_address(); + auto* socket_address = address->mutable_socket_address(); + socket_address->set_address("127.0.0.1"); + socket_address->set_port_value(port); + } + } + if (!args.drop_categories.empty()) { + auto* policy = assignment.mutable_policy(); + for (const auto& p : args.drop_categories) { + const TString& name = p.first; + const uint32_t parts_per_million = p.second; + auto* drop_overload = policy->add_drop_overloads(); + drop_overload->set_category(name); + auto* drop_percentage = drop_overload->mutable_drop_percentage(); + drop_percentage->set_numerator(parts_per_million); + drop_percentage->set_denominator(args.drop_denominator); + } + } + return assignment; + } + + void Start() { + grpc_core::MutexLock lock(&ads_mu_); + ads_done_ = false; + } + + void Shutdown() { + { + grpc_core::MutexLock lock(&ads_mu_); + NotifyDoneWithAdsCallLocked(); + resource_type_response_state_.clear(); + } + gpr_log(GPR_INFO, "ADS[%p]: shut down", this); + } + + void NotifyDoneWithAdsCall() { + grpc_core::MutexLock lock(&ads_mu_); + NotifyDoneWithAdsCallLocked(); + } + + void NotifyDoneWithAdsCallLocked() { + if (!ads_done_) { + ads_done_ = true; + ads_cond_.Broadcast(); + } + } + + std::set<TString> clients() { + grpc_core::MutexLock lock(&clients_mu_); + return clients_; + } + + private: + // A queue of resource type/name pairs that have changed since the client + // subscribed to them. + using UpdateQueue = std::deque< + std::pair<TString /* type url */, TString /* resource name */>>; + + // A struct representing a client's subscription to a particular resource. + struct SubscriptionState { + // Version that the client currently knows about. + int current_version = 0; + // The queue upon which to place updates when the resource is updated. + UpdateQueue* update_queue; + }; + + // A struct representing the a client's subscription to all the resources. + using SubscriptionNameMap = + std::map<TString /* resource_name */, SubscriptionState>; + using SubscriptionMap = + std::map<TString /* type_url */, SubscriptionNameMap>; + + // A struct representing the current state for a resource: + // - the version of the resource that is set by the SetResource() methods. + // - a list of subscriptions interested in this resource. + struct ResourceState { + int version = 0; + y_absl::optional<google::protobuf::Any> resource; + std::set<SubscriptionState*> subscriptions; + }; + + // A struct representing the current state for all resources: + // LDS, CDS, EDS, and RDS for the class as a whole. + using ResourceNameMap = + std::map<TString /* resource_name */, ResourceState>; + using ResourceMap = std::map<TString /* type_url */, ResourceNameMap>; + + template <class RpcApi, class DiscoveryRequest, class DiscoveryResponse> + class RpcService : public RpcApi::Service { + public: + using Stream = ServerReaderWriter<DiscoveryResponse, DiscoveryRequest>; + + RpcService(AdsServiceImpl* parent, bool is_v2) + : parent_(parent), is_v2_(is_v2) {} + + Status StreamAggregatedResources(ServerContext* context, + Stream* stream) override { + gpr_log(GPR_INFO, "ADS[%p]: StreamAggregatedResources starts", this); + parent_->AddClient(context->peer()); + if (is_v2_) { + parent_->seen_v2_client_ = true; + } else { + parent_->seen_v3_client_ = true; + } + // Resources (type/name pairs) that have changed since the client + // subscribed to them. + UpdateQueue update_queue; + // Resources that the client will be subscribed to keyed by resource type + // url. + SubscriptionMap subscription_map; + [&]() { + { + grpc_core::MutexLock lock(&parent_->ads_mu_); + if (parent_->ads_done_) return; + } + // Balancer shouldn't receive the call credentials metadata. + EXPECT_EQ(context->client_metadata().find(g_kCallCredsMdKey), + context->client_metadata().end()); + // Current Version map keyed by resource type url. + std::map<TString, int> resource_type_version; + // Creating blocking thread to read from stream. + std::deque<DiscoveryRequest> requests; + bool stream_closed = false; + // Take a reference of the AdsServiceImpl object, reference will go + // out of scope after the reader thread is joined. + std::shared_ptr<AdsServiceImpl> ads_service_impl = + parent_->shared_from_this(); + std::thread reader(std::bind(&RpcService::BlockingRead, this, stream, + &requests, &stream_closed)); + // Main loop to look for requests and updates. + while (true) { + // Look for new requests and and decide what to handle. + y_absl::optional<DiscoveryResponse> response; + // Boolean to keep track if the loop received any work to do: a + // request or an update; regardless whether a response was actually + // sent out. + bool did_work = false; + { + grpc_core::MutexLock lock(&parent_->ads_mu_); + if (stream_closed) break; + if (!requests.empty()) { + DiscoveryRequest request = std::move(requests.front()); + requests.pop_front(); + did_work = true; + gpr_log(GPR_INFO, + "ADS[%p]: Received request for type %s with content %s", + this, request.type_url().c_str(), + request.DebugString().c_str()); + const TString v3_resource_type = + TypeUrlToV3(request.type_url()); + // As long as we are not in shutdown, identify ACK and NACK by + // looking for version information and comparing it to nonce (this + // server ensures they are always set to the same in a response.) + auto it = + parent_->resource_type_response_state_.find(v3_resource_type); + if (it != parent_->resource_type_response_state_.end()) { + if (!request.response_nonce().empty()) { + it->second.state = + (!request.version_info().empty() && + request.version_info() == request.response_nonce()) + ? ResponseState::ACKED + : ResponseState::NACKED; + } + if (request.has_error_detail()) { + it->second.error_message = request.error_detail().message(); + } + } + // As long as the test did not tell us to ignore this type of + // request, look at all the resource names. + if (parent_->resource_types_to_ignore_.find(v3_resource_type) == + parent_->resource_types_to_ignore_.end()) { + auto& subscription_name_map = + subscription_map[v3_resource_type]; + auto& resource_name_map = + parent_->resource_map_[v3_resource_type]; + std::set<TString> resources_in_current_request; + std::set<TString> resources_added_to_response; + for (const TString& resource_name : + request.resource_names()) { + resources_in_current_request.emplace(resource_name); + auto& subscription_state = + subscription_name_map[resource_name]; + auto& resource_state = resource_name_map[resource_name]; + // Subscribe if needed. + parent_->MaybeSubscribe(v3_resource_type, resource_name, + &subscription_state, &resource_state, + &update_queue); + // Send update if needed. + if (ClientNeedsResourceUpdate(resource_state, + &subscription_state)) { + gpr_log(GPR_INFO, + "ADS[%p]: Sending update for type=%s name=%s " + "version=%d", + this, request.type_url().c_str(), + resource_name.c_str(), resource_state.version); + resources_added_to_response.emplace(resource_name); + if (!response.has_value()) response.emplace(); + if (resource_state.resource.has_value()) { + auto* resource = response->add_resources(); + resource->CopyFrom(resource_state.resource.value()); + if (is_v2_) { + resource->set_type_url(request.type_url()); + } + } + } else { + gpr_log(GPR_INFO, + "ADS[%p]: client does not need update for " + "type=%s name=%s version=%d", + this, request.type_url().c_str(), + resource_name.c_str(), resource_state.version); + } + } + // Process unsubscriptions for any resource no longer + // present in the request's resource list. + parent_->ProcessUnsubscriptions( + v3_resource_type, resources_in_current_request, + &subscription_name_map, &resource_name_map); + // Send response if needed. + if (!resources_added_to_response.empty()) { + CompleteBuildingDiscoveryResponse( + v3_resource_type, request.type_url(), + ++resource_type_version[v3_resource_type], + subscription_name_map, resources_added_to_response, + &response.value()); + } + } + } + } + if (response.has_value()) { + gpr_log(GPR_INFO, "ADS[%p]: Sending response: %s", this, + response->DebugString().c_str()); + stream->Write(response.value()); + } + response.reset(); + // Look for updates and decide what to handle. + { + grpc_core::MutexLock lock(&parent_->ads_mu_); + if (!update_queue.empty()) { + const TString resource_type = + std::move(update_queue.front().first); + const TString resource_name = + std::move(update_queue.front().second); + update_queue.pop_front(); + const TString v2_resource_type = TypeUrlToV2(resource_type); + did_work = true; + gpr_log(GPR_INFO, "ADS[%p]: Received update for type=%s name=%s", + this, resource_type.c_str(), resource_name.c_str()); + auto& subscription_name_map = subscription_map[resource_type]; + auto& resource_name_map = parent_->resource_map_[resource_type]; + auto it = subscription_name_map.find(resource_name); + if (it != subscription_name_map.end()) { + SubscriptionState& subscription_state = it->second; + ResourceState& resource_state = + resource_name_map[resource_name]; + if (ClientNeedsResourceUpdate(resource_state, + &subscription_state)) { + gpr_log( + GPR_INFO, + "ADS[%p]: Sending update for type=%s name=%s version=%d", + this, resource_type.c_str(), resource_name.c_str(), + resource_state.version); + response.emplace(); + if (resource_state.resource.has_value()) { + auto* resource = response->add_resources(); + resource->CopyFrom(resource_state.resource.value()); + if (is_v2_) { + resource->set_type_url(v2_resource_type); + } + } + CompleteBuildingDiscoveryResponse( + resource_type, v2_resource_type, + ++resource_type_version[resource_type], + subscription_name_map, {resource_name}, + &response.value()); + } + } + } + } + if (response.has_value()) { + gpr_log(GPR_INFO, "ADS[%p]: Sending update response: %s", this, + response->DebugString().c_str()); + stream->Write(response.value()); + } + // If we didn't find anything to do, delay before the next loop + // iteration; otherwise, check whether we should exit and then + // immediately continue. + gpr_timespec deadline = + grpc_timeout_milliseconds_to_deadline(did_work ? 0 : 10); + { + grpc_core::MutexLock lock(&parent_->ads_mu_); + if (!parent_->ads_cond_.WaitUntil( + &parent_->ads_mu_, [this] { return parent_->ads_done_; }, + deadline)) { + break; + } + } + } + reader.join(); + }(); + // Clean up any subscriptions that were still active when the call + // finished. + { + grpc_core::MutexLock lock(&parent_->ads_mu_); + for (auto& p : subscription_map) { + const TString& type_url = p.first; + SubscriptionNameMap& subscription_name_map = p.second; + for (auto& q : subscription_name_map) { + const TString& resource_name = q.first; + SubscriptionState& subscription_state = q.second; + ResourceState& resource_state = + parent_->resource_map_[type_url][resource_name]; + resource_state.subscriptions.erase(&subscription_state); + } + } + } + gpr_log(GPR_INFO, "ADS[%p]: StreamAggregatedResources done", this); + parent_->RemoveClient(context->peer()); + return Status::OK; + } + + private: + static TString TypeUrlToV2(const TString& resource_type) { + if (resource_type == kLdsTypeUrl) return kLdsV2TypeUrl; + if (resource_type == kRdsTypeUrl) return kRdsV2TypeUrl; + if (resource_type == kCdsTypeUrl) return kCdsV2TypeUrl; + if (resource_type == kEdsTypeUrl) return kEdsV2TypeUrl; + return resource_type; + } + + static TString TypeUrlToV3(const TString& resource_type) { + if (resource_type == kLdsV2TypeUrl) return kLdsTypeUrl; + if (resource_type == kRdsV2TypeUrl) return kRdsTypeUrl; + if (resource_type == kCdsV2TypeUrl) return kCdsTypeUrl; + if (resource_type == kEdsV2TypeUrl) return kEdsTypeUrl; + return resource_type; + } + + // Starting a thread to do blocking read on the stream until cancel. + void BlockingRead(Stream* stream, std::deque<DiscoveryRequest>* requests, + bool* stream_closed) { + DiscoveryRequest request; + bool seen_first_request = false; + while (stream->Read(&request)) { + if (!seen_first_request) { + EXPECT_TRUE(request.has_node()); + ASSERT_FALSE(request.node().client_features().empty()); + EXPECT_EQ(request.node().client_features(0), + "envoy.lb.does_not_support_overprovisioning"); + CheckBuildVersion(request); + seen_first_request = true; + } + { + grpc_core::MutexLock lock(&parent_->ads_mu_); + requests->emplace_back(std::move(request)); + } + } + gpr_log(GPR_INFO, "ADS[%p]: Null read, stream closed", this); + grpc_core::MutexLock lock(&parent_->ads_mu_); + *stream_closed = true; + } + + static void CheckBuildVersion( + const ::envoy::api::v2::DiscoveryRequest& request) { + EXPECT_FALSE(request.node().build_version().empty()); + } + + static void CheckBuildVersion( + const ::envoy::service::discovery::v3::DiscoveryRequest& request) {} + + // Completing the building a DiscoveryResponse by adding common information + // for all resources and by adding all subscribed resources for LDS and CDS. + void CompleteBuildingDiscoveryResponse( + const TString& resource_type, const TString& v2_resource_type, + const int version, const SubscriptionNameMap& subscription_name_map, + const std::set<TString>& resources_added_to_response, + DiscoveryResponse* response) { + auto& response_state = + parent_->resource_type_response_state_[resource_type]; + if (response_state.state == ResponseState::NOT_SENT) { + response_state.state = ResponseState::SENT; + } + response->set_type_url(is_v2_ ? v2_resource_type : resource_type); + response->set_version_info(y_absl::StrCat(version)); + response->set_nonce(y_absl::StrCat(version)); + if (resource_type == kLdsTypeUrl || resource_type == kCdsTypeUrl) { + // For LDS and CDS we must send back all subscribed resources + // (even the unchanged ones) + for (const auto& p : subscription_name_map) { + const TString& resource_name = p.first; + if (resources_added_to_response.find(resource_name) == + resources_added_to_response.end()) { + const ResourceState& resource_state = + parent_->resource_map_[resource_type][resource_name]; + if (resource_state.resource.has_value()) { + auto* resource = response->add_resources(); + resource->CopyFrom(resource_state.resource.value()); + if (is_v2_) { + resource->set_type_url(v2_resource_type); + } + } + } + } + } + } + + AdsServiceImpl* parent_; + const bool is_v2_; + }; + + // Checks whether the client needs to receive a newer version of + // the resource. If so, updates subscription_state->current_version and + // returns true. + static bool ClientNeedsResourceUpdate(const ResourceState& resource_state, + SubscriptionState* subscription_state) { + if (subscription_state->current_version < resource_state.version) { + subscription_state->current_version = resource_state.version; + return true; + } + return false; + } + + // Subscribes to a resource if not already subscribed: + // 1. Sets the update_queue field in subscription_state. + // 2. Adds subscription_state to resource_state->subscriptions. + void MaybeSubscribe(const TString& resource_type, + const TString& resource_name, + SubscriptionState* subscription_state, + ResourceState* resource_state, + UpdateQueue* update_queue) { + // The update_queue will be null if we were not previously subscribed. + if (subscription_state->update_queue != nullptr) return; + subscription_state->update_queue = update_queue; + resource_state->subscriptions.emplace(subscription_state); + gpr_log(GPR_INFO, "ADS[%p]: subscribe to resource type %s name %s state %p", + this, resource_type.c_str(), resource_name.c_str(), + &subscription_state); + } + + // Removes subscriptions for resources no longer present in the + // current request. + void ProcessUnsubscriptions( + const TString& resource_type, + const std::set<TString>& resources_in_current_request, + SubscriptionNameMap* subscription_name_map, + ResourceNameMap* resource_name_map) { + for (auto it = subscription_name_map->begin(); + it != subscription_name_map->end();) { + const TString& resource_name = it->first; + SubscriptionState& subscription_state = it->second; + if (resources_in_current_request.find(resource_name) != + resources_in_current_request.end()) { + ++it; + continue; + } + gpr_log(GPR_INFO, "ADS[%p]: Unsubscribe to type=%s name=%s state=%p", + this, resource_type.c_str(), resource_name.c_str(), + &subscription_state); + auto resource_it = resource_name_map->find(resource_name); + GPR_ASSERT(resource_it != resource_name_map->end()); + auto& resource_state = resource_it->second; + resource_state.subscriptions.erase(&subscription_state); + if (resource_state.subscriptions.empty() && + !resource_state.resource.has_value()) { + resource_name_map->erase(resource_it); + } + it = subscription_name_map->erase(it); + } + } + + void AddClient(const TString& client) { + grpc_core::MutexLock lock(&clients_mu_); + clients_.insert(client); + } + + void RemoveClient(const TString& client) { + grpc_core::MutexLock lock(&clients_mu_); + clients_.erase(client); + } + + RpcService<::envoy::service::discovery::v2::AggregatedDiscoveryService, + ::envoy::api::v2::DiscoveryRequest, + ::envoy::api::v2::DiscoveryResponse> + v2_rpc_service_; + RpcService<::envoy::service::discovery::v3::AggregatedDiscoveryService, + ::envoy::service::discovery::v3::DiscoveryRequest, + ::envoy::service::discovery::v3::DiscoveryResponse> + v3_rpc_service_; + + std::atomic_bool seen_v2_client_{false}; + std::atomic_bool seen_v3_client_{false}; + + grpc_core::CondVar ads_cond_; + // Protect the members below. + grpc_core::Mutex ads_mu_; + bool ads_done_ = false; + Listener default_listener_; + RouteConfiguration default_route_config_; + Cluster default_cluster_; + std::map<TString /* type_url */, ResponseState> + resource_type_response_state_; + std::set<TString /*resource_type*/> resource_types_to_ignore_; + // An instance data member containing the current state of all resources. + // Note that an entry will exist whenever either of the following is true: + // - The resource exists (i.e., has been created by SetResource() and has not + // yet been destroyed by UnsetResource()). + // - There is at least one subscription for the resource. + ResourceMap resource_map_; + + grpc_core::Mutex clients_mu_; + std::set<TString> clients_; +}; + +class LrsServiceImpl : public std::enable_shared_from_this<LrsServiceImpl> { + public: + explicit LrsServiceImpl(int client_load_reporting_interval_seconds) + : v2_rpc_service_(this), + v3_rpc_service_(this), + client_load_reporting_interval_seconds_( + client_load_reporting_interval_seconds), + cluster_names_({kDefaultClusterName}) {} + + ::envoy::service::load_stats::v2::LoadReportingService::Service* + v2_rpc_service() { + return &v2_rpc_service_; + } + + ::envoy::service::load_stats::v3::LoadReportingService::Service* + v3_rpc_service() { + return &v3_rpc_service_; + } + + size_t request_count() { + return v2_rpc_service_.request_count() + v3_rpc_service_.request_count(); + } + + size_t response_count() { + return v2_rpc_service_.response_count() + v3_rpc_service_.response_count(); + } + + // Must be called before the LRS call is started. + void set_send_all_clusters(bool send_all_clusters) { + send_all_clusters_ = send_all_clusters; + } + void set_cluster_names(const std::set<TString>& cluster_names) { + cluster_names_ = cluster_names; + } + + void Start() { + lrs_done_ = false; + result_queue_.clear(); + } + + void Shutdown() { + { + grpc_core::MutexLock lock(&lrs_mu_); + NotifyDoneWithLrsCallLocked(); + } + gpr_log(GPR_INFO, "LRS[%p]: shut down", this); + } + + std::vector<ClientStats> WaitForLoadReport() { + grpc_core::MutexLock lock(&load_report_mu_); + grpc_core::CondVar cv; + if (result_queue_.empty()) { + load_report_cond_ = &cv; + load_report_cond_->WaitUntil(&load_report_mu_, + [this] { return !result_queue_.empty(); }); + load_report_cond_ = nullptr; + } + std::vector<ClientStats> result = std::move(result_queue_.front()); + result_queue_.pop_front(); + return result; + } + + void NotifyDoneWithLrsCall() { + grpc_core::MutexLock lock(&lrs_mu_); + NotifyDoneWithLrsCallLocked(); + } + + private: + template <class RpcApi, class LoadStatsRequest, class LoadStatsResponse> + class RpcService : public CountedService<typename RpcApi::Service> { + public: + using Stream = ServerReaderWriter<LoadStatsResponse, LoadStatsRequest>; + + explicit RpcService(LrsServiceImpl* parent) : parent_(parent) {} + + Status StreamLoadStats(ServerContext* /*context*/, + Stream* stream) override { + gpr_log(GPR_INFO, "LRS[%p]: StreamLoadStats starts", this); + EXPECT_GT(parent_->client_load_reporting_interval_seconds_, 0); + // Take a reference of the LrsServiceImpl object, reference will go + // out of scope after this method exits. + std::shared_ptr<LrsServiceImpl> lrs_service_impl = + parent_->shared_from_this(); + // Read initial request. + LoadStatsRequest request; + if (stream->Read(&request)) { + CountedService<typename RpcApi::Service>::IncreaseRequestCount(); + // Verify client features. + EXPECT_THAT( + request.node().client_features(), + ::testing::Contains("envoy.lrs.supports_send_all_clusters")); + // Send initial response. + LoadStatsResponse response; + if (parent_->send_all_clusters_) { + response.set_send_all_clusters(true); + } else { + for (const TString& cluster_name : parent_->cluster_names_) { + response.add_clusters(cluster_name); + } + } + response.mutable_load_reporting_interval()->set_seconds( + parent_->client_load_reporting_interval_seconds_); + stream->Write(response); + CountedService<typename RpcApi::Service>::IncreaseResponseCount(); + // Wait for report. + request.Clear(); + while (stream->Read(&request)) { + gpr_log(GPR_INFO, "LRS[%p]: received client load report message: %s", + this, request.DebugString().c_str()); + std::vector<ClientStats> stats; + for (const auto& cluster_stats : request.cluster_stats()) { + stats.emplace_back(cluster_stats); + } + grpc_core::MutexLock lock(&parent_->load_report_mu_); + parent_->result_queue_.emplace_back(std::move(stats)); + if (parent_->load_report_cond_ != nullptr) { + parent_->load_report_cond_->Signal(); + } + } + // Wait until notified done. + grpc_core::MutexLock lock(&parent_->lrs_mu_); + parent_->lrs_cv_.WaitUntil(&parent_->lrs_mu_, + [this] { return parent_->lrs_done_; }); + } + gpr_log(GPR_INFO, "LRS[%p]: StreamLoadStats done", this); + return Status::OK; + } + + private: + LrsServiceImpl* parent_; + }; + + void NotifyDoneWithLrsCallLocked() { + if (!lrs_done_) { + lrs_done_ = true; + lrs_cv_.Broadcast(); + } + } + + RpcService<::envoy::service::load_stats::v2::LoadReportingService, + ::envoy::service::load_stats::v2::LoadStatsRequest, + ::envoy::service::load_stats::v2::LoadStatsResponse> + v2_rpc_service_; + RpcService<::envoy::service::load_stats::v3::LoadReportingService, + ::envoy::service::load_stats::v3::LoadStatsRequest, + ::envoy::service::load_stats::v3::LoadStatsResponse> + v3_rpc_service_; + + const int client_load_reporting_interval_seconds_; + bool send_all_clusters_ = false; + std::set<TString> cluster_names_; + + grpc_core::CondVar lrs_cv_; + grpc_core::Mutex lrs_mu_; // Protects lrs_done_. + bool lrs_done_ = false; + + grpc_core::Mutex load_report_mu_; // Protects the members below. + grpc_core::CondVar* load_report_cond_ = nullptr; + std::deque<std::vector<ClientStats>> result_queue_; +}; + +class TestType { + public: + TestType(bool use_xds_resolver, bool enable_load_reporting, + bool enable_rds_testing = false, bool use_v2 = false) + : use_xds_resolver_(use_xds_resolver), + enable_load_reporting_(enable_load_reporting), + enable_rds_testing_(enable_rds_testing), + use_v2_(use_v2) {} + + bool use_xds_resolver() const { return use_xds_resolver_; } + bool enable_load_reporting() const { return enable_load_reporting_; } + bool enable_rds_testing() const { return enable_rds_testing_; } + bool use_v2() const { return use_v2_; } + + TString AsString() const { + TString retval = (use_xds_resolver_ ? "XdsResolver" : "FakeResolver"); + retval += (use_v2_ ? "V2" : "V3"); + if (enable_load_reporting_) retval += "WithLoadReporting"; + if (enable_rds_testing_) retval += "Rds"; + return retval; + } + + private: + const bool use_xds_resolver_; + const bool enable_load_reporting_; + const bool enable_rds_testing_; + const bool use_v2_; +}; + +class XdsEnd2endTest : public ::testing::TestWithParam<TestType> { + protected: + XdsEnd2endTest(size_t num_backends, size_t num_balancers, + int client_load_reporting_interval_seconds = 100) + : num_backends_(num_backends), + num_balancers_(num_balancers), + client_load_reporting_interval_seconds_( + client_load_reporting_interval_seconds) {} + + static void SetUpTestCase() { + // Make the backup poller poll very frequently in order to pick up + // updates from all the subchannels's FDs. + GPR_GLOBAL_CONFIG_SET(grpc_client_channel_backup_poll_interval_ms, 1); +#if TARGET_OS_IPHONE + // Workaround Apple CFStream bug + gpr_setenv("grpc_cfstream", "0"); +#endif + grpc_init(); + } + + static void TearDownTestCase() { grpc_shutdown(); } + + void SetUp() override { + gpr_setenv("GRPC_XDS_EXPERIMENTAL_V3_SUPPORT", "true"); + gpr_setenv("GRPC_XDS_BOOTSTRAP", + GetParam().use_v2() ? g_bootstrap_file_v2 : g_bootstrap_file_v3); + g_port_saver->Reset(); + response_generator_ = + grpc_core::MakeRefCounted<grpc_core::FakeResolverResponseGenerator>(); + // Inject xDS channel response generator. + lb_channel_response_generator_ = + grpc_core::MakeRefCounted<grpc_core::FakeResolverResponseGenerator>(); + xds_channel_args_to_add_.emplace_back( + grpc_core::FakeResolverResponseGenerator::MakeChannelArg( + lb_channel_response_generator_.get())); + if (xds_resource_does_not_exist_timeout_ms_ > 0) { + xds_channel_args_to_add_.emplace_back(grpc_channel_arg_integer_create( + const_cast<char*>(GRPC_ARG_XDS_RESOURCE_DOES_NOT_EXIST_TIMEOUT_MS), + xds_resource_does_not_exist_timeout_ms_)); + } + xds_channel_args_.num_args = xds_channel_args_to_add_.size(); + xds_channel_args_.args = xds_channel_args_to_add_.data(); + grpc_core::internal::SetXdsChannelArgsForTest(&xds_channel_args_); + // Make sure each test creates a new XdsClient instance rather than + // reusing the one from the previous test. This avoids spurious failures + // caused when a load reporting test runs after a non-load reporting test + // and the XdsClient is still talking to the old LRS server, which fails + // because it's not expecting the client to connect. It also + // ensures that each test can independently set the global channel + // args for the xDS channel. + grpc_core::internal::UnsetGlobalXdsClientForTest(); + // Start the backends. + for (size_t i = 0; i < num_backends_; ++i) { + backends_.emplace_back(new BackendServerThread); + backends_.back()->Start(); + } + // Start the load balancers. + for (size_t i = 0; i < num_balancers_; ++i) { + balancers_.emplace_back( + new BalancerServerThread(GetParam().enable_load_reporting() + ? client_load_reporting_interval_seconds_ + : 0)); + balancers_.back()->Start(); + if (GetParam().enable_rds_testing()) { + balancers_[i]->ads_service()->SetLdsToUseDynamicRds(); + } + } + ResetStub(); + } + + const char* DefaultEdsServiceName() const { + return GetParam().use_xds_resolver() ? kDefaultEdsServiceName : kServerName; + } + + void TearDown() override { + ShutdownAllBackends(); + for (auto& balancer : balancers_) balancer->Shutdown(); + // Clear global xDS channel args, since they will go out of scope + // when this test object is destroyed. + grpc_core::internal::SetXdsChannelArgsForTest(nullptr); + } + + void StartAllBackends() { + for (auto& backend : backends_) backend->Start(); + } + + void StartBackend(size_t index) { backends_[index]->Start(); } + + void ShutdownAllBackends() { + for (auto& backend : backends_) backend->Shutdown(); + } + + void ShutdownBackend(size_t index) { backends_[index]->Shutdown(); } + + void ResetStub(int failover_timeout = 0) { + channel_ = CreateChannel(failover_timeout); + stub_ = grpc::testing::EchoTestService::NewStub(channel_); + stub1_ = grpc::testing::EchoTest1Service::NewStub(channel_); + stub2_ = grpc::testing::EchoTest2Service::NewStub(channel_); + } + + std::shared_ptr<Channel> CreateChannel( + int failover_timeout = 0, const char* server_name = kServerName) { + ChannelArguments args; + if (failover_timeout > 0) { + args.SetInt(GRPC_ARG_PRIORITY_FAILOVER_TIMEOUT_MS, failover_timeout); + } + // If the parent channel is using the fake resolver, we inject the + // response generator here. + if (!GetParam().use_xds_resolver()) { + args.SetPointer(GRPC_ARG_FAKE_RESOLVER_RESPONSE_GENERATOR, + response_generator_.get()); + } + TString uri = y_absl::StrCat( + GetParam().use_xds_resolver() ? "xds" : "fake", ":///", server_name); + // TODO(dgq): templatize tests to run everything using both secure and + // insecure channel credentials. + grpc_channel_credentials* channel_creds = + grpc_fake_transport_security_credentials_create(); + grpc_call_credentials* call_creds = grpc_md_only_test_credentials_create( + g_kCallCredsMdKey, g_kCallCredsMdValue, false); + std::shared_ptr<ChannelCredentials> creds( + new SecureChannelCredentials(grpc_composite_channel_credentials_create( + channel_creds, call_creds, nullptr))); + call_creds->Unref(); + channel_creds->Unref(); + return ::grpc::CreateCustomChannel(uri, creds, args); + } + + enum RpcService { + SERVICE_ECHO, + SERVICE_ECHO1, + SERVICE_ECHO2, + }; + + enum RpcMethod { + METHOD_ECHO, + METHOD_ECHO1, + METHOD_ECHO2, + }; + + struct RpcOptions { + RpcService service = SERVICE_ECHO; + RpcMethod method = METHOD_ECHO; + int timeout_ms = 1000; + bool wait_for_ready = false; + bool server_fail = false; + std::vector<std::pair<TString, TString>> metadata; + + RpcOptions() {} + + RpcOptions& set_rpc_service(RpcService rpc_service) { + service = rpc_service; + return *this; + } + + RpcOptions& set_rpc_method(RpcMethod rpc_method) { + method = rpc_method; + return *this; + } + + RpcOptions& set_timeout_ms(int rpc_timeout_ms) { + timeout_ms = rpc_timeout_ms; + return *this; + } + + RpcOptions& set_wait_for_ready(bool rpc_wait_for_ready) { + wait_for_ready = rpc_wait_for_ready; + return *this; + } + + RpcOptions& set_server_fail(bool rpc_server_fail) { + server_fail = rpc_server_fail; + return *this; + } + + RpcOptions& set_metadata( + std::vector<std::pair<TString, TString>> rpc_metadata) { + metadata = rpc_metadata; + return *this; + } + }; + + template <typename Stub> + Status SendRpcMethod(Stub* stub, const RpcOptions& rpc_options, + ClientContext* context, EchoRequest& request, + EchoResponse* response) { + switch (rpc_options.method) { + case METHOD_ECHO: + return (*stub)->Echo(context, request, response); + case METHOD_ECHO1: + return (*stub)->Echo1(context, request, response); + case METHOD_ECHO2: + return (*stub)->Echo2(context, request, response); + } + } + + void ResetBackendCounters(size_t start_index = 0, size_t stop_index = 0) { + if (stop_index == 0) stop_index = backends_.size(); + for (size_t i = start_index; i < stop_index; ++i) { + backends_[i]->backend_service()->ResetCounters(); + backends_[i]->backend_service1()->ResetCounters(); + backends_[i]->backend_service2()->ResetCounters(); + } + } + + bool SeenAllBackends(size_t start_index = 0, size_t stop_index = 0, + const RpcOptions& rpc_options = RpcOptions()) { + if (stop_index == 0) stop_index = backends_.size(); + for (size_t i = start_index; i < stop_index; ++i) { + switch (rpc_options.service) { + case SERVICE_ECHO: + if (backends_[i]->backend_service()->request_count() == 0) + return false; + break; + case SERVICE_ECHO1: + if (backends_[i]->backend_service1()->request_count() == 0) + return false; + break; + case SERVICE_ECHO2: + if (backends_[i]->backend_service2()->request_count() == 0) + return false; + break; + } + } + return true; + } + + void SendRpcAndCount(int* num_total, int* num_ok, int* num_failure, + int* num_drops, + const RpcOptions& rpc_options = RpcOptions()) { + const Status status = SendRpc(rpc_options); + if (status.ok()) { + ++*num_ok; + } else { + if (status.error_message() == "Call dropped by load balancing policy") { + ++*num_drops; + } else { + ++*num_failure; + } + } + ++*num_total; + } + + std::tuple<int, int, int> WaitForAllBackends( + size_t start_index = 0, size_t stop_index = 0, bool reset_counters = true, + const RpcOptions& rpc_options = RpcOptions(), + bool allow_failures = false) { + int num_ok = 0; + int num_failure = 0; + int num_drops = 0; + int num_total = 0; + while (!SeenAllBackends(start_index, stop_index, rpc_options)) { + SendRpcAndCount(&num_total, &num_ok, &num_failure, &num_drops, + rpc_options); + } + if (reset_counters) ResetBackendCounters(); + gpr_log(GPR_INFO, + "Performed %d warm up requests against the backends. " + "%d succeeded, %d failed, %d dropped.", + num_total, num_ok, num_failure, num_drops); + if (!allow_failures) EXPECT_EQ(num_failure, 0); + return std::make_tuple(num_ok, num_failure, num_drops); + } + + void WaitForBackend(size_t backend_idx, bool reset_counters = true, + bool require_success = false) { + gpr_log(GPR_INFO, "========= WAITING FOR BACKEND %lu ==========", + static_cast<unsigned long>(backend_idx)); + do { + Status status = SendRpc(); + if (require_success) { + EXPECT_TRUE(status.ok()) << "code=" << status.error_code() + << " message=" << status.error_message(); + } + } while (backends_[backend_idx]->backend_service()->request_count() == 0); + if (reset_counters) ResetBackendCounters(); + gpr_log(GPR_INFO, "========= BACKEND %lu READY ==========", + static_cast<unsigned long>(backend_idx)); + } + + grpc_core::ServerAddressList CreateAddressListFromPortList( + const std::vector<int>& ports) { + grpc_core::ServerAddressList addresses; + for (int port : ports) { + TString lb_uri_str = y_absl::StrCat("ipv4:127.0.0.1:", port); + grpc_uri* lb_uri = grpc_uri_parse(lb_uri_str.c_str(), true); + GPR_ASSERT(lb_uri != nullptr); + grpc_resolved_address address; + GPR_ASSERT(grpc_parse_uri(lb_uri, &address)); + addresses.emplace_back(address.addr, address.len, nullptr); + grpc_uri_destroy(lb_uri); + } + return addresses; + } + + void SetNextResolution(const std::vector<int>& ports) { + if (GetParam().use_xds_resolver()) return; // Not used with xds resolver. + grpc_core::ExecCtx exec_ctx; + grpc_core::Resolver::Result result; + result.addresses = CreateAddressListFromPortList(ports); + grpc_error* error = GRPC_ERROR_NONE; + const char* service_config_json = + GetParam().enable_load_reporting() + ? kDefaultServiceConfig + : kDefaultServiceConfigWithoutLoadReporting; + result.service_config = + grpc_core::ServiceConfig::Create(nullptr, service_config_json, &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_string(error); + ASSERT_NE(result.service_config.get(), nullptr); + response_generator_->SetResponse(std::move(result)); + } + + void SetNextResolutionForLbChannelAllBalancers( + const char* service_config_json = nullptr, + const char* expected_targets = nullptr) { + std::vector<int> ports; + for (size_t i = 0; i < balancers_.size(); ++i) { + ports.emplace_back(balancers_[i]->port()); + } + SetNextResolutionForLbChannel(ports, service_config_json, expected_targets); + } + + void SetNextResolutionForLbChannel(const std::vector<int>& ports, + const char* service_config_json = nullptr, + const char* expected_targets = nullptr) { + grpc_core::ExecCtx exec_ctx; + grpc_core::Resolver::Result result; + result.addresses = CreateAddressListFromPortList(ports); + if (service_config_json != nullptr) { + grpc_error* error = GRPC_ERROR_NONE; + result.service_config = grpc_core::ServiceConfig::Create( + nullptr, service_config_json, &error); + ASSERT_NE(result.service_config.get(), nullptr); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_string(error); + } + if (expected_targets != nullptr) { + grpc_arg expected_targets_arg = grpc_channel_arg_string_create( + const_cast<char*>(GRPC_ARG_FAKE_SECURITY_EXPECTED_TARGETS), + const_cast<char*>(expected_targets)); + result.args = + grpc_channel_args_copy_and_add(nullptr, &expected_targets_arg, 1); + } + lb_channel_response_generator_->SetResponse(std::move(result)); + } + + void SetNextReresolutionResponse(const std::vector<int>& ports) { + grpc_core::ExecCtx exec_ctx; + grpc_core::Resolver::Result result; + result.addresses = CreateAddressListFromPortList(ports); + response_generator_->SetReresolutionResponse(std::move(result)); + } + + const std::vector<int> GetBackendPorts(size_t start_index = 0, + size_t stop_index = 0) const { + if (stop_index == 0) stop_index = backends_.size(); + std::vector<int> backend_ports; + for (size_t i = start_index; i < stop_index; ++i) { + backend_ports.push_back(backends_[i]->port()); + } + return backend_ports; + } + + Status SendRpc(const RpcOptions& rpc_options = RpcOptions(), + EchoResponse* response = nullptr) { + const bool local_response = (response == nullptr); + if (local_response) response = new EchoResponse; + EchoRequest request; + ClientContext context; + for (const auto& metadata : rpc_options.metadata) { + context.AddMetadata(metadata.first, metadata.second); + } + if (rpc_options.timeout_ms != 0) { + context.set_deadline( + grpc_timeout_milliseconds_to_deadline(rpc_options.timeout_ms)); + } + if (rpc_options.wait_for_ready) context.set_wait_for_ready(true); + request.set_message(kRequestMessage); + if (rpc_options.server_fail) { + request.mutable_param()->mutable_expected_error()->set_code( + GRPC_STATUS_FAILED_PRECONDITION); + } + Status status; + switch (rpc_options.service) { + case SERVICE_ECHO: + status = + SendRpcMethod(&stub_, rpc_options, &context, request, response); + break; + case SERVICE_ECHO1: + status = + SendRpcMethod(&stub1_, rpc_options, &context, request, response); + break; + case SERVICE_ECHO2: + status = + SendRpcMethod(&stub2_, rpc_options, &context, request, response); + break; + } + if (local_response) delete response; + return status; + } + + void CheckRpcSendOk(const size_t times = 1, + const RpcOptions& rpc_options = RpcOptions()) { + for (size_t i = 0; i < times; ++i) { + EchoResponse response; + const Status status = SendRpc(rpc_options, &response); + EXPECT_TRUE(status.ok()) << "code=" << status.error_code() + << " message=" << status.error_message(); + EXPECT_EQ(response.message(), kRequestMessage); + } + } + + void CheckRpcSendFailure(const size_t times = 1, + const RpcOptions& rpc_options = RpcOptions()) { + for (size_t i = 0; i < times; ++i) { + const Status status = SendRpc(rpc_options); + EXPECT_FALSE(status.ok()); + } + } + + void SetRouteConfiguration(int idx, const RouteConfiguration& route_config) { + if (GetParam().enable_rds_testing()) { + balancers_[idx]->ads_service()->SetRdsResource(route_config); + } else { + balancers_[idx]->ads_service()->SetLdsResource( + AdsServiceImpl::BuildListener(route_config)); + } + } + + AdsServiceImpl::ResponseState RouteConfigurationResponseState(int idx) const { + AdsServiceImpl* ads_service = balancers_[idx]->ads_service(); + if (GetParam().enable_rds_testing()) { + return ads_service->rds_response_state(); + } + return ads_service->lds_response_state(); + } + + public: + // This method could benefit test subclasses; to make it accessible + // via bind with a qualified name, it needs to be public. + void SetEdsResourceWithDelay(size_t i, + const ClusterLoadAssignment& assignment, + int delay_ms) { + GPR_ASSERT(delay_ms > 0); + gpr_sleep_until(grpc_timeout_milliseconds_to_deadline(delay_ms)); + balancers_[i]->ads_service()->SetEdsResource(assignment); + } + + protected: + class ServerThread { + public: + ServerThread() : port_(g_port_saver->GetPort()) {} + virtual ~ServerThread(){}; + + void Start() { + gpr_log(GPR_INFO, "starting %s server on port %d", Type(), port_); + GPR_ASSERT(!running_); + running_ = true; + StartAllServices(); + grpc_core::Mutex mu; + // We need to acquire the lock here in order to prevent the notify_one + // by ServerThread::Serve from firing before the wait below is hit. + grpc_core::MutexLock lock(&mu); + grpc_core::CondVar cond; + thread_.reset( + new std::thread(std::bind(&ServerThread::Serve, this, &mu, &cond))); + cond.Wait(&mu); + gpr_log(GPR_INFO, "%s server startup complete", Type()); + } + + void Serve(grpc_core::Mutex* mu, grpc_core::CondVar* cond) { + // We need to acquire the lock here in order to prevent the notify_one + // below from firing before its corresponding wait is executed. + grpc_core::MutexLock lock(mu); + std::ostringstream server_address; + server_address << "localhost:" << port_; + ServerBuilder builder; + std::shared_ptr<ServerCredentials> creds(new SecureServerCredentials( + grpc_fake_transport_security_server_credentials_create())); + builder.AddListeningPort(server_address.str(), creds); + RegisterAllServices(&builder); + server_ = builder.BuildAndStart(); + cond->Signal(); + } + + void Shutdown() { + if (!running_) return; + gpr_log(GPR_INFO, "%s about to shutdown", Type()); + ShutdownAllServices(); + server_->Shutdown(grpc_timeout_milliseconds_to_deadline(0)); + thread_->join(); + gpr_log(GPR_INFO, "%s shutdown completed", Type()); + running_ = false; + } + + int port() const { return port_; } + + private: + virtual void RegisterAllServices(ServerBuilder* builder) = 0; + virtual void StartAllServices() = 0; + virtual void ShutdownAllServices() = 0; + + virtual const char* Type() = 0; + + const int port_; + std::unique_ptr<Server> server_; + std::unique_ptr<std::thread> thread_; + bool running_ = false; + }; + + class BackendServerThread : public ServerThread { + public: + BackendServiceImpl<::grpc::testing::EchoTestService::Service>* + backend_service() { + return &backend_service_; + } + BackendServiceImpl<::grpc::testing::EchoTest1Service::Service>* + backend_service1() { + return &backend_service1_; + } + BackendServiceImpl<::grpc::testing::EchoTest2Service::Service>* + backend_service2() { + return &backend_service2_; + } + + private: + void RegisterAllServices(ServerBuilder* builder) override { + builder->RegisterService(&backend_service_); + builder->RegisterService(&backend_service1_); + builder->RegisterService(&backend_service2_); + } + + void StartAllServices() override { + backend_service_.Start(); + backend_service1_.Start(); + backend_service2_.Start(); + } + + void ShutdownAllServices() override { + backend_service_.Shutdown(); + backend_service1_.Shutdown(); + backend_service2_.Shutdown(); + } + + const char* Type() override { return "Backend"; } + + BackendServiceImpl<::grpc::testing::EchoTestService::Service> + backend_service_; + BackendServiceImpl<::grpc::testing::EchoTest1Service::Service> + backend_service1_; + BackendServiceImpl<::grpc::testing::EchoTest2Service::Service> + backend_service2_; + }; + + class BalancerServerThread : public ServerThread { + public: + explicit BalancerServerThread(int client_load_reporting_interval = 0) + : ads_service_(new AdsServiceImpl(client_load_reporting_interval > 0)), + lrs_service_(new LrsServiceImpl(client_load_reporting_interval)) {} + + AdsServiceImpl* ads_service() { return ads_service_.get(); } + LrsServiceImpl* lrs_service() { return lrs_service_.get(); } + + private: + void RegisterAllServices(ServerBuilder* builder) override { + builder->RegisterService(ads_service_->v2_rpc_service()); + builder->RegisterService(ads_service_->v3_rpc_service()); + builder->RegisterService(lrs_service_->v2_rpc_service()); + builder->RegisterService(lrs_service_->v3_rpc_service()); + } + + void StartAllServices() override { + ads_service_->Start(); + lrs_service_->Start(); + } + + void ShutdownAllServices() override { + ads_service_->Shutdown(); + lrs_service_->Shutdown(); + } + + const char* Type() override { return "Balancer"; } + + std::shared_ptr<AdsServiceImpl> ads_service_; + std::shared_ptr<LrsServiceImpl> lrs_service_; + }; + + const size_t num_backends_; + const size_t num_balancers_; + const int client_load_reporting_interval_seconds_; + std::shared_ptr<Channel> channel_; + std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_; + std::unique_ptr<grpc::testing::EchoTest1Service::Stub> stub1_; + std::unique_ptr<grpc::testing::EchoTest2Service::Stub> stub2_; + std::vector<std::unique_ptr<BackendServerThread>> backends_; + std::vector<std::unique_ptr<BalancerServerThread>> balancers_; + grpc_core::RefCountedPtr<grpc_core::FakeResolverResponseGenerator> + response_generator_; + grpc_core::RefCountedPtr<grpc_core::FakeResolverResponseGenerator> + lb_channel_response_generator_; + int xds_resource_does_not_exist_timeout_ms_ = 0; + y_absl::InlinedVector<grpc_arg, 2> xds_channel_args_to_add_; + grpc_channel_args xds_channel_args_; +}; + +class BasicTest : public XdsEnd2endTest { + public: + BasicTest() : XdsEnd2endTest(4, 1) {} +}; + +// Tests that the balancer sends the correct response to the client, and the +// client sends RPCs to the backends using the default child policy. +TEST_P(BasicTest, Vanilla) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + const size_t kNumRpcsPerAddress = 100; + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName())); + // Make sure that trying to connect works without a call. + channel_->GetState(true /* try_to_connect */); + // We need to wait for all backends to come online. + WaitForAllBackends(); + // Send kNumRpcsPerAddress RPCs per server. + CheckRpcSendOk(kNumRpcsPerAddress * num_backends_); + // Each backend should have gotten 100 requests. + for (size_t i = 0; i < backends_.size(); ++i) { + EXPECT_EQ(kNumRpcsPerAddress, + backends_[i]->backend_service()->request_count()); + } + // Check LB policy name for the channel. + EXPECT_EQ((GetParam().use_xds_resolver() ? "xds_cluster_manager_experimental" + : "eds_experimental"), + channel_->GetLoadBalancingPolicyName()); +} + +TEST_P(BasicTest, IgnoresUnhealthyEndpoints) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + const size_t kNumRpcsPerAddress = 100; + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", + GetBackendPorts(), + kDefaultLocalityWeight, + kDefaultLocalityPriority, + {HealthStatus::DRAINING}}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName())); + // Make sure that trying to connect works without a call. + channel_->GetState(true /* try_to_connect */); + // We need to wait for all backends to come online. + WaitForAllBackends(/*start_index=*/1); + // Send kNumRpcsPerAddress RPCs per server. + CheckRpcSendOk(kNumRpcsPerAddress * (num_backends_ - 1)); + // Each backend should have gotten 100 requests. + for (size_t i = 1; i < backends_.size(); ++i) { + EXPECT_EQ(kNumRpcsPerAddress, + backends_[i]->backend_service()->request_count()); + } +} + +// Tests that subchannel sharing works when the same backend is listed multiple +// times. +TEST_P(BasicTest, SameBackendListedMultipleTimes) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Same backend listed twice. + std::vector<int> ports(2, backends_[0]->port()); + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", ports}, + }); + const size_t kNumRpcsPerAddress = 10; + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName())); + // We need to wait for the backend to come online. + WaitForBackend(0); + // Send kNumRpcsPerAddress RPCs per server. + CheckRpcSendOk(kNumRpcsPerAddress * ports.size()); + // Backend should have gotten 20 requests. + EXPECT_EQ(kNumRpcsPerAddress * ports.size(), + backends_[0]->backend_service()->request_count()); + // And they should have come from a single client port, because of + // subchannel sharing. + EXPECT_EQ(1UL, backends_[0]->backend_service()->clients().size()); +} + +// Tests that RPCs will be blocked until a non-empty serverlist is received. +TEST_P(BasicTest, InitiallyEmptyServerlist) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + const int kServerlistDelayMs = 500 * grpc_test_slowdown_factor(); + const int kCallDeadlineMs = kServerlistDelayMs * 2; + // First response is an empty serverlist, sent right away. + AdsServiceImpl::EdsResourceArgs::Locality empty_locality("locality0", {}); + AdsServiceImpl::EdsResourceArgs args({ + empty_locality, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName())); + // Send non-empty serverlist only after kServerlistDelayMs. + args = AdsServiceImpl::EdsResourceArgs({ + {"locality0", GetBackendPorts()}, + }); + std::thread delayed_resource_setter( + std::bind(&BasicTest::SetEdsResourceWithDelay, this, 0, + AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()), + kServerlistDelayMs)); + const auto t0 = system_clock::now(); + // Client will block: LB will initially send empty serverlist. + CheckRpcSendOk( + 1, RpcOptions().set_timeout_ms(kCallDeadlineMs).set_wait_for_ready(true)); + const auto ellapsed_ms = + std::chrono::duration_cast<std::chrono::milliseconds>( + system_clock::now() - t0); + // but eventually, the LB sends a serverlist update that allows the call to + // proceed. The call delay must be larger than the delay in sending the + // populated serverlist but under the call's deadline (which is enforced by + // the call's deadline). + EXPECT_GT(ellapsed_ms.count(), kServerlistDelayMs); + delayed_resource_setter.join(); +} + +// Tests that RPCs will fail with UNAVAILABLE instead of DEADLINE_EXCEEDED if +// all the servers are unreachable. +TEST_P(BasicTest, AllServersUnreachableFailFast) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + const size_t kNumUnreachableServers = 5; + std::vector<int> ports; + for (size_t i = 0; i < kNumUnreachableServers; ++i) { + ports.push_back(g_port_saver->GetPort()); + } + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", ports}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName())); + const Status status = SendRpc(); + // The error shouldn't be DEADLINE_EXCEEDED. + EXPECT_EQ(StatusCode::UNAVAILABLE, status.error_code()); +} + +// Tests that RPCs fail when the backends are down, and will succeed again after +// the backends are restarted. +TEST_P(BasicTest, BackendsRestart) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName())); + WaitForAllBackends(); + // Stop backends. RPCs should fail. + ShutdownAllBackends(); + // Sending multiple failed requests instead of just one to ensure that the + // client notices that all backends are down before we restart them. If we + // didn't do this, then a single RPC could fail here due to the race condition + // between the LB pick and the GOAWAY from the chosen backend being shut down, + // which would not actually prove that the client noticed that all of the + // backends are down. Then, when we send another request below (which we + // expect to succeed), if the callbacks happen in the wrong order, the same + // race condition could happen again due to the client not yet having noticed + // that the backends were all down. + CheckRpcSendFailure(num_backends_); + // Restart all backends. RPCs should start succeeding again. + StartAllBackends(); + CheckRpcSendOk(1, RpcOptions().set_timeout_ms(2000).set_wait_for_ready(true)); +} + +TEST_P(BasicTest, IgnoresDuplicateUpdates) { + const size_t kNumRpcsPerAddress = 100; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName())); + // Wait for all backends to come online. + WaitForAllBackends(); + // Send kNumRpcsPerAddress RPCs per server, but send an EDS update in + // between. If the update is not ignored, this will cause the + // round_robin policy to see an update, which will randomly reset its + // position in the address list. + for (size_t i = 0; i < kNumRpcsPerAddress; ++i) { + CheckRpcSendOk(2); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName())); + CheckRpcSendOk(2); + } + // Each backend should have gotten the right number of requests. + for (size_t i = 1; i < backends_.size(); ++i) { + EXPECT_EQ(kNumRpcsPerAddress, + backends_[i]->backend_service()->request_count()); + } +} + +using XdsResolverOnlyTest = BasicTest; + +// Tests switching over from one cluster to another. +TEST_P(XdsResolverOnlyTest, ChangeClusters) { + const char* kNewClusterName = "new_cluster_name"; + const char* kNewEdsServiceName = "new_eds_service_name"; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts(0, 2)}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args)); + // We need to wait for all backends to come online. + WaitForAllBackends(0, 2); + // Populate new EDS resource. + AdsServiceImpl::EdsResourceArgs args2({ + {"locality0", GetBackendPorts(2, 4)}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args2, kNewEdsServiceName)); + // Populate new CDS resource. + Cluster new_cluster = balancers_[0]->ads_service()->default_cluster(); + new_cluster.set_name(kNewClusterName); + new_cluster.mutable_eds_cluster_config()->set_service_name( + kNewEdsServiceName); + balancers_[0]->ads_service()->SetCdsResource(new_cluster); + // Change RDS resource to point to new cluster. + RouteConfiguration new_route_config = + balancers_[0]->ads_service()->default_route_config(); + new_route_config.mutable_virtual_hosts(0) + ->mutable_routes(0) + ->mutable_route() + ->set_cluster(kNewClusterName); + Listener listener = + balancers_[0]->ads_service()->BuildListener(new_route_config); + balancers_[0]->ads_service()->SetLdsResource(listener); + // Wait for all new backends to be used. + std::tuple<int, int, int> counts = WaitForAllBackends(2, 4); + // Make sure no RPCs failed in the transition. + EXPECT_EQ(0, std::get<1>(counts)); +} + +// Tests that we go into TRANSIENT_FAILURE if the Cluster disappears. +TEST_P(XdsResolverOnlyTest, ClusterRemoved) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args)); + // We need to wait for all backends to come online. + WaitForAllBackends(); + // Unset CDS resource. + balancers_[0]->ads_service()->UnsetResource(kCdsTypeUrl, kDefaultClusterName); + // Wait for RPCs to start failing. + do { + } while (SendRpc(RpcOptions(), nullptr).ok()); + // Make sure RPCs are still failing. + CheckRpcSendFailure(1000); + // Make sure we ACK'ed the update. + EXPECT_EQ(balancers_[0]->ads_service()->cds_response_state().state, + AdsServiceImpl::ResponseState::ACKED); +} + +// Tests that we restart all xDS requests when we reestablish the ADS call. +TEST_P(XdsResolverOnlyTest, RestartsRequestsUponReconnection) { + balancers_[0]->ads_service()->SetLdsToUseDynamicRds(); + const char* kNewClusterName = "new_cluster_name"; + const char* kNewEdsServiceName = "new_eds_service_name"; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts(0, 2)}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args)); + // We need to wait for all backends to come online. + WaitForAllBackends(0, 2); + // Now shut down and restart the balancer. When the client + // reconnects, it should automatically restart the requests for all + // resource types. + balancers_[0]->Shutdown(); + balancers_[0]->Start(); + // Make sure things are still working. + CheckRpcSendOk(100); + // Populate new EDS resource. + AdsServiceImpl::EdsResourceArgs args2({ + {"locality0", GetBackendPorts(2, 4)}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args2, kNewEdsServiceName)); + // Populate new CDS resource. + Cluster new_cluster = balancers_[0]->ads_service()->default_cluster(); + new_cluster.set_name(kNewClusterName); + new_cluster.mutable_eds_cluster_config()->set_service_name( + kNewEdsServiceName); + balancers_[0]->ads_service()->SetCdsResource(new_cluster); + // Change RDS resource to point to new cluster. + RouteConfiguration new_route_config = + balancers_[0]->ads_service()->default_route_config(); + new_route_config.mutable_virtual_hosts(0) + ->mutable_routes(0) + ->mutable_route() + ->set_cluster(kNewClusterName); + balancers_[0]->ads_service()->SetRdsResource(new_route_config); + // Wait for all new backends to be used. + std::tuple<int, int, int> counts = WaitForAllBackends(2, 4); + // Make sure no RPCs failed in the transition. + EXPECT_EQ(0, std::get<1>(counts)); +} + +TEST_P(XdsResolverOnlyTest, DefaultRouteSpecifiesSlashPrefix) { + RouteConfiguration route_config = + balancers_[0]->ads_service()->default_route_config(); + route_config.mutable_virtual_hosts(0) + ->mutable_routes(0) + ->mutable_match() + ->set_prefix("/"); + balancers_[0]->ads_service()->SetLdsResource( + AdsServiceImpl::BuildListener(route_config)); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args)); + // We need to wait for all backends to come online. + WaitForAllBackends(); +} + +TEST_P(XdsResolverOnlyTest, CircuitBreaking) { + class TestRpc { + public: + TestRpc() {} + + void StartRpc(grpc::testing::EchoTestService::Stub* stub) { + sender_thread_ = std::thread([this, stub]() { + EchoResponse response; + EchoRequest request; + request.mutable_param()->set_client_cancel_after_us(1 * 1000 * 1000); + request.set_message(kRequestMessage); + status_ = stub->Echo(&context_, request, &response); + }); + } + + void CancelRpc() { + context_.TryCancel(); + sender_thread_.join(); + } + + private: + std::thread sender_thread_; + ClientContext context_; + Status status_; + }; + + gpr_setenv("GRPC_XDS_EXPERIMENTAL_CIRCUIT_BREAKING", "true"); + constexpr size_t kMaxConcurrentRequests = 10; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Populate new EDS resources. + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts(0, 1)}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args)); + // Update CDS resource to set max concurrent request. + CircuitBreakers circuit_breaks; + Cluster cluster = balancers_[0]->ads_service()->default_cluster(); + auto* threshold = cluster.mutable_circuit_breakers()->add_thresholds(); + threshold->set_priority(RoutingPriority::DEFAULT); + threshold->mutable_max_requests()->set_value(kMaxConcurrentRequests); + balancers_[0]->ads_service()->SetCdsResource(cluster); + // Send exactly max_concurrent_requests long RPCs. + TestRpc rpcs[kMaxConcurrentRequests]; + for (size_t i = 0; i < kMaxConcurrentRequests; ++i) { + rpcs[i].StartRpc(stub_.get()); + } + // Wait for all RPCs to be in flight. + while (backends_[0]->backend_service()->RpcsWaitingForClientCancel() < + kMaxConcurrentRequests) { + gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_micros(1 * 1000, GPR_TIMESPAN))); + } + // Sending a RPC now should fail, the error message should tell us + // we hit the max concurrent requests limit and got dropped. + Status status = SendRpc(); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.error_message(), "Call dropped by load balancing policy"); + // Cancel one RPC to allow another one through + rpcs[0].CancelRpc(); + status = SendRpc(); + EXPECT_TRUE(status.ok()); + for (size_t i = 1; i < kMaxConcurrentRequests; ++i) { + rpcs[i].CancelRpc(); + } + // Make sure RPCs go to the correct backend: + EXPECT_EQ(kMaxConcurrentRequests + 1, + backends_[0]->backend_service()->request_count()); + gpr_unsetenv("GRPC_XDS_EXPERIMENTAL_CIRCUIT_BREAKING"); +} + +TEST_P(XdsResolverOnlyTest, CircuitBreakingDisabled) { + class TestRpc { + public: + TestRpc() {} + + void StartRpc(grpc::testing::EchoTestService::Stub* stub) { + sender_thread_ = std::thread([this, stub]() { + EchoResponse response; + EchoRequest request; + request.mutable_param()->set_client_cancel_after_us(1 * 1000 * 1000); + request.set_message(kRequestMessage); + status_ = stub->Echo(&context_, request, &response); + }); + } + + void CancelRpc() { + context_.TryCancel(); + sender_thread_.join(); + } + + private: + std::thread sender_thread_; + ClientContext context_; + Status status_; + }; + + constexpr size_t kMaxConcurrentRequests = 10; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Populate new EDS resources. + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts(0, 1)}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args)); + // Update CDS resource to set max concurrent request. + CircuitBreakers circuit_breaks; + Cluster cluster = balancers_[0]->ads_service()->default_cluster(); + auto* threshold = cluster.mutable_circuit_breakers()->add_thresholds(); + threshold->set_priority(RoutingPriority::DEFAULT); + threshold->mutable_max_requests()->set_value(kMaxConcurrentRequests); + balancers_[0]->ads_service()->SetCdsResource(cluster); + // Send exactly max_concurrent_requests long RPCs. + TestRpc rpcs[kMaxConcurrentRequests]; + for (size_t i = 0; i < kMaxConcurrentRequests; ++i) { + rpcs[i].StartRpc(stub_.get()); + } + // Wait for all RPCs to be in flight. + while (backends_[0]->backend_service()->RpcsWaitingForClientCancel() < + kMaxConcurrentRequests) { + gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_micros(1 * 1000, GPR_TIMESPAN))); + } + // Sending a RPC now should not fail as circuit breaking is disabled. + Status status = SendRpc(); + EXPECT_TRUE(status.ok()); + for (size_t i = 0; i < kMaxConcurrentRequests; ++i) { + rpcs[i].CancelRpc(); + } + // Make sure RPCs go to the correct backend: + EXPECT_EQ(kMaxConcurrentRequests + 1, + backends_[0]->backend_service()->request_count()); +} + +TEST_P(XdsResolverOnlyTest, MultipleChannelsShareXdsClient) { + const char* kNewServerName = "new-server.example.com"; + Listener listener = balancers_[0]->ads_service()->default_listener(); + listener.set_name(kNewServerName); + balancers_[0]->ads_service()->SetLdsResource(listener); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args)); + WaitForAllBackends(); + // Create second channel and tell it to connect to kNewServerName. + auto channel2 = CreateChannel(/*failover_timeout=*/0, kNewServerName); + channel2->GetState(/*try_to_connect=*/true); + ASSERT_TRUE( + channel2->WaitForConnected(grpc_timeout_milliseconds_to_deadline(100))); + // Make sure there's only one client connected. + EXPECT_EQ(1UL, balancers_[0]->ads_service()->clients().size()); +} + +class XdsResolverLoadReportingOnlyTest : public XdsEnd2endTest { + public: + XdsResolverLoadReportingOnlyTest() : XdsEnd2endTest(4, 1, 3) {} +}; + +// Tests load reporting when switching over from one cluster to another. +TEST_P(XdsResolverLoadReportingOnlyTest, ChangeClusters) { + const char* kNewClusterName = "new_cluster_name"; + const char* kNewEdsServiceName = "new_eds_service_name"; + balancers_[0]->lrs_service()->set_cluster_names( + {kDefaultClusterName, kNewClusterName}); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // cluster kDefaultClusterName -> locality0 -> backends 0 and 1 + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts(0, 2)}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args)); + // cluster kNewClusterName -> locality1 -> backends 2 and 3 + AdsServiceImpl::EdsResourceArgs args2({ + {"locality1", GetBackendPorts(2, 4)}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args2, kNewEdsServiceName)); + // CDS resource for kNewClusterName. + Cluster new_cluster = balancers_[0]->ads_service()->default_cluster(); + new_cluster.set_name(kNewClusterName); + new_cluster.mutable_eds_cluster_config()->set_service_name( + kNewEdsServiceName); + balancers_[0]->ads_service()->SetCdsResource(new_cluster); + // Wait for all backends to come online. + int num_ok = 0; + int num_failure = 0; + int num_drops = 0; + std::tie(num_ok, num_failure, num_drops) = WaitForAllBackends(0, 2); + // The load report received at the balancer should be correct. + std::vector<ClientStats> load_report = + balancers_[0]->lrs_service()->WaitForLoadReport(); + EXPECT_THAT( + load_report, + ::testing::ElementsAre(::testing::AllOf( + ::testing::Property(&ClientStats::cluster_name, kDefaultClusterName), + ::testing::Property( + &ClientStats::locality_stats, + ::testing::ElementsAre(::testing::Pair( + "locality0", + ::testing::AllOf( + ::testing::Field(&ClientStats::LocalityStats:: + total_successful_requests, + num_ok), + ::testing::Field(&ClientStats::LocalityStats:: + total_requests_in_progress, + 0UL), + ::testing::Field( + &ClientStats::LocalityStats::total_error_requests, + num_failure), + ::testing::Field( + &ClientStats::LocalityStats::total_issued_requests, + num_failure + num_ok))))), + ::testing::Property(&ClientStats::total_dropped_requests, + num_drops)))); + // Change RDS resource to point to new cluster. + RouteConfiguration new_route_config = + balancers_[0]->ads_service()->default_route_config(); + new_route_config.mutable_virtual_hosts(0) + ->mutable_routes(0) + ->mutable_route() + ->set_cluster(kNewClusterName); + Listener listener = + balancers_[0]->ads_service()->BuildListener(new_route_config); + balancers_[0]->ads_service()->SetLdsResource(listener); + // Wait for all new backends to be used. + std::tie(num_ok, num_failure, num_drops) = WaitForAllBackends(2, 4); + // The load report received at the balancer should be correct. + load_report = balancers_[0]->lrs_service()->WaitForLoadReport(); + EXPECT_THAT( + load_report, + ::testing::ElementsAre( + ::testing::AllOf( + ::testing::Property(&ClientStats::cluster_name, + kDefaultClusterName), + ::testing::Property( + &ClientStats::locality_stats, + ::testing::ElementsAre(::testing::Pair( + "locality0", + ::testing::AllOf( + ::testing::Field(&ClientStats::LocalityStats:: + total_successful_requests, + ::testing::Lt(num_ok)), + ::testing::Field(&ClientStats::LocalityStats:: + total_requests_in_progress, + 0UL), + ::testing::Field( + &ClientStats::LocalityStats::total_error_requests, + ::testing::Le(num_failure)), + ::testing::Field( + &ClientStats::LocalityStats:: + total_issued_requests, + ::testing::Le(num_failure + num_ok)))))), + ::testing::Property(&ClientStats::total_dropped_requests, + num_drops)), + ::testing::AllOf( + ::testing::Property(&ClientStats::cluster_name, kNewClusterName), + ::testing::Property( + &ClientStats::locality_stats, + ::testing::ElementsAre(::testing::Pair( + "locality1", + ::testing::AllOf( + ::testing::Field(&ClientStats::LocalityStats:: + total_successful_requests, + ::testing::Le(num_ok)), + ::testing::Field(&ClientStats::LocalityStats:: + total_requests_in_progress, + 0UL), + ::testing::Field( + &ClientStats::LocalityStats::total_error_requests, + ::testing::Le(num_failure)), + ::testing::Field( + &ClientStats::LocalityStats:: + total_issued_requests, + ::testing::Le(num_failure + num_ok)))))), + ::testing::Property(&ClientStats::total_dropped_requests, + num_drops)))); + int total_ok = 0; + int total_failure = 0; + for (const ClientStats& client_stats : load_report) { + total_ok += client_stats.total_successful_requests(); + total_failure += client_stats.total_error_requests(); + } + EXPECT_EQ(total_ok, num_ok); + EXPECT_EQ(total_failure, num_failure); + // The LRS service got a single request, and sent a single response. + EXPECT_EQ(1U, balancers_[0]->lrs_service()->request_count()); + EXPECT_EQ(1U, balancers_[0]->lrs_service()->response_count()); +} + +using SecureNamingTest = BasicTest; + +// Tests that secure naming check passes if target name is expected. +TEST_P(SecureNamingTest, TargetNameIsExpected) { + SetNextResolution({}); + SetNextResolutionForLbChannel({balancers_[0]->port()}, nullptr, "xds_server"); + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName())); + CheckRpcSendOk(); +} + +// Tests that secure naming check fails if target name is unexpected. +TEST_P(SecureNamingTest, TargetNameIsUnexpected) { + ::testing::FLAGS_gtest_death_test_style = "threadsafe"; + SetNextResolution({}); + SetNextResolutionForLbChannel({balancers_[0]->port()}, nullptr, + "incorrect_server_name"); + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName())); + // Make sure that we blow up (via abort() from the security connector) when + // the name from the balancer doesn't match expectations. + ASSERT_DEATH_IF_SUPPORTED({ CheckRpcSendOk(); }, ""); +} + +using LdsTest = BasicTest; + +// Tests that LDS client should send a NACK if there is no API listener in the +// Listener in the LDS response. +TEST_P(LdsTest, NoApiListener) { + auto listener = balancers_[0]->ads_service()->default_listener(); + listener.clear_api_listener(); + balancers_[0]->ads_service()->SetLdsResource(listener); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + CheckRpcSendFailure(); + const auto& response_state = + balancers_[0]->ads_service()->lds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_EQ(response_state.error_message, "Listener has no ApiListener."); +} + +// Tests that LDS client should send a NACK if the route_specifier in the +// http_connection_manager is neither inlined route_config nor RDS. +TEST_P(LdsTest, WrongRouteSpecifier) { + auto listener = balancers_[0]->ads_service()->default_listener(); + HttpConnectionManager http_connection_manager; + http_connection_manager.mutable_scoped_routes(); + listener.mutable_api_listener()->mutable_api_listener()->PackFrom( + http_connection_manager); + balancers_[0]->ads_service()->SetLdsResource(listener); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + CheckRpcSendFailure(); + const auto& response_state = + balancers_[0]->ads_service()->lds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_EQ(response_state.error_message, + "HttpConnectionManager neither has inlined route_config nor RDS."); +} + +// Tests that LDS client should send a NACK if the rds message in the +// http_connection_manager is missing the config_source field. +TEST_P(LdsTest, RdsMissingConfigSource) { + auto listener = balancers_[0]->ads_service()->default_listener(); + HttpConnectionManager http_connection_manager; + http_connection_manager.mutable_rds()->set_route_config_name( + kDefaultRouteConfigurationName); + listener.mutable_api_listener()->mutable_api_listener()->PackFrom( + http_connection_manager); + balancers_[0]->ads_service()->SetLdsResource(listener); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + CheckRpcSendFailure(); + const auto& response_state = + balancers_[0]->ads_service()->lds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_EQ(response_state.error_message, + "HttpConnectionManager missing config_source for RDS."); +} + +// Tests that LDS client should send a NACK if the rds message in the +// http_connection_manager has a config_source field that does not specify ADS. +TEST_P(LdsTest, RdsConfigSourceDoesNotSpecifyAds) { + auto listener = balancers_[0]->ads_service()->default_listener(); + HttpConnectionManager http_connection_manager; + auto* rds = http_connection_manager.mutable_rds(); + rds->set_route_config_name(kDefaultRouteConfigurationName); + rds->mutable_config_source()->mutable_self(); + listener.mutable_api_listener()->mutable_api_listener()->PackFrom( + http_connection_manager); + balancers_[0]->ads_service()->SetLdsResource(listener); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + CheckRpcSendFailure(); + const auto& response_state = + balancers_[0]->ads_service()->lds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_EQ(response_state.error_message, + "HttpConnectionManager ConfigSource for RDS does not specify ADS."); +} + +using LdsRdsTest = BasicTest; + +// Tests that LDS client should send an ACK upon correct LDS response (with +// inlined RDS result). +TEST_P(LdsRdsTest, Vanilla) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + (void)SendRpc(); + EXPECT_EQ(RouteConfigurationResponseState(0).state, + AdsServiceImpl::ResponseState::ACKED); + // Make sure we actually used the RPC service for the right version of xDS. + EXPECT_EQ(balancers_[0]->ads_service()->seen_v2_client(), + GetParam().use_v2()); + EXPECT_NE(balancers_[0]->ads_service()->seen_v3_client(), + GetParam().use_v2()); +} + +// Tests that we go into TRANSIENT_FAILURE if the Listener is removed. +TEST_P(LdsRdsTest, ListenerRemoved) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args)); + // We need to wait for all backends to come online. + WaitForAllBackends(); + // Unset LDS resource. + balancers_[0]->ads_service()->UnsetResource(kLdsTypeUrl, kServerName); + // Wait for RPCs to start failing. + do { + } while (SendRpc(RpcOptions(), nullptr).ok()); + // Make sure RPCs are still failing. + CheckRpcSendFailure(1000); + // Make sure we ACK'ed the update. + EXPECT_EQ(balancers_[0]->ads_service()->lds_response_state().state, + AdsServiceImpl::ResponseState::ACKED); +} + +// Tests that LDS client ACKs but fails if matching domain can't be found in +// the LDS response. +TEST_P(LdsRdsTest, NoMatchedDomain) { + RouteConfiguration route_config = + balancers_[0]->ads_service()->default_route_config(); + route_config.mutable_virtual_hosts(0)->clear_domains(); + route_config.mutable_virtual_hosts(0)->add_domains("unmatched_domain"); + SetRouteConfiguration(0, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + CheckRpcSendFailure(); + // Do a bit of polling, to allow the ACK to get to the ADS server. + channel_->WaitForConnected(grpc_timeout_milliseconds_to_deadline(100)); + const auto& response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::ACKED); +} + +// Tests that LDS client should choose the virtual host with matching domain if +// multiple virtual hosts exist in the LDS response. +TEST_P(LdsRdsTest, ChooseMatchedDomain) { + RouteConfiguration route_config = + balancers_[0]->ads_service()->default_route_config(); + *(route_config.add_virtual_hosts()) = route_config.virtual_hosts(0); + route_config.mutable_virtual_hosts(0)->clear_domains(); + route_config.mutable_virtual_hosts(0)->add_domains("unmatched_domain"); + SetRouteConfiguration(0, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + (void)SendRpc(); + EXPECT_EQ(RouteConfigurationResponseState(0).state, + AdsServiceImpl::ResponseState::ACKED); +} + +// Tests that LDS client should choose the last route in the virtual host if +// multiple routes exist in the LDS response. +TEST_P(LdsRdsTest, ChooseLastRoute) { + RouteConfiguration route_config = + balancers_[0]->ads_service()->default_route_config(); + *(route_config.mutable_virtual_hosts(0)->add_routes()) = + route_config.virtual_hosts(0).routes(0); + route_config.mutable_virtual_hosts(0) + ->mutable_routes(0) + ->mutable_route() + ->mutable_cluster_header(); + SetRouteConfiguration(0, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + (void)SendRpc(); + EXPECT_EQ(RouteConfigurationResponseState(0).state, + AdsServiceImpl::ResponseState::ACKED); +} + +// Tests that LDS client should send a NACK if route match has a case_sensitive +// set to false. +TEST_P(LdsRdsTest, RouteMatchHasCaseSensitiveFalse) { + RouteConfiguration route_config = + balancers_[0]->ads_service()->default_route_config(); + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->mutable_case_sensitive()->set_value(false); + SetRouteConfiguration(0, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + CheckRpcSendFailure(); + const auto& response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_EQ(response_state.error_message, + "case_sensitive if set must be set to true."); +} + +// Tests that LDS client should ignore route which has query_parameters. +TEST_P(LdsRdsTest, RouteMatchHasQueryParameters) { + RouteConfiguration route_config = + balancers_[0]->ads_service()->default_route_config(); + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/"); + route1->mutable_match()->add_query_parameters(); + SetRouteConfiguration(0, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + CheckRpcSendFailure(); + const auto& response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_EQ(response_state.error_message, "No valid routes specified."); +} + +// Tests that LDS client should send a ACK if route match has a prefix +// that is either empty or a single slash +TEST_P(LdsRdsTest, RouteMatchHasValidPrefixEmptyOrSingleSlash) { + RouteConfiguration route_config = + balancers_[0]->ads_service()->default_route_config(); + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_prefix(""); + auto* default_route = route_config.mutable_virtual_hosts(0)->add_routes(); + default_route->mutable_match()->set_prefix("/"); + default_route->mutable_route()->set_cluster(kDefaultClusterName); + SetRouteConfiguration(0, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + (void)SendRpc(); + const auto& response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::ACKED); +} + +// Tests that LDS client should ignore route which has a path +// prefix string does not start with "/". +TEST_P(LdsRdsTest, RouteMatchHasInvalidPrefixNoLeadingSlash) { + RouteConfiguration route_config = + balancers_[0]->ads_service()->default_route_config(); + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_prefix("grpc.testing.EchoTest1Service/"); + SetRouteConfiguration(0, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + CheckRpcSendFailure(); + const auto& response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_EQ(response_state.error_message, "No valid routes specified."); +} + +// Tests that LDS client should ignore route which has a prefix +// string with more than 2 slashes. +TEST_P(LdsRdsTest, RouteMatchHasInvalidPrefixExtraContent) { + RouteConfiguration route_config = + balancers_[0]->ads_service()->default_route_config(); + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/Echo1/"); + SetRouteConfiguration(0, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + CheckRpcSendFailure(); + const auto& response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_EQ(response_state.error_message, "No valid routes specified."); +} + +// Tests that LDS client should ignore route which has a prefix +// string "//". +TEST_P(LdsRdsTest, RouteMatchHasInvalidPrefixDoubleSlash) { + RouteConfiguration route_config = + balancers_[0]->ads_service()->default_route_config(); + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_prefix("//"); + SetRouteConfiguration(0, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + CheckRpcSendFailure(); + const auto& response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_EQ(response_state.error_message, "No valid routes specified."); +} + +// Tests that LDS client should ignore route which has path +// but it's empty. +TEST_P(LdsRdsTest, RouteMatchHasInvalidPathEmptyPath) { + RouteConfiguration route_config = + balancers_[0]->ads_service()->default_route_config(); + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_path(""); + SetRouteConfiguration(0, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + CheckRpcSendFailure(); + const auto& response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_EQ(response_state.error_message, "No valid routes specified."); +} + +// Tests that LDS client should ignore route which has path +// string does not start with "/". +TEST_P(LdsRdsTest, RouteMatchHasInvalidPathNoLeadingSlash) { + RouteConfiguration route_config = + balancers_[0]->ads_service()->default_route_config(); + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_path("grpc.testing.EchoTest1Service/Echo1"); + SetRouteConfiguration(0, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + CheckRpcSendFailure(); + const auto& response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_EQ(response_state.error_message, "No valid routes specified."); +} + +// Tests that LDS client should ignore route which has path +// string that has too many slashes; for example, ends with "/". +TEST_P(LdsRdsTest, RouteMatchHasInvalidPathTooManySlashes) { + RouteConfiguration route_config = + balancers_[0]->ads_service()->default_route_config(); + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_path("/grpc.testing.EchoTest1Service/Echo1/"); + SetRouteConfiguration(0, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + CheckRpcSendFailure(); + const auto& response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_EQ(response_state.error_message, "No valid routes specified."); +} + +// Tests that LDS client should ignore route which has path +// string that has only 1 slash: missing "/" between service and method. +TEST_P(LdsRdsTest, RouteMatchHasInvalidPathOnlyOneSlash) { + RouteConfiguration route_config = + balancers_[0]->ads_service()->default_route_config(); + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_path("/grpc.testing.EchoTest1Service.Echo1"); + SetRouteConfiguration(0, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + CheckRpcSendFailure(); + const auto& response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_EQ(response_state.error_message, "No valid routes specified."); +} + +// Tests that LDS client should ignore route which has path +// string that is missing service. +TEST_P(LdsRdsTest, RouteMatchHasInvalidPathMissingService) { + RouteConfiguration route_config = + balancers_[0]->ads_service()->default_route_config(); + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_path("//Echo1"); + SetRouteConfiguration(0, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + CheckRpcSendFailure(); + const auto& response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_EQ(response_state.error_message, "No valid routes specified."); +} + +// Tests that LDS client should ignore route which has path +// string that is missing method. +TEST_P(LdsRdsTest, RouteMatchHasInvalidPathMissingMethod) { + RouteConfiguration route_config = + balancers_[0]->ads_service()->default_route_config(); + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_path("/grpc.testing.EchoTest1Service/"); + SetRouteConfiguration(0, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + CheckRpcSendFailure(); + const auto& response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_EQ(response_state.error_message, "No valid routes specified."); +} + +// Test that LDS client should reject route which has invalid path regex. +TEST_P(LdsRdsTest, RouteMatchHasInvalidPathRegex) { + const char* kNewCluster1Name = "new_cluster_1"; + RouteConfiguration route_config = + balancers_[0]->ads_service()->default_route_config(); + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->mutable_safe_regex()->set_regex("a[z-a]"); + route1->mutable_route()->set_cluster(kNewCluster1Name); + SetRouteConfiguration(0, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + CheckRpcSendFailure(); + const auto& response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_EQ(response_state.error_message, + "Invalid regex string specified in path matcher."); +} + +// Tests that LDS client should send a NACK if route has an action other than +// RouteAction in the LDS response. +TEST_P(LdsRdsTest, RouteHasNoRouteAction) { + RouteConfiguration route_config = + balancers_[0]->ads_service()->default_route_config(); + route_config.mutable_virtual_hosts(0)->mutable_routes(0)->mutable_redirect(); + SetRouteConfiguration(0, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + CheckRpcSendFailure(); + const auto& response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_EQ(response_state.error_message, "No RouteAction found in route."); +} + +TEST_P(LdsRdsTest, RouteActionClusterHasEmptyClusterName) { + RouteConfiguration route_config = + balancers_[0]->ads_service()->default_route_config(); + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/"); + route1->mutable_route()->set_cluster(""); + auto* default_route = route_config.mutable_virtual_hosts(0)->add_routes(); + default_route->mutable_match()->set_prefix(""); + default_route->mutable_route()->set_cluster(kDefaultClusterName); + SetRouteConfiguration(0, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + CheckRpcSendFailure(); + const auto& response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_EQ(response_state.error_message, + "RouteAction cluster contains empty cluster name."); +} + +TEST_P(LdsRdsTest, RouteActionWeightedTargetHasIncorrectTotalWeightSet) { + const size_t kWeight75 = 75; + const char* kNewCluster1Name = "new_cluster_1"; + RouteConfiguration route_config = + balancers_[0]->ads_service()->default_route_config(); + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/"); + auto* weighted_cluster1 = + route1->mutable_route()->mutable_weighted_clusters()->add_clusters(); + weighted_cluster1->set_name(kNewCluster1Name); + weighted_cluster1->mutable_weight()->set_value(kWeight75); + route1->mutable_route() + ->mutable_weighted_clusters() + ->mutable_total_weight() + ->set_value(kWeight75 + 1); + auto* default_route = route_config.mutable_virtual_hosts(0)->add_routes(); + default_route->mutable_match()->set_prefix(""); + default_route->mutable_route()->set_cluster(kDefaultClusterName); + SetRouteConfiguration(0, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + CheckRpcSendFailure(); + const auto& response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_EQ(response_state.error_message, + "RouteAction weighted_cluster has incorrect total weight"); +} + +TEST_P(LdsRdsTest, RouteActionWeightedTargetClusterHasEmptyClusterName) { + const size_t kWeight75 = 75; + RouteConfiguration route_config = + balancers_[0]->ads_service()->default_route_config(); + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/"); + auto* weighted_cluster1 = + route1->mutable_route()->mutable_weighted_clusters()->add_clusters(); + weighted_cluster1->set_name(""); + weighted_cluster1->mutable_weight()->set_value(kWeight75); + route1->mutable_route() + ->mutable_weighted_clusters() + ->mutable_total_weight() + ->set_value(kWeight75); + auto* default_route = route_config.mutable_virtual_hosts(0)->add_routes(); + default_route->mutable_match()->set_prefix(""); + default_route->mutable_route()->set_cluster(kDefaultClusterName); + SetRouteConfiguration(0, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + CheckRpcSendFailure(); + const auto& response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_EQ( + response_state.error_message, + "RouteAction weighted_cluster cluster contains empty cluster name."); +} + +TEST_P(LdsRdsTest, RouteActionWeightedTargetClusterHasNoWeight) { + const size_t kWeight75 = 75; + const char* kNewCluster1Name = "new_cluster_1"; + RouteConfiguration route_config = + balancers_[0]->ads_service()->default_route_config(); + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/"); + auto* weighted_cluster1 = + route1->mutable_route()->mutable_weighted_clusters()->add_clusters(); + weighted_cluster1->set_name(kNewCluster1Name); + route1->mutable_route() + ->mutable_weighted_clusters() + ->mutable_total_weight() + ->set_value(kWeight75); + auto* default_route = route_config.mutable_virtual_hosts(0)->add_routes(); + default_route->mutable_match()->set_prefix(""); + default_route->mutable_route()->set_cluster(kDefaultClusterName); + SetRouteConfiguration(0, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + CheckRpcSendFailure(); + const auto& response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_EQ(response_state.error_message, + "RouteAction weighted_cluster cluster missing weight"); +} + +TEST_P(LdsRdsTest, RouteHeaderMatchInvalidRegex) { + const char* kNewCluster1Name = "new_cluster_1"; + RouteConfiguration route_config = + balancers_[0]->ads_service()->default_route_config(); + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/"); + auto* header_matcher1 = route1->mutable_match()->add_headers(); + header_matcher1->set_name("header1"); + header_matcher1->mutable_safe_regex_match()->set_regex("a[z-a]"); + route1->mutable_route()->set_cluster(kNewCluster1Name); + SetRouteConfiguration(0, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + CheckRpcSendFailure(); + const auto& response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_EQ(response_state.error_message, + "Invalid regex string specified in header matcher."); +} + +TEST_P(LdsRdsTest, RouteHeaderMatchInvalidRange) { + const char* kNewCluster1Name = "new_cluster_1"; + RouteConfiguration route_config = + balancers_[0]->ads_service()->default_route_config(); + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/"); + auto* header_matcher1 = route1->mutable_match()->add_headers(); + header_matcher1->set_name("header1"); + header_matcher1->mutable_range_match()->set_start(1001); + header_matcher1->mutable_range_match()->set_end(1000); + route1->mutable_route()->set_cluster(kNewCluster1Name); + SetRouteConfiguration(0, route_config); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + CheckRpcSendFailure(); + const auto& response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_EQ(response_state.error_message, + "Invalid range header matcher specifier specified: end " + "cannot be smaller than start."); +} + +// Tests that LDS client should choose the default route (with no matching +// specified) after unable to find a match with previous routes. +TEST_P(LdsRdsTest, XdsRoutingPathMatching) { + const char* kNewCluster1Name = "new_cluster_1"; + const char* kNewEdsService1Name = "new_eds_service_name_1"; + const char* kNewCluster2Name = "new_cluster_2"; + const char* kNewEdsService2Name = "new_eds_service_name_2"; + const size_t kNumEcho1Rpcs = 10; + const size_t kNumEcho2Rpcs = 20; + const size_t kNumEchoRpcs = 30; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Populate new EDS resources. + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts(0, 2)}, + }); + AdsServiceImpl::EdsResourceArgs args1({ + {"locality0", GetBackendPorts(2, 3)}, + }); + AdsServiceImpl::EdsResourceArgs args2({ + {"locality0", GetBackendPorts(3, 4)}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args)); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args1, kNewEdsService1Name)); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args2, kNewEdsService2Name)); + // Populate new CDS resources. + Cluster new_cluster1 = balancers_[0]->ads_service()->default_cluster(); + new_cluster1.set_name(kNewCluster1Name); + new_cluster1.mutable_eds_cluster_config()->set_service_name( + kNewEdsService1Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster1); + Cluster new_cluster2 = balancers_[0]->ads_service()->default_cluster(); + new_cluster2.set_name(kNewCluster2Name); + new_cluster2.mutable_eds_cluster_config()->set_service_name( + kNewEdsService2Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster2); + // Populating Route Configurations for LDS. + RouteConfiguration new_route_config = + balancers_[0]->ads_service()->default_route_config(); + auto* route1 = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_path("/grpc.testing.EchoTest1Service/Echo1"); + route1->mutable_route()->set_cluster(kNewCluster1Name); + auto* route2 = new_route_config.mutable_virtual_hosts(0)->add_routes(); + route2->mutable_match()->set_path("/grpc.testing.EchoTest2Service/Echo2"); + route2->mutable_route()->set_cluster(kNewCluster2Name); + auto* route3 = new_route_config.mutable_virtual_hosts(0)->add_routes(); + route3->mutable_match()->set_path("/grpc.testing.EchoTest3Service/Echo3"); + route3->mutable_route()->set_cluster(kDefaultClusterName); + auto* default_route = new_route_config.mutable_virtual_hosts(0)->add_routes(); + default_route->mutable_match()->set_prefix(""); + default_route->mutable_route()->set_cluster(kDefaultClusterName); + SetRouteConfiguration(0, new_route_config); + WaitForAllBackends(0, 2); + CheckRpcSendOk(kNumEchoRpcs, RpcOptions().set_wait_for_ready(true)); + CheckRpcSendOk(kNumEcho1Rpcs, RpcOptions() + .set_rpc_service(SERVICE_ECHO1) + .set_rpc_method(METHOD_ECHO1) + .set_wait_for_ready(true)); + CheckRpcSendOk(kNumEcho2Rpcs, RpcOptions() + .set_rpc_service(SERVICE_ECHO2) + .set_rpc_method(METHOD_ECHO2) + .set_wait_for_ready(true)); + // Make sure RPCs all go to the correct backend. + for (size_t i = 0; i < 2; ++i) { + EXPECT_EQ(kNumEchoRpcs / 2, + backends_[i]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[i]->backend_service1()->request_count()); + EXPECT_EQ(0, backends_[i]->backend_service2()->request_count()); + } + EXPECT_EQ(0, backends_[2]->backend_service()->request_count()); + EXPECT_EQ(kNumEcho1Rpcs, backends_[2]->backend_service1()->request_count()); + EXPECT_EQ(0, backends_[2]->backend_service2()->request_count()); + EXPECT_EQ(0, backends_[3]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[3]->backend_service1()->request_count()); + EXPECT_EQ(kNumEcho2Rpcs, backends_[3]->backend_service2()->request_count()); +} + +TEST_P(LdsRdsTest, XdsRoutingPrefixMatching) { + const char* kNewCluster1Name = "new_cluster_1"; + const char* kNewEdsService1Name = "new_eds_service_name_1"; + const char* kNewCluster2Name = "new_cluster_2"; + const char* kNewEdsService2Name = "new_eds_service_name_2"; + const size_t kNumEcho1Rpcs = 10; + const size_t kNumEcho2Rpcs = 20; + const size_t kNumEchoRpcs = 30; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Populate new EDS resources. + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts(0, 2)}, + }); + AdsServiceImpl::EdsResourceArgs args1({ + {"locality0", GetBackendPorts(2, 3)}, + }); + AdsServiceImpl::EdsResourceArgs args2({ + {"locality0", GetBackendPorts(3, 4)}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args)); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args1, kNewEdsService1Name)); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args2, kNewEdsService2Name)); + // Populate new CDS resources. + Cluster new_cluster1 = balancers_[0]->ads_service()->default_cluster(); + new_cluster1.set_name(kNewCluster1Name); + new_cluster1.mutable_eds_cluster_config()->set_service_name( + kNewEdsService1Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster1); + Cluster new_cluster2 = balancers_[0]->ads_service()->default_cluster(); + new_cluster2.set_name(kNewCluster2Name); + new_cluster2.mutable_eds_cluster_config()->set_service_name( + kNewEdsService2Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster2); + // Populating Route Configurations for LDS. + RouteConfiguration new_route_config = + balancers_[0]->ads_service()->default_route_config(); + auto* route1 = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/"); + route1->mutable_route()->set_cluster(kNewCluster1Name); + auto* route2 = new_route_config.mutable_virtual_hosts(0)->add_routes(); + route2->mutable_match()->set_prefix("/grpc.testing.EchoTest2Service/"); + route2->mutable_route()->set_cluster(kNewCluster2Name); + auto* default_route = new_route_config.mutable_virtual_hosts(0)->add_routes(); + default_route->mutable_match()->set_prefix(""); + default_route->mutable_route()->set_cluster(kDefaultClusterName); + SetRouteConfiguration(0, new_route_config); + WaitForAllBackends(0, 2); + CheckRpcSendOk(kNumEchoRpcs, RpcOptions().set_wait_for_ready(true)); + CheckRpcSendOk( + kNumEcho1Rpcs, + RpcOptions().set_rpc_service(SERVICE_ECHO1).set_wait_for_ready(true)); + CheckRpcSendOk( + kNumEcho2Rpcs, + RpcOptions().set_rpc_service(SERVICE_ECHO2).set_wait_for_ready(true)); + // Make sure RPCs all go to the correct backend. + for (size_t i = 0; i < 2; ++i) { + EXPECT_EQ(kNumEchoRpcs / 2, + backends_[i]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[i]->backend_service1()->request_count()); + EXPECT_EQ(0, backends_[i]->backend_service2()->request_count()); + } + EXPECT_EQ(0, backends_[2]->backend_service()->request_count()); + EXPECT_EQ(kNumEcho1Rpcs, backends_[2]->backend_service1()->request_count()); + EXPECT_EQ(0, backends_[2]->backend_service2()->request_count()); + EXPECT_EQ(0, backends_[3]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[3]->backend_service1()->request_count()); + EXPECT_EQ(kNumEcho2Rpcs, backends_[3]->backend_service2()->request_count()); +} + +TEST_P(LdsRdsTest, XdsRoutingPathRegexMatching) { + const char* kNewCluster1Name = "new_cluster_1"; + const char* kNewEdsService1Name = "new_eds_service_name_1"; + const char* kNewCluster2Name = "new_cluster_2"; + const char* kNewEdsService2Name = "new_eds_service_name_2"; + const size_t kNumEcho1Rpcs = 10; + const size_t kNumEcho2Rpcs = 20; + const size_t kNumEchoRpcs = 30; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Populate new EDS resources. + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts(0, 2)}, + }); + AdsServiceImpl::EdsResourceArgs args1({ + {"locality0", GetBackendPorts(2, 3)}, + }); + AdsServiceImpl::EdsResourceArgs args2({ + {"locality0", GetBackendPorts(3, 4)}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args)); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args1, kNewEdsService1Name)); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args2, kNewEdsService2Name)); + // Populate new CDS resources. + Cluster new_cluster1 = balancers_[0]->ads_service()->default_cluster(); + new_cluster1.set_name(kNewCluster1Name); + new_cluster1.mutable_eds_cluster_config()->set_service_name( + kNewEdsService1Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster1); + Cluster new_cluster2 = balancers_[0]->ads_service()->default_cluster(); + new_cluster2.set_name(kNewCluster2Name); + new_cluster2.mutable_eds_cluster_config()->set_service_name( + kNewEdsService2Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster2); + // Populating Route Configurations for LDS. + RouteConfiguration new_route_config = + balancers_[0]->ads_service()->default_route_config(); + auto* route1 = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0); + // Will match "/grpc.testing.EchoTest1Service/" + route1->mutable_match()->mutable_safe_regex()->set_regex(".*1.*"); + route1->mutable_route()->set_cluster(kNewCluster1Name); + auto* route2 = new_route_config.mutable_virtual_hosts(0)->add_routes(); + // Will match "/grpc.testing.EchoTest2Service/" + route2->mutable_match()->mutable_safe_regex()->set_regex(".*2.*"); + route2->mutable_route()->set_cluster(kNewCluster2Name); + auto* default_route = new_route_config.mutable_virtual_hosts(0)->add_routes(); + default_route->mutable_match()->set_prefix(""); + default_route->mutable_route()->set_cluster(kDefaultClusterName); + SetRouteConfiguration(0, new_route_config); + WaitForAllBackends(0, 2); + CheckRpcSendOk(kNumEchoRpcs, RpcOptions().set_wait_for_ready(true)); + CheckRpcSendOk( + kNumEcho1Rpcs, + RpcOptions().set_rpc_service(SERVICE_ECHO1).set_wait_for_ready(true)); + CheckRpcSendOk( + kNumEcho2Rpcs, + RpcOptions().set_rpc_service(SERVICE_ECHO2).set_wait_for_ready(true)); + // Make sure RPCs all go to the correct backend. + for (size_t i = 0; i < 2; ++i) { + EXPECT_EQ(kNumEchoRpcs / 2, + backends_[i]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[i]->backend_service1()->request_count()); + EXPECT_EQ(0, backends_[i]->backend_service2()->request_count()); + } + EXPECT_EQ(0, backends_[2]->backend_service()->request_count()); + EXPECT_EQ(kNumEcho1Rpcs, backends_[2]->backend_service1()->request_count()); + EXPECT_EQ(0, backends_[2]->backend_service2()->request_count()); + EXPECT_EQ(0, backends_[3]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[3]->backend_service1()->request_count()); + EXPECT_EQ(kNumEcho2Rpcs, backends_[3]->backend_service2()->request_count()); +} + +TEST_P(LdsRdsTest, XdsRoutingWeightedCluster) { + const char* kNewCluster1Name = "new_cluster_1"; + const char* kNewEdsService1Name = "new_eds_service_name_1"; + const char* kNewCluster2Name = "new_cluster_2"; + const char* kNewEdsService2Name = "new_eds_service_name_2"; + const size_t kNumEcho1Rpcs = 1000; + const size_t kNumEchoRpcs = 10; + const size_t kWeight75 = 75; + const size_t kWeight25 = 25; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Populate new EDS resources. + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts(0, 1)}, + }); + AdsServiceImpl::EdsResourceArgs args1({ + {"locality0", GetBackendPorts(1, 2)}, + }); + AdsServiceImpl::EdsResourceArgs args2({ + {"locality0", GetBackendPorts(2, 3)}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args)); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args1, kNewEdsService1Name)); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args2, kNewEdsService2Name)); + // Populate new CDS resources. + Cluster new_cluster1 = balancers_[0]->ads_service()->default_cluster(); + new_cluster1.set_name(kNewCluster1Name); + new_cluster1.mutable_eds_cluster_config()->set_service_name( + kNewEdsService1Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster1); + Cluster new_cluster2 = balancers_[0]->ads_service()->default_cluster(); + new_cluster2.set_name(kNewCluster2Name); + new_cluster2.mutable_eds_cluster_config()->set_service_name( + kNewEdsService2Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster2); + // Populating Route Configurations for LDS. + RouteConfiguration new_route_config = + balancers_[0]->ads_service()->default_route_config(); + auto* route1 = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/"); + auto* weighted_cluster1 = + route1->mutable_route()->mutable_weighted_clusters()->add_clusters(); + weighted_cluster1->set_name(kNewCluster1Name); + weighted_cluster1->mutable_weight()->set_value(kWeight75); + auto* weighted_cluster2 = + route1->mutable_route()->mutable_weighted_clusters()->add_clusters(); + weighted_cluster2->set_name(kNewCluster2Name); + weighted_cluster2->mutable_weight()->set_value(kWeight25); + route1->mutable_route() + ->mutable_weighted_clusters() + ->mutable_total_weight() + ->set_value(kWeight75 + kWeight25); + auto* default_route = new_route_config.mutable_virtual_hosts(0)->add_routes(); + default_route->mutable_match()->set_prefix(""); + default_route->mutable_route()->set_cluster(kDefaultClusterName); + SetRouteConfiguration(0, new_route_config); + WaitForAllBackends(0, 1); + WaitForAllBackends(1, 3, true, RpcOptions().set_rpc_service(SERVICE_ECHO1)); + CheckRpcSendOk(kNumEchoRpcs); + CheckRpcSendOk(kNumEcho1Rpcs, RpcOptions().set_rpc_service(SERVICE_ECHO1)); + // Make sure RPCs all go to the correct backend. + EXPECT_EQ(kNumEchoRpcs, backends_[0]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[0]->backend_service1()->request_count()); + EXPECT_EQ(0, backends_[1]->backend_service()->request_count()); + const int weight_75_request_count = + backends_[1]->backend_service1()->request_count(); + EXPECT_EQ(0, backends_[2]->backend_service()->request_count()); + const int weight_25_request_count = + backends_[2]->backend_service1()->request_count(); + const double kErrorTolerance = 0.2; + EXPECT_THAT(weight_75_request_count, + ::testing::AllOf(::testing::Ge(kNumEcho1Rpcs * kWeight75 / 100 * + (1 - kErrorTolerance)), + ::testing::Le(kNumEcho1Rpcs * kWeight75 / 100 * + (1 + kErrorTolerance)))); + // TODO: (@donnadionne) Reduce tolerance: increased the tolerance to keep the + // test from flaking while debugging potential root cause. + const double kErrorToleranceSmallLoad = 0.3; + gpr_log(GPR_INFO, "target_75 received %d rpcs and target_25 received %d rpcs", + weight_75_request_count, weight_25_request_count); + EXPECT_THAT(weight_25_request_count, + ::testing::AllOf(::testing::Ge(kNumEcho1Rpcs * kWeight25 / 100 * + (1 - kErrorToleranceSmallLoad)), + ::testing::Le(kNumEcho1Rpcs * kWeight25 / 100 * + (1 + kErrorToleranceSmallLoad)))); +} + +TEST_P(LdsRdsTest, RouteActionWeightedTargetDefaultRoute) { + const char* kNewCluster1Name = "new_cluster_1"; + const char* kNewEdsService1Name = "new_eds_service_name_1"; + const char* kNewCluster2Name = "new_cluster_2"; + const char* kNewEdsService2Name = "new_eds_service_name_2"; + const size_t kNumEchoRpcs = 1000; + const size_t kWeight75 = 75; + const size_t kWeight25 = 25; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Populate new EDS resources. + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts(0, 1)}, + }); + AdsServiceImpl::EdsResourceArgs args1({ + {"locality0", GetBackendPorts(1, 2)}, + }); + AdsServiceImpl::EdsResourceArgs args2({ + {"locality0", GetBackendPorts(2, 3)}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args)); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args1, kNewEdsService1Name)); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args2, kNewEdsService2Name)); + // Populate new CDS resources. + Cluster new_cluster1 = balancers_[0]->ads_service()->default_cluster(); + new_cluster1.set_name(kNewCluster1Name); + new_cluster1.mutable_eds_cluster_config()->set_service_name( + kNewEdsService1Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster1); + Cluster new_cluster2 = balancers_[0]->ads_service()->default_cluster(); + new_cluster2.set_name(kNewCluster2Name); + new_cluster2.mutable_eds_cluster_config()->set_service_name( + kNewEdsService2Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster2); + // Populating Route Configurations for LDS. + RouteConfiguration new_route_config = + balancers_[0]->ads_service()->default_route_config(); + auto* route1 = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_prefix(""); + auto* weighted_cluster1 = + route1->mutable_route()->mutable_weighted_clusters()->add_clusters(); + weighted_cluster1->set_name(kNewCluster1Name); + weighted_cluster1->mutable_weight()->set_value(kWeight75); + auto* weighted_cluster2 = + route1->mutable_route()->mutable_weighted_clusters()->add_clusters(); + weighted_cluster2->set_name(kNewCluster2Name); + weighted_cluster2->mutable_weight()->set_value(kWeight25); + route1->mutable_route() + ->mutable_weighted_clusters() + ->mutable_total_weight() + ->set_value(kWeight75 + kWeight25); + SetRouteConfiguration(0, new_route_config); + WaitForAllBackends(1, 3); + CheckRpcSendOk(kNumEchoRpcs); + // Make sure RPCs all go to the correct backend. + EXPECT_EQ(0, backends_[0]->backend_service()->request_count()); + const int weight_75_request_count = + backends_[1]->backend_service()->request_count(); + const int weight_25_request_count = + backends_[2]->backend_service()->request_count(); + const double kErrorTolerance = 0.2; + EXPECT_THAT(weight_75_request_count, + ::testing::AllOf(::testing::Ge(kNumEchoRpcs * kWeight75 / 100 * + (1 - kErrorTolerance)), + ::testing::Le(kNumEchoRpcs * kWeight75 / 100 * + (1 + kErrorTolerance)))); + // TODO: (@donnadionne) Reduce tolerance: increased the tolerance to keep the + // test from flaking while debugging potential root cause. + const double kErrorToleranceSmallLoad = 0.3; + gpr_log(GPR_INFO, "target_75 received %d rpcs and target_25 received %d rpcs", + weight_75_request_count, weight_25_request_count); + EXPECT_THAT(weight_25_request_count, + ::testing::AllOf(::testing::Ge(kNumEchoRpcs * kWeight25 / 100 * + (1 - kErrorToleranceSmallLoad)), + ::testing::Le(kNumEchoRpcs * kWeight25 / 100 * + (1 + kErrorToleranceSmallLoad)))); +} + +TEST_P(LdsRdsTest, XdsRoutingWeightedClusterUpdateWeights) { + const char* kNewCluster1Name = "new_cluster_1"; + const char* kNewEdsService1Name = "new_eds_service_name_1"; + const char* kNewCluster2Name = "new_cluster_2"; + const char* kNewEdsService2Name = "new_eds_service_name_2"; + const char* kNewCluster3Name = "new_cluster_3"; + const char* kNewEdsService3Name = "new_eds_service_name_3"; + const size_t kNumEcho1Rpcs = 1000; + const size_t kNumEchoRpcs = 10; + const size_t kWeight75 = 75; + const size_t kWeight25 = 25; + const size_t kWeight50 = 50; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Populate new EDS resources. + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts(0, 1)}, + }); + AdsServiceImpl::EdsResourceArgs args1({ + {"locality0", GetBackendPorts(1, 2)}, + }); + AdsServiceImpl::EdsResourceArgs args2({ + {"locality0", GetBackendPorts(2, 3)}, + }); + AdsServiceImpl::EdsResourceArgs args3({ + {"locality0", GetBackendPorts(3, 4)}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args)); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args1, kNewEdsService1Name)); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args2, kNewEdsService2Name)); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args3, kNewEdsService3Name)); + // Populate new CDS resources. + Cluster new_cluster1 = balancers_[0]->ads_service()->default_cluster(); + new_cluster1.set_name(kNewCluster1Name); + new_cluster1.mutable_eds_cluster_config()->set_service_name( + kNewEdsService1Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster1); + Cluster new_cluster2 = balancers_[0]->ads_service()->default_cluster(); + new_cluster2.set_name(kNewCluster2Name); + new_cluster2.mutable_eds_cluster_config()->set_service_name( + kNewEdsService2Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster2); + Cluster new_cluster3 = balancers_[0]->ads_service()->default_cluster(); + new_cluster3.set_name(kNewCluster3Name); + new_cluster3.mutable_eds_cluster_config()->set_service_name( + kNewEdsService3Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster3); + // Populating Route Configurations. + RouteConfiguration new_route_config = + balancers_[0]->ads_service()->default_route_config(); + auto* route1 = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/"); + auto* weighted_cluster1 = + route1->mutable_route()->mutable_weighted_clusters()->add_clusters(); + weighted_cluster1->set_name(kNewCluster1Name); + weighted_cluster1->mutable_weight()->set_value(kWeight75); + auto* weighted_cluster2 = + route1->mutable_route()->mutable_weighted_clusters()->add_clusters(); + weighted_cluster2->set_name(kNewCluster2Name); + weighted_cluster2->mutable_weight()->set_value(kWeight25); + route1->mutable_route() + ->mutable_weighted_clusters() + ->mutable_total_weight() + ->set_value(kWeight75 + kWeight25); + auto* default_route = new_route_config.mutable_virtual_hosts(0)->add_routes(); + default_route->mutable_match()->set_prefix(""); + default_route->mutable_route()->set_cluster(kDefaultClusterName); + SetRouteConfiguration(0, new_route_config); + WaitForAllBackends(0, 1); + WaitForAllBackends(1, 3, true, RpcOptions().set_rpc_service(SERVICE_ECHO1)); + CheckRpcSendOk(kNumEchoRpcs); + CheckRpcSendOk(kNumEcho1Rpcs, RpcOptions().set_rpc_service(SERVICE_ECHO1)); + // Make sure RPCs all go to the correct backend. + EXPECT_EQ(kNumEchoRpcs, backends_[0]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[0]->backend_service1()->request_count()); + EXPECT_EQ(0, backends_[1]->backend_service()->request_count()); + const int weight_75_request_count = + backends_[1]->backend_service1()->request_count(); + EXPECT_EQ(0, backends_[1]->backend_service2()->request_count()); + EXPECT_EQ(0, backends_[2]->backend_service()->request_count()); + const int weight_25_request_count = + backends_[2]->backend_service1()->request_count(); + EXPECT_EQ(0, backends_[3]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[3]->backend_service1()->request_count()); + const double kErrorTolerance = 0.2; + EXPECT_THAT(weight_75_request_count, + ::testing::AllOf(::testing::Ge(kNumEcho1Rpcs * kWeight75 / 100 * + (1 - kErrorTolerance)), + ::testing::Le(kNumEcho1Rpcs * kWeight75 / 100 * + (1 + kErrorTolerance)))); + // TODO: (@donnadionne) Reduce tolerance: increased the tolerance to keep the + // test from flaking while debugging potential root cause. + const double kErrorToleranceSmallLoad = 0.3; + gpr_log(GPR_INFO, "target_75 received %d rpcs and target_25 received %d rpcs", + weight_75_request_count, weight_25_request_count); + EXPECT_THAT(weight_25_request_count, + ::testing::AllOf(::testing::Ge(kNumEcho1Rpcs * kWeight25 / 100 * + (1 - kErrorToleranceSmallLoad)), + ::testing::Le(kNumEcho1Rpcs * kWeight25 / 100 * + (1 + kErrorToleranceSmallLoad)))); + // Change Route Configurations: same clusters different weights. + weighted_cluster1->mutable_weight()->set_value(kWeight50); + weighted_cluster2->mutable_weight()->set_value(kWeight50); + // Change default route to a new cluster to help to identify when new polices + // are seen by the client. + default_route->mutable_route()->set_cluster(kNewCluster3Name); + SetRouteConfiguration(0, new_route_config); + ResetBackendCounters(); + WaitForAllBackends(3, 4); + CheckRpcSendOk(kNumEchoRpcs); + CheckRpcSendOk(kNumEcho1Rpcs, RpcOptions().set_rpc_service(SERVICE_ECHO1)); + // Make sure RPCs all go to the correct backend. + EXPECT_EQ(0, backends_[0]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[0]->backend_service1()->request_count()); + EXPECT_EQ(0, backends_[1]->backend_service()->request_count()); + const int weight_50_request_count_1 = + backends_[1]->backend_service1()->request_count(); + EXPECT_EQ(0, backends_[2]->backend_service()->request_count()); + const int weight_50_request_count_2 = + backends_[2]->backend_service1()->request_count(); + EXPECT_EQ(kNumEchoRpcs, backends_[3]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[3]->backend_service1()->request_count()); + EXPECT_THAT(weight_50_request_count_1, + ::testing::AllOf(::testing::Ge(kNumEcho1Rpcs * kWeight50 / 100 * + (1 - kErrorTolerance)), + ::testing::Le(kNumEcho1Rpcs * kWeight50 / 100 * + (1 + kErrorTolerance)))); + EXPECT_THAT(weight_50_request_count_2, + ::testing::AllOf(::testing::Ge(kNumEcho1Rpcs * kWeight50 / 100 * + (1 - kErrorTolerance)), + ::testing::Le(kNumEcho1Rpcs * kWeight50 / 100 * + (1 + kErrorTolerance)))); +} + +TEST_P(LdsRdsTest, XdsRoutingWeightedClusterUpdateClusters) { + const char* kNewCluster1Name = "new_cluster_1"; + const char* kNewEdsService1Name = "new_eds_service_name_1"; + const char* kNewCluster2Name = "new_cluster_2"; + const char* kNewEdsService2Name = "new_eds_service_name_2"; + const char* kNewCluster3Name = "new_cluster_3"; + const char* kNewEdsService3Name = "new_eds_service_name_3"; + const size_t kNumEcho1Rpcs = 1000; + const size_t kNumEchoRpcs = 10; + const size_t kWeight75 = 75; + const size_t kWeight25 = 25; + const size_t kWeight50 = 50; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Populate new EDS resources. + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts(0, 1)}, + }); + AdsServiceImpl::EdsResourceArgs args1({ + {"locality0", GetBackendPorts(1, 2)}, + }); + AdsServiceImpl::EdsResourceArgs args2({ + {"locality0", GetBackendPorts(2, 3)}, + }); + AdsServiceImpl::EdsResourceArgs args3({ + {"locality0", GetBackendPorts(3, 4)}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args)); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args1, kNewEdsService1Name)); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args2, kNewEdsService2Name)); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args3, kNewEdsService3Name)); + // Populate new CDS resources. + Cluster new_cluster1 = balancers_[0]->ads_service()->default_cluster(); + new_cluster1.set_name(kNewCluster1Name); + new_cluster1.mutable_eds_cluster_config()->set_service_name( + kNewEdsService1Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster1); + Cluster new_cluster2 = balancers_[0]->ads_service()->default_cluster(); + new_cluster2.set_name(kNewCluster2Name); + new_cluster2.mutable_eds_cluster_config()->set_service_name( + kNewEdsService2Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster2); + Cluster new_cluster3 = balancers_[0]->ads_service()->default_cluster(); + new_cluster3.set_name(kNewCluster3Name); + new_cluster3.mutable_eds_cluster_config()->set_service_name( + kNewEdsService3Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster3); + // Populating Route Configurations. + RouteConfiguration new_route_config = + balancers_[0]->ads_service()->default_route_config(); + auto* route1 = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/"); + auto* weighted_cluster1 = + route1->mutable_route()->mutable_weighted_clusters()->add_clusters(); + weighted_cluster1->set_name(kNewCluster1Name); + weighted_cluster1->mutable_weight()->set_value(kWeight75); + auto* weighted_cluster2 = + route1->mutable_route()->mutable_weighted_clusters()->add_clusters(); + weighted_cluster2->set_name(kDefaultClusterName); + weighted_cluster2->mutable_weight()->set_value(kWeight25); + route1->mutable_route() + ->mutable_weighted_clusters() + ->mutable_total_weight() + ->set_value(kWeight75 + kWeight25); + auto* default_route = new_route_config.mutable_virtual_hosts(0)->add_routes(); + default_route->mutable_match()->set_prefix(""); + default_route->mutable_route()->set_cluster(kDefaultClusterName); + SetRouteConfiguration(0, new_route_config); + WaitForAllBackends(0, 1); + WaitForAllBackends(1, 2, true, RpcOptions().set_rpc_service(SERVICE_ECHO1)); + CheckRpcSendOk(kNumEchoRpcs); + CheckRpcSendOk(kNumEcho1Rpcs, RpcOptions().set_rpc_service(SERVICE_ECHO1)); + // Make sure RPCs all go to the correct backend. + EXPECT_EQ(kNumEchoRpcs, backends_[0]->backend_service()->request_count()); + int weight_25_request_count = + backends_[0]->backend_service1()->request_count(); + EXPECT_EQ(0, backends_[1]->backend_service()->request_count()); + int weight_75_request_count = + backends_[1]->backend_service1()->request_count(); + EXPECT_EQ(0, backends_[2]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[2]->backend_service1()->request_count()); + EXPECT_EQ(0, backends_[3]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[3]->backend_service1()->request_count()); + const double kErrorTolerance = 0.2; + EXPECT_THAT(weight_75_request_count, + ::testing::AllOf(::testing::Ge(kNumEcho1Rpcs * kWeight75 / 100 * + (1 - kErrorTolerance)), + ::testing::Le(kNumEcho1Rpcs * kWeight75 / 100 * + (1 + kErrorTolerance)))); + // TODO: (@donnadionne) Reduce tolerance: increased the tolerance to keep the + // test from flaking while debugging potential root cause. + const double kErrorToleranceSmallLoad = 0.3; + gpr_log(GPR_INFO, "target_75 received %d rpcs and target_25 received %d rpcs", + weight_75_request_count, weight_25_request_count); + EXPECT_THAT(weight_25_request_count, + ::testing::AllOf(::testing::Ge(kNumEcho1Rpcs * kWeight25 / 100 * + (1 - kErrorToleranceSmallLoad)), + ::testing::Le(kNumEcho1Rpcs * kWeight25 / 100 * + (1 + kErrorToleranceSmallLoad)))); + // Change Route Configurations: new set of clusters with different weights. + weighted_cluster1->mutable_weight()->set_value(kWeight50); + weighted_cluster2->set_name(kNewCluster2Name); + weighted_cluster2->mutable_weight()->set_value(kWeight50); + SetRouteConfiguration(0, new_route_config); + ResetBackendCounters(); + WaitForAllBackends(2, 3, true, RpcOptions().set_rpc_service(SERVICE_ECHO1)); + CheckRpcSendOk(kNumEchoRpcs); + CheckRpcSendOk(kNumEcho1Rpcs, RpcOptions().set_rpc_service(SERVICE_ECHO1)); + // Make sure RPCs all go to the correct backend. + EXPECT_EQ(kNumEchoRpcs, backends_[0]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[0]->backend_service1()->request_count()); + EXPECT_EQ(0, backends_[1]->backend_service()->request_count()); + const int weight_50_request_count_1 = + backends_[1]->backend_service1()->request_count(); + EXPECT_EQ(0, backends_[2]->backend_service()->request_count()); + const int weight_50_request_count_2 = + backends_[2]->backend_service1()->request_count(); + EXPECT_EQ(0, backends_[3]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[3]->backend_service1()->request_count()); + EXPECT_THAT(weight_50_request_count_1, + ::testing::AllOf(::testing::Ge(kNumEcho1Rpcs * kWeight50 / 100 * + (1 - kErrorTolerance)), + ::testing::Le(kNumEcho1Rpcs * kWeight50 / 100 * + (1 + kErrorTolerance)))); + EXPECT_THAT(weight_50_request_count_2, + ::testing::AllOf(::testing::Ge(kNumEcho1Rpcs * kWeight50 / 100 * + (1 - kErrorTolerance)), + ::testing::Le(kNumEcho1Rpcs * kWeight50 / 100 * + (1 + kErrorTolerance)))); + // Change Route Configurations. + weighted_cluster1->mutable_weight()->set_value(kWeight75); + weighted_cluster2->set_name(kNewCluster3Name); + weighted_cluster2->mutable_weight()->set_value(kWeight25); + SetRouteConfiguration(0, new_route_config); + ResetBackendCounters(); + WaitForAllBackends(3, 4, true, RpcOptions().set_rpc_service(SERVICE_ECHO1)); + CheckRpcSendOk(kNumEchoRpcs); + CheckRpcSendOk(kNumEcho1Rpcs, RpcOptions().set_rpc_service(SERVICE_ECHO1)); + // Make sure RPCs all go to the correct backend. + EXPECT_EQ(kNumEchoRpcs, backends_[0]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[0]->backend_service1()->request_count()); + EXPECT_EQ(0, backends_[1]->backend_service()->request_count()); + weight_75_request_count = backends_[1]->backend_service1()->request_count(); + EXPECT_EQ(0, backends_[2]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[2]->backend_service1()->request_count()); + EXPECT_EQ(0, backends_[3]->backend_service()->request_count()); + weight_25_request_count = backends_[3]->backend_service1()->request_count(); + EXPECT_THAT(weight_75_request_count, + ::testing::AllOf(::testing::Ge(kNumEcho1Rpcs * kWeight75 / 100 * + (1 - kErrorTolerance)), + ::testing::Le(kNumEcho1Rpcs * kWeight75 / 100 * + (1 + kErrorTolerance)))); + // TODO: (@donnadionne) Reduce tolerance: increased the tolerance to keep the + // test from flaking while debugging potential root cause. + gpr_log(GPR_INFO, "target_75 received %d rpcs and target_25 received %d rpcs", + weight_75_request_count, weight_25_request_count); + EXPECT_THAT(weight_25_request_count, + ::testing::AllOf(::testing::Ge(kNumEcho1Rpcs * kWeight25 / 100 * + (1 - kErrorToleranceSmallLoad)), + ::testing::Le(kNumEcho1Rpcs * kWeight25 / 100 * + (1 + kErrorToleranceSmallLoad)))); +} + +TEST_P(LdsRdsTest, XdsRoutingClusterUpdateClusters) { + const char* kNewClusterName = "new_cluster"; + const char* kNewEdsServiceName = "new_eds_service_name"; + const size_t kNumEchoRpcs = 5; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Populate new EDS resources. + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts(0, 1)}, + }); + AdsServiceImpl::EdsResourceArgs args1({ + {"locality0", GetBackendPorts(1, 2)}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args)); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args1, kNewEdsServiceName)); + // Populate new CDS resources. + Cluster new_cluster = balancers_[0]->ads_service()->default_cluster(); + new_cluster.set_name(kNewClusterName); + new_cluster.mutable_eds_cluster_config()->set_service_name( + kNewEdsServiceName); + balancers_[0]->ads_service()->SetCdsResource(new_cluster); + // Send Route Configuration. + RouteConfiguration new_route_config = + balancers_[0]->ads_service()->default_route_config(); + SetRouteConfiguration(0, new_route_config); + WaitForAllBackends(0, 1); + CheckRpcSendOk(kNumEchoRpcs); + // Make sure RPCs all go to the correct backend. + EXPECT_EQ(kNumEchoRpcs, backends_[0]->backend_service()->request_count()); + // Change Route Configurations: new default cluster. + auto* default_route = + new_route_config.mutable_virtual_hosts(0)->mutable_routes(0); + default_route->mutable_route()->set_cluster(kNewClusterName); + SetRouteConfiguration(0, new_route_config); + WaitForAllBackends(1, 2); + CheckRpcSendOk(kNumEchoRpcs); + // Make sure RPCs all go to the correct backend. + EXPECT_EQ(kNumEchoRpcs, backends_[1]->backend_service()->request_count()); +} + +TEST_P(LdsRdsTest, XdsRoutingClusterUpdateClustersWithPickingDelays) { + const char* kNewClusterName = "new_cluster"; + const char* kNewEdsServiceName = "new_eds_service_name"; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Populate new EDS resources. + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts(0, 1)}, + }); + AdsServiceImpl::EdsResourceArgs args1({ + {"locality0", GetBackendPorts(1, 2)}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args)); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args1, kNewEdsServiceName)); + // Populate new CDS resources. + Cluster new_cluster = balancers_[0]->ads_service()->default_cluster(); + new_cluster.set_name(kNewClusterName); + new_cluster.mutable_eds_cluster_config()->set_service_name( + kNewEdsServiceName); + balancers_[0]->ads_service()->SetCdsResource(new_cluster); + // Bring down the current backend: 0, this will delay route picking time, + // resulting in un-committed RPCs. + ShutdownBackend(0); + // Send a RouteConfiguration with a default route that points to + // backend 0. + RouteConfiguration new_route_config = + balancers_[0]->ads_service()->default_route_config(); + SetRouteConfiguration(0, new_route_config); + // Send exactly one RPC with no deadline and with wait_for_ready=true. + // This RPC will not complete until after backend 0 is started. + std::thread sending_rpc([this]() { + CheckRpcSendOk(1, RpcOptions().set_wait_for_ready(true).set_timeout_ms(0)); + }); + // Send a non-wait_for_ready RPC which should fail, this will tell us + // that the client has received the update and attempted to connect. + const Status status = SendRpc(RpcOptions().set_timeout_ms(0)); + EXPECT_FALSE(status.ok()); + // Send a update RouteConfiguration to use backend 1. + auto* default_route = + new_route_config.mutable_virtual_hosts(0)->mutable_routes(0); + default_route->mutable_route()->set_cluster(kNewClusterName); + SetRouteConfiguration(0, new_route_config); + // Wait for RPCs to go to the new backend: 1, this ensures that the client has + // processed the update. + WaitForAllBackends(1, 2, false, RpcOptions(), true); + // Bring up the previous backend: 0, this will allow the delayed RPC to + // finally call on_call_committed upon completion. + StartBackend(0); + sending_rpc.join(); + // Make sure RPCs go to the correct backend: + EXPECT_EQ(1, backends_[0]->backend_service()->request_count()); + EXPECT_EQ(1, backends_[1]->backend_service()->request_count()); +} + +TEST_P(LdsRdsTest, XdsRoutingHeadersMatching) { + const char* kNewClusterName = "new_cluster"; + const char* kNewEdsServiceName = "new_eds_service_name"; + const size_t kNumEcho1Rpcs = 100; + const size_t kNumEchoRpcs = 5; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Populate new EDS resources. + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts(0, 1)}, + }); + AdsServiceImpl::EdsResourceArgs args1({ + {"locality0", GetBackendPorts(1, 2)}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args)); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args1, kNewEdsServiceName)); + // Populate new CDS resources. + Cluster new_cluster = balancers_[0]->ads_service()->default_cluster(); + new_cluster.set_name(kNewClusterName); + new_cluster.mutable_eds_cluster_config()->set_service_name( + kNewEdsServiceName); + balancers_[0]->ads_service()->SetCdsResource(new_cluster); + // Populating Route Configurations for LDS. + RouteConfiguration route_config = + balancers_[0]->ads_service()->default_route_config(); + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/"); + auto* header_matcher1 = route1->mutable_match()->add_headers(); + header_matcher1->set_name("header1"); + header_matcher1->set_exact_match("POST,PUT,GET"); + auto* header_matcher2 = route1->mutable_match()->add_headers(); + header_matcher2->set_name("header2"); + header_matcher2->mutable_safe_regex_match()->set_regex("[a-z]*"); + auto* header_matcher3 = route1->mutable_match()->add_headers(); + header_matcher3->set_name("header3"); + header_matcher3->mutable_range_match()->set_start(1); + header_matcher3->mutable_range_match()->set_end(1000); + auto* header_matcher4 = route1->mutable_match()->add_headers(); + header_matcher4->set_name("header4"); + header_matcher4->set_present_match(false); + auto* header_matcher5 = route1->mutable_match()->add_headers(); + header_matcher5->set_name("header5"); + header_matcher5->set_prefix_match("/grpc"); + auto* header_matcher6 = route1->mutable_match()->add_headers(); + header_matcher6->set_name("header6"); + header_matcher6->set_suffix_match(".cc"); + header_matcher6->set_invert_match(true); + route1->mutable_route()->set_cluster(kNewClusterName); + auto* default_route = route_config.mutable_virtual_hosts(0)->add_routes(); + default_route->mutable_match()->set_prefix(""); + default_route->mutable_route()->set_cluster(kDefaultClusterName); + SetRouteConfiguration(0, route_config); + std::vector<std::pair<TString, TString>> metadata = { + {"header1", "POST"}, {"header2", "blah"}, + {"header3", "1"}, {"header5", "/grpc.testing.EchoTest1Service/"}, + {"header1", "PUT"}, {"header6", "grpc.java"}, + {"header1", "GET"}, + }; + const auto header_match_rpc_options = RpcOptions() + .set_rpc_service(SERVICE_ECHO1) + .set_rpc_method(METHOD_ECHO1) + .set_metadata(std::move(metadata)); + // Make sure all backends are up. + WaitForAllBackends(0, 1); + WaitForAllBackends(1, 2, true, header_match_rpc_options); + // Send RPCs. + CheckRpcSendOk(kNumEchoRpcs); + CheckRpcSendOk(kNumEcho1Rpcs, header_match_rpc_options); + EXPECT_EQ(kNumEchoRpcs, backends_[0]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[0]->backend_service1()->request_count()); + EXPECT_EQ(0, backends_[0]->backend_service2()->request_count()); + EXPECT_EQ(0, backends_[1]->backend_service()->request_count()); + EXPECT_EQ(kNumEcho1Rpcs, backends_[1]->backend_service1()->request_count()); + EXPECT_EQ(0, backends_[1]->backend_service2()->request_count()); + const auto& response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::ACKED); +} + +TEST_P(LdsRdsTest, XdsRoutingHeadersMatchingSpecialHeaderContentType) { + const char* kNewClusterName = "new_cluster"; + const char* kNewEdsServiceName = "new_eds_service_name"; + const size_t kNumEchoRpcs = 100; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Populate new EDS resources. + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts(0, 1)}, + }); + AdsServiceImpl::EdsResourceArgs args1({ + {"locality0", GetBackendPorts(1, 2)}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args)); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args1, kNewEdsServiceName)); + // Populate new CDS resources. + Cluster new_cluster = balancers_[0]->ads_service()->default_cluster(); + new_cluster.set_name(kNewClusterName); + new_cluster.mutable_eds_cluster_config()->set_service_name( + kNewEdsServiceName); + balancers_[0]->ads_service()->SetCdsResource(new_cluster); + // Populating Route Configurations for LDS. + RouteConfiguration route_config = + balancers_[0]->ads_service()->default_route_config(); + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_prefix(""); + auto* header_matcher1 = route1->mutable_match()->add_headers(); + header_matcher1->set_name("content-type"); + header_matcher1->set_exact_match("notapplication/grpc"); + route1->mutable_route()->set_cluster(kNewClusterName); + auto* default_route = route_config.mutable_virtual_hosts(0)->add_routes(); + default_route->mutable_match()->set_prefix(""); + auto* header_matcher2 = default_route->mutable_match()->add_headers(); + header_matcher2->set_name("content-type"); + header_matcher2->set_exact_match("application/grpc"); + default_route->mutable_route()->set_cluster(kDefaultClusterName); + SetRouteConfiguration(0, route_config); + // Make sure the backend is up. + WaitForAllBackends(0, 1); + // Send RPCs. + CheckRpcSendOk(kNumEchoRpcs); + EXPECT_EQ(kNumEchoRpcs, backends_[0]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[1]->backend_service()->request_count()); + const auto& response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::ACKED); +} + +TEST_P(LdsRdsTest, XdsRoutingHeadersMatchingSpecialCasesToIgnore) { + const char* kNewCluster1Name = "new_cluster_1"; + const char* kNewEdsService1Name = "new_eds_service_name_1"; + const char* kNewCluster2Name = "new_cluster_2"; + const char* kNewEdsService2Name = "new_eds_service_name_2"; + const size_t kNumEchoRpcs = 100; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Populate new EDS resources. + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts(0, 1)}, + }); + AdsServiceImpl::EdsResourceArgs args1({ + {"locality0", GetBackendPorts(1, 2)}, + }); + AdsServiceImpl::EdsResourceArgs args2({ + {"locality0", GetBackendPorts(2, 3)}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args)); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args1, kNewEdsService1Name)); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args2, kNewEdsService2Name)); + // Populate new CDS resources. + Cluster new_cluster1 = balancers_[0]->ads_service()->default_cluster(); + new_cluster1.set_name(kNewCluster1Name); + new_cluster1.mutable_eds_cluster_config()->set_service_name( + kNewEdsService1Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster1); + Cluster new_cluster2 = balancers_[0]->ads_service()->default_cluster(); + new_cluster2.set_name(kNewCluster2Name); + new_cluster2.mutable_eds_cluster_config()->set_service_name( + kNewEdsService2Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster2); + // Populating Route Configurations for LDS. + RouteConfiguration route_config = + balancers_[0]->ads_service()->default_route_config(); + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_prefix(""); + auto* header_matcher1 = route1->mutable_match()->add_headers(); + header_matcher1->set_name("grpc-foo-bin"); + header_matcher1->set_present_match(true); + route1->mutable_route()->set_cluster(kNewCluster1Name); + auto route2 = route_config.mutable_virtual_hosts(0)->add_routes(); + route2->mutable_match()->set_prefix(""); + auto* header_matcher2 = route2->mutable_match()->add_headers(); + header_matcher2->set_name("grpc-previous-rpc-attempts"); + header_matcher2->set_present_match(true); + route2->mutable_route()->set_cluster(kNewCluster2Name); + auto* default_route = route_config.mutable_virtual_hosts(0)->add_routes(); + default_route->mutable_match()->set_prefix(""); + default_route->mutable_route()->set_cluster(kDefaultClusterName); + SetRouteConfiguration(0, route_config); + // Send headers which will mismatch each route + std::vector<std::pair<TString, TString>> metadata = { + {"grpc-foo-bin", "grpc-foo-bin"}, + {"grpc-previous-rpc-attempts", "grpc-previous-rpc-attempts"}, + }; + WaitForAllBackends(0, 1); + CheckRpcSendOk(kNumEchoRpcs, RpcOptions().set_metadata(metadata)); + // Verify that only the default backend got RPCs since all previous routes + // were mismatched. + EXPECT_EQ(kNumEchoRpcs, backends_[0]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[1]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[2]->backend_service()->request_count()); + const auto& response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::ACKED); +} + +TEST_P(LdsRdsTest, XdsRoutingRuntimeFractionMatching) { + const char* kNewClusterName = "new_cluster"; + const char* kNewEdsServiceName = "new_eds_service_name"; + const size_t kNumRpcs = 1000; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Populate new EDS resources. + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts(0, 1)}, + }); + AdsServiceImpl::EdsResourceArgs args1({ + {"locality0", GetBackendPorts(1, 2)}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args)); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args1, kNewEdsServiceName)); + // Populate new CDS resources. + Cluster new_cluster = balancers_[0]->ads_service()->default_cluster(); + new_cluster.set_name(kNewClusterName); + new_cluster.mutable_eds_cluster_config()->set_service_name( + kNewEdsServiceName); + balancers_[0]->ads_service()->SetCdsResource(new_cluster); + // Populating Route Configurations for LDS. + RouteConfiguration route_config = + balancers_[0]->ads_service()->default_route_config(); + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match() + ->mutable_runtime_fraction() + ->mutable_default_value() + ->set_numerator(25); + route1->mutable_route()->set_cluster(kNewClusterName); + auto* default_route = route_config.mutable_virtual_hosts(0)->add_routes(); + default_route->mutable_match()->set_prefix(""); + default_route->mutable_route()->set_cluster(kDefaultClusterName); + SetRouteConfiguration(0, route_config); + WaitForAllBackends(0, 2); + CheckRpcSendOk(kNumRpcs); + const int default_backend_count = + backends_[0]->backend_service()->request_count(); + const int matched_backend_count = + backends_[1]->backend_service()->request_count(); + const double kErrorTolerance = 0.2; + EXPECT_THAT(default_backend_count, + ::testing::AllOf( + ::testing::Ge(kNumRpcs * 75 / 100 * (1 - kErrorTolerance)), + ::testing::Le(kNumRpcs * 75 / 100 * (1 + kErrorTolerance)))); + EXPECT_THAT(matched_backend_count, + ::testing::AllOf( + ::testing::Ge(kNumRpcs * 25 / 100 * (1 - kErrorTolerance)), + ::testing::Le(kNumRpcs * 25 / 100 * (1 + kErrorTolerance)))); + const auto& response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::ACKED); +} + +TEST_P(LdsRdsTest, XdsRoutingHeadersMatchingUnmatchCases) { + const char* kNewCluster1Name = "new_cluster_1"; + const char* kNewEdsService1Name = "new_eds_service_name_1"; + const char* kNewCluster2Name = "new_cluster_2"; + const char* kNewEdsService2Name = "new_eds_service_name_2"; + const char* kNewCluster3Name = "new_cluster_3"; + const char* kNewEdsService3Name = "new_eds_service_name_3"; + const size_t kNumEcho1Rpcs = 100; + const size_t kNumEchoRpcs = 5; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Populate new EDS resources. + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts(0, 1)}, + }); + AdsServiceImpl::EdsResourceArgs args1({ + {"locality0", GetBackendPorts(1, 2)}, + }); + AdsServiceImpl::EdsResourceArgs args2({ + {"locality0", GetBackendPorts(2, 3)}, + }); + AdsServiceImpl::EdsResourceArgs args3({ + {"locality0", GetBackendPorts(3, 4)}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args)); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args1, kNewEdsService1Name)); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args2, kNewEdsService2Name)); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args3, kNewEdsService3Name)); + // Populate new CDS resources. + Cluster new_cluster1 = balancers_[0]->ads_service()->default_cluster(); + new_cluster1.set_name(kNewCluster1Name); + new_cluster1.mutable_eds_cluster_config()->set_service_name( + kNewEdsService1Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster1); + Cluster new_cluster2 = balancers_[0]->ads_service()->default_cluster(); + new_cluster2.set_name(kNewCluster2Name); + new_cluster2.mutable_eds_cluster_config()->set_service_name( + kNewEdsService2Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster2); + Cluster new_cluster3 = balancers_[0]->ads_service()->default_cluster(); + new_cluster3.set_name(kNewCluster3Name); + new_cluster3.mutable_eds_cluster_config()->set_service_name( + kNewEdsService3Name); + balancers_[0]->ads_service()->SetCdsResource(new_cluster3); + // Populating Route Configurations for LDS. + RouteConfiguration route_config = + balancers_[0]->ads_service()->default_route_config(); + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/"); + auto* header_matcher1 = route1->mutable_match()->add_headers(); + header_matcher1->set_name("header1"); + header_matcher1->set_exact_match("POST"); + route1->mutable_route()->set_cluster(kNewCluster1Name); + auto route2 = route_config.mutable_virtual_hosts(0)->add_routes(); + route2->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/"); + auto* header_matcher2 = route2->mutable_match()->add_headers(); + header_matcher2->set_name("header2"); + header_matcher2->mutable_range_match()->set_start(1); + header_matcher2->mutable_range_match()->set_end(1000); + route2->mutable_route()->set_cluster(kNewCluster2Name); + auto route3 = route_config.mutable_virtual_hosts(0)->add_routes(); + route3->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/"); + auto* header_matcher3 = route3->mutable_match()->add_headers(); + header_matcher3->set_name("header3"); + header_matcher3->mutable_safe_regex_match()->set_regex("[a-z]*"); + route3->mutable_route()->set_cluster(kNewCluster3Name); + auto* default_route = route_config.mutable_virtual_hosts(0)->add_routes(); + default_route->mutable_match()->set_prefix(""); + default_route->mutable_route()->set_cluster(kDefaultClusterName); + SetRouteConfiguration(0, route_config); + // Send headers which will mismatch each route + std::vector<std::pair<TString, TString>> metadata = { + {"header1", "POST"}, + {"header2", "1000"}, + {"header3", "123"}, + {"header1", "GET"}, + }; + WaitForAllBackends(0, 1); + CheckRpcSendOk(kNumEchoRpcs, RpcOptions().set_metadata(metadata)); + CheckRpcSendOk(kNumEcho1Rpcs, RpcOptions() + .set_rpc_service(SERVICE_ECHO1) + .set_rpc_method(METHOD_ECHO1) + .set_metadata(metadata)); + // Verify that only the default backend got RPCs since all previous routes + // were mismatched. + for (size_t i = 1; i < 4; ++i) { + EXPECT_EQ(0, backends_[i]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[i]->backend_service1()->request_count()); + EXPECT_EQ(0, backends_[i]->backend_service2()->request_count()); + } + EXPECT_EQ(kNumEchoRpcs, backends_[0]->backend_service()->request_count()); + EXPECT_EQ(kNumEcho1Rpcs, backends_[0]->backend_service1()->request_count()); + EXPECT_EQ(0, backends_[0]->backend_service2()->request_count()); + const auto& response_state = RouteConfigurationResponseState(0); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::ACKED); +} + +TEST_P(LdsRdsTest, XdsRoutingChangeRoutesWithoutChangingClusters) { + const char* kNewClusterName = "new_cluster"; + const char* kNewEdsServiceName = "new_eds_service_name"; + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // Populate new EDS resources. + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts(0, 1)}, + }); + AdsServiceImpl::EdsResourceArgs args1({ + {"locality0", GetBackendPorts(1, 2)}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args)); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args1, kNewEdsServiceName)); + // Populate new CDS resources. + Cluster new_cluster = balancers_[0]->ads_service()->default_cluster(); + new_cluster.set_name(kNewClusterName); + new_cluster.mutable_eds_cluster_config()->set_service_name( + kNewEdsServiceName); + balancers_[0]->ads_service()->SetCdsResource(new_cluster); + // Populating Route Configurations for LDS. + RouteConfiguration route_config = + balancers_[0]->ads_service()->default_route_config(); + auto* route1 = route_config.mutable_virtual_hosts(0)->mutable_routes(0); + route1->mutable_match()->set_prefix("/grpc.testing.EchoTest1Service/"); + route1->mutable_route()->set_cluster(kNewClusterName); + auto* default_route = route_config.mutable_virtual_hosts(0)->add_routes(); + default_route->mutable_match()->set_prefix(""); + default_route->mutable_route()->set_cluster(kDefaultClusterName); + SetRouteConfiguration(0, route_config); + // Make sure all backends are up and that requests for each RPC + // service go to the right backends. + WaitForAllBackends(0, 1, false); + WaitForAllBackends(1, 2, false, RpcOptions().set_rpc_service(SERVICE_ECHO1)); + WaitForAllBackends(0, 1, false, RpcOptions().set_rpc_service(SERVICE_ECHO2)); + // Requests for services Echo and Echo2 should have gone to backend 0. + EXPECT_EQ(1, backends_[0]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[0]->backend_service1()->request_count()); + EXPECT_EQ(1, backends_[0]->backend_service2()->request_count()); + // Requests for service Echo1 should have gone to backend 1. + EXPECT_EQ(0, backends_[1]->backend_service()->request_count()); + EXPECT_EQ(1, backends_[1]->backend_service1()->request_count()); + EXPECT_EQ(0, backends_[1]->backend_service2()->request_count()); + // Now send an update that changes the first route to match a + // different RPC service, and wait for the client to make the change. + route1->mutable_match()->set_prefix("/grpc.testing.EchoTest2Service/"); + SetRouteConfiguration(0, route_config); + WaitForAllBackends(1, 2, true, RpcOptions().set_rpc_service(SERVICE_ECHO2)); + // Now repeat the earlier test, making sure all traffic goes to the + // right place. + WaitForAllBackends(0, 1, false); + WaitForAllBackends(0, 1, false, RpcOptions().set_rpc_service(SERVICE_ECHO1)); + WaitForAllBackends(1, 2, false, RpcOptions().set_rpc_service(SERVICE_ECHO2)); + // Requests for services Echo and Echo1 should have gone to backend 0. + EXPECT_EQ(1, backends_[0]->backend_service()->request_count()); + EXPECT_EQ(1, backends_[0]->backend_service1()->request_count()); + EXPECT_EQ(0, backends_[0]->backend_service2()->request_count()); + // Requests for service Echo2 should have gone to backend 1. + EXPECT_EQ(0, backends_[1]->backend_service()->request_count()); + EXPECT_EQ(0, backends_[1]->backend_service1()->request_count()); + EXPECT_EQ(1, backends_[1]->backend_service2()->request_count()); +} + +using CdsTest = BasicTest; + +// Tests that CDS client should send an ACK upon correct CDS response. +TEST_P(CdsTest, Vanilla) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + (void)SendRpc(); + EXPECT_EQ(balancers_[0]->ads_service()->cds_response_state().state, + AdsServiceImpl::ResponseState::ACKED); +} + +// Tests that CDS client should send a NACK if the cluster type in CDS response +// is other than EDS. +TEST_P(CdsTest, WrongClusterType) { + auto cluster = balancers_[0]->ads_service()->default_cluster(); + cluster.set_type(Cluster::STATIC); + balancers_[0]->ads_service()->SetCdsResource(cluster); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + CheckRpcSendFailure(); + const auto& response_state = + balancers_[0]->ads_service()->cds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_EQ(response_state.error_message, "DiscoveryType is not EDS."); +} + +// Tests that CDS client should send a NACK if the eds_config in CDS response is +// other than ADS. +TEST_P(CdsTest, WrongEdsConfig) { + auto cluster = balancers_[0]->ads_service()->default_cluster(); + cluster.mutable_eds_cluster_config()->mutable_eds_config()->mutable_self(); + balancers_[0]->ads_service()->SetCdsResource(cluster); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + CheckRpcSendFailure(); + const auto& response_state = + balancers_[0]->ads_service()->cds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_EQ(response_state.error_message, "EDS ConfigSource is not ADS."); +} + +// Tests that CDS client should send a NACK if the lb_policy in CDS response is +// other than ROUND_ROBIN. +TEST_P(CdsTest, WrongLbPolicy) { + auto cluster = balancers_[0]->ads_service()->default_cluster(); + cluster.set_lb_policy(Cluster::LEAST_REQUEST); + balancers_[0]->ads_service()->SetCdsResource(cluster); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + CheckRpcSendFailure(); + const auto& response_state = + balancers_[0]->ads_service()->cds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_EQ(response_state.error_message, "LB policy is not ROUND_ROBIN."); +} + +// Tests that CDS client should send a NACK if the lrs_server in CDS response is +// other than SELF. +TEST_P(CdsTest, WrongLrsServer) { + auto cluster = balancers_[0]->ads_service()->default_cluster(); + cluster.mutable_lrs_server()->mutable_ads(); + balancers_[0]->ads_service()->SetCdsResource(cluster); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + CheckRpcSendFailure(); + const auto& response_state = + balancers_[0]->ads_service()->cds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_EQ(response_state.error_message, "LRS ConfigSource is not self."); +} + +using EdsTest = BasicTest; + +// Tests that EDS client should send a NACK if the EDS update contains +// sparse priorities. +TEST_P(EdsTest, NacksSparsePriorityList) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts(), kDefaultLocalityWeight, 1}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args)); + CheckRpcSendFailure(); + const auto& response_state = + balancers_[0]->ads_service()->eds_response_state(); + EXPECT_EQ(response_state.state, AdsServiceImpl::ResponseState::NACKED); + EXPECT_EQ(response_state.error_message, + "EDS update includes sparse priority list"); +} + +// In most of our tests, we use different names for different resource +// types, to make sure that there are no cut-and-paste errors in the code +// that cause us to look at data for the wrong resource type. So we add +// this test to make sure that the EDS resource name defaults to the +// cluster name if not specified in the CDS resource. +TEST_P(EdsTest, EdsServiceNameDefaultsToClusterName) { + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args, kDefaultClusterName)); + Cluster cluster = balancers_[0]->ads_service()->default_cluster(); + cluster.mutable_eds_cluster_config()->clear_service_name(); + balancers_[0]->ads_service()->SetCdsResource(cluster); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + CheckRpcSendOk(); +} + +class TimeoutTest : public BasicTest { + protected: + void SetUp() override { + xds_resource_does_not_exist_timeout_ms_ = 500; + BasicTest::SetUp(); + } +}; + +// Tests that LDS client times out when no response received. +TEST_P(TimeoutTest, Lds) { + balancers_[0]->ads_service()->SetResourceIgnore(kLdsTypeUrl); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + CheckRpcSendFailure(); +} + +TEST_P(TimeoutTest, Rds) { + balancers_[0]->ads_service()->SetResourceIgnore(kRdsTypeUrl); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + CheckRpcSendFailure(); +} + +// Tests that CDS client times out when no response received. +TEST_P(TimeoutTest, Cds) { + balancers_[0]->ads_service()->SetResourceIgnore(kCdsTypeUrl); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + CheckRpcSendFailure(); +} + +TEST_P(TimeoutTest, Eds) { + balancers_[0]->ads_service()->SetResourceIgnore(kEdsTypeUrl); + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + CheckRpcSendFailure(); +} + +using LocalityMapTest = BasicTest; + +// Tests that the localities in a locality map are picked according to their +// weights. +TEST_P(LocalityMapTest, WeightedRoundRobin) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + const size_t kNumRpcs = 5000; + const int kLocalityWeight0 = 2; + const int kLocalityWeight1 = 8; + const int kTotalLocalityWeight = kLocalityWeight0 + kLocalityWeight1; + const double kLocalityWeightRate0 = + static_cast<double>(kLocalityWeight0) / kTotalLocalityWeight; + const double kLocalityWeightRate1 = + static_cast<double>(kLocalityWeight1) / kTotalLocalityWeight; + // ADS response contains 2 localities, each of which contains 1 backend. + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts(0, 1), kLocalityWeight0}, + {"locality1", GetBackendPorts(1, 2), kLocalityWeight1}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName())); + // Wait for both backends to be ready. + WaitForAllBackends(0, 2); + // Send kNumRpcs RPCs. + CheckRpcSendOk(kNumRpcs); + // The locality picking rates should be roughly equal to the expectation. + const double locality_picked_rate_0 = + static_cast<double>(backends_[0]->backend_service()->request_count()) / + kNumRpcs; + const double locality_picked_rate_1 = + static_cast<double>(backends_[1]->backend_service()->request_count()) / + kNumRpcs; + const double kErrorTolerance = 0.2; + EXPECT_THAT(locality_picked_rate_0, + ::testing::AllOf( + ::testing::Ge(kLocalityWeightRate0 * (1 - kErrorTolerance)), + ::testing::Le(kLocalityWeightRate0 * (1 + kErrorTolerance)))); + EXPECT_THAT(locality_picked_rate_1, + ::testing::AllOf( + ::testing::Ge(kLocalityWeightRate1 * (1 - kErrorTolerance)), + ::testing::Le(kLocalityWeightRate1 * (1 + kErrorTolerance)))); +} + +// Tests that we correctly handle a locality containing no endpoints. +TEST_P(LocalityMapTest, LocalityContainingNoEndpoints) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + const size_t kNumRpcs = 5000; + // EDS response contains 2 localities, one with no endpoints. + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts()}, + {"locality1", {}}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName())); + // Wait for both backends to be ready. + WaitForAllBackends(); + // Send kNumRpcs RPCs. + CheckRpcSendOk(kNumRpcs); + // All traffic should go to the reachable locality. + EXPECT_EQ(backends_[0]->backend_service()->request_count(), + kNumRpcs / backends_.size()); + EXPECT_EQ(backends_[1]->backend_service()->request_count(), + kNumRpcs / backends_.size()); + EXPECT_EQ(backends_[2]->backend_service()->request_count(), + kNumRpcs / backends_.size()); + EXPECT_EQ(backends_[3]->backend_service()->request_count(), + kNumRpcs / backends_.size()); +} + +// EDS update with no localities. +TEST_P(LocalityMapTest, NoLocalities) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource({}, DefaultEdsServiceName())); + Status status = SendRpc(); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.error_code(), StatusCode::UNAVAILABLE); +} + +// Tests that the locality map can work properly even when it contains a large +// number of localities. +TEST_P(LocalityMapTest, StressTest) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + const size_t kNumLocalities = 100; + // The first ADS response contains kNumLocalities localities, each of which + // contains backend 0. + AdsServiceImpl::EdsResourceArgs args; + for (size_t i = 0; i < kNumLocalities; ++i) { + TString name = y_absl::StrCat("locality", i); + AdsServiceImpl::EdsResourceArgs::Locality locality(name, + {backends_[0]->port()}); + args.locality_list.emplace_back(std::move(locality)); + } + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName())); + // The second ADS response contains 1 locality, which contains backend 1. + args = AdsServiceImpl::EdsResourceArgs({ + {"locality0", GetBackendPorts(1, 2)}, + }); + std::thread delayed_resource_setter( + std::bind(&BasicTest::SetEdsResourceWithDelay, this, 0, + AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()), + 60 * 1000)); + // Wait until backend 0 is ready, before which kNumLocalities localities are + // received and handled by the xds policy. + WaitForBackend(0, /*reset_counters=*/false); + EXPECT_EQ(0U, backends_[1]->backend_service()->request_count()); + // Wait until backend 1 is ready, before which kNumLocalities localities are + // removed by the xds policy. + WaitForBackend(1); + delayed_resource_setter.join(); +} + +// Tests that the localities in a locality map are picked correctly after update +// (addition, modification, deletion). +TEST_P(LocalityMapTest, UpdateMap) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + const size_t kNumRpcs = 3000; + // The locality weight for the first 3 localities. + const std::vector<int> kLocalityWeights0 = {2, 3, 4}; + const double kTotalLocalityWeight0 = + std::accumulate(kLocalityWeights0.begin(), kLocalityWeights0.end(), 0); + std::vector<double> locality_weight_rate_0; + for (int weight : kLocalityWeights0) { + locality_weight_rate_0.push_back(weight / kTotalLocalityWeight0); + } + // Delete the first locality, keep the second locality, change the third + // locality's weight from 4 to 2, and add a new locality with weight 6. + const std::vector<int> kLocalityWeights1 = {3, 2, 6}; + const double kTotalLocalityWeight1 = + std::accumulate(kLocalityWeights1.begin(), kLocalityWeights1.end(), 0); + std::vector<double> locality_weight_rate_1 = { + 0 /* placeholder for locality 0 */}; + for (int weight : kLocalityWeights1) { + locality_weight_rate_1.push_back(weight / kTotalLocalityWeight1); + } + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts(0, 1), 2}, + {"locality1", GetBackendPorts(1, 2), 3}, + {"locality2", GetBackendPorts(2, 3), 4}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName())); + // Wait for the first 3 backends to be ready. + WaitForAllBackends(0, 3); + gpr_log(GPR_INFO, "========= BEFORE FIRST BATCH =========="); + // Send kNumRpcs RPCs. + CheckRpcSendOk(kNumRpcs); + gpr_log(GPR_INFO, "========= DONE WITH FIRST BATCH =========="); + // The picking rates of the first 3 backends should be roughly equal to the + // expectation. + std::vector<double> locality_picked_rates; + for (size_t i = 0; i < 3; ++i) { + locality_picked_rates.push_back( + static_cast<double>(backends_[i]->backend_service()->request_count()) / + kNumRpcs); + } + const double kErrorTolerance = 0.2; + for (size_t i = 0; i < 3; ++i) { + gpr_log(GPR_INFO, "Locality %" PRIuPTR " rate %f", i, + locality_picked_rates[i]); + EXPECT_THAT( + locality_picked_rates[i], + ::testing::AllOf( + ::testing::Ge(locality_weight_rate_0[i] * (1 - kErrorTolerance)), + ::testing::Le(locality_weight_rate_0[i] * (1 + kErrorTolerance)))); + } + args = AdsServiceImpl::EdsResourceArgs({ + {"locality1", GetBackendPorts(1, 2), 3}, + {"locality2", GetBackendPorts(2, 3), 2}, + {"locality3", GetBackendPorts(3, 4), 6}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName())); + // Backend 3 hasn't received any request. + EXPECT_EQ(0U, backends_[3]->backend_service()->request_count()); + // Wait until the locality update has been processed, as signaled by backend 3 + // receiving a request. + WaitForAllBackends(3, 4); + gpr_log(GPR_INFO, "========= BEFORE SECOND BATCH =========="); + // Send kNumRpcs RPCs. + CheckRpcSendOk(kNumRpcs); + gpr_log(GPR_INFO, "========= DONE WITH SECOND BATCH =========="); + // Backend 0 no longer receives any request. + EXPECT_EQ(0U, backends_[0]->backend_service()->request_count()); + // The picking rates of the last 3 backends should be roughly equal to the + // expectation. + locality_picked_rates = {0 /* placeholder for backend 0 */}; + for (size_t i = 1; i < 4; ++i) { + locality_picked_rates.push_back( + static_cast<double>(backends_[i]->backend_service()->request_count()) / + kNumRpcs); + } + for (size_t i = 1; i < 4; ++i) { + gpr_log(GPR_INFO, "Locality %" PRIuPTR " rate %f", i, + locality_picked_rates[i]); + EXPECT_THAT( + locality_picked_rates[i], + ::testing::AllOf( + ::testing::Ge(locality_weight_rate_1[i] * (1 - kErrorTolerance)), + ::testing::Le(locality_weight_rate_1[i] * (1 + kErrorTolerance)))); + } +} + +// Tests that we don't fail RPCs when replacing all of the localities in +// a given priority. +TEST_P(LocalityMapTest, ReplaceAllLocalitiesInPriority) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts(0, 1)}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName())); + args = AdsServiceImpl::EdsResourceArgs({ + {"locality1", GetBackendPorts(1, 2)}, + }); + std::thread delayed_resource_setter(std::bind( + &BasicTest::SetEdsResourceWithDelay, this, 0, + AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()), 5000)); + // Wait for the first backend to be ready. + WaitForBackend(0); + // Keep sending RPCs until we switch over to backend 1, which tells us + // that we received the update. No RPCs should fail during this + // transition. + WaitForBackend(1, /*reset_counters=*/true, /*require_success=*/true); + delayed_resource_setter.join(); +} + +class FailoverTest : public BasicTest { + public: + void SetUp() override { + BasicTest::SetUp(); + ResetStub(500); + } +}; + +// Localities with the highest priority are used when multiple priority exist. +TEST_P(FailoverTest, ChooseHighestPriority) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts(0, 1), kDefaultLocalityWeight, 1}, + {"locality1", GetBackendPorts(1, 2), kDefaultLocalityWeight, 2}, + {"locality2", GetBackendPorts(2, 3), kDefaultLocalityWeight, 3}, + {"locality3", GetBackendPorts(3, 4), kDefaultLocalityWeight, 0}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName())); + WaitForBackend(3, false); + for (size_t i = 0; i < 3; ++i) { + EXPECT_EQ(0U, backends_[i]->backend_service()->request_count()); + } +} + +// Does not choose priority with no endpoints. +TEST_P(FailoverTest, DoesNotUsePriorityWithNoEndpoints) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts(0, 1), kDefaultLocalityWeight, 1}, + {"locality1", GetBackendPorts(1, 2), kDefaultLocalityWeight, 2}, + {"locality2", GetBackendPorts(2, 3), kDefaultLocalityWeight, 3}, + {"locality3", {}, kDefaultLocalityWeight, 0}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName())); + WaitForBackend(0, false); + for (size_t i = 1; i < 3; ++i) { + EXPECT_EQ(0U, backends_[i]->backend_service()->request_count()); + } +} + +// Does not choose locality with no endpoints. +TEST_P(FailoverTest, DoesNotUseLocalityWithNoEndpoints) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", {}, kDefaultLocalityWeight, 0}, + {"locality1", GetBackendPorts(), kDefaultLocalityWeight, 0}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName())); + // Wait for all backends to be used. + std::tuple<int, int, int> counts = WaitForAllBackends(); + // Make sure no RPCs failed in the transition. + EXPECT_EQ(0, std::get<1>(counts)); +} + +// If the higher priority localities are not reachable, failover to the highest +// priority among the rest. +TEST_P(FailoverTest, Failover) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts(0, 1), kDefaultLocalityWeight, 1}, + {"locality1", GetBackendPorts(1, 2), kDefaultLocalityWeight, 2}, + {"locality2", GetBackendPorts(2, 3), kDefaultLocalityWeight, 3}, + {"locality3", GetBackendPorts(3, 4), kDefaultLocalityWeight, 0}, + }); + ShutdownBackend(3); + ShutdownBackend(0); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName())); + WaitForBackend(1, false); + for (size_t i = 0; i < 4; ++i) { + if (i == 1) continue; + EXPECT_EQ(0U, backends_[i]->backend_service()->request_count()); + } +} + +// If a locality with higher priority than the current one becomes ready, +// switch to it. +TEST_P(FailoverTest, SwitchBackToHigherPriority) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + const size_t kNumRpcs = 100; + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts(0, 1), kDefaultLocalityWeight, 1}, + {"locality1", GetBackendPorts(1, 2), kDefaultLocalityWeight, 2}, + {"locality2", GetBackendPorts(2, 3), kDefaultLocalityWeight, 3}, + {"locality3", GetBackendPorts(3, 4), kDefaultLocalityWeight, 0}, + }); + ShutdownBackend(3); + ShutdownBackend(0); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName())); + WaitForBackend(1, false); + for (size_t i = 0; i < 4; ++i) { + if (i == 1) continue; + EXPECT_EQ(0U, backends_[i]->backend_service()->request_count()); + } + StartBackend(0); + WaitForBackend(0); + CheckRpcSendOk(kNumRpcs); + EXPECT_EQ(kNumRpcs, backends_[0]->backend_service()->request_count()); +} + +// The first update only contains unavailable priorities. The second update +// contains available priorities. +TEST_P(FailoverTest, UpdateInitialUnavailable) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts(0, 1), kDefaultLocalityWeight, 0}, + {"locality1", GetBackendPorts(1, 2), kDefaultLocalityWeight, 1}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName())); + args = AdsServiceImpl::EdsResourceArgs({ + {"locality0", GetBackendPorts(0, 1), kDefaultLocalityWeight, 0}, + {"locality1", GetBackendPorts(1, 2), kDefaultLocalityWeight, 1}, + {"locality2", GetBackendPorts(2, 3), kDefaultLocalityWeight, 2}, + {"locality3", GetBackendPorts(3, 4), kDefaultLocalityWeight, 3}, + }); + ShutdownBackend(0); + ShutdownBackend(1); + std::thread delayed_resource_setter(std::bind( + &BasicTest::SetEdsResourceWithDelay, this, 0, + AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()), 1000)); + gpr_timespec deadline = gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_millis(500, GPR_TIMESPAN)); + // Send 0.5 second worth of RPCs. + do { + CheckRpcSendFailure(); + } while (gpr_time_cmp(gpr_now(GPR_CLOCK_REALTIME), deadline) < 0); + WaitForBackend(2, false); + for (size_t i = 0; i < 4; ++i) { + if (i == 2) continue; + EXPECT_EQ(0U, backends_[i]->backend_service()->request_count()); + } + delayed_resource_setter.join(); +} + +// Tests that after the localities' priorities are updated, we still choose the +// highest READY priority with the updated localities. +TEST_P(FailoverTest, UpdatePriority) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + const size_t kNumRpcs = 100; + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts(0, 1), kDefaultLocalityWeight, 1}, + {"locality1", GetBackendPorts(1, 2), kDefaultLocalityWeight, 2}, + {"locality2", GetBackendPorts(2, 3), kDefaultLocalityWeight, 3}, + {"locality3", GetBackendPorts(3, 4), kDefaultLocalityWeight, 0}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName())); + args = AdsServiceImpl::EdsResourceArgs({ + {"locality0", GetBackendPorts(0, 1), kDefaultLocalityWeight, 2}, + {"locality1", GetBackendPorts(1, 2), kDefaultLocalityWeight, 0}, + {"locality2", GetBackendPorts(2, 3), kDefaultLocalityWeight, 1}, + {"locality3", GetBackendPorts(3, 4), kDefaultLocalityWeight, 3}, + }); + std::thread delayed_resource_setter(std::bind( + &BasicTest::SetEdsResourceWithDelay, this, 0, + AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()), 1000)); + WaitForBackend(3, false); + for (size_t i = 0; i < 3; ++i) { + EXPECT_EQ(0U, backends_[i]->backend_service()->request_count()); + } + WaitForBackend(1); + CheckRpcSendOk(kNumRpcs); + EXPECT_EQ(kNumRpcs, backends_[1]->backend_service()->request_count()); + delayed_resource_setter.join(); +} + +// Moves all localities in the current priority to a higher priority. +TEST_P(FailoverTest, MoveAllLocalitiesInCurrentPriorityToHigherPriority) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + // First update: + // - Priority 0 is locality 0, containing backend 0, which is down. + // - Priority 1 is locality 1, containing backends 1 and 2, which are up. + ShutdownBackend(0); + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts(0, 1), kDefaultLocalityWeight, 0}, + {"locality1", GetBackendPorts(1, 3), kDefaultLocalityWeight, 1}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName())); + // Second update: + // - Priority 0 contains both localities 0 and 1. + // - Priority 1 is not present. + // - We add backend 3 to locality 1, just so we have a way to know + // when the update has been seen by the client. + args = AdsServiceImpl::EdsResourceArgs({ + {"locality0", GetBackendPorts(0, 1), kDefaultLocalityWeight, 0}, + {"locality1", GetBackendPorts(1, 4), kDefaultLocalityWeight, 0}, + }); + std::thread delayed_resource_setter(std::bind( + &BasicTest::SetEdsResourceWithDelay, this, 0, + AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName()), 1000)); + // When we get the first update, all backends in priority 0 are down, + // so we will create priority 1. Backends 1 and 2 should have traffic, + // but backend 3 should not. + WaitForAllBackends(1, 3, false); + EXPECT_EQ(0UL, backends_[3]->backend_service()->request_count()); + // When backend 3 gets traffic, we know the second update has been seen. + WaitForBackend(3); + // The ADS service of balancer 0 got at least 1 response. + EXPECT_GT(balancers_[0]->ads_service()->eds_response_state().state, + AdsServiceImpl::ResponseState::NOT_SENT); + delayed_resource_setter.join(); +} + +using DropTest = BasicTest; + +// Tests that RPCs are dropped according to the drop config. +TEST_P(DropTest, Vanilla) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + const size_t kNumRpcs = 5000; + const uint32_t kDropPerMillionForLb = 100000; + const uint32_t kDropPerMillionForThrottle = 200000; + const double kDropRateForLb = kDropPerMillionForLb / 1000000.0; + const double kDropRateForThrottle = kDropPerMillionForThrottle / 1000000.0; + const double KDropRateForLbAndThrottle = + kDropRateForLb + (1 - kDropRateForLb) * kDropRateForThrottle; + // The ADS response contains two drop categories. + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts()}, + }); + args.drop_categories = {{kLbDropType, kDropPerMillionForLb}, + {kThrottleDropType, kDropPerMillionForThrottle}}; + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName())); + WaitForAllBackends(); + // Send kNumRpcs RPCs and count the drops. + size_t num_drops = 0; + for (size_t i = 0; i < kNumRpcs; ++i) { + EchoResponse response; + const Status status = SendRpc(RpcOptions(), &response); + if (!status.ok() && + status.error_message() == "Call dropped by load balancing policy") { + ++num_drops; + } else { + EXPECT_TRUE(status.ok()) << "code=" << status.error_code() + << " message=" << status.error_message(); + EXPECT_EQ(response.message(), kRequestMessage); + } + } + // The drop rate should be roughly equal to the expectation. + const double seen_drop_rate = static_cast<double>(num_drops) / kNumRpcs; + const double kErrorTolerance = 0.2; + EXPECT_THAT( + seen_drop_rate, + ::testing::AllOf( + ::testing::Ge(KDropRateForLbAndThrottle * (1 - kErrorTolerance)), + ::testing::Le(KDropRateForLbAndThrottle * (1 + kErrorTolerance)))); +} + +// Tests that drop config is converted correctly from per hundred. +TEST_P(DropTest, DropPerHundred) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + const size_t kNumRpcs = 5000; + const uint32_t kDropPerHundredForLb = 10; + const double kDropRateForLb = kDropPerHundredForLb / 100.0; + // The ADS response contains one drop category. + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts()}, + }); + args.drop_categories = {{kLbDropType, kDropPerHundredForLb}}; + args.drop_denominator = FractionalPercent::HUNDRED; + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName())); + WaitForAllBackends(); + // Send kNumRpcs RPCs and count the drops. + size_t num_drops = 0; + for (size_t i = 0; i < kNumRpcs; ++i) { + EchoResponse response; + const Status status = SendRpc(RpcOptions(), &response); + if (!status.ok() && + status.error_message() == "Call dropped by load balancing policy") { + ++num_drops; + } else { + EXPECT_TRUE(status.ok()) << "code=" << status.error_code() + << " message=" << status.error_message(); + EXPECT_EQ(response.message(), kRequestMessage); + } + } + // The drop rate should be roughly equal to the expectation. + const double seen_drop_rate = static_cast<double>(num_drops) / kNumRpcs; + const double kErrorTolerance = 0.2; + EXPECT_THAT( + seen_drop_rate, + ::testing::AllOf(::testing::Ge(kDropRateForLb * (1 - kErrorTolerance)), + ::testing::Le(kDropRateForLb * (1 + kErrorTolerance)))); +} + +// Tests that drop config is converted correctly from per ten thousand. +TEST_P(DropTest, DropPerTenThousand) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + const size_t kNumRpcs = 5000; + const uint32_t kDropPerTenThousandForLb = 1000; + const double kDropRateForLb = kDropPerTenThousandForLb / 10000.0; + // The ADS response contains one drop category. + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts()}, + }); + args.drop_categories = {{kLbDropType, kDropPerTenThousandForLb}}; + args.drop_denominator = FractionalPercent::TEN_THOUSAND; + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName())); + WaitForAllBackends(); + // Send kNumRpcs RPCs and count the drops. + size_t num_drops = 0; + for (size_t i = 0; i < kNumRpcs; ++i) { + EchoResponse response; + const Status status = SendRpc(RpcOptions(), &response); + if (!status.ok() && + status.error_message() == "Call dropped by load balancing policy") { + ++num_drops; + } else { + EXPECT_TRUE(status.ok()) << "code=" << status.error_code() + << " message=" << status.error_message(); + EXPECT_EQ(response.message(), kRequestMessage); + } + } + // The drop rate should be roughly equal to the expectation. + const double seen_drop_rate = static_cast<double>(num_drops) / kNumRpcs; + const double kErrorTolerance = 0.2; + EXPECT_THAT( + seen_drop_rate, + ::testing::AllOf(::testing::Ge(kDropRateForLb * (1 - kErrorTolerance)), + ::testing::Le(kDropRateForLb * (1 + kErrorTolerance)))); +} + +// Tests that drop is working correctly after update. +TEST_P(DropTest, Update) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + const size_t kNumRpcs = 3000; + const uint32_t kDropPerMillionForLb = 100000; + const uint32_t kDropPerMillionForThrottle = 200000; + const double kDropRateForLb = kDropPerMillionForLb / 1000000.0; + const double kDropRateForThrottle = kDropPerMillionForThrottle / 1000000.0; + const double KDropRateForLbAndThrottle = + kDropRateForLb + (1 - kDropRateForLb) * kDropRateForThrottle; + // The first ADS response contains one drop category. + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts()}, + }); + args.drop_categories = {{kLbDropType, kDropPerMillionForLb}}; + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName())); + WaitForAllBackends(); + // Send kNumRpcs RPCs and count the drops. + size_t num_drops = 0; + gpr_log(GPR_INFO, "========= BEFORE FIRST BATCH =========="); + for (size_t i = 0; i < kNumRpcs; ++i) { + EchoResponse response; + const Status status = SendRpc(RpcOptions(), &response); + if (!status.ok() && + status.error_message() == "Call dropped by load balancing policy") { + ++num_drops; + } else { + EXPECT_TRUE(status.ok()) << "code=" << status.error_code() + << " message=" << status.error_message(); + EXPECT_EQ(response.message(), kRequestMessage); + } + } + gpr_log(GPR_INFO, "========= DONE WITH FIRST BATCH =========="); + // The drop rate should be roughly equal to the expectation. + double seen_drop_rate = static_cast<double>(num_drops) / kNumRpcs; + gpr_log(GPR_INFO, "First batch drop rate %f", seen_drop_rate); + const double kErrorTolerance = 0.3; + EXPECT_THAT( + seen_drop_rate, + ::testing::AllOf(::testing::Ge(kDropRateForLb * (1 - kErrorTolerance)), + ::testing::Le(kDropRateForLb * (1 + kErrorTolerance)))); + // The second ADS response contains two drop categories, send an update EDS + // response. + args.drop_categories = {{kLbDropType, kDropPerMillionForLb}, + {kThrottleDropType, kDropPerMillionForThrottle}}; + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName())); + // Wait until the drop rate increases to the middle of the two configs, which + // implies that the update has been in effect. + const double kDropRateThreshold = + (kDropRateForLb + KDropRateForLbAndThrottle) / 2; + size_t num_rpcs = kNumRpcs; + while (seen_drop_rate < kDropRateThreshold) { + EchoResponse response; + const Status status = SendRpc(RpcOptions(), &response); + ++num_rpcs; + if (!status.ok() && + status.error_message() == "Call dropped by load balancing policy") { + ++num_drops; + } else { + EXPECT_TRUE(status.ok()) << "code=" << status.error_code() + << " message=" << status.error_message(); + EXPECT_EQ(response.message(), kRequestMessage); + } + seen_drop_rate = static_cast<double>(num_drops) / num_rpcs; + } + // Send kNumRpcs RPCs and count the drops. + num_drops = 0; + gpr_log(GPR_INFO, "========= BEFORE SECOND BATCH =========="); + for (size_t i = 0; i < kNumRpcs; ++i) { + EchoResponse response; + const Status status = SendRpc(RpcOptions(), &response); + if (!status.ok() && + status.error_message() == "Call dropped by load balancing policy") { + ++num_drops; + } else { + EXPECT_TRUE(status.ok()) << "code=" << status.error_code() + << " message=" << status.error_message(); + EXPECT_EQ(response.message(), kRequestMessage); + } + } + gpr_log(GPR_INFO, "========= DONE WITH SECOND BATCH =========="); + // The new drop rate should be roughly equal to the expectation. + seen_drop_rate = static_cast<double>(num_drops) / kNumRpcs; + gpr_log(GPR_INFO, "Second batch drop rate %f", seen_drop_rate); + EXPECT_THAT( + seen_drop_rate, + ::testing::AllOf( + ::testing::Ge(KDropRateForLbAndThrottle * (1 - kErrorTolerance)), + ::testing::Le(KDropRateForLbAndThrottle * (1 + kErrorTolerance)))); +} + +// Tests that all the RPCs are dropped if any drop category drops 100%. +TEST_P(DropTest, DropAll) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + const size_t kNumRpcs = 1000; + const uint32_t kDropPerMillionForLb = 100000; + const uint32_t kDropPerMillionForThrottle = 1000000; + // The ADS response contains two drop categories. + AdsServiceImpl::EdsResourceArgs args; + args.drop_categories = {{kLbDropType, kDropPerMillionForLb}, + {kThrottleDropType, kDropPerMillionForThrottle}}; + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName())); + // Send kNumRpcs RPCs and all of them are dropped. + for (size_t i = 0; i < kNumRpcs; ++i) { + EchoResponse response; + const Status status = SendRpc(RpcOptions(), &response); + EXPECT_EQ(status.error_code(), StatusCode::UNAVAILABLE); + EXPECT_EQ(status.error_message(), "Call dropped by load balancing policy"); + } +} + +class BalancerUpdateTest : public XdsEnd2endTest { + public: + BalancerUpdateTest() : XdsEnd2endTest(4, 3) {} +}; + +// Tests that the old LB call is still used after the balancer address update as +// long as that call is still alive. +TEST_P(BalancerUpdateTest, UpdateBalancersButKeepUsingOriginalBalancer) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", {backends_[0]->port()}}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName())); + args = AdsServiceImpl::EdsResourceArgs({ + {"locality0", {backends_[1]->port()}}, + }); + balancers_[1]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName())); + // Wait until the first backend is ready. + WaitForBackend(0); + // Send 10 requests. + gpr_log(GPR_INFO, "========= BEFORE FIRST BATCH =========="); + CheckRpcSendOk(10); + gpr_log(GPR_INFO, "========= DONE WITH FIRST BATCH =========="); + // All 10 requests should have gone to the first backend. + EXPECT_EQ(10U, backends_[0]->backend_service()->request_count()); + // The ADS service of balancer 0 sent at least 1 response. + EXPECT_GT(balancers_[0]->ads_service()->eds_response_state().state, + AdsServiceImpl::ResponseState::NOT_SENT); + EXPECT_EQ(balancers_[1]->ads_service()->eds_response_state().state, + AdsServiceImpl::ResponseState::NOT_SENT) + << "Error Message:" + << balancers_[1]->ads_service()->eds_response_state().error_message; + EXPECT_EQ(balancers_[2]->ads_service()->eds_response_state().state, + AdsServiceImpl::ResponseState::NOT_SENT) + << "Error Message:" + << balancers_[2]->ads_service()->eds_response_state().error_message; + gpr_log(GPR_INFO, "========= ABOUT TO UPDATE 1 =========="); + SetNextResolutionForLbChannel({balancers_[1]->port()}); + gpr_log(GPR_INFO, "========= UPDATE 1 DONE =========="); + EXPECT_EQ(0U, backends_[1]->backend_service()->request_count()); + gpr_timespec deadline = gpr_time_add( + gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_millis(10000, GPR_TIMESPAN)); + // Send 10 seconds worth of RPCs + do { + CheckRpcSendOk(); + } while (gpr_time_cmp(gpr_now(GPR_CLOCK_REALTIME), deadline) < 0); + // The current LB call is still working, so xds continued using it to the + // first balancer, which doesn't assign the second backend. + EXPECT_EQ(0U, backends_[1]->backend_service()->request_count()); + // The ADS service of balancer 0 sent at least 1 response. + EXPECT_GT(balancers_[0]->ads_service()->eds_response_state().state, + AdsServiceImpl::ResponseState::NOT_SENT); + EXPECT_EQ(balancers_[1]->ads_service()->eds_response_state().state, + AdsServiceImpl::ResponseState::NOT_SENT) + << "Error Message:" + << balancers_[1]->ads_service()->eds_response_state().error_message; + EXPECT_EQ(balancers_[2]->ads_service()->eds_response_state().state, + AdsServiceImpl::ResponseState::NOT_SENT) + << "Error Message:" + << balancers_[2]->ads_service()->eds_response_state().error_message; +} + +// Tests that the old LB call is still used after multiple balancer address +// updates as long as that call is still alive. Send an update with the same set +// of LBs as the one in SetUp() in order to verify that the LB channel inside +// xds keeps the initial connection (which by definition is also present in the +// update). +TEST_P(BalancerUpdateTest, Repeated) { + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", {backends_[0]->port()}}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName())); + args = AdsServiceImpl::EdsResourceArgs({ + {"locality0", {backends_[1]->port()}}, + }); + balancers_[1]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName())); + // Wait until the first backend is ready. + WaitForBackend(0); + // Send 10 requests. + gpr_log(GPR_INFO, "========= BEFORE FIRST BATCH =========="); + CheckRpcSendOk(10); + gpr_log(GPR_INFO, "========= DONE WITH FIRST BATCH =========="); + // All 10 requests should have gone to the first backend. + EXPECT_EQ(10U, backends_[0]->backend_service()->request_count()); + // The ADS service of balancer 0 sent at least 1 response. + EXPECT_GT(balancers_[0]->ads_service()->eds_response_state().state, + AdsServiceImpl::ResponseState::NOT_SENT); + EXPECT_EQ(balancers_[1]->ads_service()->eds_response_state().state, + AdsServiceImpl::ResponseState::NOT_SENT) + << "Error Message:" + << balancers_[1]->ads_service()->eds_response_state().error_message; + EXPECT_EQ(balancers_[2]->ads_service()->eds_response_state().state, + AdsServiceImpl::ResponseState::NOT_SENT) + << "Error Message:" + << balancers_[2]->ads_service()->eds_response_state().error_message; + std::vector<int> ports; + ports.emplace_back(balancers_[0]->port()); + ports.emplace_back(balancers_[1]->port()); + ports.emplace_back(balancers_[2]->port()); + gpr_log(GPR_INFO, "========= ABOUT TO UPDATE 1 =========="); + SetNextResolutionForLbChannel(ports); + gpr_log(GPR_INFO, "========= UPDATE 1 DONE =========="); + EXPECT_EQ(0U, backends_[1]->backend_service()->request_count()); + gpr_timespec deadline = gpr_time_add( + gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_millis(10000, GPR_TIMESPAN)); + // Send 10 seconds worth of RPCs + do { + CheckRpcSendOk(); + } while (gpr_time_cmp(gpr_now(GPR_CLOCK_REALTIME), deadline) < 0); + // xds continued using the original LB call to the first balancer, which + // doesn't assign the second backend. + EXPECT_EQ(0U, backends_[1]->backend_service()->request_count()); + ports.clear(); + ports.emplace_back(balancers_[0]->port()); + ports.emplace_back(balancers_[1]->port()); + gpr_log(GPR_INFO, "========= ABOUT TO UPDATE 2 =========="); + SetNextResolutionForLbChannel(ports); + gpr_log(GPR_INFO, "========= UPDATE 2 DONE =========="); + EXPECT_EQ(0U, backends_[1]->backend_service()->request_count()); + deadline = gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_millis(10000, GPR_TIMESPAN)); + // Send 10 seconds worth of RPCs + do { + CheckRpcSendOk(); + } while (gpr_time_cmp(gpr_now(GPR_CLOCK_REALTIME), deadline) < 0); + // xds continued using the original LB call to the first balancer, which + // doesn't assign the second backend. + EXPECT_EQ(0U, backends_[1]->backend_service()->request_count()); +} + +// Tests that if the balancer is down, the RPCs will still be sent to the +// backends according to the last balancer response, until a new balancer is +// reachable. +TEST_P(BalancerUpdateTest, DeadUpdate) { + SetNextResolution({}); + SetNextResolutionForLbChannel({balancers_[0]->port()}); + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", {backends_[0]->port()}}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName())); + args = AdsServiceImpl::EdsResourceArgs({ + {"locality0", {backends_[1]->port()}}, + }); + balancers_[1]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName())); + // Start servers and send 10 RPCs per server. + gpr_log(GPR_INFO, "========= BEFORE FIRST BATCH =========="); + CheckRpcSendOk(10); + gpr_log(GPR_INFO, "========= DONE WITH FIRST BATCH =========="); + // All 10 requests should have gone to the first backend. + EXPECT_EQ(10U, backends_[0]->backend_service()->request_count()); + // The ADS service of balancer 0 sent at least 1 response. + EXPECT_GT(balancers_[0]->ads_service()->eds_response_state().state, + AdsServiceImpl::ResponseState::NOT_SENT); + EXPECT_EQ(balancers_[1]->ads_service()->eds_response_state().state, + AdsServiceImpl::ResponseState::NOT_SENT) + << "Error Message:" + << balancers_[1]->ads_service()->eds_response_state().error_message; + EXPECT_EQ(balancers_[2]->ads_service()->eds_response_state().state, + AdsServiceImpl::ResponseState::NOT_SENT) + << "Error Message:" + << balancers_[2]->ads_service()->eds_response_state().error_message; + // Kill balancer 0 + gpr_log(GPR_INFO, "********** ABOUT TO KILL BALANCER 0 *************"); + balancers_[0]->Shutdown(); + gpr_log(GPR_INFO, "********** KILLED BALANCER 0 *************"); + // This is serviced by the existing child policy. + gpr_log(GPR_INFO, "========= BEFORE SECOND BATCH =========="); + CheckRpcSendOk(10); + gpr_log(GPR_INFO, "========= DONE WITH SECOND BATCH =========="); + // All 10 requests should again have gone to the first backend. + EXPECT_EQ(20U, backends_[0]->backend_service()->request_count()); + EXPECT_EQ(0U, backends_[1]->backend_service()->request_count()); + // The ADS service of no balancers sent anything + EXPECT_EQ(balancers_[0]->ads_service()->eds_response_state().state, + AdsServiceImpl::ResponseState::NOT_SENT) + << "Error Message:" + << balancers_[0]->ads_service()->eds_response_state().error_message; + EXPECT_EQ(balancers_[1]->ads_service()->eds_response_state().state, + AdsServiceImpl::ResponseState::NOT_SENT) + << "Error Message:" + << balancers_[1]->ads_service()->eds_response_state().error_message; + EXPECT_EQ(balancers_[2]->ads_service()->eds_response_state().state, + AdsServiceImpl::ResponseState::NOT_SENT) + << "Error Message:" + << balancers_[2]->ads_service()->eds_response_state().error_message; + gpr_log(GPR_INFO, "========= ABOUT TO UPDATE 1 =========="); + SetNextResolutionForLbChannel({balancers_[1]->port()}); + gpr_log(GPR_INFO, "========= UPDATE 1 DONE =========="); + // Wait until update has been processed, as signaled by the second backend + // receiving a request. In the meantime, the client continues to be serviced + // (by the first backend) without interruption. + EXPECT_EQ(0U, backends_[1]->backend_service()->request_count()); + WaitForBackend(1); + // This is serviced by the updated RR policy + backends_[1]->backend_service()->ResetCounters(); + gpr_log(GPR_INFO, "========= BEFORE THIRD BATCH =========="); + CheckRpcSendOk(10); + gpr_log(GPR_INFO, "========= DONE WITH THIRD BATCH =========="); + // All 10 requests should have gone to the second backend. + EXPECT_EQ(10U, backends_[1]->backend_service()->request_count()); + // The ADS service of balancer 1 sent at least 1 response. + EXPECT_EQ(balancers_[0]->ads_service()->eds_response_state().state, + AdsServiceImpl::ResponseState::NOT_SENT) + << "Error Message:" + << balancers_[0]->ads_service()->eds_response_state().error_message; + EXPECT_GT(balancers_[1]->ads_service()->eds_response_state().state, + AdsServiceImpl::ResponseState::NOT_SENT); + EXPECT_EQ(balancers_[2]->ads_service()->eds_response_state().state, + AdsServiceImpl::ResponseState::NOT_SENT) + << "Error Message:" + << balancers_[2]->ads_service()->eds_response_state().error_message; +} + +// The re-resolution tests are deferred because they rely on the fallback mode, +// which hasn't been supported. + +// TODO(juanlishen): Add TEST_P(BalancerUpdateTest, ReresolveDeadBackend). + +// TODO(juanlishen): Add TEST_P(UpdatesWithClientLoadReportingTest, +// ReresolveDeadBalancer) + +class ClientLoadReportingTest : public XdsEnd2endTest { + public: + ClientLoadReportingTest() : XdsEnd2endTest(4, 1, 3) {} +}; + +// Tests that the load report received at the balancer is correct. +TEST_P(ClientLoadReportingTest, Vanilla) { + if (!GetParam().use_xds_resolver()) { + balancers_[0]->lrs_service()->set_cluster_names({kServerName}); + } + SetNextResolution({}); + SetNextResolutionForLbChannel({balancers_[0]->port()}); + const size_t kNumRpcsPerAddress = 10; + const size_t kNumFailuresPerAddress = 3; + // TODO(juanlishen): Partition the backends after multiple localities is + // tested. + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName())); + // Wait until all backends are ready. + int num_ok = 0; + int num_failure = 0; + int num_drops = 0; + std::tie(num_ok, num_failure, num_drops) = WaitForAllBackends(); + // Send kNumRpcsPerAddress RPCs per server. + CheckRpcSendOk(kNumRpcsPerAddress * num_backends_); + CheckRpcSendFailure(kNumFailuresPerAddress * num_backends_, + RpcOptions().set_server_fail(true)); + // Check that each backend got the right number of requests. + for (size_t i = 0; i < backends_.size(); ++i) { + EXPECT_EQ(kNumRpcsPerAddress + kNumFailuresPerAddress, + backends_[i]->backend_service()->request_count()); + } + // The load report received at the balancer should be correct. + std::vector<ClientStats> load_report = + balancers_[0]->lrs_service()->WaitForLoadReport(); + ASSERT_EQ(load_report.size(), 1UL); + ClientStats& client_stats = load_report.front(); + EXPECT_EQ(kNumRpcsPerAddress * num_backends_ + num_ok, + client_stats.total_successful_requests()); + EXPECT_EQ(0U, client_stats.total_requests_in_progress()); + EXPECT_EQ((kNumRpcsPerAddress + kNumFailuresPerAddress) * num_backends_ + + num_ok + num_failure, + client_stats.total_issued_requests()); + EXPECT_EQ(kNumFailuresPerAddress * num_backends_ + num_failure, + client_stats.total_error_requests()); + EXPECT_EQ(0U, client_stats.total_dropped_requests()); + // The LRS service got a single request, and sent a single response. + EXPECT_EQ(1U, balancers_[0]->lrs_service()->request_count()); + EXPECT_EQ(1U, balancers_[0]->lrs_service()->response_count()); +} + +// Tests send_all_clusters. +TEST_P(ClientLoadReportingTest, SendAllClusters) { + balancers_[0]->lrs_service()->set_send_all_clusters(true); + SetNextResolution({}); + SetNextResolutionForLbChannel({balancers_[0]->port()}); + const size_t kNumRpcsPerAddress = 10; + const size_t kNumFailuresPerAddress = 3; + // TODO(juanlishen): Partition the backends after multiple localities is + // tested. + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName())); + // Wait until all backends are ready. + int num_ok = 0; + int num_failure = 0; + int num_drops = 0; + std::tie(num_ok, num_failure, num_drops) = WaitForAllBackends(); + // Send kNumRpcsPerAddress RPCs per server. + CheckRpcSendOk(kNumRpcsPerAddress * num_backends_); + CheckRpcSendFailure(kNumFailuresPerAddress * num_backends_, + RpcOptions().set_server_fail(true)); + // Check that each backend got the right number of requests. + for (size_t i = 0; i < backends_.size(); ++i) { + EXPECT_EQ(kNumRpcsPerAddress + kNumFailuresPerAddress, + backends_[i]->backend_service()->request_count()); + } + // The load report received at the balancer should be correct. + std::vector<ClientStats> load_report = + balancers_[0]->lrs_service()->WaitForLoadReport(); + ASSERT_EQ(load_report.size(), 1UL); + ClientStats& client_stats = load_report.front(); + EXPECT_EQ(kNumRpcsPerAddress * num_backends_ + num_ok, + client_stats.total_successful_requests()); + EXPECT_EQ(0U, client_stats.total_requests_in_progress()); + EXPECT_EQ((kNumRpcsPerAddress + kNumFailuresPerAddress) * num_backends_ + + num_ok + num_failure, + client_stats.total_issued_requests()); + EXPECT_EQ(kNumFailuresPerAddress * num_backends_ + num_failure, + client_stats.total_error_requests()); + EXPECT_EQ(0U, client_stats.total_dropped_requests()); + // The LRS service got a single request, and sent a single response. + EXPECT_EQ(1U, balancers_[0]->lrs_service()->request_count()); + EXPECT_EQ(1U, balancers_[0]->lrs_service()->response_count()); +} + +// Tests that we don't include stats for clusters that are not requested +// by the LRS server. +TEST_P(ClientLoadReportingTest, HonorsClustersRequestedByLrsServer) { + balancers_[0]->lrs_service()->set_cluster_names({"bogus"}); + SetNextResolution({}); + SetNextResolutionForLbChannel({balancers_[0]->port()}); + const size_t kNumRpcsPerAddress = 100; + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts()}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName())); + // Wait until all backends are ready. + int num_ok = 0; + int num_failure = 0; + int num_drops = 0; + std::tie(num_ok, num_failure, num_drops) = WaitForAllBackends(); + // Send kNumRpcsPerAddress RPCs per server. + CheckRpcSendOk(kNumRpcsPerAddress * num_backends_); + // Each backend should have gotten 100 requests. + for (size_t i = 0; i < backends_.size(); ++i) { + EXPECT_EQ(kNumRpcsPerAddress, + backends_[i]->backend_service()->request_count()); + } + // The LRS service got a single request, and sent a single response. + EXPECT_EQ(1U, balancers_[0]->lrs_service()->request_count()); + EXPECT_EQ(1U, balancers_[0]->lrs_service()->response_count()); + // The load report received at the balancer should be correct. + std::vector<ClientStats> load_report = + balancers_[0]->lrs_service()->WaitForLoadReport(); + ASSERT_EQ(load_report.size(), 0UL); +} + +// Tests that if the balancer restarts, the client load report contains the +// stats before and after the restart correctly. +TEST_P(ClientLoadReportingTest, BalancerRestart) { + if (!GetParam().use_xds_resolver()) { + balancers_[0]->lrs_service()->set_cluster_names({kServerName}); + } + SetNextResolution({}); + SetNextResolutionForLbChannel({balancers_[0]->port()}); + const size_t kNumBackendsFirstPass = backends_.size() / 2; + const size_t kNumBackendsSecondPass = + backends_.size() - kNumBackendsFirstPass; + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts(0, kNumBackendsFirstPass)}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName())); + // Wait until all backends returned by the balancer are ready. + int num_ok = 0; + int num_failure = 0; + int num_drops = 0; + std::tie(num_ok, num_failure, num_drops) = + WaitForAllBackends(/* start_index */ 0, + /* stop_index */ kNumBackendsFirstPass); + std::vector<ClientStats> load_report = + balancers_[0]->lrs_service()->WaitForLoadReport(); + ASSERT_EQ(load_report.size(), 1UL); + ClientStats client_stats = std::move(load_report.front()); + EXPECT_EQ(static_cast<size_t>(num_ok), + client_stats.total_successful_requests()); + EXPECT_EQ(0U, client_stats.total_requests_in_progress()); + EXPECT_EQ(0U, client_stats.total_error_requests()); + EXPECT_EQ(0U, client_stats.total_dropped_requests()); + // Shut down the balancer. + balancers_[0]->Shutdown(); + // We should continue using the last EDS response we received from the + // balancer before it was shut down. + // Note: We need to use WaitForAllBackends() here instead of just + // CheckRpcSendOk(kNumBackendsFirstPass), because when the balancer + // shuts down, the XdsClient will generate an error to the + // ServiceConfigWatcher, which will cause the xds resolver to send a + // no-op update to the LB policy. When this update gets down to the + // round_robin child policy for the locality, it will generate a new + // subchannel list, which resets the start index randomly. So we need + // to be a little more permissive here to avoid spurious failures. + ResetBackendCounters(); + int num_started = std::get<0>(WaitForAllBackends( + /* start_index */ 0, /* stop_index */ kNumBackendsFirstPass)); + // Now restart the balancer, this time pointing to the new backends. + balancers_[0]->Start(); + args = AdsServiceImpl::EdsResourceArgs({ + {"locality0", GetBackendPorts(kNumBackendsFirstPass)}, + }); + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName())); + // Wait for queries to start going to one of the new backends. + // This tells us that we're now using the new serverlist. + std::tie(num_ok, num_failure, num_drops) = + WaitForAllBackends(/* start_index */ kNumBackendsFirstPass); + num_started += num_ok + num_failure + num_drops; + // Send one RPC per backend. + CheckRpcSendOk(kNumBackendsSecondPass); + num_started += kNumBackendsSecondPass; + // Check client stats. + load_report = balancers_[0]->lrs_service()->WaitForLoadReport(); + ASSERT_EQ(load_report.size(), 1UL); + client_stats = std::move(load_report.front()); + EXPECT_EQ(num_started, client_stats.total_successful_requests()); + EXPECT_EQ(0U, client_stats.total_requests_in_progress()); + EXPECT_EQ(0U, client_stats.total_error_requests()); + EXPECT_EQ(0U, client_stats.total_dropped_requests()); +} + +class ClientLoadReportingWithDropTest : public XdsEnd2endTest { + public: + ClientLoadReportingWithDropTest() : XdsEnd2endTest(4, 1, 20) {} +}; + +// Tests that the drop stats are correctly reported by client load reporting. +TEST_P(ClientLoadReportingWithDropTest, Vanilla) { + if (!GetParam().use_xds_resolver()) { + balancers_[0]->lrs_service()->set_cluster_names({kServerName}); + } + SetNextResolution({}); + SetNextResolutionForLbChannelAllBalancers(); + const size_t kNumRpcs = 3000; + const uint32_t kDropPerMillionForLb = 100000; + const uint32_t kDropPerMillionForThrottle = 200000; + const double kDropRateForLb = kDropPerMillionForLb / 1000000.0; + const double kDropRateForThrottle = kDropPerMillionForThrottle / 1000000.0; + const double KDropRateForLbAndThrottle = + kDropRateForLb + (1 - kDropRateForLb) * kDropRateForThrottle; + // The ADS response contains two drop categories. + AdsServiceImpl::EdsResourceArgs args({ + {"locality0", GetBackendPorts()}, + }); + args.drop_categories = {{kLbDropType, kDropPerMillionForLb}, + {kThrottleDropType, kDropPerMillionForThrottle}}; + balancers_[0]->ads_service()->SetEdsResource( + AdsServiceImpl::BuildEdsResource(args, DefaultEdsServiceName())); + int num_ok = 0; + int num_failure = 0; + int num_drops = 0; + std::tie(num_ok, num_failure, num_drops) = WaitForAllBackends(); + const size_t num_warmup = num_ok + num_failure + num_drops; + // Send kNumRpcs RPCs and count the drops. + for (size_t i = 0; i < kNumRpcs; ++i) { + EchoResponse response; + const Status status = SendRpc(RpcOptions(), &response); + if (!status.ok() && + status.error_message() == "Call dropped by load balancing policy") { + ++num_drops; + } else { + EXPECT_TRUE(status.ok()) << "code=" << status.error_code() + << " message=" << status.error_message(); + EXPECT_EQ(response.message(), kRequestMessage); + } + } + // The drop rate should be roughly equal to the expectation. + const double seen_drop_rate = static_cast<double>(num_drops) / kNumRpcs; + const double kErrorTolerance = 0.2; + EXPECT_THAT( + seen_drop_rate, + ::testing::AllOf( + ::testing::Ge(KDropRateForLbAndThrottle * (1 - kErrorTolerance)), + ::testing::Le(KDropRateForLbAndThrottle * (1 + kErrorTolerance)))); + // Check client stats. + const size_t total_rpc = num_warmup + kNumRpcs; + ClientStats client_stats; + do { + std::vector<ClientStats> load_reports = + balancers_[0]->lrs_service()->WaitForLoadReport(); + for (const auto& load_report : load_reports) { + client_stats += load_report; + } + } while (client_stats.total_issued_requests() + + client_stats.total_dropped_requests() < + total_rpc); + EXPECT_EQ(num_drops, client_stats.total_dropped_requests()); + EXPECT_THAT( + client_stats.dropped_requests(kLbDropType), + ::testing::AllOf( + ::testing::Ge(total_rpc * kDropRateForLb * (1 - kErrorTolerance)), + ::testing::Le(total_rpc * kDropRateForLb * (1 + kErrorTolerance)))); + EXPECT_THAT(client_stats.dropped_requests(kThrottleDropType), + ::testing::AllOf( + ::testing::Ge(total_rpc * (1 - kDropRateForLb) * + kDropRateForThrottle * (1 - kErrorTolerance)), + ::testing::Le(total_rpc * (1 - kDropRateForLb) * + kDropRateForThrottle * (1 + kErrorTolerance)))); +} + +TString TestTypeName(const ::testing::TestParamInfo<TestType>& info) { + return info.param.AsString(); +} + +// TestType params: +// - use_xds_resolver +// - enable_load_reporting +// - enable_rds_testing = false +// - use_v2 = false + +INSTANTIATE_TEST_SUITE_P(XdsTest, BasicTest, + ::testing::Values(TestType(false, true), + TestType(false, false), + TestType(true, false), + TestType(true, true)), + &TestTypeName); + +// Run with both fake resolver and xds resolver. +// Don't run with load reporting or v2 or RDS, since they are irrelevant to +// the tests. +INSTANTIATE_TEST_SUITE_P(XdsTest, SecureNamingTest, + ::testing::Values(TestType(false, false), + TestType(true, false)), + &TestTypeName); + +// LDS depends on XdsResolver. +INSTANTIATE_TEST_SUITE_P(XdsTest, LdsTest, + ::testing::Values(TestType(true, false), + TestType(true, true)), + &TestTypeName); + +// LDS/RDS commmon tests depend on XdsResolver. +INSTANTIATE_TEST_SUITE_P(XdsTest, LdsRdsTest, + ::testing::Values(TestType(true, false), + TestType(true, true), + TestType(true, false, true), + TestType(true, true, true), + // Also test with xDS v2. + TestType(true, true, true, true)), + &TestTypeName); + +// CDS depends on XdsResolver. +INSTANTIATE_TEST_SUITE_P(XdsTest, CdsTest, + ::testing::Values(TestType(true, false), + TestType(true, true)), + &TestTypeName); + +// EDS could be tested with or without XdsResolver, but the tests would +// be the same either way, so we test it only with XdsResolver. +INSTANTIATE_TEST_SUITE_P(XdsTest, EdsTest, + ::testing::Values(TestType(true, false), + TestType(true, true)), + &TestTypeName); + +// Test initial resource timeouts for each resource type. +// Do this only for XdsResolver with RDS enabled, so that we can test +// all resource types. +// Run with V3 only, since the functionality is no different in V2. +INSTANTIATE_TEST_SUITE_P(XdsTest, TimeoutTest, + ::testing::Values(TestType(true, false, true)), + &TestTypeName); + +// XdsResolverOnlyTest depends on XdsResolver. +INSTANTIATE_TEST_SUITE_P(XdsTest, XdsResolverOnlyTest, + ::testing::Values(TestType(true, false), + TestType(true, true)), + &TestTypeName); + +// XdsResolverLoadReprtingOnlyTest depends on XdsResolver and load reporting. +INSTANTIATE_TEST_SUITE_P(XdsTest, XdsResolverLoadReportingOnlyTest, + ::testing::Values(TestType(true, true)), + &TestTypeName); + +INSTANTIATE_TEST_SUITE_P(XdsTest, LocalityMapTest, + ::testing::Values(TestType(false, true), + TestType(false, false), + TestType(true, false), + TestType(true, true)), + &TestTypeName); + +INSTANTIATE_TEST_SUITE_P(XdsTest, FailoverTest, + ::testing::Values(TestType(false, true), + TestType(false, false), + TestType(true, false), + TestType(true, true)), + &TestTypeName); + +INSTANTIATE_TEST_SUITE_P(XdsTest, DropTest, + ::testing::Values(TestType(false, true), + TestType(false, false), + TestType(true, false), + TestType(true, true)), + &TestTypeName); + +INSTANTIATE_TEST_SUITE_P(XdsTest, BalancerUpdateTest, + ::testing::Values(TestType(false, true), + TestType(false, false), + TestType(true, true)), + &TestTypeName); + +// Load reporting tests are not run with load reporting disabled. +INSTANTIATE_TEST_SUITE_P(XdsTest, ClientLoadReportingTest, + ::testing::Values(TestType(false, true), + TestType(true, true)), + &TestTypeName); + +// Load reporting tests are not run with load reporting disabled. +INSTANTIATE_TEST_SUITE_P(XdsTest, ClientLoadReportingWithDropTest, + ::testing::Values(TestType(false, true), + TestType(true, true)), + &TestTypeName); + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + grpc::testing::WriteBootstrapFiles(); + grpc::testing::g_port_saver = new grpc::testing::PortSaver(); + const auto result = RUN_ALL_TESTS(); + return result; +} diff --git a/contrib/libs/grpc/test/cpp/end2end/ya.make b/contrib/libs/grpc/test/cpp/end2end/ya.make new file mode 100644 index 0000000000..b9c1dc7fe0 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/end2end/ya.make @@ -0,0 +1,67 @@ +LIBRARY() + +LICENSE(Apache-2.0) + +LICENSE_TEXTS(.yandex_meta/licenses.list.txt) + +OWNER(dvshkurko) + +PEERDIR( + contrib/libs/grpc/src/proto/grpc/health/v1 + contrib/libs/grpc/src/proto/grpc/testing + contrib/libs/grpc/src/proto/grpc/testing/duplicate + contrib/libs/grpc/test/cpp/util + contrib/libs/grpc + contrib/restricted/googletest/googlemock + contrib/restricted/googletest/googletest +) + +ADDINCL( + ${ARCADIA_BUILD_ROOT}/contrib/libs/grpc + contrib/libs/grpc +) + +NO_COMPILER_WARNINGS() + +SRCS( + # async_end2end_test.cc + # channelz_service_test.cc + # client_callback_end2end_test.cc + # client_crash_test.cc + # client_crash_test_server.cc + # client_interceptors_end2end_test.cc + # client_lb_end2end_test.cc lb needs opencensus, not enabled. + # end2end_test.cc + # exception_test.cc + # filter_end2end_test.cc + # generic_end2end_test.cc + # grpclb_end2end_test.cc lb needs opencensus, not enabled. + # health_service_end2end_test.cc + # hybrid_end2end_test.cc + interceptors_util.cc + # mock_test.cc + # nonblocking_test.cc + # proto_server_reflection_test.cc + # raw_end2end_test.cc + # server_builder_plugin_test.cc + # server_crash_test.cc + # server_crash_test_client.cc + # server_early_return_test.cc + # server_interceptors_end2end_test.cc + # server_load_reporting_end2end_test.cc + # shutdown_test.cc + # streaming_throughput_test.cc + test_health_check_service_impl.cc + test_service_impl.cc + # thread_stress_test.cc + # time_change_test.cc +) + +END() + +RECURSE_FOR_TESTS( + health + server_interceptors + # Needs new gtest + # thread +) diff --git a/contrib/libs/grpc/test/cpp/util/.yandex_meta/licenses.list.txt b/contrib/libs/grpc/test/cpp/util/.yandex_meta/licenses.list.txt new file mode 100644 index 0000000000..d2dadabed9 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/util/.yandex_meta/licenses.list.txt @@ -0,0 +1,32 @@ +====================Apache-2.0==================== + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + + +====================COPYRIGHT==================== + * Copyright 2015 gRPC authors. + + +====================COPYRIGHT==================== + * Copyright 2015-2016 gRPC authors. + + +====================COPYRIGHT==================== + * Copyright 2016 gRPC authors. + + +====================COPYRIGHT==================== + * Copyright 2017 gRPC authors. + + +====================COPYRIGHT==================== + * Copyright 2018 gRPC authors. diff --git a/contrib/libs/grpc/test/cpp/util/byte_buffer_proto_helper.cc b/contrib/libs/grpc/test/cpp/util/byte_buffer_proto_helper.cc new file mode 100644 index 0000000000..5971b53075 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/util/byte_buffer_proto_helper.cc @@ -0,0 +1,57 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/cpp/util/byte_buffer_proto_helper.h" + +namespace grpc { +namespace testing { + +bool ParseFromByteBuffer(ByteBuffer* buffer, grpc::protobuf::Message* message) { + std::vector<Slice> slices; + (void)buffer->Dump(&slices); + TString buf; + buf.reserve(buffer->Length()); + for (auto s = slices.begin(); s != slices.end(); s++) { + buf.append(reinterpret_cast<const char*>(s->begin()), s->size()); + } + return message->ParseFromString(buf); +} + +std::unique_ptr<ByteBuffer> SerializeToByteBuffer( + grpc::protobuf::Message* message) { + TString buf; + message->SerializeToString(&buf); + Slice slice(buf); + return std::unique_ptr<ByteBuffer>(new ByteBuffer(&slice, 1)); +} + +bool SerializeToByteBufferInPlace(grpc::protobuf::Message* message, + ByteBuffer* buffer) { + TString buf; + if (!message->SerializeToString(&buf)) { + return false; + } + buffer->Clear(); + Slice slice(buf); + ByteBuffer tmp(&slice, 1); + buffer->Swap(&tmp); + return true; +} + +} // namespace testing +} // namespace grpc diff --git a/contrib/libs/grpc/test/cpp/util/byte_buffer_proto_helper.h b/contrib/libs/grpc/test/cpp/util/byte_buffer_proto_helper.h new file mode 100644 index 0000000000..3d01fb2468 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/util/byte_buffer_proto_helper.h @@ -0,0 +1,42 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef GRPC_TEST_CPP_UTIL_BYTE_BUFFER_PROTO_HELPER_H +#define GRPC_TEST_CPP_UTIL_BYTE_BUFFER_PROTO_HELPER_H + +#include <memory> + +#include <grpcpp/impl/codegen/config_protobuf.h> +#include <grpcpp/support/byte_buffer.h> + +namespace grpc { +namespace testing { + +bool ParseFromByteBuffer(ByteBuffer* buffer, + ::grpc::protobuf::Message* message); + +std::unique_ptr<ByteBuffer> SerializeToByteBuffer( + ::grpc::protobuf::Message* message); + +bool SerializeToByteBufferInPlace(::grpc::protobuf::Message* message, + ByteBuffer* buffer); + +} // namespace testing +} // namespace grpc + +#endif // GRPC_TEST_CPP_UTIL_BYTE_BUFFER_PROTO_HELPER_H diff --git a/contrib/libs/grpc/test/cpp/util/byte_buffer_test.cc b/contrib/libs/grpc/test/cpp/util/byte_buffer_test.cc new file mode 100644 index 0000000000..c63f351a8f --- /dev/null +++ b/contrib/libs/grpc/test/cpp/util/byte_buffer_test.cc @@ -0,0 +1,134 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <grpc++/support/byte_buffer.h> +#include <grpcpp/impl/grpc_library.h> + +#include <cstring> +#include <vector> + +#include <grpc/grpc.h> +#include <grpc/slice.h> +#include <grpcpp/support/slice.h> +#include <gtest/gtest.h> + +#include "test/core/util/test_config.h" + +namespace grpc { + +static internal::GrpcLibraryInitializer g_gli_initializer; + +namespace { + +const char* kContent1 = "hello xxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"; +const char* kContent2 = "yyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyy world"; + +class ByteBufferTest : public ::testing::Test { + protected: + static void SetUpTestCase() { grpc_init(); } + + static void TearDownTestCase() { grpc_shutdown(); } +}; + +TEST_F(ByteBufferTest, CopyCtor) { + ByteBuffer buffer1; + EXPECT_FALSE(buffer1.Valid()); + const ByteBuffer& buffer2 = buffer1; + EXPECT_FALSE(buffer2.Valid()); +} + +TEST_F(ByteBufferTest, CreateFromSingleSlice) { + Slice s(kContent1); + ByteBuffer buffer(&s, 1); + EXPECT_EQ(strlen(kContent1), buffer.Length()); +} + +TEST_F(ByteBufferTest, CreateFromVector) { + std::vector<Slice> slices; + slices.emplace_back(kContent1); + slices.emplace_back(kContent2); + ByteBuffer buffer(&slices[0], 2); + EXPECT_EQ(strlen(kContent1) + strlen(kContent2), buffer.Length()); +} + +TEST_F(ByteBufferTest, Clear) { + Slice s(kContent1); + ByteBuffer buffer(&s, 1); + buffer.Clear(); + EXPECT_EQ(static_cast<size_t>(0), buffer.Length()); +} + +TEST_F(ByteBufferTest, Length) { + std::vector<Slice> slices; + slices.emplace_back(kContent1); + slices.emplace_back(kContent2); + ByteBuffer buffer(&slices[0], 2); + EXPECT_EQ(strlen(kContent1) + strlen(kContent2), buffer.Length()); +} + +bool SliceEqual(const Slice& a, grpc_slice b) { + if (a.size() != GRPC_SLICE_LENGTH(b)) { + return false; + } + for (size_t i = 0; i < a.size(); i++) { + if (a.begin()[i] != GRPC_SLICE_START_PTR(b)[i]) { + return false; + } + } + return true; +} + +TEST_F(ByteBufferTest, Dump) { + grpc_slice hello = grpc_slice_from_copied_string(kContent1); + grpc_slice world = grpc_slice_from_copied_string(kContent2); + std::vector<Slice> slices; + slices.push_back(Slice(hello, Slice::STEAL_REF)); + slices.push_back(Slice(world, Slice::STEAL_REF)); + ByteBuffer buffer(&slices[0], 2); + slices.clear(); + (void)buffer.Dump(&slices); + EXPECT_TRUE(SliceEqual(slices[0], hello)); + EXPECT_TRUE(SliceEqual(slices[1], world)); +} + +TEST_F(ByteBufferTest, SerializationMakesCopy) { + grpc_slice hello = grpc_slice_from_copied_string(kContent1); + grpc_slice world = grpc_slice_from_copied_string(kContent2); + std::vector<Slice> slices; + slices.push_back(Slice(hello, Slice::STEAL_REF)); + slices.push_back(Slice(world, Slice::STEAL_REF)); + ByteBuffer send_buffer; + bool owned = false; + ByteBuffer buffer(&slices[0], 2); + slices.clear(); + auto status = SerializationTraits<ByteBuffer, void>::Serialize( + buffer, &send_buffer, &owned); + EXPECT_TRUE(status.ok()); + EXPECT_TRUE(owned); + EXPECT_TRUE(send_buffer.Valid()); +} + +} // namespace +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + int ret = RUN_ALL_TESTS(); + return ret; +} diff --git a/contrib/libs/grpc/test/cpp/util/channel_trace_proto_helper.cc b/contrib/libs/grpc/test/cpp/util/channel_trace_proto_helper.cc new file mode 100644 index 0000000000..d4b4026774 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/util/channel_trace_proto_helper.cc @@ -0,0 +1,115 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <grpc/support/port_platform.h> + +#include "test/cpp/util/channel_trace_proto_helper.h" + +#include <grpc/grpc.h> +#include <grpc/support/log.h> +#include <grpcpp/impl/codegen/config.h> +#include <grpcpp/impl/codegen/config_protobuf.h> +#include <gtest/gtest.h> + +#include "src/core/lib/iomgr/error.h" +#include "src/core/lib/json/json.h" +#include "src/proto/grpc/channelz/channelz.pb.h" + +namespace grpc { + +namespace { + +// Generic helper that takes in a json string, converts it to a proto, and +// then back to json. This ensures that the json string was correctly formatted +// according to https://developers.google.com/protocol-buffers/docs/proto3#json +template <typename Message> +void VaidateProtoJsonTranslation(const TString& json_str) { + Message msg; + grpc::protobuf::json::JsonParseOptions parse_options; + // If the following line is failing, then uncomment the last line of the + // comment, and uncomment the lines that print the two strings. You can + // then compare the output, and determine what fields are missing. + // + // parse_options.ignore_unknown_fields = true; + grpc::protobuf::util::Status s = + grpc::protobuf::json::JsonStringToMessage(json_str, &msg, parse_options); + EXPECT_TRUE(s.ok()); + TString proto_json_str; + grpc::protobuf::json::JsonPrintOptions print_options; + // We usually do not want this to be true, however it can be helpful to + // uncomment and see the output produced then all fields are printed. + // print_options.always_print_primitive_fields = true; + s = grpc::protobuf::json::MessageToJsonString(msg, &proto_json_str); + EXPECT_TRUE(s.ok()); + // Parse JSON and re-dump to string, to make sure formatting is the + // same as what would be generated by our JSON library. + grpc_error* error = GRPC_ERROR_NONE; + grpc_core::Json parsed_json = + grpc_core::Json::Parse(proto_json_str.c_str(), &error); + ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_string(error); + ASSERT_EQ(parsed_json.type(), grpc_core::Json::Type::OBJECT); + proto_json_str = parsed_json.Dump(); + // uncomment these to compare the json strings. + // gpr_log(GPR_ERROR, "tracer json: %s", json_str.c_str()); + // gpr_log(GPR_ERROR, "proto json: %s", proto_json_str.c_str()); + EXPECT_EQ(json_str, proto_json_str); +} + +} // namespace + +namespace testing { + +void ValidateChannelTraceProtoJsonTranslation(const char* json_c_str) { + VaidateProtoJsonTranslation<grpc::channelz::v1::ChannelTrace>(json_c_str); +} + +void ValidateChannelProtoJsonTranslation(const char* json_c_str) { + VaidateProtoJsonTranslation<grpc::channelz::v1::Channel>(json_c_str); +} + +void ValidateGetTopChannelsResponseProtoJsonTranslation( + const char* json_c_str) { + VaidateProtoJsonTranslation<grpc::channelz::v1::GetTopChannelsResponse>( + json_c_str); +} + +void ValidateGetChannelResponseProtoJsonTranslation(const char* json_c_str) { + VaidateProtoJsonTranslation<grpc::channelz::v1::GetChannelResponse>( + json_c_str); +} + +void ValidateGetServerResponseProtoJsonTranslation(const char* json_c_str) { + VaidateProtoJsonTranslation<grpc::channelz::v1::GetServerResponse>( + json_c_str); +} + +void ValidateSubchannelProtoJsonTranslation(const char* json_c_str) { + VaidateProtoJsonTranslation<grpc::channelz::v1::Subchannel>(json_c_str); +} + +void ValidateServerProtoJsonTranslation(const char* json_c_str) { + VaidateProtoJsonTranslation<grpc::channelz::v1::Server>(json_c_str); +} + +void ValidateGetServersResponseProtoJsonTranslation(const char* json_c_str) { + VaidateProtoJsonTranslation<grpc::channelz::v1::GetServersResponse>( + json_c_str); +} + +} // namespace testing +} // namespace grpc diff --git a/contrib/libs/grpc/test/cpp/util/channel_trace_proto_helper.h b/contrib/libs/grpc/test/cpp/util/channel_trace_proto_helper.h new file mode 100644 index 0000000000..664e899deb --- /dev/null +++ b/contrib/libs/grpc/test/cpp/util/channel_trace_proto_helper.h @@ -0,0 +1,37 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef GRPC_TEST_CPP_UTIL_CHANNEL_TRACE_PROTO_HELPER_H +#define GRPC_TEST_CPP_UTIL_CHANNEL_TRACE_PROTO_HELPER_H + +namespace grpc { +namespace testing { + +void ValidateChannelTraceProtoJsonTranslation(const char* json_c_str); +void ValidateChannelProtoJsonTranslation(const char* json_c_str); +void ValidateGetTopChannelsResponseProtoJsonTranslation(const char* json_c_str); +void ValidateGetChannelResponseProtoJsonTranslation(const char* json_c_str); +void ValidateGetServerResponseProtoJsonTranslation(const char* json_c_str); +void ValidateSubchannelProtoJsonTranslation(const char* json_c_str); +void ValidateServerProtoJsonTranslation(const char* json_c_str); +void ValidateGetServersResponseProtoJsonTranslation(const char* json_c_str); + +} // namespace testing +} // namespace grpc + +#endif // GRPC_TEST_CPP_UTIL_CHANNEL_TRACE_PROTO_HELPER_H diff --git a/contrib/libs/grpc/test/cpp/util/channelz_sampler.cc b/contrib/libs/grpc/test/cpp/util/channelz_sampler.cc new file mode 100644 index 0000000000..e6bde68556 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/util/channelz_sampler.cc @@ -0,0 +1,588 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ +#include <unistd.h> + +#include <cstdlib> +#include <fstream> +#include <iostream> +#include <memory> +#include <ostream> +#include <queue> +#include <util/generic/string.h> + +#include "y_absl/strings/str_format.h" +#include "y_absl/strings/str_join.h" +#include "gflags/gflags.h" +#include "google/protobuf/text_format.h" +#include "grpc/grpc.h" +#include "grpc/support/port_platform.h" +#include "grpcpp/channel.h" +#include "grpcpp/client_context.h" +#include "grpcpp/create_channel.h" +#include "grpcpp/ext/channelz_service_plugin.h" +#include "grpcpp/grpcpp.h" +#include "grpcpp/security/credentials.h" +#include "grpcpp/security/server_credentials.h" +#include "grpcpp/server.h" +#include "grpcpp/server_builder.h" +#include "grpcpp/server_context.h" +#include "src/core/lib/json/json.h" +#include "src/cpp/server/channelz/channelz_service.h" +#include "src/proto/grpc/channelz/channelz.pb.h" +#include "test/core/util/test_config.h" +#include "test/cpp/util/test_config.h" +#include "test/cpp/util/test_credentials_provider.h" + +DEFINE_string(server_address, "", "channelz server address"); +DEFINE_string(custom_credentials_type, "", "custom credentials type"); +DEFINE_int64(sampling_times, 1, "number of sampling"); +DEFINE_int64(sampling_interval_seconds, 0, "sampling interval in seconds"); +DEFINE_string(output_json, "", "output filename in json format"); + +namespace { +using grpc::ClientContext; +using grpc::Status; +using grpc::StatusCode; +using grpc::channelz::v1::GetChannelRequest; +using grpc::channelz::v1::GetChannelResponse; +using grpc::channelz::v1::GetServerRequest; +using grpc::channelz::v1::GetServerResponse; +using grpc::channelz::v1::GetServerSocketsRequest; +using grpc::channelz::v1::GetServerSocketsResponse; +using grpc::channelz::v1::GetServersRequest; +using grpc::channelz::v1::GetServersResponse; +using grpc::channelz::v1::GetSocketRequest; +using grpc::channelz::v1::GetSocketResponse; +using grpc::channelz::v1::GetSubchannelRequest; +using grpc::channelz::v1::GetSubchannelResponse; +using grpc::channelz::v1::GetTopChannelsRequest; +using grpc::channelz::v1::GetTopChannelsResponse; +} // namespace + +class ChannelzSampler final { + public: + // Get server_id of a server + int64_t GetServerID(const grpc::channelz::v1::Server& server) { + return server.ref().server_id(); + } + + // Get channel_id of a channel + inline int64_t GetChannelID(const grpc::channelz::v1::Channel& channel) { + return channel.ref().channel_id(); + } + + // Get subchannel_id of a subchannel + inline int64_t GetSubchannelID( + const grpc::channelz::v1::Subchannel& subchannel) { + return subchannel.ref().subchannel_id(); + } + + // Get socket_id of a socket + inline int64_t GetSocketID(const grpc::channelz::v1::Socket& socket) { + return socket.ref().socket_id(); + } + + // Get name of a server + inline TString GetServerName(const grpc::channelz::v1::Server& server) { + return server.ref().name(); + } + + // Get name of a channel + inline TString GetChannelName( + const grpc::channelz::v1::Channel& channel) { + return channel.ref().name(); + } + + // Get name of a subchannel + inline TString GetSubchannelName( + const grpc::channelz::v1::Subchannel& subchannel) { + return subchannel.ref().name(); + } + + // Get name of a socket + inline TString GetSocketName(const grpc::channelz::v1::Socket& socket) { + return socket.ref().name(); + } + + // Get a channel based on channel_id + grpc::channelz::v1::Channel GetChannelRPC(int64_t channel_id) { + GetChannelRequest get_channel_request; + get_channel_request.set_channel_id(channel_id); + GetChannelResponse get_channel_response; + ClientContext get_channel_context; + get_channel_context.set_deadline( + grpc_timeout_seconds_to_deadline(rpc_timeout_seconds_)); + Status status = channelz_stub_->GetChannel( + &get_channel_context, get_channel_request, &get_channel_response); + if (!status.ok()) { + gpr_log(GPR_ERROR, "GetChannelRPC failed: %s", + get_channel_context.debug_error_string().c_str()); + GPR_ASSERT(0); + } + return get_channel_response.channel(); + } + + // Get a subchannel based on subchannel_id + grpc::channelz::v1::Subchannel GetSubchannelRPC(int64_t subchannel_id) { + GetSubchannelRequest get_subchannel_request; + get_subchannel_request.set_subchannel_id(subchannel_id); + GetSubchannelResponse get_subchannel_response; + ClientContext get_subchannel_context; + get_subchannel_context.set_deadline( + grpc_timeout_seconds_to_deadline(rpc_timeout_seconds_)); + Status status = channelz_stub_->GetSubchannel(&get_subchannel_context, + get_subchannel_request, + &get_subchannel_response); + if (!status.ok()) { + gpr_log(GPR_ERROR, "GetSubchannelRPC failed: %s", + get_subchannel_context.debug_error_string().c_str()); + GPR_ASSERT(0); + } + return get_subchannel_response.subchannel(); + } + + // get a socket based on socket_id + grpc::channelz::v1::Socket GetSocketRPC(int64_t socket_id) { + GetSocketRequest get_socket_request; + get_socket_request.set_socket_id(socket_id); + GetSocketResponse get_socket_response; + ClientContext get_socket_context; + get_socket_context.set_deadline( + grpc_timeout_seconds_to_deadline(rpc_timeout_seconds_)); + Status status = channelz_stub_->GetSocket( + &get_socket_context, get_socket_request, &get_socket_response); + if (!status.ok()) { + gpr_log(GPR_ERROR, "GetSocketRPC failed: %s", + get_socket_context.debug_error_string().c_str()); + GPR_ASSERT(0); + } + return get_socket_response.socket(); + } + + // get the descedent channels/subchannels/sockets of a channel + // push descedent channels/subchannels to queue for layer traverse + // store descedent channels/subchannels/sockets for dumping data + void GetChannelDescedence( + const grpc::channelz::v1::Channel& channel, + std::queue<grpc::channelz::v1::Channel>& channel_queue, + std::queue<grpc::channelz::v1::Subchannel>& subchannel_queue) { + std::cout << " Channel ID" << GetChannelID(channel) << "_" + << GetChannelName(channel) << " descendence - "; + if (channel.channel_ref_size() > 0 || channel.subchannel_ref_size() > 0) { + if (channel.channel_ref_size() > 0) { + std::cout << "channel: "; + for (const auto& _channelref : channel.channel_ref()) { + int64_t ch_id = _channelref.channel_id(); + std::cout << "ID" << ch_id << "_" << _channelref.name() << " "; + grpc::channelz::v1::Channel ch = GetChannelRPC(ch_id); + channel_queue.push(ch); + if (CheckID(ch_id)) { + all_channels_.push_back(ch); + StoreChannelInJson(ch); + } + } + if (channel.subchannel_ref_size() > 0) { + std::cout << ", "; + } + } + if (channel.subchannel_ref_size() > 0) { + std::cout << "subchannel: "; + for (const auto& _subchannelref : channel.subchannel_ref()) { + int64_t subch_id = _subchannelref.subchannel_id(); + std::cout << "ID" << subch_id << "_" << _subchannelref.name() << " "; + grpc::channelz::v1::Subchannel subch = GetSubchannelRPC(subch_id); + subchannel_queue.push(subch); + if (CheckID(subch_id)) { + all_subchannels_.push_back(subch); + StoreSubchannelInJson(subch); + } + } + } + } else if (channel.socket_ref_size() > 0) { + std::cout << "socket: "; + for (const auto& _socketref : channel.socket_ref()) { + int64_t so_id = _socketref.socket_id(); + std::cout << "ID" << so_id << "_" << _socketref.name() << " "; + grpc::channelz::v1::Socket so = GetSocketRPC(so_id); + if (CheckID(so_id)) { + all_sockets_.push_back(so); + StoreSocketInJson(so); + } + } + } + std::cout << std::endl; + } + + // get the descedent channels/subchannels/sockets of a subchannel + // push descedent channels/subchannels to queue for layer traverse + // store descedent channels/subchannels/sockets for dumping data + void GetSubchannelDescedence( + grpc::channelz::v1::Subchannel& subchannel, + std::queue<grpc::channelz::v1::Channel>& channel_queue, + std::queue<grpc::channelz::v1::Subchannel>& subchannel_queue) { + std::cout << " Subchannel ID" << GetSubchannelID(subchannel) << "_" + << GetSubchannelName(subchannel) << " descendence - "; + if (subchannel.channel_ref_size() > 0 || + subchannel.subchannel_ref_size() > 0) { + if (subchannel.channel_ref_size() > 0) { + std::cout << "channel: "; + for (const auto& _channelref : subchannel.channel_ref()) { + int64_t ch_id = _channelref.channel_id(); + std::cout << "ID" << ch_id << "_" << _channelref.name() << " "; + grpc::channelz::v1::Channel ch = GetChannelRPC(ch_id); + channel_queue.push(ch); + if (CheckID(ch_id)) { + all_channels_.push_back(ch); + StoreChannelInJson(ch); + } + } + if (subchannel.subchannel_ref_size() > 0) { + std::cout << ", "; + } + } + if (subchannel.subchannel_ref_size() > 0) { + std::cout << "subchannel: "; + for (const auto& _subchannelref : subchannel.subchannel_ref()) { + int64_t subch_id = _subchannelref.subchannel_id(); + std::cout << "ID" << subch_id << "_" << _subchannelref.name() << " "; + grpc::channelz::v1::Subchannel subch = GetSubchannelRPC(subch_id); + subchannel_queue.push(subch); + if (CheckID(subch_id)) { + all_subchannels_.push_back(subch); + StoreSubchannelInJson(subch); + } + } + } + } else if (subchannel.socket_ref_size() > 0) { + std::cout << "socket: "; + for (const auto& _socketref : subchannel.socket_ref()) { + int64_t so_id = _socketref.socket_id(); + std::cout << "ID" << so_id << "_" << _socketref.name() << " "; + grpc::channelz::v1::Socket so = GetSocketRPC(so_id); + if (CheckID(so_id)) { + all_sockets_.push_back(so); + StoreSocketInJson(so); + } + } + } + std::cout << std::endl; + } + + // Set up the channelz sampler client + // Initialize json as an array + void Setup(const TString& custom_credentials_type, + const TString& server_address) { + json_ = grpc_core::Json::Array(); + rpc_timeout_seconds_ = 20; + grpc::ChannelArguments channel_args; + std::shared_ptr<grpc::ChannelCredentials> channel_creds = + grpc::testing::GetCredentialsProvider()->GetChannelCredentials( + custom_credentials_type, &channel_args); + if (!channel_creds) { + gpr_log(GPR_ERROR, + "Wrong user credential type: %s. Allowed credential types: " + "INSECURE_CREDENTIALS, ssl, alts, google_default_credentials.", + custom_credentials_type.c_str()); + GPR_ASSERT(0); + } + std::shared_ptr<grpc::Channel> channel = + CreateChannel(server_address, channel_creds); + channelz_stub_ = grpc::channelz::v1::Channelz::NewStub(channel); + } + + // Get all servers, keep querying until getting all + // Store servers for dumping data + // Need to check id repeating for servers + void GetServersRPC() { + int64_t server_start_id = 0; + while (true) { + GetServersRequest get_servers_request; + GetServersResponse get_servers_response; + ClientContext get_servers_context; + get_servers_context.set_deadline( + grpc_timeout_seconds_to_deadline(rpc_timeout_seconds_)); + get_servers_request.set_start_server_id(server_start_id); + Status status = channelz_stub_->GetServers( + &get_servers_context, get_servers_request, &get_servers_response); + if (!status.ok()) { + if (status.error_code() == StatusCode::UNIMPLEMENTED) { + gpr_log(GPR_ERROR, + "Error status UNIMPLEMENTED. Please check and make sure " + "channelz has been registered on the server being queried."); + } else { + gpr_log(GPR_ERROR, + "GetServers RPC with GetServersRequest.server_start_id=%d, " + "failed: %s", + int(server_start_id), + get_servers_context.debug_error_string().c_str()); + } + GPR_ASSERT(0); + } + for (const auto& _server : get_servers_response.server()) { + all_servers_.push_back(_server); + StoreServerInJson(_server); + } + if (!get_servers_response.end()) { + server_start_id = GetServerID(all_servers_.back()) + 1; + } else { + break; + } + } + std::cout << "Number of servers = " << all_servers_.size() << std::endl; + } + + // Get sockets that belongs to servers + // Store sockets for dumping data + void GetSocketsOfServers() { + for (const auto& _server : all_servers_) { + std::cout << "Server ID" << GetServerID(_server) << "_" + << GetServerName(_server) << " listen_socket - "; + for (const auto& _socket : _server.listen_socket()) { + int64_t so_id = _socket.socket_id(); + std::cout << "ID" << so_id << "_" << _socket.name() << " "; + if (CheckID(so_id)) { + grpc::channelz::v1::Socket so = GetSocketRPC(so_id); + all_sockets_.push_back(so); + StoreSocketInJson(so); + } + } + std::cout << std::endl; + } + } + + // Get all top channels, keep querying until getting all + // Store channels for dumping data + // No need to check id repeating for top channels + void GetTopChannelsRPC() { + int64_t channel_start_id = 0; + while (true) { + GetTopChannelsRequest get_top_channels_request; + GetTopChannelsResponse get_top_channels_response; + ClientContext get_top_channels_context; + get_top_channels_context.set_deadline( + grpc_timeout_seconds_to_deadline(rpc_timeout_seconds_)); + get_top_channels_request.set_start_channel_id(channel_start_id); + Status status = channelz_stub_->GetTopChannels( + &get_top_channels_context, get_top_channels_request, + &get_top_channels_response); + if (!status.ok()) { + gpr_log(GPR_ERROR, + "GetTopChannels RPC with " + "GetTopChannelsRequest.channel_start_id=%d failed: %s", + int(channel_start_id), + get_top_channels_context.debug_error_string().c_str()); + GPR_ASSERT(0); + } + for (const auto& _topchannel : get_top_channels_response.channel()) { + top_channels_.push_back(_topchannel); + all_channels_.push_back(_topchannel); + StoreChannelInJson(_topchannel); + } + if (!get_top_channels_response.end()) { + channel_start_id = GetChannelID(top_channels_.back()) + 1; + } else { + break; + } + } + std::cout << std::endl + << "Number of top channels = " << top_channels_.size() + << std::endl; + } + + // layer traverse for each top channel + void TraverseTopChannels() { + for (const auto& _topchannel : top_channels_) { + int tree_depth = 0; + std::queue<grpc::channelz::v1::Channel> channel_queue; + std::queue<grpc::channelz::v1::Subchannel> subchannel_queue; + std::cout << "Tree depth = " << tree_depth << std::endl; + GetChannelDescedence(_topchannel, channel_queue, subchannel_queue); + while (!channel_queue.empty() || !subchannel_queue.empty()) { + ++tree_depth; + std::cout << "Tree depth = " << tree_depth << std::endl; + int ch_q_size = channel_queue.size(); + int subch_q_size = subchannel_queue.size(); + for (int i = 0; i < ch_q_size; ++i) { + grpc::channelz::v1::Channel ch = channel_queue.front(); + channel_queue.pop(); + GetChannelDescedence(ch, channel_queue, subchannel_queue); + } + for (int i = 0; i < subch_q_size; ++i) { + grpc::channelz::v1::Subchannel subch = subchannel_queue.front(); + subchannel_queue.pop(); + GetSubchannelDescedence(subch, channel_queue, subchannel_queue); + } + } + std::cout << std::endl; + } + } + + // dump data of all entities to stdout + void DumpStdout() { + TString data_str; + for (const auto& _channel : all_channels_) { + std::cout << "channel ID" << GetChannelID(_channel) << "_" + << GetChannelName(_channel) << " data:" << std::endl; + // TODO(mohanli): TextFormat::PrintToString records time as seconds and + // nanos. Need a more human readable way. + ::google::protobuf::TextFormat::PrintToString(_channel.data(), &data_str); + printf("%s\n", data_str.c_str()); + } + for (const auto& _subchannel : all_subchannels_) { + std::cout << "subchannel ID" << GetSubchannelID(_subchannel) << "_" + << GetSubchannelName(_subchannel) << " data:" << std::endl; + ::google::protobuf::TextFormat::PrintToString(_subchannel.data(), + &data_str); + printf("%s\n", data_str.c_str()); + } + for (const auto& _server : all_servers_) { + std::cout << "server ID" << GetServerID(_server) << "_" + << GetServerName(_server) << " data:" << std::endl; + ::google::protobuf::TextFormat::PrintToString(_server.data(), &data_str); + printf("%s\n", data_str.c_str()); + } + for (const auto& _socket : all_sockets_) { + std::cout << "socket ID" << GetSocketID(_socket) << "_" + << GetSocketName(_socket) << " data:" << std::endl; + ::google::protobuf::TextFormat::PrintToString(_socket.data(), &data_str); + printf("%s\n", data_str.c_str()); + } + } + + // Store a channel in Json + void StoreChannelInJson(const grpc::channelz::v1::Channel& channel) { + TString id = grpc::to_string(GetChannelID(channel)); + TString type = "Channel"; + TString description; + ::google::protobuf::TextFormat::PrintToString(channel.data(), &description); + grpc_core::Json description_json = grpc_core::Json(description); + StoreEntityInJson(id, type, description_json); + } + + // Store a subchannel in Json + void StoreSubchannelInJson(const grpc::channelz::v1::Subchannel& subchannel) { + TString id = grpc::to_string(GetSubchannelID(subchannel)); + TString type = "Subchannel"; + TString description; + ::google::protobuf::TextFormat::PrintToString(subchannel.data(), + &description); + grpc_core::Json description_json = grpc_core::Json(description); + StoreEntityInJson(id, type, description_json); + } + + // Store a server in Json + void StoreServerInJson(const grpc::channelz::v1::Server& server) { + TString id = grpc::to_string(GetServerID(server)); + TString type = "Server"; + TString description; + ::google::protobuf::TextFormat::PrintToString(server.data(), &description); + grpc_core::Json description_json = grpc_core::Json(description); + StoreEntityInJson(id, type, description_json); + } + + // Store a socket in Json + void StoreSocketInJson(const grpc::channelz::v1::Socket& socket) { + TString id = grpc::to_string(GetSocketID(socket)); + TString type = "Socket"; + TString description; + ::google::protobuf::TextFormat::PrintToString(socket.data(), &description); + grpc_core::Json description_json = grpc_core::Json(description); + StoreEntityInJson(id, type, description_json); + } + + // Store an entity in Json + void StoreEntityInJson(TString& id, TString& type, + const grpc_core::Json& description) { + TString start, finish; + gpr_timespec ago = gpr_time_sub( + now_, + gpr_time_from_seconds(FLAGS_sampling_interval_seconds, GPR_TIMESPAN)); + std::stringstream ss; + const time_t time_now = now_.tv_sec; + ss << std::put_time(std::localtime(&time_now), "%F %T"); + finish = ss.str(); // example: "2019-02-01 12:12:18" + ss.str(""); + const time_t time_ago = ago.tv_sec; + ss << std::put_time(std::localtime(&time_ago), "%F %T"); + start = ss.str(); + grpc_core::Json obj = + grpc_core::Json::Object{{"Task", y_absl::StrFormat("%s_ID%s", type, id)}, + {"Start", start}, + {"Finish", finish}, + {"ID", id}, + {"Type", type}, + {"Description", description}}; + json_.mutable_array()->push_back(obj); + } + + // Dump data in json + TString DumpJson() { return json_.Dump(); } + + // Check if one entity has been recorded + bool CheckID(int64_t id) { + if (id_set_.count(id) == 0) { + id_set_.insert(id); + return true; + } else { + return false; + } + } + + // Record current time + void RecordNow() { now_ = gpr_now(GPR_CLOCK_REALTIME); } + + private: + std::unique_ptr<grpc::channelz::v1::Channelz::Stub> channelz_stub_; + std::vector<grpc::channelz::v1::Channel> top_channels_; + std::vector<grpc::channelz::v1::Server> all_servers_; + std::vector<grpc::channelz::v1::Channel> all_channels_; + std::vector<grpc::channelz::v1::Subchannel> all_subchannels_; + std::vector<grpc::channelz::v1::Socket> all_sockets_; + std::unordered_set<int64_t> id_set_; + grpc_core::Json json_; + int64_t rpc_timeout_seconds_; + gpr_timespec now_; +}; + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + grpc::testing::InitTest(&argc, &argv, true); + std::ofstream output_file(FLAGS_output_json); + for (int i = 0; i < FLAGS_sampling_times; ++i) { + ChannelzSampler channelz_sampler; + channelz_sampler.Setup(FLAGS_custom_credentials_type, FLAGS_server_address); + std::cout << "Wait for sampling interval " + << FLAGS_sampling_interval_seconds << "s..." << std::endl; + const gpr_timespec kDelay = gpr_time_add( + gpr_now(GPR_CLOCK_MONOTONIC), + gpr_time_from_seconds(FLAGS_sampling_interval_seconds, GPR_TIMESPAN)); + gpr_sleep_until(kDelay); + std::cout << "##### " << i << "th sampling #####" << std::endl; + channelz_sampler.RecordNow(); + channelz_sampler.GetServersRPC(); + channelz_sampler.GetSocketsOfServers(); + channelz_sampler.GetTopChannelsRPC(); + channelz_sampler.TraverseTopChannels(); + channelz_sampler.DumpStdout(); + if (!FLAGS_output_json.empty()) { + output_file << channelz_sampler.DumpJson() << "\n" << std::flush; + } + } + output_file.close(); + return 0; +} diff --git a/contrib/libs/grpc/test/cpp/util/channelz_sampler_test.cc b/contrib/libs/grpc/test/cpp/util/channelz_sampler_test.cc new file mode 100644 index 0000000000..d81dbb0d05 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/util/channelz_sampler_test.cc @@ -0,0 +1,176 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ +#include <stdlib.h> +#include <unistd.h> + +#include <cstdlib> +#include <iostream> +#include <memory> +#include <util/generic/string.h> +#include <thread> + +#include "grpc/grpc.h" +#include "grpc/support/alloc.h" +#include "grpc/support/port_platform.h" +#include "grpcpp/channel.h" +#include "grpcpp/client_context.h" +#include "grpcpp/create_channel.h" +#include "grpcpp/ext/channelz_service_plugin.h" +#include "grpcpp/grpcpp.h" +#include "grpcpp/security/credentials.h" +#include "grpcpp/security/server_credentials.h" +#include "grpcpp/server.h" +#include "grpcpp/server_builder.h" +#include "grpcpp/server_context.h" +#include "gtest/gtest.h" +#include "src/core/lib/gpr/env.h" +#include "src/cpp/server/channelz/channelz_service.h" +#include "src/proto/grpc/testing/test.grpc.pb.h" +#include "test/core/util/test_config.h" +#include "test/cpp/util/subprocess.h" +#include "test/cpp/util/test_credentials_provider.h" + +static TString g_root; + +namespace { +using grpc::ClientContext; +using grpc::Server; +using grpc::ServerBuilder; +using grpc::ServerContext; +using grpc::Status; +} // namespace + +// Test variables +TString server_address("0.0.0.0:10000"); +TString custom_credentials_type("INSECURE_CREDENTIALS"); +TString sampling_times = "2"; +TString sampling_interval_seconds = "3"; +TString output_json("output.json"); + +// Creata an echo server +class EchoServerImpl final : public grpc::testing::TestService::Service { + Status EmptyCall(::grpc::ServerContext* context, + const grpc::testing::Empty* request, + grpc::testing::Empty* response) { + return Status::OK; + } +}; + +// Run client in a thread +void RunClient(const TString& client_id, gpr_event* done_ev) { + grpc::ChannelArguments channel_args; + std::shared_ptr<grpc::ChannelCredentials> channel_creds = + grpc::testing::GetCredentialsProvider()->GetChannelCredentials( + custom_credentials_type, &channel_args); + std::unique_ptr<grpc::testing::TestService::Stub> stub = + grpc::testing::TestService::NewStub( + grpc::CreateChannel(server_address, channel_creds)); + gpr_log(GPR_INFO, "Client %s is echoing!", client_id.c_str()); + while (true) { + if (gpr_event_wait(done_ev, grpc_timeout_seconds_to_deadline(1)) != + nullptr) { + return; + } + grpc::testing::Empty request; + grpc::testing::Empty response; + ClientContext context; + Status status = stub->EmptyCall(&context, request, &response); + if (!status.ok()) { + gpr_log(GPR_ERROR, "Client echo failed."); + GPR_ASSERT(0); + } + } +} + +// Create the channelz to test the connection to the server +bool WaitForConnection(int wait_server_seconds) { + grpc::ChannelArguments channel_args; + std::shared_ptr<grpc::ChannelCredentials> channel_creds = + grpc::testing::GetCredentialsProvider()->GetChannelCredentials( + custom_credentials_type, &channel_args); + auto channel = grpc::CreateChannel(server_address, channel_creds); + return channel->WaitForConnected( + grpc_timeout_seconds_to_deadline(wait_server_seconds)); +} + +// Test the channelz sampler +TEST(ChannelzSamplerTest, SimpleTest) { + // start server + ::grpc::channelz::experimental::InitChannelzService(); + EchoServerImpl service; + grpc::ServerBuilder builder; + auto server_creds = + grpc::testing::GetCredentialsProvider()->GetServerCredentials( + custom_credentials_type); + builder.AddListeningPort(server_address, server_creds); + builder.RegisterService(&service); + std::unique_ptr<Server> server(builder.BuildAndStart()); + gpr_log(GPR_INFO, "Server listening on %s", server_address.c_str()); + const int kWaitForServerSeconds = 10; + ASSERT_TRUE(WaitForConnection(kWaitForServerSeconds)); + // client threads + gpr_event done_ev1, done_ev2; + gpr_event_init(&done_ev1); + gpr_event_init(&done_ev2); + std::thread client_thread_1(RunClient, "1", &done_ev1); + std::thread client_thread_2(RunClient, "2", &done_ev2); + // Run the channelz sampler + grpc::SubProcess* test_driver = new grpc::SubProcess( + {g_root + "/channelz_sampler", "--server_address=" + server_address, + "--custom_credentials_type=" + custom_credentials_type, + "--sampling_times=" + sampling_times, + "--sampling_interval_seconds=" + sampling_interval_seconds, + "--output_json=" + output_json}); + int status = test_driver->Join(); + if (WIFEXITED(status)) { + if (WEXITSTATUS(status)) { + gpr_log(GPR_ERROR, + "Channelz sampler test test-runner exited with code %d", + WEXITSTATUS(status)); + GPR_ASSERT(0); // log the line number of the assertion failure + } + } else if (WIFSIGNALED(status)) { + gpr_log(GPR_ERROR, "Channelz sampler test test-runner ended from signal %d", + WTERMSIG(status)); + GPR_ASSERT(0); + } else { + gpr_log(GPR_ERROR, + "Channelz sampler test test-runner ended with unknown status %d", + status); + GPR_ASSERT(0); + } + delete test_driver; + gpr_event_set(&done_ev1, (void*)1); + gpr_event_set(&done_ev2, (void*)1); + client_thread_1.join(); + client_thread_2.join(); +} + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + TString me = argv[0]; + auto lslash = me.rfind('/'); + if (lslash != TString::npos) { + g_root = me.substr(0, lslash); + } else { + g_root = "."; + } + int ret = RUN_ALL_TESTS(); + return ret; +} diff --git a/contrib/libs/grpc/test/cpp/util/cli_call.cc b/contrib/libs/grpc/test/cpp/util/cli_call.cc new file mode 100644 index 0000000000..5b3631667f --- /dev/null +++ b/contrib/libs/grpc/test/cpp/util/cli_call.cc @@ -0,0 +1,229 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/cpp/util/cli_call.h" + +#include <grpc/grpc.h> +#include <grpc/slice.h> +#include <grpc/support/log.h> +#include <grpcpp/channel.h> +#include <grpcpp/client_context.h> +#include <grpcpp/support/byte_buffer.h> + +#include <cmath> +#include <iostream> +#include <utility> + +namespace grpc { +namespace testing { +namespace { +void* tag(int i) { return (void*)static_cast<intptr_t>(i); } +} // namespace + +Status CliCall::Call(const std::shared_ptr<grpc::Channel>& channel, + const TString& method, const TString& request, + TString* response, + const OutgoingMetadataContainer& metadata, + IncomingMetadataContainer* server_initial_metadata, + IncomingMetadataContainer* server_trailing_metadata) { + CliCall call(channel, method, metadata); + call.Write(request); + call.WritesDone(); + if (!call.Read(response, server_initial_metadata)) { + fprintf(stderr, "Failed to read response.\n"); + } + return call.Finish(server_trailing_metadata); +} + +CliCall::CliCall(const std::shared_ptr<grpc::Channel>& channel, + const TString& method, + const OutgoingMetadataContainer& metadata, CliArgs args) + : stub_(new grpc::GenericStub(channel)) { + gpr_mu_init(&write_mu_); + gpr_cv_init(&write_cv_); + if (!metadata.empty()) { + for (OutgoingMetadataContainer::const_iterator iter = metadata.begin(); + iter != metadata.end(); ++iter) { + ctx_.AddMetadata(iter->first, iter->second); + } + } + + // Set deadline if timeout > 0 (default value -1 if no timeout specified) + if (args.timeout > 0) { + int64_t timeout_in_ns = ceil(args.timeout * 1e9); + + // Convert timeout (in nanoseconds) to a deadline + auto deadline = + gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC), + gpr_time_from_nanos(timeout_in_ns, GPR_TIMESPAN)); + ctx_.set_deadline(deadline); + } else if (args.timeout != -1) { + fprintf( + stderr, + "WARNING: Non-positive timeout value, skipping setting deadline.\n"); + } + + call_ = stub_->PrepareCall(&ctx_, method, &cq_); + call_->StartCall(tag(1)); + void* got_tag; + bool ok; + cq_.Next(&got_tag, &ok); + GPR_ASSERT(ok); +} + +CliCall::~CliCall() { + gpr_cv_destroy(&write_cv_); + gpr_mu_destroy(&write_mu_); +} + +void CliCall::Write(const TString& request) { + void* got_tag; + bool ok; + + gpr_slice s = gpr_slice_from_copied_buffer(request.data(), request.size()); + grpc::Slice req_slice(s, grpc::Slice::STEAL_REF); + grpc::ByteBuffer send_buffer(&req_slice, 1); + call_->Write(send_buffer, tag(2)); + cq_.Next(&got_tag, &ok); + GPR_ASSERT(ok); +} + +bool CliCall::Read(TString* response, + IncomingMetadataContainer* server_initial_metadata) { + void* got_tag; + bool ok; + + grpc::ByteBuffer recv_buffer; + call_->Read(&recv_buffer, tag(3)); + + if (!cq_.Next(&got_tag, &ok) || !ok) { + return false; + } + std::vector<grpc::Slice> slices; + GPR_ASSERT(recv_buffer.Dump(&slices).ok()); + + response->clear(); + for (size_t i = 0; i < slices.size(); i++) { + response->append(reinterpret_cast<const char*>(slices[i].begin()), + slices[i].size()); + } + if (server_initial_metadata) { + *server_initial_metadata = ctx_.GetServerInitialMetadata(); + } + return true; +} + +void CliCall::WritesDone() { + void* got_tag; + bool ok; + + call_->WritesDone(tag(4)); + cq_.Next(&got_tag, &ok); + GPR_ASSERT(ok); +} + +void CliCall::WriteAndWait(const TString& request) { + grpc::Slice req_slice(request); + grpc::ByteBuffer send_buffer(&req_slice, 1); + + gpr_mu_lock(&write_mu_); + call_->Write(send_buffer, tag(2)); + write_done_ = false; + while (!write_done_) { + gpr_cv_wait(&write_cv_, &write_mu_, gpr_inf_future(GPR_CLOCK_MONOTONIC)); + } + gpr_mu_unlock(&write_mu_); +} + +void CliCall::WritesDoneAndWait() { + gpr_mu_lock(&write_mu_); + call_->WritesDone(tag(4)); + write_done_ = false; + while (!write_done_) { + gpr_cv_wait(&write_cv_, &write_mu_, gpr_inf_future(GPR_CLOCK_MONOTONIC)); + } + gpr_mu_unlock(&write_mu_); +} + +bool CliCall::ReadAndMaybeNotifyWrite( + TString* response, IncomingMetadataContainer* server_initial_metadata) { + void* got_tag; + bool ok; + grpc::ByteBuffer recv_buffer; + + call_->Read(&recv_buffer, tag(3)); + bool cq_result = cq_.Next(&got_tag, &ok); + + while (got_tag != tag(3)) { + gpr_mu_lock(&write_mu_); + write_done_ = true; + gpr_cv_signal(&write_cv_); + gpr_mu_unlock(&write_mu_); + + cq_result = cq_.Next(&got_tag, &ok); + if (got_tag == tag(2)) { + GPR_ASSERT(ok); + } + } + + if (!cq_result || !ok) { + // If the RPC is ended on the server side, we should still wait for the + // pending write on the client side to be done. + if (!ok) { + gpr_mu_lock(&write_mu_); + if (!write_done_) { + cq_.Next(&got_tag, &ok); + GPR_ASSERT(got_tag != tag(2)); + write_done_ = true; + gpr_cv_signal(&write_cv_); + } + gpr_mu_unlock(&write_mu_); + } + return false; + } + + std::vector<grpc::Slice> slices; + GPR_ASSERT(recv_buffer.Dump(&slices).ok()); + response->clear(); + for (size_t i = 0; i < slices.size(); i++) { + response->append(reinterpret_cast<const char*>(slices[i].begin()), + slices[i].size()); + } + if (server_initial_metadata) { + *server_initial_metadata = ctx_.GetServerInitialMetadata(); + } + return true; +} + +Status CliCall::Finish(IncomingMetadataContainer* server_trailing_metadata) { + void* got_tag; + bool ok; + grpc::Status status; + + call_->Finish(&status, tag(5)); + cq_.Next(&got_tag, &ok); + GPR_ASSERT(ok); + if (server_trailing_metadata) { + *server_trailing_metadata = ctx_.GetServerTrailingMetadata(); + } + + return status; +} + +} // namespace testing +} // namespace grpc diff --git a/contrib/libs/grpc/test/cpp/util/cli_call.h b/contrib/libs/grpc/test/cpp/util/cli_call.h new file mode 100644 index 0000000000..79d00d99f4 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/util/cli_call.h @@ -0,0 +1,109 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef GRPC_TEST_CPP_UTIL_CLI_CALL_H +#define GRPC_TEST_CPP_UTIL_CLI_CALL_H + +#include <grpcpp/channel.h> +#include <grpcpp/completion_queue.h> +#include <grpcpp/generic/generic_stub.h> +#include <grpcpp/support/status.h> +#include <grpcpp/support/string_ref.h> + +#include <map> + +namespace grpc { + +class ClientContext; + +struct CliArgs { + double timeout = -1; +}; + +namespace testing { + +// CliCall handles the sending and receiving of generic messages given the name +// of the remote method. This class is only used by GrpcTool. Its thread-safe +// and thread-unsafe methods should not be used together. +class CliCall final { + public: + typedef std::multimap<TString, TString> OutgoingMetadataContainer; + typedef std::multimap<grpc::string_ref, grpc::string_ref> + IncomingMetadataContainer; + + CliCall(const std::shared_ptr<grpc::Channel>& channel, + const TString& method, const OutgoingMetadataContainer& metadata, + CliArgs args); + CliCall(const std::shared_ptr<grpc::Channel>& channel, + const TString& method, const OutgoingMetadataContainer& metadata) + : CliCall(channel, method, metadata, CliArgs{}) {} + + ~CliCall(); + + // Perform an unary generic RPC. + static Status Call(const std::shared_ptr<grpc::Channel>& channel, + const TString& method, const TString& request, + TString* response, + const OutgoingMetadataContainer& metadata, + IncomingMetadataContainer* server_initial_metadata, + IncomingMetadataContainer* server_trailing_metadata); + + // Send a generic request message in a synchronous manner. NOT thread-safe. + void Write(const TString& request); + + // Send a generic request message in a synchronous manner. NOT thread-safe. + void WritesDone(); + + // Receive a generic response message in a synchronous manner.NOT thread-safe. + bool Read(TString* response, + IncomingMetadataContainer* server_initial_metadata); + + // Thread-safe write. Must be used with ReadAndMaybeNotifyWrite. Send out a + // generic request message and wait for ReadAndMaybeNotifyWrite to finish it. + void WriteAndWait(const TString& request); + + // Thread-safe WritesDone. Must be used with ReadAndMaybeNotifyWrite. Send out + // WritesDone for gereneric request messages and wait for + // ReadAndMaybeNotifyWrite to finish it. + void WritesDoneAndWait(); + + // Thread-safe Read. Blockingly receive a generic response message. Notify + // writes if they are finished when this read is waiting for a resposne. + bool ReadAndMaybeNotifyWrite( + TString* response, + IncomingMetadataContainer* server_initial_metadata); + + // Finish the RPC. + Status Finish(IncomingMetadataContainer* server_trailing_metadata); + + TString peer() const { return ctx_.peer(); } + + private: + std::unique_ptr<grpc::GenericStub> stub_; + grpc::ClientContext ctx_; + std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call_; + grpc::CompletionQueue cq_; + gpr_mu write_mu_; + gpr_cv write_cv_; // Protected by write_mu_; + bool write_done_; // Portected by write_mu_; +}; + +} // namespace testing +} // namespace grpc + +#endif // GRPC_TEST_CPP_UTIL_CLI_CALL_H diff --git a/contrib/libs/grpc/test/cpp/util/cli_call_test.cc b/contrib/libs/grpc/test/cpp/util/cli_call_test.cc new file mode 100644 index 0000000000..4f0544b2e5 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/util/cli_call_test.cc @@ -0,0 +1,128 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/cpp/util/cli_call.h" + +#include <grpc/grpc.h> +#include <grpcpp/channel.h> +#include <grpcpp/client_context.h> +#include <grpcpp/create_channel.h> +#include <grpcpp/server.h> +#include <grpcpp/server_builder.h> +#include <grpcpp/server_context.h> +#include <gtest/gtest.h> + +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/util/string_ref_helper.h" + +using grpc::testing::EchoRequest; +using grpc::testing::EchoResponse; + +namespace grpc { +namespace testing { + +class TestServiceImpl : public ::grpc::testing::EchoTestService::Service { + public: + Status Echo(ServerContext* context, const EchoRequest* request, + EchoResponse* response) override { + if (!context->client_metadata().empty()) { + for (std::multimap<grpc::string_ref, grpc::string_ref>::const_iterator + iter = context->client_metadata().begin(); + iter != context->client_metadata().end(); ++iter) { + context->AddInitialMetadata(ToString(iter->first), + ToString(iter->second)); + } + } + context->AddTrailingMetadata("trailing_key", "trailing_value"); + response->set_message(request->message()); + return Status::OK; + } +}; + +class CliCallTest : public ::testing::Test { + protected: + CliCallTest() {} + + void SetUp() override { + int port = grpc_pick_unused_port_or_die(); + server_address_ << "localhost:" << port; + // Setup server + ServerBuilder builder; + builder.AddListeningPort(server_address_.str(), + InsecureServerCredentials()); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + } + + void TearDown() override { server_->Shutdown(); } + + void ResetStub() { + channel_ = grpc::CreateChannel(server_address_.str(), + InsecureChannelCredentials()); + stub_ = grpc::testing::EchoTestService::NewStub(channel_); + } + + std::shared_ptr<Channel> channel_; + std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_; + std::unique_ptr<Server> server_; + std::ostringstream server_address_; + TestServiceImpl service_; +}; + +// Send a rpc with a normal stub and then a CliCall. Verify they match. +TEST_F(CliCallTest, SimpleRpc) { + ResetStub(); + // Normal stub. + EchoRequest request; + EchoResponse response; + request.set_message("Hello"); + + ClientContext context; + context.AddMetadata("key1", "val1"); + Status s = stub_->Echo(&context, request, &response); + EXPECT_EQ(response.message(), request.message()); + EXPECT_TRUE(s.ok()); + + const TString kMethod("/grpc.testing.EchoTestService/Echo"); + TString request_bin, response_bin, expected_response_bin; + EXPECT_TRUE(request.SerializeToString(&request_bin)); + EXPECT_TRUE(response.SerializeToString(&expected_response_bin)); + std::multimap<TString, TString> client_metadata; + std::multimap<grpc::string_ref, grpc::string_ref> server_initial_metadata, + server_trailing_metadata; + client_metadata.insert(std::pair<TString, TString>("key1", "val1")); + Status s2 = CliCall::Call(channel_, kMethod, request_bin, &response_bin, + client_metadata, &server_initial_metadata, + &server_trailing_metadata); + EXPECT_TRUE(s2.ok()); + + EXPECT_EQ(expected_response_bin, response_bin); + EXPECT_EQ(context.GetServerInitialMetadata(), server_initial_metadata); + EXPECT_EQ(context.GetServerTrailingMetadata(), server_trailing_metadata); +} + +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/contrib/libs/grpc/test/cpp/util/cli_credentials.cc b/contrib/libs/grpc/test/cpp/util/cli_credentials.cc new file mode 100644 index 0000000000..efd548eb9b --- /dev/null +++ b/contrib/libs/grpc/test/cpp/util/cli_credentials.cc @@ -0,0 +1,245 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/cpp/util/cli_credentials.h" + +#include <gflags/gflags.h> +#include <grpc/slice.h> +#include <grpc/support/log.h> +#include <grpcpp/impl/codegen/slice.h> + +#include "src/core/lib/iomgr/load_file.h" + +DEFINE_bool( + enable_ssl, false, + "Whether to use ssl/tls. Deprecated. Use --channel_creds_type=ssl."); +DEFINE_bool(use_auth, false, + "Whether to create default google credentials. Deprecated. Use " + "--channel_creds_type=gdc."); +DEFINE_string( + access_token, "", + "The access token that will be sent to the server to authenticate RPCs. " + "Deprecated. Use --call_creds=access_token=<token>."); +DEFINE_string( + ssl_target, "", + "If not empty, treat the server host name as this for ssl/tls certificate " + "validation."); +DEFINE_string( + ssl_client_cert, "", + "If not empty, load this PEM formatted client certificate file. Requires " + "use of --ssl_client_key."); +DEFINE_string( + ssl_client_key, "", + "If not empty, load this PEM formatted private key. Requires use of " + "--ssl_client_cert"); +DEFINE_string( + local_connect_type, "local_tcp", + "The type of local connections for which local channel credentials will " + "be applied. Should be local_tcp or uds."); +DEFINE_string( + channel_creds_type, "", + "The channel creds type: insecure, ssl, gdc (Google Default Credentials), " + "alts, or local."); +DEFINE_string( + call_creds, "", + "Call credentials to use: none (default), or access_token=<token>. If " + "provided, the call creds are composited on top of channel creds."); + +namespace grpc { +namespace testing { + +namespace { + +const char ACCESS_TOKEN_PREFIX[] = "access_token="; +constexpr int ACCESS_TOKEN_PREFIX_LEN = + sizeof(ACCESS_TOKEN_PREFIX) / sizeof(*ACCESS_TOKEN_PREFIX) - 1; + +bool IsAccessToken(const TString& auth) { + return auth.length() > ACCESS_TOKEN_PREFIX_LEN && + auth.compare(0, ACCESS_TOKEN_PREFIX_LEN, ACCESS_TOKEN_PREFIX) == 0; +} + +TString AccessToken(const TString& auth) { + if (!IsAccessToken(auth)) { + return ""; + } + return TString(auth.c_str(), ACCESS_TOKEN_PREFIX_LEN); +} + +} // namespace + +TString CliCredentials::GetDefaultChannelCredsType() const { + // Compatibility logic for --enable_ssl. + if (FLAGS_enable_ssl) { + fprintf(stderr, + "warning: --enable_ssl is deprecated. Use " + "--channel_creds_type=ssl.\n"); + return "ssl"; + } + // Compatibility logic for --use_auth. + if (FLAGS_access_token.empty() && FLAGS_use_auth) { + fprintf(stderr, + "warning: --use_auth is deprecated. Use " + "--channel_creds_type=gdc.\n"); + return "gdc"; + } + return "insecure"; +} + +TString CliCredentials::GetDefaultCallCreds() const { + if (!FLAGS_access_token.empty()) { + fprintf(stderr, + "warning: --access_token is deprecated. Use " + "--call_creds=access_token=<token>.\n"); + return TString("access_token=") + FLAGS_access_token; + } + return "none"; +} + +std::shared_ptr<grpc::ChannelCredentials> +CliCredentials::GetChannelCredentials() const { + if (FLAGS_channel_creds_type.compare("insecure") == 0) { + return grpc::InsecureChannelCredentials(); + } else if (FLAGS_channel_creds_type.compare("ssl") == 0) { + grpc::SslCredentialsOptions ssl_creds_options; + // TODO(@Capstan): This won't affect Google Default Credentials using SSL. + if (!FLAGS_ssl_client_cert.empty()) { + grpc_slice cert_slice = grpc_empty_slice(); + GRPC_LOG_IF_ERROR( + "load_file", + grpc_load_file(FLAGS_ssl_client_cert.c_str(), 1, &cert_slice)); + ssl_creds_options.pem_cert_chain = + grpc::StringFromCopiedSlice(cert_slice); + grpc_slice_unref(cert_slice); + } + if (!FLAGS_ssl_client_key.empty()) { + grpc_slice key_slice = grpc_empty_slice(); + GRPC_LOG_IF_ERROR( + "load_file", + grpc_load_file(FLAGS_ssl_client_key.c_str(), 1, &key_slice)); + ssl_creds_options.pem_private_key = + grpc::StringFromCopiedSlice(key_slice); + grpc_slice_unref(key_slice); + } + return grpc::SslCredentials(ssl_creds_options); + } else if (FLAGS_channel_creds_type.compare("gdc") == 0) { + return grpc::GoogleDefaultCredentials(); + } else if (FLAGS_channel_creds_type.compare("alts") == 0) { + return grpc::experimental::AltsCredentials( + grpc::experimental::AltsCredentialsOptions()); + } else if (FLAGS_channel_creds_type.compare("local") == 0) { + if (FLAGS_local_connect_type.compare("local_tcp") == 0) { + return grpc::experimental::LocalCredentials(LOCAL_TCP); + } else if (FLAGS_local_connect_type.compare("uds") == 0) { + return grpc::experimental::LocalCredentials(UDS); + } else { + fprintf(stderr, + "--local_connect_type=%s invalid; must be local_tcp or uds.\n", + FLAGS_local_connect_type.c_str()); + } + } + fprintf(stderr, + "--channel_creds_type=%s invalid; must be insecure, ssl, gdc, " + "alts, or local.\n", + FLAGS_channel_creds_type.c_str()); + return std::shared_ptr<grpc::ChannelCredentials>(); +} + +std::shared_ptr<grpc::CallCredentials> CliCredentials::GetCallCredentials() + const { + if (IsAccessToken(FLAGS_call_creds.c_str())) { + return grpc::AccessTokenCredentials(AccessToken(FLAGS_call_creds.c_str())); + } + if (FLAGS_call_creds.compare("none") == 0) { + // Nothing to do; creds, if any, are baked into the channel. + return std::shared_ptr<grpc::CallCredentials>(); + } + fprintf(stderr, + "--call_creds=%s invalid; must be none " + "or access_token=<token>.\n", + FLAGS_call_creds.c_str()); + return std::shared_ptr<grpc::CallCredentials>(); +} + +std::shared_ptr<grpc::ChannelCredentials> CliCredentials::GetCredentials() + const { + if (FLAGS_call_creds.empty()) { + FLAGS_call_creds = GetDefaultCallCreds(); + } else if (!FLAGS_access_token.empty() && !IsAccessToken(FLAGS_call_creds.c_str())) { + fprintf(stderr, + "warning: ignoring --access_token because --call_creds " + "already set to %s.\n", + FLAGS_call_creds.c_str()); + } + if (FLAGS_channel_creds_type.empty()) { + FLAGS_channel_creds_type = GetDefaultChannelCredsType(); + } else if (FLAGS_enable_ssl && FLAGS_channel_creds_type.compare("ssl") != 0) { + fprintf(stderr, + "warning: ignoring --enable_ssl because " + "--channel_creds_type already set to %s.\n", + FLAGS_channel_creds_type.c_str()); + } else if (FLAGS_use_auth && FLAGS_channel_creds_type.compare("gdc") != 0) { + fprintf(stderr, + "warning: ignoring --use_auth because " + "--channel_creds_type already set to %s.\n", + FLAGS_channel_creds_type.c_str()); + } + // Legacy transport upgrade logic for insecure requests. + if (IsAccessToken(FLAGS_call_creds.c_str()) && + FLAGS_channel_creds_type.compare("insecure") == 0) { + fprintf(stderr, + "warning: --channel_creds_type=insecure upgraded to ssl because " + "an access token was provided.\n"); + FLAGS_channel_creds_type = "ssl"; + } + std::shared_ptr<grpc::ChannelCredentials> channel_creds = + GetChannelCredentials(); + // Composite any call-type credentials on top of the base channel. + std::shared_ptr<grpc::CallCredentials> call_creds = GetCallCredentials(); + return (channel_creds == nullptr || call_creds == nullptr) + ? channel_creds + : grpc::CompositeChannelCredentials(channel_creds, call_creds); +} + +const TString CliCredentials::GetCredentialUsage() const { + return " --enable_ssl ; Set whether to use ssl " + "(deprecated)\n" + " --use_auth ; Set whether to create default google" + " credentials\n" + " ; (deprecated)\n" + " --access_token ; Set the access token in metadata," + " overrides --use_auth\n" + " ; (deprecated)\n" + " --ssl_target ; Set server host for ssl validation\n" + " --ssl_client_cert ; Client cert for ssl\n" + " --ssl_client_key ; Client private key for ssl\n" + " --local_connect_type ; Set to local_tcp or uds\n" + " --channel_creds_type ; Set to insecure, ssl, gdc, alts, or " + "local\n" + " --call_creds ; Set to none, or" + " access_token=<token>\n"; +} + +const TString CliCredentials::GetSslTargetNameOverride() const { + bool use_ssl = FLAGS_channel_creds_type.compare("ssl") == 0 || + FLAGS_channel_creds_type.compare("gdc") == 0; + return use_ssl ? FLAGS_ssl_target : ""; +} + +} // namespace testing +} // namespace grpc diff --git a/contrib/libs/grpc/test/cpp/util/cli_credentials.h b/contrib/libs/grpc/test/cpp/util/cli_credentials.h new file mode 100644 index 0000000000..3e695692fa --- /dev/null +++ b/contrib/libs/grpc/test/cpp/util/cli_credentials.h @@ -0,0 +1,55 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef GRPC_TEST_CPP_UTIL_CLI_CREDENTIALS_H +#define GRPC_TEST_CPP_UTIL_CLI_CREDENTIALS_H + +#include <grpcpp/security/credentials.h> +#include <grpcpp/support/config.h> + +namespace grpc { +namespace testing { + +class CliCredentials { + public: + virtual ~CliCredentials() {} + std::shared_ptr<grpc::ChannelCredentials> GetCredentials() const; + virtual const TString GetCredentialUsage() const; + virtual const TString GetSslTargetNameOverride() const; + + protected: + // Returns the appropriate channel_creds_type value for the set of legacy + // flag arguments. + virtual TString GetDefaultChannelCredsType() const; + // Returns the appropriate call_creds value for the set of legacy flag + // arguments. + virtual TString GetDefaultCallCreds() const; + // Returns the base transport channel credentials. Child classes can override + // to support additional channel_creds_types unknown to this base class. + virtual std::shared_ptr<grpc::ChannelCredentials> GetChannelCredentials() + const; + // Returns call credentials to composite onto the base transport channel + // credentials. Child classes can override to support additional + // authentication flags unknown to this base class. + virtual std::shared_ptr<grpc::CallCredentials> GetCallCredentials() const; +}; + +} // namespace testing +} // namespace grpc + +#endif // GRPC_TEST_CPP_UTIL_CLI_CREDENTIALS_H diff --git a/contrib/libs/grpc/test/cpp/util/config_grpc_cli.h b/contrib/libs/grpc/test/cpp/util/config_grpc_cli.h new file mode 100644 index 0000000000..358884196d --- /dev/null +++ b/contrib/libs/grpc/test/cpp/util/config_grpc_cli.h @@ -0,0 +1,70 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef GRPC_TEST_CPP_UTIL_CONFIG_GRPC_CLI_H +#define GRPC_TEST_CPP_UTIL_CONFIG_GRPC_CLI_H + +#include <grpcpp/impl/codegen/config_protobuf.h> + +#ifndef GRPC_CUSTOM_DYNAMICMESSAGEFACTORY +#include <google/protobuf/dynamic_message.h> +#define GRPC_CUSTOM_DYNAMICMESSAGEFACTORY \ + ::google::protobuf::DynamicMessageFactory +#endif + +#ifndef GRPC_CUSTOM_DESCRIPTORPOOLDATABASE +#include <google/protobuf/descriptor.h> +#define GRPC_CUSTOM_DESCRIPTORPOOLDATABASE \ + ::google::protobuf::DescriptorPoolDatabase +#define GRPC_CUSTOM_MERGEDDESCRIPTORDATABASE \ + ::google::protobuf::MergedDescriptorDatabase +#endif + +#ifndef GRPC_CUSTOM_TEXTFORMAT +#include <google/protobuf/text_format.h> +#define GRPC_CUSTOM_TEXTFORMAT ::google::protobuf::TextFormat +#endif + +#ifndef GRPC_CUSTOM_DISKSOURCETREE +#include <google/protobuf/compiler/importer.h> +#define GRPC_CUSTOM_DISKSOURCETREE ::google::protobuf::compiler::DiskSourceTree +#define GRPC_CUSTOM_IMPORTER ::google::protobuf::compiler::Importer +#define GRPC_CUSTOM_MULTIFILEERRORCOLLECTOR \ + ::google::protobuf::compiler::MultiFileErrorCollector +#endif + +namespace grpc { +namespace protobuf { + +typedef GRPC_CUSTOM_DYNAMICMESSAGEFACTORY DynamicMessageFactory; + +typedef GRPC_CUSTOM_DESCRIPTORPOOLDATABASE DescriptorPoolDatabase; +typedef GRPC_CUSTOM_MERGEDDESCRIPTORDATABASE MergedDescriptorDatabase; + +typedef GRPC_CUSTOM_TEXTFORMAT TextFormat; + +namespace compiler { +typedef GRPC_CUSTOM_DISKSOURCETREE DiskSourceTree; +typedef GRPC_CUSTOM_IMPORTER Importer; +typedef GRPC_CUSTOM_MULTIFILEERRORCOLLECTOR MultiFileErrorCollector; +} // namespace compiler + +} // namespace protobuf +} // namespace grpc + +#endif // GRPC_TEST_CPP_UTIL_CONFIG_GRPC_CLI_H diff --git a/contrib/libs/grpc/test/cpp/util/create_test_channel.cc b/contrib/libs/grpc/test/cpp/util/create_test_channel.cc new file mode 100644 index 0000000000..86d8e22af1 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/util/create_test_channel.cc @@ -0,0 +1,252 @@ +/* + * + * Copyright 2015-2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/cpp/util/create_test_channel.h" + +#include <gflags/gflags.h> + +#include <grpc/support/log.h> +#include <grpcpp/create_channel.h> +#include <grpcpp/security/credentials.h> + +#include "test/cpp/util/test_credentials_provider.h" + +DEFINE_string( + grpc_test_use_grpclb_with_child_policy, "", + "If non-empty, set a static service config on channels created by " + "grpc::CreateTestChannel, that configures the grpclb LB policy " + "with a child policy being the value of this flag (e.g. round_robin " + "or pick_first)."); + +namespace grpc { + +namespace { + +const char kProdTlsCredentialsType[] = "prod_ssl"; + +class SslCredentialProvider : public testing::CredentialTypeProvider { + public: + std::shared_ptr<ChannelCredentials> GetChannelCredentials( + grpc::ChannelArguments* /*args*/) override { + return grpc::SslCredentials(SslCredentialsOptions()); + } + std::shared_ptr<ServerCredentials> GetServerCredentials() override { + return nullptr; + } +}; + +gpr_once g_once_init_add_prod_ssl_provider = GPR_ONCE_INIT; +// Register ssl with non-test roots type to the credentials provider. +void AddProdSslType() { + testing::GetCredentialsProvider()->AddSecureType( + kProdTlsCredentialsType, std::unique_ptr<testing::CredentialTypeProvider>( + new SslCredentialProvider)); +} + +void MaybeSetCustomChannelArgs(grpc::ChannelArguments* args) { + if (FLAGS_grpc_test_use_grpclb_with_child_policy.size() > 0) { + args->SetString("grpc.service_config", + "{\"loadBalancingConfig\":[{\"grpclb\":{\"childPolicy\":[{" + "\"" + + FLAGS_grpc_test_use_grpclb_with_child_policy + + "\":{}}]}}]}"); + } +} + +} // namespace + +// When cred_type is 'ssl', if server is empty, override_hostname is used to +// create channel. Otherwise, connect to server and override hostname if +// override_hostname is provided. +// When cred_type is not 'ssl', override_hostname is ignored. +// Set use_prod_root to true to use the SSL root for connecting to google. +// In this case, path to the roots pem file must be set via environment variable +// GRPC_DEFAULT_SSL_ROOTS_FILE_PATH. +// Otherwise, root for test SSL cert will be used. +// creds will be used to create a channel when cred_type is 'ssl'. +// Use examples: +// CreateTestChannel( +// "1.1.1.1:12345", "ssl", "override.hostname.com", false, creds); +// CreateTestChannel("test.google.com:443", "ssl", "", true, creds); +// same as above +// CreateTestChannel("", "ssl", "test.google.com:443", true, creds); +std::shared_ptr<Channel> CreateTestChannel( + const TString& server, const TString& cred_type, + const TString& override_hostname, bool use_prod_roots, + const std::shared_ptr<CallCredentials>& creds, + const ChannelArguments& args) { + return CreateTestChannel(server, cred_type, override_hostname, use_prod_roots, + creds, args, + /*interceptor_creators=*/{}); +} + +std::shared_ptr<Channel> CreateTestChannel( + const TString& server, const TString& override_hostname, + testing::transport_security security_type, bool use_prod_roots, + const std::shared_ptr<CallCredentials>& creds, + const ChannelArguments& args) { + return CreateTestChannel(server, override_hostname, security_type, + use_prod_roots, creds, args, + /*interceptor_creators=*/{}); +} + +std::shared_ptr<Channel> CreateTestChannel( + const TString& server, const TString& override_hostname, + testing::transport_security security_type, bool use_prod_roots, + const std::shared_ptr<CallCredentials>& creds) { + return CreateTestChannel(server, override_hostname, security_type, + use_prod_roots, creds, ChannelArguments()); +} + +std::shared_ptr<Channel> CreateTestChannel( + const TString& server, const TString& override_hostname, + testing::transport_security security_type, bool use_prod_roots) { + return CreateTestChannel(server, override_hostname, security_type, + use_prod_roots, std::shared_ptr<CallCredentials>()); +} + +// Shortcut for end2end and interop tests. +std::shared_ptr<Channel> CreateTestChannel( + const TString& server, testing::transport_security security_type) { + return CreateTestChannel(server, "foo.test.google.fr", security_type, false); +} + +std::shared_ptr<Channel> CreateTestChannel( + const TString& server, const TString& credential_type, + const std::shared_ptr<CallCredentials>& creds) { + ChannelArguments channel_args; + MaybeSetCustomChannelArgs(&channel_args); + std::shared_ptr<ChannelCredentials> channel_creds = + testing::GetCredentialsProvider()->GetChannelCredentials(credential_type, + &channel_args); + GPR_ASSERT(channel_creds != nullptr); + if (creds.get()) { + channel_creds = grpc::CompositeChannelCredentials(channel_creds, creds); + } + return ::grpc::CreateCustomChannel(server, channel_creds, channel_args); +} + +std::shared_ptr<Channel> CreateTestChannel( + const TString& server, const TString& cred_type, + const TString& override_hostname, bool use_prod_roots, + const std::shared_ptr<CallCredentials>& creds, const ChannelArguments& args, + std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> + interceptor_creators) { + ChannelArguments channel_args(args); + MaybeSetCustomChannelArgs(&channel_args); + std::shared_ptr<ChannelCredentials> channel_creds; + if (cred_type.empty()) { + if (interceptor_creators.empty()) { + return ::grpc::CreateCustomChannel(server, InsecureChannelCredentials(), + channel_args); + } else { + return experimental::CreateCustomChannelWithInterceptors( + server, InsecureChannelCredentials(), channel_args, + std::move(interceptor_creators)); + } + } else if (cred_type == testing::kTlsCredentialsType) { // cred_type == "ssl" + if (use_prod_roots) { + gpr_once_init(&g_once_init_add_prod_ssl_provider, &AddProdSslType); + channel_creds = testing::GetCredentialsProvider()->GetChannelCredentials( + kProdTlsCredentialsType, &channel_args); + if (!server.empty() && !override_hostname.empty()) { + channel_args.SetSslTargetNameOverride(override_hostname); + } + } else { + // override_hostname is discarded as the provider handles it. + channel_creds = testing::GetCredentialsProvider()->GetChannelCredentials( + testing::kTlsCredentialsType, &channel_args); + } + GPR_ASSERT(channel_creds != nullptr); + + const TString& connect_to = server.empty() ? override_hostname : server; + if (creds.get()) { + channel_creds = grpc::CompositeChannelCredentials(channel_creds, creds); + } + if (interceptor_creators.empty()) { + return ::grpc::CreateCustomChannel(connect_to, channel_creds, + channel_args); + } else { + return experimental::CreateCustomChannelWithInterceptors( + connect_to, channel_creds, channel_args, + std::move(interceptor_creators)); + } + } else { + channel_creds = testing::GetCredentialsProvider()->GetChannelCredentials( + cred_type, &channel_args); + GPR_ASSERT(channel_creds != nullptr); + + if (interceptor_creators.empty()) { + return ::grpc::CreateCustomChannel(server, channel_creds, channel_args); + } else { + return experimental::CreateCustomChannelWithInterceptors( + server, channel_creds, channel_args, std::move(interceptor_creators)); + } + } +} + +std::shared_ptr<Channel> CreateTestChannel( + const TString& server, const TString& override_hostname, + testing::transport_security security_type, bool use_prod_roots, + const std::shared_ptr<CallCredentials>& creds, const ChannelArguments& args, + std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> + interceptor_creators) { + TString credential_type = + security_type == testing::ALTS + ? testing::kAltsCredentialsType + : (security_type == testing::TLS ? testing::kTlsCredentialsType + : testing::kInsecureCredentialsType); + return CreateTestChannel(server, credential_type, override_hostname, + use_prod_roots, creds, args, + std::move(interceptor_creators)); +} + +std::shared_ptr<Channel> CreateTestChannel( + const TString& server, const TString& override_hostname, + testing::transport_security security_type, bool use_prod_roots, + const std::shared_ptr<CallCredentials>& creds, + std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> + interceptor_creators) { + return CreateTestChannel(server, override_hostname, security_type, + use_prod_roots, creds, ChannelArguments(), + std::move(interceptor_creators)); +} + +std::shared_ptr<Channel> CreateTestChannel( + const TString& server, const TString& credential_type, + const std::shared_ptr<CallCredentials>& creds, + std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> + interceptor_creators) { + ChannelArguments channel_args; + MaybeSetCustomChannelArgs(&channel_args); + std::shared_ptr<ChannelCredentials> channel_creds = + testing::GetCredentialsProvider()->GetChannelCredentials(credential_type, + &channel_args); + GPR_ASSERT(channel_creds != nullptr); + if (creds.get()) { + channel_creds = grpc::CompositeChannelCredentials(channel_creds, creds); + } + return experimental::CreateCustomChannelWithInterceptors( + server, channel_creds, channel_args, std::move(interceptor_creators)); +} + +} // namespace grpc diff --git a/contrib/libs/grpc/test/cpp/util/create_test_channel.h b/contrib/libs/grpc/test/cpp/util/create_test_channel.h new file mode 100644 index 0000000000..ed4ce6c11b --- /dev/null +++ b/contrib/libs/grpc/test/cpp/util/create_test_channel.h @@ -0,0 +1,99 @@ +/* + * + * Copyright 2015-2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef GRPC_TEST_CPP_UTIL_CREATE_TEST_CHANNEL_H +#define GRPC_TEST_CPP_UTIL_CREATE_TEST_CHANNEL_H + +#include <memory> + +#include <grpcpp/channel.h> +#include <grpcpp/impl/codegen/client_interceptor.h> +#include <grpcpp/security/credentials.h> +#include <grpcpp/support/channel_arguments.h> + +namespace grpc { +class Channel; + +namespace testing { + +typedef enum { INSECURE = 0, TLS, ALTS } transport_security; + +} // namespace testing + +std::shared_ptr<Channel> CreateTestChannel( + const TString& server, testing::transport_security security_type); + +std::shared_ptr<Channel> CreateTestChannel( + const TString& server, const TString& override_hostname, + testing::transport_security security_type, bool use_prod_roots); + +std::shared_ptr<Channel> CreateTestChannel( + const TString& server, const TString& override_hostname, + testing::transport_security security_type, bool use_prod_roots, + const std::shared_ptr<CallCredentials>& creds); + +std::shared_ptr<Channel> CreateTestChannel( + const TString& server, const TString& override_hostname, + testing::transport_security security_type, bool use_prod_roots, + const std::shared_ptr<CallCredentials>& creds, + const ChannelArguments& args); + +std::shared_ptr<Channel> CreateTestChannel( + const TString& server, const TString& cred_type, + const TString& override_hostname, bool use_prod_roots, + const std::shared_ptr<CallCredentials>& creds, + const ChannelArguments& args); + +std::shared_ptr<Channel> CreateTestChannel( + const TString& server, const TString& credential_type, + const std::shared_ptr<CallCredentials>& creds); + +std::shared_ptr<Channel> CreateTestChannel( + const TString& server, const TString& override_hostname, + testing::transport_security security_type, bool use_prod_roots, + const std::shared_ptr<CallCredentials>& creds, + std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> + interceptor_creators); + +std::shared_ptr<Channel> CreateTestChannel( + const TString& server, const TString& override_hostname, + testing::transport_security security_type, bool use_prod_roots, + const std::shared_ptr<CallCredentials>& creds, const ChannelArguments& args, + std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> + interceptor_creators); + +std::shared_ptr<Channel> CreateTestChannel( + const TString& server, const TString& cred_type, + const TString& override_hostname, bool use_prod_roots, + const std::shared_ptr<CallCredentials>& creds, const ChannelArguments& args, + std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> + interceptor_creators); + +std::shared_ptr<Channel> CreateTestChannel( + const TString& server, const TString& credential_type, + const std::shared_ptr<CallCredentials>& creds, + std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> + interceptor_creators); + +} // namespace grpc + +#endif // GRPC_TEST_CPP_UTIL_CREATE_TEST_CHANNEL_H diff --git a/contrib/libs/grpc/test/cpp/util/error_details_test.cc b/contrib/libs/grpc/test/cpp/util/error_details_test.cc new file mode 100644 index 0000000000..630ab1d98f --- /dev/null +++ b/contrib/libs/grpc/test/cpp/util/error_details_test.cc @@ -0,0 +1,125 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <grpcpp/support/error_details.h> +#include <gtest/gtest.h> + +#include "src/proto/grpc/status/status.pb.h" +#include "src/proto/grpc/testing/echo_messages.pb.h" +#include "test/core/util/test_config.h" + +namespace grpc { +namespace { + +TEST(ExtractTest, Success) { + google::rpc::Status expected; + expected.set_code(13); // INTERNAL + expected.set_message("I am an error message"); + testing::EchoRequest expected_details; + expected_details.set_message(TString(100, '\0')); + expected.add_details()->PackFrom(expected_details); + + google::rpc::Status to; + TString error_details = expected.SerializeAsString(); + Status from(static_cast<StatusCode>(expected.code()), expected.message(), + error_details); + EXPECT_TRUE(ExtractErrorDetails(from, &to).ok()); + EXPECT_EQ(expected.code(), to.code()); + EXPECT_EQ(expected.message(), to.message()); + EXPECT_EQ(1, to.details_size()); + testing::EchoRequest details; + to.details(0).UnpackTo(&details); + EXPECT_EQ(expected_details.message(), details.message()); +} + +TEST(ExtractTest, NullInput) { + EXPECT_EQ(StatusCode::FAILED_PRECONDITION, + ExtractErrorDetails(Status(), nullptr).error_code()); +} + +TEST(ExtractTest, Unparsable) { + TString error_details("I am not a status object"); + Status from(StatusCode::INTERNAL, "", error_details); + google::rpc::Status to; + EXPECT_EQ(StatusCode::INVALID_ARGUMENT, + ExtractErrorDetails(from, &to).error_code()); +} + +TEST(SetTest, Success) { + google::rpc::Status expected; + expected.set_code(13); // INTERNAL + expected.set_message("I am an error message"); + testing::EchoRequest expected_details; + expected_details.set_message(TString(100, '\0')); + expected.add_details()->PackFrom(expected_details); + + Status to; + Status s = SetErrorDetails(expected, &to); + EXPECT_TRUE(s.ok()); + EXPECT_EQ(expected.code(), to.error_code()); + EXPECT_EQ(expected.message(), to.error_message()); + EXPECT_EQ(expected.SerializeAsString(), to.error_details()); +} + +TEST(SetTest, NullInput) { + EXPECT_EQ(StatusCode::FAILED_PRECONDITION, + SetErrorDetails(google::rpc::Status(), nullptr).error_code()); +} + +TEST(SetTest, OutOfScopeErrorCode) { + google::rpc::Status expected; + expected.set_code(17); // Out of scope (UNAUTHENTICATED is 16). + expected.set_message("I am an error message"); + testing::EchoRequest expected_details; + expected_details.set_message(TString(100, '\0')); + expected.add_details()->PackFrom(expected_details); + + Status to; + Status s = SetErrorDetails(expected, &to); + EXPECT_TRUE(s.ok()); + EXPECT_EQ(StatusCode::UNKNOWN, to.error_code()); + EXPECT_EQ(expected.message(), to.error_message()); + EXPECT_EQ(expected.SerializeAsString(), to.error_details()); +} + +TEST(SetTest, ValidScopeErrorCode) { + for (int c = StatusCode::OK; c <= StatusCode::UNAUTHENTICATED; c++) { + google::rpc::Status expected; + expected.set_code(c); + expected.set_message("I am an error message"); + testing::EchoRequest expected_details; + expected_details.set_message(TString(100, '\0')); + expected.add_details()->PackFrom(expected_details); + + Status to; + Status s = SetErrorDetails(expected, &to); + EXPECT_TRUE(s.ok()); + EXPECT_EQ(c, to.error_code()); + EXPECT_EQ(expected.message(), to.error_message()); + EXPECT_EQ(expected.SerializeAsString(), to.error_details()); + } +} + +} // namespace +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/contrib/libs/grpc/test/cpp/util/grpc_cli.cc b/contrib/libs/grpc/test/cpp/util/grpc_cli.cc new file mode 100644 index 0000000000..45c6b94f84 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/util/grpc_cli.cc @@ -0,0 +1,90 @@ +/* + + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* + A command line tool to talk to a grpc server. + Run `grpc_cli help` command to see its usage information. + + Example of talking to grpc interop server: + grpc_cli call localhost:50051 UnaryCall "response_size:10" \ + --protofiles=src/proto/grpc/testing/test.proto --enable_ssl=false + + Options: + 1. --protofiles, use this flag to provide proto files if the server does + does not have the reflection service. + 2. --proto_path, if your proto file is not under current working directory, + use this flag to provide a search root. It should work similar to the + counterpart in protoc. This option is valid only when protofiles is + provided. + 3. --metadata specifies metadata to be sent to the server, such as: + --metadata="MyHeaderKey1:Value1:MyHeaderKey2:Value2" + 4. --enable_ssl, whether to use tls. + 5. --use_auth, if set to true, attach a GoogleDefaultCredentials to the call + 6. --infile, input filename (defaults to stdin) + 7. --outfile, output filename (defaults to stdout) + 8. --binary_input, use the serialized request as input. The serialized + request can be generated by calling something like: + protoc --proto_path=src/proto/grpc/testing/ \ + --encode=grpc.testing.SimpleRequest \ + src/proto/grpc/testing/messages.proto \ + < input.txt > input.bin + If this is used and no proto file is provided in the argument list, the + method string has to be exact in the form of /package.service/method. + 9. --binary_output, use binary format response as output, it can + be later decoded using protoc: + protoc --proto_path=src/proto/grpc/testing/ \ + --decode=grpc.testing.SimpleResponse \ + src/proto/grpc/testing/messages.proto \ + < output.bin > output.txt + 10. --default_service_config, optional default service config to use + on the channel. Note that this may be ignored if the name resolver + returns a service config. + 11. --display_peer_address, on CallMethod commands, log the peer socket + address of the connection that each RPC is made on to stderr. +*/ + +#include <fstream> +#include <functional> +#include <iostream> + +#include <gflags/gflags.h> +#include <grpcpp/support/config.h> +#include "test/cpp/util/cli_credentials.h" +#include "test/cpp/util/grpc_tool.h" +#include "test/cpp/util/test_config.h" + +DEFINE_string(outfile, "", "Output file (default is stdout)"); + +static bool SimplePrint(const TString& outfile, const TString& output) { + if (outfile.empty()) { + std::cout << output << std::flush; + } else { + std::ofstream output_file(outfile, std::ios::app | std::ios::binary); + output_file << output << std::flush; + output_file.close(); + } + return true; +} + +int main(int argc, char** argv) { + grpc::testing::InitTest(&argc, &argv, true); + + return grpc::testing::GrpcToolMainLib( + argc, (const char**)argv, grpc::testing::CliCredentials(), + std::bind(SimplePrint, TString(FLAGS_outfile.c_str()), std::placeholders::_1)); +} diff --git a/contrib/libs/grpc/test/cpp/util/grpc_tool.cc b/contrib/libs/grpc/test/cpp/util/grpc_tool.cc new file mode 100644 index 0000000000..30f3024e25 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/util/grpc_tool.cc @@ -0,0 +1,985 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/cpp/util/grpc_tool.h" + +#include <gflags/gflags.h> +#include <grpc/grpc.h> +#include <grpc/support/port_platform.h> +#include <grpcpp/channel.h> +#include <grpcpp/create_channel.h> +#include <grpcpp/grpcpp.h> +#include <grpcpp/security/credentials.h> +#include <grpcpp/support/string_ref.h> + +#include <cstdio> +#include <fstream> +#include <iostream> +#include <memory> +#include <sstream> +#include <util/generic/string.h> +#include <thread> + +#include "test/cpp/util/cli_call.h" +#include "test/cpp/util/proto_file_parser.h" +#include "test/cpp/util/proto_reflection_descriptor_database.h" +#include "test/cpp/util/service_describer.h" + +#if GPR_WINDOWS +#include <io.h> +#else +#include <unistd.h> +#endif + +namespace grpc { +namespace testing { + +DEFINE_bool(l, false, "Use a long listing format"); +DEFINE_bool(remotedb, true, "Use server types to parse and format messages"); +DEFINE_string(metadata, "", + "Metadata to send to server, in the form of key1:val1:key2:val2"); +DEFINE_string(proto_path, ".", "Path to look for the proto file."); +DEFINE_string(protofiles, "", "Name of the proto file."); +DEFINE_bool(binary_input, false, "Input in binary format"); +DEFINE_bool(binary_output, false, "Output in binary format"); +DEFINE_string( + default_service_config, "", + "Default service config to use on the channel, if non-empty. Note " + "that this will be ignored if the name resolver returns a service " + "config."); +DEFINE_bool( + display_peer_address, false, + "Log the peer socket address of the connection that each RPC is made " + "on to stderr."); +DEFINE_bool(json_input, false, "Input in json format"); +DEFINE_bool(json_output, false, "Output in json format"); +DEFINE_string(infile, "", "Input file (default is stdin)"); +DEFINE_bool(batch, false, + "Input contains multiple requests. Please do not use this to send " + "more than a few RPCs. gRPC CLI has very different performance " + "characteristics compared with normal RPC calls which make it " + "unsuitable for loadtesting or significant production traffic."); +DEFINE_double(timeout, -1, + "Specify timeout in seconds, used to set the deadline for all " + "RPCs. The default value of -1 means no deadline has been set."); + +namespace { + +class GrpcTool { + public: + explicit GrpcTool(); + virtual ~GrpcTool() {} + + bool Help(int argc, const char** argv, const CliCredentials& cred, + GrpcToolOutputCallback callback); + bool CallMethod(int argc, const char** argv, const CliCredentials& cred, + GrpcToolOutputCallback callback); + bool ListServices(int argc, const char** argv, const CliCredentials& cred, + GrpcToolOutputCallback callback); + bool PrintType(int argc, const char** argv, const CliCredentials& cred, + GrpcToolOutputCallback callback); + // TODO(zyc): implement the following methods + // bool ListServices(int argc, const char** argv, GrpcToolOutputCallback + // callback); + // bool PrintTypeId(int argc, const char** argv, GrpcToolOutputCallback + // callback); + bool ParseMessage(int argc, const char** argv, const CliCredentials& cred, + GrpcToolOutputCallback callback); + bool ToText(int argc, const char** argv, const CliCredentials& cred, + GrpcToolOutputCallback callback); + bool ToJson(int argc, const char** argv, const CliCredentials& cred, + GrpcToolOutputCallback callback); + bool ToBinary(int argc, const char** argv, const CliCredentials& cred, + GrpcToolOutputCallback callback); + + void SetPrintCommandMode(int exit_status) { + print_command_usage_ = true; + usage_exit_status_ = exit_status; + } + + private: + void CommandUsage(const TString& usage) const; + bool print_command_usage_; + int usage_exit_status_; + const TString cred_usage_; +}; + +template <typename T> +std::function<bool(GrpcTool*, int, const char**, const CliCredentials&, + GrpcToolOutputCallback)> +BindWith5Args(T&& func) { + return std::bind(std::forward<T>(func), std::placeholders::_1, + std::placeholders::_2, std::placeholders::_3, + std::placeholders::_4, std::placeholders::_5); +} + +template <typename T> +size_t ArraySize(T& a) { + return ((sizeof(a) / sizeof(*(a))) / + static_cast<size_t>(!(sizeof(a) % sizeof(*(a))))); +} + +void ParseMetadataFlag( + std::multimap<TString, TString>* client_metadata) { + if (FLAGS_metadata.empty()) { + return; + } + std::vector<TString> fields; + const char delim = ':'; + const char escape = '\\'; + size_t cur = -1; + std::stringstream ss; + while (++cur < FLAGS_metadata.length()) { + switch (FLAGS_metadata.at(cur)) { + case escape: + if (cur < FLAGS_metadata.length() - 1) { + char c = FLAGS_metadata.at(++cur); + if (c == delim || c == escape) { + ss << c; + continue; + } + } + fprintf(stderr, "Failed to parse metadata flag.\n"); + exit(1); + case delim: + fields.push_back(ss.str()); + ss.str(""); + ss.clear(); + break; + default: + ss << FLAGS_metadata.at(cur); + } + } + fields.push_back(ss.str()); + if (fields.size() % 2) { + fprintf(stderr, "Failed to parse metadata flag.\n"); + exit(1); + } + for (size_t i = 0; i < fields.size(); i += 2) { + client_metadata->insert( + std::pair<TString, TString>(fields[i], fields[i + 1])); + } +} + +template <typename T> +void PrintMetadata(const T& m, const TString& message) { + if (m.empty()) { + return; + } + fprintf(stderr, "%s\n", message.c_str()); + TString pair; + for (typename T::const_iterator iter = m.begin(); iter != m.end(); ++iter) { + pair.clear(); + pair.append(iter->first.data(), iter->first.size()); + pair.append(" : "); + pair.append(iter->second.data(), iter->second.size()); + fprintf(stderr, "%s\n", pair.c_str()); + } +} + +void ReadResponse(CliCall* call, const TString& method_name, + GrpcToolOutputCallback callback, ProtoFileParser* parser, + gpr_mu* parser_mu, bool print_mode) { + TString serialized_response_proto; + std::multimap<grpc::string_ref, grpc::string_ref> server_initial_metadata; + + for (bool receive_initial_metadata = true; call->ReadAndMaybeNotifyWrite( + &serialized_response_proto, + receive_initial_metadata ? &server_initial_metadata : nullptr); + receive_initial_metadata = false) { + fprintf(stderr, "got response.\n"); + if (!FLAGS_binary_output) { + gpr_mu_lock(parser_mu); + serialized_response_proto = parser->GetFormattedStringFromMethod( + method_name, serialized_response_proto, false /* is_request */, + FLAGS_json_output); + if (parser->HasError() && print_mode) { + fprintf(stderr, "Failed to parse response.\n"); + } + gpr_mu_unlock(parser_mu); + } + if (receive_initial_metadata) { + PrintMetadata(server_initial_metadata, + "Received initial metadata from server:"); + } + if (!callback(serialized_response_proto) && print_mode) { + fprintf(stderr, "Failed to output response.\n"); + } + } +} + +std::shared_ptr<grpc::Channel> CreateCliChannel( + const TString& server_address, const CliCredentials& cred) { + grpc::ChannelArguments args; + if (!cred.GetSslTargetNameOverride().empty()) { + args.SetSslTargetNameOverride(cred.GetSslTargetNameOverride()); + } + if (!FLAGS_default_service_config.empty()) { + args.SetString(GRPC_ARG_SERVICE_CONFIG, + FLAGS_default_service_config.c_str()); + } + return ::grpc::CreateCustomChannel(server_address, cred.GetCredentials(), + args); +} + +struct Command { + const char* command; + std::function<bool(GrpcTool*, int, const char**, const CliCredentials&, + GrpcToolOutputCallback)> + function; + int min_args; + int max_args; +}; + +const Command ops[] = { + {"help", BindWith5Args(&GrpcTool::Help), 0, INT_MAX}, + {"ls", BindWith5Args(&GrpcTool::ListServices), 1, 3}, + {"list", BindWith5Args(&GrpcTool::ListServices), 1, 3}, + {"call", BindWith5Args(&GrpcTool::CallMethod), 2, 3}, + {"type", BindWith5Args(&GrpcTool::PrintType), 2, 2}, + {"parse", BindWith5Args(&GrpcTool::ParseMessage), 2, 3}, + {"totext", BindWith5Args(&GrpcTool::ToText), 2, 3}, + {"tobinary", BindWith5Args(&GrpcTool::ToBinary), 2, 3}, + {"tojson", BindWith5Args(&GrpcTool::ToJson), 2, 3}, +}; + +void Usage(const TString& msg) { + fprintf( + stderr, + "%s\n" + " grpc_cli ls ... ; List services\n" + " grpc_cli call ... ; Call method\n" + " grpc_cli type ... ; Print type\n" + " grpc_cli parse ... ; Parse message\n" + " grpc_cli totext ... ; Convert binary message to text\n" + " grpc_cli tojson ... ; Convert binary message to json\n" + " grpc_cli tobinary ... ; Convert text message to binary\n" + " grpc_cli help ... ; Print this message, or per-command usage\n" + "\n", + msg.c_str()); + + exit(1); +} + +const Command* FindCommand(const TString& name) { + for (int i = 0; i < (int)ArraySize(ops); i++) { + if (name == ops[i].command) { + return &ops[i]; + } + } + return nullptr; +} +} // namespace + +int GrpcToolMainLib(int argc, const char** argv, const CliCredentials& cred, + GrpcToolOutputCallback callback) { + if (argc < 2) { + Usage("No command specified"); + } + + TString command = argv[1]; + argc -= 2; + argv += 2; + + const Command* cmd = FindCommand(command); + if (cmd != nullptr) { + GrpcTool grpc_tool; + if (argc < cmd->min_args || argc > cmd->max_args) { + // Force the command to print its usage message + fprintf(stderr, "\nWrong number of arguments for %s\n", command.c_str()); + grpc_tool.SetPrintCommandMode(1); + return cmd->function(&grpc_tool, -1, nullptr, cred, callback); + } + const bool ok = cmd->function(&grpc_tool, argc, argv, cred, callback); + return ok ? 0 : 1; + } else { + Usage("Invalid command '" + TString(command.c_str()) + "'"); + } + return 1; +} + +GrpcTool::GrpcTool() : print_command_usage_(false), usage_exit_status_(0) {} + +void GrpcTool::CommandUsage(const TString& usage) const { + if (print_command_usage_) { + fprintf(stderr, "\n%s%s\n", usage.c_str(), + (usage.empty() || usage[usage.size() - 1] != '\n') ? "\n" : ""); + exit(usage_exit_status_); + } +} + +bool GrpcTool::Help(int argc, const char** argv, const CliCredentials& cred, + GrpcToolOutputCallback callback) { + CommandUsage( + "Print help\n" + " grpc_cli help [subcommand]\n"); + + if (argc == 0) { + Usage(""); + } else { + const Command* cmd = FindCommand(argv[0]); + if (cmd == nullptr) { + Usage("Unknown command '" + TString(argv[0]) + "'"); + } + SetPrintCommandMode(0); + cmd->function(this, -1, nullptr, cred, callback); + } + return true; +} + +bool GrpcTool::ListServices(int argc, const char** argv, + const CliCredentials& cred, + GrpcToolOutputCallback callback) { + CommandUsage( + "List services\n" + " grpc_cli ls <address> [<service>[/<method>]]\n" + " <address> ; host:port\n" + " <service> ; Exported service name\n" + " <method> ; Method name\n" + " --l ; Use a long listing format\n" + " --outfile ; Output filename (defaults to stdout)\n" + + cred.GetCredentialUsage()); + + TString server_address(argv[0]); + std::shared_ptr<grpc::Channel> channel = + CreateCliChannel(server_address, cred); + grpc::ProtoReflectionDescriptorDatabase desc_db(channel); + grpc::protobuf::DescriptorPool desc_pool(&desc_db); + + std::vector<TString> service_list; + if (!desc_db.GetServices(&service_list)) { + fprintf(stderr, "Received an error when querying services endpoint.\n"); + return false; + } + + // If no service is specified, dump the list of services. + TString output; + if (argc < 2) { + // List all services, if --l is passed, then include full description, + // otherwise include a summarized list only. + if (FLAGS_l) { + output = DescribeServiceList(service_list, desc_pool); + } else { + for (auto it = service_list.begin(); it != service_list.end(); it++) { + auto const& service = *it; + output.append(service); + output.append("\n"); + } + } + } else { + std::string service_name; + std::string method_name; + std::stringstream ss(argv[1]); + + // Remove leading slashes. + while (ss.peek() == '/') { + ss.get(); + } + + // Parse service and method names. Support the following patterns: + // Service + // Service Method + // Service.Method + // Service/Method + if (argc == 3) { + std::getline(ss, service_name, '/'); + method_name = argv[2]; + } else { + if (std::getline(ss, service_name, '/')) { + std::getline(ss, method_name); + } + } + + const grpc::protobuf::ServiceDescriptor* service = + desc_pool.FindServiceByName(google::protobuf::string(service_name)); + if (service != nullptr) { + if (method_name.empty()) { + output = FLAGS_l ? DescribeService(service) : SummarizeService(service); + } else { + method_name.insert(0, "."); + method_name.insert(0, service_name); + const grpc::protobuf::MethodDescriptor* method = + desc_pool.FindMethodByName(google::protobuf::string(method_name)); + if (method != nullptr) { + output = FLAGS_l ? DescribeMethod(method) : SummarizeMethod(method); + } else { + fprintf(stderr, "Method %s not found in service %s.\n", + method_name.c_str(), service_name.c_str()); + return false; + } + } + } else { + if (!method_name.empty()) { + fprintf(stderr, "Service %s not found.\n", service_name.c_str()); + return false; + } else { + const grpc::protobuf::MethodDescriptor* method = + desc_pool.FindMethodByName(google::protobuf::string(service_name)); + if (method != nullptr) { + output = FLAGS_l ? DescribeMethod(method) : SummarizeMethod(method); + } else { + fprintf(stderr, "Service or method %s not found.\n", + service_name.c_str()); + return false; + } + } + } + } + return callback(output); +} + +bool GrpcTool::PrintType(int /*argc*/, const char** argv, + const CliCredentials& cred, + GrpcToolOutputCallback callback) { + CommandUsage( + "Print type\n" + " grpc_cli type <address> <type>\n" + " <address> ; host:port\n" + " <type> ; Protocol buffer type name\n" + + cred.GetCredentialUsage()); + + TString server_address(argv[0]); + std::shared_ptr<grpc::Channel> channel = + CreateCliChannel(server_address, cred); + grpc::ProtoReflectionDescriptorDatabase desc_db(channel); + grpc::protobuf::DescriptorPool desc_pool(&desc_db); + + TString output; + const grpc::protobuf::Descriptor* descriptor = + desc_pool.FindMessageTypeByName(argv[1]); + if (descriptor != nullptr) { + output = descriptor->DebugString(); + } else { + fprintf(stderr, "Type %s not found.\n", argv[1]); + return false; + } + return callback(output); +} + +bool GrpcTool::CallMethod(int argc, const char** argv, + const CliCredentials& cred, + GrpcToolOutputCallback callback) { + CommandUsage( + "Call method\n" + " grpc_cli call <address> <service>[.<method>] <request>\n" + " <address> ; host:port\n" + " <service> ; Exported service name\n" + " <method> ; Method name\n" + " <request> ; Text protobuffer (overrides infile)\n" + " --protofiles ; Comma separated proto files used as a" + " fallback when parsing request/response\n" + " --proto_path ; The search path of proto files, valid" + " only when --protofiles is given\n" + " --noremotedb ; Don't attempt to use reflection service" + " at all\n" + " --metadata ; The metadata to be sent to the server\n" + " --infile ; Input filename (defaults to stdin)\n" + " --outfile ; Output filename (defaults to stdout)\n" + " --binary_input ; Input in binary format\n" + " --binary_output ; Output in binary format\n" + " --json_input ; Input in json format\n" + " --json_output ; Output in json format\n" + " --timeout ; Specify timeout (in seconds), used to " + "set the deadline for RPCs. The default value of -1 means no " + "deadline has been set.\n" + + cred.GetCredentialUsage()); + + std::stringstream output_ss; + TString request_text; + TString server_address(argv[0]); + TString method_name(argv[1]); + TString formatted_method_name; + std::unique_ptr<ProtoFileParser> parser; + TString serialized_request_proto; + CliArgs cli_args; + cli_args.timeout = FLAGS_timeout; + bool print_mode = false; + + std::shared_ptr<grpc::Channel> channel = + CreateCliChannel(server_address, cred); + + if (!FLAGS_binary_input || !FLAGS_binary_output) { + parser.reset( + new grpc::testing::ProtoFileParser(FLAGS_remotedb ? channel : nullptr, + FLAGS_proto_path.c_str(), FLAGS_protofiles.c_str())); + if (parser->HasError()) { + fprintf( + stderr, + "Failed to find remote reflection service and local proto files.\n"); + return false; + } + } + + if (FLAGS_binary_input) { + formatted_method_name = method_name; + } else { + formatted_method_name = parser->GetFormattedMethodName(method_name); + if (parser->HasError()) { + fprintf(stderr, "Failed to find method %s in proto files.\n", + method_name.c_str()); + } + } + + if (argc == 3) { + request_text = argv[2]; + } + + if (parser->IsStreaming(method_name, true /* is_request */)) { + std::istream* input_stream; + std::ifstream input_file; + + if (FLAGS_batch) { + fprintf(stderr, "Batch mode for streaming RPC is not supported.\n"); + return false; + } + + std::multimap<TString, TString> client_metadata; + ParseMetadataFlag(&client_metadata); + PrintMetadata(client_metadata, "Sending client initial metadata:"); + + CliCall call(channel, formatted_method_name, client_metadata, cli_args); + if (FLAGS_display_peer_address) { + fprintf(stderr, "New call for method_name:%s has peer address:|%s|\n", + formatted_method_name.c_str(), call.peer().c_str()); + } + + if (FLAGS_infile.empty()) { + if (isatty(fileno(stdin))) { + print_mode = true; + fprintf(stderr, "reading streaming request message from stdin...\n"); + } + input_stream = &std::cin; + } else { + input_file.open(FLAGS_infile, std::ios::in | std::ios::binary); + input_stream = &input_file; + } + + gpr_mu parser_mu; + gpr_mu_init(&parser_mu); + std::thread read_thread(ReadResponse, &call, method_name, callback, + parser.get(), &parser_mu, print_mode); + + std::stringstream request_ss; + std::string line; + while (!request_text.empty() || + (!input_stream->eof() && getline(*input_stream, line))) { + if (!request_text.empty()) { + if (FLAGS_binary_input) { + serialized_request_proto = request_text; + request_text.clear(); + } else { + gpr_mu_lock(&parser_mu); + serialized_request_proto = parser->GetSerializedProtoFromMethod( + method_name, request_text, true /* is_request */, + FLAGS_json_input); + request_text.clear(); + if (parser->HasError()) { + if (print_mode) { + fprintf(stderr, "Failed to parse request.\n"); + } + gpr_mu_unlock(&parser_mu); + continue; + } + gpr_mu_unlock(&parser_mu); + } + + call.WriteAndWait(serialized_request_proto); + if (print_mode) { + fprintf(stderr, "Request sent.\n"); + } + } else { + if (line.length() == 0) { + request_text = request_ss.str(); + request_ss.str(TString()); + request_ss.clear(); + } else { + request_ss << line << ' '; + } + } + } + if (input_file.is_open()) { + input_file.close(); + } + + call.WritesDoneAndWait(); + read_thread.join(); + gpr_mu_destroy(&parser_mu); + + std::multimap<grpc::string_ref, grpc::string_ref> server_trailing_metadata; + Status status = call.Finish(&server_trailing_metadata); + PrintMetadata(server_trailing_metadata, + "Received trailing metadata from server:"); + + if (status.ok()) { + fprintf(stderr, "Stream RPC succeeded with OK status\n"); + return true; + } else { + fprintf(stderr, "Rpc failed with status code %d, error message: %s\n", + status.error_code(), status.error_message().c_str()); + return false; + } + + } else { // parser->IsStreaming(method_name, true /* is_request */) + if (FLAGS_batch) { + if (parser->IsStreaming(method_name, false /* is_request */)) { + fprintf(stderr, "Batch mode for streaming RPC is not supported.\n"); + return false; + } + + std::istream* input_stream; + std::ifstream input_file; + + if (FLAGS_infile.empty()) { + if (isatty(fileno(stdin))) { + print_mode = true; + fprintf(stderr, "reading request messages from stdin...\n"); + } + input_stream = &std::cin; + } else { + input_file.open(FLAGS_infile, std::ios::in | std::ios::binary); + input_stream = &input_file; + } + + std::multimap<TString, TString> client_metadata; + ParseMetadataFlag(&client_metadata); + if (print_mode) { + PrintMetadata(client_metadata, "Sending client initial metadata:"); + } + + std::stringstream request_ss; + std::string line; + while (!request_text.empty() || + (!input_stream->eof() && getline(*input_stream, line))) { + if (!request_text.empty()) { + if (FLAGS_binary_input) { + serialized_request_proto = request_text; + request_text.clear(); + } else { + serialized_request_proto = parser->GetSerializedProtoFromMethod( + method_name, request_text, true /* is_request */, + FLAGS_json_input); + request_text.clear(); + if (parser->HasError()) { + if (print_mode) { + fprintf(stderr, "Failed to parse request.\n"); + } + continue; + } + } + + TString serialized_response_proto; + std::multimap<grpc::string_ref, grpc::string_ref> + server_initial_metadata, server_trailing_metadata; + CliCall call(channel, formatted_method_name, client_metadata, + cli_args); + if (FLAGS_display_peer_address) { + fprintf(stderr, + "New call for method_name:%s has peer address:|%s|\n", + formatted_method_name.c_str(), call.peer().c_str()); + } + call.Write(serialized_request_proto); + call.WritesDone(); + if (!call.Read(&serialized_response_proto, + &server_initial_metadata)) { + fprintf(stderr, "Failed to read response.\n"); + } + Status status = call.Finish(&server_trailing_metadata); + + if (status.ok()) { + if (print_mode) { + fprintf(stderr, "Rpc succeeded with OK status.\n"); + PrintMetadata(server_initial_metadata, + "Received initial metadata from server:"); + PrintMetadata(server_trailing_metadata, + "Received trailing metadata from server:"); + } + + if (FLAGS_binary_output) { + if (!callback(serialized_response_proto)) { + break; + } + } else { + TString response_text = parser->GetFormattedStringFromMethod( + method_name, serialized_response_proto, + false /* is_request */, FLAGS_json_output); + + if (parser->HasError() && print_mode) { + fprintf(stderr, "Failed to parse response.\n"); + } else { + if (!callback(response_text)) { + break; + } + } + } + } else { + if (print_mode) { + fprintf(stderr, + "Rpc failed with status code %d, error message: %s\n", + status.error_code(), status.error_message().c_str()); + } + } + } else { + if (line.length() == 0) { + request_text = request_ss.str(); + request_ss.str(TString()); + request_ss.clear(); + } else { + request_ss << line << ' '; + } + } + } + + if (input_file.is_open()) { + input_file.close(); + } + + return true; + } + + if (argc == 3) { + if (!FLAGS_infile.empty()) { + fprintf(stderr, "warning: request given in argv, ignoring --infile\n"); + } + } else { + std::stringstream input_stream; + if (FLAGS_infile.empty()) { + if (isatty(fileno(stdin))) { + fprintf(stderr, "reading request message from stdin...\n"); + } + input_stream << std::cin.rdbuf(); + } else { + std::ifstream input_file(FLAGS_infile, std::ios::in | std::ios::binary); + input_stream << input_file.rdbuf(); + input_file.close(); + } + request_text = input_stream.str(); + } + + if (FLAGS_binary_input) { + serialized_request_proto = request_text; + } else { + serialized_request_proto = parser->GetSerializedProtoFromMethod( + method_name, request_text, true /* is_request */, FLAGS_json_input); + if (parser->HasError()) { + fprintf(stderr, "Failed to parse request.\n"); + return false; + } + } + fprintf(stderr, "connecting to %s\n", server_address.c_str()); + + TString serialized_response_proto; + std::multimap<TString, TString> client_metadata; + std::multimap<grpc::string_ref, grpc::string_ref> server_initial_metadata, + server_trailing_metadata; + ParseMetadataFlag(&client_metadata); + PrintMetadata(client_metadata, "Sending client initial metadata:"); + + CliCall call(channel, formatted_method_name, client_metadata, cli_args); + if (FLAGS_display_peer_address) { + fprintf(stderr, "New call for method_name:%s has peer address:|%s|\n", + formatted_method_name.c_str(), call.peer().c_str()); + } + call.Write(serialized_request_proto); + call.WritesDone(); + + for (bool receive_initial_metadata = true; call.Read( + &serialized_response_proto, + receive_initial_metadata ? &server_initial_metadata : nullptr); + receive_initial_metadata = false) { + if (!FLAGS_binary_output) { + serialized_response_proto = parser->GetFormattedStringFromMethod( + method_name, serialized_response_proto, false /* is_request */, + FLAGS_json_output); + if (parser->HasError()) { + fprintf(stderr, "Failed to parse response.\n"); + return false; + } + } + + if (receive_initial_metadata) { + PrintMetadata(server_initial_metadata, + "Received initial metadata from server:"); + } + if (!callback(serialized_response_proto)) { + return false; + } + } + Status status = call.Finish(&server_trailing_metadata); + PrintMetadata(server_trailing_metadata, + "Received trailing metadata from server:"); + if (status.ok()) { + fprintf(stderr, "Rpc succeeded with OK status\n"); + return true; + } else { + fprintf(stderr, "Rpc failed with status code %d, error message: %s\n", + status.error_code(), status.error_message().c_str()); + return false; + } + } + GPR_UNREACHABLE_CODE(return false); +} + +bool GrpcTool::ParseMessage(int argc, const char** argv, + const CliCredentials& cred, + GrpcToolOutputCallback callback) { + CommandUsage( + "Parse message\n" + " grpc_cli parse <address> <type> [<message>]\n" + " <address> ; host:port\n" + " <type> ; Protocol buffer type name\n" + " <message> ; Text protobuffer (overrides --infile)\n" + " --protofiles ; Comma separated proto files used as a" + " fallback when parsing request/response\n" + " --proto_path ; The search path of proto files, valid" + " only when --protofiles is given\n" + " --noremotedb ; Don't attempt to use reflection service" + " at all\n" + " --infile ; Input filename (defaults to stdin)\n" + " --outfile ; Output filename (defaults to stdout)\n" + " --binary_input ; Input in binary format\n" + " --binary_output ; Output in binary format\n" + " --json_input ; Input in json format\n" + " --json_output ; Output in json format\n" + + cred.GetCredentialUsage()); + + std::stringstream output_ss; + TString message_text; + TString server_address(argv[0]); + TString type_name(argv[1]); + std::unique_ptr<grpc::testing::ProtoFileParser> parser; + TString serialized_request_proto; + + if (argc == 3) { + message_text = argv[2]; + if (!FLAGS_infile.empty()) { + fprintf(stderr, "warning: message given in argv, ignoring --infile.\n"); + } + } else { + std::stringstream input_stream; + if (FLAGS_infile.empty()) { + if (isatty(fileno(stdin))) { + fprintf(stderr, "reading request message from stdin...\n"); + } + input_stream << std::cin.rdbuf(); + } else { + std::ifstream input_file(FLAGS_infile, std::ios::in | std::ios::binary); + input_stream << input_file.rdbuf(); + input_file.close(); + } + message_text = input_stream.str(); + } + + if (!FLAGS_binary_input || !FLAGS_binary_output) { + std::shared_ptr<grpc::Channel> channel = + CreateCliChannel(server_address, cred); + parser.reset( + new grpc::testing::ProtoFileParser(FLAGS_remotedb ? channel : nullptr, + FLAGS_proto_path.c_str(), FLAGS_protofiles.c_str())); + if (parser->HasError()) { + fprintf( + stderr, + "Failed to find remote reflection service and local proto files.\n"); + return false; + } + } + + if (FLAGS_binary_input) { + serialized_request_proto = message_text; + } else { + serialized_request_proto = parser->GetSerializedProtoFromMessageType( + type_name, message_text, FLAGS_json_input); + if (parser->HasError()) { + fprintf(stderr, "Failed to serialize the message.\n"); + return false; + } + } + + if (FLAGS_binary_output) { + output_ss << serialized_request_proto; + } else { + TString output_text; + output_text = parser->GetFormattedStringFromMessageType( + type_name, serialized_request_proto, FLAGS_json_output); + if (parser->HasError()) { + fprintf(stderr, "Failed to deserialize the message.\n"); + return false; + } + + output_ss << output_text << std::endl; + } + + return callback(output_ss.str()); +} + +bool GrpcTool::ToText(int argc, const char** argv, const CliCredentials& cred, + GrpcToolOutputCallback callback) { + CommandUsage( + "Convert binary message to text\n" + " grpc_cli totext <protofiles> <type>\n" + " <protofiles> ; Comma separated list of proto files\n" + " <type> ; Protocol buffer type name\n" + " --proto_path ; The search path of proto files\n" + " --infile ; Input filename (defaults to stdin)\n" + " --outfile ; Output filename (defaults to stdout)\n"); + + FLAGS_protofiles = argv[0]; + FLAGS_remotedb = false; + FLAGS_binary_input = true; + FLAGS_binary_output = false; + return ParseMessage(argc, argv, cred, callback); +} + +bool GrpcTool::ToJson(int argc, const char** argv, const CliCredentials& cred, + GrpcToolOutputCallback callback) { + CommandUsage( + "Convert binary message to json\n" + " grpc_cli tojson <protofiles> <type>\n" + " <protofiles> ; Comma separated list of proto files\n" + " <type> ; Protocol buffer type name\n" + " --proto_path ; The search path of proto files\n" + " --infile ; Input filename (defaults to stdin)\n" + " --outfile ; Output filename (defaults to stdout)\n"); + + FLAGS_protofiles = argv[0]; + FLAGS_remotedb = false; + FLAGS_binary_input = true; + FLAGS_binary_output = false; + FLAGS_json_output = true; + return ParseMessage(argc, argv, cred, callback); +} + +bool GrpcTool::ToBinary(int argc, const char** argv, const CliCredentials& cred, + GrpcToolOutputCallback callback) { + CommandUsage( + "Convert text message to binary\n" + " grpc_cli tobinary <protofiles> <type> [<message>]\n" + " <protofiles> ; Comma separated list of proto files\n" + " <type> ; Protocol buffer type name\n" + " --proto_path ; The search path of proto files\n" + " --infile ; Input filename (defaults to stdin)\n" + " --outfile ; Output filename (defaults to stdout)\n"); + + FLAGS_protofiles = argv[0]; + FLAGS_remotedb = false; + FLAGS_binary_input = false; + FLAGS_binary_output = true; + return ParseMessage(argc, argv, cred, callback); +} + +} // namespace testing +} // namespace grpc diff --git a/contrib/libs/grpc/test/cpp/util/grpc_tool.h b/contrib/libs/grpc/test/cpp/util/grpc_tool.h new file mode 100644 index 0000000000..5bb43430d3 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/util/grpc_tool.h @@ -0,0 +1,39 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef GRPC_TEST_CPP_UTIL_GRPC_TOOL_H +#define GRPC_TEST_CPP_UTIL_GRPC_TOOL_H + +#include <functional> + +#include <grpcpp/support/config.h> + +#include "test/cpp/util/cli_credentials.h" + +namespace grpc { +namespace testing { + +typedef std::function<bool(const TString&)> GrpcToolOutputCallback; + +int GrpcToolMainLib(int argc, const char** argv, const CliCredentials& cred, + GrpcToolOutputCallback callback); + +} // namespace testing +} // namespace grpc + +#endif // GRPC_TEST_CPP_UTIL_GRPC_TOOL_H diff --git a/contrib/libs/grpc/test/cpp/util/grpc_tool_test.cc b/contrib/libs/grpc/test/cpp/util/grpc_tool_test.cc new file mode 100644 index 0000000000..ff610daadd --- /dev/null +++ b/contrib/libs/grpc/test/cpp/util/grpc_tool_test.cc @@ -0,0 +1,1344 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/cpp/util/grpc_tool.h" + +#include <gflags/gflags.h> +#include <grpc/grpc.h> +#include <grpc/support/alloc.h> +#include <grpcpp/channel.h> +#include <grpcpp/client_context.h> +#include <grpcpp/create_channel.h> +#include <grpcpp/ext/proto_server_reflection_plugin.h> +#include <grpcpp/server.h> +#include <grpcpp/server_builder.h> +#include <grpcpp/server_context.h> +#include <gtest/gtest.h> + +#include <chrono> +#include <sstream> + +#include "src/core/lib/gpr/env.h" +#include "src/core/lib/iomgr/load_file.h" +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "src/proto/grpc/testing/echo.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/util/cli_credentials.h" +#include "test/cpp/util/string_ref_helper.h" + +#define CA_CERT_PATH "src/core/tsi/test_creds/ca.pem" +#define SERVER_CERT_PATH "src/core/tsi/test_creds/server1.pem" +#define SERVER_KEY_PATH "src/core/tsi/test_creds/server1.key" + +using grpc::testing::EchoRequest; +using grpc::testing::EchoResponse; + +#define USAGE_REGEX "( grpc_cli .+\n){2,10}" + +#define ECHO_TEST_SERVICE_SUMMARY \ + "Echo\n" \ + "Echo1\n" \ + "Echo2\n" \ + "CheckDeadlineUpperBound\n" \ + "CheckDeadlineSet\n" \ + "CheckClientInitialMetadata\n" \ + "RequestStream\n" \ + "ResponseStream\n" \ + "BidiStream\n" \ + "Unimplemented\n" + +#define ECHO_TEST_SERVICE_DESCRIPTION \ + "filename: src/proto/grpc/testing/echo.proto\n" \ + "package: grpc.testing;\n" \ + "service EchoTestService {\n" \ + " rpc Echo(grpc.testing.EchoRequest) returns (grpc.testing.EchoResponse) " \ + "{}\n" \ + " rpc Echo1(grpc.testing.EchoRequest) returns (grpc.testing.EchoResponse) " \ + "{}\n" \ + " rpc Echo2(grpc.testing.EchoRequest) returns (grpc.testing.EchoResponse) " \ + "{}\n" \ + " rpc CheckDeadlineUpperBound(grpc.testing.SimpleRequest) returns " \ + "(grpc.testing.StringValue) {}\n" \ + " rpc CheckDeadlineSet(grpc.testing.SimpleRequest) returns " \ + "(grpc.testing.StringValue) {}\n" \ + " rpc CheckClientInitialMetadata(grpc.testing.SimpleRequest) returns " \ + "(grpc.testing.SimpleResponse) {}\n" \ + " rpc RequestStream(stream grpc.testing.EchoRequest) returns " \ + "(grpc.testing.EchoResponse) {}\n" \ + " rpc ResponseStream(grpc.testing.EchoRequest) returns (stream " \ + "grpc.testing.EchoResponse) {}\n" \ + " rpc BidiStream(stream grpc.testing.EchoRequest) returns (stream " \ + "grpc.testing.EchoResponse) {}\n" \ + " rpc Unimplemented(grpc.testing.EchoRequest) returns " \ + "(grpc.testing.EchoResponse) {}\n" \ + "}\n" \ + "\n" + +#define ECHO_METHOD_DESCRIPTION \ + " rpc Echo(grpc.testing.EchoRequest) returns (grpc.testing.EchoResponse) " \ + "{}\n" + +#define ECHO_RESPONSE_MESSAGE_TEXT_FORMAT \ + "message: \"echo\"\n" \ + "param {\n" \ + " host: \"localhost\"\n" \ + " peer: \"peer\"\n" \ + "}\n\n" + +#define ECHO_RESPONSE_MESSAGE_JSON_FORMAT \ + "{\n" \ + " \"message\": \"echo\",\n" \ + " \"param\": {\n" \ + " \"host\": \"localhost\",\n" \ + " \"peer\": \"peer\"\n" \ + " }\n" \ + "}\n\n" + +DECLARE_string(channel_creds_type); +DECLARE_string(ssl_target); + +namespace grpc { +namespace testing { + +DECLARE_bool(binary_input); +DECLARE_bool(binary_output); +DECLARE_bool(json_input); +DECLARE_bool(json_output); +DECLARE_bool(l); +DECLARE_bool(batch); +DECLARE_string(metadata); +DECLARE_string(protofiles); +DECLARE_string(proto_path); +DECLARE_string(default_service_config); +DECLARE_double(timeout); + +namespace { + +const int kServerDefaultResponseStreamsToSend = 3; + +class TestCliCredentials final : public grpc::testing::CliCredentials { + public: + TestCliCredentials(bool secure = false) : secure_(secure) {} + std::shared_ptr<grpc::ChannelCredentials> GetChannelCredentials() + const override { + if (!secure_) { + return InsecureChannelCredentials(); + } + grpc_slice ca_slice; + GPR_ASSERT(GRPC_LOG_IF_ERROR("load_file", + grpc_load_file(CA_CERT_PATH, 1, &ca_slice))); + const char* test_root_cert = + reinterpret_cast<const char*> GRPC_SLICE_START_PTR(ca_slice); + SslCredentialsOptions ssl_opts = {test_root_cert, "", ""}; + std::shared_ptr<grpc::ChannelCredentials> credential_ptr = + grpc::SslCredentials(grpc::SslCredentialsOptions(ssl_opts)); + grpc_slice_unref(ca_slice); + return credential_ptr; + } + const TString GetCredentialUsage() const override { return ""; } + + private: + const bool secure_; +}; + +bool PrintStream(std::stringstream* ss, const TString& output) { + (*ss) << output; + return true; +} + +template <typename T> +size_t ArraySize(T& a) { + return ((sizeof(a) / sizeof(*(a))) / + static_cast<size_t>(!(sizeof(a) % sizeof(*(a))))); +} + +class TestServiceImpl : public ::grpc::testing::EchoTestService::Service { + public: + Status Echo(ServerContext* context, const EchoRequest* request, + EchoResponse* response) override { + if (!context->client_metadata().empty()) { + for (std::multimap<grpc::string_ref, grpc::string_ref>::const_iterator + iter = context->client_metadata().begin(); + iter != context->client_metadata().end(); ++iter) { + context->AddInitialMetadata(ToString(iter->first), + ToString(iter->second)); + } + } + context->AddTrailingMetadata("trailing_key", "trailing_value"); + response->set_message(request->message()); + return Status::OK; + } + + Status CheckDeadlineSet(ServerContext* context, const SimpleRequest* request, + StringValue* response) override { + response->set_message(context->deadline() != + std::chrono::system_clock::time_point::max() + ? "true" + : "false"); + return Status::OK; + } + + // Check if deadline - current time <= timeout + // If deadline set, timeout + current time should be an upper bound for it + Status CheckDeadlineUpperBound(ServerContext* context, + const SimpleRequest* request, + StringValue* response) override { + auto seconds = std::chrono::duration_cast<std::chrono::seconds>( + context->deadline() - std::chrono::system_clock::now()); + + // Returning string instead of bool to avoid using embedded messages in + // proto3 + response->set_message(seconds.count() <= FLAGS_timeout ? "true" : "false"); + return Status::OK; + } + + Status RequestStream(ServerContext* context, + ServerReader<EchoRequest>* reader, + EchoResponse* response) override { + EchoRequest request; + response->set_message(""); + if (!context->client_metadata().empty()) { + for (std::multimap<grpc::string_ref, grpc::string_ref>::const_iterator + iter = context->client_metadata().begin(); + iter != context->client_metadata().end(); ++iter) { + context->AddInitialMetadata(ToString(iter->first), + ToString(iter->second)); + } + } + context->AddTrailingMetadata("trailing_key", "trailing_value"); + while (reader->Read(&request)) { + response->mutable_message()->append(request.message()); + } + + return Status::OK; + } + + Status ResponseStream(ServerContext* context, const EchoRequest* request, + ServerWriter<EchoResponse>* writer) override { + if (!context->client_metadata().empty()) { + for (std::multimap<grpc::string_ref, grpc::string_ref>::const_iterator + iter = context->client_metadata().begin(); + iter != context->client_metadata().end(); ++iter) { + context->AddInitialMetadata(ToString(iter->first), + ToString(iter->second)); + } + } + context->AddTrailingMetadata("trailing_key", "trailing_value"); + + EchoResponse response; + for (int i = 0; i < kServerDefaultResponseStreamsToSend; i++) { + response.set_message(request->message() + ToString(i)); + writer->Write(response); + } + + return Status::OK; + } + + Status BidiStream( + ServerContext* context, + ServerReaderWriter<EchoResponse, EchoRequest>* stream) override { + EchoRequest request; + EchoResponse response; + if (!context->client_metadata().empty()) { + for (std::multimap<grpc::string_ref, grpc::string_ref>::const_iterator + iter = context->client_metadata().begin(); + iter != context->client_metadata().end(); ++iter) { + context->AddInitialMetadata(ToString(iter->first), + ToString(iter->second)); + } + } + context->AddTrailingMetadata("trailing_key", "trailing_value"); + + while (stream->Read(&request)) { + response.set_message(request.message()); + stream->Write(response); + } + + return Status::OK; + } +}; + +} // namespace + +class GrpcToolTest : public ::testing::Test { + protected: + GrpcToolTest() {} + + // SetUpServer cannot be used with EXPECT_EXIT. grpc_pick_unused_port_or_die() + // uses atexit() to free chosen ports, and it will spawn a new thread in + // resolve_address_posix.c:192 at exit time. + const TString SetUpServer(bool secure = false) { + std::ostringstream server_address; + int port = grpc_pick_unused_port_or_die(); + server_address << "localhost:" << port; + // Setup server + ServerBuilder builder; + std::shared_ptr<grpc::ServerCredentials> creds; + grpc_slice cert_slice, key_slice; + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "load_file", grpc_load_file(SERVER_CERT_PATH, 1, &cert_slice))); + GPR_ASSERT(GRPC_LOG_IF_ERROR( + "load_file", grpc_load_file(SERVER_KEY_PATH, 1, &key_slice))); + const char* server_cert = + reinterpret_cast<const char*> GRPC_SLICE_START_PTR(cert_slice); + const char* server_key = + reinterpret_cast<const char*> GRPC_SLICE_START_PTR(key_slice); + SslServerCredentialsOptions::PemKeyCertPair pkcp = {server_key, + server_cert}; + if (secure) { + SslServerCredentialsOptions ssl_opts; + ssl_opts.pem_root_certs = ""; + ssl_opts.pem_key_cert_pairs.push_back(pkcp); + creds = SslServerCredentials(ssl_opts); + } else { + creds = InsecureServerCredentials(); + } + builder.AddListeningPort(server_address.str(), creds); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + grpc_slice_unref(cert_slice); + grpc_slice_unref(key_slice); + return server_address.str(); + } + + void ShutdownServer() { server_->Shutdown(); } + + std::unique_ptr<Server> server_; + TestServiceImpl service_; + reflection::ProtoServerReflectionPlugin plugin_; +}; + +TEST_F(GrpcToolTest, NoCommand) { + // Test input "grpc_cli" + std::stringstream output_stream; + const char* argv[] = {"grpc_cli"}; + // Exit with 1, print usage instruction in stderr + EXPECT_EXIT( + GrpcToolMainLib( + ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, std::placeholders::_1)), + ::testing::ExitedWithCode(1), "No command specified\n" USAGE_REGEX); + // No output + EXPECT_TRUE(0 == output_stream.tellp()); +} + +TEST_F(GrpcToolTest, InvalidCommand) { + // Test input "grpc_cli" + std::stringstream output_stream; + const char* argv[] = {"grpc_cli", "abc"}; + // Exit with 1, print usage instruction in stderr + EXPECT_EXIT( + GrpcToolMainLib( + ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, std::placeholders::_1)), + ::testing::ExitedWithCode(1), "Invalid command 'abc'\n" USAGE_REGEX); + // No output + EXPECT_TRUE(0 == output_stream.tellp()); +} + +TEST_F(GrpcToolTest, HelpCommand) { + // Test input "grpc_cli help" + std::stringstream output_stream; + const char* argv[] = {"grpc_cli", "help"}; + // Exit with 1, print usage instruction in stderr + EXPECT_EXIT(GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1)), + ::testing::ExitedWithCode(1), USAGE_REGEX); + // No output + EXPECT_TRUE(0 == output_stream.tellp()); +} + +TEST_F(GrpcToolTest, ListCommand) { + // Test input "grpc_cli list localhost:<port>" + std::stringstream output_stream; + + const TString server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "ls", server_address.c_str()}; + + FLAGS_l = false; + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + EXPECT_TRUE(0 == strcmp(output_stream.str().c_str(), + "grpc.testing.EchoTestService\n" + "grpc.reflection.v1alpha.ServerReflection\n")); + + ShutdownServer(); +} + +TEST_F(GrpcToolTest, ListOneService) { + // Test input "grpc_cli list localhost:<port> grpc.testing.EchoTestService" + std::stringstream output_stream; + + const TString server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "ls", server_address.c_str(), + "grpc.testing.EchoTestService"}; + // without -l flag + FLAGS_l = false; + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + // Expected output: ECHO_TEST_SERVICE_SUMMARY + EXPECT_TRUE(0 == + strcmp(output_stream.str().c_str(), ECHO_TEST_SERVICE_SUMMARY)); + + // with -l flag + output_stream.str(TString()); + output_stream.clear(); + FLAGS_l = true; + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + // Expected output: ECHO_TEST_SERVICE_DESCRIPTION + EXPECT_TRUE( + 0 == strcmp(output_stream.str().c_str(), ECHO_TEST_SERVICE_DESCRIPTION)); + + ShutdownServer(); +} + +TEST_F(GrpcToolTest, TypeCommand) { + // Test input "grpc_cli type localhost:<port> grpc.testing.EchoRequest" + std::stringstream output_stream; + + const TString server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "type", server_address.c_str(), + "grpc.testing.EchoRequest"}; + + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + const grpc::protobuf::Descriptor* desc = + grpc::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + "grpc.testing.EchoRequest"); + // Expected output: the DebugString of grpc.testing.EchoRequest + EXPECT_TRUE(0 == + strcmp(output_stream.str().c_str(), desc->DebugString().c_str())); + + ShutdownServer(); +} + +TEST_F(GrpcToolTest, ListOneMethod) { + // Test input "grpc_cli list localhost:<port> grpc.testing.EchoTestService" + std::stringstream output_stream; + + const TString server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "ls", server_address.c_str(), + "grpc.testing.EchoTestService.Echo"}; + // without -l flag + FLAGS_l = false; + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + // Expected output: "Echo" + EXPECT_TRUE(0 == strcmp(output_stream.str().c_str(), "Echo\n")); + + // with -l flag + output_stream.str(TString()); + output_stream.clear(); + FLAGS_l = true; + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + // Expected output: ECHO_METHOD_DESCRIPTION + EXPECT_TRUE(0 == + strcmp(output_stream.str().c_str(), ECHO_METHOD_DESCRIPTION)); + + ShutdownServer(); +} + +TEST_F(GrpcToolTest, TypeNotFound) { + // Test input "grpc_cli type localhost:<port> grpc.testing.DummyRequest" + std::stringstream output_stream; + + const TString server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "type", server_address.c_str(), + "grpc.testing.DummyRequest"}; + + EXPECT_TRUE(1 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + ShutdownServer(); +} + +TEST_F(GrpcToolTest, CallCommand) { + // Test input "grpc_cli call localhost:<port> Echo "message: 'Hello'" + std::stringstream output_stream; + + const TString server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "call", server_address.c_str(), "Echo", + "message: 'Hello'"}; + + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + // Expected output: "message: \"Hello\"" + EXPECT_TRUE(nullptr != + strstr(output_stream.str().c_str(), "message: \"Hello\"")); + + // with json_output + output_stream.str(TString()); + output_stream.clear(); + + FLAGS_json_output = true; + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + FLAGS_json_output = false; + + // Expected output: + // { + // "message": "Hello" + // } + EXPECT_TRUE(nullptr != strstr(output_stream.str().c_str(), + "{\n \"message\": \"Hello\"\n}")); + + ShutdownServer(); +} + +TEST_F(GrpcToolTest, CallCommandJsonInput) { + // Test input "grpc_cli call localhost:<port> Echo "{ \"message\": \"Hello\"}" + std::stringstream output_stream; + + const TString server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "call", server_address.c_str(), "Echo", + "{ \"message\": \"Hello\"}"}; + + FLAGS_json_input = true; + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + // Expected output: "message: \"Hello\"" + EXPECT_TRUE(nullptr != + strstr(output_stream.str().c_str(), "message: \"Hello\"")); + + // with json_output + output_stream.str(TString()); + output_stream.clear(); + + FLAGS_json_output = true; + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + FLAGS_json_output = false; + FLAGS_json_input = false; + + // Expected output: + // { + // "message": "Hello" + // } + EXPECT_TRUE(nullptr != strstr(output_stream.str().c_str(), + "{\n \"message\": \"Hello\"\n}")); + + ShutdownServer(); +} + +TEST_F(GrpcToolTest, CallCommandBatch) { + // Test input "grpc_cli call Echo" + std::stringstream output_stream; + + const TString server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "call", server_address.c_str(), "Echo", + "message: 'Hello0'"}; + + // Mock std::cin input "message: 'Hello1'\n\n message: 'Hello2'\n\n" + std::streambuf* orig = std::cin.rdbuf(); + std::istringstream ss("message: 'Hello1'\n\n message: 'Hello2'\n\n"); + std::cin.rdbuf(ss.rdbuf()); + + FLAGS_batch = true; + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + FLAGS_batch = false; + + // Expected output: "message: "Hello0"\nmessage: "Hello1"\nmessage: + // "Hello2"\n" + EXPECT_TRUE(nullptr != strstr(output_stream.str().c_str(), + "message: \"Hello0\"\nmessage: " + "\"Hello1\"\nmessage: \"Hello2\"\n")); + // with json_output + output_stream.str(TString()); + output_stream.clear(); + ss.clear(); + ss.seekg(0); + std::cin.rdbuf(ss.rdbuf()); + + FLAGS_batch = true; + FLAGS_json_output = true; + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + FLAGS_json_output = false; + FLAGS_batch = false; + + // Expected output: + // { + // "message": "Hello0" + // } + // { + // "message": "Hello1" + // } + // { + // "message": "Hello2" + // } + // Expected output: "message: "Hello0"\nmessage: "Hello1"\nmessage: + // "Hello2"\n" + EXPECT_TRUE(nullptr != strstr(output_stream.str().c_str(), + "{\n \"message\": \"Hello0\"\n}\n" + "{\n \"message\": \"Hello1\"\n}\n" + "{\n \"message\": \"Hello2\"\n}\n")); + + std::cin.rdbuf(orig); + ShutdownServer(); +} + +TEST_F(GrpcToolTest, CallCommandBatchJsonInput) { + // Test input "grpc_cli call Echo" + std::stringstream output_stream; + + const TString server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "call", server_address.c_str(), "Echo", + "{\"message\": \"Hello0\"}"}; + + // Mock std::cin input "message: 'Hello1'\n\n message: 'Hello2'\n\n" + std::streambuf* orig = std::cin.rdbuf(); + std::istringstream ss( + "{\"message\": \"Hello1\"}\n\n{\"message\": \"Hello2\" }\n\n"); + std::cin.rdbuf(ss.rdbuf()); + + FLAGS_json_input = true; + FLAGS_batch = true; + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + FLAGS_batch = false; + + // Expected output: "message: "Hello0"\nmessage: "Hello1"\nmessage: + // "Hello2"\n" + EXPECT_TRUE(nullptr != strstr(output_stream.str().c_str(), + "message: \"Hello0\"\nmessage: " + "\"Hello1\"\nmessage: \"Hello2\"\n")); + // with json_output + output_stream.str(TString()); + output_stream.clear(); + ss.clear(); + ss.seekg(0); + std::cin.rdbuf(ss.rdbuf()); + + FLAGS_batch = true; + FLAGS_json_output = true; + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + FLAGS_json_output = false; + FLAGS_batch = false; + FLAGS_json_input = false; + + // Expected output: + // { + // "message": "Hello0" + // } + // { + // "message": "Hello1" + // } + // { + // "message": "Hello2" + // } + // Expected output: "message: "Hello0"\nmessage: "Hello1"\nmessage: + // "Hello2"\n" + EXPECT_TRUE(nullptr != strstr(output_stream.str().c_str(), + "{\n \"message\": \"Hello0\"\n}\n" + "{\n \"message\": \"Hello1\"\n}\n" + "{\n \"message\": \"Hello2\"\n}\n")); + + std::cin.rdbuf(orig); + ShutdownServer(); +} + +TEST_F(GrpcToolTest, CallCommandBatchWithBadRequest) { + // Test input "grpc_cli call Echo" + std::stringstream output_stream; + + const TString server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "call", server_address.c_str(), "Echo", + "message: 'Hello0'"}; + + // Mock std::cin input "message: 1\n\n message: 'Hello2'\n\n" + std::streambuf* orig = std::cin.rdbuf(); + std::istringstream ss("message: 1\n\n message: 'Hello2'\n\n"); + std::cin.rdbuf(ss.rdbuf()); + + FLAGS_batch = true; + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + FLAGS_batch = false; + + // Expected output: "message: "Hello0"\nmessage: "Hello2"\n" + EXPECT_TRUE(nullptr != strstr(output_stream.str().c_str(), + "message: \"Hello0\"\nmessage: \"Hello2\"\n")); + + // with json_output + output_stream.str(TString()); + output_stream.clear(); + ss.clear(); + ss.seekg(0); + std::cin.rdbuf(ss.rdbuf()); + + FLAGS_batch = true; + FLAGS_json_output = true; + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + FLAGS_json_output = false; + FLAGS_batch = false; + + // Expected output: + // { + // "message": "Hello0" + // } + // { + // "message": "Hello2" + // } + // Expected output: "message: "Hello0"\nmessage: "Hello1"\nmessage: + // "Hello2"\n" + EXPECT_TRUE(nullptr != strstr(output_stream.str().c_str(), + "{\n \"message\": \"Hello0\"\n}\n" + "{\n \"message\": \"Hello2\"\n}\n")); + + std::cin.rdbuf(orig); + ShutdownServer(); +} + +TEST_F(GrpcToolTest, CallCommandBatchJsonInputWithBadRequest) { + // Test input "grpc_cli call Echo" + std::stringstream output_stream; + + const TString server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "call", server_address.c_str(), "Echo", + "{ \"message\": \"Hello0\"}"}; + + // Mock std::cin input "message: 1\n\n message: 'Hello2'\n\n" + std::streambuf* orig = std::cin.rdbuf(); + std::istringstream ss( + "{ \"message\": 1 }\n\n { \"message\": \"Hello2\" }\n\n"); + std::cin.rdbuf(ss.rdbuf()); + + FLAGS_batch = true; + FLAGS_json_input = true; + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + FLAGS_json_input = false; + FLAGS_batch = false; + + // Expected output: "message: "Hello0"\nmessage: "Hello2"\n" + EXPECT_TRUE(nullptr != strstr(output_stream.str().c_str(), + "message: \"Hello0\"\nmessage: \"Hello2\"\n")); + + // with json_output + output_stream.str(TString()); + output_stream.clear(); + ss.clear(); + ss.seekg(0); + std::cin.rdbuf(ss.rdbuf()); + + FLAGS_batch = true; + FLAGS_json_input = true; + FLAGS_json_output = true; + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + FLAGS_json_output = false; + FLAGS_json_input = false; + FLAGS_batch = false; + + // Expected output: + // { + // "message": "Hello0" + // } + // { + // "message": "Hello2" + // } + // Expected output: "message: "Hello0"\nmessage: "Hello1"\nmessage: + // "Hello2"\n" + EXPECT_TRUE(nullptr != strstr(output_stream.str().c_str(), + "{\n \"message\": \"Hello0\"\n}\n" + "{\n \"message\": \"Hello2\"\n}\n")); + + std::cin.rdbuf(orig); + ShutdownServer(); +} + +TEST_F(GrpcToolTest, CallCommandRequestStream) { + // Test input: grpc_cli call localhost:<port> RequestStream "message: + // 'Hello0'" + std::stringstream output_stream; + + const TString server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "call", server_address.c_str(), + "RequestStream", "message: 'Hello0'"}; + + // Mock std::cin input "message: 'Hello1'\n\n message: 'Hello2'\n\n" + std::streambuf* orig = std::cin.rdbuf(); + std::istringstream ss("message: 'Hello1'\n\n message: 'Hello2'\n\n"); + std::cin.rdbuf(ss.rdbuf()); + + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + + // Expected output: "message: \"Hello0Hello1Hello2\"" + EXPECT_TRUE(nullptr != strstr(output_stream.str().c_str(), + "message: \"Hello0Hello1Hello2\"")); + std::cin.rdbuf(orig); + ShutdownServer(); +} + +TEST_F(GrpcToolTest, CallCommandRequestStreamJsonInput) { + // Test input: grpc_cli call localhost:<port> RequestStream "{ \"message\": + // \"Hello0\"}" + std::stringstream output_stream; + + const TString server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "call", server_address.c_str(), + "RequestStream", "{ \"message\": \"Hello0\" }"}; + + // Mock std::cin input "message: 'Hello1'\n\n message: 'Hello2'\n\n" + std::streambuf* orig = std::cin.rdbuf(); + std::istringstream ss( + "{ \"message\": \"Hello1\" }\n\n{ \"message\": \"Hello2\" }\n\n"); + std::cin.rdbuf(ss.rdbuf()); + + FLAGS_json_input = true; + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + FLAGS_json_input = false; + + // Expected output: "message: \"Hello0Hello1Hello2\"" + EXPECT_TRUE(nullptr != strstr(output_stream.str().c_str(), + "message: \"Hello0Hello1Hello2\"")); + std::cin.rdbuf(orig); + ShutdownServer(); +} + +TEST_F(GrpcToolTest, CallCommandRequestStreamWithBadRequest) { + // Test input: grpc_cli call localhost:<port> RequestStream "message: + // 'Hello0'" + std::stringstream output_stream; + + const TString server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "call", server_address.c_str(), + "RequestStream", "message: 'Hello0'"}; + + // Mock std::cin input "bad_field: 'Hello1'\n\n message: 'Hello2'\n\n" + std::streambuf* orig = std::cin.rdbuf(); + std::istringstream ss("bad_field: 'Hello1'\n\n message: 'Hello2'\n\n"); + std::cin.rdbuf(ss.rdbuf()); + + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + + // Expected output: "message: \"Hello0Hello2\"" + EXPECT_TRUE(nullptr != + strstr(output_stream.str().c_str(), "message: \"Hello0Hello2\"")); + std::cin.rdbuf(orig); + ShutdownServer(); +} + +TEST_F(GrpcToolTest, CallCommandRequestStreamWithBadRequestJsonInput) { + // Test input: grpc_cli call localhost:<port> RequestStream "message: + // 'Hello0'" + std::stringstream output_stream; + + const TString server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "call", server_address.c_str(), + "RequestStream", "{ \"message\": \"Hello0\" }"}; + + // Mock std::cin input "bad_field: 'Hello1'\n\n message: 'Hello2'\n\n" + std::streambuf* orig = std::cin.rdbuf(); + std::istringstream ss( + "{ \"bad_field\": \"Hello1\" }\n\n{ \"message\": \"Hello2\" }\n\n"); + std::cin.rdbuf(ss.rdbuf()); + + FLAGS_json_input = true; + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + FLAGS_json_input = false; + + // Expected output: "message: \"Hello0Hello2\"" + EXPECT_TRUE(nullptr != + strstr(output_stream.str().c_str(), "message: \"Hello0Hello2\"")); + std::cin.rdbuf(orig); + ShutdownServer(); +} + +TEST_F(GrpcToolTest, CallCommandWithTimeoutDeadlineSet) { + // Test input "grpc_cli call CheckDeadlineSet --timeout=5000.25" + std::stringstream output_stream; + + const TString server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "call", server_address.c_str(), + "CheckDeadlineSet"}; + + // Set timeout to 5000.25 seconds + FLAGS_timeout = 5000.25; + + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + + // Expected output: "message: "true"", deadline set + EXPECT_TRUE(nullptr != + strstr(output_stream.str().c_str(), "message: \"true\"")); + ShutdownServer(); +} + +TEST_F(GrpcToolTest, CallCommandWithTimeoutDeadlineUpperBound) { + // Test input "grpc_cli call CheckDeadlineUpperBound --timeout=900" + std::stringstream output_stream; + + const TString server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "call", server_address.c_str(), + "CheckDeadlineUpperBound"}; + + // Set timeout to 900 seconds + FLAGS_timeout = 900; + + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + + // Expected output: "message: "true"" + // deadline not greater than timeout + current time + EXPECT_TRUE(nullptr != + strstr(output_stream.str().c_str(), "message: \"true\"")); + ShutdownServer(); +} + +TEST_F(GrpcToolTest, CallCommandWithNegativeTimeoutValue) { + // Test input "grpc_cli call CheckDeadlineSet --timeout=-5" + std::stringstream output_stream; + + const TString server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "call", server_address.c_str(), + "CheckDeadlineSet"}; + + // Set timeout to -5 (deadline not set) + FLAGS_timeout = -5; + + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + + // Expected output: "message: "false"", deadline not set + EXPECT_TRUE(nullptr != + strstr(output_stream.str().c_str(), "message: \"false\"")); + + ShutdownServer(); +} + +TEST_F(GrpcToolTest, CallCommandWithDefaultTimeoutValue) { + // Test input "grpc_cli call CheckDeadlineSet --timeout=-1" + std::stringstream output_stream; + + const TString server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "call", server_address.c_str(), + "CheckDeadlineSet"}; + + // Set timeout to -1 (default value, deadline not set) + FLAGS_timeout = -1; + + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + + // Expected output: "message: "false"", deadline not set + EXPECT_TRUE(nullptr != + strstr(output_stream.str().c_str(), "message: \"false\"")); + + ShutdownServer(); +} + +TEST_F(GrpcToolTest, CallCommandResponseStream) { + // Test input: grpc_cli call localhost:<port> ResponseStream "message: + // 'Hello'" + std::stringstream output_stream; + + const TString server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "call", server_address.c_str(), + "ResponseStream", "message: 'Hello'"}; + + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + + // Expected output: "message: \"Hello{n}\"" + for (int i = 0; i < kServerDefaultResponseStreamsToSend; i++) { + TString expected_response_text = + "message: \"Hello" + ToString(i) + "\"\n"; + EXPECT_TRUE(nullptr != strstr(output_stream.str().c_str(), + expected_response_text.c_str())); + } + + // with json_output + output_stream.str(TString()); + output_stream.clear(); + + FLAGS_json_output = true; + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + FLAGS_json_output = false; + + // Expected output: "{\n \"message\": \"Hello{n}\"\n}\n" + for (int i = 0; i < kServerDefaultResponseStreamsToSend; i++) { + TString expected_response_text = + "{\n \"message\": \"Hello" + ToString(i) + "\"\n}\n"; + EXPECT_TRUE(nullptr != strstr(output_stream.str().c_str(), + expected_response_text.c_str())); + } + + ShutdownServer(); +} + +TEST_F(GrpcToolTest, CallCommandBidiStream) { + // Test input: grpc_cli call localhost:<port> BidiStream "message: 'Hello0'" + std::stringstream output_stream; + + const TString server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "call", server_address.c_str(), + "BidiStream", "message: 'Hello0'"}; + + // Mock std::cin input "message: 'Hello1'\n\n message: 'Hello2'\n\n" + std::streambuf* orig = std::cin.rdbuf(); + std::istringstream ss("message: 'Hello1'\n\n message: 'Hello2'\n\n"); + std::cin.rdbuf(ss.rdbuf()); + + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + + // Expected output: "message: \"Hello0\"\nmessage: \"Hello1\"\nmessage: + // \"Hello2\"\n\n" + EXPECT_TRUE(nullptr != strstr(output_stream.str().c_str(), + "message: \"Hello0\"\nmessage: " + "\"Hello1\"\nmessage: \"Hello2\"\n")); + std::cin.rdbuf(orig); + ShutdownServer(); +} + +TEST_F(GrpcToolTest, CallCommandBidiStreamWithBadRequest) { + // Test input: grpc_cli call localhost:<port> BidiStream "message: 'Hello0'" + std::stringstream output_stream; + + const TString server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "call", server_address.c_str(), + "BidiStream", "message: 'Hello0'"}; + + // Mock std::cin input "message: 'Hello1'\n\n message: 'Hello2'\n\n" + std::streambuf* orig = std::cin.rdbuf(); + std::istringstream ss("message: 1.0\n\n message: 'Hello2'\n\n"); + std::cin.rdbuf(ss.rdbuf()); + + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + + // Expected output: "message: \"Hello0\"\nmessage: \"Hello1\"\nmessage: + // \"Hello2\"\n\n" + EXPECT_TRUE(nullptr != strstr(output_stream.str().c_str(), + "message: \"Hello0\"\nmessage: \"Hello2\"\n")); + std::cin.rdbuf(orig); + + ShutdownServer(); +} + +TEST_F(GrpcToolTest, ParseCommand) { + // Test input "grpc_cli parse localhost:<port> grpc.testing.EchoResponse + // ECHO_RESPONSE_MESSAGE" + std::stringstream output_stream; + std::stringstream binary_output_stream; + + const TString server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "parse", server_address.c_str(), + "grpc.testing.EchoResponse", + ECHO_RESPONSE_MESSAGE_TEXT_FORMAT}; + + FLAGS_binary_input = false; + FLAGS_binary_output = false; + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + // Expected output: ECHO_RESPONSE_MESSAGE_TEXT_FORMAT + EXPECT_TRUE(0 == strcmp(output_stream.str().c_str(), + ECHO_RESPONSE_MESSAGE_TEXT_FORMAT)); + + // with json_output + output_stream.str(TString()); + output_stream.clear(); + + FLAGS_json_output = true; + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + FLAGS_json_output = false; + + // Expected output: ECHO_RESPONSE_MESSAGE_JSON_FORMAT + EXPECT_TRUE(0 == strcmp(output_stream.str().c_str(), + ECHO_RESPONSE_MESSAGE_JSON_FORMAT)); + + // Parse text message to binary message and then parse it back to text message + output_stream.str(TString()); + output_stream.clear(); + FLAGS_binary_output = true; + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + TString binary_data = output_stream.str(); + output_stream.str(TString()); + output_stream.clear(); + argv[4] = binary_data.c_str(); + FLAGS_binary_input = true; + FLAGS_binary_output = false; + EXPECT_TRUE(0 == GrpcToolMainLib(5, argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + + // Expected output: ECHO_RESPONSE_MESSAGE + EXPECT_TRUE(0 == strcmp(output_stream.str().c_str(), + ECHO_RESPONSE_MESSAGE_TEXT_FORMAT)); + + FLAGS_binary_input = false; + FLAGS_binary_output = false; + ShutdownServer(); +} + +TEST_F(GrpcToolTest, ParseCommandJsonFormat) { + // Test input "grpc_cli parse localhost:<port> grpc.testing.EchoResponse + // ECHO_RESPONSE_MESSAGE_JSON_FORMAT" + std::stringstream output_stream; + std::stringstream binary_output_stream; + + const TString server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "parse", server_address.c_str(), + "grpc.testing.EchoResponse", + ECHO_RESPONSE_MESSAGE_JSON_FORMAT}; + + FLAGS_json_input = true; + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + + // Expected output: ECHO_RESPONSE_MESSAGE_TEXT_FORMAT + EXPECT_TRUE(0 == strcmp(output_stream.str().c_str(), + ECHO_RESPONSE_MESSAGE_TEXT_FORMAT)); + + // with json_output + output_stream.str(TString()); + output_stream.clear(); + + FLAGS_json_output = true; + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + FLAGS_json_output = false; + FLAGS_json_input = false; + + // Expected output: ECHO_RESPONSE_MESSAGE_JSON_FORMAT + EXPECT_TRUE(0 == strcmp(output_stream.str().c_str(), + ECHO_RESPONSE_MESSAGE_JSON_FORMAT)); + + ShutdownServer(); +} + +TEST_F(GrpcToolTest, TooFewArguments) { + // Test input "grpc_cli call Echo" + std::stringstream output_stream; + const char* argv[] = {"grpc_cli", "call", "Echo"}; + + // Exit with 1 + EXPECT_EXIT( + GrpcToolMainLib( + ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, std::placeholders::_1)), + ::testing::ExitedWithCode(1), ".*Wrong number of arguments for call.*"); + // No output + EXPECT_TRUE(0 == output_stream.tellp()); +} + +TEST_F(GrpcToolTest, TooManyArguments) { + // Test input "grpc_cli call localhost:<port> Echo Echo "message: 'Hello'" + std::stringstream output_stream; + const char* argv[] = {"grpc_cli", "call", "localhost:10000", + "Echo", "Echo", "message: 'Hello'"}; + + // Exit with 1 + EXPECT_EXIT( + GrpcToolMainLib( + ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, std::placeholders::_1)), + ::testing::ExitedWithCode(1), ".*Wrong number of arguments for call.*"); + // No output + EXPECT_TRUE(0 == output_stream.tellp()); +} + +TEST_F(GrpcToolTest, CallCommandWithMetadata) { + // Test input "grpc_cli call localhost:<port> Echo "message: 'Hello'" + const TString server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "call", server_address.c_str(), "Echo", + "message: 'Hello'"}; + + { + std::stringstream output_stream; + FLAGS_metadata = "key0:val0:key1:valq:key2:val2"; + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, + TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + // Expected output: "message: \"Hello\"" + EXPECT_TRUE(nullptr != + strstr(output_stream.str().c_str(), "message: \"Hello\"")); + } + + { + std::stringstream output_stream; + FLAGS_metadata = "key:val\\:val"; + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, + TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + // Expected output: "message: \"Hello\"" + EXPECT_TRUE(nullptr != + strstr(output_stream.str().c_str(), "message: \"Hello\"")); + } + + { + std::stringstream output_stream; + FLAGS_metadata = "key:val\\\\val"; + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, + TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + // Expected output: "message: \"Hello\"" + EXPECT_TRUE(nullptr != + strstr(output_stream.str().c_str(), "message: \"Hello\"")); + } + + FLAGS_metadata = ""; + ShutdownServer(); +} + +TEST_F(GrpcToolTest, CallCommandWithBadMetadata) { + // Test input "grpc_cli call localhost:10000 Echo "message: 'Hello'" + const char* argv[] = {"grpc_cli", "call", "localhost:10000", + "grpc.testing.EchoTestService.Echo", + "message: 'Hello'"}; + FLAGS_protofiles = "src/proto/grpc/testing/echo.proto"; + char* test_srcdir = gpr_getenv("TEST_SRCDIR"); + if (test_srcdir != nullptr) { + FLAGS_proto_path = test_srcdir + TString("/com_github_grpc_grpc"); + } + + { + std::stringstream output_stream; + FLAGS_metadata = "key0:val0:key1"; + // Exit with 1 + EXPECT_EXIT( + GrpcToolMainLib( + ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, std::placeholders::_1)), + ::testing::ExitedWithCode(1), ".*Failed to parse metadata flag.*"); + } + + { + std::stringstream output_stream; + FLAGS_metadata = "key:val\\val"; + // Exit with 1 + EXPECT_EXIT( + GrpcToolMainLib( + ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, std::placeholders::_1)), + ::testing::ExitedWithCode(1), ".*Failed to parse metadata flag.*"); + } + + FLAGS_metadata = ""; + FLAGS_protofiles = ""; + + gpr_free(test_srcdir); +} + +TEST_F(GrpcToolTest, ListCommand_OverrideSslHostName) { + const TString server_address = SetUpServer(true); + + // Test input "grpc_cli ls localhost:<port> --channel_creds_type=ssl + // --ssl_target=z.test.google.fr" + std::stringstream output_stream; + const char* argv[] = {"grpc_cli", "ls", server_address.c_str()}; + FLAGS_l = false; + FLAGS_channel_creds_type = "ssl"; + FLAGS_ssl_target = "z.test.google.fr"; + EXPECT_TRUE( + 0 == GrpcToolMainLib( + ArraySize(argv), argv, TestCliCredentials(true), + std::bind(PrintStream, &output_stream, std::placeholders::_1))); + EXPECT_TRUE(0 == strcmp(output_stream.str().c_str(), + "grpc.testing.EchoTestService\n" + "grpc.reflection.v1alpha.ServerReflection\n")); + + FLAGS_channel_creds_type = ""; + FLAGS_ssl_target = ""; + ShutdownServer(); +} + +TEST_F(GrpcToolTest, ConfiguringDefaultServiceConfig) { + // Test input "grpc_cli list localhost:<port> + // --default_service_config={\"loadBalancingConfig\":[{\"pick_first\":{}}]}" + std::stringstream output_stream; + const TString server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "ls", server_address.c_str()}; + // Just check that the tool is still operational when --default_service_config + // is configured. This particular service config is in reality redundant with + // the channel's default configuration. + FLAGS_l = false; + FLAGS_default_service_config = + "{\"loadBalancingConfig\":[{\"pick_first\":{}}]}"; + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + FLAGS_default_service_config = ""; + EXPECT_TRUE(0 == strcmp(output_stream.str().c_str(), + "grpc.testing.EchoTestService\n" + "grpc.reflection.v1alpha.ServerReflection\n")); + ShutdownServer(); +} + +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + ::testing::FLAGS_gtest_death_test_style = "threadsafe"; + return RUN_ALL_TESTS(); +} diff --git a/contrib/libs/grpc/test/cpp/util/metrics_server.cc b/contrib/libs/grpc/test/cpp/util/metrics_server.cc new file mode 100644 index 0000000000..0493da053e --- /dev/null +++ b/contrib/libs/grpc/test/cpp/util/metrics_server.cc @@ -0,0 +1,117 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *is % allowed in string + */ + +#include "test/cpp/util/metrics_server.h" + +#include <grpc/support/log.h> +#include <grpcpp/server.h> +#include <grpcpp/server_builder.h> + +#include "src/proto/grpc/testing/metrics.grpc.pb.h" +#include "src/proto/grpc/testing/metrics.pb.h" + +namespace grpc { +namespace testing { + +QpsGauge::QpsGauge() + : start_time_(gpr_now(GPR_CLOCK_REALTIME)), num_queries_(0) {} + +void QpsGauge::Reset() { + std::lock_guard<std::mutex> lock(num_queries_mu_); + num_queries_ = 0; + start_time_ = gpr_now(GPR_CLOCK_REALTIME); +} + +void QpsGauge::Incr() { + std::lock_guard<std::mutex> lock(num_queries_mu_); + num_queries_++; +} + +long QpsGauge::Get() { + std::lock_guard<std::mutex> lock(num_queries_mu_); + gpr_timespec time_diff = + gpr_time_sub(gpr_now(GPR_CLOCK_REALTIME), start_time_); + long duration_secs = time_diff.tv_sec > 0 ? time_diff.tv_sec : 1; + return num_queries_ / duration_secs; +} + +grpc::Status MetricsServiceImpl::GetAllGauges( + ServerContext* /*context*/, const EmptyMessage* /*request*/, + ServerWriter<GaugeResponse>* writer) { + gpr_log(GPR_DEBUG, "GetAllGauges called"); + + std::lock_guard<std::mutex> lock(mu_); + for (auto it = qps_gauges_.begin(); it != qps_gauges_.end(); it++) { + GaugeResponse resp; + resp.set_name(it->first); // Gauge name + resp.set_long_value(it->second->Get()); // Gauge value + writer->Write(resp); + } + + return Status::OK; +} + +grpc::Status MetricsServiceImpl::GetGauge(ServerContext* /*context*/, + const GaugeRequest* request, + GaugeResponse* response) { + std::lock_guard<std::mutex> lock(mu_); + + const auto it = qps_gauges_.find(request->name()); + if (it != qps_gauges_.end()) { + response->set_name(it->first); + response->set_long_value(it->second->Get()); + } + + return Status::OK; +} + +std::shared_ptr<QpsGauge> MetricsServiceImpl::CreateQpsGauge( + const TString& name, bool* already_present) { + std::lock_guard<std::mutex> lock(mu_); + + std::shared_ptr<QpsGauge> qps_gauge(new QpsGauge()); + const auto p = qps_gauges_.insert(std::make_pair(name, qps_gauge)); + + // p.first is an iterator pointing to <name, shared_ptr<QpsGauge>> pair. + // p.second is a boolean which is set to 'true' if the QpsGauge is + // successfully inserted in the guages_ map and 'false' if it is already + // present in the map + *already_present = !p.second; + return p.first->second; +} + +// Starts the metrics server and returns the grpc::Server instance. Call Wait() +// on the returned server instance. +std::unique_ptr<grpc::Server> MetricsServiceImpl::StartServer(int port) { + gpr_log(GPR_INFO, "Building metrics server.."); + + const TString address = "0.0.0.0:" + ToString(port); + + ServerBuilder builder; + builder.AddListeningPort(address, grpc::InsecureServerCredentials()); + builder.RegisterService(this); + + std::unique_ptr<grpc::Server> server(builder.BuildAndStart()); + gpr_log(GPR_INFO, "Metrics server %s started. Ready to receive requests..", + address.c_str()); + + return server; +} + +} // namespace testing +} // namespace grpc diff --git a/contrib/libs/grpc/test/cpp/util/metrics_server.h b/contrib/libs/grpc/test/cpp/util/metrics_server.h new file mode 100644 index 0000000000..10ffa7b4dd --- /dev/null +++ b/contrib/libs/grpc/test/cpp/util/metrics_server.h @@ -0,0 +1,98 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *is % allowed in string + */ +#ifndef GRPC_TEST_CPP_METRICS_SERVER_H +#define GRPC_TEST_CPP_METRICS_SERVER_H + +#include <map> +#include <mutex> + +#include <grpcpp/server.h> + +#include "src/proto/grpc/testing/metrics.grpc.pb.h" +#include "src/proto/grpc/testing/metrics.pb.h" + +/* + * This implements a Metrics server defined in + * src/proto/grpc/testing/metrics.proto. Any + * test service can use this to export Metrics (TODO (sreek): Only Gauges for + * now). + * + * Example: + * MetricsServiceImpl metricsImpl; + * .. + * // Create QpsGauge(s). Note: QpsGauges can be created even after calling + * // 'StartServer'. + * QpsGauge qps_gauge1 = metricsImpl.CreateQpsGauge("foo", is_present); + * // qps_gauge1 can now be used anywhere in the program by first making a + * // one-time call qps_gauge1.Reset() and then calling qps_gauge1.Incr() + * // every time to increment a query counter + * + * ... + * // Create the metrics server + * std::unique_ptr<grpc::Server> server = metricsImpl.StartServer(port); + * server->Wait(); // Note: This is blocking. + */ +namespace grpc { +namespace testing { + +class QpsGauge { + public: + QpsGauge(); + + // Initialize the internal timer and reset the query count to 0 + void Reset(); + + // Increment the query count by 1 + void Incr(); + + // Return the current qps (i.e query count divided by the time since this + // QpsGauge object created (or Reset() was called)) + long Get(); + + private: + gpr_timespec start_time_; + long num_queries_; + std::mutex num_queries_mu_; +}; + +class MetricsServiceImpl final : public MetricsService::Service { + public: + grpc::Status GetAllGauges(ServerContext* context, const EmptyMessage* request, + ServerWriter<GaugeResponse>* writer) override; + + grpc::Status GetGauge(ServerContext* context, const GaugeRequest* request, + GaugeResponse* response) override; + + // Create a QpsGauge with name 'name'. is_present is set to true if the Gauge + // is already present in the map. + // NOTE: CreateQpsGauge can be called anytime (i.e before or after calling + // StartServer). + std::shared_ptr<QpsGauge> CreateQpsGauge(const TString& name, + bool* already_present); + + std::unique_ptr<grpc::Server> StartServer(int port); + + private: + std::map<string, std::shared_ptr<QpsGauge>> qps_gauges_; + std::mutex mu_; +}; + +} // namespace testing +} // namespace grpc + +#endif // GRPC_TEST_CPP_METRICS_SERVER_H diff --git a/contrib/libs/grpc/test/cpp/util/proto_file_parser.cc b/contrib/libs/grpc/test/cpp/util/proto_file_parser.cc new file mode 100644 index 0000000000..b0912a712c --- /dev/null +++ b/contrib/libs/grpc/test/cpp/util/proto_file_parser.cc @@ -0,0 +1,323 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/cpp/util/proto_file_parser.h" + +#include <algorithm> +#include <iostream> +#include <sstream> +#include <unordered_set> + +#include <grpcpp/support/config.h> + +namespace grpc { +namespace testing { +namespace { + +// Match the user input method string to the full_name from method descriptor. +bool MethodNameMatch(const TString& full_name, const TString& input) { + TString clean_input = input; + std::replace(clean_input.begin(), clean_input.vend(), '/', '.'); + if (clean_input.size() > full_name.size()) { + return false; + } + return full_name.compare(full_name.size() - clean_input.size(), + clean_input.size(), clean_input) == 0; +} +} // namespace + +class ErrorPrinter : public protobuf::compiler::MultiFileErrorCollector { + public: + explicit ErrorPrinter(ProtoFileParser* parser) : parser_(parser) {} + + void AddError(const google::protobuf::string& filename, int line, int column, + const google::protobuf::string& message) override { + std::ostringstream oss; + oss << "error " << filename << " " << line << " " << column << " " + << message << "\n"; + parser_->LogError(oss.str()); + } + + void AddWarning(const google::protobuf::string& filename, int line, int column, + const google::protobuf::string& message) override { + std::cerr << "warning " << filename << " " << line << " " << column << " " + << message << std::endl; + } + + private: + ProtoFileParser* parser_; // not owned +}; + +ProtoFileParser::ProtoFileParser(const std::shared_ptr<grpc::Channel>& channel, + const TString& proto_path, + const TString& protofiles) + : has_error_(false), + dynamic_factory_(new protobuf::DynamicMessageFactory()) { + std::vector<TString> service_list; + if (channel) { + reflection_db_.reset(new grpc::ProtoReflectionDescriptorDatabase(channel)); + reflection_db_->GetServices(&service_list); + } + + std::unordered_set<TString> known_services; + if (!protofiles.empty()) { + source_tree_.MapPath("", google::protobuf::string(proto_path)); + error_printer_.reset(new ErrorPrinter(this)); + importer_.reset( + new protobuf::compiler::Importer(&source_tree_, error_printer_.get())); + + std::string file_name; + std::stringstream ss(protofiles); + while (std::getline(ss, file_name, ',')) { + const auto* file_desc = importer_->Import(google::protobuf::string(file_name.c_str())); + if (file_desc) { + for (int i = 0; i < file_desc->service_count(); i++) { + service_desc_list_.push_back(file_desc->service(i)); + known_services.insert(file_desc->service(i)->full_name()); + } + } else { + std::cerr << file_name << " not found" << std::endl; + } + } + + file_db_.reset(new protobuf::DescriptorPoolDatabase(*importer_->pool())); + } + + if (!reflection_db_ && !file_db_) { + LogError("No available proto database"); + return; + } + + if (!reflection_db_) { + desc_db_ = std::move(file_db_); + } else if (!file_db_) { + desc_db_ = std::move(reflection_db_); + } else { + desc_db_.reset(new protobuf::MergedDescriptorDatabase(reflection_db_.get(), + file_db_.get())); + } + + desc_pool_.reset(new protobuf::DescriptorPool(desc_db_.get())); + + for (auto it = service_list.begin(); it != service_list.end(); it++) { + if (known_services.find(*it) == known_services.end()) { + if (const protobuf::ServiceDescriptor* service_desc = + desc_pool_->FindServiceByName(google::protobuf::string(*it))) { + service_desc_list_.push_back(service_desc); + known_services.insert(*it); + } + } + } +} + +ProtoFileParser::~ProtoFileParser() {} + +TString ProtoFileParser::GetFullMethodName(const TString& method) { + has_error_ = false; + + if (known_methods_.find(method) != known_methods_.end()) { + return known_methods_[method]; + } + + const protobuf::MethodDescriptor* method_descriptor = nullptr; + for (auto it = service_desc_list_.begin(); it != service_desc_list_.end(); + it++) { + const auto* service_desc = *it; + for (int j = 0; j < service_desc->method_count(); j++) { + const auto* method_desc = service_desc->method(j); + if (MethodNameMatch(method_desc->full_name(), method)) { + if (method_descriptor) { + std::ostringstream error_stream; + error_stream << "Ambiguous method names: "; + error_stream << method_descriptor->full_name() << " "; + error_stream << method_desc->full_name(); + LogError(error_stream.str()); + } + method_descriptor = method_desc; + } + } + } + if (!method_descriptor) { + LogError("Method name not found"); + } + if (has_error_) { + return ""; + } + + known_methods_[method] = method_descriptor->full_name(); + + return method_descriptor->full_name(); +} + +TString ProtoFileParser::GetFormattedMethodName(const TString& method) { + has_error_ = false; + TString formatted_method_name = GetFullMethodName(method); + if (has_error_) { + return ""; + } + size_t last_dot = formatted_method_name.find_last_of('.'); + if (last_dot != TString::npos) { + formatted_method_name[last_dot] = '/'; + } + formatted_method_name.insert(formatted_method_name.begin(), '/'); + return formatted_method_name; +} + +TString ProtoFileParser::GetMessageTypeFromMethod(const TString& method, + bool is_request) { + has_error_ = false; + TString full_method_name = GetFullMethodName(method); + if (has_error_) { + return ""; + } + const protobuf::MethodDescriptor* method_desc = + desc_pool_->FindMethodByName(google::protobuf::string(full_method_name)); + if (!method_desc) { + LogError("Method not found"); + return ""; + } + + return is_request ? method_desc->input_type()->full_name() + : method_desc->output_type()->full_name(); +} + +bool ProtoFileParser::IsStreaming(const TString& method, bool is_request) { + has_error_ = false; + + TString full_method_name = GetFullMethodName(method); + if (has_error_) { + return false; + } + + const protobuf::MethodDescriptor* method_desc = + desc_pool_->FindMethodByName(google::protobuf::string(full_method_name)); + if (!method_desc) { + LogError("Method not found"); + return false; + } + + return is_request ? method_desc->client_streaming() + : method_desc->server_streaming(); +} + +TString ProtoFileParser::GetSerializedProtoFromMethod( + const TString& method, const TString& formatted_proto, + bool is_request, bool is_json_format) { + has_error_ = false; + TString message_type_name = GetMessageTypeFromMethod(method, is_request); + if (has_error_) { + return ""; + } + return GetSerializedProtoFromMessageType(message_type_name, formatted_proto, + is_json_format); +} + +TString ProtoFileParser::GetFormattedStringFromMethod( + const TString& method, const TString& serialized_proto, + bool is_request, bool is_json_format) { + has_error_ = false; + TString message_type_name = GetMessageTypeFromMethod(method, is_request); + if (has_error_) { + return ""; + } + return GetFormattedStringFromMessageType(message_type_name, serialized_proto, + is_json_format); +} + +TString ProtoFileParser::GetSerializedProtoFromMessageType( + const TString& message_type_name, const TString& formatted_proto, + bool is_json_format) { + has_error_ = false; + google::protobuf::string serialized; + const protobuf::Descriptor* desc = + desc_pool_->FindMessageTypeByName(google::protobuf::string(message_type_name)); + if (!desc) { + LogError("Message type not found"); + return ""; + } + std::unique_ptr<grpc::protobuf::Message> msg( + dynamic_factory_->GetPrototype(desc)->New()); + bool ok; + if (is_json_format) { + ok = grpc::protobuf::json::JsonStringToMessage(google::protobuf::string(formatted_proto), msg.get()) + .ok(); + if (!ok) { + LogError("Failed to convert json format to proto."); + return ""; + } + } else { + ok = protobuf::TextFormat::ParseFromString(google::protobuf::string(formatted_proto), msg.get()); + if (!ok) { + LogError("Failed to convert text format to proto."); + return ""; + } + } + + ok = msg->SerializeToString(&serialized); + if (!ok) { + LogError("Failed to serialize proto."); + return ""; + } + return serialized; +} + +TString ProtoFileParser::GetFormattedStringFromMessageType( + const TString& message_type_name, const TString& serialized_proto, + bool is_json_format) { + has_error_ = false; + const protobuf::Descriptor* desc = + desc_pool_->FindMessageTypeByName(google::protobuf::string(message_type_name)); + if (!desc) { + LogError("Message type not found"); + return ""; + } + std::unique_ptr<grpc::protobuf::Message> msg( + dynamic_factory_->GetPrototype(desc)->New()); + if (!msg->ParseFromString(google::protobuf::string(serialized_proto))) { + LogError("Failed to deserialize proto."); + return ""; + } + google::protobuf::string formatted_string; + + if (is_json_format) { + grpc::protobuf::json::JsonPrintOptions jsonPrintOptions; + jsonPrintOptions.add_whitespace = true; + if (!grpc::protobuf::json::MessageToJsonString( + *msg.get(), &formatted_string, jsonPrintOptions) + .ok()) { + LogError("Failed to print proto message to json format"); + return ""; + } + } else { + if (!protobuf::TextFormat::PrintToString(*msg.get(), &formatted_string)) { + LogError("Failed to print proto message to text format"); + return ""; + } + } + return formatted_string; +} + +void ProtoFileParser::LogError(const TString& error_msg) { + if (!error_msg.empty()) { + std::cerr << error_msg << std::endl; + } + has_error_ = true; +} + +} // namespace testing +} // namespace grpc diff --git a/contrib/libs/grpc/test/cpp/util/proto_file_parser.h b/contrib/libs/grpc/test/cpp/util/proto_file_parser.h new file mode 100644 index 0000000000..c0445641c7 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/util/proto_file_parser.h @@ -0,0 +1,129 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef GRPC_TEST_CPP_UTIL_PROTO_FILE_PARSER_H +#define GRPC_TEST_CPP_UTIL_PROTO_FILE_PARSER_H + +#include <memory> + +#include <grpcpp/channel.h> + +#include "test/cpp/util/config_grpc_cli.h" +#include "test/cpp/util/proto_reflection_descriptor_database.h" + +namespace grpc { +namespace testing { +class ErrorPrinter; + +// Find method and associated request/response types. +class ProtoFileParser { + public: + // The parser will search proto files using the server reflection service + // provided on the given channel. The given protofiles in a source tree rooted + // from proto_path will also be searched. + ProtoFileParser(const std::shared_ptr<grpc::Channel>& channel, + const TString& proto_path, const TString& protofiles); + + ~ProtoFileParser(); + + // The input method name in the following four functions could be a partial + // string such as Service.Method or even just Method. It will log an error if + // there is ambiguity. + // Full method name is in the form of Service.Method, it's good to be used in + // descriptor database queries. + TString GetFullMethodName(const TString& method); + + // Formatted method name is in the form of /Service/Method, it's good to be + // used as the argument of Stub::Call() + TString GetFormattedMethodName(const TString& method); + + /// Converts a text or json string to its binary proto representation for the + /// given method's input or return type. + /// \param method the name of the method (does not need to be fully qualified + /// name) + /// \param formatted_proto the text- or json-formatted proto string + /// \param is_request if \c true the resolved type is that of the input + /// parameter of the method, otherwise it is the output type + /// \param is_json_format if \c true the \c formatted_proto is treated as a + /// json-formatted proto, otherwise it is treated as a text-formatted + /// proto + /// \return the serialised binary proto representation of \c formatted_proto + TString GetSerializedProtoFromMethod(const TString& method, + const TString& formatted_proto, + bool is_request, + bool is_json_format); + + /// Converts a text or json string to its proto representation for the given + /// message type. + /// \param formatted_proto the text- or json-formatted proto string + /// \return the serialised binary proto representation of \c formatted_proto + TString GetSerializedProtoFromMessageType( + const TString& message_type_name, const TString& formatted_proto, + bool is_json_format); + + /// Converts a binary proto string to its text or json string representation + /// for the given method's input or return type. + /// \param method the name of the method (does not need to be a fully + /// qualified name) + /// \param the serialised binary proto representation of type + /// \c message_type_name + /// \return the text- or json-formatted proto string of \c serialized_proto + TString GetFormattedStringFromMethod(const TString& method, + const TString& serialized_proto, + bool is_request, + bool is_json_format); + + /// Converts a binary proto string to its text or json string representation + /// for the given message type. + /// \param the serialised binary proto representation of type + /// \c message_type_name + /// \return the text- or json-formatted proto string of \c serialized_proto + TString GetFormattedStringFromMessageType( + const TString& message_type_name, const TString& serialized_proto, + bool is_json_format); + + bool IsStreaming(const TString& method, bool is_request); + + bool HasError() const { return has_error_; } + + void LogError(const TString& error_msg); + + private: + TString GetMessageTypeFromMethod(const TString& method, + bool is_request); + + bool has_error_; + TString request_text_; + protobuf::compiler::DiskSourceTree source_tree_; + std::unique_ptr<ErrorPrinter> error_printer_; + std::unique_ptr<protobuf::compiler::Importer> importer_; + std::unique_ptr<grpc::ProtoReflectionDescriptorDatabase> reflection_db_; + std::unique_ptr<protobuf::DescriptorPoolDatabase> file_db_; + std::unique_ptr<protobuf::DescriptorDatabase> desc_db_; + std::unique_ptr<protobuf::DescriptorPool> desc_pool_; + std::unique_ptr<protobuf::DynamicMessageFactory> dynamic_factory_; + std::unique_ptr<grpc::protobuf::Message> request_prototype_; + std::unique_ptr<grpc::protobuf::Message> response_prototype_; + std::unordered_map<TString, TString> known_methods_; + std::vector<const protobuf::ServiceDescriptor*> service_desc_list_; +}; + +} // namespace testing +} // namespace grpc + +#endif // GRPC_TEST_CPP_UTIL_PROTO_FILE_PARSER_H diff --git a/contrib/libs/grpc/test/cpp/util/proto_reflection_descriptor_database.cc b/contrib/libs/grpc/test/cpp/util/proto_reflection_descriptor_database.cc new file mode 100644 index 0000000000..27a4c1e4cf --- /dev/null +++ b/contrib/libs/grpc/test/cpp/util/proto_reflection_descriptor_database.cc @@ -0,0 +1,333 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/cpp/util/proto_reflection_descriptor_database.h" + +#include <vector> + +#include <grpc/support/log.h> + +using grpc::reflection::v1alpha::ErrorResponse; +using grpc::reflection::v1alpha::ListServiceResponse; +using grpc::reflection::v1alpha::ServerReflection; +using grpc::reflection::v1alpha::ServerReflectionRequest; +using grpc::reflection::v1alpha::ServerReflectionResponse; + +namespace grpc { + +ProtoReflectionDescriptorDatabase::ProtoReflectionDescriptorDatabase( + std::unique_ptr<ServerReflection::Stub> stub) + : stub_(std::move(stub)) {} + +ProtoReflectionDescriptorDatabase::ProtoReflectionDescriptorDatabase( + const std::shared_ptr<grpc::Channel>& channel) + : stub_(ServerReflection::NewStub(channel)) {} + +ProtoReflectionDescriptorDatabase::~ProtoReflectionDescriptorDatabase() { + if (stream_) { + stream_->WritesDone(); + Status status = stream_->Finish(); + if (!status.ok()) { + if (status.error_code() == StatusCode::UNIMPLEMENTED) { + fprintf(stderr, + "Reflection request not implemented; " + "is the ServerReflection service enabled?\n"); + } else { + fprintf(stderr, + "ServerReflectionInfo rpc failed. Error code: %d, message: %s, " + "debug info: %s\n", + static_cast<int>(status.error_code()), + status.error_message().c_str(), + ctx_.debug_error_string().c_str()); + } + } + } +} + +bool ProtoReflectionDescriptorDatabase::FindFileByName( + const google::protobuf::string& filename, protobuf::FileDescriptorProto* output) { + if (cached_db_.FindFileByName(filename, output)) { + return true; + } + + if (known_files_.find(filename) != known_files_.end()) { + return false; + } + + ServerReflectionRequest request; + request.set_file_by_filename(filename); + ServerReflectionResponse response; + + if (!DoOneRequest(request, response)) { + return false; + } + + if (response.message_response_case() == + ServerReflectionResponse::MessageResponseCase::kFileDescriptorResponse) { + AddFileFromResponse(response.file_descriptor_response()); + } else if (response.message_response_case() == + ServerReflectionResponse::MessageResponseCase::kErrorResponse) { + const ErrorResponse& error = response.error_response(); + if (error.error_code() == StatusCode::NOT_FOUND) { + gpr_log(GPR_INFO, "NOT_FOUND from server for FindFileByName(%s)", + filename.c_str()); + } else { + gpr_log(GPR_INFO, + "Error on FindFileByName(%s)\n\tError code: %d\n" + "\tError Message: %s", + filename.c_str(), error.error_code(), + error.error_message().c_str()); + } + } else { + gpr_log( + GPR_INFO, + "Error on FindFileByName(%s) response type\n" + "\tExpecting: %d\n\tReceived: %d", + filename.c_str(), + ServerReflectionResponse::MessageResponseCase::kFileDescriptorResponse, + response.message_response_case()); + } + + return cached_db_.FindFileByName(filename, output); +} + +bool ProtoReflectionDescriptorDatabase::FindFileContainingSymbol( + const google::protobuf::string& symbol_name, protobuf::FileDescriptorProto* output) { + if (cached_db_.FindFileContainingSymbol(symbol_name, output)) { + return true; + } + + if (missing_symbols_.find(symbol_name) != missing_symbols_.end()) { + return false; + } + + ServerReflectionRequest request; + request.set_file_containing_symbol(symbol_name); + ServerReflectionResponse response; + + if (!DoOneRequest(request, response)) { + return false; + } + + if (response.message_response_case() == + ServerReflectionResponse::MessageResponseCase::kFileDescriptorResponse) { + AddFileFromResponse(response.file_descriptor_response()); + } else if (response.message_response_case() == + ServerReflectionResponse::MessageResponseCase::kErrorResponse) { + const ErrorResponse& error = response.error_response(); + if (error.error_code() == StatusCode::NOT_FOUND) { + missing_symbols_.insert(symbol_name); + gpr_log(GPR_INFO, + "NOT_FOUND from server for FindFileContainingSymbol(%s)", + symbol_name.c_str()); + } else { + gpr_log(GPR_INFO, + "Error on FindFileContainingSymbol(%s)\n" + "\tError code: %d\n\tError Message: %s", + symbol_name.c_str(), error.error_code(), + error.error_message().c_str()); + } + } else { + gpr_log( + GPR_INFO, + "Error on FindFileContainingSymbol(%s) response type\n" + "\tExpecting: %d\n\tReceived: %d", + symbol_name.c_str(), + ServerReflectionResponse::MessageResponseCase::kFileDescriptorResponse, + response.message_response_case()); + } + return cached_db_.FindFileContainingSymbol(symbol_name, output); +} + +bool ProtoReflectionDescriptorDatabase::FindFileContainingExtension( + const google::protobuf::string& containing_type, int field_number, + protobuf::FileDescriptorProto* output) { + if (cached_db_.FindFileContainingExtension(containing_type, field_number, + output)) { + return true; + } + + if (missing_extensions_.find(containing_type) != missing_extensions_.end() && + missing_extensions_[containing_type].find(field_number) != + missing_extensions_[containing_type].end()) { + gpr_log(GPR_INFO, "nested map."); + return false; + } + + ServerReflectionRequest request; + request.mutable_file_containing_extension()->set_containing_type( + containing_type); + request.mutable_file_containing_extension()->set_extension_number( + field_number); + ServerReflectionResponse response; + + if (!DoOneRequest(request, response)) { + return false; + } + + if (response.message_response_case() == + ServerReflectionResponse::MessageResponseCase::kFileDescriptorResponse) { + AddFileFromResponse(response.file_descriptor_response()); + } else if (response.message_response_case() == + ServerReflectionResponse::MessageResponseCase::kErrorResponse) { + const ErrorResponse& error = response.error_response(); + if (error.error_code() == StatusCode::NOT_FOUND) { + if (missing_extensions_.find(containing_type) == + missing_extensions_.end()) { + missing_extensions_[containing_type] = {}; + } + missing_extensions_[containing_type].insert(field_number); + gpr_log(GPR_INFO, + "NOT_FOUND from server for FindFileContainingExtension(%s, %d)", + containing_type.c_str(), field_number); + } else { + gpr_log(GPR_INFO, + "Error on FindFileContainingExtension(%s, %d)\n" + "\tError code: %d\n\tError Message: %s", + containing_type.c_str(), field_number, error.error_code(), + error.error_message().c_str()); + } + } else { + gpr_log( + GPR_INFO, + "Error on FindFileContainingExtension(%s, %d) response type\n" + "\tExpecting: %d\n\tReceived: %d", + containing_type.c_str(), field_number, + ServerReflectionResponse::MessageResponseCase::kFileDescriptorResponse, + response.message_response_case()); + } + + return cached_db_.FindFileContainingExtension(containing_type, field_number, + output); +} + +bool ProtoReflectionDescriptorDatabase::FindAllExtensionNumbers( + const google::protobuf::string& extendee_type, std::vector<int>* output) { + if (cached_extension_numbers_.find(extendee_type) != + cached_extension_numbers_.end()) { + *output = cached_extension_numbers_[extendee_type]; + return true; + } + + ServerReflectionRequest request; + request.set_all_extension_numbers_of_type(extendee_type); + ServerReflectionResponse response; + + if (!DoOneRequest(request, response)) { + return false; + } + + if (response.message_response_case() == + ServerReflectionResponse::MessageResponseCase:: + kAllExtensionNumbersResponse) { + auto number = response.all_extension_numbers_response().extension_number(); + *output = std::vector<int>(number.begin(), number.end()); + cached_extension_numbers_[extendee_type] = *output; + return true; + } else if (response.message_response_case() == + ServerReflectionResponse::MessageResponseCase::kErrorResponse) { + const ErrorResponse& error = response.error_response(); + if (error.error_code() == StatusCode::NOT_FOUND) { + gpr_log(GPR_INFO, "NOT_FOUND from server for FindAllExtensionNumbers(%s)", + extendee_type.c_str()); + } else { + gpr_log(GPR_INFO, + "Error on FindAllExtensionNumbersExtension(%s)\n" + "\tError code: %d\n\tError Message: %s", + extendee_type.c_str(), error.error_code(), + error.error_message().c_str()); + } + } + return false; +} + +bool ProtoReflectionDescriptorDatabase::GetServices( + std::vector<TString>* output) { + ServerReflectionRequest request; + request.set_list_services(""); + ServerReflectionResponse response; + + if (!DoOneRequest(request, response)) { + return false; + } + + if (response.message_response_case() == + ServerReflectionResponse::MessageResponseCase::kListServicesResponse) { + const ListServiceResponse& ls_response = response.list_services_response(); + for (int i = 0; i < ls_response.service_size(); ++i) { + (*output).push_back(ls_response.service(i).name()); + } + return true; + } else if (response.message_response_case() == + ServerReflectionResponse::MessageResponseCase::kErrorResponse) { + const ErrorResponse& error = response.error_response(); + gpr_log(GPR_INFO, + "Error on GetServices()\n\tError code: %d\n" + "\tError Message: %s", + error.error_code(), error.error_message().c_str()); + } else { + gpr_log( + GPR_INFO, + "Error on GetServices() response type\n\tExpecting: %d\n\tReceived: %d", + ServerReflectionResponse::MessageResponseCase::kListServicesResponse, + response.message_response_case()); + } + return false; +} + +const protobuf::FileDescriptorProto +ProtoReflectionDescriptorDatabase::ParseFileDescriptorProtoResponse( + const TString& byte_fd_proto) { + protobuf::FileDescriptorProto file_desc_proto; + file_desc_proto.ParseFromString(google::protobuf::string(byte_fd_proto)); + return file_desc_proto; +} + +void ProtoReflectionDescriptorDatabase::AddFileFromResponse( + const grpc::reflection::v1alpha::FileDescriptorResponse& response) { + for (int i = 0; i < response.file_descriptor_proto_size(); ++i) { + const protobuf::FileDescriptorProto file_proto = + ParseFileDescriptorProtoResponse(response.file_descriptor_proto(i)); + if (known_files_.find(file_proto.name()) == known_files_.end()) { + known_files_.insert(file_proto.name()); + cached_db_.Add(file_proto); + } + } +} + +const std::shared_ptr<ProtoReflectionDescriptorDatabase::ClientStream> +ProtoReflectionDescriptorDatabase::GetStream() { + if (!stream_) { + stream_ = stub_->ServerReflectionInfo(&ctx_); + } + return stream_; +} + +bool ProtoReflectionDescriptorDatabase::DoOneRequest( + const ServerReflectionRequest& request, + ServerReflectionResponse& response) { + bool success = false; + stream_mutex_.lock(); + if (GetStream()->Write(request) && GetStream()->Read(&response)) { + success = true; + } + stream_mutex_.unlock(); + return success; +} + +} // namespace grpc diff --git a/contrib/libs/grpc/test/cpp/util/proto_reflection_descriptor_database.h b/contrib/libs/grpc/test/cpp/util/proto_reflection_descriptor_database.h new file mode 100644 index 0000000000..cdd6f0cccd --- /dev/null +++ b/contrib/libs/grpc/test/cpp/util/proto_reflection_descriptor_database.h @@ -0,0 +1,111 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ +#ifndef GRPC_TEST_CPP_PROTO_SERVER_REFLECTION_DATABSE_H +#define GRPC_TEST_CPP_PROTO_SERVER_REFLECTION_DATABSE_H + +#include <mutex> +#include <unordered_map> +#include <unordered_set> +#include <vector> + +#include <grpcpp/grpcpp.h> +#include <grpcpp/impl/codegen/config_protobuf.h> +#include "src/proto/grpc/reflection/v1alpha/reflection.grpc.pb.h" + +namespace grpc { + +// ProtoReflectionDescriptorDatabase takes a stub of ServerReflection and +// provides the methods defined by DescriptorDatabase interfaces. It can be used +// to feed a DescriptorPool instance. +class ProtoReflectionDescriptorDatabase : public protobuf::DescriptorDatabase { + public: + explicit ProtoReflectionDescriptorDatabase( + std::unique_ptr<reflection::v1alpha::ServerReflection::Stub> stub); + + explicit ProtoReflectionDescriptorDatabase( + const std::shared_ptr<grpc::Channel>& channel); + + virtual ~ProtoReflectionDescriptorDatabase(); + + // The following four methods implement DescriptorDatabase interfaces. + // + // Find a file by file name. Fills in *output and returns true if found. + // Otherwise, returns false, leaving the contents of *output undefined. + bool FindFileByName(const google::protobuf::string& filename, + protobuf::FileDescriptorProto* output) override; + + // Find the file that declares the given fully-qualified symbol name. + // If found, fills in *output and returns true, otherwise returns false + // and leaves *output undefined. + bool FindFileContainingSymbol(const google::protobuf::string& symbol_name, + protobuf::FileDescriptorProto* output) override; + + // Find the file which defines an extension extending the given message type + // with the given field number. If found, fills in *output and returns true, + // otherwise returns false and leaves *output undefined. containing_type + // must be a fully-qualified type name. + bool FindFileContainingExtension( + const google::protobuf::string& containing_type, int field_number, + protobuf::FileDescriptorProto* output) override; + + // Finds the tag numbers used by all known extensions of + // extendee_type, and appends them to output in an undefined + // order. This method is best-effort: it's not guaranteed that the + // database will find all extensions, and it's not guaranteed that + // FindFileContainingExtension will return true on all of the found + // numbers. Returns true if the search was successful, otherwise + // returns false and leaves output unchanged. + bool FindAllExtensionNumbers(const google::protobuf::string& extendee_type, + std::vector<int>* output) override; + + // Provide a list of full names of registered services + bool GetServices(std::vector<TString>* output); + + private: + typedef ClientReaderWriter< + grpc::reflection::v1alpha::ServerReflectionRequest, + grpc::reflection::v1alpha::ServerReflectionResponse> + ClientStream; + + const protobuf::FileDescriptorProto ParseFileDescriptorProtoResponse( + const TString& byte_fd_proto); + + void AddFileFromResponse( + const grpc::reflection::v1alpha::FileDescriptorResponse& response); + + const std::shared_ptr<ClientStream> GetStream(); + + bool DoOneRequest( + const grpc::reflection::v1alpha::ServerReflectionRequest& request, + grpc::reflection::v1alpha::ServerReflectionResponse& response); + + std::shared_ptr<ClientStream> stream_; + grpc::ClientContext ctx_; + std::unique_ptr<grpc::reflection::v1alpha::ServerReflection::Stub> stub_; + std::unordered_set<string> known_files_; + std::unordered_set<string> missing_symbols_; + std::unordered_map<string, std::unordered_set<int>> missing_extensions_; + std::unordered_map<string, std::vector<int>> cached_extension_numbers_; + std::mutex stream_mutex_; + + protobuf::SimpleDescriptorDatabase cached_db_; +}; + +} // namespace grpc + +#endif // GRPC_TEST_CPP_METRICS_SERVER_H diff --git a/contrib/libs/grpc/test/cpp/util/service_describer.cc b/contrib/libs/grpc/test/cpp/util/service_describer.cc new file mode 100644 index 0000000000..2af1104b97 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/util/service_describer.cc @@ -0,0 +1,92 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/cpp/util/service_describer.h" + +#include <iostream> +#include <sstream> +#include <util/generic/string.h> +#include <vector> + +namespace grpc { +namespace testing { + +TString DescribeServiceList(std::vector<TString> service_list, + grpc::protobuf::DescriptorPool& desc_pool) { + std::stringstream result; + for (auto it = service_list.begin(); it != service_list.end(); it++) { + auto const& service = *it; + const grpc::protobuf::ServiceDescriptor* service_desc = + desc_pool.FindServiceByName(google::protobuf::string(service)); + if (service_desc != nullptr) { + result << DescribeService(service_desc); + } + } + return result.str(); +} + +TString DescribeService(const grpc::protobuf::ServiceDescriptor* service) { + TString result; + if (service->options().deprecated()) { + result.append("DEPRECATED\n"); + } + result.append("filename: " + service->file()->name() + "\n"); + + TString package = service->full_name(); + size_t pos = package.rfind("." + service->name()); + if (pos != TString::npos) { + package.erase(pos); + result.append("package: " + package + ";\n"); + } + result.append("service " + service->name() + " {\n"); + for (int i = 0; i < service->method_count(); ++i) { + result.append(DescribeMethod(service->method(i))); + } + result.append("}\n\n"); + return result; +} + +TString DescribeMethod(const grpc::protobuf::MethodDescriptor* method) { + std::stringstream result; + result << " rpc " << method->name() + << (method->client_streaming() ? "(stream " : "(") + << method->input_type()->full_name() << ") returns " + << (method->server_streaming() ? "(stream " : "(") + << method->output_type()->full_name() << ") {}\n"; + if (method->options().deprecated()) { + result << " DEPRECATED"; + } + return result.str(); +} + +TString SummarizeService(const grpc::protobuf::ServiceDescriptor* service) { + TString result; + for (int i = 0; i < service->method_count(); ++i) { + result.append(SummarizeMethod(service->method(i))); + } + return result; +} + +TString SummarizeMethod(const grpc::protobuf::MethodDescriptor* method) { + TString result = method->name(); + result.append("\n"); + return result; +} + +} // namespace testing +} // namespace grpc diff --git a/contrib/libs/grpc/test/cpp/util/service_describer.h b/contrib/libs/grpc/test/cpp/util/service_describer.h new file mode 100644 index 0000000000..a473f03744 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/util/service_describer.h @@ -0,0 +1,42 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef GRPC_TEST_CPP_UTIL_SERVICE_DESCRIBER_H +#define GRPC_TEST_CPP_UTIL_SERVICE_DESCRIBER_H + +#include <grpcpp/support/config.h> +#include "test/cpp/util/config_grpc_cli.h" + +namespace grpc { +namespace testing { + +TString DescribeServiceList(std::vector<TString> service_list, + grpc::protobuf::DescriptorPool& desc_pool); + +TString DescribeService(const grpc::protobuf::ServiceDescriptor* service); + +TString DescribeMethod(const grpc::protobuf::MethodDescriptor* method); + +TString SummarizeService(const grpc::protobuf::ServiceDescriptor* service); + +TString SummarizeMethod(const grpc::protobuf::MethodDescriptor* method); + +} // namespace testing +} // namespace grpc + +#endif // GRPC_TEST_CPP_UTIL_SERVICE_DESCRIBER_H diff --git a/contrib/libs/grpc/test/cpp/util/slice_test.cc b/contrib/libs/grpc/test/cpp/util/slice_test.cc new file mode 100644 index 0000000000..d7e945ae38 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/util/slice_test.cc @@ -0,0 +1,144 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <grpc++/support/slice.h> +#include <grpcpp/impl/grpc_library.h> + +#include <grpc/grpc.h> +#include <grpc/slice.h> +#include <gtest/gtest.h> + +#include "test/core/util/test_config.h" + +namespace grpc { + +static internal::GrpcLibraryInitializer g_gli_initializer; + +namespace { + +const char* kContent = "hello xxxxxxxxxxxxxxxxxxxx world"; + +class SliceTest : public ::testing::Test { + protected: + static void SetUpTestCase() { grpc_init(); } + + static void TearDownTestCase() { grpc_shutdown(); } + + void CheckSliceSize(const Slice& s, const TString& content) { + EXPECT_EQ(content.size(), s.size()); + } + void CheckSlice(const Slice& s, const TString& content) { + EXPECT_EQ(content.size(), s.size()); + EXPECT_EQ(content, + TString(reinterpret_cast<const char*>(s.begin()), s.size())); + } +}; + +TEST_F(SliceTest, Empty) { + Slice empty_slice; + CheckSlice(empty_slice, ""); +} + +TEST_F(SliceTest, Sized) { + Slice sized_slice(strlen(kContent)); + CheckSliceSize(sized_slice, kContent); +} + +TEST_F(SliceTest, String) { + Slice spp(kContent); + CheckSlice(spp, kContent); +} + +TEST_F(SliceTest, Buf) { + Slice spp(kContent, strlen(kContent)); + CheckSlice(spp, kContent); +} + +TEST_F(SliceTest, StaticBuf) { + Slice spp(kContent, strlen(kContent), Slice::STATIC_SLICE); + CheckSlice(spp, kContent); +} + +TEST_F(SliceTest, SliceNew) { + char* x = new char[strlen(kContent) + 1]; + strcpy(x, kContent); + Slice spp(x, strlen(x), [](void* p) { delete[] static_cast<char*>(p); }); + CheckSlice(spp, kContent); +} + +TEST_F(SliceTest, SliceNewDoNothing) { + Slice spp(const_cast<char*>(kContent), strlen(kContent), [](void* /*p*/) {}); + CheckSlice(spp, kContent); +} + +TEST_F(SliceTest, SliceNewWithUserData) { + struct stest { + char* x; + int y; + }; + auto* t = new stest; + t->x = new char[strlen(kContent) + 1]; + strcpy(t->x, kContent); + Slice spp(t->x, strlen(t->x), + [](void* p) { + auto* t = static_cast<stest*>(p); + delete[] t->x; + delete t; + }, + t); + CheckSlice(spp, kContent); +} + +TEST_F(SliceTest, SliceNewLen) { + Slice spp(const_cast<char*>(kContent), strlen(kContent), + [](void* /*p*/, size_t l) { EXPECT_EQ(l, strlen(kContent)); }); + CheckSlice(spp, kContent); +} + +TEST_F(SliceTest, Steal) { + grpc_slice s = grpc_slice_from_copied_string(kContent); + Slice spp(s, Slice::STEAL_REF); + CheckSlice(spp, kContent); +} + +TEST_F(SliceTest, Add) { + grpc_slice s = grpc_slice_from_copied_string(kContent); + Slice spp(s, Slice::ADD_REF); + grpc_slice_unref(s); + CheckSlice(spp, kContent); +} + +TEST_F(SliceTest, Cslice) { + grpc_slice s = grpc_slice_from_copied_string(kContent); + Slice spp(s, Slice::STEAL_REF); + CheckSlice(spp, kContent); + grpc_slice c_slice = spp.c_slice(); + EXPECT_EQ(GRPC_SLICE_START_PTR(s), GRPC_SLICE_START_PTR(c_slice)); + EXPECT_EQ(GRPC_SLICE_END_PTR(s), GRPC_SLICE_END_PTR(c_slice)); + grpc_slice_unref(c_slice); +} + +} // namespace +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + int ret = RUN_ALL_TESTS(); + return ret; +} diff --git a/contrib/libs/grpc/test/cpp/util/string_ref_helper.cc b/contrib/libs/grpc/test/cpp/util/string_ref_helper.cc new file mode 100644 index 0000000000..e573f5d33a --- /dev/null +++ b/contrib/libs/grpc/test/cpp/util/string_ref_helper.cc @@ -0,0 +1,29 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/cpp/util/string_ref_helper.h" + +namespace grpc { +namespace testing { + +TString ToString(const grpc::string_ref& r) { + return TString(r.data(), r.size()); +} + +} // namespace testing +} // namespace grpc diff --git a/contrib/libs/grpc/test/cpp/util/string_ref_helper.h b/contrib/libs/grpc/test/cpp/util/string_ref_helper.h new file mode 100644 index 0000000000..e9e941f319 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/util/string_ref_helper.h @@ -0,0 +1,32 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef GRPC_TEST_CPP_UTIL_STRING_REF_HELPER_H +#define GRPC_TEST_CPP_UTIL_STRING_REF_HELPER_H + +#include <grpcpp/support/string_ref.h> + +namespace grpc { +namespace testing { + +TString ToString(const grpc::string_ref& r); + +} // namespace testing +} // namespace grpc + +#endif // GRPC_TEST_CPP_UTIL_STRING_REF_HELPER_H diff --git a/contrib/libs/grpc/test/cpp/util/string_ref_test.cc b/contrib/libs/grpc/test/cpp/util/string_ref_test.cc new file mode 100644 index 0000000000..8e3259b764 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/util/string_ref_test.cc @@ -0,0 +1,205 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <grpcpp/support/string_ref.h> + +#include <string.h> + +#include <gtest/gtest.h> + +#include "test/core/util/test_config.h" + +namespace grpc { +namespace { + +const char kTestString[] = "blah"; +const char kTestStringWithEmbeddedNull[] = "blah\0foo"; +const size_t kTestStringWithEmbeddedNullLength = 8; +const char kTestUnrelatedString[] = "foo"; + +class StringRefTest : public ::testing::Test {}; + +TEST_F(StringRefTest, Empty) { + string_ref s; + EXPECT_EQ(0U, s.length()); + EXPECT_EQ(nullptr, s.data()); +} + +TEST_F(StringRefTest, FromCString) { + string_ref s(kTestString); + EXPECT_EQ(strlen(kTestString), s.length()); + EXPECT_EQ(kTestString, s.data()); +} + +TEST_F(StringRefTest, FromCStringWithLength) { + string_ref s(kTestString, 2); + EXPECT_EQ(2U, s.length()); + EXPECT_EQ(kTestString, s.data()); +} + +TEST_F(StringRefTest, FromString) { + string copy(kTestString); + string_ref s(copy); + EXPECT_EQ(copy.data(), s.data()); + EXPECT_EQ(copy.length(), s.length()); +} + +TEST_F(StringRefTest, CopyConstructor) { + string_ref s1(kTestString); + ; + const string_ref& s2(s1); + EXPECT_EQ(s1.length(), s2.length()); + EXPECT_EQ(s1.data(), s2.data()); +} + +TEST_F(StringRefTest, FromStringWithEmbeddedNull) { + string copy(kTestStringWithEmbeddedNull, kTestStringWithEmbeddedNullLength); + string_ref s(copy); + EXPECT_EQ(copy.data(), s.data()); + EXPECT_EQ(copy.length(), s.length()); + EXPECT_EQ(kTestStringWithEmbeddedNullLength, s.length()); +} + +TEST_F(StringRefTest, Assignment) { + string_ref s1(kTestString); + ; + string_ref s2; + EXPECT_EQ(nullptr, s2.data()); + s2 = s1; + EXPECT_EQ(s1.length(), s2.length()); + EXPECT_EQ(s1.data(), s2.data()); +} + +TEST_F(StringRefTest, Iterator) { + string_ref s(kTestString); + size_t i = 0; + for (auto it = s.cbegin(); it != s.cend(); ++it) { + auto val = kTestString[i++]; + EXPECT_EQ(val, *it); + } + EXPECT_EQ(strlen(kTestString), i); +} + +TEST_F(StringRefTest, ReverseIterator) { + string_ref s(kTestString); + size_t i = strlen(kTestString); + for (auto rit = s.crbegin(); rit != s.crend(); ++rit) { + auto val = kTestString[--i]; + EXPECT_EQ(val, *rit); + } + EXPECT_EQ(0U, i); +} + +TEST_F(StringRefTest, Capacity) { + string_ref empty; + EXPECT_EQ(0U, empty.length()); + EXPECT_EQ(0U, empty.size()); + EXPECT_EQ(0U, empty.max_size()); + EXPECT_TRUE(empty.empty()); + + string_ref s(kTestString); + EXPECT_EQ(strlen(kTestString), s.length()); + EXPECT_EQ(s.length(), s.size()); + EXPECT_EQ(s.max_size(), s.length()); + EXPECT_FALSE(s.empty()); +} + +TEST_F(StringRefTest, Compare) { + string_ref s1(kTestString); + string s1_copy(kTestString); + string_ref s2(kTestUnrelatedString); + string_ref s3(kTestStringWithEmbeddedNull, kTestStringWithEmbeddedNullLength); + EXPECT_EQ(0, s1.compare(s1_copy)); + EXPECT_NE(0, s1.compare(s2)); + EXPECT_NE(0, s1.compare(s3)); +} + +TEST_F(StringRefTest, StartsWith) { + string_ref s1(kTestString); + string_ref s2(kTestUnrelatedString); + string_ref s3(kTestStringWithEmbeddedNull, kTestStringWithEmbeddedNullLength); + EXPECT_TRUE(s1.starts_with(s1)); + EXPECT_FALSE(s1.starts_with(s2)); + EXPECT_FALSE(s2.starts_with(s1)); + EXPECT_FALSE(s1.starts_with(s3)); + EXPECT_TRUE(s3.starts_with(s1)); +} + +TEST_F(StringRefTest, Endswith) { + string_ref s1(kTestString); + string_ref s2(kTestUnrelatedString); + string_ref s3(kTestStringWithEmbeddedNull, kTestStringWithEmbeddedNullLength); + EXPECT_TRUE(s1.ends_with(s1)); + EXPECT_FALSE(s1.ends_with(s2)); + EXPECT_FALSE(s2.ends_with(s1)); + EXPECT_FALSE(s2.ends_with(s3)); + EXPECT_TRUE(s3.ends_with(s2)); +} + +TEST_F(StringRefTest, Find) { + string_ref s1(kTestString); + string_ref s2(kTestUnrelatedString); + string_ref s3(kTestStringWithEmbeddedNull, kTestStringWithEmbeddedNullLength); + EXPECT_EQ(0U, s1.find(s1)); + EXPECT_EQ(0U, s2.find(s2)); + EXPECT_EQ(0U, s3.find(s3)); + EXPECT_EQ(string_ref::npos, s1.find(s2)); + EXPECT_EQ(string_ref::npos, s2.find(s1)); + EXPECT_EQ(string_ref::npos, s1.find(s3)); + EXPECT_EQ(0U, s3.find(s1)); + EXPECT_EQ(5U, s3.find(s2)); + EXPECT_EQ(string_ref::npos, s1.find('z')); + EXPECT_EQ(1U, s2.find('o')); +} + +TEST_F(StringRefTest, SubString) { + string_ref s(kTestStringWithEmbeddedNull, kTestStringWithEmbeddedNullLength); + string_ref sub1 = s.substr(0, 4); + EXPECT_EQ(string_ref(kTestString), sub1); + string_ref sub2 = s.substr(5); + EXPECT_EQ(string_ref(kTestUnrelatedString), sub2); +} + +TEST_F(StringRefTest, ComparisonOperators) { + string_ref s1(kTestString); + string_ref s2(kTestUnrelatedString); + string_ref s3(kTestStringWithEmbeddedNull, kTestStringWithEmbeddedNullLength); + EXPECT_EQ(s1, s1); + EXPECT_EQ(s2, s2); + EXPECT_EQ(s3, s3); + EXPECT_GE(s1, s1); + EXPECT_GE(s2, s2); + EXPECT_GE(s3, s3); + EXPECT_LE(s1, s1); + EXPECT_LE(s2, s2); + EXPECT_LE(s3, s3); + EXPECT_NE(s1, s2); + EXPECT_NE(s1, s3); + EXPECT_NE(s2, s3); + EXPECT_GT(s3, s1); + EXPECT_LT(s1, s3); +} + +} // namespace +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/contrib/libs/grpc/test/cpp/util/subprocess.cc b/contrib/libs/grpc/test/cpp/util/subprocess.cc new file mode 100644 index 0000000000..648bd50274 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/util/subprocess.cc @@ -0,0 +1,44 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/cpp/util/subprocess.h" + +#include <vector> + +#include "test/core/util/subprocess.h" + +namespace grpc { + +static gpr_subprocess* MakeProcess(const std::vector<TString>& args) { + std::vector<const char*> vargs; + for (auto it = args.begin(); it != args.end(); ++it) { + vargs.push_back(it->c_str()); + } + return gpr_subprocess_create(vargs.size(), &vargs[0]); +} + +SubProcess::SubProcess(const std::vector<TString>& args) + : subprocess_(MakeProcess(args)) {} + +SubProcess::~SubProcess() { gpr_subprocess_destroy(subprocess_); } + +int SubProcess::Join() { return gpr_subprocess_join(subprocess_); } + +void SubProcess::Interrupt() { gpr_subprocess_interrupt(subprocess_); } + +} // namespace grpc diff --git a/contrib/libs/grpc/test/cpp/util/subprocess.h b/contrib/libs/grpc/test/cpp/util/subprocess.h new file mode 100644 index 0000000000..84dda31dd1 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/util/subprocess.h @@ -0,0 +1,47 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef GRPC_TEST_CPP_UTIL_SUBPROCESS_H +#define GRPC_TEST_CPP_UTIL_SUBPROCESS_H + +#include <initializer_list> +#include <util/generic/string.h> +#include <vector> + +struct gpr_subprocess; + +namespace grpc { + +class SubProcess { + public: + SubProcess(const std::vector<TString>& args); + ~SubProcess(); + + int Join(); + void Interrupt(); + + private: + SubProcess(const SubProcess& other); + SubProcess& operator=(const SubProcess& other); + + gpr_subprocess* const subprocess_; +}; + +} // namespace grpc + +#endif // GRPC_TEST_CPP_UTIL_SUBPROCESS_H diff --git a/contrib/libs/grpc/test/cpp/util/test_config.h b/contrib/libs/grpc/test/cpp/util/test_config.h new file mode 100644 index 0000000000..094ed44f63 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/util/test_config.h @@ -0,0 +1,30 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef GRPC_TEST_CPP_UTIL_TEST_CONFIG_H +#define GRPC_TEST_CPP_UTIL_TEST_CONFIG_H + +namespace grpc { +namespace testing { + +void InitTest(int* argc, char*** argv, bool remove_flags); + +} // namespace testing +} // namespace grpc + +#endif // GRPC_TEST_CPP_UTIL_TEST_CONFIG_H diff --git a/contrib/libs/grpc/test/cpp/util/test_config_cc.cc b/contrib/libs/grpc/test/cpp/util/test_config_cc.cc new file mode 100644 index 0000000000..e4b6886335 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/util/test_config_cc.cc @@ -0,0 +1,37 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <gflags/gflags.h> +#include "test/cpp/util/test_config.h" + +// In some distros, gflags is in the namespace google, and in some others, +// in gflags. This hack is enabling us to find both. +namespace google {} +namespace gflags {} +using namespace google; +using namespace gflags; + +namespace grpc { +namespace testing { + +void InitTest(int* argc, char*** argv, bool remove_flags) { + ParseCommandLineFlags(argc, argv, remove_flags); +} + +} // namespace testing +} // namespace grpc diff --git a/contrib/libs/grpc/test/cpp/util/test_credentials_provider.cc b/contrib/libs/grpc/test/cpp/util/test_credentials_provider.cc new file mode 100644 index 0000000000..f7134b773f --- /dev/null +++ b/contrib/libs/grpc/test/cpp/util/test_credentials_provider.cc @@ -0,0 +1,181 @@ + +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/cpp/util/test_credentials_provider.h" + +#include <cstdio> +#include <fstream> +#include <iostream> + +#include <mutex> +#include <unordered_map> + +#include <gflags/gflags.h> +#include <grpc/support/log.h> +#include <grpc/support/sync.h> +#include <grpcpp/security/server_credentials.h> + +#include "test/core/end2end/data/ssl_test_data.h" + +DEFINE_string(tls_cert_file, "", "The TLS cert file used when --use_tls=true"); +DEFINE_string(tls_key_file, "", "The TLS key file used when --use_tls=true"); + +namespace grpc { +namespace testing { +namespace { + +TString ReadFile(const TString& src_path) { + std::ifstream src; + src.open(src_path, std::ifstream::in | std::ifstream::binary); + + TString contents; + src.seekg(0, std::ios::end); + contents.reserve(src.tellg()); + src.seekg(0, std::ios::beg); + contents.assign((std::istreambuf_iterator<char>(src)), + (std::istreambuf_iterator<char>())); + return contents; +} + +class DefaultCredentialsProvider : public CredentialsProvider { + public: + DefaultCredentialsProvider() { + if (!FLAGS_tls_key_file.empty()) { + custom_server_key_ = ReadFile(FLAGS_tls_key_file); + } + if (!FLAGS_tls_cert_file.empty()) { + custom_server_cert_ = ReadFile(FLAGS_tls_cert_file); + } + } + ~DefaultCredentialsProvider() override {} + + void AddSecureType( + const TString& type, + std::unique_ptr<CredentialTypeProvider> type_provider) override { + // This clobbers any existing entry for type, except the defaults, which + // can't be clobbered. + std::unique_lock<std::mutex> lock(mu_); + auto it = std::find(added_secure_type_names_.begin(), + added_secure_type_names_.end(), type); + if (it == added_secure_type_names_.end()) { + added_secure_type_names_.push_back(type); + added_secure_type_providers_.push_back(std::move(type_provider)); + } else { + added_secure_type_providers_[it - added_secure_type_names_.begin()] = + std::move(type_provider); + } + } + + std::shared_ptr<ChannelCredentials> GetChannelCredentials( + const TString& type, ChannelArguments* args) override { + if (type == grpc::testing::kInsecureCredentialsType) { + return InsecureChannelCredentials(); + } else if (type == grpc::testing::kAltsCredentialsType) { + grpc::experimental::AltsCredentialsOptions alts_opts; + return grpc::experimental::AltsCredentials(alts_opts); + } else if (type == grpc::testing::kTlsCredentialsType) { + SslCredentialsOptions ssl_opts = {test_root_cert, "", ""}; + args->SetSslTargetNameOverride("foo.test.google.fr"); + return grpc::SslCredentials(ssl_opts); + } else if (type == grpc::testing::kGoogleDefaultCredentialsType) { + return grpc::GoogleDefaultCredentials(); + } else { + std::unique_lock<std::mutex> lock(mu_); + auto it(std::find(added_secure_type_names_.begin(), + added_secure_type_names_.end(), type)); + if (it == added_secure_type_names_.end()) { + gpr_log(GPR_ERROR, "Unsupported credentials type %s.", type.c_str()); + return nullptr; + } + return added_secure_type_providers_[it - added_secure_type_names_.begin()] + ->GetChannelCredentials(args); + } + } + + std::shared_ptr<ServerCredentials> GetServerCredentials( + const TString& type) override { + if (type == grpc::testing::kInsecureCredentialsType) { + return InsecureServerCredentials(); + } else if (type == grpc::testing::kAltsCredentialsType) { + grpc::experimental::AltsServerCredentialsOptions alts_opts; + return grpc::experimental::AltsServerCredentials(alts_opts); + } else if (type == grpc::testing::kTlsCredentialsType) { + SslServerCredentialsOptions ssl_opts; + ssl_opts.pem_root_certs = ""; + if (!custom_server_key_.empty() && !custom_server_cert_.empty()) { + SslServerCredentialsOptions::PemKeyCertPair pkcp = { + custom_server_key_, custom_server_cert_}; + ssl_opts.pem_key_cert_pairs.push_back(pkcp); + } else { + SslServerCredentialsOptions::PemKeyCertPair pkcp = {test_server1_key, + test_server1_cert}; + ssl_opts.pem_key_cert_pairs.push_back(pkcp); + } + return SslServerCredentials(ssl_opts); + } else { + std::unique_lock<std::mutex> lock(mu_); + auto it(std::find(added_secure_type_names_.begin(), + added_secure_type_names_.end(), type)); + if (it == added_secure_type_names_.end()) { + gpr_log(GPR_ERROR, "Unsupported credentials type %s.", type.c_str()); + return nullptr; + } + return added_secure_type_providers_[it - added_secure_type_names_.begin()] + ->GetServerCredentials(); + } + } + std::vector<TString> GetSecureCredentialsTypeList() override { + std::vector<TString> types; + types.push_back(grpc::testing::kTlsCredentialsType); + std::unique_lock<std::mutex> lock(mu_); + for (auto it = added_secure_type_names_.begin(); + it != added_secure_type_names_.end(); it++) { + types.push_back(*it); + } + return types; + } + + private: + std::mutex mu_; + std::vector<TString> added_secure_type_names_; + std::vector<std::unique_ptr<CredentialTypeProvider>> + added_secure_type_providers_; + TString custom_server_key_; + TString custom_server_cert_; +}; + +CredentialsProvider* g_provider = nullptr; + +} // namespace + +CredentialsProvider* GetCredentialsProvider() { + if (g_provider == nullptr) { + g_provider = new DefaultCredentialsProvider; + } + return g_provider; +} + +void SetCredentialsProvider(CredentialsProvider* provider) { + // For now, forbids overriding provider. + GPR_ASSERT(g_provider == nullptr); + g_provider = provider; +} + +} // namespace testing +} // namespace grpc diff --git a/contrib/libs/grpc/test/cpp/util/test_credentials_provider.h b/contrib/libs/grpc/test/cpp/util/test_credentials_provider.h new file mode 100644 index 0000000000..acba277ada --- /dev/null +++ b/contrib/libs/grpc/test/cpp/util/test_credentials_provider.h @@ -0,0 +1,85 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef GRPC_TEST_CPP_UTIL_TEST_CREDENTIALS_PROVIDER_H +#define GRPC_TEST_CPP_UTIL_TEST_CREDENTIALS_PROVIDER_H + +#include <memory> + +#include <grpcpp/security/credentials.h> +#include <grpcpp/security/server_credentials.h> +#include <grpcpp/support/channel_arguments.h> + +namespace grpc { +namespace testing { + +const char kInsecureCredentialsType[] = "INSECURE_CREDENTIALS"; +// For real credentials, like tls/ssl, this name should match the AuthContext +// property "transport_security_type". +const char kTlsCredentialsType[] = "ssl"; +const char kAltsCredentialsType[] = "alts"; +const char kGoogleDefaultCredentialsType[] = "google_default_credentials"; + +// Provide test credentials of a particular type. +class CredentialTypeProvider { + public: + virtual ~CredentialTypeProvider() {} + + virtual std::shared_ptr<ChannelCredentials> GetChannelCredentials( + ChannelArguments* args) = 0; + virtual std::shared_ptr<ServerCredentials> GetServerCredentials() = 0; +}; + +// Provide test credentials. Thread-safe. +class CredentialsProvider { + public: + virtual ~CredentialsProvider() {} + + // Add a secure type in addition to the defaults. The default provider has + // (kInsecureCredentialsType, kTlsCredentialsType). + virtual void AddSecureType( + const TString& type, + std::unique_ptr<CredentialTypeProvider> type_provider) = 0; + + // Provide channel credentials according to the given type. Alter the channel + // arguments if needed. Return nullptr if type is not registered. + virtual std::shared_ptr<ChannelCredentials> GetChannelCredentials( + const TString& type, ChannelArguments* args) = 0; + + // Provide server credentials according to the given type. + // Return nullptr if type is not registered. + virtual std::shared_ptr<ServerCredentials> GetServerCredentials( + const TString& type) = 0; + + // Provide a list of secure credentials type. + virtual std::vector<TString> GetSecureCredentialsTypeList() = 0; +}; + +// Get the current provider. Create a default one if not set. +// Not thread-safe. +CredentialsProvider* GetCredentialsProvider(); + +// Set the global provider. Takes ownership. The previous set provider will be +// destroyed. +// Not thread-safe. +void SetCredentialsProvider(CredentialsProvider* provider); + +} // namespace testing +} // namespace grpc + +#endif // GRPC_TEST_CPP_UTIL_TEST_CREDENTIALS_PROVIDER_H diff --git a/contrib/libs/grpc/test/cpp/util/time_test.cc b/contrib/libs/grpc/test/cpp/util/time_test.cc new file mode 100644 index 0000000000..bcbfa14f94 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/util/time_test.cc @@ -0,0 +1,72 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <grpc/support/time.h> +#include <grpcpp/support/time.h> +#include <gtest/gtest.h> + +#include "test/core/util/test_config.h" + +using std::chrono::duration_cast; +using std::chrono::microseconds; +using std::chrono::system_clock; + +namespace grpc { +namespace { + +class TimeTest : public ::testing::Test {}; + +TEST_F(TimeTest, AbsolutePointTest) { + int64_t us = 10000000L; + gpr_timespec ts = gpr_time_from_micros(us, GPR_TIMESPAN); + ts.clock_type = GPR_CLOCK_REALTIME; + system_clock::time_point tp{microseconds(us)}; + system_clock::time_point tp_converted = Timespec2Timepoint(ts); + gpr_timespec ts_converted; + Timepoint2Timespec(tp_converted, &ts_converted); + EXPECT_TRUE(ts.tv_sec == ts_converted.tv_sec); + EXPECT_TRUE(ts.tv_nsec == ts_converted.tv_nsec); + system_clock::time_point tp_converted_2 = Timespec2Timepoint(ts_converted); + EXPECT_TRUE(tp == tp_converted); + EXPECT_TRUE(tp == tp_converted_2); +} + +// gpr_inf_future is treated specially and mapped to/from time_point::max() +TEST_F(TimeTest, InfFuture) { + EXPECT_EQ(system_clock::time_point::max(), + Timespec2Timepoint(gpr_inf_future(GPR_CLOCK_REALTIME))); + gpr_timespec from_time_point_max; + Timepoint2Timespec(system_clock::time_point::max(), &from_time_point_max); + EXPECT_EQ( + 0, gpr_time_cmp(gpr_inf_future(GPR_CLOCK_REALTIME), from_time_point_max)); + // This will cause an overflow + Timepoint2Timespec( + std::chrono::time_point<system_clock, std::chrono::seconds>::max(), + &from_time_point_max); + EXPECT_EQ( + 0, gpr_time_cmp(gpr_inf_future(GPR_CLOCK_REALTIME), from_time_point_max)); +} + +} // namespace +} // namespace grpc + +int main(int argc, char** argv) { + grpc::testing::TestEnvironment env(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/contrib/libs/grpc/test/cpp/util/ya.make b/contrib/libs/grpc/test/cpp/util/ya.make new file mode 100644 index 0000000000..f043cc5b14 --- /dev/null +++ b/contrib/libs/grpc/test/cpp/util/ya.make @@ -0,0 +1,39 @@ +LIBRARY() + +LICENSE(Apache-2.0) + +LICENSE_TEXTS(.yandex_meta/licenses.list.txt) + +OWNER(orivej) + +PEERDIR( + contrib/libs/gflags + contrib/libs/protoc + contrib/libs/grpc/src/proto/grpc/reflection/v1alpha + contrib/restricted/googletest/googlemock + contrib/restricted/googletest/googletest +) + +ADDINCL( + ${ARCADIA_BUILD_ROOT}/contrib/libs/grpc + contrib/libs/grpc +) + +NO_COMPILER_WARNINGS() + +SRCS( + byte_buffer_proto_helper.cc + # grpc_cli_libs: + cli_call.cc + cli_credentials.cc + grpc_tool.cc + proto_file_parser.cc + service_describer.cc + string_ref_helper.cc + # grpc++_proto_reflection_desc_db: + proto_reflection_descriptor_database.cc + # grpc++_test_config: + test_config_cc.cc +) + +END() |