diff options
author | Ivan Nikolaev <[email protected]> | 2024-09-23 12:02:51 +0300 |
---|---|---|
committer | GitHub <[email protected]> | 2024-09-23 12:02:51 +0300 |
commit | b47a8b8f9c631d8bc6a79dc4c5cb839d1b8356e0 (patch) | |
tree | de32b7f9e063f56109235d8075bf1edf3409ce0b | |
parent | 29fff7ec9f02e124dec51d9300686096ae000f31 (diff) |
Support PG types in arrow and clickhouse (#9335)
-rw-r--r-- | ydb/core/formats/arrow/converter.cpp | 6 | ||||
-rw-r--r-- | ydb/core/formats/arrow/switch/switch_type.h | 4 | ||||
-rw-r--r-- | ydb/core/formats/arrow/ut/ut_arrow.cpp | 159 | ||||
-rw-r--r-- | ydb/core/formats/arrow/ut/ya.make | 3 | ||||
-rw-r--r-- | ydb/core/formats/clickhouse_block.cpp | 25 |
5 files changed, 151 insertions, 46 deletions
diff --git a/ydb/core/formats/arrow/converter.cpp b/ydb/core/formats/arrow/converter.cpp index f0a38e2c814..08fd2b6d6c6 100644 --- a/ydb/core/formats/arrow/converter.cpp +++ b/ydb/core/formats/arrow/converter.cpp @@ -297,9 +297,6 @@ bool TArrowToYdbConverter::Process(const arrow::RecordBatch& batch, TString& err for (; row < rowsUnroll; row += unroll) { ui32 col = 0; for (auto& [colName, colType] : YdbSchema_) { - // TODO: support pg types - Y_ABORT_UNLESS(colType.GetTypeId() != NScheme::NTypeIds::Pg, "pg types are not supported"); - auto& column = allColumns[col]; bool success = SwitchYqlTypeToArrowType(colType, [&]<typename TType>(TTypeWrapper<TType> typeHolder) { Y_UNUSED(typeHolder); @@ -347,9 +344,6 @@ bool TArrowToYdbConverter::Process(const arrow::RecordBatch& batch, TString& err ui32 col = 0; for (auto& [colName, colType] : YdbSchema_) { - // TODO: support pg types - Y_ABORT_UNLESS(colType.GetTypeId() != NScheme::NTypeIds::Pg, "pg types are not supported"); - auto& column = allColumns[col]; auto& curCell = cells[0][col]; if (column->IsNull(row)) { diff --git a/ydb/core/formats/arrow/switch/switch_type.h b/ydb/core/formats/arrow/switch/switch_type.h index 75090fbc0a5..dc04eee244e 100644 --- a/ydb/core/formats/arrow/switch/switch_type.h +++ b/ydb/core/formats/arrow/switch/switch_type.h @@ -93,9 +93,9 @@ template <typename TFunc> case TEXTOID: return callback(TTypeWrapper<arrow::StringType>()); default: - break; + return false; } - break; // TODO: support pg types + break; } return false; } diff --git a/ydb/core/formats/arrow/ut/ut_arrow.cpp b/ydb/core/formats/arrow/ut/ut_arrow.cpp index b12fc5561b1..51caf2cb8d6 100644 --- a/ydb/core/formats/arrow/ut/ut_arrow.cpp +++ b/ydb/core/formats/arrow/ut/ut_arrow.cpp @@ -21,29 +21,39 @@ using TTypeId = NScheme::TTypeId; using TTypeInfo = NScheme::TTypeInfo; struct TDataRow { - static const constexpr TTypeInfo Types[20] = { - TTypeInfo(NTypeIds::Bool), - TTypeInfo(NTypeIds::Int8), - TTypeInfo(NTypeIds::Int16), - TTypeInfo(NTypeIds::Int32), - TTypeInfo(NTypeIds::Int64), - TTypeInfo(NTypeIds::Uint8), - TTypeInfo(NTypeIds::Uint16), - TTypeInfo(NTypeIds::Uint32), - TTypeInfo(NTypeIds::Uint64), - TTypeInfo(NTypeIds::Float), - TTypeInfo(NTypeIds::Double), - TTypeInfo(NTypeIds::String), - TTypeInfo(NTypeIds::Utf8), - TTypeInfo(NTypeIds::Json), - TTypeInfo(NTypeIds::Yson), - TTypeInfo(NTypeIds::Date), - TTypeInfo(NTypeIds::Datetime), - TTypeInfo(NTypeIds::Timestamp), - TTypeInfo(NTypeIds::Interval), - TTypeInfo(NTypeIds::JsonDocument), - // TODO: DyNumber, Decimal - }; + static const TTypeInfo* MakeTypeInfos() { + static const TTypeInfo types[27] = { + TTypeInfo(NTypeIds::Bool), + TTypeInfo(NTypeIds::Int8), + TTypeInfo(NTypeIds::Int16), + TTypeInfo(NTypeIds::Int32), + TTypeInfo(NTypeIds::Int64), + TTypeInfo(NTypeIds::Uint8), + TTypeInfo(NTypeIds::Uint16), + TTypeInfo(NTypeIds::Uint32), + TTypeInfo(NTypeIds::Uint64), + TTypeInfo(NTypeIds::Float), + TTypeInfo(NTypeIds::Double), + TTypeInfo(NTypeIds::String), + TTypeInfo(NTypeIds::Utf8), + TTypeInfo(NTypeIds::Json), + TTypeInfo(NTypeIds::Yson), + TTypeInfo(NTypeIds::Date), + TTypeInfo(NTypeIds::Datetime), + TTypeInfo(NTypeIds::Timestamp), + TTypeInfo(NTypeIds::Interval), + TTypeInfo(NTypeIds::JsonDocument), + TTypeInfo(NPg::TypeDescFromPgTypeId(INT2OID)), + TTypeInfo(NPg::TypeDescFromPgTypeId(INT4OID)), + TTypeInfo(NPg::TypeDescFromPgTypeId(INT8OID)), + TTypeInfo(NPg::TypeDescFromPgTypeId(FLOAT4OID)), + TTypeInfo(NPg::TypeDescFromPgTypeId(FLOAT8OID)), + TTypeInfo(NPg::TypeDescFromPgTypeId(BYTEAOID)), + TTypeInfo(NPg::TypeDescFromPgTypeId(TEXTOID)), + // TODO: DyNumber, Decimal + }; + return types; + } bool Bool; i8 Int8; @@ -65,6 +75,13 @@ struct TDataRow { i64 Timestamp; i64 Interval; std::string JsonDocument; + i16 PgInt2; + i32 PgInt4; + i64 PgInt8; + float PgFloat4; + double PgFloat8; + std::string PgBytea; + std::string PgText; //ui64 Decimal[2]; bool operator == (const TDataRow& r) const { @@ -87,7 +104,14 @@ struct TDataRow { (Datetime == r.Datetime) && (Timestamp == r.Timestamp) && (Interval == r.Interval) && - (JsonDocument == r.JsonDocument); + (JsonDocument == r.JsonDocument) && + (PgInt2 == r.PgInt2) && + (PgInt4 == r.PgInt4) && + (PgInt8 == r.PgInt8) && + (PgFloat4 == r.PgFloat4) && + (PgFloat8 == r.PgFloat8) && + (PgBytea == r.PgBytea) && + (PgText == r.PgText); //(Decimal[0] == r.Decimal[0] && Decimal[1] == r.Decimal[1]); } @@ -113,6 +137,13 @@ struct TDataRow { arrow::field("ts", arrow::timestamp(arrow::TimeUnit::TimeUnit::MICRO)), arrow::field("ival", arrow::duration(arrow::TimeUnit::TimeUnit::MICRO)), arrow::field("json_doc", arrow::binary()), + arrow::field("pgint2", arrow::int16()), + arrow::field("pgint4", arrow::int32()), + arrow::field("pgint8", arrow::int64()), + arrow::field("pgfloat4", arrow::float32()), + arrow::field("pgfloat8", arrow::float64()), + arrow::field("pgbytea", arrow::binary()), + arrow::field("pgtext", arrow::utf8()), //arrow::field("dec", arrow::decimal(NScheme::DECIMAL_PRECISION, NScheme::DECIMAL_SCALE)), }; @@ -141,13 +172,20 @@ struct TDataRow { {"ts", TTypeInfo(NTypeIds::Timestamp) }, {"ival", TTypeInfo(NTypeIds::Interval) }, {"json_doc", TTypeInfo(NTypeIds::JsonDocument) }, + {"pgint2", TTypeInfo(NPg::TypeDescFromPgTypeId(INT2OID)) }, + {"pgint4", TTypeInfo(NPg::TypeDescFromPgTypeId(INT4OID)) }, + {"pgint8", TTypeInfo(NPg::TypeDescFromPgTypeId(INT8OID)) }, + {"pgfloat4", TTypeInfo(NPg::TypeDescFromPgTypeId(FLOAT4OID)) }, + {"pgfloat8", TTypeInfo(NPg::TypeDescFromPgTypeId(FLOAT8OID)) }, + {"pgbytea", TTypeInfo(NPg::TypeDescFromPgTypeId(BYTEAOID)) }, + {"pgtext", TTypeInfo(NPg::TypeDescFromPgTypeId(TEXTOID)) }, //{"dec", TTypeInfo(NTypeIds::Decimal) } }; return columns; } NKikimr::TDbTupleRef ToDbTupleRef() const { - static TCell Cells[20]; + static TCell Cells[27]; Cells[0] = TCell::Make<bool>(Bool); Cells[1] = TCell::Make<i8>(Int8); Cells[2] = TCell::Make<i16>(Int16); @@ -168,9 +206,16 @@ struct TDataRow { Cells[17] = TCell::Make<i64>(Timestamp); Cells[18] = TCell::Make<i64>(Interval); Cells[19] = TCell(JsonDocument.data(), JsonDocument.size()); + Cells[20] = TCell::Make<i16>(PgInt2); + Cells[21] = TCell::Make<i32>(PgInt4); + Cells[22] = TCell::Make<i64>(PgInt8); + Cells[23] = TCell::Make<float>(PgFloat4); + Cells[24] = TCell::Make<double>(PgFloat8); + Cells[25] = TCell(PgBytea.data(), PgBytea.size()); + Cells[26] = TCell(PgText.data(), PgText.size()); //Cells[19] = TCell((const char *)&Decimal[0], 16); - return NKikimr::TDbTupleRef(Types, Cells, 20); + return NKikimr::TDbTupleRef(MakeTypeInfos(), Cells, 27); } TOwnedCellVec SerializedCells() const { @@ -216,6 +261,13 @@ std::vector<TDataRow> ToVector(const std::shared_ptr<T>& table) { auto arival = std::static_pointer_cast<arrow::DurationArray>(GetColumn(*table, 18)); auto arjd = std::static_pointer_cast<arrow::BinaryArray>(GetColumn(*table, 19)); + auto arpgi2 = std::static_pointer_cast<arrow::Int16Array>(GetColumn(*table, 20)); + auto arpgi4 = std::static_pointer_cast<arrow::Int32Array>(GetColumn(*table, 21)); + auto arpgi8 = std::static_pointer_cast<arrow::Int64Array>(GetColumn(*table, 22)); + auto arpgf4 = std::static_pointer_cast<arrow::FloatArray>(GetColumn(*table, 23)); + auto arpgf8 = std::static_pointer_cast<arrow::DoubleArray>(GetColumn(*table, 24)); + auto arpgb = std::static_pointer_cast<arrow::BinaryArray>(GetColumn(*table, 25)); + auto arpgt = std::static_pointer_cast<arrow::StringArray>(GetColumn(*table, 26)); //auto ardec = std::static_pointer_cast<arrow::Decimal128Array>(GetColumn(*table, 19)); for (int64_t i = 0; i < table->num_rows(); ++i) { @@ -226,7 +278,9 @@ std::vector<TDataRow> ToVector(const std::shared_ptr<T>& table) { aru8->Value(i), aru16->Value(i), aru32->Value(i), aru64->Value(i), arf32->Value(i), arf64->Value(i), arstr->GetString(i), arutf->GetString(i), arj->GetString(i), ary->GetString(i), - ard->Value(i), ardt->Value(i), arts->Value(i), arival->Value(i), arjd->GetString(i) + ard->Value(i), ardt->Value(i), arts->Value(i), arival->Value(i), arjd->GetString(i), + arpgi2->Value(i), arpgi4->Value(i), arpgi8->Value(i), arpgf4->Value(i), arpgf8->Value(i), + arpgb->GetString(i), arpgt->GetString(i) //{dec[0], dec[1]} }; rows.emplace_back(std::move(r)); @@ -268,6 +322,13 @@ public: UNIT_ASSERT(Bival.Append(row.Interval).ok()); UNIT_ASSERT(Bjd.Append(row.JsonDocument).ok()); + UNIT_ASSERT(Bpgi2.Append(row.PgInt2).ok()); + UNIT_ASSERT(Bpgi4.Append(row.PgInt4).ok()); + UNIT_ASSERT(Bpgi8.Append(row.PgInt8).ok()); + UNIT_ASSERT(Bpgf4.Append(row.PgFloat4).ok()); + UNIT_ASSERT(Bpgf8.Append(row.PgFloat8).ok()); + UNIT_ASSERT(Bpgb.Append(row.PgBytea).ok()); + UNIT_ASSERT(Bpgt.Append(row.PgText).ok()); //UNIT_ASSERT(Bdec.Append((const char *)&row.Decimal).ok()); } @@ -295,6 +356,13 @@ public: std::shared_ptr<arrow::DurationArray> arival; std::shared_ptr<arrow::BinaryArray> arjd; + std::shared_ptr<arrow::Int16Array> arpgi2; + std::shared_ptr<arrow::Int32Array> arpgi4; + std::shared_ptr<arrow::Int64Array> arpgi8; + std::shared_ptr<arrow::FloatArray> arpgf4; + std::shared_ptr<arrow::DoubleArray> arpgf8; + std::shared_ptr<arrow::BinaryArray> arpgb; + std::shared_ptr<arrow::StringArray> arpgt; //std::shared_ptr<arrow::Decimal128Array> ardec; UNIT_ASSERT(Bbool.Finish(&arbool).ok()); @@ -320,6 +388,13 @@ public: UNIT_ASSERT(Bival.Finish(&arival).ok()); UNIT_ASSERT(Bjd.Finish(&arjd).ok()); + UNIT_ASSERT(Bpgi2.Finish(&arpgi2).ok()); + UNIT_ASSERT(Bpgi4.Finish(&arpgi4).ok()); + UNIT_ASSERT(Bpgi8.Finish(&arpgi8).ok()); + UNIT_ASSERT(Bpgf4.Finish(&arpgf4).ok()); + UNIT_ASSERT(Bpgf8.Finish(&arpgf8).ok()); + UNIT_ASSERT(Bpgb.Finish(&arpgb).ok()); + UNIT_ASSERT(Bpgt.Finish(&arpgt).ok()); //UNIT_ASSERT(Bdec.Finish(&ardec).ok()); std::shared_ptr<arrow::Schema> schema = TDataRow::MakeArrowSchema(); @@ -329,7 +404,9 @@ public: aru8, aru16, aru32, aru64, arf32, arf64, arstr, arutf, arj, ary, - ard, ardt, arts, arival, arjd + ard, ardt, arts, arival, arjd, + arpgi2, arpgi4, arpgi8, arpgf4, arpgf8, + arpgb, arpgt //ardec }); } @@ -363,6 +440,13 @@ private: arrow::TimestampBuilder Bts; arrow::DurationBuilder Bival; arrow::BinaryBuilder Bjd; + arrow::Int16Builder Bpgi2; + arrow::Int32Builder Bpgi4; + arrow::Int64Builder Bpgi8; + arrow::FloatBuilder Bpgf4; + arrow::DoubleBuilder Bpgf8; + arrow::BinaryBuilder Bpgb; + arrow::StringBuilder Bpgt; //arrow::Decimal128Builder Bdec; }; @@ -370,6 +454,7 @@ std::shared_ptr<arrow::RecordBatch> VectorToBatch(const std::vector<struct TData TString err; NArrow::TArrowBatchBuilder batchBuilder; batchBuilder.Start(TDataRow::MakeYdbSchema(), 0, 0, err); + UNIT_ASSERT_C(err.Empty(), err); for (const TDataRow& row : rows) { NKikimr::TDbTupleRef key; @@ -382,10 +467,14 @@ std::shared_ptr<arrow::RecordBatch> VectorToBatch(const std::vector<struct TData std::vector<TDataRow> TestRows() { std::vector<TDataRow> rows = { - {false, -1, -1, -1, -1, 1, 1, 1, 1, -1.0f, -1.0, "s1", "u1", "{\"j\":1}", "{y:1}", 0, 0, 0, 0, "{\"jd\":1}" }, - {false, 2, 2, 2, 2, 2, 2, 2, 2, 2.0f, 2.0, "s2", "u2", "{\"j\":2}", "{y:2}", 0, 0, 0, 0, "{\"jd\":1}" }, - {false, -3, -3, -3, -3, 3, 3, 3, 3, -3.0f, -3.0, "s3", "u3", "{\"j\":3}", "{y:3}", 0, 0, 0, 0, "{\"jd\":1}" }, - {false, -4, -4, -4, -4, 4, 4, 4, 4, 4.0f, 4.0, "s4", "u4", "{\"j\":4}", "{y:4}", 0, 0, 0, 0, "{\"jd\":1}" }, + {false, -1, -1, -1, -1, 1, 1, 1, 1, -1.0f, -1.0, "s1", "u1", "{\"j\":1}", "{y:1}", 0, 0, 0, 0, "{\"jd\":1}", + -5, -5, -5, -5.1f, -5.1, "s5", "u5"}, + {false, 2, 2, 2, 2, 2, 2, 2, 2, 2.0f, 2.0, "s2", "u2", "{\"j\":2}", "{y:2}", 0, 0, 0, 0, "{\"jd\":1}", + -3, -3, -3, -3.1f, -3.1, "s3", "u3"}, + {false, -3, -3, -3, -3, 3, 3, 3, 3, -3.0f, -3.0, "s3", "u3", "{\"j\":3}", "{y:3}", 0, 0, 0, 0, "{\"jd\":1}", + -2, -2, -2, -2.1f, -2.1, "s2", "u2"}, + {false, -4, -4, -4, -4, 4, 4, 4, 4, 4.0f, 4.0, "s4", "u4", "{\"j\":4}", "{y:4}", 0, 0, 0, 0, "{\"jd\":1}", + -7, -7, -7, -7.1f, -7.1, "s7", "u7"}, }; return rows; } @@ -412,7 +501,9 @@ std::shared_ptr<arrow::Table> MakeTable1000() { i8 a = i/100; i16 b = (i%100)/10; i32 c = i%10; - builder.AddRow(TDataRow{false, a, b, c, i, 1, 1, 1, 1, 1.0f, 1.0, "", "", "", "", 0, 0, 0, 0, {0,0} }); + builder.AddRow( + TDataRow{false, a, b, c, i, 1, 1, 1, 1, 1.0f, 1.0, "", "", "", "", 0, 0, 0, 0, {0,0}, + 0, 0, 0, 0.0f, 0.0, "", ""}); } auto table = builder.Finish(); @@ -575,7 +666,7 @@ Y_UNIT_TEST_SUITE(ArrowTest) { for (size_t i = 0; i < rows.size(); ++i) { UNIT_ASSERT(0 == CompareTypedCellVectors( cellRows[i].data(), rowWriter.Rows[i].data(), - TDataRow::Types, + TDataRow::MakeTypeInfos(), cellRows[i].size(), rowWriter.Rows[i].size())); } } diff --git a/ydb/core/formats/arrow/ut/ya.make b/ydb/core/formats/arrow/ut/ya.make index 54fa4d35773..463a62486d9 100644 --- a/ydb/core/formats/arrow/ut/ya.make +++ b/ydb/core/formats/arrow/ut/ya.make @@ -10,7 +10,8 @@ PEERDIR( # for NYql::NUdf alloc stuff used in binary_json ydb/library/yql/public/udf/service/exception_policy - ydb/library/yql/sql/pg_dummy + ydb/library/yql/sql/pg + ydb/library/yql/parser/pg_wrapper ) ADDINCL( diff --git a/ydb/core/formats/clickhouse_block.cpp b/ydb/core/formats/clickhouse_block.cpp index b0b20eb1e2f..a26c8f5393a 100644 --- a/ydb/core/formats/clickhouse_block.cpp +++ b/ydb/core/formats/clickhouse_block.cpp @@ -12,6 +12,10 @@ #include <util/generic/string.h> #include <util/generic/hash.h> +extern "C" { +#include <ydb/library/yql/parser/pg_wrapper/postgresql/src/include/catalog/pg_type_d.h> +} + namespace NKikHouse { namespace NSerialization { @@ -480,9 +484,24 @@ public: CONVERT(StepOrderId, String); case NScheme::NTypeIds::Pg: - // TODO: support pg types - throw yexception() << "Unsupported pg type"; - + switch (NPg::PgTypeIdFromTypeDesc(type.GetPgTypeDesc())) { + case INT2OID: + return Get("Int16"); + case INT4OID: + return Get("Int32"); + case INT8OID: + return Get("Int64"); + case FLOAT4OID: + return Get("Float"); + case FLOAT8OID: + return Get("Double"); + case BYTEAOID: + return Get("String"); + case TEXTOID: + return Get("String"); + default: + throw yexception() << "Unsupported pg type"; + } default: throw yexception() << "Unsupported type: " << type.GetTypeId(); } |