diff options
author | nadya02 <nadya02@yandex-team.com> | 2025-01-22 16:54:19 +0300 |
---|---|---|
committer | nadya02 <nadya02@yandex-team.com> | 2025-01-22 18:02:09 +0300 |
commit | fc450d0b4461cf0d0c4069b384f0efcf6d4541a5 (patch) | |
tree | 0e37fb7a449fa2c7935597f129a74324315ab89c | |
parent | 41e6a7be09fcb492f243779ced3f61375da47367 (diff) | |
download | ydb-fc450d0b4461cf0d0c4069b384f0efcf6d4541a5.tar.gz |
YT-23828: Add YT type to complex visitor in arrow parser
* Changelog entry
Type: fix
Component: proxy
Add a type compatibility check by adding YT type to complex visitor in arrow parser
commit_hash:464ee201809f2da0d7a6178079a1ee8c5669a97d
-rw-r--r-- | yt/yt/library/formats/arrow_parser.cpp | 609 | ||||
-rw-r--r-- | yt/yt/library/formats/unittests/arrow_parser_ut.cpp | 27 |
2 files changed, 390 insertions, 246 deletions
diff --git a/yt/yt/library/formats/arrow_parser.cpp b/yt/yt/library/formats/arrow_parser.cpp index b324c575cf..7e73bd7798 100644 --- a/yt/yt/library/formats/arrow_parser.cpp +++ b/yt/yt/library/formats/arrow_parser.cpp @@ -4,6 +4,7 @@ #include <yt/yt/client/table_client/logical_type.h> #include <yt/yt/client/table_client/table_consumer.h> #include <yt/yt/client/table_client/unversioned_row.h> +#include <yt/yt/client/table_client/validate_logical_type.h> #include <yt/yt/client/formats/parser.h> @@ -42,6 +43,226 @@ void ThrowOnError(const arrow::Status& status) } } +void CheckArrowType( + auto ytTypeOrMetatype, + const std::initializer_list<arrow::Type::type>& allowedArrowTypes, + const std::string& arrowTypeName, + arrow::Type::type arrowType) +{ + if (std::find(allowedArrowTypes.begin(), allowedArrowTypes.end(), arrowType) == allowedArrowTypes.end()) { + THROW_ERROR_EXCEPTION("Unexpected arrow type %Qv for YT metatype %Qlv", + arrowTypeName, + ytTypeOrMetatype); + } +} + +void CheckArrowTypeMatch( + const ESimpleLogicalValueType& columnType, + const std::string& arrowTypeName, + arrow::Type::type arrowTypeId) +{ + switch (columnType) { + case ESimpleLogicalValueType::Int8: + case ESimpleLogicalValueType::Int16: + case ESimpleLogicalValueType::Int32: + case ESimpleLogicalValueType::Int64: + + case ESimpleLogicalValueType::Interval: + CheckArrowType( + columnType, + { + arrow::Type::INT8, + arrow::Type::INT16, + arrow::Type::INT32, + arrow::Type::INT64, + arrow::Type::DATE32, + arrow::Type::DATE64, + arrow::Type::TIMESTAMP, + arrow::Type::TIME32, + arrow::Type::TIME64, + arrow::Type::DICTIONARY + }, + arrowTypeName, + arrowTypeId); + break; + + case ESimpleLogicalValueType::Uint8: + case ESimpleLogicalValueType::Uint16: + case ESimpleLogicalValueType::Uint32: + case ESimpleLogicalValueType::Uint64: + CheckArrowType( + columnType, + { + arrow::Type::UINT8, + arrow::Type::UINT16, + arrow::Type::UINT32, + arrow::Type::UINT64, + arrow::Type::DICTIONARY + }, + arrowTypeName, + arrowTypeId); + break; + + case ESimpleLogicalValueType::Date: + case ESimpleLogicalValueType::Datetime: + case ESimpleLogicalValueType::Timestamp: + CheckArrowType( + columnType, + { + arrow::Type::UINT32, + arrow::Type::UINT64, + arrow::Type::DICTIONARY, + }, + arrowTypeName, + arrowTypeId); + break; + + case ESimpleLogicalValueType::Date32: + case ESimpleLogicalValueType::Datetime64: + case ESimpleLogicalValueType::Timestamp64: + CheckArrowType( + columnType, + { + arrow::Type::INT32, + arrow::Type::INT64, + arrow::Type::DICTIONARY, + }, + arrowTypeName, + arrowTypeId); + break; + + case ESimpleLogicalValueType::String: + CheckArrowType( + columnType, + { + arrow::Type::STRING, + arrow::Type::BINARY, + arrow::Type::LARGE_STRING, + arrow::Type::LARGE_BINARY, + arrow::Type::FIXED_SIZE_BINARY, + arrow::Type::DICTIONARY, + arrow::Type::DECIMAL128, + arrow::Type::DECIMAL256, + }, + arrowTypeName, + arrowTypeId); + break; + + case ESimpleLogicalValueType::Json: + case ESimpleLogicalValueType::Utf8: + CheckArrowType( + columnType, + { + arrow::Type::STRING, + arrow::Type::LARGE_STRING, + arrow::Type::BINARY, + arrow::Type::LARGE_BINARY, + arrow::Type::DICTIONARY, + }, + arrowTypeName, + arrowTypeId); + break; + + case ESimpleLogicalValueType::Float: + case ESimpleLogicalValueType::Double: + CheckArrowType( + columnType, + { + arrow::Type::HALF_FLOAT, + arrow::Type::FLOAT, + arrow::Type::DOUBLE, + arrow::Type::DICTIONARY + }, + arrowTypeName, + arrowTypeId); + break; + + case ESimpleLogicalValueType::Boolean: + CheckArrowType( + columnType, + {arrow::Type::BOOL, arrow::Type::DICTIONARY}, + arrowTypeName, + arrowTypeId); + break; + + case ESimpleLogicalValueType::Any: + CheckArrowType( + columnType, + { + arrow::Type::INT8, + arrow::Type::INT16, + arrow::Type::INT32, + arrow::Type::INT64, + arrow::Type::DATE32, + arrow::Type::DATE64, + arrow::Type::TIMESTAMP, + arrow::Type::TIME32, + arrow::Type::TIME64, + + arrow::Type::UINT8, + arrow::Type::UINT16, + arrow::Type::UINT32, + arrow::Type::UINT64, + + arrow::Type::HALF_FLOAT, + arrow::Type::FLOAT, + arrow::Type::DOUBLE, + + arrow::Type::STRING, + arrow::Type::BINARY, + arrow::Type::LARGE_STRING, + arrow::Type::LARGE_BINARY, + arrow::Type::FIXED_SIZE_BINARY, + + arrow::Type::BOOL, + + arrow::Type::NA, + arrow::Type::DICTIONARY + }, + arrowTypeName, + arrowTypeId); + break; + + case ESimpleLogicalValueType::Null: + case ESimpleLogicalValueType::Void: + CheckArrowType( + columnType, + { + arrow::Type::NA, + arrow::Type::DICTIONARY + }, + arrowTypeName, + arrowTypeId); + break; + + case ESimpleLogicalValueType::Uuid: + CheckArrowType( + columnType, + { + arrow::Type::STRING, + arrow::Type::BINARY, + arrow::Type::LARGE_STRING, + arrow::Type::LARGE_BINARY, + arrow::Type::FIXED_SIZE_BINARY, + arrow::Type::DICTIONARY + }, + arrowTypeName, + arrowTypeId); + break; + + case ESimpleLogicalValueType::Interval64: + THROW_ERROR_EXCEPTION("Unexpected column type %Qv", + columnType); + } +} + +void CheckArrowTypeMatch( + const ESimpleLogicalValueType& columnType, + const std::shared_ptr<arrow::Array>& column) +{ + CheckArrowTypeMatch(columnType, column->type()->name(), column->type_id()); +} + template <class TUnderlyingValueType> TStringBuf SerializeDecimalBinary(TStringBuf value, int precision, char* buffer, size_t bufferLength) { @@ -104,24 +325,24 @@ public: return ParseInt64<arrow::Int64Array>(); } - arrow::Status Visit(const arrow::Date32Type& /*type*/) override + arrow::Status Visit(const arrow::Time32Type& /*type*/) override { - return ParseInt64<arrow::Date32Array>(); + return ParseInt64<arrow::Time32Array>(); } - arrow::Status Visit(const arrow::Time32Type& /*type*/) override + arrow::Status Visit(const arrow::Time64Type& /*type*/) override { - return ParseInt64<arrow::Time32Array>(); + return ParseInt64<arrow::Time64Array>(); } - arrow::Status Visit(const arrow::Date64Type& /*type*/) override + arrow::Status Visit(const arrow::Date32Type& /*type*/) override { - return ParseInt64<arrow::Date64Array>(); + return ParseInt64<arrow::Date32Array>(); } - arrow::Status Visit(const arrow::Time64Type& /*type*/) override + arrow::Status Visit(const arrow::Date64Type& /*type*/) override { - return ParseInt64<arrow::Time64Array>(); + return ParseInt64<arrow::Date64Array>(); } arrow::Status Visit(const arrow::TimestampType& /*type*/) override @@ -224,7 +445,7 @@ private: template <typename ArrayType> arrow::Status ParseUInt64() { - auto makeUnversionedValue = [] (i64 value, i64 columnId) { + auto makeUnversionedValue = [] (ui64 value, i64 columnId) { return MakeUnversionedUint64Value(value, columnId); }; ParseSimpleNumeric<ArrayType, decltype(makeUnversionedValue)>(makeUnversionedValue); @@ -336,10 +557,12 @@ class TArrayCompositeVisitor { public: TArrayCompositeVisitor( + TLogicalTypePtr ytType, const std::shared_ptr<arrow::Array>& array, NYson::TCheckedInDebugYsonTokenWriter* writer, int rowIndex) - : RowIndex_(rowIndex) + : YTType_(DenullifyLogicalType(ytType)) + , RowIndex_(rowIndex) , Array_(array) , Writer_(writer) { @@ -347,108 +570,129 @@ public: } // Signed integer types. - arrow::Status Visit(const arrow::Int8Type& /*type*/) override + arrow::Status Visit(const arrow::Int8Type& type) override { + CheckArrowTypeMatch(YTType_->AsSimpleTypeRef().GetElement(), type.type_name(), type.id()); return ParseInt64<arrow::Int8Array>(); } - arrow::Status Visit(const arrow::Int16Type& /*type*/) override + arrow::Status Visit(const arrow::Int16Type& type) override { + CheckArrowTypeMatch(YTType_->AsSimpleTypeRef().GetElement(), type.type_name(), type.id()); return ParseInt64<arrow::Int16Array>(); } - arrow::Status Visit(const arrow::Int32Type& /*type*/) override + arrow::Status Visit(const arrow::Int32Type& type) override { + CheckArrowTypeMatch(YTType_->AsSimpleTypeRef().GetElement(), type.type_name(), type.id()); return ParseInt64<arrow::Int32Array>(); } - arrow::Status Visit(const arrow::Int64Type& /*type*/) override + arrow::Status Visit(const arrow::Int64Type& type) override { + CheckArrowTypeMatch(YTType_->AsSimpleTypeRef().GetElement(), type.type_name(), type.id()); return ParseInt64<arrow::Int64Array>(); } - arrow::Status Visit(const arrow::Date32Type& /*type*/) override + // Date types. + arrow::Status Visit(const arrow::Time32Type& type) override { - return ParseInt64<arrow::Date32Array>(); + CheckArrowTypeMatch(YTType_->AsSimpleTypeRef().GetElement(), type.type_name(), type.id()); + return ParseInt64<arrow::Time32Array>(); } - arrow::Status Visit(const arrow::Time32Type& /*type*/) override + arrow::Status Visit(const arrow::Time64Type& type) override { - return ParseInt64<arrow::Time32Array>(); + CheckArrowTypeMatch(YTType_->AsSimpleTypeRef().GetElement(), type.type_name(), type.id()); + return ParseInt64<arrow::Time64Array>(); } - arrow::Status Visit(const arrow::Date64Type& /*type*/) override + arrow::Status Visit(const arrow::Date32Type& type) override { - return ParseInt64<arrow::Date64Array>(); + CheckArrowTypeMatch(YTType_->AsSimpleTypeRef().GetElement(), type.type_name(), type.id()); + return ParseInt64<arrow::Date32Array>(); } - arrow::Status Visit(const arrow::Time64Type& /*type*/) override + arrow::Status Visit(const arrow::Date64Type& type) override { - return ParseInt64<arrow::Time64Array>(); + CheckArrowTypeMatch(YTType_->AsSimpleTypeRef().GetElement(), type.type_name(), type.id()); + return ParseInt64<arrow::Date64Array>(); } - arrow::Status Visit(const arrow::TimestampType& /*type*/) override + arrow::Status Visit(const arrow::TimestampType& type) override { + CheckArrowTypeMatch(YTType_->AsSimpleTypeRef().GetElement(), type.type_name(), type.id()); return ParseInt64<arrow::TimestampArray>(); } // Unsigned integer types. - arrow::Status Visit(const arrow::UInt8Type& /*type*/) override + arrow::Status Visit(const arrow::UInt8Type& type) override { + CheckArrowTypeMatch(YTType_->AsSimpleTypeRef().GetElement(), type.type_name(), type.id()); return ParseUInt64<arrow::UInt8Array>(); } - arrow::Status Visit(const arrow::UInt16Type& /*type*/) override + arrow::Status Visit(const arrow::UInt16Type& type) override { + CheckArrowTypeMatch(YTType_->AsSimpleTypeRef().GetElement(), type.type_name(), type.id()); return ParseUInt64<arrow::UInt16Array>(); } - arrow::Status Visit(const arrow::UInt32Type& /*type*/) override + arrow::Status Visit(const arrow::UInt32Type& type) override { + CheckArrowTypeMatch(YTType_->AsSimpleTypeRef().GetElement(), type.type_name(), type.id()); return ParseUInt64<arrow::UInt32Array>(); } - arrow::Status Visit(const arrow::UInt64Type& /*type*/) override + arrow::Status Visit(const arrow::UInt64Type& type) override { + CheckArrowTypeMatch(YTType_->AsSimpleTypeRef().GetElement(), type.type_name(), type.id()); return ParseUInt64<arrow::UInt64Array>(); } // Float types. - arrow::Status Visit(const arrow::HalfFloatType& /*type*/) override + arrow::Status Visit(const arrow::HalfFloatType& type) override { + CheckArrowTypeMatch(YTType_->AsSimpleTypeRef().GetElement(), type.type_name(), type.id()); return ParseDouble<arrow::HalfFloatArray>(); } - arrow::Status Visit(const arrow::FloatType& /*type*/) override + arrow::Status Visit(const arrow::FloatType& type) override { + CheckArrowTypeMatch(YTType_->AsSimpleTypeRef().GetElement(), type.type_name(), type.id()); return ParseDouble<arrow::FloatArray>(); } - arrow::Status Visit(const arrow::DoubleType& /*type*/) override + arrow::Status Visit(const arrow::DoubleType& type) override { + CheckArrowTypeMatch(YTType_->AsSimpleTypeRef().GetElement(), type.type_name(), type.id()); return ParseDouble<arrow::DoubleArray>(); } // Binary types. - arrow::Status Visit(const arrow::StringType& /*type*/) override + arrow::Status Visit(const arrow::StringType& type) override { + CheckArrowTypeMatch(YTType_->AsSimpleTypeRef().GetElement(), type.type_name(), type.id()); return ParseStringLikeArray<arrow::StringArray>(); } - arrow::Status Visit(const arrow::BinaryType& /*type*/) override + arrow::Status Visit(const arrow::BinaryType& type) override { + CheckArrowTypeMatch(YTType_->AsSimpleTypeRef().GetElement(), type.type_name(), type.id()); return ParseStringLikeArray<arrow::BinaryArray>(); } // Boolean types. - arrow::Status Visit(const arrow::BooleanType& /*type*/) override + arrow::Status Visit(const arrow::BooleanType& type) override { + CheckArrowTypeMatch(YTType_->AsSimpleTypeRef().GetElement(), type.type_name(), type.id()); return ParseBoolean(); } // Null types. - arrow::Status Visit(const arrow::NullType& /*type*/) override + arrow::Status Visit(const arrow::NullType& type) override { + CheckArrowTypeMatch(YTType_->AsSimpleTypeRef().GetElement(), type.type_name(), type.id()); return ParseNull(); } @@ -483,6 +727,7 @@ public: } private: + const TLogicalTypePtr YTType_; const int RowIndex_; std::shared_ptr<arrow::Array> Array_; @@ -569,6 +814,10 @@ private: arrow::Status ParseList() { + if (YTType_->GetMetatype() != ELogicalMetatype::List) { + THROW_ERROR_EXCEPTION("Unexpected arrow type \"list\" for YT metatype %Qlv", + YTType_->GetMetatype()); + } auto array = std::static_pointer_cast<arrow::ListArray>(Array_); if (array->IsNull(RowIndex_)) { Writer_->WriteEntity(); @@ -577,9 +826,14 @@ private: auto listValue = array->value_slice(RowIndex_); for (int offset = 0; offset < listValue->length(); ++offset) { - TArrayCompositeVisitor visitor(listValue, Writer_, offset); - ThrowOnError(listValue->type()->Accept(&visitor)); - + TArrayCompositeVisitor visitor(YTType_->AsListTypeRef().GetElement(), listValue, Writer_, offset); + try { + ThrowOnError(listValue->type()->Accept(&visitor)); + } catch (const std::exception& ex) { + THROW_ERROR_EXCEPTION("Failed to parse arrow type \"list\"") + << TErrorAttribute("offset", offset) + << ex; + } Writer_->WriteItemSeparator(); } @@ -590,28 +844,47 @@ private: arrow::Status ParseMap() { + if (YTType_->GetMetatype() != ELogicalMetatype::Dict) { + THROW_ERROR_EXCEPTION("Unexpected arrow type \"map\" for YT metatype %Qlv", + YTType_->GetMetatype()); + } auto array = std::static_pointer_cast<arrow::MapArray>(Array_); + auto allKeys = array->keys(); + auto allValues = array->items(); + if (array->IsNull(RowIndex_)) { Writer_->WriteEntity(); } else { - auto element = std::static_pointer_cast<arrow::StructArray>( - array->value_slice(RowIndex_)); + auto offset = array->value_offset(RowIndex_); + auto length = array->value_length(RowIndex_); - auto keyList = element->GetFieldByName("key"); - auto valueList = element->GetFieldByName("value"); + auto keyList = allKeys->Slice(offset, length); + auto valueList = allValues->Slice(offset, length); Writer_->WriteBeginList(); for (int offset = 0; offset < keyList->length(); ++offset) { Writer_->WriteBeginList(); - TArrayCompositeVisitor keyVisitor(keyList, Writer_, offset); - ThrowOnError(keyList->type()->Accept(&keyVisitor)); + TArrayCompositeVisitor keyVisitor(YTType_->AsDictTypeRef().GetKey(), keyList, Writer_, offset); + try { + ThrowOnError(keyList->type()->Accept(&keyVisitor)); + } catch (const std::exception& ex) { + THROW_ERROR_EXCEPTION("Failed to parse arrow key field of type \"map\"") + << TErrorAttribute("offset", offset) + << ex; + } Writer_->WriteItemSeparator(); - TArrayCompositeVisitor valueVisitor(valueList, Writer_, offset); - ThrowOnError(valueList->type()->Accept(&valueVisitor)); + TArrayCompositeVisitor valueVisitor(YTType_->AsDictTypeRef().GetValue(), valueList, Writer_, offset); + try { + ThrowOnError(valueList->type()->Accept(&valueVisitor)); + } catch (const std::exception& ex) { + THROW_ERROR_EXCEPTION("Failed to parse arrow value field type \"map\"") + << TErrorAttribute("offset", offset) + << ex; + } Writer_->WriteItemSeparator(); @@ -626,16 +899,33 @@ private: arrow::Status ParseStruct() { + if (YTType_->GetMetatype() != ELogicalMetatype::Struct) { + THROW_ERROR_EXCEPTION("Unexpected arrow type \"struct\" for YT metatype %Qlv", + YTType_->GetMetatype()); + } auto array = std::static_pointer_cast<arrow::StructArray>(Array_); if (array->IsNull(RowIndex_)) { Writer_->WriteEntity(); } else { Writer_->WriteBeginList(); - - for (int offset = 0; offset < array->num_fields(); ++offset) { - auto element = array->field(offset); - TArrayCompositeVisitor visitor(element, Writer_, RowIndex_); - ThrowOnError(element->type()->Accept(&visitor)); + auto structFields = YTType_->AsStructTypeRef().GetFields(); + if (std::ssize(structFields) != array->num_fields()) { + THROW_ERROR_EXCEPTION("The number of fields in the Arrow \"struct\" type does not match the number of fields in the YT \"struct\" type") + << TErrorAttribute("arrow_field_count", array->num_fields()) + << TErrorAttribute("yt_field_count", std::ssize(structFields)); + } + for (const auto& field : structFields) { + auto arrowField = array->GetFieldByName(field.Name); + if (!arrowField) { + THROW_ERROR_EXCEPTION("Field %Qv is not found in arrow type \"struct\"", field.Name); + } + TArrayCompositeVisitor visitor(field.Type, arrowField, Writer_, RowIndex_); + try { + ThrowOnError(arrowField->type()->Accept(&visitor)); + } catch (const std::exception& ex) { + THROW_ERROR_EXCEPTION("Failed to parse arrow struct field %Qv", field.Name) + << ex; + } Writer_->WriteItemSeparator(); } @@ -657,177 +947,6 @@ private: //////////////////////////////////////////////////////////////////////////////// -void CheckArrowType( - auto ytTypeOrMetatype, - const std::shared_ptr<arrow::DataType>& arrowType, - std::initializer_list<arrow::Type::type> allowedTypes) -{ - if (std::find(allowedTypes.begin(), allowedTypes.end(), arrowType->id()) == allowedTypes.end()) { - THROW_ERROR_EXCEPTION("Unexpected arrow type %Qv for YT type or metatype %Qlv", - arrowType->name(), - ytTypeOrMetatype); - } -} - -void CheckMatchingArrowTypes( - const ESimpleLogicalValueType& columnType, - const std::shared_ptr<arrow::Array>& column) -{ - switch (columnType) { - case ESimpleLogicalValueType::Int8: - case ESimpleLogicalValueType::Int16: - case ESimpleLogicalValueType::Int32: - case ESimpleLogicalValueType::Int64: - - case ESimpleLogicalValueType::Interval: - CheckArrowType( - columnType, - column->type(), - { - arrow::Type::INT8, - arrow::Type::INT16, - arrow::Type::INT32, - arrow::Type::INT64, - arrow::Type::DATE32, - arrow::Type::DATE64, - arrow::Type::TIMESTAMP, - arrow::Type::TIME32, - arrow::Type::TIME64, - arrow::Type::DICTIONARY - }); - break; - - case ESimpleLogicalValueType::Uint8: - case ESimpleLogicalValueType::Uint16: - case ESimpleLogicalValueType::Uint32: - case ESimpleLogicalValueType::Uint64: - - case ESimpleLogicalValueType::Date: - case ESimpleLogicalValueType::Datetime: - case ESimpleLogicalValueType::Timestamp: - CheckArrowType( - columnType, - column->type(), - { - arrow::Type::UINT8, - arrow::Type::UINT16, - arrow::Type::UINT32, - arrow::Type::UINT64, - arrow::Type::DICTIONARY - }); - break; - - case ESimpleLogicalValueType::String: - case ESimpleLogicalValueType::Json: - case ESimpleLogicalValueType::Utf8: - CheckArrowType( - columnType, - column->type(), - { - arrow::Type::STRING, - arrow::Type::BINARY, - arrow::Type::LARGE_STRING, - arrow::Type::LARGE_BINARY, - arrow::Type::FIXED_SIZE_BINARY, - arrow::Type::DICTIONARY, - arrow::Type::DECIMAL128, - arrow::Type::DECIMAL256, - }); - break; - - case ESimpleLogicalValueType::Float: - case ESimpleLogicalValueType::Double: - CheckArrowType( - columnType, - column->type(), - { - arrow::Type::HALF_FLOAT, - arrow::Type::FLOAT, - arrow::Type::DOUBLE, - arrow::Type::DICTIONARY - }); - break; - - case ESimpleLogicalValueType::Boolean: - CheckArrowType( - columnType, - column->type(), - {arrow::Type::BOOL, arrow::Type::DICTIONARY}); - break; - - case ESimpleLogicalValueType::Any: - CheckArrowType( - columnType, - column->type(), - { - arrow::Type::INT8, - arrow::Type::INT16, - arrow::Type::INT32, - arrow::Type::INT64, - arrow::Type::DATE32, - arrow::Type::DATE64, - arrow::Type::TIMESTAMP, - arrow::Type::TIME32, - arrow::Type::TIME64, - - arrow::Type::UINT8, - arrow::Type::UINT16, - arrow::Type::UINT32, - arrow::Type::UINT64, - - arrow::Type::HALF_FLOAT, - arrow::Type::FLOAT, - arrow::Type::DOUBLE, - - arrow::Type::STRING, - arrow::Type::BINARY, - arrow::Type::LARGE_STRING, - arrow::Type::LARGE_BINARY, - arrow::Type::FIXED_SIZE_BINARY, - - arrow::Type::BOOL, - - arrow::Type::NA, - arrow::Type::DICTIONARY - }); - break; - - case ESimpleLogicalValueType::Null: - case ESimpleLogicalValueType::Void: - CheckArrowType( - columnType, - column->type(), - { - arrow::Type::NA, - arrow::Type::DICTIONARY - }); - break; - - case ESimpleLogicalValueType::Uuid: - CheckArrowType( - columnType, - column->type(), - { - arrow::Type::STRING, - arrow::Type::BINARY, - arrow::Type::LARGE_STRING, - arrow::Type::LARGE_BINARY, - arrow::Type::FIXED_SIZE_BINARY, - arrow::Type::DICTIONARY - }); - break; - - case ESimpleLogicalValueType::Date32: - case ESimpleLogicalValueType::Datetime64: - case ESimpleLogicalValueType::Timestamp64: - case ESimpleLogicalValueType::Interval64: - THROW_ERROR_EXCEPTION("Unexpected column type %Qv", - columnType); - } -} - -//////////////////////////////////////////////////////////////////////////////// - void PrepareArrayForSimpleLogicalType( ESimpleLogicalValueType columnType, const std::shared_ptr<TChunkedOutputStream>& bufferForStringLikeValues, @@ -836,12 +955,12 @@ void PrepareArrayForSimpleLogicalType( int columnIndex, int columnId) { - CheckMatchingArrowTypes(columnType, column); + CheckArrowTypeMatch(columnType, column); if (column->type()->id() == arrow::Type::DICTIONARY) { auto dictionaryArrayColumn = std::static_pointer_cast<arrow::DictionaryArray>(column); auto dictionary = dictionaryArrayColumn->dictionary(); TUnversionedRowValues dictionaryValues(dictionary->length()); - CheckMatchingArrowTypes(columnType, dictionary); + CheckArrowTypeMatch(columnType, dictionary); TArraySimpleVisitor visitor(columnType, columnId, dictionary, bufferForStringLikeValues, &dictionaryValues); ThrowOnError(dictionaryArrayColumn->dictionary()->type()->Accept(&visitor)); @@ -873,48 +992,52 @@ void PrepareArrayForComplexType( case ELogicalMetatype::List: CheckArrowType( metatype, - column->type(), { arrow::Type::LIST, arrow::Type::BINARY - }); + }, + column->type()->name(), + column->type_id()); break; case ELogicalMetatype::Dict: CheckArrowType( metatype, - column->type(), { arrow::Type::MAP, arrow::Type::BINARY - }); + }, + column->type()->name(), + column->type_id()); break; case ELogicalMetatype::Struct: CheckArrowType( metatype, - column->type(), { arrow::Type::STRUCT, arrow::Type::BINARY - }); + }, + column->type()->name(), + column->type_id()); break; case ELogicalMetatype::Decimal: CheckArrowType( metatype, - column->type(), { arrow::Type::DECIMAL128, arrow::Type::DECIMAL256 - }); + }, + column->type()->name(), + column->type_id()); break; case ELogicalMetatype::Optional: case ELogicalMetatype::Tuple: case ELogicalMetatype::VariantTuple: case ELogicalMetatype::VariantStruct: - CheckArrowType(metatype, column->type(), {arrow::Type::BINARY}); + CheckArrowType(metatype, {arrow::Type::BINARY}, column->type()->name(), column->type_id()); break; default: @@ -947,7 +1070,7 @@ void PrepareArrayForComplexType( TBufferOutput out(valueBuffer); NYson::TCheckedInDebugYsonTokenWriter writer(&out); - TArrayCompositeVisitor visitor(column, &writer, rowIndex); + TArrayCompositeVisitor visitor(denullifiedLogicalType, column, &writer, rowIndex); ThrowOnError(column->type()->Accept(&visitor)); @@ -1063,14 +1186,18 @@ public: ? columnSchema->LogicalType() : OptionalLogicalType(SimpleLogicalType(ESimpleLogicalValueType::Any)); auto denullifiedColumnType = DenullifyLogicalType(columnType); - - PrepareArray( - denullifiedColumnType, - bufferForStringLikeValues, - batch->column(columnIndex), - rowsValues, - columnIndex, - columnId); + try { + PrepareArray( + denullifiedColumnType, + bufferForStringLikeValues, + batch->column(columnIndex), + rowsValues, + columnIndex, + columnId); + } catch (const std::exception& ex) { + THROW_ERROR_EXCEPTION("Failed to parse column %Qv", columnName) + << ex; + } } for (int rowIndex = 0; rowIndex < numRows; ++rowIndex) { diff --git a/yt/yt/library/formats/unittests/arrow_parser_ut.cpp b/yt/yt/library/formats/unittests/arrow_parser_ut.cpp index 156d8005b5..852323af18 100644 --- a/yt/yt/library/formats/unittests/arrow_parser_ut.cpp +++ b/yt/yt/library/formats/unittests/arrow_parser_ut.cpp @@ -184,7 +184,7 @@ std::string MakeMapArrow(const std::vector<std::vector<int32_t>>& key, const std auto* pool = arrow::default_memory_pool(); auto keyBuilder = std::make_shared<arrow::Int32Builder>(pool); - auto valueBuilder = std::make_shared<arrow::Int32Builder>(pool); + auto valueBuilder = std::make_shared<arrow::UInt32Builder>(pool); auto mapBuilder = std::make_unique<arrow::MapBuilder>(pool, keyBuilder, valueBuilder); for (ssize_t mapIndex = 0; mapIndex < std::ssize(key); mapIndex++) { @@ -526,10 +526,10 @@ TEST(TArrowParserTest, Map) parser->Finish(); auto firstNode = GetComposite(collectedRows.GetRowValue(0, "map")); - ASSERT_EQ(ConvertToYsonTextStringStable(firstNode), "[[1;2;];[3;2;];]"); + ASSERT_EQ(ConvertToYsonTextStringStable(firstNode), "[[1;2u;];[3;2u;];]"); auto secondNode = GetComposite(collectedRows.GetRowValue(1, "map")); - ASSERT_EQ(ConvertToYsonTextStringStable(secondNode), "[[3;2;];]"); + ASSERT_EQ(ConvertToYsonTextStringStable(secondNode), "[[3;2u;];]"); } TEST(TArrowParserTest, SeveralIntArrays) @@ -559,8 +559,8 @@ TEST(TArrowParserTest, Struct) { auto tableSchema = New<TTableSchema>(std::vector<TColumnSchema>{ TColumnSchema("struct", StructLogicalType({ - {"bar", SimpleLogicalType(ESimpleLogicalValueType::String)}, - {"foo", SimpleLogicalType(ESimpleLogicalValueType::Int64)}, + {"bar", SimpleLogicalType(ESimpleLogicalValueType::String)}, + {"foo", SimpleLogicalType(ESimpleLogicalValueType::Int64)}, })), }); @@ -578,6 +578,23 @@ TEST(TArrowParserTest, Struct) ASSERT_EQ(ConvertToYsonTextStringStable(secondNode), "[\"two\";2;]"); } +TEST(TArrowParserTest, StructError) +{ + auto tableSchema = New<TTableSchema>(std::vector<TColumnSchema>{ + TColumnSchema("struct", StructLogicalType({ + {"bar", SimpleLogicalType(ESimpleLogicalValueType::String)}, + })), + }); + + TCollectingValueConsumer collectedRows(tableSchema); + + auto parser = CreateParserForArrow(&collectedRows); + EXPECT_THROW_MESSAGE_HAS_SUBSTR( + parser->Read(MakeStructArrow({"one", "two"}, {1, 2})), + std::exception, + "The number of fields in the Arrow \"struct\" type does not match the number of fields in the YT \"struct\" type"); +} + TEST(TArrowParserTest, DecimalVariousPrecisions) { auto tableSchema = New<TTableSchema>(std::vector<TColumnSchema>{ |