diff options
author | a-romanov <Anton.Romanov@ydb.tech> | 2023-08-02 20:21:37 +0300 |
---|---|---|
committer | a-romanov <Anton.Romanov@ydb.tech> | 2023-08-02 20:21:37 +0300 |
commit | 6e81914eac95d3ff33e3325a4aaacdc571165f88 (patch) | |
tree | f43e0c225508a84acacbc3f50707a54a5c3ef43f | |
parent | f9be57e932fd6758d10c7b96429d9c278f897ebd (diff) | |
download | ydb-6e81914eac95d3ff33e3325a4aaacdc571165f88.tar.gz |
YQL-16067 Implement LLVM for wide FromFlow.
-rw-r--r-- | ydb/library/yql/minikql/comp_nodes/mkql_flow.cpp | 165 |
1 files changed, 143 insertions, 22 deletions
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_flow.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_flow.cpp index 0a76ee8f57..5d952a1269 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_flow.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_flow.cpp @@ -366,8 +366,8 @@ private: const ui32 TempStateIndex; }; -class TFromWideFlowWrapper : public TMutableComputationNode<TFromWideFlowWrapper> { - typedef TMutableComputationNode<TFromWideFlowWrapper> TBaseComputation; +class TFromWideFlowWrapper : public TCustomValueCodegeneratorNode<TFromWideFlowWrapper> { +using TBaseComputation = TCustomValueCodegeneratorNode<TFromWideFlowWrapper>; public: class TStreamValue : public TComputationValue<TStreamValue> { public: @@ -381,11 +381,12 @@ public: , StubsIndex(stubsIndex) , ClientBuffer(nullptr) {} - private: - NUdf::EFetchStatus WideFetch(NUdf::TUnboxedValue* result, ui32 width) override { - Y_VERIFY_DEBUG(width == Width); - auto valuePtrs = CompCtx.WideFields.data() + StubsIndex; + NUdf::EFetchStatus WideFetch(NUdf::TUnboxedValue* result, ui32 width) final { + if (width != Width) + Throw(width, Width); + + const auto valuePtrs = CompCtx.WideFields.data() + StubsIndex; if (result != ClientBuffer) { for (ui32 i = 0; i < width; ++i) { valuePtrs[i] = result + i; @@ -393,8 +394,7 @@ public: ClientBuffer = result; } - EFetchResult status = WideFlow->FetchValues(CompCtx, valuePtrs); - switch (status) { + switch (const auto status = WideFlow->FetchValues(CompCtx, valuePtrs)) { case EFetchResult::Finish: return NUdf::EFetchStatus::Finish; case EFetchResult::Yield: @@ -408,27 +408,147 @@ public: IComputationWideFlowNode* const WideFlow; const ui32 Width; const ui32 StubsIndex; - NUdf::TUnboxedValue* ClientBuffer; + const NUdf::TUnboxedValue* ClientBuffer; }; - TFromWideFlowWrapper(TComputationMutables& mutables, IComputationWideFlowNode* wideFlow, ui32 width) + class TStreamCodegenValue : public TComputationValue<TStreamCodegenValue> { + public: + using TBase = TComputationValue<TStreamCodegenValue>; + using TWideFetchPtr = NUdf::EFetchStatus (*)(TComputationContext*, NUdf::TUnboxedValuePod*, ui32); + + TStreamCodegenValue(TMemoryUsageInfo* memInfo, TWideFetchPtr fetch, TComputationContext* ctx) + : TBase(memInfo), WideFetchFunc(fetch), Ctx(ctx) + {} + private: + NUdf::EFetchStatus WideFetch(NUdf::TUnboxedValue* result, ui32 width) final { + return WideFetchFunc(Ctx, result, width); + } + + const TWideFetchPtr WideFetchFunc; + TComputationContext* const Ctx; + }; + + TFromWideFlowWrapper(TComputationMutables& mutables, IComputationWideFlowNode* wideFlow, std::vector<EValueRepresentation>&& representations) : TBaseComputation(mutables) , WideFlow(wideFlow) - , Width(width) - , StubsIndex(mutables.IncrementWideFieldsIndex(width)) + , Representations(std::move(representations)) + , StubsIndex(mutables.IncrementWideFieldsIndex(Representations.size())) {} NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const { - return ctx.HolderFactory.Create<TStreamValue>(ctx, WideFlow, Width, StubsIndex); +#ifndef MKQL_DISABLE_CODEGEN + if (ctx.ExecuteLLVM && WideFetch) + return ctx.HolderFactory.Create<TStreamCodegenValue>(WideFetch, &ctx); +#endif + return ctx.HolderFactory.Create<TStreamValue>(ctx, WideFlow, Representations.size(), StubsIndex); } - private: void RegisterDependencies() const final { this->DependsOn(WideFlow); } + [[noreturn]] static void Throw(ui32 requested, ui32 expected) { + TStringBuilder res; + res << "Requested " << requested << " fields, but expected " << expected; + UdfTerminate(res.data()); + } +#ifndef MKQL_DISABLE_CODEGEN + void GenerateFunctions(const NYql::NCodegen::ICodegen::TPtr& codegen) final { + WideFetchFunc = GenerateFetcher(codegen); + codegen->ExportSymbol(WideFetchFunc); + } + + void FinalizeFunctions(const NYql::NCodegen::ICodegen::TPtr& codegen) final { + if (WideFetchFunc) + WideFetch = reinterpret_cast<TStreamCodegenValue::TWideFetchPtr>(codegen->GetPointerToFunction(WideFetchFunc)); + } + + Function* GenerateFetcher(const NYql::NCodegen::ICodegen::TPtr& codegen) const { + auto& module = codegen->GetModule(); + auto& context = codegen->GetContext(); + + const auto& name = TBaseComputation::MakeName("WideFetch"); + if (const auto f = module.getFunction(name.c_str())) + return f; + + const auto valueType = Type::getInt128Ty(context); + const auto contextType = GetCompContextType(context); + const auto statusType = Type::getInt32Ty(context); + const auto indexType = Type::getInt32Ty(context); + const auto funcType = FunctionType::get(statusType, {PointerType::getUnqual(contextType), PointerType::getUnqual(valueType), indexType}, false); + + TCodegenContext ctx(codegen); + ctx.Func = cast<Function>(module.getOrInsertFunction(name.c_str(), funcType).getCallee()); + + auto args = ctx.Func->arg_begin(); + + ctx.Ctx = &*args; + const auto valuesPtr = &*++args; + const auto width = &*++args; + + const auto main = BasicBlock::Create(context, "main", ctx.Func); + const auto work = BasicBlock::Create(context, "work", ctx.Func); + const auto fail = BasicBlock::Create(context, "fail", ctx.Func); + + auto block = main; + + const auto check = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, width, ConstantInt::get(width->getType(), Representations.size()), "check", block); + + BranchInst::Create(work, fail, check, block); + + block = work; + + std::vector<Value*> pointers(Representations.size()); + for (auto i = 0U; i < pointers.size(); ++i) { + pointers[i] = GetElementPtrInst::CreateInBounds(valueType, valuesPtr, {ConstantInt::get(indexType, i)}, (TString("ptr_") += ToString(i)).c_str(), block); + SafeUnRefUnboxed(pointers[i], ctx, block); + } + + const auto getres = GetNodeValues(WideFlow, ctx, block); + + const auto yield = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, getres.first, ConstantInt::get(indexType, static_cast<i32>(EFetchResult::Yield)), "yield", block); + const auto special = SelectInst::Create(yield, ConstantInt::get(indexType, static_cast<ui32>(NUdf::EFetchStatus::Yield)), ConstantInt::get(indexType, static_cast<ui32>(NUdf::EFetchStatus::Finish)), "special", block); + + const auto good = BasicBlock::Create(context, "good", ctx.Func); + const auto done = BasicBlock::Create(context, "done", ctx.Func); + + const auto result = PHINode::Create(statusType, 2U, "result", done); + result->addIncoming(special, block); + + const auto row = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, getres.first, ConstantInt::get(indexType, static_cast<i32>(EFetchResult::One)), "row", block); + BranchInst::Create(good, done, row, block); + + block = good; + + for (auto i = 0U; i < pointers.size(); ++i) { + auto value = getres.second[i](ctx, block); + ValueAddRef(Representations[i], value, ctx, block); + new StoreInst(value, pointers[i], block); + } + + result->addIncoming(ConstantInt::get(indexType, static_cast<ui32>(NUdf::EFetchStatus::Ok)), block); + BranchInst::Create(done, block); + + block = done; + ReturnInst::Create(context, result, block); + + block = fail; + + const auto throwFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TFromWideFlowWrapper::Throw)); + const auto throwFuncType = FunctionType::get(Type::getVoidTy(context), { indexType, indexType }, false); + const auto throwFuncPtr = CastInst::Create(Instruction::IntToPtr, throwFunc, PointerType::getUnqual(throwFuncType), "thrower", block); + CallInst::Create(throwFuncType, throwFuncPtr, { width, ConstantInt::get(width->getType(), Representations.size()) }, "", block)->setTailCall(); + new UnreachableInst(context, block); + + return ctx.Func; + } + + Function* WideFetchFunc = nullptr; + + TStreamCodegenValue::TWideFetchPtr WideFetch = nullptr; +#endif IComputationWideFlowNode* const WideFlow; - const ui32 Width; + const std::vector<EValueRepresentation> Representations; const ui32 StubsIndex; }; @@ -460,14 +580,15 @@ IComputationNode* WrapToFlow(TCallable& callable, const TComputationNodeFactoryC IComputationNode* WrapFromFlow(TCallable& callable, const TComputationNodeFactoryContext& ctx) { MKQL_ENSURE(callable.GetInputsCount() == 1, "Expected 1 args, got " << callable.GetInputsCount()); - const auto flowType = AS_TYPE(TFlowType, callable.GetInput(0).GetStaticType()); - if (flowType->GetItemType()->IsMulti()) { - auto wideFlow = dynamic_cast<IComputationWideFlowNode*>(LocateNode(ctx.NodeLocator, callable, 0)); - MKQL_ENSURE(wideFlow != nullptr, "Expected wide flow node"); - ui32 width = GetWideComponentsCount(flowType); - return new TFromWideFlowWrapper(ctx.Mutables, wideFlow, width); + const auto flow = LocateNode(ctx.NodeLocator, callable, 0); + if (const auto wide = dynamic_cast<IComputationWideFlowNode*>(flow)) { + const auto multiType = AS_TYPE(TMultiType, AS_TYPE(TFlowType, callable.GetInput(0).GetStaticType())->GetItemType()); + std::vector<EValueRepresentation> outputRepresentations(multiType->GetElementsCount()); + for (auto i = 0U; i < outputRepresentations.size(); ++i) + outputRepresentations[i] = GetValueRepresentation(multiType->GetElementType(i)); + return new TFromWideFlowWrapper(ctx.Mutables, wide, std::move(outputRepresentations)); } - return new TFromFlowWrapper(ctx.Mutables, LocateNode(ctx.NodeLocator, callable, 0)); + return new TFromFlowWrapper(ctx.Mutables, flow); } } |