diff options
| -rw-r--r-- | ydb/tests/stress/oltp_workload/workload/type/vector_index.py | 138 |
1 files changed, 72 insertions, 66 deletions
diff --git a/ydb/tests/stress/oltp_workload/workload/type/vector_index.py b/ydb/tests/stress/oltp_workload/workload/type/vector_index.py index aacd237bcd8..54c138e7462 100644 --- a/ydb/tests/stress/oltp_workload/workload/type/vector_index.py +++ b/ydb/tests/stress/oltp_workload/workload/type/vector_index.py @@ -19,38 +19,28 @@ class WorkloadVectorIndex(WorkloadBase): super().__init__(client, prefix, "vector_index", stop) self.table_name = "table" self.index_name = "vector_idx" - self.rows_count = 100 + self.rows_count = 10 self.limit = 10 self.to_binary_string_converters = { "float": BinaryStringConverter( - name="Knn::ToBinaryStringFloat", - data_type="Float", - vector_type="FloatVector" + name="Knn::ToBinaryStringFloat", data_type="Float", vector_type="FloatVector" ), "uint8": BinaryStringConverter( - name="Knn::ToBinaryStringUint8", - data_type="Uint8", - vector_type="Uint8Vector" - ), - "int8": BinaryStringConverter( - name="Knn::ToBinaryStringInt8", - data_type="Int8", - vector_type="Int8Vector" + name="Knn::ToBinaryStringUint8", data_type="Uint8", vector_type="Uint8Vector" ), + "int8": BinaryStringConverter(name="Knn::ToBinaryStringInt8", data_type="Int8", vector_type="Int8Vector"), } self.targets = { - "similarity": { - "inner_product": "Knn::InnerProductSimilarity", - "cosine": "Knn::CosineSimilarity" - }, + "similarity": {"inner_product": "Knn::InnerProductSimilarity", "cosine": "Knn::CosineSimilarity"}, "distance": { "cosine": "Knn::CosineDistance", "manhattan": "Knn::ManhattanDistance", - "euclidean": "Knn::EuclideanDistance" - } + "euclidean": "Knn::EuclideanDistance", + }, } def _get_random_vector(self, type, size): + logger.info(f"random vector type: {type}") if type == "float": values = [round(random.uniform(-100, 100), 2) for _ in range(size)] return ",".join(f'{val}f' for val in values) @@ -87,14 +77,16 @@ class WorkloadVectorIndex(WorkloadBase): """ self.client.query(drop_index_sql, True) - def _create_index(self, table_path, vector_type, - vector_dimension, levels, clusters, - distance=None, similarity=None): - logger.info(f"""Create index vector_type={vector_type}, + def _create_index( + self, table_path, vector_type, vector_dimension, levels, clusters, distance=None, similarity=None + ): + logger.info( + f"""Create index vector_type={vector_type}, vector_dimension={vector_dimension}, levels={levels}, clusters={clusters}, distance={distance}, - similarity={similarity}""") + similarity={similarity}""" + ) if distance is not None: create_index_sql = f""" ALTER TABLE `{table_path}` @@ -130,10 +122,8 @@ class WorkloadVectorIndex(WorkloadBase): for key in range(self.rows_count): vector = self._get_random_vector(vector_type, vector_dimension) name = converter.name - vector_type = converter.vector_type - values.append( - f'({key}, Untag({name}([{vector}]), "{vector_type}"))' - ) + vector_types = converter.vector_type + values.append(f'({key}, Untag({name}([{vector}]), "{vector_types}"))') upsert_sql = f""" UPSERT INTO `{table_path}` (pk, embedding) @@ -141,8 +131,7 @@ class WorkloadVectorIndex(WorkloadBase): """ self.client.query(upsert_sql, False) - def _select(self, table_path, vector_type, - vector_dimension, distance, similarity): + def _select(self, table_path, vector_type, vector_dimension, distance, similarity): if distance is not None: target = self.targets["distance"][distance] else: @@ -162,15 +151,14 @@ class WorkloadVectorIndex(WorkloadBase): """ return self.client.query(select_sql, False) - def _select_top(self, table_path, vector_type, - vector_dimension, distance, similarity): + def _select_top(self, table_path, vector_type, vector_dimension, distance, similarity): logger.info("Select values from table") result_set = self._select( table_path=table_path, vector_type=vector_type, vector_dimension=vector_dimension, distance=distance, - similarity=similarity + similarity=similarity, ) if len(result_set) == 0: raise Exception("Query returned an empty set") @@ -178,17 +166,18 @@ class WorkloadVectorIndex(WorkloadBase): rows = result_set[0].rows logger.info(f"Rows count {len(rows)}") - prev = 0.0 if distance is not None else 1.0 + prev = rows[0]['target'] for row in rows: cur = row['target'] condition = prev <= cur if distance is not None else prev >= cur if not condition: - raise Exception(f"""The set of rows does not satisfy the - condition, prev: {prev}, cur: {cur}""") + raise Exception( + f"""The set of rows does not satisfy the + condition, prev: {prev}, cur: {cur}""" + ) prev = cur - def _wait_inddex_ready(self, table_path, vector_type, - vector_dimension, distance, similarity): + def _wait_inddex_ready(self, table_path, vector_type, vector_dimension, distance, similarity): for i in range(10): time.sleep(7) @@ -198,7 +187,7 @@ class WorkloadVectorIndex(WorkloadBase): vector_type=vector_type, vector_dimension=vector_dimension, distance=distance, - similarity=similarity + similarity=similarity, ) except Exception as ex: if "No global indexes for table" in str(ex): @@ -208,9 +197,7 @@ class WorkloadVectorIndex(WorkloadBase): return raise Exception("Error getting index status") - def _check_loop(self, table_path, vector_type, - vector_dimension, levels, clusters, - distance=None, similarity=None): + def _check_loop(self, table_path, vector_type, vector_dimension, levels, clusters, distance=None, similarity=None): self._create_index( table_path=table_path, vector_type=vector_type, @@ -218,72 +205,91 @@ class WorkloadVectorIndex(WorkloadBase): levels=levels, clusters=clusters, distance=distance, - similarity=similarity + similarity=similarity, ) self._wait_inddex_ready( table_path=table_path, vector_type=vector_type, vector_dimension=vector_dimension, distance=distance, - similarity=similarity + similarity=similarity, ) self._select_top( table_path=table_path, vector_type=vector_type, vector_dimension=vector_dimension, distance=distance, - similarity=similarity + similarity=similarity, ) self._drop_index(table_path) logger.info('check was completed successfully') def _loop(self): table_path = self.get_table_path(self.table_name) - distance_data = ["cosine"] # "cosine", "manhattan", "euclidean" - similarity_data = ["cosine"] # "inner_product", "cosine" - vector_type_data = ["float", "int8"] + distance_data = ["cosine", "manhattan", "euclidean"] + similarity_data = ["cosine", "inner_product"] + vector_type_data = ["float", "int8", "uint8"] levels_data = [1, 3] clusters_data = [1, 17] vector_dimension_data = [5] - - try: + self._create_table(table_path) + while not self.is_stop_requested(): for vector_type in vector_type_data: for vector_dimension in vector_dimension_data: - self._create_table(table_path) - self._upsert_values( - table_path=table_path, - vector_type=vector_type, - vector_dimension=vector_dimension + table_path=table_path, vector_type=vector_type, vector_dimension=vector_dimension ) - - for levels in levels_data: - for clusters in clusters_data: - for distance in distance_data: + for levels in levels_data: + for clusters in clusters_data: + for distance in distance_data: + logger.info( + f"""vector_type: {vector_type} + vector_dimension: {vector_dimension} + levels: {levels} + clusters: {clusters} + distance: {distance} + """ + ) + try: self._check_loop( table_path=table_path, vector_type=vector_type, vector_dimension=vector_dimension, levels=levels, clusters=clusters, - distance=distance + distance=distance, ) + except Exception as ex: + logger.info(f"ERRROR {ex}") + raise str(ex) + if self.is_stop_requested(): + return - for levels in levels_data: - for clusters in clusters_data: - for similarity in similarity_data: + for similarity in similarity_data: + logger.info( + f"""vector_type: {vector_type} + vector_dimension: {vector_dimension} + levels: {levels} + clusters: {clusters} + similarity: {similarity} + """ + ) + try: self._check_loop( table_path=table_path, vector_type=vector_type, vector_dimension=vector_dimension, levels=levels, clusters=clusters, - similarity=similarity + similarity=similarity, ) + except Exception as ex: + logger.info(f"ERRROR {ex}") + raise str(ex) + if self.is_stop_requested(): + return - self._drop_table(table_path) - except Exception as ex: - logger.info(f"ERRROR {ex}") + self._drop_table(table_path) def get_stat(self): return "" |
