diff options
author | Maxim Akhmedov <max@tracto.ai> | 2024-08-29 13:10:31 +0300 |
---|---|---|
committer | robot-piglet <robot-piglet@yandex-team.com> | 2024-08-29 13:34:50 +0300 |
commit | c73194dcb0bc7f7d74d54d342675fe490dd8611f (patch) | |
tree | 1018912c4bfacbe1991e3ff384fcf198f43f2bc9 | |
parent | feda194119e23531dfef4473c657ff97da2305eb (diff) | |
download | ydb-c73194dcb0bc7f7d74d54d342675fe490dd8611f.tar.gz |
Support Decimal128/Decimal256 in Arrow parser
No description
---
b24f71e64f22e615ebb32f33fb2cfc5c88198c1a
Pull Request resolved: https://github.com/ytsaurus/ytsaurus/pull/769
Co-authored-by: nadya02 <nadya02@yandex-team.com>
-rw-r--r-- | yt/yt/library/decimal/decimal.cpp | 15 | ||||
-rw-r--r-- | yt/yt/library/decimal/decimal.h | 3 | ||||
-rw-r--r-- | yt/yt/library/formats/arrow_parser.cpp | 118 |
3 files changed, 127 insertions, 9 deletions
diff --git a/yt/yt/library/decimal/decimal.cpp b/yt/yt/library/decimal/decimal.cpp index 890df73dd8..0d3363baba 100644 --- a/yt/yt/library/decimal/decimal.cpp +++ b/yt/yt/library/decimal/decimal.cpp @@ -535,6 +535,21 @@ TStringBuf TDecimal::WriteBinary128(int precision, TValue128 value, char* buffer return TStringBuf{buffer, sizeof(TValue128)}; } +TStringBuf TDecimal::WriteBinaryVariadic(int precision, TValue128 value, char* buffer, size_t bufferLength) +{ + const size_t resultLength = GetValueBinarySize(precision); + switch (resultLength) { + case 4: + return WriteBinary32(precision, static_cast<i32>(value.Low), buffer, bufferLength); + case 8: + return WriteBinary64(precision, static_cast<i64>(value.Low), buffer, bufferLength); + case 16: + return WriteBinary128(precision, value, buffer, bufferLength); + default: + THROW_ERROR_EXCEPTION("Invalid precision %v", precision); + } +} + template <typename T> Y_FORCE_INLINE void CheckBufferLength(int precision, size_t bufferLength) { diff --git a/yt/yt/library/decimal/decimal.h b/yt/yt/library/decimal/decimal.h index 7cb9bdd7c0..8a6cf34a8f 100644 --- a/yt/yt/library/decimal/decimal.h +++ b/yt/yt/library/decimal/decimal.h @@ -50,6 +50,9 @@ public: static TStringBuf WriteBinary64(int precision, i64 value, char* buffer, size_t bufferLength); static TStringBuf WriteBinary128(int precision, TValue128 value, char* buffer, size_t bufferLength); + // Writes either 32-bit, 64-bit or 128-bit binary value depending on precision, provided a TValue128. + static TStringBuf WriteBinaryVariadic(int precision, TValue128 value, char* buffer, size_t bufferLength); + static i32 ParseBinary32(int precision, TStringBuf buffer); static i64 ParseBinary64(int precision, TStringBuf buffer); static TValue128 ParseBinary128(int precision, TStringBuf buffer); diff --git a/yt/yt/library/formats/arrow_parser.cpp b/yt/yt/library/formats/arrow_parser.cpp index 217ddedf21..f815aaefed 100644 --- a/yt/yt/library/formats/arrow_parser.cpp +++ b/yt/yt/library/formats/arrow_parser.cpp @@ -7,6 +7,8 @@ #include <yt/yt/client/formats/parser.h> +#include <yt/yt/library/decimal/decimal.h> + #include <library/cpp/yt/memory/chunked_output_stream.h> #include <util/stream/buffer.h> @@ -19,10 +21,13 @@ #include <contrib/libs/apache/arrow/cpp/src/arrow/ipc/api.h> +#include <contrib/libs/apache/arrow/cpp/src/arrow/util/decimal.h> + namespace NYT::NFormats { using namespace NTableClient; using TUnversionedRowValues = std::vector<NTableClient::TUnversionedValue>; +using namespace NDecimal; namespace { @@ -31,7 +36,7 @@ namespace { void ThrowOnError(const arrow::Status& status) { if (!status.ok()) { - THROW_ERROR_EXCEPTION("Arrow error occurred: %Qv", status.message()); + THROW_ERROR_EXCEPTION("Arrow error [%v]: %Qv", status.CodeAsString(), status.message()); } } @@ -158,6 +163,31 @@ public: return ParseNull(); } + // Decimal types. For now, YT natively supports only Decimal128 with scale up to 35. + // Thus, we represent short enough decimals as native YT decimals, and wider decimals as + // their decimal string representation; but the latter is subject to change whenever we + // get the native support for Decimal128 with scale up to 38 or Decimal256 with scale up to 76. + arrow::Status Visit(const arrow::Decimal128Type& type) override + { + constexpr int MaximumYTDecimalPrecision = 35; + if (type.precision() <= MaximumYTDecimalPrecision) { + return ParseStringLikeArray<arrow::Decimal128Array>([&] (const TStringBuf& value, i64 columnId) { + return MakeDecimalBinaryValue(value, columnId, type.precision()); + }); + } else { + return ParseStringLikeArray<arrow::Decimal128Array>([&] (const TStringBuf& value, i64 columnId) { + return MakeDecimalTextValue<arrow::Decimal128>(value, columnId, type.scale()); + }); + } + } + + arrow::Status Visit(const arrow::Decimal256Type& type) override + { + return ParseStringLikeArray<arrow::Decimal256Array>([&] (const TStringBuf& value, i64 columnId) { + return MakeDecimalTextValue<arrow::Decimal256>(value, columnId, type.scale()); + }); + } + private: const i64 ColumnId_; @@ -209,7 +239,7 @@ private: } template <typename ArrayType> - arrow::Status ParseStringLikeArray() + arrow::Status ParseStringLikeArray(auto makeUnversionedValueFunc) { auto array = std::static_pointer_cast<ArrayType>(Array_); for (int rowIndex = 0; rowIndex < array->length(); ++rowIndex) { @@ -225,12 +255,23 @@ private: BufferForStringLikeValues_->Advance(element.size()); auto value = TStringBuf(buffer, element.size()); - (*RowValues_)[rowIndex] = MakeUnversionedStringValue(value, ColumnId_); + (*RowValues_)[rowIndex] = makeUnversionedValueFunc(value, ColumnId_); } } return arrow::Status::OK(); } + template <typename ArrayType> + arrow::Status ParseStringLikeArray() + { + // Note that MakeUnversionedStringValue actually has third argument in its signature, + // which leads to a "too few arguments" in the point of its invocation if we try to pass + // it directly to ParseStringLikeArray. + return ParseStringLikeArray<ArrayType>([] (const TStringBuf& value, i64 columnId) { + return MakeUnversionedStringValue(value, columnId); + }); + } + arrow::Status ParseBoolean() { auto array = std::static_pointer_cast<arrow::BooleanArray>(Array_); @@ -252,6 +293,34 @@ private: } return arrow::Status::OK(); } + + TUnversionedValue MakeDecimalBinaryValue(const TStringBuf& value, i64 columnId, int precision) + { + // NB: arrow wire representation of Decimal128 is little-endian and (obviously) 128 bit, + // while YT in-memory representation of Decimal is big-endian, variadic-length of either 32 bit, 64 bit or 128 bit, + // and MSB-flipped to ensure lexical sorting order. + TDecimal::TValue128 value128; + YT_VERIFY(value.size() == sizeof(value128)); + std::memcpy(&value128, value.data(), value.size()); + + const auto maxByteCount = sizeof(value128); + char* buffer = BufferForStringLikeValues_->Preallocate(maxByteCount); + auto decimalBinary = TDecimal::WriteBinaryVariadic(precision, value128, buffer, maxByteCount); + BufferForStringLikeValues_->Advance(decimalBinary.size()); + + return MakeUnversionedStringValue(decimalBinary, columnId); + } + + template <class TArrowDecimalType> + TUnversionedValue MakeDecimalTextValue(const TStringBuf& value, i64 columnId, int scale) + { + TArrowDecimalType decimal(reinterpret_cast<const uint8_t*>(value.data())); + auto string = decimal.ToString(scale); + char* buffer = BufferForStringLikeValues_->Preallocate(string.size()); + std::memcpy(buffer, string.data(), string.size()); + BufferForStringLikeValues_->Advance(string.size()); + return MakeUnversionedStringValue(TStringBuf(buffer, string.size()), columnId); + } }; //////////////////////////////////////////////////////////////////////////////// @@ -552,12 +621,14 @@ private: //////////////////////////////////////////////////////////////////////////////// void CheckArrowType( + auto ytTypeOrMetatype, const std::shared_ptr<arrow::DataType>& arrowType, std::initializer_list<arrow::Type::type> allowedTypes) { if (std::find(allowedTypes.begin(), allowedTypes.end(), arrowType->id()) == allowedTypes.end()) { - THROW_ERROR_EXCEPTION("Unexpected arrow type %Qv", - arrowType->name()); + THROW_ERROR_EXCEPTION("Unexpected arrow type %Qv for YT type or metatype %Qlv", + arrowType->name(), + ytTypeOrMetatype); } } @@ -573,6 +644,7 @@ void CheckMatchingArrowTypes( case ESimpleLogicalValueType::Interval: CheckArrowType( + columnType, column->type(), { arrow::Type::INT8, @@ -597,6 +669,7 @@ void CheckMatchingArrowTypes( case ESimpleLogicalValueType::Datetime: case ESimpleLogicalValueType::Timestamp: CheckArrowType( + columnType, column->type(), { arrow::Type::UINT8, @@ -611,6 +684,7 @@ void CheckMatchingArrowTypes( case ESimpleLogicalValueType::Json: case ESimpleLogicalValueType::Utf8: CheckArrowType( + columnType, column->type(), { arrow::Type::STRING, @@ -618,13 +692,16 @@ void CheckMatchingArrowTypes( arrow::Type::LARGE_STRING, arrow::Type::LARGE_BINARY, arrow::Type::FIXED_SIZE_BINARY, - arrow::Type::DICTIONARY + arrow::Type::DICTIONARY, + arrow::Type::DECIMAL128, + arrow::Type::DECIMAL256, }); break; case ESimpleLogicalValueType::Float: case ESimpleLogicalValueType::Double: CheckArrowType( + columnType, column->type(), { arrow::Type::HALF_FLOAT, @@ -636,12 +713,14 @@ void CheckMatchingArrowTypes( case ESimpleLogicalValueType::Boolean: CheckArrowType( + columnType, column->type(), {arrow::Type::BOOL, arrow::Type::DICTIONARY}); break; case ESimpleLogicalValueType::Any: CheckArrowType( + columnType, column->type(), { arrow::Type::INT8, @@ -679,6 +758,7 @@ void CheckMatchingArrowTypes( case ESimpleLogicalValueType::Null: case ESimpleLogicalValueType::Void: CheckArrowType( + columnType, column->type(), { arrow::Type::NA, @@ -688,6 +768,7 @@ void CheckMatchingArrowTypes( case ESimpleLogicalValueType::Uuid: CheckArrowType( + columnType, column->type(), { arrow::Type::STRING, @@ -749,9 +830,10 @@ void PrepareArrayForComplexType( int columnIndex, int columnId) { - switch (denullifiedLogicalType->GetMetatype()) { + switch (auto metatype = denullifiedLogicalType->GetMetatype()) { case ELogicalMetatype::List: CheckArrowType( + metatype, column->type(), { arrow::Type::LIST, @@ -761,6 +843,7 @@ void PrepareArrayForComplexType( case ELogicalMetatype::Dict: CheckArrowType( + metatype, column->type(), { arrow::Type::MAP, @@ -770,32 +853,49 @@ void PrepareArrayForComplexType( case ELogicalMetatype::Struct: CheckArrowType( + metatype, column->type(), { arrow::Type::STRUCT, arrow::Type::BINARY }); break; + case ELogicalMetatype::Decimal: + CheckArrowType( + metatype, + column->type(), + { + arrow::Type::DECIMAL128, + arrow::Type::DECIMAL256 + }); + break; + case ELogicalMetatype::Optional: case ELogicalMetatype::Tuple: case ELogicalMetatype::VariantTuple: case ELogicalMetatype::VariantStruct: - CheckArrowType(column->type(), {arrow::Type::BINARY}); + CheckArrowType(metatype, column->type(), {arrow::Type::BINARY}); break; default: THROW_ERROR_EXCEPTION("Unexpected arrow type in complex type %Qv", column->type()->name()); } - if (column->type()->id() == arrow::Type::BINARY) { + if (column->type()->id() == arrow::Type::BINARY || + column->type()->id() == arrow::Type::DECIMAL128 || + column->type()->id() == arrow::Type::DECIMAL256) + { TUnversionedRowValues stringValues(rowsValues[columnIndex].size()); TArraySimpleVisitor visitor(columnId, column, bufferForStringLikeValues, &stringValues); ThrowOnError(column->type()->Accept(&visitor)); for (int offset = 0; offset < std::ssize(rowsValues[columnIndex]); offset++) { if (column->IsNull(offset)) { rowsValues[columnIndex][offset] = MakeUnversionedNullValue(columnId); + } else if (column->type()->id() == arrow::Type::DECIMAL128 || column->type()->id() == arrow::Type::DECIMAL256) { + rowsValues[columnIndex][offset] = MakeUnversionedStringValue(stringValues[offset].AsStringBuf(), columnId); } else { + // TODO(max): is it even correct? Binary is not necessarily a correct YSON... rowsValues[columnIndex][offset] = MakeUnversionedCompositeValue(stringValues[offset].AsStringBuf(), columnId); } } |