summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorIvan Nikolaev <[email protected]>2024-09-23 12:02:51 +0300
committerGitHub <[email protected]>2024-09-23 12:02:51 +0300
commitb47a8b8f9c631d8bc6a79dc4c5cb839d1b8356e0 (patch)
treede32b7f9e063f56109235d8075bf1edf3409ce0b
parent29fff7ec9f02e124dec51d9300686096ae000f31 (diff)
Support PG types in arrow and clickhouse (#9335)
-rw-r--r--ydb/core/formats/arrow/converter.cpp6
-rw-r--r--ydb/core/formats/arrow/switch/switch_type.h4
-rw-r--r--ydb/core/formats/arrow/ut/ut_arrow.cpp159
-rw-r--r--ydb/core/formats/arrow/ut/ya.make3
-rw-r--r--ydb/core/formats/clickhouse_block.cpp25
5 files changed, 151 insertions, 46 deletions
diff --git a/ydb/core/formats/arrow/converter.cpp b/ydb/core/formats/arrow/converter.cpp
index f0a38e2c814..08fd2b6d6c6 100644
--- a/ydb/core/formats/arrow/converter.cpp
+++ b/ydb/core/formats/arrow/converter.cpp
@@ -297,9 +297,6 @@ bool TArrowToYdbConverter::Process(const arrow::RecordBatch& batch, TString& err
for (; row < rowsUnroll; row += unroll) {
ui32 col = 0;
for (auto& [colName, colType] : YdbSchema_) {
- // TODO: support pg types
- Y_ABORT_UNLESS(colType.GetTypeId() != NScheme::NTypeIds::Pg, "pg types are not supported");
-
auto& column = allColumns[col];
bool success = SwitchYqlTypeToArrowType(colType, [&]<typename TType>(TTypeWrapper<TType> typeHolder) {
Y_UNUSED(typeHolder);
@@ -347,9 +344,6 @@ bool TArrowToYdbConverter::Process(const arrow::RecordBatch& batch, TString& err
ui32 col = 0;
for (auto& [colName, colType] : YdbSchema_) {
- // TODO: support pg types
- Y_ABORT_UNLESS(colType.GetTypeId() != NScheme::NTypeIds::Pg, "pg types are not supported");
-
auto& column = allColumns[col];
auto& curCell = cells[0][col];
if (column->IsNull(row)) {
diff --git a/ydb/core/formats/arrow/switch/switch_type.h b/ydb/core/formats/arrow/switch/switch_type.h
index 75090fbc0a5..dc04eee244e 100644
--- a/ydb/core/formats/arrow/switch/switch_type.h
+++ b/ydb/core/formats/arrow/switch/switch_type.h
@@ -93,9 +93,9 @@ template <typename TFunc>
case TEXTOID:
return callback(TTypeWrapper<arrow::StringType>());
default:
- break;
+ return false;
}
- break; // TODO: support pg types
+ break;
}
return false;
}
diff --git a/ydb/core/formats/arrow/ut/ut_arrow.cpp b/ydb/core/formats/arrow/ut/ut_arrow.cpp
index b12fc5561b1..51caf2cb8d6 100644
--- a/ydb/core/formats/arrow/ut/ut_arrow.cpp
+++ b/ydb/core/formats/arrow/ut/ut_arrow.cpp
@@ -21,29 +21,39 @@ using TTypeId = NScheme::TTypeId;
using TTypeInfo = NScheme::TTypeInfo;
struct TDataRow {
- static const constexpr TTypeInfo Types[20] = {
- TTypeInfo(NTypeIds::Bool),
- TTypeInfo(NTypeIds::Int8),
- TTypeInfo(NTypeIds::Int16),
- TTypeInfo(NTypeIds::Int32),
- TTypeInfo(NTypeIds::Int64),
- TTypeInfo(NTypeIds::Uint8),
- TTypeInfo(NTypeIds::Uint16),
- TTypeInfo(NTypeIds::Uint32),
- TTypeInfo(NTypeIds::Uint64),
- TTypeInfo(NTypeIds::Float),
- TTypeInfo(NTypeIds::Double),
- TTypeInfo(NTypeIds::String),
- TTypeInfo(NTypeIds::Utf8),
- TTypeInfo(NTypeIds::Json),
- TTypeInfo(NTypeIds::Yson),
- TTypeInfo(NTypeIds::Date),
- TTypeInfo(NTypeIds::Datetime),
- TTypeInfo(NTypeIds::Timestamp),
- TTypeInfo(NTypeIds::Interval),
- TTypeInfo(NTypeIds::JsonDocument),
- // TODO: DyNumber, Decimal
- };
+ static const TTypeInfo* MakeTypeInfos() {
+ static const TTypeInfo types[27] = {
+ TTypeInfo(NTypeIds::Bool),
+ TTypeInfo(NTypeIds::Int8),
+ TTypeInfo(NTypeIds::Int16),
+ TTypeInfo(NTypeIds::Int32),
+ TTypeInfo(NTypeIds::Int64),
+ TTypeInfo(NTypeIds::Uint8),
+ TTypeInfo(NTypeIds::Uint16),
+ TTypeInfo(NTypeIds::Uint32),
+ TTypeInfo(NTypeIds::Uint64),
+ TTypeInfo(NTypeIds::Float),
+ TTypeInfo(NTypeIds::Double),
+ TTypeInfo(NTypeIds::String),
+ TTypeInfo(NTypeIds::Utf8),
+ TTypeInfo(NTypeIds::Json),
+ TTypeInfo(NTypeIds::Yson),
+ TTypeInfo(NTypeIds::Date),
+ TTypeInfo(NTypeIds::Datetime),
+ TTypeInfo(NTypeIds::Timestamp),
+ TTypeInfo(NTypeIds::Interval),
+ TTypeInfo(NTypeIds::JsonDocument),
+ TTypeInfo(NPg::TypeDescFromPgTypeId(INT2OID)),
+ TTypeInfo(NPg::TypeDescFromPgTypeId(INT4OID)),
+ TTypeInfo(NPg::TypeDescFromPgTypeId(INT8OID)),
+ TTypeInfo(NPg::TypeDescFromPgTypeId(FLOAT4OID)),
+ TTypeInfo(NPg::TypeDescFromPgTypeId(FLOAT8OID)),
+ TTypeInfo(NPg::TypeDescFromPgTypeId(BYTEAOID)),
+ TTypeInfo(NPg::TypeDescFromPgTypeId(TEXTOID)),
+ // TODO: DyNumber, Decimal
+ };
+ return types;
+ }
bool Bool;
i8 Int8;
@@ -65,6 +75,13 @@ struct TDataRow {
i64 Timestamp;
i64 Interval;
std::string JsonDocument;
+ i16 PgInt2;
+ i32 PgInt4;
+ i64 PgInt8;
+ float PgFloat4;
+ double PgFloat8;
+ std::string PgBytea;
+ std::string PgText;
//ui64 Decimal[2];
bool operator == (const TDataRow& r) const {
@@ -87,7 +104,14 @@ struct TDataRow {
(Datetime == r.Datetime) &&
(Timestamp == r.Timestamp) &&
(Interval == r.Interval) &&
- (JsonDocument == r.JsonDocument);
+ (JsonDocument == r.JsonDocument) &&
+ (PgInt2 == r.PgInt2) &&
+ (PgInt4 == r.PgInt4) &&
+ (PgInt8 == r.PgInt8) &&
+ (PgFloat4 == r.PgFloat4) &&
+ (PgFloat8 == r.PgFloat8) &&
+ (PgBytea == r.PgBytea) &&
+ (PgText == r.PgText);
//(Decimal[0] == r.Decimal[0] && Decimal[1] == r.Decimal[1]);
}
@@ -113,6 +137,13 @@ struct TDataRow {
arrow::field("ts", arrow::timestamp(arrow::TimeUnit::TimeUnit::MICRO)),
arrow::field("ival", arrow::duration(arrow::TimeUnit::TimeUnit::MICRO)),
arrow::field("json_doc", arrow::binary()),
+ arrow::field("pgint2", arrow::int16()),
+ arrow::field("pgint4", arrow::int32()),
+ arrow::field("pgint8", arrow::int64()),
+ arrow::field("pgfloat4", arrow::float32()),
+ arrow::field("pgfloat8", arrow::float64()),
+ arrow::field("pgbytea", arrow::binary()),
+ arrow::field("pgtext", arrow::utf8()),
//arrow::field("dec", arrow::decimal(NScheme::DECIMAL_PRECISION, NScheme::DECIMAL_SCALE)),
};
@@ -141,13 +172,20 @@ struct TDataRow {
{"ts", TTypeInfo(NTypeIds::Timestamp) },
{"ival", TTypeInfo(NTypeIds::Interval) },
{"json_doc", TTypeInfo(NTypeIds::JsonDocument) },
+ {"pgint2", TTypeInfo(NPg::TypeDescFromPgTypeId(INT2OID)) },
+ {"pgint4", TTypeInfo(NPg::TypeDescFromPgTypeId(INT4OID)) },
+ {"pgint8", TTypeInfo(NPg::TypeDescFromPgTypeId(INT8OID)) },
+ {"pgfloat4", TTypeInfo(NPg::TypeDescFromPgTypeId(FLOAT4OID)) },
+ {"pgfloat8", TTypeInfo(NPg::TypeDescFromPgTypeId(FLOAT8OID)) },
+ {"pgbytea", TTypeInfo(NPg::TypeDescFromPgTypeId(BYTEAOID)) },
+ {"pgtext", TTypeInfo(NPg::TypeDescFromPgTypeId(TEXTOID)) },
//{"dec", TTypeInfo(NTypeIds::Decimal) }
};
return columns;
}
NKikimr::TDbTupleRef ToDbTupleRef() const {
- static TCell Cells[20];
+ static TCell Cells[27];
Cells[0] = TCell::Make<bool>(Bool);
Cells[1] = TCell::Make<i8>(Int8);
Cells[2] = TCell::Make<i16>(Int16);
@@ -168,9 +206,16 @@ struct TDataRow {
Cells[17] = TCell::Make<i64>(Timestamp);
Cells[18] = TCell::Make<i64>(Interval);
Cells[19] = TCell(JsonDocument.data(), JsonDocument.size());
+ Cells[20] = TCell::Make<i16>(PgInt2);
+ Cells[21] = TCell::Make<i32>(PgInt4);
+ Cells[22] = TCell::Make<i64>(PgInt8);
+ Cells[23] = TCell::Make<float>(PgFloat4);
+ Cells[24] = TCell::Make<double>(PgFloat8);
+ Cells[25] = TCell(PgBytea.data(), PgBytea.size());
+ Cells[26] = TCell(PgText.data(), PgText.size());
//Cells[19] = TCell((const char *)&Decimal[0], 16);
- return NKikimr::TDbTupleRef(Types, Cells, 20);
+ return NKikimr::TDbTupleRef(MakeTypeInfos(), Cells, 27);
}
TOwnedCellVec SerializedCells() const {
@@ -216,6 +261,13 @@ std::vector<TDataRow> ToVector(const std::shared_ptr<T>& table) {
auto arival = std::static_pointer_cast<arrow::DurationArray>(GetColumn(*table, 18));
auto arjd = std::static_pointer_cast<arrow::BinaryArray>(GetColumn(*table, 19));
+ auto arpgi2 = std::static_pointer_cast<arrow::Int16Array>(GetColumn(*table, 20));
+ auto arpgi4 = std::static_pointer_cast<arrow::Int32Array>(GetColumn(*table, 21));
+ auto arpgi8 = std::static_pointer_cast<arrow::Int64Array>(GetColumn(*table, 22));
+ auto arpgf4 = std::static_pointer_cast<arrow::FloatArray>(GetColumn(*table, 23));
+ auto arpgf8 = std::static_pointer_cast<arrow::DoubleArray>(GetColumn(*table, 24));
+ auto arpgb = std::static_pointer_cast<arrow::BinaryArray>(GetColumn(*table, 25));
+ auto arpgt = std::static_pointer_cast<arrow::StringArray>(GetColumn(*table, 26));
//auto ardec = std::static_pointer_cast<arrow::Decimal128Array>(GetColumn(*table, 19));
for (int64_t i = 0; i < table->num_rows(); ++i) {
@@ -226,7 +278,9 @@ std::vector<TDataRow> ToVector(const std::shared_ptr<T>& table) {
aru8->Value(i), aru16->Value(i), aru32->Value(i), aru64->Value(i),
arf32->Value(i), arf64->Value(i),
arstr->GetString(i), arutf->GetString(i), arj->GetString(i), ary->GetString(i),
- ard->Value(i), ardt->Value(i), arts->Value(i), arival->Value(i), arjd->GetString(i)
+ ard->Value(i), ardt->Value(i), arts->Value(i), arival->Value(i), arjd->GetString(i),
+ arpgi2->Value(i), arpgi4->Value(i), arpgi8->Value(i), arpgf4->Value(i), arpgf8->Value(i),
+ arpgb->GetString(i), arpgt->GetString(i)
//{dec[0], dec[1]}
};
rows.emplace_back(std::move(r));
@@ -268,6 +322,13 @@ public:
UNIT_ASSERT(Bival.Append(row.Interval).ok());
UNIT_ASSERT(Bjd.Append(row.JsonDocument).ok());
+ UNIT_ASSERT(Bpgi2.Append(row.PgInt2).ok());
+ UNIT_ASSERT(Bpgi4.Append(row.PgInt4).ok());
+ UNIT_ASSERT(Bpgi8.Append(row.PgInt8).ok());
+ UNIT_ASSERT(Bpgf4.Append(row.PgFloat4).ok());
+ UNIT_ASSERT(Bpgf8.Append(row.PgFloat8).ok());
+ UNIT_ASSERT(Bpgb.Append(row.PgBytea).ok());
+ UNIT_ASSERT(Bpgt.Append(row.PgText).ok());
//UNIT_ASSERT(Bdec.Append((const char *)&row.Decimal).ok());
}
@@ -295,6 +356,13 @@ public:
std::shared_ptr<arrow::DurationArray> arival;
std::shared_ptr<arrow::BinaryArray> arjd;
+ std::shared_ptr<arrow::Int16Array> arpgi2;
+ std::shared_ptr<arrow::Int32Array> arpgi4;
+ std::shared_ptr<arrow::Int64Array> arpgi8;
+ std::shared_ptr<arrow::FloatArray> arpgf4;
+ std::shared_ptr<arrow::DoubleArray> arpgf8;
+ std::shared_ptr<arrow::BinaryArray> arpgb;
+ std::shared_ptr<arrow::StringArray> arpgt;
//std::shared_ptr<arrow::Decimal128Array> ardec;
UNIT_ASSERT(Bbool.Finish(&arbool).ok());
@@ -320,6 +388,13 @@ public:
UNIT_ASSERT(Bival.Finish(&arival).ok());
UNIT_ASSERT(Bjd.Finish(&arjd).ok());
+ UNIT_ASSERT(Bpgi2.Finish(&arpgi2).ok());
+ UNIT_ASSERT(Bpgi4.Finish(&arpgi4).ok());
+ UNIT_ASSERT(Bpgi8.Finish(&arpgi8).ok());
+ UNIT_ASSERT(Bpgf4.Finish(&arpgf4).ok());
+ UNIT_ASSERT(Bpgf8.Finish(&arpgf8).ok());
+ UNIT_ASSERT(Bpgb.Finish(&arpgb).ok());
+ UNIT_ASSERT(Bpgt.Finish(&arpgt).ok());
//UNIT_ASSERT(Bdec.Finish(&ardec).ok());
std::shared_ptr<arrow::Schema> schema = TDataRow::MakeArrowSchema();
@@ -329,7 +404,9 @@ public:
aru8, aru16, aru32, aru64,
arf32, arf64,
arstr, arutf, arj, ary,
- ard, ardt, arts, arival, arjd
+ ard, ardt, arts, arival, arjd,
+ arpgi2, arpgi4, arpgi8, arpgf4, arpgf8,
+ arpgb, arpgt
//ardec
});
}
@@ -363,6 +440,13 @@ private:
arrow::TimestampBuilder Bts;
arrow::DurationBuilder Bival;
arrow::BinaryBuilder Bjd;
+ arrow::Int16Builder Bpgi2;
+ arrow::Int32Builder Bpgi4;
+ arrow::Int64Builder Bpgi8;
+ arrow::FloatBuilder Bpgf4;
+ arrow::DoubleBuilder Bpgf8;
+ arrow::BinaryBuilder Bpgb;
+ arrow::StringBuilder Bpgt;
//arrow::Decimal128Builder Bdec;
};
@@ -370,6 +454,7 @@ std::shared_ptr<arrow::RecordBatch> VectorToBatch(const std::vector<struct TData
TString err;
NArrow::TArrowBatchBuilder batchBuilder;
batchBuilder.Start(TDataRow::MakeYdbSchema(), 0, 0, err);
+ UNIT_ASSERT_C(err.Empty(), err);
for (const TDataRow& row : rows) {
NKikimr::TDbTupleRef key;
@@ -382,10 +467,14 @@ std::shared_ptr<arrow::RecordBatch> VectorToBatch(const std::vector<struct TData
std::vector<TDataRow> TestRows() {
std::vector<TDataRow> rows = {
- {false, -1, -1, -1, -1, 1, 1, 1, 1, -1.0f, -1.0, "s1", "u1", "{\"j\":1}", "{y:1}", 0, 0, 0, 0, "{\"jd\":1}" },
- {false, 2, 2, 2, 2, 2, 2, 2, 2, 2.0f, 2.0, "s2", "u2", "{\"j\":2}", "{y:2}", 0, 0, 0, 0, "{\"jd\":1}" },
- {false, -3, -3, -3, -3, 3, 3, 3, 3, -3.0f, -3.0, "s3", "u3", "{\"j\":3}", "{y:3}", 0, 0, 0, 0, "{\"jd\":1}" },
- {false, -4, -4, -4, -4, 4, 4, 4, 4, 4.0f, 4.0, "s4", "u4", "{\"j\":4}", "{y:4}", 0, 0, 0, 0, "{\"jd\":1}" },
+ {false, -1, -1, -1, -1, 1, 1, 1, 1, -1.0f, -1.0, "s1", "u1", "{\"j\":1}", "{y:1}", 0, 0, 0, 0, "{\"jd\":1}",
+ -5, -5, -5, -5.1f, -5.1, "s5", "u5"},
+ {false, 2, 2, 2, 2, 2, 2, 2, 2, 2.0f, 2.0, "s2", "u2", "{\"j\":2}", "{y:2}", 0, 0, 0, 0, "{\"jd\":1}",
+ -3, -3, -3, -3.1f, -3.1, "s3", "u3"},
+ {false, -3, -3, -3, -3, 3, 3, 3, 3, -3.0f, -3.0, "s3", "u3", "{\"j\":3}", "{y:3}", 0, 0, 0, 0, "{\"jd\":1}",
+ -2, -2, -2, -2.1f, -2.1, "s2", "u2"},
+ {false, -4, -4, -4, -4, 4, 4, 4, 4, 4.0f, 4.0, "s4", "u4", "{\"j\":4}", "{y:4}", 0, 0, 0, 0, "{\"jd\":1}",
+ -7, -7, -7, -7.1f, -7.1, "s7", "u7"},
};
return rows;
}
@@ -412,7 +501,9 @@ std::shared_ptr<arrow::Table> MakeTable1000() {
i8 a = i/100;
i16 b = (i%100)/10;
i32 c = i%10;
- builder.AddRow(TDataRow{false, a, b, c, i, 1, 1, 1, 1, 1.0f, 1.0, "", "", "", "", 0, 0, 0, 0, {0,0} });
+ builder.AddRow(
+ TDataRow{false, a, b, c, i, 1, 1, 1, 1, 1.0f, 1.0, "", "", "", "", 0, 0, 0, 0, {0,0},
+ 0, 0, 0, 0.0f, 0.0, "", ""});
}
auto table = builder.Finish();
@@ -575,7 +666,7 @@ Y_UNIT_TEST_SUITE(ArrowTest) {
for (size_t i = 0; i < rows.size(); ++i) {
UNIT_ASSERT(0 == CompareTypedCellVectors(
cellRows[i].data(), rowWriter.Rows[i].data(),
- TDataRow::Types,
+ TDataRow::MakeTypeInfos(),
cellRows[i].size(), rowWriter.Rows[i].size()));
}
}
diff --git a/ydb/core/formats/arrow/ut/ya.make b/ydb/core/formats/arrow/ut/ya.make
index 54fa4d35773..463a62486d9 100644
--- a/ydb/core/formats/arrow/ut/ya.make
+++ b/ydb/core/formats/arrow/ut/ya.make
@@ -10,7 +10,8 @@ PEERDIR(
# for NYql::NUdf alloc stuff used in binary_json
ydb/library/yql/public/udf/service/exception_policy
- ydb/library/yql/sql/pg_dummy
+ ydb/library/yql/sql/pg
+ ydb/library/yql/parser/pg_wrapper
)
ADDINCL(
diff --git a/ydb/core/formats/clickhouse_block.cpp b/ydb/core/formats/clickhouse_block.cpp
index b0b20eb1e2f..a26c8f5393a 100644
--- a/ydb/core/formats/clickhouse_block.cpp
+++ b/ydb/core/formats/clickhouse_block.cpp
@@ -12,6 +12,10 @@
#include <util/generic/string.h>
#include <util/generic/hash.h>
+extern "C" {
+#include <ydb/library/yql/parser/pg_wrapper/postgresql/src/include/catalog/pg_type_d.h>
+}
+
namespace NKikHouse {
namespace NSerialization {
@@ -480,9 +484,24 @@ public:
CONVERT(StepOrderId, String);
case NScheme::NTypeIds::Pg:
- // TODO: support pg types
- throw yexception() << "Unsupported pg type";
-
+ switch (NPg::PgTypeIdFromTypeDesc(type.GetPgTypeDesc())) {
+ case INT2OID:
+ return Get("Int16");
+ case INT4OID:
+ return Get("Int32");
+ case INT8OID:
+ return Get("Int64");
+ case FLOAT4OID:
+ return Get("Float");
+ case FLOAT8OID:
+ return Get("Double");
+ case BYTEAOID:
+ return Get("String");
+ case TEXTOID:
+ return Get("String");
+ default:
+ throw yexception() << "Unsupported pg type";
+ }
default:
throw yexception() << "Unsupported type: " << type.GetTypeId();
}