aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVitaly Stoyan <vvvv@ydb.tech>2024-07-08 14:34:05 +0300
committerGitHub <noreply@github.com>2024-07-08 16:34:05 +0500
commitd20552c258db9a1b1212f16be58b617ce6945b2a (patch)
treeb375f86eddb0eba6ef7c7c34c8119c98ba76253c
parent7073beb9ee3c43199529046c0b561da135185675 (diff)
downloadydb-d20552c258db9a1b1212f16be58b617ce6945b2a.tar.gz
Less LLVM in block agg (#6359)
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp429
1 files changed, 230 insertions, 199 deletions
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp
index 0de21c6be7..57d5d09fe2 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp
@@ -456,72 +456,57 @@ size_t CalcMaxBlockLenForOutput(TType* out) {
return CalcBlockLen(maxBlockItemSize);
}
+class TBlockCombineAllWrapperCodegenBase {
+protected:
+#ifndef MKQL_DISABLE_CODEGEN
+ class TLLVMFieldsStructureState: public TLLVMFieldsStructure<TComputationValue<TBlockState>> {
+ private:
+ using TBase = TLLVMFieldsStructure<TComputationValue<TBlockState>>;
+ llvm::PointerType*const PointerType;
+ llvm::IntegerType*const IsFinishedType;
+ public:
+ std::vector<llvm::Type*> GetFieldsArray() {
+ std::vector<llvm::Type*> result = TBase::GetFields();
+ result.emplace_back(PointerType);
+ result.emplace_back(IsFinishedType);
+ return result;
+ }
-class TBlockCombineAllWrapper : public TStatefulWideFlowCodegeneratorNode<TBlockCombineAllWrapper> {
-using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TBlockCombineAllWrapper>;
-public:
- TBlockCombineAllWrapper(TComputationMutables& mutables,
- IComputationWideFlowNode* flow,
- std::optional<ui32> filterColumn,
- size_t width,
- std::vector<TAggParams<IBlockAggregatorCombineAll>>&& aggsParams)
- : TBaseComputation(mutables, flow, EValueRepresentation::Boxed)
- , Flow_(flow)
- , FilterColumn_(filterColumn)
- , Width_(width)
- , AggsParams_(std::move(aggsParams))
- , WideFieldsIndex_(mutables.IncrementWideFieldsIndex(width))
- {
- MKQL_ENSURE(Width_ > 0, "Missing block length column");
- }
-
- EFetchResult DoCalculate(NUdf::TUnboxedValue& state,
- TComputationContext& ctx,
- NUdf::TUnboxedValue*const* output) const
- {
- auto& s = GetState(state, ctx);
- if (s.IsFinished_)
- return EFetchResult::Finish;
+ llvm::Constant* GetPointer() {
+ return llvm::ConstantInt::get(llvm::Type::getInt32Ty(Context), TBase::GetFieldsCount() + 0);
+ }
- for (const auto fields = ctx.WideFields.data() + WideFieldsIndex_;;) {
- switch (Flow_->FetchValues(ctx, fields)) {
- case EFetchResult::Yield:
- return EFetchResult::Yield;
- case EFetchResult::One:
- s.ProcessInput();
- continue;
- case EFetchResult::Finish:
- break;
- }
- if (s.MakeOutput()) {
- for (size_t i = 0; i < AggsParams_.size(); ++i) {
- if (const auto out = output[i]) {
- *out = s.Get(i);
- }
- }
- return EFetchResult::One;
- }
- return EFetchResult::Finish;
+ llvm::Constant* GetIsFinished() {
+ return llvm::ConstantInt::get(llvm::Type::getInt32Ty(Context), TBase::GetFieldsCount() + 1);
}
- }
-#ifndef MKQL_DISABLE_CODEGEN
- ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
+
+ TLLVMFieldsStructureState(llvm::LLVMContext& context, size_t width)
+ : TBase(context)
+ , PointerType(llvm::PointerType::getUnqual(llvm::ArrayType::get(llvm::Type::getInt128Ty(Context), width)))
+ , IsFinishedType(llvm::Type::getInt1Ty(Context))
+ {}
+ };
+
+ ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValuesImpl(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block,
+ IComputationWideFlowNode* flow, size_t width, size_t aggCount,
+ uintptr_t getStateMethodPtr, uint64_t makeStateMethodPtr,
+ uintptr_t processInputMethodPtr, uintptr_t makeOutputMethodPtr) const {
auto& context = ctx.Codegen.GetContext();
const auto valueType = Type::getInt128Ty(context);
const auto statusType = Type::getInt32Ty(context);
const auto indexType = Type::getInt64Ty(context);
const auto flagType = Type::getInt1Ty(context);
- const auto arrayType = ArrayType::get(valueType, Width_);
+ const auto arrayType = ArrayType::get(valueType, width);
const auto ptrValuesType = PointerType::getUnqual(arrayType);
- TLLVMFieldsStructureState stateFields(context, Width_);
+ TLLVMFieldsStructureState stateFields(context, width);
const auto stateType = StructType::get(context, stateFields.GetFieldsArray());
const auto statePtrType = PointerType::getUnqual(stateType);
const auto atTop = &ctx.Func->getEntryBlock().back();
- const auto getFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Get));
+ const auto getFunc = ConstantInt::get(Type::getInt64Ty(context), getStateMethodPtr);
const auto getType = FunctionType::get(valueType, {statePtrType, indexType}, false);
const auto getPtr = CastInst::Create(Instruction::IntToPtr, getFunc, PointerType::getUnqual(getType), "get", atTop);
@@ -540,7 +525,7 @@ public:
const auto ptrType = PointerType::getUnqual(StructType::get(context));
const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block);
- const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TBlockCombineAllWrapper::MakeState));
+ const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), makeStateMethodPtr);
const auto makeType = FunctionType::get(Type::getVoidTy(context), {self->getType(), statePtr->getType(), ctx.Ctx->getType()}, false);
const auto makeFuncPtr = CastInst::Create(Instruction::IntToPtr, makeFunc, PointerType::getUnqual(makeType), "function", block);
CallInst::Create(makeType, makeFuncPtr, {self, statePtr, ctx.Ctx}, "", block);
@@ -567,7 +552,7 @@ public:
const auto values = new LoadInst(ptrValuesType, valuesPtr, "values", block);
SafeUnRefUnboxed(values, ctx, block);
- const auto getres = GetNodeValues(Flow_, ctx, block);
+ const auto getres = GetNodeValues(flow, ctx, block);
result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Yield)), block);
const auto way = SwitchInst::Create(getres.first, good, 2U, block);
@@ -584,7 +569,7 @@ public:
}
new StoreInst(array, values, block);
- const auto processBlockFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::ProcessInput));
+ const auto processBlockFunc = ConstantInt::get(Type::getInt64Ty(context), processInputMethodPtr);
const auto processBlockType = FunctionType::get(Type::getVoidTy(context), {statePtrType, ctx.GetFactory()->getType()}, false);
const auto processBlockPtr = CastInst::Create(Instruction::IntToPtr, processBlockFunc, PointerType::getUnqual(processBlockType), "process_inputs_func", block);
CallInst::Create(processBlockType, processBlockPtr, {stateArg, ctx.GetFactory()}, "", block);
@@ -593,7 +578,7 @@ public:
block = work;
- const auto makeOutputFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::MakeOutput));
+ const auto makeOutputFunc = ConstantInt::get(Type::getInt64Ty(context), makeOutputMethodPtr);
const auto makeOutputType = FunctionType::get(flagType, {statePtrType, ctx.GetFactory()->getType()}, false);
const auto makeOutputPtr = CastInst::Create(Instruction::IntToPtr, makeOutputFunc, PointerType::getUnqual(makeOutputType), "make_output_func", block);
const auto hasData = CallInst::Create(makeOutputType, makeOutputPtr, {stateArg, ctx.GetFactory()}, "make_output", block);
@@ -605,7 +590,7 @@ public:
block = over;
- ICodegeneratorInlineWideNode::TGettersList getters(AggsParams_.size());
+ ICodegeneratorInlineWideNode::TGettersList getters(aggCount);
for (size_t idx = 0U; idx < getters.size(); ++idx) {
getters[idx] = [idx, getType, getPtr, indexType, statePtrType, stateOnStack](const TCodegenContext& ctx, BasicBlock*& block) {
Y_UNUSED(ctx);
@@ -616,6 +601,63 @@ public:
return {result, std::move(getters)};
}
#endif
+};
+
+class TBlockCombineAllWrapper : public TStatefulWideFlowCodegeneratorNode<TBlockCombineAllWrapper>,
+ protected TBlockCombineAllWrapperCodegenBase {
+using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TBlockCombineAllWrapper>;
+public:
+ TBlockCombineAllWrapper(TComputationMutables& mutables,
+ IComputationWideFlowNode* flow,
+ std::optional<ui32> filterColumn,
+ size_t width,
+ std::vector<TAggParams<IBlockAggregatorCombineAll>>&& aggsParams)
+ : TBaseComputation(mutables, flow, EValueRepresentation::Boxed)
+ , Flow_(flow)
+ , FilterColumn_(filterColumn)
+ , Width_(width)
+ , AggsParams_(std::move(aggsParams))
+ , WideFieldsIndex_(mutables.IncrementWideFieldsIndex(width))
+ {
+ MKQL_ENSURE(Width_ > 0, "Missing block length column");
+ }
+
+ EFetchResult DoCalculate(NUdf::TUnboxedValue& state,
+ TComputationContext& ctx,
+ NUdf::TUnboxedValue*const* output) const
+ {
+ auto& s = GetState(state, ctx);
+ if (s.IsFinished_)
+ return EFetchResult::Finish;
+
+ for (const auto fields = ctx.WideFields.data() + WideFieldsIndex_;;) {
+ switch (Flow_->FetchValues(ctx, fields)) {
+ case EFetchResult::Yield:
+ return EFetchResult::Yield;
+ case EFetchResult::One:
+ s.ProcessInput();
+ continue;
+ case EFetchResult::Finish:
+ break;
+ }
+ if (s.MakeOutput()) {
+ for (size_t i = 0; i < AggsParams_.size(); ++i) {
+ if (const auto out = output[i]) {
+ *out = s.Get(i);
+ }
+ }
+ return EFetchResult::One;
+ }
+ return EFetchResult::Finish;
+ }
+ }
+#ifndef MKQL_DISABLE_CODEGEN
+ ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
+ return DoGenGetValuesImpl(ctx, statePtr, block, Flow_, Width_, AggsParams_.size(),
+ GetMethodPtr(&TState::Get), GetMethodPtr(&TBlockCombineAllWrapper::MakeState),
+ GetMethodPtr(&TState::ProcessInput), GetMethodPtr(&TState::MakeOutput));
+ }
+#endif
private:
struct TState : public TComputationValue<TState> {
NUdf::TUnboxedValue* Pointer_ = nullptr;
@@ -701,35 +743,6 @@ private:
return Values_[index];
}
};
-#ifndef MKQL_DISABLE_CODEGEN
- class TLLVMFieldsStructureState: public TLLVMFieldsStructure<TComputationValue<TBlockState>> {
- private:
- using TBase = TLLVMFieldsStructure<TComputationValue<TBlockState>>;
- llvm::PointerType*const PointerType;
- llvm::IntegerType*const IsFinishedType;
- public:
- std::vector<llvm::Type*> GetFieldsArray() {
- std::vector<llvm::Type*> result = TBase::GetFields();
- result.emplace_back(PointerType);
- result.emplace_back(IsFinishedType);
- return result;
- }
-
- llvm::Constant* GetPointer() {
- return llvm::ConstantInt::get(llvm::Type::getInt32Ty(Context), TBase::GetFieldsCount() + 0);
- }
-
- llvm::Constant* GetIsFinished() {
- return llvm::ConstantInt::get(llvm::Type::getInt32Ty(Context), TBase::GetFieldsCount() + 1);
- }
-
- TLLVMFieldsStructureState(llvm::LLVMContext& context, size_t width)
- : TBase(context)
- , PointerType(llvm::PointerType::getUnqual(llvm::ArrayType::get(llvm::Type::getInt128Ty(Context), width)))
- , IsFinishedType(llvm::Type::getInt1Ty(Context))
- {}
- };
-#endif
void RegisterDependencies() const final {
FlowDependsOn(Flow_);
}
@@ -839,101 +852,61 @@ std::hash<TExternalFixedSizeKey> MakeHash(ui32 keyLength) {
return std::hash<TExternalFixedSizeKey>(keyLength);
}
-template <typename TKey, typename TAggregator, typename TFixedAggState, bool UseSet, bool UseFilter, bool Finalize, bool Many, typename TDerived>
-class THashedWrapperBase : public TStatefulWideFlowCodegeneratorNode<TDerived> {
- using TComputationBase = TStatefulWideFlowCodegeneratorNode<TDerived>;
- static constexpr bool UseArena = !InlineAggState && std::is_same<TFixedAggState, TStateArena>::value;
-public:
- THashedWrapperBase(TComputationMutables& mutables,
- IComputationWideFlowNode* flow,
- std::optional<ui32> filterColumn,
- size_t width,
- const std::vector<TKeyParams>& keys,
- size_t maxBlockLen,
- ui32 keyLength,
- std::vector<TAggParams<TAggregator>>&& aggsParams,
- ui32 streamIndex,
- std::vector<std::vector<ui32>>&& streams)
- : TComputationBase(mutables, flow, EValueRepresentation::Boxed)
- , Flow_(flow)
- , FilterColumn_(filterColumn)
- , Width_(width)
- , OutputWidth_(keys.size() + aggsParams.size() + 1)
- , WideFieldsIndex_(mutables.IncrementWideFieldsIndex(width))
- , Keys_(keys)
- , MaxBlockLen_(maxBlockLen)
- , AggsParams_(std::move(aggsParams))
- , KeyLength_(keyLength)
- , StreamIndex_(streamIndex)
- , Streams_(std::move(streams))
- {
- MKQL_ENSURE(Width_ > 0, "Missing block length column");
- if constexpr (UseFilter) {
- MKQL_ENSURE(filterColumn, "Missing filter column");
- MKQL_ENSURE(!Finalize, "Filter isn't compatible with Finalize");
- } else {
- MKQL_ENSURE(!filterColumn, "Unexpected filter column");
+class THashedWrapperCodegenBase {
+protected:
+#ifndef MKQL_DISABLE_CODEGEN
+ class TLLVMFieldsStructureState: public TLLVMFieldsStructureBlockState {
+ private:
+ using TBase = TLLVMFieldsStructureBlockState;
+ llvm::IntegerType*const WritingOutputType;
+ llvm::IntegerType*const IsFinishedType;
+ protected:
+ using TBase::Context;
+ public:
+ std::vector<llvm::Type*> GetFieldsArray() {
+ std::vector<llvm::Type*> result = TBase::GetFieldsArray();
+ result.emplace_back(WritingOutputType);
+ result.emplace_back(IsFinishedType);
+ return result;
}
- }
-
- EFetchResult DoCalculate(NUdf::TUnboxedValue& state,
- TComputationContext& ctx,
- NUdf::TUnboxedValue*const* output) const
- {
- auto& s = GetState(state, ctx);
- if (!s.Count) {
- if (s.IsFinished_)
- return EFetchResult::Finish;
-
- while (!s.WritingOutput_) {
- const auto fields = ctx.WideFields.data() + WideFieldsIndex_;
- s.Values_.assign(s.Values_.size(), NUdf::TUnboxedValuePod());
- switch (Flow_->FetchValues(ctx, fields)) {
- case EFetchResult::Yield:
- return EFetchResult::Yield;
- case EFetchResult::One:
- s.ProcessInput(ctx.HolderFactory);
- continue;
- case EFetchResult::Finish:
- break;
- }
-
- if (s.Finish())
- break;
- else
- return EFetchResult::Finish;
- }
- if (!s.FillOutput(ctx.HolderFactory))
- return EFetchResult::Finish;
+ llvm::Constant* GetWritingOutput() {
+ return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + BaseFields);
}
- const auto sliceSize = s.Slice();
- for (size_t i = 0; i < OutputWidth_; ++i) {
- if (const auto out = output[i]) {
- *out = s.Get(sliceSize, ctx.HolderFactory, i);
- }
+ llvm::Constant* GetIsFinished() {
+ return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + BaseFields + 1);
}
- return EFetchResult::One;
- }
-#ifndef MKQL_DISABLE_CODEGEN
- ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
+
+ TLLVMFieldsStructureState(llvm::LLVMContext& context, size_t width)
+ : TBase(context, width)
+ , WritingOutputType(Type::getInt1Ty(Context))
+ , IsFinishedType(Type::getInt1Ty(Context))
+ {}
+ };
+
+ Y_NO_INLINE ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValuesImpl(
+ const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block,
+ IComputationWideFlowNode* flow, size_t width, size_t outputWidth,
+ uintptr_t getStateMethodPtr, uintptr_t makeStateMethodPtr,
+ uintptr_t processInputMethodPtr, uintptr_t finishMethodPtr,
+ uintptr_t fillOutputMethodPtr, uintptr_t sliceMethodPtr) const {
auto& context = ctx.Codegen.GetContext();
const auto valueType = Type::getInt128Ty(context);
const auto statusType = Type::getInt32Ty(context);
const auto indexType = Type::getInt64Ty(context);
const auto flagType = Type::getInt1Ty(context);
- const auto arrayType = ArrayType::get(valueType, Width_);
+ const auto arrayType = ArrayType::get(valueType, width);
const auto ptrValuesType = PointerType::getUnqual(arrayType);
- TLLVMFieldsStructureState stateFields(context, Width_);
+ TLLVMFieldsStructureState stateFields(context, width);
const auto stateType = StructType::get(context, stateFields.GetFieldsArray());
const auto statePtrType = PointerType::getUnqual(stateType);
const auto atTop = &ctx.Func->getEntryBlock().back();
- const auto getFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Get));
+ const auto getFunc = ConstantInt::get(Type::getInt64Ty(context), getStateMethodPtr);
const auto getType = FunctionType::get(valueType, {statePtrType, indexType, ctx.GetFactory()->getType(), indexType}, false);
const auto getPtr = CastInst::Create(Instruction::IntToPtr, getFunc, PointerType::getUnqual(getType), "get", atTop);
@@ -959,7 +932,7 @@ public:
const auto ptrType = PointerType::getUnqual(StructType::get(context));
const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block);
- const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&THashedWrapperBase::MakeState));
+ const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), makeStateMethodPtr);
const auto makeType = FunctionType::get(Type::getVoidTy(context), {self->getType(), statePtr->getType(), ctx.Ctx->getType()}, false);
const auto makeFuncPtr = CastInst::Create(Instruction::IntToPtr, makeFunc, PointerType::getUnqual(makeType), "function", block);
CallInst::Create(makeType, makeFuncPtr, {self, statePtr, ctx.Ctx}, "", block);
@@ -1001,7 +974,7 @@ public:
const auto values = new LoadInst(ptrValuesType, valuesPtr, "values", block);
SafeUnRefUnboxed(values, ctx, block);
- const auto getres = GetNodeValues(Flow_, ctx, block);
+ const auto getres = GetNodeValues(flow, ctx, block);
result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Yield)), block);
const auto way = SwitchInst::Create(getres.first, good, 2U, block);
@@ -1018,7 +991,7 @@ public:
}
new StoreInst(array, values, block);
- const auto processBlockFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::ProcessInput));
+ const auto processBlockFunc = ConstantInt::get(Type::getInt64Ty(context), processInputMethodPtr);
const auto processBlockType = FunctionType::get(Type::getVoidTy(context), {statePtrType, ctx.GetFactory()->getType()}, false);
const auto processBlockPtr = CastInst::Create(Instruction::IntToPtr, processBlockFunc, PointerType::getUnqual(processBlockType), "process_inputs_func", block);
CallInst::Create(processBlockType, processBlockPtr, {stateArg, ctx.GetFactory()}, "", block);
@@ -1027,7 +1000,7 @@ public:
block = stop;
- const auto finishFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Finish));
+ const auto finishFunc = ConstantInt::get(Type::getInt64Ty(context), finishMethodPtr);
const auto finishType = FunctionType::get(flagType, {statePtrType}, false);
const auto finishPtr = CastInst::Create(Instruction::IntToPtr, finishFunc, PointerType::getUnqual(finishType), "finish_func", block);
const auto hasOutput = CallInst::Create(finishType, finishPtr, {stateArg}, "has_output", block);
@@ -1038,7 +1011,7 @@ public:
block = work;
- const auto fillBlockFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::FillOutput));
+ const auto fillBlockFunc = ConstantInt::get(Type::getInt64Ty(context), fillOutputMethodPtr);
const auto fillBlockType = FunctionType::get(flagType, {statePtrType, ctx.GetFactory()->getType()}, false);
const auto fillBlockPtr = CastInst::Create(Instruction::IntToPtr, fillBlockFunc, PointerType::getUnqual(fillBlockType), "fill_output_func", block);
const auto hasData = CallInst::Create(fillBlockType, fillBlockPtr, {stateArg, ctx.GetFactory()}, "fill_output", block);
@@ -1049,7 +1022,7 @@ public:
block = fill;
- const auto sliceFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Slice));
+ const auto sliceFunc = ConstantInt::get(Type::getInt64Ty(context), sliceMethodPtr);
const auto sliceType = FunctionType::get(indexType, {statePtrType}, false);
const auto slicePtr = CastInst::Create(Instruction::IntToPtr, sliceFunc, PointerType::getUnqual(sliceType), "slice_func", block);
const auto slice = CallInst::Create(sliceType, slicePtr, {stateArg}, "slice", block);
@@ -1062,7 +1035,7 @@ public:
block = over;
- ICodegeneratorInlineWideNode::TGettersList getters(OutputWidth_);
+ ICodegeneratorInlineWideNode::TGettersList getters(outputWidth);
for (size_t idx = 0U; idx < getters.size(); ++idx) {
getters[idx] = [idx, getType, getPtr, heightPtr, indexType, statePtrType, stateOnStack](const TCodegenContext& ctx, BasicBlock*& block) {
const auto stateArg = new LoadInst(statePtrType, stateOnStack, "state", block);
@@ -1073,6 +1046,95 @@ public:
return {result, std::move(getters)};
}
#endif
+};
+
+template <typename TKey, typename TAggregator, typename TFixedAggState, bool UseSet, bool UseFilter, bool Finalize, bool Many, typename TDerived>
+class THashedWrapperBase : public TStatefulWideFlowCodegeneratorNode<TDerived>,
+ protected THashedWrapperCodegenBase
+{
+ using TComputationBase = TStatefulWideFlowCodegeneratorNode<TDerived>;
+ static constexpr bool UseArena = !InlineAggState && std::is_same<TFixedAggState, TStateArena>::value;
+public:
+ THashedWrapperBase(TComputationMutables& mutables,
+ IComputationWideFlowNode* flow,
+ std::optional<ui32> filterColumn,
+ size_t width,
+ const std::vector<TKeyParams>& keys,
+ size_t maxBlockLen,
+ ui32 keyLength,
+ std::vector<TAggParams<TAggregator>>&& aggsParams,
+ ui32 streamIndex,
+ std::vector<std::vector<ui32>>&& streams)
+ : TComputationBase(mutables, flow, EValueRepresentation::Boxed)
+ , Flow_(flow)
+ , FilterColumn_(filterColumn)
+ , Width_(width)
+ , OutputWidth_(keys.size() + aggsParams.size() + 1)
+ , WideFieldsIndex_(mutables.IncrementWideFieldsIndex(width))
+ , Keys_(keys)
+ , MaxBlockLen_(maxBlockLen)
+ , AggsParams_(std::move(aggsParams))
+ , KeyLength_(keyLength)
+ , StreamIndex_(streamIndex)
+ , Streams_(std::move(streams))
+ {
+ MKQL_ENSURE(Width_ > 0, "Missing block length column");
+ if constexpr (UseFilter) {
+ MKQL_ENSURE(filterColumn, "Missing filter column");
+ MKQL_ENSURE(!Finalize, "Filter isn't compatible with Finalize");
+ } else {
+ MKQL_ENSURE(!filterColumn, "Unexpected filter column");
+ }
+ }
+
+ EFetchResult DoCalculate(NUdf::TUnboxedValue& state,
+ TComputationContext& ctx,
+ NUdf::TUnboxedValue*const* output) const
+ {
+ auto& s = GetState(state, ctx);
+ if (!s.Count) {
+ if (s.IsFinished_)
+ return EFetchResult::Finish;
+
+ while (!s.WritingOutput_) {
+ const auto fields = ctx.WideFields.data() + WideFieldsIndex_;
+ s.Values_.assign(s.Values_.size(), NUdf::TUnboxedValuePod());
+ switch (Flow_->FetchValues(ctx, fields)) {
+ case EFetchResult::Yield:
+ return EFetchResult::Yield;
+ case EFetchResult::One:
+ s.ProcessInput(ctx.HolderFactory);
+ continue;
+ case EFetchResult::Finish:
+ break;
+ }
+
+ if (s.Finish())
+ break;
+ else
+ return EFetchResult::Finish;
+ }
+
+ if (!s.FillOutput(ctx.HolderFactory))
+ return EFetchResult::Finish;
+ }
+
+ const auto sliceSize = s.Slice();
+ for (size_t i = 0; i < OutputWidth_; ++i) {
+ if (const auto out = output[i]) {
+ *out = s.Get(sliceSize, ctx.HolderFactory, i);
+ }
+ }
+ return EFetchResult::One;
+ }
+#ifndef MKQL_DISABLE_CODEGEN
+ ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
+ return DoGenGetValuesImpl(ctx, statePtr, block, Flow_, Width_, OutputWidth_,
+ GetMethodPtr(&TState::Get), GetMethodPtr(&THashedWrapperBase::MakeState),
+ GetMethodPtr(&TState::ProcessInput), GetMethodPtr(&TState::Finish),
+ GetMethodPtr(&TState::FillOutput), GetMethodPtr(&TState::Slice));
+ }
+#endif
private:
struct TState : public TBlockState {
bool WritingOutput_ = false;
@@ -1567,37 +1629,6 @@ private:
}
};
private:
-#ifndef MKQL_DISABLE_CODEGEN
- class TLLVMFieldsStructureState: public TLLVMFieldsStructureBlockState {
- private:
- using TBase = TLLVMFieldsStructureBlockState;
- llvm::IntegerType*const WritingOutputType;
- llvm::IntegerType*const IsFinishedType;
- protected:
- using TBase::Context;
- public:
- std::vector<llvm::Type*> GetFieldsArray() {
- std::vector<llvm::Type*> result = TBase::GetFieldsArray();
- result.emplace_back(WritingOutputType);
- result.emplace_back(IsFinishedType);
- return result;
- }
-
- llvm::Constant* GetWritingOutput() {
- return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + BaseFields);
- }
-
- llvm::Constant* GetIsFinished() {
- return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + BaseFields + 1);
- }
-
- TLLVMFieldsStructureState(llvm::LLVMContext& context, size_t width)
- : TBase(context, width)
- , WritingOutputType(Type::getInt1Ty(Context))
- , IsFinishedType(Type::getInt1Ty(Context))
- {}
- };
-#endif
void RegisterDependencies() const final {
this->FlowDependsOn(Flow_);
}