diff options
author | Vitaliy Filippov <vitalif@mail.ru> | 2025-05-19 11:54:39 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-05-19 11:54:39 +0300 |
commit | b7d90103f4b5381fa94e1c08a58c123713e47007 (patch) | |
tree | 7b45c87768405ad846d23ecd1bebd83fb7d84ec7 | |
parent | 4ad14e0c307069246bc5b1a07b7f541a11fe4788 (diff) | |
download | ydb-b7d90103f4b5381fa94e1c08a58c123713e47007.tar.gz |
Implement index-only searches with covering vector indexes (#17770) (#18137)
11 files changed, 288 insertions, 97 deletions
diff --git a/ydb/core/base/table_index.cpp b/ydb/core/base/table_index.cpp index f9bc4df1c3b..f3c828d8db1 100644 --- a/ydb/core/base/table_index.cpp +++ b/ydb/core/base/table_index.cpp @@ -154,7 +154,11 @@ bool IsCompatibleIndex(NKikimrSchemeOp::EIndexType indexType, const TTableColumn } tmp.clear(); tmp.insert(table.Keys.begin(), table.Keys.end()); - tmp.insert(index.KeyColumns.begin(), index.KeyColumns.end() - (isSecondaryIndex ? 0 : 1)); + if (isSecondaryIndex) { + tmp.insert(index.KeyColumns.begin(), index.KeyColumns.end()); + } else { + // Vector indexes allow to add all columns both to index & data + } if (const auto* broken = IsContains(index.DataColumns, tmp, true)) { explain = TStringBuilder() << "the same column can't be used as key and data column for one index, for example " << *broken; diff --git a/ydb/core/kqp/opt/logical/kqp_opt_log_indexes.cpp b/ydb/core/kqp/opt/logical/kqp_opt_log_indexes.cpp index 104b0c199a7..5012241b5c8 100644 --- a/ydb/core/kqp/opt/logical/kqp_opt_log_indexes.cpp +++ b/ydb/core/kqp/opt/logical/kqp_opt_log_indexes.cpp @@ -17,8 +17,7 @@ using namespace NYql::NNodes; namespace { -TCoAtomList BuildKeyColumnsList(const TKikimrTableDescription& /* table */, TPositionHandle pos, TExprContext& ctx, - const auto& columnsToSelect) { +TCoAtomList BuildKeyColumnsList(TPositionHandle pos, TExprContext& ctx, const auto& columnsToSelect) { TVector<TExprBase> columnsList; columnsList.reserve(columnsToSelect.size()); for (auto column : columnsToSelect) { @@ -35,7 +34,7 @@ TCoAtomList BuildKeyColumnsList(const TKikimrTableDescription& /* table */, TPos } TCoAtomList BuildKeyColumnsList(const TKikimrTableDescription& table, TPositionHandle pos, TExprContext& ctx) { - return BuildKeyColumnsList(table, pos, ctx, table.Metadata->KeyColumnNames); + return BuildKeyColumnsList(pos, ctx, table.Metadata->KeyColumnNames); } TCoAtomList MergeColumns(const NNodes::TCoAtomList& col1, const TVector<TString>& col2, TExprContext& ctx) { @@ -306,21 +305,20 @@ struct TReadMatch { } }; -template<typename TRead> -bool CheckIndexCovering(const TRead& read, const TIntrusivePtr<TKikimrTableMetadata>& indexMeta) { - for (const auto& col : read.Columns()) { +bool CheckIndexCovering(const TCoAtomList& readColumns, const TIntrusivePtr<TKikimrTableMetadata>& indexMeta) { + for (const auto& col : readColumns) { if (!indexMeta->Columns.contains(col.StringValue())) { - return true; + return false; } } - return false; + return true; } TExprBase DoRewriteIndexRead(const TReadMatch& read, TExprContext& ctx, const TKikimrTableDescription& tableDesc, TIntrusivePtr<TKikimrTableMetadata> indexMeta, const TVector<TString>& extraColumns, const std::function<TExprBase(const TExprBase&)>& middleFilter = {}) { - const bool needDataRead = CheckIndexCovering(read, indexMeta); + const bool isCovered = CheckIndexCovering(read.Columns(), indexMeta); if (read.FullScan()) { const auto indexName = read.Index().StringValue(); @@ -329,7 +327,7 @@ TExprBase DoRewriteIndexRead(const TReadMatch& read, TExprContext& ctx, ctx.AddWarning(issue); } - if (!needDataRead) { + if (isCovered) { // We can read all data from index table. auto ret = read.BuildRead(ctx, BuildTableMeta(*indexMeta, read.Pos(), ctx), read.Columns()); @@ -517,26 +515,41 @@ void VectorReadLevel( void VectorReadMain( TExprContext& ctx, TPositionHandle pos, - const TKqpTable& postingTable, const TCoAtomList& postingColumns, - const TKqpTable& mainTable, const TCoAtomList& mainColumns, + const TKqpTable& postingTable, + const TIntrusivePtr<TKikimrTableMetadata> & postingTableMeta, + const TKqpTable& mainTable, + const TIntrusivePtr<TKikimrTableMetadata> & mainTableMeta, + const TCoAtomList& mainColumns, TExprNodePtr& read) { - // TODO(mbkkt) handle covered index columns TKqpStreamLookupSettings settings; settings.Strategy = EStreamLookupStrategyType::LookupRows; - read = Build<TKqlStreamLookupTable>(ctx, pos) - .Table(postingTable) - .LookupKeys(read) - .Columns(postingColumns) - .Settings(settings.BuildNode(ctx, pos)) - .Done().Ptr(); + const bool isCovered = CheckIndexCovering(mainColumns, postingTableMeta); - read = Build<TKqlStreamLookupTable>(ctx, pos) - .Table(mainTable) - .LookupKeys(read) - .Columns(mainColumns) - .Settings(settings.BuildNode(ctx, pos)) - .Done().Ptr(); + if (!isCovered) { + const auto postingColumns = BuildKeyColumnsList(pos, ctx, mainTableMeta->KeyColumnNames); + + read = Build<TKqlStreamLookupTable>(ctx, pos) + .Table(postingTable) + .LookupKeys(read) + .Columns(postingColumns) + .Settings(settings.BuildNode(ctx, pos)) + .Done().Ptr(); + + read = Build<TKqlStreamLookupTable>(ctx, pos) + .Table(mainTable) + .LookupKeys(read) + .Columns(mainColumns) + .Settings(settings.BuildNode(ctx, pos)) + .Done().Ptr(); + } else { + read = Build<TKqlStreamLookupTable>(ctx, pos) + .Table(postingTable) + .LookupKeys(read) + .Columns(mainColumns) + .Settings(settings.BuildNode(ctx, pos)) + .Done().Ptr(); + } } void VectorTopMain(TExprContext& ctx, const TCoTopBase& top, TExprNodePtr& read) { @@ -568,9 +581,8 @@ TExprBase DoRewriteTopSortOverKMeansTree( const auto postingTable = BuildTableMeta(*postingTableDesc->Metadata, pos, ctx); const auto mainTable = BuildTableMeta(*tableDesc.Metadata, pos, ctx); - const auto levelColumns = BuildKeyColumnsList(*levelTableDesc, pos, ctx, + const auto levelColumns = BuildKeyColumnsList(pos, ctx, std::initializer_list<std::string_view>{NTableIndex::NTableVectorKmeansTreeIndex::IdColumn, NTableIndex::NTableVectorKmeansTreeIndex::CentroidColumn}); - const auto postingColumns = BuildKeyColumnsList(*postingTableDesc, pos, ctx, tableDesc.Metadata->KeyColumnNames); const auto& mainColumns = match.Columns(); TNodeOnNodeOwnedMap replaces; @@ -601,7 +613,7 @@ TExprBase DoRewriteTopSortOverKMeansTree( VectorReadLevel(indexDesc, ctx, pos, kqpCtx, levelLambda, top, levelTable, levelColumns, read); - VectorReadMain(ctx, pos, postingTable, postingColumns, mainTable, mainColumns, read); + VectorReadMain(ctx, pos, postingTable, postingTableDesc->Metadata, mainTable, tableDesc.Metadata, mainColumns, read); if (flatMap) { read = Build<TCoFlatMap>(ctx, flatMap.Cast().Pos()) @@ -638,13 +650,12 @@ TExprBase DoRewriteTopSortOverPrefixedKMeansTree( const auto prefixTable = BuildTableMeta(*prefixTableDesc->Metadata, pos, ctx); const auto mainTable = BuildTableMeta(*tableDesc.Metadata, pos, ctx); - const auto levelColumns = BuildKeyColumnsList(*levelTableDesc, pos, ctx, + const auto levelColumns = BuildKeyColumnsList(pos, ctx, std::initializer_list<std::string_view>{NTableIndex::NTableVectorKmeansTreeIndex::IdColumn, NTableIndex::NTableVectorKmeansTreeIndex::CentroidColumn}); - const auto postingColumns = BuildKeyColumnsList(*postingTableDesc, pos, ctx, tableDesc.Metadata->KeyColumnNames); const auto prefixColumns = [&] { auto columns = indexDesc.KeyColumns; columns.back().assign(NTableIndex::NTableVectorKmeansTreeIndex::IdColumn); - return BuildKeyColumnsList(*prefixTableDesc, pos, ctx, columns); + return BuildKeyColumnsList(pos, ctx, columns); }(); const auto& mainColumns = match.Columns(); @@ -688,7 +699,7 @@ TExprBase DoRewriteTopSortOverPrefixedKMeansTree( VectorReadLevel(indexDesc, ctx, pos, kqpCtx, levelLambda, top, levelTable, levelColumns, read); - VectorReadMain(ctx, pos, postingTable, postingColumns, mainTable, mainColumns, read); + VectorReadMain(ctx, pos, postingTable, postingTableDesc->Metadata, mainTable, tableDesc.Metadata, mainColumns, read); if (mainLambda) { read = Build<TCoMap>(ctx, flatMap.Pos()) @@ -735,9 +746,9 @@ TExprBase KqpRewriteLookupIndex(const TExprBase& node, TExprContext& ctx, const YQL_ENSURE(indexDesc->Type != TIndexDescription::EType::GlobalSyncVectorKMeansTree, "lookup doesn't support vector index: " << indexName); - const bool needDataRead = CheckIndexCovering(lookupIndex, implTable); + const bool isCovered = CheckIndexCovering(lookupIndex.Columns(), implTable); - if (!needDataRead) { + if (isCovered) { TKqpStreamLookupSettings settings; settings.Strategy = EStreamLookupStrategyType::LookupRows; return Build<TKqlStreamLookupTable>(ctx, node.Pos()) @@ -785,8 +796,8 @@ TExprBase KqpRewriteStreamLookupIndex(const TExprBase& node, TExprContext& ctx, YQL_ENSURE(indexDesc->Type != TIndexDescription::EType::GlobalSyncVectorKMeansTree, "stream lookup doesn't support vector index: " << indexName); - const bool needDataRead = CheckIndexCovering(streamLookupIndex, implTable); - if (!needDataRead) { + const bool isCovered = CheckIndexCovering(streamLookupIndex.Columns(), implTable); + if (isCovered) { return Build<TKqlStreamLookupTable>(ctx, node.Pos()) .Table(BuildTableMeta(*implTable, node.Pos(), ctx)) .LookupKeys(streamLookupIndex.LookupKeys()) diff --git a/ydb/core/kqp/ut/indexes/kqp_indexes_prefixed_vector_ut.cpp b/ydb/core/kqp/ut/indexes/kqp_indexes_prefixed_vector_ut.cpp index 266570f3fec..33279262c0a 100644 --- a/ydb/core/kqp/ut/indexes/kqp_indexes_prefixed_vector_ut.cpp +++ b/ydb/core/kqp/ut/indexes/kqp_indexes_prefixed_vector_ut.cpp @@ -26,11 +26,20 @@ using namespace NYdb::NTable; Y_UNIT_TEST_SUITE(KqpPrefixedVectorIndexes) { - std::vector<i64> DoPositiveQueryVectorIndex(TSession& session, const TString& query) { + std::vector<i64> DoPositiveQueryVectorIndex(TSession& session, const TString& query, bool covered = false) { { auto result = session.ExplainDataQuery(query).ExtractValueSync(); UNIT_ASSERT_C(result.IsSuccess(), "Failed to explain: `" << query << "` with " << result.GetIssues().ToString()); + + if (covered) { + // Check that the query doesn't use main table + NJson::TJsonValue plan; + NJson::ReadJsonTree(result.GetPlan(), &plan, true); + UNIT_ASSERT(ValidatePlanNodeIds(plan)); + auto mainTableAccess = CountPlanNodesByKv(plan, "Table", "TestTable"); + UNIT_ASSERT_VALUES_EQUAL(mainTableAccess, 0); + } } { auto result = session.ExecuteDataQuery(query, @@ -53,7 +62,7 @@ Y_UNIT_TEST_SUITE(KqpPrefixedVectorIndexes) { } } - void DoPositiveQueriesVectorIndex(TSession& session, const TString& mainQuery, const TString& indexQuery) { + void DoPositiveQueriesVectorIndex(TSession& session, const TString& mainQuery, const TString& indexQuery, bool covered = false) { auto toStr = [](const auto& rs) -> TString { TStringBuilder b; for (const auto& r : rs) { @@ -66,7 +75,7 @@ Y_UNIT_TEST_SUITE(KqpPrefixedVectorIndexes) { UNIT_ASSERT_EQUAL_C(mainResults.size(), 3, toStr(mainResults)); UNIT_ASSERT_C(std::unique(mainResults.begin(), mainResults.end()) == mainResults.end(), toStr(mainResults)); - auto indexResults = DoPositiveQueryVectorIndex(session, indexQuery); + auto indexResults = DoPositiveQueryVectorIndex(session, indexQuery, covered); absl::c_sort(indexResults); UNIT_ASSERT_EQUAL_C(indexResults.size(), 3, toStr(indexResults)); UNIT_ASSERT_C(std::unique(indexResults.begin(), indexResults.end()) == indexResults.end(), toStr(indexResults)); @@ -79,13 +88,17 @@ Y_UNIT_TEST_SUITE(KqpPrefixedVectorIndexes) { std::string_view function, std::string_view direction, std::string_view left, - std::string_view right) { + std::string_view right, + bool covered = false) { constexpr std::string_view init = "$target = \"\x67\x68\x03\";\n" "$user = \"user_b\";"; std::string metric = std::format("Knn::{}({}, {})", function, left, right); // no metric in result { + // TODO(vitaliff): Exclude index-covered WHERE fields from KqpReadTableRanges. + // Currently even if we SELECT only pk, emb, data WHERE user=xxx we also get `user` + // in SELECT columns and thus it's required to add it to covered columns. const TString plainQuery(Q1_(std::format(R"({} SELECT * FROM `/Root/TestTable` WHERE user = $user @@ -100,7 +113,7 @@ Y_UNIT_TEST_SUITE(KqpPrefixedVectorIndexes) { ORDER BY {} {} LIMIT 3; )", init, metric, direction))); - DoPositiveQueriesVectorIndex(session, plainQuery, indexQuery); + DoPositiveQueriesVectorIndex(session, plainQuery, indexQuery, covered); } // metric in result { @@ -117,7 +130,7 @@ Y_UNIT_TEST_SUITE(KqpPrefixedVectorIndexes) { ORDER BY {} {} LIMIT 3; )", init, metric, metric, direction))); - DoPositiveQueriesVectorIndex(session, plainQuery, indexQuery); + DoPositiveQueriesVectorIndex(session, plainQuery, indexQuery, covered); } // metric as result // TODO(mbkkt) fix this behavior too @@ -136,27 +149,28 @@ Y_UNIT_TEST_SUITE(KqpPrefixedVectorIndexes) { ORDER BY m {} LIMIT 3; )", init, metric, direction))); - DoPositiveQueriesVectorIndex(session, plainQuery, indexQuery); + DoPositiveQueriesVectorIndex(session, plainQuery, indexQuery, covered); } } void DoPositiveQueriesPrefixedVectorIndexOrderBy( TSession& session, std::string_view function, - std::string_view direction) { + std::string_view direction, + bool covered = false) { // target is left, member is right - DoPositiveQueriesPrefixedVectorIndexOrderBy(session, function, direction, "$target", "emb"); + DoPositiveQueriesPrefixedVectorIndexOrderBy(session, function, direction, "$target", "emb", covered); // target is right, member is left - DoPositiveQueriesPrefixedVectorIndexOrderBy(session, function, direction, "emb", "$target"); + DoPositiveQueriesPrefixedVectorIndexOrderBy(session, function, direction, "emb", "$target", covered); } - void DoPositiveQueriesPrefixedVectorIndexOrderByCosine(TSession& session) { + void DoPositiveQueriesPrefixedVectorIndexOrderByCosine(TSession& session, bool covered = false) { // distance, default direction - DoPositiveQueriesPrefixedVectorIndexOrderBy(session, "CosineDistance", ""); + DoPositiveQueriesPrefixedVectorIndexOrderBy(session, "CosineDistance", "", covered); // distance, asc direction - DoPositiveQueriesPrefixedVectorIndexOrderBy(session, "CosineDistance", "ASC"); + DoPositiveQueriesPrefixedVectorIndexOrderBy(session, "CosineDistance", "ASC", covered); // similarity, desc direction - DoPositiveQueriesPrefixedVectorIndexOrderBy(session, "CosineSimilarity", "DESC"); + DoPositiveQueriesPrefixedVectorIndexOrderBy(session, "CosineSimilarity", "DESC", covered); } TSession DoCreateTableForPrefixedVectorIndex(TTableClient& db, bool nullable) { @@ -406,6 +420,7 @@ Y_UNIT_TEST_SUITE(KqpPrefixedVectorIndexes) { .ExtractValueSync(); UNIT_ASSERT_C(result.IsSuccess(), result.GetIssues().ToString()); + // FIXME: result does not return failure/issues when index is created but fails to be filled with data } { auto result = session.DescribeTable("/Root/TestTable").ExtractValueSync(); @@ -425,6 +440,54 @@ Y_UNIT_TEST_SUITE(KqpPrefixedVectorIndexes) { DoPositiveQueriesPrefixedVectorIndexOrderByCosine(session); } + Y_UNIT_TEST_TWIN(PrefixedVectorIndexOrderByCosineDistanceWithCover, Nullable) { + NKikimrConfig::TFeatureFlags featureFlags; + featureFlags.SetEnableVectorIndex(true); + auto setting = NKikimrKqp::TKqpSetting(); + auto serverSettings = TKikimrSettings() + .SetFeatureFlags(featureFlags) + .SetKqpSettings({setting}); + + TKikimrRunner kikimr(serverSettings); + kikimr.GetTestServer().GetRuntime()->SetLogPriority(NKikimrServices::BUILD_INDEX, NActors::NLog::PRI_TRACE); + kikimr.GetTestServer().GetRuntime()->SetLogPriority(NKikimrServices::FLAT_TX_SCHEMESHARD, NActors::NLog::PRI_TRACE); + + auto db = kikimr.GetTableClient(); + auto session = DoCreateTableForPrefixedVectorIndex(db, Nullable); + { + const TString createIndex(Q_(R"( + ALTER TABLE `/Root/TestTable` + ADD INDEX index + GLOBAL USING vector_kmeans_tree + ON (user, emb) COVER (user, emb, data) + WITH (distance=cosine, vector_type="uint8", vector_dimension=2, levels=2, clusters=2); + )")); + + auto result = session.ExecuteSchemeQuery(createIndex) + .ExtractValueSync(); + + UNIT_ASSERT_C(result.IsSuccess(), result.GetIssues().ToString()); + } + { + auto result = session.DescribeTable("/Root/TestTable").ExtractValueSync(); + UNIT_ASSERT_VALUES_EQUAL(result.GetStatus(), NYdb::EStatus::SUCCESS); + const auto& indexes = result.GetTableDescription().GetIndexDescriptions(); + UNIT_ASSERT_EQUAL(indexes.size(), 1); + UNIT_ASSERT_EQUAL(indexes[0].GetIndexName(), "index"); + std::vector<std::string> indexKeyColumns{"user", "emb"}; + UNIT_ASSERT_EQUAL(indexes[0].GetIndexColumns(), indexKeyColumns); + std::vector<std::string> indexDataColumns{"user", "emb", "data"}; + UNIT_ASSERT_EQUAL(indexes[0].GetDataColumns(), indexDataColumns); + const auto& settings = std::get<TKMeansTreeSettings>(indexes[0].GetIndexSettings()); + UNIT_ASSERT_EQUAL(settings.Settings.Metric, NYdb::NTable::TVectorIndexSettings::EMetric::CosineDistance); + UNIT_ASSERT_EQUAL(settings.Settings.VectorType, NYdb::NTable::TVectorIndexSettings::EVectorType::Uint8); + UNIT_ASSERT_EQUAL(settings.Settings.VectorDimension, 2); + UNIT_ASSERT_EQUAL(settings.Levels, 2); + UNIT_ASSERT_EQUAL(settings.Clusters, 2); + } + DoPositiveQueriesPrefixedVectorIndexOrderByCosine(session, true /*covered*/); + } + } } diff --git a/ydb/core/kqp/ut/indexes/kqp_indexes_vector_ut.cpp b/ydb/core/kqp/ut/indexes/kqp_indexes_vector_ut.cpp index 9e29f71bc31..d74cb5327a5 100644 --- a/ydb/core/kqp/ut/indexes/kqp_indexes_vector_ut.cpp +++ b/ydb/core/kqp/ut/indexes/kqp_indexes_vector_ut.cpp @@ -27,11 +27,20 @@ using namespace NYdb::NTable; Y_UNIT_TEST_SUITE(KqpVectorIndexes) { - std::vector<i64> DoPositiveQueryVectorIndex(TSession& session, const TString& query) { + std::vector<i64> DoPositiveQueryVectorIndex(TSession& session, const TString& query, bool covered = false) { { auto result = session.ExplainDataQuery(query).ExtractValueSync(); UNIT_ASSERT_C(result.IsSuccess(), "Failed to explain: `" << query << "` with " << result.GetIssues().ToString()); + + if (covered) { + // Check that the query doesn't use main table + NJson::TJsonValue plan; + NJson::ReadJsonTree(result.GetPlan(), &plan, true); + UNIT_ASSERT(ValidatePlanNodeIds(plan)); + auto mainTableAccess = CountPlanNodesByKv(plan, "Table", "TestTable"); + UNIT_ASSERT_VALUES_EQUAL(mainTableAccess, 0); + } } { auto result = session.ExecuteDataQuery(query, @@ -54,7 +63,7 @@ Y_UNIT_TEST_SUITE(KqpVectorIndexes) { } } - void DoPositiveQueriesVectorIndex(TSession& session, const TString& mainQuery, const TString& indexQuery) { + void DoPositiveQueriesVectorIndex(TSession& session, const TString& mainQuery, const TString& indexQuery, bool covered = false) { auto toStr = [](const auto& rs) -> TString { TStringBuilder b; for (const auto& r : rs) { @@ -67,7 +76,7 @@ Y_UNIT_TEST_SUITE(KqpVectorIndexes) { UNIT_ASSERT_EQUAL_C(mainResults.size(), 3, toStr(mainResults)); UNIT_ASSERT_C(std::unique(mainResults.begin(), mainResults.end()) == mainResults.end(), toStr(mainResults)); - auto indexResults = DoPositiveQueryVectorIndex(session, indexQuery); + auto indexResults = DoPositiveQueryVectorIndex(session, indexQuery, covered); absl::c_sort(indexResults); UNIT_ASSERT_EQUAL_C(indexResults.size(), 3, toStr(indexResults)); UNIT_ASSERT_C(std::unique(indexResults.begin(), indexResults.end()) == indexResults.end(), toStr(indexResults)); @@ -80,7 +89,8 @@ Y_UNIT_TEST_SUITE(KqpVectorIndexes) { std::string_view function, std::string_view direction, std::string_view left, - std::string_view right) { + std::string_view right, + bool covered = false) { constexpr std::string_view target = "$target = \"\x67\x71\x03\";"; std::string metric = std::format("Knn::{}({}, {})", function, left, right); // no metric in result @@ -97,7 +107,7 @@ Y_UNIT_TEST_SUITE(KqpVectorIndexes) { ORDER BY {} {} LIMIT 3; )", target, metric, direction))); - DoPositiveQueriesVectorIndex(session, plainQuery, indexQuery); + DoPositiveQueriesVectorIndex(session, plainQuery, indexQuery, covered); } // metric in result { @@ -112,7 +122,7 @@ Y_UNIT_TEST_SUITE(KqpVectorIndexes) { ORDER BY {} {} LIMIT 3; )", target, metric, metric, direction))); - DoPositiveQueriesVectorIndex(session, plainQuery, indexQuery); + DoPositiveQueriesVectorIndex(session, plainQuery, indexQuery, covered); } // metric as result { @@ -128,27 +138,28 @@ Y_UNIT_TEST_SUITE(KqpVectorIndexes) { ORDER BY m {} LIMIT 3; )", target, metric, direction))); - DoPositiveQueriesVectorIndex(session, plainQuery, indexQuery); + DoPositiveQueriesVectorIndex(session, plainQuery, indexQuery, covered); } } void DoPositiveQueriesVectorIndexOrderBy( TSession& session, std::string_view function, - std::string_view direction) { + std::string_view direction, + bool covered = false) { // target is left, member is right - DoPositiveQueriesVectorIndexOrderBy(session, function, direction, "$target", "emb"); + DoPositiveQueriesVectorIndexOrderBy(session, function, direction, "$target", "emb", covered); // target is right, member is left - DoPositiveQueriesVectorIndexOrderBy(session, function, direction, "emb", "$target"); + DoPositiveQueriesVectorIndexOrderBy(session, function, direction, "emb", "$target", covered); } - void DoPositiveQueriesVectorIndexOrderByCosine(TSession& session) { + void DoPositiveQueriesVectorIndexOrderByCosine(TSession& session, bool covered = false) { // distance, default direction - DoPositiveQueriesVectorIndexOrderBy(session, "CosineDistance", ""); + DoPositiveQueriesVectorIndexOrderBy(session, "CosineDistance", "", covered); // distance, asc direction - DoPositiveQueriesVectorIndexOrderBy(session, "CosineDistance", "ASC"); + DoPositiveQueriesVectorIndexOrderBy(session, "CosineDistance", "ASC", covered); // similarity, desc direction - DoPositiveQueriesVectorIndexOrderBy(session, "CosineSimilarity", "DESC"); + DoPositiveQueriesVectorIndexOrderBy(session, "CosineSimilarity", "DESC", covered); } TSession DoCreateTableForVectorIndex(TTableClient& db, bool nullable) { @@ -461,6 +472,54 @@ Y_UNIT_TEST_SUITE(KqpVectorIndexes) { UNIT_ASSERT_STRINGS_UNEQUAL(originalPostingTable, postingTable2); } + Y_UNIT_TEST_TWIN(SimpleVectorIndexOrderByCosineDistanceWithCover, Nullable) { + NKikimrConfig::TFeatureFlags featureFlags; + featureFlags.SetEnableVectorIndex(true); + auto setting = NKikimrKqp::TKqpSetting(); + auto serverSettings = TKikimrSettings() + .SetFeatureFlags(featureFlags) + .SetKqpSettings({setting}); + + TKikimrRunner kikimr(serverSettings); + kikimr.GetTestServer().GetRuntime()->SetLogPriority(NKikimrServices::BUILD_INDEX, NActors::NLog::PRI_TRACE); + kikimr.GetTestServer().GetRuntime()->SetLogPriority(NKikimrServices::FLAT_TX_SCHEMESHARD, NActors::NLog::PRI_TRACE); + + auto db = kikimr.GetTableClient(); + auto session = DoCreateTableForVectorIndex(db, Nullable); + { + const TString createIndex(Q_(R"( + ALTER TABLE `/Root/TestTable` + ADD INDEX index + GLOBAL USING vector_kmeans_tree + ON (emb) COVER (emb, data) + WITH (distance=cosine, vector_type="uint8", vector_dimension=2, levels=2, clusters=2); + )")); + + auto result = session.ExecuteSchemeQuery(createIndex) + .ExtractValueSync(); + + UNIT_ASSERT_C(result.IsSuccess(), result.GetIssues().ToString()); + } + { + auto result = session.DescribeTable("/Root/TestTable").ExtractValueSync(); + UNIT_ASSERT_VALUES_EQUAL(result.GetStatus(), NYdb::EStatus::SUCCESS); + const auto& indexes = result.GetTableDescription().GetIndexDescriptions(); + UNIT_ASSERT_EQUAL(indexes.size(), 1); + UNIT_ASSERT_EQUAL(indexes[0].GetIndexName(), "index"); + std::vector<std::string> indexKeyColumns{"emb"}; + UNIT_ASSERT_EQUAL(indexes[0].GetIndexColumns(), indexKeyColumns); + std::vector<std::string> indexDataColumns{"emb", "data"}; + UNIT_ASSERT_EQUAL(indexes[0].GetDataColumns(), indexDataColumns); + const auto& settings = std::get<TKMeansTreeSettings>(indexes[0].GetIndexSettings()); + UNIT_ASSERT_EQUAL(settings.Settings.Metric, NYdb::NTable::TVectorIndexSettings::EMetric::CosineDistance); + UNIT_ASSERT_EQUAL(settings.Settings.VectorType, NYdb::NTable::TVectorIndexSettings::EVectorType::Uint8); + UNIT_ASSERT_EQUAL(settings.Settings.VectorDimension, 2); + UNIT_ASSERT_EQUAL(settings.Levels, 2); + UNIT_ASSERT_EQUAL(settings.Clusters, 2); + } + DoPositiveQueriesVectorIndexOrderByCosine(session, true /*covered*/); + } + } } diff --git a/ydb/core/tx/datashard/build_index/kmeans_helper.cpp b/ydb/core/tx/datashard/build_index/kmeans_helper.cpp index 2e2e6c8b40e..17326a93018 100644 --- a/ydb/core/tx/datashard/build_index/kmeans_helper.cpp +++ b/ydb/core/tx/datashard/build_index/kmeans_helper.cpp @@ -118,7 +118,9 @@ MakeUploadTypes(const TUserTable& table, NKikimrTxDataShard::EKMeansState upload switch (uploadState) { case NKikimrTxDataShard::EKMeansState::UPLOAD_MAIN_TO_BUILD: case NKikimrTxDataShard::EKMeansState::UPLOAD_BUILD_TO_BUILD: - addType(embedding); + if (auto it = std::find(data.begin(), data.end(), embedding); it == data.end()) { + addType(embedding); + } [[fallthrough]]; case NKikimrTxDataShard::EKMeansState::UPLOAD_MAIN_TO_POSTING: case NKikimrTxDataShard::EKMeansState::UPLOAD_BUILD_TO_POSTING: { diff --git a/ydb/core/tx/schemeshard/schemeshard__operation_create_build_index.cpp b/ydb/core/tx/schemeshard/schemeshard__operation_create_build_index.cpp index 31f792184af..c906eacb828 100644 --- a/ydb/core/tx/schemeshard/schemeshard__operation_create_build_index.cpp +++ b/ydb/core/tx/schemeshard/schemeshard__operation_create_build_index.cpp @@ -141,11 +141,12 @@ TVector<ISubOperation::TPtr> CreateBuildIndex(TOperationId opId, const TTxTransa indexPrefixTableDesc = indexDesc.GetIndexImplTableDescriptions(2); } } - const THashSet<TString> indexKeyColumns{indexDesc.GetKeyColumnNames().begin(), indexDesc.GetKeyColumnNames().end() - 1}; + const THashSet<TString> indexDataColumns{indexDesc.GetDataColumnNames().begin(), indexDesc.GetDataColumnNames().end()}; result.push_back(createImplTable(CalcVectorKmeansTreeLevelImplTableDesc(tableInfo->PartitionConfig(), indexLevelTableDesc))); - result.push_back(createImplTable(CalcVectorKmeansTreePostingImplTableDesc(indexKeyColumns, tableInfo, tableInfo->PartitionConfig(), implTableColumns, indexPostingTableDesc))); + result.push_back(createImplTable(CalcVectorKmeansTreePostingImplTableDesc(tableInfo, tableInfo->PartitionConfig(), indexDataColumns, indexPostingTableDesc))); if (prefixVectorIndex) { - result.push_back(createImplTable(CalcVectorKmeansTreePrefixImplTableDesc(indexKeyColumns, tableInfo, tableInfo->PartitionConfig(), implTableColumns, indexPrefixTableDesc))); + const THashSet<TString> prefixColumns{indexDesc.GetKeyColumnNames().begin(), indexDesc.GetKeyColumnNames().end() - 1}; + result.push_back(createImplTable(CalcVectorKmeansTreePrefixImplTableDesc(prefixColumns, tableInfo, tableInfo->PartitionConfig(), implTableColumns, indexPrefixTableDesc))); } } else { NKikimrSchemeOp::TTableDescription indexTableDesc; diff --git a/ydb/core/tx/schemeshard/schemeshard__operation_create_indexed_table.cpp b/ydb/core/tx/schemeshard/schemeshard__operation_create_indexed_table.cpp index c47eb8fde2b..c7c45ec3f32 100644 --- a/ydb/core/tx/schemeshard/schemeshard__operation_create_indexed_table.cpp +++ b/ydb/core/tx/schemeshard/schemeshard__operation_create_indexed_table.cpp @@ -294,11 +294,12 @@ TVector<ISubOperation::TPtr> CreateIndexedTable(TOperationId nextId, const TTxTr userPrefixDesc = indexDescription.GetIndexImplTableDescriptions(2); } } - const THashSet<TString> indexKeyColumns{indexDescription.GetKeyColumnNames().begin(), indexDescription.GetKeyColumnNames().end() - 1}; + const THashSet<TString> indexDataColumns{indexDescription.GetDataColumnNames().begin(), indexDescription.GetDataColumnNames().end()}; result.push_back(createIndexImplTable(CalcVectorKmeansTreeLevelImplTableDesc(baseTableDescription.GetPartitionConfig(), userLevelDesc))); - result.push_back(createIndexImplTable(CalcVectorKmeansTreePostingImplTableDesc(indexKeyColumns, baseTableDescription, baseTableDescription.GetPartitionConfig(), implTableColumns, userPostingDesc))); + result.push_back(createIndexImplTable(CalcVectorKmeansTreePostingImplTableDesc(baseTableDescription, baseTableDescription.GetPartitionConfig(), indexDataColumns, userPostingDesc))); if (prefixVectorIndex) { - result.push_back(createIndexImplTable(CalcVectorKmeansTreePrefixImplTableDesc(indexKeyColumns, baseTableDescription, baseTableDescription.GetPartitionConfig(), implTableColumns, userPrefixDesc))); + const THashSet<TString> prefixColumns{indexDescription.GetKeyColumnNames().begin(), indexDescription.GetKeyColumnNames().end() - 1}; + result.push_back(createIndexImplTable(CalcVectorKmeansTreePrefixImplTableDesc(prefixColumns, baseTableDescription, baseTableDescription.GetPartitionConfig(), implTableColumns, userPrefixDesc))); } } else { NKikimrSchemeOp::TTableDescription userIndexDesc; diff --git a/ydb/core/tx/schemeshard/schemeshard_build_index__progress.cpp b/ydb/core/tx/schemeshard/schemeshard_build_index__progress.cpp index e8ae21558e9..5a700bcd2e6 100644 --- a/ydb/core/tx/schemeshard/schemeshard_build_index__progress.cpp +++ b/ydb/core/tx/schemeshard/schemeshard_build_index__progress.cpp @@ -298,6 +298,7 @@ THolder<TEvSchemeShard::TEvModifySchemeTransaction> CreateBuildPropose( auto path = TPath::Init(buildInfo.TablePathId, ss); const auto& tableInfo = ss->Tables.at(path->PathId); NTableIndex::TTableColumns implTableColumns; + THashSet<TString> indexDataColumns; { buildInfo.SerializeToProto(ss, modifyScheme.MutableInitiateIndexBuild()); const auto& indexDesc = modifyScheme.GetInitiateIndexBuild().GetIndex(); @@ -311,6 +312,8 @@ THolder<TEvSchemeShard::TEvModifySchemeTransaction> CreateBuildPropose( Y_ABORT_UNLESS(indexKeys.KeyColumns.size() >= 1); implTableColumns.Columns.emplace(indexKeys.KeyColumns.back()); modifyScheme.ClearInitiateIndexBuild(); + indexDataColumns = THashSet<TString>(buildInfo.DataColumns.begin(), buildInfo.DataColumns.end()); + indexDataColumns.insert(indexKeys.KeyColumns.back()); } using namespace NTableIndex::NTableVectorKmeansTreeIndex; @@ -343,7 +346,7 @@ THolder<TEvSchemeShard::TEvModifySchemeTransaction> CreateBuildPropose( return propose; } - op = CalcVectorKmeansTreePostingImplTableDesc({}, tableInfo, tableInfo->PartitionConfig(), implTableColumns, {}, suffix); + op = NTableIndex::CalcVectorKmeansTreePostingImplTableDesc(tableInfo, tableInfo->PartitionConfig(), indexDataColumns, {}, suffix); const auto [count, parts, step] = ComputeKMeansBoundaries(*tableInfo, buildInfo); auto& policy = *resetPartitionsSettings(); diff --git a/ydb/core/tx/schemeshard/schemeshard_utils.cpp b/ydb/core/tx/schemeshard/schemeshard_utils.cpp index 7b874f08bf2..9f7be56424c 100644 --- a/ydb/core/tx/schemeshard/schemeshard_utils.cpp +++ b/ydb/core/tx/schemeshard/schemeshard_utils.cpp @@ -257,13 +257,18 @@ auto CalcImplTableDescImpl( } auto CalcVectorKmeansTreePostingImplTableDescImpl( - const THashSet<TString>& indexKeyColumns, const auto& baseTable, const NKikimrSchemeOp::TPartitionConfig& baseTablePartitionConfig, - const TTableColumns& implTableColumns, + const THashSet<TString>& indexDataColumns, const NKikimrSchemeOp::TTableDescription& indexTableDesc, std::string_view suffix) { + auto tableColumns = ExtractInfo(baseTable); + THashSet<TString> indexColumns = indexDataColumns; + for (const auto & keyColumn: tableColumns.Keys) { + indexColumns.insert(keyColumn); + } + NKikimrSchemeOp::TTableDescription implTableDesc; implTableDesc.SetName(TString::Join(NTableVectorKmeansTreeIndex::PostingTable, suffix)); SetImplTablePartitionConfig(baseTablePartitionConfig, indexTableDesc, implTableDesc); @@ -275,15 +280,7 @@ auto CalcVectorKmeansTreePostingImplTableDescImpl( parentColumn->SetNotNull(true); } implTableDesc.AddKeyColumnNames(NTableVectorKmeansTreeIndex::ParentColumn); - if (indexKeyColumns.empty()) { - FillIndexImplTableColumns(GetColumns(baseTable), implTableColumns.Keys, implTableColumns.Columns, implTableDesc); - } else { - auto keys = implTableColumns.Keys; - auto columns = implTableColumns.Columns; - std::erase_if(keys, [&](const auto& key) { return indexKeyColumns.contains(key); }); - EraseNodesIf(columns, [&](const auto& key) { return indexKeyColumns.contains(key); }); - FillIndexImplTableColumns(GetColumns(baseTable), keys, columns, implTableDesc); - } + FillIndexImplTableColumns(GetColumns(baseTable), tableColumns.Keys, indexColumns, implTableDesc); implTableDesc.SetSystemColumnNamesAllowed(true); @@ -384,25 +381,23 @@ NKikimrSchemeOp::TTableDescription CalcVectorKmeansTreeLevelImplTableDesc( } NKikimrSchemeOp::TTableDescription CalcVectorKmeansTreePostingImplTableDesc( - const THashSet<TString>& indexKeyColumns, const NSchemeShard::TTableInfo::TPtr& baseTableInfo, const NKikimrSchemeOp::TPartitionConfig& baseTablePartitionConfig, - const TTableColumns& implTableColumns, + const THashSet<TString>& indexDataColumns, const NKikimrSchemeOp::TTableDescription& indexTableDesc, std::string_view suffix) { - return CalcVectorKmeansTreePostingImplTableDescImpl(indexKeyColumns, baseTableInfo, baseTablePartitionConfig, implTableColumns, indexTableDesc, suffix); + return CalcVectorKmeansTreePostingImplTableDescImpl(baseTableInfo, baseTablePartitionConfig, indexDataColumns, indexTableDesc, suffix); } NKikimrSchemeOp::TTableDescription CalcVectorKmeansTreePostingImplTableDesc( - const THashSet<TString>& indexKeyColumns, const NKikimrSchemeOp::TTableDescription& baseTableDescr, const NKikimrSchemeOp::TPartitionConfig& baseTablePartitionConfig, - const TTableColumns& implTableColumns, + const THashSet<TString>& indexDataColumns, const NKikimrSchemeOp::TTableDescription& indexTableDesc, std::string_view suffix) { - return CalcVectorKmeansTreePostingImplTableDescImpl(indexKeyColumns, baseTableDescr, baseTablePartitionConfig, implTableColumns, indexTableDesc, suffix); + return CalcVectorKmeansTreePostingImplTableDescImpl(baseTableDescr, baseTablePartitionConfig, indexDataColumns, indexTableDesc, suffix); } NKikimrSchemeOp::TTableDescription CalcVectorKmeansTreePrefixImplTableDesc( diff --git a/ydb/core/tx/schemeshard/schemeshard_utils.h b/ydb/core/tx/schemeshard/schemeshard_utils.h index 5d8ed833310..1fdf1186ce8 100644 --- a/ydb/core/tx/schemeshard/schemeshard_utils.h +++ b/ydb/core/tx/schemeshard/schemeshard_utils.h @@ -63,18 +63,16 @@ NKikimrSchemeOp::TTableDescription CalcVectorKmeansTreeLevelImplTableDesc( const NKikimrSchemeOp::TTableDescription& indexTableDesc); NKikimrSchemeOp::TTableDescription CalcVectorKmeansTreePostingImplTableDesc( - const THashSet<TString>& indexKeyColumns, const NSchemeShard::TTableInfo::TPtr& baseTableInfo, const NKikimrSchemeOp::TPartitionConfig& baseTablePartitionConfig, - const TTableColumns& implTableColumns, + const THashSet<TString>& indexDataColumns, const NKikimrSchemeOp::TTableDescription& indexTableDesc, std::string_view suffix = {}); NKikimrSchemeOp::TTableDescription CalcVectorKmeansTreePostingImplTableDesc( - const THashSet<TString>& indexKeyColumns, const NKikimrSchemeOp::TTableDescription& baseTableDescr, const NKikimrSchemeOp::TPartitionConfig& baseTablePartitionConfig, - const TTableColumns& implTableColumns, + const THashSet<TString>& indexDataColumns, const NKikimrSchemeOp::TTableDescription& indexTableDesc, std::string_view suffix = {}); diff --git a/ydb/core/tx/schemeshard/ut_index/ut_vector_index.cpp b/ydb/core/tx/schemeshard/ut_index/ut_vector_index.cpp index e3af6797673..956bf14eec6 100644 --- a/ydb/core/tx/schemeshard/ut_index/ut_vector_index.cpp +++ b/ydb/core/tx/schemeshard/ut_index/ut_vector_index.cpp @@ -78,6 +78,60 @@ Y_UNIT_TEST_SUITE(TVectorIndexTests) { Name: "vectors" Columns { Name: "id" Type: "Uint64" } Columns { Name: "embedding" Type: "String" } + Columns { Name: "prefix" Type: "String" } + KeyColumnNames: ["id"] + } + IndexDescription { + Name: "idx_vector" + KeyColumnNames: ["prefix", "embedding"] + Type: EIndexTypeGlobalVectorKmeansTree + VectorIndexKmeansTreeDescription: { Settings: { settings: { metric: DISTANCE_COSINE, vector_type: VECTOR_TYPE_FLOAT, vector_dimension: 1024 }, clusters: 4, levels: 5 } } + } + )"); + env.TestWaitNotification(runtime, txId); + + TestDescribeResult(DescribePrivatePath(runtime, "/MyRoot/vectors/idx_vector"), + { NLs::PathExist, + NLs::IndexType(NKikimrSchemeOp::EIndexTypeGlobalVectorKmeansTree), + NLs::IndexState(NKikimrSchemeOp::EIndexStateReady), + NLs::IndexKeys({"prefix", "embedding"}), + NLs::IndexDataColumns({}), + NLs::KMeansTreeDescription(Ydb::Table::VectorIndexSettings::DISTANCE_COSINE, + Ydb::Table::VectorIndexSettings::VECTOR_TYPE_FLOAT, + 1024, + 4, + 5 + ), + }); + + TestDescribeResult(DescribePrivatePath(runtime, "/MyRoot/vectors/idx_vector/indexImplPrefixTable"), + { NLs::PathExist, + NLs::CheckColumns(PrefixTable, {"prefix", IdColumn}, {}, {"prefix", IdColumn}, true) }); + + TestDescribeResult(DescribePrivatePath(runtime, "/MyRoot/vectors/idx_vector/indexImplLevelTable"), + { NLs::PathExist, + NLs::CheckColumns(LevelTable, {ParentColumn, IdColumn, CentroidColumn}, {}, {ParentColumn, IdColumn}, true) }); + + TestDescribeResult(DescribePrivatePath(runtime, "/MyRoot/vectors/idx_vector/indexImplPostingTable"), + { NLs::PathExist, + NLs::CheckColumns(PostingTable, {ParentColumn, "id"}, {}, {ParentColumn, "id"}, true) }); + + + TVector<ui64> dropTxIds; + TestDropTable(runtime, dropTxIds.emplace_back(++txId), "/MyRoot", "vectors"); + env.TestWaitNotification(runtime, dropTxIds); + } + + Y_UNIT_TEST(CreateTablePrefixCovering) { + TTestBasicRuntime runtime; + TTestEnv env(runtime); + ui64 txId = 100; + + TestCreateIndexedTable(runtime, ++txId, "/MyRoot", R"( + TableDescription { + Name: "vectors" + Columns { Name: "id" Type: "Uint64" } + Columns { Name: "embedding" Type: "String" } Columns { Name: "covered" Type: "String" } Columns { Name: "prefix" Type: "String" } KeyColumnNames: ["id"] @@ -274,9 +328,9 @@ Y_UNIT_TEST_SUITE(TVectorIndexTests) { } } { - NTableIndex::TTableColumns implTableColumns = {{"data2", "data1"}, {}}; - auto desc = CalcVectorKmeansTreePostingImplTableDesc({}, baseTableDescr, baseTablePartitionConfig, implTableColumns, indexTableDesc, "something"); - std::string_view expected[] = {ParentColumn, "data1", "data2"}; + THashSet<TString> indexDataColumns = {"data2", "data1"}; + auto desc = NTableIndex::CalcVectorKmeansTreePostingImplTableDesc(baseTableDescr, baseTablePartitionConfig, indexDataColumns, indexTableDesc, "something"); + std::string_view expected[] = {NTableIndex::NTableVectorKmeansTreeIndex::ParentColumn, "data1", "data2"}; for (size_t i = 0; auto& column : desc.GetColumns()) { UNIT_ASSERT_STRINGS_EQUAL(column.GetName(), expected[i]); ++i; |