diff options
author | ivanmorozov <[email protected]> | 2023-11-16 11:11:14 +0300 |
---|---|---|
committer | ivanmorozov <[email protected]> | 2023-11-16 11:42:23 +0300 |
commit | 0189677704e6b4969cc08a9a824ca1b7fa46701c (patch) | |
tree | 2d6ea1553806016f34685675462a898433e1e01e | |
parent | 0d32e31b28b3fd1e1ede9b32126cba1c8e5f18f2 (diff) |
KIKIMR-20087: add helpers for compute-sharding
-rw-r--r-- | ydb/core/formats/arrow/permutations.cpp | 183 | ||||
-rw-r--r-- | ydb/core/formats/arrow/permutations.h | 73 | ||||
-rw-r--r-- | ydb/core/formats/arrow/replace_key.h | 87 |
3 files changed, 256 insertions, 87 deletions
diff --git a/ydb/core/formats/arrow/permutations.cpp b/ydb/core/formats/arrow/permutations.cpp index 3abc3bb2cec..f00d621f56f 100644 --- a/ydb/core/formats/arrow/permutations.cpp +++ b/ydb/core/formats/arrow/permutations.cpp @@ -1,8 +1,10 @@ #include "arrow_helpers.h" #include "permutations.h" #include "replace_key.h" +#include "size_calcer.h" #include <ydb/core/formats/arrow/common/validation.h> #include <contrib/libs/apache/arrow/cpp/src/arrow/array/builder_primitive.h> +#include <library/cpp/actors/core/log.h> namespace NKikimr::NArrow { @@ -170,4 +172,185 @@ std::shared_ptr<arrow::Array> CopyRecords(const std::shared_ptr<arrow::Array>& s return result; } +bool THashConstructor::BuildHashUI64(std::shared_ptr<arrow::RecordBatch>& batch, const std::vector<std::string>& fieldNames, const std::string& hashFieldName) { + if (fieldNames.size() == 0) { + return false; + } + Y_ABORT_UNLESS(!batch->GetColumnByName(hashFieldName)); + if (fieldNames.size() == 1) { + auto column = batch->GetColumnByName(fieldNames.front()); + if (!column) { + return false; + } + Y_ABORT_UNLESS(column); + if (column->type()->id() == arrow::Type::UINT64 || column->type()->id() == arrow::Type::UINT32 || column->type()->id() == arrow::Type::INT64 || column->type()->id() == arrow::Type::INT32) { + batch = TStatusValidator::GetValid(batch->AddColumn(batch->num_columns(), hashFieldName, column)); + return true; + } + } + auto builder = NArrow::MakeBuilder(std::make_shared<arrow::Field>(hashFieldName, arrow::TypeTraits<arrow::UInt64Type>::type_singleton())); + { + auto& intBuilder = static_cast<arrow::UInt64Builder&>(*builder); + TStatusValidator::Validate(intBuilder.Reserve(batch->num_rows())); + std::vector<std::shared_ptr<arrow::Array>> columns; + for (auto&& i : fieldNames) { + auto column = batch->GetColumnByName(i); +// AFL_VERIFY(column)("column_name", i)("all_columns", JoinSeq(",", fieldNames)); + if (column) { + columns.emplace_back(column); + } + } + if (columns.empty()) { + return false; + } + for (i64 i = 0; i < batch->num_rows(); ++i) { + intBuilder.UnsafeAppend(TypedHash(columns, i)); + } + } + batch = TStatusValidator::GetValid(batch->AddColumn(batch->num_columns(), hashFieldName, NArrow::TStatusValidator::GetValid(builder->Finish()))); + return true; +} + +size_t THashConstructor::TypedHash(const std::vector<std::shared_ptr<arrow::Array>>& ar, const int pos) { + size_t result = 0; + for (auto&& i : ar) { + result = CombineHashes(result, TypedHash(*i, pos)); + } + return result; +} + +size_t THashConstructor::TypedHash(const arrow::Array& ar, const int pos) { + switch (ar.type_id()) { + case arrow::Type::BOOL: + return (size_t)(static_cast<const arrow::BooleanArray&>(ar).Value(pos)); + case arrow::Type::UINT8: + return THash<ui8>()(static_cast<const arrow::UInt8Array&>(ar).Value(pos)); + case arrow::Type::INT8: + return THash<i8>()(static_cast<const arrow::Int8Array&>(ar).Value(pos)); + case arrow::Type::UINT16: + return THash<ui16>()(static_cast<const arrow::UInt16Array&>(ar).Value(pos)); + case arrow::Type::INT16: + return THash<i16>()(static_cast<const arrow::Int16Array&>(ar).Value(pos)); + case arrow::Type::UINT32: + return THash<ui32>()(static_cast<const arrow::UInt32Array&>(ar).Value(pos)); + case arrow::Type::INT32: + return THash<i32>()(static_cast<const arrow::Int32Array&>(ar).Value(pos)); + case arrow::Type::UINT64: + return THash<ui64>()(static_cast<const arrow::UInt64Array&>(ar).Value(pos)); + case arrow::Type::INT64: + return THash<i64>()(static_cast<const arrow::Int64Array&>(ar).Value(pos)); + case arrow::Type::HALF_FLOAT: + break; + case arrow::Type::FLOAT: + return THash<float>()(static_cast<const arrow::FloatArray&>(ar).Value(pos)); + case arrow::Type::DOUBLE: + return THash<double>()(static_cast<const arrow::DoubleArray&>(ar).Value(pos)); + case arrow::Type::STRING: + { + const auto& str = static_cast<const arrow::StringArray&>(ar).GetView(pos); + return CityHash64(str.data(), str.size()); + } + case arrow::Type::BINARY: + { + const auto& str = static_cast<const arrow::BinaryArray&>(ar).GetView(pos); + return CityHash64(str.data(), str.size()); + } + case arrow::Type::FIXED_SIZE_BINARY: + { + const auto& str = static_cast<const arrow::FixedSizeBinaryArray&>(ar).GetView(pos); + return CityHash64(str.data(), str.size()); + } + case arrow::Type::TIMESTAMP: + return THash<i64>()(static_cast<const arrow::TimestampArray&>(ar).Value(pos)); + case arrow::Type::TIME32: + return THash<i32>()(static_cast<const arrow::Time32Array&>(ar).Value(pos)); + case arrow::Type::TIME64: + return THash<i64>()(static_cast<const arrow::Time64Array&>(ar).Value(pos)); + case arrow::Type::DURATION: + return THash<i64>()(static_cast<const arrow::DurationArray&>(ar).Value(pos)); + case arrow::Type::DATE32: + case arrow::Type::DATE64: + case arrow::Type::NA: + case arrow::Type::DECIMAL256: + case arrow::Type::DECIMAL: + case arrow::Type::DENSE_UNION: + case arrow::Type::DICTIONARY: + case arrow::Type::EXTENSION: + case arrow::Type::FIXED_SIZE_LIST: + case arrow::Type::INTERVAL_DAY_TIME: + case arrow::Type::INTERVAL_MONTHS: + case arrow::Type::LARGE_BINARY: + case arrow::Type::LARGE_LIST: + case arrow::Type::LARGE_STRING: + case arrow::Type::LIST: + case arrow::Type::MAP: + case arrow::Type::MAX_ID: + case arrow::Type::SPARSE_UNION: + case arrow::Type::STRUCT: + Y_ABORT("not implemented"); + break; + } + return 0; +} + +ui64 TShardedRecordBatch::GetMemorySize() const { + return NArrow::GetBatchMemorySize(RecordBatch); +} + +std::vector<std::shared_ptr<arrow::RecordBatch>> TShardingSplitIndex::Apply(const std::shared_ptr<arrow::RecordBatch>& input) { + Y_ABORT_UNLESS(input); + Y_ABORT_UNLESS(input->num_rows() == RecordsCount); + auto permutation = BuildPermutation(); + auto resultBatch = NArrow::TStatusValidator::GetValid(arrow::compute::Take(input, *permutation)).record_batch(); + Y_ABORT_UNLESS(resultBatch->num_rows() == RecordsCount); + std::vector<std::shared_ptr<arrow::RecordBatch>> result; + ui64 startIndex = 0; + for (auto&& i : Remapping) { + result.emplace_back(resultBatch->Slice(startIndex, i.size())); + startIndex += i.size(); + } + return result; +} + +NKikimr::NArrow::TShardedRecordBatch TShardingSplitIndex::Apply(const ui32 shardsCount, const std::shared_ptr<arrow::RecordBatch>& input, const std::string& hashColumnName) { + if (!input) { + return TShardedRecordBatch(); + } + if (shardsCount == 1) { + return TShardedRecordBatch(input); + } + auto hashColumn = input->GetColumnByName(hashColumnName); + if (!hashColumn) { + return TShardedRecordBatch(input); + } + std::optional<TShardingSplitIndex> splitter; + if (hashColumn->type()->id() == arrow::Type::UINT64) { + splitter = TShardingSplitIndex::Build<arrow::UInt64Array>(shardsCount, *hashColumn); + } else if (hashColumn->type()->id() == arrow::Type::UINT32) { + splitter = TShardingSplitIndex::Build<arrow::UInt32Array>(shardsCount, *hashColumn); + } else if (hashColumn->type()->id() == arrow::Type::INT64) { + splitter = TShardingSplitIndex::Build<arrow::Int64Array>(shardsCount, *hashColumn); + } else if (hashColumn->type()->id() == arrow::Type::INT32) { + splitter = TShardingSplitIndex::Build<arrow::Int32Array>(shardsCount, *hashColumn); + } else { + Y_ABORT_UNLESS(false); + } + return TShardedRecordBatch(input, splitter->Apply(input)); +} + +std::shared_ptr<arrow::UInt64Array> TShardingSplitIndex::BuildPermutation() const { + arrow::UInt64Builder builder; + Y_ABORT_UNLESS(builder.Reserve(RecordsCount).ok()); + + for (auto&& i : Remapping) { + for (auto&& idx : i) { + TStatusValidator::Validate(builder.Append(idx)); + } + } + + std::shared_ptr<arrow::UInt64Array> out; + Y_ABORT_UNLESS(builder.Finish(&out).ok()); + return out; +} + } diff --git a/ydb/core/formats/arrow/permutations.h b/ydb/core/formats/arrow/permutations.h index 3e380f7236a..918240dd837 100644 --- a/ydb/core/formats/arrow/permutations.h +++ b/ydb/core/formats/arrow/permutations.h @@ -5,6 +5,79 @@ namespace NKikimr::NArrow { +class THashConstructor { +public: + static bool BuildHashUI64(std::shared_ptr<arrow::RecordBatch>& batch, const std::vector<std::string>& fieldNames, const std::string& hashFieldName); + + static size_t TypedHash(const std::vector<std::shared_ptr<arrow::Array>>& ar, const int pos); + static size_t TypedHash(const arrow::Array& ar, const int pos); +}; + +class TShardedRecordBatch { +private: + YDB_READONLY_DEF(std::shared_ptr<arrow::RecordBatch>, RecordBatch); + YDB_READONLY_DEF(std::vector<std::shared_ptr<arrow::RecordBatch>>, SplittedByShards); +public: + TShardedRecordBatch() = default; + TShardedRecordBatch(const std::shared_ptr<arrow::RecordBatch>& batch) + : RecordBatch(batch) { + SplittedByShards = {RecordBatch}; + } + + TShardedRecordBatch(const std::shared_ptr<arrow::RecordBatch>& batch, std::vector<std::shared_ptr<arrow::RecordBatch>>&& splittedByShards) + : RecordBatch(batch) + , SplittedByShards(std::move(splittedByShards)) + { + } + + ui64 GetMemorySize() const; + + ui64 GetRecordsCount() const { + return RecordBatch->num_rows(); + } +}; + +class TShardingSplitIndex { +private: + ui32 ShardsCount = 0; + std::vector<std::vector<ui64>> Remapping; + ui32 RecordsCount = 0; + + template <class TIntArrowArray> + void Initialize(const TIntArrowArray& arrowHashArray) { + Y_ABORT_UNLESS(ShardsCount); + Remapping.resize(ShardsCount); + for (ui64 i = 0; i < (ui64)arrowHashArray.length(); ++i) { + const auto v = arrowHashArray.GetView(i); + if (v < 0) { + Remapping[(-v) % ShardsCount].emplace_back(i); + } else { + Remapping[v % ShardsCount].emplace_back(i); + } + } + } + + TShardingSplitIndex(const ui32 shardsCount, const arrow::Array& arrowHashArray) + : ShardsCount(shardsCount) + , RecordsCount(arrowHashArray.length()) { + } + +public: + + template <class TArrayClass> + static TShardingSplitIndex Build(const ui32 shardsCount, const arrow::Array& arrowHashArray) { + TShardingSplitIndex result(shardsCount, arrowHashArray); + result.Initialize<TArrayClass>(static_cast<const TArrayClass&>(arrowHashArray)); + return result; + } + + std::shared_ptr<arrow::UInt64Array> BuildPermutation() const; + + std::vector<std::shared_ptr<arrow::RecordBatch>> Apply(const std::shared_ptr<arrow::RecordBatch>& input); + + static TShardedRecordBatch Apply(const ui32 shardsCount, const std::shared_ptr<arrow::RecordBatch>& input, const std::string& hashColumnName); +}; + std::shared_ptr<arrow::UInt64Array> MakePermutation(const int size, const bool reverse = false); std::shared_ptr<arrow::UInt64Array> MakeFilterPermutation(const std::vector<ui64>& indexes); std::shared_ptr<arrow::UInt64Array> MakeSortPermutation(const std::shared_ptr<arrow::RecordBatch>& batch, diff --git a/ydb/core/formats/arrow/replace_key.h b/ydb/core/formats/arrow/replace_key.h index 3b0dde3e64d..165cccdbf6d 100644 --- a/ydb/core/formats/arrow/replace_key.h +++ b/ydb/core/formats/arrow/replace_key.h @@ -64,10 +64,6 @@ public: Y_ABORT_UNLESS(Size() > 0 && Position < (ui64)Column(0).length()); } - size_t Hash() const { - return TypedHash(Column(0), Position, Column(0).type_id()); - } - template<typename T> bool operator == (const TReplaceKeyTemplate<T>& key) const { Y_ABORT_UNLESS(Size() == key.Size()); @@ -223,75 +219,6 @@ private: TArrayVecPtr Columns = nullptr; ui64 Position = 0; - static size_t TypedHash(const arrow::Array& ar, int pos, arrow::Type::type typeId) { - switch (typeId) { - case arrow::Type::NA: - case arrow::Type::BOOL: - break; - case arrow::Type::UINT8: - return THash<ui8>()(static_cast<const arrow::UInt8Array&>(ar).Value(pos)); - case arrow::Type::INT8: - return THash<i8>()(static_cast<const arrow::Int8Array&>(ar).Value(pos)); - case arrow::Type::UINT16: - return THash<ui16>()(static_cast<const arrow::UInt16Array&>(ar).Value(pos)); - case arrow::Type::INT16: - return THash<i16>()(static_cast<const arrow::Int16Array&>(ar).Value(pos)); - case arrow::Type::UINT32: - return THash<ui32>()(static_cast<const arrow::UInt32Array&>(ar).Value(pos)); - case arrow::Type::INT32: - return THash<i32>()(static_cast<const arrow::Int32Array&>(ar).Value(pos)); - case arrow::Type::UINT64: - return THash<ui64>()(static_cast<const arrow::UInt64Array&>(ar).Value(pos)); - case arrow::Type::INT64: - return THash<i64>()(static_cast<const arrow::Int64Array&>(ar).Value(pos)); - case arrow::Type::HALF_FLOAT: - break; - case arrow::Type::FLOAT: - return THash<float>()(static_cast<const arrow::FloatArray&>(ar).Value(pos)); - case arrow::Type::DOUBLE: - return THash<double>()(static_cast<const arrow::DoubleArray&>(ar).Value(pos)); - case arrow::Type::STRING: { - const auto& str = static_cast<const arrow::StringArray&>(ar).GetView(pos); - return THash<std::string_view>()(std::string_view(str.data(), str.size())); - } - case arrow::Type::BINARY: { - const auto& str = static_cast<const arrow::BinaryArray&>(ar).GetView(pos); - return THash<std::string_view>()(std::string_view(str.data(), str.size())); - } - case arrow::Type::FIXED_SIZE_BINARY: - case arrow::Type::DATE32: - case arrow::Type::DATE64: - break; - case arrow::Type::TIMESTAMP: - return THash<i64>()(static_cast<const arrow::TimestampArray&>(ar).Value(pos)); - case arrow::Type::TIME32: - return THash<i32>()(static_cast<const arrow::Time32Array&>(ar).Value(pos)); - case arrow::Type::TIME64: - return THash<i64>()(static_cast<const arrow::Time64Array&>(ar).Value(pos)); - case arrow::Type::DURATION: - return THash<i64>()(static_cast<const arrow::DurationArray&>(ar).Value(pos)); - case arrow::Type::DECIMAL256: - case arrow::Type::DECIMAL: - case arrow::Type::DENSE_UNION: - case arrow::Type::DICTIONARY: - case arrow::Type::EXTENSION: - case arrow::Type::FIXED_SIZE_LIST: - case arrow::Type::INTERVAL_DAY_TIME: - case arrow::Type::INTERVAL_MONTHS: - case arrow::Type::LARGE_BINARY: - case arrow::Type::LARGE_LIST: - case arrow::Type::LARGE_STRING: - case arrow::Type::LIST: - case arrow::Type::MAP: - case arrow::Type::MAX_ID: - case arrow::Type::SPARSE_UNION: - case arrow::Type::STRUCT: - Y_ABORT("not implemented"); - break; - } - return 0; - } - template <bool notNull> static std::partial_ordering TypedCompare(const arrow::Array& lhs, int lpos, const arrow::Array& rhs, int rpos) { arrow::Type::type typeId = lhs.type_id(); @@ -420,17 +347,3 @@ public: } -template<> -struct THash<NKikimr::NArrow::TReplaceKey> { - inline ui64 operator()(const NKikimr::NArrow::TReplaceKey& x) const noexcept { - return x.Hash(); - } -}; - -template<> -struct THash<NKikimr::NArrow::TRawReplaceKey> { - inline ui64 operator()(const NKikimr::NArrow::TRawReplaceKey& x) const noexcept { - return x.Hash(); - } -}; - |