diff options
| author | vvvv <[email protected]> | 2024-11-07 04:19:26 +0300 |
|---|---|---|
| committer | vvvv <[email protected]> | 2024-11-07 04:29:50 +0300 |
| commit | 2661be00f3bc47590fda9218bf0386d6355c8c88 (patch) | |
| tree | 3d316c07519191283d31c5f537efc6aabb42a2f0 /yql/essentials/minikql/computation/mkql_computation_node.h | |
| parent | cf2a23963ac10add28c50cc114fbf48953eca5aa (diff) | |
Moved yql/minikql YQL-19206
init
[nodiff:caesar]
commit_hash:d1182ef7d430ccf7e4d37ed933c7126d7bd5d6e4
Diffstat (limited to 'yql/essentials/minikql/computation/mkql_computation_node.h')
| -rw-r--r-- | yql/essentials/minikql/computation/mkql_computation_node.h | 432 |
1 files changed, 432 insertions, 0 deletions
diff --git a/yql/essentials/minikql/computation/mkql_computation_node.h b/yql/essentials/minikql/computation/mkql_computation_node.h new file mode 100644 index 00000000000..da109d4234b --- /dev/null +++ b/yql/essentials/minikql/computation/mkql_computation_node.h @@ -0,0 +1,432 @@ +#pragma once + +#include "mkql_computation_node_list.h" +#include "mkql_spiller_factory.h" + +#include <yql/essentials/minikql/defs.h> +#include <yql/essentials/minikql/mkql_node.h> +#include <yql/essentials/minikql/mkql_node_visitor.h> +#include <yql/essentials/minikql/mkql_function_registry.h> +#include <yql/essentials/minikql/mkql_alloc.h> +#include <yql/essentials/minikql/mkql_stats_registry.h> +#include <yql/essentials/minikql/mkql_terminator.h> + +#include <yql/essentials/public/udf/udf_value.h> +#include <yql/essentials/public/udf/udf_validate.h> +#include <yql/essentials/public/udf/udf_value_builder.h> + +#include <library/cpp/cache/cache.h> +#include <library/cpp/random_provider/random_provider.h> +#include <library/cpp/time_provider/time_provider.h> + +#include <map> +#include <unordered_set> +#include <unordered_map> +#include <vector> + +namespace NKikimr { +namespace NMiniKQL { + +inline const TDefaultListRepresentation* GetDefaultListRepresentation(const NUdf::TUnboxedValuePod& value) { + return reinterpret_cast<const TDefaultListRepresentation*>(NUdf::TBoxedValueAccessor::GetListRepresentation(*value.AsBoxed())); +} + +enum class EGraphPerProcess { + Multi, + Single +}; + +struct TComputationOpts { + TComputationOpts(IStatsRegistry* stats) + : Stats(stats) + {} + + IStatsRegistry *const Stats; +}; + +struct TComputationOptsFull: public TComputationOpts { + TComputationOptsFull(IStatsRegistry* stats, TAllocState& allocState, const TTypeEnvironment& typeEnv, IRandomProvider& randomProvider, + ITimeProvider& timeProvider, NUdf::EValidatePolicy validatePolicy, const NUdf::ISecureParamsProvider* secureParamsProvider, NUdf::ICountersProvider* countersProvider) + : TComputationOpts(stats) + , AllocState(allocState) + , TypeEnv(typeEnv) + , RandomProvider(randomProvider) + , TimeProvider(timeProvider) + , ValidatePolicy(validatePolicy) + , SecureParamsProvider(secureParamsProvider) + , CountersProvider(countersProvider) + {} + + TAllocState& AllocState; + const TTypeEnvironment& TypeEnv; + IRandomProvider& RandomProvider; + ITimeProvider& TimeProvider; + NUdf::EValidatePolicy ValidatePolicy; + const NUdf::ISecureParamsProvider *const SecureParamsProvider; + NUdf::ICountersProvider *const CountersProvider; +}; + +struct TWideFieldsInitInfo { + ui32 MutablesIndex = 0; + ui32 WideFieldsIndex = 0; + std::set<ui32> Used; +}; + +struct TComputationMutables { + ui32 CurValueIndex = 0U; + std::vector<ui32> SerializableValues; // Indices of values that need to be saved in IComputationGraph::SaveGraphState() and restored in IComputationGraph::LoadGraphState(). + ui32 CurWideFieldsIndex = 0U; + std::vector<TWideFieldsInitInfo> WideFieldInitialize; + + void DeferWideFieldsInit(ui32 count, std::set<ui32> used) { + Y_DEBUG_ABORT_UNLESS(AllOf(used, [count](ui32 i) { return i < count; })); + WideFieldInitialize.push_back({CurValueIndex, CurWideFieldsIndex, std::move(used)}); + CurValueIndex += count; + CurWideFieldsIndex += count; + } + + ui32 IncrementWideFieldsIndex(ui32 addend) { + auto cur = CurWideFieldsIndex; + CurWideFieldsIndex += addend; + return cur; + } +}; + +class THolderFactory; + +// Do not reorder: used in LLVM! +struct TComputationContextLLVM { + const THolderFactory& HolderFactory; + IStatsRegistry *const Stats; + const std::unique_ptr<NUdf::TUnboxedValue[]> MutableValues; + const NUdf::IValueBuilder *const Builder; + float UsageAdjustor = 1.f; + ui32 RssCounter = 0U; + const NUdf::TSourcePosition* CalleePosition = nullptr; +}; + +struct TComputationContext : public TComputationContextLLVM { + IRandomProvider& RandomProvider; + ITimeProvider& TimeProvider; + bool ExecuteLLVM = false; + arrow::MemoryPool& ArrowMemoryPool; + std::vector<NUdf::TUnboxedValue*> WideFields; + const TTypeEnvironment& TypeEnv; + const TComputationMutables Mutables; + std::shared_ptr<ISpillerFactory> SpillerFactory; + const NUdf::ITypeInfoHelper::TPtr TypeInfoHelper; + NUdf::ICountersProvider *const CountersProvider; + const NUdf::ISecureParamsProvider *const SecureParamsProvider; + + TComputationContext(const THolderFactory& holderFactory, + const NUdf::IValueBuilder* builder, + const TComputationOptsFull& opts, + const TComputationMutables& mutables, + arrow::MemoryPool& arrowMemoryPool); + + ~TComputationContext(); + + // Returns true if current usage delta exceeds the memory limit + // The function automatically adjusts memory limit taking into account RSS delta between calls + template<bool TrackRss> + inline bool CheckAdjustedMemLimit(ui64 memLimit, ui64 initMemUsage); + + void UpdateUsageAdjustor(ui64 memLimit); +private: + ui64 InitRss = 0ULL; + ui64 LastRss = 0ULL; +#ifndef NDEBUG + TInstant LastPrintUsage; +#endif +}; + +class IArrowKernelComputationNode; + +class IComputationNode { +public: + typedef TIntrusivePtr<IComputationNode> TPtr; + typedef std::map<ui32, EValueRepresentation> TIndexesMap; + + virtual ~IComputationNode() {} + + virtual void InitNode(TComputationContext&) const = 0; + + virtual NUdf::TUnboxedValue GetValue(TComputationContext& compCtx) const = 0; + + virtual IComputationNode* AddDependence(const IComputationNode* node) = 0; + + virtual const IComputationNode* GetSource() const = 0; + + virtual void RegisterDependencies() const = 0; + + virtual ui32 GetIndex() const = 0; + virtual void CollectDependentIndexes(const IComputationNode* owner, TIndexesMap& dependencies) const = 0; + virtual ui32 GetDependencyWeight() const = 0; + virtual ui32 GetDependencesCount() const = 0; + + virtual bool IsTemporaryValue() const = 0; + + virtual EValueRepresentation GetRepresentation() const = 0; + + virtual void PrepareStageOne() = 0; + virtual void PrepareStageTwo() = 0; + + virtual TString DebugString() const = 0; + + virtual void Ref() = 0; + virtual void UnRef() = 0; + virtual ui32 RefCount() const = 0; + + virtual std::unique_ptr<IArrowKernelComputationNode> PrepareArrowKernelComputationNode(TComputationContext& ctx) const; +}; + +class IComputationExternalNode : public IComputationNode { +public: + virtual NUdf::TUnboxedValue& RefValue(TComputationContext& compCtx) const = 0; + virtual void SetValue(TComputationContext& compCtx, NUdf::TUnboxedValue&& newValue) const = 0; + virtual void SetOwner(const IComputationNode* node) = 0; + + using TGetter = std::function<NUdf::TUnboxedValue(TComputationContext&)>; + virtual void SetGetter(TGetter&& getter) = 0; + virtual void InvalidateValue(TComputationContext& compCtx) const = 0; +}; + +enum class EFetchResult : i32 { + Finish = -1, + Yield = 0, + One = 1 +}; + +class IComputationWideFlowNode : public IComputationNode { +public: + virtual EFetchResult FetchValues(TComputationContext& compCtx, NUdf::TUnboxedValue*const* values) const = 0; +}; + +class IComputationWideFlowProxyNode : public IComputationWideFlowNode { +public: + using TFetcher = std::function<EFetchResult(TComputationContext&, NUdf::TUnboxedValue*const*)>; + virtual void SetFetcher(TFetcher&& fetcher) = 0; + virtual void SetOwner(const IComputationNode* node) = 0; + virtual void InvalidateValue(TComputationContext& compCtx) const = 0; +}; + +using TDatumProvider = std::function<arrow::Datum()>; + +TDatumProvider MakeDatumProvider(const arrow::Datum& datum); +TDatumProvider MakeDatumProvider(const IComputationNode* node, TComputationContext& ctx); + +class IArrowKernelComputationNode { +public: + virtual ~IArrowKernelComputationNode() = default; + + virtual TStringBuf GetKernelName() const = 0; + virtual const arrow::compute::ScalarKernel& GetArrowKernel() const = 0; + virtual const std::vector<arrow::ValueDescr>& GetArgsDesc() const = 0; + virtual const IComputationNode* GetArgument(ui32 index) const = 0; +}; + +struct TArrowKernelsTopologyItem { + std::vector<ui32> Inputs; + std::unique_ptr<IArrowKernelComputationNode> Node; +}; + +struct TArrowKernelsTopology { + ui32 InputArgsCount = 0; + std::vector<TArrowKernelsTopologyItem> Items; +}; + +using TComputationNodePtrVector = std::vector<IComputationNode*, TMKQLAllocator<IComputationNode*>>; +using TComputationWideFlowNodePtrVector = std::vector<IComputationWideFlowNode*, TMKQLAllocator<IComputationWideFlowNode*>>; +using TComputationExternalNodePtrVector = std::vector<IComputationExternalNode*, TMKQLAllocator<IComputationExternalNode*>>; +using TConstComputationNodePtrVector = std::vector<const IComputationNode*, TMKQLAllocator<const IComputationNode*>>; +using TComputationNodePtrDeque = std::deque<IComputationNode::TPtr, TMKQLAllocator<IComputationNode::TPtr>>; +using TComputationNodeOnNodeMap = std::unordered_map<const IComputationNode*, IComputationNode*, std::hash<const IComputationNode*>, std::equal_to<const IComputationNode*>, TMKQLAllocator<std::pair<const IComputationNode *const, IComputationNode*>>>; + +class IComputationGraph { +public: + virtual ~IComputationGraph() {} + virtual void Prepare() = 0; + virtual NUdf::TUnboxedValue GetValue() = 0; + virtual TComputationContext& GetContext() = 0; + virtual IComputationExternalNode* GetEntryPoint(size_t index, bool require) = 0; + virtual const TArrowKernelsTopology* GetKernelsTopology() = 0; + virtual const TComputationNodePtrDeque& GetNodes() const = 0; + virtual void Invalidate() = 0; + virtual TMemoryUsageInfo& GetMemInfo() const = 0; + virtual const THolderFactory& GetHolderFactory() const = 0; + virtual ITerminator* GetTerminator() const = 0; + virtual bool SetExecuteLLVM(bool value) = 0; + virtual TString SaveGraphState() = 0; + virtual void LoadGraphState(TStringBuf state) = 0; +}; + +class TNodeFactory; +typedef std::function<IComputationNode* (TNode* node, bool pop)> TNodeLocator; +typedef std::function<void (IComputationNode*)> TNodePushBack; + +struct TComputationNodeFactoryContext { + TNodeLocator NodeLocator; + const IFunctionRegistry& FunctionRegistry; + const TTypeEnvironment& Env; + NUdf::ITypeInfoHelper::TPtr TypeInfoHelper; + NUdf::ICountersProvider* CountersProvider; + const NUdf::ISecureParamsProvider* SecureParamsProvider; + const TNodeFactory& NodeFactory; + const THolderFactory& HolderFactory; + const NUdf::IValueBuilder *const Builder; + NUdf::EValidateMode ValidateMode; + NUdf::EValidatePolicy ValidatePolicy; + EGraphPerProcess GraphPerProcess; + TComputationMutables& Mutables; + TComputationNodeOnNodeMap& ElementsCache; + const TNodePushBack NodePushBack; + + TComputationNodeFactoryContext( + const TNodeLocator& nodeLocator, + const IFunctionRegistry& functionRegistry, + const TTypeEnvironment& env, + NUdf::ITypeInfoHelper::TPtr typeInfoHelper, + NUdf::ICountersProvider* countersProvider, + const NUdf::ISecureParamsProvider* secureParamsProvider, + const TNodeFactory& nodeFactory, + const THolderFactory& holderFactory, + const NUdf::IValueBuilder* builder, + NUdf::EValidateMode validateMode, + NUdf::EValidatePolicy validatePolicy, + EGraphPerProcess graphPerProcess, + TComputationMutables& mutables, + TComputationNodeOnNodeMap& elementsCache, + TNodePushBack&& nodePushBack + ) + : NodeLocator(nodeLocator) + , FunctionRegistry(functionRegistry) + , Env(env) + , TypeInfoHelper(typeInfoHelper) + , CountersProvider(countersProvider) + , SecureParamsProvider(secureParamsProvider) + , NodeFactory(nodeFactory) + , HolderFactory(holderFactory) + , Builder(builder) + , ValidateMode(validateMode) + , ValidatePolicy(validatePolicy) + , GraphPerProcess(graphPerProcess) + , Mutables(mutables) + , ElementsCache(elementsCache) + , NodePushBack(std::move(nodePushBack)) + {} +}; + +using TComputationNodeFactory = std::function<IComputationNode* (TCallable&, const TComputationNodeFactoryContext&)>; +using TStreamEmitter = std::function<void(NUdf::TUnboxedValue&&)>; + +struct TPatternCacheEntry; + +struct TComputationPatternOpts { + TComputationPatternOpts(TAllocState& allocState, const TTypeEnvironment& env) + : AllocState(allocState) + , Env(env) + {} + + TComputationPatternOpts( + TAllocState& allocState, + const TTypeEnvironment& env, + TComputationNodeFactory factory, + const IFunctionRegistry* functionRegistry, + NUdf::EValidateMode validateMode, + NUdf::EValidatePolicy validatePolicy, + const TString& optLLVM, + EGraphPerProcess graphPerProcess, + IStatsRegistry* stats = nullptr, + NUdf::ICountersProvider* countersProvider = nullptr, + const NUdf::ISecureParamsProvider* secureParamsProvider = nullptr) + : AllocState(allocState) + , Env(env) + , Factory(factory) + , FunctionRegistry(functionRegistry) + , ValidateMode(validateMode) + , ValidatePolicy(validatePolicy) + , OptLLVM(optLLVM) + , GraphPerProcess(graphPerProcess) + , Stats(stats) + , CountersProvider(countersProvider) + , SecureParamsProvider(secureParamsProvider) + {} + + void SetOptions(TComputationNodeFactory factory, const IFunctionRegistry* functionRegistry, + NUdf::EValidateMode validateMode, NUdf::EValidatePolicy validatePolicy, + const TString& optLLVM, EGraphPerProcess graphPerProcess, IStatsRegistry* stats = nullptr, + NUdf::ICountersProvider* counters = nullptr, + const NUdf::ISecureParamsProvider* secureParamsProvider = nullptr) { + Factory = factory; + FunctionRegistry = functionRegistry; + ValidateMode = validateMode; + ValidatePolicy = validatePolicy; + OptLLVM = optLLVM; + GraphPerProcess = graphPerProcess; + Stats = stats; + CountersProvider = counters; + SecureParamsProvider = secureParamsProvider; + } + + void SetPatternEnv(std::shared_ptr<TPatternCacheEntry> cacheEnv) { + PatternEnv = std::move(cacheEnv); + } + + mutable std::shared_ptr<TPatternCacheEntry> PatternEnv; + TAllocState& AllocState; + const TTypeEnvironment& Env; + + TComputationNodeFactory Factory; + const IFunctionRegistry* FunctionRegistry = nullptr; + NUdf::EValidateMode ValidateMode = NUdf::EValidateMode::None; + NUdf::EValidatePolicy ValidatePolicy = NUdf::EValidatePolicy::Fail; + TString OptLLVM; + EGraphPerProcess GraphPerProcess = EGraphPerProcess::Multi; + IStatsRegistry* Stats = nullptr; + NUdf::ICountersProvider* CountersProvider = nullptr; + const NUdf::ISecureParamsProvider* SecureParamsProvider = nullptr; + + TComputationOptsFull ToComputationOptions(IRandomProvider& randomProvider, ITimeProvider& timeProvider, TAllocState* allocStatePtr = nullptr) const { + return TComputationOptsFull(Stats, allocStatePtr ? *allocStatePtr : AllocState, Env, randomProvider, timeProvider, ValidatePolicy, SecureParamsProvider, CountersProvider); + } +}; + +class IComputationPattern: public TAtomicRefCount<IComputationPattern> { +public: + typedef TIntrusivePtr<IComputationPattern> TPtr; + + virtual ~IComputationPattern() = default; + virtual void Compile(TString optLLVM, IStatsRegistry* stats) = 0; + virtual bool IsCompiled() const = 0; + virtual size_t CompiledCodeSize() const = 0; + virtual void RemoveCompiledCode() = 0; + virtual THolder<IComputationGraph> Clone(const TComputationOptsFull& compOpts) = 0; + virtual bool GetSuitableForCache() const = 0; +}; + +// node cookie's will be clean up when graph will be destroyed, explorer must not be changed/destroyed until that time +IComputationPattern::TPtr MakeComputationPattern( + TExploringNodeVisitor& explorer, + const TRuntimeNode& root, + const std::vector<TNode*>& entryPoints, + const TComputationPatternOpts& opts); + +std::unique_ptr<NUdf::ISecureParamsProvider> MakeSimpleSecureParamsProvider(const THashMap<TString, TString>& secureParams); + +using TCallableComputationNodeBuilder = std::function<IComputationNode* (TCallable&, const TComputationNodeFactoryContext& ctx)>; + +template<typename... Ts> +TCallableComputationNodeBuilder WrapComputationBuilder(IComputationNode* (*f)(const TComputationNodeFactoryContext&, Ts...)){ + return [f](TCallable& callable, const TComputationNodeFactoryContext& ctx) { + MKQL_ENSURE(callable.GetInputsCount() == sizeof...(Ts), "Incorrect number of inputs"); + return CallComputationBuilderWithArgs(f, callable, ctx, std::make_index_sequence<sizeof...(Ts)>()); + }; +} +template<typename F, size_t... Is> +auto CallComputationBuilderWithArgs(F* f, TCallable& callable, const TComputationNodeFactoryContext& ctx, + const std::integer_sequence<size_t, Is...> &) { + return f(ctx, callable.GetInput(Is)...); +} + +} // namespace NMiniKQL +} // namespace NKikimr |
