From 9314042511cd9d2598ed16eb0a19c84909895938 Mon Sep 17 00:00:00 2001
From: trofimenkov <trofimenkov@yandex-team.com>
Date: Wed, 29 May 2024 11:24:03 +0300
Subject: MaxRedirectCount param for http/simple
 1b80d64b6a03772edc52f2331a860ff0b5621898

---
 library/cpp/http/simple/http_client.cpp       |  19 +++-
 library/cpp/http/simple/http_client.h         |   7 ++
 library/cpp/http/simple/http_client_options.h |  10 ++
 library/cpp/http/simple/ut/http_ut.cpp        | 135 ++++++++++++++++++++++++++
 4 files changed, 169 insertions(+), 2 deletions(-)

(limited to 'library/cpp')

diff --git a/library/cpp/http/simple/http_client.cpp b/library/cpp/http/simple/http_client.cpp
index 818dc048ad..342236f9a3 100644
--- a/library/cpp/http/simple/http_client.cpp
+++ b/library/cpp/http/simple/http_client.cpp
@@ -301,8 +301,14 @@ TKeepAliveHttpClient TSimpleHttpClient::CreateClient() const {
 void TSimpleHttpClient::PrepareClient(TKeepAliveHttpClient&) const {
 }
 
+TRedirectableHttpClient::TRedirectableHttpClient(const TOptions& options)
+    : TSimpleHttpClient(options)
+    , Opts(options)
+{
+}
+
 TRedirectableHttpClient::TRedirectableHttpClient(const TString& host, ui32 port, TDuration socketTimeout, TDuration connectTimeout)
-    : TSimpleHttpClient(host, port, socketTimeout, connectTimeout)
+    : TRedirectableHttpClient(TOptions().Host(host).Port(port).SocketTimeout(socketTimeout).ConnectTimeout(connectTimeout))
 {
 }
 
@@ -315,6 +321,10 @@ void TRedirectableHttpClient::PrepareClient(TKeepAliveHttpClient& cl) const {
 void TRedirectableHttpClient::ProcessResponse(const TStringBuf relativeUrl, THttpInput& input, IOutputStream* output, const unsigned statusCode) const {
     for (auto i = input.Headers().Begin(), e = input.Headers().End(); i != e; ++i) {
         if (0 == TString::compare(i->Name(), TStringBuf("Location"))) {
+            if (Opts.MaxRedirectCount() == 0) {
+                ythrow THttpRequestException(statusCode) << "Exceeds MaxRedirectCount limit, code " << statusCode << " at " << Host << relativeUrl;
+            }
+
             TVector<TString> request_url_parts, request_body_parts;
 
             size_t splitted_index = 0;
@@ -339,7 +349,12 @@ void TRedirectableHttpClient::ProcessResponse(const TStringBuf relativeUrl, THtt
                 }
             }
 
-            TRedirectableHttpClient cl(url, port, TDuration::Seconds(60), TDuration::Seconds(60));
+            auto opts = Opts;
+            opts.Host(url);
+            opts.Port(port);
+            opts.MaxRedirectCount(opts.MaxRedirectCount() - 1);
+
+            TRedirectableHttpClient cl(opts);
             if (HttpsVerification) {
                 cl.EnableVerificationForHttps();
             }
diff --git a/library/cpp/http/simple/http_client.h b/library/cpp/http/simple/http_client.h
index c01b11ba43..eab92d42da 100644
--- a/library/cpp/http/simple/http_client.h
+++ b/library/cpp/http/simple/http_client.h
@@ -185,12 +185,19 @@ private:
 
 class TRedirectableHttpClient: public TSimpleHttpClient {
 public:
+    using TOptions = TSimpleHttpClientOptions;
+
+    explicit TRedirectableHttpClient(const TOptions& options);
+
     TRedirectableHttpClient(const TString& host, ui32 port, TDuration socketTimeout = TDuration::Seconds(5),
                             TDuration connectTimeout = TDuration::Seconds(30));
 
 private:
     void PrepareClient(TKeepAliveHttpClient& cl) const override;
     void ProcessResponse(const TStringBuf relativeUrl, THttpInput& input, IOutputStream* output, const unsigned statusCode) const override;
+
+private:
+    TOptions Opts;
 };
 
 namespace NPrivate {
diff --git a/library/cpp/http/simple/http_client_options.h b/library/cpp/http/simple/http_client_options.h
index f2e964a462..603ca5103a 100644
--- a/library/cpp/http/simple/http_client_options.h
+++ b/library/cpp/http/simple/http_client_options.h
@@ -51,9 +51,19 @@ public:
         return ConnectTimeout_;
     }
 
+    TSelf& MaxRedirectCount(int count) {
+        MaxRedirectCount_ = count;
+        return *this;
+    }
+
+    ui16 MaxRedirectCount() const noexcept {
+        return MaxRedirectCount_;
+    }
+
 private:
     TString Host_;
     ui16 Port_;
     TDuration SocketTimeout_ = TDuration::Seconds(5);
     TDuration ConnectTimeout_ = TDuration::Seconds(30);
+    int MaxRedirectCount_ = INT_MAX;
 };
diff --git a/library/cpp/http/simple/ut/http_ut.cpp b/library/cpp/http/simple/ut/http_ut.cpp
index bf7e767428..7768fdc4fa 100644
--- a/library/cpp/http/simple/ut/http_ut.cpp
+++ b/library/cpp/http/simple/ut/http_ut.cpp
@@ -71,6 +71,76 @@ Y_UNIT_TEST_SUITE(SimpleHttp) {
         }
     };
 
+    class TScenario {
+    public:
+        struct TElem {
+            TString Url;
+            int Status = HTTP_OK;
+            TString Content{};
+        };
+
+        TScenario(const TVector<TElem>& seq, ui16 port = 80, TDuration sleep = TDuration())
+            : Seq_(seq)
+            , Sleep_(sleep)
+            , Port_(port)
+        {
+        }
+
+        bool DoReply(const TRequestReplier::TReplyParams& params, TRequestReplier* replier) {
+            const auto parsed = TParsedHttpFull(params.Input.FirstLine());
+            const auto url = parsed.Request;
+            params.Input.ReadAll();
+
+            UNIT_ASSERT(SeqIdx_ < Seq_.size());
+            auto& elem = Seq_[SeqIdx_++];
+
+            UNIT_ASSERT_VALUES_EQUAL(elem.Url, url);
+
+            Sleep(Sleep_);
+
+            if (elem.Status == -1) {
+                replier->ResetConnection(); // RST / ECONNRESET
+                return true;
+            }
+
+            THttpResponse resp((HttpCodes)elem.Status);
+
+            if (elem.Status >= 300 && elem.Status < 400) {
+                UNIT_ASSERT(SeqIdx_ < Seq_.size());
+                resp.AddHeader("Location", TStringBuilder() << "http://localhost:" << Port_ << Seq_[SeqIdx_].Url);
+            }
+
+            resp.SetContent(elem.Content);
+            resp.OutTo(params.Output);
+
+            return true;
+        }
+
+        void VerifyInvariants() {
+            UNIT_ASSERT_VALUES_EQUAL(SeqIdx_, Seq_.size());
+        }
+
+    private:
+        TVector<TElem> Seq_;
+        size_t SeqIdx_ = 0;
+        TDuration Sleep_;
+        ui16 Port_;
+    };
+
+    class TScenarioReplier: public TRequestReplier {
+        TScenario* Scenario_ = nullptr;
+
+    public:
+        TScenarioReplier(TScenario* scenario)
+            : Scenario_(scenario)
+        {
+        }
+
+        bool DoReply(const TReplyParams& params) override {
+            return Scenario_->DoReply(params, this);
+        }
+    };
+
     class TCodedPong: public TRequestReplier {
         HttpCodes Code_;
 
@@ -129,6 +199,32 @@ Y_UNIT_TEST_SUITE(SimpleHttp) {
         }
     };
 
+    static void TestRedirectCountParam(int maxRedirectCount, int redirectCount) {
+        TPortManager pm;
+        ui16 port = pm.GetPort(80);
+
+        TVector<TScenario::TElem> steps;
+        for (int i = 0; i < redirectCount; ++i) {
+            steps.push_back({"/any", 302});
+        }
+        steps.push_back({"/any", 200, "Hello"});
+        TScenario scenario(steps, port);
+
+        NMock::TMockServer server(createOptions(port, true), [&scenario]() { return new TScenarioReplier(&scenario); });
+
+        TRedirectableHttpClient cl(TSimpleHttpClientOptions().Host("localhost").Port(port).MaxRedirectCount(maxRedirectCount));
+        UNIT_ASSERT_VALUES_EQUAL(0, server.GetClientCount());
+
+        TStringStream s;
+        if (maxRedirectCount >= redirectCount) {
+            UNIT_ASSERT_NO_EXCEPTION(cl.DoGet("/any", &s));
+            UNIT_ASSERT_VALUES_EQUAL("Hello", s.Str());
+            scenario.VerifyInvariants();
+        } else {
+            UNIT_ASSERT_EXCEPTION_CONTAINS(cl.DoGet("/any", &s), THttpRequestException, "");
+        }
+    }
+
     Y_UNIT_TEST(simpleSuccessful) {
         TPortManager pm;
         ui16 port = pm.GetPort(80);
@@ -274,6 +370,45 @@ Y_UNIT_TEST_SUITE(SimpleHttp) {
         }
     }
 
+    Y_UNIT_TEST(redirectCountDefault) {
+        TPortManager pm;
+        ui16 port = pm.GetPort(80);
+
+        TScenario scenario({
+            {"/any", 307},
+            {"/any?param=1", 302},
+            {"/any?param=1", 302},
+            {"/any?param=1", 302},
+            {"/any?param=1", 302},
+            {"/any?param=1", 302},
+            {"/any?param=1", 302},
+            {"/any?param=1", 302},
+            {"/any?param=1", 302},
+            {"/any?param=1", 302},
+            {"/any?param=2", 200, "Hello"}
+        }, port);
+
+        NMock::TMockServer server(createOptions(port, true), [&scenario]() { return new TScenarioReplier(&scenario); });
+
+        TRedirectableHttpClient cl("localhost", port);
+        UNIT_ASSERT_VALUES_EQUAL(0, server.GetClientCount());
+
+        TStringStream s;
+        UNIT_ASSERT_NO_EXCEPTION(cl.DoGet("/any", &s));
+        UNIT_ASSERT_VALUES_EQUAL("Hello", s.Str());
+
+        scenario.VerifyInvariants();
+    }
+
+    Y_UNIT_TEST(redirectCountN) {
+        TestRedirectCountParam(0, 0);
+        TestRedirectCountParam(0, 1);
+        TestRedirectCountParam(1, 1);
+        TestRedirectCountParam(3, 3);
+        TestRedirectCountParam(20, 20);
+        TestRedirectCountParam(20, 21);
+    }
+
     Y_UNIT_TEST(redirectable) {
         TPortManager pm;
         ui16 port = pm.GetPort(80);
-- 
cgit v1.2.3