diff options
author | Andrey Neporada <neporada@gmail.com> | 2022-07-01 13:40:47 +0300 |
---|---|---|
committer | Andrey Neporada <neporada@gmail.com> | 2022-07-01 13:40:47 +0300 |
commit | 0ec42a7c3c870620ed0896b14a08d91017c7cfcc (patch) | |
tree | d4311073cf27c4dd1917773dfab2b83b396d2460 | |
parent | 4b120368d66d843f27f868cd84374c732d22e830 (diff) | |
download | ydb-0ec42a7c3c870620ed0896b14a08d91017c7cfcc.tar.gz |
[YQL-15057] Refactor S3 object lister. Collect pattern matching groups.
ref:151e9d001165ef1afa05aef50367c467809ca137
4 files changed, 452 insertions, 315 deletions
diff --git a/ydb/library/yql/providers/s3/provider/CMakeLists.txt b/ydb/library/yql/providers/s3/provider/CMakeLists.txt index 0e2cc06957..5bbc6501ec 100644 --- a/ydb/library/yql/providers/s3/provider/CMakeLists.txt +++ b/ydb/library/yql/providers/s3/provider/CMakeLists.txt @@ -54,6 +54,7 @@ target_sources(providers-s3-provider PRIVATE ${CMAKE_SOURCE_DIR}/ydb/library/yql/providers/s3/provider/yql_s3_dq_integration.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/providers/s3/provider/yql_s3_exec.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/providers/s3/provider/yql_s3_io_discovery.cpp + ${CMAKE_SOURCE_DIR}/ydb/library/yql/providers/s3/provider/yql_s3_list.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/providers/s3/provider/yql_s3_logical_opt.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/providers/s3/provider/yql_s3_mkql_compiler.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/providers/s3/provider/yql_s3_phy_opt.cpp diff --git a/ydb/library/yql/providers/s3/provider/yql_s3_io_discovery.cpp b/ydb/library/yql/providers/s3/provider/yql_s3_io_discovery.cpp index e8c3f5a320..fabdd7b037 100644 --- a/ydb/library/yql/providers/s3/provider/yql_s3_io_discovery.cpp +++ b/ydb/library/yql/providers/s3/provider/yql_s3_io_discovery.cpp @@ -1,4 +1,5 @@ #include "yql_s3_provider_impl.h" +#include "yql_s3_list.h" #include <ydb/library/yql/providers/s3/expr_nodes/yql_s3_expr_nodes.h> #include <ydb/library/yql/core/yql_expr_optimize.h> @@ -6,13 +7,6 @@ #include <ydb/library/yql/utils/url_builder.h> #include <util/generic/size_literals.h> -#include <contrib/libs/re2/re2/re2.h> -#include <library/cpp/retry/retry_policy.h> - -#ifdef THROW -#undef THROW -#endif -#include <library/cpp/xml/document/xml-document.h> namespace NYql { @@ -31,160 +25,23 @@ std::array<TExprNode::TPtr, 2U> ExtractSchema(TExprNode::TListType& settings) { return {}; } -using TItemsMap = std::map<TString, ui64>; -using TPendingBuckets = std::unordered_map<std::tuple<TString, TString, TString>, std::tuple<TNodeSet, TItemsMap, TIssues>, THash<std::tuple<TString, TString, TString>>>; +struct TListRequest { + TString Token; + TString Url; + TString Pattern; +}; -ERetryErrorClass RetryS3SlowDown(long httpResponseCode) { - return httpResponseCode == 503 ? ERetryErrorClass::LongRetry : ERetryErrorClass::NoRetry; // S3 Slow Down == 503 +bool operator<(const TListRequest& a, const TListRequest& b) { + return std::tie(a.Token, a.Url, a.Pattern) < std::tie(b.Token, b.Url, b.Pattern); } -void OnDiscovery( - IHTTPGateway::TWeakPtr gateway, - TPosition pos, - IHTTPGateway::TResult&& result, - const TPendingBuckets::key_type& keys, - TPendingBuckets::mapped_type& output, - NThreading::TPromise<void> promise, - std::weak_ptr<TPendingBuckets> pendingBucketsWPtr, - int promiseInd, - const IRetryPolicy<long>::TPtr& retryPolicy, - ui64 maxDiscoveryFilesPerQuery) { - auto pendingBuckets = pendingBucketsWPtr.lock(); // keys and output could be used only when TPendingBuckets is alive - if (!pendingBuckets) { - return; - } - TString logMsg = TStringBuilder() << "promise #" << promiseInd << ": "; - switch (result.index()) { - case 0U: try { - logMsg += "Result received"; - const NXml::TDocument xml(std::get<IHTTPGateway::TContent>(std::move(result)).Extract(), NXml::TDocument::String); - if (const auto& root = xml.Root(); root.Name() == "Error") { - const auto& code = root.Node("Code", true).Value<TString>(); - const auto& message = root.Node("Message", true).Value<TString>(); - std::get<TIssues>(output) = {TIssue(pos, TStringBuilder() << message << ", error: code: " << code)}; - break; - } else if (root.Name() != "ListBucketResult") { - std::get<TIssues>(output) = { TIssue(pos, TStringBuilder() << "Unexpected response '" << root.Name() << "' on discovery.") }; - break; - } else if (const NXml::TNamespacesForXPath nss(1U, {"s3", "http://s3.amazonaws.com/doc/2006-03-01/"}); - root.Node("s3:KeyCount", false, nss).Value<unsigned>() > 0U) { - const auto& contents = root.XPath("s3:Contents", false, nss); - auto& items = std::get<TItemsMap>(output); - if (maxDiscoveryFilesPerQuery && items.size() + contents.size() > maxDiscoveryFilesPerQuery) { - std::get<TIssues>(output) = { TIssue(pos, TStringBuilder() << "Over " << maxDiscoveryFilesPerQuery << " files discovered in '" << std::get<0U>(keys) << std::get<1U>(keys) << "'")}; - break; - } - - for (const auto& content : contents) { - items.emplace(content.Node("s3:Key", false, nss).Value<TString>(), content.Node("s3:Size", false, nss).Value<ui64>()); - } - - if (root.Node("s3:IsTruncated", false, nss).Value<bool>()) { - if (const auto g = gateway.lock()) { - const auto& next = root.Node("s3:NextContinuationToken", false, nss).Value<TString>(); - const auto& maxKeys = root.Node("s3:MaxKeys", false, nss).Value<TString>(); - - IHTTPGateway::THeaders headers; - if (const auto& token = std::get<2U>(keys); !token.empty()) - headers.emplace_back(token); - - TString prefix(std::get<1U>(keys)); - TUrlBuilder urlBuilder(std::get<0U>(keys)); - auto url = urlBuilder.AddUrlParam("list-type", "2") - .AddUrlParam("prefix", prefix) - .AddUrlParam("continuation-token", next) - .AddUrlParam("max-keys", maxKeys) - .Build(); - - return g->Download( - url, - std::move(headers), - 0U, - std::bind(&OnDiscovery, gateway, pos, std::placeholders::_1, std::cref(keys), std::ref(output), std::move(promise), pendingBucketsWPtr, promiseInd, retryPolicy, maxDiscoveryFilesPerQuery), - /*data=*/"", - false, - retryPolicy); - } - YQL_CLOG(INFO, ProviderS3) << "Gateway disappeared."; - } - } - - break; - } catch (const std::exception& ex) { - logMsg += TStringBuilder() << "Exception occurred: " << ex.what(); - std::get<TIssues>(output) = {TIssue(pos, TStringBuilder() << "Error '" << ex.what() << "' on parse discovery response.")}; - break; - } - case 1U: - logMsg += TStringBuilder() << "Issues occurred: " << std::get<TIssues>(result).ToString(); - std::get<TIssues>(output) = std::get<TIssues>(std::move(result)); - break; - default: - logMsg += TStringBuilder() << "Undefined variant index: " << result.index(); - std::get<TIssues>(output) = {TIssue(pos, TStringBuilder() << "Unexpected variant index " << result.index() << " on discovery.")}; - break; - } - - // this logging does not work at the moment since we are trying to do it in non-pipeline thread (http gateway thread) - // todo: fix logging - YQL_CLOG(DEBUG, ProviderS3) << "Set promise with log message: " << logMsg; - promise.SetValue(); -} - -TString RegexFromWildcards(const std::string_view& pattern) { - const auto& escaped = RE2::QuoteMeta(re2::StringPiece(pattern)); - TStringBuilder result; - result << "(?s)"; - bool slash = false; - bool group = false; - - for (const char& c : escaped) { - switch (c) { - case '{': - result << '('; - group = true; - slash = false; - break; - case '}': - result << ')'; - group = false; - slash = false; - break; - case ',': - if (group) - result << '|'; - else - result << "\\,"; - slash = false; - break; - case '\\': - if (slash) - result << "\\\\"; - slash = !slash; - break; - case '*': - result << ".*"; - slash = false; - break; - case '?': - result << '.'; - slash = false; - break; - default: - if (slash) - result << '\\'; - result << c; - slash = false; - break; - } - } - return result; -} +using TPendingRequests = TMap<TListRequest, NThreading::TFuture<IS3Lister::TListResult>>; class TS3IODiscoveryTransformer : public TGraphTransformerBase { public: TS3IODiscoveryTransformer(TS3State::TPtr state, IHTTPGateway::TPtr gateway) - : State_(std::move(state)), Gateway_(std::move(gateway)) + : State_(std::move(state)) + , Lister_(IS3Lister::Make(gateway, State_->Configuration->MaxDiscoveryFilesPerQuery)) {} TStatus DoTransform(TExprNode::TPtr input, TExprNode::TPtr& output, TExprContext& ctx) final { @@ -193,69 +50,47 @@ public: return TStatus::Ok; } - if (auto reads = FindNodes(input, [&](const TExprNode::TPtr& node) { + auto reads = FindNodes(input, [&](const TExprNode::TPtr& node) { if (const auto maybeRead = TMaybeNode<TS3Read>(node)) { if (maybeRead.DataSource()) { return maybeRead.Cast().Arg(2).Ref().IsCallable({"MrObject", "MrTableConcat"}); } } return false; - }); !reads.empty()) { - for (auto& r : reads) { - const TS3Read read(std::move(r)); - std::unordered_set<std::string_view> paths; - if (const auto& object = read.Arg(2).Ref(); object.IsCallable("MrObject")) - paths.emplace(object.Head().Content()); - else if (object.IsCallable("MrTableConcat")) - object.ForEachChild([&paths](const TExprNode& child){ paths.emplace(child.Head().Tail().Head().Content()); }); - const auto& connect = State_->Configuration->Clusters.at(read.DataSource().Cluster().StringValue()); - const auto& token = State_->Configuration->Tokens.at(read.DataSource().Cluster().StringValue()); - const auto credentialsProviderFactory = CreateCredentialsProviderFactoryForStructuredToken(State_->CredentialsFactory, token); - const auto authToken = credentialsProviderFactory->CreateProvider()->GetAuthInfo(); - - for (const auto& path : paths) { - const auto& prefix = path.substr(0U, path.find_first_of("*?{")); - std::get<TNodeSet>((*PendingBuckets_)[std::make_tuple(connect.Url, TString(prefix), authToken.empty() ? TString() : TString("X-YaCloud-SubjectToken:") += authToken)]).emplace(read.Raw()); + }); + + TVector<NThreading::TFuture<IS3Lister::TListResult>> futures; + for (auto& r : reads) { + const TS3Read read(std::move(r)); + std::unordered_set<std::string_view> paths; + if (const auto& object = read.Arg(2).Ref(); object.IsCallable("MrObject")) + paths.emplace(object.Head().Content()); + else if (object.IsCallable("MrTableConcat")) + object.ForEachChild([&paths](const TExprNode& child){ paths.emplace(child.Head().Tail().Head().Content()); }); + const auto& connect = State_->Configuration->Clusters.at(read.DataSource().Cluster().StringValue()); + const auto& token = State_->Configuration->Tokens.at(read.DataSource().Cluster().StringValue()); + const auto credentialsProviderFactory = CreateCredentialsProviderFactoryForStructuredToken(State_->CredentialsFactory, token); + + TListRequest req; + req.Token = credentialsProviderFactory->CreateProvider()->GetAuthInfo(); + req.Url = connect.Url; + for (const auto& path : paths) { + req.Pattern = path; + RequestsByNode_[read.Raw()].push_back(req); + + if (PendingRequests_.find(req) == PendingRequests_.end()) { + auto future = Lister_->List(req.Token, req.Url, req.Pattern); + PendingRequests_[req] = future; + futures.push_back(std::move(future)); } } } - std::vector<NThreading::TFuture<void>> handles; - handles.reserve(PendingBuckets_->size()); - - int i = 0; - const auto retryPolicy = IRetryPolicy<long>::GetExponentialBackoffPolicy(RetryS3SlowDown); - for (auto& bucket : *PendingBuckets_) { - auto promise = NThreading::NewPromise(); - handles.emplace_back(promise.GetFuture()); - IHTTPGateway::THeaders headers; - if (const auto& token = std::get<2U>(bucket.first); !token.empty()) - headers.emplace_back(token); - std::weak_ptr<TPendingBuckets> pendingBucketsWPtr = PendingBuckets_; - TString prefix(std::get<1U>(bucket.first)); - TUrlBuilder urlBuilder(std::get<0U>(bucket.first)); - const auto url = urlBuilder.AddUrlParam("list-type", "2") - .AddUrlParam("prefix", prefix) - .Build(); - Gateway_->Download( - url, - headers, - 0U, - std::bind(&OnDiscovery, - IHTTPGateway::TWeakPtr(Gateway_), ctx.GetPosition((*std::get<TNodeSet>(bucket.second).cbegin())->Pos()), std::placeholders::_1, - std::cref(bucket.first), std::ref(bucket.second), std::move(promise), pendingBucketsWPtr, i++, retryPolicy, State_->Configuration->MaxDiscoveryFilesPerQuery), - /*data=*/"", - false, - retryPolicy - ); - YQL_CLOG(INFO, ProviderS3) << "Enumerate items in " << std::get<0U>(bucket.first) << std::get<1U>(bucket.first); - } - - if (handles.empty()) { + if (futures.empty()) { return TStatus::Ok; } - AllFuture_ = NThreading::WaitExceptionOrAll(handles); + AllFuture_ = NThreading::WaitExceptionOrAll(futures); return TStatus::Async; } @@ -267,126 +102,99 @@ public: // Raise errors if any AllFuture_.GetValue(); - TNodeOnNodeOwnedMap replaces(PendingBuckets_->size()); - auto buckets = std::move(*PendingBuckets_); - auto count = 0ULL; - auto readSize = 0ULL; - for (auto& bucket : buckets) { - if (const auto issues = std::move(std::get<TIssues>(bucket.second))) { - YQL_CLOG(INFO, ProviderS3) << "Discovery " << std::get<0U>(bucket.first) << std::get<1U>(bucket.first) << " error " << issues.ToString(); - std::for_each(issues.begin(), issues.end(), std::bind(&TExprContext::AddError, std::ref(ctx), std::placeholders::_1)); - return TStatus::Error; - } + TPendingRequests pendingRequests; + TNodeMap<TVector<TListRequest>> requestsByNode; + + pendingRequests.swap(PendingRequests_); + requestsByNode.swap(RequestsByNode_); + + TNodeOnNodeOwnedMap replaces; + size_t count = 0; + size_t totalSize = 0; + for (auto& [node, requests] : requestsByNode) { + const TS3Read read(node); + const auto& object = read.Arg(2).Ref(); + size_t readSize = 0; + TExprNode::TListType pathNodes; + for (auto& req : requests) { + auto it = pendingRequests.find(req); + YQL_ENSURE(it != pendingRequests.end()); + YQL_ENSURE(it->second.HasValue()); + + const IS3Lister::TListResult& listResult = it->second.GetValue(); + if (listResult.index() == 1) { + const auto& issues = std::get<TIssues>(listResult); + YQL_CLOG(INFO, ProviderS3) << "Discovery " << req.Url << req.Pattern << " error " << issues.ToString(); + std::for_each(issues.begin(), issues.end(), std::bind(&TExprContext::AddError, std::ref(ctx), std::placeholders::_1)); + return TStatus::Error; + } - const auto nodes = std::move(std::get<TNodeSet>(bucket.second)); - for (const auto r : nodes) { - const TS3Read read(r); - - std::vector<std::string_view> keys; - const auto& object = read.Arg(2).Ref(); - if (object.IsCallable("MrObject")) - keys.emplace_back(object.Head().Content()); - else if (object.IsCallable("MrTableConcat")) - object.ForEachChild([&keys](const TExprNode& child){ keys.emplace_back(child.Head().Tail().Head().Content()); }); - - const auto& items = std::get<TItemsMap>(bucket.second); - YQL_CLOG(INFO, ProviderS3) << "Discovered " << items.size() << " items in " << std::get<0U>(bucket.first) << std::get<1U>(bucket.first); - - TExprNode::TListType paths; - for (const auto& path : keys) { - if (std::string_view::npos != path.find_first_of("?*{")) { - const RE2 re(re2::StringPiece(RegexFromWildcards(path)), RE2::Options()); - paths.reserve(items.size()); - auto total = 0ULL; - for (const auto& item : items) { - if (const re2::StringPiece piece(item.first); re.Match(piece, 0, item.first.size(), RE2::ANCHOR_BOTH, nullptr, 0)) { - if (item.second > State_->Configuration->FileSizeLimit) { - ctx.AddError(TIssue(ctx.GetPosition(object.Pos()), TStringBuilder() << "Object " << item.first << " size " << item.second << " is too large, but limit is " << State_->Configuration->FileSizeLimit)); - return TStatus::Error; - } - - total += item.second; - ++count; - paths.emplace_back( - ctx.Builder(object.Pos()) - .List() - .Atom(0, item.first) - .Atom(1, ToString(item.second), TNodeFlags::Default) - .Seal() - .Build() - ); - } - } - - readSize += total; - - if (paths.empty()) { - ctx.AddError(TIssue(ctx.GetPosition(object.Pos()), TStringBuilder() << "Object " << path << " has no items.")); - return TStatus::Error; - } - YQL_CLOG(INFO, ProviderS3) << "Object " << path << " has " << paths.size() << " items with total size " << total; - } else if (const auto f = items.find(TString(path)); items.cend() == f) { - ctx.AddError(TIssue(ctx.GetPosition(object.Pos()), TStringBuilder() << "Object " << path << " doesn't exist.")); - return TStatus::Error; - } else if (const auto size = f->second; size > State_->Configuration->FileSizeLimit) { - ctx.AddError(TIssue(ctx.GetPosition(object.Pos()), TStringBuilder() << "Object " << path << " size " << size << " is too large, but limit is " << State_->Configuration->FileSizeLimit)); - return TStatus::Error; + const auto& listEntries = std::get<IS3Lister::TListEntries>(listResult); + if (listEntries.empty()) { + if (IS3Lister::HasWildcards(req.Pattern)) { + ctx.AddError(TIssue(ctx.GetPosition(object.Pos()), TStringBuilder() << "Object " << req.Pattern << " has no items.")); } else { - YQL_CLOG(INFO, ProviderS3) << "Object " << path << " size is " << size; - readSize += size; - ++count; - paths.emplace_back( - ctx.Builder(object.Pos()) - .List() - .Atom(0, path) - .Atom(1, ToString(size), TNodeFlags::Default) - .Seal() - .Build() - ); + ctx.AddError(TIssue(ctx.GetPosition(object.Pos()), TStringBuilder() << "Object " << req.Pattern << " doesn't exist.")); } + return TStatus::Error; } - - auto settings = read.Ref().Child(4)->ChildrenList(); - auto userSchema = ExtractSchema(settings); - TExprNode::TPtr s3Object; - if (object.IsCallable("MrObject")) { - auto children = object.ChildrenList(); - children.front() = ctx.NewList(object.Pos(), std::move(paths)); - s3Object = ctx.NewCallable(object.Pos(), TS3Object::CallableName(), std::move(children)); - } else if (object.IsCallable("MrTableConcat")) { - s3Object = Build<TS3Object>(ctx, object.Pos()) - .Paths(ctx.NewList(object.Pos(), std::move(paths))) - .Format(ExtractFormat(settings)) - .Settings(ctx.NewList(object.Pos(), std::move(settings))) - .Done().Ptr(); + for (auto& entry : listEntries) { + pathNodes.emplace_back( + ctx.Builder(object.Pos()) + .List() + .Atom(0, entry.Path) + .Atom(1, ToString(entry.Size), TNodeFlags::Default) + .Seal() + .Build() + ); + ++count; + readSize += entry.Size; } - replaces.emplace(r, userSchema.back() ? - Build<TS3ReadObject>(ctx, read.Pos()) - .World(read.World()) - .DataSource(read.DataSource()) - .Object(std::move(s3Object)) - .RowType(std::move(userSchema.front())) - .ColumnOrder(std::move(userSchema.back())) - .Done().Ptr(): - Build<TS3ReadObject>(ctx, read.Pos()) - .World(read.World()) - .DataSource(read.DataSource()) - .Object(std::move(s3Object)) - .RowType(std::move(userSchema.front())) - .Done().Ptr()); + YQL_CLOG(INFO, ProviderS3) << "Object " << req.Pattern << " has " << listEntries.size() << " items with total size " << readSize; + totalSize += readSize; + } + + auto settings = read.Ref().Child(4)->ChildrenList(); + auto userSchema = ExtractSchema(settings); + TExprNode::TPtr s3Object; + if (object.IsCallable("MrObject")) { + auto children = object.ChildrenList(); + children.front() = ctx.NewList(object.Pos(), std::move(pathNodes)); + s3Object = ctx.NewCallable(object.Pos(), TS3Object::CallableName(), std::move(children)); + } else if (object.IsCallable("MrTableConcat")) { + s3Object = Build<TS3Object>(ctx, object.Pos()) + .Paths(ctx.NewList(object.Pos(), std::move(pathNodes))) + .Format(ExtractFormat(settings)) + .Settings(ctx.NewList(object.Pos(), std::move(settings))) + .Done().Ptr(); } - } - YQL_CLOG(INFO, ProviderS3) << "Read " << count << " objects with total size is " << readSize; + replaces.emplace(node, userSchema.back() ? + Build<TS3ReadObject>(ctx, read.Pos()) + .World(read.World()) + .DataSource(read.DataSource()) + .Object(std::move(s3Object)) + .RowType(std::move(userSchema.front())) + .ColumnOrder(std::move(userSchema.back())) + .Done().Ptr(): + Build<TS3ReadObject>(ctx, read.Pos()) + .World(read.World()) + .DataSource(read.DataSource()) + .Object(std::move(s3Object)) + .RowType(std::move(userSchema.front())) + .Done().Ptr()); + } - if (count > State_->Configuration->MaxFilesPerQuery) { - ctx.AddError(TIssue(ctx.GetPosition(input->Pos()), TStringBuilder() << "Too many objects to read: " << count << ", but limit is " << State_->Configuration->MaxFilesPerQuery)); + const auto maxFiles = State_->Configuration->MaxFilesPerQuery; + if (count > maxFiles) { + ctx.AddError(TIssue(ctx.GetPosition(input->Pos()), TStringBuilder() << "Too many objects to read: " << count << ", but limit is " << maxFiles)); return TStatus::Error; } - if (readSize > State_->Configuration->MaxReadSizePerQuery) { - ctx.AddError(TIssue(ctx.GetPosition(input->Pos()), TStringBuilder() << "Too large objects to read: " << readSize << ", but limit is " << State_->Configuration->MaxReadSizePerQuery)); + const auto maxSize = State_->Configuration->MaxReadSizePerQuery; + if (totalSize > maxSize) { + ctx.AddError(TIssue(ctx.GetPosition(input->Pos()), TStringBuilder() << "Too large objects to read: " << totalSize << ", but limit is " << maxSize)); return TStatus::Error; } @@ -394,10 +202,10 @@ public: } private: const TS3State::TPtr State_; - const IHTTPGateway::TPtr Gateway_; - - const std::shared_ptr<TPendingBuckets> PendingBuckets_ = std::make_shared<TPendingBuckets>(); + const IS3Lister::TPtr Lister_; + TPendingRequests PendingRequests_; + TNodeMap<TVector<TListRequest>> RequestsByNode_; NThreading::TFuture<void> AllFuture_; }; diff --git a/ydb/library/yql/providers/s3/provider/yql_s3_list.cpp b/ydb/library/yql/providers/s3/provider/yql_s3_list.cpp new file mode 100644 index 0000000000..780e0b964d --- /dev/null +++ b/ydb/library/yql/providers/s3/provider/yql_s3_list.cpp @@ -0,0 +1,289 @@ +#include "yql_s3_list.h" + +#include <ydb/library/yql/utils/log/log.h> +#include <ydb/library/yql/utils/url_builder.h> +#include <ydb/library/yql/utils/yql_panic.h> + +#include <contrib/libs/re2/re2/re2.h> + +#ifdef THROW +#undef THROW +#endif +#include <library/cpp/xml/document/xml-document.h> +#include <library/cpp/retry/retry_policy.h> + +#include <util/string/builder.h> + + +namespace NYql { + +namespace { + +ERetryErrorClass RetryS3SlowDown(long httpResponseCode) { + return httpResponseCode == 503 ? ERetryErrorClass::LongRetry : ERetryErrorClass::NoRetry; // S3 Slow Down == 503 +} + +size_t GetFirstWildcardPos(const TString& pattern) { + return pattern.find_first_of("*?{"); +} + +TString RegexFromWildcards(const std::string_view& pattern) { + const auto& escaped = RE2::QuoteMeta(re2::StringPiece(pattern)); + TStringBuilder result; + result << "(?s)"; + bool slash = false; + bool group = false; + + for (const char& c : escaped) { + switch (c) { + case '{': + result << '('; + group = true; + slash = false; + break; + case '}': + result << ')'; + group = false; + slash = false; + break; + case ',': + if (group) + result << '|'; + else + result << "\\,"; + slash = false; + break; + case '\\': + if (slash) + result << "\\\\"; + slash = !slash; + break; + case '*': + result << "(.*)"; + slash = false; + break; + case '?': + result << "(.)"; + slash = false; + break; + default: + if (slash) + result << '\\'; + result << c; + slash = false; + break; + } + } + return result; +} + +using namespace NThreading; + +class TS3Lister : public IS3Lister { +public: + explicit TS3Lister(const IHTTPGateway::TPtr& httpGateway, ui64 maxFilesPerQuery) + : Gateway(httpGateway) + , MaxFilesPerQuery(maxFilesPerQuery) + {} +private: + using TResultFilter = std::function<bool (const TString& path, TVector<TString>& matchedGlobs)>; + + static TResultFilter MakeFilter(const TString& pattern, TString& prefix) { + prefix.clear(); + if (auto pos = GetFirstWildcardPos(pattern); pos != TString::npos) { + prefix = pattern.substr(0, pos); + const auto regex = RegexFromWildcards(pattern); + auto re = std::make_shared<RE2>(re2::StringPiece(regex), RE2::Options()); + YQL_ENSURE(re->ok()); + YQL_ENSURE(re->NumberOfCapturingGroups() > 0); + + const size_t numGroups = re->NumberOfCapturingGroups(); + YQL_CLOG(INFO, ProviderS3) << "Got prefix: '" << prefix << "', regex: '" << regex + << "' with " << numGroups << " capture groups from original pattern '" << pattern << "'"; + + auto groups = std::make_shared<std::vector<std::string>>(numGroups); + auto reArgs = std::make_shared<std::vector<re2::RE2::Arg>>(numGroups); + auto reArgsPtr = std::make_shared<std::vector<re2::RE2::Arg*>>(numGroups); + + for (size_t i = 0; i < size_t(numGroups); ++i) { + (*reArgs)[i] = &(*groups)[i]; + (*reArgsPtr)[i] = &(*reArgs)[i]; + } + + return [groups, reArgs, reArgsPtr, re](const TString& path, TVector<TString>& matchedGlobs) { + matchedGlobs.clear(); + bool matched = re2::RE2::FullMatchN(path, *re, reArgsPtr->data(), reArgsPtr->size()); + if (matched) { + matchedGlobs.reserve(groups->size()); + for (auto& group : *groups) { + matchedGlobs.push_back(ToString(group)); + } + } + return matched; + }; + } + prefix = pattern; + return [pattern](const TString& path, TVector<TString>& matchedGlobs) { + matchedGlobs.clear(); + return path == pattern; + }; + } + + static void OnDiscovery( + const IHTTPGateway::TWeakPtr& gatewayWeak, + IHTTPGateway::TResult&& result, + NThreading::TPromise<IS3Lister::TListResult> promise, + const std::shared_ptr<IS3Lister::TListEntries>& output, + const IRetryPolicy<long>::TPtr& retryPolicy, + const TResultFilter& filter, + const TString& token, + const TString& urlStr, + const TString& prefix, + ui64 maxDiscoveryFilesPerQuery) + try { + auto gateway = gatewayWeak.lock(); + if (!gateway) { + ythrow yexception() << "Gateway disappeared"; + } + switch (result.index()) { + case 0U: { + const NXml::TDocument xml(std::get<IHTTPGateway::TContent>(std::move(result)).Extract(), NXml::TDocument::String); + if (const auto& root = xml.Root(); root.Name() == "Error") { + const auto& code = root.Node("Code", true).Value<TString>(); + const auto& message = root.Node("Message", true).Value<TString>(); + ythrow yexception() << message << ", error: code: " << code; + } else if (root.Name() != "ListBucketResult") { + ythrow yexception() << "Unexpected response '" << root.Name() << "' on discovery."; + } else if ( + const NXml::TNamespacesForXPath nss(1U, {"s3", "http://s3.amazonaws.com/doc/2006-03-01/"}); + root.Node("s3:KeyCount", false, nss).Value<unsigned>() > 0U) + { + const auto& contents = root.XPath("s3:Contents", false, nss); + YQL_CLOG(INFO, ProviderS3) << "Listing of " << urlStr << prefix << ": have " << output->size() << " entries, got another " << contents.size() << " entries"; + if (maxDiscoveryFilesPerQuery && output->size() + contents.size() > maxDiscoveryFilesPerQuery) { + ythrow yexception() << "Over " << maxDiscoveryFilesPerQuery << " files discovered in '" << urlStr << prefix << "'"; + } + + for (const auto& content : contents) { + TString path = content.Node("s3:Key", false, nss).Value<TString>(); + TVector<TString> matchedGlobs; + if (filter(path, matchedGlobs)) { + output->emplace_back(); + output->back().Path = path; + output->back().Size = content.Node("s3:Size", false, nss).Value<ui64>(); + output->back().MatchedGlobs.swap(matchedGlobs); + } + } + + if (root.Node("s3:IsTruncated", false, nss).Value<bool>()) { + YQL_CLOG(INFO, ProviderS3) << "Listing of " << urlStr << prefix << ": got truncated flag, will continue"; + const auto& next = root.Node("s3:NextContinuationToken", false, nss).Value<TString>(); + const auto& maxKeys = root.Node("s3:MaxKeys", false, nss).Value<TString>(); + + IHTTPGateway::THeaders headers; + if (!token.empty()) { + headers.emplace_back("X-YaCloud-SubjectToken:" + token); + } + + TUrlBuilder urlBuilder(urlStr); + auto url = urlBuilder.AddUrlParam("list-type", "2") + .AddUrlParam("prefix", prefix) + .AddUrlParam("continuation-token", next) + .AddUrlParam("max-keys", maxKeys) + .Build(); + + return gateway->Download( + url, + std::move(headers), + 0U, + std::bind(&OnDiscovery, + IHTTPGateway::TWeakPtr(gateway), + std::placeholders::_1, + promise, + output, + retryPolicy, + filter, + token, + urlStr, + prefix, + maxDiscoveryFilesPerQuery), + /*data=*/"", + false, + retryPolicy); + } + } + promise.SetValue(std::move(*output)); + break; + } + case 1U: { + auto issues = std::get<TIssues>(std::move(result)); + YQL_CLOG(INFO, ProviderS3) << "Listing of " << urlStr << prefix << ": got error from http gateway: " << issues.ToString(true); + promise.SetValue(std::move(issues)); + break; + } + default: + ythrow yexception() << "Undefined variant index: " << result.index(); + } + } catch (const std::exception& ex) { + YQL_CLOG(INFO, ProviderS3) << "Listing of " << urlStr << prefix << " : got exception: " << ex.what(); + promise.SetException(std::current_exception()); + } + + + TFuture<TListResult> List(const TString& token, const TString& urlStr, const TString& pattern) override { + TString prefix; + TResultFilter filter = MakeFilter(pattern, prefix); + YQL_CLOG(INFO, ProviderS3) << "Enumerate items in " << urlStr << pattern; + + const auto retryPolicy = IRetryPolicy<long>::GetExponentialBackoffPolicy(RetryS3SlowDown); + TUrlBuilder urlBuilder(urlStr); + const auto url = urlBuilder + .AddUrlParam("list-type", "2") + .AddUrlParam("prefix", prefix) + .Build(); + + IHTTPGateway::THeaders headers; + if (!token.empty()) { + headers.emplace_back("X-YaCloud-SubjectToken:" + token); + } + + auto promise = NewPromise<IS3Lister::TListResult>(); + auto future = promise.GetFuture(); + + Gateway->Download( + url, + std::move(headers), + 0U, + std::bind(&OnDiscovery, + IHTTPGateway::TWeakPtr(Gateway), + std::placeholders::_1, + promise, + std::make_shared<IS3Lister::TListEntries>(), + retryPolicy, + filter, + token, + urlStr, + prefix, + MaxFilesPerQuery), + /*data=*/"", + false, + retryPolicy); + return future; + } + + const IHTTPGateway::TPtr Gateway; + const ui64 MaxFilesPerQuery; +}; + + +} + +bool IS3Lister::HasWildcards(const TString& pattern) { + return GetFirstWildcardPos(pattern) != TString::npos; +} + +IS3Lister::TPtr IS3Lister::Make(const IHTTPGateway::TPtr& httpGateway, ui64 maxFilesPerQuery) { + return IS3Lister::TPtr(new TS3Lister(httpGateway, maxFilesPerQuery)); +} + +} diff --git a/ydb/library/yql/providers/s3/provider/yql_s3_list.h b/ydb/library/yql/providers/s3/provider/yql_s3_list.h new file mode 100644 index 0000000000..577c4e507b --- /dev/null +++ b/ydb/library/yql/providers/s3/provider/yql_s3_list.h @@ -0,0 +1,39 @@ +#pragma once + +#include <ydb/library/yql/providers/common/http_gateway/yql_http_gateway.h> + +#include <library/cpp/threading/future/future.h> + +#include <variant> +#include <vector> +#include <memory> + +namespace NYql { + +class IS3Lister { +public: + using TPtr = std::shared_ptr<IS3Lister>; + + struct TListEntry { + TString Path; + ui64 Size = 0; + std::vector<TString> MatchedGlobs; + }; + + using TListEntries = std::vector<TListEntry>; + using TListResult = std::variant<TListEntries, TIssues>; + + virtual ~IS3Lister() = default; + // List all S3 objects matching wildcard pattern. + // Pattern may include following wildcard expressions: + // * - any (possibly empty) sequence of characters + // ? - single character + // {variant1, variant2} - list of alternatives + virtual NThreading::TFuture<TListResult> List(const TString& token, const TString& url, const TString& pattern) = 0; + + static TPtr Make(const IHTTPGateway::TPtr& httpGateway, ui64 maxFilesPerQuery = 0); + + static bool HasWildcards(const TString& pattern); +}; + +}
\ No newline at end of file |