diff options
author | a-romanov <Anton.Romanov@ydb.tech> | 2023-09-26 11:42:29 +0300 |
---|---|---|
committer | a-romanov <Anton.Romanov@ydb.tech> | 2023-09-26 12:22:16 +0300 |
commit | 0f66bece4999d5dbabdb090d47fc9a240f856730 (patch) | |
tree | 293a7c289727cfcbca412aaf3e7b65120d0359b6 | |
parent | 7fcf8b470c7f4b97acf6f2833afa770c1372f231 (diff) | |
download | ydb-0f66bece4999d5dbabdb090d47fc9a240f856730.tar.gz |
YQL-15891 LLVM for some blocks nodes.
4 files changed, 331 insertions, 49 deletions
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_blocks.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_blocks.cpp index ebe61f6ebf..ed79f31732 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_blocks.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_blocks.cpp @@ -271,15 +271,13 @@ private: struct TState : public TBlockState { size_t Rows_ = 0; bool IsFinished_ = false; - NUdf::TUnboxedValue**const Fields_; std::vector<std::unique_ptr<IArrayBuilder>> Builders_; TState(TMemoryUsageInfo* memInfo, TComputationContext& ctx, const TVector<TType*>& types, size_t maxLength, NUdf::TUnboxedValue**const fields) : TBlockState(memInfo, types.size() + 1U) - , Fields_(fields) , Builders_(types.size()) { for (size_t i = 0; i < types.size(); ++i) { - Fields_[i] = &Values[i]; + fields[i] = &Values[i]; Builders_[i] = MakeArrayBuilder(TTypeInfoHelper(), types[i], ctx.ArrowMemoryPool, maxLength, &ctx.Builder->GetPgBuilder()); } } @@ -306,6 +304,7 @@ private: private: using TBase = TLLVMFieldsStructure<TComputationValue<TState>>; llvm::IntegerType*const CountType; + llvm::PointerType*const PointerType; llvm::ArrayType*const SkipSpaceType; llvm::IntegerType*const RowsType; llvm::IntegerType*const IsFinishedType; @@ -315,6 +314,7 @@ private: std::vector<llvm::Type*> GetFieldsArray() { std::vector<llvm::Type*> result = TBase::GetFields(); result.emplace_back(CountType); + result.emplace_back(PointerType); result.emplace_back(SkipSpaceType); result.emplace_back(RowsType); result.emplace_back(IsFinishedType); @@ -325,17 +325,22 @@ private: return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + 0); } + llvm::Constant* GetPointer() { + return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + 1); + } + llvm::Constant* GetRows() { - return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + 2); + return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + 3); } llvm::Constant* GetIsFinished() { - return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + 3); + return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + 4); } TLLVMFieldsStructureState(llvm::LLVMContext& context) : TBase(context) , CountType(Type::getInt64Ty(Context)) + , PointerType(PointerType::getUnqual(Type::getInt128Ty(Context))) , SkipSpaceType(ArrayType::get(Type::getInt128Ty(Context), 3U)) // Skip std::vectors Values & Arrays , RowsType(Type::getInt64Ty(Context)) , IsFinishedType(Type::getInt1Ty(Context)) @@ -821,10 +826,11 @@ private: const std::vector<arrow::ValueDescr> EmptyDesc_; }; -class TAsScalarWrapper : public TMutableComputationNode<TAsScalarWrapper> { +class TAsScalarWrapper : public TMutableCodegeneratorNode<TAsScalarWrapper> { +using TBaseComputation = TMutableCodegeneratorNode<TAsScalarWrapper>; public: TAsScalarWrapper(TComputationMutables& mutables, IComputationNode* arg, TType* type) - : TMutableComputationNode(mutables) + : TBaseComputation(mutables, EValueRepresentation::Boxed) , Arg_(arg) , Type_(type) { @@ -833,31 +839,59 @@ public: } NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const { - auto value = Arg_->GetValue(ctx); - arrow::Datum result = ConvertScalar(Type_, value, ctx.ArrowMemoryPool); - return ctx.HolderFactory.CreateArrowBlock(std::move(result)); + return AsScalar(Arg_->GetValue(ctx).Release(), ctx); } +#ifndef MKQL_DISABLE_CODEGEN + Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const { + auto& context = ctx.Codegen.GetContext(); + + const auto value = GetNodeValue(Arg_, ctx, block); + const auto ptrType = PointerType::getUnqual(StructType::get(context)); + const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block); + const auto asScalarFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TAsScalarWrapper::AsScalar)); + + if (NYql::NCodegen::ETarget::Windows != ctx.Codegen.GetEffectiveTarget()) { + const auto asScalarType = FunctionType::get(Type::getInt128Ty(context), {self->getType(), value->getType(), ctx.Ctx->getType()}, false); + const auto asScalarFuncPtr = CastInst::Create(Instruction::IntToPtr, asScalarFunc, PointerType::getUnqual(asScalarType), "function", block); + return CallInst::Create(asScalarType, asScalarFuncPtr, {self, value, ctx.Ctx}, "scalar", block); + } else { + const auto valuePtr = new AllocaInst(value->getType(), 0U, "value", block); + new StoreInst(value, valuePtr, block); + const auto asScalarType = FunctionType::get(Type::getVoidTy(context), {self->getType(), valuePtr->getType(), valuePtr->getType(), ctx.Ctx->getType()}, false); + const auto asScalarFuncPtr = CastInst::Create(Instruction::IntToPtr, asScalarFunc, PointerType::getUnqual(asScalarType), "function", block); + CallInst::Create(asScalarType, asScalarFuncPtr, {self, valuePtr, valuePtr, ctx.Ctx}, "", block); + return new LoadInst(value->getType(), valuePtr, "result", block); + } + } +#endif +private: std::unique_ptr<IArrowKernelComputationNode> PrepareArrowKernelComputationNode(TComputationContext& ctx) const final { - auto value = Arg_->GetValue(ctx); - arrow::Datum result = ConvertScalar(Type_, value, ctx.ArrowMemoryPool); - return std::make_unique<TPrecomputedArrowNode>(result, "AsScalar"); + return std::make_unique<TPrecomputedArrowNode>(DoAsScalar(Arg_->GetValue(ctx).Release(), ctx), "AsScalar"); + } + + arrow::Datum DoAsScalar(const NUdf::TUnboxedValuePod value, TComputationContext& ctx) const { + const NUdf::TUnboxedValue v(value); + return ConvertScalar(Type_, v, ctx.ArrowMemoryPool); + } + + NUdf::TUnboxedValuePod AsScalar(const NUdf::TUnboxedValuePod value, TComputationContext& ctx) const { + return ctx.HolderFactory.CreateArrowBlock(DoAsScalar(value, ctx)); } -private: void RegisterDependencies() const final { DependsOn(Arg_); } -private: IComputationNode* const Arg_; TType* Type_; }; -class TReplicateScalarWrapper : public TMutableComputationNode<TReplicateScalarWrapper> { +class TReplicateScalarWrapper : public TMutableCodegeneratorNode<TReplicateScalarWrapper> { +using TBaseComputation = TMutableCodegeneratorNode<TReplicateScalarWrapper>; public: TReplicateScalarWrapper(TComputationMutables& mutables, IComputationNode* value, IComputationNode* count, TType* type) - : TMutableComputationNode(mutables) + : TBaseComputation(mutables, EValueRepresentation::Boxed) , Value_(value) , Count_(count) , Type_(type) @@ -867,58 +901,304 @@ public: } NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const { - return ctx.HolderFactory.CreateArrowBlock(DoReplicate(ctx)); + const auto value = Value_->GetValue(ctx).Release(); + const auto count = Count_->GetValue(ctx).Release(); + return Replicate(value, count, ctx); } +#ifndef MKQL_DISABLE_CODEGEN + Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const { + auto& context = ctx.Codegen.GetContext(); + const auto value = GetNodeValue(Value_, ctx, block); + const auto count = GetNodeValue(Count_, ctx, block); + + const auto ptrType = PointerType::getUnqual(StructType::get(context)); + const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block); + const auto replicateFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TReplicateScalarWrapper::Replicate)); + + if (NYql::NCodegen::ETarget::Windows != ctx.Codegen.GetEffectiveTarget()) { + const auto replicateType = FunctionType::get(Type::getInt128Ty(context), {self->getType(), value->getType(), count->getType(), ctx.Ctx->getType()}, false); + const auto replicateFuncPtr = CastInst::Create(Instruction::IntToPtr, replicateFunc, PointerType::getUnqual(replicateType), "function", block); + return CallInst::Create(replicateType, replicateFuncPtr, {self, value, count, ctx.Ctx}, "replicate", block); + } else { + const auto valuePtr = new AllocaInst(value->getType(), 0U, "value", block); + const auto countPtr = new AllocaInst(count->getType(), 0U, "count", block); + new StoreInst(value, valuePtr, block); + new StoreInst(count, countPtr, block); + const auto replicateType = FunctionType::get(Type::getVoidTy(context), {self->getType(), valuePtr->getType(), valuePtr->getType(), countPtr->getType(), ctx.Ctx->getType()}, false); + const auto replicateFuncPtr = CastInst::Create(Instruction::IntToPtr, replicateFunc, PointerType::getUnqual(replicateType), "function", block); + CallInst::Create(replicateType, replicateFuncPtr, {self, valuePtr, valuePtr, countPtr, ctx.Ctx}, "", block); + return new LoadInst(value->getType(), valuePtr, "result", block); + } + } +#endif +private: std::unique_ptr<IArrowKernelComputationNode> PrepareArrowKernelComputationNode(TComputationContext& ctx) const final { - return std::make_unique<TPrecomputedArrowNode>(DoReplicate(ctx), "ReplicateScalar"); + const auto value = Value_->GetValue(ctx).Release(); + const auto count = Count_->GetValue(ctx).Release(); + return std::make_unique<TPrecomputedArrowNode>(DoReplicate(value, count, ctx), "ReplicateScalar"); } -private: - arrow::Datum DoReplicate(TComputationContext& ctx) const { - auto value = TArrowBlock::From(Value_->GetValue(ctx)).GetDatum().scalar(); - const ui64 count = TArrowBlock::From(Count_->GetValue(ctx)).GetDatum().scalar_as<arrow::UInt64Scalar>().value; + arrow::Datum DoReplicate(const NUdf::TUnboxedValuePod val, const NUdf::TUnboxedValuePod cnt, TComputationContext& ctx) const { + const NUdf::TUnboxedValue v(val), c(cnt); + const auto value = TArrowBlock::From(v).GetDatum().scalar(); + const ui64 count = TArrowBlock::From(c).GetDatum().scalar_as<arrow::UInt64Scalar>().value; - auto reader = MakeBlockReader(TTypeInfoHelper(), Type_); - auto builder = MakeArrayBuilder(TTypeInfoHelper(), Type_, ctx.ArrowMemoryPool, count, &ctx.Builder->GetPgBuilder()); + const auto reader = MakeBlockReader(TTypeInfoHelper(), Type_); + const auto builder = MakeArrayBuilder(TTypeInfoHelper(), Type_, ctx.ArrowMemoryPool, count, &ctx.Builder->GetPgBuilder()); TBlockItem item = reader->GetScalarItem(*value); builder->Add(item, count); return builder->Build(true); } + NUdf::TUnboxedValuePod Replicate(const NUdf::TUnboxedValuePod value, const NUdf::TUnboxedValuePod count, TComputationContext& ctx) const { + return ctx.HolderFactory.CreateArrowBlock(DoReplicate(value, count, ctx)); + } + void RegisterDependencies() const final { DependsOn(Value_); DependsOn(Count_); } -private: IComputationNode* const Value_; IComputationNode* const Count_; TType* Type_; }; -class TBlockExpandChunkedWrapper : public TStatefulWideFlowBlockComputationNode<TBlockExpandChunkedWrapper> { +class TBlockExpandChunkedWrapper : public TStatefulWideFlowCodegeneratorNode<TBlockExpandChunkedWrapper> { +using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TBlockExpandChunkedWrapper>; public: - TBlockExpandChunkedWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow, ui32 width) - : TStatefulWideFlowBlockComputationNode(mutables, flow, width) + TBlockExpandChunkedWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow, size_t width) + : TBaseComputation(mutables, flow, EValueRepresentation::Boxed) , Flow_(flow) + , Width_(width) + , WideFieldsIndex_(mutables.IncrementWideFieldsIndex(Width_)) { } - EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, - NUdf::TUnboxedValue*const* output) const - { - Y_UNUSED(state); - return Flow_->FetchValues(ctx, output); + EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const { + auto& s = GetState(state, ctx); + const auto fields = ctx.WideFields.data() + WideFieldsIndex_; + for (size_t i = 0; i < Width_; ++i) + fields[i] = output[i] ? &s.Values[i] : nullptr; + + while (!s.Count) { + s.Values.assign(s.Values.size(), NUdf::TUnboxedValuePod()); + if (const auto result = Flow_->FetchValues(ctx, fields); result != EFetchResult::One) + return result; + s.FillArrays(); + } + + const auto sliceSize = s.Slice(); + for (size_t i = 0; i < Width_; ++i) { + if (const auto out = output[i]) { + *out = s.Get(sliceSize, ctx.HolderFactory, i); + } + } + return EFetchResult::One; } +#ifndef MKQL_DISABLE_CODEGEN + ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const { + auto& context = ctx.Codegen.GetContext(); + const auto valueType = Type::getInt128Ty(context); + const auto ptrValueType = PointerType::getUnqual(valueType); + const auto statusType = Type::getInt32Ty(context); + const auto indexType = Type::getInt64Ty(context); + const auto arrayType = ArrayType::get(valueType, Width_); + const auto ptrValuesType = PointerType::getUnqual(ArrayType::get(valueType, Width_)); + + TLLVMFieldsStructureState stateFields(context, Width_); + const auto stateType = StructType::get(context, stateFields.GetFieldsArray()); + const auto statePtrType = PointerType::getUnqual(stateType); + + const auto atTop = &ctx.Func->getEntryBlock().back(); + + const auto getFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TBlockState::Get)); + const auto getType = FunctionType::get(valueType, {statePtrType, indexType, ctx.GetFactory()->getType(), indexType}, false); + const auto getPtr = CastInst::Create(Instruction::IntToPtr, getFunc, PointerType::getUnqual(getType), "get", atTop); + + const auto heightPtr = new AllocaInst(indexType, 0U, "height_ptr", atTop); + const auto stateOnStack = new AllocaInst(statePtrType, 0U, "state_on_stack", atTop); + + new StoreInst(ConstantInt::get(indexType, 0), heightPtr, atTop); + new StoreInst(ConstantPointerNull::get(statePtrType), stateOnStack, atTop); + + const auto name = "GetCount"; + ctx.Codegen.AddGlobalMapping(name, reinterpret_cast<const void*>(&GetCount)); + const auto getCountType = NYql::NCodegen::ETarget::Windows != ctx.Codegen.GetEffectiveTarget() ? + FunctionType::get(indexType, { valueType }, false): + FunctionType::get(indexType, { ptrValueType }, false); + const auto getCount = ctx.Codegen.GetModule().getOrInsertFunction(name, getCountType); + + const auto make = BasicBlock::Create(context, "make", ctx.Func); + const auto main = BasicBlock::Create(context, "main", ctx.Func); + const auto loop = BasicBlock::Create(context, "loop", ctx.Func); + const auto more = BasicBlock::Create(context, "more", ctx.Func); + const auto work = BasicBlock::Create(context, "work", ctx.Func); + const auto fill = BasicBlock::Create(context, "fill", ctx.Func); + const auto over = BasicBlock::Create(context, "over", ctx.Func); + + BranchInst::Create(main, make, HasValue(statePtr, block), block); + block = make; + + const auto ptrType = PointerType::getUnqual(StructType::get(context)); + const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block); + const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TBlockExpandChunkedWrapper::MakeState)); + const auto makeType = FunctionType::get(Type::getVoidTy(context), {self->getType(), ctx.Ctx->getType(), statePtr->getType()}, false); + const auto makeFuncPtr = CastInst::Create(Instruction::IntToPtr, makeFunc, PointerType::getUnqual(makeType), "function", block); + CallInst::Create(makeType, makeFuncPtr, {self, ctx.Ctx, statePtr}, "", block); + BranchInst::Create(main, block); + + block = main; + + const auto state = new LoadInst(valueType, statePtr, "state", block); + const auto half = CastInst::Create(Instruction::Trunc, state, Type::getInt64Ty(context), "half", block); + const auto stateArg = CastInst::Create(Instruction::IntToPtr, half, statePtrType, "state_arg", block); + + const auto countPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetCount() }, "count_ptr", block); + + BranchInst::Create(loop, block); + + block = loop; + + const auto count = new LoadInst(indexType, countPtr, "count", block); + + const auto next = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, count, ConstantInt::get(indexType, 0), "next", block); + + BranchInst::Create(more, fill, next, block); + + block = more; + + const auto valuesPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetPointer() }, "values_ptr", block); + const auto values = new LoadInst(ptrValuesType, valuesPtr, "values", block); + SafeUnRefUnboxed(values, ctx, block); + + const auto getres = GetNodeValues(Flow_, ctx, block); + + const auto special = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SLE, getres.first, ConstantInt::get(getres.first->getType(), static_cast<i32>(EFetchResult::Yield)), "special", block); + + const auto result = PHINode::Create(statusType, 2U, "result", over); + result->addIncoming(getres.first, block); + + BranchInst::Create(over, work, special, block); + + block = work; + + const auto countValue = getres.second.back()(ctx, block); + const auto tailPtr = GetElementPtrInst::CreateInBounds(arrayType, values, { ConstantInt::get(indexType, 0), ConstantInt::get(indexType, Width_ - 1U) }, "tail_ptr", block); + new StoreInst(countValue, tailPtr, block); + AddRefBoxed(countValue, ctx, block); + + const auto height = CallInst::Create(getCount, { WrapArgumentForWindows(countValue, ctx, block) }, "height", block); + new StoreInst(height, countPtr, block); + + const auto makeBlockFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TBlockState::FillArrays)); + const auto makeBlockType = FunctionType::get(indexType, {statePtrType}, false); + const auto makeBlockPtr = CastInst::Create(Instruction::IntToPtr, makeBlockFunc, PointerType::getUnqual(makeBlockType), "fill_arrays_func", block); + CallInst::Create(makeBlockType, makeBlockPtr, {stateArg}, "", block); + + BranchInst::Create(loop, block); + + block = fill; + + const auto sliceFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TBlockState::Slice)); + const auto sliceType = FunctionType::get(indexType, {statePtrType}, false); + const auto slicePtr = CastInst::Create(Instruction::IntToPtr, sliceFunc, PointerType::getUnqual(sliceType), "slice_func", block); + const auto slice = CallInst::Create(sliceType, slicePtr, {stateArg}, "slice", block); + new StoreInst(slice, heightPtr, block); + new StoreInst(stateArg, stateOnStack, block); + + result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::One)), block); + + BranchInst::Create(over, block); + + block = over; + + ICodegeneratorInlineWideNode::TGettersList getters(Width_); + for (size_t idx = 0U; idx < getters.size(); ++idx) { + getters[idx] = [idx, width = Width_, getType, getPtr, heightPtr, indexType, arrayType, ptrValuesType, stateType, statePtrType, stateOnStack, getBlocks = getres.second](const TCodegenContext& ctx, BasicBlock*& block) { + auto& context = ctx.Codegen.GetContext(); + const auto init = BasicBlock::Create(context, "init", ctx.Func); + const auto call = BasicBlock::Create(context, "call", ctx.Func); + + TLLVMFieldsStructureState stateFields(context, width); + + const auto stateArg = new LoadInst(statePtrType, stateOnStack, "state", block); + const auto valuesPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetPointer() }, "values_ptr", block); + const auto values = new LoadInst(ptrValuesType, valuesPtr, "values", block); + const auto index = ConstantInt::get(indexType, idx); + const auto pointer = GetElementPtrInst::CreateInBounds(arrayType, values, { ConstantInt::get(indexType, 0), index }, "pointer", block); + + BranchInst::Create(call, init, HasValue(pointer, block), block); + + block = init; + + const auto value = getBlocks[idx](ctx, block); + new StoreInst(value, pointer, block); + AddRefBoxed(value, ctx, block); + + BranchInst::Create(call, block); + + block = call; + + const auto heightArg = new LoadInst(indexType, heightPtr, "height", block); + return CallInst::Create(getType, getPtr, {stateArg, heightArg, ctx.GetFactory(), index}, "get", block); + }; + } + return {result, std::move(getters)}; + } +#endif private: +#ifndef MKQL_DISABLE_CODEGEN + class TLLVMFieldsStructureState: public TLLVMFieldsStructure<TComputationValue<TBlockState>> { + private: + using TBase = TLLVMFieldsStructure<TComputationValue<TBlockState>>; + llvm::IntegerType*const CountType; + llvm::PointerType*const PointerType; + protected: + using TBase::Context; + public: + std::vector<llvm::Type*> GetFieldsArray() { + std::vector<llvm::Type*> result = TBase::GetFields(); + result.emplace_back(CountType); + result.emplace_back(PointerType); + return result; + } + + llvm::Constant* GetCount() { + return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + 0); + } + + llvm::Constant* GetPointer() { + return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + 1); + } + + TLLVMFieldsStructureState(llvm::LLVMContext& context, size_t width) + : TBase(context) + , CountType(Type::getInt64Ty(Context)) + , PointerType(PointerType::getUnqual(ArrayType::get(Type::getInt128Ty(Context), width))) + {} + }; +#endif + void MakeState(TComputationContext& ctx, NUdf::TUnboxedValue& state) const { + state = ctx.HolderFactory.Create<TBlockState>(Width_); + } + + TBlockState& GetState(NUdf::TUnboxedValue& state, TComputationContext& ctx) const { + if (!state.HasValue()) + MakeState(ctx, state); + return *static_cast<TBlockState*>(state.AsBoxed().Get()); + } + void RegisterDependencies() const final { FlowDependsOn(Flow_); } IComputationWideFlowNode* const Flow_; + const size_t Width_; + const size_t WideFieldsIndex_; }; } // namespace @@ -935,7 +1215,7 @@ IComputationNode* WrapWideToBlocks(TCallable& callable, const TComputationNodeFa const auto flowType = AS_TYPE(TFlowType, callable.GetInput(0).GetStaticType()); const auto wideComponents = GetWideComponents(flowType); TVector<TType*> items(wideComponents.begin(), wideComponents.end()); - auto wideFlow = dynamic_cast<IComputationWideFlowNode*>(LocateNode(ctx.NodeLocator, callable, 0)); + const auto wideFlow = dynamic_cast<IComputationWideFlowNode*>(LocateNode(ctx.NodeLocator, callable, 0)); MKQL_ENSURE(wideFlow != nullptr, "Expected wide flow node"); return new TWideToBlocksWrapper(ctx.Mutables, wideFlow, std::move(items)); @@ -961,9 +1241,8 @@ IComputationNode* WrapWideFromBlocks(TCallable& callable, const TComputationNode items.push_back(blockType->GetItemType()); } - auto wideFlow = dynamic_cast<IComputationWideFlowNode*>(LocateNode(ctx.NodeLocator, callable, 0)); + const auto wideFlow = dynamic_cast<IComputationWideFlowNode*>(LocateNode(ctx.NodeLocator, callable, 0)); MKQL_ENSURE(wideFlow != nullptr, "Expected wide flow node"); - return new TWideFromBlocksWrapper(ctx.Mutables, wideFlow, std::move(items)); } @@ -979,8 +1258,8 @@ IComputationNode* WrapReplicateScalar(TCallable& callable, const TComputationNod const auto valueType = AS_TYPE(TBlockType, callable.GetInput(0).GetStaticType()); MKQL_ENSURE(valueType->GetShape() == TBlockType::EShape::Scalar, "Expecting scalar as first arg"); - auto value = LocateNode(ctx.NodeLocator, callable, 0); - auto count = LocateNode(ctx.NodeLocator, callable, 1); + const auto value = LocateNode(ctx.NodeLocator, callable, 0); + const auto count = LocateNode(ctx.NodeLocator, callable, 1); return new TReplicateScalarWrapper(ctx.Mutables, value, count, valueType->GetItemType()); } @@ -990,9 +1269,8 @@ IComputationNode* WrapBlockExpandChunked(TCallable& callable, const TComputation const auto flowType = AS_TYPE(TFlowType, callable.GetInput(0).GetStaticType()); const auto wideComponents = GetWideComponents(flowType); - auto wideFlow = dynamic_cast<IComputationWideFlowNode*>(LocateNode(ctx.NodeLocator, callable, 0)); + const auto wideFlow = dynamic_cast<IComputationWideFlowNode*>(LocateNode(ctx.NodeLocator, callable, 0)); MKQL_ENSURE(wideFlow != nullptr, "Expected wide flow node"); - return new TBlockExpandChunkedWrapper(ctx.Mutables, wideFlow, wideComponents.size()); } diff --git a/ydb/library/yql/minikql/comp_nodes/ut/mkql_blocks_ut.cpp b/ydb/library/yql/minikql/comp_nodes/ut/mkql_blocks_ut.cpp index 542031997e..1b206c5a28 100644 --- a/ydb/library/yql/minikql/comp_nodes/ut/mkql_blocks_ut.cpp +++ b/ydb/library/yql/minikql/comp_nodes/ut/mkql_blocks_ut.cpp @@ -150,9 +150,10 @@ Y_UNIT_TEST_LLVM(TestWideToBlocks) { } namespace { +template<bool LLVM> void TestChunked(bool withBlockExpand) { - TSetup<false> setup; - auto& pb = *setup.PgmBuilder; + TSetup<LLVM> setup; + TProgramBuilder& pb = *setup.PgmBuilder; const auto ui64Type = pb.NewDataType(NUdf::TDataType<ui64>::Id); const auto boolType = pb.NewDataType(NUdf::TDataType<bool>::Id); @@ -239,12 +240,12 @@ void TestChunked(bool withBlockExpand) { } // namespace -Y_UNIT_TEST(TestBlockExpandChunked) { - TestChunked(true); +Y_UNIT_TEST_LLVM(TestBlockExpandChunked) { + TestChunked<LLVM>(true); } -Y_UNIT_TEST(TestWideFromBlocksForChunked) { - TestChunked(false); +Y_UNIT_TEST_LLVM(TestWideFromBlocksForChunked) { + TestChunked<LLVM>(false); } Y_UNIT_TEST(TestScalar) { diff --git a/ydb/library/yql/minikql/computation/mkql_block_impl.cpp b/ydb/library/yql/minikql/computation/mkql_block_impl.cpp index 443cb4b970..bc7447bc71 100644 --- a/ydb/library/yql/minikql/computation/mkql_block_impl.cpp +++ b/ydb/library/yql/minikql/computation/mkql_block_impl.cpp @@ -245,10 +245,11 @@ const IComputationNode* TBlockFuncNode::TArrowNode::GetArgument(ui32 index) cons return Parent_->ArgsNodes[index]; } - TBlockState::TBlockState(TMemoryUsageInfo* memInfo, size_t width) : TBase(memInfo), Values(width), Arrays(width - 1ULL) -{} +{ + Pointer_ = Values.data(); +} void TBlockState::FillArrays() { auto& counterDatum = TArrowBlock::From(Values.back()).GetDatum(); diff --git a/ydb/library/yql/minikql/computation/mkql_block_impl.h b/ydb/library/yql/minikql/computation/mkql_block_impl.h index 383108430c..bc06810e02 100644 --- a/ydb/library/yql/minikql/computation/mkql_block_impl.h +++ b/ydb/library/yql/minikql/computation/mkql_block_impl.h @@ -87,6 +87,8 @@ struct TBlockState : public TComputationValue<TBlockState> { using TBase = TComputationValue<TBlockState>; ui64 Count = 0; + NUdf::TUnboxedValue* Pointer_ = nullptr; + TUnboxedValueVector Values; std::vector<std::deque<std::shared_ptr<arrow::ArrayData>>> Arrays; |