diff options
author | vvvv <[email protected]> | 2024-11-06 10:40:15 +0300 |
---|---|---|
committer | vvvv <[email protected]> | 2024-11-06 10:50:39 +0300 |
commit | 8cf98f8169af124576399a29eac2bc2a691124e3 (patch) | |
tree | 54260d85e28822402e0d17d1e840ccfb62b33b61 /yql/essentials/ast | |
parent | 13cc9ffc62224711fd2923aed53525fc7d1838f8 (diff) |
Moved yql/ast YQL-19206
init
commit_hash:a6a63582073784b9318cc04ffcc1e212f3df703b
Diffstat (limited to 'yql/essentials/ast')
33 files changed, 17501 insertions, 0 deletions
diff --git a/yql/essentials/ast/serialize/ya.make b/yql/essentials/ast/serialize/ya.make new file mode 100644 index 00000000000..d1d37d057ea --- /dev/null +++ b/yql/essentials/ast/serialize/ya.make @@ -0,0 +1,13 @@ +LIBRARY() + +SRCS( + yql_expr_serialize.cpp +) + +PEERDIR( + yql/essentials/ast + yql/essentials/core/issue + contrib/ydb/library/yql/minikql +) + +END() diff --git a/yql/essentials/ast/serialize/yql_expr_serialize.cpp b/yql/essentials/ast/serialize/yql_expr_serialize.cpp new file mode 100644 index 00000000000..c163bf1311f --- /dev/null +++ b/yql/essentials/ast/serialize/yql_expr_serialize.cpp @@ -0,0 +1,499 @@ +#include "yql_expr_serialize.h" +#include <contrib/ydb/library/yql/minikql/pack_num.h> +#include <util/generic/algorithm.h> +#include <util/generic/deque.h> + +namespace NYql { + +namespace { + +enum ESerializeCommands { + NODE_REF = 0x00, + NODE_VALUE = 0x10, + INLINE_STR = 0x08, // string is unique, don't write it to the pool + SAME_POSITION = 0x40, + ATOM_FLAG = 0x20, + WIDE = 0x80, // mark wide lambdas + ATOM = ATOM_FLAG | NODE_VALUE, // for atoms we will use TNodeFlags bits (1/2/4) + LIST = TExprNode::List | NODE_VALUE, + CALLABLE = TExprNode::Callable | NODE_VALUE, + LAMBDA = TExprNode::Lambda | NODE_VALUE, + ARGUMENT = TExprNode::Argument | NODE_VALUE, + ARGUMENTS = TExprNode::Arguments | NODE_VALUE, + WORLD = TExprNode::World | NODE_VALUE, +}; + +using namespace NKikimr; + +class TWriter { +public: + TWriter(TExprContext& ctx, ui16 components) + : Ctx(ctx) + , Components_(components) + { + } + + const TString& Out() const { + //Cerr << "Nodes:" << WrittenNodes_.size() << ", pos: " << Positions_.size() << ", bytes: " << Out_.size() << "\n"; + return Out_; + } + + void Prepare(const TExprNode& node) { + TNodeSet visited; + PrepareImpl(node, visited); + } + + void Init() { + WriteVar32(Components_); + ui32 reusedStringCount = 0; + for (auto& x : StringCounters_) { + if (x.second.first > 1) { + x.second.second = reusedStringCount; + ++reusedStringCount; + } + } + + WriteVar32(reusedStringCount); + TVector<std::pair<TStringBuf, ui32>> sortedStrings; + sortedStrings.reserve(reusedStringCount); + for (const auto& x : StringCounters_) { + if (x.second.first > 1) { + sortedStrings.push_back({ x.first, x.second.second }); + } + } + + Sort(sortedStrings.begin(), sortedStrings.end(), [](const auto& x, const auto& y) { return x.second < y.second; }); + + for (const auto& x : sortedStrings) { + WriteVar32(x.first.length()); + WriteMany(x.first.data(), x.first.length()); + } + + if (Components_ & TSerializedExprGraphComponents::Positions) { + WriteVar32(Files_.size()); + TVector<std::pair<TStringBuf, ui32>> sortedFiles; + sortedFiles.reserve(Files_.size()); + for (const auto& x : Files_) { + sortedFiles.push_back({ x.first, x.second }); + } + + Sort(sortedFiles.begin(), sortedFiles.end(), [](const auto& x, const auto& y) { return x.second < y.second; }); + for (const auto& x : sortedFiles) { + WriteVar32(x.first.length()); + WriteMany(x.first.data(), x.first.length()); + } + + WriteVar32(Positions_.size()); + TVector<std::tuple<ui32, ui32, ui32, ui32>> sortedPositions; + sortedPositions.reserve(Positions_.size()); + for (const auto& x : Positions_) { + sortedPositions.push_back({ std::get<0>(x.first), std::get<1>(x.first), std::get<2>(x.first), x.second }); + } + + Sort(sortedPositions.begin(), sortedPositions.end(), [](const auto& x, const auto& y) + { return std::get<3>(x) < std::get<3>(y); }); + + for (const auto& x : sortedPositions) { + WriteVar32(std::get<0>(x)); + WriteVar32(std::get<1>(x)); + WriteVar32(std::get<2>(x)); + } + } + } + + void Save(const TExprNode& node) { + auto writtenIt = WrittenNodes_.find(&node); + if (writtenIt != WrittenNodes_.end()) { + Write(NODE_REF); + WriteVar32(writtenIt->second); + return; + } + + char command = (node.Type() == TExprNode::Atom) ? ATOM : ((node.Type() & TExprNode::TypeMask) | NODE_VALUE); + + if (node.Type() == TExprNode::Lambda && node.ChildrenSize() > 2U) { + command |= WIDE; + } + + if (Components_ & TSerializedExprGraphComponents::Positions) { + // will write position + if (Ctx.GetPosition(node.Pos()) == LastPosition_) { + command |= SAME_POSITION; + } + } + + if (node.Type() == TExprNode::Atom) { + command |= (TNodeFlags::FlagsMask & node.Flags()); + } + + ui32 strNum = 0; + if (node.Type() == TExprNode::Atom || node.Type() == TExprNode::Callable || node.Type() == TExprNode::Argument) { + auto strIt = StringCounters_.find(node.Content()); + YQL_ENSURE(strIt != StringCounters_.end()); + if (strIt->second.first == 1) { + command |= INLINE_STR; + } else { + strNum = strIt->second.second; + } + } + + Write(command); + if ((Components_ & TSerializedExprGraphComponents::Positions) && !(command & SAME_POSITION)) { + const auto& pos = Ctx.GetPosition(node.Pos()); + ui32 fileNum = 0; + if (pos.File) { + auto fileIt = Files_.find(pos.File); + YQL_ENSURE(fileIt != Files_.end()); + fileNum = fileIt->second; + } + + auto posIt = Positions_.find(std::make_tuple(std::move(pos.Row), std::move(pos.Column), + std::move(fileNum))); + YQL_ENSURE(posIt != Positions_.end()); + WriteVar32(posIt->second); + LastPosition_ = pos; + } + + if (node.Type() == TExprNode::Atom || node.Type() == TExprNode::Callable || node.Type() == TExprNode::Argument) { + if (command & INLINE_STR) { + WriteVar32(node.Content().length()); + WriteMany(node.Content().data(), node.Content().length()); + } else { + WriteVar32(strNum); + } + } + + if (node.Type() == TExprNode::Callable || node.Type() == TExprNode::Arguments || node.Type() == TExprNode::List || (node.Type() == TExprNode::Lambda && node.ChildrenSize() > 2U)) { + WriteVar32(node.ChildrenSize()); + } + + for (const auto& x : node.Children()) { + Save(*x); + } + + WrittenNodes_.emplace(&node, 1 + WrittenNodes_.size()); + } + +private: + void PrepareImpl(const TExprNode& node, TNodeSet& visited) { + if (!visited.emplace(&node).second) { + return; + } + + if (Components_ & TSerializedExprGraphComponents::Positions) { + const auto& pos = Ctx.GetPosition(node.Pos()); + const auto& file = pos.File; + ui32 fileNum = 0; + if (file) { + fileNum = Files_.emplace(file, 1 + (ui32)Files_.size()).first->second; + } + + Positions_.emplace(std::make_tuple(std::move(pos.Row), std::move(pos.Column), + std::move(fileNum)), (ui32)Positions_.size()); + } + + if (node.IsAtom() || node.IsCallable() || node.Type() == TExprNode::Argument) { + auto& x = StringCounters_[node.Content()]; + x.first++; + } + + for (const auto& x : node.Children()) { + PrepareImpl(*x, visited); + } + } + + Y_FORCE_INLINE void Write(char c) { + Out_.append(c); + } + + Y_FORCE_INLINE void WriteMany(const void* buf, size_t len) { + Out_.AppendNoAlias((const char*)buf, len); + } + + Y_FORCE_INLINE void WriteVar32(ui32 value) { + char buf[MAX_PACKED32_SIZE]; + Out_.AppendNoAlias(buf, Pack32(value, buf)); + } + +private: + TExprContext& Ctx; + const ui16 Components_; + THashMap<TStringBuf, ui32> Files_; + THashMap<std::tuple<ui32, ui32, ui32>, ui32> Positions_; + THashMap<TStringBuf, std::pair<ui32, ui32>> StringCounters_; // str -> id + serialized id + + TNodeMap<ui32> WrittenNodes_; + TPosition LastPosition_; + + TString Out_; +}; + +class TReader { +public: + TReader(TPosition pos, TStringBuf buffer, TExprContext& ctx) + : Pos_(pos) + , Current_(buffer.data()) + , End_(buffer.data() + buffer.size()) + , Ctx_(ctx) + , Components_(0) + { + } + + TExprNode::TPtr Load() { + try { + Components_ = ReadVar32(); + auto reusedStringCount = ReadVar32(); + Strings_.reserve(reusedStringCount); + for (ui32 i = 0; i < reusedStringCount; ++i) { + ui32 length = ReadVar32(); + auto internedBuf = Ctx_.AppendString(TStringBuf(ReadMany(length), length)); + Strings_.push_back(internedBuf); + } + + if (Components_ & TSerializedExprGraphComponents::Positions) { + auto filesCount = ReadVar32(); + Files_.reserve(filesCount); + for (ui32 i = 0; i < filesCount; ++i) { + ui32 length = ReadVar32(); + TStringBuf file(ReadMany(length), length); + Files_.push_back(TString(file)); + } + + auto positionsCount = ReadVar32(); + Positions_.reserve(positionsCount); + for (ui32 i = 0; i < positionsCount; ++i) { + ui32 row = ReadVar32(); + ui32 column = ReadVar32(); + ui32 fileNum = ReadVar32(); + if (fileNum > Files_.size()) { + ThrowCorrupted(); + } + + Positions_.push_back({ row, column, fileNum }); + } + } + + TExprNode::TPtr result = Fetch(); + if (Current_ != End_) { + ThrowCorrupted(); + } + + return result; + } catch (const yexception& e) { + TIssue issue(Pos_, TStringBuilder() << "Failed to deserialize expression graph, reason:\n" << e.what()); + issue.SetCode(UNEXPECTED_ERROR, ESeverity::TSeverityIds_ESeverityId_S_FATAL); + Ctx_.AddError(issue); + return nullptr; + } + } + +private: + TExprNode::TPtr Fetch() { + char command = Read(); + if (!(command & NODE_VALUE)) { + ui32 nodeId = ReadVar32(); + if (nodeId == 0 || nodeId > Nodes_.size()) { + ThrowCorrupted(); + } + + return Nodes_[nodeId - 1]; + } + + + command &= ~NODE_VALUE; + TPosition pos = Pos_; + if (Components_ & TSerializedExprGraphComponents::Positions) { + if (command & SAME_POSITION) { + pos = LastPosition_; + command &= ~SAME_POSITION; + } else { + ui32 posNum = ReadVar32(); + if (posNum >= Positions_.size()) { + ThrowCorrupted(); + } + + const auto& posItem = Positions_[posNum]; + + pos = TPosition(); + pos.Row = std::get<0>(posItem); + pos.Column = std::get<1>(posItem); + auto fileNum = std::get<2>(posItem); + if (fileNum > 0) { + pos.File = Files_[fileNum - 1]; + } + + LastPosition_ = pos; + } + } + + ui32 atomFlags = 0; + bool hasInlineStr = command & INLINE_STR; + command &= ~INLINE_STR; + if (command & ATOM_FLAG) { + atomFlags = command & TNodeFlags::FlagsMask; + command &= ~(ATOM_FLAG | TNodeFlags::FlagsMask); + command |= TExprNode::Atom; + } + + const bool wide = command & WIDE; + command &= ~WIDE; + + TStringBuf content; + if (command == TExprNode::Atom || command == TExprNode::Callable || command == TExprNode::Argument) { + if (hasInlineStr) { + ui32 length = ReadVar32(); + content = TStringBuf(ReadMany(length), length); + } else { + ui32 strNum = ReadVar32(); + if (strNum >= Strings_.size()) { + ThrowCorrupted(); + } + + content = Strings_[strNum]; + } + } + + ui32 childrenSize = 0; + if (command == TExprNode::Callable || command == TExprNode::Arguments || command == TExprNode::List || (command == TExprNode::Lambda && wide)) { + childrenSize = ReadVar32(); + } + + TExprNode::TPtr ret; + switch (command) { + case TExprNode::Atom: + ret = Ctx_.NewAtom(pos, content, atomFlags); + break; + case TExprNode::List: { + TExprNode::TListType children; + children.reserve(childrenSize); + for (ui32 i = 0U; i < childrenSize; ++i) { + children.emplace_back(Fetch()); + } + + ret = Ctx_.NewList(pos, std::move(children)); + break; + } + + case TExprNode::Callable: { + TExprNode::TListType children; + children.reserve(childrenSize); + for (ui32 i = 0U; i < childrenSize; ++i) { + children.emplace_back(Fetch()); + } + + ret = Ctx_.NewCallable(pos, content, std::move(children)); + break; + } + + case TExprNode::Argument: + ret = Ctx_.NewArgument(pos, content); + break; + + case TExprNode::Arguments: { + TExprNode::TListType children; + children.reserve(childrenSize); + for (ui32 i = 0U; i < childrenSize; ++i) { + children.emplace_back(Fetch()); + } + + ret = Ctx_.NewArguments(pos, std::move(children)); + break; + } + + case TExprNode::Lambda: + if (wide) { + TExprNode::TListType children; + children.reserve(childrenSize); + for (ui32 i = 0U; i < childrenSize; ++i) { + children.emplace_back(Fetch()); + } + ret = Ctx_.NewLambda(pos, std::move(children)); + } else { + auto args = Fetch(); + auto body = Fetch(); + ret = Ctx_.NewLambda(pos, {std::move(args), std::move(body)}); + } + break; + + case TExprNode::World: + ret = Ctx_.NewWorld(pos); + break; + + default: + ThrowCorrupted(); + } + + Nodes_.push_back(ret); + return ret; + } + + Y_FORCE_INLINE char Read() { + if (Current_ == End_) + ThrowNoData(); + + return *Current_++; + } + + Y_FORCE_INLINE const char* ReadMany(ui32 count) { + if (Current_ + count > End_) + ThrowNoData(); + + const char* result = Current_; + Current_ += count; + return result; + } + + Y_FORCE_INLINE ui32 ReadVar32() { + ui32 result = 0; + size_t count = Unpack32(Current_, End_ - Current_, result); + if (!count) { + ThrowCorrupted(); + } + Current_ += count; + return result; + } + + [[noreturn]] static void ThrowNoData() { + ythrow yexception() << "No more data in buffer"; + } + + [[noreturn]] static void ThrowCorrupted() { + ythrow yexception() << "Serialized data is corrupted"; + } + +private: + const TPosition Pos_; + const char* Current_; + const char* const End_; + TExprContext& Ctx_; + ui16 Components_; + + TVector<TStringBuf> Strings_; + TVector<TString> Files_; + TVector<std::tuple<ui32, ui32, ui32>> Positions_; + + TPosition LastPosition_; + TDeque<TExprNode::TPtr> Nodes_; +}; + +} + +TString SerializeGraph(const TExprNode& node, TExprContext& ctx, ui16 components) { + TWriter writer(ctx, components); + writer.Prepare(node); + writer.Init(); + writer.Save(node); + return writer.Out(); +} + +TExprNode::TPtr DeserializeGraph(TPositionHandle pos, TStringBuf buffer, TExprContext& ctx) { + return DeserializeGraph(ctx.GetPosition(pos), buffer, ctx); +} + +TExprNode::TPtr DeserializeGraph(TPosition pos, TStringBuf buffer, TExprContext& ctx) { + TReader reader(pos, buffer, ctx); + return reader.Load(); +} + +} // namespace NYql + diff --git a/yql/essentials/ast/serialize/yql_expr_serialize.h b/yql/essentials/ast/serialize/yql_expr_serialize.h new file mode 100644 index 00000000000..1496f609e82 --- /dev/null +++ b/yql/essentials/ast/serialize/yql_expr_serialize.h @@ -0,0 +1,19 @@ +#pragma once + +#include <yql/essentials/ast/yql_expr.h> + +namespace NYql { + +struct TSerializedExprGraphComponents { + enum : ui16 { + Graph = 0x00, + Positions = 0x01 + }; +}; + +TString SerializeGraph(const TExprNode& node, TExprContext& ctx, ui16 components = TSerializedExprGraphComponents::Graph); +TExprNode::TPtr DeserializeGraph(TPositionHandle pos, TStringBuf buffer, TExprContext& ctx); +TExprNode::TPtr DeserializeGraph(TPosition pos, TStringBuf buffer, TExprContext& ctx); + +} // namespace NYql + diff --git a/yql/essentials/ast/ut/ya.make b/yql/essentials/ast/ut/ya.make new file mode 100644 index 00000000000..4502c1c05f9 --- /dev/null +++ b/yql/essentials/ast/ut/ya.make @@ -0,0 +1,18 @@ +UNITTEST_FOR(yql/essentials/ast) + +FORK_SUBTESTS() + +SRCS( + yql_ast_ut.cpp + yql_expr_check_args_ut.cpp + yql_expr_builder_ut.cpp + yql_expr_ut.cpp + yql_type_string_ut.cpp + yql_constraint_ut.cpp +) + +PEERDIR( + library/cpp/yson/node +) + +END() diff --git a/yql/essentials/ast/ya.make b/yql/essentials/ast/ya.make new file mode 100644 index 00000000000..04b39457a0f --- /dev/null +++ b/yql/essentials/ast/ya.make @@ -0,0 +1,48 @@ +LIBRARY() + +SRCS( + yql_ast.cpp + yql_ast.h + yql_constraint.cpp + yql_constraint.h + yql_ast_annotation.cpp + yql_ast_annotation.h + yql_ast_escaping.cpp + yql_ast_escaping.h + yql_errors.cpp + yql_errors.h + yql_expr.cpp + yql_expr.h + yql_expr_builder.cpp + yql_expr_builder.h + yql_expr_types.cpp + yql_expr_types.h + yql_gc_nodes.cpp + yql_gc_nodes.h + yql_type_string.cpp + yql_type_string.h +) + +PEERDIR( + contrib/libs/openssl + library/cpp/colorizer + library/cpp/containers/sorted_vector + library/cpp/containers/stack_vector + library/cpp/deprecated/enum_codegen + library/cpp/enumbitset + library/cpp/string_utils/levenshtein_diff + library/cpp/yson + library/cpp/yson/node + yql/essentials/public/udf + yql/essentials/utils + yql/essentials/utils/fetch + yql/essentials/core/issue + yql/essentials/core/url_lister/interface + yql/essentials/parser/pg_catalog +) + +END() + +RECURSE_FOR_TESTS( + ut +) diff --git a/yql/essentials/ast/yql_ast.cpp b/yql/essentials/ast/yql_ast.cpp new file mode 100644 index 00000000000..4ba247130ba --- /dev/null +++ b/yql/essentials/ast/yql_ast.cpp @@ -0,0 +1,665 @@ +#include "yql_ast.h" +#include "yql_ast_escaping.h" + +#include <util/string/builder.h> +#include <util/system/compiler.h> +#include <library/cpp/containers/stack_vector/stack_vec.h> +#include <yql/essentials/utils/utf8.h> + +#include <cstdlib> + +namespace NYql { + +namespace { + + inline bool IsWhitespace(char c) { + return c == ' ' || c == '\n' || c == '\r' || c == '\t'; + } + + inline bool IsListStart(char c) { return c == '('; } + inline bool IsListEnd(char c) { return c == ')'; } + inline bool IsCommentStart(char c) { return c == '#'; } + inline bool IsQuoteStart(char c) { return c == '\''; } + inline bool IsStringStart(char c) { return c == '"'; } + inline bool IsHexStringStart(char c) { return c == 'x'; } + inline bool IsMultilineStringStart(char c) { return c == '@'; } + + inline bool NeedEscaping(const TStringBuf& str) { + for (char ch: str) { + if (IsWhitespace(ch) || IsListStart(ch) || + IsListEnd(ch) || IsCommentStart(ch) || + IsQuoteStart(ch) || IsStringStart(ch) || + !isprint(ch & 0xff)) + { + return true; + } + } + + return str.empty(); + } + + /////////////////////////////////////////////////////////////////////////// + // TAstParser + /////////////////////////////////////////////////////////////////////////// + class TAstParserContext { + public: + inline TAstParserContext(const TStringBuf& str, TMemoryPool* externalPool, const TString& file) + : Str_(str) + , Position_(1, 1, file) + , Offset_(0) + , Pool_(externalPool) + { + if (!Pool_) { + InnerPool_ = std::make_unique<TMemoryPool>(4096); + Pool_ = InnerPool_.get(); + } + } + + inline char Peek() const { + Y_ABORT_UNLESS(!AtEnd()); + return Str_[Offset_]; + } + + inline bool AtEnd() const { + return Str_.size() == Offset_; + } + + inline char Next() { + Y_ABORT_UNLESS(!AtEnd()); + char ch = Str_[Offset_]; + if (ch == '\n') { + ++Position_.Row; + Position_.Column = 1; + } else { + ++Position_.Column; + } + + ++Offset_; + return ch; + } + + // stops right afetr stopChar + inline void SeekTo(char stopChar) { + while (!AtEnd() && Next() != stopChar) { + // empty loop + } + } + + inline TStringBuf GetToken(ui32 begin, ui32 end) { + Y_ABORT_UNLESS(end >= begin); + return Str_.SubString(begin, end - begin); + } + + inline bool IsAtomEnded() { + if (AtEnd()) { + return true; + } + char c = Peek(); + return IsWhitespace(c) || IsListStart(c) || IsListEnd(c); + } + + inline const TStringBuf& Str() const { return Str_; } + inline ui32 Offset() const { return Offset_; } + inline const TPosition& Position() const { return Position_; } + inline TMemoryPool& Pool() { return *Pool_; } + inline std::unique_ptr<TMemoryPool>&& InnerPool() { return std::move(InnerPool_); } + + private: + TStringBuf Str_; + TPosition Position_; + ui32 Offset_; + TMemoryPool* Pool_; + std::unique_ptr<TMemoryPool> InnerPool_; + }; + + /////////////////////////////////////////////////////////////////////////// + // TAstParser + /////////////////////////////////////////////////////////////////////////// + class TAstParser { + public: + TAstParser(const TStringBuf& str, TMemoryPool* externalPool, const TString& file) + : Ctx_(str, externalPool, file) + { + } + + TAstParseResult Parse() { + TAstNode* root = nullptr; + if (!IsUtf8(Ctx_.Str())) { + AddError("Invalid UTF8 input"); + } else { + root = ParseList(0U); + + SkipSpace(); + if (!Ctx_.AtEnd()) { + AddError("Unexpected symbols after end of root list"); + } + } + + TAstParseResult result; + if (!Issues_.Empty()) { + result.Issues = std::move(Issues_); + } else { + result.Root = root; + result.Pool = Ctx_.InnerPool(); + } + return result; + } + + private: + inline void AddError(const TString& message) { + Issues_.AddIssue(Ctx_.Position(), message); + } + + inline void SkipComment() { + Ctx_.SeekTo('\n'); + } + + void SkipSpace() { + while (!Ctx_.AtEnd()) { + char c = Ctx_.Peek(); + if (IsWhitespace(c)) { + Ctx_.Next(); + continue; + } + + if (IsCommentStart(c)) { + SkipComment(); + continue; + } + + break; + } + } + + TAstNode* ParseList(size_t level) { + if (level >= 1000U) { + AddError("Too deep graph!"); + return nullptr; + } + + SkipSpace(); + + if (Ctx_.AtEnd()) { + AddError("Unexpected end"); + return nullptr; + } + + if (!IsListStart(Ctx_.Peek())) { + AddError("Expected ("); + return nullptr; + } + + Ctx_.Next(); + + TSmallVec<TAstNode*> children; + auto listPos = Ctx_.Position(); + while (true) { + SkipSpace(); + + if (Ctx_.AtEnd()) { + AddError("Expected )"); + return nullptr; + } + + if (IsListEnd(Ctx_.Peek())) { + Ctx_.Next(); + return TAstNode::NewList(listPos, children.data(), children.size(), Ctx_.Pool()); + } + + TAstNode* elem = ParseElement(level); + if (!elem) + return nullptr; + + children.push_back(elem); + } + } + + TAstNode* ParseElement(size_t level) { + if (Ctx_.AtEnd()) { + AddError("Expected element"); + return nullptr; + } + + char c = Ctx_.Peek(); + if (IsQuoteStart(c)) { + auto resPosition = Ctx_.Position(); + Ctx_.Next(); + + char ch = Ctx_.Peek(); + if (Ctx_.AtEnd() || IsWhitespace(ch) || IsCommentStart(ch) || + IsListEnd(ch)) + { + AddError("Expected quotation"); + return nullptr; + } + + TAstNode* content = IsListStart(ch) + ? ParseList(++level) + : ParseAtom(); + if (!content) + return nullptr; + + return TAstNode::NewList(resPosition, Ctx_.Pool(), &TAstNode::QuoteAtom, content); + } + + if (IsListStart(c)) + return ParseList(++level); + + return ParseAtom(); + } + + TAstNode* ParseAtom() { + if (Ctx_.AtEnd()) { + AddError("Expected atom"); + return nullptr; + } + + auto resPosition = Ctx_.Position(); + ui32 atomStart = Ctx_.Offset(); + + while (true) { + char c = Ctx_.Peek(); + // special symbols = 0x20, 0x0a, 0x0d, 0x09 space + // 0x22, 0x23, 0x28, 0x29 "#() + // 0x27 ' + // special symbols = 0x40, 0x78 @x + // &0x3f = 0x00,0x38 +#define MASK(x) (1ull << ui64(x)) + const ui64 mask1 = MASK(0x20) | MASK(0x0a) | MASK(0x0d) + | MASK(0x09) | MASK(0x22) | MASK(0x23) | MASK(0x28) | MASK(0x29) | MASK(0x27); + const ui64 mask2 = MASK(0x00) | MASK(0x38); +#undef MASK + if (!(c & 0x80) && ((1ull << (c & 0x3f)) & (c <= 0x3f ? mask1 : mask2))) { + if (IsWhitespace(c) || IsListStart(c) || IsListEnd(c)) + break; + + if (IsCommentStart(c)) { + AddError("Unexpected comment"); + return nullptr; + } + + if (IsQuoteStart(c)) { + AddError("Unexpected quotation"); + return nullptr; + } + + // multiline starts with '@@' + if (IsMultilineStringStart(c)) { + Ctx_.Next(); + if (Ctx_.AtEnd()) break; + + if (!IsMultilineStringStart(Ctx_.Peek())) { + continue; + } + + TString token; + if (!TryParseMultilineToken(token)) { + return nullptr; + } + + if (!Ctx_.IsAtomEnded()) { + AddError("Unexpected end of @@"); + return nullptr; + } + + return TAstNode::NewAtom(resPosition, token, Ctx_.Pool(), TNodeFlags::MultilineContent); + } + // hex string starts with 'x"' + else if (IsHexStringStart(c)) { + Ctx_.Next(); // skip 'x' + if (Ctx_.AtEnd()) break; + + if (!IsStringStart(Ctx_.Peek())) { + continue; + } + + Ctx_.Next(); // skip first '"' + + size_t readBytes = 0; + TStringStream ss; + TStringBuf atom = Ctx_.Str().SubStr(Ctx_.Offset()); + EUnescapeResult unescapeResult = UnescapeBinaryAtom( + atom, '"', &ss, &readBytes); + + // advance position + while (readBytes-- != 0) { + Ctx_.Next(); + } + + if (unescapeResult != EUnescapeResult::OK) { + AddError(TString(UnescapeResultToString(unescapeResult))); + return nullptr; + } + + Ctx_.Next(); // skip last '"' + if (!Ctx_.IsAtomEnded()) { + AddError("Unexpected end of \""); + return nullptr; + } + + return TAstNode::NewAtom(resPosition, ss.Str(), Ctx_.Pool(), TNodeFlags::BinaryContent); + } + else if (IsStringStart(c)) { + if (Ctx_.Offset() != atomStart) { + AddError("Unexpected \""); + return nullptr; + } + + Ctx_.Next(); // skip first '"' + + size_t readBytes = 0; + TStringStream ss; + TStringBuf atom = Ctx_.Str().SubStr(Ctx_.Offset()); + EUnescapeResult unescapeResult = UnescapeArbitraryAtom( + atom, '"', &ss, &readBytes); + + // advance position + while (readBytes-- != 0) { + Ctx_.Next(); + } + + if (unescapeResult != EUnescapeResult::OK) { + AddError(TString(UnescapeResultToString(unescapeResult))); + return nullptr; + } + + if (!Ctx_.IsAtomEnded()) { + AddError("Unexpected end of \""); + return nullptr; + } + + return TAstNode::NewAtom(resPosition, ss.Str(), Ctx_.Pool(), TNodeFlags::ArbitraryContent); + } + } + + Ctx_.Next(); + if (Ctx_.AtEnd()) { + break; + } + } + + return TAstNode::NewAtom(resPosition, Ctx_.GetToken(atomStart, Ctx_.Offset()), Ctx_.Pool()); + } + + bool TryParseMultilineToken(TString& token) { + Ctx_.Next(); // skip second '@' + + ui32 start = Ctx_.Offset(); + while (true) { + Ctx_.SeekTo('@'); + + if (Ctx_.AtEnd()) { + AddError("Unexpected multiline atom end"); + return false; + } + + ui32 count = 1; // we seek to first '@' + while (!Ctx_.AtEnd() && Ctx_.Peek() == '@') { + count++; + Ctx_.Next(); + if (count == 4) { + // Reduce each four '@' to two + token.append(Ctx_.GetToken(start, Ctx_.Offset() - 2)); + start = Ctx_.Offset(); + count = 0; + } + } + if (count >= 2) { + break; + } + } + + // two '@@' at the end + token.append(Ctx_.GetToken(start, Ctx_.Offset() - 2)); + return true; + } + + private: + TAstParserContext Ctx_; + TIssues Issues_; + }; + + /////////////////////////////////////////////////////////////////////////// + // ast node printing functions + /////////////////////////////////////////////////////////////////////////// + + inline bool IsQuoteNode(const TAstNode& node) { + return node.GetChildrenCount() == 2 + && node.GetChild(0)->GetType() == TAstNode::Atom + && node.GetChild(0)->GetContent() == TStringBuf("quote"); + } + + inline bool IsBlockNode(const TAstNode& node) { + return node.GetChildrenCount() == 2 + && node.GetChild(0)->GetType() == TAstNode::Atom + && node.GetChild(0)->GetContent() == TStringBuf("block"); + } + + Y_NO_INLINE void Indent(IOutputStream& out, ui32 indentation) { + char* whitespaces = reinterpret_cast<char*>(alloca(indentation)); + memset(whitespaces, ' ', indentation); + out.Write(whitespaces, indentation); + } + + void MultilineAtomPrint(IOutputStream& out, const TStringBuf& str) { + out << TStringBuf("@@"); + size_t idx = str.find('@'); + if (idx == TString::npos) { + out << str; + } else { + const char* begin = str.data(); + do { + ui32 count = 0; + for (; idx < str.length() && str[idx] == '@'; ++idx) { + ++count; + } + + if (count % 2 == 0) { + out.Write(begin, idx - (begin - str.data()) - count); + begin = str.data() + idx; + + while (count--) { + out.Write(TStringBuf("@@")); + } + } + idx = str.find('@', idx); + } while (idx != TString::npos); + out.Write(begin, str.end() - begin); + } + out << TStringBuf("@@"); + } + + void PrintNode(IOutputStream& out, const TAstNode& node) { + if (node.GetType() == TAstNode::Atom) { + if (node.GetFlags() & TNodeFlags::ArbitraryContent) { + EscapeArbitraryAtom(node.GetContent(), '"', &out); + } else if (node.GetFlags() & TNodeFlags::BinaryContent) { + EscapeBinaryAtom(node.GetContent(), '"', &out); + } else if (node.GetFlags() & TNodeFlags::MultilineContent) { + MultilineAtomPrint(out, node.GetContent()); + } else if (node.GetContent().empty()) { + EscapeArbitraryAtom(node.GetContent(), '"', &out); + } else { + out << node.GetContent(); + } + } else if (node.GetType() == TAstNode::List) { + if (!node.GetChildrenCount()) { + out << TStringBuf("()"); + } else if (IsQuoteNode(node)) { + out << '\''; + PrintNode(out, *node.GetChild(1)); + } else { + out << '('; + ui32 index = 0; + while (true) { + PrintNode(out, *node.GetChild(index)); + ++index; + if (index == node.GetChildrenCount()) break; + out << ' '; + } + out << ')'; + } + } + } + + void PrettyPrintNode( + IOutputStream& out, const TAstNode& node, + i32 indent, i32 blockLevel, i32 localIndent, ui32 flags) + { + if (!(flags & TAstPrintFlags::PerLine)) { + Indent(out, indent * 2); + } else if (localIndent == 1) { + Indent(out, blockLevel * 2); + } + + bool isQuote = false; + if (node.GetType() == TAstNode::Atom) { + if (node.GetFlags() & TNodeFlags::ArbitraryContent) { + if ((flags & TAstPrintFlags::AdaptArbitraryContent) && + !NeedEscaping(node.GetContent())) + { + out << node.GetContent(); + } else { + EscapeArbitraryAtom(node.GetContent(), '"', &out); + } + } else if (node.GetFlags() & TNodeFlags::BinaryContent) { + EscapeBinaryAtom(node.GetContent(), '"', &out); + } else if (node.GetFlags() & TNodeFlags::MultilineContent) { + MultilineAtomPrint(out, node.GetContent()); + } else { + if (((flags & TAstPrintFlags::AdaptArbitraryContent) && NeedEscaping(node.GetContent())) || + node.GetContent().empty()) + { + EscapeArbitraryAtom(node.GetContent(), '"', &out); + } else { + out << node.GetContent(); + } + } + } else if (node.GetType() == TAstNode::List) { + isQuote = IsQuoteNode(node); + if (isQuote && (flags & TAstPrintFlags::ShortQuote)) { + out << '\''; + if (localIndent == 0 || !(flags & TAstPrintFlags::PerLine)) { + out << Endl; + } + + PrettyPrintNode(out, *node.GetChild(1), indent + 1, blockLevel, localIndent + 1, flags); + } else { + out << '('; + if (localIndent == 0 || !(flags & TAstPrintFlags::PerLine)) { + out << Endl; + } + + bool isBlock = IsBlockNode(node); + for (ui32 index = 0; index < node.GetChildrenCount(); ++index) { + if (localIndent > 0 && (index > 0) && (flags & TAstPrintFlags::PerLine)) { + out << ' '; + } + + if (isBlock && (index > 0)) { + PrettyPrintNode(out, *node.GetChild(index), indent + 1, blockLevel + 1, -1, flags); + } else { + PrettyPrintNode(out, *node.GetChild(index), indent + 1, blockLevel, localIndent + 1, flags); + } + } + } + + if (!isQuote || !(flags & TAstPrintFlags::ShortQuote)) { + if (!(flags & TAstPrintFlags::PerLine)) { + Indent(out, indent * 2); + } + + if (localIndent == 0 && blockLevel > 0) { + Indent(out, (blockLevel - 1) * 2); + } + + out << ')'; + } + } + + if (!isQuote || !(flags & TAstPrintFlags::ShortQuote)) { + if (localIndent > 0 || blockLevel == 0) { + if (localIndent <= 1 || !(flags & TAstPrintFlags::PerLine)) { + out << Endl; + } + } + } + } + + void DestroyNode(TAstNode* node) { + if (node->IsList()) { + for (ui32 i = 0; i < node->GetChildrenCount(); ++i) { + DestroyNode(node->GetChild(i)); + } + } + + if (node != &TAstNode::QuoteAtom) { + node->Destroy(); + } + } +} // namespace + +TAstParseResult::~TAstParseResult() { + Destroy(); +} + +TAstParseResult::TAstParseResult(TAstParseResult&& other) + : Pool(std::move(other.Pool)) + , Root(other.Root) + , Issues(std::move(other.Issues)) + , PgAutoParamValues(std::move(other.PgAutoParamValues)) + , ActualSyntaxType(other.ActualSyntaxType) +{ + other.Root = nullptr; +} + +TAstParseResult& TAstParseResult::operator=(TAstParseResult&& other) { + Destroy(); + Pool = std::move(other.Pool); + Root = other.Root; + other.Root = nullptr; + Issues = std::move(other.Issues); + PgAutoParamValues = std::move(other.PgAutoParamValues); + ActualSyntaxType = other.ActualSyntaxType; + return *this; +} + +void TAstParseResult::Destroy() { + if (Root) { + DestroyNode(Root); + Root = nullptr; + } +} + +TAstParseResult ParseAst(const TStringBuf& str, TMemoryPool* externalPool, const TString& file) +{ + TAstParser parser(str, externalPool, file); + return parser.Parse(); +} + +void TAstNode::PrintTo(IOutputStream& out) const { + PrintNode(out, *this); +} + +void TAstNode::PrettyPrintTo(IOutputStream& out, ui32 flags) const { + PrettyPrintNode(out, *this, 0, 0, 0, flags); +} + +TAstNode TAstNode::QuoteAtom(TPosition(0, 0), TStringBuf("quote"), TNodeFlags::Default); + +} // namespace NYql + +template<> +void Out<NYql::TAstNode::EType>(class IOutputStream &o, NYql::TAstNode::EType x) { +#define YQL_AST_NODE_TYPE_MAP_TO_STRING_IMPL(name, ...) \ + case ::NYql::TAstNode::name: \ + o << #name; \ + return; + + switch (x) { + YQL_AST_NODE_TYPE_MAP(YQL_AST_NODE_TYPE_MAP_TO_STRING_IMPL) + default: + o << static_cast<int>(x); + return; + } +} diff --git a/yql/essentials/ast/yql_ast.h b/yql/essentials/ast/yql_ast.h new file mode 100644 index 00000000000..aaae6f53a1e --- /dev/null +++ b/yql/essentials/ast/yql_ast.h @@ -0,0 +1,355 @@ +#pragma once + +#include "yql_errors.h" + +#include <library/cpp/deprecated/enum_codegen/enum_codegen.h> +#include <util/generic/ptr.h> +#include <util/generic/strbuf.h> +#include <util/generic/string.h> +#include <util/generic/vector.h> +#include <util/stream/output.h> +#include <util/stream/str.h> +#include <util/memory/pool.h> +#include <util/generic/array_ref.h> + +namespace NYql { + +struct TNodeFlags { + enum : ui16 { + Default = 0, + ArbitraryContent = 0x01, + BinaryContent = 0x02, + MultilineContent = 0x04, + }; + + static constexpr ui32 FlagsMask = 0x07; // all flags should fit here +}; + +struct TAstNode { +#define YQL_AST_NODE_TYPE_MAP(xx) \ + xx(List, 0) \ + xx(Atom, 1) \ + + enum EType : ui32 { + YQL_AST_NODE_TYPE_MAP(ENUM_VALUE_GEN) + }; + + + static const ui32 SmallListCount = 2; + + void PrintTo(IOutputStream& out) const; + void PrettyPrintTo(IOutputStream& out, ui32 prettyFlags) const; + + inline TString ToString() const { + TStringStream str; + PrintTo(str); + return str.Str(); + } + + inline TString ToString(ui32 prettyFlags) const { + TStringStream str; + PrettyPrintTo(str, prettyFlags); + return str.Str(); + } + + inline EType GetType() const { + return Type; + } + + inline bool IsAtom() const { + return Type == Atom; + } + + inline bool IsList() const { + return Type == List; + } + + inline bool IsListOfSize(ui32 len) const { + return Type == List && ListCount == len; + } + + inline TPosition GetPosition() const { + return Position; + } + + inline void SetPosition(TPosition position) { + Position = position; + } + + inline TStringBuf GetContent() const { + Y_ABORT_UNLESS(IsAtom()); + return TStringBuf(Data.A.Content, Data.A.Size); + } + + inline void SetContent(TStringBuf newContent, TMemoryPool& pool) { + Y_ABORT_UNLESS(IsAtom()); + auto poolContent = pool.AppendString(newContent); + Data.A.Content = poolContent.data(); + Data.A.Size = poolContent.size(); + } + + inline void SetLiteralContent(TStringBuf newContent) { + Y_ABORT_UNLESS(IsAtom()); + Data.A.Content = newContent.data(); + Data.A.Size = newContent.size(); + } + + inline ui32 GetFlags() const { + Y_ABORT_UNLESS(IsAtom()); + return Data.A.Flags; + } + + inline void SetFlags(ui32 flags) { + Y_ABORT_UNLESS(IsAtom()); + Data.A.Flags = flags; + } + + inline ui32 GetChildrenCount() const { + Y_ABORT_UNLESS(IsList()); + return ListCount; + } + + inline const TAstNode* GetChild(ui32 index) const { + Y_ABORT_UNLESS(IsList()); + Y_ABORT_UNLESS(index < ListCount); + if (ListCount <= SmallListCount) { + return Data.S.Children[index]; + } else { + return Data.L.Children[index]; + } + } + + inline TAstNode* GetChild(ui32 index) { + Y_ABORT_UNLESS(IsList()); + Y_ABORT_UNLESS(index < ListCount); + if (ListCount <= SmallListCount) { + return Data.S.Children[index]; + } else { + return Data.L.Children[index]; + } + } + + inline TArrayRef<TAstNode* const> GetChildren() const { + Y_ABORT_UNLESS(IsList()); + return {ListCount <= SmallListCount ? Data.S.Children : Data.L.Children, ListCount}; + } + + static inline TAstNode* NewAtom(TPosition position, TStringBuf content, TMemoryPool& pool, ui32 flags = TNodeFlags::Default) { + auto poolContent = pool.AppendString(content); + auto ret = pool.Allocate<TAstNode>(); + ::new(ret) TAstNode(position, poolContent, flags); + return ret; + } + + // atom with non-owning content, useful for literal strings + static inline TAstNode* NewLiteralAtom(TPosition position, TStringBuf content, TMemoryPool& pool, ui32 flags = TNodeFlags::Default) { + auto ret = pool.Allocate<TAstNode>(); + ::new(ret) TAstNode(position, content, flags); + return ret; + } + + static inline TAstNode* NewList(TPosition position, TAstNode** children, ui32 childrenCount, TMemoryPool& pool) { + TAstNode** poolChildren = nullptr; + if (childrenCount) { + if (childrenCount > SmallListCount) { + poolChildren = pool.AllocateArray<TAstNode*>(childrenCount); + memcpy(poolChildren, children, sizeof(void*) * childrenCount); + } else { + poolChildren = children; + } + + for (ui32 index = 0; index < childrenCount; ++index) { + Y_ABORT_UNLESS(poolChildren[index]); + } + } + + auto ret = pool.Allocate<TAstNode>(); + ::new(ret) TAstNode(position, poolChildren, childrenCount); + return ret; + } + + template <typename... TNodes> + static inline TAstNode* NewList(TPosition position, TMemoryPool& pool, TNodes... nodes) { + TAstNode* children[] = { nodes... }; + return NewList(position, children, sizeof...(nodes), pool); + } + + static inline TAstNode* NewList(TPosition position, TMemoryPool& pool) { + return NewList(position, nullptr, 0, pool); + } + + static TAstNode QuoteAtom; + + static inline TAstNode* Quote(TPosition position, TMemoryPool& pool, TAstNode* node) { + return NewList(position, pool, &QuoteAtom, node); + } + + inline ~TAstNode() {} + + void Destroy() { + TString().swap(Position.File); + } + +private: + inline TAstNode(TPosition position, TStringBuf content, ui32 flags) + : Position(position) + , Type(Atom) + , ListCount(0) + { + Data.A.Content = content.data(); + Data.A.Size = content.size(); + Data.A.Flags = flags; + } + + inline TAstNode(TPosition position, TAstNode** children, ui32 childrenCount) + : Position(position) + , Type(List) + , ListCount(childrenCount) + { + if (childrenCount <= SmallListCount) { + for (ui32 index = 0; index < childrenCount; ++index) { + Data.S.Children[index] = children[index]; + } + } else { + Data.L.Children = children; + } + } + + TPosition Position; + const EType Type; + const ui32 ListCount; + + struct TAtom { + const char* Content; + ui32 Size; + ui32 Flags; + }; + + struct TListType { + TAstNode** Children; + }; + + struct TSmallList { + TAstNode* Children[SmallListCount]; + }; + + union { + TAtom A; + TListType L; + TSmallList S; + } Data; +}; + +enum class ESyntaxType { + YQLv0, + YQLv1, + Pg, +}; + +class IAutoParamBuilder; +class IAutoParamDataBuilder; + +class IAutoParamTypeBuilder { +public: + virtual ~IAutoParamTypeBuilder() = default; + + virtual void Pg(const TString& name) = 0; + + virtual void BeginList() = 0; + + virtual void EndList() = 0; + + virtual void BeginTuple() = 0; + + virtual void EndTuple() = 0; + + virtual void BeforeItem() = 0; + + virtual void AfterItem() = 0; + + virtual IAutoParamDataBuilder& FinishType() = 0; +}; + +class IAutoParamDataBuilder { +public: + virtual ~IAutoParamDataBuilder() = default; + + virtual void Pg(const TMaybe<TString>& value) = 0; + + virtual void BeginList() = 0; + + virtual void EndList() = 0; + + virtual void BeginTuple() = 0; + + virtual void EndTuple() = 0; + + virtual void BeforeItem() = 0; + + virtual void AfterItem() = 0; + + virtual IAutoParamBuilder& FinishData() = 0; +}; + +class IAutoParamBuilder : public TThrRefBase { +public: + virtual ~IAutoParamBuilder() = default; + + virtual ui32 Size() const = 0; + + virtual bool Contains(const TString& name) const = 0; + + virtual IAutoParamTypeBuilder& Add(const TString& name) = 0; +}; + +using IAutoParamBuilderPtr = TIntrusivePtr<IAutoParamBuilder>; + +class IAutoParamBuilderFactory { +public: + virtual ~IAutoParamBuilderFactory() = default; + + virtual IAutoParamBuilderPtr MakeBuilder() = 0; +}; + +struct TAstParseResult { + std::unique_ptr<TMemoryPool> Pool; + TAstNode* Root = nullptr; + TIssues Issues; + IAutoParamBuilderPtr PgAutoParamValues; + ESyntaxType ActualSyntaxType = ESyntaxType::YQLv1; + + inline bool IsOk() const { + return !!Root; + } + + TAstParseResult() = default; + ~TAstParseResult(); + TAstParseResult(const TAstParseResult&) = delete; + TAstParseResult& operator=(const TAstParseResult&) = delete; + + TAstParseResult(TAstParseResult&&); + TAstParseResult& operator=(TAstParseResult&&); + + void Destroy(); +}; + +struct TStmtParseInfo { + bool KeepInCache = true; + TMaybe<TString> CommandTagName = {}; +}; + +struct TAstPrintFlags { + enum { + Default = 0, + PerLine = 0x01, + ShortQuote = 0x02, + AdaptArbitraryContent = 0x04, + }; +}; + +TAstParseResult ParseAst(const TStringBuf& str, TMemoryPool* externalPool = nullptr, const TString& file = {}); + +} // namespace NYql + +template<> +void Out<NYql::TAstNode::EType>(class IOutputStream &o, NYql::TAstNode::EType x); diff --git a/yql/essentials/ast/yql_ast_annotation.cpp b/yql/essentials/ast/yql_ast_annotation.cpp new file mode 100644 index 00000000000..fc97b879c89 --- /dev/null +++ b/yql/essentials/ast/yql_ast_annotation.cpp @@ -0,0 +1,189 @@ +#include "yql_ast_annotation.h" +#include <util/string/printf.h> +#include <util/string/split.h> +#include <util/string/cast.h> +#include <util/string/builder.h> +#include <library/cpp/containers/stack_vector/stack_vec.h> + +namespace NYql { + +namespace { + +TAstNode* AnnotateNodePosition(TAstNode& node, TMemoryPool& pool) { + auto newPosition = node.GetPosition(); + TAstNode* pos = PositionAsNode(node.GetPosition(), pool); + TAstNode* shallowClone = &node; + if (node.IsList()) { + TSmallVec<TAstNode*> listChildren(node.GetChildrenCount()); + for (ui32 index = 0; index < node.GetChildrenCount(); ++index) { + listChildren[index] = AnnotateNodePosition(*node.GetChild(index), pool); + } + + shallowClone = TAstNode::NewList(node.GetPosition(), listChildren.data(), listChildren.size(), pool); + } + + return TAstNode::NewList(newPosition, pool, pos, shallowClone); +} + +TAstNode* RemoveNodeAnnotations(TAstNode& node, TMemoryPool& pool) { + if (!node.IsList()) + return nullptr; + + if (node.GetChildrenCount() == 0) + return nullptr; + + auto lastNode = node.GetChild(node.GetChildrenCount() - 1); + auto res = lastNode; + if (lastNode->IsList()) { + TSmallVec<TAstNode*> listChildren(lastNode->GetChildrenCount()); + for (ui32 index = 0; index < lastNode->GetChildrenCount(); ++index) { + auto item = RemoveNodeAnnotations(*lastNode->GetChild(index), pool); + if (!item) + return nullptr; + + listChildren[index] = item; + } + + res = TAstNode::NewList(lastNode->GetPosition(), listChildren.data(), listChildren.size(), pool); + } + + return res; +} + +TAstNode* ExtractNodeAnnotations(TAstNode& node, TAnnotationNodeMap& annotations, TMemoryPool& pool) { + if (!node.IsList()) + return nullptr; + + if (node.GetChildrenCount() == 0) + return nullptr; + + auto lastNode = node.GetChild(node.GetChildrenCount() - 1); + auto res = lastNode; + if (lastNode->IsList()) { + TSmallVec<TAstNode*> listChildren(lastNode->GetChildrenCount()); + for (ui32 index = 0; index < lastNode->GetChildrenCount(); ++index) { + auto item = ExtractNodeAnnotations(*lastNode->GetChild(index), annotations, pool); + if (!item) + return nullptr; + + listChildren[index] = item; + } + + res = TAstNode::NewList(lastNode->GetPosition(), listChildren.data(), listChildren.size(), pool); + } + + auto& v = annotations[res]; + v.resize(node.GetChildrenCount() - 1); + for (ui32 index = 0; index + 1 < node.GetChildrenCount(); ++index) { + v[index] = node.GetChild(index); + } + + return res; +} + +TAstNode* ApplyNodePositionAnnotations(TAstNode& node, ui32 annotationIndex, TMemoryPool& pool) { + if (!node.IsList()) + return nullptr; + + if (node.GetChildrenCount() < annotationIndex + 2) + return nullptr; + + auto annotation = node.GetChild(annotationIndex); + auto str = annotation->GetContent(); + TStringBuf rowPart; + TStringBuf colPart; + TString filePart; + GetNext(str, ':', rowPart); + GetNext(str, ':', colPart); + filePart = str; + + ui32 row = 0, col = 0; + if (!TryFromString(rowPart, row) || !TryFromString(colPart, col)) + return nullptr; + + TSmallVec<TAstNode*> listChildren(node.GetChildrenCount()); + for (ui32 index = 0; index < node.GetChildrenCount() - 1; ++index) { + listChildren[index] = node.GetChild(index); + } + + auto lastNode = node.GetChild(node.GetChildrenCount() - 1); + TAstNode* lastResNode; + if (lastNode->IsAtom()) { + lastResNode = TAstNode::NewAtom(TPosition(col, row, filePart), lastNode->GetContent(), pool, lastNode->GetFlags()); + } else { + TSmallVec<TAstNode*> lastNodeChildren(lastNode->GetChildrenCount()); + for (ui32 index = 0; index < lastNode->GetChildrenCount(); ++index) { + lastNodeChildren[index] = ApplyNodePositionAnnotations(*lastNode->GetChild(index), annotationIndex, pool); + } + + lastResNode = TAstNode::NewList(TPosition(col, row, filePart), lastNodeChildren.data(), lastNodeChildren.size(), pool); + } + + listChildren[node.GetChildrenCount() - 1] = lastResNode; + return TAstNode::NewList(node.GetPosition(), listChildren.data(), listChildren.size(), pool); +} + +bool ApplyNodePositionAnnotationsInplace(TAstNode& node, ui32 annotationIndex) { + if (!node.IsList()) + return false; + + if (node.GetChildrenCount() < annotationIndex + 2) + return false; + + auto annotation = node.GetChild(annotationIndex); + TStringBuf str = annotation->GetContent(); + TStringBuf rowPart; + TStringBuf colPart; + TString filePart; + GetNext(str, ':', rowPart); + GetNext(str, ':', colPart); + filePart = str; + ui32 row = 0, col = 0; + if (!TryFromString(rowPart, row) || !TryFromString(colPart, col)) + return false; + + auto lastNode = node.GetChild(node.GetChildrenCount() - 1); + lastNode->SetPosition(TPosition(col, row, filePart)); + if (lastNode->IsList()) { + for (ui32 index = 0; index < lastNode->GetChildrenCount(); ++index) { + if (!ApplyNodePositionAnnotationsInplace(*lastNode->GetChild(index), annotationIndex)) + return false; + } + } + + return true; +} + +} + +TAstNode* AnnotatePositions(TAstNode& root, TMemoryPool& pool) { + return AnnotateNodePosition(root, pool); +} + +TAstNode* RemoveAnnotations(TAstNode& root, TMemoryPool& pool) { + return RemoveNodeAnnotations(root, pool); +} + +TAstNode* ApplyPositionAnnotations(TAstNode& root, ui32 annotationIndex, TMemoryPool& pool) { + return ApplyNodePositionAnnotations(root, annotationIndex, pool); +} + +bool ApplyPositionAnnotationsInplace(TAstNode& root, ui32 annotationIndex) { + return ApplyNodePositionAnnotationsInplace(root, annotationIndex); +} + +TAstNode* PositionAsNode(TPosition position, TMemoryPool& pool) { + TStringBuilder str; + str << position.Row << ':' << position.Column; + if (!position.File.empty()) { + str << ':' << position.File; + } + + return TAstNode::NewAtom(position, str, pool); +} + +TAstNode* ExtractAnnotations(TAstNode& root, TAnnotationNodeMap& annotations, TMemoryPool& pool) { + return ExtractNodeAnnotations(root, annotations, pool); +} + +} diff --git a/yql/essentials/ast/yql_ast_annotation.h b/yql/essentials/ast/yql_ast_annotation.h new file mode 100644 index 00000000000..c909ce9df4d --- /dev/null +++ b/yql/essentials/ast/yql_ast_annotation.h @@ -0,0 +1,22 @@ +#pragma once +#include "yql_ast.h" +#include <util/generic/hash.h> + +namespace NYql { + +TAstNode* PositionAsNode(TPosition position, TMemoryPool& pool); + +TAstNode* AnnotatePositions(TAstNode& root, TMemoryPool& pool); +// returns nullptr in case of error +TAstNode* RemoveAnnotations(TAstNode& root, TMemoryPool& pool); +// returns nullptr in case of error +TAstNode* ApplyPositionAnnotations(TAstNode& root, ui32 annotationIndex, TMemoryPool& pool); +// returns false in case of error +bool ApplyPositionAnnotationsInplace(TAstNode& root, ui32 annotationIndex); + +typedef THashMap<const TAstNode*, TVector<const TAstNode*>> TAnnotationNodeMap; + +// returns nullptr in case of error +TAstNode* ExtractAnnotations(TAstNode& root, TAnnotationNodeMap& annotations, TMemoryPool& pool); + +} diff --git a/yql/essentials/ast/yql_ast_escaping.cpp b/yql/essentials/ast/yql_ast_escaping.cpp new file mode 100644 index 00000000000..56aee8f5964 --- /dev/null +++ b/yql/essentials/ast/yql_ast_escaping.cpp @@ -0,0 +1,275 @@ +#include "yql_ast_escaping.h" + +#include <util/charset/wide.h> +#include <util/stream/output.h> +#include <util/string/hex.h> + + +namespace NYql { + +static char HexDigit(char c) +{ + return (c < 10 ? '0' + c : 'A' + (c - 10)); +} + +static void EscapedPrintChar(ui8 c, IOutputStream* out) +{ + switch (c) { + case '\\': out->Write("\\\\", 2); break; + case '"' : out->Write("\\\"", 2); break; + case '\t': out->Write("\\t", 2); break; + case '\n': out->Write("\\n", 2); break; + case '\r': out->Write("\\r", 2); break; + case '\b': out->Write("\\b", 2); break; + case '\f': out->Write("\\f", 2); break; + case '\a': out->Write("\\a", 2); break; + case '\v': out->Write("\\v", 2); break; + default: { + if (isprint(c)) out->Write(static_cast<char>(c)); + else { + char buf[4] = { "\\x" }; + buf[2] = HexDigit((c & 0xf0) >> 4); + buf[3] = HexDigit((c & 0x0f)); + out->Write(buf, 4); + } + } + } +} + +static void EscapedPrintUnicode(wchar32 rune, IOutputStream* out) +{ + static const int MAX_ESCAPE_LEN = 10; + + if (rune < 0x80) { + EscapedPrintChar(static_cast<ui8>(rune & 0xff), out); + } else { + int i = 0; + char buf[MAX_ESCAPE_LEN]; + + if (rune < 0x10000) { + buf[i++] = '\\'; + buf[i++] = 'u'; + } else { + buf[i++] = '\\'; + buf[i++] = 'U'; + buf[i++] = HexDigit((rune & 0xf0000000) >> 28); + buf[i++] = HexDigit((rune & 0x0f000000) >> 24); + buf[i++] = HexDigit((rune & 0x00f00000) >> 20); + buf[i++] = HexDigit((rune & 0x000f0000) >> 16); + } + + buf[i++] = HexDigit((rune & 0xf000) >> 12); + buf[i++] = HexDigit((rune & 0x0f00) >> 8); + buf[i++] = HexDigit((rune & 0x00f0) >> 4); + buf[i++] = HexDigit((rune & 0x000f)); + + out->Write(buf, i); + } +} + +static bool TryParseOctal(const char*& p, const char* e, int maxlen, wchar32* value) +{ + while (maxlen-- && p != e) { + if (*value > 255) return false; + + char ch = *p++; + if (ch >= '0' && ch <= '7') { + *value = *value * 8 + (ch - '0'); + continue; + } + + break; + } + + return (maxlen == -1); +} + +static bool TryParseHex(const char*& p, const char* e, int maxlen, wchar32* value) +{ + while (maxlen-- > 0 && p != e) { + char ch = *p++; + if (ch >= '0' && ch <= '9') { + *value = *value * 16 + (ch - '0'); + continue; + } + + // to lower case + ch |= 0x20; + + if (ch >= 'a' && ch <= 'f') { + *value = *value * 16 + (ch - 'a') + 10; + continue; + } + + break; + } + + return (maxlen == -1); +} + +static bool IsValidUtf8Rune(wchar32 value) { + return value <= 0x10ffff && (value < 0xd800 || value > 0xdfff); +} + +TStringBuf UnescapeResultToString(EUnescapeResult result) +{ + switch (result) { + case EUnescapeResult::OK: + return "OK"; + case EUnescapeResult::INVALID_ESCAPE_SEQUENCE: + return "Expected escape sequence"; + case EUnescapeResult::INVALID_BINARY: + return "Invalid binary value"; + case EUnescapeResult::INVALID_OCTAL: + return "Invalid octal value"; + case EUnescapeResult::INVALID_HEXADECIMAL: + return "Invalid hexadecimal value"; + case EUnescapeResult::INVALID_UNICODE: + return "Invalid unicode value"; + case EUnescapeResult::INVALID_END: + return "Unexpected end of atom"; + } + return "Unknown unescape error"; +} + +void EscapeArbitraryAtom(TStringBuf atom, char quoteChar, IOutputStream* out) +{ + out->Write(quoteChar); + const ui8 *p = reinterpret_cast<const ui8*>(atom.begin()), + *e = reinterpret_cast<const ui8*>(atom.end()); + while (p != e) { + wchar32 rune = 0; + size_t rune_len = 0; + + if (SafeReadUTF8Char(rune, rune_len, p, e) == RECODE_RESULT::RECODE_OK && IsValidUtf8Rune(rune)) { + EscapedPrintUnicode(rune, out); + p += rune_len; + } else { + EscapedPrintChar(*p++, out); + } + } + out->Write(quoteChar); +} + +EUnescapeResult UnescapeArbitraryAtom( + TStringBuf atom, char endChar, IOutputStream* out, size_t* readBytes) +{ + const char *p = atom.begin(), + *e = atom.end(); + + while (p != e) { + char current = *p++; + + // C-style escape sequences + if (current == '\\') { + if (p == e) { + *readBytes = p - atom.begin(); + return EUnescapeResult::INVALID_ESCAPE_SEQUENCE; + } + + char next = *p++; + switch (next) { + case 't': current = '\t'; break; + case 'n': current = '\n'; break; + case 'r': current = '\r'; break; + case 'b': current = '\b'; break; + case 'f': current = '\f'; break; + case 'a': current = '\a'; break; + case 'v': current = '\v'; break; + case '0': case '1': case '2': case '3': { + wchar32 value = (next - '0'); + if (!TryParseOctal(p, e, 2, &value)) { + *readBytes = p - atom.begin(); + return EUnescapeResult::INVALID_OCTAL; + } + current = value & 0xff; + break; + } + case 'x': { + wchar32 value = 0; + if (!TryParseHex(p, e, 2, &value)) { + *readBytes = p - atom.begin(); + return EUnescapeResult::INVALID_HEXADECIMAL; + } + current = value & 0xff; + break; + } + case 'u': + case 'U': { + wchar32 value = 0; + int len = (next == 'u' ? 4 : 8); + if (!TryParseHex(p, e, len, &value) || !IsValidUtf8Rune(value)) { + *readBytes = p - atom.begin(); + return EUnescapeResult::INVALID_UNICODE; + } + size_t written = 0; + char buf[4]; + WideToUTF8(&value, 1, buf, written); + out->Write(buf, written); + continue; + } + default: { + current = next; + } + } + } else if (endChar == '`') { + if (current == '`') { + if (p == e) { + *readBytes = p - atom.begin(); + return EUnescapeResult::OK; + } else { + if (*p != '`') { + *readBytes = p - atom.begin(); + return EUnescapeResult::INVALID_ESCAPE_SEQUENCE; + } else { + p++; + } + } + } + } else if (current == endChar) { + *readBytes = p - atom.begin(); + return EUnescapeResult::OK; + } + + out->Write(current); + } + + *readBytes = p - atom.begin(); + return EUnescapeResult::INVALID_END; +} + +void EscapeBinaryAtom(TStringBuf atom, char quoteChar, IOutputStream* out) +{ + char prefix[] = { 'x', quoteChar }; + out->Write(prefix, 2); + out->Write(HexEncode(atom.data(), atom.size())); + out->Write(quoteChar); +} + +EUnescapeResult UnescapeBinaryAtom( + TStringBuf atom, char endChar, IOutputStream* out, size_t* readBytes) +{ + const char *p = atom.begin(), + *e = atom.end(); + + while (p != e) { + char current = *p; + if (current == endChar) { + *readBytes = p - atom.begin(); + return EUnescapeResult::OK; + } + + wchar32 byte = 0; + if (!TryParseHex(p, e, 2, &byte)) { + *readBytes = p - atom.begin(); + return EUnescapeResult::INVALID_BINARY; + } + + out->Write(byte & 0xff); + } + + *readBytes = p - atom.begin(); + return EUnescapeResult::INVALID_END; +} + +} // namspace NYql diff --git a/yql/essentials/ast/yql_ast_escaping.h b/yql/essentials/ast/yql_ast_escaping.h new file mode 100644 index 00000000000..744ab985161 --- /dev/null +++ b/yql/essentials/ast/yql_ast_escaping.h @@ -0,0 +1,35 @@ +#pragma once + +#include <util/generic/fwd.h> +#include <util/system/types.h> +#include <util/generic/strbuf.h> + + +class IOutputStream; + +namespace NYql { + +enum class EUnescapeResult +{ + OK, + INVALID_ESCAPE_SEQUENCE, + INVALID_BINARY, + INVALID_OCTAL, + INVALID_HEXADECIMAL, + INVALID_UNICODE, + INVALID_END, +}; + +TStringBuf UnescapeResultToString(EUnescapeResult result); + +void EscapeArbitraryAtom(TStringBuf atom, char quoteChar, IOutputStream* out); + +EUnescapeResult UnescapeArbitraryAtom( + TStringBuf atom, char endChar, IOutputStream* out, size_t* readBytes); + +void EscapeBinaryAtom(TStringBuf atom, char quoteChar, IOutputStream* out); + +EUnescapeResult UnescapeBinaryAtom( + TStringBuf atom, char endChar, IOutputStream* out, size_t* readBytes); + +} // namspace NYql diff --git a/yql/essentials/ast/yql_ast_ut.cpp b/yql/essentials/ast/yql_ast_ut.cpp new file mode 100644 index 00000000000..6f228966c2a --- /dev/null +++ b/yql/essentials/ast/yql_ast_ut.cpp @@ -0,0 +1,403 @@ +#include "yql_ast.h" +#include "yql_ast_annotation.h" + +#include <library/cpp/testing/unittest/registar.h> + +#include <util/string/util.h> +#include <util/system/sanitizers.h> + +namespace NYql { + +Y_UNIT_TEST_SUITE(TParseYqlAst) { + constexpr TStringBuf TEST_PROGRAM = + "(\n" + "#comment\n" + "(let mr_source (DataSource 'yamr 'cedar))\n" + "(let x (Read! world mr_source (Key '('table (KeyString 'Input))) '('key 'value) '()))\n" + "(let world (Left! x))\n" + "(let table1 (Right! x))\n" + "(let tresh (Int32 '100))\n" + "(let table1low (Filter table1 (lambda '(item) (< (member item 'key) tresh))))\n" + "(let mr_sink (DataSink 'yamr (quote cedar)))\n" + "(let world (Write! world mr_sink (Key '('table (KeyString 'Output))) table1low '('('mode 'append))))\n" + "(let world (Commit! world mr_sink))\n" + "(return world)\n" + ")"; + + Y_UNIT_TEST(ParseAstTest) { + TAstParseResult res = ParseAst(TEST_PROGRAM); + UNIT_ASSERT(res.IsOk()); + UNIT_ASSERT(res.Root->IsList()); + UNIT_ASSERT(res.Issues.Empty()); + } + + Y_UNIT_TEST(ParseAstTestPerf) { +#ifdef WITH_VALGRIND + const ui32 n = 1000; +#else + const ui32 n = NSan::PlainOrUnderSanitizer(100000, 1000); +#endif + auto t1 = TInstant::Now(); + for (ui32 i = 0; i < n; ++i) { + TAstParseResult res = ParseAst(TEST_PROGRAM); + UNIT_ASSERT(res.IsOk()); + UNIT_ASSERT(res.Root->IsList()); + UNIT_ASSERT(res.Issues.Empty()); + } + auto t2 = TInstant::Now(); + Cout << t2 - t1 << Endl; + } + + Y_UNIT_TEST(PrintAstTest) { + TAstParseResult ast = ParseAst(TEST_PROGRAM); + UNIT_ASSERT(ast.IsOk()); + + TString printedProgram = ast.Root->ToString(); + UNIT_ASSERT(printedProgram.find('\n') == TString::npos); + + TAstParseResult parsedAst = ParseAst(printedProgram); + UNIT_ASSERT(parsedAst.IsOk()); + } + + Y_UNIT_TEST(PrettyPrintAst) { + const ui32 testFlags[] = { + TAstPrintFlags::Default, + TAstPrintFlags::PerLine, + //TAstPrintFlags::ShortQuote, //-- generates invalid AST + TAstPrintFlags::PerLine | TAstPrintFlags::ShortQuote + }; + + TAstParseResult ast = ParseAst(TEST_PROGRAM); + UNIT_ASSERT(ast.IsOk()); + + for (ui32 i = 0; i < Y_ARRAY_SIZE(testFlags); ++i) { + ui32 prettyFlags = testFlags[i]; + + TString printedProgram1 = ast.Root->ToString(prettyFlags); + TAstParseResult parsedAst = ParseAst(printedProgram1); + UNIT_ASSERT(parsedAst.IsOk()); + + TString printedProgram2 = parsedAst.Root->ToString(prettyFlags); + UNIT_ASSERT_STRINGS_EQUAL(printedProgram1, printedProgram2); + } + } + + Y_UNIT_TEST(AnnotatedAstPrint) { + TMemoryPool pool(4096); + TAstParseResult ast = ParseAst(TEST_PROGRAM, &pool); + UNIT_ASSERT(ast.IsOk()); + + TAstParseResult astWithPositions; + astWithPositions.Root = AnnotatePositions(*ast.Root, pool); + UNIT_ASSERT(!!astWithPositions.Root); + + TString sAnn = astWithPositions.Root->ToString(); + UNIT_ASSERT(false == sAnn.empty()); + + TAstParseResult annRes = ParseAst(sAnn); + UNIT_ASSERT(annRes.IsOk()); + + TAstParseResult removedAnn; + removedAnn.Root = RemoveAnnotations(*annRes.Root, pool); + UNIT_ASSERT(!!removedAnn.Root); + + TString strOriginal = ast.Root->ToString(); + TString strAnnRemoved = removedAnn.Root->ToString(); + UNIT_ASSERT_VALUES_EQUAL(strOriginal, strAnnRemoved); + + astWithPositions.Root->GetChild(0)->SetContent("100:100", pool); + + TAstParseResult appliedPositionsAnn; + appliedPositionsAnn.Root = ApplyPositionAnnotations(*astWithPositions.Root, 0, pool); + UNIT_ASSERT(appliedPositionsAnn.Root); + + TAstParseResult removedAnn2; + removedAnn2.Root = RemoveAnnotations(*appliedPositionsAnn.Root, pool); + UNIT_ASSERT(removedAnn2.Root); + UNIT_ASSERT_VALUES_EQUAL(removedAnn2.Root->GetPosition().Row, 100); + } + + template <typename TCharType> + void TestGoodArbitraryAtom( + const TString& program, + const TBasicStringBuf<TCharType>& expectedValue) + { + TAstParseResult ast = ParseAst(program); + UNIT_ASSERT(ast.IsOk()); + UNIT_ASSERT_VALUES_EQUAL(ast.Root->GetChildrenCount(), 1); + + TAstNode* atom = ast.Root->GetChild(0); + UNIT_ASSERT(atom->IsAtom()); + UNIT_ASSERT_STRINGS_EQUAL_C( + atom->GetContent(), + TString((char*)expectedValue.data(), expectedValue.size()), + program); + } + + Y_UNIT_TEST(GoodArbitraryAtom) { + TestGoodArbitraryAtom("(\"\")", TStringBuf()); + TestGoodArbitraryAtom("(\" 1 a 3 b \")", TStringBuf(" 1 a 3 b ")); + + ui8 expectedHex[] = { 0xab, 'c', 'd', 0x00 }; + TestGoodArbitraryAtom("(\"\\xabcd\")", TBasicStringBuf<ui8>(expectedHex)); + TestGoodArbitraryAtom("(\" \\x3d \")", TStringBuf(" \x3d ")); + + ui8 expectedOctal[] = { 056, '7', '8', 0x00 }; + TestGoodArbitraryAtom("(\"\\05678\")", TBasicStringBuf<ui8>(expectedOctal)); + TestGoodArbitraryAtom("(\" \\056 \")", TStringBuf(" \056 ")); + TestGoodArbitraryAtom("(\" \\177 \")", TStringBuf(" \177 ")); + TestGoodArbitraryAtom("(\" \\377 \")", TStringBuf(" \377 ")); + TestGoodArbitraryAtom("(\" \\477 \")", TStringBuf(" 477 ")); + + { + ui8 expected1[] = { 0x01, 0x00 }; + TestGoodArbitraryAtom("(\"\\u0001\")", TBasicStringBuf<ui8>(expected1)); + + ui8 expected2[] = { 0xE1, 0x88, 0xB4, 0x00 }; + TestGoodArbitraryAtom("(\"\\u1234\")", TBasicStringBuf<ui8>(expected2)); + + ui8 expected3[] = { 0xef, 0xbf, 0xbf, 0x00 }; + TestGoodArbitraryAtom("(\"\\uffff\")", TBasicStringBuf<ui8>(expected3)); + } + + { + ui8 expected1[] = { 0x01, 0x00 }; + TestGoodArbitraryAtom("(\"\\U00000001\")", TBasicStringBuf<ui8>(expected1)); + + ui8 expected2[] = { 0xf4, 0x8f, 0xbf, 0xbf, 0x00 }; + TestGoodArbitraryAtom("(\"\\U0010ffff\")", TBasicStringBuf<ui8>(expected2)); + } + + TestGoodArbitraryAtom("(\"\\t\")", TStringBuf("\t")); + TestGoodArbitraryAtom("(\"\\n\")", TStringBuf("\n")); + TestGoodArbitraryAtom("(\"\\r\")", TStringBuf("\r")); + TestGoodArbitraryAtom("(\"\\b\")", TStringBuf("\b")); + TestGoodArbitraryAtom("(\"\\f\")", TStringBuf("\f")); + TestGoodArbitraryAtom("(\"\\a\")", TStringBuf("\a")); + TestGoodArbitraryAtom("(\"\\v\")", TStringBuf("\v")); + } + + void TestBadArbitraryAtom( + const TString& program, + const TString& expectedError) + { + TAstParseResult ast = ParseAst(program); + UNIT_ASSERT(false == ast.IsOk()); + UNIT_ASSERT(false == !!ast.Root); + UNIT_ASSERT(false == ast.Issues.Empty()); + UNIT_ASSERT_STRINGS_EQUAL(ast.Issues.begin()->GetMessage(), expectedError); + } + + Y_UNIT_TEST(BadArbitraryAtom) { + TestBadArbitraryAtom("(a\")", "Unexpected \""); + TestBadArbitraryAtom("(\"++++\"11111)", "Unexpected end of \""); + TestBadArbitraryAtom("(\"\\", "Expected escape sequence"); + TestBadArbitraryAtom("(\"\\\")", "Unexpected end of atom"); + TestBadArbitraryAtom("(\"abc)", "Unexpected end of atom"); + + TestBadArbitraryAtom("(\"\\018\")", "Invalid octal value"); + TestBadArbitraryAtom("(\"\\01\")", "Invalid octal value"); + TestBadArbitraryAtom("(\"\\378\")", "Invalid octal value"); + + TestBadArbitraryAtom("(\"\\x1g\")", "Invalid hexadecimal value"); + TestBadArbitraryAtom("(\"\\xf\")", "Invalid hexadecimal value"); + + TestBadArbitraryAtom("(\"\\u\")", "Invalid unicode value"); + TestBadArbitraryAtom("(\"\\u1\")", "Invalid unicode value"); + TestBadArbitraryAtom("(\"\\u12\")", "Invalid unicode value"); + TestBadArbitraryAtom("(\"\\u123\")", "Invalid unicode value"); + TestBadArbitraryAtom("(\"\\ughij\")", "Invalid unicode value"); + + TestBadArbitraryAtom("(\"\\U\")", "Invalid unicode value"); + TestBadArbitraryAtom("(\"\\U11\")", "Invalid unicode value"); + TestBadArbitraryAtom("(\"\\U1122\")", "Invalid unicode value"); + TestBadArbitraryAtom("(\"\\U112233\")", "Invalid unicode value"); + TestBadArbitraryAtom("(\"\\Ughijklmn\")", "Invalid unicode value"); + TestBadArbitraryAtom("(\"\\U00110000\")", "Invalid unicode value"); + TestBadArbitraryAtom("(\"\\U00123456\")", "Invalid unicode value"); + TestBadArbitraryAtom("(\"\\U00200000\")", "Invalid unicode value"); + TestBadArbitraryAtom("(\"\\Uffffffff\")", "Invalid unicode value"); + + // surrogate range + TestBadArbitraryAtom("(\"\\ud800\")", "Invalid unicode value"); + TestBadArbitraryAtom("(\"\\udfff\")", "Invalid unicode value"); + TestBadArbitraryAtom("(\"\\U0000d800\")", "Invalid unicode value"); + TestBadArbitraryAtom("(\"\\U0000dfff\")", "Invalid unicode value"); + + TestBadArbitraryAtom("(x\"ag\")", "Invalid binary value"); + TestBadArbitraryAtom("(x\"abc\")", "Invalid binary value"); + TestBadArbitraryAtom("(x\"abcd)", "Invalid binary value"); + TestBadArbitraryAtom("(x\"abcd", "Unexpected end of atom"); + } + + void ParseAndPrint(const TString& program, const TString& expected) { + TAstParseResult ast = ParseAst(program); + UNIT_ASSERT_C(ast.IsOk(), program); + + TString result = ast.Root->ToString(); + UNIT_ASSERT_STRINGS_EQUAL_C(result, expected, program); + } + + Y_UNIT_TEST(ArbitraryAtomEscaping) { + ParseAndPrint( + "(\"\\t\\n\\r\\b\\a\\f\\v\")", + "(\"\\t\\n\\r\\b\\a\\f\\v\")"); + + ParseAndPrint("(\"\\u1234\")", "(\"\\u1234\")"); + ParseAndPrint("(\"\\u1234abcd\")", "(\"\\u1234abcd\")"); + ParseAndPrint("(\"\\177\")", "(\"\\x7F\")"); + ParseAndPrint("(\"\\377\")", "(\"\\xFF\")"); + + ParseAndPrint( + "(\"тестовая строка\")", + "(\"\\u0442\\u0435\\u0441\\u0442\\u043E\\u0432\\u0430" + "\\u044F \\u0441\\u0442\\u0440\\u043E\\u043A\\u0430\")"); + + ParseAndPrint("(\"\")", "(\"\")"); + } + + Y_UNIT_TEST(BinaryAtom) { + ParseAndPrint("(x\"abcdef\")", "(x\"ABCDEF\")"); + ParseAndPrint("(x\"aBcDeF\")", "(x\"ABCDEF\")"); + ParseAndPrint("(x)", "(x)"); + ParseAndPrint("(x x)", "(x x)"); + ParseAndPrint("(x\"\" x)", "(x\"\" x)"); + ParseAndPrint("(x\"ab12cd\" x)", "(x\"AB12CD\" x)"); + } + + void ParseAndAdaptPrint(const TString& program, const TString& expected) { + TAstParseResult ast = ParseAst(program); + UNIT_ASSERT_C(ast.IsOk(), program); + + TString result = ast.Root->ToString( + TAstPrintFlags::ShortQuote | TAstPrintFlags::PerLine | + TAstPrintFlags::AdaptArbitraryContent); + + RemoveAll(result, '\n'); // for simplify expected string + UNIT_ASSERT_STRINGS_EQUAL_C(result, expected, program); + } + + Y_UNIT_TEST(AdaptArbitraryAtom) { + ParseAndAdaptPrint("(\"test\")", "(test)"); + ParseAndAdaptPrint("(\"another test\")", "(\"another test\")"); + ParseAndAdaptPrint("(\"braces(in)test\")", "(\"braces(in)test\")"); + ParseAndAdaptPrint("(\"escaped\\u1234sequence\")", "(\"escaped\\u1234sequence\")"); + ParseAndAdaptPrint("(\"escaped\\x41sequence\")", "(escapedAsequence)"); + ParseAndAdaptPrint("(\"\")", "(\"\")"); + } + + void ParseError(const TString& program) { + TAstParseResult ast = ParseAst(program); + UNIT_ASSERT_C(!ast.IsOk(), program); + } + + Y_UNIT_TEST(MultilineAtomTrivial) { + TStringStream s; + for (ui32 i = 4; i < 13; ++i) { + TStringStream prog; + prog << "("; + for (ui32 j = 0; j < i; ++j) { + prog << "@"; + } + prog << ")"; + TAstParseResult ast = ParseAst(prog.Str()); + s << prog.Str() << " --> "; + if (ast.IsOk()) { + UNIT_ASSERT_VALUES_EQUAL(ast.Root->GetChildrenCount(), 1); + + TAstNode* atom = ast.Root->GetChild(0); + UNIT_ASSERT(atom->IsAtom()); + UNIT_ASSERT(atom->GetFlags() & TNodeFlags::MultilineContent); + s << "'" << atom->GetContent() << "'" << Endl; + } else { + s << "Error" << Endl; + } + } + //~ Cerr << s.Str() << Endl; + UNIT_ASSERT_NO_DIFF( + "(@@@@) --> ''\n" + "(@@@@@) --> '@'\n" + "(@@@@@@) --> Error\n" + "(@@@@@@@) --> Error\n" + "(@@@@@@@@) --> '@@'\n" + "(@@@@@@@@@) --> '@@@'\n" + "(@@@@@@@@@@) --> Error\n" + "(@@@@@@@@@@@) --> Error\n" + "(@@@@@@@@@@@@) --> '@@@@'\n", + s.Str() + ); + } + + Y_UNIT_TEST(MultilineAtom) { + TString s1 = "(@@multi \n" + "line \n" + "string@@)"; + ParseAndPrint(s1, s1); + + TString s2 = "(@@multi \n" + "l@ine \n" + "string@@)"; + ParseAndPrint(s2, s2); + + TString s3 = "(@@multi \n" + "l@@@ine \n" + "string@@)"; + ParseError(s3); + + TString s4 = "(@@multi \n" + "l@@@@ine \n" + "string@@)"; + ParseAndPrint(s4, s4); + + TString s5 = "(@@\n" + "one@\n" + "two@@@@\n" + "four@@@@@@@@\n" + "@@@@two\n" + "@one\n" + "@@)"; + + TAstParseResult ast = ParseAst(s5); + UNIT_ASSERT(ast.IsOk()); + UNIT_ASSERT_VALUES_EQUAL(ast.Root->GetChildrenCount(), 1); + + TAstNode* atom = ast.Root->GetChild(0); + UNIT_ASSERT(atom->IsAtom()); + UNIT_ASSERT(atom->GetFlags() & TNodeFlags::MultilineContent); + + TString expected = "\n" + "one@\n" + "two@@\n" + "four@@@@\n" + "@@two\n" + "@one\n"; + UNIT_ASSERT_STRINGS_EQUAL(atom->GetContent(), expected); + + TString printResult = ast.Root->ToString(); + UNIT_ASSERT_STRINGS_EQUAL(s5, printResult); + } + + Y_UNIT_TEST(UnicodePrettyPrint) { + ParseAndAdaptPrint("(\"абв αβγ ﬡ\")", "(\"\\u0430\\u0431\\u0432 \\u03B1\\u03B2\\u03B3 \\uFB21\")"); + } + + Y_UNIT_TEST(SerializeQuotedEmptyAtom) { + TMemoryPool pool(1024); + TPosition pos(1, 1); + TAstNode* empty = TAstNode::Quote(pos, pool, TAstNode::NewAtom(pos, "", pool)); + TString expected = "'\"\""; + + UNIT_ASSERT_STRINGS_EQUAL(empty->ToString(), expected); + + TString pretty = empty->ToString(TAstPrintFlags::ShortQuote | TAstPrintFlags::PerLine | + TAstPrintFlags::AdaptArbitraryContent); + RemoveAll(pretty, '\n'); + UNIT_ASSERT_EQUAL(pretty, expected); + + pretty = empty->ToString(TAstPrintFlags::ShortQuote | TAstPrintFlags::PerLine); + RemoveAll(pretty, '\n'); + UNIT_ASSERT_EQUAL(pretty, expected); + } +} + +} // namespace NYql diff --git a/yql/essentials/ast/yql_constraint.cpp b/yql/essentials/ast/yql_constraint.cpp new file mode 100644 index 00000000000..296d2c4f5e6 --- /dev/null +++ b/yql/essentials/ast/yql_constraint.cpp @@ -0,0 +1,2382 @@ +#include "yql_constraint.h" +#include "yql_expr.h" + +#include <util/digest/murmur.h> +#include <util/generic/utility.h> +#include <util/generic/algorithm.h> +#include <util/string/join.h> + +#include <algorithm> +#include <iterator> + +namespace NYql { + +TConstraintNode::TConstraintNode(TExprContext& ctx, std::string_view name) + : Hash_(MurmurHash<ui64>(name.data(), name.size())) + , Name_(ctx.AppendString(name)) +{ +} + +TConstraintNode::TConstraintNode(TConstraintNode&& constr) + : Hash_(constr.Hash_) + , Name_(constr.Name_) +{ + constr.Hash_ = 0; + constr.Name_ = {}; +} + +void TConstraintNode::Out(IOutputStream& out) const { + out.Write(Name_); +} + +////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +TPartOfConstraintBase::TPartOfConstraintBase(TExprContext& ctx, std::string_view name) + : TConstraintNode(ctx, name) +{} + +TConstraintWithFieldsNode::TConstraintWithFieldsNode(TExprContext& ctx, std::string_view name) + : TPartOfConstraintBase(ctx, name) +{} + +const TTypeAnnotationNode* TPartOfConstraintBase::GetSubTypeByPath(const TPathType& path, const TTypeAnnotationNode& type) { + if (path.empty() && ETypeAnnotationKind::Optional != type.GetKind()) + return &type; + + const auto tail = [](const TPathType& path) { + auto p(path); + p.pop_front(); + return p; + }; + switch (type.GetKind()) { + case ETypeAnnotationKind::Optional: + return GetSubTypeByPath(path, *type.Cast<TOptionalExprType>()->GetItemType()); + case ETypeAnnotationKind::List: // TODO: Remove later: temporary stub for single AsList in FlatMap and same cases. + return GetSubTypeByPath(path, *type.Cast<TListExprType>()->GetItemType()); + case ETypeAnnotationKind::Struct: + if (const auto itemType = type.Cast<TStructExprType>()->FindItemType(path.front())) + return GetSubTypeByPath(tail(path), *itemType); + break; + case ETypeAnnotationKind::Tuple: + if (const auto index = TryFromString<ui64>(TStringBuf(path.front()))) + if (const auto typleType = type.Cast<TTupleExprType>(); typleType->GetSize() > *index) + return GetSubTypeByPath(tail(path), *typleType->GetItems()[*index]); + break; + case ETypeAnnotationKind::Multi: + if (const auto index = TryFromString<ui64>(TStringBuf(path.front()))) + if (const auto multiType = type.Cast<TMultiExprType>(); multiType->GetSize() > *index) + return GetSubTypeByPath(tail(path), *multiType->GetItems()[*index]); + break; + case ETypeAnnotationKind::Variant: + return GetSubTypeByPath(path, *type.Cast<TVariantExprType>()->GetUnderlyingType()); + case ETypeAnnotationKind::Dict: + if (const auto index = TryFromString<ui8>(TStringBuf(path.front()))) + switch (*index) { + case 0U: return GetSubTypeByPath(tail(path), *type.Cast<TDictExprType>()->GetKeyType()); + case 1U: return GetSubTypeByPath(tail(path), *type.Cast<TDictExprType>()->GetPayloadType()); + default: break; + } + break; + default: + break; + } + return nullptr; +} + +bool TPartOfConstraintBase::HasDuplicates(const TSetOfSetsType& sets) { + for (auto ot = sets.cbegin(); sets.cend() != ot; ++ot) { + for (auto it = sets.cbegin(); sets.cend() != it; ++it) { + if (ot->size() < it->size() && std::all_of(ot->cbegin(), ot->cend(), [it](const TPathType& path) { return it->contains(path); })) + return true; + } + } + return false; +} + +NYT::TNode TPartOfConstraintBase::PathToNode(const TPartOfConstraintBase::TPathType& path) { + if (1U == path.size()) + return TStringBuf(path.front()); + + return std::accumulate(path.cbegin(), path.cend(), + NYT::TNode::CreateList(), + [](NYT::TNode node, std::string_view p) -> NYT::TNode { return std::move(node).Add(TStringBuf(p)); } + ); +}; + +NYT::TNode TPartOfConstraintBase::SetToNode(const TPartOfConstraintBase::TSetType& set, bool withShortcut) { + if (withShortcut && 1U == set.size() && 1U == set.front().size()) + return TStringBuf(set.front().front()); + + return std::accumulate(set.cbegin(), set.cend(), + NYT::TNode::CreateList(), + [](NYT::TNode node, const TPathType& path) -> NYT::TNode { return std::move(node).Add(PathToNode(path)); } + ); +}; + +NYT::TNode TPartOfConstraintBase::SetOfSetsToNode(const TPartOfConstraintBase::TSetOfSetsType& sets) { + return std::accumulate(sets.cbegin(), sets.cend(), + NYT::TNode::CreateList(), + [](NYT::TNode node, const TSetType& s) { + return std::move(node).Add(TPartOfConstraintBase::SetToNode(s, true)); + }); +} + +TPartOfConstraintBase::TPathType TPartOfConstraintBase::NodeToPath(TExprContext& ctx, const NYT::TNode& node) { + if (node.IsString()) + return TPartOfConstraintBase::TPathType{ctx.AppendString(node.AsString())}; + + TPartOfConstraintBase::TPathType path; + for (const auto& col : node.AsList()) { + path.emplace_back(ctx.AppendString(col.AsString())); + } + return path; +}; + +TPartOfConstraintBase::TSetType TPartOfConstraintBase::NodeToSet(TExprContext& ctx, const NYT::TNode& node) { + if (node.IsString()) + return TPartOfConstraintBase::TSetType{TPartOfConstraintBase::TPathType(1U, ctx.AppendString(node.AsString()))}; + + TPartOfConstraintBase::TSetType set; + for (const auto& col : node.AsList()) { + set.insert_unique(NodeToPath(ctx, col)); + } + return set; +}; + +TPartOfConstraintBase::TSetOfSetsType TPartOfConstraintBase::NodeToSetOfSets(TExprContext& ctx, const NYT::TNode& node) { + TSetOfSetsType sets; + for (const auto& s : node.AsList()) { + sets.insert_unique(NodeToSet(ctx, s)); + } + return sets; +} + +////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +const TConstraintNode* TConstraintSet::GetConstraint(std::string_view name) const { + const auto it = std::lower_bound(Constraints_.cbegin(), Constraints_.cend(), name, TConstraintNode::TCompare()); + if (it != Constraints_.cend() && (*it)->GetName() == name) { + return *it; + } + return nullptr; +} + +void TConstraintSet::AddConstraint(const TConstraintNode* node) { + if (!node) { + return; + } + const auto it = std::lower_bound(Constraints_.begin(), Constraints_.end(), node, TConstraintNode::TCompare()); + if (it == Constraints_.end() || (*it)->GetName() != node->GetName()) { + Constraints_.insert(it, node); + } else { + Y_ENSURE(node->Equals(**it), "Adding unequal constraint: " << *node << " != " << **it); + } +} + +const TConstraintNode* TConstraintSet::RemoveConstraint(std::string_view name) { + const TConstraintNode* res = nullptr; + const auto it = std::lower_bound(Constraints_.begin(), Constraints_.end(), name, TConstraintNode::TCompare()); + if (it != Constraints_.end() && (*it)->GetName() == name) { + res = *it; + Constraints_.erase(it); + } + return res; +} + +void TConstraintSet::Out(IOutputStream& out) const { + out.Write('{'); + bool first = true; + for (const auto& c: Constraints_) { + if (!first) + out.Write(','); + out << *c; + first = false; + } + out.Write('}'); +} + +void TConstraintSet::ToJson(NJson::TJsonWriter& writer) const { + writer.OpenMap(); + for (const auto& node : Constraints_) { + writer.WriteKey(node->GetName()); + node->ToJson(writer); + } + writer.CloseMap(); +} + +NYT::TNode TConstraintSet::ToYson() const { + auto res = NYT::TNode::CreateMap(); + for (const auto& node : Constraints_) { + auto serialized = node->ToYson(); + YQL_ENSURE(!serialized.IsUndefined(), "Cannot serialize " << node->GetName() << " constraint"); + res[node->GetName()] = std::move(serialized); + } + return res; +} + +bool TConstraintSet::FilterConstraints(const TPredicate& predicate) { + const auto size = Constraints_.size(); + for (auto it = Constraints_.begin(); Constraints_.end() != it;) + if (predicate((*it)->GetName())) + ++it; + else + it = Constraints_.erase(it); + return Constraints_.size() != size; +} + +////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace { + +size_t GetElementsCount(const TTypeAnnotationNode* type) { + if (type) { + switch (type->GetKind()) { + case ETypeAnnotationKind::Tuple: return type->Cast<TTupleExprType>()->GetSize(); + case ETypeAnnotationKind::Multi: return type->Cast<TMultiExprType>()->GetSize(); + case ETypeAnnotationKind::Struct: return type->Cast<TStructExprType>()->GetSize(); + default: break; + } + } + return 0U; +} + +std::deque<std::string_view> GetAllItemTypeFields(const TTypeAnnotationNode* type, TExprContext& ctx) { + std::deque<std::string_view> fields; + if (type) { + switch (type->GetKind()) { + case ETypeAnnotationKind::Struct: + if (const auto structType = type->Cast<TStructExprType>()) { + fields.resize(structType->GetSize()); + std::transform(structType->GetItems().cbegin(), structType->GetItems().cend(), fields.begin(), std::bind(&TItemExprType::GetName, std::placeholders::_1)); + } + break; + case ETypeAnnotationKind::Tuple: + if (const auto size = type->Cast<TTupleExprType>()->GetSize()) { + fields.resize(size); + ui32 i = 0U; + std::generate(fields.begin(), fields.end(), [&]() { return ctx.GetIndexAsString(i++); }); + } + break; + case ETypeAnnotationKind::Multi: + if (const auto size = type->Cast<TMultiExprType>()->GetSize()) { + fields.resize(size); + ui32 i = 0U; + std::generate(fields.begin(), fields.end(), [&]() { return ctx.GetIndexAsString(i++); }); + } + break; + default: + break; + } + } + return fields; +} + +TPartOfConstraintBase::TSetOfSetsType MakeFullSet(const TPartOfConstraintBase::TSetType& keys) { + TPartOfConstraintBase::TSetOfSetsType sets; + sets.reserve(sets.size()); + for (const auto& key : keys) + sets.insert_unique(TPartOfConstraintBase::TSetType{key}); + return sets; +} + +} + +////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +TSortedConstraintNode::TSortedConstraintNode(TExprContext& ctx, TContainerType&& content) + : TConstraintWithFieldsT(ctx, Name()) + , Content_(std::move(content)) +{ + YQL_ENSURE(!Content_.empty()); + for (const auto& c : Content_) { + YQL_ENSURE(!c.first.empty()); + for (const auto& path : c.first) + Hash_ = std::accumulate(path.cbegin(), path.cend(), c.second ? Hash_ : ~Hash_, [](ui64 hash, const std::string_view& field) { return MurmurHash<ui64>(field.data(), field.size(), hash); }); + } +} + +TSortedConstraintNode::TSortedConstraintNode(TExprContext& ctx, const NYT::TNode& serialized) + : TSortedConstraintNode(ctx, NodeToContainer(ctx, serialized)) +{ +} + +TSortedConstraintNode::TContainerType TSortedConstraintNode::NodeToContainer(TExprContext& ctx, const NYT::TNode& serialized) { + TSortedConstraintNode::TContainerType sorted; + try { + for (const auto& pair : serialized.AsList()) { + TPartOfConstraintBase::TSetType set = TPartOfConstraintBase::NodeToSet(ctx, pair.AsList().front()); + sorted.emplace_back(std::move(set), pair.AsList().back().AsBool()); + } + } catch (...) { + YQL_ENSURE(false, "Cannot deserialize " << Name() << " constraint: " << CurrentExceptionMessage()); + } + return sorted; +} + +TSortedConstraintNode::TSortedConstraintNode(TSortedConstraintNode&&) = default; + +bool TSortedConstraintNode::Equals(const TConstraintNode& node) const { + if (this == &node) { + return true; + } + + if (const auto c = dynamic_cast<const TSortedConstraintNode*>(&node)) { + return GetContent() == c->GetContent(); + } + return false; +} + +bool TSortedConstraintNode::Includes(const TConstraintNode& node) const { + if (this == &node) { + return true; + } + if (GetName() != node.GetName()) { + return false; + } + + const auto& content = static_cast<const TSortedConstraintNode&>(node).GetContent(); + if (content.size() > Content_.size()) + return false; + for (TContainerType::size_type i = 0U; i < content.size(); ++i) { + if (Content_[i].second != content[i].second || + !(std::includes(Content_[i].first.cbegin(), Content_[i].first.cend(), content[i].first.cbegin(), content[i].first.cend()) || std::includes(content[i].first.cbegin(), content[i].first.cend(), Content_[i].first.cbegin(), Content_[i].first.cend()))) + return false; + } + + return true; +} + +void TSortedConstraintNode::Out(IOutputStream& out) const { + TConstraintNode::Out(out); + out.Write('('); + bool first = true; + for (const auto& c : Content_) { + if (first) + first = false; + else + out.Write(';'); + + out.Write(JoinSeq(',', c.first)); + out.Write('['); + out.Write(c.second ? "asc" : "desc"); + out.Write(']'); + } + out.Write(')'); +} + +void TSortedConstraintNode::ToJson(NJson::TJsonWriter& out) const { + out.OpenArray(); + for (const auto& c : Content_) { + out.OpenArray(); + out.Write(JoinSeq(';', c.first)); + out.Write(c.second); + out.CloseArray(); + } + out.CloseArray(); +} + +NYT::TNode TSortedConstraintNode::ToYson() const { + return std::accumulate(Content_.cbegin(), Content_.cend(), + NYT::TNode::CreateList(), + [](NYT::TNode node, const std::pair<TSetType, bool>& pair) { + return std::move(node).Add(NYT::TNode::CreateList().Add(TPartOfConstraintBase::SetToNode(pair.first, false)).Add(pair.second)); + }); +} + +bool TSortedConstraintNode::IsPrefixOf(const TSortedConstraintNode& node) const { + return node.Includes(*this); +} + +bool TSortedConstraintNode::StartsWith(const TSetType& prefix) const { + auto set = prefix; + for (const auto& key : Content_) { + bool found = false; + std::for_each(key.first.cbegin(), key.first.cend(), [&set, &found] (const TPathType& path) { + if (const auto it = set.find(path); set.cend() != it) { + set.erase(it); + found = true; + } + }); + + if (!found) + break; + } + + return set.empty(); +} + +TPartOfConstraintBase::TSetType TSortedConstraintNode::GetFullSet() const { + TSetType set; + set.reserve(Content_.size()); + for (const auto& key : Content_) + set.insert_unique(key.first.cbegin(), key.first.cend()); + return set; +} + +void TSortedConstraintNode::FilterUncompleteReferences(TSetType& references) const { + TSetType complete; + complete.reserve(references.size()); + + for (const auto& item : Content_) { + bool found = false; + for (const auto& path : item.first) { + if (references.contains(path)) { + found = true; + complete.insert_unique(path); + } + } + + if (!found) + break; + } + + references = std::move(complete); +} + +const TSortedConstraintNode* TSortedConstraintNode::MakeCommon(const std::vector<const TConstraintSet*>& constraints, TExprContext& ctx) { + if (constraints.empty()) { + return nullptr; + } + + if (constraints.size() == 1) { + return constraints.front()->GetConstraint<TSortedConstraintNode>(); + } + + std::optional<TContainerType> content; + for (size_t i = 0U; i < constraints.size(); ++i) { + if (const auto sort = constraints[i]->GetConstraint<TSortedConstraintNode>()) { + const auto& nextContent = sort->GetContent(); + if (content) { + const auto size = std::min(content->size(), nextContent.size()); + content->resize(size); + for (auto j = 0U; j < size; ++j) { + auto& one = (*content)[j]; + auto& two = nextContent[j]; + TSetType common; + common.reserve(std::min(one.first.size(), two.first.size())); + std::set_intersection(one.first.cbegin(), one.first.cend(), two.first.cbegin(), two.first.cend(), std::back_inserter(common)); + if (common.empty() || one.second != two.second) { + content->resize(j); + break; + } else + one.first = std::move(common); + } + if (content->empty()) + break; + } else { + content = nextContent; + } + } else if (!constraints[i]->GetConstraint<TEmptyConstraintNode>()) { + content.reset(); + break; + } + } + + return !content || content->empty() ? nullptr : ctx.MakeConstraint<TSortedConstraintNode>(std::move(*content)); +} + +const TSortedConstraintNode* TSortedConstraintNode::MakeCommon(const TSortedConstraintNode* other, TExprContext& ctx) const { + if (!other) { + return nullptr; + } else if (this == other) { + return this; + } + + auto content = other->GetContent(); + const auto size = std::min(content.size(), Content_.size()); + content.resize(size); + for (auto j = 0U; j < size; ++j) { + auto& one = content[j]; + auto& two = Content_[j]; + TSetType common; + common.reserve(std::min(one.first.size(), two.first.size())); + std::set_intersection(one.first.cbegin(), one.first.cend(), two.first.cbegin(), two.first.cend(), std::back_inserter(common)); + if (common.empty() || one.second != two.second) { + content.resize(j); + break; + } else + one.first = std::move(common); + } + + return content.empty() ? nullptr : ctx.MakeConstraint<TSortedConstraintNode>(std::move(content)); +} + +const TSortedConstraintNode* TSortedConstraintNode::CutPrefix(size_t newPrefixLength, TExprContext& ctx) const { + if (!newPrefixLength) + return nullptr; + + if (newPrefixLength >= Content_.size()) + return this; + + auto content = Content_; + content.resize(newPrefixLength); + return ctx.MakeConstraint<TSortedConstraintNode>(std::move(content)); +} + +const TConstraintWithFieldsNode* TSortedConstraintNode::DoFilterFields(TExprContext& ctx, const TPathFilter& filter) const { + if (!filter) + return this; + + TContainerType sorted; + sorted.reserve(Content_.size()); + for (const auto& item : Content_) { + TSetType newSet; + newSet.reserve(item.first.size()); + for (const auto& path : item.first) { + if (filter(path)) + newSet.insert_unique(path); + } + + if (newSet.empty()) + break; + else + sorted.emplace_back(std::move(newSet), item.second); + } + return sorted.empty() ? nullptr : ctx.MakeConstraint<TSortedConstraintNode>(std::move(sorted)); +} + +const TConstraintWithFieldsNode* TSortedConstraintNode::DoRenameFields(TExprContext& ctx, const TPathReduce& reduce) const { + if (!reduce) + return this; + + TContainerType sorted; + sorted.reserve(Content_.size()); + for (const auto& item : Content_) { + TSetType newSet; + newSet.reserve(item.first.size()); + for (const auto& path : item.first) { + if (const auto& newPaths = reduce(path); !newPaths.empty()) + newSet.insert_unique(newPaths.cbegin(), newPaths.cend()); + } + + if (newSet.empty()) + break; + else + sorted.emplace_back(std::move(newSet), item.second); + } + return sorted.empty() ? nullptr : ctx.MakeConstraint<TSortedConstraintNode>(std::move(sorted)); +} + +bool TSortedConstraintNode::IsApplicableToType(const TTypeAnnotationNode& type) const { + const auto& itemType = GetSeqItemType(type); + return std::all_of(Content_.cbegin(), Content_.cend(), [&itemType](const std::pair<TSetType, bool>& pair) { + return std::all_of(pair.first.cbegin(), pair.first.cend(), std::bind(&GetSubTypeByPath, std::placeholders::_1, std::cref(itemType))); + }); +} + + +const TConstraintWithFieldsNode* +TSortedConstraintNode::DoGetComplicatedForType(const TTypeAnnotationNode& type, TExprContext& ctx) const { + const auto& rowType = GetSeqItemType(type); + bool changed = false; + auto content = Content_; + for (auto it = content.begin(); content.end() != it;) { + const auto subType = GetSubTypeByPath(it->first.front(), rowType); + auto fields = GetAllItemTypeFields(subType, ctx); + for (auto j = it->first.cbegin(); it->first.cend() != ++j;) { + if (!IsSameAnnotation(*GetSubTypeByPath(*j, rowType), *subType)) { + fields.clear(); + break; + } + } + + if (fields.empty() || ETypeAnnotationKind::Struct == subType->GetKind()) + ++it; + else { + changed = true; + const bool dir = it->second; + auto set = it->first; + for (auto& path : set) + path.emplace_back(); + for (it = content.erase(it); !fields.empty(); fields.pop_front()) { + auto paths = set; + for (auto& path : paths) + path.back() = fields.front(); + it = content.emplace(it, std::move(paths), dir); + ++it; + } + } + } + + return changed ? ctx.MakeConstraint<TSortedConstraintNode>(std::move(content)) : this; +} + +const TConstraintWithFieldsNode* +TSortedConstraintNode::DoGetSimplifiedForType(const TTypeAnnotationNode& type, TExprContext& ctx) const { + if (Content_.size() == 1U && Content_.front().first.size() == 1U && Content_.front().first.front().empty()) + return DoGetComplicatedForType(type, ctx); + + const auto& rowType = GetSeqItemType(type); + const auto getPrefix = [](TPartOfConstraintBase::TPathType path) { + path.pop_back(); + return path; + }; + + bool changed = false; + auto content = Content_; + for (bool setChanged = true; setChanged;) { + setChanged = false; + for (auto it = content.begin(); content.end() != it;) { + if (it->first.size() > 1U) { + for (const auto& path : it->first) { + if (path.size() > 1U && path.back() == ctx.GetIndexAsString(0U)) { + const auto prefix = getPrefix(path); + if (const auto subType = GetSubTypeByPath(prefix, rowType); ETypeAnnotationKind::Struct != subType->GetKind() && 1 == GetElementsCount(subType)) { + it->first.erase(path); + it->first.insert(prefix); + changed = setChanged = true; + } + } + } + ++it; + } else if (it->first.size() == 1U && it->first.front().size() > 1U) { + const auto prefix = getPrefix(it->first.front()); + if (const auto subType = GetSubTypeByPath(prefix, rowType); it->first.front().back() == ctx.GetIndexAsString(0U) && ETypeAnnotationKind::Struct != subType->GetKind()) { + auto from = it++; + for (auto i = 1U; content.cend() != it && it->first.size() == 1U && it->first.front().size() > 1U && ctx.GetIndexAsString(i) == it->first.front().back() && prefix == getPrefix(it->first.front()) && from->second == it->second; ++i) + ++it; + + if (ssize_t(GetElementsCount(subType)) == std::distance(from, it)) { + *from = std::make_pair(TPartOfConstraintBase::TSetType{std::move(prefix)}, from->second); + ++from; + it = content.erase(from, it); + changed = setChanged = true; + } + } else + ++it; + } else + ++it; + } + } + + return changed ? ctx.MakeConstraint<TSortedConstraintNode>(std::move(content)) : this; +} + +////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +TChoppedConstraintNode::TChoppedConstraintNode(TExprContext& ctx, TSetOfSetsType&& sets) + : TConstraintWithFieldsT(ctx, Name()) + , Sets_(std::move(sets)) +{ + YQL_ENSURE(!Sets_.empty()); + YQL_ENSURE(!HasDuplicates(Sets_)); + const auto size = Sets_.size(); + Hash_ = MurmurHash<ui64>(&size, sizeof(size), Hash_); + for (const auto& set : Sets_) { + YQL_ENSURE(!set.empty()); + for (const auto& path : set) + Hash_ = std::accumulate(path.cbegin(), path.cend(), Hash_, [](ui64 hash, const std::string_view& field) { return MurmurHash<ui64>(field.data(), field.size(), hash); }); + } +} + +TChoppedConstraintNode::TChoppedConstraintNode(TExprContext& ctx, const TSetType& keys) + : TChoppedConstraintNode(ctx, MakeFullSet(keys)) +{} + +TChoppedConstraintNode::TChoppedConstraintNode(TExprContext& ctx, const NYT::TNode& serialized) + : TChoppedConstraintNode(ctx, NodeToSets(ctx, serialized)) +{ +} + +TChoppedConstraintNode::TSetOfSetsType TChoppedConstraintNode::NodeToSets(TExprContext& ctx, const NYT::TNode& serialized) { + try { + return TPartOfConstraintBase::NodeToSetOfSets(ctx, serialized); + } catch (...) { + YQL_ENSURE(false, "Cannot deserialize " << Name() << " constraint: " << CurrentExceptionMessage()); + } + Y_UNREACHABLE(); +} + +TChoppedConstraintNode::TChoppedConstraintNode(TChoppedConstraintNode&& constr) = default; + +TPartOfConstraintBase::TSetType TChoppedConstraintNode::GetFullSet() const { + TSetType set; + set.reserve(Sets_.size()); + for (const auto& key : Sets_) + set.insert_unique(key.cbegin(), key.cend()); + return set; +} + +bool TChoppedConstraintNode::Equals(const TConstraintNode& node) const { + if (this == &node) { + return true; + } + if (GetHash() != node.GetHash()) { + return false; + } + if (const auto c = dynamic_cast<const TChoppedConstraintNode*>(&node)) { + return Sets_ == c->Sets_; + } + return false; +} + +bool TChoppedConstraintNode::Includes(const TConstraintNode& node) const { + if (this == &node) { + return true; + } + if (const auto c = dynamic_cast<const TChoppedConstraintNode*>(&node)) { + return std::includes(Sets_.cbegin(), Sets_.cend(), c->Sets_.cbegin(), c->Sets_.cend()); + } + return false; +} + +void TChoppedConstraintNode::Out(IOutputStream& out) const { + TConstraintNode::Out(out); + out.Write('('); + + for (const auto& set : Sets_) { + out.Write('('); + bool first = true; + for (const auto& path : set) { + if (first) + first = false; + else + out.Write(','); + out << path; + } + out.Write(')'); + } + out.Write(')'); +} + +void TChoppedConstraintNode::ToJson(NJson::TJsonWriter& out) const { + out.OpenArray(); + for (const auto& set : Sets_) { + out.OpenArray(); + for (const auto& path : set) { + out.Write(JoinSeq(';', path)); + } + out.CloseArray(); + } + out.CloseArray(); +} + +NYT::TNode TChoppedConstraintNode::ToYson() const { + return TPartOfConstraintBase::SetOfSetsToNode(Sets_); +} + +bool TChoppedConstraintNode::Equals(const TSetType& prefix) const { + auto set = prefix; + for (const auto& key : Sets_) { + bool found = false; + std::for_each(key.cbegin(), key.cend(), [&set, &found] (const TPathType& path) { + if (const auto it = set.find(path); set.cend() != it) { + set.erase(it); + found = true; + } + }); + + if (!found) + return false; + } + + return set.empty(); +} + + +void TChoppedConstraintNode::FilterUncompleteReferences(TSetType& references) const { + TSetType complete; + complete.reserve(references.size()); + + for (const auto& item : Sets_) { + bool found = false; + for (const auto& path : item) { + if (references.contains(path)) { + found = true; + complete.insert_unique(path); + } + } + + if (!found) + break; + } + + references = std::move(complete); +} + +const TChoppedConstraintNode* TChoppedConstraintNode::MakeCommon(const std::vector<const TConstraintSet*>& constraints, TExprContext& ctx) { + if (constraints.empty()) { + return nullptr; + } + if (constraints.size() == 1) { + return constraints.front()->GetConstraint<TChoppedConstraintNode>(); + } + + TSetOfSetsType sets; + for (auto c: constraints) { + if (const auto uniq = c->GetConstraint<TChoppedConstraintNode>()) { + if (sets.empty()) + sets = uniq->GetContent(); + else { + TSetOfSetsType both; + both.reserve(std::min(sets.size(), uniq->GetContent().size())); + std::set_intersection(sets.cbegin(), sets.cend(), uniq->GetContent().cbegin(), uniq->GetContent().cend(), std::back_inserter(both)); + if (both.empty()) { + if (!c->GetConstraint<TEmptyConstraintNode>()) + return nullptr; + } else + sets = std::move(both); + } + } else if (!c->GetConstraint<TEmptyConstraintNode>()) { + return nullptr; + } + } + + return sets.empty() ? nullptr : ctx.MakeConstraint<TChoppedConstraintNode>(std::move(sets)); +} + +const TConstraintWithFieldsNode* +TChoppedConstraintNode::DoFilterFields(TExprContext& ctx, const TPathFilter& predicate) const { + if (!predicate) + return this; + + TSetOfSetsType chopped; + chopped.reserve(Sets_.size()); + for (const auto& set : Sets_) { + auto newSet = set; + for (auto it = newSet.cbegin(); newSet.cend() != it;) { + if (predicate(*it)) + ++it; + else + it = newSet.erase(it); + } + + if (newSet.empty()) + return nullptr;; + + chopped.insert_unique(std::move(newSet)); + } + return ctx.MakeConstraint<TChoppedConstraintNode>(std::move(chopped)); +} + +const TConstraintWithFieldsNode* +TChoppedConstraintNode::DoRenameFields(TExprContext& ctx, const TPathReduce& reduce) const { + if (!reduce) + return this; + + TSetOfSetsType chopped; + chopped.reserve(Sets_.size()); + for (const auto& set : Sets_) { + TSetType newSet; + newSet.reserve(set.size()); + for (const auto& path : set) { + if (const auto& newPaths = reduce(path); !newPaths.empty()) + newSet.insert_unique(newPaths.cbegin(), newPaths.cend()); + } + + if (newSet.empty()) + return nullptr; + + chopped.insert_unique(std::move(newSet)); + } + + return ctx.MakeConstraint<TChoppedConstraintNode>(std::move(chopped)); +} + +const TChoppedConstraintNode* +TChoppedConstraintNode::MakeCommon(const TChoppedConstraintNode* other, TExprContext& ctx) const { + if (!other) { + return nullptr; + } + if (this == other) { + return this; + } + + TSetOfSetsType both; + both.reserve(std::min(Sets_.size(), other->Sets_.size())); + std::set_intersection(Sets_.cbegin(), Sets_.cend(), other->Sets_.cbegin(), other->Sets_.cend(), std::back_inserter(both)); + return both.empty() ? nullptr : ctx.MakeConstraint<TChoppedConstraintNode>(std::move(both)); +} + +bool TChoppedConstraintNode::IsApplicableToType(const TTypeAnnotationNode& type) const { + const auto& itemType = GetSeqItemType(type); + return std::all_of(Sets_.cbegin(), Sets_.cend(), [&itemType](const TSetType& set) { + return std::all_of(set.cbegin(), set.cend(), std::bind(&GetSubTypeByPath, std::placeholders::_1, std::cref(itemType))); + }); +} + +const TConstraintWithFieldsNode* +TChoppedConstraintNode::DoGetComplicatedForType(const TTypeAnnotationNode& type, TExprContext& ctx) const { + const auto& rowType = GetSeqItemType(type); + bool changed = false; + auto sets = Sets_; + for (auto it = sets.begin(); sets.end() != it;) { + auto fields = GetAllItemTypeFields(GetSubTypeByPath(it->front(), rowType), ctx); + for (auto j = it->cbegin(); it->cend() != ++j;) { + if (const auto& copy = GetAllItemTypeFields(GetSubTypeByPath(*j, rowType), ctx); copy != fields) { + fields.clear(); + break; + } + } + + if (fields.empty()) + ++it; + else { + changed = true; + auto set = *it; + for (auto& path : set) + path.emplace_back(); + for (it = sets.erase(it); !fields.empty(); fields.pop_front()) { + auto paths = set; + for (auto& path : paths) + path.back() = fields.front(); + it = sets.insert_unique(std::move(paths)).first; + } + } + } + + return changed ? ctx.MakeConstraint<TChoppedConstraintNode>(std::move(sets)) : this; +} + +const TConstraintWithFieldsNode* +TChoppedConstraintNode::DoGetSimplifiedForType(const TTypeAnnotationNode& type, TExprContext& ctx) const { + if (Sets_.size() == 1U && Sets_.front().size() == 1U && Sets_.front().front().empty()) + return DoGetComplicatedForType(type, ctx); + + const auto& rowType = GetSeqItemType(type); + const auto getPrefix = [](TPartOfConstraintBase::TPathType path) { + path.pop_back(); + return path; + }; + + bool changed = false; + auto sets = Sets_; + for (bool setChanged = true; setChanged;) { + setChanged = false; + for (auto it = sets.begin(); sets.end() != it;) { + if (it->size() != 1U || it->front().size() <= 1U) + ++it; + else { + auto from = it++; + const auto prefix = getPrefix(from->front()); + while (sets.cend() != it && it->size() == 1U && it->front().size() > 1U && prefix == getPrefix(it->front())) + ++it; + + if (ssize_t(GetElementsCount(GetSubTypeByPath(prefix, rowType))) == std::distance(from, it)) { + *from++ = TPartOfConstraintBase::TSetType{std::move(prefix)}; + it = sets.erase(from, it); + changed = setChanged = true; + } + } + } + } + + return changed ? ctx.MakeConstraint<TChoppedConstraintNode>(std::move(sets)) : this; +} + +////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +template<bool Distinct> +TConstraintWithFieldsNode::TSetOfSetsType +TUniqueConstraintNodeBase<Distinct>::ColumnsListToSets(const std::vector<std::string_view>& columns) { + YQL_ENSURE(!columns.empty()); + TConstraintWithFieldsNode::TSetOfSetsType sets; + sets.reserve(columns.size()); + std::for_each(columns.cbegin(), columns.cend(), [&sets](const std::string_view& column) { sets.insert_unique(TConstraintWithFieldsNode::TSetType{column.empty() ? TConstraintWithFieldsNode::TPathType() : TConstraintWithFieldsNode::TPathType(1U, column)}); }); + return sets; +} + +template<bool Distinct> +typename TUniqueConstraintNodeBase<Distinct>::TContentType +TUniqueConstraintNodeBase<Distinct>::DedupSets(TContentType&& sets) { + for (bool found = true; found && sets.size() > 1U;) { + found = false; + for (auto ot = sets.cbegin(); !found && sets.cend() != ot; ++ot) { + for (auto it = sets.cbegin(); sets.cend() != it;) { + if (ot->size() < it->size() && std::all_of(ot->cbegin(), ot->cend(), [it](const TConstraintWithFieldsNode::TSetType& set) { return it->contains(set); })) { + it = sets.erase(it); + found = true; + } else + ++it; + } + } + } + + return std::move(sets); +} + +template<bool Distinct> +typename TUniqueConstraintNodeBase<Distinct>::TContentType +TUniqueConstraintNodeBase<Distinct>::MakeCommonContent(const TContentType& one, const TContentType& two) { + TContentType both; + both.reserve(std::min(one.size(), two.size())); + for (const auto& setsOne : one) { + for (const auto& setsTwo : two) { + if (setsOne.size() == setsTwo.size()) { + TConstraintWithFieldsNode::TSetOfSetsType sets; + sets.reserve(setsTwo.size()); + for (const auto& setOne : setsOne) { + for (const auto& setTwo : setsTwo) { + TConstraintWithFieldsNode::TSetType set; + set.reserve(std::min(setOne.size(), setTwo.size())); + std::set_intersection(setOne.cbegin(), setOne.cend(), setTwo.cbegin(), setTwo.cend(), std::back_inserter(set)); + if (!set.empty()) + sets.insert_unique(std::move(set)); + } + } + if (sets.size() == setsOne.size()) + both.insert_unique(std::move(sets)); + } + } + } + return both; +} + +template<bool Distinct> +TUniqueConstraintNodeBase<Distinct>::TUniqueConstraintNodeBase(TExprContext& ctx, TContentType&& sets) + : TBase(ctx, Name()) + , Content_(DedupSets(std::move(sets))) +{ + YQL_ENSURE(!Content_.empty()); + const auto size = Content_.size(); + TBase::Hash_ = MurmurHash<ui64>(&size, sizeof(size), TBase::Hash_); + for (const auto& sets : Content_) { + YQL_ENSURE(!sets.empty()); + YQL_ENSURE(!TConstraintWithFieldsNode::HasDuplicates(sets)); + for (const auto& set : sets) { + YQL_ENSURE(!set.empty()); + for (const auto& path : set) + TBase::Hash_ = std::accumulate(path.cbegin(), path.cend(), TBase::Hash_, [](ui64 hash, const std::string_view& field) { return MurmurHash<ui64>(field.data(), field.size(), hash); }); + } + } +} + +template<bool Distinct> +TUniqueConstraintNodeBase<Distinct>::TUniqueConstraintNodeBase(TExprContext& ctx, const std::vector<std::string_view>& columns) + : TUniqueConstraintNodeBase(ctx, TContentType{TPartOfConstraintBase::TSetOfSetsType{ColumnsListToSets(columns)}}) +{} + +template<bool Distinct> +TUniqueConstraintNodeBase<Distinct>::TUniqueConstraintNodeBase(TExprContext& ctx, const NYT::TNode& serialized) + : TUniqueConstraintNodeBase(ctx, NodeToContent(ctx, serialized)) +{ +} + +template<bool Distinct> +typename TUniqueConstraintNodeBase<Distinct>::TContentType TUniqueConstraintNodeBase<Distinct>::NodeToContent(TExprContext& ctx, const NYT::TNode& serialized) { + TUniqueConstraintNode::TContentType content; + try { + for (const auto& item : serialized.AsList()) { + content.insert_unique(TPartOfConstraintBase::NodeToSetOfSets(ctx, item)); + } + } catch (...) { + YQL_ENSURE(false, "Cannot deserialize " << Name() << " constraint: " << CurrentExceptionMessage()); + } + return content; +} + +template<bool Distinct> +TUniqueConstraintNodeBase<Distinct>::TUniqueConstraintNodeBase(TUniqueConstraintNodeBase&& constr) = default; + +template<bool Distinct> +TPartOfConstraintBase::TSetType +TUniqueConstraintNodeBase<Distinct>::GetFullSet() const { + TPartOfConstraintBase::TSetType set; + set.reserve(Content_.size()); + for (const auto& sets : Content_) + for (const auto& key : sets) + set.insert_unique(key.cbegin(), key.cend()); + return set; +} + +template<bool Distinct> +bool TUniqueConstraintNodeBase<Distinct>::Equals(const TConstraintNode& node) const { + if (this == &node) { + return true; + } + if (TBase::GetHash() != node.GetHash()) { + return false; + } + if (const auto c = dynamic_cast<const TUniqueConstraintNodeBase*>(&node)) { + return Content_ == c->Content_; + } + return false; +} + +template<bool Distinct> +bool TUniqueConstraintNodeBase<Distinct>::Includes(const TConstraintNode& node) const { + if (this == &node) + return true; + + if (const auto c = dynamic_cast<const TUniqueConstraintNodeBase*>(&node)) { + return std::all_of(c->Content_.cbegin(), c->Content_.cend(), [&] (const TConstraintWithFieldsNode::TSetOfSetsType& oldSets) { + return std::any_of(Content_.cbegin(), Content_.cend(), [&] (const TConstraintWithFieldsNode::TSetOfSetsType& newSets) { + return oldSets.size() == newSets.size() && std::all_of(oldSets.cbegin(), oldSets.cend(), [&] (const TConstraintWithFieldsNode::TSetType& oldSet) { + return std::any_of(newSets.cbegin(), newSets.cend(), [&] (const TConstraintWithFieldsNode::TSetType& newSet) { + return std::includes(newSet.cbegin(), newSet.cend(), oldSet.cbegin(), oldSet.cend()); + }); + }); + }); + }); + } + return false; +} + +template<bool Distinct> +void TUniqueConstraintNodeBase<Distinct>::Out(IOutputStream& out) const { + TConstraintNode::Out(out); + out.Write('('); + for (const auto& sets : Content_) { + out.Write('('); + bool first = true; + for (const auto& set : sets) { + if (first) + first = false; + else + out << ','; + if (1U == set.size()) + out << set.front(); + else + out << set; + } + out.Write(')'); + } + out.Write(')'); +} + +template<bool Distinct> +void TUniqueConstraintNodeBase<Distinct>::ToJson(NJson::TJsonWriter& out) const { + out.OpenArray(); + for (const auto& sets : Content_) { + out.OpenArray(); + for (const auto& set : sets) { + out.OpenArray(); + for (const auto& path : set) { + out.Write(JoinSeq(';', path)); + } + out.CloseArray(); + } + out.CloseArray(); + } + out.CloseArray(); +} + +template<bool Distinct> +NYT::TNode TUniqueConstraintNodeBase<Distinct>::ToYson() const { + return std::accumulate(Content_.cbegin(), Content_.cend(), + NYT::TNode::CreateList(), + [](NYT::TNode node, const TConstraintWithFieldsNode::TSetOfSetsType& sets) { + return std::move(node).Add(TConstraintWithFieldsNode::SetOfSetsToNode(sets)); + }); +} + +template<bool Distinct> +const TUniqueConstraintNodeBase<Distinct>* TUniqueConstraintNodeBase<Distinct>::MakeCommon(const std::vector<const TConstraintSet*>& constraints, TExprContext& ctx) { + if (constraints.empty()) { + return nullptr; + } + if (constraints.size() == 1) { + return constraints.front()->GetConstraint<TUniqueConstraintNodeBase>(); + } + + TContentType content; + for (auto c: constraints) { + if (const auto uniq = c->GetConstraint<TUniqueConstraintNodeBase>()) { + if (content.empty()) + content = uniq->GetContent(); + else { + if (auto both = MakeCommonContent(content, uniq->Content_); both.empty()) { + if (!c->GetConstraint<TEmptyConstraintNode>()) + return nullptr; + } else + content = std::move(both); + } + } else if (!c->GetConstraint<TEmptyConstraintNode>()) { + return nullptr; + } + } + + return content.empty() ? nullptr : ctx.MakeConstraint<TUniqueConstraintNodeBase>(std::move(content)); +} + +template<bool Distinct> +bool TUniqueConstraintNodeBase<Distinct>::IsOrderBy(const TSortedConstraintNode& sorted) const { + TConstraintWithFieldsNode::TSetType ordered; + TConstraintWithFieldsNode::TSetOfSetsType columns; + for (const auto& key : sorted.GetContent()) { + ordered.insert_unique(key.first.cbegin(), key.first.cend()); + columns.insert_unique(key.first); + } + + for (const auto& sets : Content_) { + if (std::all_of(sets.cbegin(), sets.cend(), [&ordered](const TConstraintWithFieldsNode::TSetType& set) { + return std::any_of(set.cbegin(), set.cend(), [&ordered](const TConstraintWithFieldsNode::TPathType& path) { return ordered.contains(path); }); + })) { + std::for_each(sets.cbegin(), sets.cend(), [&columns](const TConstraintWithFieldsNode::TSetType& set) { + std::for_each(set.cbegin(), set.cend(), [&columns](const TConstraintWithFieldsNode::TPathType& path) { + if (const auto it = std::find_if(columns.cbegin(), columns.cend(), [&path](const TConstraintWithFieldsNode::TSetType& s) { return s.contains(path); }); columns.cend() != it) + columns.erase(it); + }); + }); + if (columns.empty()) + return true; + } + } + + return false; +} + +template<bool Distinct> +bool TUniqueConstraintNodeBase<Distinct>::ContainsCompleteSet(const std::vector<std::string_view>& columns) const { + if (columns.empty()) + return false; + + const std::unordered_set<std::string_view> ordered(columns.cbegin(), columns.cend()); + for (const auto& sets : Content_) { + if (std::all_of(sets.cbegin(), sets.cend(), [&ordered](const TConstraintWithFieldsNode::TSetType& set) { + return std::any_of(set.cbegin(), set.cend(), [&ordered](const TConstraintWithFieldsNode::TPathType& path) { return !path.empty() && ordered.contains(path.front()); }); + })) + return true; + } + return false; +} + +template<bool Distinct> +void TUniqueConstraintNodeBase<Distinct>::FilterUncompleteReferences(TPartOfConstraintBase::TSetType& references) const { + TPartOfConstraintBase::TSetType input(std::move(references)); + references.clear(); + references.reserve(input.size()); + for (const auto& sets : Content_) { + if (std::all_of(sets.cbegin(), sets.cend(), [&input] (const TPartOfConstraintBase::TSetType& set) { return std::any_of(set.cbegin(), set.cend(), std::bind(&TPartOfConstraintBase::TSetType::contains<TPartOfConstraintBase::TPathType>, std::cref(input), std::placeholders::_1)); })) + std::for_each(sets.cbegin(), sets.cend(), [&] (const TPartOfConstraintBase::TSetType& set) { std::for_each(set.cbegin(), set.cend(), [&] (const TPartOfConstraintBase::TPathType& path) { + if (input.contains(path)) + references.insert_unique(path); + }); }); + } +} + +template<bool Distinct> +const TConstraintWithFieldsNode* +TUniqueConstraintNodeBase<Distinct>::DoFilterFields(TExprContext& ctx, const TPartOfConstraintBase::TPathFilter& predicate) const { + if (!predicate) + return this; + + TContentType content; + content.reserve(Content_.size()); + for (const auto& sets : Content_) { + if (std::all_of(sets.cbegin(), sets.cend(), [&predicate](const TPartOfConstraintBase::TSetType& set) { return std::any_of(set.cbegin(), set.cend(), predicate); })) { + TPartOfConstraintBase::TSetOfSetsType newSets; + newSets.reserve(sets.size()); + std::for_each(sets.cbegin(), sets.cend(), [&](const TPartOfConstraintBase::TSetType& set) { + TPartOfConstraintBase::TSetType newSet; + newSet.reserve(set.size()); + std::copy_if(set.cbegin(), set.cend(), std::back_inserter(newSet), predicate); + newSets.insert_unique(std::move(newSet)); + }); + content.insert_unique(std::move(newSets)); + } + } + return content.empty() ? nullptr : ctx.MakeConstraint<TUniqueConstraintNodeBase>(std::move(content)); +} + +template<bool Distinct> +const TConstraintWithFieldsNode* +TUniqueConstraintNodeBase<Distinct>::DoRenameFields(TExprContext& ctx, const TPartOfConstraintBase::TPathReduce& reduce) const { + if (!reduce) + return this; + + TContentType content; + content.reserve(Content_.size()); + for (const auto& sets : Content_) { + TPartOfConstraintBase::TSetOfSetsType newSets; + newSets.reserve(sets.size()); + for (const auto& set : sets) { + TPartOfConstraintBase::TSetType newSet; + newSet.reserve(set.size()); + for (const auto& path : set) { + const auto newPaths = reduce(path); + newSet.insert_unique(newPaths.cbegin(), newPaths.cend()); + } + if (!newSet.empty()) + newSets.insert_unique(std::move(newSet)); + } + if (sets.size() == newSets.size()) + content.insert_unique(std::move(newSets)); + } + return content.empty() ? nullptr : ctx.MakeConstraint<TUniqueConstraintNodeBase>(std::move(content)); +} + +template<bool Distinct> +const TUniqueConstraintNodeBase<Distinct>* +TUniqueConstraintNodeBase<Distinct>::MakeCommon(const TUniqueConstraintNodeBase* other, TExprContext& ctx) const { + if (!other) + return nullptr; + + if (this == other) + return this; + + auto both = MakeCommonContent(Content_, other->Content_); + return both.empty() ? nullptr : ctx.MakeConstraint<TUniqueConstraintNodeBase>(std::move(both)); +} + +template<bool Distinct> +const TUniqueConstraintNodeBase<Distinct>* TUniqueConstraintNodeBase<Distinct>::Merge(const TUniqueConstraintNodeBase* one, const TUniqueConstraintNodeBase* two, TExprContext& ctx) { + if (!one) + return two; + if (!two) + return one; + + auto content = one->Content_; + content.insert_unique(two->Content_.cbegin(), two->Content_.cend()); + return ctx.MakeConstraint<TUniqueConstraintNodeBase<Distinct>>(std::move(content)); +} + +template<bool Distinct> +const TConstraintWithFieldsNode* +TUniqueConstraintNodeBase<Distinct>::DoGetComplicatedForType(const TTypeAnnotationNode& type, TExprContext& ctx) const { + const auto& rowType = GetSeqItemType(type); + bool changed = false; + auto content = Content_; + for (auto& sets : content) { + for (auto it = sets.begin(); sets.end() != it;) { + auto fields = GetAllItemTypeFields(TBase::GetSubTypeByPath(it->front(), rowType), ctx); + for (auto j = it->cbegin(); it->cend() != ++j;) { + if (const auto& copy = GetAllItemTypeFields(TBase::GetSubTypeByPath(*j, rowType), ctx); copy != fields) { + fields.clear(); + break; + } + } + + if (fields.empty()) + ++it; + else { + changed = true; + auto set = *it; + for (auto& path : set) + path.emplace_back(); + for (it = sets.erase(it); !fields.empty(); fields.pop_front()) { + auto paths = set; + for (auto& path : paths) + path.back() = fields.front(); + it = sets.insert_unique(std::move(paths)).first; + } + } + } + } + + return changed ? ctx.MakeConstraint<TUniqueConstraintNodeBase<Distinct>>(std::move(content)) : this; +} + +template<bool Distinct> +const TConstraintWithFieldsNode* +TUniqueConstraintNodeBase<Distinct>::DoGetSimplifiedForType(const TTypeAnnotationNode& type, TExprContext& ctx) const { + if (Content_.size() == 1U && Content_.front().size() == 1U && Content_.front().front().size() == 1U && Content_.front().front().front().empty()) + return DoGetComplicatedForType(type, ctx); + + const auto& rowType = GetSeqItemType(type); + const auto getPrefix = [](TPartOfConstraintBase::TPathType path) { + path.pop_back(); + return path; + }; + + bool changed = false; + auto content = Content_; + for (auto& sets : content) { + for (bool setChanged = true; setChanged;) { + setChanged = false; + for (auto it = sets.begin(); sets.end() != it;) { + if (!it->empty() && it->front().size() > 1U) { + TPartOfConstraintBase::TSetType prefixes; + prefixes.reserve(it->size()); + for (const auto& path : *it) { + if (path.size() > 1U) { + prefixes.emplace_back(getPrefix(path)); + } + } + + auto from = it++; + if (prefixes.size() < from->size()) + continue; + + while (sets.cend() != it && it->size() == prefixes.size() && + std::all_of(it->cbegin(), it->cend(), [&](const TPartOfConstraintBase::TPathType& path) { return path.size() > 1U && prefixes.contains(getPrefix(path)); })) { + ++it; + } + + if (std::all_of(prefixes.cbegin(), prefixes.cend(), + [width = std::distance(from, it), &rowType] (const TPartOfConstraintBase::TPathType& path) { return width == ssize_t(GetElementsCount(TBase::GetSubTypeByPath(path, rowType))); })) { + *from++ =std::move(prefixes); + it = sets.erase(from, it); + changed = setChanged = true; + } + } else + ++it; + } + } + } + + return changed ? ctx.MakeConstraint<TUniqueConstraintNodeBase<Distinct>>(std::move(content)) : this; +} + +template<bool Distinct> +bool TUniqueConstraintNodeBase<Distinct>::IsApplicableToType(const TTypeAnnotationNode& type) const { + if (ETypeAnnotationKind::Dict == type.GetKind()) + return true; // TODO: check for dict. + const auto& itemType = GetSeqItemType(type); + return std::all_of(Content_.cbegin(), Content_.cend(), [&itemType](const TConstraintWithFieldsNode::TSetOfSetsType& sets) { + return std::all_of(sets.cbegin(), sets.cend(), [&itemType](const TConstraintWithFieldsNode::TSetType& set) { + return std::all_of(set.cbegin(), set.cend(), std::bind(&TConstraintWithFieldsNode::GetSubTypeByPath, std::placeholders::_1, std::cref(itemType))); + }); + }); +} + +template class TUniqueConstraintNodeBase<false>; +template class TUniqueConstraintNodeBase<true>; + +////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +template<class TOriginalConstraintNode> +TPartOfConstraintNode<TOriginalConstraintNode>::TPartOfConstraintNode(TExprContext& ctx, TMapType&& mapping) + : TBase(ctx, Name()) + , Mapping_(std::move(mapping)) +{ + YQL_ENSURE(!Mapping_.empty()); + for (const auto& part : Mapping_) { + YQL_ENSURE(!part.second.empty()); + const auto hash = part.first->GetHash(); + TBase::Hash_ = MurmurHash<ui64>(&hash, sizeof(hash), TBase::Hash_); + for (const auto& item: part.second) { + TBase::Hash_ = std::accumulate(item.first.cbegin(), item.first.cend(), TBase::Hash_, [](ui64 hash, const std::string_view& field) { return MurmurHash<ui64>(field.data(), field.size(), hash); }); + TBase::Hash_ = std::accumulate(item.second.cbegin(), item.second.cend(), TBase::Hash_, [](ui64 hash, const std::string_view& field) { return MurmurHash<ui64>(field.data(), field.size(), hash); }); + } + } +} + +template<class TOriginalConstraintNode> +TPartOfConstraintNode<TOriginalConstraintNode>::TPartOfConstraintNode(TExprContext& ctx, const NYT::TNode&) + : TBase(ctx, Name()) +{ + YQL_ENSURE(false, "TPartOfConstraintNode cannot be deserialized"); +} + +template<class TOriginalConstraintNode> +TPartOfConstraintNode<TOriginalConstraintNode>::TPartOfConstraintNode(TPartOfConstraintNode&& constr) = default; + +template<class TOriginalConstraintNode> +bool TPartOfConstraintNode<TOriginalConstraintNode>::Equals(const TConstraintNode& node) const { + if (this == &node) { + return true; + } + if (TBase::GetHash() != node.GetHash()) { + return false; + } + if (const auto c = dynamic_cast<const TPartOfConstraintNode*>(&node)) { + return Mapping_ == c->Mapping_; + } + return false; +} + +template<class TOriginalConstraintNode> +bool TPartOfConstraintNode<TOriginalConstraintNode>::Includes(const TConstraintNode& node) const { + if (this == &node) { + return true; + } + if (const auto c = dynamic_cast<const TPartOfConstraintNode*>(&node)) { + for (const auto& part : c->Mapping_) { + if (const auto it = Mapping_.find(part.first); Mapping_.cend() != it) { + for (const auto& pair : part.second) { + if (const auto p = it->second.find(pair.first); it->second.cend() == p || p->second != pair.second) { + return false; + } + } + } else + return false; + } + return true; + } + return false; +} + +template<class TOriginalConstraintNode> +void TPartOfConstraintNode<TOriginalConstraintNode>::Out(IOutputStream& out) const { + TConstraintNode::Out(out); + out.Write('('); + bool first = true; + for (const auto& part : Mapping_) { + for (const auto& item : part.second) { + if (first) + first = false; + else + out.Write(','); + + out << item.first; + out.Write(':'); + out << item.second; + } + } + out.Write(')'); +} + +template<class TOriginalConstraintNode> +void TPartOfConstraintNode<TOriginalConstraintNode>::ToJson(NJson::TJsonWriter& out) const { + out.OpenMap(); + for (const auto& part : Mapping_) { + for (const auto& [resultColumn, originalColumn] : part.second) { + out.Write(JoinSeq(';', resultColumn), JoinSeq(';', originalColumn)); + } + } + out.CloseMap(); +} + +template<class TOriginalConstraintNode> +NYT::TNode TPartOfConstraintNode<TOriginalConstraintNode>::ToYson() const { + return {}; // cannot be serialized +} + +template<class TOriginalConstraintNode> +const TPartOfConstraintNode<TOriginalConstraintNode>* +TPartOfConstraintNode<TOriginalConstraintNode>::ExtractField(TExprContext& ctx, const std::string_view& field) const { + TMapType passtrought; + for (const auto& part : Mapping_) { + auto it = part.second.lower_bound(TPartOfConstraintBase::TPathType(1U, field)); + if (part.second.cend() == it || it->first.front() != field) + continue; + + TPartType mapping; + mapping.reserve(part.second.size()); + while (it < part.second.cend() && !it->first.empty() && field == it->first.front()) { + auto item = *it++; + item.first.pop_front(); + mapping.emplace_back(std::move(item)); + } + + if (!mapping.empty()) { + passtrought.emplace(part.first, std::move(mapping)); + } + } + return passtrought.empty() ? nullptr : ctx.MakeConstraint<TPartOfConstraintNode>(std::move(passtrought)); +} + +template<class TOriginalConstraintNode> +const TPartOfConstraintBase* +TPartOfConstraintNode<TOriginalConstraintNode>::DoFilterFields(TExprContext& ctx, const TPartOfConstraintBase::TPathFilter& predicate) const { + if (!predicate) + return this; + + auto mapping = Mapping_; + for (auto part = mapping.begin(); mapping.end() != part;) { + for (auto it = part->second.cbegin(); part->second.cend() != it;) { + if (predicate(it->first)) + ++it; + else + it = part->second.erase(it); + } + + if (part->second.empty()) + part = mapping.erase(part); + else + ++part; + } + return mapping.empty() ? nullptr : ctx.MakeConstraint<TPartOfConstraintNode>(std::move(mapping)); +} + +template<class TOriginalConstraintNode> +const TPartOfConstraintBase* +TPartOfConstraintNode<TOriginalConstraintNode>::DoRenameFields(TExprContext& ctx, const TPartOfConstraintBase::TPathReduce& rename) const { + if (!rename) + return this; + + TMapType mapping(Mapping_.size()); + for (const auto& part : Mapping_) { + TPartType map; + map.reserve(part.second.size()); + + for (const auto& item : part.second) { + for (auto& path : rename(item.first)) { + map.insert_unique(std::make_pair(std::move(path), item.second)); + } + } + + if (!map.empty()) + mapping.emplace(part.first, std::move(map)); + } + return mapping.empty() ? nullptr : ctx.MakeConstraint<TPartOfConstraintNode>(std::move(mapping)); +} + +template<class TOriginalConstraintNode> +const TPartOfConstraintNode<TOriginalConstraintNode>* +TPartOfConstraintNode<TOriginalConstraintNode>::CompleteOnly(TExprContext& ctx) const { + TMapType mapping(Mapping_); + + for (auto it = mapping.begin(); mapping.end() != it;) { + TPartOfConstraintBase::TSetType set; + set.reserve(it->second.size()); + std::for_each(it->second.cbegin(), it->second.cend(), [&](const typename TPartType::value_type& pair) { set.insert_unique(pair.second); }); + + it->first->FilterUncompleteReferences(set); + + for (auto jt = it->second.cbegin(); it->second.cend() != jt;) { + if (set.contains(jt->second)) + ++jt; + else + jt = it->second.erase(jt); + } + + if (it->second.empty()) + it = mapping.erase(it); + else + ++it; + } + + return mapping.empty() ? nullptr : ctx.MakeConstraint<TPartOfConstraintNode>(std::move(mapping)); +} + +template<class TOriginalConstraintNode> +const TPartOfConstraintNode<TOriginalConstraintNode>* +TPartOfConstraintNode<TOriginalConstraintNode>:: RemoveOriginal(TExprContext& ctx, const TMainConstraint* original) const { + TMapType mapping(Mapping_); + mapping.erase(original); + return mapping.empty() ? nullptr : ctx.MakeConstraint<TPartOfConstraintNode>(std::move(mapping)); +} + +template<class TOriginalConstraintNode> +typename TPartOfConstraintNode<TOriginalConstraintNode>::TMapType +TPartOfConstraintNode<TOriginalConstraintNode>::GetColumnMapping(const std::string_view& asField) const { + auto mapping = Mapping_; + if (!asField.empty()) { + for (auto& item : mapping) { + for (auto& part : item.second) { + part.first.emplace_front(asField); + } + } + } + return mapping; +} + +template<class TOriginalConstraintNode> +typename TPartOfConstraintNode<TOriginalConstraintNode>::TMapType +TPartOfConstraintNode<TOriginalConstraintNode>::GetColumnMapping(TExprContext& ctx, const std::string_view& prefix) const { + auto mapping = Mapping_; + if (!prefix.empty()) { + const TString str(prefix); + for (auto& item : mapping) { + for (auto& part : item.second) { + if (part.first.empty()) + part.first.emplace_front(prefix); + else + part.first.front() = ctx.AppendString(str + part.first.front()); + } + } + } + return mapping; +} + +template<class TOriginalConstraintNode> +const TPartOfConstraintNode<TOriginalConstraintNode>* +TPartOfConstraintNode<TOriginalConstraintNode>::MakeCommon(const std::vector<const TConstraintSet*>& constraints, TExprContext& ctx) { + if (constraints.empty()) { + return nullptr; + } + + if (constraints.size() == 1) { + return constraints.front()->GetConstraint<TPartOfConstraintNode>(); + } + + bool first = true; + TMapType mapping; + for (size_t i = 0; i < constraints.size(); ++i) { + const auto part = constraints[i]->GetConstraint<TPartOfConstraintNode>(); + if (!part) + return nullptr; + if (first) { + mapping = part->GetColumnMapping(); + first = false; + } else { + for (const auto& nextMapping : part->GetColumnMapping()) { + if (const auto it = mapping.find(nextMapping.first); mapping.cend() != it) { + TPartType result; + std::set_intersection( + it->second.cbegin(), it->second.cend(), + nextMapping.second.cbegin(), nextMapping.second.cend(), + std::back_inserter(result), + [] (const typename TPartType::value_type& c1, const typename TPartType::value_type& c2) { + return c1 < c2; + } + ); + if (result.empty()) + mapping.erase(it); + else + it->second = std::move(result); + } + } + } + if (mapping.empty()) { + break; + } + } + + return mapping.empty() ? nullptr : ctx.MakeConstraint<TPartOfConstraintNode>(std::move(mapping)); +} + +template<class TOriginalConstraintNode> +const typename TPartOfConstraintNode<TOriginalConstraintNode>::TMapType& +TPartOfConstraintNode<TOriginalConstraintNode>::GetColumnMapping() const { + return Mapping_; +} + +template<class TOriginalConstraintNode> +typename TPartOfConstraintNode<TOriginalConstraintNode>::TMapType +TPartOfConstraintNode<TOriginalConstraintNode>::GetCommonMapping(const TOriginalConstraintNode* complete, const TPartOfConstraintNode* incomplete, const std::string_view& field) { + TMapType mapping; + if (incomplete) { + mapping = incomplete->GetColumnMapping(); + mapping.erase(complete); + if (!field.empty()) { + for (auto& part : mapping) { + std::for_each(part.second.begin(), part.second.end(), [&field](typename TPartType::value_type& map) { map.first.push_front(field); }); + } + } + } + + if (complete) { + auto& part = mapping[complete]; + for (const auto& path : complete->GetFullSet()) { + auto key = path; + if (!field.empty()) + key.emplace_front(field); + part.insert_unique(std::make_pair(key, path)); + } + } + + return mapping; +} + +template<class TOriginalConstraintNode> +void TPartOfConstraintNode<TOriginalConstraintNode>::UniqueMerge(TMapType& output, TMapType&& input) { + output.merge(input); + while (!input.empty()) { + const auto exists = input.extract(input.cbegin()); + auto& target = output[exists.key()]; + target.reserve(target.size() + exists.mapped().size()); + for (auto& item : exists.mapped()) + target.insert_unique(std::move(item)); + } +} + +template<class TOriginalConstraintNode> +typename TPartOfConstraintNode<TOriginalConstraintNode>::TMapType +TPartOfConstraintNode<TOriginalConstraintNode>::ExtractField(const TMapType& mapping, const std::string_view& field) { + TMapType parts; + for (const auto& part : mapping) { + auto it = part.second.lower_bound(TPartOfConstraintBase::TPathType(1U, field)); + if (part.second.cend() == it || it->first.empty() || it->first.front() != field) + continue; + + TPartType mapping; + mapping.reserve(part.second.size()); + while (it < part.second.cend() && !it->first.empty() && field == it->first.front()) { + auto item = *it++; + item.first.pop_front(); + mapping.emplace_back(std::move(item)); + } + + if (!mapping.empty()) { + parts.emplace(part.first, std::move(mapping)); + } + } + return parts; +} + +template<class TOriginalConstraintNode> +const TOriginalConstraintNode* +TPartOfConstraintNode<TOriginalConstraintNode>::MakeComplete(TExprContext& ctx, const TMapType& mapping, const TOriginalConstraintNode* original, const std::string_view& field) { + if (const auto it = mapping.find(original); mapping.cend() != it) { + TReversePartType reverseMap; + reverseMap.reserve(it->second.size()); + for (const auto& map : it->second) + reverseMap[map.second].insert_unique(map.first); + + const auto rename = [&](const TPartOfConstraintBase::TPathType& path) { + const auto& set = reverseMap[path]; + std::vector<TPartOfConstraintBase::TPathType> out(set.cbegin(), set.cend()); + if (!field.empty()) + std::for_each(out.begin(), out.end(), [&field](TPartOfConstraintBase::TPathType& path) { path.emplace_front(field); }); + return out; + }; + + return it->first->RenameFields(ctx, rename); + } + + return nullptr; +} + +template<class TOriginalConstraintNode> +const TOriginalConstraintNode* +TPartOfConstraintNode<TOriginalConstraintNode>::MakeComplete(TExprContext& ctx, const TPartOfConstraintNode* partial, const TOriginalConstraintNode* original, const std::string_view& field) { + if (!partial) + return nullptr; + + return MakeComplete(ctx, partial->GetColumnMapping(), original, field); +} + +template<class TOriginalConstraintNode> +bool TPartOfConstraintNode<TOriginalConstraintNode>::IsApplicableToType(const TTypeAnnotationNode& type) const { + if (ETypeAnnotationKind::Dict == type.GetKind()) + return true; // TODO: check for dict. + + const auto itemType = GetSeqItemType(&type); + const auto& actualType = itemType ? *itemType : type; + return std::all_of(Mapping_.cbegin(), Mapping_.cend(), [&actualType](const typename TMapType::value_type& pair) { + return std::all_of(pair.second.cbegin(), pair.second.cend(), [&actualType](const typename TPartType::value_type& part) { return bool(TPartOfConstraintBase::GetSubTypeByPath(part.first, actualType)); }); + }); +} + +template class TPartOfConstraintNode<TSortedConstraintNode>; +template class TPartOfConstraintNode<TChoppedConstraintNode>; +template class TPartOfConstraintNode<TUniqueConstraintNode>; +template class TPartOfConstraintNode<TDistinctConstraintNode>; + +////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +TEmptyConstraintNode::TEmptyConstraintNode(TExprContext& ctx) + : TConstraintNode(ctx, Name()) +{ +} + +TEmptyConstraintNode::TEmptyConstraintNode(TEmptyConstraintNode&& constr) + : TConstraintNode(std::move(static_cast<TConstraintNode&>(constr))) +{ +} + +TEmptyConstraintNode::TEmptyConstraintNode(TExprContext& ctx, const NYT::TNode& serialized) + : TConstraintNode(ctx, Name()) +{ + YQL_ENSURE(serialized.IsEntity(), "Unexpected serialized content of " << Name() << " constraint"); +} + +bool TEmptyConstraintNode::Equals(const TConstraintNode& node) const { + if (this == &node) { + return true; + } + if (GetHash() != node.GetHash()) { + return false; + } + return GetName() == node.GetName(); +} + +void TEmptyConstraintNode::ToJson(NJson::TJsonWriter& out) const { + out.Write(true); +} + +NYT::TNode TEmptyConstraintNode::ToYson() const { + return NYT::TNode::CreateEntity(); +} + +const TEmptyConstraintNode* TEmptyConstraintNode::MakeCommon(const std::vector<const TConstraintSet*>& constraints, TExprContext& /*ctx*/) { + if (constraints.empty()) { + return nullptr; + } + + auto empty = constraints.front()->GetConstraint<TEmptyConstraintNode>(); + if (AllOf(constraints.cbegin() + 1, constraints.cend(), [empty](const TConstraintSet* c) { return c->GetConstraint<TEmptyConstraintNode>() == empty; })) { + return empty; + } + return nullptr; +} + +////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +TVarIndexConstraintNode::TVarIndexConstraintNode(TExprContext& ctx, const TMapType& mapping) + : TConstraintNode(ctx, Name()) + , Mapping_(mapping) +{ + Hash_ = MurmurHash<ui64>(Mapping_.data(), Mapping_.size() * sizeof(TMapType::value_type), Hash_); + YQL_ENSURE(!Mapping_.empty()); +} + +TVarIndexConstraintNode::TVarIndexConstraintNode(TExprContext& ctx, const TVariantExprType& itemType) + : TVarIndexConstraintNode(ctx, itemType.GetUnderlyingType()->Cast<TTupleExprType>()->GetSize()) +{ +} + +TVarIndexConstraintNode::TVarIndexConstraintNode(TExprContext& ctx, size_t mapItemsCount) + : TConstraintNode(ctx, Name()) +{ + YQL_ENSURE(mapItemsCount > 0); + for (size_t i = 0; i < mapItemsCount; ++i) { + Mapping_.push_back(std::make_pair(i, i)); + } + Hash_ = MurmurHash<ui64>(Mapping_.data(), Mapping_.size() * sizeof(TMapType::value_type), Hash_); + YQL_ENSURE(!Mapping_.empty()); +} + +TVarIndexConstraintNode::TVarIndexConstraintNode(TExprContext& ctx, const NYT::TNode& serialized) + : TVarIndexConstraintNode(ctx, NodeToMapping(serialized)) +{ +} + +TVarIndexConstraintNode::TVarIndexConstraintNode(TVarIndexConstraintNode&& constr) + : TConstraintNode(std::move(static_cast<TConstraintNode&>(constr))) + , Mapping_(std::move(constr.Mapping_)) +{ +} + +TVarIndexConstraintNode::TMapType TVarIndexConstraintNode::NodeToMapping(const NYT::TNode& serialized) { + TMapType mapping; + try { + for (const auto& pair: serialized.AsList()) { + mapping.insert(std::make_pair<ui32, ui32>(pair.AsList().front().AsUint64(), pair.AsList().back().AsUint64())); + } + } catch (...) { + YQL_ENSURE(false, "Cannot deserialize " << Name() << " constraint: " << CurrentExceptionMessage()); + } + return mapping; +} + +TVarIndexConstraintNode::TMapType TVarIndexConstraintNode::GetReverseMapping() const { + TMapType reverseMapping; + std::transform(Mapping_.cbegin(), Mapping_.cend(), + std::back_inserter(reverseMapping), + [] (const std::pair<size_t, size_t>& p) { return std::make_pair(p.second, p.first); } + ); + ::Sort(reverseMapping); + return reverseMapping; +} + +bool TVarIndexConstraintNode::Equals(const TConstraintNode& node) const { + if (this == &node) { + return true; + } + if (GetHash() != node.GetHash()) { + return false; + } + if (GetName() != node.GetName()) { + return false; + } + if (auto c = dynamic_cast<const TVarIndexConstraintNode*>(&node)) { + return GetIndexMapping() == c->GetIndexMapping(); + } + return false; +} + +bool TVarIndexConstraintNode::Includes(const TConstraintNode& node) const { + if (this == &node) { + return true; + } + if (GetName() != node.GetName()) { + return false; + } + if (auto c = dynamic_cast<const TVarIndexConstraintNode*>(&node)) { + for (auto& pair: c->Mapping_) { + if (auto p = Mapping_.FindPtr(pair.first)) { + if (*p != pair.second) { + return false; + } + } else { + return false; + } + } + return true; + } + return false; +} + +void TVarIndexConstraintNode::Out(IOutputStream& out) const { + TConstraintNode::Out(out); + out.Write('('); + + bool first = true; + for (auto& item: Mapping_) { + if (!first) { + out.Write(','); + } + out << item.first << ':' << item.second; + first = false; + } + out.Write(')'); +} + +void TVarIndexConstraintNode::ToJson(NJson::TJsonWriter& out) const { + out.OpenArray(); + for (const auto& [resultIndex, originalIndex]: Mapping_) { + out.OpenArray(); + out.Write(resultIndex); + out.Write(originalIndex); + out.CloseArray(); + } + out.CloseArray(); +} + +NYT::TNode TVarIndexConstraintNode::ToYson() const { + return std::accumulate(Mapping_.cbegin(), Mapping_.cend(), + NYT::TNode::CreateList(), + [](NYT::TNode node, const TMapType::value_type& p) { + return std::move(node).Add(NYT::TNode::CreateList().Add(p.first).Add(p.second)); + }); +} + +const TVarIndexConstraintNode* TVarIndexConstraintNode::MakeCommon(const std::vector<const TConstraintSet*>& constraints, TExprContext& ctx) { + if (constraints.empty()) { + return nullptr; + } + + if (constraints.size() == 1) { + return constraints.front()->GetConstraint<TVarIndexConstraintNode>(); + } + + TVarIndexConstraintNode::TMapType mapping; + for (size_t i = 0; i < constraints.size(); ++i) { + if (auto varIndex = constraints[i]->GetConstraint<TVarIndexConstraintNode>()) { + mapping.insert(varIndex->GetIndexMapping().begin(), varIndex->GetIndexMapping().end()); + } + } + if (mapping.empty()) { + return nullptr; + } + ::SortUnique(mapping); + return ctx.MakeConstraint<TVarIndexConstraintNode>(std::move(mapping)); +} + +////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +TMultiConstraintNode::TMultiConstraintNode(TExprContext& ctx, TMapType&& items) + : TConstraintNode(ctx, Name()) + , Items_(std::move(items)) +{ + YQL_ENSURE(Items_.size()); + for (auto& item: Items_) { + Hash_ = MurmurHash<ui64>(&item.first, sizeof(item.first), Hash_); + for (auto c: item.second.GetAllConstraints()) { + const auto itemHash = c->GetHash(); + Hash_ = MurmurHash<ui64>(&itemHash, sizeof(itemHash), Hash_); + } + } +} + +TMultiConstraintNode::TMultiConstraintNode(TExprContext& ctx, ui32 index, const TConstraintSet& constraints) + : TMultiConstraintNode(ctx, TMapType{{index, constraints}}) +{ +} + +TMultiConstraintNode::TMultiConstraintNode(TExprContext& ctx, const NYT::TNode& serialized) + : TMultiConstraintNode(ctx, NodeToMapping(ctx, serialized)) +{ +} + +TMultiConstraintNode::TMultiConstraintNode(TMultiConstraintNode&& constr) + : TConstraintNode(std::move(static_cast<TConstraintNode&>(constr))) + , Items_(std::move(constr.Items_)) +{ +} + +TMultiConstraintNode::TMapType TMultiConstraintNode::NodeToMapping(TExprContext& ctx, const NYT::TNode& serialized) { + TMapType mapping; + try { + for (const auto& pair: serialized.AsList()) { + mapping.insert(std::make_pair((ui32)pair.AsList().front().AsUint64(), ctx.MakeConstraintSet(pair.AsList().back()))); + } + } catch (...) { + YQL_ENSURE(false, "Cannot deserialize " << Name() << " constraint: " << CurrentExceptionMessage()); + } + return mapping; +} + +bool TMultiConstraintNode::Equals(const TConstraintNode& node) const { + if (this == &node) { + return true; + } + if (GetHash() != node.GetHash()) { + return false; + } + if (GetName() != node.GetName()) { + return false; + } + if (auto c = dynamic_cast<const TMultiConstraintNode*>(&node)) { + return GetItems() == c->GetItems(); + } + return false; +} + +bool TMultiConstraintNode::Includes(const TConstraintNode& node) const { + if (this == &node) { + return true; + } + if (GetName() != node.GetName()) { + return false; + } + + if (auto m = dynamic_cast<const TMultiConstraintNode*>(&node)) { + for (auto& item: Items_) { + const auto it = m->Items_.find(item.first); + if (it == m->Items_.end()) { + if (!item.second.GetConstraint<TEmptyConstraintNode>()) { + return false; + } + continue; + } + + for (auto c: it->second.GetAllConstraints()) { + auto cit = item.second.GetConstraint(c->GetName()); + if (!cit) { + return false; + } + if (!cit->Includes(*c)) { + return false; + } + } + } + return true; + } + return false; +} + +bool TMultiConstraintNode::FilteredIncludes(const TConstraintNode& node, const THashSet<TString>& blacklist) const { + if (this == &node) { + return true; + } + if (GetName() != node.GetName()) { + return false; + } + + if (auto m = dynamic_cast<const TMultiConstraintNode*>(&node)) { + for (auto& item: Items_) { + const auto it = m->Items_.find(item.first); + if (it == m->Items_.end()) { + if (!item.second.GetConstraint<TEmptyConstraintNode>()) { + return false; + } + continue; + } + + for (auto c: it->second.GetAllConstraints()) { + if (!blacklist.contains(c->GetName())) { + const auto cit = item.second.GetConstraint(c->GetName()); + if (!cit) { + return false; + } + if (!cit->Includes(*c)) { + return false; + } + } + } + } + return true; + } + return false; +} + +void TMultiConstraintNode::Out(IOutputStream& out) const { + TConstraintNode::Out(out); + out.Write('('); + bool first = true; + for (auto& item: Items_) { + if (!first) { + out.Write(','); + } + out << item.first << ':' << '{'; + bool firstConstr = true; + for (auto c: item.second.GetAllConstraints()) { + if (!firstConstr) { + out.Write(','); + } + out << *c; + firstConstr = false; + } + out << '}'; + first = false; + } + out.Write(')'); +} + +void TMultiConstraintNode::ToJson(NJson::TJsonWriter& out) const { + out.OpenMap(); + for (const auto& [index, constraintSet] : Items_) { + out.WriteKey(ToString(index)); + constraintSet.ToJson(out); + } + out.CloseMap(); +} + +NYT::TNode TMultiConstraintNode::ToYson() const { + return std::accumulate(Items_.cbegin(), Items_.cend(), + NYT::TNode::CreateList(), + [](NYT::TNode node, const TMapType::value_type& p) { + return std::move(node).Add(NYT::TNode::CreateList().Add(p.first).Add(p.second.ToYson())); + }); +} + +const TMultiConstraintNode* TMultiConstraintNode::MakeCommon(const std::vector<const TConstraintSet*>& constraints, TExprContext& ctx) { + if (constraints.empty()) { + return nullptr; + } else if (constraints.size() == 1) { + return constraints.front()->GetConstraint<TMultiConstraintNode>(); + } + + TMapType multiItems; + for (auto c: constraints) { + if (auto m = c->GetConstraint<TMultiConstraintNode>()) { + multiItems.insert(m->GetItems().begin(), m->GetItems().end()); + } else if (!c->GetConstraint<TEmptyConstraintNode>()) { + return nullptr; + } + } + if (multiItems.empty()) { + return nullptr; + } + + multiItems.sort(); + // Remove duplicates + // For duplicated items keep only Empty constraint + auto cur = multiItems.begin(); + while (cur != multiItems.end()) { + auto start = cur; + do { + ++cur; + } while (cur != multiItems.end() && start->first == cur->first); + + switch (std::distance(start, cur)) { + case 0: + break; + case 1: + if (start->second.GetConstraint<TEmptyConstraintNode>()) { + cur = multiItems.erase(start, cur); + } + break; + default: + { + std::vector<TMapType::value_type> nonEmpty; + std::copy_if(start, cur, std::back_inserter(nonEmpty), + [] (const TMapType::value_type& v) { + return !v.second.GetConstraint<TEmptyConstraintNode>(); + } + ); + start->second.Clear(); + if (nonEmpty.empty()) { + start->second.AddConstraint(ctx.MakeConstraint<TEmptyConstraintNode>()); + } else if (nonEmpty.size() == 1) { + start->second = nonEmpty.front().second; + } + cur = multiItems.erase(start + 1, cur); + } + } + } + if (!multiItems.empty()) { + return ctx.MakeConstraint<TMultiConstraintNode>(std::move(multiItems)); + } + + return nullptr; +} + +const TMultiConstraintNode* TMultiConstraintNode::FilterConstraints(TExprContext& ctx, const TConstraintSet::TPredicate& predicate) const { + auto items = Items_; + bool hasContent = false, hasChanges = false; + for (auto& item : items) { + hasChanges = hasChanges || item.second.FilterConstraints(predicate); + hasContent = hasContent || item.second; + } + + return hasContent ? hasChanges ? ctx.MakeConstraint<TMultiConstraintNode>(std::move(items)) : this : nullptr; +} + +} // namespace NYql + +////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +void Out<NYql::TPartOfConstraintBase::TPathType>(IOutputStream& out, const NYql::TPartOfConstraintBase::TPathType& path) { + if (path.empty()) + out.Write('/'); + else { + bool first = true; + for (const auto& c : path) { + if (first) + first = false; + else + out.Write('/'); + out.Write(c); + } + } +} + +template<> +void Out<NYql::TPartOfConstraintBase::TSetType>(IOutputStream& out, const NYql::TPartOfConstraintBase::TSetType& c) { + out.Write('{'); + bool first = true; + for (const auto& path : c) { + if (first) + first = false; + else + out.Write(','); + out << path; + } + out.Write('}'); +} + +template<> +void Out<NYql::TPartOfConstraintBase::TSetOfSetsType>(IOutputStream& out, const NYql::TPartOfConstraintBase::TSetOfSetsType& c) { + out.Write('{'); + bool first = true; + for (const auto& path : c) { + if (first) + first = false; + else + out.Write(','); + out << path; + } + out.Write('}'); +} + +template<> +void Out<NYql::TConstraintNode>(IOutputStream& out, const NYql::TConstraintNode& c) { + c.Out(out); +} + +template<> +void Out<NYql::TConstraintSet>(IOutputStream& out, const NYql::TConstraintSet& s) { + s.Out(out); +} + +template<> +void Out<NYql::TSortedConstraintNode>(IOutputStream& out, const NYql::TSortedConstraintNode& c) { + c.Out(out); +} + +template<> +void Out<NYql::TChoppedConstraintNode>(IOutputStream& out, const NYql::TChoppedConstraintNode& c) { + c.Out(out); +} + +template<> +void Out<NYql::TUniqueConstraintNode>(IOutputStream& out, const NYql::TUniqueConstraintNode& c) { + c.Out(out); +} + +template<> +void Out<NYql::TDistinctConstraintNode>(IOutputStream& out, const NYql::TDistinctConstraintNode& c) { + c.Out(out); +} + +template<> +void Out<NYql::TPartOfSortedConstraintNode>(IOutputStream& out, const NYql::TPartOfSortedConstraintNode& c) { + c.Out(out); +} + +template<> +void Out<NYql::TPartOfChoppedConstraintNode>(IOutputStream& out, const NYql::TPartOfChoppedConstraintNode& c) { + c.Out(out); +} + +template<> +void Out<NYql::TPartOfUniqueConstraintNode>(IOutputStream& out, const NYql::TPartOfUniqueConstraintNode& c) { + c.Out(out); +} + +template<> +void Out<NYql::TPartOfDistinctConstraintNode>(IOutputStream& out, const NYql::TPartOfDistinctConstraintNode& c) { + c.Out(out); +} + +template<> +void Out<NYql::TEmptyConstraintNode>(IOutputStream& out, const NYql::TEmptyConstraintNode& c) { + c.Out(out); +} + +template<> +void Out<NYql::TVarIndexConstraintNode>(IOutputStream& out, const NYql::TVarIndexConstraintNode& c) { + c.Out(out); +} + +template<> +void Out<NYql::TMultiConstraintNode>(IOutputStream& out, const NYql::TMultiConstraintNode& c) { + c.Out(out); +} diff --git a/yql/essentials/ast/yql_constraint.h b/yql/essentials/ast/yql_constraint.h new file mode 100644 index 00000000000..191d3195079 --- /dev/null +++ b/yql/essentials/ast/yql_constraint.h @@ -0,0 +1,563 @@ +#pragma once + +#include <yql/essentials/utils/yql_panic.h> + +#include <library/cpp/containers/stack_vector/stack_vec.h> +#include <library/cpp/containers/sorted_vector/sorted_vector.h> +#include <library/cpp/json/json_writer.h> +#include <library/cpp/yson/node/node_builder.h> + +#include <util/stream/output.h> + +#include <deque> +#include <unordered_map> + +namespace NYql { + +struct TExprContext; +class TTypeAnnotationNode; +class TStructExprType; +class TVariantExprType; + +class TConstraintNode { +protected: + TConstraintNode(TExprContext& ctx, std::string_view name); + TConstraintNode(TConstraintNode&& constr); +public: + using TListType = std::vector<const TConstraintNode*>; + + struct THash { + size_t operator()(const TConstraintNode* node) const { + return node->GetHash(); + } + }; + + struct TEqual { + bool operator()(const TConstraintNode* one, const TConstraintNode* two) const { + return one->Equals(*two); + } + }; + + struct TCompare { + inline bool operator()(const TConstraintNode* l, const TConstraintNode* r) const { + return l->GetName() < r->GetName(); + } + + inline bool operator()(const std::string_view name, const TConstraintNode* r) const { + return name < r->GetName(); + } + + inline bool operator()(const TConstraintNode* l, const std::string_view name) const { + return l->GetName() < name; + } + }; + + virtual ~TConstraintNode() = default; + + ui64 GetHash() const { + return Hash_; + } + + virtual bool Equals(const TConstraintNode& node) const = 0; + virtual bool Includes(const TConstraintNode& node) const { + return Equals(node); + } + virtual void Out(IOutputStream& out) const; + virtual void ToJson(NJson::TJsonWriter& out) const = 0; + virtual NYT::TNode ToYson() const = 0; + + virtual bool IsApplicableToType(const TTypeAnnotationNode&) const { return true; } + + template <typename T> + const T* Cast() const { + static_assert(std::is_base_of<TConstraintNode, T>::value, + "Should be derived from TConstraintNode"); + + const auto ret = dynamic_cast<const T*>(this); + YQL_ENSURE(ret, "Cannot cast '" << Name_ << "' constraint to " << T::Name()); + return ret; + } + + const std::string_view& GetName() const { + return Name_; + } +protected: + ui64 Hash_; + std::string_view Name_; +}; + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +class TPartOfConstraintBase : public TConstraintNode { +protected: + TPartOfConstraintBase(TExprContext& ctx, std::string_view name); + TPartOfConstraintBase(TPartOfConstraintBase&& constr) = default; +public: + // Path to constraint components through nested static containers (Struct/Tuple/Multi). + // All elements is struct member name or tuple element index. + // Empty deque means root. + using TPathType = std::deque<std::string_view>; + + using TSetType = NSorted::TSimpleSet<TPathType>; + using TSetOfSetsType = NSorted::TSimpleSet<TSetType>; + + using TPathFilter = std::function<bool(const TPathType&)>; + using TPathReduce = std::function<std::vector<TPathType>(const TPathType&)>; + + static const TTypeAnnotationNode* GetSubTypeByPath(const TPathType& path, const TTypeAnnotationNode& type); + + static NYT::TNode PathToNode(const TPathType& path); + static NYT::TNode SetToNode(const TSetType& set, bool withShortcut); + static NYT::TNode SetOfSetsToNode(const TSetOfSetsType& sets); + static TPathType NodeToPath(TExprContext& ctx, const NYT::TNode& node); + static TSetType NodeToSet(TExprContext& ctx, const NYT::TNode& node); + static TSetOfSetsType NodeToSetOfSets(TExprContext& ctx, const NYT::TNode& node); + +protected: + virtual const TPartOfConstraintBase* DoFilterFields(TExprContext& ctx, const TPathFilter& predicate) const = 0; + virtual const TPartOfConstraintBase* DoRenameFields(TExprContext& ctx, const TPathReduce& reduce) const = 0; + + static bool HasDuplicates(const TSetOfSetsType& sets); +}; + +class TConstraintWithFieldsNode : public TPartOfConstraintBase { +protected: + TConstraintWithFieldsNode(TExprContext& ctx, std::string_view name); + TConstraintWithFieldsNode(TConstraintWithFieldsNode&& constr) = default; + + // Split fields with static containers (Struct/Tuple/Multi) on separeted list of all components. + // As example (/tuple_of_two_elements) -> (/tuple_of_two_elements/0,/tuple_of_two_elements/1) + virtual const TConstraintWithFieldsNode* DoGetComplicatedForType(const TTypeAnnotationNode& type, TExprContext& ctx) const = 0; + + // Combine list of separeted fields of static containers (Struct/Tuple/Multi) in single path if possible. + // As example (/tuple_of_two_elements/0,/tuple_of_two_elements/1) -> (/tuple_of_two_elements) + virtual const TConstraintWithFieldsNode* DoGetSimplifiedForType(const TTypeAnnotationNode& type, TExprContext& ctx) const = 0; +public: + // Leaves in the set of references only those that currently contain some complete constraint. Basically all or nothing. + virtual void FilterUncompleteReferences(TPartOfConstraintBase::TSetType& references) const = 0; +}; + +template<class TTnheritConstraint> +class TPartOfConstraintBaseT : public TPartOfConstraintBase { +protected: + TPartOfConstraintBaseT(TExprContext& ctx, std::string_view name) + : TPartOfConstraintBase(ctx, name) + {} + TPartOfConstraintBaseT(TPartOfConstraintBaseT&& constr) = default; +public: + const TTnheritConstraint* FilterFields(TExprContext& ctx, const TPathFilter& predicate) const { + return static_cast<const TTnheritConstraint*>(DoFilterFields(ctx, predicate)); + } + + const TTnheritConstraint* RenameFields(TExprContext& ctx, const TPathReduce& reduce) const { + return static_cast<const TTnheritConstraint*>(DoRenameFields(ctx, reduce)); + } +}; + +template<class TTnheritConstraint> +class TConstraintWithFieldsT : public TConstraintWithFieldsNode { +protected: + TConstraintWithFieldsT(TExprContext& ctx, std::string_view name) + : TConstraintWithFieldsNode(ctx, name) + {} + TConstraintWithFieldsT(TConstraintWithFieldsT&& constr) = default; +public: + const TTnheritConstraint* FilterFields(TExprContext& ctx, const TPathFilter& predicate) const { + return static_cast<const TTnheritConstraint*>(DoFilterFields(ctx, predicate)); + } + + const TTnheritConstraint* RenameFields(TExprContext& ctx, const TPathReduce& reduce) const { + return static_cast<const TTnheritConstraint*>(DoRenameFields(ctx, reduce)); + } + + const TTnheritConstraint* GetComplicatedForType(const TTypeAnnotationNode& type, TExprContext& ctx) const { + return static_cast<const TTnheritConstraint*>(DoGetComplicatedForType(type, ctx)); + } + + const TTnheritConstraint* GetSimplifiedForType(const TTypeAnnotationNode& type, TExprContext& ctx) const { + return static_cast<const TTnheritConstraint*>(DoGetSimplifiedForType(type, ctx)); + } +}; + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +class TConstraintSet { +public: + TConstraintSet() = default; + TConstraintSet(const TConstraintSet&) = default; + TConstraintSet(TConstraintSet&&) = default; + + TConstraintSet& operator =(const TConstraintSet&) = default; + TConstraintSet& operator =(TConstraintSet&&) = default; + + template <class TConstraintType> + const TConstraintType* GetConstraint() const { + auto res = GetConstraint(TConstraintType::Name()); + return res ? res->template Cast<TConstraintType>() : nullptr; + } + + template <class TConstraintType> + const TConstraintType* RemoveConstraint() { + auto res = RemoveConstraint(TConstraintType::Name()); + return res ? res->template Cast<TConstraintType>() : nullptr; + } + + const TConstraintNode::TListType& GetAllConstraints() const { + return Constraints_; + } + + void Clear() { + Constraints_.clear(); + } + + explicit operator bool() const { + return !Constraints_.empty(); + } + + bool operator ==(const TConstraintSet& s) const { + return Constraints_ == s.Constraints_; + } + + bool operator !=(const TConstraintSet& s) const { + return Constraints_ != s.Constraints_; + } + + const TConstraintNode* GetConstraint(std::string_view name) const; + void AddConstraint(const TConstraintNode* node); + const TConstraintNode* RemoveConstraint(std::string_view name); + + using TPredicate = std::function<bool(const std::string_view& name)>; + bool FilterConstraints(const TPredicate& predicate); + + void Out(IOutputStream& out) const; + void ToJson(NJson::TJsonWriter& writer) const; + NYT::TNode ToYson() const; +private: + TConstraintNode::TListType Constraints_; +}; + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +class TSortedConstraintNode final: public TConstraintWithFieldsT<TSortedConstraintNode> { +public: + using TContainerType = TSmallVec<std::pair<TSetType, bool>>; +private: + friend struct TExprContext; + + TSortedConstraintNode(TExprContext& ctx, TContainerType&& content); + TSortedConstraintNode(TExprContext& ctx, const NYT::TNode& serialized); + TSortedConstraintNode(TSortedConstraintNode&& constr); +public: + static constexpr std::string_view Name() { + return "Sorted"; + } + + const TContainerType& GetContent() const { + return Content_; + } + + TSetType GetFullSet() const; + + bool Equals(const TConstraintNode& node) const override; + bool Includes(const TConstraintNode& node) const override; + void Out(IOutputStream& out) const override; + void ToJson(NJson::TJsonWriter& out) const override; + NYT::TNode ToYson() const override; + + bool IsPrefixOf(const TSortedConstraintNode& node) const; + bool StartsWith(const TSetType& prefix) const; + + const TSortedConstraintNode* CutPrefix(size_t newPrefixLength, TExprContext& ctx) const; + + void FilterUncompleteReferences(TSetType& references) const final; + + static const TSortedConstraintNode* MakeCommon(const std::vector<const TConstraintSet*>& constraints, TExprContext& ctx); + const TSortedConstraintNode* MakeCommon(const TSortedConstraintNode* other, TExprContext& ctx) const; + + bool IsApplicableToType(const TTypeAnnotationNode& type) const override; +private: + const TConstraintWithFieldsNode* DoFilterFields(TExprContext& ctx, const TPathFilter& predicate) const final; + const TConstraintWithFieldsNode* DoRenameFields(TExprContext& ctx, const TPathReduce& reduce) const final; + + const TConstraintWithFieldsNode* DoGetComplicatedForType(const TTypeAnnotationNode& type, TExprContext& ctx) const final; + const TConstraintWithFieldsNode* DoGetSimplifiedForType(const TTypeAnnotationNode& type, TExprContext& ctx) const final; + + static TContainerType NodeToContainer(TExprContext& ctx, const NYT::TNode& serialized); + + TContainerType Content_; +}; + +class TChoppedConstraintNode final: public TConstraintWithFieldsT<TChoppedConstraintNode> { +private: + friend struct TExprContext; + + TChoppedConstraintNode(TExprContext& ctx, TSetOfSetsType&& sets); + TChoppedConstraintNode(TExprContext& ctx, const TSetType& keys); + TChoppedConstraintNode(TExprContext& ctx, const NYT::TNode& serialized); + TChoppedConstraintNode(TChoppedConstraintNode&& constr); +public: + static constexpr std::string_view Name() { + return "Chopped"; + } + + const TSetOfSetsType& GetContent() const { return Sets_; } + + TSetType GetFullSet() const; + + bool Equals(const TConstraintNode& node) const override; + bool Includes(const TConstraintNode& node) const override; + void Out(IOutputStream& out) const override; + void ToJson(NJson::TJsonWriter& out) const override; + NYT::TNode ToYson() const override; + + bool Equals(const TSetType& prefix) const; + + void FilterUncompleteReferences(TSetType& references) const final; + + static const TChoppedConstraintNode* MakeCommon(const std::vector<const TConstraintSet*>& constraints, TExprContext& ctx); + const TChoppedConstraintNode* MakeCommon(const TChoppedConstraintNode* other, TExprContext& ctx) const; + + bool IsApplicableToType(const TTypeAnnotationNode& type) const override; +private: + const TConstraintWithFieldsNode* DoFilterFields(TExprContext& ctx, const TPathFilter& predicate) const final; + const TConstraintWithFieldsNode* DoRenameFields(TExprContext& ctx, const TPathReduce& reduce) const final; + + const TConstraintWithFieldsNode* DoGetComplicatedForType(const TTypeAnnotationNode& type, TExprContext& ctx) const final; + const TConstraintWithFieldsNode* DoGetSimplifiedForType(const TTypeAnnotationNode& type, TExprContext& ctx) const final; + + static TSetOfSetsType NodeToSets(TExprContext& ctx, const NYT::TNode& serialized); + + TSetOfSetsType Sets_; +}; + +template<bool Distinct> +class TUniqueConstraintNodeBase final: public TConstraintWithFieldsT<TUniqueConstraintNodeBase<Distinct>> { +public: + using TBase = TConstraintWithFieldsT<TUniqueConstraintNodeBase<Distinct>>; + using TContentType = NSorted::TSimpleSet<TConstraintWithFieldsNode::TSetOfSetsType>; +protected: + friend struct TExprContext; + + TUniqueConstraintNodeBase(TExprContext& ctx, const std::vector<std::string_view>& columns); + TUniqueConstraintNodeBase(TExprContext& ctx, TContentType&& sets); + TUniqueConstraintNodeBase(TExprContext& ctx, const NYT::TNode& serialized); + TUniqueConstraintNodeBase(TUniqueConstraintNodeBase&& constr); +public: + static constexpr std::string_view Name() { + return Distinct ? "Distinct" : "Unique"; + } + + const TContentType& GetContent() const { return Content_; } + + TPartOfConstraintBase::TSetType GetFullSet() const; + + bool Equals(const TConstraintNode& node) const override; + bool Includes(const TConstraintNode& node) const override; + void Out(IOutputStream& out) const override; + void ToJson(NJson::TJsonWriter& out) const override; + NYT::TNode ToYson() const override; + + bool IsOrderBy(const TSortedConstraintNode& sorted) const; + bool ContainsCompleteSet(const std::vector<std::string_view>& columns) const; + + void FilterUncompleteReferences(TPartOfConstraintBase::TSetType& references) const final; + + static const TUniqueConstraintNodeBase* MakeCommon(const std::vector<const TConstraintSet*>& constraints, TExprContext& ctx); + + const TUniqueConstraintNodeBase* MakeCommon(const TUniqueConstraintNodeBase* other, TExprContext& ctx) const; + + static const TUniqueConstraintNodeBase* Merge(const TUniqueConstraintNodeBase* one, const TUniqueConstraintNodeBase* two, TExprContext& ctx); + + bool IsApplicableToType(const TTypeAnnotationNode& type) const override; +private: + const TConstraintWithFieldsNode* DoFilterFields(TExprContext& ctx, const TPartOfConstraintBase::TPathFilter& predicate) const final; + const TConstraintWithFieldsNode* DoRenameFields(TExprContext& ctx, const TPartOfConstraintBase::TPathReduce& reduce) const final; + + const TConstraintWithFieldsNode* DoGetComplicatedForType(const TTypeAnnotationNode& type, TExprContext& ctx) const final; + const TConstraintWithFieldsNode* DoGetSimplifiedForType(const TTypeAnnotationNode& type, TExprContext& ctx) const final; + + static TConstraintWithFieldsNode::TSetOfSetsType ColumnsListToSets(const std::vector<std::string_view>& columns); + static TContentType DedupSets(TContentType&& sets); + static TContentType MakeCommonContent(const TContentType& one, const TContentType& two); + + static TContentType NodeToContent(TExprContext& ctx, const NYT::TNode& serialized); + + TContentType Content_; +}; + +using TUniqueConstraintNode = TUniqueConstraintNodeBase<false>; +using TDistinctConstraintNode = TUniqueConstraintNodeBase<true>; + +template<class TOriginalConstraintNode> +class TPartOfConstraintNode final: public TPartOfConstraintBaseT<TPartOfConstraintNode<TOriginalConstraintNode>> { +public: + using TBase = TPartOfConstraintBaseT<TPartOfConstraintNode<TOriginalConstraintNode>>; + using TMainConstraint = TOriginalConstraintNode; + using TPartType = NSorted::TSimpleMap<typename TBase::TPathType, typename TBase::TPathType>; + using TReversePartType = NSorted::TSimpleMap<typename TBase::TPathType, NSorted::TSimpleSet<typename TBase::TPathType>>; + using TMapType = std::unordered_map<const TMainConstraint*, TPartType>; +private: + friend struct TExprContext; + + TPartOfConstraintNode(TPartOfConstraintNode&& constr); + TPartOfConstraintNode(TExprContext& ctx, TMapType&& mapping); + TPartOfConstraintNode(TExprContext& ctx, const NYT::TNode& serialized); +public: + static constexpr std::string_view Name(); + + const TMapType& GetColumnMapping() const; + TMapType GetColumnMapping(const std::string_view& asField) const; + TMapType GetColumnMapping(TExprContext& ctx, const std::string_view& prefix) const; + + bool Equals(const TConstraintNode& node) const override; + bool Includes(const TConstraintNode& node) const override; + void Out(IOutputStream& out) const override; + void ToJson(NJson::TJsonWriter& out) const override; + NYT::TNode ToYson() const override; + + const TPartOfConstraintNode* ExtractField(TExprContext& ctx, const std::string_view& field) const; + const TPartOfConstraintNode* CompleteOnly(TExprContext& ctx) const; + const TPartOfConstraintNode* RemoveOriginal(TExprContext& ctx, const TMainConstraint* original) const; + + static const TPartOfConstraintNode* MakeCommon(const std::vector<const TConstraintSet*>& constraints, TExprContext& ctx); + + static TMapType GetCommonMapping(const TMainConstraint* complete, const TPartOfConstraintNode* incomplete = nullptr, const std::string_view& field = {}); + static void UniqueMerge(TMapType& output, TMapType&& input); + static TMapType ExtractField(const TMapType& mapping, const std::string_view& field); + + static const TMainConstraint* MakeComplete(TExprContext& ctx, const TMapType& mapping, const TMainConstraint* original, const std::string_view& field = {}); + static const TMainConstraint* MakeComplete(TExprContext& ctx, const TPartOfConstraintNode* partial, const TMainConstraint* original, const std::string_view& field = {}); + + bool IsApplicableToType(const TTypeAnnotationNode& type) const override; +private: + const TPartOfConstraintBase* DoFilterFields(TExprContext& ctx, const TPartOfConstraintBase::TPathFilter& predicate) const final; + const TPartOfConstraintBase* DoRenameFields(TExprContext& ctx, const TPartOfConstraintBase::TPathReduce& reduce) const final; + + TMapType Mapping_; +}; + +using TPartOfSortedConstraintNode = TPartOfConstraintNode<TSortedConstraintNode>; +using TPartOfChoppedConstraintNode = TPartOfConstraintNode<TChoppedConstraintNode>; +using TPartOfUniqueConstraintNode = TPartOfConstraintNode<TUniqueConstraintNode>; +using TPartOfDistinctConstraintNode = TPartOfConstraintNode<TDistinctConstraintNode>; + +template<> +constexpr std::string_view TPartOfSortedConstraintNode::Name() { + return "PartOfSorted"; +} + +template<> +constexpr std::string_view TPartOfChoppedConstraintNode::Name() { + return "PartOfChopped"; +} + +template<> +constexpr std::string_view TPartOfUniqueConstraintNode::Name() { + return "PartOfUnique"; +} + +template<> +constexpr std::string_view TPartOfDistinctConstraintNode::Name() { + return "PartOfDistinct"; +} + +class TEmptyConstraintNode final: public TConstraintNode { +protected: + friend struct TExprContext; + + TEmptyConstraintNode(TExprContext& ctx); + TEmptyConstraintNode(TExprContext& ctx, const NYT::TNode& serialized); + TEmptyConstraintNode(TEmptyConstraintNode&& constr); + +public: + static constexpr std::string_view Name() { + return "Empty"; + } + + bool Equals(const TConstraintNode& node) const override; + void ToJson(NJson::TJsonWriter& out) const override; + NYT::TNode ToYson() const override; + + static const TEmptyConstraintNode* MakeCommon(const std::vector<const TConstraintSet*>& constraints, TExprContext& ctx); +}; + +class TVarIndexConstraintNode final: public TConstraintNode { +public: + using TMapType = NSorted::TSimpleMap<ui32, ui32>; + +protected: + friend struct TExprContext; + + TVarIndexConstraintNode(TExprContext& ctx, const TMapType& mapping); + TVarIndexConstraintNode(TExprContext& ctx, const TVariantExprType& itemType); + TVarIndexConstraintNode(TExprContext& ctx, size_t mapItemsCount); + TVarIndexConstraintNode(TExprContext& ctx, const NYT::TNode& serialized); + TVarIndexConstraintNode(TVarIndexConstraintNode&& constr); +public: + static constexpr std::string_view Name() { + return "VarIndex"; + } + + // multimap: result index -> {original indices} + const TMapType& GetIndexMapping() const { + return Mapping_; + } + + // original index -> {result indices} + TMapType GetReverseMapping() const; + + bool Equals(const TConstraintNode& node) const override; + bool Includes(const TConstraintNode& node) const override; + void Out(IOutputStream& out) const override; + void ToJson(NJson::TJsonWriter& out) const override; + NYT::TNode ToYson() const override; + + static const TVarIndexConstraintNode* MakeCommon(const std::vector<const TConstraintSet*>& constraints, TExprContext& ctx); + +private: + static TMapType NodeToMapping(const NYT::TNode& serialized); + + TMapType Mapping_; +}; + +class TMultiConstraintNode final: public TConstraintNode { +public: + using TMapType = NSorted::TSimpleMap<ui32, TConstraintSet>; +public: + TMultiConstraintNode(TExprContext& ctx, TMapType&& items); + TMultiConstraintNode(TExprContext& ctx, ui32 index, const TConstraintSet& constraints); + TMultiConstraintNode(TExprContext& ctx, const NYT::TNode& serialized); + TMultiConstraintNode(TMultiConstraintNode&& constr); +public: + static constexpr std::string_view Name() { + return "Multi"; + } + + const TMapType& GetItems() const { + return Items_; + } + + const TConstraintSet* GetItem(ui32 index) const { + return Items_.FindPtr(index); + } + + bool Equals(const TConstraintNode& node) const override; + bool Includes(const TConstraintNode& node) const override; + void Out(IOutputStream& out) const override; + void ToJson(NJson::TJsonWriter& out) const override; + NYT::TNode ToYson() const override; + + static const TMultiConstraintNode* MakeCommon(const std::vector<const TConstraintSet*>& constraints, TExprContext& ctx); + + const TMultiConstraintNode* FilterConstraints(TExprContext& ctx, const TConstraintSet::TPredicate& predicate) const; + + bool FilteredIncludes(const TConstraintNode& node, const THashSet<TString>& blacklist) const; + +private: + static TMapType NodeToMapping(TExprContext& ctx, const NYT::TNode& serialized); + + TMapType Items_; +}; + +} // namespace NYql + diff --git a/yql/essentials/ast/yql_constraint_ut.cpp b/yql/essentials/ast/yql_constraint_ut.cpp new file mode 100644 index 00000000000..6c305d06222 --- /dev/null +++ b/yql/essentials/ast/yql_constraint_ut.cpp @@ -0,0 +1,211 @@ +#include "yql_constraint.h" +#include "yql_expr.h" + +#include <yql/essentials/utils/yql_panic.h> + +#include <library/cpp/testing/unittest/registar.h> +#include <library/cpp/yson/node/node_io.h> + + +namespace NYql { + +Y_UNIT_TEST_SUITE(TSerializeConstrains) { + + Y_UNIT_TEST(SerializeSorted) { + TExprContext ctx; + auto c = ctx.MakeConstraint<TSortedConstraintNode>(TSortedConstraintNode::TContainerType{ + std::pair{TPartOfConstraintBase::TSetType{TPartOfConstraintBase::TPathType{"a", "b"}}, true}, + std::pair{TPartOfConstraintBase::TSetType{TPartOfConstraintBase::TPathType{"d"}}, false}, + std::pair{TPartOfConstraintBase::TSetType{TPartOfConstraintBase::TPathType{"e"}, TPartOfConstraintBase::TPathType{"f", "g"}}, false}, + }); + auto yson = c->ToYson(); + UNIT_ASSERT_VALUES_EQUAL(NYT::NodeToCanonicalYsonString(yson), R"([[[["a";"b"]];%true];[["d"];%false];[["e";["f";"g"]];%false]])"); + auto c2 = ctx.MakeConstraint<TSortedConstraintNode>(yson); + UNIT_ASSERT_VALUES_EQUAL(NYT::NodeToCanonicalYsonString(yson), NYT::NodeToCanonicalYsonString(c2->ToYson())); + UNIT_ASSERT_EQUAL(c, c2); + } + + Y_UNIT_TEST(SerializeChopped) { + TExprContext ctx; + auto c = ctx.MakeConstraint<TChoppedConstraintNode>(TPartOfConstraintBase::TSetOfSetsType{ + TPartOfConstraintBase::TSetType{TPartOfConstraintBase::TPathType{"a"}}, + TPartOfConstraintBase::TSetType{TPartOfConstraintBase::TPathType{"a", "b"}}, + TPartOfConstraintBase::TSetType{TPartOfConstraintBase::TPathType{"c", "d"}, TPartOfConstraintBase::TPathType{"e"}}, + }); + auto yson = c->ToYson(); + UNIT_ASSERT_VALUES_EQUAL(NYT::NodeToCanonicalYsonString(yson), R"(["a";[["a";"b"]];[["c";"d"];"e"]])"); + auto c2 = ctx.MakeConstraint<TChoppedConstraintNode>(yson); + UNIT_ASSERT_VALUES_EQUAL(NYT::NodeToCanonicalYsonString(yson), NYT::NodeToCanonicalYsonString(c2->ToYson())); + UNIT_ASSERT_EQUAL(c, c2); + } + + template <class TUniqueConstraint> + void CheckSerializeUnique() { + TExprContext ctx; + auto c = ctx.MakeConstraint<TUniqueConstraint>(typename TUniqueConstraint::TContentType{ + TConstraintWithFieldsNode::TSetOfSetsType{ + TPartOfConstraintBase::TSetType{TPartOfConstraintBase::TPathType{"a"}}, + TPartOfConstraintBase::TSetType{TPartOfConstraintBase::TPathType{"a", "b"}} + }, + TConstraintWithFieldsNode::TSetOfSetsType{ + TPartOfConstraintBase::TSetType{TPartOfConstraintBase::TPathType{"c"}}, + TPartOfConstraintBase::TSetType{TPartOfConstraintBase::TPathType{"d"}, TPartOfConstraintBase::TPathType{"e"}} + }, + }); + auto yson = c->ToYson(); + UNIT_ASSERT_VALUES_EQUAL(NYT::NodeToCanonicalYsonString(yson), R"([["a";[["a";"b"]]];["c";["d";"e"]]])"); + auto c2 = ctx.MakeConstraint<TUniqueConstraint>(yson); + UNIT_ASSERT_VALUES_EQUAL(NYT::NodeToCanonicalYsonString(yson), NYT::NodeToCanonicalYsonString(c2->ToYson())); + UNIT_ASSERT_EQUAL(c, c2); + } + + Y_UNIT_TEST(SerializeUnique) { + CheckSerializeUnique<TUniqueConstraintNode>(); + } + + Y_UNIT_TEST(SerializeDistint) { + CheckSerializeUnique<TDistinctConstraintNode>(); + } + + Y_UNIT_TEST(SerializeEmpty) { + TExprContext ctx; + auto c = ctx.MakeConstraint<TEmptyConstraintNode>(); + auto yson = c->ToYson(); + UNIT_ASSERT_VALUES_EQUAL(NYT::NodeToCanonicalYsonString(yson), R"(#)"); + auto c2 = ctx.MakeConstraint<TEmptyConstraintNode>(yson); + UNIT_ASSERT_VALUES_EQUAL(NYT::NodeToCanonicalYsonString(yson), NYT::NodeToCanonicalYsonString(c2->ToYson())); + UNIT_ASSERT_EQUAL(c, c2); + } + + Y_UNIT_TEST(SerializeVarIndex) { + TExprContext ctx; + auto c = ctx.MakeConstraint<TVarIndexConstraintNode>(TVarIndexConstraintNode::TMapType{ + std::pair{1u, 3u}, + std::pair{0u, 1u}, + }); + auto yson = c->ToYson(); + UNIT_ASSERT_VALUES_EQUAL(NYT::NodeToCanonicalYsonString(yson), R"([[0u;1u];[1u;3u]])"); + auto c2 = ctx.MakeConstraint<TVarIndexConstraintNode>(yson); + UNIT_ASSERT_VALUES_EQUAL(NYT::NodeToCanonicalYsonString(yson), NYT::NodeToCanonicalYsonString(c2->ToYson())); + UNIT_ASSERT_EQUAL(c, c2); + } + + Y_UNIT_TEST(SerializeMulti) { + TExprContext ctx; + + TConstraintSet s1; + s1.AddConstraint(ctx.MakeConstraint<TEmptyConstraintNode>()); + s1.AddConstraint( + ctx.MakeConstraint<TSortedConstraintNode>(TSortedConstraintNode::TContainerType{ + std::pair{TPartOfConstraintBase::TSetType{TPartOfConstraintBase::TPathType{"a"}}, true}, + std::pair{TPartOfConstraintBase::TSetType{TPartOfConstraintBase::TPathType{"b"}}, false}, + }) + ); + + TConstraintSet s2; + s2.AddConstraint( + ctx.MakeConstraint<TUniqueConstraintNode>(typename TUniqueConstraintNode::TContentType{ + TConstraintWithFieldsNode::TSetOfSetsType{ + TPartOfConstraintBase::TSetType{TPartOfConstraintBase::TPathType{"a"}}, + TPartOfConstraintBase::TSetType{TPartOfConstraintBase::TPathType{"b"}} + }, + TConstraintWithFieldsNode::TSetOfSetsType{ + TPartOfConstraintBase::TSetType{TPartOfConstraintBase::TPathType{"c"}}, + TPartOfConstraintBase::TSetType{TPartOfConstraintBase::TPathType{"d"}, TPartOfConstraintBase::TPathType{"e"}} + }, + }) + ); + s2.AddConstraint( + ctx.MakeConstraint<TVarIndexConstraintNode>(TVarIndexConstraintNode::TMapType{ + std::pair{0u, 1u}, + std::pair{1u, 2u}, + }) + ); + + auto c = ctx.MakeConstraint<TMultiConstraintNode>(TMultiConstraintNode::TMapType{ + std::pair{0u, s1}, + std::pair{1u, s2}, + }); + auto yson = c->ToYson(); + UNIT_ASSERT_VALUES_EQUAL(NYT::NodeToCanonicalYsonString(yson), R"([[0u;{"Empty"=#;"Sorted"=[[["a"];%true];[["b"];%false]]}];[1u;{"Unique"=[["a";"b"];["c";["d";"e"]]];"VarIndex"=[[0u;1u];[1u;2u]]}]])"); + auto c2 = ctx.MakeConstraint<TMultiConstraintNode>(yson); + UNIT_ASSERT_VALUES_EQUAL(NYT::NodeToCanonicalYsonString(yson), NYT::NodeToCanonicalYsonString(c2->ToYson())); + UNIT_ASSERT_EQUAL(c, c2); + } + + Y_UNIT_TEST(SerializeConstrainSet) { + TExprContext ctx; + + TConstraintSet s; + s.AddConstraint(ctx.MakeConstraint<TEmptyConstraintNode>()); + s.AddConstraint( + ctx.MakeConstraint<TSortedConstraintNode>(TSortedConstraintNode::TContainerType{ + std::pair{TPartOfConstraintBase::TSetType{TPartOfConstraintBase::TPathType{"a"}}, true}, + std::pair{TPartOfConstraintBase::TSetType{TPartOfConstraintBase::TPathType{"b"}}, false}, + }) + ); + s.AddConstraint( + ctx.MakeConstraint<TUniqueConstraintNode>(typename TUniqueConstraintNode::TContentType{ + TConstraintWithFieldsNode::TSetOfSetsType{ + TPartOfConstraintBase::TSetType{TPartOfConstraintBase::TPathType{"a"}}, + TPartOfConstraintBase::TSetType{TPartOfConstraintBase::TPathType{"b"}} + }, + TConstraintWithFieldsNode::TSetOfSetsType{ + TPartOfConstraintBase::TSetType{TPartOfConstraintBase::TPathType{"c"}}, + TPartOfConstraintBase::TSetType{TPartOfConstraintBase::TPathType{"d"}, TPartOfConstraintBase::TPathType{"e"}} + }, + }) + ); + s.AddConstraint( + ctx.MakeConstraint<TVarIndexConstraintNode>(TVarIndexConstraintNode::TMapType{ + std::pair{0u, 1u}, + std::pair{1u, 2u}, + }) + ); + s.AddConstraint( + ctx.MakeConstraint<TVarIndexConstraintNode>(TVarIndexConstraintNode::TMapType{ + std::pair{0u, 1u}, + std::pair{1u, 2u}, + }) + ); + + TConstraintSet inner; + inner.AddConstraint(ctx.MakeConstraint<TEmptyConstraintNode>()); + s.AddConstraint( + ctx.MakeConstraint<TMultiConstraintNode>(TMultiConstraintNode::TMapType{ + std::pair{0u, inner}, + }) + ); + + auto yson = s.ToYson(); + UNIT_ASSERT_VALUES_EQUAL(NYT::NodeToCanonicalYsonString(yson), R"({"Empty"=#;"Multi"=[[0u;{"Empty"=#}]];"Sorted"=[[["a"];%true];[["b"];%false]];"Unique"=[["a";"b"];["c";["d";"e"]]];"VarIndex"=[[0u;1u];[1u;2u]]})"); + auto s2 = ctx.MakeConstraintSet(yson); + UNIT_ASSERT_VALUES_EQUAL(NYT::NodeToCanonicalYsonString(yson), NYT::NodeToCanonicalYsonString(s2.ToYson())); + UNIT_ASSERT_EQUAL(s, s2); + + s.Clear(); + yson = s.ToYson(); + UNIT_ASSERT_VALUES_EQUAL(NYT::NodeToCanonicalYsonString(yson), R"({})"); + auto s3 = ctx.MakeConstraintSet(yson); + UNIT_ASSERT_VALUES_EQUAL(NYT::NodeToCanonicalYsonString(yson), NYT::NodeToCanonicalYsonString(s3.ToYson())); + UNIT_ASSERT_EQUAL(s, s3); + } + + Y_UNIT_TEST(DeserializeBadConstrainSet) { + TExprContext ctx; + UNIT_ASSERT_EXCEPTION( + ctx.MakeConstraintSet(NYT::NodeFromYsonString(R"(#)")), + NYql::TYqlPanic + ); + UNIT_ASSERT_EXCEPTION( + ctx.MakeConstraintSet(NYT::NodeFromYsonString(R"({"Unknown"=[]})")), + NYql::TYqlPanic + ); + UNIT_ASSERT_EXCEPTION( + ctx.MakeConstraintSet(NYT::NodeFromYsonString(R"({"Empty"=1u})")), + NYql::TYqlPanic + ); + } +} + +} // namespace NYql diff --git a/yql/essentials/ast/yql_errors.cpp b/yql/essentials/ast/yql_errors.cpp new file mode 100644 index 00000000000..c16c3650e16 --- /dev/null +++ b/yql/essentials/ast/yql_errors.cpp @@ -0,0 +1 @@ +#include "yql_errors.h" diff --git a/yql/essentials/ast/yql_errors.h b/yql/essentials/ast/yql_errors.h new file mode 100644 index 00000000000..325ea0e55fc --- /dev/null +++ b/yql/essentials/ast/yql_errors.h @@ -0,0 +1,3 @@ +#pragma once + +#include <yql/essentials/public/issue/yql_issue.h> diff --git a/yql/essentials/ast/yql_expr.cpp b/yql/essentials/ast/yql_expr.cpp new file mode 100644 index 00000000000..08a10e46bbb --- /dev/null +++ b/yql/essentials/ast/yql_expr.cpp @@ -0,0 +1,3917 @@ +#include "yql_expr.h" +#include "yql_ast_annotation.h" +#include "yql_gc_nodes.h" + +#include <yql/essentials/utils/utf8.h> +#include <yql/essentials/utils/fetch/fetch.h> +#include <yql/essentials/core/issue/yql_issue.h> + +#include <yql/essentials/parser/pg_catalog/catalog.h> +#include <library/cpp/containers/stack_vector/stack_vec.h> +#include <util/generic/hash.h> +#include <util/generic/size_literals.h> +#include <util/string/cast.h> +#include <util/string/join.h> + +#include <util/digest/fnv.h> +#include <util/digest/murmur.h> +#include <util/digest/city.h> +#include <util/digest/numeric.h> +#include <util/string/cast.h> + +#include <openssl/sha.h> + +#include <map> +#include <unordered_set> + +namespace NYql { + +const TStringBuf ZeroString = ""; +const char Dot = '.'; +const char Sep = '/'; +const TStringBuf PkgPrefix = "pkg"; + +void ReportError(TExprContext& ctx, const TIssue& issue) { + ctx.AddError(issue); +} + +namespace { + template <typename T> + const T* FindType(const T& sample, TExprContext& ctx) { + const auto it = ctx.TypeSet.find(&sample); + return ctx.TypeSet.cend() != it ? static_cast<const T*>(*it) : nullptr; + } + + template <typename T, typename... Args> + const T* AddType(TExprContext& ctx, Args&&... args) { + Y_DEBUG_ABORT_UNLESS(!ctx.Frozen); + ctx.TypeNodes.emplace(new T(std::forward<Args>(args)...)); + const auto ins = ctx.TypeSet.emplace(ctx.TypeNodes.top().get()); + return static_cast<const T*>(*ins.first); + } + + void DumpNode(const TExprNode& node, IOutputStream& out, ui32 level, TNodeSet& visited) { + for (ui32 i = 0; i < level; ++i) { + out.Write(' '); + } + + out << "#" << node.UniqueId() << " [" << node.Type() << "]"; + if (node.Type() == TExprNode::Atom || node.Type() == TExprNode::Callable || node.Type() == TExprNode::Argument) { + out << " <" << node.Content() << ">"; + } + + constexpr bool WithTypes = false; + constexpr bool WithConstraints = false; + constexpr bool WithScope = false; + + if constexpr (WithTypes) { + if (node.GetTypeAnn()) { + out << ", " << *node.GetTypeAnn(); + } + } + + if constexpr (WithConstraints) { + if (node.GetState() >= TExprNode::EState::ConstrComplete) { + out << ", " << node.GetConstraintSet(); + } + } + + if constexpr (WithScope) { + if (const auto scope = node.GetDependencyScope()) { + out << ", ("; + if (const auto outer = scope->first) { + out << '#' << outer->UniqueId(); + } else { + out << "null"; + } + + out << ','; + if (const auto inner = scope->second) { + out << '#' << inner->UniqueId(); + } else { + out << "null"; + } + out << ')'; + } + } + + bool showChildren = true; + if (!visited.emplace(&node).second) { + if (node.Type() == TExprNode::Callable || node.Type() == TExprNode::List + || node.Type() == TExprNode::Lambda || node.Type() == TExprNode::Arguments) { + out << " ..."; + showChildren = false; + } + } + + out << "\n"; + if (showChildren) { + for (auto& child : node.Children()) { + DumpNode(*child, out, level + 1, visited); + } + } + } + + struct TContext { + struct TFrame { + THashMap<TString, TExprNode::TListType> Bindings; + THashMap<TString, TString> Imports; + TExprNode::TListType Return; + }; + + TExprContext& Expr; + TVector<TFrame> Frames; + TLibraryCohesion Cohesion; + std::unordered_set<TString> OverrideLibraries; + + TNodeOnNodeOwnedMap DeepClones; + + const TAnnotationNodeMap* Annotations = nullptr; + IModuleResolver* ModuleResolver = nullptr; + IUrlListerManager* UrlListerManager = nullptr; + ui32 TypeAnnotationIndex = Max<ui32>(); + TString File; + ui16 SyntaxVersion = 0; + + TContext(TExprContext& expr) + : Expr(expr) + { + } + + void AddError(const TAstNode& node, const TString& message) { + Expr.AddError(TIssue(node.GetPosition(), message)); + } + + void AddInfo(const TAstNode& node, const TString& message) { + auto issue = TIssue(node.GetPosition(), message); + issue.SetCode(TIssuesIds::INFO, TSeverityIds::S_INFO); + Expr.AddError(issue); + } + + TExprNode::TPtr&& ProcessNode(const TAstNode& node, TExprNode::TPtr&& exprNode) { + if (TypeAnnotationIndex != Max<ui32>()) { + exprNode->SetTypeAnn(CompileTypeAnnotation(node)); + } + + return std::move(exprNode); + } + + void PushFrame() { + Frames.push_back(TFrame()); + } + + void PopFrame() { + Frames.pop_back(); + } + + TExprNode::TListType FindBinding(const TStringBuf& name) const { + for (auto it = Frames.crbegin(); it != Frames.crend(); ++it) { + const auto r = it->Bindings.find(name); + if (it->Bindings.cend() != r) + return r->second; + } + + return {}; + } + + TString FindImport(const TStringBuf& name) const { + for (auto it = Frames.crbegin(); it != Frames.crend(); ++it) { + const auto r = it->Imports.find(name); + if (it->Imports.cend() != r) + return r->second; + } + + return TString(); + } + + const TTypeAnnotationNode* CompileTypeAnnotation(const TAstNode& node) { + auto ptr = Annotations->FindPtr(&node); + if (!ptr || TypeAnnotationIndex >= ptr->size()) { + AddError(node, "Failed to load type annotation"); + return nullptr; + } + + return CompileTypeAnnotationNode(*(*ptr)[TypeAnnotationIndex]); + } + + const TTypeAnnotationNode* CompileTypeAnnotationNode(const TAstNode& node) { + if (node.IsAtom()) { + if (node.GetContent() == TStringBuf(".")) { + return nullptr; + } + else if (node.GetContent() == TStringBuf("Unit")) { + return Expr.MakeType<TUnitExprType>(); + } + else if (node.GetContent() == TStringBuf("World")) { + return Expr.MakeType<TWorldExprType>(); + } + else if (node.GetContent() == TStringBuf("Void")) { + return Expr.MakeType<TVoidExprType>(); + } + else if (node.GetContent() == TStringBuf("Null")) { + return Expr.MakeType<TNullExprType>(); + } + else if (node.GetContent() == TStringBuf("Generic")) { + return Expr.MakeType<TGenericExprType>(); + } + else if (node.GetContent() == TStringBuf("EmptyList")) { + return Expr.MakeType<TEmptyListExprType>(); + } + else if (node.GetContent() == TStringBuf("EmptyDict")) { + return Expr.MakeType<TEmptyDictExprType>(); + } + else { + AddError(node, TStringBuilder() << "Unknown type annotation: " << node.GetContent()); + return nullptr; + } + } else { + if (node.GetChildrenCount() == 0) { + AddError(node, "Bad type annotation, expected not empty list"); + return nullptr; + } + + if (!node.GetChild(0)->IsAtom()) { + AddError(node, "Bad type annotation, first list item must be an atom"); + return nullptr; + } + + auto content = node.GetChild(0)->GetContent(); + if (content == TStringBuf("Data")) { + const auto count = node.GetChildrenCount(); + if (!(count == 2 || count == 4) || !node.GetChild(1)->IsAtom()) { + AddError(node, "Bad data type annotation"); + return nullptr; + } + + auto slot = NUdf::FindDataSlot(node.GetChild(1)->GetContent()); + if (!slot) { + AddError(node, "Bad data type annotation"); + return nullptr; + } + + if (count == 2) { + return Expr.MakeType<TDataExprType>(*slot); + } else { + if (!(node.GetChild(2)->IsAtom() && node.GetChild(3)->IsAtom())) { + AddError(node, "Bad data type annotation"); + return nullptr; + } + auto ann = Expr.MakeType<TDataExprParamsType>(*slot, node.GetChild(2)->GetContent(), node.GetChild(3)->GetContent()); + if (!ann->Validate(node.GetPosition(), Expr)) { + return nullptr; + } + + return ann; + } + } else if (content == TStringBuf("Pg")) { + const auto count = node.GetChildrenCount(); + if (count != 2 || !node.GetChild(1)->IsAtom()) { + AddError(node, "Bad data type annotation"); + return nullptr; + } + + auto typeId = NPg::LookupType(TString(node.GetChild(1)->GetContent())).TypeId; + return Expr.MakeType<TPgExprType>(typeId); + } else if (content == TStringBuf("List")) { + if (node.GetChildrenCount() != 2) { + AddError(node, "Bad list type annotation"); + return nullptr; + } + + auto r = CompileTypeAnnotationNode(*node.GetChild(1)); + if (!r) + return nullptr; + + return Expr.MakeType<TListExprType>(r); + } else if (content == TStringBuf("Stream")) { + if (node.GetChildrenCount() != 2) { + AddError(node, "Bad stream type annotation"); + return nullptr; + } + + auto r = CompileTypeAnnotationNode(*node.GetChild(1)); + if (!r) + return nullptr; + + return Expr.MakeType<TStreamExprType>(r); + } else if (content == TStringBuf("Struct")) { + TVector<const TItemExprType*> children; + for (size_t index = 1; index < node.GetChildrenCount(); ++index) { + auto r = CompileTypeAnnotationNode(*node.GetChild(index)); + if (!r) + return nullptr; + + if (r->GetKind() != ETypeAnnotationKind::Item) { + AddError(node, "Expected item type annotation"); + return nullptr; + } + + children.push_back(r->Cast<TItemExprType>()); + } + + auto ann = Expr.MakeType<TStructExprType>(children); + if (!ann->Validate(node.GetPosition(), Expr)) { + return nullptr; + } + + return ann; + } else if (content == TStringBuf("Multi")) { + TTypeAnnotationNode::TListType children; + for (size_t index = 1; index < node.GetChildrenCount(); ++index) { + auto r = CompileTypeAnnotationNode(*node.GetChild(index)); + if (!r) + return nullptr; + + children.push_back(r); + } + + auto ann = Expr.MakeType<TMultiExprType>(children); + if (!ann->Validate(node.GetPosition(), Expr)) { + return nullptr; + } + + return ann; + } else if (content == TStringBuf("Tuple")) { + TTypeAnnotationNode::TListType children; + for (size_t index = 1; index < node.GetChildrenCount(); ++index) { + auto r = CompileTypeAnnotationNode(*node.GetChild(index)); + if (!r) + return nullptr; + + children.push_back(r); + } + + auto ann = Expr.MakeType<TTupleExprType>(children); + if (!ann->Validate(node.GetPosition(), Expr)) { + return nullptr; + } + + return ann; + } else if (content == TStringBuf("Item")) { + if (node.GetChildrenCount() != 3 || !node.GetChild(1)->IsAtom()) { + AddError(node, "Bad item type annotation"); + return nullptr; + } + + auto r = CompileTypeAnnotationNode(*node.GetChild(2)); + if (!r) + return nullptr; + + return Expr.MakeType<TItemExprType>(TString(node.GetChild(1)->GetContent()), r); + } else if (content == TStringBuf("Optional")) { + if (node.GetChildrenCount() != 2) { + AddError(node, "Bad optional type annotation"); + return nullptr; + } + + auto r = CompileTypeAnnotationNode(*node.GetChild(1)); + if (!r) + return nullptr; + + return Expr.MakeType<TOptionalExprType>(r); + } else if (content == TStringBuf("Type")) { + if (node.GetChildrenCount() != 2) { + AddError(node, "Bad type annotation"); + return nullptr; + } + + auto r = CompileTypeAnnotationNode(*node.GetChild(1)); + if (!r) + return nullptr; + + return Expr.MakeType<TTypeExprType>(r); + } + else if (content == TStringBuf("Dict")) { + if (node.GetChildrenCount() != 3) { + AddError(node, "Bad dict annotation"); + return nullptr; + } + + auto r1 = CompileTypeAnnotationNode(*node.GetChild(1)); + if (!r1) + return nullptr; + + auto r2 = CompileTypeAnnotationNode(*node.GetChild(2)); + if (!r2) + return nullptr; + + auto ann = Expr.MakeType<TDictExprType>(r1, r2); + if (!ann->Validate(node.GetPosition(), Expr)) { + return nullptr; + } + + return ann; + } + else if (content == TStringBuf("Callable")) { + if (node.GetChildrenCount() <= 2) { + AddError(node, "Bad callable annotation"); + return nullptr; + } + + TVector<TCallableExprType::TArgumentInfo> args; + size_t optCount = 0; + TString payload; + if (!node.GetChild(1)->IsList()) { + AddError(node, "Bad callable annotation - expected list"); + return nullptr; + } + + if (node.GetChild(1)->GetChildrenCount() > 2) { + AddError(node, "Bad callable annotation - too many settings nodes"); + return nullptr; + } + + if (node.GetChild(1)->GetChildrenCount() > 0) { + auto optChild = node.GetChild(1)->GetChild(0); + if (!optChild->IsAtom()) { + AddError(node, "Bad callable annotation - expected atom"); + return nullptr; + } + + if (!TryFromString(optChild->GetContent(), optCount)) { + AddError(node, TStringBuilder() << "Bad callable optional args count: " << node.GetChild(1)->GetContent()); + return nullptr; + } + } + + if (node.GetChild(1)->GetChildrenCount() > 1) { + auto payloadChild = node.GetChild(1)->GetChild(1); + if (!payloadChild->IsAtom()) { + AddError(node, "Bad callable annotation - expected atom"); + return nullptr; + } + + payload = payloadChild->GetContent(); + } + + auto retSettings = node.GetChild(2); + if (!retSettings->IsList() || retSettings->GetChildrenCount() != 1) { + AddError(node, "Bad callable annotation - expected list of size 1"); + return nullptr; + } + + auto retType = CompileTypeAnnotationNode(*retSettings->GetChild(0)); + if (!retType) + return nullptr; + + for (size_t index = 3; index < node.GetChildrenCount(); ++index) { + auto argSettings = node.GetChild(index); + if (!argSettings->IsList() || argSettings->GetChildrenCount() < 1 || + argSettings->GetChildrenCount() > 3) { + AddError(node, "Bad callable annotation - expected list of size 1..3"); + return nullptr; + } + + auto r = CompileTypeAnnotationNode(*argSettings->GetChild(0)); + if (!r) + return nullptr; + + TCallableExprType::TArgumentInfo arg; + arg.Type = r; + + if (argSettings->GetChildrenCount() > 1) { + auto nameChild = argSettings->GetChild(1); + if (!nameChild->IsAtom()) { + AddError(node, "Bad callable annotation - expected atom"); + return nullptr; + } + + arg.Name = nameChild->GetContent(); + } + + if (argSettings->GetChildrenCount() > 2) { + auto flagsChild = argSettings->GetChild(2); + if (!flagsChild->IsAtom()) { + AddError(node, "Bad callable annotation - expected atom"); + return nullptr; + } + + if (!TryFromString(flagsChild->GetContent(), arg.Flags)) { + AddError(node, "Bad callable annotation - bad integer"); + return nullptr; + } + } + + args.push_back(arg); + } + + auto ann = Expr.MakeType<TCallableExprType>(retType, args, optCount, payload); + if (!ann->Validate(node.GetPosition(), Expr)) { + return nullptr; + } + + return ann; + } else if (content == TStringBuf("Resource")) { + if (node.GetChildrenCount() != 2 || !node.GetChild(1)->IsAtom()) { + AddError(node, "Bad resource type annotation"); + return nullptr; + } + + return Expr.MakeType<TResourceExprType>(TString(node.GetChild(1)->GetContent())); + } else if (content == TStringBuf("Tagged")) { + if (node.GetChildrenCount() != 3 || !node.GetChild(2)->IsAtom()) { + AddError(node, "Bad tagged type annotation"); + return nullptr; + } + + auto type = CompileTypeAnnotationNode(*node.GetChild(1)); + if (!type) + return nullptr; + + TString tag(node.GetChild(2)->GetContent()); + auto ann = Expr.MakeType<TTaggedExprType>(type, tag); + if (!ann->Validate(node.GetPosition(), Expr)) { + return nullptr; + } + + return ann; + } else if (content == TStringBuf("Error")) { + if (node.GetChildrenCount() != 5 || !node.GetChild(1)->IsAtom() || + !node.GetChild(2)->IsAtom() || !node.GetChild(3)->IsAtom() || !node.GetChild(4)->IsAtom()) { + AddError(node, "Bad error type annotation"); + return nullptr; + } + + ui32 row; + if (!TryFromString(node.GetChild(1)->GetContent(), row)) { + AddError(node, TStringBuilder() << "Bad integer: " << node.GetChild(1)->GetContent()); + return nullptr; + } + + ui32 column; + if (!TryFromString(node.GetChild(2)->GetContent(), column)) { + AddError(node, TStringBuilder() << "Bad integer: " << node.GetChild(2)->GetContent()); + return nullptr; + } + + auto file = TString(node.GetChild(3)->GetContent()); + return Expr.MakeType<TErrorExprType>(TIssue(TPosition(column, row, file), TString(node.GetChild(4)->GetContent()))); + } else if (content == TStringBuf("Variant")) { + if (node.GetChildrenCount() != 2) { + AddError(node, "Bad variant type annotation"); + return nullptr; + } + + auto r = CompileTypeAnnotationNode(*node.GetChild(1)); + if (!r) + return nullptr; + + auto ann = Expr.MakeType<TVariantExprType>(r); + if (!ann->Validate(node.GetPosition(), Expr)) { + return nullptr; + } + + return ann; + } else if (content == TStringBuf("Stream")) { + if (node.GetChildrenCount() != 2) { + AddError(node, "Bad stream type annotation"); + return nullptr; + } + + auto r = CompileTypeAnnotationNode(*node.GetChild(1)); + if (!r) + return nullptr; + + return Expr.MakeType<TStreamExprType>(r); + } else if (content == TStringBuf("Flow")) { + if (node.GetChildrenCount() != 2) { + AddError(node, "Bad flow type annotation"); + return nullptr; + } + + auto r = CompileTypeAnnotationNode(*node.GetChild(1)); + if (!r) + return nullptr; + + return Expr.MakeType<TFlowExprType>(r); + } else if (content == TStringBuf("Block")) { + if (node.GetChildrenCount() != 2) { + AddError(node, "Bad block type annotation"); + return nullptr; + } + + auto r = CompileTypeAnnotationNode(*node.GetChild(1)); + if (!r) + return nullptr; + + return Expr.MakeType<TBlockExprType>(r); + } else if (content == TStringBuf("Scalar")) { + if (node.GetChildrenCount() != 2) { + AddError(node, "Bad scalar type annotation"); + return nullptr; + } + + auto r = CompileTypeAnnotationNode(*node.GetChild(1)); + if (!r) + return nullptr; + + return Expr.MakeType<TScalarExprType>(r); + } else { + AddError(node, TStringBuilder() << "Unknown type annotation"); + return nullptr; + } + } + } + }; + + TAstNode* ConvertTypeAnnotationToAst(const TTypeAnnotationNode& annotation, TMemoryPool& pool, bool refAtoms) { + switch (annotation.GetKind()) { + case ETypeAnnotationKind::Unit: + { + return TAstNode::NewLiteralAtom(TPosition(), TStringBuf("Unit"), pool); + } + + case ETypeAnnotationKind::Tuple: + { + auto self = TAstNode::NewLiteralAtom(TPosition(), TStringBuf("Tuple"), pool); + TSmallVec<TAstNode*> children; + children.push_back(self); + for (auto& child : annotation.Cast<TTupleExprType>()->GetItems()) { + children.push_back(ConvertTypeAnnotationToAst(*child, pool, refAtoms)); + } + + return TAstNode::NewList(TPosition(), children.data(), children.size(), pool); + } + + case ETypeAnnotationKind::Struct: + { + auto self = TAstNode::NewLiteralAtom(TPosition(), TStringBuf("Struct"), pool); + TSmallVec<TAstNode*> children; + children.push_back(self); + for (auto& child : annotation.Cast<TStructExprType>()->GetItems()) { + children.push_back(ConvertTypeAnnotationToAst(*child, pool, refAtoms)); + } + + return TAstNode::NewList(TPosition(), children.data(), children.size(), pool); + } + + case ETypeAnnotationKind::List: + { + auto self = TAstNode::NewLiteralAtom(TPosition(), TStringBuf("List"), pool); + auto itemType = ConvertTypeAnnotationToAst(*annotation.Cast<TListExprType>()->GetItemType(), pool, refAtoms); + return TAstNode::NewList(TPosition(), pool, self, itemType); + } + + case ETypeAnnotationKind::Optional: + { + auto self = TAstNode::NewLiteralAtom(TPosition(), TStringBuf("Optional"), pool); + auto itemType = ConvertTypeAnnotationToAst(*annotation.Cast<TOptionalExprType>()->GetItemType(), pool, refAtoms); + return TAstNode::NewList(TPosition(), pool, self, itemType); + } + + case ETypeAnnotationKind::Item: + { + auto casted = annotation.Cast<TItemExprType>(); + auto self = TAstNode::NewLiteralAtom(TPosition(), TStringBuf("Item"), pool); + auto name = refAtoms ? + TAstNode::NewLiteralAtom(TPosition(), casted->GetName(), pool, TNodeFlags::ArbitraryContent) : + TAstNode::NewAtom(TPosition(), casted->GetName(), pool, TNodeFlags::ArbitraryContent); + auto itemType = ConvertTypeAnnotationToAst(*casted->GetItemType(), pool, refAtoms); + return TAstNode::NewList(TPosition(), pool, self, name, itemType); + } + + case ETypeAnnotationKind::Data: + { + auto self = TAstNode::NewLiteralAtom(TPosition(), TStringBuf("Data"), pool); + auto datatype = refAtoms ? + TAstNode::NewLiteralAtom(TPosition(), annotation.Cast<TDataExprType>()->GetName(), pool) : + TAstNode::NewAtom(TPosition(), annotation.Cast<TDataExprType>()->GetName(), pool); + if (auto params = dynamic_cast<const TDataExprParamsType*>(&annotation)) { + auto param1 = refAtoms ? + TAstNode::NewLiteralAtom(TPosition(), params->GetParamOne(), pool) : + TAstNode::NewAtom(TPosition(), params->GetParamOne(), pool); + + auto param2 = refAtoms ? + TAstNode::NewLiteralAtom(TPosition(), params->GetParamTwo(), pool) : + TAstNode::NewAtom(TPosition(), params->GetParamTwo(), pool); + + return TAstNode::NewList(TPosition(), pool, self, datatype, param1, param2); + } + + return TAstNode::NewList(TPosition(), pool, self, datatype); + } + + case ETypeAnnotationKind::Pg: + { + auto self = TAstNode::NewLiteralAtom(TPosition(), TStringBuf("Pg"), pool); + auto name = TAstNode::NewLiteralAtom(TPosition(), annotation.Cast<TPgExprType>()->GetName(), pool); + return TAstNode::NewList(TPosition(), pool, self, name); + } + + case ETypeAnnotationKind::World: + { + return TAstNode::NewLiteralAtom(TPosition(), TStringBuf("World"), pool); + } + + case ETypeAnnotationKind::Type: + { + auto self = TAstNode::NewLiteralAtom(TPosition(), TStringBuf("Type"), pool); + auto type = ConvertTypeAnnotationToAst(*annotation.Cast<TTypeExprType>()->GetType(), pool, refAtoms); + return TAstNode::NewList(TPosition(), pool, self, type); + } + + case ETypeAnnotationKind::Dict: + { + auto self = TAstNode::NewLiteralAtom(TPosition(), TStringBuf("Dict"), pool); + auto dictType = annotation.Cast<TDictExprType>(); + auto keyType = ConvertTypeAnnotationToAst(*dictType->GetKeyType(), pool, refAtoms); + auto payloadType = ConvertTypeAnnotationToAst(*dictType->GetPayloadType(), pool, refAtoms); + return TAstNode::NewList(TPosition(), pool, self, keyType, payloadType); + } + + case ETypeAnnotationKind::Void: + { + return TAstNode::NewLiteralAtom(TPosition(), TStringBuf("Void"), pool); + } + + case ETypeAnnotationKind::Null: + { + return TAstNode::NewLiteralAtom(TPosition(), TStringBuf("Null"), pool); + } + + case ETypeAnnotationKind::Callable: + { + auto self = TAstNode::NewLiteralAtom(TPosition(), TStringBuf("Callable"), pool); + auto callable = annotation.Cast<TCallableExprType>(); + TSmallVec<TAstNode*> mainSettings; + if (callable->GetOptionalArgumentsCount() > 0 || !callable->GetPayload().empty()) { + auto optArgs = TAstNode::NewAtom(TPosition(), ToString(callable->GetOptionalArgumentsCount()), pool); + + mainSettings.push_back(optArgs); + } + + if (!callable->GetPayload().empty()) { + auto payload = TAstNode::NewAtom(TPosition(), callable->GetPayload(), pool, TNodeFlags::ArbitraryContent); + mainSettings.push_back(payload); + } + + TSmallVec<TAstNode*> children; + children.push_back(self); + + children.push_back(TAstNode::NewList(TPosition(), mainSettings.data(), mainSettings.size(), pool)); + + TSmallVec<TAstNode*> retSettings; + retSettings.push_back(ConvertTypeAnnotationToAst(*callable->GetReturnType(), pool, refAtoms)); + children.push_back(TAstNode::NewList(TPosition(), retSettings.data(), retSettings.size(), pool)); + + for (auto& arg : callable->GetArguments()) { + TSmallVec<TAstNode*> argSettings; + argSettings.push_back(ConvertTypeAnnotationToAst(*arg.Type, pool, refAtoms)); + if (!arg.Name.empty() || arg.Flags != 0) { + auto name = TAstNode::NewAtom(TPosition(), arg.Name, pool, TNodeFlags::ArbitraryContent); + argSettings.push_back(name); + } + + if (arg.Flags != 0) { + auto flags = TAstNode::NewAtom(TPosition(), ToString(arg.Flags), pool, TNodeFlags::ArbitraryContent); + argSettings.push_back(flags); + } + + children.push_back(TAstNode::NewList(TPosition(), argSettings.data(), argSettings.size(), pool)); + } + + return TAstNode::NewList(TPosition(), children.data(), children.size(), pool); + } + + case ETypeAnnotationKind::Generic: + { + return TAstNode::NewLiteralAtom(TPosition(), TStringBuf("Generic"), pool); + } + + case ETypeAnnotationKind::Resource: + { + auto self = TAstNode::NewLiteralAtom(TPosition(), TStringBuf("Resource"), pool); + auto restype = refAtoms ? + TAstNode::NewLiteralAtom(TPosition(), annotation.Cast<TResourceExprType>()->GetTag(), pool, TNodeFlags::ArbitraryContent) : + TAstNode::NewAtom(TPosition(), annotation.Cast<TResourceExprType>()->GetTag(), pool, TNodeFlags::ArbitraryContent); + return TAstNode::NewList(TPosition(), pool, self, restype); + } + + case ETypeAnnotationKind::Tagged: + { + auto self = TAstNode::NewLiteralAtom(TPosition(), TStringBuf("Tagged"), pool); + auto type = ConvertTypeAnnotationToAst(*annotation.Cast<TTaggedExprType>()->GetBaseType(), pool, refAtoms); + auto restype = refAtoms ? + TAstNode::NewLiteralAtom(TPosition(), annotation.Cast<TTaggedExprType>()->GetTag(), pool, TNodeFlags::ArbitraryContent) : + TAstNode::NewAtom(TPosition(), annotation.Cast<TTaggedExprType>()->GetTag(), pool, TNodeFlags::ArbitraryContent); + return TAstNode::NewList(TPosition(), pool, self, type, restype); + } + + case ETypeAnnotationKind::Error: + { + auto self = TAstNode::NewLiteralAtom(TPosition(), TStringBuf("Error"), pool); + const auto& err = annotation.Cast<TErrorExprType>()->GetError(); + auto row = TAstNode::NewAtom(TPosition(), ToString(err.Position.Row), pool); + auto column = TAstNode::NewAtom(TPosition(), ToString(err.Position.Column), pool); + auto file = refAtoms ? + TAstNode::NewLiteralAtom(TPosition(), err.Position.File, pool, TNodeFlags::ArbitraryContent) : + TAstNode::NewAtom(TPosition(), err.Position.File, pool, TNodeFlags::ArbitraryContent); + auto message = refAtoms ? + TAstNode::NewLiteralAtom(TPosition(), err.GetMessage(), pool, TNodeFlags::ArbitraryContent) : + TAstNode::NewAtom(TPosition(), err.GetMessage(), pool, TNodeFlags::ArbitraryContent); + return TAstNode::NewList(TPosition(), pool, self, row, column, file, message); + } + + case ETypeAnnotationKind::Variant: + { + auto self = TAstNode::NewLiteralAtom(TPosition(), TStringBuf("Variant"), pool); + auto underlyingType = ConvertTypeAnnotationToAst(*annotation.Cast<TVariantExprType>()->GetUnderlyingType(), pool, refAtoms); + return TAstNode::NewList(TPosition(), pool, self, underlyingType); + } + + case ETypeAnnotationKind::Stream: + { + auto self = TAstNode::NewLiteralAtom(TPosition(), TStringBuf("Stream"), pool); + auto itemType = ConvertTypeAnnotationToAst(*annotation.Cast<TStreamExprType>()->GetItemType(), pool, refAtoms); + return TAstNode::NewList(TPosition(), pool, self, itemType); + } + + case ETypeAnnotationKind::Flow: + { + auto self = TAstNode::NewLiteralAtom(TPosition(), TStringBuf("Flow"), pool); + auto itemType = ConvertTypeAnnotationToAst(*annotation.Cast<TFlowExprType>()->GetItemType(), pool, refAtoms); + return TAstNode::NewList(TPosition(), pool, self, itemType); + } + + case ETypeAnnotationKind::Multi: + { + auto self = TAstNode::NewLiteralAtom(TPosition(), TStringBuf("Multi"), pool); + TSmallVec<TAstNode*> children; + children.push_back(self); + for (auto& child : annotation.Cast<TMultiExprType>()->GetItems()) { + children.push_back(ConvertTypeAnnotationToAst(*child, pool, refAtoms)); + } + + return TAstNode::NewList(TPosition(), children.data(), children.size(), pool); + } + + case ETypeAnnotationKind::EmptyList: + { + return TAstNode::NewLiteralAtom(TPosition(), TStringBuf("EmptyList"), pool); + } + case ETypeAnnotationKind::EmptyDict: + { + return TAstNode::NewLiteralAtom(TPosition(), TStringBuf("EmptyDict"), pool); + } + + case ETypeAnnotationKind::Block: + { + auto type = annotation.Cast<TBlockExprType>(); + auto self = TAstNode::NewLiteralAtom(TPosition(), TStringBuf("Block"), pool); + auto itemType = ConvertTypeAnnotationToAst(*type->GetItemType(), pool, refAtoms); + return TAstNode::NewList(TPosition(), pool, self, itemType); + } + + case ETypeAnnotationKind::Scalar: + { + auto type = annotation.Cast<TScalarExprType>(); + auto self = TAstNode::NewLiteralAtom(TPosition(), TStringBuf("Scalar"), pool); + auto itemType = ConvertTypeAnnotationToAst(*type->GetItemType(), pool, refAtoms); + return TAstNode::NewList(TPosition(), pool, self, itemType); + } + + case ETypeAnnotationKind::LastType: + YQL_ENSURE(false, "Unknown kind: " << annotation.GetKind()); + + } + } + + TAstNode* AnnotateAstNode(TAstNode* node, const TExprNode* exprNode, ui32 flags, TMemoryPool& pool, bool refAtoms) { + if (!flags) + return node; + + TSmallVec<TAstNode*> children; + if (flags & TExprAnnotationFlags::Position) { + children.push_back(PositionAsNode(node->GetPosition(), pool)); + } + + if ((flags & TExprAnnotationFlags::Types)) { + TAstNode* typeAnn = nullptr; + if (exprNode) { + YQL_ENSURE(exprNode->GetTypeAnn()); + typeAnn = ConvertTypeAnnotationToAst(*exprNode->GetTypeAnn(), pool, refAtoms); + } else { + typeAnn = TAstNode::NewLiteralAtom(node->GetPosition(), TStringBuf("."), pool); + } + + children.push_back(typeAnn); + } + + children.push_back(node); + return TAstNode::NewList(node->GetPosition(), children.data(), children.size(), pool); + } + + bool AddParameterDependencies(const TString& url, const TAstNode& node, TContext& ctx) { + auto world = ctx.FindBinding("world"); + if (!world.empty()) { + TSet<TString> names; + SubstParameters(url, Nothing(), &names); + for (const auto& name : names) { + auto nameRef = ctx.FindBinding(name); + if (nameRef.empty()) { + ctx.AddError(node, TStringBuilder() << "Name not found: " << name); + return false; + } + + TExprNode::TListType args = world; + args.insert(args.end(), nameRef.begin(), nameRef.end()); + auto newWorld = TExprNode::TListType{ ctx.Expr.NewCallable(node.GetPosition(), "Left!", { + ctx.Expr.NewCallable(node.GetPosition(), "Cons!", std::move(args)) })}; + + ctx.Frames.back().Bindings["world"] = newWorld; + world = newWorld; + } + } + + return true; + } + + TExprNode::TListType Compile(const TAstNode& node, TContext& ctx); + + TExprNode::TPtr CompileQuote(const TAstNode& node, TContext& ctx) { + if (node.IsAtom()) { + return ctx.ProcessNode(node, ctx.Expr.NewAtom(node.GetPosition(), TString(node.GetContent()), node.GetFlags())); + } else { + TExprNode::TListType children; + children.reserve(node.GetChildrenCount()); + for (ui32 index = 0; index < node.GetChildrenCount(); ++index) { + auto r = Compile(*node.GetChild(index), ctx); + if (r.empty()) + return {}; + + std::move(r.begin(), r.end(), std::back_inserter(children)); + } + + return ctx.ProcessNode(node, ctx.Expr.NewList(node.GetPosition(), std::move(children))); + } + } + + TExprNode::TListType CompileLambda(const TAstNode& node, TContext& ctx) { + if (node.GetChildrenCount() < 2) { + ctx.AddError(node, "Expected size of list at least 3."); + return {}; + } + + const auto args = node.GetChild(1); + if (!args->IsList() || args->GetChildrenCount() != 2 || !args->GetChild(0)->IsAtom() || + args->GetChild(0)->GetContent() != TStringBuf("quote") || !args->GetChild(1)->IsList()) { + ctx.AddError(node, "Lambda arguments must be a quoted list of atoms"); + return {}; + } + + const auto params = args->GetChild(1); + for (ui32 index = 0; index < params->GetChildrenCount(); ++index) { + if (!params->GetChild(index)->IsAtom()) { + ctx.AddError(node, "Lambda arguments must be a quoted list of atoms"); + return {}; + } + } + + ctx.PushFrame(); + TExprNode::TListType argNodes; + for (ui32 index = 0; index < params->GetChildrenCount(); ++index) { + auto arg = params->GetChild(index); + auto lambdaArg = ctx.ProcessNode(*arg, ctx.Expr.NewArgument(arg->GetPosition(), TString(arg->GetContent()))); + argNodes.push_back(lambdaArg); + auto& binding = ctx.Frames.back().Bindings[arg->GetContent()]; + if (!binding.empty()) { + ctx.PopFrame(); + ctx.AddError(*arg, TStringBuilder() << "Duplicated name of lambda parameter: " << arg->GetContent()); + return {}; + } + + binding = {lambdaArg}; + } + + TExprNode::TListType body; + body.reserve(node.GetChildrenCount() - 2U); + for (auto i = 2U; i < node.GetChildrenCount(); ++i) { + auto r = Compile(*node.GetChild(i), ctx); + if (r.empty()) + return {}; + std::move(r.begin(), r.end(), std::back_inserter(body)); + } + ctx.PopFrame(); + + auto arguments = ctx.ProcessNode(*args, ctx.Expr.NewArguments(args->GetPosition(), std::move(argNodes))); + auto lambda = ctx.ProcessNode(node, ctx.Expr.NewLambda(node.GetPosition(), std::move(arguments), std::move(body))); + return {lambda}; + } + + bool CompileSetPackageVersion(const TAstNode& node, TContext& ctx) { + if (node.GetChildrenCount() != 3) { + ctx.AddError(node, "Expected list of size 3"); + return false; + } + + const auto name = node.GetChild(1); + if (name->IsAtom() || name->GetChildrenCount() != 2 || !name->GetChild(0)->IsAtom() || !name->GetChild(1)->IsAtom() || + name->GetChild(0)->GetContent() != TStringBuf("quote")) { + ctx.AddError(*name, "Expected quoted atom for package name"); + return false; + } + + const auto versionNode = node.GetChild(2); + if (versionNode->IsAtom() || versionNode->GetChildrenCount() != 2 || !versionNode->GetChild(0)->IsAtom() || !versionNode->GetChild(1)->IsAtom() || + versionNode->GetChild(0)->GetContent() != TStringBuf("quote")) { + ctx.AddError(*versionNode, "Expected quoted atom for package version"); + return false; + } + + ui32 version = 0; + if (!versionNode->GetChild(1)->IsAtom() || !TryFromString(versionNode->GetChild(1)->GetContent(), version)) { + ctx.AddError(*versionNode, TString("Expected package version as number, bad content ") + versionNode->GetChild(1)->GetContent()); + return false; + } + + if (ctx.ModuleResolver && !ctx.ModuleResolver->SetPackageDefaultVersion(TString(name->GetChild(1)->GetContent()), version)) { + ctx.AddError(*versionNode, TStringBuilder() << "Unable to specify version " << version << " for package " << name->GetChild(1)->GetContent()); + return false; + } + return true; + } + + TExprNode::TPtr CompileBind(const TAstNode& node, TContext& ctx) { + if (node.GetChildrenCount() != 3) { + ctx.AddError(node, "Expected list of size 3"); + return nullptr; + } + + const auto name = node.GetChild(1); + if (!name->IsAtom()) { + ctx.AddError(*name, "Expected atom"); + return nullptr; + } + + const auto alias = node.GetChild(2); + if (alias->IsAtom() || alias->GetChildrenCount() != 2 || !alias->GetChild(0)->IsAtom() || !alias->GetChild(1)->IsAtom() || + alias->GetChild(0)->GetContent() != TStringBuf("quote")) { + ctx.AddError(*alias, "Expected quoted pair"); + return nullptr; + } + + const auto& aliasValue = alias->GetChild(1)->GetContent(); + const auto& moduleName = name->GetContent(); + TStringBuilder baseMsg; + baseMsg << "Module '" << name->GetContent() << "'"; + + const auto& import = ctx.FindImport(moduleName); + if (import.empty()) { + ctx.AddError(*name, baseMsg << " does not exist"); + return nullptr; + } + + if (ctx.ModuleResolver) { + auto exportsPtr = ctx.ModuleResolver->GetModule(import); + if (!exportsPtr) { + ctx.AddError(*name, baseMsg << "'" << import << "' does not exist"); + return nullptr; + } + + const auto& exports = exportsPtr->Symbols(); + + const auto ex = exports.find(aliasValue); + if (exports.cend() == ex) { + ctx.AddError(*alias, baseMsg << " export '" << aliasValue << "' does not exist"); + return nullptr; + } + + return ctx.Expr.DeepCopy(*ex->second, exportsPtr->ExprCtx(), ctx.DeepClones, true, false); + } else { + const auto stub = ctx.Expr.NewAtom(node.GetPosition(), "stub"); + ctx.Frames.back().Bindings[name->GetContent()] = {stub}; + ctx.Cohesion.Imports[stub.Get()] = std::make_pair(import, TString(aliasValue)); + return stub; + } + } + + bool CompileLet(const TAstNode& node, TContext& ctx) { + if (node.GetChildrenCount() < 3) { + ctx.AddError(node, "Expected size of list at least 3."); + return false; + } + + const auto name = node.GetChild(1); + if (!name->IsAtom()) { + ctx.AddError(*name, "Expected atom"); + return false; + } + + TExprNode::TListType bind; + bind.reserve(node.GetChildrenCount() - 2U); + for (auto i = 2U; i < node.GetChildrenCount(); ++i) { + auto r = Compile(*node.GetChild(i), ctx); + if (r.empty()) + return false; + std::move(r.begin(), r.end(), std::back_inserter(bind)); + } + + ctx.Frames.back().Bindings[name->GetContent()] = std::move(bind); + return true; + } + + bool CompileImport(const TAstNode& node, TContext& ctx) { + if (node.GetChildrenCount() != 3) { + ctx.AddError(node, "Expected list of size 3"); + return false; + } + + const auto name = node.GetChild(1); + if (!name->IsAtom()) { + ctx.AddError(*name, "Expected atom"); + return false; + } + + const auto alias = node.GetChild(2); + if (!alias->IsListOfSize(2) || !alias->GetChild(0)->IsAtom() || !alias->GetChild(1)->IsAtom() || + alias->GetChild(0)->GetContent() != TStringBuf("quote")) { + ctx.AddError(node, "Expected quoted pair"); + return false; + } + + ctx.Frames.back().Imports[name->GetContent()] = alias->GetChild(1)->GetContent(); + return true; + } + + bool CompileExport(const TAstNode& node, TContext& ctx) { + if (node.GetChildrenCount() != 2) { + ctx.AddError(node, "Expected list of size 2"); + return false; + } + + const auto name = node.GetChild(1); + if (!name->IsAtom()) { + ctx.AddError(*name, "Expected atom"); + return false; + } + + auto r = Compile(*node.GetChild(1), ctx); + if (r.size() != 1U) + return false; + + ctx.Cohesion.Exports.Symbols(ctx.Expr)[name->GetContent()] = std::move(r.front()); + return true; + } + + bool CompileDeclare(const TAstNode& node, TContext& ctx, bool checkOnly) { + if (node.GetChildrenCount() != 3) { + ctx.AddError(node, "Expected list of size 3"); + return false; + } + + const auto name = node.GetChild(1); + if (!name->IsAtom()) { + ctx.AddError(*name, "Expected atom"); + return false; + } + + TString nameStr = TString(name->GetContent()); + if (nameStr.size() < 2) { + ctx.AddError(*name, "Parameter name should be at least 2 characters long."); + return false; + } + + if (nameStr[0] == '$' && std::isdigit(nameStr[1])) { + ctx.AddError(*name, "Parameter name cannot start with digit."); + return false; + } + + auto typeExpr = Compile(*node.GetChild(2), ctx); + if (typeExpr.size() != 1U) + return false; + + auto typePos = node.GetChild(2)->GetPosition(); + auto parameterExpr = ctx.ProcessNode(node, + ctx.Expr.NewCallable(typePos, "Parameter", { + ctx.Expr.NewAtom(node.GetPosition(), nameStr), + std::move(typeExpr.front()) + })); + + bool error = false; + if (checkOnly) { + auto it = ctx.Frames.back().Bindings.find(nameStr); + if (it == ctx.Frames.back().Bindings.end()) { + ctx.AddError(*name, TStringBuilder() << "Missing parameter: " << nameStr); + return false; + } + + if (it->second.size() != 1 || !it->second.front()->IsCallable("Parameter")) { + error = true; + } + } else { + if (!ctx.Frames.back().Bindings.emplace(nameStr, TExprNode::TListType{ std::move(parameterExpr) }).second) { + error = true; + } + } + + if (error) { + ctx.AddError(node, TStringBuilder() << "Declare statement hides previously defined name: " << nameStr); + return false; + } + + return true; + } + + bool CompileLibraryDef(const TAstNode& node, TContext& ctx) { + if (node.GetChildrenCount() < 2 || node.GetChildrenCount() > 4) { + ctx.AddError(node, "Expected list of size from 2 to 4"); + return false; + } + + const auto name = node.GetChild(1); + if (!name->IsAtom()) { + ctx.AddError(*name, "Expected atom"); + return false; + } + + TString url; + TString token; + if (node.GetChildrenCount() > 2) { + const auto file = node.GetChild(2); + if (!file->IsAtom()) { + ctx.AddError(*file, "Expected atom"); + return false; + } + + url = file->GetContent(); + + if (node.GetChildrenCount() > 3) { + const auto tokenNode = node.GetChild(3); + if (!tokenNode->IsAtom()) { + ctx.AddError(*tokenNode, "Expected atom"); + return false; + } + + token = tokenNode->GetContent(); + } + } + + if (url && !AddParameterDependencies(url, node, ctx)) { + return false; + } + + if (!ctx.ModuleResolver) { + return true; + } + + if (url) { + if (!ctx.ModuleResolver->AddFromUrl(name->GetContent(), url, token, ctx.Expr, ctx.SyntaxVersion, 0, name->GetPosition())) { + return false; + } + } else { + if (!ctx.ModuleResolver->AddFromFile(name->GetContent(), ctx.Expr, ctx.SyntaxVersion, 0, name->GetPosition())) { + return false; + } + } + + return true; + } + + bool CompilePackageDef(const TAstNode& node, TContext& ctx) { + if (node.GetChildrenCount() < 2 || node.GetChildrenCount() > 4) { + ctx.AddError(node, "Expected list of size from 2 to 4"); + return false; + } + + auto nameNode = node.GetChild(1); + if (!nameNode->IsAtom()) { + ctx.AddError(*nameNode, "Expected atom"); + return false; + } + + auto name = TString(nameNode->GetContent()); + + TString url; + if (node.GetChildrenCount() > 2) { + const auto file = node.GetChild(2); + if (!file->IsAtom()) { + ctx.AddError(*file, "Expected atom"); + return false; + } + + url = file->GetContent(); + } + + TString token; + if (node.GetChildrenCount() > 3) { + const auto tokenNode = node.GetChild(3); + if (!tokenNode->IsAtom()) { + ctx.AddError(*tokenNode, "Expected atom"); + return false; + } + + token = tokenNode->GetContent(); + } + + if (url && !AddParameterDependencies(url, node, ctx)) { + return false; + } + + if (!ctx.ModuleResolver) { + return true; + } + + if (!ctx.UrlListerManager) { + return true; + } + + ctx.ModuleResolver->RegisterPackage(name); + + auto packageModuleName = TStringBuilder() << PkgPrefix; + + TStringBuf nameBuf(name); + while (auto part = nameBuf.NextTok(Dot)) { + packageModuleName << Sep << part; + } + + auto queue = TVector<std::pair<TString, THttpURL>> { + {packageModuleName, ParseURL(url)} + }; + + while (queue) { + auto [prefix, httpUrl] = queue.back(); + queue.pop_back(); + + TVector<TUrlListEntry> urlListEntries; + try { + urlListEntries = ctx.UrlListerManager->ListUrl(httpUrl, token); + } catch (const std::exception& e) { + ctx.AddError(*nameNode, + TStringBuilder() + << "UrlListerManager: failed to list URL \"" << httpUrl.PrintS() + << "\", details: " << e.what() + ); + + return false; + } + + for (auto& urlListEntry: urlListEntries) { + switch (urlListEntry.Type) { + case EUrlListEntryType::FILE: { + auto moduleName = TStringBuilder() + << prefix << Sep << urlListEntry.Name; + + if (ctx.OverrideLibraries.contains(moduleName)) { + continue; + } + + if (!ctx.ModuleResolver->AddFromUrl( + moduleName, urlListEntry.Url.PrintS(), token, ctx.Expr, + ctx.SyntaxVersion, 0, nameNode->GetPosition() + )) { + return false; + } + + break; + } + + case EUrlListEntryType::DIRECTORY: { + queue.push_back({ + TStringBuilder() << prefix << Sep << urlListEntry.Name, + urlListEntry.Url + }); + + break; + } + } + } + } + + return true; + } + + bool CompileOverrideLibraryDef(const TAstNode& node, TContext& ctx) { + if (node.GetChildrenCount() != 2) { + ctx.AddError(node, "Expected list of size 2"); + return false; + } + + auto nameNode = node.GetChild(1); + if (!nameNode->IsAtom()) { + ctx.AddError(*nameNode, "Expected atom"); + return false; + } + + if (!ctx.ModuleResolver) { + return true; + } + + auto overrideLibraryName = TStringBuilder() + << PkgPrefix << Sep << nameNode->GetContent(); + + if (!ctx.ModuleResolver->AddFromFile( + overrideLibraryName, ctx.Expr, ctx.SyntaxVersion, 0, nameNode->GetPosition() + )) { + return false; + } + + ctx.OverrideLibraries.insert(std::move(overrideLibraryName)); + + return true; + } + + bool CompileReturn(const TAstNode& node, TContext& ctx) { + if (node.GetChildrenCount() < 2U) { + ctx.AddError(node, "Expected non empty list."); + return false; + } + + TExprNode::TListType returns; + returns.reserve(node.GetChildrenCount() - 1U); + for (auto i = 1U; i < node.GetChildrenCount(); ++i) { + auto r = Compile(*node.GetChild(i), ctx); + if (r.empty()) + return false; + std::move(r.begin(), r.end(), std::back_inserter(returns)); + } + + ctx.Frames.back().Return = std::move(returns); + return true; + } + + TExprNode::TListType CompileFunction(const TAstNode& root, TContext& ctx, bool topLevel = false) { + if (!root.IsList()) { + ctx.AddError(root, "Expected list"); + return {}; + } + + if (ctx.Frames.size() > 1000U) { + ctx.AddError(root, "Too deep graph!"); + return {}; + } + + ctx.PushFrame(); + if (topLevel) { + for (ui32 index = 0; index < root.GetChildrenCount(); ++index) { + const auto node = root.GetChild(index); + if (!node->IsList()) { + ctx.AddError(*node, "Expected list"); + return {}; + } + + if (node->GetChildrenCount() == 0) { + ctx.AddError(*node, "Expected not empty list"); + return {}; + } + + const auto firstChild = node->GetChild(0); + if (!firstChild->IsAtom()) { + ctx.AddError(*firstChild, "Expected atom"); + return {}; + } + + if (firstChild->GetContent() == TStringBuf("library")) { + if (!CompileLibraryDef(*node, ctx)) + return {}; + } else if (firstChild->GetContent() == TStringBuf("set_package_version")) { + if (!CompileSetPackageVersion(*node, ctx)) + return {}; + } else if (firstChild->GetContent() == TStringBuf("declare")) { + if (!CompileDeclare(*node, ctx, false)) + return {}; + } else if (firstChild->GetContent() == TStringBuf("package")) { + if (!CompilePackageDef(*node, ctx)) { + return {}; + } + } else if (firstChild->GetContent() == TStringBuf("override_library")) { + if (!CompileOverrideLibraryDef(*node, ctx)) { + return {}; + } + } + } + + if (ctx.ModuleResolver) { + if (!ctx.ModuleResolver->Link(ctx.Expr)) { + return {}; + } + + ctx.ModuleResolver->UpdateNextUniqueId(ctx.Expr); + } + } + + for (ui32 index = 0; index < root.GetChildrenCount(); ++index) { + const auto node = root.GetChild(index); + if (!ctx.Frames.back().Return.empty()) { + ctx.Frames.back().Return.clear(); + ctx.AddError(*node, "Return is already exist"); + return {}; + } + + if (!node->IsList()) { + ctx.AddError(*node, "Expected list"); + return {}; + } + + if (node->GetChildrenCount() == 0) { + ctx.AddError(*node, "Expected not empty list"); + return {}; + } + + auto firstChild = node->GetChild(0); + if (!firstChild->IsAtom()) { + ctx.AddError(*firstChild, "Expected atom"); + return {}; + } + + if (firstChild->GetContent() == TStringBuf("let")) { + if (!CompileLet(*node, ctx)) + return {}; + } else if (firstChild->GetContent() == TStringBuf("return")) { + if (!CompileReturn(*node, ctx)) + return {}; + } else if (firstChild->GetContent() == TStringBuf("import")) { + if (!CompileImport(*node, ctx)) + return {}; + } else if (firstChild->GetContent() == TStringBuf("declare")) { + if (!topLevel) { + ctx.AddError(*firstChild, "Declare statements are only allowed on top level block"); + return {}; + } + + if (!CompileDeclare(*node, ctx, true)) + return {}; + + continue; + } else if (firstChild->GetContent() == TStringBuf("library")) { + if (!topLevel) { + ctx.AddError(*firstChild, "Library statements are only allowed on top level block"); + return {}; + } + + continue; + } else if (firstChild->GetContent() == TStringBuf("set_package_version")) { + if (!topLevel) { + ctx.AddError(*firstChild, "set_package_version statements are only allowed on top level block"); + return {}; + } + + continue; + } else if (firstChild->GetContent() == TStringBuf("package")) { + if (!topLevel) { + ctx.AddError(*firstChild, "Package statements are only allowed on top level block"); + return {}; + } + } else if (firstChild->GetContent() == TStringBuf("override_library")) { + if (!topLevel) { + ctx.AddError(*firstChild, "override_library statements are only allowed on top level block"); + return {}; + } + } else { + ctx.AddError(*firstChild, ToString("expected either let, return or import, but have ") + firstChild->GetContent()); + return {}; + } + } + + auto ret = std::move(ctx.Frames.back().Return); + ctx.PopFrame(); + if (ret.empty()) { + ctx.AddError(root, "No return found"); + } + + return ret; + } + + bool CompileLibrary(const TAstNode& root, TContext& ctx) { + if (!root.IsList()) { + ctx.AddError(root, "Expected list"); + return false; + } + + ctx.PushFrame(); + for (ui32 index = 0; index < root.GetChildrenCount(); ++index) { + const auto node = root.GetChild(index); + + if (!node->IsList()) { + ctx.AddError(*node, "Expected list"); + return false; + } + + if (node->GetChildrenCount() == 0) { + ctx.AddError(*node, "Expected not empty list"); + return false; + } + + auto firstChild = node->GetChild(0); + if (!firstChild->IsAtom()) { + ctx.AddError(*firstChild, "Expected atom"); + return false; + } + + if (firstChild->GetContent() == TStringBuf("let")) { + if (!CompileLet(*node, ctx)) + return false; + } else if (firstChild->GetContent() == TStringBuf("import")) { + if (!CompileImport(*node, ctx)) + return false; + } else if (firstChild->GetContent() == TStringBuf("export")) { + if (!CompileExport(*node, ctx)) + return false; + } else { + ctx.AddError(*firstChild, "expected either let, export or import"); + return false; + } + } + + ctx.PopFrame(); + return true; + } + + TExprNode::TListType Compile(const TAstNode& node, TContext& ctx) { + if (node.IsAtom()) { + const auto foundNode = ctx.FindBinding(node.GetContent()); + if (foundNode.empty()) { + ctx.AddError(node, TStringBuilder() << "Name not found: " << node.GetContent()); + return {}; + } + + return foundNode; + } + + if (node.GetChildrenCount() == 0) { + ctx.AddError(node, "Empty list, did you forget quote?"); + return {}; + } + + if (!node.GetChild(0)->IsAtom()) { + ctx.AddError(node, "First item in list is not an atom, did you forget quote?"); + return {}; + } + + auto function = node.GetChild(0)->GetContent(); + if (function == TStringBuf("quote")) { + if (node.GetChildrenCount() != 2) { + ctx.AddError(node, "Quote should have one argument"); + return {}; + } + + if (auto quote = CompileQuote(*node.GetChild(1), ctx)) + return {std::move(quote)}; + + return {}; + } + + if (function == TStringBuf("let") || function == TStringBuf("return")) { + ctx.AddError(node, "Let and return should be used only at first level or inside def"); + return {}; + } + + if (function == TStringBuf("lambda")) { + return CompileLambda(node, ctx); + } + + if (function == TStringBuf("bind")) { + if (auto bind = CompileBind(node, ctx)) + return {std::move(bind)}; + return {}; + } + + if (function == TStringBuf("block")) { + if (node.GetChildrenCount() != 2) { + ctx.AddError(node, "Block should have one argument"); + return {}; + } + + const auto quotedList = node.GetChild(1); + if (quotedList->GetChildrenCount() != 2 || !quotedList->GetChild(0)->IsAtom() || + quotedList->GetChild(0)->GetContent() != TStringBuf("quote")) { + ctx.AddError(node, "Expected quoted list"); + return {}; + } + + return CompileFunction(*quotedList->GetChild(1), ctx); + } + + TExprNode::TListType children; + children.reserve(node.GetChildrenCount() - 1U); + for (auto index = 1U; index < node.GetChildrenCount(); ++index) { + auto r = Compile(*node.GetChild(index), ctx); + if (r.empty()) + return {}; + + std::move(r.begin(), r.end(), std::back_inserter(children)); + } + + return {ctx.ProcessNode(node, ctx.Expr.NewCallable(node.GetPosition(), TString(function), std::move(children)))}; + } + + struct TFrameContext { + size_t Index = 0; + size_t Parent = 0; + std::map<size_t, const TExprNode*> Nodes; + std::vector<const TExprNode*> TopoSortedNodes; + TNodeMap<TString> Bindings; + }; + + struct TVisitNodeContext { + explicit TVisitNodeContext(TExprContext& expr) + : Expr(expr) + {} + + TExprContext& Expr; + size_t Order = 0ULL; + bool RefAtoms = false; + bool AllowFreeArgs = false; + bool NormalizeAtomFlags = false; + TNodeMap<size_t> FreeArgs; + std::unique_ptr<TMemoryPool> Pool; + std::vector<TFrameContext> Frames; + TFrameContext* CurrentFrame = nullptr; + TNodeMap<size_t> LambdaFrames; + std::map<TStringBuf, std::pair<const TExprNode*, TAstNode*>> Parameters; + + struct TCounters { + size_t References = 0ULL, Neighbors = 0ULL, Order = 0ULL, Frame = 0ULL; + }; + + TNodeMap<TCounters> References; + + const TString& FindBinding(const TExprNode* node) const { + for (const auto* frame = CurrentFrame; frame; frame = frame->Index > 0 ? &Frames[frame->Parent] : nullptr) { + const auto it = frame->Bindings.find(node); + if (frame->Bindings.cend() != it) + return it->second; + } + + static const TString stub; + return stub; + } + + size_t FindCommonAncestor(size_t one, size_t two) const { + while (one && two) { + if (one == two) + return one; + if (one > two) + one = Frames[one].Parent; + else + two = Frames[two].Parent; + } + + return 0ULL; + } + }; + + void VisitArguments(const TExprNode& node, TVisitNodeContext& ctx) { + YQL_ENSURE(node.Type() == TExprNode::Arguments); + for (const auto& arg : node.Children()) { + auto& counts = ctx.References[arg.Get()]; + ++counts.References; + YQL_ENSURE(ctx.CurrentFrame->Nodes.emplace(counts.Order = ++ctx.Order, arg.Get()).second); + } + } + + void RevisitNode(const TExprNode& node, TVisitNodeContext& ctx); + + void RevisitNode(TVisitNodeContext::TCounters& counts, const TExprNode& node, TVisitNodeContext& ctx) { + const auto nf = ctx.FindCommonAncestor(ctx.CurrentFrame->Index, counts.Frame); + if (counts.Frame != nf) { + auto& frame = ctx.Frames[counts.Frame = nf]; + frame.Nodes.emplace(counts.Order, &node); + if (TExprNode::Lambda == node.Type()) { + ctx.Frames[ctx.LambdaFrames[&node]].Parent = counts.Frame; + } else { + node.ForEachChild([&ctx](const TExprNode& child) { + RevisitNode(child, ctx); + }); + } + } + } + + void RevisitNode(const TExprNode& node, TVisitNodeContext& ctx) { + if (TExprNode::Argument != node.Type()) { + RevisitNode(ctx.References[&node], node, ctx); + } + } + + void VisitNode(const TExprNode& node, size_t neighbors, TVisitNodeContext& ctx) { + if (TExprNode::Argument == node.Type()) + return; + + auto& counts = ctx.References[&node]; + counts.Neighbors += neighbors; + if (counts.References++) { + RevisitNode(counts, node, ctx); + } else { + counts.Frame = ctx.CurrentFrame->Index; + + if (node.Type() == TExprNode::Lambda) { + YQL_ENSURE(node.ChildrenSize() > 0U); + const auto index = ctx.Frames.size(); + if (ctx.LambdaFrames.emplace(&node, index).second) { + const auto prevFrameIndex = ctx.CurrentFrame - &ctx.Frames.front(); + const auto parentIndex = ctx.CurrentFrame->Index; + ctx.Frames.emplace_back(); + ctx.CurrentFrame = &ctx.Frames.back(); + ctx.CurrentFrame->Index = index; + ctx.CurrentFrame->Parent = parentIndex; + VisitArguments(node.Head(), ctx); + for(ui32 i = 1U; i < node.ChildrenSize(); ++i) { + VisitNode(*node.Child(i), node.ChildrenSize() - 1U, ctx); + } + ctx.CurrentFrame = &ctx.Frames.front() + prevFrameIndex; + } + } else { + node.ForEachChild([&](const TExprNode& child) { + VisitNode(child, node.ChildrenSize(), ctx); + }); + } + + if (!counts.Order) + counts.Order = ++ctx.Order; + + ctx.CurrentFrame->Nodes.emplace(counts.Order, &node); + } + } + + using TRoots = TSmallVec<const TExprNode*>; + + TAstNode* ConvertFunction(TPositionHandle position, const TRoots& roots, TVisitNodeContext& ctx, ui32 annotationFlags, TMemoryPool& pool); + + TAstNode* BuildValueNode(const TExprNode& node, TVisitNodeContext& ctx, const TString& topLevelName, ui32 annotationFlags, TMemoryPool& pool, bool useBindings) { + TAstNode* res = nullptr; + const auto& name = ctx.FindBinding(&node); + if (!name.empty() && name != topLevelName && useBindings) { + res = TAstNode::NewAtom(ctx.Expr.GetPosition(node.Pos()), name, pool); + } else { + switch (node.Type()) { + case TExprNode::Atom: + { + auto quote = AnnotateAstNode(&TAstNode::QuoteAtom, nullptr, annotationFlags, pool, ctx.RefAtoms); + auto flags = ctx.NormalizeAtomFlags ? TNodeFlags::ArbitraryContent : node.Flags(); + auto content = AnnotateAstNode( + ctx.RefAtoms ? + TAstNode::NewLiteralAtom(ctx.Expr.GetPosition(node.Pos()), node.Content(), pool, flags) : + TAstNode::NewAtom(ctx.Expr.GetPosition(node.Pos()), node.Content(), pool, flags), + &node, annotationFlags, pool, ctx.RefAtoms); + + res = TAstNode::NewList(ctx.Expr.GetPosition(node.Pos()), pool, quote, content); + break; + } + + case TExprNode::List: + { + TSmallVec<TAstNode*> values; + for (const auto& child : node.Children()) { + values.push_back(BuildValueNode(*child, ctx, topLevelName, annotationFlags, pool, useBindings)); + } + + auto quote = AnnotateAstNode(&TAstNode::QuoteAtom, nullptr, annotationFlags, pool, ctx.RefAtoms); + auto list = AnnotateAstNode(TAstNode::NewList( + ctx.Expr.GetPosition(node.Pos()), values.data(), values.size(), pool), &node, annotationFlags, pool, ctx.RefAtoms); + + res = TAstNode::NewList(ctx.Expr.GetPosition(node.Pos()), pool, quote, list); + break; + } + + case TExprNode::Callable: + { + if (node.Content() == "Parameter") { + const auto& nameNode = *node.Child(0); + const auto& typeNode = *node.Child(1); + Y_UNUSED(typeNode); + + res = TAstNode::NewAtom(ctx.Expr.GetPosition(node.Pos()), nameNode.Content(), pool); + + auto it = ctx.Parameters.find(nameNode.Content()); + if (it != ctx.Parameters.end()) { + break; + } + + auto declareAtom = TAstNode::NewLiteralAtom(ctx.Expr.GetPosition(node.Pos()), TStringBuf("declare"), pool); + auto nameAtom = ctx.RefAtoms + ? TAstNode::NewLiteralAtom(ctx.Expr.GetPosition(nameNode.Pos()), nameNode.Content(), pool) + : TAstNode::NewAtom(ctx.Expr.GetPosition(nameNode.Pos()), nameNode.Content(), pool); + + TSmallVec<TAstNode*> children; + children.push_back(AnnotateAstNode(declareAtom, nullptr, annotationFlags, pool, ctx.RefAtoms)); + children.push_back(AnnotateAstNode(nameAtom, nullptr, annotationFlags, pool, ctx.RefAtoms)); + children.push_back(BuildValueNode(typeNode, ctx, topLevelName, annotationFlags, pool, false)); + auto declareNode = TAstNode::NewList(ctx.Expr.GetPosition(node.Pos()), children.data(), children.size(), pool); + declareNode = AnnotateAstNode(declareNode, &node, annotationFlags, pool, ctx.RefAtoms); + + ctx.Parameters.insert(std::make_pair(nameNode.Content(), + std::make_pair(&typeNode, declareNode))); + break; + } + + TSmallVec<TAstNode*> children; + children.push_back(AnnotateAstNode( + ctx.RefAtoms ? + TAstNode::NewLiteralAtom(ctx.Expr.GetPosition(node.Pos()), node.Content(), pool) : + TAstNode::NewAtom(ctx.Expr.GetPosition(node.Pos()), node.Content(), pool), + nullptr, annotationFlags, pool, ctx.RefAtoms)); + for (const auto& child : node.Children()) { + children.push_back(BuildValueNode(*child, ctx, topLevelName, annotationFlags, pool, useBindings)); + } + + res = TAstNode::NewList(ctx.Expr.GetPosition(node.Pos()), children.data(), children.size(), pool); + break; + } + + case TExprNode::Lambda: + { + const auto prevFrame = ctx.CurrentFrame; + const auto it = ctx.LambdaFrames.find(&node); + YQL_ENSURE(it != ctx.LambdaFrames.end()); + ctx.CurrentFrame = &ctx.Frames[it->second]; + YQL_ENSURE(node.ChildrenSize() > 0U); + const auto& args = node.Head(); + TSmallVec<TAstNode*> argsChildren; + for (const auto& arg : args.Children()) { + const auto& name = ctx.FindBinding(arg.Get()); + const auto atom = TAstNode::NewAtom(ctx.Expr.GetPosition(node.Pos()), name, pool); + argsChildren.emplace_back(AnnotateAstNode(atom, arg.Get(), annotationFlags, pool, ctx.RefAtoms)); + } + + auto argsNode = TAstNode::NewList(ctx.Expr.GetPosition(args.Pos()), argsChildren.data(), argsChildren.size(), pool); + auto argsContainer = TAstNode::NewList(ctx.Expr.GetPosition(args.Pos()), pool, + AnnotateAstNode(&TAstNode::QuoteAtom, nullptr, annotationFlags, pool, ctx.RefAtoms), + AnnotateAstNode(argsNode, nullptr, annotationFlags, pool, ctx.RefAtoms)); + + const bool block = ctx.CurrentFrame->Bindings.cend() != std::find_if(ctx.CurrentFrame->Bindings.cbegin(), ctx.CurrentFrame->Bindings.cend(), + [](const auto& bind) { return bind.first->Type() != TExprNode::Argument; } + ); + + if (block) { + TSmallVec<const TExprNode*> body(node.ChildrenSize() - 1U); + for (ui32 i = 0U; i < body.size(); ++i) + body[i] = node.Child(i + 1U); + const auto blockNode = TAstNode::NewLiteralAtom(ctx.Expr.GetPosition(node.Pos()), TStringBuf("block"), pool); + const auto quotedListNode = TAstNode::NewList(ctx.Expr.GetPosition(node.Pos()), pool, + AnnotateAstNode(&TAstNode::QuoteAtom, nullptr, annotationFlags, pool, ctx.RefAtoms), + ConvertFunction(node.Pos(), body, ctx, annotationFlags, pool)); + + const auto blockBody = TAstNode::NewList(ctx.Expr.GetPosition(node.Pos()), pool, + AnnotateAstNode(blockNode, nullptr, annotationFlags, pool, ctx.RefAtoms), + AnnotateAstNode(quotedListNode, nullptr, annotationFlags, pool, ctx.RefAtoms)); + res = AnnotateAstNode(blockBody, nullptr, annotationFlags, pool, ctx.RefAtoms); + + ctx.CurrentFrame = prevFrame; + res = TAstNode::NewList(ctx.Expr.GetPosition(node.Pos()), pool, + AnnotateAstNode(TAstNode::NewLiteralAtom( + ctx.Expr.GetPosition(node.Pos()), TStringBuf("lambda"), pool), nullptr, annotationFlags, pool, ctx.RefAtoms), + AnnotateAstNode(argsContainer, &args, annotationFlags, pool, ctx.RefAtoms), + res); + } else { + TSmallVec<TAstNode*> children(node.ChildrenSize() + 1U); + for (ui32 i = 1U; i < node.ChildrenSize(); ++i) { + children[i + 1U] = BuildValueNode(*node.Child(i), ctx, topLevelName, annotationFlags, pool, useBindings); + } + + ctx.CurrentFrame = prevFrame; + children[0] = AnnotateAstNode(TAstNode::NewLiteralAtom( + ctx.Expr.GetPosition(node.Pos()), TStringBuf("lambda"), pool), nullptr, annotationFlags, pool, ctx.RefAtoms); + children[1] = AnnotateAstNode(argsContainer, &args, annotationFlags, pool, ctx.RefAtoms); + res = TAstNode::NewList(ctx.Expr.GetPosition(node.Pos()), children.data(), children.size(), pool); + } + break; + } + + case TExprNode::World: + res = TAstNode::NewLiteralAtom(ctx.Expr.GetPosition(node.Pos()), TStringBuf("world"), pool); + break; + case TExprNode::Argument: { + YQL_ENSURE(ctx.AllowFreeArgs, "Free arguments are not allowed"); + auto iter = ctx.FreeArgs.emplace(&node, ctx.FreeArgs.size()); + res = TAstNode::NewLiteralAtom(ctx.Expr.GetPosition(node.Pos()), ctx.Expr.AppendString("_FreeArg" + ToString(iter.first->second)), pool); + break; + } + default: + YQL_ENSURE(false, "Unknown type: " << static_cast<ui32>(node.Type())); + } + } + + return AnnotateAstNode(res, &node, annotationFlags, pool, ctx.RefAtoms); + } + + TAstNode* ConvertFunction(TPositionHandle position, const TRoots& roots, TVisitNodeContext& ctx, ui32 annotationFlags, TMemoryPool& pool) { + YQL_ENSURE(!roots.empty(), "Missed roots."); + TSmallVec<TAstNode*> children; + for (const auto& node : ctx.CurrentFrame->TopoSortedNodes) { + const auto& name = ctx.FindBinding(node); + if (name.empty() || node->Type() == TExprNode::Arguments || node->Type() == TExprNode::Argument) { + continue; + } + + const auto letAtom = TAstNode::NewLiteralAtom(ctx.Expr.GetPosition(node->Pos()), TStringBuf("let"), pool); + const auto nameAtom = TAstNode::NewAtom(ctx.Expr.GetPosition(node->Pos()), name, pool); + const auto valueNode = BuildValueNode(*node, ctx, name, annotationFlags, pool, true); + + const auto letNode = TAstNode::NewList(ctx.Expr.GetPosition(node->Pos()), pool, + AnnotateAstNode(letAtom, nullptr, annotationFlags, pool, ctx.RefAtoms), + AnnotateAstNode(nameAtom, nullptr, annotationFlags, pool, ctx.RefAtoms), + valueNode); + children.push_back(AnnotateAstNode(letNode, nullptr, annotationFlags, pool, ctx.RefAtoms)); + } + + const auto returnAtom = TAstNode::NewLiteralAtom(ctx.Expr.GetPosition(position), TStringBuf("return"), pool); + TSmallVec<TAstNode*> returnChildren; + returnChildren.reserve(roots.size() + 1U); + returnChildren.emplace_back(AnnotateAstNode(returnAtom, nullptr, annotationFlags, pool, ctx.RefAtoms)); + for (const auto root : roots) { + returnChildren.emplace_back(BuildValueNode(*root, ctx, TString(), annotationFlags, pool, true)); + } + const auto returnList = TAstNode::NewList(ctx.Expr.GetPosition(position), returnChildren.data(), returnChildren.size(), pool); + children.emplace_back(AnnotateAstNode(returnList, 1U == roots.size() ? roots.front() : nullptr, annotationFlags, pool, ctx.RefAtoms)); + + if (!ctx.CurrentFrame->Index && !ctx.Parameters.empty()) { + TSmallVec<TAstNode*> parameterNodes; + parameterNodes.reserve(ctx.Parameters.size()); + + for (auto& pair : ctx.Parameters) { + parameterNodes.push_back(pair.second.second); + } + + children.insert(children.begin(), parameterNodes.begin(), parameterNodes.end()); + } + + const auto res = TAstNode::NewList(ctx.Expr.GetPosition(position), children.data(), children.size(), pool); + return AnnotateAstNode(res, nullptr, annotationFlags, pool, ctx.RefAtoms); + } + + bool InlineNode(const TExprNode& node, size_t references, size_t neighbors, const TConvertToAstSettings& settings) { + if (settings.NoInlineFunc) { + if (settings.NoInlineFunc(node)) { + return false; + } + } + + switch (node.Type()) { + case TExprNode::Argument: + return false; + case TExprNode::Atom: + if (const auto flags = node.Flags()) { + if ((TNodeFlags::BinaryContent | TNodeFlags::MultilineContent) & flags) + return false; + else { + if (TNodeFlags::ArbitraryContent & flags) + return node.Content().length() <= (references == 1U ? 0x40U : 0x10U); + else + return true; + } + } else + return true; + default: + if (neighbors < 2U) + return true; + if (const auto children = node.ChildrenSize()) + return references == 1U && children < 3U; + else + return true; + } + } + + typedef std::pair<const TExprNode*, const TExprNode*> TPairOfNodePotinters; + typedef std::unordered_set<TPairOfNodePotinters, THash<TPairOfNodePotinters>> TNodesPairSet; + typedef TNodeMap<std::pair<ui32, ui32>> TArgumentsMap; + + bool CompareExpressions(const TExprNode*& one, const TExprNode*& two, TArgumentsMap& argumentsMap, ui32 level, TNodesPairSet& visited) { + const auto ins = visited.emplace(one, two); + if (!ins.second) { + return true; + } + + if (one->Type() != two->Type()) + return false; + + if (one->ChildrenSize() != two->ChildrenSize()) + return false; + + switch (two->Type()) { + case TExprNode::Arguments: { + ui32 i1 = 0U, i2 = 0U; + one->ForEachChild([&](const TExprNode& arg){ argumentsMap.emplace(&arg, std::make_pair(level, ++i1)); }); + two->ForEachChild([&](const TExprNode& arg){ argumentsMap.emplace(&arg, std::make_pair(level, ++i2)); }); + return true; + } + case TExprNode::Argument: + if (const auto oneArg = argumentsMap.find(one), twoArg = argumentsMap.find(two); oneArg == twoArg) + return argumentsMap.cend() != oneArg || one == two; + else if (argumentsMap.cend() != oneArg && argumentsMap.cend() != twoArg) { + return oneArg->second == twoArg->second; + } + return false; + case TExprNode::Atom: + if (one->GetFlagsToCompare() != two->GetFlagsToCompare()) + return false; + [[fallthrough]]; // AUTOGENERATED_FALLTHROUGH_FIXME + case TExprNode::Callable: + if (one->Content() != two->Content()) + return false; + [[fallthrough]]; // AUTOGENERATED_FALLTHROUGH_FIXME + default: + break; + case TExprNode::Lambda: + ++level; + } + + if (const auto childs = one->ChildrenSize()) { + const auto& l = one->Children(); + const auto& r = two->Children(); + for (ui32 i = 0U; i < childs; ++i) { + if (!CompareExpressions(one = l[i].Get(), two = r[i].Get(), argumentsMap, level, visited)) { + return false; + } + } + } + + return true; + } + + using TNodeSetPtr = std::shared_ptr<TNodeSet>; + + TNodeSetPtr ExcludeFromUnresolved(const TExprNode& args, const TNodeSetPtr& unresolved) { + if (!unresolved || unresolved->empty() || args.ChildrenSize() == 0) { + return unresolved; + } + + size_t excluded = 0; + auto newUnresolved = std::make_shared<TNodeSet>(*unresolved); + for (auto& toExclude : args.Children()) { + excluded += newUnresolved->erase(toExclude.Get()); + } + + return excluded ? newUnresolved : unresolved; + } + + TNodeSetPtr MergeUnresolvedArgs(const TNodeSetPtr& one, const TNodeSetPtr& two) { + if (!one || one->empty()) { + return two; + } + + if (!two || two->empty()) { + return one; + } + + const TNodeSetPtr& bigger = (one->size() > two->size()) ? one : two; + const TNodeSetPtr& smaller = (one->size() > two->size()) ? two : one; + + TNodeSetPtr result = std::make_shared<TNodeSet>(*bigger); + + bool inserted = false; + for (auto& item : *smaller) { + if (result->insert(item).second) { + inserted = true; + } + } + + return inserted ? result : bigger; + } + + TNodeSetPtr CollectUnresolvedArgs(const TExprNode& root, TNodeMap<TNodeSetPtr>& unresolvedArgs, TNodeSet& allArgs) { + auto it = unresolvedArgs.find(&root); + if (it != unresolvedArgs.end()) { + return it->second; + } + + TNodeSetPtr result; + switch (root.Type()) { + case TExprNode::Argument: + result = std::make_shared<TNodeSet>(TNodeSet{&root}); + break; + case TExprNode::Lambda: + { + if (!root.ChildrenSize()) { + ythrow yexception() << "lambda #" << root.UniqueId() << " has " << root.ChildrenSize() << " children"; + } + + const auto& arguments = root.Head(); + if (arguments.Type() != TExprNode::Arguments) { + ythrow yexception() << "unexpected type of arguments node in lambda #" << root.UniqueId(); + } + + arguments.ForEachChild([&](const TExprNode& arg) { + if (arg.Type() != TExprNode::Argument) { + ythrow yexception() << "expecting argument, #[" << arg.UniqueId() << "]"; + } + if (!allArgs.insert(&arg).second) { + ythrow yexception() << "argument is duplicated, #[" << arg.UniqueId() << "]"; + } + }); + + for (ui32 i = 1U; i < root.ChildrenSize(); ++i) { + const auto bodyUnresolvedArgs = CollectUnresolvedArgs(*root.Child(i), unresolvedArgs, allArgs); + result = ExcludeFromUnresolved(arguments, bodyUnresolvedArgs); + } + break; + } + case TExprNode::Callable: + case TExprNode::List: + { + root.ForEachChild([&](const TExprNode& child) { + result = MergeUnresolvedArgs(result, CollectUnresolvedArgs(child, unresolvedArgs, allArgs)); + }); + break; + } + case TExprNode::Atom: + case TExprNode::World: + break; + case TExprNode::Arguments: + ythrow yexception() << "unexpected free arguments node #[" << root.UniqueId() << "]"; + break; + } + + unresolvedArgs[&root] = result; + return result; + } + + typedef TNodeMap<long> TRefCountsMap; + + + void CalculateReferences(const TExprNode& node, TRefCountsMap& refCounts) { + if (!refCounts[&node]++) + for (const auto& child : node.Children()) + CalculateReferences(*child, refCounts); + } + + void CheckReferences(const TExprNode& node, TRefCountsMap& refCounts, TNodeSet& visited) { + if (visited.emplace(&node).second) { + for (const auto& child : node.Children()) { + YQL_ENSURE(child->UseCount() == refCounts[child.Get()]); + CheckReferences(*child, refCounts, visited); + } + } + } + + bool GatherParentsImpl(const TExprNode& node, TParentsMap& parentsMap, TNodeSet& visited) { + if (node.Type() == TExprNode::Arguments || node.Type() == TExprNode::Atom || node.Type() == TExprNode::World) { + return false; + } + + if (!visited.emplace(&node).second) { + return true; + } + + node.ForEachChild([&](const TExprNode& child) { + if (GatherParentsImpl(child, parentsMap, visited)) { + parentsMap[&child].emplace(&node); + } + }); + + return true; + } + +} // namespace + +bool CompileExpr(TAstNode& astRoot, TExprNode::TPtr& exprRoot, TExprContext& ctx, + IModuleResolver* resolver, IUrlListerManager* urlListerManager, + bool hasAnnotations, ui32 typeAnnotationIndex, ui16 syntaxVersion) { + exprRoot.Reset(); + TAstNode* cleanRoot = nullptr; + TAnnotationNodeMap annotations; + const TAnnotationNodeMap* currentAnnotations = nullptr; + TAstParseResult cleanupRes; + if (!hasAnnotations) { + typeAnnotationIndex = Max<ui32>(); + cleanRoot = &astRoot; + currentAnnotations = nullptr; + } else if (typeAnnotationIndex != Max<ui32>()) { + cleanupRes.Pool = std::make_unique<TMemoryPool>(4096); + cleanRoot = ExtractAnnotations(astRoot, annotations, *cleanupRes.Pool); + cleanupRes.Root = cleanRoot; + currentAnnotations = &annotations; + } else { + cleanupRes.Pool = std::make_unique<TMemoryPool>(4096); + cleanRoot = RemoveAnnotations(astRoot, *cleanupRes.Pool); + cleanupRes.Root = cleanRoot; + currentAnnotations = nullptr; + } + + if (!cleanRoot) { + return false; + } + + TContext compileCtx(ctx); + compileCtx.SyntaxVersion = syntaxVersion; + compileCtx.Annotations = currentAnnotations; + compileCtx.TypeAnnotationIndex = typeAnnotationIndex; + compileCtx.ModuleResolver = resolver; + compileCtx.UrlListerManager = urlListerManager; + compileCtx.PushFrame(); + auto world = compileCtx.Expr.NewWorld(astRoot.GetPosition()); + if (typeAnnotationIndex != Max<ui32>()) { + world->SetTypeAnn(compileCtx.Expr.MakeType<TWorldExprType>()); + } + + compileCtx.Frames.back().Bindings[TStringBuf("world")] = {std::move(world)}; + auto ret = CompileFunction(*cleanRoot, compileCtx, true); + if (1U != ret.size()) + return false; + exprRoot = std::move(ret.front()); + compileCtx.PopFrame(); + return bool(exprRoot); +} + +bool CompileExpr(TAstNode& astRoot, TExprNode::TPtr& exprRoot, TExprContext& ctx, + IModuleResolver* resolver, IUrlListerManager* urlListerManager, + ui32 annotationFlags, ui16 syntaxVersion) +{ + bool hasAnnotations = annotationFlags != TExprAnnotationFlags::None; + ui32 typeAnnotationIndex = Max<ui32>(); + if (annotationFlags & TExprAnnotationFlags::Types) { + bool hasPostions = annotationFlags & TExprAnnotationFlags::Position; + typeAnnotationIndex = hasPostions ? 1 : 0; + } + + return CompileExpr(astRoot, exprRoot, ctx, resolver, urlListerManager, hasAnnotations, typeAnnotationIndex, syntaxVersion); +} + +bool CompileExpr(TAstNode& astRoot, TLibraryCohesion& library, TExprContext& ctx, ui16 syntaxVersion) { + const TAstNode* cleanRoot = &astRoot; + TContext compileCtx(ctx); + compileCtx.Annotations = nullptr; + compileCtx.TypeAnnotationIndex = Max<ui32>(); + compileCtx.SyntaxVersion = syntaxVersion; + const bool ok = CompileLibrary(*cleanRoot, compileCtx); + library = compileCtx.Cohesion; + return ok; +} + +const TTypeAnnotationNode* CompileTypeAnnotation(const TAstNode& node, TExprContext& ctx) { + TContext compileCtx(ctx); + return compileCtx.CompileTypeAnnotationNode(node); +} + +template<class Set> +bool IsDependedImpl(const TExprNode& node, const Set& dependences, TNodeSet& visited) { + if (!visited.emplace(&node).second) + return false; + + if (dependences.cend() != dependences.find(&node)) + return true; + + for (const auto& child : node.Children()) { + if (IsDependedImpl(*child, dependences, visited)) + return true; + } + + return false; +} + +namespace { + +enum EChangeState : ui8 { + Unknown = 0, + Changed = 1, + Unchanged = 2 +}; + +ui64 CalcBloom(const ui64 id) { + return 1ULL | + (2ULL << (std::hash<ui64>()(id) % 63ULL)) | + (2ULL << (IntHash<ui64>(id) % 63ULL)) | + (2ULL << (FnvHash<ui64>(&id, sizeof(id)) % 63ULL)) | + (2ULL << (MurmurHash<ui64>(&id, sizeof(id)) % 63ULL)) | + (2ULL << (CityHash64(reinterpret_cast<const char*>(&id), sizeof(id)) % 63ULL)); +} + +inline bool InBloom(const ui64 set, const ui64 bloom) { + return (bloom >> 1) == ((bloom & set) >> 1); +} + +EChangeState GetChanges(TExprNode* start, const TNodeOnNodeOwnedMap& replaces, const TNodeMap<TNodeOnNodeOwnedMap>& localReplaces, + TNodeMap<EChangeState>& changes, TNodeMap<bool>& updatedLambdas); + +EChangeState DoGetChanges(TExprNode* start, const TNodeOnNodeOwnedMap& replaces, const TNodeMap<TNodeOnNodeOwnedMap>& localReplaces, + TNodeMap<EChangeState>& changes, TNodeMap<bool>& updatedLambdas) { + + if (start->GetBloom() & 1ULL) { + bool maybe = false; + for (const auto& repl : replaces) { + if (repl.second && !repl.first->Dead()) { + if (TExprNode::Argument != repl.first->Type()) { + maybe = true; + break; + } + + if (!repl.first->GetBloom()) + const_cast<TExprNode*>(repl.first)->SetBloom(CalcBloom(repl.first->UniqueId())); + + if (InBloom(start->GetBloom(), repl.first->GetBloom())) { + maybe = true; + break; + } + } + } + + if (!maybe) { + return EChangeState::Unchanged; + } + } + + start->SetBloom(1ULL); + ui32 combinedState = EChangeState::Unchanged; + bool incompleteBloom = false; + start->ForEachChild([&](TExprNode& child) { + combinedState |= GetChanges(&child, replaces, localReplaces, changes, updatedLambdas); + start->SetBloom(start->GetBloom() | child.GetBloom()); + incompleteBloom = incompleteBloom || (child.Type() != TExprNode::Arguments && !child.GetBloom()); + }); + if (incompleteBloom) { + start->SetBloom(0ULL); + } + + return (EChangeState)combinedState; +} + +EChangeState GetChanges(TExprNode* start, const TNodeOnNodeOwnedMap& replaces, const TNodeMap<TNodeOnNodeOwnedMap>& localReplaces, + TNodeMap<EChangeState>& changes, TNodeMap<bool>& updatedLambdas) { + if (start->Type() == TExprNode::Arguments) { + return EChangeState::Unchanged; + } + + if (!start->GetBloom() && TExprNode::Argument == start->Type()) { + start->SetBloom(CalcBloom(start->UniqueId())); + } + + auto& state = changes[start]; + if (state != EChangeState::Unknown) { + return state; + } + + if (const auto it = replaces.find(start); it != replaces.cend()) { + return state = it->second ? EChangeState::Changed : EChangeState::Unchanged; + } + + if (start->ChildrenSize() == 0) { + return state = EChangeState::Unchanged; + } + + if (start->Type() == TExprNode::Lambda) { + TNodeOnNodeOwnedMap newReplaces = replaces; + + start->Head().ForEachChild([&](const TExprNode& arg){ newReplaces[&arg] = {}; }); + + const auto locIt = localReplaces.find(start); + if (locIt != localReplaces.end()) { + for (auto& r: locIt->second) { + newReplaces[r.first] = r.second; + } + } + + state = DoGetChanges(start, newReplaces, localReplaces, changes, updatedLambdas); + + if ((state & EChangeState::Changed) != 0) { + updatedLambdas.emplace(start, false); + } + + return state; + } + + return state = DoGetChanges(start, replaces, localReplaces, changes, updatedLambdas); +} + +template<bool KeepTypeAnns> +TExprNode::TPtr DoReplace(const TExprNode::TPtr& start, const TNodeOnNodeOwnedMap& replaces, + const TNodeOnNodeOwnedMap& argReplaces, const TNodeMap<TNodeOnNodeOwnedMap>& localReplaces, + TNodeMap<EChangeState>& changes, TNodeOnNodeOwnedMap& processed, TExprContext& ctx) +{ + auto& target = processed[start.Get()]; + if (target) { + return target; + } + + TMaybe<TExprNode::TPtr> replace; + const auto it = replaces.find(start.Get()); + if (it != replaces.end()) { + replace = it->second; + } + const auto argIt = argReplaces.find(start.Get()); + if (argIt != argReplaces.end()) { + YQL_ENSURE(!replace.Defined()); + replace = argIt->second; + } + + if (replace.Defined()) { + if (*replace) { + return target = ctx.ReplaceNodes(std::move(*replace), argReplaces); + } + + return target = start; + } + + if (start->ChildrenSize() != 0) { + auto changeIt = changes.find(start.Get()); + YQL_ENSURE(changeIt != changes.end(), "Missing change"); + const bool isChanged = (changeIt->second & EChangeState::Changed) != 0; + if (isChanged) { + if (start->Type() == TExprNode::Lambda) { + TNodeOnNodeOwnedMap newArgReplaces = argReplaces; + const auto locIt = localReplaces.find(start.Get()); + YQL_ENSURE(locIt != localReplaces.end(), "Missing local changes"); + for (auto& r: locIt->second) { + newArgReplaces[r.first] = r.second; + } + + const auto& args = start->Head(); + TExprNode::TListType newArgsList; + newArgsList.reserve(args.ChildrenSize()); + args.ForEachChild([&](const TExprNode& arg) { + const auto argIt = newArgReplaces.find(&arg); + YQL_ENSURE(argIt != newArgReplaces.end(), "Missing argument"); + processed.emplace(&arg, argIt->second); + newArgsList.emplace_back(argIt->second); + }); + + auto newBody = GetLambdaBody(*start); + std::for_each(newBody.begin(), newBody.end(), [&](TExprNode::TPtr& node) { + node = DoReplace<KeepTypeAnns>(node, replaces, newArgReplaces, localReplaces, + changes, processed, ctx); + }); + auto newArgs = ctx.NewArguments(start->Pos(), std::move(newArgsList)); + if constexpr (KeepTypeAnns) + newArgs->SetTypeAnn(ctx.MakeType<TUnitExprType>()); + target = ctx.NewLambda(start->Pos(), std::move(newArgs), std::move(newBody)); + if constexpr (KeepTypeAnns) + target->SetTypeAnn(start->GetTypeAnn()); + return target; + } else { + bool replaced = false; + TExprNode::TListType newChildren; + newChildren.reserve(start->ChildrenSize()); + for (const auto& child : start->Children()) { + auto newChild = DoReplace<KeepTypeAnns>(child, replaces, argReplaces, localReplaces, + changes, processed, ctx); + if (newChild != child) + replaced = true; + + newChildren.emplace_back(std::move(newChild)); + } + + if (replaced) { + target = ctx.ChangeChildren(*start, std::move(newChildren)); + if constexpr (KeepTypeAnns) + target->SetTypeAnn(start->GetTypeAnn()); + return target; + } + } + } + } + + return target = start; +} + +void EnsureNoBadReplaces(const TExprNode& start, const TNodeOnNodeOwnedMap& replaces, TNodeSet&& visited = TNodeSet()) { + if (!visited.insert(&start).second) { + return; + } + + const auto it = replaces.find(&start); + if (it != replaces.end() && it->second) { + ythrow yexception() << "Bad replace for node: " << start.UniqueId() << "\n"; + } + + if (start.Type() == TExprNode::Lambda) { + TNodeOnNodeOwnedMap newReplaces = replaces; + start.Head().ForEachChild([&](const TExprNode& arg){ newReplaces[&arg] = {}; }); + start.ForEachChild([&](const TExprNode& child){ EnsureNoBadReplaces(child, newReplaces, std::move(visited)); }); + } else { + start.ForEachChild([&](const TExprNode& child){ EnsureNoBadReplaces(child, replaces, std::move(visited)); }); + } +} + +const bool InternalDebug = false; + +template<bool KeepTypeAnns> +TExprNode::TPtr ReplaceNodesImpl(TExprNode::TPtr&& start, const TNodeOnNodeOwnedMap& replaces, TNodeOnNodeOwnedMap& processed, TExprContext& ctx) { + if (InternalDebug) { + Cerr << "Before\n" << start->Dump() << "\n"; + Cerr << "Replaces\n"; + ui32 rep = 0; + for (auto& x : replaces) { + if (x.second) { + Cerr << "#" << ++rep << " " << x.first->Dump() << "\n into " << x.second->Dump() << "\n"; + } + } + } + + TNodeMap<EChangeState> changes; + TNodeMap<bool> updatedLambdas; + TNodeMap<TNodeOnNodeOwnedMap> localReplaces; + if ((GetChanges(start.Get(), replaces, localReplaces, changes, updatedLambdas) & EChangeState::Changed) == 0) { + return std::move(start); + } + + if (!updatedLambdas.empty()) { + for (;;) { + changes.clear(); + for (auto& x : updatedLambdas) { + if (!x.second) { + TNodeOnNodeOwnedMap& lambdaReplaces = localReplaces[x.first]; + const auto& args = x.first->Head(); + args.ForEachChild([&](const TExprNode& arg) { + const auto newArg = lambdaReplaces.emplace(&arg, ctx.ShallowCopy(arg)).first->second; + if constexpr (KeepTypeAnns) + newArg->SetTypeAnn(arg.GetTypeAnn()); + }); + x.second = true; + } + } + + auto prevSize = updatedLambdas.size(); + GetChanges(start.Get(), replaces, localReplaces, changes, updatedLambdas); + if (updatedLambdas.size() == prevSize) { + break; + } + } + } + + auto ret = DoReplace<KeepTypeAnns>(start, replaces, {}, localReplaces, changes, processed, ctx); + if (InternalDebug) { + Cerr << "After\n" << ret->Dump() << "\n"; + EnsureNoBadReplaces(*ret, replaces); + } + + return ret; +} + +} + +TExprNode::TPtr TExprContext::ReplaceNode(TExprNode::TPtr&& start, const TExprNode& src, TExprNode::TPtr dst) { + if (start->Type() == TExprNode::Lambda) { + const auto& args = start->Head(); + auto body = GetLambdaBody(*start); + std::optional<ui32> argIndex; + for (ui32 index = 0U; index < args.ChildrenSize(); ++index) { + const auto arg = args.Child(index); + if (arg == &src) { + if (argIndex) { + ythrow yexception() << "argument is duplicated, #[" << arg->UniqueId() << "]"; + } + + argIndex = index; + } + } + + if (argIndex) { + TExprNode::TListType newArgNodes; + newArgNodes.reserve(args.ChildrenSize()); + TNodeOnNodeOwnedMap replaces(args.ChildrenSize()); + + for (ui32 i = 0U; i < args.ChildrenSize(); ++i) { + const auto arg = args.Child(i); + auto newArg = (i == *argIndex) ? dst : ShallowCopy(*arg); + YQL_ENSURE(replaces.emplace(arg, newArg).second); + newArgNodes.emplace_back(std::move(newArg)); + } + + return NewLambda(start->Pos(), NewArguments(args.Pos(), std::move(newArgNodes)), ReplaceNodes<false>(std::move(body), replaces)); + } + } else if (&src == start) { + return dst; + } + + return ReplaceNodes(std::move(start), {{&src, std::move(dst)}}); +} + +TExprNode::TPtr TExprContext::ReplaceNodes(TExprNode::TPtr&& start, const TNodeOnNodeOwnedMap& replaces) { + TNodeOnNodeOwnedMap processed; + return replaces.empty() ? std::move(start) : ReplaceNodesImpl<false>(std::move(start), replaces, processed, *this); +} + +template<bool KeepTypeAnns> +TExprNode::TListType TExprContext::ReplaceNodes(TExprNode::TListType&& starts, const TNodeOnNodeOwnedMap& replaces) { + if (!replaces.empty()) { + TNodeOnNodeOwnedMap processed; + for (auto& node : starts) { + node = ReplaceNodesImpl<KeepTypeAnns>(std::move(node), replaces, processed, *this); + } + } + return std::move(starts); +} + +template TExprNode::TListType TExprContext::ReplaceNodes<true>(TExprNode::TListType&& starts, const TNodeOnNodeOwnedMap& replaces); +template TExprNode::TListType TExprContext::ReplaceNodes<false>(TExprNode::TListType&& starts, const TNodeOnNodeOwnedMap& replaces); + +bool IsDepended(const TExprNode& node, const TNodeSet& dependences) { + TNodeSet visited; + return !dependences.empty() && IsDependedImpl(node, dependences, visited); +} + +void CheckArguments(const TExprNode& root) { + try { + TNodeMap<TNodeSetPtr> unresolvedArgsMap; + TNodeSet allArgs; + auto rootUnresolved = CollectUnresolvedArgs(root, unresolvedArgsMap, allArgs); + if (rootUnresolved && !rootUnresolved->empty()) { + TVector<ui64> ids; + for (auto& i : *rootUnresolved) { + ids.push_back(i->UniqueId()); + } + ythrow yexception() << "detected unresolved arguments at top level: #[" << JoinSeq(", ", ids) << "]"; + } + } catch (yexception& e) { + e << "\n" << root.Dump(); + throw; + } +} + +TAstParseResult ConvertToAst(const TExprNode& root, TExprContext& exprContext, const TConvertToAstSettings& settings) { +#ifdef _DEBUG + CheckArguments(root); +#endif + TVisitNodeContext ctx(exprContext); + ctx.RefAtoms = settings.RefAtoms; + ctx.AllowFreeArgs = settings.AllowFreeArgs; + ctx.NormalizeAtomFlags = settings.NormalizeAtomFlags; + ctx.Pool = std::make_unique<TMemoryPool>(4096, TMemoryPool::TExpGrow::Instance(), settings.Allocator); + ctx.Frames.push_back(TFrameContext()); + ctx.CurrentFrame = &ctx.Frames.front(); + VisitNode(root, 0ULL, ctx); + ui32 uniqueNum = 0; + + for (auto& frame : ctx.Frames) { + ctx.CurrentFrame = &frame; + frame.TopoSortedNodes.reserve(frame.Nodes.size()); + for (const auto& node : frame.Nodes) { + const auto name = ctx.FindBinding(node.second); + if (name.empty()) { + const auto& ref = ctx.References[node.second]; + if (!InlineNode(*node.second, ref.References, ref.Neighbors, settings)) { + if (settings.PrintArguments && node.second->IsArgument()) { + auto buffer = TStringBuilder() << "$" << ++uniqueNum + << "{" << node.second->Content() << ":" + << node.second->UniqueId() << "}"; + YQL_ENSURE(frame.Bindings.emplace(node.second, buffer).second); + } else { + char buffer[1 + 10 + 1]; + snprintf(buffer, sizeof(buffer), "$%" PRIu32, ++uniqueNum); + YQL_ENSURE(frame.Bindings.emplace(node.second, buffer).second); + } + frame.TopoSortedNodes.emplace_back(node.second); + } + } + } + } + + ctx.CurrentFrame = &ctx.Frames.front(); + TAstParseResult result; + result.Root = ConvertFunction(exprContext.AppendPosition(TPosition(1, 1)), {&root}, ctx, settings.AnnotationFlags, *ctx.Pool); + result.Pool = std::move(ctx.Pool); + return result; +} + +TAstParseResult ConvertToAst(const TExprNode& root, TExprContext& exprContext, ui32 annotationFlags, bool refAtoms) { + TConvertToAstSettings settings; + settings.AnnotationFlags = annotationFlags; + settings.RefAtoms = refAtoms; + return ConvertToAst(root, exprContext, settings); +} + +TString TExprNode::Dump() const { + TNodeSet visited; + TStringStream out; + DumpNode(*this, out, 0, visited); + return out.Str(); +} + +TPosition TExprNode::Pos(const TExprContext& ctx) const { + return ctx.GetPosition(Pos()); +} + +TExprNode::TPtr TExprContext::RenameNode(const TExprNode& node, const TStringBuf& name) { + const auto newNode = node.ChangeContent(AllocateNextUniqueId(), AppendString(name)); + ExprNodes.emplace_back(newNode.Get()); + return newNode; +} + +TExprNode::TPtr TExprContext::ShallowCopy(const TExprNode& node) { + YQL_ENSURE(node.Type() != TExprNode::Lambda); + const auto newNode = node.Clone(AllocateNextUniqueId()); + ExprNodes.emplace_back(newNode.Get()); + return newNode; +} + +TExprNode::TPtr TExprContext::ShallowCopyWithPosition(const TExprNode& node, TPositionHandle pos) { + YQL_ENSURE(node.Type() != TExprNode::Lambda); + const auto newNode = node.CloneWithPosition(AllocateNextUniqueId(), pos); + ExprNodes.emplace_back(newNode.Get()); + return newNode; +} + +TExprNode::TPtr TExprContext::ChangeChildren(const TExprNode& node, TExprNode::TListType&& children) { + const auto newNode = node.ChangeChildren(AllocateNextUniqueId(), std::move(children)); + ExprNodes.emplace_back(newNode.Get()); + return newNode; +} + +TExprNode::TPtr TExprContext::ChangeChild(const TExprNode& node, ui32 index, TExprNode::TPtr&& child) { + const auto newNode = node.ChangeChild(AllocateNextUniqueId(), index, std::move(child)); + ExprNodes.emplace_back(newNode.Get()); + return newNode; +} + +TExprNode::TPtr TExprContext::ExactChangeChildren(const TExprNode& node, TExprNode::TListType&& children) { + const auto newNode = node.ChangeChildren(AllocateNextUniqueId(), std::move(children)); + newNode->SetTypeAnn(node.GetTypeAnn()); + newNode->CopyConstraints(node); + newNode->SetState(node.GetState()); + newNode->Result = node.Result; + ExprNodes.emplace_back(newNode.Get()); + return newNode; +} + +TExprNode::TPtr TExprContext::ExactShallowCopy(const TExprNode& node) { + YQL_ENSURE(node.Type() != TExprNode::Lambda); + const auto newNode = node.Clone(AllocateNextUniqueId()); + newNode->SetTypeAnn(node.GetTypeAnn()); + newNode->CopyConstraints(node); + newNode->SetState(node.GetState()); + newNode->Result = node.Result; + ExprNodes.emplace_back(newNode.Get()); + return newNode; +} + +TExprNode::TListType GetLambdaBody(const TExprNode& node) { + switch (node.ChildrenSize()) { + case 1U: return {}; + case 2U: return {node.TailPtr()}; + default: break; + } + + auto body = node.ChildrenList(); + body.erase(body.cbegin()); + return body; +} + +TExprNode::TPtr TExprContext::DeepCopyLambda(const TExprNode& node, TExprNode::TListType&& body) { + YQL_ENSURE(node.IsLambda()); + const auto& prevArgs = node.Head(); + + TNodeOnNodeOwnedMap replaces(prevArgs.ChildrenSize()); + + TExprNode::TListType newArgNodes; + newArgNodes.reserve(prevArgs.ChildrenSize()); + prevArgs.ForEachChild([&](const TExprNode& arg) { + auto newArg = ShallowCopy(arg); + YQL_ENSURE(replaces.emplace(&arg, newArg).second); + newArgNodes.emplace_back(std::move(newArg)); + }); + + auto newBody = ReplaceNodes(std::move(body), replaces); + return NewLambda(node.Pos(), NewArguments(prevArgs.Pos(), std::move(newArgNodes)), std::move(newBody)); +} + +TExprNode::TPtr TExprContext::DeepCopyLambda(const TExprNode& node, TExprNode::TPtr&& body) { + YQL_ENSURE(node.IsLambda()); + const auto& prevArgs = node.Head(); + + TNodeOnNodeOwnedMap replaces(prevArgs.ChildrenSize()); + + TExprNode::TListType newArgNodes; + newArgNodes.reserve(prevArgs.ChildrenSize()); + prevArgs.ForEachChild([&](const TExprNode& arg) { + auto newArg = ShallowCopy(arg); + YQL_ENSURE(replaces.emplace(&arg, newArg).second); + newArgNodes.emplace_back(std::move(newArg)); + }); + + auto newBody = ReplaceNodes(body ? TExprNode::TListType{std::move(body)} : GetLambdaBody(node), replaces); + return NewLambda(node.Pos(), NewArguments(prevArgs.Pos(), std::move(newArgNodes)), std::move(newBody)); +} + +TExprNode::TPtr TExprContext::FuseLambdas(const TExprNode& outer, const TExprNode& inner) { + YQL_ENSURE(outer.IsLambda() && inner.IsLambda()); + const auto& outerArgs = outer.Head(); + const auto& innerArgs = inner.Head(); + + TNodeOnNodeOwnedMap innerReplaces(innerArgs.ChildrenSize()); + + TExprNode::TListType newArgNodes; + newArgNodes.reserve(innerArgs.ChildrenSize()); + + innerArgs.ForEachChild([&](const TExprNode& arg) { + auto newArg = ShallowCopy(arg); + YQL_ENSURE(innerReplaces.emplace(&arg, newArg).second); + newArgNodes.emplace_back(std::move(newArg)); + }); + + auto body = ReplaceNodes(GetLambdaBody(inner), innerReplaces); + + TExprNode::TListType newBody; + auto outerBody = GetLambdaBody(outer); + if (outerArgs.ChildrenSize() + 1U == inner.ChildrenSize()) { + auto i = 0U; + TNodeOnNodeOwnedMap outerReplaces(outerArgs.ChildrenSize()); + outerArgs.ForEachChild([&](const TExprNode& arg) { + YQL_ENSURE(outerReplaces.emplace(&arg, std::move(body[i++])).second); + }); + newBody = ReplaceNodes(std::move(outerBody), outerReplaces); + } else if (1U == outerArgs.ChildrenSize()) { + newBody.reserve(newBody.size() * body.size()); + for (auto item : body) { + for (auto root : outerBody) { + newBody.emplace_back(ReplaceNode(TExprNode::TPtr(root), outerArgs.Head(), TExprNode::TPtr(item))); + } + } + } else { + YQL_ENSURE(outerBody.empty(), "Incompatible lambdas for fuse."); + } + + return NewLambda(outer.Pos(), NewArguments(inner.Head().Pos(), std::move(newArgNodes)), std::move(newBody)); +} + +TExprNode::TPtr TExprContext::DeepCopy(const TExprNode& node, TExprContext& nodeCtx, TNodeOnNodeOwnedMap& deepClones, + bool internStrings, bool copyTypes, bool copyResult, TCustomDeepCopier customCopier) +{ + const auto ins = deepClones.emplace(&node, nullptr); + if (ins.second) { + TExprNode::TListType children; + children.reserve(node.ChildrenSize()); + + if (customCopier && customCopier(node, children)) { + } else { + node.ForEachChild([&](const TExprNode& child) { + children.emplace_back(DeepCopy(child, nodeCtx, deepClones, internStrings, copyTypes, copyResult, customCopier)); + }); + } + + ++NodeAllocationCounter; + auto newNode = TExprNode::NewNode(AppendPosition(nodeCtx.GetPosition(node.Pos())), node.Type(), + std::move(children), internStrings ? AppendString(node.Content()) : node.Content(), node.Flags(), + AllocateNextUniqueId()); + + if (copyTypes && node.GetTypeAnn()) { + newNode->SetTypeAnn(node.GetTypeAnn()); + } + + if (copyResult && node.IsCallable() && node.HasResult()) { + newNode->SetResult(nodeCtx.ShallowCopy(node.GetResult())); + } + + ins.first->second = newNode; + ExprNodes.emplace_back(ins.first->second.Get()); + } + return ins.first->second; +} + +TExprNode::TPtr TExprContext::WrapByCallableIf(bool condition, const TStringBuf& callable, TExprNode::TPtr&& node) { + if (!condition) { + return node; + } + const auto pos = node->Pos(); + return NewCallable(pos, callable, {std::move(node)}); +} + +TExprNode::TPtr TExprContext::SwapWithHead(const TExprNode& node) { + return ChangeChild(node.Head(), 0U, ChangeChild(node, 0U, node.Head().HeadPtr())); +} + +TConstraintSet TExprContext::MakeConstraintSet(const NYT::TNode& serializedConstraints) { + const static std::unordered_map<std::string_view, std::function<const TConstraintNode*(TExprContext&, const NYT::TNode&)>> FACTORIES = { + {TSortedConstraintNode::Name(), std::mem_fn(&TExprContext::MakeConstraint<TSortedConstraintNode, const NYT::TNode&>)}, + {TChoppedConstraintNode::Name(), std::mem_fn(&TExprContext::MakeConstraint<TChoppedConstraintNode, const NYT::TNode&>)}, + {TUniqueConstraintNode::Name(), std::mem_fn(&TExprContext::MakeConstraint<TUniqueConstraintNode, const NYT::TNode&>)}, + {TDistinctConstraintNode::Name(), std::mem_fn(&TExprContext::MakeConstraint<TDistinctConstraintNode, const NYT::TNode&>)}, + {TEmptyConstraintNode::Name(), std::mem_fn(&TExprContext::MakeConstraint<TEmptyConstraintNode, const NYT::TNode&>)}, + {TVarIndexConstraintNode::Name(), std::mem_fn(&TExprContext::MakeConstraint<TVarIndexConstraintNode, const NYT::TNode&>)}, + {TMultiConstraintNode::Name(), std::mem_fn(&TExprContext::MakeConstraint<TMultiConstraintNode, const NYT::TNode&>)}, + }; + TConstraintSet res; + YQL_ENSURE(serializedConstraints.IsMap(), "Unexpected node type with serialize constraints: " << serializedConstraints.GetType()); + for (const auto& [key, node]: serializedConstraints.AsMap()) { + auto it = FACTORIES.find(key); + YQL_ENSURE(it != FACTORIES.cend(), "Unsupported constraint construction: " << key); + try { + res.AddConstraint((it->second)(*this, node)); + } catch (...) { + YQL_ENSURE(false, "Error while constructing constraint: " << CurrentExceptionMessage()); + } + } + return res; +} + +TNodeException::TNodeException() + : Pos_() +{ +} + +TNodeException::TNodeException(const TExprNode& node) + : Pos_(node.Pos()) +{ +} + +TNodeException::TNodeException(const TExprNode* node) + : Pos_(node ? node->Pos() : TPositionHandle()) +{ +} + +TNodeException::TNodeException(const TPositionHandle& pos) + : Pos_(pos) +{ +} + +bool ValidateName(TPosition position, TStringBuf name, TStringBuf descr, TExprContext& ctx) { + if (name.empty()) { + ctx.AddError(TIssue(position, + TStringBuilder() << "Empty " << descr << " name is not allowed")); + return false; + } + + if (!IsUtf8(name)) { + ctx.AddError(TIssue(position, TStringBuilder() << + TString(descr).to_title() << " name must be a valid utf-8 byte sequence: " << TString{name}.Quote())); + return false; + } + + if (name.size() > 16_KB) { + ctx.AddError(TIssue(position, TStringBuilder() << + TString(descr).to_title() << " name length must be less than " << 16_KB)); + return false; + } + + return true; +} + +bool ValidateName(TPositionHandle position, TStringBuf name, TStringBuf descr, TExprContext& ctx) { + return ValidateName(ctx.GetPosition(position), name, descr, ctx); +} + +bool TDataExprParamsType::Validate(TPosition position, TExprContext& ctx) const { + if (GetSlot() != EDataSlot::Decimal) { + ctx.AddError(TIssue(position, TStringBuilder() << "Only Decimal may contain parameters, but got: " << GetName())); + return false; + } + + ui8 precision; + if (!TryFromString<ui8>(GetParamOne(), precision)){ + ctx.AddError(TIssue(position, TStringBuilder() << "Invalid decimal precision: " << GetParamOne())); + return false; + } + + if (!precision || precision > 35) { + ctx.AddError(TIssue(position, TStringBuilder() << "Invalid decimal precision: " << GetParamOne())); + return false; + } + + ui8 scale; + if (!TryFromString<ui8>(GetParamTwo(), scale)){ + ctx.AddError(TIssue(position, TStringBuilder() << "Invalid decimal scale: " << GetParamTwo())); + return false; + } + + if (scale > precision) { + ctx.AddError(TIssue(position, TStringBuilder() << "Invalid decimal parameters: (" << GetParamOne() << "," << GetParamTwo() << ").")); + return false; + } + + return true; +} + +bool TDataExprParamsType::Validate(TPositionHandle position, TExprContext& ctx) const { + return Validate(ctx.GetPosition(position), ctx); +} + +bool TItemExprType::Validate(TPosition position, TExprContext& ctx) const { + return ValidateName(position, Name, "member", ctx); +} + +bool TItemExprType::Validate(TPositionHandle position, TExprContext& ctx) const { + return Validate(ctx.GetPosition(position), ctx); +} + +TStringBuf TItemExprType::GetCleanName(bool isVirtual) const { + if (!isVirtual) { + return Name; + } + + YQL_ENSURE(Name.StartsWith(YqlVirtualPrefix)); + return Name.SubStr(YqlVirtualPrefix.size()); +} + +const TItemExprType* TItemExprType::GetCleanItem(bool isVirtual, TExprContext& ctx) const { + if (!isVirtual) { + return this; + } + + YQL_ENSURE(Name.StartsWith(YqlVirtualPrefix)); + return ctx.MakeType<TItemExprType>(Name.SubStr(YqlVirtualPrefix.size()), ItemType); +} + +bool TMultiExprType::Validate(TPosition position, TExprContext& ctx) const { + if (Items.size() > Max<ui16>()) { + ctx.AddError(TIssue(position, TStringBuilder() << "Too many elements: " << Items.size())); + return false; + } + + return true; +} + +bool TMultiExprType::Validate(TPositionHandle position, TExprContext& ctx) const { + return Validate(ctx.GetPosition(position), ctx); +} + +bool TTupleExprType::Validate(TPosition position, TExprContext& ctx) const { + if (Items.size() > Max<ui16>()) { + ctx.AddError(TIssue(position, TStringBuilder() << "Too many tuple elements: " << Items.size())); + return false; + } + + return true; +} + +bool TTupleExprType::Validate(TPositionHandle position, TExprContext& ctx) const { + return Validate(ctx.GetPosition(position), ctx); +} + +bool TStructExprType::Validate(TPosition position, TExprContext& ctx) const { + if (Items.size() > Max<ui16>()) { + ctx.AddError(TIssue(position, TStringBuilder() << "Too many struct members: " << Items.size())); + return false; + } + + TString lastName; + for (auto& item : Items) { + if (!item->Validate(position, ctx)) { + return false; + } + + if (item->GetName() == lastName) { + ctx.AddError(TIssue(position, TStringBuilder() << "Duplicated member: " << lastName)); + return false; + } + + lastName = item->GetName(); + } + + return true; +} + +bool TStructExprType::Validate(TPositionHandle position, TExprContext& ctx) const { + return Validate(ctx.GetPosition(position), ctx); +} + +bool TVariantExprType::Validate(TPosition position, TExprContext& ctx) const { + if (UnderlyingType->GetKind() == ETypeAnnotationKind::Tuple) { + if (!UnderlyingType->Cast<TTupleExprType>()->GetSize()) { + ctx.AddError(TIssue(position, TStringBuilder() << "Empty tuple is not allowed as underlying type")); + return false; + } + } + else if (UnderlyingType->GetKind() == ETypeAnnotationKind::Struct) { + if (!UnderlyingType->Cast<TStructExprType>()->GetSize()) { + ctx.AddError(TIssue(position, TStringBuilder() << "Empty struct is not allowed as underlying type")); + return false; + } + } + else { + ctx.AddError(TIssue(position, TStringBuilder() << "Expected tuple or struct, but got:" << *UnderlyingType)); + return false; + } + + return true; +} + +bool TVariantExprType::Validate(TPositionHandle position, TExprContext& ctx) const { + return Validate(ctx.GetPosition(position), ctx); +} + +ui32 TVariantExprType::MakeFlags(const TTypeAnnotationNode* underlyingType) { + switch (underlyingType->GetKind()) { + case ETypeAnnotationKind::Tuple: { + const auto tupleType = underlyingType->Cast<TTupleExprType>(); + auto ret = CombineFlags(tupleType->GetItems()); + if (tupleType->GetSize() > 1) { + ret |= TypeHasManyValues; + } + return ret; + } + case ETypeAnnotationKind::Struct: { + const auto structType = underlyingType->Cast<TStructExprType>(); + auto ret = CombineFlags(structType->GetItems()); + if (structType->GetSize() > 1) { + ret |= TypeHasManyValues; + } + return ret; + } + default: break; + } + + ythrow yexception() << "unexpected underlying type" << *underlyingType; +} + + +bool TDictExprType::Validate(TPosition position, TExprContext& ctx) const { + if (KeyType->IsHashable() && KeyType->IsEquatable()) { + return true; + } + + if (KeyType->IsComparableInternal()) { + return true; + } + + ctx.AddError(TIssue(position, TStringBuilder() << "Expected hashable and equatable or internally comparable dict key type, but got: " << *KeyType)); + return false; +} + +bool TDictExprType::Validate(TPositionHandle position, TExprContext& ctx) const { + return Validate(ctx.GetPosition(position), ctx); +} + +bool TCallableExprType::Validate(TPosition position, TExprContext& ctx) const { + if (OptionalArgumentsCount > Arguments.size()) { + ctx.AddError(TIssue(position, TStringBuilder() << "Too many optional arguments: " << OptionalArgumentsCount + << ", function has only " << Arguments.size() << " arguments")); + return false; + } + + for (ui32 index = Arguments.size() - OptionalArgumentsCount; index < Arguments.size(); ++index) { + auto type = Arguments[index].Type; + if (type->GetKind() != ETypeAnnotationKind::Optional) { + ctx.AddError(TIssue(position, TStringBuilder() << "Expected optional type for argument: " << (index + 1) + << " because it's an optional argument, but got: " << *type)); + return false; + } + } + + bool startedNames = false; + std::unordered_set<std::string_view> usedNames(Arguments.size()); + for (ui32 index = 0; index < Arguments.size(); ++index) { + bool hasName = !Arguments[index].Name.empty(); + if (startedNames) { + if (!hasName) { + ctx.AddError(TIssue(position, TStringBuilder() << "Unexpected positional argument at position " + << (index + 1) << " just after named arguments")); + return false; + } + } else { + startedNames = hasName; + } + + if (hasName) { + if (!usedNames.insert(Arguments[index].Name).second) { + ctx.AddError(TIssue(position, TStringBuilder() << "Duplication of named argument: " << Arguments[index].Name)); + return false; + } + } + } + + return true; +} + +bool TCallableExprType::Validate(TPositionHandle position, TExprContext& ctx) const { + return Validate(ctx.GetPosition(position), ctx); +} + +bool TTaggedExprType::Validate(TPosition position, TExprContext& ctx) const { + return ValidateName(position, Tag, "tag", ctx); +} + +bool TTaggedExprType::Validate(TPositionHandle position, TExprContext& ctx) const { + return Validate(ctx.GetPosition(position), ctx); +} + +const TString& TPgExprType::GetName() const { + return NPg::LookupType(TypeId).Name; +} + +ui32 TPgExprType::GetFlags(ui32 typeId) { + auto descPtr = &NPg::LookupType(typeId); + if (descPtr->ArrayTypeId == descPtr->TypeId) { + auto elemType = descPtr->ElementTypeId; + descPtr = &NPg::LookupType(elemType); + } + + const auto& desc = *descPtr; + ui32 ret = TypeHasManyValues | TypeHasOptional; + if ((!desc.SendFuncId || !desc.ReceiveFuncId) && (!desc.OutFuncId || !desc.InFuncId)) { + ret |= TypeNonPersistable; + } + + if (!desc.LessProcId || !desc.CompareProcId) { + ret |= TypeNonComparable; + } + + if (!desc.EqualProcId || !desc.CompareProcId) { + if (desc.TypeId != NPg::UnknownOid) { + ret |= TypeNonEquatable; + } + } + + if (!desc.HashProcId) { + ret |= TypeNonHashable; + } + + static const std::unordered_set<std::string_view> PgSupportedPresort = { + "bool"sv, + "int2"sv, + "int4"sv, + "int8"sv, + "float4"sv, + "float8"sv, + "bytea"sv, + "varchar"sv, + "text"sv, + "cstring"sv + }; + + if (!PgSupportedPresort.contains(descPtr->Name)) { + ret |= TypeNonPresortable; + } + + return ret; +} + +ui64 TPgExprType::GetPgExtensionsMask(ui32 typeId) { + auto descPtr = &NPg::LookupType(typeId); + return MakePgExtensionMask(descPtr->ExtensionIndex); +} + +ui64 MakePgExtensionMask(ui32 extensionIndex) { + if (!extensionIndex) { + return 0; + } + + YQL_ENSURE(extensionIndex <= 64); + return 1ull << (extensionIndex - 1); +} + +TExprContext::TExprContext(ui64 nextUniqueId) + : StringPool(4096) + , NextUniqueId(nextUniqueId) + , Frozen(false) + , PositionSet( + 16, + [this](TPositionHandle p) { return GetHash(p); }, + [this](TPositionHandle a, TPositionHandle b) { return IsEqual(a, b); } + ) +{ + auto handle = AppendPosition(TPosition()); + YQL_ENSURE(handle.Handle == 0); + IssueManager.SetWarningToErrorTreatMessage( + "Treat warning as error mode enabled. " + "To disable it use \"pragma warning(\"default\", <code>);\""); + IssueManager.SetIssueCountLimit(100); +} + +TPositionHandle TExprContext::AppendPosition(const TPosition& pos) { + YQL_ENSURE(Positions.size() <= Max<ui32>(), "Too many positions"); + + TPositionHandle handle; + handle.Handle = (ui32)Positions.size(); + Positions.push_back(pos); + + auto inserted = PositionSet.insert(handle); + if (inserted.second) { + return handle; + } + + Positions.pop_back(); + return *inserted.first; +} + +TPosition TExprContext::GetPosition(TPositionHandle handle) const { + YQL_ENSURE(handle.Handle < Positions.size(), "Unknown PositionHandle"); + return Positions[handle.Handle]; +} + +TExprContext::~TExprContext() { + UnFreeze(); +} + +void TExprContext::Freeze() { + for (auto& node : ExprNodes) { + node->MarkFrozen(); + } + + Frozen = true; +} + +void TExprContext::UnFreeze() { + if (Frozen) { + for (auto& node : ExprNodes) { + node->MarkFrozen(false); + } + + Frozen = false; + } +} + +void TExprContext::Reset() { + YQL_ENSURE(!Frozen); + + IssueManager.Reset(); + Step.Reset(); + RepeatTransformCounter = 0; +} + +bool TExprContext::IsEqual(TPositionHandle a, TPositionHandle b) const { + YQL_ENSURE(a.Handle < Positions.size()); + YQL_ENSURE(b.Handle < Positions.size()); + return Positions[a.Handle] == Positions[b.Handle]; +} + +size_t TExprContext::GetHash(TPositionHandle p) const { + YQL_ENSURE(p.Handle < Positions.size()); + + const TPosition& pos = Positions[p.Handle]; + size_t h = ComputeHash(pos.File); + h = CombineHashes(h, NumericHash(pos.Row)); + return CombineHashes(h, NumericHash(pos.Column)); +} + +std::string_view TExprContext::GetIndexAsString(ui32 index) { + const auto it = Indexes.find(index); + if (it != Indexes.cend()) { + return it->second; + } + + const auto& newBuf = AppendString(ToString(index)); + Indexes.emplace_hint(it, index, newBuf); + return newBuf; +} + +template<class T, typename... Args> +const T* MakeSinglethonType(TExprContext& ctx, Args&&... args) { + auto& singleton = std::get<const T*>(ctx.SingletonTypeCache); + if (!singleton) + singleton = AddType<T>(ctx, T::MakeHash(args...), std::forward<Args>(args)...); + return singleton; +} + +const TVoidExprType* TMakeTypeImpl<TVoidExprType>::Make(TExprContext& ctx) { + return MakeSinglethonType<TVoidExprType>(ctx); +} + +const TNullExprType* TMakeTypeImpl<TNullExprType>::Make(TExprContext& ctx) { + return MakeSinglethonType<TNullExprType>(ctx); +} + +const TEmptyListExprType* TMakeTypeImpl<TEmptyListExprType>::Make(TExprContext& ctx) { + return MakeSinglethonType<TEmptyListExprType>(ctx); +} + +const TEmptyDictExprType* TMakeTypeImpl<TEmptyDictExprType>::Make(TExprContext& ctx) { + return MakeSinglethonType<TEmptyDictExprType>(ctx); +} + +const TUnitExprType* TMakeTypeImpl<TUnitExprType>::Make(TExprContext& ctx) { + return MakeSinglethonType<TUnitExprType>(ctx); +} + +const TWorldExprType* TMakeTypeImpl<TWorldExprType>::Make(TExprContext& ctx) { + return MakeSinglethonType<TWorldExprType>(ctx); +} + +const TGenericExprType* TMakeTypeImpl<TGenericExprType>::Make(TExprContext& ctx) { + return MakeSinglethonType<TGenericExprType>(ctx); +} + +const TItemExprType* TMakeTypeImpl<TItemExprType>::Make(TExprContext& ctx, const TStringBuf& name, const TTypeAnnotationNode* itemType) { + const auto hash = TItemExprType::MakeHash(name, itemType); + TItemExprType sample(hash, name, itemType); + if (const auto found = FindType(sample, ctx)) + return found; + + auto nameStr = ctx.AppendString(name); + return AddType<TItemExprType>(ctx, hash, nameStr, itemType); +} + +const TListExprType* TMakeTypeImpl<TListExprType>::Make(TExprContext& ctx, const TTypeAnnotationNode* itemType) { + const auto hash = TListExprType::MakeHash(itemType); + TListExprType sample(hash, itemType); + if (const auto found = FindType(sample, ctx)) + return found; + return AddType<TListExprType>(ctx, hash, itemType); +} + +const TOptionalExprType* TMakeTypeImpl<TOptionalExprType>::Make(TExprContext& ctx, const TTypeAnnotationNode* itemType) { + const auto hash = TOptionalExprType::MakeHash(itemType); + TOptionalExprType sample(hash, itemType); + if (const auto found = FindType(sample, ctx)) + return found; + + return AddType<TOptionalExprType>(ctx, hash, itemType); +} + +const TVariantExprType* TMakeTypeImpl<TVariantExprType>::Make(TExprContext& ctx, const TTypeAnnotationNode* underlyingType) { + const auto hash = TVariantExprType::MakeHash(underlyingType); + TVariantExprType sample(hash, underlyingType); + if (const auto found = FindType(sample, ctx)) + return found; + + return AddType<TVariantExprType>(ctx, hash, underlyingType); +} + +const TErrorExprType* TMakeTypeImpl<TErrorExprType>::Make(TExprContext& ctx, const TIssue& error) { + const auto hash = TErrorExprType::MakeHash(error); + TErrorExprType sample(hash, error); + if (const auto found = FindType(sample, ctx)) + return found; + + return AddType<TErrorExprType>(ctx, hash, error); +} + +const TDictExprType* TMakeTypeImpl<TDictExprType>::Make(TExprContext& ctx, const TTypeAnnotationNode* keyType, + const TTypeAnnotationNode* payloadType) { + const auto hash = TDictExprType::MakeHash(keyType, payloadType); + TDictExprType sample(hash, keyType, payloadType); + if (const auto found = FindType(sample, ctx)) + return found; + + return AddType<TDictExprType>(ctx, hash, keyType, payloadType); +} + +const TTypeExprType* TMakeTypeImpl<TTypeExprType>::Make(TExprContext& ctx, const TTypeAnnotationNode* baseType) { + const auto hash = TTypeExprType::MakeHash(baseType); + TTypeExprType sample(hash, baseType); + if (const auto found = FindType(sample, ctx)) + return found; + + return AddType<TTypeExprType>(ctx, hash, baseType); +} + +const TDataExprType* TMakeTypeImpl<TDataExprType>::Make(TExprContext& ctx, EDataSlot slot) { + const auto hash = TDataExprType::MakeHash(slot); + TDataExprType sample(hash, slot); + if (const auto found = FindType(sample, ctx)) + return found; + + return AddType<TDataExprType>(ctx, hash, slot); +} + +const TPgExprType* TMakeTypeImpl<TPgExprType>::Make(TExprContext& ctx, ui32 typeId) { + const auto hash = TPgExprType::MakeHash(typeId); + TPgExprType sample(hash, typeId); + if (const auto found = FindType(sample, ctx)) + return found; + + return AddType<TPgExprType>(ctx, hash, typeId); +} + +const TDataExprParamsType* TMakeTypeImpl<TDataExprParamsType>::Make(TExprContext& ctx, EDataSlot slot, const TStringBuf& one, const TStringBuf& two) { + const auto hash = TDataExprParamsType::MakeHash(slot, one, two); + TDataExprParamsType sample(hash, slot, one, two); + if (const auto found = FindType(sample, ctx)) + return found; + + return AddType<TDataExprParamsType>(ctx, hash, slot, ctx.AppendString(one), ctx.AppendString(two)); +} + +const TCallableExprType* TMakeTypeImpl<TCallableExprType>::Make( + TExprContext& ctx, const TTypeAnnotationNode* returnType, const TVector<TCallableExprType::TArgumentInfo>& arguments, + size_t optionalArgumentsCount, const TStringBuf& payload) { + const auto hash = TCallableExprType::MakeHash(returnType, arguments, optionalArgumentsCount, payload); + TCallableExprType sample(hash, returnType, arguments, optionalArgumentsCount, payload); + if (const auto found = FindType(sample, ctx)) + return found; + + TVector<TCallableExprType::TArgumentInfo> newArgs; + newArgs.reserve(arguments.size()); + for (const auto& x : arguments) { + TCallableExprType::TArgumentInfo arg; + arg.Type = x.Type; + arg.Name = ctx.AppendString(x.Name); + arg.Flags = x.Flags; + newArgs.emplace_back(arg); + } + + return AddType<TCallableExprType>(ctx, hash, returnType, newArgs, optionalArgumentsCount, ctx.AppendString(payload)); +} + +const TResourceExprType* TMakeTypeImpl<TResourceExprType>::Make(TExprContext& ctx, const TStringBuf& tag) { + const auto hash = TResourceExprType::MakeHash(tag); + TResourceExprType sample(hash, tag); + if (const auto found = FindType(sample, ctx)) + return found; + + return AddType<TResourceExprType>(ctx, hash, ctx.AppendString(tag)); +} + +const TTaggedExprType* TMakeTypeImpl<TTaggedExprType>::Make( + TExprContext& ctx, const TTypeAnnotationNode* baseType, const TStringBuf& tag) { + const auto hash = TTaggedExprType::MakeHash(baseType, tag); + TTaggedExprType sample(hash, baseType, tag); + if (const auto found = FindType(sample, ctx)) + return found; + + return AddType<TTaggedExprType>(ctx, hash, baseType, ctx.AppendString(tag)); +} + +const TStructExprType* TMakeTypeImpl<TStructExprType>::Make( + TExprContext& ctx, const TVector<const TItemExprType*>& items) { + if (items.empty()) + return MakeSinglethonType<TStructExprType>(ctx, items); + + auto sortedItems = items; + Sort(sortedItems, TStructExprType::TItemLess()); + const auto hash = TStructExprType::MakeHash(sortedItems); + TStructExprType sample(hash, sortedItems); + if (const auto found = FindType(sample, ctx)) + return found; + + return AddType<TStructExprType>(ctx, hash, sortedItems); +} + +const TMultiExprType* TMakeTypeImpl<TMultiExprType>::Make(TExprContext& ctx, const TTypeAnnotationNode::TListType& items) { + if (items.empty()) + return MakeSinglethonType<TMultiExprType>(ctx, items); + + const auto hash = TMultiExprType::MakeHash(items); + TMultiExprType sample(hash, items); + if (const auto found = FindType(sample, ctx)) + return found; + + return AddType<TMultiExprType>(ctx, hash, items); +} + +const TTupleExprType* TMakeTypeImpl<TTupleExprType>::Make(TExprContext& ctx, const TTypeAnnotationNode::TListType& items) { + if (items.empty()) + return MakeSinglethonType<TTupleExprType>(ctx, items); + + const auto hash = TTupleExprType::MakeHash(items); + TTupleExprType sample(hash, items); + if (const auto found = FindType(sample, ctx)) + return found; + + return AddType<TTupleExprType>(ctx, hash, items); +} + +const TStreamExprType* TMakeTypeImpl<TStreamExprType>::Make(TExprContext& ctx, const TTypeAnnotationNode* itemType) { + const auto hash = TStreamExprType::MakeHash(itemType); + TStreamExprType sample(hash, itemType); + if (const auto found = FindType(sample, ctx)) + return found; + + return AddType<TStreamExprType>(ctx, hash, itemType); +} + +const TFlowExprType* TMakeTypeImpl<TFlowExprType>::Make(TExprContext& ctx, const TTypeAnnotationNode* itemType) { + const auto hash = TFlowExprType::MakeHash(itemType); + TFlowExprType sample(hash, itemType); + if (const auto found = FindType(sample, ctx)) + return found; + + return AddType<TFlowExprType>(ctx, hash, itemType); +} + +const TBlockExprType* TMakeTypeImpl<TBlockExprType>::Make(TExprContext& ctx, const TTypeAnnotationNode* itemType) { + const auto hash = TBlockExprType::MakeHash(itemType); + TBlockExprType sample(hash, itemType); + if (const auto found = FindType(sample, ctx)) + return found; + + return AddType<TBlockExprType>(ctx, hash, itemType); +} + +const TScalarExprType* TMakeTypeImpl<TScalarExprType>::Make(TExprContext& ctx, const TTypeAnnotationNode* itemType) { + const auto hash = TScalarExprType::MakeHash(itemType); + TScalarExprType sample(hash, itemType); + if (const auto found = FindType(sample, ctx)) + return found; + + return AddType<TScalarExprType>(ctx, hash, itemType); +} + +bool CompareExprTrees(const TExprNode*& one, const TExprNode*& two) { + TArgumentsMap map; + ui32 level = 0; + TNodesPairSet visited; + return CompareExpressions(one, two, map, level, visited); +} + +bool CompareExprTreeParts(const TExprNode& one, const TExprNode& two, const TNodeMap<ui32>& argsMap) { + TArgumentsMap map; + ui32 level = 0; + TNodesPairSet visited; + map.reserve(argsMap.size()); + std::for_each(argsMap.cbegin(), argsMap.cend(), [&](const TNodeMap<ui32>::value_type& v){ map.emplace(v.first, std::make_pair(0U, v.second)); }); + auto l = &one, r = &two; + return CompareExpressions(l, r, map, level, visited); +} + +class TCacheKeyBuilder { +public: + TString Process(const TExprNode& root) { + SHA256_Init(&Sha); + unsigned char hash[SHA256_DIGEST_LENGTH]; + Visit(root); + SHA256_Final(hash, &Sha); + return TString((const char*)hash, sizeof(hash)); + } + +private: + void Visit(const TExprNode& node) { + auto [it, inserted] = Visited.emplace(&node, Visited.size()); + SHA256_Update(&Sha, &it->second, sizeof(it->second)); + if (!inserted) { + return; + } + + ui32 type = node.Type(); + SHA256_Update(&Sha, &type, sizeof(type)); + if (node.Type() == TExprNode::EType::Atom || node.Type() == TExprNode::EType::Callable) { + ui32 textLen = node.Content().size(); + SHA256_Update(&Sha, &textLen, sizeof(textLen)); + SHA256_Update(&Sha, node.Content().data(), textLen); + } + + if (node.Type() == TExprNode::EType::Atom || node.Type() == TExprNode::EType::Argument || node.Type() == TExprNode::EType::World) { + return; + } + + ui32 len = node.ChildrenSize(); + SHA256_Update(&Sha, &len, sizeof(len)); + for (const auto& child : node.Children()) { + Visit(*child); + } + } + +private: + SHA256_CTX Sha; + TNodeMap<ui64> Visited; +}; + +TString MakeCacheKey(const TExprNode& root) { + TCacheKeyBuilder builder; + return builder.Process(root); +} + +void GatherParents(const TExprNode& node, TParentsMap& parentsMap) { + parentsMap.clear(); + TNodeSet visisted; + GatherParentsImpl(node, parentsMap, visisted); +} + +void CheckCounts(const TExprNode& root) { + TRefCountsMap refCounts; + CalculateReferences(root, refCounts); + TNodeSet visited; + CheckReferences(root, refCounts, visited); +} + +TString SubstParameters(const TString& str, const TMaybe<NYT::TNode>& params, TSet<TString>* usedNames) { + size_t pos = 0; + try { + TStringBuilder res; + bool insideBrackets = false; + TStringBuilder paramBuilder; + for (char c : str) { + if (c == '{') { + if (insideBrackets) { + throw yexception() << "Unpexpected {"; + } + + insideBrackets = true; + continue; + } + + if (c == '}') { + if (!insideBrackets) { + throw yexception() << "Unexpected }"; + } + + insideBrackets = false; + TString param = paramBuilder; + paramBuilder.clear(); + if (usedNames) { + usedNames->insert(param); + } + + if (params) { + const auto& map = params->AsMap(); + auto it = map.find(param); + if (it == map.end()) { + throw yexception() << "No such parameter: '" << param << "'"; + } + + const auto& value = it->second["Data"]; + if (!value.IsString()) { + throw yexception() << "Parameter value must be a string"; + } + + res << value.AsString(); + } + + continue; + } + + if (insideBrackets) { + paramBuilder << c; + } + else { + res << c; + } + + ++pos; + } + + if (insideBrackets) { + throw yexception() << "Missing }"; + } + + return res; + } + catch (yexception& e) { + throw yexception() << "Failed to substitute parameters into url: " << str << ", reason:" << e.what() << ", position: " << pos; + } +} + +const TTypeAnnotationNode* GetSeqItemType(const TTypeAnnotationNode* type) { + if (!type) + return nullptr; + + switch (type->GetKind()) { + case ETypeAnnotationKind::List: return type->Cast<TListExprType>()->GetItemType(); + case ETypeAnnotationKind::Flow: return type->Cast<TFlowExprType>()->GetItemType(); + case ETypeAnnotationKind::Stream: return type->Cast<TStreamExprType>()->GetItemType(); + case ETypeAnnotationKind::Optional: return type->Cast<TOptionalExprType>()->GetItemType(); + default: break; + } + return nullptr; +} + +const TTypeAnnotationNode& GetSeqItemType(const TTypeAnnotationNode& type) { + if (const auto itemType = GetSeqItemType(&type)) + return *itemType; + throw yexception() << "Impossible to get item type from " << type; +} + +const TTypeAnnotationNode& RemoveOptionality(const TTypeAnnotationNode& type) { + return ETypeAnnotationKind::Optional == type.GetKind() ? *type.Cast<TOptionalExprType>()->GetItemType() : type; +} + +TMaybe<TIssue> NormalizeName(TPosition position, TString& name) { + const ui32 inputLength = name.length(); + ui32 startCharPos = 0; + ui32 totalSkipped = 0; + bool atStart = true; + bool justSkippedUnderscore = false; + for (ui32 i = 0; i < inputLength; ++i) { + const char c = name.at(i); + if (c == '_') { + if (!atStart) { + if (justSkippedUnderscore) { + return TIssue(position, TStringBuilder() << "\"" << name << "\" looks weird, has multiple consecutive underscores"); + } + justSkippedUnderscore = true; + ++totalSkipped; + continue; + } else { + ++startCharPos; + } + } + else { + atStart = false; + justSkippedUnderscore = false; + } + } + + if (totalSkipped >= 5) { + return TIssue(position, TStringBuilder() << "\"" << name << "\" looks weird, has multiple consecutive underscores"); + } + + ui32 outPos = startCharPos; + for (ui32 i = startCharPos; i < inputLength; i++) { + const char c = name.at(i); + if (c == '_') { + continue; + } else { + name[outPos] = AsciiToLower(c); + ++outPos; + } + } + + name.resize(outPos); + Y_ABORT_UNLESS(inputLength - outPos == totalSkipped); + + return Nothing(); +} + +TString NormalizeName(const TStringBuf& name) { + TString result(name); + TMaybe<TIssue> error = NormalizeName(TPosition(), result); + YQL_ENSURE(error.Empty(), "" << error->GetMessage()); + return result; +} + +} // namespace NYql + +template<> +void Out<NYql::TExprNode::EType>(class IOutputStream &o, NYql::TExprNode::EType x) { +#define YQL_EXPR_NODE_TYPE_MAP_TO_STRING_IMPL(name, ...) \ + case NYql::TExprNode::name: \ + o << #name; \ + return; + + switch (x) { + YQL_EXPR_NODE_TYPE_MAP(YQL_EXPR_NODE_TYPE_MAP_TO_STRING_IMPL) + default: + o << static_cast<int>(x); + return; + } +} + +template<> +void Out<NYql::TExprNode::EState>(class IOutputStream &o, NYql::TExprNode::EState x) { +#define YQL_EXPR_NODE_STATE_MAP_TO_STRING_IMPL(name, ...) \ + case NYql::TExprNode::EState::name: \ + o << #name; \ + return; + + switch (x) { + YQL_EXPR_NODE_STATE_MAP(YQL_EXPR_NODE_STATE_MAP_TO_STRING_IMPL) + default: + o << static_cast<int>(x); + return; + } +} diff --git a/yql/essentials/ast/yql_expr.h b/yql/essentials/ast/yql_expr.h new file mode 100644 index 00000000000..82382045526 --- /dev/null +++ b/yql/essentials/ast/yql_expr.h @@ -0,0 +1,2880 @@ +#pragma once + +#include "yql_ast.h" +#include "yql_expr_types.h" +#include "yql_type_string.h" +#include "yql_expr_builder.h" +#include "yql_gc_nodes.h" +#include "yql_constraint.h" +#include "yql_pos_handle.h" + +#include <yql/essentials/core/url_lister/interface/url_lister_manager.h> +#include <yql/essentials/utils/yql_panic.h> +#include <yql/essentials/public/issue/yql_issue_manager.h> +#include <yql/essentials/public/udf/udf_data_type.h> + +#include <library/cpp/yson/node/node.h> + +#include <library/cpp/string_utils/levenshtein_diff/levenshtein_diff.h> +#include <library/cpp/enumbitset/enumbitset.h> +#include <library/cpp/containers/stack_vector/stack_vec.h> +#include <library/cpp/deprecated/enum_codegen/enum_codegen.h> + +#include <util/string/ascii.h> +#include <util/string/builder.h> +#include <util/generic/array_ref.h> +#include <util/generic/deque.h> +#include <util/generic/cast.h> +#include <util/generic/hash.h> +#include <util/generic/maybe.h> +#include <util/generic/set.h> +#include <util/generic/bt_exception.h> +#include <util/generic/algorithm.h> +#include <util/digest/murmur.h> + +#include <algorithm> +#include <unordered_set> +#include <unordered_map> +#include <span> +#include <stack> + +//#define YQL_CHECK_NODES_CONSISTENCY +#ifdef YQL_CHECK_NODES_CONSISTENCY + #define ENSURE_NOT_DELETED \ + YQL_ENSURE(!Dead(), "Access to dead node # " << UniqueId_ << ": " << Type_ << " '" << ContentUnchecked() << "'"); + #define ENSURE_NOT_FROZEN \ + YQL_ENSURE(!Frozen(), "Change in frozen node # " << UniqueId_ << ": " << Type_ << " '" << ContentUnchecked() << "'"); + #define ENSURE_NOT_FROZEN_CTX \ + YQL_ENSURE(!Frozen, "Change in frozen expr context."); +#else + #define ENSURE_NOT_DELETED Y_DEBUG_ABORT_UNLESS(!Dead(), "Access to dead node # %lu: %d '%s'", UniqueId_, (int)Type_, TString(ContentUnchecked()).data()); + #define ENSURE_NOT_FROZEN Y_DEBUG_ABORT_UNLESS(!Frozen()); + #define ENSURE_NOT_FROZEN_CTX Y_DEBUG_ABORT_UNLESS(!Frozen); +#endif + +namespace NYql { + +using NUdf::EDataSlot; + +class TUnitExprType; +class TMultiExprType; +class TTupleExprType; +class TStructExprType; +class TItemExprType; +class TListExprType; +class TStreamExprType; +class TDataExprType; +class TPgExprType; +class TWorldExprType; +class TOptionalExprType; +class TCallableExprType; +class TResourceExprType; +class TTypeExprType; +class TDictExprType; +class TVoidExprType; +class TNullExprType; +class TGenericExprType; +class TTaggedExprType; +class TErrorExprType; +class TVariantExprType; +class TStreamExprType; +class TFlowExprType; +class TEmptyListExprType; +class TEmptyDictExprType; +class TBlockExprType; +class TScalarExprType; + +const size_t DefaultMistypeDistance = 3; +const TString YqlVirtualPrefix = "_yql_virtual_"; + +extern const TStringBuf ZeroString; + +struct TTypeAnnotationVisitor { + virtual ~TTypeAnnotationVisitor() = default; + + virtual void Visit(const TUnitExprType& type) = 0; + virtual void Visit(const TMultiExprType& type) = 0; + virtual void Visit(const TTupleExprType& type) = 0; + virtual void Visit(const TStructExprType& type) = 0; + virtual void Visit(const TItemExprType& type) = 0; + virtual void Visit(const TListExprType& type) = 0; + virtual void Visit(const TStreamExprType& type) = 0; + virtual void Visit(const TFlowExprType& type) = 0; + virtual void Visit(const TDataExprType& type) = 0; + virtual void Visit(const TPgExprType& type) = 0; + virtual void Visit(const TWorldExprType& type) = 0; + virtual void Visit(const TOptionalExprType& type) = 0; + virtual void Visit(const TCallableExprType& type) = 0; + virtual void Visit(const TResourceExprType& type) = 0; + virtual void Visit(const TTypeExprType& type) = 0; + virtual void Visit(const TDictExprType& type) = 0; + virtual void Visit(const TVoidExprType& type) = 0; + virtual void Visit(const TNullExprType& type) = 0; + virtual void Visit(const TGenericExprType& type) = 0; + virtual void Visit(const TTaggedExprType& type) = 0; + virtual void Visit(const TErrorExprType& type) = 0; + virtual void Visit(const TVariantExprType& type) = 0; + virtual void Visit(const TEmptyListExprType& type) = 0; + virtual void Visit(const TEmptyDictExprType& type) = 0; + virtual void Visit(const TBlockExprType& type) = 0; + virtual void Visit(const TScalarExprType& type) = 0; +}; + +enum ETypeAnnotationFlags : ui32 { + TypeNonComposable = 0x01, + TypeNonPersistable = 0x02, + TypeNonComputable = 0x04, + TypeNonInspectable = 0x08, + TypeNonHashable = 0x10, + TypeNonEquatable = 0x20, + TypeNonComparable = 0x40, + TypeHasNull = 0x80, + TypeHasOptional = 0x100, + TypeHasManyValues = 0x200, + TypeHasBareYson = 0x400, + TypeHasNestedOptional = 0x800, + TypeNonPresortable = 0x1000, + TypeHasDynamicSize = 0x2000, + TypeNonComparableInternal = 0x4000, +}; + +const ui64 TypeHashMagic = 0x10000; + +inline ui64 StreamHash(const void* buffer, size_t size, ui64 seed) { + return MurmurHash(buffer, size, seed); +} + +inline ui64 StreamHash(ui64 value, ui64 seed) { + return MurmurHash(&value, sizeof(value), seed); +} + +void ReportError(TExprContext& ctx, const TIssue& issue); + +class TTypeAnnotationNode { +protected: + TTypeAnnotationNode(ETypeAnnotationKind kind, ui32 flags, ui64 hash, ui64 usedPgExtensions) + : Kind(kind) + , Flags(flags) + , Hash(hash) + , UsedPgExtensions(usedPgExtensions) + { + } + +public: + virtual ~TTypeAnnotationNode() = default; + + template <typename T> + const T* Cast() const { + static_assert(std::is_base_of<TTypeAnnotationNode, T>::value, + "Should be derived from TTypeAnnotationNode"); + + const auto ret = dynamic_cast<const T*>(this); + YQL_ENSURE(ret, "Cannot cast type " << *this << " to " << ETypeAnnotationKind(T::KindValue)); + return ret; + } + + template <typename T> + const T* UserCast(TPosition pos, TExprContext& ctx) const { + static_assert(std::is_base_of<TTypeAnnotationNode, T>::value, + "Should be derived from TTypeAnnotationNode"); + + const auto ret = dynamic_cast<const T*>(this); + if (!ret) { + ReportError(ctx, TIssue(pos, TStringBuilder() << "Cannot cast type " << *this << " to " << ETypeAnnotationKind(T::KindValue))); + } + + return ret; + } + + ETypeAnnotationKind GetKind() const { + return Kind; + } + + bool IsComposable() const { + return (GetFlags() & TypeNonComposable) == 0; + } + + bool IsPersistable() const { + return (GetFlags() & TypeNonPersistable) == 0; + } + + bool IsComputable() const { + return (GetFlags() & TypeNonComputable) == 0; + } + + bool IsInspectable() const { + return (GetFlags() & TypeNonInspectable) == 0; + } + + bool IsHashable() const { + return IsPersistable() && (GetFlags() & TypeNonHashable) == 0; + } + + bool IsEquatable() const { + return IsPersistable() && (GetFlags() & TypeNonEquatable) == 0; + } + + bool IsComparable() const { + return IsPersistable() && (GetFlags() & TypeNonComparable) == 0; + } + + bool IsComparableInternal() const { + return IsPersistable() && (GetFlags() & TypeNonComparableInternal) == 0; + } + + bool HasNull() const { + return (GetFlags() & TypeHasNull) != 0; + } + + bool HasOptional() const { + return (GetFlags() & TypeHasOptional) != 0; + } + + bool HasNestedOptional() const { + return (GetFlags() & TypeHasNestedOptional) != 0; + } + + bool HasOptionalOrNull() const { + return (GetFlags() & (TypeHasOptional | TypeHasNull)) != 0; + } + + bool IsOptionalOrNull() const { + auto kind = GetKind(); + return kind == ETypeAnnotationKind::Optional || kind == ETypeAnnotationKind::Null || kind == ETypeAnnotationKind::Pg; + } + + bool IsBlockOrScalar() const { + return IsBlock() || IsScalar(); + } + + bool IsBlock() const { + return GetKind() == ETypeAnnotationKind::Block; + } + + bool IsScalar() const { + return GetKind() == ETypeAnnotationKind::Scalar; + } + + bool HasFixedSizeRepr() const { + return (GetFlags() & (TypeHasDynamicSize | TypeNonPersistable | TypeNonComputable)) == 0; + } + + bool IsSingleton() const { + return (GetFlags() & TypeHasManyValues) == 0; + } + + bool HasBareYson() const { + return (GetFlags() & TypeHasBareYson) != 0; + } + + bool IsPresortSupported() const { + return (GetFlags() & TypeNonPresortable) == 0; + } + + ui32 GetFlags() const { + return Flags; + } + + ui64 GetHash() const { + return Hash; + } + + ui64 GetUsedPgExtensions() const { + return UsedPgExtensions; + } + + bool Equals(const TTypeAnnotationNode& node) const; + void Accept(TTypeAnnotationVisitor& visitor) const; + + void Out(IOutputStream& out) const { + out << FormatType(this); + } + + struct THash { + size_t operator()(const TTypeAnnotationNode* node) const { + return node->GetHash(); + } + }; + + struct TEqual { + bool operator()(const TTypeAnnotationNode* one, const TTypeAnnotationNode* two) const { + return one->Equals(*two); + } + }; + + typedef std::vector<const TTypeAnnotationNode*> TListType; + typedef std::span<const TTypeAnnotationNode*> TSpanType; +protected: + template <typename T> + static ui32 CombineFlags(const T& items) { + ui32 flags = 0; + for (auto& item : items) { + flags |= item->GetFlags(); + } + + return flags; + } + + template <typename T> + static ui64 CombinePgExtensions(const T& items) { + ui64 mask = 0; + for (auto& item : items) { + mask |= item->GetUsedPgExtensions(); + } + + return mask; + } + +private: + const ETypeAnnotationKind Kind; + const ui32 Flags; + const ui64 Hash; + const ui64 UsedPgExtensions; +}; + +class TUnitExprType : public TTypeAnnotationNode { +public: + static constexpr ETypeAnnotationKind KindValue = ETypeAnnotationKind::Unit; + + TUnitExprType(ui64 hash) + : TTypeAnnotationNode(KindValue, + TypeNonComputable | TypeNonPersistable, hash, 0) + { + } + + static ui64 MakeHash() { + return TypeHashMagic | (ui64)ETypeAnnotationKind::Unit; + } + + bool operator==(const TUnitExprType& other) const { + Y_UNUSED(other); + return true; + } +}; + +class TTupleExprType : public TTypeAnnotationNode { +public: + static constexpr ETypeAnnotationKind KindValue = ETypeAnnotationKind::Tuple; + + TTupleExprType(ui64 hash, const TTypeAnnotationNode::TListType& items) + : TTypeAnnotationNode(KindValue, CombineFlags(items), hash, CombinePgExtensions(items)) + , Items(items) + { + } + + static ui64 MakeHash(const TTypeAnnotationNode::TListType& items) { + ui64 hash = TypeHashMagic | (ui64)ETypeAnnotationKind::Tuple; + hash = StreamHash(items.size(), hash); + for (const auto& item : items) { + hash = StreamHash(item->GetHash(), hash); + } + + return hash; + } + + size_t GetSize() const { + return Items.size(); + } + + const TTypeAnnotationNode::TListType& GetItems() const { + return Items; + } + + bool operator==(const TTupleExprType& other) const { + if (GetSize() != other.GetSize()) { + return false; + } + + for (ui32 i = 0, e = GetSize(); i < e; ++i) { + if (GetItems()[i] != other.GetItems()[i]) { + return false; + } + } + + return true; + } + + bool Validate(TPosition position, TExprContext& ctx) const; + bool Validate(TPositionHandle position, TExprContext& ctx) const; + +private: + TTypeAnnotationNode::TListType Items; +}; + +class TMultiExprType : public TTypeAnnotationNode { +public: + static constexpr ETypeAnnotationKind KindValue = ETypeAnnotationKind::Multi; + + TMultiExprType(ui64 hash, const TTypeAnnotationNode::TListType& items) + : TTypeAnnotationNode(KindValue, CombineFlags(items), hash, CombinePgExtensions(items)) + , Items(items) + { + } + + static ui64 MakeHash(const TTypeAnnotationNode::TListType& items) { + ui64 hash = TypeHashMagic | (ui64)ETypeAnnotationKind::Multi; + hash = StreamHash(items.size(), hash); + for (const auto& item : items) { + hash = StreamHash(item->GetHash(), hash); + } + + return hash; + } + + size_t GetSize() const { + return Items.size(); + } + + const TTypeAnnotationNode::TListType& GetItems() const { + return Items; + } + + bool operator==(const TMultiExprType& other) const { + if (GetSize() != other.GetSize()) { + return false; + } + + for (ui32 i = 0, e = GetSize(); i < e; ++i) { + if (GetItems()[i] != other.GetItems()[i]) { + return false; + } + } + + return true; + } + + bool Validate(TPosition position, TExprContext& ctx) const; + bool Validate(TPositionHandle position, TExprContext& ctx) const; + +private: + TTypeAnnotationNode::TListType Items; +}; + +struct TExprContext; + + +bool ValidateName(TPosition position, TStringBuf name, TStringBuf descr, TExprContext& ctx); +bool ValidateName(TPositionHandle position, TStringBuf name, TStringBuf descr, TExprContext& ctx); + +class TItemExprType : public TTypeAnnotationNode { +public: + static constexpr ETypeAnnotationKind KindValue = ETypeAnnotationKind::Item; + + TItemExprType(ui64 hash, const TStringBuf& name, const TTypeAnnotationNode* itemType) + : TTypeAnnotationNode(KindValue, itemType->GetFlags(), hash, itemType->GetUsedPgExtensions()) + , Name(name) + , ItemType(itemType) + { + } + + static ui64 MakeHash(const TStringBuf& name, const TTypeAnnotationNode* itemType) { + ui64 hash = TypeHashMagic | (ui64)ETypeAnnotationKind::Item; + hash = StreamHash(name.size(), hash); + hash = StreamHash(name.data(), name.size(), hash); + return StreamHash(itemType->GetHash(), hash); + } + + bool Validate(TPosition position, TExprContext& ctx) const; + bool Validate(TPositionHandle position, TExprContext& ctx) const; + + const TStringBuf& GetName() const { + return Name; + } + + TStringBuf GetCleanName(bool isVirtual) const; + + const TTypeAnnotationNode* GetItemType() const { + return ItemType; + } + + bool operator==(const TItemExprType& other) const { + return GetName() == other.GetName() && GetItemType() == other.GetItemType(); + } + + const TItemExprType* GetCleanItem(bool isVirtual, TExprContext& ctx) const; + +private: + const TStringBuf Name; + const TTypeAnnotationNode* ItemType; +}; + +class TStructExprType : public TTypeAnnotationNode { +public: + static constexpr ETypeAnnotationKind KindValue = ETypeAnnotationKind::Struct; + + struct TItemLess { + bool operator()(const TItemExprType* x, const TItemExprType* y) const { + return x->GetName() < y->GetName(); + }; + + bool operator()(const TItemExprType* x, const TStringBuf& y) const { + return x->GetName() < y; + }; + + bool operator()(const TStringBuf& x, const TItemExprType* y) const { + return x < y->GetName(); + }; + }; + + TStructExprType(ui64 hash, const TVector<const TItemExprType*>& items) + : TTypeAnnotationNode(KindValue, TypeNonComparable | CombineFlags(items), hash, CombinePgExtensions(items)) + , Items(items) + { + } + + static ui64 MakeHash(const TVector<const TItemExprType*>& items) { + Y_DEBUG_ABORT_UNLESS(IsSorted(items.begin(), items.end(), TItemLess())); + ui64 hash = TypeHashMagic | (ui64)ETypeAnnotationKind::Struct; + hash = StreamHash(items.size(), hash); + for (const auto& item : items) { + hash = StreamHash(item->GetHash(), hash); + } + + return hash; + } + + bool Validate(TPosition position, TExprContext& ctx) const; + bool Validate(TPositionHandle position, TExprContext& ctx) const; + + size_t GetSize() const { + return Items.size(); + } + + const TVector<const TItemExprType*>& GetItems() const { + return Items; + } + + TMaybe<ui32> FindItem(const TStringBuf& name) const { + auto it = LowerBound(Items.begin(), Items.end(), name, TItemLess()); + if (it == Items.end() || (*it)->GetName() != name) { + return TMaybe<ui32>(); + } + + return it - Items.begin(); + } + + TMaybe<ui32> FindItemI(const TStringBuf& name, bool* isVirtual) const { + for (ui32 v = 0; v < 2; ++v) { + if (isVirtual) { + *isVirtual = v > 0; + } + + auto nameToSearch = (v ? YqlVirtualPrefix : "") + name; + auto strict = FindItem(nameToSearch); + if (strict) { + return strict; + } + + TMaybe<ui32> ret; + for (ui32 i = 0; i < Items.size(); ++i) { + if (AsciiEqualsIgnoreCase(nameToSearch, Items[i]->GetName())) { + if (ret) { + return Nothing(); + } + + ret = i; + } + } + + if (ret) { + return ret; + } + } + + return Nothing(); + } + + const TTypeAnnotationNode* FindItemType(const TStringBuf& name) const { + const auto it = LowerBound(Items.begin(), Items.end(), name, TItemLess()); + if (it == Items.end() || (*it)->GetName() != name) { + return nullptr; + } + + return (*it)->GetItemType(); + } + + TMaybe<TStringBuf> FindMistype(const TStringBuf& name) const { + for (const auto& item: Items) { + if (NLevenshtein::Distance(name, item->GetName()) < DefaultMistypeDistance) { + return item->GetName(); + } + } + return TMaybe<TStringBuf>(); + } + + bool operator==(const TStructExprType& other) const { + if (GetSize() != other.GetSize()) { + return false; + } + + for (ui32 i = 0, e = GetSize(); i < e; ++i) { + if (GetItems()[i] != other.GetItems()[i]) { + return false; + } + } + + return true; + } + + + TString ToString() const { + TStringBuilder sb; + + for (std::size_t i = 0; i < Items.size(); i++) { + sb << i << ": " << Items[i]->GetName() << "(" << FormatType(Items[i]->GetItemType()) << ")"; + if (i != Items.size() - 1) { + sb << ", "; + } + } + + return sb; + } + +private: + TVector<const TItemExprType*> Items; +}; + +class TListExprType : public TTypeAnnotationNode { +public: + static constexpr ETypeAnnotationKind KindValue = ETypeAnnotationKind::List; + + TListExprType(ui64 hash, const TTypeAnnotationNode* itemType) + : TTypeAnnotationNode(KindValue, itemType->GetFlags() | TypeHasDynamicSize, hash, itemType->GetUsedPgExtensions()) + , ItemType(itemType) + { + } + + static ui64 MakeHash(const TTypeAnnotationNode* itemType) { + ui64 hash = TypeHashMagic | (ui64)ETypeAnnotationKind::List; + return StreamHash(itemType->GetHash(), hash); + } + + const TTypeAnnotationNode* GetItemType() const { + return ItemType; + } + + bool operator==(const TListExprType& other) const { + return GetItemType() == other.GetItemType(); + } + +private: + const TTypeAnnotationNode* ItemType; +}; + +class TStreamExprType : public TTypeAnnotationNode { +public: + static constexpr ETypeAnnotationKind KindValue = ETypeAnnotationKind::Stream; + + TStreamExprType(ui64 hash, const TTypeAnnotationNode* itemType) + : TTypeAnnotationNode(KindValue, itemType->GetFlags() | TypeNonPersistable, hash, itemType->GetUsedPgExtensions()) + , ItemType(itemType) + { + } + + static ui64 MakeHash(const TTypeAnnotationNode* itemType) { + ui64 hash = TypeHashMagic | (ui64)ETypeAnnotationKind::Stream; + return StreamHash(itemType->GetHash(), hash); + } + + const TTypeAnnotationNode* GetItemType() const { + return ItemType; + } + + bool operator==(const TStreamExprType& other) const { + return GetItemType() == other.GetItemType(); + } + +private: + const TTypeAnnotationNode* ItemType; +}; + +class TFlowExprType : public TTypeAnnotationNode { +public: + static constexpr ETypeAnnotationKind KindValue = ETypeAnnotationKind::Flow; + + TFlowExprType(ui64 hash, const TTypeAnnotationNode* itemType) + : TTypeAnnotationNode(KindValue, itemType->GetFlags() | TypeNonPersistable, hash, itemType->GetUsedPgExtensions()) + , ItemType(itemType) + { + } + + static ui64 MakeHash(const TTypeAnnotationNode* itemType) { + ui64 hash = TypeHashMagic | (ui64)ETypeAnnotationKind::Flow; + return StreamHash(itemType->GetHash(), hash); + } + + const TTypeAnnotationNode* GetItemType() const { + return ItemType; + } + + bool operator==(const TFlowExprType& other) const { + return GetItemType() == other.GetItemType(); + } + +private: + const TTypeAnnotationNode* ItemType; +}; + +class TBlockExprType : public TTypeAnnotationNode { +public: + static constexpr ETypeAnnotationKind KindValue = ETypeAnnotationKind::Block; + + TBlockExprType(ui64 hash, const TTypeAnnotationNode* itemType) + : TTypeAnnotationNode(KindValue, itemType->GetFlags() | TypeNonPersistable, hash, itemType->GetUsedPgExtensions()) + , ItemType(itemType) + { + } + + static ui64 MakeHash(const TTypeAnnotationNode* itemType) { + ui64 hash = TypeHashMagic | (ui64)ETypeAnnotationKind::Block; + return StreamHash(itemType->GetHash(), hash); + } + + const TTypeAnnotationNode* GetItemType() const { + return ItemType; + } + + bool operator==(const TBlockExprType& other) const { + return GetItemType() == other.GetItemType(); + } + +private: + const TTypeAnnotationNode* ItemType; +}; + +class TScalarExprType : public TTypeAnnotationNode { +public: + static constexpr ETypeAnnotationKind KindValue = ETypeAnnotationKind::Scalar; + + TScalarExprType(ui64 hash, const TTypeAnnotationNode* itemType) + : TTypeAnnotationNode(KindValue, itemType->GetFlags() | TypeNonPersistable, hash, itemType->GetUsedPgExtensions()) + , ItemType(itemType) + { + } + + static ui64 MakeHash(const TTypeAnnotationNode* itemType) { + ui64 hash = TypeHashMagic | (ui64)ETypeAnnotationKind::Scalar; + return StreamHash(itemType->GetHash(), hash); + } + + const TTypeAnnotationNode* GetItemType() const { + return ItemType; + } + + bool operator==(const TScalarExprType& other) const { + return GetItemType() == other.GetItemType(); + } + +private: + const TTypeAnnotationNode* ItemType; +}; + +class TDataExprType : public TTypeAnnotationNode { +public: + static constexpr ETypeAnnotationKind KindValue = ETypeAnnotationKind::Data; + + TDataExprType(ui64 hash, EDataSlot slot) + : TTypeAnnotationNode(KindValue, GetFlags(slot), hash, 0) + , Slot(slot) + { + } + + static ui32 GetFlags(EDataSlot slot) { + ui32 ret = TypeHasManyValues; + auto props = NUdf::GetDataTypeInfo(slot).Features; + if (!(props & NUdf::CanHash)) { + ret |= TypeNonHashable; + } + + if (!(props & NUdf::CanEquate)) { + ret |= TypeNonEquatable; + } + + if (!(props & NUdf::CanCompare)) { + ret |= TypeNonComparable; + ret |= TypeNonComparableInternal; + } + + if (slot == NUdf::EDataSlot::Yson) { + ret |= TypeHasBareYson; + } + + if (props & NUdf::StringType) { + ret |= TypeHasDynamicSize; + } + + return ret; + } + + static ui64 MakeHash(EDataSlot slot) { + ui64 hash = TypeHashMagic | (ui64)ETypeAnnotationKind::Data; + auto dataType = NUdf::GetDataTypeInfo(slot).Name; + hash = StreamHash(dataType.size(), hash); + return StreamHash(dataType.data(), dataType.size(), hash); + } + + EDataSlot GetSlot() const { + return Slot; + } + + TStringBuf GetName() const { + return NUdf::GetDataTypeInfo(Slot).Name; + } + + bool operator==(const TDataExprType& other) const { + return Slot == other.Slot; + } + +private: + EDataSlot Slot; +}; + +class TDataExprParamsType : public TDataExprType { +public: + TDataExprParamsType(ui64 hash, EDataSlot slot, const TStringBuf& one, const TStringBuf& two) + : TDataExprType(hash, slot), One(one), Two(two) + {} + + static ui64 MakeHash(EDataSlot slot, const TStringBuf& one, const TStringBuf& two) { + auto hash = TDataExprType::MakeHash(slot); + hash = StreamHash(one.size(), hash); + hash = StreamHash(one.data(), one.size(), hash); + hash = StreamHash(two.size(), hash); + hash = StreamHash(two.data(), two.size(), hash); + return hash; + } + + const TStringBuf& GetParamOne() const { + return One; + } + + const TStringBuf& GetParamTwo() const { + return Two; + } + + bool operator==(const TDataExprParamsType& other) const { + return GetSlot() == other.GetSlot() && GetParamOne() == other.GetParamOne() && GetParamTwo() == other.GetParamTwo(); + } + + bool Validate(TPosition position, TExprContext& ctx) const; + bool Validate(TPositionHandle position, TExprContext& ctx) const; + +private: + const TStringBuf One, Two; +}; + +class TPgExprType : public TTypeAnnotationNode { +public: + static constexpr ETypeAnnotationKind KindValue = ETypeAnnotationKind::Pg; + + // TODO: TypeHasDynamicSize for Pg types + TPgExprType(ui64 hash, ui32 typeId) + : TTypeAnnotationNode(KindValue, GetFlags(typeId), hash, GetPgExtensionsMask(typeId)) + , TypeId(typeId) + { + } + + static ui64 MakeHash(ui32 typeId) { + ui64 hash = TypeHashMagic | (ui64)ETypeAnnotationKind::Pg; + return StreamHash(typeId, hash); + } + + const TString& GetName() const; + + ui32 GetId() const { + return TypeId; + } + + bool operator==(const TPgExprType& other) const { + return TypeId == other.TypeId; + } + +private: + ui32 GetFlags(ui32 typeId); + ui64 GetPgExtensionsMask(ui32 typeId); + +private: + ui32 TypeId; +}; + +ui64 MakePgExtensionMask(ui32 extensionIndex); + +class TWorldExprType : public TTypeAnnotationNode { +public: + static constexpr ETypeAnnotationKind KindValue = ETypeAnnotationKind::World; + + TWorldExprType(ui64 hash) + : TTypeAnnotationNode(KindValue, + TypeNonComposable | TypeNonComputable | TypeNonPersistable | TypeNonInspectable, hash, 0) + { + } + + static ui64 MakeHash() { + return TypeHashMagic | (ui64)ETypeAnnotationKind::World; + } + + bool operator==(const TWorldExprType& other) const { + Y_UNUSED(other); + return true; + } +}; + +class TOptionalExprType : public TTypeAnnotationNode { +public: + static constexpr ETypeAnnotationKind KindValue = ETypeAnnotationKind::Optional; + + TOptionalExprType(ui64 hash, const TTypeAnnotationNode* itemType) + : TTypeAnnotationNode(KindValue, GetFlags(itemType), hash, itemType->GetUsedPgExtensions()) + , ItemType(itemType) + { + } + + static ui32 GetFlags(const TTypeAnnotationNode* itemType) { + auto ret = TypeHasOptional | itemType->GetFlags(); + if (itemType->GetKind() == ETypeAnnotationKind::Data && + itemType->Cast<TDataExprType>()->GetSlot() == NUdf::EDataSlot::Yson) { + ret = ret & ~TypeHasBareYson; + } + if (itemType->IsOptionalOrNull()) { + ret |= TypeHasNestedOptional; + } + + return ret; + } + + static ui64 MakeHash(const TTypeAnnotationNode* itemType) { + ui64 hash = TypeHashMagic | (ui64)ETypeAnnotationKind::Optional; + return StreamHash(itemType->GetHash(), hash); + } + + const TTypeAnnotationNode* GetItemType() const { + return ItemType; + } + + bool operator==(const TOptionalExprType& other) const { + return GetItemType() == other.GetItemType(); + } + +private: + const TTypeAnnotationNode* ItemType; +}; + +class TVariantExprType : public TTypeAnnotationNode { +public: + static constexpr ETypeAnnotationKind KindValue = ETypeAnnotationKind::Variant; + + TVariantExprType(ui64 hash, const TTypeAnnotationNode* underlyingType) + : TTypeAnnotationNode(KindValue, MakeFlags(underlyingType), hash, underlyingType->GetUsedPgExtensions()) + , UnderlyingType(underlyingType) + { + } + + static ui64 MakeHash(const TTypeAnnotationNode* underlyingType) { + ui64 hash = TypeHashMagic | (ui64)ETypeAnnotationKind::Variant; + return StreamHash(underlyingType->GetHash(), hash); + } + + const TTypeAnnotationNode* GetUnderlyingType() const { + return UnderlyingType; + } + + bool operator==(const TVariantExprType& other) const { + return GetUnderlyingType() == other.GetUnderlyingType(); + } + + bool Validate(TPosition position, TExprContext& ctx) const; + bool Validate(TPositionHandle position, TExprContext& ctx) const; + + static ui32 MakeFlags(const TTypeAnnotationNode* underlyingType); + +private: + const TTypeAnnotationNode* UnderlyingType; +}; + +class TTypeExprType : public TTypeAnnotationNode { +public: + static constexpr ETypeAnnotationKind KindValue = ETypeAnnotationKind::Type; + + TTypeExprType(ui64 hash, const TTypeAnnotationNode* type) + : TTypeAnnotationNode(KindValue, TypeNonPersistable | TypeNonComputable, hash, 0) + , Type(type) + { + } + + static ui64 MakeHash(const TTypeAnnotationNode* type) { + ui64 hash = TypeHashMagic | (ui64)ETypeAnnotationKind::Type; + return StreamHash(type->GetHash(), hash); + } + + const TTypeAnnotationNode* GetType() const { + return Type; + } + + bool operator==(const TTypeExprType& other) const { + return GetType() == other.GetType(); + } + +private: + const TTypeAnnotationNode* Type; +}; + +class TDictExprType : public TTypeAnnotationNode { +public: + static constexpr ETypeAnnotationKind KindValue = ETypeAnnotationKind::Dict; + + TDictExprType(ui64 hash, const TTypeAnnotationNode* keyType, const TTypeAnnotationNode* payloadType) + : TTypeAnnotationNode(KindValue, TypeNonComparable | TypeHasDynamicSize | + keyType->GetFlags() | payloadType->GetFlags(), hash, + keyType->GetUsedPgExtensions() | payloadType->GetUsedPgExtensions()) + , KeyType(keyType) + , PayloadType(payloadType) + { + } + + static ui64 MakeHash(const TTypeAnnotationNode* keyType, const TTypeAnnotationNode* payloadType) { + ui64 hash = TypeHashMagic | (ui64)ETypeAnnotationKind::Dict; + return StreamHash(StreamHash(keyType->GetHash(), hash), payloadType->GetHash()); + } + + bool Validate(TPosition position, TExprContext& ctx) const; + bool Validate(TPositionHandle position, TExprContext& ctx) const; + + const TTypeAnnotationNode* GetKeyType() const { + return KeyType; + } + + const TTypeAnnotationNode* GetPayloadType() const { + return PayloadType; + } + + bool operator==(const TDictExprType& other) const { + return GetKeyType() == other.GetKeyType() && + GetPayloadType() == other.GetPayloadType(); + } + +private: + const TTypeAnnotationNode* KeyType; + const TTypeAnnotationNode* PayloadType; +}; + +class TVoidExprType : public TTypeAnnotationNode { +public: + static constexpr ETypeAnnotationKind KindValue = ETypeAnnotationKind::Void; + + TVoidExprType(ui64 hash) + : TTypeAnnotationNode(KindValue, 0, hash, 0) + { + } + + static ui64 MakeHash() { + return TypeHashMagic | (ui64)ETypeAnnotationKind::Void; + } + + bool operator==(const TVoidExprType& other) const { + Y_UNUSED(other); + return true; + } +}; + +class TNullExprType : public TTypeAnnotationNode { +public: + static constexpr ETypeAnnotationKind KindValue = ETypeAnnotationKind::Null; + + TNullExprType(ui64 hash) + : TTypeAnnotationNode(KindValue, TypeHasNull, hash, 0) + { + } + + static ui64 MakeHash() { + return TypeHashMagic | (ui64)ETypeAnnotationKind::Null; + } + + bool operator==(const TNullExprType& other) const { + Y_UNUSED(other); + return true; + } +}; + +struct TArgumentFlags { + enum { + AutoMap = 0x01, + }; +}; + +class TCallableExprType : public TTypeAnnotationNode { +public: + static constexpr ETypeAnnotationKind KindValue = ETypeAnnotationKind::Callable; + + struct TArgumentInfo { + const TTypeAnnotationNode* Type = nullptr; + TStringBuf Name; + ui64 Flags = 0; + + bool operator==(const TArgumentInfo& other) const { + return Type == other.Type && Name == other.Name && Flags == other.Flags; + } + + bool operator!=(const TArgumentInfo& other) const { + return !(*this == other); + } + }; + + TCallableExprType(ui64 hash, const TTypeAnnotationNode* returnType, const TVector<TArgumentInfo>& arguments + , size_t optionalArgumentsCount, const TStringBuf& payload) + : TTypeAnnotationNode(KindValue, MakeFlags(returnType), hash, returnType->GetUsedPgExtensions()) + , ReturnType(returnType) + , Arguments(arguments) + , OptionalArgumentsCount(optionalArgumentsCount) + , Payload(payload) + { + for (ui32 i = 0; i < Arguments.size(); ++i) { + const auto& arg = Arguments[i]; + if (!arg.Name.empty()) { + IndexByName.insert({ arg.Name, i }); + } + } + } + + static ui64 MakeHash(const TTypeAnnotationNode* returnType, const TVector<TArgumentInfo>& arguments + , size_t optionalArgumentsCount, const TStringBuf& payload) { + ui64 hash = TypeHashMagic | (ui64)ETypeAnnotationKind::Callable; + hash = StreamHash(returnType->GetHash(), hash); + hash = StreamHash(arguments.size(), hash); + for (const auto& arg : arguments) { + hash = StreamHash(arg.Name.size(), hash); + hash = StreamHash(arg.Name.data(), arg.Name.size(), hash); + hash = StreamHash(arg.Flags, hash); + hash = StreamHash(arg.Type->GetHash(), hash); + } + + hash = StreamHash(optionalArgumentsCount, hash); + hash = StreamHash(payload.size(), hash); + hash = StreamHash(payload.data(), payload.size(), hash); + return hash; + } + + const TTypeAnnotationNode* GetReturnType() const { + return ReturnType; + } + + size_t GetOptionalArgumentsCount() const { + return OptionalArgumentsCount; + } + + const TStringBuf& GetPayload() const { + return Payload; + } + + size_t GetArgumentsSize() const { + return Arguments.size(); + } + + const TVector<TArgumentInfo>& GetArguments() const { + return Arguments; + } + + bool operator==(const TCallableExprType& other) const { + if (GetArgumentsSize() != other.GetArgumentsSize()) { + return false; + } + + if (GetOptionalArgumentsCount() != other.GetOptionalArgumentsCount()) { + return false; + } + + if (GetReturnType() != other.GetReturnType()) { + return false; + } + + for (ui32 i = 0, e = GetArgumentsSize(); i < e; ++i) { + if (GetArguments()[i] != other.GetArguments()[i]) { + return false; + } + } + + return true; + } + + bool Validate(TPosition position, TExprContext& ctx) const; + bool Validate(TPositionHandle position, TExprContext& ctx) const; + + TMaybe<ui32> ArgumentIndexByName(const TStringBuf& name) const { + auto it = IndexByName.find(name); + if (it == IndexByName.end()) { + return {}; + } + + return it->second; + } + +private: + static ui32 MakeFlags(const TTypeAnnotationNode* returnType) { + ui32 flags = TypeNonPersistable; + flags |= returnType->GetFlags(); + return flags; + } + +private: + const TTypeAnnotationNode* ReturnType; + TVector<TArgumentInfo> Arguments; + const size_t OptionalArgumentsCount; + const TStringBuf Payload; + THashMap<TStringBuf, ui32> IndexByName; +}; + +class TGenericExprType : public TTypeAnnotationNode { +public: + static constexpr ETypeAnnotationKind KindValue = ETypeAnnotationKind::Generic; + + TGenericExprType(ui64 hash) + : TTypeAnnotationNode(KindValue, TypeNonComputable, hash, 0) + { + } + + static ui64 MakeHash() { + return TypeHashMagic | (ui64)ETypeAnnotationKind::Generic; + } + + bool operator==(const TGenericExprType& other) const { + Y_UNUSED(other); + return true; + } +}; + +class TResourceExprType : public TTypeAnnotationNode { +public: + static constexpr ETypeAnnotationKind KindValue = ETypeAnnotationKind::Resource; + + TResourceExprType(ui64 hash, const TStringBuf& tag) + : TTypeAnnotationNode(KindValue, TypeNonPersistable | TypeHasManyValues, hash, 0) + , Tag(tag) + {} + + static ui64 MakeHash(const TStringBuf& tag) { + ui64 hash = TypeHashMagic | (ui64)ETypeAnnotationKind::Resource; + hash = StreamHash(tag.size(), hash); + return StreamHash(tag.data(), tag.size(), hash); + } + + const TStringBuf& GetTag() const { + return Tag; + } + + bool operator==(const TResourceExprType& other) const { + return Tag == other.Tag; + } + +private: + const TStringBuf Tag; +}; + +class TTaggedExprType : public TTypeAnnotationNode { +public: + static constexpr ETypeAnnotationKind KindValue = ETypeAnnotationKind::Tagged; + + TTaggedExprType(ui64 hash, const TTypeAnnotationNode* baseType, const TStringBuf& tag) + : TTypeAnnotationNode(KindValue, baseType->GetFlags(), hash, baseType->GetUsedPgExtensions()) + , BaseType(baseType) + , Tag(tag) + {} + + static ui64 MakeHash(const TTypeAnnotationNode* baseType, const TStringBuf& tag) { + ui64 hash = TypeHashMagic | (ui64)ETypeAnnotationKind::Tagged; + hash = StreamHash(baseType->GetHash(), hash); + hash = StreamHash(tag.size(), hash); + return StreamHash(tag.data(), tag.size(), hash); + } + + const TStringBuf& GetTag() const { + return Tag; + } + + const TTypeAnnotationNode* GetBaseType() const { + return BaseType; + } + + bool operator==(const TTaggedExprType& other) const { + return Tag == other.Tag && GetBaseType() == other.GetBaseType(); + } + + bool Validate(TPosition position, TExprContext& ctx) const; + bool Validate(TPositionHandle position, TExprContext& ctx) const; + +private: + const TTypeAnnotationNode* BaseType; + const TStringBuf Tag; +}; + +class TErrorExprType : public TTypeAnnotationNode { +public: + static constexpr ETypeAnnotationKind KindValue = ETypeAnnotationKind::Error; + + TErrorExprType(ui64 hash, const TIssue& error) + : TTypeAnnotationNode(KindValue, 0, hash, 0) + , Error(error) + {} + + static ui64 MakeHash(const TIssue& error) { + return error.Hash(); + } + + const TIssue& GetError() const { + return Error; + } + + bool operator==(const TErrorExprType& other) const { + return Error == other.Error; + } + +private: + const TIssue Error; +}; + +class TEmptyListExprType : public TTypeAnnotationNode { +public: + static constexpr ETypeAnnotationKind KindValue = ETypeAnnotationKind::EmptyList; + + TEmptyListExprType(ui64 hash) + : TTypeAnnotationNode(KindValue, 0, hash, 0) + { + } + + static ui64 MakeHash() { + return TypeHashMagic | (ui64)ETypeAnnotationKind::EmptyList; + } + + bool operator==(const TEmptyListExprType& other) const { + Y_UNUSED(other); + return true; + } +}; + +class TEmptyDictExprType : public TTypeAnnotationNode { +public: + static constexpr ETypeAnnotationKind KindValue = ETypeAnnotationKind::EmptyDict; + + TEmptyDictExprType(ui64 hash) + : TTypeAnnotationNode(KindValue, 0, hash, 0) + { + } + + static ui64 MakeHash() { + return TypeHashMagic | (ui64)ETypeAnnotationKind::EmptyDict; + } + + bool operator==(const TEmptyDictExprType& other) const { + Y_UNUSED(other); + return true; + } +}; + +inline bool TTypeAnnotationNode::Equals(const TTypeAnnotationNode& node) const { + if (this == &node) { + return true; + } + + if (Hash != node.GetHash()) { + return false; + } + + if (Kind != node.GetKind()) { + return false; + } + + switch (Kind) { + case ETypeAnnotationKind::Unit: + return static_cast<const TUnitExprType&>(*this) == static_cast<const TUnitExprType&>(node); + + case ETypeAnnotationKind::Tuple: + return static_cast<const TTupleExprType&>(*this) == static_cast<const TTupleExprType&>(node); + + case ETypeAnnotationKind::Struct: + return static_cast<const TStructExprType&>(*this) == static_cast<const TStructExprType&>(node); + + case ETypeAnnotationKind::Item: + return static_cast<const TItemExprType&>(*this) == static_cast<const TItemExprType&>(node); + + case ETypeAnnotationKind::List: + return static_cast<const TListExprType&>(*this) == static_cast<const TListExprType&>(node); + + case ETypeAnnotationKind::Data: + return static_cast<const TDataExprType&>(*this) == static_cast<const TDataExprType&>(node); + + case ETypeAnnotationKind::Pg: + return static_cast<const TPgExprType&>(*this) == static_cast<const TPgExprType&>(node); + + case ETypeAnnotationKind::World: + return static_cast<const TWorldExprType&>(*this) == static_cast<const TWorldExprType&>(node); + + case ETypeAnnotationKind::Optional: + return static_cast<const TOptionalExprType&>(*this) == static_cast<const TOptionalExprType&>(node); + + case ETypeAnnotationKind::Type: + return static_cast<const TTypeExprType&>(*this) == static_cast<const TTypeExprType&>(node); + + case ETypeAnnotationKind::Dict: + return static_cast<const TDictExprType&>(*this) == static_cast<const TDictExprType&>(node); + + case ETypeAnnotationKind::Void: + return static_cast<const TVoidExprType&>(*this) == static_cast<const TVoidExprType&>(node); + + case ETypeAnnotationKind::Null: + return static_cast<const TNullExprType&>(*this) == static_cast<const TNullExprType&>(node); + + case ETypeAnnotationKind::Callable: + return static_cast<const TCallableExprType&>(*this) == static_cast<const TCallableExprType&>(node); + + case ETypeAnnotationKind::Generic: + return static_cast<const TGenericExprType&>(*this) == static_cast<const TGenericExprType&>(node); + + case ETypeAnnotationKind::Resource: + return static_cast<const TResourceExprType&>(*this) == static_cast<const TResourceExprType&>(node); + + case ETypeAnnotationKind::Tagged: + return static_cast<const TTaggedExprType&>(*this) == static_cast<const TTaggedExprType&>(node); + + case ETypeAnnotationKind::Error: + return static_cast<const TErrorExprType&>(*this) == static_cast<const TErrorExprType&>(node); + + case ETypeAnnotationKind::Variant: + return static_cast<const TVariantExprType&>(*this) == static_cast<const TVariantExprType&>(node); + + case ETypeAnnotationKind::Stream: + return static_cast<const TStreamExprType&>(*this) == static_cast<const TStreamExprType&>(node); + + case ETypeAnnotationKind::Flow: + return static_cast<const TFlowExprType&>(*this) == static_cast<const TFlowExprType&>(node); + + case ETypeAnnotationKind::EmptyList: + return static_cast<const TEmptyListExprType&>(*this) == static_cast<const TEmptyListExprType&>(node); + + case ETypeAnnotationKind::EmptyDict: + return static_cast<const TEmptyDictExprType&>(*this) == static_cast<const TEmptyDictExprType&>(node); + + case ETypeAnnotationKind::Multi: + return static_cast<const TMultiExprType&>(*this) == static_cast<const TMultiExprType&>(node); + + case ETypeAnnotationKind::Block: + return static_cast<const TBlockExprType&>(*this) == static_cast<const TBlockExprType&>(node); + + case ETypeAnnotationKind::Scalar: + return static_cast<const TScalarExprType&>(*this) == static_cast<const TScalarExprType&>(node); + + case ETypeAnnotationKind::LastType: + YQL_ENSURE(false, "Incorrect type"); + + } + return false; +} + +inline void TTypeAnnotationNode::Accept(TTypeAnnotationVisitor& visitor) const { + switch (Kind) { + case ETypeAnnotationKind::Unit: + return visitor.Visit(static_cast<const TUnitExprType&>(*this)); + case ETypeAnnotationKind::Tuple: + return visitor.Visit(static_cast<const TTupleExprType&>(*this)); + case ETypeAnnotationKind::Struct: + return visitor.Visit(static_cast<const TStructExprType&>(*this)); + case ETypeAnnotationKind::Item: + return visitor.Visit(static_cast<const TItemExprType&>(*this)); + case ETypeAnnotationKind::List: + return visitor.Visit(static_cast<const TListExprType&>(*this)); + case ETypeAnnotationKind::Data: + return visitor.Visit(static_cast<const TDataExprType&>(*this)); + case ETypeAnnotationKind::Pg: + return visitor.Visit(static_cast<const TPgExprType&>(*this)); + case ETypeAnnotationKind::World: + return visitor.Visit(static_cast<const TWorldExprType&>(*this)); + case ETypeAnnotationKind::Optional: + return visitor.Visit(static_cast<const TOptionalExprType&>(*this)); + case ETypeAnnotationKind::Type: + return visitor.Visit(static_cast<const TTypeExprType&>(*this)); + case ETypeAnnotationKind::Dict: + return visitor.Visit(static_cast<const TDictExprType&>(*this)); + case ETypeAnnotationKind::Void: + return visitor.Visit(static_cast<const TVoidExprType&>(*this)); + case ETypeAnnotationKind::Null: + return visitor.Visit(static_cast<const TNullExprType&>(*this)); + case ETypeAnnotationKind::Callable: + return visitor.Visit(static_cast<const TCallableExprType&>(*this)); + case ETypeAnnotationKind::Generic: + return visitor.Visit(static_cast<const TGenericExprType&>(*this)); + case ETypeAnnotationKind::Resource: + return visitor.Visit(static_cast<const TResourceExprType&>(*this)); + case ETypeAnnotationKind::Tagged: + return visitor.Visit(static_cast<const TTaggedExprType&>(*this)); + case ETypeAnnotationKind::Error: + return visitor.Visit(static_cast<const TErrorExprType&>(*this)); + case ETypeAnnotationKind::Variant: + return visitor.Visit(static_cast<const TVariantExprType&>(*this)); + case ETypeAnnotationKind::Stream: + return visitor.Visit(static_cast<const TStreamExprType&>(*this)); + case ETypeAnnotationKind::Flow: + return visitor.Visit(static_cast<const TFlowExprType&>(*this)); + case ETypeAnnotationKind::EmptyList: + return visitor.Visit(static_cast<const TEmptyListExprType&>(*this)); + case ETypeAnnotationKind::EmptyDict: + return visitor.Visit(static_cast<const TEmptyDictExprType&>(*this)); + case ETypeAnnotationKind::Multi: + return visitor.Visit(static_cast<const TMultiExprType&>(*this)); + case ETypeAnnotationKind::Block: + return visitor.Visit(static_cast<const TBlockExprType&>(*this)); + case ETypeAnnotationKind::Scalar: + return visitor.Visit(static_cast<const TScalarExprType&>(*this)); + case ETypeAnnotationKind::LastType: + YQL_ENSURE(false, "Incorrect type"); + } +} + +class TExprNode { + friend class TExprNodeBuilder; + friend class TExprNodeReplaceBuilder; + friend struct TExprContext; + +private: + struct TExprFlags { + enum : ui16 { + Default = 0, + Dead = 0x01, + Frozen = 0x02, + }; + static constexpr ui32 FlagsMask = 0x03; // all flags should fit here + }; + +public: + typedef TIntrusivePtr<TExprNode> TPtr; + typedef std::vector<TPtr> TListType; + typedef TArrayRef<const TPtr> TChildrenType; + + struct TPtrHash : private std::hash<const TExprNode*> { + size_t operator()(const TPtr& p) const { + return std::hash<const TExprNode*>::operator()(p.Get()); + } + }; + +#define YQL_EXPR_NODE_TYPE_MAP(xx) \ + xx(List, 0) \ + xx(Atom, 1) \ + xx(Callable, 2) \ + xx(Lambda, 3) \ + xx(Argument, 4) \ + xx(Arguments, 5) \ + xx(World, 7) + + enum EType : ui8 { + YQL_EXPR_NODE_TYPE_MAP(ENUM_VALUE_GEN) + }; + + static constexpr ui32 TypeMask = 0x07; // all types should fit here + +#define YQL_EXPR_NODE_STATE_MAP(xx) \ + xx(Initial, 0) \ + xx(TypeInProgress, 1) \ + xx(TypePending, 2) \ + xx(TypeComplete, 3) \ + xx(ConstrInProgress, 4) \ + xx(ConstrPending, 5) \ + xx(ConstrComplete, 6) \ + xx(ExecutionRequired, 7) \ + xx(ExecutionInProgress, 8) \ + xx(ExecutionPending, 9) \ + xx(ExecutionComplete, 10) \ + xx(Error, 11) \ + xx(Last, 12) + + enum class EState : ui8 { + YQL_EXPR_NODE_STATE_MAP(ENUM_VALUE_GEN) + }; + + static TPtr GetResult(const TPtr& node) { + return node->Type() == Callable ? node->Result : node; + } + + const TExprNode& GetResult() const { + ENSURE_NOT_DELETED + return Type() == Callable ? *Result : *this; + } + + bool HasResult() const { + ENSURE_NOT_DELETED + return Type() != Callable || bool(Result); + } + + void SetResult(TPtr&& result) { + ENSURE_NOT_DELETED + ENSURE_NOT_FROZEN + Result = std::move(result); + } + + bool IsCallable(const std::string_view& name) const { + ENSURE_NOT_DELETED + return Type() == TExprNode::Callable && Content() == name; + } + + bool IsCallable(const std::initializer_list<std::string_view>& names) const { + ENSURE_NOT_DELETED + return Type() == TExprNode::Callable && names.end() != std::find(names.begin(), names.end(), Content()); + } + + template <class TKey> + bool IsCallable(const THashSet<TKey>& names) const { + ENSURE_NOT_DELETED + return Type() == TExprNode::Callable && names.contains(Content()); + } + + bool IsCallable() const { + ENSURE_NOT_DELETED + return Type() == TExprNode::Callable; + } + + bool IsAtom() const { + ENSURE_NOT_DELETED + return Type() == TExprNode::Atom; + } + + bool IsWorld() const { + ENSURE_NOT_DELETED + return Type() == TExprNode::World; + } + + bool IsAtom(const std::string_view& content) const { + ENSURE_NOT_DELETED + return Type() == TExprNode::Atom && Content() == content; + } + + bool IsAtom(const std::initializer_list<std::string_view>& names) const { + ENSURE_NOT_DELETED + return Type() == TExprNode::Atom && names.end() != std::find(names.begin(), names.end(), Content()); + } + + bool IsList() const { + ENSURE_NOT_DELETED + return Type() == TExprNode::List; + } + + bool IsLambda() const { + ENSURE_NOT_DELETED + return Type() == TExprNode::Lambda; + } + + bool IsArgument() const { + ENSURE_NOT_DELETED + return Type() == TExprNode::Argument; + } + + bool IsArguments() const { + ENSURE_NOT_DELETED + return Type() == TExprNode::Arguments; + } + + bool IsComposable() const { + ENSURE_NOT_DELETED + return !IsLambda() && TypeAnnotation_->IsComposable(); + } + + bool IsPersistable() const { + ENSURE_NOT_DELETED + return !IsLambda() && TypeAnnotation_->IsPersistable(); + } + + bool IsComputable() const { + ENSURE_NOT_DELETED + return !IsLambda() && TypeAnnotation_->IsComputable(); + } + + bool IsInspectable() const { + ENSURE_NOT_DELETED + return !IsLambda() && TypeAnnotation_->IsInspectable(); + } + + bool ForDisclosing() const { + ENSURE_NOT_DELETED + return Type() == TExprNode::List && ShallBeDisclosed; + } + + void SetDisclosing() { + ENSURE_NOT_DELETED + Y_ENSURE(Type() == TExprNode::List, "Must be list."); + ShallBeDisclosed = true; + } + + ui32 GetFlagsToCompare() const { + ENSURE_NOT_DELETED + ui32 ret = Flags(); + if ((ret & TNodeFlags::BinaryContent) == 0) { + ret |= TNodeFlags::ArbitraryContent | TNodeFlags::MultilineContent; + } + + return ret; + } + + TString Dump() const; + + bool StartsExecution() const { + ENSURE_NOT_DELETED + return State == EState::ExecutionComplete + || State == EState::ExecutionInProgress + || State == EState::ExecutionRequired + || State == EState::ExecutionPending; + } + + bool IsComplete() const { + YQL_ENSURE(HasLambdaScope); + return !OuterLambda; + } + + bool IsLiteralList() const { + YQL_ENSURE(IsList()); + return LiteralList; + } + + void SetLiteralList(bool literal) { + YQL_ENSURE(IsList()); + LiteralList = literal; + } + + void Ref() { + ENSURE_NOT_DELETED + ENSURE_NOT_FROZEN + Y_ENSURE(RefCount_ < Max<ui32>()); + ++RefCount_; + } + + void UnRef() { + ENSURE_NOT_DELETED + ENSURE_NOT_FROZEN + if (!--RefCount_) { + Result.Reset(); + Children_.clear(); + Constraints_.Clear(); + MarkDead(); + } + } + + ui32 UseCount() const { return RefCount_; } + bool Unique() const { return 1U == UseCount(); } + + bool Dead() const { + return ExprFlags_ & TExprFlags::Dead; + } + + TPositionHandle Pos() const { + ENSURE_NOT_DELETED + return Position_; + } + + TPosition Pos(const TExprContext& ctx) const; + + EType Type() const { + ENSURE_NOT_DELETED + return (EType)Type_; + } + + TListType::size_type ChildrenSize() const { + ENSURE_NOT_DELETED + return Children_.size(); + } + + TExprNode* Child(ui32 index) const { + ENSURE_NOT_DELETED + Y_ENSURE(index < Children_.size(), "index out of range"); + return Children_[index].Get(); + } + + TPtr ChildPtr(ui32 index) const { + ENSURE_NOT_DELETED + Y_ENSURE(index < Children_.size(), "index out of range"); + return Children_[index]; + } + + TPtr& ChildRef(ui32 index) { + ENSURE_NOT_DELETED + ENSURE_NOT_FROZEN + Y_ENSURE(index < Children_.size(), "index out of range"); + return Children_[index]; + } + + const TExprNode& Head() const { + ENSURE_NOT_DELETED + Y_ENSURE(!Children_.empty(), "no children"); + return *Children_.front(); + } + + TExprNode& Head() { + ENSURE_NOT_DELETED + Y_ENSURE(!Children_.empty(), "no children"); + return *Children_.front(); + } + + TPtr HeadPtr() const { + ENSURE_NOT_DELETED + Y_ENSURE(!Children_.empty(), "no children"); + return Children_.front(); + } + + TPtr& HeadRef() { + ENSURE_NOT_DELETED + ENSURE_NOT_FROZEN + Y_ENSURE(!Children_.empty(), "no children"); + return Children_.front(); + } + + const TExprNode& Tail() const { + ENSURE_NOT_DELETED + Y_ENSURE(!Children_.empty(), "no children"); + return *Children_.back(); + } + + TExprNode& Tail() { + ENSURE_NOT_DELETED + Y_ENSURE(!Children_.empty(), "no children"); + return *Children_.back(); + } + + TPtr TailPtr() const { + ENSURE_NOT_DELETED + Y_ENSURE(!Children_.empty(), "no children"); + return Children_.back(); + } + + TPtr& TailRef() { + ENSURE_NOT_DELETED + ENSURE_NOT_FROZEN + Y_ENSURE(!Children_.empty(), "no children"); + return Children_.back(); + } + + TChildrenType Children() const { + ENSURE_NOT_DELETED + return TChildrenType(Children_.data(), Children_.size()); + } + + TListType ChildrenList() const { + ENSURE_NOT_DELETED + return Children_; + } + + void ChangeChildrenInplace(TListType&& newChildren) { + ENSURE_NOT_DELETED + Children_ = std::move(newChildren); + } + + template<class F> + void ForEachChild(const F& visitor) const { + for (const auto& child : Children_) + visitor(*child); + } + + TStringBuf Content() const { + ENSURE_NOT_DELETED + return ContentUnchecked(); + } + + ui32 Flags() const { + ENSURE_NOT_DELETED + return Flags_; + } + + void NormalizeAtomFlags(const TExprNode& otherAtom) { + ENSURE_NOT_DELETED + ENSURE_NOT_FROZEN + Y_ENSURE(Type_ == Atom && otherAtom.Type_ == Atom, "Expected atoms"); + Y_ENSURE((Flags_ & TNodeFlags::BinaryContent) == + (otherAtom.Flags_ & TNodeFlags::BinaryContent), "Mismatch binary atom flags"); + if (!(Flags_ & TNodeFlags::BinaryContent)) { + Flags_ = Min(Flags_, otherAtom.Flags_); + } + } + + ui64 UniqueId() const { + ENSURE_NOT_DELETED + return UniqueId_; + } + + const TConstraintNode* GetConstraint(TStringBuf name) const { + ENSURE_NOT_DELETED + Y_ENSURE(static_cast<EState>(State) >= EState::ConstrComplete); + return Constraints_.GetConstraint(name); + } + + template <class TConstraintType> + const TConstraintType* GetConstraint() const { + ENSURE_NOT_DELETED + Y_ENSURE(static_cast<EState>(State) >= EState::ConstrComplete); + return Constraints_.GetConstraint<TConstraintType>(); + } + + const TConstraintNode::TListType& GetAllConstraints() const { + ENSURE_NOT_DELETED + Y_ENSURE(static_cast<EState>(State) >= EState::ConstrComplete); + return Constraints_.GetAllConstraints(); + } + + const TConstraintSet& GetConstraintSet() const { + ENSURE_NOT_DELETED + Y_ENSURE(static_cast<EState>(State) >= EState::ConstrComplete); + return Constraints_; + } + + void AddConstraint(const TConstraintNode* node) { + ENSURE_NOT_DELETED + ENSURE_NOT_FROZEN + Y_ENSURE(static_cast<EState>(State) >= EState::TypeComplete); + Y_ENSURE(!StartsExecution()); + Constraints_.AddConstraint(node); + State = EState::ConstrComplete; + } + + void CopyConstraints(const TExprNode& node) { + ENSURE_NOT_DELETED + ENSURE_NOT_FROZEN + Y_ENSURE(static_cast<EState>(State) >= EState::TypeComplete); + Constraints_ = node.Constraints_; + State = EState::ConstrComplete; + } + + void SetConstraints(const TConstraintSet& constraints) { + ENSURE_NOT_DELETED + ENSURE_NOT_FROZEN + Y_ENSURE(static_cast<EState>(State) >= EState::TypeComplete); + Constraints_ = constraints; + State = EState::ConstrComplete; + } + + static TPtr NewAtom(ui64 uniqueId, TPositionHandle pos, const TStringBuf& content, ui32 flags) { + return Make(pos, Atom, {}, content, flags, uniqueId); + } + + static TPtr NewArgument(ui64 uniqueId, TPositionHandle pos, const TStringBuf& name) { + return Make(pos, Argument, {}, name, 0, uniqueId); + } + + static TPtr NewArguments(ui64 uniqueId, TPositionHandle pos, TListType&& argNodes) { + return Make(pos, Arguments, std::move(argNodes), ZeroString, 0, uniqueId); + } + + static TPtr NewLambda(ui64 uniqueId, TPositionHandle pos, TListType&& lambda) { + return Make(pos, Lambda, std::move(lambda), ZeroString, 0, uniqueId); + } + + static TPtr NewLambda(ui64 uniqueId, TPositionHandle pos, TPtr&& args, TListType&& body) { + TListType lambda(body.size() + 1U); + lambda.front() = std::move(args); + std::move(body.rbegin(), body.rend(), lambda.rbegin()); + return NewLambda(uniqueId, pos, std::move(lambda)); + } + + static TPtr NewLambda(ui64 uniqueId, TPositionHandle pos, TPtr&& args, TPtr&& body) { + TListType children(body ? 2 : 1); + children.front() = std::move(args); + if (body) { + children.back() = std::move(body); + } + + return NewLambda(uniqueId, pos, std::move(children)); + } + + static TPtr NewWorld(ui64 uniqueId, TPositionHandle pos) { + return Make(pos, World, {}, {}, 0, uniqueId); + } + + static TPtr NewList(ui64 uniqueId, TPositionHandle pos, TListType&& children) { + return Make(pos, List, std::move(children), ZeroString, 0, uniqueId); + } + + static TPtr NewCallable(ui64 uniqueId, TPositionHandle pos, const TStringBuf& name, TListType&& children) { + return Make(pos, Callable, std::move(children), name, 0, uniqueId); + } + + TPtr Clone(ui64 newUniqueId) const { + ENSURE_NOT_DELETED + return Make(Position_, (EType)Type_, TListType(Children_), Content(), Flags_, newUniqueId); + } + + TPtr CloneWithPosition(ui64 newUniqueId, TPositionHandle pos) const { + ENSURE_NOT_DELETED + return Make(pos, (EType)Type_, TListType(Children_), Content(), Flags_, newUniqueId); + } + + static TPtr NewNode(TPositionHandle position, EType type, TListType&& children, const TStringBuf& content, ui32 flags, ui64 uniqueId) { + return Make(position, type, std::move(children), content, flags, uniqueId); + } + + TPtr ChangeContent(ui64 newUniqueId, const TStringBuf& content) const { + ENSURE_NOT_DELETED + return Make(Position_, (EType)Type_, TListType(Children_), content, Flags_, newUniqueId); + } + + TPtr ChangeChildren(ui64 newUniqueId, TListType&& children) const { + ENSURE_NOT_DELETED + return Make(Position_, (EType)Type_, std::move(children), Content(), Flags_, newUniqueId); + } + + TPtr ChangeChild(ui64 newUniqueId, ui32 index, TPtr&& child) const { + ENSURE_NOT_DELETED + Y_ENSURE(index < Children_.size(), "index out of range"); + TListType newChildren(Children_); + newChildren[index] = std::move(child); + return Make(Position_, (EType)Type_, std::move(newChildren), Content(), Flags_, newUniqueId); + } + + void SetTypeAnn(const TTypeAnnotationNode* typeAnn) { + TypeAnnotation_ = typeAnn; + State = TypeAnnotation_ ? EState::TypeComplete : EState::Initial; + } + + const TTypeAnnotationNode* GetTypeAnn() const { + return TypeAnnotation_; + } + + EState GetState() const { + return State; + } + + void SetState(EState state) { + State = state; + } + + ui32 GetArgIndex() const { + YQL_ENSURE(Type() == EType::Argument); + return ArgIndex; + } + + void SetArgIndex(ui32 argIndex) { + YQL_ENSURE(Type() == EType::Argument); + YQL_ENSURE(argIndex <= Max<ui16>()); + ArgIndex = (ui16)argIndex; + } + + ui64 GetHash() const { + Y_DEBUG_ABORT_UNLESS(HashAbove == HashBelow); + return HashAbove; + } + + void SetHash(ui64 hash) { + HashAbove = HashBelow = hash; + } + + ui64 GetHashAbove() const { + return HashAbove; + } + + void SetHashAbove(ui64 hash) { + HashAbove = hash; + } + + ui64 GetHashBelow() const { + return HashBelow; + } + + void SetHashBelow(ui64 hash) { + HashBelow = hash; + } + + ui64 GetBloom() const { + return Bloom; + } + + void SetBloom(ui64 bloom) { + Bloom = bloom; + } + + // return pair of outer and inner lambda. + std::optional<std::pair<const TExprNode*, const TExprNode*>> GetDependencyScope() const { + if (HasLambdaScope) { + return std::make_pair(OuterLambda, InnerLambda); + } + return std::nullopt; + } + + void SetDependencyScope(const TExprNode* outerLambda, const TExprNode* innerLambda) { + Y_DEBUG_ABORT_UNLESS(outerLambda == innerLambda || outerLambda->GetLambdaLevel() < innerLambda->GetLambdaLevel(), "Wrong scope of closures."); + HasLambdaScope = 1; + OuterLambda = outerLambda; + InnerLambda = innerLambda; + } + + ui16 GetLambdaLevel() const { return LambdaLevel; } + void SetLambdaLevel(ui16 lambdaLevel) { LambdaLevel = lambdaLevel; } + + bool IsUsedInDependsOn() const { + YQL_ENSURE(Type() == EType::Argument); + return UsedInDependsOn; + } + + void SetUsedInDependsOn() { + YQL_ENSURE(Type() == EType::Argument); + UsedInDependsOn = 1; + } + + void SetUnorderedChildren() { + YQL_ENSURE(Type() == EType::List || Type() == EType::Callable); + UnordChildren = 1; + } + + bool UnorderedChildren() const { + YQL_ENSURE(Type() == EType::List || Type() == EType::Callable); + return bool(UnordChildren); + } + + ~TExprNode() { + Y_ABORT_UNLESS(Dead(), "Node (id: %lu, type: %s, content: '%s') not dead on destruction.", + UniqueId_, ToString(Type_).data(), TString(ContentUnchecked()).data()); + Y_ABORT_UNLESS(!UseCount(), "Node (id: %lu, type: %s, content: '%s') has non-zero use count on destruction.", + UniqueId_, ToString(Type_).data(), TString(ContentUnchecked()).data()); + } + +private: + static TPtr Make(TPositionHandle position, EType type, TListType&& children, const TStringBuf& content, ui32 flags, ui64 uniqueId) { + Y_ENSURE(flags <= TNodeFlags::FlagsMask); + Y_ENSURE(children.size() <= Max<ui32>()); + Y_ENSURE(content.size() <= Max<ui32>()); + for (size_t i = 0; i < children.size(); ++i) { + Y_ENSURE(children[i], "Unable to create node " << content << ": " << i << "th child is null"); + } + return TPtr(new TExprNode(position, type, std::move(children), content.data(), ui32(content.size()), flags, uniqueId)); + } + + TExprNode(TPositionHandle position, EType type, TListType&& children, + const char* content, ui32 contentSize, ui32 flags, ui64 uniqueId) + : Children_(std::move(children)) + , Content_(content) + , UniqueId_(uniqueId) + , Position_(position) + , ContentSize(contentSize) + , Type_(type) + , Flags_(flags) + , ExprFlags_(TExprFlags::Default) + , State(EState::Initial) + , HasLambdaScope(0) + , UsedInDependsOn(0) + , UnordChildren(0) + , ShallBeDisclosed(0) + , LiteralList(0) + {} + + TExprNode(const TExprNode&) = delete; + TExprNode(TExprNode&&) = delete; + TExprNode& operator=(const TExprNode&) = delete; + TExprNode& operator=(TExprNode&&) = delete; + + bool Frozen() const { + return ExprFlags_ & TExprFlags::Frozen; + } + + void MarkFrozen(bool frozen = true) { + if (frozen) { + ExprFlags_ |= TExprFlags::Frozen; + } else { + ExprFlags_ &= ~TExprFlags::Frozen; + } + } + + void MarkDead() { + ExprFlags_ |= TExprFlags::Dead; + } + + TStringBuf ContentUnchecked() const { + return TStringBuf(Content_, ContentSize); + } + + TListType Children_; + TConstraintSet Constraints_; + + const char* Content_ = nullptr; + + const TExprNode* OuterLambda = nullptr; + const TExprNode* InnerLambda = nullptr; + + TPtr Result; + + ui64 HashAbove = 0ULL; + ui64 HashBelow = 0ULL; + ui64 Bloom = 0ULL; + + const ui64 UniqueId_; + const TTypeAnnotationNode* TypeAnnotation_ = nullptr; + + const TPositionHandle Position_; + ui32 RefCount_ = 0U; + const ui32 ContentSize; + + ui16 ArgIndex = ui16(-1); + ui16 LambdaLevel = 0; // filled together with OuterLambda + ui16 IntermediateHashesCount = 0; + + static_assert(TypeMask <= 7, "EType wont fit in 3 bits, increase Type_ bitfield size"); + static_assert(TNodeFlags::FlagsMask <= 7, "TNodeFlags wont fit in 3 bits, increase Flags_ bitfield size"); + static_assert(TExprFlags::FlagsMask <= 3, "TExprFlags wont fit in 2 bits, increase ExprFlags_ bitfield size"); + static_assert(int(EState::Last) <= 16, "EState wont fit in 4 bits, increase State bitfield size"); + struct { + ui8 Type_ : 3; + ui8 Flags_ : 3; + ui8 ExprFlags_ : 2; + + EState State : 4; + ui8 HasLambdaScope : 1; + ui8 UsedInDependsOn : 1; + ui8 UnordChildren : 1; + ui8 ShallBeDisclosed: 1; + ui8 LiteralList : 1; + }; +}; + +class TExportTable { +public: + using TSymbols = THashMap<TString, TExprNode::TPtr>; + + TExportTable() = default; + TExportTable(TExprContext& ctx, TSymbols&& symbols) + : Symbols_(std::move(symbols)) + , Ctx_(&ctx) + {} + + const TSymbols& Symbols() const { + return Symbols_; + } + + TSymbols& Symbols(TExprContext& ctx) { + if (Ctx_) { + YQL_ENSURE(Ctx_ == &ctx); + } else { + Ctx_ = &ctx; + } + return Symbols_; + } + + TExprContext& ExprCtx() const { + YQL_ENSURE(Ctx_); + return *Ctx_; + } +private: + TSymbols Symbols_; + TExprContext* Ctx_ = nullptr; +}; + +using TModulesTable = THashMap<TString, TExportTable>; + +class IModuleResolver { +public: + typedef std::shared_ptr<IModuleResolver> TPtr; + virtual bool AddFromFile(const std::string_view& file, TExprContext& ctx, ui16 syntaxVersion, ui32 packageVersion, TPosition pos = {}) = 0; + virtual bool AddFromUrl(const std::string_view& file, const std::string_view& url, const std::string_view& tokenName, TExprContext& ctx, ui16 syntaxVersion, ui32 packageVersion, TPosition pos = {}) = 0; + virtual bool AddFromMemory(const std::string_view& file, const TString& body, TExprContext& ctx, ui16 syntaxVersion, ui32 packageVersion, TPosition pos = {}) = 0; + virtual bool AddFromMemory(const std::string_view& file, const TString& body, TExprContext& ctx, ui16 syntaxVersion, ui32 packageVersion, TPosition pos, TString& moduleName, std::vector<TString>* exports = nullptr, std::vector<TString>* imports = nullptr) = 0; + virtual bool Link(TExprContext& ctx) = 0; + virtual void UpdateNextUniqueId(TExprContext& ctx) const = 0; + virtual ui64 GetNextUniqueId() const = 0; + virtual void RegisterPackage(const TString& package) = 0; + virtual bool SetPackageDefaultVersion(const TString& package, ui32 version) = 0; + virtual const TExportTable* GetModule(const TString& module) const = 0; + /* + Create new resolver which will use already collected modules in readonly manner. + Parent resolver should be alive while using child due to raw data sharing. + */ + virtual IModuleResolver::TPtr CreateMutableChild() const = 0; + virtual void SetFileAliasPrefix(TString&& prefix) = 0; + virtual TString GetFileAliasPrefix() const = 0; + virtual ~IModuleResolver() = default; +}; + +struct TExprStep { + enum ELevel { + Params, + ExpandApplyForLambdas, + ValidateProviders, + Configure, + ExprEval, + DiscoveryIO, + Epochs, + Intents, + LoadTablesMetadata, + RewriteIO, + Recapture, + LastLevel + }; + + TExprStep() + { + } + + void Done(ELevel level) { + Steps_.Set(level); + } + + void Reset() { + Steps_.Reset(); + } + + TExprStep& Repeat(ELevel level) { + Steps_.Reset(level); + return *this; + } + + bool IsDone(ELevel level) { + return Steps_.Test(level); + } + +private: + TEnumBitSet<ELevel, Params, LastLevel> Steps_; +}; + +template <typename T> +struct TMakeTypeImpl; + +template <class T> +using TNodeMap = std::unordered_map<const TExprNode*, T>; +using TNodeSet = std::unordered_set<const TExprNode*>; +using TNodeOnNodeOwnedMap = TNodeMap<TExprNode::TPtr>; +using TParentsMap = TNodeMap<TNodeSet>; + +using TNodeMultiSet = std::unordered_multiset<const TExprNode*>; +using TParentsMultiMap = TNodeMap<TNodeMultiSet>; + +template <> +struct TMakeTypeImpl<TVoidExprType> { + static const TVoidExprType* Make(TExprContext& ctx); +}; + +template <> +struct TMakeTypeImpl<TNullExprType> { + static const TNullExprType* Make(TExprContext& ctx); +}; + +template <> +struct TMakeTypeImpl<TEmptyListExprType> { + static const TEmptyListExprType* Make(TExprContext& ctx); +}; + +template <> +struct TMakeTypeImpl<TEmptyDictExprType> { + static const TEmptyDictExprType* Make(TExprContext& ctx); +}; + +template <> +struct TMakeTypeImpl<TUnitExprType> { + static const TUnitExprType* Make(TExprContext& ctx); +}; + +template <> +struct TMakeTypeImpl<TWorldExprType> { + static const TWorldExprType* Make(TExprContext& ctx); +}; + +template <> +struct TMakeTypeImpl<TGenericExprType> { + static const TGenericExprType* Make(TExprContext& ctx); +}; + +template <> +struct TMakeTypeImpl<TItemExprType> { + static const TItemExprType* Make(TExprContext& ctx, const TStringBuf& name, const TTypeAnnotationNode* itemType); +}; + +template <> +struct TMakeTypeImpl<TListExprType> { + static const TListExprType* Make(TExprContext& ctx, const TTypeAnnotationNode* itemType); +}; + +template <> +struct TMakeTypeImpl<TOptionalExprType> { + static const TOptionalExprType* Make(TExprContext& ctx, const TTypeAnnotationNode* itemType); +}; + +template <> +struct TMakeTypeImpl<TVariantExprType> { + static const TVariantExprType* Make(TExprContext& ctx, const TTypeAnnotationNode* underlyingType); +}; + +template <> +struct TMakeTypeImpl<TErrorExprType> { + static const TErrorExprType* Make(TExprContext& ctx, const TIssue& error); +}; + +template <> +struct TMakeTypeImpl<TDictExprType> { + static const TDictExprType* Make(TExprContext& ctx, const TTypeAnnotationNode* keyType, + const TTypeAnnotationNode* payloadType); +}; + +template <> +struct TMakeTypeImpl<TTypeExprType> { + static const TTypeExprType* Make(TExprContext& ctx, const TTypeAnnotationNode* baseType); +}; + +template <> +struct TMakeTypeImpl<TDataExprType> { + static const TDataExprType* Make(TExprContext& ctx, EDataSlot slot); +}; + +template <> +struct TMakeTypeImpl<TPgExprType> { + static const TPgExprType* Make(TExprContext& ctx, ui32 typeId); +}; + +template <> +struct TMakeTypeImpl<TDataExprParamsType> { + static const TDataExprParamsType* Make(TExprContext& ctx, EDataSlot slot, const TStringBuf& one, const TStringBuf& two); +}; + +template <> +struct TMakeTypeImpl<TCallableExprType> { + static const TCallableExprType* Make( + TExprContext& ctx, const TTypeAnnotationNode* returnType, const TVector<TCallableExprType::TArgumentInfo>& arguments, + size_t optionalArgumentsCount, const TStringBuf& payload); +}; + +template <> +struct TMakeTypeImpl<TResourceExprType> { + static const TResourceExprType* Make(TExprContext& ctx, const TStringBuf& tag); +}; + +template <> +struct TMakeTypeImpl<TTaggedExprType> { + static const TTaggedExprType* Make(TExprContext& ctx, const TTypeAnnotationNode* baseType, const TStringBuf& tag); +}; + +template <> +struct TMakeTypeImpl<TStructExprType> { + static const TStructExprType* Make(TExprContext& ctx, const TVector<const TItemExprType*>& items); +}; + +template <> +struct TMakeTypeImpl<TTupleExprType> { + static const TTupleExprType* Make(TExprContext& ctx, const TTypeAnnotationNode::TListType& items); +}; + +template <> +struct TMakeTypeImpl<TMultiExprType> { + static const TMultiExprType* Make(TExprContext& ctx, const TTypeAnnotationNode::TListType& items); +}; + +template <> +struct TMakeTypeImpl<TStreamExprType> { + static const TStreamExprType* Make(TExprContext& ctx, const TTypeAnnotationNode* itemType); +}; + +template <> +struct TMakeTypeImpl<TFlowExprType> { + static const TFlowExprType* Make(TExprContext& ctx, const TTypeAnnotationNode* itemType); +}; + +template <> +struct TMakeTypeImpl<TBlockExprType> { + static const TBlockExprType* Make(TExprContext& ctx, const TTypeAnnotationNode* itemType); +}; + +template <> +struct TMakeTypeImpl<TScalarExprType> { + static const TScalarExprType* Make(TExprContext& ctx, const TTypeAnnotationNode* itemType); +}; + +using TSingletonTypeCache = std::tuple< + const TVoidExprType*, + const TNullExprType*, + const TUnitExprType*, + const TEmptyListExprType*, + const TEmptyDictExprType*, + const TWorldExprType*, + const TGenericExprType*, + const TTupleExprType*, + const TStructExprType*, + const TMultiExprType* +>; + +struct TExprContext : private TNonCopyable { + class TFreezeGuard { + public: + TFreezeGuard(const TFreezeGuard&) = delete; + TFreezeGuard& operator=(const TFreezeGuard&) = delete; + + TFreezeGuard(TExprContext& ctx) + : Ctx(ctx) + { + Ctx.Freeze(); + } + + ~TFreezeGuard() { + Ctx.UnFreeze(); + } + + private: + TExprContext& Ctx; + }; + + TIssueManager IssueManager; + TNodeMap<TIssues> AssociativeIssues; + + TMemoryPool StringPool; + std::unordered_set<std::string_view> Strings; + std::unordered_map<ui32, std::string_view> Indexes; + + std::stack<std::unique_ptr<const TTypeAnnotationNode>> TypeNodes; + std::stack<std::unique_ptr<const TConstraintNode>> ConstraintNodes; + std::deque<std::unique_ptr<TExprNode>> ExprNodes; + + TSingletonTypeCache SingletonTypeCache; + std::unordered_set<const TTypeAnnotationNode*, TTypeAnnotationNode::THash, TTypeAnnotationNode::TEqual> TypeSet; + std::unordered_set<const TConstraintNode*, TConstraintNode::THash, TConstraintNode::TEqual> ConstraintSet; + std::unordered_map<const TTypeAnnotationNode*, TExprNode::TPtr> TypeAsNodeCache; + std::unordered_set<TStringBuf, THash<TStringBuf>> DisabledConstraints; + + ui64 NextUniqueId = 0; + ui64 NodeAllocationCounter = 0; + ui64 NodesAllocationLimit = 3000000; + ui64 StringsAllocationLimit = 100000000; + ui64 RepeatTransformLimit = 1000000; + ui64 RepeatTransformCounter = 0; + ui64 TypeAnnNodeRepeatLimit = 1000; + + TGcNodeConfig GcConfig; + + std::unordered_multimap<ui64, TExprNode*> UniqueNodes; + + TExprStep Step; + + bool Frozen; + + explicit TExprContext(ui64 nextUniqueId = 0ULL); + ~TExprContext(); + + ui64 AllocateNextUniqueId() { + ENSURE_NOT_FROZEN_CTX + const auto ret = ++NextUniqueId; + return ret; + } + + TStringBuf AppendString(const TStringBuf& buf) { + ENSURE_NOT_FROZEN_CTX + if (buf.size() == 0) { + return ZeroString; + } + + auto it = Strings.find(buf); + if (it != Strings.end()) { + return *it; + } + + auto newBuf = StringPool.AppendString(buf); + Strings.insert(it, newBuf); + return newBuf; + } + + TPositionHandle AppendPosition(const TPosition& pos); + TPosition GetPosition(TPositionHandle handle) const; + + TExprNodeBuilder Builder(TPositionHandle pos) { + return TExprNodeBuilder(pos, *this); + } + + [[nodiscard]] + TExprNode::TPtr RenameNode(const TExprNode& node, const TStringBuf& name); + [[nodiscard]] + TExprNode::TPtr ShallowCopy(const TExprNode& node); + [[nodiscard]] + TExprNode::TPtr ShallowCopyWithPosition(const TExprNode& node, TPositionHandle pos); + [[nodiscard]] + TExprNode::TPtr ChangeChildren(const TExprNode& node, TExprNode::TListType&& children); + [[nodiscard]] + TExprNode::TPtr ChangeChild(const TExprNode& node, ui32 index, TExprNode::TPtr&& child); + [[nodiscard]] + TExprNode::TPtr ExactChangeChildren(const TExprNode& node, TExprNode::TListType&& children); + [[nodiscard]] + TExprNode::TPtr ExactShallowCopy(const TExprNode& node); + [[nodiscard]] + TExprNode::TPtr DeepCopyLambda(const TExprNode& node, TExprNode::TListType&& body); + [[nodiscard]] + TExprNode::TPtr DeepCopyLambda(const TExprNode& node, TExprNode::TPtr&& body = TExprNode::TPtr()); + [[nodiscard]] + TExprNode::TPtr FuseLambdas(const TExprNode& outer, const TExprNode& inner); + + using TCustomDeepCopier = std::function<bool(const TExprNode& node, TExprNode::TListType& newChildren)>; + + [[nodiscard]] + TExprNode::TPtr DeepCopy(const TExprNode& node, TExprContext& nodeContext, TNodeOnNodeOwnedMap& deepClones, + bool internStrings, bool copyTypes, bool copyResult = false, TCustomDeepCopier customCopier = {}); + + [[nodiscard]] + TExprNode::TPtr SwapWithHead(const TExprNode& node); + TExprNode::TPtr ReplaceNode(TExprNode::TPtr&& start, const TExprNode& src, TExprNode::TPtr dst); + TExprNode::TPtr ReplaceNodes(TExprNode::TPtr&& start, const TNodeOnNodeOwnedMap& replaces); + template<bool KeepTypeAnns = false> + TExprNode::TListType ReplaceNodes(TExprNode::TListType&& start, const TNodeOnNodeOwnedMap& replaces); + + TExprNode::TPtr NewAtom(TPositionHandle pos, const TStringBuf& content, ui32 flags = TNodeFlags::ArbitraryContent) { + ++NodeAllocationCounter; + const auto node = TExprNode::NewAtom(AllocateNextUniqueId(), pos, AppendString(content), flags); + ExprNodes.emplace_back(node.Get()); + return node; + } + + TExprNode::TPtr NewAtom(TPositionHandle pos, ui32 index) { + ++NodeAllocationCounter; + const auto node = TExprNode::NewAtom(AllocateNextUniqueId(), pos, GetIndexAsString(index), TNodeFlags::Default); + ExprNodes.emplace_back(node.Get()); + return node; + } + + TExprNode::TPtr NewArgument(TPositionHandle pos, const TStringBuf& name) { + ++NodeAllocationCounter; + const auto node = TExprNode::NewArgument(AllocateNextUniqueId(), pos, AppendString(name)); + ExprNodes.emplace_back(node.Get()); + return node; + } + + TExprNode::TPtr NewArguments(TPositionHandle pos, TExprNode::TListType&& argNodes) { + ++NodeAllocationCounter; + const auto node = TExprNode::NewArguments(AllocateNextUniqueId(), pos, std::move(argNodes)); + ExprNodes.emplace_back(node.Get()); + return node; + } + + TExprNode::TPtr NewLambda(TPositionHandle pos, TExprNode::TListType&& lambda) { + ++NodeAllocationCounter; + const auto node = TExprNode::NewLambda(AllocateNextUniqueId(), pos, std::move(lambda)); + ExprNodes.emplace_back(node.Get()); + return node; + } + + TExprNode::TPtr NewLambda(TPositionHandle pos, TExprNode::TPtr&& args, TExprNode::TListType&& body) { + ++NodeAllocationCounter; + const auto node = TExprNode::NewLambda(AllocateNextUniqueId(), pos, std::move(args), std::move(body)); + ExprNodes.emplace_back(node.Get()); + return node; + } + + TExprNode::TPtr NewLambda(TPositionHandle pos, TExprNode::TPtr&& args, TExprNode::TPtr&& body) { + ++NodeAllocationCounter; + const auto node = TExprNode::NewLambda(AllocateNextUniqueId(), pos, std::move(args), std::move(body)); + ExprNodes.emplace_back(node.Get()); + return node; + } + + TExprNode::TPtr NewWorld(TPositionHandle pos) { + ++NodeAllocationCounter; + const auto node = TExprNode::NewWorld(AllocateNextUniqueId(), pos); + ExprNodes.emplace_back(node.Get()); + return node; + } + + TExprNode::TPtr NewList(TPositionHandle pos, TExprNode::TListType&& children) { + ++NodeAllocationCounter; + const auto node = TExprNode::NewList(AllocateNextUniqueId(), pos, std::move(children)); + ExprNodes.emplace_back(node.Get()); + return node; + } + + TExprNode::TPtr NewCallable(TPositionHandle pos, const TStringBuf& name, TExprNode::TListType&& children) { + ++NodeAllocationCounter; + const auto node = TExprNode::NewCallable(AllocateNextUniqueId(), pos, AppendString(name), std::move(children)); + ExprNodes.emplace_back(node.Get()); + return node; + } + + TExprNode::TPtr NewAtom(TPosition pos, const TStringBuf& content, ui32 flags = TNodeFlags::ArbitraryContent) { + return NewAtom(AppendPosition(pos), content, flags); + } + + TExprNode::TPtr NewAtom(TPosition pos, ui32 index) { + return NewAtom(AppendPosition(pos), index); + } + + TExprNode::TPtr NewArgument(TPosition pos, const TStringBuf& name) { + return NewArgument(AppendPosition(pos), name); + } + + TExprNode::TPtr NewArguments(TPosition pos, TExprNode::TListType&& argNodes) { + return NewArguments(AppendPosition(pos), std::move(argNodes)); + } + + TExprNode::TPtr NewLambda(TPosition pos, TExprNode::TListType&& lambda) { + return NewLambda(AppendPosition(pos), std::move(lambda)); + } + + TExprNode::TPtr NewLambda(TPosition pos, TExprNode::TPtr&& args, TExprNode::TListType&& body) { + return NewLambda(AppendPosition(pos), std::move(args), std::move(body)); + } + + TExprNode::TPtr NewLambda(TPosition pos, TExprNode::TPtr&& args, TExprNode::TPtr&& body) { + return NewLambda(AppendPosition(pos), std::move(args), std::move(body)); + } + + TExprNode::TPtr NewWorld(TPosition pos) { + return NewWorld(AppendPosition(pos)); + } + + TExprNode::TPtr NewList(TPosition pos, TExprNode::TListType&& children) { + return NewList(AppendPosition(pos), std::move(children)); + } + + TExprNode::TPtr NewCallable(TPosition pos, const TStringBuf& name, TExprNode::TListType&& children) { + return NewCallable(AppendPosition(pos), name, std::move(children)); + } + + TExprNode::TPtr WrapByCallableIf(bool condition, const TStringBuf& callable, TExprNode::TPtr&& node); + + template <typename T, typename... Args> + const T* MakeType(Args&&... args); + + template <typename T, typename... Args> + const T* MakeConstraint(Args&&... args); + + TConstraintSet MakeConstraintSet(const NYT::TNode& serializedConstraints); + + void AddError(const TIssue& error) { + ENSURE_NOT_FROZEN_CTX + IssueManager.RaiseIssue(error); + } + + bool AddWarning(const TIssue& warning) { + ENSURE_NOT_FROZEN_CTX + return IssueManager.RaiseWarning(warning); + } + + void Freeze(); + void UnFreeze(); + + void Reset(); + + template <class TConstraint> + bool IsConstraintEnabled() const { + return DisabledConstraints.find(TConstraint::Name()) == DisabledConstraints.end(); + } + + std::string_view GetIndexAsString(ui32 index); +private: + using TPositionHandleEqualPred = std::function<bool(TPositionHandle, TPositionHandle)>; + using TPositionHandleHasher = std::function<size_t(TPositionHandle)>; + + bool IsEqual(TPositionHandle a, TPositionHandle b) const; + size_t GetHash(TPositionHandle p) const; + + std::unordered_set<TPositionHandle, TPositionHandleHasher, TPositionHandleEqualPred> PositionSet; + std::deque<TPosition> Positions; +}; + +template <typename T, typename... Args> +inline const T* TExprContext::MakeConstraint(Args&&... args) { + ENSURE_NOT_FROZEN_CTX + if (!IsConstraintEnabled<T>()) { + return nullptr; + } + + T sample(*this, std::forward<Args>(args)...); + const auto it = ConstraintSet.find(&sample); + if (ConstraintSet.cend() != it) { + return static_cast<const T*>(*it); + } + + ConstraintNodes.emplace(new T(std::move(sample))); + const auto ins = ConstraintSet.emplace(ConstraintNodes.top().get()); + return static_cast<const T*>(*ins.first); +} + +#undef ENSURE_NOT_DELETED +#undef ENSURE_NOT_FROZEN +#undef ENSURE_NOT_FROZEN_CTX + +inline bool IsSameAnnotation(const TTypeAnnotationNode& left, const TTypeAnnotationNode& right) { + return &left == &right; +} + +template <typename T, typename... Args> +const T* TExprContext::MakeType(Args&&... args) { + return TMakeTypeImpl<T>::Make(*this, std::forward<Args>(args)...); +} + +struct TExprAnnotationFlags { + enum { + None = 0x00, + Position = 0x01, + Types = 0x02 + }; +}; + +/////////////////////////////////////////////////////////////////////////////// +// TNodeException +/////////////////////////////////////////////////////////////////////////////// +class TNodeException: public yexception { +public: + TNodeException(); + explicit TNodeException(const TExprNode& node); + explicit TNodeException(const TExprNode* node); + explicit TNodeException(const TPositionHandle& pos); + + inline const TPositionHandle& Pos() const { + return Pos_; + } + +private: + const TPositionHandle Pos_; +}; + +bool CompileExpr(TAstNode& astRoot, TExprNode::TPtr& exprRoot, TExprContext& ctx, + IModuleResolver* resolver, IUrlListerManager* urlListerManager, + bool hasAnnotations = false, ui32 typeAnnotationIndex = Max<ui32>(), ui16 syntaxVersion = 0); + +bool CompileExpr(TAstNode& astRoot, TExprNode::TPtr& exprRoot, TExprContext& ctx, + IModuleResolver* resolver, IUrlListerManager* urlListerManager, + ui32 annotationFlags, ui16 syntaxVersion = 0); + +struct TLibraryCohesion { + TExportTable Exports; + TNodeMap<std::pair<TString, TString>> Imports; +}; + +bool CompileExpr(TAstNode& astRoot, TLibraryCohesion& cohesion, TExprContext& ctx, ui16 syntaxVersion = 0); + +const TTypeAnnotationNode* CompileTypeAnnotation(const TAstNode& node, TExprContext& ctx); + +// validate consistency of arguments and lambdas +void CheckArguments(const TExprNode& root); + +void CheckCounts(const TExprNode& root); + +// Compare expression trees and return first diffrent nodes. +bool CompareExprTrees(const TExprNode*& one, const TExprNode*& two); + +bool CompareExprTreeParts(const TExprNode& one, const TExprNode& two, const TNodeMap<ui32>& argsMap); + +TString MakeCacheKey(const TExprNode& root); + +void GatherParents(const TExprNode& node, TParentsMap& parentsMap); + +struct TConvertToAstSettings { + ui32 AnnotationFlags = 0; + bool RefAtoms = false; + std::function<bool(const TExprNode&)> NoInlineFunc; + bool PrintArguments = false; + bool AllowFreeArgs = false; + bool NormalizeAtomFlags = false; + IAllocator* Allocator = TDefaultAllocator::Instance(); +}; + +TAstParseResult ConvertToAst(const TExprNode& root, TExprContext& ctx, const TConvertToAstSettings& settings); + +// refAtoms allows omit copying of atom bodies - they will be referenced from expr graph +TAstParseResult ConvertToAst(const TExprNode& root, TExprContext& ctx, ui32 annotationFlags, bool refAtoms); + +TExprNode::TListType GetLambdaBody(const TExprNode& lambda); + +TString SubstParameters(const TString& str, const TMaybe<NYT::TNode>& params, TSet<TString>* usedNames); + +const TTypeAnnotationNode* GetSeqItemType(const TTypeAnnotationNode* seq); +const TTypeAnnotationNode& GetSeqItemType(const TTypeAnnotationNode& seq); + +const TTypeAnnotationNode& RemoveOptionality(const TTypeAnnotationNode& type); + +TMaybe<TIssue> NormalizeName(TPosition position, TString& name); +TString NormalizeName(const TStringBuf& name); + +} // namespace NYql + +template<> +inline void Out<NYql::TTypeAnnotationNode>( + IOutputStream &out, const NYql::TTypeAnnotationNode& type) +{ + type.Out(out); +} + +#include "yql_expr_builder.inl" diff --git a/yql/essentials/ast/yql_expr_builder.cpp b/yql/essentials/ast/yql_expr_builder.cpp new file mode 100644 index 00000000000..5468cadd2cc --- /dev/null +++ b/yql/essentials/ast/yql_expr_builder.cpp @@ -0,0 +1,502 @@ +#include "yql_expr_builder.h" +#include "yql_expr.h" + +namespace NYql { + +TExprNodeBuilder::TExprNodeBuilder(TPositionHandle pos, TExprContext& ctx) + : Ctx(ctx) + , Parent(nullptr) + , ParentReplacer(nullptr) + , Container(nullptr) + , Pos(pos) + , CurrentNode(nullptr) +{} + +TExprNodeBuilder::TExprNodeBuilder(TPositionHandle pos, TExprContext& ctx, ExtArgsFuncType extArgsFunc) + : Ctx(ctx) + , Parent(nullptr) + , ParentReplacer(nullptr) + , Container(nullptr) + , Pos(pos) + , CurrentNode(nullptr) + , ExtArgsFunc(extArgsFunc) +{} + +TExprNodeBuilder::TExprNodeBuilder(TPositionHandle pos, TExprNodeBuilder* parent, const TExprNode::TPtr& container) + : Ctx(parent->Ctx) + , Parent(parent) + , ParentReplacer(nullptr) + , Container(std::move(container)) + , Pos(pos) + , CurrentNode(nullptr) +{ + if (Parent) { + ExtArgsFunc = Parent->ExtArgsFunc; + } +} + +TExprNodeBuilder::TExprNodeBuilder(TPositionHandle pos, TExprNodeReplaceBuilder* parentReplacer) + : Ctx(parentReplacer->Owner->Ctx) + , Parent(nullptr) + , ParentReplacer(parentReplacer) + , Container(nullptr) + , Pos(pos) + , CurrentNode(nullptr) +{ +} + +TExprNode::TPtr TExprNodeBuilder::Build() { + Y_ENSURE(CurrentNode, "No current node"); + Y_ENSURE(!Parent, "Build is allowed only on top level"); + if (CurrentNode->Type() == TExprNode::Lambda) { + Y_ENSURE(CurrentNode->ChildrenSize() > 0U, "Lambda is not complete"); + } + + return CurrentNode; +} + +TExprNodeBuilder& TExprNodeBuilder::Seal() { + Y_ENSURE(Parent, "Seal is allowed only on non-top level"); + if (Container->Type() == TExprNode::Lambda) { + if (CurrentNode) { + Y_ENSURE(Container->ChildrenSize() == 1, "Lambda is already complete."); + Container->Children_.emplace_back(std::move(CurrentNode)); + } else { + Y_ENSURE(Container->ChildrenSize() > 0U, "Lambda isn't complete."); + } + } + + return *Parent; +} + +TExprNodeReplaceBuilder& TExprNodeBuilder::Done() { + Y_ENSURE(ParentReplacer, "Done is allowed only if parent is a replacer"); + Y_ENSURE(CurrentNode, "No current node"); + for (auto& body : ParentReplacer->Body) + body = Ctx.ReplaceNode(std::move(body), ParentReplacer->CurrentNode ? *ParentReplacer->CurrentNode : *ParentReplacer->Args->Child(ParentReplacer->CurrentIndex), CurrentNode); + return *ParentReplacer; +} + +TExprNodeBuilder& TExprNodeBuilder::Atom(ui32 index, TPositionHandle pos, const TStringBuf& content, ui32 flags) { + Y_ENSURE(Container && !Container->IsLambda(), "Container expected"); + Y_ENSURE(Container->ChildrenSize() == index, + "Container position mismatch, expected: " << Container->ChildrenSize() << + ", actual: " << index); + auto child = Ctx.NewAtom(pos, content, flags); + Container->Children_.push_back(child); + return *this; +} + +TExprNodeBuilder& TExprNodeBuilder::Atom(TPositionHandle pos, const TStringBuf& content, ui32 flags) { + Y_ENSURE(!Container || Container->IsLambda(), "No container expected"); + Y_ENSURE(!CurrentNode, "Node is already build"); + CurrentNode = Ctx.NewAtom(pos, content, flags); + return *this; +} + +TExprNodeBuilder& TExprNodeBuilder::Atom(ui32 index, const TStringBuf& content, ui32 flags) { + return Atom(index, Pos, content, flags); +} + +TExprNodeBuilder& TExprNodeBuilder::Atom(const TStringBuf& content, ui32 flags) { + return Atom(Pos, content, flags); +} + +TExprNodeBuilder& TExprNodeBuilder::Atom(ui32 index, ui32 literalIndexValue) { + return Atom(index, Pos, Ctx.GetIndexAsString(literalIndexValue), TNodeFlags::Default); +} + +TExprNodeBuilder& TExprNodeBuilder::Atom(ui32 literalIndexValue) { + return Atom(Pos, Ctx.GetIndexAsString(literalIndexValue), TNodeFlags::Default); +} + +TExprNodeBuilder TExprNodeBuilder::List(ui32 index, TPositionHandle pos) { + Y_ENSURE(Container, "Container expected"); + Y_ENSURE(Container->ChildrenSize() == index + (Container->IsLambda() ? 1U : 0U), + "Container position mismatch, expected: " << Container->ChildrenSize() << + ", actual: " << index); + const auto child = Ctx.NewList(pos, {}); + Container->Children_.push_back(child); + return TExprNodeBuilder(pos, this, child); +} + +TExprNodeBuilder TExprNodeBuilder::List(TPositionHandle pos) { + Y_ENSURE(!Container || Container->Type() == TExprNode::Lambda, "No container expected"); + Y_ENSURE(!CurrentNode, "Node is already build"); + CurrentNode = Ctx.NewList(pos, {}); + return TExprNodeBuilder(pos, this, CurrentNode); +} + +TExprNodeBuilder TExprNodeBuilder::List(ui32 index) { + return List(index, Pos); +} + +TExprNodeBuilder TExprNodeBuilder::List() { + return List(Pos); +} + +TExprNodeBuilder& TExprNodeBuilder::Add(ui32 index, const TExprNode::TPtr& child) { + Y_ENSURE(Container, "Container expected"); + Y_ENSURE(Container->ChildrenSize() == index + (Container->IsLambda() ? 1U : 0U), + "Container position mismatch, expected: " << Container->ChildrenSize() << + ", actual: " << index); + Y_ENSURE(child, "child should not be nullptr"); + Container->Children_.push_back(child); + return *this; +} + +TExprNodeBuilder& TExprNodeBuilder::Add(ui32 index, TExprNode::TPtr&& child) { + Y_ENSURE(Container, "Container expected"); + Y_ENSURE(Container->ChildrenSize() == index + (Container->IsLambda() ? 1U : 0U), + "Container position mismatch, expected: " << Container->ChildrenSize() << + ", actual: " << index); + Y_ENSURE(child, "child should not be nullptr"); + Container->Children_.push_back(std::move(child)); + return *this; +} + +TExprNodeBuilder& TExprNodeBuilder::Add(TExprNode::TListType&& children) { + Y_ENSURE(Container && Container->Type() != TExprNode::Lambda, "Container expected"); + Y_ENSURE(Container->Children_.empty(), "container should be empty"); + Container->Children_ = std::move(children); + return *this; +} + +TExprNodeBuilder& TExprNodeBuilder::Set(TExprNode::TPtr&& body) { + Y_ENSURE(Container && Container->Type() == TExprNode::Lambda, "Lambda expected"); + Y_ENSURE(!CurrentNode, "Lambda already has a body"); + CurrentNode = std::move(body); + return *this; +} + +TExprNodeBuilder& TExprNodeBuilder::Set(const TExprNode::TPtr& body) { + Y_ENSURE(Container && Container->Type() == TExprNode::Lambda, "Lambda expected"); + Y_ENSURE(!CurrentNode, "Lambda already has a body"); + CurrentNode = body; + return *this; +} + +TExprNodeBuilder TExprNodeBuilder::Callable(ui32 index, TPositionHandle pos, const TStringBuf& content) { + Y_ENSURE(Container, "Container expected"); + Y_ENSURE(Container->ChildrenSize() == index + (Container->IsLambda() ? 1U : 0U), + "Container position mismatch, expected: " << Container->ChildrenSize() << + ", actual: " << index); + auto child = Ctx.NewCallable(pos, content, {}); + Container->Children_.push_back(child); + return TExprNodeBuilder(pos, this, child); +} + +TExprNodeBuilder TExprNodeBuilder::Callable(TPositionHandle pos, const TStringBuf& content) { + Y_ENSURE(!Container || Container->Type() == TExprNode::Lambda, "No container expected"); + Y_ENSURE(!CurrentNode, "Node is already build"); + CurrentNode = Ctx.NewCallable(pos, content, {}); + return TExprNodeBuilder(pos, this, CurrentNode); +} + +TExprNodeBuilder TExprNodeBuilder::Callable(ui32 index, const TStringBuf& content) { + return Callable(index, Pos, content); +} + +TExprNodeBuilder TExprNodeBuilder::Callable(const TStringBuf& content) { + return Callable(Pos, content); +} + +TExprNodeBuilder& TExprNodeBuilder::World(ui32 index, TPositionHandle pos) { + Y_ENSURE(Container, "Container expected"); + Y_ENSURE(Container->ChildrenSize() == index + (Container->IsLambda() ? 1U : 0U), + "Container position mismatch, expected: " << Container->ChildrenSize() << + ", actual: " << index); + auto child = Ctx.NewWorld(pos); + Container->Children_.push_back(child); + return *this; +} + +TExprNodeBuilder& TExprNodeBuilder::World(TPositionHandle pos) { + Y_ENSURE(!Container, "No container expected"); + Y_ENSURE(!CurrentNode, "Node is already build"); + CurrentNode = Ctx.NewWorld(pos); + return *this; +} + +TExprNodeBuilder& TExprNodeBuilder::World(ui32 index) { + return World(index, Pos); +} + +TExprNodeBuilder& TExprNodeBuilder::World() { + return World(Pos); +} + +TExprNodeBuilder TExprNodeBuilder::Lambda(ui32 index, TPositionHandle pos) { + Y_ENSURE(Container, "Container expected"); + Y_ENSURE(Container->ChildrenSize() == index + (Container->IsLambda() ? 1U : 0U), + "Container position mismatch, expected: " << Container->ChildrenSize() << + ", actual: " << index); + auto child = Ctx.NewLambda(pos, Ctx.NewArguments(pos, {}), nullptr); + Container->Children_.push_back(child); + return TExprNodeBuilder(pos, this, child); +} + +TExprNodeBuilder TExprNodeBuilder::Lambda(TPositionHandle pos) { + Y_ENSURE(!Container || Container->Type() == TExprNode::Lambda, "No container expected"); + Y_ENSURE(!CurrentNode, "Node is already build"); + CurrentNode = Ctx.NewLambda(pos, Ctx.NewArguments(pos, {}), nullptr); + return TExprNodeBuilder(pos, this, CurrentNode); +} + +TExprNodeBuilder TExprNodeBuilder::Lambda(ui32 index) { + return Lambda(index, Pos); +} + +TExprNodeBuilder TExprNodeBuilder::Lambda() { + return Lambda(Pos); +} + +TExprNodeBuilder& TExprNodeBuilder::Param(TPositionHandle pos, const TStringBuf& name) { + Y_ENSURE(!name.empty(), "Empty parameter name is not allowed"); + Y_ENSURE(Container, "Container expected"); + Y_ENSURE(Container->Type() == TExprNode::Lambda, "Container must be a lambda"); + Y_ENSURE(!CurrentNode, "Lambda already has a body"); + for (auto arg : Container->Head().Children()) { + Y_ENSURE(arg->Content() != name, "Duplicate of lambda param name: " << name); + } + + Container->Head().Children_.push_back(Ctx.NewArgument(pos, name)); + return *this; +} + +TExprNodeBuilder& TExprNodeBuilder::Param(const TStringBuf& name) { + return Param(Pos, name); +} + +TExprNodeBuilder& TExprNodeBuilder::Params(const TStringBuf& name, ui32 width) { + Y_ENSURE(!name.empty(), "Empty parameter name is not allowed"); + for (ui32 i = 0U; i < width; ++i) + Param(Pos, TString(name) += ToString(i)); + return *this; +} + +TExprNodeBuilder& TExprNodeBuilder::Arg(ui32 index, const TStringBuf& name) { + Y_ENSURE(!name.empty(), "Empty parameter name is not allowed"); + Y_ENSURE(Container, "Container expected"); + Y_ENSURE(Container->ChildrenSize() == index + (Container->IsLambda() ? 1U : 0U), + "Container position mismatch, expected: " << Container->ChildrenSize() << + ", actual: " << index); + auto arg = FindArgument(name); + Container->Children_.push_back(arg); + return *this; +} + +TExprNodeBuilder& TExprNodeBuilder::Arg(const TStringBuf& name) { + Y_ENSURE(!Container || Container->Type() == TExprNode::Lambda, "No container expected"); + Y_ENSURE(!CurrentNode, "Node is already build"); + CurrentNode = FindArgument(name); + return *this; +} + +TExprNodeBuilder& TExprNodeBuilder::Arg(ui32 index, const TStringBuf& name, ui32 toIndex) { + Y_ENSURE(!name.empty(), "Empty parameter name is not allowed"); + return Arg(index, TString(name) += ToString(toIndex)); +} + +TExprNodeBuilder& TExprNodeBuilder::Arg(const TStringBuf& name, ui32 toIndex) { + Y_ENSURE(!name.empty(), "Empty parameter name is not allowed"); + return Arg(TString(name) += ToString(toIndex)); +} + +TExprNodeBuilder& TExprNodeBuilder::Arg(const TExprNodePtr& arg) { + Y_ENSURE(arg->Type() == TExprNode::Argument, "Argument expected"); + Y_ENSURE(!Container || Container->Type() == TExprNode::Lambda, "No container expected"); + Y_ENSURE(!CurrentNode, "Node is already build"); + CurrentNode = arg; + return *this; +} + +TExprNodeBuilder& TExprNodeBuilder::Args(ui32 index, const TStringBuf& name, ui32 toIndex) { + Y_ENSURE(!name.empty(), "Empty parameter name is not allowed"); + for (auto i = 0U; index < toIndex; ++i) + Arg(index++, TString(name) += ToString(i)); + return *this; +} + +TExprNodeBuilder& TExprNodeBuilder::Args(const TStringBuf& name, ui32 toIndex) { + Y_ENSURE(!name.empty(), "Empty parameter name is not allowed"); + for (auto i = 0U; i < toIndex; ++i) + Arg(i, TString(name) += ToString(i)); + return *this; +} + +TExprNode::TPtr TExprNodeBuilder::FindArgument(const TStringBuf& name) { + for (auto builder = this; builder; builder = builder->Parent) { + while (builder->ParentReplacer) { + builder = builder->ParentReplacer->Owner; + } + + if (builder->Container && builder->Container->IsLambda()) { + for (const auto& arg : builder->Container->Head().Children()) { + if (arg->Content() == name) { + return arg; + } + } + } + } + + if (ExtArgsFunc) { + if (const auto arg = ExtArgsFunc(name)) { + return arg; + } + } + + ythrow yexception() << "Parameter not found: " << name; +} + +TExprNodeReplaceBuilder TExprNodeBuilder::Apply(ui32 index, const TExprNode& lambda) { + Y_ENSURE(Container, "Container expected"); + Y_ENSURE(Container->ChildrenSize() == index + (Container->IsLambda() ? 1U : 0U), + "Container position mismatch, expected: " << Container->ChildrenSize() << + ", actual: " << index); + return TExprNodeReplaceBuilder(this, Container, lambda.HeadPtr(), GetLambdaBody(lambda)); +} + +TExprNodeReplaceBuilder TExprNodeBuilder::Apply(const TExprNode& lambda) { + Y_ENSURE(!Container || Container->Type() == TExprNode::Lambda, "No container expected"); + Y_ENSURE(!CurrentNode, "Node is already build"); + return TExprNodeReplaceBuilder(this, Container, lambda.HeadPtr(), GetLambdaBody(lambda)); +} + +TExprNodeReplaceBuilder TExprNodeBuilder::Apply(ui32 index, const TExprNode::TPtr& lambda) { + return Apply(index, *lambda); +} + +TExprNodeReplaceBuilder TExprNodeBuilder::Apply(const TExprNode::TPtr& lambda) { + return Apply(*lambda); +} + +TExprNodeReplaceBuilder TExprNodeBuilder::ApplyPartial(ui32 index, TExprNode::TPtr args, TExprNode::TPtr body) { + Y_ENSURE(Container, "Container expected"); + Y_ENSURE(Container->ChildrenSize() == index + (Container->IsLambda() ? 1U : 0U), + "Container position mismatch, expected: " << Container->ChildrenSize() << + ", actual: " << index); + Y_ENSURE(!args || args->IsArguments()); + return TExprNodeReplaceBuilder(this, Container, std::move(args), std::move(body)); +} + +TExprNodeReplaceBuilder TExprNodeBuilder::ApplyPartial(ui32 index, TExprNode::TPtr args, TExprNode::TListType body) { + Y_ENSURE(Container, "Container expected"); + Y_ENSURE(Container->ChildrenSize() == index + (Container->IsLambda() ? 1U : 0U), + "Container position mismatch, expected: " << Container->ChildrenSize() << + ", actual: " << index); + Y_ENSURE(!args || args->IsArguments()); + return TExprNodeReplaceBuilder(this, Container, std::move(args), std::move(body)); +} + +TExprNodeReplaceBuilder TExprNodeBuilder::ApplyPartial(TExprNode::TPtr args, TExprNode::TPtr body) { + Y_ENSURE(!Container || Container->Type() == TExprNode::Lambda, "No container expected"); + Y_ENSURE(!CurrentNode, "Node is already build"); + Y_ENSURE(!args || args->IsArguments()); + return TExprNodeReplaceBuilder(this, Container, std::move(args), std::move(body)); +} + +TExprNodeReplaceBuilder TExprNodeBuilder::ApplyPartial(TExprNode::TPtr args, TExprNode::TListType body) { + Y_ENSURE(!Container || Container->Type() == TExprNode::Lambda, "No container expected"); + Y_ENSURE(!CurrentNode, "Node is already build"); + Y_ENSURE(!args || args->IsArguments()); + return TExprNodeReplaceBuilder(this, Container, std::move(args), std::move(body)); +} + +TExprNodeReplaceBuilder::TExprNodeReplaceBuilder(TExprNodeBuilder* owner, TExprNode::TPtr container, + TExprNode::TPtr&& args, TExprNode::TPtr&& body) + : Owner(owner) + , Container(std::move(container)) + , Args(std::move(args)) + , Body({std::move(body)}) + , CurrentIndex(0) + , CurrentNode(nullptr) +{ +} + +TExprNodeReplaceBuilder::TExprNodeReplaceBuilder(TExprNodeBuilder* owner, TExprNode::TPtr container, + TExprNode::TPtr&& args, TExprNode::TListType&& body) + : Owner(owner) + , Container(std::move(container)) + , Args(std::move(args)) + , Body(std::move(body)) + , CurrentIndex(0) + , CurrentNode(nullptr) +{ +} + +TExprNodeReplaceBuilder::TExprNodeReplaceBuilder(TExprNodeBuilder* owner, TExprNode::TPtr container, + const TExprNode& lambda) + : TExprNodeReplaceBuilder(owner, std::move(container), lambda.HeadPtr(), lambda.TailPtr()) +{ + Y_ENSURE(lambda.Type() == TExprNode::Lambda, "Expected lambda"); +} + +TExprNodeReplaceBuilder& TExprNodeReplaceBuilder::With( + ui32 argIndex, const TStringBuf& toName) { + Y_ENSURE(Args, "No arguments"); + Y_ENSURE(argIndex < Args->ChildrenSize(), "Wrong argument index"); + Body = Owner->Ctx.ReplaceNodes(std::move(Body), {{Args->Child(argIndex), Owner->FindArgument(toName)}}); + return *this; +} + +TExprNodeReplaceBuilder& TExprNodeReplaceBuilder::With( + ui32 argIndex, TExprNode::TPtr toNode) { + Y_ENSURE(Args, "No arguments"); + Y_ENSURE(argIndex < Args->ChildrenSize(), "Wrong argument index"); + Body = Owner->Ctx.ReplaceNodes(std::move(Body), {{Args->Child(argIndex), std::move(toNode)}}); + return *this; +} + +TExprNodeReplaceBuilder& TExprNodeReplaceBuilder::With(const TStringBuf& toName) { + Y_ENSURE(Args, "No arguments"); + Y_ENSURE(!toName.empty(), "Empty parameter name is not allowed"); + TNodeOnNodeOwnedMap replaces(Args->ChildrenSize()); + for (ui32 i = 0U; i < Args->ChildrenSize(); ++i) + Y_ENSURE(replaces.emplace(Args->Child(i), Owner->FindArgument(TString(toName) += ToString(i))).second, "Must be uinique."); + Body = Owner->Ctx.ReplaceNodes(std::move(Body), replaces); + return *this; +} + +TExprNodeReplaceBuilder& TExprNodeReplaceBuilder::With(const TStringBuf& toName, ui32 toIndex) { + Y_ENSURE(!toName.empty(), "Empty parameter name is not allowed"); + return With(TString(toName) += ToString(toIndex)); +} + +TExprNodeReplaceBuilder& TExprNodeReplaceBuilder::With(ui32 argIndex, const TStringBuf& toName, ui32 toIndex) { + Y_ENSURE(!toName.empty(), "Empty parameter name is not allowed"); + return With(argIndex, TString(toName) += ToString(toIndex)); +} + +TExprNodeReplaceBuilder& TExprNodeReplaceBuilder::WithNode(const TExprNode& fromNode, TExprNode::TPtr&& toNode) { + Body = Owner->Ctx.ReplaceNodes(std::move(Body), {{&fromNode, std::move(toNode)}}); + return *this; +} + +TExprNodeReplaceBuilder& TExprNodeReplaceBuilder::WithNode(const TExprNode& fromNode, const TStringBuf& toName) { + return WithNode(fromNode, Owner->FindArgument(toName)); +} + +TExprNodeBuilder TExprNodeReplaceBuilder::With(ui32 argIndex) { + CurrentIndex = argIndex; + return TExprNodeBuilder(Owner->Pos, this); +} + +TExprNodeBuilder TExprNodeReplaceBuilder::WithNode(TExprNode::TPtr&& fromNode) { + CurrentNode = std::move(fromNode); + return TExprNodeBuilder(Owner->Pos, this); +} + +TExprNodeBuilder& TExprNodeReplaceBuilder::Seal() { + if (Container) { + std::move(Body.begin(), Body.end(), std::back_inserter(Container->Children_)); + } else { + Y_ENSURE(1U == Body.size() && Body.front(), "Expected single node."); + Owner->CurrentNode = std::move(Body.front()); + } + Body.clear(); + return *Owner; +} + +} // namespace NYql + diff --git a/yql/essentials/ast/yql_expr_builder.h b/yql/essentials/ast/yql_expr_builder.h new file mode 100644 index 00000000000..db816359eb9 --- /dev/null +++ b/yql/essentials/ast/yql_expr_builder.h @@ -0,0 +1,175 @@ +#pragma once + +#include "yql_ast.h" +#include "yql_errors.h" +#include "yql_pos_handle.h" + +#include <functional> + +namespace NYql { + +struct TExprContext; +class TExprNode; +typedef TIntrusivePtr<TExprNode> TExprNodePtr; +typedef std::vector<TExprNodePtr> TExprNodeList; + +class TExprNodeReplaceBuilder; + +class TExprNodeBuilder { +friend class TExprNodeReplaceBuilder; +public: + typedef std::function<TExprNodePtr(const TStringBuf&)> ExtArgsFuncType; +public: + TExprNodeBuilder(TPositionHandle pos, TExprContext& ctx); + TExprNodeBuilder(TPositionHandle pos, TExprContext& ctx, ExtArgsFuncType extArgsFunc); + TExprNodePtr Build(); + TExprNodeBuilder& Seal(); + TExprNodeReplaceBuilder& Done(); + + // Indexed version of methods must be used inside of Callable or List, otherwise + // non-indexed version must be used (at root or as lambda body) + TExprNodeBuilder& Atom(ui32 index, TPositionHandle pos, const TStringBuf& content, ui32 flags = TNodeFlags::ArbitraryContent); + TExprNodeBuilder& Atom(TPositionHandle pos, const TStringBuf& content, ui32 flags = TNodeFlags::ArbitraryContent); + TExprNodeBuilder& Atom(ui32 index, const TStringBuf& content, ui32 flags = TNodeFlags::ArbitraryContent); + TExprNodeBuilder& Atom(const TStringBuf& content, ui32 flags = TNodeFlags::ArbitraryContent); + TExprNodeBuilder& Atom(ui32 index, ui32 literalIndexValue); + TExprNodeBuilder& Atom(ui32 literalIndexValue); + + TExprNodeBuilder List(ui32 index, TPositionHandle pos); + TExprNodeBuilder List(TPositionHandle pos); + TExprNodeBuilder List(ui32 index); + TExprNodeBuilder List(); + + TExprNodeBuilder& Add(ui32 index, TExprNodePtr&& child); + TExprNodeBuilder& Add(ui32 index, const TExprNodePtr& child); + TExprNodeBuilder& Add(TExprNodeList&& children); + // only for lambda bodies + TExprNodeBuilder& Set(TExprNodePtr&& body); + TExprNodeBuilder& Set(const TExprNodePtr& body); + + TExprNodeBuilder Callable(ui32 index, TPositionHandle pos, const TStringBuf& content); + TExprNodeBuilder Callable(TPositionHandle pos, const TStringBuf& content); + TExprNodeBuilder Callable(ui32 index, const TStringBuf& content); + TExprNodeBuilder Callable(const TStringBuf& content); + + TExprNodeBuilder& World(ui32 index, TPositionHandle pos); + TExprNodeBuilder& World(TPositionHandle pos); + TExprNodeBuilder& World(ui32 index); + TExprNodeBuilder& World(); + + TExprNodeBuilder Lambda(ui32 index, TPositionHandle pos); + TExprNodeBuilder Lambda(TPositionHandle pos); + TExprNodeBuilder Lambda(ui32 index); + TExprNodeBuilder Lambda(); + + TExprNodeBuilder& Param(TPositionHandle pos, const TStringBuf& name); + TExprNodeBuilder& Param(const TStringBuf& name); + TExprNodeBuilder& Params(const TStringBuf& name, ui32 width); + + TExprNodeBuilder& Arg(ui32 index, const TStringBuf& name); + TExprNodeBuilder& Arg(const TStringBuf& name); + TExprNodeBuilder& Arg(ui32 index, const TStringBuf& name, ui32 toIndex); + TExprNodeBuilder& Arg(const TStringBuf& name, ui32 toIndex); + TExprNodeBuilder& Arg(const TExprNodePtr& arg); + + TExprNodeBuilder& Args(ui32 index, const TStringBuf& name, ui32 toIndex); + TExprNodeBuilder& Args(const TStringBuf& name, ui32 toIndex); + + TExprNodeReplaceBuilder Apply(ui32 index, const TExprNode& lambda); + TExprNodeReplaceBuilder Apply(ui32 index, const TExprNodePtr& lambda); + TExprNodeReplaceBuilder Apply(const TExprNode& lambda); + TExprNodeReplaceBuilder Apply(const TExprNodePtr& lambda); + TExprNodeReplaceBuilder ApplyPartial(ui32 index, TExprNodePtr args, TExprNodePtr body); + TExprNodeReplaceBuilder ApplyPartial(ui32 index, TExprNodePtr args, TExprNodeList body); + TExprNodeReplaceBuilder ApplyPartial(TExprNodePtr args, TExprNodePtr body); + TExprNodeReplaceBuilder ApplyPartial(TExprNodePtr args, TExprNodeList body); + + template <typename TFunc> + TExprNodeBuilder& Do(const TFunc& func) { + return func(*this); + } + +private: + TExprNodeBuilder(TPositionHandle pos, TExprNodeBuilder* parent, const TExprNodePtr& container); + TExprNodeBuilder(TPositionHandle pos, TExprNodeReplaceBuilder* parentReplacer); + TExprNodePtr FindArgument(const TStringBuf& name); + +private: + TExprContext& Ctx; + TExprNodeBuilder* Parent; + TExprNodeReplaceBuilder* ParentReplacer; + TExprNodePtr Container; + TPositionHandle Pos; + TExprNodePtr CurrentNode; + ExtArgsFuncType ExtArgsFunc; +}; + +namespace NNodes { + template<typename TParent, typename TNode> + class TNodeBuilder; +} + +class TExprNodeReplaceBuilder { +friend class TExprNodeBuilder; +private: + struct TBuildAdapter { + typedef TExprNodeReplaceBuilder& ResultType; + + TBuildAdapter(TExprNodeReplaceBuilder& builder) + : Builder(builder) {} + + ResultType Value() { + return Builder; + } + + TExprNodeReplaceBuilder& Builder; + }; + +public: + TExprNodeReplaceBuilder(TExprNodeBuilder* owner, TExprNodePtr container, const TExprNode& lambda); + TExprNodeReplaceBuilder(TExprNodeBuilder* owner, TExprNodePtr container, TExprNodePtr&& args, TExprNodePtr&& body); + TExprNodeReplaceBuilder(TExprNodeBuilder* owner, TExprNodePtr container, TExprNodePtr&& args, TExprNodeList&& body); + TExprNodeReplaceBuilder& With(ui32 argIndex, const TStringBuf& toName); + TExprNodeReplaceBuilder& With(ui32 argIndex, const TStringBuf& toName, ui32 toIndex); + TExprNodeReplaceBuilder& With(ui32 argIndex, TExprNodePtr toNode); + TExprNodeReplaceBuilder& With(const TStringBuf& toName); + TExprNodeReplaceBuilder& With(const TStringBuf& toName, ui32 toIndex); + TExprNodeReplaceBuilder& WithNode(const TExprNode& fromNode, TExprNodePtr&& toNode); + TExprNodeReplaceBuilder& WithNode(const TExprNode& fromNode, const TStringBuf& toName); + TExprNodeBuilder With(ui32 argIndex); + TExprNodeBuilder WithNode(TExprNodePtr&& fromNode); + + template<typename TNode> + NNodes::TNodeBuilder<TBuildAdapter, TNode> With(ui32 argIndex) { + TBuildAdapter adapter(*this); + + NNodes::TNodeBuilder<TBuildAdapter, TNode> builder(Owner->Ctx, Owner->Pos, + [adapter, argIndex](const TNode& node) mutable -> TBuildAdapter& { + adapter.Builder = adapter.Builder.With(argIndex, node.Get()); + return adapter; + }, + [adapter] (const TStringBuf& argName) { + return adapter.Builder.Owner->FindArgument(argName); + }); + + return builder; + } + + TExprNodeBuilder& Seal(); + + template <typename TFunc> + TExprNodeReplaceBuilder& Do(const TFunc& func) { + return func(*this); + } + +private: + TExprNodeBuilder* Owner; + TExprNodePtr Container; + TExprNodePtr Args; + TExprNodeList Body; + ui32 CurrentIndex; + TExprNodePtr CurrentNode; +}; + +} // namespace NYql + diff --git a/yql/essentials/ast/yql_expr_builder.inl b/yql/essentials/ast/yql_expr_builder.inl new file mode 100644 index 00000000000..5f711f37f12 --- /dev/null +++ b/yql/essentials/ast/yql_expr_builder.inl @@ -0,0 +1,6 @@ +#pragma once + +namespace NYql { + +} // namespace NYql + diff --git a/yql/essentials/ast/yql_expr_builder_ut.cpp b/yql/essentials/ast/yql_expr_builder_ut.cpp new file mode 100644 index 00000000000..eb0f2a0bf16 --- /dev/null +++ b/yql/essentials/ast/yql_expr_builder_ut.cpp @@ -0,0 +1,690 @@ +#include "yql_expr.h" +#include <library/cpp/testing/unittest/registar.h> + +namespace NYql { + +Y_UNIT_TEST_SUITE(TExprBuilder) { + Y_UNIT_TEST(TestEmpty) { + TExprContext ctx; + UNIT_ASSERT_EXCEPTION(ctx.Builder(TPositionHandle()) + .Build(), yexception); + } + + Y_UNIT_TEST(TestRootAtom) { + TExprContext ctx; + auto res = ctx.Builder(TPositionHandle()).Atom("ABC").Build(); + UNIT_ASSERT_VALUES_EQUAL(res->Type(), TExprNode::Atom); + UNIT_ASSERT_VALUES_EQUAL(res->Content(), "ABC"); + } + + Y_UNIT_TEST(TestRootAtomTwice) { + TExprContext ctx; + UNIT_ASSERT_EXCEPTION(ctx.Builder(TPositionHandle()) + .Atom("ABC") + .Atom("ABC") + .Build(), yexception); + } + + Y_UNIT_TEST(TestRootEmptyList) { + TExprContext ctx; + auto res = ctx.Builder(TPositionHandle()).List().Seal().Build(); + UNIT_ASSERT_VALUES_EQUAL(res->Type(), TExprNode::List); + UNIT_ASSERT_VALUES_EQUAL(res->ChildrenSize(), 0); + } + + Y_UNIT_TEST(TestRootEmptyListTwice) { + TExprContext ctx; + UNIT_ASSERT_EXCEPTION(ctx.Builder(TPositionHandle()) + .List() + .List() + .Build(), yexception); + } + + Y_UNIT_TEST(TestListWithAtoms) { + TExprContext ctx; + auto res = ctx.Builder(TPositionHandle()) + .List() + .Atom(0, "ABC") + .Atom(1, "XYZ") + .Seal() + .Build(); + + UNIT_ASSERT_VALUES_EQUAL(res->Type(), TExprNode::List); + UNIT_ASSERT_VALUES_EQUAL(res->ChildrenSize(), 2); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Type(), TExprNode::Atom); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Content(), "ABC"); + UNIT_ASSERT_VALUES_EQUAL(res->Child(1)->Type(), TExprNode::Atom); + UNIT_ASSERT_VALUES_EQUAL(res->Child(1)->Content(), "XYZ"); + } + + Y_UNIT_TEST(TestMismatchChildIndex) { + TExprContext ctx; + UNIT_ASSERT_EXCEPTION(ctx.Builder(TPositionHandle()) + .List() + .Atom(1, "") + .Build(), yexception); + } + + Y_UNIT_TEST(TestListWithAdd) { + TExprContext ctx; + auto res = ctx.Builder(TPositionHandle()) + .List() + .Add(0, ctx.Builder(TPositionHandle()).Atom("ABC").Build()) + .Atom(1, "XYZ") + .Seal() + .Build(); + + UNIT_ASSERT_VALUES_EQUAL(res->Type(), TExprNode::List); + UNIT_ASSERT_VALUES_EQUAL(res->ChildrenSize(), 2); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Type(), TExprNode::Atom); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Content(), "ABC"); + UNIT_ASSERT_VALUES_EQUAL(res->Child(1)->Type(), TExprNode::Atom); + UNIT_ASSERT_VALUES_EQUAL(res->Child(1)->Content(), "XYZ"); + } + + Y_UNIT_TEST(TestNestedListWithAtoms) { + TExprContext ctx; + auto res = ctx.Builder(TPositionHandle()) + .List() + .List(0) + .Atom(0, "ABC") + .Atom(1, "DEF") + .Seal() + .Atom(1, "XYZ") + .Seal() + .Build(); + + UNIT_ASSERT_VALUES_EQUAL(res->Type(), TExprNode::List); + UNIT_ASSERT_VALUES_EQUAL(res->ChildrenSize(), 2); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Type(), TExprNode::List); + UNIT_ASSERT_VALUES_EQUAL(res->Head().ChildrenSize(), 2); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Head().Type(), TExprNode::Atom); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Head().Content(), "ABC"); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Child(1)->Type(), TExprNode::Atom); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Child(1)->Content(), "DEF"); + UNIT_ASSERT_VALUES_EQUAL(res->Child(1)->Type(), TExprNode::Atom); + UNIT_ASSERT_VALUES_EQUAL(res->Child(1)->Content(), "XYZ"); + } + + Y_UNIT_TEST(TestWrongLevelBuild) { + TExprContext ctx; + UNIT_ASSERT_EXCEPTION(ctx.Builder(TPositionHandle()) + .List() + .Build(), yexception); + } + + Y_UNIT_TEST(TestWrongLevelSeal) { + TExprContext ctx; + UNIT_ASSERT_EXCEPTION(ctx.Builder(TPositionHandle()) + .Seal(), yexception); + } + + Y_UNIT_TEST(TestCallableWithAtoms) { + TExprContext ctx; + auto res = ctx.Builder(TPositionHandle()) + .Callable("Func") + .Atom(0, "ABC") + .Atom(1, "XYZ") + .Seal() + .Build(); + + UNIT_ASSERT_VALUES_EQUAL(res->Type(), TExprNode::Callable); + UNIT_ASSERT_VALUES_EQUAL(res->Content(), "Func"); + UNIT_ASSERT_VALUES_EQUAL(res->ChildrenSize(), 2); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Type(), TExprNode::Atom); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Content(), "ABC"); + UNIT_ASSERT_VALUES_EQUAL(res->Child(1)->Type(), TExprNode::Atom); + UNIT_ASSERT_VALUES_EQUAL(res->Child(1)->Content(), "XYZ"); + } + + Y_UNIT_TEST(TestNestedCallableWithAtoms) { + TExprContext ctx; + auto res = ctx.Builder(TPositionHandle()) + .Callable("Func1") + .Callable(0, "Func2") + .Atom(0, "ABC") + .Atom(1, "DEF") + .Seal() + .Atom(1, "XYZ") + .Seal() + .Build(); + + UNIT_ASSERT_VALUES_EQUAL(res->Type(), TExprNode::Callable); + UNIT_ASSERT_VALUES_EQUAL(res->Content(), "Func1"); + UNIT_ASSERT_VALUES_EQUAL(res->ChildrenSize(), 2); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Type(), TExprNode::Callable); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Content(), "Func2"); + UNIT_ASSERT_VALUES_EQUAL(res->Head().ChildrenSize(), 2); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Head().Type(), TExprNode::Atom); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Head().Content(), "ABC"); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Child(1)->Type(), TExprNode::Atom); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Child(1)->Content(), "DEF"); + UNIT_ASSERT_VALUES_EQUAL(res->Child(1)->Type(), TExprNode::Atom); + UNIT_ASSERT_VALUES_EQUAL(res->Child(1)->Content(), "XYZ"); + } + + Y_UNIT_TEST(TestRootWorld) { + TExprContext ctx; + auto res = ctx.Builder(TPositionHandle()).World().Build(); + UNIT_ASSERT_VALUES_EQUAL(res->Type(), TExprNode::World); + } + + Y_UNIT_TEST(TestCallableWithWorld) { + TExprContext ctx; + auto res = ctx.Builder(TPositionHandle()) + .Callable("Func") + .World(0) + .Seal() + .Build(); + + UNIT_ASSERT_VALUES_EQUAL(res->Type(), TExprNode::Callable); + UNIT_ASSERT_VALUES_EQUAL(res->Content(), "Func"); + UNIT_ASSERT_VALUES_EQUAL(res->ChildrenSize(), 1); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Type(), TExprNode::World); + } + + Y_UNIT_TEST(TestIncompleteRootLambda) { + TExprContext ctx; + UNIT_ASSERT_EXCEPTION(ctx.Builder(TPositionHandle()) + .Lambda() + .Build(), yexception); + } + + Y_UNIT_TEST(TestIncompleteInnerLambda) { + TExprContext ctx; + UNIT_ASSERT_EXCEPTION(ctx.Builder(TPositionHandle()) + .List() + .Lambda() + .Seal() + .Build(), yexception); + } + + Y_UNIT_TEST(TestRootLambda) { + TExprContext ctx; + auto res = ctx.Builder(TPositionHandle()).Lambda().Atom("ABC").Seal().Build(); + UNIT_ASSERT_VALUES_EQUAL(res->Type(), TExprNode::Lambda); + UNIT_ASSERT_VALUES_EQUAL(res->ChildrenSize(), 2); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Type(), TExprNode::Arguments); + UNIT_ASSERT_VALUES_EQUAL(res->Head().ChildrenSize(), 0); + UNIT_ASSERT_VALUES_EQUAL(res->Child(1)->Type(), TExprNode::Atom); + UNIT_ASSERT_VALUES_EQUAL(res->Child(1)->Content(), "ABC"); + } + + Y_UNIT_TEST(TestRootLambdaWithBodyAsSet) { + TExprContext ctx; + auto res = ctx.Builder(TPositionHandle()) + .Lambda() + .Set(ctx.Builder(TPositionHandle()).Atom("ABC").Build()) + .Seal() + .Build(); + + UNIT_ASSERT_VALUES_EQUAL(res->Type(), TExprNode::Lambda); + UNIT_ASSERT_VALUES_EQUAL(res->ChildrenSize(), 2); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Type(), TExprNode::Arguments); + UNIT_ASSERT_VALUES_EQUAL(res->Head().ChildrenSize(), 0); + UNIT_ASSERT_VALUES_EQUAL(res->Child(1)->Type(), TExprNode::Atom); + UNIT_ASSERT_VALUES_EQUAL(res->Child(1)->Content(), "ABC"); + } + + Y_UNIT_TEST(TestInnerLambdaWithParam) { + TExprContext ctx; + auto res = ctx.Builder(TPositionHandle()) + .List() + .Lambda(0) + .Param("x") + .Atom("ABC") + .Seal() + .Seal() + .Build(); + + UNIT_ASSERT_VALUES_EQUAL(res->Type(), TExprNode::List); + UNIT_ASSERT_VALUES_EQUAL(res->ChildrenSize(), 1); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Type(), TExprNode::Lambda); + UNIT_ASSERT_VALUES_EQUAL(res->Head().ChildrenSize(), 2); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Head().Type(), TExprNode::Arguments); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Head().ChildrenSize(), 1); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Head().Head().Type(), TExprNode::Argument); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Head().Head().Content(), "x"); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Child(1)->Type(), TExprNode::Atom); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Child(1)->Content(), "ABC"); + } + + Y_UNIT_TEST(TestDuplicateLambdaParamNames) { + TExprContext ctx; + UNIT_ASSERT_EXCEPTION(ctx.Builder(TPositionHandle()) + .Lambda() + .Param("x") + .Param("x") + .Atom("ABC") + .Seal() + .Build(), yexception); + } + + Y_UNIT_TEST(TestParamAtRoot) { + TExprContext ctx; + UNIT_ASSERT_EXCEPTION(ctx.Builder(TPositionHandle()) + .Param("aaa") + .Build(), yexception); + } + + Y_UNIT_TEST(TestParamInList) { + TExprContext ctx; + UNIT_ASSERT_EXCEPTION(ctx.Builder(TPositionHandle()) + .List() + .Param("aaa") + .Seal() + .Build(), yexception); + } + + Y_UNIT_TEST(TestParamInCallable) { + TExprContext ctx; + UNIT_ASSERT_EXCEPTION(ctx.Builder(TPositionHandle()) + .Callable("Func") + .Param("aaa") + .Seal() + .Build(), yexception); + } + + Y_UNIT_TEST(TestParamAfterLambdaBody) { + TExprContext ctx; + UNIT_ASSERT_EXCEPTION(ctx.Builder(TPositionHandle()) + .Lambda() + .Param("aaa") + .Atom("ABC") + .Param("bbb") + .Seal() + .Build(), yexception); + } + + Y_UNIT_TEST(TestIndexedAtomAtRoot) { + TExprContext ctx; + UNIT_ASSERT_EXCEPTION(ctx.Builder(TPositionHandle()) + .Atom(0, "ABC") + .Build(), yexception); + } + + Y_UNIT_TEST(TestIndexedListAtRoot) { + TExprContext ctx; + UNIT_ASSERT_EXCEPTION(ctx.Builder(TPositionHandle()) + .List(0) + .Seal() + .Build(), yexception); + } + + Y_UNIT_TEST(TestIndexedWorldAtRoot) { + TExprContext ctx; + UNIT_ASSERT_EXCEPTION(ctx.Builder(TPositionHandle()) + .World(0) + .Build(), yexception); + } + + Y_UNIT_TEST(TestIndexedCallableAtRoot) { + TExprContext ctx; + UNIT_ASSERT_EXCEPTION(ctx.Builder(TPositionHandle()) + .Callable(0, "Func") + .Seal() + .Build(), yexception); + } + + Y_UNIT_TEST(TestIndexedLambdaAtRoot) { + TExprContext ctx; + UNIT_ASSERT_EXCEPTION(ctx.Builder(TPositionHandle()) + .Lambda(0) + .Atom("ABC") + .Seal() + .Build(), yexception); + } + + Y_UNIT_TEST(TestWrongIndexAtomAtLambda) { + TExprContext ctx; + UNIT_ASSERT_EXCEPTION(ctx.Builder(TPositionHandle()) + .Lambda() + .Atom(1, "ABC") + .Seal() + .Build(), yexception); + } + + Y_UNIT_TEST(TestWrongIndexListAtLambda) { + TExprContext ctx; + UNIT_ASSERT_EXCEPTION(ctx.Builder(TPositionHandle()) + .Lambda() + .List(1) + .Seal() + .Seal() + .Build(), yexception); + } + + Y_UNIT_TEST(TestWrongIndexWorldAtLambda) { + TExprContext ctx; + UNIT_ASSERT_EXCEPTION(ctx.Builder(TPositionHandle()) + .Lambda() + .World(1) + .Seal() + .Build(), yexception); + } + + Y_UNIT_TEST(TestWrongIndexCallableAtLambda) { + TExprContext ctx; + UNIT_ASSERT_EXCEPTION(ctx.Builder(TPositionHandle()) + .Lambda() + .Callable(1, "Func") + .Seal() + .Seal() + .Build(), yexception); + } + + Y_UNIT_TEST(TestWrongIndexLambdaAtLambda) { + TExprContext ctx; + UNIT_ASSERT_EXCEPTION(ctx.Builder(TPositionHandle()) + .Lambda() + .Lambda(1) + .Atom("ABC") + .Seal() + .Seal() + .Build(), yexception); + } + + Y_UNIT_TEST(TestAddAtLambda) { + TExprContext ctx; + UNIT_ASSERT_EXCEPTION(ctx.Builder(TPositionHandle()) + .Lambda() + .Add(1, ctx.Builder(TPositionHandle()).Atom("ABC").Build()) + .Seal() + .Build(), yexception); + } + + Y_UNIT_TEST(TestLambdaWithArgAsBody) { + TExprContext ctx; + auto res = ctx.Builder(TPositionHandle()) + .Lambda() + .Param("x") + .Param("y") + .Arg("x") + .Seal() + .Build(); + + UNIT_ASSERT_VALUES_EQUAL(res->Type(), TExprNode::Lambda); + UNIT_ASSERT_VALUES_EQUAL(res->ChildrenSize(), 2); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Type(), TExprNode::Arguments); + UNIT_ASSERT_VALUES_EQUAL(res->Head().ChildrenSize(), 2); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Head().Type(), TExprNode::Argument); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Head().Content(), "x"); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Child(1)->Type(), TExprNode::Argument); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Child(1)->Content(), "y"); + UNIT_ASSERT_EQUAL(res->Child(1), res->Head().Child(0)); + } + + Y_UNIT_TEST(TestIndexedArgAsLambdaBody) { + TExprContext ctx; + UNIT_ASSERT_EXCEPTION(ctx.Builder(TPositionHandle()).Lambda() + .Param("x") + .Arg(1, "x") + .Seal() + .Build(), yexception); + } + + Y_UNIT_TEST(TestWrongArgName) { + TExprContext ctx; + UNIT_ASSERT_EXCEPTION(ctx.Builder(TPositionHandle()).Lambda() + .Param("x") + .Arg("y") + .Seal() + .Build(), yexception); + } + + Y_UNIT_TEST(TestLambdaWithArgInCallables) { + TExprContext ctx; + auto res = ctx.Builder(TPositionHandle()) + .Lambda() + .Param("x") + .Param("y") + .Callable("+") + .Arg(0, "y") + .Arg(1, "x") + .Seal() + .Seal() + .Build(); + + UNIT_ASSERT_VALUES_EQUAL(res->Type(), TExprNode::Lambda); + UNIT_ASSERT_VALUES_EQUAL(res->ChildrenSize(), 2); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Type(), TExprNode::Arguments); + UNIT_ASSERT_VALUES_EQUAL(res->Head().ChildrenSize(), 2); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Head().Type(), TExprNode::Argument); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Head().Content(), "x"); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Child(1)->Type(), TExprNode::Argument); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Child(1)->Content(), "y"); + UNIT_ASSERT_VALUES_EQUAL(res->Child(1)->Type(), TExprNode::Callable); + UNIT_ASSERT_VALUES_EQUAL(res->Child(1)->Content(), "+"); + UNIT_ASSERT_VALUES_EQUAL(res->Child(1)->ChildrenSize(), 2); + UNIT_ASSERT_EQUAL(res->Child(1)->Child(0), res->Head().Child(1)); + UNIT_ASSERT_EQUAL(res->Child(1)->Child(1), res->Head().Child(0)); + } + + Y_UNIT_TEST(TestNestedScopeInLambda) { + TExprContext ctx; + auto res = ctx.Builder(TPositionHandle()) + .Lambda() + .Param("x") + .Param("y") + .Callable("Apply") + .Lambda(0) + .Param("x") + .Callable("+") + .Arg(0, "x") + .Arg(1, "y") + .Seal() + .Seal() + .Seal() + .Seal() + .Build(); + + UNIT_ASSERT_VALUES_EQUAL(res->Type(), TExprNode::Lambda); + UNIT_ASSERT_VALUES_EQUAL(res->ChildrenSize(), 2); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Type(), TExprNode::Arguments); + UNIT_ASSERT_VALUES_EQUAL(res->Head().ChildrenSize(), 2); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Head().Type(), TExprNode::Argument); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Head().Content(), "x"); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Child(1)->Type(), TExprNode::Argument); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Child(1)->Content(), "y"); + UNIT_ASSERT_VALUES_EQUAL(res->Child(1)->Type(), TExprNode::Callable); + UNIT_ASSERT_VALUES_EQUAL(res->Child(1)->Content(), "Apply"); + UNIT_ASSERT_VALUES_EQUAL(res->Child(1)->ChildrenSize(), 1); + UNIT_ASSERT_VALUES_EQUAL(res->Child(1)->Head().Type(), TExprNode::Lambda); + UNIT_ASSERT_VALUES_EQUAL(res->Child(1)->Head().Head().ChildrenSize(), 1); + UNIT_ASSERT_VALUES_EQUAL(res->Child(1)->Head().Head().Head().Type(), TExprNode::Argument); + UNIT_ASSERT_VALUES_EQUAL(res->Child(1)->Head().Head().Head().Content(), "x"); + UNIT_ASSERT_VALUES_EQUAL(res->Child(1)->Head().Child(1)->Type(), TExprNode::Callable); + UNIT_ASSERT_VALUES_EQUAL(res->Child(1)->Head().Child(1)->Content(), "+"); + UNIT_ASSERT_VALUES_EQUAL(res->Child(1)->Head().Child(1)->ChildrenSize(), 2); + // nested x + UNIT_ASSERT_VALUES_EQUAL(res->Child(1)->Head().Child(1)->Child(0), + res->Child(1)->Head().Head().Child(0)); + // outer y + UNIT_ASSERT_VALUES_EQUAL(res->Child(1)->Head().Child(1)->Child(1), + res->Head().Child(1)); + } + + Y_UNIT_TEST(TestNonIndexedArg) { + TExprContext ctx; + UNIT_ASSERT_EXCEPTION(ctx.Builder(TPositionHandle()).Lambda() + .Param("x") + .Callable("f") + .Arg("x") + .Seal() + .Seal() + .Build(), yexception); + } + + Y_UNIT_TEST(TestApplyLambdaArgAsRoot) { + TExprContext ctx; + auto lambda = ctx.Builder(TPositionHandle()) + .Lambda() + .Param("x") + .Arg("x") + .Seal() + .Build(); + + auto res = ctx.Builder(TPositionHandle()) + .Lambda() + .Param("y") + .Apply(lambda).With(0, "y").Seal() + .Seal() + .Build(); + + UNIT_ASSERT_VALUES_EQUAL(res->Type(), TExprNode::Lambda); + UNIT_ASSERT_VALUES_EQUAL(res->ChildrenSize(), 2); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Type(), TExprNode::Arguments); + UNIT_ASSERT_VALUES_EQUAL(res->Head().ChildrenSize(), 1); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Head().Type(), TExprNode::Argument); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Head().Content(), "y"); + UNIT_ASSERT_EQUAL(res->Child(1), res->Head().Child(0)); + } + + Y_UNIT_TEST(TestApplyLambdaArgInContainer) { + TExprContext ctx; + auto lambda = ctx.Builder(TPositionHandle()) + .Lambda() + .Param("x") + .Arg("x") + .Seal() + .Build(); + + auto res = ctx.Builder(TPositionHandle()) + .Lambda() + .Param("y") + .List() + .Apply(0, lambda).With(0, "y").Seal() + .Seal() + .Seal() + .Build(); + + UNIT_ASSERT_VALUES_EQUAL(res->Type(), TExprNode::Lambda); + UNIT_ASSERT_VALUES_EQUAL(res->ChildrenSize(), 2); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Type(), TExprNode::Arguments); + UNIT_ASSERT_VALUES_EQUAL(res->Head().ChildrenSize(), 1); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Head().Type(), TExprNode::Argument); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Head().Content(), "y"); + UNIT_ASSERT_VALUES_EQUAL(res->Child(1)->Type(), TExprNode::List); + UNIT_ASSERT_VALUES_EQUAL(res->Child(1)->ChildrenSize(), 1); + UNIT_ASSERT_EQUAL(res->Child(1)->Child(0), res->Head().Child(0)); + } + + Y_UNIT_TEST(TestApplyPartialLambdaArgAsRoot) { + TExprContext ctx; + auto lambda = ctx.Builder(TPositionHandle()) + .Lambda() + .Param("x") + .Callable("Func1") + .Callable(0, "Func2") + .Atom(0, "ABC") + .Arg(1, "x") + .Seal() + .Seal() + .Seal() + .Build(); + + auto res = ctx.Builder(TPositionHandle()) + .Lambda() + .Param("y") + .ApplyPartial(lambda->HeadPtr(), lambda->Child(1)->HeadPtr()).With(0, "y").Seal() + .Seal() + .Build(); + + UNIT_ASSERT_VALUES_EQUAL(res->Type(), TExprNode::Lambda); + UNIT_ASSERT_VALUES_EQUAL(res->ChildrenSize(), 2); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Type(), TExprNode::Arguments); + UNIT_ASSERT_VALUES_EQUAL(res->Head().ChildrenSize(), 1); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Head().Type(), TExprNode::Argument); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Head().Content(), "y"); + UNIT_ASSERT_EQUAL(res->Child(1)->Child(1), res->Head().Child(0)); + } + + Y_UNIT_TEST(TestApplyPartialLambdaArgInContainer) { + TExprContext ctx; + auto lambda = ctx.Builder(TPositionHandle()) + .Lambda() + .Param("x") + .Callable("Func1") + .Callable(0, "Func2") + .Atom(0, "ABC") + .Arg(1, "x") + .Seal() + .Seal() + .Seal() + .Build(); + + auto res = ctx.Builder(TPositionHandle()) + .Lambda() + .Param("y") + .Callable("Func3") + .ApplyPartial(0, lambda->HeadPtr(), lambda->Child(1)->HeadPtr()).With(0, "y").Seal() + .Seal() + .Seal() + .Build(); + + UNIT_ASSERT_VALUES_EQUAL(res->Type(), TExprNode::Lambda); + UNIT_ASSERT_VALUES_EQUAL(res->ChildrenSize(), 2); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Type(), TExprNode::Arguments); + UNIT_ASSERT_VALUES_EQUAL(res->Head().ChildrenSize(), 1); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Head().Type(), TExprNode::Argument); + UNIT_ASSERT_VALUES_EQUAL(res->Head().Head().Content(), "y"); + UNIT_ASSERT_EQUAL(res->Child(1)->Head().Child(1), res->Head().Child(0)); + } + + Y_UNIT_TEST(TestApplyOuterArg) { + TExprContext ctx; + auto ast = ctx.Builder(TPositionHandle()) + .Lambda() + .Param("x") + .Callable("Func1") + .Atom(0, "p1") + .Lambda(1) + .Callable("Func2") + .Atom(0, "ABC") + .Arg(1, "x") + .Seal() + .Seal() + .Seal() + .Seal() + .Build(); + + auto res1 = ctx.Builder(TPositionHandle()) + .Lambda() + .Param("y") + .Callable("Func3") + .ApplyPartial(0, nullptr, ast->Child(1)->Child(1)->ChildPtr(1)) + .WithNode(*ast->Head().Child(0), "y").Seal() + .Seal() + .Seal() + .Build(); + + UNIT_ASSERT_VALUES_EQUAL(res1->Type(), TExprNode::Lambda); + UNIT_ASSERT_VALUES_EQUAL(res1->ChildrenSize(), 2); + UNIT_ASSERT_VALUES_EQUAL(res1->Head().Type(), TExprNode::Arguments); + UNIT_ASSERT_VALUES_EQUAL(res1->Head().ChildrenSize(), 1); + UNIT_ASSERT_VALUES_EQUAL(res1->Head().Head().Type(), TExprNode::Argument); + UNIT_ASSERT_VALUES_EQUAL(res1->Head().Head().Content(), "y"); + UNIT_ASSERT_EQUAL(res1->Child(1)->Head().Child(1), res1->Head().Child(0)); + + auto atom = ctx.Builder(TPositionHandle()) + .Atom("const") + .Build(); + + auto res2 = ctx.Builder(TPositionHandle()) + .Lambda() + .Callable("Func3") + .ApplyPartial(0, nullptr, ast->Child(1)->Child(1)->ChildPtr(1)) + .WithNode(ast->Head().Head(), TExprNode::TPtr(atom)).Seal() + .Seal() + .Seal() + .Build(); + + UNIT_ASSERT_VALUES_EQUAL(res2->Type(), TExprNode::Lambda); + UNIT_ASSERT_VALUES_EQUAL(res2->ChildrenSize(), 2); + UNIT_ASSERT_VALUES_EQUAL(res2->Head().Type(), TExprNode::Arguments); + UNIT_ASSERT_VALUES_EQUAL(res2->Head().ChildrenSize(), 0); + UNIT_ASSERT_EQUAL(res2->Child(1)->Head().ChildPtr(1), atom); + } +} + +} // namespace NYql diff --git a/yql/essentials/ast/yql_expr_check_args_ut.cpp b/yql/essentials/ast/yql_expr_check_args_ut.cpp new file mode 100644 index 00000000000..9f1a8af420a --- /dev/null +++ b/yql/essentials/ast/yql_expr_check_args_ut.cpp @@ -0,0 +1,115 @@ +#include "yql_expr.h" +#include <library/cpp/testing/unittest/registar.h> + +namespace NYql { + +Y_UNIT_TEST_SUITE(TExprCheckArguments) { + Y_UNIT_TEST(TestDuplicateArgument) { + TExprContext ctx; + auto pos = TPositionHandle(); + auto arg0 = ctx.NewArgument(pos, "arg0"); + auto args = ctx.NewArguments(pos, { arg0 }); + auto body = ctx.Builder(pos) + .Callable("+") + .Add(0, arg0) + .Add(1, arg0) + .Seal() + .Build(); + + auto left = ctx.NewLambda(pos, std::move(args), std::move(body)); + + auto arg1 = ctx.NewArgument(pos, "arg0"); + args = ctx.NewArguments(pos, { arg0, arg1 }); + body = ctx.Builder(pos) + .Callable("+") + .Add(0, arg0) + .Add(1, arg1) + .Seal() + .Build(); + + auto right = ctx.NewLambda(pos, std::move(args), std::move(body)); + + auto root = ctx.Builder(pos) + .Callable("SomeTopLevelCallableWithTwoLambdas") + .Add(0, left) + .Add(1, right) + .Seal() + .Build(); + + UNIT_ASSERT_EXCEPTION_CONTAINS(CheckArguments(*root), yexception, "argument is duplicated, #[1]"); + } + + Y_UNIT_TEST(TestUnresolved) { + TExprContext ctx; + auto pos = TPositionHandle(); + + auto arg1 = ctx.NewArgument(pos, "arg1"); + auto arg0 = ctx.NewArgument(pos, "arg0"); + + auto innerLambdaBody = ctx.Builder(pos) + .Callable("+") + .Add(0, arg0) + .Add(1, arg1) + .Seal() + .Build(); + + auto innerLambda = ctx.NewLambda(pos, ctx.NewArguments(pos, { arg1 }), std::move(innerLambdaBody)); + + auto outerLambda = ctx.NewLambda(pos, ctx.NewArguments(pos, { arg0 }), TExprNode::TPtr(innerLambda)); + + auto root = ctx.Builder(pos) + .Callable("SomeTopLevelCallableWithTwoLambdasAndFreeArg") + .Add(0, outerLambda) + .Add(1, innerLambda) + .Seal() + .Build(); + + UNIT_ASSERT_EXCEPTION_CONTAINS(CheckArguments(*root), yexception, "detected unresolved arguments at top level: #[2]"); + + root = ctx.Builder(pos) + .Callable("SomeTopLevelCallableWithTwoLambdasAndFreeArg") + .Add(0, outerLambda) + .Add(1, innerLambda) + .Add(2, ctx.NewArgument(pos, "arg3")) + .Seal() + .Build(); + + UNIT_ASSERT_EXCEPTION_CONTAINS(CheckArguments(*root), yexception, "detected unresolved arguments at top level: #[2, 10]"); + } + + Y_UNIT_TEST(TestUnresolvedFreeArg) { + TExprContext ctx; + auto pos = TPositionHandle(); + auto arg = ctx.NewArgument(pos, "arg"); + UNIT_ASSERT_EXCEPTION_CONTAINS(CheckArguments(*arg), yexception, "detected unresolved arguments at top level: #[1]"); + } + + Y_UNIT_TEST(TestOk) { + TExprContext ctx; + auto pos = TPositionHandle(); + + auto root = ctx.Builder(pos) + .Callable("TopLevelCallableWithTwoLambdas") + .Lambda(0) + .Param("one") + .Lambda() + .Param("two") + .Callable("+") + .Arg(0, "one") + .Arg(1, "two") + .Seal() + .Seal() + .Seal() + .Lambda(1) + .Param("three") + .Callable("Not") + .Arg(0, "three") + .Seal() + .Seal() + .Seal() + .Build(); + UNIT_ASSERT_NO_EXCEPTION(CheckArguments(*root)); + } +} + +} // namespace NYql diff --git a/yql/essentials/ast/yql_expr_types.cpp b/yql/essentials/ast/yql_expr_types.cpp new file mode 100644 index 00000000000..0e44a620f31 --- /dev/null +++ b/yql/essentials/ast/yql_expr_types.cpp @@ -0,0 +1,19 @@ +#include "yql_expr_types.h" + +namespace NYql { +} + +template<> +void Out<NYql::ETypeAnnotationKind>(class IOutputStream &o, NYql::ETypeAnnotationKind x) { +#define YQL_TYPE_ANN_KIND_MAP_TO_STRING_IMPL(name, ...) \ + case NYql::ETypeAnnotationKind::name: \ + o << #name; \ + return; + + switch (x) { + YQL_TYPE_ANN_KIND_MAP(YQL_TYPE_ANN_KIND_MAP_TO_STRING_IMPL) + default: + o << static_cast<int>(x); + return; + } +} diff --git a/yql/essentials/ast/yql_expr_types.h b/yql/essentials/ast/yql_expr_types.h new file mode 100644 index 00000000000..bb628af2725 --- /dev/null +++ b/yql/essentials/ast/yql_expr_types.h @@ -0,0 +1,43 @@ +#pragma once +#include <library/cpp/deprecated/enum_codegen/enum_codegen.h> +#include <util/stream/output.h> + +namespace NYql { + +#define YQL_TYPE_ANN_KIND_MAP(xx) \ + xx(Unit, 1) \ + xx(Tuple, 2) \ + xx(Struct, 3) \ + xx(Item, 4) \ + xx(List, 5) \ + xx(Data, 6) \ + xx(World, 7) \ + xx(Optional, 8) \ + xx(Type, 9) \ + xx(Dict, 10) \ + xx(Void, 11) \ + xx(Callable, 12) \ + xx(Generic, 13) \ + xx(Resource, 14) \ + xx(Tagged, 15) \ + xx(Error, 16) \ + xx(Variant, 17) \ + xx(Stream, 18) \ + xx(Null, 19) \ + xx(Flow, 20) \ + xx(EmptyList, 21) \ + xx(EmptyDict, 22) \ + xx(Multi, 23) \ + xx(Pg, 24) \ + xx(Block, 25) \ + xx(Scalar, 26) + +enum class ETypeAnnotationKind : ui64 { + YQL_TYPE_ANN_KIND_MAP(ENUM_VALUE_GEN) + LastType +}; + +} + +template<> +void Out<NYql::ETypeAnnotationKind>(class IOutputStream &o, NYql::ETypeAnnotationKind x); diff --git a/yql/essentials/ast/yql_expr_ut.cpp b/yql/essentials/ast/yql_expr_ut.cpp new file mode 100644 index 00000000000..0db8ecf506c --- /dev/null +++ b/yql/essentials/ast/yql_expr_ut.cpp @@ -0,0 +1,1218 @@ +#include "yql_expr.h" +#include <library/cpp/testing/unittest/registar.h> + +#include <util/string/hex.h> + +namespace NYql { + +Y_UNIT_TEST_SUITE(TCompileYqlExpr) { + + static TAstParseResult ParseAstWithCheck(const TStringBuf& s) { + TAstParseResult res = ParseAst(s); + res.Issues.PrintTo(Cout); + UNIT_ASSERT(res.IsOk()); + return res; + } + + static void CompileExprWithCheck(TAstNode& root, TExprNode::TPtr& exprRoot, TExprContext& exprCtx, ui32 typeAnnotationIndex = Max<ui32>()) { + const bool success = CompileExpr(root, exprRoot, exprCtx, nullptr, nullptr, typeAnnotationIndex != Max<ui32>(), typeAnnotationIndex); + exprCtx.IssueManager.GetIssues().PrintTo(Cout); + + UNIT_ASSERT(success); + UNIT_ASSERT_VALUES_EQUAL(exprRoot->GetState(), typeAnnotationIndex != Max<ui32>() ? TExprNode::EState::TypeComplete : TExprNode::EState::Initial); + } + + static void CompileExprWithCheck(TAstNode& root, TLibraryCohesion& cohesion, TExprContext& exprCtx) { + const bool success = CompileExpr(root, cohesion, exprCtx); + exprCtx.IssueManager.GetIssues().PrintTo(Cout); + + UNIT_ASSERT(success); + } + + static bool ParseAndCompile(const TString& program) { + TAstParseResult astRes = ParseAstWithCheck(program); + TExprContext exprCtx; + TExprNode::TPtr exprRoot; + bool result = CompileExpr(*astRes.Root, exprRoot, exprCtx, nullptr, nullptr); + exprCtx.IssueManager.GetIssues().PrintTo(Cout); + return result; + } + + Y_UNIT_TEST(TestNoReturn1) { + auto s = "(\n" + ")\n"; + UNIT_ASSERT(false == ParseAndCompile(s)); + } + + Y_UNIT_TEST(TestNoReturn2) { + auto s = "(\n" + "(let x 'y)\n" + ")\n"; + UNIT_ASSERT(false == ParseAndCompile(s)); + } + + Y_UNIT_TEST(TestExportInsteadOfReturn) { + const auto s = + "# library\n" + "(\n" + " (let sqr (lambda '(x) (* x x)))\n" + " (export sqr)\n" + ")\n" + ; + UNIT_ASSERT(false == ParseAndCompile(s)); + } + + Y_UNIT_TEST(TestLeftAfterReturn) { + auto s = "(\n" + "(return 'x)\n" + "(let x 'y)\n" + ")\n"; + UNIT_ASSERT(false == ParseAndCompile(s)); + } + + Y_UNIT_TEST(TestReturn) { + auto s = "(\n" + "(return world)\n" + ")\n"; + + TAstParseResult astRes = ParseAstWithCheck(s); + TExprContext exprCtx; + TExprNode::TPtr exprRoot; + CompileExprWithCheck(*astRes.Root, exprRoot, exprCtx); + UNIT_ASSERT_VALUES_EQUAL(exprRoot->Type(), TExprNode::World); + } + + Y_UNIT_TEST(TestExport) { + auto s = "(\n" + "(let X 'Y)\n" + "(let ex '42)\n" + "(export ex)\n" + "(export X)\n" + ")\n"; + + TAstParseResult astRes = ParseAstWithCheck(s); + TExprContext exprCtx; + TLibraryCohesion cohesion; + CompileExprWithCheck(*astRes.Root, cohesion, exprCtx); + auto& exports = cohesion.Exports.Symbols(exprCtx); + UNIT_ASSERT_VALUES_EQUAL(2U, exports.size()); + UNIT_ASSERT_VALUES_EQUAL("42", exports["ex"]->Content()); + UNIT_ASSERT_VALUES_EQUAL("Y", exports["X"]->Content()); + } + + Y_UNIT_TEST(TestEmptyLib) { + auto s = "(\n" + "(let X 'Y)\n" + "(let ex '42)\n" + ")\n"; + + TAstParseResult astRes = ParseAstWithCheck(s); + TExprContext exprCtx; + TLibraryCohesion cohesion; + CompileExprWithCheck(*astRes.Root, cohesion, exprCtx); + UNIT_ASSERT(cohesion.Exports.Symbols().empty()); + UNIT_ASSERT(cohesion.Imports.empty()); + } + + Y_UNIT_TEST(TestArbitraryAtom) { + auto s = "(\n" + "(let x '\"\\x01\\x23\\x45\\x67\\x89\\xAB\\xCD\\xEF\")" + "(return x)\n" + ")\n"; + + TAstParseResult astRes = ParseAstWithCheck(s); + TExprContext exprCtx; + TExprNode::TPtr exprRoot; + CompileExprWithCheck(*astRes.Root, exprRoot, exprCtx); + UNIT_ASSERT_VALUES_EQUAL(exprRoot->Type(), TExprNode::Atom); + UNIT_ASSERT_STRINGS_EQUAL(HexEncode(exprRoot->Content()), "0123456789ABCDEF"); + UNIT_ASSERT(exprRoot->Flags() & TNodeFlags::ArbitraryContent); + + auto ast = ConvertToAst(*exprRoot, exprCtx, TExprAnnotationFlags::None, true); + TAstNode* xValue = ast.Root->GetChild(0)->GetChild(1)->GetChild(1); + UNIT_ASSERT_STRINGS_EQUAL(HexEncode(TString(xValue->GetContent())), "0123456789ABCDEF"); + UNIT_ASSERT(xValue->GetFlags() & TNodeFlags::ArbitraryContent); + } + + Y_UNIT_TEST(TestBinaryAtom) { + auto s = "(\n" + "(let x 'x\"FEDCBA9876543210\")" + "(return x)\n" + ")\n"; + + TAstParseResult astRes = ParseAstWithCheck(s); + TExprContext exprCtx; + TExprNode::TPtr exprRoot; + CompileExprWithCheck(*astRes.Root, exprRoot, exprCtx); + UNIT_ASSERT_VALUES_EQUAL(exprRoot->Type(), TExprNode::Atom); + UNIT_ASSERT_STRINGS_EQUAL(HexEncode(exprRoot->Content()), "FEDCBA9876543210"); + UNIT_ASSERT(exprRoot->Flags() & TNodeFlags::BinaryContent); + + auto ast = ConvertToAst(*exprRoot, exprCtx, TExprAnnotationFlags::None, true); + TAstNode* xValue = ast.Root->GetChild(0)->GetChild(2)->GetChild(1); + UNIT_ASSERT_STRINGS_EQUAL(HexEncode(TString(xValue->GetContent())), "FEDCBA9876543210"); + UNIT_ASSERT(xValue->GetFlags() & TNodeFlags::BinaryContent); + } + + Y_UNIT_TEST(TestLet) { + auto s = "(\n" + "(let x 'y)\n" + "(return x)\n" + ")\n"; + + TAstParseResult astRes = ParseAstWithCheck(s); + TExprContext exprCtx; + TExprNode::TPtr exprRoot; + CompileExprWithCheck(*astRes.Root, exprRoot, exprCtx); + UNIT_ASSERT_VALUES_EQUAL(exprRoot->Type(), TExprNode::Atom); + UNIT_ASSERT_VALUES_EQUAL(exprRoot->Content(), "y"); + } + + Y_UNIT_TEST(TestComplexQuote) { + auto s = "(\n" + "(let x 'a)\n" + "(let y 'b)\n" + "(let z (quote (x y)))\n" + "(return z)\n" + ")\n"; + + TAstParseResult astRes = ParseAstWithCheck(s); + TExprContext exprCtx; + TExprNode::TPtr exprRoot; + CompileExprWithCheck(*astRes.Root, exprRoot, exprCtx); + UNIT_ASSERT_VALUES_EQUAL(exprRoot->Type(), TExprNode::List); + UNIT_ASSERT_VALUES_EQUAL(exprRoot->ChildrenSize(), 2); + UNIT_ASSERT_VALUES_EQUAL(exprRoot->Child(0)->Type(), TExprNode::Atom); + UNIT_ASSERT_VALUES_EQUAL(exprRoot->Child(0)->Content(), "a"); + UNIT_ASSERT_VALUES_EQUAL(exprRoot->Child(1)->Type(), TExprNode::Atom); + UNIT_ASSERT_VALUES_EQUAL(exprRoot->Child(1)->Content(), "b"); + } + + Y_UNIT_TEST(TestEmptyReturn) { + auto s = "(\n" + "(return)\n" + ")\n"; + UNIT_ASSERT(false == ParseAndCompile(s)); + } + + Y_UNIT_TEST(TestManyReturn) { + auto s = "(\n" + "(return world world)\n" + ")\n"; + UNIT_ASSERT(false == ParseAndCompile(s)); + } + + Y_UNIT_TEST(TestUnknownFunction) { + auto s = "(\n" + "(let a '2)\n" + "(let x (+ a '3))\n" + "(return x)\n" + ")\n"; + + TAstParseResult astRes = ParseAstWithCheck(s); + TExprContext exprCtx; + TExprNode::TPtr exprRoot; + CompileExprWithCheck(*astRes.Root, exprRoot, exprCtx); + UNIT_ASSERT_VALUES_EQUAL(exprRoot->Type(), TExprNode::Callable); + UNIT_ASSERT_VALUES_EQUAL(exprRoot->Content(), "+"); + UNIT_ASSERT_VALUES_EQUAL(exprRoot->ChildrenSize(), 2); + UNIT_ASSERT_VALUES_EQUAL(exprRoot->Child(0)->Type(), TExprNode::Atom); + UNIT_ASSERT_VALUES_EQUAL(exprRoot->Child(0)->Content(), "2"); + UNIT_ASSERT_VALUES_EQUAL(exprRoot->Child(1)->Type(), TExprNode::Atom); + UNIT_ASSERT_VALUES_EQUAL(exprRoot->Child(1)->Content(), "3"); + } + + Y_UNIT_TEST(TestReturnTwice) { + auto s = "(\n" + "(return)\n" + "(return)\n" + ")\n"; + UNIT_ASSERT(false == ParseAndCompile(s)); + } + + Y_UNIT_TEST(TestDeclareNonTop) { + const auto s = R"( + ( + (let $1 (block '( + (declare $param (DataType 'Uint32)) + (return $param) + ))) + (return $1) + ) + )"; + UNIT_ASSERT(false == ParseAndCompile(s)); + } + + Y_UNIT_TEST(TestDeclareHideLet) { + const auto s = R"( + ( + (let $name (Uint32 '10)) + (declare $name (DataType 'Uint32)) + (return $name) + ) + )"; + UNIT_ASSERT(false == ParseAndCompile(s)); + } + + Y_UNIT_TEST(TestDeclareBadName) { + const auto s = R"( + ( + (declare $15 (DataType 'Uint32)) + (return $15) + ) + )"; + UNIT_ASSERT(false == ParseAndCompile(s)); + } + + Y_UNIT_TEST(TestLetHideDeclare) { + const auto s = R"( + ( + (declare $name (DataType 'Uint32)) + (let $name (Uint32 '10)) + (return $name) + ) + )"; + + TAstParseResult astRes = ParseAstWithCheck(s); + TExprContext exprCtx; + TExprNode::TPtr exprRoot; + CompileExprWithCheck(*astRes.Root, exprRoot, exprCtx); + UNIT_ASSERT_VALUES_EQUAL(exprRoot->Type(), TExprNode::Callable); + UNIT_ASSERT_VALUES_EQUAL(exprRoot->Content(), "Uint32"); + UNIT_ASSERT_VALUES_EQUAL(exprRoot->ChildrenSize(), 1); + UNIT_ASSERT_VALUES_EQUAL(exprRoot->Child(0)->Type(), TExprNode::Atom); + UNIT_ASSERT_VALUES_EQUAL(exprRoot->Child(0)->Content(), "10"); + } + + Y_UNIT_TEST(TestDeclare) { + const auto s = R"( + ( + (declare $param (DataType 'Uint32)) + (return $param) + ) + )"; + + TAstParseResult astRes = ParseAstWithCheck(s); + TExprContext exprCtx; + TExprNode::TPtr exprRoot; + CompileExprWithCheck(*astRes.Root, exprRoot, exprCtx); + UNIT_ASSERT_VALUES_EQUAL(exprRoot->Type(), TExprNode::Callable); + UNIT_ASSERT_VALUES_EQUAL(exprRoot->ChildrenSize(), 2); + UNIT_ASSERT_VALUES_EQUAL(exprRoot->Child(0)->Type(), TExprNode::Atom); + UNIT_ASSERT_VALUES_EQUAL(exprRoot->Child(0)->Content(), "$param"); + UNIT_ASSERT_VALUES_EQUAL(exprRoot->Child(1)->Type(), TExprNode::Callable); + UNIT_ASSERT_VALUES_EQUAL(exprRoot->Child(1)->Content(), "DataType"); + } +} + +Y_UNIT_TEST_SUITE(TCompareExprTrees) { + void CompileAndCompare(const TString& one, const TString& two, const std::pair<TPosition, TPosition> *const diffPositions = nullptr) { + const auto progOne(ParseAst(one)), progTwo(ParseAst(two)); + UNIT_ASSERT(progOne.IsOk() && progTwo.IsOk()); + + TExprContext ctxOne, ctxTwo; + TExprNode::TPtr rootOne, rootTwo; + + UNIT_ASSERT(CompileExpr(*progOne.Root, rootOne, ctxOne, nullptr, nullptr)); + UNIT_ASSERT(CompileExpr(*progTwo.Root, rootTwo, ctxTwo, nullptr, nullptr)); + + const TExprNode* diffOne = rootOne.Get(); + const TExprNode* diffTwo = rootTwo.Get(); + + if (diffPositions) { + UNIT_ASSERT(!CompareExprTrees(diffOne, diffTwo)); + UNIT_ASSERT_EQUAL(ctxOne.GetPosition(diffOne->Pos()), diffPositions->first); + UNIT_ASSERT_EQUAL(ctxTwo.GetPosition(diffTwo->Pos()), diffPositions->second); + } else + UNIT_ASSERT(CompareExprTrees(diffOne, diffTwo)); + } + + Y_UNIT_TEST(BigGoodCompare) { + const auto one = R"( + ( + (let $1 world) + (let $2 (DataSource '"yt" '"plato")) + (let $3 (MrTableRange '"statbox/yql-log" '"2016-05-25" '"2016-06-01")) + (let $4 '('table $3)) + (let $5 (Key $4)) + (let $6 '('"method" '"uri" '"login" '"user_agent" '"millis")) + (let $7 '()) + (let $8 (Read! $1 $2 $5 $6 $7)) + (let $9 (Left! $8)) + (let $10 (DataSink 'result)) + (let $11 (Key)) + (let $12 (Right! $8)) + (let $13 (lambda '($111) (block '( + (let $113 (Member $111 '"method")) + (let $114 (String '"POST")) + (let $115 (== $113 $114)) + (let $116 (Member $111 '"uri")) + (let $117 (String '"/api/v2/operations")) + (let $118 (== $116 $117)) + (let $119 (Udf '"String.HasPrefix")) + (let $120 (Member $111 '"uri")) + (let $121 (String '"/api/v2/tutorials/")) + (let $122 (Apply $119 $120 $121)) + (let $123 (Or $118 $122)) + (let $124 (Udf '"String.HasPrefix")) + (let $125 (Member $111 '"uri")) + (let $126 (String '"/api/v2/queries/")) + (let $127 (Apply $124 $125 $126)) + (let $128 (Or $123 $127)) + (let $129 (And $115 $128)) + (let $130 (Udf '"String.HasPrefix")) + (let $131 (Member $111 '"uri")) + (let $132 (String '"/api/v2/table_data_async")) + (let $133 (Apply $130 $131 $132)) + (let $134 (Or $129 $133)) + (let $135 (Udf '"String.HasPrefix")) + (let $136 (Member $111 '"login")) + (let $137 (String '"robot-")) + (let $138 (Apply $135 $136 $137)) + (let $139 (Not $138)) + (let $140 (And $134 $139)) + (let $141 (Bool 'false)) + (let $142 (Coalesce $140 $141)) + (return $142) + )))) + (let $14 (Filter $12 $13)) + (let $15 (lambda '($143) (block '( + (let $145 (Struct)) + (let $146 (Member $143 '"login")) + (let $147 (AddMember $145 '"login" $146)) + (let $148 (Udf '"String.HasPrefix")) + (let $149 (Member $143 '"user_agent")) + (let $150 (String '"YQL ")) + (let $151 (Apply $148 $149 $150)) + (let $152 (Bool 'false)) + (let $153 (Coalesce $151 $152)) + (let $154 (Udf '"String.SplitToList")) + (let $155 (Member $143 '"user_agent")) + (let $156 (String '" ")) + (let $157 (Apply $154 $155 $156)) + (let $158 (Int64 '"1")) + (let $159 (SqlAccess 'dict $157 $158)) + (let $160 (String '"CLI")) + (let $161 (== $159 $160)) + (let $162 (Bool 'false)) + (let $163 (Coalesce $161 $162)) + (let $164 (String '"CLI")) + (let $165 (String '"API")) + (let $166 (If $163 $164 $165)) + (let $167 (String '"Web UI")) + (let $168 (If $153 $166 $167)) + (let $169 (AddMember $147 '"client_type" $168)) + (let $170 (Udf '"DateTime.ToDate")) + (let $171 (Udf '"DateTime.StartOfWeek")) + (let $172 (Udf '"DateTime.FromMilliSeconds")) + (let $173 (Member $143 '"millis")) + (let $174 (Cast $173 'Uint64)) + (let $175 (Apply $172 $174)) + (let $176 (Apply $171 $175)) + (let $177 (Apply $170 $176)) + (let $178 (String '"")) + (let $179 (Coalesce $177 $178)) + (let $180 (String '" - ")) + (let $181 (Concat $179 $180)) + (let $182 (Udf '"DateTime.ToDate")) + (let $183 (Udf '"DateTime.StartOfWeek")) + (let $184 (Udf '"DateTime.FromMilliSeconds")) + (let $185 (Member $143 '"millis")) + (let $186 (Cast $185 'Uint64)) + (let $187 (Apply $184 $186)) + (let $188 (Apply $183 $187)) + (let $189 (Udf '"DateTime.FromDays")) + (let $190 (Int64 '"6")) + (let $191 (Apply $189 $190)) + (let $192 (+ $188 $191)) + (let $193 (Apply $182 $192)) + (let $194 (String '"")) + (let $195 (Coalesce $193 $194)) + (let $196 (Concat $181 $195)) + (let $197 (AddMember $169 '"week" $196)) + (let $198 (AsList $197)) + (return $198) + )))) + )" + R"( + (let $16 (FlatMap $14 $15)) + (let $17 '('"client_type" '"week")) + (let $18 (lambda '($199 $200) (block '( + (let $202 (lambda '($205 $206 $207) (block '( + (let $209 (ListItemType $205)) + (let $210 (lambda '($223) (block '( + (let $212 (ListItemType $205)) + (let $213 (InstanceOf $212)) + (let $214 (Apply $206 $213)) + (let $215 (TypeOf $214)) + (let $216 (ListType $215)) + (let $217 (Apply $207 $216)) + (let $225 (NthArg '1 $217)) + (let $226 (Apply $206 $223)) + (let $227 (Apply $225 $226)) + (return $227) + )))) + (let $211 (lambda '($228 $229) (block '( + (let $212 (ListItemType $205)) + (let $213 (InstanceOf $212)) + (let $214 (Apply $206 $213)) + (let $215 (TypeOf $214)) + (let $216 (ListType $215)) + (let $217 (Apply $207 $216)) + (let $231 (NthArg '2 $217)) + (let $232 (Apply $206 $228)) + (let $233 (Apply $231 $232 $229)) + (return $233) + )))) + (let $212 (ListItemType $205)) + (let $213 (InstanceOf $212)) + (let $214 (Apply $206 $213)) + (let $215 (TypeOf $214)) + (let $216 (ListType $215)) + (let $217 (Apply $207 $216)) + (let $218 (NthArg '3 $217)) + (let $219 (NthArg '4 $217)) + (let $220 (NthArg '5 $217)) + (let $221 (NthArg '6 $217)) + (let $222 (AggregationTraits $209 $210 $211 $218 $219 $220 $221)) + (return $222) + )))) + (let $203 (lambda '($234) (block '( + (let $236 (ListItemType $234)) + (let $237 (lambda '($244) (block '( + (let $246 (AggrCountInit $244)) + (return $246) + )))) + (let $238 (lambda '($247 $248) (block '( + (let $250 (AggrCountUpdate $247 $248)) + (return $250) + )))) + (let $239 (lambda '($251) (block '( + (return $251) + )))) + (let $240 (lambda '($253) (block '( + (return $253) + )))) + (let $241 (lambda '($255 $256) (block '( + (let $258 (+ $255 $256)) + (return $258) + )))) + (let $242 (lambda '($259) (block '( + (return $259) + )))) + (let $243 (AggregationTraits $236 $237 $238 $239 $240 $241 $242)) + (return $243) + )))) + (let $204 (Apply $202 $199 $200 $203)) + (return $204) + )))) + (let $19 (TypeOf $16)) + (let $20 (ListItemType $19)) + (let $21 (StructMemberType $20 '"login")) + (let $22 (ListType $21)) + (let $23 (lambda '($261) (block '( + (return $261) + )))) + (let $24 (Apply $18 $22 $23)) + (let $25 '('Count1 $24 '"login")) + (let $26 (TypeOf $16)) + (let $27 (lambda '($263) (block '( + (let $265 (Void)) + (return $265) + )))) + (let $28 (Apply $18 $26 $27)) + (let $29 '('Count2 $28)) + (let $30 (TypeOf $16)) + (let $31 (lambda '($266) (block '( + (let $268 (Void)) + (return $268) + )))) + (let $32 (Apply $18 $30 $31)) + (let $33 '('Count3 $32)) + (let $34 (TypeOf $16)) + (let $35 (ListItemType $34)) + (let $36 (StructMemberType $35 '"login")) + (let $37 (ListType $36)) + (let $38 (lambda '($269) (block '( + (return $269) + )))) + (let $39 (Apply $18 $37 $38)) + (let $40 '('Count4 $39 '"login")) + (let $41 '($25 $29 $33 $40)) + (let $42 (Aggregate $16 $17 $41)) + (let $43 (lambda '($271) (block '( + (let $273 (Struct)) + (let $274 (Member $271 '"week")) + (let $275 (AddMember $273 '"week" $274)) + (let $276 (Member $271 '"client_type")) + (let $277 (AddMember $275 '"client_type" $276)) + (let $278 (Member $271 'Count1)) + (let $279 (AddMember $277 '"users_count" $278)) + (let $280 (Member $271 'Count2)) + (let $281 (AddMember $279 '"operations_count" $280)) + (let $282 (Member $271 'Count3)) + (let $283 (Member $271 'Count4)) + (let $284 (/ $282 $283)) + (let $285 (AddMember $281 '"operations_per_user" $284)) + (let $286 (AsList $285)) + (return $286) + )))) + (let $44 (FlatMap $42 $43)) + (let $45 (Bool 'false)) + (let $46 (Bool 'false)) + (let $47 '($45 $46)) + (let $48 (lambda '($287) (block '( + (let $289 (Member $287 '"week")) + (let $290 (Member $287 '"users_count")) + (let $291 '($289 $290)) + (return $291) + )))) + (let $49 (Sort $44 $47 $48)) + (let $50 '('type)) + (let $51 '('autoref)) + (let $52 '('"week" '"client_type" '"users_count" '"operations_count" '"operations_per_user")) + (let $53 '('columns $52)) + (let $54 '($50 $51 $53)) + (let $55 (Write! $9 $10 $11 $49 $54)) + (let $56 (Commit! $55 $10)) + (let $57 (DataSource '"yt" '"plato")) + (let $58 (MrTableRange '"statbox/yql-log" '"2016-05-25" '"2016-06-01")) + (let $59 '('table $58)) + (let $60 (Key $59)) + (let $61 '('"method" '"uri" '"login" '"user_agent" '"millis")) + (let $62 '()) + (let $63 (Read! $56 $57 $60 $61 $62)) + (let $64 (Left! $63)) + (let $65 (DataSink 'result)) + (let $66 (Key)) + (let $67 (Right! $63)) + (let $68 (lambda '($292) (block '( + (let $294 (Member $292 '"method")) + (let $295 (String '"POST")) + (let $296 (== $294 $295)) + (let $297 (Member $292 '"uri")) + (let $298 (String '"/api/v2/operations")) + (let $299 (== $297 $298)) + (let $300 (Udf '"String.HasPrefix")) + (let $301 (Member $292 '"uri")) + (let $302 (String '"/api/v2/tutorials/")) + (let $303 (Apply $300 $301 $302)) + (let $304 (Or $299 $303)) + (let $305 (Udf '"String.HasPrefix")) + (let $306 (Member $292 '"uri")) + (let $307 (String '"/api/v2/queries/")) + (let $308 (Apply $305 $306 $307)) + (let $309 (Or $304 $308)) + (let $310 (And $296 $309)) + (let $311 (Udf '"String.HasPrefix")) + (let $312 (Member $292 '"uri")) + (let $313 (String '"/api/v2/table_data_async")) + (let $314 (Apply $311 $312 $313)) + (let $315 (Or $310 $314)) + (let $316 (Udf '"String.HasPrefix")) + (let $317 (Member $292 '"login")) + (let $318 (String '"robot-")) + (let $319 (Apply $316 $317 $318)) + (let $320 (Not $319)) + (let $321 (And $315 $320)) + (let $322 (Bool 'false)) + (let $323 (Coalesce $321 $322)) + (return $323) + )))) + (let $69 (Filter $67 $68)) + (let $70 (lambda '($324) (block '( + (let $326 (Struct)) + (let $327 (Member $324 '"login")) + (let $328 (AddMember $326 '"login" $327)) + (let $329 (Udf '"String.HasPrefix")) + (let $330 (Member $324 '"user_agent")) + (let $331 (String '"YQL ")) + (let $332 (Apply $329 $330 $331)) + (let $333 (Bool 'false)) + (let $334 (Coalesce $332 $333)) + (let $335 (Udf '"String.SplitToList")) + (let $336 (Member $324 '"user_agent")) + (let $337 (String '" ")) + (let $338 (Apply $335 $336 $337)) + (let $339 (Int64 '"1")) + (let $340 (SqlAccess 'dict $338 $339)) + (let $341 (String '"CLI")) + (let $342 (== $340 $341)) + (let $343 (Bool 'false)) + (let $344 (Coalesce $342 $343)) + (let $345 (String '"CLI")) + (let $346 (String '"API")) + (let $347 (If $344 $345 $346)) + (let $348 (String '"Web UI")) + (let $349 (If $334 $347 $348)) + (let $350 (AddMember $328 '"client_type" $349)) + (let $351 (Udf '"DateTime.ToDate")) + (let $352 (Udf '"DateTime.StartOfWeek")) + (let $353 (Udf '"DateTime.FromMilliSeconds")) + (let $354 (Member $324 '"millis")) + (let $355 (Cast $354 'Uint64)) + (let $356 (Apply $353 $355)) + (let $357 (Apply $352 $356)) + (let $358 (Apply $351 $357)) + (let $359 (String '"")) + (let $360 (Coalesce $358 $359)) + (let $361 (String '" - ")) + (let $362 (Concat $360 $361)) + (let $363 (Udf '"DateTime.ToDate")) + (let $364 (Udf '"DateTime.StartOfWeek")) + (let $365 (Udf '"DateTime.FromMilliSeconds")) + (let $366 (Member $324 '"millis")) + (let $367 (Cast $366 'Uint64)) + (let $368 (Apply $365 $367)) + (let $369 (Apply $364 $368)) + (let $370 (Udf '"DateTime.FromDays")) + (let $371 (Int64 '"6")) + (let $372 (Apply $370 $371)) + (let $373 (+ $369 $372)) + (let $374 (Apply $363 $373)) + (let $375 (String '"")) + (let $376 (Coalesce $374 $375)) + (let $377 (Concat $362 $376)) + (let $378 (AddMember $350 '"week" $377)) + (let $379 (AsList $378)) + (return $379) + )))) + (let $71 (FlatMap $69 $70)) + (let $72 '('"week")) + (let $73 (TypeOf $71)) + (let $74 (ListItemType $73)) + (let $75 (StructMemberType $74 '"login")) + (let $76 (ListType $75)) + (let $77 (lambda '($380) (block '( + (return $380) + )))) + (let $78 (Apply $18 $76 $77)) + (let $79 '('Count6 $78 '"login")) + (let $80 (TypeOf $71)) + (let $81 (lambda '($382) (block '( + (let $384 (Void)) + (return $384) + )))) + (let $82 (Apply $18 $80 $81)) + (let $83 '('Count7 $82)) + (let $84 (TypeOf $71)) + (let $85 (lambda '($385) (block '( + (let $387 (Void)) + (return $387) + )))) + (let $86 (Apply $18 $84 $85)) + (let $87 '('Count8 $86)) + (let $88 (TypeOf $71)) + (let $89 (ListItemType $88)) + (let $90 (StructMemberType $89 '"login")) + (let $91 (ListType $90)) + (let $92 (lambda '($388) (block '( + (return $388) + )))) + (let $93 (Apply $18 $91 $92)) + (let $94 '('Count9 $93 '"login")) + (let $95 '($79 $83 $87 $94)) + (let $96 (Aggregate $71 $72 $95)) + (let $97 (lambda '($390) (block '( + (let $392 (Struct)) + (let $393 (Member $390 '"week")) + (let $394 (AddMember $392 '"week" $393)) + (let $395 (Member $390 'Count6)) + (let $396 (AddMember $394 '"users_count" $395)) + (let $397 (Member $390 'Count7)) + (let $398 (AddMember $396 '"operations_count" $397)) + (let $399 (Member $390 'Count8)) + (let $400 (Member $390 'Count9)) + (let $401 (/ $399 $400)) + (let $402 (AddMember $398 '"operations_per_user" $401)) + (let $403 (AsList $402)) + (return $403) + )))) + (let $98 (FlatMap $96 $97)) + (let $99 (Bool 'false)) + (let $100 (lambda '($404) (block '( + (let $406 (Member $404 '"week")) + (return $406) + )))) + (let $101 (Sort $98 $99 $100)) + (let $102 '('type)) + (let $103 '('autoref)) + (let $104 '('"week" '"users_count" '"operations_count" '"operations_per_user")) + (let $105 '('columns $104)) + (let $106 '($102 $103 $105)) + (let $107 (Write! $64 $65 $66 $101 $106)) + (let $108 (Commit! $107 $65)) + (let $109 (DataSink '"yt" '"plato")) + (let $110 (Commit! $108 $109)) + (return $110) + ) + )"; + + const auto two = R"( + ( + (let $1 (MrTableRange '"statbox/yql-log" '"2016-05-25" '"2016-06-01")) + (let $2 '('"method" '"uri" '"login" '"user_agent" '"millis")) + (let $3 (Read! world (DataSource '"yt" '"plato") (Key '('table $1)) $2 '())) + (let $4 (DataSink 'result)) + (let $5 (FlatMap (Filter (Right! $3) (lambda '($36) (block '( + (let $37 (Apply (Udf '"String.HasPrefix") (Member $36 '"uri") (String '"/api/v2/tutorials/"))) + (let $38 (Apply (Udf '"String.HasPrefix") (Member $36 '"uri") (String '"/api/v2/queries/"))) + (let $39 (Apply (Udf '"String.HasPrefix") (Member $36 '"uri") (String '"/api/v2/table_data_async"))) + (let $40 (Apply (Udf '"String.HasPrefix") (Member $36 '"login") (String '"robot-"))) + (return (Coalesce (And (Or (And (== (Member $36 '"method") (String '"POST")) (Or (Or (== (Member $36 '"uri") (String '"/api/v2/operations")) $37) $38)) $39) (Not $40)) (Bool 'false))) + )))) (lambda '($41) (block '( + (let $42 (AddMember (Struct) '"login" (Member $41 '"login"))) + (let $43 (Apply (Udf '"String.HasPrefix") (Member $41 '"user_agent") (String '"YQL "))) + (let $44 (Apply (Udf '"String.SplitToList") (Member $41 '"user_agent") (String '" "))) + (let $45 (SqlAccess 'dict $44 (Int64 '"1"))) + (let $46 (If (Coalesce (== $45 (String '"CLI")) (Bool 'false)) (String '"CLI") (String '"API"))) + (let $47 (If (Coalesce $43 (Bool 'false)) $46 (String '"Web UI"))) + (let $48 (AddMember $42 '"client_type" $47)) + (let $49 (AddMember $48 '"week" (Concat (Concat (Coalesce (Apply (Udf '"DateTime.ToDate") (Apply (Udf '"DateTime.StartOfWeek") (Apply (Udf '"DateTime.FromMilliSeconds") (Cast (Member $41 '"millis") 'Uint64)))) (String '"")) (String '" - ")) (Coalesce (Apply (Udf '"DateTime.ToDate") (+ (Apply (Udf '"DateTime.StartOfWeek") (Apply (Udf '"DateTime.FromMilliSeconds") (Cast (Member $41 '"millis") 'Uint64))) (Apply (Udf '"DateTime.FromDays") (Int64 '"6")))) (String '""))))) + (return (AsList $49)) + ))))) + (let $6 (lambda '($50 $51) (block '( + (let $52 (Apply (lambda '($53 $54 $55) (block '( + (let $57 (Apply $55 (ListType (TypeOf (Apply $54 (InstanceOf (ListItemType $53))))))) + (let $58 (AggregationTraits (ListItemType $53) (lambda '($59) (block '( + (let $57 (Apply $55 (ListType (TypeOf (Apply $54 (InstanceOf (ListItemType $53))))))) + (return (Apply (NthArg '1 $57) (Apply $54 $59))) + ))) (lambda '($60 $61) (block '( + (let $57 (Apply $55 (ListType (TypeOf (Apply $54 (InstanceOf (ListItemType $53))))))) + (let $62 (Apply (NthArg '2 $57) (Apply $54 $60) $61)) + (return $62) + ))) (NthArg '3 $57) (NthArg '4 $57) (NthArg '5 $57) (NthArg '6 $57))) + (return $58) + ))) $50 $51 (lambda '($63) (block '( + (let $64 (AggregationTraits (ListItemType $63) (lambda '($65) (AggrCountInit $65)) (lambda '($66 $67) (AggrCountUpdate $66 $67)) (lambda '($68) $68) (lambda '($69) $69) (lambda '($70 $71) (+ $70 $71)) (lambda '($72) $72))) + (return $64) + ))))) + (return $52) + )))) + (let $7 (Apply $6 (ListType (StructMemberType (ListItemType (TypeOf $5)) '"login")) (lambda '($73) $73))) + (let $8 '('Count1 $7 '"login")) + (let $9 (Apply $6 (TypeOf $5) (lambda '($74) (Void)))) + (let $10 (Apply $6 (TypeOf $5) (lambda '($75) (Void)))) + (let $11 (Apply $6 (ListType (StructMemberType (ListItemType (TypeOf $5)) '"login")) (lambda '($76) $76))) + (let $12 '('Count4 $11 '"login")) + (let $13 '($8 '('Count2 $9) '('Count3 $10) $12)) + (let $14 (Aggregate $5 '('"client_type" '"week") $13)) + (let $15 (Sort (FlatMap $14 (lambda '($77) (block '( + (let $78 (AddMember (Struct) '"week" (Member $77 '"week"))) + (let $79 (AddMember $78 '"client_type" (Member $77 '"client_type"))) + (let $80 (AddMember $79 '"users_count" (Member $77 'Count1))) + (let $81 (AddMember $80 '"operations_count" (Member $77 'Count2))) + (let $82 (AddMember $81 '"operations_per_user" (/ (Member $77 'Count3) (Member $77 'Count4)))) + (return (AsList $82)) + )))) '((Bool 'false) (Bool 'false)) (lambda '($83) '((Member $83 '"week") (Member $83 '"users_count"))))) + (let $16 '('"week" '"client_type" '"users_count" '"operations_count" '"operations_per_user")) + (let $17 '('('type) '('autoref) '('columns $16))) + (let $18 (Write! (Left! $3) $4 (Key) $15 $17)) + (let $19 (MrTableRange '"statbox/yql-log" '"2016-05-25" '"2016-06-01")) + (let $20 '('"method" '"uri" '"login" '"user_agent" '"millis")) + (let $21 (Read! (Commit! $18 $4) (DataSource '"yt" '"plato") (Key '('table $19)) $20 '())) + (let $22 (DataSink 'result)) + (let $23 (FlatMap (Filter (Right! $21) (lambda '($84) (block '( + (let $85 (Apply (Udf '"String.HasPrefix") (Member $84 '"uri") (String '"/api/v2/tutorials/"))) + (let $86 (Apply (Udf '"String.HasPrefix") (Member $84 '"uri") (String '"/api/v2/queries/"))) + (let $87 (Apply (Udf '"String.HasPrefix") (Member $84 '"uri") (String '"/api/v2/table_data_async"))) + (let $88 (Apply (Udf '"String.HasPrefix") (Member $84 '"login") (String '"robot-"))) + (return (Coalesce (And (Or (And (== (Member $84 '"method") (String '"POST")) (Or (Or (== (Member $84 '"uri") (String '"/api/v2/operations")) $85) $86)) $87) (Not $88)) (Bool 'false))) + )))) (lambda '($89) (block '( + (let $90 (AddMember (Struct) '"login" (Member $89 '"login"))) + (let $91 (Apply (Udf '"String.HasPrefix") (Member $89 '"user_agent") (String '"YQL "))) + (let $92 (Apply (Udf '"String.SplitToList") (Member $89 '"user_agent") (String '" "))) + (let $93 (SqlAccess 'dict $92 (Int64 '"1"))) + (let $94 (If (Coalesce (== $93 (String '"CLI")) (Bool 'false)) (String '"CLI") (String '"API"))) + (let $95 (If (Coalesce $91 (Bool 'false)) $94 (String '"Web UI"))) + (let $96 (AddMember $90 '"client_type" $95)) + (let $97 (AddMember $96 '"week" (Concat (Concat (Coalesce (Apply (Udf '"DateTime.ToDate") (Apply (Udf '"DateTime.StartOfWeek") (Apply (Udf '"DateTime.FromMilliSeconds") (Cast (Member $89 '"millis") 'Uint64)))) (String '"")) (String '" - ")) (Coalesce (Apply (Udf '"DateTime.ToDate") (+ (Apply (Udf '"DateTime.StartOfWeek") (Apply (Udf '"DateTime.FromMilliSeconds") (Cast (Member $89 '"millis") 'Uint64))) (Apply (Udf '"DateTime.FromDays") (Int64 '"6")))) (String '""))))) + (return (AsList $97)) + ))))) + (let $24 (Apply $6 (ListType (StructMemberType (ListItemType (TypeOf $23)) '"login")) (lambda '($98) $98))) + (let $25 '('Count6 $24 '"login")) + (let $26 (Apply $6 (TypeOf $23) (lambda '($99) (Void)))) + (let $27 (Apply $6 (TypeOf $23) (lambda '($100) (Void)))) + (let $28 (Apply $6 (ListType (StructMemberType (ListItemType (TypeOf $23)) '"login")) (lambda '($101) $101))) + (let $29 '('Count9 $28 '"login")) + (let $30 '($25 '('Count7 $26) '('Count8 $27) $29)) + (let $31 (Aggregate $23 '('"week") $30)) + (let $32 (Sort (FlatMap $31 (lambda '($102) (block '( + (let $103 (AddMember (Struct) '"week" (Member $102 '"week"))) + (let $104 (AddMember $103 '"users_count" (Member $102 'Count6))) + (let $105 (AddMember $104 '"operations_count" (Member $102 'Count7))) + (let $106 (AddMember $105 '"operations_per_user" (/ (Member $102 'Count8) (Member $102 'Count9)))) + (return (AsList $106)) + )))) (Bool 'false) (lambda '($107) (Member $107 '"week")))) + (let $33 '('"week" '"users_count" '"operations_count" '"operations_per_user")) + (let $34 '('('type) '('autoref) '('columns $33))) + (let $35 (Write! (Left! $21) $22 (Key) $32 $34)) + (return (Commit! (Commit! $35 $22) (DataSink '"yt" '"plato"))) + ) + )"; + + CompileAndCompare(one, two); + } + + Y_UNIT_TEST(DiffrentAtoms) { + const auto one = "((return (+ '4 (- '3 '2))))"; + const auto two = "((let x '3)\n(let y '1)\n(let z (- x y))\n(let r (+ '4 z))\n(return r))"; + + const auto diff = std::make_pair(TPosition(23,1), TPosition(9,2)); + CompileAndCompare(one, two, &diff); + } + + Y_UNIT_TEST(DiffrentLists) { + const auto one = "((return '('7 '4 '('1 '3 '2))))"; + const auto two = "((let x '('1 '3))\n(let y '('7 '4 x))\n(return y))"; + + const auto diff = std::make_pair(TPosition(20,1), TPosition(11,1)); + CompileAndCompare(one, two, &diff); + } + + Y_UNIT_TEST(DiffrentCallables) { + const auto one = "((return (- '4 (- '3 '2))))"; + const auto two = "((let x '3)\n(let y '2)\n(let z (- x y))\n(let r (+ '4 z))\n(return r))"; + + const auto diff = std::make_pair(TPosition(11,1), TPosition(9,4)); + CompileAndCompare(one, two, &diff); + } + + Y_UNIT_TEST(SwapArguments) { + const auto one = "((let l (lambda '(x y) (+ x y)))\n(return (Apply l '7 '9)))"; + const auto two = "((return (Apply (lambda '(x y) (+ y x)) '7 '9)))"; + + const auto diff = std::make_pair(TPosition(19,1), TPosition(29,1)); + CompileAndCompare(one, two, &diff); + } +} + +Y_UNIT_TEST_SUITE(TConvertToAst) { + static TString CompileAndDisassemble(const TString& program, bool expectEqualExprs = true) { + const auto astRes = ParseAst(program); + UNIT_ASSERT(astRes.IsOk()); + TExprContext exprCtx; + TExprNode::TPtr exprRoot; + UNIT_ASSERT(CompileExpr(*astRes.Root, exprRoot, exprCtx, nullptr, nullptr)); + UNIT_ASSERT(exprRoot); + + const auto convRes = ConvertToAst(*exprRoot, exprCtx, 0, true); + UNIT_ASSERT(convRes.IsOk()); + + TExprContext exprCtx2; + TExprNode::TPtr exprRoot2; + auto compileOk = CompileExpr(*convRes.Root, exprRoot2, exprCtx2, nullptr, nullptr); + exprCtx2.IssueManager.GetIssues().PrintTo(Cout); + UNIT_ASSERT(compileOk); + UNIT_ASSERT(exprRoot2); + const TExprNode* node = exprRoot.Get(); + const TExprNode* node2 = exprRoot2.Get(); + bool equal = CompareExprTrees(node, node2); + UNIT_ASSERT(equal == expectEqualExprs); + + return convRes.Root->ToString(TAstPrintFlags::PerLine | TAstPrintFlags::ShortQuote); + } + + Y_UNIT_TEST(ManyLambdaWithCaptures) { + const auto program = R"( + ( + #comment + (let mr_source (DataSource 'yt 'plato)) + (let x (Read! world mr_source (Key '('table (String 'Input))) '('key 'subkey 'value) '())) + (let world (Left! x)) + (let table1 (Right! x)) + (let table1int (FlatMap table1 + (lambda '(item) (block '( + (let intKey (FromString (Member item 'key) 'Int32)) + (let keyDiv100 (FlatMap intKey (lambda '(x) (/ x (Int32 '100))))) + (let ret (Map keyDiv100 (lambda '(y) (block '( + (let r '(y (Member item 'value))) + (return r) + ))))) + (return ret) + ))) + )) + (let table1intDebug (Map table1int (lambda '(it) (block '( + (let s (Struct)) + (let s (AddMember s 'key (ToString (Nth it '0)))) + (let s (AddMember s 'subkey (String '.))) + (let s (AddMember s 'value (Nth it '1))) + (return s) + ))))) + (let mr_sink (DataSink 'yt (quote plato))) + (let world (Write! world mr_sink (Key '('table (String 'Output))) table1intDebug '('('mode 'append)))) + (let world (Commit! world mr_sink)) + (return world) + ) + )"; + CompileAndDisassemble(program); + } + + Y_UNIT_TEST(LambdaWithCaptureArgumentOfTopLambda) { + const auto program = R"( + ( + (let mr_source (DataSource 'yt 'plato)) + (let x (Read! world mr_source (Key '('table (String 'Input))) '('key 'subkey 'value) '())) + (let world (Left! x)) + (let table1 (Right! x)) + (let table1low (FlatMap table1 (lambda '(item) (block '( + (let intValueOpt (FromString (Member item 'key) 'Int32)) + (let ret (FlatMap intValueOpt (lambda '(item2) (block '( + (return (ListIf (< item2 (Int32 '100)) item)) + ))))) + (return ret) + ))))) + (let res_sink (DataSink 'result)) + (let data (AsList (String 'x))) + (let world (Write! world res_sink (Key) table1low '())) + (let world (Commit! world res_sink)) + (return world) + ) + )"; + CompileAndDisassemble(program); + } + + Y_UNIT_TEST(LambdaWithCapture) { + const auto program = R"( + ( + (let conf (Configure! world (DataSource 'yt '"$all") '"Attr" '"mapjoinlimit" '"1")) + (let dsr (DataSink 'result)) + (let dsy (DataSink 'yt 'plato)) + (let co '('key 'subkey 'value)) + (let data (DataSource 'yt 'plato)) + (let string (DataType 'String)) + (let ostr (OptionalType string)) + (let struct (StructType '('key ostr) '('subkey ostr) '('value ostr))) + (let scheme '('('"scheme" struct))) + (let temp (MrTempTable dsy '"tmp/bb686f68-2245bd5f-2318fa4e-1" scheme)) + (let str (lambda '(arg) (Just (AsStruct '('key (Just (Member arg 'key))) '('subkey (Just (Member arg 'subkey))) '('value (Just (Member arg 'value))))))) + (let map (MrMap! world dsy (Key '('table (String 'Input1))) co '() data temp '() str)) + (let tt (MrTempTable dsy '"tmp/7ae6459a-7382d1e7-7935c08e-2" scheme)) + (let map2 (MrMap! world dsy (Key '('table (String 'Input2))) co '() data tt '() str)) + (let s2 (StructType '('"a.key" string) '('"a.subkey" ostr) '('"a.value" ostr) '('"b.key" string) '('"b.subkey" ostr) '('"b.value" ostr))) + (let mtt3 (MrTempTable dsy '"tmp/ecfc6738-59d47572-b9936849-3" '('('"scheme" s2)))) + (let tuple '('('"take" (Uint64 '"101")))) + (let lmap (MrLMap! (Sync! map map2) dsy temp co '() data mtt3 '('('"limit" '(tuple))) (lambda '(arg) (block '( + (let read (MrReadTable! world data tt co '())) + (let key '('"key")) + (let wtf '('"Hashed" '"One" '"Compact" '('"ItemsCount" '"4"))) + (let dict (ToDict (FilterNullMembers (MrTableContent read '()) key) (lambda '(arg) (Member arg '"key")) (lambda '(x) x) wtf)) + (let acols '('"key" '"a.key" '"subkey" '"a.subkey" '"value" '"a.value")) + (let bcols '('"key" '"b.key" '"subkey" '"b.subkey" '"value" '"b.value")) + (return (MapJoinCore arg dict 'Inner key acols bcols)) + ))))) + (let cols '('"a.key" '"a.subkey" '"a.value" '"b.key" '"b.subkey" '"b.value")) + (let res (ResPull! conf dsr (Key) (Right! (MrReadTable! lmap data mtt3 cols tuple)) '('('type)) 'yt)) + (return (Commit! res dsr)) + ) + )"; + + const auto disassembled = CompileAndDisassemble(program); + UNIT_ASSERT(TString::npos != disassembled.find("'('key 'subkey 'value)")); + UNIT_ASSERT_EQUAL(disassembled.find("'('key 'subkey 'value)"), disassembled.rfind("'('key 'subkey 'value)")); + } + + Y_UNIT_TEST(ManyLambdasWithCommonCapture) { + const auto program = R"( + ( + (let c42 (+ (Int64 '40) (Int64 '2))) + (let c100 (Int64 '100)) + (let l0 (lambda '(x) (- (* x c42) c42))) + (let l1 (lambda '(y) (Apply l0 y))) + (let l2 (lambda '(z) (+ (* c42) (Apply l0 c42)))) + (return (* (Apply l1 c100)(Apply l2 c100))) + ) + )"; + + const auto disassembled = CompileAndDisassemble(program); + UNIT_ASSERT(TString::npos != disassembled.find("(+ (Int64 '40) (Int64 '2))")); + UNIT_ASSERT_EQUAL(disassembled.find("(+ (Int64 '40) (Int64 '2))"), disassembled.rfind("(+ (Int64 '40) (Int64 '2))")); + } + + Y_UNIT_TEST(CapturedUseInTopLevelAfrerLambda) { + const auto program = R"( + ( + (let $1 (DataSink 'result)) + (let $2 (DataSink '"yt" '"plato")) + (let $3 (DataSource '"yt" '"plato")) + (let $4 (TupleType (DataType 'Int64) (DataType 'Uint64))) + (let $5 (MrTempTable $2 '"tmp/ecfc6738-59d47572-b9936849-3" '('('"scheme" (StructType '('"key" (DataType 'String)) '('"value" (StructType '('Avg1 (OptionalType $4)) '('Avg2 $4)))))))) + (let $6 (MrMapCombine! world $2 (Key '('table (String '"Input"))) '('"key" '"subkey") '() $3 $5 '() (lambda '($13) (Just (AsStruct '('"key" (Cast (Member $13 '"key") 'Int64)) '('"sub" (Unwrap (Cast (Member $13 '"subkey") 'Int64)))))) (lambda '($14) (Uint32 '"0")) (lambda '($15 $16) (block '( + (let $18 (Uint64 '1)) + (let $17 (IfPresent (Member $16 '"key") (lambda '($19) (Just '($19 $18))) (Nothing (OptionalType (TupleType (DataType 'Int64) (DataType 'Uint64)))))) + (return (AsStruct '('Avg1 $17) '('Avg2 '((Member $16 '"sub") $18)))) + ))) (lambda '($20 $21 $22) (block '( + (let $23 (IfPresent (Member $21 '"key") (lambda '($28) (block '( + (let $29 (Uint64 '1)) + (return (Just '($28 $29))) + ))) (Nothing (OptionalType (TupleType (DataType 'Int64) (DataType 'Uint64)))))) + (let $24 (IfPresent (Member $22 'Avg1) (lambda '($26) (IfPresent (Member $21 '"key") (lambda '($27) (Just '((+ (Nth $26 '0) $27) (Inc (Nth $26 '1))))) (Just $26))) $23)) + (let $25 (Member $22 'Avg2)) + (return (AsStruct '('Avg1 $24) '('Avg2 '((+ (Nth $25 '0) (Member $21 '"sub")) (Inc (Nth $25 '1)))))) + ))) (lambda '($30 $31) (Just (AsStruct '('"value" $31) '('"key" (String '""))))))) + (return $6) + ) + )"; + + CompileAndDisassemble(program); + } + + Y_UNIT_TEST(SelectCommonAncestor) { + const auto program = R"( + ( + (let $1 (DataSink 'result)) + (let $2 (DataSink '"yt" '"plato")) + (let $3 '('"key" '"value")) + (let $4 (DataSource '"yt" '"plato")) + (let $5 (DataType 'String)) + (let $6 (OptionalType $5)) + (let $7 (MrTempTable $2 '"tmp/41c7eb81-87a9f8b6-70daa714-11" '('('"scheme" (StructType '('"key" $5) '('"value" (StructType '('Histogram0 $6) '('Histogram1 $6)))))))) + (let $8 (Udf 'Histogram.AdaptiveWardHistogram_Create (Void) (VoidType) '"" (CallableType '() '((ResourceType '"Histogram.AdaptiveWard")) '((DataType 'Double)) '((DataType 'Double)) '((DataType 'Uint32))))) + (let $9 (Double '1.0)) + (let $10 (Cast (Int32 '"1") 'Uint32)) + (let $11 (Double '1.0)) + (let $12 (Cast (Int32 '"1000000") 'Uint32)) + (let $13 (MrMapCombine! world $2 (Key '('table (String '"Input"))) $3 '() $4 $7 '() (lambda '($21) (Just $21)) (lambda '($22) (Uint32 '"0")) (lambda '($23 $24) (AsStruct '('Histogram0 (FlatMap (Cast (Member $24 '"key") 'Double) (lambda '($25) (block '( + (let $26 '((DataType 'Double))) + (let $27 (CallableType '() '((ResourceType '"Histogram.AdaptiveWard")) $26 $26 '((DataType 'Uint32)))) + (let $28 '((Unwrap (Cast (Member $24 '"key") 'Double)) $9 $10)) + (return (Just (NamedApply $8 $28 (AsStruct) (Uint32 '"0")))) + ))))) '('Histogram1 (FlatMap (Cast (Member $24 '"value") 'Double) (lambda '($29) (block '( + (let $30 '((Unwrap (Cast (Member $24 '"value") 'Double)) $11 $12)) + (return (Just (NamedApply $8 $30 (AsStruct) (Uint32 '"1")))) + ))))))) (lambda '($31 $32 $33) (block '( + (let $34 (Udf 'Histogram.AdaptiveWardHistogram_AddValue (Void) (VoidType) '"" (CallableType '() '((ResourceType '"Histogram.AdaptiveWard")) '((ResourceType '"Histogram.AdaptiveWard")) '((DataType 'Double)) '((DataType 'Double))))) + (let $35 (Uint32 '"0")) + (let $36 (IfPresent (Member $33 'Histogram0) (lambda '($39) (block '( + (let $40 (Cast (Member $32 '"key") 'Double)) + (let $41 '((ResourceType '"Histogram.AdaptiveWard"))) + (let $42 '((DataType 'Double))) + (let $43 (CallableType '() $41 $41 $42 $42)) + (let $44 '($39 (Unwrap $40) $9)) + (let $45 (NamedApply $34 $44 (AsStruct) $35)) + (return (Just (If (Exists $40) $45 $39))) + ))) (FlatMap (Cast (Member $32 '"key") 'Double) (lambda '($46) (block '( + (let $47 '((Unwrap (Cast (Member $32 '"key") 'Double)) $9 $10)) + (return (Just (NamedApply $8 $47 (AsStruct) $35))) + )))))) + (let $37 (Uint32 '"1")) + (let $38 (IfPresent (Member $33 'Histogram1) (lambda '($48) (block '( + (let $49 (Cast (Member $32 '"value") 'Double)) + (let $50 '($48 (Unwrap $49) $11)) + (let $51 (NamedApply $34 $50 (AsStruct) $37)) + (return (Just (If (Exists $49) $51 $48))) + ))) (FlatMap (Cast (Member $32 '"value") 'Double) (lambda '($52) (block '( + (let $53 '((Unwrap (Cast (Member $32 '"value") 'Double)) $11 $12)) + (return (Just (NamedApply $8 $53 (AsStruct) $37))) + )))))) + (return (AsStruct '('Histogram0 $36) '('Histogram1 $38))) + ))) (lambda '($54 $55) (block '( + (let $56 (lambda '($57) (block '( + (let $58 (CallableType '() '((DataType 'String)) '((ResourceType '"Histogram.AdaptiveWard")))) + (let $59 (Udf 'Histogram.AdaptiveWardHistogram_Serialize (Void) (VoidType) '"" $58)) + (return (Just (Apply $59 $57))) + )))) + (return (Just (AsStruct '('"value" (AsStruct '('Histogram0 (FlatMap (Member $55 'Histogram0) $56)) '('Histogram1 (FlatMap (Member $55 'Histogram1) $56)))) '('"key" (String '""))))) + ))))) + (let $14 (DataType 'Double)) + (let $15 (OptionalType (StructType '('"Bins" (ListType (StructType '('"Frequency" $14) '('"Position" $14)))) '('"Max" $14) '('"Min" $14) '('"WeightsSum" $14)))) + (let $16 (MrTempTable $2 '"tmp/b6d6c3ee-30bb3e55-ea0c48bc-12" '('('"scheme" (StructType '('"key_histogram" $15) '('"value_histogram" $15)))))) + (let $17 (MrReduce! $13 $2 $7 $3 '() $4 $16 '('('"reduceBy" '('"key"))) (lambda '($60 $61) (block '( + (let $67 (Udf 'Histogram.AdaptiveWardHistogram_Deserialize (Void) (VoidType) '"" (CallableType '() '((ResourceType '"Histogram.AdaptiveWard")) '((DataType 'String)) '((DataType 'Uint32))))) + (let $62 (lambda '($68) (block '( + (let $69 (CallableType '() '((ResourceType '"Histogram.AdaptiveWard")) '((DataType 'String)) '((DataType 'Uint32)))) + (return (Just (Apply $67 $68 $10))) + )))) + (let $63 (lambda '($70) (Just (Apply $67 $70 $12)))) + (let $64 (Fold1 (FlatMap $61 (lambda '($65) (Just (Member $65 '"value")))) (lambda '($66) (block '( + (return (AsStruct '('Histogram0 (FlatMap (Member $66 'Histogram0) $62)) '('Histogram1 (FlatMap (Member $66 'Histogram1) $63)))) + ))) (lambda '($71 $72) (block '( + (let $73 (lambda '($76 $77) (block '( + (let $78 '((ResourceType '"Histogram.AdaptiveWard"))) + (let $79 (CallableType '() $78 $78 $78)) + (let $80 (Udf 'Histogram.AdaptiveWardHistogram_Merge (Void) (VoidType) '"" $79)) + (return (Apply $80 $76 $77)) + )))) + (let $74 (OptionalReduce (FlatMap (Member $71 'Histogram0) $62) (Member $72 'Histogram0) $73)) + (let $75 (OptionalReduce (FlatMap (Member $71 'Histogram1) $63) (Member $72 'Histogram1) $73)) + (return (AsStruct '('Histogram0 $74) '('Histogram1 $75))) + ))))) + (return (FlatMap $64 (lambda '($81) (block '( + (let $82 (lambda '($83) (block '( + (let $84 (DataType 'Double)) + (let $85 (CallableType '() '((StructType '('"Bins" (ListType (StructType '('"Frequency" $84) '('"Position" $84)))) '('"Max" $84) '('"Min" $84) '('"WeightsSum" $84))) '((ResourceType '"Histogram.AdaptiveWard")))) + (let $86 (Udf 'Histogram.AdaptiveWardHistogram_GetResult (Void) (VoidType) '"" $85)) + (return (Just (Apply $86 $83))) + )))) + (return (AsList (AsStruct '('"key_histogram" (FlatMap (Member $81 'Histogram0) $82)) '('"value_histogram" (FlatMap (Member $81 'Histogram1) $82))))) + ))))) + ))))) + (let $18 '('"key_histogram" '"value_histogram")) + (let $19 '('('type) '('autoref) '('columns $18))) + (let $20 (ResPull! world $1 (Key) (Right! (MrReadTable! $17 $4 $16 $18 '())) $19 '"yt")) + (return (Commit! (Commit! $20 $1) $2 '('('"epoch" '"1")))) + ) + )"; + + CompileAndDisassemble(program); + } + + Y_UNIT_TEST(Parameters) { + const auto program = R"( + ( + (let $nameType (OptionalType (DataType 'String))) + (let $1 (Read! world (DataSource '"kikimr" '"local_ut") (Key '('table (String '"tmp/table"))) (Void) '())) + (let $2 (DataSink 'result)) + (let $5 (Write! (Left! $1) $2 (Key) (FlatMap (Filter (Right! $1) + (lambda '($9) (Coalesce (And + (== (Member $9 '"Group") (Parameter '"$Group" (DataType 'Uint32))) + (== (Member $9 '"Name") (Parameter '"$Name" $nameType))) (Bool 'false)))) + (lambda '($10) (AsList $10))) '('('type) '('autoref)))) + (let $6 (Read! (Commit! $5 $2) (DataSource '"kikimr" '"local_ut") + (Key '('table (String '"tmp/table"))) (Void) '())) + (let $7 (DataSink 'result)) + (let $8 (Write! (Left! $6) $7 (Key) (FlatMap (Filter (Right! $6) + (lambda '($11) (Coalesce (And + (== (Member $11 '"Group") (+ (Parameter '"$Group" (DataType 'Uint32)) (Int32 '"1"))) + (== (Member $11 '"Name") (Coalesce (Parameter '"$Name" $nameType) + (String '"Empty")))) (Bool 'false)))) + (lambda '($12) (AsList $12))) '('('type) '('autoref)))) + (return (Commit! $8 $7)) + ) + )"; + + const auto disassembled = CompileAndDisassemble(program); + UNIT_ASSERT(TString::npos != disassembled.find("(declare $Group (DataType 'Uint32))")); + UNIT_ASSERT(TString::npos != disassembled.find("(declare $Name (OptionalType (DataType 'String)))")); + } + + Y_UNIT_TEST(ParametersDifferentTypes) { + const auto program = R"( + ( + (let $1 (Read! world (DataSource '"kikimr" '"local_ut") (Key '('table (String '"tmp/table"))) (Void) '())) + (let $2 (DataSink 'result)) + (let $5 (Write! (Left! $1) $2 (Key) (FlatMap (Filter (Right! $1) + (lambda '($9) (Coalesce (And + (== (Member $9 '"Group") (Parameter '"$Group" (DataType 'Uint32))) + (== (Member $9 '"Name") (Parameter '"$Name" (OptionalType (DataType 'String))))) (Bool 'false)))) + (lambda '($10) (AsList $10))) '('('type) '('autoref)))) + (let $6 (Read! (Commit! $5 $2) (DataSource '"kikimr" '"local_ut") + (Key '('table (String '"tmp/table"))) (Void) '())) + (let $7 (DataSink 'result)) + (let $8 (Write! (Left! $6) $7 (Key) (FlatMap (Filter (Right! $6) + (lambda '($11) (Coalesce (And + (== (Member $11 '"Group") (+ (Parameter '"$Group" (DataType 'Uint32)) (Int32 '"1"))) + (== (Member $11 '"Name") (Coalesce (Parameter '"$Name" (OptionalType (DataType 'Int32))) + (String '"Empty")))) (Bool 'false)))) + (lambda '($12) (AsList $12))) '('('type) '('autoref)))) + (return (Commit! $8 $7)) + ) + )"; + + const auto disassembled = CompileAndDisassemble(program, false); + UNIT_ASSERT(TString::npos != disassembled.find("(declare $Group (DataType 'Uint32))")); + UNIT_ASSERT(TString::npos != disassembled.find("(declare $Name (OptionalType (DataType 'String)))")); + } +} + +} // namespace NYql diff --git a/yql/essentials/ast/yql_gc_nodes.cpp b/yql/essentials/ast/yql_gc_nodes.cpp new file mode 100644 index 00000000000..6259e9babe7 --- /dev/null +++ b/yql/essentials/ast/yql_gc_nodes.cpp @@ -0,0 +1,5 @@ +#include "yql_gc_nodes.h" + +namespace NYql { + +} diff --git a/yql/essentials/ast/yql_gc_nodes.h b/yql/essentials/ast/yql_gc_nodes.h new file mode 100644 index 00000000000..59c35d5a51b --- /dev/null +++ b/yql/essentials/ast/yql_gc_nodes.h @@ -0,0 +1,22 @@ +#pragma once + +#include <util/system/types.h> + +namespace NYql { + +struct TGcNodeSettings { + ui64 NodeCountThreshold = 1000; + double CollectRatio = 0.8; +}; + +struct TGcNodeStatistics { + ui64 CollectCount = 0; + ui64 TotalCollectedNodes = 0; +}; + +struct TGcNodeConfig { + TGcNodeSettings Settings; + TGcNodeStatistics Statistics; +}; + +} diff --git a/yql/essentials/ast/yql_pos_handle.h b/yql/essentials/ast/yql_pos_handle.h new file mode 100644 index 00000000000..e00cf63a78c --- /dev/null +++ b/yql/essentials/ast/yql_pos_handle.h @@ -0,0 +1,13 @@ +#pragma once + +#include <util/system/types.h> + +namespace NYql { + +struct TPositionHandle { + friend struct TExprContext; +private: + ui32 Handle = 0; // 0 is guaranteed to represent default-constructed TPosition +}; + +} diff --git a/yql/essentials/ast/yql_type_string.cpp b/yql/essentials/ast/yql_type_string.cpp new file mode 100644 index 00000000000..b9b6570cf00 --- /dev/null +++ b/yql/essentials/ast/yql_type_string.cpp @@ -0,0 +1,1494 @@ +#include "yql_type_string.h" +#include "yql_expr.h" +#include "yql_ast_escaping.h" + +#include <yql/essentials/parser/pg_catalog/catalog.h> +#include <library/cpp/containers/stack_vector/stack_vec.h> + +#include <util/string/cast.h> +#include <util/generic/map.h> +#include <util/generic/utility.h> +#include <library/cpp/deprecated/enum_codegen/enum_codegen.h> + + +#define EXPECT_AND_SKIP_TOKEN_IMPL(token, message, result) \ + do { \ + if (Y_LIKELY(Token == token)) { \ + GetNextToken(); \ + } else { \ + AddError(message); \ + return result; \ + } \ + } while (0); + +#define EXPECT_AND_SKIP_TOKEN(token, result) \ + EXPECT_AND_SKIP_TOKEN_IMPL(token, "Expected " #token, result) + + +namespace NYql { +namespace { + +enum EToken +{ + TOKEN_EOF = -1, + + // type keywords + TOKEN_TYPE_MIN = -2, + TOKEN_STRING = -3, + TOKEN_BOOL = -4, + TOKEN_INT32 = -6, + TOKEN_UINT32 = -7, + TOKEN_INT64 = -8, + TOKEN_UINT64 = -9, + TOKEN_FLOAT = -10, + TOKEN_DOUBLE = -11, + TOKEN_LIST = -12, + TOKEN_OPTIONAL = -13, + TOKEN_DICT = -14, + TOKEN_TUPLE = -15, + TOKEN_STRUCT = -16, + TOKEN_RESOURCE = -17, + TOKEN_VOID = -18, + TOKEN_CALLABLE = -19, + TOKEN_TAGGED = -20, + TOKEN_YSON = -21, + TOKEN_UTF8 = -22, + TOKEN_VARIANT = -23, + TOKEN_UNIT = -24, + TOKEN_STREAM = -25, + TOKEN_GENERIC = -26, + TOKEN_JSON = -27, + TOKEN_NULL = -28, + TOKEN_DATE = -29, + TOKEN_DATETIME = -30, + TOKEN_TIMESTAMP = -31, + TOKEN_INTERVAL = -32, + TOKEN_DECIMAL = -33, + TOKEN_INT8 = -34, + TOKEN_UINT8 = -35, + TOKEN_INT16 = -36, + TOKEN_UINT16 = -37, + TOKEN_TZDATE = -38, + TOKEN_TZDATETIME = -39, + TOKEN_TZTIMESTAMP = -40, + TOKEN_UUID = -41, + TOKEN_FLOW = -42, + TOKEN_SET = -43, + TOKEN_ENUM = -44, + TOKEN_EMPTYLIST = -45, + TOKEN_EMPTYDICT = -46, + TOKEN_TYPE_MAX = -47, + TOKEN_JSON_DOCUMENT = -48, + TOKEN_DYNUMBER = -49, + TOKEN_SCALAR = -50, + TOKEN_BLOCK = -51, + + // identifiers + TOKEN_IDENTIFIER = -100, + TOKEN_ESCAPED_IDENTIFIER = -101, + + // special + TOKEN_ARROW = -200, +}; + +bool IsTypeKeyword(int token) +{ + return token < TOKEN_TYPE_MIN && token > TOKEN_TYPE_MAX; +} + +EToken TokenTypeFromStr(TStringBuf str) +{ + static const THashMap<TStringBuf, EToken> map = { + { TStringBuf("String"), TOKEN_STRING }, + { TStringBuf("Bool"), TOKEN_BOOL }, + { TStringBuf("Int32"), TOKEN_INT32 }, + { TStringBuf("Uint32"), TOKEN_UINT32 }, + { TStringBuf("Int64"), TOKEN_INT64 }, + { TStringBuf("Uint64"), TOKEN_UINT64 }, + { TStringBuf("Float"), TOKEN_FLOAT }, + { TStringBuf("Double"), TOKEN_DOUBLE }, + { TStringBuf("List"), TOKEN_LIST }, + { TStringBuf("Optional"), TOKEN_OPTIONAL }, + { TStringBuf("Dict"), TOKEN_DICT }, + { TStringBuf("Tuple"), TOKEN_TUPLE }, + { TStringBuf("Struct"), TOKEN_STRUCT }, + { TStringBuf("Resource"), TOKEN_RESOURCE }, + { TStringBuf("Void"), TOKEN_VOID }, + { TStringBuf("Callable"), TOKEN_CALLABLE }, + { TStringBuf("Tagged"), TOKEN_TAGGED }, + { TStringBuf("Yson"), TOKEN_YSON }, + { TStringBuf("Utf8"), TOKEN_UTF8 }, + { TStringBuf("Variant"), TOKEN_VARIANT }, + { TStringBuf("Unit"), TOKEN_UNIT }, + { TStringBuf("Stream"), TOKEN_STREAM }, + { TStringBuf("Generic"), TOKEN_GENERIC }, + { TStringBuf("Json"), TOKEN_JSON }, + { TStringBuf("Date"), TOKEN_DATE }, + { TStringBuf("Datetime"), TOKEN_DATETIME }, + { TStringBuf("Timestamp"), TOKEN_TIMESTAMP }, + { TStringBuf("Interval"), TOKEN_INTERVAL }, + { TStringBuf("Null"), TOKEN_NULL }, + { TStringBuf("Decimal"), TOKEN_DECIMAL }, + { TStringBuf("Int8"), TOKEN_INT8 }, + { TStringBuf("Uint8"), TOKEN_UINT8 }, + { TStringBuf("Int16"), TOKEN_INT16 }, + { TStringBuf("Uint16"), TOKEN_UINT16 }, + { TStringBuf("TzDate"), TOKEN_TZDATE }, + { TStringBuf("TzDatetime"), TOKEN_TZDATETIME }, + { TStringBuf("TzTimestamp"), TOKEN_TZTIMESTAMP }, + { TStringBuf("Uuid"), TOKEN_UUID }, + { TStringBuf("Flow"), TOKEN_FLOW }, + { TStringBuf("Set"), TOKEN_SET }, + { TStringBuf("Enum"), TOKEN_ENUM }, + { TStringBuf("EmptyList"), TOKEN_EMPTYLIST }, + { TStringBuf("EmptyDict"), TOKEN_EMPTYDICT }, + { TStringBuf("JsonDocument"), TOKEN_JSON_DOCUMENT }, + { TStringBuf("DyNumber"), TOKEN_DYNUMBER }, + { TStringBuf("Block"), TOKEN_BLOCK}, + { TStringBuf("Scalar"), TOKEN_SCALAR}, + }; + + auto it = map.find(str); + if (it != map.end()) { + return it->second; + } + + return TOKEN_IDENTIFIER; +} + + +////////////////////////////////////////////////////////////////////////////// +// TTypeParser +////////////////////////////////////////////////////////////////////////////// +class TTypeParser +{ +public: + TTypeParser( + TStringBuf str, TIssues& issues, + TPosition position, TMemoryPool& pool) + : Str(str) + , Issues(issues) + , Position(position) + , Index(0) + , Pool(pool) + { + GetNextToken(); + } + + TAstNode* ParseTopLevelType() { + TAstNode* type = ParseType(); + if (type) { + EXPECT_AND_SKIP_TOKEN_IMPL( + TOKEN_EOF, "Expected end of string", nullptr); + } + return type; + } + +private: + TAstNode* ParseType() { + TAstNode* type = nullptr; + + switch (Token) { + case '(': return ParseCallableType(); + + case TOKEN_STRING: + case TOKEN_BOOL: + case TOKEN_INT8: + case TOKEN_UINT8: + case TOKEN_INT16: + case TOKEN_UINT16: + case TOKEN_INT32: + case TOKEN_UINT32: + case TOKEN_INT64: + case TOKEN_UINT64: + case TOKEN_FLOAT: + case TOKEN_DOUBLE: + case TOKEN_YSON: + case TOKEN_UTF8: + case TOKEN_JSON: + case TOKEN_DATE: + case TOKEN_DATETIME: + case TOKEN_TIMESTAMP: + case TOKEN_INTERVAL: + case TOKEN_TZDATE: + case TOKEN_TZDATETIME: + case TOKEN_TZTIMESTAMP: + case TOKEN_UUID: + case TOKEN_JSON_DOCUMENT: + case TOKEN_DYNUMBER: + type = MakeDataType(Identifier); + GetNextToken(); + break; + + case TOKEN_DECIMAL: + type = ParseDecimalType(); + break; + + case TOKEN_LIST: + type = ParseListType(); + break; + + case TOKEN_OPTIONAL: + type = ParseOptionalType(); + break; + + case TOKEN_DICT: + type = ParseDictType(); + break; + + case TOKEN_TUPLE: + type = ParseTupleType(); + break; + + case TOKEN_STRUCT: + type = ParseStructType(); + break; + + case TOKEN_RESOURCE: + type = ParseResourceType(); + break; + + case TOKEN_VOID: + type = MakeVoidType(); + GetNextToken(); + break; + + case TOKEN_NULL: + type = MakeNullType(); + GetNextToken(); + break; + + case TOKEN_EMPTYLIST: + type = MakeEmptyListType(); + GetNextToken(); + break; + + case TOKEN_EMPTYDICT: + type = MakeEmptyDictType(); + GetNextToken(); + break; + + case TOKEN_CALLABLE: + type = ParseCallableTypeWithKeyword(); + break; + + case TOKEN_TAGGED: + type = ParseTaggedType(); + break; + + case TOKEN_VARIANT: + type = ParseVariantType(); + break; + + case TOKEN_UNIT: + type = MakeUnitType(); + GetNextToken(); + break; + + case TOKEN_STREAM: + type = ParseStreamType(); + break; + + case TOKEN_FLOW: + type = ParseFlowType(); + break; + + case TOKEN_GENERIC: + type = MakeGenericType(); + GetNextToken(); + break; + + case TOKEN_SET: + type = ParseSetType(); + break; + + case TOKEN_ENUM: + type = ParseEnumType(); + break; + + case TOKEN_BLOCK: + type = ParseBlockType(); + break; + + case TOKEN_SCALAR: + type = ParseScalarType(); + break; + + default: + if (Identifier.empty()) { + return AddError("Expected type"); + } + + auto id = Identifier; + if (id.SkipPrefix("pg")) { + if (NPg::HasType(TString(id))) { + type = MakePgType(id); + GetNextToken(); + } + } else if (id.SkipPrefix("_pg")) { + if (NPg::HasType(TString(id)) && !id.StartsWith('_')) { + type = MakePgType(TString("_") + id); + GetNextToken(); + } + } + + if (!type) { + return AddError(TString("Unknown type: '") + Identifier + "\'"); + } + } + + if (type) { + while (Token == '?') { + type = MakeOptionalType(type); + GetNextToken(); + } + } + return type; + } + + char LookaheadNonSpaceChar() { + size_t i = Index; + while (i < Str.size() && isspace(Str[i])) { + i++; + } + return (i < Str.size()) ? Str[i] : -1; + } + + int GetNextToken() { + return Token = ReadNextToken(); + } + + int ReadNextToken() { + // skip spaces + while (!AtEnd() && isspace(Get())) { + Move(); + } + + TokenBegin = Position; + if (AtEnd()) { + return TOKEN_EOF; + } + + // clear last readed indentifier + Identifier = {}; + + char lastChar = Get(); + if (lastChar == '_' || isalnum(lastChar)) { // identifier + size_t start = Index; + while (!AtEnd()) { + lastChar = Get(); + if (lastChar == '_' || isalnum(lastChar)) Move(); + else break; + } + + Identifier = Str.SubString(start, Index - start); + return TokenTypeFromStr(Identifier); + } else if (lastChar == '\'') { // escaped identifier + Move(); // skip '\'' + if (AtEnd()) return TOKEN_EOF; + + UnescapedIdentifier.clear(); + TStringOutput sout(UnescapedIdentifier); + TStringBuf atom = Str.SubStr(Index); + size_t readBytes = 0; + EUnescapeResult unescapeResunt = + UnescapeArbitraryAtom(atom, '\'', &sout, &readBytes); + + if (unescapeResunt != EUnescapeResult::OK) return TOKEN_EOF; + + // skip already readed chars + while (readBytes-- != 0) { + Move(); + } + + if (AtEnd()) return TOKEN_EOF; + + Identifier = UnescapedIdentifier; + return TOKEN_ESCAPED_IDENTIFIER; + } else { + Move(); // skip last char + if (lastChar == '-' && !AtEnd() && Get() == '>') { + Move(); // skip '>' + return TOKEN_ARROW; + } + // otherwise, just return the last character as its ascii value + return lastChar; + } + } + + TAstNode* ParseCallableType() { + EXPECT_AND_SKIP_TOKEN('(', nullptr); + + TSmallVec<TAstNode*> args; + args.push_back(nullptr); // CallableType Atom + settings + return type + args.push_back(nullptr); + args.push_back(nullptr); + bool optArgsStarted = false; + bool namedArgsStarted = false; + ui32 optArgsCount = 0; + bool lastWasTypeStatement = false; + + // (1) parse argements + for (;;) { + if (Token == TOKEN_EOF) { + if (optArgsStarted) { + return AddError("Expected ']'"); + } + return AddError("Expected ')'"); + } + + if (Token == ']' || Token == ')') { + break; + } + + if (lastWasTypeStatement) { + EXPECT_AND_SKIP_TOKEN(',', nullptr); + lastWasTypeStatement = false; + } + + if (Token == '[') { + optArgsStarted = true; + GetNextToken(); // eat '[' + } else if (Token == ':') { + return AddError("Expected non empty argument name"); + } else if (IsTypeKeyword(Token) || Token == '(' || // '(' - begin of callable type + Token == TOKEN_IDENTIFIER || + Token == TOKEN_ESCAPED_IDENTIFIER) + { + TStringBuf argName; + ui32 argNameFlags = TNodeFlags::Default; + + if (LookaheadNonSpaceChar() == ':') { + namedArgsStarted = true; + argName = Identifier; + + if (Token == TOKEN_ESCAPED_IDENTIFIER) { + argNameFlags = TNodeFlags::ArbitraryContent; + } + + GetNextToken(); // eat name + EXPECT_AND_SKIP_TOKEN(':', nullptr); + + if (Token == TOKEN_EOF) { + return AddError("Expected type of named argument"); + } + } else { + if (namedArgsStarted) { + return AddError("Expected named argument, because of " + "previous argument(s) was named"); + } + } + + auto argType = ParseType(); + if (!argType) { + return nullptr; + } + lastWasTypeStatement = true; + + if (optArgsStarted) { + if (!argType->IsList() || argType->GetChildrenCount() == 0 || + !argType->GetChild(0)->IsAtom() || + argType->GetChild(0)->GetContent() != TStringBuf("OptionalType")) + { + return AddError("Optionals are only allowed in the optional arguments"); + } + optArgsCount++; + } + + ui32 argFlags = 0; + if (Token == '{') { + if (!ParseCallableArgFlags(argFlags)) return nullptr; + } + + TSmallVec<TAstNode*> argSettings; + argSettings.push_back(argType); + if (!argName.empty()) { + argSettings.push_back(MakeQuotedAtom(argName, argNameFlags)); + } + if (argFlags) { + if (argName.empty()) { + auto atom = MakeQuotedLiteralAtom(TStringBuf(""), TNodeFlags::ArbitraryContent); + argSettings.push_back(atom); + } + argSettings.push_back(MakeQuotedAtom(ToString(argFlags))); + } + args.push_back(MakeQuote( + MakeList(argSettings.data(), argSettings.size()))); + } else { + return AddError("Expected type or argument name"); + } + } + + if (optArgsStarted) { + EXPECT_AND_SKIP_TOKEN(']', nullptr); + } + + EXPECT_AND_SKIP_TOKEN(')', nullptr); + + // (2) expect '->' after arguments + EXPECT_AND_SKIP_TOKEN_IMPL( + TOKEN_ARROW, "Expected '->' after arguments", nullptr); + + // (3) parse return type + TAstNode* returnType = ParseType(); + if (!returnType) { + return nullptr; + } + + // (4) parse payload + TStringBuf payload; + if (Token == '{') { + if (!ParseCallablePayload(payload)) return nullptr; + } + + return MakeCallableType(args, optArgsCount, returnType, payload); + } + + // { Flags: f1 | f2 | f3 } + bool ParseCallableArgFlags(ui32& argFlags) { + GetNextToken(); // eat '{' + + if (Token != TOKEN_IDENTIFIER || Identifier != TStringBuf("Flags")) { + AddError("Expected Flags field"); + return false; + } + + GetNextToken(); // eat 'Flags' + EXPECT_AND_SKIP_TOKEN(':', false); + + for (;;) { + if (Token == TOKEN_IDENTIFIER) { + if (Identifier == TStringBuf("AutoMap")) { + argFlags |= TArgumentFlags::AutoMap; + } else { + AddError(TString("Unknown flag name: ") + Identifier); + return false; + } + GetNextToken(); // eat flag name + } else { + AddError("Expected flag name"); + return false; + } + + if (Token == '}') { + break; + } else if (Token == '|') { + GetNextToken(); // eat '|' + } else { + AddError("Expected '}' or '|'"); + } + } + + GetNextToken(); // eat '}' + return true; + } + + bool ParseCallablePayload(TStringBuf& payload) { + GetNextToken(); // eat '{' + + if (Token != TOKEN_IDENTIFIER && Identifier != TStringBuf("Payload")) { + AddError("Expected Payload field"); + return false; + } + + GetNextToken(); // eat 'Payload' + EXPECT_AND_SKIP_TOKEN(':', false); + + if (Token == TOKEN_IDENTIFIER || Token == TOKEN_ESCAPED_IDENTIFIER) { + payload = Identifier; + GetNextToken(); // eat payload data + } else { + AddError("Expected payload data"); + return false; + } + + EXPECT_AND_SKIP_TOKEN('}', false); + return true; + } + + TAstNode* ParseCallableTypeWithKeyword() { + GetNextToken(); // eat keyword + EXPECT_AND_SKIP_TOKEN('<', nullptr); + + auto type = ParseCallableType(); + if (!type) return nullptr; + + EXPECT_AND_SKIP_TOKEN('>', nullptr); + return type; + } + + TAstNode* ParseListType() { + GetNextToken(); // eat keyword + EXPECT_AND_SKIP_TOKEN('<', nullptr); + + auto itemType = ParseType(); + if (!itemType) return nullptr; + + EXPECT_AND_SKIP_TOKEN('>', nullptr); + return MakeListType(itemType); + } + + TAstNode* ParseStreamType() { + GetNextToken(); // eat keyword + EXPECT_AND_SKIP_TOKEN('<', nullptr); + + auto itemType = ParseType(); + if (!itemType) return nullptr; + + EXPECT_AND_SKIP_TOKEN('>', nullptr); + return MakeStreamType(itemType); + } + + TAstNode* ParseFlowType() { + GetNextToken(); // eat keyword + EXPECT_AND_SKIP_TOKEN('<', nullptr); + + auto itemType = ParseType(); + if (!itemType) return nullptr; + + EXPECT_AND_SKIP_TOKEN('>', nullptr); + return MakeFlowType(itemType); + } + + TAstNode* ParseBlockType() { + GetNextToken(); // eat keyword + EXPECT_AND_SKIP_TOKEN('<', nullptr); + + auto itemType = ParseType(); + if (!itemType) return nullptr; + + EXPECT_AND_SKIP_TOKEN('>', nullptr); + return MakeBlockType(itemType); + } + + TAstNode* ParseScalarType() { + GetNextToken(); // eat keyword + EXPECT_AND_SKIP_TOKEN('<', nullptr); + + auto itemType = ParseType(); + if (!itemType) return nullptr; + + EXPECT_AND_SKIP_TOKEN('>', nullptr); + return MakeScalarType(itemType); + } + + TAstNode* ParseDecimalType() { + GetNextToken(); // eat keyword + EXPECT_AND_SKIP_TOKEN('(', nullptr); + + const auto precision = Identifier; + GetNextToken(); // eat keyword + + EXPECT_AND_SKIP_TOKEN(',', nullptr); + + const auto scale = Identifier; + GetNextToken(); // eat keyword + + EXPECT_AND_SKIP_TOKEN(')', nullptr); + + return MakeDecimalType(precision, scale); + } + + TAstNode* ParseOptionalType() { + GetNextToken(); // eat keyword + EXPECT_AND_SKIP_TOKEN('<', nullptr); + + auto itemType = ParseType(); + if (!itemType) return nullptr; + + EXPECT_AND_SKIP_TOKEN('>', nullptr); + return MakeOptionalType(itemType); + } + + TAstNode* ParseDictType() { + GetNextToken(); // eat keyword + EXPECT_AND_SKIP_TOKEN('<', nullptr); + + auto keyType = ParseType(); + if (!keyType) return nullptr; + + EXPECT_AND_SKIP_TOKEN(',', nullptr); + + auto valueType = ParseType(); + if (!valueType) return nullptr; + + EXPECT_AND_SKIP_TOKEN('>', nullptr); + return MakeDictType(keyType, valueType); + } + + TAstNode* ParseSetType() { + GetNextToken(); // eat keyword + EXPECT_AND_SKIP_TOKEN('<', nullptr); + + auto keyType = ParseType(); + if (!keyType) return nullptr; + + EXPECT_AND_SKIP_TOKEN('>', nullptr); + return MakeDictType(keyType, MakeVoidType()); + } + + TAstNode* ParseTupleTypeImpl() { + TSmallVec<TAstNode*> items; + items.push_back(nullptr); // reserve for TupleType + + if (Token != '>') { + for (;;) { + auto itemType = ParseType(); + if (!itemType) return nullptr; + + items.push_back(itemType); + + if (Token == '>') { + break; + } else if (Token == ',') { + GetNextToken(); + } else { + return AddError("Expected '>' or ','"); + } + } + } + + return MakeTupleType(items); + } + + TAstNode* ParseTupleType() { + GetNextToken(); // eat keyword + EXPECT_AND_SKIP_TOKEN('<', nullptr); + TAstNode* tupleType = ParseTupleTypeImpl(); + if (tupleType) { + EXPECT_AND_SKIP_TOKEN('>', nullptr); + } + return tupleType; + } + + TAstNode* ParseStructTypeImpl() { + TMap<TString, TAstNode*> members; + if (Token != '>') { + for (;;) { + TString name; + if (Token == TOKEN_IDENTIFIER || + Token == TOKEN_ESCAPED_IDENTIFIER) + { + name = Identifier; + } else { + return AddError("Expected struct member name"); + } + + if (name.empty()) { + return AddError("Empty name is not allowed"); + } else if (members.contains(name)) { + return AddError("Member name duplication"); + } + + GetNextToken(); // eat member name + EXPECT_AND_SKIP_TOKEN(':', nullptr); + + auto type = ParseType(); + if (!type) return nullptr; + + members.emplace(std::move(name), type); + + if (Token == '>') { + break; + } else if (Token == ',') { + GetNextToken(); + } else { + return AddError("Expected '>' or ','"); + } + } + } + + return MakeStructType(members); + } + + TAstNode* ParseStructType() { + GetNextToken(); // eat keyword + EXPECT_AND_SKIP_TOKEN('<', nullptr); + TAstNode* structType = ParseStructTypeImpl(); + if (structType) { + EXPECT_AND_SKIP_TOKEN('>', nullptr); + } + return structType; + } + + TAstNode* ParseVariantType() { + GetNextToken(); // eat keyword + EXPECT_AND_SKIP_TOKEN('<', nullptr); + + TAstNode* underlyingType = nullptr; + if (Token == TOKEN_IDENTIFIER || Token == TOKEN_ESCAPED_IDENTIFIER) { + underlyingType = ParseStructTypeImpl(); + } else if (IsTypeKeyword(Token) || Token == '(') { + underlyingType = ParseTupleTypeImpl(); + } else { + return AddError("Expected type"); + } + + if (!underlyingType) return nullptr; + + EXPECT_AND_SKIP_TOKEN('>', nullptr); + return MakeVariantType(underlyingType); + } + + TAstNode* ParseEnumType() { + GetNextToken(); // eat keyword + EXPECT_AND_SKIP_TOKEN('<', nullptr); + + TMap<TString, TAstNode*> members; + for (;;) { + TString name; + if (Token == TOKEN_IDENTIFIER || + Token == TOKEN_ESCAPED_IDENTIFIER) + { + name = Identifier; + } else { + return AddError("Expected name"); + } + + if (name.empty()) { + return AddError("Empty name is not allowed"); + } else if (members.contains(name)) { + return AddError("Member name duplication"); + } + + GetNextToken(); // eat member name + members.emplace(std::move(name), MakeVoidType()); + + if (Token == '>') { + break; + } else if (Token == ',') { + GetNextToken(); + } else { + return AddError("Expected '>' or ','"); + } + } + + auto underlyingType = MakeStructType(members); + EXPECT_AND_SKIP_TOKEN('>', nullptr); + return MakeVariantType(underlyingType); + } + + TAstNode* MakeCallableType( + TSmallVec<TAstNode*>& args, size_t optionalArgsCount, + TAstNode* returnType, TStringBuf payload) + { + args[0] = MakeLiteralAtom(TStringBuf("CallableType")); + TSmallVec<TAstNode*> mainSettings; + if (optionalArgsCount || !payload.empty()) { + mainSettings.push_back(optionalArgsCount + ? MakeQuotedAtom(ToString(optionalArgsCount)) + : MakeQuotedLiteralAtom(TStringBuf("0"))); + } + + if (!payload.empty()) { + mainSettings.push_back(MakeQuotedAtom(payload, TNodeFlags::ArbitraryContent)); + } + + args[1] = MakeQuote(MakeList(mainSettings.data(), mainSettings.size())); + + TSmallVec<TAstNode*> returnSettings; + returnSettings.push_back(returnType); + args[2] = MakeQuote(MakeList(returnSettings.data(), returnSettings.size())); + + return MakeList(args.data(), args.size()); + } + + TAstNode* MakeListType(TAstNode* itemType) { + TAstNode* items[] = { + MakeLiteralAtom(TStringBuf("ListType")), + itemType, + }; + return MakeList(items, Y_ARRAY_SIZE(items)); + } + + TAstNode* MakeStreamType(TAstNode* itemType) { + TAstNode* items[] = { + MakeLiteralAtom(TStringBuf("StreamType")), + itemType, + }; + return MakeList(items, Y_ARRAY_SIZE(items)); + } + + TAstNode* MakeFlowType(TAstNode* itemType) { + TAstNode* items[] = { + MakeLiteralAtom(TStringBuf("FlowType")), + itemType, + }; + return MakeList(items, Y_ARRAY_SIZE(items)); + } + + TAstNode* MakeBlockType(TAstNode* itemType) { + TAstNode* items[] = { + MakeLiteralAtom(TStringBuf("BlockType")), + itemType, + }; + return MakeList(items, Y_ARRAY_SIZE(items)); + } + + TAstNode* MakeScalarType(TAstNode* itemType) { + TAstNode* items[] = { + MakeLiteralAtom(TStringBuf("ScalarType")), + itemType, + }; + return MakeList(items, Y_ARRAY_SIZE(items)); + } + + TAstNode* MakeVariantType(TAstNode* underlyingType) { + TAstNode* items[] = { + MakeLiteralAtom(TStringBuf("VariantType")), + underlyingType, + }; + return MakeList(items, Y_ARRAY_SIZE(items)); + } + + TAstNode* MakeDictType(TAstNode* keyType, TAstNode* valueType) { + TAstNode* items[] = { + MakeLiteralAtom(TStringBuf("DictType")), + keyType, + valueType, + }; + return MakeList(items, Y_ARRAY_SIZE(items)); + } + + TAstNode* MakeTupleType(TSmallVec<TAstNode*>& items) { + items[0] = MakeLiteralAtom(TStringBuf("TupleType")); + return MakeList(items.data(), items.size()); + } + + TAstNode* MakeStructType(const TMap<TString, TAstNode*>& members) { + TSmallVec<TAstNode*> items; + items.push_back(MakeLiteralAtom(TStringBuf("StructType"))); + + for (const auto& member: members) { + TAstNode* memberType[] = { + MakeQuotedAtom(member.first, TNodeFlags::ArbitraryContent), // name + member.second, // type + }; + items.push_back(MakeQuote(MakeList(memberType, Y_ARRAY_SIZE(memberType)))); + } + + return MakeList(items.data(), items.size()); + } + + TAstNode* ParseResourceType() { + GetNextToken(); // eat keyword + EXPECT_AND_SKIP_TOKEN('<', nullptr); + + if (Token != TOKEN_IDENTIFIER && Token != TOKEN_ESCAPED_IDENTIFIER) { + return AddError("Expected resource tag"); + } + + TStringBuf tag = Identifier; + if (tag.empty()) { + return AddError("Expected non empty resource tag"); + } + + GetNextToken(); // eat tag + EXPECT_AND_SKIP_TOKEN('>', nullptr); + return MakeResourceType(tag); + } + + TAstNode* ParseTaggedType() { + GetNextToken(); // eat keyword + EXPECT_AND_SKIP_TOKEN('<', nullptr); + + auto baseType = ParseType(); + if (!baseType) return nullptr; + + EXPECT_AND_SKIP_TOKEN(',', nullptr); + + if (Token != TOKEN_IDENTIFIER && Token != TOKEN_ESCAPED_IDENTIFIER) { + return AddError("Expected tag of type"); + } + + TStringBuf tag = Identifier; + if (tag.empty()) { + return AddError("Expected non empty tag of type"); + } + + GetNextToken(); // eat tag + EXPECT_AND_SKIP_TOKEN('>', nullptr); + return MakeTaggedType(baseType, tag); + } + + TAstNode* MakeResourceType(TStringBuf tag) { + TAstNode* items[] = { + MakeLiteralAtom(TStringBuf("ResourceType")), + MakeQuotedAtom(tag), + }; + return MakeList(items, Y_ARRAY_SIZE(items)); + } + + TAstNode* MakeVoidType() { + TAstNode* items[] = { + MakeLiteralAtom(TStringBuf("VoidType")) + }; + return MakeList(items, Y_ARRAY_SIZE(items)); + } + + TAstNode* MakeNullType() { + TAstNode* items[] = { + MakeLiteralAtom(TStringBuf("NullType")) + }; + return MakeList(items, Y_ARRAY_SIZE(items)); + } + + TAstNode* MakeEmptyListType() { + TAstNode* items[] = { + MakeLiteralAtom(TStringBuf("EmptyListType")) + }; + return MakeList(items, Y_ARRAY_SIZE(items)); + } + + TAstNode* MakeEmptyDictType() { + TAstNode* items[] = { + MakeLiteralAtom(TStringBuf("EmptyDictType")) + }; + return MakeList(items, Y_ARRAY_SIZE(items)); + } + + TAstNode* MakeUnitType() { + TAstNode* items[] = { + MakeLiteralAtom(TStringBuf("UnitType")) + }; + return MakeList(items, Y_ARRAY_SIZE(items)); + } + + TAstNode* MakeGenericType() { + TAstNode* items[] = { + MakeLiteralAtom(TStringBuf("GenericType")) + }; + return MakeList(items, Y_ARRAY_SIZE(items)); + } + + TAstNode* MakeTaggedType(TAstNode* baseType, TStringBuf tag) { + TAstNode* items[] = { + MakeLiteralAtom(TStringBuf("TaggedType")), + baseType, + MakeQuotedAtom(tag) + }; + + return MakeList(items, Y_ARRAY_SIZE(items)); + } + + TAstNode* MakeDataType(TStringBuf type) { + TAstNode* items[] = { + MakeLiteralAtom(TStringBuf("DataType")), + MakeQuotedAtom(type), + }; + return MakeList(items, Y_ARRAY_SIZE(items)); + } + + TAstNode* MakePgType(TStringBuf type) { + TAstNode* items[] = { + MakeLiteralAtom(TStringBuf("PgType")), + MakeQuotedAtom(type), + }; + return MakeList(items, Y_ARRAY_SIZE(items)); + } + + TAstNode* MakeDecimalType(TStringBuf precision, TStringBuf scale) { + TAstNode* items[] = { + MakeLiteralAtom(TStringBuf("DataType")), + MakeQuotedAtom(TStringBuf("Decimal")), + MakeQuotedAtom(precision), + MakeQuotedAtom(scale), + }; + return MakeList(items, Y_ARRAY_SIZE(items)); + } + + TAstNode* MakeOptionalType(TAstNode* type) { + TAstNode* items[] = { + MakeLiteralAtom(TStringBuf("OptionalType")), + type, + }; + return MakeList(items, Y_ARRAY_SIZE(items)); + } + + TAstNode* MakeAtom(TStringBuf content, ui32 flags = TNodeFlags::Default) { + return TAstNode::NewAtom(Position, content, Pool, flags); + } + + TAstNode* MakeLiteralAtom(TStringBuf content, ui32 flags = TNodeFlags::Default) { + return TAstNode::NewLiteralAtom(Position, content, Pool, flags); + } + + TAstNode* MakeQuote(TAstNode* node) { + TAstNode* items[] = { + &TAstNode::QuoteAtom, + node, + }; + return MakeList(items, Y_ARRAY_SIZE(items)); + } + + TAstNode* MakeQuotedAtom(TStringBuf content, ui32 flags = TNodeFlags::Default) { + return MakeQuote(MakeAtom(content, flags)); + } + + TAstNode* MakeQuotedLiteralAtom(TStringBuf content, ui32 flags = TNodeFlags::Default) { + return MakeQuote(MakeLiteralAtom(content, flags)); + } + + TAstNode* MakeList(TAstNode** children, ui32 count) { + return TAstNode::NewList(Position, children, count, Pool); + } + + char Get() const { + return Str[Index]; + } + + bool AtEnd() const { + return Index >= Str.size(); + } + + void Move() { + if (AtEnd()) return; + + ++Index; + ++Position.Column; + + if (!AtEnd() && Str[Index] == '\n') { + Position.Row++; + Position.Column = 1; + } + } + + TAstNode* AddError(const TString& message) { + Issues.AddIssue(TIssue(TokenBegin, message)); + return nullptr; + } + +private: + TStringBuf Str; + TIssues& Issues; + TPosition TokenBegin, Position; + size_t Index; + int Token; + TString UnescapedIdentifier; + TStringBuf Identifier; + TMemoryPool& Pool; +}; + +////////////////////////////////////////////////////////////////////////////// +// TTypePrinter +////////////////////////////////////////////////////////////////////////////// +class TTypePrinter: public TTypeAnnotationVisitor +{ +public: + TTypePrinter(IOutputStream& out) + : Out_(out) + { + } + +private: + void Visit(const TUnitExprType& type) final { + TopLevel = false; + Y_UNUSED(type); + Out_ << TStringBuf("Unit"); + } + + void Visit(const TMultiExprType& type) final { + TopLevel = false; + Out_ << TStringBuf("Multi<"); + const auto& items = type.GetItems(); + for (ui32 i = 0; i < items.size(); ++i) { + if (i) { + Out_ << ','; + } + items[i]->Accept(*this); + } + Out_ << '>'; + } + + void Visit(const TTupleExprType& type) final { + TopLevel = false; + Out_ << TStringBuf("Tuple<"); + const auto& items = type.GetItems(); + for (ui32 i = 0; i < items.size(); ++i) { + if (i) { + Out_ << ','; + } + items[i]->Accept(*this); + } + Out_ << '>'; + } + + void Visit(const TStructExprType& type) final { + TopLevel = false; + Out_ << TStringBuf("Struct<"); + const auto& items = type.GetItems(); + for (ui32 i = 0; i < items.size(); ++i) { + if (i) { + Out_ << ','; + } + items[i]->Accept(*this); + } + Out_ << '>'; + } + + void Visit(const TItemExprType& type) final { + TopLevel = false; + EscapeArbitraryAtom(type.GetName(), '\'', &Out_); + Out_ << ':'; + type.GetItemType()->Accept(*this); + } + + void Visit(const TListExprType& type) final { + TopLevel = false; + Out_ << TStringBuf("List<"); + type.GetItemType()->Accept(*this); + Out_ << '>'; + } + + void Visit(const TStreamExprType& type) final { + TopLevel = false; + Out_ << TStringBuf("Stream<"); + type.GetItemType()->Accept(*this); + Out_ << '>'; + } + + void Visit(const TFlowExprType& type) final { + TopLevel = false; + Out_ << TStringBuf("Flow<"); + type.GetItemType()->Accept(*this); + Out_ << '>'; + } + + void Visit(const TBlockExprType& type) final { + TopLevel = false; + Out_ << TStringBuf("Block<"); + type.GetItemType()->Accept(*this); + Out_ << '>'; + } + + void Visit(const TScalarExprType& type) final { + TopLevel = false; + Out_ << TStringBuf("Scalar<"); + type.GetItemType()->Accept(*this); + Out_ << '>'; + } + + void Visit(const TDataExprType& type) final { + TopLevel = false; + Out_ << type.GetName(); + if (const auto dataExprParamsType = dynamic_cast<const TDataExprParamsType*>(&type)) { + Out_ << '(' << dataExprParamsType->GetParamOne() << ',' << dataExprParamsType->GetParamTwo() << ')'; + } + } + + void Visit(const TPgExprType& type) final { + TopLevel = false; + TStringBuf name = type.GetName(); + if (!name.SkipPrefix("_")) { + Out_ << "pg" << name; + } else { + Out_ << "_pg" << name; + } + } + + void Visit(const TWorldExprType& type) final { + Y_UNUSED(type); + TopLevel = false; + Out_ << TStringBuf("World"); + } + + void Visit(const TOptionalExprType& type) final { + const TTypeAnnotationNode* itemType = type.GetItemType(); + if (TopLevel || itemType->GetKind() == ETypeAnnotationKind::Callable) { + TopLevel = false; + Out_ << TStringBuf("Optional<"); + itemType->Accept(*this); + Out_ << '>'; + } else { + TopLevel = false; + itemType->Accept(*this); + Out_ << '?'; + } + } + + void Visit(const TCallableExprType& type) final { + TopLevel = false; + const auto& args = type.GetArguments(); + ui32 argsCount = type.GetArgumentsSize(); + ui32 optArgsCount = + Min<ui32>(type.GetOptionalArgumentsCount(), argsCount); + + Out_ << TStringBuf("Callable<("); + for (ui32 i = 0; i < argsCount; ++i) { + if (i) { + Out_ << ','; + } + if (i == argsCount - optArgsCount) { + Out_ << '['; + } + const TCallableExprType::TArgumentInfo& argInfo = args[i]; + if (!argInfo.Name.empty()) { + EscapeArbitraryAtom(argInfo.Name, '\'', &Out_); + Out_ << ':'; + } + argInfo.Type->Accept(*this); + if (argInfo.Flags) { + Out_ << TStringBuf("{Flags:"); + if (argInfo.Flags & TArgumentFlags::AutoMap) { + Out_ << TStringBuf("AutoMap"); + } + Out_ << '}'; + } + } + + if (optArgsCount > 0) { + Out_ << ']'; + } + + Out_ << TStringBuf(")->"); + type.GetReturnType()->Accept(*this); + if (!type.GetPayload().empty()) { + Out_ << TStringBuf("{Payload:") << type.GetPayload() << '}'; + } + Out_ << '>'; + } + + void Visit(const TResourceExprType& type) final { + TopLevel = false; + Out_ << TStringBuf("Resource<"); + EscapeArbitraryAtom(type.GetTag(), '\'', &Out_); + Out_ << '>'; + } + + void Visit(const TTypeExprType& type) final { + TopLevel = false; + Out_ << TStringBuf("Type<"); + type.GetType()->Accept(*this); + Out_ << '>'; + } + + void Visit(const TDictExprType& type) final { + TopLevel = false; + if (type.GetPayloadType()->GetKind() == ETypeAnnotationKind::Void) { + Out_ << TStringBuf("Set<"); + type.GetKeyType()->Accept(*this); + Out_ << '>'; + } else { + Out_ << TStringBuf("Dict<"); + type.GetKeyType()->Accept(*this); + Out_ << ','; + type.GetPayloadType()->Accept(*this); + Out_ << '>'; + } + } + + void Visit(const TVoidExprType& type) final { + Y_UNUSED(type); + TopLevel = false; + Out_ << TStringBuf("Void"); + } + + void Visit(const TNullExprType& type) final { + Y_UNUSED(type); + TopLevel = false; + Out_ << TStringBuf("Null"); + } + + void Visit(const TEmptyListExprType& type) final { + Y_UNUSED(type); + TopLevel = false; + Out_ << TStringBuf("EmptyList"); + } + + void Visit(const TEmptyDictExprType& type) final { + Y_UNUSED(type); + TopLevel = false; + Out_ << TStringBuf("EmptyDict"); + } + + void Visit(const TGenericExprType& type) final { + Y_UNUSED(type); + TopLevel = false; + Out_ << TStringBuf("Generic"); + } + + void Visit(const TTaggedExprType& type) final { + TopLevel = false; + Out_ << TStringBuf("Tagged<"); + type.GetBaseType()->Accept(*this); + Out_ << ','; + EscapeArbitraryAtom(type.GetTag(), '\'', &Out_); + Out_ << '>'; + } + + void Visit(const TErrorExprType& type) final { + TopLevel = false; + Out_ << TStringBuf("Error<"); + auto pos = type.GetError().Position; + EscapeArbitraryAtom(pos.File.empty() ? "<main>" : pos.File, '\'', &Out_); + Out_ << ':'; + if (pos) { + Out_ << pos.Row << ':' << pos.Column << ':'; + } + + EscapeArbitraryAtom(type.GetError().GetMessage(), '\'', &Out_); + Out_ << '>'; + } + + void Visit(const TVariantExprType& type) final { + TopLevel = false; + auto underlyingType = type.GetUnderlyingType(); + if (underlyingType->GetKind() == ETypeAnnotationKind::Tuple) { + Out_ << TStringBuf("Variant<"); + auto tupleType = underlyingType->Cast<TTupleExprType>(); + const auto& items = tupleType->GetItems(); + for (ui32 i = 0; i < items.size(); ++i) { + if (i) { + Out_ << ','; + } + items[i]->Accept(*this); + } + } else { + auto srtuctType = underlyingType->Cast<TStructExprType>(); + const auto& items = srtuctType->GetItems(); + bool allVoid = true; + for (ui32 i = 0; i < items.size(); ++i) { + allVoid = allVoid && (items[i]->GetItemType()->GetKind() == ETypeAnnotationKind::Void); + } + + Out_ << (allVoid ? TStringBuf("Enum<") : TStringBuf("Variant<")); + for (ui32 i = 0; i < items.size(); ++i) { + if (i) { + Out_ << ','; + } + + if (allVoid) { + EscapeArbitraryAtom(items[i]->GetName(), '\'', &Out_); + } else { + items[i]->Accept(*this); + } + } + } + + Out_ << '>'; + } + +private: + IOutputStream& Out_; + bool TopLevel = true; +}; + +} // namespace + + +TAstNode* ParseType(TStringBuf str, TMemoryPool& pool, TIssues& issues, + TPosition position /* = TPosition(1, 1) */) +{ + TTypeParser parser(str, issues, position, pool); + return parser.ParseTopLevelType(); +} + +TString FormatType(const TTypeAnnotationNode* typeNode) +{ + TStringStream ss; + TTypePrinter printer(ss); + typeNode->Accept(printer); + return ss.Str(); +} + +} // namespace NYql diff --git a/yql/essentials/ast/yql_type_string.h b/yql/essentials/ast/yql_type_string.h new file mode 100644 index 00000000000..a8f3ec7f648 --- /dev/null +++ b/yql/essentials/ast/yql_type_string.h @@ -0,0 +1,15 @@ +#pragma once + +#include "yql_ast.h" +#include <yql/essentials/public/issue/yql_issue_manager.h> + +namespace NYql { + +class TTypeAnnotationNode; + +TAstNode* ParseType(TStringBuf str, TMemoryPool& pool, TIssues& issues, + TPosition position = {1, 1}); + +TString FormatType(const TTypeAnnotationNode* typeNode); + +} // namespace NYql diff --git a/yql/essentials/ast/yql_type_string_ut.cpp b/yql/essentials/ast/yql_type_string_ut.cpp new file mode 100644 index 00000000000..9b6db5f11e3 --- /dev/null +++ b/yql/essentials/ast/yql_type_string_ut.cpp @@ -0,0 +1,686 @@ +#include "yql_type_string.h" +#include "yql_expr.h" + +#include <library/cpp/testing/unittest/registar.h> + + +using namespace NYql; + +Y_UNIT_TEST_SUITE(TTypeString) +{ + void TestFail(const TStringBuf& prog, ui32 column, const TStringBuf& expectedError) { + TMemoryPool pool(4096); + TIssues errors; + auto res = ParseType(prog, pool, errors); + UNIT_ASSERT(res == nullptr); + UNIT_ASSERT(!errors.Empty()); + errors.PrintWithProgramTo(Cerr, "-memory-", TString(prog)); + UNIT_ASSERT_STRINGS_EQUAL(errors.begin()->GetMessage(), expectedError); + UNIT_ASSERT_VALUES_EQUAL(errors.begin()->Position.Column, column); + } + + void TestOk(const TStringBuf& prog, const TStringBuf& expectedType) { + TMemoryPool pool(4096); + TIssues errors; + auto res = ParseType(prog, pool, errors); + if (!res) { + errors.PrintWithProgramTo(Cerr, "-memory-", TString(prog)); + UNIT_FAIL(TStringBuilder() << "Parsing failed:" << Endl << prog); + } + UNIT_ASSERT_STRINGS_EQUAL(res->ToString(), expectedType); + } + + Y_UNIT_TEST(ParseEmpty) { + TestFail("", 1, "Expected type"); + } + + Y_UNIT_TEST(ParseDataTypes) { + TestOk("String", "(DataType 'String)"); + TestOk("Bool", "(DataType 'Bool)"); + TestOk("Uint8", "(DataType 'Uint8)"); + TestOk("Int8", "(DataType 'Int8)"); + TestOk("Uint16", "(DataType 'Uint16)"); + TestOk("Int16", "(DataType 'Int16)"); + TestOk("Int32", "(DataType 'Int32)"); + TestOk("Uint32", "(DataType 'Uint32)"); + TestOk("Int64", "(DataType 'Int64)"); + TestOk("Uint64", "(DataType 'Uint64)"); + TestOk("Float", "(DataType 'Float)"); + TestOk("Double", "(DataType 'Double)"); + TestOk("Yson", "(DataType 'Yson)"); + TestOk("Utf8", "(DataType 'Utf8)"); + TestOk("Json", "(DataType 'Json)"); + TestOk("Date", "(DataType 'Date)"); + TestOk("Datetime", "(DataType 'Datetime)"); + TestOk("Timestamp", "(DataType 'Timestamp)"); + TestOk("Interval", "(DataType 'Interval)"); + TestOk("TzDate", "(DataType 'TzDate)"); + TestOk("TzDatetime", "(DataType 'TzDatetime)"); + TestOk("TzTimestamp", "(DataType 'TzTimestamp)"); + TestOk("Uuid", "(DataType 'Uuid)"); + TestOk("Decimal(10,2)", "(DataType 'Decimal '10 '2)"); + } + + Y_UNIT_TEST(Multiline) { + TestOk(R"(Struct + < + name : String + , + age : Uint32 + >)", "(StructType " + "'('\"age\" (DataType 'Uint32)) " + "'('\"name\" (DataType 'String)))"); + } + + Y_UNIT_TEST(ParseNoArgsWithStringResult) { + TestOk("()->String", "(CallableType '() '((DataType 'String)))"); + TestOk("()->Utf8", "(CallableType '() '((DataType 'Utf8)))"); + } + + Y_UNIT_TEST(ParseNoArgsWithOptionalStringResult) { + TestOk("()->String?", + "(CallableType '() '((OptionalType (DataType 'String))))"); + TestOk("()->Yson?", + "(CallableType '() '((OptionalType (DataType 'Yson))))"); + } + + Y_UNIT_TEST(ParseOneArgWithDoubleResult) { + TestOk("(Int32)->Double", + "(CallableType '() '((DataType 'Double)) " + "'((DataType 'Int32))" + ")"); + TestOk("(Yson)->Double", + "(CallableType '() '((DataType 'Double)) " + "'((DataType 'Yson))" + ")"); + TestOk("(Utf8)->Double", + "(CallableType '() '((DataType 'Double)) " + "'((DataType 'Utf8))" + ")"); + } + + Y_UNIT_TEST(ParseTwoArgsWithOptionalByteResult) { + TestOk("(Int32?, String)->Uint8?", + "(CallableType '() '((OptionalType (DataType 'Uint8))) " + "'((OptionalType (DataType 'Int32))) " + "'((DataType 'String))" + ")"); + } + + Y_UNIT_TEST(ParseWithEmptyOptionalArgsStringResult) { + TestOk("([])->String", "(CallableType '() '((DataType 'String)))"); + } + + Y_UNIT_TEST(ParseWithOneOptionalArgDoubleResult) { + TestOk("([Int32?])->Double", + "(CallableType '('1) '((DataType 'Double)) " + "'((OptionalType (DataType 'Int32)))" + ")"); + } + + Y_UNIT_TEST(ParseOneReqAndOneOptionalArgsWithDoubleResult) { + TestOk("(String,[Int32?])->Double", + "(CallableType '('1) '((DataType 'Double)) " + "'((DataType 'String)) " + "'((OptionalType (DataType 'Int32)))" + ")"); + } + + Y_UNIT_TEST(ParseOneReqAndTwoOptionalArgsWithDoubleResult) { + TestOk("(String,[Int32?, Uint8?])->Double", + "(CallableType '('2) '((DataType 'Double)) " + "'((DataType 'String)) " + "'((OptionalType (DataType 'Int32))) " + "'((OptionalType (DataType 'Uint8)))" + ")"); + } + + Y_UNIT_TEST(ParseCallableArgWithDoubleResult) { + TestOk("(()->Uint8)->Double", + "(CallableType '() '((DataType 'Double)) " + "'((CallableType '() '((DataType 'Uint8))))" + ")"); + } + + Y_UNIT_TEST(ParseCallableOptionalArgWithDoubleResult) { + TestOk("([Optional<()->Uint8>])->Double", + "(CallableType '('1) '((DataType 'Double)) " + "'((OptionalType (CallableType '() '((DataType 'Uint8)))))" + ")"); + } + + Y_UNIT_TEST(ParseOptionalCallableArgWithDoubleResult) { + TestOk("(Optional<()->Uint8>)->Double", + "(CallableType '() '((DataType 'Double)) " + "'((OptionalType (CallableType '() '((DataType 'Uint8)))))" + ")"); + } + + Y_UNIT_TEST(ParseCallableWithNamedArgs) { + TestOk("(a:Uint8)->Double", + "(CallableType '() '((DataType 'Double)) " + "'((DataType 'Uint8) 'a)" + ")"); + TestOk("(List:Uint8)->Double", + "(CallableType '() '((DataType 'Double)) " + "'((DataType 'Uint8) 'List)" + ")"); + TestOk("('Dict':Uint8)->Double", + "(CallableType '() '((DataType 'Double)) " + "'((DataType 'Uint8) '\"Dict\")" + ")"); + TestOk("(a:Uint8,[b:Int32?])->Double", + "(CallableType '('1) '((DataType 'Double)) " + "'((DataType 'Uint8) 'a) " + "'((OptionalType (DataType 'Int32)) 'b)" + ")"); + TestOk("(Uint8,[b:Int32?])->Double", + "(CallableType '('1) '((DataType 'Double)) " + "'((DataType 'Uint8)) " + "'((OptionalType (DataType 'Int32)) 'b)" + ")"); + } + + Y_UNIT_TEST(ParseCallableWithArgFlags) { + TestOk("(Int32{Flags:AutoMap})->Double", + "(CallableType '() '((DataType 'Double)) " + "'((DataType 'Int32) '\"\" '1)" + ")"); + TestOk("(Int32?{Flags:AutoMap})->Double", + "(CallableType '() '((DataType 'Double)) " + "'((OptionalType (DataType 'Int32)) '\"\" '1)" + ")"); + TestOk("(x:Int32{Flags:AutoMap})->Double", + "(CallableType '() '((DataType 'Double)) " + "'((DataType 'Int32) 'x '1)" + ")"); + TestOk("(x:Int32{Flags:AutoMap}, [y:Uint32?{Flags: AutoMap}])->Double", + "(CallableType '('1) '((DataType 'Double)) " + "'((DataType 'Int32) 'x '1) " + "'((OptionalType (DataType 'Uint32)) 'y '1)" + ")"); + TestOk("(x:Int32{Flags: AutoMap | AutoMap})->Double", + "(CallableType '() '((DataType 'Double)) " + "'((DataType 'Int32) 'x '1)" + ")"); + } + + Y_UNIT_TEST(ParseCallableWithPayload) { + TestOk("(Int32)->Double{Payload:MyFunction}", + "(CallableType '('0 '\"MyFunction\") '((DataType 'Double)) " + "'((DataType 'Int32))" + ")"); + } + + Y_UNIT_TEST(ParseOptional) { + TestOk("Uint32?", "(OptionalType (DataType 'Uint32))"); + TestOk("Optional<Uint32>", "(OptionalType (DataType 'Uint32))"); + TestOk("Uint32??", "(OptionalType (OptionalType (DataType 'Uint32)))"); + TestOk("Optional<Uint32>?", "(OptionalType (OptionalType (DataType 'Uint32)))"); + TestOk("Optional<Uint32?>", "(OptionalType (OptionalType (DataType 'Uint32)))"); + TestOk("Optional<Optional<Uint32>>", "(OptionalType (OptionalType (DataType 'Uint32)))"); + } + + Y_UNIT_TEST(ParseCallableComplete) { + TestFail("(Uint32)->", 11, "Expected type"); + TestFail("(,)->", 2, "Expected type or argument name"); + TestFail("(Int32 Int32)->Int32", 8, "Expected ','"); + TestFail("([],)->Uint32", 4, "Expected ')'"); + TestFail("([)->Uint32", 3, "Expected ']'"); + TestFail("(])->Uint32", 2, "Expected ')'"); + TestFail("([,)->Uint32", 3, "Expected type or argument name"); + TestFail("([,])->Uint32", 3, "Expected type or argument name"); + TestFail("(->Uint32", 2, "Expected type or argument name"); + TestFail("([Uint32],Uint8)->Uint32", 9, "Optionals are only allowed in the optional arguments"); + TestFail("([Uint32?],Uint8)->Uint32", 11, "Expected ')'"); + TestFail("Callable<()>", 12, "Expected '->' after arguments"); + TestFail("Callable<()->", 14, "Expected type"); + TestFail("Callable<()->Uint32", 20, "Expected '>'"); + TestFail("(:Uint32)->Uint32", 2, "Expected non empty argument name"); + TestFail("(a:)->Uint32", 4, "Expected type"); + TestFail("(:)->Uint32", 2, "Expected non empty argument name"); + TestFail("(a:Uint32,Uint32)->Uint32", 11, "Expected named argument, because of previous argument(s) was named"); + TestFail("(Uint32{)->Uint32", 9, "Expected Flags field"); + TestFail("(Uint32})->Uint32", 8, "Expected ','"); + TestFail("(Uint32{})->Uint32", 9, "Expected Flags field"); + TestFail("(Uint32{Flags})->Uint32", 14, "Expected ':'"); + TestFail("(Uint32{Flags:})->Uint32", 15, "Expected flag name"); + TestFail("(Uint32{Flags:Map})->Uint32", 15, "Unknown flag name: Map"); + TestFail("(Uint32{Flags:|})->Uint32", 15, "Expected flag name"); + TestFail("(Uint32{Flags:AutoMap|})->Uint32", 23, "Expected flag name"); + TestFail("(Uint32{NonFlags:AutoMap})->Uint32", 9, "Expected Flags field"); + TestFail("(Uint32)->Uint32{", 18, "Expected Payload field"); + TestFail("(Uint32)->Uint32{}", 18, "Expected Payload field"); + TestFail("(Uint32)->Uint32}", 17, "Expected end of string"); + TestFail("(Uint32)->Uint32{Payload}", 25, "Expected ':'"); + TestFail("(Uint32)->Uint32{Payload:}", 26, "Expected payload data"); + } + + Y_UNIT_TEST(ParseCallableWithKeyword) { + TestOk("(Callable<()->String>) -> Callable<()->Uint32>", + "(CallableType '() '((CallableType '() '((DataType 'Uint32)))) " + "'((CallableType '() '((DataType 'String))))" + ")"); + } + + Y_UNIT_TEST(ParseListOfDataType) { + TestOk("(List<String>)->String", + "(CallableType '() '((DataType 'String)) " + "'((ListType (DataType 'String)))" + ")"); + } + + Y_UNIT_TEST(ParseStreamOfDataType) { + TestOk("(Stream<String>)->String", + "(CallableType '() '((DataType 'String)) " + "'((StreamType (DataType 'String)))" + ")"); + } + + Y_UNIT_TEST(ParseFlowOfDataType) { + TestOk("(Flow<String>)->String", + "(CallableType '() '((DataType 'String)) " + "'((FlowType (DataType 'String)))" + ")"); + } + + Y_UNIT_TEST(ParseVariantType) { + TestOk("Variant<String>", + "(VariantType (TupleType " + "(DataType 'String)" + "))"); + TestOk("Variant<String, Uint8>", + "(VariantType (TupleType " + "(DataType 'String) " + "(DataType 'Uint8)" + "))"); + TestOk("Variant<Name: String, Age: Int32>", + "(VariantType (StructType " + "'('\"Age\" (DataType 'Int32)) " + "'('\"Name\" (DataType 'String))" + "))"); + TestOk("Variant<'Some Name': String, 'Age': Int32>", + "(VariantType (StructType " + "'('\"Age\" (DataType 'Int32)) " + "'('\"Some Name\" (DataType 'String))" + "))"); + } + + Y_UNIT_TEST(ParseEnumType) { + TestOk("Enum<Name, Age>", + "(VariantType (StructType " + "'('\"Age\" (VoidType)) " + "'('\"Name\" (VoidType))" + "))"); + TestOk("Enum<'Some Name', 'Age'>", + "(VariantType (StructType " + "'('\"Age\" (VoidType)) " + "'('\"Some Name\" (VoidType))" + "))"); + } + + Y_UNIT_TEST(ParseListAsReturnType) { + TestOk("(String, String)->List<String>", + "(CallableType '() '((ListType (DataType 'String))) " + "'((DataType 'String)) " + "'((DataType 'String))" + ")"); + } + + Y_UNIT_TEST(ParseListOfOptionalDataType) { + TestOk("(List<String?>)->String", + "(CallableType '() '((DataType 'String)) " + "'((ListType " + "(OptionalType (DataType 'String))" + "))" + ")"); + } + + Y_UNIT_TEST(ParseOptionalListOfDataType) { + TestOk("(List<String>?)->String", + "(CallableType '() '((DataType 'String)) " + "'((OptionalType " + "(ListType (DataType 'String))" + "))" + ")"); + } + + Y_UNIT_TEST(ParseListOfListType) { + TestOk("(List<List<Uint32>>)->Uint32", + "(CallableType '() '((DataType 'Uint32)) " + "'((ListType " + "(ListType (DataType 'Uint32))" + "))" + ")"); + } + + Y_UNIT_TEST(ParseDictOfDataTypes) { + TestOk("(Dict<String, Uint32>)->Uint32", + "(CallableType '() '((DataType 'Uint32)) " + "'((DictType " + "(DataType 'String) " + "(DataType 'Uint32)" + "))" + ")"); + } + + Y_UNIT_TEST(ParseSetOfDataTypes) { + TestOk("(Set<String>)->Uint32", + "(CallableType '() '((DataType 'Uint32)) " + "'((DictType " + "(DataType 'String) " + "(VoidType)" + "))" + ")"); + } + + Y_UNIT_TEST(ParseListComplete) { + TestFail("(List<>)->Uint32", 7, "Expected type"); + TestFail("(List<Uint32,>)->Uint32", 13, "Expected '>'"); + } + + Y_UNIT_TEST(ParseVariantComplete) { + TestFail("Variant<>", 9, "Expected type"); + TestFail("Variant<Uint32,>", 16, "Expected type"); + TestFail("Variant<Uint32", 15, "Expected '>' or ','"); + + TestFail("Variant<name:>", 14, "Expected type"); + TestFail("Variant<name:String,>", 21, "Expected struct member name"); + TestFail("Variant<name:String", 20, "Expected '>' or ','"); + } + + Y_UNIT_TEST(ParseDictOfDictTypes) { + TestOk("(Dict<String, Dict<Uint32, Uint32>>)->Uint32", + "(CallableType '() '((DataType 'Uint32)) " + "'((DictType " + "(DataType 'String) " + "(DictType (DataType 'Uint32) (DataType 'Uint32))" + "))" + ")"); + } + + Y_UNIT_TEST(ParseDictComplete) { + TestFail("(Dict<>)->Uint32", 7, "Expected type"); + TestFail("(Dict<Uint32>)->Uint32", 13, "Expected ','"); + TestFail("(Dict<Uint32,>)->Uint32", 14, "Expected type"); + TestFail("(Dict<Uint32, String)->Uint32", 21, "Expected '>'"); + } + + Y_UNIT_TEST(ParseTupleOfDataTypes) { + TestOk("(Tuple<String, Uint32, Uint8>)->Uint32", + "(CallableType '() '((DataType 'Uint32)) " + "'((TupleType " + "(DataType 'String) " + "(DataType 'Uint32) " + "(DataType 'Uint8)" + "))" + ")"); + } + + Y_UNIT_TEST(ParseTupleComplete) { + TestFail("(Tuple<Uint32,>)->Uint32", 15, "Expected type"); + TestFail("(Tuple<Uint32)->Uint32", 14, "Expected '>' or ','"); + } + + Y_UNIT_TEST(ParseStructOfDataTypes) { + TestOk("(Struct<Name: String, Age: Uint32, Male: Bool>)->Uint32", + "(CallableType '() '((DataType 'Uint32)) " + "'((StructType " + "'('\"Age\" (DataType 'Uint32)) " + "'('\"Male\" (DataType 'Bool)) " + "'('\"Name\" (DataType 'String))" + "))" + ")"); + } + + Y_UNIT_TEST(ParseStructWithEscaping) { + TestOk("Struct<'My\\tName': String, 'My Age': Uint32>", + "(StructType " + "'('\"My\\tName\" (DataType 'String)) " + "'('\"My Age\" (DataType 'Uint32))" + ")"); + } + + Y_UNIT_TEST(ParseStructComplete) { + TestFail("(Struct<name>)->Uint32", 13, "Expected ':'"); + TestFail("(Struct<name:>)->Uint32", 14, "Expected type"); + TestFail("(Struct<name:String,>)->Uint32", 21, "Expected struct member name"); + TestFail("(Struct<name:String)->Uint32", 20, "Expected '>' or ','"); + } + + Y_UNIT_TEST(ParseResource) { + TestOk("Resource<aaa>", "(ResourceType 'aaa)"); + TestOk("(Resource<aaa>?)->Resource<bbb>", + "(CallableType '() '((ResourceType 'bbb)) " + "'((OptionalType (ResourceType 'aaa)))" + ")"); + } + + Y_UNIT_TEST(ParseVoid) { + TestOk("Void", "(VoidType)"); + TestOk("Void?", "(OptionalType (VoidType))"); + TestOk("(Void?)->Void", + "(CallableType '() '((VoidType)) " + "'((OptionalType (VoidType)))" + ")"); + } + + Y_UNIT_TEST(ParseNull) { + TestOk("Null", "(NullType)"); + TestOk("Null?", "(OptionalType (NullType))"); + TestOk("(Null?)->Null", + "(CallableType '() '((NullType)) " + "'((OptionalType (NullType)))" + ")"); + } + + Y_UNIT_TEST(ParseEmptyList) { + TestOk("EmptyList", "(EmptyListType)"); + TestOk("EmptyList?", "(OptionalType (EmptyListType))"); + TestOk("(EmptyList?)->EmptyList", + "(CallableType '() '((EmptyListType)) " + "'((OptionalType (EmptyListType)))" + ")"); + } + + Y_UNIT_TEST(ParseEmptyDict) { + TestOk("EmptyDict", "(EmptyDictType)"); + TestOk("EmptyDict?", "(OptionalType (EmptyDictType))"); + TestOk("(EmptyDict?)->EmptyDict", + "(CallableType '() '((EmptyDictType)) " + "'((OptionalType (EmptyDictType)))" + ")"); + } + + Y_UNIT_TEST(UnknownType) { + TestFail("(Yson2)->String", 2, "Unknown type: 'Yson2'"); + TestFail("()->", 5, "Expected type"); + } + + Y_UNIT_TEST(ParseTagged) { + TestOk("Tagged<Uint32, IdTag>", "(TaggedType (DataType 'Uint32) 'IdTag)"); + } + + Y_UNIT_TEST(ParseEmptyTuple) { + TestOk("Tuple<>", "(TupleType)"); + } + + Y_UNIT_TEST(ParseEmptyStruct) { + TestOk("Struct<>", "(StructType)"); + } + + void TestFormat(const TString& yql, const TString& expectedTypeStr) { + TMemoryPool pool(4096); + + TAstParseResult astRes = ParseAst(yql, &pool); + if (!astRes.IsOk()) { + astRes.Issues.PrintWithProgramTo(Cerr, "-memory-", yql); + UNIT_FAIL("Can't parse yql"); + } + + TExprContext ctx; + const TTypeAnnotationNode* type = CompileTypeAnnotation(*astRes.Root->GetChild(0), ctx); + if (!type) { + ctx.IssueManager.GetIssues().PrintWithProgramTo(Cerr, "-memory-", yql); + UNIT_FAIL("Can't compile types"); + } + + TString typeStr = FormatType(type); + UNIT_ASSERT_STRINGS_EQUAL(typeStr, expectedTypeStr); + } + + Y_UNIT_TEST(FormatUnit) { + TestFormat("(Unit)", "Unit"); + } + + Y_UNIT_TEST(FormatTuple) { + TestFormat("((Tuple " + " (Data Int32) " + " (Data Bool) " + " (Data String)" + "))", + "Tuple<Int32,Bool,String>"); + } + + Y_UNIT_TEST(FormatDataStruct) { + TestFormat("((Struct " + " (Item Name (Data String))" + " (Item Age (Data Uint32))" + " (Item Male (Data Bool))" + "))", + "Struct<'Age':Uint32,'Male':Bool,'Name':String>"); + } + + Y_UNIT_TEST(FormatDecimal) { + TestFormat("((Data Decimal 10 3))", "Decimal(10,3)"); + } + + Y_UNIT_TEST(FormatList) { + TestFormat("((List (Data String)))", "List<String>"); + } + + Y_UNIT_TEST(FormatStream) { + TestFormat("((Stream (Data String)))", "Stream<String>"); + } + + Y_UNIT_TEST(FormatFlow) { + TestFormat("((Flow (Data String)))", "Flow<String>"); + } + + Y_UNIT_TEST(FormatBlock) { + TestFormat("((Block (Data String)))", "Block<String>"); + } + + Y_UNIT_TEST(FormatScalar) { + TestFormat("((Scalar (Data String)))", "Scalar<String>"); + } + + Y_UNIT_TEST(FormatOptional) { + TestFormat("((Optional (Data Uint32)))", "Optional<Uint32>"); + TestFormat("((List (Optional (Data Uint32))))", "List<Uint32?>"); + } + + Y_UNIT_TEST(FormatVariant) { + TestFormat("((Variant (Tuple (Data String))))", "Variant<String>"); + } + + Y_UNIT_TEST(FormatEnum) { + TestFormat("((Variant (Struct (Item a Void) (Item b Void) )))", "Enum<'a','b'>"); + } + + Y_UNIT_TEST(FormatDict) { + TestFormat("((Dict " + " (Data String)" + " (Data Uint32)" + "))", + "Dict<String,Uint32>"); + } + + Y_UNIT_TEST(FormatSet) { + TestFormat("((Dict " + " (Data String)" + " Void" + "))", + "Set<String>"); + } + + Y_UNIT_TEST(FormatCallable) { + TestFormat("((Callable () " + " ((Data String))" + " ((Data Uint32))" + " ((Optional (Data Uint8)))" + "))", + "Callable<(Uint32,Uint8?)->String>"); + TestFormat("((Callable (1) " + " ((Data String))" + " ((Data Uint32))" + " ((Optional (Data Uint8)))" + "))", + "Callable<(Uint32,[Uint8?])->String>"); + TestFormat("((Callable (2) " + " ((Data String))" + " ((Optional (Data Uint32)))" + " ((Optional (Data Uint8)))" + "))", + "Callable<([Uint32?,Uint8?])->String>"); + } + + Y_UNIT_TEST(FormatOptionalCallable) { + TestFormat("((Optional (Callable () " + " ((Data String))" + " ((Optional (Data Uint8)))" + ")))", + "Optional<Callable<(Uint8?)->String>>"); + + TestFormat("((Optional (Optional (Callable () " + " ((Data String))" + " ((Optional (Data Uint8)))" + "))))", + "Optional<Optional<Callable<(Uint8?)->String>>>"); + } + + Y_UNIT_TEST(FormatCallableWithNamedArgs) { + TestFormat("((Callable () " + " ((Data String))" + " ((Data Uint32) x)" + " ((Data Uint8) y)" + "))", + "Callable<('x':Uint32,'y':Uint8)->String>"); + TestFormat("((Callable () " + " ((Data String))" + " ((Optional (Data Uint8)) a 1)" + "))", + "Callable<('a':Uint8?{Flags:AutoMap})->String>"); + } + + Y_UNIT_TEST(FormatCallableWithPayload) { + TestFormat("((Callable (0 MyFunction) " + " ((Data String))" + " ((Optional (Data Uint8)))" + "))", + "Callable<(Uint8?)->String{Payload:MyFunction}>"); + } + + Y_UNIT_TEST(FormatResource) { + TestFormat("((Resource aaa))", "Resource<'aaa'>"); + TestFormat("((Resource \"a b\"))", "Resource<'a b'>"); + TestFormat("((Resource \"a\\t\\n\\x01b\"))", "Resource<'a\\t\\n\\x01b'>"); + TestFormat("((Optional (Resource aaa)))", "Optional<Resource<'aaa'>>"); + } + + Y_UNIT_TEST(FormatTagged) { + TestFormat("((Tagged (Data String) aaa))", "Tagged<String,'aaa'>"); + TestFormat("((Tagged (Data String) \"a b\"))", "Tagged<String,'a b'>"); + TestFormat("((Tagged (Data String) \"a\\t\\n\\x01b\"))", "Tagged<String,'a\\t\\n\\x01b'>"); + } + + Y_UNIT_TEST(FormatPg) { + TestFormat("((Pg int4))", "pgint4"); + TestFormat("((Pg _int4))", "_pgint4"); + } + + Y_UNIT_TEST(FormatOptionalPg) { + TestFormat("((Optional (Pg int4)))", "Optional<pgint4>"); + TestFormat("((Optional (Pg _int4)))", "Optional<_pgint4>"); + } +} |