diff options
author | Valery Mironov <mbkkt@ydb.tech> | 2024-07-08 10:55:06 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-07-08 10:55:06 +0300 |
commit | 06a15253bdc1bd0d810f1fca0d79a49cea5d56b5 (patch) | |
tree | e1c48bf929be3de4b03b87e96517cd501c157b89 | |
parent | 1f88b11c2cb56fe35a57e2884f771cca6d4b1f3d (diff) | |
download | ydb-06a15253bdc1bd0d810f1fca0d79a49cea5d56b5.tar.gz |
Add flat vector index example to C++ SDK (#6262)
-rw-r--r-- | ydb/public/sdk/cpp/examples/vector_index/README.md | 74 | ||||
-rw-r--r-- | ydb/public/sdk/cpp/examples/vector_index/main.cpp | 72 | ||||
-rw-r--r-- | ydb/public/sdk/cpp/examples/vector_index/target.json | 1 | ||||
-rw-r--r-- | ydb/public/sdk/cpp/examples/vector_index/vector_index.cpp | 288 | ||||
-rw-r--r-- | ydb/public/sdk/cpp/examples/vector_index/vector_index.h | 51 | ||||
-rw-r--r-- | ydb/public/sdk/cpp/examples/vector_index/ya.make | 13 | ||||
-rw-r--r-- | ydb/public/sdk/cpp/examples/ya.make | 1 |
7 files changed, 500 insertions, 0 deletions
diff --git a/ydb/public/sdk/cpp/examples/vector_index/README.md b/ydb/public/sdk/cpp/examples/vector_index/README.md new file mode 100644 index 0000000000..244d1fcf52 --- /dev/null +++ b/ydb/public/sdk/cpp/examples/vector_index/README.md @@ -0,0 +1,74 @@ +# Vector Index + + +Parameters description + +``` +./vector_index --help +Usage: ./vector_index [OPTIONS] [ARG]... + +Required parameters: + {-e|--endpoint} HOST:PORT YDB endpoint + {-d|--database} PATH YDB database + {-c|--command} COMMAND execute command: [Create|Drop|Build|Recreate]Index + or TopK (used for search, read request) + --table TABLE table name + --index_type TYPE index type: [flat] + --index_quantizer QUANTIZER + index quantizer: [none|int8|uint8|bit] + --primary_key PK primary key column + --embedding EMBEDDING embedding (vector) column + --distance DISTANCE distance function: + [Cosine|Euclidean|Manhattan]Distance + --rows ROWS count of rows in table, used only for + [Build|Recreate]Index commands + --top_k TOPK count of rows in top, used only for TopK command + --data DATA list of columns to read, used only for TopK command + --target TARGET file with target vector, used only for TopK command +``` + +RecreateIndex -- command which sequentially executes Drop, Create and Build commands + +Index table will have name like `<table>_<index_type>_<index_quantizer>` + +## Examples + +### Flat Index + +It uses scalar quantization to speedup ANN search: +* an approximate search is performed using quantization +* an approximate list of primary keys is obtained +* we search this list without using quantization + +#### Create Flat Bit Index + +It creates index table, with two columns: `primary_key` and `embedding`, `embedding` will be trasformed from `<table>` `<embedding>` + +``` +./vector_index --endpoint=<endpoint> --database=<database> --command=RecreateIndex --table=<table> --index_type=flat --index_quantizer=bit --primary_key=<primary_key> --embedding=<embedding> --distance=CosineDistance --rows=490000 --top_k=0 --data="" --target="" +``` + +#### Search Flat Bit Index + +Execute query like this, to get approximate candidates + +```sql +$candidates = SELECT primary_key +FROM index_table +ORDER BY CosineDistance(target, embedding) +LIMIT top_k * 2 +``` + +And then execute and print + +```sql +SELECT CosineDistance(target, embedding) AS distance -- <list of columns from data parameter> +FROM table +WHERE primary_key IN $candidates +ORDER BY distance +LIMIT top_k +``` + +``` +./vector_index --endpoint=<endpoint> --database=<database> --command=RecreateIndex --table=<table> --index_type=flat --index_quantizer=bit --primary_key=<primary_key> --embedding=<embedding> --distance=CosineDistance --rows=0 --top_k=15 --data=<list of string columns, e.g. "title, text"> --target=<absolute-path-to-target.json> +``` diff --git a/ydb/public/sdk/cpp/examples/vector_index/main.cpp b/ydb/public/sdk/cpp/examples/vector_index/main.cpp new file mode 100644 index 0000000000..fe310f3cf5 --- /dev/null +++ b/ydb/public/sdk/cpp/examples/vector_index/main.cpp @@ -0,0 +1,72 @@ +#include "vector_index.h" + +#include <util/system/env.h> +#include <util/stream/file.h> + +using namespace NLastGetopt; +using namespace NYdb; + +int main(int argc, char** argv) { + TString endpoint; + TString command; + TOptions options; + + TOpts opts = TOpts::Default(); + + opts.AddLongOption('e', "endpoint", "YDB endpoint").Required().RequiredArgument("HOST:PORT").StoreResult(&endpoint); + opts.AddLongOption('d', "database", "YDB database").Required().RequiredArgument("PATH").StoreResult(&options.Database); + opts.AddLongOption('c', "command", "execute command: [Create|Drop|Build|Recreate]Index or TopK (used for search, read request)").Required().RequiredArgument("COMMAND").StoreResult(&command); + opts.AddLongOption("table", "table name").Required().RequiredArgument("TABLE").StoreResult(&options.Table); + opts.AddLongOption("index_type", "index type: [flat]").Required().RequiredArgument("TYPE").StoreResult(&options.IndexType); + opts.AddLongOption("index_quantizer", "index quantizer: [none|int8|uint8|bit]").Required().RequiredArgument("QUANTIZER").StoreResult(&options.IndexQuantizer); + opts.AddLongOption("primary_key", "primary key column").Required().RequiredArgument("PK").StoreResult(&options.PrimaryKey); + opts.AddLongOption("embedding", "embedding (vector) column").Required().RequiredArgument("EMBEDDING").StoreResult(&options.Embedding); + opts.AddLongOption("distance", "distance function: [Cosine|Euclidean|Manhattan]Distance").Required().RequiredArgument("DISTANCE").StoreResult(&options.Distance); + opts.AddLongOption("rows", "count of rows in table, used only for [Build|Recreate]Index commands").Required().RequiredArgument("ROWS").StoreResult(&options.Rows); + opts.AddLongOption("top_k", "count of rows in top, used only for TopK command").Required().RequiredArgument("TOPK").StoreResult(&options.TopK); + opts.AddLongOption("data", "list of columns to read, used only for TopK command").Required().RequiredArgument("DATA").StoreResult(&options.Data); + opts.AddLongOption("target", "file with target vector, used only for TopK command").Required().RequiredArgument("TARGET").StoreResult(&options.Target); + + opts.SetFreeArgsMin(0); + TOptsParseResult result(&opts, argc, argv); + + ECommand cmd = Parse(command); + + if (cmd == ECommand::None) { + Cerr << "Unsupported command: " << command << Endl; + return 1; + } + + auto config = TDriverConfig() + .SetEndpoint(endpoint) + .SetDatabase(options.Database) + .SetAuthToken(GetEnv("YDB_TOKEN")); + + TDriver driver(config); + + try { + switch (cmd) { + case ECommand::DropIndex: + return DropIndex(driver, options); + case ECommand::CreateIndex: + return CreateIndex(driver, options); + case ECommand::BuildIndex: + return BuildIndex(driver, options); + case ECommand::RecreateIndex: + if (auto r = DropIndex(driver, options); r != 0) { + return r; + } + if (auto r = CreateIndex(driver, options); r != 0) { + return r; + } + return BuildIndex(driver, options); + case ECommand::TopK: + return TopK(driver, options); + default: + break; + } + } catch (const std::exception& e) { + Cerr << "Execution failed: " << e.what() << Endl; + } + return 1; +} diff --git a/ydb/public/sdk/cpp/examples/vector_index/target.json b/ydb/public/sdk/cpp/examples/vector_index/target.json new file mode 100644 index 0000000000..7eeb1d3946 --- /dev/null +++ b/ydb/public/sdk/cpp/examples/vector_index/target.json @@ -0,0 +1 @@ +[0.1961289,0.51426697,0.03864574,0.5552187,-0.041873194,0.24177523,0.46322846,-0.3476358,-0.0802049,0.44246107,-0.06727136,-0.04970105,-0.0012320493,0.29773152,-0.3771864,0.047693416,0.30664062,0.15911901,0.27795044,0.11875397,-0.056650203,0.33322853,-0.28901896,-0.43791273,-0.014167095,0.36109218,-0.16923136,0.29162315,-0.22875166,0.122518055,0.030670911,-0.13762642,-0.13884683,0.31455114,-0.21587017,0.32154146,-0.4452795,-0.058932953,0.07103838,0.4289945,-0.6023675,-0.14161813,0.11005565,0.19201005,0.2591869,-0.24074492,0.18088372,-0.16547637,0.08194011,0.10669302,-0.049760908,0.15548608,0.011035396,0.16121127,-0.4862669,0.5691393,-0.4885568,0.90131176,0.20769958,0.010636337,-0.2094356,-0.15292564,-0.2704138,-0.01326699,0.11226809,0.37113565,-0.018971693,0.86532146,0.28991342,0.004782651,-0.0024367527,-0.0861291,0.39704522,0.25665164,-0.45121723,-0.2728092,0.1441502,-0.5042585,0.3507123,-0.38818485,0.5468399,0.16378048,-0.11177127,0.5224827,-0.05927702,0.44906104,-0.036211397,-0.08465567,-0.33162776,0.25222498,-0.22274417,0.15050206,-0.012386843,0.23640677,-0.18704978,0.1139806,0.19379948,-0.2326912,0.36477265,-0.2544955,0.27143118,-0.095495716,-0.1727166,0.29109988,0.32738894,0.0016002139,0.052142758,0.37208632,0.034044757,0.17740013,0.16472393,-0.20134833,0.055949032,-0.06671674,0.04691583,0.13196157,-0.13174891,-0.17132106,-0.4257385,-1.1067779,0.55262613,0.37117195,-0.37033138,-0.16229,-0.31594914,-0.87293816,0.62064904,-0.32178572,0.28461748,0.41640115,-0.050539408,0.009697271,0.3483608,0.4401717,-0.08273758,0.4873984,0.057845585,0.28128678,-0.43955156,-0.18790118,0.40001884,0.54413813,0.054571174,0.65416795,0.04503013,0.40744695,-0.048226677,0.4787822,0.09700139,0.07739511,0.6503141,0.39685145,-0.54047453,0.041596334,-0.22190939,0.25528133,0.17406437,-0.17308964,0.22076453,0.31207982,0.8434676,0.2086337,-0.014262581,0.05081182,-0.30908328,-0.35717097,0.17224313,0.5266846,0.58924395,-0.29272506,0.01910475,0.061457288,0.18099669,0.04807291,0.34706554,0.32477927,0.17174402,-0.070991516,0.5819317,0.71045977,0.07172716,0.32184732,0.19009985,0.04727492,0.3004647,0.26943457,0.61640364,0.1655051,-0.6033329,0.09797926,-0.20623252,0.10987298,1.016591,-0.29540864,0.25161317,0.19790122,0.14642714,0.5081536,-0.22128952,0.4286613,-0.029895071,0.23768105,-0.0023987228,0.086968,0.42884818,-0.33578634,-0.38033295,-0.16163215,-0.18072455,-0.5015756,0.28035417,-0.0066010267,0.67613393,-0.026721207,0.22796173,-0.008428602,-0.38017297,-0.33044866,0.4519961,-0.05542353,-0.2976922,0.37046987,0.23409955,-0.24246313,-0.12839256,-0.4206849,-0.049280513,-0.7651326,0.1649417,-0.2321146,0.106625736,-0.37506104,0.14470209,-0.114986554,-0.17738944,0.612335,0.25292027,-0.092776075,-0.3876576,-0.08905502,0.3793106,0.7376429,-0.3080258,-0.3869677,0.5239047,-0.41152182,0.22852719,0.42226496,-0.28244498,0.0651847,0.3525671,-0.5396397,-0.17514983,0.29470462,-0.47671098,0.43471992,0.38677526,0.054752454,0.2183725,0.06853758,-0.12792642,0.67841107,0.24607432,0.18936129,0.24056062,-0.30873874,0.62442464,0.5792256,0.20426203,0.54328054,0.56583667,-0.7724596,-0.08384111,-0.16767848,-0.21682987,0.05710991,-0.015403866,0.38889074,-0.6050326,0.4075437,0.40839496,0.2507789,-0.32695654,0.24276069,0.1271161,-0.010688765,-0.31864303,0.15747054,-0.4670915,-0.21059138,0.7470888,0.47273478,-0.119508654,-0.63659865,0.64500844,0.5370401,0.28596714,0.0046216915,0.12771192,-0.18660222,0.47342712,-0.32039297,0.10946048,0.25172964,0.021965463,-0.12397459,-0.048939236,0.2881649,-0.61231786,-0.33459276,-0.29495123,-0.14027011,-0.23020774,0.73250633,0.71871173,0.78408533,0.4140183,0.1398299,0.7395877,0.06801048,-0.8895956,-0.64981127,-0.37226167,0.1905936,0.12819989,-0.47098637,-0.14334664,-0.933116,0.4597078,0.09895813,0.38114703,0.14368558,-0.42793563,-0.10805895,0.025374172,0.40162122,-0.1686769,0.5257471,-0.3540743,0.08181256,-0.34759146,0.0053078625,0.09163392,0.074487045,-0.14934056,0.034427803,0.19613744,-0.00032829077,0.27792764,0.09889235,-0.029708104,0.3528952,0.22679164,-0.27263018,0.6655268,-0.21362385,0.13035864,0.41666874,0.1253278,-0.22861275,0.105085365,0.09412938,0.03228179,0.11568338,0.23504587,-0.044100706,0.0104857525,-0.07461301,0.1034835,0.3078725,0.5257031,-0.015183647,-0.0060899477,-0.02852683,-0.39821762,-0.20495597,-0.14892153,0.44850922,0.40366673,-0.10324784,0.4095244,0.8356313,0.21190739,-0.12822983,0.06830399,0.036365107,0.044244137,0.26112562,0.033477627,-0.41074416,-0.009961431,0.23717403,0.12438699,-0.05255729,-0.18411024,-0.18563229,-0.16543737,-0.122300245,0.40962145,-0.4751102,0.5309857,0.04474563,0.103834346,0.14118321,4.2373734,0.45751426,0.21709882,0.6866778,0.14838168,-0.1831362,0.10963214,-0.33557487,-0.1084519,0.3299757,0.076113895,0.12850489,-0.07326015,-0.23770756,0.11080451,0.29712623,-0.13904962,0.25797644,-0.5074562,0.4018296,-0.23186816,0.24427155,0.39540753,0.015477164,0.14021018,0.273185,0.013538655,0.47227964,0.52339536,0.54428,0.16983595,0.5470162,-0.0042650895,0.21768,0.090606116,-0.13433483,0.5818122,-0.1384567,0.2354754,0.08440857,-0.2166868,0.48664945,-0.13175073,0.45613387,0.089229666,0.15436831,0.08720108,0.37597507,0.52855235,-0.019367872,0.544358,-0.327109,-0.20839518,-0.33598265,0.033363096,0.42312673,0.13452567,0.40526676,0.08402101,-0.19661862,-0.24802914,0.23069139,0.5153508,0.13562717,-0.23842931,-0.23257096,-0.009195984,0.41388315,0.56304437,-0.23492545,-0.2642354,0.3038204,-0.09548942,-0.22467934,-0.2561862,-0.34057313,-0.19744347,0.0007430283,-0.12842518,-0.13980682,0.6849243,0.1795335,-0.5626032,-0.07626079,-0.062749654,0.6660117,-0.4479761,0.07978033,0.6269782,0.536793,0.6801336,-0.22563715,0.38902125,-0.09493616,0.21312712,0.17763247,0.1796997,-3.868085,0.08134122,0.10347531,-0.034904435,-0.2792477,-0.17850947,0.083218865,0.26535586,-0.25551575,0.28172702,0.1383222,0.10376686,-0.123248994,0.1985073,-0.40000066,0.44763976,0.028454497,0.37575415,0.071487874,-0.16965964,0.38927504,0.29088503,-0.011822928,-0.19522227,-0.1766321,0.1731763,0.49192554,0.44358602,-0.49064636,0.024170646,0.025736902,-0.17963372,0.38337404,0.07339889,0.042465065,0.5910191,0.07904464,-0.043729525,-0.16969916,0.4008944,-0.04921039,-0.3757768,0.6075314,-0.24661873,-0.1780646,0.60300773,-0.09518917,0.2213779,-0.46496615,-0.41421738,0.23309247,0.14687467,-0.36499617,0.04227981,0.88024706,0.57489127,0.21026954,-0.13666761,0.05710815,0.22095469,-0.033460964,0.13861561,0.22527887,0.1660716,-0.3286249,-0.060175333,-0.2971499,0.2454142,0.6536238,-0.22991207,0.046677545,-0.026631566,-0.04271381,-0.53681016,0.11866242,-0.24970472,-0.37882543,0.33650783,0.7634871,-0.2858582,0.029164914,0.28833458,-0.39263156,0.64842117,2.6358266,0.058920268,2.2507918,0.6809379,-0.41290292,0.36954543,-0.60793567,0.42561662,0.2498035,0.27133986,-0.005307673,0.32910514,-0.03169463,-0.02270061,-0.14702365,-0.25256258,0.54468036,-0.46112943,-0.07411629,-0.030253865,0.20578359,0.6495886,-0.11674013,0.029835526,0.019896187,-0.008101909,0.3706806,-0.26088533,-0.018712807,0.17228629,0.15223767,0.0675542,0.6338221,-0.15303946,0.02908536,0.27217266,-0.10829474,4.503505,-0.37745082,0.20543274,-0.087563366,-0.14404398,0.5562983,0.41639867,-0.38191214,-0.16266975,-0.46071815,0.51874137,0.36326376,0.027115177,-0.06804209,0.35159302,-0.41162485,0.30493516,0.18828706,0.63608,-0.04735176,0.13811842,0.09368063,0.037441075,-0.0012712433,-0.19929455,0.34804425,0.46975428,0.38857734,-0.061463855,0.122808196,0.37608445,5.2436657,0.25659403,-0.19236223,-0.25611007,0.22265173,0.5898642,-0.28255892,-0.4123271,-0.4214137,0.09197922,-0.060595497,-0.13819462,-0.13570791,0.25433356,0.5907837,0.2548469,-0.39375016,-0.37651995,0.701745,-0.0359955,-0.048193086,0.4458719,0.088069156,-0.015497342,0.52568024,-0.4795603,-0.025876174,0.76476455,-0.32245165,-0.038828112,0.6325802,0.06385053,-0.26389623,0.2439906,-0.4231506,0.19213657,0.5828574,0.053197365,0.45217928,0.040650904,0.83714896,0.63782233,-0.737095,-0.41026706,0.23113042,0.19471557,-0.24410644,-0.35155243,0.20881484,-0.01721743,-0.29494065,-0.114185065,1.2226206,-0.16469914,0.083336286,0.63608664,0.41011855,-0.032080106,-0.08833447,-0.6261006,0.22665286,0.08313674,-0.16372047,0.5235312,0.39580458,0.0007253827,0.10186727,-0.15955615,0.54162663,0.32992217,-0.02491269,0.16312002,0.118171245,-0.029900813,0.038405042,0.31396118,0.45241603,-0.07010825,0.07611299,0.084779754,0.34168348,-0.60676336,0.054825004,-0.16054128,0.2525291,0.20532744,-0.1510394,0.4857572,0.32150552,0.35749313,0.4483151,0.0057622716,0.28705776,-0.018361313,0.08605509,-0.08649293,0.26918742,0.4806176,0.098294765,0.3284613,0.00010664656,0.43832678,-0.33351916,0.02354738,0.004953976,-0.14319824,-0.33351237,-0.7268964,0.56292313,0.1275613,0.4438945,0.7984555,-0.19372283,0.2940397,-0.11770557] diff --git a/ydb/public/sdk/cpp/examples/vector_index/vector_index.cpp b/ydb/public/sdk/cpp/examples/vector_index/vector_index.cpp new file mode 100644 index 0000000000..d8b452bfc4 --- /dev/null +++ b/ydb/public/sdk/cpp/examples/vector_index/vector_index.cpp @@ -0,0 +1,288 @@ +#include "vector_index.h" +#include <format> +#include <fstream> +#include <sstream> +#include <thread> + +template <> +struct std::formatter<TString>: std::formatter<std::string_view> { + template <typename FormatContext> + auto format(const TString& param, FormatContext& fc) const { + return std::formatter<std::string_view>::format(std::string_view{param}, fc); + } +}; + +using namespace NYdb; +using namespace NTable; +namespace { + +constexpr ui64 kBulkSize = 1000; +constexpr std::string_view FlatIndex = "flat"; + +namespace NQuantizer { + +static constexpr std::string_view None = "None"; +static constexpr std::string_view Int8 = "Int8"; +static constexpr std::string_view Uint8 = "Uint8"; +static constexpr std::string_view Bit = "Bit"; + +} // namespace NQuantizer + +bool EqualsICase(std::string_view l, std::string_view r) { + return std::equal(l.begin(), l.end(), r.begin(), r.end(), [](char l, char r) { + return std::tolower(l) == std::tolower(r); + }); +} + +void PrintTop(TResultSetParser&& parser) { + while (parser.TryNextRow()) { + Y_ASSERT(parser.ColumnsCount() >= 1); + Cout << *parser.ColumnParser(0).GetOptionalFloat() << "\t"; + for (size_t i = 1; i < parser.ColumnsCount(); ++i) { + Cout << *parser.ColumnParser(1).GetOptionalUtf8() << "\t"; + } + Cout << "\n"; + } + Cout << Endl; +} + +TString FullName(const TOptions& options, const TString& name) { + return TString::Join(options.Database, "/", name); +} + +TString IndexName(const TOptions& options) { + return TString::Join(options.Table, "_", options.IndexType, "_", options.IndexQuantizer); +} + +TString FullIndexName(const TOptions& options) { + return FullName(options, IndexName(options)); +} + +void DropTable(TTableClient& client, const TString& table) { + auto r = client.RetryOperationSync([&](TSession session) { + TDropTableSettings settings; + return session.DropTable(table).ExtractValueSync(); + }); + if (!r.IsSuccess() && r.GetStatus() != EStatus::SCHEME_ERROR) { + ythrow TVectorException{r}; + } +} + +void DropIndex(TTableClient& client, const TOptions& options) { + DropTable(client, FullIndexName(options)); +} + +void CreateFlat(TTableClient& client, const TOptions& options) { + auto r = client.RetryOperationSync([&](TSession session) { + auto desc = TTableBuilder() + .AddNonNullableColumn(options.PrimaryKey, EPrimitiveType::Uint32) + .AddNullableColumn(options.Embedding, EPrimitiveType::String) + .SetPrimaryKeyColumn(options.PrimaryKey) + .Build(); + + return session.CreateTable(FullIndexName(options), std::move(desc)).ExtractValueSync(); + }); + if (!r.IsSuccess()) { + ythrow TVectorException{r}; + } +} + +void UpdateFlat(TTableClient& client, const TOptions& options, std::string_view type) { + TString query = std::format(R"( + DECLARE $begin AS Uint64; + DECLARE $rows AS Uint64; + + UPSERT INTO {1} + SELECT COALESCE(CAST({2} AS Uint32), 0) AS {2}, Untag(Knn::ToBinaryString{4}(CAST(Knn::FloatFromBinaryString({3}) AS List<{5}>)), "{4}Vector") AS {3} + FROM {0} + WHERE $begin <= {2} AND {2} < $begin + $rows; + )", + options.Table, + IndexName(options), + options.PrimaryKey, + options.Embedding, + type, + type == NQuantizer::Bit ? "Float" : type); + Cout << query << Endl; + + auto last = std::chrono::steady_clock::now(); + ui64 current = 0; + ui64 overall = (options.Rows + kBulkSize - 1) / kBulkSize; + auto report = [&](auto curr) { + Cout << "Already done " << current << " / " << overall << " upserts, time spent: " << std::chrono::duration<double>{curr - last}.count() << Endl; + last = curr; + }; + auto waitRequest = [&](auto& request) { + auto r = request.ExtractValueSync(); + if (!r.IsSuccess()) { + ythrow TVectorException{r}; + } + ++current; + if (auto curr = std::chrono::steady_clock::now(); (curr - last) >= std::chrono::seconds{1}) { + report(curr); + } + }; + + std::deque<TAsyncStatus> requests; + auto waitFirst = [&] { + if (requests.size() < std::thread::hardware_concurrency()) { + return; + } + waitRequest(requests.front()); + requests.pop_front(); + }; + + TParamsBuilder paramsBuilder; + TRetryOperationSettings retrySettings; + retrySettings + .MaxRetries(60) + .GetSessionClientTimeout(TDuration::Seconds(60)) + .Idempotent(true) + .RetryUndefined(true); + for (ui64 i = 0; i < options.Rows; i += kBulkSize) { + waitFirst(); + paramsBuilder.AddParam("$begin").Uint64(i).Build(); + paramsBuilder.AddParam("$rows").Uint64(kBulkSize).Build(); + auto f = client.RetryOperation([&, p = paramsBuilder.Build()](TSession session) { + auto params = p; + return session.ExecuteDataQuery( + query, + TTxControl::BeginTx(TTxSettings::SerializableRW()).CommitTx(), + std::move(params)) + .Apply([](auto result) -> TStatus { + return result.ExtractValueSync(); + }); + }, retrySettings); + requests.push_back(std::move(f)); + } + + for (auto& request : requests) { + waitRequest(request); + } + report(std::chrono::steady_clock::now()); +} + +void TopKFlat(TTableClient& client, const TOptions& options, std::string_view type) { + TString query = std::format(R"( + $TargetBinary = Knn::ToBinaryStringFloat($Target); + $TargetSQ = Knn::ToBinaryString{7}(CAST($Target AS List<{8}>)); + + $IndexIds = SELECT {1}, Knn::{0}({4}, $TargetSQ) as distance + FROM {2} + ORDER BY distance + LIMIT {6} * 2; + + SELECT Knn::{0}({4}, $TargetBinary) as distance, {5} + FROM {3} + WHERE {1} IN (SELECT {1} FROM $IndexIds) + ORDER BY distance + LIMIT {6}; + )", + options.Distance, + options.PrimaryKey, + IndexName(options), + options.Table, + options.Embedding, + options.Data, + options.TopK, + type, + type == NQuantizer::Bit ? "Float" : type); + Cout << query << Endl; + std::ifstream targetFileStream(options.Target); + std::stringstream targetStrStream; + targetStrStream << targetFileStream.rdbuf(); + query = std::format(R"($Target = CAST({0} AS List<Float>); {1})", targetStrStream.view(), query); + TExecDataQuerySettings settings; + settings.KeepInQueryCache(true); + auto r = client.RetryOperationSync([&](TSession session) -> TStatus { + auto f = session.ExecuteDataQuery(query, TTxControl::BeginTx(TTxSettings::SerializableRW()).CommitTx(),settings); + auto r = f.ExtractValueSync(); + if (r.IsSuccess()) { + PrintTop(r.GetResultSetParser(0)); + } + return r; + }); + if (!r.IsSuccess()) { + ythrow TVectorException{r}; + } +} + +} // namespace + +ECommand Parse(std::string_view command) { + if (EqualsICase(command, "DropIndex")) { + return ECommand::DropIndex; + } + if (EqualsICase(command, "CreateIndex")) { + return ECommand::CreateIndex; + } + if (EqualsICase(command, "BuildIndex")) { + return ECommand::BuildIndex; + } + if (EqualsICase(command, "RecreateIndex")) { + return ECommand::RecreateIndex; + } + if (EqualsICase(command, "TopK")) { + return ECommand::TopK; + } + return ECommand::None; +} + +int DropIndex(NYdb::TDriver& driver, const TOptions& options) { + TTableClient client(driver); + DropIndex(client, options); + return 0; +} + +int CreateIndex(NYdb::TDriver& driver, const TOptions& options) { + TTableClient client(driver); + if (options.IndexType == FlatIndex) { + CreateFlat(client, options); + return 0; + } + return 1; +} + +int BuildIndex(NYdb::TDriver& driver, const TOptions& options) { + TTableClient client(driver); + if (EqualsICase(options.IndexType, FlatIndex)) { + if (EqualsICase(options.IndexQuantizer, NQuantizer::None)) { + return 0; + } + if (EqualsICase(options.IndexQuantizer, NQuantizer::Int8)) { + UpdateFlat(client, options, NQuantizer::Int8); + return 0; + } + if (EqualsICase(options.IndexQuantizer, NQuantizer::Uint8)) { + UpdateFlat(client, options, NQuantizer::Uint8); + return 0; + } + if (EqualsICase(options.IndexQuantizer, NQuantizer::Bit)) { + UpdateFlat(client, options, NQuantizer::Bit); + return 0; + } + } + return 1; +} + +int TopK(NYdb::TDriver& driver, const TOptions& options) { + TTableClient client(driver); + if (EqualsICase(options.IndexType, FlatIndex)) { + if (EqualsICase(options.IndexQuantizer, NQuantizer::None)) { + return 0; + } + if (EqualsICase(options.IndexQuantizer, NQuantizer::Int8)) { + TopKFlat(client, options, NQuantizer::Int8); + return 0; + } + if (EqualsICase(options.IndexQuantizer, NQuantizer::Uint8)) { + TopKFlat(client, options, NQuantizer::Uint8); + return 0; + } + if (EqualsICase(options.IndexQuantizer, NQuantizer::Bit)) { + TopKFlat(client, options, NQuantizer::Bit); + return 0; + } + } + return 1; +}
\ No newline at end of file diff --git a/ydb/public/sdk/cpp/examples/vector_index/vector_index.h b/ydb/public/sdk/cpp/examples/vector_index/vector_index.h new file mode 100644 index 0000000000..8ebcf5ab85 --- /dev/null +++ b/ydb/public/sdk/cpp/examples/vector_index/vector_index.h @@ -0,0 +1,51 @@ +#pragma once + +#include <ydb/public/sdk/cpp/client/ydb_driver/driver.h> +#include <ydb/public/sdk/cpp/client/ydb_table/table.h> + +#include <library/cpp/getopt/last_getopt.h> +#include <util/generic/string.h> +#include <util/generic/yexception.h> +#include <util/stream/output.h> +#include <util/string/builder.h> +#include <util/string/printf.h> + +enum class ECommand { + DropIndex, + CreateIndex, + BuildIndex, + RecreateIndex, // Drop, Create, Build + TopK, + None, +}; + +ECommand Parse(std::string_view command); + +struct TOptions { + TString Database; + TString Table; + TString IndexType; + TString IndexQuantizer; + TString PrimaryKey; + TString Embedding; + TString Distance; + TString Data; + TString Target; + ui64 Rows = 0; + ui64 TopK = 0; +}; + +int DropIndex(NYdb::TDriver& driver, const TOptions& options); + +int CreateIndex(NYdb::TDriver& driver, const TOptions& options); + +int BuildIndex(NYdb::TDriver& driver, const TOptions& options); + +int TopK(NYdb::TDriver& driver, const TOptions& options); + +class TVectorException: public yexception { +public: + TVectorException(const NYdb::TStatus& status) { + *this << "Status:" << status; + } +}; diff --git a/ydb/public/sdk/cpp/examples/vector_index/ya.make b/ydb/public/sdk/cpp/examples/vector_index/ya.make new file mode 100644 index 0000000000..2c7a5c2a6f --- /dev/null +++ b/ydb/public/sdk/cpp/examples/vector_index/ya.make @@ -0,0 +1,13 @@ +PROGRAM() + +SRCS( + main.cpp + vector_index.cpp +) + +PEERDIR( + library/cpp/getopt + ydb/public/sdk/cpp/client/ydb_table +) + +END() diff --git a/ydb/public/sdk/cpp/examples/ya.make b/ydb/public/sdk/cpp/examples/ya.make index 2cb41f28d2..14a498dfd8 100644 --- a/ydb/public/sdk/cpp/examples/ya.make +++ b/ydb/public/sdk/cpp/examples/ya.make @@ -6,4 +6,5 @@ RECURSE( secondary_index_builtin topic_reader ttl + vector_index ) |