aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMaxim Akhmedov <max@tracto.ai>2024-08-29 13:10:31 +0300
committerrobot-piglet <robot-piglet@yandex-team.com>2024-08-29 13:34:50 +0300
commitc73194dcb0bc7f7d74d54d342675fe490dd8611f (patch)
tree1018912c4bfacbe1991e3ff384fcf198f43f2bc9
parentfeda194119e23531dfef4473c657ff97da2305eb (diff)
downloadydb-c73194dcb0bc7f7d74d54d342675fe490dd8611f.tar.gz
Support Decimal128/Decimal256 in Arrow parser
No description --- b24f71e64f22e615ebb32f33fb2cfc5c88198c1a Pull Request resolved: https://github.com/ytsaurus/ytsaurus/pull/769 Co-authored-by: nadya02 <nadya02@yandex-team.com>
-rw-r--r--yt/yt/library/decimal/decimal.cpp15
-rw-r--r--yt/yt/library/decimal/decimal.h3
-rw-r--r--yt/yt/library/formats/arrow_parser.cpp118
3 files changed, 127 insertions, 9 deletions
diff --git a/yt/yt/library/decimal/decimal.cpp b/yt/yt/library/decimal/decimal.cpp
index 890df73dd8..0d3363baba 100644
--- a/yt/yt/library/decimal/decimal.cpp
+++ b/yt/yt/library/decimal/decimal.cpp
@@ -535,6 +535,21 @@ TStringBuf TDecimal::WriteBinary128(int precision, TValue128 value, char* buffer
return TStringBuf{buffer, sizeof(TValue128)};
}
+TStringBuf TDecimal::WriteBinaryVariadic(int precision, TValue128 value, char* buffer, size_t bufferLength)
+{
+ const size_t resultLength = GetValueBinarySize(precision);
+ switch (resultLength) {
+ case 4:
+ return WriteBinary32(precision, static_cast<i32>(value.Low), buffer, bufferLength);
+ case 8:
+ return WriteBinary64(precision, static_cast<i64>(value.Low), buffer, bufferLength);
+ case 16:
+ return WriteBinary128(precision, value, buffer, bufferLength);
+ default:
+ THROW_ERROR_EXCEPTION("Invalid precision %v", precision);
+ }
+}
+
template <typename T>
Y_FORCE_INLINE void CheckBufferLength(int precision, size_t bufferLength)
{
diff --git a/yt/yt/library/decimal/decimal.h b/yt/yt/library/decimal/decimal.h
index 7cb9bdd7c0..8a6cf34a8f 100644
--- a/yt/yt/library/decimal/decimal.h
+++ b/yt/yt/library/decimal/decimal.h
@@ -50,6 +50,9 @@ public:
static TStringBuf WriteBinary64(int precision, i64 value, char* buffer, size_t bufferLength);
static TStringBuf WriteBinary128(int precision, TValue128 value, char* buffer, size_t bufferLength);
+ // Writes either 32-bit, 64-bit or 128-bit binary value depending on precision, provided a TValue128.
+ static TStringBuf WriteBinaryVariadic(int precision, TValue128 value, char* buffer, size_t bufferLength);
+
static i32 ParseBinary32(int precision, TStringBuf buffer);
static i64 ParseBinary64(int precision, TStringBuf buffer);
static TValue128 ParseBinary128(int precision, TStringBuf buffer);
diff --git a/yt/yt/library/formats/arrow_parser.cpp b/yt/yt/library/formats/arrow_parser.cpp
index 217ddedf21..f815aaefed 100644
--- a/yt/yt/library/formats/arrow_parser.cpp
+++ b/yt/yt/library/formats/arrow_parser.cpp
@@ -7,6 +7,8 @@
#include <yt/yt/client/formats/parser.h>
+#include <yt/yt/library/decimal/decimal.h>
+
#include <library/cpp/yt/memory/chunked_output_stream.h>
#include <util/stream/buffer.h>
@@ -19,10 +21,13 @@
#include <contrib/libs/apache/arrow/cpp/src/arrow/ipc/api.h>
+#include <contrib/libs/apache/arrow/cpp/src/arrow/util/decimal.h>
+
namespace NYT::NFormats {
using namespace NTableClient;
using TUnversionedRowValues = std::vector<NTableClient::TUnversionedValue>;
+using namespace NDecimal;
namespace {
@@ -31,7 +36,7 @@ namespace {
void ThrowOnError(const arrow::Status& status)
{
if (!status.ok()) {
- THROW_ERROR_EXCEPTION("Arrow error occurred: %Qv", status.message());
+ THROW_ERROR_EXCEPTION("Arrow error [%v]: %Qv", status.CodeAsString(), status.message());
}
}
@@ -158,6 +163,31 @@ public:
return ParseNull();
}
+ // Decimal types. For now, YT natively supports only Decimal128 with scale up to 35.
+ // Thus, we represent short enough decimals as native YT decimals, and wider decimals as
+ // their decimal string representation; but the latter is subject to change whenever we
+ // get the native support for Decimal128 with scale up to 38 or Decimal256 with scale up to 76.
+ arrow::Status Visit(const arrow::Decimal128Type& type) override
+ {
+ constexpr int MaximumYTDecimalPrecision = 35;
+ if (type.precision() <= MaximumYTDecimalPrecision) {
+ return ParseStringLikeArray<arrow::Decimal128Array>([&] (const TStringBuf& value, i64 columnId) {
+ return MakeDecimalBinaryValue(value, columnId, type.precision());
+ });
+ } else {
+ return ParseStringLikeArray<arrow::Decimal128Array>([&] (const TStringBuf& value, i64 columnId) {
+ return MakeDecimalTextValue<arrow::Decimal128>(value, columnId, type.scale());
+ });
+ }
+ }
+
+ arrow::Status Visit(const arrow::Decimal256Type& type) override
+ {
+ return ParseStringLikeArray<arrow::Decimal256Array>([&] (const TStringBuf& value, i64 columnId) {
+ return MakeDecimalTextValue<arrow::Decimal256>(value, columnId, type.scale());
+ });
+ }
+
private:
const i64 ColumnId_;
@@ -209,7 +239,7 @@ private:
}
template <typename ArrayType>
- arrow::Status ParseStringLikeArray()
+ arrow::Status ParseStringLikeArray(auto makeUnversionedValueFunc)
{
auto array = std::static_pointer_cast<ArrayType>(Array_);
for (int rowIndex = 0; rowIndex < array->length(); ++rowIndex) {
@@ -225,12 +255,23 @@ private:
BufferForStringLikeValues_->Advance(element.size());
auto value = TStringBuf(buffer, element.size());
- (*RowValues_)[rowIndex] = MakeUnversionedStringValue(value, ColumnId_);
+ (*RowValues_)[rowIndex] = makeUnversionedValueFunc(value, ColumnId_);
}
}
return arrow::Status::OK();
}
+ template <typename ArrayType>
+ arrow::Status ParseStringLikeArray()
+ {
+ // Note that MakeUnversionedStringValue actually has third argument in its signature,
+ // which leads to a "too few arguments" in the point of its invocation if we try to pass
+ // it directly to ParseStringLikeArray.
+ return ParseStringLikeArray<ArrayType>([] (const TStringBuf& value, i64 columnId) {
+ return MakeUnversionedStringValue(value, columnId);
+ });
+ }
+
arrow::Status ParseBoolean()
{
auto array = std::static_pointer_cast<arrow::BooleanArray>(Array_);
@@ -252,6 +293,34 @@ private:
}
return arrow::Status::OK();
}
+
+ TUnversionedValue MakeDecimalBinaryValue(const TStringBuf& value, i64 columnId, int precision)
+ {
+ // NB: arrow wire representation of Decimal128 is little-endian and (obviously) 128 bit,
+ // while YT in-memory representation of Decimal is big-endian, variadic-length of either 32 bit, 64 bit or 128 bit,
+ // and MSB-flipped to ensure lexical sorting order.
+ TDecimal::TValue128 value128;
+ YT_VERIFY(value.size() == sizeof(value128));
+ std::memcpy(&value128, value.data(), value.size());
+
+ const auto maxByteCount = sizeof(value128);
+ char* buffer = BufferForStringLikeValues_->Preallocate(maxByteCount);
+ auto decimalBinary = TDecimal::WriteBinaryVariadic(precision, value128, buffer, maxByteCount);
+ BufferForStringLikeValues_->Advance(decimalBinary.size());
+
+ return MakeUnversionedStringValue(decimalBinary, columnId);
+ }
+
+ template <class TArrowDecimalType>
+ TUnversionedValue MakeDecimalTextValue(const TStringBuf& value, i64 columnId, int scale)
+ {
+ TArrowDecimalType decimal(reinterpret_cast<const uint8_t*>(value.data()));
+ auto string = decimal.ToString(scale);
+ char* buffer = BufferForStringLikeValues_->Preallocate(string.size());
+ std::memcpy(buffer, string.data(), string.size());
+ BufferForStringLikeValues_->Advance(string.size());
+ return MakeUnversionedStringValue(TStringBuf(buffer, string.size()), columnId);
+ }
};
////////////////////////////////////////////////////////////////////////////////
@@ -552,12 +621,14 @@ private:
////////////////////////////////////////////////////////////////////////////////
void CheckArrowType(
+ auto ytTypeOrMetatype,
const std::shared_ptr<arrow::DataType>& arrowType,
std::initializer_list<arrow::Type::type> allowedTypes)
{
if (std::find(allowedTypes.begin(), allowedTypes.end(), arrowType->id()) == allowedTypes.end()) {
- THROW_ERROR_EXCEPTION("Unexpected arrow type %Qv",
- arrowType->name());
+ THROW_ERROR_EXCEPTION("Unexpected arrow type %Qv for YT type or metatype %Qlv",
+ arrowType->name(),
+ ytTypeOrMetatype);
}
}
@@ -573,6 +644,7 @@ void CheckMatchingArrowTypes(
case ESimpleLogicalValueType::Interval:
CheckArrowType(
+ columnType,
column->type(),
{
arrow::Type::INT8,
@@ -597,6 +669,7 @@ void CheckMatchingArrowTypes(
case ESimpleLogicalValueType::Datetime:
case ESimpleLogicalValueType::Timestamp:
CheckArrowType(
+ columnType,
column->type(),
{
arrow::Type::UINT8,
@@ -611,6 +684,7 @@ void CheckMatchingArrowTypes(
case ESimpleLogicalValueType::Json:
case ESimpleLogicalValueType::Utf8:
CheckArrowType(
+ columnType,
column->type(),
{
arrow::Type::STRING,
@@ -618,13 +692,16 @@ void CheckMatchingArrowTypes(
arrow::Type::LARGE_STRING,
arrow::Type::LARGE_BINARY,
arrow::Type::FIXED_SIZE_BINARY,
- arrow::Type::DICTIONARY
+ arrow::Type::DICTIONARY,
+ arrow::Type::DECIMAL128,
+ arrow::Type::DECIMAL256,
});
break;
case ESimpleLogicalValueType::Float:
case ESimpleLogicalValueType::Double:
CheckArrowType(
+ columnType,
column->type(),
{
arrow::Type::HALF_FLOAT,
@@ -636,12 +713,14 @@ void CheckMatchingArrowTypes(
case ESimpleLogicalValueType::Boolean:
CheckArrowType(
+ columnType,
column->type(),
{arrow::Type::BOOL, arrow::Type::DICTIONARY});
break;
case ESimpleLogicalValueType::Any:
CheckArrowType(
+ columnType,
column->type(),
{
arrow::Type::INT8,
@@ -679,6 +758,7 @@ void CheckMatchingArrowTypes(
case ESimpleLogicalValueType::Null:
case ESimpleLogicalValueType::Void:
CheckArrowType(
+ columnType,
column->type(),
{
arrow::Type::NA,
@@ -688,6 +768,7 @@ void CheckMatchingArrowTypes(
case ESimpleLogicalValueType::Uuid:
CheckArrowType(
+ columnType,
column->type(),
{
arrow::Type::STRING,
@@ -749,9 +830,10 @@ void PrepareArrayForComplexType(
int columnIndex,
int columnId)
{
- switch (denullifiedLogicalType->GetMetatype()) {
+ switch (auto metatype = denullifiedLogicalType->GetMetatype()) {
case ELogicalMetatype::List:
CheckArrowType(
+ metatype,
column->type(),
{
arrow::Type::LIST,
@@ -761,6 +843,7 @@ void PrepareArrayForComplexType(
case ELogicalMetatype::Dict:
CheckArrowType(
+ metatype,
column->type(),
{
arrow::Type::MAP,
@@ -770,32 +853,49 @@ void PrepareArrayForComplexType(
case ELogicalMetatype::Struct:
CheckArrowType(
+ metatype,
column->type(),
{
arrow::Type::STRUCT,
arrow::Type::BINARY
});
break;
+
case ELogicalMetatype::Decimal:
+ CheckArrowType(
+ metatype,
+ column->type(),
+ {
+ arrow::Type::DECIMAL128,
+ arrow::Type::DECIMAL256
+ });
+ break;
+
case ELogicalMetatype::Optional:
case ELogicalMetatype::Tuple:
case ELogicalMetatype::VariantTuple:
case ELogicalMetatype::VariantStruct:
- CheckArrowType(column->type(), {arrow::Type::BINARY});
+ CheckArrowType(metatype, column->type(), {arrow::Type::BINARY});
break;
default:
THROW_ERROR_EXCEPTION("Unexpected arrow type in complex type %Qv", column->type()->name());
}
- if (column->type()->id() == arrow::Type::BINARY) {
+ if (column->type()->id() == arrow::Type::BINARY ||
+ column->type()->id() == arrow::Type::DECIMAL128 ||
+ column->type()->id() == arrow::Type::DECIMAL256)
+ {
TUnversionedRowValues stringValues(rowsValues[columnIndex].size());
TArraySimpleVisitor visitor(columnId, column, bufferForStringLikeValues, &stringValues);
ThrowOnError(column->type()->Accept(&visitor));
for (int offset = 0; offset < std::ssize(rowsValues[columnIndex]); offset++) {
if (column->IsNull(offset)) {
rowsValues[columnIndex][offset] = MakeUnversionedNullValue(columnId);
+ } else if (column->type()->id() == arrow::Type::DECIMAL128 || column->type()->id() == arrow::Type::DECIMAL256) {
+ rowsValues[columnIndex][offset] = MakeUnversionedStringValue(stringValues[offset].AsStringBuf(), columnId);
} else {
+ // TODO(max): is it even correct? Binary is not necessarily a correct YSON...
rowsValues[columnIndex][offset] = MakeUnversionedCompositeValue(stringValues[offset].AsStringBuf(), columnId);
}
}