aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authora-romanov <Anton.Romanov@ydb.tech>2023-08-02 20:21:37 +0300
committera-romanov <Anton.Romanov@ydb.tech>2023-08-02 20:21:37 +0300
commit6e81914eac95d3ff33e3325a4aaacdc571165f88 (patch)
treef43e0c225508a84acacbc3f50707a54a5c3ef43f
parentf9be57e932fd6758d10c7b96429d9c278f897ebd (diff)
downloadydb-6e81914eac95d3ff33e3325a4aaacdc571165f88.tar.gz
YQL-16067 Implement LLVM for wide FromFlow.
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_flow.cpp165
1 files changed, 143 insertions, 22 deletions
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_flow.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_flow.cpp
index 0a76ee8f57..5d952a1269 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_flow.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_flow.cpp
@@ -366,8 +366,8 @@ private:
const ui32 TempStateIndex;
};
-class TFromWideFlowWrapper : public TMutableComputationNode<TFromWideFlowWrapper> {
- typedef TMutableComputationNode<TFromWideFlowWrapper> TBaseComputation;
+class TFromWideFlowWrapper : public TCustomValueCodegeneratorNode<TFromWideFlowWrapper> {
+using TBaseComputation = TCustomValueCodegeneratorNode<TFromWideFlowWrapper>;
public:
class TStreamValue : public TComputationValue<TStreamValue> {
public:
@@ -381,11 +381,12 @@ public:
, StubsIndex(stubsIndex)
, ClientBuffer(nullptr)
{}
-
private:
- NUdf::EFetchStatus WideFetch(NUdf::TUnboxedValue* result, ui32 width) override {
- Y_VERIFY_DEBUG(width == Width);
- auto valuePtrs = CompCtx.WideFields.data() + StubsIndex;
+ NUdf::EFetchStatus WideFetch(NUdf::TUnboxedValue* result, ui32 width) final {
+ if (width != Width)
+ Throw(width, Width);
+
+ const auto valuePtrs = CompCtx.WideFields.data() + StubsIndex;
if (result != ClientBuffer) {
for (ui32 i = 0; i < width; ++i) {
valuePtrs[i] = result + i;
@@ -393,8 +394,7 @@ public:
ClientBuffer = result;
}
- EFetchResult status = WideFlow->FetchValues(CompCtx, valuePtrs);
- switch (status) {
+ switch (const auto status = WideFlow->FetchValues(CompCtx, valuePtrs)) {
case EFetchResult::Finish:
return NUdf::EFetchStatus::Finish;
case EFetchResult::Yield:
@@ -408,27 +408,147 @@ public:
IComputationWideFlowNode* const WideFlow;
const ui32 Width;
const ui32 StubsIndex;
- NUdf::TUnboxedValue* ClientBuffer;
+ const NUdf::TUnboxedValue* ClientBuffer;
};
- TFromWideFlowWrapper(TComputationMutables& mutables, IComputationWideFlowNode* wideFlow, ui32 width)
+ class TStreamCodegenValue : public TComputationValue<TStreamCodegenValue> {
+ public:
+ using TBase = TComputationValue<TStreamCodegenValue>;
+ using TWideFetchPtr = NUdf::EFetchStatus (*)(TComputationContext*, NUdf::TUnboxedValuePod*, ui32);
+
+ TStreamCodegenValue(TMemoryUsageInfo* memInfo, TWideFetchPtr fetch, TComputationContext* ctx)
+ : TBase(memInfo), WideFetchFunc(fetch), Ctx(ctx)
+ {}
+ private:
+ NUdf::EFetchStatus WideFetch(NUdf::TUnboxedValue* result, ui32 width) final {
+ return WideFetchFunc(Ctx, result, width);
+ }
+
+ const TWideFetchPtr WideFetchFunc;
+ TComputationContext* const Ctx;
+ };
+
+ TFromWideFlowWrapper(TComputationMutables& mutables, IComputationWideFlowNode* wideFlow, std::vector<EValueRepresentation>&& representations)
: TBaseComputation(mutables)
, WideFlow(wideFlow)
- , Width(width)
- , StubsIndex(mutables.IncrementWideFieldsIndex(width))
+ , Representations(std::move(representations))
+ , StubsIndex(mutables.IncrementWideFieldsIndex(Representations.size()))
{}
NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
- return ctx.HolderFactory.Create<TStreamValue>(ctx, WideFlow, Width, StubsIndex);
+#ifndef MKQL_DISABLE_CODEGEN
+ if (ctx.ExecuteLLVM && WideFetch)
+ return ctx.HolderFactory.Create<TStreamCodegenValue>(WideFetch, &ctx);
+#endif
+ return ctx.HolderFactory.Create<TStreamValue>(ctx, WideFlow, Representations.size(), StubsIndex);
}
-
private:
void RegisterDependencies() const final {
this->DependsOn(WideFlow);
}
+ [[noreturn]] static void Throw(ui32 requested, ui32 expected) {
+ TStringBuilder res;
+ res << "Requested " << requested << " fields, but expected " << expected;
+ UdfTerminate(res.data());
+ }
+#ifndef MKQL_DISABLE_CODEGEN
+ void GenerateFunctions(const NYql::NCodegen::ICodegen::TPtr& codegen) final {
+ WideFetchFunc = GenerateFetcher(codegen);
+ codegen->ExportSymbol(WideFetchFunc);
+ }
+
+ void FinalizeFunctions(const NYql::NCodegen::ICodegen::TPtr& codegen) final {
+ if (WideFetchFunc)
+ WideFetch = reinterpret_cast<TStreamCodegenValue::TWideFetchPtr>(codegen->GetPointerToFunction(WideFetchFunc));
+ }
+
+ Function* GenerateFetcher(const NYql::NCodegen::ICodegen::TPtr& codegen) const {
+ auto& module = codegen->GetModule();
+ auto& context = codegen->GetContext();
+
+ const auto& name = TBaseComputation::MakeName("WideFetch");
+ if (const auto f = module.getFunction(name.c_str()))
+ return f;
+
+ const auto valueType = Type::getInt128Ty(context);
+ const auto contextType = GetCompContextType(context);
+ const auto statusType = Type::getInt32Ty(context);
+ const auto indexType = Type::getInt32Ty(context);
+ const auto funcType = FunctionType::get(statusType, {PointerType::getUnqual(contextType), PointerType::getUnqual(valueType), indexType}, false);
+
+ TCodegenContext ctx(codegen);
+ ctx.Func = cast<Function>(module.getOrInsertFunction(name.c_str(), funcType).getCallee());
+
+ auto args = ctx.Func->arg_begin();
+
+ ctx.Ctx = &*args;
+ const auto valuesPtr = &*++args;
+ const auto width = &*++args;
+
+ const auto main = BasicBlock::Create(context, "main", ctx.Func);
+ const auto work = BasicBlock::Create(context, "work", ctx.Func);
+ const auto fail = BasicBlock::Create(context, "fail", ctx.Func);
+
+ auto block = main;
+
+ const auto check = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, width, ConstantInt::get(width->getType(), Representations.size()), "check", block);
+
+ BranchInst::Create(work, fail, check, block);
+
+ block = work;
+
+ std::vector<Value*> pointers(Representations.size());
+ for (auto i = 0U; i < pointers.size(); ++i) {
+ pointers[i] = GetElementPtrInst::CreateInBounds(valueType, valuesPtr, {ConstantInt::get(indexType, i)}, (TString("ptr_") += ToString(i)).c_str(), block);
+ SafeUnRefUnboxed(pointers[i], ctx, block);
+ }
+
+ const auto getres = GetNodeValues(WideFlow, ctx, block);
+
+ const auto yield = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, getres.first, ConstantInt::get(indexType, static_cast<i32>(EFetchResult::Yield)), "yield", block);
+ const auto special = SelectInst::Create(yield, ConstantInt::get(indexType, static_cast<ui32>(NUdf::EFetchStatus::Yield)), ConstantInt::get(indexType, static_cast<ui32>(NUdf::EFetchStatus::Finish)), "special", block);
+
+ const auto good = BasicBlock::Create(context, "good", ctx.Func);
+ const auto done = BasicBlock::Create(context, "done", ctx.Func);
+
+ const auto result = PHINode::Create(statusType, 2U, "result", done);
+ result->addIncoming(special, block);
+
+ const auto row = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, getres.first, ConstantInt::get(indexType, static_cast<i32>(EFetchResult::One)), "row", block);
+ BranchInst::Create(good, done, row, block);
+
+ block = good;
+
+ for (auto i = 0U; i < pointers.size(); ++i) {
+ auto value = getres.second[i](ctx, block);
+ ValueAddRef(Representations[i], value, ctx, block);
+ new StoreInst(value, pointers[i], block);
+ }
+
+ result->addIncoming(ConstantInt::get(indexType, static_cast<ui32>(NUdf::EFetchStatus::Ok)), block);
+ BranchInst::Create(done, block);
+
+ block = done;
+ ReturnInst::Create(context, result, block);
+
+ block = fail;
+
+ const auto throwFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TFromWideFlowWrapper::Throw));
+ const auto throwFuncType = FunctionType::get(Type::getVoidTy(context), { indexType, indexType }, false);
+ const auto throwFuncPtr = CastInst::Create(Instruction::IntToPtr, throwFunc, PointerType::getUnqual(throwFuncType), "thrower", block);
+ CallInst::Create(throwFuncType, throwFuncPtr, { width, ConstantInt::get(width->getType(), Representations.size()) }, "", block)->setTailCall();
+ new UnreachableInst(context, block);
+
+ return ctx.Func;
+ }
+
+ Function* WideFetchFunc = nullptr;
+
+ TStreamCodegenValue::TWideFetchPtr WideFetch = nullptr;
+#endif
IComputationWideFlowNode* const WideFlow;
- const ui32 Width;
+ const std::vector<EValueRepresentation> Representations;
const ui32 StubsIndex;
};
@@ -460,14 +580,15 @@ IComputationNode* WrapToFlow(TCallable& callable, const TComputationNodeFactoryC
IComputationNode* WrapFromFlow(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
MKQL_ENSURE(callable.GetInputsCount() == 1, "Expected 1 args, got " << callable.GetInputsCount());
- const auto flowType = AS_TYPE(TFlowType, callable.GetInput(0).GetStaticType());
- if (flowType->GetItemType()->IsMulti()) {
- auto wideFlow = dynamic_cast<IComputationWideFlowNode*>(LocateNode(ctx.NodeLocator, callable, 0));
- MKQL_ENSURE(wideFlow != nullptr, "Expected wide flow node");
- ui32 width = GetWideComponentsCount(flowType);
- return new TFromWideFlowWrapper(ctx.Mutables, wideFlow, width);
+ const auto flow = LocateNode(ctx.NodeLocator, callable, 0);
+ if (const auto wide = dynamic_cast<IComputationWideFlowNode*>(flow)) {
+ const auto multiType = AS_TYPE(TMultiType, AS_TYPE(TFlowType, callable.GetInput(0).GetStaticType())->GetItemType());
+ std::vector<EValueRepresentation> outputRepresentations(multiType->GetElementsCount());
+ for (auto i = 0U; i < outputRepresentations.size(); ++i)
+ outputRepresentations[i] = GetValueRepresentation(multiType->GetElementType(i));
+ return new TFromWideFlowWrapper(ctx.Mutables, wide, std::move(outputRepresentations));
}
- return new TFromFlowWrapper(ctx.Mutables, LocateNode(ctx.NodeLocator, callable, 0));
+ return new TFromFlowWrapper(ctx.Mutables, flow);
}
}