aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authora-romanov <Anton.Romanov@ydb.tech>2023-10-17 14:52:04 +0300
committera-romanov <Anton.Romanov@ydb.tech>2023-10-17 15:23:28 +0300
commitbb268f4050420167111455bf6682288627ffbff6 (patch)
tree7889c80a028d2035872abdfea5742a6b2ee2e828
parent88a3de5de621a28adbd0dd40f0864a32f8396dcd (diff)
downloadydb-bb268f4050420167111455bf6682288627ffbff6.tar.gz
YQL-15891 LLVM for BlockCombineAll.
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp301
1 files changed, 230 insertions, 71 deletions
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp
index 4ccc1787cf..76d0ccc7e3 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp
@@ -456,18 +456,20 @@ size_t CalcMaxBlockLenForOutput(TType* out) {
}
-class TBlockCombineAllWrapper : public TStatefulWideFlowComputationNode<TBlockCombineAllWrapper> {
+class TBlockCombineAllWrapper : public TStatefulWideFlowCodegeneratorNode<TBlockCombineAllWrapper> {
+using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TBlockCombineAllWrapper>;
public:
TBlockCombineAllWrapper(TComputationMutables& mutables,
IComputationWideFlowNode* flow,
std::optional<ui32> filterColumn,
size_t width,
std::vector<TAggParams<IBlockAggregatorCombineAll>>&& aggsParams)
- : TStatefulWideFlowComputationNode(mutables, flow, EValueRepresentation::Any)
+ : TBaseComputation(mutables, flow, EValueRepresentation::Boxed)
, Flow_(flow)
, FilterColumn_(filterColumn)
, Width_(width)
, AggsParams_(std::move(aggsParams))
+ , WideFieldsIndex_(mutables.IncrementWideFieldsIndex(width))
{
MKQL_ENSURE(Width_ > 0, "Missing block length column");
}
@@ -477,88 +479,160 @@ public:
NUdf::TUnboxedValue*const* output) const
{
auto& s = GetState(state, ctx);
- if (s.IsFinished_) {
+ if (s.IsFinished_)
return EFetchResult::Finish;
- }
- for (;;) {
- auto result = Flow_->FetchValues(ctx, s.ValuePointers_.data());
- if (result == EFetchResult::Yield) {
- return result;
- } else if (result == EFetchResult::One) {
- ui64 batchLength = GetBatchLength(s.Values_.data());
- if (!batchLength) {
+ for (const auto fields = ctx.WideFields.data() + WideFieldsIndex_;;) {
+ switch (Flow_->FetchValues(ctx, fields)) {
+ case EFetchResult::Yield:
+ return EFetchResult::Yield;
+ case EFetchResult::One:
+ s.ProcessInput();
continue;
+ case EFetchResult::Finish:
+ break;
+ }
+ if (s.MakeOutput()) {
+ for (size_t i = 0; i < AggsParams_.size(); ++i) {
+ if (const auto out = output[i]) {
+ *out = s.Pull(i);
+ }
}
+ return EFetchResult::One;
+ }
+ return EFetchResult::Finish;
+ }
+ }
+#ifndef MKQL_DISABLE_CODEGEN
+ ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
+ auto& context = ctx.Codegen.GetContext();
- std::optional<ui64> filtered;
- if (FilterColumn_) {
- auto filterDatum = TArrowBlock::From(s.Values_[*FilterColumn_]).GetDatum();
- if (filterDatum.is_scalar()) {
- if (!filterDatum.scalar_as<arrow::UInt8Scalar>().value) {
- continue;
- }
- } else {
- ui64 popCount = GetBitmapPopCount(filterDatum.array());
- if (popCount == 0) {
- continue;
- }
+ const auto valueType = Type::getInt128Ty(context);
+ const auto ptrValueType = PointerType::getUnqual(valueType);
+ const auto statusType = Type::getInt32Ty(context);
+ const auto indexType = Type::getInt64Ty(context);
+ const auto flagType = Type::getInt1Ty(context);
+ const auto arrayType = ArrayType::get(valueType, Width_);
+ const auto ptrValuesType = PointerType::getUnqual(arrayType);
- if (popCount < batchLength) {
- filtered = popCount;
- }
- }
- }
+ TLLVMFieldsStructureState stateFields(context, Width_);
+ const auto stateType = StructType::get(context, stateFields.GetFieldsArray());
+ const auto statePtrType = PointerType::getUnqual(stateType);
- s.HasValues_ = true;
- char* ptr = s.AggStates_.data();
- for (size_t i = 0; i < s.Aggs_.size(); ++i) {
- if (output[i]) {
- s.Aggs_[i]->AddMany(ptr, s.Values_.data(), batchLength, filtered);
- }
+ const auto atTop = &ctx.Func->getEntryBlock().back();
- ptr += s.Aggs_[i]->StateSize;
- }
- } else {
- s.IsFinished_ = true;
- if (!s.HasValues_) {
- return EFetchResult::Finish;
- }
+ const auto pullFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Pull));
+ const auto pullType = FunctionType::get(valueType, {statePtrType, indexType}, false);
+ const auto pullPtr = CastInst::Create(Instruction::IntToPtr, pullFunc, PointerType::getUnqual(pullType), "pull", atTop);
- char* ptr = s.AggStates_.data();
- for (size_t i = 0; i < s.Aggs_.size(); ++i) {
- if (auto* out = output[i]; out != nullptr) {
- *out = s.Aggs_[i]->FinishOne(ptr);
- s.Aggs_[i]->DestroyState(ptr);
- }
+ const auto stateOnStack = new AllocaInst(statePtrType, 0U, "state_on_stack", atTop);
+ new StoreInst(ConstantPointerNull::get(statePtrType), stateOnStack, atTop);
- ptr += s.Aggs_[i]->StateSize;
- }
+ const auto make = BasicBlock::Create(context, "make", ctx.Func);
+ const auto main = BasicBlock::Create(context, "main", ctx.Func);
+ const auto read = BasicBlock::Create(context, "read", ctx.Func);
+ const auto good = BasicBlock::Create(context, "good", ctx.Func);
+ const auto work = BasicBlock::Create(context, "work", ctx.Func);
+ const auto over = BasicBlock::Create(context, "over", ctx.Func);
- return EFetchResult::One;
- }
+ BranchInst::Create(main, make, HasValue(statePtr, block), block);
+ block = make;
+
+ const auto ptrType = PointerType::getUnqual(StructType::get(context));
+ const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block);
+ const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TBlockCombineAllWrapper::MakeState));
+ const auto makeType = FunctionType::get(Type::getVoidTy(context), {self->getType(), statePtr->getType(), ctx.Ctx->getType()}, false);
+ const auto makeFuncPtr = CastInst::Create(Instruction::IntToPtr, makeFunc, PointerType::getUnqual(makeType), "function", block);
+ CallInst::Create(makeType, makeFuncPtr, {self, statePtr, ctx.Ctx}, "", block);
+
+ BranchInst::Create(main, block);
+
+ block = main;
+
+ const auto state = new LoadInst(valueType, statePtr, "state", block);
+ const auto half = CastInst::Create(Instruction::Trunc, state, Type::getInt64Ty(context), "half", block);
+ const auto stateArg = CastInst::Create(Instruction::IntToPtr, half, statePtrType, "state_arg", block);
+
+ const auto finishedPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetIsFinished() }, "is_finished_ptr", block);
+ const auto finished = new LoadInst(flagType, finishedPtr, "finished", block);
+
+ const auto result = PHINode::Create(statusType, 3U, "result", over);
+ result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), block);
+
+ BranchInst::Create(over, read, finished, block);
+
+ block = read;
+
+ const auto valuesPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetPointer() }, "values_ptr", block);
+ const auto values = new LoadInst(ptrValuesType, valuesPtr, "values", block);
+ SafeUnRefUnboxed(values, ctx, block);
+
+ const auto getres = GetNodeValues(Flow_, ctx, block);
+ result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Yield)), block);
+
+ const auto way = SwitchInst::Create(getres.first, good, 2U, block);
+ way->addCase(ConstantInt::get(statusType, i32(EFetchResult::Finish)), work);
+ way->addCase(ConstantInt::get(statusType, i32(EFetchResult::Yield)), over);
+
+ block = good;
+
+ Value* array = UndefValue::get(arrayType);
+ for (auto idx = 0U; idx < getres.second.size(); ++idx) {
+ const auto value = getres.second[idx](ctx, block);
+ AddRefBoxed(value, ctx, block);
+ array = InsertValueInst::Create(array, value, {idx}, (TString("value_") += ToString(idx)).c_str(), block);
}
+ new StoreInst(array, values, block);
- return EFetchResult::Finish;
- }
+ const auto processBlockFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::ProcessInput));
+ const auto processBlockType = FunctionType::get(Type::getVoidTy(context), {statePtrType, ctx.GetFactory()->getType()}, false);
+ const auto processBlockPtr = CastInst::Create(Instruction::IntToPtr, processBlockFunc, PointerType::getUnqual(processBlockType), "process_inputs_func", block);
+ CallInst::Create(processBlockType, processBlockPtr, {stateArg, ctx.GetFactory()}, "", block);
+
+ BranchInst::Create(read, block);
+
+ block = work;
+
+ const auto makeOutputFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::MakeOutput));
+ const auto makeOutputType = FunctionType::get(flagType, {statePtrType, ctx.GetFactory()->getType()}, false);
+ const auto makeOutputPtr = CastInst::Create(Instruction::IntToPtr, makeOutputFunc, PointerType::getUnqual(makeOutputType), "make_output_func", block);
+ const auto hasData = CallInst::Create(makeOutputType, makeOutputPtr, {stateArg, ctx.GetFactory()}, "make_output", block);
+ const auto output = SelectInst::Create(hasData, ConstantInt::get(statusType, static_cast<i32>(EFetchResult::One)), ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), "output", block);
+ new StoreInst(stateArg, stateOnStack, block);
+
+ result->addIncoming(output, block);
+ BranchInst::Create(over, block);
+
+ block = over;
+ ICodegeneratorInlineWideNode::TGettersList getters(AggsParams_.size());
+ for (size_t idx = 0U; idx < getters.size(); ++idx) {
+ getters[idx] = [idx, pullType, pullPtr, indexType, statePtrType, stateOnStack](const TCodegenContext& ctx, BasicBlock*& block) {
+ const auto stateArg = new LoadInst(statePtrType, stateOnStack, "state", block);
+ return CallInst::Create(pullType, pullPtr, {stateArg, ConstantInt::get(indexType, idx)}, "pull", block);
+ };
+ }
+ return {result, std::move(getters)};
+ }
+#endif
private:
struct TState : public TComputationValue<TState> {
- std::vector<NUdf::TUnboxedValue> Values_;
- std::vector<NUdf::TUnboxedValue*> ValuePointers_;
- std::vector<std::unique_ptr<IBlockAggregatorCombineAll>> Aggs_;
+ NUdf::TUnboxedValue* Pointer_ = nullptr;
bool IsFinished_ = false;
bool HasValues_ = false;
+ TUnboxedValueVector Values_;
+ std::vector<std::unique_ptr<IBlockAggregatorCombineAll>> Aggs_;
std::vector<char> AggStates_;
+ const std::optional<ui32> FilterColumn_;
+ const size_t Width_;
- TState(TMemoryUsageInfo* memInfo, size_t width, std::optional<ui32>, const std::vector<TAggParams<IBlockAggregatorCombineAll>>& params, TComputationContext& ctx)
+ TState(TMemoryUsageInfo* memInfo, size_t width, std::optional<ui32> filterColumn, const std::vector<TAggParams<IBlockAggregatorCombineAll>>& params, TComputationContext& ctx)
: TComputationValue(memInfo)
- , Values_(width)
- , ValuePointers_(width)
+ , Values_(std::max(width, params.size()))
+ , FilterColumn_(filterColumn)
+ , Width_(width)
{
- for (size_t i = 0; i < width; ++i) {
- ValuePointers_[i] = &Values_[i];
- }
+ Pointer_ = Values_.data();
ui32 totalStateSize = 0;
for (const auto& p : params) {
@@ -574,29 +648,114 @@ private:
ptr += agg->StateSize;
}
}
+
+ void ProcessInput() {
+ const ui64 batchLength = TArrowBlock::From(Values_[Width_ - 1U]).GetDatum().scalar_as<arrow::UInt64Scalar>().value;
+ if (!batchLength) {
+ return;
+ }
+
+ std::optional<ui64> filtered;
+ if (FilterColumn_) {
+ const auto filterDatum = TArrowBlock::From(Values_[*FilterColumn_]).GetDatum();
+ if (filterDatum.is_scalar()) {
+ if (!filterDatum.scalar_as<arrow::UInt8Scalar>().value) {
+ return;
+ }
+ } else {
+ const ui64 popCount = GetBitmapPopCount(filterDatum.array());
+ if (popCount == 0) {
+ return;
+ }
+
+ if (popCount < batchLength) {
+ filtered = popCount;
+ }
+ }
+ }
+
+ HasValues_ = true;
+ char* ptr = AggStates_.data();
+ for (size_t i = 0; i < Aggs_.size(); ++i) {
+ Aggs_[i]->AddMany(ptr, Values_.data(), batchLength, filtered);
+ ptr += Aggs_[i]->StateSize;
+ }
+ }
+
+ bool MakeOutput() {
+ IsFinished_ = true;
+ if (!HasValues_)
+ return false;
+
+ char* ptr = AggStates_.data();
+ for (size_t i = 0; i < Aggs_.size(); ++i) {
+ Values_[i] = Aggs_[i]->FinishOne(ptr);
+ Aggs_[i]->DestroyState(ptr);
+ ptr += Aggs_[i]->StateSize;
+ }
+ return true;
+ }
+
+ NUdf::TUnboxedValuePod Pull(size_t index) {
+ return Values_[index].Release();
+ }
};
+#ifndef MKQL_DISABLE_CODEGEN
+ class TLLVMFieldsStructureState: public TLLVMFieldsStructure<TComputationValue<TBlockState>> {
+ private:
+ using TBase = TLLVMFieldsStructure<TComputationValue<TBlockState>>;
+ llvm::PointerType*const PointerType;
+ llvm::IntegerType*const IsFinishedType;
+ public:
+ std::vector<llvm::Type*> GetFieldsArray() {
+ std::vector<llvm::Type*> result = TBase::GetFields();
+ result.emplace_back(PointerType);
+ result.emplace_back(IsFinishedType);
+ return result;
+ }
-private:
+ llvm::Constant* GetPointer() {
+ return llvm::ConstantInt::get(llvm::Type::getInt32Ty(Context), TBase::GetFieldsCount() + 0);
+ }
+
+ llvm::Constant* GetIsFinished() {
+ return llvm::ConstantInt::get(llvm::Type::getInt32Ty(Context), TBase::GetFieldsCount() + 1);
+ }
+
+ TLLVMFieldsStructureState(llvm::LLVMContext& context, size_t width)
+ : TBase(context)
+ , PointerType(llvm::PointerType::getUnqual(llvm::ArrayType::get(llvm::Type::getInt128Ty(Context), width)))
+ , IsFinishedType(llvm::Type::getInt1Ty(Context))
+ {}
+ };
+#endif
void RegisterDependencies() const final {
FlowDependsOn(Flow_);
}
+ void MakeState(NUdf::TUnboxedValue& state, TComputationContext& ctx) const {
+ state = ctx.HolderFactory.Create<TState>(Width_, FilterColumn_, AggsParams_, ctx);
+ }
+
TState& GetState(NUdf::TUnboxedValue& state, TComputationContext& ctx) const {
if (!state.HasValue()) {
- state = ctx.HolderFactory.Create<TState>(Width_, FilterColumn_, AggsParams_, ctx);
+ MakeState(state, ctx);
+
+ auto& s = *static_cast<TState*>(state.AsBoxed().Get());
+ const auto fields = ctx.WideFields.data() + WideFieldsIndex_;
+ for (size_t i = 0; i < s.Values_.size(); ++i) {
+ fields[i] = &s.Values_[i];
+ }
+ return s;
}
return *static_cast<TState*>(state.AsBoxed().Get());
}
-
- ui64 GetBatchLength(const NUdf::TUnboxedValue* columns) const {
- return TArrowBlock::From(columns[Width_ - 1]).GetDatum().scalar_as<arrow::UInt64Scalar>().value;
- }
-
private:
- IComputationWideFlowNode* Flow_;
- std::optional<ui32> FilterColumn_;
+ IComputationWideFlowNode *const Flow_;
+ const std::optional<ui32> FilterColumn_;
const size_t Width_;
const std::vector<TAggParams<IBlockAggregatorCombineAll>> AggsParams_;
+ const size_t WideFieldsIndex_;
};
template <typename T>