diff options
author | Daniil Timizhev <[email protected]> | 2025-09-15 12:54:53 +0300 |
---|---|---|
committer | GitHub <[email protected]> | 2025-09-15 12:54:53 +0300 |
commit | 0896bce9b4bddddd0a79a3283f6c5c5501a0f15b (patch) | |
tree | f04d7aabe6bee590a9018bee250f4e0e8d6da196 | |
parent | a42fcfde366e2a713faf9cd498edd973146d9d06 (diff) |
Support containers with Arrow format of result sets (#23831)oidc-1.2.6
-rw-r--r-- | ydb/core/formats/arrow/arrow_batch_builder.cpp | 49 | ||||
-rw-r--r-- | ydb/core/formats/arrow/arrow_batch_builder.h | 20 | ||||
-rw-r--r-- | ydb/core/formats/arrow/arrow_helpers.cpp | 6 | ||||
-rw-r--r-- | ydb/core/formats/arrow/arrow_helpers.h | 4 | ||||
-rw-r--r-- | ydb/core/formats/arrow/arrow_helpers_minikql.cpp | 42 | ||||
-rw-r--r-- | ydb/core/formats/arrow/arrow_helpers_minikql.h | 17 | ||||
-rw-r--r-- | ydb/core/formats/arrow/switch/switch_type.h | 1 | ||||
-rw-r--r-- | ydb/core/formats/arrow/ya.make | 3 | ||||
-rw-r--r-- | ydb/core/kqp/query_data/kqp_query_data.cpp | 21 | ||||
-rw-r--r-- | ydb/core/kqp/runtime/kqp_transport.cpp | 17 | ||||
-rw-r--r-- | ydb/core/kqp/ut/arrow/kqp_result_set_formats.cpp (renamed from ydb/core/kqp/ut/arrow/kqp_result_set_format_arrow.cpp) | 659 | ||||
-rw-r--r-- | ydb/core/kqp/ut/arrow/ya.make | 2 | ||||
-rw-r--r-- | ydb/core/tx/columnshard/test_helper/columnshard_ut_common.h | 7 | ||||
-rw-r--r-- | ydb/library/yql/dq/runtime/dq_arrow_helpers.cpp | 766 | ||||
-rw-r--r-- | ydb/library/yql/dq/runtime/dq_arrow_helpers.h | 42 | ||||
-rw-r--r-- | ydb/library/yql/dq/runtime/dq_arrow_helpers_ut.cpp | 799 |
16 files changed, 2054 insertions, 401 deletions
diff --git a/ydb/core/formats/arrow/arrow_batch_builder.cpp b/ydb/core/formats/arrow/arrow_batch_builder.cpp index 58e8b4ba909..9cfa08a278a 100644 --- a/ydb/core/formats/arrow/arrow_batch_builder.cpp +++ b/ydb/core/formats/arrow/arrow_batch_builder.cpp @@ -1,7 +1,13 @@ #include "arrow_batch_builder.h" -#include "switch/switch_type.h" + +#include <ydb/core/formats/arrow/arrow_helpers_minikql.h> +#include <ydb/core/formats/arrow/switch/switch_type.h> +#include <ydb/core/kqp/common/kqp_types.h> +#include <ydb/library/yql/dq/runtime/dq_arrow_helpers.h> + #include <contrib/libs/apache/arrow/cpp/src/arrow/io/memory.h> #include <contrib/libs/apache/arrow/cpp/src/arrow/ipc/reader.h> + namespace NKikimr::NArrow { namespace { @@ -78,6 +84,15 @@ arrow::Status AppendCell(arrow::RecordBatchBuilder& builder, const TCell& cell, return result; } +arrow::Status AppendValue(arrow::RecordBatchBuilder& builder, const NUdf::TUnboxedValue& value, ui32 colNum, const NKikimr::NMiniKQL::TType* type) { + try { + NYql::NArrow::AppendElement(value, builder.GetField(colNum), type); + } catch (const std::exception& e) { + return arrow::Status::FromArgs(arrow::StatusCode::Invalid, e.what()); + } + return arrow::Status::OK(); +} + } NKikimr::NArrow::TRecordBatchConstructor::TRecordConstructor& TRecordBatchConstructor::TRecordConstructor::AddRecordValue( @@ -225,6 +240,20 @@ arrow::Status TArrowBatchBuilder::Start(const std::vector<std::pair<TString, NSc return arrow::Status::OK(); } +arrow::Status TArrowBatchBuilder::Start(const std::vector<std::pair<TString, NKikimr::NMiniKQL::TType*>> yqlColumns) { + YqlSchema = yqlColumns; + auto schema = MakeArrowSchema(yqlColumns, NotNullColumns); + if (!schema.ok()) { + return arrow::Status::FromArgs(schema.status().code(), "Cannot make arrow schema: ", schema.status().ToString()); + } + auto status = arrow::RecordBatchBuilder::Make(*schema, MemoryPool, RowsToReserve, &BatchBuilder); + NumRows = NumBytes = 0; + if (!status.ok()) { + return arrow::Status::FromArgs(schema.status().code(), "Cannot make arrow builder: ", status.ToString()); + } + return arrow::Status::OK(); +} + void TArrowBatchBuilder::AppendCell(const TCell& cell, ui32 colNum) { NumBytes += cell.Size(); auto ydbType = YdbSchema[colNum].second; @@ -232,6 +261,13 @@ void TArrowBatchBuilder::AppendCell(const TCell& cell, ui32 colNum) { Y_ABORT_UNLESS(status.ok(), "Failed to append cell: %s", status.ToString().c_str()); } +void TArrowBatchBuilder::AppendValue(const NUdf::TUnboxedValue& value, ui32 colNum) { + NumBytes += sizeof(NUdf::TUnboxedValue); // TODO: strings or containers sizes? + auto yqlType = YqlSchema[colNum].second; + auto status = NKikimr::NArrow::AppendValue(*BatchBuilder, value, colNum, yqlType); + Y_ENSURE(status.ok(), "Failed to append value: " << status.ToString()); +} + void TArrowBatchBuilder::AddRow(const TDbTupleRef& key, const TDbTupleRef& value) { ++NumRows; @@ -273,6 +309,17 @@ void TArrowBatchBuilder::AddRow(const TConstArrayRef<TCell>& row) { } } +void TArrowBatchBuilder::AddRow(const NUdf::TUnboxedValue& row, size_t membersCount, const TVector<ui32>* columnOrder) { + Y_ABORT_UNLESS(!YqlSchema.empty()); + + ++NumRows; + + for (size_t i = 0; i < membersCount; ++i) { + const auto& memberIndex = (!columnOrder || columnOrder->empty()) ? i : (*columnOrder)[i]; + AppendValue(row.GetElement(memberIndex), i); + } +} + void TArrowBatchBuilder::ReserveData(ui32 columnNo, size_t size) { if (!BatchBuilder || columnNo >= (ui32)BatchBuilder->num_fields()) { return; diff --git a/ydb/core/formats/arrow/arrow_batch_builder.h b/ydb/core/formats/arrow/arrow_batch_builder.h index d1341208a02..052457d992f 100644 --- a/ydb/core/formats/arrow/arrow_batch_builder.h +++ b/ydb/core/formats/arrow/arrow_batch_builder.h @@ -1,9 +1,18 @@ #pragma once -#include "arrow_helpers.h" + #include <ydb/core/formats/factory.h> +#include <ydb/core/formats/arrow/arrow_helpers.h> #include <ydb/core/scheme/scheme_tablecell.h> #include <ydb/library/conclusion/status.h> +namespace NYql::NUdf { +class TUnboxedValue; +} + +namespace NKikimr::NMiniKQL { +class TType; +} + namespace NKikimr::NArrow { class TRecordBatchReader { @@ -169,6 +178,7 @@ public: void AddRow(const NKikimr::TDbTupleRef& key, const NKikimr::TDbTupleRef& value) override; void AddRow(const TConstArrayRef<TCell>& key, const TConstArrayRef<TCell>& value); void AddRow(const TConstArrayRef<TCell>& row); + void AddRow(const NYql::NUdf::TUnboxedValue& row, size_t membersCount, const TVector<ui32>* columnOrder); // You have to call it before Start() void Reserve(size_t numRows) { @@ -183,19 +193,27 @@ public: } arrow::Status Start(const std::vector<std::pair<TString, NScheme::TTypeInfo>>& columns); + arrow::Status Start(const std::vector<std::pair<TString, NKikimr::NMiniKQL::TType*>> columns); + std::shared_ptr<arrow::RecordBatch> FlushBatch(bool reinitialize, bool flushEmpty = false); std::shared_ptr<arrow::RecordBatch> GetBatch() const { return Batch; } protected: void AppendCell(const TCell& cell, ui32 colNum); + void AppendValue(const NYql::NUdf::TUnboxedValue& value, ui32 colNum); const std::vector<std::pair<TString, NScheme::TTypeInfo>>& GetYdbSchema() const { return YdbSchema; } + const std::vector<std::pair<TString, NKikimr::NMiniKQL::TType*>> GetYqlSchema() const { + return YqlSchema; + } + private: arrow::ipc::IpcWriteOptions WriteOptions; std::vector<std::pair<TString, NScheme::TTypeInfo>> YdbSchema; + std::vector<std::pair<TString, NKikimr::NMiniKQL::TType*>> YqlSchema; std::unique_ptr<arrow::RecordBatchBuilder> BatchBuilder; std::shared_ptr<arrow::RecordBatch> Batch; size_t RowsToReserve{DEFAULT_ROWS_TO_RESERVE}; diff --git a/ydb/core/formats/arrow/arrow_helpers.cpp b/ydb/core/formats/arrow/arrow_helpers.cpp index 8f23ac19f28..54766f40615 100644 --- a/ydb/core/formats/arrow/arrow_helpers.cpp +++ b/ydb/core/formats/arrow/arrow_helpers.cpp @@ -89,11 +89,11 @@ arrow::Result<std::shared_ptr<arrow::DataType>> GetCSVArrowType(NScheme::TTypeIn } arrow::Result<arrow::FieldVector> MakeArrowFields( - const std::vector<std::pair<TString, NScheme::TTypeInfo>>& columns, const std::set<std::string>& notNullColumns) { + const std::vector<std::pair<TString, NScheme::TTypeInfo>>& ydbColumns, const std::set<std::string>& notNullColumns) { std::vector<std::shared_ptr<arrow::Field>> fields; - fields.reserve(columns.size()); + fields.reserve(ydbColumns.size()); TVector<TString> errors; - for (auto& [name, ydbType] : columns) { + for (auto& [name, ydbType] : ydbColumns) { std::string colName(name.data(), name.size()); auto arrowType = GetArrowType(ydbType); if (arrowType.ok()) { diff --git a/ydb/core/formats/arrow/arrow_helpers.h b/ydb/core/formats/arrow/arrow_helpers.h index 61a8c2355fd..952651282ed 100644 --- a/ydb/core/formats/arrow/arrow_helpers.h +++ b/ydb/core/formats/arrow/arrow_helpers.h @@ -48,8 +48,8 @@ std::optional<ui32> FindUpperOrEqualPosition(const TArray& arr, const TValue val arrow::Result<std::shared_ptr<arrow::DataType>> GetArrowType(NScheme::TTypeInfo typeInfo); arrow::Result<std::shared_ptr<arrow::DataType>> GetCSVArrowType(NScheme::TTypeInfo typeId); -arrow::Result<arrow::FieldVector> MakeArrowFields(const std::vector<std::pair<TString, NScheme::TTypeInfo>>& columns, const std::set<std::string>& notNullColumns = {}); -arrow::Result<std::shared_ptr<arrow::Schema>> MakeArrowSchema(const std::vector<std::pair<TString, NScheme::TTypeInfo>>& columns, const std::set<std::string>& notNullColumns = {}); +arrow::Result<arrow::FieldVector> MakeArrowFields(const std::vector<std::pair<TString, NScheme::TTypeInfo>>& ydbColumns, const std::set<std::string>& notNullColumns = {}); +arrow::Result<std::shared_ptr<arrow::Schema>> MakeArrowSchema(const std::vector<std::pair<TString, NScheme::TTypeInfo>>& ydbColumns, const std::set<std::string>& notNullColumns = {}); std::shared_ptr<arrow::Schema> DeserializeSchema(const TString& str); diff --git a/ydb/core/formats/arrow/arrow_helpers_minikql.cpp b/ydb/core/formats/arrow/arrow_helpers_minikql.cpp new file mode 100644 index 00000000000..273bb581725 --- /dev/null +++ b/ydb/core/formats/arrow/arrow_helpers_minikql.cpp @@ -0,0 +1,42 @@ +#include "arrow_helpers_minikql.h" + +#include <ydb/library/yql/dq/runtime/dq_arrow_helpers.h> +#include <util/string/join.h> + +namespace NKikimr::NArrow { + +arrow::Result<arrow::FieldVector> MakeArrowFields( + const std::vector<std::pair<TString, NKikimr::NMiniKQL::TType*>>& yqlColumns, const std::set<std::string>& notNullColumns) { + std::vector<std::shared_ptr<arrow::Field>> fields; + fields.reserve(yqlColumns.size()); + TVector<TString> errors; + for (auto& [name, mkqlType] : yqlColumns) { + std::string colName(name.data(), name.size()); + std::shared_ptr<arrow::DataType> arrowType; + + try { + arrowType = NYql::NArrow::GetArrowType(mkqlType); + } catch (const yexception& e) { + errors.emplace_back(colName + " error: " + e.what()); + } + + if (arrowType) { + fields.emplace_back(std::make_shared<arrow::Field>(colName, arrowType, !notNullColumns.contains(colName))); + } + } + if (errors.empty()) { + return fields; + } + return arrow::Status::TypeError(JoinSeq(", ", errors)); +} + +arrow::Result<std::shared_ptr<arrow::Schema>> MakeArrowSchema( + const std::vector<std::pair<TString, NKikimr::NMiniKQL::TType*>>& yqlColumns, const std::set<std::string>& notNullColumns) { + const auto fields = MakeArrowFields(yqlColumns, notNullColumns); + if (fields.ok()) { + return std::make_shared<arrow::Schema>(fields.ValueUnsafe()); + } + return fields.status(); +} + +} // namespace NKikimr::NArrow diff --git a/ydb/core/formats/arrow/arrow_helpers_minikql.h b/ydb/core/formats/arrow/arrow_helpers_minikql.h new file mode 100644 index 00000000000..ef3a4796346 --- /dev/null +++ b/ydb/core/formats/arrow/arrow_helpers_minikql.h @@ -0,0 +1,17 @@ +#pragma once + +#include <yql/essentials/minikql/mkql_node.h> + +#include <contrib/libs/apache/arrow/cpp/src/arrow/api.h> +#include <util/generic/string.h> + +namespace NKikimr::NMiniKQL { +class TType; +} + +namespace NKikimr::NArrow { + +arrow::Result<arrow::FieldVector> MakeArrowFields(const std::vector<std::pair<TString, NKikimr::NMiniKQL::TType*>>& yqlColumns, const std::set<std::string>& notNullColumns = {}); +arrow::Result<std::shared_ptr<arrow::Schema>> MakeArrowSchema(const std::vector<std::pair<TString, NKikimr::NMiniKQL::TType*>>& yqlColumns, const std::set<std::string>& notNullColumns = {}); + +} // namespace NKikimr::NArrow diff --git a/ydb/core/formats/arrow/switch/switch_type.h b/ydb/core/formats/arrow/switch/switch_type.h index fd8b167ef5e..b07fa8c039e 100644 --- a/ydb/core/formats/arrow/switch/switch_type.h +++ b/ydb/core/formats/arrow/switch/switch_type.h @@ -64,6 +64,7 @@ template <typename TFunc> case NScheme::NTypeIds::Interval: return callback(TTypeWrapper<arrow::DurationType>()); case NScheme::NTypeIds::Decimal: + case NScheme::NTypeIds::Uuid: return callback(TTypeWrapper<arrow::FixedSizeBinaryType>()); case NScheme::NTypeIds::Datetime64: diff --git a/ydb/core/formats/arrow/ya.make b/ydb/core/formats/arrow/ya.make index 6a49d427715..53e3636d68f 100644 --- a/ydb/core/formats/arrow/ya.make +++ b/ydb/core/formats/arrow/ya.make @@ -23,6 +23,8 @@ PEERDIR( ydb/library/formats/arrow ydb/library/services yql/essentials/core/arrow_kernels/request + yql/essentials/minikql + ydb/library/yql/dq/runtime ) YQL_LAST_ABI_VERSION() @@ -30,6 +32,7 @@ YQL_LAST_ABI_VERSION() SRCS( arrow_batch_builder.cpp arrow_helpers.cpp + arrow_helpers_minikql.cpp arrow_filter.cpp converter.cpp converter.h diff --git a/ydb/core/kqp/query_data/kqp_query_data.cpp b/ydb/core/kqp/query_data/kqp_query_data.cpp index fcd0de37879..90c5d1b2a7d 100644 --- a/ydb/core/kqp/query_data/kqp_query_data.cpp +++ b/ydb/core/kqp/query_data/kqp_query_data.cpp @@ -90,13 +90,10 @@ bool TKqpExecuterTxResult::HasTrailingResults() { void TKqpExecuterTxResult::FillYdb(Ydb::ResultSet* ydbResult, const TResultSetFormatSettings& resultSetFormatSettings, bool fillSchema, TMaybe<ui64> rowsLimitPerWrite) { YQL_ENSURE(ydbResult); YQL_ENSURE(!Rows.IsWide()); - YQL_ENSURE(MkqlItemType->GetKind() == NKikimr::NMiniKQL::TType::EKind::Struct); + YQL_ENSURE(MkqlItemType->GetKind() == NMiniKQL::TType::EKind::Struct); const auto* mkqlSrcRowStructType = static_cast<const TStructType*>(MkqlItemType); - std::vector<std::pair<TString, NScheme::TTypeInfo>> arrowSchema; - std::set<std::string> arrowNotNullColumns; - if (fillSchema) { for (ui32 idx = 0; idx < mkqlSrcRowStructType->GetMembersCount(); ++idx) { auto* column = ydbResult->add_columns(); @@ -110,6 +107,9 @@ void TKqpExecuterTxResult::FillYdb(Ydb::ResultSet* ydbResult, const TResultSetFo } } + std::vector<std::pair<TString, NMiniKQL::TType*>> arrowSchema; + std::set<std::string> arrowNotNullColumns; + if (resultSetFormatSettings.IsArrowFormat()) { for (ui32 idx = 0; idx < mkqlSrcRowStructType->GetMembersCount(); ++idx) { ui32 memberIndex = (!ColumnOrder || ColumnOrder->empty()) ? idx : (*ColumnOrder)[idx]; @@ -120,8 +120,7 @@ void TKqpExecuterTxResult::FillYdb(Ydb::ResultSet* ydbResult, const TResultSetFo arrowNotNullColumns.insert(columnName); } - NScheme::TTypeInfo typeInfo = NScheme::TypeInfoFromMiniKQLType(columnType); - arrowSchema.emplace_back(std::move(columnName), std::move(typeInfo)); + arrowSchema.emplace_back(std::move(columnName), std::move(columnType)); } } @@ -145,7 +144,6 @@ void TKqpExecuterTxResult::FillYdb(Ydb::ResultSet* ydbResult, const TResultSetFo batchBuilder.Reserve(Rows.RowCount()); YQL_ENSURE(batchBuilder.Start(arrowSchema).ok()); - TRowBuilder rowBuilder(arrowSchema.size()); Rows.ForEachRow([&](const NUdf::TUnboxedValue& row) -> bool { if (rowsLimitPerWrite) { if (*rowsLimitPerWrite == 0) { @@ -155,14 +153,7 @@ void TKqpExecuterTxResult::FillYdb(Ydb::ResultSet* ydbResult, const TResultSetFo --(*rowsLimitPerWrite); } - for (size_t i = 0; i < arrowSchema.size(); ++i) { - ui32 memberIndex = (!ColumnOrder || ColumnOrder->empty()) ? i : (*ColumnOrder)[i]; - const auto& [name, type] = arrowSchema[i]; - rowBuilder.AddCell(i, type, row.GetElement(memberIndex), type.GetPgTypeMod(name)); - } - - auto cells = rowBuilder.BuildCells(); - batchBuilder.AddRow(cells); + batchBuilder.AddRow(row, arrowSchema.size(), ColumnOrder); return true; }); diff --git a/ydb/core/kqp/runtime/kqp_transport.cpp b/ydb/core/kqp/runtime/kqp_transport.cpp index 85077cf8e3a..9e3e37c23fe 100644 --- a/ydb/core/kqp/runtime/kqp_transport.cpp +++ b/ydb/core/kqp/runtime/kqp_transport.cpp @@ -63,7 +63,7 @@ void TKqpProtoBuilder::BuildYdbResultSet( TColumnOrder order = columnHints ? TColumnOrder(*columnHints) : TColumnOrder{}; - std::vector<std::pair<TString, NScheme::TTypeInfo>> arrowSchema; + std::vector<std::pair<TString, NMiniKQL::TType*>> arrowSchema; std::set<std::string> arrowNotNullColumns; if (fillSchema) { @@ -85,12 +85,11 @@ void TKqpProtoBuilder::BuildYdbResultSet( auto columnName = TString(columnHints && columnHints->size() ? order.at(idx).LogicalName : mkqlSrcRowStructType->GetMemberName(memberIndex)); auto* columnType = mkqlSrcRowStructType->GetMemberType(memberIndex); - if (columnType->GetKind() != TType::EKind::Optional) { + if (columnType->GetKind() != NMiniKQL::TType::EKind::Optional) { arrowNotNullColumns.insert(columnName); } - NScheme::TTypeInfo typeInfo = NScheme::TypeInfoFromMiniKQLType(columnType); - arrowSchema.emplace_back(std::move(columnName), std::move(typeInfo)); + arrowSchema.emplace_back(std::move(columnName), std::move(columnType)); } } @@ -140,7 +139,6 @@ void TKqpProtoBuilder::BuildYdbResultSet( batchBuilder.Reserve(arrowRowsCount); YQL_ENSURE(batchBuilder.Start(arrowSchema).ok()); - TRowBuilder rowBuilder(arrowSchema.size()); for (auto& part : data) { if (!part.ChunkCount()) { continue; @@ -150,14 +148,7 @@ void TKqpProtoBuilder::BuildYdbResultSet( dataSerializer.Deserialize(std::move(part), mkqlSrcRowType, rows); rows.ForEachRow([&](const NUdf::TUnboxedValue& value) { - for (size_t i = 0; i < arrowSchema.size(); ++i) { - ui32 memberIndex = (!columnOrder || columnOrder->empty()) ? i : (*columnOrder)[i]; - const auto& [name, type] = arrowSchema[i]; - rowBuilder.AddCell(i, type, value.GetElement(memberIndex), type.GetPgTypeMod(name)); - } - - auto cells = rowBuilder.BuildCells(); - batchBuilder.AddRow(cells); + batchBuilder.AddRow(value, arrowSchema.size(), columnOrder); }); } diff --git a/ydb/core/kqp/ut/arrow/kqp_result_set_format_arrow.cpp b/ydb/core/kqp/ut/arrow/kqp_result_set_formats.cpp index dbda1212cdf..92414804a2b 100644 --- a/ydb/core/kqp/ut/arrow/kqp_result_set_format_arrow.cpp +++ b/ydb/core/kqp/ut/arrow/kqp_result_set_formats.cpp @@ -58,6 +58,7 @@ void CreateAllTypesRowTable(TQueryClient& client) { YsonValue Yson, JsonDocumentValue JsonDocument, DyNumberValue DyNumber, + UuidValue Uuid, Int32NotNullValue Int32 NOT NULL, PRIMARY KEY (Key) ); @@ -65,8 +66,8 @@ void CreateAllTypesRowTable(TQueryClient& client) { UNIT_ASSERT_C(createResult.IsSuccess(), createResult.GetIssues().ToString()); auto insertResult = client.ExecuteQuery(R"( - INSERT INTO `/Root/RowTable` (Key, BoolValue, Int8Value, Uint8Value, Int16Value, Uint16Value, Int32Value, Uint32Value, Int64Value, Uint64Value, FloatValue, DoubleValue, StringValue, Utf8Value, DateValue, DatetimeValue, TimestampValue, IntervalValue, DecimalValue, JsonValue, YsonValue, JsonDocumentValue, DyNumberValue, Int32NotNullValue) VALUES - (42, true, -1, 1, -2, 2, -3, 3, -4, 4, CAST(3.0 AS Float), 4.0, "five", Utf8("six"), Date("2007-07-07"), Datetime("2008-08-08T08:08:08Z"), Timestamp("2009-09-09T09:09:09.09Z"), Interval("P10D"), CAST("11.11" AS Decimal(22, 9)), "[12]", "[13]", JsonDocument("[14]"), DyNumber("15.15"), 123); + INSERT INTO `/Root/RowTable` (Key, BoolValue, Int8Value, Uint8Value, Int16Value, Uint16Value, Int32Value, Uint32Value, Int64Value, Uint64Value, FloatValue, DoubleValue, StringValue, Utf8Value, DateValue, DatetimeValue, TimestampValue, IntervalValue, DecimalValue, JsonValue, YsonValue, JsonDocumentValue, DyNumberValue, UuidValue, Int32NotNullValue) VALUES + (42, true, -1, 1, -2, 2, -3, 3, -4, 4, CAST(3.0 AS Float), 4.0, "five", Utf8("six"), Date("2007-07-07"), Datetime("2008-08-08T08:08:08Z"), Timestamp("2009-09-09T09:09:09.09Z"), Interval("P10D"), CAST("11.11" AS Decimal(22, 9)), "[12]", "[13]", JsonDocument("[14]"), DyNumber("15.15"), Uuid("5b99a330-04ef-4f1a-9b64-ba6d5f44eafe"), 123); )", TTxControl::BeginTx().CommitTx()).GetValueSync(); UNIT_ASSERT_C(insertResult.IsSuccess(), insertResult.GetIssues().ToString()); } @@ -142,14 +143,14 @@ void AssertArrowValueResultsSize(const std::vector<TResultSet>& arrowResultSets, } } -std::vector<std::shared_ptr<arrow::RecordBatch>> ExecuteAndCombineBatches(TQueryClient& client, const TString& query, bool assertSize = false, ui64 minBatchesCount = 1) { +std::vector<std::shared_ptr<arrow::RecordBatch>> ExecuteAndCombineBatches(TQueryClient& client, const TString& query, bool assertSize = false, ui64 minBatchesCount = 1, TParams params = TParamsBuilder().Build()) { auto arrowSettings = TExecuteQuerySettings().Format(TResultSet::EFormat::Arrow); - auto arrowResponse = client.ExecuteQuery(query, TTxControl::BeginTx().CommitTx(), arrowSettings).GetValueSync(); + auto arrowResponse = client.ExecuteQuery(query, TTxControl::BeginTx().CommitTx(), params, arrowSettings).GetValueSync(); UNIT_ASSERT_C(arrowResponse.IsSuccess(), arrowResponse.GetIssues().ToString()); if (assertSize) { auto valueSettings = TExecuteQuerySettings().Format(TResultSet::EFormat::Value); - auto valueResponse = client.ExecuteQuery(query, TTxControl::BeginTx().CommitTx(), valueSettings).GetValueSync(); + auto valueResponse = client.ExecuteQuery(query, TTxControl::BeginTx().CommitTx(), params, valueSettings).GetValueSync(); UNIT_ASSERT_C(valueResponse.IsSuccess(), valueResponse.GetIssues().ToString()); AssertArrowValueResultsSize(arrowResponse.GetResultSets(), valueResponse.GetResultSets()); } @@ -255,6 +256,25 @@ void CompareCompressedAndDefaultBatches(TQueryClient& client, std::optional<TArr UNIT_ASSERT_VALUES_EQUAL(firstArrowBatch->ToString(), secondArrowBatch->ToString()); } +void ValidateOptionalColumn(const std::shared_ptr<arrow::Array>& array, int depth, bool isVariant) { + if (depth == 0 && isVariant) { + UNIT_ASSERT_C(array->type()->id() == arrow::Type::DENSE_UNION, "Column type must be arrow::Type::DENSE_UNION"); + return; + } + + if (depth == 1 && !isVariant) { + return; + } + + UNIT_ASSERT_C(array->type()->id() == arrow::Type::STRUCT, "Column type must be arrow::Type::STRUCT"); + + auto structArray = static_pointer_cast<arrow::StructArray>(array); + UNIT_ASSERT_C(structArray->num_fields() == 1, "Struct array must have 1 field"); + + auto innerArray = structArray->field(0); + ValidateOptionalColumn(innerArray, depth - 1, isVariant); +} + } // namespace Y_UNIT_TEST_SUITE(KqpResultSetFormats) { @@ -780,7 +800,7 @@ Y_UNIT_TEST_SUITE(KqpResultSetFormats) { UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); NColumnShard::TTableUpdatesBuilder builder(NArrow::MakeArrowSchema({ - std::make_pair("BoolValue", TTypeInfo(NTypeIds::Bool)), + std::make_pair("BoolValue", TTypeInfo(NTypeIds::Uint8)), std::make_pair("Int8Value", TTypeInfo(NTypeIds::Int8)), std::make_pair("Uint8Value", TTypeInfo(NTypeIds::Uint8)), std::make_pair("Int16Value", TTypeInfo(NTypeIds::Int16)), @@ -795,8 +815,8 @@ Y_UNIT_TEST_SUITE(KqpResultSetFormats) { })); builder.AddRow().AddNull().Add<i8>(-1).Add<ui8>(1).Add<i16>(-2).Add<ui16>(2).Add<i32>(-3).Add<ui32>(3).Add<i64>(-4).Add<ui64>(4).Add<float>(5.0).Add<double>(6.0).Add(TDecimalValue("7.77", 22, 2)); - builder.AddRow().Add<bool>(false).Add<i8>(-1).Add<ui8>(1).Add<i16>(-2).Add<ui16>(2).Add<i32>(-3).Add<ui32>(3).Add<i64>(-4).Add<ui64>(4).Add<float>(5.0).Add<double>(6.0).Add(TDecimalValue("7.77", 22, 2)); - builder.AddRow().Add<bool>(true).Add<i8>(-1).Add<ui8>(1).Add<i16>(-2).Add<ui16>(2).Add<i32>(-3).Add<ui32>(3).Add<i64>(-4).Add<ui64>(4).Add<float>(5.0).Add<double>(6.0).Add(TDecimalValue("7.77", 22, 2)); + builder.AddRow().Add<ui8>(false).Add<i8>(-1).Add<ui8>(1).Add<i16>(-2).Add<ui16>(2).Add<i32>(-3).Add<ui32>(3).Add<i64>(-4).Add<ui64>(4).Add<float>(5.0).Add<double>(6.0).Add(TDecimalValue("7.77", 22, 2)); + builder.AddRow().Add<ui8>(true).Add<i8>(-1).Add<ui8>(1).Add<i16>(-2).Add<ui16>(2).Add<i32>(-3).Add<ui32>(3).Add<i64>(-4).Add<ui64>(4).Add<float>(5.0).Add<double>(6.0).Add(TDecimalValue("7.77", 22, 2)); auto expected = builder.BuildArrow(); UNIT_ASSERT_VALUES_EQUAL(batches.front()->ToString(), expected->ToString()); @@ -871,10 +891,12 @@ Y_UNIT_TEST_SUITE(KqpResultSetFormats) { YsonValue Yson, DyNumberValue DyNumber, JsonDocumentValue JsonDocument, + UuidValue Uuid, StringNotNullValue String NOT NULL, YsonNotNullValue Yson NOT NULL, JsonDocumentNotNullValue JsonDocument NOT NULL, DyNumberNotNullValue DyNumber NOT NULL, + UuidNotNullValue Uuid NOT NULL, PRIMARY KEY (StringValue) ); )", TTxControl::NoTx()).GetValueSync(); @@ -882,17 +904,17 @@ Y_UNIT_TEST_SUITE(KqpResultSetFormats) { } { auto result = client.ExecuteQuery(R"( - INSERT INTO BinaryTypesTable (StringValue, YsonValue, DyNumberValue, JsonDocumentValue, StringNotNullValue, YsonNotNullValue, JsonDocumentNotNullValue, DyNumberNotNullValue) VALUES - ("John", "[1]", DyNumber("1.0"), JsonDocument("{\"a\": 1}"), "Mark", "[2]", JsonDocument("{\"b\": 2}"), DyNumber("4.0")), - (NULL, "[4]", NULL, NULL, "Maria", "[5]", JsonDocument("[6]"), DyNumber("7.0")), - ("Mark", NULL, NULL, NULL, "Michael", "[7]", JsonDocument("[8]"), DyNumber("9.0")), - ("Leo", "[10]", DyNumber("11.0"), JsonDocument("[12]"), "Maria", "[13]", JsonDocument("[14]"), DyNumber("15.0")); + INSERT INTO BinaryTypesTable (StringValue, YsonValue, DyNumberValue, JsonDocumentValue, UuidValue, StringNotNullValue, YsonNotNullValue, JsonDocumentNotNullValue, DyNumberNotNullValue, UuidNotNullValue) VALUES + ("John", "[1]", DyNumber("1.0"), JsonDocument("{\"a\": 1}"), Uuid("5b99a330-04ef-4f1a-9b64-ba6d5f44eafe"), "Mark", "[2]", JsonDocument("{\"b\": 2}"), DyNumber("4.0"), Uuid("5b99a330-04ef-4f1a-9b64-ba6d5f44eafe")), + (NULL, "[4]", NULL, NULL, NULL, "Maria", "[5]", JsonDocument("[6]"), DyNumber("7.0"), Uuid("5b99a330-04ef-4f1a-9b64-ba6d5f44eafe")), + ("Mark", NULL, NULL, NULL, NULL, "Michael", "[7]", JsonDocument("[8]"), DyNumber("9.0"), Uuid("5b99a330-04ef-4f1a-9b64-ba6d5f44eafe")), + ("Leo", "[10]", DyNumber("11.0"), JsonDocument("[12]"), Uuid("5b99a330-04ef-4f1a-9b64-ba6d5f44eafe"), "Maria", "[13]", JsonDocument("[14]"), DyNumber("15.0"), Uuid("5b99a330-04ef-4f1a-9b64-ba6d5f44eafe")); )", TTxControl::BeginTx().CommitTx()).GetValueSync(); UNIT_ASSERT_C(result.IsSuccess(), result.GetIssues().ToString()); } { auto batches = ExecuteAndCombineBatches(client, R"( - SELECT StringValue, YsonValue, DyNumberValue, JsonDocumentValue, StringNotNullValue, YsonNotNullValue, JsonDocumentNotNullValue, DyNumberNotNullValue + SELECT StringValue, YsonValue, DyNumberValue, JsonDocumentValue, UuidValue, StringNotNullValue, YsonNotNullValue, JsonDocumentNotNullValue, DyNumberNotNullValue, UuidNotNullValue FROM BinaryTypesTable ORDER BY StringValue; )", /* assertSize */ true); @@ -903,16 +925,18 @@ Y_UNIT_TEST_SUITE(KqpResultSetFormats) { std::make_pair("YsonValue", TTypeInfo(NTypeIds::Yson)), std::make_pair("DyNumberValue", TTypeInfo(NTypeIds::DyNumber)), std::make_pair("JsonDocumentValue", TTypeInfo(NTypeIds::JsonDocument)), + std::make_pair("UuidValue", TTypeInfo(NTypeIds::Uuid)), std::make_pair("StringNotNullValue", TTypeInfo(NTypeIds::String)), std::make_pair("YsonNotNullValue", TTypeInfo(NTypeIds::Yson)), std::make_pair("JsonDocumentNotNullValue", TTypeInfo(NTypeIds::JsonDocument)), - std::make_pair("DyNumberNotNullValue", TTypeInfo(NTypeIds::DyNumber)) + std::make_pair("DyNumberNotNullValue", TTypeInfo(NTypeIds::DyNumber)), + std::make_pair("UuidNotNullValue", TTypeInfo(NTypeIds::Uuid)) })); - builder.AddRow().AddNull().Add("[4]").AddNull().AddNull().Add("Maria").Add("[5]").Add(SerializeToBinaryJsonString("[6]")).Add(NDyNumber::ParseDyNumberString("7.0")->c_str()); - builder.AddRow().Add("John").Add("[1]").Add(NDyNumber::ParseDyNumberString("1.0")->c_str()).Add(SerializeToBinaryJsonString("{\"a\": 1}")).Add("Mark").Add("[2]").Add(SerializeToBinaryJsonString("{\"b\": 2}")).Add(NDyNumber::ParseDyNumberString("4.0")->c_str()); - builder.AddRow().Add("Leo").Add("[10]").Add(NDyNumber::ParseDyNumberString("11.0")->c_str()).Add(SerializeToBinaryJsonString("[12]")).Add("Maria").Add("[13]").Add(SerializeToBinaryJsonString("[14]")).Add(NDyNumber::ParseDyNumberString("15.0")->c_str()); - builder.AddRow().Add("Mark").AddNull().AddNull().AddNull().Add("Michael").Add("[7]").Add(SerializeToBinaryJsonString("[8]")).Add(NDyNumber::ParseDyNumberString("9.0")->c_str()); + builder.AddRow().AddNull().Add("[4]").AddNull().AddNull().AddNull().Add("Maria").Add("[5]").Add(SerializeToBinaryJsonString("[6]")).Add(NDyNumber::ParseDyNumberString("7.0")->c_str()).Add<NYdb::TUuidValue>(NYdb::TUuidValue("5b99a330-04ef-4f1a-9b64-ba6d5f44eafe")); + builder.AddRow().Add("John").Add("[1]").Add(NDyNumber::ParseDyNumberString("1.0")->c_str()).Add(SerializeToBinaryJsonString("{\"a\": 1}")).Add(NYdb::TUuidValue("5b99a330-04ef-4f1a-9b64-ba6d5f44eafe")).Add("Mark").Add("[2]").Add(SerializeToBinaryJsonString("{\"b\": 2}")).Add(NDyNumber::ParseDyNumberString("4.0")->c_str()).Add(NYdb::TUuidValue("5b99a330-04ef-4f1a-9b64-ba6d5f44eafe")); + builder.AddRow().Add("Leo").Add("[10]").Add(NDyNumber::ParseDyNumberString("11.0")->c_str()).Add(SerializeToBinaryJsonString("[12]")).Add(NYdb::TUuidValue("5b99a330-04ef-4f1a-9b64-ba6d5f44eafe")).Add("Maria").Add("[13]").Add(SerializeToBinaryJsonString("[14]")).Add(NDyNumber::ParseDyNumberString("15.0")->c_str()).Add(NYdb::TUuidValue("5b99a330-04ef-4f1a-9b64-ba6d5f44eafe")); + builder.AddRow().Add("Mark").AddNull().AddNull().AddNull().AddNull().Add("Michael").Add("[7]").Add(SerializeToBinaryJsonString("[8]")).Add(NDyNumber::ParseDyNumberString("9.0")->c_str()).Add(NYdb::TUuidValue("5b99a330-04ef-4f1a-9b64-ba6d5f44eafe")); auto expected = builder.BuildArrow(); @@ -955,10 +979,10 @@ Y_UNIT_TEST_SUITE(KqpResultSetFormats) { UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); NColumnShard::TTableUpdatesBuilder builder(NArrow::MakeArrowSchema({ - std::make_pair("DateValue", TTypeInfo(NTypeIds::Date)), - std::make_pair("DatetimeValue", TTypeInfo(NTypeIds::Datetime)), - std::make_pair("TimestampValue", TTypeInfo(NTypeIds::Timestamp)), - std::make_pair("IntervalValue", TTypeInfo(NTypeIds::Interval)) + std::make_pair("DateValue", TTypeInfo(NTypeIds::Uint16)), + std::make_pair("DatetimeValue", TTypeInfo(NTypeIds::Uint32)), + std::make_pair("TimestampValue", TTypeInfo(NTypeIds::Uint64)), + std::make_pair("IntervalValue", TTypeInfo(NTypeIds::Int64)) })); builder.AddRow().Add<ui16>(11323).Add<ui32>(1012615322).Add<ui64>(1046660583000000).Add<i64>(604800000000); @@ -1419,6 +1443,595 @@ Y_UNIT_TEST_SUITE(KqpResultSetFormats) { UNIT_ASSERT_GT_C(count, 1, "Expected at least 2 result sets for statement with ResultSetIndex = " << idx); } } + + /** + * More tests for different types with correctness and convertations between Arrow and UV : + * ydb/library/yql/dq/runtime/dq_arrow_helpers_ut.cpp + */ + + // Optional<T> + Y_UNIT_TEST(ArrowFormat_Types_Optional_1) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ true); + auto client = kikimr.GetQueryClient(); + + { + auto batches = ExecuteAndCombineBatches(client, R"( + SELECT Key1, Name FROM Join2 + WHERE Key1 IN [104, 106, 108] + ORDER BY Key1; + )", /* assertSize */ false, 1); + + UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); + + const auto& batch = batches.front(); + + UNIT_ASSERT_VALUES_EQUAL(batch->num_rows(), 3); + UNIT_ASSERT_VALUES_EQUAL(batch->num_columns(), 2); + + ValidateOptionalColumn(batch->column(0), 1, false); + ValidateOptionalColumn(batch->column(1), 1, false); + + const TString expected = +R"(Key1: [ + 104, + 106, + 108 + ] +Name: [ + 4E616D6533, + 4E616D6533, + null + ] +)"; + UNIT_ASSERT_VALUES_EQUAL(batch->ToString(), expected); + } + } + + // Optional<Optional<T>> + Y_UNIT_TEST(ArrowFormat_Types_Optional_2) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ true); + auto client = kikimr.GetQueryClient(); + + { + auto batches = ExecuteAndCombineBatches(client, R"( + SELECT Just(Key1), Just(Name) FROM Join2 + WHERE Key1 IN [104, 106, 108] + ORDER BY Key1; + )", /* assertSize */ false, 1); + + UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); + + const auto& batch = batches.front(); + + UNIT_ASSERT_VALUES_EQUAL(batch->num_rows(), 3); + UNIT_ASSERT_VALUES_EQUAL(batch->num_columns(), 2); + + ValidateOptionalColumn(batch->column(0), 2, false); + ValidateOptionalColumn(batch->column(1), 2, false); + + const TString expected = +R"(column0: -- is_valid: all not null + -- child 0 type: uint32 + [ + 104, + 106, + 108 + ] +column1: -- is_valid: all not null + -- child 0 type: binary + [ + 4E616D6533, + 4E616D6533, + null + ] +)"; + UNIT_ASSERT_VALUES_EQUAL(batch->ToString(), expected); + } + } + + // Optional<Optional<Optional<Optional<T>>>> + Y_UNIT_TEST(ArrowFormat_Types_Optional_3) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ true); + auto client = kikimr.GetQueryClient(); + + { + auto batches = ExecuteAndCombineBatches(client, R"( + SELECT Just(Just(Just(Key1))), Just(Just(Just(Name))) FROM Join2 + WHERE Key1 IN [104, 106, 108] + ORDER BY Key1; + )", /* assertSize */ false, 1); + + UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); + + const auto& batch = batches.front(); + + UNIT_ASSERT_VALUES_EQUAL(batch->num_rows(), 3); + UNIT_ASSERT_VALUES_EQUAL(batch->num_columns(), 2); + + ValidateOptionalColumn(batch->column(0), 3, false); + ValidateOptionalColumn(batch->column(1), 3, false); + + const TString expected = +R"(column0: -- is_valid: all not null + -- child 0 type: struct<opt: struct<opt: uint32 not null> not null> + -- is_valid: all not null + -- child 0 type: struct<opt: uint32 not null> + -- is_valid: all not null + -- child 0 type: uint32 + [ + 104, + 106, + 108 + ] +column1: -- is_valid: all not null + -- child 0 type: struct<opt: struct<opt: binary not null> not null> + -- is_valid: all not null + -- child 0 type: struct<opt: binary not null> + -- is_valid: all not null + -- child 0 type: binary + [ + 4E616D6533, + 4E616D6533, + null + ] +)"; + UNIT_ASSERT_VALUES_EQUAL(batch->ToString(), expected); + } + } + + // Optional<Variant<T, F>> + Y_UNIT_TEST(ArrowFormat_Types_Optional_4) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ false); + auto client = kikimr.GetQueryClient(); + + { + auto batches = ExecuteAndCombineBatches(client, R"( + SELECT Just(Variant(1, "foo", Variant<foo: Int32, bar: Bool>)); + )", /* assertSize */ false); + + UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); + + const auto& batch = batches.front(); + + UNIT_ASSERT_VALUES_EQUAL(batch->num_rows(), 1); + UNIT_ASSERT_VALUES_EQUAL(batch->num_columns(), 1); + + ValidateOptionalColumn(batch->column(0), 1, /* isVariant */ true); + + const TString expected = +R"(column0: -- is_valid: all not null + -- child 0 type: dense_union<bar: uint8 not null=0, foo: int32 not null=1> + -- is_valid: all not null + -- type_ids: [ + 1 + ] + -- value_offsets: [ + 0 + ] + -- child 0 type: uint8 + [] + -- child 1 type: int32 + [ + 1 + ] +)"; + UNIT_ASSERT_VALUES_EQUAL(batch->ToString(), expected); + } + } + + // Optional<Optional<Variant<T, F, G>>> + Y_UNIT_TEST(ArrowFormat_Types_Optional_5) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ false); + auto client = kikimr.GetQueryClient(); + + { + auto batches = ExecuteAndCombineBatches(client, R"( + SELECT Just(Just(Variant(1, "foo", Variant<foo: Int32, bar: Bool, foobar: String>))); + )", /* assertSize */ false); + + UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); + + const auto& batch = batches.front(); + + UNIT_ASSERT_VALUES_EQUAL(batch->num_rows(), 1); + UNIT_ASSERT_VALUES_EQUAL(batch->num_columns(), 1); + + ValidateOptionalColumn(batch->column(0), 2, /* isVariant */ true); + + const TString expected = +R"(column0: -- is_valid: all not null + -- child 0 type: struct<opt: dense_union<bar: uint8 not null=0, foo: int32 not null=1, foobar: binary not null=2> not null> + -- is_valid: all not null + -- child 0 type: dense_union<bar: uint8 not null=0, foo: int32 not null=1, foobar: binary not null=2> + -- is_valid: all not null + -- type_ids: [ + 1 + ] + -- value_offsets: [ + 0 + ] + -- child 0 type: uint8 + [] + -- child 1 type: int32 + [ + 1 + ] + -- child 2 type: binary + [] +)"; + UNIT_ASSERT_VALUES_EQUAL(batch->ToString(), expected); + } + } + + // List<T> + Y_UNIT_TEST(ArrowFormat_Types_List_1) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ false); + auto client = kikimr.GetQueryClient(); + + { + auto batches = ExecuteAndCombineBatches(client, R"( + SELECT CAST([1, 2, 3] AS List<Int32>); + )", /* assertSize */ false); + + UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); + + const auto& batch = batches.front(); + + UNIT_ASSERT_VALUES_EQUAL(batch->num_rows(), 1); + UNIT_ASSERT_VALUES_EQUAL(batch->num_columns(), 1); + UNIT_ASSERT_C(batch->column(0)->type()->id() == arrow::Type::LIST, "Column type must be arrow::Type::LIST"); + + const TString expected = +R"(column0: [ + [ + 1, + 2, + 3 + ] + ] +)"; + UNIT_ASSERT_VALUES_EQUAL(batch->ToString(), expected); + } + } + + // List<Optional<T>> + Y_UNIT_TEST(ArrowFormat_Types_List_2) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ false); + auto client = kikimr.GetQueryClient(); + + { + auto batches = ExecuteAndCombineBatches(client, R"( + SELECT CAST([1, NULL, 3, null] AS List<Optional<Int32>>); + )", /* assertSize */ false); + + UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); + + const auto& batch = batches.front(); + + UNIT_ASSERT_VALUES_EQUAL(batch->num_rows(), 1); + UNIT_ASSERT_VALUES_EQUAL(batch->num_columns(), 1); + UNIT_ASSERT_C(batch->column(0)->type()->id() == arrow::Type::LIST, "Column type must be arrow::Type::LIST"); + + const TString expected = +R"(column0: [ + [ + 1, + null, + 3, + null + ] + ] +)"; + UNIT_ASSERT_VALUES_EQUAL(batch->ToString(), expected); + } + } + + // List<Optional<T>> of columns + Y_UNIT_TEST(ArrowFormat_Types_List_3) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ true); + auto client = kikimr.GetQueryClient(); + + { + auto batches = ExecuteAndCombineBatches(client, R"( + SELECT [App, Host] FROM Logs + ORDER BY App; + )", /* assertSize */ false); + + UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); + + const auto& batch = batches.front(); + + UNIT_ASSERT_VALUES_EQUAL(batch->num_rows(), 9); + UNIT_ASSERT_VALUES_EQUAL(batch->num_columns(), 1); + UNIT_ASSERT_C(batch->column(0)->type()->id() == arrow::Type::LIST, "Column type must be arrow::Type::LIST"); + + const TString expected = +R"(column0: [ + [ + "apache", + "front-42" + ], + [ + "kikimr-db", + "kikimr-db-10" + ], + [ + "kikimr-db", + "kikimr-db-21" + ], + [ + "kikimr-db", + "kikimr-db-21" + ], + [ + "kikimr-db", + "kikimr-db-53" + ], + [ + "nginx", + "nginx-10" + ], + [ + "nginx", + "nginx-23" + ], + [ + "nginx", + "nginx-23" + ], + [ + "ydb", + "ydb-1000" + ] + ] +)"; + UNIT_ASSERT_VALUES_EQUAL(batch->ToString(), expected); + } + } + + // List<> + Y_UNIT_TEST(ArrowFormat_Types_EmptyList) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ false); + auto client = kikimr.GetQueryClient(); + + { + auto batches = ExecuteAndCombineBatches(client, R"( + SELECT []; + )", /* assertSize */ false); + + UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); + + const auto& batch = batches.front(); + + UNIT_ASSERT_VALUES_EQUAL(batch->num_rows(), 1); + UNIT_ASSERT_VALUES_EQUAL(batch->num_columns(), 1); + UNIT_ASSERT_C(batch->column(0)->type()->id() == arrow::Type::STRUCT, "Column type must be arrow::Type::STRUCT"); + + const TString expected = +R"(column0: -- is_valid: all not null +)"; + UNIT_ASSERT_VALUES_EQUAL(batch->ToString(), expected); + } + } + + // Tuple<T, F> + Y_UNIT_TEST(ArrowFormat_Types_Tuple) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ false); + auto client = kikimr.GetQueryClient(); + + { + auto batches = ExecuteAndCombineBatches(client, R"( + SELECT CAST((1, 2.5) AS Tuple<Uint32, Double>); + )", /* assertSize */ false); + + UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); + + const auto& batch = batches.front(); + + UNIT_ASSERT_VALUES_EQUAL(batch->num_rows(), 1); + UNIT_ASSERT_VALUES_EQUAL(batch->num_columns(), 1); + UNIT_ASSERT_C(batch->column(0)->type()->id() == arrow::Type::STRUCT, "Column type must be arrow::Type::STRUCT"); + + const TString expected = +R"(column0: -- is_valid: all not null + -- child 0 type: uint32 + [ + 1 + ] + -- child 1 type: double + [ + 2.5 + ] +)"; + UNIT_ASSERT_VALUES_EQUAL(batch->ToString(), expected); + } + } + + // Dict<K, V> + Y_UNIT_TEST(ArrowFormat_Types_Dict_1) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ false); + auto client = kikimr.GetQueryClient(); + + { + auto batches = ExecuteAndCombineBatches(client, R"( + SELECT CAST({"a": 1, "b": 2, "c": 3} AS Dict<String,Int32>); + )", /* assertSize */ false); + + UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); + + const auto& batch = batches.front(); + + UNIT_ASSERT_VALUES_EQUAL(batch->num_rows(), 1); + UNIT_ASSERT_VALUES_EQUAL(batch->num_columns(), 1); + UNIT_ASSERT_C(batch->column(0)->type()->id() == arrow::Type::STRUCT, "Column type must be arrow::Type::STRUCT"); + + const TString expected = +R"(column0: -- is_valid: all not null + -- child 0 type: map<binary, int32> + [ + keys: + [ + 61, + 63, + 62 + ] + values: + [ + 1, + 3, + 2 + ] + ] + -- child 1 type: uint64 + [ + 0 + ] +)"; + + UNIT_ASSERT_VALUES_EQUAL(batch->ToString(), expected); + } + } + + // Dict<Optional<K>, V> + Y_UNIT_TEST(ArrowFormat_Types_Dict_2) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ false); + auto client = kikimr.GetQueryClient(); + + { + auto batches = ExecuteAndCombineBatches(client, R"( + SELECT CAST({"a": 1, "b": 2, NULL: 3} AS Dict<Optional<String>,Int32>); + )", /* assertSize */ false); + + UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); + + const auto& batch = batches.front(); + + UNIT_ASSERT_VALUES_EQUAL(batch->num_rows(), 1); + UNIT_ASSERT_VALUES_EQUAL(batch->num_columns(), 1); + UNIT_ASSERT_C(batch->column(0)->type()->id() == arrow::Type::STRUCT, "Column type must be arrow::Type::STRUCT"); + + const TString expected = +R"(column0: -- is_valid: all not null + -- child 0 type: list<item: struct<key: binary, payload: int32 not null>> + [ + -- is_valid: all not null + -- child 0 type: binary + [ + 61, + 62, + null + ] + -- child 1 type: int32 + [ + 1, + 2, + 3 + ] + ] + -- child 1 type: uint64 + [ + 0 + ] +)"; + + UNIT_ASSERT_VALUES_EQUAL(batch->ToString(), expected); + } + } + + // Dict<> + Y_UNIT_TEST(ArrowFormat_Types_EmptyDict) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ false); + auto client = kikimr.GetQueryClient(); + + { + auto batches = ExecuteAndCombineBatches(client, R"( + SELECT {}; + )", /* assertSize */ false); + + UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); + + const auto& batch = batches.front(); + + UNIT_ASSERT_VALUES_EQUAL(batch->num_rows(), 1); + UNIT_ASSERT_VALUES_EQUAL(batch->num_columns(), 1); + UNIT_ASSERT_C(batch->column(0)->type()->id() == arrow::Type::STRUCT, "Column type must be arrow::Type::STRUCT"); + + const TString expected = +R"(column0: -- is_valid: all not null +)"; + UNIT_ASSERT_VALUES_EQUAL(batch->ToString(), expected); + } + } + + // Struct<first:T, second:F> + Y_UNIT_TEST(ArrowFormat_Types_Struct) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ false); + auto client = kikimr.GetQueryClient(); + + { + auto batches = ExecuteAndCombineBatches(client, R"( + SELECT CAST(<|first: 1, second: "2"|> AS Struct<first:Int32,second:Utf8>); + )", /* assertSize */ false); + + UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); + + const auto& batch = batches.front(); + + UNIT_ASSERT_VALUES_EQUAL(batch->num_rows(), 1); + UNIT_ASSERT_VALUES_EQUAL(batch->num_columns(), 1); + UNIT_ASSERT_C(batch->column(0)->type()->id() == arrow::Type::STRUCT, "Column type must be arrow::Type::STRUCT"); + + const TString expected = +R"(column0: -- is_valid: all not null + -- child 0 type: int32 + [ + 1 + ] + -- child 1 type: string + [ + "2" + ] +)"; + UNIT_ASSERT_VALUES_EQUAL(batch->ToString(), expected); + } + } + + // Variant<T, F> + Y_UNIT_TEST(ArrowFormat_Types_Variant) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ false); + auto client = kikimr.GetQueryClient(); + + { + auto batches = ExecuteAndCombineBatches(client, R"( + SELECT Variant(1, "foo", Variant<foo: Int32, bar: Bool>); + )", /* assertSize */ false); + + UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); + + const auto& batch = batches.front(); + + UNIT_ASSERT_VALUES_EQUAL(batch->num_rows(), 1); + UNIT_ASSERT_VALUES_EQUAL(batch->num_columns(), 1); + UNIT_ASSERT_C(batch->column(0)->type()->id() == arrow::Type::DENSE_UNION, "Column type must be arrow::Type::DENSE_UNION"); + + const TString expected = +R"(column0: -- is_valid: all not null + -- type_ids: [ + 1 + ] + -- value_offsets: [ + 0 + ] + -- child 0 type: uint8 + [] + -- child 1 type: int32 + [ + 1 + ] +)"; + UNIT_ASSERT_VALUES_EQUAL(batch->ToString(), expected); + } + } } } // namespace NKikimr::NKqp diff --git a/ydb/core/kqp/ut/arrow/ya.make b/ydb/core/kqp/ut/arrow/ya.make index 666b7755881..891a4052de2 100644 --- a/ydb/core/kqp/ut/arrow/ya.make +++ b/ydb/core/kqp/ut/arrow/ya.make @@ -8,7 +8,7 @@ SIZE(MEDIUM) SRCS( kqp_arrow_in_channels_ut.cpp kqp_types_arrow_ut.cpp - kqp_result_set_format_arrow.cpp + kqp_result_set_formats.cpp ) PEERDIR( diff --git a/ydb/core/tx/columnshard/test_helper/columnshard_ut_common.h b/ydb/core/tx/columnshard/test_helper/columnshard_ut_common.h index 1ac17a5da7a..98738dcc1eb 100644 --- a/ydb/core/tx/columnshard/test_helper/columnshard_ut_common.h +++ b/ydb/core/tx/columnshard/test_helper/columnshard_ut_common.h @@ -547,6 +547,13 @@ public: } } + if constexpr (std::is_same<TData, NYdb::TUuidValue>::value) { + if constexpr (std::is_same<T, arrow::FixedSizeBinaryType>::value) { + Y_ABORT_UNLESS(typedBuilder.Append(data.Buf_.Bytes).ok()); + return true; + } + } + Y_ABORT("Unknown type combination"); return false; })); diff --git a/ydb/library/yql/dq/runtime/dq_arrow_helpers.cpp b/ydb/library/yql/dq/runtime/dq_arrow_helpers.cpp index c55fddb0c13..a329c36f757 100644 --- a/ydb/library/yql/dq/runtime/dq_arrow_helpers.cpp +++ b/ydb/library/yql/dq/runtime/dq_arrow_helpers.cpp @@ -1,6 +1,8 @@ #include "dq_arrow_helpers.h" #include <cstddef> + +#include <yql/essentials/public/udf/arrow/block_type_helper.h> #include <yql/essentials/minikql/arrow/arrow_util.h> #include <yql/essentials/minikql/computation/mkql_block_reader.h> #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h> @@ -51,11 +53,10 @@ struct TTypeWrapper template <typename TFunc> bool SwitchMiniKQLDataTypeToArrowType(NUdf::EDataSlot type, TFunc&& callback) { switch (type) { - case NUdf::EDataSlot::Bool: - return callback(TTypeWrapper<arrow::BooleanType>()); case NUdf::EDataSlot::Int8: return callback(TTypeWrapper<arrow::Int8Type>()); case NUdf::EDataSlot::Uint8: + case NUdf::EDataSlot::Bool: return callback(TTypeWrapper<arrow::UInt8Type>()); case NUdf::EDataSlot::Int16: return callback(TTypeWrapper<arrow::Int16Type>()); @@ -69,40 +70,59 @@ bool SwitchMiniKQLDataTypeToArrowType(NUdf::EDataSlot type, TFunc&& callback) { case NUdf::EDataSlot::Uint32: return callback(TTypeWrapper<arrow::UInt32Type>()); case NUdf::EDataSlot::Int64: + case NUdf::EDataSlot::Interval: case NUdf::EDataSlot::Datetime64: case NUdf::EDataSlot::Timestamp64: case NUdf::EDataSlot::Interval64: return callback(TTypeWrapper<arrow::Int64Type>()); case NUdf::EDataSlot::Uint64: + case NUdf::EDataSlot::Timestamp: return callback(TTypeWrapper<arrow::UInt64Type>()); case NUdf::EDataSlot::Float: return callback(TTypeWrapper<arrow::FloatType>()); case NUdf::EDataSlot::Double: return callback(TTypeWrapper<arrow::DoubleType>()); - case NUdf::EDataSlot::Timestamp: - return callback(TTypeWrapper<arrow::TimestampType>()); - case NUdf::EDataSlot::Interval: - return callback(TTypeWrapper<arrow::DurationType>()); case NUdf::EDataSlot::Utf8: case NUdf::EDataSlot::Json: - case NUdf::EDataSlot::Yson: - case NUdf::EDataSlot::JsonDocument: return callback(TTypeWrapper<arrow::StringType>()); case NUdf::EDataSlot::String: - case NUdf::EDataSlot::Uuid: case NUdf::EDataSlot::DyNumber: + case NUdf::EDataSlot::Yson: + case NUdf::EDataSlot::JsonDocument: return callback(TTypeWrapper<arrow::BinaryType>()); case NUdf::EDataSlot::Decimal: - return callback(TTypeWrapper<arrow::Decimal128Type>()); - // TODO convert Tz-types to native arrow date and time types. + case NUdf::EDataSlot::Uuid: + return callback(TTypeWrapper<arrow::FixedSizeBinaryType>()); case NUdf::EDataSlot::TzDate: case NUdf::EDataSlot::TzDatetime: case NUdf::EDataSlot::TzTimestamp: case NUdf::EDataSlot::TzDate32: case NUdf::EDataSlot::TzDatetime64: case NUdf::EDataSlot::TzTimestamp64: + return callback(TTypeWrapper<arrow::StructType>()); + } +} + +bool NeedWrapByExternalOptional(const TType* type) { + switch (type->GetKind()) { + case TType::EKind::Void: + case TType::EKind::Null: + case TType::EKind::Variant: + case TType::EKind::Optional: + return true; + case TType::EKind::EmptyList: + case TType::EKind::EmptyDict: + case TType::EKind::Data: + case TType::EKind::Struct: + case TType::EKind::Tuple: + case TType::EKind::List: + case TType::EKind::Dict: return false; + default: + YQL_ENSURE(false, "Unsupported type: " << type->GetKindAsStr()); } + + return true; } template <typename TArrowType> @@ -112,7 +132,52 @@ NUdf::TUnboxedValue GetUnboxedValue(std::shared_ptr<arrow::Array> column, ui32 r return NUdf::TUnboxedValuePod(static_cast<typename TArrowType::c_type>(array->Value(row))); } -// The following 4 specialization are for darwin build (because of difference in long long) +template <> +NUdf::TUnboxedValue GetUnboxedValue<arrow::StructType>(std::shared_ptr<arrow::Array> column, ui32 row) { + auto array = std::static_pointer_cast<arrow::StructArray>(column); + YQL_ENSURE(array->num_fields() == 2, "StructArray of some TzDate type should have 2 fields"); + + auto datetimeArray = array->field(0); + auto timezoneArray = std::static_pointer_cast<arrow::UInt16Array>(array->field(1)); + + NUdf::TUnboxedValuePod value; + + switch (datetimeArray->type()->id()) { + // NUdf::EDataSlot::TzDate + case arrow::Type::UINT16: { + value = NUdf::TUnboxedValuePod(static_cast<ui16>(std::static_pointer_cast<arrow::UInt16Array>(datetimeArray)->Value(row))); + break; + } + // NUdf::EDataSlot::TzDatetime + case arrow::Type::UINT32: { + value = NUdf::TUnboxedValuePod(static_cast<ui32>(std::static_pointer_cast<arrow::UInt32Array>(datetimeArray)->Value(row))); + break; + } + // NUdf::EDataSlot::TzTimestamp + case arrow::Type::UINT64: { + value = NUdf::TUnboxedValuePod(static_cast<ui64>(std::static_pointer_cast<arrow::UInt64Array>(datetimeArray)->Value(row))); + break; + } + // NUdf::EDataSlot::TzDate32 + case arrow::Type::INT32: { + value = NUdf::TUnboxedValuePod(static_cast<i32>(std::static_pointer_cast<arrow::Int32Array>(datetimeArray)->Value(row))); + break; + } + // NUdf::EDataSlot::TzDatetime64, NUdf::EDataSlot::TzTimestamp64 + case arrow::Type::INT64: { + value = NUdf::TUnboxedValuePod(static_cast<i64>(std::static_pointer_cast<arrow::Int64Array>(datetimeArray)->Value(row))); + break; + } + default: + YQL_ENSURE(false, "Unexpected timezone datetime slot"); + return NUdf::TUnboxedValuePod(); + } + + value.SetTimezoneId(timezoneArray->Value(row)); + return value; +} + +// The following specializations are for darwin build (because of difference in long long) template <> // For darwin build NUdf::TUnboxedValue GetUnboxedValue<arrow::UInt64Type>(std::shared_ptr<arrow::Array> column, ui32 row) { @@ -126,20 +191,6 @@ NUdf::TUnboxedValue GetUnboxedValue<arrow::Int64Type>(std::shared_ptr<arrow::Arr return NUdf::TUnboxedValuePod(static_cast<i64>(array->Value(row))); } -template <> // For darwin build -NUdf::TUnboxedValue GetUnboxedValue<arrow::TimestampType>(std::shared_ptr<arrow::Array> column, ui32 row) { - using TArrayType = typename arrow::TypeTraits<arrow::TimestampType>::ArrayType; - auto array = std::static_pointer_cast<TArrayType>(column); - return NUdf::TUnboxedValuePod(static_cast<ui64>(array->Value(row))); -} - -template <> // For darwin build -NUdf::TUnboxedValue GetUnboxedValue<arrow::DurationType>(std::shared_ptr<arrow::Array> column, ui32 row) { - using TArrayType = typename arrow::TypeTraits<arrow::DurationType>::ArrayType; - auto array = std::static_pointer_cast<TArrayType>(column); - return NUdf::TUnboxedValuePod(static_cast<ui64>(array->Value(row))); -} - template <> NUdf::TUnboxedValue GetUnboxedValue<arrow::BinaryType>(std::shared_ptr<arrow::Array> column, ui32 row) { auto array = std::static_pointer_cast<arrow::BinaryArray>(column); @@ -155,46 +206,63 @@ NUdf::TUnboxedValue GetUnboxedValue<arrow::StringType>(std::shared_ptr<arrow::Ar } template <> -NUdf::TUnboxedValue GetUnboxedValue<arrow::Decimal128Type>(std::shared_ptr<arrow::Array> column, ui32 row) { - auto array = std::static_pointer_cast<arrow::Decimal128Array>(column); +NUdf::TUnboxedValue GetUnboxedValue<arrow::FixedSizeBinaryType>(std::shared_ptr<arrow::Array> column, ui32 row) { + auto array = std::static_pointer_cast<arrow::FixedSizeBinaryArray>(column); auto data = array->GetView(row); - // We check that Decimal(22,9) but it may not be true - // TODO Support other decimal precisions. - const auto& type = arrow::internal::checked_cast<const arrow::Decimal128Type&>(*array->type()); - Y_ABORT_UNLESS(type.precision() == NScheme::DECIMAL_PRECISION, "Unsupported Decimal precision."); - Y_ABORT_UNLESS(type.scale() == NScheme::DECIMAL_SCALE, "Unsupported Decimal scale."); - Y_ABORT_UNLESS(data.size() == sizeof(NYql::NDecimal::TInt128), "Wrong data size"); - NYql::NDecimal::TInt128 val; - std::memcpy(reinterpret_cast<char*>(&val), data.data(), data.size()); - return NUdf::TUnboxedValuePod(val); + return NMiniKQL::MakeString(NUdf::TStringRef(data.data(), data.size())); } template <typename TType> -std::shared_ptr<arrow::DataType> CreateEmptyArrowImpl() { +std::shared_ptr<arrow::DataType> CreateEmptyArrowImpl(NUdf::EDataSlot slot) { + Y_UNUSED(slot); return std::make_shared<TType>(); } template <> -std::shared_ptr<arrow::DataType> CreateEmptyArrowImpl<arrow::Decimal128Type>() { - // TODO use non-fixed precision, derive it from data. - return arrow::decimal(NScheme::DECIMAL_PRECISION, NScheme::DECIMAL_SCALE); +std::shared_ptr<arrow::DataType> CreateEmptyArrowImpl<arrow::FixedSizeBinaryType>(NUdf::EDataSlot slot) { + Y_UNUSED(slot); + return arrow::fixed_size_binary(NScheme::FSB_SIZE); } template <> -std::shared_ptr<arrow::DataType> CreateEmptyArrowImpl<arrow::TimestampType>() { - return arrow::timestamp(arrow::TimeUnit::TimeUnit::MICRO); -} +std::shared_ptr<arrow::DataType> CreateEmptyArrowImpl<arrow::StructType>(NUdf::EDataSlot slot) { + std::shared_ptr<arrow::DataType> type; + switch (slot) { + case NUdf::EDataSlot::TzDate: + type = NYql::NUdf::MakeTzLayoutArrowType<NUdf::EDataSlot::TzDate>(); + break; + case NUdf::EDataSlot::TzDatetime: + type = NYql::NUdf::MakeTzLayoutArrowType<NUdf::EDataSlot::TzDatetime>(); + break; + case NUdf::EDataSlot::TzTimestamp: + type = NYql::NUdf::MakeTzLayoutArrowType<NUdf::EDataSlot::TzTimestamp>(); + break; + case NUdf::EDataSlot::TzDate32: + type = NYql::NUdf::MakeTzLayoutArrowType<NUdf::EDataSlot::TzDate32>(); + break; + case NUdf::EDataSlot::TzDatetime64: + type = NYql::NUdf::MakeTzLayoutArrowType<NUdf::EDataSlot::TzDatetime64>(); + break; + case NUdf::EDataSlot::TzTimestamp64: + type = NYql::NUdf::MakeTzLayoutArrowType<NUdf::EDataSlot::TzTimestamp64>(); + break; + default: + YQL_ENSURE(false, "Unexpected timezone datetime slot"); + return std::make_shared<arrow::NullType>(); + } -template <> -std::shared_ptr<arrow::DataType> CreateEmptyArrowImpl<arrow::DurationType>() { - return arrow::duration(arrow::TimeUnit::TimeUnit::MICRO); + std::vector<std::shared_ptr<arrow::Field>> fields { + std::make_shared<arrow::Field>("datetime", type, false), + std::make_shared<arrow::Field>("timezoneId", arrow::uint16(), false), + }; + return arrow::struct_(fields); } std::shared_ptr<arrow::DataType> GetArrowType(const TDataType* dataType) { std::shared_ptr<arrow::DataType> result; bool success = SwitchMiniKQLDataTypeToArrowType(*dataType->GetDataSlot().Get(), [&]<typename TType>(TTypeWrapper<TType> typeHolder) { Y_UNUSED(typeHolder); - result = CreateEmptyArrowImpl<TType>(); + result = CreateEmptyArrowImpl<TType>(*dataType->GetDataSlot().Get()); return true; }); if (success) { @@ -208,8 +276,10 @@ std::shared_ptr<arrow::DataType> GetArrowType(const TStructType* structType) { fields.reserve(structType->GetMembersCount()); for (ui32 index = 0; index < structType->GetMembersCount(); ++index) { auto memberType = structType->GetMemberType(index); - fields.emplace_back(std::make_shared<arrow::Field>(std::string(structType->GetMemberName(index)), - NArrow::GetArrowType(memberType))); + auto memberName = std::string(structType->GetMemberName(index)); + auto memberArrowType = NArrow::GetArrowType(memberType); + + fields.emplace_back(std::make_shared<arrow::Field>(memberName, memberArrowType, memberType->IsOptional())); } return arrow::struct_(fields); } @@ -218,27 +288,42 @@ std::shared_ptr<arrow::DataType> GetArrowType(const TTupleType* tupleType) { std::vector<std::shared_ptr<arrow::Field>> fields; fields.reserve(tupleType->GetElementsCount()); for (ui32 index = 0; index < tupleType->GetElementsCount(); ++index) { + auto elementName = std::string("field" + ToString(index)); auto elementType = tupleType->GetElementType(index); - fields.push_back(std::make_shared<arrow::Field>("", NArrow::GetArrowType(elementType))); + auto elementArrowType = NArrow::GetArrowType(elementType); + + fields.push_back(std::make_shared<arrow::Field>(elementName, elementArrowType, elementType->IsOptional())); } return arrow::struct_(fields); } std::shared_ptr<arrow::DataType> GetArrowType(const TListType* listType) { auto itemType = listType->GetItemType(); - return arrow::list(NArrow::GetArrowType(itemType)); + auto itemArrowType = NArrow::GetArrowType(itemType); + auto field = std::make_shared<arrow::Field>("item", itemArrowType, itemType->IsOptional()); + return arrow::list(field); } std::shared_ptr<arrow::DataType> GetArrowType(const TDictType* dictType) { auto keyType = dictType->GetKeyType(); auto payloadType = dictType->GetPayloadType(); + + auto keyArrowType = NArrow::GetArrowType(keyType); + auto payloadArrowType = NArrow::GetArrowType(payloadType); + + auto custom = std::make_shared<arrow::Field>("custom", arrow::uint64(), false); + if (keyType->GetKind() == TType::EKind::Optional) { - std::vector<std::shared_ptr<arrow::Field>> fields; - fields.emplace_back(std::make_shared<arrow::Field>("", NArrow::GetArrowType(keyType))); - fields.emplace_back(std::make_shared<arrow::Field>("", NArrow::GetArrowType(payloadType))); - return arrow::list(arrow::struct_(fields)); + std::vector<std::shared_ptr<arrow::Field>> items; + items.emplace_back(std::make_shared<arrow::Field>("key", keyArrowType, true)); + items.emplace_back(std::make_shared<arrow::Field>("payload", payloadArrowType, payloadType->IsOptional())); + + auto fieldMap = std::make_shared<arrow::Field>("map", arrow::list(arrow::struct_(items)), false); + return arrow::struct_({fieldMap, custom}); } - return arrow::map(NArrow::GetArrowType(keyType), NArrow::GetArrowType(payloadType)); + + auto fieldMap = std::make_shared<arrow::Field>("map", arrow::map(keyArrowType, payloadArrowType), false); + return arrow::struct_({fieldMap, custom}); } std::shared_ptr<arrow::DataType> GetArrowType(const TVariantType* variantType) { @@ -246,52 +331,77 @@ std::shared_ptr<arrow::DataType> GetArrowType(const TVariantType* variantType) { arrow::FieldVector types; TStructType* structType = nullptr; TTupleType* tupleType = nullptr; + if (innerType->IsStruct()) { structType = static_cast<TStructType*>(innerType); } else { - Y_VERIFY_S(innerType->IsTuple(), "Unexpected underlying variant type: " << innerType->GetKindAsStr()); + YQL_ENSURE(innerType->IsTuple(), "Unexpected underlying variant type: " << innerType->GetKindAsStr()); tupleType = static_cast<TTupleType*>(innerType); } + // Create Union of unions if there are more types then arrow::dense_union supports. if (variantType->GetAlternativesCount() > arrow::UnionType::kMaxTypeCode) { - // Create Union of unions if there are more types then arrow::dense_union supports. ui32 numberOfGroups = (variantType->GetAlternativesCount() - 1) / arrow::UnionType::kMaxTypeCode + 1; types.reserve(numberOfGroups); + for (ui32 groupIndex = 0; groupIndex < numberOfGroups; ++groupIndex) { ui32 beginIndex = groupIndex * arrow::UnionType::kMaxTypeCode; ui32 endIndex = std::min((groupIndex + 1) * arrow::UnionType::kMaxTypeCode, variantType->GetAlternativesCount()); + arrow::FieldVector groupTypes; groupTypes.reserve(endIndex - beginIndex); - if (structType == nullptr) { - for (ui32 index = beginIndex; index < endIndex; ++ index) { - groupTypes.emplace_back(std::make_shared<arrow::Field>("", - NArrow::GetArrowType(tupleType->GetElementType(index)))); - } - } else { - for (ui32 index = beginIndex; index < endIndex; ++ index) { - groupTypes.emplace_back(std::make_shared<arrow::Field>(std::string(structType->GetMemberName(index)), - NArrow::GetArrowType(structType->GetMemberType(index)))); - } - } - types.emplace_back(std::make_shared<arrow::Field>("", arrow::dense_union(groupTypes))); - } - } else { - // Simply put all types in one arrow::dense_union - types.reserve(variantType->GetAlternativesCount()); - if (structType == nullptr) { - for (ui32 index = 0; index < variantType->GetAlternativesCount(); ++index) { - types.push_back(std::make_shared<arrow::Field>("", NArrow::GetArrowType(tupleType->GetElementType(index)))); - } - } else { - for (ui32 index = 0; index < variantType->GetAlternativesCount(); ++index) { - types.emplace_back(std::make_shared<arrow::Field>(std::string(structType->GetMemberName(index)), - NArrow::GetArrowType(structType->GetMemberType(index)))); + + for (ui32 index = beginIndex; index < endIndex; ++index) { + auto itemName = (structType == nullptr) ? std::string("field" + ToString(index)) : std::string(structType->GetMemberName(index)); + auto itemType = (structType == nullptr) ? tupleType->GetElementType(index) : structType->GetMemberType(index); + auto itemArrowType = NArrow::GetArrowType(itemType); + + groupTypes.emplace_back(std::make_shared<arrow::Field>(itemName, itemArrowType, itemType->IsOptional())); } + + auto fieldName = std::string("field" + ToString(groupIndex)); + types.emplace_back(std::make_shared<arrow::Field>(fieldName, arrow::dense_union(groupTypes), false)); } + + return arrow::dense_union(types); } + + // Else put all types in one arrow::dense_union + types.reserve(variantType->GetAlternativesCount()); + for (ui32 index = 0; index < variantType->GetAlternativesCount(); ++index) { + auto itemName = (structType == nullptr) ? std::string("field" + ToString(index)) : std::string(structType->GetMemberName(index)); + auto itemType = (structType == nullptr) ? tupleType->GetElementType(index) : structType->GetMemberType(index); + auto itemArrowType = NArrow::GetArrowType(itemType); + + types.emplace_back(std::make_shared<arrow::Field>(itemName, itemArrowType, itemType->IsOptional())); + } + return arrow::dense_union(types); } +std::shared_ptr<arrow::DataType> GetArrowType(const TOptionalType* optionalType) { + auto currentType = optionalType->GetItemType(); + ui32 depth = 1; + + while (currentType->IsOptional()) { + currentType = static_cast<const TOptionalType*>(currentType)->GetItemType(); + ++depth; + } + + if (NeedWrapByExternalOptional(currentType)) { + ++depth; + } + + std::shared_ptr<arrow::DataType> innerArrowType = NArrow::GetArrowType(currentType); + + for (ui32 i = 1; i < depth; ++i) { + auto field = std::make_shared<arrow::Field>("opt", innerArrowType, false); + innerArrowType = std::make_shared<arrow::StructType>(std::vector<std::shared_ptr<arrow::Field>>{ field }); + } + + return innerArrowType; +} + template <typename TArrowType> void AppendDataValue(arrow::ArrayBuilder* builder, NUdf::TUnboxedValue value) { auto typedBuilder = reinterpret_cast<typename arrow::TypeTraits<TArrowType>::BuilderType*>(builder); @@ -301,12 +411,12 @@ void AppendDataValue(arrow::ArrayBuilder* builder, NUdf::TUnboxedValue value) { } else { status = typedBuilder->Append(value.Get<typename TArrowType::c_type>()); } - Y_VERIFY_S(status.ok(), status.ToString()); + YQL_ENSURE(status.ok(), "Failed to append data value: " << status.ToString()); } template <> void AppendDataValue<arrow::UInt64Type>(arrow::ArrayBuilder* builder, NUdf::TUnboxedValue value) { - Y_DEBUG_ABORT_UNLESS(builder->type()->id() == arrow::Type::UINT64); + YQL_ENSURE(builder->type()->id() == arrow::Type::UINT64); auto typedBuilder = reinterpret_cast<arrow::UInt64Builder*>(builder); arrow::Status status; if (!value.HasValue()) { @@ -314,12 +424,12 @@ void AppendDataValue<arrow::UInt64Type>(arrow::ArrayBuilder* builder, NUdf::TUnb } else { status = typedBuilder->Append(value.Get<ui64>()); } - Y_VERIFY_S(status.ok(), status.ToString()); + YQL_ENSURE(status.ok(), "Failed to append data value: " << status.ToString()); } template <> void AppendDataValue<arrow::Int64Type>(arrow::ArrayBuilder* builder, NUdf::TUnboxedValue value) { - Y_DEBUG_ABORT_UNLESS(builder->type()->id() == arrow::Type::INT64); + YQL_ENSURE(builder->type()->id() == arrow::Type::INT64); auto typedBuilder = reinterpret_cast<arrow::Int64Builder*>(builder); arrow::Status status; if (!value.HasValue()) { @@ -327,38 +437,12 @@ void AppendDataValue<arrow::Int64Type>(arrow::ArrayBuilder* builder, NUdf::TUnbo } else { status = typedBuilder->Append(value.Get<i64>()); } - Y_VERIFY_S(status.ok(), status.ToString()); -} - -template <> -void AppendDataValue<arrow::TimestampType>(arrow::ArrayBuilder* builder, NUdf::TUnboxedValue value) { - Y_DEBUG_ABORT_UNLESS(builder->type()->id() == arrow::Type::TIMESTAMP); - auto typedBuilder = reinterpret_cast<arrow::TimestampBuilder*>(builder); - arrow::Status status; - if (!value.HasValue()) { - status = typedBuilder->AppendNull(); - } else { - status = typedBuilder->Append(value.Get<ui64>()); - } - Y_VERIFY_S(status.ok(), status.ToString()); -} - -template <> -void AppendDataValue<arrow::DurationType>(arrow::ArrayBuilder* builder, NUdf::TUnboxedValue value) { - Y_DEBUG_ABORT_UNLESS(builder->type()->id() == arrow::Type::DURATION); - auto typedBuilder = reinterpret_cast<arrow::DurationBuilder*>(builder); - arrow::Status status; - if (!value.HasValue()) { - status = typedBuilder->AppendNull(); - } else { - status = typedBuilder->Append(value.Get<ui64>()); - } - Y_VERIFY_S(status.ok(), status.ToString()); + YQL_ENSURE(status.ok(), "Failed to append data value: " << status.ToString()); } template <> void AppendDataValue<arrow::StringType>(arrow::ArrayBuilder* builder, NUdf::TUnboxedValue value) { - Y_DEBUG_ABORT_UNLESS(builder->type()->id() == arrow::Type::STRING); + YQL_ENSURE(builder->type()->id() == arrow::Type::STRING); auto typedBuilder = reinterpret_cast<arrow::StringBuilder*>(builder); arrow::Status status; if (!value.HasValue()) { @@ -367,12 +451,12 @@ void AppendDataValue<arrow::StringType>(arrow::ArrayBuilder* builder, NUdf::TUnb auto data = value.AsStringRef(); status = typedBuilder->Append(data.Data(), data.Size()); } - Y_VERIFY_S(status.ok(), status.ToString()); + YQL_ENSURE(status.ok(), "Failed to append data value: " << status.ToString()); } template <> void AppendDataValue<arrow::BinaryType>(arrow::ArrayBuilder* builder, NUdf::TUnboxedValue value) { - Y_DEBUG_ABORT_UNLESS(builder->type()->id() == arrow::Type::BINARY); + YQL_ENSURE(builder->type()->id() == arrow::Type::BINARY); auto typedBuilder = reinterpret_cast<arrow::BinaryBuilder*>(builder); arrow::Status status; if (!value.HasValue()) { @@ -381,21 +465,86 @@ void AppendDataValue<arrow::BinaryType>(arrow::ArrayBuilder* builder, NUdf::TUnb auto data = value.AsStringRef(); status = typedBuilder->Append(data.Data(), data.Size()); } - Y_VERIFY_S(status.ok(), status.ToString()); + YQL_ENSURE(status.ok(), "Failed to append data value: " << status.ToString()); } +// Only for timezone datetime types template <> -void AppendDataValue<arrow::Decimal128Type>(arrow::ArrayBuilder* builder, NUdf::TUnboxedValue value) { - Y_DEBUG_ABORT_UNLESS(builder->type()->id() == arrow::Type::DECIMAL128); - auto typedBuilder = reinterpret_cast<arrow::Decimal128Builder*>(builder); +void AppendDataValue<arrow::StructType>(arrow::ArrayBuilder* builder, NUdf::TUnboxedValue value) { + YQL_ENSURE(builder->type()->id() == arrow::Type::STRUCT); + auto typedBuilder = reinterpret_cast<arrow::StructBuilder*>(builder); + YQL_ENSURE(typedBuilder->num_fields() == 2, "StructBuilder of timezone datetime types should have 2 fields"); + + if (!value.HasValue()) { + auto status = typedBuilder->AppendNull(); + YQL_ENSURE(status.ok(), "Failed to append data value: " << status.ToString()); + return; + } + + auto status = typedBuilder->Append(); + YQL_ENSURE(status.ok(), "Failed to append data value: " << status.ToString()); + + auto datetimeArray = typedBuilder->field_builder(0); + auto timezoneArray = reinterpret_cast<arrow::UInt16Builder*>(typedBuilder->field_builder(1)); + + switch (datetimeArray->type()->id()) { + // NUdf::EDataSlot::TzDate + case arrow::Type::UINT16: { + status = reinterpret_cast<arrow::UInt16Builder*>(datetimeArray)->Append(value.Get<ui16>()); + break; + } + // NUdf::EDataSlot::TzDatetime + case arrow::Type::UINT32: { + status = reinterpret_cast<arrow::UInt32Builder*>(datetimeArray)->Append(value.Get<ui32>()); + break; + } + // NUdf::EDataSlot::TzTimestamp + case arrow::Type::UINT64: { + status = reinterpret_cast<arrow::UInt64Builder*>(datetimeArray)->Append(value.Get<ui64>()); + break; + } + // NUdf::EDataSlot::TzDate32 + case arrow::Type::INT32: { + status = reinterpret_cast<arrow::Int32Builder*>(datetimeArray)->Append(value.Get<i32>()); + break; + } + // NUdf::EDataSlot::TzDatetime64, NUdf::EDataSlot::TzTimestamp64 + case arrow::Type::INT64: { + status = reinterpret_cast<arrow::Int64Builder*>(datetimeArray)->Append(value.Get<i64>()); + break; + } + default: + YQL_ENSURE(false, "Unexpected timezone datetime slot"); + return; + } + YQL_ENSURE(status.ok(), "Failed to append data value: " << status.ToString()); + + status = timezoneArray->Append(value.GetTimezoneId()); + YQL_ENSURE(status.ok(), "Failed to append data value: " << status.ToString()); +} + +template <typename TArrowType> +void AppendFixedSizeDataValue(arrow::ArrayBuilder* builder, NUdf::TUnboxedValue value, NUdf::EDataSlot dataSlot) { + static_assert(std::is_same_v<TArrowType, arrow::FixedSizeBinaryType>, "This function is only for FixedSizeBinaryType"); + + YQL_ENSURE(builder->type()->id() == arrow::Type::FIXED_SIZE_BINARY); + auto typedBuilder = reinterpret_cast<arrow::FixedSizeBinaryBuilder*>(builder); arrow::Status status; + if (!value.HasValue()) { status = typedBuilder->AppendNull(); } else { - // Parse value from string - status = typedBuilder->Append(value.AsStringRef().Data()); + if (dataSlot == NUdf::EDataSlot::Uuid) { + auto data = value.AsStringRef(); + status = typedBuilder->Append(data.Data()); + } else if (dataSlot == NUdf::EDataSlot::Decimal) { + auto intVal = value.GetInt128(); + status = typedBuilder->Append(reinterpret_cast<const char*>(&intVal)); + } else { + YQL_ENSURE(false, "Unexpected data slot"); + } } - Y_VERIFY_S(status.ok(), status.ToString()); + YQL_ENSURE(status.ok(), "Failed to append data value: " << status.ToString()); } } // namespace @@ -404,9 +553,10 @@ std::shared_ptr<arrow::DataType> GetArrowType(const TType* type) { switch (type->GetKind()) { case TType::EKind::Void: case TType::EKind::Null: + return arrow::null(); case TType::EKind::EmptyList: case TType::EKind::EmptyDict: - break; + return arrow::struct_({}); case TType::EKind::Data: { auto dataType = static_cast<const TDataType*>(type); return GetArrowType(dataType); @@ -421,17 +571,7 @@ std::shared_ptr<arrow::DataType> GetArrowType(const TType* type) { } case TType::EKind::Optional: { auto optionalType = static_cast<const TOptionalType*>(type); - auto innerOptionalType = optionalType->GetItemType(); - if (innerOptionalType->GetKind() == TType::EKind::Optional) { - std::vector<std::shared_ptr<arrow::Field>> fields; - fields.emplace_back(std::make_shared<arrow::Field>("", std::make_shared<arrow::UInt64Type>())); - while (innerOptionalType->GetKind() == TType::EKind::Optional) { - innerOptionalType = static_cast<const TOptionalType*>(innerOptionalType)->GetItemType(); - } - fields.emplace_back(std::make_shared<arrow::Field>("", GetArrowType(innerOptionalType))); - return arrow::struct_(fields); - } - return GetArrowType(innerOptionalType); + return GetArrowType(optionalType); } case TType::EKind::List: { auto listType = static_cast<const TListType*>(type); @@ -446,7 +586,7 @@ std::shared_ptr<arrow::DataType> GetArrowType(const TType* type) { return GetArrowType(variantType); } default: - THROW yexception() << "Unsupported type: " << type->GetKindAsStr(); + YQL_ENSURE(false, "Unsupported type: " << type->GetKindAsStr()); } return arrow::null(); } @@ -459,6 +599,7 @@ bool IsArrowCompatible(const NKikimr::NMiniKQL::TType* type) { case TType::EKind::EmptyDict: case TType::EKind::Data: return true; + case TType::EKind::Struct: { auto structType = static_cast<const TStructType*>(type); bool isCompatible = true; @@ -468,6 +609,7 @@ bool IsArrowCompatible(const NKikimr::NMiniKQL::TType* type) { } return isCompatible; } + case TType::EKind::Tuple: { auto tupleType = static_cast<const TTupleType*>(type); bool isCompatible = true; @@ -477,37 +619,33 @@ bool IsArrowCompatible(const NKikimr::NMiniKQL::TType* type) { } return isCompatible; } + case TType::EKind::Optional: { auto optionalType = static_cast<const TOptionalType*>(type); auto innerOptionalType = optionalType->GetItemType(); - if (innerOptionalType->GetKind() == TType::EKind::Optional) { + if (NeedWrapByExternalOptional(innerOptionalType)) { return false; } return IsArrowCompatible(innerOptionalType); } + case TType::EKind::List: { auto listType = static_cast<const TListType*>(type); auto itemType = listType->GetItemType(); return IsArrowCompatible(itemType); } - case TType::EKind::Dict: { - auto dictType = static_cast<const TDictType*>(type); - auto keyType = dictType->GetKeyType(); - auto payloadType = dictType->GetPayloadType(); - if (keyType->GetKind() == TType::EKind::Optional) { - return false; - } - return IsArrowCompatible(keyType) && IsArrowCompatible(payloadType); - } + case TType::EKind::Variant: { auto variantType = static_cast<const TVariantType*>(type); if (variantType->GetAlternativesCount() > arrow::UnionType::kMaxTypeCode) { return false; } TType* innerType = variantType->GetUnderlyingType(); - Y_VERIFY_S(innerType->IsTuple() || innerType->IsStruct(), "Unexpected underlying variant type: " << innerType->GetKindAsStr()); + YQL_ENSURE(innerType->IsTuple() || innerType->IsStruct(), "Unexpected underlying variant type: " << innerType->GetKindAsStr()); return IsArrowCompatible(innerType); } + + case TType::EKind::Dict: case TType::EKind::Block: case TType::EKind::Type: case TType::EKind::Stream: @@ -529,66 +667,84 @@ std::unique_ptr<arrow::ArrayBuilder> MakeArrowBuilder(const TType* type) { auto arrayType = GetArrowType(type); std::unique_ptr<arrow::ArrayBuilder> builder; auto status = arrow::MakeBuilder(arrow::default_memory_pool(), arrayType, &builder); - Y_VERIFY_S(status.ok(), status.ToString()); + YQL_ENSURE(status.ok(), "Failed to make arrow builder: " << status.ToString()); return builder; } void AppendElement(NUdf::TUnboxedValue value, arrow::ArrayBuilder* builder, const TType* type) { switch (type->GetKind()) { case TType::EKind::Void: - case TType::EKind::Null: + case TType::EKind::Null: { + YQL_ENSURE(builder->type()->id() == arrow::Type::NA, "Unexpected builder type"); + auto status = builder->AppendNull(); + YQL_ENSURE(status.ok(), "Failed to append null value: " << status.ToString()); + break; + } + case TType::EKind::EmptyList: case TType::EKind::EmptyDict: { - auto status = builder->AppendNull(); - Y_VERIFY_S(status.ok(), status.ToString()); + YQL_ENSURE(builder->type()->id() == arrow::Type::STRUCT, "Unexpected builder type"); + auto structBuilder = reinterpret_cast<arrow::StructBuilder*>(builder); + auto status = structBuilder->Append(); + YQL_ENSURE(status.ok(), "Failed to append empty dict/list value: " << status.ToString()); break; } case TType::EKind::Data: { - // TODO for TzDate, TzDatetime, TzTimestamp pass timezone to arrow builder? auto dataType = static_cast<const TDataType*>(type); - bool success = SwitchMiniKQLDataTypeToArrowType(*dataType->GetDataSlot().Get(), [&]<typename TType>(TTypeWrapper<TType> typeHolder) { + auto slot = *dataType->GetDataSlot().Get(); + bool success = SwitchMiniKQLDataTypeToArrowType(slot, [&]<typename TType>(TTypeWrapper<TType> typeHolder) { Y_UNUSED(typeHolder); - AppendDataValue<TType>(builder, value); + if constexpr (std::is_same_v<TType, arrow::FixedSizeBinaryType>) { + AppendFixedSizeDataValue<TType>(builder, value, slot); + } else { + AppendDataValue<TType>(builder, value); + } return true; }); - Y_ABORT_UNLESS(success); + YQL_ENSURE(success, "Failed to append data value to arrow builder"); break; } case TType::EKind::Optional: { - auto optionalType = static_cast<const TOptionalType*>(type); - if (optionalType->GetItemType()->GetKind() != TType::EKind::Optional) { - if (value.HasValue()) { - AppendElement(value.GetOptionalValue(), builder, optionalType->GetItemType()); - } else { - auto status = builder->AppendNull(); - Y_VERIFY_S(status.ok(), status.ToString()); + auto innerType = static_cast<const TOptionalType*>(type)->GetItemType(); + ui32 depth = 1; + + while (innerType->IsOptional()) { + innerType = static_cast<const TOptionalType*>(innerType)->GetItemType(); + ++depth; + } + + if (NeedWrapByExternalOptional(innerType)) { + ++depth; + } + + auto innerBuilder = builder; + auto innerValue = value; + + for (ui32 i = 1; i < depth; ++i) { + YQL_ENSURE(innerBuilder->type()->id() == arrow::Type::STRUCT, "Unexpected builder type"); + auto structBuilder = reinterpret_cast<arrow::StructBuilder*>(innerBuilder); + YQL_ENSURE(structBuilder->num_fields() == 1, "Unexpected number of fields"); + + if (!innerValue) { + auto status = innerBuilder->AppendNull(); + YQL_ENSURE(status.ok(), "Failed to append null optional value: " << status.ToString()); + return; } - } else { - Y_DEBUG_ABORT_UNLESS(builder->type()->id() == arrow::Type::STRUCT); - auto structBuilder = reinterpret_cast<arrow::StructBuilder*>(builder); - Y_DEBUG_ABORT_UNLESS(structBuilder->num_fields() == 2); - Y_DEBUG_ABORT_UNLESS(structBuilder->field_builder(0)->type()->id() == arrow::Type::UINT64); + auto status = structBuilder->Append(); - Y_VERIFY_S(status.ok(), status.ToString()); - auto depthBuilder = reinterpret_cast<arrow::UInt64Builder*>(structBuilder->field_builder(0)); - auto valueBuilder = structBuilder->field_builder(1); - ui64 depth = 0; - TType* innerType = optionalType->GetItemType(); - while (innerType->GetKind() == TType::EKind::Optional && value.HasValue()) { - innerType = static_cast<const TOptionalType*>(innerType)->GetItemType(); - value = value.GetOptionalValue(); - ++depth; - } - status = depthBuilder->Append(depth); - Y_VERIFY_S(status.ok(), status.ToString()); - if (value.HasValue()) { - AppendElement(value, valueBuilder, innerType); - } else { - status = valueBuilder->AppendNull(); - Y_VERIFY_S(status.ok(), status.ToString()); - } + YQL_ENSURE(status.ok(), "Failed to append optional value: " << status.ToString()); + + innerValue = innerValue.GetOptionalValue(); + innerBuilder = structBuilder->field_builder(0); + } + + if (innerValue) { + AppendElement(innerValue.GetOptionalValue(), innerBuilder, innerType); + } else { + auto status = innerBuilder->AppendNull(); + YQL_ENSURE(status.ok(), "Failed to append null optional value: " << status.ToString()); } break; } @@ -596,16 +752,19 @@ void AppendElement(NUdf::TUnboxedValue value, arrow::ArrayBuilder* builder, cons case TType::EKind::List: { auto listType = static_cast<const TListType*>(type); auto itemType = listType->GetItemType(); - Y_DEBUG_ABORT_UNLESS(builder->type()->id() == arrow::Type::LIST); + + YQL_ENSURE(builder->type()->id() == arrow::Type::LIST, "Unexpected builder type"); auto listBuilder = reinterpret_cast<arrow::ListBuilder*>(builder); + auto status = listBuilder->Append(); - Y_VERIFY_S(status.ok(), status.ToString()); + YQL_ENSURE(status.ok(), "Failed to append list value: " << status.ToString()); + auto innerBuilder = listBuilder->value_builder(); - if (auto p = value.GetElements()) { - auto len = value.GetListLength(); - while (len > 0) { - AppendElement(*p++, innerBuilder, itemType); - --len; + if (auto item = value.GetElements()) { + auto length = value.GetListLength(); + while (length > 0) { + AppendElement(*item++, innerBuilder, itemType); + --length; } } else { const auto iter = value.GetListIterator(); @@ -618,11 +777,14 @@ void AppendElement(NUdf::TUnboxedValue value, arrow::ArrayBuilder* builder, cons case TType::EKind::Struct: { auto structType = static_cast<const TStructType*>(type); - Y_DEBUG_ABORT_UNLESS(builder->type()->id() == arrow::Type::STRUCT); + + YQL_ENSURE(builder->type()->id() == arrow::Type::STRUCT, "Unexpected builder type"); auto structBuilder = reinterpret_cast<arrow::StructBuilder*>(builder); + auto status = structBuilder->Append(); - Y_VERIFY_S(status.ok(), status.ToString()); - Y_DEBUG_ABORT_UNLESS(static_cast<ui32>(structBuilder->num_fields()) == structType->GetMembersCount()); + YQL_ENSURE(status.ok(), "Failed to append struct value: " << status.ToString()); + + YQL_ENSURE(static_cast<ui32>(structBuilder->num_fields()) == structType->GetMembersCount(), "Unexpected number of fields"); for (ui32 index = 0; index < structType->GetMembersCount(); ++index) { auto innerBuilder = structBuilder->field_builder(index); auto memberType = structType->GetMemberType(index); @@ -633,11 +795,14 @@ void AppendElement(NUdf::TUnboxedValue value, arrow::ArrayBuilder* builder, cons case TType::EKind::Tuple: { auto tupleType = static_cast<const TTupleType*>(type); - Y_DEBUG_ABORT_UNLESS(builder->type()->id() == arrow::Type::STRUCT); + + YQL_ENSURE(builder->type()->id() == arrow::Type::STRUCT, "Unexpected builder type"); auto structBuilder = reinterpret_cast<arrow::StructBuilder*>(builder); + auto status = structBuilder->Append(); - Y_VERIFY_S(status.ok(), status.ToString()); - Y_DEBUG_ABORT_UNLESS(static_cast<ui32>(structBuilder->num_fields()) == tupleType->GetElementsCount()); + YQL_ENSURE(status.ok(), "Failed to append tuple value: " << status.ToString()); + + YQL_ENSURE(static_cast<ui32>(structBuilder->num_fields()) == tupleType->GetElementsCount(), "Unexpected number of fields"); for (ui32 index = 0; index < tupleType->GetElementsCount(); ++index) { auto innerBuilder = structBuilder->field_builder(index); auto elementType = tupleType->GetElementType(index); @@ -651,37 +816,53 @@ void AppendElement(NUdf::TUnboxedValue value, arrow::ArrayBuilder* builder, cons auto keyType = dictType->GetKeyType(); auto payloadType = dictType->GetPayloadType(); - arrow::ArrayBuilder* keyBuilder; - arrow::ArrayBuilder* itemBuilder; + arrow::ArrayBuilder* keyBuilder = nullptr; + arrow::ArrayBuilder* itemBuilder = nullptr; arrow::StructBuilder* structBuilder = nullptr; + + YQL_ENSURE(builder->type()->id() == arrow::Type::STRUCT, "Unexpected builder type"); + arrow::StructBuilder* wrapBuilder = reinterpret_cast<arrow::StructBuilder*>(builder); + YQL_ENSURE(wrapBuilder->num_fields() == 2, "Unexpected number of fields"); + + auto status = wrapBuilder->Append(); + YQL_ENSURE(status.ok(), "Failed to append dict value: " << status.ToString()); + if (keyType->GetKind() == TType::EKind::Optional) { - Y_DEBUG_ABORT_UNLESS(builder->type()->id() == arrow::Type::LIST); - auto listBuilder = reinterpret_cast<arrow::ListBuilder*>(builder); - Y_DEBUG_ABORT_UNLESS(listBuilder->value_builder()->type()->id() == arrow::Type::STRUCT); - // Start a new list in ListArray of structs + YQL_ENSURE(wrapBuilder->field_builder(0)->type()->id() == arrow::Type::LIST, "Unexpected builder type"); + auto listBuilder = reinterpret_cast<arrow::ListBuilder*>(wrapBuilder->field_builder(0)); + auto status = listBuilder->Append(); - Y_VERIFY_S(status.ok(), status.ToString()); + YQL_ENSURE(status.ok(), "Failed to append dict value: " << status.ToString()); + + YQL_ENSURE(listBuilder->value_builder()->type()->id() == arrow::Type::STRUCT, "Unexpected builder type"); structBuilder = reinterpret_cast<arrow::StructBuilder*>(listBuilder->value_builder()); - Y_DEBUG_ABORT_UNLESS(structBuilder->num_fields() == 2); + YQL_ENSURE(structBuilder->num_fields() == 2, "Unexpected number of fields"); + keyBuilder = structBuilder->field_builder(0); itemBuilder = structBuilder->field_builder(1); } else { - Y_DEBUG_ABORT_UNLESS(builder->type()->id() == arrow::Type::MAP); - auto mapBuilder = reinterpret_cast<arrow::MapBuilder*>(builder); - // Start a new map in MapArray + YQL_ENSURE(wrapBuilder->field_builder(0)->type()->id() == arrow::Type::MAP, "Unexpected builder type"); + auto mapBuilder = reinterpret_cast<arrow::MapBuilder*>(wrapBuilder->field_builder(0)); + auto status = mapBuilder->Append(); - Y_VERIFY_S(status.ok(), status.ToString()); + YQL_ENSURE(status.ok(), "Failed to append dict value: " << status.ToString()); + keyBuilder = mapBuilder->key_builder(); itemBuilder = mapBuilder->item_builder(); } - const auto iter = value.GetDictIterator(); + arrow::UInt64Builder* customBuilder = reinterpret_cast<arrow::UInt64Builder*>(wrapBuilder->field_builder(1)); + status = customBuilder->Append(0); + YQL_ENSURE(status.ok(), "Failed to append dict value: " << status.ToString()); + // We do not sort dictionary before appending it to builder. + const auto iter = value.GetDictIterator(); for (NUdf::TUnboxedValue key, payload; iter.NextPair(key, payload);) { if (structBuilder != nullptr) { - auto status = structBuilder->Append(); - Y_VERIFY_S(status.ok(), status.ToString()); + status = structBuilder->Append(); + YQL_ENSURE(status.ok(), "Failed to append dict value: " << status.ToString()); } + AppendElement(key, keyBuilder, keyType); AppendElement(payload, itemBuilder, payloadType); } @@ -691,33 +872,42 @@ void AppendElement(NUdf::TUnboxedValue value, arrow::ArrayBuilder* builder, cons case TType::EKind::Variant: { // TODO Need to properly convert variants containing more than 127*127 types? auto variantType = static_cast<const TVariantType*>(type); - Y_DEBUG_ABORT_UNLESS(builder->type()->id() == arrow::Type::DENSE_UNION); + + YQL_ENSURE(builder->type()->id() == arrow::Type::DENSE_UNION, "Unexpected builder type"); auto unionBuilder = reinterpret_cast<arrow::DenseUnionBuilder*>(builder); + ui32 variantIndex = value.GetVariantIndex(); TType* innerType = variantType->GetUnderlyingType(); + if (innerType->IsStruct()) { innerType = static_cast<TStructType*>(innerType)->GetMemberType(variantIndex); } else { - Y_VERIFY_S(innerType->IsTuple(), "Unexpected underlying variant type: " << innerType->GetKindAsStr()); + YQL_ENSURE(innerType->IsTuple(), "Unexpected underlying variant type: " << innerType->GetKindAsStr()); innerType = static_cast<TTupleType*>(innerType)->GetElementType(variantIndex); } + if (variantType->GetAlternativesCount() > arrow::UnionType::kMaxTypeCode) { ui32 numberOfGroups = (variantType->GetAlternativesCount() - 1) / arrow::UnionType::kMaxTypeCode + 1; - Y_DEBUG_ABORT_UNLESS(static_cast<ui32>(unionBuilder->num_children()) == numberOfGroups); + YQL_ENSURE(static_cast<ui32>(unionBuilder->num_children()) == numberOfGroups); + ui32 groupIndex = variantIndex / arrow::UnionType::kMaxTypeCode; auto status = unionBuilder->Append(groupIndex); - Y_VERIFY_S(status.ok(), status.ToString()); + YQL_ENSURE(status.ok(), "Failed to append variant value: " << status.ToString()); + auto innerBuilder = unionBuilder->child_builder(groupIndex); - Y_DEBUG_ABORT_UNLESS(innerBuilder->type()->id() == arrow::Type::DENSE_UNION); + YQL_ENSURE(innerBuilder->type()->id() == arrow::Type::DENSE_UNION); auto innerUnionBuilder = reinterpret_cast<arrow::DenseUnionBuilder*>(innerBuilder.get()); + ui32 innerVariantIndex = variantIndex % arrow::UnionType::kMaxTypeCode; status = innerUnionBuilder->Append(innerVariantIndex); - Y_VERIFY_S(status.ok(), status.ToString()); + YQL_ENSURE(status.ok(), "Failed to append variant value: " << status.ToString()); + auto doubleInnerBuilder = innerUnionBuilder->child_builder(innerVariantIndex); AppendElement(value.GetVariantItem(), doubleInnerBuilder.get(), innerType); } else { auto status = unionBuilder->Append(variantIndex); - Y_VERIFY_S(status.ok(), status.ToString()); + YQL_ENSURE(status.ok(), "Failed to append variant value: " << status.ToString()); + auto innerBuilder = unionBuilder->child_builder(variantIndex); AppendElement(value.GetVariantItem(), innerBuilder.get(), innerType); } @@ -725,20 +915,20 @@ void AppendElement(NUdf::TUnboxedValue value, arrow::ArrayBuilder* builder, cons } default: - THROW yexception() << "Unsupported type: " << type->GetKindAsStr(); + YQL_ENSURE(false, "Unsupported type: " << type->GetKindAsStr()); } } std::shared_ptr<arrow::Array> MakeArray(NMiniKQL::TUnboxedValueVector& values, const TType* itemType) { auto builder = MakeArrowBuilder(itemType); auto status = builder->Reserve(values.size()); - Y_VERIFY_S(status.ok(), status.ToString()); + YQL_ENSURE(status.ok(), "Failed to reserve space for array: " << status.ToString()); for (auto& value: values) { AppendElement(value, builder.get(), itemType); } std::shared_ptr<arrow::Array> result; status = builder->Finish(&result); - Y_VERIFY_S(status.ok(), status.ToString()); + YQL_ENSURE(status.ok(), "Failed to finish array: " << status.ToString()); return result; } @@ -746,13 +936,15 @@ NUdf::TUnboxedValue ExtractUnboxedValue(const std::shared_ptr<arrow::Array>& arr if (array->IsNull(row)) { return NUdf::TUnboxedValuePod(); } + switch(itemType->GetKind()) { case TType::EKind::Void: case TType::EKind::Null: case TType::EKind::EmptyList: case TType::EKind::EmptyDict: break; - case TType::EKind::Data: { // TODO TzDate need special care + + case TType::EKind::Data: { auto dataType = static_cast<const TDataType*>(itemType); NUdf::TUnboxedValue result; bool success = SwitchMiniKQLDataTypeToArrowType(*dataType->GetDataSlot().Get(), [&]<typename TType>(TTypeWrapper<TType> typeHolder) { @@ -760,70 +952,98 @@ NUdf::TUnboxedValue ExtractUnboxedValue(const std::shared_ptr<arrow::Array>& arr result = GetUnboxedValue<TType>(array, row); return true; }); - Y_DEBUG_ABORT_UNLESS(success); + Y_ENSURE(success, "Failed to extract unboxed value from arrow array"); return result; } + case TType::EKind::Struct: { auto structType = static_cast<const TStructType*>(itemType); - Y_DEBUG_ABORT_UNLESS(array->type_id() == arrow::Type::STRUCT); + + YQL_ENSURE(array->type_id() == arrow::Type::STRUCT); auto typedArray = static_pointer_cast<arrow::StructArray>(array); - Y_DEBUG_ABORT_UNLESS(static_cast<ui32>(typedArray->num_fields()) == structType->GetMembersCount()); + YQL_ENSURE(static_cast<ui32>(typedArray->num_fields()) == structType->GetMembersCount()); + NUdf::TUnboxedValue* itemsPtr = nullptr; auto result = holderFactory.CreateDirectArrayHolder(structType->GetMembersCount(), itemsPtr); + for (ui32 index = 0; index < structType->GetMembersCount(); ++index) { auto memberType = structType->GetMemberType(index); itemsPtr[index] = ExtractUnboxedValue(typedArray->field(index), row, memberType, holderFactory); } return result; } + case TType::EKind::Tuple: { auto tupleType = static_cast<const TTupleType*>(itemType); - Y_DEBUG_ABORT_UNLESS(array->type_id() == arrow::Type::STRUCT); + + YQL_ENSURE(array->type_id() == arrow::Type::STRUCT); auto typedArray = static_pointer_cast<arrow::StructArray>(array); - Y_DEBUG_ABORT_UNLESS(static_cast<ui32>(typedArray->num_fields()) == tupleType->GetElementsCount()); + YQL_ENSURE(static_cast<ui32>(typedArray->num_fields()) == tupleType->GetElementsCount()); + NUdf::TUnboxedValue* itemsPtr = nullptr; auto result = holderFactory.CreateDirectArrayHolder(tupleType->GetElementsCount(), itemsPtr); + for (ui32 index = 0; index < tupleType->GetElementsCount(); ++index) { auto elementType = tupleType->GetElementType(index); itemsPtr[index] = ExtractUnboxedValue(typedArray->field(index), row, elementType, holderFactory); } return result; } + case TType::EKind::Optional: { auto optionalType = static_cast<const TOptionalType*>(itemType); auto innerOptionalType = optionalType->GetItemType(); - if (innerOptionalType->GetKind() == TType::EKind::Optional) { - Y_DEBUG_ABORT_UNLESS(array->type_id() == arrow::Type::STRUCT); - auto structArray = static_pointer_cast<arrow::StructArray>(array); - Y_DEBUG_ABORT_UNLESS(structArray->num_fields() == 2); - Y_DEBUG_ABORT_UNLESS(structArray->field(0)->type_id() == arrow::Type::UINT64); - auto depthArray = static_pointer_cast<arrow::UInt64Array>(structArray->field(0)); - auto valuesArray = structArray->field(1); - auto depth = depthArray->Value(row); + + if (NeedWrapByExternalOptional(innerOptionalType)) { + YQL_ENSURE(array->type_id() == arrow::Type::STRUCT); + + auto innerArray = array; + auto innerType = itemType; + NUdf::TUnboxedValue value; - if (valuesArray->IsNull(row)) { - value = NUdf::TUnboxedValuePod(); - } else { - while (innerOptionalType->GetKind() == TType::EKind::Optional) { - innerOptionalType = static_cast<const TOptionalType*>(innerOptionalType)->GetItemType(); + int depth = 0; + + while (innerArray->type_id() == arrow::Type::STRUCT) { + auto structArray = static_pointer_cast<arrow::StructArray>(innerArray); + YQL_ENSURE(structArray->num_fields() == 1); + + if (structArray->IsNull(row)) { + value = NUdf::TUnboxedValuePod(); + break; } - value = ExtractUnboxedValue(valuesArray, row, innerOptionalType, holderFactory); + + innerType = static_cast<const TOptionalType*>(innerType)->GetItemType(); + innerArray = structArray->field(0); + ++depth; } - for (ui64 i = 0; i < depth; ++i) { + + auto wrap = NeedWrapByExternalOptional(innerType); + if (wrap || !innerArray->IsNull(row)) { + value = ExtractUnboxedValue(innerArray, row, innerType, holderFactory); + if (wrap) { + --depth; + } + } + + for (int i = 0; i < depth; ++i) { value = value.MakeOptional(); } return value; - } else { - return ExtractUnboxedValue(array, row, innerOptionalType, holderFactory).Release().MakeOptional(); } + + return ExtractUnboxedValue(array, row, innerOptionalType, holderFactory).Release().MakeOptional(); } + case TType::EKind::List: { auto listType = static_cast<const TListType*>(itemType); - Y_DEBUG_ABORT_UNLESS(array->type_id() == arrow::Type::LIST); + + YQL_ENSURE(array->type_id() == arrow::Type::LIST); auto typedArray = static_pointer_cast<arrow::ListArray>(array); + auto arraySlice = typedArray->value_slice(row); auto itemType = listType->GetItemType(); const auto len = arraySlice->length(); + NUdf::TUnboxedValue *items = nullptr; auto list = holderFactory.CreateDirectArrayHolder(len, items); for (ui64 i = 0; i < static_cast<ui64>(len); ++i) { @@ -831,8 +1051,10 @@ NUdf::TUnboxedValue ExtractUnboxedValue(const std::shared_ptr<arrow::Array>& arr } return list; } + case TType::EKind::Dict: { auto dictType = static_cast<const TDictType*>(itemType); + auto keyType = dictType->GetKeyType(); auto payloadType = dictType->GetPayloadType(); auto dictBuilder = holderFactory.NewDict(dictType, NUdf::TDictFlags::EDictKind::Hashed); @@ -841,24 +1063,35 @@ NUdf::TUnboxedValue ExtractUnboxedValue(const std::shared_ptr<arrow::Array>& arr std::shared_ptr<arrow::Array> payloadArray = nullptr; ui64 dictLength = 0; ui64 offset = 0; + + YQL_ENSURE(array->type_id() == arrow::Type::STRUCT); + auto wrapArray = static_pointer_cast<arrow::StructArray>(array); + YQL_ENSURE(wrapArray->num_fields() == 2); + + auto dictSlice = wrapArray->field(0); + if (keyType->GetKind() == TType::EKind::Optional) { - Y_DEBUG_ABORT_UNLESS(array->type_id() == arrow::Type::LIST); - auto listArray = static_pointer_cast<arrow::ListArray>(array); + YQL_ENSURE(dictSlice->type_id() == arrow::Type::LIST); + auto listArray = static_pointer_cast<arrow::ListArray>(dictSlice); + auto arraySlice = listArray->value_slice(row); - Y_DEBUG_ABORT_UNLESS(arraySlice->type_id() == arrow::Type::STRUCT); + YQL_ENSURE(arraySlice->type_id() == arrow::Type::STRUCT); auto structArray = static_pointer_cast<arrow::StructArray>(arraySlice); - Y_DEBUG_ABORT_UNLESS(structArray->num_fields() == 2); + YQL_ENSURE(structArray->num_fields() == 2); + dictLength = arraySlice->length(); keyArray = structArray->field(0); payloadArray = structArray->field(1); } else { - Y_DEBUG_ABORT_UNLESS(array->type_id() == arrow::Type::MAP); - auto mapArray = static_pointer_cast<arrow::MapArray>(array); + YQL_ENSURE(dictSlice->type_id() == arrow::Type::MAP); + auto mapArray = static_pointer_cast<arrow::MapArray>(dictSlice); + dictLength = mapArray->value_length(row); offset = mapArray->value_offset(row); keyArray = mapArray->keys(); payloadArray = mapArray->items(); } + for (ui64 i = offset; i < offset + static_cast<ui64>(dictLength); ++i) { auto key = ExtractUnboxedValue(keyArray, i, keyType, holderFactory); auto payload = ExtractUnboxedValue(payloadArray, i, payloadType, holderFactory); @@ -866,35 +1099,42 @@ NUdf::TUnboxedValue ExtractUnboxedValue(const std::shared_ptr<arrow::Array>& arr } return dictBuilder->Build(); } + case TType::EKind::Variant: { // TODO Need to properly convert variants containing more than 127*127 types? auto variantType = static_cast<const TVariantType*>(itemType); - Y_DEBUG_ABORT_UNLESS(array->type_id() == arrow::Type::DENSE_UNION); + + YQL_ENSURE(array->type_id() == arrow::Type::DENSE_UNION); auto unionArray = static_pointer_cast<arrow::DenseUnionArray>(array); + auto variantIndex = unionArray->child_id(row); auto rowInChild = unionArray->value_offset(row); std::shared_ptr<arrow::Array> valuesArray = unionArray->field(variantIndex); + if (variantType->GetAlternativesCount() > arrow::UnionType::kMaxTypeCode) { // Go one step deeper - Y_DEBUG_ABORT_UNLESS(valuesArray->type_id() == arrow::Type::DENSE_UNION); + YQL_ENSURE(valuesArray->type_id() == arrow::Type::DENSE_UNION); auto innerUnionArray = static_pointer_cast<arrow::DenseUnionArray>(valuesArray); auto innerVariantIndex = innerUnionArray->child_id(rowInChild); + rowInChild = innerUnionArray->value_offset(rowInChild); valuesArray = innerUnionArray->field(innerVariantIndex); variantIndex = variantIndex * arrow::UnionType::kMaxTypeCode + innerVariantIndex; } + TType* innerType = variantType->GetUnderlyingType(); if (innerType->IsStruct()) { innerType = static_cast<TStructType*>(innerType)->GetMemberType(variantIndex); } else { - Y_VERIFY_S(innerType->IsTuple(), "Unexpected underlying variant type: " << innerType->GetKindAsStr()); + YQL_ENSURE(innerType->IsTuple(), "Unexpected underlying variant type: " << innerType->GetKindAsStr()); innerType = static_cast<TTupleType*>(innerType)->GetElementType(variantIndex); } + NUdf::TUnboxedValue value = ExtractUnboxedValue(valuesArray, rowInChild, innerType, holderFactory); return holderFactory.CreateVariantHolder(value.Release(), variantIndex); } default: - THROW yexception() << "Unsupported type: " << itemType->GetKindAsStr(); + YQL_ENSURE(false, "Unsupported type: " << itemType->GetKindAsStr()); } return NUdf::TUnboxedValuePod(); } @@ -915,20 +1155,20 @@ std::string SerializeArray(const std::shared_ptr<arrow::Array>& array) { writeOptions.use_threads = false; // TODO Decide which compression level will be default. Will it depend on the length of array? auto codecResult = arrow::util::Codec::Create(arrow::Compression::LZ4_FRAME); - Y_ABORT_UNLESS(codecResult.ok()); + YQL_ENSURE(codecResult.ok()); writeOptions.codec = std::move(codecResult.ValueOrDie()); int64_t size; auto status = GetRecordBatchSize(*batch, writeOptions, &size); - Y_ABORT_UNLESS(status.ok()); + YQL_ENSURE(status.ok()); std::string str; str.resize(size); auto writer = arrow::Buffer::GetWriter(arrow::MutableBuffer::Wrap(&str[0], size)); - Y_ABORT_UNLESS(writer.status().ok()); + YQL_ENSURE(writer.status().ok()); status = SerializeRecordBatch(*batch, writeOptions, (*writer).get()); - Y_ABORT_UNLESS(status.ok()); + YQL_ENSURE(status.ok()); return str; } @@ -941,7 +1181,7 @@ std::shared_ptr<arrow::Array> DeserializeArray(const std::string& blob, std::sha arrow::io::BufferReader reader(buffer); auto schema = std::make_shared<arrow::Schema>(std::vector<std::shared_ptr<arrow::Field>>{arrow::field("", type)}); auto batch = ReadRecordBatch(schema, &dictMemo, options, &reader); - Y_DEBUG_ABORT_UNLESS(batch.ok() && (*batch)->ValidateFull().ok(), "Failed to deserialize batch"); + YQL_ENSURE(batch.ok() && (*batch)->ValidateFull().ok(), "Failed to deserialize batch"); return (*batch)->column(0); } diff --git a/ydb/library/yql/dq/runtime/dq_arrow_helpers.h b/ydb/library/yql/dq/runtime/dq_arrow_helpers.h index 769c62a9579..9270ece1f9b 100644 --- a/ydb/library/yql/dq/runtime/dq_arrow_helpers.h +++ b/ydb/library/yql/dq/runtime/dq_arrow_helpers.h @@ -14,9 +14,25 @@ namespace NArrow { /** * @brief Convert TType to the arrow::DataType object * - * The logic of this conversion is the following: + * The logic of this conversion is from YQL-15332: * - * Struct, tuple => StructArray + * Void, Null => NullType + * Bool => Uint8 + * Integral => Uint8..Uint64, Int8..Int64 + * Floats => Float, Double + * Date => Uint16 + * Datetime => Uint32 + * Timestamp => Uint64 + * Interval => Int64 + * Date32 => Int32 + * Interval64, Timestamp64, Datetime64 => Int64 + * Utf8, Json => String + * String, Yson, JsonDocument => Binary + * Decimal, UUID => FixedSizeBinary(16) + * Timezone datetime type => StructArray<type, Uint16> + * DyNumber => BinaryArray (it is not added to YQL-15332) + * + * Struct, Tuple, EmptyList, EmptyDict => StructArray * Names of fields constructed from tuple are just empty strings. * * List => ListArray @@ -26,20 +42,26 @@ namespace NArrow { * Variant => DenseUnionArray<DenseUnionArray> * TODO Implement convertion of data to DenseUnionArray<DenseUnionArray> and back * - * Optional(Optional ..(type)..) => StructArray<ui64, type> - * Here the integer value equals the number of calls of method GetOptionalValue(). - * If value is null at some depth, then the value in second field of Array is Null - * (and the integer equals this depth). If value is present, then it is contained in the - * second field (and the integer equals the number of Optional(...) levels). - * This information is sufficient to restore an UnboxedValue knowing its type. + * Optional<T> => StructArray<T> if T is Variant + * Because DenseUnionArray does not have validity bitmap + * Optional<T> => T for other types + * By default, other types have a validity bitmap + * + * Optional<Optional<...<T>...>> => StructArray<StructArray<...StructArray<T>...>> + * For example: + * - Optional<Optional<Int32>> => StructArray<Int32> + * Int32 has validity bitmap, so we wrap it in StructArray N - 1 times, where N is the number of Optional levels + * - Optional<Optional<Variant<Int32, Int64>>> => StructArray<StructArray<DenseUnionArray<Int32, Int64>>> + * DenseUnionArray does not have validity bitmap, so we wrap it in StructArray N times, where N is the number of Optional levels * - * Dict<KeyType, ValueType> => MapArray<KeyArray, ValueArray> + * Dict<KeyType, ValueType> => StructArray<MapArray<KeyArray, ValueArray>, Uint64Array (on demand, default: 0)> * We do not use arrow::DictArray because it must be used for encoding not for mapping keys to values. * (https://arrow.apache.org/docs/cpp/api/array.html#classarrow_1_1_dictionary_array) * If the type of dict key is optional then we map - * Dict<Optional(KeyType), ValueType> => ListArray<StructArray<KeyArray, ValueArray>> + * Dict<Optional<KeyType>, ValueType> => StructArray<ListArray<StructArray<KeyArray, ValueArray>, Uint64Array (on demand, default: 0)> * because keys of MapArray can not be nullable * + * * @param type Yql type to parse * @return std::shared_ptr<arrow::DataType> arrow type of the same structure as type */ diff --git a/ydb/library/yql/dq/runtime/dq_arrow_helpers_ut.cpp b/ydb/library/yql/dq/runtime/dq_arrow_helpers_ut.cpp index d1ca1e46bd5..86dc9be2428 100644 --- a/ydb/library/yql/dq/runtime/dq_arrow_helpers_ut.cpp +++ b/ydb/library/yql/dq/runtime/dq_arrow_helpers_ut.cpp @@ -11,6 +11,7 @@ #include <library/cpp/testing/unittest/registar.h> #include <util/string/builder.h> +#include <util/string/hex.h> #include <util/system/yassert.h> #include <ydb/library/formats/arrow/arrow_helpers.h> @@ -24,6 +25,7 @@ #include <yql/essentials/minikql/computation/mkql_value_builder.h> #include <yql/essentials/minikql/mkql_node.h> #include <yql/essentials/minikql/mkql_string_util.h> +#include <yql/essentials/minikql/mkql_type_ops.h> #include <yql/essentials/public/udf/arrow/defs.h> #include <yql/essentials/public/udf/udf_data_type.h> #include <yql/essentials/public/udf/udf_string_ref.h> @@ -63,6 +65,36 @@ NUdf::TUnboxedValue GetValueOfBasicType(TType* type, ui64 value) { return NUdf::TUnboxedValuePod(static_cast<float>(value) / 1234); case NUdf::EDataSlot::Double: return NUdf::TUnboxedValuePod(static_cast<double>(value) / 12345); + case NUdf::EDataSlot::TzDate: { + auto ret = NUdf::TUnboxedValuePod(static_cast<ui16>(value % NUdf::MAX_DATE)); + ret.SetTimezoneId(NKikimr::NMiniKQL::GetTimezoneId("Europe/Moscow")); + return ret; + } + case NUdf::EDataSlot::TzDatetime: { + auto ret = NUdf::TUnboxedValuePod(static_cast<ui32>(value % NUdf::MAX_DATETIME)); + ret.SetTimezoneId(NKikimr::NMiniKQL::GetTimezoneId("Asia/Omsk")); + return ret; + } + case NUdf::EDataSlot::TzTimestamp: { + auto ret = NUdf::TUnboxedValuePod(static_cast<ui64>(value % NUdf::MAX_TIMESTAMP)); + ret.SetTimezoneId(NKikimr::NMiniKQL::GetTimezoneId("Europe/Tallinn")); + return ret; + } + case NUdf::EDataSlot::TzDate32: { + auto ret = NUdf::TUnboxedValuePod(static_cast<i32>(value % NUdf::MAX_DATE32)); + ret.SetTimezoneId(NKikimr::NMiniKQL::GetTimezoneId("US/Eastern")); + return ret; + } + case NUdf::EDataSlot::TzDatetime64: { + auto ret = NUdf::TUnboxedValuePod(static_cast<i64>(value % NUdf::MAX_DATETIME64)); + ret.SetTimezoneId(NKikimr::NMiniKQL::GetTimezoneId("America/Nuuk")); + return ret; + } + case NUdf::EDataSlot::TzTimestamp64: { + auto ret = NUdf::TUnboxedValuePod(static_cast<i64>(value % NUdf::MAX_TIMESTAMP64)); + ret.SetTimezoneId(NKikimr::NMiniKQL::GetTimezoneId("Atlantic/Faroe")); + return ret; + } default: Y_ABORT("Not implemented creation value for such type"); } @@ -110,7 +142,13 @@ struct TTestContext { TDataType::Create(NUdf::TDataType<i64>::Id, TypeEnv), TDataType::Create(NUdf::TDataType<ui64>::Id, TypeEnv), TDataType::Create(NUdf::TDataType<float>::Id, TypeEnv), - TDataType::Create(NUdf::TDataType<double>::Id, TypeEnv) + TDataType::Create(NUdf::TDataType<double>::Id, TypeEnv), + TDataType::Create(NUdf::TDataType<NUdf::TTzDate>::Id, TypeEnv), + TDataType::Create(NUdf::TDataType<NUdf::TTzDatetime>::Id, TypeEnv), + TDataType::Create(NUdf::TDataType<NUdf::TTzTimestamp>::Id, TypeEnv), + TDataType::Create(NUdf::TDataType<NUdf::TTzDate32>::Id, TypeEnv), + TDataType::Create(NUdf::TDataType<NUdf::TTzDatetime64>::Id, TypeEnv), + TDataType::Create(NUdf::TDataType<NUdf::TTzTimestamp64>::Id, TypeEnv) }; TTestContext() @@ -212,6 +250,32 @@ struct TTestContext { return values; } + TType* GetOptionalListOfOptional() { + TType* itemType = TOptionalType::Create(TDataType::Create(NUdf::TDataType<i32>::Id, TypeEnv), TypeEnv); + return TOptionalType::Create(TListType::Create(itemType, TypeEnv), TypeEnv); + } + + TUnboxedValueVector CreateOptionalListOfOptional(ui32 quantity) { + TUnboxedValueVector values; + for (ui64 value = 0; value < quantity; ++value) { + if (value % 2 == 0) { + values.emplace_back(NUdf::TUnboxedValuePod()); + continue; + } + + TUnboxedValueVector items; + items.reserve(value); + for (ui64 i = 0; i < value; ++i) { + NUdf::TUnboxedValue item = ((value + i) % 2 == 0) ? NUdf::TUnboxedValuePod() : NUdf::TUnboxedValuePod(i); + items.push_back(std::move(item).MakeOptional()); + } + + auto listValue = Vb.NewList(items.data(), value); + values.emplace_back(std::move(listValue).MakeOptional()); + } + return values; + } + TType* GetVariantOverStructType() { TStructMember members[4] = { {"0_yson", TDataType::Create(NUdf::TDataType<NUdf::TYson>::Id, TypeEnv)}, @@ -235,7 +299,8 @@ struct TTestContext { std::string data = TStringBuilder() << "{value:" << value << "}"; item = MakeString(NUdf::TStringRef(data.data(), data.size())); } else if (typeIndex == 2) { - std::string data = TStringBuilder() << "id-QwErY-" << value; + std::string sample = "7856341212905634789012345678901"; + std::string data = TStringBuilder() << HexDecode(sample + static_cast<char>('0' + (value % 10))); item = MakeString(NUdf::TStringRef(data.data(), data.size())); } else if (typeIndex == 3) { item = NUdf::TUnboxedValuePod(static_cast<float>(value) / 4); @@ -246,6 +311,79 @@ struct TTestContext { return values; } + TType* GetOptionalVariantOverStructType() { + return TOptionalType::Create(GetVariantOverStructType(), TypeEnv); + } + + TUnboxedValueVector CreateOptionalVariantOverStruct(ui32 quantity) { + TUnboxedValueVector values; + for (ui64 value = 0; value < quantity; ++value) { + auto typeIndex = value % 4; + NUdf::TUnboxedValue item; + + if (value % 2 == 0) { + values.push_back(NUdf::TUnboxedValuePod()); + continue; + } + + if (typeIndex == 0) { + std::string data = TStringBuilder() << "{value=" << value << "}"; + item = MakeString(NUdf::TStringRef(data.data(), data.size())); + } else if (typeIndex == 1) { + std::string data = TStringBuilder() << "{value:" << value << "}"; + item = MakeString(NUdf::TStringRef(data.data(), data.size())); + } else if (typeIndex == 2) { + std::string sample = "7856341212905634789012345678901"; + std::string data = TStringBuilder() << HexDecode(sample + static_cast<char>('0' + (value % 10))); + item = MakeString(NUdf::TStringRef(data.data(), data.size())); + } else if (typeIndex == 3) { + item = NUdf::TUnboxedValuePod(static_cast<float>(value) / 4); + } + auto wrapped = Vb.NewVariant(typeIndex, std::move(item)).MakeOptional(); + values.push_back(std::move(wrapped)); + } + return values; + } + + TType* GetDoubleOptionalVariantOverStructType() { + return TOptionalType::Create(GetOptionalVariantOverStructType(), TypeEnv); + } + + TUnboxedValueVector CreateDoubleOptionalVariantOverStruct(ui32 quantity) { + TUnboxedValueVector values; + for (ui64 value = 0; value < quantity; ++value) { + auto typeIndex = value % 4; + NUdf::TUnboxedValue item; + + if (value % 3 == 0) { + if (typeIndex == 0) { + std::string data = TStringBuilder() << "{value=" << value << "}"; + item = MakeString(NUdf::TStringRef(data.data(), data.size())); + } else if (typeIndex == 1) { + std::string data = TStringBuilder() << "{value:" << value << "}"; + item = MakeString(NUdf::TStringRef(data.data(), data.size())); + } else if (typeIndex == 2) { + std::string sample = "7856341212905634789012345678901"; + std::string data = TStringBuilder() << HexDecode(sample + static_cast<char>('0' + (value % 10))); + item = MakeString(NUdf::TStringRef(data.data(), data.size())); + } else if (typeIndex == 3) { + item = NUdf::TUnboxedValuePod(static_cast<float>(value) / 4); + } + + item = Vb.NewVariant(typeIndex, std::move(item)).MakeOptional(); + } else { + item = NUdf::TUnboxedValuePod(); + } + + if (value % 3 != 2) { + item = item.MakeOptional(); + } + + values.push_back(std::move(item)); + } + return values; + } + TType* GetVariantOverTupleWithOptionalsType() { TType* members[5] = { TDataType::Create(NUdf::TDataType<bool>::Id, TypeEnv), @@ -284,6 +422,83 @@ struct TTestContext { return values; } + TType* GetOptionalVariantOverTupleWithOptionalsType() { + return TOptionalType::Create(GetVariantOverTupleWithOptionalsType(), TypeEnv); + } + + TUnboxedValueVector CreateOptionalVariantOverTupleWithOptionals(ui32 quantity) { + NKikimr::NMiniKQL::TUnboxedValueVector values; + for (ui64 value = 0; value < quantity; ++value) { + + if (value % 2 == 0) { + values.push_back(NUdf::TUnboxedValuePod()); + continue; + } + + auto typeIndex = value % 5; + NUdf::TUnboxedValue item; + if (typeIndex == 0) { + item = NUdf::TUnboxedValuePod(value % 3 == 0); + } else if (typeIndex == 1) { + item = NUdf::TUnboxedValuePod(static_cast<i16>(-value)); + } else if (typeIndex == 2) { + item = NUdf::TUnboxedValuePod(static_cast<ui16>(value)); + } else if (typeIndex == 3) { + item = NUdf::TUnboxedValuePod(static_cast<i32>(-value)); + } else if (typeIndex == 4) { + NUdf::TUnboxedValue innerItem; + innerItem = value % 2 == 0 + ? NUdf::TUnboxedValuePod(static_cast<i32>(value)) + : NUdf::TUnboxedValuePod(); + item = innerItem.MakeOptional(); + } + auto wrapped = Vb.NewVariant(typeIndex, std::move(item)).MakeOptional(); + values.emplace_back(std::move(wrapped)); + } + return values; + } + + TType* GetDoubleOptionalVariantOverTupleWithOptionalsType() { + return TOptionalType::Create(GetOptionalVariantOverTupleWithOptionalsType(), TypeEnv); + } + + TUnboxedValueVector CreateDoubleOptionalVariantOverTupleWithOptionals(ui32 quantity) { + NKikimr::NMiniKQL::TUnboxedValueVector values; + for (ui64 value = 0; value < quantity; ++value) { + auto typeIndex = value % 5; + NUdf::TUnboxedValue item; + + if (value % 3 == 0) { + if (typeIndex == 0) { + item = NUdf::TUnboxedValuePod(value % 3 == 0); + } else if (typeIndex == 1) { + item = NUdf::TUnboxedValuePod(static_cast<i16>(-value)); + } else if (typeIndex == 2) { + item = NUdf::TUnboxedValuePod(static_cast<ui16>(value)); + } else if (typeIndex == 3) { + item = NUdf::TUnboxedValuePod(static_cast<i32>(-value)); + } else if (typeIndex == 4) { + NUdf::TUnboxedValue innerItem; + innerItem = value % 2 == 0 + ? NUdf::TUnboxedValuePod(static_cast<i32>(value)) + : NUdf::TUnboxedValuePod(); + item = innerItem.MakeOptional(); + } + + item = Vb.NewVariant(typeIndex, std::move(item)); + } else { + item = NUdf::TUnboxedValuePod(); + } + + if (value % 3 != 2) { + item = item.MakeOptional(); + } + + values.emplace_back(std::move(item)); + } + return values; + } + TType* GetDictOptionalToTupleType() { TType* keyType = TOptionalType::Create(TDataType::Create(NUdf::TDataType<double>::Id, TypeEnv), TypeEnv); TType* members[2] = { @@ -346,7 +561,7 @@ struct TTestContext { for (ui64 index = 0; index < variantSize; ++index) { TVector<TType*> selectedTypes; for (ui32 i = 0; i < BasicTypes.size(); ++i) { - if ((index >> i) % 2 == 1) { + if ((index ^ i) % 5 >= 2) { selectedTypes.push_back(BasicTypes[i]); } } @@ -363,7 +578,7 @@ struct TTestContext { auto typeIndex = index % VariantSize; TUnboxedValueVector tupleItems; for (ui64 i = 0; i < BasicTypes.size(); ++i) { - if ((typeIndex >> i) % 2 == 1) { + if ((typeIndex ^ i) % 5 >= 2) { tupleItems.push_back(GetValueOfBasicType(BasicTypes[i], i)); } } @@ -620,13 +835,13 @@ Y_UNIT_TEST_SUITE(DqUnboxedValueToNativeArrowConversion) { UNIT_ASSERT(array->type_id() == arrow::Type::STRUCT); auto structArray = static_pointer_cast<arrow::StructArray>(array); UNIT_ASSERT(structArray->num_fields() == 3); - UNIT_ASSERT(structArray->field(0)->type_id() == arrow::Type::BOOL); + UNIT_ASSERT(structArray->field(0)->type_id() == arrow::Type::UINT8); UNIT_ASSERT(structArray->field(1)->type_id() == arrow::Type::INT8); UNIT_ASSERT(structArray->field(2)->type_id() == arrow::Type::UINT8); UNIT_ASSERT(static_cast<ui64>(structArray->field(0)->length()) == values.size()); UNIT_ASSERT(static_cast<ui64>(structArray->field(1)->length()) == values.size()); UNIT_ASSERT(static_cast<ui64>(structArray->field(2)->length()) == values.size()); - auto boolArray = static_pointer_cast<arrow::BooleanArray>(structArray->field(0)); + auto boolArray = static_pointer_cast<arrow::UInt8Array>(structArray->field(0)); auto int8Array = static_pointer_cast<arrow::Int8Array>(structArray->field(1)); auto uint8Array = static_pointer_cast<arrow::UInt8Array>(structArray->field(2)); auto index = 0; @@ -646,38 +861,6 @@ Y_UNIT_TEST_SUITE(DqUnboxedValueToNativeArrowConversion) { } } - Y_UNIT_TEST(DictUtf8ToInterval) { - TTestContext context; - - auto dictType = context.GetDictUtf8ToIntervalType(); - UNIT_ASSERT(NArrow::IsArrowCompatible(dictType)); - - auto values = context.CreateDictUtf8ToInterval(100); - auto array = NArrow::MakeArray(values, dictType); - UNIT_ASSERT(array->ValidateFull().ok()); - UNIT_ASSERT(static_cast<ui64>(array->length()) == values.size()); - UNIT_ASSERT(array->type_id() == arrow::Type::MAP); - auto mapArray = static_pointer_cast<arrow::MapArray>(array); - - UNIT_ASSERT(mapArray->num_fields() == 1); - UNIT_ASSERT(mapArray->keys()->type_id() == arrow::Type::STRING); - UNIT_ASSERT(mapArray->items()->type_id() == arrow::Type::DURATION); - auto utf8Array = static_pointer_cast<arrow::StringArray>(mapArray->keys()); - auto intervalArray = static_pointer_cast<arrow::NumericArray<arrow::DurationType>>(mapArray->items()); - ui64 index = 0; - for (const auto& value: values) { - UNIT_ASSERT(value.GetDictLength() == static_cast<ui64>(mapArray->value_length(index))); - for (auto subindex = mapArray->value_offset(index); subindex < mapArray->value_offset(index + 1); ++subindex) { - auto keyArrow = utf8Array->GetView(subindex); - NUdf::TUnboxedValue key = MakeString(NUdf::TStringRef(keyArrow.data(), keyArrow.size())); - UNIT_ASSERT(value.Contains(key)); - NUdf::TUnboxedValue payloadValue = value.Lookup(key); - UNIT_ASSERT(intervalArray->Value(subindex) == payloadValue.Get<i64>()); - } - ++index; - } - } - Y_UNIT_TEST(ListOfJsons) { TTestContext context; @@ -711,6 +894,48 @@ Y_UNIT_TEST_SUITE(DqUnboxedValueToNativeArrowConversion) { } } + Y_UNIT_TEST(OptionalListOfOptional) { + TTestContext context; + + auto listType = context.GetOptionalListOfOptional(); + Y_ABORT_UNLESS(NArrow::IsArrowCompatible(listType)); + + auto values = context.CreateOptionalListOfOptional(100); + auto array = NArrow::MakeArray(values, listType); + UNIT_ASSERT(array->ValidateFull().ok()); + UNIT_ASSERT(static_cast<ui64>(array->length()) == values.size()); + UNIT_ASSERT(array->type_id() == arrow::Type::LIST); + + auto listArray = static_pointer_cast<arrow::ListArray>(array); + UNIT_ASSERT(listArray->num_fields() == 1); + UNIT_ASSERT(listArray->value_type()->id() == arrow::Type::INT32); + + auto i32Array = static_pointer_cast<arrow::Int32Array>(listArray->values()); + auto index = 0; + auto innerIndex = 0; + for (const auto& value: values) { + if (!value.HasValue()) { + UNIT_ASSERT(listArray->IsNull(index)); + ++index; + continue; + } + + auto listValue = value.GetOptionalValue(); + + UNIT_ASSERT_VALUES_EQUAL(listValue.GetListLength(), static_cast<ui64>(listArray->value_length(index))); + const auto iter = listValue.GetListIterator(); + for (NUdf::TUnboxedValue item; iter.Next(item);) { + if (!item.HasValue()) { + UNIT_ASSERT(i32Array->IsNull(innerIndex)); + } else { + UNIT_ASSERT(i32Array->Value(innerIndex) == item.GetOptionalValue().Get<i32>()); + } + ++innerIndex; + } + ++index; + } + } + Y_UNIT_TEST(VariantOverStruct) { TTestContext context; @@ -725,14 +950,16 @@ Y_UNIT_TEST_SUITE(DqUnboxedValueToNativeArrowConversion) { auto unionArray = static_pointer_cast<arrow::DenseUnionArray>(array); UNIT_ASSERT(unionArray->num_fields() == 4); - UNIT_ASSERT(unionArray->field(0)->type_id() == arrow::Type::STRING); - UNIT_ASSERT(unionArray->field(1)->type_id() == arrow::Type::STRING); - UNIT_ASSERT(unionArray->field(2)->type_id() == arrow::Type::BINARY); + UNIT_ASSERT(unionArray->field(0)->type_id() == arrow::Type::BINARY); + UNIT_ASSERT(unionArray->field(1)->type_id() == arrow::Type::BINARY); + UNIT_ASSERT(unionArray->field(2)->type_id() == arrow::Type::FIXED_SIZE_BINARY); UNIT_ASSERT(unionArray->field(3)->type_id() == arrow::Type::FLOAT); - auto ysonArray = static_pointer_cast<arrow::StringArray>(unionArray->field(0)); - auto jsonDocArray = static_pointer_cast<arrow::StringArray>(unionArray->field(1)); - auto uuidArray = static_pointer_cast<arrow::BinaryArray>(unionArray->field(2)); + + auto ysonArray = static_pointer_cast<arrow::BinaryArray>(unionArray->field(0)); + auto jsonDocArray = static_pointer_cast<arrow::BinaryArray>(unionArray->field(1)); + auto uuidArray = static_pointer_cast<arrow::FixedSizeBinaryArray>(unionArray->field(2)); auto floatArray = static_pointer_cast<arrow::FloatArray>(unionArray->field(3)); + for (ui64 index = 0; index < values.size(); ++index) { auto value = values[index]; UNIT_ASSERT(value.GetVariantIndex() == static_cast<ui32>(unionArray->child_id(index))); @@ -759,13 +986,148 @@ Y_UNIT_TEST_SUITE(DqUnboxedValueToNativeArrowConversion) { } } + Y_UNIT_TEST(OptionalVariantOverStruct) { + TTestContext context; + + auto variantType = context.GetOptionalVariantOverStructType(); + UNIT_ASSERT(!NArrow::IsArrowCompatible(variantType)); + + auto values = context.CreateOptionalVariantOverStruct(100); + auto array = NArrow::MakeArray(values, variantType); + UNIT_ASSERT(array->ValidateFull().ok()); + UNIT_ASSERT(static_cast<ui64>(array->length()) == values.size()); + UNIT_ASSERT(array->type_id() == arrow::Type::STRUCT); + + auto structArray = static_pointer_cast<arrow::StructArray>(array); + UNIT_ASSERT(structArray->num_fields() == 1); + UNIT_ASSERT(structArray->field(0)->type_id() == arrow::Type::DENSE_UNION); + + auto unionArray = static_pointer_cast<arrow::DenseUnionArray>(structArray->field(0)); + + UNIT_ASSERT(unionArray->num_fields() == 4); + UNIT_ASSERT(unionArray->field(0)->type_id() == arrow::Type::BINARY); + UNIT_ASSERT(unionArray->field(1)->type_id() == arrow::Type::BINARY); + UNIT_ASSERT(unionArray->field(2)->type_id() == arrow::Type::FIXED_SIZE_BINARY); + UNIT_ASSERT(unionArray->field(3)->type_id() == arrow::Type::FLOAT); + + auto ysonArray = static_pointer_cast<arrow::BinaryArray>(unionArray->field(0)); + auto jsonDocArray = static_pointer_cast<arrow::BinaryArray>(unionArray->field(1)); + auto uuidArray = static_pointer_cast<arrow::FixedSizeBinaryArray>(unionArray->field(2)); + auto floatArray = static_pointer_cast<arrow::FloatArray>(unionArray->field(3)); + + for (ui64 index = 0; index < values.size(); ++index) { + auto value = values[index]; + if (!value.HasValue()) { + // NULL + UNIT_ASSERT(structArray->IsNull(index)); + continue; + } + + UNIT_ASSERT(!structArray->IsNull(index)); + + UNIT_ASSERT(value.GetVariantIndex() == static_cast<ui32>(unionArray->child_id(index))); + auto fieldIndex = unionArray->value_offset(index); + if (value.GetVariantIndex() == 3) { + auto valueArrow = floatArray->Value(fieldIndex); + auto valueInner = value.GetVariantItem().Get<float>(); + UNIT_ASSERT(valueArrow == valueInner); + } else { + arrow::util::string_view viewArrow; + if (value.GetVariantIndex() == 0) { + viewArrow = ysonArray->GetView(fieldIndex); + } else if (value.GetVariantIndex() == 1) { + viewArrow = jsonDocArray->GetView(fieldIndex); + } else if (value.GetVariantIndex() == 2) { + viewArrow = uuidArray->GetView(fieldIndex); + } + std::string valueArrow(viewArrow.data(), viewArrow.size()); + auto innerItem = value.GetVariantItem(); + auto refInner = innerItem.AsStringRef(); + std::string valueInner(refInner.Data(), refInner.Size()); + UNIT_ASSERT(valueArrow == valueInner); + } + } + } + + Y_UNIT_TEST(DoubleOptionalVariantOverStruct) { + TTestContext context; + + auto variantType = context.GetDoubleOptionalVariantOverStructType(); + UNIT_ASSERT(!NArrow::IsArrowCompatible(variantType)); + + auto values = context.CreateDoubleOptionalVariantOverStruct(100); + auto array = NArrow::MakeArray(values, variantType); + UNIT_ASSERT(array->ValidateFull().ok()); + UNIT_ASSERT(static_cast<ui64>(array->length()) == values.size()); + UNIT_ASSERT(array->type_id() == arrow::Type::STRUCT); + + auto firstStructArray = static_pointer_cast<arrow::StructArray>(array); + UNIT_ASSERT(firstStructArray->num_fields() == 1); + UNIT_ASSERT(firstStructArray->field(0)->type_id() == arrow::Type::STRUCT); + + auto secondStructArray = static_pointer_cast<arrow::StructArray>(firstStructArray->field(0)); + UNIT_ASSERT(secondStructArray->num_fields() == 1); + UNIT_ASSERT(secondStructArray->field(0)->type_id() == arrow::Type::DENSE_UNION); + + auto unionArray = static_pointer_cast<arrow::DenseUnionArray>(secondStructArray->field(0)); + + UNIT_ASSERT(unionArray->num_fields() == 4); + UNIT_ASSERT(unionArray->field(0)->type_id() == arrow::Type::BINARY); + UNIT_ASSERT(unionArray->field(1)->type_id() == arrow::Type::BINARY); + UNIT_ASSERT(unionArray->field(2)->type_id() == arrow::Type::FIXED_SIZE_BINARY); + UNIT_ASSERT(unionArray->field(3)->type_id() == arrow::Type::FLOAT); + + auto ysonArray = static_pointer_cast<arrow::BinaryArray>(unionArray->field(0)); + auto jsonDocArray = static_pointer_cast<arrow::BinaryArray>(unionArray->field(1)); + auto uuidArray = static_pointer_cast<arrow::FixedSizeBinaryArray>(unionArray->field(2)); + auto floatArray = static_pointer_cast<arrow::FloatArray>(unionArray->field(3)); + + for (ui64 index = 0; index < values.size(); ++index) { + auto value = values[index]; + if (!value.HasValue()) { + if (value) { + // Optional(NULL) + UNIT_ASSERT(secondStructArray->IsNull(index)); + } else { + // NULL + UNIT_ASSERT(firstStructArray->IsNull(index)); + } + continue; + } + + UNIT_ASSERT(!firstStructArray->IsNull(index) && !secondStructArray->IsNull(index)); + + UNIT_ASSERT(value.GetVariantIndex() == static_cast<ui32>(unionArray->child_id(index))); + auto fieldIndex = unionArray->value_offset(index); + if (value.GetVariantIndex() == 3) { + auto valueArrow = floatArray->Value(fieldIndex); + auto valueInner = value.GetVariantItem().Get<float>(); + UNIT_ASSERT_VALUES_EQUAL(valueArrow, valueInner); + } else { + arrow::util::string_view viewArrow; + if (value.GetVariantIndex() == 0) { + viewArrow = ysonArray->GetView(fieldIndex); + } else if (value.GetVariantIndex() == 1) { + viewArrow = jsonDocArray->GetView(fieldIndex); + } else if (value.GetVariantIndex() == 2) { + viewArrow = uuidArray->GetView(fieldIndex); + } + std::string valueArrow(viewArrow.data(), viewArrow.size()); + auto innerItem = value.GetVariantItem(); + auto refInner = innerItem.AsStringRef(); + std::string valueInner(refInner.Data(), refInner.Size()); + UNIT_ASSERT_VALUES_EQUAL(valueArrow, valueInner); + } + } + } + Y_UNIT_TEST(VariantOverTupleWithOptionals) { TTestContext context; auto variantType = context.GetVariantOverTupleWithOptionalsType(); UNIT_ASSERT(NArrow::IsArrowCompatible(variantType)); - auto values = context.CreateVariantOverStruct(100); + auto values = context.CreateVariantOverTupleWithOptionals(100); auto array = NArrow::MakeArray(values, variantType); UNIT_ASSERT(array->ValidateFull().ok()); UNIT_ASSERT(static_cast<ui64>(array->length()) == values.size()); @@ -773,12 +1135,12 @@ Y_UNIT_TEST_SUITE(DqUnboxedValueToNativeArrowConversion) { auto unionArray = static_pointer_cast<arrow::DenseUnionArray>(array); UNIT_ASSERT(unionArray->num_fields() == 5); - UNIT_ASSERT(unionArray->field(0)->type_id() == arrow::Type::BOOL); + UNIT_ASSERT(unionArray->field(0)->type_id() == arrow::Type::UINT8); UNIT_ASSERT(unionArray->field(1)->type_id() == arrow::Type::INT16); UNIT_ASSERT(unionArray->field(2)->type_id() == arrow::Type::UINT16); UNIT_ASSERT(unionArray->field(3)->type_id() == arrow::Type::INT32); UNIT_ASSERT(unionArray->field(4)->type_id() == arrow::Type::UINT32); - auto boolArray = static_pointer_cast<arrow::BooleanArray>(unionArray->field(0)); + auto boolArray = static_pointer_cast<arrow::UInt8Array>(unionArray->field(0)); auto i16Array = static_pointer_cast<arrow::Int16Array>(unionArray->field(1)); auto ui16Array = static_pointer_cast<arrow::UInt16Array>(unionArray->field(2)); auto i32Array = static_pointer_cast<arrow::Int32Array>(unionArray->field(3)); @@ -790,26 +1152,175 @@ Y_UNIT_TEST_SUITE(DqUnboxedValueToNativeArrowConversion) { if (value.GetVariantIndex() == 0) { bool valueArrow = boolArray->Value(fieldIndex); auto valueInner = value.GetVariantItem().Get<bool>(); - UNIT_ASSERT(valueArrow == valueInner); + UNIT_ASSERT_VALUES_EQUAL(valueArrow, valueInner); } else if (value.GetVariantIndex() == 1) { auto valueArrow = i16Array->Value(fieldIndex); auto valueInner = value.GetVariantItem().Get<i16>(); - UNIT_ASSERT(valueArrow == valueInner); + UNIT_ASSERT_VALUES_EQUAL(valueArrow, valueInner); } else if (value.GetVariantIndex() == 2) { auto valueArrow = ui16Array->Value(fieldIndex); auto valueInner = value.GetVariantItem().Get<ui16>(); - UNIT_ASSERT(valueArrow == valueInner); + UNIT_ASSERT_VALUES_EQUAL(valueArrow, valueInner); } else if (value.GetVariantIndex() == 3) { auto valueArrow = i32Array->Value(fieldIndex); auto valueInner = value.GetVariantItem().Get<i32>(); - UNIT_ASSERT(valueArrow == valueInner); + UNIT_ASSERT_VALUES_EQUAL(valueArrow, valueInner); + } else if (value.GetVariantIndex() == 4) { + if (!value.GetVariantItem().HasValue()) { + UNIT_ASSERT(ui32Array->IsNull(fieldIndex)); + } else { + auto valueArrow = ui32Array->Value(fieldIndex); + auto valueInner = value.GetVariantItem().Get<ui32>(); + UNIT_ASSERT_VALUES_EQUAL(valueArrow, valueInner); + } + } + } + } + + Y_UNIT_TEST(OptionalVariantOverTupleWithOptionals) { + // DenseUnionArray does not support NULL values, so we wrap it in a StructArray + + TTestContext context; + + auto variantType = context.GetOptionalVariantOverTupleWithOptionalsType(); + UNIT_ASSERT(!NArrow::IsArrowCompatible(variantType)); + + auto values = context.CreateOptionalVariantOverTupleWithOptionals(100); + auto array = NArrow::MakeArray(values, variantType); + UNIT_ASSERT(array->ValidateFull().ok()); + UNIT_ASSERT(static_cast<ui64>(array->length()) == values.size()); + UNIT_ASSERT(array->type_id() == arrow::Type::STRUCT); + + auto structArray = static_pointer_cast<arrow::StructArray>(array); + UNIT_ASSERT(structArray->num_fields() == 1); + UNIT_ASSERT(structArray->field(0)->type_id() == arrow::Type::DENSE_UNION); + + auto unionArray = static_pointer_cast<arrow::DenseUnionArray>(structArray->field(0)); + UNIT_ASSERT(unionArray->num_fields() == 5); + UNIT_ASSERT(unionArray->field(0)->type_id() == arrow::Type::UINT8); + UNIT_ASSERT(unionArray->field(1)->type_id() == arrow::Type::INT16); + UNIT_ASSERT(unionArray->field(2)->type_id() == arrow::Type::UINT16); + UNIT_ASSERT(unionArray->field(3)->type_id() == arrow::Type::INT32); + UNIT_ASSERT(unionArray->field(4)->type_id() == arrow::Type::UINT32); + auto boolArray = static_pointer_cast<arrow::UInt8Array>(unionArray->field(0)); + auto i16Array = static_pointer_cast<arrow::Int16Array>(unionArray->field(1)); + auto ui16Array = static_pointer_cast<arrow::UInt16Array>(unionArray->field(2)); + auto i32Array = static_pointer_cast<arrow::Int32Array>(unionArray->field(3)); + auto ui32Array = static_pointer_cast<arrow::UInt32Array>(unionArray->field(4)); + for (ui64 index = 0; index < values.size(); ++index) { + auto value = values[index]; + if (!value) { + // NULL + UNIT_ASSERT(structArray->IsNull(index)); + continue; + } + + UNIT_ASSERT(!structArray->IsNull(index)); + + UNIT_ASSERT(value.GetVariantIndex() == static_cast<ui32>(unionArray->child_id(index))); + auto fieldIndex = unionArray->value_offset(index); + if (value.GetVariantIndex() == 0) { + bool valueArrow = boolArray->Value(fieldIndex); + auto valueInner = value.GetVariantItem().Get<bool>(); + UNIT_ASSERT_VALUES_EQUAL(valueArrow, valueInner); + } else if (value.GetVariantIndex() == 1) { + auto valueArrow = i16Array->Value(fieldIndex); + auto valueInner = value.GetVariantItem().Get<i16>(); + UNIT_ASSERT_VALUES_EQUAL(valueArrow, valueInner); + } else if (value.GetVariantIndex() == 2) { + auto valueArrow = ui16Array->Value(fieldIndex); + auto valueInner = value.GetVariantItem().Get<ui16>(); + UNIT_ASSERT_VALUES_EQUAL(valueArrow, valueInner); + } else if (value.GetVariantIndex() == 3) { + auto valueArrow = i32Array->Value(fieldIndex); + auto valueInner = value.GetVariantItem().Get<i32>(); + UNIT_ASSERT_VALUES_EQUAL(valueArrow, valueInner); } else if (value.GetVariantIndex() == 4) { if (!value.GetVariantItem().HasValue()) { UNIT_ASSERT(ui32Array->IsNull(fieldIndex)); } else { auto valueArrow = ui32Array->Value(fieldIndex); auto valueInner = value.GetVariantItem().Get<ui32>(); - UNIT_ASSERT(valueArrow == valueInner); + UNIT_ASSERT_VALUES_EQUAL(valueArrow, valueInner); + } + } + } + } + + Y_UNIT_TEST(DoubleOptionalVariantOverTupleWithOptionals) { + // DenseUnionArray does not support NULL values, so we wrap it in a StructArray + + TTestContext context; + + auto variantType = context.GetDoubleOptionalVariantOverTupleWithOptionalsType(); + UNIT_ASSERT(!NArrow::IsArrowCompatible(variantType)); + + auto values = context.CreateDoubleOptionalVariantOverTupleWithOptionals(100); + auto array = NArrow::MakeArray(values, variantType); + UNIT_ASSERT(array->ValidateFull().ok()); + UNIT_ASSERT(static_cast<ui64>(array->length()) == values.size()); + UNIT_ASSERT(array->type_id() == arrow::Type::STRUCT); + + auto firstStructArray = static_pointer_cast<arrow::StructArray>(array); + UNIT_ASSERT(firstStructArray->num_fields() == 1); + UNIT_ASSERT(firstStructArray->field(0)->type_id() == arrow::Type::STRUCT); + + auto secondStructArray = static_pointer_cast<arrow::StructArray>(firstStructArray->field(0)); + UNIT_ASSERT(secondStructArray->num_fields() == 1); + UNIT_ASSERT(secondStructArray->field(0)->type_id() == arrow::Type::DENSE_UNION); + + auto unionArray = static_pointer_cast<arrow::DenseUnionArray>(secondStructArray->field(0)); + UNIT_ASSERT(unionArray->num_fields() == 5); + UNIT_ASSERT(unionArray->field(0)->type_id() == arrow::Type::UINT8); + UNIT_ASSERT(unionArray->field(1)->type_id() == arrow::Type::INT16); + UNIT_ASSERT(unionArray->field(2)->type_id() == arrow::Type::UINT16); + UNIT_ASSERT(unionArray->field(3)->type_id() == arrow::Type::INT32); + UNIT_ASSERT(unionArray->field(4)->type_id() == arrow::Type::UINT32); + auto boolArray = static_pointer_cast<arrow::UInt8Array>(unionArray->field(0)); + auto i16Array = static_pointer_cast<arrow::Int16Array>(unionArray->field(1)); + auto ui16Array = static_pointer_cast<arrow::UInt16Array>(unionArray->field(2)); + auto i32Array = static_pointer_cast<arrow::Int32Array>(unionArray->field(3)); + auto ui32Array = static_pointer_cast<arrow::UInt32Array>(unionArray->field(4)); + for (ui64 index = 0; index < values.size(); ++index) { + auto value = values[index]; + if (!value.HasValue()) { + if (value && !value.GetOptionalValue()) { + // Optional(NULL) + UNIT_ASSERT(secondStructArray->IsNull(index)); + } else if (!value) { + // NULL + UNIT_ASSERT(firstStructArray->IsNull(index)); + } + continue; + } + + UNIT_ASSERT(!firstStructArray->IsNull(index) && !secondStructArray->IsNull(index)); + + UNIT_ASSERT(value.GetVariantIndex() == static_cast<ui32>(unionArray->child_id(index))); + auto fieldIndex = unionArray->value_offset(index); + if (value.GetVariantIndex() == 0) { + bool valueArrow = boolArray->Value(fieldIndex); + auto valueInner = value.GetVariantItem().Get<bool>(); + UNIT_ASSERT_VALUES_EQUAL(valueArrow, valueInner); + } else if (value.GetVariantIndex() == 1) { + auto valueArrow = i16Array->Value(fieldIndex); + auto valueInner = value.GetVariantItem().Get<i16>(); + UNIT_ASSERT_VALUES_EQUAL(valueArrow, valueInner); + } else if (value.GetVariantIndex() == 2) { + auto valueArrow = ui16Array->Value(fieldIndex); + auto valueInner = value.GetVariantItem().Get<ui16>(); + UNIT_ASSERT_VALUES_EQUAL(valueArrow, valueInner); + } else if (value.GetVariantIndex() == 3) { + auto valueArrow = i32Array->Value(fieldIndex); + auto valueInner = value.GetVariantItem().Get<i32>(); + UNIT_ASSERT_VALUES_EQUAL(valueArrow, valueInner); + } else if (value.GetVariantIndex() == 4) { + if (!value.GetVariantItem().HasValue()) { + UNIT_ASSERT(ui32Array->IsNull(fieldIndex)); + } else { + auto valueArrow = ui32Array->Value(fieldIndex); + auto valueInner = value.GetVariantItem().Get<ui32>(); + UNIT_ASSERT_VALUES_EQUAL(valueArrow, valueInner); } } } @@ -817,6 +1328,51 @@ Y_UNIT_TEST_SUITE(DqUnboxedValueToNativeArrowConversion) { } Y_UNIT_TEST_SUITE(DqUnboxedValueDoNotFitToArrow) { + Y_UNIT_TEST(DictUtf8ToInterval) { + TTestContext context; + + auto dictType = context.GetDictUtf8ToIntervalType(); + UNIT_ASSERT(!NArrow::IsArrowCompatible(dictType)); + + auto values = context.CreateDictUtf8ToInterval(100); + auto array = NArrow::MakeArray(values, dictType); + UNIT_ASSERT(array->ValidateFull().ok()); + + UNIT_ASSERT(array->type_id() == arrow::Type::STRUCT); + auto wrapArray = static_pointer_cast<arrow::StructArray>(array); + UNIT_ASSERT_VALUES_EQUAL(wrapArray->num_fields(), 2); + UNIT_ASSERT_VALUES_EQUAL(static_cast<ui64>(wrapArray->length()), values.size()); + + UNIT_ASSERT(wrapArray->field(0)->type_id() == arrow::Type::MAP); + auto mapArray = static_pointer_cast<arrow::MapArray>(wrapArray->field(0)); + UNIT_ASSERT_VALUES_EQUAL(static_cast<ui64>(mapArray->length()), values.size()); + + UNIT_ASSERT(wrapArray->field(1)->type_id() == arrow::Type::UINT64); + auto customArray = static_pointer_cast<arrow::UInt64Array>(wrapArray->field(1)); + UNIT_ASSERT_VALUES_EQUAL(static_cast<ui64>(customArray->length()), values.size()); + + UNIT_ASSERT_VALUES_EQUAL(mapArray->num_fields(), 1); + + UNIT_ASSERT(mapArray->keys()->type_id() == arrow::Type::STRING); + auto utf8Array = static_pointer_cast<arrow::StringArray>(mapArray->keys()); + + UNIT_ASSERT(mapArray->items()->type_id() == arrow::Type::INT64); + auto intervalArray = static_pointer_cast<arrow::Int64Array>(mapArray->items()); + + ui64 index = 0; + for (const auto& value: values) { + UNIT_ASSERT_VALUES_EQUAL(value.GetDictLength(), static_cast<ui64>(mapArray->value_length(index))); + for (auto subindex = mapArray->value_offset(index); subindex < mapArray->value_offset(index + 1); ++subindex) { + auto keyArrow = utf8Array->GetView(subindex); + NUdf::TUnboxedValue key = MakeString(NUdf::TStringRef(keyArrow.data(), keyArrow.size())); + UNIT_ASSERT(value.Contains(key)); + NUdf::TUnboxedValue payloadValue = value.Lookup(key); + UNIT_ASSERT_VALUES_EQUAL(intervalArray->Value(subindex), payloadValue.Get<i64>()); + } + ++index; + } + } + Y_UNIT_TEST(DictOptionalToTuple) { TTestContext context; @@ -827,8 +1383,20 @@ Y_UNIT_TEST_SUITE(DqUnboxedValueDoNotFitToArrow) { auto array = NArrow::MakeArray(values, dictType); UNIT_ASSERT(array->ValidateFull().ok()); UNIT_ASSERT_EQUAL(static_cast<ui64>(array->length()), values.size()); - UNIT_ASSERT_EQUAL(array->type_id(), arrow::Type::LIST); - auto listArray = static_pointer_cast<arrow::ListArray>(array); + UNIT_ASSERT_EQUAL(array->type_id(), arrow::Type::STRUCT); + + auto wrapArray = static_pointer_cast<arrow::StructArray>(array); + UNIT_ASSERT_EQUAL(wrapArray->num_fields(), 2); + UNIT_ASSERT_EQUAL(wrapArray->field(0)->type_id(), arrow::Type::LIST); + + UNIT_ASSERT_EQUAL(wrapArray->field(1)->type_id(), arrow::Type::UINT64); + auto listArray = static_pointer_cast<arrow::ListArray>(wrapArray->field(0)); + UNIT_ASSERT_EQUAL(static_cast<ui64>(listArray->length()), values.size()); + + UNIT_ASSERT_EQUAL(wrapArray->field(1)->type_id(), arrow::Type::UINT64); + auto customArray = static_pointer_cast<arrow::UInt64Array>(wrapArray->field(1)); + UNIT_ASSERT_EQUAL(static_cast<ui64>(customArray->length()), values.size()); + UNIT_ASSERT_EQUAL(listArray->value_type()->id(), arrow::Type::STRUCT); auto structArray = static_pointer_cast<arrow::StructArray>(listArray->values()); @@ -870,27 +1438,45 @@ Y_UNIT_TEST_SUITE(DqUnboxedValueDoNotFitToArrow) { auto array = NArrow::MakeArray(values, doubleOptionalType); UNIT_ASSERT(array->ValidateFull().ok()); UNIT_ASSERT_EQUAL(static_cast<ui64>(array->length()), values.size()); - UNIT_ASSERT_EQUAL(array->type_id(), arrow::Type::STRUCT); - auto structArray = static_pointer_cast<arrow::StructArray>(array); - UNIT_ASSERT_EQUAL(structArray->num_fields(), 2); - UNIT_ASSERT_EQUAL(structArray->field(0)->type_id(), arrow::Type::UINT64); - UNIT_ASSERT_EQUAL(structArray->field(1)->type_id(), arrow::Type::INT32); - auto depthArray = static_pointer_cast<arrow::UInt64Array>(structArray->field(0)); - auto i32Array = static_pointer_cast<arrow::Int32Array>(structArray->field(1)); auto index = 0; for (auto value: values) { - auto depth = depthArray->Value(index); - while (depth > 0) { - UNIT_ASSERT(value.HasValue()); + std::shared_ptr<arrow::Array> currentArray = array; + int depth = 0; + + while (currentArray->type()->id() == arrow::Type::STRUCT) { + auto structArray = static_pointer_cast<arrow::StructArray>(currentArray); + UNIT_ASSERT_EQUAL(structArray->num_fields(), 1); + + if (structArray->IsNull(index)) { + break; + } + + ++depth; + + auto childArray = structArray->field(0); + if (childArray->type()->id() == arrow::Type::DENSE_UNION) { + break; + } + + currentArray = childArray; + } + + while (depth--) { + UNIT_ASSERT(value); value = value.GetOptionalValue(); - --depth; } + if (value.HasValue()) { - UNIT_ASSERT_EQUAL(value.Get<i32>(), i32Array->Value(index)); + if (currentArray->type()->id() == arrow::Type::INT32) { + UNIT_ASSERT_EQUAL(value.Get<i32>(), static_pointer_cast<arrow::Int32Array>(currentArray)->Value(index)); + } else { + UNIT_ASSERT(!currentArray->IsNull(index)); + } } else { - UNIT_ASSERT(i32Array->IsNull(index)); + UNIT_ASSERT(currentArray->IsNull(index)); } + ++index; } } @@ -954,7 +1540,7 @@ Y_UNIT_TEST_SUITE(ConvertUnboxedValueToArrowAndBack){ TTestContext context; auto dictType = context.GetDictUtf8ToIntervalType(); - UNIT_ASSERT(NArrow::IsArrowCompatible(dictType)); + UNIT_ASSERT(!NArrow::IsArrowCompatible(dictType)); auto values = context.CreateDictUtf8ToInterval(100); auto array = NArrow::MakeArray(values, dictType); @@ -980,6 +1566,21 @@ Y_UNIT_TEST_SUITE(ConvertUnboxedValueToArrowAndBack){ } } + Y_UNIT_TEST(OptionalListOfOptional) { + TTestContext context; + + auto listType = context.GetOptionalListOfOptional(); + Y_ABORT_UNLESS(NArrow::IsArrowCompatible(listType)); + + auto values = context.CreateOptionalListOfOptional(100); + auto array = NArrow::MakeArray(values, listType); + auto restoredValues = NArrow::ExtractUnboxedValues(array, listType, context.HolderFactory); + UNIT_ASSERT_EQUAL(values.size(), restoredValues.size()); + for (ui64 index = 0; index < values.size(); ++index) { + AssertUnboxedValuesAreEqual(values[index], restoredValues[index], listType); + } + } + Y_UNIT_TEST(VariantOverStruct) { TTestContext context; @@ -995,13 +1596,43 @@ Y_UNIT_TEST_SUITE(ConvertUnboxedValueToArrowAndBack){ } } + Y_UNIT_TEST(OptionalVariantOverStruct) { + TTestContext context; + + auto optionalVariantType = context.GetOptionalVariantOverStructType(); + UNIT_ASSERT(!NArrow::IsArrowCompatible(optionalVariantType)); + + auto values = context.CreateOptionalVariantOverStruct(100); + auto array = NArrow::MakeArray(values, optionalVariantType); + auto restoredValues = NArrow::ExtractUnboxedValues(array, optionalVariantType, context.HolderFactory); + UNIT_ASSERT_EQUAL(values.size(), restoredValues.size()); + for (ui64 index = 0; index < values.size(); ++index) { + AssertUnboxedValuesAreEqual(values[index], restoredValues[index], optionalVariantType); + } + } + + Y_UNIT_TEST(DoubleOptionalVariantOverStruct) { + TTestContext context; + + auto doubleOptionalVariantType = context.GetDoubleOptionalVariantOverStructType(); + UNIT_ASSERT(!NArrow::IsArrowCompatible(doubleOptionalVariantType)); + + auto values = context.CreateDoubleOptionalVariantOverStruct(100); + auto array = NArrow::MakeArray(values, doubleOptionalVariantType); + auto restoredValues = NArrow::ExtractUnboxedValues(array, doubleOptionalVariantType, context.HolderFactory); + UNIT_ASSERT_EQUAL(values.size(), restoredValues.size()); + for (ui64 index = 0; index < values.size(); ++index) { + AssertUnboxedValuesAreEqual(values[index], restoredValues[index], doubleOptionalVariantType); + } + } + Y_UNIT_TEST(VariantOverTupleWithOptionals) { TTestContext context; auto variantType = context.GetVariantOverTupleWithOptionalsType(); UNIT_ASSERT(NArrow::IsArrowCompatible(variantType)); - auto values = context.CreateVariantOverStruct(100); + auto values = context.CreateVariantOverTupleWithOptionals(100); auto array = NArrow::MakeArray(values, variantType); auto restoredValues = NArrow::ExtractUnboxedValues(array, variantType, context.HolderFactory); UNIT_ASSERT_EQUAL(values.size(), restoredValues.size()); @@ -1010,6 +1641,36 @@ Y_UNIT_TEST_SUITE(ConvertUnboxedValueToArrowAndBack){ } } + Y_UNIT_TEST(OptionalVariantOverTupleWithOptionals) { + TTestContext context; + + auto optionalVariantType = context.GetOptionalVariantOverTupleWithOptionalsType(); + UNIT_ASSERT(!NArrow::IsArrowCompatible(optionalVariantType)); + + auto values = context.CreateOptionalVariantOverTupleWithOptionals(100); + auto array = NArrow::MakeArray(values, optionalVariantType); + auto restoredValues = NArrow::ExtractUnboxedValues(array, optionalVariantType, context.HolderFactory); + UNIT_ASSERT_EQUAL(values.size(), restoredValues.size()); + for (ui64 index = 0; index < values.size(); ++index) { + AssertUnboxedValuesAreEqual(values[index], restoredValues[index], optionalVariantType); + } + } + + Y_UNIT_TEST(DoubleOptionalVariantOverTupleWithOptionals) { + TTestContext context; + + auto doubleOptionalVariantType = context.GetDoubleOptionalVariantOverTupleWithOptionalsType(); + UNIT_ASSERT(!NArrow::IsArrowCompatible(doubleOptionalVariantType)); + + auto values = context.CreateDoubleOptionalVariantOverTupleWithOptionals(100); + auto array = NArrow::MakeArray(values, doubleOptionalVariantType); + auto restoredValues = NArrow::ExtractUnboxedValues(array, doubleOptionalVariantType, context.HolderFactory); + UNIT_ASSERT_EQUAL(values.size(), restoredValues.size()); + for (ui64 index = 0; index < values.size(); ++index) { + AssertUnboxedValuesAreEqual(values[index], restoredValues[index], doubleOptionalVariantType); + } + } + Y_UNIT_TEST(DictOptionalToTuple) { TTestContext context; |