diff options
author | snaury <snaury@ydb.tech> | 2023-03-30 17:31:01 +0300 |
---|---|---|
committer | snaury <snaury@ydb.tech> | 2023-03-30 17:31:01 +0300 |
commit | 5edad89a8a68a895e0486b5cd1d7ae6e95a9359c (patch) | |
tree | 6058029cabf2a2f2655f33257d8582ef76d3643f | |
parent | c3fb936192b5b2247d507d2a5f4edc6dd111ac7a (diff) | |
download | ydb-5edad89a8a68a895e0486b5cd1d7ae6e95a9359c.tar.gz |
Support checking change visibility in ydb serializable
-rw-r--r-- | ydb/tests/library/serializability/checker.py | 31 | ||||
-rw-r--r-- | ydb/tests/tools/ydb_serializable/__main__.py | 6 | ||||
-rw-r--r-- | ydb/tests/tools/ydb_serializable/lib/__init__.py | 335 |
3 files changed, 261 insertions, 111 deletions
diff --git a/ydb/tests/library/serializability/checker.py b/ydb/tests/library/serializability/checker.py index 0e39175a64..1f8ae9d873 100644 --- a/ydb/tests/library/serializability/checker.py +++ b/ydb/tests/library/serializability/checker.py @@ -216,6 +216,7 @@ class SerializabilityChecker(object): self.nodes[0] = self.initial_node self.committed = set() self.committed.add(self.initial_node) + self.observed_commits = set() self.aborted = set() self._next_value = 1 @@ -361,12 +362,15 @@ class SerializabilityChecker(object): """For node ensures that it reads key/value pair correctly""" assert key in self.keys, 'Key %r was not prepared properly' % (key,) write = self.find_node(value) - assert key in write.writes, 'Key %r was not in write set for value %r' % (key, value) + if key not in write.writes: + raise SerializabilityError('Node %r observes key %r = %r which was not written' % (node.value, key, value)) assert key not in node.reads or node.reads[key] is write, 'Key %r has multiple conflicting reads' % (key,) node.reads[key] = write if self.explain: self._add_reason(Reason.Observed(key, write.value, node.value)) node.dependencies.add(write) + # We observed this write, so it must be committed + self.observed_commits.add(write) def ensure_write_value(self, node, key): """For node ensure that it writes key/value pair correctly""" @@ -389,7 +393,29 @@ class SerializabilityChecker(object): self.committed.add(node) def abort(self, node): - self.aborted.add(node) + if node.reads: + # We want to check reads are correct even in aborted transactions + # To do that we pretend like we committed a readonly tx + # + # This assumes linearizable dependencies (calculated before tx is + # started) are correct, which is true for ydb since successful + # reads imply interactivity, and interactive transactions take a + # global snapshot, linearizing with everything. + for key in node.writes: + assert key in self.keys + info = self.keys[key] + del info.writes[node.value] + node.writes.clear() + self.commit(node) + else: + self.aborted.add(node) + + def _process_observed_commits(self): + """Makes sure all observed nodes are in a committed set""" + for node in self.observed_commits: + if node in self.aborted or not node.writes: + raise SerializabilityError('Node %r was aborted but observed committed' % (node.value,)) + self.committed.add(node) def _extend_committed(self): """Makes sure all nodes reachable from committed set are in a committed set""" @@ -660,6 +686,7 @@ class SerializabilityChecker(object): self.logger.warning('WARNING: found some unexpected indirect dependencies') def verify(self): + self._process_observed_commits() self._extend_committed() self._flatten_dependencies() self._fill_reverse_dependencies() diff --git a/ydb/tests/tools/ydb_serializable/__main__.py b/ydb/tests/tools/ydb_serializable/__main__.py index 8ae95dc9b8..f03cc08bce 100644 --- a/ydb/tests/tools/ydb_serializable/__main__.py +++ b/ydb/tests/tools/ydb_serializable/__main__.py @@ -24,6 +24,7 @@ def main(): parser.add_argument('-rw', dest='nreadwriters', type=int, default=100, help='Number of coroutines with point-key read/writes, default 100') parser.add_argument('-rt', dest='nreadtablers', type=int, default=100, help='Number of coroutines with read-table transactions, default 100') parser.add_argument('-rr', dest='nrangereaders', type=int, default=100, help='Number of coroutines with range-key reads, default 100') + parser.add_argument('-rwr', dest='nrwrs', type=int, default=0, help='Number of coroutines with read-write-read txs, default 0') parser.add_argument('--keys', dest='numkeys', type=int, default=40, help='Number of distinct keys in a table, default 40') parser.add_argument('--shards', dest='numshards', type=int, default=4, help='Number of shards for a table, default 4') parser.add_argument('--seconds', type=float, default=2.0, help='Minimum number of seconds per iteration, default 2 seconds') @@ -33,6 +34,8 @@ def main(): parser.add_argument('--rt-snapshot', dest='read_table_snapshot', action='store_true', default=None, help='Use server-side snapshots for read-table transactions') parser.add_argument('--ignore-rt', dest='ignore_read_table', action='store_true', help='Ignore read-table results (e.g. for interference only, legacy option)') parser.add_argument('--processes', type=int, default=1, help='Number of processes to fork into, default is 1') + parser.add_argument('--print-unique-errors', dest='print_unique_errors', action='store_const', const=1, default=0, help='Print unique errors that happen during execution') + parser.add_argument('--print-unique-traceback', dest='print_unique_errors', action='store_const', const=2, help='Print traceback for unique errors that happen during execution') args = parser.parse_args() logger = DummyLogger() @@ -42,6 +45,7 @@ def main(): options.shards = args.numshards options.readers = args.nreaders options.writers = args.nwriters + options.rwrs = args.nrwrs options.readwriters = args.nreadwriters options.readtablers = args.nreadtablers options.rangereaders = args.nrangereaders @@ -53,7 +57,7 @@ def main(): async def async_run_single(): iterations = args.iterations - async with DatabaseChecker(args.endpoint, args.database, path=args.path, logger=logger) as checker: + async with DatabaseChecker(args.endpoint, args.database, path=args.path, logger=logger, print_unique_errors=args.print_unique_errors) as checker: while iterations is None or iterations > 0: try: await checker.async_run(options) diff --git a/ydb/tests/tools/ydb_serializable/lib/__init__.py b/ydb/tests/tools/ydb_serializable/lib/__init__.py index 791b8cc139..189156151f 100644 --- a/ydb/tests/tools/ydb_serializable/lib/__init__.py +++ b/ydb/tests/tools/ydb_serializable/lib/__init__.py @@ -11,83 +11,92 @@ import ydb.aio KEY_PREFIX_TYPE = ydb.TupleType().add_element(ydb.OptionalType(ydb.PrimitiveType.Uint64)) +class Query(ydb.DataQuery): + __slots__ = ydb.DataQuery.__slots__ + ('declares', 'statements') + + def __init__(self, declares, statements, types): + super().__init__(''.join(declares + statements), types) + self.declares = declares + self.statements = statements + + def __add__(self, other): + if not isinstance(other, Query): + raise TypeError('cannot add %r and %r' % (self, other)) + return Query( + self.declares + other.declares, + self.statements + other.statements, + dict(**self.parameters_types, **other.parameters_types)) + + def generate_query_point_reads(table, var='$data'): - text = '''\ + declares = [ + '''\ DECLARE {var} AS List<Struct< key: Uint64>>; - + '''.format(var=var), + ] + statements = [ + '''\ SELECT t.key AS key, t.value AS value FROM AS_TABLE({var}) AS d - INNER JOIN `{TABLE}` AS t ON t.key = d.key; - '''.format(TABLE=table, var=var) + INNER JOIN `{table}` AS t ON t.key = d.key; + '''.format(table=table, var=var), + ] row_type = ( ydb.StructType() .add_member('key', ydb.PrimitiveType.Uint64)) - return ydb.DataQuery(text, { - var: ydb.ListType(row_type), - }) + return Query(declares, statements, {var: ydb.ListType(row_type)}) def generate_query_point_writes(table, var='$data'): - text = '''\ + declares = [ + '''\ DECLARE {var} AS List<Struct< key: Uint64, value: Uint64>>; - - UPSERT INTO `{TABLE}` + '''.format(var=var), + ] + statements = [ + '''\ + UPSERT INTO `{table}` SELECT key, value FROM AS_TABLE({var}); - '''.format(TABLE=table, var=var) + '''.format(table=table, var=var), + ] row_type = ( ydb.StructType() .add_member('key', ydb.PrimitiveType.Uint64) .add_member('value', ydb.PrimitiveType.Uint64)) - return ydb.DataQuery(text, { - var: ydb.ListType(row_type), - }) + return Query(declares, statements, {var: ydb.ListType(row_type)}) def generate_query_point_reads_writes(table, readsvar='$reads', writesvar='$writes'): - text = '''\ - DECLARE {readsvar} AS List<Struct< - key: Uint64>>; - DECLARE {writesvar} AS List<Struct< - key: Uint64, - value: Uint64>>; - - SELECT t.key AS key, t.value AS value - FROM AS_TABLE({readsvar}) AS d - INNER JOIN `{TABLE}` AS t ON t.key = d.key; - - UPSERT INTO `{TABLE}` - SELECT key, value - FROM AS_TABLE({writesvar}); - '''.format(TABLE=table, readsvar=readsvar, writesvar=writesvar) - reads_row_type = ( - ydb.StructType() - .add_member('key', ydb.PrimitiveType.Uint64)) - writes_row_type = ( - ydb.StructType() - .add_member('key', ydb.PrimitiveType.Uint64) - .add_member('value', ydb.PrimitiveType.Uint64)) - return ydb.DataQuery(text, { - readsvar: ydb.ListType(reads_row_type), - writesvar: ydb.ListType(writes_row_type), - }) + return ( + generate_query_point_reads(table, readsvar) + + generate_query_point_writes(table, writesvar) + ) def generate_query_range_reads(table, minvar='$minKey', maxvar='$maxKey'): - text = '''\ + declares = [ + '''\ DECLARE {minvar} AS Uint64; DECLARE {maxvar} AS Uint64; - + '''.format(minvar=minvar, maxvar=maxvar), + ] + statements = [ + '''\ SELECT key, value - FROM `{TABLE}` + FROM `{table}` WHERE key >= {minvar} AND key <= {maxvar}; - '''.format(TABLE=table, minvar=minvar, maxvar=maxvar) - return ydb.DataQuery(text, { - minvar: ydb.PrimitiveType.Uint64, - maxvar: ydb.PrimitiveType.Uint64, + '''.format(table=table, minvar=minvar, maxvar=maxvar), + ] + return Query( + declares, + statements, + { + minvar: ydb.PrimitiveType.Uint64, + maxvar: ydb.PrimitiveType.Uint64, }) @@ -154,23 +163,42 @@ class History(object): elif self.linearizable == 'global': checker.ensure_linearizable_globally(node) else: - assert False, 'Unexpected value for linearizable' + raise ValueError('Unexpected value for linearizable: %r' % (self.linearizable,)) for key in self.write_keys: checker.ensure_write_value(node, key) return node class Abort(object): - def __init__(self, op, value): + def __init__(self, op, value, observed=None): self.op = op self.value = value + if observed is None: + self.observed = None + elif isinstance(observed, (list, tuple)): + self.observed = list(observed) + elif isinstance(observed, dict): + self.observed = sorted(observed.items()) + else: + raise TypeError('Unexpected value for observed: %r' % (observed,)) def to_json(self): - return ['abort', self.op, self.value] + result = ['abort', self.op, self.value] + if self.observed is not None: + result.append(self.observed) + return result def apply_to(self, checker): node = checker.nodes.get(self.value) if node is None: raise RuntimeError('Abort of unknown node %r' % (self.value,)) + if self.observed is not None: + seen = set() + for key, value in self.observed: + checker.ensure_read_value(node, key, value) + seen.add(key) + for key in node.expected_read_keys: + if key not in seen: + checker.ensure_read_value(node, key, 0) checker.abort(node) return node @@ -255,6 +283,7 @@ class DatabaseCheckerOptions(object): self.shards = 4 self.readers = 100 self.writers = 100 + self.rwrs = 0 self.readwriters = 100 self.readtablers = 100 self.rangereaders = 100 @@ -265,7 +294,7 @@ class DatabaseCheckerOptions(object): class DatabaseChecker(object): - def __init__(self, endpoint, database, path=None, logger=None): + def __init__(self, endpoint, database, path=None, logger=None, print_unique_errors=0): if not database.startswith('/'): database = '/' + database @@ -280,6 +309,10 @@ class DatabaseChecker(object): self.logger = logger self.pool = None + self._stopping = False + self._unique_errors = {} + self._print_unique_errors = int(print_unique_errors) + async def async_init(self): if self.pool is None: self.driver = ydb.aio.Driver(ydb.ConnectionParams(self.endpoint, self.database)) @@ -291,15 +324,39 @@ class DatabaseChecker(object): return self async def __aexit__(self, exc_type=None, exc_val=None, exc_tb=None): + self._stopping = True await self.pool.stop() await self.driver.stop() + def is_stopping(self): + return self._stopping + + def _report_error(self, e): + if not self._print_unique_errors: + return + text = '%s: %s' % (type(e), str(e)) + if text not in self._unique_errors: + self._unique_errors[text] = 1 + if self._print_unique_errors > 1: + import traceback + traceback.print_exc() + else: + self.logger.debug(text) + else: + self._unique_errors[text] += 1 + if (self._unique_errors[text] % 1000) == 0: + self.logger.debug('%s x%d' % (text, self._unique_errors[text])) + async def async_retry_operation(self, callable, deadline): - while time.time() < deadline: + while time.time() < deadline and not self.is_stopping(): try: async with self.pool.checkout() as session: try: - result = await callable(session) + try: + result = await callable(session) + except Exception as e: + self._report_error(e) + raise except (ydb.Aborted, ydb.Undetermined, ydb.NotFound, ydb.InternalError): raise # these are not retried except (ydb.Unavailable, ydb.Overloaded, ydb.ConnectionError): @@ -308,12 +365,15 @@ class DatabaseChecker(object): continue # retry return result else: - raise ydb.Aborted('deadline reached') + if self.is_stopping(): + raise ydb.Aborted('stopping') + else: + raise ydb.Aborted('deadline reached') async def async_perform_point_reads(self, history, table, options, checker, deadline): read_query = generate_query_point_reads(table) - while time.time() < deadline: + while time.time() < deadline and not self.is_stopping(): keys = checker.select_read_from_write_keys(cnt=random.randint(1, options.shards)) if not keys: # There are not enough in-progress writes to this table yet, spin a little @@ -321,10 +381,12 @@ class DatabaseChecker(object): continue node = history.add(History.Begin('reads', None, read_keys=keys)).apply_to(checker) + observed_values = None async def perform(session): - tx = session.transaction(ydb.SerializableReadWrite()) - try: + nonlocal observed_values + observed_values = None + async with session.transaction(ydb.SerializableReadWrite()) as tx: simple_tx = bool(random.randint(0, 1)) rss = await tx.execute( read_query, @@ -335,40 +397,31 @@ class DatabaseChecker(object): }, commit_tx=simple_tx, ) + observed_values = {} + for row in rss[0].rows: + observed_values[row.key] = row.value if not simple_tx: await tx.commit() - return rss - finally: - if tx.tx_id is not None: - try: - await tx.rollback() - except ydb.Error: - pass try: - rss = await self.async_retry_operation(perform, deadline) + await self.async_retry_operation(perform, deadline) except ydb.Aborted: - history.add(History.Abort('reads', node.value)).apply_to(checker) + history.add(History.Abort('reads', node.value, observed_values)).apply_to(checker) except ydb.Undetermined: pass # transaction outcome unknown else: - values = {} - for row in rss[0].rows: - values[row.key] = row.value - - history.add(History.Commit('reads', node.value, values)).apply_to(checker) + history.add(History.Commit('reads', node.value, observed_values)).apply_to(checker) async def async_perform_point_writes(self, history, table, options, checker, deadline): write_query = generate_query_point_writes(table) - while time.time() < deadline: + while time.time() < deadline and not self.is_stopping(): keys = checker.select_write_keys(cnt=random.randint(1, options.shards)) node = history.add(History.Begin('writes', None, write_keys=keys)).apply_to(checker) async def perform(session): - tx = session.transaction(ydb.SerializableReadWrite()) - try: + async with session.transaction(ydb.SerializableReadWrite()) as tx: simple_tx = bool(random.randint(0, 1)) await tx.execute( write_query, @@ -381,12 +434,6 @@ class DatabaseChecker(object): ) if not simple_tx: await tx.commit() - finally: - if tx.tx_id is not None: - try: - await tx.rollback() - except ydb.Error: - pass try: await self.async_retry_operation(perform, deadline) @@ -399,12 +446,102 @@ class DatabaseChecker(object): checker.release_write_keys(keys) + async def async_perform_point_rwr(self, history, table, options, checker, deadline): + read1_query = generate_query_point_reads(table, '$reads1') + write_query = generate_query_point_writes(table, '$writes') + read2_query = generate_query_point_reads(table, '$reads2') + rwr_query = read1_query + write_query + read2_query + wr_query = write_query + read2_query + + while time.time() < deadline and not self.is_stopping(): + read1_cnt = random.randint(0, options.shards * 2) + read1_keys = checker.select_read_keys(cnt=read1_cnt) if read1_cnt else [] + write_keys = checker.select_write_keys(cnt=random.randint(1, options.shards * 2)) + read2_keys = checker.select_read_from_write_keys(cnt=random.randint(1, options.shards * 2)) + + read_keys = sorted(set(read1_keys).union(set(read2_keys).difference(set(write_keys)))) + write_keys_set = set(write_keys) + + node = history.add(History.Begin('rwr', None, read_keys=read_keys, write_keys=write_keys)).apply_to(checker) + observed_values = None + + async def perform(session): + nonlocal observed_values + observed_values = None + async with session.transaction(ydb.SerializableReadWrite()) as tx: + if read1_keys: + read1_params = { + '$reads1': [ + {'key': key} for key in read1_keys + ], + } + else: + read1_params = {} + write_params = { + '$writes': [ + {'key': key, 'value': node.value} for key in write_keys + ], + } + read2_params = { + '$reads2': [ + {'key': key} for key in read2_keys + ], + } + simple_tx = bool(random.randint(0, 1)) + fuse_commit = bool(random.randint(0, 1)) + if simple_tx: + query = rwr_query if read1_keys else wr_query + parameters = dict(**read1_params, **write_params, **read2_params) + rss = await tx.execute(query, parameters, commit_tx=True) + if read1_keys: + read1 = rss[0] + read2 = rss[1] + else: + read1 = None + read2 = rss[0] + elif read1_keys: + rss = await tx.execute(read1_query, read1_params, commit_tx=False) + read1 = rss[0] + else: + read1 = None + if read1_keys: + observed_values = {} + for row in read1.rows: + observed_values[row.key] = row.value + node.expected_read_keys = tuple(read1_keys) + if not simple_tx: + await tx.execute(write_query, write_params, commit_tx=False) + rss = await tx.execute(read2_query, read2_params, commit_tx=fuse_commit) + read2 = rss[0] + for row in read2.rows: + if row.key in write_keys_set: + if row.value != node.value: + raise SerializabilityError('Tx %r writes to key %r but then reads %r' % (node.value, row.key, row.value)) + else: + if observed_values is None: + observed_values = {} + observed_values[row.key] = row.value + node.expected_read_keys = tuple(read_keys) + if not simple_tx and not fuse_commit: + await tx.commit() + + try: + await self.async_retry_operation(perform, deadline) + except ydb.Aborted: + history.add(History.Abort('rwr', node.value, observed_values)).apply_to(checker) + except ydb.Undetermined: + pass # transaction outcome unknown + else: + history.add(History.Commit('rwr', node.value, observed_values)).apply_to(checker) + + checker.release_write_keys(write_keys) + async def async_perform_point_reads_writes(self, history, table, options, checker, deadline, keysets): read_query = generate_query_point_reads(table) write_query = generate_query_point_writes(table) read_write_query = generate_query_point_reads_writes(table) - while time.time() < deadline: + while time.time() < deadline and not self.is_stopping(): read_keys = checker.select_read_keys(cnt=random.randint(1, options.shards)) write_keys = checker.select_write_keys(cnt=random.randint(1, options.shards)) @@ -414,8 +551,7 @@ class DatabaseChecker(object): async def perform(session): # Read/Write tx may fail with TLI - tx = session.transaction(ydb.SerializableReadWrite()) - try: + async with session.transaction(ydb.SerializableReadWrite()) as tx: simple_tx = bool(random.randint(0, 1)) if simple_tx: rss = await tx.execute( @@ -450,12 +586,6 @@ class DatabaseChecker(object): commit_tx=True, ) return rss - finally: - if tx.tx_id is not None: - try: - await tx.rollback() - except ydb.Error: - pass try: rss = await self.async_retry_operation(perform, deadline) @@ -477,7 +607,7 @@ class DatabaseChecker(object): async def async_perform_verifying_reads(self, history, table, options, checker, deadline, keysets): read_query = generate_query_point_reads(table) - while time.time() < deadline: + while time.time() < deadline and not self.is_stopping(): if not keysets: # There are not enough in-progress writes to this table yet, spin a little await asyncio.sleep(0.000001) @@ -488,8 +618,7 @@ class DatabaseChecker(object): node = history.add(History.Begin('reads_of_writes', None, read_keys=keys)).apply_to(checker) async def perform(session): - tx = session.transaction(ydb.SerializableReadWrite()) - try: + async with session.transaction(ydb.SerializableReadWrite()) as tx: simple_tx = bool(random.randint(0, 1)) rss = await tx.execute( read_query, @@ -503,12 +632,6 @@ class DatabaseChecker(object): if not simple_tx: await tx.commit() return rss - finally: - if tx.tx_id is not None: - try: - await tx.rollback() - except ydb.Error: - pass try: rss = await self.async_retry_operation(perform, deadline) @@ -526,7 +649,7 @@ class DatabaseChecker(object): async def async_perform_range_reads(self, history, table, options, checker, deadline): range_query = generate_query_range_reads(table) - while time.time() < deadline: + while time.time() < deadline and not self.is_stopping(): min_key = random.randint(0, options.keys) max_key = random.randint(min_key, options.keys) read_keys = list(range(min_key, max_key + 1)) @@ -534,8 +657,7 @@ class DatabaseChecker(object): node = history.add(History.Begin('read_range', None, read_keys=read_keys)).apply_to(checker) async def perform(session): - tx = session.transaction(ydb.SerializableReadWrite()) - try: + async with session.transaction(ydb.SerializableReadWrite()) as tx: simple_tx = bool(random.randint(0, 1)) rss = await tx.execute( range_query, @@ -548,12 +670,6 @@ class DatabaseChecker(object): if not simple_tx: await tx.commit() return rss - finally: - if tx.tx_id is not None: - try: - await tx.rollback() - except ydb.Error: - pass try: rss = await self.async_retry_operation(perform, deadline) @@ -569,7 +685,7 @@ class DatabaseChecker(object): history.add(History.Commit('read_range', node.value, values)).apply_to(checker) async def async_perform_read_tables(self, history, table, options, checker, deadline): - while time.time() < deadline: + while time.time() < deadline and not self.is_stopping(): if options.read_table_ranges: min_key = random.randint(0, options.keys) max_key = random.randint(min_key, options.keys) @@ -618,6 +734,9 @@ class DatabaseChecker(object): for _ in range(options.writers): futures.append(self.async_perform_point_writes(history, table, options, checker, deadline=deadline)) + for _ in range(options.rwrs): + futures.append(self.async_perform_point_rwr(history, table, options, checker, deadline=deadline)) + readwrite_keysets = set() for _ in range(options.readwriters): futures.append(self.async_perform_point_reads_writes(history, table, options, checker, deadline=deadline, keysets=readwrite_keysets)) |