diff options
author | atarasov5 <[email protected]> | 2025-08-28 14:59:55 +0300 |
---|---|---|
committer | atarasov5 <[email protected]> | 2025-08-28 15:45:33 +0300 |
commit | 642fe48387994c15621318e16f98eac8d11a301c (patch) | |
tree | f362d0234fe8e4e263bc5b8ad64b12788a6b7039 /yql/essentials/minikql/comp_nodes/mkql_element.cpp | |
parent | 81d828c32c8d5477cb2f0ce5da06a1a8d9392ca3 (diff) |
YQL-20340: Fix getelem comp node
commit_hash:4b93115d4e3d46770946a7a462c7413d6183282f
Diffstat (limited to 'yql/essentials/minikql/comp_nodes/mkql_element.cpp')
-rw-r--r-- | yql/essentials/minikql/comp_nodes/mkql_element.cpp | 127 |
1 files changed, 105 insertions, 22 deletions
diff --git a/yql/essentials/minikql/comp_nodes/mkql_element.cpp b/yql/essentials/minikql/comp_nodes/mkql_element.cpp index 58ab35d8037..c9013b98fc0 100644 --- a/yql/essentials/minikql/comp_nodes/mkql_element.cpp +++ b/yql/essentials/minikql/comp_nodes/mkql_element.cpp @@ -8,6 +8,46 @@ namespace NMiniKQL { namespace { +enum class EOptionalityHandlerStrategy { + // Just return child as is. + ReturnChildAsIs, + // Return child and add optionality to it. + AddOptionalToChild, + // Return child but set optionality to (tuple & child) intersection. + IntersectOptionals, +}; + +inline bool IsOptionalOrNull(const TType* type) { + return type->IsOptional() || type->IsNull() || type->IsPg(); +} + +// The strategy is based on tuple and its child optionality. +// Tuple<X> -> return child as is (ReturnChildAsIs). +// Tuple<X?> -> return child as is (ReturnChildAsIs). +// Tuple<X>? -> return child and add extra optional level (AddOptionalToChild). +// Tuple<X?>? -> return child as is BUT set mask to (tuple & child) intersection (IntersectOptionals). +EOptionalityHandlerStrategy GetStrategyBasedOnTupleType(TType* tupleType, TType* elementType) { + if (!tupleType->IsOptional()) { + return EOptionalityHandlerStrategy::ReturnChildAsIs; + } else if (IsOptionalOrNull(elementType)) { + return EOptionalityHandlerStrategy::IntersectOptionals; + } else { + return EOptionalityHandlerStrategy::AddOptionalToChild; + } + Y_UNREACHABLE(); +} + +constexpr bool IsTupleOptional(EOptionalityHandlerStrategy strategy) { + switch (strategy) { + case EOptionalityHandlerStrategy::ReturnChildAsIs: + return false; + case EOptionalityHandlerStrategy::AddOptionalToChild: + case EOptionalityHandlerStrategy::IntersectOptionals: + return true; + } + Y_UNREACHABLE(); +} + template <bool IsOptional> class TElementsWrapper : public TMutableCodegeneratorNode<TElementsWrapper<IsOptional>> { typedef TMutableCodegeneratorNode<TElementsWrapper<IsOptional>> TBaseComputation; @@ -64,9 +104,9 @@ private: IComputationNode* const Array; }; -template <bool IsOptional> -class TElementWrapper : public TMutableCodegeneratorPtrNode<TElementWrapper<IsOptional>> { - typedef TMutableCodegeneratorPtrNode<TElementWrapper<IsOptional>> TBaseComputation; +template <EOptionalityHandlerStrategy Strategy> +class TElementWrapper : public TMutableCodegeneratorPtrNode<TElementWrapper<Strategy>> { + typedef TMutableCodegeneratorPtrNode<TElementWrapper<Strategy>> TBaseComputation; public: TElementWrapper(TComputationMutables& mutables, EValueRepresentation kind, IComputationNode* cache, IComputationNode* array, ui32 index) : TBaseComputation(mutables, kind), Cache(cache), Array(array), Index(index) @@ -75,19 +115,33 @@ public: NUdf::TUnboxedValue DoCalculate(TComputationContext& ctx) const { if (Cache->GetDependencesCount() > 1U) { const auto cache = Cache->GetValue(ctx); - if (IsOptional && !cache) { + if (IsTupleOptional(Strategy) && !cache) { return NUdf::TUnboxedValue(); } - if (const auto elements = cache.Get<ui64>()) { - return reinterpret_cast<const NUdf::TUnboxedValuePod*>(elements)[Index]; + const auto elements = cache.Get<ui64>(); + if (elements) { + auto element = reinterpret_cast<const NUdf::TUnboxedValuePod*>(elements)[Index]; + if constexpr (Strategy == EOptionalityHandlerStrategy::IntersectOptionals) { + return element; + } else if constexpr (Strategy == EOptionalityHandlerStrategy::AddOptionalToChild) { + return element.MakeOptional(); + } else if constexpr (Strategy == EOptionalityHandlerStrategy::ReturnChildAsIs) { + return element; + } else { + static_assert(false, "Unsupported type."); + } } } const auto& array = Array->GetValue(ctx); - if constexpr (IsOptional) { + if constexpr (Strategy == EOptionalityHandlerStrategy::IntersectOptionals) { return array ? array.GetElement(Index) : NUdf::TUnboxedValue(); - } else { + } else if constexpr (Strategy == EOptionalityHandlerStrategy::AddOptionalToChild) { + return array ? NUdf::TUnboxedValue(array.GetElement(Index).MakeOptional()) : NUdf::TUnboxedValue(); + } else if constexpr (Strategy == EOptionalityHandlerStrategy::ReturnChildAsIs) { return array.GetElement(Index); + } else { + static_assert(false, "Unsupported type."); } } @@ -95,9 +149,10 @@ public: void DoGenerateGetElement(const TCodegenContext& ctx, Value* pointer, BasicBlock*& block) const { auto& context = ctx.Codegen.GetContext(); + const auto valueType = Type::getInt128Ty(context); const auto array = GetNodeValue(Array, ctx, block); const auto index = ConstantInt::get(Type::getInt32Ty(context), Index); - if constexpr (IsOptional) { + if constexpr (IsTupleOptional(Strategy)) { const auto good = BasicBlock::Create(context, "good", ctx.Func); const auto zero = BasicBlock::Create(context, "zero", ctx.Func); const auto exit = BasicBlock::Create(context, "exit", ctx.Func); @@ -110,15 +165,21 @@ public: block = good; CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::GetElement>(pointer, array, ctx.Codegen, block, index); + if constexpr (Strategy == EOptionalityHandlerStrategy::AddOptionalToChild) { + const auto load = new LoadInst(valueType, pointer, "load", block); + new StoreInst(MakeOptional(context, load, block), pointer, block); + } if (Array->IsTemporaryValue()) CleanupBoxed(array, ctx, block); BranchInst::Create(exit, block); block = exit; - } else { + } else if constexpr (Strategy == EOptionalityHandlerStrategy::ReturnChildAsIs){ CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::GetElement>(pointer, array, ctx.Codegen, block, index); if (Array->IsTemporaryValue()) CleanupBoxed(array, ctx, block); + } else { + static_assert(false, "Unhandled case."); } } @@ -134,7 +195,7 @@ public: const auto slow = BasicBlock::Create(context, "slow", ctx.Func); const auto done = BasicBlock::Create(context, "done", ctx.Func); - if constexpr (IsOptional) { + if constexpr (IsTupleOptional(Strategy)) { const auto zero = ConstantInt::get(cache->getType(), 0ULL); const auto check = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, cache, zero, "check", block); @@ -160,8 +221,14 @@ public: const auto index = ConstantInt::get(Type::getInt32Ty(context), this->Index); const auto ptr = GetElementPtrInst::CreateInBounds(cache->getType(), elements, {index}, "ptr", block); const auto item = new LoadInst(cache->getType(), ptr, "item", block); + ValueAddRef(this->GetRepresentation(), item, ctx, block); - new StoreInst(item, pointer, block); + if constexpr (Strategy == EOptionalityHandlerStrategy::AddOptionalToChild) { + new StoreInst(MakeOptional(context, item, block), pointer, block); + } else { + new StoreInst(item, pointer, block); + } + BranchInst::Create(done, block); block = slow; @@ -200,17 +267,19 @@ IComputationNode* WrapNth(TCallable& callable, const TComputationNodeFactoryCont const auto indexData = AS_VALUE(TDataLiteral, callable.GetInput(1U)); const auto index = indexData->AsValue().Get<ui32>(); MKQL_ENSURE(index < tupleType->GetElementsCount(), "Bad tuple index"); - + auto nthStrategy = GetStrategyBasedOnTupleType(input.GetStaticType(), tupleType->GetElementType(index)); const auto tuple = LocateNode(ctx.NodeLocator, callable, 0); const auto ins = ctx.ElementsCache.emplace(tuple, nullptr); if (ins.second) { ctx.NodePushBack(ins.first->second = WrapElements(tuple, ctx, isOptional)); } - - if (isOptional) { - return new TElementWrapper<true>(ctx.Mutables, GetValueRepresentation(tupleType->GetElementType(index)), ins.first->second, tuple, index); - } else { - return new TElementWrapper<false>(ctx.Mutables, GetValueRepresentation(tupleType->GetElementType(index)), ins.first->second, tuple, index); + switch (nthStrategy) { + case EOptionalityHandlerStrategy::ReturnChildAsIs: + return new TElementWrapper<EOptionalityHandlerStrategy::ReturnChildAsIs>(ctx.Mutables, GetValueRepresentation(tupleType->GetElementType(index)), ins.first->second, tuple, index); + case EOptionalityHandlerStrategy::IntersectOptionals: + return new TElementWrapper<EOptionalityHandlerStrategy::IntersectOptionals>(ctx.Mutables, GetValueRepresentation(tupleType->GetElementType(index)), ins.first->second, tuple, index); + case EOptionalityHandlerStrategy::AddOptionalToChild: + return new TElementWrapper<EOptionalityHandlerStrategy::AddOptionalToChild>(ctx.Mutables, GetValueRepresentation(tupleType->GetElementType(index)), ins.first->second, tuple, index); } } @@ -228,10 +297,24 @@ IComputationNode* WrapMember(TCallable& callable, const TComputationNodeFactoryC if (ins.second) { ctx.NodePushBack(ins.first->second = WrapElements(structObj, ctx, isOptional)); } - if (isOptional) { - return new TElementWrapper<true>(ctx.Mutables, GetValueRepresentation(structType->GetMemberType(index)), ins.first->second, structObj, index); - } else { - return new TElementWrapper<false>(ctx.Mutables, GetValueRepresentation(structType->GetMemberType(index)), ins.first->second, structObj, index); + + auto nthStrategy = GetStrategyBasedOnTupleType(input.GetStaticType(), structType->GetMemberType(index)); + switch (nthStrategy) { + case EOptionalityHandlerStrategy::ReturnChildAsIs: + return new TElementWrapper<EOptionalityHandlerStrategy::ReturnChildAsIs>( + ctx.Mutables, + GetValueRepresentation(structType->GetMemberType(index)), + ins.first->second, structObj, index); + case EOptionalityHandlerStrategy::AddOptionalToChild: + return new TElementWrapper<EOptionalityHandlerStrategy::AddOptionalToChild>( + ctx.Mutables, + GetValueRepresentation(structType->GetMemberType(index)), + ins.first->second, structObj, index); + case EOptionalityHandlerStrategy::IntersectOptionals: + return new TElementWrapper<EOptionalityHandlerStrategy::IntersectOptionals>( + ctx.Mutables, + GetValueRepresentation(structType->GetMemberType(index)), + ins.first->second, structObj, index); } } |