diff options
author | a-romanov <Anton.Romanov@ydb.tech> | 2023-10-17 14:52:04 +0300 |
---|---|---|
committer | a-romanov <Anton.Romanov@ydb.tech> | 2023-10-17 15:23:28 +0300 |
commit | bb268f4050420167111455bf6682288627ffbff6 (patch) | |
tree | 7889c80a028d2035872abdfea5742a6b2ee2e828 | |
parent | 88a3de5de621a28adbd0dd40f0864a32f8396dcd (diff) | |
download | ydb-bb268f4050420167111455bf6682288627ffbff6.tar.gz |
YQL-15891 LLVM for BlockCombineAll.
-rw-r--r-- | ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp | 301 |
1 files changed, 230 insertions, 71 deletions
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp index 4ccc1787cf..76d0ccc7e3 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp @@ -456,18 +456,20 @@ size_t CalcMaxBlockLenForOutput(TType* out) { } -class TBlockCombineAllWrapper : public TStatefulWideFlowComputationNode<TBlockCombineAllWrapper> { +class TBlockCombineAllWrapper : public TStatefulWideFlowCodegeneratorNode<TBlockCombineAllWrapper> { +using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TBlockCombineAllWrapper>; public: TBlockCombineAllWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow, std::optional<ui32> filterColumn, size_t width, std::vector<TAggParams<IBlockAggregatorCombineAll>>&& aggsParams) - : TStatefulWideFlowComputationNode(mutables, flow, EValueRepresentation::Any) + : TBaseComputation(mutables, flow, EValueRepresentation::Boxed) , Flow_(flow) , FilterColumn_(filterColumn) , Width_(width) , AggsParams_(std::move(aggsParams)) + , WideFieldsIndex_(mutables.IncrementWideFieldsIndex(width)) { MKQL_ENSURE(Width_ > 0, "Missing block length column"); } @@ -477,88 +479,160 @@ public: NUdf::TUnboxedValue*const* output) const { auto& s = GetState(state, ctx); - if (s.IsFinished_) { + if (s.IsFinished_) return EFetchResult::Finish; - } - for (;;) { - auto result = Flow_->FetchValues(ctx, s.ValuePointers_.data()); - if (result == EFetchResult::Yield) { - return result; - } else if (result == EFetchResult::One) { - ui64 batchLength = GetBatchLength(s.Values_.data()); - if (!batchLength) { + for (const auto fields = ctx.WideFields.data() + WideFieldsIndex_;;) { + switch (Flow_->FetchValues(ctx, fields)) { + case EFetchResult::Yield: + return EFetchResult::Yield; + case EFetchResult::One: + s.ProcessInput(); continue; + case EFetchResult::Finish: + break; + } + if (s.MakeOutput()) { + for (size_t i = 0; i < AggsParams_.size(); ++i) { + if (const auto out = output[i]) { + *out = s.Pull(i); + } } + return EFetchResult::One; + } + return EFetchResult::Finish; + } + } +#ifndef MKQL_DISABLE_CODEGEN + ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const { + auto& context = ctx.Codegen.GetContext(); - std::optional<ui64> filtered; - if (FilterColumn_) { - auto filterDatum = TArrowBlock::From(s.Values_[*FilterColumn_]).GetDatum(); - if (filterDatum.is_scalar()) { - if (!filterDatum.scalar_as<arrow::UInt8Scalar>().value) { - continue; - } - } else { - ui64 popCount = GetBitmapPopCount(filterDatum.array()); - if (popCount == 0) { - continue; - } + const auto valueType = Type::getInt128Ty(context); + const auto ptrValueType = PointerType::getUnqual(valueType); + const auto statusType = Type::getInt32Ty(context); + const auto indexType = Type::getInt64Ty(context); + const auto flagType = Type::getInt1Ty(context); + const auto arrayType = ArrayType::get(valueType, Width_); + const auto ptrValuesType = PointerType::getUnqual(arrayType); - if (popCount < batchLength) { - filtered = popCount; - } - } - } + TLLVMFieldsStructureState stateFields(context, Width_); + const auto stateType = StructType::get(context, stateFields.GetFieldsArray()); + const auto statePtrType = PointerType::getUnqual(stateType); - s.HasValues_ = true; - char* ptr = s.AggStates_.data(); - for (size_t i = 0; i < s.Aggs_.size(); ++i) { - if (output[i]) { - s.Aggs_[i]->AddMany(ptr, s.Values_.data(), batchLength, filtered); - } + const auto atTop = &ctx.Func->getEntryBlock().back(); - ptr += s.Aggs_[i]->StateSize; - } - } else { - s.IsFinished_ = true; - if (!s.HasValues_) { - return EFetchResult::Finish; - } + const auto pullFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Pull)); + const auto pullType = FunctionType::get(valueType, {statePtrType, indexType}, false); + const auto pullPtr = CastInst::Create(Instruction::IntToPtr, pullFunc, PointerType::getUnqual(pullType), "pull", atTop); - char* ptr = s.AggStates_.data(); - for (size_t i = 0; i < s.Aggs_.size(); ++i) { - if (auto* out = output[i]; out != nullptr) { - *out = s.Aggs_[i]->FinishOne(ptr); - s.Aggs_[i]->DestroyState(ptr); - } + const auto stateOnStack = new AllocaInst(statePtrType, 0U, "state_on_stack", atTop); + new StoreInst(ConstantPointerNull::get(statePtrType), stateOnStack, atTop); - ptr += s.Aggs_[i]->StateSize; - } + const auto make = BasicBlock::Create(context, "make", ctx.Func); + const auto main = BasicBlock::Create(context, "main", ctx.Func); + const auto read = BasicBlock::Create(context, "read", ctx.Func); + const auto good = BasicBlock::Create(context, "good", ctx.Func); + const auto work = BasicBlock::Create(context, "work", ctx.Func); + const auto over = BasicBlock::Create(context, "over", ctx.Func); - return EFetchResult::One; - } + BranchInst::Create(main, make, HasValue(statePtr, block), block); + block = make; + + const auto ptrType = PointerType::getUnqual(StructType::get(context)); + const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block); + const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TBlockCombineAllWrapper::MakeState)); + const auto makeType = FunctionType::get(Type::getVoidTy(context), {self->getType(), statePtr->getType(), ctx.Ctx->getType()}, false); + const auto makeFuncPtr = CastInst::Create(Instruction::IntToPtr, makeFunc, PointerType::getUnqual(makeType), "function", block); + CallInst::Create(makeType, makeFuncPtr, {self, statePtr, ctx.Ctx}, "", block); + + BranchInst::Create(main, block); + + block = main; + + const auto state = new LoadInst(valueType, statePtr, "state", block); + const auto half = CastInst::Create(Instruction::Trunc, state, Type::getInt64Ty(context), "half", block); + const auto stateArg = CastInst::Create(Instruction::IntToPtr, half, statePtrType, "state_arg", block); + + const auto finishedPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetIsFinished() }, "is_finished_ptr", block); + const auto finished = new LoadInst(flagType, finishedPtr, "finished", block); + + const auto result = PHINode::Create(statusType, 3U, "result", over); + result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), block); + + BranchInst::Create(over, read, finished, block); + + block = read; + + const auto valuesPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetPointer() }, "values_ptr", block); + const auto values = new LoadInst(ptrValuesType, valuesPtr, "values", block); + SafeUnRefUnboxed(values, ctx, block); + + const auto getres = GetNodeValues(Flow_, ctx, block); + result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Yield)), block); + + const auto way = SwitchInst::Create(getres.first, good, 2U, block); + way->addCase(ConstantInt::get(statusType, i32(EFetchResult::Finish)), work); + way->addCase(ConstantInt::get(statusType, i32(EFetchResult::Yield)), over); + + block = good; + + Value* array = UndefValue::get(arrayType); + for (auto idx = 0U; idx < getres.second.size(); ++idx) { + const auto value = getres.second[idx](ctx, block); + AddRefBoxed(value, ctx, block); + array = InsertValueInst::Create(array, value, {idx}, (TString("value_") += ToString(idx)).c_str(), block); } + new StoreInst(array, values, block); - return EFetchResult::Finish; - } + const auto processBlockFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::ProcessInput)); + const auto processBlockType = FunctionType::get(Type::getVoidTy(context), {statePtrType, ctx.GetFactory()->getType()}, false); + const auto processBlockPtr = CastInst::Create(Instruction::IntToPtr, processBlockFunc, PointerType::getUnqual(processBlockType), "process_inputs_func", block); + CallInst::Create(processBlockType, processBlockPtr, {stateArg, ctx.GetFactory()}, "", block); + + BranchInst::Create(read, block); + + block = work; + + const auto makeOutputFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::MakeOutput)); + const auto makeOutputType = FunctionType::get(flagType, {statePtrType, ctx.GetFactory()->getType()}, false); + const auto makeOutputPtr = CastInst::Create(Instruction::IntToPtr, makeOutputFunc, PointerType::getUnqual(makeOutputType), "make_output_func", block); + const auto hasData = CallInst::Create(makeOutputType, makeOutputPtr, {stateArg, ctx.GetFactory()}, "make_output", block); + const auto output = SelectInst::Create(hasData, ConstantInt::get(statusType, static_cast<i32>(EFetchResult::One)), ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), "output", block); + new StoreInst(stateArg, stateOnStack, block); + + result->addIncoming(output, block); + BranchInst::Create(over, block); + + block = over; + ICodegeneratorInlineWideNode::TGettersList getters(AggsParams_.size()); + for (size_t idx = 0U; idx < getters.size(); ++idx) { + getters[idx] = [idx, pullType, pullPtr, indexType, statePtrType, stateOnStack](const TCodegenContext& ctx, BasicBlock*& block) { + const auto stateArg = new LoadInst(statePtrType, stateOnStack, "state", block); + return CallInst::Create(pullType, pullPtr, {stateArg, ConstantInt::get(indexType, idx)}, "pull", block); + }; + } + return {result, std::move(getters)}; + } +#endif private: struct TState : public TComputationValue<TState> { - std::vector<NUdf::TUnboxedValue> Values_; - std::vector<NUdf::TUnboxedValue*> ValuePointers_; - std::vector<std::unique_ptr<IBlockAggregatorCombineAll>> Aggs_; + NUdf::TUnboxedValue* Pointer_ = nullptr; bool IsFinished_ = false; bool HasValues_ = false; + TUnboxedValueVector Values_; + std::vector<std::unique_ptr<IBlockAggregatorCombineAll>> Aggs_; std::vector<char> AggStates_; + const std::optional<ui32> FilterColumn_; + const size_t Width_; - TState(TMemoryUsageInfo* memInfo, size_t width, std::optional<ui32>, const std::vector<TAggParams<IBlockAggregatorCombineAll>>& params, TComputationContext& ctx) + TState(TMemoryUsageInfo* memInfo, size_t width, std::optional<ui32> filterColumn, const std::vector<TAggParams<IBlockAggregatorCombineAll>>& params, TComputationContext& ctx) : TComputationValue(memInfo) - , Values_(width) - , ValuePointers_(width) + , Values_(std::max(width, params.size())) + , FilterColumn_(filterColumn) + , Width_(width) { - for (size_t i = 0; i < width; ++i) { - ValuePointers_[i] = &Values_[i]; - } + Pointer_ = Values_.data(); ui32 totalStateSize = 0; for (const auto& p : params) { @@ -574,29 +648,114 @@ private: ptr += agg->StateSize; } } + + void ProcessInput() { + const ui64 batchLength = TArrowBlock::From(Values_[Width_ - 1U]).GetDatum().scalar_as<arrow::UInt64Scalar>().value; + if (!batchLength) { + return; + } + + std::optional<ui64> filtered; + if (FilterColumn_) { + const auto filterDatum = TArrowBlock::From(Values_[*FilterColumn_]).GetDatum(); + if (filterDatum.is_scalar()) { + if (!filterDatum.scalar_as<arrow::UInt8Scalar>().value) { + return; + } + } else { + const ui64 popCount = GetBitmapPopCount(filterDatum.array()); + if (popCount == 0) { + return; + } + + if (popCount < batchLength) { + filtered = popCount; + } + } + } + + HasValues_ = true; + char* ptr = AggStates_.data(); + for (size_t i = 0; i < Aggs_.size(); ++i) { + Aggs_[i]->AddMany(ptr, Values_.data(), batchLength, filtered); + ptr += Aggs_[i]->StateSize; + } + } + + bool MakeOutput() { + IsFinished_ = true; + if (!HasValues_) + return false; + + char* ptr = AggStates_.data(); + for (size_t i = 0; i < Aggs_.size(); ++i) { + Values_[i] = Aggs_[i]->FinishOne(ptr); + Aggs_[i]->DestroyState(ptr); + ptr += Aggs_[i]->StateSize; + } + return true; + } + + NUdf::TUnboxedValuePod Pull(size_t index) { + return Values_[index].Release(); + } }; +#ifndef MKQL_DISABLE_CODEGEN + class TLLVMFieldsStructureState: public TLLVMFieldsStructure<TComputationValue<TBlockState>> { + private: + using TBase = TLLVMFieldsStructure<TComputationValue<TBlockState>>; + llvm::PointerType*const PointerType; + llvm::IntegerType*const IsFinishedType; + public: + std::vector<llvm::Type*> GetFieldsArray() { + std::vector<llvm::Type*> result = TBase::GetFields(); + result.emplace_back(PointerType); + result.emplace_back(IsFinishedType); + return result; + } -private: + llvm::Constant* GetPointer() { + return llvm::ConstantInt::get(llvm::Type::getInt32Ty(Context), TBase::GetFieldsCount() + 0); + } + + llvm::Constant* GetIsFinished() { + return llvm::ConstantInt::get(llvm::Type::getInt32Ty(Context), TBase::GetFieldsCount() + 1); + } + + TLLVMFieldsStructureState(llvm::LLVMContext& context, size_t width) + : TBase(context) + , PointerType(llvm::PointerType::getUnqual(llvm::ArrayType::get(llvm::Type::getInt128Ty(Context), width))) + , IsFinishedType(llvm::Type::getInt1Ty(Context)) + {} + }; +#endif void RegisterDependencies() const final { FlowDependsOn(Flow_); } + void MakeState(NUdf::TUnboxedValue& state, TComputationContext& ctx) const { + state = ctx.HolderFactory.Create<TState>(Width_, FilterColumn_, AggsParams_, ctx); + } + TState& GetState(NUdf::TUnboxedValue& state, TComputationContext& ctx) const { if (!state.HasValue()) { - state = ctx.HolderFactory.Create<TState>(Width_, FilterColumn_, AggsParams_, ctx); + MakeState(state, ctx); + + auto& s = *static_cast<TState*>(state.AsBoxed().Get()); + const auto fields = ctx.WideFields.data() + WideFieldsIndex_; + for (size_t i = 0; i < s.Values_.size(); ++i) { + fields[i] = &s.Values_[i]; + } + return s; } return *static_cast<TState*>(state.AsBoxed().Get()); } - - ui64 GetBatchLength(const NUdf::TUnboxedValue* columns) const { - return TArrowBlock::From(columns[Width_ - 1]).GetDatum().scalar_as<arrow::UInt64Scalar>().value; - } - private: - IComputationWideFlowNode* Flow_; - std::optional<ui32> FilterColumn_; + IComputationWideFlowNode *const Flow_; + const std::optional<ui32> FilterColumn_; const size_t Width_; const std::vector<TAggParams<IBlockAggregatorCombineAll>> AggsParams_; + const size_t WideFieldsIndex_; }; template <typename T> |