diff options
author | ninaiad <ninaiad@yandex-team.com> | 2025-04-24 13:42:39 +0300 |
---|---|---|
committer | ninaiad <ninaiad@yandex-team.com> | 2025-04-24 13:58:03 +0300 |
commit | 921a285daaf8009a0afd651fbb38356409582c15 (patch) | |
tree | ecb49fec63d0d45b1467f54463ad041b6d859186 | |
parent | cd04ddcf7f59905292eb61fa3d747d1697e9d00d (diff) | |
download | ydb-921a285daaf8009a0afd651fbb38356409582c15.tar.gz |
YQL TaggedType <-> Arrow
commit_hash:f775e1468b8181bc7f7f409dde6a1af284725e68
-rw-r--r-- | yql/essentials/minikql/mkql_type_builder.cpp | 11 | ||||
-rw-r--r-- | yql/essentials/minikql/mkql_type_builder_ut.cpp | 16 | ||||
-rw-r--r-- | yql/essentials/public/udf/arrow/dispatch_traits.h | 6 | ||||
-rw-r--r-- | yql/essentials/public/udf/arrow/ut/array_builder_ut.cpp | 51 | ||||
-rw-r--r-- | yql/essentials/public/udf/arrow/util.cpp | 10 | ||||
-rw-r--r-- | yql/essentials/public/udf/arrow/util.h | 4 |
6 files changed, 95 insertions, 3 deletions
diff --git a/yql/essentials/minikql/mkql_type_builder.cpp b/yql/essentials/minikql/mkql_type_builder.cpp index 59c00d17ebc..f6e529a4a37 100644 --- a/yql/essentials/minikql/mkql_type_builder.cpp +++ b/yql/essentials/minikql/mkql_type_builder.cpp @@ -1660,6 +1660,12 @@ bool ConvertArrowTypeImpl(TType* itemType, std::shared_ptr<arrow::DataType>& typ return true; } + if (itemType->IsTagged()) { + auto taggedType = AS_TYPE(TTaggedType, itemType); + auto baseType = taggedType->GetBaseType(); + return ConvertArrowTypeImpl(baseType, type, onFail, output); + } + if (IsSingularType(unpacked)) { type = arrow::null(); return true; @@ -2562,6 +2568,11 @@ size_t CalcMaxBlockItemSize(const TType* type) { return 0; } + if (type->IsTagged()) { + auto taggedType = AS_TYPE(TTaggedType, type); + return CalcMaxBlockItemSize(taggedType->GetBaseType()); + } + if (type->IsData()) { auto slot = *AS_TYPE(TDataType, type)->GetDataSlot(); switch (slot) { diff --git a/yql/essentials/minikql/mkql_type_builder_ut.cpp b/yql/essentials/minikql/mkql_type_builder_ut.cpp index 95d47795df7..f74808e4651 100644 --- a/yql/essentials/minikql/mkql_type_builder_ut.cpp +++ b/yql/essentials/minikql/mkql_type_builder_ut.cpp @@ -40,6 +40,7 @@ private: UNIT_TEST(TestDataTypeFormat); UNIT_TEST(TestBlockTypeFormat); UNIT_TEST(TestArrowType); + UNIT_TEST(TestArrowTaggedType); UNIT_TEST_SUITE_END(); TString FormatType(NUdf::TType* t) { @@ -148,8 +149,8 @@ private: void TestTaggedTypeFormat() { { - auto s = FormatType(FunctionTypeInfoBuilder.Tagged(FunctionTypeInfoBuilder.SimpleType<i8>(), "my_resource")); - UNIT_ASSERT_VALUES_EQUAL(s, "Tagged<Int8,'my_resource'>"); + auto s = FormatType(FunctionTypeInfoBuilder.Tagged(FunctionTypeInfoBuilder.SimpleType<i8>(), "my_tag")); + UNIT_ASSERT_VALUES_EQUAL(s, "Tagged<Int8,'my_tag'>"); } } @@ -348,6 +349,17 @@ private: auto atype2 = TypeInfoHelper->ImportArrowType(&s); UNIT_ASSERT_VALUES_EQUAL(static_cast<TArrowType*>(atype2.Get())->GetType()->ToString(), std::string("uint64")); } + + void TestArrowTaggedType() { + auto type = FunctionTypeInfoBuilder.Tagged(FunctionTypeInfoBuilder.SimpleType<ui64>(), "my_tag"); + auto atype1 = TypeInfoHelper->MakeArrowType(type); + UNIT_ASSERT(atype1); + UNIT_ASSERT_VALUES_EQUAL(static_cast<TArrowType*>(atype1.Get())->GetType()->ToString(), std::string("uint64")); + ArrowSchema s; + atype1->Export(&s); + auto atype2 = TypeInfoHelper->ImportArrowType(&s); + UNIT_ASSERT_VALUES_EQUAL(static_cast<TArrowType*>(atype2.Get())->GetType()->ToString(), std::string("uint64")); + } }; UNIT_TEST_SUITE_REGISTRATION(TMiniKQLTypeBuilderTest); diff --git a/yql/essentials/public/udf/arrow/dispatch_traits.h b/yql/essentials/public/udf/arrow/dispatch_traits.h index 87c25b93f56..93d7084296e 100644 --- a/yql/essentials/public/udf/arrow/dispatch_traits.h +++ b/yql/essentials/public/udf/arrow/dispatch_traits.h @@ -84,8 +84,9 @@ std::unique_ptr<typename TTraits::TResult> DispatchByArrowTraits(const ITypeInfo isOptional = true; } + unpacked = SkipTaggedType(typeInfoHelper, unpacked); + TOptionalTypeInspector unpackedOpt(typeInfoHelper, unpacked); - TPgTypeInspector unpackedPg(typeInfoHelper, unpacked); if (unpackedOpt || (typeOpt && NeedWrapWithExternalOptional(typeInfoHelper, unpacked))) { ui32 nestLevel = 0; auto currentType = type; @@ -97,6 +98,9 @@ std::unique_ptr<typename TTraits::TResult> DispatchByArrowTraits(const ITypeInfo types.push_back(currentType); TOptionalTypeInspector currentOpt(typeInfoHelper, currentType); currentType = currentOpt.GetItemType(); + + currentType = SkipTaggedType(typeInfoHelper, currentType); + TOptionalTypeInspector nexOpt(typeInfoHelper, currentType); if (!nexOpt) { break; diff --git a/yql/essentials/public/udf/arrow/ut/array_builder_ut.cpp b/yql/essentials/public/udf/arrow/ut/array_builder_ut.cpp index d0851c5e869..117ebad40c7 100644 --- a/yql/essentials/public/udf/arrow/ut/array_builder_ut.cpp +++ b/yql/essentials/public/udf/arrow/ut/array_builder_ut.cpp @@ -56,6 +56,57 @@ Y_UNIT_TEST_SUITE(TArrayBuilderTest) { "Expected equal values after building array"); } + Y_UNIT_TEST(TestTaggedTypeBuilder) { + TArrayBuilderTestData data; + const auto intType = data.PgmBuilder.NewDataType(NUdf::EDataSlot::Int32, false); + const auto taggedType = data.PgmBuilder.NewTaggedType(intType, "tag"); + + const auto arrayBuilder = MakeArrayBuilder(NMiniKQL::TTypeInfoHelper(), taggedType, *data.ArrowPool, MAX_BLOCK_SIZE, /*pgBuilder=*/nullptr); + + TUnboxedValue testData = TUnboxedValuePod(123); + + arrayBuilder->Add(testData); + + auto datum = arrayBuilder->Build(true); + + UNIT_ASSERT(datum.is_array()); + UNIT_ASSERT_VALUES_EQUAL(datum.length(), 1); + + auto value = datum.array()->buffers[1]; + + UNIT_ASSERT_VALUES_EQUAL(*reinterpret_cast<int32_t*>(value->address()), 123); + } + + Y_UNIT_TEST(TestTaggedTypeReader) { + TArrayBuilderTestData data; + const auto intType = data.PgmBuilder.NewDataType(NUdf::EDataSlot::Int32, false); + const auto taggedType = data.PgmBuilder.NewTaggedType(intType, "tag"); + + const auto arrayBuilder = MakeArrayBuilder(NMiniKQL::TTypeInfoHelper(), taggedType, *data.ArrowPool, MAX_BLOCK_SIZE, /*pgBuilder=*/nullptr); + + TUnboxedValue first = TUnboxedValuePod(123); + TUnboxedValue second = TUnboxedValuePod(456); + + arrayBuilder->Add(first); + arrayBuilder->Add(second); + + auto datum = arrayBuilder->Build(true); + + UNIT_ASSERT(datum.is_array()); + UNIT_ASSERT_VALUES_EQUAL(datum.length(), 2); + + const auto blockReader = MakeBlockReader(NMiniKQL::TTypeInfoHelper(), taggedType); + + const auto item1AfterRead = blockReader->GetItem(*datum.array(), 0); + const auto item2AfterRead = blockReader->GetItem(*datum.array(), 1); + + UNIT_ASSERT_C(item1AfterRead.HasValue(), "Expected not null"); + UNIT_ASSERT_C(item2AfterRead.HasValue(), "Expected not null"); + + UNIT_ASSERT_VALUES_EQUAL(item1AfterRead.Get<int>(), 123); + UNIT_ASSERT_VALUES_EQUAL(item2AfterRead.Get<int>(), 456); + } + extern const char ResourceName[] = "Resource.Name"; Y_UNIT_TEST(TestDtorCall) { TArrayBuilderTestData data; diff --git a/yql/essentials/public/udf/arrow/util.cpp b/yql/essentials/public/udf/arrow/util.cpp index 169a809d0b6..7fa2103bfd4 100644 --- a/yql/essentials/public/udf/arrow/util.cpp +++ b/yql/essentials/public/udf/arrow/util.cpp @@ -169,5 +169,15 @@ ui64 GetSizeOfArrowExecBatchInBytes(const arrow::compute::ExecBatch& batch) { return size; } + +const TType* SkipTaggedType(const ITypeInfoHelper& typeInfoHelper, const TType* type) { + TTaggedTypeInspector typeTagged(typeInfoHelper, type); + while (typeTagged) { + type = typeTagged.GetBaseType(); + typeTagged = TTaggedTypeInspector(typeInfoHelper, type); + } + + return type; +} } } diff --git a/yql/essentials/public/udf/arrow/util.h b/yql/essentials/public/udf/arrow/util.h index ea9033b35c6..a4e430aa4ad 100644 --- a/yql/essentials/public/udf/arrow/util.h +++ b/yql/essentials/public/udf/arrow/util.h @@ -248,7 +248,11 @@ inline bool IsSingularType(const ITypeInfoHelper& typeInfoHelper, const TType* t kind == ETypeKind::EmptyList; } +const TType* SkipTaggedType(const ITypeInfoHelper& typeInfoHelper, const TType* type); + inline bool NeedWrapWithExternalOptional(const ITypeInfoHelper& typeInfoHelper, const TType* type) { + type = SkipTaggedType(typeInfoHelper, type); + return TPgTypeInspector(typeInfoHelper, type) || IsSingularType(typeInfoHelper, type); } |