diff options
author | vvvv <[email protected]> | 2024-11-26 12:00:01 +0300 |
---|---|---|
committer | vvvv <[email protected]> | 2024-11-26 12:11:55 +0300 |
commit | 376f6f42244428b36cde31eaa8c7e89a90d7fd1b (patch) | |
tree | ae023e0a2285cfbf8ad4441f6e1a146c2dd1383f | |
parent | d7c900fd2ee9f48165550376be3c798be3299a59 (diff) |
Moved yql/public/purecalc YQL-19206
init
commit_hash:abf729827c312980464da21824f86ea1defe094c
107 files changed, 10946 insertions, 7 deletions
diff --git a/yql/essentials/public/purecalc/common/compile_mkql.cpp b/yql/essentials/public/purecalc/common/compile_mkql.cpp new file mode 100644 index 00000000000..743447ada9c --- /dev/null +++ b/yql/essentials/public/purecalc/common/compile_mkql.cpp @@ -0,0 +1,116 @@ +#include "compile_mkql.h" + +#include <yql/essentials/providers/common/mkql/yql_provider_mkql.h> +#include <yql/essentials/providers/common/mkql/yql_type_mkql.h> +#include <yql/essentials/core/yql_user_data_storage.h> +#include <yql/essentials/public/purecalc/common/names.h> + +#include <util/stream/file.h> + +namespace NYql::NPureCalc { + +namespace { + +NCommon::IMkqlCallableCompiler::TCompiler MakeSelfCallableCompiler() { + return [](const TExprNode& node, NCommon::TMkqlBuildContext& ctx) { + MKQL_ENSURE(node.ChildrenSize() == 1, "Self takes exactly 1 argument"); + const auto* argument = node.Child(0); + MKQL_ENSURE(argument->IsAtom(), "Self argument must be atom"); + ui32 inputIndex = 0; + MKQL_ENSURE(TryFromString(argument->Content(), inputIndex), "Self argument must be UI32"); + auto type = NCommon::BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder); + NKikimr::NMiniKQL::TCallableBuilder call(ctx.ProgramBuilder.GetTypeEnvironment(), node.Content(), type); + call.Add(ctx.ProgramBuilder.NewDataLiteral<ui32>(inputIndex)); + return NKikimr::NMiniKQL::TRuntimeNode(call.Build(), false); + }; +} + +NCommon::IMkqlCallableCompiler::TCompiler MakeFilePathCallableCompiler(const TUserDataTable& userData) { + return [&](const TExprNode& node, NCommon::TMkqlBuildContext& ctx) { + const TString name(node.Child(0)->Content()); + auto block = TUserDataStorage::FindUserDataBlock(userData, TUserDataKey::File(name)); + if (!block) { + auto blockKey = TUserDataKey::File(GetDefaultFilePrefix() + name); + block = TUserDataStorage::FindUserDataBlock(userData, blockKey); + } + MKQL_ENSURE(block, "file not found: " << name); + MKQL_ENSURE(block->Type == EUserDataType::PATH, + "FilePath not supported for non-filesystem user data, name: " + << name << ", block type: " << block->Type); + return ctx.ProgramBuilder.NewDataLiteral<NKikimr::NUdf::EDataSlot::String>(block->Data); + }; +} + +NCommon::IMkqlCallableCompiler::TCompiler MakeFileContentCallableCompiler(const TUserDataTable& userData) { + return [&](const TExprNode& node, NCommon::TMkqlBuildContext& ctx) { + const TString name(node.Child(0)->Content()); + auto block = TUserDataStorage::FindUserDataBlock(userData, TUserDataKey::File(name)); + if (!block) { + auto blockKey = TUserDataKey::File(GetDefaultFilePrefix() + name); + block = TUserDataStorage::FindUserDataBlock(userData, blockKey); + } + MKQL_ENSURE(block, "file not found: " << name); + if (block->Type == EUserDataType::PATH) { + auto content = TFileInput(block->Data).ReadAll(); + return ctx.ProgramBuilder.NewDataLiteral<NKikimr::NUdf::EDataSlot::String>(content); + } else if (block->Type == EUserDataType::RAW_INLINE_DATA) { + return ctx.ProgramBuilder.NewDataLiteral<NKikimr::NUdf::EDataSlot::String>(block->Data); + } else { + // TODO support EUserDataType::URL + MKQL_ENSURE(false, "user data blocks of type URL are not supported by FileContent: " << name); + Y_UNREACHABLE(); + } + }; +} + +NCommon::IMkqlCallableCompiler::TCompiler MakeFolderPathCallableCompiler(const TUserDataTable& userData) { + return [&](const TExprNode& node, NCommon::TMkqlBuildContext& ctx) { + const TString name(node.Child(0)->Content()); + auto folderName = TUserDataStorage::MakeFolderName(name); + TMaybe<TString> folderPath; + for (const auto& x : userData) { + if (!x.first.Alias().StartsWith(folderName)) { + continue; + } + + MKQL_ENSURE(x.second.Type == EUserDataType::PATH, + "FilePath not supported for non-file data block, name: " + << x.first.Alias() << ", block type: " << x.second.Type); + + auto pathPrefixLength = x.second.Data.size() - (x.first.Alias().size() - folderName.size()); + auto newFolderPath = x.second.Data.substr(0, pathPrefixLength); + if (!folderPath) { + folderPath = newFolderPath; + } else { + MKQL_ENSURE(*folderPath == newFolderPath, + "file " << x.second.Data << " is out of directory " << *folderPath); + } + } + return ctx.ProgramBuilder.NewDataLiteral<NKikimr::NUdf::EDataSlot::String>(*folderPath); + }; +} + +} + +NKikimr::NMiniKQL::TRuntimeNode CompileMkql(const TExprNode::TPtr& exprRoot, TExprContext& exprCtx, + const NKikimr::NMiniKQL::IFunctionRegistry& funcRegistry, const NKikimr::NMiniKQL::TTypeEnvironment& env, const TUserDataTable& userData) +{ + NCommon::TMkqlCommonCallableCompiler compiler; + + compiler.AddCallable(PurecalcInputCallableName, MakeSelfCallableCompiler()); + compiler.AddCallable(PurecalcBlockInputCallableName, MakeSelfCallableCompiler()); + compiler.OverrideCallable("FileContent", MakeFileContentCallableCompiler(userData)); + compiler.OverrideCallable("FilePath", MakeFilePathCallableCompiler(userData)); + compiler.OverrideCallable("FolderPath", MakeFolderPathCallableCompiler(userData)); + + // Prepare build context + + NKikimr::NMiniKQL::TProgramBuilder pgmBuilder(env, funcRegistry); + NCommon::TMkqlBuildContext buildCtx(compiler, pgmBuilder, exprCtx); + + // Build the root MKQL node + + return NCommon::MkqlBuildExpr(*exprRoot, buildCtx); +} + +} // NYql::NPureCalc diff --git a/yql/essentials/public/purecalc/common/compile_mkql.h b/yql/essentials/public/purecalc/common/compile_mkql.h new file mode 100644 index 00000000000..0b6c16aef52 --- /dev/null +++ b/yql/essentials/public/purecalc/common/compile_mkql.h @@ -0,0 +1,17 @@ +#pragma once + +#include <yql/essentials/public/purecalc/common/interface.h> +#include <yql/essentials/minikql/mkql_node.h> +#include <yql/essentials/ast/yql_expr.h> +#include <yql/essentials/core/yql_user_data.h> + +namespace NYql { + namespace NPureCalc { + /** + * Compile expr to mkql byte-code + */ + + NKikimr::NMiniKQL::TRuntimeNode CompileMkql(const TExprNode::TPtr& exprRoot, TExprContext& exprCtx, + const NKikimr::NMiniKQL::IFunctionRegistry& funcRegistry, const NKikimr::NMiniKQL::TTypeEnvironment& env, const TUserDataTable& userData); + } +} diff --git a/yql/essentials/public/purecalc/common/fwd.cpp b/yql/essentials/public/purecalc/common/fwd.cpp new file mode 100644 index 00000000000..4214b6df83e --- /dev/null +++ b/yql/essentials/public/purecalc/common/fwd.cpp @@ -0,0 +1 @@ +#include "fwd.h" diff --git a/yql/essentials/public/purecalc/common/fwd.h b/yql/essentials/public/purecalc/common/fwd.h new file mode 100644 index 00000000000..22df90a6b29 --- /dev/null +++ b/yql/essentials/public/purecalc/common/fwd.h @@ -0,0 +1,56 @@ +#pragma once + +#include <util/generic/fwd.h> +#include <memory> + +namespace NYql::NPureCalc { + class TCompileError; + + template <typename> + class IConsumer; + + template <typename> + class IStream; + + class IProgramFactory; + + class IWorkerFactory; + + class IPullStreamWorkerFactory; + + class IPullListWorkerFactory; + + class IPushStreamWorkerFactory; + + class IWorker; + + class IPullStreamWorker; + + class IPullListWorker; + + class IPushStreamWorker; + + class TInputSpecBase; + + class TOutputSpecBase; + + class IProgram; + + template <typename, typename, typename> + class TProgramCommon; + + template <typename, typename> + class TPullStreamProgram; + + template <typename, typename> + class TPullListProgram; + + template <typename, typename> + class TPushStreamProgram; + + using IProgramFactoryPtr = TIntrusivePtr<IProgramFactory>; + using IWorkerFactoryPtr = std::shared_ptr<IWorkerFactory>; + using IPullStreamWorkerFactoryPtr = std::shared_ptr<IPullStreamWorkerFactory>; + using IPullListWorkerFactoryPtr = std::shared_ptr<IPullListWorkerFactory>; + using IPushStreamWorkerFactoryPtr = std::shared_ptr<IPushStreamWorkerFactory>; +} diff --git a/yql/essentials/public/purecalc/common/inspect_input.cpp b/yql/essentials/public/purecalc/common/inspect_input.cpp new file mode 100644 index 00000000000..9ca56da5dec --- /dev/null +++ b/yql/essentials/public/purecalc/common/inspect_input.cpp @@ -0,0 +1,33 @@ +#include "inspect_input.h" + +#include <yql/essentials/core/yql_expr_type_annotation.h> + +namespace NYql::NPureCalc { + bool TryFetchInputIndexFromSelf(const TExprNode& node, TExprContext& ctx, ui32 inputsCount, ui32& result) { + TIssueScopeGuard issueSope(ctx.IssueManager, [&]() { + return MakeIntrusive<TIssue>(ctx.GetPosition(node.Pos()), TStringBuilder() << "At function: " << node.Content()); + }); + + if (!EnsureArgsCount(node, 1, ctx)) { + return false; + } + + if (!EnsureAtom(*node.Child(0), ctx)) { + return false; + } + + if (!TryFromString(node.Child(0)->Content(), result)) { + auto message = TStringBuilder() << "Index " << TString{node.Child(0)->Content()}.Quote() << " isn't UI32"; + ctx.AddError(TIssue(ctx.GetPosition(node.Child(0)->Pos()), std::move(message))); + return false; + } + + if (result >= inputsCount) { + auto message = TStringBuilder() << "Invalid input index: " << result << " is out of range [0;" << inputsCount << ")"; + ctx.AddError(TIssue(ctx.GetPosition(node.Child(0)->Pos()), std::move(message))); + return false; + } + + return true; + } +} diff --git a/yql/essentials/public/purecalc/common/inspect_input.h b/yql/essentials/public/purecalc/common/inspect_input.h new file mode 100644 index 00000000000..558144865da --- /dev/null +++ b/yql/essentials/public/purecalc/common/inspect_input.h @@ -0,0 +1,7 @@ +#pragma once + +#include <yql/essentials/ast/yql_expr.h> + +namespace NYql::NPureCalc { + bool TryFetchInputIndexFromSelf(const TExprNode&, TExprContext&, ui32, ui32&); +} diff --git a/yql/essentials/public/purecalc/common/interface.cpp b/yql/essentials/public/purecalc/common/interface.cpp new file mode 100644 index 00000000000..6783ad407e0 --- /dev/null +++ b/yql/essentials/public/purecalc/common/interface.cpp @@ -0,0 +1,128 @@ +#include "interface.h" + +#include <yql/essentials/providers/common/codec/yql_codec_type_flags.h> +#include <yql/essentials/public/purecalc/common/logger_init.h> +#include <yql/essentials/public/purecalc/common/program_factory.h> + +using namespace NYql; +using namespace NYql::NPureCalc; + +TLoggingOptions::TLoggingOptions() + : LogLevel_(ELogPriority::TLOG_ERR) + , LogDestination(&Clog) +{ +} + +TLoggingOptions& TLoggingOptions::SetLogLevel(ELogPriority logLevel) { + LogLevel_ = logLevel; + return *this; +} + +TLoggingOptions& TLoggingOptions::SetLogDestination(IOutputStream* logDestination) { + LogDestination = logDestination; + return *this; +} + +TProgramFactoryOptions::TProgramFactoryOptions() + : UdfsDir_("") + , UserData_() + , LLVMSettings("OFF") + , BlockEngineSettings("disable") + , ExprOutputStream(nullptr) + , CountersProvider(nullptr) + , NativeYtTypeFlags(0) + , UseSystemColumns(false) + , UseWorkerPool(true) +{ +} + +TProgramFactoryOptions& TProgramFactoryOptions::SetUDFsDir(TStringBuf dir) { + UdfsDir_ = dir; + return *this; +} + +TProgramFactoryOptions& TProgramFactoryOptions::AddLibrary(NUserData::EDisposition disposition, TStringBuf name, TStringBuf content) { + auto& ref = UserData_.emplace_back(); + + ref.Type_ = NUserData::EType::LIBRARY; + ref.Disposition_ = disposition; + ref.Name_ = name; + ref.Content_ = content; + + return *this; +} + +TProgramFactoryOptions& TProgramFactoryOptions::AddFile(NUserData::EDisposition disposition, TStringBuf name, TStringBuf content) { + auto& ref = UserData_.emplace_back(); + + ref.Type_ = NUserData::EType::FILE; + ref.Disposition_ = disposition; + ref.Name_ = name; + ref.Content_ = content; + + return *this; +} + +TProgramFactoryOptions& TProgramFactoryOptions::AddUDF(NUserData::EDisposition disposition, TStringBuf name, TStringBuf content) { + auto& ref = UserData_.emplace_back(); + + ref.Type_ = NUserData::EType::UDF; + ref.Disposition_ = disposition; + ref.Name_ = name; + ref.Content_ = content; + + return *this; +} + +TProgramFactoryOptions& TProgramFactoryOptions::SetLLVMSettings(TStringBuf llvm_settings) { + LLVMSettings = llvm_settings; + return *this; +} + +TProgramFactoryOptions& TProgramFactoryOptions::SetBlockEngineSettings(TStringBuf blockEngineSettings) { + BlockEngineSettings = blockEngineSettings; + return *this; +} + +TProgramFactoryOptions& TProgramFactoryOptions::SetExprOutputStream(IOutputStream* exprOutputStream) { + ExprOutputStream = exprOutputStream; + return *this; +} + +TProgramFactoryOptions& TProgramFactoryOptions::SetCountersProvider(NKikimr::NUdf::ICountersProvider* countersProvider) { + CountersProvider = countersProvider; + return *this; +} + +TProgramFactoryOptions& TProgramFactoryOptions::SetUseNativeYtTypes(bool useNativeTypes) { + NativeYtTypeFlags = useNativeTypes ? NTCF_PRODUCTION : NTCF_NONE; + return *this; +} + +TProgramFactoryOptions& TProgramFactoryOptions::SetNativeYtTypeFlags(ui64 nativeTypeFlags) { + NativeYtTypeFlags = nativeTypeFlags; + return *this; +} + +TProgramFactoryOptions& TProgramFactoryOptions::SetDeterministicTimeProviderSeed(TMaybe<ui64> seed) { + DeterministicTimeProviderSeed = seed; + return *this; +} + +TProgramFactoryOptions& TProgramFactoryOptions::SetUseSystemColumns(bool useSystemColumns) { + UseSystemColumns = useSystemColumns; + return *this; +} + +TProgramFactoryOptions& TProgramFactoryOptions::SetUseWorkerPool(bool useWorkerPool) { + UseWorkerPool = useWorkerPool; + return *this; +} + +void NYql::NPureCalc::ConfigureLogging(const TLoggingOptions& options) { + InitLogging(options); +} + +IProgramFactoryPtr NYql::NPureCalc::MakeProgramFactory(const TProgramFactoryOptions& options) { + return new TProgramFactory(options); +} diff --git a/yql/essentials/public/purecalc/common/interface.h b/yql/essentials/public/purecalc/common/interface.h new file mode 100644 index 00000000000..6e56c9aa3f9 --- /dev/null +++ b/yql/essentials/public/purecalc/common/interface.h @@ -0,0 +1,1180 @@ +#pragma once + +#include "fwd.h" +#include "wrappers.h" + +#include <yql/essentials/core/user_data/yql_user_data.h> + +#include <yql/essentials/public/udf/udf_value.h> +#include <yql/essentials/public/udf/udf_counter.h> +#include <yql/essentials/public/udf/udf_registrator.h> + +#include <yql/essentials/public/issue/yql_issue.h> +#include <library/cpp/yson/node/node.h> + +#include <library/cpp/logger/priority.h> + +#include <util/generic/ptr.h> +#include <util/generic/maybe.h> +#include <util/generic/hash_set.h> +#include <util/generic/string.h> +#include <util/stream/output.h> + +class ITimeProvider; + +namespace NKikimr { + namespace NMiniKQL { + class TScopedAlloc; + class IComputationGraph; + class IFunctionRegistry; + class TTypeEnvironment; + class TType; + class TStructType; + } +} + +namespace NYql { + namespace NPureCalc { + /** + * SQL or s-expression translation error. + */ + class TCompileError: public yexception { + private: + TString Yql_; + TString Issues_; + + public: + // TODO: maybe accept an actual list of issues here? + // See https://a.yandex-team.ru/arc/review/439403/details#comment-778237 + TCompileError(TString yql, TString issues) + : Yql_(std::move(yql)) + , Issues_(std::move(issues)) + { + } + + public: + /** + * Get the sql query which caused the error (if there is one available). + */ + const TString& GetYql() const { + return Yql_; + } + + /** + * Get detailed description for all errors and warnings that happened during sql translation. + */ + const TString& GetIssues() const { + return Issues_; + } + }; + + //////////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * A generic input stream of objects. + */ + template <typename T> + class IStream { + public: + virtual ~IStream() = default; + + public: + /** + * Pops and returns a next value in the stream. If the stream is finished, should return some sentinel object. + * + * Depending on return type, this function may not transfer object ownership to a user. + * Thus, the stream may manage the returned object * itself. + * That is, the returned object's lifetime may be bound to the input stream lifetime; it may be destroyed + * upon calling Fetch() or upon destroying the stream, whichever happens first. + */ + virtual T Fetch() = 0; + }; + + /** + * Create a new stream which applies the given functor to the elements of the original stream. + */ + template <typename TOld, typename TNew, typename TFunctor> + inline THolder<IStream<TNew>> MapStream(THolder<IStream<TOld>> stream, TFunctor functor) { + return THolder(new NPrivate::TMappingStream<TNew, TOld, TFunctor>(std::move(stream), std::move(functor))); + }; + + /** + * Convert stream of objects into a stream of potentially incompatible objects. + * + * This conversion applies static cast to the output of the original stream. Use with caution! + */ + /// @{ + template < + typename TNew, typename TOld, + std::enable_if_t<!std::is_same<TNew, TOld>::value>* = nullptr> + inline THolder<IStream<TNew>> ConvertStreamUnsafe(THolder<IStream<TOld>> stream) { + return MapStream<TOld, TNew>(std::move(stream), [](TOld x) -> TNew { return static_cast<TNew>(x); }); + } + template <typename T> + inline THolder<IStream<T>> ConvertStreamUnsafe(THolder<IStream<T>> stream) { + return stream; + } + /// @} + + /** + * Convert stream of objects into a stream of compatible objects. + * + * Note: each conversion adds one level of indirection so avoid them if possible. + */ + template <typename TNew, typename TOld, std::enable_if_t<std::is_convertible<TOld, TNew>::value>* = nullptr> + inline THolder<IStream<TNew>> ConvertStream(THolder<IStream<TOld>> stream) { + return ConvertStreamUnsafe<TNew, TOld>(std::move(stream)); + } + + //////////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * A generic push consumer. + */ + template <typename T> + class IConsumer { + public: + virtual ~IConsumer() = default; + + public: + /** + * Feed an object to consumer. + * + * Depending on argument type, the consumer may not take ownership of the passed object; + * in that case it is the caller responsibility to manage the object lifetime after passing it to this method. + * + * The passed object can be destroyed after the consumer returns from this function; the consumer should + * not store pointer to the passed object or the passed object itself without taking all necessary precautions + * to ensure that the pointer or the object stays valid after returning. + */ + virtual void OnObject(T) = 0; + + /** + * Close the consumer and run finalization logic. Calling OnObject after calling this function is an error. + */ + virtual void OnFinish() = 0; + }; + + /** + * Create a new consumer which applies the given functor to objects before . + */ + template <typename TOld, typename TNew, typename TFunctor> + inline THolder<IConsumer<TNew>> MapConsumer(THolder<IConsumer<TOld>> stream, TFunctor functor) { + return THolder(new NPrivate::TMappingConsumer<TNew, TOld, TFunctor>(std::move(stream), std::move(functor))); + }; + + + /** + * Convert consumer of objects into a consumer of potentially incompatible objects. + * + * This conversion applies static cast to the input value. Use with caution. + */ + /// @{ + template < + typename TNew, typename TOld, + std::enable_if_t<!std::is_same<TNew, TOld>::value>* = nullptr> + inline THolder<IConsumer<TNew>> ConvertConsumerUnsafe(THolder<IConsumer<TOld>> consumer) { + return MapConsumer<TOld, TNew>(std::move(consumer), [](TNew x) -> TOld { return static_cast<TOld>(x); }); + } + template <typename T> + inline THolder<IConsumer<T>> ConvertConsumerUnsafe(THolder<IConsumer<T>> consumer) { + return consumer; + } + /// @} + + /** + * Convert consumer of objects into a consumer of compatible objects. + * + * Note: each conversion adds one level of indirection so avoid them if possible. + */ + template <typename TNew, typename TOld, std::enable_if_t<std::is_convertible<TNew, TOld>::value>* = nullptr> + inline THolder<IConsumer<TNew>> ConvertConsumer(THolder<IConsumer<TOld>> consumer) { + return ConvertConsumerUnsafe<TNew, TOld>(std::move(consumer)); + } + + /** + * Create a consumer which holds a non-owning pointer to the given consumer + * and passes all messages to the latter. + */ + template <typename T, typename C> + THolder<NPrivate::TNonOwningConsumer<T, C>> MakeNonOwningConsumer(C consumer) { + return MakeHolder<NPrivate::TNonOwningConsumer<T, C>>(consumer); + } + + //////////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Logging options. + */ + struct TLoggingOptions final { + public: + /// Logging level for messages generated during compilation. + ELogPriority LogLevel_; // TODO: rename to LogLevel + + /// Where to write log messages. + IOutputStream* LogDestination; + + public: + TLoggingOptions(); + /** + * Set a new logging level. + * + * @return reference to self, to allow method chaining. + */ + TLoggingOptions& SetLogLevel(ELogPriority); + + /** + * Set a new logging destination. + * + * @return reference to self, to allow method chaining. + */ + TLoggingOptions& SetLogDestination(IOutputStream*); + }; + + /** + * General options for program factory. + */ + struct TProgramFactoryOptions final { + public: + /// Path to a directory with compiled UDFs. Leave empty to disable loading external UDFs. + TString UdfsDir_; // TODO: rename to UDFDir + + /// List of available external resources, e.g. files, UDFs, libraries. + TVector<NUserData::TUserData> UserData_; // TODO: rename to UserData + + /// LLVM settings. Assign "OFF" to disable LLVM, empty string for default settings. + TString LLVMSettings; + + /// Block engine settings. Assign "force" to unconditionally enable + /// it, "disable" for turn it off and "auto" to left the final + /// decision to the platform heuristics. + TString BlockEngineSettings; + + /// Output stream to dump the compiled and optimized expressions. + IOutputStream* ExprOutputStream; + + /// Provider for generic counters which can be used to export statistics from UDFs. + NKikimr::NUdf::ICountersProvider* CountersProvider; + + /// YT Type V3 flags for Skiff/Yson serialization. + ui64 NativeYtTypeFlags; + + /// Seed for deterministic time provider + TMaybe<ui64> DeterministicTimeProviderSeed; + + /// Use special system columns to support tables naming (supports non empty ``TablePath()``/``TableName()``) + bool UseSystemColumns; + + /// Reuse allocated workers + bool UseWorkerPool; + + public: + TProgramFactoryOptions(); + + public: + /** + * Set a new path to a directory with UDFs. + * + * @return reference to self, to allow method chaining. + */ + TProgramFactoryOptions& SetUDFsDir(TStringBuf); + + /** + * Add a new library to the UserData list. + * + * @param disposition where the resource resides, e.g. on filesystem, in memory, etc. + * NB: URL disposition is not supported. + * @param name name of the resource. + * @param content depending on disposition, either path to the resource or its content. + * @return reference to self, to allow method chaining. + */ + TProgramFactoryOptions& AddLibrary(NUserData::EDisposition disposition, TStringBuf name, TStringBuf content); + + /** + * Add a new file to the UserData list. + * + * @param disposition where the resource resides, e.g. on filesystem, in memory, etc. + * NB: URL disposition is not supported. + * @param name name of the resource. + * @param content depending on disposition, either path to the resource or its content. + * @return reference to self, to allow method chaining. + */ + TProgramFactoryOptions& AddFile(NUserData::EDisposition disposition, TStringBuf name, TStringBuf content); + + /** + * Add a new UDF to the UserData list. + * + * @param disposition where the resource resides, e.g. on filesystem, in memory, etc. + * NB: URL disposition is not supported. + * @param name name of the resource. + * @param content depending on disposition, either path to the resource or its content. + * @return reference to self, to allow method chaining. + */ + TProgramFactoryOptions& AddUDF(NUserData::EDisposition disposition, TStringBuf name, TStringBuf content); + + /** + * Set new LLVM settings. + * + * @return reference to self, to allow method chaining. + */ + TProgramFactoryOptions& SetLLVMSettings(TStringBuf llvm_settings); + + /** + * Set new block engine settings. + * + * @return reference to self, to allow method chaining. + */ + TProgramFactoryOptions& SetBlockEngineSettings(TStringBuf blockEngineSettings); + + /** + * Set the stream to dump the compiled and optimized expressions. + * + * @return reference to self, to allow method chaining. + */ + TProgramFactoryOptions& SetExprOutputStream(IOutputStream* exprOutputStream); + + /** + * Set new counters provider. Passed pointer should stay alive for as long as the processor factory + * stays alive. + * + * @return reference to self, to allow method chaining. + */ + TProgramFactoryOptions& SetCountersProvider(NKikimr::NUdf::ICountersProvider* countersProvider); + + /** + * Set new YT Type V3 mode. Deprecated method. Use SetNativeYtTypeFlags instead + * + * @return reference to self, to allow method chaining. + */ + TProgramFactoryOptions& SetUseNativeYtTypes(bool useNativeTypes); + + /** + * Set YT Type V3 flags. + * + * @return reference to self, to allow method chaining. + */ + TProgramFactoryOptions& SetNativeYtTypeFlags(ui64 nativeTypeFlags); + + /** + * Set seed for deterministic time provider. + * + * @return reference to self, to allow method chaining. + */ + TProgramFactoryOptions& SetDeterministicTimeProviderSeed(TMaybe<ui64> seed); + + /** + * Set new flag whether to allow using system columns or not. + * + * @return reference to self, to allow method chaining. + */ + TProgramFactoryOptions& SetUseSystemColumns(bool useSystemColumns); + + /** + * Set new flag whether to allow reusing workers or not. + * + * @return reference to self, to allow method chaining. + */ + TProgramFactoryOptions& SetUseWorkerPool(bool useWorkerPool); + }; + + //////////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * What exactly are we parsing: SQL or an s-expression. + */ + enum class ETranslationMode { + SQL /* "SQL" */, + SExpr /* "s-expression" */, + Mkql /* "mkql" */, + PG /* PostgreSQL */ + }; + + /** + * A facility for compiling sql and s-expressions and making programs from them. + */ + class IProgramFactory: public TThrRefBase { + protected: + virtual IPullStreamWorkerFactoryPtr MakePullStreamWorkerFactory(const TInputSpecBase&, const TOutputSpecBase&, TString, ETranslationMode, ui16) = 0; + virtual IPullListWorkerFactoryPtr MakePullListWorkerFactory(const TInputSpecBase&, const TOutputSpecBase&, TString, ETranslationMode, ui16) = 0; + virtual IPushStreamWorkerFactoryPtr MakePushStreamWorkerFactory(const TInputSpecBase&, const TOutputSpecBase&, TString, ETranslationMode, ui16) = 0; + + public: + /** + * Add new udf module. It's not specified whether adding new modules will affect existing programs + * (theoretical answer is 'no'). + */ + virtual void AddUdfModule(const TStringBuf&, NKikimr::NUdf::TUniquePtr<NKikimr::NUdf::IUdfModule>&&) = 0; + // TODO: support setting udf modules via factory options. + + /** + * Set new counters provider, override one that was specified via factory options. Note that existing + * programs will still reference the previous provider. + */ + virtual void SetCountersProvider(NKikimr::NUdf::ICountersProvider*) = 0; + // TODO: support setting providers via factory options. + + template <typename TInputSpec, typename TOutputSpec> + THolder<TPullStreamProgram<TInputSpec, TOutputSpec>> MakePullStreamProgram( + TInputSpec inputSpec, TOutputSpec outputSpec, TString query, ETranslationMode mode = ETranslationMode::SQL, ui16 syntaxVersion = 1 + ) { + auto workerFactory = MakePullStreamWorkerFactory(inputSpec, outputSpec, std::move(query), mode, syntaxVersion); + return MakeHolder<TPullStreamProgram<TInputSpec, TOutputSpec>>(std::move(inputSpec), std::move(outputSpec), workerFactory); + } + + template <typename TInputSpec, typename TOutputSpec> + THolder<TPullListProgram<TInputSpec, TOutputSpec>> MakePullListProgram( + TInputSpec inputSpec, TOutputSpec outputSpec, TString query, ETranslationMode mode = ETranslationMode::SQL, ui16 syntaxVersion = 1 + ) { + auto workerFactory = MakePullListWorkerFactory(inputSpec, outputSpec, std::move(query), mode, syntaxVersion); + return MakeHolder<TPullListProgram<TInputSpec, TOutputSpec>>(std::move(inputSpec), std::move(outputSpec), workerFactory); + } + + template <typename TInputSpec, typename TOutputSpec> + THolder<TPushStreamProgram<TInputSpec, TOutputSpec>> MakePushStreamProgram( + TInputSpec inputSpec, TOutputSpec outputSpec, TString query, ETranslationMode mode = ETranslationMode::SQL, ui16 syntaxVersion = 1 + ) { + auto workerFactory = MakePushStreamWorkerFactory(inputSpec, outputSpec, std::move(query), mode, syntaxVersion); + return MakeHolder<TPushStreamProgram<TInputSpec, TOutputSpec>>(std::move(inputSpec), std::move(outputSpec), workerFactory); + } + }; + + //////////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * A facility for creating workers. Despite being a part of a public API, worker factory is not used directly. + */ + class IWorkerFactory: public std::enable_shared_from_this<IWorkerFactory> { + public: + virtual ~IWorkerFactory() = default; + /** + * Get input column names for specified input that are actually used in the query. + */ + virtual const THashSet<TString>& GetUsedColumns(ui32) const = 0; + /** + * Overload for single-input programs. + */ + virtual const THashSet<TString>& GetUsedColumns() const = 0; + + /** + * Make input type schema for specified input as deduced by program optimizer. This schema is equivalent + * to one provided by input spec up to the order of the fields in structures. + */ + virtual NYT::TNode MakeInputSchema(ui32) const = 0; + /** + * Overload for single-input programs. + */ + virtual NYT::TNode MakeInputSchema() const = 0; + + /** + * Make output type schema as deduced by program optimizer. If output spec provides its own schema, than + * this schema is equivalent to one provided by output spec up to the order of the fields in structures. + */ + /// @{ + /** + * Overload for single-table output programs (i.e. output type is struct). + */ + virtual NYT::TNode MakeOutputSchema() const = 0; + /** + * Overload for multi-table output programs (i.e. output type is variant over tuple). + */ + virtual NYT::TNode MakeOutputSchema(ui32) const = 0; + /** + * Overload for multi-table output programs (i.e. output type is variant over struct). + */ + virtual NYT::TNode MakeOutputSchema(TStringBuf) const = 0; + /// @} + + /** + * Make full output schema. For single-output programs returns struct type, for multi-output programs + * returns variant type. + * + * Warning: calling this function may result in extended memory usage for large number of output tables. + */ + virtual NYT::TNode MakeFullOutputSchema() const = 0; + + /** + * Get compilation issues + */ + virtual TIssues GetIssues() const = 0; + + /** + * Get precompiled mkql program + */ + virtual TString GetCompiledProgram() = 0; + + /** + * Return a worker to the factory for possible reuse + */ + virtual void ReturnWorker(IWorker* worker) = 0; + }; + + class TReleaseWorker { + public: + template <class T> + static inline void Destroy(T* t) noexcept { + t->Release(); + } + }; + + template <class T> + using TWorkerHolder = THolder<T, TReleaseWorker>; + + /** + * Factory for generating pull stream workers. + */ + class IPullStreamWorkerFactory: public IWorkerFactory { + public: + /** + * Create a new pull stream worker. + */ + virtual TWorkerHolder<IPullStreamWorker> MakeWorker() = 0; + }; + + /** + * Factory for generating pull list workers. + */ + class IPullListWorkerFactory: public IWorkerFactory { + public: + /** + * Create a new pull list worker. + */ + virtual TWorkerHolder<IPullListWorker> MakeWorker() = 0; + }; + + /** + * Factory for generating push stream workers. + */ + class IPushStreamWorkerFactory: public IWorkerFactory { + public: + /** + * Create a new push stream worker. + */ + virtual TWorkerHolder<IPushStreamWorker> MakeWorker() = 0; + }; + + //////////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Worker is a central part of any program instance. It contains current computation state + * (called computation graph) and objects required to work with it, including an allocator for unboxed values. + * + * Usually, users do not interact with workers directly. They use program instance entry points such as streams + * and consumers instead. The only case when one would have to to interact with workers is when implementing + * custom io-specification. + */ + class IWorker { + protected: + friend class TReleaseWorker; + /** + * Cleanup the worker and return to a worker factory for reuse + */ + virtual void Release() = 0; + + public: + virtual ~IWorker() = default; + + public: + /** + * Number of inputs for this program. + */ + virtual ui32 GetInputsCount() const = 0; + + /** + * MiniKQL input struct type of specified input for this program. Type is equivalent to the deduced input + * schema (see IWorker::MakeInputSchema()) + * + * If ``original`` is set to ``true``, returns type without virtual system columns. + */ + virtual const NKikimr::NMiniKQL::TStructType* GetInputType(ui32, bool original = false) const = 0; + /** + * Overload for single-input programs. + */ + virtual const NKikimr::NMiniKQL::TStructType* GetInputType(bool original = false) const = 0; + + /** + * MiniKQL input struct type of the specified input for this program. + * The returned type is the actual type of the specified input node. + */ + virtual const NKikimr::NMiniKQL::TStructType* GetRawInputType(ui32) const = 0; + /** + * Overload for single-input programs. + */ + virtual const NKikimr::NMiniKQL::TStructType* GetRawInputType() const = 0; + + /** + * MiniKQL output struct type for this program. The returned type is equivalent to the deduced output + * schema (see IWorker::MakeFullOutputSchema()). + */ + virtual const NKikimr::NMiniKQL::TType* GetOutputType() const = 0; + + /** + * MiniKQL output struct type for this program. The returned type is + * the actual type of the root node. + */ + virtual const NKikimr::NMiniKQL::TType* GetRawOutputType() const = 0; + + /** + * Make input type schema for specified input as deduced by program optimizer. This schema is equivalent + * to one provided by input spec up to the order of the fields in structures. + */ + virtual NYT::TNode MakeInputSchema(ui32) const = 0; + /** + * Overload for single-input programs. + */ + virtual NYT::TNode MakeInputSchema() const = 0; + + /** + * Make output type schema as deduced by program optimizer. If output spec provides its own schema, than + * this schema is equivalent to one provided by output spec up to the order of the fields in structures. + */ + /// @{ + /** + * Overload for single-table output programs (i.e. output type is struct). + */ + virtual NYT::TNode MakeOutputSchema() const = 0; + /** + * Overload for multi-table output programs (i.e. output type is variant over tuple). + */ + virtual NYT::TNode MakeOutputSchema(ui32) const = 0; + /** + * Overload for multi-table output programs (i.e. output type is variant over struct). + */ + virtual NYT::TNode MakeOutputSchema(TStringBuf) const = 0; + /// @} + + /** + * Generates full output schema. For single-output programs returns struct type, for multi-output programs + * returns variant type. + * + * Warning: calling this function may result in extended memory usage for large number of output tables. + */ + virtual NYT::TNode MakeFullOutputSchema() const = 0; + + /** + * Get scoped alloc used in this worker. + */ + virtual NKikimr::NMiniKQL::TScopedAlloc& GetScopedAlloc() = 0; + + /** + * Get computation graph. + */ + virtual NKikimr::NMiniKQL::IComputationGraph& GetGraph() = 0; + + /** + * Get function registry for this worker. + */ + virtual const NKikimr::NMiniKQL::IFunctionRegistry& GetFunctionRegistry() const = 0; + + /** + * Get type environment for this worker. + */ + virtual NKikimr::NMiniKQL::TTypeEnvironment& GetTypeEnvironment() = 0; + + /** + * Get llvm settings for this worker. + */ + virtual const TString& GetLLVMSettings() const = 0; + + /** + * Get YT Type V3 flags + */ + virtual ui64 GetNativeYtTypeFlags() const = 0; + + /** + * Get time provider + */ + virtual ITimeProvider* GetTimeProvider() const = 0; + }; + + /** + * Worker which operates in pull stream mode. + */ + class IPullStreamWorker: public IWorker { + public: + /** + * Set input computation graph node for specified input. The passed unboxed value should be a stream of + * structs. It should be created via the allocator associated with this very worker. + * This function can only be called once for each input. + */ + virtual void SetInput(NKikimr::NUdf::TUnboxedValue&&, ui32) = 0; + + /** + * Get the output computation graph node. The returned node will be a stream of structs or variants. + * This function cannot be called before setting an input value. + */ + virtual NKikimr::NUdf::TUnboxedValue& GetOutput() = 0; + }; + + /** + * Worker which operates in pull list mode. + */ + class IPullListWorker: public IWorker { + public: + /** + * Set input computation graph node for specified input. The passed unboxed value should be a list of + * structs. It should be created via the allocator associated with this very worker. + * This function can only be called once for each index. + */ + virtual void SetInput(NKikimr::NUdf::TUnboxedValue&&, ui32) = 0; + + /** + * Get the output computation graph node. The returned node will be a list of structs or variants. + * This function cannot be called before setting an input value. + */ + virtual NKikimr::NUdf::TUnboxedValue& GetOutput() = 0; + + /** + * Get iterator over the output list. + */ + virtual NKikimr::NUdf::TUnboxedValue& GetOutputIterator() = 0; + + /** + * Reset iterator to the beginning of the output list. After calling this function, GetOutputIterator() + * will return a fresh iterator; all previously returned iterators will become invalid. + */ + virtual void ResetOutputIterator() = 0; + }; + + /** + * Worker which operates in push stream mode. + */ + class IPushStreamWorker: public IWorker { + public: + /** + * Set a consumer where the worker will relay its output. This function can only be called once. + */ + virtual void SetConsumer(THolder<IConsumer<const NKikimr::NUdf::TUnboxedValue*>>) = 0; + + /** + * Push new value to the graph, than feed all new output to the consumer. Values cannot be pushed before + * assigning a consumer. + */ + virtual void Push(NKikimr::NUdf::TUnboxedValue&&) = 0; + + /** + * Send finish event and clear the computation graph. No new values will be accepted. + */ + virtual void OnFinish() = 0; + }; + + //////////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Input specifications describe format for program input. They carry information about input data schema + * as well as the knowledge about how to convert input structures into unboxed values (data format which can be + * processed by the YQL runtime). + * + * Input spec defines the arguments of the program's Apply method. For example, a program + * with the protobuf input spec will accept a stream of protobuf messages while a program with the + * yson spec will accept an input stream (binary or text one). + * + * See documentation for input and output spec traits for hints on how to implement a custom specs. + */ + class TInputSpecBase { + protected: + mutable TVector<THashMap<TString, NYT::TNode>> AllVirtualColumns_; + + public: + virtual ~TInputSpecBase() = default; + + public: + /** + * Get input data schemas in YQL format (NB: not a YT format). Each item of the returned vector must + * describe a structure. + * + * Format of each item is approximately this one: + * + * @code + * [ + * 'StructType', + * [ + * ["Field1Name", ["DataType", "Int32"]], + * ["Field2Name", ["DataType", "String"]], + * ... + * ] + * ] + * @endcode + */ + virtual const TVector<NYT::TNode>& GetSchemas() const = 0; + // TODO: make a neat schema builder + + /** + * Get virtual columns for each input. + * + * Key of each mapping is column name, value is data schema in YQL format. + */ + const TVector<THashMap<TString, NYT::TNode>>& GetAllVirtualColumns() const { + if (AllVirtualColumns_.empty()) { + AllVirtualColumns_ = TVector<THashMap<TString, NYT::TNode>>(GetSchemas().size()); + } + + return AllVirtualColumns_; + } + + virtual bool ProvidesBlocks() const { return false; } + }; + + /** + * Output specifications describe format for program output. Like input specifications, they cary knowledge + * about program output type and how to convert unboxed values into that type. + */ + class TOutputSpecBase { + private: + TMaybe<THashSet<TString>> OutputColumnsFilter_; + + public: + virtual ~TOutputSpecBase() = default; + + public: + /** + * Get output data schema in YQL format (NB: not a YT format). The returned value must describe a structure + * or a variant made of structures for fulti-table outputs (note: not all specs support multi-table output). + * + * See docs for the input spec's GetSchemas(). + * + * Also TNode entity could be returned (NYT::TNode::CreateEntity()), + * in which case output schema would be inferred from query and could be + * obtained by Program::GetOutputSchema() call. + */ + virtual const NYT::TNode& GetSchema() const = 0; + + /** + * Get an output columns filter. + * + * Output columns filter is a set of column names that should be left in the output. All columns that are + * not in this set will not be calculated. Depending on the output schema, they will be either removed + * completely (for optional columns) or filled with defaults (for required columns). + */ + const TMaybe<THashSet<TString>>& GetOutputColumnsFilter() const { + return OutputColumnsFilter_; + } + + /** + * Set new output columns filter. + */ + void SetOutputColumnsFilter(const TMaybe<THashSet<TString>>& outputColumnsFilter) { + OutputColumnsFilter_ = outputColumnsFilter; + } + + virtual bool AcceptsBlocks() const { return false; } + }; + + //////////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Input spec traits provide information on how to process program input. + * + * Each input spec should create a template specialization for this class, in which it should provide several + * static variables and functions. + * + * For example, a hypothetical example of implementing a JSON input spec would look like this: + * + * @code + * class TJsonInputSpec: public TInputSpecBase { + * // whatever magic you require for this spec + * }; + * + * template <> + * class TInputSpecTraits<TJsonInputSpec> { + * // write here four constants, one typedef and three static functions described below + * }; + * @endcode + * + * @tparam T input spec type. + */ + template <typename T> + struct TInputSpecTraits { + /// Safety flag which should be set to false in all template specializations of this class. Attempt to + /// build a program using a spec with `IsPartial=true` will result in compilation error. + static const constexpr bool IsPartial = true; + + /// Indicates whether this spec supports pull stream mode. + static const constexpr bool SupportPullStreamMode = false; + /// Indicates whether this spec supports pull list mode. + static const constexpr bool SupportPullListMode = false; + /// Indicates whether this spec supports push stream mode. + static const constexpr bool SupportPushStreamMode = false; + + /// For push mode, indicates the return type of the builder's Process function. + using TConsumerType = void; + + /// For pull stream mode, should take an input spec, a pull stream worker and whatever the user passed + /// to the program's Apply function, create an unboxed values with a custom stream implementations + /// and pass it to the worker's SetInput function for each input. + template <typename ...A> + static void PreparePullStreamWorker(const T&, IPullStreamWorker*, A&&...) { + Y_UNREACHABLE(); + } + + /// For pull list mode, should take an input spec, a pull list worker and whatever the user passed + /// to the program's Apply function, create an unboxed values with a custom list implementations + /// and pass it to the worker's SetInput function for each input. + template <typename ...A> + static void PreparePullListWorker(const T&, IPullListWorker*, A&&...) { + Y_UNREACHABLE(); + } + + /// For push stream mode, should take an input spec and a worker and create a consumer which will + /// be returned to the user. The consumer should keep the worker alive until its own destruction. + /// The return type of this function should exactly match the one defined in ConsumerType typedef. + static TConsumerType MakeConsumer(const T&, TWorkerHolder<IPushStreamWorker>) { + Y_UNREACHABLE(); + } + }; + + /** + * Output spec traits provide information on how to process program output. Like with input specs, each output + * spec requires an appropriate template specialization of this class. + * + * @tparam T output spec type. + */ + template <typename T> + struct TOutputSpecTraits { + /// Safety flag which should be set to false in all template specializations of this class. Attempt to + /// build a program using a spec with `IsPartial=false` will result in compilation error. + static const constexpr bool IsPartial = true; + + /// Indicates whether this spec supports pull stream mode. + static const constexpr bool SupportPullStreamMode = false; + /// Indicates whether this spec supports pull list mode. + static const constexpr bool SupportPullListMode = false; + /// Indicates whether this spec supports push stream mode. + static const constexpr bool SupportPushStreamMode = false; + + /// For pull stream mode, indicates the return type of the program's Apply function. + using TPullStreamReturnType = void; + + /// For pull list mode, indicates the return type of the program's Apply function. + using TPullListReturnType = void; + + /// For pull stream mode, should take an output spec and a worker and build a stream which will be returned + /// to the user. The return type of this function must match the one specified in the PullStreamReturnType. + static TPullStreamReturnType ConvertPullStreamWorkerToOutputType(const T&, TWorkerHolder<IPullStreamWorker>) { + Y_UNREACHABLE(); + } + + /// For pull list mode, should take an output spec and a worker and build a list which will be returned + /// to the user. The return type of this function must match the one specified in the PullListReturnType. + static TPullListReturnType ConvertPullListWorkerToOutputType(const T&, TWorkerHolder<IPullListWorker>) { + Y_UNREACHABLE(); + } + + /// For push stream mode, should take an output spec, a worker and whatever arguments the user passed + /// to the program's Apply function, create a consumer for unboxed values and pass it to the worker's + /// SetConsumer function. + template <typename ...A> + static void SetConsumerToWorker(const T&, IPushStreamWorker*, A&&...) { + Y_UNREACHABLE(); + } + }; + + //////////////////////////////////////////////////////////////////////////////////////////////////// + +#define NOT_SPEC_MSG(spec_type) "passed class should be derived from " spec_type " spec base" +#define PARTIAL_SPEC_MSG(spec_type) "this " spec_type " spec does not define its traits. Make sure you've passed " \ + "an " spec_type " spec and not some other object; also make sure you've included " \ + "all necessary headers. If you're developing a spec, make sure you have " \ + "a spec traits template specialization" +#define UNSUPPORTED_MODE_MSG(spec_type, mode) "this " spec_type " spec does not support " mode " mode" + + class IProgram { + public: + virtual ~IProgram() = default; + + public: + virtual const TInputSpecBase& GetInputSpecBase() const = 0; + virtual const TOutputSpecBase& GetOutputSpecBase() const = 0; + virtual const THashSet<TString>& GetUsedColumns(ui32) const = 0; + virtual const THashSet<TString>& GetUsedColumns() const = 0; + virtual NYT::TNode MakeInputSchema(ui32) const = 0; + virtual NYT::TNode MakeInputSchema() const = 0; + virtual NYT::TNode MakeOutputSchema() const = 0; + virtual NYT::TNode MakeOutputSchema(ui32) const = 0; + virtual NYT::TNode MakeOutputSchema(TStringBuf) const = 0; + virtual NYT::TNode MakeFullOutputSchema() const = 0; + virtual TIssues GetIssues() const = 0; + virtual TString GetCompiledProgram() = 0; + + inline void MergeUsedColumns(THashSet<TString>& columns, ui32 inputIndex) { + const auto& usedColumns = GetUsedColumns(inputIndex); + columns.insert(usedColumns.begin(), usedColumns.end()); + } + + inline void MergeUsedColumns(THashSet<TString>& columns) { + const auto& usedColumns = GetUsedColumns(); + columns.insert(usedColumns.begin(), usedColumns.end()); + } + }; + + template <typename TInputSpec, typename TOutputSpec, typename WorkerFactory> + class TProgramCommon: public IProgram { + static_assert(std::is_base_of<TInputSpecBase, TInputSpec>::value, NOT_SPEC_MSG("input")); + static_assert(std::is_base_of<TOutputSpecBase, TOutputSpec>::value, NOT_SPEC_MSG("output")); + + protected: + TInputSpec InputSpec_; + TOutputSpec OutputSpec_; + std::shared_ptr<WorkerFactory> WorkerFactory_; + + public: + explicit TProgramCommon( + TInputSpec inputSpec, + TOutputSpec outputSpec, + std::shared_ptr<WorkerFactory> workerFactory + ) + : InputSpec_(inputSpec) + , OutputSpec_(outputSpec) + , WorkerFactory_(std::move(workerFactory)) + { + } + + public: + const TInputSpec& GetInputSpec() const { + return InputSpec_; + } + + const TOutputSpec& GetOutputSpec() const { + return OutputSpec_; + } + + const TInputSpecBase& GetInputSpecBase() const override { + return InputSpec_; + } + + const TOutputSpecBase& GetOutputSpecBase() const override { + return OutputSpec_; + } + + const THashSet<TString>& GetUsedColumns(ui32 inputIndex) const override { + return WorkerFactory_->GetUsedColumns(inputIndex); + } + + const THashSet<TString>& GetUsedColumns() const override { + return WorkerFactory_->GetUsedColumns(); + } + + NYT::TNode MakeInputSchema(ui32 inputIndex) const override { + return WorkerFactory_->MakeInputSchema(inputIndex); + } + + NYT::TNode MakeInputSchema() const override { + return WorkerFactory_->MakeInputSchema(); + } + + NYT::TNode MakeOutputSchema() const override { + return WorkerFactory_->MakeOutputSchema(); + } + + NYT::TNode MakeOutputSchema(ui32 outputIndex) const override { + return WorkerFactory_->MakeOutputSchema(outputIndex); + } + + NYT::TNode MakeOutputSchema(TStringBuf outputName) const override { + return WorkerFactory_->MakeOutputSchema(outputName); + } + + NYT::TNode MakeFullOutputSchema() const override { + return WorkerFactory_->MakeFullOutputSchema(); + } + + TIssues GetIssues() const override { + return WorkerFactory_->GetIssues(); + } + + TString GetCompiledProgram() override { + return WorkerFactory_->GetCompiledProgram(); + } + }; + + template <typename TInputSpec, typename TOutputSpec> + class TPullStreamProgram final: public TProgramCommon<TInputSpec, TOutputSpec, IPullStreamWorkerFactory> { + using TProgramCommon<TInputSpec, TOutputSpec, IPullStreamWorkerFactory>::WorkerFactory_; + using TProgramCommon<TInputSpec, TOutputSpec, IPullStreamWorkerFactory>::InputSpec_; + using TProgramCommon<TInputSpec, TOutputSpec, IPullStreamWorkerFactory>::OutputSpec_; + + public: + using TProgramCommon<TInputSpec, TOutputSpec, IPullStreamWorkerFactory>::TProgramCommon; + + public: + template <typename ...T> + typename TOutputSpecTraits<TOutputSpec>::TPullStreamReturnType Apply(T&& ... t) { + static_assert(!TInputSpecTraits<TInputSpec>::IsPartial, PARTIAL_SPEC_MSG("input")); + static_assert(!TOutputSpecTraits<TOutputSpec>::IsPartial, PARTIAL_SPEC_MSG("output")); + static_assert(TInputSpecTraits<TInputSpec>::SupportPullStreamMode, UNSUPPORTED_MODE_MSG("input", "pull stream")); + static_assert(TOutputSpecTraits<TOutputSpec>::SupportPullStreamMode, UNSUPPORTED_MODE_MSG("output", "pull stream")); + + auto worker = WorkerFactory_->MakeWorker(); + TInputSpecTraits<TInputSpec>::PreparePullStreamWorker(InputSpec_, worker.Get(), std::forward<T>(t)...); + return TOutputSpecTraits<TOutputSpec>::ConvertPullStreamWorkerToOutputType(OutputSpec_, std::move(worker)); + } + }; + + template <typename TInputSpec, typename TOutputSpec> + class TPullListProgram final: public TProgramCommon<TInputSpec, TOutputSpec, IPullListWorkerFactory> { + using TProgramCommon<TInputSpec, TOutputSpec, IPullListWorkerFactory>::WorkerFactory_; + using TProgramCommon<TInputSpec, TOutputSpec, IPullListWorkerFactory>::InputSpec_; + using TProgramCommon<TInputSpec, TOutputSpec, IPullListWorkerFactory>::OutputSpec_; + + public: + using TProgramCommon<TInputSpec, TOutputSpec, IPullListWorkerFactory>::TProgramCommon; + + public: + template <typename ...T> + typename TOutputSpecTraits<TOutputSpec>::TPullListReturnType Apply(T&& ... t) { + static_assert(!TInputSpecTraits<TInputSpec>::IsPartial, PARTIAL_SPEC_MSG("input")); + static_assert(!TOutputSpecTraits<TOutputSpec>::IsPartial, PARTIAL_SPEC_MSG("output")); + static_assert(TInputSpecTraits<TInputSpec>::SupportPullListMode, UNSUPPORTED_MODE_MSG("input", "pull list")); + static_assert(TOutputSpecTraits<TOutputSpec>::SupportPullListMode, UNSUPPORTED_MODE_MSG("output", "pull list")); + + auto worker = WorkerFactory_->MakeWorker(); + TInputSpecTraits<TInputSpec>::PreparePullListWorker(InputSpec_, worker.Get(), std::forward<T>(t)...); + return TOutputSpecTraits<TOutputSpec>::ConvertPullListWorkerToOutputType(OutputSpec_, std::move(worker)); + } + }; + + template <typename TInputSpec, typename TOutputSpec> + class TPushStreamProgram final: public TProgramCommon<TInputSpec, TOutputSpec, IPushStreamWorkerFactory> { + using TProgramCommon<TInputSpec, TOutputSpec, IPushStreamWorkerFactory>::WorkerFactory_; + using TProgramCommon<TInputSpec, TOutputSpec, IPushStreamWorkerFactory>::InputSpec_; + using TProgramCommon<TInputSpec, TOutputSpec, IPushStreamWorkerFactory>::OutputSpec_; + + public: + using TProgramCommon<TInputSpec, TOutputSpec, IPushStreamWorkerFactory>::TProgramCommon; + + public: + template <typename ...T> + typename TInputSpecTraits<TInputSpec>::TConsumerType Apply(T&& ... t) { + static_assert(!TInputSpecTraits<TInputSpec>::IsPartial, PARTIAL_SPEC_MSG("input")); + static_assert(!TOutputSpecTraits<TOutputSpec>::IsPartial, PARTIAL_SPEC_MSG("output")); + static_assert(TInputSpecTraits<TInputSpec>::SupportPushStreamMode, UNSUPPORTED_MODE_MSG("input", "push stream")); + static_assert(TOutputSpecTraits<TOutputSpec>::SupportPushStreamMode, UNSUPPORTED_MODE_MSG("output", "push stream")); + + auto worker = WorkerFactory_->MakeWorker(); + TOutputSpecTraits<TOutputSpec>::SetConsumerToWorker(OutputSpec_, worker.Get(), std::forward<T>(t)...); + return TInputSpecTraits<TInputSpec>::MakeConsumer(InputSpec_, std::move(worker)); + } + }; + +#undef NOT_SPEC_MSG +#undef PARTIAL_SPEC_MSG +#undef UNSUPPORTED_MODE_MSG + + //////////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Configure global logging facilities. Affects all YQL modules. + */ + void ConfigureLogging(const TLoggingOptions& = {}); + + /** + * Create a new program factory. + * Custom logging initialization could be preformed by a call to the ConfigureLogging method beforehand. + * If the ConfigureLogging method has not been called the default logging initialization will be performed. + */ + IProgramFactoryPtr MakeProgramFactory(const TProgramFactoryOptions& = {}); + } +} + +Y_DECLARE_OUT_SPEC(inline, NYql::NPureCalc::TCompileError, stream, value) { + stream << value.AsStrBuf() << Endl << "Issues:" << Endl << value.GetIssues() << Endl << Endl << "Yql:" << Endl <<value.GetYql(); +} diff --git a/yql/essentials/public/purecalc/common/logger_init.cpp b/yql/essentials/public/purecalc/common/logger_init.cpp new file mode 100644 index 00000000000..a7da19d9f10 --- /dev/null +++ b/yql/essentials/public/purecalc/common/logger_init.cpp @@ -0,0 +1,32 @@ +#include "logger_init.h" + +#include <yql/essentials/utils/log/log.h> + +#include <atomic> + +namespace NYql { +namespace NPureCalc { + +namespace { + std::atomic_bool Initialized; +} + + void InitLogging(const TLoggingOptions& options) { + NLog::InitLogger(options.LogDestination); + auto& logger = NLog::YqlLogger(); + logger.SetDefaultPriority(options.LogLevel_); + for (int i = 0; i < NLog::EComponentHelpers::ToInt(NLog::EComponent::MaxValue); ++i) { + logger.SetComponentLevel((NLog::EComponent) i, (NLog::ELevel) options.LogLevel_); + } + Initialized = true; + } + + void EnsureLoggingInitialized() { + if (Initialized.load()) { + return; + } + InitLogging(TLoggingOptions()); + } + +} +} diff --git a/yql/essentials/public/purecalc/common/logger_init.h b/yql/essentials/public/purecalc/common/logger_init.h new file mode 100644 index 00000000000..039cbd44118 --- /dev/null +++ b/yql/essentials/public/purecalc/common/logger_init.h @@ -0,0 +1,10 @@ +#pragma once + +#include "interface.h" + +namespace NYql { + namespace NPureCalc { + void InitLogging(const TLoggingOptions& options); + void EnsureLoggingInitialized(); + } +} diff --git a/yql/essentials/public/purecalc/common/names.cpp b/yql/essentials/public/purecalc/common/names.cpp new file mode 100644 index 00000000000..5e8412a7b22 --- /dev/null +++ b/yql/essentials/public/purecalc/common/names.cpp @@ -0,0 +1,19 @@ +#include "names.h" + +#include <util/generic/strbuf.h> + +namespace NYql::NPureCalc { + const TStringBuf PurecalcSysColumnsPrefix = "_yql_sys_"; + const TStringBuf PurecalcSysColumnTablePath = "_yql_sys_tablepath"; + const TStringBuf PurecalcBlockColumnLength = "_yql_block_length"; + + const TStringBuf PurecalcDefaultCluster = "view"; + const TStringBuf PurecalcDefaultService = "data"; + + const TStringBuf PurecalcInputCallableName = "Self"; + const TStringBuf PurecalcInputTablePrefix = "Input"; + + const TStringBuf PurecalcBlockInputCallableName = "BlockSelf"; + + const TStringBuf PurecalcUdfModulePrefix = "<purecalc>::"; +} diff --git a/yql/essentials/public/purecalc/common/names.h b/yql/essentials/public/purecalc/common/names.h new file mode 100644 index 00000000000..b19c15ca4fe --- /dev/null +++ b/yql/essentials/public/purecalc/common/names.h @@ -0,0 +1,19 @@ +#pragma once + +#include <util/generic/fwd.h> + +namespace NYql::NPureCalc { + extern const TStringBuf PurecalcSysColumnsPrefix; + extern const TStringBuf PurecalcSysColumnTablePath; + extern const TStringBuf PurecalcBlockColumnLength; + + extern const TStringBuf PurecalcDefaultCluster; + extern const TStringBuf PurecalcDefaultService; + + extern const TStringBuf PurecalcInputCallableName; + extern const TStringBuf PurecalcInputTablePrefix; + + extern const TStringBuf PurecalcBlockInputCallableName; + + extern const TStringBuf PurecalcUdfModulePrefix; +} diff --git a/yql/essentials/public/purecalc/common/no_llvm/ya.make b/yql/essentials/public/purecalc/common/no_llvm/ya.make new file mode 100644 index 00000000000..96820516b77 --- /dev/null +++ b/yql/essentials/public/purecalc/common/no_llvm/ya.make @@ -0,0 +1,18 @@ +LIBRARY() + +INCLUDE(../ya.make.inc) + +PEERDIR( + contrib/ydb/library/yql/providers/yt/codec/codegen/no_llvm + yql/essentials/providers/config + yql/essentials/minikql/computation/no_llvm + yql/essentials/minikql/invoke_builtins/no_llvm + yql/essentials/minikql/comp_nodes/no_llvm + yql/essentials/minikql/codegen/no_llvm + yql/essentials/parser/pg_wrapper + yql/essentials/parser/pg_wrapper/interface + yql/essentials/sql/pg +) + +END() + diff --git a/yql/essentials/public/purecalc/common/processor_mode.cpp b/yql/essentials/public/purecalc/common/processor_mode.cpp new file mode 100644 index 00000000000..957cc2d7f42 --- /dev/null +++ b/yql/essentials/public/purecalc/common/processor_mode.cpp @@ -0,0 +1 @@ +#include "processor_mode.h" diff --git a/yql/essentials/public/purecalc/common/processor_mode.h b/yql/essentials/public/purecalc/common/processor_mode.h new file mode 100644 index 00000000000..9bec87cadc9 --- /dev/null +++ b/yql/essentials/public/purecalc/common/processor_mode.h @@ -0,0 +1,11 @@ +#pragma once + +namespace NYql { + namespace NPureCalc { + enum class EProcessorMode { + PullList, + PullStream, + PushStream + }; + } +} diff --git a/yql/essentials/public/purecalc/common/program_factory.cpp b/yql/essentials/public/purecalc/common/program_factory.cpp new file mode 100644 index 00000000000..8452dc3d003 --- /dev/null +++ b/yql/essentials/public/purecalc/common/program_factory.cpp @@ -0,0 +1,158 @@ +#include "program_factory.h" +#include "logger_init.h" +#include "names.h" +#include "worker_factory.h" + +#include <yql/essentials/utils/log/log.h> + +using namespace NYql; +using namespace NYql::NPureCalc; + +TProgramFactory::TProgramFactory(const TProgramFactoryOptions& options) + : Options_(options) + , ExprOutputStream_(Options_.ExprOutputStream) + , CountersProvider_(nullptr) +{ + EnsureLoggingInitialized(); + + if (!TryFromString(Options_.BlockEngineSettings, BlockEngineMode_)) { + ythrow TCompileError("", "") << "Unknown BlockEngineSettings value: expected " + << GetEnumAllNames<EBlockEngineMode>() + << ", but got: " + << Options_.BlockEngineSettings; + } + + NUserData::TUserData::UserDataToLibraries(Options_.UserData_, Modules_); + + UserData_ = GetYqlModuleResolver(ExprContext_, ModuleResolver_, Options_.UserData_, {}, {}); + + if (!ModuleResolver_) { + ythrow TCompileError("", ExprContext_.IssueManager.GetIssues().ToString()) << "failed to compile modules"; + } + + TVector<TString> UDFsPaths; + for (const auto& item: Options_.UserData_) { + if ( + item.Type_ == NUserData::EType::UDF && + item.Disposition_ == NUserData::EDisposition::FILESYSTEM + ) { + UDFsPaths.push_back(item.Content_); + } + } + + if (!Options_.UdfsDir_.empty()) { + NKikimr::NMiniKQL::FindUdfsInDir(Options_.UdfsDir_, &UDFsPaths); + } + + FuncRegistry_ = NKikimr::NMiniKQL::CreateFunctionRegistry( + &NYql::NBacktrace::KikimrBackTrace, NKikimr::NMiniKQL::CreateBuiltinRegistry(), false, UDFsPaths)->Clone(); + + NKikimr::NMiniKQL::FillStaticModules(*FuncRegistry_); +} + +TProgramFactory::~TProgramFactory() { +} + +void TProgramFactory::AddUdfModule( + const TStringBuf& moduleName, + NKikimr::NUdf::TUniquePtr<NKikimr::NUdf::IUdfModule>&& module +) { + FuncRegistry_->AddModule( + TString::Join(PurecalcUdfModulePrefix, moduleName), moduleName, std::move(module) + ); +} + +void TProgramFactory::SetCountersProvider(NKikimr::NUdf::ICountersProvider* provider) { + CountersProvider_ = provider; +} + +IPullStreamWorkerFactoryPtr TProgramFactory::MakePullStreamWorkerFactory( + const TInputSpecBase& inputSpec, + const TOutputSpecBase& outputSpec, + TString query, + ETranslationMode mode, + ui16 syntaxVersion +) { + return std::make_shared<TPullStreamWorkerFactory>(TWorkerFactoryOptions( + TIntrusivePtr<TProgramFactory>(this), + inputSpec, + outputSpec, + query, + FuncRegistry_, + ModuleResolver_, + UserData_, + Modules_, + Options_.LLVMSettings, + BlockEngineMode_, + ExprOutputStream_, + CountersProvider_, + mode, + syntaxVersion, + Options_.NativeYtTypeFlags, + Options_.DeterministicTimeProviderSeed, + Options_.UseSystemColumns, + Options_.UseWorkerPool + )); +} + +IPullListWorkerFactoryPtr TProgramFactory::MakePullListWorkerFactory( + const TInputSpecBase& inputSpec, + const TOutputSpecBase& outputSpec, + TString query, + ETranslationMode mode, + ui16 syntaxVersion +) { + return std::make_shared<TPullListWorkerFactory>(TWorkerFactoryOptions( + TIntrusivePtr<TProgramFactory>(this), + inputSpec, + outputSpec, + query, + FuncRegistry_, + ModuleResolver_, + UserData_, + Modules_, + Options_.LLVMSettings, + BlockEngineMode_, + ExprOutputStream_, + CountersProvider_, + mode, + syntaxVersion, + Options_.NativeYtTypeFlags, + Options_.DeterministicTimeProviderSeed, + Options_.UseSystemColumns, + Options_.UseWorkerPool + )); +} + +IPushStreamWorkerFactoryPtr TProgramFactory::MakePushStreamWorkerFactory( + const TInputSpecBase& inputSpec, + const TOutputSpecBase& outputSpec, + TString query, + ETranslationMode mode, + ui16 syntaxVersion +) { + if (inputSpec.GetSchemas().size() > 1) { + ythrow yexception() << "push stream mode doesn't support several inputs"; + } + + return std::make_shared<TPushStreamWorkerFactory>(TWorkerFactoryOptions( + TIntrusivePtr<TProgramFactory>(this), + inputSpec, + outputSpec, + query, + FuncRegistry_, + ModuleResolver_, + UserData_, + Modules_, + Options_.LLVMSettings, + BlockEngineMode_, + ExprOutputStream_, + CountersProvider_, + mode, + syntaxVersion, + Options_.NativeYtTypeFlags, + Options_.DeterministicTimeProviderSeed, + Options_.UseSystemColumns, + Options_.UseWorkerPool + )); +} diff --git a/yql/essentials/public/purecalc/common/program_factory.h b/yql/essentials/public/purecalc/common/program_factory.h new file mode 100644 index 00000000000..278d3e05a6a --- /dev/null +++ b/yql/essentials/public/purecalc/common/program_factory.h @@ -0,0 +1,48 @@ +#pragma once + +#include "interface.h" + +#include <yql/essentials/utils/backtrace/backtrace.h> +#include <yql/essentials/core/services/mounts/yql_mounts.h> + +#include <yql/essentials/ast/yql_expr.h> +#include <yql/essentials/core/yql_user_data.h> +#include <yql/essentials/minikql/mkql_function_registry.h> +#include <yql/essentials/minikql/invoke_builtins/mkql_builtins.h> + +#include <util/generic/function.h> +#include <util/generic/ptr.h> +#include <util/generic/strbuf.h> + +namespace NYql { + namespace NPureCalc { + class TProgramFactory: public IProgramFactory { + private: + TProgramFactoryOptions Options_; + TExprContext ExprContext_; + TIntrusivePtr<NKikimr::NMiniKQL::IMutableFunctionRegistry> FuncRegistry_; + IModuleResolver::TPtr ModuleResolver_; + TUserDataTable UserData_; + EBlockEngineMode BlockEngineMode_; + IOutputStream* ExprOutputStream_; + THashMap<TString, TString> Modules_; + NKikimr::NUdf::ICountersProvider* CountersProvider_; + + public: + explicit TProgramFactory(const TProgramFactoryOptions&); + ~TProgramFactory() override; + + public: + void AddUdfModule( + const TStringBuf& moduleName, + NKikimr::NUdf::TUniquePtr<NKikimr::NUdf::IUdfModule>&& module + ) override; + + void SetCountersProvider(NKikimr::NUdf::ICountersProvider* provider) override; + + IPullStreamWorkerFactoryPtr MakePullStreamWorkerFactory(const TInputSpecBase&, const TOutputSpecBase&, TString, ETranslationMode, ui16) override; + IPullListWorkerFactoryPtr MakePullListWorkerFactory(const TInputSpecBase&, const TOutputSpecBase&, TString, ETranslationMode, ui16) override; + IPushStreamWorkerFactoryPtr MakePushStreamWorkerFactory(const TInputSpecBase&, const TOutputSpecBase&, TString, ETranslationMode, ui16) override; + }; + } +} diff --git a/yql/essentials/public/purecalc/common/transformations/align_output_schema.cpp b/yql/essentials/public/purecalc/common/transformations/align_output_schema.cpp new file mode 100644 index 00000000000..16cbeeabcc5 --- /dev/null +++ b/yql/essentials/public/purecalc/common/transformations/align_output_schema.cpp @@ -0,0 +1,122 @@ +#include "align_output_schema.h" + +#include <yql/essentials/public/purecalc/common/names.h> +#include <yql/essentials/public/purecalc/common/type_from_schema.h> +#include <yql/essentials/public/purecalc/common/transformations/utils.h> + +#include <yql/essentials/core/yql_expr_type_annotation.h> + +using namespace NYql; +using namespace NYql::NPureCalc; + +namespace { + class TOutputAligner : public TSyncTransformerBase { + private: + const TTypeAnnotationNode* OutputStruct_; + bool AcceptsBlocks_; + EProcessorMode ProcessorMode_; + + public: + explicit TOutputAligner( + const TTypeAnnotationNode* outputStruct, + bool acceptsBlocks, + EProcessorMode processorMode + ) + : OutputStruct_(outputStruct) + , AcceptsBlocks_(acceptsBlocks) + , ProcessorMode_(processorMode) + { + } + + public: + TStatus DoTransform(TExprNode::TPtr input, TExprNode::TPtr& output, TExprContext& ctx) final { + output = input; + + const auto* expectedType = MakeExpectedType(ctx); + const auto* expectedItemType = MakeExpectedItemType(); + const auto* actualType = MakeActualType(input); + const auto* actualItemType = MakeActualItemType(input); + + // XXX: Tweak the obtained expression type, is the spec supports blocks: + // 1. Remove "_yql_block_length" attribute, since it's for internal usage. + // 2. Strip block container from the type to store its internal type. + if (AcceptsBlocks_) { + Y_ENSURE(actualItemType->GetKind() == ETypeAnnotationKind::Struct); + actualItemType = UnwrapBlockStruct(actualItemType->Cast<TStructExprType>(), ctx); + if (ProcessorMode_ == EProcessorMode::PullList) { + actualType = ctx.MakeType<TListExprType>(actualItemType); + } else { + actualType = ctx.MakeType<TStreamExprType>(actualItemType); + } + } + + if (!ValidateOutputType(actualItemType, expectedItemType, ctx)) { + return TStatus::Error; + } + + if (!expectedType) { + return TStatus::Ok; + } + + auto status = TryConvertTo(output, *actualType, *expectedType, ctx); + + if (status.Level == IGraphTransformer::TStatus::Repeat) { + status = IGraphTransformer::TStatus(IGraphTransformer::TStatus::Repeat, true); + } + + return status; + } + + void Rewind() final { + } + + private: + const TTypeAnnotationNode* MakeExpectedType(TExprContext& ctx) { + if (!OutputStruct_) { + return nullptr; + } + + switch (ProcessorMode_) { + case EProcessorMode::PullList: + return ctx.MakeType<TListExprType>(OutputStruct_); + case EProcessorMode::PullStream: + case EProcessorMode::PushStream: + return ctx.MakeType<TStreamExprType>(OutputStruct_); + } + + Y_ABORT("Unexpected"); + } + + const TTypeAnnotationNode* MakeExpectedItemType() { + return OutputStruct_; + } + + const TTypeAnnotationNode* MakeActualType(TExprNode::TPtr& input) { + return input->GetTypeAnn(); + } + + const TTypeAnnotationNode* MakeActualItemType(TExprNode::TPtr& input) { + auto actualType = MakeActualType(input); + switch (actualType->GetKind()) { + case ETypeAnnotationKind::Stream: + Y_ENSURE(ProcessorMode_ != EProcessorMode::PullList, + "processor mode mismatches the actual container type"); + return actualType->Cast<TStreamExprType>()->GetItemType(); + case ETypeAnnotationKind::List: + Y_ENSURE(ProcessorMode_ == EProcessorMode::PullList, + "processor mode mismatches the actual container type"); + return actualType->Cast<TListExprType>()->GetItemType(); + default: + Y_ABORT("unexpected return type"); + } + } + }; +} + +TAutoPtr<IGraphTransformer> NYql::NPureCalc::MakeOutputAligner( + const TTypeAnnotationNode* outputStruct, + bool acceptsBlocks, + EProcessorMode processorMode +) { + return new TOutputAligner(outputStruct, acceptsBlocks, processorMode); +} diff --git a/yql/essentials/public/purecalc/common/transformations/align_output_schema.h b/yql/essentials/public/purecalc/common/transformations/align_output_schema.h new file mode 100644 index 00000000000..da673aaede1 --- /dev/null +++ b/yql/essentials/public/purecalc/common/transformations/align_output_schema.h @@ -0,0 +1,25 @@ +#pragma once + +#include <yql/essentials/public/purecalc/common/processor_mode.h> + +#include <yql/essentials/core/yql_graph_transformer.h> +#include <yql/essentials/core/yql_type_annotation.h> + +namespace NYql { + namespace NPureCalc { + /** + * A transformer which converts an output type of the expression to the given type or reports an error. + * + * @param outputStruct destination output struct type. + * @param acceptsBlocks indicates, whether the output type need to be + * preprocessed. + * @param processorMode specifies the top-most container of the result. + * @return a graph transformer for type alignment. + */ + TAutoPtr<IGraphTransformer> MakeOutputAligner( + const TTypeAnnotationNode* outputStruct, + bool acceptsBlocks, + EProcessorMode processorMode + ); + } +} diff --git a/yql/essentials/public/purecalc/common/transformations/extract_used_columns.cpp b/yql/essentials/public/purecalc/common/transformations/extract_used_columns.cpp new file mode 100644 index 00000000000..9ff7a0df638 --- /dev/null +++ b/yql/essentials/public/purecalc/common/transformations/extract_used_columns.cpp @@ -0,0 +1,96 @@ +#include "extract_used_columns.h" + +#include <yql/essentials/public/purecalc/common/inspect_input.h> + +#include <yql/essentials/core/yql_expr_optimize.h> +#include <yql/essentials/core/expr_nodes/yql_expr_nodes.h> + +using namespace NYql; +using namespace NYql::NPureCalc; + +namespace { + class TUsedColumnsExtractor : public TSyncTransformerBase { + private: + TVector<THashSet<TString>>* const Destination_; + const TVector<THashSet<TString>>& AllColumns_; + TString NodeName_; + + bool CalculatedUsedFields_ = false; + + public: + TUsedColumnsExtractor( + TVector<THashSet<TString>>* destination, + const TVector<THashSet<TString>>& allColumns, + TString nodeName + ) + : Destination_(destination) + , AllColumns_(allColumns) + , NodeName_(std::move(nodeName)) + { + } + + TUsedColumnsExtractor(TVector<THashSet<TString>>*, TVector<THashSet<TString>>&&, TString) = delete; + + public: + TStatus DoTransform(TExprNode::TPtr input, TExprNode::TPtr& output, TExprContext& ctx) final { + output = input; + + if (CalculatedUsedFields_) { + return IGraphTransformer::TStatus::Ok; + } + + bool hasError = false; + + *Destination_ = AllColumns_; + + VisitExpr(input, [&](const TExprNode::TPtr& inputExpr) { + NNodes::TExprBase node(inputExpr); + if (auto maybeExtract = node.Maybe<NNodes::TCoExtractMembers>()) { + auto extract = maybeExtract.Cast(); + const auto& arg = extract.Input().Ref(); + if (arg.IsCallable(NodeName_)) { + ui32 inputIndex; + if (!TryFetchInputIndexFromSelf(arg, ctx, AllColumns_.size(), inputIndex)) { + hasError = true; + return false; + } + + YQL_ENSURE(inputIndex < AllColumns_.size()); + + auto& destinationColumnsSet = (*Destination_)[inputIndex]; + const auto& allColumnsSet = AllColumns_[inputIndex]; + + destinationColumnsSet.clear(); + for (const auto& columnAtom : extract.Members()) { + TString name = TString(columnAtom.Value()); + YQL_ENSURE(allColumnsSet.contains(name), "unexpected column in the input struct"); + destinationColumnsSet.insert(name); + } + } + } + + return true; + }); + + if (hasError) { + return IGraphTransformer::TStatus::Error; + } + + CalculatedUsedFields_ = true; + + return IGraphTransformer::TStatus::Ok; + } + + void Rewind() final { + CalculatedUsedFields_ = false; + } + }; +} + +TAutoPtr<IGraphTransformer> NYql::NPureCalc::MakeUsedColumnsExtractor( + TVector<THashSet<TString>>* destination, + const TVector<THashSet<TString>>& allColumns, + const TString& nodeName +) { + return new TUsedColumnsExtractor(destination, allColumns, nodeName); +} diff --git a/yql/essentials/public/purecalc/common/transformations/extract_used_columns.h b/yql/essentials/public/purecalc/common/transformations/extract_used_columns.h new file mode 100644 index 00000000000..d0850e28b59 --- /dev/null +++ b/yql/essentials/public/purecalc/common/transformations/extract_used_columns.h @@ -0,0 +1,29 @@ +#pragma once + +#include <yql/essentials/public/purecalc/common/names.h> + +#include <yql/essentials/core/yql_graph_transformer.h> +#include <yql/essentials/core/yql_type_annotation.h> + +#include <util/generic/hash_set.h> +#include <util/generic/string.h> + +namespace NYql { + namespace NPureCalc { + /** + * Make transformation which builds sets of input columns from the given expression. + * + * @param destination a vector of string sets which will be populated with column names sets when + * transformation pipeline is launched. This pointer should contain a valid + * TVector<THashSet> instance. The transformation will overwrite its contents. + * @param allColumns vector of sets with all available columns for each input. + * @param nodeName name of the callable used to get input data, e.g. `Self`. + * @return an extractor which scans an input structs contents and populates destination. + */ + TAutoPtr<IGraphTransformer> MakeUsedColumnsExtractor( + TVector<THashSet<TString>>* destination, + const TVector<THashSet<TString>>& allColumns, + const TString& nodeName = TString{PurecalcInputCallableName} + ); + } +} diff --git a/yql/essentials/public/purecalc/common/transformations/output_columns_filter.cpp b/yql/essentials/public/purecalc/common/transformations/output_columns_filter.cpp new file mode 100644 index 00000000000..04181db7c83 --- /dev/null +++ b/yql/essentials/public/purecalc/common/transformations/output_columns_filter.cpp @@ -0,0 +1,100 @@ +#include "output_columns_filter.h" + +#include <yql/essentials/core/yql_expr_type_annotation.h> + +using namespace NYql; +using namespace NYql::NPureCalc; + +namespace { + class TOutputColumnsFilter: public TSyncTransformerBase { + private: + TMaybe<THashSet<TString>> Filter_; + bool Fired_; + + public: + explicit TOutputColumnsFilter(TMaybe<THashSet<TString>> filter) + : Filter_(std::move(filter)) + , Fired_(false) + { + } + + public: + void Rewind() override { + Fired_ = false; + } + + TStatus DoTransform(TExprNode::TPtr input, TExprNode::TPtr& output, TExprContext& ctx) final { + output = input; + + if (Fired_ || Filter_.Empty()) { + return IGraphTransformer::TStatus::Ok; + } + + const TTypeAnnotationNode* returnType = output->GetTypeAnn(); + const TTypeAnnotationNode* returnItemType = nullptr; + switch (returnType->GetKind()) { + case ETypeAnnotationKind::Stream: + returnItemType = returnType->Cast<TStreamExprType>()->GetItemType(); + break; + case ETypeAnnotationKind::List: + returnItemType = returnType->Cast<TListExprType>()->GetItemType(); + break; + default: + Y_ABORT("unexpected return type"); + } + + if (returnItemType->GetKind() != ETypeAnnotationKind::Struct) { + ctx.AddError(TIssue(ctx.GetPosition(output->Pos()), "columns filter only supported for single-output programs")); + } + + const auto* returnItemStruct = returnItemType->Cast<TStructExprType>(); + + auto arg = ctx.NewArgument(TPositionHandle(), "row"); + TExprNode::TListType asStructItems; + for (const auto& x : returnItemStruct->GetItems()) { + TExprNode::TPtr value; + if (Filter_->contains(x->GetName())) { + value = ctx.Builder({}) + .Callable("Member") + .Add(0, arg) + .Atom(1, x->GetName()) + .Seal() + .Build(); + } else { + auto type = x->GetItemType(); + value = ctx.Builder({}) + .Callable(type->GetKind() == ETypeAnnotationKind::Optional ? "Nothing" : "Default") + .Add(0, ExpandType({}, *type, ctx)) + .Seal() + .Build(); + } + + auto item = ctx.Builder({}) + .List() + .Atom(0, x->GetName()) + .Add(1, value) + .Seal() + .Build(); + + asStructItems.push_back(item); + } + + auto body = ctx.NewCallable(TPositionHandle(), "AsStruct", std::move(asStructItems)); + auto lambda = ctx.NewLambda(TPositionHandle(), ctx.NewArguments(TPositionHandle(), {arg}), std::move(body)); + output = ctx.Builder(TPositionHandle()) + .Callable("Map") + .Add(0, output) + .Add(1, lambda) + .Seal() + .Build(); + + Fired_ = true; + + return IGraphTransformer::TStatus(IGraphTransformer::TStatus::Repeat, true); + } + }; +} + +TAutoPtr<IGraphTransformer> NYql::NPureCalc::MakeOutputColumnsFilter(const TMaybe<THashSet<TString>>& columns) { + return new TOutputColumnsFilter(columns); +} diff --git a/yql/essentials/public/purecalc/common/transformations/output_columns_filter.h b/yql/essentials/public/purecalc/common/transformations/output_columns_filter.h new file mode 100644 index 00000000000..85302d82feb --- /dev/null +++ b/yql/essentials/public/purecalc/common/transformations/output_columns_filter.h @@ -0,0 +1,18 @@ +#pragma once + +#include <yql/essentials/public/purecalc/common/processor_mode.h> + +#include <yql/essentials/core/yql_graph_transformer.h> +#include <yql/essentials/core/yql_type_annotation.h> + +namespace NYql { + namespace NPureCalc { + /** + * A transformer which removes unwanted columns from output. + * + * @param columns remove all columns that are not in this set. + * @return a graph transformer for filtering output. + */ + TAutoPtr<IGraphTransformer> MakeOutputColumnsFilter(const TMaybe<THashSet<TString>>& columns); + } +} diff --git a/yql/essentials/public/purecalc/common/transformations/replace_table_reads.cpp b/yql/essentials/public/purecalc/common/transformations/replace_table_reads.cpp new file mode 100644 index 00000000000..141e92baf28 --- /dev/null +++ b/yql/essentials/public/purecalc/common/transformations/replace_table_reads.cpp @@ -0,0 +1,247 @@ +#include "replace_table_reads.h" + +#include <yql/essentials/public/purecalc/common/names.h> +#include <yql/essentials/public/purecalc/common/transformations/utils.h> + +#include <yql/essentials/core/yql_expr_optimize.h> +#include <yql/essentials/core/yql_expr_type_annotation.h> + +using namespace NYql; +using namespace NYql::NPureCalc; + +namespace { + class TTableReadsReplacer: public TSyncTransformerBase { + private: + const TVector<const TStructExprType*>& InputStructs_; + bool UseSystemColumns_; + EProcessorMode ProcessorMode_; + TString CallableName_; + TString TablePrefix_; + bool Complete_ = false; + + public: + explicit TTableReadsReplacer( + const TVector<const TStructExprType*>& inputStructs, + bool useSystemColumns, + EProcessorMode processorMode, + TString inputNodeName, + TString tablePrefix + ) + : InputStructs_(inputStructs) + , UseSystemColumns_(useSystemColumns) + , ProcessorMode_(processorMode) + , CallableName_(std::move(inputNodeName)) + , TablePrefix_(std::move(tablePrefix)) + { + } + + TTableReadsReplacer(TVector<const TStructExprType*>&&, TString, TString) = delete; + + public: + TStatus DoTransform(const TExprNode::TPtr input, TExprNode::TPtr& output, TExprContext& ctx) final { + output = input; + if (Complete_) { + return TStatus::Ok; + } + + TOptimizeExprSettings settings(nullptr); + + auto status = OptimizeExpr(input, output, [&](const TExprNode::TPtr& node, TExprContext& ctx) -> TExprNode::TPtr { + if (node->IsCallable(NNodes::TCoRight::CallableName())) { + TIssueScopeGuard issueScope(ctx.IssueManager, [&]() { + return new TIssue(ctx.GetPosition(node->Pos()), TStringBuilder() << "At function: " << node->Content()); + }); + + if (!EnsureMinArgsCount(*node, 1, ctx)) { + return nullptr; + } + + if (node->Child(0)->IsCallable(NNodes::TCoCons::CallableName())) { + return node; + } + + if (!node->Child(0)->IsCallable(NNodes::TCoRead::CallableName())) { + ctx.AddError(TIssue(ctx.GetPosition(node->Child(0)->Pos()), TStringBuilder() << "Expected Read!")); + return nullptr; + } + + return BuildInputFromRead(node->Pos(), node->ChildPtr(0), ctx); + } else if (node->IsCallable(NNodes::TCoLeft::CallableName())) { + TIssueScopeGuard issueScope(ctx.IssueManager, [&]() { + return new TIssue(ctx.GetPosition(node->Pos()), TStringBuilder() << "At function: " << node->Content()); + }); + + if (!EnsureMinArgsCount(*node, 1, ctx)) { + return nullptr; + } + + if (!node->Child(0)->IsCallable(NNodes::TCoRead::CallableName())) { + ctx.AddError(TIssue(ctx.GetPosition(node->Child(0)->Pos()), TStringBuilder() << "Expected Read!")); + return nullptr; + } + + return node->Child(0)->HeadPtr(); + } + + return node; + }, ctx, settings); + + if (status.Level == TStatus::Ok) { + Complete_ = true; + } + return status; + } + + void Rewind() override { + Complete_ = false; + } + + private: + TExprNode::TPtr BuildInputFromRead(TPositionHandle replacePos, const TExprNode::TPtr& node, TExprContext& ctx) { + TIssueScopeGuard issueScope(ctx.IssueManager, [&]() { + return MakeIntrusive<TIssue>(ctx.GetPosition(node->Pos()), TStringBuilder() << "At function: " << node->Content()); + }); + + if (!EnsureMinArgsCount(*node, 3, ctx)) { + return nullptr; + } + + const auto source = node->ChildPtr(2); + if (source->IsCallable(NNodes::TCoKey::CallableName())) { + return BuildInputFromKey(replacePos, source, ctx); + } + if (source->IsCallable("DataTables")) { + return BuildInputFromDataTables(replacePos, source, ctx); + } + + ctx.AddError(TIssue(ctx.GetPosition(source->Pos()), TStringBuilder() << "Unsupported read source: " << source->Content())); + + return nullptr; + } + + TExprNode::TPtr BuildInputFromKey(TPositionHandle replacePos, const TExprNode::TPtr& node, TExprContext& ctx) { + TIssueScopeGuard issueScope(ctx.IssueManager, [&]() { + return MakeIntrusive<TIssue>(ctx.GetPosition(node->Pos()), TStringBuilder() << "At function: " << node->Content()); + }); + + ui32 inputIndex; + TExprNode::TPtr inputTableName; + + if (!TryFetchInputIndexFromKey(node, ctx, inputIndex, inputTableName)) { + return nullptr; + } + + YQL_ENSURE(inputTableName->IsCallable(NNodes::TCoString::CallableName())); + + auto inputNode = ctx.Builder(replacePos) + .Callable(CallableName_) + .Atom(0, ToString(inputIndex)) + .Seal() + .Build(); + + if (inputNode->IsCallable(PurecalcBlockInputCallableName)) { + const auto inputStruct = InputStructs_[inputIndex]->Cast<TStructExprType>(); + const auto blocksLambda = NodeFromBlocks(replacePos, inputStruct, ctx); + bool wrapLMap = ProcessorMode_ == EProcessorMode::PullList; + inputNode = ApplyToIterable(replacePos, inputNode, blocksLambda, wrapLMap, ctx); + } + + if (UseSystemColumns_) { + auto mapLambda = ctx.Builder(replacePos) + .Lambda() + .Param("row") + .Callable(0, NNodes::TCoAddMember::CallableName()) + .Arg(0, "row") + .Atom(1, PurecalcSysColumnTablePath) + .Add(2, inputTableName) + .Seal() + .Seal() + .Build(); + + return ctx.Builder(replacePos) + .Callable(NNodes::TCoMap::CallableName()) + .Add(0, std::move(inputNode)) + .Add(1, std::move(mapLambda)) + .Seal() + .Build(); + } + + return inputNode; + } + + TExprNode::TPtr BuildInputFromDataTables(TPositionHandle replacePos, const TExprNode::TPtr& node, TExprContext& ctx) { + TIssueScopeGuard issueScope(ctx.IssueManager, [&]() { + return MakeIntrusive<TIssue>(ctx.GetPosition(node->Pos()), TStringBuilder() << "At function: " << node->Content()); + }); + + if (InputStructs_.empty()) { + ctx.AddError(TIssue(ctx.GetPosition(node->Pos()), "No inputs provided by input spec")); + return nullptr; + } + + if (!EnsureArgsCount(*node, 0, ctx)) { + return nullptr; + } + + auto builder = ctx.Builder(replacePos); + + if (InputStructs_.size() > 1) { + auto listBuilder = builder.List(); + + for (ui32 i = 0; i < InputStructs_.size(); ++i) { + listBuilder.Callable(i, CallableName_).Atom(0, ToString(i)).Seal(); + } + + return listBuilder.Seal().Build(); + } + + return builder.Callable(CallableName_).Atom(0, "0").Seal().Build(); + } + + bool TryFetchInputIndexFromKey(const TExprNode::TPtr& node, TExprContext& ctx, ui32& resultIndex, TExprNode::TPtr& resultTableName) { + if (!EnsureArgsCount(*node, 1, ctx)) { + return false; + } + + const auto* keyArg = node->Child(0); + if (!keyArg->IsList() || keyArg->ChildrenSize() != 2 || !keyArg->Child(0)->IsAtom("table") || + !keyArg->Child(1)->IsCallable(NNodes::TCoString::CallableName())) + { + ctx.AddError(TIssue(ctx.GetPosition(keyArg->Pos()), "Expected single table name")); + return false; + } + + resultTableName = keyArg->ChildPtr(1); + + auto tableName = resultTableName->Child(0)->Content(); + + if (!tableName.StartsWith(TablePrefix_)) { + ctx.AddError(TIssue(ctx.GetPosition(resultTableName->Child(0)->Pos()), + TStringBuilder() << "Invalid table name " << TString{tableName}.Quote() << ": prefix must be " << TablePrefix_.Quote())); + return false; + } + + tableName.SkipPrefix(TablePrefix_); + + if (!tableName) { + resultIndex = 0; + } else if (!TryFromString(tableName, resultIndex)) { + ctx.AddError(TIssue(ctx.GetPosition(resultTableName->Child(0)->Pos()), + TStringBuilder() << "Invalid table name " << TString{tableName}.Quote() << ": suffix must be UI32 number")); + return false; + } + + return true; + } + }; +} + +TAutoPtr<IGraphTransformer> NYql::NPureCalc::MakeTableReadsReplacer( + const TVector<const TStructExprType*>& inputStructs, + bool useSystemColumns, + EProcessorMode processorMode, + TString callableName, + TString tablePrefix +) { + return new TTableReadsReplacer(inputStructs, useSystemColumns, processorMode, std::move(callableName), std::move(tablePrefix)); +} diff --git a/yql/essentials/public/purecalc/common/transformations/replace_table_reads.h b/yql/essentials/public/purecalc/common/transformations/replace_table_reads.h new file mode 100644 index 00000000000..33bc7174ac4 --- /dev/null +++ b/yql/essentials/public/purecalc/common/transformations/replace_table_reads.h @@ -0,0 +1,30 @@ +#pragma once + +#include <yql/essentials/public/purecalc/common/names.h> +#include <yql/essentials/public/purecalc/common/processor_mode.h> + +#include <yql/essentials/core/yql_graph_transformer.h> + +namespace NYql::NPureCalc { + /** + * SQL translation would generate a standard Read! call to read each input table. It will than generate + * a Right! call to get the table data from a tuple returned by Read!. This transformation replaces any Right! + * call with a call to special function used to get input data. + * + * Each table name must starts with the specified prefix and ends with an index of program input (e.g. `Input0`). + * Name without numeric suffix is an alias for the first input. + * + * @param inputStructs types of each input. + * @param useSystemColumns whether to allow special system columns in input structs. + * @param callableName name of the special callable used to get input data (e.g. `Self`). + * @param tablePrefix required prefix for all table names (e.g. `Input`). + * @param return a graph transformer for replacing table reads. + */ + TAutoPtr<IGraphTransformer> MakeTableReadsReplacer( + const TVector<const TStructExprType*>& inputStructs, + bool useSystemColumns, + EProcessorMode processorMode, + TString callableName = TString{PurecalcInputCallableName}, + TString tablePrefix = TString{PurecalcInputTablePrefix} + ); +} diff --git a/yql/essentials/public/purecalc/common/transformations/root_to_blocks.cpp b/yql/essentials/public/purecalc/common/transformations/root_to_blocks.cpp new file mode 100644 index 00000000000..07c959d1077 --- /dev/null +++ b/yql/essentials/public/purecalc/common/transformations/root_to_blocks.cpp @@ -0,0 +1,65 @@ +#include "root_to_blocks.h" + +#include <yql/essentials/public/purecalc/common/transformations/utils.h> + +#include <yql/essentials/core/yql_expr_type_annotation.h> + +using namespace NYql; +using namespace NYql::NPureCalc; + +namespace { + +class TRootToBlocks: public TSyncTransformerBase { +private: + bool AcceptsBlocks_; + EProcessorMode ProcessorMode_; + bool Wrapped_; + +public: + explicit TRootToBlocks(bool acceptsBlocks, EProcessorMode processorMode) + : AcceptsBlocks_(acceptsBlocks) + , ProcessorMode_(processorMode) + , Wrapped_(false) + { + } + +public: + void Rewind() override { + Wrapped_ = false; + } + + TStatus DoTransform(TExprNode::TPtr input, TExprNode::TPtr& output, TExprContext& ctx) final { + if (Wrapped_ || !AcceptsBlocks_) { + return IGraphTransformer::TStatus::Ok; + } + + const TTypeAnnotationNode* returnItemType; + const TTypeAnnotationNode* returnType = input->GetTypeAnn(); + if (ProcessorMode_ == EProcessorMode::PullList) { + Y_ENSURE(returnType->GetKind() == ETypeAnnotationKind::List); + returnItemType = returnType->Cast<TListExprType>()->GetItemType(); + } else { + Y_ENSURE(returnType->GetKind() == ETypeAnnotationKind::Stream); + returnItemType = returnType->Cast<TStreamExprType>()->GetItemType(); + } + + Y_ENSURE(returnItemType->GetKind() == ETypeAnnotationKind::Struct); + const TStructExprType* structType = returnItemType->Cast<TStructExprType>(); + const auto blocksLambda = NodeToBlocks(input->Pos(), structType, ctx); + bool wrapLMap = ProcessorMode_ == EProcessorMode::PullList; + output = ApplyToIterable(input->Pos(), input, blocksLambda, wrapLMap, ctx); + + Wrapped_ = true; + + return IGraphTransformer::TStatus(IGraphTransformer::TStatus::Repeat, true); + } +}; + +} // namespace + +TAutoPtr<IGraphTransformer> NYql::NPureCalc::MakeRootToBlocks( + bool acceptsBlocks, + EProcessorMode processorMode +) { + return new TRootToBlocks(acceptsBlocks, processorMode); +} diff --git a/yql/essentials/public/purecalc/common/transformations/root_to_blocks.h b/yql/essentials/public/purecalc/common/transformations/root_to_blocks.h new file mode 100644 index 00000000000..13a7a9dfc11 --- /dev/null +++ b/yql/essentials/public/purecalc/common/transformations/root_to_blocks.h @@ -0,0 +1,22 @@ +#pragma once + +#include <yql/essentials/public/purecalc/common/processor_mode.h> + +#include <yql/essentials/core/yql_graph_transformer.h> + +namespace NYql { + namespace NPureCalc { + /** + * A transformer which rewrite the root to respect block types. + * + * @param acceptsBlock allows using this transformer in pipeline and + * skip this phase if no block output is required. + * @param processorMode specifies the top-most container of the result. + * @return a graph transformer for rewriting the root node. + */ + TAutoPtr<IGraphTransformer> MakeRootToBlocks( + bool acceptsBlocks, + EProcessorMode processorMode + ); + } +} diff --git a/yql/essentials/public/purecalc/common/transformations/type_annotation.cpp b/yql/essentials/public/purecalc/common/transformations/type_annotation.cpp new file mode 100644 index 00000000000..3d322e8b32e --- /dev/null +++ b/yql/essentials/public/purecalc/common/transformations/type_annotation.cpp @@ -0,0 +1,251 @@ +#include "type_annotation.h" + +#include <yql/essentials/public/purecalc/common/interface.h> +#include <yql/essentials/public/purecalc/common/inspect_input.h> +#include <yql/essentials/public/purecalc/common/names.h> +#include <yql/essentials/public/purecalc/common/transformations/utils.h> + +#include <yql/essentials/core/type_ann/type_ann_core.h> +#include <yql/essentials/core/yql_expr_type_annotation.h> + +#include <util/generic/fwd.h> + +using namespace NYql; +using namespace NYql::NPureCalc; + +namespace { + class TTypeAnnotatorBase: public TSyncTransformerBase { + public: + using THandler = std::function<TStatus(const TExprNode::TPtr&, TExprNode::TPtr&, TExprContext&)>; + + TTypeAnnotatorBase(TTypeAnnotationContextPtr typeAnnotationContext) + { + OriginalTransformer_.reset(CreateExtCallableTypeAnnotationTransformer(*typeAnnotationContext).Release()); + } + + TStatus DoTransform(TExprNode::TPtr input, TExprNode::TPtr& output, TExprContext& ctx) final { + if (input->Type() == TExprNode::Callable) { + if (auto handler = Handlers_.FindPtr(input->Content())) { + return (*handler)(input, output, ctx); + } + } + + auto status = OriginalTransformer_->Transform(input, output, ctx); + + YQL_ENSURE(status.Level != IGraphTransformer::TStatus::Async, "Async type check is not supported"); + + return status; + } + + void Rewind() final { + OriginalTransformer_->Rewind(); + } + + protected: + void AddHandler(std::initializer_list<TStringBuf> names, THandler handler) { + for (auto name: names) { + YQL_ENSURE(Handlers_.emplace(name, handler).second, "Duplicate handler for " << name); + } + } + + template <class TDerived> + THandler Hndl(TStatus(TDerived::* handler)(const TExprNode::TPtr&, TExprNode::TPtr&, TExprContext&)) { + return [this, handler] (TExprNode::TPtr input, TExprNode::TPtr& output, TExprContext& ctx) { + return (static_cast<TDerived*>(this)->*handler)(input, output, ctx); + }; + } + + template <class TDerived> + THandler Hndl(TStatus(TDerived::* handler)(const TExprNode::TPtr&, TExprContext&)) { + return [this, handler] (TExprNode::TPtr input, TExprNode::TPtr& /*output*/, TExprContext& ctx) { + return (static_cast<TDerived*>(this)->*handler)(input, ctx); + }; + } + + private: + std::shared_ptr<IGraphTransformer> OriginalTransformer_; + THashMap<TStringBuf, THandler> Handlers_; + }; + + class TTypeAnnotator : public TTypeAnnotatorBase { + private: + TTypeAnnotationContextPtr TypeAnnotationContext_; + const TVector<const TStructExprType*>& InputStructs_; + TVector<const TStructExprType*>& RawInputTypes_; + EProcessorMode ProcessorMode_; + TString InputNodeName_; + + public: + TTypeAnnotator( + TTypeAnnotationContextPtr typeAnnotationContext, + const TVector<const TStructExprType*>& inputStructs, + TVector<const TStructExprType*>& rawInputTypes, + EProcessorMode processorMode, + TString nodeName + ) + : TTypeAnnotatorBase(typeAnnotationContext) + , TypeAnnotationContext_(typeAnnotationContext) + , InputStructs_(inputStructs) + , RawInputTypes_(rawInputTypes) + , ProcessorMode_(processorMode) + , InputNodeName_(std::move(nodeName)) + { + AddHandler({InputNodeName_}, Hndl(&TTypeAnnotator::HandleInputNode)); + AddHandler({NNodes::TCoTableName::CallableName()}, Hndl(&TTypeAnnotator::HandleTableName)); + AddHandler({NNodes::TCoTablePath::CallableName()}, Hndl(&TTypeAnnotator::HandleTablePath)); + AddHandler({NNodes::TCoHoppingTraits::CallableName()}, Hndl(&TTypeAnnotator::HandleHoppingTraits)); + } + + TTypeAnnotator(TTypeAnnotationContextPtr, TVector<const TStructExprType*>&&, EProcessorMode, TString) = delete; + + private: + TStatus HandleInputNode(const TExprNode::TPtr& input, TExprContext& ctx) { + ui32 inputIndex; + if (!TryFetchInputIndexFromSelf(*input, ctx, InputStructs_.size(), inputIndex)) { + return IGraphTransformer::TStatus::Error; + } + + YQL_ENSURE(inputIndex < InputStructs_.size()); + + auto itemType = InputStructs_[inputIndex]; + + // XXX: Tweak the input expression type, if the spec supports blocks: + // 1. Add "_yql_block_length" attribute for internal usage. + // 2. Add block container to wrap the actual item type. + if (input->IsCallable(PurecalcBlockInputCallableName)) { + itemType = WrapBlockStruct(itemType, ctx); + } + + RawInputTypes_[inputIndex] = itemType; + + TColumnOrder columnOrder; + for (const auto& i : itemType->GetItems()) { + columnOrder.AddColumn(TString(i->GetName())); + } + + if (ProcessorMode_ != EProcessorMode::PullList) { + input->SetTypeAnn(ctx.MakeType<TStreamExprType>(itemType)); + } else { + input->SetTypeAnn(ctx.MakeType<TListExprType>(itemType)); + } + + TypeAnnotationContext_->SetColumnOrder(*input, columnOrder, ctx); + return TStatus::Ok; + } + + TStatus HandleTableName(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExprContext& ctx) { + if (!EnsureMinMaxArgsCount(*input, 1, 2, ctx)) { + return TStatus::Error; + } + + if (input->ChildrenSize() > 1) { + if (!EnsureAtom(input->Tail(), ctx)) { + return TStatus::Error; + } + + if (input->Tail().Content() != PurecalcDefaultService) { + ctx.AddError( + TIssue( + ctx.GetPosition(input->Tail().Pos()), + TStringBuilder() << "Unsupported system: " << input->Tail().Content())); + return TStatus::Error; + } + } + + if (input->Head().IsCallable(NNodes::TCoDependsOn::CallableName())) { + if (!EnsureArgsCount(input->Head(), 1, ctx)) { + return TStatus::Error; + } + + if (!TryBuildTableNameNode(input->Pos(), input->Head().HeadPtr(), output, ctx)) { + return TStatus::Error; + } + } else { + if (!EnsureSpecificDataType(input->Head(), EDataSlot::String, ctx)) { + return TStatus::Error; + } + output = input->HeadPtr(); + } + + return TStatus::Repeat; + } + + TStatus HandleTablePath(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExprContext& ctx) { + if (!EnsureArgsCount(*input, 1, ctx)) { + return TStatus::Error; + } + + if (!EnsureDependsOn(input->Head(), ctx)) { + return TStatus::Error; + } + + if (!EnsureArgsCount(input->Head(), 1, ctx)) { + return TStatus::Error; + } + + if (!TryBuildTableNameNode(input->Pos(), input->Head().HeadPtr(), output, ctx)) { + return TStatus::Error; + } + + return TStatus::Repeat; + } + + TStatus HandleHoppingTraits(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExprContext& ctx) { + Y_UNUSED(output); + if (input->ChildrenSize() == 1) { + auto children = input->ChildrenList(); + auto falseArg = ctx.Builder(input->Pos()) + .Atom("false") + .Seal() + .Build(); + children.emplace_back(falseArg); + input->ChangeChildrenInplace(std::move(children)); + return TStatus::Repeat; + } + + return TStatus::Ok; + } + + private: + bool TryBuildTableNameNode( + TPositionHandle position, const TExprNode::TPtr& row, TExprNode::TPtr& result, TExprContext& ctx) + { + if (!EnsureStructType(*row, ctx)) { + return false; + } + + const auto* structType = row->GetTypeAnn()->Cast<TStructExprType>(); + + if (auto pos = structType->FindItem(PurecalcSysColumnTablePath)) { + if (!EnsureSpecificDataType(row->Pos(), *structType->GetItems()[*pos]->GetItemType(), EDataSlot::String, ctx)) { + return false; + } + + result = ctx.Builder(position) + .Callable(NNodes::TCoMember::CallableName()) + .Add(0, row) + .Atom(1, PurecalcSysColumnTablePath) + .Seal() + .Build(); + } else { + result = ctx.Builder(position) + .Callable(NNodes::TCoString::CallableName()) + .Atom(0, "") + .Seal() + .Build(); + } + + return true; + } + }; +} + +TAutoPtr<IGraphTransformer> NYql::NPureCalc::MakeTypeAnnotationTransformer( + TTypeAnnotationContextPtr typeAnnotationContext, + const TVector<const TStructExprType*>& inputStructs, + TVector<const TStructExprType*>& rawInputTypes, + EProcessorMode processorMode, + const TString& nodeName +) { + return new TTypeAnnotator(typeAnnotationContext, inputStructs, rawInputTypes, processorMode, nodeName); +} diff --git a/yql/essentials/public/purecalc/common/transformations/type_annotation.h b/yql/essentials/public/purecalc/common/transformations/type_annotation.h new file mode 100644 index 00000000000..87649fd231a --- /dev/null +++ b/yql/essentials/public/purecalc/common/transformations/type_annotation.h @@ -0,0 +1,30 @@ +#pragma once + +#include <yql/essentials/public/purecalc/common/names.h> +#include <yql/essentials/public/purecalc/common/processor_mode.h> + +#include <yql/essentials/core/yql_graph_transformer.h> +#include <yql/essentials/core/yql_type_annotation.h> + +namespace NYql { + namespace NPureCalc { + /** + * Build type annotation transformer that is aware of type of the input rows. + * + * @param typeAnnotationContext current context. + * @param inputStructs types of each input. + * @param rawInputStructs container to store the resulting input item type. + * @param processorMode current processor mode. This will affect generated input type, + * e.g. list node or struct node. + * @param nodeName name of the callable used to get input data, e.g. `Self`. + * @return a graph transformer for type annotation. + */ + TAutoPtr<IGraphTransformer> MakeTypeAnnotationTransformer( + TTypeAnnotationContextPtr typeAnnotationContext, + const TVector<const TStructExprType*>& inputStructs, + TVector<const TStructExprType*>& rawInputStructs, + EProcessorMode processorMode, + const TString& nodeName = TString{PurecalcInputCallableName} + ); + } +} diff --git a/yql/essentials/public/purecalc/common/transformations/utils.cpp b/yql/essentials/public/purecalc/common/transformations/utils.cpp new file mode 100644 index 00000000000..4e2da41835c --- /dev/null +++ b/yql/essentials/public/purecalc/common/transformations/utils.cpp @@ -0,0 +1,179 @@ +#include "utils.h" + +#include <yql/essentials/public/purecalc/common/names.h> +#include <yql/essentials/core/yql_expr_type_annotation.h> + +using namespace NYql; +using namespace NYql::NPureCalc; + +TExprNode::TPtr NYql::NPureCalc::NodeFromBlocks( + const TPositionHandle& pos, + const TStructExprType* structType, + TExprContext& ctx +) { + const auto items = structType->GetItems(); + Y_ENSURE(items.size() > 0); + return ctx.Builder(pos) + .Lambda() + .Param("stream") + .Callable(0, "FromFlow") + .Callable(0, "NarrowMap") + .Callable(0, "WideFromBlocks") + .Callable(0, "ExpandMap") + .Callable(0, "ToFlow") + .Arg(0, "stream") + .Seal() + .Lambda(1) + .Param("item") + .Do([&](TExprNodeBuilder& lambda) -> TExprNodeBuilder& { + ui32 i = 0; + for (const auto& item : items) { + lambda.Callable(i++, "Member") + .Arg(0, "item") + .Atom(1, item->GetName()) + .Seal(); + } + lambda.Callable(i, "Member") + .Arg(0, "item") + .Atom(1, PurecalcBlockColumnLength) + .Seal(); + return lambda; + }) + .Seal() + .Seal() + .Seal() + .Lambda(1) + .Params("fields", items.size()) + .Callable("AsStruct") + .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& { + ui32 i = 0; + for (const auto& item : items) { + parent.List(i) + .Atom(0, item->GetName()) + .Arg(1, "fields", i++) + .Seal(); + } + return parent; + }) + .Seal() + .Seal() + .Seal() + .Seal() + .Seal() + .Build(); +} + +TExprNode::TPtr NYql::NPureCalc::NodeToBlocks( + const TPositionHandle& pos, + const TStructExprType* structType, + TExprContext& ctx +) { + const auto items = structType->GetItems(); + Y_ENSURE(items.size() > 0); + return ctx.Builder(pos) + .Lambda() + .Param("stream") + .Callable("FromFlow") + .Callable(0, "NarrowMap") + .Callable(0, "WideToBlocks") + .Callable(0, "ExpandMap") + .Callable(0, "ToFlow") + .Arg(0, "stream") + .Seal() + .Lambda(1) + .Param("item") + .Do([&](TExprNodeBuilder& lambda) -> TExprNodeBuilder& { + ui32 i = 0; + for (const auto& item : items) { + lambda.Callable(i++, "Member") + .Arg(0, "item") + .Atom(1, item->GetName()) + .Seal(); + } + return lambda; + }) + .Seal() + .Seal() + .Seal() + .Lambda(1) + .Params("fields", items.size() + 1) + .Callable("AsStruct") + .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& { + ui32 i = 0; + for (const auto& item : items) { + parent.List(i) + .Atom(0, item->GetName()) + .Arg(1, "fields", i++) + .Seal(); + } + parent.List(i) + .Atom(0, PurecalcBlockColumnLength) + .Arg(1, "fields", i) + .Seal(); + return parent; + }) + .Seal() + .Seal() + .Seal() + .Seal() + .Seal() + .Build(); +} + +TExprNode::TPtr NYql::NPureCalc::ApplyToIterable( + const TPositionHandle& pos, + const TExprNode::TPtr iterable, + const TExprNode::TPtr lambda, + bool wrapLMap, + TExprContext& ctx +) { + if (wrapLMap) { + return ctx.Builder(pos) + .Callable("LMap") + .Add(0, iterable) + .Lambda(1) + .Param("stream") + .Apply(lambda) + .With(0, "stream") + .Seal() + .Seal() + .Seal() + .Build(); + } else { + return ctx.Builder(pos) + .Apply(lambda) + .With(0, iterable) + .Seal() + .Build(); + } +} + +const TStructExprType* NYql::NPureCalc::WrapBlockStruct( + const TStructExprType* structType, + TExprContext& ctx +) { + TVector<const TItemExprType*> members; + for (const auto& item : structType->GetItems()) { + const auto blockItemType = ctx.MakeType<TBlockExprType>(item->GetItemType()); + members.push_back(ctx.MakeType<TItemExprType>(item->GetName(), blockItemType)); + } + const auto scalarItemType = ctx.MakeType<TScalarExprType>(ctx.MakeType<TDataExprType>(EDataSlot::Uint64)); + members.push_back(ctx.MakeType<TItemExprType>(PurecalcBlockColumnLength, scalarItemType)); + return ctx.MakeType<TStructExprType>(members); +} + +const TStructExprType* NYql::NPureCalc::UnwrapBlockStruct( + const TStructExprType* structType, + TExprContext& ctx +) { + TVector<const TItemExprType*> members; + for (const auto& item : structType->GetItems()) { + if (item->GetName() == PurecalcBlockColumnLength) { + continue; + } + bool isScalarUnused; + const auto blockItemType = GetBlockItemType(*item->GetItemType(), isScalarUnused); + members.push_back(ctx.MakeType<TItemExprType>(item->GetName(), blockItemType)); + } + return ctx.MakeType<TStructExprType>(members); +} diff --git a/yql/essentials/public/purecalc/common/transformations/utils.h b/yql/essentials/public/purecalc/common/transformations/utils.h new file mode 100644 index 00000000000..cc8849b7e3a --- /dev/null +++ b/yql/essentials/public/purecalc/common/transformations/utils.h @@ -0,0 +1,83 @@ +#pragma once + +#include <yql/essentials/core/yql_graph_transformer.h> + +namespace NYql { + namespace NPureCalc { + /** + * A transformer which wraps the given input node with the pipeline + * converting the input type to the block one. + * + * @param pos the position of the given node to be rewritten. + * @param structType the item type of the container provided by the node. + * @param ctx the context to make ExprNode rewrites. + * @return the resulting ExprNode. + */ + TExprNode::TPtr NodeFromBlocks( + const TPositionHandle& pos, + const TStructExprType* structType, + TExprContext& ctx + ); + + /** + * A transformer which wraps the given root node with the pipeline + * converting the output type to the block one. + * + * @param pos the position of the given node to be rewritten. + * @param structType the item type of the container provided by the node. + * @param ctx the context to make ExprNode rewrites. + * @return the resulting ExprNode. + */ + TExprNode::TPtr NodeToBlocks( + const TPositionHandle& pos, + const TStructExprType* structType, + TExprContext& ctx + ); + + /** + * A transformer to apply the given lambda to the given iterable (either + * list or stream). If the iterable is list, the lambda should be passed + * to the <LMap> callable; if the iterable is stream, the lambda should + * be applied right to the iterable. + * + * @param pos the position of the given node to be rewritten. + * @param iterable the node, that provides the iterable to be processed. + * @param lambda the node, that provides lambda to be applied. + * @param wrapLMap indicator to wrap the result with LMap callable. + * @oaram ctx the context to make ExprNode rewrites. + */ + TExprNode::TPtr ApplyToIterable( + const TPositionHandle& pos, + const TExprNode::TPtr iterable, + const TExprNode::TPtr lambda, + bool wrapLMap, + TExprContext& ctx + ); + + /** + * A helper which wraps the items of the given struct with the block + * type container and appends the new item for _yql_block_length column. + * + * @param structType original struct to be wrapped. + * @param ctx the context to make ExprType rewrite. + * @return the new struct with block items. + */ + const TStructExprType* WrapBlockStruct( + const TStructExprType* structType, + TExprContext& ctx + ); + + /** + * A helper which unwraps the block container from the items of the + * given struct and removes the item for _yql_block_length column. + * + * @param structType original struct to be unwrapped. + * @param ctx the context to make ExprType rewrite. + * @return the new struct without block items. + */ + const TStructExprType* UnwrapBlockStruct( + const TStructExprType* structType, + TExprContext& ctx + ); + } +} diff --git a/yql/essentials/public/purecalc/common/type_from_schema.cpp b/yql/essentials/public/purecalc/common/type_from_schema.cpp new file mode 100644 index 00000000000..373283a1a8e --- /dev/null +++ b/yql/essentials/public/purecalc/common/type_from_schema.cpp @@ -0,0 +1,255 @@ +#include "type_from_schema.h" + +#include <library/cpp/yson/node/node_io.h> + +#include <yql/essentials/core/yql_expr_type_annotation.h> +#include <yql/essentials/providers/common/schema/expr/yql_expr_schema.h> + +namespace { + using namespace NYql; + +#define REPORT(...) ctx.AddError(TIssue(TString(TStringBuilder() << __VA_ARGS__))) + + bool CheckStruct(const TStructExprType* got, const TStructExprType* expected, TExprContext& ctx) { + auto status = true; + + if (expected) { + for (const auto* gotNamedItem : got->GetItems()) { + auto expectedIndex = expected->FindItem(gotNamedItem->GetName()); + if (expectedIndex) { + const auto* gotItem = gotNamedItem->GetItemType(); + const auto* expectedItem = expected->GetItems()[*expectedIndex]->GetItemType(); + + auto arg = ctx.NewArgument(TPositionHandle(), "arg"); + auto fieldConversionStatus = TrySilentConvertTo(arg, *gotItem, *expectedItem, ctx); + if (fieldConversionStatus.Level == IGraphTransformer::TStatus::Error) { + REPORT("Item " << TString{gotNamedItem->GetName()}.Quote() << " expected to be " << + *expectedItem << ", but got " << *gotItem); + status = false; + } + } else { + REPORT("Got unexpected item " << TString{gotNamedItem->GetName()}.Quote()); + status = false; + } + } + + for (const auto* expectedNamedItem : expected->GetItems()) { + if (expectedNamedItem->GetItemType()->GetKind() == ETypeAnnotationKind::Optional) { + continue; + } + if (!got->FindItem(expectedNamedItem->GetName())) { + REPORT("Expected item " << TString{expectedNamedItem->GetName()}.Quote()); + status = false; + } + } + } + + return status; + } + + bool CheckVariantContent(const TStructExprType* got, const TStructExprType* expected, TExprContext& ctx) { + auto status = true; + + if (expected) { + for (const auto* gotNamedItem : got->GetItems()) { + if (!expected->FindItem(gotNamedItem->GetName())) { + REPORT("Got unexpected alternative " << TString{gotNamedItem->GetName()}.Quote()); + status = false; + } + } + + for (const auto* expectedNamedItem : expected->GetItems()) { + if (!got->FindItem(expectedNamedItem->GetName())) { + REPORT("Expected alternative " << TString{expectedNamedItem->GetName()}.Quote()); + status = false; + } + } + } + + for (const auto* gotNamedItem : got->GetItems()) { + const auto* gotItem = gotNamedItem->GetItemType(); + auto expectedIndex = expected ? expected->FindItem(gotNamedItem->GetName()) : Nothing(); + const auto* expectedItem = expected && expectedIndex ? expected->GetItems()[*expectedIndex]->GetItemType() : nullptr; + + TIssueScopeGuard issueScope(ctx.IssueManager, [&]() { + return new TIssue(TPosition(), TStringBuilder() << "Alternative " << TString{gotNamedItem->GetName()}.Quote()); + }); + + if (expectedItem && expectedItem->GetKind() != gotItem->GetKind()) { + REPORT("Expected to be " << expectedItem->GetKind() << ", but got " << gotItem->GetKind()); + status = false; + } + + if (gotItem->GetKind() != ETypeAnnotationKind::Struct) { + REPORT("Expected to be Struct, but got " << gotItem->GetKind()); + status = false; + } + + const auto* gotStruct = gotItem->Cast<TStructExprType>(); + const auto* expectedStruct = expectedItem ? expectedItem->Cast<TStructExprType>() : nullptr; + + if (!CheckStruct(gotStruct, expectedStruct, ctx)) { + status = false; + } + } + + return status; + } + + bool CheckVariantContent(const TTupleExprType* got, const TTupleExprType* expected, TExprContext& ctx) { + if (expected && expected->GetSize() != got->GetSize()) { + REPORT("Expected to have " << expected->GetSize() << " alternatives, but got " << got->GetSize()); + return false; + } + + auto status = true; + + for (size_t i = 0; i < got->GetSize(); i++) { + const auto* gotItem = got->GetItems()[i]; + const auto* expectedItem = expected ? expected->GetItems()[i] : nullptr; + + TIssueScopeGuard issueScope(ctx.IssueManager, [i]() { + return new TIssue(TPosition(), TStringBuilder() << "Alternative #" << i); + }); + + if (expectedItem && expectedItem->GetKind() != gotItem->GetKind()) { + REPORT("Expected " << expectedItem->GetKind() << ", but got " << gotItem->GetKind()); + status = false; + } + + if (gotItem->GetKind() != ETypeAnnotationKind::Struct) { + REPORT("Expected Struct, but got " << gotItem->GetKind()); + status = false; + } + + const auto* gotStruct = gotItem->Cast<TStructExprType>(); + const auto* expectedStruct = expectedItem ? expectedItem->Cast<TStructExprType>() : nullptr; + + if (!CheckStruct(gotStruct, expectedStruct, ctx)) { + status = false; + } + } + + return status; + } + + bool CheckVariant(const TVariantExprType* got, const TVariantExprType* expected, TExprContext& ctx) { + if (expected && expected->GetUnderlyingType()->GetKind() != got->GetUnderlyingType()->GetKind()) { + REPORT("Expected Variant over " << expected->GetUnderlyingType()->GetKind() << + ", but got Variant over " << got->GetUnderlyingType()->GetKind()); + return false; + } + + switch (got->GetUnderlyingType()->GetKind()) { + case ETypeAnnotationKind::Struct: + { + const auto* gotStruct = got->GetUnderlyingType()->Cast<TStructExprType>(); + const auto* expectedStruct = expected ? expected->GetUnderlyingType()->Cast<TStructExprType>() : nullptr; + return CheckVariantContent(gotStruct, expectedStruct, ctx); + } + case ETypeAnnotationKind::Tuple: + { + const auto* gotTuple = got->GetUnderlyingType()->Cast<TTupleExprType>(); + const auto* expectedTuple = expected ? expected->GetUnderlyingType()->Cast<TTupleExprType>() : nullptr; + return CheckVariantContent(gotTuple, expectedTuple, ctx); + } + default: + Y_UNREACHABLE(); + } + + return false; + } + + bool CheckSchema(const TTypeAnnotationNode* got, const TTypeAnnotationNode* expected, TExprContext& ctx, bool allowVariant) { + if (expected && expected->GetKind() != got->GetKind()) { + REPORT("Expected " << expected->GetKind() << ", but got " << got->GetKind()); + return false; + } + + switch (got->GetKind()) { + case ETypeAnnotationKind::Struct: + { + TIssueScopeGuard issueScope(ctx.IssueManager, []() { return new TIssue(TPosition(), "Toplevel struct"); }); + + const auto* gotStruct = got->Cast<TStructExprType>(); + const auto* expectedStruct = expected ? expected->Cast<TStructExprType>() : nullptr; + + if (!gotStruct->Validate(TPositionHandle(), ctx)) { + return false; + } + + return CheckStruct(gotStruct, expectedStruct, ctx); + } + case ETypeAnnotationKind::Variant: + if (allowVariant) { + TIssueScopeGuard issueScope(ctx.IssueManager, []() { return new TIssue(TPosition(), "Toplevel variant"); }); + + const auto* gotVariant = got->Cast<TVariantExprType>(); + const auto* expectedVariant = expected ? expected->Cast<TVariantExprType>() : nullptr; + + if (!gotVariant->Validate(TPositionHandle(), ctx)) { + return false; + } + + return CheckVariant(gotVariant, expectedVariant, ctx); + } + [[fallthrough]]; + default: + if (allowVariant) { + REPORT("Expected Struct or Variant, but got " << got->GetKind()); + } else { + REPORT("Expected Struct, but got " << got->GetKind()); + } + return false; + } + } +} + +namespace NYql::NPureCalc { + const TTypeAnnotationNode* MakeTypeFromSchema(const NYT::TNode& yson, TExprContext& ctx) { + const auto* type = NCommon::ParseTypeFromYson(yson, ctx); + + if (!type) { + ythrow TCompileError("", ctx.IssueManager.GetIssues().ToString()) + << "Incorrect schema: " << NYT::NodeToYsonString(yson, NYson::EYsonFormat::Text); + } + + return type; + } + + const TStructExprType* ExtendStructType( + const TStructExprType* type, const THashMap<TString, NYT::TNode>& extraColumns, TExprContext& ctx) + { + if (extraColumns.empty()) { + return type; + } + + auto items = type->GetItems(); + for (const auto& pair : extraColumns) { + items.push_back(ctx.MakeType<TItemExprType>(pair.first, MakeTypeFromSchema(pair.second, ctx))); + } + + auto result = ctx.MakeType<TStructExprType>(items); + + if (!result->Validate(TPosition(), ctx)) { + ythrow TCompileError("", ctx.IssueManager.GetIssues().ToString()) << "Incorrect extended struct type"; + } + + return result; + } + + bool ValidateInputSchema(const TTypeAnnotationNode* type, TExprContext& ctx) { + TIssueScopeGuard issueScope(ctx.IssueManager, []() { return new TIssue(TPosition(), "Input schema"); }); + return CheckSchema(type, nullptr, ctx, false); + } + + bool ValidateOutputSchema(const TTypeAnnotationNode* type, TExprContext& ctx) { + TIssueScopeGuard issueScope(ctx.IssueManager, []() { return new TIssue(TPosition(), "Output schema"); }); + return CheckSchema(type, nullptr, ctx, true); + } + + bool ValidateOutputType(const TTypeAnnotationNode* type, const TTypeAnnotationNode* expected, TExprContext& ctx) { + TIssueScopeGuard issueScope(ctx.IssueManager, []() { return new TIssue(TPosition(), "Program return type"); }); + return CheckSchema(type, expected, ctx, true); + } +} diff --git a/yql/essentials/public/purecalc/common/type_from_schema.h b/yql/essentials/public/purecalc/common/type_from_schema.h new file mode 100644 index 00000000000..b957aad9a10 --- /dev/null +++ b/yql/essentials/public/purecalc/common/type_from_schema.h @@ -0,0 +1,36 @@ +#pragma once + +#include <yql/essentials/public/purecalc/common/interface.h> + +#include <yql/essentials/ast/yql_expr.h> + +#include <library/cpp/yson/node/node.h> + +namespace NYql { + namespace NPureCalc { + /** + * Load struct type from yson. Use methods below to check returned type for correctness. + */ + const TTypeAnnotationNode* MakeTypeFromSchema(const NYT::TNode&, TExprContext&); + + /** + * Extend struct type with additional columns. Type of each extra column is loaded from yson. + */ + const TStructExprType* ExtendStructType(const TStructExprType*, const THashMap<TString, NYT::TNode>&, TExprContext&); + + /** + * Check if the given type can be used as an input schema, i.e. it is a struct. + */ + bool ValidateInputSchema(const TTypeAnnotationNode* type, TExprContext& ctx); + + /** + * Check if the given type can be used as an output schema, i.e. it is a struct or a variant of structs. + */ + bool ValidateOutputSchema(const TTypeAnnotationNode* type, TExprContext& ctx); + + /** + * Check if output type can be silently converted to the expected type. + */ + bool ValidateOutputType(const TTypeAnnotationNode* type, const TTypeAnnotationNode* expected, TExprContext& ctx); + } +} diff --git a/yql/essentials/public/purecalc/common/worker.cpp b/yql/essentials/public/purecalc/common/worker.cpp new file mode 100644 index 00000000000..f670458c728 --- /dev/null +++ b/yql/essentials/public/purecalc/common/worker.cpp @@ -0,0 +1,613 @@ +#include "worker.h" +#include "compile_mkql.h" + +#include <yql/essentials/ast/yql_expr.h> +#include <yql/essentials/core/yql_user_data.h> +#include <yql/essentials/core/yql_user_data_storage.h> +#include <yql/essentials/providers/common/comp_nodes/yql_factory.h> +#include <yql/essentials/public/purecalc/common/names.h> +#include <yql/essentials/minikql/mkql_function_registry.h> +#include <yql/essentials/minikql/mkql_node.h> +#include <yql/essentials/minikql/mkql_node_builder.h> +#include <yql/essentials/minikql/mkql_node_cast.h> +#include <yql/essentials/minikql/mkql_node_visitor.h> +#include <yql/essentials/minikql/mkql_node_serialization.h> +#include <yql/essentials/minikql/mkql_program_builder.h> +#include <yql/essentials/minikql/comp_nodes/mkql_factories.h> +#include <yql/essentials/minikql/computation/mkql_computation_node.h> +#include <yql/essentials/minikql/computation/mkql_computation_node_holders.h> +#include <yql/essentials/minikql/computation/mkql_computation_node_impl.h> +#include <yql/essentials/providers/common/mkql/yql_provider_mkql.h> +#include <yql/essentials/providers/common/mkql/yql_type_mkql.h> + +#include <library/cpp/random_provider/random_provider.h> +#include <library/cpp/time_provider/time_provider.h> + +#include <util/stream/file.h> +#include <yql/essentials/minikql/computation/mkql_custom_list.h> +#include <yql/essentials/parser/pg_wrapper/interface/comp_factory.h> + +using namespace NYql; +using namespace NYql::NPureCalc; + +TWorkerGraph::TWorkerGraph( + const TExprNode::TPtr& exprRoot, + TExprContext& exprCtx, + const TString& serializedProgram, + const NKikimr::NMiniKQL::IFunctionRegistry& funcRegistry, + const TUserDataTable& userData, + const TVector<const TStructExprType*>& inputTypes, + const TVector<const TStructExprType*>& originalInputTypes, + const TVector<const TStructExprType*>& rawInputTypes, + const TTypeAnnotationNode* outputType, + const TTypeAnnotationNode* rawOutputType, + const TString& LLVMSettings, + NKikimr::NUdf::ICountersProvider* countersProvider, + ui64 nativeYtTypeFlags, + TMaybe<ui64> deterministicTimeProviderSeed +) + : ScopedAlloc_(__LOCATION__, NKikimr::TAlignedPagePoolCounters(), funcRegistry.SupportsSizedAllocators()) + , Env_(ScopedAlloc_) + , FuncRegistry_(funcRegistry) + , RandomProvider_(CreateDefaultRandomProvider()) + , TimeProvider_(deterministicTimeProviderSeed ? + CreateDeterministicTimeProvider(*deterministicTimeProviderSeed) : + CreateDefaultTimeProvider()) + , LLVMSettings_(LLVMSettings) + , NativeYtTypeFlags_(nativeYtTypeFlags) +{ + // Build the root MKQL node + + NKikimr::NMiniKQL::TRuntimeNode rootNode; + if (exprRoot) { + rootNode = CompileMkql(exprRoot, exprCtx, FuncRegistry_, Env_, userData); + } else { + rootNode = NKikimr::NMiniKQL::DeserializeRuntimeNode(serializedProgram, Env_); + } + + // Prepare container for input nodes + + const ui32 inputsCount = inputTypes.size(); + + YQL_ENSURE(inputTypes.size() == originalInputTypes.size()); + + SelfNodes_.resize(inputsCount, nullptr); + + YQL_ENSURE(SelfNodes_.size() == inputsCount); + + // Setup struct types + + NKikimr::NMiniKQL::TProgramBuilder pgmBuilder(Env_, FuncRegistry_); + for (ui32 i = 0; i < inputsCount; ++i) { + const auto* type = static_cast<NKikimr::NMiniKQL::TStructType*>(NCommon::BuildType(TPositionHandle(), *inputTypes[i], pgmBuilder)); + const auto* originalType = type; + const auto* rawType = static_cast<NKikimr::NMiniKQL::TStructType*>(NCommon::BuildType(TPositionHandle(), *rawInputTypes[i], pgmBuilder)); + if (inputTypes[i] != originalInputTypes[i]) { + YQL_ENSURE(inputTypes[i]->GetSize() >= originalInputTypes[i]->GetSize()); + originalType = static_cast<NKikimr::NMiniKQL::TStructType*>(NCommon::BuildType(TPositionHandle(), *originalInputTypes[i], pgmBuilder)); + } + + InputTypes_.push_back(type); + OriginalInputTypes_.push_back(originalType); + RawInputTypes_.push_back(rawType); + } + + if (outputType) { + OutputType_ = NCommon::BuildType(TPositionHandle(), *outputType, pgmBuilder); + } + if (rawOutputType) { + RawOutputType_ = NCommon::BuildType(TPositionHandle(), *rawOutputType, pgmBuilder); + } + + if (!exprRoot) { + auto outMkqlType = rootNode.GetStaticType(); + if (outMkqlType->GetKind() == NKikimr::NMiniKQL::TType::EKind::List) { + outMkqlType = static_cast<NKikimr::NMiniKQL::TListType*>(outMkqlType)->GetItemType(); + } else if (outMkqlType->GetKind() == NKikimr::NMiniKQL::TType::EKind::Stream) { + outMkqlType = static_cast<NKikimr::NMiniKQL::TStreamType*>(outMkqlType)->GetItemType(); + } else { + ythrow TCompileError("", "") << "unexpected mkql output type " << NKikimr::NMiniKQL::TType::KindAsStr(outMkqlType->GetKind()); + } + if (OutputType_) { + if (!OutputType_->IsSameType(*outMkqlType)) { + ythrow TCompileError("", "") << "precompiled program output type doesn't match the output schema"; + } + } else { + OutputType_ = outMkqlType; + RawOutputType_ = outMkqlType; + } + } + + // Compile computation pattern + + const THashSet<NKikimr::NMiniKQL::TInternName> selfCallableNames = { + Env_.InternName(PurecalcInputCallableName), + Env_.InternName(PurecalcBlockInputCallableName) + }; + + NKikimr::NMiniKQL::TExploringNodeVisitor explorer; + explorer.Walk(rootNode.GetNode(), Env_); + + auto compositeNodeFactory = NKikimr::NMiniKQL::GetCompositeWithBuiltinFactory( + {NKikimr::NMiniKQL::GetYqlFactory(), NYql::GetPgFactory()} + ); + + auto nodeFactory = [&]( + NKikimr::NMiniKQL::TCallable& callable, const NKikimr::NMiniKQL::TComputationNodeFactoryContext& ctx + ) -> NKikimr::NMiniKQL::IComputationNode* { + if (selfCallableNames.contains(callable.GetType()->GetNameStr())) { + YQL_ENSURE(callable.GetInputsCount() == 1, "Self takes exactly 1 argument"); + const auto inputIndex = AS_VALUE(NKikimr::NMiniKQL::TDataLiteral, callable.GetInput(0))->AsValue().Get<ui32>(); + YQL_ENSURE(inputIndex < inputsCount, "Self index is out of range"); + YQL_ENSURE(!SelfNodes_[inputIndex], "Self can be called at most once with each index"); + return SelfNodes_[inputIndex] = new NKikimr::NMiniKQL::TExternalComputationNode(ctx.Mutables); + } + else { + return compositeNodeFactory(callable, ctx); + } + }; + + NKikimr::NMiniKQL::TComputationPatternOpts computationPatternOpts( + ScopedAlloc_.Ref(), + Env_, + nodeFactory, + &funcRegistry, + NKikimr::NUdf::EValidateMode::None, + NKikimr::NUdf::EValidatePolicy::Exception, + LLVMSettings, + NKikimr::NMiniKQL::EGraphPerProcess::Multi, + nullptr, + countersProvider); + + ComputationPattern_ = NKikimr::NMiniKQL::MakeComputationPattern( + explorer, + rootNode, + { rootNode.GetNode() }, + computationPatternOpts); + + ComputationGraph_ = ComputationPattern_->Clone( + computationPatternOpts.ToComputationOptions(*RandomProvider_, *TimeProvider_)); + + ComputationGraph_->Prepare(); + + // Scoped alloc acquires itself on construction. We need to release it before returning control to user. + // Note that scoped alloc releases itself on destruction so it is no problem if the above code throws. + ScopedAlloc_.Release(); +} + +TWorkerGraph::~TWorkerGraph() { + // Remember, we've released scoped alloc in constructor? Now, we need to acquire it back before destroying. + ScopedAlloc_.Acquire(); +} + +template <typename TBase> +TWorker<TBase>::TWorker( + TWorkerFactoryPtr factory, + const TExprNode::TPtr& exprRoot, + TExprContext& exprCtx, + const TString& serializedProgram, + const NKikimr::NMiniKQL::IFunctionRegistry& funcRegistry, + const TUserDataTable& userData, + const TVector<const TStructExprType*>& inputTypes, + const TVector<const TStructExprType*>& originalInputTypes, + const TVector<const TStructExprType*>& rawInputTypes, + const TTypeAnnotationNode* outputType, + const TTypeAnnotationNode* rawOutputType, + const TString& LLVMSettings, + NKikimr::NUdf::ICountersProvider* countersProvider, + ui64 nativeYtTypeFlags, + TMaybe<ui64> deterministicTimeProviderSeed +) + : WorkerFactory_(std::move(factory)) + , Graph_(exprRoot, exprCtx, serializedProgram, funcRegistry, userData, + inputTypes, originalInputTypes, rawInputTypes, outputType, rawOutputType, + LLVMSettings, countersProvider, nativeYtTypeFlags, deterministicTimeProviderSeed) +{ +} + +template <typename TBase> +inline ui32 TWorker<TBase>::GetInputsCount() const { + return Graph_.InputTypes_.size(); +} + +template <typename TBase> +inline const NKikimr::NMiniKQL::TStructType* TWorker<TBase>::GetInputType(ui32 inputIndex, bool original) const { + const auto& container = original ? Graph_.OriginalInputTypes_ : Graph_.InputTypes_; + + YQL_ENSURE(inputIndex < container.size(), "invalid input index (" << inputIndex << ") in GetInputType call"); + + return container[inputIndex]; +} + +template <typename TBase> +inline const NKikimr::NMiniKQL::TStructType* TWorker<TBase>::GetInputType(bool original) const { + const auto& container = original ? Graph_.OriginalInputTypes_ : Graph_.InputTypes_; + + YQL_ENSURE(container.size() == 1, "GetInputType() can be used only for single-input programs"); + + return container[0]; +} + +template <typename TBase> +inline const NKikimr::NMiniKQL::TStructType* TWorker<TBase>::GetRawInputType(ui32 inputIndex) const { + const auto& container = Graph_.RawInputTypes_; + YQL_ENSURE(inputIndex < container.size(), "invalid input index (" << inputIndex << ") in GetInputType call"); + return container[inputIndex]; +} + +template <typename TBase> +inline const NKikimr::NMiniKQL::TStructType* TWorker<TBase>::GetRawInputType() const { + const auto& container = Graph_.RawInputTypes_; + YQL_ENSURE(container.size() == 1, "GetInputType() can be used only for single-input programs"); + return container[0]; +} + +template <typename TBase> +inline const NKikimr::NMiniKQL::TType* TWorker<TBase>::GetOutputType() const { + return Graph_.OutputType_; +} + +template <typename TBase> +inline const NKikimr::NMiniKQL::TType* TWorker<TBase>::GetRawOutputType() const { + return Graph_.RawOutputType_; +} + +template <typename TBase> +NYT::TNode TWorker<TBase>::MakeInputSchema(ui32 inputIndex) const { + auto p = WorkerFactory_.lock(); + YQL_ENSURE(p, "Access to destroyed worker factory"); + return p->MakeInputSchema(inputIndex); +} + +template <typename TBase> +NYT::TNode TWorker<TBase>::MakeInputSchema() const { + auto p = WorkerFactory_.lock(); + YQL_ENSURE(p, "Access to destroyed worker factory"); + return p->MakeInputSchema(); +} + +template <typename TBase> +NYT::TNode TWorker<TBase>::MakeOutputSchema() const { + auto p = WorkerFactory_.lock(); + YQL_ENSURE(p, "Access to destroyed worker factory"); + return p->MakeOutputSchema(); +} + +template <typename TBase> +NYT::TNode TWorker<TBase>::MakeOutputSchema(ui32) const { + auto p = WorkerFactory_.lock(); + YQL_ENSURE(p, "Access to destroyed worker factory"); + return p->MakeOutputSchema(); +} + +template <typename TBase> +NYT::TNode TWorker<TBase>::MakeOutputSchema(TStringBuf) const { + auto p = WorkerFactory_.lock(); + YQL_ENSURE(p, "Access to destroyed worker factory"); + return p->MakeOutputSchema(); +} + +template <typename TBase> +NYT::TNode TWorker<TBase>::MakeFullOutputSchema() const { + auto p = WorkerFactory_.lock(); + YQL_ENSURE(p, "Access to destroyed worker factory"); + return p->MakeFullOutputSchema(); +} + +template <typename TBase> +inline NKikimr::NMiniKQL::TScopedAlloc& TWorker<TBase>::GetScopedAlloc() { + return Graph_.ScopedAlloc_; +} + +template <typename TBase> +inline NKikimr::NMiniKQL::IComputationGraph& TWorker<TBase>::GetGraph() { + return *Graph_.ComputationGraph_; +} + +template <typename TBase> +inline const NKikimr::NMiniKQL::IFunctionRegistry& +TWorker<TBase>::GetFunctionRegistry() const { + return Graph_.FuncRegistry_; +} + +template <typename TBase> +inline NKikimr::NMiniKQL::TTypeEnvironment& +TWorker<TBase>::GetTypeEnvironment() { + return Graph_.Env_; +} + +template <typename TBase> +inline const TString& TWorker<TBase>::GetLLVMSettings() const { + return Graph_.LLVMSettings_; +} + +template <typename TBase> +inline ui64 TWorker<TBase>::GetNativeYtTypeFlags() const { + return Graph_.NativeYtTypeFlags_; +} + +template <typename TBase> +ITimeProvider* TWorker<TBase>::GetTimeProvider() const { + return Graph_.TimeProvider_.Get(); +} + +template <typename TBase> +void TWorker<TBase>::Release() { + if (auto p = WorkerFactory_.lock()) { + p->ReturnWorker(this); + } else { + delete this; + } +} + +TPullStreamWorker::~TPullStreamWorker() { + auto guard = Guard(GetScopedAlloc()); + Output_.Clear(); +} + +void TPullStreamWorker::SetInput(NKikimr::NUdf::TUnboxedValue&& value, ui32 inputIndex) { + const auto inputsCount = Graph_.SelfNodes_.size(); + + if (Y_UNLIKELY(inputIndex >= inputsCount)) { + ythrow yexception() << "invalid input index (" << inputIndex << ") in SetInput call"; + } + + if (HasInput_.size() < inputsCount) { + HasInput_.resize(inputsCount, false); + } + + if (Y_UNLIKELY(HasInput_[inputIndex])) { + ythrow yexception() << "input value for #" << inputIndex << " input is already set"; + } + + auto selfNode = Graph_.SelfNodes_[inputIndex]; + + if (selfNode) { + YQL_ENSURE(value); + selfNode->SetValue(Graph_.ComputationGraph_->GetContext(), std::move(value)); + } + + HasInput_[inputIndex] = true; + + if (CheckAllInputsSet()) { + Output_ = Graph_.ComputationGraph_->GetValue(); + } +} + +NKikimr::NUdf::TUnboxedValue& TPullStreamWorker::GetOutput() { + if (Y_UNLIKELY(!CheckAllInputsSet())) { + ythrow yexception() << "some input values have not been set"; + } + + return Output_; +} + +void TPullStreamWorker::Release() { + with_lock(GetScopedAlloc()) { + Output_ = NKikimr::NUdf::TUnboxedValue::Invalid(); + for (auto selfNode: Graph_.SelfNodes_) { + if (selfNode) { + selfNode->SetValue(Graph_.ComputationGraph_->GetContext(), NKikimr::NUdf::TUnboxedValue::Invalid()); + } + } + } + HasInput_.clear(); + TWorker<IPullStreamWorker>::Release(); +} + +TPullListWorker::~TPullListWorker() { + auto guard = Guard(GetScopedAlloc()); + Output_.Clear(); + OutputIterator_.Clear(); +} + +void TPullListWorker::SetInput(NKikimr::NUdf::TUnboxedValue&& value, ui32 inputIndex) { + const auto inputsCount = Graph_.SelfNodes_.size(); + + if (Y_UNLIKELY(inputIndex >= inputsCount)) { + ythrow yexception() << "invalid input index (" << inputIndex << ") in SetInput call"; + } + + if (HasInput_.size() < inputsCount) { + HasInput_.resize(inputsCount, false); + } + + if (Y_UNLIKELY(HasInput_[inputIndex])) { + ythrow yexception() << "input value for #" << inputIndex << " input is already set"; + } + + auto selfNode = Graph_.SelfNodes_[inputIndex]; + + if (selfNode) { + YQL_ENSURE(value); + selfNode->SetValue(Graph_.ComputationGraph_->GetContext(), std::move(value)); + } + + HasInput_[inputIndex] = true; + + if (CheckAllInputsSet()) { + Output_ = Graph_.ComputationGraph_->GetValue(); + ResetOutputIterator(); + } +} + +NKikimr::NUdf::TUnboxedValue& TPullListWorker::GetOutput() { + if (Y_UNLIKELY(!CheckAllInputsSet())) { + ythrow yexception() << "some input values have not been set"; + } + + return Output_; +} + +NKikimr::NUdf::TUnboxedValue& TPullListWorker::GetOutputIterator() { + if (Y_UNLIKELY(!CheckAllInputsSet())) { + ythrow yexception() << "some input values have not been set"; + } + + return OutputIterator_; +} + +void TPullListWorker::ResetOutputIterator() { + if (Y_UNLIKELY(!CheckAllInputsSet())) { + ythrow yexception() << "some input values have not been set"; + } + + OutputIterator_ = Output_.GetListIterator(); +} + +void TPullListWorker::Release() { + with_lock(GetScopedAlloc()) { + Output_ = NKikimr::NUdf::TUnboxedValue::Invalid(); + OutputIterator_ = NKikimr::NUdf::TUnboxedValue::Invalid(); + + for (auto selfNode: Graph_.SelfNodes_) { + if (selfNode) { + selfNode->SetValue(Graph_.ComputationGraph_->GetContext(), NKikimr::NUdf::TUnboxedValue::Invalid()); + } + } + } + HasInput_.clear(); + TWorker<IPullListWorker>::Release(); +} + +namespace { + class TPushStream final: public NKikimr::NMiniKQL::TCustomListValue { + private: + mutable bool HasIterator_ = false; + bool HasValue_ = false; + bool IsFinished_ = false; + NKikimr::NUdf::TUnboxedValue Value_ = NKikimr::NUdf::TUnboxedValue::Invalid(); + + public: + using TCustomListValue::TCustomListValue; + + public: + void SetValue(NKikimr::NUdf::TUnboxedValue&& value) { + Value_ = std::move(value); + HasValue_ = true; + } + + void SetFinished() { + IsFinished_ = true; + } + + NKikimr::NUdf::TUnboxedValue GetListIterator() const override { + YQL_ENSURE(!HasIterator_, "only one pass over input is supported"); + HasIterator_ = true; + return NKikimr::NUdf::TUnboxedValuePod(const_cast<TPushStream*>(this)); + } + + NKikimr::NUdf::EFetchStatus Fetch(NKikimr::NUdf::TUnboxedValue& result) override { + if (IsFinished_) { + return NKikimr::NUdf::EFetchStatus::Finish; + } else if (!HasValue_) { + return NKikimr::NUdf::EFetchStatus::Yield; + } else { + result = std::move(Value_); + HasValue_ = false; + return NKikimr::NUdf::EFetchStatus::Ok; + } + } + }; +} + +void TPushStreamWorker::FeedToConsumer() { + auto value = Graph_.ComputationGraph_->GetValue(); + + for (;;) { + NKikimr::NUdf::TUnboxedValue item; + auto status = value.Fetch(item); + + if (status != NKikimr::NUdf::EFetchStatus::Ok) { + break; + } + + Consumer_->OnObject(&item); + } +} + +NYql::NUdf::IBoxedValue* TPushStreamWorker::GetPushStream() const { + auto& ctx = Graph_.ComputationGraph_->GetContext(); + NUdf::TUnboxedValue pushStream = SelfNode_->GetValue(ctx); + + if (Y_UNLIKELY(pushStream.IsInvalid())) { + SelfNode_->SetValue(ctx, Graph_.ComputationGraph_->GetHolderFactory().Create<TPushStream>()); + pushStream = SelfNode_->GetValue(ctx); + } + + return pushStream.AsBoxed().Get(); +} + +void TPushStreamWorker::SetConsumer(THolder<IConsumer<const NKikimr::NUdf::TUnboxedValue*>> consumer) { + auto guard = Guard(GetScopedAlloc()); + const auto inputsCount = Graph_.SelfNodes_.size(); + + YQL_ENSURE(inputsCount < 2, "push stream mode doesn't support several inputs"); + YQL_ENSURE(!Consumer_, "consumer is already set"); + + Consumer_ = std::move(consumer); + + if (inputsCount == 1) { + SelfNode_ = Graph_.SelfNodes_[0]; + } + + if (SelfNode_) { + SelfNode_->SetValue( + Graph_.ComputationGraph_->GetContext(), + Graph_.ComputationGraph_->GetHolderFactory().Create<TPushStream>()); + } + + FeedToConsumer(); +} + +void TPushStreamWorker::Push(NKikimr::NUdf::TUnboxedValue&& value) { + YQL_ENSURE(Consumer_, "consumer is not set"); + YQL_ENSURE(!Finished_, "OnFinish has already been sent to the consumer; no new values can be pushed"); + + if (Y_LIKELY(SelfNode_)) { + static_cast<TPushStream*>(GetPushStream())->SetValue(std::move(value)); + } + + FeedToConsumer(); +} + +void TPushStreamWorker::OnFinish() { + YQL_ENSURE(Consumer_, "consumer is not set"); + YQL_ENSURE(!Finished_, "already finished"); + + if (Y_LIKELY(SelfNode_)) { + static_cast<TPushStream*>(GetPushStream())->SetFinished(); + } + + FeedToConsumer(); + + Consumer_->OnFinish(); + + Finished_ = true; +} + +void TPushStreamWorker::Release() { + with_lock(GetScopedAlloc()) { + Consumer_.Destroy(); + if (SelfNode_) { + SelfNode_->SetValue(Graph_.ComputationGraph_->GetContext(), NKikimr::NUdf::TUnboxedValue::Invalid()); + } + SelfNode_ = nullptr; + } + Finished_ = false; + TWorker<IPushStreamWorker>::Release(); +} + + +namespace NYql { + namespace NPureCalc { + template + class TWorker<IPullStreamWorker>; + + template + class TWorker<IPullListWorker>; + + template + class TWorker<IPushStreamWorker>; + } +} diff --git a/yql/essentials/public/purecalc/common/worker.h b/yql/essentials/public/purecalc/common/worker.h new file mode 100644 index 00000000000..07b8dfa2e79 --- /dev/null +++ b/yql/essentials/public/purecalc/common/worker.h @@ -0,0 +1,178 @@ +#pragma once + +#include <yql/essentials/public/purecalc/common/interface.h> + +#include <yql/essentials/public/udf/udf_value.h> +#include <yql/essentials/ast/yql_expr.h> +#include <yql/essentials/core/yql_user_data.h> +#include <yql/essentials/minikql/mkql_alloc.h> +#include <yql/essentials/minikql/mkql_node.h> +#include <yql/essentials/minikql/mkql_node_visitor.h> +#include <yql/essentials/minikql/computation/mkql_computation_node.h> +#include <yql/essentials/providers/common/mkql/yql_provider_mkql.h> + +#include <memory> + +namespace NYql { + namespace NPureCalc { + struct TWorkerGraph { + TWorkerGraph( + const TExprNode::TPtr& exprRoot, + TExprContext& exprCtx, + const TString& serializedProgram, + const NKikimr::NMiniKQL::IFunctionRegistry& funcRegistry, + const TUserDataTable& userData, + const TVector<const TStructExprType*>& inputTypes, + const TVector<const TStructExprType*>& originalInputTypes, + const TVector<const TStructExprType*>& rawInputTypes, + const TTypeAnnotationNode* outputType, + const TTypeAnnotationNode* rawOutputType, + const TString& LLVMSettings, + NKikimr::NUdf::ICountersProvider* countersProvider, + ui64 nativeYtTypeFlags, + TMaybe<ui64> deterministicTimeProviderSeed + ); + + ~TWorkerGraph(); + + NKikimr::NMiniKQL::TScopedAlloc ScopedAlloc_; + NKikimr::NMiniKQL::TTypeEnvironment Env_; + const NKikimr::NMiniKQL::IFunctionRegistry& FuncRegistry_; + TIntrusivePtr<IRandomProvider> RandomProvider_; + TIntrusivePtr<ITimeProvider> TimeProvider_; + NKikimr::NMiniKQL::IComputationPattern::TPtr ComputationPattern_; + THolder<NKikimr::NMiniKQL::IComputationGraph> ComputationGraph_; + TString LLVMSettings_; + ui64 NativeYtTypeFlags_; + TMaybe<TString> TimestampColumn_; + const NKikimr::NMiniKQL::TType* OutputType_; + const NKikimr::NMiniKQL::TType* RawOutputType_; + TVector<NKikimr::NMiniKQL::IComputationExternalNode*> SelfNodes_; + TVector<const NKikimr::NMiniKQL::TStructType*> InputTypes_; + TVector<const NKikimr::NMiniKQL::TStructType*> OriginalInputTypes_; + TVector<const NKikimr::NMiniKQL::TStructType*> RawInputTypes_; + }; + + template <typename TBase> + class TWorker: public TBase { + public: + using TWorkerFactoryPtr = std::weak_ptr<IWorkerFactory>; + private: + // Worker factory implementation should stay alive for this worker to operate correctly. + TWorkerFactoryPtr WorkerFactory_; + + protected: + TWorkerGraph Graph_; + + public: + TWorker( + TWorkerFactoryPtr factory, + const TExprNode::TPtr& exprRoot, + TExprContext& exprCtx, + const TString& serializedProgram, + const NKikimr::NMiniKQL::IFunctionRegistry& funcRegistry, + const TUserDataTable& userData, + const TVector<const TStructExprType*>& inputTypes, + const TVector<const TStructExprType*>& originalInputTypes, + const TVector<const TStructExprType*>& rawInputTypes, + const TTypeAnnotationNode* outputType, + const TTypeAnnotationNode* rawOutputType, + const TString& LLVMSettings, + NKikimr::NUdf::ICountersProvider* countersProvider, + ui64 nativeYtTypeFlags, + TMaybe<ui64> deterministicTimeProviderSeed + ); + + public: + ui32 GetInputsCount() const override; + const NKikimr::NMiniKQL::TStructType* GetInputType(ui32, bool) const override; + const NKikimr::NMiniKQL::TStructType* GetInputType(bool) const override; + const NKikimr::NMiniKQL::TStructType* GetRawInputType(ui32) const override; + const NKikimr::NMiniKQL::TStructType* GetRawInputType() const override; + const NKikimr::NMiniKQL::TType* GetOutputType() const override; + const NKikimr::NMiniKQL::TType* GetRawOutputType() const override; + NYT::TNode MakeInputSchema() const override; + NYT::TNode MakeInputSchema(ui32) const override; + NYT::TNode MakeOutputSchema() const override; + NYT::TNode MakeOutputSchema(ui32) const override; + NYT::TNode MakeOutputSchema(TStringBuf) const override; + NYT::TNode MakeFullOutputSchema() const override; + NKikimr::NMiniKQL::TScopedAlloc& GetScopedAlloc() override; + NKikimr::NMiniKQL::IComputationGraph& GetGraph() override; + const NKikimr::NMiniKQL::IFunctionRegistry& GetFunctionRegistry() const override; + NKikimr::NMiniKQL::TTypeEnvironment& GetTypeEnvironment() override; + const TString& GetLLVMSettings() const override; + ui64 GetNativeYtTypeFlags() const override; + ITimeProvider* GetTimeProvider() const override; + protected: + void Release() override; + }; + + class TPullStreamWorker final: public TWorker<IPullStreamWorker> { + private: + NKikimr::NUdf::TUnboxedValue Output_ = NKikimr::NUdf::TUnboxedValue::Invalid(); + TVector<bool> HasInput_; + + inline bool CheckAllInputsSet() { + return AllOf(HasInput_, [](bool x) { return x; }); + } + + public: + using TWorker::TWorker; + ~TPullStreamWorker(); + + public: + void SetInput(NKikimr::NUdf::TUnboxedValue&&, ui32) override; + NKikimr::NUdf::TUnboxedValue& GetOutput() override; + + protected: + void Release() override; + }; + + class TPullListWorker final: public TWorker<IPullListWorker> { + private: + NKikimr::NUdf::TUnboxedValue Output_ = NKikimr::NUdf::TUnboxedValue::Invalid(); + NKikimr::NUdf::TUnboxedValue OutputIterator_ = NKikimr::NUdf::TUnboxedValue::Invalid(); + TVector<bool> HasInput_; + + inline bool CheckAllInputsSet() { + return AllOf(HasInput_, [](bool x) { return x; }); + } + + public: + using TWorker::TWorker; + ~TPullListWorker(); + + public: + void SetInput(NKikimr::NUdf::TUnboxedValue&&, ui32) override; + NKikimr::NUdf::TUnboxedValue& GetOutput() override; + NKikimr::NUdf::TUnboxedValue& GetOutputIterator() override; + void ResetOutputIterator() override; + + protected: + void Release() override; + }; + + class TPushStreamWorker final: public TWorker<IPushStreamWorker> { + private: + THolder<IConsumer<const NKikimr::NUdf::TUnboxedValue*>> Consumer_{}; + bool Finished_ = false; + NKikimr::NMiniKQL::IComputationExternalNode* SelfNode_ = nullptr; + + public: + using TWorker::TWorker; + + private: + void FeedToConsumer(); + NYql::NUdf::IBoxedValue* GetPushStream() const; + + public: + void SetConsumer(THolder<IConsumer<const NKikimr::NUdf::TUnboxedValue*>>) override; + void Push(NKikimr::NUdf::TUnboxedValue&&) override; + void OnFinish() override; + + protected: + void Release() override; + }; + } +} diff --git a/yql/essentials/public/purecalc/common/worker_factory.cpp b/yql/essentials/public/purecalc/common/worker_factory.cpp new file mode 100644 index 00000000000..173f73b7beb --- /dev/null +++ b/yql/essentials/public/purecalc/common/worker_factory.cpp @@ -0,0 +1,532 @@ +#include "worker_factory.h" + +#include "type_from_schema.h" +#include "worker.h" +#include "compile_mkql.h" + +#include <yql/essentials/sql/sql.h> +#include <yql/essentials/ast/yql_expr.h> +#include <yql/essentials/core/yql_expr_optimize.h> +#include <yql/essentials/core/yql_type_helpers.h> +#include <yql/essentials/core/peephole_opt/yql_opt_peephole_physical.h> +#include <yql/essentials/providers/common/codec/yql_codec.h> +#include <yql/essentials/providers/common/udf_resolve/yql_simple_udf_resolver.h> +#include <yql/essentials/providers/common/arrow_resolve/yql_simple_arrow_resolver.h> +#include <yql/essentials/providers/common/schema/expr/yql_expr_schema.h> +#include <yql/essentials/providers/common/provider/yql_provider.h> +#include <yql/essentials/providers/common/provider/yql_provider_names.h> +#include <yql/essentials/providers/config/yql_config_provider.h> +#include <yql/essentials/minikql/mkql_node.h> +#include <yql/essentials/minikql/mkql_node_serialization.h> +#include <yql/essentials/minikql/mkql_alloc.h> +#include <yql/essentials/minikql/aligned_page_pool.h> +#include <yql/essentials/core/services/yql_transform_pipeline.h> +#include <yql/essentials/public/purecalc/common/names.h> +#include <yql/essentials/public/purecalc/common/transformations/type_annotation.h> +#include <yql/essentials/public/purecalc/common/transformations/align_output_schema.h> +#include <yql/essentials/public/purecalc/common/transformations/extract_used_columns.h> +#include <yql/essentials/public/purecalc/common/transformations/output_columns_filter.h> +#include <yql/essentials/public/purecalc/common/transformations/replace_table_reads.h> +#include <yql/essentials/public/purecalc/common/transformations/root_to_blocks.h> +#include <yql/essentials/public/purecalc/common/transformations/utils.h> +#include <yql/essentials/utils/log/log.h> +#include <util/stream/trace.h> + +using namespace NYql; +using namespace NYql::NPureCalc; + +template <typename TBase> +TWorkerFactory<TBase>::TWorkerFactory(TWorkerFactoryOptions options, EProcessorMode processorMode) + : Factory_(std::move(options.Factory)) + , FuncRegistry_(std::move(options.FuncRegistry)) + , UserData_(std::move(options.UserData)) + , LLVMSettings_(std::move(options.LLVMSettings)) + , BlockEngineMode_(options.BlockEngineMode) + , ExprOutputStream_(options.ExprOutputStream) + , CountersProvider_(options.CountersProvider_) + , NativeYtTypeFlags_(options.NativeYtTypeFlags_) + , DeterministicTimeProviderSeed_(options.DeterministicTimeProviderSeed_) + , UseSystemColumns_(options.UseSystemColumns) + , UseWorkerPool_(options.UseWorkerPool) +{ + // Prepare input struct types and extract all column names from inputs + + const auto& inputSchemas = options.InputSpec.GetSchemas(); + const auto& allVirtualColumns = options.InputSpec.GetAllVirtualColumns(); + + YQL_ENSURE(inputSchemas.size() == allVirtualColumns.size()); + + const auto inputsCount = inputSchemas.size(); + + for (ui32 i = 0; i < inputsCount; ++i) { + const auto* originalInputType = MakeTypeFromSchema(inputSchemas[i], ExprContext_); + if (!ValidateInputSchema(originalInputType, ExprContext_)) { + ythrow TCompileError("", ExprContext_.IssueManager.GetIssues().ToString()) << "invalid schema for #" << i << " input"; + } + + const auto* originalStructType = originalInputType->template Cast<TStructExprType>(); + const auto* structType = ExtendStructType(originalStructType, allVirtualColumns[i], ExprContext_); + + InputTypes_.push_back(structType); + OriginalInputTypes_.push_back(originalStructType); + RawInputTypes_.push_back(originalStructType); + + auto& columnsSet = AllColumns_.emplace_back(); + for (const auto* structItem : structType->GetItems()) { + columnsSet.insert(TString(structItem->GetName())); + + if (!UseSystemColumns_ && structItem->GetName().StartsWith(PurecalcSysColumnsPrefix)) { + ythrow TCompileError("", ExprContext_.IssueManager.GetIssues().ToString()) + << "#" << i << " input provides system column " << structItem->GetName() + << ", but it is forbidden by options"; + } + } + } + + // Prepare output type + + auto outputSchema = options.OutputSpec.GetSchema(); + if (!outputSchema.IsNull()) { + OutputType_ = MakeTypeFromSchema(outputSchema, ExprContext_); + if (!ValidateOutputSchema(OutputType_, ExprContext_)) { + ythrow TCompileError("", ExprContext_.IssueManager.GetIssues().ToString()) << "invalid output schema"; + } + } else { + OutputType_ = nullptr; + } + + RawOutputType_ = OutputType_; + + // Translate + + if (options.TranslationMode_ == ETranslationMode::Mkql) { + SerializedProgram_ = TString{options.Query}; + } else { + ExprRoot_ = Compile(options.Query, options.TranslationMode_, + options.ModuleResolver, options.SyntaxVersion_, options.Modules, + options.InputSpec, options.OutputSpec, processorMode); + + RawOutputType_ = GetSequenceItemType(ExprRoot_->Pos(), ExprRoot_->GetTypeAnn(), true, ExprContext_); + + // Deduce output type if it wasn't provided by output spec + + if (!OutputType_) { + OutputType_ = RawOutputType_; + // XXX: Tweak the obtained expression type, is the spec supports blocks: + // 1. Remove "_yql_block_length" attribute, since it's for internal usage. + // 2. Strip block container from the type to store its internal type. + if (options.OutputSpec.AcceptsBlocks()) { + Y_ENSURE(OutputType_->GetKind() == ETypeAnnotationKind::Struct); + OutputType_ = UnwrapBlockStruct(OutputType_->Cast<TStructExprType>(), ExprContext_); + } + } + if (!OutputType_) { + ythrow TCompileError("", ExprContext_.IssueManager.GetIssues().ToString()) << "cannot deduce output schema"; + } + } +} + +template <typename TBase> +TExprNode::TPtr TWorkerFactory<TBase>::Compile( + TStringBuf query, + ETranslationMode mode, + IModuleResolver::TPtr moduleResolver, + ui16 syntaxVersion, + const THashMap<TString, TString>& modules, + const TInputSpecBase& inputSpec, + const TOutputSpecBase& outputSpec, + EProcessorMode processorMode +) { + if (mode == ETranslationMode::PG && processorMode != EProcessorMode::PullList) { + ythrow TCompileError("", "") << "only PullList mode is compatible to PostgreSQL syntax"; + } + + // Prepare type annotation context + + TTypeAnnotationContextPtr typeContext; + + typeContext = MakeIntrusive<TTypeAnnotationContext>(); + typeContext->RandomProvider = CreateDefaultRandomProvider(); + typeContext->TimeProvider = DeterministicTimeProviderSeed_ ? + CreateDeterministicTimeProvider(*DeterministicTimeProviderSeed_) : + CreateDefaultTimeProvider(); + typeContext->UdfResolver = NCommon::CreateSimpleUdfResolver(FuncRegistry_.Get()); + typeContext->ArrowResolver = MakeSimpleArrowResolver(*FuncRegistry_.Get()); + typeContext->UserDataStorage = MakeIntrusive<TUserDataStorage>(nullptr, UserData_, nullptr, nullptr); + typeContext->Modules = moduleResolver; + typeContext->BlockEngineMode = BlockEngineMode_; + auto configProvider = CreateConfigProvider(*typeContext, nullptr, ""); + typeContext->AddDataSource(ConfigProviderName, configProvider); + typeContext->Initialize(ExprContext_); + + if (auto modules = dynamic_cast<TModuleResolver*>(moduleResolver.get())) { + modules->AttachUserData(typeContext->UserDataStorage); + } + + // Parse SQL/s-expr into AST + + TAstParseResult astRes; + + if (mode == ETranslationMode::SQL || mode == ETranslationMode::PG) { + NSQLTranslation::TTranslationSettings settings; + + typeContext->DeprecatedSQL = (syntaxVersion == 0); + if (mode == ETranslationMode::PG) { + settings.PgParser = true; + } + + settings.SyntaxVersion = syntaxVersion; + settings.V0Behavior = NSQLTranslation::EV0Behavior::Disable; + settings.Mode = NSQLTranslation::ESqlMode::LIMITED_VIEW; + settings.DefaultCluster = PurecalcDefaultCluster; + settings.ClusterMapping[settings.DefaultCluster] = PurecalcDefaultService; + settings.ModuleMapping = modules; + settings.EnableGenericUdfs = true; + settings.File = "generated.sql"; + settings.Flags = { + "AnsiOrderByLimitInUnionAll", + "AnsiRankForNullableKeys", + "DisableAnsiOptionalAs", + "DisableCoalesceJoinKeysOnQualifiedAll", + "DisableUnorderedSubqueries", + "FlexibleTypes" + }; + if (BlockEngineMode_ != EBlockEngineMode::Disable) { + settings.Flags.insert("EmitAggApply"); + } + for (const auto& [key, block] : UserData_) { + TStringBuf alias(key.Alias()); + if (block.Usage.Test(EUserDataBlockUsage::Library) && !alias.StartsWith("/lib")) { + alias.SkipPrefix("/home/"); + settings.Libraries.emplace(alias); + } + } + + astRes = SqlToYql(TString(query), settings); + } else { + astRes = ParseAst(TString(query)); + } + + if (!astRes.IsOk()) { + ythrow TCompileError(TString(query), astRes.Issues.ToString()) << "failed to parse " << mode; + } + + ExprContext_.IssueManager.AddIssues(astRes.Issues); + + if (ETraceLevel::TRACE_DETAIL <= StdDbgLevel()) { + Cdbg << "Before optimization:" << Endl; + astRes.Root->PrettyPrintTo(Cdbg, TAstPrintFlags::PerLine | TAstPrintFlags::ShortQuote | TAstPrintFlags::AdaptArbitraryContent); + } + + // Translate AST into expression + + TExprNode::TPtr exprRoot; + if (!CompileExpr(*astRes.Root, exprRoot, ExprContext_, moduleResolver.get(), nullptr, 0, syntaxVersion)) { + TStringStream astStr; + astRes.Root->PrettyPrintTo(astStr, TAstPrintFlags::ShortQuote | TAstPrintFlags::PerLine); + ythrow TCompileError(astStr.Str(), ExprContext_.IssueManager.GetIssues().ToString()) << "failed to compile"; + } + + + // Prepare transformation pipeline + THolder<IGraphTransformer> calcTransformer = CreateFunctorTransformer([&](TExprNode::TPtr input, TExprNode::TPtr& output, TExprContext& ctx) + -> IGraphTransformer::TStatus + { + output = input; + auto valueNode = input->HeadPtr(); + + auto peepHole = MakePeepholeOptimization(typeContext); + auto status = SyncTransform(*peepHole, valueNode, ctx); + if (status != IGraphTransformer::TStatus::Ok) { + return status; + } + + TStringStream out; + NYson::TYsonWriter writer(&out, NYson::EYsonFormat::Text, ::NYson::EYsonType::Node, true); + writer.OnBeginMap(); + + writer.OnKeyedItem("Data"); + + TWorkerGraph graph( + valueNode, + ctx, + {}, + *FuncRegistry_, + UserData_, + {}, + {}, + {}, + valueNode->GetTypeAnn(), + valueNode->GetTypeAnn(), + LLVMSettings_, + CountersProvider_, + NativeYtTypeFlags_, + DeterministicTimeProviderSeed_ + ); + + with_lock (graph.ScopedAlloc_) { + const auto value = graph.ComputationGraph_->GetValue(); + NCommon::WriteYsonValue(writer, value, const_cast<NKikimr::NMiniKQL::TType*>(graph.OutputType_), nullptr); + } + writer.OnEndMap(); + + auto ysonAtom = ctx.NewAtom(TPositionHandle(), out.Str()); + input->SetResult(std::move(ysonAtom)); + return IGraphTransformer::TStatus::Ok; + }); + + const TString& selfName = TString(inputSpec.ProvidesBlocks() + ? PurecalcBlockInputCallableName + : PurecalcInputCallableName); + + TTransformationPipeline pipeline(typeContext); + + pipeline.Add(MakeTableReadsReplacer(InputTypes_, UseSystemColumns_, processorMode, selfName), + "ReplaceTableReads", EYqlIssueCode::TIssuesIds_EIssueCode_DEFAULT_ERROR, + "Replace reads from tables"); + pipeline.AddServiceTransformers(); + pipeline.AddPreTypeAnnotation(); + pipeline.AddExpressionEvaluation(*FuncRegistry_, calcTransformer.Get()); + pipeline.AddIOAnnotation(); + pipeline.AddTypeAnnotationTransformer(MakeTypeAnnotationTransformer(typeContext, InputTypes_, RawInputTypes_, processorMode, selfName)); + pipeline.AddPostTypeAnnotation(); + pipeline.Add(CreateFunctorTransformer( + [&](const TExprNode::TPtr& input, TExprNode::TPtr& output, TExprContext& ctx) { + return OptimizeExpr(input, output, [](const TExprNode::TPtr& node, TExprContext&) -> TExprNode::TPtr { + if (node->IsCallable("Unordered") && node->Child(0)->IsCallable({ + PurecalcInputCallableName, PurecalcBlockInputCallableName + })) { + return node->ChildPtr(0); + } + return node; + }, ctx, TOptimizeExprSettings(nullptr)); + }), "Unordered", EYqlIssueCode::TIssuesIds_EIssueCode_DEFAULT_ERROR, + "Unordered optimizations"); + pipeline.Add(CreateFunctorTransformer( + [&](const TExprNode::TPtr& input, TExprNode::TPtr& output, TExprContext& ctx) { + return OptimizeExpr(input, output, [](const TExprNode::TPtr& node, TExprContext&) -> TExprNode::TPtr { + if (node->IsCallable("Right!") && node->Head().IsCallable("Cons!")) { + return node->Head().ChildPtr(1); + } + + return node; + }, ctx, TOptimizeExprSettings(nullptr)); + }), "Cons", EYqlIssueCode::TIssuesIds_EIssueCode_DEFAULT_ERROR, + "Cons optimizations"); + pipeline.Add(MakeOutputColumnsFilter(outputSpec.GetOutputColumnsFilter()), + "Filter", EYqlIssueCode::TIssuesIds_EIssueCode_DEFAULT_ERROR, + "Filter output columns"); + pipeline.Add(MakeRootToBlocks(outputSpec.AcceptsBlocks(), processorMode), + "RootToBlocks", EYqlIssueCode::TIssuesIds_EIssueCode_DEFAULT_ERROR, + "Rewrite the root if the output spec accepts blocks"); + pipeline.Add(MakeOutputAligner(OutputType_, outputSpec.AcceptsBlocks(), processorMode), + "Convert", EYqlIssueCode::TIssuesIds_EIssueCode_DEFAULT_ERROR, + "Align return type of the program to output schema"); + pipeline.AddCommonOptimization(); + pipeline.AddFinalCommonOptimization(); + pipeline.Add(MakeUsedColumnsExtractor(&UsedColumns_, AllColumns_), + "ExtractColumns", EYqlIssueCode::TIssuesIds_EIssueCode_DEFAULT_ERROR, + "Extract used columns"); + pipeline.Add(MakePeepholeOptimization(typeContext), + "PeepHole", EYqlIssueCode::TIssuesIds_EIssueCode_DEFAULT_ERROR, + "Peephole optimizations"); + pipeline.AddCheckExecution(false); + + // Apply optimizations + + auto transformer = pipeline.Build(); + auto status = SyncTransform(*transformer, exprRoot, ExprContext_); + auto transformStats = transformer->GetStatistics(); + TStringStream out; + NYson::TYsonWriter writer(&out, NYson::EYsonFormat::Pretty); + NCommon::TransformerStatsToYson("", transformStats, writer); + YQL_CLOG(DEBUG, Core) << "Transform stats: " << out.Str(); + if (status == IGraphTransformer::TStatus::Error) { + ythrow TCompileError("", ExprContext_.IssueManager.GetIssues().ToString()) << "Failed to optimize"; + } + + IOutputStream* exprOut = nullptr; + if (ExprOutputStream_) { + exprOut = ExprOutputStream_; + } else if (ETraceLevel::TRACE_DETAIL <= StdDbgLevel()) { + exprOut = &Cdbg; + } + + if (exprOut) { + *exprOut << "After optimization:" << Endl; + ConvertToAst(*exprRoot, ExprContext_, 0, true).Root + ->PrettyPrintTo(*exprOut, TAstPrintFlags::PerLine + | TAstPrintFlags::ShortQuote + | TAstPrintFlags::AdaptArbitraryContent); + } + return exprRoot; +} + +template <typename TBase> +NYT::TNode TWorkerFactory<TBase>::MakeInputSchema(ui32 inputIndex) const { + Y_ENSURE( + inputIndex < InputTypes_.size(), + "invalid input index (" << inputIndex << ") in MakeInputSchema call"); + + return NCommon::TypeToYsonNode(InputTypes_[inputIndex]); +} + +template <typename TBase> +NYT::TNode TWorkerFactory<TBase>::MakeInputSchema() const { + Y_ENSURE( + InputTypes_.size() == 1, + "MakeInputSchema() can be used only with single-input programs"); + + return NCommon::TypeToYsonNode(InputTypes_[0]); +} + +template <typename TBase> +NYT::TNode TWorkerFactory<TBase>::MakeOutputSchema() const { + Y_ENSURE(OutputType_, "MakeOutputSchema() cannot be used with precompiled programs"); + Y_ENSURE( + OutputType_->GetKind() == ETypeAnnotationKind::Struct, + "MakeOutputSchema() cannot be used with multi-output programs"); + + return NCommon::TypeToYsonNode(OutputType_); +} + +template <typename TBase> +NYT::TNode TWorkerFactory<TBase>::MakeOutputSchema(ui32 index) const { + Y_ENSURE(OutputType_, "MakeOutputSchema() cannot be used with precompiled programs"); + Y_ENSURE( + OutputType_->GetKind() == ETypeAnnotationKind::Variant, + "MakeOutputSchema(ui32) cannot be used with single-output programs"); + + auto vtype = OutputType_->template Cast<TVariantExprType>(); + + Y_ENSURE( + vtype->GetUnderlyingType()->GetKind() == ETypeAnnotationKind::Tuple, + "MakeOutputSchema(ui32) cannot be used to process variants over struct"); + + auto ttype = vtype->GetUnderlyingType()->template Cast<TTupleExprType>(); + + Y_ENSURE( + index < ttype->GetSize(), + "Invalid table index " << index); + + return NCommon::TypeToYsonNode(ttype->GetItems()[index]); +} + +template <typename TBase> +NYT::TNode TWorkerFactory<TBase>::MakeOutputSchema(TStringBuf tableName) const { + Y_ENSURE(OutputType_, "MakeOutputSchema() cannot be used with precompiled programs"); + Y_ENSURE( + OutputType_->GetKind() == ETypeAnnotationKind::Variant, + "MakeOutputSchema(TStringBuf) cannot be used with single-output programs"); + + auto vtype = OutputType_->template Cast<TVariantExprType>(); + + Y_ENSURE( + vtype->GetUnderlyingType()->GetKind() == ETypeAnnotationKind::Struct, + "MakeOutputSchema(TStringBuf) cannot be used to process variants over tuple"); + + auto stype = vtype->GetUnderlyingType()->template Cast<TStructExprType>(); + + auto index = stype->FindItem(tableName); + + Y_ENSURE( + index.Defined(), + "Invalid table index " << TString{tableName}.Quote()); + + return NCommon::TypeToYsonNode(stype->GetItems()[*index]->GetItemType()); +} + +template <typename TBase> +NYT::TNode TWorkerFactory<TBase>::MakeFullOutputSchema() const { + Y_ENSURE(OutputType_, "MakeFullOutputSchema() cannot be used with precompiled programs"); + return NCommon::TypeToYsonNode(OutputType_); +} + +template <typename TBase> +const THashSet<TString>& TWorkerFactory<TBase>::GetUsedColumns(ui32 inputIndex) const { + Y_ENSURE( + inputIndex < UsedColumns_.size(), + "invalid input index (" << inputIndex << ") in GetUsedColumns call"); + + return UsedColumns_[inputIndex]; +} + +template <typename TBase> +const THashSet<TString>& TWorkerFactory<TBase>::GetUsedColumns() const { + Y_ENSURE( + UsedColumns_.size() == 1, + "GetUsedColumns() can be used only with single-input programs"); + + return UsedColumns_[0]; +} + +template <typename TBase> +TIssues TWorkerFactory<TBase>::GetIssues() const { + return ExprContext_.IssueManager.GetCompletedIssues(); +} + +template <typename TBase> +TString TWorkerFactory<TBase>::GetCompiledProgram() { + if (ExprRoot_) { + NKikimr::NMiniKQL::TScopedAlloc alloc(__LOCATION__, NKikimr::TAlignedPagePoolCounters(), + FuncRegistry_->SupportsSizedAllocators()); + NKikimr::NMiniKQL::TTypeEnvironment env(alloc); + + auto rootNode = CompileMkql(ExprRoot_, ExprContext_, *FuncRegistry_, env, UserData_); + return NKikimr::NMiniKQL::SerializeRuntimeNode(rootNode, env); + } + + return SerializedProgram_; +} + +template <typename TBase> +void TWorkerFactory<TBase>::ReturnWorker(IWorker* worker) { + THolder<IWorker> tmp(worker); + if (UseWorkerPool_) { + WorkerPool_.push_back(std::move(tmp)); + } +} + + +#define DEFINE_WORKER_MAKER(MODE) \ + TWorkerHolder<I##MODE##Worker> T##MODE##WorkerFactory::MakeWorker() { \ + if (!WorkerPool_.empty()) { \ + auto res = std::move(WorkerPool_.back()); \ + WorkerPool_.pop_back(); \ + return TWorkerHolder<I##MODE##Worker>((I##MODE##Worker *)res.Release()); \ + } \ + return TWorkerHolder<I##MODE##Worker>(new T##MODE##Worker( \ + weak_from_this(), \ + ExprRoot_, \ + ExprContext_, \ + SerializedProgram_, \ + *FuncRegistry_, \ + UserData_, \ + InputTypes_, \ + OriginalInputTypes_, \ + RawInputTypes_, \ + OutputType_, \ + RawOutputType_, \ + LLVMSettings_, \ + CountersProvider_, \ + NativeYtTypeFlags_, \ + DeterministicTimeProviderSeed_ \ + )); \ + } + +DEFINE_WORKER_MAKER(PullStream) +DEFINE_WORKER_MAKER(PullList) +DEFINE_WORKER_MAKER(PushStream) + +namespace NYql { + namespace NPureCalc { + template + class TWorkerFactory<IPullStreamWorkerFactory>; + + template + class TWorkerFactory<IPullListWorkerFactory>; + + template + class TWorkerFactory<IPushStreamWorkerFactory>; + } +} diff --git a/yql/essentials/public/purecalc/common/worker_factory.h b/yql/essentials/public/purecalc/common/worker_factory.h new file mode 100644 index 00000000000..d2600413aa9 --- /dev/null +++ b/yql/essentials/public/purecalc/common/worker_factory.h @@ -0,0 +1,168 @@ +#pragma once + +#include <yql/essentials/public/purecalc/common/interface.h> + +#include "processor_mode.h" + +#include <util/generic/ptr.h> +#include <yql/essentials/ast/yql_expr.h> +#include <yql/essentials/core/yql_user_data.h> +#include <yql/essentials/minikql/mkql_function_registry.h> +#include <yql/essentials/core/yql_type_annotation.h> +#include <utility> + +namespace NYql { + namespace NPureCalc { + struct TWorkerFactoryOptions { + IProgramFactoryPtr Factory; + const TInputSpecBase& InputSpec; + const TOutputSpecBase& OutputSpec; + TStringBuf Query; + TIntrusivePtr<NKikimr::NMiniKQL::IMutableFunctionRegistry> FuncRegistry; + IModuleResolver::TPtr ModuleResolver; + const TUserDataTable& UserData; + const THashMap<TString, TString>& Modules; + TString LLVMSettings; + EBlockEngineMode BlockEngineMode; + IOutputStream* ExprOutputStream; + NKikimr::NUdf::ICountersProvider* CountersProvider_; + ETranslationMode TranslationMode_; + ui16 SyntaxVersion_; + ui64 NativeYtTypeFlags_; + TMaybe<ui64> DeterministicTimeProviderSeed_; + bool UseSystemColumns; + bool UseWorkerPool; + + TWorkerFactoryOptions( + IProgramFactoryPtr Factory, + const TInputSpecBase& InputSpec, + const TOutputSpecBase& OutputSpec, + TStringBuf Query, + TIntrusivePtr<NKikimr::NMiniKQL::IMutableFunctionRegistry> FuncRegistry, + IModuleResolver::TPtr ModuleResolver, + const TUserDataTable& UserData, + const THashMap<TString, TString>& Modules, + TString LLVMSettings, + EBlockEngineMode BlockEngineMode, + IOutputStream* ExprOutputStream, + NKikimr::NUdf::ICountersProvider* CountersProvider, + ETranslationMode translationMode, + ui16 syntaxVersion, + ui64 nativeYtTypeFlags, + TMaybe<ui64> deterministicTimeProviderSeed, + bool useSystemColumns, + bool useWorkerPool + ) + : Factory(std::move(Factory)) + , InputSpec(InputSpec) + , OutputSpec(OutputSpec) + , Query(Query) + , FuncRegistry(std::move(FuncRegistry)) + , ModuleResolver(std::move(ModuleResolver)) + , UserData(UserData) + , Modules(Modules) + , LLVMSettings(std::move(LLVMSettings)) + , BlockEngineMode(BlockEngineMode) + , ExprOutputStream(ExprOutputStream) + , CountersProvider_(CountersProvider) + , TranslationMode_(translationMode) + , SyntaxVersion_(syntaxVersion) + , NativeYtTypeFlags_(nativeYtTypeFlags) + , DeterministicTimeProviderSeed_(deterministicTimeProviderSeed) + , UseSystemColumns(useSystemColumns) + , UseWorkerPool(useWorkerPool) + { + } + }; + + template <typename TBase> + class TWorkerFactory: public TBase { + private: + IProgramFactoryPtr Factory_; + + protected: + TIntrusivePtr<NKikimr::NMiniKQL::IMutableFunctionRegistry> FuncRegistry_; + const TUserDataTable& UserData_; + TExprContext ExprContext_; + TExprNode::TPtr ExprRoot_; + TString SerializedProgram_; + TVector<const TStructExprType*> InputTypes_; + TVector<const TStructExprType*> OriginalInputTypes_; + TVector<const TStructExprType*> RawInputTypes_; + const TTypeAnnotationNode* OutputType_; + const TTypeAnnotationNode* RawOutputType_; + TVector<THashSet<TString>> AllColumns_; + TVector<THashSet<TString>> UsedColumns_; + TString LLVMSettings_; + EBlockEngineMode BlockEngineMode_; + IOutputStream* ExprOutputStream_; + NKikimr::NUdf::ICountersProvider* CountersProvider_; + ui64 NativeYtTypeFlags_; + TMaybe<ui64> DeterministicTimeProviderSeed_; + bool UseSystemColumns_; + bool UseWorkerPool_; + TVector<THolder<IWorker>> WorkerPool_; + + public: + TWorkerFactory(TWorkerFactoryOptions, EProcessorMode); + + public: + NYT::TNode MakeInputSchema(ui32) const override; + NYT::TNode MakeInputSchema() const override; + NYT::TNode MakeOutputSchema() const override; + NYT::TNode MakeOutputSchema(ui32) const override; + NYT::TNode MakeOutputSchema(TStringBuf) const override; + NYT::TNode MakeFullOutputSchema() const override; + const THashSet<TString>& GetUsedColumns(ui32 inputIndex) const override; + const THashSet<TString>& GetUsedColumns() const override; + TIssues GetIssues() const override; + TString GetCompiledProgram() override; + + protected: + void ReturnWorker(IWorker* worker) override; + + private: + TExprNode::TPtr Compile(TStringBuf query, + ETranslationMode mode, + IModuleResolver::TPtr moduleResolver, + ui16 syntaxVersion, + const THashMap<TString, TString>& modules, + const TInputSpecBase& inputSpec, + const TOutputSpecBase& outputSpec, + EProcessorMode processorMode); + }; + + class TPullStreamWorkerFactory final: public TWorkerFactory<IPullStreamWorkerFactory> { + public: + explicit TPullStreamWorkerFactory(TWorkerFactoryOptions options) + : TWorkerFactory(std::move(options), EProcessorMode::PullStream) + { + } + + public: + TWorkerHolder<IPullStreamWorker> MakeWorker() override; + }; + + class TPullListWorkerFactory final: public TWorkerFactory<IPullListWorkerFactory> { + public: + explicit TPullListWorkerFactory(TWorkerFactoryOptions options) + : TWorkerFactory(std::move(options), EProcessorMode::PullList) + { + } + + public: + TWorkerHolder<IPullListWorker> MakeWorker() override; + }; + + class TPushStreamWorkerFactory final: public TWorkerFactory<IPushStreamWorkerFactory> { + public: + explicit TPushStreamWorkerFactory(TWorkerFactoryOptions options) + : TWorkerFactory(std::move(options), EProcessorMode::PushStream) + { + } + + public: + TWorkerHolder<IPushStreamWorker> MakeWorker() override; + }; + } +} diff --git a/yql/essentials/public/purecalc/common/wrappers.cpp b/yql/essentials/public/purecalc/common/wrappers.cpp new file mode 100644 index 00000000000..c808d7b3940 --- /dev/null +++ b/yql/essentials/public/purecalc/common/wrappers.cpp @@ -0,0 +1 @@ +#include "wrappers.h" diff --git a/yql/essentials/public/purecalc/common/wrappers.h b/yql/essentials/public/purecalc/common/wrappers.h new file mode 100644 index 00000000000..4d65e012716 --- /dev/null +++ b/yql/essentials/public/purecalc/common/wrappers.h @@ -0,0 +1,70 @@ +#pragma once + +#include "fwd.h" + +#include <util/generic/ptr.h> + +namespace NYql::NPureCalc::NPrivate { + template <typename TNew, typename TOld, typename TFunctor> + class TMappingStream final: public IStream<TNew> { + private: + THolder<IStream<TOld>> Old_; + TFunctor Functor_; + + public: + TMappingStream(THolder<IStream<TOld>> old, TFunctor functor) + : Old_(std::move(old)) + , Functor_(std::move(functor)) + { + } + + public: + TNew Fetch() override { + return Functor_(Old_->Fetch()); + } + }; + + template <typename TNew, typename TOld, typename TFunctor> + class TMappingConsumer final: public IConsumer<TNew> { + private: + THolder<IConsumer<TOld>> Old_; + TFunctor Functor_; + + public: + TMappingConsumer(THolder<IConsumer<TOld>> old, TFunctor functor) + : Old_(std::move(old)) + , Functor_(std::move(functor)) + { + } + + public: + void OnObject(TNew object) override { + Old_->OnObject(Functor_(object)); + } + + void OnFinish() override { + Old_->OnFinish(); + } + }; + + template <typename T, typename C> + class TNonOwningConsumer final: public IConsumer<T> { + private: + C Consumer; + + public: + explicit TNonOwningConsumer(const C& consumer) + : Consumer(consumer) + { + } + + public: + void OnObject(T t) override { + Consumer->OnObject(t); + } + + void OnFinish() override { + Consumer->OnFinish(); + } + }; +} diff --git a/yql/essentials/public/purecalc/common/ya.make b/yql/essentials/public/purecalc/common/ya.make new file mode 100644 index 00000000000..98d002f2a59 --- /dev/null +++ b/yql/essentials/public/purecalc/common/ya.make @@ -0,0 +1,21 @@ +LIBRARY() + +INCLUDE(ya.make.inc) + +PEERDIR( + contrib/ydb/library/yql/providers/yt/codec/codegen + yql/essentials/providers/config + yql/essentials/minikql/computation/llvm14 + yql/essentials/minikql/invoke_builtins/llvm14 + yql/essentials/minikql/comp_nodes/llvm14 + yql/essentials/parser/pg_wrapper + yql/essentials/parser/pg_wrapper/interface + yql/essentials/sql/pg +) + +END() + +RECURSE( + no_llvm +) + diff --git a/yql/essentials/public/purecalc/common/ya.make.inc b/yql/essentials/public/purecalc/common/ya.make.inc new file mode 100644 index 00000000000..eb9387da4f8 --- /dev/null +++ b/yql/essentials/public/purecalc/common/ya.make.inc @@ -0,0 +1,52 @@ +SRCDIR( + yql/essentials/public/purecalc/common +) + +ADDINCL( + yql/essentials/public/purecalc/common +) + +SRCS( + compile_mkql.cpp + fwd.cpp + inspect_input.cpp + interface.cpp + logger_init.cpp + names.cpp + processor_mode.cpp + program_factory.cpp + transformations/align_output_schema.cpp + transformations/extract_used_columns.cpp + transformations/output_columns_filter.cpp + transformations/replace_table_reads.cpp + transformations/root_to_blocks.cpp + transformations/type_annotation.cpp + transformations/utils.cpp + type_from_schema.cpp + worker.cpp + worker_factory.cpp + wrappers.cpp +) + +PEERDIR( + yql/essentials/ast + yql/essentials/core/services + yql/essentials/core/services/mounts + yql/essentials/core/user_data + yql/essentials/utils/backtrace + yql/essentials/utils/log + yql/essentials/core + yql/essentials/core/type_ann + yql/essentials/providers/common/codec + yql/essentials/providers/common/comp_nodes + yql/essentials/providers/common/mkql + yql/essentials/providers/common/provider + yql/essentials/providers/common/schema/expr + yql/essentials/providers/common/udf_resolve + yql/essentials/providers/common/arrow_resolve +) + +YQL_LAST_ABI_VERSION() + +GENERATE_ENUM_SERIALIZATION(interface.h) + diff --git a/yql/essentials/public/purecalc/examples/protobuf/main.cpp b/yql/essentials/public/purecalc/examples/protobuf/main.cpp new file mode 100644 index 00000000000..2cf9ff47360 --- /dev/null +++ b/yql/essentials/public/purecalc/examples/protobuf/main.cpp @@ -0,0 +1,133 @@ +#include <yql/essentials/public/purecalc/examples/protobuf/main.pb.h> + +#include <yql/essentials/public/purecalc/purecalc.h> +#include <yql/essentials/public/purecalc/io_specs/protobuf/spec.h> +#include <yql/essentials/public/purecalc/helpers/stream/stream_from_vector.h> + +using namespace NYql::NPureCalc; +using namespace NExampleProtos; + +void PullStreamExample(IProgramFactoryPtr); +void PushStreamExample(IProgramFactoryPtr); +void PrecompileExample(IProgramFactoryPtr factory); +THolder<IStream<TInput*>> MakeInput(); + +class TConsumer: public IConsumer<TOutput*> { +public: + void OnObject(TOutput* message) override { + Cout << "path = " << message->GetPath() << Endl; + Cout << "host = " << message->GetHost() << Endl; + } + + void OnFinish() override { + Cout << "end" << Endl; + } +}; + +const char* Query = R"( + $a = (SELECT * FROM Input); + $b = (SELECT CAST(Url::GetTail(Url) AS Utf8) AS Path, CAST(Url::GetHost(Url) AS Utf8) AS Host, Ip FROM $a); + $c = (SELECT Path, Host FROM $b WHERE Path IS NOT NULL AND Host IS NOT NULL AND Ip::IsIPv4(Ip::FromString(Ip))); + $d = (SELECT Unwrap(Path) AS Path, Unwrap(Host) AS Host FROM $c); + SELECT * FROM $d; +)"; + +int main(int argc, char** argv) { + try { + auto factory = MakeProgramFactory( + TProgramFactoryOptions().SetUDFsDir(argc > 1 ? argv[1] : "../../../../udfs")); + + Cout << "Pull stream:" << Endl; + PullStreamExample(factory); + + Cout << Endl; + Cout << "Push stream:" << Endl; + PushStreamExample(factory); + + Cout << Endl; + Cout << "Pull stream with pre-compilation:" << Endl; + PrecompileExample(factory); + } catch (const TCompileError& err) { + Cerr << err.GetIssues() << Endl; + Cerr << err.what() << Endl; + } +} + +void PullStreamExample(IProgramFactoryPtr factory) { + auto program = factory->MakePullStreamProgram( + TProtobufInputSpec<TInput>(), + TProtobufOutputSpec<TOutput>(), + Query, + ETranslationMode::SQL); + + auto result = program->Apply(MakeInput()); + + while (auto* message = result->Fetch()) { + Cout << "path = " << message->GetPath() << Endl; + Cout << "host = " << message->GetHost() << Endl; + } +} + +void PushStreamExample(IProgramFactoryPtr factory) { + auto program = factory->MakePushStreamProgram( + TProtobufInputSpec<TInput>(), + TProtobufOutputSpec<TOutput>(), + Query, + ETranslationMode::SQL); + + auto consumer = program->Apply(MakeHolder<TConsumer>()); + + auto input = MakeInput(); + while (auto* message = input->Fetch()) { + consumer->OnObject(message); + } + consumer->OnFinish(); +} + +void PrecompileExample(IProgramFactoryPtr factory) { + TString prg; + { + auto program = factory->MakePullStreamProgram( + TProtobufInputSpec<TInput>(), + TProtobufOutputSpec<TOutput>(), + Query, + ETranslationMode::SQL); + + prg = program->GetCompiledProgram(); + } + + auto program = factory->MakePullStreamProgram( + TProtobufInputSpec<TInput>(), + TProtobufOutputSpec<TOutput>(), + prg, + ETranslationMode::Mkql); + + auto result = program->Apply(MakeInput()); + + while (auto* message = result->Fetch()) { + Cout << "path = " << message->GetPath() << Endl; + Cout << "host = " << message->GetHost() << Endl; + } +} + +THolder<IStream<TInput*>> MakeInput() { + TVector<TInput> input; + + { + auto& message = input.emplace_back(); + message.SetUrl("https://news.yandex.ru/Moscow/index.html?from=index"); + message.SetIp("83.220.231.160"); + } + { + auto& message = input.emplace_back(); + message.SetUrl("https://music.yandex.ru/radio/"); + message.SetIp("83.220.231.161"); + } + { + auto& message = input.emplace_back(); + message.SetUrl("https://yandex.ru/maps/?ll=141.475401%2C11.581666&spn=1.757813%2C1.733096&z=7&l=map%2Cstv%2Csta&mode=search&panorama%5Bpoint%5D=141.476317%2C11.582710&panorama%5Bdirection%5D=177.241445%2C-15.219821&panorama%5Bspan%5D=107.410156%2C61.993317"); + message.SetIp("::ffff:77.75.155.3"); + } + + return StreamFromVector(std::move(input)); +} diff --git a/yql/essentials/public/purecalc/examples/protobuf/main.proto b/yql/essentials/public/purecalc/examples/protobuf/main.proto new file mode 100644 index 00000000000..54fd15e226d --- /dev/null +++ b/yql/essentials/public/purecalc/examples/protobuf/main.proto @@ -0,0 +1,11 @@ +package NExampleProtos; + +message TInput { + required string Url = 1; + required string Ip = 2; +} + +message TOutput { + required string Path = 1; + required string Host = 2; +} diff --git a/yql/essentials/public/purecalc/examples/protobuf/ut/canondata/exectest.run_protobuf_/log.out b/yql/essentials/public/purecalc/examples/protobuf/ut/canondata/exectest.run_protobuf_/log.out new file mode 100644 index 00000000000..1ec34e485d2 --- /dev/null +++ b/yql/essentials/public/purecalc/examples/protobuf/ut/canondata/exectest.run_protobuf_/log.out @@ -0,0 +1,18 @@ +Pull stream: +path = /Moscow/index.html?from=index +host = news.yandex.ru +path = /radio/ +host = music.yandex.ru + +Push stream: +path = /Moscow/index.html?from=index +host = news.yandex.ru +path = /radio/ +host = music.yandex.ru +end + +Pull stream with pre-compilation: +path = /Moscow/index.html?from=index +host = news.yandex.ru +path = /radio/ +host = music.yandex.ru diff --git a/yql/essentials/public/purecalc/examples/protobuf/ut/canondata/result.json b/yql/essentials/public/purecalc/examples/protobuf/ut/canondata/result.json new file mode 100644 index 00000000000..96a5814765e --- /dev/null +++ b/yql/essentials/public/purecalc/examples/protobuf/ut/canondata/result.json @@ -0,0 +1,5 @@ +{ + "exectest.run[protobuf]": { + "uri": "file://exectest.run_protobuf_/log.out" + } +} diff --git a/yql/essentials/public/purecalc/examples/protobuf/ut/ya.make b/yql/essentials/public/purecalc/examples/protobuf/ut/ya.make new file mode 100644 index 00000000000..55feb21d95c --- /dev/null +++ b/yql/essentials/public/purecalc/examples/protobuf/ut/ya.make @@ -0,0 +1,15 @@ +IF (NOT SANITIZER_TYPE AND NOT OPENSOURCE) + +EXECTEST() + +RUN(protobuf ${ARCADIA_BUILD_ROOT}/yql/essentials/udfs STDOUT log.out CANONIZE_LOCALLY log.out) + +DEPENDS( + yql/essentials/public/purecalc/examples/protobuf + yql/essentials/udfs/common/url_base + yql/essentials/udfs/common/ip_base +) + +END() + +ENDIF() diff --git a/yql/essentials/public/purecalc/examples/protobuf/ya.make b/yql/essentials/public/purecalc/examples/protobuf/ya.make new file mode 100644 index 00000000000..c50a3c4af25 --- /dev/null +++ b/yql/essentials/public/purecalc/examples/protobuf/ya.make @@ -0,0 +1,27 @@ +PROGRAM() + +SRCS( + main.proto + main.cpp +) + +PEERDIR( + yql/essentials/public/purecalc + yql/essentials/public/purecalc/io_specs/protobuf + yql/essentials/public/purecalc/helpers/stream +) + + + YQL_LAST_ABI_VERSION() + + +END() + +RECURSE_ROOT_RELATIVE( + yql/essentials/udfs/common/url_base + yql/essentials/udfs/common/ip_base +) + +RECURSE_FOR_TESTS( + ut +) diff --git a/yql/essentials/public/purecalc/examples/protobuf_pull_list/main.cpp b/yql/essentials/public/purecalc/examples/protobuf_pull_list/main.cpp new file mode 100644 index 00000000000..b3e27cec10f --- /dev/null +++ b/yql/essentials/public/purecalc/examples/protobuf_pull_list/main.cpp @@ -0,0 +1,75 @@ +#include <yql/essentials/public/purecalc/examples/protobuf_pull_list/main.pb.h> + +#include <yql/essentials/public/purecalc/purecalc.h> +#include <yql/essentials/public/purecalc/io_specs/protobuf/spec.h> +#include <yql/essentials/public/purecalc/helpers/stream/stream_from_vector.h> + +using namespace NYql::NPureCalc; +using namespace NExampleProtos; + +const char* Query = R"( + SELECT + Url, + COUNT(*) AS Hits + FROM + Input + GROUP BY + Url + ORDER BY + Url +)"; + +THolder<IStream<TInput*>> MakeInput(); + +int main() { + try { + auto factory = MakeProgramFactory(); + + auto program = factory->MakePullListProgram( + TProtobufInputSpec<TInput>(), + TProtobufOutputSpec<TOutput>(), + Query, + ETranslationMode::SQL + ); + + auto result = program->Apply(MakeInput()); + + while (auto* message = result->Fetch()) { + Cout << "url = " << message->GetUrl() << Endl; + Cout << "hits = " << message->GetHits() << Endl; + } + } catch (TCompileError& e) { + Cout << e.GetIssues(); + } +} + +THolder<IStream<TInput*>> MakeInput() { + TVector<TInput> input; + + { + auto& message = input.emplace_back(); + message.SetUrl("https://yandex.ru/a"); + } + { + auto& message = input.emplace_back(); + message.SetUrl("https://yandex.ru/a"); + } + { + auto& message = input.emplace_back(); + message.SetUrl("https://yandex.ru/b"); + } + { + auto& message = input.emplace_back(); + message.SetUrl("https://yandex.ru/c"); + } + { + auto& message = input.emplace_back(); + message.SetUrl("https://yandex.ru/b"); + } + { + auto& message = input.emplace_back(); + message.SetUrl("https://yandex.ru/b"); + } + + return StreamFromVector(std::move(input)); +} diff --git a/yql/essentials/public/purecalc/examples/protobuf_pull_list/main.proto b/yql/essentials/public/purecalc/examples/protobuf_pull_list/main.proto new file mode 100644 index 00000000000..2766c4b8c0c --- /dev/null +++ b/yql/essentials/public/purecalc/examples/protobuf_pull_list/main.proto @@ -0,0 +1,10 @@ +package NExampleProtos; + +message TInput { + required string Url = 1; +} + +message TOutput { + required string Url = 1; + required uint64 Hits = 2; +} diff --git a/yql/essentials/public/purecalc/examples/protobuf_pull_list/ut/canondata/exectest.run_protobuf_pull_list_/log.out b/yql/essentials/public/purecalc/examples/protobuf_pull_list/ut/canondata/exectest.run_protobuf_pull_list_/log.out new file mode 100644 index 00000000000..0a799ed4b09 --- /dev/null +++ b/yql/essentials/public/purecalc/examples/protobuf_pull_list/ut/canondata/exectest.run_protobuf_pull_list_/log.out @@ -0,0 +1,6 @@ +url = https://yandex.ru/a +hits = 2 +url = https://yandex.ru/b +hits = 3 +url = https://yandex.ru/c +hits = 1 diff --git a/yql/essentials/public/purecalc/examples/protobuf_pull_list/ut/canondata/result.json b/yql/essentials/public/purecalc/examples/protobuf_pull_list/ut/canondata/result.json new file mode 100644 index 00000000000..668467cc850 --- /dev/null +++ b/yql/essentials/public/purecalc/examples/protobuf_pull_list/ut/canondata/result.json @@ -0,0 +1,6 @@ +{ + "exectest.run[protobuf_pull_list]": { + "checksum": "29bf513fe0ca6f81ae076213a1c7801c", + "uri": "file://exectest.run_protobuf_pull_list_/log.out" + } +}
\ No newline at end of file diff --git a/yql/essentials/public/purecalc/examples/protobuf_pull_list/ut/ya.make b/yql/essentials/public/purecalc/examples/protobuf_pull_list/ut/ya.make new file mode 100644 index 00000000000..3da0d508d17 --- /dev/null +++ b/yql/essentials/public/purecalc/examples/protobuf_pull_list/ut/ya.make @@ -0,0 +1,9 @@ +EXECTEST() + +RUN(protobuf_pull_list STDOUT log.out CANONIZE_LOCALLY log.out) + +DEPENDS( + yql/essentials/public/purecalc/examples/protobuf_pull_list +) + +END() diff --git a/yql/essentials/public/purecalc/examples/protobuf_pull_list/ya.make b/yql/essentials/public/purecalc/examples/protobuf_pull_list/ya.make new file mode 100644 index 00000000000..a102f5fb2ca --- /dev/null +++ b/yql/essentials/public/purecalc/examples/protobuf_pull_list/ya.make @@ -0,0 +1,20 @@ +PROGRAM() + +SRCS( + main.proto + main.cpp +) + +PEERDIR( + yql/essentials/public/purecalc + yql/essentials/public/purecalc/io_specs/protobuf + yql/essentials/public/purecalc/helpers/stream +) + +YQL_LAST_ABI_VERSION() + +END() + +RECURSE_FOR_TESTS( + ut +) diff --git a/yql/essentials/public/purecalc/examples/ya.make b/yql/essentials/public/purecalc/examples/ya.make new file mode 100644 index 00000000000..9c0e9259a08 --- /dev/null +++ b/yql/essentials/public/purecalc/examples/ya.make @@ -0,0 +1,4 @@ +RECURSE( + protobuf + protobuf_pull_list +) diff --git a/yql/essentials/public/purecalc/helpers/protobuf/schema_from_proto.cpp b/yql/essentials/public/purecalc/helpers/protobuf/schema_from_proto.cpp new file mode 100644 index 00000000000..6927c46240c --- /dev/null +++ b/yql/essentials/public/purecalc/helpers/protobuf/schema_from_proto.cpp @@ -0,0 +1,202 @@ +#include "schema_from_proto.h" + +#include <yt/yt_proto/yt/formats/extension.pb.h> + +#include <util/generic/algorithm.h> +#include <util/generic/string.h> +#include <util/string/printf.h> +#include <util/string/vector.h> + +namespace pb = google::protobuf; + +namespace NYql { + namespace NPureCalc { + + TProtoSchemaOptions::TProtoSchemaOptions() + : EnumPolicy(EEnumPolicy::Int32) + , ListIsOptional(false) + { + } + + TProtoSchemaOptions& TProtoSchemaOptions::SetEnumPolicy(EEnumPolicy policy) { + EnumPolicy = policy; + return *this; + } + + TProtoSchemaOptions& TProtoSchemaOptions::SetListIsOptional(bool value) { + ListIsOptional = value; + return *this; + } + + TProtoSchemaOptions& TProtoSchemaOptions::SetFieldRenames( + THashMap<TString, TString> fieldRenames + ) { + FieldRenames = std::move(fieldRenames); + return *this; + } + + namespace { + EEnumFormatType EnumFormatTypeWithYTFlag(const pb::FieldDescriptor& enumField, EEnumFormatType defaultEnumFormatType) { + auto flags = enumField.options().GetRepeatedExtension(NYT::flags); + for (auto flag : flags) { + if (flag == NYT::EWrapperFieldFlag::ENUM_INT) { + return EEnumFormatType::Int32; + } else if (flag == NYT::EWrapperFieldFlag::ENUM_STRING) { + return EEnumFormatType::String; + } + } + return defaultEnumFormatType; + } + } + + EEnumFormatType EnumFormatType(const pb::FieldDescriptor& enumField, EEnumPolicy enumPolicy) { + switch (enumPolicy) { + case EEnumPolicy::Int32: + return EEnumFormatType::Int32; + case EEnumPolicy::String: + return EEnumFormatType::String; + case EEnumPolicy::YTFlagDefaultInt32: + return EnumFormatTypeWithYTFlag(enumField, EEnumFormatType::Int32); + case EEnumPolicy::YTFlagDefaultString: + return EnumFormatTypeWithYTFlag(enumField, EEnumFormatType::String); + } + } + + namespace { + const char* FormatTypeName(const pb::FieldDescriptor* field, EEnumPolicy enumPolicy) { + switch (field->type()) { + case pb::FieldDescriptor::TYPE_DOUBLE: + return "Double"; + case pb::FieldDescriptor::TYPE_FLOAT: + return "Float"; + case pb::FieldDescriptor::TYPE_INT64: + case pb::FieldDescriptor::TYPE_SFIXED64: + case pb::FieldDescriptor::TYPE_SINT64: + return "Int64"; + case pb::FieldDescriptor::TYPE_UINT64: + case pb::FieldDescriptor::TYPE_FIXED64: + return "Uint64"; + case pb::FieldDescriptor::TYPE_INT32: + case pb::FieldDescriptor::TYPE_SFIXED32: + case pb::FieldDescriptor::TYPE_SINT32: + return "Int32"; + case pb::FieldDescriptor::TYPE_UINT32: + case pb::FieldDescriptor::TYPE_FIXED32: + return "Uint32"; + case pb::FieldDescriptor::TYPE_BOOL: + return "Bool"; + case pb::FieldDescriptor::TYPE_STRING: + return "Utf8"; + case pb::FieldDescriptor::TYPE_BYTES: + return "String"; + case pb::FieldDescriptor::TYPE_ENUM: + switch (EnumFormatType(*field, enumPolicy)) { + case EEnumFormatType::Int32: + return "Int32"; + case EEnumFormatType::String: + return "String"; + } + default: + ythrow yexception() << "Unsupported protobuf type: " << field->type_name() + << ", field: " << field->name() << ", " << int(field->type()); + } + } + } + + NYT::TNode MakeSchemaFromProto(const pb::Descriptor& descriptor, TVector<const pb::Descriptor*>& nested, const TProtoSchemaOptions& options) { + if (Find(nested, &descriptor) != nested.end()) { + TVector<TString> nestedNames; + for (const auto* d : nested) { + nestedNames.push_back(d->full_name()); + } + nestedNames.push_back(descriptor.full_name()); + ythrow yexception() << Sprintf("recursive messages are not supported (%s)", + JoinStrings(nestedNames, "->").c_str()); + } + nested.push_back(&descriptor); + + auto items = NYT::TNode::CreateList(); + for (int fieldNo = 0; fieldNo < descriptor.field_count(); ++fieldNo) { + const auto& fieldDescriptor = *descriptor.field(fieldNo); + + auto name = fieldDescriptor.name(); + if ( + auto renamePtr = options.FieldRenames.FindPtr(name); + nested.size() == 1 && renamePtr + ) { + name = *renamePtr; + } + + NYT::TNode itemType; + if (fieldDescriptor.type() == pb::FieldDescriptor::TYPE_MESSAGE) { + itemType = MakeSchemaFromProto(*fieldDescriptor.message_type(), nested, options); + } else { + itemType = NYT::TNode::CreateList(); + itemType.Add("DataType"); + itemType.Add(FormatTypeName(&fieldDescriptor, options.EnumPolicy)); + } + switch (fieldDescriptor.label()) { + case pb::FieldDescriptor::LABEL_OPTIONAL: + { + auto optionalType = NYT::TNode::CreateList(); + optionalType.Add("OptionalType"); + optionalType.Add(std::move(itemType)); + itemType = std::move(optionalType); + } + break; + case pb::FieldDescriptor::LABEL_REQUIRED: + break; + case pb::FieldDescriptor::LABEL_REPEATED: + { + auto listType = NYT::TNode::CreateList(); + listType.Add("ListType"); + listType.Add(std::move(itemType)); + itemType = std::move(listType); + if (options.ListIsOptional) { + itemType = NYT::TNode::CreateList().Add("OptionalType").Add(std::move(itemType)); + } + } + break; + default: + ythrow yexception() << "Unknown protobuf label: " << (ui32)fieldDescriptor.label() << ", field: " << name; + } + + auto itemNode = NYT::TNode::CreateList(); + itemNode.Add(name); + itemNode.Add(std::move(itemType)); + + items.Add(std::move(itemNode)); + } + auto root = NYT::TNode::CreateList(); + root.Add("StructType"); + root.Add(std::move(items)); + + nested.pop_back(); + return root; + } + + NYT::TNode MakeSchemaFromProto(const pb::Descriptor& descriptor, const TProtoSchemaOptions& options) { + TVector<const pb::Descriptor*> nested; + return MakeSchemaFromProto(descriptor, nested, options); + } + + NYT::TNode MakeVariantSchemaFromProtos(const TVector<const pb::Descriptor*>& descriptors, const TProtoSchemaOptions& options) { + Y_ENSURE(options.FieldRenames.empty(), "Renames are not supported in variant mode"); + + auto tupleItems = NYT::TNode::CreateList(); + for (auto descriptor : descriptors) { + tupleItems.Add(MakeSchemaFromProto(*descriptor, options)); + } + + auto tupleType = NYT::TNode::CreateList(); + tupleType.Add("TupleType"); + tupleType.Add(std::move(tupleItems)); + + auto variantType = NYT::TNode::CreateList(); + variantType.Add("VariantType"); + variantType.Add(std::move(tupleType)); + + return variantType; + } + } +} diff --git a/yql/essentials/public/purecalc/helpers/protobuf/schema_from_proto.h b/yql/essentials/public/purecalc/helpers/protobuf/schema_from_proto.h new file mode 100644 index 00000000000..168c654ac78 --- /dev/null +++ b/yql/essentials/public/purecalc/helpers/protobuf/schema_from_proto.h @@ -0,0 +1,60 @@ +#pragma once + +#include <library/cpp/yson/node/node.h> + +#include <util/generic/hash.h> +#include <util/generic/string.h> + +#include <google/protobuf/descriptor.h> + + +namespace NYql { + namespace NPureCalc { + enum class EEnumPolicy { + Int32, + String, + YTFlagDefaultInt32, + YTFlagDefaultString + }; + + enum class EEnumFormatType { + Int32, + String + }; + + /** + * Options that customize building of struct type from protobuf descriptor. + */ + struct TProtoSchemaOptions { + public: + EEnumPolicy EnumPolicy; + bool ListIsOptional; + THashMap<TString, TString> FieldRenames; + + public: + TProtoSchemaOptions(); + + public: + TProtoSchemaOptions& SetEnumPolicy(EEnumPolicy); + + TProtoSchemaOptions& SetListIsOptional(bool); + + TProtoSchemaOptions& SetFieldRenames( + THashMap<TString, TString> fieldRenames + ); + }; + + EEnumFormatType EnumFormatType(const google::protobuf::FieldDescriptor& enumField, EEnumPolicy enumPolicy); + + /** + * Build struct type from a protobuf descriptor. The returned yson can be loaded into a struct annotation node + * using the ParseTypeFromYson function. + */ + NYT::TNode MakeSchemaFromProto(const google::protobuf::Descriptor&, const TProtoSchemaOptions& = {}); + + /** + * Build variant over tuple type from protobuf descriptors. + */ + NYT::TNode MakeVariantSchemaFromProtos(const TVector<const google::protobuf::Descriptor*>&, const TProtoSchemaOptions& = {}); + } +} diff --git a/yql/essentials/public/purecalc/helpers/protobuf/ya.make b/yql/essentials/public/purecalc/helpers/protobuf/ya.make new file mode 100644 index 00000000000..11300baba84 --- /dev/null +++ b/yql/essentials/public/purecalc/helpers/protobuf/ya.make @@ -0,0 +1,14 @@ +LIBRARY() + +SRCS( + schema_from_proto.cpp +) + +PEERDIR( + contrib/libs/protobuf + library/cpp/yson/node + yt/yt_proto/yt/formats + yt/yt_proto/yt/formats +) + +END() diff --git a/yql/essentials/public/purecalc/helpers/stream/stream_from_vector.cpp b/yql/essentials/public/purecalc/helpers/stream/stream_from_vector.cpp new file mode 100644 index 00000000000..e1aed5d6899 --- /dev/null +++ b/yql/essentials/public/purecalc/helpers/stream/stream_from_vector.cpp @@ -0,0 +1 @@ +#include "stream_from_vector.h" diff --git a/yql/essentials/public/purecalc/helpers/stream/stream_from_vector.h b/yql/essentials/public/purecalc/helpers/stream/stream_from_vector.h new file mode 100644 index 00000000000..a2a50558003 --- /dev/null +++ b/yql/essentials/public/purecalc/helpers/stream/stream_from_vector.h @@ -0,0 +1,40 @@ +#pragma once + +#include <yql/essentials/public/purecalc/common/interface.h> + +namespace NYql { + namespace NPureCalc { + namespace NPrivate { + template <typename T> + class TVectorStream final: public IStream<T*> { + private: + size_t I_; + TVector<T> Data_; + + public: + explicit TVectorStream(TVector<T> data) + : I_(0) + , Data_(std::move(data)) + { + } + + public: + T* Fetch() override { + if (I_ >= Data_.size()) { + return nullptr; + } else { + return &Data_[I_++]; + } + } + }; + } + + /** + * Convert vector into a purecalc stream. + */ + template <typename T> + THolder<IStream<T*>> StreamFromVector(TVector<T> data) { + return MakeHolder<NPrivate::TVectorStream<T>>(std::move(data)); + } + } +} diff --git a/yql/essentials/public/purecalc/helpers/stream/ya.make b/yql/essentials/public/purecalc/helpers/stream/ya.make new file mode 100644 index 00000000000..f40bb9af559 --- /dev/null +++ b/yql/essentials/public/purecalc/helpers/stream/ya.make @@ -0,0 +1,13 @@ +LIBRARY() + +SRCS( + stream_from_vector.cpp +) + +PEERDIR( + yql/essentials/public/purecalc/common +) + +YQL_LAST_ABI_VERSION() + +END() diff --git a/yql/essentials/public/purecalc/helpers/ya.make b/yql/essentials/public/purecalc/helpers/ya.make new file mode 100644 index 00000000000..49cff31687f --- /dev/null +++ b/yql/essentials/public/purecalc/helpers/ya.make @@ -0,0 +1,8 @@ +LIBRARY() + +PEERDIR( + yql/essentials/public/purecalc/helpers/protobuf + yql/essentials/public/purecalc/helpers/stream +) + +END() diff --git a/yql/essentials/public/purecalc/io_specs/arrow/spec.cpp b/yql/essentials/public/purecalc/io_specs/arrow/spec.cpp new file mode 100644 index 00000000000..e7b755cb195 --- /dev/null +++ b/yql/essentials/public/purecalc/io_specs/arrow/spec.cpp @@ -0,0 +1,576 @@ +#include "spec.h" + +#include <yql/essentials/public/purecalc/common/names.h> + +#include <yql/essentials/minikql/computation/mkql_computation_node_holders.h> +#include <yql/essentials/minikql/computation/mkql_custom_list.h> +#include <yql/essentials/minikql/mkql_node_cast.h> +#include <yql/essentials/public/udf/arrow/udf_arrow_helpers.h> +#include <yql/essentials/utils/yql_panic.h> + +using namespace NYql::NPureCalc; +using namespace NKikimr::NUdf; +using namespace NKikimr::NMiniKQL; + +using IArrowIStream = typename TInputSpecTraits<TArrowInputSpec>::IInputStream; +using InputItemType = typename TInputSpecTraits<TArrowInputSpec>::TInputItemType; +using OutputItemType = typename TOutputSpecTraits<TArrowOutputSpec>::TOutputItemType; +using PullListReturnType = typename TOutputSpecTraits<TArrowOutputSpec>::TPullListReturnType; +using PullStreamReturnType = typename TOutputSpecTraits<TArrowOutputSpec>::TPullStreamReturnType; +using ConsumerType = typename TInputSpecTraits<TArrowInputSpec>::TConsumerType; + +namespace { + +template <typename T> +inline TVector<THolder<T>> VectorFromHolder(THolder<T> holder) { + TVector<THolder<T>> result; + result.push_back(std::move(holder)); + return result; +} + + +class TArrowIStreamImpl : public IArrowIStream { +private: + IArrowIStream* Underlying_; + // If we own Underlying_, than Owned_ == Underlying_; + // otherwise Owned_ is nullptr. + THolder<IArrowIStream> Owned_; + + TArrowIStreamImpl(IArrowIStream* underlying, THolder<IArrowIStream> owned) + : Underlying_(underlying) + , Owned_(std::move(owned)) + { + } + +public: + TArrowIStreamImpl(THolder<IArrowIStream> stream) + : TArrowIStreamImpl(stream.Get(), nullptr) + { + Owned_ = std::move(stream); + } + + TArrowIStreamImpl(IArrowIStream* stream) + : TArrowIStreamImpl(stream, nullptr) + { + } + + InputItemType Fetch() { + return Underlying_->Fetch(); + } +}; + + +/** + * Converts input Datums to unboxed values. + */ +class TArrowInputConverter { +protected: + const THolderFactory& Factory_; + TVector<ui32> DatumToMemberIDMap_; + size_t BatchLengthID_; + +public: + explicit TArrowInputConverter( + const TArrowInputSpec& inputSpec, + ui32 index, + IWorker* worker + ) + : Factory_(worker->GetGraph().GetHolderFactory()) + { + const NYT::TNode& inputSchema = inputSpec.GetSchema(index); + // Deduce the schema from the input MKQL type, if no is + // provided by <inputSpec>. + const NYT::TNode& schema = inputSchema.IsEntity() + ? worker->MakeInputSchema(index) + : inputSchema; + + const auto* type = worker->GetRawInputType(index); + + Y_ENSURE(type->IsStruct()); + Y_ENSURE(schema.ChildAsString(0) == "StructType"); + + const auto& members = schema.ChildAsList(1); + DatumToMemberIDMap_.resize(members.size()); + + for (size_t i = 0; i < DatumToMemberIDMap_.size(); i++) { + const auto& name = members[i].ChildAsString(0); + const auto& memberIndex = type->FindMemberIndex(name); + Y_ENSURE(memberIndex); + DatumToMemberIDMap_[i] = *memberIndex; + } + const auto& batchLengthID = type->FindMemberIndex(PurecalcBlockColumnLength); + Y_ENSURE(batchLengthID); + BatchLengthID_ = *batchLengthID; + } + + void DoConvert(arrow::compute::ExecBatch* batch, TUnboxedValue& result) { + size_t nvalues = DatumToMemberIDMap_.size(); + Y_ENSURE(nvalues == static_cast<size_t>(batch->num_values())); + + TUnboxedValue* datums = nullptr; + result = Factory_.CreateDirectArrayHolder(nvalues + 1, datums); + for (size_t i = 0; i < nvalues; i++) { + const ui32 id = DatumToMemberIDMap_[i]; + datums[id] = Factory_.CreateArrowBlock(std::move(batch->values[i])); + } + arrow::Datum length(std::make_shared<arrow::UInt64Scalar>(batch->length)); + datums[BatchLengthID_] = Factory_.CreateArrowBlock(std::move(length)); + } +}; + + +/** + * Converts unboxed values to output Datums (single-output program case). + */ +class TArrowOutputConverter { +protected: + const THolderFactory& Factory_; + TVector<ui32> DatumToMemberIDMap_; + THolder<arrow::compute::ExecBatch> Batch_; + size_t BatchLengthID_; + +public: + explicit TArrowOutputConverter( + const TArrowOutputSpec& outputSpec, + IWorker* worker + ) + : Factory_(worker->GetGraph().GetHolderFactory()) + { + Batch_.Reset(new arrow::compute::ExecBatch); + + const NYT::TNode& outputSchema = outputSpec.GetSchema(); + // Deduce the schema from the output MKQL type, if no is + // provided by <outputSpec>. + const NYT::TNode& schema = outputSchema.IsEntity() + ? worker->MakeOutputSchema() + : outputSchema; + + const auto* type = worker->GetRawOutputType(); + + Y_ENSURE(type->IsStruct()); + Y_ENSURE(schema.ChildAsString(0) == "StructType"); + + const auto* stype = AS_TYPE(NKikimr::NMiniKQL::TStructType, type); + + const auto& members = schema.ChildAsList(1); + DatumToMemberIDMap_.resize(members.size()); + + for (size_t i = 0; i < DatumToMemberIDMap_.size(); i++) { + const auto& name = members[i].ChildAsString(0); + const auto& memberIndex = stype->FindMemberIndex(name); + Y_ENSURE(memberIndex); + DatumToMemberIDMap_[i] = *memberIndex; + } + const auto& batchLengthID = stype->FindMemberIndex(PurecalcBlockColumnLength); + Y_ENSURE(batchLengthID); + BatchLengthID_ = *batchLengthID; + } + + OutputItemType DoConvert(TUnboxedValue value) { + OutputItemType batch = Batch_.Get(); + size_t nvalues = DatumToMemberIDMap_.size(); + + const auto& sizeDatum = TArrowBlock::From(value.GetElement(BatchLengthID_)).GetDatum(); + Y_ENSURE(sizeDatum.is_scalar()); + const auto& sizeScalar = sizeDatum.scalar(); + const auto& sizeData = arrow::internal::checked_cast<const arrow::UInt64Scalar&>(*sizeScalar); + const int64_t length = sizeData.value; + + TVector<arrow::Datum> datums(nvalues); + for (size_t i = 0; i < nvalues; i++) { + const ui32 id = DatumToMemberIDMap_[i]; + const auto& datum = TArrowBlock::From(value.GetElement(id)).GetDatum(); + datums[i] = datum; + if (datum.is_scalar()) { + continue; + } + Y_ENSURE(datum.length() == length); + } + + *batch = arrow::compute::ExecBatch(std::move(datums), length); + return batch; + } +}; + + +/** + * List (or, better, stream) of unboxed values. + * Used as an input value in pull workers. + */ +class TArrowListValue final: public TCustomListValue { +private: + mutable bool HasIterator_ = false; + THolder<IArrowIStream> Underlying_; + IWorker* Worker_; + TArrowInputConverter Converter_; + TScopedAlloc& ScopedAlloc_; + +public: + TArrowListValue( + TMemoryUsageInfo* memInfo, + const TArrowInputSpec& inputSpec, + ui32 index, + THolder<IArrowIStream> underlying, + IWorker* worker + ) + : TCustomListValue(memInfo) + , Underlying_(std::move(underlying)) + , Worker_(worker) + , Converter_(inputSpec, index, Worker_) + , ScopedAlloc_(Worker_->GetScopedAlloc()) + { + } + + ~TArrowListValue() override { + { + // This list value stored in the worker's computation graph and + // destroyed upon the computation graph's destruction. This brings + // us to an interesting situation: scoped alloc is acquired, worker + // and computation graph are half-way destroyed, and now it's our + // turn to die. The problem is, the underlying stream may own + // another worker. This happens when chaining programs. Now, to + // destroy that worker correctly, we need to release our scoped + // alloc (because that worker has its own computation graph and + // scoped alloc). + // By the way, note that we shouldn't interact with the worker here + // because worker is in the middle of its own destruction. So we're + // using our own reference to the scoped alloc. That reference is + // alive because scoped alloc destroyed after computation graph. + auto unguard = Unguard(ScopedAlloc_); + Underlying_.Destroy(); + } + } + + TUnboxedValue GetListIterator() const override { + YQL_ENSURE(!HasIterator_, "Only one pass over input is supported"); + HasIterator_ = true; + return TUnboxedValuePod(const_cast<TArrowListValue*>(this)); + } + + bool Next(TUnboxedValue& result) override { + arrow::compute::ExecBatch* batch; + { + auto unguard = Unguard(ScopedAlloc_); + batch = Underlying_->Fetch(); + } + + if (!batch) { + return false; + } + + Converter_.DoConvert(batch, result); + return true; + } + + EFetchStatus Fetch(TUnboxedValue& result) override { + if (Next(result)) { + return EFetchStatus::Ok; + } else { + return EFetchStatus::Finish; + } + } +}; + + +/** + * Arrow input stream for unboxed value lists. + */ +class TArrowListImpl final: public IStream<OutputItemType> { +protected: + TWorkerHolder<IPullListWorker> WorkerHolder_; + TArrowOutputConverter Converter_; + +public: + explicit TArrowListImpl( + const TArrowOutputSpec& outputSpec, + TWorkerHolder<IPullListWorker> worker + ) + : WorkerHolder_(std::move(worker)) + , Converter_(outputSpec, WorkerHolder_.Get()) + { + } + + OutputItemType Fetch() override { + TBindTerminator bind(WorkerHolder_->GetGraph().GetTerminator()); + + with_lock(WorkerHolder_->GetScopedAlloc()) { + TUnboxedValue value; + + if (!WorkerHolder_->GetOutputIterator().Next(value)) { + return TOutputSpecTraits<TArrowOutputSpec>::StreamSentinel; + } + + return Converter_.DoConvert(value); + } + } +}; + + +/** + * Arrow input stream for unboxed value streams. + */ +class TArrowStreamImpl final: public IStream<OutputItemType> { +protected: + TWorkerHolder<IPullStreamWorker> WorkerHolder_; + TArrowOutputConverter Converter_; + +public: + explicit TArrowStreamImpl(const TArrowOutputSpec& outputSpec, TWorkerHolder<IPullStreamWorker> worker) + : WorkerHolder_(std::move(worker)) + , Converter_(outputSpec, WorkerHolder_.Get()) + { + } + + OutputItemType Fetch() override { + TBindTerminator bind(WorkerHolder_->GetGraph().GetTerminator()); + + with_lock(WorkerHolder_->GetScopedAlloc()) { + TUnboxedValue value; + + auto status = WorkerHolder_->GetOutput().Fetch(value); + YQL_ENSURE(status != EFetchStatus::Yield, "Yield is not supported in pull mode"); + + if (status == EFetchStatus::Finish) { + return TOutputSpecTraits<TArrowOutputSpec>::StreamSentinel; + } + + return Converter_.DoConvert(value); + } + } +}; + + +/** + * Consumer which converts Datums to unboxed values and relays them to the + * worker. Used as a return value of the push processor's Process function. + */ +class TArrowConsumerImpl final: public IConsumer<arrow::compute::ExecBatch*> { +private: + TWorkerHolder<IPushStreamWorker> WorkerHolder_; + TArrowInputConverter Converter_; + +public: + explicit TArrowConsumerImpl( + const TArrowInputSpec& inputSpec, + TWorkerHolder<IPushStreamWorker> worker + ) + : TArrowConsumerImpl(inputSpec, 0, std::move(worker)) + { + } + + explicit TArrowConsumerImpl( + const TArrowInputSpec& inputSpec, + ui32 index, + TWorkerHolder<IPushStreamWorker> worker + ) + : WorkerHolder_(std::move(worker)) + , Converter_(inputSpec, index, WorkerHolder_.Get()) + { + } + + void OnObject(arrow::compute::ExecBatch* batch) override { + TBindTerminator bind(WorkerHolder_->GetGraph().GetTerminator()); + + with_lock(WorkerHolder_->GetScopedAlloc()) { + TUnboxedValue result; + Converter_.DoConvert(batch, result); + WorkerHolder_->Push(std::move(result)); + } + } + + void OnFinish() override { + TBindTerminator bind(WorkerHolder_->GetGraph().GetTerminator()); + + with_lock(WorkerHolder_->GetScopedAlloc()) { + WorkerHolder_->OnFinish(); + } + } +}; + + +/** + * Push relay used to convert generated unboxed value to a Datum and push it to + * the user's consumer. + */ +class TArrowPushRelayImpl: public IConsumer<const TUnboxedValue*> { +private: + THolder<IConsumer<OutputItemType>> Underlying_; + IWorker* Worker_; + TArrowOutputConverter Converter_; + +public: + TArrowPushRelayImpl( + const TArrowOutputSpec& outputSpec, + IPushStreamWorker* worker, + THolder<IConsumer<OutputItemType>> underlying + ) + : Underlying_(std::move(underlying)) + , Worker_(worker) + , Converter_(outputSpec, Worker_) + { + } + + // XXX: If you've read a comment in the TArrowListValue's destructor, you + // may be wondering why don't we do the same trick here. Well, that's + // because in push mode, consumer is destroyed before acquiring scoped alloc + // and destroying computation graph. + + void OnObject(const TUnboxedValue* value) override { + OutputItemType message = Converter_.DoConvert(*value); + auto unguard = Unguard(Worker_->GetScopedAlloc()); + Underlying_->OnObject(message); + } + + void OnFinish() override { + auto unguard = Unguard(Worker_->GetScopedAlloc()); + Underlying_->OnFinish(); + } +}; + + +template <typename TWorker> +void PrepareWorkerImpl(const TArrowInputSpec& inputSpec, TWorker* worker, + TVector<THolder<TArrowIStreamImpl>>&& streams +) { + YQL_ENSURE(worker->GetInputsCount() == streams.size(), + "number of input streams should match number of inputs provided by spec"); + + with_lock(worker->GetScopedAlloc()) { + auto& holderFactory = worker->GetGraph().GetHolderFactory(); + for (ui32 i = 0; i < streams.size(); i++) { + auto input = holderFactory.template Create<TArrowListValue>( + inputSpec, i, std::move(streams[i]), worker); + worker->SetInput(std::move(input), i); + } + } +} + +} // namespace + + +TArrowInputSpec::TArrowInputSpec(const TVector<NYT::TNode>& schemas) + : Schemas_(schemas) +{ +} + +const TVector<NYT::TNode>& TArrowInputSpec::GetSchemas() const { + return Schemas_; +} + +const NYT::TNode& TArrowInputSpec::GetSchema(ui32 index) const { + return Schemas_[index]; +} + +void TInputSpecTraits<TArrowInputSpec>::PreparePullListWorker( + const TArrowInputSpec& inputSpec, IPullListWorker* worker, + IArrowIStream* stream +) { + TInputSpecTraits<TArrowInputSpec>::PreparePullListWorker( + inputSpec, worker, TVector<IArrowIStream*>({stream})); +} + +void TInputSpecTraits<TArrowInputSpec>::PreparePullListWorker( + const TArrowInputSpec& inputSpec, IPullListWorker* worker, + const TVector<IArrowIStream*>& streams +) { + TVector<THolder<TArrowIStreamImpl>> wrappers; + for (ui32 i = 0; i < streams.size(); i++) { + wrappers.push_back(MakeHolder<TArrowIStreamImpl>(streams[i])); + } + PrepareWorkerImpl(inputSpec, worker, std::move(wrappers)); +} + +void TInputSpecTraits<TArrowInputSpec>::PreparePullListWorker( + const TArrowInputSpec& inputSpec, IPullListWorker* worker, + THolder<IArrowIStream> stream +) { + TInputSpecTraits<TArrowInputSpec>::PreparePullListWorker(inputSpec, worker, + VectorFromHolder<IArrowIStream>(std::move(stream))); +} + +void TInputSpecTraits<TArrowInputSpec>::PreparePullListWorker( + const TArrowInputSpec& inputSpec, IPullListWorker* worker, + TVector<THolder<IArrowIStream>>&& streams +) { + TVector<THolder<TArrowIStreamImpl>> wrappers; + for (ui32 i = 0; i < streams.size(); i++) { + wrappers.push_back(MakeHolder<TArrowIStreamImpl>(std::move(streams[i]))); + } + PrepareWorkerImpl(inputSpec, worker, std::move(wrappers)); +} + + +void TInputSpecTraits<TArrowInputSpec>::PreparePullStreamWorker( + const TArrowInputSpec& inputSpec, IPullStreamWorker* worker, + IArrowIStream* stream +) { + TInputSpecTraits<TArrowInputSpec>::PreparePullStreamWorker( + inputSpec, worker, TVector<IArrowIStream*>({stream})); +} + +void TInputSpecTraits<TArrowInputSpec>::PreparePullStreamWorker( + const TArrowInputSpec& inputSpec, IPullStreamWorker* worker, + const TVector<IArrowIStream*>& streams +) { + TVector<THolder<TArrowIStreamImpl>> wrappers; + for (ui32 i = 0; i < streams.size(); i++) { + wrappers.push_back(MakeHolder<TArrowIStreamImpl>(streams[i])); + } + PrepareWorkerImpl(inputSpec, worker, std::move(wrappers)); +} + +void TInputSpecTraits<TArrowInputSpec>::PreparePullStreamWorker( + const TArrowInputSpec& inputSpec, IPullStreamWorker* worker, + THolder<IArrowIStream> stream +) { + TInputSpecTraits<TArrowInputSpec>::PreparePullStreamWorker( + inputSpec, worker, VectorFromHolder<IArrowIStream>(std::move(stream))); +} + +void TInputSpecTraits<TArrowInputSpec>::PreparePullStreamWorker( + const TArrowInputSpec& inputSpec, IPullStreamWorker* worker, + TVector<THolder<IArrowIStream>>&& streams +) { + TVector<THolder<TArrowIStreamImpl>> wrappers; + for (ui32 i = 0; i < streams.size(); i++) { + wrappers.push_back(MakeHolder<TArrowIStreamImpl>(std::move(streams[i]))); + } + PrepareWorkerImpl(inputSpec, worker, std::move(wrappers)); +} + + +ConsumerType TInputSpecTraits<TArrowInputSpec>::MakeConsumer( + const TArrowInputSpec& inputSpec, TWorkerHolder<IPushStreamWorker> worker +) { + return MakeHolder<TArrowConsumerImpl>(inputSpec, std::move(worker)); +} + + +TArrowOutputSpec::TArrowOutputSpec(const NYT::TNode& schema) + : Schema_(schema) +{ +} + +const NYT::TNode& TArrowOutputSpec::GetSchema() const { + return Schema_; +} + + +PullListReturnType TOutputSpecTraits<TArrowOutputSpec>::ConvertPullListWorkerToOutputType( + const TArrowOutputSpec& outputSpec, TWorkerHolder<IPullListWorker> worker +) { + return MakeHolder<TArrowListImpl>(outputSpec, std::move(worker)); +} + +PullStreamReturnType TOutputSpecTraits<TArrowOutputSpec>::ConvertPullStreamWorkerToOutputType( + const TArrowOutputSpec& outputSpec, TWorkerHolder<IPullStreamWorker> worker +) { + return MakeHolder<TArrowStreamImpl>(outputSpec, std::move(worker)); +} + +void TOutputSpecTraits<TArrowOutputSpec>::SetConsumerToWorker( + const TArrowOutputSpec& outputSpec, IPushStreamWorker* worker, + THolder<IConsumer<TOutputItemType>> consumer +) { + worker->SetConsumer(MakeHolder<TArrowPushRelayImpl>(outputSpec, worker, std::move(consumer))); +} diff --git a/yql/essentials/public/purecalc/io_specs/arrow/spec.h b/yql/essentials/public/purecalc/io_specs/arrow/spec.h new file mode 100644 index 00000000000..42780b1a376 --- /dev/null +++ b/yql/essentials/public/purecalc/io_specs/arrow/spec.h @@ -0,0 +1,130 @@ +#pragma once + +#include <yql/essentials/public/purecalc/common/interface.h> +#include <arrow/compute/kernel.h> + +namespace NYql { +namespace NPureCalc { + +/** + * Processing mode for working with Apache Arrow batches inputs. + * + * In this mode purecalc accept pointers to abstract Arrow ExecBatches and + * processes them. All Datums in batches should respect the given YT schema + * (the one you pass to the constructor of the input spec). + * + * All working modes are implemented. In pull list and pull stream modes a + * program would accept a pointer to a single stream object or vector of + * pointers of stream objects of Arrow ExecBatch pointers. In push mode, a + * program will return a consumer of pointers to Arrow ExecBatch. + * + * The program synopsis follows: + * + * @code + * ... TPullListProgram::Apply(IStream<arrow::compute::ExecBatch*>*); + * ... TPullListProgram::Apply(TVector<IStream<arrow::compute::ExecBatch*>*>); + * ... TPullStreamProgram::Apply(IStream<arrow::compute::ExecBatch*>*); + * ... TPullStreamProgram::Apply(TVector<IStream<arrow::compute::ExecBatch*>*>); + * TConsumer<arrow::compute::ExecBatch*> TPushStreamProgram::Apply(...); + * @endcode + */ + +class TArrowInputSpec: public TInputSpecBase { +private: + const TVector<NYT::TNode> Schemas_; + +public: + explicit TArrowInputSpec(const TVector<NYT::TNode>& schemas); + const TVector<NYT::TNode>& GetSchemas() const override; + const NYT::TNode& GetSchema(ui32 index) const; + bool ProvidesBlocks() const override { return true; } +}; + +/** + * Processing mode for working with Apache Arrow batches outputs. + * + * In this mode purecalc yields pointers to abstract Arrow ExecBatches. All + * Datums in generated batches respects the given YT schema. + * + * Note that one should not expect that the returned pointer will be valid + * forever; in can (and will) become outdated once a new output is + * requested/pushed. + * + * All working modes are implemented. In pull stream and pull list modes a + * program will return a pointer to a stream of pointers to Arrow ExecBatches. + * In push mode, it will accept a single consumer of pointers to Arrow ExecBatch. + * + * The program synopsis follows: + * + * @code + * IStream<arrow::compute::ExecBatch*> TPullStreamProgram::Apply(...); + * IStream<arrow::compute::ExecBatch*> TPullListProgram::Apply(...); + * ... TPushStreamProgram::Apply(TConsumer<arrow::compute::ExecBatch*>); + * @endcode + */ + +class TArrowOutputSpec: public TOutputSpecBase { +private: + const NYT::TNode Schema_; + +public: + explicit TArrowOutputSpec(const NYT::TNode& schema); + const NYT::TNode& GetSchema() const override; + bool AcceptsBlocks() const override { return true; } +}; + +template <> +struct TInputSpecTraits<TArrowInputSpec> { + static const constexpr bool IsPartial = false; + + static const constexpr bool SupportPullListMode = true; + static const constexpr bool SupportPullStreamMode = true; + static const constexpr bool SupportPushStreamMode = true; + + using TInputItemType = arrow::compute::ExecBatch*; + using IInputStream = IStream<TInputItemType>; + using TConsumerType = THolder<IConsumer<TInputItemType>>; + + static void PreparePullListWorker(const TArrowInputSpec&, IPullListWorker*, + IInputStream*); + static void PreparePullListWorker(const TArrowInputSpec&, IPullListWorker*, + THolder<IInputStream>); + static void PreparePullListWorker(const TArrowInputSpec&, IPullListWorker*, + const TVector<IInputStream*>&); + static void PreparePullListWorker(const TArrowInputSpec&, IPullListWorker*, + TVector<THolder<IInputStream>>&&); + + static void PreparePullStreamWorker(const TArrowInputSpec&, IPullStreamWorker*, + IInputStream*); + static void PreparePullStreamWorker(const TArrowInputSpec&, IPullStreamWorker*, + THolder<IInputStream>); + static void PreparePullStreamWorker(const TArrowInputSpec&, IPullStreamWorker*, + const TVector<IInputStream*>&); + static void PreparePullStreamWorker(const TArrowInputSpec&, IPullStreamWorker*, + TVector<THolder<IInputStream>>&&); + + static TConsumerType MakeConsumer(const TArrowInputSpec&, TWorkerHolder<IPushStreamWorker>); +}; + +template <> +struct TOutputSpecTraits<TArrowOutputSpec> { + static const constexpr bool IsPartial = false; + + static const constexpr bool SupportPullListMode = true; + static const constexpr bool SupportPullStreamMode = true; + static const constexpr bool SupportPushStreamMode = true; + + using TOutputItemType = arrow::compute::ExecBatch*; + using IOutputStream = IStream<TOutputItemType>; + using TPullListReturnType = THolder<IOutputStream>; + using TPullStreamReturnType = THolder<IOutputStream>; + + static const constexpr TOutputItemType StreamSentinel = nullptr; + + static TPullListReturnType ConvertPullListWorkerToOutputType(const TArrowOutputSpec&, TWorkerHolder<IPullListWorker>); + static TPullStreamReturnType ConvertPullStreamWorkerToOutputType(const TArrowOutputSpec&, TWorkerHolder<IPullStreamWorker>); + static void SetConsumerToWorker(const TArrowOutputSpec&, IPushStreamWorker*, THolder<IConsumer<TOutputItemType>>); +}; + +} // namespace NPureCalc +} // namespace NYql diff --git a/yql/essentials/public/purecalc/io_specs/arrow/ut/test_spec.cpp b/yql/essentials/public/purecalc/io_specs/arrow/ut/test_spec.cpp new file mode 100644 index 00000000000..fa1ca5171a4 --- /dev/null +++ b/yql/essentials/public/purecalc/io_specs/arrow/ut/test_spec.cpp @@ -0,0 +1,419 @@ +#include <library/cpp/testing/unittest/registar.h> + +#include <yql/essentials/public/purecalc/common/interface.h> +#include <yql/essentials/public/purecalc/io_specs/arrow/spec.h> +#include <yql/essentials/public/purecalc/ut/lib/helpers.h> +#include <yql/essentials/core/yql_type_annotation.h> + +#include <yql/essentials/public/udf/arrow/udf_arrow_helpers.h> +#include <arrow/array/builder_primitive.h> + +namespace { + +#define Y_UNIT_TEST_ADD_BLOCK_TEST(N, MODE) \ + TCurrentTest::AddTest(#N ":BlockEngineMode=" #MODE, \ + static_cast<void (*)(NUnitTest::TTestContext&)>(&N<NYql::EBlockEngineMode::MODE>), false); + +#define Y_UNIT_TEST_BLOCKS(N) \ + template<NYql::EBlockEngineMode BlockEngineMode> \ + void N(NUnitTest::TTestContext&); \ + struct TTestRegistration##N { \ + TTestRegistration##N() { \ + Y_UNIT_TEST_ADD_BLOCK_TEST(N, Disable) \ + Y_UNIT_TEST_ADD_BLOCK_TEST(N, Auto) \ + Y_UNIT_TEST_ADD_BLOCK_TEST(N, Force) \ + } \ + }; \ + static TTestRegistration##N testRegistration##N; \ + template<NYql::EBlockEngineMode BlockEngineMode> \ + void N(NUnitTest::TTestContext&) + +NYql::NPureCalc::TProgramFactoryOptions TestOptions(NYql::EBlockEngineMode mode) { + static const TMap<NYql::EBlockEngineMode, const TString> mode2settings = { + {NYql::EBlockEngineMode::Disable, "disable"}, + {NYql::EBlockEngineMode::Auto, "auto"}, + {NYql::EBlockEngineMode::Force, "force"}, + }; + auto options = NYql::NPureCalc::TProgramFactoryOptions(); + options.SetBlockEngineSettings(mode2settings.at(mode)); + return options; +} + + +template <typename T> +struct TVectorStream: public NYql::NPureCalc::IStream<T*> { + TVector<T> Data_; + size_t Index_ = 0; + +public: + TVectorStream(TVector<T> items) + : Data_(std::move(items)) + { + } + + T* Fetch() override { + return Index_ < Data_.size() ? &Data_[Index_++] : nullptr; + } +}; + + +template<typename T> +struct TVectorConsumer: public NYql::NPureCalc::IConsumer<T*> { + TVector<T>& Data_; + size_t Index_ = 0; + +public: + TVectorConsumer(TVector<T>& items) + : Data_(items) + { + } + + void OnObject(T* t) override { + Index_++; + Data_.push_back(*t); + } + + void OnFinish() override { + UNIT_ASSERT_GT(Index_, 0); + } +}; + + +using ExecBatchStreamImpl = TVectorStream<arrow::compute::ExecBatch>; +using ExecBatchConsumerImpl = TVectorConsumer<arrow::compute::ExecBatch>; + +template <typename TBuilder> +arrow::Datum MakeArrayDatumFromVector( + const TVector<typename TBuilder::value_type>& data, + const TVector<bool>& valid +) { + TBuilder builder; + ARROW_OK(builder.Reserve(data.size())); + ARROW_OK(builder.AppendValues(data, valid)); + return arrow::Datum(ARROW_RESULT(builder.Finish())); +} + +template <typename TValue> +TVector<TValue> MakeVectorFromArrayDatum( + const arrow::Datum& datum, + const int64_t dsize +) { + Y_ENSURE(datum.is_array(), "ExecBatch layout doesn't respect the schema"); + + const auto& array = *datum.array(); + Y_ENSURE(array.length == dsize, + "Array Datum size differs from the given ExecBatch size"); + Y_ENSURE(array.GetNullCount() == 0, + "Null values conversion is not supported"); + Y_ENSURE(array.buffers.size() == 2, + "Array Datum layout doesn't respect the schema"); + + const TValue* adata1 = array.GetValuesSafe<TValue>(1); + return TVector<TValue>(adata1, adata1 + dsize); +} + +arrow::compute::ExecBatch MakeBatch(ui64 bsize, i64 value, ui64 init = 1) { + TVector<uint64_t> data1(bsize); + TVector<int64_t> data2(bsize); + TVector<bool> valid(bsize); + std::iota(data1.begin(), data1.end(), init); + std::fill(data2.begin(), data2.end(), value); + std::fill(valid.begin(), valid.end(), true); + + TVector<arrow::Datum> batchArgs = { + MakeArrayDatumFromVector<arrow::UInt64Builder>(data1, valid), + MakeArrayDatumFromVector<arrow::Int64Builder>(data2, valid) + }; + + return arrow::compute::ExecBatch(std::move(batchArgs), bsize); +} + +TVector<std::tuple<ui64, i64>> CanonBatches(const TVector<arrow::compute::ExecBatch>& batches) { + TVector<std::tuple<ui64, i64>> result; + for (const auto& batch : batches) { + const auto bsize = batch.length; + + const auto& avec1 = MakeVectorFromArrayDatum<ui64>(batch.values[0], bsize); + const auto& avec2 = MakeVectorFromArrayDatum<i64>(batch.values[1], bsize); + + for (auto i = 0; i < bsize; i++) { + result.push_back(std::make_tuple(avec1[i], avec2[i])); + } + } + std::sort(result.begin(), result.end()); + return result; +} + +} // namespace + + +Y_UNIT_TEST_SUITE(TestSimplePullListArrowIO) { + Y_UNIT_TEST_BLOCKS(TestSingleInput) { + using namespace NYql::NPureCalc; + + TVector<TString> fields = {"uint64", "int64"}; + auto schema = NYql::NPureCalc::NPrivate::GetSchema(fields); + + auto factory = MakeProgramFactory(TestOptions(BlockEngineMode)); + + try { + auto program = factory->MakePullListProgram( + TArrowInputSpec({schema}), + TArrowOutputSpec(schema), + "SELECT * FROM Input", + ETranslationMode::SQL + ); + + const TVector<arrow::compute::ExecBatch> input({MakeBatch(9, 19)}); + const auto canonInput = CanonBatches(input); + ExecBatchStreamImpl items(input); + + auto stream = program->Apply(&items); + + TVector<arrow::compute::ExecBatch> output; + while (arrow::compute::ExecBatch* batch = stream->Fetch()) { + output.push_back(*batch); + } + const auto canonOutput = CanonBatches(output); + UNIT_ASSERT_EQUAL(canonInput, canonOutput); + } catch (const TCompileError& error) { + UNIT_FAIL(error.GetIssues()); + } + } + + Y_UNIT_TEST_BLOCKS(TestMultiInput) { + using namespace NYql::NPureCalc; + + TVector<TString> fields = {"uint64", "int64"}; + auto schema = NYql::NPureCalc::NPrivate::GetSchema(fields); + + auto factory = MakeProgramFactory(TestOptions(BlockEngineMode)); + + try { + auto program = factory->MakePullListProgram( + TArrowInputSpec({schema, schema}), + TArrowOutputSpec(schema), + R"( + SELECT * FROM Input0 + UNION ALL + SELECT * FROM Input1 + )", + ETranslationMode::SQL + ); + + TVector<arrow::compute::ExecBatch> inputs = { + MakeBatch(9, 19), + MakeBatch(7, 17) + }; + const auto canonInputs = CanonBatches(inputs); + + ExecBatchStreamImpl items0({inputs[0]}); + ExecBatchStreamImpl items1({inputs[1]}); + + const TVector<IStream<arrow::compute::ExecBatch*>*> items({&items0, &items1}); + + auto stream = program->Apply(items); + + TVector<arrow::compute::ExecBatch> output; + while (arrow::compute::ExecBatch* batch = stream->Fetch()) { + output.push_back(*batch); + } + const auto canonOutput = CanonBatches(output); + UNIT_ASSERT_EQUAL(canonInputs, canonOutput); + } catch (const TCompileError& error) { + UNIT_FAIL(error.GetIssues()); + } + } +} + + +Y_UNIT_TEST_SUITE(TestMorePullListArrowIO) { + Y_UNIT_TEST_BLOCKS(TestInc) { + using namespace NYql::NPureCalc; + + TVector<TString> fields = {"uint64", "int64"}; + auto schema = NYql::NPureCalc::NPrivate::GetSchema(fields); + + auto factory = MakeProgramFactory(TestOptions(BlockEngineMode)); + + try { + auto program = factory->MakePullListProgram( + TArrowInputSpec({schema}), + TArrowOutputSpec(schema), + R"(SELECT + uint64 + 1 as uint64, + int64 - 2 as int64, + FROM Input)", + ETranslationMode::SQL + ); + + const TVector<arrow::compute::ExecBatch> input({MakeBatch(9, 19)}); + const auto canonInput = CanonBatches(input); + ExecBatchStreamImpl items(input); + + auto stream = program->Apply(&items); + + TVector<arrow::compute::ExecBatch> output; + while (arrow::compute::ExecBatch* batch = stream->Fetch()) { + output.push_back(*batch); + } + const auto canonOutput = CanonBatches(output); + const TVector<arrow::compute::ExecBatch> check({MakeBatch(9, 17, 2)}); + const auto canonCheck = CanonBatches(check); + UNIT_ASSERT_EQUAL(canonCheck, canonOutput); + } catch (const TCompileError& error) { + UNIT_FAIL(error.GetIssues()); + } + } +} + + +Y_UNIT_TEST_SUITE(TestSimplePullStreamArrowIO) { + Y_UNIT_TEST_BLOCKS(TestSingleInput) { + using namespace NYql::NPureCalc; + + TVector<TString> fields = {"uint64", "int64"}; + auto schema = NYql::NPureCalc::NPrivate::GetSchema(fields); + + auto factory = MakeProgramFactory(TestOptions(BlockEngineMode)); + + try { + auto program = factory->MakePullStreamProgram( + TArrowInputSpec({schema}), + TArrowOutputSpec(schema), + "SELECT * FROM Input", + ETranslationMode::SQL + ); + + const TVector<arrow::compute::ExecBatch> input({MakeBatch(9, 19)}); + const auto canonInput = CanonBatches(input); + ExecBatchStreamImpl items(input); + + auto stream = program->Apply(&items); + + TVector<arrow::compute::ExecBatch> output; + while (arrow::compute::ExecBatch* batch = stream->Fetch()) { + output.push_back(*batch); + } + const auto canonOutput = CanonBatches(output); + UNIT_ASSERT_EQUAL(canonInput, canonOutput); + } catch (const TCompileError& error) { + UNIT_FAIL(error.GetIssues()); + } + } +} + + +Y_UNIT_TEST_SUITE(TestMorePullStreamArrowIO) { + Y_UNIT_TEST_BLOCKS(TestInc) { + using namespace NYql::NPureCalc; + + TVector<TString> fields = {"uint64", "int64"}; + auto schema = NYql::NPureCalc::NPrivate::GetSchema(fields); + + auto factory = MakeProgramFactory(TestOptions(BlockEngineMode)); + + try { + auto program = factory->MakePullStreamProgram( + TArrowInputSpec({schema}), + TArrowOutputSpec(schema), + R"(SELECT + uint64 + 1 as uint64, + int64 - 2 as int64, + FROM Input)", + ETranslationMode::SQL + ); + + const TVector<arrow::compute::ExecBatch> input({MakeBatch(9, 19)}); + const auto canonInput = CanonBatches(input); + ExecBatchStreamImpl items(input); + + auto stream = program->Apply(&items); + + TVector<arrow::compute::ExecBatch> output; + while (arrow::compute::ExecBatch* batch = stream->Fetch()) { + output.push_back(*batch); + } + const auto canonOutput = CanonBatches(output); + const TVector<arrow::compute::ExecBatch> check({MakeBatch(9, 17, 2)}); + const auto canonCheck = CanonBatches(check); + UNIT_ASSERT_EQUAL(canonCheck, canonOutput); + } catch (const TCompileError& error) { + UNIT_FAIL(error.GetIssues()); + } + } +} + + +Y_UNIT_TEST_SUITE(TestPushStreamArrowIO) { + Y_UNIT_TEST_BLOCKS(TestAllColumns) { + using namespace NYql::NPureCalc; + + TVector<TString> fields = {"uint64", "int64"}; + auto schema = NYql::NPureCalc::NPrivate::GetSchema(fields); + + auto factory = MakeProgramFactory(TestOptions(BlockEngineMode)); + + try { + auto program = factory->MakePushStreamProgram( + TArrowInputSpec({schema}), + TArrowOutputSpec(schema), + "SELECT * FROM Input", + ETranslationMode::SQL + ); + + arrow::compute::ExecBatch input = MakeBatch(9, 19); + const auto canonInput = CanonBatches({input}); + TVector<arrow::compute::ExecBatch> output; + + auto consumer = program->Apply(MakeHolder<ExecBatchConsumerImpl>(output)); + + UNIT_ASSERT_NO_EXCEPTION([&](){ consumer->OnObject(&input); }()); + UNIT_ASSERT_NO_EXCEPTION([&](){ consumer->OnFinish(); }()); + + const auto canonOutput = CanonBatches(output); + UNIT_ASSERT_EQUAL(canonInput, canonOutput); + } catch (const TCompileError& error) { + UNIT_FAIL(error.GetIssues()); + } + } +} + +Y_UNIT_TEST_SUITE(TestMorePushStreamArrowIO) { + Y_UNIT_TEST_BLOCKS(TestInc) { + using namespace NYql::NPureCalc; + + TVector<TString> fields = {"uint64", "int64"}; + auto schema = NYql::NPureCalc::NPrivate::GetSchema(fields); + + auto factory = MakeProgramFactory(TestOptions(BlockEngineMode)); + + try { + auto program = factory->MakePushStreamProgram( + TArrowInputSpec({schema}), + TArrowOutputSpec(schema), + R"(SELECT + uint64 + 1 as uint64, + int64 - 2 as int64, + FROM Input)", + ETranslationMode::SQL + ); + + arrow::compute::ExecBatch input = MakeBatch(9, 19); + const auto canonInput = CanonBatches({input}); + TVector<arrow::compute::ExecBatch> output; + + auto consumer = program->Apply(MakeHolder<ExecBatchConsumerImpl>(output)); + + UNIT_ASSERT_NO_EXCEPTION([&](){ consumer->OnObject(&input); }()); + UNIT_ASSERT_NO_EXCEPTION([&](){ consumer->OnFinish(); }()); + + const auto canonOutput = CanonBatches(output); + const TVector<arrow::compute::ExecBatch> check({MakeBatch(9, 17, 2)}); + const auto canonCheck = CanonBatches(check); + UNIT_ASSERT_EQUAL(canonCheck, canonOutput); + } catch (const TCompileError& error) { + UNIT_FAIL(error.GetIssues()); + } + } +} diff --git a/yql/essentials/public/purecalc/io_specs/arrow/ut/ya.make b/yql/essentials/public/purecalc/io_specs/arrow/ut/ya.make new file mode 100644 index 00000000000..ad7eb5881f5 --- /dev/null +++ b/yql/essentials/public/purecalc/io_specs/arrow/ut/ya.make @@ -0,0 +1,20 @@ +UNITTEST() + +SIZE(MEDIUM) + +TIMEOUT(300) + +PEERDIR( + yql/essentials/public/udf/service/exception_policy + yql/essentials/public/purecalc + yql/essentials/public/purecalc/io_specs/arrow + yql/essentials/public/purecalc/ut/lib +) + +YQL_LAST_ABI_VERSION() + +SRCS( + test_spec.cpp +) + +END() diff --git a/yql/essentials/public/purecalc/io_specs/arrow/ya.make b/yql/essentials/public/purecalc/io_specs/arrow/ya.make new file mode 100644 index 00000000000..f98059dcebe --- /dev/null +++ b/yql/essentials/public/purecalc/io_specs/arrow/ya.make @@ -0,0 +1,13 @@ +LIBRARY() + +PEERDIR( + yql/essentials/public/purecalc/common +) + +INCLUDE(ya.make.inc) + +END() + +RECURSE_FOR_TESTS( + ut +) diff --git a/yql/essentials/public/purecalc/io_specs/arrow/ya.make.inc b/yql/essentials/public/purecalc/io_specs/arrow/ya.make.inc new file mode 100644 index 00000000000..37ff3be849e --- /dev/null +++ b/yql/essentials/public/purecalc/io_specs/arrow/ya.make.inc @@ -0,0 +1,13 @@ +SRCDIR( + yql/essentials/public/purecalc/io_specs/arrow +) + +ADDINCL( + yql/essentials/public/purecalc/io_specs/arrow +) + +YQL_LAST_ABI_VERSION() + +SRCS( + spec.cpp +) diff --git a/yql/essentials/public/purecalc/io_specs/protobuf/proto_variant.cpp b/yql/essentials/public/purecalc/io_specs/protobuf/proto_variant.cpp new file mode 100644 index 00000000000..90f0b339ca6 --- /dev/null +++ b/yql/essentials/public/purecalc/io_specs/protobuf/proto_variant.cpp @@ -0,0 +1 @@ +#include "proto_variant.h" diff --git a/yql/essentials/public/purecalc/io_specs/protobuf/proto_variant.h b/yql/essentials/public/purecalc/io_specs/protobuf/proto_variant.h new file mode 100644 index 00000000000..0692440ca1c --- /dev/null +++ b/yql/essentials/public/purecalc/io_specs/protobuf/proto_variant.h @@ -0,0 +1,80 @@ +#pragma once + +#include <yql/essentials/public/purecalc/common/interface.h> + +#include <array> + +namespace NYql::NPureCalc::NPrivate { + using TProtoRawMultiOutput = std::pair<ui32, google::protobuf::Message*>; + + template <typename... T> + using TProtoMultiOutput = std::variant<T*...>; + + template <size_t I, typename... T> + using TProtoOutput = std::add_pointer_t<typename TTypeList<T...>::template TGet<I>>; + + template <size_t I, typename... T> + TProtoMultiOutput<T...> InitProtobufsVariant(google::protobuf::Message* ptr) { + static_assert(std::conjunction_v<std::is_base_of<google::protobuf::Message, T>...>); + return TProtoMultiOutput<T...>(std::in_place_index<I>, static_cast<TProtoOutput<I, T...>>(ptr)); + } + + template <typename... T> + class TProtobufsMappingBase { + public: + TProtobufsMappingBase() + : InitFuncs_(BuildInitFuncs(std::make_index_sequence<sizeof...(T)>())) + { + } + + private: + typedef TProtoMultiOutput<T...> (*initfunc)(google::protobuf::Message*); + + template <size_t... I> + inline std::array<initfunc, sizeof...(T)> BuildInitFuncs(std::index_sequence<I...>) { + return {&InitProtobufsVariant<I, T...>...}; + } + + protected: + const std::array<initfunc, sizeof...(T)> InitFuncs_; + }; + + template <typename... T> + class TProtobufsMappingStream: public IStream<TProtoMultiOutput<T...>>, public TProtobufsMappingBase<T...> { + public: + TProtobufsMappingStream(THolder<IStream<TProtoRawMultiOutput>> oldStream) + : OldStream_(std::move(oldStream)) + { + } + + public: + TProtoMultiOutput<T...> Fetch() override { + auto&& oldItem = OldStream_->Fetch(); + return this->InitFuncs_[oldItem.first](oldItem.second); + } + + private: + THolder<IStream<TProtoRawMultiOutput>> OldStream_; + }; + + template <typename... T> + class TProtobufsMappingConsumer: public IConsumer<TProtoRawMultiOutput>, public TProtobufsMappingBase<T...> { + public: + TProtobufsMappingConsumer(THolder<IConsumer<TProtoMultiOutput<T...>>> oldConsumer) + : OldConsumer_(std::move(oldConsumer)) + { + } + + public: + void OnObject(TProtoRawMultiOutput oldItem) override { + OldConsumer_->OnObject(this->InitFuncs_[oldItem.first](oldItem.second)); + } + + void OnFinish() override { + OldConsumer_->OnFinish(); + } + + private: + THolder<IConsumer<TProtoMultiOutput<T...>>> OldConsumer_; + }; +} diff --git a/yql/essentials/public/purecalc/io_specs/protobuf/spec.cpp b/yql/essentials/public/purecalc/io_specs/protobuf/spec.cpp new file mode 100644 index 00000000000..91de6c290a3 --- /dev/null +++ b/yql/essentials/public/purecalc/io_specs/protobuf/spec.cpp @@ -0,0 +1 @@ +#include "spec.h" diff --git a/yql/essentials/public/purecalc/io_specs/protobuf/spec.h b/yql/essentials/public/purecalc/io_specs/protobuf/spec.h new file mode 100644 index 00000000000..0e1a97f632a --- /dev/null +++ b/yql/essentials/public/purecalc/io_specs/protobuf/spec.h @@ -0,0 +1,147 @@ +#pragma once + +#include "proto_variant.h" + +#include <yql/essentials/public/purecalc/io_specs/protobuf_raw/spec.h> + +namespace NYql { + namespace NPureCalc { + /** + * Processing mode for working with non-raw protobuf messages. + * + * @tparam T message type. + */ + template <typename T> + class TProtobufInputSpec: public TProtobufRawInputSpec { + static_assert(std::is_base_of<google::protobuf::Message, T>::value, + "should be derived from google::protobuf::Message"); + public: + TProtobufInputSpec( + const TMaybe<TString>& timestampColumn = Nothing(), + const TProtoSchemaOptions& options = {} + ) + : TProtobufRawInputSpec(*T::descriptor(), timestampColumn, options) + { + } + }; + + /** + * Processing mode for working with non-raw protobuf messages. + * + * @tparam T message type. + */ + template <typename T> + class TProtobufOutputSpec: public TProtobufRawOutputSpec { + static_assert(std::is_base_of<google::protobuf::Message, T>::value, + "should be derived from google::protobuf::Message"); + public: + TProtobufOutputSpec( + const TProtoSchemaOptions& options = {}, + google::protobuf::Arena* arena = nullptr + ) + : TProtobufRawOutputSpec(*T::descriptor(), nullptr, options, arena) + { + } + }; + + /** + * Processing mode for working with non-raw protobuf messages and several outputs. + */ + template <typename... T> + class TProtobufMultiOutputSpec: public TProtobufRawMultiOutputSpec { + static_assert( + std::conjunction_v<std::is_base_of<google::protobuf::Message, T>...>, + "all types should be derived from google::protobuf::Message"); + public: + TProtobufMultiOutputSpec( + const TProtoSchemaOptions& options = {}, + TMaybe<TVector<google::protobuf::Arena*>> arenas = {} + ) + : TProtobufRawMultiOutputSpec({T::descriptor()...}, Nothing(), options, std::move(arenas)) + { + } + }; + + template <typename T> + struct TInputSpecTraits<TProtobufInputSpec<T>> { + static const constexpr bool IsPartial = false; + + static const constexpr bool SupportPullStreamMode = true; + static const constexpr bool SupportPullListMode = true; + static const constexpr bool SupportPushStreamMode = true; + + using TConsumerType = THolder<IConsumer<T*>>; + + static void PreparePullStreamWorker(const TProtobufInputSpec<T>& inputSpec, IPullStreamWorker* worker, THolder<IStream<T*>> stream) { + auto raw = ConvertStream<google::protobuf::Message*>(std::move(stream)); + TInputSpecTraits<TProtobufRawInputSpec>::PreparePullStreamWorker(inputSpec, worker, std::move(raw)); + } + + static void PreparePullListWorker(const TProtobufInputSpec<T>& inputSpec, IPullListWorker* worker, THolder<IStream<T*>> stream) { + auto raw = ConvertStream<google::protobuf::Message*>(std::move(stream)); + TInputSpecTraits<TProtobufRawInputSpec>::PreparePullListWorker(inputSpec, worker, std::move(raw)); + } + + static TConsumerType MakeConsumer(const TProtobufInputSpec<T>& inputSpec, TWorkerHolder<IPushStreamWorker> worker) { + auto raw = TInputSpecTraits<TProtobufRawInputSpec>::MakeConsumer(inputSpec, std::move(worker)); + return ConvertConsumer<T*>(std::move(raw)); + } + }; + + template <typename T> + struct TOutputSpecTraits<TProtobufOutputSpec<T>> { + static const constexpr bool IsPartial = false; + + static const constexpr bool SupportPullStreamMode = true; + static const constexpr bool SupportPullListMode = true; + static const constexpr bool SupportPushStreamMode = true; + + using TOutputItemType = T*; + using TPullStreamReturnType = THolder<IStream<TOutputItemType>>; + using TPullListReturnType = THolder<IStream<TOutputItemType>>; + + static TPullStreamReturnType ConvertPullStreamWorkerToOutputType(const TProtobufOutputSpec<T>& outputSpec, TWorkerHolder<IPullStreamWorker> worker) { + auto raw = TOutputSpecTraits<TProtobufRawOutputSpec>::ConvertPullStreamWorkerToOutputType(outputSpec, std::move(worker)); + return ConvertStreamUnsafe<TOutputItemType>(std::move(raw)); + } + + static TPullListReturnType ConvertPullListWorkerToOutputType(const TProtobufOutputSpec<T>& outputSpec, TWorkerHolder<IPullListWorker> worker) { + auto raw = TOutputSpecTraits<TProtobufRawOutputSpec>::ConvertPullListWorkerToOutputType(outputSpec, std::move(worker)); + return ConvertStreamUnsafe<TOutputItemType>(std::move(raw)); + } + + static void SetConsumerToWorker(const TProtobufOutputSpec<T>& outputSpec, IPushStreamWorker* worker, THolder<IConsumer<T*>> consumer) { + auto raw = ConvertConsumerUnsafe<google::protobuf::Message*>(std::move(consumer)); + TOutputSpecTraits<TProtobufRawOutputSpec>::SetConsumerToWorker(outputSpec, worker, std::move(raw)); + } + }; + + template <typename... T> + struct TOutputSpecTraits<TProtobufMultiOutputSpec<T...>> { + static const constexpr bool IsPartial = false; + + static const constexpr bool SupportPullStreamMode = true; + static const constexpr bool SupportPullListMode = true; + static const constexpr bool SupportPushStreamMode = true; + + using TOutputItemType = std::variant<T*...>; + using TPullStreamReturnType = THolder<IStream<TOutputItemType>>; + using TPullListReturnType = THolder<IStream<TOutputItemType>>; + + static TPullStreamReturnType ConvertPullStreamWorkerToOutputType(const TProtobufMultiOutputSpec<T...>& outputSpec, TWorkerHolder<IPullStreamWorker> worker) { + auto raw = TOutputSpecTraits<TProtobufRawMultiOutputSpec>::ConvertPullStreamWorkerToOutputType(outputSpec, std::move(worker)); + return THolder(new NPrivate::TProtobufsMappingStream<T...>(std::move(raw))); + } + + static TPullListReturnType ConvertPullListWorkerToOutputType(const TProtobufMultiOutputSpec<T...>& outputSpec, TWorkerHolder<IPullListWorker> worker) { + auto raw = TOutputSpecTraits<TProtobufRawMultiOutputSpec>::ConvertPullListWorkerToOutputType(outputSpec, std::move(worker)); + return THolder(new NPrivate::TProtobufsMappingStream<T...>(std::move(raw))); + } + + static void SetConsumerToWorker(const TProtobufMultiOutputSpec<T...>& outputSpec, IPushStreamWorker* worker, THolder<IConsumer<TOutputItemType>> consumer) { + auto wrapper = MakeHolder<NPrivate::TProtobufsMappingConsumer<T...>>(std::move(consumer)); + TOutputSpecTraits<TProtobufRawMultiOutputSpec>::SetConsumerToWorker(outputSpec, worker, std::move(wrapper)); + } + }; + } +} diff --git a/yql/essentials/public/purecalc/io_specs/protobuf/ut/test_spec.cpp b/yql/essentials/public/purecalc/io_specs/protobuf/ut/test_spec.cpp new file mode 100644 index 00000000000..923a4f5bd8f --- /dev/null +++ b/yql/essentials/public/purecalc/io_specs/protobuf/ut/test_spec.cpp @@ -0,0 +1,996 @@ +#include <library/cpp/testing/unittest/registar.h> + +#include <yql/essentials/public/purecalc/common/interface.h> +#include <yql/essentials/public/purecalc/io_specs/protobuf/spec.h> +#include <yql/essentials/public/purecalc/ut/protos/test_structs.pb.h> +#include <library/cpp/protobuf/util/pb_io.h> +#include <util/generic/xrange.h> + +namespace { + TMaybe<NPureCalcProto::TAllTypes> allTypesMessage; + + NPureCalcProto::TAllTypes& GetCanonicalMessage() { + if (!allTypesMessage) { + allTypesMessage = NPureCalcProto::TAllTypes(); + + allTypesMessage->SetFDouble(1); + allTypesMessage->SetFFloat(2); + allTypesMessage->SetFInt64(3); + allTypesMessage->SetFSfixed64(4); + allTypesMessage->SetFSint64(5); + allTypesMessage->SetFUint64(6); + allTypesMessage->SetFFixed64(7); + allTypesMessage->SetFInt32(8); + allTypesMessage->SetFSfixed32(9); + allTypesMessage->SetFSint32(10); + allTypesMessage->SetFUint32(11); + allTypesMessage->SetFFixed32(12); + allTypesMessage->SetFBool(true); + allTypesMessage->SetFString("asd"); + allTypesMessage->SetFBytes("dsa"); + } + + return allTypesMessage.GetRef(); + } + + template <typename T1, typename T2> + void AssertEqualToCanonical(const T1& got, const T2& expected) { + UNIT_ASSERT_EQUAL(expected.GetFDouble(), got.GetFDouble()); + UNIT_ASSERT_EQUAL(expected.GetFFloat(), got.GetFFloat()); + UNIT_ASSERT_EQUAL(expected.GetFInt64(), got.GetFInt64()); + UNIT_ASSERT_EQUAL(expected.GetFSfixed64(), got.GetFSfixed64()); + UNIT_ASSERT_EQUAL(expected.GetFSint64(), got.GetFSint64()); + UNIT_ASSERT_EQUAL(expected.GetFUint64(), got.GetFUint64()); + UNIT_ASSERT_EQUAL(expected.GetFFixed64(), got.GetFFixed64()); + UNIT_ASSERT_EQUAL(expected.GetFInt32(), got.GetFInt32()); + UNIT_ASSERT_EQUAL(expected.GetFSfixed32(), got.GetFSfixed32()); + UNIT_ASSERT_EQUAL(expected.GetFSint32(), got.GetFSint32()); + UNIT_ASSERT_EQUAL(expected.GetFUint32(), got.GetFUint32()); + UNIT_ASSERT_EQUAL(expected.GetFFixed32(), got.GetFFixed32()); + UNIT_ASSERT_EQUAL(expected.GetFBool(), got.GetFBool()); + UNIT_ASSERT_EQUAL(expected.GetFString(), got.GetFString()); + UNIT_ASSERT_EQUAL(expected.GetFBytes(), got.GetFBytes()); + } + + template <typename T> + void AssertEqualToCanonical(const T& got) { + AssertEqualToCanonical(got, GetCanonicalMessage()); + } + + TString SerializeToTextFormatAsString(const google::protobuf::Message& message) { + TString result; + { + TStringOutput output(result); + SerializeToTextFormat(message, output); + } + return result; + } + + template <typename T> + void AssertProtoEqual(const T& actual, const T& expected) { + UNIT_ASSERT_VALUES_EQUAL(SerializeToTextFormatAsString(actual), SerializeToTextFormatAsString(expected)); + } +} + +class TAllTypesStreamImpl: public NYql::NPureCalc::IStream<NPureCalcProto::TAllTypes*> { +private: + int I_ = 0; + NPureCalcProto::TAllTypes Message_ = GetCanonicalMessage(); + +public: + NPureCalcProto::TAllTypes* Fetch() override { + if (I_ > 0) { + return nullptr; + } else { + I_ += 1; + return &Message_; + } + } +}; + +class TSimpleMessageStreamImpl: public NYql::NPureCalc::IStream<NPureCalcProto::TSimpleMessage*> { +public: + TSimpleMessageStreamImpl(i32 value) + { + Message_.SetX(value); + } + + NPureCalcProto::TSimpleMessage* Fetch() override { + if (Exhausted_) { + return nullptr; + } else { + Exhausted_ = true; + return &Message_; + } + } + +private: + NPureCalcProto::TSimpleMessage Message_; + bool Exhausted_ = false; +}; + +class TAllTypesConsumerImpl: public NYql::NPureCalc::IConsumer<NPureCalcProto::TAllTypes*> { +private: + int I_ = 0; + +public: + void OnObject(NPureCalcProto::TAllTypes* t) override { + I_ += 1; + AssertEqualToCanonical(*t); + } + + void OnFinish() override { + UNIT_ASSERT(I_ > 0); + } +}; + +class TStringMessageStreamImpl: public NYql::NPureCalc::IStream<NPureCalcProto::TStringMessage*> { +private: + int I_ = 0; + NPureCalcProto::TStringMessage Message_{}; + +public: + NPureCalcProto::TStringMessage* Fetch() override { + if (I_ >= 3) { + return nullptr; + } else { + Message_.SetX(TString("-") * I_); + I_ += 1; + return &Message_; + } + } +}; + +class TSimpleMessageConsumerImpl: public NYql::NPureCalc::IConsumer<NPureCalcProto::TSimpleMessage*> { +private: + TVector<int>* Buf_; + +public: + TSimpleMessageConsumerImpl(TVector<int>* buf) + : Buf_(buf) + { + } + +public: + void OnObject(NPureCalcProto::TSimpleMessage* t) override { + Buf_->push_back(t->GetX()); + } + + void OnFinish() override { + Buf_->push_back(-100); + } +}; + +using TMessagesVariant = std::variant<NPureCalcProto::TSplitted1*, NPureCalcProto::TSplitted2*, NPureCalcProto::TStringMessage*>; + +class TVariantConsumerImpl: public NYql::NPureCalc::IConsumer<TMessagesVariant> { +public: + using TType0 = TVector<std::pair<i32, TString>>; + using TType1 = TVector<std::pair<ui32, TString>>; + using TType2 = TVector<TString>; + +public: + TVariantConsumerImpl(TType0* q0, TType1* q1, TType2* q2, int* v) + : Queue0_(q0) + , Queue1_(q1) + , Queue2_(q2) + , Value_(v) + { + } + + void OnObject(TMessagesVariant value) override { + if (auto* p = std::get_if<0>(&value)) { + Queue0_->push_back({(*p)->GetBInt(), std::move(*(*p)->MutableBString())}); + } else if (auto* p = std::get_if<1>(&value)) { + Queue1_->push_back({(*p)->GetCUint(), std::move(*(*p)->MutableCString())}); + } else if (auto* p = std::get_if<2>(&value)) { + Queue2_->push_back(std::move(*(*p)->MutableX())); + } else { + Y_ABORT("invalid variant alternative"); + } + } + + void OnFinish() override { + *Value_ = 42; + } + +private: + TType0* Queue0_; + TType1* Queue1_; + TType2* Queue2_; + int* Value_; +}; + +class TUnsplittedStreamImpl: public NYql::NPureCalc::IStream<NPureCalcProto::TUnsplitted*> { +public: + TUnsplittedStreamImpl() + { + Message_.SetAInt(-23); + Message_.SetAUint(111); + Message_.SetAString("Hello!"); + } + +public: + NPureCalcProto::TUnsplitted* Fetch() override { + switch (I_) { + case 0: + ++I_; + return &Message_; + case 1: + ++I_; + Message_.SetABool(false); + return &Message_; + case 2: + ++I_; + Message_.SetABool(true); + return &Message_; + default: + return nullptr; + } + } + +private: + NPureCalcProto::TUnsplitted Message_; + ui32 I_ = 0; +}; + +template<typename T> +struct TVectorConsumer: public NYql::NPureCalc::IConsumer<T*> { + TVector<T> Data; + + void OnObject(T* t) override { + Data.push_back(*t); + } + + void OnFinish() override { + } +}; + +template <typename T> +struct TVectorStream: public NYql::NPureCalc::IStream<T*> { + TVector<T> Data; + size_t Index = 0; + +public: + T* Fetch() override { + return Index < Data.size() ? &Data[Index++] : nullptr; + } +}; + +Y_UNIT_TEST_SUITE(TestProtoIO) { + Y_UNIT_TEST(TestAllTypes) { + using namespace NYql::NPureCalc; + + auto factory = MakeProgramFactory(); + + { + auto program = factory->MakePullStreamProgram( + TProtobufInputSpec<NPureCalcProto::TAllTypes>(), + TProtobufOutputSpec<NPureCalcProto::TAllTypes>(), + "SELECT * FROM Input", + ETranslationMode::SQL + ); + + auto stream = program->Apply(MakeHolder<TAllTypesStreamImpl>()); + + NPureCalcProto::TAllTypes* message; + + UNIT_ASSERT(message = stream->Fetch()); + AssertEqualToCanonical(*message); + UNIT_ASSERT(!stream->Fetch()); + } + + { + auto program = factory->MakePullListProgram( + TProtobufInputSpec<NPureCalcProto::TAllTypes>(), + TProtobufOutputSpec<NPureCalcProto::TAllTypes>(), + "SELECT * FROM Input", + ETranslationMode::SQL + ); + + auto stream = program->Apply(MakeHolder<TAllTypesStreamImpl>()); + + NPureCalcProto::TAllTypes* message; + + UNIT_ASSERT(message = stream->Fetch()); + AssertEqualToCanonical(*message); + UNIT_ASSERT(!stream->Fetch()); + } + + { + auto program = factory->MakePushStreamProgram( + TProtobufInputSpec<NPureCalcProto::TAllTypes>(), + TProtobufOutputSpec<NPureCalcProto::TAllTypes>(), + "SELECT * FROM Input", + ETranslationMode::SQL + ); + + auto consumer = program->Apply(MakeHolder<TAllTypesConsumerImpl>()); + + UNIT_ASSERT_NO_EXCEPTION([&](){ consumer->OnObject(&GetCanonicalMessage()); }()); + UNIT_ASSERT_NO_EXCEPTION([&](){ consumer->OnFinish(); }()); + } + } + + template <typename T> + void CheckPassThroughYql(T& testInput, google::protobuf::Arena* arena = nullptr) { + using namespace NYql::NPureCalc; + + auto resetArena = [arena]() { + if (arena != nullptr) { + arena->Reset(); + } + }; + + auto factory = MakeProgramFactory(); + + { + auto program = factory->MakePushStreamProgram( + TProtobufInputSpec<T>(), + TProtobufOutputSpec<T>({}, arena), + "SELECT * FROM Input", + ETranslationMode::SQL + ); + + auto resultConsumer = MakeHolder<TVectorConsumer<T>>(); + auto* resultConsumerPtr = resultConsumer.Get(); + auto sourceConsumer = program->Apply(std::move(resultConsumer)); + + sourceConsumer->OnObject(&testInput); + UNIT_ASSERT_VALUES_EQUAL(1, resultConsumerPtr->Data.size()); + AssertProtoEqual(resultConsumerPtr->Data[0], testInput); + + resultConsumerPtr->Data.clear(); + sourceConsumer->OnObject(&testInput); + UNIT_ASSERT_VALUES_EQUAL(1, resultConsumerPtr->Data.size()); + AssertProtoEqual(resultConsumerPtr->Data[0], testInput); + } + resetArena(); + + { + auto program = factory->MakePullStreamProgram( + TProtobufInputSpec<T>(), + TProtobufOutputSpec<T>({}, arena), + "SELECT * FROM Input", + ETranslationMode::SQL + ); + + auto sourceStream = MakeHolder<TVectorStream<T>>(); + auto* sourceStreamPtr = sourceStream.Get(); + auto resultStream = program->Apply(std::move(sourceStream)); + + sourceStreamPtr->Data.push_back(testInput); + T* resultMessage; + UNIT_ASSERT(resultMessage = resultStream->Fetch()); + AssertProtoEqual(*resultMessage, testInput); + UNIT_ASSERT(!resultStream->Fetch()); + + UNIT_ASSERT_VALUES_EQUAL(resultMessage->GetArena(), arena); + } + resetArena(); + + { + auto program = factory->MakePullListProgram( + TProtobufInputSpec<T>(), + TProtobufOutputSpec<T>({}, arena), + "SELECT * FROM Input", + ETranslationMode::SQL + ); + + auto sourceStream = MakeHolder<TVectorStream<T>>(); + auto* sourceStreamPtr = sourceStream.Get(); + auto resultStream = program->Apply(std::move(sourceStream)); + + sourceStreamPtr->Data.push_back(testInput); + T* resultMessage; + UNIT_ASSERT(resultMessage = resultStream->Fetch()); + AssertProtoEqual(*resultMessage, testInput); + UNIT_ASSERT(!resultStream->Fetch()); + + UNIT_ASSERT_VALUES_EQUAL(resultMessage->GetArena(), arena); + } + resetArena(); + } + + template <typename T> + void CheckMessageIsInvalid(const TString& expectedExceptionMessage) { + using namespace NYql::NPureCalc; + + auto factory = MakeProgramFactory(); + + UNIT_ASSERT_EXCEPTION_CONTAINS([&]() { + factory->MakePushStreamProgram(TProtobufInputSpec<T>(), TProtobufOutputSpec<T>(), "SELECT * FROM Input", ETranslationMode::SQL); + }(), yexception, expectedExceptionMessage); + + UNIT_ASSERT_EXCEPTION_CONTAINS([&]() { + factory->MakePullStreamProgram(TProtobufInputSpec<T>(), TProtobufOutputSpec<T>(), "SELECT * FROM Input", ETranslationMode::SQL); + }(), yexception, expectedExceptionMessage); + + UNIT_ASSERT_EXCEPTION_CONTAINS([&]() { + factory->MakePullListProgram(TProtobufInputSpec<T>(), TProtobufOutputSpec<T>(), "SELECT * FROM Input", ETranslationMode::SQL); + }(), yexception, expectedExceptionMessage); + } + + Y_UNIT_TEST(TestSimpleNested) { + NPureCalcProto::TSimpleNested input; + input.SetX(10); + { + auto* item = input.MutableY(); + *item = GetCanonicalMessage(); + item->SetFUint64(100); + } + CheckPassThroughYql(input); + } + + Y_UNIT_TEST(TestOptionalNested) { + NPureCalcProto::TOptionalNested input; + { + auto* item = input.MutableX(); + *item = GetCanonicalMessage(); + item->SetFUint64(100); + } + CheckPassThroughYql(input); + } + + Y_UNIT_TEST(TestSimpleRepeated) { + NPureCalcProto::TSimpleRepeated input; + input.SetX(20); + input.AddY(100); + input.AddY(200); + input.AddY(300); + CheckPassThroughYql(input); + } + + Y_UNIT_TEST(TestNestedRepeated) { + NPureCalcProto::TNestedRepeated input; + input.SetX(20); + { + auto* item = input.MutableY()->Add(); + item->SetX(100); + { + auto* y = item->MutableY(); + *y = GetCanonicalMessage(); + y->SetFUint64(1000); + } + } + { + auto* item = input.MutableY()->Add(); + item->SetX(200); + { + auto* y = item->MutableY(); + *y = GetCanonicalMessage(); + y->SetFUint64(2000); + } + } + CheckPassThroughYql(input); + } + + Y_UNIT_TEST(TestMessageWithEnum) { + NPureCalcProto::TMessageWithEnum input; + input.AddEnumValue(NPureCalcProto::TMessageWithEnum::VALUE1); + input.AddEnumValue(NPureCalcProto::TMessageWithEnum::VALUE2); + CheckPassThroughYql(input); + } + + Y_UNIT_TEST(TestRecursive) { + CheckMessageIsInvalid<NPureCalcProto::TRecursive>("NPureCalcProto.TRecursive->NPureCalcProto.TRecursive"); + } + + Y_UNIT_TEST(TestRecursiveIndirectly) { + CheckMessageIsInvalid<NPureCalcProto::TRecursiveIndirectly>( + "NPureCalcProto.TRecursiveIndirectly->NPureCalcProto.TRecursiveIndirectly.TNested->NPureCalcProto.TRecursiveIndirectly"); + } + + Y_UNIT_TEST(TestColumnsFilter) { + using namespace NYql::NPureCalc; + + auto factory = MakeProgramFactory(); + + auto filter = THashSet<TString>({"FFixed64", "FBool", "FBytes"}); + + NPureCalcProto::TOptionalAllTypes canonicalMessage; + canonicalMessage.SetFFixed64(GetCanonicalMessage().GetFFixed64()); + canonicalMessage.SetFBool(GetCanonicalMessage().GetFBool()); + canonicalMessage.SetFBytes(GetCanonicalMessage().GetFBytes()); + + { + auto inputSpec = TProtobufInputSpec<NPureCalcProto::TAllTypes>(); + auto outputSpec = TProtobufOutputSpec<NPureCalcProto::TOptionalAllTypes>(); + outputSpec.SetOutputColumnsFilter(filter); + + auto program = factory->MakePullStreamProgram( + inputSpec, + outputSpec, + "SELECT * FROM Input", + ETranslationMode::SQL + ); + + UNIT_ASSERT_EQUAL(program->GetUsedColumns(), filter); + + auto stream = program->Apply(MakeHolder<TAllTypesStreamImpl>()); + + NPureCalcProto::TOptionalAllTypes* message; + + UNIT_ASSERT(message = stream->Fetch()); + AssertEqualToCanonical(*message, canonicalMessage); + UNIT_ASSERT(!stream->Fetch()); + } + } + + Y_UNIT_TEST(TestColumnsFilterWithOptionalFields) { + using namespace NYql::NPureCalc; + + auto factory = MakeProgramFactory(); + + auto fields = THashSet<TString>({"FFixed64", "FBool", "FBytes"}); + + NPureCalcProto::TOptionalAllTypes canonicalMessage; + canonicalMessage.SetFFixed64(GetCanonicalMessage().GetFFixed64()); + canonicalMessage.SetFBool(GetCanonicalMessage().GetFBool()); + canonicalMessage.SetFBytes(GetCanonicalMessage().GetFBytes()); + + { + auto program = factory->MakePullStreamProgram( + TProtobufInputSpec<NPureCalcProto::TAllTypes>(), + TProtobufOutputSpec<NPureCalcProto::TOptionalAllTypes>(), + "SELECT FFixed64, FBool, FBytes FROM Input", + ETranslationMode::SQL + ); + + UNIT_ASSERT_EQUAL(program->GetUsedColumns(), fields); + + auto stream = program->Apply(MakeHolder<TAllTypesStreamImpl>()); + + NPureCalcProto::TOptionalAllTypes* message; + + UNIT_ASSERT(message = stream->Fetch()); + AssertEqualToCanonical(*message, canonicalMessage); + UNIT_ASSERT(!stream->Fetch()); + } + + UNIT_ASSERT_EXCEPTION_CONTAINS([&](){ + factory->MakePullStreamProgram( + TProtobufInputSpec<NPureCalcProto::TAllTypes>(), + TProtobufOutputSpec<NPureCalcProto::TAllTypes>(), + "SELECT FFixed64, FBool, FBytes FROM Input", + ETranslationMode::SQL + ); + }(), TCompileError, "Failed to optimize"); + } + + Y_UNIT_TEST(TestUsedColumns) { + using namespace NYql::NPureCalc; + + auto factory = MakeProgramFactory(); + + auto allFields = THashSet<TString>(); + + for (auto i: xrange(NPureCalcProto::TOptionalAllTypes::descriptor()->field_count())) { + allFields.emplace(NPureCalcProto::TOptionalAllTypes::descriptor()->field(i)->name()); + } + + { + auto program = factory->MakePullStreamProgram( + TProtobufInputSpec<NPureCalcProto::TAllTypes>(), + TProtobufOutputSpec<NPureCalcProto::TOptionalAllTypes>(), + "SELECT * FROM Input", + ETranslationMode::SQL + ); + + UNIT_ASSERT_EQUAL(program->GetUsedColumns(), allFields); + } + } + + Y_UNIT_TEST(TestChaining) { + using namespace NYql::NPureCalc; + + auto factory = MakeProgramFactory(); + + TString sql1 = "SELECT UNWRAP(X || CAST(\"HI\" AS Utf8)) AS X FROM Input"; + TString sql2 = "SELECT LENGTH(X) AS X FROM Input"; + + { + auto program1 = factory->MakePullStreamProgram( + TProtobufInputSpec<NPureCalcProto::TStringMessage>(), + TProtobufOutputSpec<NPureCalcProto::TStringMessage>(), + sql1, + ETranslationMode::SQL + ); + + auto program2 = factory->MakePullStreamProgram( + TProtobufInputSpec<NPureCalcProto::TStringMessage>(), + TProtobufOutputSpec<NPureCalcProto::TSimpleMessage>(), + sql2, + ETranslationMode::SQL + ); + + auto input = MakeHolder<TStringMessageStreamImpl>(); + auto intermediate = program1->Apply(std::move(input)); + auto output = program2->Apply(std::move(intermediate)); + + TVector<int> expected = {2, 3, 4}; + TVector<int> actual{}; + + while (auto *x = output->Fetch()) { + actual.push_back(x->GetX()); + } + + UNIT_ASSERT_EQUAL(expected, actual); + } + + { + auto program1 = factory->MakePullListProgram( + TProtobufInputSpec<NPureCalcProto::TStringMessage>(), + TProtobufOutputSpec<NPureCalcProto::TStringMessage>(), + sql1, + ETranslationMode::SQL + ); + + auto program2 = factory->MakePullListProgram( + TProtobufInputSpec<NPureCalcProto::TStringMessage>(), + TProtobufOutputSpec<NPureCalcProto::TSimpleMessage>(), + sql2, + ETranslationMode::SQL + ); + + auto input = MakeHolder<TStringMessageStreamImpl>(); + auto intermediate = program1->Apply(std::move(input)); + auto output = program2->Apply(std::move(intermediate)); + + TVector<int> expected = {2, 3, 4}; + TVector<int> actual{}; + + while (auto *x = output->Fetch()) { + actual.push_back(x->GetX()); + } + + UNIT_ASSERT_EQUAL(expected, actual); + } + + { + auto program1 = factory->MakePushStreamProgram( + TProtobufInputSpec<NPureCalcProto::TStringMessage>(), + TProtobufOutputSpec<NPureCalcProto::TStringMessage>(), + sql1, + ETranslationMode::SQL + ); + + auto program2 = factory->MakePushStreamProgram( + TProtobufInputSpec<NPureCalcProto::TStringMessage>(), + TProtobufOutputSpec<NPureCalcProto::TSimpleMessage>(), + sql2, + ETranslationMode::SQL + ); + + TVector<int> expected = {2, 3, 4, -100}; + TVector<int> actual{}; + + auto consumer = MakeHolder<TSimpleMessageConsumerImpl>(&actual); + auto intermediate = program2->Apply(std::move(consumer)); + auto input = program1->Apply(std::move(intermediate)); + + NPureCalcProto::TStringMessage Message; + + Message.SetX(""); + input->OnObject(&Message); + + Message.SetX("1"); + input->OnObject(&Message); + + Message.SetX("22"); + input->OnObject(&Message); + + input->OnFinish(); + + UNIT_ASSERT_EQUAL(expected, actual); + } + } + + Y_UNIT_TEST(TestTimestampColumn) { + using namespace NYql::NPureCalc; + + auto factory = MakeProgramFactory(TProgramFactoryOptions() + .SetDeterministicTimeProviderSeed(1)); // seconds + + NPureCalcProto::TOptionalAllTypes canonicalMessage; + + { + auto inputSpec = TProtobufInputSpec<NPureCalcProto::TAllTypes>("MyTimestamp"); + auto outputSpec = TProtobufOutputSpec<NPureCalcProto::TOptionalAllTypes>(); + + auto program = factory->MakePullStreamProgram( + inputSpec, + outputSpec, + "SELECT MyTimestamp AS FFixed64 FROM Input", + ETranslationMode::SQL + ); + + auto stream = program->Apply(MakeHolder<TAllTypesStreamImpl>()); + + NPureCalcProto::TOptionalAllTypes* message; + + UNIT_ASSERT(message = stream->Fetch()); + UNIT_ASSERT_VALUES_EQUAL(message->GetFFixed64(), 1000000); // microseconds + UNIT_ASSERT(!stream->Fetch()); + } + } + + Y_UNIT_TEST(TestTableNames) { + using namespace NYql::NPureCalc; + + auto factory = MakeProgramFactory(TProgramFactoryOptions().SetUseSystemColumns(true)); + + auto runTest = [&](TStringBuf tableName, i32 value) { + auto program = factory->MakePullStreamProgram( + TProtobufInputSpec<NPureCalcProto::TSimpleMessage>(), + TProtobufOutputSpec<NPureCalcProto::TNamedSimpleMessage>(), + TString::Join("SELECT TableName() AS Name, X FROM ", tableName), + ETranslationMode::SQL + ); + + auto stream = program->Apply(MakeHolder<TSimpleMessageStreamImpl>(value)); + auto message = stream->Fetch(); + + UNIT_ASSERT(message); + UNIT_ASSERT_VALUES_EQUAL(message->GetX(), value); + UNIT_ASSERT_VALUES_EQUAL(message->GetName(), tableName); + UNIT_ASSERT(!stream->Fetch()); + }; + + runTest("Input", 37); + runTest("Input0", -23); + } + + void CheckMultiOutputs(TMaybe<TVector<google::protobuf::Arena*>> arenas) { + using namespace NYql::NPureCalc; + + auto factory = MakeProgramFactory(); + TString sExpr = R"( +( + (let $type (ParseType '"Variant<Struct<BInt:Int32,BString:Utf8>, Struct<CUint:Uint32,CString:Utf8>, Struct<X:Utf8>>")) + (let $stream (Self '0)) + (return (FlatMap (Self '0) (lambda '(x) (block '( + (let $cond (Member x 'ABool)) + (let $item0 (Variant (AsStruct '('BInt (Member x 'AInt)) '('BString (Member x 'AString))) '0 $type)) + (let $item1 (Variant (AsStruct '('CUint (Member x 'AUint)) '('CString (Member x 'AString))) '1 $type)) + (let $item2 (Variant (AsStruct '('X (Utf8 'Error))) '2 $type)) + (return (If (Exists $cond) (If (Unwrap $cond) (AsList $item0) (AsList $item1)) (AsList $item2))) + ))))) +) + )"; + + { + auto program = factory->MakePushStreamProgram( + TProtobufInputSpec<NPureCalcProto::TUnsplitted>(), + TProtobufMultiOutputSpec<NPureCalcProto::TSplitted1, NPureCalcProto::TSplitted2, NPureCalcProto::TStringMessage>( + {}, arenas + ), + sExpr, + ETranslationMode::SExpr + ); + + TVariantConsumerImpl::TType0 queue0; + TVariantConsumerImpl::TType1 queue1; + TVariantConsumerImpl::TType2 queue2; + int finalValue = 0; + + auto consumer = MakeHolder<TVariantConsumerImpl>(&queue0, &queue1, &queue2, &finalValue); + auto input = program->Apply(std::move(consumer)); + + NPureCalcProto::TUnsplitted message; + message.SetAInt(-13); + message.SetAUint(47); + message.SetAString("first message"); + message.SetABool(true); + + input->OnObject(&message); + UNIT_ASSERT(queue0.size() == 1 && queue1.empty() && queue2.empty() && finalValue == 0); + + message.SetABool(false); + message.SetAString("second message"); + + input->OnObject(&message); + UNIT_ASSERT(queue0.size() == 1 && queue1.size() == 1 && queue2.empty() && finalValue == 0); + + message.ClearABool(); + + input->OnObject(&message); + UNIT_ASSERT(queue0.size() == 1 && queue1.size() == 1 && queue2.size() == 1 && finalValue == 0); + + input->OnFinish(); + UNIT_ASSERT(queue0.size() == 1 && queue1.size() == 1 && queue2.size() == 1 && finalValue == 42); + + TVariantConsumerImpl::TType0 expected0 = {{-13, "first message"}}; + UNIT_ASSERT_EQUAL(queue0, expected0); + + TVariantConsumerImpl::TType1 expected1 = {{47, "second message"}}; + UNIT_ASSERT_EQUAL(queue1, expected1); + + TVariantConsumerImpl::TType2 expected2 = {{"Error"}}; + UNIT_ASSERT_EQUAL(queue2, expected2); + } + + { + auto program1 = factory->MakePullStreamProgram( + TProtobufInputSpec<NPureCalcProto::TUnsplitted>(), + TProtobufMultiOutputSpec<NPureCalcProto::TSplitted1, NPureCalcProto::TSplitted2, NPureCalcProto::TStringMessage>( + {}, arenas + ), + sExpr, + ETranslationMode::SExpr + ); + + auto program2 = factory->MakePullListProgram( + TProtobufInputSpec<NPureCalcProto::TUnsplitted>(), + TProtobufMultiOutputSpec<NPureCalcProto::TSplitted1, NPureCalcProto::TSplitted2, NPureCalcProto::TStringMessage>( + {}, arenas + ), + sExpr, + ETranslationMode::SExpr + ); + + auto input1 = MakeHolder<TUnsplittedStreamImpl>(); + auto output1 = program1->Apply(std::move(input1)); + + auto input2 = MakeHolder<TUnsplittedStreamImpl>(); + auto output2 = program2->Apply(std::move(input2)); + + decltype(output1->Fetch()) variant1; + decltype(output2->Fetch()) variant2; + +#define ASSERT_EQUAL_FIELDS(X1, X2, I, F, E) \ + UNIT_ASSERT_EQUAL(X1.index(), I); \ + UNIT_ASSERT_EQUAL(X2.index(), I); \ + UNIT_ASSERT_EQUAL(std::get<I>(X1)->Get##F(), E); \ + UNIT_ASSERT_EQUAL(std::get<I>(X2)->Get##F(), E) + + variant1 = output1->Fetch(); + variant2 = output2->Fetch(); + ASSERT_EQUAL_FIELDS(variant1, variant2, 2, X, "Error"); + ASSERT_EQUAL_FIELDS(variant1, variant2, 2, Arena, (arenas.Defined() ? arenas->at(2) : nullptr)); + + variant1 = output1->Fetch(); + variant2 = output2->Fetch(); + ASSERT_EQUAL_FIELDS(variant1, variant2, 1, CUint, 111); + ASSERT_EQUAL_FIELDS(variant1, variant2, 1, CString, "Hello!"); + ASSERT_EQUAL_FIELDS(variant1, variant2, 1, Arena, (arenas.Defined() ? arenas->at(1) : nullptr)); + + variant1 = output1->Fetch(); + variant2 = output2->Fetch(); + ASSERT_EQUAL_FIELDS(variant1, variant2, 0, BInt, -23); + ASSERT_EQUAL_FIELDS(variant1, variant2, 0, BString, "Hello!"); + ASSERT_EQUAL_FIELDS(variant1, variant2, 0, Arena, (arenas.Defined() ? arenas->at(0) : nullptr)); + + variant1 = output1->Fetch(); + variant2 = output2->Fetch(); + UNIT_ASSERT_EQUAL(variant1.index(), 0); + UNIT_ASSERT_EQUAL(variant2.index(), 0); + UNIT_ASSERT_EQUAL(std::get<0>(variant1), nullptr); + UNIT_ASSERT_EQUAL(std::get<0>(variant1), nullptr); + +#undef ASSERT_EQUAL_FIELDS + } + } + + Y_UNIT_TEST(TestMultiOutputs) { + CheckMultiOutputs(Nothing()); + } + + Y_UNIT_TEST(TestSupportedTypes) { + + } + + Y_UNIT_TEST(TestProtobufArena) { + { + NPureCalcProto::TNestedRepeated input; + input.SetX(20); + { + auto* item = input.MutableY()->Add(); + item->SetX(100); + { + auto* y = item->MutableY(); + *y = GetCanonicalMessage(); + y->SetFUint64(1000); + } + } + { + auto* item = input.MutableY()->Add(); + item->SetX(200); + { + auto* y = item->MutableY(); + *y = GetCanonicalMessage(); + y->SetFUint64(2000); + } + } + + google::protobuf::Arena arena; + CheckPassThroughYql(input, &arena); + } + + { + google::protobuf::Arena arena1; + google::protobuf::Arena arena2; + TVector<google::protobuf::Arena*> arenas{&arena1, &arena2, &arena1}; + CheckMultiOutputs(arenas); + } + } + + Y_UNIT_TEST(TestFieldRenames) { + using namespace NYql::NPureCalc; + + auto factory = MakeProgramFactory(); + + TString query = "SELECT InputAlias AS OutputAlias FROM Input"; + + auto inputProtoOptions = TProtoSchemaOptions(); + inputProtoOptions.SetFieldRenames({{"X", "InputAlias"}}); + + auto inputSpec = TProtobufInputSpec<NPureCalcProto::TSimpleMessage>( + Nothing(), std::move(inputProtoOptions) + ); + + auto outputProtoOptions = TProtoSchemaOptions(); + outputProtoOptions.SetFieldRenames({{"X", "OutputAlias"}}); + + auto outputSpec = TProtobufOutputSpec<NPureCalcProto::TSimpleMessage>( + std::move(outputProtoOptions) + ); + + { + auto program = factory->MakePullStreamProgram( + inputSpec, outputSpec, query, ETranslationMode::SQL + ); + + auto input = MakeHolder<TSimpleMessageStreamImpl>(1); + auto output = program->Apply(std::move(input)); + + TVector<int> expected = {1}; + TVector<int> actual; + + while (auto* x = output->Fetch()) { + actual.push_back(x->GetX()); + } + + UNIT_ASSERT_VALUES_EQUAL(expected, actual); + } + + { + auto program = factory->MakePullListProgram( + inputSpec, outputSpec, query, ETranslationMode::SQL + ); + + auto input = MakeHolder<TSimpleMessageStreamImpl>(1); + auto output = program->Apply(std::move(input)); + + TVector<int> expected = {1}; + TVector<int> actual; + + while (auto* x = output->Fetch()) { + actual.push_back(x->GetX()); + } + + UNIT_ASSERT_VALUES_EQUAL(expected, actual); + } + + { + auto program = factory->MakePushStreamProgram( + inputSpec, outputSpec, query, ETranslationMode::SQL + ); + + TVector<int> expected = {1, -100}; + TVector<int> actual; + + auto consumer = MakeHolder<TSimpleMessageConsumerImpl>(&actual); + auto input = program->Apply(std::move(consumer)); + + NPureCalcProto::TSimpleMessage Message; + + Message.SetX(1); + input->OnObject(&Message); + + input->OnFinish(); + + UNIT_ASSERT_VALUES_EQUAL(expected, actual); + } + } +} diff --git a/yql/essentials/public/purecalc/io_specs/protobuf/ut/ya.make b/yql/essentials/public/purecalc/io_specs/protobuf/ut/ya.make new file mode 100644 index 00000000000..2519816d02e --- /dev/null +++ b/yql/essentials/public/purecalc/io_specs/protobuf/ut/ya.make @@ -0,0 +1,23 @@ +IF (NOT SANITIZER_TYPE) + +UNITTEST() + +PEERDIR( + library/cpp/protobuf/util + yql/essentials/public/udf/service/exception_policy + yql/essentials/public/purecalc + yql/essentials/public/purecalc/io_specs/protobuf + yql/essentials/public/purecalc/ut/protos +) + +SIZE(MEDIUM) + +YQL_LAST_ABI_VERSION() + +SRCS( + test_spec.cpp +) + +END() + +ENDIF() diff --git a/yql/essentials/public/purecalc/io_specs/protobuf/ya.make b/yql/essentials/public/purecalc/io_specs/protobuf/ya.make new file mode 100644 index 00000000000..b9441ceecf4 --- /dev/null +++ b/yql/essentials/public/purecalc/io_specs/protobuf/ya.make @@ -0,0 +1,19 @@ +LIBRARY() + +PEERDIR( + yql/essentials/public/purecalc/common + yql/essentials/public/purecalc/io_specs/protobuf_raw +) + +SRCS( + spec.cpp + proto_variant.cpp +) + +YQL_LAST_ABI_VERSION() + +END() + +RECURSE_FOR_TESTS( + ut +) diff --git a/yql/essentials/public/purecalc/io_specs/protobuf_raw/proto_holder.cpp b/yql/essentials/public/purecalc/io_specs/protobuf_raw/proto_holder.cpp new file mode 100644 index 00000000000..95adbc4de95 --- /dev/null +++ b/yql/essentials/public/purecalc/io_specs/protobuf_raw/proto_holder.cpp @@ -0,0 +1 @@ +#include "proto_holder.h" diff --git a/yql/essentials/public/purecalc/io_specs/protobuf_raw/proto_holder.h b/yql/essentials/public/purecalc/io_specs/protobuf_raw/proto_holder.h new file mode 100644 index 00000000000..7d4d843bfcf --- /dev/null +++ b/yql/essentials/public/purecalc/io_specs/protobuf_raw/proto_holder.h @@ -0,0 +1,31 @@ +#pragma once + +#include <google/protobuf/arena.h> + +#include <util/generic/ptr.h> + +#include <type_traits> + +namespace NYql::NPureCalc { + class TProtoDestroyer { + public: + template <typename T> + static inline void Destroy(T* t) noexcept { + if (t->GetArena() == nullptr) { + CheckedDelete(t); + } + } + }; + + template <typename TProto> + concept IsProtoMessage = std::is_base_of_v<NProtoBuf::Message, TProto>; + + template <IsProtoMessage TProto> + using TProtoHolder = THolder<TProto, TProtoDestroyer>; + + template <IsProtoMessage TProto, typename... TArgs> + TProtoHolder<TProto> MakeProtoHolder(NProtoBuf::Arena* arena, TArgs&&... args) { + auto* ptr = NProtoBuf::Arena::CreateMessage<TProto>(arena, std::forward<TArgs>(args)...); + return TProtoHolder<TProto>(ptr); + } +} diff --git a/yql/essentials/public/purecalc/io_specs/protobuf_raw/spec.cpp b/yql/essentials/public/purecalc/io_specs/protobuf_raw/spec.cpp new file mode 100644 index 00000000000..0a3cc41427f --- /dev/null +++ b/yql/essentials/public/purecalc/io_specs/protobuf_raw/spec.cpp @@ -0,0 +1,1064 @@ +#include "proto_holder.h" +#include "spec.h" + +#include <yql/essentials/public/udf/udf_value.h> +#include <yql/essentials/minikql/computation/mkql_computation_node_holders.h> +#include <yql/essentials/minikql/computation/mkql_custom_list.h> +#include <yql/essentials/minikql/mkql_string_util.h> +#include <yql/essentials/utils/yql_panic.h> +#include <google/protobuf/reflection.h> + +using namespace NYql; +using namespace NYql::NPureCalc; +using namespace google::protobuf; +using namespace NKikimr::NUdf; +using namespace NKikimr::NMiniKQL; + +TProtobufRawInputSpec::TProtobufRawInputSpec( + const Descriptor& descriptor, + const TMaybe<TString>& timestampColumn, + const TProtoSchemaOptions& options +) + : Descriptor_(descriptor) + , TimestampColumn_(timestampColumn) + , SchemaOptions_(options) +{ +} + +const TVector<NYT::TNode>& TProtobufRawInputSpec::GetSchemas() const { + if (SavedSchemas_.size() == 0) { + SavedSchemas_.push_back(MakeSchemaFromProto(Descriptor_, SchemaOptions_)); + if (TimestampColumn_) { + auto timestampType = NYT::TNode::CreateList(); + timestampType.Add("DataType"); + timestampType.Add("Uint64"); + auto timestamp = NYT::TNode::CreateList(); + timestamp.Add(*TimestampColumn_); + timestamp.Add(timestampType); + SavedSchemas_.back().AsList()[1].AsList().push_back(timestamp); + } + } + + return SavedSchemas_; +} + +const Descriptor& TProtobufRawInputSpec::GetDescriptor() const { + return Descriptor_; +} + +const TMaybe<TString>& TProtobufRawInputSpec::GetTimestampColumn() const { + return TimestampColumn_; +} + +const TProtoSchemaOptions& TProtobufRawInputSpec::GetSchemaOptions() const { + return SchemaOptions_; +} + +TProtobufRawOutputSpec::TProtobufRawOutputSpec( + const Descriptor& descriptor, + MessageFactory* factory, + const TProtoSchemaOptions& options, + Arena* arena +) + : Descriptor_(descriptor) + , Factory_(factory) + , SchemaOptions_(options) + , Arena_(arena) +{ + SchemaOptions_.ListIsOptional = true; +} + +const NYT::TNode& TProtobufRawOutputSpec::GetSchema() const { + if (!SavedSchema_) { + SavedSchema_ = MakeSchemaFromProto(Descriptor_, SchemaOptions_); + } + + return SavedSchema_.GetRef(); +} + +const Descriptor& TProtobufRawOutputSpec::GetDescriptor() const { + return Descriptor_; +} + +void TProtobufRawOutputSpec::SetFactory(MessageFactory* factory) { + Factory_ = factory; +} + +MessageFactory* TProtobufRawOutputSpec::GetFactory() const { + return Factory_; +} + +void TProtobufRawOutputSpec::SetArena(Arena* arena) { + Arena_ = arena; +} + +Arena* TProtobufRawOutputSpec::GetArena() const { + return Arena_; +} + +const TProtoSchemaOptions& TProtobufRawOutputSpec::GetSchemaOptions() const { + return SchemaOptions_; +} + +TProtobufRawMultiOutputSpec::TProtobufRawMultiOutputSpec( + TVector<const Descriptor*> descriptors, + TMaybe<TVector<MessageFactory*>> factories, + const TProtoSchemaOptions& options, + TMaybe<TVector<Arena*>> arenas +) + : Descriptors_(std::move(descriptors)) + , SchemaOptions_(options) +{ + if (factories) { + Y_ENSURE(factories->size() == Descriptors_.size(), "number of factories must match number of descriptors"); + Factories_ = std::move(*factories); + } else { + Factories_ = TVector<MessageFactory*>(Descriptors_.size(), nullptr); + } + + if (arenas) { + Y_ENSURE(arenas->size() == Descriptors_.size(), "number of arenas must match number of descriptors"); + Arenas_ = std::move(*arenas); + } else { + Arenas_ = TVector<Arena*>(Descriptors_.size(), nullptr); + } +} + +const NYT::TNode& TProtobufRawMultiOutputSpec::GetSchema() const { + if (SavedSchema_.IsUndefined()) { + SavedSchema_ = MakeVariantSchemaFromProtos(Descriptors_, SchemaOptions_); + } + + return SavedSchema_; +} + +const Descriptor& TProtobufRawMultiOutputSpec::GetDescriptor(ui32 index) const { + Y_ENSURE(index < Descriptors_.size(), "invalid output index"); + + return *Descriptors_[index]; +} + +void TProtobufRawMultiOutputSpec::SetFactory(ui32 index, MessageFactory* factory) { + Y_ENSURE(index < Factories_.size(), "invalid output index"); + + Factories_[index] = factory; +} + +MessageFactory* TProtobufRawMultiOutputSpec::GetFactory(ui32 index) const { + Y_ENSURE(index < Factories_.size(), "invalid output index"); + + return Factories_[index]; +} + +void TProtobufRawMultiOutputSpec::SetArena(ui32 index, Arena* arena) { + Y_ENSURE(index < Arenas_.size(), "invalid output index"); + + Arenas_[index] = arena; +} + +Arena* TProtobufRawMultiOutputSpec::GetArena(ui32 index) const { + Y_ENSURE(index < Arenas_.size(), "invalid output index"); + + return Arenas_[index]; +} + +ui32 TProtobufRawMultiOutputSpec::GetOutputsNumber() const { + return static_cast<ui32>(Descriptors_.size()); +} + +const TProtoSchemaOptions& TProtobufRawMultiOutputSpec::GetSchemaOptions() const { + return SchemaOptions_; +} + +namespace { + struct TFieldMapping { + TString Name; + const FieldDescriptor* Field; + TVector<TFieldMapping> NestedFields; + }; + + /** + * Fills a tree of field mappings from the given yql struct type to protobuf message. + * + * @param fromType source yql type. + * @param toType target protobuf message type. + * @param mappings destination vector will be filled with field descriptors. Order of descriptors will match + * the order of field names. + */ + void FillFieldMappings( + const TStructType* fromType, + const Descriptor& toType, + TVector<TFieldMapping>& mappings, + const TMaybe<TString>& timestampColumn, + bool listIsOptional, + const THashMap<TString, TString>& fieldRenames + ) { + THashMap<TString, TString> inverseFieldRenames; + + for (const auto& [source, target]: fieldRenames) { + auto [iterator, emplaced] = inverseFieldRenames.emplace(target, source); + Y_ENSURE(emplaced, "Duplicate rename field found: " << source << " -> " << target); + } + + mappings.resize(fromType->GetMembersCount()); + for (ui32 i = 0; i < fromType->GetMembersCount(); ++i) { + TString fieldName(fromType->GetMemberName(i)); + if (auto fieldRenamePtr = inverseFieldRenames.FindPtr(fieldName)) { + fieldName = *fieldRenamePtr; + } + + mappings[i].Name = fieldName; + mappings[i].Field = toType.FindFieldByName(fieldName); + YQL_ENSURE( + mappings[i].Field || timestampColumn && *timestampColumn == fieldName, + "Missing field: " << fieldName); + + const auto* fieldType = fromType->GetMemberType(i); + if (fieldType->GetKind() == NKikimr::NMiniKQL::TType::EKind::List) { + const auto* listType = static_cast<const NKikimr::NMiniKQL::TListType*>(fieldType); + fieldType = listType->GetItemType(); + } else if (fieldType->GetKind() == NKikimr::NMiniKQL::TType::EKind::Optional) { + const auto* optionalType = static_cast<const NKikimr::NMiniKQL::TOptionalType*>(fieldType); + fieldType = optionalType->GetItemType(); + + if (listIsOptional) { + if (fieldType->GetKind() == NKikimr::NMiniKQL::TType::EKind::List) { + const auto* listType = static_cast<const NKikimr::NMiniKQL::TListType*>(fieldType); + fieldType = listType->GetItemType(); + } + } + } + YQL_ENSURE(fieldType->GetKind() == NKikimr::NMiniKQL::TType::EKind::Struct || + fieldType->GetKind() == NKikimr::NMiniKQL::TType::EKind::Data, + "unsupported field kind [" << fieldType->GetKindAsStr() << "], field [" << fieldName << "]"); + if (fieldType->GetKind() == NKikimr::NMiniKQL::TType::EKind::Struct) { + FillFieldMappings(static_cast<const NKikimr::NMiniKQL::TStructType*>(fieldType), + *mappings[i].Field->message_type(), + mappings[i].NestedFields, Nothing(), listIsOptional, {}); + } + } + } + + /** + * Extract field values from the given protobuf message into an array of unboxed values. + * + * @param factory to create nested unboxed values. + * @param source source protobuf message. + * @param destination destination array of unboxed values. Each element in the array corresponds to a field + * in the protobuf message. + * @param mappings vector of protobuf field descriptors which denotes relation between fields of the + * source message and elements of the destination array. + * @param scratch temporary string which will be used during conversion. + */ + void FillInputValue( + const THolderFactory& factory, + const Message* source, + TUnboxedValue* destination, + const TVector<TFieldMapping>& mappings, + const TMaybe<TString>& timestampColumn, + ITimeProvider* timeProvider, + EEnumPolicy enumPolicy + ) { + TString scratch; + auto reflection = source->GetReflection(); + for (ui32 i = 0; i < mappings.size(); ++i) { + auto mapping = mappings[i]; + if (!mapping.Field) { + YQL_ENSURE(timestampColumn && mapping.Name == *timestampColumn); + destination[i] = TUnboxedValuePod(timeProvider->Now().MicroSeconds()); + continue; + } + + const auto type = mapping.Field->type(); + if (mapping.Field->label() == FieldDescriptor::LABEL_REPEATED) { + const auto size = static_cast<ui32>(reflection->FieldSize(*source, mapping.Field)); + if (size == 0) { + destination[i] = factory.GetEmptyContainerLazy(); + } else { + TUnboxedValue* inplace = nullptr; + destination[i] = factory.CreateDirectArrayHolder(size, inplace); + for (ui32 j = 0; j < size; ++j) { + switch (type) { + case FieldDescriptor::TYPE_DOUBLE: + inplace[j] = TUnboxedValuePod(reflection->GetRepeatedDouble(*source, mapping.Field, j)); + break; + + case FieldDescriptor::TYPE_FLOAT: + inplace[j] = TUnboxedValuePod(reflection->GetRepeatedFloat(*source, mapping.Field, j)); + break; + + case FieldDescriptor::TYPE_INT64: + case FieldDescriptor::TYPE_SFIXED64: + case FieldDescriptor::TYPE_SINT64: + inplace[j] = TUnboxedValuePod(reflection->GetRepeatedInt64(*source, mapping.Field, j)); + break; + + case FieldDescriptor::TYPE_ENUM: + switch (EnumFormatType(*mapping.Field, enumPolicy)) { + case EEnumFormatType::Int32: + inplace[j] = TUnboxedValuePod(reflection->GetRepeatedEnumValue(*source, mapping.Field, j)); + break; + case EEnumFormatType::String: + inplace[j] = MakeString(reflection->GetRepeatedEnum(*source, mapping.Field, j)->name()); + break; + } + break; + + case FieldDescriptor::TYPE_UINT64: + case FieldDescriptor::TYPE_FIXED64: + inplace[j] = TUnboxedValuePod(reflection->GetRepeatedUInt64(*source, mapping.Field, j)); + break; + + case FieldDescriptor::TYPE_INT32: + case FieldDescriptor::TYPE_SFIXED32: + case FieldDescriptor::TYPE_SINT32: + inplace[j] = TUnboxedValuePod(reflection->GetRepeatedInt32(*source, mapping.Field, j)); + break; + + case FieldDescriptor::TYPE_UINT32: + case FieldDescriptor::TYPE_FIXED32: + inplace[j] = TUnboxedValuePod(reflection->GetRepeatedUInt32(*source, mapping.Field, j)); + break; + + case FieldDescriptor::TYPE_BOOL: + inplace[j] = TUnboxedValuePod(reflection->GetRepeatedBool(*source, mapping.Field, j)); + break; + + case FieldDescriptor::TYPE_STRING: + inplace[j] = MakeString(reflection->GetRepeatedStringReference(*source, mapping.Field, j, &scratch)); + break; + + case FieldDescriptor::TYPE_BYTES: + inplace[j] = MakeString(reflection->GetRepeatedStringReference(*source, mapping.Field, j, &scratch)); + break; + + case FieldDescriptor::TYPE_MESSAGE: + { + const Message& nestedMessage = reflection->GetRepeatedMessage(*source, mapping.Field, j); + TUnboxedValue* nestedValues = nullptr; + inplace[j] = factory.CreateDirectArrayHolder(static_cast<ui32>(mapping.NestedFields.size()), + nestedValues); + FillInputValue(factory, &nestedMessage, nestedValues, mapping.NestedFields, Nothing(), timeProvider, enumPolicy); + } + break; + + default: + ythrow yexception() << "Unsupported protobuf type: " << mapping.Field->type_name() << ", field: " << mapping.Field->name(); + } + } + } + } else { + if (!reflection->HasField(*source, mapping.Field)) { + continue; + } + + switch (type) { + case FieldDescriptor::TYPE_DOUBLE: + destination[i] = TUnboxedValuePod(reflection->GetDouble(*source, mapping.Field)); + break; + + case FieldDescriptor::TYPE_FLOAT: + destination[i] = TUnboxedValuePod(reflection->GetFloat(*source, mapping.Field)); + break; + + case FieldDescriptor::TYPE_INT64: + case FieldDescriptor::TYPE_SFIXED64: + case FieldDescriptor::TYPE_SINT64: + destination[i] = TUnboxedValuePod(reflection->GetInt64(*source, mapping.Field)); + break; + + case FieldDescriptor::TYPE_ENUM: + switch (EnumFormatType(*mapping.Field, enumPolicy)) { + case EEnumFormatType::Int32: + destination[i] = TUnboxedValuePod(reflection->GetEnumValue(*source, mapping.Field)); + break; + case EEnumFormatType::String: + destination[i] = MakeString(reflection->GetEnum(*source, mapping.Field)->name()); + break; + } + break; + + case FieldDescriptor::TYPE_UINT64: + case FieldDescriptor::TYPE_FIXED64: + destination[i] = TUnboxedValuePod(reflection->GetUInt64(*source, mapping.Field)); + break; + + case FieldDescriptor::TYPE_INT32: + case FieldDescriptor::TYPE_SFIXED32: + case FieldDescriptor::TYPE_SINT32: + destination[i] = TUnboxedValuePod(reflection->GetInt32(*source, mapping.Field)); + break; + + case FieldDescriptor::TYPE_UINT32: + case FieldDescriptor::TYPE_FIXED32: + destination[i] = TUnboxedValuePod(reflection->GetUInt32(*source, mapping.Field)); + break; + + case FieldDescriptor::TYPE_BOOL: + destination[i] = TUnboxedValuePod(reflection->GetBool(*source, mapping.Field)); + break; + + case FieldDescriptor::TYPE_STRING: + destination[i] = MakeString(reflection->GetStringReference(*source, mapping.Field, &scratch)); + break; + + case FieldDescriptor::TYPE_BYTES: + destination[i] = MakeString(reflection->GetStringReference(*source, mapping.Field, &scratch)); + break; + case FieldDescriptor::TYPE_MESSAGE: + { + const Message& nestedMessage = reflection->GetMessage(*source, mapping.Field); + TUnboxedValue* nestedValues = nullptr; + destination[i] = factory.CreateDirectArrayHolder(static_cast<ui32>(mapping.NestedFields.size()), + nestedValues); + FillInputValue(factory, &nestedMessage, nestedValues, mapping.NestedFields, Nothing(), timeProvider, enumPolicy); + } + break; + + default: + ythrow yexception() << "Unsupported protobuf type: " << mapping.Field->type_name() + << ", field: " << mapping.Field->name(); + } + } + } + } + + + /** + * Convert unboxed value to protobuf. + * + * @param source unboxed value to extract data from. Type of the value should be struct. It's UB to pass + * a non-struct value here. + * @param destination destination message. Data in this message will be overwritten + * by data from unboxed value. + * @param mappings vector of protobuf field descriptors which denotes relation between struct fields + * and message fields. For any i-th element of this vector, type of the i-th element of + * the unboxed structure must match type of the field pointed by descriptor. Size of this + * vector should match the number of fields in the struct. + */ + void FillOutputMessage( + const TUnboxedValue& source, + Message* destination, + const TVector<TFieldMapping>& mappings, + EEnumPolicy enumPolicy + ) { + auto reflection = destination->GetReflection(); + for (ui32 i = 0; i < mappings.size(); ++i) { + const auto& mapping = mappings[i]; + const auto& cell = source.GetElement(i); + if (!cell) { + reflection->ClearField(destination, mapping.Field); + continue; + } + const auto type = mapping.Field->type(); + if (mapping.Field->label() == FieldDescriptor::LABEL_REPEATED) { + const auto iter = cell.GetListIterator(); + reflection->ClearField(destination, mapping.Field); + for (TUnboxedValue item; iter.Next(item);) { + switch (mapping.Field->type()) { + case FieldDescriptor::TYPE_DOUBLE: + reflection->AddDouble(destination, mapping.Field, item.Get<double>()); + break; + + case FieldDescriptor::TYPE_FLOAT: + reflection->AddFloat(destination, mapping.Field, item.Get<float>()); + break; + + case FieldDescriptor::TYPE_INT64: + case FieldDescriptor::TYPE_SFIXED64: + case FieldDescriptor::TYPE_SINT64: + reflection->AddInt64(destination, mapping.Field, item.Get<i64>()); + break; + + case FieldDescriptor::TYPE_ENUM: { + switch (EnumFormatType(*mapping.Field, enumPolicy)) { + case EEnumFormatType::Int32: + reflection->AddEnumValue(destination, mapping.Field, item.Get<i32>()); + break; + case EEnumFormatType::String: { + auto enumValueDescriptor = mapping.Field->enum_type()->FindValueByName(TString(item.AsStringRef())); + if (!enumValueDescriptor) { + enumValueDescriptor = mapping.Field->default_value_enum(); + } + reflection->AddEnum(destination, mapping.Field, enumValueDescriptor); + break; + } + } + break; + } + + case FieldDescriptor::TYPE_UINT64: + case FieldDescriptor::TYPE_FIXED64: + reflection->AddUInt64(destination, mapping.Field, item.Get<ui64>()); + break; + + case FieldDescriptor::TYPE_INT32: + case FieldDescriptor::TYPE_SFIXED32: + case FieldDescriptor::TYPE_SINT32: + reflection->AddInt32(destination, mapping.Field, item.Get<i32>()); + break; + + case FieldDescriptor::TYPE_UINT32: + case FieldDescriptor::TYPE_FIXED32: + reflection->AddUInt32(destination, mapping.Field, item.Get<ui32>()); + break; + + case FieldDescriptor::TYPE_BOOL: + reflection->AddBool(destination, mapping.Field, item.Get<bool>()); + break; + + case FieldDescriptor::TYPE_STRING: + reflection->AddString(destination, mapping.Field, TString(item.AsStringRef())); + break; + + case FieldDescriptor::TYPE_BYTES: + reflection->AddString(destination, mapping.Field, TString(item.AsStringRef())); + break; + + case FieldDescriptor::TYPE_MESSAGE: + { + auto* nestedMessage = reflection->AddMessage(destination, mapping.Field); + FillOutputMessage(item, nestedMessage, mapping.NestedFields, enumPolicy); + } + break; + + default: + ythrow yexception() << "Unsupported protobuf type: " + << mapping.Field->type_name() << ", field: " << mapping.Field->name(); + } + } + } else { + switch (type) { + case FieldDescriptor::TYPE_DOUBLE: + reflection->SetDouble(destination, mapping.Field, cell.Get<double>()); + break; + + case FieldDescriptor::TYPE_FLOAT: + reflection->SetFloat(destination, mapping.Field, cell.Get<float>()); + break; + + case FieldDescriptor::TYPE_INT64: + case FieldDescriptor::TYPE_SFIXED64: + case FieldDescriptor::TYPE_SINT64: + reflection->SetInt64(destination, mapping.Field, cell.Get<i64>()); + break; + + case FieldDescriptor::TYPE_ENUM: { + switch (EnumFormatType(*mapping.Field, enumPolicy)) { + case EEnumFormatType::Int32: + reflection->SetEnumValue(destination, mapping.Field, cell.Get<i32>()); + break; + case EEnumFormatType::String: { + auto enumValueDescriptor = mapping.Field->enum_type()->FindValueByName(TString(cell.AsStringRef())); + if (!enumValueDescriptor) { + enumValueDescriptor = mapping.Field->default_value_enum(); + } + reflection->SetEnum(destination, mapping.Field, enumValueDescriptor); + break; + } + } + break; + } + + case FieldDescriptor::TYPE_UINT64: + case FieldDescriptor::TYPE_FIXED64: + reflection->SetUInt64(destination, mapping.Field, cell.Get<ui64>()); + break; + + case FieldDescriptor::TYPE_INT32: + case FieldDescriptor::TYPE_SFIXED32: + case FieldDescriptor::TYPE_SINT32: + reflection->SetInt32(destination, mapping.Field, cell.Get<i32>()); + break; + + case FieldDescriptor::TYPE_UINT32: + case FieldDescriptor::TYPE_FIXED32: + reflection->SetUInt32(destination, mapping.Field, cell.Get<ui32>()); + break; + + case FieldDescriptor::TYPE_BOOL: + reflection->SetBool(destination, mapping.Field, cell.Get<bool>()); + break; + + case FieldDescriptor::TYPE_STRING: + reflection->SetString(destination, mapping.Field, TString(cell.AsStringRef())); + break; + + case FieldDescriptor::TYPE_BYTES: + reflection->SetString(destination, mapping.Field, TString(cell.AsStringRef())); + break; + + case FieldDescriptor::TYPE_MESSAGE: + { + auto* nestedMessage = reflection->MutableMessage(destination, mapping.Field); + FillOutputMessage(cell, nestedMessage, mapping.NestedFields, enumPolicy); + } + break; + + default: + ythrow yexception() << "Unsupported protobuf type: " + << mapping.Field->type_name() << ", field: " << mapping.Field->name(); + } + } + } + } + + /** + * Converts input messages to unboxed values. + */ + class TInputConverter { + protected: + IWorker* Worker_; + TVector<TFieldMapping> Mappings_; + TPlainContainerCache Cache_; + TMaybe<TString> TimestampColumn_; + EEnumPolicy EnumPolicy_ = EEnumPolicy::Int32; + + public: + explicit TInputConverter(const TProtobufRawInputSpec& inputSpec, IWorker* worker) + : Worker_(worker) + , TimestampColumn_(inputSpec.GetTimestampColumn()) + , EnumPolicy_(inputSpec.GetSchemaOptions().EnumPolicy) + { + FillFieldMappings( + Worker_->GetInputType(), inputSpec.GetDescriptor(), + Mappings_, TimestampColumn_, + inputSpec.GetSchemaOptions().ListIsOptional, + inputSpec.GetSchemaOptions().FieldRenames + ); + } + + public: + void DoConvert(const Message* message, TUnboxedValue& result) { + auto& holderFactory = Worker_->GetGraph().GetHolderFactory(); + TUnboxedValue* items = nullptr; + result = Cache_.NewArray(holderFactory, static_cast<ui32>(Mappings_.size()), items); + FillInputValue(holderFactory, message, items, Mappings_, TimestampColumn_, Worker_->GetTimeProvider(), EnumPolicy_); + } + + void ClearCache() { + Cache_.Clear(); + } + }; + + template <typename TOutputSpec> + using OutputItemType = typename TOutputSpecTraits<TOutputSpec>::TOutputItemType; + + template <typename TOutputSpec> + class TOutputConverter; + + /** + * Converts unboxed values to output messages (single-output program case). + */ + template <> + class TOutputConverter<TProtobufRawOutputSpec> { + protected: + IWorker* Worker_; + TVector<TFieldMapping> OutputColumns_; + TProtoHolder<Message> Message_; + EEnumPolicy EnumPolicy_ = EEnumPolicy::Int32; + + public: + explicit TOutputConverter(const TProtobufRawOutputSpec& outputSpec, IWorker* worker) + : Worker_(worker) + , EnumPolicy_(outputSpec.GetSchemaOptions().EnumPolicy) + { + if (!Worker_->GetOutputType()->IsStruct()) { + ythrow yexception() << "protobuf output spec does not support multiple outputs"; + } + + FillFieldMappings( + static_cast<const NKikimr::NMiniKQL::TStructType*>(Worker_->GetOutputType()), + outputSpec.GetDescriptor(), + OutputColumns_, + Nothing(), + outputSpec.GetSchemaOptions().ListIsOptional, + outputSpec.GetSchemaOptions().FieldRenames + ); + + auto* factory = outputSpec.GetFactory(); + + if (!factory) { + factory = MessageFactory::generated_factory(); + } + + Message_.Reset(factory->GetPrototype(&outputSpec.GetDescriptor())->New(outputSpec.GetArena())); + } + + OutputItemType<TProtobufRawOutputSpec> DoConvert(TUnboxedValue value) { + FillOutputMessage(value, Message_.Get(), OutputColumns_, EnumPolicy_); + return Message_.Get(); + } + }; + + /* + * Converts unboxed values to output type (multi-output programs case). + */ + template <> + class TOutputConverter<TProtobufRawMultiOutputSpec> { + protected: + IWorker* Worker_; + TVector<TVector<TFieldMapping>> OutputColumns_; + TVector<TProtoHolder<Message>> Messages_; + EEnumPolicy EnumPolicy_ = EEnumPolicy::Int32; + + public: + explicit TOutputConverter(const TProtobufRawMultiOutputSpec& outputSpec, IWorker* worker) + : Worker_(worker) + , EnumPolicy_(outputSpec.GetSchemaOptions().EnumPolicy) + { + const auto* outputType = Worker_->GetOutputType(); + Y_ENSURE(outputType->IsVariant(), "protobuf multi-output spec requires multi-output program"); + const auto* variantType = static_cast<const NKikimr::NMiniKQL::TVariantType*>(outputType); + Y_ENSURE( + variantType->GetUnderlyingType()->IsTuple(), + "protobuf multi-output spec requires variant over tuple as program output type" + ); + Y_ENSURE( + outputSpec.GetOutputsNumber() == variantType->GetAlternativesCount(), + "number of outputs provided by spec does not match number of variant alternatives" + ); + + auto defaultFactory = MessageFactory::generated_factory(); + + for (ui32 i = 0; i < variantType->GetAlternativesCount(); ++i) { + const auto* type = variantType->GetAlternativeType(i); + Y_ASSERT(type->IsStruct()); + Y_ASSERT(OutputColumns_.size() == i && Messages_.size() == i); + + OutputColumns_.push_back({}); + + FillFieldMappings( + static_cast<const NKikimr::NMiniKQL::TStructType*>(type), + outputSpec.GetDescriptor(i), + OutputColumns_.back(), + Nothing(), + outputSpec.GetSchemaOptions().ListIsOptional, + {} + ); + + auto factory = outputSpec.GetFactory(i); + if (!factory) { + factory = defaultFactory; + } + + Messages_.push_back(TProtoHolder<Message>( + factory->GetPrototype(&outputSpec.GetDescriptor(i))->New(outputSpec.GetArena(i)) + )); + } + } + + OutputItemType<TProtobufRawMultiOutputSpec> DoConvert(TUnboxedValue value) { + auto index = value.GetVariantIndex(); + auto msgPtr = Messages_[index].Get(); + FillOutputMessage(value.GetVariantItem(), msgPtr, OutputColumns_[index], EnumPolicy_); + return {index, msgPtr}; + } + }; + + /** + * List (or, better, stream) of unboxed values. Used as an input value in pull workers. + */ + class TProtoListValue final: public TCustomListValue { + private: + mutable bool HasIterator_ = false; + THolder<IStream<Message*>> Underlying_; + TInputConverter Converter_; + IWorker* Worker_; + TScopedAlloc& ScopedAlloc_; + + public: + TProtoListValue( + TMemoryUsageInfo* memInfo, + const TProtobufRawInputSpec& inputSpec, + THolder<IStream<Message*>> underlying, + IWorker* worker + ) + : TCustomListValue(memInfo) + , Underlying_(std::move(underlying)) + , Converter_(inputSpec, worker) + , Worker_(worker) + , ScopedAlloc_(Worker_->GetScopedAlloc()) + { + } + + ~TProtoListValue() override { + { + // This list value stored in the worker's computation graph and destroyed upon the computation + // graph's destruction. This brings us to an interesting situation: scoped alloc is acquired, + // worker and computation graph are half-way destroyed, and now it's our turn to die. The problem is, + // the underlying stream may own another worker. This happens when chaining programs. Now, to destroy + // that worker correctly, we need to release our scoped alloc (because that worker has its own + // computation graph and scoped alloc). + // By the way, note that we shouldn't interact with the worker here because worker is in the middle of + // its own destruction. So we're using our own reference to the scoped alloc. That reference is alive + // because scoped alloc destroyed after computation graph. + auto unguard = Unguard(ScopedAlloc_); + Underlying_.Destroy(); + } + } + + public: + TUnboxedValue GetListIterator() const override { + YQL_ENSURE(!HasIterator_, "Only one pass over input is supported"); + HasIterator_ = true; + return TUnboxedValuePod(const_cast<TProtoListValue*>(this)); + } + + bool Next(TUnboxedValue& result) override { + const Message* message; + { + auto unguard = Unguard(ScopedAlloc_); + message = Underlying_->Fetch(); + } + + if (!message) { + return false; + } + + Converter_.DoConvert(message, result); + + return true; + } + + EFetchStatus Fetch(TUnboxedValue& result) override { + if (Next(result)) { + return EFetchStatus::Ok; + } else { + return EFetchStatus::Finish; + } + } + }; + + /** + * Consumer which converts messages to unboxed values and relays them to the worker. Used as a return value + * of the push processor's Process function. + */ + class TProtoConsumerImpl final: public IConsumer<Message*> { + private: + TWorkerHolder<IPushStreamWorker> WorkerHolder_; + TInputConverter Converter_; + + public: + explicit TProtoConsumerImpl( + const TProtobufRawInputSpec& inputSpec, + TWorkerHolder<IPushStreamWorker> worker + ) + : WorkerHolder_(std::move(worker)) + , Converter_(inputSpec, WorkerHolder_.Get()) + { + } + + ~TProtoConsumerImpl() override { + with_lock(WorkerHolder_->GetScopedAlloc()) { + Converter_.ClearCache(); + } + } + + public: + void OnObject(Message* message) override { + TBindTerminator bind(WorkerHolder_->GetGraph().GetTerminator()); + + with_lock(WorkerHolder_->GetScopedAlloc()) { + TUnboxedValue result; + Converter_.DoConvert(message, result); + WorkerHolder_->Push(std::move(result)); + } + } + + void OnFinish() override { + TBindTerminator bind(WorkerHolder_->GetGraph().GetTerminator()); + + with_lock(WorkerHolder_->GetScopedAlloc()) { + WorkerHolder_->OnFinish(); + } + } + }; + + /** + * Protobuf input stream for unboxed value streams. + */ + template <typename TOutputSpec> + class TRawProtoStreamImpl final: public IStream<OutputItemType<TOutputSpec>> { + protected: + TWorkerHolder<IPullStreamWorker> WorkerHolder_; + TOutputConverter<TOutputSpec> Converter_; + + public: + explicit TRawProtoStreamImpl(const TOutputSpec& outputSpec, TWorkerHolder<IPullStreamWorker> worker) + : WorkerHolder_(std::move(worker)) + , Converter_(outputSpec, WorkerHolder_.Get()) + { + } + + public: + OutputItemType<TOutputSpec> Fetch() override { + TBindTerminator bind(WorkerHolder_->GetGraph().GetTerminator()); + + with_lock(WorkerHolder_->GetScopedAlloc()) { + TUnboxedValue value; + + auto status = WorkerHolder_->GetOutput().Fetch(value); + + YQL_ENSURE(status != EFetchStatus::Yield, "Yield is not supported in pull mode"); + + if (status == EFetchStatus::Finish) { + return TOutputSpecTraits<TOutputSpec>::StreamSentinel; + } + + return Converter_.DoConvert(value); + } + } + }; + + /** + * Protobuf input stream for unboxed value lists. + */ + template <typename TOutputSpec> + class TRawProtoListImpl final: public IStream<OutputItemType<TOutputSpec>> { + protected: + TWorkerHolder<IPullListWorker> WorkerHolder_; + TOutputConverter<TOutputSpec> Converter_; + + public: + explicit TRawProtoListImpl(const TOutputSpec& outputSpec, TWorkerHolder<IPullListWorker> worker) + : WorkerHolder_(std::move(worker)) + , Converter_(outputSpec, WorkerHolder_.Get()) + { + } + + public: + OutputItemType<TOutputSpec> Fetch() override { + TBindTerminator bind(WorkerHolder_->GetGraph().GetTerminator()); + + with_lock(WorkerHolder_->GetScopedAlloc()) { + TUnboxedValue value; + + if (!WorkerHolder_->GetOutputIterator().Next(value)) { + return TOutputSpecTraits<TOutputSpec>::StreamSentinel; + } + + return Converter_.DoConvert(value); + } + } + }; + + /** + * Push relay used to convert generated unboxed value to a message and push it to the user's consumer. + */ + template <typename TOutputSpec> + class TPushRelayImpl: public IConsumer<const TUnboxedValue*> { + private: + THolder<IConsumer<OutputItemType<TOutputSpec>>> Underlying_; + TOutputConverter<TOutputSpec> Converter_; + IWorker* Worker_; + + public: + TPushRelayImpl( + const TOutputSpec& outputSpec, + IPushStreamWorker* worker, + THolder<IConsumer<OutputItemType<TOutputSpec>>> underlying + ) + : Underlying_(std::move(underlying)) + , Converter_(outputSpec, worker) + , Worker_(worker) + { + } + + // If you've read a comment in the TProtoListValue's destructor, you may be wondering why don't we do the + // same trick here. Well, that's because in push mode, consumer is destroyed before acquiring scoped alloc and + // destroying computation graph. + + public: + void OnObject(const TUnboxedValue* value) override { + OutputItemType<TOutputSpec> message = Converter_.DoConvert(*value); + auto unguard = Unguard(Worker_->GetScopedAlloc()); + Underlying_->OnObject(message); + } + + void OnFinish() override { + auto unguard = Unguard(Worker_->GetScopedAlloc()); + Underlying_->OnFinish(); + } + }; +} + +using ConsumerType = TInputSpecTraits<TProtobufRawInputSpec>::TConsumerType; + +void TInputSpecTraits<TProtobufRawInputSpec>::PreparePullStreamWorker( + const TProtobufRawInputSpec& inputSpec, + IPullStreamWorker* worker, + THolder<IStream<Message*>> stream +) { + with_lock(worker->GetScopedAlloc()) { + worker->SetInput( + worker->GetGraph().GetHolderFactory().Create<TProtoListValue>(inputSpec, std::move(stream), worker), 0); + } +} + +void TInputSpecTraits<TProtobufRawInputSpec>::PreparePullListWorker( + const TProtobufRawInputSpec& inputSpec, + IPullListWorker* worker, + THolder<IStream<Message*>> stream +) { + with_lock(worker->GetScopedAlloc()) { + worker->SetInput( + worker->GetGraph().GetHolderFactory().Create<TProtoListValue>(inputSpec, std::move(stream), worker), 0); + } +} + +ConsumerType TInputSpecTraits<TProtobufRawInputSpec>::MakeConsumer( + const TProtobufRawInputSpec& inputSpec, + TWorkerHolder<IPushStreamWorker> worker +) { + return MakeHolder<TProtoConsumerImpl>(inputSpec, std::move(worker)); +} + +template <typename TOutputSpec> +using PullStreamReturnType = typename TOutputSpecTraits<TOutputSpec>::TPullStreamReturnType; +template <typename TOutputSpec> +using PullListReturnType = typename TOutputSpecTraits<TOutputSpec>::TPullListReturnType; + +PullStreamReturnType<TProtobufRawOutputSpec> TOutputSpecTraits<TProtobufRawOutputSpec>::ConvertPullStreamWorkerToOutputType( + const TProtobufRawOutputSpec& outputSpec, + TWorkerHolder<IPullStreamWorker> worker +) { + return MakeHolder<TRawProtoStreamImpl<TProtobufRawOutputSpec>>(outputSpec, std::move(worker)); +} + +PullListReturnType<TProtobufRawOutputSpec> TOutputSpecTraits<TProtobufRawOutputSpec>::ConvertPullListWorkerToOutputType( + const TProtobufRawOutputSpec& outputSpec, + TWorkerHolder<IPullListWorker> worker +) { + return MakeHolder<TRawProtoListImpl<TProtobufRawOutputSpec>>(outputSpec, std::move(worker)); +} + +void TOutputSpecTraits<TProtobufRawOutputSpec>::SetConsumerToWorker( + const TProtobufRawOutputSpec& outputSpec, + IPushStreamWorker* worker, + THolder<IConsumer<TOutputItemType>> consumer +) { + worker->SetConsumer(MakeHolder<TPushRelayImpl<TProtobufRawOutputSpec>>(outputSpec, worker, std::move(consumer))); +} + +PullStreamReturnType<TProtobufRawMultiOutputSpec> TOutputSpecTraits<TProtobufRawMultiOutputSpec>::ConvertPullStreamWorkerToOutputType( + const TProtobufRawMultiOutputSpec& outputSpec, + TWorkerHolder<IPullStreamWorker> worker +) { + return MakeHolder<TRawProtoStreamImpl<TProtobufRawMultiOutputSpec>>(outputSpec, std::move(worker)); +} + +PullListReturnType<TProtobufRawMultiOutputSpec> TOutputSpecTraits<TProtobufRawMultiOutputSpec>::ConvertPullListWorkerToOutputType( + const TProtobufRawMultiOutputSpec& outputSpec, + TWorkerHolder<IPullListWorker> worker +) { + return MakeHolder<TRawProtoListImpl<TProtobufRawMultiOutputSpec>>(outputSpec, std::move(worker)); +} + +void TOutputSpecTraits<TProtobufRawMultiOutputSpec>::SetConsumerToWorker( + const TProtobufRawMultiOutputSpec& outputSpec, + IPushStreamWorker* worker, + THolder<IConsumer<TOutputItemType>> consumer +) { + worker->SetConsumer(MakeHolder<TPushRelayImpl<TProtobufRawMultiOutputSpec>>(outputSpec, worker, std::move(consumer))); +} diff --git a/yql/essentials/public/purecalc/io_specs/protobuf_raw/spec.h b/yql/essentials/public/purecalc/io_specs/protobuf_raw/spec.h new file mode 100644 index 00000000000..436b243bffd --- /dev/null +++ b/yql/essentials/public/purecalc/io_specs/protobuf_raw/spec.h @@ -0,0 +1,257 @@ +#pragma once + +#include <yql/essentials/public/purecalc/common/interface.h> +#include <yql/essentials/public/purecalc/helpers/protobuf/schema_from_proto.h> + +#include <google/protobuf/message.h> + +#include <util/generic/maybe.h> + +namespace NYql { + namespace NPureCalc { + /** + * Processing mode for working with raw protobuf message inputs. + * + * In this mode purecalc accept pointers to abstract protobuf messages and processes them using the reflection + * mechanism. All passed messages should have the same descriptor (the one you pass to the constructor + * of the input spec). + * + * All working modes are implemented. In pull stream and pull list modes a program would accept a single object + * stream of const protobuf messages. In push mode, a program will return a consumer of const protobuf messages. + * + * The program synopsis follows: + * + * @code + * ... TPullStreamProgram::Apply(IStream<google::protobuf::Message*>); + * ... TPullListProgram::Apply(IStream<google::protobuf::Message*>); + * TConsumer<google::protobuf::Message*> TPushStreamProgram::Apply(...); + * @endcode + */ + class TProtobufRawInputSpec: public TInputSpecBase { + private: + const google::protobuf::Descriptor& Descriptor_; + const TMaybe<TString> TimestampColumn_; + const TProtoSchemaOptions SchemaOptions_; + mutable TVector<NYT::TNode> SavedSchemas_; + + public: + /** + * Build input spec and associate the given message descriptor. + */ + explicit TProtobufRawInputSpec( + const google::protobuf::Descriptor& descriptor, + const TMaybe<TString>& timestampColumn = Nothing(), + const TProtoSchemaOptions& options = {} + ); + + public: + const TVector<NYT::TNode>& GetSchemas() const override; + + /** + * Get the descriptor associated with this spec. + */ + const google::protobuf::Descriptor& GetDescriptor() const; + + const TMaybe<TString>& GetTimestampColumn() const; + + /* + * Get options that customize input struct type building. + */ + const TProtoSchemaOptions& GetSchemaOptions() const; + }; + + /** + * Processing mode for working with raw protobuf message outputs. + * + * In this mode purecalc yields pointers to abstract protobuf messages. All generated messages share the same + * descriptor so they can be safely converted into an appropriate message type. + * + * Note that one should not expect that the returned pointer will be valid forever; in can (and will) become + * outdated once a new output is requested/pushed. + * + * All working modes are implemented. In pull stream and pull list modes a program will return an object + * stream of non-const protobuf messages. In push mode, it will accept a single consumer of non-const + * messages. + * + * The program synopsis follows: + * + * @code + * IStream<google::protobuf::Message*> TPullStreamProgram::Apply(...); + * IStream<google::protobuf::Message*> TPullListProgram::Apply(...); + * ... TPushStreamProgram::Apply(TConsumer<google::protobuf::Message*>); + * @endcode + */ + class TProtobufRawOutputSpec: public TOutputSpecBase { + private: + const google::protobuf::Descriptor& Descriptor_; + google::protobuf::MessageFactory* Factory_; + TProtoSchemaOptions SchemaOptions_; + google::protobuf::Arena* Arena_; + mutable TMaybe<NYT::TNode> SavedSchema_; + + public: + /** + * Build output spec and associate the given message descriptor and maybe the given message factory. + */ + explicit TProtobufRawOutputSpec( + const google::protobuf::Descriptor& descriptor, + google::protobuf::MessageFactory* = nullptr, + const TProtoSchemaOptions& options = {}, + google::protobuf::Arena* arena = nullptr + ); + + public: + const NYT::TNode& GetSchema() const override; + + /** + * Get the descriptor associated with this spec. + */ + const google::protobuf::Descriptor& GetDescriptor() const; + + /** + * Set a new message factory which will be used to generate messages. Pass a null pointer to use the + * default factory. + */ + void SetFactory(google::protobuf::MessageFactory*); + + /** + * Get the message factory which is currently associated with this spec. + */ + google::protobuf::MessageFactory* GetFactory() const; + + /** + * Set a new arena which will be used to generate messages. Pass a null pointer to create on the heap. + */ + void SetArena(google::protobuf::Arena*); + + /** + * Get the arena which is currently associated with this spec. + */ + google::protobuf::Arena* GetArena() const; + + /** + * Get options that customize output struct type building. + */ + const TProtoSchemaOptions& GetSchemaOptions() const; + }; + + /** + * Processing mode for working with raw protobuf messages and several outputs. + * + * The program synopsis follows: + * + * @code + * IStream<std::pair<ui32, google::protobuf::Message*>> TPullStreamProgram::Apply(...); + * IStream<std::pair<ui32, google::protobuf::Message*>> TPullListProgram::Apply(...); + * ... TPushStreamProgram::Apply(TConsumer<std::pair<ui32, google::protobuf::Message*>>); + * @endcode + */ + class TProtobufRawMultiOutputSpec: public TOutputSpecBase { + private: + TVector<const google::protobuf::Descriptor*> Descriptors_; + TVector<google::protobuf::MessageFactory*> Factories_; + const TProtoSchemaOptions SchemaOptions_; + TVector<google::protobuf::Arena*> Arenas_; + mutable NYT::TNode SavedSchema_; + + public: + TProtobufRawMultiOutputSpec( + TVector<const google::protobuf::Descriptor*>, + TMaybe<TVector<google::protobuf::MessageFactory*>> = {}, + const TProtoSchemaOptions& options = {}, + TMaybe<TVector<google::protobuf::Arena*>> arenas = {} + ); + + public: + const NYT::TNode& GetSchema() const override; + + /** + * Get the descriptor associated with given output. + */ + const google::protobuf::Descriptor& GetDescriptor(ui32) const; + + /** + * Set a new message factory for given output. It will be used to generate messages for this output. + */ + void SetFactory(ui32, google::protobuf::MessageFactory*); + + /** + * Get the message factory which is currently associated with given output. + */ + google::protobuf::MessageFactory* GetFactory(ui32) const; + + /** + * Set a new arena for given output. It will be used to generate messages for this output. + */ + void SetArena(ui32, google::protobuf::Arena*); + + /** + * Get the arena which is currently associated with given output. + */ + google::protobuf::Arena* GetArena(ui32) const; + + /** + * Get number of outputs for this spec. + */ + ui32 GetOutputsNumber() const; + + /** + * Get options that customize output struct type building. + */ + const TProtoSchemaOptions& GetSchemaOptions() const; + }; + + template <> + struct TInputSpecTraits<TProtobufRawInputSpec> { + static const constexpr bool IsPartial = false; + + static const constexpr bool SupportPullStreamMode = true; + static const constexpr bool SupportPullListMode = true; + static const constexpr bool SupportPushStreamMode = true; + + using TConsumerType = THolder<IConsumer<google::protobuf::Message*>>; + + static void PreparePullStreamWorker(const TProtobufRawInputSpec&, IPullStreamWorker*, THolder<IStream<google::protobuf::Message*>>); + static void PreparePullListWorker(const TProtobufRawInputSpec&, IPullListWorker*, THolder<IStream<google::protobuf::Message*>>); + static TConsumerType MakeConsumer(const TProtobufRawInputSpec&, TWorkerHolder<IPushStreamWorker>); + }; + + template <> + struct TOutputSpecTraits<TProtobufRawOutputSpec> { + static const constexpr bool IsPartial = false; + + static const constexpr bool SupportPullStreamMode = true; + static const constexpr bool SupportPullListMode = true; + static const constexpr bool SupportPushStreamMode = true; + + using TOutputItemType = google::protobuf::Message*; + using TPullStreamReturnType = THolder<IStream<TOutputItemType>>; + using TPullListReturnType = THolder<IStream<TOutputItemType>>; + + static const constexpr TOutputItemType StreamSentinel = nullptr; + + static TPullStreamReturnType ConvertPullStreamWorkerToOutputType(const TProtobufRawOutputSpec&, TWorkerHolder<IPullStreamWorker>); + static TPullListReturnType ConvertPullListWorkerToOutputType(const TProtobufRawOutputSpec&, TWorkerHolder<IPullListWorker>); + static void SetConsumerToWorker(const TProtobufRawOutputSpec&, IPushStreamWorker*, THolder<IConsumer<TOutputItemType>>); + }; + + template <> + struct TOutputSpecTraits<TProtobufRawMultiOutputSpec> { + static const constexpr bool IsPartial = false; + + static const constexpr bool SupportPullStreamMode = true; + static const constexpr bool SupportPullListMode = true; + static const constexpr bool SupportPushStreamMode = true; + + using TOutputItemType = std::pair<ui32, google::protobuf::Message*>; + using TPullStreamReturnType = THolder<IStream<TOutputItemType>>; + using TPullListReturnType = THolder<IStream<TOutputItemType>>; + + static const constexpr TOutputItemType StreamSentinel = {0, nullptr}; + + static TPullStreamReturnType ConvertPullStreamWorkerToOutputType(const TProtobufRawMultiOutputSpec&, TWorkerHolder<IPullStreamWorker>); + static TPullListReturnType ConvertPullListWorkerToOutputType(const TProtobufRawMultiOutputSpec&, TWorkerHolder<IPullListWorker>); + static void SetConsumerToWorker(const TProtobufRawMultiOutputSpec&, IPushStreamWorker*, THolder<IConsumer<TOutputItemType>>); + }; + } +} diff --git a/yql/essentials/public/purecalc/io_specs/protobuf_raw/ya.make b/yql/essentials/public/purecalc/io_specs/protobuf_raw/ya.make new file mode 100644 index 00000000000..db3fab7e7a5 --- /dev/null +++ b/yql/essentials/public/purecalc/io_specs/protobuf_raw/ya.make @@ -0,0 +1,16 @@ +LIBRARY() + +PEERDIR( + yql/essentials/public/purecalc/common + yql/essentials/public/purecalc/helpers/protobuf +) + +SRCS( + proto_holder.cpp + spec.cpp + spec.h +) + +YQL_LAST_ABI_VERSION() + +END() diff --git a/yql/essentials/public/purecalc/io_specs/ut/ya.make b/yql/essentials/public/purecalc/io_specs/ut/ya.make new file mode 100644 index 00000000000..70b8bb521d2 --- /dev/null +++ b/yql/essentials/public/purecalc/io_specs/ut/ya.make @@ -0,0 +1,4 @@ +RECURSE( + ../arrow/ut + ../protobuf/ut +) diff --git a/yql/essentials/public/purecalc/io_specs/ya.make b/yql/essentials/public/purecalc/io_specs/ya.make new file mode 100644 index 00000000000..30c7d259103 --- /dev/null +++ b/yql/essentials/public/purecalc/io_specs/ya.make @@ -0,0 +1,9 @@ +RECURSE( + arrow + protobuf + protobuf_raw +) + +RECURSE_FOR_TESTS( + ut +) diff --git a/yql/essentials/public/purecalc/no_llvm/purecalc.h b/yql/essentials/public/purecalc/no_llvm/purecalc.h new file mode 100644 index 00000000000..9b281a7caa7 --- /dev/null +++ b/yql/essentials/public/purecalc/no_llvm/purecalc.h @@ -0,0 +1,4 @@ +#pragma once + +#include "common/interface.h" + diff --git a/yql/essentials/public/purecalc/no_llvm/ya.make b/yql/essentials/public/purecalc/no_llvm/ya.make new file mode 100644 index 00000000000..2b3bd870314 --- /dev/null +++ b/yql/essentials/public/purecalc/no_llvm/ya.make @@ -0,0 +1,30 @@ +LIBRARY() + +ADDINCL( + yql/essentials/public/purecalc +) + +SRCDIR( + yql/essentials/public/purecalc +) + +SRCS( + purecalc.cpp +) + +PEERDIR( + yql/essentials/public/udf/service/exception_policy + yql/essentials/public/purecalc/common/no_llvm + contrib/ydb/library/yql/providers/yt/codec/codegen/no_llvm + yql/essentials/minikql/codegen/no_llvm + yql/essentials/minikql/computation/no_llvm + yql/essentials/minikql/invoke_builtins/no_llvm + yql/essentials/minikql/comp_nodes/no_llvm +) + +YQL_LAST_ABI_VERSION() + +PROVIDES(YQL_PURECALC) + +END() + diff --git a/yql/essentials/public/purecalc/purecalc.cpp b/yql/essentials/public/purecalc/purecalc.cpp new file mode 100644 index 00000000000..80cfd39d963 --- /dev/null +++ b/yql/essentials/public/purecalc/purecalc.cpp @@ -0,0 +1 @@ +#include "purecalc.h" diff --git a/yql/essentials/public/purecalc/purecalc.h b/yql/essentials/public/purecalc/purecalc.h new file mode 100644 index 00000000000..83bd8a7b842 --- /dev/null +++ b/yql/essentials/public/purecalc/purecalc.h @@ -0,0 +1,3 @@ +#pragma once + +#include "common/interface.h" diff --git a/yql/essentials/public/purecalc/ut/empty_stream.h b/yql/essentials/public/purecalc/ut/empty_stream.h new file mode 100644 index 00000000000..8d10e647aee --- /dev/null +++ b/yql/essentials/public/purecalc/ut/empty_stream.h @@ -0,0 +1,20 @@ +#pragma once + +#include <yql/essentials/public/purecalc/purecalc.h> + +namespace NYql { + namespace NPureCalc { + template <typename T> + class TEmptyStreamImpl: public IStream<T> { + public: + T Fetch() override { + return nullptr; + } + }; + + template <typename T> + THolder<IStream<T>> EmptyStream() { + return MakeHolder<TEmptyStreamImpl<T>>(); + } + } +} diff --git a/yql/essentials/public/purecalc/ut/fake_spec.cpp b/yql/essentials/public/purecalc/ut/fake_spec.cpp new file mode 100644 index 00000000000..b56f7cfdfd5 --- /dev/null +++ b/yql/essentials/public/purecalc/ut/fake_spec.cpp @@ -0,0 +1,36 @@ +#include "fake_spec.h" + +namespace NYql { + namespace NPureCalc { + NYT::TNode MakeFakeSchema(bool pg) { + auto itemType = NYT::TNode::CreateList(); + itemType.Add(pg ? "PgType" : "DataType"); + itemType.Add(pg ? "int4" : "Int32"); + + auto itemNode = NYT::TNode::CreateList(); + itemNode.Add("Name"); + itemNode.Add(std::move(itemType)); + + auto items = NYT::TNode::CreateList(); + items.Add(std::move(itemNode)); + + auto schema = NYT::TNode::CreateList(); + schema.Add("StructType"); + schema.Add(std::move(items)); + + return schema; + } + + TFakeInputSpec FakeIS(ui32 inputsNumber, bool pg) { + auto spec = TFakeInputSpec(); + spec.Schemas = TVector<NYT::TNode>(inputsNumber, MakeFakeSchema(pg)); + return spec; + } + + TFakeOutputSpec FakeOS(bool pg) { + auto spec = TFakeOutputSpec(); + spec.Schema = MakeFakeSchema(pg); + return spec; + } + } +} diff --git a/yql/essentials/public/purecalc/ut/fake_spec.h b/yql/essentials/public/purecalc/ut/fake_spec.h new file mode 100644 index 00000000000..3cb1457f01d --- /dev/null +++ b/yql/essentials/public/purecalc/ut/fake_spec.h @@ -0,0 +1,54 @@ +#pragma once + +#include <yql/essentials/public/purecalc/purecalc.h> + +namespace NYql { + namespace NPureCalc { + class TFakeInputSpec: public TInputSpecBase { + public: + TVector<NYT::TNode> Schemas = {NYT::TNode::CreateList()}; + + public: + const TVector<NYT::TNode>& GetSchemas() const override { + return Schemas; + } + }; + + class TFakeOutputSpec: public TOutputSpecBase { + public: + NYT::TNode Schema = NYT::TNode::CreateList(); + + public: + const NYT::TNode& GetSchema() const override { + return Schema; + } + }; + + template <> + struct TInputSpecTraits<TFakeInputSpec> { + static const constexpr bool IsPartial = false; + + static const constexpr bool SupportPullStreamMode = false; + static const constexpr bool SupportPullListMode = false; + static const constexpr bool SupportPushStreamMode = false; + + using TConsumerType = void; + }; + + template <> + struct TOutputSpecTraits<TFakeOutputSpec> { + static const constexpr bool IsPartial = false; + + static const constexpr bool SupportPullStreamMode = false; + static const constexpr bool SupportPullListMode = false; + static const constexpr bool SupportPushStreamMode = false; + + using TPullStreamReturnType = void; + using TPullListReturnType = void; + }; + + NYT::TNode MakeFakeSchema(bool pg = false); + TFakeInputSpec FakeIS(ui32 inputsNumber = 1, bool pg = false); + TFakeOutputSpec FakeOS(bool pg = false); + } +} diff --git a/yql/essentials/public/purecalc/ut/lib/helpers.cpp b/yql/essentials/public/purecalc/ut/lib/helpers.cpp new file mode 100644 index 00000000000..cef9a995235 --- /dev/null +++ b/yql/essentials/public/purecalc/ut/lib/helpers.cpp @@ -0,0 +1,55 @@ +#include "helpers.h" + +#include <library/cpp/yson/writer.h> + +#include <library/cpp/yson/node/node_visitor.h> + +#include <util/string/ascii.h> +#include <util/generic/hash_set.h> + + +namespace NYql { + namespace NPureCalc { + namespace NPrivate { + NYT::TNode GetSchema( + const TVector<TString>& fields, + const TVector<TString>& optionalFields + ) { + THashSet<TString> optionalFilter { + optionalFields.begin(), optionalFields.end() + }; + + NYT::TNode members {NYT::TNode::CreateList()}; + + auto addField = [&] (const TString& name, const TString& type) { + auto typeNode = NYT::TNode::CreateList() + .Add("DataType") + .Add(type); + + if (optionalFilter.contains(name)) { + typeNode = NYT::TNode::CreateList() + .Add("OptionalType") + .Add(typeNode); + } + + members.Add(NYT::TNode::CreateList() + .Add(name) + .Add(typeNode) + ); + }; + + for (const auto& field: fields) { + TString type {field}; + type[0] = AsciiToUpper(type[0]); + addField(field, type); + } + + NYT::TNode schema = NYT::TNode::CreateList() + .Add("StructType") + .Add(members); + + return schema; + } + } + } +} diff --git a/yql/essentials/public/purecalc/ut/lib/helpers.h b/yql/essentials/public/purecalc/ut/lib/helpers.h new file mode 100644 index 00000000000..53a22661ec3 --- /dev/null +++ b/yql/essentials/public/purecalc/ut/lib/helpers.h @@ -0,0 +1,18 @@ +#pragma once + +#include <library/cpp/yson/node/node.h> +#include <util/generic/string.h> +#include <util/generic/vector.h> +#include <util/stream/str.h> + + +namespace NYql { + namespace NPureCalc { + namespace NPrivate { + NYT::TNode GetSchema( + const TVector<TString>& fields, + const TVector<TString>& optionalFields = {} + ); + } + } +} diff --git a/yql/essentials/public/purecalc/ut/lib/ya.make b/yql/essentials/public/purecalc/ut/lib/ya.make new file mode 100644 index 00000000000..36134a2940b --- /dev/null +++ b/yql/essentials/public/purecalc/ut/lib/ya.make @@ -0,0 +1,14 @@ +LIBRARY() + +PEERDIR( + contrib/libs/apache/arrow + library/cpp/yson + library/cpp/yson/node +) + +SRCS( + helpers.cpp + helpers.h +) + +END() diff --git a/yql/essentials/public/purecalc/ut/protos/test_structs.proto b/yql/essentials/public/purecalc/ut/protos/test_structs.proto new file mode 100644 index 00000000000..66593005a5e --- /dev/null +++ b/yql/essentials/public/purecalc/ut/protos/test_structs.proto @@ -0,0 +1,122 @@ +package NPureCalcProto; + +message TUnparsed { + required string S = 1; +} + +message TParsed { + required int32 A = 1; + optional int32 B = 2; + required int32 C = 3; +} + +message TPartial { + required int32 X = 1; +} + +message TSimpleMessage { + required int32 X = 1; +} + +message TNamedSimpleMessage { + required int32 X = 1; + required bytes Name = 2; +} + +message TStringMessage { + required string X = 1; +} + +message TAllTypes { + required double FDouble = 1; + required float FFloat = 2; + required int64 FInt64 = 3; + required sfixed64 FSfixed64 = 4; + required sint64 FSint64 = 5; + required uint64 FUint64 = 6; + required fixed64 FFixed64 = 7; + required int32 FInt32 = 8; + required sfixed32 FSfixed32 = 9; + required sint32 FSint32 = 10; + required uint32 FUint32 = 11; + required fixed32 FFixed32 = 12; + required bool FBool = 13; + required string FString = 14; + required bytes FBytes = 15; +} + +message TOptionalAllTypes { + optional double FDouble = 1; + optional float FFloat = 2; + optional int64 FInt64 = 3; + optional sfixed64 FSfixed64 = 4; + optional sint64 FSint64 = 5; + optional uint64 FUint64 = 6; + optional fixed64 FFixed64 = 7; + optional int32 FInt32 = 8; + optional sfixed32 FSfixed32 = 9; + optional sint32 FSint32 = 10; + optional uint32 FUint32 = 11; + optional fixed32 FFixed32 = 12; + optional bool FBool = 13; + optional string FString = 14; + optional bytes FBytes = 15; +} + +message TSimpleNested { + required int32 X = 1; + required TAllTypes Y = 2; +} + +message TOptionalNested { + optional TAllTypes X = 1; +} + +message TSimpleRepeated { + required int32 X = 1; + repeated int32 Y = 2; +} + +message TNestedRepeated { + required int32 X = 1; + repeated TSimpleNested Y = 2; +} + +message TRecursive { + required int32 X = 1; + required TRecursive Nested = 2; +} + +message TRecursiveIndirectly { + message TNested { + required TRecursiveIndirectly Nested = 1; + } + + required int32 X = 1; + repeated TNested Nested = 2; +} + +message TMessageWithEnum { + enum ETestEnum { + VALUE1 = 0; + VALUE2 = 1; + } + repeated ETestEnum EnumValue = 1; +} + +message TUnsplitted { + required int32 AInt = 1; + required uint32 AUint = 2; + required string AString = 3; + optional bool ABool = 4; +} + +message TSplitted1 { + required int32 BInt = 1; + required string BString = 2; +} + +message TSplitted2 { + required uint32 CUint = 1; + required string CString = 2; +} diff --git a/yql/essentials/public/purecalc/ut/protos/ya.make b/yql/essentials/public/purecalc/ut/protos/ya.make new file mode 100644 index 00000000000..a455ff2fba2 --- /dev/null +++ b/yql/essentials/public/purecalc/ut/protos/ya.make @@ -0,0 +1,9 @@ +PROTO_LIBRARY() + +SRCS( + test_structs.proto +) + +EXCLUDE_TAGS(GO_PROTO) + +END() diff --git a/yql/essentials/public/purecalc/ut/test_eval.cpp b/yql/essentials/public/purecalc/ut/test_eval.cpp new file mode 100644 index 00000000000..38ad7cc952d --- /dev/null +++ b/yql/essentials/public/purecalc/ut/test_eval.cpp @@ -0,0 +1,30 @@ +#include <yql/essentials/public/purecalc/purecalc.h> +#include <yql/essentials/public/purecalc/io_specs/protobuf/spec.h> +#include <yql/essentials/public/purecalc/ut/protos/test_structs.pb.h> +#include <yql/essentials/public/purecalc/ut/empty_stream.h> + +#include <library/cpp/testing/unittest/registar.h> + +Y_UNIT_TEST_SUITE(TestEval) { + Y_UNIT_TEST(TestEvalExpr) { + using namespace NYql::NPureCalc; + + auto options = TProgramFactoryOptions(); + auto factory = MakeProgramFactory(options); + + auto program = factory->MakePullListProgram( + TProtobufInputSpec<NPureCalcProto::TStringMessage>(), + TProtobufOutputSpec<NPureCalcProto::TStringMessage>(), + "SELECT Unwrap(cast(EvaluateExpr('foo' || 'bar') as Utf8)) AS X", + ETranslationMode::SQL + ); + + auto stream = program->Apply(EmptyStream<NPureCalcProto::TStringMessage*>()); + + NPureCalcProto::TStringMessage* message; + + UNIT_ASSERT(message = stream->Fetch()); + UNIT_ASSERT_EQUAL(message->GetX(), "foobar"); + UNIT_ASSERT(!stream->Fetch()); + } +} diff --git a/yql/essentials/public/purecalc/ut/test_mixed_allocators.cpp b/yql/essentials/public/purecalc/ut/test_mixed_allocators.cpp new file mode 100644 index 00000000000..797f3c5b512 --- /dev/null +++ b/yql/essentials/public/purecalc/ut/test_mixed_allocators.cpp @@ -0,0 +1,139 @@ +#include <library/cpp/testing/unittest/registar.h> + +#include <yql/essentials/minikql/computation/mkql_computation_node_holders.h> +#include <yql/essentials/minikql/mkql_string_util.h> + +#include <yql/essentials/public/purecalc/io_specs/protobuf/spec.h> +#include <yql/essentials/public/purecalc/ut/protos/test_structs.pb.h> + +using namespace NYql::NPureCalc; + +namespace { + class TStatelessInputSpec : public TInputSpecBase { + public: + TStatelessInputSpec() + : Schemas_({NYT::TNode::CreateList() + .Add("StructType") + .Add(NYT::TNode::CreateList() + .Add(NYT::TNode::CreateList() + .Add("InputValue") + .Add(NYT::TNode::CreateList() + .Add("DataType") + .Add("Utf8") + ) + ) + ) + }) + {}; + + const TVector<NYT::TNode>& GetSchemas() const override { + return Schemas_; + } + + private: + const TVector<NYT::TNode> Schemas_; + }; + + class TStatelessInputConsumer : public IConsumer<const NYql::NUdf::TUnboxedValue&> { + public: + TStatelessInputConsumer(TWorkerHolder<IPushStreamWorker> worker) + : Worker_(std::move(worker)) + {} + + void OnObject(const NYql::NUdf::TUnboxedValue& value) override { + with_lock (Worker_->GetScopedAlloc()) { + NYql::NUdf::TUnboxedValue* items = nullptr; + NYql::NUdf::TUnboxedValue result = Worker_->GetGraph().GetHolderFactory().CreateDirectArrayHolder(1, items); + + items[0] = value; + + Worker_->Push(std::move(result)); + + // Clear graph after each object because + // values allocated on another allocator and should be released + Worker_->GetGraph().Invalidate(); + } + } + + void OnFinish() override { + with_lock(Worker_->GetScopedAlloc()) { + Worker_->OnFinish(); + } + } + + private: + TWorkerHolder<IPushStreamWorker> Worker_; + }; + + class TStatelessConsumer : public IConsumer<NPureCalcProto::TStringMessage*> { + const TString ExpectedData_; + const ui64 ExpectedRows_; + ui64 RowId_ = 0; + + public: + TStatelessConsumer(const TString& expectedData, ui64 expectedRows) + : ExpectedData_(expectedData) + , ExpectedRows_(expectedRows) + {} + + void OnObject(NPureCalcProto::TStringMessage* message) override { + UNIT_ASSERT_VALUES_EQUAL_C(ExpectedData_, message->GetX(), RowId_); + RowId_++; + } + + void OnFinish() override { + UNIT_ASSERT_VALUES_EQUAL(ExpectedRows_, RowId_); + } + }; +} + +template <> +struct TInputSpecTraits<TStatelessInputSpec> { + static constexpr bool IsPartial = false; + static constexpr bool SupportPushStreamMode = true; + + using TConsumerType = THolder<IConsumer<const NYql::NUdf::TUnboxedValue&>>; + + static TConsumerType MakeConsumer(const TStatelessInputSpec&, TWorkerHolder<IPushStreamWorker> worker) { + return MakeHolder<TStatelessInputConsumer>(std::move(worker)); + } +}; + +Y_UNIT_TEST_SUITE(TestMixedAllocators) { + Y_UNIT_TEST(TestPushStream) { + const auto targetString = "large string >= 14 bytes"; + const auto factory = MakeProgramFactory(); + const auto sql = TStringBuilder() << "SELECT InputValue AS X FROM Input WHERE InputValue = \"" << targetString << "\";"; + + const auto program = factory->MakePushStreamProgram( + TStatelessInputSpec(), + TProtobufOutputSpec<NPureCalcProto::TStringMessage>(), + sql + ); + + const ui64 numberRows = 5; + const auto inputConsumer = program->Apply(MakeHolder<TStatelessConsumer>(targetString, numberRows)); + NKikimr::NMiniKQL::TScopedAlloc alloc(__LOCATION__, NKikimr::TAlignedPagePoolCounters(), true, false); + + const auto pushString = [&](TString inputValue) { + NYql::NUdf::TUnboxedValue stringValue; + with_lock(alloc) { + stringValue = NKikimr::NMiniKQL::MakeString(inputValue); + alloc.Ref().LockObject(stringValue); + } + + inputConsumer->OnObject(stringValue); + + with_lock(alloc) { + alloc.Ref().UnlockObject(stringValue); + stringValue.Clear(); + } + }; + + for (ui64 i = 0; i < numberRows; ++i) { + pushString(targetString); + pushString("another large string >= 14 bytes"); + } + inputConsumer->OnFinish(); + } +} diff --git a/yql/essentials/public/purecalc/ut/test_pg.cpp b/yql/essentials/public/purecalc/ut/test_pg.cpp new file mode 100644 index 00000000000..3d26cfbd1be --- /dev/null +++ b/yql/essentials/public/purecalc/ut/test_pg.cpp @@ -0,0 +1,71 @@ +#include <yql/essentials/public/purecalc/purecalc.h> + +#include "fake_spec.h" + +#include <yql/essentials/public/purecalc/ut/protos/test_structs.pb.h> + +#include <library/cpp/testing/unittest/registar.h> + +Y_UNIT_TEST_SUITE(TestPg) { + using namespace NYql::NPureCalc; + + Y_UNIT_TEST(TestPgCompile) { + auto factory = MakeProgramFactory(); + + auto sql = TString(R"( + SELECT * FROM "Input"; + )"); + + UNIT_ASSERT_NO_EXCEPTION([&](){ + factory->MakePullListProgram(FakeIS(1,true), FakeOS(true), sql, ETranslationMode::PG); + }()); + + UNIT_ASSERT_EXCEPTION_CONTAINS([&](){ + factory->MakePullStreamProgram(FakeIS(1,true), FakeOS(true), sql, ETranslationMode::PG); + }(), TCompileError, "PullList mode"); + + UNIT_ASSERT_EXCEPTION_CONTAINS([&](){ + factory->MakePushStreamProgram(FakeIS(1, true), FakeOS(true), sql, ETranslationMode::PG); + }(), TCompileError, "PullList mode"); + } + + Y_UNIT_TEST(TestSqlWrongTableName) { + auto factory = MakeProgramFactory(); + + auto sql = TString(R"( + SELECT * FROM WrongTable; + )"); + + UNIT_ASSERT_EXCEPTION_CONTAINS([&](){ + factory->MakePullListProgram(FakeIS(1, true), FakeOS(true), sql, ETranslationMode::PG); + }(), TCompileError, "Failed to optimize"); + + UNIT_ASSERT_EXCEPTION_CONTAINS([&](){ + factory->MakePullStreamProgram(FakeIS(1, true), FakeOS(true), sql, ETranslationMode::PG); + }(), TCompileError, "PullList mode"); + + UNIT_ASSERT_EXCEPTION_CONTAINS([&](){ + factory->MakePushStreamProgram(FakeIS(1, true), FakeOS(true), sql, ETranslationMode::PG); + }(), TCompileError, "PullList mode"); + } + + Y_UNIT_TEST(TestInvalidSql) { + auto factory = MakeProgramFactory(); + + auto sql = TString(R"( + Just some invalid SQL; + )"); + + UNIT_ASSERT_EXCEPTION_CONTAINS([&](){ + factory->MakePullListProgram(FakeIS(1, true), FakeOS(true), sql, ETranslationMode::PG); + }(), TCompileError, "failed to parse PG"); + + UNIT_ASSERT_EXCEPTION_CONTAINS([&](){ + factory->MakePullStreamProgram(FakeIS(1, true), FakeOS(true), sql, ETranslationMode::PG); + }(), TCompileError, "PullList mode"); + + UNIT_ASSERT_EXCEPTION_CONTAINS([&](){ + factory->MakePushStreamProgram(FakeIS(1, true), FakeOS(true), sql, ETranslationMode::PG); + }(), TCompileError, "PullList mode"); + } +} diff --git a/yql/essentials/public/purecalc/ut/test_pool.cpp b/yql/essentials/public/purecalc/ut/test_pool.cpp new file mode 100644 index 00000000000..b3de36cbf5f --- /dev/null +++ b/yql/essentials/public/purecalc/ut/test_pool.cpp @@ -0,0 +1,184 @@ +#include <library/cpp/testing/unittest/registar.h> + +#include <yql/essentials/public/purecalc/common/interface.h> +#include <yql/essentials/public/purecalc/io_specs/protobuf/spec.h> +#include <yql/essentials/public/purecalc/ut/protos/test_structs.pb.h> +#include <library/cpp/protobuf/util/pb_io.h> + +#include <util/string/cast.h> + +using namespace NYql::NPureCalc; + +namespace { + class TStringMessageStreamImpl: public IStream<NPureCalcProto::TStringMessage*> { + private: + ui32 I_ = 0; + NPureCalcProto::TStringMessage Message_{}; + + public: + NPureCalcProto::TStringMessage* Fetch() override { + if (I_ >= 3) { + return nullptr; + } else { + Message_.SetX(ToString(I_)); + ++I_; + return &Message_; + } + } + }; + + class TStringMessageConsumerImpl: public IConsumer<NPureCalcProto::TStringMessage*> { + private: + TVector<TString>* Buf_; + + public: + TStringMessageConsumerImpl(TVector<TString>* buf) + : Buf_(buf) + { + } + + public: + void OnObject(NPureCalcProto::TStringMessage* t) override { + Buf_->push_back(t->GetX()); + } + + void OnFinish() override { + } + }; + +} + +Y_UNIT_TEST_SUITE(TestWorkerPool) { + static TString sql = "SELECT 'abc'u || X AS X FROM Input"; + + static TVector<TString> expected{"abc0", "abc1", "abc2"}; + + void TestPullStreamImpl(bool useWorkerPool) { + auto factory = MakeProgramFactory(TProgramFactoryOptions().SetUseWorkerPool(useWorkerPool)); + + auto program = factory->MakePullStreamProgram( + TProtobufInputSpec<NPureCalcProto::TStringMessage>(), + TProtobufOutputSpec<NPureCalcProto::TStringMessage>(), + sql, + ETranslationMode::SQL + ); + + auto check = [](IStream<NPureCalcProto::TStringMessage*>* output) { + TVector<TString> actual; + while (auto *x = output->Fetch()) { + actual.push_back(x->GetX()); + } + + UNIT_ASSERT_VALUES_EQUAL(expected, actual); + }; + + // Sequential use + for (size_t i = 0; i < 2; ++i) { + auto output = program->Apply(MakeHolder<TStringMessageStreamImpl>()); + check(output.Get()); + } + // Parallel use + { + auto output1 = program->Apply(MakeHolder<TStringMessageStreamImpl>()); + auto output2 = program->Apply(MakeHolder<TStringMessageStreamImpl>()); + check(output1.Get()); + check(output2.Get()); + } + } + + Y_UNIT_TEST(TestPullStreamUseWorkerPool) { + TestPullStreamImpl(true); + } + + Y_UNIT_TEST(TestPullStreamNoWorkerPool) { + TestPullStreamImpl(false); + } + + void TestPullListImpl(bool useWorkerPool) { + auto factory = MakeProgramFactory(TProgramFactoryOptions().SetUseWorkerPool(useWorkerPool)); + + auto program = factory->MakePullListProgram( + TProtobufInputSpec<NPureCalcProto::TStringMessage>(), + TProtobufOutputSpec<NPureCalcProto::TStringMessage>(), + sql, + ETranslationMode::SQL + ); + + auto check = [](IStream<NPureCalcProto::TStringMessage*>* output) { + TVector<TString> actual; + while (auto *x = output->Fetch()) { + actual.push_back(x->GetX()); + } + + UNIT_ASSERT_VALUES_EQUAL(expected, actual); + }; + + // Sequential use + for (size_t i = 0; i < 2; ++i) { + auto output = program->Apply(MakeHolder<TStringMessageStreamImpl>()); + check(output.Get()); + } + // Parallel use + { + auto output1 = program->Apply(MakeHolder<TStringMessageStreamImpl>()); + auto output2 = program->Apply(MakeHolder<TStringMessageStreamImpl>()); + check(output1.Get()); + check(output2.Get()); + } + } + + Y_UNIT_TEST(TestPullListUseWorkerPool) { + TestPullListImpl(true); + } + + Y_UNIT_TEST(TestPullListNoWorkerPool) { + TestPullListImpl(false); + } + + void TestPushStreamImpl(bool useWorkerPool) { + auto factory = MakeProgramFactory(TProgramFactoryOptions().SetUseWorkerPool(useWorkerPool)); + + auto program = factory->MakePushStreamProgram( + TProtobufInputSpec<NPureCalcProto::TStringMessage>(), + TProtobufOutputSpec<NPureCalcProto::TStringMessage>(), + sql, + ETranslationMode::SQL + ); + + auto check = [](IConsumer<NPureCalcProto::TStringMessage*>* input, const TVector<TString>& result) { + NPureCalcProto::TStringMessage message; + for (auto s: {"0", "1", "2"}) { + message.SetX(s); + input->OnObject(&message); + } + input->OnFinish(); + + UNIT_ASSERT_VALUES_EQUAL(expected, result); + }; + + // Sequential use + for (size_t i = 0; i < 2; ++i) { + TVector<TString> actual; + auto input = program->Apply(MakeHolder<TStringMessageConsumerImpl>(&actual)); + check(input.Get(), actual); + } + + // Parallel use + { + TVector<TString> actual1; + auto input1 = program->Apply(MakeHolder<TStringMessageConsumerImpl>(&actual1)); + TVector<TString> actual2; + auto input2 = program->Apply(MakeHolder<TStringMessageConsumerImpl>(&actual2)); + check(input1.Get(), actual1); + check(input2.Get(), actual2); + } + } + + Y_UNIT_TEST(TestPushStreamUseWorkerPool) { + TestPushStreamImpl(true); + } + + Y_UNIT_TEST(TestPushStreamNoWorkerPool) { + TestPushStreamImpl(false); + } +} diff --git a/yql/essentials/public/purecalc/ut/test_schema.cpp b/yql/essentials/public/purecalc/ut/test_schema.cpp new file mode 100644 index 00000000000..9763e52b005 --- /dev/null +++ b/yql/essentials/public/purecalc/ut/test_schema.cpp @@ -0,0 +1 @@ +#include <library/cpp/testing/unittest/registar.h> diff --git a/yql/essentials/public/purecalc/ut/test_sexpr.cpp b/yql/essentials/public/purecalc/ut/test_sexpr.cpp new file mode 100644 index 00000000000..9c50dd1f291 --- /dev/null +++ b/yql/essentials/public/purecalc/ut/test_sexpr.cpp @@ -0,0 +1,55 @@ +#include <yql/essentials/public/purecalc/purecalc.h> + +#include "fake_spec.h" + +#include <yql/essentials/public/purecalc/ut/protos/test_structs.pb.h> + +#include <library/cpp/testing/unittest/registar.h> + +Y_UNIT_TEST_SUITE(TestSExpr) { + Y_UNIT_TEST(TestSExprCompile) { + using namespace NYql::NPureCalc; + + auto factory = MakeProgramFactory(); + + auto expr = TString(R"( + ( + (return (Self '0)) + ) + )"); + + UNIT_ASSERT_NO_EXCEPTION([&](){ + factory->MakePullStreamProgram(FakeIS(), FakeOS(), expr, ETranslationMode::SExpr); + }()); + + UNIT_ASSERT_NO_EXCEPTION([&](){ + factory->MakePullListProgram(FakeIS(), FakeOS(), expr, ETranslationMode::SExpr); + }()); + + UNIT_ASSERT_NO_EXCEPTION([&](){ + factory->MakePushStreamProgram(FakeIS(), FakeOS(), expr, ETranslationMode::SExpr); + }()); + } + + Y_UNIT_TEST(TestInvalidSExpr) { + using namespace NYql::NPureCalc; + + auto factory = MakeProgramFactory(); + + auto sql = TString(R"( + Some totally invalid SExpr + )"); + + UNIT_ASSERT_EXCEPTION_CONTAINS([&](){ + factory->MakePullStreamProgram(FakeIS(), FakeOS(), sql, ETranslationMode::SExpr); + }(), TCompileError, "failed to parse s-expression"); + + UNIT_ASSERT_EXCEPTION_CONTAINS([&](){ + factory->MakePullListProgram(FakeIS(), FakeOS(), sql, ETranslationMode::SExpr); + }(), TCompileError, "failed to parse s-expression"); + + UNIT_ASSERT_EXCEPTION_CONTAINS([&](){ + factory->MakePushStreamProgram(FakeIS(), FakeOS(), sql, ETranslationMode::SExpr); + }(), TCompileError, "failed to parse s-expression"); + } +} diff --git a/yql/essentials/public/purecalc/ut/test_sql.cpp b/yql/essentials/public/purecalc/ut/test_sql.cpp new file mode 100644 index 00000000000..64ec760f9ec --- /dev/null +++ b/yql/essentials/public/purecalc/ut/test_sql.cpp @@ -0,0 +1,205 @@ +#include <yql/essentials/public/purecalc/purecalc.h> + +#include "fake_spec.h" + +#include <yql/essentials/public/purecalc/ut/protos/test_structs.pb.h> + +#include <library/cpp/testing/unittest/registar.h> + +Y_UNIT_TEST_SUITE(TestSql) { + using namespace NYql::NPureCalc; + + Y_UNIT_TEST(TestSqlCompile) { + auto factory = MakeProgramFactory(); + + auto sql = TString(R"( + SELECT * FROM Input; + )"); + + UNIT_ASSERT_NO_EXCEPTION([&](){ + factory->MakePullStreamProgram(FakeIS(), FakeOS(), sql, ETranslationMode::SQL); + }()); + + UNIT_ASSERT_NO_EXCEPTION([&](){ + factory->MakePullListProgram(FakeIS(), FakeOS(), sql, ETranslationMode::SQL); + }()); + + auto program = factory->MakePushStreamProgram(FakeIS(), FakeOS(), sql, ETranslationMode::SQL); + auto expectedIssues = TString(R"(<main>: Warning: Type annotation, code: 1030 + generated.sql:2:13: Warning: At function: PersistableRepr + generated.sql:2:13: Warning: Persistable required. Atom, key, world, datasink, datasource, callable, resource, stream and lambda are not persistable, code: 1104 +)"); + + UNIT_ASSERT_VALUES_EQUAL(expectedIssues, program->GetIssues().ToString()); + } + + Y_UNIT_TEST(TestSqlCompileSingleUnnamedInput) { + auto factory = MakeProgramFactory(); + + auto sql = TString(R"( + SELECT * FROM TABLES() + )"); + + UNIT_ASSERT_NO_EXCEPTION([&](){ + factory->MakePullStreamProgram(FakeIS(), FakeOS(), sql, ETranslationMode::SQL); + }()); + + UNIT_ASSERT_NO_EXCEPTION([&](){ + factory->MakePullListProgram(FakeIS(), FakeOS(), sql, ETranslationMode::SQL); + }()); + + UNIT_ASSERT_NO_EXCEPTION([&](){ + factory->MakePushStreamProgram(FakeIS(), FakeOS(), sql, ETranslationMode::SQL); + }()); + } + + Y_UNIT_TEST(TestSqlCompileNamedMultiinputs) { + auto factory = MakeProgramFactory(); + + auto sql = TString(R"( + SELECT * FROM Input0 + UNION ALL + SELECT * FROM Input1 + )"); + + UNIT_ASSERT_NO_EXCEPTION([&](){ + factory->MakePullListProgram(FakeIS(2), FakeOS(), sql, ETranslationMode::SQL); + }()); + } + + Y_UNIT_TEST(TestSqlCompileUnnamedMultiinputs) { + auto factory = MakeProgramFactory(); + + auto sql = TString(R"( + $t0, $t1, $t2 = PROCESS TABLES(); + SELECT * FROM $t0 + UNION ALL + SELECT * FROM $t1 + UNION ALL + SELECT * FROM $t2 + )"); + + UNIT_ASSERT_NO_EXCEPTION([&](){ + factory->MakePullListProgram(FakeIS(3), FakeOS(), sql, ETranslationMode::SQL); + }()); + } + + Y_UNIT_TEST(TestSqlCompileWithWarning) { + auto factory = MakeProgramFactory(); + + auto sql = TString(R"( + $x = 1; + $y = 2; + SELECT $x as Name FROM Input; + )"); + + auto expectedIssues = TString(R"(generated.sql:3:13: Warning: Symbol $y is not used, code: 4527 +<main>: Warning: Type annotation, code: 1030 + generated.sql:4:13: Warning: At function: PersistableRepr + generated.sql:4:13: Warning: Persistable required. Atom, key, world, datasink, datasource, callable, resource, stream and lambda are not persistable, code: 1104 +)"); + + auto program = factory->MakePushStreamProgram(FakeIS(), FakeOS(), sql, ETranslationMode::SQL); + UNIT_ASSERT_VALUES_EQUAL(expectedIssues, program->GetIssues().ToString()); + } + + Y_UNIT_TEST(TestSqlWrongTableName) { + auto factory = MakeProgramFactory(); + + auto sql = TString(R"( + SELECT * FROM WrongTable; + )"); + + UNIT_ASSERT_EXCEPTION_CONTAINS([&](){ + factory->MakePullStreamProgram(FakeIS(), FakeOS(), sql, ETranslationMode::SQL); + }(), TCompileError, "Failed to optimize"); + + UNIT_ASSERT_EXCEPTION_CONTAINS([&](){ + factory->MakePullListProgram(FakeIS(), FakeOS(), sql, ETranslationMode::SQL); + }(), TCompileError, "Failed to optimize"); + + UNIT_ASSERT_EXCEPTION_CONTAINS([&](){ + factory->MakePushStreamProgram(FakeIS(), FakeOS(), sql, ETranslationMode::SQL); + }(), TCompileError, "Failed to optimize"); + } + + Y_UNIT_TEST(TestAllocateLargeStringOnEvaluate) { + auto factory = MakeProgramFactory(); + + auto sql = TString(R"( + $data = Length(EvaluateExpr("long string" || " very loooong string")); + SELECT $data as Name FROM Input; + )"); + + UNIT_ASSERT_NO_EXCEPTION([&](){ + factory->MakePullStreamProgram(FakeIS(), FakeOS(), sql, ETranslationMode::SQL); + }()); + + UNIT_ASSERT_NO_EXCEPTION([&](){ + factory->MakePullListProgram(FakeIS(), FakeOS(), sql, ETranslationMode::SQL); + }()); + + UNIT_ASSERT_NO_EXCEPTION([&](){ + factory->MakePushStreamProgram(FakeIS(), FakeOS(), sql, ETranslationMode::SQL); + }()); + } + + Y_UNIT_TEST(TestInvalidSql) { + auto factory = MakeProgramFactory(); + + auto sql = TString(R"( + Just some invalid SQL; + )"); + + UNIT_ASSERT_EXCEPTION_CONTAINS([&](){ + factory->MakePullStreamProgram(FakeIS(), FakeOS(), sql, ETranslationMode::SQL); + }(), TCompileError, "failed to parse SQL"); + + UNIT_ASSERT_EXCEPTION_CONTAINS([&](){ + factory->MakePullListProgram(FakeIS(), FakeOS(), sql, ETranslationMode::SQL); + }(), TCompileError, "failed to parse SQL"); + + UNIT_ASSERT_EXCEPTION_CONTAINS([&](){ + factory->MakePushStreamProgram(FakeIS(), FakeOS(), sql, ETranslationMode::SQL); + }(), TCompileError, "failed to parse SQL"); + } + + Y_UNIT_TEST(TestUseProcess) { + auto factory = MakeProgramFactory(); + + auto sql = TString(R"( + $processor = ($row) -> ($row); + + PROCESS Input using $processor(TableRow()); + )"); + + UNIT_ASSERT_NO_EXCEPTION([&](){ + factory->MakePullStreamProgram(FakeIS(), FakeOS(), sql, ETranslationMode::SQL); + }()); + + UNIT_ASSERT_NO_EXCEPTION([&](){ + factory->MakePullListProgram(FakeIS(), FakeOS(), sql, ETranslationMode::SQL); + }()); + + UNIT_ASSERT_NO_EXCEPTION([&](){ + factory->MakePushStreamProgram(FakeIS(), FakeOS(), sql, ETranslationMode::SQL); + }()); + } + + Y_UNIT_TEST(TestUseCodegen) { + auto factory = MakeProgramFactory(); + + auto sql = TString(R"( + $processor = ($row) -> { + $lambda = EvaluateCode(LambdaCode(($row) -> ($row))); + return $lambda($row); + }; + + PROCESS Input using $processor(TableRow()); + )"); + + UNIT_ASSERT_NO_EXCEPTION([&](){ + factory->MakePullListProgram(FakeIS(), FakeOS(), sql, ETranslationMode::SQL); + }()); + } +} diff --git a/yql/essentials/public/purecalc/ut/test_udf.cpp b/yql/essentials/public/purecalc/ut/test_udf.cpp new file mode 100644 index 00000000000..732917739e7 --- /dev/null +++ b/yql/essentials/public/purecalc/ut/test_udf.cpp @@ -0,0 +1,195 @@ +#include <library/cpp/testing/unittest/registar.h> + +#include <yql/essentials/public/purecalc/purecalc.h> +#include <yql/essentials/public/purecalc/io_specs/protobuf/spec.h> +#include <yql/essentials/public/purecalc/ut/protos/test_structs.pb.h> +#include <yql/essentials/public/udf/udf_counter.h> +#include <yql/essentials/public/udf/udf_type_builder.h> +#include <library/cpp/testing/unittest/registar.h> + +class TMyModule : public NKikimr::NUdf::IUdfModule { +public: + class TFunc : public NKikimr::NUdf::TBoxedValue { + public: + TFunc(NKikimr::NUdf::TCounter counter, NKikimr::NUdf::TScopedProbe scopedProbe) + : Counter_(counter) + , ScopedProbe_(scopedProbe) + {} + + NKikimr::NUdf::TUnboxedValue Run(const NKikimr::NUdf::IValueBuilder* valueBuilder, const NKikimr::NUdf::TUnboxedValuePod* args) const override { + Y_UNUSED(valueBuilder); + with_lock(ScopedProbe_) { + Counter_.Inc(); + return NKikimr::NUdf::TUnboxedValuePod(args[0].Get<i32>()); + } + } + + private: + mutable NKikimr::NUdf::TCounter Counter_; + mutable NKikimr::NUdf::TScopedProbe ScopedProbe_; + }; + + void GetAllFunctions(NKikimr::NUdf::IFunctionsSink& sink) const override { + Y_UNUSED(sink); + } + + void BuildFunctionTypeInfo( + const NKikimr::NUdf::TStringRef& name, + NKikimr::NUdf::TType* userType, + const NKikimr::NUdf::TStringRef& typeConfig, + ui32 flags, + NKikimr::NUdf::IFunctionTypeInfoBuilder& builder) const override { + Y_UNUSED(userType); + Y_UNUSED(typeConfig); + Y_UNUSED(flags); + if (name == NKikimr::NUdf::TStringRef::Of("Func")) { + builder.SimpleSignature<i32(i32)>(); + builder.Implementation(new TFunc( + builder.GetCounter("FuncCalls",true), + builder.GetScopedProbe("FuncTime") + )); + } + } + + void CleanupOnTerminate() const override { + } +}; + +class TMyCountersProvider : public NKikimr::NUdf::ICountersProvider, public NKikimr::NUdf::IScopedProbeHost { +public: + TMyCountersProvider(i64* calls, TString* log) + : Calls_(calls) + , Log_(log) + {} + + NKikimr::NUdf::TCounter GetCounter(const NKikimr::NUdf::TStringRef& module, const NKikimr::NUdf::TStringRef& name, bool deriv) override { + UNIT_ASSERT_VALUES_EQUAL(module, "MyModule"); + UNIT_ASSERT_VALUES_EQUAL(name, "FuncCalls"); + UNIT_ASSERT_VALUES_EQUAL(deriv, true); + return NKikimr::NUdf::TCounter(Calls_); + } + + NKikimr::NUdf::TScopedProbe GetScopedProbe(const NKikimr::NUdf::TStringRef& module, const NKikimr::NUdf::TStringRef& name) override { + UNIT_ASSERT_VALUES_EQUAL(module, "MyModule"); + UNIT_ASSERT_VALUES_EQUAL(name, "FuncTime"); + return NKikimr::NUdf::TScopedProbe(Log_ ? this : nullptr, Log_); + } + + void Acquire(void* cookie) override { + UNIT_ASSERT(cookie == Log_); + *Log_ += "Enter\n"; + } + + void Release(void* cookie) override { + UNIT_ASSERT(cookie == Log_); + *Log_ += "Exit\n"; + } + +private: + i64* Calls_; + TString* Log_; +}; + +namespace NPureCalcProto { + class TUnparsed; + class TParsed; +} + +class TDocInput : public NYql::NPureCalc::IStream<NPureCalcProto::TUnparsed*> { +public: + NPureCalcProto::TUnparsed* Fetch() override { + if (Extracted) { + return nullptr; + } + + Extracted = true; + Msg.SetS("foo"); + return &Msg; + } + +public: + NPureCalcProto::TUnparsed Msg; + bool Extracted = false; +}; + +Y_UNIT_TEST_SUITE(TestUdf) { + Y_UNIT_TEST(TestCounters) { + using namespace NYql::NPureCalc; + + auto factory = MakeProgramFactory(); + + i64 callCounter = 0; + TMyCountersProvider myCountersProvider(&callCounter, nullptr); + factory->AddUdfModule("MyModule", new TMyModule); + factory->SetCountersProvider(&myCountersProvider); + + auto program = factory->MakePullStreamProgram( + TProtobufInputSpec<NPureCalcProto::TUnparsed>(), + TProtobufOutputSpec<NPureCalcProto::TParsed>(), + "select MyModule::Func(1) as A, 2 as B, 3 as C from Input", + ETranslationMode::SQL); + + auto out = program->Apply(MakeHolder<TDocInput>()); + auto* message = out->Fetch(); + UNIT_ASSERT(message); + UNIT_ASSERT_VALUES_EQUAL(message->GetA(), 1); + UNIT_ASSERT_VALUES_EQUAL(message->GetB(), 2); + UNIT_ASSERT_VALUES_EQUAL(message->GetC(), 3); + UNIT_ASSERT_VALUES_EQUAL(callCounter, 1); + UNIT_ASSERT(!out->Fetch()); + } + + Y_UNIT_TEST(TestCountersFilteredColumns) { + using namespace NYql::NPureCalc; + + auto factory = MakeProgramFactory(); + + i64 callCounter = 0; + TMyCountersProvider myCountersProvider(&callCounter, nullptr); + factory->AddUdfModule("MyModule", new TMyModule); + factory->SetCountersProvider(&myCountersProvider); + + auto ospec = TProtobufOutputSpec<NPureCalcProto::TParsed>(); + ospec.SetOutputColumnsFilter(THashSet<TString>({"B", "C"})); + auto program = factory->MakePullStreamProgram( + TProtobufInputSpec<NPureCalcProto::TUnparsed>(), + ospec, + "select MyModule::Func(1) as A, 2 as B, 3 as C from Input", + ETranslationMode::SQL); + + auto out = program->Apply(MakeHolder<TDocInput>()); + auto* message = out->Fetch(); + UNIT_ASSERT(message); + UNIT_ASSERT_VALUES_EQUAL(message->GetA(), 0); + UNIT_ASSERT_VALUES_EQUAL(message->GetB(), 2); + UNIT_ASSERT_VALUES_EQUAL(message->GetC(), 3); + UNIT_ASSERT_VALUES_EQUAL(callCounter, 0); + UNIT_ASSERT(!out->Fetch()); + } + + Y_UNIT_TEST(TestScopedProbes) { + using namespace NYql::NPureCalc; + + auto factory = MakeProgramFactory(); + + TString log; + TMyCountersProvider myCountersProvider(nullptr, &log); + factory->AddUdfModule("MyModule", new TMyModule); + factory->SetCountersProvider(&myCountersProvider); + + auto program = factory->MakePullStreamProgram( + TProtobufInputSpec<NPureCalcProto::TUnparsed>(), + TProtobufOutputSpec<NPureCalcProto::TParsed>(), + "select MyModule::Func(1) as A, 2 as B, 3 as C from Input", + ETranslationMode::SQL); + + auto out = program->Apply(MakeHolder<TDocInput>()); + auto* message = out->Fetch(); + UNIT_ASSERT(message); + UNIT_ASSERT_VALUES_EQUAL(message->GetA(), 1); + UNIT_ASSERT_VALUES_EQUAL(message->GetB(), 2); + UNIT_ASSERT_VALUES_EQUAL(message->GetC(), 3); + UNIT_ASSERT_VALUES_EQUAL(log, "Enter\nExit\n"); + UNIT_ASSERT(!out->Fetch()); + } +} diff --git a/yql/essentials/public/purecalc/ut/test_user_data.cpp b/yql/essentials/public/purecalc/ut/test_user_data.cpp new file mode 100644 index 00000000000..b87940ab6b2 --- /dev/null +++ b/yql/essentials/public/purecalc/ut/test_user_data.cpp @@ -0,0 +1,62 @@ +#include <yql/essentials/public/purecalc/purecalc.h> +#include <yql/essentials/public/purecalc/io_specs/protobuf/spec.h> +#include <yql/essentials/public/purecalc/ut/protos/test_structs.pb.h> +#include <yql/essentials/public/purecalc/ut/empty_stream.h> + +#include <library/cpp/testing/unittest/registar.h> + +Y_UNIT_TEST_SUITE(TestUserData) { + Y_UNIT_TEST(TestUserData) { + using namespace NYql::NPureCalc; + + auto options = TProgramFactoryOptions() + .AddFile(NYql::NUserData::EDisposition::INLINE, "my_file.txt", "my content!"); + + auto factory = MakeProgramFactory(options); + + auto program = factory->MakePullListProgram( + TProtobufInputSpec<NPureCalcProto::TStringMessage>(), + TProtobufOutputSpec<NPureCalcProto::TStringMessage>(), + "SELECT UNWRAP(CAST(FileContent(\"my_file.txt\") AS Utf8)) AS X", + ETranslationMode::SQL + ); + + auto stream = program->Apply(EmptyStream<NPureCalcProto::TStringMessage*>()); + + NPureCalcProto::TStringMessage* message; + + UNIT_ASSERT(message = stream->Fetch()); + UNIT_ASSERT_EQUAL(message->GetX(), "my content!"); + UNIT_ASSERT(!stream->Fetch()); + } + + Y_UNIT_TEST(TestUserDataLibrary) { + using namespace NYql::NPureCalc; + + try { + auto options = TProgramFactoryOptions() + .AddLibrary(NYql::NUserData::EDisposition::INLINE, "a.sql", "$x = 1; EXPORT $x;") + .AddLibrary(NYql::NUserData::EDisposition::INLINE, "b.sql", "IMPORT a SYMBOLS $x; $y = CAST($x + 1 AS String); EXPORT $y;"); + + auto factory = MakeProgramFactory(options); + + auto program = factory->MakePullListProgram( + TProtobufInputSpec<NPureCalcProto::TStringMessage>(), + TProtobufOutputSpec<NPureCalcProto::TStringMessage>(), + "IMPORT b SYMBOLS $y; SELECT CAST($y AS Utf8) ?? '' AS X;", + ETranslationMode::SQL + ); + + auto stream = program->Apply(EmptyStream<NPureCalcProto::TStringMessage*>()); + + NPureCalcProto::TStringMessage* message; + + UNIT_ASSERT(message = stream->Fetch()); + UNIT_ASSERT_EQUAL(message->GetX(), "2"); + UNIT_ASSERT(!stream->Fetch()); + } catch (const TCompileError& e) { + Cerr << e; + throw e; + } + } +} diff --git a/yql/essentials/public/purecalc/ut/ya.make b/yql/essentials/public/purecalc/ut/ya.make new file mode 100644 index 00000000000..474280a8e27 --- /dev/null +++ b/yql/essentials/public/purecalc/ut/ya.make @@ -0,0 +1,28 @@ +UNITTEST() + +SRCS( + empty_stream.h + fake_spec.cpp + fake_spec.h + test_schema.cpp + test_sexpr.cpp + test_sql.cpp + test_pg.cpp + test_udf.cpp + test_user_data.cpp + test_eval.cpp + test_pool.cpp + test_mixed_allocators.cpp +) + +PEERDIR( + yql/essentials/public/purecalc + yql/essentials/public/purecalc/io_specs/protobuf + yql/essentials/public/purecalc/ut/protos +) + +SIZE(MEDIUM) + +YQL_LAST_ABI_VERSION() + +END() diff --git a/yql/essentials/public/purecalc/ya.make b/yql/essentials/public/purecalc/ya.make new file mode 100644 index 00000000000..e7f3ff8818f --- /dev/null +++ b/yql/essentials/public/purecalc/ya.make @@ -0,0 +1,28 @@ +LIBRARY() + +SRCS( + purecalc.cpp +) + +PEERDIR( + yql/essentials/public/udf/service/exception_policy + yql/essentials/public/purecalc/common +) + +YQL_LAST_ABI_VERSION() + +PROVIDES(YQL_PURECALC) + +END() + +RECURSE( + common + examples + helpers + io_specs + no_llvm +) + +RECURSE_FOR_TESTS( + ut +) diff --git a/yql/essentials/public/ya.make b/yql/essentials/public/ya.make index 8cb77bcdb11..25709b4555b 100644 --- a/yql/essentials/public/ya.make +++ b/yql/essentials/public/ya.make @@ -2,6 +2,7 @@ RECURSE( decimal fastcheck issue + purecalc result_format types udf diff --git a/yql/essentials/tools/purebench/purebench.cpp b/yql/essentials/tools/purebench/purebench.cpp index d175a271498..1ec81317b80 100644 --- a/yql/essentials/tools/purebench/purebench.cpp +++ b/yql/essentials/tools/purebench/purebench.cpp @@ -1,10 +1,10 @@ #include <library/cpp/svnversion/svnversion.h> #include <library/cpp/getopt/last_getopt.h> -#include <contrib/ydb/library/yql/public/purecalc/purecalc.h> -#include <contrib/ydb/library/yql/public/purecalc/io_specs/mkql/spec.h> -#include <contrib/ydb/library/yql/public/purecalc/io_specs/arrow/spec.h> -#include <contrib/ydb/library/yql/public/purecalc/helpers/stream/stream_from_vector.h> +#include <yql/essentials/public/purecalc/purecalc.h> +#include <yt/yql/purecalc/io_specs/mkql/spec.h> +#include <yql/essentials/public/purecalc/io_specs/arrow/spec.h> +#include <yql/essentials/public/purecalc/helpers/stream/stream_from_vector.h> #include <yql/essentials/utils/log/log.h> #include <yql/essentials/utils/backtrace/backtrace.h> diff --git a/yql/essentials/tools/purebench/ya.make b/yql/essentials/tools/purebench/ya.make index 055d01743ff..7068203bae1 100644 --- a/yql/essentials/tools/purebench/ya.make +++ b/yql/essentials/tools/purebench/ya.make @@ -23,9 +23,9 @@ PEERDIR( yql/essentials/public/udf/service/exception_policy library/cpp/skiff library/cpp/yson - contrib/ydb/library/yql/public/purecalc/io_specs/mkql - contrib/ydb/library/yql/public/purecalc/io_specs/arrow - contrib/ydb/library/yql/public/purecalc + yt/yql/purecalc/io_specs/mkql + yql/essentials/public/purecalc/io_specs/arrow + yql/essentials/public/purecalc ) YQL_LAST_ABI_VERSION() |