diff options
author | fedorov-misha <fedorov-misha@yandex-team.ru> | 2022-02-10 16:52:13 +0300 |
---|---|---|
committer | Daniil Cherednik <dcherednik@yandex-team.ru> | 2022-02-10 16:52:13 +0300 |
commit | 9a9addaa33d15f500b54cbe8d78c17d6abd69257 (patch) | |
tree | ab7fbbf3253d4c0e2793218f09378908beb025fb | |
parent | e623fb462572164ba2188eb5a00a14f7eb9bdc38 (diff) | |
download | ydb-9a9addaa33d15f500b54cbe8d78c17d6abd69257.tar.gz |
Restoring authorship annotation for <fedorov-misha@yandex-team.ru>. Commit 2 of 2.
42 files changed, 4018 insertions, 4018 deletions
diff --git a/ydb/core/formats/arrow_batch_builder.cpp b/ydb/core/formats/arrow_batch_builder.cpp index 09414dc10da..8ecf97b9c78 100644 --- a/ydb/core/formats/arrow_batch_builder.cpp +++ b/ydb/core/formats/arrow_batch_builder.cpp @@ -40,7 +40,7 @@ arrow::Status AppendCell(arrow::Decimal128Builder& builder, const TCell& cell) { return builder.AppendNull(); } - /// @warning There's no conversion for special YQL Decimal valies here, + /// @warning There's no conversion for special YQL Decimal valies here, /// so we could convert them to Arrow and back but cannot calculate anything on them. /// We need separate Arrow.Decimal, YQL.Decimal, CH.Decimal and YDB.Decimal in future. return builder.Append(cell.Data()); @@ -52,18 +52,18 @@ arrow::Status AppendCell(arrow::RecordBatchBuilder& builder, const TCell& cell, return AppendCell(*builder.GetFieldAs<TBuilderType>(colNum), cell); } -arrow::Status AppendCell(arrow::RecordBatchBuilder& builder, const TCell& cell, ui32 colNum, NScheme::TTypeId type) { - arrow::Status result; - auto callback = [&]<typename TType>(TTypeWrapper<TType> typeHolder) { - Y_UNUSED(typeHolder); - result = AppendCell<TType>(builder, cell, colNum); - return true; - }; - auto success = SwitchYqlTypeToArrowType(type, std::move(callback)); - if(!success) { - return arrow::Status::TypeError("Unsupported type"); +arrow::Status AppendCell(arrow::RecordBatchBuilder& builder, const TCell& cell, ui32 colNum, NScheme::TTypeId type) { + arrow::Status result; + auto callback = [&]<typename TType>(TTypeWrapper<TType> typeHolder) { + Y_UNUSED(typeHolder); + result = AppendCell<TType>(builder, cell, colNum); + return true; + }; + auto success = SwitchYqlTypeToArrowType(type, std::move(callback)); + if(!success) { + return arrow::Status::TypeError("Unsupported type"); } - return result; + return result; } } @@ -166,18 +166,18 @@ TString TArrowBatchBuilder::Finish() { return str; } -std::shared_ptr<arrow::RecordBatch> CreateNoColumnsBatch(ui64 rowsCount) { - auto field = std::make_shared<arrow::Field>("", std::make_shared<arrow::NullType>()); - std::shared_ptr<arrow::Schema> schema = std::make_shared<arrow::Schema>(std::vector<std::shared_ptr<arrow::Field>>({field})); - std::unique_ptr<arrow::RecordBatchBuilder> batchBuilder; - auto status = arrow::RecordBatchBuilder::Make(schema, arrow::default_memory_pool(), &batchBuilder); - Y_VERIFY_DEBUG(status.ok(), "Failed to create BatchBuilder"); - status = batchBuilder->GetFieldAs<arrow::NullBuilder>(0)->AppendNulls(rowsCount); - Y_VERIFY_DEBUG(status.ok(), "Failed to Append nulls"); - std::shared_ptr<arrow::RecordBatch> batch; - status = batchBuilder->Flush(&batch); - Y_VERIFY_DEBUG(status.ok(), "Failed to Flush Batch"); - return batch; -} - +std::shared_ptr<arrow::RecordBatch> CreateNoColumnsBatch(ui64 rowsCount) { + auto field = std::make_shared<arrow::Field>("", std::make_shared<arrow::NullType>()); + std::shared_ptr<arrow::Schema> schema = std::make_shared<arrow::Schema>(std::vector<std::shared_ptr<arrow::Field>>({field})); + std::unique_ptr<arrow::RecordBatchBuilder> batchBuilder; + auto status = arrow::RecordBatchBuilder::Make(schema, arrow::default_memory_pool(), &batchBuilder); + Y_VERIFY_DEBUG(status.ok(), "Failed to create BatchBuilder"); + status = batchBuilder->GetFieldAs<arrow::NullBuilder>(0)->AppendNulls(rowsCount); + Y_VERIFY_DEBUG(status.ok(), "Failed to Append nulls"); + std::shared_ptr<arrow::RecordBatch> batch; + status = batchBuilder->Flush(&batch); + Y_VERIFY_DEBUG(status.ok(), "Failed to Flush Batch"); + return batch; +} + } diff --git a/ydb/core/formats/arrow_batch_builder.h b/ydb/core/formats/arrow_batch_builder.h index 0a7bf495e9c..d52a94ed20c 100644 --- a/ydb/core/formats/arrow_batch_builder.h +++ b/ydb/core/formats/arrow_batch_builder.h @@ -15,7 +15,7 @@ public: TArrowBatchBuilder(arrow::Compression::type codec = arrow::Compression::UNCOMPRESSED); ~TArrowBatchBuilder() = default; - bool Start(const TVector<std::pair<TString, NScheme::TTypeId>>& columns, + bool Start(const TVector<std::pair<TString, NScheme::TTypeId>>& columns, ui64 maxRowsInBlock, ui64 maxBytesInBlock, TString& err) override { Y_UNUSED(maxRowsInBlock); Y_UNUSED(maxBytesInBlock); @@ -66,8 +66,8 @@ private: } }; -// Creates a batch with single column of type NullType and with num_rows equal rowsCount. All values are null. We need -// this function, because batch can not have zero columns. And NullType conusumes the least place in memory. -std::shared_ptr<arrow::RecordBatch> CreateNoColumnsBatch(ui64 rowsCount); - +// Creates a batch with single column of type NullType and with num_rows equal rowsCount. All values are null. We need +// this function, because batch can not have zero columns. And NullType conusumes the least place in memory. +std::shared_ptr<arrow::RecordBatch> CreateNoColumnsBatch(ui64 rowsCount); + } diff --git a/ydb/core/formats/arrow_helpers.cpp b/ydb/core/formats/arrow_helpers.cpp index 7bf43e16d86..3e1e1b0444a 100644 --- a/ydb/core/formats/arrow_helpers.cpp +++ b/ydb/core/formats/arrow_helpers.cpp @@ -181,39 +181,39 @@ std::shared_ptr<arrow::UInt64Array> SortPermutation(const std::shared_ptr<arrow: #endif } -template <typename TType> -std::shared_ptr<arrow::DataType> CreateEmptyArrowImpl() { - return std::make_shared<TType>(); -} - -template <> -std::shared_ptr<arrow::DataType> CreateEmptyArrowImpl<arrow::Decimal128Type>() { - return arrow::decimal(NScheme::DECIMAL_PRECISION, NScheme::DECIMAL_SCALE); -} - -template <> -std::shared_ptr<arrow::DataType> CreateEmptyArrowImpl<arrow::TimestampType>() { - return arrow::timestamp(arrow::TimeUnit::TimeUnit::MICRO); -} - -template <> -std::shared_ptr<arrow::DataType> CreateEmptyArrowImpl<arrow::DurationType>() { - return arrow::duration(arrow::TimeUnit::TimeUnit::MICRO); -} - -std::shared_ptr<arrow::DataType> GetArrowType(NScheme::TTypeId typeId) { - std::shared_ptr<arrow::DataType> result; - bool success = SwitchYqlTypeToArrowType(typeId, [&]<typename TType>(TTypeWrapper<TType> typeHolder) { - Y_UNUSED(typeHolder); - result = CreateEmptyArrowImpl<TType>(); - return true; - }); - if (success) { - return result; - } - return std::make_shared<arrow::NullType>(); -} - +template <typename TType> +std::shared_ptr<arrow::DataType> CreateEmptyArrowImpl() { + return std::make_shared<TType>(); +} + +template <> +std::shared_ptr<arrow::DataType> CreateEmptyArrowImpl<arrow::Decimal128Type>() { + return arrow::decimal(NScheme::DECIMAL_PRECISION, NScheme::DECIMAL_SCALE); +} + +template <> +std::shared_ptr<arrow::DataType> CreateEmptyArrowImpl<arrow::TimestampType>() { + return arrow::timestamp(arrow::TimeUnit::TimeUnit::MICRO); +} + +template <> +std::shared_ptr<arrow::DataType> CreateEmptyArrowImpl<arrow::DurationType>() { + return arrow::duration(arrow::TimeUnit::TimeUnit::MICRO); +} + +std::shared_ptr<arrow::DataType> GetArrowType(NScheme::TTypeId typeId) { + std::shared_ptr<arrow::DataType> result; + bool success = SwitchYqlTypeToArrowType(typeId, [&]<typename TType>(TTypeWrapper<TType> typeHolder) { + Y_UNUSED(typeHolder); + result = CreateEmptyArrowImpl<TType>(); + return true; + }); + if (success) { + return result; + } + return std::make_shared<arrow::NullType>(); +} + std::vector<std::shared_ptr<arrow::Field>> MakeArrowFields(const TVector<std::pair<TString, NScheme::TTypeId>>& columns) { std::vector<std::shared_ptr<arrow::Field>> fields; fields.reserve(columns.size()); @@ -224,7 +224,7 @@ std::vector<std::shared_ptr<arrow::Field>> MakeArrowFields(const TVector<std::pa return fields; } -std::shared_ptr<arrow::Schema> MakeArrowSchema(const TVector<std::pair<TString, NScheme::TTypeId>>& ydbColumns) { +std::shared_ptr<arrow::Schema> MakeArrowSchema(const TVector<std::pair<TString, NScheme::TTypeId>>& ydbColumns) { return std::make_shared<arrow::Schema>(MakeArrowFields(ydbColumns)); } @@ -604,70 +604,70 @@ TVector<TString> ColumnNames(const std::shared_ptr<arrow::Schema>& schema) { return out; } -ui64 GetBatchDataSize(const std::shared_ptr<arrow::RecordBatch>& batch) { +ui64 GetBatchDataSize(const std::shared_ptr<arrow::RecordBatch>& batch) { if (!batch) { - return 0; - } - ui64 bytes = 0; + return 0; + } + ui64 bytes = 0; for (auto& column : batch->columns()) { // TODO: use column_data() instead of columns() bytes += GetArrayDataSize(column); - } - return bytes; -} - -template <typename TType> -ui64 GetArrayDataSizeImpl(const std::shared_ptr<arrow::Array>& column) { - return sizeof(typename TType::c_type) * column->length(); -} - -template <> -ui64 GetArrayDataSizeImpl<arrow::NullType>(const std::shared_ptr<arrow::Array>& column) { - return column->length() * 8; // Special value for empty lines -} - -template <> -ui64 GetArrayDataSizeImpl<arrow::StringType>(const std::shared_ptr<arrow::Array>& column) { - auto typedColumn = std::static_pointer_cast<arrow::StringArray>(column); - return typedColumn->total_values_length(); -} - -template <> -ui64 GetArrayDataSizeImpl<arrow::LargeStringType>(const std::shared_ptr<arrow::Array>& column) { - auto typedColumn = std::static_pointer_cast<arrow::StringArray>(column); - return typedColumn->total_values_length(); -} - -template <> -ui64 GetArrayDataSizeImpl<arrow::BinaryType>(const std::shared_ptr<arrow::Array>& column) { - auto typedColumn = std::static_pointer_cast<arrow::BinaryArray>(column); - return typedColumn->total_values_length(); -} - -template <> -ui64 GetArrayDataSizeImpl<arrow::LargeBinaryType>(const std::shared_ptr<arrow::Array>& column) { - auto typedColumn = std::static_pointer_cast<arrow::BinaryArray>(column); - return typedColumn->total_values_length(); -} - -template <> -ui64 GetArrayDataSizeImpl<arrow::FixedSizeBinaryType>(const std::shared_ptr<arrow::Array>& column) { - auto typedColumn = std::static_pointer_cast<arrow::FixedSizeBinaryArray>(column); - return typedColumn->byte_width() * typedColumn->length(); -} - -template <> -ui64 GetArrayDataSizeImpl<arrow::Decimal128Type>(const std::shared_ptr<arrow::Array>& column) { - return sizeof(ui64) * 2 * column->length(); -} - + } + return bytes; +} + +template <typename TType> +ui64 GetArrayDataSizeImpl(const std::shared_ptr<arrow::Array>& column) { + return sizeof(typename TType::c_type) * column->length(); +} + +template <> +ui64 GetArrayDataSizeImpl<arrow::NullType>(const std::shared_ptr<arrow::Array>& column) { + return column->length() * 8; // Special value for empty lines +} + +template <> +ui64 GetArrayDataSizeImpl<arrow::StringType>(const std::shared_ptr<arrow::Array>& column) { + auto typedColumn = std::static_pointer_cast<arrow::StringArray>(column); + return typedColumn->total_values_length(); +} + +template <> +ui64 GetArrayDataSizeImpl<arrow::LargeStringType>(const std::shared_ptr<arrow::Array>& column) { + auto typedColumn = std::static_pointer_cast<arrow::StringArray>(column); + return typedColumn->total_values_length(); +} + +template <> +ui64 GetArrayDataSizeImpl<arrow::BinaryType>(const std::shared_ptr<arrow::Array>& column) { + auto typedColumn = std::static_pointer_cast<arrow::BinaryArray>(column); + return typedColumn->total_values_length(); +} + +template <> +ui64 GetArrayDataSizeImpl<arrow::LargeBinaryType>(const std::shared_ptr<arrow::Array>& column) { + auto typedColumn = std::static_pointer_cast<arrow::BinaryArray>(column); + return typedColumn->total_values_length(); +} + +template <> +ui64 GetArrayDataSizeImpl<arrow::FixedSizeBinaryType>(const std::shared_ptr<arrow::Array>& column) { + auto typedColumn = std::static_pointer_cast<arrow::FixedSizeBinaryArray>(column); + return typedColumn->byte_width() * typedColumn->length(); +} + +template <> +ui64 GetArrayDataSizeImpl<arrow::Decimal128Type>(const std::shared_ptr<arrow::Array>& column) { + return sizeof(ui64) * 2 * column->length(); +} + ui64 GetArrayDataSize(const std::shared_ptr<arrow::Array>& column) { auto type = column->type(); - ui64 bytes = 0; - bool success = SwitchTypeWithNull(type->id(), [&]<typename TType>(TTypeWrapper<TType> typeHolder) { - Y_UNUSED(typeHolder); - bytes = GetArrayDataSizeImpl<TType>(column); - return true; - }); + ui64 bytes = 0; + bool success = SwitchTypeWithNull(type->id(), [&]<typename TType>(TTypeWrapper<TType> typeHolder) { + Y_UNUSED(typeHolder); + bytes = GetArrayDataSizeImpl<TType>(column); + return true; + }); // Add null bit mask overhead if any. if (HasNulls(column)) { @@ -675,9 +675,9 @@ ui64 GetArrayDataSize(const std::shared_ptr<arrow::Array>& column) { } Y_VERIFY_DEBUG(success, "Unsupported arrow type %s", type->ToString().data()); - return bytes; -} - + return bytes; +} + std::shared_ptr<arrow::UInt64Array> MakeUI64Array(ui64 value, i64 size) { auto res = arrow::MakeArrayFromScalar(arrow::UInt64Scalar(value), size); Y_VERIFY(res.ok()); @@ -924,7 +924,7 @@ bool TArrowToYdbConverter::Process(const arrow::RecordBatch& batch, TString& err return false; } allColumns.emplace_back(std::move(column)); - } + } std::vector<TSmallVec<TCell>> cells; i64 row = 0; @@ -982,12 +982,12 @@ bool TArrowToYdbConverter::Process(const arrow::RecordBatch& batch, TString& err } bool success = SwitchYqlTypeToArrowType(colType, [&]<typename TType>(TTypeWrapper<TType> typeHolder) { - Y_UNUSED(typeHolder); + Y_UNUSED(typeHolder); cells[0][col] = MakeCell<typename arrow::TypeTraits<TType>::ArrayType>(column, row); - return true; - }); + return true; + }); - if (!success) { + if (!success) { errorMessage = TStringBuilder() << "No arrow conversion for type Yql::" << NScheme::TypeName(colType) << " at column '" << colName << "'"; return false; diff --git a/ydb/core/formats/arrow_helpers.h b/ydb/core/formats/arrow_helpers.h index d96ad566691..cd3ec9f865f 100644 --- a/ydb/core/formats/arrow_helpers.h +++ b/ydb/core/formats/arrow_helpers.h @@ -1,5 +1,5 @@ #pragma once -#include "switch_type.h" +#include "switch_type.h" #include <ydb/core/formats/factory.h> #include <ydb/core/scheme/scheme_tablecell.h> #include <contrib/libs/apache/arrow/cpp/src/arrow/api.h> @@ -23,7 +23,7 @@ public: } }; -std::shared_ptr<arrow::DataType> GetArrowType(NScheme::TTypeId typeId); +std::shared_ptr<arrow::DataType> GetArrowType(NScheme::TTypeId typeId); template <typename T> inline bool ArrayEqualValue(const std::shared_ptr<arrow::Array>& x, const std::shared_ptr<arrow::Array>& y) { @@ -90,9 +90,9 @@ std::shared_ptr<arrow::UInt64Array> MakePermutation(int size, bool reverse = fal std::shared_ptr<arrow::BooleanArray> MakeFilter(const std::vector<bool>& bits); std::vector<bool> CombineFilters(std::vector<bool>&& f1, std::vector<bool>&& f2); TVector<TString> ColumnNames(const std::shared_ptr<arrow::Schema>& schema); -// Return size in bytes including size of bitmap mask -ui64 GetBatchDataSize(const std::shared_ptr<arrow::RecordBatch>& batch); -// Return size in bytes *not* including size of bitmap mask +// Return size in bytes including size of bitmap mask +ui64 GetBatchDataSize(const std::shared_ptr<arrow::RecordBatch>& batch); +// Return size in bytes *not* including size of bitmap mask ui64 GetArrayDataSize(const std::shared_ptr<arrow::Array>& column); enum class ECompareType { @@ -160,26 +160,26 @@ private: return TCell(data.data(), data.size()); } - template <typename TArrayType> + template <typename TArrayType> TCell MakeCell(const std::shared_ptr<arrow::Array>& column, i64 row) { return MakeCellFromValue<TArrayType>(column, row); - } - - template <> + } + + template <> TCell MakeCell<arrow::BinaryArray>(const std::shared_ptr<arrow::Array>& column, i64 row) { return MakeCellFromView<arrow::BinaryArray>(column, row); - } - - template <> + } + + template <> TCell MakeCell<arrow::StringArray>(const std::shared_ptr<arrow::Array>& column, i64 row) { return MakeCellFromView<arrow::StringArray>(column, row); - } - - template <> + } + + template <> TCell MakeCell<arrow::Decimal128Array>(const std::shared_ptr<arrow::Array>& column, i64 row) { return MakeCellFromView<arrow::Decimal128Array>(column, row); - } - + } + public: TArrowToYdbConverter(const TVector<std::pair<TString, NScheme::TTypeId>>& ydbSchema, IRowWriter& rowWriter) : YdbSchema(ydbSchema) diff --git a/ydb/core/formats/switch_type.h b/ydb/core/formats/switch_type.h index 7873cecc029..5cc7924389d 100644 --- a/ydb/core/formats/switch_type.h +++ b/ydb/core/formats/switch_type.h @@ -10,15 +10,15 @@ struct TTypeWrapper using T = TType; }; -template <typename TFunc, bool EnableNull = false> +template <typename TFunc, bool EnableNull = false> bool SwitchType(arrow::Type::type typeId, TFunc&& f) { switch (typeId) { - case arrow::Type::NA: { - if constexpr (EnableNull) { - return f(TTypeWrapper<arrow::NullType>()); - } + case arrow::Type::NA: { + if constexpr (EnableNull) { + return f(TTypeWrapper<arrow::NullType>()); + } break; - } + } case arrow::Type::BOOL: return f(TTypeWrapper<arrow::BooleanType>()); case arrow::Type::UINT8: @@ -48,7 +48,7 @@ bool SwitchType(arrow::Type::type typeId, TFunc&& f) { case arrow::Type::BINARY: return f(TTypeWrapper<arrow::BinaryType>()); case arrow::Type::FIXED_SIZE_BINARY: - return f(TTypeWrapper<arrow::FixedSizeBinaryType>()); + return f(TTypeWrapper<arrow::FixedSizeBinaryType>()); case arrow::Type::DATE32: return f(TTypeWrapper<arrow::Date32Type>()); case arrow::Type::DATE64: @@ -61,8 +61,8 @@ bool SwitchType(arrow::Type::type typeId, TFunc&& f) { return f(TTypeWrapper<arrow::Time64Type>()); case arrow::Type::INTERVAL_MONTHS: return f(TTypeWrapper<arrow::MonthIntervalType>()); - case arrow::Type::DECIMAL: - return f(TTypeWrapper<arrow::Decimal128Type>()); + case arrow::Type::DECIMAL: + return f(TTypeWrapper<arrow::Decimal128Type>()); case arrow::Type::DURATION: return f(TTypeWrapper<arrow::DurationType>()); case arrow::Type::LARGE_STRING: @@ -88,77 +88,77 @@ bool SwitchType(arrow::Type::type typeId, TFunc&& f) { } template <typename TFunc> -bool SwitchTypeWithNull(arrow::Type::type typeId, TFunc&& f) { - return SwitchType<TFunc, true>(typeId, std::move(f)); -} - -template <typename TFunc> +bool SwitchTypeWithNull(arrow::Type::type typeId, TFunc&& f) { + return SwitchType<TFunc, true>(typeId, std::move(f)); +} + +template <typename TFunc> bool SwitchArrayType(const arrow::Datum& column, TFunc&& f) { auto type = column.type(); Y_VERIFY(type); return SwitchType(type->id(), std::forward<TFunc>(f)); } -/** - * @brief Function to switch yql type correctly and uniformly converting it to arrow type using callback - * - * @tparam TFunc Callback type - * @param typeId Type of data callback work with. - * @param callback Template function of signature (TTypeWrapper) -> bool - * @return Result of execution of callback or false if the type typeId is not supported. - */ -template <typename TFunc> -bool SwitchYqlTypeToArrowType(NScheme::TTypeId typeId, TFunc&& callback) { - switch (typeId) { - case NScheme::NTypeIds::Bool: - return callback(TTypeWrapper<arrow::BooleanType>()); - case NScheme::NTypeIds::Int8: - return callback(TTypeWrapper<arrow::Int8Type>()); - case NScheme::NTypeIds::Uint8: - return callback(TTypeWrapper<arrow::UInt8Type>()); - case NScheme::NTypeIds::Int16: - return callback(TTypeWrapper<arrow::Int16Type>()); - case NScheme::NTypeIds::Date: - case NScheme::NTypeIds::Uint16: - return callback(TTypeWrapper<arrow::UInt16Type>()); - case NScheme::NTypeIds::Int32: - return callback(TTypeWrapper<arrow::Int32Type>()); - case NScheme::NTypeIds::Datetime: - case NScheme::NTypeIds::Uint32: - return callback(TTypeWrapper<arrow::UInt32Type>()); - case NScheme::NTypeIds::Int64: - return callback(TTypeWrapper<arrow::Int64Type>()); - case NScheme::NTypeIds::Uint64: - return callback(TTypeWrapper<arrow::UInt64Type>()); - case NScheme::NTypeIds::Float: - return callback(TTypeWrapper<arrow::FloatType>()); - case NScheme::NTypeIds::Double: - return callback(TTypeWrapper<arrow::DoubleType>()); - case NScheme::NTypeIds::Utf8: - return callback(TTypeWrapper<arrow::StringType>()); - case NScheme::NTypeIds::String: - case NScheme::NTypeIds::String4k: - case NScheme::NTypeIds::String2m: - case NScheme::NTypeIds::Yson: - case NScheme::NTypeIds::Json: - case NScheme::NTypeIds::DyNumber: - case NScheme::NTypeIds::JsonDocument: - return callback(TTypeWrapper<arrow::BinaryType>()); - case NScheme::NTypeIds::Timestamp: - return callback(TTypeWrapper<arrow::TimestampType>()); - case NScheme::NTypeIds::Interval: - return callback(TTypeWrapper<arrow::DurationType>()); - case NScheme::NTypeIds::Decimal: - return callback(TTypeWrapper<arrow::Decimal128Type>()); - - case NScheme::NTypeIds::PairUi64Ui64: - case NScheme::NTypeIds::ActorId: - case NScheme::NTypeIds::StepOrderId: - break; // Deprecated types - } - return false; -} - +/** + * @brief Function to switch yql type correctly and uniformly converting it to arrow type using callback + * + * @tparam TFunc Callback type + * @param typeId Type of data callback work with. + * @param callback Template function of signature (TTypeWrapper) -> bool + * @return Result of execution of callback or false if the type typeId is not supported. + */ +template <typename TFunc> +bool SwitchYqlTypeToArrowType(NScheme::TTypeId typeId, TFunc&& callback) { + switch (typeId) { + case NScheme::NTypeIds::Bool: + return callback(TTypeWrapper<arrow::BooleanType>()); + case NScheme::NTypeIds::Int8: + return callback(TTypeWrapper<arrow::Int8Type>()); + case NScheme::NTypeIds::Uint8: + return callback(TTypeWrapper<arrow::UInt8Type>()); + case NScheme::NTypeIds::Int16: + return callback(TTypeWrapper<arrow::Int16Type>()); + case NScheme::NTypeIds::Date: + case NScheme::NTypeIds::Uint16: + return callback(TTypeWrapper<arrow::UInt16Type>()); + case NScheme::NTypeIds::Int32: + return callback(TTypeWrapper<arrow::Int32Type>()); + case NScheme::NTypeIds::Datetime: + case NScheme::NTypeIds::Uint32: + return callback(TTypeWrapper<arrow::UInt32Type>()); + case NScheme::NTypeIds::Int64: + return callback(TTypeWrapper<arrow::Int64Type>()); + case NScheme::NTypeIds::Uint64: + return callback(TTypeWrapper<arrow::UInt64Type>()); + case NScheme::NTypeIds::Float: + return callback(TTypeWrapper<arrow::FloatType>()); + case NScheme::NTypeIds::Double: + return callback(TTypeWrapper<arrow::DoubleType>()); + case NScheme::NTypeIds::Utf8: + return callback(TTypeWrapper<arrow::StringType>()); + case NScheme::NTypeIds::String: + case NScheme::NTypeIds::String4k: + case NScheme::NTypeIds::String2m: + case NScheme::NTypeIds::Yson: + case NScheme::NTypeIds::Json: + case NScheme::NTypeIds::DyNumber: + case NScheme::NTypeIds::JsonDocument: + return callback(TTypeWrapper<arrow::BinaryType>()); + case NScheme::NTypeIds::Timestamp: + return callback(TTypeWrapper<arrow::TimestampType>()); + case NScheme::NTypeIds::Interval: + return callback(TTypeWrapper<arrow::DurationType>()); + case NScheme::NTypeIds::Decimal: + return callback(TTypeWrapper<arrow::Decimal128Type>()); + + case NScheme::NTypeIds::PairUi64Ui64: + case NScheme::NTypeIds::ActorId: + case NScheme::NTypeIds::StepOrderId: + break; // Deprecated types + } + return false; +} + template <typename T> bool Append(arrow::ArrayBuilder& builder, const typename T::c_type& value) { using TBuilder = typename arrow::TypeTraits<T>::BuilderType; diff --git a/ydb/core/formats/ut_arrow.cpp b/ydb/core/formats/ut_arrow.cpp index 73501111f1d..344bbe7cb8e 100644 --- a/ydb/core/formats/ut_arrow.cpp +++ b/ydb/core/formats/ut_arrow.cpp @@ -54,8 +54,8 @@ struct TDataRow { std::string Utf8; std::string Json; std::string Yson; - ui16 Date; - ui32 Datetime; + ui16 Date; + ui32 Datetime; i64 Timestamp; i64 Interval; ui64 Decimal[2]; @@ -96,12 +96,12 @@ struct TDataRow { arrow::field("ui64", arrow::uint64()), arrow::field("f32", arrow::float32()), arrow::field("f64", arrow::float64()), - arrow::field("string", arrow::binary()), + arrow::field("string", arrow::binary()), arrow::field("utf8", arrow::utf8()), arrow::field("json", arrow::binary()), arrow::field("yson", arrow::binary()), - arrow::field("date", arrow::uint16()), - arrow::field("datetime", arrow::uint32()), + arrow::field("date", arrow::uint16()), + arrow::field("datetime", arrow::uint32()), arrow::field("ts", arrow::timestamp(arrow::TimeUnit::TimeUnit::MICRO)), arrow::field("ival", arrow::duration(arrow::TimeUnit::TimeUnit::MICRO)), arrow::field("dec", arrow::decimal(NScheme::DECIMAL_PRECISION, NScheme::DECIMAL_SCALE)), @@ -110,8 +110,8 @@ struct TDataRow { return std::make_shared<arrow::Schema>(fields); } - static TVector<std::pair<TString, TTypeId>> MakeYdbSchema() { - TVector<std::pair<TString, TTypeId>> columns = { + static TVector<std::pair<TString, TTypeId>> MakeYdbSchema() { + TVector<std::pair<TString, TTypeId>> columns = { {"bool", NTypeIds::Bool }, {"i8", NTypeIds::Int8 }, {"i16", NTypeIds::Int16 }, @@ -153,8 +153,8 @@ struct TDataRow { Cells[12] = TCell(Utf8.data(), Utf8.size()); Cells[13] = TCell(Json.data(), Json.size()); Cells[14] = TCell(Yson.data(), Yson.size()); - Cells[15] = TCell::Make<ui16>(Date); - Cells[16] = TCell::Make<ui32>(Datetime); + Cells[15] = TCell::Make<ui16>(Date); + Cells[16] = TCell::Make<ui32>(Datetime); Cells[17] = TCell::Make<i64>(Timestamp); Cells[18] = TCell::Make<i64>(Interval); Cells[19] = TCell((const char *)&Decimal[0], 16); @@ -193,8 +193,8 @@ std::vector<TDataRow> ToVector(const std::shared_ptr<T>& table) { auto arj = std::static_pointer_cast<arrow::BinaryArray>(GetColumn(*table, 13)); auto ary = std::static_pointer_cast<arrow::BinaryArray>(GetColumn(*table, 14)); - auto ard = std::static_pointer_cast<arrow::UInt16Array>(GetColumn(*table, 15)); - auto ardt = std::static_pointer_cast<arrow::UInt32Array>(GetColumn(*table, 16)); + auto ard = std::static_pointer_cast<arrow::UInt16Array>(GetColumn(*table, 15)); + auto ardt = std::static_pointer_cast<arrow::UInt32Array>(GetColumn(*table, 16)); auto arts = std::static_pointer_cast<arrow::TimestampArray>(GetColumn(*table, 17)); auto arival = std::static_pointer_cast<arrow::DurationArray>(GetColumn(*table, 18)); @@ -270,8 +270,8 @@ public: std::shared_ptr<arrow::BinaryArray> arj; std::shared_ptr<arrow::BinaryArray> ary; - std::shared_ptr<arrow::UInt16Array> ard; - std::shared_ptr<arrow::UInt32Array> ardt; + std::shared_ptr<arrow::UInt16Array> ard; + std::shared_ptr<arrow::UInt32Array> ardt; std::shared_ptr<arrow::TimestampArray> arts; std::shared_ptr<arrow::DurationArray> arival; @@ -337,8 +337,8 @@ private: arrow::StringBuilder Butf; arrow::BinaryBuilder Bj; arrow::BinaryBuilder By; - arrow::UInt16Builder Bd; - arrow::UInt32Builder Bdt; + arrow::UInt16Builder Bd; + arrow::UInt32Builder Bdt; arrow::TimestampBuilder Bts; arrow::DurationBuilder Bival; arrow::Decimal128Builder Bdec; @@ -566,7 +566,7 @@ Y_UNIT_TEST_SUITE(ArrowTest) { } } rowWriter; - NArrow::TArrowToYdbConverter toYdbConverter(TDataRow::MakeYdbSchema(), rowWriter); + NArrow::TArrowToYdbConverter toYdbConverter(TDataRow::MakeYdbSchema(), rowWriter); TString errStr; bool ok = toYdbConverter.Process(*batch, errStr); UNIT_ASSERT(ok); diff --git a/ydb/core/grpc_services/rpc_read_columns.cpp b/ydb/core/grpc_services/rpc_read_columns.cpp index dc46e88efbe..4baae606a97 100644 --- a/ydb/core/grpc_services/rpc_read_columns.cpp +++ b/ydb/core/grpc_services/rpc_read_columns.cpp @@ -369,11 +369,11 @@ private: TString lastKey; size_t rowsExtracted = 0; bool skippedBeforeMinKey = false; - - if (ev->Get()->GetDataFormat() == NKikimrTxDataShard::ARROW) { - return ReplyWithError(Ydb::StatusIds::INTERNAL_ERROR, "Arrow format not supported yet", ctx); - } - + + if (ev->Get()->GetDataFormat() == NKikimrTxDataShard::ARROW) { + return ReplyWithError(Ydb::StatusIds::INTERNAL_ERROR, "Arrow format not supported yet", ctx); + } + for (auto&& row : ev->Get()->Rows) { ++rowsExtracted; if (row.size() != keyColumnCount + ValueColumnTypes.size()) { diff --git a/ydb/core/kqp/compute_actor/kqp_pure_compute_actor.cpp b/ydb/core/kqp/compute_actor/kqp_pure_compute_actor.cpp index 5145bbc893d..750207da8e4 100644 --- a/ydb/core/kqp/compute_actor/kqp_pure_compute_actor.cpp +++ b/ydb/core/kqp/compute_actor/kqp_pure_compute_actor.cpp @@ -98,7 +98,7 @@ public: ScanData = &ComputeCtx.GetTableScan(0); columns.reserve(Meta->ColumnsSize()); - for (const auto& column : ScanData->GetColumns()) { + for (const auto& column : ScanData->GetColumns()) { NMiniKQL::TKqpScanComputeContext::TColumn c; c.Tag = column.Tag; c.Type = column.Type; @@ -242,35 +242,35 @@ private: auto& msg = *ev->Get(); - ui64 bytes = 0; - ui64 rowsCount = 0; - { - auto guard = TaskRunner->BindAllocator(); - switch (msg.GetDataFormat()) { - case NKikimrTxDataShard::EScanDataFormat::UNSPECIFIED: - case NKikimrTxDataShard::EScanDataFormat::CELLVEC: { - if (!msg.Rows.empty()) { - bytes = ScanData->AddRows(msg.Rows, {}, TaskRunner->GetHolderFactory()); - rowsCount = msg.Rows.size(); - } - break; - } - case NKikimrTxDataShard::EScanDataFormat::ARROW: { - if(msg.ArrowBatch != nullptr) { - bytes = ScanData->AddRows(*msg.ArrowBatch, {}, TaskRunner->GetHolderFactory()); - rowsCount = msg.ArrowBatch->num_rows(); - } - break; - } - } - } - - CA_LOG_D("Got sysview scandata, rows: " << rowsCount << ", bytes: " << bytes + ui64 bytes = 0; + ui64 rowsCount = 0; + { + auto guard = TaskRunner->BindAllocator(); + switch (msg.GetDataFormat()) { + case NKikimrTxDataShard::EScanDataFormat::UNSPECIFIED: + case NKikimrTxDataShard::EScanDataFormat::CELLVEC: { + if (!msg.Rows.empty()) { + bytes = ScanData->AddRows(msg.Rows, {}, TaskRunner->GetHolderFactory()); + rowsCount = msg.Rows.size(); + } + break; + } + case NKikimrTxDataShard::EScanDataFormat::ARROW: { + if(msg.ArrowBatch != nullptr) { + bytes = ScanData->AddRows(*msg.ArrowBatch, {}, TaskRunner->GetHolderFactory()); + rowsCount = msg.ArrowBatch->num_rows(); + } + break; + } + } + } + + CA_LOG_D("Got sysview scandata, rows: " << rowsCount << ", bytes: " << bytes << ", finished: " << msg.Finished << ", from: " << SysViewActorId); if (msg.Finished) { CA_LOG_D("Finishing rows buffer"); - ScanData->Finish(); + ScanData->Finish(); } if (Y_UNLIKELY(ScanData->ProfileStats)) { @@ -283,8 +283,8 @@ private: } } - ui64 freeSpace = GetMemoryLimits().ScanBufferSize > ScanData->GetStoredBytes() - ? GetMemoryLimits().ScanBufferSize - ScanData->GetStoredBytes() + ui64 freeSpace = GetMemoryLimits().ScanBufferSize > ScanData->GetStoredBytes() + ? GetMemoryLimits().ScanBufferSize - ScanData->GetStoredBytes() : 0; if (freeSpace > 0) { diff --git a/ydb/core/kqp/compute_actor/kqp_scan_compute_actor.cpp b/ydb/core/kqp/compute_actor/kqp_scan_compute_actor.cpp index 0de5cc63fa7..f8e3cd0c772 100644 --- a/ydb/core/kqp/compute_actor/kqp_scan_compute_actor.cpp +++ b/ydb/core/kqp/compute_actor/kqp_scan_compute_actor.cpp @@ -376,29 +376,29 @@ private: LastKey = std::move(msg.LastKey); - ui64 bytes = 0; - ui64 rowsCount = 0; - { - auto guard = TaskRunner->BindAllocator(); - switch (msg.GetDataFormat()) { - case NKikimrTxDataShard::EScanDataFormat::CELLVEC: - case NKikimrTxDataShard::EScanDataFormat::UNSPECIFIED: { - if (!msg.Rows.empty()) { - bytes = ScanData->AddRows(msg.Rows, state.TabletId, TaskRunner->GetHolderFactory()); - rowsCount = msg.Rows.size(); - } - break; - } - case NKikimrTxDataShard::EScanDataFormat::ARROW: { + ui64 bytes = 0; + ui64 rowsCount = 0; + { + auto guard = TaskRunner->BindAllocator(); + switch (msg.GetDataFormat()) { + case NKikimrTxDataShard::EScanDataFormat::CELLVEC: + case NKikimrTxDataShard::EScanDataFormat::UNSPECIFIED: { + if (!msg.Rows.empty()) { + bytes = ScanData->AddRows(msg.Rows, state.TabletId, TaskRunner->GetHolderFactory()); + rowsCount = msg.Rows.size(); + } + break; + } + case NKikimrTxDataShard::EScanDataFormat::ARROW: { if (msg.ArrowBatch != nullptr) { - bytes = ScanData->AddRows(*msg.ArrowBatch, state.TabletId, TaskRunner->GetHolderFactory()); - rowsCount = msg.ArrowBatch->num_rows(); - } - break; - } - } - } - + bytes = ScanData->AddRows(*msg.ArrowBatch, state.TabletId, TaskRunner->GetHolderFactory()); + rowsCount = msg.ArrowBatch->num_rows(); + } + break; + } + } + } + CA_LOG_D("Got EvScanData, rows: " << rowsCount << ", bytes: " << bytes << ", finished: " << msg.Finished << ", from: " << ev->Sender << ", shards remain: " << Shards.size() << ", delayed for: " << latency.SecondsFloat() << " seconds by ratelimitter"); @@ -421,7 +421,7 @@ private: StartTableScan(); } else { CA_LOG_D("Finish scans"); - ScanData->Finish(); + ScanData->Finish(); if (ScanData->BasicStats) { ScanData->BasicStats->AffectedShards = AffectedShards.size(); @@ -809,7 +809,7 @@ private: auto ev = MakeHolder<TEvDataShard::TEvKqpScan>(); ev->Record.SetLocalPathId(ScanData->TableId.PathId.LocalPathId); - for (auto& column: ScanData->GetColumns()) { + for (auto& column: ScanData->GetColumns()) { ev->Record.AddColumnTags(column.Tag); ev->Record.AddColumnTypes(column.Type); } @@ -849,8 +849,8 @@ private: ); } - ev->Record.SetDataFormat(Meta.GetDataFormat()); - + ev->Record.SetDataFormat(Meta.GetDataFormat()); + bool subscribed = std::exchange(state.SubscribedOnTablet, true); CA_LOG_D("Send EvKqpScan to shardId: " << state.TabletId << ", tablePath: " << ScanData->TablePath @@ -930,8 +930,8 @@ private: state.Ranges.back().To.GetCells(), state.Ranges.back().ToInclusive); TVector<TKeyDesc::TColumnOp> columns; - columns.reserve(ScanData->GetColumns().size()); - for (const auto& column : ScanData->GetColumns()) { + columns.reserve(ScanData->GetColumns().size()); + for (const auto& column : ScanData->GetColumns()) { TKeyDesc::TColumnOp op; op.Column = column.Tag; op.Operation = TKeyDesc::EColumnOperation::Read; @@ -963,8 +963,8 @@ private: THolder<IDestructable> GetSourcesState() override { if (ScanData) { auto state = MakeHolder<TScanFreeSpace>(); - state->FreeSpace = GetMemoryLimits().ScanBufferSize > ScanData->GetStoredBytes() - ? GetMemoryLimits().ScanBufferSize - ScanData->GetStoredBytes() + state->FreeSpace = GetMemoryLimits().ScanBufferSize > ScanData->GetStoredBytes() + ? GetMemoryLimits().ScanBufferSize - ScanData->GetStoredBytes() : 0ul; return state; } @@ -978,12 +978,12 @@ private: auto& state = Shards.front(); - ui64 freeSpace = GetMemoryLimits().ScanBufferSize > ScanData->GetStoredBytes() - ? GetMemoryLimits().ScanBufferSize - ScanData->GetStoredBytes() + ui64 freeSpace = GetMemoryLimits().ScanBufferSize > ScanData->GetStoredBytes() + ? GetMemoryLimits().ScanBufferSize - ScanData->GetStoredBytes() : 0ul; ui64 prevFreeSpace = static_cast<TScanFreeSpace*>(prev.Get())->FreeSpace; - CA_LOG_T("Scan over tablet " << state.TabletId << " finished: " << ScanData->IsFinished() + CA_LOG_T("Scan over tablet " << state.TabletId << " finished: " << ScanData->IsFinished() << ", prevFreeSpace: " << prevFreeSpace << ", freeSpace: " << freeSpace << ", peer: " << state.ActorId); if (!ScanData->IsFinished() && state.State != EShardState::PostRunning diff --git a/ydb/core/kqp/executer/kqp_partition_helper.cpp b/ydb/core/kqp/executer/kqp_partition_helper.cpp index 53ab2df6891..ac786f6c24d 100644 --- a/ydb/core/kqp/executer/kqp_partition_helper.cpp +++ b/ydb/core/kqp/executer/kqp_partition_helper.cpp @@ -75,7 +75,7 @@ THashMap<ui64, TShardParamValuesAndRanges> PartitionParamByKey(const NDq::TMkqlV shardData.ParamType = itemType; } - NDq::TDqDataSerializer dataSerializer{typeEnv, holderFactory, NDqProto::EDataTransportVersion::DATA_TRANSPORT_UV_PICKLE_1_0}; + NDq::TDqDataSerializer dataSerializer{typeEnv, holderFactory, NDqProto::EDataTransportVersion::DATA_TRANSPORT_UV_PICKLE_1_0}; for (auto& [shardId, data] : ret) { ret[shardId].ParamValues = dataSerializer.Serialize(shardParamValues[shardId], itemType); } @@ -159,7 +159,7 @@ THashMap<ui64, TShardParamValuesAndRanges> PartitionParamByKeyPrefix(const NDq:: } } - NDq::TDqDataSerializer dataSerializer(typeEnv, holderFactory, NDqProto::EDataTransportVersion::DATA_TRANSPORT_UV_PICKLE_1_0); + NDq::TDqDataSerializer dataSerializer(typeEnv, holderFactory, NDqProto::EDataTransportVersion::DATA_TRANSPORT_UV_PICKLE_1_0); for (auto& [shardId, data] : ret) { data.ParamValues = dataSerializer.Serialize(shardParamValues[shardId], itemType); } diff --git a/ydb/core/kqp/executer/kqp_scan_executer.cpp b/ydb/core/kqp/executer/kqp_scan_executer.cpp index ee3d24d50f3..07c29f64e21 100644 --- a/ydb/core/kqp/executer/kqp_scan_executer.cpp +++ b/ydb/core/kqp/executer/kqp_scan_executer.cpp @@ -794,26 +794,26 @@ private: protoTaskMeta.AddKeyColumnTypes(keyColumn.Type); } - switch (tableInfo.TableKind) { - case ETableKind::Unknown: - case ETableKind::SysView: { - protoTaskMeta.SetDataFormat(NKikimrTxDataShard::EScanDataFormat::CELLVEC); - break; - } - case ETableKind::Datashard: { - if (AppData()->FeatureFlags.GetEnableArrowFormatAtDatashard()) { - protoTaskMeta.SetDataFormat(NKikimrTxDataShard::EScanDataFormat::ARROW); - } else { - protoTaskMeta.SetDataFormat(NKikimrTxDataShard::EScanDataFormat::CELLVEC); - } - break; - } - case ETableKind::Olap: { - protoTaskMeta.SetDataFormat(NKikimrTxDataShard::EScanDataFormat::ARROW); - break; - } - } - + switch (tableInfo.TableKind) { + case ETableKind::Unknown: + case ETableKind::SysView: { + protoTaskMeta.SetDataFormat(NKikimrTxDataShard::EScanDataFormat::CELLVEC); + break; + } + case ETableKind::Datashard: { + if (AppData()->FeatureFlags.GetEnableArrowFormatAtDatashard()) { + protoTaskMeta.SetDataFormat(NKikimrTxDataShard::EScanDataFormat::ARROW); + } else { + protoTaskMeta.SetDataFormat(NKikimrTxDataShard::EScanDataFormat::CELLVEC); + } + break; + } + case ETableKind::Olap: { + protoTaskMeta.SetDataFormat(NKikimrTxDataShard::EScanDataFormat::ARROW); + break; + } + } + for (bool skipNullKey : stageInfo.Meta.SkipNullKeys) { protoTaskMeta.AddSkipNullKeys(skipNullKey); } diff --git a/ydb/core/kqp/kqp_compute.h b/ydb/core/kqp/kqp_compute.h index bfc244cd7a2..943f5adac39 100644 --- a/ydb/core/kqp/kqp_compute.h +++ b/ydb/core/kqp/kqp_compute.h @@ -1,194 +1,194 @@ -#pragma once - -#include "kqp.h" - +#pragma once + +#include "kqp.h" + #include <ydb/core/formats/arrow_helpers.h> #include <ydb/core/protos/tx_datashard.pb.h> - + #include <contrib/libs/apache/arrow/cpp/src/arrow/api.h> - -namespace NKikimr::NKqp { - -struct TEvKqpCompute { - struct TEvRemoteScanData : public TEventPB<TEvRemoteScanData, NKikimrKqp::TEvRemoteScanData, - TKqpComputeEvents::EvRemoteScanData> {}; - - /* - * Scan communications. - * - * TEvScanData is intentionally preserved as a local event for performance reasons: leaf compute - * actors are communicating with shard scans using this message, so big amount of unfiltered data - * is expected. However, it is possible that after query planning datashard would migrate to other - * node. To support scans in this case we provide serialization routines. For now such remote scan - * is considered as rare event and not worth of some fast serialization, so we just use protobuf. - * - * TEvScanDataAck follows the same pattern mostly for symmetry reasons. - */ - struct TEvScanData : public NActors::TEventLocal<TEvScanData, TKqpComputeEvents::EvScanData> { - TEvScanData(ui32 scanId, ui32 generation = 0) - : ScanId(scanId) - , Generation(generation) - , Finished(false) {} - - ui32 ScanId; - ui32 Generation; - TVector<TOwnedCellVec> Rows; - std::shared_ptr<arrow::RecordBatch> ArrowBatch; - TOwnedCellVec LastKey; + +namespace NKikimr::NKqp { + +struct TEvKqpCompute { + struct TEvRemoteScanData : public TEventPB<TEvRemoteScanData, NKikimrKqp::TEvRemoteScanData, + TKqpComputeEvents::EvRemoteScanData> {}; + + /* + * Scan communications. + * + * TEvScanData is intentionally preserved as a local event for performance reasons: leaf compute + * actors are communicating with shard scans using this message, so big amount of unfiltered data + * is expected. However, it is possible that after query planning datashard would migrate to other + * node. To support scans in this case we provide serialization routines. For now such remote scan + * is considered as rare event and not worth of some fast serialization, so we just use protobuf. + * + * TEvScanDataAck follows the same pattern mostly for symmetry reasons. + */ + struct TEvScanData : public NActors::TEventLocal<TEvScanData, TKqpComputeEvents::EvScanData> { + TEvScanData(ui32 scanId, ui32 generation = 0) + : ScanId(scanId) + , Generation(generation) + , Finished(false) {} + + ui32 ScanId; + ui32 Generation; + TVector<TOwnedCellVec> Rows; + std::shared_ptr<arrow::RecordBatch> ArrowBatch; + TOwnedCellVec LastKey; TDuration CpuTime; - TDuration WaitTime; - ui32 PageFaults = 0; // number of page faults occurred when filling in this message - bool Finished = false; - bool PageFault = false; // page fault was the reason for sending this message - mutable THolder<TEvRemoteScanData> Remote; - - bool IsSerializable() const override { - return true; - } - - ui32 CalculateSerializedSize() const override { - InitRemote(); - return Remote->CalculateSerializedSizeCached(); - } - - bool SerializeToArcadiaStream(NActors::TChunkSerializer* chunker) const override { - InitRemote(); - return Remote->SerializeToArcadiaStream(chunker); - } - - NKikimrTxDataShard::EScanDataFormat GetDataFormat() const { - if (ArrowBatch != nullptr) { - return NKikimrTxDataShard::EScanDataFormat::ARROW; - } - return NKikimrTxDataShard::EScanDataFormat::CELLVEC; - } - - - static NActors::IEventBase* Load(TEventSerializedData* data) { - auto pbEv = THolder<TEvRemoteScanData>(static_cast<TEvRemoteScanData *>(TEvRemoteScanData::Load(data))); - auto ev = MakeHolder<TEvScanData>(pbEv->Record.GetScanId()); - - ev->Generation = pbEv->Record.GetGeneration(); + TDuration WaitTime; + ui32 PageFaults = 0; // number of page faults occurred when filling in this message + bool Finished = false; + bool PageFault = false; // page fault was the reason for sending this message + mutable THolder<TEvRemoteScanData> Remote; + + bool IsSerializable() const override { + return true; + } + + ui32 CalculateSerializedSize() const override { + InitRemote(); + return Remote->CalculateSerializedSizeCached(); + } + + bool SerializeToArcadiaStream(NActors::TChunkSerializer* chunker) const override { + InitRemote(); + return Remote->SerializeToArcadiaStream(chunker); + } + + NKikimrTxDataShard::EScanDataFormat GetDataFormat() const { + if (ArrowBatch != nullptr) { + return NKikimrTxDataShard::EScanDataFormat::ARROW; + } + return NKikimrTxDataShard::EScanDataFormat::CELLVEC; + } + + + static NActors::IEventBase* Load(TEventSerializedData* data) { + auto pbEv = THolder<TEvRemoteScanData>(static_cast<TEvRemoteScanData *>(TEvRemoteScanData::Load(data))); + auto ev = MakeHolder<TEvScanData>(pbEv->Record.GetScanId()); + + ev->Generation = pbEv->Record.GetGeneration(); ev->CpuTime = TDuration::MicroSeconds(pbEv->Record.GetCpuTimeUs()); - ev->WaitTime = TDuration::MilliSeconds(pbEv->Record.GetWaitTimeMs()); - ev->PageFault = pbEv->Record.GetPageFault(); - ev->PageFaults = pbEv->Record.GetPageFaults(); - ev->Finished = pbEv->Record.GetFinished(); - ev->LastKey = TOwnedCellVec(TSerializedCellVec(pbEv->Record.GetLastKey()).GetCells()); - - auto rows = pbEv->Record.GetRows(); - ev->Rows.reserve(rows.size()); - for (const auto& row: rows) { - ev->Rows.emplace_back(TSerializedCellVec(row).GetCells()); - } - - if (pbEv->Record.HasArrowBatch()) { - auto batch = pbEv->Record.GetArrowBatch(); - auto schema = NArrow::DeserializeSchema(batch.GetSchema()); - ev->ArrowBatch = NArrow::DeserializeBatch(batch.GetBatch(), schema); - } - return ev.Release(); - } - - private: - void InitRemote() const { - if (!Remote) { - Remote = MakeHolder<TEvRemoteScanData>(); - - Remote->Record.SetScanId(ScanId); - Remote->Record.SetGeneration(Generation); + ev->WaitTime = TDuration::MilliSeconds(pbEv->Record.GetWaitTimeMs()); + ev->PageFault = pbEv->Record.GetPageFault(); + ev->PageFaults = pbEv->Record.GetPageFaults(); + ev->Finished = pbEv->Record.GetFinished(); + ev->LastKey = TOwnedCellVec(TSerializedCellVec(pbEv->Record.GetLastKey()).GetCells()); + + auto rows = pbEv->Record.GetRows(); + ev->Rows.reserve(rows.size()); + for (const auto& row: rows) { + ev->Rows.emplace_back(TSerializedCellVec(row).GetCells()); + } + + if (pbEv->Record.HasArrowBatch()) { + auto batch = pbEv->Record.GetArrowBatch(); + auto schema = NArrow::DeserializeSchema(batch.GetSchema()); + ev->ArrowBatch = NArrow::DeserializeBatch(batch.GetBatch(), schema); + } + return ev.Release(); + } + + private: + void InitRemote() const { + if (!Remote) { + Remote = MakeHolder<TEvRemoteScanData>(); + + Remote->Record.SetScanId(ScanId); + Remote->Record.SetGeneration(Generation); Remote->Record.SetCpuTimeUs(CpuTime.MicroSeconds()); - Remote->Record.SetWaitTimeMs(WaitTime.MilliSeconds()); - Remote->Record.SetPageFaults(PageFaults); - Remote->Record.SetFinished(Finished); - Remote->Record.SetPageFaults(PageFaults); - Remote->Record.SetPageFault(PageFault); - Remote->Record.SetLastKey(TSerializedCellVec::Serialize(LastKey)); - - switch (GetDataFormat()) { - case NKikimrTxDataShard::EScanDataFormat::UNSPECIFIED: - case NKikimrTxDataShard::EScanDataFormat::CELLVEC: { - Remote->Record.MutableRows()->Reserve(Rows.size()); - for (const auto& row: Rows) { - Remote->Record.AddRows(TSerializedCellVec::Serialize(row)); - } - break; - } - case NKikimrTxDataShard::EScanDataFormat::ARROW: { - Y_VERIFY_DEBUG(ArrowBatch != nullptr); - auto* protoArrowBatch = Remote->Record.MutableArrowBatch(); - protoArrowBatch->SetSchema(NArrow::SerializeSchema(*ArrowBatch->schema())); + Remote->Record.SetWaitTimeMs(WaitTime.MilliSeconds()); + Remote->Record.SetPageFaults(PageFaults); + Remote->Record.SetFinished(Finished); + Remote->Record.SetPageFaults(PageFaults); + Remote->Record.SetPageFault(PageFault); + Remote->Record.SetLastKey(TSerializedCellVec::Serialize(LastKey)); + + switch (GetDataFormat()) { + case NKikimrTxDataShard::EScanDataFormat::UNSPECIFIED: + case NKikimrTxDataShard::EScanDataFormat::CELLVEC: { + Remote->Record.MutableRows()->Reserve(Rows.size()); + for (const auto& row: Rows) { + Remote->Record.AddRows(TSerializedCellVec::Serialize(row)); + } + break; + } + case NKikimrTxDataShard::EScanDataFormat::ARROW: { + Y_VERIFY_DEBUG(ArrowBatch != nullptr); + auto* protoArrowBatch = Remote->Record.MutableArrowBatch(); + protoArrowBatch->SetSchema(NArrow::SerializeSchema(*ArrowBatch->schema())); protoArrowBatch->SetBatch(NArrow::SerializeBatchNoCompression(ArrowBatch)); - break; - } - } - } - } - }; - - struct TEvRemoteScanDataAck : public NActors::TEventPB<TEvRemoteScanDataAck, NKikimrKqp::TEvRemoteScanDataAck, - TKqpComputeEvents::EvRemoteScanDataAck> {}; - - struct TEvScanDataAck : public NActors::TEventLocal<TEvScanDataAck, TKqpComputeEvents::EvScanDataAck> { - explicit TEvScanDataAck(ui64 freeSpace, ui32 generation = 0) - : FreeSpace(freeSpace) - , Generation(generation) {} - - const ui64 FreeSpace; - const ui32 Generation; - mutable THolder<TEvRemoteScanDataAck> Remote; - - bool IsSerializable() const override { - return true; - } - - ui32 CalculateSerializedSize() const override { - InitRemote(); - return Remote->CalculateSerializedSizeCached(); - } - - bool SerializeToArcadiaStream(NActors::TChunkSerializer* chunker) const override { - InitRemote(); - return Remote->SerializeToArcadiaStream(chunker); - } - - static NActors::IEventBase* Load(TEventSerializedData* data) { - auto pbEv = THolder<TEvRemoteScanDataAck>(static_cast<TEvRemoteScanDataAck *>(TEvRemoteScanDataAck::Load(data))); - return new TEvScanDataAck(pbEv->Record.GetFreeSpace(), pbEv->Record.GetGeneration()); - } - - private: - void InitRemote() const { - if (!Remote) { - Remote.Reset(new TEvRemoteScanDataAck); - Remote->Record.SetFreeSpace(FreeSpace); - Remote->Record.SetGeneration(Generation); - } - } - }; - - struct TEvScanError : public NActors::TEventPB<TEvScanError, NKikimrKqp::TEvScanError, - TKqpComputeEvents::EvScanError> - { - TEvScanError(ui32 generation = 0) { - Record.SetGeneration(generation); - } - }; - - struct TEvScanInitActor : public NActors::TEventPB<TEvScanInitActor, NKikimrKqp::TEvScanInitActor, - TKqpComputeEvents::EvScanInitActor> - { - TEvScanInitActor() {} - - TEvScanInitActor(ui64 scanId, const NActors::TActorId& scanActor, ui32 generation = 0) { - Record.SetScanId(scanId); - ActorIdToProto(scanActor, Record.MutableScanActorId()); - Record.SetGeneration(generation); - } - }; + break; + } + } + } + } + }; + + struct TEvRemoteScanDataAck : public NActors::TEventPB<TEvRemoteScanDataAck, NKikimrKqp::TEvRemoteScanDataAck, + TKqpComputeEvents::EvRemoteScanDataAck> {}; + + struct TEvScanDataAck : public NActors::TEventLocal<TEvScanDataAck, TKqpComputeEvents::EvScanDataAck> { + explicit TEvScanDataAck(ui64 freeSpace, ui32 generation = 0) + : FreeSpace(freeSpace) + , Generation(generation) {} + + const ui64 FreeSpace; + const ui32 Generation; + mutable THolder<TEvRemoteScanDataAck> Remote; + + bool IsSerializable() const override { + return true; + } + + ui32 CalculateSerializedSize() const override { + InitRemote(); + return Remote->CalculateSerializedSizeCached(); + } + + bool SerializeToArcadiaStream(NActors::TChunkSerializer* chunker) const override { + InitRemote(); + return Remote->SerializeToArcadiaStream(chunker); + } + + static NActors::IEventBase* Load(TEventSerializedData* data) { + auto pbEv = THolder<TEvRemoteScanDataAck>(static_cast<TEvRemoteScanDataAck *>(TEvRemoteScanDataAck::Load(data))); + return new TEvScanDataAck(pbEv->Record.GetFreeSpace(), pbEv->Record.GetGeneration()); + } + + private: + void InitRemote() const { + if (!Remote) { + Remote.Reset(new TEvRemoteScanDataAck); + Remote->Record.SetFreeSpace(FreeSpace); + Remote->Record.SetGeneration(Generation); + } + } + }; + + struct TEvScanError : public NActors::TEventPB<TEvScanError, NKikimrKqp::TEvScanError, + TKqpComputeEvents::EvScanError> + { + TEvScanError(ui32 generation = 0) { + Record.SetGeneration(generation); + } + }; + + struct TEvScanInitActor : public NActors::TEventPB<TEvScanInitActor, NKikimrKqp::TEvScanInitActor, + TKqpComputeEvents::EvScanInitActor> + { + TEvScanInitActor() {} + + TEvScanInitActor(ui64 scanId, const NActors::TActorId& scanActor, ui32 generation = 0) { + Record.SetScanId(scanId); + ActorIdToProto(scanActor, Record.MutableScanActorId()); + Record.SetGeneration(generation); + } + }; struct TEvKillScanTablet : public NActors::TEventPB<TEvKillScanTablet, NKikimrKqp::TEvKillScanTablet, TKqpComputeEvents::EvKillScanTablet> {}; -}; - -} // namespace NKikimr::NKqp +}; + +} // namespace NKikimr::NKqp diff --git a/ydb/core/kqp/runtime/kqp_compute.h b/ydb/core/kqp/runtime/kqp_compute.h index 436cc75cb97..c55a4e2a21b 100644 --- a/ydb/core/kqp/runtime/kqp_compute.h +++ b/ydb/core/kqp/runtime/kqp_compute.h @@ -5,8 +5,8 @@ #include <ydb/core/scheme/scheme_tabledefs.h> #include <ydb/core/tablet_flat/flat_row_eggs.h> -// TODO rename file to runtime_compute_context.h - +// TODO rename file to runtime_compute_context.h + namespace NKikimr { namespace NMiniKQL { diff --git a/ydb/core/kqp/runtime/kqp_scan_data.cpp b/ydb/core/kqp/runtime/kqp_scan_data.cpp index 5fff1fe8633..c4d8e241966 100644 --- a/ydb/core/kqp/runtime/kqp_scan_data.cpp +++ b/ydb/core/kqp/runtime/kqp_scan_data.cpp @@ -5,254 +5,254 @@ #include <ydb/library/yql/minikql/mkql_string_util.h> #include <ydb/library/yql/utils/yql_panic.h> - + namespace NKikimr { namespace NMiniKQL { -namespace { - -struct TBytesStatistics { - ui64 AllocatedBytes = 0; - ui64 DataBytes = 0; - - void AddStatistics(const TBytesStatistics& other) { - AllocatedBytes += other.AllocatedBytes; - DataBytes += other.DataBytes; - } - -}; - -TBytesStatistics GetUnboxedValueSize(const NUdf::TUnboxedValue& value, NScheme::TTypeId type) { - namespace NTypeIds = NScheme::NTypeIds; - if (!value) { - return {sizeof(NUdf::TUnboxedValue), 8}; // Special value for NULL elements - } - switch (type) { - case NTypeIds::Bool: - case NTypeIds::Int8: - case NTypeIds::Uint8: - - case NTypeIds::Int16: - case NTypeIds::Uint16: - - case NTypeIds::Int32: - case NTypeIds::Uint32: - case NTypeIds::Float: - case NTypeIds::Date: - - case NTypeIds::Int64: - case NTypeIds::Uint64: - case NTypeIds::Double: - case NTypeIds::Datetime: - case NTypeIds::Timestamp: - case NTypeIds::Interval: - case NTypeIds::ActorId: - case NTypeIds::StepOrderId: { - YQL_ENSURE(value.IsEmbedded(), "Passed wrong type: " << NScheme::TypeName(type)); - return {sizeof(NUdf::TUnboxedValue), sizeof(i64)}; - } - case NTypeIds::Decimal: - { - YQL_ENSURE(value.IsEmbedded(), "Passed wrong type: " << NScheme::TypeName(type)); - return {sizeof(NUdf::TUnboxedValue), sizeof(NYql::NDecimal::TInt128)}; - } - case NTypeIds::String: - case NTypeIds::Utf8: - case NTypeIds::Json: - case NTypeIds::Yson: - case NTypeIds::JsonDocument: - case NTypeIds::DyNumber: - case NTypeIds::PairUi64Ui64: { - if (value.IsEmbedded()) { - return {sizeof(NUdf::TUnboxedValue), std::max((ui32) 8, value.AsStringRef().Size())}; - } else { - Y_VERIFY_DEBUG_S(8 < value.AsStringRef().Size(), "Small string of size " << value.AsStringRef().Size() << " is not embedded."); - return {sizeof(NUdf::TUnboxedValue) + value.AsStringRef().Size(), value.AsStringRef().Size()}; - } - } - - default: - Y_VERIFY_DEBUG_S(false, "Unsupported type " << NScheme::TypeName(type)); - if (value.IsEmbedded()) { - return {sizeof(NUdf::TUnboxedValue), sizeof(NUdf::TUnboxedValue)}; - } else { - return {sizeof(NUdf::TUnboxedValue) + value.AsStringRef().Size(), value.AsStringRef().Size()}; - } - } -} - -TBytesStatistics GetRowSize(const NUdf::TUnboxedValue* row, const TSmallVec<TKqpComputeContextBase::TColumn>& columns, - const TSmallVec<TKqpComputeContextBase::TColumn>& systemColumns) -{ - TBytesStatistics rowStats{systemColumns.size() * sizeof(NUdf::TUnboxedValue), 0}; - for (size_t columnIndex = 0; columnIndex < columns.size(); ++columnIndex) { - rowStats.AddStatistics(GetUnboxedValueSize(row[columnIndex], columns[columnIndex].Type)); - } - if (columns.empty()) { - rowStats.AddStatistics({sizeof(ui64), sizeof(ui64)}); - } - return rowStats; -} - -void FillSystemColumns(NUdf::TUnboxedValue* rowItems, TMaybe<ui64> shardId, const TSmallVec<TKqpComputeContextBase::TColumn>& systemColumns) { - for (ui32 i = 0; i < systemColumns.size(); ++i) { - YQL_ENSURE(systemColumns[i].Tag == TKeyDesc::EColumnIdDataShard, "Unknown system column tag: " << systemColumns[i].Tag); - - if (shardId) { - rowItems[i] = NUdf::TUnboxedValuePod(*shardId); - } else { - rowItems[i] = NUdf::TUnboxedValue(); - } - } -} - -template <typename TArrayType, typename TValueType = typename TArrayType::value_type> +namespace { + +struct TBytesStatistics { + ui64 AllocatedBytes = 0; + ui64 DataBytes = 0; + + void AddStatistics(const TBytesStatistics& other) { + AllocatedBytes += other.AllocatedBytes; + DataBytes += other.DataBytes; + } + +}; + +TBytesStatistics GetUnboxedValueSize(const NUdf::TUnboxedValue& value, NScheme::TTypeId type) { + namespace NTypeIds = NScheme::NTypeIds; + if (!value) { + return {sizeof(NUdf::TUnboxedValue), 8}; // Special value for NULL elements + } + switch (type) { + case NTypeIds::Bool: + case NTypeIds::Int8: + case NTypeIds::Uint8: + + case NTypeIds::Int16: + case NTypeIds::Uint16: + + case NTypeIds::Int32: + case NTypeIds::Uint32: + case NTypeIds::Float: + case NTypeIds::Date: + + case NTypeIds::Int64: + case NTypeIds::Uint64: + case NTypeIds::Double: + case NTypeIds::Datetime: + case NTypeIds::Timestamp: + case NTypeIds::Interval: + case NTypeIds::ActorId: + case NTypeIds::StepOrderId: { + YQL_ENSURE(value.IsEmbedded(), "Passed wrong type: " << NScheme::TypeName(type)); + return {sizeof(NUdf::TUnboxedValue), sizeof(i64)}; + } + case NTypeIds::Decimal: + { + YQL_ENSURE(value.IsEmbedded(), "Passed wrong type: " << NScheme::TypeName(type)); + return {sizeof(NUdf::TUnboxedValue), sizeof(NYql::NDecimal::TInt128)}; + } + case NTypeIds::String: + case NTypeIds::Utf8: + case NTypeIds::Json: + case NTypeIds::Yson: + case NTypeIds::JsonDocument: + case NTypeIds::DyNumber: + case NTypeIds::PairUi64Ui64: { + if (value.IsEmbedded()) { + return {sizeof(NUdf::TUnboxedValue), std::max((ui32) 8, value.AsStringRef().Size())}; + } else { + Y_VERIFY_DEBUG_S(8 < value.AsStringRef().Size(), "Small string of size " << value.AsStringRef().Size() << " is not embedded."); + return {sizeof(NUdf::TUnboxedValue) + value.AsStringRef().Size(), value.AsStringRef().Size()}; + } + } + + default: + Y_VERIFY_DEBUG_S(false, "Unsupported type " << NScheme::TypeName(type)); + if (value.IsEmbedded()) { + return {sizeof(NUdf::TUnboxedValue), sizeof(NUdf::TUnboxedValue)}; + } else { + return {sizeof(NUdf::TUnboxedValue) + value.AsStringRef().Size(), value.AsStringRef().Size()}; + } + } +} + +TBytesStatistics GetRowSize(const NUdf::TUnboxedValue* row, const TSmallVec<TKqpComputeContextBase::TColumn>& columns, + const TSmallVec<TKqpComputeContextBase::TColumn>& systemColumns) +{ + TBytesStatistics rowStats{systemColumns.size() * sizeof(NUdf::TUnboxedValue), 0}; + for (size_t columnIndex = 0; columnIndex < columns.size(); ++columnIndex) { + rowStats.AddStatistics(GetUnboxedValueSize(row[columnIndex], columns[columnIndex].Type)); + } + if (columns.empty()) { + rowStats.AddStatistics({sizeof(ui64), sizeof(ui64)}); + } + return rowStats; +} + +void FillSystemColumns(NUdf::TUnboxedValue* rowItems, TMaybe<ui64> shardId, const TSmallVec<TKqpComputeContextBase::TColumn>& systemColumns) { + for (ui32 i = 0; i < systemColumns.size(); ++i) { + YQL_ENSURE(systemColumns[i].Tag == TKeyDesc::EColumnIdDataShard, "Unknown system column tag: " << systemColumns[i].Tag); + + if (shardId) { + rowItems[i] = NUdf::TUnboxedValuePod(*shardId); + } else { + rowItems[i] = NUdf::TUnboxedValue(); + } + } +} + +template <typename TArrayType, typename TValueType = typename TArrayType::value_type> NUdf::TUnboxedValue MakeUnboxedValue(arrow::Array* column, ui32 row) { auto array = reinterpret_cast<TArrayType*>(column); - return NUdf::TUnboxedValuePod(static_cast<TValueType>(array->Value(row))); -} - + return NUdf::TUnboxedValuePod(static_cast<TValueType>(array->Value(row))); +} + NUdf::TUnboxedValue MakeUnboxedValueFromBinaryData(arrow::Array* column, ui32 row) { auto array = reinterpret_cast<arrow::BinaryArray*>(column); - auto data = array->GetView(row); - return MakeString(NUdf::TStringRef(data.data(), data.size())); -} - + auto data = array->GetView(row); + return MakeString(NUdf::TStringRef(data.data(), data.size())); +} + NUdf::TUnboxedValue MakeUnboxedValueFromFixedSizeBinaryData(arrow::Array* column, ui32 row) { auto array = reinterpret_cast<arrow::FixedSizeBinaryArray*>(column); - auto data = array->GetView(row); - return MakeString(NUdf::TStringRef(data.data(), data.size()-1)); -} - + auto data = array->GetView(row); + return MakeString(NUdf::TStringRef(data.data(), data.size()-1)); +} + NUdf::TUnboxedValue MakeUnboxedValueFromDecimal128Array(arrow::Array* column, ui32 row) { auto array = reinterpret_cast<arrow::Decimal128Array*>(column); - auto data = array->GetView(row); - // It's known that Decimal params are always Decimal(22,9), - // so we verify Decimal type here before store it in UnboxedValue. - const auto& type = arrow::internal::checked_cast<const arrow::Decimal128Type&>(*array->type()); - YQL_ENSURE(type.precision() == NScheme::DECIMAL_PRECISION, "Unsupported Decimal precision."); - YQL_ENSURE(type.scale() == NScheme::DECIMAL_SCALE, "Unsupported Decimal scale."); - YQL_ENSURE(data.size() == sizeof(NYql::NDecimal::TInt128), "Wrong data size"); - NYql::NDecimal::TInt128 val; - std::memcpy(reinterpret_cast<char*>(&val), data.data(), data.size()); - return NUdf::TUnboxedValuePod(val); -} - + auto data = array->GetView(row); + // It's known that Decimal params are always Decimal(22,9), + // so we verify Decimal type here before store it in UnboxedValue. + const auto& type = arrow::internal::checked_cast<const arrow::Decimal128Type&>(*array->type()); + YQL_ENSURE(type.precision() == NScheme::DECIMAL_PRECISION, "Unsupported Decimal precision."); + YQL_ENSURE(type.scale() == NScheme::DECIMAL_SCALE, "Unsupported Decimal scale."); + YQL_ENSURE(data.size() == sizeof(NYql::NDecimal::TInt128), "Wrong data size"); + NYql::NDecimal::TInt128 val; + std::memcpy(reinterpret_cast<char*>(&val), data.data(), data.size()); + return NUdf::TUnboxedValuePod(val); +} + TBytesStatistics WriteColumnValuesFromArrow(const TVector<NUdf::TUnboxedValue*>& editAccessors, const arrow::RecordBatch& batch, i64 columnIndex, NScheme::TTypeId columnType) { - TBytesStatistics columnStats; + TBytesStatistics columnStats; // Hold pointer to column until function end std::shared_ptr<arrow::Array> columnSharedPtr = batch.column(columnIndex); arrow::Array* columnPtr = columnSharedPtr.get(); - namespace NTypeIds = NScheme::NTypeIds; - for (i64 rowIndex = 0; rowIndex < batch.num_rows(); ++rowIndex) { - auto& rowItem = editAccessors[rowIndex][columnIndex]; + namespace NTypeIds = NScheme::NTypeIds; + for (i64 rowIndex = 0; rowIndex < batch.num_rows(); ++rowIndex) { + auto& rowItem = editAccessors[rowIndex][columnIndex]; if (columnPtr->IsNull(rowIndex)) { - rowItem = NUdf::TUnboxedValue(); - } else { - switch(columnType) { - case NTypeIds::Bool: { + rowItem = NUdf::TUnboxedValue(); + } else { + switch(columnType) { + case NTypeIds::Bool: { rowItem = MakeUnboxedValue<arrow::BooleanArray, bool>(columnPtr, rowIndex); - break; - } - case NTypeIds::Int8: { + break; + } + case NTypeIds::Int8: { rowItem = MakeUnboxedValue<arrow::Int8Array>(columnPtr, rowIndex); - break; - } - case NTypeIds::Int16: { + break; + } + case NTypeIds::Int16: { rowItem = MakeUnboxedValue<arrow::Int16Array>(columnPtr, rowIndex); - break; - } - case NTypeIds::Int32: { + break; + } + case NTypeIds::Int32: { rowItem = MakeUnboxedValue<arrow::Int32Array>(columnPtr, rowIndex); - break; - } - case NTypeIds::Int64: { + break; + } + case NTypeIds::Int64: { rowItem = MakeUnboxedValue<arrow::Int64Array, i64>(columnPtr, rowIndex); - break; - } - case NTypeIds::Uint8: { + break; + } + case NTypeIds::Uint8: { rowItem = MakeUnboxedValue<arrow::UInt8Array>(columnPtr, rowIndex); - break; - } - case NTypeIds::Uint16: { + break; + } + case NTypeIds::Uint16: { rowItem = MakeUnboxedValue<arrow::UInt16Array>(columnPtr, rowIndex); - break; - } - case NTypeIds::Uint32: { + break; + } + case NTypeIds::Uint32: { rowItem = MakeUnboxedValue<arrow::UInt32Array>(columnPtr, rowIndex); - break; - } - case NTypeIds::Uint64: { + break; + } + case NTypeIds::Uint64: { rowItem = MakeUnboxedValue<arrow::UInt64Array, ui64>(columnPtr, rowIndex); - break; - } - case NTypeIds::Float: { + break; + } + case NTypeIds::Float: { rowItem = MakeUnboxedValue<arrow::FloatArray>(columnPtr, rowIndex); - break; - } - case NTypeIds::Double: { + break; + } + case NTypeIds::Double: { rowItem = MakeUnboxedValue<arrow::DoubleArray>(columnPtr, rowIndex); - break; - } - case NTypeIds::String: - case NTypeIds::Utf8: - case NTypeIds::Json: - case NTypeIds::Yson: - case NTypeIds::JsonDocument: - case NTypeIds::DyNumber: { + break; + } + case NTypeIds::String: + case NTypeIds::Utf8: + case NTypeIds::Json: + case NTypeIds::Yson: + case NTypeIds::JsonDocument: + case NTypeIds::DyNumber: { rowItem = MakeUnboxedValueFromBinaryData(columnPtr, rowIndex); - break; - } - case NTypeIds::Date: { + break; + } + case NTypeIds::Date: { rowItem = MakeUnboxedValue<arrow::UInt16Array>(columnPtr, rowIndex); - break; - } - case NTypeIds::Datetime: { + break; + } + case NTypeIds::Datetime: { rowItem = MakeUnboxedValue<arrow::UInt32Array>(columnPtr, rowIndex); - break; - } - case NTypeIds::Timestamp: { + break; + } + case NTypeIds::Timestamp: { rowItem = MakeUnboxedValue<arrow::TimestampArray, ui64>(columnPtr, rowIndex); - break; - } - case NTypeIds::Interval: { + break; + } + case NTypeIds::Interval: { rowItem = MakeUnboxedValue<arrow::DurationArray, ui64>(columnPtr, rowIndex); - break; - } - case NTypeIds::Decimal: { + break; + } + case NTypeIds::Decimal: { rowItem = MakeUnboxedValueFromDecimal128Array(columnPtr, rowIndex); - break; - } - case NTypeIds::PairUi64Ui64: - case NTypeIds::ActorId: - case NTypeIds::StepOrderId: { - Y_VERIFY_DEBUG_S(false, "Unsupported (deprecated) type: " << NScheme::TypeName(columnType)); + break; + } + case NTypeIds::PairUi64Ui64: + case NTypeIds::ActorId: + case NTypeIds::StepOrderId: { + Y_VERIFY_DEBUG_S(false, "Unsupported (deprecated) type: " << NScheme::TypeName(columnType)); rowItem = MakeUnboxedValueFromFixedSizeBinaryData(columnPtr, rowIndex); - break; - } - default: - YQL_ENSURE(false, "Unsupported type: " << NScheme::TypeName(columnType) << " at column " << columnIndex); - } - } - columnStats.AddStatistics(GetUnboxedValueSize(rowItem, columnType)); - } - return columnStats; -} - -} // namespace - -std::pair<ui64, ui64> GetUnboxedValueSizeForTests(const NUdf::TUnboxedValue& value, NScheme::TTypeId type) { - auto sizes = GetUnboxedValueSize(value, type); - return {sizes.AllocatedBytes, sizes.DataBytes}; -} - + break; + } + default: + YQL_ENSURE(false, "Unsupported type: " << NScheme::TypeName(columnType) << " at column " << columnIndex); + } + } + columnStats.AddStatistics(GetUnboxedValueSize(rowItem, columnType)); + } + return columnStats; +} + +} // namespace + +std::pair<ui64, ui64> GetUnboxedValueSizeForTests(const NUdf::TUnboxedValue& value, NScheme::TTypeId type) { + auto sizes = GetUnboxedValueSize(value, type); + return {sizes.AllocatedBytes, sizes.DataBytes}; +} + TKqpScanComputeContext::TScanData::TScanData(const TTableId& tableId, const TTableRange& range, const TSmallVec<TColumn>& columns, const TSmallVec<TColumn>& systemColumns, const TSmallVec<bool>& skipNullKeys) : TableId(tableId) , Range(range) - , SkipNullKeys(skipNullKeys) + , SkipNullKeys(skipNullKeys) , Columns(columns) , SystemColumns(systemColumns) {} @@ -288,67 +288,67 @@ TKqpScanComputeContext::TScanData::TScanData(const NKikimrTxDataShard::TKqpTrans } } - -ui64 TKqpScanComputeContext::TScanData::AddRows(const TVector<TOwnedCellVec>& batch, TMaybe<ui64> shardId, const THolderFactory& holderFactory) { - if (Finished || batch.empty()) { - return 0; - } - + +ui64 TKqpScanComputeContext::TScanData::AddRows(const TVector<TOwnedCellVec>& batch, TMaybe<ui64> shardId, const THolderFactory& holderFactory) { + if (Finished || batch.empty()) { + return 0; + } + TBytesStatistics stats; - TVector<ui64> bytesList; - bytesList.reserve(batch.size()); - - TUnboxedValueVector rows; - rows.reserve(batch.size()); - - for (size_t rowIndex = 0; rowIndex < batch.size(); ++rowIndex) { - auto& row = batch[rowIndex]; - - // Convert row into an UnboxedValue - NUdf::TUnboxedValue* rowItems = nullptr; - rows.emplace_back(holderFactory.CreateDirectArrayHolder(Columns.size() + SystemColumns.size(), rowItems)); - for (ui32 i = 0; i < Columns.size(); ++i) { - rowItems[i] = GetCellValue(row[i], Columns[i].Type); - } - FillSystemColumns(&rowItems[Columns.size()], shardId, SystemColumns); - - stats.AddStatistics(GetRowSize(rowItems, Columns, SystemColumns)); - } - RowBatches.emplace(RowBatch{std::move(rows), shardId}); - - StoredBytes += stats.AllocatedBytes; - if (BasicStats) { - BasicStats->Rows += batch.size(); - BasicStats->Bytes += stats.DataBytes; - } - - return stats.AllocatedBytes; -} - + TVector<ui64> bytesList; + bytesList.reserve(batch.size()); + + TUnboxedValueVector rows; + rows.reserve(batch.size()); + + for (size_t rowIndex = 0; rowIndex < batch.size(); ++rowIndex) { + auto& row = batch[rowIndex]; + + // Convert row into an UnboxedValue + NUdf::TUnboxedValue* rowItems = nullptr; + rows.emplace_back(holderFactory.CreateDirectArrayHolder(Columns.size() + SystemColumns.size(), rowItems)); + for (ui32 i = 0; i < Columns.size(); ++i) { + rowItems[i] = GetCellValue(row[i], Columns[i].Type); + } + FillSystemColumns(&rowItems[Columns.size()], shardId, SystemColumns); + + stats.AddStatistics(GetRowSize(rowItems, Columns, SystemColumns)); + } + RowBatches.emplace(RowBatch{std::move(rows), shardId}); + + StoredBytes += stats.AllocatedBytes; + if (BasicStats) { + BasicStats->Rows += batch.size(); + BasicStats->Bytes += stats.DataBytes; + } + + return stats.AllocatedBytes; +} + ui64 TKqpScanComputeContext::TScanData::AddRows(const arrow::RecordBatch& batch, TMaybe<ui64> shardId, const THolderFactory& holderFactory) { - // RecordBatch hasn't empty method so check the number of rows - if (Finished || batch.num_rows() == 0) { - return 0; - } - - TBytesStatistics stats; - TUnboxedValueVector rows; - + // RecordBatch hasn't empty method so check the number of rows + if (Finished || batch.num_rows() == 0) { + return 0; + } + + TBytesStatistics stats; + TUnboxedValueVector rows; + if (Columns.empty() && SystemColumns.empty()) { rows.resize(batch.num_rows(), holderFactory.GetEmptyContainer()); } else { TVector<NUdf::TUnboxedValue*> editAccessors(batch.num_rows()); rows.reserve(batch.num_rows()); - for (i64 rowIndex = 0; rowIndex < batch.num_rows(); ++rowIndex) { + for (i64 rowIndex = 0; rowIndex < batch.num_rows(); ++rowIndex) { rows.emplace_back(holderFactory.CreateDirectArrayHolder( Columns.size() + SystemColumns.size(), editAccessors[rowIndex]) ); - } + } for (size_t columnIndex = 0; columnIndex < Columns.size(); ++columnIndex) { stats.AddStatistics( @@ -363,37 +363,37 @@ ui64 TKqpScanComputeContext::TScanData::AddRows(const arrow::RecordBatch& batch, stats.AllocatedBytes += batch.num_rows() * SystemColumns.size() * sizeof(NUdf::TUnboxedValue); } - } - - if (Columns.empty()) { - stats.AddStatistics({sizeof(ui64) * batch.num_rows(), sizeof(ui64) * batch.num_rows()}); - } - - RowBatches.emplace(RowBatch{std::move(rows), shardId}); - - StoredBytes += stats.AllocatedBytes; - if (BasicStats) { - BasicStats->Rows += batch.num_rows(); - BasicStats->Bytes += stats.DataBytes; - } - - return stats.AllocatedBytes; + } + + if (Columns.empty()) { + stats.AddStatistics({sizeof(ui64) * batch.num_rows(), sizeof(ui64) * batch.num_rows()}); + } + + RowBatches.emplace(RowBatch{std::move(rows), shardId}); + + StoredBytes += stats.AllocatedBytes; + if (BasicStats) { + BasicStats->Rows += batch.num_rows(); + BasicStats->Bytes += stats.DataBytes; + } + + return stats.AllocatedBytes; +} + +NUdf::TUnboxedValue TKqpScanComputeContext::TScanData::TakeRow() { + YQL_ENSURE(!RowBatches.empty()); + auto& batch = RowBatches.front(); + auto row = std::move(batch.Batch[batch.CurrentRow++]); + auto rowStats = GetRowSize(row.GetElements(), Columns, SystemColumns); + + StoredBytes -= rowStats.AllocatedBytes; + if (batch.CurrentRow == batch.Batch.size()) { + RowBatches.pop(); + } + YQL_ENSURE(RowBatches.empty() == (StoredBytes == 0), "StoredBytes miscalculated!"); + return row; } -NUdf::TUnboxedValue TKqpScanComputeContext::TScanData::TakeRow() { - YQL_ENSURE(!RowBatches.empty()); - auto& batch = RowBatches.front(); - auto row = std::move(batch.Batch[batch.CurrentRow++]); - auto rowStats = GetRowSize(row.GetElements(), Columns, SystemColumns); - - StoredBytes -= rowStats.AllocatedBytes; - if (batch.CurrentRow == batch.Batch.size()) { - RowBatches.pop(); - } - YQL_ENSURE(RowBatches.empty() == (StoredBytes == 0), "StoredBytes miscalculated!"); - return row; -} - void TKqpScanComputeContext::AddTableScan(ui32, const TTableId& tableId, const TTableRange& range, const TSmallVec<TColumn>& columns, const TSmallVec<TColumn>& systemColumns, const TSmallVec<bool>& skipNullKeys) { @@ -449,30 +449,30 @@ public: : ScanData(scanData) {} - NUdf::EFetchStatus Next(NUdf::TUnboxedValue& result) override { - if (ScanData.IsEmpty()) { - if (ScanData.IsFinished()) { + NUdf::EFetchStatus Next(NUdf::TUnboxedValue& result) override { + if (ScanData.IsEmpty()) { + if (ScanData.IsFinished()) { return NUdf::EFetchStatus::Finish; } return NUdf::EFetchStatus::Yield; } - result = std::move(ScanData.TakeRow()); + result = std::move(ScanData.TakeRow()); return NUdf::EFetchStatus::Ok; } EFetchResult Next(NUdf::TUnboxedValue* const* result) override { - if (ScanData.IsEmpty()) { - if (ScanData.IsFinished()) { + if (ScanData.IsEmpty()) { + if (ScanData.IsFinished()) { return EFetchResult::Finish; } - return EFetchResult::Yield; + return EFetchResult::Yield; } - auto row = ScanData.TakeRow(); - for (ui32 i = 0; i < ScanData.GetColumns().size() + ScanData.GetSystemColumns().size(); ++i) { + auto row = ScanData.TakeRow(); + for (ui32 i = 0; i < ScanData.GetColumns().size() + ScanData.GetSystemColumns().size(); ++i) { if (result[i]) { - *result[i] = std::move(row.GetElement(i)); + *result[i] = std::move(row.GetElement(i)); } } diff --git a/ydb/core/kqp/runtime/kqp_scan_data.h b/ydb/core/kqp/runtime/kqp_scan_data.h index 49dc97d141b..c8e7d82b328 100644 --- a/ydb/core/kqp/runtime/kqp_scan_data.h +++ b/ydb/core/kqp/runtime/kqp_scan_data.h @@ -11,10 +11,10 @@ #include <ydb/library/yql/dq/actors/protos/dq_stats.pb.h> #include <ydb/library/yql/minikql/computation/mkql_computation_node_holders.h> -#include <library/cpp/actors/core/log.h> - +#include <library/cpp/actors/core/log.h> + #include <contrib/libs/apache/arrow/cpp/src/arrow/api.h> - + namespace NKikimrTxDataShard { class TKqpTransaction_TScanTaskMeta; } @@ -22,70 +22,70 @@ namespace NKikimrTxDataShard { namespace NKikimr { namespace NMiniKQL { -std::pair<ui64, ui64> GetUnboxedValueSizeForTests(const NUdf::TUnboxedValue& value, NScheme::TTypeId type); - -class IKqpTableReader : public TSimpleRefCount<IKqpTableReader> { +std::pair<ui64, ui64> GetUnboxedValueSizeForTests(const NUdf::TUnboxedValue& value, NScheme::TTypeId type); + +class IKqpTableReader : public TSimpleRefCount<IKqpTableReader> { public: - virtual ~IKqpTableReader() = default; + virtual ~IKqpTableReader() = default; - virtual NUdf::EFetchStatus Next(NUdf::TUnboxedValue& result) = 0; - virtual EFetchResult Next(NUdf::TUnboxedValue* const* output) = 0; -}; + virtual NUdf::EFetchStatus Next(NUdf::TUnboxedValue& result) = 0; + virtual EFetchResult Next(NUdf::TUnboxedValue* const* output) = 0; +}; -class TKqpScanComputeContext : public TKqpComputeContextBase { -public: - class TScanData { - public: - TScanData(TScanData&&) = default; // needed to create TMap<ui32, TScanData> Scans - TScanData(const TTableId& tableId, const TTableRange& range, const TSmallVec<TColumn>& columns, - const TSmallVec<TColumn>& systemColumns, const TSmallVec<bool>& skipNullKeys); +class TKqpScanComputeContext : public TKqpComputeContextBase { +public: + class TScanData { + public: + TScanData(TScanData&&) = default; // needed to create TMap<ui32, TScanData> Scans + TScanData(const TTableId& tableId, const TTableRange& range, const TSmallVec<TColumn>& columns, + const TSmallVec<TColumn>& systemColumns, const TSmallVec<bool>& skipNullKeys); TScanData(const NKikimrTxDataShard::TKqpTransaction_TScanTaskMeta& meta, NYql::NDqProto::EDqStatsMode statsMode); - ~TScanData() { - TString msg = TStringBuilder() << "Buffer in TScanData was not cleared, data is leaking: " - << "Queue of UnboxedValues must be emptied under allocator using Clear method, but has " << RowBatches.size() << " elements!"; - if (!RowBatches.empty()) { - LOG_CRIT_S(*NActors::TlsActivationContext, NKikimrServices::KQP_COMPUTE, msg); - } - Y_VERIFY_DEBUG_S(RowBatches.empty(), msg); - } - - const TSmallVec<TColumn>& GetColumns() const { - return Columns; - } - - const TSmallVec<TColumn>& GetSystemColumns() const { - return SystemColumns; + ~TScanData() { + TString msg = TStringBuilder() << "Buffer in TScanData was not cleared, data is leaking: " + << "Queue of UnboxedValues must be emptied under allocator using Clear method, but has " << RowBatches.size() << " elements!"; + if (!RowBatches.empty()) { + LOG_CRIT_S(*NActors::TlsActivationContext, NKikimrServices::KQP_COMPUTE, msg); + } + Y_VERIFY_DEBUG_S(RowBatches.empty(), msg); + } + + const TSmallVec<TColumn>& GetColumns() const { + return Columns; + } + + const TSmallVec<TColumn>& GetSystemColumns() const { + return SystemColumns; } - ui64 AddRows(const TVector<TOwnedCellVec>& batch, TMaybe<ui64> shardId, const THolderFactory& holderFactory); + ui64 AddRows(const TVector<TOwnedCellVec>& batch, TMaybe<ui64> shardId, const THolderFactory& holderFactory); - ui64 AddRows(const arrow::RecordBatch& batch, TMaybe<ui64> shardId, const THolderFactory& holderFactory); - - NUdf::TUnboxedValue TakeRow(); + ui64 AddRows(const arrow::RecordBatch& batch, TMaybe<ui64> shardId, const THolderFactory& holderFactory); - bool IsEmpty() const { - return RowBatches.empty(); - } + NUdf::TUnboxedValue TakeRow(); - ui64 GetStoredBytes() const { - return StoredBytes; - } + bool IsEmpty() const { + return RowBatches.empty(); + } - void Finish() { - Finished = true; - } + ui64 GetStoredBytes() const { + return StoredBytes; + } - bool IsFinished() const { - return Finished; - } + void Finish() { + Finished = true; + } - void Clear() { - RowBatches.clear(); - } + bool IsFinished() const { + return Finished; + } - public: + void Clear() { + RowBatches.clear(); + } + + public: ui64 TaskId = 0; TTableId TableId; TString TablePath; @@ -114,18 +114,18 @@ public: std::unique_ptr<TBasicStats> BasicStats; std::unique_ptr<TProfileStats> ProfileStats; - private: - struct RowBatch { - TUnboxedValueVector Batch; - TMaybe<ui64> ShardId; - ui64 CurrentRow = 0; - }; - - TSmallVec<TColumn> Columns; - TSmallVec<TColumn> SystemColumns; - TQueue<RowBatch> RowBatches; - ui64 StoredBytes = 0; - bool Finished = false; + private: + struct RowBatch { + TUnboxedValueVector Batch; + TMaybe<ui64> ShardId; + ui64 CurrentRow = 0; + }; + + TSmallVec<TColumn> Columns; + TSmallVec<TColumn> SystemColumns; + TQueue<RowBatch> RowBatches; + ui64 StoredBytes = 0; + bool Finished = false; }; public: @@ -144,12 +144,12 @@ public: TMap<ui32, TScanData>& GetTableScans(); const TMap<ui32, TScanData>& GetTableScans() const; - void Clear() { - for (auto& scan: Scans) { - scan.second.Clear(); - } - Scans.clear(); - } + void Clear() { + for (auto& scan: Scans) { + scan.second.Clear(); + } + Scans.clear(); + } private: const NYql::NDqProto::EDqStatsMode StatsMode; diff --git a/ydb/core/kqp/runtime/kqp_scan_data_ut.cpp b/ydb/core/kqp/runtime/kqp_scan_data_ut.cpp index bdf3eb3f459..daf32d84f3a 100644 --- a/ydb/core/kqp/runtime/kqp_scan_data_ut.cpp +++ b/ydb/core/kqp/runtime/kqp_scan_data_ut.cpp @@ -1,302 +1,302 @@ -#include "kqp_scan_data.h" - +#include "kqp_scan_data.h" + #include <ydb/library/yql/public/udf/udf_ut_helpers.h> #include <ydb/library/yql/minikql/mkql_alloc.h> -#include <library/cpp/testing/unittest/registar.h> - -namespace NKikimr::NMiniKQL { - -namespace { -namespace NTypeIds = NScheme::NTypeIds; -using TTypeId = NScheme::TTypeId; - -struct TDataRow { - TSmallVec<TKqpComputeContextBase::TColumn> Columns() { - return { - {0, NTypeIds::Bool}, - {1, NTypeIds::Int8}, - {2, NTypeIds::Int16}, - {3, NTypeIds::Int32}, - {4, NTypeIds::Int64}, - {5, NTypeIds::Uint8}, - {6, NTypeIds::Uint16}, - {7, NTypeIds::Uint32}, - {8, NTypeIds::Uint64}, - {9, NTypeIds::Float}, - {10, NTypeIds::Double}, - {11, NTypeIds::String}, - {12, NTypeIds::Utf8}, - {13, NTypeIds::Json}, - {14, NTypeIds::Yson}, - {15, NTypeIds::Date}, - {16, NTypeIds::Datetime}, - {17, NTypeIds::Timestamp}, - {18, NTypeIds::Interval}, - {19, NTypeIds::Decimal}, - }; - } - - bool Bool; - i8 Int8; - i16 Int16; - i32 Int32; - i64 Int64; - ui8 UInt8; - ui16 UInt16; - ui32 UInt32; - ui64 UInt64; - float Float32; - double Float64; - TString String; - TString Utf8; - TString Json; - TString Yson; - i32 Date; - i64 Datetime; - i64 Timestamp; - i64 Interval; - NYql::NDecimal::TInt128 Decimal; - - static std::shared_ptr<arrow::Schema> MakeArrowSchema() { - std::vector<std::shared_ptr<arrow::Field>> fields = { - arrow::field("bool", arrow::boolean()), - arrow::field("i8", arrow::int8()), - arrow::field("i16", arrow::int16()), - arrow::field("i32", arrow::int32()), - arrow::field("i64", arrow::int64()), - arrow::field("ui8", arrow::uint8()), - arrow::field("ui16", arrow::uint16()), - arrow::field("ui32", arrow::uint32()), - arrow::field("ui64", arrow::uint64()), - arrow::field("f32", arrow::float32()), - arrow::field("f64", arrow::float64()), - arrow::field("string", arrow::utf8()), - arrow::field("utf8", arrow::utf8()), - arrow::field("json", arrow::binary()), - arrow::field("yson", arrow::binary()), - arrow::field("date", arrow::date32()), - arrow::field("datetime", arrow::timestamp(arrow::TimeUnit::TimeUnit::SECOND)), - arrow::field("ts", arrow::timestamp(arrow::TimeUnit::TimeUnit::MICRO)), - arrow::field("ival", arrow::duration(arrow::TimeUnit::TimeUnit::MICRO)), - arrow::field("dec", arrow::decimal(NScheme::DECIMAL_PRECISION, NScheme::DECIMAL_SCALE)), - }; - - return std::make_shared<arrow::Schema>(fields); - } -}; - -std::shared_ptr<arrow::RecordBatch> VectorToBatch(const std::vector<struct TDataRow>& rows) { - TString err; - std::unique_ptr<arrow::RecordBatchBuilder> batchBuilder = nullptr; - std::shared_ptr<arrow::RecordBatch> batch; - auto result = arrow::RecordBatchBuilder::Make(rows.front().MakeArrowSchema(), arrow::default_memory_pool(), &batchBuilder); - UNIT_ASSERT(result.ok()); - - for (const TDataRow& row : rows) { - auto result0 = batchBuilder->GetFieldAs<arrow::BooleanBuilder >(0 )->Append(row.Bool ); - UNIT_ASSERT(result.ok()); - auto result1 = batchBuilder->GetFieldAs<arrow::Int8Builder >(1 )->Append(row.Int8 ); - UNIT_ASSERT(result.ok()); - auto result2 = batchBuilder->GetFieldAs<arrow::Int16Builder >(2 )->Append(row.Int16 ); - UNIT_ASSERT(result.ok()); - auto result3 = batchBuilder->GetFieldAs<arrow::Int32Builder >(3 )->Append(row.Int32 ); - UNIT_ASSERT(result.ok()); - auto result4 = batchBuilder->GetFieldAs<arrow::Int64Builder >(4 )->Append(row.Int64 ); - UNIT_ASSERT(result.ok()); - auto result5 = batchBuilder->GetFieldAs<arrow::UInt8Builder >(5 )->Append(row.UInt8 ); - UNIT_ASSERT(result.ok()); - auto result6 = batchBuilder->GetFieldAs<arrow::UInt16Builder >(6 )->Append(row.UInt16 ); - UNIT_ASSERT(result.ok()); - auto result7 = batchBuilder->GetFieldAs<arrow::UInt32Builder >(7 )->Append(row.UInt32 ); - UNIT_ASSERT(result.ok()); - auto result8 = batchBuilder->GetFieldAs<arrow::UInt64Builder >(8 )->Append(row.UInt64 ); - UNIT_ASSERT(result.ok()); - auto result9 = batchBuilder->GetFieldAs<arrow::FloatBuilder >(9 )->Append(row.Float32); - UNIT_ASSERT(result.ok()); - auto result10 = batchBuilder->GetFieldAs<arrow::DoubleBuilder >(10)->Append(row.Float64); - UNIT_ASSERT(result.ok()); - auto result11 = batchBuilder->GetFieldAs<arrow::StringBuilder >(11)->Append(row.String.data(), row.String.size()); - UNIT_ASSERT(result.ok()); - auto result12 = batchBuilder->GetFieldAs<arrow::StringBuilder >(12)->Append(row.Utf8.data(), row.Utf8.size()); - UNIT_ASSERT(result.ok()); - auto result13 = batchBuilder->GetFieldAs<arrow::BinaryBuilder >(13)->Append(row.Json.data(), row.Json.size()); - UNIT_ASSERT(result.ok()); - auto result14 = batchBuilder->GetFieldAs<arrow::BinaryBuilder >(14)->Append(row.Yson.data(), row.Yson.size()); - UNIT_ASSERT(result.ok()); - auto result15 = batchBuilder->GetFieldAs<arrow::Date32Builder >(15)->Append(row.Date); - UNIT_ASSERT(result.ok()); - auto result16 = batchBuilder->GetFieldAs<arrow::TimestampBuilder >(16)->Append(row.Datetime); - UNIT_ASSERT(result.ok()); - auto result17 = batchBuilder->GetFieldAs<arrow::TimestampBuilder >(17)->Append(row.Timestamp); - UNIT_ASSERT(result.ok()); - auto result18 = batchBuilder->GetFieldAs<arrow::DurationBuilder >(18)->Append(row.Interval); - UNIT_ASSERT(result.ok()); - auto result19 = batchBuilder->GetFieldAs<arrow::Decimal128Builder>(19)->Append(reinterpret_cast<const char*>(&row.Decimal)); - UNIT_ASSERT(result.ok()); - } - - auto resultFlush = batchBuilder->Flush(&batch); - UNIT_ASSERT(resultFlush.ok()); - return batch; -} - -TVector<TDataRow> TestRows() { - TVector<TDataRow> rows = { - {false, -1, -1, -1, -1, 1, 1, 1, 1, -1.0f, -1.0, "s1" , "u1" , "{j:1}", "{y:1}", 0, 0, 0, 0, 111}, - {false, 2, 2, 2, 2, 2, 2, 2, 2, 2.0f, 2.0, "s2" , "u2" , "{j:2}", "{y:2}", 0, 0, 0, 0, 222}, - {false, -3, -3, -3, -3, 3, 3, 3, 3, -3.0f, -3.0, "s3" , "u3" , "{j:3}", "{y:3}", 0, 0, 0, 0, 333}, - {false, -4, -4, -4, -4, 4, 4, 4, 4, 4.0f, 4.0, "s4" , "u4" , "{j:4}", "{y:4}", 0, 0, 0, 0, 444}, - {false, -5, -5, -5, -5, 5, 5, 5, 5, 5.0f, 5.0, "long5long5long5long5long5", "utflong5utflong5utflong5", "{j:5}", "{y:5}", 0, 0, 0, 0, 555}, - }; - return rows; -} - -} - -Y_UNIT_TEST_SUITE(TKqpScanData) { - - Y_UNIT_TEST(UnboxedValueSize) { - NKikimr::NMiniKQL::TScopedAlloc alloc; - namespace NTypeIds = NScheme::NTypeIds; - struct TTestCase { - NUdf::TUnboxedValue Value; - NScheme::TTypeId Type; - std::pair<ui64, ui64> ExpectedSizes; - }; - TString pattern = "This string has 26 symbols"; - NUdf::TStringValue str(pattern.size()); - std::memcpy(str.Data(), pattern.data(), pattern.size()); - NUdf::TUnboxedValue containsLongString(NUdf::TUnboxedValuePod(std::move(str))); - NYql::NDecimal::TInt128 decimalVal = 123456789012; - TVector<TTestCase> cases = { - {NUdf::TUnboxedValuePod( ), NTypeIds::Bool , {16, 8 } }, - {NUdf::TUnboxedValuePod( ), NTypeIds::Int32 , {16, 8 } }, - {NUdf::TUnboxedValuePod( ), NTypeIds::Uint32 , {16, 8 } }, - {NUdf::TUnboxedValuePod( ), NTypeIds::Int64 , {16, 8 } }, - {NUdf::TUnboxedValuePod( ), NTypeIds::Uint64 , {16, 8 } }, - {NUdf::TUnboxedValuePod( ), NTypeIds::Double , {16, 8 } }, - {NUdf::TUnboxedValuePod( ), NTypeIds::Float , {16, 8 } }, - {NUdf::TUnboxedValuePod( ), NTypeIds::String , {16, 8 } }, - {NUdf::TUnboxedValuePod( ), NTypeIds::Utf8 , {16, 8 } }, - {NUdf::TUnboxedValuePod( ), NTypeIds::Yson , {16, 8 } }, - {NUdf::TUnboxedValuePod( ), NTypeIds::Json , {16, 8 } }, - {NUdf::TUnboxedValuePod( ), NTypeIds::Decimal , {16, 8 } }, - {NUdf::TUnboxedValuePod( ), NTypeIds::Date , {16, 8 } }, - {NUdf::TUnboxedValuePod( ), NTypeIds::Datetime , {16, 8 } }, - {NUdf::TUnboxedValuePod( ), NTypeIds::Timestamp , {16, 8 } }, - {NUdf::TUnboxedValuePod( ), NTypeIds::Interval , {16, 8 } }, - {NUdf::TUnboxedValuePod(true ), NTypeIds::Bool , {16, 8 } }, - {NUdf::TUnboxedValuePod((i8) 1 ), NTypeIds::Int8 , {16, 8 } }, - {NUdf::TUnboxedValuePod((i16) 2 ), NTypeIds::Int16 , {16, 8 } }, - {NUdf::TUnboxedValuePod((i32) 3 ), NTypeIds::Int32 , {16, 8 } }, - {NUdf::TUnboxedValuePod((i64) 4 ), NTypeIds::Int64 , {16, 8 } }, - {NUdf::TUnboxedValuePod((ui8) 5 ), NTypeIds::Uint8 , {16, 8 } }, - {NUdf::TUnboxedValuePod((ui16) 6 ), NTypeIds::Uint16 , {16, 8 } }, - {NUdf::TUnboxedValuePod((ui32) 7 ), NTypeIds::Uint32 , {16, 8 } }, - {NUdf::TUnboxedValuePod((ui64) 8 ), NTypeIds::Uint64 , {16, 8 } }, - {NUdf::TUnboxedValuePod((float) 1.2 ), NTypeIds::Float , {16, 8 } }, - {NUdf::TUnboxedValuePod(123456789012), NTypeIds::Date , {16, 8 } }, - {NUdf::TUnboxedValuePod((double) 3.4), NTypeIds::Double , {16, 8 } }, - {NUdf::TUnboxedValuePod(123456789012), NTypeIds::Datetime , {16, 8 } }, - {NUdf::TUnboxedValuePod(123456789012), NTypeIds::Timestamp , {16, 8 } }, - {NUdf::TUnboxedValuePod(123456789012), NTypeIds::Interval , {16, 8 } }, - {NUdf::TUnboxedValuePod(decimalVal ), NTypeIds::Decimal , {16, 16} }, - {NUdf::TUnboxedValuePod::Embedded("12charecters"), NTypeIds::String , {16, 12 } }, - {NUdf::TUnboxedValuePod::Embedded("foooo"), NTypeIds::String , {16, 8 } }, - {NUdf::TUnboxedValuePod::Embedded("FOOD!"), NTypeIds::Utf8 , {16, 8 } }, - {NUdf::TUnboxedValuePod::Embedded("{j:0}"), NTypeIds::Json , {16, 8 } }, - {NUdf::TUnboxedValuePod::Embedded("{y:0}"), NTypeIds::Yson , {16, 8 } }, - {containsLongString , NTypeIds::String, {16 + pattern.size(), pattern.size()}} - }; - - for (auto& testCase: cases) { - auto sizes = GetUnboxedValueSizeForTests(testCase.Value, testCase.Type); - UNIT_ASSERT_EQUAL_C(sizes, testCase.ExpectedSizes, "Wrong size for type " << NScheme::TypeName(testCase.Type)); - } - } - - Y_UNIT_TEST(ArrowToUnboxedValueConverter) { - TVector<TDataRow> rows = TestRows(); - std::shared_ptr<arrow::RecordBatch> batch = VectorToBatch(rows); - NKikimr::NMiniKQL::TScopedAlloc alloc; - TMemoryUsageInfo memInfo(""); - THolderFactory factory(alloc.Ref(), memInfo); - - TKqpScanComputeContext::TScanData scanData({}, TTableRange({}), rows.front().Columns(), {}, {}); - - scanData.AddRows(*batch, {}, factory); - - for (auto& row: rows) { - auto result_row = scanData.TakeRow(); - UNIT_ASSERT_EQUAL(result_row.GetElement(0 ).Get<bool >(), row.Bool ); - UNIT_ASSERT_EQUAL(result_row.GetElement(1 ).Get<i8 >(), row.Int8 ); - UNIT_ASSERT_EQUAL(result_row.GetElement(2 ).Get<i16 >(), row.Int16 ); - UNIT_ASSERT_EQUAL(result_row.GetElement(3 ).Get<i32 >(), row.Int32 ); - UNIT_ASSERT_EQUAL(result_row.GetElement(4 ).Get<i64 >(), row.Int64 ); - UNIT_ASSERT_EQUAL(result_row.GetElement(5 ).Get<ui8 >(), row.UInt8 ); - UNIT_ASSERT_EQUAL(result_row.GetElement(6 ).Get<ui16 >(), row.UInt16 ); - UNIT_ASSERT_EQUAL(result_row.GetElement(7 ).Get<ui32 >(), row.UInt32 ); - UNIT_ASSERT_EQUAL(result_row.GetElement(8 ).Get<ui64 >(), row.UInt64 ); - UNIT_ASSERT_EQUAL(result_row.GetElement(9 ).Get<float >(), row.Float32); - UNIT_ASSERT_EQUAL(result_row.GetElement(10).Get<double>(), row.Float64); - auto tmpString = result_row.GetElement(11); - UNIT_ASSERT_EQUAL(TString(tmpString.AsStringRef().Data()), row.String); - auto tmpUtf8 = result_row.GetElement(12); - UNIT_ASSERT_EQUAL(TString(tmpUtf8.AsStringRef().Data()), row.Utf8); - auto tmpJson = result_row.GetElement(13); - UNIT_ASSERT_EQUAL(TString(tmpJson.AsStringRef().Data()), row.Json); - auto tmpYson = result_row.GetElement(14); - UNIT_ASSERT_EQUAL(TString(tmpYson.AsStringRef().Data()), row.Yson); - UNIT_ASSERT_EQUAL(result_row.GetElement(15).Get<i32 >(), row.Date ); - UNIT_ASSERT_EQUAL(result_row.GetElement(16).Get<i64 >(), row.Datetime ); - UNIT_ASSERT_EQUAL(result_row.GetElement(17).Get<i64 >(), row.Timestamp); - UNIT_ASSERT_EQUAL(result_row.GetElement(18).Get<i64 >(), row.Interval ); - UNIT_ASSERT_EQUAL(result_row.GetElement(19).GetInt128(), row.Decimal ); - } - - UNIT_ASSERT(scanData.IsEmpty()); - - scanData.Clear(); - } - - Y_UNIT_TEST(EmptyColumns) { - NKikimr::NMiniKQL::TScopedAlloc alloc; - TMemoryUsageInfo memInfo(""); - THolderFactory factory(alloc.Ref(), memInfo); - - TKqpScanComputeContext::TScanData scanData({}, TTableRange({}), {}, {}, {}); - TVector<TOwnedCellVec> emptyBatch(1000); - auto bytes = scanData.AddRows(emptyBatch, {}, factory); - UNIT_ASSERT(bytes > 0); - - for (const auto& row: emptyBatch) { - Y_UNUSED(row); - UNIT_ASSERT(!scanData.IsEmpty()); - auto item = scanData.TakeRow(); - UNIT_ASSERT(item.GetListLength() == 0); - } - UNIT_ASSERT(scanData.IsEmpty()); - } - - Y_UNIT_TEST(EmptyColumnsAndNonEmptyArrowBatch) { -NKikimr::NMiniKQL::TScopedAlloc alloc; - TMemoryUsageInfo memInfo(""); - THolderFactory factory(alloc.Ref(), memInfo); - TKqpScanComputeContext::TScanData scanData({}, TTableRange({}), {}, {}, {}); - - TVector<TDataRow> rows = TestRows(); - std::shared_ptr<arrow::RecordBatch> anotherEmptyBatch = VectorToBatch(rows); - - auto bytes = scanData.AddRows(*anotherEmptyBatch, {}, factory); - UNIT_ASSERT(bytes > 0); - for (const auto& row: rows) { - Y_UNUSED(row); - UNIT_ASSERT(!scanData.IsEmpty()); - auto item = scanData.TakeRow(); - UNIT_ASSERT(item.GetListLength() == 0); - } - UNIT_ASSERT(scanData.IsEmpty()); - } -} - +#include <library/cpp/testing/unittest/registar.h> + +namespace NKikimr::NMiniKQL { + +namespace { +namespace NTypeIds = NScheme::NTypeIds; +using TTypeId = NScheme::TTypeId; + +struct TDataRow { + TSmallVec<TKqpComputeContextBase::TColumn> Columns() { + return { + {0, NTypeIds::Bool}, + {1, NTypeIds::Int8}, + {2, NTypeIds::Int16}, + {3, NTypeIds::Int32}, + {4, NTypeIds::Int64}, + {5, NTypeIds::Uint8}, + {6, NTypeIds::Uint16}, + {7, NTypeIds::Uint32}, + {8, NTypeIds::Uint64}, + {9, NTypeIds::Float}, + {10, NTypeIds::Double}, + {11, NTypeIds::String}, + {12, NTypeIds::Utf8}, + {13, NTypeIds::Json}, + {14, NTypeIds::Yson}, + {15, NTypeIds::Date}, + {16, NTypeIds::Datetime}, + {17, NTypeIds::Timestamp}, + {18, NTypeIds::Interval}, + {19, NTypeIds::Decimal}, + }; + } + + bool Bool; + i8 Int8; + i16 Int16; + i32 Int32; + i64 Int64; + ui8 UInt8; + ui16 UInt16; + ui32 UInt32; + ui64 UInt64; + float Float32; + double Float64; + TString String; + TString Utf8; + TString Json; + TString Yson; + i32 Date; + i64 Datetime; + i64 Timestamp; + i64 Interval; + NYql::NDecimal::TInt128 Decimal; + + static std::shared_ptr<arrow::Schema> MakeArrowSchema() { + std::vector<std::shared_ptr<arrow::Field>> fields = { + arrow::field("bool", arrow::boolean()), + arrow::field("i8", arrow::int8()), + arrow::field("i16", arrow::int16()), + arrow::field("i32", arrow::int32()), + arrow::field("i64", arrow::int64()), + arrow::field("ui8", arrow::uint8()), + arrow::field("ui16", arrow::uint16()), + arrow::field("ui32", arrow::uint32()), + arrow::field("ui64", arrow::uint64()), + arrow::field("f32", arrow::float32()), + arrow::field("f64", arrow::float64()), + arrow::field("string", arrow::utf8()), + arrow::field("utf8", arrow::utf8()), + arrow::field("json", arrow::binary()), + arrow::field("yson", arrow::binary()), + arrow::field("date", arrow::date32()), + arrow::field("datetime", arrow::timestamp(arrow::TimeUnit::TimeUnit::SECOND)), + arrow::field("ts", arrow::timestamp(arrow::TimeUnit::TimeUnit::MICRO)), + arrow::field("ival", arrow::duration(arrow::TimeUnit::TimeUnit::MICRO)), + arrow::field("dec", arrow::decimal(NScheme::DECIMAL_PRECISION, NScheme::DECIMAL_SCALE)), + }; + + return std::make_shared<arrow::Schema>(fields); + } +}; + +std::shared_ptr<arrow::RecordBatch> VectorToBatch(const std::vector<struct TDataRow>& rows) { + TString err; + std::unique_ptr<arrow::RecordBatchBuilder> batchBuilder = nullptr; + std::shared_ptr<arrow::RecordBatch> batch; + auto result = arrow::RecordBatchBuilder::Make(rows.front().MakeArrowSchema(), arrow::default_memory_pool(), &batchBuilder); + UNIT_ASSERT(result.ok()); + + for (const TDataRow& row : rows) { + auto result0 = batchBuilder->GetFieldAs<arrow::BooleanBuilder >(0 )->Append(row.Bool ); + UNIT_ASSERT(result.ok()); + auto result1 = batchBuilder->GetFieldAs<arrow::Int8Builder >(1 )->Append(row.Int8 ); + UNIT_ASSERT(result.ok()); + auto result2 = batchBuilder->GetFieldAs<arrow::Int16Builder >(2 )->Append(row.Int16 ); + UNIT_ASSERT(result.ok()); + auto result3 = batchBuilder->GetFieldAs<arrow::Int32Builder >(3 )->Append(row.Int32 ); + UNIT_ASSERT(result.ok()); + auto result4 = batchBuilder->GetFieldAs<arrow::Int64Builder >(4 )->Append(row.Int64 ); + UNIT_ASSERT(result.ok()); + auto result5 = batchBuilder->GetFieldAs<arrow::UInt8Builder >(5 )->Append(row.UInt8 ); + UNIT_ASSERT(result.ok()); + auto result6 = batchBuilder->GetFieldAs<arrow::UInt16Builder >(6 )->Append(row.UInt16 ); + UNIT_ASSERT(result.ok()); + auto result7 = batchBuilder->GetFieldAs<arrow::UInt32Builder >(7 )->Append(row.UInt32 ); + UNIT_ASSERT(result.ok()); + auto result8 = batchBuilder->GetFieldAs<arrow::UInt64Builder >(8 )->Append(row.UInt64 ); + UNIT_ASSERT(result.ok()); + auto result9 = batchBuilder->GetFieldAs<arrow::FloatBuilder >(9 )->Append(row.Float32); + UNIT_ASSERT(result.ok()); + auto result10 = batchBuilder->GetFieldAs<arrow::DoubleBuilder >(10)->Append(row.Float64); + UNIT_ASSERT(result.ok()); + auto result11 = batchBuilder->GetFieldAs<arrow::StringBuilder >(11)->Append(row.String.data(), row.String.size()); + UNIT_ASSERT(result.ok()); + auto result12 = batchBuilder->GetFieldAs<arrow::StringBuilder >(12)->Append(row.Utf8.data(), row.Utf8.size()); + UNIT_ASSERT(result.ok()); + auto result13 = batchBuilder->GetFieldAs<arrow::BinaryBuilder >(13)->Append(row.Json.data(), row.Json.size()); + UNIT_ASSERT(result.ok()); + auto result14 = batchBuilder->GetFieldAs<arrow::BinaryBuilder >(14)->Append(row.Yson.data(), row.Yson.size()); + UNIT_ASSERT(result.ok()); + auto result15 = batchBuilder->GetFieldAs<arrow::Date32Builder >(15)->Append(row.Date); + UNIT_ASSERT(result.ok()); + auto result16 = batchBuilder->GetFieldAs<arrow::TimestampBuilder >(16)->Append(row.Datetime); + UNIT_ASSERT(result.ok()); + auto result17 = batchBuilder->GetFieldAs<arrow::TimestampBuilder >(17)->Append(row.Timestamp); + UNIT_ASSERT(result.ok()); + auto result18 = batchBuilder->GetFieldAs<arrow::DurationBuilder >(18)->Append(row.Interval); + UNIT_ASSERT(result.ok()); + auto result19 = batchBuilder->GetFieldAs<arrow::Decimal128Builder>(19)->Append(reinterpret_cast<const char*>(&row.Decimal)); + UNIT_ASSERT(result.ok()); + } + + auto resultFlush = batchBuilder->Flush(&batch); + UNIT_ASSERT(resultFlush.ok()); + return batch; +} + +TVector<TDataRow> TestRows() { + TVector<TDataRow> rows = { + {false, -1, -1, -1, -1, 1, 1, 1, 1, -1.0f, -1.0, "s1" , "u1" , "{j:1}", "{y:1}", 0, 0, 0, 0, 111}, + {false, 2, 2, 2, 2, 2, 2, 2, 2, 2.0f, 2.0, "s2" , "u2" , "{j:2}", "{y:2}", 0, 0, 0, 0, 222}, + {false, -3, -3, -3, -3, 3, 3, 3, 3, -3.0f, -3.0, "s3" , "u3" , "{j:3}", "{y:3}", 0, 0, 0, 0, 333}, + {false, -4, -4, -4, -4, 4, 4, 4, 4, 4.0f, 4.0, "s4" , "u4" , "{j:4}", "{y:4}", 0, 0, 0, 0, 444}, + {false, -5, -5, -5, -5, 5, 5, 5, 5, 5.0f, 5.0, "long5long5long5long5long5", "utflong5utflong5utflong5", "{j:5}", "{y:5}", 0, 0, 0, 0, 555}, + }; + return rows; +} + +} + +Y_UNIT_TEST_SUITE(TKqpScanData) { + + Y_UNIT_TEST(UnboxedValueSize) { + NKikimr::NMiniKQL::TScopedAlloc alloc; + namespace NTypeIds = NScheme::NTypeIds; + struct TTestCase { + NUdf::TUnboxedValue Value; + NScheme::TTypeId Type; + std::pair<ui64, ui64> ExpectedSizes; + }; + TString pattern = "This string has 26 symbols"; + NUdf::TStringValue str(pattern.size()); + std::memcpy(str.Data(), pattern.data(), pattern.size()); + NUdf::TUnboxedValue containsLongString(NUdf::TUnboxedValuePod(std::move(str))); + NYql::NDecimal::TInt128 decimalVal = 123456789012; + TVector<TTestCase> cases = { + {NUdf::TUnboxedValuePod( ), NTypeIds::Bool , {16, 8 } }, + {NUdf::TUnboxedValuePod( ), NTypeIds::Int32 , {16, 8 } }, + {NUdf::TUnboxedValuePod( ), NTypeIds::Uint32 , {16, 8 } }, + {NUdf::TUnboxedValuePod( ), NTypeIds::Int64 , {16, 8 } }, + {NUdf::TUnboxedValuePod( ), NTypeIds::Uint64 , {16, 8 } }, + {NUdf::TUnboxedValuePod( ), NTypeIds::Double , {16, 8 } }, + {NUdf::TUnboxedValuePod( ), NTypeIds::Float , {16, 8 } }, + {NUdf::TUnboxedValuePod( ), NTypeIds::String , {16, 8 } }, + {NUdf::TUnboxedValuePod( ), NTypeIds::Utf8 , {16, 8 } }, + {NUdf::TUnboxedValuePod( ), NTypeIds::Yson , {16, 8 } }, + {NUdf::TUnboxedValuePod( ), NTypeIds::Json , {16, 8 } }, + {NUdf::TUnboxedValuePod( ), NTypeIds::Decimal , {16, 8 } }, + {NUdf::TUnboxedValuePod( ), NTypeIds::Date , {16, 8 } }, + {NUdf::TUnboxedValuePod( ), NTypeIds::Datetime , {16, 8 } }, + {NUdf::TUnboxedValuePod( ), NTypeIds::Timestamp , {16, 8 } }, + {NUdf::TUnboxedValuePod( ), NTypeIds::Interval , {16, 8 } }, + {NUdf::TUnboxedValuePod(true ), NTypeIds::Bool , {16, 8 } }, + {NUdf::TUnboxedValuePod((i8) 1 ), NTypeIds::Int8 , {16, 8 } }, + {NUdf::TUnboxedValuePod((i16) 2 ), NTypeIds::Int16 , {16, 8 } }, + {NUdf::TUnboxedValuePod((i32) 3 ), NTypeIds::Int32 , {16, 8 } }, + {NUdf::TUnboxedValuePod((i64) 4 ), NTypeIds::Int64 , {16, 8 } }, + {NUdf::TUnboxedValuePod((ui8) 5 ), NTypeIds::Uint8 , {16, 8 } }, + {NUdf::TUnboxedValuePod((ui16) 6 ), NTypeIds::Uint16 , {16, 8 } }, + {NUdf::TUnboxedValuePod((ui32) 7 ), NTypeIds::Uint32 , {16, 8 } }, + {NUdf::TUnboxedValuePod((ui64) 8 ), NTypeIds::Uint64 , {16, 8 } }, + {NUdf::TUnboxedValuePod((float) 1.2 ), NTypeIds::Float , {16, 8 } }, + {NUdf::TUnboxedValuePod(123456789012), NTypeIds::Date , {16, 8 } }, + {NUdf::TUnboxedValuePod((double) 3.4), NTypeIds::Double , {16, 8 } }, + {NUdf::TUnboxedValuePod(123456789012), NTypeIds::Datetime , {16, 8 } }, + {NUdf::TUnboxedValuePod(123456789012), NTypeIds::Timestamp , {16, 8 } }, + {NUdf::TUnboxedValuePod(123456789012), NTypeIds::Interval , {16, 8 } }, + {NUdf::TUnboxedValuePod(decimalVal ), NTypeIds::Decimal , {16, 16} }, + {NUdf::TUnboxedValuePod::Embedded("12charecters"), NTypeIds::String , {16, 12 } }, + {NUdf::TUnboxedValuePod::Embedded("foooo"), NTypeIds::String , {16, 8 } }, + {NUdf::TUnboxedValuePod::Embedded("FOOD!"), NTypeIds::Utf8 , {16, 8 } }, + {NUdf::TUnboxedValuePod::Embedded("{j:0}"), NTypeIds::Json , {16, 8 } }, + {NUdf::TUnboxedValuePod::Embedded("{y:0}"), NTypeIds::Yson , {16, 8 } }, + {containsLongString , NTypeIds::String, {16 + pattern.size(), pattern.size()}} + }; + + for (auto& testCase: cases) { + auto sizes = GetUnboxedValueSizeForTests(testCase.Value, testCase.Type); + UNIT_ASSERT_EQUAL_C(sizes, testCase.ExpectedSizes, "Wrong size for type " << NScheme::TypeName(testCase.Type)); + } + } + + Y_UNIT_TEST(ArrowToUnboxedValueConverter) { + TVector<TDataRow> rows = TestRows(); + std::shared_ptr<arrow::RecordBatch> batch = VectorToBatch(rows); + NKikimr::NMiniKQL::TScopedAlloc alloc; + TMemoryUsageInfo memInfo(""); + THolderFactory factory(alloc.Ref(), memInfo); + + TKqpScanComputeContext::TScanData scanData({}, TTableRange({}), rows.front().Columns(), {}, {}); + + scanData.AddRows(*batch, {}, factory); + + for (auto& row: rows) { + auto result_row = scanData.TakeRow(); + UNIT_ASSERT_EQUAL(result_row.GetElement(0 ).Get<bool >(), row.Bool ); + UNIT_ASSERT_EQUAL(result_row.GetElement(1 ).Get<i8 >(), row.Int8 ); + UNIT_ASSERT_EQUAL(result_row.GetElement(2 ).Get<i16 >(), row.Int16 ); + UNIT_ASSERT_EQUAL(result_row.GetElement(3 ).Get<i32 >(), row.Int32 ); + UNIT_ASSERT_EQUAL(result_row.GetElement(4 ).Get<i64 >(), row.Int64 ); + UNIT_ASSERT_EQUAL(result_row.GetElement(5 ).Get<ui8 >(), row.UInt8 ); + UNIT_ASSERT_EQUAL(result_row.GetElement(6 ).Get<ui16 >(), row.UInt16 ); + UNIT_ASSERT_EQUAL(result_row.GetElement(7 ).Get<ui32 >(), row.UInt32 ); + UNIT_ASSERT_EQUAL(result_row.GetElement(8 ).Get<ui64 >(), row.UInt64 ); + UNIT_ASSERT_EQUAL(result_row.GetElement(9 ).Get<float >(), row.Float32); + UNIT_ASSERT_EQUAL(result_row.GetElement(10).Get<double>(), row.Float64); + auto tmpString = result_row.GetElement(11); + UNIT_ASSERT_EQUAL(TString(tmpString.AsStringRef().Data()), row.String); + auto tmpUtf8 = result_row.GetElement(12); + UNIT_ASSERT_EQUAL(TString(tmpUtf8.AsStringRef().Data()), row.Utf8); + auto tmpJson = result_row.GetElement(13); + UNIT_ASSERT_EQUAL(TString(tmpJson.AsStringRef().Data()), row.Json); + auto tmpYson = result_row.GetElement(14); + UNIT_ASSERT_EQUAL(TString(tmpYson.AsStringRef().Data()), row.Yson); + UNIT_ASSERT_EQUAL(result_row.GetElement(15).Get<i32 >(), row.Date ); + UNIT_ASSERT_EQUAL(result_row.GetElement(16).Get<i64 >(), row.Datetime ); + UNIT_ASSERT_EQUAL(result_row.GetElement(17).Get<i64 >(), row.Timestamp); + UNIT_ASSERT_EQUAL(result_row.GetElement(18).Get<i64 >(), row.Interval ); + UNIT_ASSERT_EQUAL(result_row.GetElement(19).GetInt128(), row.Decimal ); + } + + UNIT_ASSERT(scanData.IsEmpty()); + + scanData.Clear(); + } + + Y_UNIT_TEST(EmptyColumns) { + NKikimr::NMiniKQL::TScopedAlloc alloc; + TMemoryUsageInfo memInfo(""); + THolderFactory factory(alloc.Ref(), memInfo); + + TKqpScanComputeContext::TScanData scanData({}, TTableRange({}), {}, {}, {}); + TVector<TOwnedCellVec> emptyBatch(1000); + auto bytes = scanData.AddRows(emptyBatch, {}, factory); + UNIT_ASSERT(bytes > 0); + + for (const auto& row: emptyBatch) { + Y_UNUSED(row); + UNIT_ASSERT(!scanData.IsEmpty()); + auto item = scanData.TakeRow(); + UNIT_ASSERT(item.GetListLength() == 0); + } + UNIT_ASSERT(scanData.IsEmpty()); + } + + Y_UNIT_TEST(EmptyColumnsAndNonEmptyArrowBatch) { +NKikimr::NMiniKQL::TScopedAlloc alloc; + TMemoryUsageInfo memInfo(""); + THolderFactory factory(alloc.Ref(), memInfo); + TKqpScanComputeContext::TScanData scanData({}, TTableRange({}), {}, {}, {}); + + TVector<TDataRow> rows = TestRows(); + std::shared_ptr<arrow::RecordBatch> anotherEmptyBatch = VectorToBatch(rows); + + auto bytes = scanData.AddRows(*anotherEmptyBatch, {}, factory); + UNIT_ASSERT(bytes > 0); + for (const auto& row: rows) { + Y_UNUSED(row); + UNIT_ASSERT(!scanData.IsEmpty()); + auto item = scanData.TakeRow(); + UNIT_ASSERT(item.GetListLength() == 0); + } + UNIT_ASSERT(scanData.IsEmpty()); + } +} + } // namespace NKikimr::NMiniKQL diff --git a/ydb/core/kqp/runtime/kqp_transport.cpp b/ydb/core/kqp/runtime/kqp_transport.cpp index 1b593ef546c..8c1b3262d0d 100644 --- a/ydb/core/kqp/runtime/kqp_transport.cpp +++ b/ydb/core/kqp/runtime/kqp_transport.cpp @@ -86,26 +86,26 @@ void TKqpProtoBuilder::BuildValue(const TVector<NDqProto::TData>& data, const NK auto mkqlType = ImportTypeFromProto(result->GetType(), *TypeEnv); TUnboxedValueVector buffer; - auto transportVersion = NDqProto::EDataTransportVersion::DATA_TRANSPORT_VERSION_UNSPECIFIED; - if (!data.empty()) { - switch (data.front().GetTransportVersion()) { - case 10000: { - transportVersion = NDqProto::EDataTransportVersion::DATA_TRANSPORT_YSON_1_0; - break; - } - case 20000: { - transportVersion = NDqProto::EDataTransportVersion::DATA_TRANSPORT_UV_PICKLE_1_0; - break; - } - case 30000: { - transportVersion = NDqProto::EDataTransportVersion::DATA_TRANSPORT_ARROW_1_0; - break; - } - default: - transportVersion = NDqProto::EDataTransportVersion::DATA_TRANSPORT_VERSION_UNSPECIFIED; - } - } - NDq::TDqDataSerializer dataSerializer(*TypeEnv, *HolderFactory, transportVersion); + auto transportVersion = NDqProto::EDataTransportVersion::DATA_TRANSPORT_VERSION_UNSPECIFIED; + if (!data.empty()) { + switch (data.front().GetTransportVersion()) { + case 10000: { + transportVersion = NDqProto::EDataTransportVersion::DATA_TRANSPORT_YSON_1_0; + break; + } + case 20000: { + transportVersion = NDqProto::EDataTransportVersion::DATA_TRANSPORT_UV_PICKLE_1_0; + break; + } + case 30000: { + transportVersion = NDqProto::EDataTransportVersion::DATA_TRANSPORT_ARROW_1_0; + break; + } + default: + transportVersion = NDqProto::EDataTransportVersion::DATA_TRANSPORT_VERSION_UNSPECIFIED; + } + } + NDq::TDqDataSerializer dataSerializer(*TypeEnv, *HolderFactory, transportVersion); for (auto& part : data) { dataSerializer.Deserialize(part, mkqlType, buffer); } @@ -143,27 +143,27 @@ void TKqpProtoBuilder::BuildStream(const TVector<NDqProto::TData>& data, const N guard = MakeHolder<TGuard<TScopedAlloc>>(*Alloc); } - auto transportVersion = NDqProto::EDataTransportVersion::DATA_TRANSPORT_VERSION_UNSPECIFIED; - if (!data.empty()) { - switch (data.front().GetTransportVersion()) { - case 10000: { - transportVersion = NDqProto::EDataTransportVersion::DATA_TRANSPORT_YSON_1_0; - break; - } - case 20000: { - transportVersion = NDqProto::EDataTransportVersion::DATA_TRANSPORT_UV_PICKLE_1_0; - break; - } - case 30000: { - transportVersion = NDqProto::EDataTransportVersion::DATA_TRANSPORT_ARROW_1_0; - break; - } - default: - transportVersion = NDqProto::EDataTransportVersion::DATA_TRANSPORT_VERSION_UNSPECIFIED; - } - } - NDq::TDqDataSerializer dataSerializer(*TypeEnv, *HolderFactory, transportVersion); - + auto transportVersion = NDqProto::EDataTransportVersion::DATA_TRANSPORT_VERSION_UNSPECIFIED; + if (!data.empty()) { + switch (data.front().GetTransportVersion()) { + case 10000: { + transportVersion = NDqProto::EDataTransportVersion::DATA_TRANSPORT_YSON_1_0; + break; + } + case 20000: { + transportVersion = NDqProto::EDataTransportVersion::DATA_TRANSPORT_UV_PICKLE_1_0; + break; + } + case 30000: { + transportVersion = NDqProto::EDataTransportVersion::DATA_TRANSPORT_ARROW_1_0; + break; + } + default: + transportVersion = NDqProto::EDataTransportVersion::DATA_TRANSPORT_VERSION_UNSPECIFIED; + } + } + NDq::TDqDataSerializer dataSerializer(*TypeEnv, *HolderFactory, transportVersion); + if (dstRowType) { YQL_ENSURE(dstRowType->GetKind() == NKikimrMiniKQL::Struct); newRowType->CopyFrom(*dstRowType); @@ -281,26 +281,26 @@ Ydb::ResultSet TKqpProtoBuilder::BuildYdbResultSet(const TVector<NDqProto::TData guard = MakeHolder<TGuard<TScopedAlloc>>(*Alloc); } - auto transportVersion = NDqProto::EDataTransportVersion::DATA_TRANSPORT_VERSION_UNSPECIFIED; - if (!data.empty()) { - switch (data.front().GetTransportVersion()) { - case 10000: { - transportVersion = NDqProto::EDataTransportVersion::DATA_TRANSPORT_YSON_1_0; - break; - } - case 20000: { - transportVersion = NDqProto::EDataTransportVersion::DATA_TRANSPORT_UV_PICKLE_1_0; - break; - } - case 30000: { - transportVersion = NDqProto::EDataTransportVersion::DATA_TRANSPORT_ARROW_1_0; - break; - } - default: - transportVersion = NDqProto::EDataTransportVersion::DATA_TRANSPORT_VERSION_UNSPECIFIED; - } - } - NDq::TDqDataSerializer dataSerializer(*TypeEnv, *HolderFactory, transportVersion); + auto transportVersion = NDqProto::EDataTransportVersion::DATA_TRANSPORT_VERSION_UNSPECIFIED; + if (!data.empty()) { + switch (data.front().GetTransportVersion()) { + case 10000: { + transportVersion = NDqProto::EDataTransportVersion::DATA_TRANSPORT_YSON_1_0; + break; + } + case 20000: { + transportVersion = NDqProto::EDataTransportVersion::DATA_TRANSPORT_UV_PICKLE_1_0; + break; + } + case 30000: { + transportVersion = NDqProto::EDataTransportVersion::DATA_TRANSPORT_ARROW_1_0; + break; + } + default: + transportVersion = NDqProto::EDataTransportVersion::DATA_TRANSPORT_VERSION_UNSPECIFIED; + } + } + NDq::TDqDataSerializer dataSerializer(*TypeEnv, *HolderFactory, transportVersion); for (auto& part : data) { if (part.GetRows()) { diff --git a/ydb/core/kqp/runtime/ut/ya.make b/ydb/core/kqp/runtime/ut/ya.make index e2ebaee9b1e..ea3f10523c2 100644 --- a/ydb/core/kqp/runtime/ut/ya.make +++ b/ydb/core/kqp/runtime/ut/ya.make @@ -9,12 +9,12 @@ IF (SANITIZER_TYPE OR WITH_VALGRIND) ENDIF() SRCS( - # kqp_spilling_file_ut.cpp - kqp_scan_data_ut.cpp + # kqp_spilling_file_ut.cpp + kqp_scan_data_ut.cpp ) YQL_LAST_ABI_VERSION() - + PEERDIR( library/cpp/testing/unittest ydb/core/testlib/basics diff --git a/ydb/core/kqp/ut/kqp_arrow_in_channels_ut.cpp b/ydb/core/kqp/ut/kqp_arrow_in_channels_ut.cpp index ed50b29eb0b..4209f59e66c 100644 --- a/ydb/core/kqp/ut/kqp_arrow_in_channels_ut.cpp +++ b/ydb/core/kqp/ut/kqp_arrow_in_channels_ut.cpp @@ -1,233 +1,233 @@ #include <ydb/core/kqp/ut/common/kqp_ut_common.h> - -#include <util/generic/size_literals.h> - -namespace NKikimr { -namespace NKqp { - -using namespace NYdb; -using namespace NYdb::NTable; - -namespace { - -TKikimrRunner RunnerWithArrowInChannels() { - NKikimrConfig::TFeatureFlags featureFlags; - featureFlags.SetEnableArrowFormatInChannels(true); - - return TKikimrRunner{featureFlags}; -} - -void InsertAllColumnsAndCheckSelectAll(TKikimrRunner* runner) { - auto db = runner->GetTableClient(); - auto session = db.CreateSession().GetValueSync().GetSession(); - - auto createResult = session.ExecuteSchemeQuery(R"( - --!syntax_v1 - CREATE TABLE `/Root/Tmp` ( - Key Uint64, - BoolValue Bool, - Int32Value Int32, - Uint32Value Uint32, - Int64Value Int64, - Uint64Value Uint64, - FloatValue Float, - DoubleValue Double, - StringValue String, - Utf8Value Utf8, - DateValue Date, - DatetimeValue Datetime, - TimestampValue Timestamp, - IntervalValue Interval, - DecimalValue Decimal(22,9), - JsonValue Json, - YsonValue Yson, - JsonDocumentValue JsonDocument, - DyNumberValue DyNumber, - PRIMARY KEY (Key) - ); - )").GetValueSync(); - UNIT_ASSERT_C(createResult.IsSuccess(), createResult.GetIssues().ToString()); - - auto insertResult = session.ExecuteDataQuery(R"( - --!syntax_v1 - INSERT INTO `/Root/Tmp` (Key, BoolValue, Int32Value, Uint32Value, Int64Value, Uint64Value, FloatValue, DoubleValue, StringValue, Utf8Value, DateValue, DatetimeValue, TimestampValue, IntervalValue, DecimalValue, JsonValue, YsonValue, JsonDocumentValue, DyNumberValue) VALUES - (42, true, -1, 1, -2, 2, CAST(3.0 AS Float), 4.0, "five", Utf8("six"), Date("2007-07-07"), Datetime("2008-08-08T08:08:08Z"), Timestamp("2009-09-09T09:09:09.09Z"), Interval("P10D"), CAST("11.11" AS Decimal(22, 9)), "[12]", "[13]", JsonDocument("[14]"), DyNumber("15.15")); - )", TTxControl::BeginTx().CommitTx()).GetValueSync(); - UNIT_ASSERT_C(insertResult.IsSuccess(), insertResult.GetIssues().ToString()); - - auto it = db.StreamExecuteScanQuery("SELECT Key, BoolValue, Int32Value, Uint32Value, Int64Value, Uint64Value, FloatValue, DoubleValue, StringValue, Utf8Value, DateValue, DatetimeValue, TimestampValue, IntervalValue, DecimalValue, JsonValue, YsonValue, JsonDocumentValue, DyNumberValue FROM `/Root/Tmp`").GetValueSync(); - UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString()); - auto streamPart = it.ReadNext().GetValueSync(); - UNIT_ASSERT_C(streamPart.IsSuccess(), streamPart.GetIssues().ToString()); - auto resultSet = streamPart.ExtractResultSet(); - auto columns = resultSet.GetColumnsMeta(); - UNIT_ASSERT_C(columns.size() == 19, "Wrong columns count"); - NYdb::TResultSetParser parser(resultSet); - UNIT_ASSERT_C(parser.TryNextRow(), "Row is missing"); - UNIT_ASSERT(*parser.ColumnParser(0).GetOptionalUint64().Get() == 42); - UNIT_ASSERT(*parser.ColumnParser(1).GetOptionalBool().Get() == true); - UNIT_ASSERT(*parser.ColumnParser(2).GetOptionalInt32().Get() == -1); - UNIT_ASSERT(*parser.ColumnParser(3).GetOptionalUint32().Get() == 1); - UNIT_ASSERT(*parser.ColumnParser(4).GetOptionalInt64().Get() == -2); - UNIT_ASSERT(*parser.ColumnParser(5).GetOptionalUint64().Get() == 2); - UNIT_ASSERT(*parser.ColumnParser(6).GetOptionalFloat().Get() == 3.0); - UNIT_ASSERT(*parser.ColumnParser(7).GetOptionalDouble().Get() == 4.0); - UNIT_ASSERT(*parser.ColumnParser(8).GetOptionalString().Get() == TString("five")); - UNIT_ASSERT(*parser.ColumnParser(9).GetOptionalUtf8().Get() == TString("six")); - UNIT_ASSERT(*parser.ColumnParser(10).GetOptionalDate().Get() == TInstant::ParseIso8601("2007-07-07")); - UNIT_ASSERT(*parser.ColumnParser(11).GetOptionalDatetime().Get() == TInstant::ParseIso8601("2008-08-08T08:08:08Z")); - UNIT_ASSERT(*parser.ColumnParser(12).GetOptionalTimestamp().Get() == TInstant::ParseIso8601("2009-09-09T09:09:09.09Z")); - Cerr << TInstant::Days(10).MicroSeconds() << Endl; - UNIT_ASSERT(*parser.ColumnParser(13).GetOptionalInterval().Get() == TInstant::Days(10).MicroSeconds()); - UNIT_ASSERT(parser.ColumnParser(14).GetOptionalDecimal().Get()->ToString() == TString("11.11")); - UNIT_ASSERT(*parser.ColumnParser(15).GetOptionalJson().Get() == TString("[12]")); - UNIT_ASSERT(*parser.ColumnParser(16).GetOptionalYson().Get() == TString("[13]")); - UNIT_ASSERT(*parser.ColumnParser(17).GetOptionalJsonDocument().Get() == TString("[14]")); - UNIT_ASSERT(*parser.ColumnParser(18).GetOptionalDyNumber().Get() == TString(".1515e2")); - streamPart = it.ReadNext().GetValueSync(); - UNIT_ASSERT_C(streamPart.EOS(), streamPart.GetIssues().ToString()); -} - -} - -Y_UNIT_TEST_SUITE(KqpScanArrowInChanels) { - Y_UNIT_TEST(AggregateCountStar) { - auto kikimr = RunnerWithArrowInChannels(); - auto db = kikimr.GetTableClient(); - - auto it = db.StreamExecuteScanQuery("SELECT COUNT(*) FROM `/Root/EightShard`").GetValueSync(); - UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString()); - CompareYson(R"([[24u]])", StreamResultToYson(it)); - } - - Y_UNIT_TEST(AllTypesColumns) { - auto kikimr = RunnerWithArrowInChannels(); - InsertAllColumnsAndCheckSelectAll(&kikimr); - } - - Y_UNIT_TEST(SingleKey) { - auto kikimr = RunnerWithArrowInChannels(); - auto db = kikimr.GetTableClient(); - - auto params = db.GetParamsBuilder() - .AddParam("$key") - .Uint64(202) - .Build() - .Build(); - - auto it = db.StreamExecuteScanQuery(R"( - DECLARE $key AS Uint64; - - SELECT * FROM `/Root/EightShard` WHERE Key = $key; - )", params).GetValueSync(); - - UNIT_ASSERT(it.IsSuccess()); - - CompareYson(R"([ - [[1];[202u];["Value2"]] - ])", StreamResultToYson(it)); - } - - Y_UNIT_TEST(AggregateByColumn) { - auto kikimr = RunnerWithArrowInChannels(); - auto db = kikimr.GetTableClient(); - - auto it = db.StreamExecuteScanQuery(R"( - SELECT Text, SUM(Key) AS Total FROM `/Root/EightShard` - GROUP BY Text - ORDER BY Total DESC; - )").GetValueSync(); - - UNIT_ASSERT(it.IsSuccess()); - - CompareYson(R"([ - [["Value3"];[3624u]]; - [["Value2"];[3616u]]; - [["Value1"];[3608u]] - ])", StreamResultToYson(it)); - } - - Y_UNIT_TEST(AggregateNoColumn) { - auto kikimr = RunnerWithArrowInChannels(); - auto db = kikimr.GetTableClient(); - - auto it = db.StreamExecuteScanQuery(R"( - SELECT SUM(Data), AVG(Data), COUNT(*), MAX(Data), MIN(Data), SUM(Data * 3 + Key * 2) as foo - FROM `/Root/EightShard` - WHERE Key > 300 - )").GetValueSync(); - - UNIT_ASSERT(it.IsSuccess()); - - CompareYson(R"([[[36];[2.];18u;[3];[1];[19980u]]])", StreamResultToYson(it)); - } - - Y_UNIT_TEST(AggregateNoColumnNoRemaps) { - auto kikimr = RunnerWithArrowInChannels(); - auto db = kikimr.GetTableClient(); - - auto it = db.StreamExecuteScanQuery(R"( - SELECT SUM(Data), AVG(Data), COUNT(*) - FROM `/Root/EightShard` - WHERE Key > 300 - )").GetValueSync(); - - UNIT_ASSERT(it.IsSuccess()); - - CompareYson(R"([[[36];[2.];18u]])", StreamResultToYson(it)); - } - - Y_UNIT_TEST(AggregateWithFunction) { - auto kikimr = RunnerWithArrowInChannels(); - auto db = kikimr.GetTableClient(); - - auto it = db.StreamExecuteScanQuery(R"( - SELECT (SUM(Data) * 100) / (MIN(Data) + 10) - FROM `/Root/EightShard` - )").GetValueSync(); - - UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString()); - - CompareYson(R"([[[436]]])", StreamResultToYson(it)); - } - - Y_UNIT_TEST(AggregateEmptySum) { - auto kikimr = RunnerWithArrowInChannels(); - auto db = kikimr.GetTableClient(); - - auto it = db.StreamExecuteScanQuery("SELECT SUM(Data) FROM `/Root/EightShard` WHERE Key < 10").GetValueSync(); - UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString()); - CompareYson(R"([[#]])", StreamResultToYson(it)); - } - - Y_UNIT_TEST(JoinWithParams) { - auto kikimr = RunnerWithArrowInChannels(); - auto db = kikimr.GetTableClient(); - auto params = TParamsBuilder().AddParam("$in") - .BeginList() - .AddListItem().BeginStruct().AddMember("key").Uint64(1).EndStruct() - .EndList() - .Build().Build(); - // table join params - auto query1 = R"( - declare $in as List<Struct<key: UInt64>>; - select l.Key, l.Value - from `/Root/KeyValue` as l join AS_TABLE($in) as r on l.Key = r.key - )"; - // params join table - auto query2 = R"( - declare $in as List<Struct<key: UInt64>>; - select r.Key, r.Value - from AS_TABLE($in) as l join `/Root/KeyValue` as r on l.key = r.Key - )"; - for (auto& query : {query1, query2}) { - auto it = db.StreamExecuteScanQuery(query, params).GetValueSync(); - UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString()); - CompareYson(R"([[[1u];["One"]]])", StreamResultToYson(it)); - } - } -} // Test suite - -} // NKqp + +#include <util/generic/size_literals.h> + +namespace NKikimr { +namespace NKqp { + +using namespace NYdb; +using namespace NYdb::NTable; + +namespace { + +TKikimrRunner RunnerWithArrowInChannels() { + NKikimrConfig::TFeatureFlags featureFlags; + featureFlags.SetEnableArrowFormatInChannels(true); + + return TKikimrRunner{featureFlags}; +} + +void InsertAllColumnsAndCheckSelectAll(TKikimrRunner* runner) { + auto db = runner->GetTableClient(); + auto session = db.CreateSession().GetValueSync().GetSession(); + + auto createResult = session.ExecuteSchemeQuery(R"( + --!syntax_v1 + CREATE TABLE `/Root/Tmp` ( + Key Uint64, + BoolValue Bool, + Int32Value Int32, + Uint32Value Uint32, + Int64Value Int64, + Uint64Value Uint64, + FloatValue Float, + DoubleValue Double, + StringValue String, + Utf8Value Utf8, + DateValue Date, + DatetimeValue Datetime, + TimestampValue Timestamp, + IntervalValue Interval, + DecimalValue Decimal(22,9), + JsonValue Json, + YsonValue Yson, + JsonDocumentValue JsonDocument, + DyNumberValue DyNumber, + PRIMARY KEY (Key) + ); + )").GetValueSync(); + UNIT_ASSERT_C(createResult.IsSuccess(), createResult.GetIssues().ToString()); + + auto insertResult = session.ExecuteDataQuery(R"( + --!syntax_v1 + INSERT INTO `/Root/Tmp` (Key, BoolValue, Int32Value, Uint32Value, Int64Value, Uint64Value, FloatValue, DoubleValue, StringValue, Utf8Value, DateValue, DatetimeValue, TimestampValue, IntervalValue, DecimalValue, JsonValue, YsonValue, JsonDocumentValue, DyNumberValue) VALUES + (42, true, -1, 1, -2, 2, CAST(3.0 AS Float), 4.0, "five", Utf8("six"), Date("2007-07-07"), Datetime("2008-08-08T08:08:08Z"), Timestamp("2009-09-09T09:09:09.09Z"), Interval("P10D"), CAST("11.11" AS Decimal(22, 9)), "[12]", "[13]", JsonDocument("[14]"), DyNumber("15.15")); + )", TTxControl::BeginTx().CommitTx()).GetValueSync(); + UNIT_ASSERT_C(insertResult.IsSuccess(), insertResult.GetIssues().ToString()); + + auto it = db.StreamExecuteScanQuery("SELECT Key, BoolValue, Int32Value, Uint32Value, Int64Value, Uint64Value, FloatValue, DoubleValue, StringValue, Utf8Value, DateValue, DatetimeValue, TimestampValue, IntervalValue, DecimalValue, JsonValue, YsonValue, JsonDocumentValue, DyNumberValue FROM `/Root/Tmp`").GetValueSync(); + UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString()); + auto streamPart = it.ReadNext().GetValueSync(); + UNIT_ASSERT_C(streamPart.IsSuccess(), streamPart.GetIssues().ToString()); + auto resultSet = streamPart.ExtractResultSet(); + auto columns = resultSet.GetColumnsMeta(); + UNIT_ASSERT_C(columns.size() == 19, "Wrong columns count"); + NYdb::TResultSetParser parser(resultSet); + UNIT_ASSERT_C(parser.TryNextRow(), "Row is missing"); + UNIT_ASSERT(*parser.ColumnParser(0).GetOptionalUint64().Get() == 42); + UNIT_ASSERT(*parser.ColumnParser(1).GetOptionalBool().Get() == true); + UNIT_ASSERT(*parser.ColumnParser(2).GetOptionalInt32().Get() == -1); + UNIT_ASSERT(*parser.ColumnParser(3).GetOptionalUint32().Get() == 1); + UNIT_ASSERT(*parser.ColumnParser(4).GetOptionalInt64().Get() == -2); + UNIT_ASSERT(*parser.ColumnParser(5).GetOptionalUint64().Get() == 2); + UNIT_ASSERT(*parser.ColumnParser(6).GetOptionalFloat().Get() == 3.0); + UNIT_ASSERT(*parser.ColumnParser(7).GetOptionalDouble().Get() == 4.0); + UNIT_ASSERT(*parser.ColumnParser(8).GetOptionalString().Get() == TString("five")); + UNIT_ASSERT(*parser.ColumnParser(9).GetOptionalUtf8().Get() == TString("six")); + UNIT_ASSERT(*parser.ColumnParser(10).GetOptionalDate().Get() == TInstant::ParseIso8601("2007-07-07")); + UNIT_ASSERT(*parser.ColumnParser(11).GetOptionalDatetime().Get() == TInstant::ParseIso8601("2008-08-08T08:08:08Z")); + UNIT_ASSERT(*parser.ColumnParser(12).GetOptionalTimestamp().Get() == TInstant::ParseIso8601("2009-09-09T09:09:09.09Z")); + Cerr << TInstant::Days(10).MicroSeconds() << Endl; + UNIT_ASSERT(*parser.ColumnParser(13).GetOptionalInterval().Get() == TInstant::Days(10).MicroSeconds()); + UNIT_ASSERT(parser.ColumnParser(14).GetOptionalDecimal().Get()->ToString() == TString("11.11")); + UNIT_ASSERT(*parser.ColumnParser(15).GetOptionalJson().Get() == TString("[12]")); + UNIT_ASSERT(*parser.ColumnParser(16).GetOptionalYson().Get() == TString("[13]")); + UNIT_ASSERT(*parser.ColumnParser(17).GetOptionalJsonDocument().Get() == TString("[14]")); + UNIT_ASSERT(*parser.ColumnParser(18).GetOptionalDyNumber().Get() == TString(".1515e2")); + streamPart = it.ReadNext().GetValueSync(); + UNIT_ASSERT_C(streamPart.EOS(), streamPart.GetIssues().ToString()); +} + +} + +Y_UNIT_TEST_SUITE(KqpScanArrowInChanels) { + Y_UNIT_TEST(AggregateCountStar) { + auto kikimr = RunnerWithArrowInChannels(); + auto db = kikimr.GetTableClient(); + + auto it = db.StreamExecuteScanQuery("SELECT COUNT(*) FROM `/Root/EightShard`").GetValueSync(); + UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString()); + CompareYson(R"([[24u]])", StreamResultToYson(it)); + } + + Y_UNIT_TEST(AllTypesColumns) { + auto kikimr = RunnerWithArrowInChannels(); + InsertAllColumnsAndCheckSelectAll(&kikimr); + } + + Y_UNIT_TEST(SingleKey) { + auto kikimr = RunnerWithArrowInChannels(); + auto db = kikimr.GetTableClient(); + + auto params = db.GetParamsBuilder() + .AddParam("$key") + .Uint64(202) + .Build() + .Build(); + + auto it = db.StreamExecuteScanQuery(R"( + DECLARE $key AS Uint64; + + SELECT * FROM `/Root/EightShard` WHERE Key = $key; + )", params).GetValueSync(); + + UNIT_ASSERT(it.IsSuccess()); + + CompareYson(R"([ + [[1];[202u];["Value2"]] + ])", StreamResultToYson(it)); + } + + Y_UNIT_TEST(AggregateByColumn) { + auto kikimr = RunnerWithArrowInChannels(); + auto db = kikimr.GetTableClient(); + + auto it = db.StreamExecuteScanQuery(R"( + SELECT Text, SUM(Key) AS Total FROM `/Root/EightShard` + GROUP BY Text + ORDER BY Total DESC; + )").GetValueSync(); + + UNIT_ASSERT(it.IsSuccess()); + + CompareYson(R"([ + [["Value3"];[3624u]]; + [["Value2"];[3616u]]; + [["Value1"];[3608u]] + ])", StreamResultToYson(it)); + } + + Y_UNIT_TEST(AggregateNoColumn) { + auto kikimr = RunnerWithArrowInChannels(); + auto db = kikimr.GetTableClient(); + + auto it = db.StreamExecuteScanQuery(R"( + SELECT SUM(Data), AVG(Data), COUNT(*), MAX(Data), MIN(Data), SUM(Data * 3 + Key * 2) as foo + FROM `/Root/EightShard` + WHERE Key > 300 + )").GetValueSync(); + + UNIT_ASSERT(it.IsSuccess()); + + CompareYson(R"([[[36];[2.];18u;[3];[1];[19980u]]])", StreamResultToYson(it)); + } + + Y_UNIT_TEST(AggregateNoColumnNoRemaps) { + auto kikimr = RunnerWithArrowInChannels(); + auto db = kikimr.GetTableClient(); + + auto it = db.StreamExecuteScanQuery(R"( + SELECT SUM(Data), AVG(Data), COUNT(*) + FROM `/Root/EightShard` + WHERE Key > 300 + )").GetValueSync(); + + UNIT_ASSERT(it.IsSuccess()); + + CompareYson(R"([[[36];[2.];18u]])", StreamResultToYson(it)); + } + + Y_UNIT_TEST(AggregateWithFunction) { + auto kikimr = RunnerWithArrowInChannels(); + auto db = kikimr.GetTableClient(); + + auto it = db.StreamExecuteScanQuery(R"( + SELECT (SUM(Data) * 100) / (MIN(Data) + 10) + FROM `/Root/EightShard` + )").GetValueSync(); + + UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString()); + + CompareYson(R"([[[436]]])", StreamResultToYson(it)); + } + + Y_UNIT_TEST(AggregateEmptySum) { + auto kikimr = RunnerWithArrowInChannels(); + auto db = kikimr.GetTableClient(); + + auto it = db.StreamExecuteScanQuery("SELECT SUM(Data) FROM `/Root/EightShard` WHERE Key < 10").GetValueSync(); + UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString()); + CompareYson(R"([[#]])", StreamResultToYson(it)); + } + + Y_UNIT_TEST(JoinWithParams) { + auto kikimr = RunnerWithArrowInChannels(); + auto db = kikimr.GetTableClient(); + auto params = TParamsBuilder().AddParam("$in") + .BeginList() + .AddListItem().BeginStruct().AddMember("key").Uint64(1).EndStruct() + .EndList() + .Build().Build(); + // table join params + auto query1 = R"( + declare $in as List<Struct<key: UInt64>>; + select l.Key, l.Value + from `/Root/KeyValue` as l join AS_TABLE($in) as r on l.Key = r.key + )"; + // params join table + auto query2 = R"( + declare $in as List<Struct<key: UInt64>>; + select r.Key, r.Value + from AS_TABLE($in) as l join `/Root/KeyValue` as r on l.key = r.Key + )"; + for (auto& query : {query1, query2}) { + auto it = db.StreamExecuteScanQuery(query, params).GetValueSync(); + UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString()); + CompareYson(R"([[[1u];["One"]]])", StreamResultToYson(it)); + } + } +} // Test suite + +} // NKqp } // NKikimr diff --git a/ydb/core/kqp/ut/kqp_types_arrow_ut.cpp b/ydb/core/kqp/ut/kqp_types_arrow_ut.cpp index 664e8b783d7..54d1c525963 100644 --- a/ydb/core/kqp/ut/kqp_types_arrow_ut.cpp +++ b/ydb/core/kqp/ut/kqp_types_arrow_ut.cpp @@ -1,239 +1,239 @@ #include <ydb/core/kqp/ut/common/kqp_ut_common.h> - -#include <util/generic/size_literals.h> - -namespace NKikimr { -namespace NKqp { - -using namespace NYdb; -using namespace NYdb::NTable; - -namespace { - -TKikimrRunner RunnerWithArrowFormatEnabled() { - NKikimrConfig::TFeatureFlags featureFlags; - featureFlags.SetEnableArrowFormatAtDatashard(true); - - return TKikimrRunner{featureFlags}; -} - -void InsertAllColumnsAndCheckSelectAll(TKikimrRunner* runner) { - auto db = runner->GetTableClient(); - auto session = db.CreateSession().GetValueSync().GetSession(); - - auto createResult = session.ExecuteSchemeQuery(R"( - --!syntax_v1 - CREATE TABLE `/Root/Tmp` ( - Key Uint64, - BoolValue Bool, - Int32Value Int32, - Uint32Value Uint32, - Int64Value Int64, - Uint64Value Uint64, - FloatValue Float, - DoubleValue Double, - StringValue String, - Utf8Value Utf8, - DateValue Date, - DatetimeValue Datetime, - TimestampValue Timestamp, - IntervalValue Interval, - DecimalValue Decimal(22,9), - JsonValue Json, - YsonValue Yson, - JsonDocumentValue JsonDocument, - DyNumberValue DyNumber, - PRIMARY KEY (Key) - ); - )").GetValueSync(); - UNIT_ASSERT_C(createResult.IsSuccess(), createResult.GetIssues().ToString()); - - auto insertResult = session.ExecuteDataQuery(R"( - --!syntax_v1 - INSERT INTO `/Root/Tmp` (Key, BoolValue, Int32Value, Uint32Value, Int64Value, Uint64Value, FloatValue, DoubleValue, StringValue, Utf8Value, DateValue, DatetimeValue, TimestampValue, IntervalValue, DecimalValue, JsonValue, YsonValue, JsonDocumentValue, DyNumberValue) VALUES - (42, true, -1, 1, -2, 2, CAST(3.0 AS Float), 4.0, "five", Utf8("six"), Date("2007-07-07"), Datetime("2008-08-08T08:08:08Z"), Timestamp("2009-09-09T09:09:09.09Z"), Interval("P10D"), CAST("11.11" AS Decimal(22, 9)), "[12]", "[13]", JsonDocument("[14]"), DyNumber("15.15")); - )", TTxControl::BeginTx().CommitTx()).GetValueSync(); - UNIT_ASSERT_C(insertResult.IsSuccess(), insertResult.GetIssues().ToString()); - - auto it = db.StreamExecuteScanQuery("SELECT Key, BoolValue, Int32Value, Uint32Value, Int64Value, Uint64Value, FloatValue, DoubleValue, StringValue, Utf8Value, DateValue, DatetimeValue, TimestampValue, IntervalValue, DecimalValue, JsonValue, YsonValue, JsonDocumentValue, DyNumberValue FROM `/Root/Tmp`").GetValueSync(); - UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString()); - auto streamPart = it.ReadNext().GetValueSync(); - UNIT_ASSERT_C(streamPart.IsSuccess(), streamPart.GetIssues().ToString()); - auto resultSet = streamPart.ExtractResultSet(); - auto columns = resultSet.GetColumnsMeta(); - UNIT_ASSERT_C(columns.size() == 19, "Wrong columns count"); - NYdb::TResultSetParser parser(resultSet); - UNIT_ASSERT_C(parser.TryNextRow(), "Row is missing"); - UNIT_ASSERT(*parser.ColumnParser(0).GetOptionalUint64().Get() == 42); - UNIT_ASSERT(*parser.ColumnParser(1).GetOptionalBool().Get() == true); - UNIT_ASSERT(*parser.ColumnParser(2).GetOptionalInt32().Get() == -1); - UNIT_ASSERT(*parser.ColumnParser(3).GetOptionalUint32().Get() == 1); - UNIT_ASSERT(*parser.ColumnParser(4).GetOptionalInt64().Get() == -2); - UNIT_ASSERT(*parser.ColumnParser(5).GetOptionalUint64().Get() == 2); - UNIT_ASSERT(*parser.ColumnParser(6).GetOptionalFloat().Get() == 3.0); - UNIT_ASSERT(*parser.ColumnParser(7).GetOptionalDouble().Get() == 4.0); - UNIT_ASSERT(*parser.ColumnParser(8).GetOptionalString().Get() == TString("five")); - UNIT_ASSERT(*parser.ColumnParser(9).GetOptionalUtf8().Get() == TString("six")); - UNIT_ASSERT(*parser.ColumnParser(10).GetOptionalDate().Get() == TInstant::ParseIso8601("2007-07-07")); - UNIT_ASSERT(*parser.ColumnParser(11).GetOptionalDatetime().Get() == TInstant::ParseIso8601("2008-08-08T08:08:08Z")); - UNIT_ASSERT(*parser.ColumnParser(12).GetOptionalTimestamp().Get() == TInstant::ParseIso8601("2009-09-09T09:09:09.09Z")); - Cerr << TInstant::Days(10).MicroSeconds() << Endl; - UNIT_ASSERT(*parser.ColumnParser(13).GetOptionalInterval().Get() == TInstant::Days(10).MicroSeconds()); - UNIT_ASSERT(parser.ColumnParser(14).GetOptionalDecimal().Get()->ToString() == TString("11.11")); - UNIT_ASSERT(*parser.ColumnParser(15).GetOptionalJson().Get() == TString("[12]")); - UNIT_ASSERT(*parser.ColumnParser(16).GetOptionalYson().Get() == TString("[13]")); - UNIT_ASSERT(*parser.ColumnParser(17).GetOptionalJsonDocument().Get() == TString("[14]")); - UNIT_ASSERT(*parser.ColumnParser(18).GetOptionalDyNumber().Get() == TString(".1515e2")); - streamPart = it.ReadNext().GetValueSync(); - UNIT_ASSERT_C(streamPart.EOS(), streamPart.GetIssues().ToString()); -} - -} - -Y_UNIT_TEST_SUITE(KqpScanArrowFormat) { - Y_UNIT_TEST(AggregateCountStar) { - auto kikimr = RunnerWithArrowFormatEnabled(); - auto db = kikimr.GetTableClient(); - - auto it = db.StreamExecuteScanQuery("SELECT COUNT(*) FROM `/Root/EightShard`").GetValueSync(); - UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString()); - CompareYson(R"([[24u]])", StreamResultToYson(it)); - } - - Y_UNIT_TEST(AllTypesColumns) { - auto kikimr = RunnerWithArrowFormatEnabled(); - - InsertAllColumnsAndCheckSelectAll(&kikimr); - } - - Y_UNIT_TEST(AllTypesColumnsCellvec) { - TKikimrRunner kikimr; - InsertAllColumnsAndCheckSelectAll(&kikimr); - } - - Y_UNIT_TEST(SingleKey) { - auto kikimr = RunnerWithArrowFormatEnabled(); - auto db = kikimr.GetTableClient(); - - auto params = db.GetParamsBuilder() - .AddParam("$key") - .Uint64(202) - .Build() - .Build(); - - auto it = db.StreamExecuteScanQuery(R"( - DECLARE $key AS Uint64; - - SELECT * FROM `/Root/EightShard` WHERE Key = $key; - )", params).GetValueSync(); - - UNIT_ASSERT(it.IsSuccess()); - - CompareYson(R"([ - [[1];[202u];["Value2"]] - ])", StreamResultToYson(it)); - } - - Y_UNIT_TEST(AggregateByColumn) { - auto kikimr = RunnerWithArrowFormatEnabled(); - auto db = kikimr.GetTableClient(); - - auto it = db.StreamExecuteScanQuery(R"( - SELECT Text, SUM(Key) AS Total FROM `/Root/EightShard` - GROUP BY Text - ORDER BY Total DESC; - )").GetValueSync(); - - UNIT_ASSERT(it.IsSuccess()); - - CompareYson(R"([ - [["Value3"];[3624u]]; - [["Value2"];[3616u]]; - [["Value1"];[3608u]] - ])", StreamResultToYson(it)); - } - - Y_UNIT_TEST(AggregateNoColumn) { - auto kikimr = RunnerWithArrowFormatEnabled(); - auto db = kikimr.GetTableClient(); - - auto it = db.StreamExecuteScanQuery(R"( - SELECT SUM(Data), AVG(Data), COUNT(*), MAX(Data), MIN(Data), SUM(Data * 3 + Key * 2) as foo - FROM `/Root/EightShard` - WHERE Key > 300 - )").GetValueSync(); - - UNIT_ASSERT(it.IsSuccess()); - - CompareYson(R"([[[36];[2.];18u;[3];[1];[19980u]]])", StreamResultToYson(it)); - } - - Y_UNIT_TEST(AggregateNoColumnNoRemaps) { - auto kikimr = RunnerWithArrowFormatEnabled(); - auto db = kikimr.GetTableClient(); - - auto it = db.StreamExecuteScanQuery(R"( - SELECT SUM(Data), AVG(Data), COUNT(*) - FROM `/Root/EightShard` - WHERE Key > 300 - )").GetValueSync(); - - UNIT_ASSERT(it.IsSuccess()); - - CompareYson(R"([[[36];[2.];18u]])", StreamResultToYson(it)); - } - - Y_UNIT_TEST(AggregateWithFunction) { - auto kikimr = RunnerWithArrowFormatEnabled(); - auto db = kikimr.GetTableClient(); - - auto it = db.StreamExecuteScanQuery(R"( - SELECT (SUM(Data) * 100) / (MIN(Data) + 10) - FROM `/Root/EightShard` - )").GetValueSync(); - - UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString()); - - CompareYson(R"([[[436]]])", StreamResultToYson(it)); - } - - Y_UNIT_TEST(AggregateEmptySum) { - auto kikimr = RunnerWithArrowFormatEnabled(); - auto db = kikimr.GetTableClient(); - - auto it = db.StreamExecuteScanQuery("SELECT SUM(Data) FROM `/Root/EightShard` WHERE Key < 10").GetValueSync(); - UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString()); - CompareYson(R"([[#]])", StreamResultToYson(it)); - } - - Y_UNIT_TEST(JoinWithParams) { - auto kikimr = RunnerWithArrowFormatEnabled(); - auto db = kikimr.GetTableClient(); - auto params = TParamsBuilder().AddParam("$in") - .BeginList() - .AddListItem().BeginStruct().AddMember("key").Uint64(1).EndStruct() - .EndList() - .Build().Build(); - // table join params - auto query1 = R"( - declare $in as List<Struct<key: UInt64>>; - select l.Key, l.Value - from `/Root/KeyValue` as l join AS_TABLE($in) as r on l.Key = r.key - )"; - // params join table - auto query2 = R"( - declare $in as List<Struct<key: UInt64>>; - select r.Key, r.Value - from AS_TABLE($in) as l join `/Root/KeyValue` as r on l.key = r.Key - )"; - for (auto& query : {query1, query2}) { - auto it = db.StreamExecuteScanQuery(query, params).GetValueSync(); - UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString()); - CompareYson(R"([[[1u];["One"]]])", StreamResultToYson(it)); - } - } -} // Test suite - -} // NKqp + +#include <util/generic/size_literals.h> + +namespace NKikimr { +namespace NKqp { + +using namespace NYdb; +using namespace NYdb::NTable; + +namespace { + +TKikimrRunner RunnerWithArrowFormatEnabled() { + NKikimrConfig::TFeatureFlags featureFlags; + featureFlags.SetEnableArrowFormatAtDatashard(true); + + return TKikimrRunner{featureFlags}; +} + +void InsertAllColumnsAndCheckSelectAll(TKikimrRunner* runner) { + auto db = runner->GetTableClient(); + auto session = db.CreateSession().GetValueSync().GetSession(); + + auto createResult = session.ExecuteSchemeQuery(R"( + --!syntax_v1 + CREATE TABLE `/Root/Tmp` ( + Key Uint64, + BoolValue Bool, + Int32Value Int32, + Uint32Value Uint32, + Int64Value Int64, + Uint64Value Uint64, + FloatValue Float, + DoubleValue Double, + StringValue String, + Utf8Value Utf8, + DateValue Date, + DatetimeValue Datetime, + TimestampValue Timestamp, + IntervalValue Interval, + DecimalValue Decimal(22,9), + JsonValue Json, + YsonValue Yson, + JsonDocumentValue JsonDocument, + DyNumberValue DyNumber, + PRIMARY KEY (Key) + ); + )").GetValueSync(); + UNIT_ASSERT_C(createResult.IsSuccess(), createResult.GetIssues().ToString()); + + auto insertResult = session.ExecuteDataQuery(R"( + --!syntax_v1 + INSERT INTO `/Root/Tmp` (Key, BoolValue, Int32Value, Uint32Value, Int64Value, Uint64Value, FloatValue, DoubleValue, StringValue, Utf8Value, DateValue, DatetimeValue, TimestampValue, IntervalValue, DecimalValue, JsonValue, YsonValue, JsonDocumentValue, DyNumberValue) VALUES + (42, true, -1, 1, -2, 2, CAST(3.0 AS Float), 4.0, "five", Utf8("six"), Date("2007-07-07"), Datetime("2008-08-08T08:08:08Z"), Timestamp("2009-09-09T09:09:09.09Z"), Interval("P10D"), CAST("11.11" AS Decimal(22, 9)), "[12]", "[13]", JsonDocument("[14]"), DyNumber("15.15")); + )", TTxControl::BeginTx().CommitTx()).GetValueSync(); + UNIT_ASSERT_C(insertResult.IsSuccess(), insertResult.GetIssues().ToString()); + + auto it = db.StreamExecuteScanQuery("SELECT Key, BoolValue, Int32Value, Uint32Value, Int64Value, Uint64Value, FloatValue, DoubleValue, StringValue, Utf8Value, DateValue, DatetimeValue, TimestampValue, IntervalValue, DecimalValue, JsonValue, YsonValue, JsonDocumentValue, DyNumberValue FROM `/Root/Tmp`").GetValueSync(); + UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString()); + auto streamPart = it.ReadNext().GetValueSync(); + UNIT_ASSERT_C(streamPart.IsSuccess(), streamPart.GetIssues().ToString()); + auto resultSet = streamPart.ExtractResultSet(); + auto columns = resultSet.GetColumnsMeta(); + UNIT_ASSERT_C(columns.size() == 19, "Wrong columns count"); + NYdb::TResultSetParser parser(resultSet); + UNIT_ASSERT_C(parser.TryNextRow(), "Row is missing"); + UNIT_ASSERT(*parser.ColumnParser(0).GetOptionalUint64().Get() == 42); + UNIT_ASSERT(*parser.ColumnParser(1).GetOptionalBool().Get() == true); + UNIT_ASSERT(*parser.ColumnParser(2).GetOptionalInt32().Get() == -1); + UNIT_ASSERT(*parser.ColumnParser(3).GetOptionalUint32().Get() == 1); + UNIT_ASSERT(*parser.ColumnParser(4).GetOptionalInt64().Get() == -2); + UNIT_ASSERT(*parser.ColumnParser(5).GetOptionalUint64().Get() == 2); + UNIT_ASSERT(*parser.ColumnParser(6).GetOptionalFloat().Get() == 3.0); + UNIT_ASSERT(*parser.ColumnParser(7).GetOptionalDouble().Get() == 4.0); + UNIT_ASSERT(*parser.ColumnParser(8).GetOptionalString().Get() == TString("five")); + UNIT_ASSERT(*parser.ColumnParser(9).GetOptionalUtf8().Get() == TString("six")); + UNIT_ASSERT(*parser.ColumnParser(10).GetOptionalDate().Get() == TInstant::ParseIso8601("2007-07-07")); + UNIT_ASSERT(*parser.ColumnParser(11).GetOptionalDatetime().Get() == TInstant::ParseIso8601("2008-08-08T08:08:08Z")); + UNIT_ASSERT(*parser.ColumnParser(12).GetOptionalTimestamp().Get() == TInstant::ParseIso8601("2009-09-09T09:09:09.09Z")); + Cerr << TInstant::Days(10).MicroSeconds() << Endl; + UNIT_ASSERT(*parser.ColumnParser(13).GetOptionalInterval().Get() == TInstant::Days(10).MicroSeconds()); + UNIT_ASSERT(parser.ColumnParser(14).GetOptionalDecimal().Get()->ToString() == TString("11.11")); + UNIT_ASSERT(*parser.ColumnParser(15).GetOptionalJson().Get() == TString("[12]")); + UNIT_ASSERT(*parser.ColumnParser(16).GetOptionalYson().Get() == TString("[13]")); + UNIT_ASSERT(*parser.ColumnParser(17).GetOptionalJsonDocument().Get() == TString("[14]")); + UNIT_ASSERT(*parser.ColumnParser(18).GetOptionalDyNumber().Get() == TString(".1515e2")); + streamPart = it.ReadNext().GetValueSync(); + UNIT_ASSERT_C(streamPart.EOS(), streamPart.GetIssues().ToString()); +} + +} + +Y_UNIT_TEST_SUITE(KqpScanArrowFormat) { + Y_UNIT_TEST(AggregateCountStar) { + auto kikimr = RunnerWithArrowFormatEnabled(); + auto db = kikimr.GetTableClient(); + + auto it = db.StreamExecuteScanQuery("SELECT COUNT(*) FROM `/Root/EightShard`").GetValueSync(); + UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString()); + CompareYson(R"([[24u]])", StreamResultToYson(it)); + } + + Y_UNIT_TEST(AllTypesColumns) { + auto kikimr = RunnerWithArrowFormatEnabled(); + + InsertAllColumnsAndCheckSelectAll(&kikimr); + } + + Y_UNIT_TEST(AllTypesColumnsCellvec) { + TKikimrRunner kikimr; + InsertAllColumnsAndCheckSelectAll(&kikimr); + } + + Y_UNIT_TEST(SingleKey) { + auto kikimr = RunnerWithArrowFormatEnabled(); + auto db = kikimr.GetTableClient(); + + auto params = db.GetParamsBuilder() + .AddParam("$key") + .Uint64(202) + .Build() + .Build(); + + auto it = db.StreamExecuteScanQuery(R"( + DECLARE $key AS Uint64; + + SELECT * FROM `/Root/EightShard` WHERE Key = $key; + )", params).GetValueSync(); + + UNIT_ASSERT(it.IsSuccess()); + + CompareYson(R"([ + [[1];[202u];["Value2"]] + ])", StreamResultToYson(it)); + } + + Y_UNIT_TEST(AggregateByColumn) { + auto kikimr = RunnerWithArrowFormatEnabled(); + auto db = kikimr.GetTableClient(); + + auto it = db.StreamExecuteScanQuery(R"( + SELECT Text, SUM(Key) AS Total FROM `/Root/EightShard` + GROUP BY Text + ORDER BY Total DESC; + )").GetValueSync(); + + UNIT_ASSERT(it.IsSuccess()); + + CompareYson(R"([ + [["Value3"];[3624u]]; + [["Value2"];[3616u]]; + [["Value1"];[3608u]] + ])", StreamResultToYson(it)); + } + + Y_UNIT_TEST(AggregateNoColumn) { + auto kikimr = RunnerWithArrowFormatEnabled(); + auto db = kikimr.GetTableClient(); + + auto it = db.StreamExecuteScanQuery(R"( + SELECT SUM(Data), AVG(Data), COUNT(*), MAX(Data), MIN(Data), SUM(Data * 3 + Key * 2) as foo + FROM `/Root/EightShard` + WHERE Key > 300 + )").GetValueSync(); + + UNIT_ASSERT(it.IsSuccess()); + + CompareYson(R"([[[36];[2.];18u;[3];[1];[19980u]]])", StreamResultToYson(it)); + } + + Y_UNIT_TEST(AggregateNoColumnNoRemaps) { + auto kikimr = RunnerWithArrowFormatEnabled(); + auto db = kikimr.GetTableClient(); + + auto it = db.StreamExecuteScanQuery(R"( + SELECT SUM(Data), AVG(Data), COUNT(*) + FROM `/Root/EightShard` + WHERE Key > 300 + )").GetValueSync(); + + UNIT_ASSERT(it.IsSuccess()); + + CompareYson(R"([[[36];[2.];18u]])", StreamResultToYson(it)); + } + + Y_UNIT_TEST(AggregateWithFunction) { + auto kikimr = RunnerWithArrowFormatEnabled(); + auto db = kikimr.GetTableClient(); + + auto it = db.StreamExecuteScanQuery(R"( + SELECT (SUM(Data) * 100) / (MIN(Data) + 10) + FROM `/Root/EightShard` + )").GetValueSync(); + + UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString()); + + CompareYson(R"([[[436]]])", StreamResultToYson(it)); + } + + Y_UNIT_TEST(AggregateEmptySum) { + auto kikimr = RunnerWithArrowFormatEnabled(); + auto db = kikimr.GetTableClient(); + + auto it = db.StreamExecuteScanQuery("SELECT SUM(Data) FROM `/Root/EightShard` WHERE Key < 10").GetValueSync(); + UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString()); + CompareYson(R"([[#]])", StreamResultToYson(it)); + } + + Y_UNIT_TEST(JoinWithParams) { + auto kikimr = RunnerWithArrowFormatEnabled(); + auto db = kikimr.GetTableClient(); + auto params = TParamsBuilder().AddParam("$in") + .BeginList() + .AddListItem().BeginStruct().AddMember("key").Uint64(1).EndStruct() + .EndList() + .Build().Build(); + // table join params + auto query1 = R"( + declare $in as List<Struct<key: UInt64>>; + select l.Key, l.Value + from `/Root/KeyValue` as l join AS_TABLE($in) as r on l.Key = r.key + )"; + // params join table + auto query2 = R"( + declare $in as List<Struct<key: UInt64>>; + select r.Key, r.Value + from AS_TABLE($in) as l join `/Root/KeyValue` as r on l.key = r.Key + )"; + for (auto& query : {query1, query2}) { + auto it = db.StreamExecuteScanQuery(query, params).GetValueSync(); + UNIT_ASSERT_C(it.IsSuccess(), it.GetIssues().ToString()); + CompareYson(R"([[[1u];["One"]]])", StreamResultToYson(it)); + } + } +} // Test suite + +} // NKqp } // NKikimr diff --git a/ydb/core/kqp/ut/ya.make b/ydb/core/kqp/ut/ya.make index 1b8d37a4466..609ec3237f8 100644 --- a/ydb/core/kqp/ut/ya.make +++ b/ydb/core/kqp/ut/ya.make @@ -20,7 +20,7 @@ ENDIF() SRCS( kqp_acl_ut.cpp - kqp_arrow_in_channels_ut.cpp + kqp_arrow_in_channels_ut.cpp kqp_document_api_ut.cpp kqp_effects_perf_ut.cpp kqp_explain_ut.cpp @@ -53,7 +53,7 @@ SRCS( kqp_sys_col_ut.cpp kqp_table_predicate_ut.cpp kqp_tx_ut.cpp - kqp_types_arrow_ut.cpp + kqp_types_arrow_ut.cpp kqp_write_ut.cpp kqp_yql_ut.cpp ) @@ -80,5 +80,5 @@ RECURSE( fat ../rm/ut ../proxy/ut - ../runtime/ut + ../runtime/ut ) diff --git a/ydb/core/protos/config.proto b/ydb/core/protos/config.proto index 81d5c65f275..d64169d4fc0 100644 --- a/ydb/core/protos/config.proto +++ b/ydb/core/protos/config.proto @@ -661,10 +661,10 @@ message TFeatureFlags { optional bool EnableMvccSnapshotReads = 46 [default = false]; optional Tribool EnableMvcc = 47 [default = VALUE_TRUE]; optional bool EnableSchemeTransactionsAtSchemeShard = 48 [default = false]; - optional bool EnableArrowFormatAtDatashard = 49 [default = false]; + optional bool EnableArrowFormatAtDatashard = 49 [default = false]; optional bool Enable3x3RequestsForMirror3DCMinLatencyPut = 50 [default = false]; optional bool EnableBackgroundCompaction = 51 [default = true]; - optional bool EnableArrowFormatInChannels = 52 [default = false]; + optional bool EnableArrowFormatInChannels = 52 [default = false]; optional bool EnableBackgroundCompactionServerless = 53 [default = false]; optional bool EnableNotNullColumns = 54 [default = false]; optional bool EnableTtlOnAsyncIndexedTables = 55 [default = false]; diff --git a/ydb/core/protos/kqp.proto b/ydb/core/protos/kqp.proto index a29bf481d18..9b46b720e11 100644 --- a/ydb/core/protos/kqp.proto +++ b/ydb/core/protos/kqp.proto @@ -547,18 +547,18 @@ message TEvRemoteScanData { optional bool PageFault = 5; optional bool Finished = 6; optional bytes LastKey = 7; - optional uint32 Generation = 9; - - message TArrowBatch { - optional bytes Schema = 1; - optional bytes Batch = 2; - } - - // Only one of the fields Rows and ArrowBatch must be filled. However, we can not use oneof feature because Rows - // field is repeated. Moving it into oneof is impossible. We may wrap it into a message but this would break - // backwards comparability. + optional uint32 Generation = 9; + + message TArrowBatch { + optional bytes Schema = 1; + optional bytes Batch = 2; + } + + // Only one of the fields Rows and ArrowBatch must be filled. However, we can not use oneof feature because Rows + // field is repeated. Moving it into oneof is impossible. We may wrap it into a message but this would break + // backwards comparability. repeated bytes Rows = 8; - optional TArrowBatch ArrowBatch = 10; + optional TArrowBatch ArrowBatch = 10; } message TEvRemoteScanDataAck { diff --git a/ydb/core/protos/tx_datashard.proto b/ydb/core/protos/tx_datashard.proto index c2a11016e83..89863b25400 100644 --- a/ydb/core/protos/tx_datashard.proto +++ b/ydb/core/protos/tx_datashard.proto @@ -137,12 +137,12 @@ enum EKqpTransactionType { KQP_TX_TYPE_SCAN = 2; } -enum EScanDataFormat { - UNSPECIFIED = 0; - CELLVEC = 1; - ARROW = 2; -} - +enum EScanDataFormat { + UNSPECIFIED = 0; + CELLVEC = 1; + ARROW = 2; +} + message TKqpTransaction { message TColumnMeta { optional uint32 Id = 1; @@ -211,7 +211,7 @@ message TKqpTransaction { optional uint64 ItemsLimit = 6; optional bool Reverse = 7; reserved 8; // optional bytes ProcessProgram = 8; - optional EScanDataFormat DataFormat = 9; + optional EScanDataFormat DataFormat = 9; optional NKikimrSSA.TOlapProgram OlapProgram = 10; // Currently only for OLAP tables } @@ -1392,7 +1392,7 @@ message TEvKqpScan { reserved 14; optional uint64 ItemsLimit = 15; optional bool Reverse = 16; - optional EScanDataFormat DataFormat = 17; + optional EScanDataFormat DataFormat = 17; optional NYql.NDqProto.EDqStatsMode StatsMode = 18; optional bytes OlapProgram = 19; optional NKikimrSchemeOp.EOlapProgramType OlapProgramType = 20; diff --git a/ydb/core/tx/columnshard/columnshard__scan.cpp b/ydb/core/tx/columnshard/columnshard__scan.cpp index bb70d92e1df..e2f7ba5a515 100644 --- a/ydb/core/tx/columnshard/columnshard__scan.cpp +++ b/ydb/core/tx/columnshard/columnshard__scan.cpp @@ -32,7 +32,7 @@ public: TColumnShardScan(const TActorId& columnShardActorId, const TActorId& scanComputeActorId, ui32 scanId, ui64 txId, ui32 scanGen, ui64 requestCookie, const TString& table, TDuration timeout, TVector<TTxScan::TReadMetadataPtr>&& readMetadataList, - NKikimrTxDataShard::EScanDataFormat dataFormat) + NKikimrTxDataShard::EScanDataFormat dataFormat) : ColumnShardActorId(columnShardActorId) , ScanComputeActorId(scanComputeActorId) , BlobCacheActorId(NBlobCache::MakeBlobCacheServiceId()) @@ -40,7 +40,7 @@ public: , TxId(txId) , ScanGen(scanGen) , RequestCookie(requestCookie) - , DataFormat(dataFormat) + , DataFormat(dataFormat) , TablePath(table) , ReadMetadataRanges(std::move(readMetadataList)) , ReadMetadataIndex(0) @@ -151,7 +151,7 @@ private: auto result = ScanIterator->GetBatch(); if (ResultYqlSchema.empty() && DataFormat != NKikimrTxDataShard::EScanDataFormat::ARROW) { ResultYqlSchema = ReadMetadataRanges[ReadMetadataIndex]->GetResultYqlSchema(); - } + } if (!result.ResultBatch) { // No data is ready yet return false; @@ -303,28 +303,28 @@ private: if (reserveRows) { Y_VERIFY(DataFormat != NKikimrTxDataShard::EScanDataFormat::ARROW); Result->Rows.reserve(reserveRows); - } + } } } - void NextReadMetadata() { + void NextReadMetadata() { ScanIterator.reset(); - ++ReadMetadataIndex; - + ++ReadMetadataIndex; + if (ReadMetadataIndex == ReadMetadataRanges.size()) { - // Send empty batch with "finished" flag + // Send empty batch with "finished" flag MakeResult(); - SendResult(false, true); - return Finish(); - } - + SendResult(false, true); + return Finish(); + } + ScanIterator = ReadMetadataRanges[ReadMetadataIndex]->StartScan(); - - // Used in TArrowToYdbConverter + + // Used in TArrowToYdbConverter ResultYqlSchema.clear(); - } - + } + void AddRow(const TConstArrayRef<TCell>& row) override { Result->Rows.emplace_back(TOwnedCellVec::Make(row)); ++Rows; @@ -430,7 +430,7 @@ private: const ui32 ScanGen; const ui64 RequestCookie; const i64 MaxReadAheadBytes = DEFAULT_READ_AHEAD_BYTES; - const NKikimrTxDataShard::EScanDataFormat DataFormat; + const NKikimrTxDataShard::EScanDataFormat DataFormat; const TString TablePath; @@ -669,7 +669,7 @@ void TTxScan::Complete(const TActorContext& ctx) { const ui64 txId = request.GetTxId(); const ui32 scanGen = request.GetGeneration(); TString table = request.GetTablePath(); - auto dataFormat = request.GetDataFormat(); + auto dataFormat = request.GetDataFormat(); TDuration timeout = TDuration::MilliSeconds(request.GetTimeoutMs()); if (scanGen > 1) { diff --git a/ydb/core/tx/columnshard/engines/indexed_read_data.h b/ydb/core/tx/columnshard/engines/indexed_read_data.h index a77799a25a1..938c141fbc0 100644 --- a/ydb/core/tx/columnshard/engines/indexed_read_data.h +++ b/ydb/core/tx/columnshard/engines/indexed_read_data.h @@ -103,15 +103,15 @@ struct TReadMetadata : public TReadMetadataBase, public std::enable_shared_from_ } TVector<std::pair<TString, NScheme::TTypeId>> GetResultYqlSchema() const override { - TVector<NTable::TTag> columnIds; - columnIds.reserve(ResultSchema->num_fields()); - for (const auto& field: ResultSchema->fields()) { - TString name = TStringBuilder() << field->name(); - columnIds.emplace_back(IndexInfo.GetColumnId(name)); - } - return IndexInfo.GetColumns(columnIds); - } - + TVector<NTable::TTag> columnIds; + columnIds.reserve(ResultSchema->num_fields()); + for (const auto& field: ResultSchema->fields()) { + TString name = TStringBuilder() << field->name(); + columnIds.emplace_back(IndexInfo.GetColumnId(name)); + } + return IndexInfo.GetColumns(columnIds); + } + TVector<std::pair<TString, NScheme::TTypeId>> GetKeyYqlSchema() const override { return IndexInfo.GetPK(); } diff --git a/ydb/core/tx/datashard/datashard__kqp_scan.cpp b/ydb/core/tx/datashard/datashard__kqp_scan.cpp index 04c3391af09..b6b6b7e337d 100644 --- a/ydb/core/tx/datashard/datashard__kqp_scan.cpp +++ b/ydb/core/tx/datashard/datashard__kqp_scan.cpp @@ -38,7 +38,7 @@ public: NDataShard::TUserTable::TCPtr tableInfo, const TSmallVec<TSerializedTableRange>&& tableRanges, const TSmallVec<NTable::TTag>&& columnTags, const TSmallVec<bool>&& skipNullKeys, const NYql::NDqProto::EDqStatsMode& statsMode, ui64 timeoutMs, ui32 generation, - NKikimrTxDataShard::EScanDataFormat dataFormat) + NKikimrTxDataShard::EScanDataFormat dataFormat) : TActor(&TKqpScan::StateScan) , ComputeActorId(computeActorId) , DatashardActorId(datashardActorId) @@ -52,33 +52,33 @@ public: , StatsMode(statsMode) , Deadline(TInstant::Now() + (timeoutMs ? TDuration::MilliSeconds(timeoutMs) + SCAN_HARD_TIMEOUT_GAP : SCAN_HARD_TIMEOUT)) , Generation(generation) - , DataFormat(dataFormat) + , DataFormat(dataFormat) , PeerFreeSpace(0) , Sleep(true) , IsLocal(computeActorId.NodeId() == datashardActorId.NodeId()) - { - if (DataFormat == NKikimrTxDataShard::EScanDataFormat::ARROW) { - BatchBuilder = MakeHolder<NArrow::TArrowBatchBuilder>(); - TVector<std::pair<TString, NScheme::TTypeId>> schema; - if (!Tags.empty()) { - Types.reserve(Tags.size()); - schema.reserve(Tags.size()); - for (const auto tag: Tags) { - const auto& column = TableInfo->Columns.at(tag); - Types.emplace_back(column.Type); - schema.emplace_back(column.Name, column.Type); - } + { + if (DataFormat == NKikimrTxDataShard::EScanDataFormat::ARROW) { + BatchBuilder = MakeHolder<NArrow::TArrowBatchBuilder>(); + TVector<std::pair<TString, NScheme::TTypeId>> schema; + if (!Tags.empty()) { + Types.reserve(Tags.size()); + schema.reserve(Tags.size()); + for (const auto tag: Tags) { + const auto& column = TableInfo->Columns.at(tag); + Types.emplace_back(column.Type); + schema.emplace_back(column.Name, column.Type); + } BatchBuilder->Reserve(INIT_BATCH_ROWS); bool started = BatchBuilder->Start(schema); - YQL_ENSURE(started, "Failed to start BatchBuilder"); - } - } + YQL_ENSURE(started, "Failed to start BatchBuilder"); + } + } for (auto& range : TableRanges) { LOG_TRACE_S(*TlsActivationContext, NKikimrServices::TX_DATASHARD, "--> Scan range: " << DebugPrintRange(TableInfo->KeyColumnTypes, range.ToTableRange(), *AppData()->TypeRegistry)); } - } + } private: STATEFN(StateScan) { @@ -266,7 +266,7 @@ private: } } - MakeResult(); + MakeResult(); if (Y_UNLIKELY(IsProfile())) { Result->WaitTime += TInstant::Now() - StartWaitTime; @@ -278,7 +278,7 @@ private: } }; - AddRow(row); + AddRow(row); auto sent = SendResult(/* pageFault */ false); @@ -359,8 +359,8 @@ private: } else { Result = MakeHolder<TEvKqpCompute::TEvScanData>(ScanId, Generation); } - auto send = SendResult(Result->PageFault, true); - Y_VERIFY_DEBUG(send); + auto send = SendResult(Result->PageFault, true); + Y_VERIFY_DEBUG(send); } Driver = nullptr; @@ -376,80 +376,80 @@ private: out << "TExecuteKqpScanTxUnit, TKqpScan"; } - void MakeResult() { - if (!Result) { - Result = MakeHolder<TEvKqpCompute::TEvScanData>(ScanId, Generation); - switch (DataFormat) { - case NKikimrTxDataShard::EScanDataFormat::UNSPECIFIED: - case NKikimrTxDataShard::EScanDataFormat::CELLVEC: { - Result->Rows.reserve(INIT_BATCH_ROWS); - break; - } - case NKikimrTxDataShard::EScanDataFormat::ARROW: { - } - } - } - } - - void AddRow(const TRow& row) { - ++Rows; - // NOTE: Some per-row overhead to deal with the case when no columns were requested - if (Tags.empty()) { - CellvecBytes += 8; - } - for (auto& cell: *row) { - CellvecBytes += std::max((ui64)8, (ui64)cell.Size()); - } - switch (DataFormat) { - case NKikimrTxDataShard::EScanDataFormat::UNSPECIFIED: - case NKikimrTxDataShard::EScanDataFormat::CELLVEC: { - Result->Rows.emplace_back(TOwnedCellVec::Make(*row)); - break; - } - case NKikimrTxDataShard::EScanDataFormat::ARROW: { - NKikimr::TDbTupleRef key; - Y_VERIFY_DEBUG((*row).size() == Types.size()); - NKikimr::TDbTupleRef value = NKikimr::TDbTupleRef(Types.data(), (*row).data(), Types.size()); - BatchBuilder->AddRow(key, value); - break; - } - } - } - - bool SendResult(bool pageFault, bool finish = false) noexcept { - if (Rows >= MAX_BATCH_ROWS || CellvecBytes >= PeerFreeSpace || - (pageFault && (Rows >= MIN_BATCH_ROWS_ON_PAGEFAULT || CellvecBytes >= MIN_BATCH_SIZE_ON_PAGEFAULT)) || finish) + void MakeResult() { + if (!Result) { + Result = MakeHolder<TEvKqpCompute::TEvScanData>(ScanId, Generation); + switch (DataFormat) { + case NKikimrTxDataShard::EScanDataFormat::UNSPECIFIED: + case NKikimrTxDataShard::EScanDataFormat::CELLVEC: { + Result->Rows.reserve(INIT_BATCH_ROWS); + break; + } + case NKikimrTxDataShard::EScanDataFormat::ARROW: { + } + } + } + } + + void AddRow(const TRow& row) { + ++Rows; + // NOTE: Some per-row overhead to deal with the case when no columns were requested + if (Tags.empty()) { + CellvecBytes += 8; + } + for (auto& cell: *row) { + CellvecBytes += std::max((ui64)8, (ui64)cell.Size()); + } + switch (DataFormat) { + case NKikimrTxDataShard::EScanDataFormat::UNSPECIFIED: + case NKikimrTxDataShard::EScanDataFormat::CELLVEC: { + Result->Rows.emplace_back(TOwnedCellVec::Make(*row)); + break; + } + case NKikimrTxDataShard::EScanDataFormat::ARROW: { + NKikimr::TDbTupleRef key; + Y_VERIFY_DEBUG((*row).size() == Types.size()); + NKikimr::TDbTupleRef value = NKikimr::TDbTupleRef(Types.data(), (*row).data(), Types.size()); + BatchBuilder->AddRow(key, value); + break; + } + } + } + + bool SendResult(bool pageFault, bool finish = false) noexcept { + if (Rows >= MAX_BATCH_ROWS || CellvecBytes >= PeerFreeSpace || + (pageFault && (Rows >= MIN_BATCH_ROWS_ON_PAGEFAULT || CellvecBytes >= MIN_BATCH_SIZE_ON_PAGEFAULT)) || finish) { Result->PageFault = pageFault; Result->PageFaults = PageFaults; - if (finish) { - Result->Finished = true; - } else { - Result->LastKey = LastKey; - } - auto sendBytes = CellvecBytes; - - if (DataFormat == NKikimrTxDataShard::EScanDataFormat::ARROW) { - FlushBatchToResult(); - sendBytes = NArrow::GetBatchDataSize(Result->ArrowBatch); - // Batch is stored inside BatchBuilder until we flush it into Result. So we verify number of rows here. - YQL_ENSURE(Rows == 0 && Result->ArrowBatch == nullptr || Result->ArrowBatch->num_rows() == (i64) Rows); - } else { - YQL_ENSURE(Result->Rows.size() == Rows); - } - + if (finish) { + Result->Finished = true; + } else { + Result->LastKey = LastKey; + } + auto sendBytes = CellvecBytes; + + if (DataFormat == NKikimrTxDataShard::EScanDataFormat::ARROW) { + FlushBatchToResult(); + sendBytes = NArrow::GetBatchDataSize(Result->ArrowBatch); + // Batch is stored inside BatchBuilder until we flush it into Result. So we verify number of rows here. + YQL_ENSURE(Rows == 0 && Result->ArrowBatch == nullptr || Result->ArrowBatch->num_rows() == (i64) Rows); + } else { + YQL_ENSURE(Result->Rows.size() == Rows); + } + PageFaults = 0; LOG_DEBUG_S(*TlsActivationContext, NKikimrServices::TX_DATASHARD, "Send ScanData" << ", from: " << ScanActorId << ", to: " << ComputeActorId << ", scanId: " << ScanId << ", table: " << TablePath - << ", bytes: " << sendBytes << ", rows: " << Rows << ", page faults: " << Result->PageFaults + << ", bytes: " << sendBytes << ", rows: " << Rows << ", page faults: " << Result->PageFaults << ", finished: " << Result->Finished << ", pageFault: " << Result->PageFault); - if (PeerFreeSpace < sendBytes) { + if (PeerFreeSpace < sendBytes) { PeerFreeSpace = 0; } else { - PeerFreeSpace -= sendBytes; + PeerFreeSpace -= sendBytes; } if (sendBytes >= 48_MB) { @@ -479,20 +479,20 @@ private: return false; } - // Call only after MakeResult method. - void FlushBatchToResult() { - // FlushBatch reset Batch pointer in BatchBuilder only if some rows were added after. So we if we have already - // send a batch and try to send an empty batch again without adding rows, then a copy of the batch will be send - // instead. So we check Rows here. - if (Rows != 0) { + // Call only after MakeResult method. + void FlushBatchToResult() { + // FlushBatch reset Batch pointer in BatchBuilder only if some rows were added after. So we if we have already + // send a batch and try to send an empty batch again without adding rows, then a copy of the batch will be send + // instead. So we check Rows here. + if (Rows != 0) { Result->ArrowBatch = Tags.empty() ? NArrow::CreateNoColumnsBatch(Rows) : BatchBuilder->FlushBatch(true); - } - } - + } + } + void ReportDatashardStats() { Send(DatashardActorId, new TDataShard::TEvPrivate::TEvScanStats(Rows, CellvecBytes)); Rows = 0; - CellvecBytes = 0; + CellvecBytes = 0; } bool IsProfile() const { @@ -508,12 +508,12 @@ private: const TSmallVec<TSerializedTableRange> TableRanges; ui32 CurrentRange; const TSmallVec<NTable::TTag> Tags; - TSmallVec<NScheme::TTypeId> Types; + TSmallVec<NScheme::TTypeId> Types; const TSmallVec<bool> SkipNullKeys; const NYql::NDqProto::EDqStatsMode StatsMode; const TInstant Deadline; const ui32 Generation; - const NKikimrTxDataShard::EScanDataFormat DataFormat; + const NKikimrTxDataShard::EScanDataFormat DataFormat; ui64 PeerFreeSpace = 0; bool Sleep; const bool IsLocal; @@ -523,10 +523,10 @@ private: TActorId TimeoutActorId; TAutoPtr<TEvKqp::TEvAbortExecution> AbortEvent; - THolder<NArrow::TArrowBatchBuilder> BatchBuilder; + THolder<NArrow::TArrowBatchBuilder> BatchBuilder; THolder<TEvKqpCompute::TEvScanData> Result; ui64 Rows = 0; - ui64 CellvecBytes = 0; + ui64 CellvecBytes = 0; ui32 PageFaults = 0; TInstant StartWaitTime; @@ -631,8 +631,8 @@ void TDataShard::Handle(TEvDataShard::TEvKqpScan::TPtr& ev, const TActorContext& std::move(TSmallVec<bool>(request.GetSkipNullKeys().begin(), request.GetSkipNullKeys().end())), request.GetStatsMode(), request.GetTimeoutMs(), - generation, - request.GetDataFormat() + generation, + request.GetDataFormat() ); auto scanOptions = TScanOptions() diff --git a/ydb/library/yql/dq/proto/dq_transport.proto b/ydb/library/yql/dq/proto/dq_transport.proto index 2605861e8e0..d78fa7806c8 100644 --- a/ydb/library/yql/dq/proto/dq_transport.proto +++ b/ydb/library/yql/dq/proto/dq_transport.proto @@ -7,7 +7,7 @@ enum EDataTransportVersion { DATA_TRANSPORT_VERSION_UNSPECIFIED = 0; DATA_TRANSPORT_YSON_1_0 = 10000; DATA_TRANSPORT_UV_PICKLE_1_0 = 20000; - DATA_TRANSPORT_ARROW_1_0 = 30000; + DATA_TRANSPORT_ARROW_1_0 = 30000; } message TData { diff --git a/ydb/library/yql/dq/runtime/dq_arrow_helpers.cpp b/ydb/library/yql/dq/runtime/dq_arrow_helpers.cpp index 5514eed5da0..6d63c9873ef 100644 --- a/ydb/library/yql/dq/runtime/dq_arrow_helpers.cpp +++ b/ydb/library/yql/dq/runtime/dq_arrow_helpers.cpp @@ -1,938 +1,938 @@ -#include "dq_arrow_helpers.h" - -#include <cstddef> +#include "dq_arrow_helpers.h" + +#include <cstddef> #include <ydb/library/yql/public/udf/udf_value.h> #include <ydb/library/yql/minikql/defs.h> #include <ydb/library/yql/minikql/computation/mkql_computation_node_holders.h> #include <ydb/library/yql/minikql/mkql_node.h> - + #include <ydb/core/util/yverify_stream.h> #include <ydb/public/lib/scheme_types/scheme_type_id.h> - -#include <contrib/libs/apache/arrow/cpp/src/arrow/array/builder_base.h> -#include <contrib/libs/apache/arrow/cpp/src/arrow/buffer.h> -#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/api.h> -#include <contrib/libs/apache/arrow/cpp/src/arrow/io/memory.h> -#include <contrib/libs/apache/arrow/cpp/src/arrow/ipc/reader.h> -#include <contrib/libs/apache/arrow/cpp/src/arrow/ipc/writer.h> -#include <contrib/libs/apache/arrow/cpp/src/arrow/record_batch.h> -#include <contrib/libs/apache/arrow/cpp/src/arrow/type.h> -#include <contrib/libs/apache/arrow/cpp/src/arrow/type_fwd.h> -#include <contrib/libs/apache/arrow/cpp/src/arrow/util/compression.h> -#include <contrib/libs/apache/arrow/cpp/src/arrow/util/type_fwd.h> - -#include <util/system/compiler.h> -#include <util/system/yassert.h> - -namespace NYql { -namespace NArrow { - -using namespace NKikimr; -using namespace NMiniKQL; - -namespace { - -template <typename TArrowType> -struct TTypeWrapper -{ - using T = TArrowType; -}; - -/** - * @brief Function to switch MiniKQL DataType correctly and uniformly converting it to arrow type using callback - * - * @tparam TFunc Callback type - * @param typeId Type callback work with. - * @param callback Template function of signature (TTypeWrapper) -> bool - * @return Result of execution of callback or false if the type typeId is not supported. - */ -template <typename TFunc> -bool SwitchMiniKQLDataTypeToArrowType(NUdf::EDataSlot type, TFunc&& callback) { - switch (type) { - case NUdf::EDataSlot::Bool: - return callback(TTypeWrapper<arrow::BooleanType>()); - case NUdf::EDataSlot::Int8: - return callback(TTypeWrapper<arrow::Int8Type>()); - case NUdf::EDataSlot::Uint8: - return callback(TTypeWrapper<arrow::UInt8Type>()); - case NUdf::EDataSlot::Int16: - return callback(TTypeWrapper<arrow::Int16Type>()); - case NUdf::EDataSlot::Date: - case NUdf::EDataSlot::Uint16: - return callback(TTypeWrapper<arrow::UInt16Type>()); - case NUdf::EDataSlot::Int32: - return callback(TTypeWrapper<arrow::Int32Type>()); - case NUdf::EDataSlot::Datetime: - case NUdf::EDataSlot::Uint32: - return callback(TTypeWrapper<arrow::UInt32Type>()); - case NUdf::EDataSlot::Int64: - return callback(TTypeWrapper<arrow::Int64Type>()); - case NUdf::EDataSlot::Uint64: - return callback(TTypeWrapper<arrow::UInt64Type>()); - case NUdf::EDataSlot::Float: - return callback(TTypeWrapper<arrow::FloatType>()); - case NUdf::EDataSlot::Double: - return callback(TTypeWrapper<arrow::DoubleType>()); - case NUdf::EDataSlot::Timestamp: - return callback(TTypeWrapper<arrow::TimestampType>()); - case NUdf::EDataSlot::Interval: - return callback(TTypeWrapper<arrow::DurationType>()); - case NUdf::EDataSlot::Utf8: - case NUdf::EDataSlot::Json: - case NUdf::EDataSlot::Yson: - case NUdf::EDataSlot::JsonDocument: - return callback(TTypeWrapper<arrow::StringType>()); - case NUdf::EDataSlot::String: - case NUdf::EDataSlot::Uuid: - case NUdf::EDataSlot::DyNumber: - return callback(TTypeWrapper<arrow::BinaryType>()); - case NUdf::EDataSlot::Decimal: - return callback(TTypeWrapper<arrow::Decimal128Type>()); - // TODO convert Tz-types to native arrow date and time types. - case NUdf::EDataSlot::TzDate: - case NUdf::EDataSlot::TzDatetime: - case NUdf::EDataSlot::TzTimestamp: - return false; - } -} - -template <typename TArrowType> -NUdf::TUnboxedValue GetUnboxedValue(std::shared_ptr<arrow::Array> column, ui32 row) { - using TArrayType = typename arrow::TypeTraits<TArrowType>::ArrayType; - auto array = std::static_pointer_cast<TArrayType>(column); - return NUdf::TUnboxedValuePod(static_cast<typename TArrowType::c_type>(array->Value(row))); -} - -// The following 4 specialization are for darwin build (because of difference in long long) - -template <> // For darwin build -NUdf::TUnboxedValue GetUnboxedValue<arrow::UInt64Type>(std::shared_ptr<arrow::Array> column, ui32 row) { - auto array = std::static_pointer_cast<arrow::UInt64Array>(column); - return NUdf::TUnboxedValuePod(static_cast<ui64>(array->Value(row))); -} - -template <> // For darwin build -NUdf::TUnboxedValue GetUnboxedValue<arrow::Int64Type>(std::shared_ptr<arrow::Array> column, ui32 row) { - auto array = std::static_pointer_cast<arrow::Int64Array>(column); - return NUdf::TUnboxedValuePod(static_cast<i64>(array->Value(row))); -} - -template <> // For darwin build -NUdf::TUnboxedValue GetUnboxedValue<arrow::TimestampType>(std::shared_ptr<arrow::Array> column, ui32 row) { - using TArrayType = typename arrow::TypeTraits<arrow::TimestampType>::ArrayType; - auto array = std::static_pointer_cast<TArrayType>(column); - return NUdf::TUnboxedValuePod(static_cast<ui64>(array->Value(row))); -} - -template <> // For darwin build -NUdf::TUnboxedValue GetUnboxedValue<arrow::DurationType>(std::shared_ptr<arrow::Array> column, ui32 row) { - using TArrayType = typename arrow::TypeTraits<arrow::DurationType>::ArrayType; - auto array = std::static_pointer_cast<TArrayType>(column); - return NUdf::TUnboxedValuePod(static_cast<ui64>(array->Value(row))); -} - -template <> -NUdf::TUnboxedValue GetUnboxedValue<arrow::BinaryType>(std::shared_ptr<arrow::Array> column, ui32 row) { - auto array = std::static_pointer_cast<arrow::BinaryArray>(column); - auto data = array->GetView(row); - return NMiniKQL::MakeString(NUdf::TStringRef(data.data(), data.size())); -} - -template <> -NUdf::TUnboxedValue GetUnboxedValue<arrow::StringType>(std::shared_ptr<arrow::Array> column, ui32 row) { - auto array = std::static_pointer_cast<arrow::StringArray>(column); - auto data = array->GetView(row); - return NMiniKQL::MakeString(NUdf::TStringRef(data.data(), data.size())); -} - -template <> -NUdf::TUnboxedValue GetUnboxedValue<arrow::Decimal128Type>(std::shared_ptr<arrow::Array> column, ui32 row) { - auto array = std::static_pointer_cast<arrow::Decimal128Array>(column); - auto data = array->GetView(row); - // We check that Decimal(22,9) but it may not be true - // TODO Support other decimal precisions. - const auto& type = arrow::internal::checked_cast<const arrow::Decimal128Type&>(*array->type()); - Y_VERIFY(type.precision() == NScheme::DECIMAL_PRECISION, "Unsupported Decimal precision."); - Y_VERIFY(type.scale() == NScheme::DECIMAL_SCALE, "Unsupported Decimal scale."); - Y_VERIFY(data.size() == sizeof(NYql::NDecimal::TInt128), "Wrong data size"); - NYql::NDecimal::TInt128 val; - std::memcpy(reinterpret_cast<char*>(&val), data.data(), data.size()); - return NUdf::TUnboxedValuePod(val); -} - -template <typename TType> -std::shared_ptr<arrow::DataType> CreateEmptyArrowImpl() { - return std::make_shared<TType>(); -} - -template <> -std::shared_ptr<arrow::DataType> CreateEmptyArrowImpl<arrow::Decimal128Type>() { - // TODO use non-fixed precision, derive it from data. - return arrow::decimal(NScheme::DECIMAL_PRECISION, NScheme::DECIMAL_SCALE); -} - -template <> -std::shared_ptr<arrow::DataType> CreateEmptyArrowImpl<arrow::TimestampType>() { - return arrow::timestamp(arrow::TimeUnit::TimeUnit::MICRO); -} - -template <> -std::shared_ptr<arrow::DataType> CreateEmptyArrowImpl<arrow::DurationType>() { - return arrow::duration(arrow::TimeUnit::TimeUnit::MICRO); -} - -std::shared_ptr<arrow::DataType> GetArrowType(const TDataType* dataType) { - std::shared_ptr<arrow::DataType> result; - bool success = SwitchMiniKQLDataTypeToArrowType(*dataType->GetDataSlot().Get(), [&]<typename TType>(TTypeWrapper<TType> typeHolder) { - Y_UNUSED(typeHolder); - result = CreateEmptyArrowImpl<TType>(); - return true; - }); - if (success) { - return result; - } - return std::make_shared<arrow::NullType>(); -} - -std::shared_ptr<arrow::DataType> GetArrowType(const TStructType* structType) { - std::vector<std::shared_ptr<arrow::Field>> fields; - fields.reserve(structType->GetMembersCount()); - for (ui32 index = 0; index < structType->GetMembersCount(); ++index) { - auto memberType = structType->GetMemberType(index); - fields.emplace_back(std::make_shared<arrow::Field>(std::string(structType->GetMemberName(index)), - NArrow::GetArrowType(memberType))); - } - return arrow::struct_(fields); -} - -std::shared_ptr<arrow::DataType> GetArrowType(const TTupleType* tupleType) { - std::vector<std::shared_ptr<arrow::Field>> fields; - fields.reserve(tupleType->GetElementsCount()); - for (ui32 index = 0; index < tupleType->GetElementsCount(); ++index) { - auto elementType = tupleType->GetElementType(index); - fields.push_back(std::make_shared<arrow::Field>("", NArrow::GetArrowType(elementType))); - } - return arrow::struct_(fields); -} - -std::shared_ptr<arrow::DataType> GetArrowType(const TListType* listType) { - auto itemType = listType->GetItemType(); - return arrow::list(NArrow::GetArrowType(itemType)); -} - -std::shared_ptr<arrow::DataType> GetArrowType(const TDictType* dictType) { - auto keyType = dictType->GetKeyType(); - auto payloadType = dictType->GetPayloadType(); - if (keyType->GetKind() == TType::EKind::Optional) { - std::vector<std::shared_ptr<arrow::Field>> fields; - fields.emplace_back(std::make_shared<arrow::Field>("", NArrow::GetArrowType(keyType))); - fields.emplace_back(std::make_shared<arrow::Field>("", NArrow::GetArrowType(payloadType))); - return arrow::list(arrow::struct_(fields)); - } - return arrow::map(NArrow::GetArrowType(keyType), NArrow::GetArrowType(payloadType)); -} - -std::shared_ptr<arrow::DataType> GetArrowType(const TVariantType* variantType) { - TType* innerType = variantType->GetUnderlyingType(); - arrow::FieldVector types; - TStructType* structType = nullptr; - TTupleType* tupleType = nullptr; - if (innerType->IsStruct()) { - structType = static_cast<TStructType*>(innerType); - } else { - Y_VERIFY_S(innerType->IsTuple(), "Unexpected underlying variant type: " << innerType->GetKindAsStr()); - tupleType = static_cast<TTupleType*>(innerType); - } - - if (variantType->GetAlternativesCount() > arrow::UnionType::kMaxTypeCode) { - // Create Union of unions if there are more types then arrow::dense_union supports. - ui32 numberOfGroups = (variantType->GetAlternativesCount() - 1) / arrow::UnionType::kMaxTypeCode + 1; - types.reserve(numberOfGroups); - for (ui32 groupIndex = 0; groupIndex < numberOfGroups; ++groupIndex) { - ui32 beginIndex = groupIndex * arrow::UnionType::kMaxTypeCode; - ui32 endIndex = std::min((groupIndex + 1) * arrow::UnionType::kMaxTypeCode, variantType->GetAlternativesCount()); - arrow::FieldVector groupTypes; - groupTypes.reserve(endIndex - beginIndex); - if (structType == nullptr) { - for (ui32 index = beginIndex; index < endIndex; ++ index) { - groupTypes.emplace_back(std::make_shared<arrow::Field>("", - NArrow::GetArrowType(tupleType->GetElementType(index)))); - } - } else { - for (ui32 index = beginIndex; index < endIndex; ++ index) { - groupTypes.emplace_back(std::make_shared<arrow::Field>(std::string(structType->GetMemberName(index)), - NArrow::GetArrowType(structType->GetMemberType(index)))); - } - } - types.emplace_back(std::make_shared<arrow::Field>("", arrow::dense_union(groupTypes))); - } - } else { - // Simply put all types in one arrow::dense_union - types.reserve(variantType->GetAlternativesCount()); - if (structType == nullptr) { - for (ui32 index = 0; index < variantType->GetAlternativesCount(); ++index) { - types.push_back(std::make_shared<arrow::Field>("", NArrow::GetArrowType(tupleType->GetElementType(index)))); - } - } else { - for (ui32 index = 0; index < variantType->GetAlternativesCount(); ++index) { - types.emplace_back(std::make_shared<arrow::Field>(std::string(structType->GetMemberName(index)), - NArrow::GetArrowType(structType->GetMemberType(index)))); - } - } - } - return arrow::dense_union(types); -} - -template <typename TArrowType> -void AppendDataValue(arrow::ArrayBuilder* builder, NUdf::TUnboxedValue value) { - auto typedBuilder = reinterpret_cast<typename arrow::TypeTraits<TArrowType>::BuilderType*>(builder); - arrow::Status status; - if (!value.HasValue()) { - status = typedBuilder->AppendNull(); - } else { - status = typedBuilder->Append(value.Get<typename TArrowType::c_type>()); - } - Y_VERIFY_S(status.ok(), status.ToString()); -} - -template <> -void AppendDataValue<arrow::UInt64Type>(arrow::ArrayBuilder* builder, NUdf::TUnboxedValue value) { - Y_VERIFY_DEBUG(builder->type()->id() == arrow::Type::UINT64); - auto typedBuilder = reinterpret_cast<arrow::UInt64Builder*>(builder); - arrow::Status status; - if (!value.HasValue()) { - status = typedBuilder->AppendNull(); - } else { - status = typedBuilder->Append(value.Get<ui64>()); - } - Y_VERIFY_S(status.ok(), status.ToString()); -} - -template <> -void AppendDataValue<arrow::Int64Type>(arrow::ArrayBuilder* builder, NUdf::TUnboxedValue value) { - Y_VERIFY_DEBUG(builder->type()->id() == arrow::Type::INT64); - auto typedBuilder = reinterpret_cast<arrow::Int64Builder*>(builder); - arrow::Status status; - if (!value.HasValue()) { - status = typedBuilder->AppendNull(); - } else { - status = typedBuilder->Append(value.Get<i64>()); - } - Y_VERIFY_S(status.ok(), status.ToString()); -} - -template <> -void AppendDataValue<arrow::TimestampType>(arrow::ArrayBuilder* builder, NUdf::TUnboxedValue value) { - Y_VERIFY_DEBUG(builder->type()->id() == arrow::Type::TIMESTAMP); - auto typedBuilder = reinterpret_cast<arrow::TimestampBuilder*>(builder); - arrow::Status status; - if (!value.HasValue()) { - status = typedBuilder->AppendNull(); - } else { - status = typedBuilder->Append(value.Get<ui64>()); - } - Y_VERIFY_S(status.ok(), status.ToString()); -} - -template <> -void AppendDataValue<arrow::DurationType>(arrow::ArrayBuilder* builder, NUdf::TUnboxedValue value) { - Y_VERIFY_DEBUG(builder->type()->id() == arrow::Type::DURATION); - auto typedBuilder = reinterpret_cast<arrow::DurationBuilder*>(builder); - arrow::Status status; - if (!value.HasValue()) { - status = typedBuilder->AppendNull(); - } else { - status = typedBuilder->Append(value.Get<ui64>()); - } - Y_VERIFY_S(status.ok(), status.ToString()); -} - -template <> -void AppendDataValue<arrow::StringType>(arrow::ArrayBuilder* builder, NUdf::TUnboxedValue value) { - Y_VERIFY_DEBUG(builder->type()->id() == arrow::Type::STRING); - auto typedBuilder = reinterpret_cast<arrow::StringBuilder*>(builder); - arrow::Status status; - if (!value.HasValue()) { - status = typedBuilder->AppendNull(); - } else { - auto data = value.AsStringRef(); - status = typedBuilder->Append(data.Data(), data.Size()); - } - Y_VERIFY_S(status.ok(), status.ToString()); -} - -template <> -void AppendDataValue<arrow::BinaryType>(arrow::ArrayBuilder* builder, NUdf::TUnboxedValue value) { - Y_VERIFY_DEBUG(builder->type()->id() == arrow::Type::BINARY); - auto typedBuilder = reinterpret_cast<arrow::BinaryBuilder*>(builder); - arrow::Status status; - if (!value.HasValue()) { - status = typedBuilder->AppendNull(); - } else { - auto data = value.AsStringRef(); - status = typedBuilder->Append(data.Data(), data.Size()); - } - Y_VERIFY_S(status.ok(), status.ToString()); -} - -template <> -void AppendDataValue<arrow::Decimal128Type>(arrow::ArrayBuilder* builder, NUdf::TUnboxedValue value) { - Y_VERIFY_DEBUG(builder->type()->id() == arrow::Type::DECIMAL128); - auto typedBuilder = reinterpret_cast<arrow::Decimal128Builder*>(builder); - arrow::Status status; - if (!value.HasValue()) { - status = typedBuilder->AppendNull(); - } else { - // Parse value from string - status = typedBuilder->Append(value.AsStringRef().Data()); - } - Y_VERIFY_S(status.ok(), status.ToString()); -} - -} // namespace - -std::shared_ptr<arrow::DataType> GetArrowType(const TType* type) { - switch (type->GetKind()) { - case TType::EKind::Void: - case TType::EKind::Null: - case TType::EKind::EmptyList: - case TType::EKind::EmptyDict: - break; - case TType::EKind::Data: { - auto dataType = static_cast<const TDataType*>(type); - return GetArrowType(dataType); - } - case TType::EKind::Struct: { - auto structType = static_cast<const TStructType*>(type); - return GetArrowType(structType); - } - case TType::EKind::Tuple: { - auto tupleType = static_cast<const TTupleType*>(type); - return GetArrowType(tupleType); - } - case TType::EKind::Optional: { - auto optionalType = static_cast<const TOptionalType*>(type); - auto innerOptionalType = optionalType->GetItemType(); - if (innerOptionalType->GetKind() == TType::EKind::Optional) { - std::vector<std::shared_ptr<arrow::Field>> fields; - fields.emplace_back(std::make_shared<arrow::Field>("", std::make_shared<arrow::UInt64Type>())); - while (innerOptionalType->GetKind() == TType::EKind::Optional) { - innerOptionalType = static_cast<const TOptionalType*>(innerOptionalType)->GetItemType(); - } - fields.emplace_back(std::make_shared<arrow::Field>("", GetArrowType(innerOptionalType))); - return arrow::struct_(fields); - } - return GetArrowType(innerOptionalType); - } - case TType::EKind::List: { - auto listType = static_cast<const TListType*>(type); - return GetArrowType(listType); - } - case TType::EKind::Dict: { - auto dictType = static_cast<const TDictType*>(type); - return GetArrowType(dictType); - } - case TType::EKind::Variant: { - auto variantType = static_cast<const TVariantType*>(type); - return GetArrowType(variantType); - } - default: - THROW yexception() << "Unsupported type: " << type->GetKindAsStr(); - } - return arrow::null(); -} - -bool IsArrowCompatible(const NKikimr::NMiniKQL::TType* type) { - switch (type->GetKind()) { - case TType::EKind::Void: - case TType::EKind::Null: - case TType::EKind::EmptyList: - case TType::EKind::EmptyDict: - case TType::EKind::Data: - return true; - case TType::EKind::Struct: { - auto structType = static_cast<const TStructType*>(type); - bool isCompatible = true; - for (ui32 index = 0; index < structType->GetMembersCount(); ++index) { - auto memberType = structType->GetMemberType(index); - isCompatible = isCompatible && IsArrowCompatible(memberType); - } - return isCompatible; - } - case TType::EKind::Tuple: { - auto tupleType = static_cast<const TTupleType*>(type); - bool isCompatible = true; - for (ui32 index = 0; index < tupleType->GetElementsCount(); ++index) { - auto elementType = tupleType->GetElementType(index); - isCompatible = isCompatible && IsArrowCompatible(elementType); - } - return isCompatible; - } - case TType::EKind::Optional: { - auto optionalType = static_cast<const TOptionalType*>(type); - auto innerOptionalType = optionalType->GetItemType(); - if (innerOptionalType->GetKind() == TType::EKind::Optional) { - return false; - } - return IsArrowCompatible(innerOptionalType); - } - case TType::EKind::List: { - auto listType = static_cast<const TListType*>(type); - auto itemType = listType->GetItemType(); - return IsArrowCompatible(itemType); - } - case TType::EKind::Dict: { - auto dictType = static_cast<const TDictType*>(type); - auto keyType = dictType->GetKeyType(); - auto payloadType = dictType->GetPayloadType(); - if (keyType->GetKind() == TType::EKind::Optional) { - return false; - } - return IsArrowCompatible(keyType) && IsArrowCompatible(payloadType); - } - case TType::EKind::Variant: { - auto variantType = static_cast<const TVariantType*>(type); - if (variantType->GetAlternativesCount() > arrow::UnionType::kMaxTypeCode) { - return false; - } - TType* innerType = variantType->GetUnderlyingType(); - Y_VERIFY_S(innerType->IsTuple() || innerType->IsStruct(), "Unexpected underlying variant type: " << innerType->GetKindAsStr()); - return IsArrowCompatible(innerType); - } - case TType::EKind::Block: - case TType::EKind::Type: - case TType::EKind::Stream: - case TType::EKind::Callable: - case TType::EKind::Any: - case TType::EKind::Resource: - case TType::EKind::ReservedKind: - case TType::EKind::Flow: - case TType::EKind::Tagged: - return false; - } - return false; -} - -std::unique_ptr<arrow::ArrayBuilder> MakeArrowBuilder(const TType* type) { - auto arrayType = GetArrowType(type); - std::unique_ptr<arrow::ArrayBuilder> builder; - auto status = arrow::MakeBuilder(arrow::default_memory_pool(), arrayType, &builder); - Y_VERIFY_S(status.ok(), status.ToString()); - return builder; -} - -void AppendElement(NUdf::TUnboxedValue value, arrow::ArrayBuilder* builder, const TType* type) { - switch (type->GetKind()) { - case TType::EKind::Void: - case TType::EKind::Null: - case TType::EKind::EmptyList: - case TType::EKind::EmptyDict: { - auto status = builder->AppendNull(); - Y_VERIFY_S(status.ok(), status.ToString()); - break; - } - - case TType::EKind::Data: { - // TODO for TzDate, TzDatetime, TzTimestamp pass timezone to arrow builder? - auto dataType = static_cast<const TDataType*>(type); - bool success = SwitchMiniKQLDataTypeToArrowType(*dataType->GetDataSlot().Get(), [&]<typename TType>(TTypeWrapper<TType> typeHolder) { - Y_UNUSED(typeHolder); - AppendDataValue<TType>(builder, value); - return true; - }); - Y_VERIFY(success); - break; - } - - case TType::EKind::Optional: { - auto optionalType = static_cast<const TOptionalType*>(type); - if (optionalType->GetItemType()->GetKind() != TType::EKind::Optional) { - if (value.HasValue()) { - AppendElement(value.GetOptionalValue(), builder, optionalType->GetItemType()); - } else { - auto status = builder->AppendNull(); - Y_VERIFY_S(status.ok(), status.ToString()); - } - } else { - Y_VERIFY_DEBUG(builder->type()->id() == arrow::Type::STRUCT); - auto structBuilder = reinterpret_cast<arrow::StructBuilder*>(builder); - Y_VERIFY_DEBUG(structBuilder->num_fields() == 2); - Y_VERIFY_DEBUG(structBuilder->field_builder(0)->type()->id() == arrow::Type::UINT64); - auto status = structBuilder->Append(); - Y_VERIFY_S(status.ok(), status.ToString()); - auto depthBuilder = reinterpret_cast<arrow::UInt64Builder*>(structBuilder->field_builder(0)); - auto valueBuilder = structBuilder->field_builder(1); - ui64 depth = 0; - TType* innerType = optionalType->GetItemType(); - while (innerType->GetKind() == TType::EKind::Optional && value.HasValue()) { - innerType = static_cast<const TOptionalType*>(innerType)->GetItemType(); - value = value.GetOptionalValue(); - ++depth; - } - status = depthBuilder->Append(depth); - Y_VERIFY_S(status.ok(), status.ToString()); - if (value.HasValue()) { - AppendElement(value, valueBuilder, innerType); - } else { - status = valueBuilder->AppendNull(); - Y_VERIFY_S(status.ok(), status.ToString()); - } - } - break; - } - - case TType::EKind::List: { - auto listType = static_cast<const TListType*>(type); - auto itemType = listType->GetItemType(); - Y_VERIFY_DEBUG(builder->type()->id() == arrow::Type::LIST); - auto listBuilder = reinterpret_cast<arrow::ListBuilder*>(builder); - auto status = listBuilder->Append(); - Y_VERIFY_S(status.ok(), status.ToString()); - auto innerBuilder = listBuilder->value_builder(); - if (auto p = value.GetElements()) { - auto len = value.GetListLength(); - while (len > 0) { - AppendElement(*p++, innerBuilder, itemType); - --len; - } - } else { - const auto iter = value.GetListIterator(); - for (NUdf::TUnboxedValue item; iter.Next(item);) { - AppendElement(item, innerBuilder, itemType); - } - } - break; - } - - case TType::EKind::Struct: { - auto structType = static_cast<const TStructType*>(type); - Y_VERIFY_DEBUG(builder->type()->id() == arrow::Type::STRUCT); - auto structBuilder = reinterpret_cast<arrow::StructBuilder*>(builder); - auto status = structBuilder->Append(); - Y_VERIFY_S(status.ok(), status.ToString()); - Y_VERIFY_DEBUG(static_cast<ui32>(structBuilder->num_fields()) == structType->GetMembersCount()); - for (ui32 index = 0; index < structType->GetMembersCount(); ++index) { - auto innerBuilder = structBuilder->field_builder(index); - auto memberType = structType->GetMemberType(index); - AppendElement(value.GetElement(index), innerBuilder, memberType); - } - break; - } - - case TType::EKind::Tuple: { - auto tupleType = static_cast<const TTupleType*>(type); - Y_VERIFY_DEBUG(builder->type()->id() == arrow::Type::STRUCT); - auto structBuilder = reinterpret_cast<arrow::StructBuilder*>(builder); - auto status = structBuilder->Append(); - Y_VERIFY_S(status.ok(), status.ToString()); - Y_VERIFY_DEBUG(static_cast<ui32>(structBuilder->num_fields()) == tupleType->GetElementsCount()); - for (ui32 index = 0; index < tupleType->GetElementsCount(); ++index) { - auto innerBuilder = structBuilder->field_builder(index); - auto elementType = tupleType->GetElementType(index); - AppendElement(value.GetElement(index), innerBuilder, elementType); - } - break; - } - - case TType::EKind::Dict: { - auto dictType = static_cast<const TDictType*>(type); - auto keyType = dictType->GetKeyType(); - auto payloadType = dictType->GetPayloadType(); - - arrow::ArrayBuilder* keyBuilder; - arrow::ArrayBuilder* itemBuilder; - arrow::StructBuilder* structBuilder = nullptr; - if (keyType->GetKind() == TType::EKind::Optional) { - Y_VERIFY_DEBUG(builder->type()->id() == arrow::Type::LIST); - auto listBuilder = reinterpret_cast<arrow::ListBuilder*>(builder); - Y_VERIFY_DEBUG(listBuilder->value_builder()->type()->id() == arrow::Type::STRUCT); - // Start a new list in ListArray of structs - auto status = listBuilder->Append(); - Y_VERIFY_S(status.ok(), status.ToString()); - structBuilder = reinterpret_cast<arrow::StructBuilder*>(listBuilder->value_builder()); - Y_VERIFY_DEBUG(structBuilder->num_fields() == 2); - keyBuilder = structBuilder->field_builder(0); - itemBuilder = structBuilder->field_builder(1); - } else { - Y_VERIFY_DEBUG(builder->type()->id() == arrow::Type::MAP); - auto mapBuilder = reinterpret_cast<arrow::MapBuilder*>(builder); - // Start a new map in MapArray - auto status = mapBuilder->Append(); - Y_VERIFY_S(status.ok(), status.ToString()); - keyBuilder = mapBuilder->key_builder(); - itemBuilder = mapBuilder->item_builder(); - } - - const auto iter = value.GetDictIterator(); - // We do not sort dictionary before appending it to builder. - for (NUdf::TUnboxedValue key, payload; iter.NextPair(key, payload);) { - if (structBuilder != nullptr) { - auto status = structBuilder->Append(); - Y_VERIFY_S(status.ok(), status.ToString()); - } - AppendElement(key, keyBuilder, keyType); - AppendElement(payload, itemBuilder, payloadType); - } - break; - } - - case TType::EKind::Variant: { - // TODO Need to properly convert variants containing more than 127*127 types? - auto variantType = static_cast<const TVariantType*>(type); - Y_VERIFY_DEBUG(builder->type()->id() == arrow::Type::DENSE_UNION); - auto unionBuilder = reinterpret_cast<arrow::DenseUnionBuilder*>(builder); - ui32 variantIndex = value.GetVariantIndex(); - TType* innerType = variantType->GetUnderlyingType(); - if (innerType->IsStruct()) { - innerType = static_cast<TStructType*>(innerType)->GetMemberType(variantIndex); - } else { - Y_VERIFY_S(innerType->IsTuple(), "Unexpected underlying variant type: " << innerType->GetKindAsStr()); - innerType = static_cast<TTupleType*>(innerType)->GetElementType(variantIndex); - } - if (variantType->GetAlternativesCount() > arrow::UnionType::kMaxTypeCode) { - ui32 numberOfGroups = (variantType->GetAlternativesCount() - 1) / arrow::UnionType::kMaxTypeCode + 1; - Y_VERIFY_DEBUG(static_cast<ui32>(unionBuilder->num_children()) == numberOfGroups); - ui32 groupIndex = variantIndex / arrow::UnionType::kMaxTypeCode; - auto status = unionBuilder->Append(groupIndex); - Y_VERIFY_S(status.ok(), status.ToString()); - auto innerBuilder = unionBuilder->child_builder(groupIndex); - Y_VERIFY_DEBUG(innerBuilder->type()->id() == arrow::Type::DENSE_UNION); - auto innerUnionBuilder = reinterpret_cast<arrow::DenseUnionBuilder*>(innerBuilder.get()); - ui32 innerVariantIndex = variantIndex % arrow::UnionType::kMaxTypeCode; - status = innerUnionBuilder->Append(innerVariantIndex); - Y_VERIFY_S(status.ok(), status.ToString()); - auto doubleInnerBuilder = innerUnionBuilder->child_builder(innerVariantIndex); - AppendElement(value.GetVariantItem(), doubleInnerBuilder.get(), innerType); - } else { - auto status = unionBuilder->Append(variantIndex); - Y_VERIFY_S(status.ok(), status.ToString()); - auto innerBuilder = unionBuilder->child_builder(variantIndex); - AppendElement(value.GetVariantItem(), innerBuilder.get(), innerType); - } - break; - } - - default: - THROW yexception() << "Unsupported type: " << type->GetKindAsStr(); - } -} - -std::shared_ptr<arrow::Array> MakeArray(NMiniKQL::TUnboxedValueVector& values, const TType* itemType) { - auto builder = MakeArrowBuilder(itemType); - auto status = builder->Reserve(values.size()); - Y_VERIFY_S(status.ok(), status.ToString()); - for (auto& value: values) { - AppendElement(value, builder.get(), itemType); - } - std::shared_ptr<arrow::Array> result; - status = builder->Finish(&result); - Y_VERIFY_S(status.ok(), status.ToString()); - return result; -} - -NUdf::TUnboxedValue ExtractUnboxedValue(const std::shared_ptr<arrow::Array>& array, ui64 row, const TType* itemType, const NMiniKQL::THolderFactory& holderFactory) { - if (array->IsNull(row)) { - return NUdf::TUnboxedValuePod(); - } - switch(itemType->GetKind()) { - case TType::EKind::Void: - case TType::EKind::Null: - case TType::EKind::EmptyList: - case TType::EKind::EmptyDict: - break; - case TType::EKind::Data: { // TODO TzDate need special care - auto dataType = static_cast<const TDataType*>(itemType); - NUdf::TUnboxedValue result; - bool success = SwitchMiniKQLDataTypeToArrowType(*dataType->GetDataSlot().Get(), [&]<typename TType>(TTypeWrapper<TType> typeHolder) { - Y_UNUSED(typeHolder); - result = GetUnboxedValue<TType>(array, row); - return true; - }); - Y_VERIFY_DEBUG(success); - return result; - } - case TType::EKind::Struct: { - auto structType = static_cast<const TStructType*>(itemType); - Y_VERIFY_DEBUG(array->type_id() == arrow::Type::STRUCT); - auto typedArray = static_pointer_cast<arrow::StructArray>(array); - Y_VERIFY_DEBUG(static_cast<ui32>(typedArray->num_fields()) == structType->GetMembersCount()); - NUdf::TUnboxedValue* itemsPtr = nullptr; - auto result = holderFactory.CreateDirectArrayHolder(structType->GetMembersCount(), itemsPtr); - for (ui32 index = 0; index < structType->GetMembersCount(); ++index) { - auto memberType = structType->GetMemberType(index); - itemsPtr[index] = ExtractUnboxedValue(typedArray->field(index), row, memberType, holderFactory); - } - return result; - } - case TType::EKind::Tuple: { - auto tupleType = static_cast<const TTupleType*>(itemType); - Y_VERIFY_DEBUG(array->type_id() == arrow::Type::STRUCT); - auto typedArray = static_pointer_cast<arrow::StructArray>(array); - Y_VERIFY_DEBUG(static_cast<ui32>(typedArray->num_fields()) == tupleType->GetElementsCount()); - NUdf::TUnboxedValue* itemsPtr = nullptr; - auto result = holderFactory.CreateDirectArrayHolder(tupleType->GetElementsCount(), itemsPtr); - for (ui32 index = 0; index < tupleType->GetElementsCount(); ++index) { - auto elementType = tupleType->GetElementType(index); - itemsPtr[index] = ExtractUnboxedValue(typedArray->field(index), row, elementType, holderFactory); - } - return result; - } - case TType::EKind::Optional: { - auto optionalType = static_cast<const TOptionalType*>(itemType); - auto innerOptionalType = optionalType->GetItemType(); - if (innerOptionalType->GetKind() == TType::EKind::Optional) { - Y_VERIFY_DEBUG(array->type_id() == arrow::Type::STRUCT); - auto structArray = static_pointer_cast<arrow::StructArray>(array); - Y_VERIFY_DEBUG(structArray->num_fields() == 2); - Y_VERIFY_DEBUG(structArray->field(0)->type_id() == arrow::Type::UINT64); - auto depthArray = static_pointer_cast<arrow::UInt64Array>(structArray->field(0)); - auto valuesArray = structArray->field(1); - auto depth = depthArray->Value(row); - NUdf::TUnboxedValue value; - if (valuesArray->IsNull(row)) { - value = NUdf::TUnboxedValuePod(); - } else { - while (innerOptionalType->GetKind() == TType::EKind::Optional) { - innerOptionalType = static_cast<const TOptionalType*>(innerOptionalType)->GetItemType(); - } - value = ExtractUnboxedValue(valuesArray, row, innerOptionalType, holderFactory); - } - for (ui64 i = 0; i < depth; ++i) { - value = value.MakeOptional(); - } - return value; - } else { - return ExtractUnboxedValue(array, row, innerOptionalType, holderFactory).Release().MakeOptional(); - } - } - case TType::EKind::List: { - auto listType = static_cast<const TListType*>(itemType); - Y_VERIFY_DEBUG(array->type_id() == arrow::Type::LIST); - auto typedArray = static_pointer_cast<arrow::ListArray>(array); - auto arraySlice = typedArray->value_slice(row); - auto itemType = listType->GetItemType(); - const auto len = arraySlice->length(); - NUdf::TUnboxedValue *items = nullptr; - auto list = holderFactory.CreateDirectArrayHolder(len, items); - for (ui64 i = 0; i < static_cast<ui64>(len); ++i) { - *items++ = ExtractUnboxedValue(arraySlice, i, itemType, holderFactory); - } - return list; - } - case TType::EKind::Dict: { - auto dictType = static_cast<const TDictType*>(itemType); - auto keyType = dictType->GetKeyType(); - auto payloadType = dictType->GetPayloadType(); - auto dictBuilder = holderFactory.NewDict(dictType, NUdf::TDictFlags::EDictKind::Hashed); - - std::shared_ptr<arrow::Array> keyArray = nullptr; - std::shared_ptr<arrow::Array> payloadArray = nullptr; - ui64 dictLength = 0; - ui64 offset = 0; - if (keyType->GetKind() == TType::EKind::Optional) { - Y_VERIFY_DEBUG(array->type_id() == arrow::Type::LIST); - auto listArray = static_pointer_cast<arrow::ListArray>(array); - auto arraySlice = listArray->value_slice(row); - Y_VERIFY_DEBUG(arraySlice->type_id() == arrow::Type::STRUCT); - auto structArray = static_pointer_cast<arrow::StructArray>(arraySlice); - Y_VERIFY_DEBUG(structArray->num_fields() == 2); - dictLength = arraySlice->length(); - keyArray = structArray->field(0); - payloadArray = structArray->field(1); - } else { - Y_VERIFY_DEBUG(array->type_id() == arrow::Type::MAP); - auto mapArray = static_pointer_cast<arrow::MapArray>(array); - dictLength = mapArray->value_length(row); - offset = mapArray->value_offset(row); - keyArray = mapArray->keys(); - payloadArray = mapArray->items(); - } - for (ui64 i = offset; i < offset + static_cast<ui64>(dictLength); ++i) { - auto key = ExtractUnboxedValue(keyArray, i, keyType, holderFactory); - auto payload = ExtractUnboxedValue(payloadArray, i, payloadType, holderFactory); - dictBuilder->Add(std::move(key), std::move(payload)); - } - return dictBuilder->Build(); - } - case TType::EKind::Variant: { - // TODO Need to properly convert variants containing more than 127*127 types? - auto variantType = static_cast<const TVariantType*>(itemType); - Y_VERIFY_DEBUG(array->type_id() == arrow::Type::DENSE_UNION); - auto unionArray = static_pointer_cast<arrow::DenseUnionArray>(array); - auto variantIndex = unionArray->child_id(row); - auto rowInChild = unionArray->value_offset(row); - std::shared_ptr<arrow::Array> valuesArray = unionArray->field(variantIndex); - if (variantType->GetAlternativesCount() > arrow::UnionType::kMaxTypeCode) { - // Go one step deeper - Y_VERIFY_DEBUG(valuesArray->type_id() == arrow::Type::DENSE_UNION); - auto innerUnionArray = static_pointer_cast<arrow::DenseUnionArray>(valuesArray); - auto innerVariantIndex = innerUnionArray->child_id(rowInChild); - rowInChild = innerUnionArray->value_offset(rowInChild); - valuesArray = innerUnionArray->field(innerVariantIndex); - variantIndex = variantIndex * arrow::UnionType::kMaxTypeCode + innerVariantIndex; - } - TType* innerType = variantType->GetUnderlyingType(); - if (innerType->IsStruct()) { - innerType = static_cast<TStructType*>(innerType)->GetMemberType(variantIndex); - } else { - Y_VERIFY_S(innerType->IsTuple(), "Unexpected underlying variant type: " << innerType->GetKindAsStr()); - innerType = static_cast<TTupleType*>(innerType)->GetElementType(variantIndex); - } - NUdf::TUnboxedValue value = ExtractUnboxedValue(valuesArray, rowInChild, innerType, holderFactory); - return holderFactory.CreateVariantHolder(value.Release(), variantIndex); - } - default: - THROW yexception() << "Unsupported type: " << itemType->GetKindAsStr(); - } - return NUdf::TUnboxedValuePod(); -} - -NMiniKQL::TUnboxedValueVector ExtractUnboxedValues(const std::shared_ptr<arrow::Array>& array, const TType* itemType, const NMiniKQL::THolderFactory& holderFactory) { - NMiniKQL::TUnboxedValueVector values; - values.reserve(array->length()); - for (auto i = 0 ; i < array->length(); ++i) { - values.push_back(ExtractUnboxedValue(array, i, itemType, holderFactory)); - } - return values; -} - -std::string SerializeArray(const std::shared_ptr<arrow::Array>& array) { - auto schema = std::make_shared<arrow::Schema>(std::vector<std::shared_ptr<arrow::Field>>{arrow::field("", array->type())}); - auto batch = arrow::RecordBatch::Make(schema, array->length(), {array}); - auto writeOptions = arrow::ipc::IpcWriteOptions::Defaults(); // no compression set - writeOptions.use_threads = false; - // TODO Decide which compression level will be default. Will it depend on the length of array? - auto codecResult = arrow::util::Codec::Create(arrow::Compression::LZ4_FRAME); - Y_VERIFY(codecResult.ok()); - writeOptions.codec = std::move(codecResult.ValueOrDie()); - int64_t size; - auto status = GetRecordBatchSize(*batch, writeOptions, &size); - Y_VERIFY(status.ok()); - - std::string str; - str.resize(size); - - auto writer = arrow::Buffer::GetWriter(arrow::MutableBuffer::Wrap(&str[0], size)); - Y_VERIFY(writer.status().ok()); - - status = SerializeRecordBatch(*batch, writeOptions, (*writer).get()); - Y_VERIFY(status.ok()); - return str; -} - -std::shared_ptr<arrow::Array> DeserializeArray(const std::string& blob, std::shared_ptr<arrow::DataType> type) { - arrow::ipc::DictionaryMemo dictMemo; - auto options = arrow::ipc::IpcReadOptions::Defaults(); - options.use_threads = false; - - std::shared_ptr<arrow::Buffer> buffer = arrow::Buffer::FromString(blob); - arrow::io::BufferReader reader(buffer); - auto schema = std::make_shared<arrow::Schema>(std::vector<std::shared_ptr<arrow::Field>>{arrow::field("", type)}); - auto batch = ReadRecordBatch(schema, &dictMemo, options, &reader); - Y_VERIFY_DEBUG(batch.ok() && (*batch)->ValidateFull().ok(), "Failed to deserialize batch"); - return (*batch)->column(0); -} - -} // namespace NArrow + +#include <contrib/libs/apache/arrow/cpp/src/arrow/array/builder_base.h> +#include <contrib/libs/apache/arrow/cpp/src/arrow/buffer.h> +#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/api.h> +#include <contrib/libs/apache/arrow/cpp/src/arrow/io/memory.h> +#include <contrib/libs/apache/arrow/cpp/src/arrow/ipc/reader.h> +#include <contrib/libs/apache/arrow/cpp/src/arrow/ipc/writer.h> +#include <contrib/libs/apache/arrow/cpp/src/arrow/record_batch.h> +#include <contrib/libs/apache/arrow/cpp/src/arrow/type.h> +#include <contrib/libs/apache/arrow/cpp/src/arrow/type_fwd.h> +#include <contrib/libs/apache/arrow/cpp/src/arrow/util/compression.h> +#include <contrib/libs/apache/arrow/cpp/src/arrow/util/type_fwd.h> + +#include <util/system/compiler.h> +#include <util/system/yassert.h> + +namespace NYql { +namespace NArrow { + +using namespace NKikimr; +using namespace NMiniKQL; + +namespace { + +template <typename TArrowType> +struct TTypeWrapper +{ + using T = TArrowType; +}; + +/** + * @brief Function to switch MiniKQL DataType correctly and uniformly converting it to arrow type using callback + * + * @tparam TFunc Callback type + * @param typeId Type callback work with. + * @param callback Template function of signature (TTypeWrapper) -> bool + * @return Result of execution of callback or false if the type typeId is not supported. + */ +template <typename TFunc> +bool SwitchMiniKQLDataTypeToArrowType(NUdf::EDataSlot type, TFunc&& callback) { + switch (type) { + case NUdf::EDataSlot::Bool: + return callback(TTypeWrapper<arrow::BooleanType>()); + case NUdf::EDataSlot::Int8: + return callback(TTypeWrapper<arrow::Int8Type>()); + case NUdf::EDataSlot::Uint8: + return callback(TTypeWrapper<arrow::UInt8Type>()); + case NUdf::EDataSlot::Int16: + return callback(TTypeWrapper<arrow::Int16Type>()); + case NUdf::EDataSlot::Date: + case NUdf::EDataSlot::Uint16: + return callback(TTypeWrapper<arrow::UInt16Type>()); + case NUdf::EDataSlot::Int32: + return callback(TTypeWrapper<arrow::Int32Type>()); + case NUdf::EDataSlot::Datetime: + case NUdf::EDataSlot::Uint32: + return callback(TTypeWrapper<arrow::UInt32Type>()); + case NUdf::EDataSlot::Int64: + return callback(TTypeWrapper<arrow::Int64Type>()); + case NUdf::EDataSlot::Uint64: + return callback(TTypeWrapper<arrow::UInt64Type>()); + case NUdf::EDataSlot::Float: + return callback(TTypeWrapper<arrow::FloatType>()); + case NUdf::EDataSlot::Double: + return callback(TTypeWrapper<arrow::DoubleType>()); + case NUdf::EDataSlot::Timestamp: + return callback(TTypeWrapper<arrow::TimestampType>()); + case NUdf::EDataSlot::Interval: + return callback(TTypeWrapper<arrow::DurationType>()); + case NUdf::EDataSlot::Utf8: + case NUdf::EDataSlot::Json: + case NUdf::EDataSlot::Yson: + case NUdf::EDataSlot::JsonDocument: + return callback(TTypeWrapper<arrow::StringType>()); + case NUdf::EDataSlot::String: + case NUdf::EDataSlot::Uuid: + case NUdf::EDataSlot::DyNumber: + return callback(TTypeWrapper<arrow::BinaryType>()); + case NUdf::EDataSlot::Decimal: + return callback(TTypeWrapper<arrow::Decimal128Type>()); + // TODO convert Tz-types to native arrow date and time types. + case NUdf::EDataSlot::TzDate: + case NUdf::EDataSlot::TzDatetime: + case NUdf::EDataSlot::TzTimestamp: + return false; + } +} + +template <typename TArrowType> +NUdf::TUnboxedValue GetUnboxedValue(std::shared_ptr<arrow::Array> column, ui32 row) { + using TArrayType = typename arrow::TypeTraits<TArrowType>::ArrayType; + auto array = std::static_pointer_cast<TArrayType>(column); + return NUdf::TUnboxedValuePod(static_cast<typename TArrowType::c_type>(array->Value(row))); +} + +// The following 4 specialization are for darwin build (because of difference in long long) + +template <> // For darwin build +NUdf::TUnboxedValue GetUnboxedValue<arrow::UInt64Type>(std::shared_ptr<arrow::Array> column, ui32 row) { + auto array = std::static_pointer_cast<arrow::UInt64Array>(column); + return NUdf::TUnboxedValuePod(static_cast<ui64>(array->Value(row))); +} + +template <> // For darwin build +NUdf::TUnboxedValue GetUnboxedValue<arrow::Int64Type>(std::shared_ptr<arrow::Array> column, ui32 row) { + auto array = std::static_pointer_cast<arrow::Int64Array>(column); + return NUdf::TUnboxedValuePod(static_cast<i64>(array->Value(row))); +} + +template <> // For darwin build +NUdf::TUnboxedValue GetUnboxedValue<arrow::TimestampType>(std::shared_ptr<arrow::Array> column, ui32 row) { + using TArrayType = typename arrow::TypeTraits<arrow::TimestampType>::ArrayType; + auto array = std::static_pointer_cast<TArrayType>(column); + return NUdf::TUnboxedValuePod(static_cast<ui64>(array->Value(row))); +} + +template <> // For darwin build +NUdf::TUnboxedValue GetUnboxedValue<arrow::DurationType>(std::shared_ptr<arrow::Array> column, ui32 row) { + using TArrayType = typename arrow::TypeTraits<arrow::DurationType>::ArrayType; + auto array = std::static_pointer_cast<TArrayType>(column); + return NUdf::TUnboxedValuePod(static_cast<ui64>(array->Value(row))); +} + +template <> +NUdf::TUnboxedValue GetUnboxedValue<arrow::BinaryType>(std::shared_ptr<arrow::Array> column, ui32 row) { + auto array = std::static_pointer_cast<arrow::BinaryArray>(column); + auto data = array->GetView(row); + return NMiniKQL::MakeString(NUdf::TStringRef(data.data(), data.size())); +} + +template <> +NUdf::TUnboxedValue GetUnboxedValue<arrow::StringType>(std::shared_ptr<arrow::Array> column, ui32 row) { + auto array = std::static_pointer_cast<arrow::StringArray>(column); + auto data = array->GetView(row); + return NMiniKQL::MakeString(NUdf::TStringRef(data.data(), data.size())); +} + +template <> +NUdf::TUnboxedValue GetUnboxedValue<arrow::Decimal128Type>(std::shared_ptr<arrow::Array> column, ui32 row) { + auto array = std::static_pointer_cast<arrow::Decimal128Array>(column); + auto data = array->GetView(row); + // We check that Decimal(22,9) but it may not be true + // TODO Support other decimal precisions. + const auto& type = arrow::internal::checked_cast<const arrow::Decimal128Type&>(*array->type()); + Y_VERIFY(type.precision() == NScheme::DECIMAL_PRECISION, "Unsupported Decimal precision."); + Y_VERIFY(type.scale() == NScheme::DECIMAL_SCALE, "Unsupported Decimal scale."); + Y_VERIFY(data.size() == sizeof(NYql::NDecimal::TInt128), "Wrong data size"); + NYql::NDecimal::TInt128 val; + std::memcpy(reinterpret_cast<char*>(&val), data.data(), data.size()); + return NUdf::TUnboxedValuePod(val); +} + +template <typename TType> +std::shared_ptr<arrow::DataType> CreateEmptyArrowImpl() { + return std::make_shared<TType>(); +} + +template <> +std::shared_ptr<arrow::DataType> CreateEmptyArrowImpl<arrow::Decimal128Type>() { + // TODO use non-fixed precision, derive it from data. + return arrow::decimal(NScheme::DECIMAL_PRECISION, NScheme::DECIMAL_SCALE); +} + +template <> +std::shared_ptr<arrow::DataType> CreateEmptyArrowImpl<arrow::TimestampType>() { + return arrow::timestamp(arrow::TimeUnit::TimeUnit::MICRO); +} + +template <> +std::shared_ptr<arrow::DataType> CreateEmptyArrowImpl<arrow::DurationType>() { + return arrow::duration(arrow::TimeUnit::TimeUnit::MICRO); +} + +std::shared_ptr<arrow::DataType> GetArrowType(const TDataType* dataType) { + std::shared_ptr<arrow::DataType> result; + bool success = SwitchMiniKQLDataTypeToArrowType(*dataType->GetDataSlot().Get(), [&]<typename TType>(TTypeWrapper<TType> typeHolder) { + Y_UNUSED(typeHolder); + result = CreateEmptyArrowImpl<TType>(); + return true; + }); + if (success) { + return result; + } + return std::make_shared<arrow::NullType>(); +} + +std::shared_ptr<arrow::DataType> GetArrowType(const TStructType* structType) { + std::vector<std::shared_ptr<arrow::Field>> fields; + fields.reserve(structType->GetMembersCount()); + for (ui32 index = 0; index < structType->GetMembersCount(); ++index) { + auto memberType = structType->GetMemberType(index); + fields.emplace_back(std::make_shared<arrow::Field>(std::string(structType->GetMemberName(index)), + NArrow::GetArrowType(memberType))); + } + return arrow::struct_(fields); +} + +std::shared_ptr<arrow::DataType> GetArrowType(const TTupleType* tupleType) { + std::vector<std::shared_ptr<arrow::Field>> fields; + fields.reserve(tupleType->GetElementsCount()); + for (ui32 index = 0; index < tupleType->GetElementsCount(); ++index) { + auto elementType = tupleType->GetElementType(index); + fields.push_back(std::make_shared<arrow::Field>("", NArrow::GetArrowType(elementType))); + } + return arrow::struct_(fields); +} + +std::shared_ptr<arrow::DataType> GetArrowType(const TListType* listType) { + auto itemType = listType->GetItemType(); + return arrow::list(NArrow::GetArrowType(itemType)); +} + +std::shared_ptr<arrow::DataType> GetArrowType(const TDictType* dictType) { + auto keyType = dictType->GetKeyType(); + auto payloadType = dictType->GetPayloadType(); + if (keyType->GetKind() == TType::EKind::Optional) { + std::vector<std::shared_ptr<arrow::Field>> fields; + fields.emplace_back(std::make_shared<arrow::Field>("", NArrow::GetArrowType(keyType))); + fields.emplace_back(std::make_shared<arrow::Field>("", NArrow::GetArrowType(payloadType))); + return arrow::list(arrow::struct_(fields)); + } + return arrow::map(NArrow::GetArrowType(keyType), NArrow::GetArrowType(payloadType)); +} + +std::shared_ptr<arrow::DataType> GetArrowType(const TVariantType* variantType) { + TType* innerType = variantType->GetUnderlyingType(); + arrow::FieldVector types; + TStructType* structType = nullptr; + TTupleType* tupleType = nullptr; + if (innerType->IsStruct()) { + structType = static_cast<TStructType*>(innerType); + } else { + Y_VERIFY_S(innerType->IsTuple(), "Unexpected underlying variant type: " << innerType->GetKindAsStr()); + tupleType = static_cast<TTupleType*>(innerType); + } + + if (variantType->GetAlternativesCount() > arrow::UnionType::kMaxTypeCode) { + // Create Union of unions if there are more types then arrow::dense_union supports. + ui32 numberOfGroups = (variantType->GetAlternativesCount() - 1) / arrow::UnionType::kMaxTypeCode + 1; + types.reserve(numberOfGroups); + for (ui32 groupIndex = 0; groupIndex < numberOfGroups; ++groupIndex) { + ui32 beginIndex = groupIndex * arrow::UnionType::kMaxTypeCode; + ui32 endIndex = std::min((groupIndex + 1) * arrow::UnionType::kMaxTypeCode, variantType->GetAlternativesCount()); + arrow::FieldVector groupTypes; + groupTypes.reserve(endIndex - beginIndex); + if (structType == nullptr) { + for (ui32 index = beginIndex; index < endIndex; ++ index) { + groupTypes.emplace_back(std::make_shared<arrow::Field>("", + NArrow::GetArrowType(tupleType->GetElementType(index)))); + } + } else { + for (ui32 index = beginIndex; index < endIndex; ++ index) { + groupTypes.emplace_back(std::make_shared<arrow::Field>(std::string(structType->GetMemberName(index)), + NArrow::GetArrowType(structType->GetMemberType(index)))); + } + } + types.emplace_back(std::make_shared<arrow::Field>("", arrow::dense_union(groupTypes))); + } + } else { + // Simply put all types in one arrow::dense_union + types.reserve(variantType->GetAlternativesCount()); + if (structType == nullptr) { + for (ui32 index = 0; index < variantType->GetAlternativesCount(); ++index) { + types.push_back(std::make_shared<arrow::Field>("", NArrow::GetArrowType(tupleType->GetElementType(index)))); + } + } else { + for (ui32 index = 0; index < variantType->GetAlternativesCount(); ++index) { + types.emplace_back(std::make_shared<arrow::Field>(std::string(structType->GetMemberName(index)), + NArrow::GetArrowType(structType->GetMemberType(index)))); + } + } + } + return arrow::dense_union(types); +} + +template <typename TArrowType> +void AppendDataValue(arrow::ArrayBuilder* builder, NUdf::TUnboxedValue value) { + auto typedBuilder = reinterpret_cast<typename arrow::TypeTraits<TArrowType>::BuilderType*>(builder); + arrow::Status status; + if (!value.HasValue()) { + status = typedBuilder->AppendNull(); + } else { + status = typedBuilder->Append(value.Get<typename TArrowType::c_type>()); + } + Y_VERIFY_S(status.ok(), status.ToString()); +} + +template <> +void AppendDataValue<arrow::UInt64Type>(arrow::ArrayBuilder* builder, NUdf::TUnboxedValue value) { + Y_VERIFY_DEBUG(builder->type()->id() == arrow::Type::UINT64); + auto typedBuilder = reinterpret_cast<arrow::UInt64Builder*>(builder); + arrow::Status status; + if (!value.HasValue()) { + status = typedBuilder->AppendNull(); + } else { + status = typedBuilder->Append(value.Get<ui64>()); + } + Y_VERIFY_S(status.ok(), status.ToString()); +} + +template <> +void AppendDataValue<arrow::Int64Type>(arrow::ArrayBuilder* builder, NUdf::TUnboxedValue value) { + Y_VERIFY_DEBUG(builder->type()->id() == arrow::Type::INT64); + auto typedBuilder = reinterpret_cast<arrow::Int64Builder*>(builder); + arrow::Status status; + if (!value.HasValue()) { + status = typedBuilder->AppendNull(); + } else { + status = typedBuilder->Append(value.Get<i64>()); + } + Y_VERIFY_S(status.ok(), status.ToString()); +} + +template <> +void AppendDataValue<arrow::TimestampType>(arrow::ArrayBuilder* builder, NUdf::TUnboxedValue value) { + Y_VERIFY_DEBUG(builder->type()->id() == arrow::Type::TIMESTAMP); + auto typedBuilder = reinterpret_cast<arrow::TimestampBuilder*>(builder); + arrow::Status status; + if (!value.HasValue()) { + status = typedBuilder->AppendNull(); + } else { + status = typedBuilder->Append(value.Get<ui64>()); + } + Y_VERIFY_S(status.ok(), status.ToString()); +} + +template <> +void AppendDataValue<arrow::DurationType>(arrow::ArrayBuilder* builder, NUdf::TUnboxedValue value) { + Y_VERIFY_DEBUG(builder->type()->id() == arrow::Type::DURATION); + auto typedBuilder = reinterpret_cast<arrow::DurationBuilder*>(builder); + arrow::Status status; + if (!value.HasValue()) { + status = typedBuilder->AppendNull(); + } else { + status = typedBuilder->Append(value.Get<ui64>()); + } + Y_VERIFY_S(status.ok(), status.ToString()); +} + +template <> +void AppendDataValue<arrow::StringType>(arrow::ArrayBuilder* builder, NUdf::TUnboxedValue value) { + Y_VERIFY_DEBUG(builder->type()->id() == arrow::Type::STRING); + auto typedBuilder = reinterpret_cast<arrow::StringBuilder*>(builder); + arrow::Status status; + if (!value.HasValue()) { + status = typedBuilder->AppendNull(); + } else { + auto data = value.AsStringRef(); + status = typedBuilder->Append(data.Data(), data.Size()); + } + Y_VERIFY_S(status.ok(), status.ToString()); +} + +template <> +void AppendDataValue<arrow::BinaryType>(arrow::ArrayBuilder* builder, NUdf::TUnboxedValue value) { + Y_VERIFY_DEBUG(builder->type()->id() == arrow::Type::BINARY); + auto typedBuilder = reinterpret_cast<arrow::BinaryBuilder*>(builder); + arrow::Status status; + if (!value.HasValue()) { + status = typedBuilder->AppendNull(); + } else { + auto data = value.AsStringRef(); + status = typedBuilder->Append(data.Data(), data.Size()); + } + Y_VERIFY_S(status.ok(), status.ToString()); +} + +template <> +void AppendDataValue<arrow::Decimal128Type>(arrow::ArrayBuilder* builder, NUdf::TUnboxedValue value) { + Y_VERIFY_DEBUG(builder->type()->id() == arrow::Type::DECIMAL128); + auto typedBuilder = reinterpret_cast<arrow::Decimal128Builder*>(builder); + arrow::Status status; + if (!value.HasValue()) { + status = typedBuilder->AppendNull(); + } else { + // Parse value from string + status = typedBuilder->Append(value.AsStringRef().Data()); + } + Y_VERIFY_S(status.ok(), status.ToString()); +} + +} // namespace + +std::shared_ptr<arrow::DataType> GetArrowType(const TType* type) { + switch (type->GetKind()) { + case TType::EKind::Void: + case TType::EKind::Null: + case TType::EKind::EmptyList: + case TType::EKind::EmptyDict: + break; + case TType::EKind::Data: { + auto dataType = static_cast<const TDataType*>(type); + return GetArrowType(dataType); + } + case TType::EKind::Struct: { + auto structType = static_cast<const TStructType*>(type); + return GetArrowType(structType); + } + case TType::EKind::Tuple: { + auto tupleType = static_cast<const TTupleType*>(type); + return GetArrowType(tupleType); + } + case TType::EKind::Optional: { + auto optionalType = static_cast<const TOptionalType*>(type); + auto innerOptionalType = optionalType->GetItemType(); + if (innerOptionalType->GetKind() == TType::EKind::Optional) { + std::vector<std::shared_ptr<arrow::Field>> fields; + fields.emplace_back(std::make_shared<arrow::Field>("", std::make_shared<arrow::UInt64Type>())); + while (innerOptionalType->GetKind() == TType::EKind::Optional) { + innerOptionalType = static_cast<const TOptionalType*>(innerOptionalType)->GetItemType(); + } + fields.emplace_back(std::make_shared<arrow::Field>("", GetArrowType(innerOptionalType))); + return arrow::struct_(fields); + } + return GetArrowType(innerOptionalType); + } + case TType::EKind::List: { + auto listType = static_cast<const TListType*>(type); + return GetArrowType(listType); + } + case TType::EKind::Dict: { + auto dictType = static_cast<const TDictType*>(type); + return GetArrowType(dictType); + } + case TType::EKind::Variant: { + auto variantType = static_cast<const TVariantType*>(type); + return GetArrowType(variantType); + } + default: + THROW yexception() << "Unsupported type: " << type->GetKindAsStr(); + } + return arrow::null(); +} + +bool IsArrowCompatible(const NKikimr::NMiniKQL::TType* type) { + switch (type->GetKind()) { + case TType::EKind::Void: + case TType::EKind::Null: + case TType::EKind::EmptyList: + case TType::EKind::EmptyDict: + case TType::EKind::Data: + return true; + case TType::EKind::Struct: { + auto structType = static_cast<const TStructType*>(type); + bool isCompatible = true; + for (ui32 index = 0; index < structType->GetMembersCount(); ++index) { + auto memberType = structType->GetMemberType(index); + isCompatible = isCompatible && IsArrowCompatible(memberType); + } + return isCompatible; + } + case TType::EKind::Tuple: { + auto tupleType = static_cast<const TTupleType*>(type); + bool isCompatible = true; + for (ui32 index = 0; index < tupleType->GetElementsCount(); ++index) { + auto elementType = tupleType->GetElementType(index); + isCompatible = isCompatible && IsArrowCompatible(elementType); + } + return isCompatible; + } + case TType::EKind::Optional: { + auto optionalType = static_cast<const TOptionalType*>(type); + auto innerOptionalType = optionalType->GetItemType(); + if (innerOptionalType->GetKind() == TType::EKind::Optional) { + return false; + } + return IsArrowCompatible(innerOptionalType); + } + case TType::EKind::List: { + auto listType = static_cast<const TListType*>(type); + auto itemType = listType->GetItemType(); + return IsArrowCompatible(itemType); + } + case TType::EKind::Dict: { + auto dictType = static_cast<const TDictType*>(type); + auto keyType = dictType->GetKeyType(); + auto payloadType = dictType->GetPayloadType(); + if (keyType->GetKind() == TType::EKind::Optional) { + return false; + } + return IsArrowCompatible(keyType) && IsArrowCompatible(payloadType); + } + case TType::EKind::Variant: { + auto variantType = static_cast<const TVariantType*>(type); + if (variantType->GetAlternativesCount() > arrow::UnionType::kMaxTypeCode) { + return false; + } + TType* innerType = variantType->GetUnderlyingType(); + Y_VERIFY_S(innerType->IsTuple() || innerType->IsStruct(), "Unexpected underlying variant type: " << innerType->GetKindAsStr()); + return IsArrowCompatible(innerType); + } + case TType::EKind::Block: + case TType::EKind::Type: + case TType::EKind::Stream: + case TType::EKind::Callable: + case TType::EKind::Any: + case TType::EKind::Resource: + case TType::EKind::ReservedKind: + case TType::EKind::Flow: + case TType::EKind::Tagged: + return false; + } + return false; +} + +std::unique_ptr<arrow::ArrayBuilder> MakeArrowBuilder(const TType* type) { + auto arrayType = GetArrowType(type); + std::unique_ptr<arrow::ArrayBuilder> builder; + auto status = arrow::MakeBuilder(arrow::default_memory_pool(), arrayType, &builder); + Y_VERIFY_S(status.ok(), status.ToString()); + return builder; +} + +void AppendElement(NUdf::TUnboxedValue value, arrow::ArrayBuilder* builder, const TType* type) { + switch (type->GetKind()) { + case TType::EKind::Void: + case TType::EKind::Null: + case TType::EKind::EmptyList: + case TType::EKind::EmptyDict: { + auto status = builder->AppendNull(); + Y_VERIFY_S(status.ok(), status.ToString()); + break; + } + + case TType::EKind::Data: { + // TODO for TzDate, TzDatetime, TzTimestamp pass timezone to arrow builder? + auto dataType = static_cast<const TDataType*>(type); + bool success = SwitchMiniKQLDataTypeToArrowType(*dataType->GetDataSlot().Get(), [&]<typename TType>(TTypeWrapper<TType> typeHolder) { + Y_UNUSED(typeHolder); + AppendDataValue<TType>(builder, value); + return true; + }); + Y_VERIFY(success); + break; + } + + case TType::EKind::Optional: { + auto optionalType = static_cast<const TOptionalType*>(type); + if (optionalType->GetItemType()->GetKind() != TType::EKind::Optional) { + if (value.HasValue()) { + AppendElement(value.GetOptionalValue(), builder, optionalType->GetItemType()); + } else { + auto status = builder->AppendNull(); + Y_VERIFY_S(status.ok(), status.ToString()); + } + } else { + Y_VERIFY_DEBUG(builder->type()->id() == arrow::Type::STRUCT); + auto structBuilder = reinterpret_cast<arrow::StructBuilder*>(builder); + Y_VERIFY_DEBUG(structBuilder->num_fields() == 2); + Y_VERIFY_DEBUG(structBuilder->field_builder(0)->type()->id() == arrow::Type::UINT64); + auto status = structBuilder->Append(); + Y_VERIFY_S(status.ok(), status.ToString()); + auto depthBuilder = reinterpret_cast<arrow::UInt64Builder*>(structBuilder->field_builder(0)); + auto valueBuilder = structBuilder->field_builder(1); + ui64 depth = 0; + TType* innerType = optionalType->GetItemType(); + while (innerType->GetKind() == TType::EKind::Optional && value.HasValue()) { + innerType = static_cast<const TOptionalType*>(innerType)->GetItemType(); + value = value.GetOptionalValue(); + ++depth; + } + status = depthBuilder->Append(depth); + Y_VERIFY_S(status.ok(), status.ToString()); + if (value.HasValue()) { + AppendElement(value, valueBuilder, innerType); + } else { + status = valueBuilder->AppendNull(); + Y_VERIFY_S(status.ok(), status.ToString()); + } + } + break; + } + + case TType::EKind::List: { + auto listType = static_cast<const TListType*>(type); + auto itemType = listType->GetItemType(); + Y_VERIFY_DEBUG(builder->type()->id() == arrow::Type::LIST); + auto listBuilder = reinterpret_cast<arrow::ListBuilder*>(builder); + auto status = listBuilder->Append(); + Y_VERIFY_S(status.ok(), status.ToString()); + auto innerBuilder = listBuilder->value_builder(); + if (auto p = value.GetElements()) { + auto len = value.GetListLength(); + while (len > 0) { + AppendElement(*p++, innerBuilder, itemType); + --len; + } + } else { + const auto iter = value.GetListIterator(); + for (NUdf::TUnboxedValue item; iter.Next(item);) { + AppendElement(item, innerBuilder, itemType); + } + } + break; + } + + case TType::EKind::Struct: { + auto structType = static_cast<const TStructType*>(type); + Y_VERIFY_DEBUG(builder->type()->id() == arrow::Type::STRUCT); + auto structBuilder = reinterpret_cast<arrow::StructBuilder*>(builder); + auto status = structBuilder->Append(); + Y_VERIFY_S(status.ok(), status.ToString()); + Y_VERIFY_DEBUG(static_cast<ui32>(structBuilder->num_fields()) == structType->GetMembersCount()); + for (ui32 index = 0; index < structType->GetMembersCount(); ++index) { + auto innerBuilder = structBuilder->field_builder(index); + auto memberType = structType->GetMemberType(index); + AppendElement(value.GetElement(index), innerBuilder, memberType); + } + break; + } + + case TType::EKind::Tuple: { + auto tupleType = static_cast<const TTupleType*>(type); + Y_VERIFY_DEBUG(builder->type()->id() == arrow::Type::STRUCT); + auto structBuilder = reinterpret_cast<arrow::StructBuilder*>(builder); + auto status = structBuilder->Append(); + Y_VERIFY_S(status.ok(), status.ToString()); + Y_VERIFY_DEBUG(static_cast<ui32>(structBuilder->num_fields()) == tupleType->GetElementsCount()); + for (ui32 index = 0; index < tupleType->GetElementsCount(); ++index) { + auto innerBuilder = structBuilder->field_builder(index); + auto elementType = tupleType->GetElementType(index); + AppendElement(value.GetElement(index), innerBuilder, elementType); + } + break; + } + + case TType::EKind::Dict: { + auto dictType = static_cast<const TDictType*>(type); + auto keyType = dictType->GetKeyType(); + auto payloadType = dictType->GetPayloadType(); + + arrow::ArrayBuilder* keyBuilder; + arrow::ArrayBuilder* itemBuilder; + arrow::StructBuilder* structBuilder = nullptr; + if (keyType->GetKind() == TType::EKind::Optional) { + Y_VERIFY_DEBUG(builder->type()->id() == arrow::Type::LIST); + auto listBuilder = reinterpret_cast<arrow::ListBuilder*>(builder); + Y_VERIFY_DEBUG(listBuilder->value_builder()->type()->id() == arrow::Type::STRUCT); + // Start a new list in ListArray of structs + auto status = listBuilder->Append(); + Y_VERIFY_S(status.ok(), status.ToString()); + structBuilder = reinterpret_cast<arrow::StructBuilder*>(listBuilder->value_builder()); + Y_VERIFY_DEBUG(structBuilder->num_fields() == 2); + keyBuilder = structBuilder->field_builder(0); + itemBuilder = structBuilder->field_builder(1); + } else { + Y_VERIFY_DEBUG(builder->type()->id() == arrow::Type::MAP); + auto mapBuilder = reinterpret_cast<arrow::MapBuilder*>(builder); + // Start a new map in MapArray + auto status = mapBuilder->Append(); + Y_VERIFY_S(status.ok(), status.ToString()); + keyBuilder = mapBuilder->key_builder(); + itemBuilder = mapBuilder->item_builder(); + } + + const auto iter = value.GetDictIterator(); + // We do not sort dictionary before appending it to builder. + for (NUdf::TUnboxedValue key, payload; iter.NextPair(key, payload);) { + if (structBuilder != nullptr) { + auto status = structBuilder->Append(); + Y_VERIFY_S(status.ok(), status.ToString()); + } + AppendElement(key, keyBuilder, keyType); + AppendElement(payload, itemBuilder, payloadType); + } + break; + } + + case TType::EKind::Variant: { + // TODO Need to properly convert variants containing more than 127*127 types? + auto variantType = static_cast<const TVariantType*>(type); + Y_VERIFY_DEBUG(builder->type()->id() == arrow::Type::DENSE_UNION); + auto unionBuilder = reinterpret_cast<arrow::DenseUnionBuilder*>(builder); + ui32 variantIndex = value.GetVariantIndex(); + TType* innerType = variantType->GetUnderlyingType(); + if (innerType->IsStruct()) { + innerType = static_cast<TStructType*>(innerType)->GetMemberType(variantIndex); + } else { + Y_VERIFY_S(innerType->IsTuple(), "Unexpected underlying variant type: " << innerType->GetKindAsStr()); + innerType = static_cast<TTupleType*>(innerType)->GetElementType(variantIndex); + } + if (variantType->GetAlternativesCount() > arrow::UnionType::kMaxTypeCode) { + ui32 numberOfGroups = (variantType->GetAlternativesCount() - 1) / arrow::UnionType::kMaxTypeCode + 1; + Y_VERIFY_DEBUG(static_cast<ui32>(unionBuilder->num_children()) == numberOfGroups); + ui32 groupIndex = variantIndex / arrow::UnionType::kMaxTypeCode; + auto status = unionBuilder->Append(groupIndex); + Y_VERIFY_S(status.ok(), status.ToString()); + auto innerBuilder = unionBuilder->child_builder(groupIndex); + Y_VERIFY_DEBUG(innerBuilder->type()->id() == arrow::Type::DENSE_UNION); + auto innerUnionBuilder = reinterpret_cast<arrow::DenseUnionBuilder*>(innerBuilder.get()); + ui32 innerVariantIndex = variantIndex % arrow::UnionType::kMaxTypeCode; + status = innerUnionBuilder->Append(innerVariantIndex); + Y_VERIFY_S(status.ok(), status.ToString()); + auto doubleInnerBuilder = innerUnionBuilder->child_builder(innerVariantIndex); + AppendElement(value.GetVariantItem(), doubleInnerBuilder.get(), innerType); + } else { + auto status = unionBuilder->Append(variantIndex); + Y_VERIFY_S(status.ok(), status.ToString()); + auto innerBuilder = unionBuilder->child_builder(variantIndex); + AppendElement(value.GetVariantItem(), innerBuilder.get(), innerType); + } + break; + } + + default: + THROW yexception() << "Unsupported type: " << type->GetKindAsStr(); + } +} + +std::shared_ptr<arrow::Array> MakeArray(NMiniKQL::TUnboxedValueVector& values, const TType* itemType) { + auto builder = MakeArrowBuilder(itemType); + auto status = builder->Reserve(values.size()); + Y_VERIFY_S(status.ok(), status.ToString()); + for (auto& value: values) { + AppendElement(value, builder.get(), itemType); + } + std::shared_ptr<arrow::Array> result; + status = builder->Finish(&result); + Y_VERIFY_S(status.ok(), status.ToString()); + return result; +} + +NUdf::TUnboxedValue ExtractUnboxedValue(const std::shared_ptr<arrow::Array>& array, ui64 row, const TType* itemType, const NMiniKQL::THolderFactory& holderFactory) { + if (array->IsNull(row)) { + return NUdf::TUnboxedValuePod(); + } + switch(itemType->GetKind()) { + case TType::EKind::Void: + case TType::EKind::Null: + case TType::EKind::EmptyList: + case TType::EKind::EmptyDict: + break; + case TType::EKind::Data: { // TODO TzDate need special care + auto dataType = static_cast<const TDataType*>(itemType); + NUdf::TUnboxedValue result; + bool success = SwitchMiniKQLDataTypeToArrowType(*dataType->GetDataSlot().Get(), [&]<typename TType>(TTypeWrapper<TType> typeHolder) { + Y_UNUSED(typeHolder); + result = GetUnboxedValue<TType>(array, row); + return true; + }); + Y_VERIFY_DEBUG(success); + return result; + } + case TType::EKind::Struct: { + auto structType = static_cast<const TStructType*>(itemType); + Y_VERIFY_DEBUG(array->type_id() == arrow::Type::STRUCT); + auto typedArray = static_pointer_cast<arrow::StructArray>(array); + Y_VERIFY_DEBUG(static_cast<ui32>(typedArray->num_fields()) == structType->GetMembersCount()); + NUdf::TUnboxedValue* itemsPtr = nullptr; + auto result = holderFactory.CreateDirectArrayHolder(structType->GetMembersCount(), itemsPtr); + for (ui32 index = 0; index < structType->GetMembersCount(); ++index) { + auto memberType = structType->GetMemberType(index); + itemsPtr[index] = ExtractUnboxedValue(typedArray->field(index), row, memberType, holderFactory); + } + return result; + } + case TType::EKind::Tuple: { + auto tupleType = static_cast<const TTupleType*>(itemType); + Y_VERIFY_DEBUG(array->type_id() == arrow::Type::STRUCT); + auto typedArray = static_pointer_cast<arrow::StructArray>(array); + Y_VERIFY_DEBUG(static_cast<ui32>(typedArray->num_fields()) == tupleType->GetElementsCount()); + NUdf::TUnboxedValue* itemsPtr = nullptr; + auto result = holderFactory.CreateDirectArrayHolder(tupleType->GetElementsCount(), itemsPtr); + for (ui32 index = 0; index < tupleType->GetElementsCount(); ++index) { + auto elementType = tupleType->GetElementType(index); + itemsPtr[index] = ExtractUnboxedValue(typedArray->field(index), row, elementType, holderFactory); + } + return result; + } + case TType::EKind::Optional: { + auto optionalType = static_cast<const TOptionalType*>(itemType); + auto innerOptionalType = optionalType->GetItemType(); + if (innerOptionalType->GetKind() == TType::EKind::Optional) { + Y_VERIFY_DEBUG(array->type_id() == arrow::Type::STRUCT); + auto structArray = static_pointer_cast<arrow::StructArray>(array); + Y_VERIFY_DEBUG(structArray->num_fields() == 2); + Y_VERIFY_DEBUG(structArray->field(0)->type_id() == arrow::Type::UINT64); + auto depthArray = static_pointer_cast<arrow::UInt64Array>(structArray->field(0)); + auto valuesArray = structArray->field(1); + auto depth = depthArray->Value(row); + NUdf::TUnboxedValue value; + if (valuesArray->IsNull(row)) { + value = NUdf::TUnboxedValuePod(); + } else { + while (innerOptionalType->GetKind() == TType::EKind::Optional) { + innerOptionalType = static_cast<const TOptionalType*>(innerOptionalType)->GetItemType(); + } + value = ExtractUnboxedValue(valuesArray, row, innerOptionalType, holderFactory); + } + for (ui64 i = 0; i < depth; ++i) { + value = value.MakeOptional(); + } + return value; + } else { + return ExtractUnboxedValue(array, row, innerOptionalType, holderFactory).Release().MakeOptional(); + } + } + case TType::EKind::List: { + auto listType = static_cast<const TListType*>(itemType); + Y_VERIFY_DEBUG(array->type_id() == arrow::Type::LIST); + auto typedArray = static_pointer_cast<arrow::ListArray>(array); + auto arraySlice = typedArray->value_slice(row); + auto itemType = listType->GetItemType(); + const auto len = arraySlice->length(); + NUdf::TUnboxedValue *items = nullptr; + auto list = holderFactory.CreateDirectArrayHolder(len, items); + for (ui64 i = 0; i < static_cast<ui64>(len); ++i) { + *items++ = ExtractUnboxedValue(arraySlice, i, itemType, holderFactory); + } + return list; + } + case TType::EKind::Dict: { + auto dictType = static_cast<const TDictType*>(itemType); + auto keyType = dictType->GetKeyType(); + auto payloadType = dictType->GetPayloadType(); + auto dictBuilder = holderFactory.NewDict(dictType, NUdf::TDictFlags::EDictKind::Hashed); + + std::shared_ptr<arrow::Array> keyArray = nullptr; + std::shared_ptr<arrow::Array> payloadArray = nullptr; + ui64 dictLength = 0; + ui64 offset = 0; + if (keyType->GetKind() == TType::EKind::Optional) { + Y_VERIFY_DEBUG(array->type_id() == arrow::Type::LIST); + auto listArray = static_pointer_cast<arrow::ListArray>(array); + auto arraySlice = listArray->value_slice(row); + Y_VERIFY_DEBUG(arraySlice->type_id() == arrow::Type::STRUCT); + auto structArray = static_pointer_cast<arrow::StructArray>(arraySlice); + Y_VERIFY_DEBUG(structArray->num_fields() == 2); + dictLength = arraySlice->length(); + keyArray = structArray->field(0); + payloadArray = structArray->field(1); + } else { + Y_VERIFY_DEBUG(array->type_id() == arrow::Type::MAP); + auto mapArray = static_pointer_cast<arrow::MapArray>(array); + dictLength = mapArray->value_length(row); + offset = mapArray->value_offset(row); + keyArray = mapArray->keys(); + payloadArray = mapArray->items(); + } + for (ui64 i = offset; i < offset + static_cast<ui64>(dictLength); ++i) { + auto key = ExtractUnboxedValue(keyArray, i, keyType, holderFactory); + auto payload = ExtractUnboxedValue(payloadArray, i, payloadType, holderFactory); + dictBuilder->Add(std::move(key), std::move(payload)); + } + return dictBuilder->Build(); + } + case TType::EKind::Variant: { + // TODO Need to properly convert variants containing more than 127*127 types? + auto variantType = static_cast<const TVariantType*>(itemType); + Y_VERIFY_DEBUG(array->type_id() == arrow::Type::DENSE_UNION); + auto unionArray = static_pointer_cast<arrow::DenseUnionArray>(array); + auto variantIndex = unionArray->child_id(row); + auto rowInChild = unionArray->value_offset(row); + std::shared_ptr<arrow::Array> valuesArray = unionArray->field(variantIndex); + if (variantType->GetAlternativesCount() > arrow::UnionType::kMaxTypeCode) { + // Go one step deeper + Y_VERIFY_DEBUG(valuesArray->type_id() == arrow::Type::DENSE_UNION); + auto innerUnionArray = static_pointer_cast<arrow::DenseUnionArray>(valuesArray); + auto innerVariantIndex = innerUnionArray->child_id(rowInChild); + rowInChild = innerUnionArray->value_offset(rowInChild); + valuesArray = innerUnionArray->field(innerVariantIndex); + variantIndex = variantIndex * arrow::UnionType::kMaxTypeCode + innerVariantIndex; + } + TType* innerType = variantType->GetUnderlyingType(); + if (innerType->IsStruct()) { + innerType = static_cast<TStructType*>(innerType)->GetMemberType(variantIndex); + } else { + Y_VERIFY_S(innerType->IsTuple(), "Unexpected underlying variant type: " << innerType->GetKindAsStr()); + innerType = static_cast<TTupleType*>(innerType)->GetElementType(variantIndex); + } + NUdf::TUnboxedValue value = ExtractUnboxedValue(valuesArray, rowInChild, innerType, holderFactory); + return holderFactory.CreateVariantHolder(value.Release(), variantIndex); + } + default: + THROW yexception() << "Unsupported type: " << itemType->GetKindAsStr(); + } + return NUdf::TUnboxedValuePod(); +} + +NMiniKQL::TUnboxedValueVector ExtractUnboxedValues(const std::shared_ptr<arrow::Array>& array, const TType* itemType, const NMiniKQL::THolderFactory& holderFactory) { + NMiniKQL::TUnboxedValueVector values; + values.reserve(array->length()); + for (auto i = 0 ; i < array->length(); ++i) { + values.push_back(ExtractUnboxedValue(array, i, itemType, holderFactory)); + } + return values; +} + +std::string SerializeArray(const std::shared_ptr<arrow::Array>& array) { + auto schema = std::make_shared<arrow::Schema>(std::vector<std::shared_ptr<arrow::Field>>{arrow::field("", array->type())}); + auto batch = arrow::RecordBatch::Make(schema, array->length(), {array}); + auto writeOptions = arrow::ipc::IpcWriteOptions::Defaults(); // no compression set + writeOptions.use_threads = false; + // TODO Decide which compression level will be default. Will it depend on the length of array? + auto codecResult = arrow::util::Codec::Create(arrow::Compression::LZ4_FRAME); + Y_VERIFY(codecResult.ok()); + writeOptions.codec = std::move(codecResult.ValueOrDie()); + int64_t size; + auto status = GetRecordBatchSize(*batch, writeOptions, &size); + Y_VERIFY(status.ok()); + + std::string str; + str.resize(size); + + auto writer = arrow::Buffer::GetWriter(arrow::MutableBuffer::Wrap(&str[0], size)); + Y_VERIFY(writer.status().ok()); + + status = SerializeRecordBatch(*batch, writeOptions, (*writer).get()); + Y_VERIFY(status.ok()); + return str; +} + +std::shared_ptr<arrow::Array> DeserializeArray(const std::string& blob, std::shared_ptr<arrow::DataType> type) { + arrow::ipc::DictionaryMemo dictMemo; + auto options = arrow::ipc::IpcReadOptions::Defaults(); + options.use_threads = false; + + std::shared_ptr<arrow::Buffer> buffer = arrow::Buffer::FromString(blob); + arrow::io::BufferReader reader(buffer); + auto schema = std::make_shared<arrow::Schema>(std::vector<std::shared_ptr<arrow::Field>>{arrow::field("", type)}); + auto batch = ReadRecordBatch(schema, &dictMemo, options, &reader); + Y_VERIFY_DEBUG(batch.ok() && (*batch)->ValidateFull().ok(), "Failed to deserialize batch"); + return (*batch)->column(0); +} + +} // namespace NArrow } // namespace NYql diff --git a/ydb/library/yql/dq/runtime/dq_arrow_helpers.h b/ydb/library/yql/dq/runtime/dq_arrow_helpers.h index 3937a889f7b..31c5248afcb 100644 --- a/ydb/library/yql/dq/runtime/dq_arrow_helpers.h +++ b/ydb/library/yql/dq/runtime/dq_arrow_helpers.h @@ -1,79 +1,79 @@ #pragma once - + #include <ydb/library/yql/minikql/mkql_node.h> #include <ydb/library/yql/minikql/mkql_string_util.h> #include <ydb/library/yql/minikql/computation/mkql_computation_node_holders.h> -#include <contrib/libs/apache/arrow/cpp/src/arrow/api.h> -#include <contrib/libs/apache/arrow/cpp/src/arrow/array/array_base.h> -#include <contrib/libs/apache/arrow/cpp/src/arrow/array/builder_base.h> - -namespace NYql { -namespace NArrow { - -/** - * @brief Convert TType to the arrow::DataType object - * - * The logic of this conversion is the following: - * - * Struct, tuple => StructArray - * Names of fields constructed from tuple are just empty strings. - * - * List => ListArray - * - * Variant => DenseUnionArray - * If variant contains more than 127 items then we map - * Variant => DenseUnionArray<DenseUnionArray> - * TODO Implement convertion of data to DenseUnionArray<DenseUnionArray> and back - * - * Optional(Optional ..(type)..) => StructArray<ui64, type> - * Here the integer value equals the number of calls of method GetOptionalValue(). - * If value is null at some depth, then the value in second field of Array is Null - * (and the integer equals this depth). If value is present, then it is contained in the - * second field (and the integer equals the number of Optional(...) levels). - * This information is sufficient to restore an UnboxedValue knowing its type. - * - * Dict<KeyType, ValueType> => MapArray<KeyArray, ValueArray> - * We do not use arrow::DictArray because it must be used for encoding not for mapping keys to values. - * (https://arrow.apache.org/docs/cpp/api/array.html#classarrow_1_1_dictionary_array) - * If the type of dict key is optional then we map - * Dict<Optional(KeyType), ValueType> => ListArray<StructArray<KeyArray, ValueArray>> - * because keys of MapArray can not be nullable - * - * @param type Yql type to parse - * @return std::shared_ptr<arrow::DataType> arrow type of the same structure as type - */ -std::shared_ptr<arrow::DataType> GetArrowType(const NKikimr::NMiniKQL::TType* type); - -/** - * @brief Check if type can be converted to arrow format using only native arrow classes. - * - * @param type Type of UnboxedValue to check. - * @return true if type does not contain neither nested Optional, nor Dicts with Optional keys, nor Variants - * between more than 255 types. - * @return false otherwise - */ -bool IsArrowCompatible(const NKikimr::NMiniKQL::TType* type); - -std::unique_ptr<arrow::ArrayBuilder> MakeArrowBuilder(const NKikimr::NMiniKQL::TType* type); - -/** - * @brief Convert UnboxedValue-s to arrow Array - * - * @param values elements of future array - * @param itemType type of each element to parse it and to construct corresponding arrow type - * @return std::shared_ptr<arrow::Array> data in arrow format - */ +#include <contrib/libs/apache/arrow/cpp/src/arrow/api.h> +#include <contrib/libs/apache/arrow/cpp/src/arrow/array/array_base.h> +#include <contrib/libs/apache/arrow/cpp/src/arrow/array/builder_base.h> + +namespace NYql { +namespace NArrow { + +/** + * @brief Convert TType to the arrow::DataType object + * + * The logic of this conversion is the following: + * + * Struct, tuple => StructArray + * Names of fields constructed from tuple are just empty strings. + * + * List => ListArray + * + * Variant => DenseUnionArray + * If variant contains more than 127 items then we map + * Variant => DenseUnionArray<DenseUnionArray> + * TODO Implement convertion of data to DenseUnionArray<DenseUnionArray> and back + * + * Optional(Optional ..(type)..) => StructArray<ui64, type> + * Here the integer value equals the number of calls of method GetOptionalValue(). + * If value is null at some depth, then the value in second field of Array is Null + * (and the integer equals this depth). If value is present, then it is contained in the + * second field (and the integer equals the number of Optional(...) levels). + * This information is sufficient to restore an UnboxedValue knowing its type. + * + * Dict<KeyType, ValueType> => MapArray<KeyArray, ValueArray> + * We do not use arrow::DictArray because it must be used for encoding not for mapping keys to values. + * (https://arrow.apache.org/docs/cpp/api/array.html#classarrow_1_1_dictionary_array) + * If the type of dict key is optional then we map + * Dict<Optional(KeyType), ValueType> => ListArray<StructArray<KeyArray, ValueArray>> + * because keys of MapArray can not be nullable + * + * @param type Yql type to parse + * @return std::shared_ptr<arrow::DataType> arrow type of the same structure as type + */ +std::shared_ptr<arrow::DataType> GetArrowType(const NKikimr::NMiniKQL::TType* type); + +/** + * @brief Check if type can be converted to arrow format using only native arrow classes. + * + * @param type Type of UnboxedValue to check. + * @return true if type does not contain neither nested Optional, nor Dicts with Optional keys, nor Variants + * between more than 255 types. + * @return false otherwise + */ +bool IsArrowCompatible(const NKikimr::NMiniKQL::TType* type); + +std::unique_ptr<arrow::ArrayBuilder> MakeArrowBuilder(const NKikimr::NMiniKQL::TType* type); + +/** + * @brief Convert UnboxedValue-s to arrow Array + * + * @param values elements of future array + * @param itemType type of each element to parse it and to construct corresponding arrow type + * @return std::shared_ptr<arrow::Array> data in arrow format + */ std::shared_ptr<arrow::Array> MakeArray(NKikimr::NMiniKQL::TUnboxedValueVector& values, const NKikimr::NMiniKQL::TType* itemType); - + NKikimr::NMiniKQL::TUnboxedValueVector ExtractUnboxedValues(const std::shared_ptr<arrow::Array>& array, const NKikimr::NMiniKQL::TType* itemType, const NKikimr::NMiniKQL::THolderFactory& holderFactory); - -std::string SerializeArray(const std::shared_ptr<arrow::Array>& array); - -std::shared_ptr<arrow::Array> DeserializeArray(const std::string& blob, std::shared_ptr<arrow::DataType> type); - + +std::string SerializeArray(const std::shared_ptr<arrow::Array>& array); + +std::shared_ptr<arrow::Array> DeserializeArray(const std::string& blob, std::shared_ptr<arrow::DataType> type); + /** * @brief Append UnboxedValue to arrow Array via arrow Builder * @@ -83,7 +83,7 @@ std::shared_ptr<arrow::Array> DeserializeArray(const std::string& blob, std::sha * @return std::shared_ptr<arrow::Array> data in arrow format */ void AppendElement(NYql::NUdf::TUnboxedValue value, arrow::ArrayBuilder* builder, const NKikimr::NMiniKQL::TType* type); - -} // NArrow + +} // NArrow } // NYql diff --git a/ydb/library/yql/dq/runtime/dq_arrow_helpers_ut.cpp b/ydb/library/yql/dq/runtime/dq_arrow_helpers_ut.cpp index 821e157c781..b6762ee51fa 100644 --- a/ydb/library/yql/dq/runtime/dq_arrow_helpers_ut.cpp +++ b/ydb/library/yql/dq/runtime/dq_arrow_helpers_ut.cpp @@ -1,8 +1,8 @@ -#include "dq_arrow_helpers.h" - +#include "dq_arrow_helpers.h" + #include <ydb/core/util/yverify_stream.h> - -#include <memory> + +#include <memory> #include <ydb/library/yql/public/udf/udf_data_type.h> #include <ydb/library/yql/public/udf/udf_string_ref.h> #include <ydb/library/yql/public/udf/udf_type_ops.h> @@ -11,959 +11,959 @@ #include <ydb/library/yql/minikql/computation/mkql_computation_node_holders.h> #include <ydb/library/yql/minikql/computation/mkql_value_builder.h> #include <ydb/library/yql/minikql/mkql_string_util.h> - -#include <contrib/libs/apache/arrow/cpp/src/arrow/array/array_binary.h> -#include <contrib/libs/apache/arrow/cpp/src/arrow/array/array_nested.h> -#include <contrib/libs/apache/arrow/cpp/src/arrow/array/array_primitive.h> -#include <contrib/libs/apache/arrow/cpp/src/arrow/type.h> -#include <contrib/libs/apache/arrow/cpp/src/arrow/type_fwd.h> - -#include <util/string/builder.h> -#include <util/system/yassert.h> - -#include <library/cpp/testing/unittest/registar.h> - -using namespace NKikimr; -using namespace NKikimr::NMiniKQL; -using namespace NYql; - -namespace { -NUdf::TUnboxedValue GetValueOfBasicType(TType* type, ui64 value) { - Y_VERIFY(type->GetKind() == TType::EKind::Data); - auto dataType = static_cast<const TDataType*>(type); - auto slot = *dataType->GetDataSlot().Get(); - switch(slot) { - case NUdf::EDataSlot::Bool: - return NUdf::TUnboxedValuePod(static_cast<bool>(value % 2 == 0)); - case NUdf::EDataSlot::Int8: - return NUdf::TUnboxedValuePod(static_cast<i8>(-(value % 126))); - case NUdf::EDataSlot::Uint8: - return NUdf::TUnboxedValuePod(static_cast<ui8>(value % 255)); - case NUdf::EDataSlot::Int16: - return NUdf::TUnboxedValuePod(static_cast<i16>(-(value % ((1 << 15) - 1)))); - case NUdf::EDataSlot::Uint16: - return NUdf::TUnboxedValuePod(static_cast<ui16>(value % (1 << 16))); - case NUdf::EDataSlot::Int32: - return NUdf::TUnboxedValuePod(static_cast<i32>(-(value % ((1 << 31) - 1)))); - case NUdf::EDataSlot::Uint32: - return NUdf::TUnboxedValuePod(static_cast<ui32>(value % (1 << 31))); - case NUdf::EDataSlot::Int64: - return NUdf::TUnboxedValuePod(static_cast<i64>(- (value / 2))); - case NUdf::EDataSlot::Uint64: - return NUdf::TUnboxedValuePod(static_cast<ui64>(value)); - case NUdf::EDataSlot::Float: - return NUdf::TUnboxedValuePod(static_cast<float>(value) / 1234); - case NUdf::EDataSlot::Double: - return NUdf::TUnboxedValuePod(static_cast<double>(value) / 12345); - default: - Y_FAIL("Not implemented creation value for such type"); - } -} - -struct TTestContext { - TScopedAlloc Alloc; - TTypeEnvironment TypeEnv; - TMemoryUsageInfo MemInfo; - THolderFactory HolderFactory; - TDefaultValueBuilder Vb; - ui16 VariantSize = 0; - - // Used to create LargeVariantType - TVector<TType*> BasicTypes = { - TDataType::Create(NUdf::TDataType<bool>::Id, TypeEnv), - TDataType::Create(NUdf::TDataType<i8>::Id, TypeEnv), - TDataType::Create(NUdf::TDataType<ui8>::Id, TypeEnv), - TDataType::Create(NUdf::TDataType<i16>::Id, TypeEnv), - TDataType::Create(NUdf::TDataType<ui16>::Id, TypeEnv), - TDataType::Create(NUdf::TDataType<i32>::Id, TypeEnv), - TDataType::Create(NUdf::TDataType<ui32>::Id, TypeEnv), - TDataType::Create(NUdf::TDataType<i64>::Id, TypeEnv), - TDataType::Create(NUdf::TDataType<ui64>::Id, TypeEnv), - TDataType::Create(NUdf::TDataType<float>::Id, TypeEnv), - TDataType::Create(NUdf::TDataType<double>::Id, TypeEnv) - }; - - TTestContext() - : Alloc() - , TypeEnv(Alloc) - , MemInfo("TestMem") - , HolderFactory(Alloc.Ref(), MemInfo) - , Vb(HolderFactory) - { - } - - TType* GetStructType() { - TStructMember members[3] = { - {"s", TDataType::Create(NUdf::TDataType<char*>::Id, TypeEnv)}, - {"x", TDataType::Create(NUdf::TDataType<i32>::Id, TypeEnv)}, - {"y", TDataType::Create(NUdf::TDataType<ui64>::Id, TypeEnv)} - }; - return TStructType::Create(3, members, TypeEnv); - } - - TUnboxedValueVector CreateStructs(ui32 quantity) { - TUnboxedValueVector values; - for (ui32 value = 0; value < quantity; ++value) { - NUdf::TUnboxedValue* items; - auto structValue = Vb.NewArray(3, items); - std::string string = TStringBuilder() << value; - items[0] = MakeString(NUdf::TStringRef(string.data(), string.size())); - items[1] = NUdf::TUnboxedValuePod(static_cast<i32>(-value)); - items[2] = NUdf::TUnboxedValuePod((ui64) (value * value)); - values.emplace_back(std::move(structValue)); - } - return values; - } - - TType* GetTupleType() { - TType* members[3] = { - TDataType::Create(NUdf::TDataType<bool>::Id, TypeEnv), - TDataType::Create(NUdf::TDataType<i8>::Id, TypeEnv), - TDataType::Create(NUdf::TDataType<ui8>::Id, TypeEnv) - }; - return TTupleType::Create(3, members, TypeEnv); - } - - TUnboxedValueVector CreateTuples(ui32 quantity) { - NKikimr::NMiniKQL::TUnboxedValueVector values; - for (ui32 value = 0; value < quantity; ++value) { - auto array = NArrow::MakeArray(values, GetTupleType()); - auto str = NArrow::SerializeArray(array); - NUdf::TUnboxedValue* items; - auto tupleValue = Vb.NewArray(3, items); - items[0] = NUdf::TUnboxedValuePod(value % 3 == 0); - items[1] = NUdf::TUnboxedValuePod(static_cast<i8>(-value)); - items[2] = NUdf::TUnboxedValuePod(static_cast<ui8>(value)); - values.push_back(std::move(tupleValue)); - } - return values; - } - - TType* GetDictUtf8ToIntervalType() { - TType* keyType = TDataType::Create(NUdf::TDataType<NUdf::TUtf8>::Id, TypeEnv); - TType* payloadType = TDataType::Create(NUdf::TDataType<NUdf::TInterval>::Id, TypeEnv); - return TDictType::Create(keyType, payloadType, TypeEnv); - } - - TUnboxedValueVector CreateDictUtf8ToInterval(ui32 quantity) { - NKikimr::NMiniKQL::TUnboxedValueVector values; - auto dictType = GetDictUtf8ToIntervalType(); - for (ui32 value = 0; value < quantity; ++value) { - auto dictBuilder = Vb.NewDict(dictType, 0); - for (ui32 i = 0; i < value * value; ++i) { - std::string string = TStringBuilder() << "This is a long string #" << i; - NUdf::TUnboxedValue key = MakeString(NUdf::TStringRef(string.data(), string.size())); - NUdf::TUnboxedValue payload = NUdf::TUnboxedValuePod(static_cast<i64>(value * i)); - dictBuilder->Add(std::move(key), std::move(payload)); - } - auto dictValue = dictBuilder->Build(); - values.emplace_back(std::move(dictValue)); - } - return values; - } - - TType* GetListOfJsonsType() { - TType* itemType = TDataType::Create(NUdf::TDataType<NUdf::TJson>::Id, TypeEnv); - return TListType::Create(itemType, TypeEnv); - } - - TUnboxedValueVector CreateListOfJsons(ui32 quantity) { - TUnboxedValueVector values; - for (ui64 value = 0; value < quantity; ++value) { - TUnboxedValueVector items; - items.reserve(value); - for (ui64 i = 0; i < value; ++i) { - std::string json = TStringBuilder() << "{'item':" << i << "}"; - items.push_back(MakeString(NUdf::TStringRef(json.data(), json.size()))); - } - auto listValue = Vb.NewList(items.data(), value); - values.emplace_back(std::move(listValue)); - } - return values; - } - - TType* GetVariantOverStructType() { - TStructMember members[4] = { - {"0_yson", TDataType::Create(NUdf::TDataType<NUdf::TYson>::Id, TypeEnv)}, - {"1_json-document", TDataType::Create(NUdf::TDataType<NUdf::TJsonDocument>::Id, TypeEnv)}, - {"2_uuid", TDataType::Create(NUdf::TDataType<NUdf::TUuid>::Id, TypeEnv)}, - {"3_float", TDataType::Create(NUdf::TDataType<float>::Id, TypeEnv)} - }; - auto structType = TStructType::Create(4, members, TypeEnv); - return TVariantType::Create(structType, TypeEnv); - } - - TUnboxedValueVector CreateVariantOverStruct(ui32 quantity) { - TUnboxedValueVector values; - for (ui64 value = 0; value < quantity; ++value) { - auto typeIndex = value % 4; - NUdf::TUnboxedValue item; - if (typeIndex == 0) { - std::string data = TStringBuilder() << "{value=" << value << "}"; - item = MakeString(NUdf::TStringRef(data.data(), data.size())); - } else if (typeIndex == 1) { - std::string data = TStringBuilder() << "{value:" << value << "}"; - item = MakeString(NUdf::TStringRef(data.data(), data.size())); - } else if (typeIndex == 2) { - std::string data = TStringBuilder() << "id-QwErY-" << value; - item = MakeString(NUdf::TStringRef(data.data(), data.size())); - } else if (typeIndex == 3) { - item = NUdf::TUnboxedValuePod(static_cast<float>(value) / 4); - } - auto wrapped = Vb.NewVariant(typeIndex, std::move(item)); - values.push_back(std::move(wrapped)); - } - return values; - } - - TType* GetVariantOverTupleWithOptionalsType() { - TType* members[5] = { - TDataType::Create(NUdf::TDataType<bool>::Id, TypeEnv), - TDataType::Create(NUdf::TDataType<i16>::Id, TypeEnv), - TDataType::Create(NUdf::TDataType<ui16>::Id, TypeEnv), - TDataType::Create(NUdf::TDataType<i32>::Id, TypeEnv), - TOptionalType::Create(TDataType::Create(NUdf::TDataType<ui32>::Id, TypeEnv), TypeEnv) - }; - auto tupleType = TTupleType::Create(5, members, TypeEnv); - return TVariantType::Create(tupleType, TypeEnv); - } - - TUnboxedValueVector CreateVariantOverTupleWithOptionals(ui32 quantity) { - NKikimr::NMiniKQL::TUnboxedValueVector values; - for (ui64 value = 0; value < quantity; ++value) { - auto typeIndex = value % 5; - NUdf::TUnboxedValue item; - if (typeIndex == 0) { - item = NUdf::TUnboxedValuePod(value % 3 == 0); - } else if (typeIndex == 1) { - item = NUdf::TUnboxedValuePod(static_cast<i16>(-value)); - } else if (typeIndex == 2) { - item = NUdf::TUnboxedValuePod(static_cast<ui16>(value)); - } else if (typeIndex == 3) { - item = NUdf::TUnboxedValuePod(static_cast<i32>(-value)); - } else if (typeIndex == 4) { - NUdf::TUnboxedValue innerItem; - innerItem = value % 2 == 0 - ? NUdf::TUnboxedValuePod(static_cast<i32>(value)) - : NUdf::TUnboxedValuePod(); - item = innerItem.MakeOptional(); - } - auto wrapped = Vb.NewVariant(typeIndex, std::move(item)); - values.emplace_back(std::move(wrapped)); - } - return values; - } - - TType* GetDictOptionalToTupleType() { - TType* keyType = TOptionalType::Create(TDataType::Create(NUdf::TDataType<double>::Id, TypeEnv), TypeEnv); - TType* members[2] = { - TDataType::Create(NUdf::TDataType<i32>::Id, TypeEnv), - TDataType::Create(NUdf::TDataType<ui32>::Id, TypeEnv), - }; - TType* payloadType = TTupleType::Create(2, members, TypeEnv); - return TDictType::Create(keyType, payloadType, TypeEnv); - } - - TUnboxedValueVector CreateDictOptionalToTuple(ui32 quantity) { - NKikimr::NMiniKQL::TUnboxedValueVector values; - for (ui64 value = 0; value < quantity; ++value) { - auto dictBuilder = Vb.NewDict(GetDictOptionalToTupleType(), 0); - for (ui64 i = 0; i < value * value; ++i) { - NUdf::TUnboxedValue key; - if (i == 0) { - key = NUdf::TUnboxedValuePod(); - } else { - key = NUdf::TUnboxedValuePod(value / 4).MakeOptional(); - } - NUdf::TUnboxedValue* items; - auto payload = Vb.NewArray(2, items); - items[0] = NUdf::TUnboxedValuePod(static_cast<i32>(-value)); - items[1] = NUdf::TUnboxedValuePod(static_cast<ui32>(value)); - dictBuilder->Add(std::move(key), std::move(payload)); - } - auto dictValue = dictBuilder->Build(); - values.emplace_back(std::move(dictValue)); - } - return values; - } - - TType* GetOptionalOfOptionalType() { - return TOptionalType::Create( - TOptionalType::Create( - TDataType::Create(NUdf::TDataType<i32>::Id, TypeEnv), - TypeEnv), - TypeEnv); - } - - TUnboxedValueVector CreateOptionalOfOptional(ui32 quantity) { - TUnboxedValueVector values; - for (ui64 value = 0; value < quantity; ++value) { - NUdf::TUnboxedValue element = value % 3 == 0 - ? NUdf::TUnboxedValuePod(value).MakeOptional() - : NUdf::TUnboxedValuePod(); - if (value % 3 != 2) { - element = element.MakeOptional(); - } - values.emplace_back(std::move(element)); - } - return values; - } - - TType* GetLargeVariantType(const ui16 variantSize) { - VariantSize = variantSize; - TVector<TType*> tupleTypes; - tupleTypes.reserve(variantSize); - for (ui64 index = 0; index < variantSize; ++index) { - TVector<TType*> selectedTypes; - for (ui32 i = 0; i < BasicTypes.size(); ++i) { - if ((index >> i) % 2 == 1) { - selectedTypes.push_back(BasicTypes[i]); - } - } - tupleTypes.push_back(TTupleType::Create(selectedTypes.size(), selectedTypes.data(), TypeEnv)); - } - auto tupleOfTuplesType = TTupleType::Create(variantSize, tupleTypes.data(), TypeEnv); - return TVariantType::Create(tupleOfTuplesType, TypeEnv); - } - - TUnboxedValueVector CreateLargeVariant(ui32 quantity) { - TUnboxedValueVector values; - for (ui64 index = 0; index < quantity; ++index) { - NUdf::TUnboxedValue item; - auto typeIndex = index % VariantSize; - TUnboxedValueVector tupleItems; - for (ui64 i = 0; i < BasicTypes.size(); ++i) { - if ((typeIndex >> i) % 2 == 1) { - tupleItems.push_back(GetValueOfBasicType(BasicTypes[i], i)); - } - } - auto wrapped = Vb.NewVariant(typeIndex, HolderFactory.VectorAsArray(tupleItems)); - values.emplace_back(std::move(wrapped)); - } - return values; - } -}; - -// Note this equality check is not fully valid. But it is sufficient for UnboxedValues used in tests. -void AssertUnboxedValuesAreEqual(NUdf::TUnboxedValue& left, NUdf::TUnboxedValue& right, TType* type) { - switch (type->GetKind()) { - case TType::EKind::Void: - case TType::EKind::Null: - case TType::EKind::EmptyList: - case TType::EKind::EmptyDict: { - UNIT_ASSERT(!left.HasValue()); - UNIT_ASSERT(!right.HasValue()); - break; - } - - case TType::EKind::Data: { - auto dataType = static_cast<const TDataType*>(type); - auto dataSlot = *dataType->GetDataSlot().Get(); - // Json-like type are not comparable so just skip them - if (dataSlot != NUdf::EDataSlot::Json && dataSlot != NUdf::EDataSlot::Yson && dataSlot != NUdf::EDataSlot::JsonDocument) { - UNIT_ASSERT(NUdf::EquateValues(dataSlot, left, right)); - } - break; - } - - case TType::EKind::Optional: { - UNIT_ASSERT_EQUAL(left.HasValue(), right.HasValue()); - if (left.HasValue()) { - auto innerType = static_cast<const TOptionalType*>(type)->GetItemType(); - NUdf::TUnboxedValue leftInner = left.GetOptionalValue(); - NUdf::TUnboxedValue rightInner = right.GetOptionalValue(); - AssertUnboxedValuesAreEqual(leftInner, rightInner, innerType); - } - break; - } - - case TType::EKind::List: { - auto listType = static_cast<const TListType*>(type); - auto itemType = listType->GetItemType(); - auto leftPtr = left.GetElements(); - auto rightPtr = right.GetElements(); - UNIT_ASSERT_EQUAL(leftPtr != nullptr, rightPtr != nullptr); - if (leftPtr != nullptr) { - auto leftLen = left.GetListLength(); - auto rightLen = right.GetListLength(); - UNIT_ASSERT_EQUAL(leftLen, rightLen); - while (leftLen > 0) { - NUdf::TUnboxedValue leftItem = *leftPtr++; - NUdf::TUnboxedValue rightItem = *rightPtr++; - AssertUnboxedValuesAreEqual(leftItem, rightItem, itemType); - --leftLen; - } - } else { - const auto leftIter = left.GetListIterator(); - const auto rightIter = right.GetListIterator(); - NUdf::TUnboxedValue leftItem; - NUdf::TUnboxedValue rightItem; - bool leftHasValue = leftIter.Next(leftItem); - bool rightHasValue = rightIter.Next(leftItem); - while (leftHasValue && rightHasValue) { - AssertUnboxedValuesAreEqual(leftItem, rightItem, itemType); - leftHasValue = leftIter.Next(leftItem); - rightHasValue = rightIter.Next(leftItem); - } - UNIT_ASSERT_EQUAL(leftHasValue, rightHasValue); - } - break; - } - - case TType::EKind::Struct: { - auto structType = static_cast<const TStructType*>(type); - UNIT_ASSERT_EQUAL(left.GetListLength(), structType->GetMembersCount()); - UNIT_ASSERT_EQUAL(right.GetListLength(), structType->GetMembersCount()); - for (ui32 index = 0; index < structType->GetMembersCount(); ++index) { - auto memberType = structType->GetMemberType(index); - NUdf::TUnboxedValue leftMember = left.GetElement(index); - NUdf::TUnboxedValue rightMember = right.GetElement(index); - AssertUnboxedValuesAreEqual(leftMember, rightMember, memberType); - } - break; - } - - case TType::EKind::Tuple: { - auto tupleType = static_cast<const TTupleType*>(type); - UNIT_ASSERT_EQUAL(left.GetListLength(), tupleType->GetElementsCount()); - UNIT_ASSERT_EQUAL(right.GetListLength(), tupleType->GetElementsCount()); - for (ui32 index = 0; index < tupleType->GetElementsCount(); ++index) { - auto elementType = tupleType->GetElementType(index); - NUdf::TUnboxedValue leftMember = left.GetElement(index); - NUdf::TUnboxedValue rightMember = right.GetElement(index); - AssertUnboxedValuesAreEqual(leftMember, rightMember, elementType); - } - break; - } - - case TType::EKind::Dict: { - auto dictType = static_cast<const TDictType*>(type); - auto payloadType = dictType->GetPayloadType(); - - UNIT_ASSERT_EQUAL(left.GetDictLength(), right.GetDictLength()); - const auto leftIter = left.GetDictIterator(); - for (NUdf::TUnboxedValue key, leftPayload; leftIter.NextPair(key, leftPayload);) { - UNIT_ASSERT(right.Contains(key)); - NUdf::TUnboxedValue rightPayload = right.Lookup(key); - AssertUnboxedValuesAreEqual(leftPayload, rightPayload, payloadType); - } - break; - } - - case TType::EKind::Variant: { - auto variantType = static_cast<const TVariantType*>(type); - UNIT_ASSERT_EQUAL(left.GetVariantIndex(), right.GetVariantIndex()); - ui32 variantIndex = left.GetVariantIndex(); - TType* innerType = variantType->GetUnderlyingType(); - if (innerType->IsStruct()) { - innerType = static_cast<TStructType*>(innerType)->GetMemberType(variantIndex); - } else { - Y_VERIFY_S(innerType->IsTuple(), "Unexpected underlying variant type: " << innerType->GetKindAsStr()); - innerType = static_cast<TTupleType*>(innerType)->GetElementType(variantIndex); - } - NUdf::TUnboxedValue leftValue = left.GetVariantItem(); - NUdf::TUnboxedValue rightValue = right.GetVariantItem(); - AssertUnboxedValuesAreEqual(leftValue, rightValue, innerType); - break; - } - - default: - THROW yexception() << "Unsupported type: " << type->GetKindAsStr(); - } -} -} - - -Y_UNIT_TEST_SUITE(DqUnboxedValueToNativeArrowConversion) { - Y_UNIT_TEST(Struct) { - TTestContext context; - - auto structType = context.GetStructType(); - UNIT_ASSERT(NArrow::IsArrowCompatible(structType)); - - auto values = context.CreateStructs(100); - auto array = NArrow::MakeArray(values, structType); - - UNIT_ASSERT(array->ValidateFull().ok()); - UNIT_ASSERT(array->length() == static_cast<i64>(values.size())); - UNIT_ASSERT(array->type_id() == arrow::Type::STRUCT); - auto structArray = static_pointer_cast<arrow::StructArray>(array); - UNIT_ASSERT(structArray->num_fields() == 3); - UNIT_ASSERT(structArray->field(0)->type_id() == arrow::Type::BINARY); - UNIT_ASSERT(structArray->field(1)->type_id() == arrow::Type::INT32); - UNIT_ASSERT(structArray->field(2)->type_id() == arrow::Type::UINT64); - UNIT_ASSERT(static_cast<ui64>(structArray->field(0)->length()) == values.size()); - UNIT_ASSERT(static_cast<ui64>(structArray->field(1)->length()) == values.size()); - UNIT_ASSERT(static_cast<ui64>(structArray->field(2)->length()) == values.size()); - auto binaryArray = static_pointer_cast<arrow::BinaryArray>(structArray->field(0)); - auto int32Array = static_pointer_cast<arrow::Int32Array>(structArray->field(1)); - auto uint64Array = static_pointer_cast<arrow::UInt64Array>(structArray->field(2)); - auto index = 0; - for (const auto& value: values) { - auto stringValue = value.GetElement(0); - auto stringRef = stringValue.AsStringRef(); - auto stringView = binaryArray->GetView(index); - UNIT_ASSERT_EQUAL(std::string(stringRef.Data(), stringRef.Size()), std::string(stringView)); - - auto intValue = value.GetElement(1).Get<i32>(); - auto intArrow = int32Array->Value(index); - UNIT_ASSERT_EQUAL(intValue, intArrow); - - auto uIntValue = value.GetElement(2).Get<ui64>(); - auto uIntArrow = uint64Array->Value(index); - UNIT_ASSERT_EQUAL(uIntValue, uIntArrow); - ++index; - } - } - - Y_UNIT_TEST(Tuple) { - TTestContext context; - - auto tupleType = context.GetTupleType(); - UNIT_ASSERT(NArrow::IsArrowCompatible(tupleType)); - - auto values = context.CreateTuples(100); - auto array = NArrow::MakeArray(values, tupleType); - UNIT_ASSERT(array->ValidateFull().ok()); - - UNIT_ASSERT(array->length() == static_cast<i64>(values.size())); - UNIT_ASSERT(array->type_id() == arrow::Type::STRUCT); - auto structArray = static_pointer_cast<arrow::StructArray>(array); - UNIT_ASSERT(structArray->num_fields() == 3); - UNIT_ASSERT(structArray->field(0)->type_id() == arrow::Type::BOOL); - UNIT_ASSERT(structArray->field(1)->type_id() == arrow::Type::INT8); - UNIT_ASSERT(structArray->field(2)->type_id() == arrow::Type::UINT8); - UNIT_ASSERT(static_cast<ui64>(structArray->field(0)->length()) == values.size()); - UNIT_ASSERT(static_cast<ui64>(structArray->field(1)->length()) == values.size()); - UNIT_ASSERT(static_cast<ui64>(structArray->field(2)->length()) == values.size()); - auto boolArray = static_pointer_cast<arrow::BooleanArray>(structArray->field(0)); - auto int8Array = static_pointer_cast<arrow::Int8Array>(structArray->field(1)); - auto uint8Array = static_pointer_cast<arrow::UInt8Array>(structArray->field(2)); - auto index = 0; - for (const auto& value: values) { - auto boolValue = value.GetElement(0).Get<bool>(); - auto boolArrow = boolArray->Value(index); - UNIT_ASSERT(boolValue == boolArrow); - - auto intValue = value.GetElement(1).Get<i8>(); - auto intArrow = int8Array->Value(index); - UNIT_ASSERT(intValue == intArrow); - - auto uIntValue = value.GetElement(2).Get<ui8>(); - auto uIntArrow = uint8Array->Value(index); - UNIT_ASSERT(uIntValue == uIntArrow); - ++index; - } - } - - Y_UNIT_TEST(DictUtf8ToInterval) { - TTestContext context; - - auto dictType = context.GetDictUtf8ToIntervalType(); - UNIT_ASSERT(NArrow::IsArrowCompatible(dictType)); - - auto values = context.CreateDictUtf8ToInterval(100); - auto array = NArrow::MakeArray(values, dictType); - UNIT_ASSERT(array->ValidateFull().ok()); - UNIT_ASSERT(static_cast<ui64>(array->length()) == values.size()); - UNIT_ASSERT(array->type_id() == arrow::Type::MAP); - auto mapArray = static_pointer_cast<arrow::MapArray>(array); - - UNIT_ASSERT(mapArray->num_fields() == 1); - UNIT_ASSERT(mapArray->keys()->type_id() == arrow::Type::STRING); - UNIT_ASSERT(mapArray->items()->type_id() == arrow::Type::DURATION); - auto utf8Array = static_pointer_cast<arrow::StringArray>(mapArray->keys()); - auto intervalArray = static_pointer_cast<arrow::NumericArray<arrow::DurationType>>(mapArray->items()); - ui64 index = 0; - for (const auto& value: values) { - UNIT_ASSERT(value.GetDictLength() == static_cast<ui64>(mapArray->value_length(index))); - for (auto subindex = mapArray->value_offset(index); subindex < mapArray->value_offset(index + 1); ++subindex) { - auto keyArrow = utf8Array->GetView(subindex); - NUdf::TUnboxedValue key = MakeString(NUdf::TStringRef(keyArrow.data(), keyArrow.size())); - UNIT_ASSERT(value.Contains(key)); - NUdf::TUnboxedValue payloadValue = value.Lookup(key); - UNIT_ASSERT(intervalArray->Value(subindex) == payloadValue.Get<i64>()); - } - ++index; - } - } - - Y_UNIT_TEST(ListOfJsons) { - TTestContext context; - - auto listType = context.GetListOfJsonsType(); - Y_VERIFY(NArrow::IsArrowCompatible(listType)); - - auto values = context.CreateListOfJsons(100); - auto array = NArrow::MakeArray(values, listType); - UNIT_ASSERT(array->ValidateFull().ok()); - UNIT_ASSERT(static_cast<ui64>(array->length()) == values.size()); - UNIT_ASSERT(array->type_id() == arrow::Type::LIST); - auto listArray = static_pointer_cast<arrow::ListArray>(array); - - UNIT_ASSERT(listArray->num_fields() == 1); - UNIT_ASSERT(listArray->value_type()->id() == arrow::Type::STRING); - auto jsonArray = static_pointer_cast<arrow::StringArray>(listArray->values()); - auto index = 0; - auto innerIndex = 0; - for (const auto& value: values) { - UNIT_ASSERT(value.GetListLength() == static_cast<ui64>(listArray->value_length(index))); - const auto iter = value.GetListIterator(); - for (NUdf::TUnboxedValue item; iter.Next(item);) { - auto view = jsonArray->GetView(innerIndex); - std::string itemArrow(view.data(), view.size()); - auto stringRef = item.AsStringRef(); - std::string itemList(stringRef.Data(), stringRef.Size()); - UNIT_ASSERT(itemList == itemArrow); - ++innerIndex; - } - ++index; - } - } - - Y_UNIT_TEST(VariantOverStruct) { - TTestContext context; - - auto variantType = context.GetVariantOverStructType(); - UNIT_ASSERT(NArrow::IsArrowCompatible(variantType)); - - auto values = context.CreateVariantOverStruct(100); - auto array = NArrow::MakeArray(values, variantType); - UNIT_ASSERT(array->ValidateFull().ok()); - UNIT_ASSERT(static_cast<ui64>(array->length()) == values.size()); - UNIT_ASSERT(array->type_id() == arrow::Type::DENSE_UNION); - auto unionArray = static_pointer_cast<arrow::DenseUnionArray>(array); - - UNIT_ASSERT(unionArray->num_fields() == 4); - UNIT_ASSERT(unionArray->field(0)->type_id() == arrow::Type::STRING); - UNIT_ASSERT(unionArray->field(1)->type_id() == arrow::Type::STRING); - UNIT_ASSERT(unionArray->field(2)->type_id() == arrow::Type::BINARY); - UNIT_ASSERT(unionArray->field(3)->type_id() == arrow::Type::FLOAT); - auto ysonArray = static_pointer_cast<arrow::StringArray>(unionArray->field(0)); - auto jsonDocArray = static_pointer_cast<arrow::StringArray>(unionArray->field(1)); - auto uuidArray = static_pointer_cast<arrow::BinaryArray>(unionArray->field(2)); - auto floatArray = static_pointer_cast<arrow::FloatArray>(unionArray->field(3)); - for (ui64 index = 0; index < values.size(); ++index) { - auto value = values[index]; - UNIT_ASSERT(value.GetVariantIndex() == static_cast<ui32>(unionArray->child_id(index))); - auto fieldIndex = unionArray->value_offset(index); - if (value.GetVariantIndex() == 3) { - auto valueArrow = floatArray->Value(fieldIndex); - auto valueInner = value.GetVariantItem().Get<float>(); - UNIT_ASSERT(valueArrow == valueInner); - } else { - arrow::util::string_view viewArrow; - if (value.GetVariantIndex() == 0) { - viewArrow = ysonArray->GetView(fieldIndex); - } else if (value.GetVariantIndex() == 1) { - viewArrow = jsonDocArray->GetView(fieldIndex); - } else if (value.GetVariantIndex() == 2) { - viewArrow = uuidArray->GetView(fieldIndex); - } - std::string valueArrow(viewArrow.data(), viewArrow.size()); - auto innerItem = value.GetVariantItem(); - auto refInner = innerItem.AsStringRef(); - std::string valueInner(refInner.Data(), refInner.Size()); - UNIT_ASSERT(valueArrow == valueInner); - } - } - } - - Y_UNIT_TEST(VariantOverTupleWithOptionals) { - TTestContext context; - - auto variantType = context.GetVariantOverTupleWithOptionalsType(); - UNIT_ASSERT(NArrow::IsArrowCompatible(variantType)); - - auto values = context.CreateVariantOverStruct(100); - auto array = NArrow::MakeArray(values, variantType); - UNIT_ASSERT(array->ValidateFull().ok()); - UNIT_ASSERT(static_cast<ui64>(array->length()) == values.size()); - UNIT_ASSERT(array->type_id() == arrow::Type::DENSE_UNION); - auto unionArray = static_pointer_cast<arrow::DenseUnionArray>(array); - - UNIT_ASSERT(unionArray->num_fields() == 5); - UNIT_ASSERT(unionArray->field(0)->type_id() == arrow::Type::BOOL); - UNIT_ASSERT(unionArray->field(1)->type_id() == arrow::Type::INT16); - UNIT_ASSERT(unionArray->field(2)->type_id() == arrow::Type::UINT16); - UNIT_ASSERT(unionArray->field(3)->type_id() == arrow::Type::INT32); - UNIT_ASSERT(unionArray->field(4)->type_id() == arrow::Type::UINT32); - auto boolArray = static_pointer_cast<arrow::BooleanArray>(unionArray->field(0)); - auto i16Array = static_pointer_cast<arrow::Int16Array>(unionArray->field(1)); - auto ui16Array = static_pointer_cast<arrow::UInt16Array>(unionArray->field(2)); - auto i32Array = static_pointer_cast<arrow::Int32Array>(unionArray->field(3)); - auto ui32Array = static_pointer_cast<arrow::UInt32Array>(unionArray->field(4)); - for (ui64 index = 0; index < values.size(); ++index) { - auto value = values[index]; - UNIT_ASSERT(value.GetVariantIndex() == static_cast<ui32>(unionArray->child_id(index))); - auto fieldIndex = unionArray->value_offset(index); - if (value.GetVariantIndex() == 0) { - bool valueArrow = boolArray->Value(fieldIndex); - auto valueInner = value.GetVariantItem().Get<bool>(); - UNIT_ASSERT(valueArrow == valueInner); - } else if (value.GetVariantIndex() == 1) { - auto valueArrow = i16Array->Value(fieldIndex); - auto valueInner = value.GetVariantItem().Get<i16>(); - UNIT_ASSERT(valueArrow == valueInner); - } else if (value.GetVariantIndex() == 2) { - auto valueArrow = ui16Array->Value(fieldIndex); - auto valueInner = value.GetVariantItem().Get<ui16>(); - UNIT_ASSERT(valueArrow == valueInner); - } else if (value.GetVariantIndex() == 3) { - auto valueArrow = i32Array->Value(fieldIndex); - auto valueInner = value.GetVariantItem().Get<i32>(); - UNIT_ASSERT(valueArrow == valueInner); - } else if (value.GetVariantIndex() == 4) { - if (!value.GetVariantItem().HasValue()) { - UNIT_ASSERT(ui32Array->IsNull(fieldIndex)); - } else { - auto valueArrow = ui32Array->Value(fieldIndex); - auto valueInner = value.GetVariantItem().Get<ui32>(); - UNIT_ASSERT(valueArrow == valueInner); - } - } - } - } -} - -Y_UNIT_TEST_SUITE(DqUnboxedValueDoNotFitToArrow) { - Y_UNIT_TEST(DictOptionalToTuple) { - TTestContext context; - - auto dictType = context.GetDictOptionalToTupleType(); - UNIT_ASSERT(!NArrow::IsArrowCompatible(dictType)); - - auto values = context.CreateDictOptionalToTuple(100); - auto array = NArrow::MakeArray(values, dictType); - UNIT_ASSERT(array->ValidateFull().ok()); - UNIT_ASSERT_EQUAL(static_cast<ui64>(array->length()), values.size()); - UNIT_ASSERT_EQUAL(array->type_id(), arrow::Type::LIST); - auto listArray = static_pointer_cast<arrow::ListArray>(array); - UNIT_ASSERT_EQUAL(listArray->value_type()->id(), arrow::Type::STRUCT); - auto structArray = static_pointer_cast<arrow::StructArray>(listArray->values()); - - UNIT_ASSERT_EQUAL(listArray->num_fields(), 1); - UNIT_ASSERT_EQUAL(structArray->num_fields(), 2); - UNIT_ASSERT_EQUAL(structArray->field(0)->type_id(), arrow::Type::DOUBLE); - UNIT_ASSERT_EQUAL(structArray->field(1)->type_id(), arrow::Type::STRUCT); - auto keysArray = static_pointer_cast<arrow::DoubleArray>(structArray->field(0)); - auto itemsArray = static_pointer_cast<arrow::StructArray>(structArray->field(1)); - UNIT_ASSERT_EQUAL(itemsArray->num_fields(), 2); - UNIT_ASSERT_EQUAL(itemsArray->field(0)->type_id(), arrow::Type::INT32); - UNIT_ASSERT_EQUAL(itemsArray->field(1)->type_id(), arrow::Type::UINT32); - auto i32Array = static_pointer_cast<arrow::Int32Array>(itemsArray->field(0)); - auto ui32Array = static_pointer_cast<arrow::UInt32Array>(itemsArray->field(1)); - - ui64 index = 0; - for (const auto& value: values) { - UNIT_ASSERT(value.GetDictLength() == static_cast<ui64>(listArray->value_length(index))); - for (auto subindex = listArray->value_offset(index); subindex < listArray->value_offset(index + 1); ++subindex) { - NUdf::TUnboxedValue key = keysArray->IsNull(subindex) - ? NUdf::TUnboxedValuePod() - : NUdf::TUnboxedValuePod(keysArray->Value(subindex)); - UNIT_ASSERT(value.Contains(key)); - NUdf::TUnboxedValue payloadValue = value.Lookup(key); - UNIT_ASSERT_EQUAL(payloadValue.GetElement(0).Get<i32>(), i32Array->Value(subindex)); - UNIT_ASSERT_EQUAL(payloadValue.GetElement(1).Get<ui32>(), ui32Array->Value(subindex)); - } - ++index; - } - } - - Y_UNIT_TEST(OptionalOfOptional) { - TTestContext context; - - auto doubleOptionalType = context.GetOptionalOfOptionalType(); - UNIT_ASSERT(!NArrow::IsArrowCompatible(doubleOptionalType)); - - auto values = context.CreateOptionalOfOptional(100); - auto array = NArrow::MakeArray(values, doubleOptionalType); - UNIT_ASSERT(array->ValidateFull().ok()); - UNIT_ASSERT_EQUAL(static_cast<ui64>(array->length()), values.size()); - UNIT_ASSERT_EQUAL(array->type_id(), arrow::Type::STRUCT); - auto structArray = static_pointer_cast<arrow::StructArray>(array); - UNIT_ASSERT_EQUAL(structArray->num_fields(), 2); - UNIT_ASSERT_EQUAL(structArray->field(0)->type_id(), arrow::Type::UINT64); - UNIT_ASSERT_EQUAL(structArray->field(1)->type_id(), arrow::Type::INT32); - auto depthArray = static_pointer_cast<arrow::UInt64Array>(structArray->field(0)); - auto i32Array = static_pointer_cast<arrow::Int32Array>(structArray->field(1)); - - auto index = 0; - for (auto value: values) { - auto depth = depthArray->Value(index); - while (depth > 0) { - UNIT_ASSERT(value.HasValue()); - value = value.GetOptionalValue(); - --depth; - } - if (value.HasValue()) { - UNIT_ASSERT_EQUAL(value.Get<i32>(), i32Array->Value(index)); - } else { - UNIT_ASSERT(i32Array->IsNull(index)); - } - ++index; - } - } - - Y_UNIT_TEST(LargeVariant) { - TTestContext context; - - ui32 numberOfTypes = 500; - auto variantType = context.GetLargeVariantType(numberOfTypes); - bool isCompatible = NArrow::IsArrowCompatible(variantType); - UNIT_ASSERT(!isCompatible); - - auto values = context.CreateLargeVariant(1000); - auto array = NArrow::MakeArray(values, variantType); - UNIT_ASSERT(array->ValidateFull().ok()); - UNIT_ASSERT_EQUAL(static_cast<ui64>(array->length()), values.size()); - UNIT_ASSERT_EQUAL(array->type_id(), arrow::Type::DENSE_UNION); - auto unionArray = static_pointer_cast<arrow::DenseUnionArray>(array); - ui32 numberOfGroups = (numberOfTypes - 1) / arrow::UnionType::kMaxTypeCode + 1; - UNIT_ASSERT_EQUAL(numberOfGroups, static_cast<ui32>(unionArray->num_fields())); - ui32 typesInArrow = 0; - for (auto i = 0 ; i < unionArray->num_fields(); ++i) { - UNIT_ASSERT_EQUAL(unionArray->field(i)->type_id(), arrow::Type::DENSE_UNION); - typesInArrow += unionArray->field(i)->num_fields(); - } - UNIT_ASSERT_EQUAL(numberOfTypes, typesInArrow); - // TODO Check array content. - } -} - -Y_UNIT_TEST_SUITE(ConvertUnboxedValueToArrowAndBack){ - Y_UNIT_TEST(Struct) { - TTestContext context; - - auto structType = context.GetStructType(); - auto values = context.CreateStructs(100); - auto array = NArrow::MakeArray(values, structType); - auto restoredValues = NArrow::ExtractUnboxedValues(array, structType, context.HolderFactory); - UNIT_ASSERT_EQUAL(values.size(), restoredValues.size()); - for (ui64 index = 0; index < values.size(); ++index) { - AssertUnboxedValuesAreEqual(values[index], restoredValues[index], structType); - } - } - - Y_UNIT_TEST(Tuple) { - TTestContext context; - - auto tupleType = context.GetTupleType(); - UNIT_ASSERT(NArrow::IsArrowCompatible(tupleType)); - - auto values = context.CreateTuples(100); - auto array = NArrow::MakeArray(values, tupleType); - auto restoredValues = NArrow::ExtractUnboxedValues(array, tupleType, context.HolderFactory); - UNIT_ASSERT_EQUAL(values.size(), restoredValues.size()); - for (ui64 index = 0; index < values.size(); ++index) { - AssertUnboxedValuesAreEqual(values[index], restoredValues[index], tupleType); - } - } - - Y_UNIT_TEST(DictUtf8ToInterval) { - TTestContext context; - - auto dictType = context.GetDictUtf8ToIntervalType(); - UNIT_ASSERT(NArrow::IsArrowCompatible(dictType)); - - auto values = context.CreateDictUtf8ToInterval(100); - auto array = NArrow::MakeArray(values, dictType); - auto restoredValues = NArrow::ExtractUnboxedValues(array, dictType, context.HolderFactory); - UNIT_ASSERT_EQUAL(values.size(), restoredValues.size()); - for (ui64 index = 0; index < values.size(); ++index) { - AssertUnboxedValuesAreEqual(values[index], restoredValues[index], dictType); - } - } - - Y_UNIT_TEST(ListOfJsons) { - TTestContext context; - - auto listType = context.GetListOfJsonsType(); - Y_VERIFY(NArrow::IsArrowCompatible(listType)); - - auto values = context.CreateListOfJsons(100); - auto array = NArrow::MakeArray(values, listType); - auto restoredValues = NArrow::ExtractUnboxedValues(array, listType, context.HolderFactory); - UNIT_ASSERT_EQUAL(values.size(), restoredValues.size()); - for (ui64 index = 0; index < values.size(); ++index) { - AssertUnboxedValuesAreEqual(values[index], restoredValues[index], listType); - } - } - - Y_UNIT_TEST(VariantOverStruct) { - TTestContext context; - - auto variantType = context.GetVariantOverStructType(); - UNIT_ASSERT(NArrow::IsArrowCompatible(variantType)); - - auto values = context.CreateVariantOverStruct(100); - auto array = NArrow::MakeArray(values, variantType); - auto restoredValues = NArrow::ExtractUnboxedValues(array, variantType, context.HolderFactory); - UNIT_ASSERT_EQUAL(values.size(), restoredValues.size()); - for (ui64 index = 0; index < values.size(); ++index) { - AssertUnboxedValuesAreEqual(values[index], restoredValues[index], variantType); - } - } - - Y_UNIT_TEST(VariantOverTupleWithOptionals) { - TTestContext context; - - auto variantType = context.GetVariantOverTupleWithOptionalsType(); - UNIT_ASSERT(NArrow::IsArrowCompatible(variantType)); - - auto values = context.CreateVariantOverStruct(100); - auto array = NArrow::MakeArray(values, variantType); - auto restoredValues = NArrow::ExtractUnboxedValues(array, variantType, context.HolderFactory); - UNIT_ASSERT_EQUAL(values.size(), restoredValues.size()); - for (ui64 index = 0; index < values.size(); ++index) { - AssertUnboxedValuesAreEqual(values[index], restoredValues[index], variantType); - } - } - - Y_UNIT_TEST(DictOptionalToTuple) { - TTestContext context; - - auto dictType = context.GetDictOptionalToTupleType(); - UNIT_ASSERT(!NArrow::IsArrowCompatible(dictType)); - - auto values = context.CreateDictOptionalToTuple(100); - auto array = NArrow::MakeArray(values, dictType); - auto restoredValues = NArrow::ExtractUnboxedValues(array, dictType, context.HolderFactory); - UNIT_ASSERT_EQUAL(values.size(), restoredValues.size()); - for (ui64 index = 0; index < values.size(); ++index) { - AssertUnboxedValuesAreEqual(values[index], restoredValues[index], dictType); - } - } - - Y_UNIT_TEST(OptionalOfOptional) { - TTestContext context; - - auto doubleOptionalType = context.GetOptionalOfOptionalType(); - UNIT_ASSERT(!NArrow::IsArrowCompatible(doubleOptionalType)); - - auto values = context.CreateOptionalOfOptional(100); - auto array = NArrow::MakeArray(values, doubleOptionalType); - auto restoredValues = NArrow::ExtractUnboxedValues(array, doubleOptionalType, context.HolderFactory); - UNIT_ASSERT_EQUAL(values.size(), restoredValues.size()); - for (ui64 index = 0; index < values.size(); ++index) { - AssertUnboxedValuesAreEqual(values[index], restoredValues[index], doubleOptionalType); - } - } - - Y_UNIT_TEST(LargeVariant) { - TTestContext context; - - auto variantType = context.GetLargeVariantType(500); - bool isCompatible = NArrow::IsArrowCompatible(variantType); - UNIT_ASSERT(!isCompatible); - - auto values = context.CreateLargeVariant(1000); - auto array = NArrow::MakeArray(values, variantType); - auto restoredValues = NArrow::ExtractUnboxedValues(array, variantType, context.HolderFactory); - UNIT_ASSERT_EQUAL(values.size(), restoredValues.size()); - for (ui64 index = 0; index < values.size(); ++index) { - AssertUnboxedValuesAreEqual(values[index], restoredValues[index], variantType); - } - } + +#include <contrib/libs/apache/arrow/cpp/src/arrow/array/array_binary.h> +#include <contrib/libs/apache/arrow/cpp/src/arrow/array/array_nested.h> +#include <contrib/libs/apache/arrow/cpp/src/arrow/array/array_primitive.h> +#include <contrib/libs/apache/arrow/cpp/src/arrow/type.h> +#include <contrib/libs/apache/arrow/cpp/src/arrow/type_fwd.h> + +#include <util/string/builder.h> +#include <util/system/yassert.h> + +#include <library/cpp/testing/unittest/registar.h> + +using namespace NKikimr; +using namespace NKikimr::NMiniKQL; +using namespace NYql; + +namespace { +NUdf::TUnboxedValue GetValueOfBasicType(TType* type, ui64 value) { + Y_VERIFY(type->GetKind() == TType::EKind::Data); + auto dataType = static_cast<const TDataType*>(type); + auto slot = *dataType->GetDataSlot().Get(); + switch(slot) { + case NUdf::EDataSlot::Bool: + return NUdf::TUnboxedValuePod(static_cast<bool>(value % 2 == 0)); + case NUdf::EDataSlot::Int8: + return NUdf::TUnboxedValuePod(static_cast<i8>(-(value % 126))); + case NUdf::EDataSlot::Uint8: + return NUdf::TUnboxedValuePod(static_cast<ui8>(value % 255)); + case NUdf::EDataSlot::Int16: + return NUdf::TUnboxedValuePod(static_cast<i16>(-(value % ((1 << 15) - 1)))); + case NUdf::EDataSlot::Uint16: + return NUdf::TUnboxedValuePod(static_cast<ui16>(value % (1 << 16))); + case NUdf::EDataSlot::Int32: + return NUdf::TUnboxedValuePod(static_cast<i32>(-(value % ((1 << 31) - 1)))); + case NUdf::EDataSlot::Uint32: + return NUdf::TUnboxedValuePod(static_cast<ui32>(value % (1 << 31))); + case NUdf::EDataSlot::Int64: + return NUdf::TUnboxedValuePod(static_cast<i64>(- (value / 2))); + case NUdf::EDataSlot::Uint64: + return NUdf::TUnboxedValuePod(static_cast<ui64>(value)); + case NUdf::EDataSlot::Float: + return NUdf::TUnboxedValuePod(static_cast<float>(value) / 1234); + case NUdf::EDataSlot::Double: + return NUdf::TUnboxedValuePod(static_cast<double>(value) / 12345); + default: + Y_FAIL("Not implemented creation value for such type"); + } +} + +struct TTestContext { + TScopedAlloc Alloc; + TTypeEnvironment TypeEnv; + TMemoryUsageInfo MemInfo; + THolderFactory HolderFactory; + TDefaultValueBuilder Vb; + ui16 VariantSize = 0; + + // Used to create LargeVariantType + TVector<TType*> BasicTypes = { + TDataType::Create(NUdf::TDataType<bool>::Id, TypeEnv), + TDataType::Create(NUdf::TDataType<i8>::Id, TypeEnv), + TDataType::Create(NUdf::TDataType<ui8>::Id, TypeEnv), + TDataType::Create(NUdf::TDataType<i16>::Id, TypeEnv), + TDataType::Create(NUdf::TDataType<ui16>::Id, TypeEnv), + TDataType::Create(NUdf::TDataType<i32>::Id, TypeEnv), + TDataType::Create(NUdf::TDataType<ui32>::Id, TypeEnv), + TDataType::Create(NUdf::TDataType<i64>::Id, TypeEnv), + TDataType::Create(NUdf::TDataType<ui64>::Id, TypeEnv), + TDataType::Create(NUdf::TDataType<float>::Id, TypeEnv), + TDataType::Create(NUdf::TDataType<double>::Id, TypeEnv) + }; + + TTestContext() + : Alloc() + , TypeEnv(Alloc) + , MemInfo("TestMem") + , HolderFactory(Alloc.Ref(), MemInfo) + , Vb(HolderFactory) + { + } + + TType* GetStructType() { + TStructMember members[3] = { + {"s", TDataType::Create(NUdf::TDataType<char*>::Id, TypeEnv)}, + {"x", TDataType::Create(NUdf::TDataType<i32>::Id, TypeEnv)}, + {"y", TDataType::Create(NUdf::TDataType<ui64>::Id, TypeEnv)} + }; + return TStructType::Create(3, members, TypeEnv); + } + + TUnboxedValueVector CreateStructs(ui32 quantity) { + TUnboxedValueVector values; + for (ui32 value = 0; value < quantity; ++value) { + NUdf::TUnboxedValue* items; + auto structValue = Vb.NewArray(3, items); + std::string string = TStringBuilder() << value; + items[0] = MakeString(NUdf::TStringRef(string.data(), string.size())); + items[1] = NUdf::TUnboxedValuePod(static_cast<i32>(-value)); + items[2] = NUdf::TUnboxedValuePod((ui64) (value * value)); + values.emplace_back(std::move(structValue)); + } + return values; + } + + TType* GetTupleType() { + TType* members[3] = { + TDataType::Create(NUdf::TDataType<bool>::Id, TypeEnv), + TDataType::Create(NUdf::TDataType<i8>::Id, TypeEnv), + TDataType::Create(NUdf::TDataType<ui8>::Id, TypeEnv) + }; + return TTupleType::Create(3, members, TypeEnv); + } + + TUnboxedValueVector CreateTuples(ui32 quantity) { + NKikimr::NMiniKQL::TUnboxedValueVector values; + for (ui32 value = 0; value < quantity; ++value) { + auto array = NArrow::MakeArray(values, GetTupleType()); + auto str = NArrow::SerializeArray(array); + NUdf::TUnboxedValue* items; + auto tupleValue = Vb.NewArray(3, items); + items[0] = NUdf::TUnboxedValuePod(value % 3 == 0); + items[1] = NUdf::TUnboxedValuePod(static_cast<i8>(-value)); + items[2] = NUdf::TUnboxedValuePod(static_cast<ui8>(value)); + values.push_back(std::move(tupleValue)); + } + return values; + } + + TType* GetDictUtf8ToIntervalType() { + TType* keyType = TDataType::Create(NUdf::TDataType<NUdf::TUtf8>::Id, TypeEnv); + TType* payloadType = TDataType::Create(NUdf::TDataType<NUdf::TInterval>::Id, TypeEnv); + return TDictType::Create(keyType, payloadType, TypeEnv); + } + + TUnboxedValueVector CreateDictUtf8ToInterval(ui32 quantity) { + NKikimr::NMiniKQL::TUnboxedValueVector values; + auto dictType = GetDictUtf8ToIntervalType(); + for (ui32 value = 0; value < quantity; ++value) { + auto dictBuilder = Vb.NewDict(dictType, 0); + for (ui32 i = 0; i < value * value; ++i) { + std::string string = TStringBuilder() << "This is a long string #" << i; + NUdf::TUnboxedValue key = MakeString(NUdf::TStringRef(string.data(), string.size())); + NUdf::TUnboxedValue payload = NUdf::TUnboxedValuePod(static_cast<i64>(value * i)); + dictBuilder->Add(std::move(key), std::move(payload)); + } + auto dictValue = dictBuilder->Build(); + values.emplace_back(std::move(dictValue)); + } + return values; + } + + TType* GetListOfJsonsType() { + TType* itemType = TDataType::Create(NUdf::TDataType<NUdf::TJson>::Id, TypeEnv); + return TListType::Create(itemType, TypeEnv); + } + + TUnboxedValueVector CreateListOfJsons(ui32 quantity) { + TUnboxedValueVector values; + for (ui64 value = 0; value < quantity; ++value) { + TUnboxedValueVector items; + items.reserve(value); + for (ui64 i = 0; i < value; ++i) { + std::string json = TStringBuilder() << "{'item':" << i << "}"; + items.push_back(MakeString(NUdf::TStringRef(json.data(), json.size()))); + } + auto listValue = Vb.NewList(items.data(), value); + values.emplace_back(std::move(listValue)); + } + return values; + } + + TType* GetVariantOverStructType() { + TStructMember members[4] = { + {"0_yson", TDataType::Create(NUdf::TDataType<NUdf::TYson>::Id, TypeEnv)}, + {"1_json-document", TDataType::Create(NUdf::TDataType<NUdf::TJsonDocument>::Id, TypeEnv)}, + {"2_uuid", TDataType::Create(NUdf::TDataType<NUdf::TUuid>::Id, TypeEnv)}, + {"3_float", TDataType::Create(NUdf::TDataType<float>::Id, TypeEnv)} + }; + auto structType = TStructType::Create(4, members, TypeEnv); + return TVariantType::Create(structType, TypeEnv); + } + + TUnboxedValueVector CreateVariantOverStruct(ui32 quantity) { + TUnboxedValueVector values; + for (ui64 value = 0; value < quantity; ++value) { + auto typeIndex = value % 4; + NUdf::TUnboxedValue item; + if (typeIndex == 0) { + std::string data = TStringBuilder() << "{value=" << value << "}"; + item = MakeString(NUdf::TStringRef(data.data(), data.size())); + } else if (typeIndex == 1) { + std::string data = TStringBuilder() << "{value:" << value << "}"; + item = MakeString(NUdf::TStringRef(data.data(), data.size())); + } else if (typeIndex == 2) { + std::string data = TStringBuilder() << "id-QwErY-" << value; + item = MakeString(NUdf::TStringRef(data.data(), data.size())); + } else if (typeIndex == 3) { + item = NUdf::TUnboxedValuePod(static_cast<float>(value) / 4); + } + auto wrapped = Vb.NewVariant(typeIndex, std::move(item)); + values.push_back(std::move(wrapped)); + } + return values; + } + + TType* GetVariantOverTupleWithOptionalsType() { + TType* members[5] = { + TDataType::Create(NUdf::TDataType<bool>::Id, TypeEnv), + TDataType::Create(NUdf::TDataType<i16>::Id, TypeEnv), + TDataType::Create(NUdf::TDataType<ui16>::Id, TypeEnv), + TDataType::Create(NUdf::TDataType<i32>::Id, TypeEnv), + TOptionalType::Create(TDataType::Create(NUdf::TDataType<ui32>::Id, TypeEnv), TypeEnv) + }; + auto tupleType = TTupleType::Create(5, members, TypeEnv); + return TVariantType::Create(tupleType, TypeEnv); + } + + TUnboxedValueVector CreateVariantOverTupleWithOptionals(ui32 quantity) { + NKikimr::NMiniKQL::TUnboxedValueVector values; + for (ui64 value = 0; value < quantity; ++value) { + auto typeIndex = value % 5; + NUdf::TUnboxedValue item; + if (typeIndex == 0) { + item = NUdf::TUnboxedValuePod(value % 3 == 0); + } else if (typeIndex == 1) { + item = NUdf::TUnboxedValuePod(static_cast<i16>(-value)); + } else if (typeIndex == 2) { + item = NUdf::TUnboxedValuePod(static_cast<ui16>(value)); + } else if (typeIndex == 3) { + item = NUdf::TUnboxedValuePod(static_cast<i32>(-value)); + } else if (typeIndex == 4) { + NUdf::TUnboxedValue innerItem; + innerItem = value % 2 == 0 + ? NUdf::TUnboxedValuePod(static_cast<i32>(value)) + : NUdf::TUnboxedValuePod(); + item = innerItem.MakeOptional(); + } + auto wrapped = Vb.NewVariant(typeIndex, std::move(item)); + values.emplace_back(std::move(wrapped)); + } + return values; + } + + TType* GetDictOptionalToTupleType() { + TType* keyType = TOptionalType::Create(TDataType::Create(NUdf::TDataType<double>::Id, TypeEnv), TypeEnv); + TType* members[2] = { + TDataType::Create(NUdf::TDataType<i32>::Id, TypeEnv), + TDataType::Create(NUdf::TDataType<ui32>::Id, TypeEnv), + }; + TType* payloadType = TTupleType::Create(2, members, TypeEnv); + return TDictType::Create(keyType, payloadType, TypeEnv); + } + + TUnboxedValueVector CreateDictOptionalToTuple(ui32 quantity) { + NKikimr::NMiniKQL::TUnboxedValueVector values; + for (ui64 value = 0; value < quantity; ++value) { + auto dictBuilder = Vb.NewDict(GetDictOptionalToTupleType(), 0); + for (ui64 i = 0; i < value * value; ++i) { + NUdf::TUnboxedValue key; + if (i == 0) { + key = NUdf::TUnboxedValuePod(); + } else { + key = NUdf::TUnboxedValuePod(value / 4).MakeOptional(); + } + NUdf::TUnboxedValue* items; + auto payload = Vb.NewArray(2, items); + items[0] = NUdf::TUnboxedValuePod(static_cast<i32>(-value)); + items[1] = NUdf::TUnboxedValuePod(static_cast<ui32>(value)); + dictBuilder->Add(std::move(key), std::move(payload)); + } + auto dictValue = dictBuilder->Build(); + values.emplace_back(std::move(dictValue)); + } + return values; + } + + TType* GetOptionalOfOptionalType() { + return TOptionalType::Create( + TOptionalType::Create( + TDataType::Create(NUdf::TDataType<i32>::Id, TypeEnv), + TypeEnv), + TypeEnv); + } + + TUnboxedValueVector CreateOptionalOfOptional(ui32 quantity) { + TUnboxedValueVector values; + for (ui64 value = 0; value < quantity; ++value) { + NUdf::TUnboxedValue element = value % 3 == 0 + ? NUdf::TUnboxedValuePod(value).MakeOptional() + : NUdf::TUnboxedValuePod(); + if (value % 3 != 2) { + element = element.MakeOptional(); + } + values.emplace_back(std::move(element)); + } + return values; + } + + TType* GetLargeVariantType(const ui16 variantSize) { + VariantSize = variantSize; + TVector<TType*> tupleTypes; + tupleTypes.reserve(variantSize); + for (ui64 index = 0; index < variantSize; ++index) { + TVector<TType*> selectedTypes; + for (ui32 i = 0; i < BasicTypes.size(); ++i) { + if ((index >> i) % 2 == 1) { + selectedTypes.push_back(BasicTypes[i]); + } + } + tupleTypes.push_back(TTupleType::Create(selectedTypes.size(), selectedTypes.data(), TypeEnv)); + } + auto tupleOfTuplesType = TTupleType::Create(variantSize, tupleTypes.data(), TypeEnv); + return TVariantType::Create(tupleOfTuplesType, TypeEnv); + } + + TUnboxedValueVector CreateLargeVariant(ui32 quantity) { + TUnboxedValueVector values; + for (ui64 index = 0; index < quantity; ++index) { + NUdf::TUnboxedValue item; + auto typeIndex = index % VariantSize; + TUnboxedValueVector tupleItems; + for (ui64 i = 0; i < BasicTypes.size(); ++i) { + if ((typeIndex >> i) % 2 == 1) { + tupleItems.push_back(GetValueOfBasicType(BasicTypes[i], i)); + } + } + auto wrapped = Vb.NewVariant(typeIndex, HolderFactory.VectorAsArray(tupleItems)); + values.emplace_back(std::move(wrapped)); + } + return values; + } +}; + +// Note this equality check is not fully valid. But it is sufficient for UnboxedValues used in tests. +void AssertUnboxedValuesAreEqual(NUdf::TUnboxedValue& left, NUdf::TUnboxedValue& right, TType* type) { + switch (type->GetKind()) { + case TType::EKind::Void: + case TType::EKind::Null: + case TType::EKind::EmptyList: + case TType::EKind::EmptyDict: { + UNIT_ASSERT(!left.HasValue()); + UNIT_ASSERT(!right.HasValue()); + break; + } + + case TType::EKind::Data: { + auto dataType = static_cast<const TDataType*>(type); + auto dataSlot = *dataType->GetDataSlot().Get(); + // Json-like type are not comparable so just skip them + if (dataSlot != NUdf::EDataSlot::Json && dataSlot != NUdf::EDataSlot::Yson && dataSlot != NUdf::EDataSlot::JsonDocument) { + UNIT_ASSERT(NUdf::EquateValues(dataSlot, left, right)); + } + break; + } + + case TType::EKind::Optional: { + UNIT_ASSERT_EQUAL(left.HasValue(), right.HasValue()); + if (left.HasValue()) { + auto innerType = static_cast<const TOptionalType*>(type)->GetItemType(); + NUdf::TUnboxedValue leftInner = left.GetOptionalValue(); + NUdf::TUnboxedValue rightInner = right.GetOptionalValue(); + AssertUnboxedValuesAreEqual(leftInner, rightInner, innerType); + } + break; + } + + case TType::EKind::List: { + auto listType = static_cast<const TListType*>(type); + auto itemType = listType->GetItemType(); + auto leftPtr = left.GetElements(); + auto rightPtr = right.GetElements(); + UNIT_ASSERT_EQUAL(leftPtr != nullptr, rightPtr != nullptr); + if (leftPtr != nullptr) { + auto leftLen = left.GetListLength(); + auto rightLen = right.GetListLength(); + UNIT_ASSERT_EQUAL(leftLen, rightLen); + while (leftLen > 0) { + NUdf::TUnboxedValue leftItem = *leftPtr++; + NUdf::TUnboxedValue rightItem = *rightPtr++; + AssertUnboxedValuesAreEqual(leftItem, rightItem, itemType); + --leftLen; + } + } else { + const auto leftIter = left.GetListIterator(); + const auto rightIter = right.GetListIterator(); + NUdf::TUnboxedValue leftItem; + NUdf::TUnboxedValue rightItem; + bool leftHasValue = leftIter.Next(leftItem); + bool rightHasValue = rightIter.Next(leftItem); + while (leftHasValue && rightHasValue) { + AssertUnboxedValuesAreEqual(leftItem, rightItem, itemType); + leftHasValue = leftIter.Next(leftItem); + rightHasValue = rightIter.Next(leftItem); + } + UNIT_ASSERT_EQUAL(leftHasValue, rightHasValue); + } + break; + } + + case TType::EKind::Struct: { + auto structType = static_cast<const TStructType*>(type); + UNIT_ASSERT_EQUAL(left.GetListLength(), structType->GetMembersCount()); + UNIT_ASSERT_EQUAL(right.GetListLength(), structType->GetMembersCount()); + for (ui32 index = 0; index < structType->GetMembersCount(); ++index) { + auto memberType = structType->GetMemberType(index); + NUdf::TUnboxedValue leftMember = left.GetElement(index); + NUdf::TUnboxedValue rightMember = right.GetElement(index); + AssertUnboxedValuesAreEqual(leftMember, rightMember, memberType); + } + break; + } + + case TType::EKind::Tuple: { + auto tupleType = static_cast<const TTupleType*>(type); + UNIT_ASSERT_EQUAL(left.GetListLength(), tupleType->GetElementsCount()); + UNIT_ASSERT_EQUAL(right.GetListLength(), tupleType->GetElementsCount()); + for (ui32 index = 0; index < tupleType->GetElementsCount(); ++index) { + auto elementType = tupleType->GetElementType(index); + NUdf::TUnboxedValue leftMember = left.GetElement(index); + NUdf::TUnboxedValue rightMember = right.GetElement(index); + AssertUnboxedValuesAreEqual(leftMember, rightMember, elementType); + } + break; + } + + case TType::EKind::Dict: { + auto dictType = static_cast<const TDictType*>(type); + auto payloadType = dictType->GetPayloadType(); + + UNIT_ASSERT_EQUAL(left.GetDictLength(), right.GetDictLength()); + const auto leftIter = left.GetDictIterator(); + for (NUdf::TUnboxedValue key, leftPayload; leftIter.NextPair(key, leftPayload);) { + UNIT_ASSERT(right.Contains(key)); + NUdf::TUnboxedValue rightPayload = right.Lookup(key); + AssertUnboxedValuesAreEqual(leftPayload, rightPayload, payloadType); + } + break; + } + + case TType::EKind::Variant: { + auto variantType = static_cast<const TVariantType*>(type); + UNIT_ASSERT_EQUAL(left.GetVariantIndex(), right.GetVariantIndex()); + ui32 variantIndex = left.GetVariantIndex(); + TType* innerType = variantType->GetUnderlyingType(); + if (innerType->IsStruct()) { + innerType = static_cast<TStructType*>(innerType)->GetMemberType(variantIndex); + } else { + Y_VERIFY_S(innerType->IsTuple(), "Unexpected underlying variant type: " << innerType->GetKindAsStr()); + innerType = static_cast<TTupleType*>(innerType)->GetElementType(variantIndex); + } + NUdf::TUnboxedValue leftValue = left.GetVariantItem(); + NUdf::TUnboxedValue rightValue = right.GetVariantItem(); + AssertUnboxedValuesAreEqual(leftValue, rightValue, innerType); + break; + } + + default: + THROW yexception() << "Unsupported type: " << type->GetKindAsStr(); + } +} +} + + +Y_UNIT_TEST_SUITE(DqUnboxedValueToNativeArrowConversion) { + Y_UNIT_TEST(Struct) { + TTestContext context; + + auto structType = context.GetStructType(); + UNIT_ASSERT(NArrow::IsArrowCompatible(structType)); + + auto values = context.CreateStructs(100); + auto array = NArrow::MakeArray(values, structType); + + UNIT_ASSERT(array->ValidateFull().ok()); + UNIT_ASSERT(array->length() == static_cast<i64>(values.size())); + UNIT_ASSERT(array->type_id() == arrow::Type::STRUCT); + auto structArray = static_pointer_cast<arrow::StructArray>(array); + UNIT_ASSERT(structArray->num_fields() == 3); + UNIT_ASSERT(structArray->field(0)->type_id() == arrow::Type::BINARY); + UNIT_ASSERT(structArray->field(1)->type_id() == arrow::Type::INT32); + UNIT_ASSERT(structArray->field(2)->type_id() == arrow::Type::UINT64); + UNIT_ASSERT(static_cast<ui64>(structArray->field(0)->length()) == values.size()); + UNIT_ASSERT(static_cast<ui64>(structArray->field(1)->length()) == values.size()); + UNIT_ASSERT(static_cast<ui64>(structArray->field(2)->length()) == values.size()); + auto binaryArray = static_pointer_cast<arrow::BinaryArray>(structArray->field(0)); + auto int32Array = static_pointer_cast<arrow::Int32Array>(structArray->field(1)); + auto uint64Array = static_pointer_cast<arrow::UInt64Array>(structArray->field(2)); + auto index = 0; + for (const auto& value: values) { + auto stringValue = value.GetElement(0); + auto stringRef = stringValue.AsStringRef(); + auto stringView = binaryArray->GetView(index); + UNIT_ASSERT_EQUAL(std::string(stringRef.Data(), stringRef.Size()), std::string(stringView)); + + auto intValue = value.GetElement(1).Get<i32>(); + auto intArrow = int32Array->Value(index); + UNIT_ASSERT_EQUAL(intValue, intArrow); + + auto uIntValue = value.GetElement(2).Get<ui64>(); + auto uIntArrow = uint64Array->Value(index); + UNIT_ASSERT_EQUAL(uIntValue, uIntArrow); + ++index; + } + } + + Y_UNIT_TEST(Tuple) { + TTestContext context; + + auto tupleType = context.GetTupleType(); + UNIT_ASSERT(NArrow::IsArrowCompatible(tupleType)); + + auto values = context.CreateTuples(100); + auto array = NArrow::MakeArray(values, tupleType); + UNIT_ASSERT(array->ValidateFull().ok()); + + UNIT_ASSERT(array->length() == static_cast<i64>(values.size())); + UNIT_ASSERT(array->type_id() == arrow::Type::STRUCT); + auto structArray = static_pointer_cast<arrow::StructArray>(array); + UNIT_ASSERT(structArray->num_fields() == 3); + UNIT_ASSERT(structArray->field(0)->type_id() == arrow::Type::BOOL); + UNIT_ASSERT(structArray->field(1)->type_id() == arrow::Type::INT8); + UNIT_ASSERT(structArray->field(2)->type_id() == arrow::Type::UINT8); + UNIT_ASSERT(static_cast<ui64>(structArray->field(0)->length()) == values.size()); + UNIT_ASSERT(static_cast<ui64>(structArray->field(1)->length()) == values.size()); + UNIT_ASSERT(static_cast<ui64>(structArray->field(2)->length()) == values.size()); + auto boolArray = static_pointer_cast<arrow::BooleanArray>(structArray->field(0)); + auto int8Array = static_pointer_cast<arrow::Int8Array>(structArray->field(1)); + auto uint8Array = static_pointer_cast<arrow::UInt8Array>(structArray->field(2)); + auto index = 0; + for (const auto& value: values) { + auto boolValue = value.GetElement(0).Get<bool>(); + auto boolArrow = boolArray->Value(index); + UNIT_ASSERT(boolValue == boolArrow); + + auto intValue = value.GetElement(1).Get<i8>(); + auto intArrow = int8Array->Value(index); + UNIT_ASSERT(intValue == intArrow); + + auto uIntValue = value.GetElement(2).Get<ui8>(); + auto uIntArrow = uint8Array->Value(index); + UNIT_ASSERT(uIntValue == uIntArrow); + ++index; + } + } + + Y_UNIT_TEST(DictUtf8ToInterval) { + TTestContext context; + + auto dictType = context.GetDictUtf8ToIntervalType(); + UNIT_ASSERT(NArrow::IsArrowCompatible(dictType)); + + auto values = context.CreateDictUtf8ToInterval(100); + auto array = NArrow::MakeArray(values, dictType); + UNIT_ASSERT(array->ValidateFull().ok()); + UNIT_ASSERT(static_cast<ui64>(array->length()) == values.size()); + UNIT_ASSERT(array->type_id() == arrow::Type::MAP); + auto mapArray = static_pointer_cast<arrow::MapArray>(array); + + UNIT_ASSERT(mapArray->num_fields() == 1); + UNIT_ASSERT(mapArray->keys()->type_id() == arrow::Type::STRING); + UNIT_ASSERT(mapArray->items()->type_id() == arrow::Type::DURATION); + auto utf8Array = static_pointer_cast<arrow::StringArray>(mapArray->keys()); + auto intervalArray = static_pointer_cast<arrow::NumericArray<arrow::DurationType>>(mapArray->items()); + ui64 index = 0; + for (const auto& value: values) { + UNIT_ASSERT(value.GetDictLength() == static_cast<ui64>(mapArray->value_length(index))); + for (auto subindex = mapArray->value_offset(index); subindex < mapArray->value_offset(index + 1); ++subindex) { + auto keyArrow = utf8Array->GetView(subindex); + NUdf::TUnboxedValue key = MakeString(NUdf::TStringRef(keyArrow.data(), keyArrow.size())); + UNIT_ASSERT(value.Contains(key)); + NUdf::TUnboxedValue payloadValue = value.Lookup(key); + UNIT_ASSERT(intervalArray->Value(subindex) == payloadValue.Get<i64>()); + } + ++index; + } + } + + Y_UNIT_TEST(ListOfJsons) { + TTestContext context; + + auto listType = context.GetListOfJsonsType(); + Y_VERIFY(NArrow::IsArrowCompatible(listType)); + + auto values = context.CreateListOfJsons(100); + auto array = NArrow::MakeArray(values, listType); + UNIT_ASSERT(array->ValidateFull().ok()); + UNIT_ASSERT(static_cast<ui64>(array->length()) == values.size()); + UNIT_ASSERT(array->type_id() == arrow::Type::LIST); + auto listArray = static_pointer_cast<arrow::ListArray>(array); + + UNIT_ASSERT(listArray->num_fields() == 1); + UNIT_ASSERT(listArray->value_type()->id() == arrow::Type::STRING); + auto jsonArray = static_pointer_cast<arrow::StringArray>(listArray->values()); + auto index = 0; + auto innerIndex = 0; + for (const auto& value: values) { + UNIT_ASSERT(value.GetListLength() == static_cast<ui64>(listArray->value_length(index))); + const auto iter = value.GetListIterator(); + for (NUdf::TUnboxedValue item; iter.Next(item);) { + auto view = jsonArray->GetView(innerIndex); + std::string itemArrow(view.data(), view.size()); + auto stringRef = item.AsStringRef(); + std::string itemList(stringRef.Data(), stringRef.Size()); + UNIT_ASSERT(itemList == itemArrow); + ++innerIndex; + } + ++index; + } + } + + Y_UNIT_TEST(VariantOverStruct) { + TTestContext context; + + auto variantType = context.GetVariantOverStructType(); + UNIT_ASSERT(NArrow::IsArrowCompatible(variantType)); + + auto values = context.CreateVariantOverStruct(100); + auto array = NArrow::MakeArray(values, variantType); + UNIT_ASSERT(array->ValidateFull().ok()); + UNIT_ASSERT(static_cast<ui64>(array->length()) == values.size()); + UNIT_ASSERT(array->type_id() == arrow::Type::DENSE_UNION); + auto unionArray = static_pointer_cast<arrow::DenseUnionArray>(array); + + UNIT_ASSERT(unionArray->num_fields() == 4); + UNIT_ASSERT(unionArray->field(0)->type_id() == arrow::Type::STRING); + UNIT_ASSERT(unionArray->field(1)->type_id() == arrow::Type::STRING); + UNIT_ASSERT(unionArray->field(2)->type_id() == arrow::Type::BINARY); + UNIT_ASSERT(unionArray->field(3)->type_id() == arrow::Type::FLOAT); + auto ysonArray = static_pointer_cast<arrow::StringArray>(unionArray->field(0)); + auto jsonDocArray = static_pointer_cast<arrow::StringArray>(unionArray->field(1)); + auto uuidArray = static_pointer_cast<arrow::BinaryArray>(unionArray->field(2)); + auto floatArray = static_pointer_cast<arrow::FloatArray>(unionArray->field(3)); + for (ui64 index = 0; index < values.size(); ++index) { + auto value = values[index]; + UNIT_ASSERT(value.GetVariantIndex() == static_cast<ui32>(unionArray->child_id(index))); + auto fieldIndex = unionArray->value_offset(index); + if (value.GetVariantIndex() == 3) { + auto valueArrow = floatArray->Value(fieldIndex); + auto valueInner = value.GetVariantItem().Get<float>(); + UNIT_ASSERT(valueArrow == valueInner); + } else { + arrow::util::string_view viewArrow; + if (value.GetVariantIndex() == 0) { + viewArrow = ysonArray->GetView(fieldIndex); + } else if (value.GetVariantIndex() == 1) { + viewArrow = jsonDocArray->GetView(fieldIndex); + } else if (value.GetVariantIndex() == 2) { + viewArrow = uuidArray->GetView(fieldIndex); + } + std::string valueArrow(viewArrow.data(), viewArrow.size()); + auto innerItem = value.GetVariantItem(); + auto refInner = innerItem.AsStringRef(); + std::string valueInner(refInner.Data(), refInner.Size()); + UNIT_ASSERT(valueArrow == valueInner); + } + } + } + + Y_UNIT_TEST(VariantOverTupleWithOptionals) { + TTestContext context; + + auto variantType = context.GetVariantOverTupleWithOptionalsType(); + UNIT_ASSERT(NArrow::IsArrowCompatible(variantType)); + + auto values = context.CreateVariantOverStruct(100); + auto array = NArrow::MakeArray(values, variantType); + UNIT_ASSERT(array->ValidateFull().ok()); + UNIT_ASSERT(static_cast<ui64>(array->length()) == values.size()); + UNIT_ASSERT(array->type_id() == arrow::Type::DENSE_UNION); + auto unionArray = static_pointer_cast<arrow::DenseUnionArray>(array); + + UNIT_ASSERT(unionArray->num_fields() == 5); + UNIT_ASSERT(unionArray->field(0)->type_id() == arrow::Type::BOOL); + UNIT_ASSERT(unionArray->field(1)->type_id() == arrow::Type::INT16); + UNIT_ASSERT(unionArray->field(2)->type_id() == arrow::Type::UINT16); + UNIT_ASSERT(unionArray->field(3)->type_id() == arrow::Type::INT32); + UNIT_ASSERT(unionArray->field(4)->type_id() == arrow::Type::UINT32); + auto boolArray = static_pointer_cast<arrow::BooleanArray>(unionArray->field(0)); + auto i16Array = static_pointer_cast<arrow::Int16Array>(unionArray->field(1)); + auto ui16Array = static_pointer_cast<arrow::UInt16Array>(unionArray->field(2)); + auto i32Array = static_pointer_cast<arrow::Int32Array>(unionArray->field(3)); + auto ui32Array = static_pointer_cast<arrow::UInt32Array>(unionArray->field(4)); + for (ui64 index = 0; index < values.size(); ++index) { + auto value = values[index]; + UNIT_ASSERT(value.GetVariantIndex() == static_cast<ui32>(unionArray->child_id(index))); + auto fieldIndex = unionArray->value_offset(index); + if (value.GetVariantIndex() == 0) { + bool valueArrow = boolArray->Value(fieldIndex); + auto valueInner = value.GetVariantItem().Get<bool>(); + UNIT_ASSERT(valueArrow == valueInner); + } else if (value.GetVariantIndex() == 1) { + auto valueArrow = i16Array->Value(fieldIndex); + auto valueInner = value.GetVariantItem().Get<i16>(); + UNIT_ASSERT(valueArrow == valueInner); + } else if (value.GetVariantIndex() == 2) { + auto valueArrow = ui16Array->Value(fieldIndex); + auto valueInner = value.GetVariantItem().Get<ui16>(); + UNIT_ASSERT(valueArrow == valueInner); + } else if (value.GetVariantIndex() == 3) { + auto valueArrow = i32Array->Value(fieldIndex); + auto valueInner = value.GetVariantItem().Get<i32>(); + UNIT_ASSERT(valueArrow == valueInner); + } else if (value.GetVariantIndex() == 4) { + if (!value.GetVariantItem().HasValue()) { + UNIT_ASSERT(ui32Array->IsNull(fieldIndex)); + } else { + auto valueArrow = ui32Array->Value(fieldIndex); + auto valueInner = value.GetVariantItem().Get<ui32>(); + UNIT_ASSERT(valueArrow == valueInner); + } + } + } + } +} + +Y_UNIT_TEST_SUITE(DqUnboxedValueDoNotFitToArrow) { + Y_UNIT_TEST(DictOptionalToTuple) { + TTestContext context; + + auto dictType = context.GetDictOptionalToTupleType(); + UNIT_ASSERT(!NArrow::IsArrowCompatible(dictType)); + + auto values = context.CreateDictOptionalToTuple(100); + auto array = NArrow::MakeArray(values, dictType); + UNIT_ASSERT(array->ValidateFull().ok()); + UNIT_ASSERT_EQUAL(static_cast<ui64>(array->length()), values.size()); + UNIT_ASSERT_EQUAL(array->type_id(), arrow::Type::LIST); + auto listArray = static_pointer_cast<arrow::ListArray>(array); + UNIT_ASSERT_EQUAL(listArray->value_type()->id(), arrow::Type::STRUCT); + auto structArray = static_pointer_cast<arrow::StructArray>(listArray->values()); + + UNIT_ASSERT_EQUAL(listArray->num_fields(), 1); + UNIT_ASSERT_EQUAL(structArray->num_fields(), 2); + UNIT_ASSERT_EQUAL(structArray->field(0)->type_id(), arrow::Type::DOUBLE); + UNIT_ASSERT_EQUAL(structArray->field(1)->type_id(), arrow::Type::STRUCT); + auto keysArray = static_pointer_cast<arrow::DoubleArray>(structArray->field(0)); + auto itemsArray = static_pointer_cast<arrow::StructArray>(structArray->field(1)); + UNIT_ASSERT_EQUAL(itemsArray->num_fields(), 2); + UNIT_ASSERT_EQUAL(itemsArray->field(0)->type_id(), arrow::Type::INT32); + UNIT_ASSERT_EQUAL(itemsArray->field(1)->type_id(), arrow::Type::UINT32); + auto i32Array = static_pointer_cast<arrow::Int32Array>(itemsArray->field(0)); + auto ui32Array = static_pointer_cast<arrow::UInt32Array>(itemsArray->field(1)); + + ui64 index = 0; + for (const auto& value: values) { + UNIT_ASSERT(value.GetDictLength() == static_cast<ui64>(listArray->value_length(index))); + for (auto subindex = listArray->value_offset(index); subindex < listArray->value_offset(index + 1); ++subindex) { + NUdf::TUnboxedValue key = keysArray->IsNull(subindex) + ? NUdf::TUnboxedValuePod() + : NUdf::TUnboxedValuePod(keysArray->Value(subindex)); + UNIT_ASSERT(value.Contains(key)); + NUdf::TUnboxedValue payloadValue = value.Lookup(key); + UNIT_ASSERT_EQUAL(payloadValue.GetElement(0).Get<i32>(), i32Array->Value(subindex)); + UNIT_ASSERT_EQUAL(payloadValue.GetElement(1).Get<ui32>(), ui32Array->Value(subindex)); + } + ++index; + } + } + + Y_UNIT_TEST(OptionalOfOptional) { + TTestContext context; + + auto doubleOptionalType = context.GetOptionalOfOptionalType(); + UNIT_ASSERT(!NArrow::IsArrowCompatible(doubleOptionalType)); + + auto values = context.CreateOptionalOfOptional(100); + auto array = NArrow::MakeArray(values, doubleOptionalType); + UNIT_ASSERT(array->ValidateFull().ok()); + UNIT_ASSERT_EQUAL(static_cast<ui64>(array->length()), values.size()); + UNIT_ASSERT_EQUAL(array->type_id(), arrow::Type::STRUCT); + auto structArray = static_pointer_cast<arrow::StructArray>(array); + UNIT_ASSERT_EQUAL(structArray->num_fields(), 2); + UNIT_ASSERT_EQUAL(structArray->field(0)->type_id(), arrow::Type::UINT64); + UNIT_ASSERT_EQUAL(structArray->field(1)->type_id(), arrow::Type::INT32); + auto depthArray = static_pointer_cast<arrow::UInt64Array>(structArray->field(0)); + auto i32Array = static_pointer_cast<arrow::Int32Array>(structArray->field(1)); + + auto index = 0; + for (auto value: values) { + auto depth = depthArray->Value(index); + while (depth > 0) { + UNIT_ASSERT(value.HasValue()); + value = value.GetOptionalValue(); + --depth; + } + if (value.HasValue()) { + UNIT_ASSERT_EQUAL(value.Get<i32>(), i32Array->Value(index)); + } else { + UNIT_ASSERT(i32Array->IsNull(index)); + } + ++index; + } + } + + Y_UNIT_TEST(LargeVariant) { + TTestContext context; + + ui32 numberOfTypes = 500; + auto variantType = context.GetLargeVariantType(numberOfTypes); + bool isCompatible = NArrow::IsArrowCompatible(variantType); + UNIT_ASSERT(!isCompatible); + + auto values = context.CreateLargeVariant(1000); + auto array = NArrow::MakeArray(values, variantType); + UNIT_ASSERT(array->ValidateFull().ok()); + UNIT_ASSERT_EQUAL(static_cast<ui64>(array->length()), values.size()); + UNIT_ASSERT_EQUAL(array->type_id(), arrow::Type::DENSE_UNION); + auto unionArray = static_pointer_cast<arrow::DenseUnionArray>(array); + ui32 numberOfGroups = (numberOfTypes - 1) / arrow::UnionType::kMaxTypeCode + 1; + UNIT_ASSERT_EQUAL(numberOfGroups, static_cast<ui32>(unionArray->num_fields())); + ui32 typesInArrow = 0; + for (auto i = 0 ; i < unionArray->num_fields(); ++i) { + UNIT_ASSERT_EQUAL(unionArray->field(i)->type_id(), arrow::Type::DENSE_UNION); + typesInArrow += unionArray->field(i)->num_fields(); + } + UNIT_ASSERT_EQUAL(numberOfTypes, typesInArrow); + // TODO Check array content. + } +} + +Y_UNIT_TEST_SUITE(ConvertUnboxedValueToArrowAndBack){ + Y_UNIT_TEST(Struct) { + TTestContext context; + + auto structType = context.GetStructType(); + auto values = context.CreateStructs(100); + auto array = NArrow::MakeArray(values, structType); + auto restoredValues = NArrow::ExtractUnboxedValues(array, structType, context.HolderFactory); + UNIT_ASSERT_EQUAL(values.size(), restoredValues.size()); + for (ui64 index = 0; index < values.size(); ++index) { + AssertUnboxedValuesAreEqual(values[index], restoredValues[index], structType); + } + } + + Y_UNIT_TEST(Tuple) { + TTestContext context; + + auto tupleType = context.GetTupleType(); + UNIT_ASSERT(NArrow::IsArrowCompatible(tupleType)); + + auto values = context.CreateTuples(100); + auto array = NArrow::MakeArray(values, tupleType); + auto restoredValues = NArrow::ExtractUnboxedValues(array, tupleType, context.HolderFactory); + UNIT_ASSERT_EQUAL(values.size(), restoredValues.size()); + for (ui64 index = 0; index < values.size(); ++index) { + AssertUnboxedValuesAreEqual(values[index], restoredValues[index], tupleType); + } + } + + Y_UNIT_TEST(DictUtf8ToInterval) { + TTestContext context; + + auto dictType = context.GetDictUtf8ToIntervalType(); + UNIT_ASSERT(NArrow::IsArrowCompatible(dictType)); + + auto values = context.CreateDictUtf8ToInterval(100); + auto array = NArrow::MakeArray(values, dictType); + auto restoredValues = NArrow::ExtractUnboxedValues(array, dictType, context.HolderFactory); + UNIT_ASSERT_EQUAL(values.size(), restoredValues.size()); + for (ui64 index = 0; index < values.size(); ++index) { + AssertUnboxedValuesAreEqual(values[index], restoredValues[index], dictType); + } + } + + Y_UNIT_TEST(ListOfJsons) { + TTestContext context; + + auto listType = context.GetListOfJsonsType(); + Y_VERIFY(NArrow::IsArrowCompatible(listType)); + + auto values = context.CreateListOfJsons(100); + auto array = NArrow::MakeArray(values, listType); + auto restoredValues = NArrow::ExtractUnboxedValues(array, listType, context.HolderFactory); + UNIT_ASSERT_EQUAL(values.size(), restoredValues.size()); + for (ui64 index = 0; index < values.size(); ++index) { + AssertUnboxedValuesAreEqual(values[index], restoredValues[index], listType); + } + } + + Y_UNIT_TEST(VariantOverStruct) { + TTestContext context; + + auto variantType = context.GetVariantOverStructType(); + UNIT_ASSERT(NArrow::IsArrowCompatible(variantType)); + + auto values = context.CreateVariantOverStruct(100); + auto array = NArrow::MakeArray(values, variantType); + auto restoredValues = NArrow::ExtractUnboxedValues(array, variantType, context.HolderFactory); + UNIT_ASSERT_EQUAL(values.size(), restoredValues.size()); + for (ui64 index = 0; index < values.size(); ++index) { + AssertUnboxedValuesAreEqual(values[index], restoredValues[index], variantType); + } + } + + Y_UNIT_TEST(VariantOverTupleWithOptionals) { + TTestContext context; + + auto variantType = context.GetVariantOverTupleWithOptionalsType(); + UNIT_ASSERT(NArrow::IsArrowCompatible(variantType)); + + auto values = context.CreateVariantOverStruct(100); + auto array = NArrow::MakeArray(values, variantType); + auto restoredValues = NArrow::ExtractUnboxedValues(array, variantType, context.HolderFactory); + UNIT_ASSERT_EQUAL(values.size(), restoredValues.size()); + for (ui64 index = 0; index < values.size(); ++index) { + AssertUnboxedValuesAreEqual(values[index], restoredValues[index], variantType); + } + } + + Y_UNIT_TEST(DictOptionalToTuple) { + TTestContext context; + + auto dictType = context.GetDictOptionalToTupleType(); + UNIT_ASSERT(!NArrow::IsArrowCompatible(dictType)); + + auto values = context.CreateDictOptionalToTuple(100); + auto array = NArrow::MakeArray(values, dictType); + auto restoredValues = NArrow::ExtractUnboxedValues(array, dictType, context.HolderFactory); + UNIT_ASSERT_EQUAL(values.size(), restoredValues.size()); + for (ui64 index = 0; index < values.size(); ++index) { + AssertUnboxedValuesAreEqual(values[index], restoredValues[index], dictType); + } + } + + Y_UNIT_TEST(OptionalOfOptional) { + TTestContext context; + + auto doubleOptionalType = context.GetOptionalOfOptionalType(); + UNIT_ASSERT(!NArrow::IsArrowCompatible(doubleOptionalType)); + + auto values = context.CreateOptionalOfOptional(100); + auto array = NArrow::MakeArray(values, doubleOptionalType); + auto restoredValues = NArrow::ExtractUnboxedValues(array, doubleOptionalType, context.HolderFactory); + UNIT_ASSERT_EQUAL(values.size(), restoredValues.size()); + for (ui64 index = 0; index < values.size(); ++index) { + AssertUnboxedValuesAreEqual(values[index], restoredValues[index], doubleOptionalType); + } + } + + Y_UNIT_TEST(LargeVariant) { + TTestContext context; + + auto variantType = context.GetLargeVariantType(500); + bool isCompatible = NArrow::IsArrowCompatible(variantType); + UNIT_ASSERT(!isCompatible); + + auto values = context.CreateLargeVariant(1000); + auto array = NArrow::MakeArray(values, variantType); + auto restoredValues = NArrow::ExtractUnboxedValues(array, variantType, context.HolderFactory); + UNIT_ASSERT_EQUAL(values.size(), restoredValues.size()); + for (ui64 index = 0; index < values.size(); ++index) { + AssertUnboxedValuesAreEqual(values[index], restoredValues[index], variantType); + } + } } diff --git a/ydb/library/yql/dq/runtime/dq_input_channel.cpp b/ydb/library/yql/dq/runtime/dq_input_channel.cpp index 56e3ad2a15f..bd0a541e377 100644 --- a/ydb/library/yql/dq/runtime/dq_input_channel.cpp +++ b/ydb/library/yql/dq/runtime/dq_input_channel.cpp @@ -8,8 +8,8 @@ class TDqInputChannel : public TDqInputImpl<TDqInputChannel, IDqInputChannel> { friend TBaseImpl; public: TDqInputChannel(ui64 channelId, NKikimr::NMiniKQL::TType* inputType, ui64 maxBufferBytes, bool collectProfileStats, - const NKikimr::NMiniKQL::TTypeEnvironment& typeEnv, const NKikimr::NMiniKQL::THolderFactory& holderFactory, - NDqProto::EDataTransportVersion transportVersion) + const NKikimr::NMiniKQL::TTypeEnvironment& typeEnv, const NKikimr::NMiniKQL::THolderFactory& holderFactory, + NDqProto::EDataTransportVersion transportVersion) : TBaseImpl(inputType, maxBufferBytes) , ChannelId(channelId) , BasicStats(ChannelId) diff --git a/ydb/library/yql/dq/runtime/dq_output_channel.cpp b/ydb/library/yql/dq/runtime/dq_output_channel.cpp index aba417ede31..211d4ac39a3 100644 --- a/ydb/library/yql/dq/runtime/dq_output_channel.cpp +++ b/ydb/library/yql/dq/runtime/dq_output_channel.cpp @@ -73,7 +73,7 @@ public: , OutputType(outputType) , BasicStats(ChannelId) , ProfileStats(collectProfileStats ? &BasicStats : nullptr) - , DataSerializer(typeEnv, holderFactory, transportVersion) + , DataSerializer(typeEnv, holderFactory, transportVersion) , MaxStoredBytes(maxStoredBytes) , MaxChunkBytes(maxChunkBytes) , LogFunc(logFunc) {} diff --git a/ydb/library/yql/dq/runtime/dq_output_channel_ut.cpp b/ydb/library/yql/dq/runtime/dq_output_channel_ut.cpp index aa40a1dcc58..2c3c1ad100d 100644 --- a/ydb/library/yql/dq/runtime/dq_output_channel_ut.cpp +++ b/ydb/library/yql/dq/runtime/dq_output_channel_ut.cpp @@ -32,7 +32,7 @@ struct TTestContext { TMemoryUsageInfo MemInfo; THolderFactory HolderFactory; TDefaultValueBuilder Vb; - NDqProto::EDataTransportVersion TransportVersion; + NDqProto::EDataTransportVersion TransportVersion; TDqDataSerializer Ds; TStructType* OutputType = nullptr; @@ -42,8 +42,8 @@ struct TTestContext { , MemInfo("Mem") , HolderFactory(Alloc.Ref(), MemInfo) , Vb(HolderFactory) - , TransportVersion(transportVersion) - , Ds(TypeEnv, HolderFactory, TransportVersion) + , TransportVersion(transportVersion) + , Ds(TypeEnv, HolderFactory, TransportVersion) { if (bigRows) { TStructMember members[3] = { @@ -296,7 +296,7 @@ void TestBigRow(TTestContext& ctx, bool quantum) { } -void TestSpillWithMockStorage(TTestContext& ctx) { +void TestSpillWithMockStorage(TTestContext& ctx) { TDqOutputChannelSettings settings; settings.MaxStoredBytes = 100; settings.MaxChunkBytes = 10; @@ -414,7 +414,7 @@ void TestSpillWithMockStorage(TTestContext& ctx) { UNIT_ASSERT_VALUES_EQUAL(0, ch->GetValuesCount(/* inMemoryOnly */ false)); } -void TestOverflowWithMockStorage(TTestContext& ctx) { +void TestOverflowWithMockStorage(TTestContext& ctx) { TDqOutputChannelSettings settings; settings.MaxStoredBytes = 100; settings.MaxChunkBytes = 10; @@ -446,40 +446,40 @@ void TestOverflowWithMockStorage(TTestContext& ctx) { } } -} // anonymous namespace - -Y_UNIT_TEST_SUITE(DqOutputChannelNoStorageTests) { - -Y_UNIT_TEST(SingleRead) { - TTestContext ctx; +} // anonymous namespace + +Y_UNIT_TEST_SUITE(DqOutputChannelNoStorageTests) { + +Y_UNIT_TEST(SingleRead) { + TTestContext ctx; TestSingleRead(ctx, false); } - + Y_UNIT_TEST(SingleReadQ) { TTestContext ctx; TestSingleRead(ctx, true); } -Y_UNIT_TEST(SingleReadWithArrow) { - TTestContext ctx(NDqProto::DATA_TRANSPORT_ARROW_1_0); +Y_UNIT_TEST(SingleReadWithArrow) { + TTestContext ctx(NDqProto::DATA_TRANSPORT_ARROW_1_0); TestSingleRead(ctx, false); -} - +} + Y_UNIT_TEST(SingleReadWithArrowQ) { TTestContext ctx(NDqProto::DATA_TRANSPORT_ARROW_1_0); TestSingleRead(ctx, true); } -Y_UNIT_TEST(PartialRead) { - TTestContext ctx; +Y_UNIT_TEST(PartialRead) { + TTestContext ctx; TestPartialRead(ctx, false); -} - +} + Y_UNIT_TEST(PartialReadQ) { TTestContext ctx; TestPartialRead(ctx, true); -} - +} + // too heavy messages... //Y_UNIT_TEST(PartialReadWithArrow) { // TTestContext ctx(NDqProto::DATA_TRANSPORT_ARROW_1_0); @@ -491,46 +491,46 @@ Y_UNIT_TEST(PartialReadQ) { // TestPartialRead(ctx, true); //} -Y_UNIT_TEST(Overflow) { - TTestContext ctx; +Y_UNIT_TEST(Overflow) { + TTestContext ctx; TestOverflow(ctx, false); -} - +} + Y_UNIT_TEST(OverflowQ) { TTestContext ctx; TestOverflow(ctx, true); } -Y_UNIT_TEST(OverflowWithArrow) { - TTestContext ctx(NDqProto::DATA_TRANSPORT_ARROW_1_0); +Y_UNIT_TEST(OverflowWithArrow) { + TTestContext ctx(NDqProto::DATA_TRANSPORT_ARROW_1_0); TestOverflow(ctx, false); -} - +} + Y_UNIT_TEST(OverflowWithArrowQ) { TTestContext ctx(NDqProto::DATA_TRANSPORT_ARROW_1_0); TestOverflow(ctx, true); } -Y_UNIT_TEST(PopAll) { - TTestContext ctx; +Y_UNIT_TEST(PopAll) { + TTestContext ctx; TestPopAll(ctx, false); -} - +} + Y_UNIT_TEST(PopAllQ) { TTestContext ctx; TestPopAll(ctx, true); } -Y_UNIT_TEST(PopAllWithArrow) { - TTestContext ctx(NDqProto::DATA_TRANSPORT_ARROW_1_0); +Y_UNIT_TEST(PopAllWithArrow) { + TTestContext ctx(NDqProto::DATA_TRANSPORT_ARROW_1_0); TestPopAll(ctx, false); -} +} Y_UNIT_TEST(PopAllWithArrowQ) { TTestContext ctx(NDqProto::DATA_TRANSPORT_ARROW_1_0); TestPopAll(ctx, true); -} - +} + Y_UNIT_TEST(BigRow) { TTestContext ctx(NDqProto::DATA_TRANSPORT_UV_PICKLE_1_0, true); TestBigRow(ctx, false); @@ -543,27 +543,27 @@ Y_UNIT_TEST(BigRowQ) { } -Y_UNIT_TEST_SUITE(DqOutputChannelWithStorageTests) { - -Y_UNIT_TEST(Spill) { - TTestContext ctx; - TestSpillWithMockStorage(ctx); -} - -// Fail because arrow serialization has a big overhead -// Y_UNIT_TEST(SpillWithArrow) { -// TTestContext ctx(NDqProto::DATA_TRANSPORT_ARROW_1_0); -// TestSpillWithMockStorage(ctx); -// } - -Y_UNIT_TEST(Overflow) { - TTestContext ctx; - TestOverflowWithMockStorage(ctx); -} - -// Fail because arrow serialization has a big overhead -// Y_UNIT_TEST(OverflowWithArrow) { -// TTestContext ctx(NDqProto::DATA_TRANSPORT_ARROW_1_0); -// TestOverflowWithMockStorage(ctx); -// } -} +Y_UNIT_TEST_SUITE(DqOutputChannelWithStorageTests) { + +Y_UNIT_TEST(Spill) { + TTestContext ctx; + TestSpillWithMockStorage(ctx); +} + +// Fail because arrow serialization has a big overhead +// Y_UNIT_TEST(SpillWithArrow) { +// TTestContext ctx(NDqProto::DATA_TRANSPORT_ARROW_1_0); +// TestSpillWithMockStorage(ctx); +// } + +Y_UNIT_TEST(Overflow) { + TTestContext ctx; + TestOverflowWithMockStorage(ctx); +} + +// Fail because arrow serialization has a big overhead +// Y_UNIT_TEST(OverflowWithArrow) { +// TTestContext ctx(NDqProto::DATA_TRANSPORT_ARROW_1_0); +// TestOverflowWithMockStorage(ctx); +// } +} diff --git a/ydb/library/yql/dq/runtime/dq_tasks_runner.cpp b/ydb/library/yql/dq/runtime/dq_tasks_runner.cpp index 767155a7ef3..10900352c3b 100644 --- a/ydb/library/yql/dq/runtime/dq_tasks_runner.cpp +++ b/ydb/library/yql/dq/runtime/dq_tasks_runner.cpp @@ -120,7 +120,7 @@ void ValidateParamValue(std::string_view paramName, const TType* type, const NUd #define LOG(...) do { if (Y_UNLIKELY(LogFunc)) { LogFunc(__VA_ARGS__); } } while (0) -NUdf::TUnboxedValue DqBuildInputValue(const NDqProto::TTaskInput& inputDesc, const NKikimr::NMiniKQL::TType* type, +NUdf::TUnboxedValue DqBuildInputValue(const NDqProto::TTaskInput& inputDesc, const NKikimr::NMiniKQL::TType* type, TVector<IDqInput::TPtr>&& inputs, const THolderFactory& holderFactory) { switch (inputDesc.GetTypeCase()) { @@ -420,8 +420,8 @@ public: for (auto& inputChannelDesc : inputDesc.GetChannels()) { ui64 channelId = inputChannelDesc.GetId(); auto inputChannel = CreateDqInputChannel(channelId, ProgramParsed.InputItemTypes[i], - memoryLimits.ChannelBufferSize, Settings.CollectProfileStats, typeEnv, holderFactory, - inputChannelDesc.GetTransportVersion()); + memoryLimits.ChannelBufferSize, Settings.CollectProfileStats, typeEnv, holderFactory, + inputChannelDesc.GetTransportVersion()); auto ret = InputChannels.emplace(channelId, inputChannel); YQL_ENSURE(ret.second, "task: " << TaskId << ", duplicated input channelId: " << channelId); inputs.emplace_back(inputChannel); diff --git a/ydb/library/yql/dq/runtime/dq_transport.cpp b/ydb/library/yql/dq/runtime/dq_transport.cpp index d66c0b4079e..93b85b0c6f8 100644 --- a/ydb/library/yql/dq/runtime/dq_transport.cpp +++ b/ydb/library/yql/dq/runtime/dq_transport.cpp @@ -1,5 +1,5 @@ #include "dq_transport.h" -#include "dq_arrow_helpers.h" +#include "dq_arrow_helpers.h" #include <ydb/library/mkql_proto/mkql_proto.h> #include <ydb/library/yql/minikql/computation/mkql_computation_node_holders.h> @@ -7,8 +7,8 @@ #include <ydb/library/yql/providers/common/mkql/yql_type_mkql.h> #include <ydb/library/yql/utils/yql_panic.h> -#include <util/system/yassert.h> - +#include <util/system/yassert.h> + namespace NYql::NDq { using namespace NKikimr; @@ -18,10 +18,10 @@ using namespace NYql; namespace { NDqProto::TData SerializeBufferArrowV1(TUnboxedValueVector& buffer, const TType* itemType); - + void DeserializeBufferArrowV1(const NDqProto::TData& data, const TType* itemType, - const THolderFactory& holderFactory, TUnboxedValueVector& buffer); - + const THolderFactory& holderFactory, TUnboxedValueVector& buffer); + NDqProto::TData SerializeValuePickleV1(const TType* type, const NUdf::TUnboxedValuePod& value) { TValuePacker packer(/* stable */ false, type); TStringBuf packResult = packer.Pack(value); @@ -35,15 +35,15 @@ NDqProto::TData SerializeValuePickleV1(const TType* type, const NUdf::TUnboxedVa } NDqProto::TData SerializeValueArrowV1(const TType* type, const NUdf::TUnboxedValuePod& value) { - TUnboxedValueVector buffer; - buffer.push_back(value); - return SerializeBufferArrowV1(buffer, type); -} - + TUnboxedValueVector buffer; + buffer.push_back(value); + return SerializeBufferArrowV1(buffer, type); +} + void DeserializeValuePickleV1(const TType* type, const NDqProto::TData& data, NUdf::TUnboxedValue& value, const THolderFactory& holderFactory) { - YQL_ENSURE(data.GetTransportVersion() == (ui32) NDqProto::DATA_TRANSPORT_UV_PICKLE_1_0); + YQL_ENSURE(data.GetTransportVersion() == (ui32) NDqProto::DATA_TRANSPORT_UV_PICKLE_1_0); TValuePacker packer(/* stable */ false, type); value = packer.Unpack(data.GetRaw(), holderFactory); } @@ -51,41 +51,41 @@ void DeserializeValuePickleV1(const TType* type, const NDqProto::TData& data, NU void DeserializeValueArrowV1(const TType* type, const NDqProto::TData& data, NUdf::TUnboxedValue& value, const THolderFactory& holderFactory) { - TUnboxedValueVector buffer; - DeserializeBufferArrowV1(data, type, holderFactory, buffer); - value = buffer[0]; -} - + TUnboxedValueVector buffer; + DeserializeBufferArrowV1(data, type, holderFactory, buffer); + value = buffer[0]; +} + NDqProto::TData SerializeBufferPickleV1(TUnboxedValueVector& buffer, const TType* itemType, const TTypeEnvironment& typeEnv, const THolderFactory& holderFactory) { const auto listType = TListType::Create(const_cast<TType*>(itemType), typeEnv); const NUdf::TUnboxedValue listValue = holderFactory.VectorAsArray(buffer); - auto data = SerializeValuePickleV1(listType, listValue); + auto data = SerializeValuePickleV1(listType, listValue); data.SetRows(buffer.size()); return data; } NDqProto::TData SerializeBufferArrowV1(TUnboxedValueVector& buffer, const TType* itemType) { - auto array = NArrow::MakeArray(buffer, itemType); - - auto serialized = NArrow::SerializeArray(array); - - NDqProto::TData data; - data.SetTransportVersion(NDqProto::DATA_TRANSPORT_ARROW_1_0); - data.SetRaw(serialized.data(), serialized.size()); - data.SetRows(buffer.size()); - return data; -} - + auto array = NArrow::MakeArray(buffer, itemType); + + auto serialized = NArrow::SerializeArray(array); + + NDqProto::TData data; + data.SetTransportVersion(NDqProto::DATA_TRANSPORT_ARROW_1_0); + data.SetRaw(serialized.data(), serialized.size()); + data.SetRows(buffer.size()); + return data; +} + void DeserializeBufferPickleV1(const NDqProto::TData& data, const TType* itemType, const TTypeEnvironment& typeEnv, const THolderFactory& holderFactory, TUnboxedValueVector& buffer) { auto listType = TListType::Create(const_cast<TType*>(itemType), typeEnv); NUdf::TUnboxedValue value; - DeserializeValuePickleV1(listType, data, value, holderFactory); + DeserializeValuePickleV1(listType, data, value, holderFactory); const auto iter = value.GetListIterator(); for (NUdf::TUnboxedValue item; iter.Next(item);) { @@ -96,94 +96,94 @@ void DeserializeBufferPickleV1(const NDqProto::TData& data, const TType* itemTyp void DeserializeBufferArrowV1(const NDqProto::TData& data, const TType* itemType, const THolderFactory& holderFactory, TUnboxedValueVector& buffer) { - YQL_ENSURE(data.GetTransportVersion() == (ui32) NDqProto::DATA_TRANSPORT_ARROW_1_0); - - auto array = NArrow::DeserializeArray(data.GetRaw(), NArrow::GetArrowType(itemType)); - YQL_ENSURE(array->length() == data.GetRows()); - auto newElements = NArrow::ExtractUnboxedValues(array, itemType, holderFactory); - for (NUdf::TUnboxedValue item: newElements) { - buffer.emplace_back(std::move(item)); - } -} - + YQL_ENSURE(data.GetTransportVersion() == (ui32) NDqProto::DATA_TRANSPORT_ARROW_1_0); + + auto array = NArrow::DeserializeArray(data.GetRaw(), NArrow::GetArrowType(itemType)); + YQL_ENSURE(array->length() == data.GetRows()); + auto newElements = NArrow::ExtractUnboxedValues(array, itemType, holderFactory); + for (NUdf::TUnboxedValue item: newElements) { + buffer.emplace_back(std::move(item)); + } +} + NDqProto::TData SerializeParamV1(const TMkqlValueRef& param, const TTypeEnvironment& typeEnv, const THolderFactory& holderFactory) { auto [type, value] = ImportValueFromProto(param.GetType(), param.GetValue(), typeEnv, holderFactory); - return SerializeValuePickleV1(type, value); + return SerializeValuePickleV1(type, value); } void DeserializeParamV1(const NDqProto::TData& data, const TType* type, const THolderFactory& holderFactory, NUdf::TUnboxedValue& value) { - DeserializeValuePickleV1(type, data, value, holderFactory); + DeserializeValuePickleV1(type, data, value, holderFactory); } } // namespace NDqProto::EDataTransportVersion TDqDataSerializer::GetTransportVersion() const { - return TransportVersion; + return TransportVersion; } NDqProto::TData TDqDataSerializer::Serialize(const NUdf::TUnboxedValue& value, const TType* itemType) const { - switch (TransportVersion) { - case NDqProto::DATA_TRANSPORT_VERSION_UNSPECIFIED: - case NDqProto::DATA_TRANSPORT_UV_PICKLE_1_0: - return SerializeValuePickleV1(itemType, value); - case NDqProto::DATA_TRANSPORT_ARROW_1_0: - return SerializeValueArrowV1(itemType, value); - default: - YQL_ENSURE(false, "Unsupported TransportVersion"); - } + switch (TransportVersion) { + case NDqProto::DATA_TRANSPORT_VERSION_UNSPECIFIED: + case NDqProto::DATA_TRANSPORT_UV_PICKLE_1_0: + return SerializeValuePickleV1(itemType, value); + case NDqProto::DATA_TRANSPORT_ARROW_1_0: + return SerializeValueArrowV1(itemType, value); + default: + YQL_ENSURE(false, "Unsupported TransportVersion"); + } } NDqProto::TData TDqDataSerializer::Serialize(TUnboxedValueVector& buffer, const TType* itemType) const { - switch (TransportVersion) { - case NDqProto::DATA_TRANSPORT_VERSION_UNSPECIFIED: - case NDqProto::DATA_TRANSPORT_UV_PICKLE_1_0: - return SerializeBufferPickleV1(buffer, itemType, TypeEnv, HolderFactory); - case NDqProto::DATA_TRANSPORT_ARROW_1_0: - return SerializeBufferArrowV1(buffer, itemType); - default: - YQL_ENSURE(false, "Unsupported TransportVersion"); - } + switch (TransportVersion) { + case NDqProto::DATA_TRANSPORT_VERSION_UNSPECIFIED: + case NDqProto::DATA_TRANSPORT_UV_PICKLE_1_0: + return SerializeBufferPickleV1(buffer, itemType, TypeEnv, HolderFactory); + case NDqProto::DATA_TRANSPORT_ARROW_1_0: + return SerializeBufferArrowV1(buffer, itemType); + default: + YQL_ENSURE(false, "Unsupported TransportVersion"); + } } void TDqDataSerializer::Deserialize(const NDqProto::TData& data, const TType* itemType, TUnboxedValueVector& buffer) const { - switch (TransportVersion) { - case NDqProto::DATA_TRANSPORT_VERSION_UNSPECIFIED: - case NDqProto::DATA_TRANSPORT_UV_PICKLE_1_0: { - DeserializeBufferPickleV1(data, itemType, TypeEnv, HolderFactory, buffer); - break; - } - case NDqProto::DATA_TRANSPORT_ARROW_1_0: { - DeserializeBufferArrowV1(data, itemType, HolderFactory, buffer); - break; - } - default: - YQL_ENSURE(false, "Unsupported TransportVersion"); - } + switch (TransportVersion) { + case NDqProto::DATA_TRANSPORT_VERSION_UNSPECIFIED: + case NDqProto::DATA_TRANSPORT_UV_PICKLE_1_0: { + DeserializeBufferPickleV1(data, itemType, TypeEnv, HolderFactory, buffer); + break; + } + case NDqProto::DATA_TRANSPORT_ARROW_1_0: { + DeserializeBufferArrowV1(data, itemType, HolderFactory, buffer); + break; + } + default: + YQL_ENSURE(false, "Unsupported TransportVersion"); + } } void TDqDataSerializer::Deserialize(const NDqProto::TData& data, const TType* itemType, NUdf::TUnboxedValue& value) const { - switch (TransportVersion) { - case NDqProto::DATA_TRANSPORT_VERSION_UNSPECIFIED: - case NDqProto::DATA_TRANSPORT_UV_PICKLE_1_0: { - DeserializeValuePickleV1(itemType, data, value, HolderFactory); - break; - } - case NDqProto::DATA_TRANSPORT_ARROW_1_0: { - DeserializeValueArrowV1(itemType, data, value, HolderFactory); - break; - } - default: - YQL_ENSURE(false, "Unsupported TransportVersion"); - } + switch (TransportVersion) { + case NDqProto::DATA_TRANSPORT_VERSION_UNSPECIFIED: + case NDqProto::DATA_TRANSPORT_UV_PICKLE_1_0: { + DeserializeValuePickleV1(itemType, data, value, HolderFactory); + break; + } + case NDqProto::DATA_TRANSPORT_ARROW_1_0: { + DeserializeValueArrowV1(itemType, data, value, HolderFactory); + break; + } + default: + YQL_ENSURE(false, "Unsupported TransportVersion"); + } } @@ -206,12 +206,12 @@ NDqProto::TData TDqDataSerializer::SerializeParamValue(const TType* type, const } ui64 TDqDataSerializer::CalcSerializedSize(NUdf::TUnboxedValue& value, const NKikimr::NMiniKQL::TType* itemType) { - auto data = SerializeValuePickleV1(itemType, value); - // YQL-9648 - DeserializeValuePickleV1(itemType, data, value, HolderFactory); - return data.GetRaw().size(); -} - + auto data = SerializeValuePickleV1(itemType, value); + // YQL-9648 + DeserializeValuePickleV1(itemType, data, value, HolderFactory); + return data.GetRaw().size(); +} + namespace { std::optional<ui64> EstimateIntegralDataSize(const TDataType* dataType) { diff --git a/ydb/library/yql/dq/runtime/dq_transport.h b/ydb/library/yql/dq/runtime/dq_transport.h index 1bd8101b26e..8f6b6f947ff 100644 --- a/ydb/library/yql/dq/runtime/dq_transport.h +++ b/ydb/library/yql/dq/runtime/dq_transport.h @@ -27,25 +27,25 @@ public: template <class TForwardIterator> NDqProto::TData Serialize(TForwardIterator first, TForwardIterator last, const NKikimr::NMiniKQL::TType* itemType) const { - switch (TransportVersion) { - case NDqProto::DATA_TRANSPORT_VERSION_UNSPECIFIED: - case NDqProto::DATA_TRANSPORT_UV_PICKLE_1_0: { - auto count = std::distance(first, last); + switch (TransportVersion) { + case NDqProto::DATA_TRANSPORT_VERSION_UNSPECIFIED: + case NDqProto::DATA_TRANSPORT_UV_PICKLE_1_0: { + auto count = std::distance(first, last); const auto listType = NKikimr::NMiniKQL::TListType::Create( const_cast<NKikimr::NMiniKQL::TType*>(itemType), TypeEnv); - const NUdf::TUnboxedValue listValue = HolderFactory.RangeAsArray(first, last); + const NUdf::TUnboxedValue listValue = HolderFactory.RangeAsArray(first, last); - auto data = Serialize(listValue, listType); - data.SetRows(count); - return data; - } - case NDqProto::DATA_TRANSPORT_ARROW_1_0: { - NKikimr::NMiniKQL::TUnboxedValueVector buffer(first, last); - return Serialize(buffer, itemType); - } - default: - YQL_ENSURE(false, "Unsupported TransportVersion"); - } + auto data = Serialize(listValue, listType); + data.SetRows(count); + return data; + } + case NDqProto::DATA_TRANSPORT_ARROW_1_0: { + NKikimr::NMiniKQL::TUnboxedValueVector buffer(first, last); + return Serialize(buffer, itemType); + } + default: + YQL_ENSURE(false, "Unsupported TransportVersion"); + } } void Deserialize(const NDqProto::TData& data, const NKikimr::NMiniKQL::TType* itemType, @@ -54,7 +54,7 @@ public: ui64 CalcSerializedSize(NUdf::TUnboxedValue& value, const NKikimr::NMiniKQL::TType* type); static ui64 EstimateSize(const NUdf::TUnboxedValue& value, const NKikimr::NMiniKQL::TType* type, bool* fixed = nullptr); - + static NDqProto::TData SerializeParam(const TMkqlValueRef& param, const NKikimr::NMiniKQL::TTypeEnvironment& typeEnv, const NKikimr::NMiniKQL::THolderFactory& holderFactory); static void DeserializeParam(const NDqProto::TData& data, const NKikimr::NMiniKQL::TType* type, @@ -65,7 +65,7 @@ public: public: const NKikimr::NMiniKQL::TTypeEnvironment& TypeEnv; const NKikimr::NMiniKQL::THolderFactory& HolderFactory; - const NDqProto::EDataTransportVersion TransportVersion; + const NDqProto::EDataTransportVersion TransportVersion; }; } // namespace NYql::NDq diff --git a/ydb/library/yql/dq/runtime/ut/ya.make b/ydb/library/yql/dq/runtime/ut/ya.make index 7c0ceafe993..d0b21e0c247 100644 --- a/ydb/library/yql/dq/runtime/ut/ya.make +++ b/ydb/library/yql/dq/runtime/ut/ya.make @@ -9,7 +9,7 @@ IF (SANITIZER_TYPE OR WITH_VALGRIND) ENDIF() SRCS( - dq_arrow_helpers_ut.cpp + dq_arrow_helpers_ut.cpp dq_output_channel_ut.cpp ut_helper.cpp ) diff --git a/ydb/library/yql/dq/runtime/ya.make b/ydb/library/yql/dq/runtime/ya.make index 5ff520f713d..4cc8585431d 100644 --- a/ydb/library/yql/dq/runtime/ya.make +++ b/ydb/library/yql/dq/runtime/ya.make @@ -6,7 +6,7 @@ OWNER( ) PEERDIR( - contrib/libs/apache/arrow + contrib/libs/apache/arrow ydb/core/util ydb/library/mkql_proto ydb/library/yql/minikql/comp_nodes @@ -19,7 +19,7 @@ PEERDIR( ) SRCS( - dq_arrow_helpers.cpp + dq_arrow_helpers.cpp dq_columns_resolve.cpp dq_compute.cpp dq_input_channel.cpp diff --git a/ydb/library/yql/providers/dq/actors/proto_builder.cpp b/ydb/library/yql/providers/dq/actors/proto_builder.cpp index cb94e200c08..e4b16403c3b 100644 --- a/ydb/library/yql/providers/dq/actors/proto_builder.cpp +++ b/ydb/library/yql/providers/dq/actors/proto_builder.cpp @@ -135,8 +135,8 @@ bool TProtoBuilder::WriteData(const TVector<NDqProto::TData>& rows, const std::f for (const auto& item : buffer) { if (!func(item)) { return false; - } - } + } + } } return true; } diff --git a/ydb/library/yql/providers/dq/runtime/task_command_executor.cpp b/ydb/library/yql/providers/dq/runtime/task_command_executor.cpp index c9d2d5dae03..9b41920b177 100644 --- a/ydb/library/yql/providers/dq/runtime/task_command_executor.cpp +++ b/ydb/library/yql/providers/dq/runtime/task_command_executor.cpp @@ -308,24 +308,24 @@ public: request.Load(&input); auto guard = Runner->BindAllocator(0); // Explicitly reset memory limit - auto transportVersion = NDqProto::EDataTransportVersion::DATA_TRANSPORT_VERSION_UNSPECIFIED; - switch (request.GetData().GetTransportVersion()) { - case 10000: { - transportVersion = NDqProto::EDataTransportVersion::DATA_TRANSPORT_YSON_1_0; - break; - } - case 20000: { - transportVersion = NDqProto::EDataTransportVersion::DATA_TRANSPORT_UV_PICKLE_1_0; - break; - } - case 30000: { - transportVersion = NDqProto::EDataTransportVersion::DATA_TRANSPORT_ARROW_1_0; - break; - } - default: - transportVersion = NDqProto::EDataTransportVersion::DATA_TRANSPORT_VERSION_UNSPECIFIED; - } - NDq::TDqDataSerializer dataSerializer(Runner->GetTypeEnv(), Runner->GetHolderFactory(), transportVersion); + auto transportVersion = NDqProto::EDataTransportVersion::DATA_TRANSPORT_VERSION_UNSPECIFIED; + switch (request.GetData().GetTransportVersion()) { + case 10000: { + transportVersion = NDqProto::EDataTransportVersion::DATA_TRANSPORT_YSON_1_0; + break; + } + case 20000: { + transportVersion = NDqProto::EDataTransportVersion::DATA_TRANSPORT_UV_PICKLE_1_0; + break; + } + case 30000: { + transportVersion = NDqProto::EDataTransportVersion::DATA_TRANSPORT_ARROW_1_0; + break; + } + default: + transportVersion = NDqProto::EDataTransportVersion::DATA_TRANSPORT_VERSION_UNSPECIFIED; + } + NDq::TDqDataSerializer dataSerializer(Runner->GetTypeEnv(), Runner->GetHolderFactory(), transportVersion); NKikimr::NMiniKQL::TUnboxedValueVector buffer; buffer.reserve(request.GetData().GetRows()); if (request.GetString().empty() && request.GetChunks() == 0) { diff --git a/ydb/library/yql/providers/dq/task_runner/tasks_runner_pipe.cpp b/ydb/library/yql/providers/dq/task_runner/tasks_runner_pipe.cpp index add3386007c..cdcbd32ef5d 100644 --- a/ydb/library/yql/providers/dq/task_runner/tasks_runner_pipe.cpp +++ b/ydb/library/yql/providers/dq/task_runner/tasks_runner_pipe.cpp @@ -772,7 +772,7 @@ public: void Push(NKikimr::NMiniKQL::TUnboxedValueVector&& batch, i64 space) override { auto inputType = GetInputType(); NDqProto::TSourcePushRequest data; - TDqDataSerializer dataSerializer(TaskRunner->GetTypeEnv(), TaskRunner->GetHolderFactory(), NDqProto::DATA_TRANSPORT_UV_PICKLE_1_0); + TDqDataSerializer dataSerializer(TaskRunner->GetTypeEnv(), TaskRunner->GetHolderFactory(), NDqProto::DATA_TRANSPORT_UV_PICKLE_1_0); *data.MutableData() = dataSerializer.Serialize(batch, static_cast<NKikimr::NMiniKQL::TType*>(inputType)); data.SetSpace(space); |