diff options
author | Vitaly Stoyan <vvvv@ydb.tech> | 2024-07-08 14:34:05 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-07-08 16:34:05 +0500 |
commit | d20552c258db9a1b1212f16be58b617ce6945b2a (patch) | |
tree | b375f86eddb0eba6ef7c7c34c8119c98ba76253c | |
parent | 7073beb9ee3c43199529046c0b561da135185675 (diff) | |
download | ydb-d20552c258db9a1b1212f16be58b617ce6945b2a.tar.gz |
Less LLVM in block agg (#6359)
-rw-r--r-- | ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp | 429 |
1 files changed, 230 insertions, 199 deletions
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp index 0de21c6be7..57d5d09fe2 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp @@ -456,72 +456,57 @@ size_t CalcMaxBlockLenForOutput(TType* out) { return CalcBlockLen(maxBlockItemSize); } +class TBlockCombineAllWrapperCodegenBase { +protected: +#ifndef MKQL_DISABLE_CODEGEN + class TLLVMFieldsStructureState: public TLLVMFieldsStructure<TComputationValue<TBlockState>> { + private: + using TBase = TLLVMFieldsStructure<TComputationValue<TBlockState>>; + llvm::PointerType*const PointerType; + llvm::IntegerType*const IsFinishedType; + public: + std::vector<llvm::Type*> GetFieldsArray() { + std::vector<llvm::Type*> result = TBase::GetFields(); + result.emplace_back(PointerType); + result.emplace_back(IsFinishedType); + return result; + } -class TBlockCombineAllWrapper : public TStatefulWideFlowCodegeneratorNode<TBlockCombineAllWrapper> { -using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TBlockCombineAllWrapper>; -public: - TBlockCombineAllWrapper(TComputationMutables& mutables, - IComputationWideFlowNode* flow, - std::optional<ui32> filterColumn, - size_t width, - std::vector<TAggParams<IBlockAggregatorCombineAll>>&& aggsParams) - : TBaseComputation(mutables, flow, EValueRepresentation::Boxed) - , Flow_(flow) - , FilterColumn_(filterColumn) - , Width_(width) - , AggsParams_(std::move(aggsParams)) - , WideFieldsIndex_(mutables.IncrementWideFieldsIndex(width)) - { - MKQL_ENSURE(Width_ > 0, "Missing block length column"); - } - - EFetchResult DoCalculate(NUdf::TUnboxedValue& state, - TComputationContext& ctx, - NUdf::TUnboxedValue*const* output) const - { - auto& s = GetState(state, ctx); - if (s.IsFinished_) - return EFetchResult::Finish; + llvm::Constant* GetPointer() { + return llvm::ConstantInt::get(llvm::Type::getInt32Ty(Context), TBase::GetFieldsCount() + 0); + } - for (const auto fields = ctx.WideFields.data() + WideFieldsIndex_;;) { - switch (Flow_->FetchValues(ctx, fields)) { - case EFetchResult::Yield: - return EFetchResult::Yield; - case EFetchResult::One: - s.ProcessInput(); - continue; - case EFetchResult::Finish: - break; - } - if (s.MakeOutput()) { - for (size_t i = 0; i < AggsParams_.size(); ++i) { - if (const auto out = output[i]) { - *out = s.Get(i); - } - } - return EFetchResult::One; - } - return EFetchResult::Finish; + llvm::Constant* GetIsFinished() { + return llvm::ConstantInt::get(llvm::Type::getInt32Ty(Context), TBase::GetFieldsCount() + 1); } - } -#ifndef MKQL_DISABLE_CODEGEN - ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const { + + TLLVMFieldsStructureState(llvm::LLVMContext& context, size_t width) + : TBase(context) + , PointerType(llvm::PointerType::getUnqual(llvm::ArrayType::get(llvm::Type::getInt128Ty(Context), width))) + , IsFinishedType(llvm::Type::getInt1Ty(Context)) + {} + }; + + ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValuesImpl(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block, + IComputationWideFlowNode* flow, size_t width, size_t aggCount, + uintptr_t getStateMethodPtr, uint64_t makeStateMethodPtr, + uintptr_t processInputMethodPtr, uintptr_t makeOutputMethodPtr) const { auto& context = ctx.Codegen.GetContext(); const auto valueType = Type::getInt128Ty(context); const auto statusType = Type::getInt32Ty(context); const auto indexType = Type::getInt64Ty(context); const auto flagType = Type::getInt1Ty(context); - const auto arrayType = ArrayType::get(valueType, Width_); + const auto arrayType = ArrayType::get(valueType, width); const auto ptrValuesType = PointerType::getUnqual(arrayType); - TLLVMFieldsStructureState stateFields(context, Width_); + TLLVMFieldsStructureState stateFields(context, width); const auto stateType = StructType::get(context, stateFields.GetFieldsArray()); const auto statePtrType = PointerType::getUnqual(stateType); const auto atTop = &ctx.Func->getEntryBlock().back(); - const auto getFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Get)); + const auto getFunc = ConstantInt::get(Type::getInt64Ty(context), getStateMethodPtr); const auto getType = FunctionType::get(valueType, {statePtrType, indexType}, false); const auto getPtr = CastInst::Create(Instruction::IntToPtr, getFunc, PointerType::getUnqual(getType), "get", atTop); @@ -540,7 +525,7 @@ public: const auto ptrType = PointerType::getUnqual(StructType::get(context)); const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block); - const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TBlockCombineAllWrapper::MakeState)); + const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), makeStateMethodPtr); const auto makeType = FunctionType::get(Type::getVoidTy(context), {self->getType(), statePtr->getType(), ctx.Ctx->getType()}, false); const auto makeFuncPtr = CastInst::Create(Instruction::IntToPtr, makeFunc, PointerType::getUnqual(makeType), "function", block); CallInst::Create(makeType, makeFuncPtr, {self, statePtr, ctx.Ctx}, "", block); @@ -567,7 +552,7 @@ public: const auto values = new LoadInst(ptrValuesType, valuesPtr, "values", block); SafeUnRefUnboxed(values, ctx, block); - const auto getres = GetNodeValues(Flow_, ctx, block); + const auto getres = GetNodeValues(flow, ctx, block); result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Yield)), block); const auto way = SwitchInst::Create(getres.first, good, 2U, block); @@ -584,7 +569,7 @@ public: } new StoreInst(array, values, block); - const auto processBlockFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::ProcessInput)); + const auto processBlockFunc = ConstantInt::get(Type::getInt64Ty(context), processInputMethodPtr); const auto processBlockType = FunctionType::get(Type::getVoidTy(context), {statePtrType, ctx.GetFactory()->getType()}, false); const auto processBlockPtr = CastInst::Create(Instruction::IntToPtr, processBlockFunc, PointerType::getUnqual(processBlockType), "process_inputs_func", block); CallInst::Create(processBlockType, processBlockPtr, {stateArg, ctx.GetFactory()}, "", block); @@ -593,7 +578,7 @@ public: block = work; - const auto makeOutputFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::MakeOutput)); + const auto makeOutputFunc = ConstantInt::get(Type::getInt64Ty(context), makeOutputMethodPtr); const auto makeOutputType = FunctionType::get(flagType, {statePtrType, ctx.GetFactory()->getType()}, false); const auto makeOutputPtr = CastInst::Create(Instruction::IntToPtr, makeOutputFunc, PointerType::getUnqual(makeOutputType), "make_output_func", block); const auto hasData = CallInst::Create(makeOutputType, makeOutputPtr, {stateArg, ctx.GetFactory()}, "make_output", block); @@ -605,7 +590,7 @@ public: block = over; - ICodegeneratorInlineWideNode::TGettersList getters(AggsParams_.size()); + ICodegeneratorInlineWideNode::TGettersList getters(aggCount); for (size_t idx = 0U; idx < getters.size(); ++idx) { getters[idx] = [idx, getType, getPtr, indexType, statePtrType, stateOnStack](const TCodegenContext& ctx, BasicBlock*& block) { Y_UNUSED(ctx); @@ -616,6 +601,63 @@ public: return {result, std::move(getters)}; } #endif +}; + +class TBlockCombineAllWrapper : public TStatefulWideFlowCodegeneratorNode<TBlockCombineAllWrapper>, + protected TBlockCombineAllWrapperCodegenBase { +using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TBlockCombineAllWrapper>; +public: + TBlockCombineAllWrapper(TComputationMutables& mutables, + IComputationWideFlowNode* flow, + std::optional<ui32> filterColumn, + size_t width, + std::vector<TAggParams<IBlockAggregatorCombineAll>>&& aggsParams) + : TBaseComputation(mutables, flow, EValueRepresentation::Boxed) + , Flow_(flow) + , FilterColumn_(filterColumn) + , Width_(width) + , AggsParams_(std::move(aggsParams)) + , WideFieldsIndex_(mutables.IncrementWideFieldsIndex(width)) + { + MKQL_ENSURE(Width_ > 0, "Missing block length column"); + } + + EFetchResult DoCalculate(NUdf::TUnboxedValue& state, + TComputationContext& ctx, + NUdf::TUnboxedValue*const* output) const + { + auto& s = GetState(state, ctx); + if (s.IsFinished_) + return EFetchResult::Finish; + + for (const auto fields = ctx.WideFields.data() + WideFieldsIndex_;;) { + switch (Flow_->FetchValues(ctx, fields)) { + case EFetchResult::Yield: + return EFetchResult::Yield; + case EFetchResult::One: + s.ProcessInput(); + continue; + case EFetchResult::Finish: + break; + } + if (s.MakeOutput()) { + for (size_t i = 0; i < AggsParams_.size(); ++i) { + if (const auto out = output[i]) { + *out = s.Get(i); + } + } + return EFetchResult::One; + } + return EFetchResult::Finish; + } + } +#ifndef MKQL_DISABLE_CODEGEN + ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const { + return DoGenGetValuesImpl(ctx, statePtr, block, Flow_, Width_, AggsParams_.size(), + GetMethodPtr(&TState::Get), GetMethodPtr(&TBlockCombineAllWrapper::MakeState), + GetMethodPtr(&TState::ProcessInput), GetMethodPtr(&TState::MakeOutput)); + } +#endif private: struct TState : public TComputationValue<TState> { NUdf::TUnboxedValue* Pointer_ = nullptr; @@ -701,35 +743,6 @@ private: return Values_[index]; } }; -#ifndef MKQL_DISABLE_CODEGEN - class TLLVMFieldsStructureState: public TLLVMFieldsStructure<TComputationValue<TBlockState>> { - private: - using TBase = TLLVMFieldsStructure<TComputationValue<TBlockState>>; - llvm::PointerType*const PointerType; - llvm::IntegerType*const IsFinishedType; - public: - std::vector<llvm::Type*> GetFieldsArray() { - std::vector<llvm::Type*> result = TBase::GetFields(); - result.emplace_back(PointerType); - result.emplace_back(IsFinishedType); - return result; - } - - llvm::Constant* GetPointer() { - return llvm::ConstantInt::get(llvm::Type::getInt32Ty(Context), TBase::GetFieldsCount() + 0); - } - - llvm::Constant* GetIsFinished() { - return llvm::ConstantInt::get(llvm::Type::getInt32Ty(Context), TBase::GetFieldsCount() + 1); - } - - TLLVMFieldsStructureState(llvm::LLVMContext& context, size_t width) - : TBase(context) - , PointerType(llvm::PointerType::getUnqual(llvm::ArrayType::get(llvm::Type::getInt128Ty(Context), width))) - , IsFinishedType(llvm::Type::getInt1Ty(Context)) - {} - }; -#endif void RegisterDependencies() const final { FlowDependsOn(Flow_); } @@ -839,101 +852,61 @@ std::hash<TExternalFixedSizeKey> MakeHash(ui32 keyLength) { return std::hash<TExternalFixedSizeKey>(keyLength); } -template <typename TKey, typename TAggregator, typename TFixedAggState, bool UseSet, bool UseFilter, bool Finalize, bool Many, typename TDerived> -class THashedWrapperBase : public TStatefulWideFlowCodegeneratorNode<TDerived> { - using TComputationBase = TStatefulWideFlowCodegeneratorNode<TDerived>; - static constexpr bool UseArena = !InlineAggState && std::is_same<TFixedAggState, TStateArena>::value; -public: - THashedWrapperBase(TComputationMutables& mutables, - IComputationWideFlowNode* flow, - std::optional<ui32> filterColumn, - size_t width, - const std::vector<TKeyParams>& keys, - size_t maxBlockLen, - ui32 keyLength, - std::vector<TAggParams<TAggregator>>&& aggsParams, - ui32 streamIndex, - std::vector<std::vector<ui32>>&& streams) - : TComputationBase(mutables, flow, EValueRepresentation::Boxed) - , Flow_(flow) - , FilterColumn_(filterColumn) - , Width_(width) - , OutputWidth_(keys.size() + aggsParams.size() + 1) - , WideFieldsIndex_(mutables.IncrementWideFieldsIndex(width)) - , Keys_(keys) - , MaxBlockLen_(maxBlockLen) - , AggsParams_(std::move(aggsParams)) - , KeyLength_(keyLength) - , StreamIndex_(streamIndex) - , Streams_(std::move(streams)) - { - MKQL_ENSURE(Width_ > 0, "Missing block length column"); - if constexpr (UseFilter) { - MKQL_ENSURE(filterColumn, "Missing filter column"); - MKQL_ENSURE(!Finalize, "Filter isn't compatible with Finalize"); - } else { - MKQL_ENSURE(!filterColumn, "Unexpected filter column"); +class THashedWrapperCodegenBase { +protected: +#ifndef MKQL_DISABLE_CODEGEN + class TLLVMFieldsStructureState: public TLLVMFieldsStructureBlockState { + private: + using TBase = TLLVMFieldsStructureBlockState; + llvm::IntegerType*const WritingOutputType; + llvm::IntegerType*const IsFinishedType; + protected: + using TBase::Context; + public: + std::vector<llvm::Type*> GetFieldsArray() { + std::vector<llvm::Type*> result = TBase::GetFieldsArray(); + result.emplace_back(WritingOutputType); + result.emplace_back(IsFinishedType); + return result; } - } - - EFetchResult DoCalculate(NUdf::TUnboxedValue& state, - TComputationContext& ctx, - NUdf::TUnboxedValue*const* output) const - { - auto& s = GetState(state, ctx); - if (!s.Count) { - if (s.IsFinished_) - return EFetchResult::Finish; - - while (!s.WritingOutput_) { - const auto fields = ctx.WideFields.data() + WideFieldsIndex_; - s.Values_.assign(s.Values_.size(), NUdf::TUnboxedValuePod()); - switch (Flow_->FetchValues(ctx, fields)) { - case EFetchResult::Yield: - return EFetchResult::Yield; - case EFetchResult::One: - s.ProcessInput(ctx.HolderFactory); - continue; - case EFetchResult::Finish: - break; - } - - if (s.Finish()) - break; - else - return EFetchResult::Finish; - } - if (!s.FillOutput(ctx.HolderFactory)) - return EFetchResult::Finish; + llvm::Constant* GetWritingOutput() { + return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + BaseFields); } - const auto sliceSize = s.Slice(); - for (size_t i = 0; i < OutputWidth_; ++i) { - if (const auto out = output[i]) { - *out = s.Get(sliceSize, ctx.HolderFactory, i); - } + llvm::Constant* GetIsFinished() { + return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + BaseFields + 1); } - return EFetchResult::One; - } -#ifndef MKQL_DISABLE_CODEGEN - ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const { + + TLLVMFieldsStructureState(llvm::LLVMContext& context, size_t width) + : TBase(context, width) + , WritingOutputType(Type::getInt1Ty(Context)) + , IsFinishedType(Type::getInt1Ty(Context)) + {} + }; + + Y_NO_INLINE ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValuesImpl( + const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block, + IComputationWideFlowNode* flow, size_t width, size_t outputWidth, + uintptr_t getStateMethodPtr, uintptr_t makeStateMethodPtr, + uintptr_t processInputMethodPtr, uintptr_t finishMethodPtr, + uintptr_t fillOutputMethodPtr, uintptr_t sliceMethodPtr) const { auto& context = ctx.Codegen.GetContext(); const auto valueType = Type::getInt128Ty(context); const auto statusType = Type::getInt32Ty(context); const auto indexType = Type::getInt64Ty(context); const auto flagType = Type::getInt1Ty(context); - const auto arrayType = ArrayType::get(valueType, Width_); + const auto arrayType = ArrayType::get(valueType, width); const auto ptrValuesType = PointerType::getUnqual(arrayType); - TLLVMFieldsStructureState stateFields(context, Width_); + TLLVMFieldsStructureState stateFields(context, width); const auto stateType = StructType::get(context, stateFields.GetFieldsArray()); const auto statePtrType = PointerType::getUnqual(stateType); const auto atTop = &ctx.Func->getEntryBlock().back(); - const auto getFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Get)); + const auto getFunc = ConstantInt::get(Type::getInt64Ty(context), getStateMethodPtr); const auto getType = FunctionType::get(valueType, {statePtrType, indexType, ctx.GetFactory()->getType(), indexType}, false); const auto getPtr = CastInst::Create(Instruction::IntToPtr, getFunc, PointerType::getUnqual(getType), "get", atTop); @@ -959,7 +932,7 @@ public: const auto ptrType = PointerType::getUnqual(StructType::get(context)); const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block); - const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&THashedWrapperBase::MakeState)); + const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), makeStateMethodPtr); const auto makeType = FunctionType::get(Type::getVoidTy(context), {self->getType(), statePtr->getType(), ctx.Ctx->getType()}, false); const auto makeFuncPtr = CastInst::Create(Instruction::IntToPtr, makeFunc, PointerType::getUnqual(makeType), "function", block); CallInst::Create(makeType, makeFuncPtr, {self, statePtr, ctx.Ctx}, "", block); @@ -1001,7 +974,7 @@ public: const auto values = new LoadInst(ptrValuesType, valuesPtr, "values", block); SafeUnRefUnboxed(values, ctx, block); - const auto getres = GetNodeValues(Flow_, ctx, block); + const auto getres = GetNodeValues(flow, ctx, block); result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Yield)), block); const auto way = SwitchInst::Create(getres.first, good, 2U, block); @@ -1018,7 +991,7 @@ public: } new StoreInst(array, values, block); - const auto processBlockFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::ProcessInput)); + const auto processBlockFunc = ConstantInt::get(Type::getInt64Ty(context), processInputMethodPtr); const auto processBlockType = FunctionType::get(Type::getVoidTy(context), {statePtrType, ctx.GetFactory()->getType()}, false); const auto processBlockPtr = CastInst::Create(Instruction::IntToPtr, processBlockFunc, PointerType::getUnqual(processBlockType), "process_inputs_func", block); CallInst::Create(processBlockType, processBlockPtr, {stateArg, ctx.GetFactory()}, "", block); @@ -1027,7 +1000,7 @@ public: block = stop; - const auto finishFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Finish)); + const auto finishFunc = ConstantInt::get(Type::getInt64Ty(context), finishMethodPtr); const auto finishType = FunctionType::get(flagType, {statePtrType}, false); const auto finishPtr = CastInst::Create(Instruction::IntToPtr, finishFunc, PointerType::getUnqual(finishType), "finish_func", block); const auto hasOutput = CallInst::Create(finishType, finishPtr, {stateArg}, "has_output", block); @@ -1038,7 +1011,7 @@ public: block = work; - const auto fillBlockFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::FillOutput)); + const auto fillBlockFunc = ConstantInt::get(Type::getInt64Ty(context), fillOutputMethodPtr); const auto fillBlockType = FunctionType::get(flagType, {statePtrType, ctx.GetFactory()->getType()}, false); const auto fillBlockPtr = CastInst::Create(Instruction::IntToPtr, fillBlockFunc, PointerType::getUnqual(fillBlockType), "fill_output_func", block); const auto hasData = CallInst::Create(fillBlockType, fillBlockPtr, {stateArg, ctx.GetFactory()}, "fill_output", block); @@ -1049,7 +1022,7 @@ public: block = fill; - const auto sliceFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Slice)); + const auto sliceFunc = ConstantInt::get(Type::getInt64Ty(context), sliceMethodPtr); const auto sliceType = FunctionType::get(indexType, {statePtrType}, false); const auto slicePtr = CastInst::Create(Instruction::IntToPtr, sliceFunc, PointerType::getUnqual(sliceType), "slice_func", block); const auto slice = CallInst::Create(sliceType, slicePtr, {stateArg}, "slice", block); @@ -1062,7 +1035,7 @@ public: block = over; - ICodegeneratorInlineWideNode::TGettersList getters(OutputWidth_); + ICodegeneratorInlineWideNode::TGettersList getters(outputWidth); for (size_t idx = 0U; idx < getters.size(); ++idx) { getters[idx] = [idx, getType, getPtr, heightPtr, indexType, statePtrType, stateOnStack](const TCodegenContext& ctx, BasicBlock*& block) { const auto stateArg = new LoadInst(statePtrType, stateOnStack, "state", block); @@ -1073,6 +1046,95 @@ public: return {result, std::move(getters)}; } #endif +}; + +template <typename TKey, typename TAggregator, typename TFixedAggState, bool UseSet, bool UseFilter, bool Finalize, bool Many, typename TDerived> +class THashedWrapperBase : public TStatefulWideFlowCodegeneratorNode<TDerived>, + protected THashedWrapperCodegenBase +{ + using TComputationBase = TStatefulWideFlowCodegeneratorNode<TDerived>; + static constexpr bool UseArena = !InlineAggState && std::is_same<TFixedAggState, TStateArena>::value; +public: + THashedWrapperBase(TComputationMutables& mutables, + IComputationWideFlowNode* flow, + std::optional<ui32> filterColumn, + size_t width, + const std::vector<TKeyParams>& keys, + size_t maxBlockLen, + ui32 keyLength, + std::vector<TAggParams<TAggregator>>&& aggsParams, + ui32 streamIndex, + std::vector<std::vector<ui32>>&& streams) + : TComputationBase(mutables, flow, EValueRepresentation::Boxed) + , Flow_(flow) + , FilterColumn_(filterColumn) + , Width_(width) + , OutputWidth_(keys.size() + aggsParams.size() + 1) + , WideFieldsIndex_(mutables.IncrementWideFieldsIndex(width)) + , Keys_(keys) + , MaxBlockLen_(maxBlockLen) + , AggsParams_(std::move(aggsParams)) + , KeyLength_(keyLength) + , StreamIndex_(streamIndex) + , Streams_(std::move(streams)) + { + MKQL_ENSURE(Width_ > 0, "Missing block length column"); + if constexpr (UseFilter) { + MKQL_ENSURE(filterColumn, "Missing filter column"); + MKQL_ENSURE(!Finalize, "Filter isn't compatible with Finalize"); + } else { + MKQL_ENSURE(!filterColumn, "Unexpected filter column"); + } + } + + EFetchResult DoCalculate(NUdf::TUnboxedValue& state, + TComputationContext& ctx, + NUdf::TUnboxedValue*const* output) const + { + auto& s = GetState(state, ctx); + if (!s.Count) { + if (s.IsFinished_) + return EFetchResult::Finish; + + while (!s.WritingOutput_) { + const auto fields = ctx.WideFields.data() + WideFieldsIndex_; + s.Values_.assign(s.Values_.size(), NUdf::TUnboxedValuePod()); + switch (Flow_->FetchValues(ctx, fields)) { + case EFetchResult::Yield: + return EFetchResult::Yield; + case EFetchResult::One: + s.ProcessInput(ctx.HolderFactory); + continue; + case EFetchResult::Finish: + break; + } + + if (s.Finish()) + break; + else + return EFetchResult::Finish; + } + + if (!s.FillOutput(ctx.HolderFactory)) + return EFetchResult::Finish; + } + + const auto sliceSize = s.Slice(); + for (size_t i = 0; i < OutputWidth_; ++i) { + if (const auto out = output[i]) { + *out = s.Get(sliceSize, ctx.HolderFactory, i); + } + } + return EFetchResult::One; + } +#ifndef MKQL_DISABLE_CODEGEN + ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const { + return DoGenGetValuesImpl(ctx, statePtr, block, Flow_, Width_, OutputWidth_, + GetMethodPtr(&TState::Get), GetMethodPtr(&THashedWrapperBase::MakeState), + GetMethodPtr(&TState::ProcessInput), GetMethodPtr(&TState::Finish), + GetMethodPtr(&TState::FillOutput), GetMethodPtr(&TState::Slice)); + } +#endif private: struct TState : public TBlockState { bool WritingOutput_ = false; @@ -1567,37 +1629,6 @@ private: } }; private: -#ifndef MKQL_DISABLE_CODEGEN - class TLLVMFieldsStructureState: public TLLVMFieldsStructureBlockState { - private: - using TBase = TLLVMFieldsStructureBlockState; - llvm::IntegerType*const WritingOutputType; - llvm::IntegerType*const IsFinishedType; - protected: - using TBase::Context; - public: - std::vector<llvm::Type*> GetFieldsArray() { - std::vector<llvm::Type*> result = TBase::GetFieldsArray(); - result.emplace_back(WritingOutputType); - result.emplace_back(IsFinishedType); - return result; - } - - llvm::Constant* GetWritingOutput() { - return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + BaseFields); - } - - llvm::Constant* GetIsFinished() { - return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + BaseFields + 1); - } - - TLLVMFieldsStructureState(llvm::LLVMContext& context, size_t width) - : TBase(context, width) - , WritingOutputType(Type::getInt1Ty(Context)) - , IsFinishedType(Type::getInt1Ty(Context)) - {} - }; -#endif void RegisterDependencies() const final { this->FlowDependsOn(Flow_); } |