diff options
author | a-romanov <Anton.Romanov@ydb.tech> | 2023-04-29 22:05:30 +0300 |
---|---|---|
committer | a-romanov <Anton.Romanov@ydb.tech> | 2023-04-29 22:05:30 +0300 |
commit | 30fcddb738365ea92672976dc2f67fb8762a50c3 (patch) | |
tree | cbdca0b80946889689b18ca36e45e92b37ed664b | |
parent | 092195a4eec9472a10ae234bbebbffd779256cc0 (diff) | |
download | ydb-30fcddb738365ea92672976dc2f67fb8762a50c3.tar.gz |
YQL-8971 YQL-15555 Constraints for JoinDict, Casts and many fixes.
-rw-r--r-- | ydb/library/yql/ast/yql_constraint.cpp | 88 | ||||
-rw-r--r-- | ydb/library/yql/ast/yql_constraint.h | 2 | ||||
-rw-r--r-- | ydb/library/yql/core/yql_expr_constraint.cpp | 426 |
3 files changed, 251 insertions, 265 deletions
diff --git a/ydb/library/yql/ast/yql_constraint.cpp b/ydb/library/yql/ast/yql_constraint.cpp index 569d705123c..acb36b15dd2 100644 --- a/ydb/library/yql/ast/yql_constraint.cpp +++ b/ydb/library/yql/ast/yql_constraint.cpp @@ -30,7 +30,7 @@ void TConstraintNode::Out(IOutputStream& out) const { } const TTypeAnnotationNode* TConstraintNode::GetSubTypeByPath(const TPathType& path, const TTypeAnnotationNode& type) { - if (path.empty()) + if (path.empty() && ETypeAnnotationKind::Optional != type.GetKind()) return &type; const auto tail = [](const TPathType& path) { @@ -55,6 +55,16 @@ const TTypeAnnotationNode* TConstraintNode::GetSubTypeByPath(const TPathType& pa if (const auto multiType = type.Cast<TMultiExprType>(); multiType->GetSize() > *index) return GetSubTypeByPath(tail(path), *multiType->GetItems()[*index]); break; + case ETypeAnnotationKind::Variant: + return GetSubTypeByPath(path, *type.Cast<TVariantExprType>()->GetUnderlyingType()); + case ETypeAnnotationKind::Dict: + if (const auto index = TryFromString<ui8>(TStringBuf(path.front()))) + switch (*index) { + case 0U: return GetSubTypeByPath(tail(path), *type.Cast<TDictExprType>()->GetKeyType()); + case 1U: return GetSubTypeByPath(tail(path), *type.Cast<TDictExprType>()->GetPayloadType()); + default: break; + } + break; default: break; } @@ -1014,20 +1024,19 @@ TPartOfConstraintNode<TOriginalConstraintNode>::FilterFields(TExprContext& ctx, template<class TOriginalConstraintNode> const TPartOfConstraintNode<TOriginalConstraintNode>* TPartOfConstraintNode<TOriginalConstraintNode>::RenameFields(TExprContext& ctx, const TPathReduce& rename) const { - auto mapping = Mapping_; - for (auto part = mapping.begin(); mapping.end() != part;) { - TPartType old; - part->second.swap(old); - for (auto& item : std::move(old)) { + TMapType mapping(Mapping_.size()); + for (const auto& part : Mapping_) { + TPartType map; + map.reserve(part.second.size()); + + for (const auto& item : part.second) { for (auto& path : rename(item.first)) { - part->second.insert_unique(std::make_pair(std::move(path), std::move(item.second))); + map.insert_unique(std::make_pair(std::move(path), item.second)); } } - if (part->second.empty()) - part = mapping.erase(part); - else - ++part; + if (!map.empty()) + mapping.emplace(part.first, std::move(map)); } return mapping.empty() ? nullptr : ctx.MakeConstraint<TPartOfConstraintNode>(std::move(mapping)); } @@ -1353,21 +1362,14 @@ void TPassthroughConstraintNode::Out(IOutputStream& out) const { bool first = true; for (const auto& part : Mapping_) { for (const auto& item : part.second) { - if (!first) { + if (first) + first = false; + else out.Write(','); - } - if (!item.first.empty()) { - auto it = item.first.cbegin(); - out.Write(*it); - while (item.first.cend() > ++it) { - out.Write('#'); - out.Write(*it); - } - } - out.Write(':'); - out.Write(item.second); - first = false; + out << item.first; + out.Write(':'); + out << item.second; } } out.Write(')'); @@ -1420,6 +1422,44 @@ const TPassthroughConstraintNode* TPassthroughConstraintNode::ExtractField(TExpr return passtrought.empty() ? nullptr : ctx.MakeConstraint<TPassthroughConstraintNode>(std::move(passtrought)); } +const TPassthroughConstraintNode* +TPassthroughConstraintNode::FilterFields(TExprContext& ctx, const TPathFilter& predicate) const { + TMapType passtrought(Mapping_.size()); + for (const auto& part : Mapping_) { + TPartType mapping; + mapping.reserve(part.second.size()); + + for (const auto& item : part.second) { + if (predicate(item.first)) { + mapping.insert_unique(std::make_pair(item.first, item.second)); + } + } + + if (!mapping.empty()) + passtrought.emplace(part.first ? part.first : this, std::move(mapping)); + } + return passtrought.empty() ? nullptr : ctx.MakeConstraint<TPassthroughConstraintNode>(std::move(passtrought)); +} + +const TPassthroughConstraintNode* +TPassthroughConstraintNode::RenameFields(TExprContext& ctx, const TPathReduce& rename) const { + TMapType passtrought(Mapping_.size()); + for (const auto& part : Mapping_) { + TPartType mapping; + mapping.reserve(part.second.size()); + + for (const auto& item : part.second) { + for (auto& path : rename(item.first)) { + mapping.insert_unique(std::make_pair(std::move(path), item.second)); + } + } + + if (!mapping.empty()) + passtrought.emplace(part.first ? part.first : this, std::move(mapping)); + } + return passtrought.empty() ? nullptr : ctx.MakeConstraint<TPassthroughConstraintNode>(std::move(passtrought)); +} + const TPassthroughConstraintNode* TPassthroughConstraintNode::MakeCommon(const std::vector<const TConstraintSet*>& constraints, TExprContext& ctx) { if (constraints.empty()) { return nullptr; diff --git a/ydb/library/yql/ast/yql_constraint.h b/ydb/library/yql/ast/yql_constraint.h index 188d517f1df..c05e153bf71 100644 --- a/ydb/library/yql/ast/yql_constraint.h +++ b/ydb/library/yql/ast/yql_constraint.h @@ -374,6 +374,8 @@ public: static void UniqueMerge(TMapType& output, TMapType&& input); const TPassthroughConstraintNode* ExtractField(TExprContext& ctx, const std::string_view& field) const; + const TPassthroughConstraintNode* FilterFields(TExprContext& ctx, const TPathFilter& predicate) const; + const TPassthroughConstraintNode* RenameFields(TExprContext& ctx, const TPathReduce& reduce) const; static const TPassthroughConstraintNode* MakeCommon(const std::vector<const TConstraintSet*>& constraints, TExprContext& ctx); const TPassthroughConstraintNode* MakeCommon(const TPassthroughConstraintNode* other, TExprContext& ctx) const; diff --git a/ydb/library/yql/core/yql_expr_constraint.cpp b/ydb/library/yql/core/yql_expr_constraint.cpp index 976e49f9028..3a153c2cdb6 100644 --- a/ydb/library/yql/core/yql_expr_constraint.cpp +++ b/ydb/library/yql/core/yql_expr_constraint.cpp @@ -117,8 +117,8 @@ public: Functions["SelectMembers"] = &TCallableConstraintTransformer::SelectMembersWrap; Functions["FilterMembers"] = &TCallableConstraintTransformer::SelectMembersWrap; Functions["CastStruct"] = &TCallableConstraintTransformer::SelectMembersWrap; - Functions["SafeCast"] = &TCallableConstraintTransformer::SelectMembersWrap<true>; - Functions["StrictCast"] = &TCallableConstraintTransformer::SelectMembersWrap<true>; + Functions["SafeCast"] = &TCallableConstraintTransformer::CastWrap<false>; + Functions["StrictCast"] = &TCallableConstraintTransformer::CastWrap<true>; Functions["DivePrefixMembers"] = &TCallableConstraintTransformer::DivePrefixMembersWrap; Functions["OrderedFilter"] = &TCallableConstraintTransformer::FilterWrap<true>; Functions["Filter"] = &TCallableConstraintTransformer::FilterWrap<false>; @@ -194,6 +194,7 @@ public: Functions["Mux"] = &TCallableConstraintTransformer::MuxWrap; Functions["Nth"] = &TCallableConstraintTransformer::NthWrap; Functions["EquiJoin"] = &TCallableConstraintTransformer::EquiJoinWrap; + Functions["JoinDict"] = &TCallableConstraintTransformer::JoinDictWrap; Functions["MapJoinCore"] = &TCallableConstraintTransformer::MapJoinCoreWrap; Functions["GraceJoinCore"] = &TCallableConstraintTransformer::GraceJoinCoreWrap; Functions["CommonJoinCore"] = &TCallableConstraintTransformer::FromFirst<TEmptyConstraintNode>; @@ -441,16 +442,45 @@ private: return FromFirst<TPassthroughConstraintNode, TEmptyConstraintNode, TUniqueConstraintNode, TDistinctConstraintNode, TVarIndexConstraintNode>(input, output, ctx); } - template <bool CheckMembersType = false> + template<class TConstraint> + static void FilterFromHead(const TExprNode& input, TConstraintSet& constraints, const TConstraintNode::TPathFilter& filter, TExprContext& ctx) { + if (const auto source = input.Head().GetConstraint<TConstraint>()) { + if (const auto filtered = source->FilterFields(ctx, filter)) { + constraints.AddConstraint(filtered); + } + } + } + + template<class TConstraint> + static void ReduceFromHead(const TExprNode::TPtr& input, const TConstraintNode::TPathReduce& reduce, TExprContext& ctx) { + if (const auto source = input->Head().GetConstraint<TConstraint>()) { + if (const auto filtered = source->RenameFields(ctx, reduce)) { + input->AddConstraint(filtered); + } + } + } + + template<class TConstraint> + static void FilterFromHead(const TExprNode::TPtr& input, const TConstraintNode::TPathFilter& filter, TExprContext& ctx) { + if (const auto source = input->Head().GetConstraint<TConstraint>()) { + if (const auto filtered = source->FilterFields(ctx, filter)) { + input->AddConstraint(filtered); + } + } + } + + template<class TConstraint> + static void FilterFromHeadIfMissed(const TExprNode::TPtr& input, const TConstraintNode::TPathFilter& filter, TExprContext& ctx) { + if (!input->GetConstraint<TConstraint>()) + FilterFromHead<TConstraint>(input, filter, ctx); + } + TStatus SelectMembersWrap(const TExprNode::TPtr& input, TExprNode::TPtr& /*output*/, TExprContext& ctx) const { auto outItemType = input->GetTypeAnn(); while (outItemType->GetKind() == ETypeAnnotationKind::Optional) { outItemType = outItemType->Cast<TOptionalExprType>()->GetItemType(); } - auto inItemType = input->Head().GetTypeAnn(); - while (inItemType->GetKind() == ETypeAnnotationKind::Optional) { - inItemType = inItemType->Cast<TOptionalExprType>()->GetItemType(); - } + if (outItemType->GetKind() == ETypeAnnotationKind::Variant) { if (outItemType->Cast<TVariantExprType>()->GetUnderlyingType()->GetKind() == ETypeAnnotationKind::Tuple) { const auto outSize = outItemType->Cast<TVariantExprType>()->GetUnderlyingType()->Cast<TTupleExprType>()->GetSize(); @@ -485,114 +515,49 @@ private: } } else if (outItemType->GetKind() == ETypeAnnotationKind::Struct) { - const auto outStructType = outItemType->Cast<TStructExprType>(); - const auto inStructType = inItemType->Cast<TStructExprType>(); - if (const auto passthrough = input->Head().GetConstraint<TPassthroughConstraintNode>()) { - TPassthroughConstraintNode::TMapType filteredMapping; - if constexpr (CheckMembersType) { - const auto& inItems = inStructType->GetItems(); - const auto& outItems = outStructType->GetItems(); - for (const auto& part : passthrough->GetColumnMapping()) { - TPassthroughConstraintNode::TPartType filtered; - filtered.reserve(part.second.size()); - for (const auto& item : part.second) { - if (!item.first.empty()) { - const auto outItem = outStructType->FindItem(item.first.front()); - const auto inItem = inStructType->FindItem(item.first.front()); - if (outItem && inItem && IsSameAnnotation(*outItems[*outItem]->GetItemType(), *inItems[*inItem]->GetItemType())) { - filtered.push_back(item); - } - } - if (!filtered.empty()) { - filteredMapping.emplace(part.first ? part.first : passthrough, std::move(filtered)); - } - } - } - } else { - for (const auto& part: passthrough->GetColumnMapping()) { - TPassthroughConstraintNode::TPartType filtered; - filtered.reserve(part.second.size()); - for (const auto& item : part.second) { - if (!item.first.empty() && outStructType->FindItem(item.first.front())) { - filtered.push_back(item); - } - } - if (!filtered.empty()) { - filteredMapping.emplace(part.first ? part.first : passthrough, std::move(filtered)); - } - } - } - if (!filteredMapping.empty()) { - input->AddConstraint(ctx.MakeConstraint<TPassthroughConstraintNode>(std::move(filteredMapping))); - } - } - - const auto filter = CheckMembersType ? - TConstraintNode::TPathFilter([inStructType, outStructType](const TConstraintNode::TPathType& path) { - if (path.empty()) - return false; - if (const auto itemType = TConstraintNode::GetSubTypeByPath(path, *outStructType)) - return IsSameAnnotation(*itemType, *TConstraintNode::GetSubTypeByPath(path, *inStructType)); - return false; - }): - TConstraintNode::TPathFilter([outStructType](const TConstraintNode::TPathType& path) { - return !path.empty() && TConstraintNode::GetSubTypeByPath(path, *outStructType); } - ); + const auto filter = [outItemType](const TConstraintNode::TPathType& path) { + return !path.empty() && TConstraintNode::GetSubTypeByPath(path, *outItemType); + }; - if (const auto part = input->Head().GetConstraint<TPartOfSortedConstraintNode>()) { - if (const auto filtered = part->FilterFields(ctx, filter)) { - input->AddConstraint(filtered); - } - } + FilterFromHead<TPassthroughConstraintNode>(input, filter, ctx); + FilterFromHead<TPartOfSortedConstraintNode>(input, filter, ctx); + FilterFromHead<TPartOfChoppedConstraintNode>(input, filter, ctx); + FilterFromHead<TPartOfUniqueConstraintNode>(input, filter, ctx); + FilterFromHead<TPartOfDistinctConstraintNode>(input, filter, ctx); + } - if (const auto part = input->Head().GetConstraint<TPartOfChoppedConstraintNode>()) { - if (const auto filtered = part->FilterFields(ctx, filter)) { - input->AddConstraint(filtered); - } - } + return TStatus::Ok; + } - if (const auto part = input->Head().GetConstraint<TPartOfUniqueConstraintNode>()) { - if (const auto filtered = part->FilterFields(ctx, filter)) { - input->AddConstraint(filtered); - } - } + template <bool Strict> + TStatus CastWrap(const TExprNode::TPtr& input, TExprNode::TPtr& /*output*/, TExprContext& ctx) const { + const auto outItemType = input->GetTypeAnn(); + const auto inItemType = input->Head().GetTypeAnn(); + const auto filter = [inItemType, outItemType](const TConstraintNode::TPathType& path) { + if (const auto outType = TConstraintNode::GetSubTypeByPath(path, *outItemType)) + return IsSameAnnotation(*outType, *TConstraintNode::GetSubTypeByPath(path, *inItemType)); + return false; + }; - if (const auto part = input->Head().GetConstraint<TPartOfDistinctConstraintNode>()) { - if (const auto filtered = part->FilterFields(ctx, filter)) { - input->AddConstraint(filtered); - } - } - } + const auto filterForUnique = Strict ? [outItemType](const TConstraintNode::TPathType& path) { + return bool(TConstraintNode::GetSubTypeByPath(path, *outItemType)); + } : TConstraintNode::TPathFilter(filter); + FilterFromHead<TPassthroughConstraintNode>(input, filter, ctx); + FilterFromHead<TSortedConstraintNode>(input, filter, ctx); + FilterFromHead<TChoppedConstraintNode>(input, filter, ctx); + FilterFromHead<TUniqueConstraintNode>(input, filter, ctx); + FilterFromHead<TDistinctConstraintNode>(input, filter, ctx); + FilterFromHead<TPartOfSortedConstraintNode>(input, filter, ctx); + FilterFromHead<TPartOfChoppedConstraintNode>(input, filter, ctx); + FilterFromHead<TPartOfUniqueConstraintNode>(input, filterForUnique, ctx); + FilterFromHead<TPartOfDistinctConstraintNode>(input, filterForUnique, ctx); return TStatus::Ok; } TStatus DivePrefixMembersWrap(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExprContext& ctx) const { - const auto prefixes = input->Child(1)->Children(); - if (const auto passthrough = input->Head().GetConstraint<TPassthroughConstraintNode>()) { - TPassthroughConstraintNode::TMapType filteredMapping; - for (const auto& part: passthrough->GetColumnMapping()) { - TPassthroughConstraintNode::TPartType filtered; - filtered.reserve(part.second.size()); - for (auto item: part.second) { - for (const auto& p : prefixes) { - if (const auto& prefix = p->Content(); !item.first.empty() && item.first.front().starts_with(prefix)) { - item.first.front() = item.first.front().substr(prefix.length()); - filtered.insert_unique(std::move(item)); - break; - } - } - } - if (!filtered.empty()) { - filteredMapping.emplace(part.first ? part.first : passthrough, std::move(filtered)); - } - } - if (!filteredMapping.empty()) { - input->AddConstraint(ctx.MakeConstraint<TPassthroughConstraintNode>(std::move(filteredMapping))); - } - } - - const auto rename = [&](const TConstraintNode::TPathType& path) -> std::vector<TConstraintNode::TPathType> { + const auto prefixes = input->Tail().Children(); + const auto rename = [&prefixes](const TConstraintNode::TPathType& path) -> std::vector<TConstraintNode::TPathType> { if (path.empty()) return {}; @@ -607,92 +572,26 @@ private: return {}; }; - - if (const auto part = input->Head().GetConstraint<TPartOfSortedConstraintNode>()) { - if (const auto filtered = part->RenameFields(ctx, rename)) { - input->AddConstraint(filtered); - } - } - - if (const auto part = input->Head().GetConstraint<TPartOfUniqueConstraintNode>()) { - if (const auto filtered = part->RenameFields(ctx, rename)) { - input->AddConstraint(filtered); - } - } - - if (const auto part = input->Head().GetConstraint<TPartOfDistinctConstraintNode>()) { - if (const auto filtered = part->RenameFields(ctx, rename)) { - input->AddConstraint(filtered); - } - } - + ReduceFromHead<TPassthroughConstraintNode>(input, rename, ctx); + ReduceFromHead<TPartOfSortedConstraintNode>(input, rename, ctx); + ReduceFromHead<TPartOfChoppedConstraintNode>(input, rename, ctx); + ReduceFromHead<TPartOfUniqueConstraintNode>(input, rename, ctx); + ReduceFromHead<TPartOfDistinctConstraintNode>(input, rename, ctx); return FromFirst<TVarIndexConstraintNode>(input, output, ctx); } - template<class TConstraint> - static void FilterFromHead(const TExprNode& input, TConstraintSet& constraints, const TConstraintNode::TPathFilter& filter, TExprContext& ctx) { - if (const auto source = input.Head().GetConstraint<TConstraint>()) { - if (const auto filtered = source->FilterFields(ctx, filter)) { - constraints.AddConstraint(filtered); - } - } - } - - template<class TConstraint> - static void ReduceFromHead(const TExprNode::TPtr& input, const TConstraintNode::TPathReduce& reduce, TExprContext& ctx) { - if (const auto source = input->Head().GetConstraint<TConstraint>()) { - if (const auto filtered = source->RenameFields(ctx, reduce)) { - input->AddConstraint(filtered); - } - } - } - - template<class TConstraint> - static void FilterFromHead(const TExprNode::TPtr& input, const TConstraintNode::TPathFilter& filter, TExprContext& ctx) { - if (const auto source = input->Head().GetConstraint<TConstraint>()) { - if (const auto filtered = source->FilterFields(ctx, filter)) { - input->AddConstraint(filtered); - } - } - } - - template<class TConstraint> - static void FilterFromHeadIfMissed(const TExprNode::TPtr& input, const TConstraintNode::TPathFilter& filter, TExprContext& ctx) { - if (!input->GetConstraint<TConstraint>()) - FilterFromHead<TConstraint>(input, filter, ctx); - } - TStatus ExtractMembersWrap(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExprContext& ctx) const { const auto outItemType = GetSeqItemType(*input->GetTypeAnn()).Cast<TStructExprType>(); - if (const auto passthrough = input->Head().GetConstraint<TPassthroughConstraintNode>()) { - TPassthroughConstraintNode::TMapType filteredMapping; - for (const auto& part: passthrough->GetColumnMapping()) { - TPassthroughConstraintNode::TPartType filtered; - filtered.reserve(part.second.size()); - for (const auto& item: part.second) { - if (!item.first.empty() && outItemType->FindItem(item.first.front())) { - filtered.push_back(item); - } - } - if (!filtered.empty()) { - filteredMapping.emplace(part.first ? part.first : passthrough, std::move(filtered)); - } - } - if (!filteredMapping.empty()) { - input->AddConstraint(ctx.MakeConstraint<TPassthroughConstraintNode>(std::move(filteredMapping))); - } - } - const auto filter = [outItemType](const TConstraintNode::TPathType& path) { return !path.empty() && outItemType->FindItem(path.front()); }; + FilterFromHead<TPassthroughConstraintNode>(input, filter, ctx); FilterFromHead<TSortedConstraintNode>(input, filter, ctx); FilterFromHead<TChoppedConstraintNode>(input, filter, ctx); FilterFromHead<TUniqueConstraintNode>(input, filter, ctx); FilterFromHead<TDistinctConstraintNode>(input, filter, ctx); FilterFromHead<TPartOfSortedConstraintNode>(input, filter, ctx); - FilterFromHead<TChoppedConstraintNode>(input, filter, ctx); + FilterFromHead<TPartOfChoppedConstraintNode>(input, filter, ctx); FilterFromHead<TPartOfUniqueConstraintNode>(input, filter, ctx); FilterFromHead<TPartOfDistinctConstraintNode>(input, filter, ctx); - return FromFirst<TEmptyConstraintNode, TVarIndexConstraintNode>(input, output, ctx); } @@ -704,26 +603,8 @@ private: if (outItemType->GetKind() == ETypeAnnotationKind::Struct) { const auto outStructType = outItemType->Cast<TStructExprType>(); - if (const auto passthrough = input->Head().GetConstraint<TPassthroughConstraintNode>()) { - TPassthroughConstraintNode::TMapType filteredMapping; - for (const auto& part: passthrough->GetColumnMapping()) { - TPassthroughConstraintNode::TPartType filtered; - filtered.reserve(part.second.size()); - for (const auto& item: part.second) { - if (!item.first.empty() && outStructType->FindItem(item.first.front())) { - filtered.push_back(item); - } - } - if (!filtered.empty()) { - filteredMapping.emplace(part.first ? part.first : passthrough, std::move(filtered)); - } - } - if (!filteredMapping.empty()) { - input->AddConstraint(ctx.MakeConstraint<TPassthroughConstraintNode>(std::move(filteredMapping))); - } - } - const auto filter = [outStructType](const TConstraintNode::TPathType& path) { return !path.empty() && outStructType->FindItem(path.front()); }; + FilterFromHead<TPassthroughConstraintNode>(input, filter, ctx); FilterFromHead<TPartOfSortedConstraintNode>(input, filter, ctx); FilterFromHead<TPartOfChoppedConstraintNode>(input, filter, ctx); FilterFromHead<TPartOfUniqueConstraintNode>(input, filter, ctx); @@ -737,28 +618,9 @@ private: YQL_ENSURE(item.first < tupleUnderType->GetSize()); auto& constr = multiItems[item.first]; - const TStructExprType* outStructType = tupleUnderType->GetItems()[item.first]->Cast<TStructExprType>(); - - if (const auto passthrough = item.second.GetConstraint<TPassthroughConstraintNode>()) { - TPassthroughConstraintNode::TMapType filteredMapping; - for (const auto& part: passthrough->GetColumnMapping()) { - TPassthroughConstraintNode::TPartType filtered; - filtered.reserve(part.second.size()); - for (const auto& item: part.second) { - if (!item.first.empty() && outStructType->FindItem(item.first.front())) { - filtered.push_back(item); - } - } - if (!filtered.empty()) { - filteredMapping.emplace(part.first ? part.first : passthrough, std::move(filtered)); - } - } - if (!filteredMapping.empty()) { - constr.AddConstraint(ctx.MakeConstraint<TPassthroughConstraintNode>(std::move(filteredMapping))); - } - } - + const auto outStructType = tupleUnderType->GetItems()[item.first]->Cast<TStructExprType>(); const auto filter = [outStructType](const TConstraintNode::TPathType& path) { return !path.empty() && outStructType->FindItem(path.front()); }; + FilterFromHead<TPassthroughConstraintNode>(*input, constr, filter, ctx); FilterFromHead<TPartOfSortedConstraintNode>(*input, constr, filter, ctx); FilterFromHead<TPartOfChoppedConstraintNode>(*input, constr, filter, ctx); FilterFromHead<TPartOfUniqueConstraintNode>(*input, constr, filter, ctx); @@ -1666,33 +1528,12 @@ private: TStatus RemoveMemberWrap(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExprContext& ctx) const { const auto& name = input->Tail().Content(); - if (const auto structPassthrough = input->Head().GetConstraint<TPassthroughConstraintNode>()) { - const TConstraintNode::TPathType key(1U, name); - auto mapping = structPassthrough->GetColumnMapping(); - if (const auto self = mapping.find(nullptr); mapping.cend() != self) - mapping.emplace(structPassthrough, std::move(mapping.extract(self).mapped())); - for (auto p = mapping.begin(); mapping.end() != p;) { - if (auto it = p->second.lower_bound(key); p->second.cend() > it && it->first.front() == key.front()) { - do p->second.erase(it++); - while (p->second.end() > it && it->first.front() == key.front()); - if (p->second.empty()) { - mapping.erase(p++); - continue; - } - } - ++p; - } - if (!mapping.empty()) { - input->AddConstraint(ctx.MakeConstraint<TPassthroughConstraintNode>(std::move(mapping))); - } - } - const auto filter = [&name](const TConstraintNode::TPathType& path) { return !path.empty() && path.front() != name; }; + FilterFromHead<TPassthroughConstraintNode>(input, filter, ctx); FilterFromHead<TPartOfSortedConstraintNode>(input, filter, ctx); FilterFromHead<TPartOfChoppedConstraintNode>(input, filter, ctx); FilterFromHead<TPartOfUniqueConstraintNode>(input, filter, ctx); FilterFromHead<TPartOfDistinctConstraintNode>(input, filter, ctx); - return FromFirst<TVarIndexConstraintNode>(input, output, ctx); } @@ -2613,6 +2454,109 @@ private: return TStatus::Ok; } + template<bool Distinct> + static const TUniqueConstraintNodeBase<Distinct>* GetForPayload(const TExprNode& input, TExprContext& ctx) { + if (const auto constraint = input.GetConstraint<TUniqueConstraintNodeBase<Distinct>>()) { + return constraint->RenameFields(ctx, [&ctx](const TConstraintNode::TPathType& path) -> std::vector<TConstraintNode::TPathType> { + if (path.empty() || path.front() != ctx.GetIndexAsString(1U)) + return {}; + auto copy = path; + copy.pop_front(); + return {copy}; + }); + } + return nullptr; + } + + TStatus JoinDictWrap(const TExprNode::TPtr& input, TExprNode::TPtr& /*output*/, TExprContext& ctx) const { + const TCoJoinDict join(input); + const auto& joinType = join.JoinKind().Ref(); + if (const auto lEmpty = join.LeftInput().Ref().GetConstraint<TEmptyConstraintNode>(), rEmpty = join.RightInput().Ref().GetConstraint<TEmptyConstraintNode>(); lEmpty && rEmpty) { + input->AddConstraint(ctx.MakeConstraint<TEmptyConstraintNode>()); + } else if (lEmpty && joinType.Content().starts_with("Left")) { + input->AddConstraint(lEmpty); + } else if (rEmpty && joinType.Content().starts_with("Right")) { + input->AddConstraint(rEmpty); + } else if ((lEmpty || rEmpty) && (joinType.IsAtom("Inner") || joinType.Content().ends_with("Semi"))) { + input->AddConstraint(ctx.MakeConstraint<TEmptyConstraintNode>()); + } + + bool lOneRow = false, rOneRow = false; + if (const auto& flags = join.Flags()) { + flags.Cast().Ref().ForEachChild([&](const TExprNode& flag) { + lOneRow = lOneRow || flag.IsAtom("LeftUnique"); + rOneRow = rOneRow || flag.IsAtom("RightUnique"); + }); + } + + const auto lUnique = GetForPayload<false>(join.LeftInput().Ref(), ctx); + const auto rUnique = GetForPayload<false>(join.RightInput().Ref(), ctx); + + const auto lDistinct = GetForPayload<true>(join.LeftInput().Ref(), ctx); + const auto rDistinct = GetForPayload<true>(join.RightInput().Ref(), ctx); + + const bool leftSide = joinType.Content().starts_with("Left"); + const bool rightSide = joinType.Content().starts_with("Right"); + + if (joinType.Content().ends_with("Semi") || joinType.Content().ends_with("Only")) { + if (leftSide) { + if (lUnique) + input->AddConstraint(lUnique); + if (lDistinct) + input->AddConstraint(lDistinct); + } else if (rightSide) { + if (rUnique) + input->AddConstraint(rUnique); + if (rDistinct) + input->AddConstraint(rDistinct); + } + } else if (lOneRow || rOneRow) { + const auto rename = [](const std::string_view& prefix, TConstraintNode::TPathType path) { + path.emplace_front(prefix); + return std::vector<TConstraintNode::TPathType>(1U, std::move(path)); + }; + const auto leftRename = std::bind(rename, ctx.GetIndexAsString(0U), std::placeholders::_1); + const auto rightRename = std::bind(rename, ctx.GetIndexAsString(1U), std::placeholders::_1); + + if (lUnique || rUnique) { + const TUniqueConstraintNode* unique = nullptr; + + const bool exclusion = joinType.IsAtom("Exclusion"); + const bool useLeft = lUnique && (rOneRow || exclusion); + const bool useRight = rUnique && (lOneRow || exclusion); + + if (useLeft && !useRight) + unique = lUnique->RenameFields(ctx, leftRename); + else if (useRight && !useLeft) + unique = rUnique->RenameFields(ctx, rightRename); + else if (useLeft && useRight) + unique = TUniqueConstraintNode::Merge(lUnique->RenameFields(ctx, leftRename), rUnique->RenameFields(ctx, rightRename), ctx); + + if (unique) + input->AddConstraint(unique); + } + + if (lDistinct || rDistinct) { + const TDistinctConstraintNode* distinct = nullptr; + + const bool inner = joinType.IsAtom("Inner"); + const bool useLeft = lDistinct && rOneRow && (inner || leftSide); + const bool useRight = rDistinct && lOneRow && (inner || rightSide); + + if (useLeft && !useRight) + distinct = lDistinct->RenameFields(ctx, leftRename); + else if (useRight && !useLeft) + distinct = rDistinct->RenameFields(ctx, rightRename); + else if (useLeft && useRight) + distinct = TDistinctConstraintNode::Merge(lDistinct->RenameFields(ctx, leftRename), rDistinct->RenameFields(ctx, rightRename), ctx); + + if (distinct) + input->AddConstraint(distinct); + } + } + + return TStatus::Ok; + } TStatus IsKeySwitchWrap(const TExprNode::TPtr& input, TExprNode::TPtr& /*output*/, TExprContext& ctx) const { const TCoIsKeySwitch keySwitch(input); |