aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authornadya02 <nadya02@yandex-team.com>2025-01-22 16:54:19 +0300
committernadya02 <nadya02@yandex-team.com>2025-01-22 18:02:09 +0300
commitfc450d0b4461cf0d0c4069b384f0efcf6d4541a5 (patch)
tree0e37fb7a449fa2c7935597f129a74324315ab89c
parent41e6a7be09fcb492f243779ced3f61375da47367 (diff)
downloadydb-fc450d0b4461cf0d0c4069b384f0efcf6d4541a5.tar.gz
YT-23828: Add YT type to complex visitor in arrow parser
* Changelog entry Type: fix Component: proxy Add a type compatibility check by adding YT type to complex visitor in arrow parser commit_hash:464ee201809f2da0d7a6178079a1ee8c5669a97d
-rw-r--r--yt/yt/library/formats/arrow_parser.cpp609
-rw-r--r--yt/yt/library/formats/unittests/arrow_parser_ut.cpp27
2 files changed, 390 insertions, 246 deletions
diff --git a/yt/yt/library/formats/arrow_parser.cpp b/yt/yt/library/formats/arrow_parser.cpp
index b324c575cf..7e73bd7798 100644
--- a/yt/yt/library/formats/arrow_parser.cpp
+++ b/yt/yt/library/formats/arrow_parser.cpp
@@ -4,6 +4,7 @@
#include <yt/yt/client/table_client/logical_type.h>
#include <yt/yt/client/table_client/table_consumer.h>
#include <yt/yt/client/table_client/unversioned_row.h>
+#include <yt/yt/client/table_client/validate_logical_type.h>
#include <yt/yt/client/formats/parser.h>
@@ -42,6 +43,226 @@ void ThrowOnError(const arrow::Status& status)
}
}
+void CheckArrowType(
+ auto ytTypeOrMetatype,
+ const std::initializer_list<arrow::Type::type>& allowedArrowTypes,
+ const std::string& arrowTypeName,
+ arrow::Type::type arrowType)
+{
+ if (std::find(allowedArrowTypes.begin(), allowedArrowTypes.end(), arrowType) == allowedArrowTypes.end()) {
+ THROW_ERROR_EXCEPTION("Unexpected arrow type %Qv for YT metatype %Qlv",
+ arrowTypeName,
+ ytTypeOrMetatype);
+ }
+}
+
+void CheckArrowTypeMatch(
+ const ESimpleLogicalValueType& columnType,
+ const std::string& arrowTypeName,
+ arrow::Type::type arrowTypeId)
+{
+ switch (columnType) {
+ case ESimpleLogicalValueType::Int8:
+ case ESimpleLogicalValueType::Int16:
+ case ESimpleLogicalValueType::Int32:
+ case ESimpleLogicalValueType::Int64:
+
+ case ESimpleLogicalValueType::Interval:
+ CheckArrowType(
+ columnType,
+ {
+ arrow::Type::INT8,
+ arrow::Type::INT16,
+ arrow::Type::INT32,
+ arrow::Type::INT64,
+ arrow::Type::DATE32,
+ arrow::Type::DATE64,
+ arrow::Type::TIMESTAMP,
+ arrow::Type::TIME32,
+ arrow::Type::TIME64,
+ arrow::Type::DICTIONARY
+ },
+ arrowTypeName,
+ arrowTypeId);
+ break;
+
+ case ESimpleLogicalValueType::Uint8:
+ case ESimpleLogicalValueType::Uint16:
+ case ESimpleLogicalValueType::Uint32:
+ case ESimpleLogicalValueType::Uint64:
+ CheckArrowType(
+ columnType,
+ {
+ arrow::Type::UINT8,
+ arrow::Type::UINT16,
+ arrow::Type::UINT32,
+ arrow::Type::UINT64,
+ arrow::Type::DICTIONARY
+ },
+ arrowTypeName,
+ arrowTypeId);
+ break;
+
+ case ESimpleLogicalValueType::Date:
+ case ESimpleLogicalValueType::Datetime:
+ case ESimpleLogicalValueType::Timestamp:
+ CheckArrowType(
+ columnType,
+ {
+ arrow::Type::UINT32,
+ arrow::Type::UINT64,
+ arrow::Type::DICTIONARY,
+ },
+ arrowTypeName,
+ arrowTypeId);
+ break;
+
+ case ESimpleLogicalValueType::Date32:
+ case ESimpleLogicalValueType::Datetime64:
+ case ESimpleLogicalValueType::Timestamp64:
+ CheckArrowType(
+ columnType,
+ {
+ arrow::Type::INT32,
+ arrow::Type::INT64,
+ arrow::Type::DICTIONARY,
+ },
+ arrowTypeName,
+ arrowTypeId);
+ break;
+
+ case ESimpleLogicalValueType::String:
+ CheckArrowType(
+ columnType,
+ {
+ arrow::Type::STRING,
+ arrow::Type::BINARY,
+ arrow::Type::LARGE_STRING,
+ arrow::Type::LARGE_BINARY,
+ arrow::Type::FIXED_SIZE_BINARY,
+ arrow::Type::DICTIONARY,
+ arrow::Type::DECIMAL128,
+ arrow::Type::DECIMAL256,
+ },
+ arrowTypeName,
+ arrowTypeId);
+ break;
+
+ case ESimpleLogicalValueType::Json:
+ case ESimpleLogicalValueType::Utf8:
+ CheckArrowType(
+ columnType,
+ {
+ arrow::Type::STRING,
+ arrow::Type::LARGE_STRING,
+ arrow::Type::BINARY,
+ arrow::Type::LARGE_BINARY,
+ arrow::Type::DICTIONARY,
+ },
+ arrowTypeName,
+ arrowTypeId);
+ break;
+
+ case ESimpleLogicalValueType::Float:
+ case ESimpleLogicalValueType::Double:
+ CheckArrowType(
+ columnType,
+ {
+ arrow::Type::HALF_FLOAT,
+ arrow::Type::FLOAT,
+ arrow::Type::DOUBLE,
+ arrow::Type::DICTIONARY
+ },
+ arrowTypeName,
+ arrowTypeId);
+ break;
+
+ case ESimpleLogicalValueType::Boolean:
+ CheckArrowType(
+ columnType,
+ {arrow::Type::BOOL, arrow::Type::DICTIONARY},
+ arrowTypeName,
+ arrowTypeId);
+ break;
+
+ case ESimpleLogicalValueType::Any:
+ CheckArrowType(
+ columnType,
+ {
+ arrow::Type::INT8,
+ arrow::Type::INT16,
+ arrow::Type::INT32,
+ arrow::Type::INT64,
+ arrow::Type::DATE32,
+ arrow::Type::DATE64,
+ arrow::Type::TIMESTAMP,
+ arrow::Type::TIME32,
+ arrow::Type::TIME64,
+
+ arrow::Type::UINT8,
+ arrow::Type::UINT16,
+ arrow::Type::UINT32,
+ arrow::Type::UINT64,
+
+ arrow::Type::HALF_FLOAT,
+ arrow::Type::FLOAT,
+ arrow::Type::DOUBLE,
+
+ arrow::Type::STRING,
+ arrow::Type::BINARY,
+ arrow::Type::LARGE_STRING,
+ arrow::Type::LARGE_BINARY,
+ arrow::Type::FIXED_SIZE_BINARY,
+
+ arrow::Type::BOOL,
+
+ arrow::Type::NA,
+ arrow::Type::DICTIONARY
+ },
+ arrowTypeName,
+ arrowTypeId);
+ break;
+
+ case ESimpleLogicalValueType::Null:
+ case ESimpleLogicalValueType::Void:
+ CheckArrowType(
+ columnType,
+ {
+ arrow::Type::NA,
+ arrow::Type::DICTIONARY
+ },
+ arrowTypeName,
+ arrowTypeId);
+ break;
+
+ case ESimpleLogicalValueType::Uuid:
+ CheckArrowType(
+ columnType,
+ {
+ arrow::Type::STRING,
+ arrow::Type::BINARY,
+ arrow::Type::LARGE_STRING,
+ arrow::Type::LARGE_BINARY,
+ arrow::Type::FIXED_SIZE_BINARY,
+ arrow::Type::DICTIONARY
+ },
+ arrowTypeName,
+ arrowTypeId);
+ break;
+
+ case ESimpleLogicalValueType::Interval64:
+ THROW_ERROR_EXCEPTION("Unexpected column type %Qv",
+ columnType);
+ }
+}
+
+void CheckArrowTypeMatch(
+ const ESimpleLogicalValueType& columnType,
+ const std::shared_ptr<arrow::Array>& column)
+{
+ CheckArrowTypeMatch(columnType, column->type()->name(), column->type_id());
+}
+
template <class TUnderlyingValueType>
TStringBuf SerializeDecimalBinary(TStringBuf value, int precision, char* buffer, size_t bufferLength)
{
@@ -104,24 +325,24 @@ public:
return ParseInt64<arrow::Int64Array>();
}
- arrow::Status Visit(const arrow::Date32Type& /*type*/) override
+ arrow::Status Visit(const arrow::Time32Type& /*type*/) override
{
- return ParseInt64<arrow::Date32Array>();
+ return ParseInt64<arrow::Time32Array>();
}
- arrow::Status Visit(const arrow::Time32Type& /*type*/) override
+ arrow::Status Visit(const arrow::Time64Type& /*type*/) override
{
- return ParseInt64<arrow::Time32Array>();
+ return ParseInt64<arrow::Time64Array>();
}
- arrow::Status Visit(const arrow::Date64Type& /*type*/) override
+ arrow::Status Visit(const arrow::Date32Type& /*type*/) override
{
- return ParseInt64<arrow::Date64Array>();
+ return ParseInt64<arrow::Date32Array>();
}
- arrow::Status Visit(const arrow::Time64Type& /*type*/) override
+ arrow::Status Visit(const arrow::Date64Type& /*type*/) override
{
- return ParseInt64<arrow::Time64Array>();
+ return ParseInt64<arrow::Date64Array>();
}
arrow::Status Visit(const arrow::TimestampType& /*type*/) override
@@ -224,7 +445,7 @@ private:
template <typename ArrayType>
arrow::Status ParseUInt64()
{
- auto makeUnversionedValue = [] (i64 value, i64 columnId) {
+ auto makeUnversionedValue = [] (ui64 value, i64 columnId) {
return MakeUnversionedUint64Value(value, columnId);
};
ParseSimpleNumeric<ArrayType, decltype(makeUnversionedValue)>(makeUnversionedValue);
@@ -336,10 +557,12 @@ class TArrayCompositeVisitor
{
public:
TArrayCompositeVisitor(
+ TLogicalTypePtr ytType,
const std::shared_ptr<arrow::Array>& array,
NYson::TCheckedInDebugYsonTokenWriter* writer,
int rowIndex)
- : RowIndex_(rowIndex)
+ : YTType_(DenullifyLogicalType(ytType))
+ , RowIndex_(rowIndex)
, Array_(array)
, Writer_(writer)
{
@@ -347,108 +570,129 @@ public:
}
// Signed integer types.
- arrow::Status Visit(const arrow::Int8Type& /*type*/) override
+ arrow::Status Visit(const arrow::Int8Type& type) override
{
+ CheckArrowTypeMatch(YTType_->AsSimpleTypeRef().GetElement(), type.type_name(), type.id());
return ParseInt64<arrow::Int8Array>();
}
- arrow::Status Visit(const arrow::Int16Type& /*type*/) override
+ arrow::Status Visit(const arrow::Int16Type& type) override
{
+ CheckArrowTypeMatch(YTType_->AsSimpleTypeRef().GetElement(), type.type_name(), type.id());
return ParseInt64<arrow::Int16Array>();
}
- arrow::Status Visit(const arrow::Int32Type& /*type*/) override
+ arrow::Status Visit(const arrow::Int32Type& type) override
{
+ CheckArrowTypeMatch(YTType_->AsSimpleTypeRef().GetElement(), type.type_name(), type.id());
return ParseInt64<arrow::Int32Array>();
}
- arrow::Status Visit(const arrow::Int64Type& /*type*/) override
+ arrow::Status Visit(const arrow::Int64Type& type) override
{
+ CheckArrowTypeMatch(YTType_->AsSimpleTypeRef().GetElement(), type.type_name(), type.id());
return ParseInt64<arrow::Int64Array>();
}
- arrow::Status Visit(const arrow::Date32Type& /*type*/) override
+ // Date types.
+ arrow::Status Visit(const arrow::Time32Type& type) override
{
- return ParseInt64<arrow::Date32Array>();
+ CheckArrowTypeMatch(YTType_->AsSimpleTypeRef().GetElement(), type.type_name(), type.id());
+ return ParseInt64<arrow::Time32Array>();
}
- arrow::Status Visit(const arrow::Time32Type& /*type*/) override
+ arrow::Status Visit(const arrow::Time64Type& type) override
{
- return ParseInt64<arrow::Time32Array>();
+ CheckArrowTypeMatch(YTType_->AsSimpleTypeRef().GetElement(), type.type_name(), type.id());
+ return ParseInt64<arrow::Time64Array>();
}
- arrow::Status Visit(const arrow::Date64Type& /*type*/) override
+ arrow::Status Visit(const arrow::Date32Type& type) override
{
- return ParseInt64<arrow::Date64Array>();
+ CheckArrowTypeMatch(YTType_->AsSimpleTypeRef().GetElement(), type.type_name(), type.id());
+ return ParseInt64<arrow::Date32Array>();
}
- arrow::Status Visit(const arrow::Time64Type& /*type*/) override
+ arrow::Status Visit(const arrow::Date64Type& type) override
{
- return ParseInt64<arrow::Time64Array>();
+ CheckArrowTypeMatch(YTType_->AsSimpleTypeRef().GetElement(), type.type_name(), type.id());
+ return ParseInt64<arrow::Date64Array>();
}
- arrow::Status Visit(const arrow::TimestampType& /*type*/) override
+ arrow::Status Visit(const arrow::TimestampType& type) override
{
+ CheckArrowTypeMatch(YTType_->AsSimpleTypeRef().GetElement(), type.type_name(), type.id());
return ParseInt64<arrow::TimestampArray>();
}
// Unsigned integer types.
- arrow::Status Visit(const arrow::UInt8Type& /*type*/) override
+ arrow::Status Visit(const arrow::UInt8Type& type) override
{
+ CheckArrowTypeMatch(YTType_->AsSimpleTypeRef().GetElement(), type.type_name(), type.id());
return ParseUInt64<arrow::UInt8Array>();
}
- arrow::Status Visit(const arrow::UInt16Type& /*type*/) override
+ arrow::Status Visit(const arrow::UInt16Type& type) override
{
+ CheckArrowTypeMatch(YTType_->AsSimpleTypeRef().GetElement(), type.type_name(), type.id());
return ParseUInt64<arrow::UInt16Array>();
}
- arrow::Status Visit(const arrow::UInt32Type& /*type*/) override
+ arrow::Status Visit(const arrow::UInt32Type& type) override
{
+ CheckArrowTypeMatch(YTType_->AsSimpleTypeRef().GetElement(), type.type_name(), type.id());
return ParseUInt64<arrow::UInt32Array>();
}
- arrow::Status Visit(const arrow::UInt64Type& /*type*/) override
+ arrow::Status Visit(const arrow::UInt64Type& type) override
{
+ CheckArrowTypeMatch(YTType_->AsSimpleTypeRef().GetElement(), type.type_name(), type.id());
return ParseUInt64<arrow::UInt64Array>();
}
// Float types.
- arrow::Status Visit(const arrow::HalfFloatType& /*type*/) override
+ arrow::Status Visit(const arrow::HalfFloatType& type) override
{
+ CheckArrowTypeMatch(YTType_->AsSimpleTypeRef().GetElement(), type.type_name(), type.id());
return ParseDouble<arrow::HalfFloatArray>();
}
- arrow::Status Visit(const arrow::FloatType& /*type*/) override
+ arrow::Status Visit(const arrow::FloatType& type) override
{
+ CheckArrowTypeMatch(YTType_->AsSimpleTypeRef().GetElement(), type.type_name(), type.id());
return ParseDouble<arrow::FloatArray>();
}
- arrow::Status Visit(const arrow::DoubleType& /*type*/) override
+ arrow::Status Visit(const arrow::DoubleType& type) override
{
+ CheckArrowTypeMatch(YTType_->AsSimpleTypeRef().GetElement(), type.type_name(), type.id());
return ParseDouble<arrow::DoubleArray>();
}
// Binary types.
- arrow::Status Visit(const arrow::StringType& /*type*/) override
+ arrow::Status Visit(const arrow::StringType& type) override
{
+ CheckArrowTypeMatch(YTType_->AsSimpleTypeRef().GetElement(), type.type_name(), type.id());
return ParseStringLikeArray<arrow::StringArray>();
}
- arrow::Status Visit(const arrow::BinaryType& /*type*/) override
+ arrow::Status Visit(const arrow::BinaryType& type) override
{
+ CheckArrowTypeMatch(YTType_->AsSimpleTypeRef().GetElement(), type.type_name(), type.id());
return ParseStringLikeArray<arrow::BinaryArray>();
}
// Boolean types.
- arrow::Status Visit(const arrow::BooleanType& /*type*/) override
+ arrow::Status Visit(const arrow::BooleanType& type) override
{
+ CheckArrowTypeMatch(YTType_->AsSimpleTypeRef().GetElement(), type.type_name(), type.id());
return ParseBoolean();
}
// Null types.
- arrow::Status Visit(const arrow::NullType& /*type*/) override
+ arrow::Status Visit(const arrow::NullType& type) override
{
+ CheckArrowTypeMatch(YTType_->AsSimpleTypeRef().GetElement(), type.type_name(), type.id());
return ParseNull();
}
@@ -483,6 +727,7 @@ public:
}
private:
+ const TLogicalTypePtr YTType_;
const int RowIndex_;
std::shared_ptr<arrow::Array> Array_;
@@ -569,6 +814,10 @@ private:
arrow::Status ParseList()
{
+ if (YTType_->GetMetatype() != ELogicalMetatype::List) {
+ THROW_ERROR_EXCEPTION("Unexpected arrow type \"list\" for YT metatype %Qlv",
+ YTType_->GetMetatype());
+ }
auto array = std::static_pointer_cast<arrow::ListArray>(Array_);
if (array->IsNull(RowIndex_)) {
Writer_->WriteEntity();
@@ -577,9 +826,14 @@ private:
auto listValue = array->value_slice(RowIndex_);
for (int offset = 0; offset < listValue->length(); ++offset) {
- TArrayCompositeVisitor visitor(listValue, Writer_, offset);
- ThrowOnError(listValue->type()->Accept(&visitor));
-
+ TArrayCompositeVisitor visitor(YTType_->AsListTypeRef().GetElement(), listValue, Writer_, offset);
+ try {
+ ThrowOnError(listValue->type()->Accept(&visitor));
+ } catch (const std::exception& ex) {
+ THROW_ERROR_EXCEPTION("Failed to parse arrow type \"list\"")
+ << TErrorAttribute("offset", offset)
+ << ex;
+ }
Writer_->WriteItemSeparator();
}
@@ -590,28 +844,47 @@ private:
arrow::Status ParseMap()
{
+ if (YTType_->GetMetatype() != ELogicalMetatype::Dict) {
+ THROW_ERROR_EXCEPTION("Unexpected arrow type \"map\" for YT metatype %Qlv",
+ YTType_->GetMetatype());
+ }
auto array = std::static_pointer_cast<arrow::MapArray>(Array_);
+ auto allKeys = array->keys();
+ auto allValues = array->items();
+
if (array->IsNull(RowIndex_)) {
Writer_->WriteEntity();
} else {
- auto element = std::static_pointer_cast<arrow::StructArray>(
- array->value_slice(RowIndex_));
+ auto offset = array->value_offset(RowIndex_);
+ auto length = array->value_length(RowIndex_);
- auto keyList = element->GetFieldByName("key");
- auto valueList = element->GetFieldByName("value");
+ auto keyList = allKeys->Slice(offset, length);
+ auto valueList = allValues->Slice(offset, length);
Writer_->WriteBeginList();
for (int offset = 0; offset < keyList->length(); ++offset) {
Writer_->WriteBeginList();
- TArrayCompositeVisitor keyVisitor(keyList, Writer_, offset);
- ThrowOnError(keyList->type()->Accept(&keyVisitor));
+ TArrayCompositeVisitor keyVisitor(YTType_->AsDictTypeRef().GetKey(), keyList, Writer_, offset);
+ try {
+ ThrowOnError(keyList->type()->Accept(&keyVisitor));
+ } catch (const std::exception& ex) {
+ THROW_ERROR_EXCEPTION("Failed to parse arrow key field of type \"map\"")
+ << TErrorAttribute("offset", offset)
+ << ex;
+ }
Writer_->WriteItemSeparator();
- TArrayCompositeVisitor valueVisitor(valueList, Writer_, offset);
- ThrowOnError(valueList->type()->Accept(&valueVisitor));
+ TArrayCompositeVisitor valueVisitor(YTType_->AsDictTypeRef().GetValue(), valueList, Writer_, offset);
+ try {
+ ThrowOnError(valueList->type()->Accept(&valueVisitor));
+ } catch (const std::exception& ex) {
+ THROW_ERROR_EXCEPTION("Failed to parse arrow value field type \"map\"")
+ << TErrorAttribute("offset", offset)
+ << ex;
+ }
Writer_->WriteItemSeparator();
@@ -626,16 +899,33 @@ private:
arrow::Status ParseStruct()
{
+ if (YTType_->GetMetatype() != ELogicalMetatype::Struct) {
+ THROW_ERROR_EXCEPTION("Unexpected arrow type \"struct\" for YT metatype %Qlv",
+ YTType_->GetMetatype());
+ }
auto array = std::static_pointer_cast<arrow::StructArray>(Array_);
if (array->IsNull(RowIndex_)) {
Writer_->WriteEntity();
} else {
Writer_->WriteBeginList();
-
- for (int offset = 0; offset < array->num_fields(); ++offset) {
- auto element = array->field(offset);
- TArrayCompositeVisitor visitor(element, Writer_, RowIndex_);
- ThrowOnError(element->type()->Accept(&visitor));
+ auto structFields = YTType_->AsStructTypeRef().GetFields();
+ if (std::ssize(structFields) != array->num_fields()) {
+ THROW_ERROR_EXCEPTION("The number of fields in the Arrow \"struct\" type does not match the number of fields in the YT \"struct\" type")
+ << TErrorAttribute("arrow_field_count", array->num_fields())
+ << TErrorAttribute("yt_field_count", std::ssize(structFields));
+ }
+ for (const auto& field : structFields) {
+ auto arrowField = array->GetFieldByName(field.Name);
+ if (!arrowField) {
+ THROW_ERROR_EXCEPTION("Field %Qv is not found in arrow type \"struct\"", field.Name);
+ }
+ TArrayCompositeVisitor visitor(field.Type, arrowField, Writer_, RowIndex_);
+ try {
+ ThrowOnError(arrowField->type()->Accept(&visitor));
+ } catch (const std::exception& ex) {
+ THROW_ERROR_EXCEPTION("Failed to parse arrow struct field %Qv", field.Name)
+ << ex;
+ }
Writer_->WriteItemSeparator();
}
@@ -657,177 +947,6 @@ private:
////////////////////////////////////////////////////////////////////////////////
-void CheckArrowType(
- auto ytTypeOrMetatype,
- const std::shared_ptr<arrow::DataType>& arrowType,
- std::initializer_list<arrow::Type::type> allowedTypes)
-{
- if (std::find(allowedTypes.begin(), allowedTypes.end(), arrowType->id()) == allowedTypes.end()) {
- THROW_ERROR_EXCEPTION("Unexpected arrow type %Qv for YT type or metatype %Qlv",
- arrowType->name(),
- ytTypeOrMetatype);
- }
-}
-
-void CheckMatchingArrowTypes(
- const ESimpleLogicalValueType& columnType,
- const std::shared_ptr<arrow::Array>& column)
-{
- switch (columnType) {
- case ESimpleLogicalValueType::Int8:
- case ESimpleLogicalValueType::Int16:
- case ESimpleLogicalValueType::Int32:
- case ESimpleLogicalValueType::Int64:
-
- case ESimpleLogicalValueType::Interval:
- CheckArrowType(
- columnType,
- column->type(),
- {
- arrow::Type::INT8,
- arrow::Type::INT16,
- arrow::Type::INT32,
- arrow::Type::INT64,
- arrow::Type::DATE32,
- arrow::Type::DATE64,
- arrow::Type::TIMESTAMP,
- arrow::Type::TIME32,
- arrow::Type::TIME64,
- arrow::Type::DICTIONARY
- });
- break;
-
- case ESimpleLogicalValueType::Uint8:
- case ESimpleLogicalValueType::Uint16:
- case ESimpleLogicalValueType::Uint32:
- case ESimpleLogicalValueType::Uint64:
-
- case ESimpleLogicalValueType::Date:
- case ESimpleLogicalValueType::Datetime:
- case ESimpleLogicalValueType::Timestamp:
- CheckArrowType(
- columnType,
- column->type(),
- {
- arrow::Type::UINT8,
- arrow::Type::UINT16,
- arrow::Type::UINT32,
- arrow::Type::UINT64,
- arrow::Type::DICTIONARY
- });
- break;
-
- case ESimpleLogicalValueType::String:
- case ESimpleLogicalValueType::Json:
- case ESimpleLogicalValueType::Utf8:
- CheckArrowType(
- columnType,
- column->type(),
- {
- arrow::Type::STRING,
- arrow::Type::BINARY,
- arrow::Type::LARGE_STRING,
- arrow::Type::LARGE_BINARY,
- arrow::Type::FIXED_SIZE_BINARY,
- arrow::Type::DICTIONARY,
- arrow::Type::DECIMAL128,
- arrow::Type::DECIMAL256,
- });
- break;
-
- case ESimpleLogicalValueType::Float:
- case ESimpleLogicalValueType::Double:
- CheckArrowType(
- columnType,
- column->type(),
- {
- arrow::Type::HALF_FLOAT,
- arrow::Type::FLOAT,
- arrow::Type::DOUBLE,
- arrow::Type::DICTIONARY
- });
- break;
-
- case ESimpleLogicalValueType::Boolean:
- CheckArrowType(
- columnType,
- column->type(),
- {arrow::Type::BOOL, arrow::Type::DICTIONARY});
- break;
-
- case ESimpleLogicalValueType::Any:
- CheckArrowType(
- columnType,
- column->type(),
- {
- arrow::Type::INT8,
- arrow::Type::INT16,
- arrow::Type::INT32,
- arrow::Type::INT64,
- arrow::Type::DATE32,
- arrow::Type::DATE64,
- arrow::Type::TIMESTAMP,
- arrow::Type::TIME32,
- arrow::Type::TIME64,
-
- arrow::Type::UINT8,
- arrow::Type::UINT16,
- arrow::Type::UINT32,
- arrow::Type::UINT64,
-
- arrow::Type::HALF_FLOAT,
- arrow::Type::FLOAT,
- arrow::Type::DOUBLE,
-
- arrow::Type::STRING,
- arrow::Type::BINARY,
- arrow::Type::LARGE_STRING,
- arrow::Type::LARGE_BINARY,
- arrow::Type::FIXED_SIZE_BINARY,
-
- arrow::Type::BOOL,
-
- arrow::Type::NA,
- arrow::Type::DICTIONARY
- });
- break;
-
- case ESimpleLogicalValueType::Null:
- case ESimpleLogicalValueType::Void:
- CheckArrowType(
- columnType,
- column->type(),
- {
- arrow::Type::NA,
- arrow::Type::DICTIONARY
- });
- break;
-
- case ESimpleLogicalValueType::Uuid:
- CheckArrowType(
- columnType,
- column->type(),
- {
- arrow::Type::STRING,
- arrow::Type::BINARY,
- arrow::Type::LARGE_STRING,
- arrow::Type::LARGE_BINARY,
- arrow::Type::FIXED_SIZE_BINARY,
- arrow::Type::DICTIONARY
- });
- break;
-
- case ESimpleLogicalValueType::Date32:
- case ESimpleLogicalValueType::Datetime64:
- case ESimpleLogicalValueType::Timestamp64:
- case ESimpleLogicalValueType::Interval64:
- THROW_ERROR_EXCEPTION("Unexpected column type %Qv",
- columnType);
- }
-}
-
-////////////////////////////////////////////////////////////////////////////////
-
void PrepareArrayForSimpleLogicalType(
ESimpleLogicalValueType columnType,
const std::shared_ptr<TChunkedOutputStream>& bufferForStringLikeValues,
@@ -836,12 +955,12 @@ void PrepareArrayForSimpleLogicalType(
int columnIndex,
int columnId)
{
- CheckMatchingArrowTypes(columnType, column);
+ CheckArrowTypeMatch(columnType, column);
if (column->type()->id() == arrow::Type::DICTIONARY) {
auto dictionaryArrayColumn = std::static_pointer_cast<arrow::DictionaryArray>(column);
auto dictionary = dictionaryArrayColumn->dictionary();
TUnversionedRowValues dictionaryValues(dictionary->length());
- CheckMatchingArrowTypes(columnType, dictionary);
+ CheckArrowTypeMatch(columnType, dictionary);
TArraySimpleVisitor visitor(columnType, columnId, dictionary, bufferForStringLikeValues, &dictionaryValues);
ThrowOnError(dictionaryArrayColumn->dictionary()->type()->Accept(&visitor));
@@ -873,48 +992,52 @@ void PrepareArrayForComplexType(
case ELogicalMetatype::List:
CheckArrowType(
metatype,
- column->type(),
{
arrow::Type::LIST,
arrow::Type::BINARY
- });
+ },
+ column->type()->name(),
+ column->type_id());
break;
case ELogicalMetatype::Dict:
CheckArrowType(
metatype,
- column->type(),
{
arrow::Type::MAP,
arrow::Type::BINARY
- });
+ },
+ column->type()->name(),
+ column->type_id());
break;
case ELogicalMetatype::Struct:
CheckArrowType(
metatype,
- column->type(),
{
arrow::Type::STRUCT,
arrow::Type::BINARY
- });
+ },
+ column->type()->name(),
+ column->type_id());
break;
case ELogicalMetatype::Decimal:
CheckArrowType(
metatype,
- column->type(),
{
arrow::Type::DECIMAL128,
arrow::Type::DECIMAL256
- });
+ },
+ column->type()->name(),
+ column->type_id());
break;
case ELogicalMetatype::Optional:
case ELogicalMetatype::Tuple:
case ELogicalMetatype::VariantTuple:
case ELogicalMetatype::VariantStruct:
- CheckArrowType(metatype, column->type(), {arrow::Type::BINARY});
+ CheckArrowType(metatype, {arrow::Type::BINARY}, column->type()->name(), column->type_id());
break;
default:
@@ -947,7 +1070,7 @@ void PrepareArrayForComplexType(
TBufferOutput out(valueBuffer);
NYson::TCheckedInDebugYsonTokenWriter writer(&out);
- TArrayCompositeVisitor visitor(column, &writer, rowIndex);
+ TArrayCompositeVisitor visitor(denullifiedLogicalType, column, &writer, rowIndex);
ThrowOnError(column->type()->Accept(&visitor));
@@ -1063,14 +1186,18 @@ public:
? columnSchema->LogicalType()
: OptionalLogicalType(SimpleLogicalType(ESimpleLogicalValueType::Any));
auto denullifiedColumnType = DenullifyLogicalType(columnType);
-
- PrepareArray(
- denullifiedColumnType,
- bufferForStringLikeValues,
- batch->column(columnIndex),
- rowsValues,
- columnIndex,
- columnId);
+ try {
+ PrepareArray(
+ denullifiedColumnType,
+ bufferForStringLikeValues,
+ batch->column(columnIndex),
+ rowsValues,
+ columnIndex,
+ columnId);
+ } catch (const std::exception& ex) {
+ THROW_ERROR_EXCEPTION("Failed to parse column %Qv", columnName)
+ << ex;
+ }
}
for (int rowIndex = 0; rowIndex < numRows; ++rowIndex) {
diff --git a/yt/yt/library/formats/unittests/arrow_parser_ut.cpp b/yt/yt/library/formats/unittests/arrow_parser_ut.cpp
index 156d8005b5..852323af18 100644
--- a/yt/yt/library/formats/unittests/arrow_parser_ut.cpp
+++ b/yt/yt/library/formats/unittests/arrow_parser_ut.cpp
@@ -184,7 +184,7 @@ std::string MakeMapArrow(const std::vector<std::vector<int32_t>>& key, const std
auto* pool = arrow::default_memory_pool();
auto keyBuilder = std::make_shared<arrow::Int32Builder>(pool);
- auto valueBuilder = std::make_shared<arrow::Int32Builder>(pool);
+ auto valueBuilder = std::make_shared<arrow::UInt32Builder>(pool);
auto mapBuilder = std::make_unique<arrow::MapBuilder>(pool, keyBuilder, valueBuilder);
for (ssize_t mapIndex = 0; mapIndex < std::ssize(key); mapIndex++) {
@@ -526,10 +526,10 @@ TEST(TArrowParserTest, Map)
parser->Finish();
auto firstNode = GetComposite(collectedRows.GetRowValue(0, "map"));
- ASSERT_EQ(ConvertToYsonTextStringStable(firstNode), "[[1;2;];[3;2;];]");
+ ASSERT_EQ(ConvertToYsonTextStringStable(firstNode), "[[1;2u;];[3;2u;];]");
auto secondNode = GetComposite(collectedRows.GetRowValue(1, "map"));
- ASSERT_EQ(ConvertToYsonTextStringStable(secondNode), "[[3;2;];]");
+ ASSERT_EQ(ConvertToYsonTextStringStable(secondNode), "[[3;2u;];]");
}
TEST(TArrowParserTest, SeveralIntArrays)
@@ -559,8 +559,8 @@ TEST(TArrowParserTest, Struct)
{
auto tableSchema = New<TTableSchema>(std::vector<TColumnSchema>{
TColumnSchema("struct", StructLogicalType({
- {"bar", SimpleLogicalType(ESimpleLogicalValueType::String)},
- {"foo", SimpleLogicalType(ESimpleLogicalValueType::Int64)},
+ {"bar", SimpleLogicalType(ESimpleLogicalValueType::String)},
+ {"foo", SimpleLogicalType(ESimpleLogicalValueType::Int64)},
})),
});
@@ -578,6 +578,23 @@ TEST(TArrowParserTest, Struct)
ASSERT_EQ(ConvertToYsonTextStringStable(secondNode), "[\"two\";2;]");
}
+TEST(TArrowParserTest, StructError)
+{
+ auto tableSchema = New<TTableSchema>(std::vector<TColumnSchema>{
+ TColumnSchema("struct", StructLogicalType({
+ {"bar", SimpleLogicalType(ESimpleLogicalValueType::String)},
+ })),
+ });
+
+ TCollectingValueConsumer collectedRows(tableSchema);
+
+ auto parser = CreateParserForArrow(&collectedRows);
+ EXPECT_THROW_MESSAGE_HAS_SUBSTR(
+ parser->Read(MakeStructArrow({"one", "two"}, {1, 2})),
+ std::exception,
+ "The number of fields in the Arrow \"struct\" type does not match the number of fields in the YT \"struct\" type");
+}
+
TEST(TArrowParserTest, DecimalVariousPrecisions)
{
auto tableSchema = New<TTableSchema>(std::vector<TColumnSchema>{