diff options
author | qrort <qrort@yandex-team.com> | 2022-11-30 23:47:12 +0300 |
---|---|---|
committer | qrort <qrort@yandex-team.com> | 2022-11-30 23:47:12 +0300 |
commit | 22f8ae0e3f5d68b92aecccdf96c1d841a0334311 (patch) | |
tree | bffa27765faf54126ad44bcafa89fadecb7a73d7 /library/cpp/http | |
parent | 332b99e2173f0425444abb759eebcb2fafaa9209 (diff) | |
download | ydb-22f8ae0e3f5d68b92aecccdf96c1d841a0334311.tar.gz |
validate canons without yatest_common
Diffstat (limited to 'library/cpp/http')
34 files changed, 4183 insertions, 0 deletions
diff --git a/library/cpp/http/client/client.cpp b/library/cpp/http/client/client.cpp new file mode 100644 index 0000000000..f36aac37d7 --- /dev/null +++ b/library/cpp/http/client/client.cpp @@ -0,0 +1,232 @@ +#include "client.h" +#include "request.h" + +#include <library/cpp/coroutine/dns/cache.h> +#include <library/cpp/coroutine/dns/coro.h> +#include <library/cpp/coroutine/dns/helpers.h> +#include <library/cpp/coroutine/engine/condvar.h> +#include <library/cpp/coroutine/engine/impl.h> +#include <library/cpp/coroutine/engine/network.h> +#include <library/cpp/coroutine/util/pipeque.h> + +#include <library/cpp/http/client/ssl/sslsock.h> +#include <library/cpp/http/client/fetch/coctx.h> +#include <library/cpp/http/client/fetch/codes.h> +#include <library/cpp/http/client/fetch/cosocket.h> +#include <library/cpp/http/client/fetch/fetch_single.h> + +#include <util/stream/output.h> +#include <util/thread/factory.h> +#include <util/system/event.h> +#include <util/system/spinlock.h> +#include <util/system/thread.h> + +namespace NHttp { + namespace { + using namespace NHttpFetcher; + + class TFetcher: private IThreadFactory::IThreadAble { + public: + TFetcher(const TClientOptions& options) + : Options_(options) + , FetchCoroutines(Max<size_t>(Options_.FetchCoroutines, 1)) + , RequestsQueue_(true, false) + , Done_(false) + { + } + + void Start() { + T_ = SystemThreadFactory()->Run(this); + } + + void Stop() { + if (T_) { + for (size_t i = 0; i < FetchCoroutines; ++i) { + RequestsQueue_.Push(nullptr); + } + + T_->Join(); + T_.Reset(); + } + } + + ~TFetcher() override { + Stop(); + } + + TFetchState FetchAsync(const TFetchRequestRef& req, NHttpFetcher::TCallBack cb) { + req->SetCallback(cb); + RequestsQueue_.Push(req); + return TFetchState(req); + } + + static TFetcher* Instance() { + static struct TFetcherHolder { + TFetcherHolder() { + Fetcher.Start(); + } + + TFetcher Fetcher{{}}; + } holder; + + return &holder.Fetcher; + } + + private: + void DoDrainLoop(TCont* c) { + TInstant nextDrain = TInstant::Now() + Options_.KeepAliveTimeout; + + while (true) { + DrainMutex_.LockI(c); + + while (true) { + if (Done_) { + DrainMutex_.UnLock(); + // All sockets in the connection pool should be cleared + // on some active couroutine. + SocketPool_.Clear(); + return; + } + if (DrainCond_.WaitD(c, &DrainMutex_, nextDrain) != 0) { + // In case of timeout the mutex will be in unlocked state. + break; + } + } + + SocketPool_.Drain(Options_.KeepAliveTimeout); + nextDrain = TInstant::Now() + Options_.KeepAliveTimeout; + } + } + + void DoFetchLoop(TCont* c) { + while (true) { + if (NCoro::PollI(c, RequestsQueue_.PopFd(), CONT_POLL_READ) == 0) { + TFetchRequestRef req; + + if (RequestsQueue_.Pop(&req)) { + if (!req) { + DrainMutex_.LockI(c); + const auto wasDone = Done_; + Done_ = true; + DrainMutex_.UnLock(); + if (!wasDone) { + DrainCond_.Signal(); + } + break; + } + + if (req->IsCancelled()) { + auto result = MakeIntrusive<TResult>(req->GetRequestImpl()->Url, FETCH_CANCELLED); + req->OnResponse(result); + continue; + } + + try { + while (true) { + auto getConnectionPool = [&] () -> TSocketPool* { + if (!Options_.KeepAlive || req->GetForceReconnect()) { + return nullptr; + } + return &SocketPool_; + }; + + auto sleep = req->OnResponse( + FetchSingleImpl(req->GetRequestImpl(), getConnectionPool())); + + if (!req->IsValid()) { + break; + } + + if (sleep != TDuration::Zero()) { + c->SleepT(sleep); + } + } + } catch (...) { + req->SetException(std::current_exception()); + } + } + } + } + } + + void DoExecute() override { + // Executor must be initialized in the same thread that will use it + // for fibers to work correctly on windows + TContExecutor executor(Options_.ExecutorStackSize); + + TThread::SetCurrentThreadName(Options_.Name.c_str()); + NAsyncDns::TOptions dnsOpts; + dnsOpts.SetMaxRequests(200); + NAsyncDns::TContResolver resolver(&executor, dnsOpts); + + THolder<NAsyncDns::TContDnsCache> dnsCache; + if (Options_.DnsCacheLifetime != TDuration::Zero()) { + NAsyncDns::TCacheOptions cacheOptions; + cacheOptions.SetEntryLifetime(Options_.DnsCacheLifetime); + dnsCache = MakeHolder<NAsyncDns::TContDnsCache>(&executor, cacheOptions); + } + + TCoCtxSetter ctxSetter(&executor, &resolver, dnsCache.Get()); + + for (size_t i = 0; i < FetchCoroutines; ++i) { + executor.Create<TFetcher, &TFetcher::DoFetchLoop>(this, "fetch_loop"); + } + + if (Options_.KeepAlive) { + executor.Create<TFetcher, &TFetcher::DoDrainLoop>(this, "drain_loop"); + } + + executor.Execute(); + executor.Abort(); + } + + private: + using IThreadRef = THolder<IThreadFactory::IThread>; + + const TClientOptions Options_; + const size_t FetchCoroutines; + + TContCondVar DrainCond_; + TContMutex DrainMutex_; + TSocketPool SocketPool_; + + /// Queue of incoming requests. + TPipeQueue<TFetchRequestRef> RequestsQueue_; + + bool Done_; + IThreadRef T_; + }; + + } + + class TFetchClient::TImpl: public TFetcher { + public: + inline TImpl(const TClientOptions& options) + : TFetcher(options) + { + } + }; + + TFetchClient::TFetchClient(const TClientOptions& options) + : Impl_(new TImpl(options)) + { + Impl_->Start(); + } + + TFetchClient::~TFetchClient() { + Impl_->Stop(); + } + + TFetchState TFetchClient::Fetch(const TFetchQuery& query, NHttpFetcher::TCallBack cb) { + return Impl_->FetchAsync(TFetchRequest::FromQuery(query), cb); + } + + TResultRef Fetch(const TFetchQuery& query) { + return FetchAsync(query, NHttpFetcher::TCallBack()).Get(); + } + + TFetchState FetchAsync(const TFetchQuery& query, NHttpFetcher::TCallBack cb) { + return TFetcher::Instance()->FetchAsync(TFetchRequest::FromQuery(query), cb); + } + +} diff --git a/library/cpp/http/client/client.h b/library/cpp/http/client/client.h new file mode 100644 index 0000000000..717601989e --- /dev/null +++ b/library/cpp/http/client/client.h @@ -0,0 +1,59 @@ +#pragma once + +#include "query.h" +#include "request.h" + +namespace NHttp { + struct TClientOptions { +#define DECLARE_FIELD(name, type, default) \ + type name{default}; \ + inline TClientOptions& Set##name(const type& value) { \ + name = value; \ + return *this; \ + } + + /// The size of stack of fetching coroutine. + DECLARE_FIELD(ExecutorStackSize, size_t, 1 << 20); + + /// The number of fetching coroutines. + DECLARE_FIELD(FetchCoroutines, size_t, 3); + + DECLARE_FIELD(Name, TString, "GlobalFetcher"); + + /// The lifetime of entries in the DNS cache (if zero then cache is not used). + DECLARE_FIELD(DnsCacheLifetime, TDuration, TDuration::Zero()); + + /// Established connections will be keept for further usage. + DECLARE_FIELD(KeepAlive, bool, false); + + /// How long established connections should be keept. + DECLARE_FIELD(KeepAliveTimeout, TDuration, TDuration::Minutes(5)); + +#undef DECLARE_FIELD + }; + + /** + * Statefull fetching client. + * Can handle multiply fetching request simultaneously. Also it's may apply + * politeness policy to control load of each host. + */ + class TFetchClient { + public: + explicit TFetchClient(const TClientOptions& options = TClientOptions()); + ~TFetchClient(); + + /// Execute give fetch request in asynchronous fashion. + TFetchState Fetch(const TFetchQuery& query, NHttpFetcher::TCallBack cb); + + private: + class TImpl; + THolder<TImpl> Impl_; + }; + + /// Execute give fetch request in synchronous fashion. + NHttpFetcher::TResultRef Fetch(const TFetchQuery& query); + + /// Execute give fetch request in asynchronous fashion. + TFetchState FetchAsync(const TFetchQuery& query, NHttpFetcher::TCallBack cb); + +} diff --git a/library/cpp/http/client/cookies/cookie.h b/library/cpp/http/client/cookies/cookie.h new file mode 100644 index 0000000000..ff2f299cfb --- /dev/null +++ b/library/cpp/http/client/cookies/cookie.h @@ -0,0 +1,20 @@ +#pragma once + +#include <util/datetime/base.h> +#include <util/generic/string.h> + +namespace NHttp { + struct TCookie { + TString Name; + TString Value; + TString Domain; + TString Path; + TString Expires; + int MaxAge = -1; + bool IsSecure = false; + bool IsHttpOnly = false; + + static TCookie Parse(const TString& header); + }; + +} diff --git a/library/cpp/http/client/cookies/cookiestore.cpp b/library/cpp/http/client/cookies/cookiestore.cpp new file mode 100644 index 0000000000..cf053e2cf6 --- /dev/null +++ b/library/cpp/http/client/cookies/cookiestore.cpp @@ -0,0 +1,280 @@ +#include "cookiestore.h" + +#include <library/cpp/deprecated/split/split_iterator.h> + +#include <util/generic/algorithm.h> +#include <util/string/ascii.h> + +#include <time.h> + +namespace NHttp { + bool TCookieStore::TStoredCookie::IsEquivalent(const TStoredCookie& rhs) const { + return (IsHostOnly == rhs.IsHostOnly) && (Cookie.Domain == rhs.Cookie.Domain) && + (Cookie.Path == rhs.Cookie.Path) && (Cookie.Name == rhs.Cookie.Name); + } + + TCookieStore::TCookieStore() { + } + + TCookieStore::~TCookieStore() { + } + + bool TCookieStore::SetCookie(const NUri::TUri& requestUri, const TString& cookieHeader) { + // https://tools.ietf.org/html/rfc6265#section-5.3 + TStoredCookie stored; + stored.CreateTime = Now(); + try { + stored.Cookie = TCookie::Parse(cookieHeader); + } catch (const yexception&) { + // Parse failed, ignore cookie + return false; + } + if (stored.Cookie.Domain) { + if (!DomainMatch(requestUri.GetHost(), stored.Cookie.Domain)) { + // Cookie for other domain + return false; + } + stored.IsHostOnly = false; + } else { + stored.Cookie.Domain = requestUri.GetHost(); + stored.IsHostOnly = true; + } + if (!stored.Cookie.Path) { + stored.Cookie.Path = requestUri.GetField(NUri::TField::FieldPath); + } + stored.ExpireTime = GetExpireTime(stored.Cookie); + + auto g(Guard(Lock_)); + for (auto it = Cookies_.begin(); it != Cookies_.end(); ++it) { + if (it->IsEquivalent(stored)) { + *it = stored; + return true; + } + } + Cookies_.push_back(stored); + return true; + } + + TString TCookieStore::GetCookieString(const NUri::TUri& requestUri) const { + // https://tools.ietf.org/html/rfc6265#section-5.4 + const TInstant now = Now(); + auto g(Guard(Lock_)); + + TVector<TCookieVector::const_iterator> validCookies; + validCookies.reserve(Cookies_.size()); + + // Filter cookies + for (auto it = Cookies_.begin(); it != Cookies_.end(); ++it) { + if (it->IsHostOnly) { + if (!AsciiEqualsIgnoreCase(it->Cookie.Domain, requestUri.GetHost())) { + continue; + } + } else { + if (!DomainMatch(requestUri.GetHost(), it->Cookie.Domain)) { + continue; + } + } + if (!PathMatch(requestUri.GetField(NUri::TField::FieldPath), it->Cookie.Path)) { + continue; + } + if (it->Cookie.IsSecure && requestUri.GetScheme() != NUri::TScheme::SchemeHTTPS) { + continue; + } + if (now >= it->ExpireTime) { + continue; + } + validCookies.push_back(it); + } + // Sort cookies + Sort(validCookies.begin(), validCookies.end(), [](const TCookieVector::const_iterator& a, const TCookieVector::const_iterator& b) { + // Cookies with longer paths are listed before cookies with shorter paths. + auto pa = a->Cookie.Path.length(); + auto pb = b->Cookie.Path.length(); + if (pa != pb) { + return pa > pb; + } + // cookies with earlier creation-times are listed before cookies with later creation-times. + if (a->CreateTime != b->CreateTime) { + return a->CreateTime < b->CreateTime; + } + return &*a < &*b; //Any order + }); + TStringStream os; + for (auto it = validCookies.begin(); it != validCookies.end(); ++it) { + if (!os.Empty()) { + os << "; "; + } + const TStoredCookie& stored = **it; + os << stored.Cookie.Name << "=" << stored.Cookie.Value; + } + return os.Str(); + } + + void TCookieStore::Clear() { + auto g(Guard(Lock_)); + Cookies_.clear(); + } + + bool TCookieStore::DomainMatch(const TStringBuf& requestDomain, const TStringBuf& cookieDomain) const { + // https://tools.ietf.org/html/rfc6265#section-5.1.3 + if (AsciiEqualsIgnoreCase(requestDomain, cookieDomain)) { + return true; + } + if (requestDomain.length() > cookieDomain.length() && + AsciiHasSuffixIgnoreCase(requestDomain, cookieDomain) && + requestDomain[requestDomain.length() - cookieDomain.length() - 1] == '.') { + return true; + } + return false; + } + + bool TCookieStore::PathMatch(const TStringBuf& requestPath, const TStringBuf& cookiePath) const { + // https://tools.ietf.org/html/rfc6265#section-5.1.4 + if (cookiePath == requestPath) { + return true; + } + if (requestPath.StartsWith(cookiePath)) { + if (!cookiePath.empty() && cookiePath.back() == '/') { + return true; + } + if (requestPath.length() > cookiePath.length() && requestPath[cookiePath.length()] == '/') { + return true; + } + } + return false; + } + + TInstant TCookieStore::GetExpireTime(const TCookie& cookie) { + // Алгоритм скопирован из blink: net/cookies/canonical_cookie.cc:CanonExpiration + // First, try the Max-Age attribute. + if (cookie.MaxAge >= 0) { + return TDuration::Seconds(cookie.MaxAge).ToDeadLine(); + } + // Try the Expires attribute. + if (cookie.Expires) { + static const TStringBuf kMonths[] = { + TStringBuf("jan"), TStringBuf("feb"), TStringBuf("mar"), + TStringBuf("apr"), TStringBuf("may"), TStringBuf("jun"), + TStringBuf("jul"), TStringBuf("aug"), TStringBuf("sep"), + TStringBuf("oct"), TStringBuf("nov"), TStringBuf("dec")}; + static const int kMonthsLen = Y_ARRAY_SIZE(kMonths); + // We want to be pretty liberal, and support most non-ascii and non-digit + // characters as a delimiter. We can't treat : as a delimiter, because it + // is the delimiter for hh:mm:ss, and we want to keep this field together. + // We make sure to include - and +, since they could prefix numbers. + // If the cookie attribute came in in quotes (ex expires="XXX"), the quotes + // will be preserved, and we will get them here. So we make sure to include + // quote characters, and also \ for anything that was internally escaped. + static const TSplitDelimiters kDelimiters("\t !\"#$%&'()*+,-./;<=>?@[\\]^_`{|}~"); + + struct tm exploded; + Zero(exploded); + + TDelimitersSplit tokenizer(cookie.Expires.data(), cookie.Expires.size(), kDelimiters); + TDelimitersSplit::TIterator tokenizerIt = tokenizer.Iterator(); + + bool found_day_of_month = false; + bool found_month = false; + bool found_time = false; + bool found_year = false; + + while (true) { + const TStringBuf token = tokenizerIt.NextTok(); + if (!token.IsInited()) { + break; + } + if (token.empty()) { + continue; + } + + bool numerical = IsAsciiDigit(token[0]); + + // String field + if (!numerical) { + if (!found_month) { + for (int i = 0; i < kMonthsLen; ++i) { + // Match prefix, so we could match January, etc + if (AsciiHasPrefixIgnoreCase(token, kMonths[i])) { + exploded.tm_mon = i; + found_month = true; + break; + } + } + } else { + // If we've gotten here, it means we've already found and parsed our + // month, and we have another string, which we would expect to be the + // the time zone name. According to the RFC and my experiments with + // how sites format their expirations, we don't have much of a reason + // to support timezones. We don't want to ever barf on user input, + // but this DCHECK should pass for well-formed data. + // DCHECK(token == "GMT"); + } + // Numeric field w/ a colon + } else if (token.Contains(':')) { + if (!found_time && + sscanf( + token.data(), "%2u:%2u:%2u", &exploded.tm_hour, + &exploded.tm_min, &exploded.tm_sec) == 3) { + found_time = true; + } else { + // We should only ever encounter one time-like thing. If we're here, + // it means we've found a second, which shouldn't happen. We keep + // the first. This check should be ok for well-formed input: + // NOTREACHED(); + } + // Numeric field + } else { + // Overflow with atoi() is unspecified, so we enforce a max length. + if (!found_day_of_month && token.size() <= 2) { + exploded.tm_mday = atoi(token.data()); + found_day_of_month = true; + } else if (!found_year && token.size() <= 5) { + exploded.tm_year = atoi(token.data()); + found_year = true; + } else { + // If we're here, it means we've either found an extra numeric field, + // or a numeric field which was too long. For well-formed input, the + // following check would be reasonable: + // NOTREACHED(); + } + } + } + + if (!found_day_of_month || !found_month || !found_time || !found_year) { + // We didn't find all of the fields we need. For well-formed input, the + // following check would be reasonable: + // NOTREACHED() << "Cookie parse expiration failed: " << time_string; + return TInstant::Max(); + } + + // Normalize the year to expand abbreviated years to the full year. + if (exploded.tm_year >= 69 && exploded.tm_year <= 99) { + exploded.tm_year += 1900; + } + if (exploded.tm_year >= 0 && exploded.tm_year <= 68) { + exploded.tm_year += 2000; + } + + // If our values are within their correct ranges, we got our time. + if (exploded.tm_mday >= 1 && exploded.tm_mday <= 31 && + exploded.tm_mon >= 0 && exploded.tm_mon <= 11 && + exploded.tm_year >= 1601 && exploded.tm_year <= 30827 && + exploded.tm_hour <= 23 && exploded.tm_min <= 59 && exploded.tm_sec <= 59) + { + exploded.tm_year -= 1900; // Adopt to tm struct + // Convert to TInstant + time_t tt = TimeGM(&exploded); + if (tt != -1) { + return TInstant::Seconds(tt); + } + } + + // One of our values was out of expected range. For well-formed input, + // the following check would be reasonable: + // NOTREACHED() << "Cookie exploded expiration failed: " << time_string; + } + // Invalid or no expiration, persistent cookie. + return TInstant::Max(); + } + +} diff --git a/library/cpp/http/client/cookies/cookiestore.h b/library/cpp/http/client/cookies/cookiestore.h new file mode 100644 index 0000000000..1bdc19bfe7 --- /dev/null +++ b/library/cpp/http/client/cookies/cookiestore.h @@ -0,0 +1,54 @@ +#pragma once + +#include "cookie.h" + +#include <library/cpp/uri/uri.h> + +#include <util/generic/vector.h> +#include <util/system/mutex.h> + +namespace NHttp { + /** + * Cookie storage for values obtained from a server via Set-Cookie header. + * + * Later client may use GetCookieString to build a cookie for sending + * back to the server via Cookie header. + */ + class TCookieStore { + public: + TCookieStore(); + ~TCookieStore(); + + /// Removes all cookies from store. + void Clear(); + + /// Builds Cookie header from the given url. + TString GetCookieString(const NUri::TUri& requestUri) const; + + /// Parses cookie from the Set-Cookie header and stores it. + bool SetCookie(const NUri::TUri& requestUri, const TString& cookieHeader); + + private: + bool DomainMatch(const TStringBuf& requestDomain, const TStringBuf& cookieDomain) const; + bool PathMatch(const TStringBuf& requestPath, const TStringBuf& cookiePath) const; + static TInstant GetExpireTime(const TCookie& cookie); + + private: + struct TStoredCookie { + TCookie Cookie; + TInstant CreateTime; + TInstant ExpireTime; + bool IsHostOnly = true; + + /// Compares only Domain, Path and name. + bool IsEquivalent(const TStoredCookie& rhs) const; + }; + + using TCookieVector = TVector<TStoredCookie>; + + + TMutex Lock_; + TCookieVector Cookies_; + }; + +} diff --git a/library/cpp/http/client/cookies/parser.rl6 b/library/cpp/http/client/cookies/parser.rl6 new file mode 100644 index 0000000000..700140e582 --- /dev/null +++ b/library/cpp/http/client/cookies/parser.rl6 @@ -0,0 +1,121 @@ +#include <library/cpp/http/client/cookies/cookie.h> + +#include <util/datetime/parser.h> + +namespace NHttp { +namespace { +%%{ +machine http_cookie_parser; + +include HttpDateTimeParser "../../../../../util/datetime/parser.rl6"; + +alphtype unsigned char; + +################# Actions ################# +action set_name { + result.Name = TString((const char*)S_, p - S_); +} + +action set_value { + valueSet = true; + result.Value = TString((const char*)S_, p - S_); +} + +action set_expires { + // TODO take care about server's date ? + result.Expires = TString((const char*)S_, p - S_); +} + +action set_max_age { + // TODO take care about server's date ? + result.MaxAge = I; +} + +action set_domain { + result.Domain = TString((const char*)S_, p - S_); + result.Domain.to_lower(); +} + +action set_path { + result.Path = TString((const char*)S_, p - S_); +} + +action set_secure { + result.IsSecure = true; + +} +action set_httponly { + result.IsHttpOnly = true; +} + +################# Basic Rules ################# +ws = [ \t]; + +separators = '(' | ')' | '<' | '>' | '@' | ',' | ';' | ':' | '\\' | + '"' | '/' | '[' | ']' | '?' | '=' | '{' | '}' | 32 | 9; + +token_octet = 32..126 - separators; +token = token_octet+; + +other_octet = 32..126 - ';'; +other = other_octet+; + +cookie_name_octet = other_octet - '='; + +############ Set-Cookie line ################# +# See https://tools.ietf.org/html/rfc6265 +cookie_name = cookie_name_octet+ > { S_ = p; } %set_name; +cookie_value = other_octet* > { S_ = p; } %set_value; +cookie_pair = cookie_name "=" cookie_value; + +expires_av = "expires="i other > { S_ = p; } %set_expires; +max_age_av = "max-age="i int %set_max_age; +domain_av = "domain="i "."? other > { S_ = p; } %set_domain; +path_av = "path="i other > { S_ = p; } %set_path; +secure_av = "secure"i %set_secure; +httponly_av = "httponly"i %set_httponly; +extension_av = other; + +cookie_av = extension_av | expires_av | max_age_av | domain_av | + path_av | secure_av | httponly_av; + +set_cookie_string = cookie_pair ( ";" ws* cookie_av )*; + +################# main ############################ +main := set_cookie_string; + +}%% + +%% write data; + +} // namespace + +/////////////////////////////////////////////////////////////////////////////// + +TCookie TCookie::Parse(const TString& data) { + TCookie result; + { + const unsigned char* S_ = nullptr; + long I = -1; + int Dc; + int cs; + + const unsigned char *p = (const unsigned char*)data.data(); + const unsigned char *pe = p + data.size(); + const unsigned char* eof = pe; + bool valueSet = false; + %% write init; + %% write exec; + if (cs == %%{ write error; }%%) { + throw yexception() << "Cookie parse error"; + } + if (!valueSet) { + throw yexception() << "Cookie value not set"; + } + } + return result; +} + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace NHttp diff --git a/library/cpp/http/client/fetch/coctx.cpp b/library/cpp/http/client/fetch/coctx.cpp new file mode 100644 index 0000000000..dfbb6943a3 --- /dev/null +++ b/library/cpp/http/client/fetch/coctx.cpp @@ -0,0 +1 @@ +#include "coctx.h" diff --git a/library/cpp/http/client/fetch/coctx.h b/library/cpp/http/client/fetch/coctx.h new file mode 100644 index 0000000000..bd1b61cb59 --- /dev/null +++ b/library/cpp/http/client/fetch/coctx.h @@ -0,0 +1,50 @@ +#pragma once + +#include <library/cpp/coroutine/engine/impl.h> + +#include <util/thread/singleton.h> + +namespace NAsyncDns { + class TContResolver; + class TContDnsCache; +} + +namespace NHttpFetcher { + struct TCoCtx { + TContExecutor* Executor; + NAsyncDns::TContResolver* Resolver; + NAsyncDns::TContDnsCache* DnsCache; + + TCoCtx(TContExecutor* executor, NAsyncDns::TContResolver* resolver, NAsyncDns::TContDnsCache* dnsCache = nullptr) + : Executor(executor) + , Resolver(resolver) + , DnsCache(dnsCache) + { + } + + TCont* Cont() { + return Executor->Running(); + } + }; + + inline TCoCtx*& CoCtx() { + return *FastTlsSingletonWithPriority<TCoCtx*, 0>(); + } + + class TCoCtxSetter { + public: + TCoCtxSetter(TContExecutor* executor, NAsyncDns::TContResolver* resolver, NAsyncDns::TContDnsCache* dnsCache = nullptr) + : Instance(executor, resolver, dnsCache) + { + Y_VERIFY(!CoCtx(), "coCtx already exists"); + CoCtx() = &Instance; + } + + ~TCoCtxSetter() { + CoCtx() = nullptr; + } + + private: + TCoCtx Instance; + }; +} diff --git a/library/cpp/http/client/fetch/codes.h b/library/cpp/http/client/fetch/codes.h new file mode 100644 index 0000000000..25d09c88f8 --- /dev/null +++ b/library/cpp/http/client/fetch/codes.h @@ -0,0 +1,36 @@ +#pragma once + +namespace NHttpFetcher { + const int FETCH_SUCCESS_CODE = 200; + const int SERVICE_UNAVAILABLE = 503; + const int ZORA_TIMEOUT_CODE = 5000; + const int URL_FILTER_CODE = 6000; + const int WRONG_HTTP_HEADER_CODE = 6001; + const int FETCH_LARGE_FILE = 6002; + const int FETCH_CANNOT_PARSE = 6003; + const int HOSTS_QUEUE_TIMEOUT = 6004; + const int WRONG_HTTP_RESPONSE = 6005; + const int UNKNOWN_ERROR = 6006; + const int FETCHER_QUEUE_TIMEOUT = 6007; + const int FETCH_IGNORE = 6008; + const int FETCH_CANCELLED = 6009; + + inline bool IsRedirectCode(int code) { + return 301 == code || 302 == code || 303 == code || + 305 == code || 307 == code || 308 == code; + } + + inline bool IsSuccessCode(int code) { + return code >= 200 && code < 300; + } + + inline bool NoRefetch(int code) { + return code == 415 || // Unsupported media type + code == 601 || // Large file + (code >= 400 && code < 500) || + code == 1003 || // disallowed by robots.txt + code == 1006 || // not found by dns server + code == 6008; // once ignored, always ignored + } + +} diff --git a/library/cpp/http/client/fetch/cosocket.h b/library/cpp/http/client/fetch/cosocket.h new file mode 100644 index 0000000000..8230d36bbe --- /dev/null +++ b/library/cpp/http/client/fetch/cosocket.h @@ -0,0 +1,97 @@ +#pragma once + +#include "coctx.h" + +#include <library/cpp/coroutine/engine/network.h> +#include <library/cpp/http/fetch_gpl/sockhandler.h> + +#include <util/system/error.h> + +namespace NHttpFetcher { + class TCoSocketHandler { + public: + TCoSocketHandler() = default; + + ~TCoSocketHandler() { + Disconnect(); + } + + int Good() const { + return (Fd != INVALID_SOCKET); + } + + int Connect(const TAddrList& addrs, TDuration timeout) { + TCont* cont = CoCtx()->Cont(); + Timeout = timeout; + for (const auto& item : addrs) { + try { + const sockaddr* sa = item->Addr(); + TSocketHolder s(NCoro::Socket(sa->sa_family, SOCK_STREAM, 0)); + if (s.Closed()) { + continue; + } + int err = NCoro::ConnectT(cont, s, sa, item->Len(), Timeout); + if (err) { + s.Close(); + errno = err; + continue; + } + SetZeroLinger(s); + SetKeepAlive(s, true); + Fd.Swap(s); + return 0; + } catch (const TSystemError&) { + } + } + return errno ? errno : EBADF; + } + + void Disconnect() { + if (Fd.Closed()) + return; + try { + ShutDown(Fd, SHUT_RDWR); + } catch (const TSystemError&) { + } + Fd.Close(); + } + + void shutdown() { + try { + ShutDown(Fd, SHUT_WR); + } catch (TSystemError&) { + } + } + + ssize_t send(const void* message, size_t messlen) { + TCont* cont = CoCtx()->Cont(); + TContIOStatus status = NCoro::WriteT(cont, Fd, message, messlen, Timeout); + errno = status.Status(); + return status.Status() ? -1 : (ssize_t)status.Processed(); + } + + bool peek() { + TCont* cont = CoCtx()->Cont(); + if ((errno = NCoro::PollT(cont, Fd, CONT_POLL_READ, Timeout))) + return false; + char buf[1]; +#ifdef _win32_ + return (1 == ::recv(Fd, buf, 1, MSG_PEEK)); +#else + return (1 == ::recv(Fd, buf, 1, MSG_PEEK | MSG_DONTWAIT)); +#endif + } + + ssize_t read(void* message, size_t messlen) { + TCont* cont = CoCtx()->Cont(); + TContIOStatus status = NCoro::ReadT(cont, Fd, message, messlen, Timeout); + errno = status.Status(); + return status.Status() ? -1 : (ssize_t)status.Processed(); + } + + protected: + TSocketHolder Fd; + TDuration Timeout; + static THolder<TIpAddress> AddrToBind; + }; +} diff --git a/library/cpp/http/client/fetch/fetch_request.cpp b/library/cpp/http/client/fetch/fetch_request.cpp new file mode 100644 index 0000000000..2f8453fc45 --- /dev/null +++ b/library/cpp/http/client/fetch/fetch_request.cpp @@ -0,0 +1,114 @@ +#include "fetch_request.h" + +#include <library/cpp/deprecated/atomic/atomic.h> + +// TRequest +namespace NHttpFetcher { + const TString DEFAULT_ACCEPT_ENCODING = "gzip, deflate"; + const size_t DEFAULT_MAX_HEADER_SIZE = 100 << 10; + const size_t DEFAULT_MAX_BODY_SIZE = 1 << 29; + + TRequest::TRequest(const TString& url, TCallBack onFetch) + : Url(url) + , Deadline(TInstant::Now() + DEFAULT_REQUEST_TIMEOUT) + , Freshness(DEFAULT_REQUEST_FRESHNESS) + , Priority(40) + , IgnoreRobotsTxt(false) + , LangRegion(ELR_RU) + , OnFetch(onFetch) + , AcceptEncoding(DEFAULT_ACCEPT_ENCODING) + , OnlyHeaders(false) + , MaxHeaderSize(DEFAULT_MAX_HEADER_SIZE) + , MaxBodySize(DEFAULT_MAX_BODY_SIZE) + { + GenerateSequence(); + } + + TRequest::TRequest(const TString& url, bool ignoreRobotsTxt, TDuration timeout, TDuration freshness, TCallBack onFetch) + : Url(url) + , Deadline(Now() + timeout) + , Freshness(freshness) + , Priority(40) + , IgnoreRobotsTxt(ignoreRobotsTxt) + , LangRegion(ELR_RU) + , OnFetch(onFetch) + , AcceptEncoding(DEFAULT_ACCEPT_ENCODING) + , OnlyHeaders(false) + , MaxHeaderSize(DEFAULT_MAX_HEADER_SIZE) + , MaxBodySize(DEFAULT_MAX_BODY_SIZE) + { + GenerateSequence(); + } + + TRequest::TRequest(const TString& url, TDuration timeout, TDuration freshness, bool ignoreRobots, + size_t priority, const TMaybe<TString>& login, const TMaybe<TString>& password, + ELangRegion langRegion, TCallBack onFetch) + : Url(url) + , Deadline(Now() + timeout) + , Freshness(freshness) + , Priority(priority) + , Login(login) + , Password(password) + , IgnoreRobotsTxt(ignoreRobots) + , LangRegion(langRegion) + , OnFetch(onFetch) + , AcceptEncoding(DEFAULT_ACCEPT_ENCODING) + , OnlyHeaders(false) + , MaxHeaderSize(DEFAULT_MAX_HEADER_SIZE) + , MaxBodySize(DEFAULT_MAX_BODY_SIZE) + { + GenerateSequence(); + } + + void TRequest::GenerateSequence() { + static TAtomic nextSeq = 0; + Sequence = AtomicIncrement(nextSeq); + } + + TRequestRef TRequest::Clone() { + THolder<TRequest> request = THolder<TRequest>(new TRequest(*this)); + request->GenerateSequence(); + return request.Release(); + } + + void TRequest::Dump(IOutputStream& out) { + out << "url: " << Url << "\n"; + out << "timeout: " << (Deadline - Now()).MilliSeconds() << " ms\n"; + out << "freshness: " << Freshness.Seconds() << "\n"; + out << "priority: " << Priority << "\n"; + if (!!Login) { + out << "login: " << *Login << "\n"; + } + if (!!Password) { + out << "password: " << *Password << "\n"; + } + if (!!OAuthToken) { + out << "oauth token: " << *OAuthToken << "\n"; + } + if (IgnoreRobotsTxt) { + out << "ignore robots: " << IgnoreRobotsTxt << "\n"; + } + out << "lang reg: " << LangRegion2Str(LangRegion) << "\n"; + if (!!CustomHost) { + out << "custom host: " << *CustomHost << "\n"; + } + if (!!UserAgent) { + out << "user agent: " << *UserAgent << "\n"; + } + if (!!AcceptEncoding) { + out << "accept enc: " << *AcceptEncoding << "\n"; + } + if (OnlyHeaders) { + out << "only headers: " << OnlyHeaders << "\n"; + } + out << "max header sz: " << MaxHeaderSize << "\n"; + out << "max body sz: " << MaxBodySize << "\n"; + if (!!PostData) { + out << "post data: " << *PostData << "\n"; + } + if (!!ContentType) { + out << "content type: " << *ContentType << "\n"; + } + } + +} diff --git a/library/cpp/http/client/fetch/fetch_request.h b/library/cpp/http/client/fetch/fetch_request.h new file mode 100644 index 0000000000..169c2940d7 --- /dev/null +++ b/library/cpp/http/client/fetch/fetch_request.h @@ -0,0 +1,65 @@ +#pragma once + +#include "fetch_result.h" + +#include <kernel/langregion/langregion.h> + +#include <util/datetime/base.h> +#include <util/generic/ptr.h> + +namespace NHttpFetcher { + const TDuration DEFAULT_REQUEST_TIMEOUT = TDuration::Minutes(1); + const TDuration DEFAULT_REQUEST_FRESHNESS = TDuration::Seconds(10000); + + class TRequest; + using TRequestRef = TIntrusivePtr<TRequest>; + + class TRequest: public TAtomicRefCount<TRequest> { + private: + TRequest(const TRequest&) = default; + TRequest& operator=(const TRequest&) = default; + void GenerateSequence(); + + public: + TRequest(const TString& url = "", TCallBack onFetch = TCallBack()); + TRequest(const TString& url, bool ignoreRobotsTxt, TDuration timeout, + TDuration freshness, TCallBack onFetch = TCallBack()); + TRequest(const TString& url, TDuration timeout, TDuration freshness, bool ignoreRobots, + size_t priority, const TMaybe<TString>& login = TMaybe<TString>(), + const TMaybe<TString>& password = TMaybe<TString>(), + ELangRegion langRegion = ELR_RU, TCallBack onFetch = TCallBack()); + void Dump(IOutputStream& out); + TRequestRef Clone(); + + public: + TString Url; + TMaybe<TString> UnixSocketPath; + + TInstant Deadline; // [default = 1 min] + TDuration RdWrTimeout; + TMaybe<TDuration> ConnectTimeout; + TDuration Freshness; // [default = 1000 sec] + size_t Priority; // lower is more important; range [0, 100], default 40 + + TMaybe<TString> Login; + TMaybe<TString> Password; + TMaybe<TString> OAuthToken; + bool IgnoreRobotsTxt; // [default = false] + ELangRegion LangRegion; // [default = ELR_RU] + TCallBack OnFetch; // for async requests + ui64 Sequence; // unique id + TMaybe<TString> CustomHost; // Use custom host for "Host" header + TMaybe<TString> UserAgent; // custom user agen, [default = YandexNews] + TMaybe<TString> AcceptEncoding; // custom accept encoding, [default = "gzip, deflate"] + bool OnlyHeaders; // [default = false], if true - no content will be fetched (HEAD request) + size_t MaxHeaderSize; // returns 1002 error if exceeded + size_t MaxBodySize; // returns 1002 error if exceeded + TNeedDataCallback NeedDataCallback; // set this callback if you need to check data while fetching + // true - coninue fetching, false - stop + TMaybe<TString> Method; // for http exotics like "PUT ", "PATCH ", "DELETE ". if doesn't exist, GET or POST will br used + TMaybe<TString> PostData; // if exists - send post request + TMaybe<TString> ContentType; // custom content-type for post requests + // [default = "application/x-www-form-urlencoded"] + TVector<TString> ExtraHeaders; // needed for some servers (to auth, for ex.); don't forget to add "\r\n"! + }; +} diff --git a/library/cpp/http/client/fetch/fetch_result.cpp b/library/cpp/http/client/fetch/fetch_result.cpp new file mode 100644 index 0000000000..0ba1b1e6be --- /dev/null +++ b/library/cpp/http/client/fetch/fetch_result.cpp @@ -0,0 +1,32 @@ +#include "codes.h" +#include "fetch_result.h" + +#include <library/cpp/charset/recyr.hh> + +namespace NHttpFetcher { + TResult::TResult(const TString& url, int code) + : RequestUrl(url) + , ResolvedUrl(url) + , Code(code) + , ConnectionReused(false) + { + } + + TString TResult::DecodeData(bool* decoded) const { + if (!!Encoding && *Encoding != CODES_UTF8) { + if (decoded) { + *decoded = true; + } + return Recode(*Encoding, CODES_UTF8, Data); + } + if (decoded) { + *decoded = false; + } + return Data; + } + + bool TResult::Success() const { + return Code == FETCH_SUCCESS_CODE; + } + +} diff --git a/library/cpp/http/client/fetch/fetch_result.h b/library/cpp/http/client/fetch/fetch_result.h new file mode 100644 index 0000000000..24fe49e1f6 --- /dev/null +++ b/library/cpp/http/client/fetch/fetch_result.h @@ -0,0 +1,40 @@ +#pragma once + +#include <library/cpp/charset/doccodes.h> +#include <library/cpp/http/io/headers.h> +#include <library/cpp/langs/langs.h> + +#include <util/generic/maybe.h> +#include <util/generic/ptr.h> +#include <util/generic/string.h> +#include <util/generic/vector.h> + +#include <functional> + +namespace NHttpFetcher { + // Result + using TResultRef = TIntrusivePtr<struct TResult>; + struct TResult: public TAtomicRefCount<TResult> { + TResult(const TString& url, int code = 0); + TString DecodeData(bool* decoded = nullptr) const; + bool Success() const; + + public: + TString RequestUrl; + TString ResolvedUrl; + TString Location; + int Code; + bool ConnectionReused; + TString StatusStr; + TString MimeType; + THttpHeaders Headers; + TString Data; + TMaybe<ECharset> Encoding; + TMaybe<ELanguage> Language; + TVector<TResultRef> Redirects; + TString HttpVersion; + }; + + using TCallBack = std::function<void(TResultRef)>; + using TNeedDataCallback = std::function<bool(const TString&)>; +} diff --git a/library/cpp/http/client/fetch/fetch_single.cpp b/library/cpp/http/client/fetch/fetch_single.cpp new file mode 100644 index 0000000000..27d0888b80 --- /dev/null +++ b/library/cpp/http/client/fetch/fetch_single.cpp @@ -0,0 +1,205 @@ +#include "codes.h" +#include "fetch_single.h" +#include "parse.h" + +#include <library/cpp/string_utils/base64/base64.h> + +#include <util/string/ascii.h> +#include <util/string/cast.h> +#include <library/cpp/string_utils/url/url.h> + +namespace NHttpFetcher { + static sockaddr_in6 IPv6Loopback() { + sockaddr_in6 sock = {}; + sock.sin6_family = AF_INET6; + sock.sin6_addr = IN6ADDR_LOOPBACK_INIT; + return sock; + } + + class THeaders { + public: + inline void Build(const TRequestRef& request) { + if (!!request->Password || !!request->Login) { + TString pass; + TString login; + TString raw; + TString encoded; + + if (!!request->Password) { + pass = *request->Password; + } + if (!!request->Login) { + login = *request->Login; + } + raw = TString::Join(login, ":", pass); + Base64Encode(raw, encoded); + BasicAuth_ = TString::Join("Authorization: Basic ", encoded, "\r\n"); + Headers_.push_back(BasicAuth_.c_str()); + } + + if (request->ExtraHeaders) { + Headers_.reserve(Headers_.size() + request->ExtraHeaders.size()); + for (const TString& header : request->ExtraHeaders) { + if (AsciiHasSuffixIgnoreCase(header, "\r\n")) { + Headers_.push_back(header.c_str()); + } else { + Fixed_.push_back(header + "\r\n"); + Headers_.push_back(Fixed_.back().c_str()); + } + } + } + + if (!!request->OAuthToken) { + OAuth_ = TString::Join("Authorization: OAuth ", *request->OAuthToken, "\r\n"); + Headers_.push_back(OAuth_.c_str()); + } + + if (!!request->AcceptEncoding) { + AcceptEncoding_ = TString::Join("Accept-Encoding: ", *request->AcceptEncoding, "\r\n"); + Headers_.push_back(AcceptEncoding_.c_str()); + } + + ContentType_ = "application/x-www-form-urlencoded"; + if (!!request->PostData) { + if (!!request->ContentType) { + ContentType_ = *request->ContentType; + } + ContentLength_ = TString::Join("Content-Length: ", ToString(request->PostData->size()), "\r\n"); + ContentType_ = TString::Join("Content-Type: ", ContentType_, "\r\n"); + Headers_.push_back(ContentLength_.c_str()); + Headers_.push_back(ContentType_.c_str()); + } + + Headers_.push_back((const char*)nullptr); + } + + inline const char* const* Data() { + return Headers_.data(); + } + + private: + TVector<const char*> Headers_; + TVector<TString> Fixed_; + TString BasicAuth_; + TString OAuth_; + TString AcceptEncoding_; + TString ContentLength_; + TString ContentType_; + }; + + TResultRef FetchSingleImpl(TRequestRef request, TSocketPool* pool) { + Y_ASSERT(!!request->Url && "no url passed in fetch request"); + TResultRef result(new TResult(request->Url)); + try { + TSimpleFetcherFetcher fetcher; + THttpHeader header; + + THttpURL::EKind kind = THttpURL::SchemeHTTP; + ui16 port = 80; + TString host; + ParseUrl(request->Url, kind, host, port); + + TString path = ToString(GetPathAndQuery(request->Url)); + + bool defaultPort = (kind == THttpURL::SchemeHTTP && port == 80) || + (kind == THttpURL::SchemeHTTPS && port == 443); + + if (request->UnixSocketPath && !request->UnixSocketPath->empty()) { + TAddrList addrs; + addrs.emplace_back(new NAddr::TUnixSocketAddr{*request->UnixSocketPath}); + fetcher.SetHost(host.data(), port, addrs, kind); + } else if (host == "127.0.0.1") { + // todo: correctly handle /etc/hosts records && ip-addresses + + // bypass normal DNS resolving for localhost + TAddrList addrs({new NAddr::TIPv4Addr(TIpAddress(0x0100007F, port))}); + fetcher.SetHost(host.data(), port, addrs, kind); + } else if (host == "localhost") { + sockaddr_in6 ipv6Addr = IPv6Loopback(); + ipv6Addr.sin6_port = HostToInet(port); + + TAddrList addrs({new NAddr::TIPv6Addr(ipv6Addr), new NAddr::TIPv4Addr(TIpAddress(0x0100007F, port))}); + fetcher.SetHost(host.data(), port, addrs, kind); + } else { + Y_ASSERT(!!host && "no host detected in url passed"); + fetcher.SetHost(host.data(), port, kind); + } + header.Init(); + + THeaders headers; + headers.Build(request); + + TString hostHeader = (!!request->CustomHost ? *request->CustomHost : host) + + (defaultPort ? "" : ":" + ToString(port)); + fetcher.SetHostHeader(hostHeader.data()); + + if (!!request->UserAgent) { + fetcher.SetIdentification((*request->UserAgent).data(), nullptr); + } else { + fetcher.SetIdentification("Mozilla/5.0 (compatible; YandexNews/3.0; +http://yandex.com/bots)", nullptr); + } + + if (!!request->Method) { + fetcher.SetMethod(request->Method->data(), request->Method->size()); + } + + if (!!request->PostData) { + fetcher.SetPostData(request->PostData->data(), request->PostData->size()); + } + + fetcher.SetMaxBodySize(request->MaxBodySize); + fetcher.SetMaxHeaderSize(request->MaxHeaderSize); + + if (request->ConnectTimeout) { + fetcher.SetConnectTimeout(*request->ConnectTimeout); + } + + { + const TDuration rest = request->Deadline - Now(); + fetcher.SetTimeout(request->RdWrTimeout != TDuration::Zero() + ? ::Min(request->RdWrTimeout, rest) + : rest); + } + fetcher.SetNeedDataCallback(request->NeedDataCallback); + bool persistent = !request->OnlyHeaders; + void* handle = nullptr; + + if (pool) { + while (auto socket = pool->GetSocket(host, port)) { + if (socket->Good()) { + handle = socket.Get(); + + fetcher.SetSocket(socket.Release()); + fetcher.SetPersistent(true); + break; + } + } + } + + int fetchResult = fetcher.Fetch(&header, path.data(), headers.Data(), persistent, request->OnlyHeaders); + + if (!fetcher.Data.empty()) { + TStringInput httpIn(fetcher.Data.Str()); + ParseHttpResponse(*result, httpIn, kind, host, port); + } + + if (fetchResult < 0 || header.error != 0) { + result->Code = header.error; + } + + if (pool && persistent && !header.error && !header.connection_closed) { + THolder<TSocketPool::TSocketHandle> socket(fetcher.PickOutSocket()); + + if (!!socket && socket->Good()) { + if (handle == socket.Get()) { + result->ConnectionReused = true; + } + pool->ReturnSocket(host, port, std::move(socket)); + } + } + } catch (...) { + result->Code = UNKNOWN_ERROR; + } + return result; + } +} diff --git a/library/cpp/http/client/fetch/fetch_single.h b/library/cpp/http/client/fetch/fetch_single.h new file mode 100644 index 0000000000..890c925cc9 --- /dev/null +++ b/library/cpp/http/client/fetch/fetch_single.h @@ -0,0 +1,88 @@ +#pragma once + +#include "cosocket.h" +#include "fetch_request.h" +#include "fetch_result.h" +#include "pool.h" + +#include <library/cpp/coroutine/dns/helpers.h> +#include <library/cpp/http/client/ssl/sslsock.h> +#include <library/cpp/http/fetch_gpl/httpagent.h> +#include <library/cpp/http/fetch/httpfetcher.h> +#include <library/cpp/http/fetch/httpheader.h> + +#include <util/generic/algorithm.h> + +namespace NHttpFetcher { + class TCoIpResolver { + public: + TAddrList Resolve(const char* host, TIpPort port) const { + NAsyncDns::TAddrs addrs; + try { + NAsyncDns::ResolveAddr(*CoCtx()->Resolver, host, port, addrs, CoCtx()->DnsCache); + } catch (...) { + return TAddrList(); + } + + // prefer IPv6 + SortBy(addrs.begin(), addrs.end(), [](const auto& addr) { + return addr->Addr()->sa_family == AF_INET6 ? 0 : 1; + }); + + return TAddrList(addrs.begin(), addrs.end()); + } + }; + + struct TStringSaver { + int Write(const void* buf, size_t len) { + Data.Write(buf, len); + return 0; + } + TStringStream Data; + }; + + struct TSimpleCheck { + inline bool Check(THttpHeader*) { + return false; + } + void CheckDocPart(void* data, size_t size, THttpHeader*) { + if (!!NeedDataCallback) { + CheckData += TString(static_cast<const char*>(data), size); + if (!NeedDataCallback(CheckData)) { + BodyMax = 0; + } + } + } + void CheckEndDoc(THttpHeader*) { + } + size_t GetMaxHeaderSize() { + return HeaderMax; + } + size_t GetMaxBodySize(THttpHeader*) { + return BodyMax; + } + void SetMaxHeaderSize(size_t headerMax) { + HeaderMax = headerMax; + } + void SetMaxBodySize(size_t bodyMax) { + BodyMax = bodyMax; + } + void SetNeedDataCallback(const TNeedDataCallback& callback) { + NeedDataCallback = callback; + } + + private: + size_t HeaderMax; + size_t BodyMax; + TNeedDataCallback NeedDataCallback; + TString CheckData; + }; + + using TSimpleHttpAgent = THttpsAgent<TCoSocketHandler, TCoIpResolver, + TSslSocketBase::TFakeLogger, TNoTimer, + NHttpFetcher::TSslSocketHandler>; + using TSimpleFetcherFetcher = THttpFetcher<TFakeAlloc<>, TSimpleCheck, TStringSaver, TSimpleHttpAgent>; + + //! Private method of fetcher library. Don't use it in your code. + TResultRef FetchSingleImpl(TRequestRef request, TSocketPool* pool = nullptr); +} diff --git a/library/cpp/http/client/fetch/parse.cpp b/library/cpp/http/client/fetch/parse.cpp new file mode 100644 index 0000000000..62d610102b --- /dev/null +++ b/library/cpp/http/client/fetch/parse.cpp @@ -0,0 +1,160 @@ +#include "codes.h" +#include "parse.h" + +#include <library/cpp/charset/codepage.h> +#include <library/cpp/http/io/stream.h> +#include <library/cpp/mime/types/mime.h> +#include <library/cpp/uri/uri.h> + +#include <library/cpp/string_utils/url/url.h> +#include <util/string/vector.h> + +namespace NHttpFetcher { + namespace { + static TString MimeTypeFromUrl(const NUri::TUri& httpUrl) { + TStringBuf path = httpUrl.GetField(NUri::TField::FieldPath); + size_t pos = path.find_last_of('.'); + if (pos == TStringBuf::npos) { + return ""; + } + // TODO (stanly) replace TString with TStringBuf + TString ext = TString(path.substr(pos + 1)); + TString mime = mimetypeByExt(path.data()); + if (mime) { + return mime; + } + + if (ext == "jpg" || ext == "jpeg" || ext == "png" || ext == "gif") { + return "image/" + ext; + } else if (ext == "m4v" || ext == "mp4" || ext == "flv" || ext == "mpeg") { + return "video/" + ext; + } else if (ext == "mp3" || ext == "wav" || ext == "ogg") { + return "audio/" + ext; + } else if (ext == "zip" || ext == "doc" || ext == "docx" || ext == "xls" || ext == "xlsx" || ext == "pdf" || ext == "ppt") { + return "application/" + ext; + } else if (ext == "rar" || ext == "7z") { + return "application/x-" + ext + "-compressed"; + } else if (ext == "exe") { + return "application/octet-stream"; + } + + return ""; + } + + static TString MimeTypeFromUrl(const TString& url) { + static const ui64 flags = NUri::TFeature::FeaturesRobot | NUri::TFeature::FeatureToLower; + + NUri::TUri httpUrl; + if (httpUrl.Parse(url, flags) != NUri::TUri::ParsedOK) { + return ""; + } + + return MimeTypeFromUrl(httpUrl); + } + + // Extracts encoding & content-type from headers + static void ProcessHeaders(TResult& result) { + for (THttpHeaders::TConstIterator it = result.Headers.Begin(); it != result.Headers.End(); it++) { + TString name = it->Name(); + name.to_lower(); + if (name == "content-type") { + TString value = it->Value(); + value.to_lower(); + size_t delimPos = value.find(';'); + if (delimPos == TString::npos) { + delimPos = value.size(); + } + result.MimeType = value.substr(0, delimPos); + size_t charsetPos = value.find("charset="); + if (charsetPos == TString::npos) { + continue; + } + delimPos = value.find(';', charsetPos + 1); + TString charsetStr = value.substr(charsetPos + 8, + delimPos == TString::npos ? delimPos : delimPos - charsetPos - 8); + ECharset charset = CharsetByName(charsetStr.data()); + if (charset != CODES_UNSUPPORTED && charset != CODES_UNKNOWN) { + result.Encoding = charset; + } + } + } + + if (result.MimeType.empty() || result.MimeType == "application/octet-stream") { + const TString& detectedMimeType = MimeTypeFromUrl(result.ResolvedUrl); + if (detectedMimeType) { + result.MimeType = detectedMimeType; + } + } + } + + } + + void ParseHttpResponse(TResult& result, IInputStream& is, THttpURL::EKind kind, + TStringBuf host, ui16 port) { + THttpInput httpIn(&is); + TString firstLine = httpIn.FirstLine(); + TVector<TString> params = SplitString(firstLine, " "); + try { + if (params.size() < 2) { + ythrow yexception() << "failed to parse first line"; + } + result.HttpVersion = params[0]; + result.Code = FromString(params[1]); + } catch (const std::exception&) { + result.Code = WRONG_HTTP_HEADER_CODE; + } + for (auto it = httpIn.Headers().Begin(); it < httpIn.Headers().End(); ++it) { + const THttpInputHeader& header = *it; + TString name = header.Name(); + name.to_lower(); + if (name == "location" && IsRedirectCode(result.Code)) { + // TODO (stanly) use correct routine to parse location + result.Location = header.Value(); + result.ResolvedUrl = header.Value(); + if (result.ResolvedUrl.StartsWith('/')) { + const bool defaultPort = + (kind == THttpURL::SchemeHTTP && port == 80) || + (kind == THttpURL::SchemeHTTPS && port == 443); + + result.ResolvedUrl = TString(NUri::SchemeKindToString(kind)) + "://" + host + + (defaultPort ? "" : ":" + ToString(port)) + + result.ResolvedUrl; + } + } + } + try { + result.Headers = httpIn.Headers(); + result.Data = httpIn.ReadAll(); + ProcessHeaders(result); + // TODO (stanly) try to detect mime-type by content + } catch (const yexception& /* exception */) { + result.Code = WRONG_HTTP_RESPONSE; + } + } + + void ParseHttpResponse(TResult& result, IInputStream& stream, const TString& url) { + THttpURL::EKind kind; + TString host; + ui16 port; + ParseUrl(url, kind, host, port); + ParseHttpResponse(result, stream, kind, host, port); + } + + void ParseUrl(const TStringBuf url, THttpURL::EKind& kind, TString& host, ui16& port) { + using namespace NUri; + + static const int URI_PARSE_FLAGS = + TFeature::FeatureSchemeKnown | TFeature::FeatureConvertHostIDN | TFeature::FeatureEncodeExtendedDelim | TFeature::FeatureEncodePercent; + + TUri uri; + // Cut out url's path to speedup processing. + if (uri.Parse(GetSchemeHostAndPort(url, false, false), URI_PARSE_FLAGS) != TUri::ParsedOK) { + ythrow yexception() << "can't parse url: " << url; + } + + kind = uri.GetScheme(); + host = uri.GetField(TField::FieldHost); + port = uri.GetPort(); + } + +} diff --git a/library/cpp/http/client/fetch/parse.h b/library/cpp/http/client/fetch/parse.h new file mode 100644 index 0000000000..dacfa9bf84 --- /dev/null +++ b/library/cpp/http/client/fetch/parse.h @@ -0,0 +1,14 @@ +#pragma once + +#include "fetch_result.h" +#include <library/cpp/uri/http_url.h> + +namespace NHttpFetcher { + void ParseUrl(const TStringBuf url, THttpURL::EKind& kind, TString& host, ui16& port); + + void ParseHttpResponse(TResult& result, IInputStream& stream, THttpURL::EKind kind, + TStringBuf host, ui16 port); + + void ParseHttpResponse(TResult& result, IInputStream& stream, const TString& url); + +} diff --git a/library/cpp/http/client/fetch/pool.cpp b/library/cpp/http/client/fetch/pool.cpp new file mode 100644 index 0000000000..f0a142eced --- /dev/null +++ b/library/cpp/http/client/fetch/pool.cpp @@ -0,0 +1,57 @@ +#include "pool.h" + +namespace NHttpFetcher { + void TSocketPool::Clear() { + TSocketMap sockets; + + { + auto g(Guard(Lock_)); + Sockets_.swap(sockets); + } + } + + void TSocketPool::Drain(const TDuration timeout) { + const TInstant now = TInstant::Now(); + TVector<THolder<TSocketHandle>> sockets; + + { + auto g(Guard(Lock_)); + for (auto si = Sockets_.begin(); si != Sockets_.end();) { + if (si->second.Touched + timeout < now) { + sockets.push_back(std::move(si->second.Socket)); + Sockets_.erase(si++); + } else { + ++si; + } + } + } + } + + THolder<TSocketPool::TSocketHandle> TSocketPool::GetSocket(const TString& host, const TIpPort port) { + THolder<TSocketPool::TSocketHandle> socket; + + { + auto g(Guard(Lock_)); + auto si = Sockets_.find(std::make_pair(host, port)); + if (si != Sockets_.end()) { + socket = std::move(si->second.Socket); + Sockets_.erase(si); + } + } + + return socket; + } + + void TSocketPool::ReturnSocket(const TString& host, const TIpPort port, THolder<TSocketHandle> socket) { + TConnection conn; + + conn.Socket = std::move(socket); + conn.Touched = TInstant::Now(); + + { + auto g(Guard(Lock_)); + Sockets_.emplace(std::make_pair(host, port), std::move(conn)); + } + } + +} diff --git a/library/cpp/http/client/fetch/pool.h b/library/cpp/http/client/fetch/pool.h new file mode 100644 index 0000000000..73c3eda0c6 --- /dev/null +++ b/library/cpp/http/client/fetch/pool.h @@ -0,0 +1,41 @@ +#pragma once + +#include "cosocket.h" + +#include <library/cpp/http/client/ssl/sslsock.h> + +#include <util/generic/hash_multi_map.h> +#include <util/generic/ptr.h> +#include <util/system/mutex.h> + +namespace NHttpFetcher { + class TSocketPool { + public: + using TSocketHandle = TSslSocketHandler<TCoSocketHandler, TSslSocketBase::TFakeLogger>; + + public: + /// Closes all sockets. + void Clear(); + + /// Closes all sockets that have been opened too long. + void Drain(const TDuration timeout); + + /// Returns socket for the given endpoint if available. + THolder<TSocketHandle> GetSocket(const TString& host, const TIpPort port); + + /// Puts socket to the pool. + void ReturnSocket(const TString& host, const TIpPort port, THolder<TSocketHandle> socket); + + private: + struct TConnection { + THolder<TSocketHandle> Socket; + TInstant Touched; + }; + + using TSocketMap = THashMultiMap<std::pair<TString, TIpPort>, TConnection>; + + TMutex Lock_; + TSocketMap Sockets_; + }; + +} diff --git a/library/cpp/http/client/query.cpp b/library/cpp/http/client/query.cpp new file mode 100644 index 0000000000..36a946074b --- /dev/null +++ b/library/cpp/http/client/query.cpp @@ -0,0 +1,92 @@ +#include "query.h" +#include "request.h" + +namespace NHttp { + TFetchQuery::TFetchQuery(const TString& url, + const TFetchOptions& options) + : Url_(url) + , Options_(options) + { + } + + TFetchQuery::TFetchQuery(const TString& url, + const TVector<TString>& headers, + const TFetchOptions& options) + : Url_(url) + , Headers_(headers) + , Options_(options) + { + } + + TFetchQuery::~TFetchQuery() = default; + + TString TFetchQuery::GetUrl() const { + return Url_; + } + + TFetchQuery& TFetchQuery::OnFail(TOnFail cb) { + OnFailCb_ = cb; + return *this; + } + + TFetchQuery& TFetchQuery::OnRedirect(TOnRedirect cb) { + OnRedirectCb_ = cb; + return *this; + } + + TFetchQuery& TFetchQuery::OnPartialRead(NHttpFetcher::TNeedDataCallback cb) { + OnPartialReadCb_ = cb; + return *this; + } + + TFetchRequestRef TFetchQuery::ConstructRequest() const { + TFetchRequestRef request = new TFetchRequest(Url_, Headers_, Options_); + if (OnFailCb_) { + request->SetOnFail(*OnFailCb_); + } + + if (OnRedirectCb_) { + request->SetOnRedirect(*OnRedirectCb_); + } + + if (OnPartialReadCb_) { + request->SetOnPartialRead(*OnPartialReadCb_); + } + + return request; + } + + TFetchState::TFetchState() { + } + + TFetchState::TFetchState(const TFetchRequestRef& req) + : Request_(req) + { + } + + void TFetchState::Cancel() const { + if (Request_) { + Request_->Cancel(); + } + } + + NHttpFetcher::TResultRef TFetchState::Get() const { + if (Request_) { + WaitI(); + return Request_->MakeResult(); + } + return NHttpFetcher::TResultRef(); + } + + void TFetchState::WaitI() const { + WaitT(TDuration::Max()); + } + + bool TFetchState::WaitT(TDuration timeout) const { + if (Request_) { + return Request_->WaitT(timeout); + } + return false; + } + +} diff --git a/library/cpp/http/client/query.h b/library/cpp/http/client/query.h new file mode 100644 index 0000000000..e3dbc3b7be --- /dev/null +++ b/library/cpp/http/client/query.h @@ -0,0 +1,162 @@ +#pragma once + +#include <library/cpp/http/client/fetch/fetch_result.h> + +#include <util/datetime/base.h> +#include <util/generic/maybe.h> +#include <util/generic/string.h> + +#include <functional> + +namespace NHttp { + /** + * Various options for fetching a document. + */ + struct TFetchOptions { +#define DECLARE_FIELD(name, type, default) \ + type name{default}; \ + inline TFetchOptions& Set##name(const type& value) { \ + name = value; \ + return *this; \ + } + + /// Set a timeout for connection establishment + DECLARE_FIELD(ConnectTimeout, TMaybe<TDuration>, Nothing()); + + /// Total request timeout for each attempt + DECLARE_FIELD(Timeout, TDuration, TDuration::Minutes(1)); + + /// Count of additional attempts before return error code. + DECLARE_FIELD(RetryCount, ui32, 0); + + /// Sleep delay before next retry + DECLARE_FIELD(RetryDelay, TDuration, TDuration::Seconds(5)); + + /// Parse cookie from server's response and attach its to further + /// requests in case of redirects. + DECLARE_FIELD(UseCookie, bool, true); + + /// Finite numbers of following redirects to prevent hang on infinitiy + /// redirect sequence. + DECLARE_FIELD(RedirectDepth, ui32, 15); + + /// If true - no content will be fetched (HEAD request). + DECLARE_FIELD(OnlyHeaders, bool, false); + + /// Use custom host for "Host" header. + DECLARE_FIELD(CustomHost, TMaybe<TString>, Nothing()); + + /// Login for basic HTTP authorization. + DECLARE_FIELD(Login, TMaybe<TString>, Nothing()); + + /// Password for basic HTTP authorization. + DECLARE_FIELD(Password, TMaybe<TString>, Nothing()); + + /// OAuth token. + DECLARE_FIELD(OAuthToken, TMaybe<TString>, Nothing()); + + /// For http exotics like "PUT ", "PATCH ", "DELETE ". If doesn't exist, + /// GET or POST will be used. + DECLARE_FIELD(Method, TMaybe<TString>, Nothing()); + + /// If exists - send POST request. + DECLARE_FIELD(PostData, TMaybe<TString>, Nothing()); + + /// Custom content-type for POST requests. + DECLARE_FIELD(ContentType, TMaybe<TString>, Nothing()); + + /// Custom value for "UserAgent" header. + DECLARE_FIELD(UserAgent, TString, "Python-urllib/2.6"); + + /// Always establish new connection to the target host. + DECLARE_FIELD(ForceReconnect, bool, false); + + /// Set max header size if needed + DECLARE_FIELD(MaxHeaderSize, TMaybe<size_t>, Nothing()); + + /// Set max body size if needed + DECLARE_FIELD(MaxBodySize, TMaybe<size_t>, Nothing()); + + /// If set, unix domain socket will be used as a backend + DECLARE_FIELD(UnixSocketPath, TMaybe<TString>, Nothing()); + +#undef DECLARE_FIELD + }; + + using TFetchRequestRef = TIntrusivePtr<class TFetchRequest>; + + /// If handler will return true then try again else stop fetching. + using TOnFail = std::function<bool(const NHttpFetcher::TResultRef& reply)>; + + /// If handler will return true then follow redirect else stop fetching. + using TOnRedirect = std::function<bool(const TString& from, const TString& location)>; + + class TFetchQuery { + public: + TFetchQuery() = default; + TFetchQuery(const TString& url, + const TFetchOptions& options = TFetchOptions()); + TFetchQuery(const TString& url, + const TVector<TString>& headers, + const TFetchOptions& options = TFetchOptions()); + ~TFetchQuery(); + + /// Returns original request's url. + TString GetUrl() const; + + /// Set handler for fail's handling. + TFetchQuery& OnFail(TOnFail cb); + + /// Set handler for redirect's handling. + TFetchQuery& OnRedirect(TOnRedirect cb); + + /// Set handler which will be called periodically while reading data from peer + /// with ALL the data accumulated so far. One is able to cancel further data exchange + /// by returning false from the callback. + TFetchQuery& OnPartialRead(NHttpFetcher::TNeedDataCallback cb); + + TFetchRequestRef ConstructRequest() const; + + private: + friend class TFetchRequest; + + TString Url_; + TVector<TString> Headers_; + TFetchOptions Options_; + TMaybe<TOnFail> OnFailCb_; + TMaybe<TOnRedirect> OnRedirectCb_; + TMaybe<NHttpFetcher::TNeedDataCallback> OnPartialReadCb_; + }; + + class TFetchState { + public: + TFetchState(); + TFetchState(const TFetchRequestRef& req); + TFetchState(const TFetchState&) = default; + TFetchState(TFetchState&&) = default; + + ~TFetchState() = default; + + TFetchState& operator=(TFetchState&&) = default; + TFetchState& operator=(const TFetchState&) = default; + + /// Cancel the request. + void Cancel() const; + + /// Wait for completion of the request and return result. + NHttpFetcher::TResultRef Get() const; + + /// Waits for completion of the request. + void WaitI() const; + + /// Waits for completion of the request no more than given timeout. + /// + /// \return true if fetching is finished. + /// \return false if timed out. + bool WaitT(TDuration timeout) const; + + private: + TFetchRequestRef Request_; + }; + +} diff --git a/library/cpp/http/client/request.cpp b/library/cpp/http/client/request.cpp new file mode 100644 index 0000000000..08cf73da7b --- /dev/null +++ b/library/cpp/http/client/request.cpp @@ -0,0 +1,249 @@ +#include "request.h" + +#include <library/cpp/http/client/fetch/codes.h> +#include <library/cpp/uri/location.h> + +#include <util/string/ascii.h> + +namespace NHttp { + static const ui64 URI_PARSE_FLAGS = + (NUri::TFeature::FeaturesRecommended | NUri::TFeature::FeatureConvertHostIDN | NUri::TFeature::FeatureEncodeExtendedDelim | NUri::TFeature::FeatureEncodePercent) & ~NUri::TFeature::FeatureHashBangToEscapedFragment; + + /// Generates sequence of unique identifiers of requests. + static TAtomic RequestCounter = 0; + + TFetchRequest::TRedirects::TRedirects(bool parseCookie) { + if (parseCookie) { + CookieStore.Reset(new NHttp::TCookieStore); + } + } + + size_t TFetchRequest::TRedirects::Level() const { + return this->size(); + } + + void TFetchRequest::TRedirects::ParseCookies(const TString& url, + const THttpHeaders& headers) { + if (CookieStore) { + NUri::TUri uri; + if (uri.Parse(url, URI_PARSE_FLAGS) != NUri::TUri::ParsedOK) { + return; + } + if (!uri.IsValidGlobal()) { + return; + } + + for (THttpHeaders::TConstIterator it = headers.Begin(); it != headers.End(); it++) { + if (AsciiEqualsIgnoreCase(it->Name(), TStringBuf("Set-Cookie"))) { + CookieStore->SetCookie(uri, it->Value()); + } + } + } + } + + TFetchRequest::TFetchRequest(const TString& url, const TFetchOptions& options) + : Url_(url) + , Options_(options) + , Id_(AtomicIncrement(RequestCounter)) + , RetryAttempts_(options.RetryCount) + , RetryDelay_(options.RetryDelay) + , Cancel_(false) + , CurrentUrl_(url) + { + } + + TFetchRequest::TFetchRequest(const TString& url, TVector<TString> headers, const TFetchOptions& options) + : TFetchRequest(url, options) + { + Headers_ = std::move(headers); + } + + void TFetchRequest::Cancel() { + AtomicSet(Cancel_, 1); + } + + NHttpFetcher::TRequestRef TFetchRequest::GetRequestImpl() const { + NHttpFetcher::TRequestRef req(new NHttpFetcher::TRequest(CurrentUrl_)); + + req->UnixSocketPath = Options_.UnixSocketPath; + req->Login = Options_.Login; + req->Password = Options_.Password; + req->OAuthToken = Options_.OAuthToken; + req->OnlyHeaders = Options_.OnlyHeaders; + req->CustomHost = Options_.CustomHost; + req->Method = Options_.Method; + req->OAuthToken = Options_.OAuthToken; + req->ContentType = Options_.ContentType; + req->PostData = Options_.PostData; + req->UserAgent = Options_.UserAgent; + req->ExtraHeaders.assign(Headers_.begin(), Headers_.end()); + req->NeedDataCallback = OnPartialRead_; + req->Deadline = TInstant::Now() + Options_.Timeout; + + if (Options_.ConnectTimeout) { + req->ConnectTimeout = Options_.ConnectTimeout; + } + + if (Options_.MaxHeaderSize) { + req->MaxHeaderSize = Options_.MaxHeaderSize.GetRef(); + } + if (Options_.MaxBodySize) { + req->MaxBodySize = Options_.MaxBodySize.GetRef(); + } + + if (Redirects_ && Redirects_->CookieStore) { + NUri::TUri uri; + if (uri.Parse(CurrentUrl_, URI_PARSE_FLAGS) == NUri::TUri::ParsedOK) { + if (TString cookies = Redirects_->CookieStore->GetCookieString(uri)) { + req->ExtraHeaders.push_back("Cookie: " + cookies + "\r\n"); + } + } + } + + return req; + } + + bool TFetchRequest::IsValid() const { + auto g(Guard(Lock_)); + return IsValidNoLock(); + } + + bool TFetchRequest::IsCancelled() const { + return AtomicGet(Cancel_); + } + + bool TFetchRequest::GetForceReconnect() const { + return Options_.ForceReconnect; + } + + NHttpFetcher::TResultRef TFetchRequest::MakeResult() const { + if (Exception_) { + std::rethrow_exception(Exception_); + } + + return Result_; + } + + void TFetchRequest::SetException(std::exception_ptr ptr) { + Exception_ = ptr; + } + + void TFetchRequest::SetCallback(NHttpFetcher::TCallBack cb) { + Cb_ = cb; + } + + void TFetchRequest::SetOnFail(TOnFail cb) { + OnFail_ = cb; + } + + void TFetchRequest::SetOnRedirect(TOnRedirect cb) { + OnRedirect_ = cb; + } + + void TFetchRequest::SetOnPartialRead(NHttpFetcher::TNeedDataCallback cb) { + OnPartialRead_ = cb; + } + + bool TFetchRequest::WaitT(TDuration timeout) { + TCondVar c; + + { + auto g(Guard(Lock_)); + + if (IsValidNoLock()) { + THolder<TWaitState> state(new TWaitState(&c)); + Awaitings_.PushBack(state.Get()); + if (!c.WaitT(Lock_, timeout)) { + AtomicSet(Cancel_, 1); + // Удаляем элемент из очереди ожидания в случае, если + // истёк установленный период времени. + state->Unlink(); + return false; + } + } + } + + return true; + } + + bool TFetchRequest::IsValidNoLock() const { + return !Result_ && !Exception_ && !AtomicGet(Cancel_); + } + + void TFetchRequest::Reply(NHttpFetcher::TResultRef result) { + NHttpFetcher::TCallBack cb; + + { + auto g(Guard(Lock_)); + + cb.swap(Cb_); + Result_.Swap(result); + + if (Redirects_) { + Result_->Redirects.assign(Redirects_->begin(), Redirects_->end()); + } + } + + if (cb) { + try { + cb(Result_); + } catch (...) { + SetException(std::current_exception()); + } + } + + { + auto g(Guard(Lock_)); + while (!Awaitings_.Empty()) { + Awaitings_.PopFront()->Signal(); + } + } + } + + TDuration TFetchRequest::OnResponse(NHttpFetcher::TResultRef result) { + if (AtomicGet(Cancel_)) { + goto finish; + } + + if (NHttpFetcher::IsRedirectCode(result->Code)) { + auto location = NUri::ResolveRedirectLocation(CurrentUrl_, result->Location); + + if (!Redirects_) { + Redirects_.Reset(new TRedirects(true)); + } + result->ResolvedUrl = location; + Redirects_->push_back(result); + + if (Options_.UseCookie) { + Redirects_->ParseCookies(CurrentUrl_, result->Headers); + } + + if (Redirects_->Level() < Options_.RedirectDepth) { + if (OnRedirect_) { + if (!OnRedirect_(CurrentUrl_, location)) { + goto finish; + } + } + + CurrentUrl_ = location; + return TDuration::Zero(); + } + } else if (!NHttpFetcher::IsSuccessCode(result->Code)) { + bool again = RetryAttempts_ > 0; + + if (OnFail_ && !OnFail_(result)) { + again = false; + } + if (again) { + RetryAttempts_--; + return RetryDelay_; + } + } + + finish: + Reply(result); + + return TDuration::Zero(); + } + +} diff --git a/library/cpp/http/client/request.h b/library/cpp/http/client/request.h new file mode 100644 index 0000000000..6dcf39ae9c --- /dev/null +++ b/library/cpp/http/client/request.h @@ -0,0 +1,133 @@ +#pragma once + +#include "query.h" + +#include <library/cpp/deprecated/atomic/atomic.h> + +#include <library/cpp/http/client/fetch/fetch_request.h> +#include <library/cpp/http/client/fetch/fetch_result.h> + +#include <library/cpp/http/client/cookies/cookiestore.h> + +#include <util/generic/intrlist.h> +#include <util/system/condvar.h> +#include <util/system/spinlock.h> + +namespace NHttp { + class TFetchRequest: public TAtomicRefCount<TFetchRequest> { + public: + TFetchRequest(const TString& url, const TFetchOptions& options); + TFetchRequest(const TString& url, TVector<TString> headers, const TFetchOptions& options); + + /// Returns reference to request object. + static TFetchRequestRef FromQuery(const TFetchQuery& query) { + return query.ConstructRequest(); + } + + /// Cancel the request. + void Cancel(); + + /// Is the current request is still valid? + bool IsValid() const; + + /// Is the current request cancelled? + bool IsCancelled() const; + + /// Makes request in the format of underlining fetch subsystem. + NHttpFetcher::TRequestRef GetRequestImpl() const; + + /// Whether new connection should been established. + bool GetForceReconnect() const; + + /// Unique identifier of the request. + ui64 GetId() const { + return Id_; + } + + /// Returns url of original request. + TString GetUrl() const { + return Url_; + } + + /// Makes final result. + NHttpFetcher::TResultRef MakeResult() const; + + void SetException(std::exception_ptr ptr); + + void SetCallback(NHttpFetcher::TCallBack cb); + + void SetOnFail(TOnFail cb); + + void SetOnRedirect(TOnRedirect cb); + + void SetOnPartialRead(NHttpFetcher::TNeedDataCallback cb); + + /// Waits for completion of the request no more than given timeout. + /// + /// \return true if fetching is finished. + /// \return false if timed out. + bool WaitT(TDuration timeout); + + /// Response has been gotten. + TDuration OnResponse(NHttpFetcher::TResultRef result); + + private: + bool IsValidNoLock() const; + + void Reply(NHttpFetcher::TResultRef result); + + private: + /// State of redirects processing. + struct TRedirects: public std::vector<NHttpFetcher::TResultRef> { + THolder<NHttp::TCookieStore> CookieStore; + + explicit TRedirects(bool parseCookie); + + size_t Level() const; + + void ParseCookies(const TString& url, const THttpHeaders& headers); + }; + + struct TWaitState : TIntrusiveListItem<TWaitState> { + inline TWaitState(TCondVar* c) + : C_(c) + { + } + + inline void Signal() { + C_->Signal(); + } + + TCondVar* const C_; + }; + + const TString Url_; + const TFetchOptions Options_; + TVector<TString> Headers_; + + TAtomic Id_; + ui32 RetryAttempts_; + TDuration RetryDelay_; + + NHttpFetcher::TCallBack + Cb_; + TOnFail OnFail_; + TOnRedirect OnRedirect_; + NHttpFetcher::TNeedDataCallback OnPartialRead_; + + std::exception_ptr Exception_; + NHttpFetcher::TResultRef + Result_; + + TMutex Lock_; + TIntrusiveList<TWaitState> + Awaitings_; + TAtomic Cancel_; + + /// During following through redirects the url is changing. + /// So, this is actual url for the current step. + TString CurrentUrl_; + THolder<TRedirects> Redirects_; + }; + +} diff --git a/library/cpp/http/client/scheduler.cpp b/library/cpp/http/client/scheduler.cpp new file mode 100644 index 0000000000..87670bfb45 --- /dev/null +++ b/library/cpp/http/client/scheduler.cpp @@ -0,0 +1,37 @@ +#include "scheduler.h" + +namespace NHttp { + namespace { + class TDefaultHostsPolicy: public IHostsPolicy { + public: + size_t GetMaxHostConnections(const TStringBuf&) const override { + return 20; + } + }; + + } + + TScheduler::TScheduler() + : HostsPolicy_(new TDefaultHostsPolicy) + { + } + + TFetchRequestRef TScheduler::Extract() { + { + auto g(Guard(Lock_)); + + if (!RequestQueue_.empty()) { + TFetchRequestRef result(RequestQueue_.front()); + RequestQueue_.pop(); + return result; + } + } + return TFetchRequestRef(); + } + + void TScheduler::Schedule(TFetchRequestRef req) { + auto g(Guard(Lock_)); + RequestQueue_.push(req); + } + +} diff --git a/library/cpp/http/client/scheduler.h b/library/cpp/http/client/scheduler.h new file mode 100644 index 0000000000..6700d7cee9 --- /dev/null +++ b/library/cpp/http/client/scheduler.h @@ -0,0 +1,47 @@ +#pragma once + +#include "request.h" + +#include <util/generic/ptr.h> +#include <util/generic/queue.h> +#include <util/generic/strbuf.h> +#include <util/system/mutex.h> + +namespace NHttp { + using namespace NHttpFetcher; + + // Асинхронный механизм скачивания. + // Несколько документов одновременно. + // - несколько потоков + // - контроль нагрузки на хост => один объект на приложение. + // Редиректы + // Отмена запроса по таймеру. + + class IHostsPolicy { + public: + virtual ~IHostsPolicy() = default; + + //! Максимальное количество одновременных соединений к хосту. + virtual size_t GetMaxHostConnections(const TStringBuf& host) const = 0; + }; + + //! Управляет процессом скачивания документа по заданному урлу. + class TScheduler { + // host loading + // redirects + public: + TScheduler(); + + //! Получить запрос на скачивание. + TFetchRequestRef Extract(); + + //! Поместить запрос в очередь на скачивание. + void Schedule(TFetchRequestRef req); + + private: + THolder<IHostsPolicy> HostsPolicy_; + TMutex Lock_; + TQueue<TFetchRequestRef> RequestQueue_; + }; + +} diff --git a/library/cpp/http/client/ssl/sslsock.cpp b/library/cpp/http/client/ssl/sslsock.cpp new file mode 100644 index 0000000000..97104de9ec --- /dev/null +++ b/library/cpp/http/client/ssl/sslsock.cpp @@ -0,0 +1,163 @@ +#include "sslsock.h" + +#include <library/cpp/openssl/init/init.h> + +#include <util/datetime/base.h> +#include <util/draft/holder_vector.h> +#include <util/generic/singleton.h> +#include <util/stream/str.h> +#include <util/system/mutex.h> + +#include <contrib/libs/openssl/include/openssl/conf.h> +#include <contrib/libs/openssl/include/openssl/crypto.h> +#include <contrib/libs/openssl/include/openssl/err.h> +#include <contrib/libs/openssl/include/openssl/x509v3.h> + +#include <contrib/libs/libc_compat/string.h> + +namespace NHttpFetcher { + namespace { + struct TSslInit { + inline TSslInit() { + InitOpenSSL(); + } + } SSL_INIT; + } + + TString TSslSocketBase::GetSslErrors() { + TStringStream ss; + unsigned long err = 0; + while (err = ERR_get_error()) { + if (ss.Str().size()) { + ss << "\n"; + } + ss << ERR_error_string(err, nullptr); + } + return ss.Str(); + } + + class TSslSocketBase::TSslCtx { + public: + SSL_CTX* Ctx; + + TSslCtx() { + const SSL_METHOD* method = SSLv23_method(); + if (!method) { + TString err = GetSslErrors(); + Y_FAIL("SslSocket StaticInit: SSLv23_method failed: %s", err.data()); + } + Ctx = SSL_CTX_new(method); + if (!Ctx) { + TString err = GetSslErrors(); + Y_FAIL("SSL_CTX_new: %s", err.data()); + } + + SSL_CTX_set_options(Ctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | SSL_OP_NO_COMPRESSION); + SSL_CTX_set_verify(Ctx, SSL_VERIFY_NONE, nullptr); + } + + ~TSslCtx() { + SSL_CTX_free(Ctx); + } + }; + + SSL_CTX* TSslSocketBase::SslCtx() { + return Singleton<TSslCtx>()->Ctx; + } + + void TSslSocketBase::LoadCaCerts(const TString& caFile, const TString& caPath) { + if (1 != SSL_CTX_load_verify_locations(SslCtx(), + !!caFile ? caFile.data() : nullptr, + !!caPath ? caPath.data() : nullptr)) + { + ythrow yexception() << "Error loading CA certs: " << GetSslErrors(); + } + SSL_CTX_set_verify(SslCtx(), SSL_VERIFY_PEER, nullptr); + } + + namespace { + enum EMatchResult { + MATCH_FOUND, + NO_MATCH, + NO_EXTENSION, + ERROR + }; + + bool EqualNoCase(TStringBuf a, TStringBuf b) { + return (a.size() == b.size()) && (strncasecmp(a.data(), b.data(), a.size()) == 0); + } + + bool MatchDomainName(TStringBuf tmpl, TStringBuf name) { + // match wildcards only in the left-most part + // do not support (optional according to RFC) partial wildcards (ww*.yandex.ru) + // see RFC-6125 + TStringBuf tmplRest = tmpl; + TStringBuf tmplFirst = tmplRest.NextTok('.'); + if (tmplFirst == "*") { + tmpl = tmplRest; + name.NextTok('.'); + } + return EqualNoCase(tmpl, name); + } + + EMatchResult MatchCertCommonName(X509* cert, TStringBuf hostname) { + int commonNameLoc = X509_NAME_get_index_by_NID(X509_get_subject_name(cert), NID_commonName, -1); + if (commonNameLoc < 0) { + return ERROR; + } + + X509_NAME_ENTRY* commonNameEntry = X509_NAME_get_entry(X509_get_subject_name(cert), commonNameLoc); + if (!commonNameEntry) { + return ERROR; + } + + ASN1_STRING* commonNameAsn1 = X509_NAME_ENTRY_get_data(commonNameEntry); + if (!commonNameAsn1) { + return ERROR; + } + + TStringBuf commonName((const char*)ASN1_STRING_get0_data(commonNameAsn1), ASN1_STRING_length(commonNameAsn1)); + + return MatchDomainName(commonName, hostname) + ? MATCH_FOUND + : NO_MATCH; + } + + EMatchResult MatchCertAltNames(X509* cert, TStringBuf hostname) { + EMatchResult result = NO_MATCH; + STACK_OF(GENERAL_NAME)* names = (STACK_OF(GENERAL_NAME)*)X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, NULL); + if (!names) { + return NO_EXTENSION; + } + + int namesCt = sk_GENERAL_NAME_num(names); + for (int i = 0; i < namesCt; ++i) { + const GENERAL_NAME* name = sk_GENERAL_NAME_value(names, i); + + if (name->type == GEN_DNS) { + TStringBuf dnsName((const char*)ASN1_STRING_get0_data(name->d.dNSName), ASN1_STRING_length(name->d.dNSName)); + if (MatchDomainName(dnsName, hostname)) { + result = MATCH_FOUND; + break; + } + } + } + sk_GENERAL_NAME_pop_free(names, GENERAL_NAME_free); + return result; + } + } + + bool TSslSocketBase::CheckCertHostname(X509* cert, TStringBuf hostname) { + switch (MatchCertAltNames(cert, hostname)) { + case MATCH_FOUND: + return true; + break; + case NO_EXTENSION: + return MatchCertCommonName(cert, hostname) == MATCH_FOUND; + break; + default: + return false; + } + } + +} diff --git a/library/cpp/http/client/ssl/sslsock.h b/library/cpp/http/client/ssl/sslsock.h new file mode 100644 index 0000000000..1e57ce4e56 --- /dev/null +++ b/library/cpp/http/client/ssl/sslsock.h @@ -0,0 +1,414 @@ +#pragma once + +#include <library/cpp/http/fetch/sockhandler.h> +#include <library/cpp/openssl/method/io.h> + +#include <util/generic/maybe.h> +#include <util/network/ip.h> +#include <util/network/socket.h> +#include <util/system/yassert.h> + +#include <contrib/libs/openssl/include/openssl/ssl.h> +#include <cerrno> +#include <util/generic/noncopyable.h> + +namespace NHttpFetcher { + class TSslSocketBase { + public: + enum ECertErrors { + SSL_CERT_VALIDATION_FAILED = 0x01, + SSL_CERT_HOSTNAME_MISMATCH = 0x02, + }; + + struct TSessIdDestructor { + static void Destroy(SSL_SESSION* session) { + SSL_SESSION_free(session); + } + }; + + static void LoadCaCerts(const TString& caFile, const TString& caPath); + + protected: + class TSocketCtx { + public: + ui16 SslError; + ui16 CertErrors; + const char* Host; + size_t HostLen; + + public: + TSocketCtx() + : SslError(0) + , CertErrors(0) + , Host(nullptr) + , HostLen(0) + { + } + + void AllocBuffers() { + } + + void FreeBuffers() { + } + }; + + protected: + static SSL_CTX* SslCtx(); + + static TString GetSslErrors(); + + static bool CheckCertHostname(X509* cert, TStringBuf hostname); + + private: + class TSslCtx; + + public: + class TFakeLogger { + public: + static void Write(const char* /*format*/, ...) { + } + }; + }; + + namespace NPrivate { + template <typename TSocketHandler> + class TSocketHandlerIO : public NOpenSSL::TAbstractIO { + private: + TSocketHandler& Sock; + + public: + TSocketHandlerIO(TSocketHandler& sock) + : Sock(sock) + { + } + + int Write(const char* data, size_t dlen, size_t* written) override { + int ret = Sock.send(data, dlen); + + if (ret <= 0) { + *written = 0; + return ret; + } + + *written = dlen; // send returns only 0 or 1 + return 1; + } + + int Read(char* data, size_t dlen, size_t* readbytes) override { + ssize_t ret = Sock.read(data, dlen); + + if (ret <= 0) { + *readbytes = 0; + return ret; + } + + *readbytes = ret; + return 1; + } + + int Puts(const char* buf) override { + TStringBuf sb(buf); + size_t written = 0; + + int ret = Write(sb.data(), sb.size(), &written); + + if (ret <= 0) { + return ret; + } + + return written; + } + + int Gets(char* buf, int size) override { + Y_UNUSED(buf); + Y_UNUSED(size); + return -1; + } + + void Flush() override { + } + }; + + class TDestroyBio { + public: + static void Destroy(BIO* bio) { + if (BIO_free(bio) != 1) { + Y_FAIL("BIO_free failed"); + } + } + }; + + class TDestroyCert { + public: + static void Destroy(X509* cert) { + X509_free(cert); + } + }; + + class TErrnoRestore { + private: + int Value; + + public: + TErrnoRestore() + : Value(errno) + { + } + + ~TErrnoRestore() { + errno = Value; + } + }; + } + + template <class TSocketHandler, class TErrorLogger = TSslSocketBase::TFakeLogger> + class TSslSocketHandler + : public TSslSocketBase, + protected TSocketHandler, + TNonCopyable { + private: + public: + struct TSocketCtx: public TSslSocketBase::TSocketCtx { + THolder<SSL_SESSION, TSessIdDestructor> CachedSession; + TMaybe<char> PeekedByte; + }; + + TSslSocketHandler() + : TSocketHandler() + { + } + + virtual ~TSslSocketHandler() { + Disconnect(); + } + + int Good() const { + return TSocketHandler::Good(); + } + bool HasSsl() const { + return !!SslBio; + } + + // set reconnect "true" to try to recover from cached session id + int Connect(TSocketCtx* ctx, const TAddrList& addrs, TDuration timeout, bool isHttps, bool reconnect = false); + + // for debug "file" socket + bool open(const char* name) { + Disconnect(); + return TSocketHandler::open(name); + } + + void Disconnect(TSocketCtx* ctx = nullptr) { // if ctx is non-NULL, cache session id in it. + if (!!SslBio) { + if (ctx) { + SSL* ssl; + if (!BIO_get_ssl(SslBio.Get(), &ssl)) { + Y_FAIL("BIO_get_ssl failed"); + } + SSL_SESSION* sess = SSL_get1_session(ssl); + if (!sess) { + TErrorLogger::Write("TSslSocketHandler::Disconnect: failed to create session id for host %s\n", TString(ctx->Host, ctx->HostLen).data()); + } + ctx->CachedSession.Reset(sess); + } + BIO_ssl_shutdown(SslBio.Get()); + SslBio.Destroy(); + SocketBio.Destroy(); + } + TSocketHandler::Disconnect(); + } + + void shutdown() { + TSocketHandler::shutdown(); + } + + ssize_t send(TSocketCtx* ctx, const void* message, size_t messlen) { + Y_ASSERT(TSocketHandler::Good()); + if (!SslBio) + return TSocketHandler::send(message, messlen); + int rc = SslWrite(ctx, static_cast<const char*>(message), (int)messlen); + if (rc < 0) { + NPrivate::TErrnoRestore errnoRest; + Disconnect(); + return false; + } + Y_ASSERT((size_t)rc == messlen); + return true; + } + + bool peek(TSocketCtx* ctx) { + if (!SslBio) + return TSocketHandler::peek(); + int rc; + rc = SslRead(ctx, nullptr, 0); + if (rc < 0) { + NPrivate::TErrnoRestore errnoRest; + Disconnect(); + return false; + } else { + return (rc > 0); + } + } + + ssize_t read(TSocketCtx* ctx, void* message, size_t messlen) { + if (!SslBio) + return TSocketHandler::read(message, messlen); + int rc; + if (!messlen) + return 0; + rc = SslRead(ctx, static_cast<char*>(message), (int)messlen); + if (rc < 0) { + NPrivate::TErrnoRestore errnoRest; + Disconnect(); + } + return rc; + } + + private: + int SslRead(TSocketCtx* ctx, char* buf, int buflen); + int SslWrite(TSocketCtx* ctx, const char* buf, int len); + + THolder<BIO, NPrivate::TDestroyBio> SslBio; + THolder<NPrivate::TSocketHandlerIO<TSocketHandler>> SocketBio; + }; + + template <typename TSocketHandler, typename TErrorLogger> + int TSslSocketHandler<TSocketHandler, TErrorLogger>::Connect(TSocketCtx* ctx, const TAddrList& addrs, TDuration timeout, bool isHttps, bool reconnect) { + ctx->SslError = 0; + ctx->CertErrors = 0; + Disconnect(); + int res = TSocketHandler::Connect(addrs, timeout); + if (!isHttps || res != 0) { + ctx->CachedSession.Destroy(); + return res; + } + + // create ssl session + SslBio.Reset(BIO_new_ssl(SslCtx(), /*client =*/1)); + if (!SslBio) { + ctx->SslError = 1; + NPrivate::TErrnoRestore errnoRest; + Disconnect(); + return -1; + } + + SSL* ssl; + if (BIO_get_ssl(SslBio.Get(), &ssl) != 1) { + Y_FAIL("BIO_get_ssl failed"); + } + + SSL_set_ex_data(ssl, /*index =*/0, ctx); + + if (reconnect) { + SSL_set_session(ssl, ctx->CachedSession.Get()); + } + + TString host(ctx->Host, ctx->HostLen); + host = host.substr(0, host.rfind(':')); + if (SSL_set_tlsext_host_name(ssl, host.data()) != 1) { + ctx->SslError = 1; + return -1; + } + + SocketBio.Reset(new NPrivate::TSocketHandlerIO<TSocketHandler>(*this)); + + BIO_push(SslBio.Get(), *SocketBio); + + ctx->CachedSession.Destroy(); + + // now it's time to perform handshake + if (BIO_do_handshake(SslBio.Get()) != 1) { + long verify_err = SSL_get_verify_result(ssl); + if (verify_err != X509_V_OK) { + // It failed because the certificate chain validation failed + TErrorLogger::Write("SSL Handshake failed: %s", GetSslErrors().data()); + ctx->CertErrors |= SSL_CERT_VALIDATION_FAILED; + } + ctx->SslError = 1; + return -1; + } + + THolder<X509, NPrivate::TDestroyCert> cert(SSL_get_peer_certificate(ssl)); + if (!cert) { + // The handshake was successful although the server did not provide a certificate + // Most likely using an insecure anonymous cipher suite... get out! + TErrorLogger::Write("No SSL certificate"); + ctx->SslError = 1; + return -1; + } + + if (!CheckCertHostname(cert.Get(), host)) { + ctx->SslError = 1; + ctx->CertErrors |= SSL_CERT_HOSTNAME_MISMATCH; + return -1; + } + + return 0; + } + + template <typename TSocketHandler, typename TErrorLogger> + int TSslSocketHandler<TSocketHandler, TErrorLogger>::SslRead(TSocketCtx* ctx, char* buf, int buflen) { + Y_ASSERT(SslCtx()); + + if (!SslBio || buflen < 0) + return -1; + + if (!buf) { + // peek + char byte; + int res = BIO_read(SslBio.Get(), &byte, 1); + if (res < 0) { + ctx->SslError = 1; + return -1; + } + if (res == 0) { + return 0; + } + Y_VERIFY(res == 1); + ctx->PeekedByte = byte; + return 1; + } + + if (buflen) { + size_t read = 0; + if (!!ctx->PeekedByte) { + *buf = *(ctx->PeekedByte); + ++buf; + --buflen; + ctx->PeekedByte.Clear(); + read = 1; + } + int res = BIO_read(SslBio.Get(), buf, buflen); + if (res < 0) { + ctx->SslError = 1; + return -1; + } + read += res; + return read; + } + + return 0; + } + + template <typename TSocketHandler, typename TErrorLogger> + int TSslSocketHandler<TSocketHandler, TErrorLogger>::SslWrite(TSocketCtx* ctx, const char* buf, int len) { + if (len <= 0) + return len ? -1 : 0; + + size_t remaining = len; + while (remaining) { + int res = BIO_write(SslBio.Get(), buf, len); + if (res < 0) { + ctx->SslError = 1; + return -1; + } + remaining -= res; + buf += res; + } + if (BIO_flush(SslBio.Get()) != 1) { + ctx->SslError = 1; + return -1; + } + return len; + } +} diff --git a/library/cpp/http/cookies/cookies.cpp b/library/cpp/http/cookies/cookies.cpp new file mode 100644 index 0000000000..12b66c7f9d --- /dev/null +++ b/library/cpp/http/cookies/cookies.cpp @@ -0,0 +1,33 @@ +#include "cookies.h" + +#include <library/cpp/string_utils/scan/scan.h> +#include <util/string/strip.h> +#include <util/string/builder.h> + +namespace { + struct TCookiesScanner { + THttpCookies* C; + + inline void operator()(const TStringBuf& key, const TStringBuf& val) { + C->Add(StripString(key), StripString(val)); + } + }; +} + +void THttpCookies::Scan(const TStringBuf& s) { + Clear(); + TCookiesScanner scan = {this}; + ScanKeyValue<true, ';', '='>(s, scan); +} + +/*** https://datatracker.ietf.org/doc/html/rfc6265#section-5.4 ***/ +TString THttpCookies::ToString() const { + TStringBuilder result; + for (const auto& [key, value] : *this) { + if (!result.empty()) { + result << "; "; + } + result << key << "=" << value; + } + return result; +} diff --git a/library/cpp/http/cookies/cookies.h b/library/cpp/http/cookies/cookies.h new file mode 100644 index 0000000000..d7a0030c8b --- /dev/null +++ b/library/cpp/http/cookies/cookies.h @@ -0,0 +1,17 @@ +#pragma once + +#include "lctable.h" + +class THttpCookies: public TLowerCaseTable<TStringBuf> { +public: + inline THttpCookies(const TStringBuf& cookieString) { + Scan(cookieString); + } + + inline THttpCookies() noexcept { + } + + void Scan(const TStringBuf& cookieString); + + TString ToString() const; +}; diff --git a/library/cpp/http/cookies/lctable.h b/library/cpp/http/cookies/lctable.h new file mode 100644 index 0000000000..09c88eafb8 --- /dev/null +++ b/library/cpp/http/cookies/lctable.h @@ -0,0 +1,86 @@ +#pragma once + +#include <library/cpp/digest/lower_case/lchash.h> + +#include <util/generic/hash_multi_map.h> +#include <util/generic/strbuf.h> +#include <util/generic/algorithm.h> +#include <util/generic/singleton.h> + +struct TStrBufHash { + inline size_t operator()(const TStringBuf& s) const noexcept { + return FnvCaseLess<size_t>(s); + } +}; + +struct TStrBufEqualToCaseLess { + inline bool operator()(const TStringBuf& c1, const TStringBuf& c2) const noexcept { + typedef TLowerCaseIterator<const TStringBuf::TChar> TIter; + + return (c1.size() == c2.size()) && std::equal(TIter(c1.begin()), TIter(c1.end()), TIter(c2.begin())); + } +}; + +template <class T> +class TLowerCaseTable: private THashMultiMap<TStringBuf, T, TStrBufHash, TStrBufEqualToCaseLess> { + typedef THashMultiMap<TStringBuf, T, TStrBufHash, TStrBufEqualToCaseLess> TBase; + +public: + typedef typename TBase::const_iterator const_iterator; + typedef std::pair<const_iterator, const_iterator> TConstIteratorPair; + + using TBase::TBase; + using TBase::begin; + using TBase::end; + + inline TConstIteratorPair EqualRange(const TStringBuf& name) const { + return TBase::equal_range(name); + } + + inline const T& Get(const TStringBuf& name, size_t numOfValue = 0) const { + TConstIteratorPair range = EqualRange(name); + + if (range.first == TBase::end()) + return Default<T>(); + + if (numOfValue == 0) + return range.first->second; + + const_iterator next = range.first; + for (size_t c = 0; c < numOfValue; ++c) { + ++next; + if (next == range.second) + return Default<T>(); + } + + return next->second; + } + + inline bool Has(const TStringBuf& name) const { + return TBase::find(name) != TBase::end(); + } + + size_t NumOfValues(const TStringBuf& name) const { + return TBase::count(name); + } + + inline size_t Size() const noexcept { + return TBase::size(); + } + + inline bool Empty() const noexcept { + return TBase::empty(); + } + + inline void Add(const TStringBuf& key, const T& val) { + TBase::insert(typename TBase::value_type(key, val)); + } + + inline void Clear() noexcept { + TBase::clear(); + } + + inline size_t Erase(const TStringBuf& key) { + return TBase::erase(key); + } +}; diff --git a/library/cpp/http/fetch_gpl/httpagent.h b/library/cpp/http/fetch_gpl/httpagent.h new file mode 100644 index 0000000000..b77c246c4f --- /dev/null +++ b/library/cpp/http/fetch_gpl/httpagent.h @@ -0,0 +1,292 @@ +#pragma once + +#include <library/cpp/http/fetch/httpagent.h> + +template <class TSockHndl = TSimpleSocketHandler, + class TDnsClient = TIpResolver, + class TErrorLogger = TSslSocketBase::TFakeLogger, + class TTimer = TNoTimer, + template <class, class> class TSslSocketImpl = TSslSocketHandler> +class THttpsAgent: public TTimer { +public: + typedef TSslSocketImpl<TSockHndl, TErrorLogger> TSocket; + THttpsAgent() + : Socket(new TSocket) + , Scheme(0) + , Persistent(0) + , Timeout(TDuration::MicroSeconds(150)) + , Hostheader(nullptr) + , Footer(nullptr) + , pHostBeg(nullptr) + , pHostEnd(nullptr) + , AltFooter(nullptr) + , PostData(nullptr) + , PostDataLen(0) + , Method(nullptr) + , MethodLen(0) + , HostheaderLen(0) + { + SetIdentification("YandexSomething/1.0", "webadmin@yandex.ru"); + } + + ~THttpsAgent() { + Disconnect(); + delete[] Hostheader; + delete[] Footer; + } + + void SetIdentification(const char* userAgent, const char* httpFrom) { + Y_VERIFY(Socket.Get(), "HttpsAgent: socket is picked out. Can't use until a valid socket is set"); + delete[] Footer; + size_t len = userAgent ? strlen(userAgent) + 15 : 0; + len += httpFrom ? strlen(httpFrom) + 9 : 0; + len += 3; + Footer = new char[len]; + if (userAgent) + strcat(strcat(strcpy(Footer, "User-Agent: "), userAgent), "\r\n"); + if (httpFrom) + strcat(strcat(strcat(Footer, "From: "), httpFrom), "\r\n"); + } + + void SetUserAgentFooter(const char* altFooter) { + AltFooter = altFooter; + } + + void SetPostData(const char* postData, size_t postDataLen) { + PostData = postData; + PostDataLen = postDataLen; + } + + void SetMethod(const char* method, size_t methodLen) { + Method = method; + MethodLen = methodLen; + } + + // deprecated + ui32 GetIp() const { + return Addrs.GetV4Addr().first; + } + + int GetScheme() const { + return Scheme; + } + + void SetTimeout(TDuration tim) { + Timeout = tim; + } + + void SetConnectTimeout(TDuration timeout) { + ConnectTimeout = timeout; + } + + int Disconnected() { + return !Persistent || !Socket.Get() || !Socket->Good(); + } + + int SetHost(const char* hostname, TIpPort port, int scheme = THttpURL::SchemeHTTP) { + TStringBuf host{hostname}; + if (host.StartsWith('[') && host.EndsWith(']')) { + TString tmp = ToString(host.Skip(1).Chop(1)); + TSockAddrInet6 sa(tmp.data(), port); + NAddr::IRemoteAddrRef addr = new NAddr::TIPv6Addr(sa); + SetHost(hostname, port, {addr}, scheme); + return 0; + } + TAddrList addrs = DnsClient.Resolve(hostname, port); + if (!addrs.size()) { + return 1; + } + SetHost(hostname, port, addrs, scheme); + return 0; + } + + int SetHost(const char* hostname, TIpPort port, const TAddrList& addrs, int scheme = THttpURL::SchemeHTTP) { + Disconnect(); + Addrs = addrs; + Scheme = scheme; + size_t reqHostheaderLen = strlen(hostname) + 20; + if (HostheaderLen < reqHostheaderLen) { + delete[] Hostheader; + Hostheader = new char[(HostheaderLen = reqHostheaderLen)]; + } + if (Scheme == THttpURL::SchemeHTTPS && port == 443 || port == 80) + sprintf(Hostheader, "Host: %s\r\n", hostname); + else + sprintf(Hostheader, "Host: %s:%u\r\n", hostname, port); + pHostBeg = strchr(Hostheader, ' ') + 1; + pHostEnd = strchr(pHostBeg, '\r'); + // convert hostname to lower case since some web server don't like + // uppper case (Task ROBOT-562) + for (char* p = pHostBeg; p < pHostEnd; p++) + *p = tolower(*p); + SocketCtx.Host = pHostBeg; + SocketCtx.HostLen = pHostEnd - pHostBeg; + return 0; + } + + // deprecated v4-only version + int SetHost(const char* hostname, TIpPort port, ui32 ip, int scheme = THttpURL::SchemeHTTP, TIpPort connPort = 0) { + connPort = connPort ? connPort : port; + return SetHost(hostname, port, TAddrList::MakeV4Addr(ip, connPort), scheme); + } + + void SetHostHeader(const char* host) { + size_t reqHostheaderLen = strlen(host) + 20; + if (HostheaderLen < reqHostheaderLen) { + delete[] Hostheader; + Hostheader = new char[(HostheaderLen = reqHostheaderLen)]; + } + sprintf(Hostheader, "Host: %s\r\n", host); + pHostBeg = strchr(Hostheader, ' ') + 1; + pHostEnd = strchr(pHostBeg, '\r'); + SocketCtx.Host = pHostBeg; + SocketCtx.HostLen = pHostEnd - pHostBeg; + } + + void SetSocket(TSocket* s) { + Y_VERIFY(s, "HttpsAgent: socket handler is null"); + SocketCtx.FreeBuffers(); + if (s->HasSsl()) + SocketCtx.AllocBuffers(); + Socket.Reset(s); + } + + void SetPersistent(const bool value) { + Persistent = value; + } + + TSocket* PickOutSocket() { + SocketCtx.FreeBuffers(); + SocketCtx.CachedSession.Destroy(); + return Socket.Release(); + } + + void Disconnect() { + if (Socket.Get()) + Socket->Disconnect(); + SocketCtx.FreeBuffers(); + SocketCtx.CachedSession.Destroy(); + } + + ssize_t read(void* buffer, size_t buflen) { + Y_VERIFY(Socket.Get(), "HttpsAgent: socket is picked out. Can't use until a valid socket is set"); + ssize_t ret = Socket->read(&SocketCtx, buffer, buflen); + TTimer::OnAfterRecv(); + return ret; + } + + int RequestGet(const char* url, const char* const* headers, int persistent = 1, bool head_request = false) { + Y_VERIFY(Socket.Get(), "HttpsAgent: socket is picked out. Can't use until a valid socket is set"); + if (!Addrs.size()) + return HTTP_DNS_FAILURE; + char message[MessageMax]; + ssize_t messlen = 0; + if (Method) { + strncpy(message, Method, MethodLen); + message[MethodLen] = ' '; + messlen = MethodLen + 1; + } else if (PostData) { + strcpy(message, "POST "); + messlen = 5; + } else if (head_request) { + strcpy(message, "HEAD "); + messlen = 5; + } else { + strcpy(message, "GET "); + messlen = 4; + } +#define _AppendMessage(mes) messlen += Min(MessageMax - messlen, \ + (ssize_t)strlcpy(message + messlen, (mes), MessageMax - messlen)) + _AppendMessage(url); + _AppendMessage(" HTTP/1.1\r\n"); + _AppendMessage(Hostheader); + _AppendMessage("Connection: "); + _AppendMessage(persistent ? "Keep-Alive\r\n" : "Close\r\n"); + while (headers && *headers) + _AppendMessage(*headers++); + if (AltFooter) + _AppendMessage(AltFooter); + else + _AppendMessage(Footer); + _AppendMessage("\r\n"); +#undef _AppendMessage + if (messlen >= MessageMax) + return HTTP_HEADER_TOO_LARGE; + + if (!Persistent) + Socket->Disconnect(&SocketCtx); + Persistent = persistent; + int connected = Socket->Good(); + SocketCtx.FreeBuffers(); + if (Scheme == THttpURL::SchemeHTTPS) { + SocketCtx.AllocBuffers(); + } + + bool success = false; + Y_SCOPE_EXIT(&success, this) { if (!success) { this->SocketCtx.FreeBuffers(); }; }; + + TTimer::OnBeforeSend(); + for (int attempt = !connected; attempt < 2; attempt++) { + const auto connectTimeout = ConnectTimeout ? ConnectTimeout : Timeout; + if (!Socket->Good() && Socket->Connect(&SocketCtx, Addrs, connectTimeout , Scheme == THttpURL::SchemeHTTPS, true)) { + return SocketCtx.SslError ? HTTP_SSL_ERROR : HTTP_CONNECT_FAILED; + } else { // We successfully connected + connected = true; + } + + int sendOk = Socket->send(&SocketCtx, message, messlen); + if (sendOk && PostData && PostDataLen) + sendOk = Socket->send(&SocketCtx, PostData, PostDataLen); + if (!sendOk) { + int err = errno; + Socket->Disconnect(&SocketCtx); + errno = err; + continue; + } + TTimer::OnAfterSend(); + + if (!Socket->peek(&SocketCtx)) { + int err = errno; + Socket->Disconnect(&SocketCtx); + if (err == EINTR) { + errno = err; + return HTTP_INTERRUPTED; + } + if (err == ETIMEDOUT) { + errno = err; + return HTTP_TIMEDOUT_WHILE_BYTES_RECEIVING; + } + } else { + TTimer::OnBeforeRecv(); + if (!persistent) { + Socket->shutdown(); + } + success = true; + return 0; + } + } + return SocketCtx.SslError ? HTTP_SSL_ERROR : (connected ? HTTP_CONNECTION_LOST : HTTP_CONNECT_FAILED); + } + + ui16 CertCheckErrors() const { + return SocketCtx.CertErrors; + } + +protected: + THolder<TSocket> Socket; + typename TSocket::TSocketCtx SocketCtx; + TIpResolverWrapper<TDnsClient> DnsClient; + TAddrList Addrs; + int Scheme; + int Persistent; + TDuration Timeout; + TDuration ConnectTimeout; + char *Hostheader, *Footer, *pHostBeg, *pHostEnd; + const char* AltFooter; // alternative footer can be set by the caller + const char* PostData; + size_t PostDataLen; + const char* Method; + size_t MethodLen; + unsigned short HostheaderLen; + static const ssize_t MessageMax = 65536; +}; diff --git a/library/cpp/http/fetch_gpl/sockhandler.cpp b/library/cpp/http/fetch_gpl/sockhandler.cpp new file mode 100644 index 0000000000..a3fa2e9de1 --- /dev/null +++ b/library/cpp/http/fetch_gpl/sockhandler.cpp @@ -0,0 +1,135 @@ +#include "sockhandler.h" + +#include <util/datetime/base.h> + +bool TSslSocketBase::Initialized = false; +sslKeys_t* TSslSocketBase::Keys = nullptr; +THolder<TSslSocketBase::TBufferAllocator> TSslSocketBase::BufAlloc; + +bool TSslSocketBase::StaticInit(const char* caFile) { + Y_VERIFY(!Initialized, "SslSocket StaticInit: already initialized"); + BufAlloc.Reset(new TSslSocketBase::TBufferAllocator); + if (matrixSslOpen() < 0) + Y_FAIL("SslSocket StaticInit: unable to initialize matrixSsl"); + Y_VERIFY(caFile && caFile[0], "SslSocket StaticInit: no certificate authority file"); + + if (matrixSslReadKeys(&Keys, nullptr, nullptr, nullptr, caFile) < 0) { + Y_FAIL("SslSocket StaticInit: unable to load ssl keys from %s", caFile); + } + Initialized = true; + return Initialized; +} + +bool TSslSocketBase::StaticInit(unsigned char* caBuff, int caLen) { + Y_VERIFY(!Initialized, "SslSocket StaticInit: already initialized"); + BufAlloc.Reset(new TSslSocketBase::TBufferAllocator); + if (matrixSslOpen() < 0) + Y_FAIL("SslSocket StaticInit: unable to initialize matrixSsl"); + Y_VERIFY(caBuff && caBuff[0] && caLen > 0, "SslSocket StaticInit: no certificate authority file"); + + if (matrixSslReadKeysMem(&Keys, nullptr, 0, nullptr, 0, caBuff, caLen) < 0) { + Y_FAIL("SslSocket StaticInit: unable to load ssl keys from memory"); + } + Initialized = true; + return Initialized; +} + +void TSslSocketBase::StaticTerm() { + Y_VERIFY(Initialized, "SslSocket StaticTerm: not initialized"); + matrixSslFreeKeys(Keys); + matrixSslClose(); + Keys = nullptr; + BufAlloc.Reset(nullptr); + Initialized = false; +} + +bool MatchPattern(const char* p, const char* pe, const char* s, const char* se, int maxAsteriskNum) { + if (maxAsteriskNum <= 0) + return false; + while (p < pe && s < se) { + if (*p == '*') { + ++p; + while (p < pe && *p == '*') + ++p; + while (s < se) { + if (MatchPattern(p, pe, s, se, maxAsteriskNum - 1)) + return true; + if (*s == '.') + return false; + ++s; + } + return p == pe; + } else { + if (*p != *s) + return false; + ++p; + ++s; + } + } + while (p < pe && *p == '*') + ++p; + return (p == pe && s == se); +} + +bool MatchHostName(const char* pattern, size_t patternLen, const char* name, size_t nameLen) { + // rfc 2818 says: + // wildcard character * can match any single domain name component or component fragment + Y_VERIFY(name && nameLen, "Ssl certificate check error: hostname is empty"); + if (!pattern || !patternLen) + return false; + const char* ne = strchr(name, ':'); + if (!ne || ne > name + nameLen) + ne = name + nameLen; + return MatchPattern(pattern, pattern + patternLen, name, ne, 5); +} + +bool IsExpired(const char* notBefore, const char* notAfter) { + time_t notbefore, notafter; + if (!ParseX509ValidityDateTimeDeprecated(notBefore, notbefore) || !ParseX509ValidityDateTimeDeprecated(notAfter, notafter)) + return true; + time_t t = Seconds(); + return notbefore > t || t > notafter; +} + +int TSslSocketBase::CertChecker(sslCertInfo_t* cert, void* arg) { + Y_ASSERT(cert); + Y_ASSERT(arg); + TSocketCtx* ctx = (TSocketCtx*)arg; + ctx->CertErrors = 0; + + // matching hostname + if (ctx->Host && ctx->HostLen) { + bool nameMatched = false; + sslSubjectAltNameEntry* an = cert->subjectAltName; + while (an) { + // dNSName id is 2. + if (an->id == 2 && MatchHostName((const char*)an->data, an->dataLen, ctx->Host, ctx->HostLen)) { + nameMatched = true; + break; + } + an = an->next; + } + if (!nameMatched && cert->subject.commonName) { + nameMatched = MatchHostName(cert->subject.commonName, strlen(cert->subject.commonName), ctx->Host, ctx->HostLen); + } + if (!nameMatched) + ctx->CertErrors |= SSL_CERT_HOSTNAME_MISMATCH; + } + + // walk through certificate chain and check if they are signed correctly and not expired + sslCertInfo_t* c = cert; + while (c->next) { + if (IsExpired(c->notBefore, c->notAfter)) + ctx->CertErrors |= SSL_CERT_EXPIRED; + if (c->verified < 0) { + ctx->CertErrors |= SSL_CERT_BAD_CHAIN; + } + c = c->next; + } + if (c->verified < 0) + ctx->CertErrors |= SSL_CERT_UNTRUSTED; + if (IsExpired(c->notBefore, c->notAfter)) + ctx->CertErrors |= SSL_CERT_EXPIRED; + + return SSL_ALLOW_ANON_CONNECTION; +} diff --git a/library/cpp/http/fetch_gpl/sockhandler.h b/library/cpp/http/fetch_gpl/sockhandler.h new file mode 100644 index 0000000000..91d4f67a06 --- /dev/null +++ b/library/cpp/http/fetch_gpl/sockhandler.h @@ -0,0 +1,557 @@ +#pragma once + +#include <library/cpp/http/fetch/sockhandler.h> +#include <contrib/libs/matrixssl/matrixSsl.h> + +class TSslSocketBase { +public: + static bool StaticInit(const char* caFile = nullptr); + static bool StaticInit(unsigned char* caBuff, int caLen); + static void StaticTerm(); + static int CertChecker(sslCertInfo_t* cert, void* arg); + enum ECertErrors { + SSL_CERT_UNTRUSTED = 0x01, + SSL_CERT_BAD_CHAIN = 0x02, + SSL_CERT_HOSTNAME_MISMATCH = 0x04, + SSL_CERT_EXPIRED = 0x08 + }; + + struct TSessIdDestructor { + static void Destroy(sslSessionId_t* id) { + matrixSslFreeSessionId(id); + } + }; + +protected: + enum ESslStatus { + SSLSOCKET_EOF = 0x1, + SSLSOCKET_CLOSE_NOTIFY = 0x2 + }; + struct TSocketCtx { + ui16 SslError; + ui16 CertErrors; + const char* Host; + size_t HostLen; + TSocketCtx() + : SslError(0) + , CertErrors(0) + , Host(nullptr) + , HostLen(0) + { + } + }; + +protected: + class TBufferAllocator { + class TChunk; + typedef TIntrusiveSListItem<TChunk> TChunkBase; + + class TChunk: public TChunkBase { + public: + inline unsigned char* ToPointer() { + //shut up clang warning + (void)Buf; + + return (unsigned char*)this; + + static_assert(sizeof(TChunk) >= SSL_MAX_BUF_SIZE, "expect sizeof(TChunk) >= SSL_MAX_BUF_SIZE"); + } + static inline TChunk* FromPointer(unsigned char* ptr) { + return (TChunk*)ptr; + } + + private: + unsigned char Buf[SSL_MAX_BUF_SIZE - sizeof(TChunkBase)]; + }; + + public: + TBufferAllocator() + : NFree(0) + , NAllocated(0) + { + static_assert(InitialItems > 0 && A1 > 0 && A2 > 0 && A1 >= A2, "expect InitialItems > 0 && A1 > 0 && A2 > 0 && A1 >= A2"); + ResizeList(InitialItems); + } + + ~TBufferAllocator() { + Y_VERIFY(!NAllocated, "Ssl bufferAllocator: %" PRISZT " blocks lost!", NAllocated); + ResizeList(0); + } + + unsigned char* Alloc() { + TGuard<TMutex> guard(Lock); + if (Free_.Empty()) + ResizeList(A2 * NAllocated); + + NAllocated++; + NFree--; + return Free_.PopFront()->ToPointer(); + } + + void Free(unsigned char* p) { + if (!p) + return; + TGuard<TMutex> guard(Lock); + Y_VERIFY(NAllocated, "Ssl bufferAllocator: multiple frees?"); + TChunk* ch = TChunk::FromPointer(p); + Free_.PushFront(ch); + NFree++; + NAllocated--; + + // destroy some items if NFree/NAllocated increased too much + size_t newSize = A2 * NAllocated; + if (NAllocated + newSize >= InitialItems && NFree >= A1 * NAllocated) + ResizeList(newSize); + } + + private: + inline void ResizeList(size_t newSize) { + while (NFree < newSize) { + Free_.PushFront(new TChunk); + NFree++; + } + while (NFree > newSize) { + TChunk* ch = Free_.PopFront(); + Y_VERIFY(ch, "Ssl bufferAllocator: internal error"); + delete ch; + NFree--; + } + } + + static const size_t InitialItems = 100; + static const unsigned A1 = 3; // maximum reserved/allocated ratio + static const unsigned A2 = 1; // if ratio A1 is reached, decrease by A2 + + TIntrusiveSList<TChunk> Free_; + size_t NFree; + size_t NAllocated; + TMutex Lock; + }; + + static bool Initialized; + static sslKeys_t* Keys; + static THolder<TBufferAllocator> BufAlloc; + +public: + class TFakeLogger { + public: + static void Write(const char* /*format*/, ...) { + } + }; +}; + +template <class TSocketHandler = TSimpleSocketHandler, class TErrorLogger = TSslSocketBase::TFakeLogger> +class TSslSocketHandler: public TSslSocketBase, protected TSocketHandler, TNonCopyable { +public: + struct TSocketCtx: public TSslSocketBase::TSocketCtx { + sslBuf_t InBuf; + sslBuf_t InSock; + THolder<sslSessionId_t, TSessIdDestructor> CachedSession; + TSocketCtx() { + Zero(InBuf); + Zero(InSock); + } + void AllocBuffers() { + Y_ASSERT(InBuf.size == 0); + InBuf.size = SSL_MAX_BUF_SIZE; + InBuf.start = InBuf.end = InBuf.buf = BufAlloc->Alloc(); + + Y_ASSERT(InSock.size == 0); + InSock.size = SSL_MAX_BUF_SIZE; + InSock.start = InSock.end = InSock.buf = BufAlloc->Alloc(); + } + void FreeBuffers() { + if (InBuf.buf) { + if (InBuf.end - InBuf.start) { + // We had some data read and decrypted. Too sad, nobody needs it now :( + TErrorLogger::Write("TSocketCtx::FreeBuffers: %i bytes of data lost in InBuf (%s)\n", (int)(InBuf.end - InBuf.start), TString(Host, HostLen).data()); + } + BufAlloc->Free(InBuf.buf); + Zero(InBuf); + } + if (InSock.buf) { + if (InSock.end - InSock.start) { + // We had some data read and waiting for decryption. Most likely we disconnected before server's "bye". + TErrorLogger::Write("TSocketCtx::FreeBuffers: %i bytes of data lost in InSock (%s)\n", (int)(InSock.end - InSock.start), TString(Host, HostLen).data()); + } + BufAlloc->Free(InSock.buf); + Zero(InSock); + } + } + void ResetBuffers() { + InBuf.start = InBuf.end = InBuf.buf; + InSock.start = InSock.end = InSock.buf; + } + }; + + TSslSocketHandler() + : TSocketHandler() + , Ssl(nullptr) + { + Y_VERIFY(TSslSocketBase::Initialized, "Ssl library isn't initialized. Call TSslSocketBase::StaticInit() first"); + } + + virtual ~TSslSocketHandler() { + Y_ASSERT(Initialized); + Disconnect(); + } + + int Good() const { + return TSocketHandler::Good(); + } + bool HasSsl() const { + return Ssl; + } + + // set reconnect "true" to try to recover from cached session id + int Connect(TSocketCtx* ctx, const TAddrList& addrs, TDuration timeout, bool isHttps, bool reconnect = false); + + // for debug "file" socket + bool open(const char* name) { + Y_ASSERT(Initialized); + Disconnect(); + return TSocketHandler::open(name); + } + + void Disconnect(TSocketCtx* ctx = nullptr) { // if ctx is non-NULL, cache session id in it. + Y_ASSERT(Initialized); + if (Ssl) { + if (ctx) { + sslSessionId_t* cs; + if (matrixSslGetSessionId(Ssl, &cs) < 0) { + cs = nullptr; + TErrorLogger::Write("TSslSocketHandler::Disconnect: failed to create session id for host %s\n", TString(ctx->Host, ctx->HostLen).data()); + } + ctx->CachedSession.Reset(cs); + } + matrixSslDeleteSession(Ssl); + Ssl = nullptr; + } + TSocketHandler::Disconnect(); + } + + void shutdown() { + TSocketHandler::shutdown(); + } + + ssize_t send(TSocketCtx* ctx, const void* message, size_t messlen) { + Y_ASSERT(TSocketHandler::Good()); + if (!Ssl) + return TSocketHandler::send(message, messlen); + int status; + int rc = SslWrite(ctx, static_cast<const char*>(message), (int)messlen, &status); + if (rc < 0) { + errno = status; + ctx->ResetBuffers(); + Disconnect(); + return false; + } + Y_ASSERT((size_t)rc == messlen); + return true; + } + + bool peek(TSocketCtx* ctx) { + if (!Ssl) + return TSocketHandler::peek(); + int rc; + int status; + while (true) { + rc = SslRead(ctx, nullptr, 0, &status); + if (rc < 0) { + errno = status; + ctx->ResetBuffers(); + Disconnect(); + return false; + } else if (rc > 0) { + return true; + } + // else if (rc == 0) + if (status) { + errno = status; + ctx->ResetBuffers(); + Disconnect(); + return false; + } + } + } + + ssize_t read(TSocketCtx* ctx, void* message, size_t messlen) { + if (!Ssl) + return TSocketHandler::read(message, messlen); + int rc; + int status; + if (!messlen) + return 0; + while (true) { + rc = SslRead(ctx, static_cast<char*>(message), (int)messlen, &status); + if (rc < 0) { + errno = status; + ctx->ResetBuffers(); + Disconnect(); + return rc; + } else if (rc > 0) + return rc; + // else if (rc == 0) + if (status) { + errno = status; + ctx->ResetBuffers(); + Disconnect(); + return 0; + } + } + } + +private: + int SslRead(TSocketCtx* ctx, char* buf, int buflen, int* status); + int SslWrite(TSocketCtx* ctx, const char* buf, int len, int* status); + + ssl_t* Ssl; +}; + +template <typename TSocketHandler, typename TErrorLogger> +int TSslSocketHandler<TSocketHandler, TErrorLogger>::Connect(TSocketCtx* ctx, const TAddrList& addrs, TDuration timeout, bool isHttps, bool reconnect) { + Y_ASSERT(Initialized); + ctx->SslError = 0; + ctx->ResetBuffers(); + Disconnect(); + int res = TSocketHandler::Connect(addrs, timeout); + if (!isHttps || res != 0) { + ctx->CachedSession.Destroy(); + return res; + } + + // create ssl session + if ((res = matrixSslNewSession(&Ssl, Keys, reconnect ? ctx->CachedSession.Get() : nullptr, 0)) < 0) { + ctx->SslError = 1; + ctx->ResetBuffers(); + Disconnect(); + return res; + } + ctx->CachedSession.Destroy(); + + matrixSslSetCertValidator(Ssl, CertChecker, ctx); + + // now it's time to perform handshake + sslBuf_t outsock; + outsock.buf = outsock.start = outsock.end = BufAlloc->Alloc(); + outsock.size = SSL_MAX_BUF_SIZE; + + res = matrixSslEncodeClientHello(Ssl, &outsock, 0); + if (res) { + TErrorLogger::Write("TSslSocketHandler::Connect: internal error %i\n", res); + BufAlloc->Free(outsock.buf); + ctx->SslError = 1; + ctx->ResetBuffers(); + Disconnect(); + return -1; + } + + if (!TSocketHandler::send(outsock.start, outsock.end - outsock.start)) { + BufAlloc->Free(outsock.buf); + ctx->SslError = 1; + ctx->ResetBuffers(); + Disconnect(); + return -1; + } + BufAlloc->Free(outsock.buf); + + // SslRead will handle handshake and is suppozed to return 0 + int status, rc; + int ncalls = 10; // FIXME: maybe it's better to check time + while (true) { + rc = SslRead(ctx, nullptr, 0, &status); + if (rc == 0) { + if (status == SSLSOCKET_EOF || status == SSLSOCKET_CLOSE_NOTIFY) { + ctx->SslError = 1; + ctx->ResetBuffers(); + Disconnect(); + return -1; + } + if (matrixSslHandshakeIsComplete(Ssl)) + break; + if (--ncalls <= 0) { + TErrorLogger::Write("TSslSocketHandler::Connect: handshake too long (server wants multiple handshakes maybe)\n"); + ctx->SslError = 1; + ctx->ResetBuffers(); + Disconnect(); + return -1; + } + continue; + } else if (rc > 0) { + TErrorLogger::Write("TSslSocketHandler::Connect: server sent data instead of a handshake\n"); + ctx->SslError = 1; + ctx->ResetBuffers(); + Disconnect(); + return -1; + } else { // rc < 0 + //this is an error + ctx->SslError = 1; + ctx->ResetBuffers(); + Disconnect(); + return -1; + } + } + + return 0; +} + +template <typename TSocketHandler, typename TErrorLogger> +int TSslSocketHandler<TSocketHandler, TErrorLogger>::SslRead(TSocketCtx* ctx, char* buf, int buflen, int* status) { + Y_ASSERT(Initialized); + int remaining, bytes, rc; + + *status = 0; + if (Ssl == nullptr || buflen < 0) + return -1; + + // Return data if we still have cached + if (ctx->InBuf.start < ctx->InBuf.end) { + remaining = (int)(ctx->InBuf.end - ctx->InBuf.start); + if (!buflen) // polling + return remaining; + bytes = Min(buflen, remaining); + memcpy(buf, ctx->InBuf.start, bytes); + ctx->InBuf.start += bytes; + return bytes; + } + + // Pack buffered socket data (if any) so that start is at zero. + if (ctx->InSock.buf < ctx->InSock.start) { + if (ctx->InSock.start == ctx->InSock.end) { + ctx->InSock.start = ctx->InSock.end = ctx->InSock.buf; + } else { + memmove(ctx->InSock.buf, ctx->InSock.start, ctx->InSock.end - ctx->InSock.start); + ctx->InSock.end -= (ctx->InSock.start - ctx->InSock.buf); + ctx->InSock.start = ctx->InSock.buf; + } + } + + bool performRead = false; + bool dontPerformRead = false; + unsigned char error; + unsigned char alertLevel; + unsigned char alertDescription; + + while (true) { + // Read data from socket + if (!dontPerformRead && (performRead || ctx->InSock.end == ctx->InSock.start)) { + performRead = true; + bytes = TSocketHandler::read((void*)ctx->InSock.end, (ctx->InSock.buf + ctx->InSock.size) - ctx->InSock.end); + if (bytes == SOCKET_ERROR) { + *status = errno; + return -1; + } + if (bytes == 0) { + *status = SSLSOCKET_EOF; + return 0; + } + ctx->InSock.end += bytes; + } + dontPerformRead = false; + + error = 0; + alertLevel = 0; + alertDescription = 0; + + ctx->InBuf.start = ctx->InBuf.end = ctx->InBuf.buf; + rc = matrixSslDecode(Ssl, &ctx->InSock, &ctx->InBuf, &error, &alertLevel, &alertDescription); + + switch (rc) { + // Successfully decoded a record that did not return data or require a response. + case SSL_SUCCESS: + return 0; + + case SSL_PROCESS_DATA: + rc = (int)(ctx->InBuf.end - ctx->InBuf.start); + if (!buflen) + return rc; + rc = Min(rc, buflen); + memcpy(buf, ctx->InBuf.start, rc); + ctx->InBuf.start += rc; + return rc; + + case SSL_SEND_RESPONSE: + if (!TSocketHandler::send(ctx->InBuf.start, ctx->InBuf.end - ctx->InBuf.start)) { + *status = errno; + return -1; + } + ctx->InBuf.start = ctx->InBuf.end = ctx->InBuf.buf; + return 0; + + case SSL_ERROR: + if (ctx->InBuf.start < ctx->InBuf.end) + TSocketHandler::send(ctx->InBuf.start, ctx->InBuf.end - ctx->InBuf.start); + ctx->InBuf.start = ctx->InBuf.end = ctx->InBuf.buf; + ctx->SslError = 1; + return -1; + + case SSL_ALERT: + if (alertDescription == SSL_ALERT_CLOSE_NOTIFY) { + *status = SSLSOCKET_CLOSE_NOTIFY; + ctx->InBuf.start = ctx->InBuf.end = ctx->InBuf.buf; + return 0; + } + ctx->InBuf.start = ctx->InBuf.end = ctx->InBuf.buf; + ctx->SslError = 1; + return -1; + + case SSL_PARTIAL: + if (ctx->InSock.start == ctx->InSock.buf && ctx->InSock.end == (ctx->InSock.buf + ctx->InSock.size)) { + ctx->InSock.start = ctx->InSock.end = ctx->InSock.buf; + ctx->SslError = 1; + return -1; + } + if (!performRead) { + performRead = 1; + ctx->InBuf.start = ctx->InBuf.end = ctx->InBuf.buf; + continue; + } else { + ctx->InBuf.start = ctx->InBuf.end = ctx->InBuf.buf; + return 0; + } + + case SSL_FULL: + ctx->InBuf.start = ctx->InBuf.end = ctx->InBuf.buf; + ctx->SslError = 1; + return -1; + } + } + + return 0; +} + +template <typename TSocketHandler, typename TErrorLogger> +int TSslSocketHandler<TSocketHandler, TErrorLogger>::SslWrite(TSocketCtx* ctx, const char* buf, int len, int* status) { + Y_ASSERT(Initialized); + if (len <= 0) + return len ? -1 : 0; + int rc; + *status = 0; + + sslBuf_t outsock; + outsock.size = SSL_MAX_BUF_SIZE; + outsock.start = outsock.end = outsock.buf = BufAlloc->Alloc(); + + size_t remaining = len; + while (remaining) { + size_t l = Min<size_t>(remaining, SSL_MAX_PLAINTEXT_LEN); + rc = matrixSslEncode(Ssl, (const unsigned char*)buf, l, &outsock); + if (rc <= 0) { + TErrorLogger::Write("TSslSocketHandler::SslWrite: internal error: %u\n", rc); + BufAlloc->Free(outsock.buf); + ctx->SslError = 1; + return -1; + } + rc = TSocketHandler::send(outsock.start, outsock.end - outsock.start); + if (!rc) { + *status = errno; + BufAlloc->Free(outsock.buf); + return -1; + } + remaining -= l; + buf += l; + outsock.start = outsock.end = outsock.buf; + } + BufAlloc->Free(outsock.buf); + return len; +} |