aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVitaliy Filippov <vitaliff@ydb.tech>2025-07-21 17:13:37 +0300
committerGitHub <noreply@github.com>2025-07-21 17:13:37 +0300
commit571c0dfd0e8bfcaa33355aac1a55aa41f5db9a10 (patch)
tree56814a7c74d7102bbe3e728812ea1bacc235f21f
parente5c8e37272b713f8c52fc3b01e09e172f9861bf0 (diff)
downloadydb-571c0dfd0e8bfcaa33355aac1a55aa41f5db9a10.tar.gz
Use another window function in vector workload: TOP_BY/BOTTOM_BY (#21137)
-rw-r--r--ydb/library/workload/vector/vector_recall_evaluator.cpp39
1 files changed, 13 insertions, 26 deletions
diff --git a/ydb/library/workload/vector/vector_recall_evaluator.cpp b/ydb/library/workload/vector/vector_recall_evaluator.cpp
index 7095b3f142b..062b4497edd 100644
--- a/ydb/library/workload/vector/vector_recall_evaluator.cpp
+++ b/ydb/library/workload/vector/vector_recall_evaluator.cpp
@@ -36,26 +36,20 @@ void TVectorRecallEvaluator::SelectReferenceResults(const TVectorSampler& sample
refQueryBuilder << ", prefix: " << Params.PrefixType;
}
refQueryBuilder << ">>;\n"
- << "SELECT * FROM ("
- << " SELECT s.id AS id"
- << ", UNWRAP(CAST(m." << Params.KeyColumn << " AS string)) AS result_id"
- << ", UNWRAP(Knn::" << functionName << "(m." << Params.EmbeddingColumn << ", s.embedding)) AS distance"
- << ", (ROW_NUMBER() OVER w) AS position"
+ << "SELECT s.id AS id"
+ << ", " << (isAscending ? "BOTTOM_BY" : "TOP_BY") << "(UNWRAP(CAST(m." << Params.KeyColumn << " AS string))" <<
+ ", Knn::" << functionName << "(m." << Params.EmbeddingColumn << ", s.embedding), " << Params.Limit << ") result_ids"
<< " FROM " << Params.TableName << " m"
<< (Params.PrefixColumn ? " INNER JOIN " : " CROSS JOIN ") << "AS_TABLE($Samples) AS s";
if (Params.PrefixColumn) {
refQueryBuilder << " ON s.prefix = m." << *Params.PrefixColumn;
}
- refQueryBuilder << " WINDOW w AS (PARTITION BY s.id"
- << " ORDER BY Knn::" << functionName << "(m." << Params.EmbeddingColumn << ", s.embedding)"
- << (isAscending ? " ASC" : " DESC") << ")"
- << ") AS t WHERE position <= " << Params.Limit
- << " ORDER BY id, position";
+ refQueryBuilder << " GROUP BY s.id";
std::string refQuery = refQueryBuilder;
- // Process targets in batches (batch size should be ~10000 / Limit)
- const ui64 batchSize = 10000 / Params.Limit;
+ // Process targets in batches
+ const ui64 batchSize = 10;
for (ui64 batchStart = 0; batchStart < sampler.GetTargetCount(); batchStart += batchSize) {
const size_t batchEnd = (batchStart + batchSize < sampler.GetTargetCount() ? batchStart + batchSize : sampler.GetTargetCount());
NYdb::TParamsBuilder paramsBuilder;
@@ -82,24 +76,17 @@ void TVectorRecallEvaluator::SelectReferenceResults(const TVectorSampler& sample
return result;
}));
- ui64 refId = 0;
- std::vector<std::string> refList;
-
NYdb::TResultSetParser parser(*resultSet);
while (parser.TryNextRow()) {
ui64 id = parser.ColumnParser("id").GetUint64();
- std::string res = parser.ColumnParser("result_id").GetString();
- if (id != refId) {
- if (refList.size()) {
- References[refId] = refList;
- }
- refList.clear();
- refId = id;
+ std::vector<std::string> refList;
+ auto& lst = parser.ColumnParser("result_ids");
+ lst.OpenList();
+ while (lst.TryNextListItem()) {
+ refList.push_back(lst.GetString());
}
- refList.push_back(res);
- }
- if (refList.size()) {
- References[refId] = std::move(refList);
+ lst.CloseList();
+ References[id] = refList;
}
}
Cout << "Reference results for " << sampler.GetTargetCount()