From 08e7cde56ecd933346be66e2d41418a6ea0a0ab3 Mon Sep 17 00:00:00 2001
From: jolex007 <jolex007@yandex-team.com>
Date: Thu, 27 Feb 2025 12:33:07 +0300
Subject: Fix cancellation in unifetcher
 commit_hash:909fa7aadbf673448dbc709b19d2088963b40404

---
 library/cpp/http/simple/http_client.cpp | 34 +++++++++++++++++-------------
 library/cpp/http/simple/http_client.h   | 37 +++++++++++++++++++++++++--------
 library/cpp/http/simple/ya.make         |  1 +
 3 files changed, 49 insertions(+), 23 deletions(-)

(limited to 'library/cpp/http/simple')

diff --git a/library/cpp/http/simple/http_client.cpp b/library/cpp/http/simple/http_client.cpp
index bac6bdc39e..2be5a14582 100644
--- a/library/cpp/http/simple/http_client.cpp
+++ b/library/cpp/http/simple/http_client.cpp
@@ -25,26 +25,30 @@ TKeepAliveHttpClient::TKeepAliveHttpClient(const TString& host,
 TKeepAliveHttpClient::THttpCode TKeepAliveHttpClient::DoGet(const TStringBuf relativeUrl,
                                                             IOutputStream* output,
                                                             const THeaders& headers,
-                                                            THttpHeaders* outHeaders) {
+                                                            THttpHeaders* outHeaders,
+                                                            NThreading::TCancellationToken cancellation) {
     return DoRequest(TStringBuf("GET"),
                      relativeUrl,
                      {},
                      output,
                      headers,
-                     outHeaders);
+                     outHeaders,
+                     std::move(cancellation));
 }
 
 TKeepAliveHttpClient::THttpCode TKeepAliveHttpClient::DoPost(const TStringBuf relativeUrl,
                                                              const TStringBuf body,
                                                              IOutputStream* output,
                                                              const THeaders& headers,
-                                                             THttpHeaders* outHeaders) {
+                                                             THttpHeaders* outHeaders,
+                                                             NThreading::TCancellationToken cancellation) {
     return DoRequest(TStringBuf("POST"),
                      relativeUrl,
                      body,
                      output,
                      headers,
-                     outHeaders);
+                     outHeaders,
+                     std::move(cancellation));
 }
 
 TKeepAliveHttpClient::THttpCode TKeepAliveHttpClient::DoRequest(const TStringBuf method,
@@ -52,15 +56,17 @@ TKeepAliveHttpClient::THttpCode TKeepAliveHttpClient::DoRequest(const TStringBuf
                                                                 const TStringBuf body,
                                                                 IOutputStream* output,
                                                                 const THeaders& inHeaders,
-                                                                THttpHeaders* outHeaders) {
+                                                                THttpHeaders* outHeaders,
+                                                                NThreading::TCancellationToken cancellation) {
     const TString contentLength = IntToString<10, size_t>(body.size());
-    return DoRequestReliable(FormRequest(method, relativeUrl, body, inHeaders, contentLength), output, outHeaders);
+    return DoRequestReliable(FormRequest(method, relativeUrl, body, inHeaders, contentLength), output, outHeaders, std::move(cancellation));
 }
 
 TKeepAliveHttpClient::THttpCode TKeepAliveHttpClient::DoRequestRaw(const TStringBuf raw,
                                                                    IOutputStream* output,
-                                                                   THttpHeaders* outHeaders) {
-    return DoRequestReliable(raw, output, outHeaders);
+                                                                   THttpHeaders* outHeaders,
+                                                                   NThreading::TCancellationToken cancellation) {
+    return DoRequestReliable(raw, output, outHeaders, std::move(cancellation));
 }
 
 void TKeepAliveHttpClient::DisableVerificationForHttps() {
@@ -189,28 +195,28 @@ void TSimpleHttpClient::EnableVerificationForHttps() {
     HttpsVerification = true;
 }
 
-void TSimpleHttpClient::DoGet(const TStringBuf relativeUrl, IOutputStream* output, const THeaders& headers) const {
+void TSimpleHttpClient::DoGet(const TStringBuf relativeUrl, IOutputStream* output, const THeaders& headers, NThreading::TCancellationToken cancellation) const {
     TKeepAliveHttpClient cl = CreateClient();
 
-    TKeepAliveHttpClient::THttpCode code = cl.DoGet(relativeUrl, output, headers);
+    TKeepAliveHttpClient::THttpCode code = cl.DoGet(relativeUrl, output, headers, nullptr, std::move(cancellation));
 
     Y_ENSURE(cl.GetHttpInput());
     ProcessResponse(relativeUrl, *cl.GetHttpInput(), output, code);
 }
 
-void TSimpleHttpClient::DoPost(const TStringBuf relativeUrl, TStringBuf body, IOutputStream* output, const THashMap<TString, TString>& headers) const {
+void TSimpleHttpClient::DoPost(const TStringBuf relativeUrl, TStringBuf body, IOutputStream* output, const THashMap<TString, TString>& headers, NThreading::TCancellationToken cancellation) const {
     TKeepAliveHttpClient cl = CreateClient();
 
-    TKeepAliveHttpClient::THttpCode code = cl.DoPost(relativeUrl, body, output, headers);
+    TKeepAliveHttpClient::THttpCode code = cl.DoPost(relativeUrl, body, output, headers, nullptr, std::move(cancellation));
 
     Y_ENSURE(cl.GetHttpInput());
     ProcessResponse(relativeUrl, *cl.GetHttpInput(), output, code);
 }
 
-void TSimpleHttpClient::DoPostRaw(const TStringBuf relativeUrl, const TStringBuf rawRequest, IOutputStream* output) const {
+void TSimpleHttpClient::DoPostRaw(const TStringBuf relativeUrl, const TStringBuf rawRequest, IOutputStream* output, NThreading::TCancellationToken cancellation) const {
     TKeepAliveHttpClient cl = CreateClient();
 
-    TKeepAliveHttpClient::THttpCode code = cl.DoRequestRaw(rawRequest, output);
+    TKeepAliveHttpClient::THttpCode code = cl.DoRequestRaw(rawRequest, output, nullptr, std::move(cancellation));
 
     Y_ENSURE(cl.GetHttpInput());
     ProcessResponse(relativeUrl, *cl.GetHttpInput(), output, code);
diff --git a/library/cpp/http/simple/http_client.h b/library/cpp/http/simple/http_client.h
index eab92d42da..224be58a24 100644
--- a/library/cpp/http/simple/http_client.h
+++ b/library/cpp/http/simple/http_client.h
@@ -12,6 +12,7 @@
 #include <library/cpp/http/io/stream.h>
 #include <library/cpp/http/misc/httpcodes.h>
 #include <library/cpp/openssl/io/stream.h>
+#include <library/cpp/threading/cancellation/cancellation_token.h>
 
 class TNetworkAddress;
 class IOutputStream;
@@ -54,14 +55,16 @@ public:
     THttpCode DoGet(const TStringBuf relativeUrl,
                     IOutputStream* output = nullptr,
                     const THeaders& headers = THeaders(),
-                    THttpHeaders* outHeaders = nullptr);
+                    THttpHeaders* outHeaders = nullptr,
+                    NThreading::TCancellationToken cancellation = NThreading::TCancellationToken::Default());
 
     // builds post request from headers and body
     THttpCode DoPost(const TStringBuf relativeUrl,
                      const TStringBuf body,
                      IOutputStream* output = nullptr,
                      const THeaders& headers = THeaders(),
-                     THttpHeaders* outHeaders = nullptr);
+                     THttpHeaders* outHeaders = nullptr,
+                     NThreading::TCancellationToken cancellation = NThreading::TCancellationToken::Default());
 
     // builds request with any HTTP method from headers and body
     THttpCode DoRequest(const TStringBuf method,
@@ -69,12 +72,14 @@ public:
                         const TStringBuf body,
                         IOutputStream* output = nullptr,
                         const THeaders& inHeaders = THeaders(),
-                        THttpHeaders* outHeaders = nullptr);
+                        THttpHeaders* outHeaders = nullptr,
+                        NThreading::TCancellationToken cancellation = NThreading::TCancellationToken::Default());
 
     // requires already well-formed request
     THttpCode DoRequestRaw(const TStringBuf raw,
                            IOutputStream* output = nullptr,
-                           THttpHeaders* outHeaders = nullptr);
+                           THttpHeaders* outHeaders = nullptr,
+                           NThreading::TCancellationToken cancellation = NThreading::TCancellationToken::Default());
 
     void DisableVerificationForHttps();
     void SetClientCertificate(const TOpenSslClientIO::TOptions::TClientCert& options);
@@ -93,7 +98,8 @@ private:
     template <class T>
     THttpCode DoRequestReliable(const T& raw,
                                 IOutputStream* output,
-                                THttpHeaders* outHeaders);
+                                THttpHeaders* outHeaders,
+                                NThreading::TCancellationToken cancellation);
 
     TVector<IOutputStream::TPart> FormRequest(TStringBuf method, const TStringBuf relativeUrl,
                                               TStringBuf body,
@@ -166,13 +172,13 @@ public:
 
     void EnableVerificationForHttps();
 
-    void DoGet(const TStringBuf relativeUrl, IOutputStream* output, const THeaders& headers = THeaders()) const;
+    void DoGet(const TStringBuf relativeUrl, IOutputStream* output, const THeaders& headers = THeaders(), NThreading::TCancellationToken cancellation = NThreading::TCancellationToken::Default()) const;
 
     // builds post request from headers and body
-    void DoPost(const TStringBuf relativeUrl, TStringBuf body, IOutputStream* output, const THeaders& headers = THeaders()) const;
+    void DoPost(const TStringBuf relativeUrl, TStringBuf body, IOutputStream* output, const THeaders& headers = THeaders(), NThreading::TCancellationToken cancellation = NThreading::TCancellationToken::Default()) const;
 
     // requires already well-formed post request
-    void DoPostRaw(const TStringBuf relativeUrl, TStringBuf rawRequest, IOutputStream* output) const;
+    void DoPostRaw(const TStringBuf relativeUrl, TStringBuf rawRequest, IOutputStream* output, NThreading::TCancellationToken cancellation = NThreading::TCancellationToken::Default()) const;
 
     virtual ~TSimpleHttpClient();
 
@@ -227,6 +233,10 @@ namespace NPrivate {
             return HttpIn.Get();
         }
 
+        void Shutdown() {
+            Socket.ShutDown(SHUT_RDWR);
+        }
+
     private:
         static TNetworkAddress Resolve(const TString& host, ui32 port);
 
@@ -250,12 +260,18 @@ namespace NPrivate {
 template <class T>
 TKeepAliveHttpClient::THttpCode TKeepAliveHttpClient::DoRequestReliable(const T& raw,
                                                                         IOutputStream* output,
-                                                                        THttpHeaders* outHeaders) {
+                                                                        THttpHeaders* outHeaders,
+                                                                        NThreading::TCancellationToken cancellation) {
+
     for (int i = 0; i < 2; ++i) {
         const bool haveNewConnection = CreateNewConnectionIfNeeded();
         const bool couldRetry = !haveNewConnection && i == 0; // Actually old connection could be already closed by server,
                                                               // so we should try one more time in this case.
         try {
+            cancellation.Future().Subscribe([&](auto&) {
+                Connection->Shutdown();
+            });
+
             Connection->Write(raw);
 
             THttpCode code = ReadAndTransferHttp(*Connection->GetHttpInput(), output, outHeaders);
@@ -265,16 +281,19 @@ TKeepAliveHttpClient::THttpCode TKeepAliveHttpClient::DoRequestReliable(const T&
             return code;
         } catch (const TSystemError& e) {
             Connection.Reset();
+            cancellation.ThrowIfCancellationRequested();
             if (!couldRetry || e.Status() != EPIPE) {
                 throw;
             }
         } catch (const THttpReadException&) { // Actually old connection is already closed by server
             Connection.Reset();
+            cancellation.ThrowIfCancellationRequested();
             if (!couldRetry) {
                 throw;
             }
         } catch (const std::exception&) {
             Connection.Reset();
+            cancellation.ThrowIfCancellationRequested();
             throw;
         }
     }
diff --git a/library/cpp/http/simple/ya.make b/library/cpp/http/simple/ya.make
index 40744675e8..6a4e5775a4 100644
--- a/library/cpp/http/simple/ya.make
+++ b/library/cpp/http/simple/ya.make
@@ -4,6 +4,7 @@ PEERDIR(
     library/cpp/http/io
     library/cpp/openssl/io
     library/cpp/string_utils/url
+    library/cpp/threading/cancellation
 )
 
 SRCS(
-- 
cgit v1.2.3