From 0896bce9b4bddddd0a79a3283f6c5c5501a0f15b Mon Sep 17 00:00:00 2001 From: Daniil Timizhev Date: Mon, 15 Sep 2025 12:54:53 +0300 Subject: Support containers with Arrow format of result sets (#23831) --- ydb/core/formats/arrow/arrow_batch_builder.cpp | 49 +- ydb/core/formats/arrow/arrow_batch_builder.h | 20 +- ydb/core/formats/arrow/arrow_helpers.cpp | 6 +- ydb/core/formats/arrow/arrow_helpers.h | 4 +- ydb/core/formats/arrow/arrow_helpers_minikql.cpp | 42 + ydb/core/formats/arrow/arrow_helpers_minikql.h | 17 + ydb/core/formats/arrow/switch/switch_type.h | 1 + ydb/core/formats/arrow/ya.make | 3 + ydb/core/kqp/query_data/kqp_query_data.cpp | 21 +- ydb/core/kqp/runtime/kqp_transport.cpp | 17 +- .../kqp/ut/arrow/kqp_result_set_format_arrow.cpp | 1424 -------------- ydb/core/kqp/ut/arrow/kqp_result_set_formats.cpp | 2037 ++++++++++++++++++++ ydb/core/kqp/ut/arrow/ya.make | 2 +- .../test_helper/columnshard_ut_common.h | 7 + ydb/library/yql/dq/runtime/dq_arrow_helpers.cpp | 766 +++++--- ydb/library/yql/dq/runtime/dq_arrow_helpers.h | 42 +- ydb/library/yql/dq/runtime/dq_arrow_helpers_ut.cpp | 799 +++++++- 17 files changed, 3455 insertions(+), 1802 deletions(-) create mode 100644 ydb/core/formats/arrow/arrow_helpers_minikql.cpp create mode 100644 ydb/core/formats/arrow/arrow_helpers_minikql.h delete mode 100644 ydb/core/kqp/ut/arrow/kqp_result_set_format_arrow.cpp create mode 100644 ydb/core/kqp/ut/arrow/kqp_result_set_formats.cpp diff --git a/ydb/core/formats/arrow/arrow_batch_builder.cpp b/ydb/core/formats/arrow/arrow_batch_builder.cpp index 58e8b4ba909..9cfa08a278a 100644 --- a/ydb/core/formats/arrow/arrow_batch_builder.cpp +++ b/ydb/core/formats/arrow/arrow_batch_builder.cpp @@ -1,7 +1,13 @@ #include "arrow_batch_builder.h" -#include "switch/switch_type.h" + +#include +#include +#include +#include + #include #include + namespace NKikimr::NArrow { namespace { @@ -78,6 +84,15 @@ arrow::Status AppendCell(arrow::RecordBatchBuilder& builder, const TCell& cell, return result; } +arrow::Status AppendValue(arrow::RecordBatchBuilder& builder, const NUdf::TUnboxedValue& value, ui32 colNum, const NKikimr::NMiniKQL::TType* type) { + try { + NYql::NArrow::AppendElement(value, builder.GetField(colNum), type); + } catch (const std::exception& e) { + return arrow::Status::FromArgs(arrow::StatusCode::Invalid, e.what()); + } + return arrow::Status::OK(); +} + } NKikimr::NArrow::TRecordBatchConstructor::TRecordConstructor& TRecordBatchConstructor::TRecordConstructor::AddRecordValue( @@ -225,6 +240,20 @@ arrow::Status TArrowBatchBuilder::Start(const std::vector> yqlColumns) { + YqlSchema = yqlColumns; + auto schema = MakeArrowSchema(yqlColumns, NotNullColumns); + if (!schema.ok()) { + return arrow::Status::FromArgs(schema.status().code(), "Cannot make arrow schema: ", schema.status().ToString()); + } + auto status = arrow::RecordBatchBuilder::Make(*schema, MemoryPool, RowsToReserve, &BatchBuilder); + NumRows = NumBytes = 0; + if (!status.ok()) { + return arrow::Status::FromArgs(schema.status().code(), "Cannot make arrow builder: ", status.ToString()); + } + return arrow::Status::OK(); +} + void TArrowBatchBuilder::AppendCell(const TCell& cell, ui32 colNum) { NumBytes += cell.Size(); auto ydbType = YdbSchema[colNum].second; @@ -232,6 +261,13 @@ void TArrowBatchBuilder::AppendCell(const TCell& cell, ui32 colNum) { Y_ABORT_UNLESS(status.ok(), "Failed to append cell: %s", status.ToString().c_str()); } +void TArrowBatchBuilder::AppendValue(const NUdf::TUnboxedValue& value, ui32 colNum) { + NumBytes += sizeof(NUdf::TUnboxedValue); // TODO: strings or containers sizes? + auto yqlType = YqlSchema[colNum].second; + auto status = NKikimr::NArrow::AppendValue(*BatchBuilder, value, colNum, yqlType); + Y_ENSURE(status.ok(), "Failed to append value: " << status.ToString()); +} + void TArrowBatchBuilder::AddRow(const TDbTupleRef& key, const TDbTupleRef& value) { ++NumRows; @@ -273,6 +309,17 @@ void TArrowBatchBuilder::AddRow(const TConstArrayRef& row) { } } +void TArrowBatchBuilder::AddRow(const NUdf::TUnboxedValue& row, size_t membersCount, const TVector* columnOrder) { + Y_ABORT_UNLESS(!YqlSchema.empty()); + + ++NumRows; + + for (size_t i = 0; i < membersCount; ++i) { + const auto& memberIndex = (!columnOrder || columnOrder->empty()) ? i : (*columnOrder)[i]; + AppendValue(row.GetElement(memberIndex), i); + } +} + void TArrowBatchBuilder::ReserveData(ui32 columnNo, size_t size) { if (!BatchBuilder || columnNo >= (ui32)BatchBuilder->num_fields()) { return; diff --git a/ydb/core/formats/arrow/arrow_batch_builder.h b/ydb/core/formats/arrow/arrow_batch_builder.h index d1341208a02..052457d992f 100644 --- a/ydb/core/formats/arrow/arrow_batch_builder.h +++ b/ydb/core/formats/arrow/arrow_batch_builder.h @@ -1,9 +1,18 @@ #pragma once -#include "arrow_helpers.h" + #include +#include #include #include +namespace NYql::NUdf { +class TUnboxedValue; +} + +namespace NKikimr::NMiniKQL { +class TType; +} + namespace NKikimr::NArrow { class TRecordBatchReader { @@ -169,6 +178,7 @@ public: void AddRow(const NKikimr::TDbTupleRef& key, const NKikimr::TDbTupleRef& value) override; void AddRow(const TConstArrayRef& key, const TConstArrayRef& value); void AddRow(const TConstArrayRef& row); + void AddRow(const NYql::NUdf::TUnboxedValue& row, size_t membersCount, const TVector* columnOrder); // You have to call it before Start() void Reserve(size_t numRows) { @@ -183,19 +193,27 @@ public: } arrow::Status Start(const std::vector>& columns); + arrow::Status Start(const std::vector> columns); + std::shared_ptr FlushBatch(bool reinitialize, bool flushEmpty = false); std::shared_ptr GetBatch() const { return Batch; } protected: void AppendCell(const TCell& cell, ui32 colNum); + void AppendValue(const NYql::NUdf::TUnboxedValue& value, ui32 colNum); const std::vector>& GetYdbSchema() const { return YdbSchema; } + const std::vector> GetYqlSchema() const { + return YqlSchema; + } + private: arrow::ipc::IpcWriteOptions WriteOptions; std::vector> YdbSchema; + std::vector> YqlSchema; std::unique_ptr BatchBuilder; std::shared_ptr Batch; size_t RowsToReserve{DEFAULT_ROWS_TO_RESERVE}; diff --git a/ydb/core/formats/arrow/arrow_helpers.cpp b/ydb/core/formats/arrow/arrow_helpers.cpp index 8f23ac19f28..54766f40615 100644 --- a/ydb/core/formats/arrow/arrow_helpers.cpp +++ b/ydb/core/formats/arrow/arrow_helpers.cpp @@ -89,11 +89,11 @@ arrow::Result> GetCSVArrowType(NScheme::TTypeIn } arrow::Result MakeArrowFields( - const std::vector>& columns, const std::set& notNullColumns) { + const std::vector>& ydbColumns, const std::set& notNullColumns) { std::vector> fields; - fields.reserve(columns.size()); + fields.reserve(ydbColumns.size()); TVector errors; - for (auto& [name, ydbType] : columns) { + for (auto& [name, ydbType] : ydbColumns) { std::string colName(name.data(), name.size()); auto arrowType = GetArrowType(ydbType); if (arrowType.ok()) { diff --git a/ydb/core/formats/arrow/arrow_helpers.h b/ydb/core/formats/arrow/arrow_helpers.h index 61a8c2355fd..952651282ed 100644 --- a/ydb/core/formats/arrow/arrow_helpers.h +++ b/ydb/core/formats/arrow/arrow_helpers.h @@ -48,8 +48,8 @@ std::optional FindUpperOrEqualPosition(const TArray& arr, const TValue val arrow::Result> GetArrowType(NScheme::TTypeInfo typeInfo); arrow::Result> GetCSVArrowType(NScheme::TTypeInfo typeId); -arrow::Result MakeArrowFields(const std::vector>& columns, const std::set& notNullColumns = {}); -arrow::Result> MakeArrowSchema(const std::vector>& columns, const std::set& notNullColumns = {}); +arrow::Result MakeArrowFields(const std::vector>& ydbColumns, const std::set& notNullColumns = {}); +arrow::Result> MakeArrowSchema(const std::vector>& ydbColumns, const std::set& notNullColumns = {}); std::shared_ptr DeserializeSchema(const TString& str); diff --git a/ydb/core/formats/arrow/arrow_helpers_minikql.cpp b/ydb/core/formats/arrow/arrow_helpers_minikql.cpp new file mode 100644 index 00000000000..273bb581725 --- /dev/null +++ b/ydb/core/formats/arrow/arrow_helpers_minikql.cpp @@ -0,0 +1,42 @@ +#include "arrow_helpers_minikql.h" + +#include +#include + +namespace NKikimr::NArrow { + +arrow::Result MakeArrowFields( + const std::vector>& yqlColumns, const std::set& notNullColumns) { + std::vector> fields; + fields.reserve(yqlColumns.size()); + TVector errors; + for (auto& [name, mkqlType] : yqlColumns) { + std::string colName(name.data(), name.size()); + std::shared_ptr arrowType; + + try { + arrowType = NYql::NArrow::GetArrowType(mkqlType); + } catch (const yexception& e) { + errors.emplace_back(colName + " error: " + e.what()); + } + + if (arrowType) { + fields.emplace_back(std::make_shared(colName, arrowType, !notNullColumns.contains(colName))); + } + } + if (errors.empty()) { + return fields; + } + return arrow::Status::TypeError(JoinSeq(", ", errors)); +} + +arrow::Result> MakeArrowSchema( + const std::vector>& yqlColumns, const std::set& notNullColumns) { + const auto fields = MakeArrowFields(yqlColumns, notNullColumns); + if (fields.ok()) { + return std::make_shared(fields.ValueUnsafe()); + } + return fields.status(); +} + +} // namespace NKikimr::NArrow diff --git a/ydb/core/formats/arrow/arrow_helpers_minikql.h b/ydb/core/formats/arrow/arrow_helpers_minikql.h new file mode 100644 index 00000000000..ef3a4796346 --- /dev/null +++ b/ydb/core/formats/arrow/arrow_helpers_minikql.h @@ -0,0 +1,17 @@ +#pragma once + +#include + +#include +#include + +namespace NKikimr::NMiniKQL { +class TType; +} + +namespace NKikimr::NArrow { + +arrow::Result MakeArrowFields(const std::vector>& yqlColumns, const std::set& notNullColumns = {}); +arrow::Result> MakeArrowSchema(const std::vector>& yqlColumns, const std::set& notNullColumns = {}); + +} // namespace NKikimr::NArrow diff --git a/ydb/core/formats/arrow/switch/switch_type.h b/ydb/core/formats/arrow/switch/switch_type.h index fd8b167ef5e..b07fa8c039e 100644 --- a/ydb/core/formats/arrow/switch/switch_type.h +++ b/ydb/core/formats/arrow/switch/switch_type.h @@ -64,6 +64,7 @@ template case NScheme::NTypeIds::Interval: return callback(TTypeWrapper()); case NScheme::NTypeIds::Decimal: + case NScheme::NTypeIds::Uuid: return callback(TTypeWrapper()); case NScheme::NTypeIds::Datetime64: diff --git a/ydb/core/formats/arrow/ya.make b/ydb/core/formats/arrow/ya.make index 6a49d427715..53e3636d68f 100644 --- a/ydb/core/formats/arrow/ya.make +++ b/ydb/core/formats/arrow/ya.make @@ -23,6 +23,8 @@ PEERDIR( ydb/library/formats/arrow ydb/library/services yql/essentials/core/arrow_kernels/request + yql/essentials/minikql + ydb/library/yql/dq/runtime ) YQL_LAST_ABI_VERSION() @@ -30,6 +32,7 @@ YQL_LAST_ABI_VERSION() SRCS( arrow_batch_builder.cpp arrow_helpers.cpp + arrow_helpers_minikql.cpp arrow_filter.cpp converter.cpp converter.h diff --git a/ydb/core/kqp/query_data/kqp_query_data.cpp b/ydb/core/kqp/query_data/kqp_query_data.cpp index fcd0de37879..90c5d1b2a7d 100644 --- a/ydb/core/kqp/query_data/kqp_query_data.cpp +++ b/ydb/core/kqp/query_data/kqp_query_data.cpp @@ -90,13 +90,10 @@ bool TKqpExecuterTxResult::HasTrailingResults() { void TKqpExecuterTxResult::FillYdb(Ydb::ResultSet* ydbResult, const TResultSetFormatSettings& resultSetFormatSettings, bool fillSchema, TMaybe rowsLimitPerWrite) { YQL_ENSURE(ydbResult); YQL_ENSURE(!Rows.IsWide()); - YQL_ENSURE(MkqlItemType->GetKind() == NKikimr::NMiniKQL::TType::EKind::Struct); + YQL_ENSURE(MkqlItemType->GetKind() == NMiniKQL::TType::EKind::Struct); const auto* mkqlSrcRowStructType = static_cast(MkqlItemType); - std::vector> arrowSchema; - std::set arrowNotNullColumns; - if (fillSchema) { for (ui32 idx = 0; idx < mkqlSrcRowStructType->GetMembersCount(); ++idx) { auto* column = ydbResult->add_columns(); @@ -110,6 +107,9 @@ void TKqpExecuterTxResult::FillYdb(Ydb::ResultSet* ydbResult, const TResultSetFo } } + std::vector> arrowSchema; + std::set arrowNotNullColumns; + if (resultSetFormatSettings.IsArrowFormat()) { for (ui32 idx = 0; idx < mkqlSrcRowStructType->GetMembersCount(); ++idx) { ui32 memberIndex = (!ColumnOrder || ColumnOrder->empty()) ? idx : (*ColumnOrder)[idx]; @@ -120,8 +120,7 @@ void TKqpExecuterTxResult::FillYdb(Ydb::ResultSet* ydbResult, const TResultSetFo arrowNotNullColumns.insert(columnName); } - NScheme::TTypeInfo typeInfo = NScheme::TypeInfoFromMiniKQLType(columnType); - arrowSchema.emplace_back(std::move(columnName), std::move(typeInfo)); + arrowSchema.emplace_back(std::move(columnName), std::move(columnType)); } } @@ -145,7 +144,6 @@ void TKqpExecuterTxResult::FillYdb(Ydb::ResultSet* ydbResult, const TResultSetFo batchBuilder.Reserve(Rows.RowCount()); YQL_ENSURE(batchBuilder.Start(arrowSchema).ok()); - TRowBuilder rowBuilder(arrowSchema.size()); Rows.ForEachRow([&](const NUdf::TUnboxedValue& row) -> bool { if (rowsLimitPerWrite) { if (*rowsLimitPerWrite == 0) { @@ -155,14 +153,7 @@ void TKqpExecuterTxResult::FillYdb(Ydb::ResultSet* ydbResult, const TResultSetFo --(*rowsLimitPerWrite); } - for (size_t i = 0; i < arrowSchema.size(); ++i) { - ui32 memberIndex = (!ColumnOrder || ColumnOrder->empty()) ? i : (*ColumnOrder)[i]; - const auto& [name, type] = arrowSchema[i]; - rowBuilder.AddCell(i, type, row.GetElement(memberIndex), type.GetPgTypeMod(name)); - } - - auto cells = rowBuilder.BuildCells(); - batchBuilder.AddRow(cells); + batchBuilder.AddRow(row, arrowSchema.size(), ColumnOrder); return true; }); diff --git a/ydb/core/kqp/runtime/kqp_transport.cpp b/ydb/core/kqp/runtime/kqp_transport.cpp index 85077cf8e3a..9e3e37c23fe 100644 --- a/ydb/core/kqp/runtime/kqp_transport.cpp +++ b/ydb/core/kqp/runtime/kqp_transport.cpp @@ -63,7 +63,7 @@ void TKqpProtoBuilder::BuildYdbResultSet( TColumnOrder order = columnHints ? TColumnOrder(*columnHints) : TColumnOrder{}; - std::vector> arrowSchema; + std::vector> arrowSchema; std::set arrowNotNullColumns; if (fillSchema) { @@ -85,12 +85,11 @@ void TKqpProtoBuilder::BuildYdbResultSet( auto columnName = TString(columnHints && columnHints->size() ? order.at(idx).LogicalName : mkqlSrcRowStructType->GetMemberName(memberIndex)); auto* columnType = mkqlSrcRowStructType->GetMemberType(memberIndex); - if (columnType->GetKind() != TType::EKind::Optional) { + if (columnType->GetKind() != NMiniKQL::TType::EKind::Optional) { arrowNotNullColumns.insert(columnName); } - NScheme::TTypeInfo typeInfo = NScheme::TypeInfoFromMiniKQLType(columnType); - arrowSchema.emplace_back(std::move(columnName), std::move(typeInfo)); + arrowSchema.emplace_back(std::move(columnName), std::move(columnType)); } } @@ -140,7 +139,6 @@ void TKqpProtoBuilder::BuildYdbResultSet( batchBuilder.Reserve(arrowRowsCount); YQL_ENSURE(batchBuilder.Start(arrowSchema).ok()); - TRowBuilder rowBuilder(arrowSchema.size()); for (auto& part : data) { if (!part.ChunkCount()) { continue; @@ -150,14 +148,7 @@ void TKqpProtoBuilder::BuildYdbResultSet( dataSerializer.Deserialize(std::move(part), mkqlSrcRowType, rows); rows.ForEachRow([&](const NUdf::TUnboxedValue& value) { - for (size_t i = 0; i < arrowSchema.size(); ++i) { - ui32 memberIndex = (!columnOrder || columnOrder->empty()) ? i : (*columnOrder)[i]; - const auto& [name, type] = arrowSchema[i]; - rowBuilder.AddCell(i, type, value.GetElement(memberIndex), type.GetPgTypeMod(name)); - } - - auto cells = rowBuilder.BuildCells(); - batchBuilder.AddRow(cells); + batchBuilder.AddRow(value, arrowSchema.size(), columnOrder); }); } diff --git a/ydb/core/kqp/ut/arrow/kqp_result_set_format_arrow.cpp b/ydb/core/kqp/ut/arrow/kqp_result_set_format_arrow.cpp deleted file mode 100644 index dbda1212cdf..00000000000 --- a/ydb/core/kqp/ut/arrow/kqp_result_set_format_arrow.cpp +++ /dev/null @@ -1,1424 +0,0 @@ -#include - -#include -#include -#include - -#include - -#include - -#include -#include - -namespace NKikimr::NKqp { - -using namespace NYdb; -using namespace NYdb::NQuery; -using TTypeInfo = NScheme::TTypeInfo; -namespace NTypeIds = NScheme::NTypeIds; - -namespace { - -TKikimrRunner CreateKikimrRunner(bool withSampleTables, ui64 channelBufferSize = 8_MB) { - NKikimrConfig::TFeatureFlags featureFlags; - featureFlags.SetEnableArrowResultSetFormat(true); - - NKikimrConfig::TAppConfig appConfig; - appConfig.MutableTableServiceConfig()->SetEnableOlapSink(true); - appConfig.MutableTableServiceConfig()->MutableResourceManager()->SetChannelBufferSize(channelBufferSize); - - auto settings = TKikimrSettings(appConfig).SetFeatureFlags(featureFlags).SetWithSampleTables(withSampleTables); - return TKikimrRunner(settings); -} - -void CreateAllTypesRowTable(TQueryClient& client) { - auto createResult = client.ExecuteQuery(R"( - CREATE TABLE `/Root/RowTable` ( - Key Uint64, - BoolValue Bool, - Int8Value Int8, - Uint8Value Uint8, - Int16Value Int16, - Uint16Value Uint16, - Int32Value Int32, - Uint32Value Uint32, - Int64Value Int64, - Uint64Value Uint64, - FloatValue Float, - DoubleValue Double, - StringValue String, - Utf8Value Utf8, - DateValue Date, - DatetimeValue Datetime, - TimestampValue Timestamp, - IntervalValue Interval, - DecimalValue Decimal(22,9), - JsonValue Json, - YsonValue Yson, - JsonDocumentValue JsonDocument, - DyNumberValue DyNumber, - Int32NotNullValue Int32 NOT NULL, - PRIMARY KEY (Key) - ); - )", TTxControl::NoTx()).GetValueSync(); - UNIT_ASSERT_C(createResult.IsSuccess(), createResult.GetIssues().ToString()); - - auto insertResult = client.ExecuteQuery(R"( - INSERT INTO `/Root/RowTable` (Key, BoolValue, Int8Value, Uint8Value, Int16Value, Uint16Value, Int32Value, Uint32Value, Int64Value, Uint64Value, FloatValue, DoubleValue, StringValue, Utf8Value, DateValue, DatetimeValue, TimestampValue, IntervalValue, DecimalValue, JsonValue, YsonValue, JsonDocumentValue, DyNumberValue, Int32NotNullValue) VALUES - (42, true, -1, 1, -2, 2, -3, 3, -4, 4, CAST(3.0 AS Float), 4.0, "five", Utf8("six"), Date("2007-07-07"), Datetime("2008-08-08T08:08:08Z"), Timestamp("2009-09-09T09:09:09.09Z"), Interval("P10D"), CAST("11.11" AS Decimal(22, 9)), "[12]", "[13]", JsonDocument("[14]"), DyNumber("15.15"), 123); - )", TTxControl::BeginTx().CommitTx()).GetValueSync(); - UNIT_ASSERT_C(insertResult.IsSuccess(), insertResult.GetIssues().ToString()); -} - -void CreateAllTypesColumnTable(TQueryClient& client) { - auto createResult = client.ExecuteQuery(R"( - CREATE TABLE `/Root/ColumnTable` ( - Key Uint64 NOT NULL, - Int8Value Int8, - Uint8Value Uint8, - Int16Value Int16, - Uint16Value Uint16, - Int32Value Int32, - Uint32Value Uint32, - Int64Value Int64, - Uint64Value Uint64, - FloatValue Float, - DoubleValue Double, - StringValue String, - Utf8Value Utf8, - DateValue Date, - DatetimeValue Datetime, - TimestampValue Timestamp, - JsonValue Json, - YsonValue Yson, - JsonDocumentValue JsonDocument, - PRIMARY KEY (Key) - ) WITH ( - STORE = COLUMN - ); - )", TTxControl::NoTx()).GetValueSync(); - UNIT_ASSERT_C(createResult.IsSuccess(), createResult.GetIssues().ToString()); - - auto insertResult = client.ExecuteQuery(R"( - INSERT INTO `/Root/ColumnTable` (Key, Int8Value, Uint8Value, Int16Value, Uint16Value, Int32Value, Uint32Value, Int64Value, Uint64Value, FloatValue, DoubleValue, StringValue, Utf8Value, DateValue, DatetimeValue, TimestampValue, JsonValue, YsonValue, JsonDocumentValue) VALUES - (42, -1, 1, -2, 2, -3, 3, -4, 4, CAST(3.0 AS Float), 4.0, "five", Utf8("six"), Date("2007-07-07"), Datetime("2008-08-08T08:08:08Z"), Timestamp("2009-09-09T09:09:09.09Z"), "[12]", "[13]", JsonDocument("[14]")); - )", TTxControl::BeginTx().CommitTx()).GetValueSync(); - UNIT_ASSERT_C(insertResult.IsSuccess(), insertResult.GetIssues().ToString()); -} - -void AssertArrowValueResultsSize(const std::vector& arrowResultSets, const std::vector& valueResultSets) { - UNIT_ASSERT_VALUES_EQUAL_C(arrowResultSets.size(), valueResultSets.size(), "Result sets count mismatch"); - - for (size_t i = 0; i < arrowResultSets.size(); ++i) { - const auto& arrowResultSet = arrowResultSets[i]; - const auto& valueResultSet = valueResultSets[i]; - - UNIT_ASSERT_VALUES_EQUAL_C(TArrowAccessor::Format(arrowResultSet), TResultSet::EFormat::Arrow, "Result set format mismatch"); - UNIT_ASSERT_VALUES_EQUAL_C(TArrowAccessor::Format(valueResultSet), TResultSet::EFormat::Value, "Result set format mismatch"); - - UNIT_ASSERT_VALUES_EQUAL_C(arrowResultSet.RowsCount(), 0, "Rows must be empty for Arrow format of the result set"); - - size_t arrowRowsCount = 0; - - const auto& schema = TArrowAccessor::GetArrowSchema(arrowResultSet); - const auto& batches = TArrowAccessor::GetArrowBatches(arrowResultSet); - - UNIT_ASSERT_C(!schema.empty(), "Schema must not be empty"); - - std::shared_ptr arrowSchema = NArrow::DeserializeSchema(TString(schema)); - - for (const auto& batch : batches) { - auto arrowBatch = NArrow::DeserializeBatch(TString(batch), arrowSchema); - UNIT_ASSERT_C(arrowBatch, "Batch must be deserialized"); - UNIT_ASSERT_C(arrowBatch->ValidateFull().ok(), "Batch validation failed"); - - arrowRowsCount += arrowBatch->num_rows(); - - UNIT_ASSERT_VALUES_EQUAL_C(arrowBatch->num_columns(), valueResultSet.ColumnsCount(), "Columns count mismatch"); - } - - UNIT_ASSERT_VALUES_EQUAL_C(arrowRowsCount, valueResultSet.RowsCount(), "Rows count mismatch"); - } -} - -std::vector> ExecuteAndCombineBatches(TQueryClient& client, const TString& query, bool assertSize = false, ui64 minBatchesCount = 1) { - auto arrowSettings = TExecuteQuerySettings().Format(TResultSet::EFormat::Arrow); - auto arrowResponse = client.ExecuteQuery(query, TTxControl::BeginTx().CommitTx(), arrowSettings).GetValueSync(); - UNIT_ASSERT_C(arrowResponse.IsSuccess(), arrowResponse.GetIssues().ToString()); - - if (assertSize) { - auto valueSettings = TExecuteQuerySettings().Format(TResultSet::EFormat::Value); - auto valueResponse = client.ExecuteQuery(query, TTxControl::BeginTx().CommitTx(), valueSettings).GetValueSync(); - UNIT_ASSERT_C(valueResponse.IsSuccess(), valueResponse.GetIssues().ToString()); - AssertArrowValueResultsSize(arrowResponse.GetResultSets(), valueResponse.GetResultSets()); - } - - std::vector> resultBatches; - - for (const auto& resultSet : arrowResponse.GetResultSets()) { - const auto& schema = TArrowAccessor::GetArrowSchema(resultSet); - const auto& batches = TArrowAccessor::GetArrowBatches(resultSet); - - UNIT_ASSERT_C(!schema.empty(), "Schema must not be empty"); - UNIT_ASSERT_GE_C(batches.size(), minBatchesCount, "Batches count must be greater than or equal to " + ToString(minBatchesCount)); - - std::vector> arrowBatches; - auto arrowSchema = NArrow::DeserializeSchema(TString(schema)); - - for (const auto& batch : batches) { - auto arrowBatch = NArrow::DeserializeBatch(TString(batch), arrowSchema); - UNIT_ASSERT_C(arrowBatch, "Batch must be deserialized"); - UNIT_ASSERT_C(arrowBatch->ValidateFull().ok(), "Batch validation failed"); - - arrowBatches.push_back(std::move(arrowBatch)); - } - - auto resultBatch = NArrow::CombineBatches(arrowBatches); - UNIT_ASSERT_C(resultBatch->ValidateFull().ok(), "Batch combine validation failed"); - - resultBatches.push_back(std::move(resultBatch)); - } - - return resultBatches; -} - -std::string SerializeToBinaryJsonString(const TStringBuf json) { - const auto binaryJson = std::get(NBinaryJson::SerializeToBinaryJson(json)); - const TStringBuf buffer(binaryJson.Data(), binaryJson.Size()); - return TString(buffer); -} - -void CompareCompressedAndDefaultBatches(TQueryClient& client, std::optional codec, bool assertEqual = false) { - std::shared_ptr schemaCompressedBatch; - TString compressedBatch; - - std::shared_ptr schemaDefaultBatch; - TString defaultBatch; - - { - auto settings = TExecuteQuerySettings() - .Format(TResultSet::EFormat::Arrow) - .ArrowFormatSettings(TArrowFormatSettings() - .CompressionCodec(std::move(codec))); - - auto result = client.ExecuteQuery(R"( - SELECT Comment, Amount, Name FROM Test ORDER BY Amount DESC; - )", TTxControl::BeginTx().CommitTx(), settings).GetValueSync(); - UNIT_ASSERT_C(result.IsSuccess(), result.GetIssues().ToString()); - - const auto& schema = TArrowAccessor::GetArrowSchema(result.GetResultSet(0)); - const auto& batches = TArrowAccessor::GetArrowBatches(result.GetResultSet(0)); - - UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); - - schemaCompressedBatch = NArrow::DeserializeSchema(TString(schema)); - compressedBatch = std::move(batches[0]); - } - { - auto settings = TExecuteQuerySettings().Format(TResultSet::EFormat::Arrow); - - auto result = client.ExecuteQuery(R"( - SELECT Comment, Amount, Name FROM Test ORDER BY Amount DESC; - )", TTxControl::BeginTx().CommitTx(), settings).GetValueSync(); - UNIT_ASSERT_C(result.IsSuccess(), result.GetIssues().ToString()); - - const auto& schema = TArrowAccessor::GetArrowSchema(result.GetResultSet(0)); - const auto& batches = TArrowAccessor::GetArrowBatches(result.GetResultSet(0)); - - UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); - - schemaDefaultBatch = NArrow::DeserializeSchema(TString(schema)); - defaultBatch = std::move(batches[0]); - } - - UNIT_ASSERT_VALUES_EQUAL(schemaCompressedBatch->ToString(), schemaDefaultBatch->ToString()); - - // TODO [ditimizhev@]: Assert arrow::Codec compression types instead of strings - if (assertEqual) { - UNIT_ASSERT_VALUES_EQUAL(compressedBatch, defaultBatch); - } else { - UNIT_ASSERT_VALUES_UNEQUAL(compressedBatch, defaultBatch); - } - - auto firstArrowBatch = NArrow::DeserializeBatch(compressedBatch, schemaCompressedBatch); - auto secondArrowBatch = NArrow::DeserializeBatch(defaultBatch, schemaDefaultBatch); - - UNIT_ASSERT_C(firstArrowBatch, "First arrow batch must be deserialized"); - UNIT_ASSERT_C(secondArrowBatch, "Second arrow batch must be deserialized"); - - UNIT_ASSERT_C(firstArrowBatch->num_rows() > 0, "Arrow batch must not be empty"); - - UNIT_ASSERT_C(firstArrowBatch->ValidateFull().ok(), "Batch validation failed"); - UNIT_ASSERT_C(secondArrowBatch->ValidateFull().ok(), "Batch validation failed"); - - UNIT_ASSERT_VALUES_EQUAL(firstArrowBatch->ToString(), secondArrowBatch->ToString()); -} - -} // namespace - -Y_UNIT_TEST_SUITE(KqpResultSetFormats) { - /** - * By default, unspecified format is Value for compatibility with previous versions. - */ - Y_UNIT_TEST(DefaultFormat) { - auto kikimr = CreateKikimrRunner(/* withSampleTables */ true); - auto client = kikimr.GetQueryClient(); - - auto result = client.ExecuteQuery(R"( - SELECT Comment, Amount, Name FROM Test ORDER BY Amount DESC; - )", TTxControl::BeginTx().CommitTx()).GetValueSync(); - UNIT_ASSERT_C(result.IsSuccess(), result.GetIssues().ToString()); - - auto resultSet = result.GetResultSet(0); - UNIT_ASSERT_VALUES_EQUAL(TArrowAccessor::Format(resultSet), TResultSet::EFormat::Value); - - CompareYson(R"([ - [["None"];[7200u];["Tony"]]; - [["None"];[3500u];["Anna"]]; - [["None"];[300u];["Paul"]] - ])", FormatResultSetYson(resultSet)); - } - - /** - * Set Value format explicitly in TExecuteQuerySettings. - */ - Y_UNIT_TEST(ValueFormat_Simple) { - auto kikimr = CreateKikimrRunner(/* withSampleTables */ true); - auto client = kikimr.GetQueryClient(); - - auto settings = TExecuteQuerySettings().Format(TResultSet::EFormat::Value); - - auto result = client.ExecuteQuery(R"( - SELECT Comment, Amount, Name FROM Test ORDER BY Amount DESC; - )", TTxControl::BeginTx().CommitTx(), settings).GetValueSync(); - UNIT_ASSERT_C(result.IsSuccess(), result.GetIssues().ToString()); - - auto resultSet = result.GetResultSet(0); - UNIT_ASSERT_VALUES_EQUAL(TArrowAccessor::Format(resultSet), TResultSet::EFormat::Value); - - CompareYson(R"([ - [["None"];[7200u];["Tony"]]; - [["None"];[3500u];["Anna"]]; - [["None"];[300u];["Paul"]] - ])", FormatResultSetYson(resultSet)); - } - - /** - * Small channel buffer size, rows from many ExecuteQueryResponePart parts are filled into a single ResultSet. - */ - Y_UNIT_TEST(ValueFormat_SmallChannelBufferSize) { - auto kikimr = CreateKikimrRunner(/* withSampleTables */ false, 1_KB); - auto client = kikimr.GetQueryClient(); - - CreateLargeTable(kikimr, 100, 2, 2, 10, 2); - - auto settings = TExecuteQuerySettings().Format(TResultSet::EFormat::Value); - - auto result = client.ExecuteQuery(R"( - SELECT * FROM LargeTable; - )", TTxControl::BeginTx().CommitTx(), settings).GetValueSync(); - UNIT_ASSERT_C(result.IsSuccess(), result.GetIssues().ToString()); - - auto resultSet = result.GetResultSet(0); - UNIT_ASSERT_VALUES_EQUAL(TArrowAccessor::Format(resultSet), TResultSet::EFormat::Value); - UNIT_ASSERT_VALUES_EQUAL(resultSet.RowsCount(), 200); - UNIT_ASSERT_VALUES_EQUAL(resultSet.ColumnsCount(), 4); - } - - /** - * By default, SchemaInclusionMode is ALWAYS for Value format. - */ - Y_UNIT_TEST(ValueFormat_SchemaInclusionMode_Unspecified) { - auto kikimr = CreateKikimrRunner(/* withSampleTables */ false, 1_KB); - auto client = kikimr.GetQueryClient(); - - CreateLargeTable(kikimr, 100, 2, 2, 10, 2); - - auto settings = TExecuteQuerySettings().Format(TResultSet::EFormat::Value); - - auto it = client.StreamExecuteQuery(R"( - SELECT * FROM LargeTable; - )", TTxControl::BeginTx().CommitTx(), settings).GetValueSync(); - UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString()); - - size_t count = 0; - for (;;) { - auto part = it.ReadNext().GetValueSync(); - if (!part.IsSuccess()) { - UNIT_ASSERT_C(part.EOS(), part.GetIssues().ToString()); - break; - } - - if (part.HasResultSet()) { - auto resultSet = part.ExtractResultSet(); - UNIT_ASSERT_VALUES_EQUAL(TArrowAccessor::Format(resultSet), TResultSet::EFormat::Value); - UNIT_ASSERT_VALUES_UNEQUAL(resultSet.ColumnsCount(), 0); - - ++count; - } - } - - UNIT_ASSERT_GT_C(count, 1, "Expected at least 2 result sets"); - } - - /** - * Set SchemaInclusionMode ALWAYS for Value format explicitly in TExecuteQuerySettings. - */ - Y_UNIT_TEST(ValueFormat_SchemaInclusionMode_Always) { - auto kikimr = CreateKikimrRunner(/* withSampleTables */ false, 1_KB); - auto client = kikimr.GetQueryClient(); - - CreateLargeTable(kikimr, 100, 2, 2, 10, 2); - - auto settings = TExecuteQuerySettings().Format(TResultSet::EFormat::Value).SchemaInclusionMode(ESchemaInclusionMode::Always); - - auto it = client.StreamExecuteQuery(R"( - SELECT * FROM LargeTable; - )", TTxControl::BeginTx().CommitTx(), settings).GetValueSync(); - UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString()); - - size_t count = 0; - for (;;) { - auto part = it.ReadNext().GetValueSync(); - if (!part.IsSuccess()) { - UNIT_ASSERT_C(part.EOS(), part.GetIssues().ToString()); - break; - } - - if (part.HasResultSet()) { - auto resultSet = part.ExtractResultSet(); - UNIT_ASSERT_VALUES_EQUAL(TArrowAccessor::Format(resultSet), TResultSet::EFormat::Value); - UNIT_ASSERT_VALUES_UNEQUAL(resultSet.ColumnsCount(), 0); - - ++count; - } - } - - UNIT_ASSERT_GT_C(count, 1, "Expected at least 2 result sets"); - } - - /** - * Set SchemaInclusionMode FIRST_ONLY for Value format explicitly in TExecuteQuerySettings. - */ - Y_UNIT_TEST(ValueFormat_SchemaInclusionMode_FirstOnly) { - auto kikimr = CreateKikimrRunner(/* withSampleTables */ false, 1_KB); - auto client = kikimr.GetQueryClient(); - - CreateLargeTable(kikimr, 100, 2, 2, 10, 2); - - auto settings = TExecuteQuerySettings().Format(TResultSet::EFormat::Value).SchemaInclusionMode(ESchemaInclusionMode::FirstOnly); - - auto it = client.StreamExecuteQuery(R"( - SELECT * FROM LargeTable; - )", TTxControl::BeginTx().CommitTx(), settings).GetValueSync(); - UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString()); - - size_t count = 0; - for (;;) { - auto part = it.ReadNext().GetValueSync(); - if (!part.IsSuccess()) { - UNIT_ASSERT_C(part.EOS(), part.GetIssues().ToString()); - break; - } - - if (part.HasResultSet()) { - auto resultSet = part.ExtractResultSet(); - UNIT_ASSERT_VALUES_EQUAL(TArrowAccessor::Format(resultSet), TResultSet::EFormat::Value); - - if (count == 0) { - UNIT_ASSERT_VALUES_UNEQUAL(resultSet.ColumnsCount(), 0); - } else { - UNIT_ASSERT_VALUES_EQUAL(resultSet.ColumnsCount(), 0); - } - - ++count; - } - } - - UNIT_ASSERT_GT_C(count, 1, "Expected at least 2 result sets"); - } - - /** - * For Value format, FirstOnly schema inclusion mode is supported for multistatement queries. - */ - Y_UNIT_TEST(ValueFormat_SchemaInclusionMode_FirstOnly_Multistatement) { - auto kikimr = CreateKikimrRunner(/* withSampleTables */ false, 1_KB); - auto client = kikimr.GetQueryClient(); - - CreateLargeTable(kikimr, 200, 2, 2, 10, 2); - - auto settings = TExecuteQuerySettings().Format(TResultSet::EFormat::Value).SchemaInclusionMode(ESchemaInclusionMode::FirstOnly); - - auto it = client.StreamExecuteQuery(R"( - SELECT * FROM LargeTable; - SELECT Key, Data FROM LargeTable; - )", TTxControl::BeginTx().CommitTx(), settings).GetValueSync(); - UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString()); - - std::unordered_map counts; - - for (;;) { - auto part = it.ReadNext().GetValueSync(); - if (!part.IsSuccess()) { - UNIT_ASSERT_C(part.EOS(), part.GetIssues().ToString()); - break; - } - - if (part.HasResultSet()) { - auto resultSet = part.ExtractResultSet(); - UNIT_ASSERT_VALUES_EQUAL(TArrowAccessor::Format(resultSet), TResultSet::EFormat::Value); - - auto idx = part.GetResultSetIndex(); - - if (counts.find(idx) == counts.end()) { - UNIT_ASSERT_VALUES_UNEQUAL(resultSet.ColumnsCount(), 0); - } else { - UNIT_ASSERT_VALUES_EQUAL(resultSet.ColumnsCount(), 0); - } - - ++counts[idx]; - } - } - - UNIT_ASSERT_C(counts.size() == 2, "Expected 2 result set indexes"); - - for (const auto& [idx, count] : counts) { - UNIT_ASSERT_GT_C(count, 1, "Expected at least 2 result sets for statement with ResultSetIndex = " << idx); - } - } - - /** - * Set Arrow format explicitly in TExecuteQuerySettings. - */ - Y_UNIT_TEST(ArrowFormat_Simple) { - auto kikimr = CreateKikimrRunner(/* withSampleTables */ true); - auto client = kikimr.GetQueryClient(); - - auto settings = TExecuteQuerySettings().Format(TResultSet::EFormat::Arrow); - - auto result = client.ExecuteQuery(R"( - SELECT Comment, Amount, Name FROM Test ORDER BY Amount DESC; - )", TTxControl::BeginTx().CommitTx(), settings).GetValueSync(); - UNIT_ASSERT_C(result.IsSuccess(), result.GetIssues().ToString()); - - auto resultSet = result.GetResultSet(0); - UNIT_ASSERT_VALUES_EQUAL(TArrowAccessor::Format(resultSet), TResultSet::EFormat::Arrow); - - const auto& schema = TArrowAccessor::GetArrowSchema(resultSet); - const auto& batches = TArrowAccessor::GetArrowBatches(resultSet); - - UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); - - auto arrowSchema = NArrow::DeserializeSchema(TString(schema)); - auto arrowBatch = NArrow::DeserializeBatch(TString(batches[0]), arrowSchema); - - UNIT_ASSERT_C(arrowSchema, "Schema must be deserialized"); - UNIT_ASSERT_C(arrowBatch, "Batch must be deserialized"); - UNIT_ASSERT_C(arrowBatch->ValidateFull().ok(), "Batch validation failed"); - - NColumnShard::TTableUpdatesBuilder builder(NArrow::MakeArrowSchema({ - std::make_pair("Comment", TTypeInfo(NTypeIds::String)), - std::make_pair("Amount", TTypeInfo(NTypeIds::Uint64)), - std::make_pair("Name", TTypeInfo(NTypeIds::String)) - })); - - builder.AddRow().Add("None").Add(7200).Add("Tony"); - builder.AddRow().Add("None").Add(3500).Add("Anna"); - builder.AddRow().Add("None").Add(300).Add("Paul"); - - auto expected = builder.BuildArrow(); - UNIT_ASSERT_VALUES_EQUAL(arrowBatch->ToString(), expected->ToString()); - } - - Y_UNIT_TEST(ArrowFormat_EmptyBatch) { - auto kikimr = CreateKikimrRunner(/* withSampleTables */ true); - auto client = kikimr.GetQueryClient(); - - auto settings = TExecuteQuerySettings().Format(TResultSet::EFormat::Arrow); - - auto result = client.ExecuteQuery(R"( - SELECT Comment, Amount, Name FROM Test WHERE Amount >= 999999; - )", TTxControl::BeginTx().CommitTx(), settings).GetValueSync(); - UNIT_ASSERT_C(result.IsSuccess(), result.GetIssues().ToString()); - - auto resultSet = result.GetResultSet(0); - UNIT_ASSERT_VALUES_EQUAL(TArrowAccessor::Format(resultSet), TResultSet::EFormat::Arrow); - - const auto& schema = TArrowAccessor::GetArrowSchema(resultSet); - const auto& batches = TArrowAccessor::GetArrowBatches(resultSet); - - UNIT_ASSERT_C(!batches.empty(), "Expected at least one empty batch"); - - auto arrowSchema = NArrow::DeserializeSchema(TString(schema)); - auto arrowBatch = NArrow::DeserializeBatch(TString(batches[0]), arrowSchema); - - UNIT_ASSERT_C(arrowSchema, "Schema must be deserialized"); - UNIT_ASSERT_C(arrowBatch, "Batch must be deserialized"); - UNIT_ASSERT_C(arrowBatch->ValidateFull().ok(), "Batch validation failed"); - - NColumnShard::TTableUpdatesBuilder builder(NArrow::MakeArrowSchema({ - std::make_pair("Comment", TTypeInfo(NTypeIds::String)), - std::make_pair("Amount", TTypeInfo(NTypeIds::Uint64)), - std::make_pair("Name", TTypeInfo(NTypeIds::String)) - })); - - UNIT_ASSERT_C(arrowBatch->num_rows() == 0, "Batch must have 0 rows"); - - auto expected = builder.BuildArrow(); - UNIT_ASSERT_VALUES_EQUAL(arrowBatch->ToString(), expected->ToString()); - } - - /** - * Arrow format is supported for all types of columns. - */ - Y_UNIT_TEST_TWIN(ArrowFormat_AllTypes, isOlap) { - auto kikimr = CreateKikimrRunner(/* withSampleTables */ false); - auto client = kikimr.GetQueryClient(); - - if (isOlap) { - CreateAllTypesColumnTable(client); - } else { - CreateAllTypesRowTable(client); - } - - const TString query = Sprintf(R"( - SELECT * FROM `/Root/%s`; - )", (isOlap) ? "ColumnTable" : "RowTable"); - - Y_UNUSED(ExecuteAndCombineBatches(client, query, /* assertSize */ true)); - } - - /** - * Arrow format is supported for large batches. - */ - Y_UNIT_TEST(ArrowFormat_LargeTable) { - auto kikimr = CreateKikimrRunner(/* withSampleTables */ false); - auto client = kikimr.GetQueryClient(); - - CreateLargeTable(kikimr, 10000, 4, 10, 5000, 10); - - const TString query = Sprintf(R"( - SELECT * FROM `/Root/LargeTable`; - )"); - - Y_UNUSED(ExecuteAndCombineBatches(client, query, /* assertSize */ true)); - } - - /** - * Arrow format is supported for large batches with LIMIT. - */ - Y_UNIT_TEST(ArrowFormat_LargeTable_Limit) { - auto kikimr = CreateKikimrRunner(/* withSampleTables */ false); - auto client = kikimr.GetQueryClient(); - - CreateLargeTable(kikimr, 10000, 4, 10, 5000, 10); - - const TString query = Sprintf(R"( - SELECT * FROM `/Root/LargeTable` LIMIT 70000; - )"); - - Y_UNUSED(ExecuteAndCombineBatches(client, query, /* assertSize */ true)); - } - - /** - * Arrow format is supported for returning. - */ - Y_UNIT_TEST_TWIN(ArrowFormat_Returning, isOlap) { - auto kikimr = CreateKikimrRunner(/* withSampleTables */ false); - auto client = kikimr.GetQueryClient(); - - TString query; - - if (isOlap) { - CreateAllTypesColumnTable(client); - query = R"( - UPSERT INTO `/Root/ColumnTable` (Key, Int8Value, Uint8Value, Int16Value, Uint16Value, Int32Value, Uint32Value, Int64Value, Uint64Value, FloatValue, DoubleValue, StringValue, Utf8Value, DateValue, DatetimeValue, TimestampValue, JsonValue, YsonValue, JsonDocumentValue) VALUES - (43, -1, 1, -2, 2, -3, 3, -4, 4, CAST(3.0 AS Float), 4.0, "five", Utf8("six"), Date("2007-07-07"), Datetime("2008-08-08T08:08:08Z"), Timestamp("2009-09-09T09:09:09.09Z"), "[12]", "[13]", JsonDocument("[14]")), - (44, -1, 1, -2, 2, -3, 3, -4, 4, CAST(3.0 AS Float), 4.0, "five", Utf8("six"), Date("2007-07-07"), Datetime("2008-08-08T08:08:08Z"), Timestamp("2009-09-09T09:09:09.09Z"), "[12]", "[13]", JsonDocument("[14]")), - (45, -1, 1, -2, 2, -3, 3, -4, 4, CAST(3.0 AS Float), 4.0, "five", Utf8("six"), Date("2007-07-07"), Datetime("2008-08-08T08:08:08Z"), Timestamp("2009-09-09T09:09:09.09Z"), "[12]", "[13]", JsonDocument("[14]")) - RETURNING *; - )"; - } else { - CreateAllTypesRowTable(client); - query = R"( - UPSERT INTO `/Root/RowTable` (Key, BoolValue, Int8Value, Uint8Value, Int16Value, Uint16Value, Int32Value, Uint32Value, Int64Value, Uint64Value, FloatValue, DoubleValue, StringValue, Utf8Value, DateValue, DatetimeValue, TimestampValue, IntervalValue, DecimalValue, JsonValue, YsonValue, JsonDocumentValue, DyNumberValue, Int32NotNullValue) VALUES - (43, true, -1, 1, -2, 2, -3, 3, -4, 4, CAST(3.0 AS Float), 4.0, "five", Utf8("six"), Date("2007-07-07"), Datetime("2008-08-08T08:08:08Z"), Timestamp("2009-09-09T09:09:09.09Z"), Interval("P10D"), CAST("11.11" AS Decimal(22, 9)), "[12]", "[13]", JsonDocument("[14]"), DyNumber("15.15"), 123), - (44, true, -1, 1, -2, 2, -3, 3, -4, 4, CAST(3.0 AS Float), 4.0, "five", Utf8("six"), Date("2007-07-07"), Datetime("2008-08-08T08:08:08Z"), Timestamp("2009-09-09T09:09:09.09Z"), Interval("P10D"), CAST("11.11" AS Decimal(22, 9)), "[12]", "[13]", JsonDocument("[14]"), DyNumber("15.15"), 123), - (45, true, -1, 1, -2, 2, -3, 3, -4, 4, CAST(3.0 AS Float), 4.0, "five", Utf8("six"), Date("2007-07-07"), Datetime("2008-08-08T08:08:08Z"), Timestamp("2009-09-09T09:09:09.09Z"), Interval("P10D"), CAST("11.11" AS Decimal(22, 9)), "[12]", "[13]", JsonDocument("[14]"), DyNumber("15.15"), 123) - RETURNING *; - )"; - } - - Y_UNUSED(ExecuteAndCombineBatches(client, query, /* assertSize */ true)); - } - - /** - * Check different orders of columns in SELECT with Arrow format. - */ - Y_UNIT_TEST(ArrowFormat_ColumnOrder) { - auto kikimr = CreateKikimrRunner(/* withSampleTables */ true); - auto client = kikimr.GetQueryClient(); - - { - auto batches = ExecuteAndCombineBatches(client, R"( - SELECT Name, Amount FROM Test WHERE Group = 2; - )", /* assertSize */ true); - - UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); - - NColumnShard::TTableUpdatesBuilder builder(NArrow::MakeArrowSchema({ - std::make_pair("Name", TTypeInfo(NTypeIds::String)), - std::make_pair("Amount", TTypeInfo(NTypeIds::Uint64)) - })); - - builder.AddRow().Add("Tony").Add(7200); - - auto expected = builder.BuildArrow(); - UNIT_ASSERT_VALUES_EQUAL(batches.front()->ToString(), expected->ToString()); - } - { - auto batches = ExecuteAndCombineBatches(client, R"( - SELECT Amount, Name FROM Test WHERE Group = 2; - )", /* assertSize */ true); - - UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); - - NColumnShard::TTableUpdatesBuilder builder(NArrow::MakeArrowSchema({ - std::make_pair("Amount", TTypeInfo(NTypeIds::Uint64)), - std::make_pair("Name", TTypeInfo(NTypeIds::String)) - })); - - builder.AddRow().Add(7200).Add("Tony"); - - auto expected = builder.BuildArrow(); - UNIT_ASSERT_VALUES_EQUAL(batches.front()->ToString(), expected->ToString()); - } - { - auto batches = ExecuteAndCombineBatches(client, R"( - SELECT Comment, Amount, Name FROM Test ORDER BY Amount DESC; - )", /* assertSize */ true); - - UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); - - NColumnShard::TTableUpdatesBuilder builder(NArrow::MakeArrowSchema({ - std::make_pair("Comment", TTypeInfo(NTypeIds::String)), - std::make_pair("Amount", TTypeInfo(NTypeIds::Uint64)), - std::make_pair("Name", TTypeInfo(NTypeIds::String)) - })); - - builder.AddRow().Add("None").Add(7200).Add("Tony"); - builder.AddRow().Add("None").Add(3500).Add("Anna"); - builder.AddRow().Add("None").Add(300).Add("Paul"); - - auto expected = builder.BuildArrow(); - UNIT_ASSERT_VALUES_EQUAL(batches.front()->ToString(), expected->ToString()); - } - } - - /** - * Small channel buffer size, data bytes and schema from many ExecuteQueryResponePart parts are filled into a single ResultSet as a std::vector with a single schema. - */ - Y_UNIT_TEST(ArrowFormat_SmallChannelBufferSize) { - auto kikimr = CreateKikimrRunner(/* withSampleTables */ false, 1_KB); - auto client = kikimr.GetQueryClient(); - - CreateLargeTable(kikimr, 100, 2, 2, 10, 2); - - auto batches = ExecuteAndCombineBatches(client, R"( - SELECT * FROM LargeTable; - )", /* assertSize */ true, /* minBatchesCount */ 2); - - UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); - - UNIT_ASSERT_VALUES_EQUAL(batches.front()->num_rows(), 200); - UNIT_ASSERT_VALUES_EQUAL(batches.front()->num_columns(), 4); - } - - /** - * These YQL types are supported for Arrow format as arithmetic types: - */ - Y_UNIT_TEST(ArrowFormat_Types_Arithmetic) { - auto kikimr = CreateKikimrRunner(/* withSampleTables */ false); - auto client = kikimr.GetQueryClient(); - - { - auto result = client.ExecuteQuery(R"( - CREATE TABLE ArithmeticTypesTable ( - BoolValue Bool, - Int8Value Int8, - Uint8Value Uint8 NOT NULL, - Int16Value Int16, - Uint16Value Uint16 NOT NULL, - Int32Value Int32, - Uint32Value Uint32 NOT NULL, - Int64Value Int64, - Uint64Value Uint64 NOT NULL, - FloatValue Float, - DoubleValue Double NOT NULL, - DecimalValue Decimal(22, 2), - PRIMARY KEY (BoolValue) - ); - )", TTxControl::NoTx()).GetValueSync(); - UNIT_ASSERT_C(result.IsSuccess(), result.GetIssues().ToString()); - } - { - auto result = client.ExecuteQuery(R"( - INSERT INTO ArithmeticTypesTable (BoolValue, Int8Value, Uint8Value, Int16Value, Uint16Value, Int32Value, Uint32Value, Int64Value, Uint64Value, FloatValue, DoubleValue, DecimalValue) VALUES - (NULL, -1, 1, -2, 2, -3, 3, -4, 4, CAST(5.0 AS Float), 6.0, CAST("7.77" AS Decimal(22, 2))), - (true, -1, 1, -2, 2, -3, 3, -4, 4, CAST(5.0 AS Float), 6.0, CAST("7.77" AS Decimal(22, 2))), - (false, -1, 1, -2, 2, -3, 3, -4, 4, CAST(5.0 AS Float), 6.0, CAST("7.77" AS Decimal(22, 2))); - )", TTxControl::BeginTx().CommitTx()).GetValueSync(); - UNIT_ASSERT_C(result.IsSuccess(), result.GetIssues().ToString()); - } - { - auto batches = ExecuteAndCombineBatches(client, R"( - SELECT BoolValue, Int8Value, Uint8Value, Int16Value, Uint16Value, Int32Value, Uint32Value, Int64Value, Uint64Value, FloatValue, DoubleValue, DecimalValue - FROM ArithmeticTypesTable ORDER BY BoolValue; - )", /* assertSize */ true); - - UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); - - NColumnShard::TTableUpdatesBuilder builder(NArrow::MakeArrowSchema({ - std::make_pair("BoolValue", TTypeInfo(NTypeIds::Bool)), - std::make_pair("Int8Value", TTypeInfo(NTypeIds::Int8)), - std::make_pair("Uint8Value", TTypeInfo(NTypeIds::Uint8)), - std::make_pair("Int16Value", TTypeInfo(NTypeIds::Int16)), - std::make_pair("Uint16Value", TTypeInfo(NTypeIds::Uint16)), - std::make_pair("Int32Value", TTypeInfo(NTypeIds::Int32)), - std::make_pair("Uint32Value", TTypeInfo(NTypeIds::Uint32)), - std::make_pair("Int64Value", TTypeInfo(NTypeIds::Int64)), - std::make_pair("Uint64Value", TTypeInfo(NTypeIds::Uint64)), - std::make_pair("FloatValue", TTypeInfo(NTypeIds::Float)), - std::make_pair("DoubleValue", TTypeInfo(NTypeIds::Double)), - std::make_pair("DecimalValue", TTypeInfo(NScheme::TDecimalType(22, 2))) - })); - - builder.AddRow().AddNull().Add(-1).Add(1).Add(-2).Add(2).Add(-3).Add(3).Add(-4).Add(4).Add(5.0).Add(6.0).Add(TDecimalValue("7.77", 22, 2)); - builder.AddRow().Add(false).Add(-1).Add(1).Add(-2).Add(2).Add(-3).Add(3).Add(-4).Add(4).Add(5.0).Add(6.0).Add(TDecimalValue("7.77", 22, 2)); - builder.AddRow().Add(true).Add(-1).Add(1).Add(-2).Add(2).Add(-3).Add(3).Add(-4).Add(4).Add(5.0).Add(6.0).Add(TDecimalValue("7.77", 22, 2)); - - auto expected = builder.BuildArrow(); - UNIT_ASSERT_VALUES_EQUAL(batches.front()->ToString(), expected->ToString()); - } - } - - /** - * These YQL types are supported for Arrow format as string types: - */ - Y_UNIT_TEST(ArrowFormat_Types_String) { - auto kikimr = CreateKikimrRunner(/* withSampleTables */ false); - auto client = kikimr.GetQueryClient(); - - { - auto result = client.ExecuteQuery(R"( - CREATE TABLE StringTypesTable ( - Utf8Value Utf8, - JsonValue Json, - Utf8NotNullValue Utf8 NOT NULL, - JsonNotNullValue Json NOT NULL, - PRIMARY KEY (Utf8Value) - ); - )", TTxControl::NoTx()).GetValueSync(); - UNIT_ASSERT_C(result.IsSuccess(), result.GetIssues().ToString()); - } - { - auto result = client.ExecuteQuery(R"( - INSERT INTO StringTypesTable (Utf8Value, JsonValue, Utf8NotNullValue, JsonNotNullValue) VALUES - ("John", "[1]", "John", "[2]"), - (NULL, "[]", "Maria", "[3]"), - ("Leo", NULL, "Leo", "[4]"), - ("Michael", "[5]", "Michael", "[6]"); - )", TTxControl::BeginTx().CommitTx()).GetValueSync(); - UNIT_ASSERT_C(result.IsSuccess(), result.GetIssues().ToString()); - } - { - auto batches = ExecuteAndCombineBatches(client, R"( - SELECT Utf8Value, JsonValue, Utf8NotNullValue, JsonNotNullValue - FROM StringTypesTable ORDER BY Utf8Value; - )", /* assertSize */ true); - - UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); - - NColumnShard::TTableUpdatesBuilder builder(NArrow::MakeArrowSchema({ - std::make_pair("Utf8Value", TTypeInfo(NTypeIds::Utf8)), - std::make_pair("JsonValue", TTypeInfo(NTypeIds::Json)), - std::make_pair("Utf8NotNullValue", TTypeInfo(NTypeIds::Utf8)), - std::make_pair("JsonNotNullValue", TTypeInfo(NTypeIds::Json)) - })); - - builder.AddRow().AddNull().Add("[]").Add("Maria").Add("[3]"); - builder.AddRow().Add("John").Add("[1]").Add("John").Add("[2]"); - builder.AddRow().Add("Leo").AddNull().Add("Leo").Add("[4]"); - builder.AddRow().Add("Michael").Add("[5]").Add("Michael").Add("[6]"); - - auto expected = builder.BuildArrow(); - UNIT_ASSERT_VALUES_EQUAL(batches.front()->ToString(), expected->ToString()); - } - } - - /** - * These YQL types are supported for Arrow format as binary types: - */ - Y_UNIT_TEST(ArrowFormat_Types_Binary) { - auto kikimr = CreateKikimrRunner(/* withSampleTables */ false); - auto client = kikimr.GetQueryClient(); - - { - auto result = client.ExecuteQuery(R"( - CREATE TABLE BinaryTypesTable ( - StringValue String, - YsonValue Yson, - DyNumberValue DyNumber, - JsonDocumentValue JsonDocument, - StringNotNullValue String NOT NULL, - YsonNotNullValue Yson NOT NULL, - JsonDocumentNotNullValue JsonDocument NOT NULL, - DyNumberNotNullValue DyNumber NOT NULL, - PRIMARY KEY (StringValue) - ); - )", TTxControl::NoTx()).GetValueSync(); - UNIT_ASSERT_C(result.IsSuccess(), result.GetIssues().ToString()); - } - { - auto result = client.ExecuteQuery(R"( - INSERT INTO BinaryTypesTable (StringValue, YsonValue, DyNumberValue, JsonDocumentValue, StringNotNullValue, YsonNotNullValue, JsonDocumentNotNullValue, DyNumberNotNullValue) VALUES - ("John", "[1]", DyNumber("1.0"), JsonDocument("{\"a\": 1}"), "Mark", "[2]", JsonDocument("{\"b\": 2}"), DyNumber("4.0")), - (NULL, "[4]", NULL, NULL, "Maria", "[5]", JsonDocument("[6]"), DyNumber("7.0")), - ("Mark", NULL, NULL, NULL, "Michael", "[7]", JsonDocument("[8]"), DyNumber("9.0")), - ("Leo", "[10]", DyNumber("11.0"), JsonDocument("[12]"), "Maria", "[13]", JsonDocument("[14]"), DyNumber("15.0")); - )", TTxControl::BeginTx().CommitTx()).GetValueSync(); - UNIT_ASSERT_C(result.IsSuccess(), result.GetIssues().ToString()); - } - { - auto batches = ExecuteAndCombineBatches(client, R"( - SELECT StringValue, YsonValue, DyNumberValue, JsonDocumentValue, StringNotNullValue, YsonNotNullValue, JsonDocumentNotNullValue, DyNumberNotNullValue - FROM BinaryTypesTable ORDER BY StringValue; - )", /* assertSize */ true); - - UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); - - NColumnShard::TTableUpdatesBuilder builder(NArrow::MakeArrowSchema({ - std::make_pair("StringValue", TTypeInfo(NTypeIds::String)), - std::make_pair("YsonValue", TTypeInfo(NTypeIds::Yson)), - std::make_pair("DyNumberValue", TTypeInfo(NTypeIds::DyNumber)), - std::make_pair("JsonDocumentValue", TTypeInfo(NTypeIds::JsonDocument)), - std::make_pair("StringNotNullValue", TTypeInfo(NTypeIds::String)), - std::make_pair("YsonNotNullValue", TTypeInfo(NTypeIds::Yson)), - std::make_pair("JsonDocumentNotNullValue", TTypeInfo(NTypeIds::JsonDocument)), - std::make_pair("DyNumberNotNullValue", TTypeInfo(NTypeIds::DyNumber)) - })); - - builder.AddRow().AddNull().Add("[4]").AddNull().AddNull().Add("Maria").Add("[5]").Add(SerializeToBinaryJsonString("[6]")).Add(NDyNumber::ParseDyNumberString("7.0")->c_str()); - builder.AddRow().Add("John").Add("[1]").Add(NDyNumber::ParseDyNumberString("1.0")->c_str()).Add(SerializeToBinaryJsonString("{\"a\": 1}")).Add("Mark").Add("[2]").Add(SerializeToBinaryJsonString("{\"b\": 2}")).Add(NDyNumber::ParseDyNumberString("4.0")->c_str()); - builder.AddRow().Add("Leo").Add("[10]").Add(NDyNumber::ParseDyNumberString("11.0")->c_str()).Add(SerializeToBinaryJsonString("[12]")).Add("Maria").Add("[13]").Add(SerializeToBinaryJsonString("[14]")).Add(NDyNumber::ParseDyNumberString("15.0")->c_str()); - builder.AddRow().Add("Mark").AddNull().AddNull().AddNull().Add("Michael").Add("[7]").Add(SerializeToBinaryJsonString("[8]")).Add(NDyNumber::ParseDyNumberString("9.0")->c_str()); - - - auto expected = builder.BuildArrow(); - UNIT_ASSERT_VALUES_EQUAL(batches.front()->ToString(), expected->ToString()); - } - } - - /** - * These YQL types are supported for Arrow format as integer and time types: - */ - Y_UNIT_TEST(ArrowFormat_Types_Time) { - auto kikimr = CreateKikimrRunner(/* withSampleTables */ false); - auto client = kikimr.GetQueryClient(); - - { - auto result = client.ExecuteQuery(R"( - CREATE TABLE TimeTypesTable ( - DateValue Date, - DatetimeValue Datetime, - TimestampValue Timestamp, - IntervalValue Interval, - PRIMARY KEY (DateValue) - ); - )", TTxControl::NoTx()).GetValueSync(); - UNIT_ASSERT_C(result.IsSuccess(), result.GetIssues().ToString()); - } - { - auto result = client.ExecuteQuery(R"( - INSERT INTO TimeTypesTable (DateValue, DatetimeValue, TimestampValue, IntervalValue) VALUES - (Date("2001-01-01"), Datetime("2002-02-02T02:02:02Z"), Timestamp("2003-03-03T03:03:03Z"), Interval("P7D")); - )", TTxControl::BeginTx().CommitTx()).GetValueSync(); - UNIT_ASSERT_C(result.IsSuccess(), result.GetIssues().ToString()); - } - { - auto batches = ExecuteAndCombineBatches(client, R"( - SELECT DateValue, DatetimeValue, TimestampValue, IntervalValue - FROM TimeTypesTable ORDER BY DateValue; - )", /* assertSize */ true); - - UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); - - NColumnShard::TTableUpdatesBuilder builder(NArrow::MakeArrowSchema({ - std::make_pair("DateValue", TTypeInfo(NTypeIds::Date)), - std::make_pair("DatetimeValue", TTypeInfo(NTypeIds::Datetime)), - std::make_pair("TimestampValue", TTypeInfo(NTypeIds::Timestamp)), - std::make_pair("IntervalValue", TTypeInfo(NTypeIds::Interval)) - })); - - builder.AddRow().Add(11323).Add(1012615322).Add(1046660583000000).Add(604800000000); - - auto expected = builder.BuildArrow(); - UNIT_ASSERT_VALUES_EQUAL(batches.front()->ToString(), expected->ToString()); - } - } - - /** - * Arrow format is supported for compression. - * By default, unspecified compression codec is None (without compression). - */ - Y_UNIT_TEST(ArrowFormat_Compression_None) { - auto kikimr = CreateKikimrRunner(/* withSampleTables */ true); - auto client = kikimr.GetQueryClient(); - - CompareCompressedAndDefaultBatches(client, std::nullopt, /* assertEqual */ true); - - CompareCompressedAndDefaultBatches(client, TArrowFormatSettings::TCompressionCodec().Type(TArrowFormatSettings::TCompressionCodec::EType::None), /* assertEqual */ true); - } - - /** - * Arrow format is supported for compression by ZSTD codec. - * Compression level is supported. - */ - Y_UNIT_TEST(ArrowFormat_Compression_ZSTD) { - auto kikimr = CreateKikimrRunner(/* withSampleTables */ true); - auto client = kikimr.GetQueryClient(); - - CompareCompressedAndDefaultBatches(client, TArrowFormatSettings::TCompressionCodec().Type(TArrowFormatSettings::TCompressionCodec::EType::Zstd).Level(12), /* assertEqual */ false); - } - - /** - * Arrow format is supported for compression by LZ4_FRAME codec. - * Compression level is not supported. - */ - Y_UNIT_TEST(ArrowFormat_Compression_LZ4_FRAME) { - auto kikimr = CreateKikimrRunner(/* withSampleTables */ true); - auto client = kikimr.GetQueryClient(); - - CompareCompressedAndDefaultBatches(client, TArrowFormatSettings::TCompressionCodec().Type(TArrowFormatSettings::TCompressionCodec::EType::Lz4Frame), /* assertEqual */ false); - - { - auto settings = TExecuteQuerySettings() - .Format(TResultSet::EFormat::Arrow) - .ArrowFormatSettings(TArrowFormatSettings() - .CompressionCodec(TArrowFormatSettings::TCompressionCodec() - .Type(TArrowFormatSettings::TCompressionCodec::EType::Lz4Frame) - .Level(12))); - - auto result = client.ExecuteQuery(R"( - SELECT Comment, Amount, Name FROM Test ORDER BY Amount DESC; - )", TTxControl::BeginTx().CommitTx(), settings).GetValueSync(); - - UNIT_ASSERT_VALUES_EQUAL(result.GetStatus(), EStatus::INTERNAL_ERROR); - UNIT_ASSERT_STRING_CONTAINS(result.GetIssues().ToString(), "Codec 'lz4' doesn't support setting a compression level"); - } - } - - /** - * Arrow batches are returned for different result set indexes. - */ - Y_UNIT_TEST(ArrowFormat_Multistatement) { - auto kikimr = CreateKikimrRunner(/* withSampleTables */ true); - auto client = kikimr.GetQueryClient(); - - auto settings = TExecuteQuerySettings().Format(TResultSet::EFormat::Arrow); - - auto result = client.ExecuteQuery(R"( - SELECT Comment, Amount, Name FROM Test ORDER BY Amount DESC; - SELECT Key, Value FROM KeyValue ORDER BY Key; - )", TTxControl::BeginTx().CommitTx(), settings).GetValueSync(); - UNIT_ASSERT_C(result.IsSuccess(), result.GetIssues().ToString()); - - UNIT_ASSERT_VALUES_EQUAL(result.GetResultSets().size(), 2); - - { - auto resultSet = result.GetResultSet(0); - UNIT_ASSERT_VALUES_EQUAL(TArrowAccessor::Format(resultSet), TResultSet::EFormat::Arrow); - - const auto& schema = TArrowAccessor::GetArrowSchema(resultSet); - const auto& batches = TArrowAccessor::GetArrowBatches(resultSet); - - UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); - - auto arrowSchema = NArrow::DeserializeSchema(TString(schema)); - auto arrowBatch = NArrow::DeserializeBatch(TString(batches[0]), arrowSchema); - - UNIT_ASSERT_C(arrowSchema, "Schema must be deserialized"); - UNIT_ASSERT_C(arrowBatch, "Batch must be deserialized"); - UNIT_ASSERT_C(arrowBatch->ValidateFull().ok(), "Batch validation failed"); - - NColumnShard::TTableUpdatesBuilder builder(NArrow::MakeArrowSchema({ - std::make_pair("Comment", TTypeInfo(NTypeIds::String)), - std::make_pair("Amount", TTypeInfo(NTypeIds::Uint64)), - std::make_pair("Name", TTypeInfo(NTypeIds::String)) - })); - - builder.AddRow().Add("None").Add(7200).Add("Tony"); - builder.AddRow().Add("None").Add(3500).Add("Anna"); - builder.AddRow().Add("None").Add(300).Add("Paul"); - - auto expected = builder.BuildArrow(); - UNIT_ASSERT_VALUES_EQUAL(arrowBatch->ToString(), expected->ToString()); - } - { - auto resultSet = result.GetResultSet(1); - UNIT_ASSERT_VALUES_EQUAL(TArrowAccessor::Format(resultSet), TResultSet::EFormat::Arrow); - - const auto& schema = TArrowAccessor::GetArrowSchema(resultSet); - const auto& batches = TArrowAccessor::GetArrowBatches(resultSet); - - UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); - - auto arrowSchema = NArrow::DeserializeSchema(TString(schema)); - auto arrowBatch = NArrow::DeserializeBatch(TString(batches[0]), arrowSchema); - - UNIT_ASSERT_C(arrowSchema, "Schema must be deserialized"); - UNIT_ASSERT_C(arrowBatch, "Batch must be deserialized"); - UNIT_ASSERT_C(arrowBatch->ValidateFull().ok(), "Batch validation failed"); - - NColumnShard::TTableUpdatesBuilder builder(NArrow::MakeArrowSchema({ - std::make_pair("Key", TTypeInfo(NTypeIds::Uint64)), - std::make_pair("Value", TTypeInfo(NTypeIds::String)) - })); - - builder.AddRow().Add(1).Add("One"); - builder.AddRow().Add(2).Add("Two"); - - auto expected = builder.BuildArrow(); - UNIT_ASSERT_VALUES_EQUAL(arrowBatch->ToString(), expected->ToString()); - } - } - - /** - * By default, SchemaInclusionMode is ALWAYS for Arrow format. - */ - Y_UNIT_TEST(ArrowFormat_SchemaInclusionMode_Unspecified) { - auto kikimr = CreateKikimrRunner(/* withSampleTables */ false, 1_KB); - auto client = kikimr.GetQueryClient(); - - CreateLargeTable(kikimr, 100, 2, 2, 10, 2); - - auto settings = TExecuteQuerySettings().Format(TResultSet::EFormat::Arrow); - - auto it = client.StreamExecuteQuery(R"( - SELECT * FROM LargeTable; - )", TTxControl::BeginTx().CommitTx(), settings).GetValueSync(); - UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString()); - - std::shared_ptr arrowSchema; - - size_t count = 0; - for (;;) { - auto part = it.ReadNext().GetValueSync(); - if (!part.IsSuccess()) { - UNIT_ASSERT_C(part.EOS(), part.GetIssues().ToString()); - break; - } - - if (part.HasResultSet()) { - auto resultSet = part.ExtractResultSet(); - UNIT_ASSERT_VALUES_EQUAL(TArrowAccessor::Format(resultSet), TResultSet::EFormat::Arrow); - - const auto& schema = TArrowAccessor::GetArrowSchema(resultSet); - const auto& batches = TArrowAccessor::GetArrowBatches(resultSet); - - UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); - - // With Arrow format, the result set contains a YQL schema and an Arrow record batch schema - UNIT_ASSERT_C(!schema.empty(), "Schema must not be empty"); - UNIT_ASSERT_VALUES_UNEQUAL_C(resultSet.ColumnsCount(), 0, "Columns must not be empty for the first result set"); - - auto curSchema = NArrow::DeserializeSchema(TString(schema)); - UNIT_ASSERT_C(curSchema, "Schema must be deserialized"); - - if (arrowSchema) { - UNIT_ASSERT_VALUES_EQUAL(arrowSchema->ToString(), curSchema->ToString()); - } else { - arrowSchema = curSchema; - } - - auto arrowBatch = NArrow::DeserializeBatch(TString(batches[0]), arrowSchema); - UNIT_ASSERT_C(arrowBatch, "Batch must be deserialized"); - UNIT_ASSERT_C(arrowBatch->ValidateFull().ok(), "Batch validation failed"); - UNIT_ASSERT_GT_C(arrowBatch->num_rows(), 0, "Batch must have at least 1 row"); - - ++count; - } - } - - UNIT_ASSERT_GT_C(count, 1, "Expected at least 2 result sets"); - } - - /** - * Set SchemaInclusionMode ALWAYS for Arrow format explicitly in TExecuteQuerySettings. - */ - Y_UNIT_TEST(ArrowFormat_SchemaInclusionMode_Always) { - auto kikimr = CreateKikimrRunner(/* withSampleTables */ false, 1_KB); - auto client = kikimr.GetQueryClient(); - - CreateLargeTable(kikimr, 100, 2, 2, 10, 2); - - auto settings = TExecuteQuerySettings().Format(TResultSet::EFormat::Arrow).SchemaInclusionMode(ESchemaInclusionMode::Always); - - auto it = client.StreamExecuteQuery(R"( - SELECT * FROM LargeTable; - )", TTxControl::BeginTx().CommitTx(), settings).GetValueSync(); - UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString()); - - std::shared_ptr arrowSchema; - - size_t count = 0; - for (;;) { - auto part = it.ReadNext().GetValueSync(); - if (!part.IsSuccess()) { - UNIT_ASSERT_C(part.EOS(), part.GetIssues().ToString()); - break; - } - - if (part.HasResultSet()) { - auto resultSet = part.ExtractResultSet(); - UNIT_ASSERT_VALUES_EQUAL(TArrowAccessor::Format(resultSet), TResultSet::EFormat::Arrow); - - const auto& schema = TArrowAccessor::GetArrowSchema(resultSet); - const auto& batches = TArrowAccessor::GetArrowBatches(resultSet); - - UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); - - // With Arrow format, the result set contains a YQL schema and an Arrow record batch schema - UNIT_ASSERT_C(!schema.empty(), "Schema must not be empty"); - UNIT_ASSERT_VALUES_UNEQUAL_C(resultSet.ColumnsCount(), 0, "Columns must not be empty for the first result set"); - - auto curSchema = NArrow::DeserializeSchema(TString(schema)); - UNIT_ASSERT_C(curSchema, "Schema must be deserialized"); - - if (arrowSchema) { - UNIT_ASSERT_VALUES_EQUAL(arrowSchema->ToString(), curSchema->ToString()); - } else { - arrowSchema = curSchema; - } - - auto arrowBatch = NArrow::DeserializeBatch(TString(batches[0]), arrowSchema); - UNIT_ASSERT_C(arrowBatch, "Batch must be deserialized"); - UNIT_ASSERT_C(arrowBatch->ValidateFull().ok(), "Batch validation failed"); - UNIT_ASSERT_GT_C(arrowBatch->num_rows(), 0, "Batch must have at least 1 row"); - - ++count; - } - } - - UNIT_ASSERT_GT_C(count, 1, "Expected at least 2 result sets"); - } - - /** - * Set SchemaInclusionMode FIRST_ONLY for Arrow format explicitly in TExecuteQuerySettings. - */ - Y_UNIT_TEST(ArrowFormat_SchemaInclusionMode_FirstOnly) { - auto kikimr = CreateKikimrRunner(/* withSampleTables */ false, 1_KB); - auto client = kikimr.GetQueryClient(); - - CreateLargeTable(kikimr, 100, 2, 2, 10, 2); - - auto settings = TExecuteQuerySettings().Format(TResultSet::EFormat::Arrow).SchemaInclusionMode(ESchemaInclusionMode::FirstOnly); - - auto it = client.StreamExecuteQuery(R"( - SELECT * FROM LargeTable; - )", TTxControl::BeginTx().CommitTx(), settings).GetValueSync(); - UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString()); - - std::shared_ptr arrowSchema; - - size_t count = 0; - for (;;) { - auto part = it.ReadNext().GetValueSync(); - if (!part.IsSuccess()) { - UNIT_ASSERT_C(part.EOS(), part.GetIssues().ToString()); - break; - } - - if (part.HasResultSet()) { - auto resultSet = part.ExtractResultSet(); - UNIT_ASSERT_VALUES_EQUAL(TArrowAccessor::Format(resultSet), TResultSet::EFormat::Arrow); - - const auto& schema = TArrowAccessor::GetArrowSchema(resultSet); - const auto& batches = TArrowAccessor::GetArrowBatches(resultSet); - - UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); - - if (count == 0) { - // With Arrow format, the result set contains a YQL schema and an Arrow record batch schema - UNIT_ASSERT_VALUES_UNEQUAL_C(resultSet.ColumnsCount(), 0, "Columns must not be empty for the first result set"); - UNIT_ASSERT_C(!schema.empty(), "Schema must not be empty for the first result set"); - - arrowSchema = NArrow::DeserializeSchema(TString(schema)); - - UNIT_ASSERT_C(arrowSchema, "Schema must be deserialized"); - } else { - UNIT_ASSERT_VALUES_EQUAL_C(resultSet.ColumnsCount(), 0, "Columns count must be empty for the rest result sets"); - UNIT_ASSERT_C(schema.empty(), "Schema must be empty for the rest result sets"); - } - - auto arrowBatch = NArrow::DeserializeBatch(TString(batches[0]), arrowSchema); - UNIT_ASSERT_C(arrowBatch, "Batch must be deserialized"); - UNIT_ASSERT_C(arrowBatch->ValidateFull().ok(), "Batch validation failed"); - UNIT_ASSERT_GT_C(arrowBatch->num_rows(), 0, "Batch must have at least 1 row"); - - ++count; - } - } - - UNIT_ASSERT_GT_C(count, 1, "Expected at least 2 result sets"); - } - - /** - * For Arrow format, FirstOnly schema inclusion mode is supported for multistatement queries. - */ - Y_UNIT_TEST(ArrowFormat_SchemaInclusionMode_FirstOnly_Multistatement) { - auto kikimr = CreateKikimrRunner(/* withSampleTables */ false, 1_KB); - auto client = kikimr.GetQueryClient(); - - CreateLargeTable(kikimr, 200, 2, 2, 10, 2); - - auto settings = TExecuteQuerySettings().Format(TResultSet::EFormat::Arrow).SchemaInclusionMode(ESchemaInclusionMode::FirstOnly); - - auto it = client.StreamExecuteQuery(R"( - SELECT * FROM LargeTable; - SELECT Key, Data FROM LargeTable; - )", TTxControl::BeginTx().CommitTx(), settings).GetValueSync(); - UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString()); - - std::unordered_map> arrowSchemas; - std::unordered_map counts; - - for (;;) { - auto part = it.ReadNext().GetValueSync(); - if (!part.IsSuccess()) { - UNIT_ASSERT_C(part.EOS(), part.GetIssues().ToString()); - break; - } - - if (part.HasResultSet()) { - auto resultSet = part.ExtractResultSet(); - UNIT_ASSERT_VALUES_EQUAL(TArrowAccessor::Format(resultSet), TResultSet::EFormat::Arrow); - - const auto& schema = TArrowAccessor::GetArrowSchema(resultSet); - const auto& batches = TArrowAccessor::GetArrowBatches(resultSet); - - UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); - - auto idx = part.GetResultSetIndex(); - - if (arrowSchemas.find(idx) == arrowSchemas.end()) { - // The first result set of each statement contains schemas. - UNIT_ASSERT_VALUES_UNEQUAL_C(resultSet.ColumnsCount(), 0, "Columns must not be empty for the first result set of the statement"); - UNIT_ASSERT_C(!schema.empty(), "Schema must not be empty for the first result set of the statement"); - - arrowSchemas[idx] = NArrow::DeserializeSchema(TString(schema)); - - UNIT_ASSERT_C(arrowSchemas[idx], "Schema must be deserialized"); - } else { - UNIT_ASSERT_VALUES_EQUAL_C(resultSet.ColumnsCount(), 0, "Columns count must be empty for the rest result sets of the statement"); - UNIT_ASSERT_C(schema.empty(), "Schema must be empty for the rest result sets of the statement"); - } - - auto arrowBatch = NArrow::DeserializeBatch(TString(batches[0]), arrowSchemas[idx]); - UNIT_ASSERT_C(arrowBatch, "Batch must be deserialized"); - UNIT_ASSERT_C(arrowBatch->ValidateFull().ok(), "Batch validation failed"); - UNIT_ASSERT_GT_C(arrowBatch->num_rows(), 0, "Batch must have at least 1 row"); - - ++counts[idx]; - } - } - - UNIT_ASSERT_C(counts.size() == 2, "Expected 2 result set indexes"); - - for (const auto& [idx, count] : counts) { - UNIT_ASSERT_GT_C(count, 1, "Expected at least 2 result sets for statement with ResultSetIndex = " << idx); - } - } - - /** - * Small stress test for Arrow format: - * - 3 statements - * - 10 shards, 1000 rows per shard - * - ZSTD compression with level 10 - * - SchemaInclusionMode is FIRST_ONLY - * - ChannelBufferSize is 1KB - */ - Y_UNIT_TEST(ArrowFormat_Stress) { - auto kikimr = CreateKikimrRunner(/* withSampleTables */ false, 1_KB); - auto client = kikimr.GetQueryClient(); - - CreateLargeTable(kikimr, 1000, 4, 4, 100, 10); - - auto settings = TExecuteQuerySettings() - .Format(TResultSet::EFormat::Arrow) - .SchemaInclusionMode(ESchemaInclusionMode::FirstOnly) - .ArrowFormatSettings(TArrowFormatSettings() - .CompressionCodec(TArrowFormatSettings::TCompressionCodec() - .Type(TArrowFormatSettings::TCompressionCodec::EType::Zstd) - .Level(10))); - - auto it = client.StreamExecuteQuery(R"( - SELECT * FROM LargeTable; - UPDATE LargeTable SET Data = Data + 1 WHERE Key % 2 = 1 RETURNING Data; - SELECT DataText FROM LargeTable WHERE Key % 2 = 0; - )", TTxControl::BeginTx().CommitTx(), settings).GetValueSync(); - UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString()); - - std::unordered_map> arrowSchemas; - std::unordered_map counts; - - for (;;) { - auto part = it.ReadNext().GetValueSync(); - if (!part.IsSuccess()) { - UNIT_ASSERT_C(part.EOS(), part.GetIssues().ToString()); - break; - } - - if (part.HasResultSet()) { - auto resultSet = part.ExtractResultSet(); - UNIT_ASSERT_VALUES_EQUAL(TArrowAccessor::Format(resultSet), TResultSet::EFormat::Arrow); - - const auto& schema = TArrowAccessor::GetArrowSchema(resultSet); - const auto& batches = TArrowAccessor::GetArrowBatches(resultSet); - - UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); - - auto idx = part.GetResultSetIndex(); - - if (arrowSchemas.find(idx) == arrowSchemas.end()) { - // The first result set of each statement contains schemas. - UNIT_ASSERT_VALUES_UNEQUAL_C(resultSet.ColumnsCount(), 0, "Columns must not be empty for the first result set of the statement"); - UNIT_ASSERT_C(!schema.empty(), "Schema must not be empty for the first result set of the statement"); - - arrowSchemas[idx] = NArrow::DeserializeSchema(TString(schema)); - - UNIT_ASSERT_C(arrowSchemas[idx], "Schema must be deserialized"); - } else { - UNIT_ASSERT_VALUES_EQUAL_C(resultSet.ColumnsCount(), 0, "Columns count must be empty for the rest result sets of the statement"); - UNIT_ASSERT_C(schema.empty(), "Schema must be empty for the rest result sets of the statement"); - } - - auto arrowBatch = NArrow::DeserializeBatch(TString(batches[0]), arrowSchemas[idx]); - UNIT_ASSERT_C(arrowBatch, "Batch must be deserialized"); - UNIT_ASSERT_C(arrowBatch->ValidateFull().ok(), "Batch validation failed"); - UNIT_ASSERT_GT_C(arrowBatch->num_rows(), 0, "Batch must have at least 1 row"); - - ++counts[idx]; - } - } - - UNIT_ASSERT_C(counts.size() == 3, "Expected 3 result set indexes"); - - for (const auto& [idx, count] : counts) { - UNIT_ASSERT_GT_C(count, 1, "Expected at least 2 result sets for statement with ResultSetIndex = " << idx); - } - } -} - -} // namespace NKikimr::NKqp diff --git a/ydb/core/kqp/ut/arrow/kqp_result_set_formats.cpp b/ydb/core/kqp/ut/arrow/kqp_result_set_formats.cpp new file mode 100644 index 00000000000..92414804a2b --- /dev/null +++ b/ydb/core/kqp/ut/arrow/kqp_result_set_formats.cpp @@ -0,0 +1,2037 @@ +#include + +#include +#include +#include + +#include + +#include + +#include +#include + +namespace NKikimr::NKqp { + +using namespace NYdb; +using namespace NYdb::NQuery; +using TTypeInfo = NScheme::TTypeInfo; +namespace NTypeIds = NScheme::NTypeIds; + +namespace { + +TKikimrRunner CreateKikimrRunner(bool withSampleTables, ui64 channelBufferSize = 8_MB) { + NKikimrConfig::TFeatureFlags featureFlags; + featureFlags.SetEnableArrowResultSetFormat(true); + + NKikimrConfig::TAppConfig appConfig; + appConfig.MutableTableServiceConfig()->SetEnableOlapSink(true); + appConfig.MutableTableServiceConfig()->MutableResourceManager()->SetChannelBufferSize(channelBufferSize); + + auto settings = TKikimrSettings(appConfig).SetFeatureFlags(featureFlags).SetWithSampleTables(withSampleTables); + return TKikimrRunner(settings); +} + +void CreateAllTypesRowTable(TQueryClient& client) { + auto createResult = client.ExecuteQuery(R"( + CREATE TABLE `/Root/RowTable` ( + Key Uint64, + BoolValue Bool, + Int8Value Int8, + Uint8Value Uint8, + Int16Value Int16, + Uint16Value Uint16, + Int32Value Int32, + Uint32Value Uint32, + Int64Value Int64, + Uint64Value Uint64, + FloatValue Float, + DoubleValue Double, + StringValue String, + Utf8Value Utf8, + DateValue Date, + DatetimeValue Datetime, + TimestampValue Timestamp, + IntervalValue Interval, + DecimalValue Decimal(22,9), + JsonValue Json, + YsonValue Yson, + JsonDocumentValue JsonDocument, + DyNumberValue DyNumber, + UuidValue Uuid, + Int32NotNullValue Int32 NOT NULL, + PRIMARY KEY (Key) + ); + )", TTxControl::NoTx()).GetValueSync(); + UNIT_ASSERT_C(createResult.IsSuccess(), createResult.GetIssues().ToString()); + + auto insertResult = client.ExecuteQuery(R"( + INSERT INTO `/Root/RowTable` (Key, BoolValue, Int8Value, Uint8Value, Int16Value, Uint16Value, Int32Value, Uint32Value, Int64Value, Uint64Value, FloatValue, DoubleValue, StringValue, Utf8Value, DateValue, DatetimeValue, TimestampValue, IntervalValue, DecimalValue, JsonValue, YsonValue, JsonDocumentValue, DyNumberValue, UuidValue, Int32NotNullValue) VALUES + (42, true, -1, 1, -2, 2, -3, 3, -4, 4, CAST(3.0 AS Float), 4.0, "five", Utf8("six"), Date("2007-07-07"), Datetime("2008-08-08T08:08:08Z"), Timestamp("2009-09-09T09:09:09.09Z"), Interval("P10D"), CAST("11.11" AS Decimal(22, 9)), "[12]", "[13]", JsonDocument("[14]"), DyNumber("15.15"), Uuid("5b99a330-04ef-4f1a-9b64-ba6d5f44eafe"), 123); + )", TTxControl::BeginTx().CommitTx()).GetValueSync(); + UNIT_ASSERT_C(insertResult.IsSuccess(), insertResult.GetIssues().ToString()); +} + +void CreateAllTypesColumnTable(TQueryClient& client) { + auto createResult = client.ExecuteQuery(R"( + CREATE TABLE `/Root/ColumnTable` ( + Key Uint64 NOT NULL, + Int8Value Int8, + Uint8Value Uint8, + Int16Value Int16, + Uint16Value Uint16, + Int32Value Int32, + Uint32Value Uint32, + Int64Value Int64, + Uint64Value Uint64, + FloatValue Float, + DoubleValue Double, + StringValue String, + Utf8Value Utf8, + DateValue Date, + DatetimeValue Datetime, + TimestampValue Timestamp, + JsonValue Json, + YsonValue Yson, + JsonDocumentValue JsonDocument, + PRIMARY KEY (Key) + ) WITH ( + STORE = COLUMN + ); + )", TTxControl::NoTx()).GetValueSync(); + UNIT_ASSERT_C(createResult.IsSuccess(), createResult.GetIssues().ToString()); + + auto insertResult = client.ExecuteQuery(R"( + INSERT INTO `/Root/ColumnTable` (Key, Int8Value, Uint8Value, Int16Value, Uint16Value, Int32Value, Uint32Value, Int64Value, Uint64Value, FloatValue, DoubleValue, StringValue, Utf8Value, DateValue, DatetimeValue, TimestampValue, JsonValue, YsonValue, JsonDocumentValue) VALUES + (42, -1, 1, -2, 2, -3, 3, -4, 4, CAST(3.0 AS Float), 4.0, "five", Utf8("six"), Date("2007-07-07"), Datetime("2008-08-08T08:08:08Z"), Timestamp("2009-09-09T09:09:09.09Z"), "[12]", "[13]", JsonDocument("[14]")); + )", TTxControl::BeginTx().CommitTx()).GetValueSync(); + UNIT_ASSERT_C(insertResult.IsSuccess(), insertResult.GetIssues().ToString()); +} + +void AssertArrowValueResultsSize(const std::vector& arrowResultSets, const std::vector& valueResultSets) { + UNIT_ASSERT_VALUES_EQUAL_C(arrowResultSets.size(), valueResultSets.size(), "Result sets count mismatch"); + + for (size_t i = 0; i < arrowResultSets.size(); ++i) { + const auto& arrowResultSet = arrowResultSets[i]; + const auto& valueResultSet = valueResultSets[i]; + + UNIT_ASSERT_VALUES_EQUAL_C(TArrowAccessor::Format(arrowResultSet), TResultSet::EFormat::Arrow, "Result set format mismatch"); + UNIT_ASSERT_VALUES_EQUAL_C(TArrowAccessor::Format(valueResultSet), TResultSet::EFormat::Value, "Result set format mismatch"); + + UNIT_ASSERT_VALUES_EQUAL_C(arrowResultSet.RowsCount(), 0, "Rows must be empty for Arrow format of the result set"); + + size_t arrowRowsCount = 0; + + const auto& schema = TArrowAccessor::GetArrowSchema(arrowResultSet); + const auto& batches = TArrowAccessor::GetArrowBatches(arrowResultSet); + + UNIT_ASSERT_C(!schema.empty(), "Schema must not be empty"); + + std::shared_ptr arrowSchema = NArrow::DeserializeSchema(TString(schema)); + + for (const auto& batch : batches) { + auto arrowBatch = NArrow::DeserializeBatch(TString(batch), arrowSchema); + UNIT_ASSERT_C(arrowBatch, "Batch must be deserialized"); + UNIT_ASSERT_C(arrowBatch->ValidateFull().ok(), "Batch validation failed"); + + arrowRowsCount += arrowBatch->num_rows(); + + UNIT_ASSERT_VALUES_EQUAL_C(arrowBatch->num_columns(), valueResultSet.ColumnsCount(), "Columns count mismatch"); + } + + UNIT_ASSERT_VALUES_EQUAL_C(arrowRowsCount, valueResultSet.RowsCount(), "Rows count mismatch"); + } +} + +std::vector> ExecuteAndCombineBatches(TQueryClient& client, const TString& query, bool assertSize = false, ui64 minBatchesCount = 1, TParams params = TParamsBuilder().Build()) { + auto arrowSettings = TExecuteQuerySettings().Format(TResultSet::EFormat::Arrow); + auto arrowResponse = client.ExecuteQuery(query, TTxControl::BeginTx().CommitTx(), params, arrowSettings).GetValueSync(); + UNIT_ASSERT_C(arrowResponse.IsSuccess(), arrowResponse.GetIssues().ToString()); + + if (assertSize) { + auto valueSettings = TExecuteQuerySettings().Format(TResultSet::EFormat::Value); + auto valueResponse = client.ExecuteQuery(query, TTxControl::BeginTx().CommitTx(), params, valueSettings).GetValueSync(); + UNIT_ASSERT_C(valueResponse.IsSuccess(), valueResponse.GetIssues().ToString()); + AssertArrowValueResultsSize(arrowResponse.GetResultSets(), valueResponse.GetResultSets()); + } + + std::vector> resultBatches; + + for (const auto& resultSet : arrowResponse.GetResultSets()) { + const auto& schema = TArrowAccessor::GetArrowSchema(resultSet); + const auto& batches = TArrowAccessor::GetArrowBatches(resultSet); + + UNIT_ASSERT_C(!schema.empty(), "Schema must not be empty"); + UNIT_ASSERT_GE_C(batches.size(), minBatchesCount, "Batches count must be greater than or equal to " + ToString(minBatchesCount)); + + std::vector> arrowBatches; + auto arrowSchema = NArrow::DeserializeSchema(TString(schema)); + + for (const auto& batch : batches) { + auto arrowBatch = NArrow::DeserializeBatch(TString(batch), arrowSchema); + UNIT_ASSERT_C(arrowBatch, "Batch must be deserialized"); + UNIT_ASSERT_C(arrowBatch->ValidateFull().ok(), "Batch validation failed"); + + arrowBatches.push_back(std::move(arrowBatch)); + } + + auto resultBatch = NArrow::CombineBatches(arrowBatches); + UNIT_ASSERT_C(resultBatch->ValidateFull().ok(), "Batch combine validation failed"); + + resultBatches.push_back(std::move(resultBatch)); + } + + return resultBatches; +} + +std::string SerializeToBinaryJsonString(const TStringBuf json) { + const auto binaryJson = std::get(NBinaryJson::SerializeToBinaryJson(json)); + const TStringBuf buffer(binaryJson.Data(), binaryJson.Size()); + return TString(buffer); +} + +void CompareCompressedAndDefaultBatches(TQueryClient& client, std::optional codec, bool assertEqual = false) { + std::shared_ptr schemaCompressedBatch; + TString compressedBatch; + + std::shared_ptr schemaDefaultBatch; + TString defaultBatch; + + { + auto settings = TExecuteQuerySettings() + .Format(TResultSet::EFormat::Arrow) + .ArrowFormatSettings(TArrowFormatSettings() + .CompressionCodec(std::move(codec))); + + auto result = client.ExecuteQuery(R"( + SELECT Comment, Amount, Name FROM Test ORDER BY Amount DESC; + )", TTxControl::BeginTx().CommitTx(), settings).GetValueSync(); + UNIT_ASSERT_C(result.IsSuccess(), result.GetIssues().ToString()); + + const auto& schema = TArrowAccessor::GetArrowSchema(result.GetResultSet(0)); + const auto& batches = TArrowAccessor::GetArrowBatches(result.GetResultSet(0)); + + UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); + + schemaCompressedBatch = NArrow::DeserializeSchema(TString(schema)); + compressedBatch = std::move(batches[0]); + } + { + auto settings = TExecuteQuerySettings().Format(TResultSet::EFormat::Arrow); + + auto result = client.ExecuteQuery(R"( + SELECT Comment, Amount, Name FROM Test ORDER BY Amount DESC; + )", TTxControl::BeginTx().CommitTx(), settings).GetValueSync(); + UNIT_ASSERT_C(result.IsSuccess(), result.GetIssues().ToString()); + + const auto& schema = TArrowAccessor::GetArrowSchema(result.GetResultSet(0)); + const auto& batches = TArrowAccessor::GetArrowBatches(result.GetResultSet(0)); + + UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); + + schemaDefaultBatch = NArrow::DeserializeSchema(TString(schema)); + defaultBatch = std::move(batches[0]); + } + + UNIT_ASSERT_VALUES_EQUAL(schemaCompressedBatch->ToString(), schemaDefaultBatch->ToString()); + + // TODO [ditimizhev@]: Assert arrow::Codec compression types instead of strings + if (assertEqual) { + UNIT_ASSERT_VALUES_EQUAL(compressedBatch, defaultBatch); + } else { + UNIT_ASSERT_VALUES_UNEQUAL(compressedBatch, defaultBatch); + } + + auto firstArrowBatch = NArrow::DeserializeBatch(compressedBatch, schemaCompressedBatch); + auto secondArrowBatch = NArrow::DeserializeBatch(defaultBatch, schemaDefaultBatch); + + UNIT_ASSERT_C(firstArrowBatch, "First arrow batch must be deserialized"); + UNIT_ASSERT_C(secondArrowBatch, "Second arrow batch must be deserialized"); + + UNIT_ASSERT_C(firstArrowBatch->num_rows() > 0, "Arrow batch must not be empty"); + + UNIT_ASSERT_C(firstArrowBatch->ValidateFull().ok(), "Batch validation failed"); + UNIT_ASSERT_C(secondArrowBatch->ValidateFull().ok(), "Batch validation failed"); + + UNIT_ASSERT_VALUES_EQUAL(firstArrowBatch->ToString(), secondArrowBatch->ToString()); +} + +void ValidateOptionalColumn(const std::shared_ptr& array, int depth, bool isVariant) { + if (depth == 0 && isVariant) { + UNIT_ASSERT_C(array->type()->id() == arrow::Type::DENSE_UNION, "Column type must be arrow::Type::DENSE_UNION"); + return; + } + + if (depth == 1 && !isVariant) { + return; + } + + UNIT_ASSERT_C(array->type()->id() == arrow::Type::STRUCT, "Column type must be arrow::Type::STRUCT"); + + auto structArray = static_pointer_cast(array); + UNIT_ASSERT_C(structArray->num_fields() == 1, "Struct array must have 1 field"); + + auto innerArray = structArray->field(0); + ValidateOptionalColumn(innerArray, depth - 1, isVariant); +} + +} // namespace + +Y_UNIT_TEST_SUITE(KqpResultSetFormats) { + /** + * By default, unspecified format is Value for compatibility with previous versions. + */ + Y_UNIT_TEST(DefaultFormat) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ true); + auto client = kikimr.GetQueryClient(); + + auto result = client.ExecuteQuery(R"( + SELECT Comment, Amount, Name FROM Test ORDER BY Amount DESC; + )", TTxControl::BeginTx().CommitTx()).GetValueSync(); + UNIT_ASSERT_C(result.IsSuccess(), result.GetIssues().ToString()); + + auto resultSet = result.GetResultSet(0); + UNIT_ASSERT_VALUES_EQUAL(TArrowAccessor::Format(resultSet), TResultSet::EFormat::Value); + + CompareYson(R"([ + [["None"];[7200u];["Tony"]]; + [["None"];[3500u];["Anna"]]; + [["None"];[300u];["Paul"]] + ])", FormatResultSetYson(resultSet)); + } + + /** + * Set Value format explicitly in TExecuteQuerySettings. + */ + Y_UNIT_TEST(ValueFormat_Simple) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ true); + auto client = kikimr.GetQueryClient(); + + auto settings = TExecuteQuerySettings().Format(TResultSet::EFormat::Value); + + auto result = client.ExecuteQuery(R"( + SELECT Comment, Amount, Name FROM Test ORDER BY Amount DESC; + )", TTxControl::BeginTx().CommitTx(), settings).GetValueSync(); + UNIT_ASSERT_C(result.IsSuccess(), result.GetIssues().ToString()); + + auto resultSet = result.GetResultSet(0); + UNIT_ASSERT_VALUES_EQUAL(TArrowAccessor::Format(resultSet), TResultSet::EFormat::Value); + + CompareYson(R"([ + [["None"];[7200u];["Tony"]]; + [["None"];[3500u];["Anna"]]; + [["None"];[300u];["Paul"]] + ])", FormatResultSetYson(resultSet)); + } + + /** + * Small channel buffer size, rows from many ExecuteQueryResponePart parts are filled into a single ResultSet. + */ + Y_UNIT_TEST(ValueFormat_SmallChannelBufferSize) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ false, 1_KB); + auto client = kikimr.GetQueryClient(); + + CreateLargeTable(kikimr, 100, 2, 2, 10, 2); + + auto settings = TExecuteQuerySettings().Format(TResultSet::EFormat::Value); + + auto result = client.ExecuteQuery(R"( + SELECT * FROM LargeTable; + )", TTxControl::BeginTx().CommitTx(), settings).GetValueSync(); + UNIT_ASSERT_C(result.IsSuccess(), result.GetIssues().ToString()); + + auto resultSet = result.GetResultSet(0); + UNIT_ASSERT_VALUES_EQUAL(TArrowAccessor::Format(resultSet), TResultSet::EFormat::Value); + UNIT_ASSERT_VALUES_EQUAL(resultSet.RowsCount(), 200); + UNIT_ASSERT_VALUES_EQUAL(resultSet.ColumnsCount(), 4); + } + + /** + * By default, SchemaInclusionMode is ALWAYS for Value format. + */ + Y_UNIT_TEST(ValueFormat_SchemaInclusionMode_Unspecified) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ false, 1_KB); + auto client = kikimr.GetQueryClient(); + + CreateLargeTable(kikimr, 100, 2, 2, 10, 2); + + auto settings = TExecuteQuerySettings().Format(TResultSet::EFormat::Value); + + auto it = client.StreamExecuteQuery(R"( + SELECT * FROM LargeTable; + )", TTxControl::BeginTx().CommitTx(), settings).GetValueSync(); + UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString()); + + size_t count = 0; + for (;;) { + auto part = it.ReadNext().GetValueSync(); + if (!part.IsSuccess()) { + UNIT_ASSERT_C(part.EOS(), part.GetIssues().ToString()); + break; + } + + if (part.HasResultSet()) { + auto resultSet = part.ExtractResultSet(); + UNIT_ASSERT_VALUES_EQUAL(TArrowAccessor::Format(resultSet), TResultSet::EFormat::Value); + UNIT_ASSERT_VALUES_UNEQUAL(resultSet.ColumnsCount(), 0); + + ++count; + } + } + + UNIT_ASSERT_GT_C(count, 1, "Expected at least 2 result sets"); + } + + /** + * Set SchemaInclusionMode ALWAYS for Value format explicitly in TExecuteQuerySettings. + */ + Y_UNIT_TEST(ValueFormat_SchemaInclusionMode_Always) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ false, 1_KB); + auto client = kikimr.GetQueryClient(); + + CreateLargeTable(kikimr, 100, 2, 2, 10, 2); + + auto settings = TExecuteQuerySettings().Format(TResultSet::EFormat::Value).SchemaInclusionMode(ESchemaInclusionMode::Always); + + auto it = client.StreamExecuteQuery(R"( + SELECT * FROM LargeTable; + )", TTxControl::BeginTx().CommitTx(), settings).GetValueSync(); + UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString()); + + size_t count = 0; + for (;;) { + auto part = it.ReadNext().GetValueSync(); + if (!part.IsSuccess()) { + UNIT_ASSERT_C(part.EOS(), part.GetIssues().ToString()); + break; + } + + if (part.HasResultSet()) { + auto resultSet = part.ExtractResultSet(); + UNIT_ASSERT_VALUES_EQUAL(TArrowAccessor::Format(resultSet), TResultSet::EFormat::Value); + UNIT_ASSERT_VALUES_UNEQUAL(resultSet.ColumnsCount(), 0); + + ++count; + } + } + + UNIT_ASSERT_GT_C(count, 1, "Expected at least 2 result sets"); + } + + /** + * Set SchemaInclusionMode FIRST_ONLY for Value format explicitly in TExecuteQuerySettings. + */ + Y_UNIT_TEST(ValueFormat_SchemaInclusionMode_FirstOnly) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ false, 1_KB); + auto client = kikimr.GetQueryClient(); + + CreateLargeTable(kikimr, 100, 2, 2, 10, 2); + + auto settings = TExecuteQuerySettings().Format(TResultSet::EFormat::Value).SchemaInclusionMode(ESchemaInclusionMode::FirstOnly); + + auto it = client.StreamExecuteQuery(R"( + SELECT * FROM LargeTable; + )", TTxControl::BeginTx().CommitTx(), settings).GetValueSync(); + UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString()); + + size_t count = 0; + for (;;) { + auto part = it.ReadNext().GetValueSync(); + if (!part.IsSuccess()) { + UNIT_ASSERT_C(part.EOS(), part.GetIssues().ToString()); + break; + } + + if (part.HasResultSet()) { + auto resultSet = part.ExtractResultSet(); + UNIT_ASSERT_VALUES_EQUAL(TArrowAccessor::Format(resultSet), TResultSet::EFormat::Value); + + if (count == 0) { + UNIT_ASSERT_VALUES_UNEQUAL(resultSet.ColumnsCount(), 0); + } else { + UNIT_ASSERT_VALUES_EQUAL(resultSet.ColumnsCount(), 0); + } + + ++count; + } + } + + UNIT_ASSERT_GT_C(count, 1, "Expected at least 2 result sets"); + } + + /** + * For Value format, FirstOnly schema inclusion mode is supported for multistatement queries. + */ + Y_UNIT_TEST(ValueFormat_SchemaInclusionMode_FirstOnly_Multistatement) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ false, 1_KB); + auto client = kikimr.GetQueryClient(); + + CreateLargeTable(kikimr, 200, 2, 2, 10, 2); + + auto settings = TExecuteQuerySettings().Format(TResultSet::EFormat::Value).SchemaInclusionMode(ESchemaInclusionMode::FirstOnly); + + auto it = client.StreamExecuteQuery(R"( + SELECT * FROM LargeTable; + SELECT Key, Data FROM LargeTable; + )", TTxControl::BeginTx().CommitTx(), settings).GetValueSync(); + UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString()); + + std::unordered_map counts; + + for (;;) { + auto part = it.ReadNext().GetValueSync(); + if (!part.IsSuccess()) { + UNIT_ASSERT_C(part.EOS(), part.GetIssues().ToString()); + break; + } + + if (part.HasResultSet()) { + auto resultSet = part.ExtractResultSet(); + UNIT_ASSERT_VALUES_EQUAL(TArrowAccessor::Format(resultSet), TResultSet::EFormat::Value); + + auto idx = part.GetResultSetIndex(); + + if (counts.find(idx) == counts.end()) { + UNIT_ASSERT_VALUES_UNEQUAL(resultSet.ColumnsCount(), 0); + } else { + UNIT_ASSERT_VALUES_EQUAL(resultSet.ColumnsCount(), 0); + } + + ++counts[idx]; + } + } + + UNIT_ASSERT_C(counts.size() == 2, "Expected 2 result set indexes"); + + for (const auto& [idx, count] : counts) { + UNIT_ASSERT_GT_C(count, 1, "Expected at least 2 result sets for statement with ResultSetIndex = " << idx); + } + } + + /** + * Set Arrow format explicitly in TExecuteQuerySettings. + */ + Y_UNIT_TEST(ArrowFormat_Simple) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ true); + auto client = kikimr.GetQueryClient(); + + auto settings = TExecuteQuerySettings().Format(TResultSet::EFormat::Arrow); + + auto result = client.ExecuteQuery(R"( + SELECT Comment, Amount, Name FROM Test ORDER BY Amount DESC; + )", TTxControl::BeginTx().CommitTx(), settings).GetValueSync(); + UNIT_ASSERT_C(result.IsSuccess(), result.GetIssues().ToString()); + + auto resultSet = result.GetResultSet(0); + UNIT_ASSERT_VALUES_EQUAL(TArrowAccessor::Format(resultSet), TResultSet::EFormat::Arrow); + + const auto& schema = TArrowAccessor::GetArrowSchema(resultSet); + const auto& batches = TArrowAccessor::GetArrowBatches(resultSet); + + UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); + + auto arrowSchema = NArrow::DeserializeSchema(TString(schema)); + auto arrowBatch = NArrow::DeserializeBatch(TString(batches[0]), arrowSchema); + + UNIT_ASSERT_C(arrowSchema, "Schema must be deserialized"); + UNIT_ASSERT_C(arrowBatch, "Batch must be deserialized"); + UNIT_ASSERT_C(arrowBatch->ValidateFull().ok(), "Batch validation failed"); + + NColumnShard::TTableUpdatesBuilder builder(NArrow::MakeArrowSchema({ + std::make_pair("Comment", TTypeInfo(NTypeIds::String)), + std::make_pair("Amount", TTypeInfo(NTypeIds::Uint64)), + std::make_pair("Name", TTypeInfo(NTypeIds::String)) + })); + + builder.AddRow().Add("None").Add(7200).Add("Tony"); + builder.AddRow().Add("None").Add(3500).Add("Anna"); + builder.AddRow().Add("None").Add(300).Add("Paul"); + + auto expected = builder.BuildArrow(); + UNIT_ASSERT_VALUES_EQUAL(arrowBatch->ToString(), expected->ToString()); + } + + Y_UNIT_TEST(ArrowFormat_EmptyBatch) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ true); + auto client = kikimr.GetQueryClient(); + + auto settings = TExecuteQuerySettings().Format(TResultSet::EFormat::Arrow); + + auto result = client.ExecuteQuery(R"( + SELECT Comment, Amount, Name FROM Test WHERE Amount >= 999999; + )", TTxControl::BeginTx().CommitTx(), settings).GetValueSync(); + UNIT_ASSERT_C(result.IsSuccess(), result.GetIssues().ToString()); + + auto resultSet = result.GetResultSet(0); + UNIT_ASSERT_VALUES_EQUAL(TArrowAccessor::Format(resultSet), TResultSet::EFormat::Arrow); + + const auto& schema = TArrowAccessor::GetArrowSchema(resultSet); + const auto& batches = TArrowAccessor::GetArrowBatches(resultSet); + + UNIT_ASSERT_C(!batches.empty(), "Expected at least one empty batch"); + + auto arrowSchema = NArrow::DeserializeSchema(TString(schema)); + auto arrowBatch = NArrow::DeserializeBatch(TString(batches[0]), arrowSchema); + + UNIT_ASSERT_C(arrowSchema, "Schema must be deserialized"); + UNIT_ASSERT_C(arrowBatch, "Batch must be deserialized"); + UNIT_ASSERT_C(arrowBatch->ValidateFull().ok(), "Batch validation failed"); + + NColumnShard::TTableUpdatesBuilder builder(NArrow::MakeArrowSchema({ + std::make_pair("Comment", TTypeInfo(NTypeIds::String)), + std::make_pair("Amount", TTypeInfo(NTypeIds::Uint64)), + std::make_pair("Name", TTypeInfo(NTypeIds::String)) + })); + + UNIT_ASSERT_C(arrowBatch->num_rows() == 0, "Batch must have 0 rows"); + + auto expected = builder.BuildArrow(); + UNIT_ASSERT_VALUES_EQUAL(arrowBatch->ToString(), expected->ToString()); + } + + /** + * Arrow format is supported for all types of columns. + */ + Y_UNIT_TEST_TWIN(ArrowFormat_AllTypes, isOlap) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ false); + auto client = kikimr.GetQueryClient(); + + if (isOlap) { + CreateAllTypesColumnTable(client); + } else { + CreateAllTypesRowTable(client); + } + + const TString query = Sprintf(R"( + SELECT * FROM `/Root/%s`; + )", (isOlap) ? "ColumnTable" : "RowTable"); + + Y_UNUSED(ExecuteAndCombineBatches(client, query, /* assertSize */ true)); + } + + /** + * Arrow format is supported for large batches. + */ + Y_UNIT_TEST(ArrowFormat_LargeTable) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ false); + auto client = kikimr.GetQueryClient(); + + CreateLargeTable(kikimr, 10000, 4, 10, 5000, 10); + + const TString query = Sprintf(R"( + SELECT * FROM `/Root/LargeTable`; + )"); + + Y_UNUSED(ExecuteAndCombineBatches(client, query, /* assertSize */ true)); + } + + /** + * Arrow format is supported for large batches with LIMIT. + */ + Y_UNIT_TEST(ArrowFormat_LargeTable_Limit) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ false); + auto client = kikimr.GetQueryClient(); + + CreateLargeTable(kikimr, 10000, 4, 10, 5000, 10); + + const TString query = Sprintf(R"( + SELECT * FROM `/Root/LargeTable` LIMIT 70000; + )"); + + Y_UNUSED(ExecuteAndCombineBatches(client, query, /* assertSize */ true)); + } + + /** + * Arrow format is supported for returning. + */ + Y_UNIT_TEST_TWIN(ArrowFormat_Returning, isOlap) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ false); + auto client = kikimr.GetQueryClient(); + + TString query; + + if (isOlap) { + CreateAllTypesColumnTable(client); + query = R"( + UPSERT INTO `/Root/ColumnTable` (Key, Int8Value, Uint8Value, Int16Value, Uint16Value, Int32Value, Uint32Value, Int64Value, Uint64Value, FloatValue, DoubleValue, StringValue, Utf8Value, DateValue, DatetimeValue, TimestampValue, JsonValue, YsonValue, JsonDocumentValue) VALUES + (43, -1, 1, -2, 2, -3, 3, -4, 4, CAST(3.0 AS Float), 4.0, "five", Utf8("six"), Date("2007-07-07"), Datetime("2008-08-08T08:08:08Z"), Timestamp("2009-09-09T09:09:09.09Z"), "[12]", "[13]", JsonDocument("[14]")), + (44, -1, 1, -2, 2, -3, 3, -4, 4, CAST(3.0 AS Float), 4.0, "five", Utf8("six"), Date("2007-07-07"), Datetime("2008-08-08T08:08:08Z"), Timestamp("2009-09-09T09:09:09.09Z"), "[12]", "[13]", JsonDocument("[14]")), + (45, -1, 1, -2, 2, -3, 3, -4, 4, CAST(3.0 AS Float), 4.0, "five", Utf8("six"), Date("2007-07-07"), Datetime("2008-08-08T08:08:08Z"), Timestamp("2009-09-09T09:09:09.09Z"), "[12]", "[13]", JsonDocument("[14]")) + RETURNING *; + )"; + } else { + CreateAllTypesRowTable(client); + query = R"( + UPSERT INTO `/Root/RowTable` (Key, BoolValue, Int8Value, Uint8Value, Int16Value, Uint16Value, Int32Value, Uint32Value, Int64Value, Uint64Value, FloatValue, DoubleValue, StringValue, Utf8Value, DateValue, DatetimeValue, TimestampValue, IntervalValue, DecimalValue, JsonValue, YsonValue, JsonDocumentValue, DyNumberValue, Int32NotNullValue) VALUES + (43, true, -1, 1, -2, 2, -3, 3, -4, 4, CAST(3.0 AS Float), 4.0, "five", Utf8("six"), Date("2007-07-07"), Datetime("2008-08-08T08:08:08Z"), Timestamp("2009-09-09T09:09:09.09Z"), Interval("P10D"), CAST("11.11" AS Decimal(22, 9)), "[12]", "[13]", JsonDocument("[14]"), DyNumber("15.15"), 123), + (44, true, -1, 1, -2, 2, -3, 3, -4, 4, CAST(3.0 AS Float), 4.0, "five", Utf8("six"), Date("2007-07-07"), Datetime("2008-08-08T08:08:08Z"), Timestamp("2009-09-09T09:09:09.09Z"), Interval("P10D"), CAST("11.11" AS Decimal(22, 9)), "[12]", "[13]", JsonDocument("[14]"), DyNumber("15.15"), 123), + (45, true, -1, 1, -2, 2, -3, 3, -4, 4, CAST(3.0 AS Float), 4.0, "five", Utf8("six"), Date("2007-07-07"), Datetime("2008-08-08T08:08:08Z"), Timestamp("2009-09-09T09:09:09.09Z"), Interval("P10D"), CAST("11.11" AS Decimal(22, 9)), "[12]", "[13]", JsonDocument("[14]"), DyNumber("15.15"), 123) + RETURNING *; + )"; + } + + Y_UNUSED(ExecuteAndCombineBatches(client, query, /* assertSize */ true)); + } + + /** + * Check different orders of columns in SELECT with Arrow format. + */ + Y_UNIT_TEST(ArrowFormat_ColumnOrder) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ true); + auto client = kikimr.GetQueryClient(); + + { + auto batches = ExecuteAndCombineBatches(client, R"( + SELECT Name, Amount FROM Test WHERE Group = 2; + )", /* assertSize */ true); + + UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); + + NColumnShard::TTableUpdatesBuilder builder(NArrow::MakeArrowSchema({ + std::make_pair("Name", TTypeInfo(NTypeIds::String)), + std::make_pair("Amount", TTypeInfo(NTypeIds::Uint64)) + })); + + builder.AddRow().Add("Tony").Add(7200); + + auto expected = builder.BuildArrow(); + UNIT_ASSERT_VALUES_EQUAL(batches.front()->ToString(), expected->ToString()); + } + { + auto batches = ExecuteAndCombineBatches(client, R"( + SELECT Amount, Name FROM Test WHERE Group = 2; + )", /* assertSize */ true); + + UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); + + NColumnShard::TTableUpdatesBuilder builder(NArrow::MakeArrowSchema({ + std::make_pair("Amount", TTypeInfo(NTypeIds::Uint64)), + std::make_pair("Name", TTypeInfo(NTypeIds::String)) + })); + + builder.AddRow().Add(7200).Add("Tony"); + + auto expected = builder.BuildArrow(); + UNIT_ASSERT_VALUES_EQUAL(batches.front()->ToString(), expected->ToString()); + } + { + auto batches = ExecuteAndCombineBatches(client, R"( + SELECT Comment, Amount, Name FROM Test ORDER BY Amount DESC; + )", /* assertSize */ true); + + UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); + + NColumnShard::TTableUpdatesBuilder builder(NArrow::MakeArrowSchema({ + std::make_pair("Comment", TTypeInfo(NTypeIds::String)), + std::make_pair("Amount", TTypeInfo(NTypeIds::Uint64)), + std::make_pair("Name", TTypeInfo(NTypeIds::String)) + })); + + builder.AddRow().Add("None").Add(7200).Add("Tony"); + builder.AddRow().Add("None").Add(3500).Add("Anna"); + builder.AddRow().Add("None").Add(300).Add("Paul"); + + auto expected = builder.BuildArrow(); + UNIT_ASSERT_VALUES_EQUAL(batches.front()->ToString(), expected->ToString()); + } + } + + /** + * Small channel buffer size, data bytes and schema from many ExecuteQueryResponePart parts are filled into a single ResultSet as a std::vector with a single schema. + */ + Y_UNIT_TEST(ArrowFormat_SmallChannelBufferSize) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ false, 1_KB); + auto client = kikimr.GetQueryClient(); + + CreateLargeTable(kikimr, 100, 2, 2, 10, 2); + + auto batches = ExecuteAndCombineBatches(client, R"( + SELECT * FROM LargeTable; + )", /* assertSize */ true, /* minBatchesCount */ 2); + + UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); + + UNIT_ASSERT_VALUES_EQUAL(batches.front()->num_rows(), 200); + UNIT_ASSERT_VALUES_EQUAL(batches.front()->num_columns(), 4); + } + + /** + * These YQL types are supported for Arrow format as arithmetic types: + */ + Y_UNIT_TEST(ArrowFormat_Types_Arithmetic) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ false); + auto client = kikimr.GetQueryClient(); + + { + auto result = client.ExecuteQuery(R"( + CREATE TABLE ArithmeticTypesTable ( + BoolValue Bool, + Int8Value Int8, + Uint8Value Uint8 NOT NULL, + Int16Value Int16, + Uint16Value Uint16 NOT NULL, + Int32Value Int32, + Uint32Value Uint32 NOT NULL, + Int64Value Int64, + Uint64Value Uint64 NOT NULL, + FloatValue Float, + DoubleValue Double NOT NULL, + DecimalValue Decimal(22, 2), + PRIMARY KEY (BoolValue) + ); + )", TTxControl::NoTx()).GetValueSync(); + UNIT_ASSERT_C(result.IsSuccess(), result.GetIssues().ToString()); + } + { + auto result = client.ExecuteQuery(R"( + INSERT INTO ArithmeticTypesTable (BoolValue, Int8Value, Uint8Value, Int16Value, Uint16Value, Int32Value, Uint32Value, Int64Value, Uint64Value, FloatValue, DoubleValue, DecimalValue) VALUES + (NULL, -1, 1, -2, 2, -3, 3, -4, 4, CAST(5.0 AS Float), 6.0, CAST("7.77" AS Decimal(22, 2))), + (true, -1, 1, -2, 2, -3, 3, -4, 4, CAST(5.0 AS Float), 6.0, CAST("7.77" AS Decimal(22, 2))), + (false, -1, 1, -2, 2, -3, 3, -4, 4, CAST(5.0 AS Float), 6.0, CAST("7.77" AS Decimal(22, 2))); + )", TTxControl::BeginTx().CommitTx()).GetValueSync(); + UNIT_ASSERT_C(result.IsSuccess(), result.GetIssues().ToString()); + } + { + auto batches = ExecuteAndCombineBatches(client, R"( + SELECT BoolValue, Int8Value, Uint8Value, Int16Value, Uint16Value, Int32Value, Uint32Value, Int64Value, Uint64Value, FloatValue, DoubleValue, DecimalValue + FROM ArithmeticTypesTable ORDER BY BoolValue; + )", /* assertSize */ true); + + UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); + + NColumnShard::TTableUpdatesBuilder builder(NArrow::MakeArrowSchema({ + std::make_pair("BoolValue", TTypeInfo(NTypeIds::Uint8)), + std::make_pair("Int8Value", TTypeInfo(NTypeIds::Int8)), + std::make_pair("Uint8Value", TTypeInfo(NTypeIds::Uint8)), + std::make_pair("Int16Value", TTypeInfo(NTypeIds::Int16)), + std::make_pair("Uint16Value", TTypeInfo(NTypeIds::Uint16)), + std::make_pair("Int32Value", TTypeInfo(NTypeIds::Int32)), + std::make_pair("Uint32Value", TTypeInfo(NTypeIds::Uint32)), + std::make_pair("Int64Value", TTypeInfo(NTypeIds::Int64)), + std::make_pair("Uint64Value", TTypeInfo(NTypeIds::Uint64)), + std::make_pair("FloatValue", TTypeInfo(NTypeIds::Float)), + std::make_pair("DoubleValue", TTypeInfo(NTypeIds::Double)), + std::make_pair("DecimalValue", TTypeInfo(NScheme::TDecimalType(22, 2))) + })); + + builder.AddRow().AddNull().Add(-1).Add(1).Add(-2).Add(2).Add(-3).Add(3).Add(-4).Add(4).Add(5.0).Add(6.0).Add(TDecimalValue("7.77", 22, 2)); + builder.AddRow().Add(false).Add(-1).Add(1).Add(-2).Add(2).Add(-3).Add(3).Add(-4).Add(4).Add(5.0).Add(6.0).Add(TDecimalValue("7.77", 22, 2)); + builder.AddRow().Add(true).Add(-1).Add(1).Add(-2).Add(2).Add(-3).Add(3).Add(-4).Add(4).Add(5.0).Add(6.0).Add(TDecimalValue("7.77", 22, 2)); + + auto expected = builder.BuildArrow(); + UNIT_ASSERT_VALUES_EQUAL(batches.front()->ToString(), expected->ToString()); + } + } + + /** + * These YQL types are supported for Arrow format as string types: + */ + Y_UNIT_TEST(ArrowFormat_Types_String) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ false); + auto client = kikimr.GetQueryClient(); + + { + auto result = client.ExecuteQuery(R"( + CREATE TABLE StringTypesTable ( + Utf8Value Utf8, + JsonValue Json, + Utf8NotNullValue Utf8 NOT NULL, + JsonNotNullValue Json NOT NULL, + PRIMARY KEY (Utf8Value) + ); + )", TTxControl::NoTx()).GetValueSync(); + UNIT_ASSERT_C(result.IsSuccess(), result.GetIssues().ToString()); + } + { + auto result = client.ExecuteQuery(R"( + INSERT INTO StringTypesTable (Utf8Value, JsonValue, Utf8NotNullValue, JsonNotNullValue) VALUES + ("John", "[1]", "John", "[2]"), + (NULL, "[]", "Maria", "[3]"), + ("Leo", NULL, "Leo", "[4]"), + ("Michael", "[5]", "Michael", "[6]"); + )", TTxControl::BeginTx().CommitTx()).GetValueSync(); + UNIT_ASSERT_C(result.IsSuccess(), result.GetIssues().ToString()); + } + { + auto batches = ExecuteAndCombineBatches(client, R"( + SELECT Utf8Value, JsonValue, Utf8NotNullValue, JsonNotNullValue + FROM StringTypesTable ORDER BY Utf8Value; + )", /* assertSize */ true); + + UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); + + NColumnShard::TTableUpdatesBuilder builder(NArrow::MakeArrowSchema({ + std::make_pair("Utf8Value", TTypeInfo(NTypeIds::Utf8)), + std::make_pair("JsonValue", TTypeInfo(NTypeIds::Json)), + std::make_pair("Utf8NotNullValue", TTypeInfo(NTypeIds::Utf8)), + std::make_pair("JsonNotNullValue", TTypeInfo(NTypeIds::Json)) + })); + + builder.AddRow().AddNull().Add("[]").Add("Maria").Add("[3]"); + builder.AddRow().Add("John").Add("[1]").Add("John").Add("[2]"); + builder.AddRow().Add("Leo").AddNull().Add("Leo").Add("[4]"); + builder.AddRow().Add("Michael").Add("[5]").Add("Michael").Add("[6]"); + + auto expected = builder.BuildArrow(); + UNIT_ASSERT_VALUES_EQUAL(batches.front()->ToString(), expected->ToString()); + } + } + + /** + * These YQL types are supported for Arrow format as binary types: + */ + Y_UNIT_TEST(ArrowFormat_Types_Binary) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ false); + auto client = kikimr.GetQueryClient(); + + { + auto result = client.ExecuteQuery(R"( + CREATE TABLE BinaryTypesTable ( + StringValue String, + YsonValue Yson, + DyNumberValue DyNumber, + JsonDocumentValue JsonDocument, + UuidValue Uuid, + StringNotNullValue String NOT NULL, + YsonNotNullValue Yson NOT NULL, + JsonDocumentNotNullValue JsonDocument NOT NULL, + DyNumberNotNullValue DyNumber NOT NULL, + UuidNotNullValue Uuid NOT NULL, + PRIMARY KEY (StringValue) + ); + )", TTxControl::NoTx()).GetValueSync(); + UNIT_ASSERT_C(result.IsSuccess(), result.GetIssues().ToString()); + } + { + auto result = client.ExecuteQuery(R"( + INSERT INTO BinaryTypesTable (StringValue, YsonValue, DyNumberValue, JsonDocumentValue, UuidValue, StringNotNullValue, YsonNotNullValue, JsonDocumentNotNullValue, DyNumberNotNullValue, UuidNotNullValue) VALUES + ("John", "[1]", DyNumber("1.0"), JsonDocument("{\"a\": 1}"), Uuid("5b99a330-04ef-4f1a-9b64-ba6d5f44eafe"), "Mark", "[2]", JsonDocument("{\"b\": 2}"), DyNumber("4.0"), Uuid("5b99a330-04ef-4f1a-9b64-ba6d5f44eafe")), + (NULL, "[4]", NULL, NULL, NULL, "Maria", "[5]", JsonDocument("[6]"), DyNumber("7.0"), Uuid("5b99a330-04ef-4f1a-9b64-ba6d5f44eafe")), + ("Mark", NULL, NULL, NULL, NULL, "Michael", "[7]", JsonDocument("[8]"), DyNumber("9.0"), Uuid("5b99a330-04ef-4f1a-9b64-ba6d5f44eafe")), + ("Leo", "[10]", DyNumber("11.0"), JsonDocument("[12]"), Uuid("5b99a330-04ef-4f1a-9b64-ba6d5f44eafe"), "Maria", "[13]", JsonDocument("[14]"), DyNumber("15.0"), Uuid("5b99a330-04ef-4f1a-9b64-ba6d5f44eafe")); + )", TTxControl::BeginTx().CommitTx()).GetValueSync(); + UNIT_ASSERT_C(result.IsSuccess(), result.GetIssues().ToString()); + } + { + auto batches = ExecuteAndCombineBatches(client, R"( + SELECT StringValue, YsonValue, DyNumberValue, JsonDocumentValue, UuidValue, StringNotNullValue, YsonNotNullValue, JsonDocumentNotNullValue, DyNumberNotNullValue, UuidNotNullValue + FROM BinaryTypesTable ORDER BY StringValue; + )", /* assertSize */ true); + + UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); + + NColumnShard::TTableUpdatesBuilder builder(NArrow::MakeArrowSchema({ + std::make_pair("StringValue", TTypeInfo(NTypeIds::String)), + std::make_pair("YsonValue", TTypeInfo(NTypeIds::Yson)), + std::make_pair("DyNumberValue", TTypeInfo(NTypeIds::DyNumber)), + std::make_pair("JsonDocumentValue", TTypeInfo(NTypeIds::JsonDocument)), + std::make_pair("UuidValue", TTypeInfo(NTypeIds::Uuid)), + std::make_pair("StringNotNullValue", TTypeInfo(NTypeIds::String)), + std::make_pair("YsonNotNullValue", TTypeInfo(NTypeIds::Yson)), + std::make_pair("JsonDocumentNotNullValue", TTypeInfo(NTypeIds::JsonDocument)), + std::make_pair("DyNumberNotNullValue", TTypeInfo(NTypeIds::DyNumber)), + std::make_pair("UuidNotNullValue", TTypeInfo(NTypeIds::Uuid)) + })); + + builder.AddRow().AddNull().Add("[4]").AddNull().AddNull().AddNull().Add("Maria").Add("[5]").Add(SerializeToBinaryJsonString("[6]")).Add(NDyNumber::ParseDyNumberString("7.0")->c_str()).Add(NYdb::TUuidValue("5b99a330-04ef-4f1a-9b64-ba6d5f44eafe")); + builder.AddRow().Add("John").Add("[1]").Add(NDyNumber::ParseDyNumberString("1.0")->c_str()).Add(SerializeToBinaryJsonString("{\"a\": 1}")).Add(NYdb::TUuidValue("5b99a330-04ef-4f1a-9b64-ba6d5f44eafe")).Add("Mark").Add("[2]").Add(SerializeToBinaryJsonString("{\"b\": 2}")).Add(NDyNumber::ParseDyNumberString("4.0")->c_str()).Add(NYdb::TUuidValue("5b99a330-04ef-4f1a-9b64-ba6d5f44eafe")); + builder.AddRow().Add("Leo").Add("[10]").Add(NDyNumber::ParseDyNumberString("11.0")->c_str()).Add(SerializeToBinaryJsonString("[12]")).Add(NYdb::TUuidValue("5b99a330-04ef-4f1a-9b64-ba6d5f44eafe")).Add("Maria").Add("[13]").Add(SerializeToBinaryJsonString("[14]")).Add(NDyNumber::ParseDyNumberString("15.0")->c_str()).Add(NYdb::TUuidValue("5b99a330-04ef-4f1a-9b64-ba6d5f44eafe")); + builder.AddRow().Add("Mark").AddNull().AddNull().AddNull().AddNull().Add("Michael").Add("[7]").Add(SerializeToBinaryJsonString("[8]")).Add(NDyNumber::ParseDyNumberString("9.0")->c_str()).Add(NYdb::TUuidValue("5b99a330-04ef-4f1a-9b64-ba6d5f44eafe")); + + + auto expected = builder.BuildArrow(); + UNIT_ASSERT_VALUES_EQUAL(batches.front()->ToString(), expected->ToString()); + } + } + + /** + * These YQL types are supported for Arrow format as integer and time types: + */ + Y_UNIT_TEST(ArrowFormat_Types_Time) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ false); + auto client = kikimr.GetQueryClient(); + + { + auto result = client.ExecuteQuery(R"( + CREATE TABLE TimeTypesTable ( + DateValue Date, + DatetimeValue Datetime, + TimestampValue Timestamp, + IntervalValue Interval, + PRIMARY KEY (DateValue) + ); + )", TTxControl::NoTx()).GetValueSync(); + UNIT_ASSERT_C(result.IsSuccess(), result.GetIssues().ToString()); + } + { + auto result = client.ExecuteQuery(R"( + INSERT INTO TimeTypesTable (DateValue, DatetimeValue, TimestampValue, IntervalValue) VALUES + (Date("2001-01-01"), Datetime("2002-02-02T02:02:02Z"), Timestamp("2003-03-03T03:03:03Z"), Interval("P7D")); + )", TTxControl::BeginTx().CommitTx()).GetValueSync(); + UNIT_ASSERT_C(result.IsSuccess(), result.GetIssues().ToString()); + } + { + auto batches = ExecuteAndCombineBatches(client, R"( + SELECT DateValue, DatetimeValue, TimestampValue, IntervalValue + FROM TimeTypesTable ORDER BY DateValue; + )", /* assertSize */ true); + + UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); + + NColumnShard::TTableUpdatesBuilder builder(NArrow::MakeArrowSchema({ + std::make_pair("DateValue", TTypeInfo(NTypeIds::Uint16)), + std::make_pair("DatetimeValue", TTypeInfo(NTypeIds::Uint32)), + std::make_pair("TimestampValue", TTypeInfo(NTypeIds::Uint64)), + std::make_pair("IntervalValue", TTypeInfo(NTypeIds::Int64)) + })); + + builder.AddRow().Add(11323).Add(1012615322).Add(1046660583000000).Add(604800000000); + + auto expected = builder.BuildArrow(); + UNIT_ASSERT_VALUES_EQUAL(batches.front()->ToString(), expected->ToString()); + } + } + + /** + * Arrow format is supported for compression. + * By default, unspecified compression codec is None (without compression). + */ + Y_UNIT_TEST(ArrowFormat_Compression_None) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ true); + auto client = kikimr.GetQueryClient(); + + CompareCompressedAndDefaultBatches(client, std::nullopt, /* assertEqual */ true); + + CompareCompressedAndDefaultBatches(client, TArrowFormatSettings::TCompressionCodec().Type(TArrowFormatSettings::TCompressionCodec::EType::None), /* assertEqual */ true); + } + + /** + * Arrow format is supported for compression by ZSTD codec. + * Compression level is supported. + */ + Y_UNIT_TEST(ArrowFormat_Compression_ZSTD) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ true); + auto client = kikimr.GetQueryClient(); + + CompareCompressedAndDefaultBatches(client, TArrowFormatSettings::TCompressionCodec().Type(TArrowFormatSettings::TCompressionCodec::EType::Zstd).Level(12), /* assertEqual */ false); + } + + /** + * Arrow format is supported for compression by LZ4_FRAME codec. + * Compression level is not supported. + */ + Y_UNIT_TEST(ArrowFormat_Compression_LZ4_FRAME) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ true); + auto client = kikimr.GetQueryClient(); + + CompareCompressedAndDefaultBatches(client, TArrowFormatSettings::TCompressionCodec().Type(TArrowFormatSettings::TCompressionCodec::EType::Lz4Frame), /* assertEqual */ false); + + { + auto settings = TExecuteQuerySettings() + .Format(TResultSet::EFormat::Arrow) + .ArrowFormatSettings(TArrowFormatSettings() + .CompressionCodec(TArrowFormatSettings::TCompressionCodec() + .Type(TArrowFormatSettings::TCompressionCodec::EType::Lz4Frame) + .Level(12))); + + auto result = client.ExecuteQuery(R"( + SELECT Comment, Amount, Name FROM Test ORDER BY Amount DESC; + )", TTxControl::BeginTx().CommitTx(), settings).GetValueSync(); + + UNIT_ASSERT_VALUES_EQUAL(result.GetStatus(), EStatus::INTERNAL_ERROR); + UNIT_ASSERT_STRING_CONTAINS(result.GetIssues().ToString(), "Codec 'lz4' doesn't support setting a compression level"); + } + } + + /** + * Arrow batches are returned for different result set indexes. + */ + Y_UNIT_TEST(ArrowFormat_Multistatement) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ true); + auto client = kikimr.GetQueryClient(); + + auto settings = TExecuteQuerySettings().Format(TResultSet::EFormat::Arrow); + + auto result = client.ExecuteQuery(R"( + SELECT Comment, Amount, Name FROM Test ORDER BY Amount DESC; + SELECT Key, Value FROM KeyValue ORDER BY Key; + )", TTxControl::BeginTx().CommitTx(), settings).GetValueSync(); + UNIT_ASSERT_C(result.IsSuccess(), result.GetIssues().ToString()); + + UNIT_ASSERT_VALUES_EQUAL(result.GetResultSets().size(), 2); + + { + auto resultSet = result.GetResultSet(0); + UNIT_ASSERT_VALUES_EQUAL(TArrowAccessor::Format(resultSet), TResultSet::EFormat::Arrow); + + const auto& schema = TArrowAccessor::GetArrowSchema(resultSet); + const auto& batches = TArrowAccessor::GetArrowBatches(resultSet); + + UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); + + auto arrowSchema = NArrow::DeserializeSchema(TString(schema)); + auto arrowBatch = NArrow::DeserializeBatch(TString(batches[0]), arrowSchema); + + UNIT_ASSERT_C(arrowSchema, "Schema must be deserialized"); + UNIT_ASSERT_C(arrowBatch, "Batch must be deserialized"); + UNIT_ASSERT_C(arrowBatch->ValidateFull().ok(), "Batch validation failed"); + + NColumnShard::TTableUpdatesBuilder builder(NArrow::MakeArrowSchema({ + std::make_pair("Comment", TTypeInfo(NTypeIds::String)), + std::make_pair("Amount", TTypeInfo(NTypeIds::Uint64)), + std::make_pair("Name", TTypeInfo(NTypeIds::String)) + })); + + builder.AddRow().Add("None").Add(7200).Add("Tony"); + builder.AddRow().Add("None").Add(3500).Add("Anna"); + builder.AddRow().Add("None").Add(300).Add("Paul"); + + auto expected = builder.BuildArrow(); + UNIT_ASSERT_VALUES_EQUAL(arrowBatch->ToString(), expected->ToString()); + } + { + auto resultSet = result.GetResultSet(1); + UNIT_ASSERT_VALUES_EQUAL(TArrowAccessor::Format(resultSet), TResultSet::EFormat::Arrow); + + const auto& schema = TArrowAccessor::GetArrowSchema(resultSet); + const auto& batches = TArrowAccessor::GetArrowBatches(resultSet); + + UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); + + auto arrowSchema = NArrow::DeserializeSchema(TString(schema)); + auto arrowBatch = NArrow::DeserializeBatch(TString(batches[0]), arrowSchema); + + UNIT_ASSERT_C(arrowSchema, "Schema must be deserialized"); + UNIT_ASSERT_C(arrowBatch, "Batch must be deserialized"); + UNIT_ASSERT_C(arrowBatch->ValidateFull().ok(), "Batch validation failed"); + + NColumnShard::TTableUpdatesBuilder builder(NArrow::MakeArrowSchema({ + std::make_pair("Key", TTypeInfo(NTypeIds::Uint64)), + std::make_pair("Value", TTypeInfo(NTypeIds::String)) + })); + + builder.AddRow().Add(1).Add("One"); + builder.AddRow().Add(2).Add("Two"); + + auto expected = builder.BuildArrow(); + UNIT_ASSERT_VALUES_EQUAL(arrowBatch->ToString(), expected->ToString()); + } + } + + /** + * By default, SchemaInclusionMode is ALWAYS for Arrow format. + */ + Y_UNIT_TEST(ArrowFormat_SchemaInclusionMode_Unspecified) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ false, 1_KB); + auto client = kikimr.GetQueryClient(); + + CreateLargeTable(kikimr, 100, 2, 2, 10, 2); + + auto settings = TExecuteQuerySettings().Format(TResultSet::EFormat::Arrow); + + auto it = client.StreamExecuteQuery(R"( + SELECT * FROM LargeTable; + )", TTxControl::BeginTx().CommitTx(), settings).GetValueSync(); + UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString()); + + std::shared_ptr arrowSchema; + + size_t count = 0; + for (;;) { + auto part = it.ReadNext().GetValueSync(); + if (!part.IsSuccess()) { + UNIT_ASSERT_C(part.EOS(), part.GetIssues().ToString()); + break; + } + + if (part.HasResultSet()) { + auto resultSet = part.ExtractResultSet(); + UNIT_ASSERT_VALUES_EQUAL(TArrowAccessor::Format(resultSet), TResultSet::EFormat::Arrow); + + const auto& schema = TArrowAccessor::GetArrowSchema(resultSet); + const auto& batches = TArrowAccessor::GetArrowBatches(resultSet); + + UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); + + // With Arrow format, the result set contains a YQL schema and an Arrow record batch schema + UNIT_ASSERT_C(!schema.empty(), "Schema must not be empty"); + UNIT_ASSERT_VALUES_UNEQUAL_C(resultSet.ColumnsCount(), 0, "Columns must not be empty for the first result set"); + + auto curSchema = NArrow::DeserializeSchema(TString(schema)); + UNIT_ASSERT_C(curSchema, "Schema must be deserialized"); + + if (arrowSchema) { + UNIT_ASSERT_VALUES_EQUAL(arrowSchema->ToString(), curSchema->ToString()); + } else { + arrowSchema = curSchema; + } + + auto arrowBatch = NArrow::DeserializeBatch(TString(batches[0]), arrowSchema); + UNIT_ASSERT_C(arrowBatch, "Batch must be deserialized"); + UNIT_ASSERT_C(arrowBatch->ValidateFull().ok(), "Batch validation failed"); + UNIT_ASSERT_GT_C(arrowBatch->num_rows(), 0, "Batch must have at least 1 row"); + + ++count; + } + } + + UNIT_ASSERT_GT_C(count, 1, "Expected at least 2 result sets"); + } + + /** + * Set SchemaInclusionMode ALWAYS for Arrow format explicitly in TExecuteQuerySettings. + */ + Y_UNIT_TEST(ArrowFormat_SchemaInclusionMode_Always) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ false, 1_KB); + auto client = kikimr.GetQueryClient(); + + CreateLargeTable(kikimr, 100, 2, 2, 10, 2); + + auto settings = TExecuteQuerySettings().Format(TResultSet::EFormat::Arrow).SchemaInclusionMode(ESchemaInclusionMode::Always); + + auto it = client.StreamExecuteQuery(R"( + SELECT * FROM LargeTable; + )", TTxControl::BeginTx().CommitTx(), settings).GetValueSync(); + UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString()); + + std::shared_ptr arrowSchema; + + size_t count = 0; + for (;;) { + auto part = it.ReadNext().GetValueSync(); + if (!part.IsSuccess()) { + UNIT_ASSERT_C(part.EOS(), part.GetIssues().ToString()); + break; + } + + if (part.HasResultSet()) { + auto resultSet = part.ExtractResultSet(); + UNIT_ASSERT_VALUES_EQUAL(TArrowAccessor::Format(resultSet), TResultSet::EFormat::Arrow); + + const auto& schema = TArrowAccessor::GetArrowSchema(resultSet); + const auto& batches = TArrowAccessor::GetArrowBatches(resultSet); + + UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); + + // With Arrow format, the result set contains a YQL schema and an Arrow record batch schema + UNIT_ASSERT_C(!schema.empty(), "Schema must not be empty"); + UNIT_ASSERT_VALUES_UNEQUAL_C(resultSet.ColumnsCount(), 0, "Columns must not be empty for the first result set"); + + auto curSchema = NArrow::DeserializeSchema(TString(schema)); + UNIT_ASSERT_C(curSchema, "Schema must be deserialized"); + + if (arrowSchema) { + UNIT_ASSERT_VALUES_EQUAL(arrowSchema->ToString(), curSchema->ToString()); + } else { + arrowSchema = curSchema; + } + + auto arrowBatch = NArrow::DeserializeBatch(TString(batches[0]), arrowSchema); + UNIT_ASSERT_C(arrowBatch, "Batch must be deserialized"); + UNIT_ASSERT_C(arrowBatch->ValidateFull().ok(), "Batch validation failed"); + UNIT_ASSERT_GT_C(arrowBatch->num_rows(), 0, "Batch must have at least 1 row"); + + ++count; + } + } + + UNIT_ASSERT_GT_C(count, 1, "Expected at least 2 result sets"); + } + + /** + * Set SchemaInclusionMode FIRST_ONLY for Arrow format explicitly in TExecuteQuerySettings. + */ + Y_UNIT_TEST(ArrowFormat_SchemaInclusionMode_FirstOnly) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ false, 1_KB); + auto client = kikimr.GetQueryClient(); + + CreateLargeTable(kikimr, 100, 2, 2, 10, 2); + + auto settings = TExecuteQuerySettings().Format(TResultSet::EFormat::Arrow).SchemaInclusionMode(ESchemaInclusionMode::FirstOnly); + + auto it = client.StreamExecuteQuery(R"( + SELECT * FROM LargeTable; + )", TTxControl::BeginTx().CommitTx(), settings).GetValueSync(); + UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString()); + + std::shared_ptr arrowSchema; + + size_t count = 0; + for (;;) { + auto part = it.ReadNext().GetValueSync(); + if (!part.IsSuccess()) { + UNIT_ASSERT_C(part.EOS(), part.GetIssues().ToString()); + break; + } + + if (part.HasResultSet()) { + auto resultSet = part.ExtractResultSet(); + UNIT_ASSERT_VALUES_EQUAL(TArrowAccessor::Format(resultSet), TResultSet::EFormat::Arrow); + + const auto& schema = TArrowAccessor::GetArrowSchema(resultSet); + const auto& batches = TArrowAccessor::GetArrowBatches(resultSet); + + UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); + + if (count == 0) { + // With Arrow format, the result set contains a YQL schema and an Arrow record batch schema + UNIT_ASSERT_VALUES_UNEQUAL_C(resultSet.ColumnsCount(), 0, "Columns must not be empty for the first result set"); + UNIT_ASSERT_C(!schema.empty(), "Schema must not be empty for the first result set"); + + arrowSchema = NArrow::DeserializeSchema(TString(schema)); + + UNIT_ASSERT_C(arrowSchema, "Schema must be deserialized"); + } else { + UNIT_ASSERT_VALUES_EQUAL_C(resultSet.ColumnsCount(), 0, "Columns count must be empty for the rest result sets"); + UNIT_ASSERT_C(schema.empty(), "Schema must be empty for the rest result sets"); + } + + auto arrowBatch = NArrow::DeserializeBatch(TString(batches[0]), arrowSchema); + UNIT_ASSERT_C(arrowBatch, "Batch must be deserialized"); + UNIT_ASSERT_C(arrowBatch->ValidateFull().ok(), "Batch validation failed"); + UNIT_ASSERT_GT_C(arrowBatch->num_rows(), 0, "Batch must have at least 1 row"); + + ++count; + } + } + + UNIT_ASSERT_GT_C(count, 1, "Expected at least 2 result sets"); + } + + /** + * For Arrow format, FirstOnly schema inclusion mode is supported for multistatement queries. + */ + Y_UNIT_TEST(ArrowFormat_SchemaInclusionMode_FirstOnly_Multistatement) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ false, 1_KB); + auto client = kikimr.GetQueryClient(); + + CreateLargeTable(kikimr, 200, 2, 2, 10, 2); + + auto settings = TExecuteQuerySettings().Format(TResultSet::EFormat::Arrow).SchemaInclusionMode(ESchemaInclusionMode::FirstOnly); + + auto it = client.StreamExecuteQuery(R"( + SELECT * FROM LargeTable; + SELECT Key, Data FROM LargeTable; + )", TTxControl::BeginTx().CommitTx(), settings).GetValueSync(); + UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString()); + + std::unordered_map> arrowSchemas; + std::unordered_map counts; + + for (;;) { + auto part = it.ReadNext().GetValueSync(); + if (!part.IsSuccess()) { + UNIT_ASSERT_C(part.EOS(), part.GetIssues().ToString()); + break; + } + + if (part.HasResultSet()) { + auto resultSet = part.ExtractResultSet(); + UNIT_ASSERT_VALUES_EQUAL(TArrowAccessor::Format(resultSet), TResultSet::EFormat::Arrow); + + const auto& schema = TArrowAccessor::GetArrowSchema(resultSet); + const auto& batches = TArrowAccessor::GetArrowBatches(resultSet); + + UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); + + auto idx = part.GetResultSetIndex(); + + if (arrowSchemas.find(idx) == arrowSchemas.end()) { + // The first result set of each statement contains schemas. + UNIT_ASSERT_VALUES_UNEQUAL_C(resultSet.ColumnsCount(), 0, "Columns must not be empty for the first result set of the statement"); + UNIT_ASSERT_C(!schema.empty(), "Schema must not be empty for the first result set of the statement"); + + arrowSchemas[idx] = NArrow::DeserializeSchema(TString(schema)); + + UNIT_ASSERT_C(arrowSchemas[idx], "Schema must be deserialized"); + } else { + UNIT_ASSERT_VALUES_EQUAL_C(resultSet.ColumnsCount(), 0, "Columns count must be empty for the rest result sets of the statement"); + UNIT_ASSERT_C(schema.empty(), "Schema must be empty for the rest result sets of the statement"); + } + + auto arrowBatch = NArrow::DeserializeBatch(TString(batches[0]), arrowSchemas[idx]); + UNIT_ASSERT_C(arrowBatch, "Batch must be deserialized"); + UNIT_ASSERT_C(arrowBatch->ValidateFull().ok(), "Batch validation failed"); + UNIT_ASSERT_GT_C(arrowBatch->num_rows(), 0, "Batch must have at least 1 row"); + + ++counts[idx]; + } + } + + UNIT_ASSERT_C(counts.size() == 2, "Expected 2 result set indexes"); + + for (const auto& [idx, count] : counts) { + UNIT_ASSERT_GT_C(count, 1, "Expected at least 2 result sets for statement with ResultSetIndex = " << idx); + } + } + + /** + * Small stress test for Arrow format: + * - 3 statements + * - 10 shards, 1000 rows per shard + * - ZSTD compression with level 10 + * - SchemaInclusionMode is FIRST_ONLY + * - ChannelBufferSize is 1KB + */ + Y_UNIT_TEST(ArrowFormat_Stress) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ false, 1_KB); + auto client = kikimr.GetQueryClient(); + + CreateLargeTable(kikimr, 1000, 4, 4, 100, 10); + + auto settings = TExecuteQuerySettings() + .Format(TResultSet::EFormat::Arrow) + .SchemaInclusionMode(ESchemaInclusionMode::FirstOnly) + .ArrowFormatSettings(TArrowFormatSettings() + .CompressionCodec(TArrowFormatSettings::TCompressionCodec() + .Type(TArrowFormatSettings::TCompressionCodec::EType::Zstd) + .Level(10))); + + auto it = client.StreamExecuteQuery(R"( + SELECT * FROM LargeTable; + UPDATE LargeTable SET Data = Data + 1 WHERE Key % 2 = 1 RETURNING Data; + SELECT DataText FROM LargeTable WHERE Key % 2 = 0; + )", TTxControl::BeginTx().CommitTx(), settings).GetValueSync(); + UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString()); + + std::unordered_map> arrowSchemas; + std::unordered_map counts; + + for (;;) { + auto part = it.ReadNext().GetValueSync(); + if (!part.IsSuccess()) { + UNIT_ASSERT_C(part.EOS(), part.GetIssues().ToString()); + break; + } + + if (part.HasResultSet()) { + auto resultSet = part.ExtractResultSet(); + UNIT_ASSERT_VALUES_EQUAL(TArrowAccessor::Format(resultSet), TResultSet::EFormat::Arrow); + + const auto& schema = TArrowAccessor::GetArrowSchema(resultSet); + const auto& batches = TArrowAccessor::GetArrowBatches(resultSet); + + UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); + + auto idx = part.GetResultSetIndex(); + + if (arrowSchemas.find(idx) == arrowSchemas.end()) { + // The first result set of each statement contains schemas. + UNIT_ASSERT_VALUES_UNEQUAL_C(resultSet.ColumnsCount(), 0, "Columns must not be empty for the first result set of the statement"); + UNIT_ASSERT_C(!schema.empty(), "Schema must not be empty for the first result set of the statement"); + + arrowSchemas[idx] = NArrow::DeserializeSchema(TString(schema)); + + UNIT_ASSERT_C(arrowSchemas[idx], "Schema must be deserialized"); + } else { + UNIT_ASSERT_VALUES_EQUAL_C(resultSet.ColumnsCount(), 0, "Columns count must be empty for the rest result sets of the statement"); + UNIT_ASSERT_C(schema.empty(), "Schema must be empty for the rest result sets of the statement"); + } + + auto arrowBatch = NArrow::DeserializeBatch(TString(batches[0]), arrowSchemas[idx]); + UNIT_ASSERT_C(arrowBatch, "Batch must be deserialized"); + UNIT_ASSERT_C(arrowBatch->ValidateFull().ok(), "Batch validation failed"); + UNIT_ASSERT_GT_C(arrowBatch->num_rows(), 0, "Batch must have at least 1 row"); + + ++counts[idx]; + } + } + + UNIT_ASSERT_C(counts.size() == 3, "Expected 3 result set indexes"); + + for (const auto& [idx, count] : counts) { + UNIT_ASSERT_GT_C(count, 1, "Expected at least 2 result sets for statement with ResultSetIndex = " << idx); + } + } + + /** + * More tests for different types with correctness and convertations between Arrow and UV : + * ydb/library/yql/dq/runtime/dq_arrow_helpers_ut.cpp + */ + + // Optional + Y_UNIT_TEST(ArrowFormat_Types_Optional_1) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ true); + auto client = kikimr.GetQueryClient(); + + { + auto batches = ExecuteAndCombineBatches(client, R"( + SELECT Key1, Name FROM Join2 + WHERE Key1 IN [104, 106, 108] + ORDER BY Key1; + )", /* assertSize */ false, 1); + + UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); + + const auto& batch = batches.front(); + + UNIT_ASSERT_VALUES_EQUAL(batch->num_rows(), 3); + UNIT_ASSERT_VALUES_EQUAL(batch->num_columns(), 2); + + ValidateOptionalColumn(batch->column(0), 1, false); + ValidateOptionalColumn(batch->column(1), 1, false); + + const TString expected = +R"(Key1: [ + 104, + 106, + 108 + ] +Name: [ + 4E616D6533, + 4E616D6533, + null + ] +)"; + UNIT_ASSERT_VALUES_EQUAL(batch->ToString(), expected); + } + } + + // Optional> + Y_UNIT_TEST(ArrowFormat_Types_Optional_2) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ true); + auto client = kikimr.GetQueryClient(); + + { + auto batches = ExecuteAndCombineBatches(client, R"( + SELECT Just(Key1), Just(Name) FROM Join2 + WHERE Key1 IN [104, 106, 108] + ORDER BY Key1; + )", /* assertSize */ false, 1); + + UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); + + const auto& batch = batches.front(); + + UNIT_ASSERT_VALUES_EQUAL(batch->num_rows(), 3); + UNIT_ASSERT_VALUES_EQUAL(batch->num_columns(), 2); + + ValidateOptionalColumn(batch->column(0), 2, false); + ValidateOptionalColumn(batch->column(1), 2, false); + + const TString expected = +R"(column0: -- is_valid: all not null + -- child 0 type: uint32 + [ + 104, + 106, + 108 + ] +column1: -- is_valid: all not null + -- child 0 type: binary + [ + 4E616D6533, + 4E616D6533, + null + ] +)"; + UNIT_ASSERT_VALUES_EQUAL(batch->ToString(), expected); + } + } + + // Optional>>> + Y_UNIT_TEST(ArrowFormat_Types_Optional_3) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ true); + auto client = kikimr.GetQueryClient(); + + { + auto batches = ExecuteAndCombineBatches(client, R"( + SELECT Just(Just(Just(Key1))), Just(Just(Just(Name))) FROM Join2 + WHERE Key1 IN [104, 106, 108] + ORDER BY Key1; + )", /* assertSize */ false, 1); + + UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); + + const auto& batch = batches.front(); + + UNIT_ASSERT_VALUES_EQUAL(batch->num_rows(), 3); + UNIT_ASSERT_VALUES_EQUAL(batch->num_columns(), 2); + + ValidateOptionalColumn(batch->column(0), 3, false); + ValidateOptionalColumn(batch->column(1), 3, false); + + const TString expected = +R"(column0: -- is_valid: all not null + -- child 0 type: struct not null> + -- is_valid: all not null + -- child 0 type: struct + -- is_valid: all not null + -- child 0 type: uint32 + [ + 104, + 106, + 108 + ] +column1: -- is_valid: all not null + -- child 0 type: struct not null> + -- is_valid: all not null + -- child 0 type: struct + -- is_valid: all not null + -- child 0 type: binary + [ + 4E616D6533, + 4E616D6533, + null + ] +)"; + UNIT_ASSERT_VALUES_EQUAL(batch->ToString(), expected); + } + } + + // Optional> + Y_UNIT_TEST(ArrowFormat_Types_Optional_4) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ false); + auto client = kikimr.GetQueryClient(); + + { + auto batches = ExecuteAndCombineBatches(client, R"( + SELECT Just(Variant(1, "foo", Variant)); + )", /* assertSize */ false); + + UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); + + const auto& batch = batches.front(); + + UNIT_ASSERT_VALUES_EQUAL(batch->num_rows(), 1); + UNIT_ASSERT_VALUES_EQUAL(batch->num_columns(), 1); + + ValidateOptionalColumn(batch->column(0), 1, /* isVariant */ true); + + const TString expected = +R"(column0: -- is_valid: all not null + -- child 0 type: dense_union + -- is_valid: all not null + -- type_ids: [ + 1 + ] + -- value_offsets: [ + 0 + ] + -- child 0 type: uint8 + [] + -- child 1 type: int32 + [ + 1 + ] +)"; + UNIT_ASSERT_VALUES_EQUAL(batch->ToString(), expected); + } + } + + // Optional>> + Y_UNIT_TEST(ArrowFormat_Types_Optional_5) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ false); + auto client = kikimr.GetQueryClient(); + + { + auto batches = ExecuteAndCombineBatches(client, R"( + SELECT Just(Just(Variant(1, "foo", Variant))); + )", /* assertSize */ false); + + UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); + + const auto& batch = batches.front(); + + UNIT_ASSERT_VALUES_EQUAL(batch->num_rows(), 1); + UNIT_ASSERT_VALUES_EQUAL(batch->num_columns(), 1); + + ValidateOptionalColumn(batch->column(0), 2, /* isVariant */ true); + + const TString expected = +R"(column0: -- is_valid: all not null + -- child 0 type: struct not null> + -- is_valid: all not null + -- child 0 type: dense_union + -- is_valid: all not null + -- type_ids: [ + 1 + ] + -- value_offsets: [ + 0 + ] + -- child 0 type: uint8 + [] + -- child 1 type: int32 + [ + 1 + ] + -- child 2 type: binary + [] +)"; + UNIT_ASSERT_VALUES_EQUAL(batch->ToString(), expected); + } + } + + // List + Y_UNIT_TEST(ArrowFormat_Types_List_1) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ false); + auto client = kikimr.GetQueryClient(); + + { + auto batches = ExecuteAndCombineBatches(client, R"( + SELECT CAST([1, 2, 3] AS List); + )", /* assertSize */ false); + + UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); + + const auto& batch = batches.front(); + + UNIT_ASSERT_VALUES_EQUAL(batch->num_rows(), 1); + UNIT_ASSERT_VALUES_EQUAL(batch->num_columns(), 1); + UNIT_ASSERT_C(batch->column(0)->type()->id() == arrow::Type::LIST, "Column type must be arrow::Type::LIST"); + + const TString expected = +R"(column0: [ + [ + 1, + 2, + 3 + ] + ] +)"; + UNIT_ASSERT_VALUES_EQUAL(batch->ToString(), expected); + } + } + + // List> + Y_UNIT_TEST(ArrowFormat_Types_List_2) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ false); + auto client = kikimr.GetQueryClient(); + + { + auto batches = ExecuteAndCombineBatches(client, R"( + SELECT CAST([1, NULL, 3, null] AS List>); + )", /* assertSize */ false); + + UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); + + const auto& batch = batches.front(); + + UNIT_ASSERT_VALUES_EQUAL(batch->num_rows(), 1); + UNIT_ASSERT_VALUES_EQUAL(batch->num_columns(), 1); + UNIT_ASSERT_C(batch->column(0)->type()->id() == arrow::Type::LIST, "Column type must be arrow::Type::LIST"); + + const TString expected = +R"(column0: [ + [ + 1, + null, + 3, + null + ] + ] +)"; + UNIT_ASSERT_VALUES_EQUAL(batch->ToString(), expected); + } + } + + // List> of columns + Y_UNIT_TEST(ArrowFormat_Types_List_3) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ true); + auto client = kikimr.GetQueryClient(); + + { + auto batches = ExecuteAndCombineBatches(client, R"( + SELECT [App, Host] FROM Logs + ORDER BY App; + )", /* assertSize */ false); + + UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); + + const auto& batch = batches.front(); + + UNIT_ASSERT_VALUES_EQUAL(batch->num_rows(), 9); + UNIT_ASSERT_VALUES_EQUAL(batch->num_columns(), 1); + UNIT_ASSERT_C(batch->column(0)->type()->id() == arrow::Type::LIST, "Column type must be arrow::Type::LIST"); + + const TString expected = +R"(column0: [ + [ + "apache", + "front-42" + ], + [ + "kikimr-db", + "kikimr-db-10" + ], + [ + "kikimr-db", + "kikimr-db-21" + ], + [ + "kikimr-db", + "kikimr-db-21" + ], + [ + "kikimr-db", + "kikimr-db-53" + ], + [ + "nginx", + "nginx-10" + ], + [ + "nginx", + "nginx-23" + ], + [ + "nginx", + "nginx-23" + ], + [ + "ydb", + "ydb-1000" + ] + ] +)"; + UNIT_ASSERT_VALUES_EQUAL(batch->ToString(), expected); + } + } + + // List<> + Y_UNIT_TEST(ArrowFormat_Types_EmptyList) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ false); + auto client = kikimr.GetQueryClient(); + + { + auto batches = ExecuteAndCombineBatches(client, R"( + SELECT []; + )", /* assertSize */ false); + + UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); + + const auto& batch = batches.front(); + + UNIT_ASSERT_VALUES_EQUAL(batch->num_rows(), 1); + UNIT_ASSERT_VALUES_EQUAL(batch->num_columns(), 1); + UNIT_ASSERT_C(batch->column(0)->type()->id() == arrow::Type::STRUCT, "Column type must be arrow::Type::STRUCT"); + + const TString expected = +R"(column0: -- is_valid: all not null +)"; + UNIT_ASSERT_VALUES_EQUAL(batch->ToString(), expected); + } + } + + // Tuple + Y_UNIT_TEST(ArrowFormat_Types_Tuple) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ false); + auto client = kikimr.GetQueryClient(); + + { + auto batches = ExecuteAndCombineBatches(client, R"( + SELECT CAST((1, 2.5) AS Tuple); + )", /* assertSize */ false); + + UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); + + const auto& batch = batches.front(); + + UNIT_ASSERT_VALUES_EQUAL(batch->num_rows(), 1); + UNIT_ASSERT_VALUES_EQUAL(batch->num_columns(), 1); + UNIT_ASSERT_C(batch->column(0)->type()->id() == arrow::Type::STRUCT, "Column type must be arrow::Type::STRUCT"); + + const TString expected = +R"(column0: -- is_valid: all not null + -- child 0 type: uint32 + [ + 1 + ] + -- child 1 type: double + [ + 2.5 + ] +)"; + UNIT_ASSERT_VALUES_EQUAL(batch->ToString(), expected); + } + } + + // Dict + Y_UNIT_TEST(ArrowFormat_Types_Dict_1) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ false); + auto client = kikimr.GetQueryClient(); + + { + auto batches = ExecuteAndCombineBatches(client, R"( + SELECT CAST({"a": 1, "b": 2, "c": 3} AS Dict); + )", /* assertSize */ false); + + UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); + + const auto& batch = batches.front(); + + UNIT_ASSERT_VALUES_EQUAL(batch->num_rows(), 1); + UNIT_ASSERT_VALUES_EQUAL(batch->num_columns(), 1); + UNIT_ASSERT_C(batch->column(0)->type()->id() == arrow::Type::STRUCT, "Column type must be arrow::Type::STRUCT"); + + const TString expected = +R"(column0: -- is_valid: all not null + -- child 0 type: map + [ + keys: + [ + 61, + 63, + 62 + ] + values: + [ + 1, + 3, + 2 + ] + ] + -- child 1 type: uint64 + [ + 0 + ] +)"; + + UNIT_ASSERT_VALUES_EQUAL(batch->ToString(), expected); + } + } + + // Dict, V> + Y_UNIT_TEST(ArrowFormat_Types_Dict_2) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ false); + auto client = kikimr.GetQueryClient(); + + { + auto batches = ExecuteAndCombineBatches(client, R"( + SELECT CAST({"a": 1, "b": 2, NULL: 3} AS Dict,Int32>); + )", /* assertSize */ false); + + UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); + + const auto& batch = batches.front(); + + UNIT_ASSERT_VALUES_EQUAL(batch->num_rows(), 1); + UNIT_ASSERT_VALUES_EQUAL(batch->num_columns(), 1); + UNIT_ASSERT_C(batch->column(0)->type()->id() == arrow::Type::STRUCT, "Column type must be arrow::Type::STRUCT"); + + const TString expected = +R"(column0: -- is_valid: all not null + -- child 0 type: list> + [ + -- is_valid: all not null + -- child 0 type: binary + [ + 61, + 62, + null + ] + -- child 1 type: int32 + [ + 1, + 2, + 3 + ] + ] + -- child 1 type: uint64 + [ + 0 + ] +)"; + + UNIT_ASSERT_VALUES_EQUAL(batch->ToString(), expected); + } + } + + // Dict<> + Y_UNIT_TEST(ArrowFormat_Types_EmptyDict) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ false); + auto client = kikimr.GetQueryClient(); + + { + auto batches = ExecuteAndCombineBatches(client, R"( + SELECT {}; + )", /* assertSize */ false); + + UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); + + const auto& batch = batches.front(); + + UNIT_ASSERT_VALUES_EQUAL(batch->num_rows(), 1); + UNIT_ASSERT_VALUES_EQUAL(batch->num_columns(), 1); + UNIT_ASSERT_C(batch->column(0)->type()->id() == arrow::Type::STRUCT, "Column type must be arrow::Type::STRUCT"); + + const TString expected = +R"(column0: -- is_valid: all not null +)"; + UNIT_ASSERT_VALUES_EQUAL(batch->ToString(), expected); + } + } + + // Struct + Y_UNIT_TEST(ArrowFormat_Types_Struct) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ false); + auto client = kikimr.GetQueryClient(); + + { + auto batches = ExecuteAndCombineBatches(client, R"( + SELECT CAST(<|first: 1, second: "2"|> AS Struct); + )", /* assertSize */ false); + + UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); + + const auto& batch = batches.front(); + + UNIT_ASSERT_VALUES_EQUAL(batch->num_rows(), 1); + UNIT_ASSERT_VALUES_EQUAL(batch->num_columns(), 1); + UNIT_ASSERT_C(batch->column(0)->type()->id() == arrow::Type::STRUCT, "Column type must be arrow::Type::STRUCT"); + + const TString expected = +R"(column0: -- is_valid: all not null + -- child 0 type: int32 + [ + 1 + ] + -- child 1 type: string + [ + "2" + ] +)"; + UNIT_ASSERT_VALUES_EQUAL(batch->ToString(), expected); + } + } + + // Variant + Y_UNIT_TEST(ArrowFormat_Types_Variant) { + auto kikimr = CreateKikimrRunner(/* withSampleTables */ false); + auto client = kikimr.GetQueryClient(); + + { + auto batches = ExecuteAndCombineBatches(client, R"( + SELECT Variant(1, "foo", Variant); + )", /* assertSize */ false); + + UNIT_ASSERT_C(!batches.empty(), "Batches must not be empty"); + + const auto& batch = batches.front(); + + UNIT_ASSERT_VALUES_EQUAL(batch->num_rows(), 1); + UNIT_ASSERT_VALUES_EQUAL(batch->num_columns(), 1); + UNIT_ASSERT_C(batch->column(0)->type()->id() == arrow::Type::DENSE_UNION, "Column type must be arrow::Type::DENSE_UNION"); + + const TString expected = +R"(column0: -- is_valid: all not null + -- type_ids: [ + 1 + ] + -- value_offsets: [ + 0 + ] + -- child 0 type: uint8 + [] + -- child 1 type: int32 + [ + 1 + ] +)"; + UNIT_ASSERT_VALUES_EQUAL(batch->ToString(), expected); + } + } +} + +} // namespace NKikimr::NKqp diff --git a/ydb/core/kqp/ut/arrow/ya.make b/ydb/core/kqp/ut/arrow/ya.make index 666b7755881..891a4052de2 100644 --- a/ydb/core/kqp/ut/arrow/ya.make +++ b/ydb/core/kqp/ut/arrow/ya.make @@ -8,7 +8,7 @@ SIZE(MEDIUM) SRCS( kqp_arrow_in_channels_ut.cpp kqp_types_arrow_ut.cpp - kqp_result_set_format_arrow.cpp + kqp_result_set_formats.cpp ) PEERDIR( diff --git a/ydb/core/tx/columnshard/test_helper/columnshard_ut_common.h b/ydb/core/tx/columnshard/test_helper/columnshard_ut_common.h index 1ac17a5da7a..98738dcc1eb 100644 --- a/ydb/core/tx/columnshard/test_helper/columnshard_ut_common.h +++ b/ydb/core/tx/columnshard/test_helper/columnshard_ut_common.h @@ -547,6 +547,13 @@ public: } } + if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { + Y_ABORT_UNLESS(typedBuilder.Append(data.Buf_.Bytes).ok()); + return true; + } + } + Y_ABORT("Unknown type combination"); return false; })); diff --git a/ydb/library/yql/dq/runtime/dq_arrow_helpers.cpp b/ydb/library/yql/dq/runtime/dq_arrow_helpers.cpp index c55fddb0c13..a329c36f757 100644 --- a/ydb/library/yql/dq/runtime/dq_arrow_helpers.cpp +++ b/ydb/library/yql/dq/runtime/dq_arrow_helpers.cpp @@ -1,6 +1,8 @@ #include "dq_arrow_helpers.h" #include + +#include #include #include #include @@ -51,11 +53,10 @@ struct TTypeWrapper template bool SwitchMiniKQLDataTypeToArrowType(NUdf::EDataSlot type, TFunc&& callback) { switch (type) { - case NUdf::EDataSlot::Bool: - return callback(TTypeWrapper()); case NUdf::EDataSlot::Int8: return callback(TTypeWrapper()); case NUdf::EDataSlot::Uint8: + case NUdf::EDataSlot::Bool: return callback(TTypeWrapper()); case NUdf::EDataSlot::Int16: return callback(TTypeWrapper()); @@ -69,40 +70,59 @@ bool SwitchMiniKQLDataTypeToArrowType(NUdf::EDataSlot type, TFunc&& callback) { case NUdf::EDataSlot::Uint32: return callback(TTypeWrapper()); case NUdf::EDataSlot::Int64: + case NUdf::EDataSlot::Interval: case NUdf::EDataSlot::Datetime64: case NUdf::EDataSlot::Timestamp64: case NUdf::EDataSlot::Interval64: return callback(TTypeWrapper()); case NUdf::EDataSlot::Uint64: + case NUdf::EDataSlot::Timestamp: return callback(TTypeWrapper()); case NUdf::EDataSlot::Float: return callback(TTypeWrapper()); case NUdf::EDataSlot::Double: return callback(TTypeWrapper()); - case NUdf::EDataSlot::Timestamp: - return callback(TTypeWrapper()); - case NUdf::EDataSlot::Interval: - return callback(TTypeWrapper()); case NUdf::EDataSlot::Utf8: case NUdf::EDataSlot::Json: - case NUdf::EDataSlot::Yson: - case NUdf::EDataSlot::JsonDocument: return callback(TTypeWrapper()); case NUdf::EDataSlot::String: - case NUdf::EDataSlot::Uuid: case NUdf::EDataSlot::DyNumber: + case NUdf::EDataSlot::Yson: + case NUdf::EDataSlot::JsonDocument: return callback(TTypeWrapper()); case NUdf::EDataSlot::Decimal: - return callback(TTypeWrapper()); - // TODO convert Tz-types to native arrow date and time types. + case NUdf::EDataSlot::Uuid: + return callback(TTypeWrapper()); case NUdf::EDataSlot::TzDate: case NUdf::EDataSlot::TzDatetime: case NUdf::EDataSlot::TzTimestamp: case NUdf::EDataSlot::TzDate32: case NUdf::EDataSlot::TzDatetime64: case NUdf::EDataSlot::TzTimestamp64: + return callback(TTypeWrapper()); + } +} + +bool NeedWrapByExternalOptional(const TType* type) { + switch (type->GetKind()) { + case TType::EKind::Void: + case TType::EKind::Null: + case TType::EKind::Variant: + case TType::EKind::Optional: + return true; + case TType::EKind::EmptyList: + case TType::EKind::EmptyDict: + case TType::EKind::Data: + case TType::EKind::Struct: + case TType::EKind::Tuple: + case TType::EKind::List: + case TType::EKind::Dict: return false; + default: + YQL_ENSURE(false, "Unsupported type: " << type->GetKindAsStr()); } + + return true; } template @@ -112,7 +132,52 @@ NUdf::TUnboxedValue GetUnboxedValue(std::shared_ptr column, ui32 r return NUdf::TUnboxedValuePod(static_cast(array->Value(row))); } -// The following 4 specialization are for darwin build (because of difference in long long) +template <> +NUdf::TUnboxedValue GetUnboxedValue(std::shared_ptr column, ui32 row) { + auto array = std::static_pointer_cast(column); + YQL_ENSURE(array->num_fields() == 2, "StructArray of some TzDate type should have 2 fields"); + + auto datetimeArray = array->field(0); + auto timezoneArray = std::static_pointer_cast(array->field(1)); + + NUdf::TUnboxedValuePod value; + + switch (datetimeArray->type()->id()) { + // NUdf::EDataSlot::TzDate + case arrow::Type::UINT16: { + value = NUdf::TUnboxedValuePod(static_cast(std::static_pointer_cast(datetimeArray)->Value(row))); + break; + } + // NUdf::EDataSlot::TzDatetime + case arrow::Type::UINT32: { + value = NUdf::TUnboxedValuePod(static_cast(std::static_pointer_cast(datetimeArray)->Value(row))); + break; + } + // NUdf::EDataSlot::TzTimestamp + case arrow::Type::UINT64: { + value = NUdf::TUnboxedValuePod(static_cast(std::static_pointer_cast(datetimeArray)->Value(row))); + break; + } + // NUdf::EDataSlot::TzDate32 + case arrow::Type::INT32: { + value = NUdf::TUnboxedValuePod(static_cast(std::static_pointer_cast(datetimeArray)->Value(row))); + break; + } + // NUdf::EDataSlot::TzDatetime64, NUdf::EDataSlot::TzTimestamp64 + case arrow::Type::INT64: { + value = NUdf::TUnboxedValuePod(static_cast(std::static_pointer_cast(datetimeArray)->Value(row))); + break; + } + default: + YQL_ENSURE(false, "Unexpected timezone datetime slot"); + return NUdf::TUnboxedValuePod(); + } + + value.SetTimezoneId(timezoneArray->Value(row)); + return value; +} + +// The following specializations are for darwin build (because of difference in long long) template <> // For darwin build NUdf::TUnboxedValue GetUnboxedValue(std::shared_ptr column, ui32 row) { @@ -126,20 +191,6 @@ NUdf::TUnboxedValue GetUnboxedValue(std::shared_ptr(array->Value(row))); } -template <> // For darwin build -NUdf::TUnboxedValue GetUnboxedValue(std::shared_ptr column, ui32 row) { - using TArrayType = typename arrow::TypeTraits::ArrayType; - auto array = std::static_pointer_cast(column); - return NUdf::TUnboxedValuePod(static_cast(array->Value(row))); -} - -template <> // For darwin build -NUdf::TUnboxedValue GetUnboxedValue(std::shared_ptr column, ui32 row) { - using TArrayType = typename arrow::TypeTraits::ArrayType; - auto array = std::static_pointer_cast(column); - return NUdf::TUnboxedValuePod(static_cast(array->Value(row))); -} - template <> NUdf::TUnboxedValue GetUnboxedValue(std::shared_ptr column, ui32 row) { auto array = std::static_pointer_cast(column); @@ -155,46 +206,63 @@ NUdf::TUnboxedValue GetUnboxedValue(std::shared_ptr -NUdf::TUnboxedValue GetUnboxedValue(std::shared_ptr column, ui32 row) { - auto array = std::static_pointer_cast(column); +NUdf::TUnboxedValue GetUnboxedValue(std::shared_ptr column, ui32 row) { + auto array = std::static_pointer_cast(column); auto data = array->GetView(row); - // We check that Decimal(22,9) but it may not be true - // TODO Support other decimal precisions. - const auto& type = arrow::internal::checked_cast(*array->type()); - Y_ABORT_UNLESS(type.precision() == NScheme::DECIMAL_PRECISION, "Unsupported Decimal precision."); - Y_ABORT_UNLESS(type.scale() == NScheme::DECIMAL_SCALE, "Unsupported Decimal scale."); - Y_ABORT_UNLESS(data.size() == sizeof(NYql::NDecimal::TInt128), "Wrong data size"); - NYql::NDecimal::TInt128 val; - std::memcpy(reinterpret_cast(&val), data.data(), data.size()); - return NUdf::TUnboxedValuePod(val); + return NMiniKQL::MakeString(NUdf::TStringRef(data.data(), data.size())); } template -std::shared_ptr CreateEmptyArrowImpl() { +std::shared_ptr CreateEmptyArrowImpl(NUdf::EDataSlot slot) { + Y_UNUSED(slot); return std::make_shared(); } template <> -std::shared_ptr CreateEmptyArrowImpl() { - // TODO use non-fixed precision, derive it from data. - return arrow::decimal(NScheme::DECIMAL_PRECISION, NScheme::DECIMAL_SCALE); +std::shared_ptr CreateEmptyArrowImpl(NUdf::EDataSlot slot) { + Y_UNUSED(slot); + return arrow::fixed_size_binary(NScheme::FSB_SIZE); } template <> -std::shared_ptr CreateEmptyArrowImpl() { - return arrow::timestamp(arrow::TimeUnit::TimeUnit::MICRO); -} +std::shared_ptr CreateEmptyArrowImpl(NUdf::EDataSlot slot) { + std::shared_ptr type; + switch (slot) { + case NUdf::EDataSlot::TzDate: + type = NYql::NUdf::MakeTzLayoutArrowType(); + break; + case NUdf::EDataSlot::TzDatetime: + type = NYql::NUdf::MakeTzLayoutArrowType(); + break; + case NUdf::EDataSlot::TzTimestamp: + type = NYql::NUdf::MakeTzLayoutArrowType(); + break; + case NUdf::EDataSlot::TzDate32: + type = NYql::NUdf::MakeTzLayoutArrowType(); + break; + case NUdf::EDataSlot::TzDatetime64: + type = NYql::NUdf::MakeTzLayoutArrowType(); + break; + case NUdf::EDataSlot::TzTimestamp64: + type = NYql::NUdf::MakeTzLayoutArrowType(); + break; + default: + YQL_ENSURE(false, "Unexpected timezone datetime slot"); + return std::make_shared(); + } -template <> -std::shared_ptr CreateEmptyArrowImpl() { - return arrow::duration(arrow::TimeUnit::TimeUnit::MICRO); + std::vector> fields { + std::make_shared("datetime", type, false), + std::make_shared("timezoneId", arrow::uint16(), false), + }; + return arrow::struct_(fields); } std::shared_ptr GetArrowType(const TDataType* dataType) { std::shared_ptr result; bool success = SwitchMiniKQLDataTypeToArrowType(*dataType->GetDataSlot().Get(), [&](TTypeWrapper typeHolder) { Y_UNUSED(typeHolder); - result = CreateEmptyArrowImpl(); + result = CreateEmptyArrowImpl(*dataType->GetDataSlot().Get()); return true; }); if (success) { @@ -208,8 +276,10 @@ std::shared_ptr GetArrowType(const TStructType* structType) { fields.reserve(structType->GetMembersCount()); for (ui32 index = 0; index < structType->GetMembersCount(); ++index) { auto memberType = structType->GetMemberType(index); - fields.emplace_back(std::make_shared(std::string(structType->GetMemberName(index)), - NArrow::GetArrowType(memberType))); + auto memberName = std::string(structType->GetMemberName(index)); + auto memberArrowType = NArrow::GetArrowType(memberType); + + fields.emplace_back(std::make_shared(memberName, memberArrowType, memberType->IsOptional())); } return arrow::struct_(fields); } @@ -218,27 +288,42 @@ std::shared_ptr GetArrowType(const TTupleType* tupleType) { std::vector> fields; fields.reserve(tupleType->GetElementsCount()); for (ui32 index = 0; index < tupleType->GetElementsCount(); ++index) { + auto elementName = std::string("field" + ToString(index)); auto elementType = tupleType->GetElementType(index); - fields.push_back(std::make_shared("", NArrow::GetArrowType(elementType))); + auto elementArrowType = NArrow::GetArrowType(elementType); + + fields.push_back(std::make_shared(elementName, elementArrowType, elementType->IsOptional())); } return arrow::struct_(fields); } std::shared_ptr GetArrowType(const TListType* listType) { auto itemType = listType->GetItemType(); - return arrow::list(NArrow::GetArrowType(itemType)); + auto itemArrowType = NArrow::GetArrowType(itemType); + auto field = std::make_shared("item", itemArrowType, itemType->IsOptional()); + return arrow::list(field); } std::shared_ptr GetArrowType(const TDictType* dictType) { auto keyType = dictType->GetKeyType(); auto payloadType = dictType->GetPayloadType(); + + auto keyArrowType = NArrow::GetArrowType(keyType); + auto payloadArrowType = NArrow::GetArrowType(payloadType); + + auto custom = std::make_shared("custom", arrow::uint64(), false); + if (keyType->GetKind() == TType::EKind::Optional) { - std::vector> fields; - fields.emplace_back(std::make_shared("", NArrow::GetArrowType(keyType))); - fields.emplace_back(std::make_shared("", NArrow::GetArrowType(payloadType))); - return arrow::list(arrow::struct_(fields)); + std::vector> items; + items.emplace_back(std::make_shared("key", keyArrowType, true)); + items.emplace_back(std::make_shared("payload", payloadArrowType, payloadType->IsOptional())); + + auto fieldMap = std::make_shared("map", arrow::list(arrow::struct_(items)), false); + return arrow::struct_({fieldMap, custom}); } - return arrow::map(NArrow::GetArrowType(keyType), NArrow::GetArrowType(payloadType)); + + auto fieldMap = std::make_shared("map", arrow::map(keyArrowType, payloadArrowType), false); + return arrow::struct_({fieldMap, custom}); } std::shared_ptr GetArrowType(const TVariantType* variantType) { @@ -246,52 +331,77 @@ std::shared_ptr GetArrowType(const TVariantType* variantType) { arrow::FieldVector types; TStructType* structType = nullptr; TTupleType* tupleType = nullptr; + if (innerType->IsStruct()) { structType = static_cast(innerType); } else { - Y_VERIFY_S(innerType->IsTuple(), "Unexpected underlying variant type: " << innerType->GetKindAsStr()); + YQL_ENSURE(innerType->IsTuple(), "Unexpected underlying variant type: " << innerType->GetKindAsStr()); tupleType = static_cast(innerType); } + // Create Union of unions if there are more types then arrow::dense_union supports. if (variantType->GetAlternativesCount() > arrow::UnionType::kMaxTypeCode) { - // Create Union of unions if there are more types then arrow::dense_union supports. ui32 numberOfGroups = (variantType->GetAlternativesCount() - 1) / arrow::UnionType::kMaxTypeCode + 1; types.reserve(numberOfGroups); + for (ui32 groupIndex = 0; groupIndex < numberOfGroups; ++groupIndex) { ui32 beginIndex = groupIndex * arrow::UnionType::kMaxTypeCode; ui32 endIndex = std::min((groupIndex + 1) * arrow::UnionType::kMaxTypeCode, variantType->GetAlternativesCount()); + arrow::FieldVector groupTypes; groupTypes.reserve(endIndex - beginIndex); - if (structType == nullptr) { - for (ui32 index = beginIndex; index < endIndex; ++ index) { - groupTypes.emplace_back(std::make_shared("", - NArrow::GetArrowType(tupleType->GetElementType(index)))); - } - } else { - for (ui32 index = beginIndex; index < endIndex; ++ index) { - groupTypes.emplace_back(std::make_shared(std::string(structType->GetMemberName(index)), - NArrow::GetArrowType(structType->GetMemberType(index)))); - } - } - types.emplace_back(std::make_shared("", arrow::dense_union(groupTypes))); - } - } else { - // Simply put all types in one arrow::dense_union - types.reserve(variantType->GetAlternativesCount()); - if (structType == nullptr) { - for (ui32 index = 0; index < variantType->GetAlternativesCount(); ++index) { - types.push_back(std::make_shared("", NArrow::GetArrowType(tupleType->GetElementType(index)))); - } - } else { - for (ui32 index = 0; index < variantType->GetAlternativesCount(); ++index) { - types.emplace_back(std::make_shared(std::string(structType->GetMemberName(index)), - NArrow::GetArrowType(structType->GetMemberType(index)))); + + for (ui32 index = beginIndex; index < endIndex; ++index) { + auto itemName = (structType == nullptr) ? std::string("field" + ToString(index)) : std::string(structType->GetMemberName(index)); + auto itemType = (structType == nullptr) ? tupleType->GetElementType(index) : structType->GetMemberType(index); + auto itemArrowType = NArrow::GetArrowType(itemType); + + groupTypes.emplace_back(std::make_shared(itemName, itemArrowType, itemType->IsOptional())); } + + auto fieldName = std::string("field" + ToString(groupIndex)); + types.emplace_back(std::make_shared(fieldName, arrow::dense_union(groupTypes), false)); } + + return arrow::dense_union(types); } + + // Else put all types in one arrow::dense_union + types.reserve(variantType->GetAlternativesCount()); + for (ui32 index = 0; index < variantType->GetAlternativesCount(); ++index) { + auto itemName = (structType == nullptr) ? std::string("field" + ToString(index)) : std::string(structType->GetMemberName(index)); + auto itemType = (structType == nullptr) ? tupleType->GetElementType(index) : structType->GetMemberType(index); + auto itemArrowType = NArrow::GetArrowType(itemType); + + types.emplace_back(std::make_shared(itemName, itemArrowType, itemType->IsOptional())); + } + return arrow::dense_union(types); } +std::shared_ptr GetArrowType(const TOptionalType* optionalType) { + auto currentType = optionalType->GetItemType(); + ui32 depth = 1; + + while (currentType->IsOptional()) { + currentType = static_cast(currentType)->GetItemType(); + ++depth; + } + + if (NeedWrapByExternalOptional(currentType)) { + ++depth; + } + + std::shared_ptr innerArrowType = NArrow::GetArrowType(currentType); + + for (ui32 i = 1; i < depth; ++i) { + auto field = std::make_shared("opt", innerArrowType, false); + innerArrowType = std::make_shared(std::vector>{ field }); + } + + return innerArrowType; +} + template void AppendDataValue(arrow::ArrayBuilder* builder, NUdf::TUnboxedValue value) { auto typedBuilder = reinterpret_cast::BuilderType*>(builder); @@ -301,12 +411,12 @@ void AppendDataValue(arrow::ArrayBuilder* builder, NUdf::TUnboxedValue value) { } else { status = typedBuilder->Append(value.Get()); } - Y_VERIFY_S(status.ok(), status.ToString()); + YQL_ENSURE(status.ok(), "Failed to append data value: " << status.ToString()); } template <> void AppendDataValue(arrow::ArrayBuilder* builder, NUdf::TUnboxedValue value) { - Y_DEBUG_ABORT_UNLESS(builder->type()->id() == arrow::Type::UINT64); + YQL_ENSURE(builder->type()->id() == arrow::Type::UINT64); auto typedBuilder = reinterpret_cast(builder); arrow::Status status; if (!value.HasValue()) { @@ -314,12 +424,12 @@ void AppendDataValue(arrow::ArrayBuilder* builder, NUdf::TUnb } else { status = typedBuilder->Append(value.Get()); } - Y_VERIFY_S(status.ok(), status.ToString()); + YQL_ENSURE(status.ok(), "Failed to append data value: " << status.ToString()); } template <> void AppendDataValue(arrow::ArrayBuilder* builder, NUdf::TUnboxedValue value) { - Y_DEBUG_ABORT_UNLESS(builder->type()->id() == arrow::Type::INT64); + YQL_ENSURE(builder->type()->id() == arrow::Type::INT64); auto typedBuilder = reinterpret_cast(builder); arrow::Status status; if (!value.HasValue()) { @@ -327,38 +437,12 @@ void AppendDataValue(arrow::ArrayBuilder* builder, NUdf::TUnbo } else { status = typedBuilder->Append(value.Get()); } - Y_VERIFY_S(status.ok(), status.ToString()); -} - -template <> -void AppendDataValue(arrow::ArrayBuilder* builder, NUdf::TUnboxedValue value) { - Y_DEBUG_ABORT_UNLESS(builder->type()->id() == arrow::Type::TIMESTAMP); - auto typedBuilder = reinterpret_cast(builder); - arrow::Status status; - if (!value.HasValue()) { - status = typedBuilder->AppendNull(); - } else { - status = typedBuilder->Append(value.Get()); - } - Y_VERIFY_S(status.ok(), status.ToString()); -} - -template <> -void AppendDataValue(arrow::ArrayBuilder* builder, NUdf::TUnboxedValue value) { - Y_DEBUG_ABORT_UNLESS(builder->type()->id() == arrow::Type::DURATION); - auto typedBuilder = reinterpret_cast(builder); - arrow::Status status; - if (!value.HasValue()) { - status = typedBuilder->AppendNull(); - } else { - status = typedBuilder->Append(value.Get()); - } - Y_VERIFY_S(status.ok(), status.ToString()); + YQL_ENSURE(status.ok(), "Failed to append data value: " << status.ToString()); } template <> void AppendDataValue(arrow::ArrayBuilder* builder, NUdf::TUnboxedValue value) { - Y_DEBUG_ABORT_UNLESS(builder->type()->id() == arrow::Type::STRING); + YQL_ENSURE(builder->type()->id() == arrow::Type::STRING); auto typedBuilder = reinterpret_cast(builder); arrow::Status status; if (!value.HasValue()) { @@ -367,12 +451,12 @@ void AppendDataValue(arrow::ArrayBuilder* builder, NUdf::TUnb auto data = value.AsStringRef(); status = typedBuilder->Append(data.Data(), data.Size()); } - Y_VERIFY_S(status.ok(), status.ToString()); + YQL_ENSURE(status.ok(), "Failed to append data value: " << status.ToString()); } template <> void AppendDataValue(arrow::ArrayBuilder* builder, NUdf::TUnboxedValue value) { - Y_DEBUG_ABORT_UNLESS(builder->type()->id() == arrow::Type::BINARY); + YQL_ENSURE(builder->type()->id() == arrow::Type::BINARY); auto typedBuilder = reinterpret_cast(builder); arrow::Status status; if (!value.HasValue()) { @@ -381,21 +465,86 @@ void AppendDataValue(arrow::ArrayBuilder* builder, NUdf::TUnb auto data = value.AsStringRef(); status = typedBuilder->Append(data.Data(), data.Size()); } - Y_VERIFY_S(status.ok(), status.ToString()); + YQL_ENSURE(status.ok(), "Failed to append data value: " << status.ToString()); } +// Only for timezone datetime types template <> -void AppendDataValue(arrow::ArrayBuilder* builder, NUdf::TUnboxedValue value) { - Y_DEBUG_ABORT_UNLESS(builder->type()->id() == arrow::Type::DECIMAL128); - auto typedBuilder = reinterpret_cast(builder); +void AppendDataValue(arrow::ArrayBuilder* builder, NUdf::TUnboxedValue value) { + YQL_ENSURE(builder->type()->id() == arrow::Type::STRUCT); + auto typedBuilder = reinterpret_cast(builder); + YQL_ENSURE(typedBuilder->num_fields() == 2, "StructBuilder of timezone datetime types should have 2 fields"); + + if (!value.HasValue()) { + auto status = typedBuilder->AppendNull(); + YQL_ENSURE(status.ok(), "Failed to append data value: " << status.ToString()); + return; + } + + auto status = typedBuilder->Append(); + YQL_ENSURE(status.ok(), "Failed to append data value: " << status.ToString()); + + auto datetimeArray = typedBuilder->field_builder(0); + auto timezoneArray = reinterpret_cast(typedBuilder->field_builder(1)); + + switch (datetimeArray->type()->id()) { + // NUdf::EDataSlot::TzDate + case arrow::Type::UINT16: { + status = reinterpret_cast(datetimeArray)->Append(value.Get()); + break; + } + // NUdf::EDataSlot::TzDatetime + case arrow::Type::UINT32: { + status = reinterpret_cast(datetimeArray)->Append(value.Get()); + break; + } + // NUdf::EDataSlot::TzTimestamp + case arrow::Type::UINT64: { + status = reinterpret_cast(datetimeArray)->Append(value.Get()); + break; + } + // NUdf::EDataSlot::TzDate32 + case arrow::Type::INT32: { + status = reinterpret_cast(datetimeArray)->Append(value.Get()); + break; + } + // NUdf::EDataSlot::TzDatetime64, NUdf::EDataSlot::TzTimestamp64 + case arrow::Type::INT64: { + status = reinterpret_cast(datetimeArray)->Append(value.Get()); + break; + } + default: + YQL_ENSURE(false, "Unexpected timezone datetime slot"); + return; + } + YQL_ENSURE(status.ok(), "Failed to append data value: " << status.ToString()); + + status = timezoneArray->Append(value.GetTimezoneId()); + YQL_ENSURE(status.ok(), "Failed to append data value: " << status.ToString()); +} + +template +void AppendFixedSizeDataValue(arrow::ArrayBuilder* builder, NUdf::TUnboxedValue value, NUdf::EDataSlot dataSlot) { + static_assert(std::is_same_v, "This function is only for FixedSizeBinaryType"); + + YQL_ENSURE(builder->type()->id() == arrow::Type::FIXED_SIZE_BINARY); + auto typedBuilder = reinterpret_cast(builder); arrow::Status status; + if (!value.HasValue()) { status = typedBuilder->AppendNull(); } else { - // Parse value from string - status = typedBuilder->Append(value.AsStringRef().Data()); + if (dataSlot == NUdf::EDataSlot::Uuid) { + auto data = value.AsStringRef(); + status = typedBuilder->Append(data.Data()); + } else if (dataSlot == NUdf::EDataSlot::Decimal) { + auto intVal = value.GetInt128(); + status = typedBuilder->Append(reinterpret_cast(&intVal)); + } else { + YQL_ENSURE(false, "Unexpected data slot"); + } } - Y_VERIFY_S(status.ok(), status.ToString()); + YQL_ENSURE(status.ok(), "Failed to append data value: " << status.ToString()); } } // namespace @@ -404,9 +553,10 @@ std::shared_ptr GetArrowType(const TType* type) { switch (type->GetKind()) { case TType::EKind::Void: case TType::EKind::Null: + return arrow::null(); case TType::EKind::EmptyList: case TType::EKind::EmptyDict: - break; + return arrow::struct_({}); case TType::EKind::Data: { auto dataType = static_cast(type); return GetArrowType(dataType); @@ -421,17 +571,7 @@ std::shared_ptr GetArrowType(const TType* type) { } case TType::EKind::Optional: { auto optionalType = static_cast(type); - auto innerOptionalType = optionalType->GetItemType(); - if (innerOptionalType->GetKind() == TType::EKind::Optional) { - std::vector> fields; - fields.emplace_back(std::make_shared("", std::make_shared())); - while (innerOptionalType->GetKind() == TType::EKind::Optional) { - innerOptionalType = static_cast(innerOptionalType)->GetItemType(); - } - fields.emplace_back(std::make_shared("", GetArrowType(innerOptionalType))); - return arrow::struct_(fields); - } - return GetArrowType(innerOptionalType); + return GetArrowType(optionalType); } case TType::EKind::List: { auto listType = static_cast(type); @@ -446,7 +586,7 @@ std::shared_ptr GetArrowType(const TType* type) { return GetArrowType(variantType); } default: - THROW yexception() << "Unsupported type: " << type->GetKindAsStr(); + YQL_ENSURE(false, "Unsupported type: " << type->GetKindAsStr()); } return arrow::null(); } @@ -459,6 +599,7 @@ bool IsArrowCompatible(const NKikimr::NMiniKQL::TType* type) { case TType::EKind::EmptyDict: case TType::EKind::Data: return true; + case TType::EKind::Struct: { auto structType = static_cast(type); bool isCompatible = true; @@ -468,6 +609,7 @@ bool IsArrowCompatible(const NKikimr::NMiniKQL::TType* type) { } return isCompatible; } + case TType::EKind::Tuple: { auto tupleType = static_cast(type); bool isCompatible = true; @@ -477,37 +619,33 @@ bool IsArrowCompatible(const NKikimr::NMiniKQL::TType* type) { } return isCompatible; } + case TType::EKind::Optional: { auto optionalType = static_cast(type); auto innerOptionalType = optionalType->GetItemType(); - if (innerOptionalType->GetKind() == TType::EKind::Optional) { + if (NeedWrapByExternalOptional(innerOptionalType)) { return false; } return IsArrowCompatible(innerOptionalType); } + case TType::EKind::List: { auto listType = static_cast(type); auto itemType = listType->GetItemType(); return IsArrowCompatible(itemType); } - case TType::EKind::Dict: { - auto dictType = static_cast(type); - auto keyType = dictType->GetKeyType(); - auto payloadType = dictType->GetPayloadType(); - if (keyType->GetKind() == TType::EKind::Optional) { - return false; - } - return IsArrowCompatible(keyType) && IsArrowCompatible(payloadType); - } + case TType::EKind::Variant: { auto variantType = static_cast(type); if (variantType->GetAlternativesCount() > arrow::UnionType::kMaxTypeCode) { return false; } TType* innerType = variantType->GetUnderlyingType(); - Y_VERIFY_S(innerType->IsTuple() || innerType->IsStruct(), "Unexpected underlying variant type: " << innerType->GetKindAsStr()); + YQL_ENSURE(innerType->IsTuple() || innerType->IsStruct(), "Unexpected underlying variant type: " << innerType->GetKindAsStr()); return IsArrowCompatible(innerType); } + + case TType::EKind::Dict: case TType::EKind::Block: case TType::EKind::Type: case TType::EKind::Stream: @@ -529,66 +667,84 @@ std::unique_ptr MakeArrowBuilder(const TType* type) { auto arrayType = GetArrowType(type); std::unique_ptr builder; auto status = arrow::MakeBuilder(arrow::default_memory_pool(), arrayType, &builder); - Y_VERIFY_S(status.ok(), status.ToString()); + YQL_ENSURE(status.ok(), "Failed to make arrow builder: " << status.ToString()); return builder; } void AppendElement(NUdf::TUnboxedValue value, arrow::ArrayBuilder* builder, const TType* type) { switch (type->GetKind()) { case TType::EKind::Void: - case TType::EKind::Null: + case TType::EKind::Null: { + YQL_ENSURE(builder->type()->id() == arrow::Type::NA, "Unexpected builder type"); + auto status = builder->AppendNull(); + YQL_ENSURE(status.ok(), "Failed to append null value: " << status.ToString()); + break; + } + case TType::EKind::EmptyList: case TType::EKind::EmptyDict: { - auto status = builder->AppendNull(); - Y_VERIFY_S(status.ok(), status.ToString()); + YQL_ENSURE(builder->type()->id() == arrow::Type::STRUCT, "Unexpected builder type"); + auto structBuilder = reinterpret_cast(builder); + auto status = structBuilder->Append(); + YQL_ENSURE(status.ok(), "Failed to append empty dict/list value: " << status.ToString()); break; } case TType::EKind::Data: { - // TODO for TzDate, TzDatetime, TzTimestamp pass timezone to arrow builder? auto dataType = static_cast(type); - bool success = SwitchMiniKQLDataTypeToArrowType(*dataType->GetDataSlot().Get(), [&](TTypeWrapper typeHolder) { + auto slot = *dataType->GetDataSlot().Get(); + bool success = SwitchMiniKQLDataTypeToArrowType(slot, [&](TTypeWrapper typeHolder) { Y_UNUSED(typeHolder); - AppendDataValue(builder, value); + if constexpr (std::is_same_v) { + AppendFixedSizeDataValue(builder, value, slot); + } else { + AppendDataValue(builder, value); + } return true; }); - Y_ABORT_UNLESS(success); + YQL_ENSURE(success, "Failed to append data value to arrow builder"); break; } case TType::EKind::Optional: { - auto optionalType = static_cast(type); - if (optionalType->GetItemType()->GetKind() != TType::EKind::Optional) { - if (value.HasValue()) { - AppendElement(value.GetOptionalValue(), builder, optionalType->GetItemType()); - } else { - auto status = builder->AppendNull(); - Y_VERIFY_S(status.ok(), status.ToString()); + auto innerType = static_cast(type)->GetItemType(); + ui32 depth = 1; + + while (innerType->IsOptional()) { + innerType = static_cast(innerType)->GetItemType(); + ++depth; + } + + if (NeedWrapByExternalOptional(innerType)) { + ++depth; + } + + auto innerBuilder = builder; + auto innerValue = value; + + for (ui32 i = 1; i < depth; ++i) { + YQL_ENSURE(innerBuilder->type()->id() == arrow::Type::STRUCT, "Unexpected builder type"); + auto structBuilder = reinterpret_cast(innerBuilder); + YQL_ENSURE(structBuilder->num_fields() == 1, "Unexpected number of fields"); + + if (!innerValue) { + auto status = innerBuilder->AppendNull(); + YQL_ENSURE(status.ok(), "Failed to append null optional value: " << status.ToString()); + return; } - } else { - Y_DEBUG_ABORT_UNLESS(builder->type()->id() == arrow::Type::STRUCT); - auto structBuilder = reinterpret_cast(builder); - Y_DEBUG_ABORT_UNLESS(structBuilder->num_fields() == 2); - Y_DEBUG_ABORT_UNLESS(structBuilder->field_builder(0)->type()->id() == arrow::Type::UINT64); + auto status = structBuilder->Append(); - Y_VERIFY_S(status.ok(), status.ToString()); - auto depthBuilder = reinterpret_cast(structBuilder->field_builder(0)); - auto valueBuilder = structBuilder->field_builder(1); - ui64 depth = 0; - TType* innerType = optionalType->GetItemType(); - while (innerType->GetKind() == TType::EKind::Optional && value.HasValue()) { - innerType = static_cast(innerType)->GetItemType(); - value = value.GetOptionalValue(); - ++depth; - } - status = depthBuilder->Append(depth); - Y_VERIFY_S(status.ok(), status.ToString()); - if (value.HasValue()) { - AppendElement(value, valueBuilder, innerType); - } else { - status = valueBuilder->AppendNull(); - Y_VERIFY_S(status.ok(), status.ToString()); - } + YQL_ENSURE(status.ok(), "Failed to append optional value: " << status.ToString()); + + innerValue = innerValue.GetOptionalValue(); + innerBuilder = structBuilder->field_builder(0); + } + + if (innerValue) { + AppendElement(innerValue.GetOptionalValue(), innerBuilder, innerType); + } else { + auto status = innerBuilder->AppendNull(); + YQL_ENSURE(status.ok(), "Failed to append null optional value: " << status.ToString()); } break; } @@ -596,16 +752,19 @@ void AppendElement(NUdf::TUnboxedValue value, arrow::ArrayBuilder* builder, cons case TType::EKind::List: { auto listType = static_cast(type); auto itemType = listType->GetItemType(); - Y_DEBUG_ABORT_UNLESS(builder->type()->id() == arrow::Type::LIST); + + YQL_ENSURE(builder->type()->id() == arrow::Type::LIST, "Unexpected builder type"); auto listBuilder = reinterpret_cast(builder); + auto status = listBuilder->Append(); - Y_VERIFY_S(status.ok(), status.ToString()); + YQL_ENSURE(status.ok(), "Failed to append list value: " << status.ToString()); + auto innerBuilder = listBuilder->value_builder(); - if (auto p = value.GetElements()) { - auto len = value.GetListLength(); - while (len > 0) { - AppendElement(*p++, innerBuilder, itemType); - --len; + if (auto item = value.GetElements()) { + auto length = value.GetListLength(); + while (length > 0) { + AppendElement(*item++, innerBuilder, itemType); + --length; } } else { const auto iter = value.GetListIterator(); @@ -618,11 +777,14 @@ void AppendElement(NUdf::TUnboxedValue value, arrow::ArrayBuilder* builder, cons case TType::EKind::Struct: { auto structType = static_cast(type); - Y_DEBUG_ABORT_UNLESS(builder->type()->id() == arrow::Type::STRUCT); + + YQL_ENSURE(builder->type()->id() == arrow::Type::STRUCT, "Unexpected builder type"); auto structBuilder = reinterpret_cast(builder); + auto status = structBuilder->Append(); - Y_VERIFY_S(status.ok(), status.ToString()); - Y_DEBUG_ABORT_UNLESS(static_cast(structBuilder->num_fields()) == structType->GetMembersCount()); + YQL_ENSURE(status.ok(), "Failed to append struct value: " << status.ToString()); + + YQL_ENSURE(static_cast(structBuilder->num_fields()) == structType->GetMembersCount(), "Unexpected number of fields"); for (ui32 index = 0; index < structType->GetMembersCount(); ++index) { auto innerBuilder = structBuilder->field_builder(index); auto memberType = structType->GetMemberType(index); @@ -633,11 +795,14 @@ void AppendElement(NUdf::TUnboxedValue value, arrow::ArrayBuilder* builder, cons case TType::EKind::Tuple: { auto tupleType = static_cast(type); - Y_DEBUG_ABORT_UNLESS(builder->type()->id() == arrow::Type::STRUCT); + + YQL_ENSURE(builder->type()->id() == arrow::Type::STRUCT, "Unexpected builder type"); auto structBuilder = reinterpret_cast(builder); + auto status = structBuilder->Append(); - Y_VERIFY_S(status.ok(), status.ToString()); - Y_DEBUG_ABORT_UNLESS(static_cast(structBuilder->num_fields()) == tupleType->GetElementsCount()); + YQL_ENSURE(status.ok(), "Failed to append tuple value: " << status.ToString()); + + YQL_ENSURE(static_cast(structBuilder->num_fields()) == tupleType->GetElementsCount(), "Unexpected number of fields"); for (ui32 index = 0; index < tupleType->GetElementsCount(); ++index) { auto innerBuilder = structBuilder->field_builder(index); auto elementType = tupleType->GetElementType(index); @@ -651,37 +816,53 @@ void AppendElement(NUdf::TUnboxedValue value, arrow::ArrayBuilder* builder, cons auto keyType = dictType->GetKeyType(); auto payloadType = dictType->GetPayloadType(); - arrow::ArrayBuilder* keyBuilder; - arrow::ArrayBuilder* itemBuilder; + arrow::ArrayBuilder* keyBuilder = nullptr; + arrow::ArrayBuilder* itemBuilder = nullptr; arrow::StructBuilder* structBuilder = nullptr; + + YQL_ENSURE(builder->type()->id() == arrow::Type::STRUCT, "Unexpected builder type"); + arrow::StructBuilder* wrapBuilder = reinterpret_cast(builder); + YQL_ENSURE(wrapBuilder->num_fields() == 2, "Unexpected number of fields"); + + auto status = wrapBuilder->Append(); + YQL_ENSURE(status.ok(), "Failed to append dict value: " << status.ToString()); + if (keyType->GetKind() == TType::EKind::Optional) { - Y_DEBUG_ABORT_UNLESS(builder->type()->id() == arrow::Type::LIST); - auto listBuilder = reinterpret_cast(builder); - Y_DEBUG_ABORT_UNLESS(listBuilder->value_builder()->type()->id() == arrow::Type::STRUCT); - // Start a new list in ListArray of structs + YQL_ENSURE(wrapBuilder->field_builder(0)->type()->id() == arrow::Type::LIST, "Unexpected builder type"); + auto listBuilder = reinterpret_cast(wrapBuilder->field_builder(0)); + auto status = listBuilder->Append(); - Y_VERIFY_S(status.ok(), status.ToString()); + YQL_ENSURE(status.ok(), "Failed to append dict value: " << status.ToString()); + + YQL_ENSURE(listBuilder->value_builder()->type()->id() == arrow::Type::STRUCT, "Unexpected builder type"); structBuilder = reinterpret_cast(listBuilder->value_builder()); - Y_DEBUG_ABORT_UNLESS(structBuilder->num_fields() == 2); + YQL_ENSURE(structBuilder->num_fields() == 2, "Unexpected number of fields"); + keyBuilder = structBuilder->field_builder(0); itemBuilder = structBuilder->field_builder(1); } else { - Y_DEBUG_ABORT_UNLESS(builder->type()->id() == arrow::Type::MAP); - auto mapBuilder = reinterpret_cast(builder); - // Start a new map in MapArray + YQL_ENSURE(wrapBuilder->field_builder(0)->type()->id() == arrow::Type::MAP, "Unexpected builder type"); + auto mapBuilder = reinterpret_cast(wrapBuilder->field_builder(0)); + auto status = mapBuilder->Append(); - Y_VERIFY_S(status.ok(), status.ToString()); + YQL_ENSURE(status.ok(), "Failed to append dict value: " << status.ToString()); + keyBuilder = mapBuilder->key_builder(); itemBuilder = mapBuilder->item_builder(); } - const auto iter = value.GetDictIterator(); + arrow::UInt64Builder* customBuilder = reinterpret_cast(wrapBuilder->field_builder(1)); + status = customBuilder->Append(0); + YQL_ENSURE(status.ok(), "Failed to append dict value: " << status.ToString()); + // We do not sort dictionary before appending it to builder. + const auto iter = value.GetDictIterator(); for (NUdf::TUnboxedValue key, payload; iter.NextPair(key, payload);) { if (structBuilder != nullptr) { - auto status = structBuilder->Append(); - Y_VERIFY_S(status.ok(), status.ToString()); + status = structBuilder->Append(); + YQL_ENSURE(status.ok(), "Failed to append dict value: " << status.ToString()); } + AppendElement(key, keyBuilder, keyType); AppendElement(payload, itemBuilder, payloadType); } @@ -691,33 +872,42 @@ void AppendElement(NUdf::TUnboxedValue value, arrow::ArrayBuilder* builder, cons case TType::EKind::Variant: { // TODO Need to properly convert variants containing more than 127*127 types? auto variantType = static_cast(type); - Y_DEBUG_ABORT_UNLESS(builder->type()->id() == arrow::Type::DENSE_UNION); + + YQL_ENSURE(builder->type()->id() == arrow::Type::DENSE_UNION, "Unexpected builder type"); auto unionBuilder = reinterpret_cast(builder); + ui32 variantIndex = value.GetVariantIndex(); TType* innerType = variantType->GetUnderlyingType(); + if (innerType->IsStruct()) { innerType = static_cast(innerType)->GetMemberType(variantIndex); } else { - Y_VERIFY_S(innerType->IsTuple(), "Unexpected underlying variant type: " << innerType->GetKindAsStr()); + YQL_ENSURE(innerType->IsTuple(), "Unexpected underlying variant type: " << innerType->GetKindAsStr()); innerType = static_cast(innerType)->GetElementType(variantIndex); } + if (variantType->GetAlternativesCount() > arrow::UnionType::kMaxTypeCode) { ui32 numberOfGroups = (variantType->GetAlternativesCount() - 1) / arrow::UnionType::kMaxTypeCode + 1; - Y_DEBUG_ABORT_UNLESS(static_cast(unionBuilder->num_children()) == numberOfGroups); + YQL_ENSURE(static_cast(unionBuilder->num_children()) == numberOfGroups); + ui32 groupIndex = variantIndex / arrow::UnionType::kMaxTypeCode; auto status = unionBuilder->Append(groupIndex); - Y_VERIFY_S(status.ok(), status.ToString()); + YQL_ENSURE(status.ok(), "Failed to append variant value: " << status.ToString()); + auto innerBuilder = unionBuilder->child_builder(groupIndex); - Y_DEBUG_ABORT_UNLESS(innerBuilder->type()->id() == arrow::Type::DENSE_UNION); + YQL_ENSURE(innerBuilder->type()->id() == arrow::Type::DENSE_UNION); auto innerUnionBuilder = reinterpret_cast(innerBuilder.get()); + ui32 innerVariantIndex = variantIndex % arrow::UnionType::kMaxTypeCode; status = innerUnionBuilder->Append(innerVariantIndex); - Y_VERIFY_S(status.ok(), status.ToString()); + YQL_ENSURE(status.ok(), "Failed to append variant value: " << status.ToString()); + auto doubleInnerBuilder = innerUnionBuilder->child_builder(innerVariantIndex); AppendElement(value.GetVariantItem(), doubleInnerBuilder.get(), innerType); } else { auto status = unionBuilder->Append(variantIndex); - Y_VERIFY_S(status.ok(), status.ToString()); + YQL_ENSURE(status.ok(), "Failed to append variant value: " << status.ToString()); + auto innerBuilder = unionBuilder->child_builder(variantIndex); AppendElement(value.GetVariantItem(), innerBuilder.get(), innerType); } @@ -725,20 +915,20 @@ void AppendElement(NUdf::TUnboxedValue value, arrow::ArrayBuilder* builder, cons } default: - THROW yexception() << "Unsupported type: " << type->GetKindAsStr(); + YQL_ENSURE(false, "Unsupported type: " << type->GetKindAsStr()); } } std::shared_ptr MakeArray(NMiniKQL::TUnboxedValueVector& values, const TType* itemType) { auto builder = MakeArrowBuilder(itemType); auto status = builder->Reserve(values.size()); - Y_VERIFY_S(status.ok(), status.ToString()); + YQL_ENSURE(status.ok(), "Failed to reserve space for array: " << status.ToString()); for (auto& value: values) { AppendElement(value, builder.get(), itemType); } std::shared_ptr result; status = builder->Finish(&result); - Y_VERIFY_S(status.ok(), status.ToString()); + YQL_ENSURE(status.ok(), "Failed to finish array: " << status.ToString()); return result; } @@ -746,13 +936,15 @@ NUdf::TUnboxedValue ExtractUnboxedValue(const std::shared_ptr& arr if (array->IsNull(row)) { return NUdf::TUnboxedValuePod(); } + switch(itemType->GetKind()) { case TType::EKind::Void: case TType::EKind::Null: case TType::EKind::EmptyList: case TType::EKind::EmptyDict: break; - case TType::EKind::Data: { // TODO TzDate need special care + + case TType::EKind::Data: { auto dataType = static_cast(itemType); NUdf::TUnboxedValue result; bool success = SwitchMiniKQLDataTypeToArrowType(*dataType->GetDataSlot().Get(), [&](TTypeWrapper typeHolder) { @@ -760,70 +952,98 @@ NUdf::TUnboxedValue ExtractUnboxedValue(const std::shared_ptr& arr result = GetUnboxedValue(array, row); return true; }); - Y_DEBUG_ABORT_UNLESS(success); + Y_ENSURE(success, "Failed to extract unboxed value from arrow array"); return result; } + case TType::EKind::Struct: { auto structType = static_cast(itemType); - Y_DEBUG_ABORT_UNLESS(array->type_id() == arrow::Type::STRUCT); + + YQL_ENSURE(array->type_id() == arrow::Type::STRUCT); auto typedArray = static_pointer_cast(array); - Y_DEBUG_ABORT_UNLESS(static_cast(typedArray->num_fields()) == structType->GetMembersCount()); + YQL_ENSURE(static_cast(typedArray->num_fields()) == structType->GetMembersCount()); + NUdf::TUnboxedValue* itemsPtr = nullptr; auto result = holderFactory.CreateDirectArrayHolder(structType->GetMembersCount(), itemsPtr); + for (ui32 index = 0; index < structType->GetMembersCount(); ++index) { auto memberType = structType->GetMemberType(index); itemsPtr[index] = ExtractUnboxedValue(typedArray->field(index), row, memberType, holderFactory); } return result; } + case TType::EKind::Tuple: { auto tupleType = static_cast(itemType); - Y_DEBUG_ABORT_UNLESS(array->type_id() == arrow::Type::STRUCT); + + YQL_ENSURE(array->type_id() == arrow::Type::STRUCT); auto typedArray = static_pointer_cast(array); - Y_DEBUG_ABORT_UNLESS(static_cast(typedArray->num_fields()) == tupleType->GetElementsCount()); + YQL_ENSURE(static_cast(typedArray->num_fields()) == tupleType->GetElementsCount()); + NUdf::TUnboxedValue* itemsPtr = nullptr; auto result = holderFactory.CreateDirectArrayHolder(tupleType->GetElementsCount(), itemsPtr); + for (ui32 index = 0; index < tupleType->GetElementsCount(); ++index) { auto elementType = tupleType->GetElementType(index); itemsPtr[index] = ExtractUnboxedValue(typedArray->field(index), row, elementType, holderFactory); } return result; } + case TType::EKind::Optional: { auto optionalType = static_cast(itemType); auto innerOptionalType = optionalType->GetItemType(); - if (innerOptionalType->GetKind() == TType::EKind::Optional) { - Y_DEBUG_ABORT_UNLESS(array->type_id() == arrow::Type::STRUCT); - auto structArray = static_pointer_cast(array); - Y_DEBUG_ABORT_UNLESS(structArray->num_fields() == 2); - Y_DEBUG_ABORT_UNLESS(structArray->field(0)->type_id() == arrow::Type::UINT64); - auto depthArray = static_pointer_cast(structArray->field(0)); - auto valuesArray = structArray->field(1); - auto depth = depthArray->Value(row); + + if (NeedWrapByExternalOptional(innerOptionalType)) { + YQL_ENSURE(array->type_id() == arrow::Type::STRUCT); + + auto innerArray = array; + auto innerType = itemType; + NUdf::TUnboxedValue value; - if (valuesArray->IsNull(row)) { - value = NUdf::TUnboxedValuePod(); - } else { - while (innerOptionalType->GetKind() == TType::EKind::Optional) { - innerOptionalType = static_cast(innerOptionalType)->GetItemType(); + int depth = 0; + + while (innerArray->type_id() == arrow::Type::STRUCT) { + auto structArray = static_pointer_cast(innerArray); + YQL_ENSURE(structArray->num_fields() == 1); + + if (structArray->IsNull(row)) { + value = NUdf::TUnboxedValuePod(); + break; } - value = ExtractUnboxedValue(valuesArray, row, innerOptionalType, holderFactory); + + innerType = static_cast(innerType)->GetItemType(); + innerArray = structArray->field(0); + ++depth; } - for (ui64 i = 0; i < depth; ++i) { + + auto wrap = NeedWrapByExternalOptional(innerType); + if (wrap || !innerArray->IsNull(row)) { + value = ExtractUnboxedValue(innerArray, row, innerType, holderFactory); + if (wrap) { + --depth; + } + } + + for (int i = 0; i < depth; ++i) { value = value.MakeOptional(); } return value; - } else { - return ExtractUnboxedValue(array, row, innerOptionalType, holderFactory).Release().MakeOptional(); } + + return ExtractUnboxedValue(array, row, innerOptionalType, holderFactory).Release().MakeOptional(); } + case TType::EKind::List: { auto listType = static_cast(itemType); - Y_DEBUG_ABORT_UNLESS(array->type_id() == arrow::Type::LIST); + + YQL_ENSURE(array->type_id() == arrow::Type::LIST); auto typedArray = static_pointer_cast(array); + auto arraySlice = typedArray->value_slice(row); auto itemType = listType->GetItemType(); const auto len = arraySlice->length(); + NUdf::TUnboxedValue *items = nullptr; auto list = holderFactory.CreateDirectArrayHolder(len, items); for (ui64 i = 0; i < static_cast(len); ++i) { @@ -831,8 +1051,10 @@ NUdf::TUnboxedValue ExtractUnboxedValue(const std::shared_ptr& arr } return list; } + case TType::EKind::Dict: { auto dictType = static_cast(itemType); + auto keyType = dictType->GetKeyType(); auto payloadType = dictType->GetPayloadType(); auto dictBuilder = holderFactory.NewDict(dictType, NUdf::TDictFlags::EDictKind::Hashed); @@ -841,24 +1063,35 @@ NUdf::TUnboxedValue ExtractUnboxedValue(const std::shared_ptr& arr std::shared_ptr payloadArray = nullptr; ui64 dictLength = 0; ui64 offset = 0; + + YQL_ENSURE(array->type_id() == arrow::Type::STRUCT); + auto wrapArray = static_pointer_cast(array); + YQL_ENSURE(wrapArray->num_fields() == 2); + + auto dictSlice = wrapArray->field(0); + if (keyType->GetKind() == TType::EKind::Optional) { - Y_DEBUG_ABORT_UNLESS(array->type_id() == arrow::Type::LIST); - auto listArray = static_pointer_cast(array); + YQL_ENSURE(dictSlice->type_id() == arrow::Type::LIST); + auto listArray = static_pointer_cast(dictSlice); + auto arraySlice = listArray->value_slice(row); - Y_DEBUG_ABORT_UNLESS(arraySlice->type_id() == arrow::Type::STRUCT); + YQL_ENSURE(arraySlice->type_id() == arrow::Type::STRUCT); auto structArray = static_pointer_cast(arraySlice); - Y_DEBUG_ABORT_UNLESS(structArray->num_fields() == 2); + YQL_ENSURE(structArray->num_fields() == 2); + dictLength = arraySlice->length(); keyArray = structArray->field(0); payloadArray = structArray->field(1); } else { - Y_DEBUG_ABORT_UNLESS(array->type_id() == arrow::Type::MAP); - auto mapArray = static_pointer_cast(array); + YQL_ENSURE(dictSlice->type_id() == arrow::Type::MAP); + auto mapArray = static_pointer_cast(dictSlice); + dictLength = mapArray->value_length(row); offset = mapArray->value_offset(row); keyArray = mapArray->keys(); payloadArray = mapArray->items(); } + for (ui64 i = offset; i < offset + static_cast(dictLength); ++i) { auto key = ExtractUnboxedValue(keyArray, i, keyType, holderFactory); auto payload = ExtractUnboxedValue(payloadArray, i, payloadType, holderFactory); @@ -866,35 +1099,42 @@ NUdf::TUnboxedValue ExtractUnboxedValue(const std::shared_ptr& arr } return dictBuilder->Build(); } + case TType::EKind::Variant: { // TODO Need to properly convert variants containing more than 127*127 types? auto variantType = static_cast(itemType); - Y_DEBUG_ABORT_UNLESS(array->type_id() == arrow::Type::DENSE_UNION); + + YQL_ENSURE(array->type_id() == arrow::Type::DENSE_UNION); auto unionArray = static_pointer_cast(array); + auto variantIndex = unionArray->child_id(row); auto rowInChild = unionArray->value_offset(row); std::shared_ptr valuesArray = unionArray->field(variantIndex); + if (variantType->GetAlternativesCount() > arrow::UnionType::kMaxTypeCode) { // Go one step deeper - Y_DEBUG_ABORT_UNLESS(valuesArray->type_id() == arrow::Type::DENSE_UNION); + YQL_ENSURE(valuesArray->type_id() == arrow::Type::DENSE_UNION); auto innerUnionArray = static_pointer_cast(valuesArray); auto innerVariantIndex = innerUnionArray->child_id(rowInChild); + rowInChild = innerUnionArray->value_offset(rowInChild); valuesArray = innerUnionArray->field(innerVariantIndex); variantIndex = variantIndex * arrow::UnionType::kMaxTypeCode + innerVariantIndex; } + TType* innerType = variantType->GetUnderlyingType(); if (innerType->IsStruct()) { innerType = static_cast(innerType)->GetMemberType(variantIndex); } else { - Y_VERIFY_S(innerType->IsTuple(), "Unexpected underlying variant type: " << innerType->GetKindAsStr()); + YQL_ENSURE(innerType->IsTuple(), "Unexpected underlying variant type: " << innerType->GetKindAsStr()); innerType = static_cast(innerType)->GetElementType(variantIndex); } + NUdf::TUnboxedValue value = ExtractUnboxedValue(valuesArray, rowInChild, innerType, holderFactory); return holderFactory.CreateVariantHolder(value.Release(), variantIndex); } default: - THROW yexception() << "Unsupported type: " << itemType->GetKindAsStr(); + YQL_ENSURE(false, "Unsupported type: " << itemType->GetKindAsStr()); } return NUdf::TUnboxedValuePod(); } @@ -915,20 +1155,20 @@ std::string SerializeArray(const std::shared_ptr& array) { writeOptions.use_threads = false; // TODO Decide which compression level will be default. Will it depend on the length of array? auto codecResult = arrow::util::Codec::Create(arrow::Compression::LZ4_FRAME); - Y_ABORT_UNLESS(codecResult.ok()); + YQL_ENSURE(codecResult.ok()); writeOptions.codec = std::move(codecResult.ValueOrDie()); int64_t size; auto status = GetRecordBatchSize(*batch, writeOptions, &size); - Y_ABORT_UNLESS(status.ok()); + YQL_ENSURE(status.ok()); std::string str; str.resize(size); auto writer = arrow::Buffer::GetWriter(arrow::MutableBuffer::Wrap(&str[0], size)); - Y_ABORT_UNLESS(writer.status().ok()); + YQL_ENSURE(writer.status().ok()); status = SerializeRecordBatch(*batch, writeOptions, (*writer).get()); - Y_ABORT_UNLESS(status.ok()); + YQL_ENSURE(status.ok()); return str; } @@ -941,7 +1181,7 @@ std::shared_ptr DeserializeArray(const std::string& blob, std::sha arrow::io::BufferReader reader(buffer); auto schema = std::make_shared(std::vector>{arrow::field("", type)}); auto batch = ReadRecordBatch(schema, &dictMemo, options, &reader); - Y_DEBUG_ABORT_UNLESS(batch.ok() && (*batch)->ValidateFull().ok(), "Failed to deserialize batch"); + YQL_ENSURE(batch.ok() && (*batch)->ValidateFull().ok(), "Failed to deserialize batch"); return (*batch)->column(0); } diff --git a/ydb/library/yql/dq/runtime/dq_arrow_helpers.h b/ydb/library/yql/dq/runtime/dq_arrow_helpers.h index 769c62a9579..9270ece1f9b 100644 --- a/ydb/library/yql/dq/runtime/dq_arrow_helpers.h +++ b/ydb/library/yql/dq/runtime/dq_arrow_helpers.h @@ -14,9 +14,25 @@ namespace NArrow { /** * @brief Convert TType to the arrow::DataType object * - * The logic of this conversion is the following: + * The logic of this conversion is from YQL-15332: * - * Struct, tuple => StructArray + * Void, Null => NullType + * Bool => Uint8 + * Integral => Uint8..Uint64, Int8..Int64 + * Floats => Float, Double + * Date => Uint16 + * Datetime => Uint32 + * Timestamp => Uint64 + * Interval => Int64 + * Date32 => Int32 + * Interval64, Timestamp64, Datetime64 => Int64 + * Utf8, Json => String + * String, Yson, JsonDocument => Binary + * Decimal, UUID => FixedSizeBinary(16) + * Timezone datetime type => StructArray + * DyNumber => BinaryArray (it is not added to YQL-15332) + * + * Struct, Tuple, EmptyList, EmptyDict => StructArray * Names of fields constructed from tuple are just empty strings. * * List => ListArray @@ -26,20 +42,26 @@ namespace NArrow { * Variant => DenseUnionArray * TODO Implement convertion of data to DenseUnionArray and back * - * Optional(Optional ..(type)..) => StructArray - * Here the integer value equals the number of calls of method GetOptionalValue(). - * If value is null at some depth, then the value in second field of Array is Null - * (and the integer equals this depth). If value is present, then it is contained in the - * second field (and the integer equals the number of Optional(...) levels). - * This information is sufficient to restore an UnboxedValue knowing its type. + * Optional => StructArray if T is Variant + * Because DenseUnionArray does not have validity bitmap + * Optional => T for other types + * By default, other types have a validity bitmap + * + * Optional...>> => StructArray...>> + * For example: + * - Optional> => StructArray + * Int32 has validity bitmap, so we wrap it in StructArray N - 1 times, where N is the number of Optional levels + * - Optional>> => StructArray>> + * DenseUnionArray does not have validity bitmap, so we wrap it in StructArray N times, where N is the number of Optional levels * - * Dict => MapArray + * Dict => StructArray, Uint64Array (on demand, default: 0)> * We do not use arrow::DictArray because it must be used for encoding not for mapping keys to values. * (https://arrow.apache.org/docs/cpp/api/array.html#classarrow_1_1_dictionary_array) * If the type of dict key is optional then we map - * Dict => ListArray> + * Dict, ValueType> => StructArray, Uint64Array (on demand, default: 0)> * because keys of MapArray can not be nullable * + * * @param type Yql type to parse * @return std::shared_ptr arrow type of the same structure as type */ diff --git a/ydb/library/yql/dq/runtime/dq_arrow_helpers_ut.cpp b/ydb/library/yql/dq/runtime/dq_arrow_helpers_ut.cpp index d1ca1e46bd5..86dc9be2428 100644 --- a/ydb/library/yql/dq/runtime/dq_arrow_helpers_ut.cpp +++ b/ydb/library/yql/dq/runtime/dq_arrow_helpers_ut.cpp @@ -11,6 +11,7 @@ #include #include +#include #include #include @@ -24,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -63,6 +65,36 @@ NUdf::TUnboxedValue GetValueOfBasicType(TType* type, ui64 value) { return NUdf::TUnboxedValuePod(static_cast(value) / 1234); case NUdf::EDataSlot::Double: return NUdf::TUnboxedValuePod(static_cast(value) / 12345); + case NUdf::EDataSlot::TzDate: { + auto ret = NUdf::TUnboxedValuePod(static_cast(value % NUdf::MAX_DATE)); + ret.SetTimezoneId(NKikimr::NMiniKQL::GetTimezoneId("Europe/Moscow")); + return ret; + } + case NUdf::EDataSlot::TzDatetime: { + auto ret = NUdf::TUnboxedValuePod(static_cast(value % NUdf::MAX_DATETIME)); + ret.SetTimezoneId(NKikimr::NMiniKQL::GetTimezoneId("Asia/Omsk")); + return ret; + } + case NUdf::EDataSlot::TzTimestamp: { + auto ret = NUdf::TUnboxedValuePod(static_cast(value % NUdf::MAX_TIMESTAMP)); + ret.SetTimezoneId(NKikimr::NMiniKQL::GetTimezoneId("Europe/Tallinn")); + return ret; + } + case NUdf::EDataSlot::TzDate32: { + auto ret = NUdf::TUnboxedValuePod(static_cast(value % NUdf::MAX_DATE32)); + ret.SetTimezoneId(NKikimr::NMiniKQL::GetTimezoneId("US/Eastern")); + return ret; + } + case NUdf::EDataSlot::TzDatetime64: { + auto ret = NUdf::TUnboxedValuePod(static_cast(value % NUdf::MAX_DATETIME64)); + ret.SetTimezoneId(NKikimr::NMiniKQL::GetTimezoneId("America/Nuuk")); + return ret; + } + case NUdf::EDataSlot::TzTimestamp64: { + auto ret = NUdf::TUnboxedValuePod(static_cast(value % NUdf::MAX_TIMESTAMP64)); + ret.SetTimezoneId(NKikimr::NMiniKQL::GetTimezoneId("Atlantic/Faroe")); + return ret; + } default: Y_ABORT("Not implemented creation value for such type"); } @@ -110,7 +142,13 @@ struct TTestContext { TDataType::Create(NUdf::TDataType::Id, TypeEnv), TDataType::Create(NUdf::TDataType::Id, TypeEnv), TDataType::Create(NUdf::TDataType::Id, TypeEnv), - TDataType::Create(NUdf::TDataType::Id, TypeEnv) + TDataType::Create(NUdf::TDataType::Id, TypeEnv), + TDataType::Create(NUdf::TDataType::Id, TypeEnv), + TDataType::Create(NUdf::TDataType::Id, TypeEnv), + TDataType::Create(NUdf::TDataType::Id, TypeEnv), + TDataType::Create(NUdf::TDataType::Id, TypeEnv), + TDataType::Create(NUdf::TDataType::Id, TypeEnv), + TDataType::Create(NUdf::TDataType::Id, TypeEnv) }; TTestContext() @@ -212,6 +250,32 @@ struct TTestContext { return values; } + TType* GetOptionalListOfOptional() { + TType* itemType = TOptionalType::Create(TDataType::Create(NUdf::TDataType::Id, TypeEnv), TypeEnv); + return TOptionalType::Create(TListType::Create(itemType, TypeEnv), TypeEnv); + } + + TUnboxedValueVector CreateOptionalListOfOptional(ui32 quantity) { + TUnboxedValueVector values; + for (ui64 value = 0; value < quantity; ++value) { + if (value % 2 == 0) { + values.emplace_back(NUdf::TUnboxedValuePod()); + continue; + } + + TUnboxedValueVector items; + items.reserve(value); + for (ui64 i = 0; i < value; ++i) { + NUdf::TUnboxedValue item = ((value + i) % 2 == 0) ? NUdf::TUnboxedValuePod() : NUdf::TUnboxedValuePod(i); + items.push_back(std::move(item).MakeOptional()); + } + + auto listValue = Vb.NewList(items.data(), value); + values.emplace_back(std::move(listValue).MakeOptional()); + } + return values; + } + TType* GetVariantOverStructType() { TStructMember members[4] = { {"0_yson", TDataType::Create(NUdf::TDataType::Id, TypeEnv)}, @@ -235,7 +299,8 @@ struct TTestContext { std::string data = TStringBuilder() << "{value:" << value << "}"; item = MakeString(NUdf::TStringRef(data.data(), data.size())); } else if (typeIndex == 2) { - std::string data = TStringBuilder() << "id-QwErY-" << value; + std::string sample = "7856341212905634789012345678901"; + std::string data = TStringBuilder() << HexDecode(sample + static_cast('0' + (value % 10))); item = MakeString(NUdf::TStringRef(data.data(), data.size())); } else if (typeIndex == 3) { item = NUdf::TUnboxedValuePod(static_cast(value) / 4); @@ -246,6 +311,79 @@ struct TTestContext { return values; } + TType* GetOptionalVariantOverStructType() { + return TOptionalType::Create(GetVariantOverStructType(), TypeEnv); + } + + TUnboxedValueVector CreateOptionalVariantOverStruct(ui32 quantity) { + TUnboxedValueVector values; + for (ui64 value = 0; value < quantity; ++value) { + auto typeIndex = value % 4; + NUdf::TUnboxedValue item; + + if (value % 2 == 0) { + values.push_back(NUdf::TUnboxedValuePod()); + continue; + } + + if (typeIndex == 0) { + std::string data = TStringBuilder() << "{value=" << value << "}"; + item = MakeString(NUdf::TStringRef(data.data(), data.size())); + } else if (typeIndex == 1) { + std::string data = TStringBuilder() << "{value:" << value << "}"; + item = MakeString(NUdf::TStringRef(data.data(), data.size())); + } else if (typeIndex == 2) { + std::string sample = "7856341212905634789012345678901"; + std::string data = TStringBuilder() << HexDecode(sample + static_cast('0' + (value % 10))); + item = MakeString(NUdf::TStringRef(data.data(), data.size())); + } else if (typeIndex == 3) { + item = NUdf::TUnboxedValuePod(static_cast(value) / 4); + } + auto wrapped = Vb.NewVariant(typeIndex, std::move(item)).MakeOptional(); + values.push_back(std::move(wrapped)); + } + return values; + } + + TType* GetDoubleOptionalVariantOverStructType() { + return TOptionalType::Create(GetOptionalVariantOverStructType(), TypeEnv); + } + + TUnboxedValueVector CreateDoubleOptionalVariantOverStruct(ui32 quantity) { + TUnboxedValueVector values; + for (ui64 value = 0; value < quantity; ++value) { + auto typeIndex = value % 4; + NUdf::TUnboxedValue item; + + if (value % 3 == 0) { + if (typeIndex == 0) { + std::string data = TStringBuilder() << "{value=" << value << "}"; + item = MakeString(NUdf::TStringRef(data.data(), data.size())); + } else if (typeIndex == 1) { + std::string data = TStringBuilder() << "{value:" << value << "}"; + item = MakeString(NUdf::TStringRef(data.data(), data.size())); + } else if (typeIndex == 2) { + std::string sample = "7856341212905634789012345678901"; + std::string data = TStringBuilder() << HexDecode(sample + static_cast('0' + (value % 10))); + item = MakeString(NUdf::TStringRef(data.data(), data.size())); + } else if (typeIndex == 3) { + item = NUdf::TUnboxedValuePod(static_cast(value) / 4); + } + + item = Vb.NewVariant(typeIndex, std::move(item)).MakeOptional(); + } else { + item = NUdf::TUnboxedValuePod(); + } + + if (value % 3 != 2) { + item = item.MakeOptional(); + } + + values.push_back(std::move(item)); + } + return values; + } + TType* GetVariantOverTupleWithOptionalsType() { TType* members[5] = { TDataType::Create(NUdf::TDataType::Id, TypeEnv), @@ -284,6 +422,83 @@ struct TTestContext { return values; } + TType* GetOptionalVariantOverTupleWithOptionalsType() { + return TOptionalType::Create(GetVariantOverTupleWithOptionalsType(), TypeEnv); + } + + TUnboxedValueVector CreateOptionalVariantOverTupleWithOptionals(ui32 quantity) { + NKikimr::NMiniKQL::TUnboxedValueVector values; + for (ui64 value = 0; value < quantity; ++value) { + + if (value % 2 == 0) { + values.push_back(NUdf::TUnboxedValuePod()); + continue; + } + + auto typeIndex = value % 5; + NUdf::TUnboxedValue item; + if (typeIndex == 0) { + item = NUdf::TUnboxedValuePod(value % 3 == 0); + } else if (typeIndex == 1) { + item = NUdf::TUnboxedValuePod(static_cast(-value)); + } else if (typeIndex == 2) { + item = NUdf::TUnboxedValuePod(static_cast(value)); + } else if (typeIndex == 3) { + item = NUdf::TUnboxedValuePod(static_cast(-value)); + } else if (typeIndex == 4) { + NUdf::TUnboxedValue innerItem; + innerItem = value % 2 == 0 + ? NUdf::TUnboxedValuePod(static_cast(value)) + : NUdf::TUnboxedValuePod(); + item = innerItem.MakeOptional(); + } + auto wrapped = Vb.NewVariant(typeIndex, std::move(item)).MakeOptional(); + values.emplace_back(std::move(wrapped)); + } + return values; + } + + TType* GetDoubleOptionalVariantOverTupleWithOptionalsType() { + return TOptionalType::Create(GetOptionalVariantOverTupleWithOptionalsType(), TypeEnv); + } + + TUnboxedValueVector CreateDoubleOptionalVariantOverTupleWithOptionals(ui32 quantity) { + NKikimr::NMiniKQL::TUnboxedValueVector values; + for (ui64 value = 0; value < quantity; ++value) { + auto typeIndex = value % 5; + NUdf::TUnboxedValue item; + + if (value % 3 == 0) { + if (typeIndex == 0) { + item = NUdf::TUnboxedValuePod(value % 3 == 0); + } else if (typeIndex == 1) { + item = NUdf::TUnboxedValuePod(static_cast(-value)); + } else if (typeIndex == 2) { + item = NUdf::TUnboxedValuePod(static_cast(value)); + } else if (typeIndex == 3) { + item = NUdf::TUnboxedValuePod(static_cast(-value)); + } else if (typeIndex == 4) { + NUdf::TUnboxedValue innerItem; + innerItem = value % 2 == 0 + ? NUdf::TUnboxedValuePod(static_cast(value)) + : NUdf::TUnboxedValuePod(); + item = innerItem.MakeOptional(); + } + + item = Vb.NewVariant(typeIndex, std::move(item)); + } else { + item = NUdf::TUnboxedValuePod(); + } + + if (value % 3 != 2) { + item = item.MakeOptional(); + } + + values.emplace_back(std::move(item)); + } + return values; + } + TType* GetDictOptionalToTupleType() { TType* keyType = TOptionalType::Create(TDataType::Create(NUdf::TDataType::Id, TypeEnv), TypeEnv); TType* members[2] = { @@ -346,7 +561,7 @@ struct TTestContext { for (ui64 index = 0; index < variantSize; ++index) { TVector selectedTypes; for (ui32 i = 0; i < BasicTypes.size(); ++i) { - if ((index >> i) % 2 == 1) { + if ((index ^ i) % 5 >= 2) { selectedTypes.push_back(BasicTypes[i]); } } @@ -363,7 +578,7 @@ struct TTestContext { auto typeIndex = index % VariantSize; TUnboxedValueVector tupleItems; for (ui64 i = 0; i < BasicTypes.size(); ++i) { - if ((typeIndex >> i) % 2 == 1) { + if ((typeIndex ^ i) % 5 >= 2) { tupleItems.push_back(GetValueOfBasicType(BasicTypes[i], i)); } } @@ -620,13 +835,13 @@ Y_UNIT_TEST_SUITE(DqUnboxedValueToNativeArrowConversion) { UNIT_ASSERT(array->type_id() == arrow::Type::STRUCT); auto structArray = static_pointer_cast(array); UNIT_ASSERT(structArray->num_fields() == 3); - UNIT_ASSERT(structArray->field(0)->type_id() == arrow::Type::BOOL); + UNIT_ASSERT(structArray->field(0)->type_id() == arrow::Type::UINT8); UNIT_ASSERT(structArray->field(1)->type_id() == arrow::Type::INT8); UNIT_ASSERT(structArray->field(2)->type_id() == arrow::Type::UINT8); UNIT_ASSERT(static_cast(structArray->field(0)->length()) == values.size()); UNIT_ASSERT(static_cast(structArray->field(1)->length()) == values.size()); UNIT_ASSERT(static_cast(structArray->field(2)->length()) == values.size()); - auto boolArray = static_pointer_cast(structArray->field(0)); + auto boolArray = static_pointer_cast(structArray->field(0)); auto int8Array = static_pointer_cast(structArray->field(1)); auto uint8Array = static_pointer_cast(structArray->field(2)); auto index = 0; @@ -646,38 +861,6 @@ Y_UNIT_TEST_SUITE(DqUnboxedValueToNativeArrowConversion) { } } - Y_UNIT_TEST(DictUtf8ToInterval) { - TTestContext context; - - auto dictType = context.GetDictUtf8ToIntervalType(); - UNIT_ASSERT(NArrow::IsArrowCompatible(dictType)); - - auto values = context.CreateDictUtf8ToInterval(100); - auto array = NArrow::MakeArray(values, dictType); - UNIT_ASSERT(array->ValidateFull().ok()); - UNIT_ASSERT(static_cast(array->length()) == values.size()); - UNIT_ASSERT(array->type_id() == arrow::Type::MAP); - auto mapArray = static_pointer_cast(array); - - UNIT_ASSERT(mapArray->num_fields() == 1); - UNIT_ASSERT(mapArray->keys()->type_id() == arrow::Type::STRING); - UNIT_ASSERT(mapArray->items()->type_id() == arrow::Type::DURATION); - auto utf8Array = static_pointer_cast(mapArray->keys()); - auto intervalArray = static_pointer_cast>(mapArray->items()); - ui64 index = 0; - for (const auto& value: values) { - UNIT_ASSERT(value.GetDictLength() == static_cast(mapArray->value_length(index))); - for (auto subindex = mapArray->value_offset(index); subindex < mapArray->value_offset(index + 1); ++subindex) { - auto keyArrow = utf8Array->GetView(subindex); - NUdf::TUnboxedValue key = MakeString(NUdf::TStringRef(keyArrow.data(), keyArrow.size())); - UNIT_ASSERT(value.Contains(key)); - NUdf::TUnboxedValue payloadValue = value.Lookup(key); - UNIT_ASSERT(intervalArray->Value(subindex) == payloadValue.Get()); - } - ++index; - } - } - Y_UNIT_TEST(ListOfJsons) { TTestContext context; @@ -711,6 +894,48 @@ Y_UNIT_TEST_SUITE(DqUnboxedValueToNativeArrowConversion) { } } + Y_UNIT_TEST(OptionalListOfOptional) { + TTestContext context; + + auto listType = context.GetOptionalListOfOptional(); + Y_ABORT_UNLESS(NArrow::IsArrowCompatible(listType)); + + auto values = context.CreateOptionalListOfOptional(100); + auto array = NArrow::MakeArray(values, listType); + UNIT_ASSERT(array->ValidateFull().ok()); + UNIT_ASSERT(static_cast(array->length()) == values.size()); + UNIT_ASSERT(array->type_id() == arrow::Type::LIST); + + auto listArray = static_pointer_cast(array); + UNIT_ASSERT(listArray->num_fields() == 1); + UNIT_ASSERT(listArray->value_type()->id() == arrow::Type::INT32); + + auto i32Array = static_pointer_cast(listArray->values()); + auto index = 0; + auto innerIndex = 0; + for (const auto& value: values) { + if (!value.HasValue()) { + UNIT_ASSERT(listArray->IsNull(index)); + ++index; + continue; + } + + auto listValue = value.GetOptionalValue(); + + UNIT_ASSERT_VALUES_EQUAL(listValue.GetListLength(), static_cast(listArray->value_length(index))); + const auto iter = listValue.GetListIterator(); + for (NUdf::TUnboxedValue item; iter.Next(item);) { + if (!item.HasValue()) { + UNIT_ASSERT(i32Array->IsNull(innerIndex)); + } else { + UNIT_ASSERT(i32Array->Value(innerIndex) == item.GetOptionalValue().Get()); + } + ++innerIndex; + } + ++index; + } + } + Y_UNIT_TEST(VariantOverStruct) { TTestContext context; @@ -725,14 +950,16 @@ Y_UNIT_TEST_SUITE(DqUnboxedValueToNativeArrowConversion) { auto unionArray = static_pointer_cast(array); UNIT_ASSERT(unionArray->num_fields() == 4); - UNIT_ASSERT(unionArray->field(0)->type_id() == arrow::Type::STRING); - UNIT_ASSERT(unionArray->field(1)->type_id() == arrow::Type::STRING); - UNIT_ASSERT(unionArray->field(2)->type_id() == arrow::Type::BINARY); + UNIT_ASSERT(unionArray->field(0)->type_id() == arrow::Type::BINARY); + UNIT_ASSERT(unionArray->field(1)->type_id() == arrow::Type::BINARY); + UNIT_ASSERT(unionArray->field(2)->type_id() == arrow::Type::FIXED_SIZE_BINARY); UNIT_ASSERT(unionArray->field(3)->type_id() == arrow::Type::FLOAT); - auto ysonArray = static_pointer_cast(unionArray->field(0)); - auto jsonDocArray = static_pointer_cast(unionArray->field(1)); - auto uuidArray = static_pointer_cast(unionArray->field(2)); + + auto ysonArray = static_pointer_cast(unionArray->field(0)); + auto jsonDocArray = static_pointer_cast(unionArray->field(1)); + auto uuidArray = static_pointer_cast(unionArray->field(2)); auto floatArray = static_pointer_cast(unionArray->field(3)); + for (ui64 index = 0; index < values.size(); ++index) { auto value = values[index]; UNIT_ASSERT(value.GetVariantIndex() == static_cast(unionArray->child_id(index))); @@ -759,13 +986,148 @@ Y_UNIT_TEST_SUITE(DqUnboxedValueToNativeArrowConversion) { } } + Y_UNIT_TEST(OptionalVariantOverStruct) { + TTestContext context; + + auto variantType = context.GetOptionalVariantOverStructType(); + UNIT_ASSERT(!NArrow::IsArrowCompatible(variantType)); + + auto values = context.CreateOptionalVariantOverStruct(100); + auto array = NArrow::MakeArray(values, variantType); + UNIT_ASSERT(array->ValidateFull().ok()); + UNIT_ASSERT(static_cast(array->length()) == values.size()); + UNIT_ASSERT(array->type_id() == arrow::Type::STRUCT); + + auto structArray = static_pointer_cast(array); + UNIT_ASSERT(structArray->num_fields() == 1); + UNIT_ASSERT(structArray->field(0)->type_id() == arrow::Type::DENSE_UNION); + + auto unionArray = static_pointer_cast(structArray->field(0)); + + UNIT_ASSERT(unionArray->num_fields() == 4); + UNIT_ASSERT(unionArray->field(0)->type_id() == arrow::Type::BINARY); + UNIT_ASSERT(unionArray->field(1)->type_id() == arrow::Type::BINARY); + UNIT_ASSERT(unionArray->field(2)->type_id() == arrow::Type::FIXED_SIZE_BINARY); + UNIT_ASSERT(unionArray->field(3)->type_id() == arrow::Type::FLOAT); + + auto ysonArray = static_pointer_cast(unionArray->field(0)); + auto jsonDocArray = static_pointer_cast(unionArray->field(1)); + auto uuidArray = static_pointer_cast(unionArray->field(2)); + auto floatArray = static_pointer_cast(unionArray->field(3)); + + for (ui64 index = 0; index < values.size(); ++index) { + auto value = values[index]; + if (!value.HasValue()) { + // NULL + UNIT_ASSERT(structArray->IsNull(index)); + continue; + } + + UNIT_ASSERT(!structArray->IsNull(index)); + + UNIT_ASSERT(value.GetVariantIndex() == static_cast(unionArray->child_id(index))); + auto fieldIndex = unionArray->value_offset(index); + if (value.GetVariantIndex() == 3) { + auto valueArrow = floatArray->Value(fieldIndex); + auto valueInner = value.GetVariantItem().Get(); + UNIT_ASSERT(valueArrow == valueInner); + } else { + arrow::util::string_view viewArrow; + if (value.GetVariantIndex() == 0) { + viewArrow = ysonArray->GetView(fieldIndex); + } else if (value.GetVariantIndex() == 1) { + viewArrow = jsonDocArray->GetView(fieldIndex); + } else if (value.GetVariantIndex() == 2) { + viewArrow = uuidArray->GetView(fieldIndex); + } + std::string valueArrow(viewArrow.data(), viewArrow.size()); + auto innerItem = value.GetVariantItem(); + auto refInner = innerItem.AsStringRef(); + std::string valueInner(refInner.Data(), refInner.Size()); + UNIT_ASSERT(valueArrow == valueInner); + } + } + } + + Y_UNIT_TEST(DoubleOptionalVariantOverStruct) { + TTestContext context; + + auto variantType = context.GetDoubleOptionalVariantOverStructType(); + UNIT_ASSERT(!NArrow::IsArrowCompatible(variantType)); + + auto values = context.CreateDoubleOptionalVariantOverStruct(100); + auto array = NArrow::MakeArray(values, variantType); + UNIT_ASSERT(array->ValidateFull().ok()); + UNIT_ASSERT(static_cast(array->length()) == values.size()); + UNIT_ASSERT(array->type_id() == arrow::Type::STRUCT); + + auto firstStructArray = static_pointer_cast(array); + UNIT_ASSERT(firstStructArray->num_fields() == 1); + UNIT_ASSERT(firstStructArray->field(0)->type_id() == arrow::Type::STRUCT); + + auto secondStructArray = static_pointer_cast(firstStructArray->field(0)); + UNIT_ASSERT(secondStructArray->num_fields() == 1); + UNIT_ASSERT(secondStructArray->field(0)->type_id() == arrow::Type::DENSE_UNION); + + auto unionArray = static_pointer_cast(secondStructArray->field(0)); + + UNIT_ASSERT(unionArray->num_fields() == 4); + UNIT_ASSERT(unionArray->field(0)->type_id() == arrow::Type::BINARY); + UNIT_ASSERT(unionArray->field(1)->type_id() == arrow::Type::BINARY); + UNIT_ASSERT(unionArray->field(2)->type_id() == arrow::Type::FIXED_SIZE_BINARY); + UNIT_ASSERT(unionArray->field(3)->type_id() == arrow::Type::FLOAT); + + auto ysonArray = static_pointer_cast(unionArray->field(0)); + auto jsonDocArray = static_pointer_cast(unionArray->field(1)); + auto uuidArray = static_pointer_cast(unionArray->field(2)); + auto floatArray = static_pointer_cast(unionArray->field(3)); + + for (ui64 index = 0; index < values.size(); ++index) { + auto value = values[index]; + if (!value.HasValue()) { + if (value) { + // Optional(NULL) + UNIT_ASSERT(secondStructArray->IsNull(index)); + } else { + // NULL + UNIT_ASSERT(firstStructArray->IsNull(index)); + } + continue; + } + + UNIT_ASSERT(!firstStructArray->IsNull(index) && !secondStructArray->IsNull(index)); + + UNIT_ASSERT(value.GetVariantIndex() == static_cast(unionArray->child_id(index))); + auto fieldIndex = unionArray->value_offset(index); + if (value.GetVariantIndex() == 3) { + auto valueArrow = floatArray->Value(fieldIndex); + auto valueInner = value.GetVariantItem().Get(); + UNIT_ASSERT_VALUES_EQUAL(valueArrow, valueInner); + } else { + arrow::util::string_view viewArrow; + if (value.GetVariantIndex() == 0) { + viewArrow = ysonArray->GetView(fieldIndex); + } else if (value.GetVariantIndex() == 1) { + viewArrow = jsonDocArray->GetView(fieldIndex); + } else if (value.GetVariantIndex() == 2) { + viewArrow = uuidArray->GetView(fieldIndex); + } + std::string valueArrow(viewArrow.data(), viewArrow.size()); + auto innerItem = value.GetVariantItem(); + auto refInner = innerItem.AsStringRef(); + std::string valueInner(refInner.Data(), refInner.Size()); + UNIT_ASSERT_VALUES_EQUAL(valueArrow, valueInner); + } + } + } + Y_UNIT_TEST(VariantOverTupleWithOptionals) { TTestContext context; auto variantType = context.GetVariantOverTupleWithOptionalsType(); UNIT_ASSERT(NArrow::IsArrowCompatible(variantType)); - auto values = context.CreateVariantOverStruct(100); + auto values = context.CreateVariantOverTupleWithOptionals(100); auto array = NArrow::MakeArray(values, variantType); UNIT_ASSERT(array->ValidateFull().ok()); UNIT_ASSERT(static_cast(array->length()) == values.size()); @@ -773,12 +1135,12 @@ Y_UNIT_TEST_SUITE(DqUnboxedValueToNativeArrowConversion) { auto unionArray = static_pointer_cast(array); UNIT_ASSERT(unionArray->num_fields() == 5); - UNIT_ASSERT(unionArray->field(0)->type_id() == arrow::Type::BOOL); + UNIT_ASSERT(unionArray->field(0)->type_id() == arrow::Type::UINT8); UNIT_ASSERT(unionArray->field(1)->type_id() == arrow::Type::INT16); UNIT_ASSERT(unionArray->field(2)->type_id() == arrow::Type::UINT16); UNIT_ASSERT(unionArray->field(3)->type_id() == arrow::Type::INT32); UNIT_ASSERT(unionArray->field(4)->type_id() == arrow::Type::UINT32); - auto boolArray = static_pointer_cast(unionArray->field(0)); + auto boolArray = static_pointer_cast(unionArray->field(0)); auto i16Array = static_pointer_cast(unionArray->field(1)); auto ui16Array = static_pointer_cast(unionArray->field(2)); auto i32Array = static_pointer_cast(unionArray->field(3)); @@ -790,26 +1152,175 @@ Y_UNIT_TEST_SUITE(DqUnboxedValueToNativeArrowConversion) { if (value.GetVariantIndex() == 0) { bool valueArrow = boolArray->Value(fieldIndex); auto valueInner = value.GetVariantItem().Get(); - UNIT_ASSERT(valueArrow == valueInner); + UNIT_ASSERT_VALUES_EQUAL(valueArrow, valueInner); } else if (value.GetVariantIndex() == 1) { auto valueArrow = i16Array->Value(fieldIndex); auto valueInner = value.GetVariantItem().Get(); - UNIT_ASSERT(valueArrow == valueInner); + UNIT_ASSERT_VALUES_EQUAL(valueArrow, valueInner); } else if (value.GetVariantIndex() == 2) { auto valueArrow = ui16Array->Value(fieldIndex); auto valueInner = value.GetVariantItem().Get(); - UNIT_ASSERT(valueArrow == valueInner); + UNIT_ASSERT_VALUES_EQUAL(valueArrow, valueInner); } else if (value.GetVariantIndex() == 3) { auto valueArrow = i32Array->Value(fieldIndex); auto valueInner = value.GetVariantItem().Get(); - UNIT_ASSERT(valueArrow == valueInner); + UNIT_ASSERT_VALUES_EQUAL(valueArrow, valueInner); + } else if (value.GetVariantIndex() == 4) { + if (!value.GetVariantItem().HasValue()) { + UNIT_ASSERT(ui32Array->IsNull(fieldIndex)); + } else { + auto valueArrow = ui32Array->Value(fieldIndex); + auto valueInner = value.GetVariantItem().Get(); + UNIT_ASSERT_VALUES_EQUAL(valueArrow, valueInner); + } + } + } + } + + Y_UNIT_TEST(OptionalVariantOverTupleWithOptionals) { + // DenseUnionArray does not support NULL values, so we wrap it in a StructArray + + TTestContext context; + + auto variantType = context.GetOptionalVariantOverTupleWithOptionalsType(); + UNIT_ASSERT(!NArrow::IsArrowCompatible(variantType)); + + auto values = context.CreateOptionalVariantOverTupleWithOptionals(100); + auto array = NArrow::MakeArray(values, variantType); + UNIT_ASSERT(array->ValidateFull().ok()); + UNIT_ASSERT(static_cast(array->length()) == values.size()); + UNIT_ASSERT(array->type_id() == arrow::Type::STRUCT); + + auto structArray = static_pointer_cast(array); + UNIT_ASSERT(structArray->num_fields() == 1); + UNIT_ASSERT(structArray->field(0)->type_id() == arrow::Type::DENSE_UNION); + + auto unionArray = static_pointer_cast(structArray->field(0)); + UNIT_ASSERT(unionArray->num_fields() == 5); + UNIT_ASSERT(unionArray->field(0)->type_id() == arrow::Type::UINT8); + UNIT_ASSERT(unionArray->field(1)->type_id() == arrow::Type::INT16); + UNIT_ASSERT(unionArray->field(2)->type_id() == arrow::Type::UINT16); + UNIT_ASSERT(unionArray->field(3)->type_id() == arrow::Type::INT32); + UNIT_ASSERT(unionArray->field(4)->type_id() == arrow::Type::UINT32); + auto boolArray = static_pointer_cast(unionArray->field(0)); + auto i16Array = static_pointer_cast(unionArray->field(1)); + auto ui16Array = static_pointer_cast(unionArray->field(2)); + auto i32Array = static_pointer_cast(unionArray->field(3)); + auto ui32Array = static_pointer_cast(unionArray->field(4)); + for (ui64 index = 0; index < values.size(); ++index) { + auto value = values[index]; + if (!value) { + // NULL + UNIT_ASSERT(structArray->IsNull(index)); + continue; + } + + UNIT_ASSERT(!structArray->IsNull(index)); + + UNIT_ASSERT(value.GetVariantIndex() == static_cast(unionArray->child_id(index))); + auto fieldIndex = unionArray->value_offset(index); + if (value.GetVariantIndex() == 0) { + bool valueArrow = boolArray->Value(fieldIndex); + auto valueInner = value.GetVariantItem().Get(); + UNIT_ASSERT_VALUES_EQUAL(valueArrow, valueInner); + } else if (value.GetVariantIndex() == 1) { + auto valueArrow = i16Array->Value(fieldIndex); + auto valueInner = value.GetVariantItem().Get(); + UNIT_ASSERT_VALUES_EQUAL(valueArrow, valueInner); + } else if (value.GetVariantIndex() == 2) { + auto valueArrow = ui16Array->Value(fieldIndex); + auto valueInner = value.GetVariantItem().Get(); + UNIT_ASSERT_VALUES_EQUAL(valueArrow, valueInner); + } else if (value.GetVariantIndex() == 3) { + auto valueArrow = i32Array->Value(fieldIndex); + auto valueInner = value.GetVariantItem().Get(); + UNIT_ASSERT_VALUES_EQUAL(valueArrow, valueInner); } else if (value.GetVariantIndex() == 4) { if (!value.GetVariantItem().HasValue()) { UNIT_ASSERT(ui32Array->IsNull(fieldIndex)); } else { auto valueArrow = ui32Array->Value(fieldIndex); auto valueInner = value.GetVariantItem().Get(); - UNIT_ASSERT(valueArrow == valueInner); + UNIT_ASSERT_VALUES_EQUAL(valueArrow, valueInner); + } + } + } + } + + Y_UNIT_TEST(DoubleOptionalVariantOverTupleWithOptionals) { + // DenseUnionArray does not support NULL values, so we wrap it in a StructArray + + TTestContext context; + + auto variantType = context.GetDoubleOptionalVariantOverTupleWithOptionalsType(); + UNIT_ASSERT(!NArrow::IsArrowCompatible(variantType)); + + auto values = context.CreateDoubleOptionalVariantOverTupleWithOptionals(100); + auto array = NArrow::MakeArray(values, variantType); + UNIT_ASSERT(array->ValidateFull().ok()); + UNIT_ASSERT(static_cast(array->length()) == values.size()); + UNIT_ASSERT(array->type_id() == arrow::Type::STRUCT); + + auto firstStructArray = static_pointer_cast(array); + UNIT_ASSERT(firstStructArray->num_fields() == 1); + UNIT_ASSERT(firstStructArray->field(0)->type_id() == arrow::Type::STRUCT); + + auto secondStructArray = static_pointer_cast(firstStructArray->field(0)); + UNIT_ASSERT(secondStructArray->num_fields() == 1); + UNIT_ASSERT(secondStructArray->field(0)->type_id() == arrow::Type::DENSE_UNION); + + auto unionArray = static_pointer_cast(secondStructArray->field(0)); + UNIT_ASSERT(unionArray->num_fields() == 5); + UNIT_ASSERT(unionArray->field(0)->type_id() == arrow::Type::UINT8); + UNIT_ASSERT(unionArray->field(1)->type_id() == arrow::Type::INT16); + UNIT_ASSERT(unionArray->field(2)->type_id() == arrow::Type::UINT16); + UNIT_ASSERT(unionArray->field(3)->type_id() == arrow::Type::INT32); + UNIT_ASSERT(unionArray->field(4)->type_id() == arrow::Type::UINT32); + auto boolArray = static_pointer_cast(unionArray->field(0)); + auto i16Array = static_pointer_cast(unionArray->field(1)); + auto ui16Array = static_pointer_cast(unionArray->field(2)); + auto i32Array = static_pointer_cast(unionArray->field(3)); + auto ui32Array = static_pointer_cast(unionArray->field(4)); + for (ui64 index = 0; index < values.size(); ++index) { + auto value = values[index]; + if (!value.HasValue()) { + if (value && !value.GetOptionalValue()) { + // Optional(NULL) + UNIT_ASSERT(secondStructArray->IsNull(index)); + } else if (!value) { + // NULL + UNIT_ASSERT(firstStructArray->IsNull(index)); + } + continue; + } + + UNIT_ASSERT(!firstStructArray->IsNull(index) && !secondStructArray->IsNull(index)); + + UNIT_ASSERT(value.GetVariantIndex() == static_cast(unionArray->child_id(index))); + auto fieldIndex = unionArray->value_offset(index); + if (value.GetVariantIndex() == 0) { + bool valueArrow = boolArray->Value(fieldIndex); + auto valueInner = value.GetVariantItem().Get(); + UNIT_ASSERT_VALUES_EQUAL(valueArrow, valueInner); + } else if (value.GetVariantIndex() == 1) { + auto valueArrow = i16Array->Value(fieldIndex); + auto valueInner = value.GetVariantItem().Get(); + UNIT_ASSERT_VALUES_EQUAL(valueArrow, valueInner); + } else if (value.GetVariantIndex() == 2) { + auto valueArrow = ui16Array->Value(fieldIndex); + auto valueInner = value.GetVariantItem().Get(); + UNIT_ASSERT_VALUES_EQUAL(valueArrow, valueInner); + } else if (value.GetVariantIndex() == 3) { + auto valueArrow = i32Array->Value(fieldIndex); + auto valueInner = value.GetVariantItem().Get(); + UNIT_ASSERT_VALUES_EQUAL(valueArrow, valueInner); + } else if (value.GetVariantIndex() == 4) { + if (!value.GetVariantItem().HasValue()) { + UNIT_ASSERT(ui32Array->IsNull(fieldIndex)); + } else { + auto valueArrow = ui32Array->Value(fieldIndex); + auto valueInner = value.GetVariantItem().Get(); + UNIT_ASSERT_VALUES_EQUAL(valueArrow, valueInner); } } } @@ -817,6 +1328,51 @@ Y_UNIT_TEST_SUITE(DqUnboxedValueToNativeArrowConversion) { } Y_UNIT_TEST_SUITE(DqUnboxedValueDoNotFitToArrow) { + Y_UNIT_TEST(DictUtf8ToInterval) { + TTestContext context; + + auto dictType = context.GetDictUtf8ToIntervalType(); + UNIT_ASSERT(!NArrow::IsArrowCompatible(dictType)); + + auto values = context.CreateDictUtf8ToInterval(100); + auto array = NArrow::MakeArray(values, dictType); + UNIT_ASSERT(array->ValidateFull().ok()); + + UNIT_ASSERT(array->type_id() == arrow::Type::STRUCT); + auto wrapArray = static_pointer_cast(array); + UNIT_ASSERT_VALUES_EQUAL(wrapArray->num_fields(), 2); + UNIT_ASSERT_VALUES_EQUAL(static_cast(wrapArray->length()), values.size()); + + UNIT_ASSERT(wrapArray->field(0)->type_id() == arrow::Type::MAP); + auto mapArray = static_pointer_cast(wrapArray->field(0)); + UNIT_ASSERT_VALUES_EQUAL(static_cast(mapArray->length()), values.size()); + + UNIT_ASSERT(wrapArray->field(1)->type_id() == arrow::Type::UINT64); + auto customArray = static_pointer_cast(wrapArray->field(1)); + UNIT_ASSERT_VALUES_EQUAL(static_cast(customArray->length()), values.size()); + + UNIT_ASSERT_VALUES_EQUAL(mapArray->num_fields(), 1); + + UNIT_ASSERT(mapArray->keys()->type_id() == arrow::Type::STRING); + auto utf8Array = static_pointer_cast(mapArray->keys()); + + UNIT_ASSERT(mapArray->items()->type_id() == arrow::Type::INT64); + auto intervalArray = static_pointer_cast(mapArray->items()); + + ui64 index = 0; + for (const auto& value: values) { + UNIT_ASSERT_VALUES_EQUAL(value.GetDictLength(), static_cast(mapArray->value_length(index))); + for (auto subindex = mapArray->value_offset(index); subindex < mapArray->value_offset(index + 1); ++subindex) { + auto keyArrow = utf8Array->GetView(subindex); + NUdf::TUnboxedValue key = MakeString(NUdf::TStringRef(keyArrow.data(), keyArrow.size())); + UNIT_ASSERT(value.Contains(key)); + NUdf::TUnboxedValue payloadValue = value.Lookup(key); + UNIT_ASSERT_VALUES_EQUAL(intervalArray->Value(subindex), payloadValue.Get()); + } + ++index; + } + } + Y_UNIT_TEST(DictOptionalToTuple) { TTestContext context; @@ -827,8 +1383,20 @@ Y_UNIT_TEST_SUITE(DqUnboxedValueDoNotFitToArrow) { auto array = NArrow::MakeArray(values, dictType); UNIT_ASSERT(array->ValidateFull().ok()); UNIT_ASSERT_EQUAL(static_cast(array->length()), values.size()); - UNIT_ASSERT_EQUAL(array->type_id(), arrow::Type::LIST); - auto listArray = static_pointer_cast(array); + UNIT_ASSERT_EQUAL(array->type_id(), arrow::Type::STRUCT); + + auto wrapArray = static_pointer_cast(array); + UNIT_ASSERT_EQUAL(wrapArray->num_fields(), 2); + UNIT_ASSERT_EQUAL(wrapArray->field(0)->type_id(), arrow::Type::LIST); + + UNIT_ASSERT_EQUAL(wrapArray->field(1)->type_id(), arrow::Type::UINT64); + auto listArray = static_pointer_cast(wrapArray->field(0)); + UNIT_ASSERT_EQUAL(static_cast(listArray->length()), values.size()); + + UNIT_ASSERT_EQUAL(wrapArray->field(1)->type_id(), arrow::Type::UINT64); + auto customArray = static_pointer_cast(wrapArray->field(1)); + UNIT_ASSERT_EQUAL(static_cast(customArray->length()), values.size()); + UNIT_ASSERT_EQUAL(listArray->value_type()->id(), arrow::Type::STRUCT); auto structArray = static_pointer_cast(listArray->values()); @@ -870,27 +1438,45 @@ Y_UNIT_TEST_SUITE(DqUnboxedValueDoNotFitToArrow) { auto array = NArrow::MakeArray(values, doubleOptionalType); UNIT_ASSERT(array->ValidateFull().ok()); UNIT_ASSERT_EQUAL(static_cast(array->length()), values.size()); - UNIT_ASSERT_EQUAL(array->type_id(), arrow::Type::STRUCT); - auto structArray = static_pointer_cast(array); - UNIT_ASSERT_EQUAL(structArray->num_fields(), 2); - UNIT_ASSERT_EQUAL(structArray->field(0)->type_id(), arrow::Type::UINT64); - UNIT_ASSERT_EQUAL(structArray->field(1)->type_id(), arrow::Type::INT32); - auto depthArray = static_pointer_cast(structArray->field(0)); - auto i32Array = static_pointer_cast(structArray->field(1)); auto index = 0; for (auto value: values) { - auto depth = depthArray->Value(index); - while (depth > 0) { - UNIT_ASSERT(value.HasValue()); + std::shared_ptr currentArray = array; + int depth = 0; + + while (currentArray->type()->id() == arrow::Type::STRUCT) { + auto structArray = static_pointer_cast(currentArray); + UNIT_ASSERT_EQUAL(structArray->num_fields(), 1); + + if (structArray->IsNull(index)) { + break; + } + + ++depth; + + auto childArray = structArray->field(0); + if (childArray->type()->id() == arrow::Type::DENSE_UNION) { + break; + } + + currentArray = childArray; + } + + while (depth--) { + UNIT_ASSERT(value); value = value.GetOptionalValue(); - --depth; } + if (value.HasValue()) { - UNIT_ASSERT_EQUAL(value.Get(), i32Array->Value(index)); + if (currentArray->type()->id() == arrow::Type::INT32) { + UNIT_ASSERT_EQUAL(value.Get(), static_pointer_cast(currentArray)->Value(index)); + } else { + UNIT_ASSERT(!currentArray->IsNull(index)); + } } else { - UNIT_ASSERT(i32Array->IsNull(index)); + UNIT_ASSERT(currentArray->IsNull(index)); } + ++index; } } @@ -954,7 +1540,7 @@ Y_UNIT_TEST_SUITE(ConvertUnboxedValueToArrowAndBack){ TTestContext context; auto dictType = context.GetDictUtf8ToIntervalType(); - UNIT_ASSERT(NArrow::IsArrowCompatible(dictType)); + UNIT_ASSERT(!NArrow::IsArrowCompatible(dictType)); auto values = context.CreateDictUtf8ToInterval(100); auto array = NArrow::MakeArray(values, dictType); @@ -980,6 +1566,21 @@ Y_UNIT_TEST_SUITE(ConvertUnboxedValueToArrowAndBack){ } } + Y_UNIT_TEST(OptionalListOfOptional) { + TTestContext context; + + auto listType = context.GetOptionalListOfOptional(); + Y_ABORT_UNLESS(NArrow::IsArrowCompatible(listType)); + + auto values = context.CreateOptionalListOfOptional(100); + auto array = NArrow::MakeArray(values, listType); + auto restoredValues = NArrow::ExtractUnboxedValues(array, listType, context.HolderFactory); + UNIT_ASSERT_EQUAL(values.size(), restoredValues.size()); + for (ui64 index = 0; index < values.size(); ++index) { + AssertUnboxedValuesAreEqual(values[index], restoredValues[index], listType); + } + } + Y_UNIT_TEST(VariantOverStruct) { TTestContext context; @@ -995,13 +1596,43 @@ Y_UNIT_TEST_SUITE(ConvertUnboxedValueToArrowAndBack){ } } + Y_UNIT_TEST(OptionalVariantOverStruct) { + TTestContext context; + + auto optionalVariantType = context.GetOptionalVariantOverStructType(); + UNIT_ASSERT(!NArrow::IsArrowCompatible(optionalVariantType)); + + auto values = context.CreateOptionalVariantOverStruct(100); + auto array = NArrow::MakeArray(values, optionalVariantType); + auto restoredValues = NArrow::ExtractUnboxedValues(array, optionalVariantType, context.HolderFactory); + UNIT_ASSERT_EQUAL(values.size(), restoredValues.size()); + for (ui64 index = 0; index < values.size(); ++index) { + AssertUnboxedValuesAreEqual(values[index], restoredValues[index], optionalVariantType); + } + } + + Y_UNIT_TEST(DoubleOptionalVariantOverStruct) { + TTestContext context; + + auto doubleOptionalVariantType = context.GetDoubleOptionalVariantOverStructType(); + UNIT_ASSERT(!NArrow::IsArrowCompatible(doubleOptionalVariantType)); + + auto values = context.CreateDoubleOptionalVariantOverStruct(100); + auto array = NArrow::MakeArray(values, doubleOptionalVariantType); + auto restoredValues = NArrow::ExtractUnboxedValues(array, doubleOptionalVariantType, context.HolderFactory); + UNIT_ASSERT_EQUAL(values.size(), restoredValues.size()); + for (ui64 index = 0; index < values.size(); ++index) { + AssertUnboxedValuesAreEqual(values[index], restoredValues[index], doubleOptionalVariantType); + } + } + Y_UNIT_TEST(VariantOverTupleWithOptionals) { TTestContext context; auto variantType = context.GetVariantOverTupleWithOptionalsType(); UNIT_ASSERT(NArrow::IsArrowCompatible(variantType)); - auto values = context.CreateVariantOverStruct(100); + auto values = context.CreateVariantOverTupleWithOptionals(100); auto array = NArrow::MakeArray(values, variantType); auto restoredValues = NArrow::ExtractUnboxedValues(array, variantType, context.HolderFactory); UNIT_ASSERT_EQUAL(values.size(), restoredValues.size()); @@ -1010,6 +1641,36 @@ Y_UNIT_TEST_SUITE(ConvertUnboxedValueToArrowAndBack){ } } + Y_UNIT_TEST(OptionalVariantOverTupleWithOptionals) { + TTestContext context; + + auto optionalVariantType = context.GetOptionalVariantOverTupleWithOptionalsType(); + UNIT_ASSERT(!NArrow::IsArrowCompatible(optionalVariantType)); + + auto values = context.CreateOptionalVariantOverTupleWithOptionals(100); + auto array = NArrow::MakeArray(values, optionalVariantType); + auto restoredValues = NArrow::ExtractUnboxedValues(array, optionalVariantType, context.HolderFactory); + UNIT_ASSERT_EQUAL(values.size(), restoredValues.size()); + for (ui64 index = 0; index < values.size(); ++index) { + AssertUnboxedValuesAreEqual(values[index], restoredValues[index], optionalVariantType); + } + } + + Y_UNIT_TEST(DoubleOptionalVariantOverTupleWithOptionals) { + TTestContext context; + + auto doubleOptionalVariantType = context.GetDoubleOptionalVariantOverTupleWithOptionalsType(); + UNIT_ASSERT(!NArrow::IsArrowCompatible(doubleOptionalVariantType)); + + auto values = context.CreateDoubleOptionalVariantOverTupleWithOptionals(100); + auto array = NArrow::MakeArray(values, doubleOptionalVariantType); + auto restoredValues = NArrow::ExtractUnboxedValues(array, doubleOptionalVariantType, context.HolderFactory); + UNIT_ASSERT_EQUAL(values.size(), restoredValues.size()); + for (ui64 index = 0; index < values.size(); ++index) { + AssertUnboxedValuesAreEqual(values[index], restoredValues[index], doubleOptionalVariantType); + } + } + Y_UNIT_TEST(DictOptionalToTuple) { TTestContext context; -- cgit v1.3