diff options
| author | vvvv <[email protected]> | 2025-10-09 12:25:18 +0300 |
|---|---|---|
| committer | vvvv <[email protected]> | 2025-10-09 12:57:17 +0300 |
| commit | cb77d014972b2cdb27d2e6d979fc3a2772b27ad4 (patch) | |
| tree | 7f3bcd8ce71c6bd0f3ccc11e31b9f665475b819e /yql/essentials/minikql/mkql_program_builder.cpp | |
| parent | d58a8990d353b051c27e1069141117fdfde64358 (diff) | |
YQL-20086 minikql
commit_hash:e96f7390db5fcbe7e9f64f898141a263ad522daa
Diffstat (limited to 'yql/essentials/minikql/mkql_program_builder.cpp')
| -rw-r--r-- | yql/essentials/minikql/mkql_program_builder.cpp | 1848 |
1 files changed, 1061 insertions, 787 deletions
diff --git a/yql/essentials/minikql/mkql_program_builder.cpp b/yql/essentials/minikql/mkql_program_builder.cpp index 765f04ea359..52e201899b2 100644 --- a/yql/essentials/minikql/mkql_program_builder.cpp +++ b/yql/essentials/minikql/mkql_program_builder.cpp @@ -44,25 +44,25 @@ struct TDataFunctionFlags { }; }; -#define MKQL_BAD_TYPE_VISIT(NodeType, ScriptName) \ - void Visit(NodeType& node) override { \ - Y_UNUSED(node); \ +#define MKQL_BAD_TYPE_VISIT(NodeType, ScriptName) \ + void Visit(NodeType& node) override { \ + Y_UNUSED(node); \ MKQL_ENSURE(false, "Can't convert " #NodeType " to " ScriptName " object"); \ } -class TPythonTypeChecker : public TExploringNodeVisitor { +class TPythonTypeChecker: public TExploringNodeVisitor { using TExploringNodeVisitor::Visit; MKQL_BAD_TYPE_VISIT(TAnyType, "Python"); }; -class TLuaTypeChecker : public TExploringNodeVisitor { +class TLuaTypeChecker: public TExploringNodeVisitor { using TExploringNodeVisitor::Visit; MKQL_BAD_TYPE_VISIT(TVoidType, "Lua"); MKQL_BAD_TYPE_VISIT(TAnyType, "Lua"); MKQL_BAD_TYPE_VISIT(TVariantType, "Lua"); }; -class TJavascriptTypeChecker : public TExploringNodeVisitor { +class TJavascriptTypeChecker: public TExploringNodeVisitor { using TExploringNodeVisitor::Visit; MKQL_BAD_TYPE_VISIT(TAnyType, "Javascript"); }; @@ -70,9 +70,9 @@ class TJavascriptTypeChecker : public TExploringNodeVisitor { #undef MKQL_BAD_TYPE_VISIT void EnsureScriptSpecificTypes( - EScriptType scriptType, - TCallableType* funcType, - std::vector<TNode*>& nodeStack) + EScriptType scriptType, + TCallableType* funcType, + std::vector<TNode*>& nodeStack) { switch (scriptType) { case EScriptType::Lua: @@ -97,71 +97,74 @@ void EnsureScriptSpecificTypes( return TPythonTypeChecker().Walk(funcType, nodeStack); case EScriptType::Javascript: return TJavascriptTypeChecker().Walk(funcType, nodeStack); - default: - MKQL_ENSURE(false, "Unknown script type " << static_cast<ui32>(scriptType)); + default: + MKQL_ENSURE(false, "Unknown script type " << static_cast<ui32>(scriptType)); } } ui32 GetNumericSchemeTypeLevel(NUdf::TDataTypeId typeId) { switch (typeId) { - case NUdf::TDataType<ui8>::Id: - return 0; - case NUdf::TDataType<i8>::Id: - return 1; - case NUdf::TDataType<ui16>::Id: - return 2; - case NUdf::TDataType<i16>::Id: - return 3; - case NUdf::TDataType<ui32>::Id: - return 4; - case NUdf::TDataType<i32>::Id: - return 5; - case NUdf::TDataType<ui64>::Id: - return 6; - case NUdf::TDataType<i64>::Id: - return 7; - case NUdf::TDataType<float>::Id: - return 8; - case NUdf::TDataType<double>::Id: - return 9; - default: - ythrow yexception() << "Unknown numeric type: " << typeId; + case NUdf::TDataType<ui8>::Id: + return 0; + case NUdf::TDataType<i8>::Id: + return 1; + case NUdf::TDataType<ui16>::Id: + return 2; + case NUdf::TDataType<i16>::Id: + return 3; + case NUdf::TDataType<ui32>::Id: + return 4; + case NUdf::TDataType<i32>::Id: + return 5; + case NUdf::TDataType<ui64>::Id: + return 6; + case NUdf::TDataType<i64>::Id: + return 7; + case NUdf::TDataType<float>::Id: + return 8; + case NUdf::TDataType<double>::Id: + return 9; + default: + ythrow yexception() << "Unknown numeric type: " << typeId; } } NUdf::TDataTypeId GetNumericSchemeTypeByLevel(ui32 level) { switch (level) { - case 0: - return NUdf::TDataType<ui8>::Id; - case 1: - return NUdf::TDataType<i8>::Id; - case 2: - return NUdf::TDataType<ui16>::Id; - case 3: - return NUdf::TDataType<i16>::Id; - case 4: - return NUdf::TDataType<ui32>::Id; - case 5: - return NUdf::TDataType<i32>::Id; - case 6: - return NUdf::TDataType<ui64>::Id; - case 7: - return NUdf::TDataType<i64>::Id; - case 8: - return NUdf::TDataType<float>::Id; - case 9: - return NUdf::TDataType<double>::Id; - default: - ythrow yexception() << "Unknown numeric level: " << level; + case 0: + return NUdf::TDataType<ui8>::Id; + case 1: + return NUdf::TDataType<i8>::Id; + case 2: + return NUdf::TDataType<ui16>::Id; + case 3: + return NUdf::TDataType<i16>::Id; + case 4: + return NUdf::TDataType<ui32>::Id; + case 5: + return NUdf::TDataType<i32>::Id; + case 6: + return NUdf::TDataType<ui64>::Id; + case 7: + return NUdf::TDataType<i64>::Id; + case 8: + return NUdf::TDataType<float>::Id; + case 9: + return NUdf::TDataType<double>::Id; + default: + ythrow yexception() << "Unknown numeric level: " << level; } } NUdf::TDataTypeId MakeNumericDataSuperType(NUdf::TDataTypeId typeId1, NUdf::TDataTypeId typeId2) { - return typeId1 == typeId2 ? typeId1 : - GetNumericSchemeTypeByLevel(std::max(GetNumericSchemeTypeLevel(typeId1), GetNumericSchemeTypeLevel(typeId2))); + if (typeId1 == typeId2) { + return typeId1; + } else { + return GetNumericSchemeTypeByLevel(std::max(GetNumericSchemeTypeLevel(typeId1), GetNumericSchemeTypeLevel(typeId2))); + } } -template<bool IsFilter> +template <bool IsFilter> bool CollectOptionalElements(const TType* type, std::vector<std::string_view>& test, std::vector<std::pair<std::string_view, TType*>>& output) { const auto structType = AS_TYPE(TStructType, type); test.reserve(structType->GetMembersCount()); @@ -182,7 +185,7 @@ bool CollectOptionalElements(const TType* type, std::vector<std::string_view>& t return multiOptional; } -template<bool IsFilter> +template <bool IsFilter> bool CollectOptionalElements(const TType* type, std::vector<ui32>& test, std::vector<TType*>& output) { const auto typleType = AS_TYPE(TTupleType, type); test.reserve(typleType->GetElementsCount()); @@ -255,7 +258,8 @@ static std::vector<TType*> ValidateBlockItems(const TArrayRef<TType* const>& wid } MKQL_ENSURE(isScalar, "Last column should be scalar"); - MKQL_ENSURE(AS_TYPE(TDataType, itemType)->GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected Uint64"); + MKQL_ENSURE(AS_TYPE(TDataType, itemType)->GetSchemeType() == NUdf::TDataType<ui64>::Id, + "Expected Uint64"); return items; } @@ -264,7 +268,8 @@ static std::vector<TType*> ValidateBlockItems(const TArrayRef<TType* const>& wid std::string_view ScriptTypeAsStr(EScriptType type) { switch (type) { #define MKQL_SCRIPT_TYPE_CASE(name, value, ...) \ - case EScriptType::name: return std::string_view(#name); + case EScriptType::name: \ + return std::string_view(#name); MKQL_SCRIPT_TYPES(MKQL_SCRIPT_TYPE_CASE) @@ -277,8 +282,9 @@ std::string_view ScriptTypeAsStr(EScriptType type) { EScriptType ScriptTypeFromStr(std::string_view str) { TString lowerStr = TString(str); lowerStr.to_lower(); -#define MKQL_SCRIPT_TYPE_FROM_STR(name, value, lowerName, allowSuffix) \ - if ((allowSuffix && lowerStr.StartsWith(#lowerName)) || lowerStr == #lowerName) return EScriptType::name; +#define MKQL_SCRIPT_TYPE_FROM_STR(name, value, lowerName, allowSuffix) \ + if ((allowSuffix && lowerStr.StartsWith(#lowerName)) || lowerStr == #lowerName) \ + return EScriptType::name; MKQL_SCRIPT_TYPES(MKQL_SCRIPT_TYPE_FROM_STR) #undef MKQL_SCRIPT_TYPE_FROM_STR @@ -288,21 +294,21 @@ EScriptType ScriptTypeFromStr(std::string_view str) { bool IsCustomPython(EScriptType type) { return type == EScriptType::CustomPython || - type == EScriptType::CustomPython2 || - type == EScriptType::CustomPython3; + type == EScriptType::CustomPython2 || + type == EScriptType::CustomPython3; } bool IsSystemPython(EScriptType type) { - return type == EScriptType::SystemPython2 - || type == EScriptType::SystemPython3 - || type == EScriptType::SystemPython3_8 - || type == EScriptType::SystemPython3_9 - || type == EScriptType::SystemPython3_10 - || type == EScriptType::SystemPython3_11 - || type == EScriptType::SystemPython3_12 - || type == EScriptType::SystemPython3_13 - || type == EScriptType::Python - || type == EScriptType::Python2; + return type == EScriptType::SystemPython2 || + type == EScriptType::SystemPython3 || + type == EScriptType::SystemPython3_8 || + type == EScriptType::SystemPython3_9 || + type == EScriptType::SystemPython3_10 || + type == EScriptType::SystemPython3_11 || + type == EScriptType::SystemPython3_12 || + type == EScriptType::SystemPython3_13 || + type == EScriptType::Python || + type == EScriptType::Python2; } EScriptType CanonizeScriptType(EScriptType type) { @@ -319,8 +325,9 @@ EScriptType CanonizeScriptType(EScriptType type) { void EnsureDataOrOptionalOfData(TRuntimeNode node) { MKQL_ENSURE(node.GetStaticType()->IsData() || - node.GetStaticType()->IsOptional() && AS_TYPE(TOptionalType, node.GetStaticType()) - ->GetItemType()->IsData(), "Expected data or optional of data"); + node.GetStaticType()->IsOptional() && AS_TYPE(TOptionalType, node.GetStaticType()) + ->GetItemType() + ->IsData(), "Expected data or optional of data"); } std::vector<TType*> ValidateBlockType(const TType* type, bool unwrap) { @@ -339,12 +346,13 @@ std::vector<TType*> ValidateBlockFlowType(const TType* flowType, bool unwrap) { } TProgramBuilder::TProgramBuilder(const TTypeEnvironment& env, const IFunctionRegistry& functionRegistry, - bool voidWithEffects, NYql::TLangVersion langver) + bool voidWithEffects, NYql::TLangVersion langver) : TTypeBuilder(env) , FunctionRegistry_(functionRegistry) , VoidWithEffects_(voidWithEffects) , LangVer_(langver) -{} +{ +} const TTypeEnvironment& TProgramBuilder::GetTypeEnvironment() const { return Env_; @@ -362,10 +370,10 @@ TType* TProgramBuilder::ChooseCommonType(TType* type1, TType* type2) { return isOptional1 ? type1 : type2; } - MKQL_ENSURE(! - ((NUdf::GetDataTypeInfo(*data1->GetDataSlot()).Features | NUdf::GetDataTypeInfo(*data2->GetDataSlot()).Features) & (NUdf::EDataTypeFeatures::DateType | NUdf::EDataTypeFeatures::TzDateType)), - "Not same date types: " << *type1 << " and " << *type2 - ); + MKQL_ENSURE(!((NUdf::GetDataTypeInfo(*data1->GetDataSlot()).Features | + NUdf::GetDataTypeInfo(*data2->GetDataSlot()).Features) & + (NUdf::EDataTypeFeatures::DateType | NUdf::EDataTypeFeatures::TzDateType)), + "Not same date types: " << *type1 << " and " << *type2); const auto data = NewDataType(MakeNumericDataSuperType(data1->GetSchemeType(), data2->GetSchemeType())); return isOptional1 || isOptional2 ? NewOptionalType(data) : data; @@ -386,11 +394,10 @@ TType* TProgramBuilder::BuildArithmeticCommonType(TType* type1, TType* type2) { return NewOptionalType(features1 & NUdf::EDataTypeFeatures::IntegralType ? data2 : data1); } else if ( features1 & (NUdf::EDataTypeFeatures::DateType | NUdf::EDataTypeFeatures::TzDateType) && - features2 & (NUdf::EDataTypeFeatures::DateType | NUdf::EDataTypeFeatures::TzDateType) - ) { + features2 & (NUdf::EDataTypeFeatures::DateType | NUdf::EDataTypeFeatures::TzDateType)) { const auto used = ((features1 | features2) & NUdf::EDataTypeFeatures::ExtDateType) - ? NewDataType(NUdf::EDataSlot::Interval64) - : NewDataType(NUdf::EDataSlot::Interval); + ? NewDataType(NUdf::EDataSlot::Interval64) + : NewDataType(NUdf::EDataSlot::Interval); return isOptional ? NewOptionalType(used) : used; } else if (data1->GetSchemeType() == NUdf::TDataType<NUdf::TDecimal>::Id) { MKQL_ENSURE(data1->IsSameType(*data2), "Must be same type."); @@ -470,8 +477,7 @@ TRuntimeNode TProgramBuilder::RemoveMember(TRuntimeNode structObj, const std::st for (ui32 i = 0, e = oldTypeDetailed.GetMembersCount(); i < e; ++i) { if (oldTypeDetailed.GetMemberName(i) != memberName) { newTypeBuilder.Add(oldTypeDetailed.GetMemberName(i), oldTypeDetailed.GetMemberType(i)); - } - else { + } else { memberIndex = i; } } @@ -550,7 +556,7 @@ TRuntimeNode TProgramBuilder::Enumerate(TRuntimeNode list, TRuntimeNode start, T MKQL_ENSURE(AS_TYPE(TDataType, start)->GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected Uint64 as start"); MKQL_ENSURE(AS_TYPE(TDataType, step)->GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected Uint64 as step"); - const std::array<TType*, 2U> tupleTypes = {{ NewDataType(NUdf::EDataSlot::Uint64), itemType }}; + const std::array<TType*, 2U> tupleTypes = {{NewDataType(NUdf::EDataSlot::Uint64), itemType}}; const auto returnType = NewListType(NewTupleType(tupleTypes)); TCallableBuilder callableBuilder(Env_, __func__, returnType); @@ -603,15 +609,19 @@ TRuntimeNode TProgramBuilder::Fold1(TRuntimeNode list, const TUnaryLambda& init, } TRuntimeNode TProgramBuilder::Reduce(TRuntimeNode list, TRuntimeNode state1, - const TBinaryLambda& handler1, - const TUnaryLambda& handler2, - TRuntimeNode state3, - const TBinaryLambda& handler3) { + const TBinaryLambda& handler1, + const TUnaryLambda& handler2, + TRuntimeNode state3, + const TBinaryLambda& handler3) { const auto listType = list.GetStaticType(); MKQL_ENSURE(listType->IsList() || listType->IsStream(), "Expected list or stream"); - const auto itemType = listType->IsList()? - static_cast<const TListType&>(*listType).GetItemType(): - static_cast<const TStreamType&>(*listType).GetItemType(); + TType* itemType; + if (listType->IsList()) { + itemType = static_cast<const TListType&>(*listType).GetItemType(); + } else { + itemType = static_cast<const TStreamType&>(*listType).GetItemType(); + } + ThrowIfListOfVoid(itemType); const auto state1NodeArg = Arg(state1.GetStaticType()); @@ -641,8 +651,8 @@ TRuntimeNode TProgramBuilder::Reduce(TRuntimeNode list, TRuntimeNode state1, } TRuntimeNode TProgramBuilder::Condense(TRuntimeNode flow, TRuntimeNode state, - const TBinaryLambda& switcher, - const TBinaryLambda& handler, bool useCtx) { + const TBinaryLambda& switcher, + const TBinaryLambda& handler, bool useCtx) { const auto flowType = flow.GetStaticType(); if (flowType->IsList()) { @@ -652,9 +662,12 @@ TRuntimeNode TProgramBuilder::Condense(TRuntimeNode flow, TRuntimeNode state, MKQL_ENSURE(flowType->IsFlow() || flowType->IsStream(), "Expected flow or stream."); - const auto itemType = flowType->IsFlow() ? - static_cast<const TFlowType&>(*flowType).GetItemType(): - static_cast<const TStreamType&>(*flowType).GetItemType(); + TType* itemType; + if (flowType->IsFlow()) { + itemType = static_cast<const TFlowType&>(*flowType).GetItemType(); + } else { + itemType = static_cast<const TStreamType&>(*flowType).GetItemType(); + } const auto itemArg = Arg(itemType); const auto stateArg = Arg(state.GetStaticType()); @@ -662,7 +675,8 @@ TRuntimeNode TProgramBuilder::Condense(TRuntimeNode flow, TRuntimeNode state, const auto newState = handler(itemArg, stateArg); MKQL_ENSURE(newState.GetStaticType()->IsSameType(*state.GetStaticType()), "State type is changed by the handler"); - TCallableBuilder callableBuilder(Env_, __func__, flowType->IsFlow() ? NewFlowType(state.GetStaticType()) : NewStreamType(state.GetStaticType())); + auto returnType = flowType->IsFlow() ? NewFlowType(state.GetStaticType()) : NewStreamType(state.GetStaticType()); + TCallableBuilder callableBuilder(Env_, __func__, returnType); callableBuilder.Add(flow); callableBuilder.Add(state); callableBuilder.Add(itemArg); @@ -677,8 +691,8 @@ TRuntimeNode TProgramBuilder::Condense(TRuntimeNode flow, TRuntimeNode state, } TRuntimeNode TProgramBuilder::Condense1(TRuntimeNode flow, const TUnaryLambda& init, - const TBinaryLambda& switcher, - const TBinaryLambda& handler, bool useCtx) { + const TBinaryLambda& switcher, + const TBinaryLambda& handler, bool useCtx) { const auto flowType = flow.GetStaticType(); if (flowType->IsList()) { @@ -688,9 +702,12 @@ TRuntimeNode TProgramBuilder::Condense1(TRuntimeNode flow, const TUnaryLambda& i MKQL_ENSURE(flowType->IsFlow() || flowType->IsStream(), "Expected flow or stream."); - const auto itemType = flowType->IsFlow() ? - static_cast<const TFlowType&>(*flowType).GetItemType(): - static_cast<const TStreamType&>(*flowType).GetItemType(); + TType* itemType; + if (flowType->IsFlow()) { + itemType = static_cast<const TFlowType&>(*flowType).GetItemType(); + } else { + itemType = static_cast<const TStreamType&>(*flowType).GetItemType(); + } const auto itemArg = Arg(itemType); const auto initState = init(itemArg); @@ -700,7 +717,14 @@ TRuntimeNode TProgramBuilder::Condense1(TRuntimeNode flow, const TUnaryLambda& i MKQL_ENSURE(newState.GetStaticType()->IsSameType(*initState.GetStaticType()), "State type is changed by the handler"); - TCallableBuilder callableBuilder(Env_, __func__, flowType->IsFlow() ? NewFlowType(newState.GetStaticType()) : NewStreamType(newState.GetStaticType())); + TType* returnType; + if (flowType->IsFlow()) { + returnType = NewFlowType(newState.GetStaticType()); + } else { + returnType = NewStreamType(newState.GetStaticType()); + } + + TCallableBuilder callableBuilder(Env_, __func__, returnType); callableBuilder.Add(flow); callableBuilder.Add(itemArg); callableBuilder.Add(initState); @@ -715,9 +739,9 @@ TRuntimeNode TProgramBuilder::Condense1(TRuntimeNode flow, const TUnaryLambda& i } TRuntimeNode TProgramBuilder::Squeeze(TRuntimeNode stream, TRuntimeNode state, - const TBinaryLambda& handler, - const TUnaryLambda& save, - const TUnaryLambda& load) { + const TBinaryLambda& handler, + const TUnaryLambda& save, + const TUnaryLambda& load) { const auto streamType = stream.GetStaticType(); MKQL_ENSURE(streamType->IsStream(), "Expected stream"); @@ -754,9 +778,9 @@ TRuntimeNode TProgramBuilder::Squeeze(TRuntimeNode stream, TRuntimeNode state, } TRuntimeNode TProgramBuilder::Squeeze1(TRuntimeNode stream, const TUnaryLambda& init, - const TBinaryLambda& handler, - const TUnaryLambda& save, - const TUnaryLambda& load) { + const TBinaryLambda& handler, + const TUnaryLambda& save, + const TUnaryLambda& load) { const auto streamType = stream.GetStaticType(); MKQL_ENSURE(streamType->IsStream(), "Expected stream"); @@ -814,9 +838,7 @@ TRuntimeNode TProgramBuilder::MapNext(TRuntimeNode list, const TBinaryLambda& ha const auto listType = list.GetStaticType(); MKQL_ENSURE(listType->IsStream() || listType->IsFlow(), "Expected stream or flow"); - const auto itemType = listType->IsFlow() ? - AS_TYPE(TFlowType, listType)->GetItemType(): - AS_TYPE(TStreamType, listType)->GetItemType(); + const auto itemType = listType->IsFlow() ? AS_TYPE(TFlowType, listType)->GetItemType() : AS_TYPE(TStreamType, listType)->GetItemType(); ThrowIfListOfVoid(itemType); @@ -827,9 +849,12 @@ TRuntimeNode TProgramBuilder::MapNext(TRuntimeNode list, const TBinaryLambda& ha const auto newItem = handler(itemArg, nextItemArg); - const auto resultListType = listType->IsFlow() ? - (TType*)TFlowType::Create(newItem.GetStaticType(), Env_): - (TType*)TStreamType::Create(newItem.GetStaticType(), Env_); + TType* resultListType; + if (listType->IsFlow()) { + resultListType = (TType*)TFlowType::Create(newItem.GetStaticType(), Env_); + } else { + resultListType = (TType*)TStreamType::Create(newItem.GetStaticType(), Env_); + } TCallableBuilder callableBuilder(Env_, __func__, resultListType); callableBuilder.Add(list); @@ -844,9 +869,7 @@ TRuntimeNode TProgramBuilder::BuildExtract(TRuntimeNode list, const std::string_ const auto listType = list.GetStaticType(); MKQL_ENSURE(listType->IsList() || listType->IsOptional(), "Expected list or optional."); - const auto itemType = listType->IsList() ? - AS_TYPE(TListType, listType)->GetItemType(): - AS_TYPE(TOptionalType, listType)->GetItemType(); + const auto itemType = listType->IsList() ? AS_TYPE(TListType, listType)->GetItemType() : AS_TYPE(TOptionalType, listType)->GetItemType(); const auto lambda = [&](TRuntimeNode item) { return itemType->IsStruct() ? Member(item, name) : Nth(item, ::FromString<ui32>(name)); @@ -874,11 +897,14 @@ TRuntimeNode TProgramBuilder::ChainMap(TRuntimeNode list, TRuntimeNode state, co const auto listType = list.GetStaticType(); MKQL_ENSURE(listType->IsFlow() || listType->IsList() || listType->IsStream(), "Expected flow, list or stream"); - const auto itemType = listType->IsFlow() ? - AS_TYPE(TFlowType, listType)->GetItemType(): - listType->IsList() ? - AS_TYPE(TListType, listType)->GetItemType(): - AS_TYPE(TStreamType, listType)->GetItemType(); + TType* itemType; + if (listType->IsFlow()) { + itemType = AS_TYPE(TFlowType, listType)->GetItemType(); + } else if (listType->IsList()) { + itemType = AS_TYPE(TListType, listType)->GetItemType(); + } else { + itemType = AS_TYPE(TStreamType, listType)->GetItemType(); + } ThrowIfListOfVoid(itemType); @@ -909,33 +935,33 @@ TRuntimeNode TProgramBuilder::ChainMap(TRuntimeNode list, TRuntimeNode state, co TRuntimeNode TProgramBuilder::Chain1Map(TRuntimeNode list, const TUnaryLambda& init, const TBinaryLambda& handler) { return Chain1Map(list, - [&](TRuntimeNode item) -> TRuntimeNodePair { + [&](TRuntimeNode item) -> TRuntimeNodePair { const auto result = init(item); - return {result, result}; - }, - [&](TRuntimeNode item, TRuntimeNode state) -> TRuntimeNodePair { + return {result, result}; }, + [&](TRuntimeNode item, TRuntimeNode state) -> TRuntimeNodePair { const auto result = handler(item, state); - return {result, result}; - } - ); + return {result, result}; }); } TRuntimeNode TProgramBuilder::Chain1Map(TRuntimeNode list, const TUnarySplitLambda& init, const TBinarySplitLambda& handler) { const auto listType = list.GetStaticType(); MKQL_ENSURE(listType->IsFlow() || listType->IsList() || listType->IsStream(), "Expected flow, list or stream"); - const auto itemType = listType->IsFlow() ? - AS_TYPE(TFlowType, listType)->GetItemType(): - listType->IsList() ? - AS_TYPE(TListType, listType)->GetItemType(): - AS_TYPE(TStreamType, listType)->GetItemType(); + TType* itemType; + if (listType->IsFlow()) { + itemType = AS_TYPE(TFlowType, listType)->GetItemType(); + } else if (listType->IsList()) { + itemType = AS_TYPE(TListType, listType)->GetItemType(); + } else { + itemType = AS_TYPE(TStreamType, listType)->GetItemType(); + } ThrowIfListOfVoid(itemType); const auto itemArg = Arg(itemType); const auto initItemAndState = init(itemArg); const auto resultItemType = std::get<0U>(initItemAndState).GetStaticType(); - const auto stateType = std::get<1U>(initItemAndState).GetStaticType();; + const auto stateType = std::get<1U>(initItemAndState).GetStaticType(); TType* resultListType = nullptr; if (listType->IsFlow()) { resultListType = TFlowType::Create(resultItemType, Env_); @@ -999,8 +1025,9 @@ TRuntimeNode TProgramBuilder::Last(TRuntimeNode list) { } TRuntimeNode TProgramBuilder::Nanvl(TRuntimeNode data, TRuntimeNode dataIfNaN) { - const std::array<TRuntimeNode, 2> args = {{ data, dataIfNaN }}; - return Invoke(__func__, BuildArithmeticCommonType(data.GetStaticType(), dataIfNaN.GetStaticType()), args); + const std::array<TRuntimeNode, 2> args = {{data, dataIfNaN}}; + return Invoke(__func__, + BuildArithmeticCommonType(data.GetStaticType(), dataIfNaN.GetStaticType()), args); } TRuntimeNode TProgramBuilder::FlatMap(TRuntimeNode list, const TUnaryLambda& handler) @@ -1054,7 +1081,7 @@ TRuntimeNode TProgramBuilder::SkipWhileInclusive(TRuntimeNode list, const TUnary } TRuntimeNode TProgramBuilder::BuildListSort(const std::string_view& callableName, TRuntimeNode list, TRuntimeNode ascending, - const TUnaryLambda& keyExtractor) + const TUnaryLambda& keyExtractor) { const auto listType = list.GetStaticType(); MKQL_ENSURE(listType->IsList(), "Expected list."); @@ -1089,7 +1116,7 @@ TRuntimeNode TProgramBuilder::BuildListSort(const std::string_view& callableName } TRuntimeNode TProgramBuilder::BuildListNth(const std::string_view& callableName, TRuntimeNode list, TRuntimeNode n, TRuntimeNode ascending, - const TUnaryLambda& keyExtractor) + const TUnaryLambda& keyExtractor) { const auto listType = list.GetStaticType(); MKQL_ENSURE(listType->IsList(), "Expected list."); @@ -1127,29 +1154,37 @@ TRuntimeNode TProgramBuilder::BuildListNth(const std::string_view& callableName, } TRuntimeNode TProgramBuilder::BuildSort(const std::string_view& callableName, TRuntimeNode flow, TRuntimeNode ascending, - const TUnaryLambda& keyExtractor) + const TUnaryLambda& keyExtractor) { if (const auto flowType = flow.GetStaticType(); flowType->IsFlow() || flowType->IsStream()) { const bool isFlow = flowType->IsFlow(); - const auto condense = isFlow ? - SqueezeToList(Map(flow, [&](TRuntimeNode item) { return Pickle(item); }), NewEmptyOptionalDataLiteral(NUdf::TDataType<ui64>::Id)) : - Condense1(flow, - [this](TRuntimeNode item) { return AsList(item); }, - [this](TRuntimeNode, TRuntimeNode) { return NewDataLiteral<bool>(false); }, - [this](TRuntimeNode item, TRuntimeNode state) { return Append(state, item); } - ); + TRuntimeNode condense; + if (isFlow) { + const auto pickle = [&](TRuntimeNode item) { return Pickle(item); }; + condense = SqueezeToList(Map(flow, pickle), NewEmptyOptionalDataLiteral(NUdf::TDataType<ui64>::Id)); + } else { + condense = Condense1(flow, + [this](TRuntimeNode item) { return AsList(item); }, + [this](TRuntimeNode, TRuntimeNode) { return NewDataLiteral<bool>(false); }, + [this](TRuntimeNode item, TRuntimeNode state) { return Append(state, item); }); + } const auto finalKeyExtractor = isFlow ? [&](TRuntimeNode item) { - auto itemType = AS_TYPE(TFlowType, flowType)->GetItemType(); - return keyExtractor(Unpickle(itemType, item)); - } : keyExtractor; + auto itemType = AS_TYPE(TFlowType, flowType)->GetItemType(); + return keyExtractor(Unpickle(itemType, item)); + } + : keyExtractor; return FlatMap(condense, [&](TRuntimeNode list) { auto sorted = BuildSort("UnstableSort", Steal(list), ascending, finalKeyExtractor); - return isFlow ? Map(LazyList(sorted), [&](TRuntimeNode item) { - auto itemType = AS_TYPE(TFlowType, flowType)->GetItemType(); - return Unpickle(itemType, item); - }) : sorted; + if (isFlow) { + return Map(LazyList(sorted), [&](TRuntimeNode item) { + auto itemType = AS_TYPE(TFlowType, flowType)->GetItemType(); + return Unpickle(itemType, item); + }); + } else { + return sorted; + } }); } @@ -1157,16 +1192,14 @@ TRuntimeNode TProgramBuilder::BuildSort(const std::string_view& callableName, TR } TRuntimeNode TProgramBuilder::BuildNth(const std::string_view& callableName, TRuntimeNode flow, TRuntimeNode n, TRuntimeNode ascending, - const TUnaryLambda& keyExtractor) + const TUnaryLambda& keyExtractor) { if (const auto flowType = flow.GetStaticType(); flowType->IsFlow() || flowType->IsStream()) { return FlatMap(Condense1(flow, - [this](TRuntimeNode item) { return AsList(item); }, - [this](TRuntimeNode, TRuntimeNode) { return NewDataLiteral<bool>(false); }, - [this](TRuntimeNode item, TRuntimeNode state) { return Append(state, item); } - ), - [&](TRuntimeNode list) { return BuildNth(callableName, list, n, ascending, keyExtractor); } - ); + [this](TRuntimeNode item) { return AsList(item); }, + [this](TRuntimeNode, TRuntimeNode) { return NewDataLiteral<bool>(false); }, + [this](TRuntimeNode item, TRuntimeNode state) { return Append(state, item); }), + [&](TRuntimeNode list) { return BuildNth(callableName, list, n, ascending, keyExtractor); }); } return BuildListNth(callableName, flow, n, ascending, keyExtractor); @@ -1196,7 +1229,7 @@ TRuntimeNode TProgramBuilder::BuildTake(const std::string_view& callableName, TR return TRuntimeNode(callableBuilder.Build(), false); } -template<bool IsFilter, bool OnStruct> +template <bool IsFilter, bool OnStruct> TRuntimeNode TProgramBuilder::BuildFilterNulls(TRuntimeNode list) { const auto listType = list.GetStaticType(); @@ -1221,7 +1254,7 @@ TRuntimeNode TProgramBuilder::BuildFilterNulls(TRuntimeNode list) { std::vector<TRuntimeNode> checkMembers; checkMembers.reserve(members.size()); std::transform(members.cbegin(), members.cend(), std::back_inserter(checkMembers), - [=, this](const auto& i){ return Exists(Element(item, i)); }); + [=, this](const auto& i) { return Exists(Element(item, i)); }); return And(checkMembers); }; @@ -1231,19 +1264,25 @@ TRuntimeNode TProgramBuilder::BuildFilterNulls(TRuntimeNode list) { if (const auto filteredItemType = NewArrayType(filteredItems); multiOptional) { return BuildFilterNulls<OnStruct>(list, members, filteredItems); } else { - resultType = listType->IsFlow() ? - NewFlowType(filteredItemType): - listType->IsList() ? - NewListType(filteredItemType): - listType->IsStream() ? NewStreamType(filteredItemType) : NewOptionalType(filteredItemType); + if (listType->IsFlow()) { + resultType = NewFlowType(filteredItemType); + } else if (listType->IsList()) { + resultType = NewListType(filteredItemType); + } else if (listType->IsStream()) { + resultType = NewStreamType(filteredItemType); + } else { + resultType = NewOptionalType(filteredItemType); + } } } return Filter(list, predicate, resultType); } -template<bool IsFilter, bool OnStruct> -TRuntimeNode TProgramBuilder::BuildFilterNulls(TRuntimeNode list, const TArrayRef<std::conditional_t<OnStruct, const std::string_view, const ui32>>& members) { +template <bool IsFilter, bool OnStruct> +TRuntimeNode TProgramBuilder::BuildFilterNulls(TRuntimeNode list, + const TArrayRef<std::conditional_t< + OnStruct, const std::string_view, const ui32>>& members) { if (members.empty()) { return list; } @@ -1267,7 +1306,7 @@ TRuntimeNode TProgramBuilder::BuildFilterNulls(TRuntimeNode list, const TArrayRe TRuntimeNode::TList checkMembers; checkMembers.reserve(members.size()); std::transform(members.cbegin(), members.cend(), std::back_inserter(checkMembers), - [=, this](const auto& i){ return Exists(Element(item, i)); }); + [=, this](const auto& i) { return Exists(Element(item, i)); }); return And(checkMembers); }; @@ -1279,25 +1318,31 @@ TRuntimeNode TProgramBuilder::BuildFilterNulls(TRuntimeNode list, const TArrayRe return BuildFilterNulls<OnStruct>(list, members, filteredItems); } else { const auto filteredItemType = NewArrayType(filteredItems); - resultType = listType->IsFlow() ? - NewFlowType(filteredItemType): - listType->IsList() ? - NewListType(filteredItemType): - listType->IsStream() ? NewStreamType(filteredItemType) : NewOptionalType(filteredItemType); + if (listType->IsFlow()) { + resultType = NewFlowType(filteredItemType); + } else if (listType->IsList()) { + resultType = NewListType(filteredItemType); + } else if (listType->IsStream()) { + resultType = NewStreamType(filteredItemType); + } else { + resultType = NewOptionalType(filteredItemType); + } } } return Filter(list, predicate, resultType); } -template<bool OnStruct> -TRuntimeNode TProgramBuilder::BuildFilterNulls(TRuntimeNode list, const TArrayRef<std::conditional_t<OnStruct, const std::string_view, const ui32>>& members, - const std::conditional_t<OnStruct, std::vector<std::pair<std::string_view, TType*>>, std::vector<TType*>>& filteredItems) { +template <bool OnStruct> +TRuntimeNode TProgramBuilder::BuildFilterNulls(TRuntimeNode list, + const TArrayRef<std::conditional_t<OnStruct, const std::string_view, const ui32>>& members, + const std::conditional_t< + OnStruct, std::vector<std::pair<std::string_view, TType*>>, std::vector<TType*>>& filteredItems) { return FlatMap(list, [&](TRuntimeNode item) { TRuntimeNode::TList checkMembers; checkMembers.reserve(members.size()); std::transform(members.cbegin(), members.cend(), std::back_inserter(checkMembers), - [=, this](const auto& i){ return this->Element(item, i); }); + [=, this](const auto& i) { return this->Element(item, i); }); return IfPresent(checkMembers, [&](TRuntimeNode::TList items) { std::conditional_t<OnStruct, std::vector<std::pair<std::string_view, TRuntimeNode>>, TRuntimeNode::TList> row; @@ -1305,25 +1350,23 @@ TRuntimeNode TProgramBuilder::BuildFilterNulls(TRuntimeNode list, const TArrayRe auto j = 0U; if constexpr (OnStruct) { std::transform(filteredItems.cbegin(), filteredItems.cend(), std::back_inserter(row), - [&](const std::pair<std::string_view, TType*>& i) { - const auto& member = i.first; - const bool passtrought = members.cend() == std::find(members.cbegin(), members.cend(), member); - return std::make_pair(member, passtrought ? Element(item, member) : items[j++]); - } - ); + [&](const std::pair<std::string_view, TType*>& i) { + const auto& member = i.first; + const bool passtrought = members.cend() == std::find(members.cbegin(), members.cend(), member); + return std::make_pair(member, passtrought ? Element(item, member) : items[j++]); + }); return NewOptional(NewStruct(row)); } else { auto i = 0U; std::generate_n(std::back_inserter(row), filteredItems.size(), - [&]() { - const auto index = i++; - const bool passtrought = members.cend() == std::find(members.cbegin(), members.cend(), index); - return passtrought ? Element(item, index) : items[j++]; - } - ); + [&]() { + const auto index = i++; + const bool passtrought = members.cend() == std::find(members.cbegin(), members.cend(), index); + return passtrought ? Element(item, index) : items[j++]; + }); return NewOptional(NewTuple(row)); } - }, NewEmptyOptional(NewOptionalType(NewArrayType(filteredItems)))); + }, NewEmptyOptional(NewOptionalType(NewArrayType(filteredItems)))); }); } @@ -1425,7 +1468,8 @@ TRuntimeNode TProgramBuilder::LazyList(TRuntimeNode list) { TRuntimeNode TProgramBuilder::ForwardList(TRuntimeNode stream) { const auto type = stream.GetStaticType(); MKQL_ENSURE(type->IsStream() || type->IsFlow(), "Expected flow or stream."); - TCallableBuilder callableBuilder(Env_, __func__, NewListType(type->IsFlow() ? AS_TYPE(TFlowType, stream)->GetItemType() : AS_TYPE(TStreamType, stream)->GetItemType())); + const auto itemType = type->IsFlow() ? AS_TYPE(TFlowType, stream)->GetItemType() : AS_TYPE(TStreamType, stream)->GetItemType(); + TCallableBuilder callableBuilder(Env_, __func__, NewListType(itemType)); callableBuilder.Add(stream); return TRuntimeNode(callableBuilder.Build(), false); } @@ -1433,8 +1477,15 @@ TRuntimeNode TProgramBuilder::ForwardList(TRuntimeNode stream) { TRuntimeNode TProgramBuilder::ToFlow(TRuntimeNode stream) { const auto type = stream.GetStaticType(); MKQL_ENSURE(type->IsStream() || type->IsList() || type->IsOptional(), "Expected stream, list or optional."); - const auto itemType = type->IsStream() ? AS_TYPE(TStreamType, stream)->GetItemType() : - type->IsList() ? AS_TYPE(TListType, stream)->GetItemType() : AS_TYPE(TOptionalType, stream)->GetItemType(); + TType* itemType; + if (type->IsStream()) { + itemType = AS_TYPE(TStreamType, stream)->GetItemType(); + } else if (type->IsList()) { + itemType = AS_TYPE(TListType, stream)->GetItemType(); + } else { + itemType = AS_TYPE(TOptionalType, stream)->GetItemType(); + } + TCallableBuilder callableBuilder(Env_, __func__, NewFlowType(itemType)); callableBuilder.Add(stream); return TRuntimeNode(callableBuilder.Build(), false); @@ -1600,7 +1651,7 @@ TRuntimeNode TProgramBuilder::ReplicateScalar(TRuntimeNode value, TRuntimeNode c MKQL_ENSURE(countType->GetItemType()->IsData(), "Expected scalar data as second argument"); MKQL_ENSURE(AS_TYPE(TDataType, countType->GetItemType())->GetSchemeType() == - NUdf::TDataType<ui64>::Id, "Expected scalar ui64 as second argument"); + NUdf::TDataType<ui64>::Id, "Expected scalar ui64 as second argument"); auto outputType = NewBlockType(valueType->GetItemType(), TBlockType::EShape::Many); @@ -1615,7 +1666,8 @@ TRuntimeNode TProgramBuilder::BlockCompress(TRuntimeNode stream, ui32 bitmapInde MKQL_ENSURE(blockItemTypes.size() >= 2, "Expected at least two input columns"); MKQL_ENSURE(bitmapIndex < blockItemTypes.size() - 1, "Invalid bitmap index"); - MKQL_ENSURE(AS_TYPE(TDataType, blockItemTypes[bitmapIndex])->GetSchemeType() == NUdf::TDataType<bool>::Id, "Expected Bool as bitmap column type"); + MKQL_ENSURE(AS_TYPE(TDataType, blockItemTypes[bitmapIndex])->GetSchemeType() == NUdf::TDataType<bool>::Id, + "Expected Bool as bitmap column type"); const auto wideComponents = GetWideComponents(stream.GetStaticType()); MKQL_ENSURE(wideComponents.size() == blockItemTypes.size(), "Unexpected tuple size"); @@ -1657,7 +1709,8 @@ TRuntimeNode TProgramBuilder::BlockCoalesce(TRuntimeNode first, TRuntimeNode sec auto firstItemType = firstType->GetItemType(); auto secondItemType = secondType->GetItemType(); - MKQL_ENSURE(firstItemType->IsOptional() || firstItemType->IsPg(), TStringBuilder() << "Expecting Optional or Pg type as first argument, but got: " << *firstItemType); + MKQL_ENSURE(firstItemType->IsOptional() || firstItemType->IsPg(), + TStringBuilder() << "Expecting Optional or Pg type as first argument, but got: " << *firstItemType); if (!firstItemType->IsSameType(*secondItemType)) { bool firstOptional; @@ -1705,8 +1758,7 @@ TRuntimeNode TProgramBuilder::BlockNth(TRuntimeNode tuple, ui32 index) { bool isOptional; const auto type = AS_TYPE(TTupleType, UnpackOptional(blockType->GetItemType(), isOptional)); - MKQL_ENSURE(index < type->GetElementsCount(), "Index out of range: " << index << - " is not less than " << type->GetElementsCount()); + MKQL_ENSURE(index < type->GetElementsCount(), "Index out of range: " << index << " is not less than " << type->GetElementsCount()); auto itemType = type->GetElementType(index); if (isOptional && !itemType->IsOptional() && !itemType->IsNull() && !itemType->IsPg()) { itemType = TOptionalType::Create(itemType, Env_); @@ -1815,9 +1867,9 @@ TRuntimeNode TProgramBuilder::ListFromRange(TRuntimeNode start, TRuntimeNode end MKQL_ENSURE(start.GetStaticType()->IsData(), "Expected data"); MKQL_ENSURE(end.GetStaticType()->IsSameType(*start.GetStaticType()), "Mismatch type"); MKQL_ENSURE(IsNumericType(AS_TYPE(TDataType, start)->GetSchemeType()) || - IsDateType(AS_TYPE(TDataType, start)->GetSchemeType()) || - IsTzDateType(AS_TYPE(TDataType, start)->GetSchemeType()) || - IsIntervalType(AS_TYPE(TDataType, start)->GetSchemeType()), + IsDateType(AS_TYPE(TDataType, start)->GetSchemeType()) || + IsTzDateType(AS_TYPE(TDataType, start)->GetSchemeType()) || + IsIntervalType(AS_TYPE(TDataType, start)->GetSchemeType()), "Expected numeric, date or tzdate"); if (IsNumericType(AS_TYPE(TDataType, start)->GetSchemeType())) { @@ -1834,9 +1886,9 @@ TRuntimeNode TProgramBuilder::ListFromRange(TRuntimeNode start, TRuntimeNode end } TRuntimeNode TProgramBuilder::Switch(TRuntimeNode stream, - const TArrayRef<const TSwitchInput>& handlerInputs, - std::function<TRuntimeNode(ui32 index, TRuntimeNode item)> handler, - ui64 memoryLimitBytes, TType* returnType) { + const TArrayRef<const TSwitchInput>& handlerInputs, + std::function<TRuntimeNode(ui32 index, TRuntimeNode item)> handler, + ui64 memoryLimitBytes, TType* returnType) { MKQL_ENSURE(stream.GetStaticType()->IsStream() || stream.GetStaticType()->IsFlow(), "Expected stream or flow."); std::vector<TRuntimeNode> argNodes(handlerInputs.size()); std::vector<TRuntimeNode> outputNodes(handlerInputs.size()); @@ -1878,7 +1930,7 @@ TRuntimeNode TProgramBuilder::Reverse(TRuntimeNode list) { const auto listType = UnpackOptional(list, isOptional); if (isOptional) { - return Map(list, [&](TRuntimeNode unpacked) { return Reverse(unpacked); } ); + return Map(list, [&](TRuntimeNode unpacked) { return Reverse(unpacked); }); } const auto listDetailedType = AS_TYPE(TListType, listType); @@ -1918,7 +1970,11 @@ TRuntimeNode TProgramBuilder::WideSort(TRuntimeNode flow, const std::vector<std: return BuildWideTopOrSort(__func__, flow, Nothing(), keys, /*isBlocks=*/false); } -TRuntimeNode TProgramBuilder::BuildWideTopOrSort(const std::string_view& callableName, TRuntimeNode stream, TMaybe<TRuntimeNode> count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys, bool isBlocks) { +TRuntimeNode TProgramBuilder::BuildWideTopOrSort(const std::string_view& callableName, + TRuntimeNode stream, + TMaybe<TRuntimeNode> count, + const std::vector<std::pair<ui32, TRuntimeNode>>& keys, + bool isBlocks) { if (isBlocks) { return BuildWideTopOrSortImpl(callableName, stream, count, keys, TType::EKind::Stream); } else { @@ -1926,7 +1982,11 @@ TRuntimeNode TProgramBuilder::BuildWideTopOrSort(const std::string_view& callabl } } -TRuntimeNode TProgramBuilder::BuildWideTopOrSortImpl(const std::string_view& callableName, TRuntimeNode stream, TMaybe<TRuntimeNode> count, const std::vector<std::pair<ui32, TRuntimeNode>>& keys, TType::EKind streamKind) { +TRuntimeNode TProgramBuilder::BuildWideTopOrSortImpl(const std::string_view& callableName, + TRuntimeNode stream, + TMaybe<TRuntimeNode> count, + const std::vector<std::pair<ui32, TRuntimeNode>>& keys, + TType::EKind streamKind) { MKQL_ENSURE(stream.GetStaticType()->GetKind() == streamKind, "Mismatched input type"); const auto width = GetWideComponentsCount(stream.GetStaticType()); MKQL_ENSURE(!keys.empty() && keys.size() <= width, "Unexpected keys count: " << keys.size()); @@ -1972,14 +2032,10 @@ TRuntimeNode TProgramBuilder::Top(TRuntimeNode flow, TRuntimeNode count, TRuntim }; return FlatMap(Condense1(Map(flow, cacheKeyExtractor), - [&](TRuntimeNode item) { return AsList(item); }, - [this](TRuntimeNode, TRuntimeNode) { return NewDataLiteral<bool>(false); }, - [&](TRuntimeNode item, TRuntimeNode state) { - return KeepTop(count, state, item, ascending, getKey); - } - ), - [&](TRuntimeNode list) { return Map(Top(list, count, ascending, getKey), getItem); } - ); + [&](TRuntimeNode item) { return AsList(item); }, + [this](TRuntimeNode, TRuntimeNode) { return NewDataLiteral<bool>(false); }, + [&](TRuntimeNode item, TRuntimeNode state) { return KeepTop(count, state, item, ascending, getKey); }), + [&](TRuntimeNode list) { return Map(Top(list, count, ascending, getKey), getItem); }); } return BuildListNth(__func__, flow, count, ascending, keyExtractor); @@ -1994,20 +2050,20 @@ TRuntimeNode TProgramBuilder::TopSort(TRuntimeNode flow, TRuntimeNode count, TRu }; return FlatMap(Condense1(Map(flow, cacheKeyExtractor), - [&](TRuntimeNode item) { return AsList(item); }, - [this](TRuntimeNode, TRuntimeNode) { return NewDataLiteral<bool>(false); }, - [&](TRuntimeNode item, TRuntimeNode state) { - return KeepTop(count, state, item, ascending, getKey); - } - ), - [&](TRuntimeNode list) { return Map(TopSort(list, count, ascending, getKey), getItem); } - ); + [&](TRuntimeNode item) { return AsList(item); }, + [this](TRuntimeNode, TRuntimeNode) { return NewDataLiteral<bool>(false); }, + [&](TRuntimeNode item, TRuntimeNode state) { return KeepTop(count, state, item, ascending, getKey); }), + [&](TRuntimeNode list) { return Map(TopSort(list, count, ascending, getKey), getItem); }); } return BuildListNth(__func__, flow, count, ascending, keyExtractor); } -TRuntimeNode TProgramBuilder::KeepTop(TRuntimeNode count, TRuntimeNode list, TRuntimeNode item, TRuntimeNode ascending, const TUnaryLambda& keyExtractor) { +TRuntimeNode TProgramBuilder::KeepTop(TRuntimeNode count, + TRuntimeNode list, + TRuntimeNode item, + TRuntimeNode ascending, + const TUnaryLambda& keyExtractor) { const auto listType = list.GetStaticType(); MKQL_ENSURE(listType->IsList(), "Expected list."); @@ -2049,11 +2105,13 @@ TRuntimeNode TProgramBuilder::KeepTop(TRuntimeNode count, TRuntimeNode list, TRu } TRuntimeNode TProgramBuilder::Contains(TRuntimeNode dict, TRuntimeNode key) { - if (!dict.GetStaticType()->IsDict()) + if (!dict.GetStaticType()->IsDict()) { return DataCompare(__func__, dict, key); + } const auto keyType = AS_TYPE(TDictType, dict.GetStaticType())->GetKeyType(); - MKQL_ENSURE(keyType->IsSameType(*key.GetStaticType()), "Key type mismatch. Requred: " << *keyType << ", but got: " << *key.GetStaticType()); + MKQL_ENSURE(keyType->IsSameType(*key.GetStaticType()), + "Key type mismatch. Requred: " << *keyType << ", but got: " << *key.GetStaticType()); TCallableBuilder callableBuilder(Env_, __func__, NewDataType(NUdf::TDataType<bool>::Id)); callableBuilder.Add(dict); @@ -2064,7 +2122,8 @@ TRuntimeNode TProgramBuilder::Contains(TRuntimeNode dict, TRuntimeNode key) { TRuntimeNode TProgramBuilder::Lookup(TRuntimeNode dict, TRuntimeNode key) { const auto dictType = AS_TYPE(TDictType, dict.GetStaticType()); const auto keyType = dictType->GetKeyType(); - MKQL_ENSURE(keyType->IsSameType(*key.GetStaticType()), "Key type mismatch. Requred: " << *keyType << ", but got: " << *key.GetStaticType()); + MKQL_ENSURE(keyType->IsSameType(*key.GetStaticType()), + "Key type mismatch. Requred: " << *keyType << ", but got: " << *key.GetStaticType()); TCallableBuilder callableBuilder(Env_, __func__, NewOptionalType(dictType->GetPayloadType())); callableBuilder.Add(dict); @@ -2074,7 +2133,7 @@ TRuntimeNode TProgramBuilder::Lookup(TRuntimeNode dict, TRuntimeNode key) { TRuntimeNode TProgramBuilder::DictItems(TRuntimeNode dict) { const auto dictTypeChecked = AS_TYPE(TDictType, dict.GetStaticType()); - const auto itemType = NewTupleType({ dictTypeChecked->GetKeyType(), dictTypeChecked->GetPayloadType() }); + const auto itemType = NewTupleType({dictTypeChecked->GetKeyType(), dictTypeChecked->GetPayloadType()}); TCallableBuilder callableBuilder(Env_, __func__, NewListType(itemType)); callableBuilder.Add(dict); return TRuntimeNode(callableBuilder.Build(), false); @@ -2109,36 +2168,43 @@ TRuntimeNode TProgramBuilder::JoinDict(TRuntimeNode dict1, bool isMulti1, TRunti const auto dict2type = AS_TYPE(TDictType, dict2); MKQL_ENSURE(dict1type->GetKeyType()->IsSameType(*dict2type->GetKeyType()), "Dict key types must be the same"); - if (joinKind == EJoinKind::RightOnly || joinKind == EJoinKind::RightSemi) + if (joinKind == EJoinKind::RightOnly || joinKind == EJoinKind::RightSemi) { MKQL_ENSURE(dict1type->GetPayloadType()->IsVoid(), "Void required for first dict payload."); - else if (isMulti1) + } else if (isMulti1) { MKQL_ENSURE(dict1type->GetPayloadType()->IsList(), "List required for first dict payload."); + } - if (joinKind == EJoinKind::LeftOnly || joinKind == EJoinKind::LeftSemi) + if (joinKind == EJoinKind::LeftOnly || joinKind == EJoinKind::LeftSemi) { MKQL_ENSURE(dict2type->GetPayloadType()->IsVoid(), "Void required for second dict payload."); - else if (isMulti2) + } else if (isMulti2) { MKQL_ENSURE(dict2type->GetPayloadType()->IsList(), "List required for second dict payload."); + } - std::array<TType*, 2> tupleItems = {{ dict1type->GetPayloadType(), dict2type->GetPayloadType() }}; - if (isMulti1 && tupleItems.front()->IsList()) + std::array<TType*, 2> tupleItems = {{dict1type->GetPayloadType(), dict2type->GetPayloadType()}}; + if (isMulti1 && tupleItems.front()->IsList()) { tupleItems.front() = AS_TYPE(TListType, tupleItems.front())->GetItemType(); + } - if (isMulti2 && tupleItems.back()->IsList()) + if (isMulti2 && tupleItems.back()->IsList()) { tupleItems.back() = AS_TYPE(TListType, tupleItems.back())->GetItemType(); + } - if (IsLeftOptional(joinKind)) + if (IsLeftOptional(joinKind)) { tupleItems.front() = NewOptionalType(tupleItems.front()); + } - if (IsRightOptional(joinKind)) + if (IsRightOptional(joinKind)) { tupleItems.back() = NewOptionalType(tupleItems.back()); + } TType* itemType; - if (joinKind == EJoinKind::LeftOnly || joinKind == EJoinKind::LeftSemi) + if (joinKind == EJoinKind::LeftOnly || joinKind == EJoinKind::LeftSemi) { itemType = tupleItems.front(); - else if (joinKind == EJoinKind::RightOnly || joinKind == EJoinKind::RightSemi) + } else if (joinKind == EJoinKind::RightOnly || joinKind == EJoinKind::RightSemi) { itemType = tupleItems.back(); - else + } else { itemType = NewTupleType(tupleItems); + } const auto returnType = NewListType(itemType); TCallableBuilder callableBuilder(Env_, __func__, returnType); @@ -2151,9 +2217,9 @@ TRuntimeNode TProgramBuilder::JoinDict(TRuntimeNode dict1, bool isMulti1, TRunti } TRuntimeNode TProgramBuilder::GraceJoinCommon(const TStringBuf& funcName, TRuntimeNode flowLeft, TRuntimeNode flowRight, EJoinKind joinKind, - const TArrayRef<const ui32>& leftKeyColumns, const TArrayRef<const ui32>& rightKeyColumns, - const TArrayRef<const ui32>& leftRenames, const TArrayRef<const ui32>& rightRenames, TType* returnType, EAnyJoinSettings anyJoinSettings ) { - + const TArrayRef<const ui32>& leftKeyColumns, const TArrayRef<const ui32>& rightKeyColumns, + const TArrayRef<const ui32>& leftRenames, const TArrayRef<const ui32>& rightRenames, + TType* returnType, EAnyJoinSettings anyJoinSettings) { MKQL_ENSURE(!leftKeyColumns.empty(), "At least one key column must be specified"); if (flowRight) { MKQL_ENSURE(!rightKeyColumns.empty(), "At least one key column must be specified"); @@ -2162,17 +2228,20 @@ TRuntimeNode TProgramBuilder::GraceJoinCommon(const TStringBuf& funcName, TRunti TRuntimeNode::TList leftKeyColumnsNodes, rightKeyColumnsNodes, leftRenamesNodes, rightRenamesNodes; leftKeyColumnsNodes.reserve(leftKeyColumns.size()); - std::transform(leftKeyColumns.cbegin(), leftKeyColumns.cend(), std::back_inserter(leftKeyColumnsNodes), [this](const ui32 idx) { return NewDataLiteral(idx); }); + std::transform(leftKeyColumns.cbegin(), leftKeyColumns.cend(), std::back_inserter(leftKeyColumnsNodes), + [this](const ui32 idx) { return NewDataLiteral(idx); }); rightKeyColumnsNodes.reserve(rightKeyColumns.size()); - std::transform(rightKeyColumns.cbegin(), rightKeyColumns.cend(), std::back_inserter(rightKeyColumnsNodes), [this](const ui32 idx) { return NewDataLiteral(idx); }); + std::transform(rightKeyColumns.cbegin(), rightKeyColumns.cend(), std::back_inserter(rightKeyColumnsNodes), + [this](const ui32 idx) { return NewDataLiteral(idx); }); leftRenamesNodes.reserve(leftRenames.size()); - std::transform(leftRenames.cbegin(), leftRenames.cend(), std::back_inserter(leftRenamesNodes), [this](const ui32 idx) { return NewDataLiteral(idx); }); + std::transform(leftRenames.cbegin(), leftRenames.cend(), std::back_inserter(leftRenamesNodes), + [this](const ui32 idx) { return NewDataLiteral(idx); }); rightRenamesNodes.reserve(rightRenames.size()); - std::transform(rightRenames.cbegin(), rightRenames.cend(), std::back_inserter(rightRenamesNodes), [this](const ui32 idx) { return NewDataLiteral(idx); }); - + std::transform(rightRenames.cbegin(), rightRenames.cend(), std::back_inserter(rightRenamesNodes), + [this](const ui32 idx) { return NewDataLiteral(idx); }); TCallableBuilder callableBuilder(Env_, funcName, returnType); callableBuilder.Add(flowLeft); @@ -2190,45 +2259,47 @@ TRuntimeNode TProgramBuilder::GraceJoinCommon(const TStringBuf& funcName, TRunti } TRuntimeNode TProgramBuilder::GraceJoin(TRuntimeNode flowLeft, TRuntimeNode flowRight, EJoinKind joinKind, - const TArrayRef<const ui32>& leftKeyColumns, const TArrayRef<const ui32>& rightKeyColumns, - const TArrayRef<const ui32>& leftRenames, const TArrayRef<const ui32>& rightRenames, TType* returnType, EAnyJoinSettings anyJoinSettings ) { - - return GraceJoinCommon(__func__, flowLeft, flowRight, joinKind, leftKeyColumns, rightKeyColumns, leftRenames, rightRenames, returnType, anyJoinSettings); + const TArrayRef<const ui32>& leftKeyColumns, const TArrayRef<const ui32>& rightKeyColumns, + const TArrayRef<const ui32>& leftRenames, const TArrayRef<const ui32>& rightRenames, + TType* returnType, EAnyJoinSettings anyJoinSettings) { + return GraceJoinCommon(__func__, flowLeft, flowRight, joinKind, leftKeyColumns, rightKeyColumns, + leftRenames, rightRenames, returnType, anyJoinSettings); } -TRuntimeNode TProgramBuilder::GraceSelfJoin(TRuntimeNode flowLeft, EJoinKind joinKind, const TArrayRef<const ui32>& leftKeyColumns, const TArrayRef<const ui32>& rightKeyColumns, - const TArrayRef<const ui32>& leftRenames, const TArrayRef<const ui32>& rightRenames, TType* returnType, EAnyJoinSettings anyJoinSettings ) { - - return GraceJoinCommon(__func__, flowLeft, {}, joinKind, leftKeyColumns, rightKeyColumns, leftRenames, rightRenames, returnType, anyJoinSettings); +TRuntimeNode TProgramBuilder::GraceSelfJoin(TRuntimeNode flowLeft, EJoinKind joinKind, const TArrayRef<const ui32>& leftKeyColumns, + const TArrayRef<const ui32>& rightKeyColumns, const TArrayRef<const ui32>& leftRenames, + const TArrayRef<const ui32>& rightRenames, TType* returnType, EAnyJoinSettings anyJoinSettings) { + return GraceJoinCommon(__func__, flowLeft, {}, joinKind, leftKeyColumns, rightKeyColumns, + leftRenames, rightRenames, returnType, anyJoinSettings); } TRuntimeNode TProgramBuilder::ToSortedDict(TRuntimeNode list, bool all, const TUnaryLambda& keySelector, - const TUnaryLambda& payloadSelector, bool isCompact, ui64 itemsCountHint) { + const TUnaryLambda& payloadSelector, bool isCompact, ui64 itemsCountHint) { return ToDict(list, all, keySelector, payloadSelector, __func__, isCompact, itemsCountHint); } TRuntimeNode TProgramBuilder::ToHashedDict(TRuntimeNode list, bool all, const TUnaryLambda& keySelector, - const TUnaryLambda& payloadSelector, bool isCompact, ui64 itemsCountHint) { + const TUnaryLambda& payloadSelector, bool isCompact, ui64 itemsCountHint) { return ToDict(list, all, keySelector, payloadSelector, __func__, isCompact, itemsCountHint); } TRuntimeNode TProgramBuilder::SqueezeToSortedDict(TRuntimeNode stream, bool all, const TUnaryLambda& keySelector, - const TUnaryLambda& payloadSelector, bool isCompact, ui64 itemsCountHint) { + const TUnaryLambda& payloadSelector, bool isCompact, ui64 itemsCountHint) { return SqueezeToDict(stream, all, keySelector, payloadSelector, __func__, isCompact, itemsCountHint); } TRuntimeNode TProgramBuilder::SqueezeToHashedDict(TRuntimeNode stream, bool all, const TUnaryLambda& keySelector, - const TUnaryLambda& payloadSelector, bool isCompact, ui64 itemsCountHint) { + const TUnaryLambda& payloadSelector, bool isCompact, ui64 itemsCountHint) { return SqueezeToDict(stream, all, keySelector, payloadSelector, __func__, isCompact, itemsCountHint); } TRuntimeNode TProgramBuilder::NarrowSqueezeToSortedDict(TRuntimeNode stream, bool all, const TNarrowLambda& keySelector, - const TNarrowLambda& payloadSelector, bool isCompact, ui64 itemsCountHint) { + const TNarrowLambda& payloadSelector, bool isCompact, ui64 itemsCountHint) { return NarrowSqueezeToDict(stream, all, keySelector, payloadSelector, __func__, isCompact, itemsCountHint); } TRuntimeNode TProgramBuilder::NarrowSqueezeToHashedDict(TRuntimeNode stream, bool all, const TNarrowLambda& keySelector, - const TNarrowLambda& payloadSelector, bool isCompact, ui64 itemsCountHint) { + const TNarrowLambda& payloadSelector, bool isCompact, ui64 itemsCountHint) { return NarrowSqueezeToDict(stream, all, keySelector, payloadSelector, __func__, isCompact, itemsCountHint); } @@ -2278,9 +2349,8 @@ TRuntimeNode TProgramBuilder::BuildExtend(const std::string_view& callableName, MKQL_ENSURE(listType->IsFlow() || listType->IsList() || listType->IsStream(), "Expected either flow, list or stream"); for (ui32 i = 1; i < lists.size(); ++i) { auto listType2 = lists[i].GetStaticType(); - MKQL_ENSURE(listType->IsSameType(*listType2), "Types of flows are different, left: " << - PrintNode(listType, true) << ", right: " << - PrintNode(listType2, true)); + MKQL_ENSURE(listType->IsSameType(*listType2), + "Types of flows are different, left: " << PrintNode(listType, true) << ", right: " << PrintNode(listType2, true)); } TCallableBuilder callableBuilder(Env_, callableName, listType); @@ -2299,57 +2369,57 @@ TRuntimeNode TProgramBuilder::OrderedExtend(const TArrayRef<const TRuntimeNode>& return BuildExtend(__func__, lists); } -template<> +template <> TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::String>(const NUdf::TStringRef& data) const { return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<const char*>::Id, Env_), true); } -template<> +template <> TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::Utf8>(const NUdf::TStringRef& data) const { return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TUtf8>::Id, Env_), true); } -template<> +template <> TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::Yson>(const NUdf::TStringRef& data) const { return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TYson>::Id, Env_), true); } -template<> +template <> TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::Json>(const NUdf::TStringRef& data) const { return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TJson>::Id, Env_), true); } -template<> +template <> TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::JsonDocument>(const NUdf::TStringRef& data) const { return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TJsonDocument>::Id, Env_), true); } -template<> +template <> TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::Uuid>(const NUdf::TStringRef& data) const { return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TUuid>::Id, Env_), true); } -template<> +template <> TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::Date>(const NUdf::TStringRef& data) const { return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TDate>::Id, Env_), true); } -template<> +template <> TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::Datetime>(const NUdf::TStringRef& data) const { return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TDatetime>::Id, Env_), true); } -template<> +template <> TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::Timestamp>(const NUdf::TStringRef& data) const { return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TTimestamp>::Id, Env_), true); } -template<> +template <> TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::Interval>(const NUdf::TStringRef& data) const { return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TInterval>::Id, Env_), true); } -template<> +template <> TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::DyNumber>(const NUdf::TStringRef& data) const { return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TDyNumber>::Id, Env_), true); } @@ -2358,22 +2428,22 @@ TRuntimeNode TProgramBuilder::NewDecimalLiteral(NYql::NDecimal::TInt128 data, ui return TRuntimeNode(TDataLiteral::Create(NUdf::TUnboxedValuePod(data), TDataDecimalType::Create(precision, scale, Env_), Env_), true); } -template<> +template <> TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::Date32>(const NUdf::TStringRef& data) const { return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TDate32>::Id, Env_), true); } -template<> +template <> TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::Datetime64>(const NUdf::TStringRef& data) const { return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TDatetime64>::Id, Env_), true); } -template<> +template <> TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::Timestamp64>(const NUdf::TStringRef& data) const { return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TTimestamp64>::Id, Env_), true); } -template<> +template <> TRuntimeNode TProgramBuilder::NewDataLiteral<NUdf::EDataSlot::Interval64>(const NUdf::TStringRef& data) const { return TRuntimeNode(BuildDataLiteral(data, NUdf::TDataType<NUdf::TInterval64>::Id, Env_), true); } @@ -2477,7 +2547,6 @@ TRuntimeNode TProgramBuilder::NewEmptyTuple() { return TRuntimeNode(Env_.GetEmptyTupleLazy(), true); } - TRuntimeNode TProgramBuilder::NewTuple(TType* tupleType, const TArrayRef<const TRuntimeNode>& elements) { MKQL_ENSURE(tupleType->IsTuple(), "Expected tuple type"); @@ -2494,7 +2563,6 @@ TRuntimeNode TProgramBuilder::NewTuple(const TArrayRef<const TRuntimeNode>& elem return NewTuple(NewTupleType(types), elements); } - TRuntimeNode TProgramBuilder::NewVariant(TRuntimeNode item, ui32 index, TType* variantType) { const auto type = AS_TYPE(TVariantType, variantType); MKQL_ENSURE(type->GetUnderlyingType()->IsTuple(), "Expected tuple as underlying type"); @@ -2744,9 +2812,8 @@ TRuntimeNode TProgramBuilder::MutDictItems(TType* dictType, TRuntimeNode mdict) MKQL_ENSURE(dictType->IsDict(), "Expected dict"); ValidateMutDictType(mdict.GetStaticType()); - TType* listItemType = NewTupleType({ - AS_TYPE(TDictType, dictType)->GetKeyType(), - AS_TYPE(TDictType, dictType)->GetPayloadType()}); + TType* listItemType = NewTupleType({AS_TYPE(TDictType, dictType)->GetKeyType(), + AS_TYPE(TDictType, dictType)->GetPayloadType()}); auto listType = NewListType(listItemType); auto retType = NewTupleType({mdict.GetStaticType(), listType}); @@ -2815,7 +2882,8 @@ TRuntimeNode TProgramBuilder::Coalesce(TRuntimeNode data, TRuntimeNode defaultDa if (!dataType->IsSameType(*defaultData.GetStaticType())) { bool isOptionalDefault; const auto defaultDataType = UnpackOptional(defaultData, isOptionalDefault); - MKQL_ENSURE(dataType->IsSameType(*defaultDataType), TStringBuilder() << "Mismatch operand types. Left: " << *dataType << ", right: " << *defaultDataType); + MKQL_ENSURE(dataType->IsSameType(*defaultDataType), + TStringBuilder() << "Mismatch operand types. Left: " << *dataType << ", right: " << *defaultDataType); } TCallableBuilder callableBuilder(Env_, __func__, defaultData.GetStaticType()); @@ -2833,7 +2901,8 @@ TRuntimeNode TProgramBuilder::Unwrap(TRuntimeNode optional, TRuntimeNode message MKQL_ENSURE(messageType->IsData(), "Expected data"); const auto& messageTypeData = static_cast<const TDataType&>(*messageType); - MKQL_ENSURE(messageTypeData.GetSchemeType() == NUdf::TDataType<char*>::Id || messageTypeData.GetSchemeType() == NUdf::TDataType<NUdf::TUtf8>::Id, "Expected string or utf8."); + MKQL_ENSURE(messageTypeData.GetSchemeType() == NUdf::TDataType<char*>::Id || + messageTypeData.GetSchemeType() == NUdf::TDataType<NUdf::TUtf8>::Id, "Expected string or utf8."); TCallableBuilder callableBuilder(Env_, __func__, underlyingType); callableBuilder.Add(optional); @@ -2845,50 +2914,56 @@ TRuntimeNode TProgramBuilder::Unwrap(TRuntimeNode optional, TRuntimeNode message } TRuntimeNode TProgramBuilder::Increment(TRuntimeNode data) { - const std::array<TRuntimeNode, 1> args = {{ data }}; + const std::array<TRuntimeNode, 1> args = {{data}}; bool isOptional; const auto type = UnpackOptionalData(data, isOptional); - if (type->GetSchemeType() != NUdf::TDataType<NUdf::TDecimal>::Id) + if (type->GetSchemeType() != NUdf::TDataType<NUdf::TDecimal>::Id) { return Invoke(__func__, data.GetStaticType(), args); + } - return Invoke(TString("Inc_") += ::ToString(static_cast<TDataDecimalType*>(type)->GetParams().first), data.GetStaticType(), args); + const auto suffix = ::ToString(static_cast<TDataDecimalType*>(type)->GetParams().first); + return Invoke(TString("Inc_") += suffix, data.GetStaticType(), args); } TRuntimeNode TProgramBuilder::Decrement(TRuntimeNode data) { - const std::array<TRuntimeNode, 1> args = {{ data }}; + const std::array<TRuntimeNode, 1> args = {{data}}; bool isOptional; const auto type = UnpackOptionalData(data, isOptional); - if (type->GetSchemeType() != NUdf::TDataType<NUdf::TDecimal>::Id) + if (type->GetSchemeType() != NUdf::TDataType<NUdf::TDecimal>::Id) { return Invoke(__func__, data.GetStaticType(), args); + } - return Invoke(TString("Dec_") += ::ToString(static_cast<TDataDecimalType*>(type)->GetParams().first), data.GetStaticType(), args); + auto suffix = ::ToString(static_cast<TDataDecimalType*>(type)->GetParams().first); + return Invoke(TString("Dec_") += suffix, data.GetStaticType(), args); } TRuntimeNode TProgramBuilder::Abs(TRuntimeNode data) { - const std::array<TRuntimeNode, 1> args = {{ data }}; + const std::array<TRuntimeNode, 1> args = {{data}}; return Invoke(__func__, data.GetStaticType(), args); } TRuntimeNode TProgramBuilder::Plus(TRuntimeNode data) { - const std::array<TRuntimeNode, 1> args = {{ data }}; + const std::array<TRuntimeNode, 1> args = {{data}}; return Invoke(__func__, data.GetStaticType(), args); } TRuntimeNode TProgramBuilder::Minus(TRuntimeNode data) { - const std::array<TRuntimeNode, 1> args = {{ data }}; + const std::array<TRuntimeNode, 1> args = {{data}}; return Invoke(__func__, data.GetStaticType(), args); } TRuntimeNode TProgramBuilder::Add(TRuntimeNode data1, TRuntimeNode data2) { - const std::array<TRuntimeNode, 2> args = {{ data1, data2 }}; + const std::array<TRuntimeNode, 2> args = {{data1, data2}}; bool isOptionalLeft; const auto leftType = UnpackOptionalData(data1, isOptionalLeft); - if (leftType->GetSchemeType() != NUdf::TDataType<NUdf::TDecimal>::Id) - return Invoke(__func__, BuildArithmeticCommonType(data1.GetStaticType(), data2.GetStaticType()), args); + if (leftType->GetSchemeType() != NUdf::TDataType<NUdf::TDecimal>::Id) { + auto commonType = BuildArithmeticCommonType(data1.GetStaticType(), data2.GetStaticType()); + return Invoke(__func__, commonType, args); + } const auto decimalType = static_cast<TDataDecimalType*>(leftType); bool isOptionalRight; @@ -2899,13 +2974,15 @@ TRuntimeNode TProgramBuilder::Add(TRuntimeNode data1, TRuntimeNode data2) { } TRuntimeNode TProgramBuilder::Sub(TRuntimeNode data1, TRuntimeNode data2) { - const std::array<TRuntimeNode, 2> args = {{ data1, data2 }}; + const std::array<TRuntimeNode, 2> args = {{data1, data2}}; bool isOptionalLeft; const auto leftType = UnpackOptionalData(data1, isOptionalLeft); - if (leftType->GetSchemeType() != NUdf::TDataType<NUdf::TDecimal>::Id) - return Invoke(__func__, BuildArithmeticCommonType(data1.GetStaticType(), data2.GetStaticType()), args); + if (leftType->GetSchemeType() != NUdf::TDataType<NUdf::TDecimal>::Id) { + auto commonType = BuildArithmeticCommonType(data1.GetStaticType(), data2.GetStaticType()); + return Invoke(__func__, commonType, args); + } const auto decimalType = static_cast<TDataDecimalType*>(leftType); bool isOptionalRight; @@ -2916,14 +2993,15 @@ TRuntimeNode TProgramBuilder::Sub(TRuntimeNode data1, TRuntimeNode data2) { } TRuntimeNode TProgramBuilder::Mul(TRuntimeNode data1, TRuntimeNode data2) { - const std::array<TRuntimeNode, 2> args = {{ data1, data2 }}; + const std::array<TRuntimeNode, 2> args = {{data1, data2}}; return Invoke(__func__, BuildArithmeticCommonType(data1.GetStaticType(), data2.GetStaticType()), args); } TRuntimeNode TProgramBuilder::Div(TRuntimeNode data1, TRuntimeNode data2) { - const std::array<TRuntimeNode, 2> args = {{ data1, data2 }}; + const std::array<TRuntimeNode, 2> args = {{data1, data2}}; auto resultType = BuildArithmeticCommonType(data1.GetStaticType(), data2.GetStaticType()); - if (resultType->IsData() && !(NUdf::GetDataTypeInfo(*static_cast<TDataType*>(resultType)->GetDataSlot()).Features & (NUdf::EDataTypeFeatures::FloatType | NUdf::EDataTypeFeatures::DecimalType))) { + if (resultType->IsData() && !(NUdf::GetDataTypeInfo(*static_cast<TDataType*>(resultType)->GetDataSlot()).Features & + (NUdf::EDataTypeFeatures::FloatType | NUdf::EDataTypeFeatures::DecimalType))) { resultType = NewOptionalType(resultType); } return Invoke(__func__, resultType, args); @@ -2934,10 +3012,11 @@ TRuntimeNode TProgramBuilder::DecimalDiv(TRuntimeNode data1, TRuntimeNode data2) const auto leftType = static_cast<TDataDecimalType*>(UnpackOptionalData(data1, isOptionalLeft)); const auto rightType = UnpackOptionalData(data2, isOptionalRight); - if (rightType->GetSchemeType() == NUdf::TDataType<NUdf::TDecimal>::Id) + if (rightType->GetSchemeType() == NUdf::TDataType<NUdf::TDecimal>::Id) { MKQL_ENSURE(static_cast<TDataDecimalType*>(rightType)->IsSameType(*leftType), "Operands type mismatch"); - else + } else { MKQL_ENSURE(NUdf::GetDataTypeInfo(*rightType->GetDataSlot()).Features & NUdf::IntegralType, "Operands type mismatch"); + } const auto returnType = isOptionalLeft || isOptionalRight ? NewOptionalType(leftType) : leftType; @@ -2952,10 +3031,11 @@ TRuntimeNode TProgramBuilder::DecimalMod(TRuntimeNode data1, TRuntimeNode data2) const auto leftType = static_cast<TDataDecimalType*>(UnpackOptionalData(data1, isOptionalLeft)); const auto rightType = UnpackOptionalData(data2, isOptionalRight); - if (rightType->GetSchemeType() == NUdf::TDataType<NUdf::TDecimal>::Id) + if (rightType->GetSchemeType() == NUdf::TDataType<NUdf::TDecimal>::Id) { MKQL_ENSURE(static_cast<TDataDecimalType*>(rightType)->IsSameType(*leftType), "Operands type mismatch"); - else + } else { MKQL_ENSURE(NUdf::GetDataTypeInfo(*rightType->GetDataSlot()).Features & NUdf::IntegralType, "Operands type mismatch"); + } const auto returnType = isOptionalLeft || isOptionalRight ? NewOptionalType(leftType) : leftType; @@ -2970,10 +3050,11 @@ TRuntimeNode TProgramBuilder::DecimalMul(TRuntimeNode data1, TRuntimeNode data2) const auto leftType = static_cast<TDataDecimalType*>(UnpackOptionalData(data1, isOptionalLeft)); const auto rightType = UnpackOptionalData(data2, isOptionalRight); - if (rightType->GetSchemeType() == NUdf::TDataType<NUdf::TDecimal>::Id) + if (rightType->GetSchemeType() == NUdf::TDataType<NUdf::TDecimal>::Id) { MKQL_ENSURE(static_cast<TDataDecimalType*>(rightType)->IsSameType(*leftType), "Operands type mismatch"); - else + } else { MKQL_ENSURE(NUdf::GetDataTypeInfo(*rightType->GetDataSlot()).Features & NUdf::IntegralType, "Operands type mismatch"); + } const auto returnType = isOptionalLeft || isOptionalRight ? NewOptionalType(leftType) : leftType; @@ -2992,54 +3073,55 @@ TRuntimeNode TProgramBuilder::NotAllOf(TRuntimeNode list, const TUnaryLambda& pr } TRuntimeNode TProgramBuilder::BitNot(TRuntimeNode data) { - const std::array<TRuntimeNode, 1> args = {{ data }}; + const std::array<TRuntimeNode, 1> args = {{data}}; return Invoke(__func__, data.GetStaticType(), args); } TRuntimeNode TProgramBuilder::CountBits(TRuntimeNode data) { - const std::array<TRuntimeNode, 1> args = {{ data }}; + const std::array<TRuntimeNode, 1> args = {{data}}; return Invoke(__func__, data.GetStaticType(), args); } TRuntimeNode TProgramBuilder::BitAnd(TRuntimeNode data1, TRuntimeNode data2) { - const std::array<TRuntimeNode, 2> args = {{ data1, data2 }}; + const std::array<TRuntimeNode, 2> args = {{data1, data2}}; return Invoke(__func__, BuildArithmeticCommonType(data1.GetStaticType(), data2.GetStaticType()), args); } TRuntimeNode TProgramBuilder::BitOr(TRuntimeNode data1, TRuntimeNode data2) { - const std::array<TRuntimeNode, 2> args = {{ data1, data2 }}; + const std::array<TRuntimeNode, 2> args = {{data1, data2}}; return Invoke(__func__, BuildArithmeticCommonType(data1.GetStaticType(), data2.GetStaticType()), args); } TRuntimeNode TProgramBuilder::BitXor(TRuntimeNode data1, TRuntimeNode data2) { - const std::array<TRuntimeNode, 2> args = {{ data1, data2 }}; + const std::array<TRuntimeNode, 2> args = {{data1, data2}}; return Invoke(__func__, BuildArithmeticCommonType(data1.GetStaticType(), data2.GetStaticType()), args); } TRuntimeNode TProgramBuilder::ShiftLeft(TRuntimeNode arg, TRuntimeNode bits) { - const std::array<TRuntimeNode, 2> args = {{ arg, bits }}; + const std::array<TRuntimeNode, 2> args = {{arg, bits}}; return Invoke(__func__, arg.GetStaticType(), args); } TRuntimeNode TProgramBuilder::RotLeft(TRuntimeNode arg, TRuntimeNode bits) { - const std::array<TRuntimeNode, 2> args = {{ arg, bits }}; + const std::array<TRuntimeNode, 2> args = {{arg, bits}}; return Invoke(__func__, arg.GetStaticType(), args); } TRuntimeNode TProgramBuilder::ShiftRight(TRuntimeNode arg, TRuntimeNode bits) { - const std::array<TRuntimeNode, 2> args = {{ arg, bits }}; + const std::array<TRuntimeNode, 2> args = {{arg, bits}}; return Invoke(__func__, arg.GetStaticType(), args); } TRuntimeNode TProgramBuilder::RotRight(TRuntimeNode arg, TRuntimeNode bits) { - const std::array<TRuntimeNode, 2> args = {{ arg, bits }}; + const std::array<TRuntimeNode, 2> args = {{arg, bits}}; return Invoke(__func__, arg.GetStaticType(), args); } TRuntimeNode TProgramBuilder::Mod(TRuntimeNode data1, TRuntimeNode data2) { - const std::array<TRuntimeNode, 2> args = {{ data1, data2 }}; + const std::array<TRuntimeNode, 2> args = {{data1, data2}}; auto resultType = BuildArithmeticCommonType(data1.GetStaticType(), data2.GetStaticType()); - if (resultType->IsData() && !(NUdf::GetDataTypeInfo(*static_cast<TDataType*>(resultType)->GetDataSlot()).Features & (NUdf::EDataTypeFeatures::FloatType | NUdf::EDataTypeFeatures::DecimalType))) { + if (resultType->IsData() && !(NUdf::GetDataTypeInfo(*static_cast<TDataType*>(resultType)->GetDataSlot()).Features & + (NUdf::EDataTypeFeatures::FloatType | NUdf::EDataTypeFeatures::DecimalType))) { resultType = NewOptionalType(resultType); } return Invoke(__func__, resultType, args); @@ -3047,14 +3129,20 @@ TRuntimeNode TProgramBuilder::Mod(TRuntimeNode data1, TRuntimeNode data2) { TRuntimeNode TProgramBuilder::BuildMinMax(const std::string_view& callableName, const TRuntimeNode* data, size_t size) { switch (size) { - case 0U: return NewNull(); - case 1U: return *data; - case 2U: return InvokeBinary(callableName, ChooseCommonType(data[0U].GetStaticType(), data[1U].GetStaticType()), data[0U], data[1U]); - default: break; + case 0U: + return NewNull(); + case 1U: + return *data; + case 2U: { + const auto commonType = ChooseCommonType(data[0U].GetStaticType(), data[1U].GetStaticType()); + return InvokeBinary(callableName, commonType, data[0U], data[1U]); + } + default: + break; } const auto half = size >> 1U; - const std::array<TRuntimeNode, 2U> args = {{ BuildMinMax(callableName, data, half), BuildMinMax(callableName, data + half, size - half) }}; + const std::array<TRuntimeNode, 2U> args = {{BuildMinMax(callableName, data, half), BuildMinMax(callableName, data + half, size - half)}}; return BuildMinMax(callableName, args.data(), args.size()); } @@ -3121,20 +3209,20 @@ TRuntimeNode TProgramBuilder::BuildBlockDecimalBinary(const std::string_view& ca } TRuntimeNode TProgramBuilder::Min(const TArrayRef<const TRuntimeNode>& args) { - return BuildMinMax(__func__, args.data(), args.size()); + return BuildMinMax(__func__, args.data(), args.size()); } TRuntimeNode TProgramBuilder::Max(const TArrayRef<const TRuntimeNode>& args) { - return BuildMinMax(__func__, args.data(), args.size()); + return BuildMinMax(__func__, args.data(), args.size()); } TRuntimeNode TProgramBuilder::Min(TRuntimeNode data1, TRuntimeNode data2) { - const std::array<TRuntimeNode, 2U> args = {{ data1, data2 }}; + const std::array<TRuntimeNode, 2U> args = {{data1, data2}}; return Min(args); } TRuntimeNode TProgramBuilder::Max(TRuntimeNode data1, TRuntimeNode data2) { - const std::array<TRuntimeNode, 2U> args = {{ data1, data2 }}; + const std::array<TRuntimeNode, 2U> args = {{data1, data2}}; return Max(args); } @@ -3163,7 +3251,7 @@ TRuntimeNode TProgramBuilder::GreaterOrEqual(TRuntimeNode data1, TRuntimeNode da } TRuntimeNode TProgramBuilder::InvokeBinary(const std::string_view& callableName, TType* type, TRuntimeNode data1, TRuntimeNode data2) { - const std::array<TRuntimeNode, 2> args = {{ data1, data2 }}; + const std::array<TRuntimeNode, 2> args = {{data1, data2}}; return Invoke(callableName, type, args); } @@ -3183,20 +3271,33 @@ TRuntimeNode TProgramBuilder::DataCompare(const std::string_view& callableName, const auto& lDec = static_cast<TDataDecimalType*>(leftType)->GetParams(); const auto& rDec = static_cast<TDataDecimalType*>(rightType)->GetParams(); if (lDec.second < rDec.second) { - left = ToDecimal(left, std::min<ui8>(lDec.first + rDec.second - lDec.second, NYql::NDecimal::MaxPrecision), rDec.second); + left = ToDecimal(left, + std::min<ui8>(lDec.first + rDec.second - lDec.second, NYql::NDecimal::MaxPrecision), + rDec.second); } else if (lDec.second > rDec.second) { - right = ToDecimal(right, std::min<ui8>(rDec.first + lDec.second - rDec.second, NYql::NDecimal::MaxPrecision), lDec.second); + right = ToDecimal(right, + std::min<ui8>(rDec.first + lDec.second - rDec.second, NYql::NDecimal::MaxPrecision), + lDec.second); } - } else if (lId == NUdf::TDataType<NUdf::TDecimal>::Id && NUdf::GetDataTypeInfo(NUdf::GetDataSlot(rId)).Features & NUdf::EDataTypeFeatures::IntegralType) { + } else if (lId == NUdf::TDataType<NUdf::TDecimal>::Id && + NUdf::GetDataTypeInfo(NUdf::GetDataSlot(rId)).Features & NUdf::EDataTypeFeatures::IntegralType) { const auto scale = static_cast<TDataDecimalType*>(leftType)->GetParams().second; - right = ToDecimal(right, std::min<ui8>(NYql::NDecimal::MaxPrecision, NUdf::GetDataTypeInfo(NUdf::GetDataSlot(rId)).DecimalDigits + scale), scale); - } else if (rId == NUdf::TDataType<NUdf::TDecimal>::Id && NUdf::GetDataTypeInfo(NUdf::GetDataSlot(lId)).Features & NUdf::EDataTypeFeatures::IntegralType) { + right = ToDecimal(right, + std::min<ui8>(NYql::NDecimal::MaxPrecision, + NUdf::GetDataTypeInfo(NUdf::GetDataSlot(rId)).DecimalDigits + scale), + scale); + } else if (rId == NUdf::TDataType<NUdf::TDecimal>::Id && + NUdf::GetDataTypeInfo(NUdf::GetDataSlot(lId)).Features & NUdf::EDataTypeFeatures::IntegralType) { const auto scale = static_cast<TDataDecimalType*>(rightType)->GetParams().second; - left = ToDecimal(left, std::min<ui8>(NYql::NDecimal::MaxPrecision, NUdf::GetDataTypeInfo(NUdf::GetDataSlot(lId)).DecimalDigits + scale), scale); + left = ToDecimal(left, + std::min<ui8>(NYql::NDecimal::MaxPrecision, + NUdf::GetDataTypeInfo(NUdf::GetDataSlot(lId)).DecimalDigits + scale), + scale); } - const std::array<TRuntimeNode, 2> args = {{ left, right }}; - const auto resultType = isOptionalLeft || isOptionalRight ? NewOptionalType(NewDataType(NUdf::TDataType<bool>::Id)) : NewDataType(NUdf::TDataType<bool>::Id); + const std::array<TRuntimeNode, 2> args = {{left, right}}; + const auto boolType = NewDataType(NUdf::TDataType<bool>::Id); + const auto resultType = isOptionalLeft || isOptionalRight ? NewOptionalType(boolType) : boolType; return Invoke(callableName, resultType, args); } @@ -3276,7 +3377,8 @@ TRuntimeNode TProgramBuilder::If(TRuntimeNode condition, TRuntimeNode thenBranch return TRuntimeNode(callableBuilder.Build(), false); } -TRuntimeNode TProgramBuilder::Ensure(TRuntimeNode value, TRuntimeNode predicate, TRuntimeNode message, const std::string_view& file, ui32 row, ui32 column) { +TRuntimeNode TProgramBuilder::Ensure(TRuntimeNode value, TRuntimeNode predicate, TRuntimeNode message, const std::string_view& file, + ui32 row, ui32 column) { bool isOptional; const auto unpackedType = UnpackOptionalData(predicate, isOptional); MKQL_ENSURE(unpackedType->GetSchemeType() == NUdf::TDataType<bool>::Id, "Expected bool"); @@ -3285,7 +3387,8 @@ TRuntimeNode TProgramBuilder::Ensure(TRuntimeNode value, TRuntimeNode predicate, MKQL_ENSURE(messageType->IsData(), "Expected data"); const auto& messageTypeData = static_cast<const TDataType&>(*messageType); - MKQL_ENSURE(messageTypeData.GetSchemeType() == NUdf::TDataType<char*>::Id || messageTypeData.GetSchemeType() == NUdf::TDataType<NUdf::TUtf8>::Id, "Expected string or utf8."); + MKQL_ENSURE(messageTypeData.GetSchemeType() == NUdf::TDataType<char*>::Id || + messageTypeData.GetSchemeType() == NUdf::TDataType<NUdf::TUtf8>::Id, "Expected string or utf8."); TCallableBuilder callableBuilder(Env_, __func__, value.GetStaticType()); callableBuilder.Add(value); @@ -3336,26 +3439,23 @@ TRuntimeNode TProgramBuilder::IfPresent(TRuntimeNode::TList optionals, const TNa case 0U: return thenBranch({}); case 1U: - return IfPresent(optionals.front(), [&](TRuntimeNode unwrap){ return thenBranch({unwrap}); }, elseBranch); + return IfPresent(optionals.front(), [&](TRuntimeNode unwrap) { return thenBranch({unwrap}); }, elseBranch); default: break; - } const auto first = optionals.front(); optionals.erase(optionals.cbegin()); return IfPresent(first, - [&](TRuntimeNode head) { - return IfPresent(optionals, - [&](TRuntimeNode::TList tail) { - tail.insert(tail.cbegin(), head); - return thenBranch(tail); - }, - elseBranch - ); - }, - elseBranch - ); + [&](TRuntimeNode head) { + return IfPresent(optionals, + [&](TRuntimeNode::TList tail) { + tail.insert(tail.cbegin(), head); + return thenBranch(tail); + }, + elseBranch); + }, + elseBranch); } TRuntimeNode TProgramBuilder::Not(TRuntimeNode data) { @@ -3377,8 +3477,10 @@ TRuntimeNode TProgramBuilder::BuildLogical(const std::string_view& callableName, MKQL_ENSURE(!args.empty(), "Empty logical args."); switch (args.size()) { - case 1U: return args.front(); - case 2U: return BuildBinaryLogical(callableName, args.front(), args.back()); + case 1U: + return args.front(); + case 2U: + return BuildBinaryLogical(callableName, args.front(), args.back()); } const auto half = (args.size() + 1U) >> 1U; @@ -3425,7 +3527,7 @@ TRuntimeNode TProgramBuilder::NextMTRand(TRuntimeNode rand) { auto resType = AS_TYPE(TResourceType, rand); MKQL_ENSURE(resType->GetTag() == RandomMTResource, "Expected MTRand resource"); - const std::array<TType*, 2U> tupleTypes = {{ NewDataType(NUdf::TDataType<ui64>::Id), rand.GetStaticType() }}; + const std::array<TType*, 2U> tupleTypes = {{NewDataType(NUdf::TDataType<ui64>::Id), rand.GetStaticType()}}; auto returnType = NewTupleType(tupleTypes); TCallableBuilder callableBuilder(Env_, __func__, returnType); @@ -3460,12 +3562,13 @@ TRuntimeNode TProgramBuilder::AggrMax(TRuntimeNode data1, TRuntimeNode data2) { } TRuntimeNode TProgramBuilder::AggrAdd(TRuntimeNode data1, TRuntimeNode data2) { - const std::array<TRuntimeNode, 2> args = {{ data1, data2 }}; + const std::array<TRuntimeNode, 2> args = {{data1, data2}}; bool isOptionalLeft; const auto leftType = UnpackOptionalData(data1, isOptionalLeft); - if (leftType->GetSchemeType() != NUdf::TDataType<NUdf::TDecimal>::Id) + if (leftType->GetSchemeType() != NUdf::TDataType<NUdf::TDecimal>::Id) { return Invoke(__func__, data1.GetStaticType(), args); + } const auto decimalType = static_cast<TDataDecimalType*>(leftType); bool isOptionalRight; @@ -3474,7 +3577,8 @@ TRuntimeNode TProgramBuilder::AggrAdd(TRuntimeNode data1, TRuntimeNode data2) { return Invoke(TString("AggrAdd_") += ::ToString(decimalType->GetParams().first), data1.GetStaticType(), args); } -TRuntimeNode TProgramBuilder::QueueCreate(TRuntimeNode initCapacity, TRuntimeNode initSize, const TArrayRef<const TRuntimeNode>& dependentNodes, TType* returnType) { +TRuntimeNode TProgramBuilder::QueueCreate(TRuntimeNode initCapacity, TRuntimeNode initSize, + const TArrayRef<const TRuntimeNode>& dependentNodes, TType* returnType) { auto resType = AS_TYPE(TResourceType, returnType); const auto tag = resType->GetTag(); @@ -3515,7 +3619,8 @@ TRuntimeNode TProgramBuilder::QueuePop(TRuntimeNode resource) { return TRuntimeNode(callableBuilder.Build(), false); } -TRuntimeNode TProgramBuilder::QueuePeek(TRuntimeNode resource, TRuntimeNode index, const TArrayRef<const TRuntimeNode>& dependentNodes, TType* returnType) { +TRuntimeNode TProgramBuilder::QueuePeek(TRuntimeNode resource, TRuntimeNode index, + const TArrayRef<const TRuntimeNode>& dependentNodes, TType* returnType) { MKQL_ENSURE(returnType->IsOptional(), "Expected optional type as result of QueuePeek"); auto resType = AS_TYPE(TResourceType, resource); auto indexType = AS_TYPE(TDataType, index); @@ -3531,7 +3636,8 @@ TRuntimeNode TProgramBuilder::QueuePeek(TRuntimeNode resource, TRuntimeNode inde return TRuntimeNode(callableBuilder.Build(), false); } -TRuntimeNode TProgramBuilder::QueueRange(TRuntimeNode resource, TRuntimeNode begin, TRuntimeNode end, const TArrayRef<const TRuntimeNode>& dependentNodes, TType* returnType) { +TRuntimeNode TProgramBuilder::QueueRange(TRuntimeNode resource, TRuntimeNode begin, TRuntimeNode end, + const TArrayRef<const TRuntimeNode>& dependentNodes, TType* returnType) { MKQL_ENSURE(returnType->IsList(), "Expected list type as result of QueueRange"); auto resType = AS_TYPE(TResourceType, resource); @@ -3589,7 +3695,8 @@ TRuntimeNode TProgramBuilder::FromYsonSimpleType(TRuntimeNode input, NUdf::TData return TRuntimeNode(callableBuilder.Build(), false); } -TRuntimeNode TProgramBuilder::TryWeakMemberFromDict(TRuntimeNode other, TRuntimeNode rest, NUdf::TDataTypeId schemeType, const std::string_view& memberName) { +TRuntimeNode TProgramBuilder::TryWeakMemberFromDict(TRuntimeNode other, TRuntimeNode rest, + NUdf::TDataTypeId schemeType, const std::string_view& memberName) { auto resDataType = NewDataType(schemeType); auto resultType = NewOptionalType(resDataType); @@ -3631,14 +3738,26 @@ TRuntimeNode TProgramBuilder::AddTimezone(TRuntimeNode utc, TRuntimeNode id) { MKQL_ENSURE(dataType2->GetSchemeType() == NUdf::TDataType<ui16>::Id, "Expected ui16"); NUdf::EDataSlot tzType; switch (*dataType1->GetDataSlot()) { - case NUdf::EDataSlot::Date: tzType = NUdf::EDataSlot::TzDate; break; - case NUdf::EDataSlot::Datetime: tzType = NUdf::EDataSlot::TzDatetime; break; - case NUdf::EDataSlot::Timestamp: tzType = NUdf::EDataSlot::TzTimestamp; break; - case NUdf::EDataSlot::Date32: tzType = NUdf::EDataSlot::TzDate32; break; - case NUdf::EDataSlot::Datetime64: tzType = NUdf::EDataSlot::TzDatetime64; break; - case NUdf::EDataSlot::Timestamp64: tzType = NUdf::EDataSlot::TzTimestamp64; break; - default: - ythrow yexception() << "Unknown date type: " << *dataType1->GetDataSlot(); + case NUdf::EDataSlot::Date: + tzType = NUdf::EDataSlot::TzDate; + break; + case NUdf::EDataSlot::Datetime: + tzType = NUdf::EDataSlot::TzDatetime; + break; + case NUdf::EDataSlot::Timestamp: + tzType = NUdf::EDataSlot::TzTimestamp; + break; + case NUdf::EDataSlot::Date32: + tzType = NUdf::EDataSlot::TzDate32; + break; + case NUdf::EDataSlot::Datetime64: + tzType = NUdf::EDataSlot::TzDatetime64; + break; + case NUdf::EDataSlot::Timestamp64: + tzType = NUdf::EDataSlot::TzTimestamp64; + break; + default: + ythrow yexception() << "Unknown date type: " << *dataType1->GetDataSlot(); } auto resultType = NewOptionalType(NewDataType(tzType)); @@ -3651,18 +3770,31 @@ TRuntimeNode TProgramBuilder::AddTimezone(TRuntimeNode utc, TRuntimeNode id) { TRuntimeNode TProgramBuilder::RemoveTimezone(TRuntimeNode local) { bool isOptional1; const auto dataType1 = UnpackOptionalData(local, isOptional1); - MKQL_ENSURE((NUdf::GetDataTypeInfo(*dataType1->GetDataSlot()).Features & NUdf::TzDateType), "Expected date with timezone type"); + MKQL_ENSURE((NUdf::GetDataTypeInfo(*dataType1->GetDataSlot()).Features & NUdf::TzDateType), + "Expected date with timezone type"); NUdf::EDataSlot type; switch (*dataType1->GetDataSlot()) { - case NUdf::EDataSlot::TzDate: type = NUdf::EDataSlot::Date; break; - case NUdf::EDataSlot::TzDatetime: type = NUdf::EDataSlot::Datetime; break; - case NUdf::EDataSlot::TzTimestamp: type = NUdf::EDataSlot::Timestamp; break; - case NUdf::EDataSlot::TzDate32: type = NUdf::EDataSlot::Date32; break; - case NUdf::EDataSlot::TzDatetime64: type = NUdf::EDataSlot::Datetime64; break; - case NUdf::EDataSlot::TzTimestamp64: type = NUdf::EDataSlot::Timestamp64; break; - default: - ythrow yexception() << "Unknown date with timezone type: " << *dataType1->GetDataSlot(); + case NUdf::EDataSlot::TzDate: + type = NUdf::EDataSlot::Date; + break; + case NUdf::EDataSlot::TzDatetime: + type = NUdf::EDataSlot::Datetime; + break; + case NUdf::EDataSlot::TzTimestamp: + type = NUdf::EDataSlot::Timestamp; + break; + case NUdf::EDataSlot::TzDate32: + type = NUdf::EDataSlot::Date32; + break; + case NUdf::EDataSlot::TzDatetime64: + type = NUdf::EDataSlot::Datetime64; + break; + case NUdf::EDataSlot::TzTimestamp64: + type = NUdf::EDataSlot::Timestamp64; + break; + default: + ythrow yexception() << "Unknown date with timezone type: " << *dataType1->GetDataSlot(); } return Convert(local, NewDataType(type, isOptional1)); @@ -3672,8 +3804,7 @@ TRuntimeNode TProgramBuilder::Nth(TRuntimeNode tuple, ui32 index) { bool isOptional; const auto type = AS_TYPE(TTupleType, UnpackOptional(tuple.GetStaticType(), isOptional)); - MKQL_ENSURE(index < type->GetElementsCount(), "Index out of range: " << index << - " is not less than " << type->GetElementsCount()); + MKQL_ENSURE(index < type->GetElementsCount(), "Index out of range: " << index << " is not less than " << type->GetElementsCount()); auto itemType = type->GetElementType(index); if (isOptional && !itemType->IsOptional() && !itemType->IsNull() && !itemType->IsPg()) { itemType = TOptionalType::Create(itemType, Env_); @@ -3847,14 +3978,18 @@ TRuntimeNode TProgramBuilder::UnaryDataFunction(TRuntimeNode data, const std::st } TRuntimeNode TProgramBuilder::ToDict(TRuntimeNode list, bool multi, const TUnaryLambda& keySelector, - const TUnaryLambda& payloadSelector, std::string_view callableName, bool isCompact, ui64 itemsCountHint) + const TUnaryLambda& payloadSelector, std::string_view callableName, bool isCompact, ui64 itemsCountHint) { bool isOptional; const auto type = UnpackOptional(list, isOptional); MKQL_ENSURE(type->IsList(), "Expected list."); if (isOptional) { - return Map(list, [&](TRuntimeNode unpacked) { return ToDict(unpacked, multi, keySelector, payloadSelector, callableName, isCompact, itemsCountHint); } ); + auto lambda = [&](TRuntimeNode unpacked) { + return ToDict(unpacked, multi, keySelector, payloadSelector, callableName, isCompact, itemsCountHint); + }; + + return Map(list, lambda); } const auto itemType = AS_TYPE(TListType, type)->GetItemType(); @@ -3883,7 +4018,7 @@ TRuntimeNode TProgramBuilder::ToDict(TRuntimeNode list, bool multi, const TUnary } TRuntimeNode TProgramBuilder::SqueezeToDict(TRuntimeNode stream, bool multi, const TUnaryLambda& keySelector, - const TUnaryLambda& payloadSelector, std::string_view callableName, bool isCompact, ui64 itemsCountHint) + const TUnaryLambda& payloadSelector, std::string_view callableName, bool isCompact, ui64 itemsCountHint) { const auto type = stream.GetStaticType(); MKQL_ENSURE(type->IsStream() || type->IsFlow(), "Expected stream or flow."); @@ -3903,8 +4038,8 @@ TRuntimeNode TProgramBuilder::SqueezeToDict(TRuntimeNode stream, bool multi, con auto dictType = TDictType::Create(keyType, payloadType, Env_); auto returnType = type->IsFlow() - ? (TType*) TFlowType::Create(dictType, Env_) - : (TType*) TStreamType::Create(dictType, Env_); + ? (TType*)TFlowType::Create(dictType, Env_) + : (TType*)TStreamType::Create(dictType, Env_); TCallableBuilder callableBuilder(Env_, callableName, returnType); callableBuilder.Add(stream); callableBuilder.Add(itemArg); @@ -3917,14 +4052,15 @@ TRuntimeNode TProgramBuilder::SqueezeToDict(TRuntimeNode stream, bool multi, con } TRuntimeNode TProgramBuilder::NarrowSqueezeToDict(TRuntimeNode flow, bool multi, const TNarrowLambda& keySelector, - const TNarrowLambda& payloadSelector, std::string_view callableName, bool isCompact, ui64 itemsCountHint) + const TNarrowLambda& payloadSelector, std::string_view callableName, + bool isCompact, ui64 itemsCountHint) { const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flow.GetStaticType())); TRuntimeNode::TList itemArgs; itemArgs.reserve(wideComponents.size()); auto i = 0U; - std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&](){ return Arg(wideComponents[i++]); }); + std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&]() { return Arg(wideComponents[i++]); }); const auto key = keySelector(itemArgs); const auto keyType = key.GetStaticType(); @@ -3939,7 +4075,8 @@ TRuntimeNode TProgramBuilder::NarrowSqueezeToDict(TRuntimeNode flow, bool multi, const auto returnType = TFlowType::Create(dictType, Env_); TCallableBuilder callableBuilder(Env_, callableName, returnType); callableBuilder.Add(flow); - std::for_each(itemArgs.cbegin(), itemArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + std::for_each(itemArgs.cbegin(), itemArgs.cend(), + std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); callableBuilder.Add(key); callableBuilder.Add(payload); callableBuilder.Add(NewDataLiteral(multi)); @@ -3955,7 +4092,8 @@ void TProgramBuilder::ThrowIfListOfVoid(TType* type) { TRuntimeNode TProgramBuilder::BuildFlatMap(const std::string_view& callableName, TRuntimeNode list, const TUnaryLambda& handler) { const auto listType = list.GetStaticType(); - MKQL_ENSURE(listType->IsFlow() || listType->IsList() || listType->IsOptional() || listType->IsStream(), "Expected flow, list, stream or optional"); + MKQL_ENSURE(listType->IsFlow() || listType->IsList() || listType->IsOptional() || listType->IsStream(), + "Expected flow, list, stream or optional"); if (listType->IsOptional()) { const auto itemArg = Arg(AS_TYPE(TOptionalType, listType)->GetItemType()); @@ -3964,14 +4102,18 @@ TRuntimeNode TProgramBuilder::BuildFlatMap(const std::string_view& callableName, MKQL_ENSURE(type->IsList() || type->IsOptional() || type->IsStream() || type->IsFlow(), "Expected flow, list, stream or optional"); return IfPresent(list, [&](TRuntimeNode item) { return handler(item); - }, type->IsOptional() ? NewEmptyOptional(type) : type->IsList() ? NewEmptyList(AS_TYPE(TListType, type)->GetItemType()) : EmptyIterator(type)); + }, type->IsOptional() ? NewEmptyOptional(type) : type->IsList() ? NewEmptyList(AS_TYPE(TListType, type)->GetItemType()) + : EmptyIterator(type)); } - const auto itemType = listType->IsFlow() ? - AS_TYPE(TFlowType, listType)->GetItemType(): - listType->IsList() ? - AS_TYPE(TListType, listType)->GetItemType(): - AS_TYPE(TStreamType, listType)->GetItemType(); + TType* itemType; + if (listType->IsFlow()) { + itemType = AS_TYPE(TFlowType, listType)->GetItemType(); + } else if (listType->IsList()) { + itemType = AS_TYPE(TListType, listType)->GetItemType(); + } else { + itemType = AS_TYPE(TStreamType, listType)->GetItemType(); + } ThrowIfListOfVoid(itemType); const auto itemArg = Arg(itemType); @@ -3991,11 +4133,15 @@ TRuntimeNode TProgramBuilder::BuildFlatMap(const std::string_view& callableName, THROW yexception() << "Expected flow, list or stream."; } - const auto resultListType = listType->IsFlow() || type->IsFlow() ? - TFlowType::Create(retItemType, Env_): - listType->IsList() ? - (TType*)TListType::Create(retItemType, Env_): - (TType*)TStreamType::Create(retItemType, Env_); + TType* resultListType; + if (listType->IsFlow() || type->IsFlow()) { + resultListType = TFlowType::Create(retItemType, Env_); + } else if (listType->IsList()) { + resultListType = TListType::Create(retItemType, Env_); + } else { + resultListType = TStreamType::Create(retItemType, Env_); + } + TCallableBuilder callableBuilder(Env_, callableName, resultListType); callableBuilder.Add(list); callableBuilder.Add(itemArg); @@ -4017,8 +4163,13 @@ TRuntimeNode TProgramBuilder::MultiMap(TRuntimeNode list, const TExpandLambda& h const auto retItemType = newList.front().GetStaticType(); MKQL_ENSURE(retItemType->IsSameType(*newList.back().GetStaticType()), "Must be same type."); - const auto resultListType = listType->IsFlow() ? - (TType*)TFlowType::Create(retItemType, Env_) : (TType*)TListType::Create(retItemType, Env_); + TType* resultListType; + if (listType->IsFlow()) { + resultListType = TFlowType::Create(retItemType, Env_); + } else { + resultListType = TListType::Create(retItemType, Env_); + } + TCallableBuilder callableBuilder(Env_, __func__, resultListType); callableBuilder.Add(list); callableBuilder.Add(itemArg); @@ -4032,7 +4183,7 @@ TRuntimeNode TProgramBuilder::NarrowMultiMap(TRuntimeNode flow, const TWideLambd TRuntimeNode::TList itemArgs; itemArgs.reserve(wideComponents.size()); auto i = 0U; - std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&](){ return Arg(wideComponents[i++]); }); + std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&]() { return Arg(wideComponents[i++]); }); const auto newList = handler(itemArgs); @@ -4042,8 +4193,10 @@ TRuntimeNode TProgramBuilder::NarrowMultiMap(TRuntimeNode flow, const TWideLambd TCallableBuilder callableBuilder(Env_, __func__, NewFlowType(newList.front().GetStaticType())); callableBuilder.Add(flow); - std::for_each(itemArgs.cbegin(), itemArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); - std::for_each(newList.cbegin(), newList.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + std::for_each(itemArgs.cbegin(), itemArgs.cend(), + std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + std::for_each(newList.cbegin(), newList.cend(), + std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); return TRuntimeNode(callableBuilder.Build(), false); } @@ -4054,12 +4207,14 @@ TRuntimeNode TProgramBuilder::ExpandMap(TRuntimeNode flow, const TExpandLambda& std::vector<TType*> tupleItems; tupleItems.reserve(newItems.size()); - std::transform(newItems.cbegin(), newItems.cend(), std::back_inserter(tupleItems), std::bind(&TRuntimeNode::GetStaticType, std::placeholders::_1)); + std::transform(newItems.cbegin(), newItems.cend(), std::back_inserter(tupleItems), + std::bind(&TRuntimeNode::GetStaticType, std::placeholders::_1)); TCallableBuilder callableBuilder(Env_, __func__, NewFlowType(NewMultiType(tupleItems))); callableBuilder.Add(flow); callableBuilder.Add(itemArg); - std::for_each(newItems.cbegin(), newItems.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + std::for_each(newItems.cbegin(), newItems.cend(), + std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); return TRuntimeNode(callableBuilder.Build(), false); } @@ -4070,18 +4225,22 @@ TRuntimeNode TProgramBuilder::WideMap(TRuntimeNode flowOrStream, const TWideLamb TRuntimeNode::TList itemArgs; itemArgs.reserve(wideComponents.size()); auto i = 0U; - std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&](){ return Arg(wideComponents[i++]); }); + std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), + [&]() { return Arg(wideComponents[i++]); }); const auto newItems = handler(itemArgs); std::vector<TType*> tupleItems; tupleItems.reserve(newItems.size()); - std::transform(newItems.cbegin(), newItems.cend(), std::back_inserter(tupleItems), std::bind(&TRuntimeNode::GetStaticType, std::placeholders::_1)); + std::transform(newItems.cbegin(), newItems.cend(), std::back_inserter(tupleItems), + std::bind(&TRuntimeNode::GetStaticType, std::placeholders::_1)); - auto fillCallableBuilder = [&] (TCallableBuilder& builder, TRuntimeNode input) { + auto fillCallableBuilder = [&](TCallableBuilder& builder, TRuntimeNode input) { builder.Add(input); - std::for_each(itemArgs.cbegin(), itemArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(builder), std::placeholders::_1)); - std::for_each(newItems.cbegin(), newItems.cend(), std::bind(&TCallableBuilder::Add, std::ref(builder), std::placeholders::_1)); + std::for_each(itemArgs.cbegin(), itemArgs.cend(), + std::bind(&TCallableBuilder::Add, std::ref(builder), std::placeholders::_1)); + std::for_each(newItems.cbegin(), newItems.cend(), + std::bind(&TCallableBuilder::Add, std::ref(builder), std::placeholders::_1)); return TRuntimeNode(builder.Build(), false); }; @@ -4105,17 +4264,19 @@ TRuntimeNode TProgramBuilder::WideChain1Map(TRuntimeNode flow, const TWideLambda TRuntimeNode::TList inputArgs; inputArgs.reserve(wideComponents.size()); auto i = 0U; - std::generate_n(std::back_inserter(inputArgs), wideComponents.size(), [&](){ return Arg(wideComponents[i++]); }); + std::generate_n(std::back_inserter(inputArgs), wideComponents.size(), [&]() { return Arg(wideComponents[i++]); }); const auto initItems = init(inputArgs); std::vector<TType*> tupleItems; tupleItems.reserve(initItems.size()); - std::transform(initItems.cbegin(), initItems.cend(), std::back_inserter(tupleItems), std::bind(&TRuntimeNode::GetStaticType, std::placeholders::_1)); + std::transform(initItems.cbegin(), initItems.cend(), std::back_inserter(tupleItems), + std::bind(&TRuntimeNode::GetStaticType, std::placeholders::_1)); TRuntimeNode::TList outputArgs; outputArgs.reserve(tupleItems.size()); - std::transform(tupleItems.cbegin(), tupleItems.cend(), std::back_inserter(outputArgs), std::bind(&TProgramBuilder::Arg, this, std::placeholders::_1)); + std::transform(tupleItems.cbegin(), tupleItems.cend(), std::back_inserter(outputArgs), + std::bind(&TProgramBuilder::Arg, this, std::placeholders::_1)); const auto updateItems = update(inputArgs, outputArgs); @@ -4123,10 +4284,14 @@ TRuntimeNode TProgramBuilder::WideChain1Map(TRuntimeNode flow, const TWideLambda TCallableBuilder callableBuilder(Env_, __func__, NewFlowType(NewMultiType(tupleItems))); callableBuilder.Add(flow); - std::for_each(inputArgs.cbegin(), inputArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); - std::for_each(initItems.cbegin(), initItems.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); - std::for_each(outputArgs.cbegin(), outputArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); - std::for_each(updateItems.cbegin(), updateItems.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + std::for_each(inputArgs.cbegin(), inputArgs.cend(), + std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + std::for_each(initItems.cbegin(), initItems.cend(), + std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + std::for_each(outputArgs.cbegin(), outputArgs.cend(), + std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + std::for_each(updateItems.cbegin(), updateItems.cend(), + std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); return TRuntimeNode(callableBuilder.Build(), false); } @@ -4136,13 +4301,14 @@ TRuntimeNode TProgramBuilder::NarrowMap(TRuntimeNode flow, const TNarrowLambda& TRuntimeNode::TList itemArgs; itemArgs.reserve(wideComponents.size()); auto i = 0U; - std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&](){ return Arg(wideComponents[i++]); }); + std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&]() { return Arg(wideComponents[i++]); }); const auto newItem = handler(itemArgs); TCallableBuilder callableBuilder(Env_, __func__, NewFlowType(newItem.GetStaticType())); callableBuilder.Add(flow); - std::for_each(itemArgs.cbegin(), itemArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + std::for_each(itemArgs.cbegin(), itemArgs.cend(), + std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); callableBuilder.Add(newItem); return TRuntimeNode(callableBuilder.Build(), false); } @@ -4153,7 +4319,8 @@ TRuntimeNode TProgramBuilder::NarrowFlatMap(TRuntimeNode flow, const TNarrowLamb TRuntimeNode::TList itemArgs; itemArgs.reserve(wideComponents.size()); auto i = 0U; - std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&](){ return Arg(wideComponents[i++]); }); + std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), + [&]() { return Arg(wideComponents[i++]); }); const auto newList = handler(itemArgs); const auto type = newList.GetStaticType(); @@ -4173,7 +4340,8 @@ TRuntimeNode TProgramBuilder::NarrowFlatMap(TRuntimeNode flow, const TNarrowLamb TCallableBuilder callableBuilder(Env_, __func__, NewFlowType(retItemType)); callableBuilder.Add(flow); - std::for_each(itemArgs.cbegin(), itemArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + std::for_each(itemArgs.cbegin(), itemArgs.cend(), + std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); callableBuilder.Add(newList); return TRuntimeNode(callableBuilder.Build(), false); } @@ -4184,13 +4352,14 @@ TRuntimeNode TProgramBuilder::BuildWideFilter(const std::string_view& callableNa TRuntimeNode::TList itemArgs; itemArgs.reserve(wideComponents.size()); auto i = 0U; - std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&](){ return Arg(wideComponents[i++]); }); + std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&]() { return Arg(wideComponents[i++]); }); const auto predicate = handler(itemArgs); TCallableBuilder callableBuilder(Env_, callableName, flow.GetStaticType()); callableBuilder.Add(flow); - std::for_each(itemArgs.cbegin(), itemArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + std::for_each(itemArgs.cbegin(), itemArgs.cend(), + std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); callableBuilder.Add(predicate); return TRuntimeNode(callableBuilder.Build(), false); } @@ -4221,13 +4390,14 @@ TRuntimeNode TProgramBuilder::WideFilter(TRuntimeNode flow, TRuntimeNode limit, TRuntimeNode::TList itemArgs; itemArgs.reserve(wideComponents.size()); auto i = 0U; - std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&](){ return Arg(wideComponents[i++]); }); + std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&]() { return Arg(wideComponents[i++]); }); const auto predicate = handler(itemArgs); TCallableBuilder callableBuilder(Env_, __func__, flow.GetStaticType()); callableBuilder.Add(flow); - std::for_each(itemArgs.cbegin(), itemArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + std::for_each(itemArgs.cbegin(), itemArgs.cend(), + std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); callableBuilder.Add(predicate); callableBuilder.Add(limit); return TRuntimeNode(callableBuilder.Build(), false); @@ -4239,11 +4409,15 @@ TRuntimeNode TProgramBuilder::BuildFilter(const std::string_view& callableName, MKQL_ENSURE(listType->IsFlow() || listType->IsList() || listType->IsStream(), "Expected flow, list or stream."); const auto outputType = resultType ? resultType : listType; - const auto itemType = listType->IsFlow() ? - AS_TYPE(TFlowType, listType)->GetItemType(): - listType->IsList() ? - AS_TYPE(TListType, listType)->GetItemType(): - AS_TYPE(TStreamType, listType)->GetItemType(); + TType* itemType; + if (listType->IsFlow()) { + itemType = AS_TYPE(TFlowType, listType)->GetItemType(); + } else if (listType->IsList()) { + itemType = AS_TYPE(TListType, listType)->GetItemType(); + } else { + itemType = AS_TYPE(TStreamType, listType)->GetItemType(); + } + ThrowIfListOfVoid(itemType); const auto itemArg = Arg(itemType); @@ -4261,18 +4435,23 @@ TRuntimeNode TProgramBuilder::BuildFilter(const std::string_view& callableName, return TRuntimeNode(callableBuilder.Build(), false); } -TRuntimeNode TProgramBuilder::BuildFilter(const std::string_view& callableName, TRuntimeNode list, TRuntimeNode limit, const TUnaryLambda& handler, TType* resultType) +TRuntimeNode TProgramBuilder::BuildFilter(const std::string_view& callableName, TRuntimeNode list, + TRuntimeNode limit, const TUnaryLambda& handler, TType* resultType) { const auto listType = list.GetStaticType(); MKQL_ENSURE(listType->IsFlow() || listType->IsList() || listType->IsStream(), "Expected flow, list or stream."); MKQL_ENSURE(limit.GetStaticType()->IsData(), "Expected data"); const auto outputType = resultType ? resultType : listType; - const auto itemType = listType->IsFlow() ? - AS_TYPE(TFlowType, listType)->GetItemType(): - listType->IsList() ? - AS_TYPE(TListType, listType)->GetItemType(): - AS_TYPE(TStreamType, listType)->GetItemType(); + TType* itemType; + if (listType->IsFlow()) { + itemType = AS_TYPE(TFlowType, listType)->GetItemType(); + } else if (listType->IsList()) { + itemType = AS_TYPE(TListType, listType)->GetItemType(); + } else { + itemType = AS_TYPE(TStreamType, listType)->GetItemType(); + } + ThrowIfListOfVoid(itemType); const auto itemArg = Arg(itemType); @@ -4296,13 +4475,11 @@ TRuntimeNode TProgramBuilder::Filter(TRuntimeNode list, const TUnaryLambda& hand const auto type = list.GetStaticType(); if (type->IsOptional()) { - return - IfPresent(list, - [&](TRuntimeNode item) { - return If(handler(item), item, NewEmptyOptional(resultType), resultType); - }, - NewEmptyOptional(resultType) - ); + return IfPresent(list, + [&](TRuntimeNode item) { + return If(handler(item), item, NewEmptyOptional(resultType), resultType); + }, + NewEmptyOptional(resultType)); } return BuildFilter(__func__, list, handler, resultType); @@ -4377,34 +4554,41 @@ TRuntimeNode TProgramBuilder::PartialSort(TRuntimeNode list, TRuntimeNode n, con TRuntimeNode TProgramBuilder::BuildMap(const std::string_view& callableName, TRuntimeNode list, const TUnaryLambda& handler) { const auto listType = list.GetStaticType(); - MKQL_ENSURE(listType->IsFlow() || listType->IsList() || listType->IsStream() || listType->IsOptional(), "Expected flow, list, stream or optional"); + MKQL_ENSURE(listType->IsFlow() || listType->IsList() || listType->IsStream() || listType->IsOptional(), + "Expected flow, list, stream or optional"); if (listType->IsOptional()) { const auto itemArg = Arg(AS_TYPE(TOptionalType, listType)->GetItemType()); const auto newItem = handler(itemArg); return IfPresent(list, - [&](TRuntimeNode item) { return NewOptional(handler(item)); }, - NewEmptyOptional(NewOptionalType(newItem.GetStaticType())) - ); + [&](TRuntimeNode item) { return NewOptional(handler(item)); }, + NewEmptyOptional(NewOptionalType(newItem.GetStaticType()))); } - const auto itemType = listType->IsFlow() ? - AS_TYPE(TFlowType, listType)->GetItemType(): - listType->IsList() ? - AS_TYPE(TListType, listType)->GetItemType(): - AS_TYPE(TStreamType, listType)->GetItemType(); + TType* itemType; + if (listType->IsFlow()) { + itemType = AS_TYPE(TFlowType, listType)->GetItemType(); + } else if (listType->IsList()) { + itemType = AS_TYPE(TListType, listType)->GetItemType(); + } else { + itemType = AS_TYPE(TStreamType, listType)->GetItemType(); + } ThrowIfListOfVoid(itemType); const auto itemArg = Arg(itemType); const auto newItem = handler(itemArg); - const auto resultListType = listType->IsFlow() ? - (TType*)TFlowType::Create(newItem.GetStaticType(), Env_): - listType->IsList() ? - (TType*)TListType::Create(newItem.GetStaticType(), Env_): - (TType*)TStreamType::Create(newItem.GetStaticType(), Env_); + TType* resultListType; + if (listType->IsFlow()) { + resultListType = TFlowType::Create(newItem.GetStaticType(), Env_); + } else if (listType->IsList()) { + resultListType = TListType::Create(newItem.GetStaticType(), Env_); + } else { + resultListType = TStreamType::Create(newItem.GetStaticType(), Env_); + } + TCallableBuilder callableBuilder(Env_, callableName, resultListType); callableBuilder.Add(list); callableBuilder.Add(itemArg); @@ -4437,8 +4621,7 @@ TRuntimeNode TProgramBuilder::Udf( const std::string_view& funcName, TRuntimeNode runConfig, TType* userType, - const std::string_view& typeConfig -) + const std::string_view& typeConfig) { TRuntimeNode userTypeNode = userType ? TRuntimeNode(userType, true) : TRuntimeNode(Env_.GetVoidLazy()->GetType(), true); const ui32 flags = NUdf::IUdfModule::TFlags::TypesOnly; @@ -4547,11 +4730,11 @@ TRuntimeNode TProgramBuilder::ScriptUdf( } TRuntimeNode TProgramBuilder::Apply(TRuntimeNode callableNode, const TArrayRef<const TRuntimeNode>& args, - const std::string_view& file, ui32 row, ui32 column, ui32 dependentCount) { + const std::string_view& file, ui32 row, ui32 column, ui32 dependentCount) { MKQL_ENSURE(dependentCount <= args.size(), "Too many dependent nodes"); ui32 usedArgs = args.size() - dependentCount; MKQL_ENSURE(!callableNode.IsImmediate() && callableNode.GetNode()->GetType()->IsCallable(), - "Expected callable"); + "Expected callable"); auto callable = static_cast<TCallable*>(callableNode.GetNode()); TType* returnType = callable->GetType()->GetReturnType(); @@ -4566,7 +4749,7 @@ TRuntimeNode TProgramBuilder::Apply(TRuntimeNode callableNode, const TArrayRef<c TRuntimeNode arg = args[i]; MKQL_ENSURE(arg.GetStaticType()->IsConvertableTo(*argType), "Argument type mismatch for argument " << i << ": runtime " << argType->GetKindAsStr() - << " with static " << arg.GetStaticType()->GetKindAsStr()); + << " with static " << arg.GetStaticType()->GetKindAsStr()); } TCallableBuilder callableBuilder(Env_, "Apply2", callableType->GetReturnType()); @@ -4576,7 +4759,7 @@ TRuntimeNode TProgramBuilder::Apply(TRuntimeNode callableNode, const TArrayRef<c callableBuilder.Add(NewDataLiteral(row)); callableBuilder.Add(NewDataLiteral(column)); - for (const auto& arg: args) { + for (const auto& arg : args) { callableBuilder.Add(arg); } @@ -4584,9 +4767,9 @@ TRuntimeNode TProgramBuilder::Apply(TRuntimeNode callableNode, const TArrayRef<c } TRuntimeNode TProgramBuilder::Apply( - TRuntimeNode callableNode, - const TArrayRef<const TRuntimeNode>& args, - ui32 dependentCount) { + TRuntimeNode callableNode, + const TArrayRef<const TRuntimeNode>& args, + ui32 dependentCount) { return Apply(callableNode, args, {}, 0, 0, dependentCount); } @@ -4626,22 +4809,22 @@ TRuntimeNode TProgramBuilder::Concat(TRuntimeNode data1, TRuntimeNode data2) { TRuntimeNode TProgramBuilder::AggrConcat(TRuntimeNode data1, TRuntimeNode data2) { MKQL_ENSURE(data1.GetStaticType()->IsSameType(*data2.GetStaticType()), "Operands type mismatch."); - const std::array<TRuntimeNode, 2> args = {{ data1, data2 }}; + const std::array<TRuntimeNode, 2> args = {{data1, data2}}; return Invoke(__func__, data1.GetStaticType(), args); } TRuntimeNode TProgramBuilder::Substring(TRuntimeNode data, TRuntimeNode start, TRuntimeNode count) { - const std::array<TRuntimeNode, 3U> args = {{ data, start, count }}; + const std::array<TRuntimeNode, 3U> args = {{data, start, count}}; return Invoke(__func__, data.GetStaticType(), args); } TRuntimeNode TProgramBuilder::Find(TRuntimeNode haystack, TRuntimeNode needle, TRuntimeNode pos) { - const std::array<TRuntimeNode, 3U> args = {{ haystack, needle, pos }}; + const std::array<TRuntimeNode, 3U> args = {{haystack, needle, pos}}; return Invoke(__func__, NewOptionalType(NewDataType(NUdf::TDataType<ui32>::Id)), args); } TRuntimeNode TProgramBuilder::RFind(TRuntimeNode haystack, TRuntimeNode needle, TRuntimeNode pos) { - const std::array<TRuntimeNode, 3U> args = {{ haystack, needle, pos }}; + const std::array<TRuntimeNode, 3U> args = {{haystack, needle, pos}}; return Invoke(__func__, NewOptionalType(NewDataType(NUdf::TDataType<ui32>::Id)), args); } @@ -4658,19 +4841,24 @@ TRuntimeNode TProgramBuilder::StringContains(TRuntimeNode string, TRuntimeNode p TDataType* type1 = UnpackOptionalData(string, isOpt1); TDataType* type2 = UnpackOptionalData(pattern, isOpt2); MKQL_ENSURE(type1->GetSchemeType() == NUdf::TDataType<NUdf::TUtf8>::Id || - type1->GetSchemeType() == NUdf::TDataType<char*>::Id, "Expecting string as first argument"); + type1->GetSchemeType() == NUdf::TDataType<char*>::Id, "Expecting string as first argument"); MKQL_ENSURE(type2->GetSchemeType() == NUdf::TDataType<NUdf::TUtf8>::Id || - type2->GetSchemeType() == NUdf::TDataType<char*>::Id, "Expecting string as second argument"); + type2->GetSchemeType() == NUdf::TDataType<char*>::Id, "Expecting string as second argument"); return DataCompare(__func__, string, pattern); } TRuntimeNode TProgramBuilder::ByteAt(TRuntimeNode data, TRuntimeNode index) { - const std::array<TRuntimeNode, 2U> args = {{ data, index }}; + const std::array<TRuntimeNode, 2U> args = {{data, index}}; return Invoke(__func__, NewOptionalType(NewDataType(NUdf::TDataType<ui8>::Id)), args); } TRuntimeNode TProgramBuilder::Size(TRuntimeNode data) { - return UnaryDataFunction(data, __func__, TDataFunctionFlags::HasUi32Result | TDataFunctionFlags::AllowNull | TDataFunctionFlags::AllowOptionalArgs | TDataFunctionFlags::CommonOptionalResult); + return UnaryDataFunction(data, + __func__, + TDataFunctionFlags::HasUi32Result | + TDataFunctionFlags::AllowNull | + TDataFunctionFlags::AllowOptionalArgs | + TDataFunctionFlags::CommonOptionalResult); } template <bool Utf8> @@ -4688,7 +4876,8 @@ TRuntimeNode TProgramBuilder::FromString(TRuntimeNode data, TType* type) { bool isOptional; const auto sourceType = UnpackOptionalData(data, isOptional); const auto targetType = UnpackOptionalData(type, isOptional); - MKQL_ENSURE(sourceType->GetSchemeType() == NUdf::TDataType<char*>::Id || sourceType->GetSchemeType() == NUdf::TDataType<NUdf::TUtf8>::Id, "Expected String"); + MKQL_ENSURE(sourceType->GetSchemeType() == NUdf::TDataType<char*>::Id || sourceType->GetSchemeType() == NUdf::TDataType<NUdf::TUtf8>::Id, + "Expected String"); MKQL_ENSURE(targetType->GetSchemeType() != 0, "Null is not allowed"); TCallableBuilder callableBuilder(Env_, __func__, type); callableBuilder.Add(data); @@ -4705,7 +4894,8 @@ TRuntimeNode TProgramBuilder::StrictFromString(TRuntimeNode data, TType* type) { bool isOptional; const auto sourceType = UnpackOptionalData(data, isOptional); const auto targetType = UnpackOptionalData(type, isOptional); - MKQL_ENSURE(sourceType->GetSchemeType() == NUdf::TDataType<char*>::Id || sourceType->GetSchemeType() == NUdf::TDataType<NUdf::TUtf8>::Id, "Expected String"); + MKQL_ENSURE(sourceType->GetSchemeType() == NUdf::TDataType<char*>::Id || sourceType->GetSchemeType() == NUdf::TDataType<NUdf::TUtf8>::Id, + "Expected String"); MKQL_ENSURE(targetType->GetSchemeType() != 0, "Null is not allowed"); TCallableBuilder callableBuilder(Env_, __func__, type); callableBuilder.Add(data); @@ -4719,7 +4909,11 @@ TRuntimeNode TProgramBuilder::StrictFromString(TRuntimeNode data, TType* type) { } TRuntimeNode TProgramBuilder::ToBytes(TRuntimeNode data) { - return UnaryDataFunction(data, __func__, TDataFunctionFlags::HasStringResult | TDataFunctionFlags::AllowOptionalArgs | TDataFunctionFlags::CommonOptionalResult); + return UnaryDataFunction(data, + __func__, + TDataFunctionFlags::HasStringResult | + TDataFunctionFlags::AllowOptionalArgs | + TDataFunctionFlags::CommonOptionalResult); } TRuntimeNode TProgramBuilder::FromBytes(TRuntimeNode data, TType* targetType) { @@ -4743,12 +4937,12 @@ TRuntimeNode TProgramBuilder::FromBytes(TRuntimeNode data, TType* targetType) { } TRuntimeNode TProgramBuilder::InversePresortString(TRuntimeNode data) { - const std::array<TRuntimeNode, 1U> args = {{ data }}; + const std::array<TRuntimeNode, 1U> args = {{data}}; return Invoke(__func__, NewDataType(NUdf::TDataType<char*>::Id), args); } TRuntimeNode TProgramBuilder::InverseString(TRuntimeNode data) { - const std::array<TRuntimeNode, 1U> args = {{ data }}; + const std::array<TRuntimeNode, 1U> args = {{data}}; return Invoke(__func__, NewDataType(NUdf::TDataType<char*>::Id), args); } @@ -4802,7 +4996,8 @@ TRuntimeNode TProgramBuilder::CurrentUtcDatetime(const TArrayRef<const TRuntimeN TRuntimeNode TProgramBuilder::CurrentUtcTimestamp(const TArrayRef<const TRuntimeNode>& args) { return Coalesce(ToIntegral(Now(args), NewDataType(NUdf::TDataType<NUdf::TTimestamp>::Id, true)), - TRuntimeNode(BuildDataLiteral(NUdf::TUnboxedValuePod(ui64(NUdf::MAX_TIMESTAMP - 1ULL)), NUdf::TDataType<NUdf::TTimestamp>::Id, Env_), true)); + TRuntimeNode(BuildDataLiteral(NUdf::TUnboxedValuePod(ui64(NUdf::MAX_TIMESTAMP - 1ULL)), + NUdf::TDataType<NUdf::TTimestamp>::Id, Env_), true)); } TRuntimeNode TProgramBuilder::Pickle(TRuntimeNode data) { @@ -4847,7 +5042,7 @@ TRuntimeNode TProgramBuilder::Convert(TRuntimeNode data, TType* type) { bool isOptional; const auto dataType = UnpackOptionalData(data, isOptional); - const std::array<TRuntimeNode, 1> args = {{ data }}; + const std::array<TRuntimeNode, 1> args = {{data}}; if (dataType->GetSchemeType() == NUdf::TDataType<NUdf::TDecimal>::Id) { const auto targetSchemeType = UnpackOptionalData(type, isOptional)->GetSchemeType(); @@ -4865,9 +5060,10 @@ TRuntimeNode TProgramBuilder::ToDecimal(TRuntimeNode data, ui8 precision, ui8 sc auto dataType = UnpackOptionalData(data, isOptional); TType* decimal = TDataDecimalType::Create(precision, scale, Env_); - if (isOptional) + if (isOptional) { decimal = TOptionalType::Create(decimal, Env_); - const std::array<TRuntimeNode, 1> args = {{ data }}; + } + const std::array<TRuntimeNode, 1> args = {{data}}; if (dataType->GetSchemeType() == NUdf::TDataType<NUdf::TDecimal>::Id) { const auto& params = static_cast<const TDataDecimalType*>(dataType)->GetParams(); @@ -4877,7 +5073,7 @@ TRuntimeNode TProgramBuilder::ToDecimal(TRuntimeNode data, ui8 precision, ui8 sc return Invoke("ScaleUp_" + ::ToString(scale - params.second), decimal, args); } else if (params.second > scale) { TRuntimeNode scaled = Invoke("ScaleDown_" + ::ToString(params.second - scale), decimal, args); - return Invoke("CheckBounds_" + ::ToString(precision), decimal, {{ scaled }}); + return Invoke("CheckBounds_" + ::ToString(precision), decimal, {{scaled}}); } else if (precision < params.first) { return Invoke("CheckBounds_" + ::ToString(precision), decimal, args); } else if (precision > params.first) { @@ -4888,10 +5084,11 @@ TRuntimeNode TProgramBuilder::ToDecimal(TRuntimeNode data, ui8 precision, ui8 sc } else { const auto digits = NUdf::GetDataTypeInfo(*dataType->GetDataSlot()).DecimalDigits; MKQL_ENSURE(digits, "Can't cast into Decimal."); - if (digits <= precision && !scale) + if (digits <= precision && !scale) { return Invoke(__func__, decimal, args); - else + } else { return ToDecimal(ToDecimal(data, digits, 0), precision, scale); + } } } @@ -4901,11 +5098,12 @@ TRuntimeNode TProgramBuilder::ToIntegral(TRuntimeNode data, TType* type) { if (dataType->GetSchemeType() == NUdf::TDataType<NUdf::TDecimal>::Id) { const auto& params = static_cast<const TDataDecimalType*>(dataType)->GetParams(); - if (params.second) + if (params.second) { return ToIntegral(ToDecimal(data, params.first - params.second, 0), type); + } } - const std::array<TRuntimeNode, 1> args = {{ data }}; + const std::array<TRuntimeNode, 1> args = {{data}}; return Invoke(__func__, type, args); } @@ -4930,10 +5128,12 @@ TRuntimeNode TProgramBuilder::AsList(const TArrayRef<const TRuntimeNode>& items) } TRuntimeNode TProgramBuilder::MapJoinCore(TRuntimeNode flow, TRuntimeNode dict, EJoinKind joinKind, - const TArrayRef<const ui32>& leftKeyColumns, const TArrayRef<const ui32>& leftRenames, - const TArrayRef<const ui32>& rightRenames, TType* returnType) { - - MKQL_ENSURE(joinKind == EJoinKind::Inner || joinKind == EJoinKind::Left || joinKind == EJoinKind::LeftSemi || joinKind == EJoinKind::LeftOnly, "Unsupported join kind"); + const TArrayRef<const ui32>& leftKeyColumns, const TArrayRef<const ui32>& leftRenames, + const TArrayRef<const ui32>& rightRenames, TType* returnType) { + MKQL_ENSURE(joinKind == EJoinKind::Inner || + joinKind == EJoinKind::Left || + joinKind == EJoinKind::LeftSemi || + joinKind == EJoinKind::LeftOnly, "Unsupported join kind"); MKQL_ENSURE(!leftKeyColumns.empty(), "At least one key column must be specified"); MKQL_ENSURE(leftRenames.size() % 2U == 0U, "Expected even count"); MKQL_ENSURE(rightRenames.size() % 2U == 0U, "Expected even count"); @@ -4941,13 +5141,16 @@ TRuntimeNode TProgramBuilder::MapJoinCore(TRuntimeNode flow, TRuntimeNode dict, TRuntimeNode::TList leftKeyColumnsNodes, leftRenamesNodes, rightRenamesNodes; leftKeyColumnsNodes.reserve(leftKeyColumns.size()); - std::transform(leftKeyColumns.cbegin(), leftKeyColumns.cend(), std::back_inserter(leftKeyColumnsNodes), [this](const ui32 idx) { return NewDataLiteral(idx); }); + std::transform(leftKeyColumns.cbegin(), leftKeyColumns.cend(), std::back_inserter(leftKeyColumnsNodes), + [this](const ui32 idx) { return NewDataLiteral(idx); }); leftRenamesNodes.reserve(leftRenames.size()); - std::transform(leftRenames.cbegin(), leftRenames.cend(), std::back_inserter(leftRenamesNodes), [this](const ui32 idx) { return NewDataLiteral(idx); }); + std::transform(leftRenames.cbegin(), leftRenames.cend(), std::back_inserter(leftRenamesNodes), + [this](const ui32 idx) { return NewDataLiteral(idx); }); rightRenamesNodes.reserve(rightRenames.size()); - std::transform(rightRenames.cbegin(), rightRenames.cend(), std::back_inserter(rightRenamesNodes), [this](const ui32 idx) { return NewDataLiteral(idx); }); + std::transform(rightRenames.cbegin(), rightRenames.cend(), std::back_inserter(rightRenamesNodes), + [this](const ui32 idx) { return NewDataLiteral(idx); }); TCallableBuilder callableBuilder(Env_, __func__, returnType); callableBuilder.Add(flow); @@ -4961,11 +5164,10 @@ TRuntimeNode TProgramBuilder::MapJoinCore(TRuntimeNode flow, TRuntimeNode dict, } TRuntimeNode TProgramBuilder::CommonJoinCore(TRuntimeNode flow, EJoinKind joinKind, - const TArrayRef<const ui32>& leftColumns, const TArrayRef<const ui32>& rightColumns, - const TArrayRef<const ui32>& requiredColumns, const TArrayRef<const ui32>& keyColumns, - ui64 memLimit, std::optional<ui32> sortedTableOrder, - EAnyJoinSettings anyJoinSettings, const ui32 tableIndexField, TType* returnType) { - + const TArrayRef<const ui32>& leftColumns, const TArrayRef<const ui32>& rightColumns, + const TArrayRef<const ui32>& requiredColumns, const TArrayRef<const ui32>& keyColumns, + ui64 memLimit, std::optional<ui32> sortedTableOrder, + EAnyJoinSettings anyJoinSettings, const ui32 tableIndexField, TType* returnType) { MKQL_ENSURE(leftColumns.size() % 2U == 0U, "Expected even count"); MKQL_ENSURE(rightColumns.size() % 2U == 0U, "Expected even count"); @@ -4973,11 +5175,11 @@ TRuntimeNode TProgramBuilder::CommonJoinCore(TRuntimeNode flow, EJoinKind joinKi leftOutputColumnsNodes, rightOutputColumnsNodes, keyColumnsNodes; bool s = false; - for (const auto idx : leftColumns) { + for (const auto idx : leftColumns) { ((s = !s) ? leftInputColumnsNodes : leftOutputColumnsNodes).emplace_back(NewDataLiteral(idx)); } - for (const auto idx : rightColumns) { + for (const auto idx : rightColumns) { ((s = !s) ? rightInputColumnsNodes : rightOutputColumnsNodes).emplace_back(NewDataLiteral(idx)); } @@ -4986,14 +5188,14 @@ TRuntimeNode TProgramBuilder::CommonJoinCore(TRuntimeNode flow, EJoinKind joinKi requiredColumnsNodes.reserve(requiredColumns.size()); std::transform(requiredColumns.cbegin(), requiredColumns.cend(), std::back_inserter(requiredColumnsNodes), - std::bind(&TProgramBuilder::NewDataLiteral<ui32>, this, std::placeholders::_1)); + std::bind(&TProgramBuilder::NewDataLiteral<ui32>, this, std::placeholders::_1)); const std::unordered_set<ui32> keyIndices(keyColumns.cbegin(), keyColumns.cend()); MKQL_ENSURE(keyIndices.size() == keyColumns.size(), "Duplication of key columns."); keyColumnsNodes.reserve(keyColumns.size()); std::transform(keyColumns.cbegin(), keyColumns.cend(), std::back_inserter(keyColumnsNodes), - std::bind(&TProgramBuilder::NewDataLiteral<ui32>, this, std::placeholders::_1)); + std::bind(&TProgramBuilder::NewDataLiteral<ui32>, this, std::placeholders::_1)); TCallableBuilder callableBuilder(Env_, __func__, returnType); callableBuilder.Add(flow); @@ -5011,120 +5213,158 @@ TRuntimeNode TProgramBuilder::CommonJoinCore(TRuntimeNode flow, EJoinKind joinKi return TRuntimeNode(callableBuilder.Build(), false); } -TRuntimeNode TProgramBuilder::WideCombiner(TRuntimeNode flow, i64 memLimit, const TWideLambda& extractor, const TBinaryWideLambda& init, const TTernaryWideLambda& update, const TBinaryWideLambda& finish) { +TRuntimeNode TProgramBuilder::WideCombiner(TRuntimeNode flow, i64 memLimit, const TWideLambda& extractor, + const TBinaryWideLambda& init, const TTernaryWideLambda& update, + const TBinaryWideLambda& finish) { const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flow.GetStaticType())); TRuntimeNode::TList itemArgs; itemArgs.reserve(wideComponents.size()); auto i = 0U; - std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&](){ return Arg(wideComponents[i++]); }); + std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), + [&]() { return Arg(wideComponents[i++]); }); const auto keys = extractor(itemArgs); TRuntimeNode::TList keyArgs; keyArgs.reserve(keys.size()); - std::transform(keys.cbegin(), keys.cend(), std::back_inserter(keyArgs), [&](TRuntimeNode key){ return Arg(key.GetStaticType()); } ); + std::transform(keys.cbegin(), keys.cend(), std::back_inserter(keyArgs), + [&](TRuntimeNode key) { return Arg(key.GetStaticType()); }); const auto first = init(keyArgs, itemArgs); TRuntimeNode::TList stateArgs; stateArgs.reserve(first.size()); - std::transform(first.cbegin(), first.cend(), std::back_inserter(stateArgs), [&](TRuntimeNode state){ return Arg(state.GetStaticType()); } ); + std::transform(first.cbegin(), first.cend(), std::back_inserter(stateArgs), + [&](TRuntimeNode state) { return Arg(state.GetStaticType()); }); const auto next = update(keyArgs, itemArgs, stateArgs); MKQL_ENSURE(next.size() == first.size(), "Mismatch init and update state size."); TRuntimeNode::TList finishKeyArgs; finishKeyArgs.reserve(keys.size()); - std::transform(keys.cbegin(), keys.cend(), std::back_inserter(finishKeyArgs), [&](TRuntimeNode key){ return Arg(key.GetStaticType()); } ); + std::transform(keys.cbegin(), keys.cend(), std::back_inserter(finishKeyArgs), + [&](TRuntimeNode key) { return Arg(key.GetStaticType()); }); TRuntimeNode::TList finishStateArgs; finishStateArgs.reserve(next.size()); - std::transform(next.cbegin(), next.cend(), std::back_inserter(finishStateArgs), [&](TRuntimeNode state){ return Arg(state.GetStaticType()); } ); + std::transform(next.cbegin(), next.cend(), std::back_inserter(finishStateArgs), + [&](TRuntimeNode state) { return Arg(state.GetStaticType()); }); const auto output = finish(finishKeyArgs, finishStateArgs); std::vector<TType*> tupleItems; tupleItems.reserve(output.size()); - std::transform(output.cbegin(), output.cend(), std::back_inserter(tupleItems), std::bind(&TRuntimeNode::GetStaticType, std::placeholders::_1)); + std::transform(output.cbegin(), output.cend(), std::back_inserter(tupleItems), + std::bind(&TRuntimeNode::GetStaticType, std::placeholders::_1)); TCallableBuilder callableBuilder(Env_, __func__, NewFlowType(NewMultiType(tupleItems))); callableBuilder.Add(flow); callableBuilder.Add(NewDataLiteral(memLimit)); callableBuilder.Add(NewDataLiteral(ui32(keyArgs.size()))); callableBuilder.Add(NewDataLiteral(ui32(stateArgs.size()))); - std::for_each(itemArgs.cbegin(), itemArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); - std::for_each(keys.cbegin(), keys.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); - std::for_each(keyArgs.cbegin(), keyArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); - std::for_each(first.cbegin(), first.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); - std::for_each(stateArgs.cbegin(), stateArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); - std::for_each(next.cbegin(), next.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); - std::for_each(finishKeyArgs.cbegin(), finishKeyArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); - std::for_each(finishStateArgs.cbegin(), finishStateArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); - std::for_each(output.cbegin(), output.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); - return TRuntimeNode(callableBuilder.Build(), false); -} - -TRuntimeNode TProgramBuilder::WideLastCombinerCommon(const TStringBuf& funcName, TRuntimeNode flow, const TWideLambda& extractor, const TBinaryWideLambda& init, const TTernaryWideLambda& update, const TBinaryWideLambda& finish) { + std::for_each(itemArgs.cbegin(), itemArgs.cend(), + std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + std::for_each(keys.cbegin(), keys.cend(), + std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + std::for_each(keyArgs.cbegin(), keyArgs.cend(), + std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + std::for_each(first.cbegin(), first.cend(), + std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + std::for_each(stateArgs.cbegin(), stateArgs.cend(), + std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + std::for_each(next.cbegin(), next.cend(), + std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + std::for_each(finishKeyArgs.cbegin(), finishKeyArgs.cend(), + std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + std::for_each(finishStateArgs.cbegin(), finishStateArgs.cend(), + std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + std::for_each(output.cbegin(), output.cend(), + std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + return TRuntimeNode(callableBuilder.Build(), false); +} + +TRuntimeNode TProgramBuilder::WideLastCombinerCommon(const TStringBuf& funcName, TRuntimeNode flow, + const TWideLambda& extractor, const TBinaryWideLambda& init, + const TTernaryWideLambda& update, const TBinaryWideLambda& finish) { const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flow.GetStaticType())); TRuntimeNode::TList itemArgs; itemArgs.reserve(wideComponents.size()); auto i = 0U; - std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&](){ return Arg(wideComponents[i++]); }); + std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), + [&]() { return Arg(wideComponents[i++]); }); const auto keys = extractor(itemArgs); TRuntimeNode::TList keyArgs; keyArgs.reserve(keys.size()); - std::transform(keys.cbegin(), keys.cend(), std::back_inserter(keyArgs), [&](TRuntimeNode key){ return Arg(key.GetStaticType()); } ); + std::transform(keys.cbegin(), keys.cend(), std::back_inserter(keyArgs), + [&](TRuntimeNode key) { return Arg(key.GetStaticType()); }); const auto first = init(keyArgs, itemArgs); TRuntimeNode::TList stateArgs; stateArgs.reserve(first.size()); - std::transform(first.cbegin(), first.cend(), std::back_inserter(stateArgs), [&](TRuntimeNode state){ return Arg(state.GetStaticType()); } ); + std::transform(first.cbegin(), first.cend(), std::back_inserter(stateArgs), + [&](TRuntimeNode state) { return Arg(state.GetStaticType()); }); const auto next = update(keyArgs, itemArgs, stateArgs); MKQL_ENSURE(next.size() == first.size(), "Mismatch init and update state size."); TRuntimeNode::TList finishKeyArgs; finishKeyArgs.reserve(keys.size()); - std::transform(keys.cbegin(), keys.cend(), std::back_inserter(finishKeyArgs), [&](TRuntimeNode key){ return Arg(key.GetStaticType()); } ); + std::transform(keys.cbegin(), keys.cend(), std::back_inserter(finishKeyArgs), + [&](TRuntimeNode key) { return Arg(key.GetStaticType()); }); TRuntimeNode::TList finishStateArgs; finishStateArgs.reserve(next.size()); - std::transform(next.cbegin(), next.cend(), std::back_inserter(finishStateArgs), [&](TRuntimeNode state){ return Arg(state.GetStaticType()); } ); + std::transform(next.cbegin(), next.cend(), std::back_inserter(finishStateArgs), + [&](TRuntimeNode state) { return Arg(state.GetStaticType()); }); const auto output = finish(finishKeyArgs, finishStateArgs); std::vector<TType*> tupleItems; tupleItems.reserve(output.size()); - std::transform(output.cbegin(), output.cend(), std::back_inserter(tupleItems), std::bind(&TRuntimeNode::GetStaticType, std::placeholders::_1)); + std::transform(output.cbegin(), output.cend(), std::back_inserter(tupleItems), + std::bind(&TRuntimeNode::GetStaticType, std::placeholders::_1)); TCallableBuilder callableBuilder(Env_, funcName, NewFlowType(NewMultiType(tupleItems))); callableBuilder.Add(flow); callableBuilder.Add(NewDataLiteral(ui32(keyArgs.size()))); callableBuilder.Add(NewDataLiteral(ui32(stateArgs.size()))); - std::for_each(itemArgs.cbegin(), itemArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); - std::for_each(keys.cbegin(), keys.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); - std::for_each(keyArgs.cbegin(), keyArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); - std::for_each(first.cbegin(), first.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); - std::for_each(stateArgs.cbegin(), stateArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); - std::for_each(next.cbegin(), next.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); - std::for_each(finishKeyArgs.cbegin(), finishKeyArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); - std::for_each(finishStateArgs.cbegin(), finishStateArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); - std::for_each(output.cbegin(), output.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); - return TRuntimeNode(callableBuilder.Build(), false); -} - -TRuntimeNode TProgramBuilder::WideLastCombiner(TRuntimeNode flow, const TWideLambda& extractor, const TBinaryWideLambda& init, const TTernaryWideLambda& update, const TBinaryWideLambda& finish) { + std::for_each(itemArgs.cbegin(), itemArgs.cend(), + std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + std::for_each(keys.cbegin(), keys.cend(), + std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + std::for_each(keyArgs.cbegin(), keyArgs.cend(), + std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + std::for_each(first.cbegin(), first.cend(), + std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + std::for_each(stateArgs.cbegin(), stateArgs.cend(), + std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + std::for_each(next.cbegin(), next.cend(), + std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + std::for_each(finishKeyArgs.cbegin(), finishKeyArgs.cend(), + std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + std::for_each(finishStateArgs.cbegin(), finishStateArgs.cend(), + std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + std::for_each(output.cbegin(), output.cend(), + std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + return TRuntimeNode(callableBuilder.Build(), false); +} + +TRuntimeNode TProgramBuilder::WideLastCombiner(TRuntimeNode flow, const TWideLambda& extractor, + const TBinaryWideLambda& init, const TTernaryWideLambda& update, + const TBinaryWideLambda& finish) { return WideLastCombinerCommon(__func__, flow, extractor, init, update, finish); } -TRuntimeNode TProgramBuilder::WideLastCombinerWithSpilling(TRuntimeNode flow, const TWideLambda& extractor, const TBinaryWideLambda& init, const TTernaryWideLambda& update, const TBinaryWideLambda& finish) { +TRuntimeNode TProgramBuilder::WideLastCombinerWithSpilling(TRuntimeNode flow, const TWideLambda& extractor, + const TBinaryWideLambda& init, const TTernaryWideLambda& update, + const TBinaryWideLambda& finish) { if constexpr (RuntimeVersion < 49U) { THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__; } @@ -5132,20 +5372,23 @@ TRuntimeNode TProgramBuilder::WideLastCombinerWithSpilling(TRuntimeNode flow, co return WideLastCombinerCommon(__func__, flow, extractor, init, update, finish); } -TRuntimeNode TProgramBuilder::WideCondense1(TRuntimeNode flow, const TWideLambda& init, const TWideSwitchLambda& switcher, const TBinaryWideLambda& update, bool useCtx) { +TRuntimeNode TProgramBuilder::WideCondense1(TRuntimeNode flow, const TWideLambda& init, const TWideSwitchLambda& switcher, + const TBinaryWideLambda& update, bool useCtx) { const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flow.GetStaticType())); TRuntimeNode::TList itemArgs; itemArgs.reserve(wideComponents.size()); auto i = 0U; - std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&](){ return Arg(wideComponents[i++]); }); + std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), + [&]() { return Arg(wideComponents[i++]); }); const auto first = init(itemArgs); TRuntimeNode::TList stateArgs; stateArgs.reserve(first.size()); - std::transform(first.cbegin(), first.cend(), std::back_inserter(stateArgs), [&](TRuntimeNode state){ return Arg(state.GetStaticType()); } ); + std::transform(first.cbegin(), first.cend(), std::back_inserter(stateArgs), + [&](TRuntimeNode state) { return Arg(state.GetStaticType()); }); const auto chop = switcher(itemArgs, stateArgs); @@ -5154,15 +5397,20 @@ TRuntimeNode TProgramBuilder::WideCondense1(TRuntimeNode flow, const TWideLambda std::vector<TType*> tupleItems; tupleItems.reserve(next.size()); - std::transform(next.cbegin(), next.cend(), std::back_inserter(tupleItems), std::bind(&TRuntimeNode::GetStaticType, std::placeholders::_1)); + std::transform(next.cbegin(), next.cend(), std::back_inserter(tupleItems), + std::bind(&TRuntimeNode::GetStaticType, std::placeholders::_1)); TCallableBuilder callableBuilder(Env_, __func__, NewFlowType(NewMultiType(tupleItems))); callableBuilder.Add(flow); - std::for_each(itemArgs.cbegin(), itemArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); - std::for_each(first.cbegin(), first.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); - std::for_each(stateArgs.cbegin(), stateArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + std::for_each(itemArgs.cbegin(), itemArgs.cend(), + std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + std::for_each(first.cbegin(), first.cend(), + std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + std::for_each(stateArgs.cbegin(), stateArgs.cend(), + std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); callableBuilder.Add(chop); - std::for_each(next.cbegin(), next.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + std::for_each(next.cbegin(), next.cend(), + std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); if (useCtx) { callableBuilder.Add(NewDataLiteral<bool>(useCtx)); } @@ -5171,11 +5419,11 @@ TRuntimeNode TProgramBuilder::WideCondense1(TRuntimeNode flow, const TWideLambda } TRuntimeNode TProgramBuilder::CombineCore(TRuntimeNode stream, - const TUnaryLambda& keyExtractor, - const TBinaryLambda& init, - const TTernaryLambda& update, - const TBinaryLambda& finish, - ui64 memLimit) + const TUnaryLambda& keyExtractor, + const TBinaryLambda& init, + const TTernaryLambda& update, + const TBinaryLambda& finish, + ui64 memLimit) { const bool isStream = stream.GetStaticType()->IsStream(); const auto itemType = isStream ? AS_TYPE(TStreamType, stream)->GetItemType() : AS_TYPE(TFlowType, stream)->GetItemType(); @@ -5223,9 +5471,9 @@ TRuntimeNode TProgramBuilder::CombineCore(TRuntimeNode stream, } TRuntimeNode TProgramBuilder::GroupingCore(TRuntimeNode stream, - const TBinaryLambda& groupSwitch, - const TUnaryLambda& keyExtractor, - const TUnaryLambda& handler) + const TBinaryLambda& groupSwitch, + const TUnaryLambda& keyExtractor, + const TUnaryLambda& handler) { auto itemType = AS_TYPE(TStreamType, stream)->GetItemType(); @@ -5237,7 +5485,7 @@ TRuntimeNode TProgramBuilder::GroupingCore(TRuntimeNode stream, TRuntimeNode groupSwitchResult = groupSwitch(groupSwitchKeyArg, groupSwitchItemArg); MKQL_ENSURE(AS_TYPE(TDataType, groupSwitchResult)->GetSchemeType() == NUdf::TDataType<bool>::Id, - "Expected bool type"); + "Expected bool type"); TRuntimeNode handlerItemArg; TRuntimeNode handlerResult; @@ -5248,7 +5496,7 @@ TRuntimeNode TProgramBuilder::GroupingCore(TRuntimeNode stream, itemType = handlerResult.GetStaticType(); } - const std::array<TType*, 2U> tupleItems = {{ keyExtractorResult.GetStaticType(), NewStreamType(itemType) }}; + const std::array<TType*, 2U> tupleItems = {{keyExtractorResult.GetStaticType(), NewStreamType(itemType)}}; const auto finishType = NewStreamType(NewTupleType(tupleItems)); TCallableBuilder callableBuilder(Env_, __func__, finishType); @@ -5266,7 +5514,8 @@ TRuntimeNode TProgramBuilder::GroupingCore(TRuntimeNode stream, return TRuntimeNode(callableBuilder.Build(), false); } -TRuntimeNode TProgramBuilder::Chopper(TRuntimeNode flow, const TUnaryLambda& keyExtractor, const TBinaryLambda& groupSwitch, const TBinaryLambda& groupHandler) { +TRuntimeNode TProgramBuilder::Chopper(TRuntimeNode flow, const TUnaryLambda& keyExtractor, + const TBinaryLambda& groupSwitch, const TBinaryLambda& groupHandler) { const auto flowType = flow.GetStaticType(); MKQL_ENSURE(flowType->IsFlow() || flowType->IsStream(), "Expected flow or stream."); @@ -5295,21 +5544,21 @@ TRuntimeNode TProgramBuilder::Chopper(TRuntimeNode flow, const TUnaryLambda& key return TRuntimeNode(callableBuilder.Build(), false); } -TRuntimeNode TProgramBuilder::WideChopper(TRuntimeNode flow, const TWideLambda& extractor, const TWideSwitchLambda& groupSwitch, - const std::function<TRuntimeNode (TRuntimeNode::TList, TRuntimeNode)>& groupHandler -) { +TRuntimeNode TProgramBuilder::WideChopper(TRuntimeNode flow, const TWideLambda& extractor, + const TWideSwitchLambda& groupSwitch, + const std::function<TRuntimeNode(TRuntimeNode::TList, TRuntimeNode)>& groupHandler) { const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flow.GetStaticType())); TRuntimeNode::TList itemArgs, keyArgs; itemArgs.reserve(wideComponents.size()); auto i = 0U; - std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&](){ return Arg(wideComponents[i++]); }); + std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&]() { return Arg(wideComponents[i++]); }); const auto keys = extractor(itemArgs); keyArgs.reserve(keys.size()); - std::transform(keys.cbegin(), keys.cend(), std::back_inserter(keyArgs), [&](TRuntimeNode key){ return Arg(key.GetStaticType()); } ); + std::transform(keys.cbegin(), keys.cend(), std::back_inserter(keyArgs), [&](TRuntimeNode key) { return Arg(key.GetStaticType()); }); const auto groupSwitchResult = groupSwitch(keyArgs, itemArgs); @@ -5318,9 +5567,12 @@ TRuntimeNode TProgramBuilder::WideChopper(TRuntimeNode flow, const TWideLambda& TCallableBuilder callableBuilder(Env_, __func__, output.GetStaticType()); callableBuilder.Add(flow); - std::for_each(itemArgs.cbegin(), itemArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); - std::for_each(keys.cbegin(), keys.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); - std::for_each(keyArgs.cbegin(), keyArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + std::for_each(itemArgs.cbegin(), itemArgs.cend(), + std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + std::for_each(keys.cbegin(), keys.cend(), + std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + std::for_each(keyArgs.cbegin(), keyArgs.cend(), + std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); callableBuilder.Add(groupSwitchResult); callableBuilder.Add(input); callableBuilder.Add(output); @@ -5328,14 +5580,14 @@ TRuntimeNode TProgramBuilder::WideChopper(TRuntimeNode flow, const TWideLambda& } TRuntimeNode TProgramBuilder::HoppingCore(TRuntimeNode list, - const TUnaryLambda& timeExtractor, - const TUnaryLambda& init, - const TBinaryLambda& update, - const TUnaryLambda& save, - const TUnaryLambda& load, - const TBinaryLambda& merge, - const TBinaryLambda& finish, - TRuntimeNode hop, TRuntimeNode interval, TRuntimeNode delay) + const TUnaryLambda& timeExtractor, + const TUnaryLambda& init, + const TBinaryLambda& update, + const TUnaryLambda& save, + const TUnaryLambda& load, + const TBinaryLambda& merge, + const TBinaryLambda& finish, + TRuntimeNode hop, TRuntimeNode interval, TRuntimeNode delay) { auto streamType = AS_TYPE(TStreamType, list); auto itemType = AS_TYPE(TStructType, streamType->GetItemType()); @@ -5399,16 +5651,16 @@ TRuntimeNode TProgramBuilder::HoppingCore(TRuntimeNode list, } TRuntimeNode TProgramBuilder::MultiHoppingCore(TRuntimeNode list, - const TUnaryLambda& keyExtractor, - const TUnaryLambda& timeExtractor, - const TUnaryLambda& init, - const TBinaryLambda& update, - const TUnaryLambda& save, - const TUnaryLambda& load, - const TBinaryLambda& merge, - const TTernaryLambda& finish, - TRuntimeNode hop, TRuntimeNode interval, TRuntimeNode delay, - TRuntimeNode dataWatermarks, TRuntimeNode watermarksMode) + const TUnaryLambda& keyExtractor, + const TUnaryLambda& timeExtractor, + const TUnaryLambda& init, + const TBinaryLambda& update, + const TUnaryLambda& save, + const TUnaryLambda& load, + const TBinaryLambda& merge, + const TTernaryLambda& finish, + TRuntimeNode hop, TRuntimeNode interval, TRuntimeNode delay, + TRuntimeNode dataWatermarks, TRuntimeNode watermarksMode) { auto streamType = AS_TYPE(TStreamType, list); auto itemType = AS_TYPE(TStructType, streamType->GetItemType()); @@ -5487,9 +5739,15 @@ TRuntimeNode TProgramBuilder::Default(TType* type) { } const auto scheme = targetType->GetSchemeType(); - const auto value = scheme == NUdf::TDataType<NUdf::TUuid>::Id ? - Env_.NewStringValue("\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0"sv) : - scheme == NUdf::TDataType<NUdf::TDyNumber>::Id ? NUdf::TUnboxedValuePod::Embedded("\1") : NUdf::TUnboxedValuePod::Zero(); + NUdf::TUnboxedValue value; + if (scheme == NUdf::TDataType<NUdf::TUuid>::Id) { + value = Env_.NewStringValue("\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0"sv); + } else if (scheme == NUdf::TDataType<NUdf::TDyNumber>::Id) { + value = NUdf::TUnboxedValuePod::Embedded("\1"); + } else { + value = NUdf::TUnboxedValuePod::Zero(); + } + return TRuntimeNode(TDataLiteral::Create(value, targetType, Env_), true); } @@ -5536,11 +5794,11 @@ TRuntimeNode TProgramBuilder::Cast(TRuntimeNode arg, TType* type) { const auto options = NKikimr::NUdf::GetCastResult(*sourceType->GetDataSlot(), *targetType->GetDataSlot()); MKQL_ENSURE((*options & NKikimr::NUdf::ECastOptions::Undefined) || - !(*options & NKikimr::NUdf::ECastOptions::Impossible), - "Impossible to cast " << *static_cast<TType*>(sourceType) << " into " << *static_cast<TType*>(targetType)); + !(*options & NKikimr::NUdf::ECastOptions::Impossible), + "Impossible to cast " << *static_cast<TType*>(sourceType) << " into " << *static_cast<TType*>(targetType)); const bool useToIntegral = (*options & NKikimr::NUdf::ECastOptions::Undefined) || - (*options & NKikimr::NUdf::ECastOptions::MayFail); + (*options & NKikimr::NUdf::ECastOptions::MayFail); return useToIntegral ? ToIntegral(arg, type) : Convert(arg, type); } @@ -5551,17 +5809,17 @@ TRuntimeNode TProgramBuilder::RangeCreate(TRuntimeNode list) { auto tupleType = static_cast<TTupleType*>(itemType); MKQL_ENSURE(tupleType->GetElementsCount() == 2, - "Expecting list ot 2-element tuples, got: " << tupleType->GetElementsCount() << " elements"); + "Expecting list ot 2-element tuples, got: " << tupleType->GetElementsCount() << " elements"); MKQL_ENSURE(tupleType->GetElementType(0)->IsSameType(*tupleType->GetElementType(1)), - "Expecting list ot 2-element tuples of same type"); + "Expecting list ot 2-element tuples of same type"); MKQL_ENSURE(tupleType->GetElementType(0)->IsTuple(), - "Expecting range boundary to be tuple"); + "Expecting range boundary to be tuple"); auto boundaryType = static_cast<TTupleType*>(tupleType->GetElementType(0)); MKQL_ENSURE(boundaryType->GetElementsCount() >= 2, - "Range boundary should have at least 2 components, got: " << boundaryType->GetElementsCount()); + "Range boundary should have at least 2 components, got: " << boundaryType->GetElementsCount()); auto lastComp = boundaryType->GetElementType(boundaryType->GetElementsCount() - 1); std::vector<TType*> outputComponents; @@ -5597,8 +5855,8 @@ TRuntimeNode TProgramBuilder::RangeMultiply(const TArrayRef<const TRuntimeNode>& unlimited = true; } else { MKQL_ENSURE(args.front().GetStaticType()->IsData() && - static_cast<TDataType*>(args.front().GetStaticType())->GetSchemeType() == NUdf::TDataType<ui64>::Id, - "Expected ui64 as first argument"); + static_cast<TDataType*>(args.front().GetStaticType())->GetSchemeType() == NUdf::TDataType<ui64>::Id, + "Expected ui64 as first argument"); } std::vector<TType*> outputComponents; @@ -5680,18 +5938,17 @@ TRuntimeNode TProgramBuilder::Round(const std::string_view& callableName, TRunti MKQL_ENSURE(sourceType->IsData(), "Expecting first arg to be of Data type"); MKQL_ENSURE(targetType->IsData(), "Expecting second arg to be Data type"); - const auto ss = *static_cast<TDataType*>(sourceType)->GetDataSlot(); const auto ts = *static_cast<TDataType*>(targetType)->GetDataSlot(); const auto options = NKikimr::NUdf::GetCastResult(ss, ts); MKQL_ENSURE(!(*options & NKikimr::NUdf::ECastOptions::Impossible), - "Impossible to cast " << *sourceType << " into " << *targetType); + "Impossible to cast " << *sourceType << " into " << *targetType); MKQL_ENSURE(*options & (NKikimr::NUdf::ECastOptions::MayFail | NKikimr::NUdf::ECastOptions::MayLoseData | NKikimr::NUdf::ECastOptions::AnywayLoseData), - "Rounding from " << *sourceType << " to " << *targetType << " is trivial"); + "Rounding from " << *sourceType << " to " << *targetType << " is trivial"); TCallableBuilder callableBuilder(Env_, callableName, TOptionalType::Create(targetType, Env_)); callableBuilder.Add(source); @@ -5750,8 +6007,8 @@ TRuntimeNode TProgramBuilder::PgConst(TPgType* pgType, const std::string_view& v } TRuntimeNode TProgramBuilder::PgResolvedCall(bool useContext, const std::string_view& name, - ui32 id, const TArrayRef<const TRuntimeNode>& args, - TType* returnType, bool rangeFunction) { + ui32 id, const TArrayRef<const TRuntimeNode>& args, + TType* returnType, bool rangeFunction) { TCallableBuilder callableBuilder(Env_, __func__, returnType); callableBuilder.Add(NewDataLiteral(useContext)); callableBuilder.Add(NewDataLiteral(rangeFunction)); @@ -5764,7 +6021,7 @@ TRuntimeNode TProgramBuilder::PgResolvedCall(bool useContext, const std::string_ } TRuntimeNode TProgramBuilder::BlockPgResolvedCall(const std::string_view& name, ui32 id, - const TArrayRef<const TRuntimeNode>& args, TType* returnType) { + const TArrayRef<const TRuntimeNode>& args, TType* returnType) { TCallableBuilder callableBuilder(Env_, __func__, returnType); callableBuilder.Add(NewDataLiteral<NUdf::EDataSlot::String>(name)); callableBuilder.Add(NewDataLiteral(id)); @@ -5902,7 +6159,7 @@ TRuntimeNode TProgramBuilder::BlockFunc(const std::string_view& funcName, TType* } TRuntimeNode TProgramBuilder::BuildBlockCombineAll(const std::string_view& callableName, TRuntimeNode input, std::optional<ui32> filterColumn, - const TArrayRef<const TAggInfo>& aggs, TType* returnType) { + const TArrayRef<const TAggInfo>& aggs, TType* returnType) { const auto inputType = input.GetStaticType(); MKQL_ENSURE(inputType->IsStream() || inputType->IsFlow(), "Expected either stream or flow as input type"); MKQL_ENSURE(returnType->IsStream() || returnType->IsFlow(), "Expected either stream or flow as return type"); @@ -5932,7 +6189,7 @@ TRuntimeNode TProgramBuilder::BuildBlockCombineAll(const std::string_view& calla } TRuntimeNode TProgramBuilder::BlockCombineAll(TRuntimeNode stream, std::optional<ui32> filterColumn, - const TArrayRef<const TAggInfo>& aggs, TType* returnType) { + const TArrayRef<const TAggInfo>& aggs, TType* returnType) { MKQL_ENSURE(stream.GetStaticType()->IsStream(), "Expected stream as input type"); MKQL_ENSURE(returnType->IsStream(), "Expected stream as return type"); @@ -5944,8 +6201,9 @@ TRuntimeNode TProgramBuilder::BlockCombineAll(TRuntimeNode stream, std::optional } } -TRuntimeNode TProgramBuilder::BuildBlockCombineHashed(const std::string_view& callableName, TRuntimeNode input, std::optional<ui32> filterColumn, - const TArrayRef<ui32>& keys, const TArrayRef<const TAggInfo>& aggs, TType* returnType) { +TRuntimeNode TProgramBuilder::BuildBlockCombineHashed(const std::string_view& callableName, TRuntimeNode input, + std::optional<ui32> filterColumn, const TArrayRef<ui32>& keys, + const TArrayRef<const TAggInfo>& aggs, TType* returnType) { const auto inputType = input.GetStaticType(); MKQL_ENSURE(inputType->IsStream() || inputType->IsFlow(), "Expected either stream or flow as input type"); MKQL_ENSURE(returnType->IsStream() || returnType->IsFlow(), "Expected either stream or flow as return type"); @@ -5980,21 +6238,27 @@ TRuntimeNode TProgramBuilder::BuildBlockCombineHashed(const std::string_view& ca return TRuntimeNode(builder.Build(), false); } -TRuntimeNode TProgramBuilder::BlockCombineHashed(TRuntimeNode stream, std::optional<ui32> filterColumn, const TArrayRef<ui32>& keys, - const TArrayRef<const TAggInfo>& aggs, TType* returnType) { +TRuntimeNode TProgramBuilder::BlockCombineHashed(TRuntimeNode stream, std::optional<ui32> filterColumn, + const TArrayRef<ui32>& keys, + const TArrayRef<const TAggInfo>& aggs, + TType* returnType) { MKQL_ENSURE(stream.GetStaticType()->IsStream(), "Expected stream as input type"); MKQL_ENSURE(returnType->IsStream(), "Expected stream as return type"); if constexpr (RuntimeVersion < 52U) { const auto flowReturnType = NewFlowType(AS_TYPE(TStreamType, returnType)->GetItemType()); - return FromFlow(BuildBlockCombineHashed(__func__, ToFlow(stream), filterColumn, keys, aggs, flowReturnType)); + return FromFlow(BuildBlockCombineHashed(__func__, + ToFlow(stream), filterColumn, + keys, aggs, flowReturnType)); } else { return BuildBlockCombineHashed(__func__, stream, filterColumn, keys, aggs, returnType); } } -TRuntimeNode TProgramBuilder::BuildBlockMergeFinalizeHashed(const std::string_view& callableName, TRuntimeNode input, const TArrayRef<ui32>& keys, - const TArrayRef<const TAggInfo>& aggs, TType* returnType) { +TRuntimeNode TProgramBuilder::BuildBlockMergeFinalizeHashed(const std::string_view& callableName, TRuntimeNode input, + const TArrayRef<ui32>& keys, + const TArrayRef<const TAggInfo>& aggs, + TType* returnType) { const auto inputType = input.GetStaticType(); MKQL_ENSURE(inputType->IsStream() || inputType->IsFlow(), "Expected either stream or flow as input type"); MKQL_ENSURE(returnType->IsStream() || returnType->IsFlow(), "Expected either stream or flow as return type"); @@ -6024,20 +6288,25 @@ TRuntimeNode TProgramBuilder::BuildBlockMergeFinalizeHashed(const std::string_vi } TRuntimeNode TProgramBuilder::BlockMergeFinalizeHashed(TRuntimeNode stream, const TArrayRef<ui32>& keys, - const TArrayRef<const TAggInfo>& aggs, TType* returnType) { + const TArrayRef<const TAggInfo>& aggs, TType* returnType) { MKQL_ENSURE(stream.GetStaticType()->IsStream(), "Expected stream as input type"); MKQL_ENSURE(returnType->IsStream(), "Expected stream as return type"); if constexpr (RuntimeVersion < 52U) { const auto flowReturnType = NewFlowType(AS_TYPE(TStreamType, returnType)->GetItemType()); - return FromFlow(BuildBlockMergeFinalizeHashed(__func__, ToFlow(stream), keys, aggs, flowReturnType)); + return FromFlow(BuildBlockMergeFinalizeHashed(__func__, + ToFlow(stream), keys, + aggs, flowReturnType)); } else { return BuildBlockMergeFinalizeHashed(__func__, stream, keys, aggs, returnType); } } -TRuntimeNode TProgramBuilder::BuildBlockMergeManyFinalizeHashed(const std::string_view& callableName, TRuntimeNode input, const TArrayRef<ui32>& keys, - const TArrayRef<const TAggInfo>& aggs, ui32 streamIndex, const TVector<TVector<ui32>>& streams, TType* returnType) { +TRuntimeNode TProgramBuilder::BuildBlockMergeManyFinalizeHashed(const std::string_view& callableName, TRuntimeNode input, + const TArrayRef<ui32>& keys, + const TArrayRef<const TAggInfo>& aggs, + ui32 streamIndex, + const TVector<TVector<ui32>>& streams, TType* returnType) { const auto inputType = input.GetStaticType(); MKQL_ENSURE(inputType->IsStream() || inputType->IsFlow(), "Expected either stream or flow as input type"); MKQL_ENSURE(returnType->IsStream() || returnType->IsFlow(), "Expected either stream or flow as return type"); @@ -6079,7 +6348,10 @@ TRuntimeNode TProgramBuilder::BuildBlockMergeManyFinalizeHashed(const std::strin } TRuntimeNode TProgramBuilder::BlockMergeManyFinalizeHashed(TRuntimeNode stream, const TArrayRef<ui32>& keys, - const TArrayRef<const TAggInfo>& aggs, ui32 streamIndex, const TVector<TVector<ui32>>& streams, TType* returnType) { + const TArrayRef<const TAggInfo>& aggs, + ui32 streamIndex, + const TVector<TVector<ui32>>& streams, + TType* returnType) { MKQL_ENSURE(stream.GetStaticType()->IsStream(), "Expected stream as input type"); MKQL_ENSURE(returnType->IsStream(), "Expected stream as return type"); @@ -6141,7 +6413,11 @@ TRuntimeNode TProgramBuilder::BlockStorage(TRuntimeNode list, TType* returnType) return TRuntimeNode(callableBuilder.Build(), false); } -TRuntimeNode TProgramBuilder::BlockMapJoinIndex(TRuntimeNode blockStorage, TType* listItemType, const TArrayRef<const ui32>& keyColumns, bool any, TType* returnType) { +TRuntimeNode TProgramBuilder::BlockMapJoinIndex(TRuntimeNode blockStorage, + TType* listItemType, + const TArrayRef<const ui32>& keyColumns, + bool any, + TType* returnType) { if constexpr (RuntimeVersion < 62U) { THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__; } @@ -6161,9 +6437,9 @@ TRuntimeNode TProgramBuilder::BlockMapJoinIndex(TRuntimeNode blockStorage, TType TRuntimeNode::TList keyColumnsNodes; keyColumnsNodes.reserve(keyColumns.size()); std::transform(keyColumns.cbegin(), keyColumns.cend(), - std::back_inserter(keyColumnsNodes), [this](const ui32 idx) { - return NewDataLiteral(idx); - }); + std::back_inserter(keyColumnsNodes), [this](const ui32 idx) { + return NewDataLiteral(idx); + }); TCallableBuilder callableBuilder(Env_, __func__, returnType); callableBuilder.Add(blockStorage); @@ -6174,10 +6450,13 @@ TRuntimeNode TProgramBuilder::BlockMapJoinIndex(TRuntimeNode blockStorage, TType return TRuntimeNode(callableBuilder.Build(), false); } -TRuntimeNode TProgramBuilder::BlockMapJoinCore(TRuntimeNode leftStream, TRuntimeNode rightBlockStorage, TType* rightListItemType, EJoinKind joinKind, - const TArrayRef<const ui32>& leftKeyColumns, const TArrayRef<const ui32>& leftKeyDrops, - const TArrayRef<const ui32>& rightKeyColumns, const TArrayRef<const ui32>& rightKeyDrops, TType* returnType -) { +TRuntimeNode TProgramBuilder::BlockMapJoinCore(TRuntimeNode leftStream, TRuntimeNode rightBlockStorage, + TType* rightListItemType, EJoinKind joinKind, + const TArrayRef<const ui32>& leftKeyColumns, + const TArrayRef<const ui32>& leftKeyDrops, + const TArrayRef<const ui32>& rightKeyColumns, + const TArrayRef<const ui32>& rightKeyDrops, + TType* returnType) { if constexpr (RuntimeVersion < 62U) { THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__; } @@ -6191,7 +6470,7 @@ TRuntimeNode TProgramBuilder::BlockMapJoinCore(TRuntimeNode leftStream, TRuntime } MKQL_ENSURE(joinKind == EJoinKind::Inner || joinKind == EJoinKind::Left || - joinKind == EJoinKind::LeftSemi || joinKind == EJoinKind::LeftOnly || joinKind == EJoinKind::Cross, + joinKind == EJoinKind::LeftSemi || joinKind == EJoinKind::LeftOnly || joinKind == EJoinKind::Cross, "Unsupported join kind"); MKQL_ENSURE(leftKeyColumns.size() == rightKeyColumns.size(), "Key column count mismatch"); if (joinKind != EJoinKind::Cross) { @@ -6210,30 +6489,30 @@ TRuntimeNode TProgramBuilder::BlockMapJoinCore(TRuntimeNode leftStream, TRuntime TRuntimeNode::TList leftKeyColumnsNodes; leftKeyColumnsNodes.reserve(leftKeyColumns.size()); std::transform(leftKeyColumns.cbegin(), leftKeyColumns.cend(), - std::back_inserter(leftKeyColumnsNodes), [this](const ui32 idx) { - return NewDataLiteral(idx); - }); + std::back_inserter(leftKeyColumnsNodes), [this](const ui32 idx) { + return NewDataLiteral(idx); + }); TRuntimeNode::TList leftKeyDropsNodes; leftKeyDropsNodes.reserve(leftKeyDrops.size()); std::transform(leftKeyDrops.cbegin(), leftKeyDrops.cend(), - std::back_inserter(leftKeyDropsNodes), [this](const ui32 idx) { - return NewDataLiteral(idx); - }); + std::back_inserter(leftKeyDropsNodes), [this](const ui32 idx) { + return NewDataLiteral(idx); + }); TRuntimeNode::TList rightKeyColumnsNodes; rightKeyColumnsNodes.reserve(rightKeyColumns.size()); std::transform(rightKeyColumns.cbegin(), rightKeyColumns.cend(), - std::back_inserter(rightKeyColumnsNodes), [this](const ui32 idx) { - return NewDataLiteral(idx); - }); + std::back_inserter(rightKeyColumnsNodes), [this](const ui32 idx) { + return NewDataLiteral(idx); + }); TRuntimeNode::TList rightKeyDropsNodes; rightKeyDropsNodes.reserve(leftKeyDrops.size()); std::transform(rightKeyDrops.cbegin(), rightKeyDrops.cend(), - std::back_inserter(rightKeyDropsNodes), [this](const ui32 idx) { - return NewDataLiteral(idx); - }); + std::back_inserter(rightKeyDropsNodes), [this](const ui32 idx) { + return NewDataLiteral(idx); + }); TCallableBuilder callableBuilder(Env_, __func__, returnType); callableBuilder.Add(leftStream); @@ -6253,18 +6532,18 @@ using namespace NYql::NMatchRecognize; TRuntimeNode PatternToRuntimeNode(const TRowPattern& pattern, const TProgramBuilder& programBuilder) { const auto& env = programBuilder.GetTypeEnvironment(); TTupleLiteralBuilder patternBuilder(env); - for (const auto& term: pattern) { + for (const auto& term : pattern) { TTupleLiteralBuilder termBuilder(env); - for (const auto& factor: term) { + for (const auto& factor : term) { TTupleLiteralBuilder factorBuilder(env); - factorBuilder.Add(std::visit(TOverloaded { - [&](const TString& s) { - return programBuilder.NewDataLiteral<NUdf::EDataSlot::String>(s); - }, - [&](const TRowPattern& pattern) { - return PatternToRuntimeNode(pattern, programBuilder); - }, - }, factor.Primary)); + factorBuilder.Add(std::visit(TOverloaded{ + [&](const TString& s) { + return programBuilder.NewDataLiteral<NUdf::EDataSlot::String>(s); + }, + [&](const TRowPattern& pattern) { + return PatternToRuntimeNode(pattern, programBuilder); + }, + }, factor.Primary)); factorBuilder.Add(programBuilder.NewDataLiteral(factor.QuantityMin)); factorBuilder.Add(programBuilder.NewDataLiteral(factor.QuantityMax)); factorBuilder.Add(programBuilder.NewDataLiteral(factor.Greedy)); @@ -6277,7 +6556,7 @@ TRuntimeNode PatternToRuntimeNode(const TRowPattern& pattern, const TProgramBuil return {patternBuilder.Build(), true}; }; -} //namespace +} // namespace TRuntimeNode TProgramBuilder::MatchRecognizeCore( TRuntimeNode inputStream, @@ -6290,19 +6569,16 @@ TRuntimeNode TProgramBuilder::MatchRecognizeCore( const TVector<TTernaryLambda>& getDefines, bool streamingMode, const NYql::NMatchRecognize::TAfterMatchSkipTo& skipTo, - NYql::NMatchRecognize::ERowsPerMatch rowsPerMatch -) { + NYql::NMatchRecognize::ERowsPerMatch rowsPerMatch) { const auto inputRowType = AS_TYPE(TStructType, AS_TYPE(TFlowType, inputStream.GetStaticType())->GetItemType()); const auto inputRowArg = Arg(inputRowType); const auto partitionKeySelectorNode = getPartitionKeySelectorNode(inputRowArg); const auto partitionColumnTypes = AS_TYPE(TTupleType, partitionKeySelectorNode.GetStaticType())->GetElements(); - const auto rangeList = NewListType(NewStructType({ - {"From", NewDataType(NUdf::EDataSlot::Uint64)}, - {"To", NewDataType(NUdf::EDataSlot::Uint64)} - })); + const auto rangeList = NewListType(NewStructType({{"From", NewDataType(NUdf::EDataSlot::Uint64)}, + {"To", NewDataType(NUdf::EDataSlot::Uint64)}})); TStructTypeBuilder matchedVarsTypeBuilder(Env_); - for (const auto& var: GetPatternVars(pattern)) { + for (const auto& var : GetPatternVars(pattern)) { matchedVarsTypeBuilder.Add(var, rangeList); } const auto matchedVarsType = matchedVarsTypeBuilder.Build(); @@ -6323,22 +6599,21 @@ TRuntimeNode TProgramBuilder::MatchRecognizeCore( measureInputDataRowTypeBuilder.Add(inputRowType->GetMemberName(i), inputRowType->GetMemberType(i)); } measureInputDataRowTypeBuilder.Add( - MeasureInputDataSpecialColumnName(EMeasureInputDataSpecialColumns::Classifier), - NewDataType(NUdf::EDataSlot::Utf8) - ); + MeasureInputDataSpecialColumnName(EMeasureInputDataSpecialColumns::Classifier), + NewDataType(NUdf::EDataSlot::Utf8)); measureInputDataRowTypeBuilder.Add( - MeasureInputDataSpecialColumnName(EMeasureInputDataSpecialColumns::MatchNumber), - NewDataType(NUdf::EDataSlot::Uint64) - ); + MeasureInputDataSpecialColumnName(EMeasureInputDataSpecialColumns::MatchNumber), + NewDataType(NUdf::EDataSlot::Uint64)); const auto measureInputDataRowType = measureInputDataRowTypeBuilder.Build(); for (ui32 i = 0; i < measureInputDataRowType->GetMembersCount(); ++i) { - //assume a few, if grows, it's better to use a lookup table here + // assume a few, if grows, it's better to use a lookup table here static_assert(static_cast<size_t>(EMeasureInputDataSpecialColumns::Last) < 5); for (size_t j = 0; j != static_cast<size_t>(EMeasureInputDataSpecialColumns::Last); ++j) { if (measureInputDataRowType->GetMemberName(i) == - NYql::NMatchRecognize::MeasureInputDataSpecialColumnName(static_cast<EMeasureInputDataSpecialColumns>(j))) + NYql::NMatchRecognize::MeasureInputDataSpecialColumnName(static_cast<EMeasureInputDataSpecialColumns>(j))) { specialColumnIndexesInMeasureInputDataRow[j] = NewDataLiteral(i); + } } } @@ -6358,20 +6633,20 @@ TRuntimeNode TProgramBuilder::MatchRecognizeCore( outputRowTypeBuilder.Add(name, measures[i].GetStaticType()); } switch (rowsPerMatch) { - case NYql::NMatchRecognize::ERowsPerMatch::OneRow: - for (size_t i = 0; i < partitionColumnNames.size(); ++i) { - const auto name = partitionColumnNames[i]; - partitionColumnLookup.emplace(name, i); - outputRowTypeBuilder.Add(name, partitionColumnTypes[i]); - } - break; - case NYql::NMatchRecognize::ERowsPerMatch::AllRows: - for (size_t i = 0; i < inputRowType->GetMembersCount(); ++i) { - const auto name = inputRowType->GetMemberName(i); - otherColumnLookup.emplace(name, i); - outputRowTypeBuilder.Add(name, inputRowType->GetMemberType(i)); - } - break; + case NYql::NMatchRecognize::ERowsPerMatch::OneRow: + for (size_t i = 0; i < partitionColumnNames.size(); ++i) { + const auto name = partitionColumnNames[i]; + partitionColumnLookup.emplace(name, i); + outputRowTypeBuilder.Add(name, partitionColumnTypes[i]); + } + break; + case NYql::NMatchRecognize::ERowsPerMatch::AllRows: + for (size_t i = 0; i < inputRowType->GetMembersCount(); ++i) { + const auto name = inputRowType->GetMemberName(i); + otherColumnLookup.emplace(name, i); + outputRowTypeBuilder.Add(name, inputRowType->GetMemberType(i)); + } + break; } auto outputRowType = outputRowTypeBuilder.Build(); @@ -6388,14 +6663,14 @@ TRuntimeNode TProgramBuilder::MatchRecognizeCore( std::pair{"SourceType", NewDataLiteral(static_cast<i32>(EOutputColumnSource::PartitionKey))}, })); } else if (auto iter = measureColumnLookup.find(name); - iter != measureColumnLookup.end()) { + iter != measureColumnLookup.end()) { measureColumnIndexes[iter->second] = NewDataLiteral(i); outputColumnOrder.push_back(NewStruct({ std::pair{"Index", NewDataLiteral(iter->second)}, std::pair{"SourceType", NewDataLiteral(static_cast<i32>(EOutputColumnSource::Measure))}, })); } else if (auto iter = otherColumnLookup.find(name); - iter != otherColumnLookup.end()) { + iter != otherColumnLookup.end()) { outputColumnOrder.push_back(NewStruct({ std::pair{"Index", NewDataLiteral(iter->second)}, std::pair{"SourceType", NewDataLiteral(static_cast<i32>(EOutputColumnSource::Other))}, @@ -6419,13 +6694,13 @@ TRuntimeNode TProgramBuilder::MatchRecognizeCore( TVector<TRuntimeNode> defineNodes(patternVarLookup.size()); const auto inputDataArg = Arg(NewListType(inputRowType)); const auto currentRowIndexArg = Arg(NewDataType(NUdf::EDataSlot::Uint64)); - for (const auto& [v, i]: patternVarLookup) { + for (const auto& [v, i] : patternVarLookup) { defineNames[i] = NewDataLiteral<NUdf::EDataSlot::String>(v); if (auto iter = defineLookup.find(v); iter != defineLookup.end()) { defineNodes[i] = getDefines[iter->second](inputDataArg, matchedVarsArg, currentRowIndexArg); } else if ("$" == v || "^" == v) { - //DO nothing, //will be handled in a specific way + // DO nothing, //will be handled in a specific way } else { // a var without a predicate matches any row defineNodes[i] = NewDataLiteral(true); } @@ -6446,7 +6721,7 @@ TRuntimeNode TProgramBuilder::MatchRecognizeCore( callableBuilder.Add(NewDataLiteral(inputRowType->GetMembersCount())); callableBuilder.Add(matchedVarsArg); callableBuilder.Add(NewList(indexType, measureColumnIndexes)); - for (const auto& m: measures) { + for (const auto& m : measures) { callableBuilder.Add(m); } @@ -6455,7 +6730,7 @@ TRuntimeNode TProgramBuilder::MatchRecognizeCore( callableBuilder.Add(currentRowIndexArg); callableBuilder.Add(inputDataArg); callableBuilder.Add(NewList(NewDataType(NUdf::EDataSlot::String), defineNames)); - for (const auto& d: defineNodes) { + for (const auto& d : defineNodes) { callableBuilder.Add(d); } callableBuilder.Add(NewDataLiteral(streamingMode)); @@ -6475,8 +6750,7 @@ TRuntimeNode TProgramBuilder::TimeOrderRecover( const TUnaryLambda& getTimeExtractor, TRuntimeNode delay, TRuntimeNode ahead, - TRuntimeNode rowLimit - ) + TRuntimeNode rowLimit) { auto& inputRowType = *static_cast<TStructType*>(AS_TYPE(TStructType, AS_TYPE(TFlowType, inputStream.GetStaticType())->GetItemType())); const auto inputRowArg = Arg(&inputRowType); @@ -6498,8 +6772,8 @@ TRuntimeNode TProgramBuilder::TimeOrderRecover( callableBuilder.Add(NewDataLiteral(inputRowColumnCount)); callableBuilder.Add(NewDataLiteral(outOfOrderColumnIndex)); callableBuilder.Add(delay), - callableBuilder.Add(ahead), - callableBuilder.Add(rowLimit); + callableBuilder.Add(ahead), + callableBuilder.Add(rowLimit); return TRuntimeNode(callableBuilder.Build(), false); } @@ -6513,119 +6787,119 @@ bool CanExportType(TType* type, const TTypeEnvironment& env) { bool canExport = true; for (auto& node : explorer.GetNodes()) { switch (static_cast<TType*>(node)->GetKind()) { - case TType::EKind::Void: - node->SetCookie(1); - break; - - case TType::EKind::Data: - node->SetCookie(1); - break; - - case TType::EKind::Pg: - node->SetCookie(1); - break; - - case TType::EKind::Optional: { - auto optionalType = static_cast<TOptionalType*>(node); - if (!optionalType->GetItemType()->GetCookie()) { - canExport = false; - } else { + case TType::EKind::Void: node->SetCookie(1); - } - - break; - } + break; - case TType::EKind::List: { - auto listType = static_cast<TListType*>(node); - if (!listType->GetItemType()->GetCookie()) { - canExport = false; - } else { + case TType::EKind::Data: node->SetCookie(1); - } + break; - break; - } + case TType::EKind::Pg: + node->SetCookie(1); + break; - case TType::EKind::Struct: { - auto structType = static_cast<TStructType*>(node); - for (ui32 index = 0; index < structType->GetMembersCount(); ++index) { - if (!structType->GetMemberType(index)->GetCookie()) { + case TType::EKind::Optional: { + auto optionalType = static_cast<TOptionalType*>(node); + if (!optionalType->GetItemType()->GetCookie()) { canExport = false; - break; + } else { + node->SetCookie(1); } - } - if (canExport) { - node->SetCookie(1); + break; } - break; - } - - case TType::EKind::Tuple: { - auto tupleType = static_cast<TTupleType*>(node); - for (ui32 index = 0; index < tupleType->GetElementsCount(); ++index) { - if (!tupleType->GetElementType(index)->GetCookie()) { + case TType::EKind::List: { + auto listType = static_cast<TListType*>(node); + if (!listType->GetItemType()->GetCookie()) { canExport = false; - break; + } else { + node->SetCookie(1); } - } - if (canExport) { - node->SetCookie(1); + break; } - break; - } - - case TType::EKind::Dict: { - auto dictType = static_cast<TDictType*>(node); - if (!dictType->GetKeyType()->GetCookie() || !dictType->GetPayloadType()->GetCookie()) { - canExport = false; - } else { - node->SetCookie(1); - } - - break; - } - - case TType::EKind::Variant: { - auto variantType = static_cast<TVariantType*>(node); - TType* innerType = variantType->GetUnderlyingType(); - - if (innerType->IsStruct()) { - auto structType = static_cast<TStructType*>(innerType); + case TType::EKind::Struct: { + auto structType = static_cast<TStructType*>(node); for (ui32 index = 0; index < structType->GetMembersCount(); ++index) { if (!structType->GetMemberType(index)->GetCookie()) { canExport = false; break; } } + + if (canExport) { + node->SetCookie(1); + } + + break; } - if (innerType->IsTuple()) { - auto tupleType = static_cast<TTupleType*>(innerType); + case TType::EKind::Tuple: { + auto tupleType = static_cast<TTupleType*>(node); for (ui32 index = 0; index < tupleType->GetElementsCount(); ++index) { if (!tupleType->GetElementType(index)->GetCookie()) { canExport = false; break; } } + + if (canExport) { + node->SetCookie(1); + } + + break; } - if (canExport) { - node->SetCookie(1); + case TType::EKind::Dict: { + auto dictType = static_cast<TDictType*>(node); + if (!dictType->GetKeyType()->GetCookie() || !dictType->GetPayloadType()->GetCookie()) { + canExport = false; + } else { + node->SetCookie(1); + } + + break; } - break; - } + case TType::EKind::Variant: { + auto variantType = static_cast<TVariantType*>(node); + TType* innerType = variantType->GetUnderlyingType(); + + if (innerType->IsStruct()) { + auto structType = static_cast<TStructType*>(innerType); + for (ui32 index = 0; index < structType->GetMembersCount(); ++index) { + if (!structType->GetMemberType(index)->GetCookie()) { + canExport = false; + break; + } + } + } - case TType::EKind::Type: - break; + if (innerType->IsTuple()) { + auto tupleType = static_cast<TTupleType*>(innerType); + for (ui32 index = 0; index < tupleType->GetElementsCount(); ++index) { + if (!tupleType->GetElementType(index)->GetCookie()) { + canExport = false; + break; + } + } + } - default: - canExport = false; + if (canExport) { + node->SetCookie(1); + } + + break; + } + + case TType::EKind::Type: + break; + + default: + canExport = false; } if (!canExport) { @@ -6640,5 +6914,5 @@ bool CanExportType(TType* type, const TTypeEnvironment& env) { return canExport; } -} -} +} // namespace NMiniKQL +} // namespace NKikimr |
