diff options
author | kungurtsev <kungasc@ydb.tech> | 2025-04-08 19:25:21 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-04-08 19:25:21 +0200 |
commit | 16688838aecc0561c93778b50e9ea199673fc004 (patch) | |
tree | d60a5993c8b093b36aace6e90794bd8304e6c3ab | |
parent | 0b84ae0cbe15d2e5bc5fb8a7c2fcff8bcfb2ee2f (diff) | |
download | ydb-16688838aecc0561c93778b50e9ea199673fc004.tar.gz |
Vector Index Local KMeans with one scan (#16909)
-rw-r--r-- | ydb/core/tx/datashard/datashard_ut_local_kmeans.cpp | 267 | ||||
-rw-r--r-- | ydb/core/tx/datashard/datashard_ut_prefix_kmeans.cpp | 13 | ||||
-rw-r--r-- | ydb/core/tx/datashard/local_kmeans.cpp | 633 | ||||
-rw-r--r-- | ydb/core/tx/datashard/prefix_kmeans.cpp | 30 | ||||
-rw-r--r-- | ydb/core/tx/datashard/reshuffle_kmeans.cpp | 4 | ||||
-rw-r--r-- | ydb/core/tx/datashard/sample_k.cpp | 3 |
6 files changed, 630 insertions, 320 deletions
diff --git a/ydb/core/tx/datashard/datashard_ut_local_kmeans.cpp b/ydb/core/tx/datashard/datashard_ut_local_kmeans.cpp index 3d081ef34ad..931af10560d 100644 --- a/ydb/core/tx/datashard/datashard_ut_local_kmeans.cpp +++ b/ydb/core/tx/datashard/datashard_ut_local_kmeans.cpp @@ -90,9 +90,9 @@ Y_UNIT_TEST_SUITE (TTxDataShardLocalKMeansScan) { } static std::tuple<TString, TString> DoLocalKMeans( - Tests::TServer::TPtr server, TActorId sender, NTableIndex::TClusterId parent, ui64 seed, ui64 k, + Tests::TServer::TPtr server, TActorId sender, NTableIndex::TClusterId parentFrom, NTableIndex::TClusterId parentTo, ui64 seed, ui64 k, NKikimrTxDataShard::EKMeansState upload, VectorIndexSettings::VectorType type, - VectorIndexSettings::Metric metric) + VectorIndexSettings::Metric metric, ui32 maxBatchRows = 50000) { auto id = sId.fetch_add(1, std::memory_order_relaxed); auto& runtime = *server->GetRuntime(); @@ -131,15 +131,17 @@ Y_UNIT_TEST_SUITE (TTxDataShardLocalKMeansScan) { rec.SetNeedsRounds(300); - rec.SetParentFrom(parent); - rec.SetParentTo(parent); - rec.SetChild(parent + 1); + rec.SetParentFrom(parentFrom); + rec.SetParentTo(parentTo); + rec.SetChild(parentTo + 1); rec.SetEmbeddingColumn("embedding"); rec.AddDataColumns("data"); rec.SetLevelName(kLevelTable); rec.SetPostingName(kPostingTable); + + rec.MutableScanSettings()->SetMaxBatchRows(maxBatchRows); }; fill(ev1); fill(ev2); @@ -158,6 +160,10 @@ Y_UNIT_TEST_SUITE (TTxDataShardLocalKMeansScan) { auto level = ReadShardedTable(server, kLevelTable); auto posting = ReadShardedTable(server, kPostingTable); + Cerr << "Level:" << Endl; + Cerr << level << Endl; + Cerr << "Posting:" << Endl; + Cerr << posting << Endl; return {std::move(level), std::move(posting)}; } @@ -223,6 +229,7 @@ Y_UNIT_TEST_SUITE (TTxDataShardLocalKMeansScan) { auto sender = runtime.AllocateEdgeActor(); runtime.SetLogPriority(NKikimrServices::TX_DATASHARD, NLog::PRI_DEBUG); + runtime.SetLogPriority(NKikimrServices::BUILD_INDEX, NLog::PRI_TRACE); InitRoot(server, sender); @@ -310,6 +317,7 @@ Y_UNIT_TEST_SUITE (TTxDataShardLocalKMeansScan) { auto sender = runtime.AllocateEdgeActor(); runtime.SetLogPriority(NKikimrServices::TX_DATASHARD, NLog::PRI_DEBUG); + runtime.SetLogPriority(NKikimrServices::BUILD_INDEX, NLog::PRI_TRACE); InitRoot(server, sender); @@ -346,7 +354,7 @@ Y_UNIT_TEST_SUITE (TTxDataShardLocalKMeansScan) { seed = 0; for (auto distance : {VectorIndexSettings::DISTANCE_MANHATTAN, VectorIndexSettings::DISTANCE_EUCLIDEAN}) { - auto [level, posting] = DoLocalKMeans(server, sender, 0, seed, k, + auto [level, posting] = DoLocalKMeans(server, sender, 0, 0, seed, k, NKikimrTxDataShard::EKMeansState::UPLOAD_MAIN_TO_POSTING, VectorIndexSettings::VECTOR_TYPE_UINT8, distance); UNIT_ASSERT_VALUES_EQUAL(level, "__ydb_parent = 0, __ydb_id = 1, __ydb_centroid = mm\3\n" @@ -361,7 +369,7 @@ Y_UNIT_TEST_SUITE (TTxDataShardLocalKMeansScan) { seed = 111; for (auto distance : {VectorIndexSettings::DISTANCE_MANHATTAN, VectorIndexSettings::DISTANCE_EUCLIDEAN}) { - auto [level, posting] = DoLocalKMeans(server, sender, 0, seed, k, + auto [level, posting] = DoLocalKMeans(server, sender, 0, 0, seed, k, NKikimrTxDataShard::EKMeansState::UPLOAD_MAIN_TO_POSTING, VectorIndexSettings::VECTOR_TYPE_UINT8, distance); UNIT_ASSERT_VALUES_EQUAL(level, "__ydb_parent = 0, __ydb_id = 1, __ydb_centroid = 11\3\n" @@ -377,7 +385,7 @@ Y_UNIT_TEST_SUITE (TTxDataShardLocalKMeansScan) { for (auto similarity : {VectorIndexSettings::SIMILARITY_INNER_PRODUCT, VectorIndexSettings::SIMILARITY_COSINE, VectorIndexSettings::DISTANCE_COSINE}) { - auto [level, posting] = DoLocalKMeans(server, sender, 0, seed, k, + auto [level, posting] = DoLocalKMeans(server, sender, 0, 0, seed, k, NKikimrTxDataShard::EKMeansState::UPLOAD_MAIN_TO_POSTING, VectorIndexSettings::VECTOR_TYPE_UINT8, similarity); UNIT_ASSERT_VALUES_EQUAL(level, "__ydb_parent = 0, __ydb_id = 1, __ydb_centroid = II\3\n"); @@ -400,6 +408,7 @@ Y_UNIT_TEST_SUITE (TTxDataShardLocalKMeansScan) { auto sender = runtime.AllocateEdgeActor(); runtime.SetLogPriority(NKikimrServices::TX_DATASHARD, NLog::PRI_DEBUG); + runtime.SetLogPriority(NKikimrServices::BUILD_INDEX, NLog::PRI_TRACE); InitRoot(server, sender); @@ -436,7 +445,7 @@ Y_UNIT_TEST_SUITE (TTxDataShardLocalKMeansScan) { seed = 0; for (auto distance : {VectorIndexSettings::DISTANCE_MANHATTAN, VectorIndexSettings::DISTANCE_EUCLIDEAN}) { - auto [level, posting] = DoLocalKMeans(server, sender, 0, seed, k, + auto [level, posting] = DoLocalKMeans(server, sender, 0, 0, seed, k, NKikimrTxDataShard::EKMeansState::UPLOAD_MAIN_TO_BUILD, VectorIndexSettings::VECTOR_TYPE_UINT8, distance); UNIT_ASSERT_VALUES_EQUAL(level, "__ydb_parent = 0, __ydb_id = 1, __ydb_centroid = mm\3\n" @@ -451,7 +460,7 @@ Y_UNIT_TEST_SUITE (TTxDataShardLocalKMeansScan) { seed = 111; for (auto distance : {VectorIndexSettings::DISTANCE_MANHATTAN, VectorIndexSettings::DISTANCE_EUCLIDEAN}) { - auto [level, posting] = DoLocalKMeans(server, sender, 0, seed, k, + auto [level, posting] = DoLocalKMeans(server, sender, 0, 0, seed, k, NKikimrTxDataShard::EKMeansState::UPLOAD_MAIN_TO_BUILD, VectorIndexSettings::VECTOR_TYPE_UINT8, distance); UNIT_ASSERT_VALUES_EQUAL(level, "__ydb_parent = 0, __ydb_id = 1, __ydb_centroid = 11\3\n" @@ -467,7 +476,7 @@ Y_UNIT_TEST_SUITE (TTxDataShardLocalKMeansScan) { for (auto similarity : {VectorIndexSettings::SIMILARITY_INNER_PRODUCT, VectorIndexSettings::SIMILARITY_COSINE, VectorIndexSettings::DISTANCE_COSINE}) { - auto [level, posting] = DoLocalKMeans(server, sender, 0, seed, k, + auto [level, posting] = DoLocalKMeans(server, sender, 0, 0, seed, k, NKikimrTxDataShard::EKMeansState::UPLOAD_MAIN_TO_BUILD, VectorIndexSettings::VECTOR_TYPE_UINT8, similarity); UNIT_ASSERT_VALUES_EQUAL(level, "__ydb_parent = 0, __ydb_id = 1, __ydb_centroid = II\3\n"); @@ -490,6 +499,7 @@ Y_UNIT_TEST_SUITE (TTxDataShardLocalKMeansScan) { auto sender = runtime.AllocateEdgeActor(); runtime.SetLogPriority(NKikimrServices::TX_DATASHARD, NLog::PRI_DEBUG); + runtime.SetLogPriority(NKikimrServices::BUILD_INDEX, NLog::PRI_TRACE); InitRoot(server, sender); @@ -528,7 +538,7 @@ Y_UNIT_TEST_SUITE (TTxDataShardLocalKMeansScan) { seed = 0; for (auto distance : {VectorIndexSettings::DISTANCE_MANHATTAN, VectorIndexSettings::DISTANCE_EUCLIDEAN}) { - auto [level, posting] = DoLocalKMeans(server, sender, 40, seed, k, + auto [level, posting] = DoLocalKMeans(server, sender, 40, 40, seed, k, NKikimrTxDataShard::EKMeansState::UPLOAD_BUILD_TO_POSTING, VectorIndexSettings::VECTOR_TYPE_UINT8, distance); UNIT_ASSERT_VALUES_EQUAL(level, "__ydb_parent = 40, __ydb_id = 41, __ydb_centroid = mm\3\n" @@ -543,7 +553,7 @@ Y_UNIT_TEST_SUITE (TTxDataShardLocalKMeansScan) { seed = 111; for (auto distance : {VectorIndexSettings::DISTANCE_MANHATTAN, VectorIndexSettings::DISTANCE_EUCLIDEAN}) { - auto [level, posting] = DoLocalKMeans(server, sender, 40, seed, k, + auto [level, posting] = DoLocalKMeans(server, sender, 40, 40, seed, k, NKikimrTxDataShard::EKMeansState::UPLOAD_BUILD_TO_POSTING, VectorIndexSettings::VECTOR_TYPE_UINT8, distance); UNIT_ASSERT_VALUES_EQUAL(level, "__ydb_parent = 40, __ydb_id = 41, __ydb_centroid = 11\3\n" @@ -559,7 +569,7 @@ Y_UNIT_TEST_SUITE (TTxDataShardLocalKMeansScan) { for (auto similarity : {VectorIndexSettings::SIMILARITY_INNER_PRODUCT, VectorIndexSettings::SIMILARITY_COSINE, VectorIndexSettings::DISTANCE_COSINE}) { - auto [level, posting] = DoLocalKMeans(server, sender, 40, seed, k, + auto [level, posting] = DoLocalKMeans(server, sender, 40, 40, seed, k, NKikimrTxDataShard::EKMeansState::UPLOAD_BUILD_TO_POSTING, VectorIndexSettings::VECTOR_TYPE_UINT8, similarity); UNIT_ASSERT_VALUES_EQUAL(level, "__ydb_parent = 40, __ydb_id = 41, __ydb_centroid = II\3\n"); @@ -582,6 +592,7 @@ Y_UNIT_TEST_SUITE (TTxDataShardLocalKMeansScan) { auto sender = runtime.AllocateEdgeActor(); runtime.SetLogPriority(NKikimrServices::TX_DATASHARD, NLog::PRI_DEBUG); + runtime.SetLogPriority(NKikimrServices::BUILD_INDEX, NLog::PRI_TRACE); InitRoot(server, sender); @@ -620,7 +631,7 @@ Y_UNIT_TEST_SUITE (TTxDataShardLocalKMeansScan) { seed = 0; for (auto distance : {VectorIndexSettings::DISTANCE_MANHATTAN, VectorIndexSettings::DISTANCE_EUCLIDEAN}) { - auto [level, posting] = DoLocalKMeans(server, sender, 40, seed, k, + auto [level, posting] = DoLocalKMeans(server, sender, 40, 40, seed, k, NKikimrTxDataShard::EKMeansState::UPLOAD_BUILD_TO_BUILD, VectorIndexSettings::VECTOR_TYPE_UINT8, distance); UNIT_ASSERT_VALUES_EQUAL(level, "__ydb_parent = 40, __ydb_id = 41, __ydb_centroid = mm\3\n" @@ -635,7 +646,7 @@ Y_UNIT_TEST_SUITE (TTxDataShardLocalKMeansScan) { seed = 111; for (auto distance : {VectorIndexSettings::DISTANCE_MANHATTAN, VectorIndexSettings::DISTANCE_EUCLIDEAN}) { - auto [level, posting] = DoLocalKMeans(server, sender, 40, seed, k, + auto [level, posting] = DoLocalKMeans(server, sender, 40, 40, seed, k, NKikimrTxDataShard::EKMeansState::UPLOAD_BUILD_TO_BUILD, VectorIndexSettings::VECTOR_TYPE_UINT8, distance); UNIT_ASSERT_VALUES_EQUAL(level, "__ydb_parent = 40, __ydb_id = 41, __ydb_centroid = 11\3\n" @@ -651,7 +662,7 @@ Y_UNIT_TEST_SUITE (TTxDataShardLocalKMeansScan) { for (auto similarity : {VectorIndexSettings::SIMILARITY_INNER_PRODUCT, VectorIndexSettings::SIMILARITY_COSINE, VectorIndexSettings::DISTANCE_COSINE}) { - auto [level, posting] = DoLocalKMeans(server, sender, 40, seed, k, + auto [level, posting] = DoLocalKMeans(server, sender, 40, 40, seed, k, NKikimrTxDataShard::EKMeansState::UPLOAD_BUILD_TO_BUILD, VectorIndexSettings::VECTOR_TYPE_UINT8, similarity); UNIT_ASSERT_VALUES_EQUAL(level, "__ydb_parent = 40, __ydb_id = 41, __ydb_centroid = II\3\n"); @@ -663,6 +674,228 @@ Y_UNIT_TEST_SUITE (TTxDataShardLocalKMeansScan) { recreate(); } } + + Y_UNIT_TEST (BuildToBuild_Ranges) { + TPortManager pm; + TServerSettings serverSettings(pm.GetPort(2134)); + serverSettings.SetDomainName("Root"); + + Tests::TServer::TPtr server = new TServer(serverSettings); + auto& runtime = *server->GetRuntime(); + auto sender = runtime.AllocateEdgeActor(); + + runtime.SetLogPriority(NKikimrServices::TX_DATASHARD, NLog::PRI_DEBUG); + runtime.SetLogPriority(NKikimrServices::BUILD_INDEX, NLog::PRI_TRACE); + + InitRoot(server, sender); + + TShardedTableOptions options; + options.EnableOutOfOrder(true); // TODO(mbkkt) what is it? + options.Shards(1); + + CreateBuildTable(server, sender, options, "table-main"); + // Upsert some initial values + ExecSQL(server, sender, + R"( + UPSERT INTO `/Root/table-main` + (__ydb_parent, key, embedding, data) + VALUES )" + "(39, 1, \"\x30\x30\3\", \"one\")," + "(39, 2, \"\x32\x32\3\", \"two\")," + "(40, 1, \"\x30\x30\3\", \"one\")," + "(40, 2, \"\x31\x31\3\", \"two\")," + "(40, 3, \"\x32\x32\3\", \"three\")," + "(40, 4, \"\x65\x65\3\", \"four\")," + "(40, 5, \"\x75\x75\3\", \"five\")," + "(41, 5, \"\x75\x75\3\", \"five2\")," + "(41, 6, \"\x76\x76\3\", \"six\");"); + + auto create = [&] { + CreateLevelTable(server, sender, options); + CreateBuildTable(server, sender, options, "table-posting"); + }; + create(); + auto recreate = [&] { + DropTable(server, sender, "table-level"); + DropTable(server, sender, "table-posting"); + create(); + }; + + { // ParentFrom = 39 ParentTo = 39 + for (ui32 maxBatchRows : {0, 1, 4, 5, 6, 50000}) { + auto [level, posting] = DoLocalKMeans(server, sender, + 39, 39, 111, 2, + NKikimrTxDataShard::EKMeansState::UPLOAD_BUILD_TO_BUILD, VectorIndexSettings::VECTOR_TYPE_UINT8, VectorIndexSettings::DISTANCE_MANHATTAN, + maxBatchRows); + UNIT_ASSERT_VALUES_EQUAL(level, + "__ydb_parent = 39, __ydb_id = 40, __ydb_centroid = 00\3\n" + "__ydb_parent = 39, __ydb_id = 41, __ydb_centroid = 22\3\n"); + UNIT_ASSERT_VALUES_EQUAL(posting, + "__ydb_parent = 40, key = 1, embedding = 00\3, data = one\n" + "__ydb_parent = 41, key = 2, embedding = 22\3, data = two\n"); + recreate(); + } + } + + { // ParentFrom = 40 ParentTo = 40 + for (ui32 maxBatchRows : {0, 1, 4, 5, 6, 50000}) { + auto [level, posting] = DoLocalKMeans(server, sender, + 40, 40, 111, 2, + NKikimrTxDataShard::EKMeansState::UPLOAD_BUILD_TO_BUILD, VectorIndexSettings::VECTOR_TYPE_UINT8, VectorIndexSettings::DISTANCE_MANHATTAN, + maxBatchRows); + UNIT_ASSERT_VALUES_EQUAL(level, + "__ydb_parent = 40, __ydb_id = 41, __ydb_centroid = 11\3\n" + "__ydb_parent = 40, __ydb_id = 42, __ydb_centroid = mm\3\n"); + UNIT_ASSERT_VALUES_EQUAL(posting, + "__ydb_parent = 41, key = 1, embedding = \x30\x30\3, data = one\n" + "__ydb_parent = 41, key = 2, embedding = \x31\x31\3, data = two\n" + "__ydb_parent = 41, key = 3, embedding = \x32\x32\3, data = three\n" + "__ydb_parent = 42, key = 4, embedding = \x65\x65\3, data = four\n" + "__ydb_parent = 42, key = 5, embedding = \x75\x75\3, data = five\n"); + recreate(); + } + } + + { // ParentFrom = 41 ParentTo = 41 + for (ui32 maxBatchRows : {0, 1, 4, 5, 6, 50000}) { + auto [level, posting] = DoLocalKMeans(server, sender, + 41, 41, 111, 2, + NKikimrTxDataShard::EKMeansState::UPLOAD_BUILD_TO_BUILD, VectorIndexSettings::VECTOR_TYPE_UINT8, VectorIndexSettings::DISTANCE_MANHATTAN, + maxBatchRows); + UNIT_ASSERT_VALUES_EQUAL(level, + "__ydb_parent = 41, __ydb_id = 42, __ydb_centroid = uu\3\n" + "__ydb_parent = 41, __ydb_id = 43, __ydb_centroid = vv\3\n"); + UNIT_ASSERT_VALUES_EQUAL(posting, + "__ydb_parent = 42, key = 5, embedding = uu\3, data = five2\n" + "__ydb_parent = 43, key = 6, embedding = vv\3, data = six\n"); + recreate(); + } + } + + { // ParentFrom = 39 ParentTo = 40 + for (ui32 maxBatchRows : {0, 1, 4, 5, 6, 50000}) { + auto [level, posting] = DoLocalKMeans(server, sender, + 39, 40, 111, 2, + NKikimrTxDataShard::EKMeansState::UPLOAD_BUILD_TO_BUILD, VectorIndexSettings::VECTOR_TYPE_UINT8, VectorIndexSettings::DISTANCE_MANHATTAN, + maxBatchRows); + UNIT_ASSERT_VALUES_EQUAL(level, + "__ydb_parent = 39, __ydb_id = 41, __ydb_centroid = 00\3\n" + "__ydb_parent = 39, __ydb_id = 42, __ydb_centroid = 22\3\n" + "__ydb_parent = 40, __ydb_id = 43, __ydb_centroid = 11\3\n" + "__ydb_parent = 40, __ydb_id = 44, __ydb_centroid = mm\3\n"); + UNIT_ASSERT_VALUES_EQUAL(posting, + "__ydb_parent = 41, key = 1, embedding = 00\3, data = one\n" + "__ydb_parent = 42, key = 2, embedding = 22\3, data = two\n" + "__ydb_parent = 43, key = 1, embedding = \x30\x30\3, data = one\n" + "__ydb_parent = 43, key = 2, embedding = \x31\x31\3, data = two\n" + "__ydb_parent = 43, key = 3, embedding = \x32\x32\3, data = three\n" + "__ydb_parent = 44, key = 4, embedding = \x65\x65\3, data = four\n" + "__ydb_parent = 44, key = 5, embedding = \x75\x75\3, data = five\n"); + recreate(); + } + } + + { // ParentFrom = 40 ParentTo = 41 + for (ui32 maxBatchRows : {0, 1, 4, 5, 6, 50000}) { + auto [level, posting] = DoLocalKMeans(server, sender, + 40, 41, 111, 2, + NKikimrTxDataShard::EKMeansState::UPLOAD_BUILD_TO_BUILD, VectorIndexSettings::VECTOR_TYPE_UINT8, VectorIndexSettings::DISTANCE_MANHATTAN, + maxBatchRows); + UNIT_ASSERT_VALUES_EQUAL(level, + "__ydb_parent = 40, __ydb_id = 42, __ydb_centroid = 11\3\n" + "__ydb_parent = 40, __ydb_id = 43, __ydb_centroid = mm\3\n" + "__ydb_parent = 41, __ydb_id = 44, __ydb_centroid = uu\3\n" + "__ydb_parent = 41, __ydb_id = 45, __ydb_centroid = vv\3\n"); + UNIT_ASSERT_VALUES_EQUAL(posting, + "__ydb_parent = 42, key = 1, embedding = \x30\x30\3, data = one\n" + "__ydb_parent = 42, key = 2, embedding = \x31\x31\3, data = two\n" + "__ydb_parent = 42, key = 3, embedding = \x32\x32\3, data = three\n" + "__ydb_parent = 43, key = 4, embedding = \x65\x65\3, data = four\n" + "__ydb_parent = 43, key = 5, embedding = \x75\x75\3, data = five\n" + "__ydb_parent = 44, key = 5, embedding = uu\3, data = five2\n" + "__ydb_parent = 45, key = 6, embedding = vv\3, data = six\n"); + recreate(); + } + } + + { // ParentFrom = 39 ParentTo = 41 + for (ui32 maxBatchRows : {0, 1, 4, 5, 6, 50000}) { + auto [level, posting] = DoLocalKMeans(server, sender, + 39, 41, 111, 2, + NKikimrTxDataShard::EKMeansState::UPLOAD_BUILD_TO_BUILD, VectorIndexSettings::VECTOR_TYPE_UINT8, VectorIndexSettings::DISTANCE_MANHATTAN, + maxBatchRows); + UNIT_ASSERT_VALUES_EQUAL(level, + "__ydb_parent = 39, __ydb_id = 42, __ydb_centroid = 00\3\n" + "__ydb_parent = 39, __ydb_id = 43, __ydb_centroid = 22\3\n" + "__ydb_parent = 40, __ydb_id = 44, __ydb_centroid = 11\3\n" + "__ydb_parent = 40, __ydb_id = 45, __ydb_centroid = mm\3\n" + "__ydb_parent = 41, __ydb_id = 46, __ydb_centroid = uu\3\n" + "__ydb_parent = 41, __ydb_id = 47, __ydb_centroid = vv\3\n"); + UNIT_ASSERT_VALUES_EQUAL(posting, + "__ydb_parent = 42, key = 1, embedding = 00\3, data = one\n" + "__ydb_parent = 43, key = 2, embedding = 22\3, data = two\n" + "__ydb_parent = 44, key = 1, embedding = \x30\x30\3, data = one\n" + "__ydb_parent = 44, key = 2, embedding = \x31\x31\3, data = two\n" + "__ydb_parent = 44, key = 3, embedding = \x32\x32\3, data = three\n" + "__ydb_parent = 45, key = 4, embedding = \x65\x65\3, data = four\n" + "__ydb_parent = 45, key = 5, embedding = \x75\x75\3, data = five\n" + "__ydb_parent = 46, key = 5, embedding = uu\3, data = five2\n" + "__ydb_parent = 47, key = 6, embedding = vv\3, data = six\n"); + recreate(); + } + } + + { // ParentFrom = 30 ParentTo = 50 + for (ui32 maxBatchRows : {0, 1, 4, 5, 6, 50000}) { + auto [level, posting] = DoLocalKMeans(server, sender, + 30, 50, 111, 2, + NKikimrTxDataShard::EKMeansState::UPLOAD_BUILD_TO_BUILD, VectorIndexSettings::VECTOR_TYPE_UINT8, VectorIndexSettings::DISTANCE_MANHATTAN, + maxBatchRows); + UNIT_ASSERT_VALUES_EQUAL(level, + "__ydb_parent = 39, __ydb_id = 69, __ydb_centroid = 00\3\n" + "__ydb_parent = 39, __ydb_id = 70, __ydb_centroid = 22\3\n" + "__ydb_parent = 40, __ydb_id = 71, __ydb_centroid = 11\3\n" + "__ydb_parent = 40, __ydb_id = 72, __ydb_centroid = mm\3\n" + "__ydb_parent = 41, __ydb_id = 73, __ydb_centroid = uu\3\n" + "__ydb_parent = 41, __ydb_id = 74, __ydb_centroid = vv\3\n"); + UNIT_ASSERT_VALUES_EQUAL(posting, + "__ydb_parent = 69, key = 1, embedding = 00\3, data = one\n" + "__ydb_parent = 70, key = 2, embedding = 22\3, data = two\n" + "__ydb_parent = 71, key = 1, embedding = \x30\x30\3, data = one\n" + "__ydb_parent = 71, key = 2, embedding = \x31\x31\3, data = two\n" + "__ydb_parent = 71, key = 3, embedding = \x32\x32\3, data = three\n" + "__ydb_parent = 72, key = 4, embedding = \x65\x65\3, data = four\n" + "__ydb_parent = 72, key = 5, embedding = \x75\x75\3, data = five\n" + "__ydb_parent = 73, key = 5, embedding = uu\3, data = five2\n" + "__ydb_parent = 74, key = 6, embedding = vv\3, data = six\n"); + recreate(); + } + } + + { // ParentFrom = 30 ParentTo = 31 + for (ui32 maxBatchRows : {0, 1, 4, 5, 6, 50000}) { + auto [level, posting] = DoLocalKMeans(server, sender, + 30, 31, 111, 2, + NKikimrTxDataShard::EKMeansState::UPLOAD_BUILD_TO_BUILD, VectorIndexSettings::VECTOR_TYPE_UINT8, VectorIndexSettings::DISTANCE_MANHATTAN, + maxBatchRows); + UNIT_ASSERT_VALUES_EQUAL(level, ""); + UNIT_ASSERT_VALUES_EQUAL(posting, ""); + recreate(); + } + } + + { // ParentFrom = 100 ParentTo = 101 + for (ui32 maxBatchRows : {0, 1, 4, 5, 6, 50000}) { + auto [level, posting] = DoLocalKMeans(server, sender, + 100, 101, 111, 2, + NKikimrTxDataShard::EKMeansState::UPLOAD_BUILD_TO_BUILD, VectorIndexSettings::VECTOR_TYPE_UINT8, VectorIndexSettings::DISTANCE_MANHATTAN, + maxBatchRows); + UNIT_ASSERT_VALUES_EQUAL(level, ""); + UNIT_ASSERT_VALUES_EQUAL(posting, ""); + recreate(); + } + } + } } } diff --git a/ydb/core/tx/datashard/datashard_ut_prefix_kmeans.cpp b/ydb/core/tx/datashard/datashard_ut_prefix_kmeans.cpp index 4bab0e10981..e221986fee4 100644 --- a/ydb/core/tx/datashard/datashard_ut_prefix_kmeans.cpp +++ b/ydb/core/tx/datashard/datashard_ut_prefix_kmeans.cpp @@ -291,19 +291,6 @@ Y_UNIT_TEST_SUITE (TTxDataShardPrefixKMeansScan) { DoBadRequest(server, sender, ev, 2, VectorIndexSettings::VECTOR_TYPE_FLOAT, VectorIndexSettings::METRIC_UNSPECIFIED); } - // TODO(mbkkt) For now all build_index, sample_k, build_columns, local_kmeans, prefix_kmeans doesn't really check this - // { - // auto ev = std::make_unique<TEvDataShard::TEvPrefixKMeansRequest>(); - // auto snapshotCopy = snapshot; - // snapshotCopy.Step++; - // DoBadRequest(server, sender, ev); - // } - // { - // auto ev = std::make_unique<TEvDataShard::TEvPrefixKMeansRequest>(); - // auto snapshotCopy = snapshot; - // snapshotCopy.TxId++; - // DoBadRequest(server, sender, ev); - // } } Y_UNIT_TEST (BuildToPosting) { diff --git a/ydb/core/tx/datashard/local_kmeans.cpp b/ydb/core/tx/datashard/local_kmeans.cpp index b39a09c3f60..5b83216c0df 100644 --- a/ydb/core/tx/datashard/local_kmeans.cpp +++ b/ydb/core/tx/datashard/local_kmeans.cpp @@ -22,44 +22,6 @@ namespace NKikimr::NDataShard { using namespace NKMeans; -class TResult { -public: - explicit TResult(const TActorId& responseActorId, TAutoPtr<TEvDataShard::TEvLocalKMeansResponse> response) - : ResponseActorId{responseActorId} - , Response{std::move(response)} { - Y_ASSERT(Response); - } - - void SizeAdd(i64 size) { - if (size != 0) { - std::lock_guard lock{Mutex}; - Size += size; - } - } - - template<typename Func> - void Send(const TActorContext& ctx, Func&& func) { - std::unique_lock lock{Mutex}; - if (Size <= 0) { - return; - } - if (func(Response->Record) && --Size > 0) { - return; - } - Size = 0; - lock.unlock(); - - LOG_N("Finish TLocalKMeansScan " << Response->Record.ShortDebugString()); - ctx.Send(ResponseActorId, std::move(Response)); - } - -private: - std::mutex Mutex; - i64 Size = 1; - TActorId ResponseActorId; - TAutoPtr<TEvDataShard::TEvLocalKMeansResponse> Response; -}; - // This scan needed to run local (not distributed) kmeans. // We have this local stage because we construct kmeans tree from top to bottom. // And bottom kmeans can be constructed completely locally in datashards to avoid extra communication. @@ -81,7 +43,7 @@ private: // NTable::IScan::Seek used to switch from current state to the next one. // If less than 1% of vectors are reassigned to new clusters we want to stop -// TODO(mbkkt) 1% is choosed by common sense and should be adjusted in future +// TODO(mbkkt) 1% is choosen by common sense and should be adjusted in future static constexpr double MinVectorsNeedsReassigned = 0.01; class TLocalKMeansScanBase: public TActor<TLocalKMeansScanBase>, public NTable::IScan { @@ -92,12 +54,13 @@ protected: NTableIndex::TClusterId Child = 0; ui32 Round = 0; - ui32 MaxRounds = 0; + const ui32 MaxRounds = 0; + const ui32 InitK = 0; ui32 K = 0; EState State; - EState UploadState; + const EState UploadState; IDriver* Driver = nullptr; @@ -124,14 +87,17 @@ protected: std::vector<ui64> ClusterSizes; // Upload - std::shared_ptr<NTxProxy::TUploadTypes> TargetTypes; - std::shared_ptr<NTxProxy::TUploadTypes> NextTypes; + std::shared_ptr<NTxProxy::TUploadTypes> LevelTypes; + std::shared_ptr<NTxProxy::TUploadTypes> PostingTypes; + std::shared_ptr<NTxProxy::TUploadTypes> UploadTypes; - TString TargetTable; - TString NextTable; + const TString LevelTable; + const TString PostingTable; + TString UploadTable; - TBufferData ReadBuf; - TBufferData WriteBuf; + TBufferData LevelBuf; + TBufferData PostingBuf; + TBufferData UploadBuf; NTable::TPos EmbeddingPos = 0; NTable::TPos DataPos = 1; @@ -141,15 +107,25 @@ protected: TActorId Uploader; const TIndexBuildScanSettings ScanSettings; - NTable::TTag KMeansScan; - TTags UploadScan; + NTable::TTag EmbeddingTag; + TTags ScanTags; TUploadStatus UploadStatus; ui64 UploadRows = 0; ui64 UploadBytes = 0; - std::shared_ptr<TResult> Result; + TActorId ResponseActorId; + TAutoPtr<TEvDataShard::TEvLocalKMeansResponse> Response; + + // FIXME: save PrefixRows as std::vector<std::pair<TSerializedCellVec, TSerializedCellVec>> to avoid parsing + const ui32 PrefixColumns; + TSerializedCellVec Prefix; + TBufferData PrefixRows; + bool IsFirstPrefixFeed = true; + bool IsPrefixRowsValid = true; + + bool IsExhausted = false; public: static constexpr NKikimrServices::TActivity::EType ActorActivityType() @@ -157,38 +133,45 @@ public: return NKikimrServices::TActivity::LOCAL_KMEANS_SCAN_ACTOR; } - TLocalKMeansScanBase(ui64 buildId, const TUserTable& table, TLead&& lead, NTableIndex::TClusterId parent, NTableIndex::TClusterId child, + TLocalKMeansScanBase(const TUserTable& table, const NKikimrTxDataShard::TEvLocalKMeansRequest& request, - std::shared_ptr<TResult> result) + const TActorId& responseActorId, + TAutoPtr<TEvDataShard::TEvLocalKMeansResponse>&& response, + TLead&& lead) : TActor{&TThis::StateWork} - , Parent{parent} - , Child{child} + , Parent{request.GetParentFrom()} + , Child{request.GetChild()} , MaxRounds{request.GetNeedsRounds()} + , InitK{request.GetK()} , K{request.GetK()} , State{EState::SAMPLE} , UploadState{request.GetUpload()} , Lead{std::move(lead)} - , BuildId{buildId} + , BuildId{request.GetId()} , Rng{request.GetSeed()} - , TargetTable{request.GetLevelName()} - , NextTable{request.GetPostingName()} + , LevelTable{request.GetLevelName()} + , PostingTable{request.GetPostingName()} , ScanSettings(request.GetScanSettings()) - , Result{std::move(result)} + , ResponseActorId{responseActorId} + , Response{std::move(response)} + , PrefixColumns{request.GetParentFrom() == 0 && request.GetParentTo() == 0 ? 0u : 1u} { const auto& embedding = request.GetEmbeddingColumn(); const auto& data = request.GetDataColumns(); // scan tags - UploadScan = MakeUploadTags(table, embedding, data, EmbeddingPos, DataPos, KMeansScan); + ScanTags = MakeUploadTags(table, embedding, data, EmbeddingPos, DataPos, EmbeddingTag); + Lead.SetTags(ScanTags); // upload types - if (Ydb::Type type; State <= EState::KMEANS) { - TargetTypes = std::make_shared<NTxProxy::TUploadTypes>(3); + { + Ydb::Type type; + LevelTypes = std::make_shared<NTxProxy::TUploadTypes>(3); type.set_type_id(NTableIndex::ClusterIdType); - (*TargetTypes)[0] = {NTableIndex::NTableVectorKmeansTreeIndex::ParentColumn, type}; - (*TargetTypes)[1] = {NTableIndex::NTableVectorKmeansTreeIndex::IdColumn, type}; + (*LevelTypes)[0] = {NTableIndex::NTableVectorKmeansTreeIndex::ParentColumn, type}; + (*LevelTypes)[1] = {NTableIndex::NTableVectorKmeansTreeIndex::IdColumn, type}; type.set_type_id(Ydb::Type::STRING); - (*TargetTypes)[2] = {NTableIndex::NTableVectorKmeansTreeIndex::CentroidColumn, type}; + (*LevelTypes)[2] = {NTableIndex::NTableVectorKmeansTreeIndex::CentroidColumn, type}; } - NextTypes = MakeUploadTypes(table, UploadState, embedding, data); + PostingTypes = MakeUploadTypes(table, UploadState, embedding, data); } TInitialState Prepare(IDriver* driver, TIntrusiveConstPtr<TScheme>) final @@ -202,30 +185,27 @@ public: TAutoPtr<IDestructable> Finish(EAbort abort) final { - LOG_D("Finish " << Debug()); - if (Uploader) { - Send(Uploader, new TEvents::TEvPoisonPill); + Send(Uploader, new TEvents::TEvPoison); Uploader = {}; } - Result->Send(TActivationContext::AsActorContext(), [&] (NKikimrTxDataShard::TEvLocalKMeansResponse& response) { - response.SetReadRows(ReadRows); - response.SetReadBytes(ReadBytes); - response.SetUploadRows(UploadRows); - response.SetUploadBytes(UploadBytes); - NYql::IssuesToMessage(UploadStatus.Issues, response.MutableIssues()); - if (abort != EAbort::None) { - response.SetStatus(NKikimrIndexBuilder::EBuildStatus::ABORTED); - return false; - } else if (UploadStatus.IsSuccess()) { - response.SetStatus(NKikimrIndexBuilder::EBuildStatus::DONE); - return true; - } else { - response.SetStatus(NKikimrIndexBuilder::EBuildStatus::BUILD_ERROR); - return false; - } - }); + auto& record = Response->Record; + record.SetReadRows(ReadRows); + record.SetReadBytes(ReadBytes); + record.SetUploadRows(UploadRows); + record.SetUploadBytes(UploadBytes); + if (abort != EAbort::None) { + record.SetStatus(NKikimrIndexBuilder::EBuildStatus::ABORTED); + } else if (UploadStatus.IsNone() || UploadStatus.IsSuccess()) { + record.SetStatus(NKikimrIndexBuilder::EBuildStatus::DONE); + } else { + record.SetStatus(NKikimrIndexBuilder::EBuildStatus::BUILD_ERROR); + } + NYql::IssuesToMessage(UploadStatus.Issues, record.MutableIssues()); + + LOG_N("Finish " << Debug() << " " << Response->Record.ShortDebugString()); + Send(ResponseActorId, Response.Release()); Driver = nullptr; this->PassAway(); @@ -240,9 +220,10 @@ public: TString Debug() const { return TStringBuilder() << "TLocalKMeansScan Id: " << BuildId << " Parent: " << Parent << " Child: " << Child - << " Target: " << TargetTable << " K: " << K << " Clusters: " << Clusters.size() + << " K: " << K << " Clusters: " << Clusters.size() << " State: " << State << " Round: " << Round << " / " << MaxRounds - << " ReadBuf size: " << ReadBuf.Size() << " WriteBuf size: " << WriteBuf.Size(); + << " LevelBuf size: " << LevelBuf.Size() << " PostingBuf size: " << PostingBuf.Size() + << " UploadTable: " << UploadTable << " UploadBuf size: " << UploadBuf.Size() << " RetryCount: " << RetryCount; } EScan PageFault() final @@ -258,28 +239,31 @@ protected: HFunc(TEvTxUserProxy::TEvUploadRowsResponse, Handle); CFunc(TEvents::TSystem::Wakeup, HandleWakeup); default: - LOG_E("TLocalKMeansScan: StateWork unexpected event type: " << ev->GetTypeRewrite() << " event: " - << ev->ToString() << " " << Debug()); + LOG_E("StateWork unexpected event type: " << ev->GetTypeRewrite() + << " event: " << ev->ToString() << " " << Debug()); } } void HandleWakeup(const NActors::TActorContext& /*ctx*/) { - LOG_I("Retry upload " << Debug()); + LOG_D("Retry upload " << Debug()); - if (!WriteBuf.IsEmpty()) { - Upload(true); + if (UploadInProgress()) { + RetryUpload(); } } void Handle(TEvTxUserProxy::TEvUploadRowsResponse::TPtr& ev, const TActorContext& ctx) { LOG_D("Handle TEvUploadRowsResponse " << Debug() - << " Uploader: " << Uploader.ToString() << " ev->Sender: " << ev->Sender.ToString()); + << " Uploader: " << (Uploader ? Uploader.ToString() : "<null>") + << " ev->Sender: " << ev->Sender.ToString()); if (Uploader) { - Y_ENSURE(Uploader == ev->Sender, "Mismatch Uploader: " << Uploader.ToString() << " ev->Sender: " - << ev->Sender.ToString() << Debug()); + Y_ENSURE(Uploader == ev->Sender, "Mismatch" + << " Uploader: " << Uploader.ToString() + << " Sender: " << ev->Sender.ToString()); + Uploader = {}; } else { Y_ENSURE(Driver == nullptr); return; @@ -288,81 +272,109 @@ protected: UploadStatus.StatusCode = ev->Get()->Status; UploadStatus.Issues = ev->Get()->Issues; if (UploadStatus.IsSuccess()) { - UploadRows += WriteBuf.GetRows(); - UploadBytes += WriteBuf.GetBytes(); - WriteBuf.Clear(); - if (HasReachedLimits(ReadBuf, ScanSettings)) { - ReadBuf.FlushTo(WriteBuf); - Upload(false); - } + UploadRows += UploadBuf.GetRows(); + UploadBytes += UploadBuf.GetBytes(); + UploadBuf.Clear(); + + TryUpload(LevelBuf, LevelTable, LevelTypes, true) + || TryUpload(PostingBuf, PostingTable, PostingTypes, true); Driver->Touch(EScan::Feed); return; } if (RetryCount < ScanSettings.GetMaxBatchRetries() && UploadStatus.IsRetriable()) { - LOG_N("Got retriable error, " << Debug() << UploadStatus.ToString()); + LOG_N("Got retriable error, " << Debug() << " " << UploadStatus.ToString()); ctx.Schedule(GetRetryWakeupTimeoutBackoff(RetryCount), new TEvents::TEvWakeup()); return; } - LOG_N("Got error, abort scan, " << Debug() << UploadStatus.ToString()); + LOG_N("Got error, abort scan, " << Debug() << " " << UploadStatus.ToString()); Driver->Touch(EScan::Final); } - EScan FeedUpload() + bool ShouldWaitUpload() { - if (!HasReachedLimits(ReadBuf, ScanSettings)) { - return EScan::Feed; + if (!HasReachedLimits(LevelBuf, ScanSettings) && !HasReachedLimits(PostingBuf, ScanSettings)) { + return false; } - if (!WriteBuf.IsEmpty()) { - return EScan::Sleep; + + if (UploadInProgress()) { + return true; } - ReadBuf.FlushTo(WriteBuf); - Upload(false); - return EScan::Feed; - } + + TryUpload(LevelBuf, LevelTable, LevelTypes, true) + || TryUpload(PostingBuf, PostingTable, PostingTypes, true); - ui64 GetProbability() - { - return Rng.GenRand64(); + return !HasReachedLimits(LevelBuf, ScanSettings) && !HasReachedLimits(PostingBuf, ScanSettings); } - void Upload(bool isRetry) + void UploadImpl() { - if (isRetry) { - ++RetryCount; - } else { - RetryCount = 0; - if (State != EState::KMEANS && NextTypes) { - TargetTypes = std::exchange(NextTypes, {}); - TargetTable = std::move(NextTable); - } - } + LOG_D("Uploading " << Debug()); + Y_ASSERT(!UploadBuf.IsEmpty()); + Y_ASSERT(!Uploader); auto actor = NTxProxy::CreateUploadRowsInternal( - this->SelfId(), TargetTable, TargetTypes, WriteBuf.GetRowsData(), + this->SelfId(), UploadTable, UploadTypes, UploadBuf.GetRowsData(), NTxProxy::EUploadRowsMode::WriteToTableShadow, true /*writeToPrivateTable*/); Uploader = this->Register(actor); } - void UploadSample() + void InitUpload(std::string_view table, std::shared_ptr<NTxProxy::TUploadTypes> types) + { + RetryCount = 0; + UploadTable = table; + UploadTypes = std::move(types); + UploadImpl(); + } + + void RetryUpload() + { + ++RetryCount; + UploadImpl(); + } + + bool UploadInProgress() + { + return !UploadBuf.IsEmpty(); + } + + bool TryUpload(TBufferData& buffer, const TString& table, const std::shared_ptr<NTxProxy::TUploadTypes>& types, bool byLimit) + { + if (Y_UNLIKELY(UploadInProgress())) { + // already uploading something + return true; + } + + if (!buffer.IsEmpty() && (!byLimit || HasReachedLimits(buffer, ScanSettings))) { + buffer.FlushTo(UploadBuf); + InitUpload(table, types); + return true; + } + + return false; + } + + void FormLevelRows() { - Y_ASSERT(ReadBuf.IsEmpty()); - Y_ASSERT(WriteBuf.IsEmpty()); std::array<TCell, 2> pk; std::array<TCell, 1> data; for (NTable::TPos pos = 0; const auto& row : Clusters) { pk[0] = TCell::Make(Parent); pk[1] = TCell::Make(Child + pos); data[0] = TCell{row}; - WriteBuf.AddRow(TSerializedCellVec{pk}, TSerializedCellVec::Serialize(data)); + LevelBuf.AddRow(TSerializedCellVec{pk}, TSerializedCellVec::Serialize(data)); ++pos; } - Upload(false); + } + + ui64 GetProbability() + { + return Rng.GenRand64(); } }; @@ -377,10 +389,29 @@ class TLocalKMeansScan final: public TLocalKMeansScanBase, private TCalculation< }; std::vector<TAggregatedCluster> AggregatedClusters; + void StartNewPrefix() { + Round = 0; + K = InitK; + State = EState::SAMPLE; + Lead.Valid = true; + Lead.Key = TSerializedCellVec(Prefix.GetCells()); // seek to (prefix, inf) + Lead.Relation = NTable::ESeek::Upper; + Prefix = {}; + IsFirstPrefixFeed = true; + IsPrefixRowsValid = true; + PrefixRows.Clear(); + MaxProbability = std::numeric_limits<ui64>::max(); + MaxRows.clear(); + Clusters.clear(); + ClusterSizes.clear(); + AggregatedClusters.clear(); + } + public: - TLocalKMeansScan(ui64 buildId, const TUserTable& table, TLead&& lead, NTableIndex::TClusterId parent, NTableIndex::TClusterId child, NKikimrTxDataShard::TEvLocalKMeansRequest& request, - std::shared_ptr<TResult> result) - : TLocalKMeansScanBase{buildId, table, std::move(lead), parent, child, request, std::move(result)} + TLocalKMeansScan(const TUserTable& table, NKikimrTxDataShard::TEvLocalKMeansRequest& request, + const TActorId& responseActorId, TAutoPtr<TEvDataShard::TEvLocalKMeansResponse>&& response, + TLead&& lead) + : TLocalKMeansScanBase{table, request, responseActorId, std::move(response), std::move(lead)} { this->Dimensions = request.GetSettings().vector_dimension(); LOG_I("Create " << Debug()); @@ -388,90 +419,143 @@ public: EScan Seek(TLead& lead, ui64 seq) final { - LOG_D("Seek " << Debug()); - if (State == UploadState) { - if (!WriteBuf.IsEmpty()) { - return EScan::Sleep; - } - if (!ReadBuf.IsEmpty()) { - ReadBuf.FlushTo(WriteBuf); - Upload(false); + LOG_D("Seek " << seq << " " << Debug()); + + if (IsExhausted) { + if (UploadInProgress() + || TryUpload(LevelBuf, LevelTable, LevelTypes, false) + || TryUpload(PostingBuf, PostingTable, PostingTypes, false)) + { return EScan::Sleep; } - if (UploadStatus.IsNone()) { - UploadStatus.StatusCode = Ydb::StatusIds::SUCCESS; - } return EScan::Final; } - if (State == EState::SAMPLE) { - lead = Lead; - lead.SetTags({&KMeansScan, 1}); - if (seq == 0) { - return EScan::Feed; + lead = Lead; + + return EScan::Feed; + } + + EScan Feed(TArrayRef<const TCell> key, const TRow& row) final + { + LOG_T("Feed " << Debug()); + + ++ReadRows; + ReadBytes += CountBytes(key, row); + + if (PrefixColumns && Prefix && !TCellVectorsEquals{}(Prefix.GetCells(), key.subspan(0, PrefixColumns))) { + if (!FinishPrefix()) { + // scan current prefix rows with a new state again + return EScan::Reset; } - State = EState::KMEANS; - if (!InitAggregatedClusters()) { - // We don't need to do anything, - // because this datashard doesn't have valid embeddings for this parent - if (UploadStatus.IsNone()) { - UploadStatus.StatusCode = Ydb::StatusIds::SUCCESS; - } - return EScan::Final; + } + + if (PrefixColumns && !Prefix) { + Prefix = TSerializedCellVec{key.subspan(0, PrefixColumns)}; + auto newParent = key.at(0).template AsValue<ui64>(); + Child += (newParent - Parent) * InitK; + Parent = newParent; + } + + if (IsFirstPrefixFeed && IsPrefixRowsValid) { + PrefixRows.AddRow(TSerializedCellVec{key}, TSerializedCellVec::Serialize(*row)); + if (HasReachedLimits(PrefixRows, ScanSettings)) { + PrefixRows.Clear(); + IsPrefixRowsValid = false; } - ++Round; - return EScan::Feed; } - Y_ASSERT(State == EState::KMEANS); - if (RecomputeClusters()) { - lead = std::move(Lead); - lead.SetTags(UploadScan); + Feed(key, *row); + + return ShouldWaitUpload() ? EScan::Sleep : EScan::Feed; + } - UploadSample(); - State = UploadState; + EScan Exhausted() final + { + LOG_D("Exhausted " << Debug()); + + if (!FinishPrefix()) { + return EScan::Reset; + } + + IsExhausted = true; + + // call Seek to wait uploads + return EScan::Reset; + } + +private: + bool FinishPrefix() + { + if (FinishPrefixImpl()) { + StartNewPrefix(); + LOG_D("FinishPrefix finished " << Debug()); + return true; } else { - lead = Lead; - lead.SetTags({&KMeansScan, 1}); - ++Round; + IsFirstPrefixFeed = false; + + if (IsPrefixRowsValid) { + LOG_D("FinishPrefix not finished, manually feeding " << PrefixRows.Size() << " saved rows " << Debug()); + for (ui64 iteration = 0; ; iteration++) { + for (const auto& [key, row_] : *PrefixRows.GetRowsData()) { + TSerializedCellVec row(row_); + Feed(key.GetCells(), row.GetCells()); + } + if (FinishPrefixImpl()) { + StartNewPrefix(); + LOG_D("FinishPrefix finished in " << iteration << " iterations " << Debug()); + return true; + } else { + LOG_D("FinishPrefix not finished in " << iteration << " iterations " << Debug()); + } + } + } else { + LOG_D("FinishPrefix not finished, rescanning rows " << Debug()); + } + + return false; } - return EScan::Feed; } - EScan Feed(TArrayRef<const TCell> key, const TRow& row_) final + bool FinishPrefixImpl() { - LOG_T("Feed " << Debug()); + if (State == EState::SAMPLE) { + State = EState::KMEANS; + if (!InitAggregatedClusters()) { + // We don't need to do anything, + // because this datashard doesn't have valid embeddings for this prefix + return true; + } + Round = 1; + return false; // do KMEANS + } - ++ReadRows; - ReadBytes += CountBytes(key, row_); - auto row = *row_; - - switch (State) { - case EState::SAMPLE: - return FeedSample(row); - case EState::KMEANS: - return FeedKMeans(row); - case EState::UPLOAD_MAIN_TO_BUILD: - return FeedUploadMain2Build(key, row); - case EState::UPLOAD_MAIN_TO_POSTING: - return FeedUploadMain2Posting(key, row); - case EState::UPLOAD_BUILD_TO_BUILD: - return FeedUploadBuild2Build(key, row); - case EState::UPLOAD_BUILD_TO_POSTING: - return FeedUploadBuild2Posting(key, row); - default: - return EScan::Final; + if (State == EState::KMEANS) { + if (RecomputeClusters()) { + FormLevelRows(); + State = UploadState; + return false; // do UPLOAD_* + } else { + ++Round; + return false; // recompute KMEANS + } } + + if (State == UploadState) { + return true; + } + + Y_ASSERT(false); + return true; } -private: bool InitAggregatedClusters() { if (Clusters.size() == 0) { return false; } if (Clusters.size() < K) { - // if this datashard have smaller than K count of valid embeddings for this parent + // if this datashard have less than K valid embeddings for this parent // lets make single centroid for it K = 1; Clusters.resize(K); @@ -544,12 +628,37 @@ private: return true; } - EScan FeedSample(TArrayRef<const TCell> row) + void Feed(TArrayRef<const TCell> key, TArrayRef<const TCell> row) { - Y_ASSERT(row.size() == 1); - const auto embedding = row.at(0).AsRef(); + switch (State) { + case EState::SAMPLE: + FeedSample(row); + break; + case EState::KMEANS: + FeedKMeans(row); + break; + case EState::UPLOAD_MAIN_TO_BUILD: + FeedUploadMain2Build(key, row); + break; + case EState::UPLOAD_MAIN_TO_POSTING: + FeedUploadMain2Posting(key, row); + break; + case EState::UPLOAD_BUILD_TO_BUILD: + FeedUploadBuild2Build(key, row); + break; + case EState::UPLOAD_BUILD_TO_POSTING: + FeedUploadBuild2Posting(key, row); + break; + default: + Y_ASSERT(false); + } + } + + void FeedSample(TArrayRef<const TCell> row) + { + const auto embedding = row.at(EmbeddingPos).AsRef(); if (!this->IsExpectedSize(embedding)) { - return EScan::Feed; + return; } const auto probability = GetProbability(); @@ -568,55 +677,44 @@ private: std::push_heap(MaxRows.begin(), MaxRows.end()); MaxProbability = MaxRows.front().P; } - return MaxProbability != 0 ? EScan::Feed : EScan::Reset; } - EScan FeedKMeans(TArrayRef<const TCell> row) + void FeedKMeans(TArrayRef<const TCell> row) { - Y_ASSERT(row.size() == 1); - const ui32 pos = FeedEmbedding(*this, Clusters, row, 0); - AggregateToCluster(pos, row.at(0).Data()); - return EScan::Feed; + const ui32 pos = FeedEmbedding(*this, Clusters, row, EmbeddingPos); + AggregateToCluster(pos, row.at(EmbeddingPos).Data()); } - EScan FeedUploadMain2Build(TArrayRef<const TCell> key, TArrayRef<const TCell> row) + void FeedUploadMain2Build(TArrayRef<const TCell> key, TArrayRef<const TCell> row) { const ui32 pos = FeedEmbedding(*this, Clusters, row, EmbeddingPos); - if (pos >= K) { - return EScan::Feed; + if (pos < K) { + AddRowMain2Build(PostingBuf, Child + pos, key, row); } - AddRowMain2Build(ReadBuf, Child + pos, key, row); - return FeedUpload(); } - EScan FeedUploadMain2Posting(TArrayRef<const TCell> key, TArrayRef<const TCell> row) + void FeedUploadMain2Posting(TArrayRef<const TCell> key, TArrayRef<const TCell> row) { const ui32 pos = FeedEmbedding(*this, Clusters, row, EmbeddingPos); - if (pos >= K) { - return EScan::Feed; + if (pos < K) { + AddRowMain2Posting(PostingBuf, Child + pos, key, row, DataPos); } - AddRowMain2Posting(ReadBuf, Child + pos, key, row, DataPos); - return FeedUpload(); } - EScan FeedUploadBuild2Build(TArrayRef<const TCell> key, TArrayRef<const TCell> row) + void FeedUploadBuild2Build(TArrayRef<const TCell> key, TArrayRef<const TCell> row) { const ui32 pos = FeedEmbedding(*this, Clusters, row, EmbeddingPos); - if (pos >= K) { - return EScan::Feed; + if (pos < K) { + AddRowBuild2Build(PostingBuf, Child + pos, key, row); } - AddRowBuild2Build(ReadBuf, Child + pos, key, row); - return FeedUpload(); } - EScan FeedUploadBuild2Posting(TArrayRef<const TCell> key, TArrayRef<const TCell> row) + void FeedUploadBuild2Posting(TArrayRef<const TCell> key, TArrayRef<const TCell> row) { const ui32 pos = FeedEmbedding(*this, Clusters, row, EmbeddingPos); - if (pos >= K) { - return EScan::Feed; + if (pos < K) { + AddRowBuild2Posting(PostingBuf, Child + pos, key, row, DataPos); } - AddRowBuild2Posting(ReadBuf, Child + pos, key, row, DataPos); - return FeedUpload(); } }; @@ -671,22 +769,14 @@ void TDataShard::HandleSafe(TEvDataShard::TEvLocalKMeansRequest::TPtr& ev, const TScanRecord::TSeqNo seqNo = {request.GetSeqNoGeneration(), request.GetSeqNoRound()}; response->Record.SetRequestSeqNoGeneration(seqNo.Generation); response->Record.SetRequestSeqNoRound(seqNo.Round); - auto result = std::make_shared<TResult>(ev->Sender, std::move(response)); - ui32 localTid = 0; - TScanRecord::TScanIds scanIds; auto badRequest = [&](const TString& error) { - for (auto scanId : scanIds) { - CancelScan(localTid, scanId); - } - result->Send(ctx, [&] (NKikimrTxDataShard::TEvLocalKMeansResponse& response) { - response.SetStatus(NKikimrIndexBuilder::EBuildStatus::BAD_REQUEST); - auto issue = response.AddIssues(); - issue->set_severity(NYql::TSeverityIds::S_ERROR); - issue->set_message(error); - return false; - }); - result.reset(); + response->Record.SetStatus(NKikimrIndexBuilder::EBuildStatus::BAD_REQUEST); + auto issue = response->Record.AddIssues(); + issue->set_severity(NYql::TSeverityIds::S_ERROR); + issue->set_message(error); + ctx.Send(ev->Sender, std::move(response)); + response.Reset(); }; if (const ui64 shardId = request.GetTabletId(); shardId != TabletID()) { @@ -715,8 +805,6 @@ void TDataShard::HandleSafe(TEvDataShard::TEvLocalKMeansRequest::TPtr& ev, const ScanManager.Drop(id); } - localTid = userTable.LocalTid; - if (request.HasSnapshotStep() || request.HasSnapshotTxId()) { const TSnapshotKey snapshotKey(pathId, rowVersion.Step, rowVersion.TxId); if (!SnapshotManager.FindAvailable(snapshotKey)) { @@ -743,45 +831,44 @@ void TDataShard::HandleSafe(TEvDataShard::TEvLocalKMeansRequest::TPtr& ev, const badRequest(TStringBuilder() << "Parent from " << parentFrom << " should be less or equal to parent to " << parentTo); return; } - const i64 expectedSize = parentTo - parentFrom + 1; - result->SizeAdd(expectedSize); - for (auto parent = parentFrom; parent <= parentTo; ++parent) { - TCell from, to; - const auto range = CreateRangeFrom(userTable, parent, from, to); - if (range.IsEmptyRange(userTable.KeyColumnTypes)) { - LOG_D("TEvLocalKMeansRequst " << request.GetId() << " parent " << parent << " is empty"); - continue; - } - - TAutoPtr<NTable::IScan> scan; - auto createScan = [&]<typename T> { - scan = new TLocalKMeansScan<T>{ - request.GetId(), userTable, - CreateLeadFrom(range), parent, request.GetChild() + request.GetK() * (parent - parentFrom), - request, result, - }; - }; - MakeScan(request, createScan, badRequest); - if (!scan) { - Y_ASSERT(!result); + NTable::TLead lead; + if (parentFrom == 0 && parentTo == 0) { + lead.To({}, NTable::ESeek::Lower); + } else { + TCell from = TCell::Make(parentFrom - 1); + TCell to = TCell::Make(parentTo); + TTableRange range{{&from, 1}, false, {&to, 1}, true}; + auto scanRange = Intersect(userTable.KeyColumnTypes, range, userTable.Range.ToTableRange()); + if (scanRange.IsEmptyRange(userTable.KeyColumnTypes)) { + badRequest(TStringBuilder() << " requested range doesn't intersect with table range" + << " requestedRange: " << DebugPrintRange(userTable.KeyColumnTypes, range, *AppData()->TypeRegistry) + << " tableRange: " << DebugPrintRange(userTable.KeyColumnTypes, userTable.Range.ToTableRange(), *AppData()->TypeRegistry) + << " scanRange: " << DebugPrintRange(userTable.KeyColumnTypes, scanRange, *AppData()->TypeRegistry)); return; } - - TScanOptions scanOpts; - scanOpts.SetSnapshotRowVersion(rowVersion); - scanOpts.SetResourceBroker("build_index", 10); // TODO(mbkkt) Should be different group? - const auto scanId = QueueScan(userTable.LocalTid, std::move(scan), 0, scanOpts); - scanIds.push_back(scanId); + lead.To(range.From, NTable::ESeek::Upper); + lead.Until(range.To, true); } - - if (scanIds.empty()) { - badRequest("Requested range doesn't intersect with table range"); + + TAutoPtr<NTable::IScan> scan; + auto createScan = [&]<typename T> { + scan = new TLocalKMeansScan<T>{ + userTable, request, ev->Sender, std::move(response), + std::move(lead) + }; + }; + MakeScan(request, createScan, badRequest); + if (!scan) { + Y_ASSERT(!response); return; } - result->SizeAdd(static_cast<i64>(scanIds.size()) - expectedSize); - result->Send(ctx, [] (auto&) { return true; }); // decrement extra one - ScanManager.Set(id, seqNo) = std::move(scanIds); + + TScanOptions scanOpts; + scanOpts.SetSnapshotRowVersion(rowVersion); + scanOpts.SetResourceBroker("build_index", 10); // TODO(mbkkt) Should be different group? + const auto scanId = QueueScan(userTable.LocalTid, std::move(scan), 0, scanOpts); + ScanManager.Set(id, seqNo).push_back(scanId); } } diff --git a/ydb/core/tx/datashard/prefix_kmeans.cpp b/ydb/core/tx/datashard/prefix_kmeans.cpp index 5a36035d782..8742c3cf8ff 100644 --- a/ydb/core/tx/datashard/prefix_kmeans.cpp +++ b/ydb/core/tx/datashard/prefix_kmeans.cpp @@ -101,12 +101,12 @@ protected: TAutoPtr<TEvDataShard::TEvPrefixKMeansResponse> Response; // FIXME: save PrefixRows as std::vector<std::pair<TSerializedCellVec, TSerializedCellVec>> to avoid parsing - ui32 PrefixColumns; + const ui32 PrefixColumns; TSerializedCellVec Prefix; TBufferData PrefixRows; bool IsFirstPrefixFeed = true; bool IsPrefixRowsValid = true; - + bool IsExhausted = false; public: @@ -153,6 +153,7 @@ public: (*LevelTypes)[2] = {NTableIndex::NTableVectorKmeansTreeIndex::CentroidColumn, type}; } PostingTypes = MakeUploadTypes(table, UploadState, embedding, data, PrefixColumns); + // prefix types { auto types = GetAllTypes(table); @@ -242,8 +243,8 @@ protected: HFunc(TEvTxUserProxy::TEvUploadRowsResponse, Handle); CFunc(TEvents::TSystem::Wakeup, HandleWakeup); default: - LOG_E("TPrefixKMeansScan: StateWork unexpected event type: " << ev->GetTypeRewrite() << " event: " - << ev->ToString() << " " << Debug()); + LOG_E("StateWork unexpected event type: " << ev->GetTypeRewrite() + << " event: " << ev->ToString() << " " << Debug()); } } @@ -299,11 +300,6 @@ protected: Driver->Touch(EScan::Final); } - ui64 GetProbability() - { - return Rng.GenRand64(); - } - bool ShouldWaitUpload() { if (!HasReachedLimits(LevelBuf, ScanSettings) && !HasReachedLimits(PostingBuf, ScanSettings) && !HasReachedLimits(PrefixBuf, ScanSettings)) { @@ -381,6 +377,11 @@ protected: ++pos; } } + + ui64 GetProbability() + { + return Rng.GenRand64(); + } }; template <typename TMetric> @@ -394,17 +395,13 @@ class TPrefixKMeansScan final: public TPrefixKMeansScanBase, private TCalculatio }; std::vector<TAggregatedCluster> AggregatedClusters; - void StartNewPrefix() { Parent = Child + K; Child = Parent + 1; Round = 0; K = InitK; State = EState::SAMPLE; - // TODO(mbkkt) Upper or Lower doesn't matter here, because we seek to (prefix, inf) - // so we can choose Lower if it's faster. - // Exact seek with Lower also possible but needs to rewrite some code in Feed - Lead.To(Prefix.GetCells(), NTable::ESeek::Upper); + Lead.To(Prefix.GetCells(), NTable::ESeek::Upper); // seek to (prefix, inf) Prefix = {}; IsFirstPrefixFeed = true; IsPrefixRowsValid = true; @@ -804,6 +801,11 @@ void TDataShard::HandleSafe(TEvDataShard::TEvPrefixKMeansRequest::TPtr& ev, cons return; } + if (request.GetPrefixColumns() <= 0) { + badRequest("Should be requested on at least one prefix column"); + return; + } + TAutoPtr<NTable::IScan> scan; auto createScan = [&]<typename T> { scan = new TPrefixKMeansScan<T>{ diff --git a/ydb/core/tx/datashard/reshuffle_kmeans.cpp b/ydb/core/tx/datashard/reshuffle_kmeans.cpp index f7db951e60b..f3a7f900876 100644 --- a/ydb/core/tx/datashard/reshuffle_kmeans.cpp +++ b/ydb/core/tx/datashard/reshuffle_kmeans.cpp @@ -189,8 +189,8 @@ protected: hFunc(TEvTxUserProxy::TEvUploadRowsResponse, Handle); cFunc(TEvents::TSystem::Wakeup, HandleWakeup); default: - LOG_E("TReshuffleKMeansScan: StateWork unexpected event type: " << ev->GetTypeRewrite() << " event: " - << ev->ToString() << " " << Debug()); + LOG_E("StateWork unexpected event type: " << ev->GetTypeRewrite() + << " event: " << ev->ToString() << " " << Debug()); } } diff --git a/ydb/core/tx/datashard/sample_k.cpp b/ydb/core/tx/datashard/sample_k.cpp index e8472a4d7ec..011418b5e8a 100644 --- a/ydb/core/tx/datashard/sample_k.cpp +++ b/ydb/core/tx/datashard/sample_k.cpp @@ -180,7 +180,8 @@ private: STFUNC(StateWork) { switch (ev->GetTypeRewrite()) { default: - LOG_E("TSampleKScan: StateWork unexpected event type: " << ev->GetTypeRewrite() << " event: " << ev->ToString() << " " << Debug()); + LOG_E("StateWork unexpected event type: " << ev->GetTypeRewrite() + << " event: " << ev->ToString() << " " << Debug()); } } |