aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorkungurtsev <kungasc@ydb.tech>2025-04-08 19:25:21 +0200
committerGitHub <noreply@github.com>2025-04-08 19:25:21 +0200
commit16688838aecc0561c93778b50e9ea199673fc004 (patch)
treed60a5993c8b093b36aace6e90794bd8304e6c3ab
parent0b84ae0cbe15d2e5bc5fb8a7c2fcff8bcfb2ee2f (diff)
downloadydb-16688838aecc0561c93778b50e9ea199673fc004.tar.gz
Vector Index Local KMeans with one scan (#16909)
-rw-r--r--ydb/core/tx/datashard/datashard_ut_local_kmeans.cpp267
-rw-r--r--ydb/core/tx/datashard/datashard_ut_prefix_kmeans.cpp13
-rw-r--r--ydb/core/tx/datashard/local_kmeans.cpp633
-rw-r--r--ydb/core/tx/datashard/prefix_kmeans.cpp30
-rw-r--r--ydb/core/tx/datashard/reshuffle_kmeans.cpp4
-rw-r--r--ydb/core/tx/datashard/sample_k.cpp3
6 files changed, 630 insertions, 320 deletions
diff --git a/ydb/core/tx/datashard/datashard_ut_local_kmeans.cpp b/ydb/core/tx/datashard/datashard_ut_local_kmeans.cpp
index 3d081ef34ad..931af10560d 100644
--- a/ydb/core/tx/datashard/datashard_ut_local_kmeans.cpp
+++ b/ydb/core/tx/datashard/datashard_ut_local_kmeans.cpp
@@ -90,9 +90,9 @@ Y_UNIT_TEST_SUITE (TTxDataShardLocalKMeansScan) {
}
static std::tuple<TString, TString> DoLocalKMeans(
- Tests::TServer::TPtr server, TActorId sender, NTableIndex::TClusterId parent, ui64 seed, ui64 k,
+ Tests::TServer::TPtr server, TActorId sender, NTableIndex::TClusterId parentFrom, NTableIndex::TClusterId parentTo, ui64 seed, ui64 k,
NKikimrTxDataShard::EKMeansState upload, VectorIndexSettings::VectorType type,
- VectorIndexSettings::Metric metric)
+ VectorIndexSettings::Metric metric, ui32 maxBatchRows = 50000)
{
auto id = sId.fetch_add(1, std::memory_order_relaxed);
auto& runtime = *server->GetRuntime();
@@ -131,15 +131,17 @@ Y_UNIT_TEST_SUITE (TTxDataShardLocalKMeansScan) {
rec.SetNeedsRounds(300);
- rec.SetParentFrom(parent);
- rec.SetParentTo(parent);
- rec.SetChild(parent + 1);
+ rec.SetParentFrom(parentFrom);
+ rec.SetParentTo(parentTo);
+ rec.SetChild(parentTo + 1);
rec.SetEmbeddingColumn("embedding");
rec.AddDataColumns("data");
rec.SetLevelName(kLevelTable);
rec.SetPostingName(kPostingTable);
+
+ rec.MutableScanSettings()->SetMaxBatchRows(maxBatchRows);
};
fill(ev1);
fill(ev2);
@@ -158,6 +160,10 @@ Y_UNIT_TEST_SUITE (TTxDataShardLocalKMeansScan) {
auto level = ReadShardedTable(server, kLevelTable);
auto posting = ReadShardedTable(server, kPostingTable);
+ Cerr << "Level:" << Endl;
+ Cerr << level << Endl;
+ Cerr << "Posting:" << Endl;
+ Cerr << posting << Endl;
return {std::move(level), std::move(posting)};
}
@@ -223,6 +229,7 @@ Y_UNIT_TEST_SUITE (TTxDataShardLocalKMeansScan) {
auto sender = runtime.AllocateEdgeActor();
runtime.SetLogPriority(NKikimrServices::TX_DATASHARD, NLog::PRI_DEBUG);
+ runtime.SetLogPriority(NKikimrServices::BUILD_INDEX, NLog::PRI_TRACE);
InitRoot(server, sender);
@@ -310,6 +317,7 @@ Y_UNIT_TEST_SUITE (TTxDataShardLocalKMeansScan) {
auto sender = runtime.AllocateEdgeActor();
runtime.SetLogPriority(NKikimrServices::TX_DATASHARD, NLog::PRI_DEBUG);
+ runtime.SetLogPriority(NKikimrServices::BUILD_INDEX, NLog::PRI_TRACE);
InitRoot(server, sender);
@@ -346,7 +354,7 @@ Y_UNIT_TEST_SUITE (TTxDataShardLocalKMeansScan) {
seed = 0;
for (auto distance : {VectorIndexSettings::DISTANCE_MANHATTAN, VectorIndexSettings::DISTANCE_EUCLIDEAN}) {
- auto [level, posting] = DoLocalKMeans(server, sender, 0, seed, k,
+ auto [level, posting] = DoLocalKMeans(server, sender, 0, 0, seed, k,
NKikimrTxDataShard::EKMeansState::UPLOAD_MAIN_TO_POSTING,
VectorIndexSettings::VECTOR_TYPE_UINT8, distance);
UNIT_ASSERT_VALUES_EQUAL(level, "__ydb_parent = 0, __ydb_id = 1, __ydb_centroid = mm\3\n"
@@ -361,7 +369,7 @@ Y_UNIT_TEST_SUITE (TTxDataShardLocalKMeansScan) {
seed = 111;
for (auto distance : {VectorIndexSettings::DISTANCE_MANHATTAN, VectorIndexSettings::DISTANCE_EUCLIDEAN}) {
- auto [level, posting] = DoLocalKMeans(server, sender, 0, seed, k,
+ auto [level, posting] = DoLocalKMeans(server, sender, 0, 0, seed, k,
NKikimrTxDataShard::EKMeansState::UPLOAD_MAIN_TO_POSTING,
VectorIndexSettings::VECTOR_TYPE_UINT8, distance);
UNIT_ASSERT_VALUES_EQUAL(level, "__ydb_parent = 0, __ydb_id = 1, __ydb_centroid = 11\3\n"
@@ -377,7 +385,7 @@ Y_UNIT_TEST_SUITE (TTxDataShardLocalKMeansScan) {
for (auto similarity : {VectorIndexSettings::SIMILARITY_INNER_PRODUCT, VectorIndexSettings::SIMILARITY_COSINE,
VectorIndexSettings::DISTANCE_COSINE})
{
- auto [level, posting] = DoLocalKMeans(server, sender, 0, seed, k,
+ auto [level, posting] = DoLocalKMeans(server, sender, 0, 0, seed, k,
NKikimrTxDataShard::EKMeansState::UPLOAD_MAIN_TO_POSTING,
VectorIndexSettings::VECTOR_TYPE_UINT8, similarity);
UNIT_ASSERT_VALUES_EQUAL(level, "__ydb_parent = 0, __ydb_id = 1, __ydb_centroid = II\3\n");
@@ -400,6 +408,7 @@ Y_UNIT_TEST_SUITE (TTxDataShardLocalKMeansScan) {
auto sender = runtime.AllocateEdgeActor();
runtime.SetLogPriority(NKikimrServices::TX_DATASHARD, NLog::PRI_DEBUG);
+ runtime.SetLogPriority(NKikimrServices::BUILD_INDEX, NLog::PRI_TRACE);
InitRoot(server, sender);
@@ -436,7 +445,7 @@ Y_UNIT_TEST_SUITE (TTxDataShardLocalKMeansScan) {
seed = 0;
for (auto distance : {VectorIndexSettings::DISTANCE_MANHATTAN, VectorIndexSettings::DISTANCE_EUCLIDEAN}) {
- auto [level, posting] = DoLocalKMeans(server, sender, 0, seed, k,
+ auto [level, posting] = DoLocalKMeans(server, sender, 0, 0, seed, k,
NKikimrTxDataShard::EKMeansState::UPLOAD_MAIN_TO_BUILD,
VectorIndexSettings::VECTOR_TYPE_UINT8, distance);
UNIT_ASSERT_VALUES_EQUAL(level, "__ydb_parent = 0, __ydb_id = 1, __ydb_centroid = mm\3\n"
@@ -451,7 +460,7 @@ Y_UNIT_TEST_SUITE (TTxDataShardLocalKMeansScan) {
seed = 111;
for (auto distance : {VectorIndexSettings::DISTANCE_MANHATTAN, VectorIndexSettings::DISTANCE_EUCLIDEAN}) {
- auto [level, posting] = DoLocalKMeans(server, sender, 0, seed, k,
+ auto [level, posting] = DoLocalKMeans(server, sender, 0, 0, seed, k,
NKikimrTxDataShard::EKMeansState::UPLOAD_MAIN_TO_BUILD,
VectorIndexSettings::VECTOR_TYPE_UINT8, distance);
UNIT_ASSERT_VALUES_EQUAL(level, "__ydb_parent = 0, __ydb_id = 1, __ydb_centroid = 11\3\n"
@@ -467,7 +476,7 @@ Y_UNIT_TEST_SUITE (TTxDataShardLocalKMeansScan) {
for (auto similarity : {VectorIndexSettings::SIMILARITY_INNER_PRODUCT, VectorIndexSettings::SIMILARITY_COSINE,
VectorIndexSettings::DISTANCE_COSINE})
{
- auto [level, posting] = DoLocalKMeans(server, sender, 0, seed, k,
+ auto [level, posting] = DoLocalKMeans(server, sender, 0, 0, seed, k,
NKikimrTxDataShard::EKMeansState::UPLOAD_MAIN_TO_BUILD,
VectorIndexSettings::VECTOR_TYPE_UINT8, similarity);
UNIT_ASSERT_VALUES_EQUAL(level, "__ydb_parent = 0, __ydb_id = 1, __ydb_centroid = II\3\n");
@@ -490,6 +499,7 @@ Y_UNIT_TEST_SUITE (TTxDataShardLocalKMeansScan) {
auto sender = runtime.AllocateEdgeActor();
runtime.SetLogPriority(NKikimrServices::TX_DATASHARD, NLog::PRI_DEBUG);
+ runtime.SetLogPriority(NKikimrServices::BUILD_INDEX, NLog::PRI_TRACE);
InitRoot(server, sender);
@@ -528,7 +538,7 @@ Y_UNIT_TEST_SUITE (TTxDataShardLocalKMeansScan) {
seed = 0;
for (auto distance : {VectorIndexSettings::DISTANCE_MANHATTAN, VectorIndexSettings::DISTANCE_EUCLIDEAN}) {
- auto [level, posting] = DoLocalKMeans(server, sender, 40, seed, k,
+ auto [level, posting] = DoLocalKMeans(server, sender, 40, 40, seed, k,
NKikimrTxDataShard::EKMeansState::UPLOAD_BUILD_TO_POSTING,
VectorIndexSettings::VECTOR_TYPE_UINT8, distance);
UNIT_ASSERT_VALUES_EQUAL(level, "__ydb_parent = 40, __ydb_id = 41, __ydb_centroid = mm\3\n"
@@ -543,7 +553,7 @@ Y_UNIT_TEST_SUITE (TTxDataShardLocalKMeansScan) {
seed = 111;
for (auto distance : {VectorIndexSettings::DISTANCE_MANHATTAN, VectorIndexSettings::DISTANCE_EUCLIDEAN}) {
- auto [level, posting] = DoLocalKMeans(server, sender, 40, seed, k,
+ auto [level, posting] = DoLocalKMeans(server, sender, 40, 40, seed, k,
NKikimrTxDataShard::EKMeansState::UPLOAD_BUILD_TO_POSTING,
VectorIndexSettings::VECTOR_TYPE_UINT8, distance);
UNIT_ASSERT_VALUES_EQUAL(level, "__ydb_parent = 40, __ydb_id = 41, __ydb_centroid = 11\3\n"
@@ -559,7 +569,7 @@ Y_UNIT_TEST_SUITE (TTxDataShardLocalKMeansScan) {
for (auto similarity : {VectorIndexSettings::SIMILARITY_INNER_PRODUCT, VectorIndexSettings::SIMILARITY_COSINE,
VectorIndexSettings::DISTANCE_COSINE})
{
- auto [level, posting] = DoLocalKMeans(server, sender, 40, seed, k,
+ auto [level, posting] = DoLocalKMeans(server, sender, 40, 40, seed, k,
NKikimrTxDataShard::EKMeansState::UPLOAD_BUILD_TO_POSTING,
VectorIndexSettings::VECTOR_TYPE_UINT8, similarity);
UNIT_ASSERT_VALUES_EQUAL(level, "__ydb_parent = 40, __ydb_id = 41, __ydb_centroid = II\3\n");
@@ -582,6 +592,7 @@ Y_UNIT_TEST_SUITE (TTxDataShardLocalKMeansScan) {
auto sender = runtime.AllocateEdgeActor();
runtime.SetLogPriority(NKikimrServices::TX_DATASHARD, NLog::PRI_DEBUG);
+ runtime.SetLogPriority(NKikimrServices::BUILD_INDEX, NLog::PRI_TRACE);
InitRoot(server, sender);
@@ -620,7 +631,7 @@ Y_UNIT_TEST_SUITE (TTxDataShardLocalKMeansScan) {
seed = 0;
for (auto distance : {VectorIndexSettings::DISTANCE_MANHATTAN, VectorIndexSettings::DISTANCE_EUCLIDEAN}) {
- auto [level, posting] = DoLocalKMeans(server, sender, 40, seed, k,
+ auto [level, posting] = DoLocalKMeans(server, sender, 40, 40, seed, k,
NKikimrTxDataShard::EKMeansState::UPLOAD_BUILD_TO_BUILD,
VectorIndexSettings::VECTOR_TYPE_UINT8, distance);
UNIT_ASSERT_VALUES_EQUAL(level, "__ydb_parent = 40, __ydb_id = 41, __ydb_centroid = mm\3\n"
@@ -635,7 +646,7 @@ Y_UNIT_TEST_SUITE (TTxDataShardLocalKMeansScan) {
seed = 111;
for (auto distance : {VectorIndexSettings::DISTANCE_MANHATTAN, VectorIndexSettings::DISTANCE_EUCLIDEAN}) {
- auto [level, posting] = DoLocalKMeans(server, sender, 40, seed, k,
+ auto [level, posting] = DoLocalKMeans(server, sender, 40, 40, seed, k,
NKikimrTxDataShard::EKMeansState::UPLOAD_BUILD_TO_BUILD,
VectorIndexSettings::VECTOR_TYPE_UINT8, distance);
UNIT_ASSERT_VALUES_EQUAL(level, "__ydb_parent = 40, __ydb_id = 41, __ydb_centroid = 11\3\n"
@@ -651,7 +662,7 @@ Y_UNIT_TEST_SUITE (TTxDataShardLocalKMeansScan) {
for (auto similarity : {VectorIndexSettings::SIMILARITY_INNER_PRODUCT, VectorIndexSettings::SIMILARITY_COSINE,
VectorIndexSettings::DISTANCE_COSINE})
{
- auto [level, posting] = DoLocalKMeans(server, sender, 40, seed, k,
+ auto [level, posting] = DoLocalKMeans(server, sender, 40, 40, seed, k,
NKikimrTxDataShard::EKMeansState::UPLOAD_BUILD_TO_BUILD,
VectorIndexSettings::VECTOR_TYPE_UINT8, similarity);
UNIT_ASSERT_VALUES_EQUAL(level, "__ydb_parent = 40, __ydb_id = 41, __ydb_centroid = II\3\n");
@@ -663,6 +674,228 @@ Y_UNIT_TEST_SUITE (TTxDataShardLocalKMeansScan) {
recreate();
}
}
+
+ Y_UNIT_TEST (BuildToBuild_Ranges) {
+ TPortManager pm;
+ TServerSettings serverSettings(pm.GetPort(2134));
+ serverSettings.SetDomainName("Root");
+
+ Tests::TServer::TPtr server = new TServer(serverSettings);
+ auto& runtime = *server->GetRuntime();
+ auto sender = runtime.AllocateEdgeActor();
+
+ runtime.SetLogPriority(NKikimrServices::TX_DATASHARD, NLog::PRI_DEBUG);
+ runtime.SetLogPriority(NKikimrServices::BUILD_INDEX, NLog::PRI_TRACE);
+
+ InitRoot(server, sender);
+
+ TShardedTableOptions options;
+ options.EnableOutOfOrder(true); // TODO(mbkkt) what is it?
+ options.Shards(1);
+
+ CreateBuildTable(server, sender, options, "table-main");
+ // Upsert some initial values
+ ExecSQL(server, sender,
+ R"(
+ UPSERT INTO `/Root/table-main`
+ (__ydb_parent, key, embedding, data)
+ VALUES )"
+ "(39, 1, \"\x30\x30\3\", \"one\"),"
+ "(39, 2, \"\x32\x32\3\", \"two\"),"
+ "(40, 1, \"\x30\x30\3\", \"one\"),"
+ "(40, 2, \"\x31\x31\3\", \"two\"),"
+ "(40, 3, \"\x32\x32\3\", \"three\"),"
+ "(40, 4, \"\x65\x65\3\", \"four\"),"
+ "(40, 5, \"\x75\x75\3\", \"five\"),"
+ "(41, 5, \"\x75\x75\3\", \"five2\"),"
+ "(41, 6, \"\x76\x76\3\", \"six\");");
+
+ auto create = [&] {
+ CreateLevelTable(server, sender, options);
+ CreateBuildTable(server, sender, options, "table-posting");
+ };
+ create();
+ auto recreate = [&] {
+ DropTable(server, sender, "table-level");
+ DropTable(server, sender, "table-posting");
+ create();
+ };
+
+ { // ParentFrom = 39 ParentTo = 39
+ for (ui32 maxBatchRows : {0, 1, 4, 5, 6, 50000}) {
+ auto [level, posting] = DoLocalKMeans(server, sender,
+ 39, 39, 111, 2,
+ NKikimrTxDataShard::EKMeansState::UPLOAD_BUILD_TO_BUILD, VectorIndexSettings::VECTOR_TYPE_UINT8, VectorIndexSettings::DISTANCE_MANHATTAN,
+ maxBatchRows);
+ UNIT_ASSERT_VALUES_EQUAL(level,
+ "__ydb_parent = 39, __ydb_id = 40, __ydb_centroid = 00\3\n"
+ "__ydb_parent = 39, __ydb_id = 41, __ydb_centroid = 22\3\n");
+ UNIT_ASSERT_VALUES_EQUAL(posting,
+ "__ydb_parent = 40, key = 1, embedding = 00\3, data = one\n"
+ "__ydb_parent = 41, key = 2, embedding = 22\3, data = two\n");
+ recreate();
+ }
+ }
+
+ { // ParentFrom = 40 ParentTo = 40
+ for (ui32 maxBatchRows : {0, 1, 4, 5, 6, 50000}) {
+ auto [level, posting] = DoLocalKMeans(server, sender,
+ 40, 40, 111, 2,
+ NKikimrTxDataShard::EKMeansState::UPLOAD_BUILD_TO_BUILD, VectorIndexSettings::VECTOR_TYPE_UINT8, VectorIndexSettings::DISTANCE_MANHATTAN,
+ maxBatchRows);
+ UNIT_ASSERT_VALUES_EQUAL(level,
+ "__ydb_parent = 40, __ydb_id = 41, __ydb_centroid = 11\3\n"
+ "__ydb_parent = 40, __ydb_id = 42, __ydb_centroid = mm\3\n");
+ UNIT_ASSERT_VALUES_EQUAL(posting,
+ "__ydb_parent = 41, key = 1, embedding = \x30\x30\3, data = one\n"
+ "__ydb_parent = 41, key = 2, embedding = \x31\x31\3, data = two\n"
+ "__ydb_parent = 41, key = 3, embedding = \x32\x32\3, data = three\n"
+ "__ydb_parent = 42, key = 4, embedding = \x65\x65\3, data = four\n"
+ "__ydb_parent = 42, key = 5, embedding = \x75\x75\3, data = five\n");
+ recreate();
+ }
+ }
+
+ { // ParentFrom = 41 ParentTo = 41
+ for (ui32 maxBatchRows : {0, 1, 4, 5, 6, 50000}) {
+ auto [level, posting] = DoLocalKMeans(server, sender,
+ 41, 41, 111, 2,
+ NKikimrTxDataShard::EKMeansState::UPLOAD_BUILD_TO_BUILD, VectorIndexSettings::VECTOR_TYPE_UINT8, VectorIndexSettings::DISTANCE_MANHATTAN,
+ maxBatchRows);
+ UNIT_ASSERT_VALUES_EQUAL(level,
+ "__ydb_parent = 41, __ydb_id = 42, __ydb_centroid = uu\3\n"
+ "__ydb_parent = 41, __ydb_id = 43, __ydb_centroid = vv\3\n");
+ UNIT_ASSERT_VALUES_EQUAL(posting,
+ "__ydb_parent = 42, key = 5, embedding = uu\3, data = five2\n"
+ "__ydb_parent = 43, key = 6, embedding = vv\3, data = six\n");
+ recreate();
+ }
+ }
+
+ { // ParentFrom = 39 ParentTo = 40
+ for (ui32 maxBatchRows : {0, 1, 4, 5, 6, 50000}) {
+ auto [level, posting] = DoLocalKMeans(server, sender,
+ 39, 40, 111, 2,
+ NKikimrTxDataShard::EKMeansState::UPLOAD_BUILD_TO_BUILD, VectorIndexSettings::VECTOR_TYPE_UINT8, VectorIndexSettings::DISTANCE_MANHATTAN,
+ maxBatchRows);
+ UNIT_ASSERT_VALUES_EQUAL(level,
+ "__ydb_parent = 39, __ydb_id = 41, __ydb_centroid = 00\3\n"
+ "__ydb_parent = 39, __ydb_id = 42, __ydb_centroid = 22\3\n"
+ "__ydb_parent = 40, __ydb_id = 43, __ydb_centroid = 11\3\n"
+ "__ydb_parent = 40, __ydb_id = 44, __ydb_centroid = mm\3\n");
+ UNIT_ASSERT_VALUES_EQUAL(posting,
+ "__ydb_parent = 41, key = 1, embedding = 00\3, data = one\n"
+ "__ydb_parent = 42, key = 2, embedding = 22\3, data = two\n"
+ "__ydb_parent = 43, key = 1, embedding = \x30\x30\3, data = one\n"
+ "__ydb_parent = 43, key = 2, embedding = \x31\x31\3, data = two\n"
+ "__ydb_parent = 43, key = 3, embedding = \x32\x32\3, data = three\n"
+ "__ydb_parent = 44, key = 4, embedding = \x65\x65\3, data = four\n"
+ "__ydb_parent = 44, key = 5, embedding = \x75\x75\3, data = five\n");
+ recreate();
+ }
+ }
+
+ { // ParentFrom = 40 ParentTo = 41
+ for (ui32 maxBatchRows : {0, 1, 4, 5, 6, 50000}) {
+ auto [level, posting] = DoLocalKMeans(server, sender,
+ 40, 41, 111, 2,
+ NKikimrTxDataShard::EKMeansState::UPLOAD_BUILD_TO_BUILD, VectorIndexSettings::VECTOR_TYPE_UINT8, VectorIndexSettings::DISTANCE_MANHATTAN,
+ maxBatchRows);
+ UNIT_ASSERT_VALUES_EQUAL(level,
+ "__ydb_parent = 40, __ydb_id = 42, __ydb_centroid = 11\3\n"
+ "__ydb_parent = 40, __ydb_id = 43, __ydb_centroid = mm\3\n"
+ "__ydb_parent = 41, __ydb_id = 44, __ydb_centroid = uu\3\n"
+ "__ydb_parent = 41, __ydb_id = 45, __ydb_centroid = vv\3\n");
+ UNIT_ASSERT_VALUES_EQUAL(posting,
+ "__ydb_parent = 42, key = 1, embedding = \x30\x30\3, data = one\n"
+ "__ydb_parent = 42, key = 2, embedding = \x31\x31\3, data = two\n"
+ "__ydb_parent = 42, key = 3, embedding = \x32\x32\3, data = three\n"
+ "__ydb_parent = 43, key = 4, embedding = \x65\x65\3, data = four\n"
+ "__ydb_parent = 43, key = 5, embedding = \x75\x75\3, data = five\n"
+ "__ydb_parent = 44, key = 5, embedding = uu\3, data = five2\n"
+ "__ydb_parent = 45, key = 6, embedding = vv\3, data = six\n");
+ recreate();
+ }
+ }
+
+ { // ParentFrom = 39 ParentTo = 41
+ for (ui32 maxBatchRows : {0, 1, 4, 5, 6, 50000}) {
+ auto [level, posting] = DoLocalKMeans(server, sender,
+ 39, 41, 111, 2,
+ NKikimrTxDataShard::EKMeansState::UPLOAD_BUILD_TO_BUILD, VectorIndexSettings::VECTOR_TYPE_UINT8, VectorIndexSettings::DISTANCE_MANHATTAN,
+ maxBatchRows);
+ UNIT_ASSERT_VALUES_EQUAL(level,
+ "__ydb_parent = 39, __ydb_id = 42, __ydb_centroid = 00\3\n"
+ "__ydb_parent = 39, __ydb_id = 43, __ydb_centroid = 22\3\n"
+ "__ydb_parent = 40, __ydb_id = 44, __ydb_centroid = 11\3\n"
+ "__ydb_parent = 40, __ydb_id = 45, __ydb_centroid = mm\3\n"
+ "__ydb_parent = 41, __ydb_id = 46, __ydb_centroid = uu\3\n"
+ "__ydb_parent = 41, __ydb_id = 47, __ydb_centroid = vv\3\n");
+ UNIT_ASSERT_VALUES_EQUAL(posting,
+ "__ydb_parent = 42, key = 1, embedding = 00\3, data = one\n"
+ "__ydb_parent = 43, key = 2, embedding = 22\3, data = two\n"
+ "__ydb_parent = 44, key = 1, embedding = \x30\x30\3, data = one\n"
+ "__ydb_parent = 44, key = 2, embedding = \x31\x31\3, data = two\n"
+ "__ydb_parent = 44, key = 3, embedding = \x32\x32\3, data = three\n"
+ "__ydb_parent = 45, key = 4, embedding = \x65\x65\3, data = four\n"
+ "__ydb_parent = 45, key = 5, embedding = \x75\x75\3, data = five\n"
+ "__ydb_parent = 46, key = 5, embedding = uu\3, data = five2\n"
+ "__ydb_parent = 47, key = 6, embedding = vv\3, data = six\n");
+ recreate();
+ }
+ }
+
+ { // ParentFrom = 30 ParentTo = 50
+ for (ui32 maxBatchRows : {0, 1, 4, 5, 6, 50000}) {
+ auto [level, posting] = DoLocalKMeans(server, sender,
+ 30, 50, 111, 2,
+ NKikimrTxDataShard::EKMeansState::UPLOAD_BUILD_TO_BUILD, VectorIndexSettings::VECTOR_TYPE_UINT8, VectorIndexSettings::DISTANCE_MANHATTAN,
+ maxBatchRows);
+ UNIT_ASSERT_VALUES_EQUAL(level,
+ "__ydb_parent = 39, __ydb_id = 69, __ydb_centroid = 00\3\n"
+ "__ydb_parent = 39, __ydb_id = 70, __ydb_centroid = 22\3\n"
+ "__ydb_parent = 40, __ydb_id = 71, __ydb_centroid = 11\3\n"
+ "__ydb_parent = 40, __ydb_id = 72, __ydb_centroid = mm\3\n"
+ "__ydb_parent = 41, __ydb_id = 73, __ydb_centroid = uu\3\n"
+ "__ydb_parent = 41, __ydb_id = 74, __ydb_centroid = vv\3\n");
+ UNIT_ASSERT_VALUES_EQUAL(posting,
+ "__ydb_parent = 69, key = 1, embedding = 00\3, data = one\n"
+ "__ydb_parent = 70, key = 2, embedding = 22\3, data = two\n"
+ "__ydb_parent = 71, key = 1, embedding = \x30\x30\3, data = one\n"
+ "__ydb_parent = 71, key = 2, embedding = \x31\x31\3, data = two\n"
+ "__ydb_parent = 71, key = 3, embedding = \x32\x32\3, data = three\n"
+ "__ydb_parent = 72, key = 4, embedding = \x65\x65\3, data = four\n"
+ "__ydb_parent = 72, key = 5, embedding = \x75\x75\3, data = five\n"
+ "__ydb_parent = 73, key = 5, embedding = uu\3, data = five2\n"
+ "__ydb_parent = 74, key = 6, embedding = vv\3, data = six\n");
+ recreate();
+ }
+ }
+
+ { // ParentFrom = 30 ParentTo = 31
+ for (ui32 maxBatchRows : {0, 1, 4, 5, 6, 50000}) {
+ auto [level, posting] = DoLocalKMeans(server, sender,
+ 30, 31, 111, 2,
+ NKikimrTxDataShard::EKMeansState::UPLOAD_BUILD_TO_BUILD, VectorIndexSettings::VECTOR_TYPE_UINT8, VectorIndexSettings::DISTANCE_MANHATTAN,
+ maxBatchRows);
+ UNIT_ASSERT_VALUES_EQUAL(level, "");
+ UNIT_ASSERT_VALUES_EQUAL(posting, "");
+ recreate();
+ }
+ }
+
+ { // ParentFrom = 100 ParentTo = 101
+ for (ui32 maxBatchRows : {0, 1, 4, 5, 6, 50000}) {
+ auto [level, posting] = DoLocalKMeans(server, sender,
+ 100, 101, 111, 2,
+ NKikimrTxDataShard::EKMeansState::UPLOAD_BUILD_TO_BUILD, VectorIndexSettings::VECTOR_TYPE_UINT8, VectorIndexSettings::DISTANCE_MANHATTAN,
+ maxBatchRows);
+ UNIT_ASSERT_VALUES_EQUAL(level, "");
+ UNIT_ASSERT_VALUES_EQUAL(posting, "");
+ recreate();
+ }
+ }
+ }
}
}
diff --git a/ydb/core/tx/datashard/datashard_ut_prefix_kmeans.cpp b/ydb/core/tx/datashard/datashard_ut_prefix_kmeans.cpp
index 4bab0e10981..e221986fee4 100644
--- a/ydb/core/tx/datashard/datashard_ut_prefix_kmeans.cpp
+++ b/ydb/core/tx/datashard/datashard_ut_prefix_kmeans.cpp
@@ -291,19 +291,6 @@ Y_UNIT_TEST_SUITE (TTxDataShardPrefixKMeansScan) {
DoBadRequest(server, sender, ev, 2, VectorIndexSettings::VECTOR_TYPE_FLOAT,
VectorIndexSettings::METRIC_UNSPECIFIED);
}
- // TODO(mbkkt) For now all build_index, sample_k, build_columns, local_kmeans, prefix_kmeans doesn't really check this
- // {
- // auto ev = std::make_unique<TEvDataShard::TEvPrefixKMeansRequest>();
- // auto snapshotCopy = snapshot;
- // snapshotCopy.Step++;
- // DoBadRequest(server, sender, ev);
- // }
- // {
- // auto ev = std::make_unique<TEvDataShard::TEvPrefixKMeansRequest>();
- // auto snapshotCopy = snapshot;
- // snapshotCopy.TxId++;
- // DoBadRequest(server, sender, ev);
- // }
}
Y_UNIT_TEST (BuildToPosting) {
diff --git a/ydb/core/tx/datashard/local_kmeans.cpp b/ydb/core/tx/datashard/local_kmeans.cpp
index b39a09c3f60..5b83216c0df 100644
--- a/ydb/core/tx/datashard/local_kmeans.cpp
+++ b/ydb/core/tx/datashard/local_kmeans.cpp
@@ -22,44 +22,6 @@
namespace NKikimr::NDataShard {
using namespace NKMeans;
-class TResult {
-public:
- explicit TResult(const TActorId& responseActorId, TAutoPtr<TEvDataShard::TEvLocalKMeansResponse> response)
- : ResponseActorId{responseActorId}
- , Response{std::move(response)} {
- Y_ASSERT(Response);
- }
-
- void SizeAdd(i64 size) {
- if (size != 0) {
- std::lock_guard lock{Mutex};
- Size += size;
- }
- }
-
- template<typename Func>
- void Send(const TActorContext& ctx, Func&& func) {
- std::unique_lock lock{Mutex};
- if (Size <= 0) {
- return;
- }
- if (func(Response->Record) && --Size > 0) {
- return;
- }
- Size = 0;
- lock.unlock();
-
- LOG_N("Finish TLocalKMeansScan " << Response->Record.ShortDebugString());
- ctx.Send(ResponseActorId, std::move(Response));
- }
-
-private:
- std::mutex Mutex;
- i64 Size = 1;
- TActorId ResponseActorId;
- TAutoPtr<TEvDataShard::TEvLocalKMeansResponse> Response;
-};
-
// This scan needed to run local (not distributed) kmeans.
// We have this local stage because we construct kmeans tree from top to bottom.
// And bottom kmeans can be constructed completely locally in datashards to avoid extra communication.
@@ -81,7 +43,7 @@ private:
// NTable::IScan::Seek used to switch from current state to the next one.
// If less than 1% of vectors are reassigned to new clusters we want to stop
-// TODO(mbkkt) 1% is choosed by common sense and should be adjusted in future
+// TODO(mbkkt) 1% is choosen by common sense and should be adjusted in future
static constexpr double MinVectorsNeedsReassigned = 0.01;
class TLocalKMeansScanBase: public TActor<TLocalKMeansScanBase>, public NTable::IScan {
@@ -92,12 +54,13 @@ protected:
NTableIndex::TClusterId Child = 0;
ui32 Round = 0;
- ui32 MaxRounds = 0;
+ const ui32 MaxRounds = 0;
+ const ui32 InitK = 0;
ui32 K = 0;
EState State;
- EState UploadState;
+ const EState UploadState;
IDriver* Driver = nullptr;
@@ -124,14 +87,17 @@ protected:
std::vector<ui64> ClusterSizes;
// Upload
- std::shared_ptr<NTxProxy::TUploadTypes> TargetTypes;
- std::shared_ptr<NTxProxy::TUploadTypes> NextTypes;
+ std::shared_ptr<NTxProxy::TUploadTypes> LevelTypes;
+ std::shared_ptr<NTxProxy::TUploadTypes> PostingTypes;
+ std::shared_ptr<NTxProxy::TUploadTypes> UploadTypes;
- TString TargetTable;
- TString NextTable;
+ const TString LevelTable;
+ const TString PostingTable;
+ TString UploadTable;
- TBufferData ReadBuf;
- TBufferData WriteBuf;
+ TBufferData LevelBuf;
+ TBufferData PostingBuf;
+ TBufferData UploadBuf;
NTable::TPos EmbeddingPos = 0;
NTable::TPos DataPos = 1;
@@ -141,15 +107,25 @@ protected:
TActorId Uploader;
const TIndexBuildScanSettings ScanSettings;
- NTable::TTag KMeansScan;
- TTags UploadScan;
+ NTable::TTag EmbeddingTag;
+ TTags ScanTags;
TUploadStatus UploadStatus;
ui64 UploadRows = 0;
ui64 UploadBytes = 0;
- std::shared_ptr<TResult> Result;
+ TActorId ResponseActorId;
+ TAutoPtr<TEvDataShard::TEvLocalKMeansResponse> Response;
+
+ // FIXME: save PrefixRows as std::vector<std::pair<TSerializedCellVec, TSerializedCellVec>> to avoid parsing
+ const ui32 PrefixColumns;
+ TSerializedCellVec Prefix;
+ TBufferData PrefixRows;
+ bool IsFirstPrefixFeed = true;
+ bool IsPrefixRowsValid = true;
+
+ bool IsExhausted = false;
public:
static constexpr NKikimrServices::TActivity::EType ActorActivityType()
@@ -157,38 +133,45 @@ public:
return NKikimrServices::TActivity::LOCAL_KMEANS_SCAN_ACTOR;
}
- TLocalKMeansScanBase(ui64 buildId, const TUserTable& table, TLead&& lead, NTableIndex::TClusterId parent, NTableIndex::TClusterId child,
+ TLocalKMeansScanBase(const TUserTable& table,
const NKikimrTxDataShard::TEvLocalKMeansRequest& request,
- std::shared_ptr<TResult> result)
+ const TActorId& responseActorId,
+ TAutoPtr<TEvDataShard::TEvLocalKMeansResponse>&& response,
+ TLead&& lead)
: TActor{&TThis::StateWork}
- , Parent{parent}
- , Child{child}
+ , Parent{request.GetParentFrom()}
+ , Child{request.GetChild()}
, MaxRounds{request.GetNeedsRounds()}
+ , InitK{request.GetK()}
, K{request.GetK()}
, State{EState::SAMPLE}
, UploadState{request.GetUpload()}
, Lead{std::move(lead)}
- , BuildId{buildId}
+ , BuildId{request.GetId()}
, Rng{request.GetSeed()}
- , TargetTable{request.GetLevelName()}
- , NextTable{request.GetPostingName()}
+ , LevelTable{request.GetLevelName()}
+ , PostingTable{request.GetPostingName()}
, ScanSettings(request.GetScanSettings())
- , Result{std::move(result)}
+ , ResponseActorId{responseActorId}
+ , Response{std::move(response)}
+ , PrefixColumns{request.GetParentFrom() == 0 && request.GetParentTo() == 0 ? 0u : 1u}
{
const auto& embedding = request.GetEmbeddingColumn();
const auto& data = request.GetDataColumns();
// scan tags
- UploadScan = MakeUploadTags(table, embedding, data, EmbeddingPos, DataPos, KMeansScan);
+ ScanTags = MakeUploadTags(table, embedding, data, EmbeddingPos, DataPos, EmbeddingTag);
+ Lead.SetTags(ScanTags);
// upload types
- if (Ydb::Type type; State <= EState::KMEANS) {
- TargetTypes = std::make_shared<NTxProxy::TUploadTypes>(3);
+ {
+ Ydb::Type type;
+ LevelTypes = std::make_shared<NTxProxy::TUploadTypes>(3);
type.set_type_id(NTableIndex::ClusterIdType);
- (*TargetTypes)[0] = {NTableIndex::NTableVectorKmeansTreeIndex::ParentColumn, type};
- (*TargetTypes)[1] = {NTableIndex::NTableVectorKmeansTreeIndex::IdColumn, type};
+ (*LevelTypes)[0] = {NTableIndex::NTableVectorKmeansTreeIndex::ParentColumn, type};
+ (*LevelTypes)[1] = {NTableIndex::NTableVectorKmeansTreeIndex::IdColumn, type};
type.set_type_id(Ydb::Type::STRING);
- (*TargetTypes)[2] = {NTableIndex::NTableVectorKmeansTreeIndex::CentroidColumn, type};
+ (*LevelTypes)[2] = {NTableIndex::NTableVectorKmeansTreeIndex::CentroidColumn, type};
}
- NextTypes = MakeUploadTypes(table, UploadState, embedding, data);
+ PostingTypes = MakeUploadTypes(table, UploadState, embedding, data);
}
TInitialState Prepare(IDriver* driver, TIntrusiveConstPtr<TScheme>) final
@@ -202,30 +185,27 @@ public:
TAutoPtr<IDestructable> Finish(EAbort abort) final
{
- LOG_D("Finish " << Debug());
-
if (Uploader) {
- Send(Uploader, new TEvents::TEvPoisonPill);
+ Send(Uploader, new TEvents::TEvPoison);
Uploader = {};
}
- Result->Send(TActivationContext::AsActorContext(), [&] (NKikimrTxDataShard::TEvLocalKMeansResponse& response) {
- response.SetReadRows(ReadRows);
- response.SetReadBytes(ReadBytes);
- response.SetUploadRows(UploadRows);
- response.SetUploadBytes(UploadBytes);
- NYql::IssuesToMessage(UploadStatus.Issues, response.MutableIssues());
- if (abort != EAbort::None) {
- response.SetStatus(NKikimrIndexBuilder::EBuildStatus::ABORTED);
- return false;
- } else if (UploadStatus.IsSuccess()) {
- response.SetStatus(NKikimrIndexBuilder::EBuildStatus::DONE);
- return true;
- } else {
- response.SetStatus(NKikimrIndexBuilder::EBuildStatus::BUILD_ERROR);
- return false;
- }
- });
+ auto& record = Response->Record;
+ record.SetReadRows(ReadRows);
+ record.SetReadBytes(ReadBytes);
+ record.SetUploadRows(UploadRows);
+ record.SetUploadBytes(UploadBytes);
+ if (abort != EAbort::None) {
+ record.SetStatus(NKikimrIndexBuilder::EBuildStatus::ABORTED);
+ } else if (UploadStatus.IsNone() || UploadStatus.IsSuccess()) {
+ record.SetStatus(NKikimrIndexBuilder::EBuildStatus::DONE);
+ } else {
+ record.SetStatus(NKikimrIndexBuilder::EBuildStatus::BUILD_ERROR);
+ }
+ NYql::IssuesToMessage(UploadStatus.Issues, record.MutableIssues());
+
+ LOG_N("Finish " << Debug() << " " << Response->Record.ShortDebugString());
+ Send(ResponseActorId, Response.Release());
Driver = nullptr;
this->PassAway();
@@ -240,9 +220,10 @@ public:
TString Debug() const
{
return TStringBuilder() << "TLocalKMeansScan Id: " << BuildId << " Parent: " << Parent << " Child: " << Child
- << " Target: " << TargetTable << " K: " << K << " Clusters: " << Clusters.size()
+ << " K: " << K << " Clusters: " << Clusters.size()
<< " State: " << State << " Round: " << Round << " / " << MaxRounds
- << " ReadBuf size: " << ReadBuf.Size() << " WriteBuf size: " << WriteBuf.Size();
+ << " LevelBuf size: " << LevelBuf.Size() << " PostingBuf size: " << PostingBuf.Size()
+ << " UploadTable: " << UploadTable << " UploadBuf size: " << UploadBuf.Size() << " RetryCount: " << RetryCount;
}
EScan PageFault() final
@@ -258,28 +239,31 @@ protected:
HFunc(TEvTxUserProxy::TEvUploadRowsResponse, Handle);
CFunc(TEvents::TSystem::Wakeup, HandleWakeup);
default:
- LOG_E("TLocalKMeansScan: StateWork unexpected event type: " << ev->GetTypeRewrite() << " event: "
- << ev->ToString() << " " << Debug());
+ LOG_E("StateWork unexpected event type: " << ev->GetTypeRewrite()
+ << " event: " << ev->ToString() << " " << Debug());
}
}
void HandleWakeup(const NActors::TActorContext& /*ctx*/)
{
- LOG_I("Retry upload " << Debug());
+ LOG_D("Retry upload " << Debug());
- if (!WriteBuf.IsEmpty()) {
- Upload(true);
+ if (UploadInProgress()) {
+ RetryUpload();
}
}
void Handle(TEvTxUserProxy::TEvUploadRowsResponse::TPtr& ev, const TActorContext& ctx)
{
LOG_D("Handle TEvUploadRowsResponse " << Debug()
- << " Uploader: " << Uploader.ToString() << " ev->Sender: " << ev->Sender.ToString());
+ << " Uploader: " << (Uploader ? Uploader.ToString() : "<null>")
+ << " ev->Sender: " << ev->Sender.ToString());
if (Uploader) {
- Y_ENSURE(Uploader == ev->Sender, "Mismatch Uploader: " << Uploader.ToString() << " ev->Sender: "
- << ev->Sender.ToString() << Debug());
+ Y_ENSURE(Uploader == ev->Sender, "Mismatch"
+ << " Uploader: " << Uploader.ToString()
+ << " Sender: " << ev->Sender.ToString());
+ Uploader = {};
} else {
Y_ENSURE(Driver == nullptr);
return;
@@ -288,81 +272,109 @@ protected:
UploadStatus.StatusCode = ev->Get()->Status;
UploadStatus.Issues = ev->Get()->Issues;
if (UploadStatus.IsSuccess()) {
- UploadRows += WriteBuf.GetRows();
- UploadBytes += WriteBuf.GetBytes();
- WriteBuf.Clear();
- if (HasReachedLimits(ReadBuf, ScanSettings)) {
- ReadBuf.FlushTo(WriteBuf);
- Upload(false);
- }
+ UploadRows += UploadBuf.GetRows();
+ UploadBytes += UploadBuf.GetBytes();
+ UploadBuf.Clear();
+
+ TryUpload(LevelBuf, LevelTable, LevelTypes, true)
+ || TryUpload(PostingBuf, PostingTable, PostingTypes, true);
Driver->Touch(EScan::Feed);
return;
}
if (RetryCount < ScanSettings.GetMaxBatchRetries() && UploadStatus.IsRetriable()) {
- LOG_N("Got retriable error, " << Debug() << UploadStatus.ToString());
+ LOG_N("Got retriable error, " << Debug() << " " << UploadStatus.ToString());
ctx.Schedule(GetRetryWakeupTimeoutBackoff(RetryCount), new TEvents::TEvWakeup());
return;
}
- LOG_N("Got error, abort scan, " << Debug() << UploadStatus.ToString());
+ LOG_N("Got error, abort scan, " << Debug() << " " << UploadStatus.ToString());
Driver->Touch(EScan::Final);
}
- EScan FeedUpload()
+ bool ShouldWaitUpload()
{
- if (!HasReachedLimits(ReadBuf, ScanSettings)) {
- return EScan::Feed;
+ if (!HasReachedLimits(LevelBuf, ScanSettings) && !HasReachedLimits(PostingBuf, ScanSettings)) {
+ return false;
}
- if (!WriteBuf.IsEmpty()) {
- return EScan::Sleep;
+
+ if (UploadInProgress()) {
+ return true;
}
- ReadBuf.FlushTo(WriteBuf);
- Upload(false);
- return EScan::Feed;
- }
+
+ TryUpload(LevelBuf, LevelTable, LevelTypes, true)
+ || TryUpload(PostingBuf, PostingTable, PostingTypes, true);
- ui64 GetProbability()
- {
- return Rng.GenRand64();
+ return !HasReachedLimits(LevelBuf, ScanSettings) && !HasReachedLimits(PostingBuf, ScanSettings);
}
- void Upload(bool isRetry)
+ void UploadImpl()
{
- if (isRetry) {
- ++RetryCount;
- } else {
- RetryCount = 0;
- if (State != EState::KMEANS && NextTypes) {
- TargetTypes = std::exchange(NextTypes, {});
- TargetTable = std::move(NextTable);
- }
- }
+ LOG_D("Uploading " << Debug());
+ Y_ASSERT(!UploadBuf.IsEmpty());
+ Y_ASSERT(!Uploader);
auto actor = NTxProxy::CreateUploadRowsInternal(
- this->SelfId(), TargetTable, TargetTypes, WriteBuf.GetRowsData(),
+ this->SelfId(), UploadTable, UploadTypes, UploadBuf.GetRowsData(),
NTxProxy::EUploadRowsMode::WriteToTableShadow, true /*writeToPrivateTable*/);
Uploader = this->Register(actor);
}
- void UploadSample()
+ void InitUpload(std::string_view table, std::shared_ptr<NTxProxy::TUploadTypes> types)
+ {
+ RetryCount = 0;
+ UploadTable = table;
+ UploadTypes = std::move(types);
+ UploadImpl();
+ }
+
+ void RetryUpload()
+ {
+ ++RetryCount;
+ UploadImpl();
+ }
+
+ bool UploadInProgress()
+ {
+ return !UploadBuf.IsEmpty();
+ }
+
+ bool TryUpload(TBufferData& buffer, const TString& table, const std::shared_ptr<NTxProxy::TUploadTypes>& types, bool byLimit)
+ {
+ if (Y_UNLIKELY(UploadInProgress())) {
+ // already uploading something
+ return true;
+ }
+
+ if (!buffer.IsEmpty() && (!byLimit || HasReachedLimits(buffer, ScanSettings))) {
+ buffer.FlushTo(UploadBuf);
+ InitUpload(table, types);
+ return true;
+ }
+
+ return false;
+ }
+
+ void FormLevelRows()
{
- Y_ASSERT(ReadBuf.IsEmpty());
- Y_ASSERT(WriteBuf.IsEmpty());
std::array<TCell, 2> pk;
std::array<TCell, 1> data;
for (NTable::TPos pos = 0; const auto& row : Clusters) {
pk[0] = TCell::Make(Parent);
pk[1] = TCell::Make(Child + pos);
data[0] = TCell{row};
- WriteBuf.AddRow(TSerializedCellVec{pk}, TSerializedCellVec::Serialize(data));
+ LevelBuf.AddRow(TSerializedCellVec{pk}, TSerializedCellVec::Serialize(data));
++pos;
}
- Upload(false);
+ }
+
+ ui64 GetProbability()
+ {
+ return Rng.GenRand64();
}
};
@@ -377,10 +389,29 @@ class TLocalKMeansScan final: public TLocalKMeansScanBase, private TCalculation<
};
std::vector<TAggregatedCluster> AggregatedClusters;
+ void StartNewPrefix() {
+ Round = 0;
+ K = InitK;
+ State = EState::SAMPLE;
+ Lead.Valid = true;
+ Lead.Key = TSerializedCellVec(Prefix.GetCells()); // seek to (prefix, inf)
+ Lead.Relation = NTable::ESeek::Upper;
+ Prefix = {};
+ IsFirstPrefixFeed = true;
+ IsPrefixRowsValid = true;
+ PrefixRows.Clear();
+ MaxProbability = std::numeric_limits<ui64>::max();
+ MaxRows.clear();
+ Clusters.clear();
+ ClusterSizes.clear();
+ AggregatedClusters.clear();
+ }
+
public:
- TLocalKMeansScan(ui64 buildId, const TUserTable& table, TLead&& lead, NTableIndex::TClusterId parent, NTableIndex::TClusterId child, NKikimrTxDataShard::TEvLocalKMeansRequest& request,
- std::shared_ptr<TResult> result)
- : TLocalKMeansScanBase{buildId, table, std::move(lead), parent, child, request, std::move(result)}
+ TLocalKMeansScan(const TUserTable& table, NKikimrTxDataShard::TEvLocalKMeansRequest& request,
+ const TActorId& responseActorId, TAutoPtr<TEvDataShard::TEvLocalKMeansResponse>&& response,
+ TLead&& lead)
+ : TLocalKMeansScanBase{table, request, responseActorId, std::move(response), std::move(lead)}
{
this->Dimensions = request.GetSettings().vector_dimension();
LOG_I("Create " << Debug());
@@ -388,90 +419,143 @@ public:
EScan Seek(TLead& lead, ui64 seq) final
{
- LOG_D("Seek " << Debug());
- if (State == UploadState) {
- if (!WriteBuf.IsEmpty()) {
- return EScan::Sleep;
- }
- if (!ReadBuf.IsEmpty()) {
- ReadBuf.FlushTo(WriteBuf);
- Upload(false);
+ LOG_D("Seek " << seq << " " << Debug());
+
+ if (IsExhausted) {
+ if (UploadInProgress()
+ || TryUpload(LevelBuf, LevelTable, LevelTypes, false)
+ || TryUpload(PostingBuf, PostingTable, PostingTypes, false))
+ {
return EScan::Sleep;
}
- if (UploadStatus.IsNone()) {
- UploadStatus.StatusCode = Ydb::StatusIds::SUCCESS;
- }
return EScan::Final;
}
- if (State == EState::SAMPLE) {
- lead = Lead;
- lead.SetTags({&KMeansScan, 1});
- if (seq == 0) {
- return EScan::Feed;
+ lead = Lead;
+
+ return EScan::Feed;
+ }
+
+ EScan Feed(TArrayRef<const TCell> key, const TRow& row) final
+ {
+ LOG_T("Feed " << Debug());
+
+ ++ReadRows;
+ ReadBytes += CountBytes(key, row);
+
+ if (PrefixColumns && Prefix && !TCellVectorsEquals{}(Prefix.GetCells(), key.subspan(0, PrefixColumns))) {
+ if (!FinishPrefix()) {
+ // scan current prefix rows with a new state again
+ return EScan::Reset;
}
- State = EState::KMEANS;
- if (!InitAggregatedClusters()) {
- // We don't need to do anything,
- // because this datashard doesn't have valid embeddings for this parent
- if (UploadStatus.IsNone()) {
- UploadStatus.StatusCode = Ydb::StatusIds::SUCCESS;
- }
- return EScan::Final;
+ }
+
+ if (PrefixColumns && !Prefix) {
+ Prefix = TSerializedCellVec{key.subspan(0, PrefixColumns)};
+ auto newParent = key.at(0).template AsValue<ui64>();
+ Child += (newParent - Parent) * InitK;
+ Parent = newParent;
+ }
+
+ if (IsFirstPrefixFeed && IsPrefixRowsValid) {
+ PrefixRows.AddRow(TSerializedCellVec{key}, TSerializedCellVec::Serialize(*row));
+ if (HasReachedLimits(PrefixRows, ScanSettings)) {
+ PrefixRows.Clear();
+ IsPrefixRowsValid = false;
}
- ++Round;
- return EScan::Feed;
}
- Y_ASSERT(State == EState::KMEANS);
- if (RecomputeClusters()) {
- lead = std::move(Lead);
- lead.SetTags(UploadScan);
+ Feed(key, *row);
+
+ return ShouldWaitUpload() ? EScan::Sleep : EScan::Feed;
+ }
- UploadSample();
- State = UploadState;
+ EScan Exhausted() final
+ {
+ LOG_D("Exhausted " << Debug());
+
+ if (!FinishPrefix()) {
+ return EScan::Reset;
+ }
+
+ IsExhausted = true;
+
+ // call Seek to wait uploads
+ return EScan::Reset;
+ }
+
+private:
+ bool FinishPrefix()
+ {
+ if (FinishPrefixImpl()) {
+ StartNewPrefix();
+ LOG_D("FinishPrefix finished " << Debug());
+ return true;
} else {
- lead = Lead;
- lead.SetTags({&KMeansScan, 1});
- ++Round;
+ IsFirstPrefixFeed = false;
+
+ if (IsPrefixRowsValid) {
+ LOG_D("FinishPrefix not finished, manually feeding " << PrefixRows.Size() << " saved rows " << Debug());
+ for (ui64 iteration = 0; ; iteration++) {
+ for (const auto& [key, row_] : *PrefixRows.GetRowsData()) {
+ TSerializedCellVec row(row_);
+ Feed(key.GetCells(), row.GetCells());
+ }
+ if (FinishPrefixImpl()) {
+ StartNewPrefix();
+ LOG_D("FinishPrefix finished in " << iteration << " iterations " << Debug());
+ return true;
+ } else {
+ LOG_D("FinishPrefix not finished in " << iteration << " iterations " << Debug());
+ }
+ }
+ } else {
+ LOG_D("FinishPrefix not finished, rescanning rows " << Debug());
+ }
+
+ return false;
}
- return EScan::Feed;
}
- EScan Feed(TArrayRef<const TCell> key, const TRow& row_) final
+ bool FinishPrefixImpl()
{
- LOG_T("Feed " << Debug());
+ if (State == EState::SAMPLE) {
+ State = EState::KMEANS;
+ if (!InitAggregatedClusters()) {
+ // We don't need to do anything,
+ // because this datashard doesn't have valid embeddings for this prefix
+ return true;
+ }
+ Round = 1;
+ return false; // do KMEANS
+ }
- ++ReadRows;
- ReadBytes += CountBytes(key, row_);
- auto row = *row_;
-
- switch (State) {
- case EState::SAMPLE:
- return FeedSample(row);
- case EState::KMEANS:
- return FeedKMeans(row);
- case EState::UPLOAD_MAIN_TO_BUILD:
- return FeedUploadMain2Build(key, row);
- case EState::UPLOAD_MAIN_TO_POSTING:
- return FeedUploadMain2Posting(key, row);
- case EState::UPLOAD_BUILD_TO_BUILD:
- return FeedUploadBuild2Build(key, row);
- case EState::UPLOAD_BUILD_TO_POSTING:
- return FeedUploadBuild2Posting(key, row);
- default:
- return EScan::Final;
+ if (State == EState::KMEANS) {
+ if (RecomputeClusters()) {
+ FormLevelRows();
+ State = UploadState;
+ return false; // do UPLOAD_*
+ } else {
+ ++Round;
+ return false; // recompute KMEANS
+ }
}
+
+ if (State == UploadState) {
+ return true;
+ }
+
+ Y_ASSERT(false);
+ return true;
}
-private:
bool InitAggregatedClusters()
{
if (Clusters.size() == 0) {
return false;
}
if (Clusters.size() < K) {
- // if this datashard have smaller than K count of valid embeddings for this parent
+ // if this datashard have less than K valid embeddings for this parent
// lets make single centroid for it
K = 1;
Clusters.resize(K);
@@ -544,12 +628,37 @@ private:
return true;
}
- EScan FeedSample(TArrayRef<const TCell> row)
+ void Feed(TArrayRef<const TCell> key, TArrayRef<const TCell> row)
{
- Y_ASSERT(row.size() == 1);
- const auto embedding = row.at(0).AsRef();
+ switch (State) {
+ case EState::SAMPLE:
+ FeedSample(row);
+ break;
+ case EState::KMEANS:
+ FeedKMeans(row);
+ break;
+ case EState::UPLOAD_MAIN_TO_BUILD:
+ FeedUploadMain2Build(key, row);
+ break;
+ case EState::UPLOAD_MAIN_TO_POSTING:
+ FeedUploadMain2Posting(key, row);
+ break;
+ case EState::UPLOAD_BUILD_TO_BUILD:
+ FeedUploadBuild2Build(key, row);
+ break;
+ case EState::UPLOAD_BUILD_TO_POSTING:
+ FeedUploadBuild2Posting(key, row);
+ break;
+ default:
+ Y_ASSERT(false);
+ }
+ }
+
+ void FeedSample(TArrayRef<const TCell> row)
+ {
+ const auto embedding = row.at(EmbeddingPos).AsRef();
if (!this->IsExpectedSize(embedding)) {
- return EScan::Feed;
+ return;
}
const auto probability = GetProbability();
@@ -568,55 +677,44 @@ private:
std::push_heap(MaxRows.begin(), MaxRows.end());
MaxProbability = MaxRows.front().P;
}
- return MaxProbability != 0 ? EScan::Feed : EScan::Reset;
}
- EScan FeedKMeans(TArrayRef<const TCell> row)
+ void FeedKMeans(TArrayRef<const TCell> row)
{
- Y_ASSERT(row.size() == 1);
- const ui32 pos = FeedEmbedding(*this, Clusters, row, 0);
- AggregateToCluster(pos, row.at(0).Data());
- return EScan::Feed;
+ const ui32 pos = FeedEmbedding(*this, Clusters, row, EmbeddingPos);
+ AggregateToCluster(pos, row.at(EmbeddingPos).Data());
}
- EScan FeedUploadMain2Build(TArrayRef<const TCell> key, TArrayRef<const TCell> row)
+ void FeedUploadMain2Build(TArrayRef<const TCell> key, TArrayRef<const TCell> row)
{
const ui32 pos = FeedEmbedding(*this, Clusters, row, EmbeddingPos);
- if (pos >= K) {
- return EScan::Feed;
+ if (pos < K) {
+ AddRowMain2Build(PostingBuf, Child + pos, key, row);
}
- AddRowMain2Build(ReadBuf, Child + pos, key, row);
- return FeedUpload();
}
- EScan FeedUploadMain2Posting(TArrayRef<const TCell> key, TArrayRef<const TCell> row)
+ void FeedUploadMain2Posting(TArrayRef<const TCell> key, TArrayRef<const TCell> row)
{
const ui32 pos = FeedEmbedding(*this, Clusters, row, EmbeddingPos);
- if (pos >= K) {
- return EScan::Feed;
+ if (pos < K) {
+ AddRowMain2Posting(PostingBuf, Child + pos, key, row, DataPos);
}
- AddRowMain2Posting(ReadBuf, Child + pos, key, row, DataPos);
- return FeedUpload();
}
- EScan FeedUploadBuild2Build(TArrayRef<const TCell> key, TArrayRef<const TCell> row)
+ void FeedUploadBuild2Build(TArrayRef<const TCell> key, TArrayRef<const TCell> row)
{
const ui32 pos = FeedEmbedding(*this, Clusters, row, EmbeddingPos);
- if (pos >= K) {
- return EScan::Feed;
+ if (pos < K) {
+ AddRowBuild2Build(PostingBuf, Child + pos, key, row);
}
- AddRowBuild2Build(ReadBuf, Child + pos, key, row);
- return FeedUpload();
}
- EScan FeedUploadBuild2Posting(TArrayRef<const TCell> key, TArrayRef<const TCell> row)
+ void FeedUploadBuild2Posting(TArrayRef<const TCell> key, TArrayRef<const TCell> row)
{
const ui32 pos = FeedEmbedding(*this, Clusters, row, EmbeddingPos);
- if (pos >= K) {
- return EScan::Feed;
+ if (pos < K) {
+ AddRowBuild2Posting(PostingBuf, Child + pos, key, row, DataPos);
}
- AddRowBuild2Posting(ReadBuf, Child + pos, key, row, DataPos);
- return FeedUpload();
}
};
@@ -671,22 +769,14 @@ void TDataShard::HandleSafe(TEvDataShard::TEvLocalKMeansRequest::TPtr& ev, const
TScanRecord::TSeqNo seqNo = {request.GetSeqNoGeneration(), request.GetSeqNoRound()};
response->Record.SetRequestSeqNoGeneration(seqNo.Generation);
response->Record.SetRequestSeqNoRound(seqNo.Round);
- auto result = std::make_shared<TResult>(ev->Sender, std::move(response));
- ui32 localTid = 0;
- TScanRecord::TScanIds scanIds;
auto badRequest = [&](const TString& error) {
- for (auto scanId : scanIds) {
- CancelScan(localTid, scanId);
- }
- result->Send(ctx, [&] (NKikimrTxDataShard::TEvLocalKMeansResponse& response) {
- response.SetStatus(NKikimrIndexBuilder::EBuildStatus::BAD_REQUEST);
- auto issue = response.AddIssues();
- issue->set_severity(NYql::TSeverityIds::S_ERROR);
- issue->set_message(error);
- return false;
- });
- result.reset();
+ response->Record.SetStatus(NKikimrIndexBuilder::EBuildStatus::BAD_REQUEST);
+ auto issue = response->Record.AddIssues();
+ issue->set_severity(NYql::TSeverityIds::S_ERROR);
+ issue->set_message(error);
+ ctx.Send(ev->Sender, std::move(response));
+ response.Reset();
};
if (const ui64 shardId = request.GetTabletId(); shardId != TabletID()) {
@@ -715,8 +805,6 @@ void TDataShard::HandleSafe(TEvDataShard::TEvLocalKMeansRequest::TPtr& ev, const
ScanManager.Drop(id);
}
- localTid = userTable.LocalTid;
-
if (request.HasSnapshotStep() || request.HasSnapshotTxId()) {
const TSnapshotKey snapshotKey(pathId, rowVersion.Step, rowVersion.TxId);
if (!SnapshotManager.FindAvailable(snapshotKey)) {
@@ -743,45 +831,44 @@ void TDataShard::HandleSafe(TEvDataShard::TEvLocalKMeansRequest::TPtr& ev, const
badRequest(TStringBuilder() << "Parent from " << parentFrom << " should be less or equal to parent to " << parentTo);
return;
}
- const i64 expectedSize = parentTo - parentFrom + 1;
- result->SizeAdd(expectedSize);
- for (auto parent = parentFrom; parent <= parentTo; ++parent) {
- TCell from, to;
- const auto range = CreateRangeFrom(userTable, parent, from, to);
- if (range.IsEmptyRange(userTable.KeyColumnTypes)) {
- LOG_D("TEvLocalKMeansRequst " << request.GetId() << " parent " << parent << " is empty");
- continue;
- }
-
- TAutoPtr<NTable::IScan> scan;
- auto createScan = [&]<typename T> {
- scan = new TLocalKMeansScan<T>{
- request.GetId(), userTable,
- CreateLeadFrom(range), parent, request.GetChild() + request.GetK() * (parent - parentFrom),
- request, result,
- };
- };
- MakeScan(request, createScan, badRequest);
- if (!scan) {
- Y_ASSERT(!result);
+ NTable::TLead lead;
+ if (parentFrom == 0 && parentTo == 0) {
+ lead.To({}, NTable::ESeek::Lower);
+ } else {
+ TCell from = TCell::Make(parentFrom - 1);
+ TCell to = TCell::Make(parentTo);
+ TTableRange range{{&from, 1}, false, {&to, 1}, true};
+ auto scanRange = Intersect(userTable.KeyColumnTypes, range, userTable.Range.ToTableRange());
+ if (scanRange.IsEmptyRange(userTable.KeyColumnTypes)) {
+ badRequest(TStringBuilder() << " requested range doesn't intersect with table range"
+ << " requestedRange: " << DebugPrintRange(userTable.KeyColumnTypes, range, *AppData()->TypeRegistry)
+ << " tableRange: " << DebugPrintRange(userTable.KeyColumnTypes, userTable.Range.ToTableRange(), *AppData()->TypeRegistry)
+ << " scanRange: " << DebugPrintRange(userTable.KeyColumnTypes, scanRange, *AppData()->TypeRegistry));
return;
}
-
- TScanOptions scanOpts;
- scanOpts.SetSnapshotRowVersion(rowVersion);
- scanOpts.SetResourceBroker("build_index", 10); // TODO(mbkkt) Should be different group?
- const auto scanId = QueueScan(userTable.LocalTid, std::move(scan), 0, scanOpts);
- scanIds.push_back(scanId);
+ lead.To(range.From, NTable::ESeek::Upper);
+ lead.Until(range.To, true);
}
-
- if (scanIds.empty()) {
- badRequest("Requested range doesn't intersect with table range");
+
+ TAutoPtr<NTable::IScan> scan;
+ auto createScan = [&]<typename T> {
+ scan = new TLocalKMeansScan<T>{
+ userTable, request, ev->Sender, std::move(response),
+ std::move(lead)
+ };
+ };
+ MakeScan(request, createScan, badRequest);
+ if (!scan) {
+ Y_ASSERT(!response);
return;
}
- result->SizeAdd(static_cast<i64>(scanIds.size()) - expectedSize);
- result->Send(ctx, [] (auto&) { return true; }); // decrement extra one
- ScanManager.Set(id, seqNo) = std::move(scanIds);
+
+ TScanOptions scanOpts;
+ scanOpts.SetSnapshotRowVersion(rowVersion);
+ scanOpts.SetResourceBroker("build_index", 10); // TODO(mbkkt) Should be different group?
+ const auto scanId = QueueScan(userTable.LocalTid, std::move(scan), 0, scanOpts);
+ ScanManager.Set(id, seqNo).push_back(scanId);
}
}
diff --git a/ydb/core/tx/datashard/prefix_kmeans.cpp b/ydb/core/tx/datashard/prefix_kmeans.cpp
index 5a36035d782..8742c3cf8ff 100644
--- a/ydb/core/tx/datashard/prefix_kmeans.cpp
+++ b/ydb/core/tx/datashard/prefix_kmeans.cpp
@@ -101,12 +101,12 @@ protected:
TAutoPtr<TEvDataShard::TEvPrefixKMeansResponse> Response;
// FIXME: save PrefixRows as std::vector<std::pair<TSerializedCellVec, TSerializedCellVec>> to avoid parsing
- ui32 PrefixColumns;
+ const ui32 PrefixColumns;
TSerializedCellVec Prefix;
TBufferData PrefixRows;
bool IsFirstPrefixFeed = true;
bool IsPrefixRowsValid = true;
-
+
bool IsExhausted = false;
public:
@@ -153,6 +153,7 @@ public:
(*LevelTypes)[2] = {NTableIndex::NTableVectorKmeansTreeIndex::CentroidColumn, type};
}
PostingTypes = MakeUploadTypes(table, UploadState, embedding, data, PrefixColumns);
+ // prefix types
{
auto types = GetAllTypes(table);
@@ -242,8 +243,8 @@ protected:
HFunc(TEvTxUserProxy::TEvUploadRowsResponse, Handle);
CFunc(TEvents::TSystem::Wakeup, HandleWakeup);
default:
- LOG_E("TPrefixKMeansScan: StateWork unexpected event type: " << ev->GetTypeRewrite() << " event: "
- << ev->ToString() << " " << Debug());
+ LOG_E("StateWork unexpected event type: " << ev->GetTypeRewrite()
+ << " event: " << ev->ToString() << " " << Debug());
}
}
@@ -299,11 +300,6 @@ protected:
Driver->Touch(EScan::Final);
}
- ui64 GetProbability()
- {
- return Rng.GenRand64();
- }
-
bool ShouldWaitUpload()
{
if (!HasReachedLimits(LevelBuf, ScanSettings) && !HasReachedLimits(PostingBuf, ScanSettings) && !HasReachedLimits(PrefixBuf, ScanSettings)) {
@@ -381,6 +377,11 @@ protected:
++pos;
}
}
+
+ ui64 GetProbability()
+ {
+ return Rng.GenRand64();
+ }
};
template <typename TMetric>
@@ -394,17 +395,13 @@ class TPrefixKMeansScan final: public TPrefixKMeansScanBase, private TCalculatio
};
std::vector<TAggregatedCluster> AggregatedClusters;
-
void StartNewPrefix() {
Parent = Child + K;
Child = Parent + 1;
Round = 0;
K = InitK;
State = EState::SAMPLE;
- // TODO(mbkkt) Upper or Lower doesn't matter here, because we seek to (prefix, inf)
- // so we can choose Lower if it's faster.
- // Exact seek with Lower also possible but needs to rewrite some code in Feed
- Lead.To(Prefix.GetCells(), NTable::ESeek::Upper);
+ Lead.To(Prefix.GetCells(), NTable::ESeek::Upper); // seek to (prefix, inf)
Prefix = {};
IsFirstPrefixFeed = true;
IsPrefixRowsValid = true;
@@ -804,6 +801,11 @@ void TDataShard::HandleSafe(TEvDataShard::TEvPrefixKMeansRequest::TPtr& ev, cons
return;
}
+ if (request.GetPrefixColumns() <= 0) {
+ badRequest("Should be requested on at least one prefix column");
+ return;
+ }
+
TAutoPtr<NTable::IScan> scan;
auto createScan = [&]<typename T> {
scan = new TPrefixKMeansScan<T>{
diff --git a/ydb/core/tx/datashard/reshuffle_kmeans.cpp b/ydb/core/tx/datashard/reshuffle_kmeans.cpp
index f7db951e60b..f3a7f900876 100644
--- a/ydb/core/tx/datashard/reshuffle_kmeans.cpp
+++ b/ydb/core/tx/datashard/reshuffle_kmeans.cpp
@@ -189,8 +189,8 @@ protected:
hFunc(TEvTxUserProxy::TEvUploadRowsResponse, Handle);
cFunc(TEvents::TSystem::Wakeup, HandleWakeup);
default:
- LOG_E("TReshuffleKMeansScan: StateWork unexpected event type: " << ev->GetTypeRewrite() << " event: "
- << ev->ToString() << " " << Debug());
+ LOG_E("StateWork unexpected event type: " << ev->GetTypeRewrite()
+ << " event: " << ev->ToString() << " " << Debug());
}
}
diff --git a/ydb/core/tx/datashard/sample_k.cpp b/ydb/core/tx/datashard/sample_k.cpp
index e8472a4d7ec..011418b5e8a 100644
--- a/ydb/core/tx/datashard/sample_k.cpp
+++ b/ydb/core/tx/datashard/sample_k.cpp
@@ -180,7 +180,8 @@ private:
STFUNC(StateWork) {
switch (ev->GetTypeRewrite()) {
default:
- LOG_E("TSampleKScan: StateWork unexpected event type: " << ev->GetTypeRewrite() << " event: " << ev->ToString() << " " << Debug());
+ LOG_E("StateWork unexpected event type: " << ev->GetTypeRewrite()
+ << " event: " << ev->ToString() << " " << Debug());
}
}