summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authoraneporada <[email protected]>2023-01-05 19:00:02 +0300
committeraneporada <[email protected]>2023-01-05 19:00:02 +0300
commitd7b4c95518652104d51c915f0fd20fd3e954135c (patch)
treeb075db6ef90bf898c2dafa2791577cfefe9a728b
parentb095aea0de8d68cf45a23c9043f5b7d1ea91e52a (diff)
Emit StringContains/EndsWith for sql LIKE if pragma AnsiLike is set
-rw-r--r--ydb/library/yql/core/common_opt/yql_co_simple1.cpp12
-rw-r--r--ydb/library/yql/core/common_opt/yql_co_simple2.cpp2
-rw-r--r--ydb/library/yql/core/type_ann/type_ann_core.cpp1
-rw-r--r--ydb/library/yql/minikql/invoke_builtins/mkql_builtins_with.cpp2
-rw-r--r--ydb/library/yql/minikql/mkql_program_builder.cpp30
-rw-r--r--ydb/library/yql/minikql/mkql_program_builder.h1
-rw-r--r--ydb/library/yql/minikql/mkql_runtime_version.h2
-rw-r--r--ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp1
-rw-r--r--ydb/library/yql/sql/v1/context.h1
-rw-r--r--ydb/library/yql/sql/v1/sql.cpp148
10 files changed, 146 insertions, 54 deletions
diff --git a/ydb/library/yql/core/common_opt/yql_co_simple1.cpp b/ydb/library/yql/core/common_opt/yql_co_simple1.cpp
index b1a5b4e40c2..b4afdb81c1d 100644
--- a/ydb/library/yql/core/common_opt/yql_co_simple1.cpp
+++ b/ydb/library/yql/core/common_opt/yql_co_simple1.cpp
@@ -4519,8 +4519,16 @@ void RegisterCoSimpleCallables1(TCallableOptimizerMap& map) {
map["IsNotDistinctFrom"] = std::bind(&OptimizeDistinctFrom<true>, _1, _2);
map["IsDistinctFrom"] = std::bind(&OptimizeDistinctFrom<false>, _1, _2);
- map["StartsWith"] = std::bind(&OptimizeEquality<true>, _1, _2);
- map["EndsWith"] = std::bind(&OptimizeEquality<true>, _1, _2);
+ map["StartsWith"] = map["EndsWith"] = map["StringContains"] = [](const TExprNode::TPtr& node, TExprContext& ctx, TOptimizeContext& /*optCtx*/) {
+ if (node->GetTypeAnn()->GetKind() != ETypeAnnotationKind::Optional &&
+ node->Tail().IsCallable("String") && node->Tail().Head().Content().empty())
+ {
+ YQL_CLOG(DEBUG, Core) << node->Content() << " with empty string in second argument";
+ return MakeBool<true>(node->Pos(), ctx);
+ }
+
+ return OptimizeEquality<true>(node, ctx);
+ };
map["<"] = map["<="] = map[">"] = map[">="] = std::bind(&OptimizeCompare, _1, _2);;
diff --git a/ydb/library/yql/core/common_opt/yql_co_simple2.cpp b/ydb/library/yql/core/common_opt/yql_co_simple2.cpp
index 100bc88cd42..1ac2dc216c1 100644
--- a/ydb/library/yql/core/common_opt/yql_co_simple2.cpp
+++ b/ydb/library/yql/core/common_opt/yql_co_simple2.cpp
@@ -480,7 +480,7 @@ void RegisterCoSimpleCallables2(TCallableOptimizerMap& map) {
map["AggrMin"] = map["AggrMax"] = map["Coalesce"] = std::bind(&DropAggrOverSame, _1);
- map["StartsWith"] = map["EndsWith"] = std::bind(&CheckCompareSame<true, false>, _1, _2);
+ map["StartsWith"] = map["EndsWith"] = map["StringContains"] = std::bind(&CheckCompareSame<true, false>, _1, _2);
map["=="] = map["<="] = map[">="] = std::bind(&CheckCompareSame<true, false>, _1, _2);
map["!="] = map["<"] = map[">"] = std::bind(&CheckCompareSame<false, false>, _1, _2);
diff --git a/ydb/library/yql/core/type_ann/type_ann_core.cpp b/ydb/library/yql/core/type_ann/type_ann_core.cpp
index 1c3653b5863..bab3640cf57 100644
--- a/ydb/library/yql/core/type_ann/type_ann_core.cpp
+++ b/ydb/library/yql/core/type_ann/type_ann_core.cpp
@@ -11359,6 +11359,7 @@ template <NKikimr::NUdf::EDataSlot DataSlot>
Functions["RFind"] = &FindWrapper;
Functions["StartsWith"] = &WithWrapper;
Functions["EndsWith"] = &WithWrapper;
+ Functions["StringContains"] = &WithWrapper;
Functions["ByteAt"] = &ByteAtWrapper;
Functions["ListIf"] = &ListIfWrapper;
Functions["AsList"] = &AsListWrapper<false>;
diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_with.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_with.cpp
index c8cd4a51d8c..dae40098025 100644
--- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_with.cpp
+++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_with.cpp
@@ -59,7 +59,7 @@ template<NUdf::EDataSlot> using TContains = TStringWith<&StringContains>;
void RegisterWith(IBuiltinFunctionRegistry& registry) {
RegisterCompareStrings<TStartsWith, TCompareArgsOpt, false>(registry, "StartsWith");
RegisterCompareStrings<TEndsWith, TCompareArgsOpt, false>(registry, "EndsWith");
- RegisterCompareStrings<TContains, TCompareArgsOpt, false>(registry, "Contains");
+ RegisterCompareStrings<TContains, TCompareArgsOpt, false>(registry, "StringContains");
}
} // namespace NMiniKQL
diff --git a/ydb/library/yql/minikql/mkql_program_builder.cpp b/ydb/library/yql/minikql/mkql_program_builder.cpp
index bb9bf360732..501e708c0af 100644
--- a/ydb/library/yql/minikql/mkql_program_builder.cpp
+++ b/ydb/library/yql/minikql/mkql_program_builder.cpp
@@ -4121,6 +4121,36 @@ TRuntimeNode TProgramBuilder::EndsWith(TRuntimeNode string, TRuntimeNode suffix)
return DataCompare(__func__, string, suffix);
}
+TRuntimeNode TProgramBuilder::StringContains(TRuntimeNode string, TRuntimeNode pattern) {
+ bool isOpt1, isOpt2;
+ TDataType* type1 = UnpackOptionalData(string, isOpt1);
+ TDataType* type2 = UnpackOptionalData(pattern, isOpt2);
+ MKQL_ENSURE(type1->GetSchemeType() == NUdf::TDataType<NUdf::TUtf8>::Id ||
+ type1->GetSchemeType() == NUdf::TDataType<char*>::Id, "Expecting string as first argument");
+ MKQL_ENSURE(type2->GetSchemeType() == NUdf::TDataType<NUdf::TUtf8>::Id ||
+ type2->GetSchemeType() == NUdf::TDataType<char*>::Id, "Expecting string as second argument");
+ if constexpr (RuntimeVersion < 32U) {
+ auto stringCasted = (type1->GetSchemeType() == NUdf::TDataType<NUdf::TUtf8>::Id) ? ToString(string) : string;
+ auto patternCasted = (type2->GetSchemeType() == NUdf::TDataType<NUdf::TUtf8>::Id) ? ToString(pattern) : pattern;
+ auto found = Exists(Find(stringCasted, patternCasted, NewDataLiteral(ui32(0))));
+ if (!isOpt1 && !isOpt2) {
+ return found;
+ }
+ TVector<TRuntimeNode> predicates;
+ if (isOpt1) {
+ predicates.push_back(Exists(string));
+ }
+ if (isOpt2) {
+ predicates.push_back(Exists(pattern));
+ }
+
+ TRuntimeNode argsNotNull = (predicates.size() == 1) ? predicates.front() : And(predicates);
+ return If(argsNotNull, NewOptional(found), NewEmptyOptionalDataLiteral(NUdf::TDataType<bool>::Id));
+ }
+
+ return DataCompare(__func__, string, pattern);
+}
+
TRuntimeNode TProgramBuilder::ByteAt(TRuntimeNode data, TRuntimeNode index) {
const std::array<TRuntimeNode, 2U> args = {{ data, index }};
return Invoke(__func__, NewOptionalType(NewDataType(NUdf::TDataType<ui8>::Id)), args);
diff --git a/ydb/library/yql/minikql/mkql_program_builder.h b/ydb/library/yql/minikql/mkql_program_builder.h
index d037a3cfd47..99b42332018 100644
--- a/ydb/library/yql/minikql/mkql_program_builder.h
+++ b/ydb/library/yql/minikql/mkql_program_builder.h
@@ -210,6 +210,7 @@ public:
TRuntimeNode RFind(TRuntimeNode haystack, TRuntimeNode needle, TRuntimeNode pos);
TRuntimeNode StartsWith(TRuntimeNode string, TRuntimeNode prefix);
TRuntimeNode EndsWith(TRuntimeNode string, TRuntimeNode suffix);
+ TRuntimeNode StringContains(TRuntimeNode string, TRuntimeNode pattern);
TRuntimeNode ByteAt(TRuntimeNode data, TRuntimeNode index);
TRuntimeNode Size(TRuntimeNode data);
template <bool Utf8 = false>
diff --git a/ydb/library/yql/minikql/mkql_runtime_version.h b/ydb/library/yql/minikql/mkql_runtime_version.h
index 44337f8b7b8..043e54c485f 100644
--- a/ydb/library/yql/minikql/mkql_runtime_version.h
+++ b/ydb/library/yql/minikql/mkql_runtime_version.h
@@ -24,7 +24,7 @@ namespace NMiniKQL {
// 1. Bump this version every time incompatible runtime nodes are introduced.
// 2. Make sure you provide runtime node generation for previous runtime versions.
#ifndef MKQL_RUNTIME_VERSION
-#define MKQL_RUNTIME_VERSION 31U
+#define MKQL_RUNTIME_VERSION 32U
#endif
// History:
diff --git a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp
index a6c89236af8..020ba1e709c 100644
--- a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp
+++ b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp
@@ -538,6 +538,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() {
{"StartsWith", &TProgramBuilder::StartsWith},
{"EndsWith", &TProgramBuilder::EndsWith},
+ {"StringContains", &TProgramBuilder::StringContains},
{"SqueezeToList", &TProgramBuilder::SqueezeToList},
diff --git a/ydb/library/yql/sql/v1/context.h b/ydb/library/yql/sql/v1/context.h
index 0a6d2c31f6c..c1ba2537675 100644
--- a/ydb/library/yql/sql/v1/context.h
+++ b/ydb/library/yql/sql/v1/context.h
@@ -279,6 +279,7 @@ namespace NSQLTranslationV1 {
bool EmitStartsWith = true;
bool EmitAggApply = false;
bool UseBlocks = false;
+ bool AnsiLike = false;
};
class TColumnRefScope {
diff --git a/ydb/library/yql/sql/v1/sql.cpp b/ydb/library/yql/sql/v1/sql.cpp
index 4a62213ef2f..d2cca8d29e6 100644
--- a/ydb/library/yql/sql/v1/sql.cpp
+++ b/ydb/library/yql/sql/v1/sql.cpp
@@ -2177,44 +2177,52 @@ namespace {
}
template<typename TChar>
- struct TSplitResult {
+ struct TPatternComponent {
TBasicString<TChar> Prefix;
TBasicString<TChar> Suffix;
+ bool IsSimple = true;
+
+ void AppendPlain(TChar c) {
+ if (IsSimple) {
+ Prefix.push_back(c);
+ }
+ Suffix.push_back(c);
+ }
+
+ void AppendAnyChar() {
+ IsSimple = false;
+ Suffix.clear();
+ }
};
template<typename TChar>
- TSplitResult<TChar> SplitPattern(const TBasicString<TChar>& pattern, TMaybe<char> escape,
- bool& inEscape, bool& hasPattern, bool& isSimple)
- {
- inEscape = hasPattern = false;
- isSimple = true;
- TSplitResult<TChar> result;
- TBasicString<TChar> left, right, current;
- for (const auto c : pattern) {
+ TVector<TPatternComponent<TChar>> SplitPattern(const TBasicString<TChar>& pattern, TMaybe<char> escape, bool& inEscape) {
+ inEscape = false;
+ TVector<TPatternComponent<TChar>> result;
+ TPatternComponent<TChar> current;
+ bool prevIsPercentChar = false;
+ for (const TChar c : pattern) {
if (inEscape) {
- current.append(c);
+ current.AppendPlain(c);
inEscape = false;
+ prevIsPercentChar = false;
} else if (escape && c == static_cast<TChar>(*escape)) {
inEscape = true;
- } else if (c == '%' || c == '_') {
- if (result.Prefix.empty() && !hasPattern)
- std::swap(result.Prefix, current);
- else if (left.empty())
- std::swap(left, current);
- else if (right.empty())
- std::swap(right, current);
- else
- isSimple = false;
- hasPattern = true;
- if (c == '_') {
- isSimple = false;
- }
+ } else if (c == '%') {
+ if (!prevIsPercentChar) {
+ result.push_back(std::move(current));
+ }
+ current = {};
+ prevIsPercentChar = true;
+ } else if (c == '_') {
+ current.AppendAnyChar();
+ prevIsPercentChar = false;
} else {
- current.append(c);
+ current.AppendPlain(c);
+ prevIsPercentChar = false;
}
}
-
- result.Suffix = std::move(current);
+ result.push_back(std::move(current));
return result;
}
}
@@ -5160,26 +5168,26 @@ TNodePtr TSqlExpression::SubExpr(const TRule_xor_subexpr& node, const TTrailingQ
}
if (literalPattern) {
- TString prefix, suffix;
bool inEscape = false;
- bool hasPattern = false;
- bool isSimple = true;
-
TMaybe<char> escape;
if (escapeLiteral) {
escape = escapeLiteral->front();
}
bool mayIgnoreCase;
+ TVector<TPatternComponent<char>> components;
if (isUtf8) {
- auto splitResult = SplitPattern(UTF8ToUTF32<false>(*literalPattern), escape, inEscape, hasPattern, isSimple);
- prefix = WideToUTF8(splitResult.Prefix);
- suffix = WideToUTF8(splitResult.Suffix);
+ auto splitResult = SplitPattern(UTF8ToUTF32<false>(*literalPattern), escape, inEscape);
+ for (const auto& component : splitResult) {
+ TPatternComponent<char> converted;
+ converted.IsSimple = component.IsSimple;
+ converted.Prefix = WideToUTF8(component.Prefix);
+ converted.Suffix = WideToUTF8(component.Suffix);
+ components.push_back(std::move(converted));
+ }
mayIgnoreCase = ToLowerUTF8(*literalPattern) == ToUpperUTF8(*literalPattern);
} else {
- auto splitResult = SplitPattern(*literalPattern, escape, inEscape, hasPattern, isSimple);
- prefix = splitResult.Prefix;
- suffix = splitResult.Suffix;
+ components = SplitPattern(*literalPattern, escape, inEscape);
mayIgnoreCase = WithoutAlpha(*literalPattern);
}
@@ -5190,32 +5198,68 @@ TNodePtr TSqlExpression::SubExpr(const TRule_xor_subexpr& node, const TTrailingQ
}
if (opName == "like" || mayIgnoreCase) {
-//TODO: Drop regex if (isSimple) {}
-
- if (!(hasPattern || suffix.empty())) {
- isMatch = BuildBinaryOp(Ctx, pos, "==", res, BuildLiteralRawString(pos, suffix, isUtf8));
- } else if (!prefix.empty()) {
+ // TODO: expand LIKE in optimizers - we can analyze argument types there
+ YQL_ENSURE(!components.empty());
+ const auto& first = components.front();
+ if (components.size() == 1 && first.IsSimple) {
+ // no '%'s and '_'s in pattern
+ YQL_ENSURE(first.Prefix == first.Suffix);
+ isMatch = BuildBinaryOp(Ctx, pos, "==", res, BuildLiteralRawString(pos, first.Suffix, isUtf8));
+ } else if (!first.Prefix.empty()) {
+ const TString& prefix = first.Prefix;
+ TNodePtr prefixMatch;
if (Ctx.EmitStartsWith) {
- const auto& lowerBoundOp = BuildBinaryOp(Ctx, pos, "StartsWith", res, BuildLiteralRawString(pos, prefix, isUtf8));
- isMatch = BuildBinaryOp(Ctx, pos, "And", lowerBoundOp, isMatch);
+ prefixMatch = BuildBinaryOp(Ctx, pos, "StartsWith", res, BuildLiteralRawString(pos, prefix, isUtf8));
} else {
- const auto& lowerBoundOp = BuildBinaryOp(Ctx, pos, ">=", res, BuildLiteralRawString(pos, prefix, isUtf8));
+ prefixMatch = BuildBinaryOp(Ctx, pos, ">=", res, BuildLiteralRawString(pos, prefix, isUtf8));
auto upperBound = isUtf8 ? NextValidUtf8(prefix) : NextLexicographicString(prefix);
-
if (upperBound) {
- const auto& between = BuildBinaryOp(
+ prefixMatch = BuildBinaryOp(
Ctx,
pos,
"And",
- lowerBoundOp,
+ prefixMatch,
BuildBinaryOp(Ctx, pos, "<", res, BuildLiteralRawString(pos, TString(*upperBound), isUtf8))
);
- isMatch = BuildBinaryOp(Ctx, pos, "And", between, isMatch);
+ }
+ }
+
+ if (Ctx.AnsiLike && first.IsSimple && components.size() == 2 && components.back().IsSimple) {
+ const TString& suffix = components.back().Suffix;
+ // 'prefix%suffix'
+ if (suffix.empty()) {
+ isMatch = prefixMatch;
} else {
- isMatch = BuildBinaryOp(Ctx, pos, "And", lowerBoundOp, isMatch);
+ // len(str) >= len(prefix) + len(suffix) && StartsWith(str, prefix) && EndsWith(str, suffix)
+ TNodePtr sizePred = BuildBinaryOp(Ctx, pos, ">=",
+ TNodePtr(new TCallNodeImpl(pos, "Size", { res })),
+ TNodePtr(new TLiteralNumberNode<ui32>(pos, "Uint32", ToString(prefix.size() + suffix.size()))));
+ TNodePtr suffixMatch = BuildBinaryOp(Ctx, pos, "EndsWith", res, BuildLiteralRawString(pos, suffix, isUtf8));
+ isMatch = new TCallNodeImpl(pos, "And", {
+ sizePred,
+ prefixMatch,
+ suffixMatch
+ });
}
+ } else {
+ isMatch = BuildBinaryOp(Ctx, pos, "And", prefixMatch, isMatch);
+ }
+ } else if (Ctx.AnsiLike && AllOf(components, [](const auto& comp) { return comp.IsSimple; })) {
+ YQL_ENSURE(first.Prefix.empty());
+ if (components.size() == 3 && components.back().Prefix.empty()) {
+ // '%foo%'
+ YQL_ENSURE(!components[1].Prefix.empty());
+ isMatch = BuildBinaryOp(Ctx, pos, "StringContains", res, BuildLiteralRawString(pos, components[1].Prefix, isUtf8));
+ } else if (components.size() == 2) {
+ // '%foo'
+ isMatch = BuildBinaryOp(Ctx, pos, "EndsWith", res, BuildLiteralRawString(pos, components[1].Prefix, isUtf8));
}
+ } else if (Ctx.AnsiLike && !components.back().Suffix.empty()) {
+ const TString& suffix = components.back().Suffix;
+ TNodePtr suffixMatch = BuildBinaryOp(Ctx, pos, "EndsWith", res, BuildLiteralRawString(pos, suffix, isUtf8));
+ isMatch = BuildBinaryOp(Ctx, pos, "And", suffixMatch, isMatch);
}
+ // TODO: more StringContains/StartsWith/EndsWith cases?
}
}
@@ -10088,6 +10132,12 @@ TNodePtr TSqlQuery::PragmaStatement(const TRule_pragma_stmt& stmt, bool& success
} else if (normalizedPragma == "disableuseblocks") {
Ctx.UseBlocks = false;
Ctx.IncrementMonCounter("sql_pragma", "DisableUseBlocks");
+ } else if (normalizedPragma == "ansilike") {
+ Ctx.AnsiLike = true;
+ Ctx.IncrementMonCounter("sql_pragma", "AnsiLike");
+ } else if (normalizedPragma == "disableansilike") {
+ Ctx.AnsiLike = false;
+ Ctx.IncrementMonCounter("sql_pragma", "DisableAnsiLike");
} else {
Error() << "Unknown pragma: " << pragma;
Ctx.IncrementMonCounter("sql_errors", "UnknownPragma");