summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorivanmorozov <[email protected]>2023-11-16 11:11:14 +0300
committerivanmorozov <[email protected]>2023-11-16 11:42:23 +0300
commit0189677704e6b4969cc08a9a824ca1b7fa46701c (patch)
tree2d6ea1553806016f34685675462a898433e1e01e
parent0d32e31b28b3fd1e1ede9b32126cba1c8e5f18f2 (diff)
KIKIMR-20087: add helpers for compute-sharding
-rw-r--r--ydb/core/formats/arrow/permutations.cpp183
-rw-r--r--ydb/core/formats/arrow/permutations.h73
-rw-r--r--ydb/core/formats/arrow/replace_key.h87
3 files changed, 256 insertions, 87 deletions
diff --git a/ydb/core/formats/arrow/permutations.cpp b/ydb/core/formats/arrow/permutations.cpp
index 3abc3bb2cec..f00d621f56f 100644
--- a/ydb/core/formats/arrow/permutations.cpp
+++ b/ydb/core/formats/arrow/permutations.cpp
@@ -1,8 +1,10 @@
#include "arrow_helpers.h"
#include "permutations.h"
#include "replace_key.h"
+#include "size_calcer.h"
#include <ydb/core/formats/arrow/common/validation.h>
#include <contrib/libs/apache/arrow/cpp/src/arrow/array/builder_primitive.h>
+#include <library/cpp/actors/core/log.h>
namespace NKikimr::NArrow {
@@ -170,4 +172,185 @@ std::shared_ptr<arrow::Array> CopyRecords(const std::shared_ptr<arrow::Array>& s
return result;
}
+bool THashConstructor::BuildHashUI64(std::shared_ptr<arrow::RecordBatch>& batch, const std::vector<std::string>& fieldNames, const std::string& hashFieldName) {
+ if (fieldNames.size() == 0) {
+ return false;
+ }
+ Y_ABORT_UNLESS(!batch->GetColumnByName(hashFieldName));
+ if (fieldNames.size() == 1) {
+ auto column = batch->GetColumnByName(fieldNames.front());
+ if (!column) {
+ return false;
+ }
+ Y_ABORT_UNLESS(column);
+ if (column->type()->id() == arrow::Type::UINT64 || column->type()->id() == arrow::Type::UINT32 || column->type()->id() == arrow::Type::INT64 || column->type()->id() == arrow::Type::INT32) {
+ batch = TStatusValidator::GetValid(batch->AddColumn(batch->num_columns(), hashFieldName, column));
+ return true;
+ }
+ }
+ auto builder = NArrow::MakeBuilder(std::make_shared<arrow::Field>(hashFieldName, arrow::TypeTraits<arrow::UInt64Type>::type_singleton()));
+ {
+ auto& intBuilder = static_cast<arrow::UInt64Builder&>(*builder);
+ TStatusValidator::Validate(intBuilder.Reserve(batch->num_rows()));
+ std::vector<std::shared_ptr<arrow::Array>> columns;
+ for (auto&& i : fieldNames) {
+ auto column = batch->GetColumnByName(i);
+// AFL_VERIFY(column)("column_name", i)("all_columns", JoinSeq(",", fieldNames));
+ if (column) {
+ columns.emplace_back(column);
+ }
+ }
+ if (columns.empty()) {
+ return false;
+ }
+ for (i64 i = 0; i < batch->num_rows(); ++i) {
+ intBuilder.UnsafeAppend(TypedHash(columns, i));
+ }
+ }
+ batch = TStatusValidator::GetValid(batch->AddColumn(batch->num_columns(), hashFieldName, NArrow::TStatusValidator::GetValid(builder->Finish())));
+ return true;
+}
+
+size_t THashConstructor::TypedHash(const std::vector<std::shared_ptr<arrow::Array>>& ar, const int pos) {
+ size_t result = 0;
+ for (auto&& i : ar) {
+ result = CombineHashes(result, TypedHash(*i, pos));
+ }
+ return result;
+}
+
+size_t THashConstructor::TypedHash(const arrow::Array& ar, const int pos) {
+ switch (ar.type_id()) {
+ case arrow::Type::BOOL:
+ return (size_t)(static_cast<const arrow::BooleanArray&>(ar).Value(pos));
+ case arrow::Type::UINT8:
+ return THash<ui8>()(static_cast<const arrow::UInt8Array&>(ar).Value(pos));
+ case arrow::Type::INT8:
+ return THash<i8>()(static_cast<const arrow::Int8Array&>(ar).Value(pos));
+ case arrow::Type::UINT16:
+ return THash<ui16>()(static_cast<const arrow::UInt16Array&>(ar).Value(pos));
+ case arrow::Type::INT16:
+ return THash<i16>()(static_cast<const arrow::Int16Array&>(ar).Value(pos));
+ case arrow::Type::UINT32:
+ return THash<ui32>()(static_cast<const arrow::UInt32Array&>(ar).Value(pos));
+ case arrow::Type::INT32:
+ return THash<i32>()(static_cast<const arrow::Int32Array&>(ar).Value(pos));
+ case arrow::Type::UINT64:
+ return THash<ui64>()(static_cast<const arrow::UInt64Array&>(ar).Value(pos));
+ case arrow::Type::INT64:
+ return THash<i64>()(static_cast<const arrow::Int64Array&>(ar).Value(pos));
+ case arrow::Type::HALF_FLOAT:
+ break;
+ case arrow::Type::FLOAT:
+ return THash<float>()(static_cast<const arrow::FloatArray&>(ar).Value(pos));
+ case arrow::Type::DOUBLE:
+ return THash<double>()(static_cast<const arrow::DoubleArray&>(ar).Value(pos));
+ case arrow::Type::STRING:
+ {
+ const auto& str = static_cast<const arrow::StringArray&>(ar).GetView(pos);
+ return CityHash64(str.data(), str.size());
+ }
+ case arrow::Type::BINARY:
+ {
+ const auto& str = static_cast<const arrow::BinaryArray&>(ar).GetView(pos);
+ return CityHash64(str.data(), str.size());
+ }
+ case arrow::Type::FIXED_SIZE_BINARY:
+ {
+ const auto& str = static_cast<const arrow::FixedSizeBinaryArray&>(ar).GetView(pos);
+ return CityHash64(str.data(), str.size());
+ }
+ case arrow::Type::TIMESTAMP:
+ return THash<i64>()(static_cast<const arrow::TimestampArray&>(ar).Value(pos));
+ case arrow::Type::TIME32:
+ return THash<i32>()(static_cast<const arrow::Time32Array&>(ar).Value(pos));
+ case arrow::Type::TIME64:
+ return THash<i64>()(static_cast<const arrow::Time64Array&>(ar).Value(pos));
+ case arrow::Type::DURATION:
+ return THash<i64>()(static_cast<const arrow::DurationArray&>(ar).Value(pos));
+ case arrow::Type::DATE32:
+ case arrow::Type::DATE64:
+ case arrow::Type::NA:
+ case arrow::Type::DECIMAL256:
+ case arrow::Type::DECIMAL:
+ case arrow::Type::DENSE_UNION:
+ case arrow::Type::DICTIONARY:
+ case arrow::Type::EXTENSION:
+ case arrow::Type::FIXED_SIZE_LIST:
+ case arrow::Type::INTERVAL_DAY_TIME:
+ case arrow::Type::INTERVAL_MONTHS:
+ case arrow::Type::LARGE_BINARY:
+ case arrow::Type::LARGE_LIST:
+ case arrow::Type::LARGE_STRING:
+ case arrow::Type::LIST:
+ case arrow::Type::MAP:
+ case arrow::Type::MAX_ID:
+ case arrow::Type::SPARSE_UNION:
+ case arrow::Type::STRUCT:
+ Y_ABORT("not implemented");
+ break;
+ }
+ return 0;
+}
+
+ui64 TShardedRecordBatch::GetMemorySize() const {
+ return NArrow::GetBatchMemorySize(RecordBatch);
+}
+
+std::vector<std::shared_ptr<arrow::RecordBatch>> TShardingSplitIndex::Apply(const std::shared_ptr<arrow::RecordBatch>& input) {
+ Y_ABORT_UNLESS(input);
+ Y_ABORT_UNLESS(input->num_rows() == RecordsCount);
+ auto permutation = BuildPermutation();
+ auto resultBatch = NArrow::TStatusValidator::GetValid(arrow::compute::Take(input, *permutation)).record_batch();
+ Y_ABORT_UNLESS(resultBatch->num_rows() == RecordsCount);
+ std::vector<std::shared_ptr<arrow::RecordBatch>> result;
+ ui64 startIndex = 0;
+ for (auto&& i : Remapping) {
+ result.emplace_back(resultBatch->Slice(startIndex, i.size()));
+ startIndex += i.size();
+ }
+ return result;
+}
+
+NKikimr::NArrow::TShardedRecordBatch TShardingSplitIndex::Apply(const ui32 shardsCount, const std::shared_ptr<arrow::RecordBatch>& input, const std::string& hashColumnName) {
+ if (!input) {
+ return TShardedRecordBatch();
+ }
+ if (shardsCount == 1) {
+ return TShardedRecordBatch(input);
+ }
+ auto hashColumn = input->GetColumnByName(hashColumnName);
+ if (!hashColumn) {
+ return TShardedRecordBatch(input);
+ }
+ std::optional<TShardingSplitIndex> splitter;
+ if (hashColumn->type()->id() == arrow::Type::UINT64) {
+ splitter = TShardingSplitIndex::Build<arrow::UInt64Array>(shardsCount, *hashColumn);
+ } else if (hashColumn->type()->id() == arrow::Type::UINT32) {
+ splitter = TShardingSplitIndex::Build<arrow::UInt32Array>(shardsCount, *hashColumn);
+ } else if (hashColumn->type()->id() == arrow::Type::INT64) {
+ splitter = TShardingSplitIndex::Build<arrow::Int64Array>(shardsCount, *hashColumn);
+ } else if (hashColumn->type()->id() == arrow::Type::INT32) {
+ splitter = TShardingSplitIndex::Build<arrow::Int32Array>(shardsCount, *hashColumn);
+ } else {
+ Y_ABORT_UNLESS(false);
+ }
+ return TShardedRecordBatch(input, splitter->Apply(input));
+}
+
+std::shared_ptr<arrow::UInt64Array> TShardingSplitIndex::BuildPermutation() const {
+ arrow::UInt64Builder builder;
+ Y_ABORT_UNLESS(builder.Reserve(RecordsCount).ok());
+
+ for (auto&& i : Remapping) {
+ for (auto&& idx : i) {
+ TStatusValidator::Validate(builder.Append(idx));
+ }
+ }
+
+ std::shared_ptr<arrow::UInt64Array> out;
+ Y_ABORT_UNLESS(builder.Finish(&out).ok());
+ return out;
+}
+
}
diff --git a/ydb/core/formats/arrow/permutations.h b/ydb/core/formats/arrow/permutations.h
index 3e380f7236a..918240dd837 100644
--- a/ydb/core/formats/arrow/permutations.h
+++ b/ydb/core/formats/arrow/permutations.h
@@ -5,6 +5,79 @@
namespace NKikimr::NArrow {
+class THashConstructor {
+public:
+ static bool BuildHashUI64(std::shared_ptr<arrow::RecordBatch>& batch, const std::vector<std::string>& fieldNames, const std::string& hashFieldName);
+
+ static size_t TypedHash(const std::vector<std::shared_ptr<arrow::Array>>& ar, const int pos);
+ static size_t TypedHash(const arrow::Array& ar, const int pos);
+};
+
+class TShardedRecordBatch {
+private:
+ YDB_READONLY_DEF(std::shared_ptr<arrow::RecordBatch>, RecordBatch);
+ YDB_READONLY_DEF(std::vector<std::shared_ptr<arrow::RecordBatch>>, SplittedByShards);
+public:
+ TShardedRecordBatch() = default;
+ TShardedRecordBatch(const std::shared_ptr<arrow::RecordBatch>& batch)
+ : RecordBatch(batch) {
+ SplittedByShards = {RecordBatch};
+ }
+
+ TShardedRecordBatch(const std::shared_ptr<arrow::RecordBatch>& batch, std::vector<std::shared_ptr<arrow::RecordBatch>>&& splittedByShards)
+ : RecordBatch(batch)
+ , SplittedByShards(std::move(splittedByShards))
+ {
+ }
+
+ ui64 GetMemorySize() const;
+
+ ui64 GetRecordsCount() const {
+ return RecordBatch->num_rows();
+ }
+};
+
+class TShardingSplitIndex {
+private:
+ ui32 ShardsCount = 0;
+ std::vector<std::vector<ui64>> Remapping;
+ ui32 RecordsCount = 0;
+
+ template <class TIntArrowArray>
+ void Initialize(const TIntArrowArray& arrowHashArray) {
+ Y_ABORT_UNLESS(ShardsCount);
+ Remapping.resize(ShardsCount);
+ for (ui64 i = 0; i < (ui64)arrowHashArray.length(); ++i) {
+ const auto v = arrowHashArray.GetView(i);
+ if (v < 0) {
+ Remapping[(-v) % ShardsCount].emplace_back(i);
+ } else {
+ Remapping[v % ShardsCount].emplace_back(i);
+ }
+ }
+ }
+
+ TShardingSplitIndex(const ui32 shardsCount, const arrow::Array& arrowHashArray)
+ : ShardsCount(shardsCount)
+ , RecordsCount(arrowHashArray.length()) {
+ }
+
+public:
+
+ template <class TArrayClass>
+ static TShardingSplitIndex Build(const ui32 shardsCount, const arrow::Array& arrowHashArray) {
+ TShardingSplitIndex result(shardsCount, arrowHashArray);
+ result.Initialize<TArrayClass>(static_cast<const TArrayClass&>(arrowHashArray));
+ return result;
+ }
+
+ std::shared_ptr<arrow::UInt64Array> BuildPermutation() const;
+
+ std::vector<std::shared_ptr<arrow::RecordBatch>> Apply(const std::shared_ptr<arrow::RecordBatch>& input);
+
+ static TShardedRecordBatch Apply(const ui32 shardsCount, const std::shared_ptr<arrow::RecordBatch>& input, const std::string& hashColumnName);
+};
+
std::shared_ptr<arrow::UInt64Array> MakePermutation(const int size, const bool reverse = false);
std::shared_ptr<arrow::UInt64Array> MakeFilterPermutation(const std::vector<ui64>& indexes);
std::shared_ptr<arrow::UInt64Array> MakeSortPermutation(const std::shared_ptr<arrow::RecordBatch>& batch,
diff --git a/ydb/core/formats/arrow/replace_key.h b/ydb/core/formats/arrow/replace_key.h
index 3b0dde3e64d..165cccdbf6d 100644
--- a/ydb/core/formats/arrow/replace_key.h
+++ b/ydb/core/formats/arrow/replace_key.h
@@ -64,10 +64,6 @@ public:
Y_ABORT_UNLESS(Size() > 0 && Position < (ui64)Column(0).length());
}
- size_t Hash() const {
- return TypedHash(Column(0), Position, Column(0).type_id());
- }
-
template<typename T>
bool operator == (const TReplaceKeyTemplate<T>& key) const {
Y_ABORT_UNLESS(Size() == key.Size());
@@ -223,75 +219,6 @@ private:
TArrayVecPtr Columns = nullptr;
ui64 Position = 0;
- static size_t TypedHash(const arrow::Array& ar, int pos, arrow::Type::type typeId) {
- switch (typeId) {
- case arrow::Type::NA:
- case arrow::Type::BOOL:
- break;
- case arrow::Type::UINT8:
- return THash<ui8>()(static_cast<const arrow::UInt8Array&>(ar).Value(pos));
- case arrow::Type::INT8:
- return THash<i8>()(static_cast<const arrow::Int8Array&>(ar).Value(pos));
- case arrow::Type::UINT16:
- return THash<ui16>()(static_cast<const arrow::UInt16Array&>(ar).Value(pos));
- case arrow::Type::INT16:
- return THash<i16>()(static_cast<const arrow::Int16Array&>(ar).Value(pos));
- case arrow::Type::UINT32:
- return THash<ui32>()(static_cast<const arrow::UInt32Array&>(ar).Value(pos));
- case arrow::Type::INT32:
- return THash<i32>()(static_cast<const arrow::Int32Array&>(ar).Value(pos));
- case arrow::Type::UINT64:
- return THash<ui64>()(static_cast<const arrow::UInt64Array&>(ar).Value(pos));
- case arrow::Type::INT64:
- return THash<i64>()(static_cast<const arrow::Int64Array&>(ar).Value(pos));
- case arrow::Type::HALF_FLOAT:
- break;
- case arrow::Type::FLOAT:
- return THash<float>()(static_cast<const arrow::FloatArray&>(ar).Value(pos));
- case arrow::Type::DOUBLE:
- return THash<double>()(static_cast<const arrow::DoubleArray&>(ar).Value(pos));
- case arrow::Type::STRING: {
- const auto& str = static_cast<const arrow::StringArray&>(ar).GetView(pos);
- return THash<std::string_view>()(std::string_view(str.data(), str.size()));
- }
- case arrow::Type::BINARY: {
- const auto& str = static_cast<const arrow::BinaryArray&>(ar).GetView(pos);
- return THash<std::string_view>()(std::string_view(str.data(), str.size()));
- }
- case arrow::Type::FIXED_SIZE_BINARY:
- case arrow::Type::DATE32:
- case arrow::Type::DATE64:
- break;
- case arrow::Type::TIMESTAMP:
- return THash<i64>()(static_cast<const arrow::TimestampArray&>(ar).Value(pos));
- case arrow::Type::TIME32:
- return THash<i32>()(static_cast<const arrow::Time32Array&>(ar).Value(pos));
- case arrow::Type::TIME64:
- return THash<i64>()(static_cast<const arrow::Time64Array&>(ar).Value(pos));
- case arrow::Type::DURATION:
- return THash<i64>()(static_cast<const arrow::DurationArray&>(ar).Value(pos));
- case arrow::Type::DECIMAL256:
- case arrow::Type::DECIMAL:
- case arrow::Type::DENSE_UNION:
- case arrow::Type::DICTIONARY:
- case arrow::Type::EXTENSION:
- case arrow::Type::FIXED_SIZE_LIST:
- case arrow::Type::INTERVAL_DAY_TIME:
- case arrow::Type::INTERVAL_MONTHS:
- case arrow::Type::LARGE_BINARY:
- case arrow::Type::LARGE_LIST:
- case arrow::Type::LARGE_STRING:
- case arrow::Type::LIST:
- case arrow::Type::MAP:
- case arrow::Type::MAX_ID:
- case arrow::Type::SPARSE_UNION:
- case arrow::Type::STRUCT:
- Y_ABORT("not implemented");
- break;
- }
- return 0;
- }
-
template <bool notNull>
static std::partial_ordering TypedCompare(const arrow::Array& lhs, int lpos, const arrow::Array& rhs, int rpos) {
arrow::Type::type typeId = lhs.type_id();
@@ -420,17 +347,3 @@ public:
}
-template<>
-struct THash<NKikimr::NArrow::TReplaceKey> {
- inline ui64 operator()(const NKikimr::NArrow::TReplaceKey& x) const noexcept {
- return x.Hash();
- }
-};
-
-template<>
-struct THash<NKikimr::NArrow::TRawReplaceKey> {
- inline ui64 operator()(const NKikimr::NArrow::TRawReplaceKey& x) const noexcept {
- return x.Hash();
- }
-};
-