aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorivanmorozov <ivanmorozov@yandex-team.com>2023-06-29 20:59:25 +0300
committerivanmorozov <ivanmorozov@yandex-team.com>2023-06-29 20:59:25 +0300
commit68d778d6318cc6a38942ba13a0cf21faea919b77 (patch)
treecf3f59b5e437f077a0deae4bc153ae63044de25d
parentd1915c974c742501871fb6a49f12aeff142a4068 (diff)
downloadydb-68d778d6318cc6a38942ba13a0cf21faea919b77.tar.gz
fix dictionary size calculation
-rw-r--r--ydb/core/formats/arrow/dictionary/conversion.cpp8
-rw-r--r--ydb/core/formats/arrow/dictionary/conversion.h2
-rw-r--r--ydb/core/formats/arrow/size_calcer.cpp5
-rw-r--r--ydb/core/formats/arrow/ut/ut_size_calcer.cpp6
4 files changed, 18 insertions, 3 deletions
diff --git a/ydb/core/formats/arrow/dictionary/conversion.cpp b/ydb/core/formats/arrow/dictionary/conversion.cpp
index 6250b8c934..bf7f5c0288 100644
--- a/ydb/core/formats/arrow/dictionary/conversion.cpp
+++ b/ydb/core/formats/arrow/dictionary/conversion.cpp
@@ -2,6 +2,7 @@
#include <ydb/core/formats/arrow/switch/switch_type.h>
#include <ydb/core/formats/arrow/simple_builder/filler.h>
#include <ydb/core/formats/arrow/simple_builder/array.h>
+#include <ydb/core/formats/arrow/size_calcer.h>
namespace NKikimr::NArrow {
@@ -130,4 +131,11 @@ bool IsDictionableArray(const std::shared_ptr<arrow::Array>& data) {
return result;
}
+ui64 GetDictionarySize(const std::shared_ptr<arrow::DictionaryArray>& data) {
+ if (!data) {
+ return 0;
+ }
+ return GetArrayDataSize(data->dictionary()) + GetArrayDataSize(data->indices());
+}
+
}
diff --git a/ydb/core/formats/arrow/dictionary/conversion.h b/ydb/core/formats/arrow/dictionary/conversion.h
index 787fd1050c..ee044bfd51 100644
--- a/ydb/core/formats/arrow/dictionary/conversion.h
+++ b/ydb/core/formats/arrow/dictionary/conversion.h
@@ -2,10 +2,12 @@
#include <contrib/libs/apache/arrow/cpp/src/arrow/array/array_dict.h>
#include <contrib/libs/apache/arrow/cpp/src/arrow/array/array_base.h>
#include <contrib/libs/apache/arrow/cpp/src/arrow/record_batch.h>
+#include <util/system/types.h>
namespace NKikimr::NArrow {
bool IsDictionableArray(const std::shared_ptr<arrow::Array>& data);
+ui64 GetDictionarySize(const std::shared_ptr<arrow::DictionaryArray>& data);
std::shared_ptr<arrow::DictionaryArray> ArrayToDictionary(const std::shared_ptr<arrow::Array>& data);
std::shared_ptr<arrow::RecordBatch> ArrayToDictionary(const std::shared_ptr<arrow::RecordBatch>& data);
std::shared_ptr<arrow::Array> DictionaryToArray(const std::shared_ptr<arrow::DictionaryArray>& data);
diff --git a/ydb/core/formats/arrow/size_calcer.cpp b/ydb/core/formats/arrow/size_calcer.cpp
index 41d9634595..c09fc15c89 100644
--- a/ydb/core/formats/arrow/size_calcer.cpp
+++ b/ydb/core/formats/arrow/size_calcer.cpp
@@ -1,6 +1,7 @@
#include "size_calcer.h"
#include "switch_type.h"
#include "arrow_helpers.h"
+#include "dictionary/conversion.h"
#include <contrib/libs/apache/arrow/cpp/src/arrow/type.h>
#include <util/system/yassert.h>
@@ -153,6 +154,10 @@ ui64 GetArrayDataSizeImpl<arrow::Decimal128Type>(const std::shared_ptr<arrow::Ar
ui64 GetArrayDataSize(const std::shared_ptr<arrow::Array>& column) {
auto type = column->type();
+ if (type->id() == arrow::Type::DICTIONARY) {
+ auto dictArray = static_pointer_cast<arrow::DictionaryArray>(column);
+ return GetDictionarySize(dictArray);
+ }
ui64 bytes = 0;
bool success = SwitchTypeWithNull(type->id(), [&]<typename TType>(TTypeWrapper<TType> typeHolder) {
Y_UNUSED(typeHolder);
diff --git a/ydb/core/formats/arrow/ut/ut_size_calcer.cpp b/ydb/core/formats/arrow/ut/ut_size_calcer.cpp
index 1b9835d801..b7c959e325 100644
--- a/ydb/core/formats/arrow/ut/ut_size_calcer.cpp
+++ b/ydb/core/formats/arrow/ut/ut_size_calcer.cpp
@@ -13,7 +13,7 @@ Y_UNIT_TEST_SUITE(SizeCalcer) {
Y_UNIT_TEST(SimpleStrings) {
NConstruction::IArrayBuilder::TPtr column = std::make_shared<NConstruction::TSimpleArrayConstructor<NConstruction::TStringPoolFiller>>(
- "field", NConstruction::TStringPoolFiller(1024, 512));
+ "field", NConstruction::TStringPoolFiller(8, 512));
std::shared_ptr<arrow::RecordBatch> batch = NConstruction::TRecordBatchConstructor({ column }).BuildBatch(2048);
Cerr << GetBatchDataSize(batch) << Endl;
UNIT_ASSERT(GetBatchDataSize(batch) == 2048 * 512);
@@ -21,10 +21,10 @@ Y_UNIT_TEST_SUITE(SizeCalcer) {
Y_UNIT_TEST(DictionaryStrings) {
NConstruction::IArrayBuilder::TPtr column = std::make_shared<NConstruction::TDictionaryArrayConstructor<NConstruction::TStringPoolFiller>>(
- "field", NConstruction::TStringPoolFiller(1024, 512));
+ "field", NConstruction::TStringPoolFiller(8, 512));
std::shared_ptr<arrow::RecordBatch> batch = NConstruction::TRecordBatchConstructor({ column }).BuildBatch(2048);
Cerr << GetBatchDataSize(batch) << Endl;
- UNIT_ASSERT(GetBatchDataSize(batch) == 2048 * 512);
+ UNIT_ASSERT(GetBatchDataSize(batch) == 8 * 512 + 2048);
}
Y_UNIT_TEST(SimpleInt64) {