diff options
author | Vitaliy Filippov <vitaliff@ydb.tech> | 2025-07-21 17:13:37 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-07-21 17:13:37 +0300 |
commit | 571c0dfd0e8bfcaa33355aac1a55aa41f5db9a10 (patch) | |
tree | 56814a7c74d7102bbe3e728812ea1bacc235f21f | |
parent | e5c8e37272b713f8c52fc3b01e09e172f9861bf0 (diff) | |
download | ydb-571c0dfd0e8bfcaa33355aac1a55aa41f5db9a10.tar.gz |
Use another window function in vector workload: TOP_BY/BOTTOM_BY (#21137)
-rw-r--r-- | ydb/library/workload/vector/vector_recall_evaluator.cpp | 39 |
1 files changed, 13 insertions, 26 deletions
diff --git a/ydb/library/workload/vector/vector_recall_evaluator.cpp b/ydb/library/workload/vector/vector_recall_evaluator.cpp index 7095b3f142b..062b4497edd 100644 --- a/ydb/library/workload/vector/vector_recall_evaluator.cpp +++ b/ydb/library/workload/vector/vector_recall_evaluator.cpp @@ -36,26 +36,20 @@ void TVectorRecallEvaluator::SelectReferenceResults(const TVectorSampler& sample refQueryBuilder << ", prefix: " << Params.PrefixType; } refQueryBuilder << ">>;\n" - << "SELECT * FROM (" - << " SELECT s.id AS id" - << ", UNWRAP(CAST(m." << Params.KeyColumn << " AS string)) AS result_id" - << ", UNWRAP(Knn::" << functionName << "(m." << Params.EmbeddingColumn << ", s.embedding)) AS distance" - << ", (ROW_NUMBER() OVER w) AS position" + << "SELECT s.id AS id" + << ", " << (isAscending ? "BOTTOM_BY" : "TOP_BY") << "(UNWRAP(CAST(m." << Params.KeyColumn << " AS string))" << + ", Knn::" << functionName << "(m." << Params.EmbeddingColumn << ", s.embedding), " << Params.Limit << ") result_ids" << " FROM " << Params.TableName << " m" << (Params.PrefixColumn ? " INNER JOIN " : " CROSS JOIN ") << "AS_TABLE($Samples) AS s"; if (Params.PrefixColumn) { refQueryBuilder << " ON s.prefix = m." << *Params.PrefixColumn; } - refQueryBuilder << " WINDOW w AS (PARTITION BY s.id" - << " ORDER BY Knn::" << functionName << "(m." << Params.EmbeddingColumn << ", s.embedding)" - << (isAscending ? " ASC" : " DESC") << ")" - << ") AS t WHERE position <= " << Params.Limit - << " ORDER BY id, position"; + refQueryBuilder << " GROUP BY s.id"; std::string refQuery = refQueryBuilder; - // Process targets in batches (batch size should be ~10000 / Limit) - const ui64 batchSize = 10000 / Params.Limit; + // Process targets in batches + const ui64 batchSize = 10; for (ui64 batchStart = 0; batchStart < sampler.GetTargetCount(); batchStart += batchSize) { const size_t batchEnd = (batchStart + batchSize < sampler.GetTargetCount() ? batchStart + batchSize : sampler.GetTargetCount()); NYdb::TParamsBuilder paramsBuilder; @@ -82,24 +76,17 @@ void TVectorRecallEvaluator::SelectReferenceResults(const TVectorSampler& sample return result; })); - ui64 refId = 0; - std::vector<std::string> refList; - NYdb::TResultSetParser parser(*resultSet); while (parser.TryNextRow()) { ui64 id = parser.ColumnParser("id").GetUint64(); - std::string res = parser.ColumnParser("result_id").GetString(); - if (id != refId) { - if (refList.size()) { - References[refId] = refList; - } - refList.clear(); - refId = id; + std::vector<std::string> refList; + auto& lst = parser.ColumnParser("result_ids"); + lst.OpenList(); + while (lst.TryNextListItem()) { + refList.push_back(lst.GetString()); } - refList.push_back(res); - } - if (refList.size()) { - References[refId] = std::move(refList); + lst.CloseList(); + References[id] = refList; } } Cout << "Reference results for " << sampler.GetTargetCount() |