aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/python
diff options
context:
space:
mode:
authorrobot-piglet <robot-piglet@yandex-team.com>2025-03-31 10:42:46 +0300
committerrobot-piglet <robot-piglet@yandex-team.com>2025-03-31 10:52:09 +0300
commit1d42097b0f04723927549b4e89bbfdd02ecb7a4c (patch)
treec0431882b81ab94a5c05484a16a01906c3637f5a /contrib/python
parent79ff08a2da256291def48e53ca02343cf6f752be (diff)
downloadydb-1d42097b0f04723927549b4e89bbfdd02ecb7a4c.tar.gz
Intermediate changes
commit_hash:d5b3c59b6e8a4b2975f385625c72a7a6aef30691
Diffstat (limited to 'contrib/python')
-rw-r--r--contrib/python/ydb/py3/.dist-info/METADATA2
-rw-r--r--contrib/python/ydb/py3/ya.make2
-rw-r--r--contrib/python/ydb/py3/ydb/_apis.py1
-rw-r--r--contrib/python/ydb/py3/ydb/_errors.py1
-rw-r--r--contrib/python/ydb/py3/ydb/_grpc/grpcwrapper/common_utils.py3
-rw-r--r--contrib/python/ydb/py3/ydb/_grpc/grpcwrapper/ydb_topic.py67
-rw-r--r--contrib/python/ydb/py3/ydb/_topic_reader/datatypes.py9
-rw-r--r--contrib/python/ydb/py3/ydb/_topic_reader/topic_reader_asyncio.py121
-rw-r--r--contrib/python/ydb/py3/ydb/_topic_reader/topic_reader_sync.py34
-rw-r--r--contrib/python/ydb/py3/ydb/_topic_writer/topic_writer.py10
-rw-r--r--contrib/python/ydb/py3/ydb/_topic_writer/topic_writer_asyncio.py85
-rw-r--r--contrib/python/ydb/py3/ydb/_topic_writer/topic_writer_sync.py52
-rw-r--r--contrib/python/ydb/py3/ydb/aio/driver.py1
-rw-r--r--contrib/python/ydb/py3/ydb/aio/query/pool.py8
-rw-r--r--contrib/python/ydb/py3/ydb/aio/query/transaction.py48
-rw-r--r--contrib/python/ydb/py3/ydb/driver.py1
-rw-r--r--contrib/python/ydb/py3/ydb/issues.py4
-rw-r--r--contrib/python/ydb/py3/ydb/query/base.py67
-rw-r--r--contrib/python/ydb/py3/ydb/query/pool.py5
-rw-r--r--contrib/python/ydb/py3/ydb/query/transaction.py87
-rw-r--r--contrib/python/ydb/py3/ydb/topic.py76
-rw-r--r--contrib/python/ydb/py3/ydb/ydb_version.py2
22 files changed, 635 insertions, 51 deletions
diff --git a/contrib/python/ydb/py3/.dist-info/METADATA b/contrib/python/ydb/py3/.dist-info/METADATA
index b6911ce75e8..1f52419b882 100644
--- a/contrib/python/ydb/py3/.dist-info/METADATA
+++ b/contrib/python/ydb/py3/.dist-info/METADATA
@@ -1,6 +1,6 @@
Metadata-Version: 2.1
Name: ydb
-Version: 3.19.3
+Version: 3.20.0
Summary: YDB Python SDK
Home-page: http://github.com/ydb-platform/ydb-python-sdk
Author: Yandex LLC
diff --git a/contrib/python/ydb/py3/ya.make b/contrib/python/ydb/py3/ya.make
index 71cfb8fa720..847d0def900 100644
--- a/contrib/python/ydb/py3/ya.make
+++ b/contrib/python/ydb/py3/ya.make
@@ -2,7 +2,7 @@
PY3_LIBRARY()
-VERSION(3.19.3)
+VERSION(3.20.0)
LICENSE(Apache-2.0)
diff --git a/contrib/python/ydb/py3/ydb/_apis.py b/contrib/python/ydb/py3/ydb/_apis.py
index fc28d0ceb29..fc6f16e287c 100644
--- a/contrib/python/ydb/py3/ydb/_apis.py
+++ b/contrib/python/ydb/py3/ydb/_apis.py
@@ -115,6 +115,7 @@ class TopicService(object):
DropTopic = "DropTopic"
StreamRead = "StreamRead"
StreamWrite = "StreamWrite"
+ UpdateOffsetsInTransaction = "UpdateOffsetsInTransaction"
class QueryService(object):
diff --git a/contrib/python/ydb/py3/ydb/_errors.py b/contrib/python/ydb/py3/ydb/_errors.py
index 17002d25749..1e2308ef394 100644
--- a/contrib/python/ydb/py3/ydb/_errors.py
+++ b/contrib/python/ydb/py3/ydb/_errors.py
@@ -5,6 +5,7 @@ from . import issues
_errors_retriable_fast_backoff_types = [
issues.Unavailable,
+ issues.ClientInternalError,
]
_errors_retriable_slow_backoff_types = [
issues.Aborted,
diff --git a/contrib/python/ydb/py3/ydb/_grpc/grpcwrapper/common_utils.py b/contrib/python/ydb/py3/ydb/_grpc/grpcwrapper/common_utils.py
index 95a5744313e..10d98918c30 100644
--- a/contrib/python/ydb/py3/ydb/_grpc/grpcwrapper/common_utils.py
+++ b/contrib/python/ydb/py3/ydb/_grpc/grpcwrapper/common_utils.py
@@ -160,9 +160,6 @@ class GrpcWrapperAsyncIO(IGrpcWrapperAsyncIO):
self._stream_call = None
self._wait_executor = None
- def __del__(self):
- self._clean_executor(wait=False)
-
async def start(self, driver: SupportedDriverType, stub, method):
if asyncio.iscoroutinefunction(driver.__call__):
await self._start_asyncio_driver(driver, stub, method)
diff --git a/contrib/python/ydb/py3/ydb/_grpc/grpcwrapper/ydb_topic.py b/contrib/python/ydb/py3/ydb/_grpc/grpcwrapper/ydb_topic.py
index 5b22c7cf862..0f8a0f03a7a 100644
--- a/contrib/python/ydb/py3/ydb/_grpc/grpcwrapper/ydb_topic.py
+++ b/contrib/python/ydb/py3/ydb/_grpc/grpcwrapper/ydb_topic.py
@@ -141,6 +141,18 @@ class UpdateTokenResponse(IFromProto):
########################################################################################################################
+@dataclass
+class TransactionIdentity(IToProto):
+ tx_id: str
+ session_id: str
+
+ def to_proto(self) -> ydb_topic_pb2.TransactionIdentity:
+ return ydb_topic_pb2.TransactionIdentity(
+ id=self.tx_id,
+ session=self.session_id,
+ )
+
+
class StreamWriteMessage:
@dataclass()
class InitRequest(IToProto):
@@ -199,6 +211,7 @@ class StreamWriteMessage:
class WriteRequest(IToProto):
messages: typing.List["StreamWriteMessage.WriteRequest.MessageData"]
codec: int
+ tx_identity: Optional[TransactionIdentity]
@dataclass
class MessageData(IToProto):
@@ -237,6 +250,9 @@ class StreamWriteMessage:
proto = ydb_topic_pb2.StreamWriteMessage.WriteRequest()
proto.codec = self.codec
+ if self.tx_identity is not None:
+ proto.tx.CopyFrom(self.tx_identity.to_proto())
+
for message in self.messages:
proto_mess = proto.messages.add()
proto_mess.CopyFrom(message.to_proto())
@@ -297,6 +313,8 @@ class StreamWriteMessage:
)
except ValueError:
message_write_status = reason
+ elif proto_ack.HasField("written_in_tx"):
+ message_write_status = StreamWriteMessage.WriteResponse.WriteAck.StatusWrittenInTx()
else:
raise NotImplementedError("unexpected ack status")
@@ -309,6 +327,9 @@ class StreamWriteMessage:
class StatusWritten:
offset: int
+ class StatusWrittenInTx:
+ pass
+
@dataclass
class StatusSkipped:
reason: "StreamWriteMessage.WriteResponse.WriteAck.StatusSkipped.Reason"
@@ -1197,6 +1218,52 @@ class MeteringMode(int, IFromProto, IFromPublic, IToPublic):
@dataclass
+class UpdateOffsetsInTransactionRequest(IToProto):
+ tx: TransactionIdentity
+ topics: List[UpdateOffsetsInTransactionRequest.TopicOffsets]
+ consumer: str
+
+ def to_proto(self):
+ return ydb_topic_pb2.UpdateOffsetsInTransactionRequest(
+ tx=self.tx.to_proto(),
+ consumer=self.consumer,
+ topics=list(
+ map(
+ UpdateOffsetsInTransactionRequest.TopicOffsets.to_proto,
+ self.topics,
+ )
+ ),
+ )
+
+ @dataclass
+ class TopicOffsets(IToProto):
+ path: str
+ partitions: List[UpdateOffsetsInTransactionRequest.TopicOffsets.PartitionOffsets]
+
+ def to_proto(self):
+ return ydb_topic_pb2.UpdateOffsetsInTransactionRequest.TopicOffsets(
+ path=self.path,
+ partitions=list(
+ map(
+ UpdateOffsetsInTransactionRequest.TopicOffsets.PartitionOffsets.to_proto,
+ self.partitions,
+ )
+ ),
+ )
+
+ @dataclass
+ class PartitionOffsets(IToProto):
+ partition_id: int
+ partition_offsets: List[OffsetsRange]
+
+ def to_proto(self) -> ydb_topic_pb2.UpdateOffsetsInTransactionRequest.TopicOffsets.PartitionOffsets:
+ return ydb_topic_pb2.UpdateOffsetsInTransactionRequest.TopicOffsets.PartitionOffsets(
+ partition_id=self.partition_id,
+ partition_offsets=list(map(OffsetsRange.to_proto, self.partition_offsets)),
+ )
+
+
+@dataclass
class CreateTopicRequest(IToProto, IFromPublic):
path: str
partitioning_settings: "PartitioningSettings"
diff --git a/contrib/python/ydb/py3/ydb/_topic_reader/datatypes.py b/contrib/python/ydb/py3/ydb/_topic_reader/datatypes.py
index b48501aff2f..74f06a086fc 100644
--- a/contrib/python/ydb/py3/ydb/_topic_reader/datatypes.py
+++ b/contrib/python/ydb/py3/ydb/_topic_reader/datatypes.py
@@ -108,6 +108,9 @@ class PartitionSession:
waiter = self._ack_waiters.popleft()
waiter._finish_ok()
+ def _update_last_commited_offset_if_needed(self, offset: int):
+ self.committed_offset = max(self.committed_offset, offset)
+
def close(self):
if self.closed:
return
@@ -211,3 +214,9 @@ class PublicBatch(ICommittable, ISessionAlive):
self._bytes_size = self._bytes_size - new_batch._bytes_size
return new_batch
+
+ def _update_partition_offsets(self, tx, exc=None):
+ if exc is not None:
+ return
+ offsets = self._commit_get_offsets_range()
+ self._partition_session._update_last_commited_offset_if_needed(offsets.end)
diff --git a/contrib/python/ydb/py3/ydb/_topic_reader/topic_reader_asyncio.py b/contrib/python/ydb/py3/ydb/_topic_reader/topic_reader_asyncio.py
index 7061b4e449c..c9704d5542a 100644
--- a/contrib/python/ydb/py3/ydb/_topic_reader/topic_reader_asyncio.py
+++ b/contrib/python/ydb/py3/ydb/_topic_reader/topic_reader_asyncio.py
@@ -5,7 +5,7 @@ import concurrent.futures
import gzip
import typing
from asyncio import Task
-from collections import OrderedDict
+from collections import defaultdict, OrderedDict
from typing import Optional, Set, Dict, Union, Callable
import ydb
@@ -19,17 +19,24 @@ from . import topic_reader
from .._grpc.grpcwrapper.common_utils import (
IGrpcWrapperAsyncIO,
SupportedDriverType,
+ to_thread,
GrpcWrapperAsyncIO,
)
from .._grpc.grpcwrapper.ydb_topic import (
StreamReadMessage,
UpdateTokenRequest,
UpdateTokenResponse,
+ UpdateOffsetsInTransactionRequest,
Codec,
)
from .._errors import check_retriable_error
import logging
+from ..query.base import TxEvent
+
+if typing.TYPE_CHECKING:
+ from ..query.transaction import BaseQueryTxContext
+
logger = logging.getLogger(__name__)
@@ -77,7 +84,7 @@ class PublicAsyncIOReader:
):
self._loop = asyncio.get_running_loop()
self._closed = False
- self._reconnector = ReaderReconnector(driver, settings)
+ self._reconnector = ReaderReconnector(driver, settings, self._loop)
self._parent = _parent
async def __aenter__(self):
@@ -88,8 +95,7 @@ class PublicAsyncIOReader:
def __del__(self):
if not self._closed:
- task = self._loop.create_task(self.close(flush=False))
- topic_common.wrap_set_name_for_asyncio_task(task, task_name="close reader")
+ logger.warning("Topic reader was not closed properly. Consider using method close().")
async def wait_message(self):
"""
@@ -112,6 +118,23 @@ class PublicAsyncIOReader:
max_messages=max_messages,
)
+ async def receive_batch_with_tx(
+ self,
+ tx: "BaseQueryTxContext",
+ max_messages: typing.Union[int, None] = None,
+ ) -> typing.Union[datatypes.PublicBatch, None]:
+ """
+ Get one messages batch with tx from reader.
+ All messages in a batch from same partition.
+
+ use asyncio.wait_for for wait with timeout.
+ """
+ await self._reconnector.wait_message()
+ return self._reconnector.receive_batch_with_tx_nowait(
+ tx=tx,
+ max_messages=max_messages,
+ )
+
async def receive_message(self) -> typing.Optional[datatypes.PublicMessage]:
"""
Block until receive new message
@@ -165,11 +188,18 @@ class ReaderReconnector:
_state_changed: asyncio.Event
_stream_reader: Optional["ReaderStream"]
_first_error: asyncio.Future[YdbError]
+ _tx_to_batches_map: Dict[str, typing.List[datatypes.PublicBatch]]
- def __init__(self, driver: Driver, settings: topic_reader.PublicReaderSettings):
+ def __init__(
+ self,
+ driver: Driver,
+ settings: topic_reader.PublicReaderSettings,
+ loop: Optional[asyncio.AbstractEventLoop] = None,
+ ):
self._id = self._static_reader_reconnector_counter.inc_and_get()
self._settings = settings
self._driver = driver
+ self._loop = loop if loop is not None else asyncio.get_running_loop()
self._background_tasks = set()
self._state_changed = asyncio.Event()
@@ -177,6 +207,8 @@ class ReaderReconnector:
self._background_tasks.add(asyncio.create_task(self._connection_loop()))
self._first_error = asyncio.get_running_loop().create_future()
+ self._tx_to_batches_map = dict()
+
async def _connection_loop(self):
attempt = 0
while True:
@@ -190,6 +222,7 @@ class ReaderReconnector:
if not retry_info.is_retriable:
self._set_first_error(err)
return
+
await asyncio.sleep(retry_info.sleep_timeout_seconds)
attempt += 1
@@ -222,9 +255,87 @@ class ReaderReconnector:
max_messages=max_messages,
)
+ def receive_batch_with_tx_nowait(self, tx: "BaseQueryTxContext", max_messages: Optional[int] = None):
+ batch = self._stream_reader.receive_batch_nowait(
+ max_messages=max_messages,
+ )
+
+ self._init_tx(tx)
+
+ self._tx_to_batches_map[tx.tx_id].append(batch)
+
+ tx._add_callback(TxEvent.AFTER_COMMIT, batch._update_partition_offsets, self._loop)
+
+ return batch
+
def receive_message_nowait(self):
return self._stream_reader.receive_message_nowait()
+ def _init_tx(self, tx: "BaseQueryTxContext"):
+ if tx.tx_id not in self._tx_to_batches_map: # Init tx callbacks
+ self._tx_to_batches_map[tx.tx_id] = []
+ tx._add_callback(TxEvent.BEFORE_COMMIT, self._commit_batches_with_tx, self._loop)
+ tx._add_callback(TxEvent.AFTER_COMMIT, self._handle_after_tx_commit, self._loop)
+ tx._add_callback(TxEvent.AFTER_ROLLBACK, self._handle_after_tx_rollback, self._loop)
+
+ async def _commit_batches_with_tx(self, tx: "BaseQueryTxContext"):
+ grouped_batches = defaultdict(lambda: defaultdict(list))
+ for batch in self._tx_to_batches_map[tx.tx_id]:
+ grouped_batches[batch._partition_session.topic_path][batch._partition_session.partition_id].append(batch)
+
+ request = UpdateOffsetsInTransactionRequest(tx=tx._tx_identity(), consumer=self._settings.consumer, topics=[])
+
+ for topic_path in grouped_batches:
+ topic_offsets = UpdateOffsetsInTransactionRequest.TopicOffsets(path=topic_path, partitions=[])
+ for partition_id in grouped_batches[topic_path]:
+ partition_offsets = UpdateOffsetsInTransactionRequest.TopicOffsets.PartitionOffsets(
+ partition_id=partition_id,
+ partition_offsets=[
+ batch._commit_get_offsets_range() for batch in grouped_batches[topic_path][partition_id]
+ ],
+ )
+ topic_offsets.partitions.append(partition_offsets)
+ request.topics.append(topic_offsets)
+
+ try:
+ return await self._do_commit_batches_with_tx_call(request)
+ except BaseException:
+ exc = issues.ClientInternalError("Failed to update offsets in tx.")
+ tx._set_external_error(exc)
+ self._stream_reader._set_first_error(exc)
+ finally:
+ del self._tx_to_batches_map[tx.tx_id]
+
+ async def _do_commit_batches_with_tx_call(self, request: UpdateOffsetsInTransactionRequest):
+ args = [
+ request.to_proto(),
+ _apis.TopicService.Stub,
+ _apis.TopicService.UpdateOffsetsInTransaction,
+ topic_common.wrap_operation,
+ ]
+
+ if asyncio.iscoroutinefunction(self._driver.__call__):
+ res = await self._driver(*args)
+ else:
+ res = await to_thread(self._driver, *args, executor=None)
+
+ return res
+
+ async def _handle_after_tx_rollback(self, tx: "BaseQueryTxContext", exc: Optional[BaseException]) -> None:
+ if tx.tx_id in self._tx_to_batches_map:
+ del self._tx_to_batches_map[tx.tx_id]
+ exc = issues.ClientInternalError("Reconnect due to transaction rollback")
+ self._stream_reader._set_first_error(exc)
+
+ async def _handle_after_tx_commit(self, tx: "BaseQueryTxContext", exc: Optional[BaseException]) -> None:
+ if tx.tx_id in self._tx_to_batches_map:
+ del self._tx_to_batches_map[tx.tx_id]
+
+ if exc is not None:
+ self._stream_reader._set_first_error(
+ issues.ClientInternalError("Reconnect due to transaction commit failed")
+ )
+
def commit(self, batch: datatypes.ICommittable) -> datatypes.PartitionSession.CommitAckWaiter:
return self._stream_reader.commit(batch)
diff --git a/contrib/python/ydb/py3/ydb/_topic_reader/topic_reader_sync.py b/contrib/python/ydb/py3/ydb/_topic_reader/topic_reader_sync.py
index eda1d374fc3..3e6806d06fe 100644
--- a/contrib/python/ydb/py3/ydb/_topic_reader/topic_reader_sync.py
+++ b/contrib/python/ydb/py3/ydb/_topic_reader/topic_reader_sync.py
@@ -1,5 +1,6 @@
import asyncio
import concurrent.futures
+import logging
import typing
from typing import List, Union, Optional
@@ -20,6 +21,11 @@ from ydb._topic_reader.topic_reader_asyncio import (
TopicReaderClosedError,
)
+if typing.TYPE_CHECKING:
+ from ..query.transaction import BaseQueryTxContext
+
+logger = logging.getLogger(__name__)
+
class TopicReaderSync:
_caller: CallFromSyncToAsync
@@ -52,7 +58,8 @@ class TopicReaderSync:
self._parent = _parent
def __del__(self):
- self.close(flush=False)
+ if not self._closed:
+ logger.warning("Topic reader was not closed properly. Consider using method close().")
def __enter__(self):
return self
@@ -109,6 +116,31 @@ class TopicReaderSync:
timeout,
)
+ def receive_batch_with_tx(
+ self,
+ tx: "BaseQueryTxContext",
+ *,
+ max_messages: typing.Union[int, None] = None,
+ max_bytes: typing.Union[int, None] = None,
+ timeout: Union[float, None] = None,
+ ) -> Union[PublicBatch, None]:
+ """
+ Get one messages batch with tx from reader
+ It has no async_ version for prevent lost messages, use async_wait_message as signal for new batches available.
+
+ if no new message in timeout seconds (default - infinite): raise TimeoutError()
+ if timeout <= 0 - it will fast wait only one event loop cycle - without wait any i/o operations or pauses, get messages from internal buffer only.
+ """
+ self._check_closed()
+
+ return self._caller.safe_call_with_result(
+ self._async_reader.receive_batch_with_tx(
+ tx=tx,
+ max_messages=max_messages,
+ ),
+ timeout,
+ )
+
def commit(self, mess: typing.Union[datatypes.PublicMessage, datatypes.PublicBatch]):
"""
Put commit message to internal buffer.
diff --git a/contrib/python/ydb/py3/ydb/_topic_writer/topic_writer.py b/contrib/python/ydb/py3/ydb/_topic_writer/topic_writer.py
index aa5fe9749a7..a3e407ed86d 100644
--- a/contrib/python/ydb/py3/ydb/_topic_writer/topic_writer.py
+++ b/contrib/python/ydb/py3/ydb/_topic_writer/topic_writer.py
@@ -11,6 +11,7 @@ import typing
import ydb.aio
from .._grpc.grpcwrapper.ydb_topic import StreamWriteMessage
+from .._grpc.grpcwrapper.ydb_topic import TransactionIdentity
from .._grpc.grpcwrapper.common_utils import IToProto
from .._grpc.grpcwrapper.ydb_topic_public_types import PublicCodec
from .. import connection
@@ -53,8 +54,12 @@ class PublicWriteResult:
class Skipped:
pass
+ @dataclass(eq=True)
+ class WrittenInTx:
+ pass
+
-PublicWriteResultTypes = Union[PublicWriteResult.Written, PublicWriteResult.Skipped]
+PublicWriteResultTypes = Union[PublicWriteResult.Written, PublicWriteResult.Skipped, PublicWriteResult.WrittenInTx]
class WriterSettings(PublicWriterSettings):
@@ -205,6 +210,7 @@ def default_serializer_message_content(data: Any) -> bytes:
def messages_to_proto_requests(
messages: List[InternalMessage],
+ tx_identity: Optional[TransactionIdentity],
) -> List[StreamWriteMessage.FromClient]:
gropus = _slit_messages_for_send(messages)
@@ -215,6 +221,7 @@ def messages_to_proto_requests(
StreamWriteMessage.WriteRequest(
messages=list(map(InternalMessage.to_message_data, group)),
codec=group[0].codec,
+ tx_identity=tx_identity,
)
)
res.append(req)
@@ -239,6 +246,7 @@ _message_data_overhead = (
),
],
codec=20000,
+ tx_identity=None,
)
)
.to_proto()
diff --git a/contrib/python/ydb/py3/ydb/_topic_writer/topic_writer_asyncio.py b/contrib/python/ydb/py3/ydb/_topic_writer/topic_writer_asyncio.py
index 32d8fefe51c..1ea6c25028b 100644
--- a/contrib/python/ydb/py3/ydb/_topic_writer/topic_writer_asyncio.py
+++ b/contrib/python/ydb/py3/ydb/_topic_writer/topic_writer_asyncio.py
@@ -1,7 +1,6 @@
import asyncio
import concurrent.futures
import datetime
-import functools
import gzip
import typing
from collections import deque
@@ -35,6 +34,7 @@ from .._grpc.grpcwrapper.ydb_topic import (
UpdateTokenRequest,
UpdateTokenResponse,
StreamWriteMessage,
+ TransactionIdentity,
WriterMessagesFromServerToClient,
)
from .._grpc.grpcwrapper.common_utils import (
@@ -43,6 +43,11 @@ from .._grpc.grpcwrapper.common_utils import (
GrpcWrapperAsyncIO,
)
+from ..query.base import TxEvent
+
+if typing.TYPE_CHECKING:
+ from ..query.transaction import BaseQueryTxContext
+
logger = logging.getLogger(__name__)
@@ -74,10 +79,8 @@ class WriterAsyncIO:
raise
def __del__(self):
- if self._closed or self._loop.is_closed():
- return
-
- self._loop.call_soon(functools.partial(self.close, flush=False))
+ if not self._closed:
+ logger.warning("Topic writer was not closed properly. Consider using method close().")
async def close(self, *, flush: bool = True):
if self._closed:
@@ -164,6 +167,57 @@ class WriterAsyncIO:
return await self._reconnector.wait_init()
+class TxWriterAsyncIO(WriterAsyncIO):
+ _tx: "BaseQueryTxContext"
+
+ def __init__(
+ self,
+ tx: "BaseQueryTxContext",
+ driver: SupportedDriverType,
+ settings: PublicWriterSettings,
+ _client=None,
+ _is_implicit=False,
+ ):
+ self._tx = tx
+ self._loop = asyncio.get_running_loop()
+ self._closed = False
+ self._reconnector = WriterAsyncIOReconnector(driver=driver, settings=WriterSettings(settings), tx=self._tx)
+ self._parent = _client
+ self._is_implicit = _is_implicit
+
+ # For some reason, creating partition could conflict with other session operations.
+ # Could be removed later.
+ self._first_write = True
+
+ tx._add_callback(TxEvent.BEFORE_COMMIT, self._on_before_commit, self._loop)
+ tx._add_callback(TxEvent.BEFORE_ROLLBACK, self._on_before_rollback, self._loop)
+
+ async def write(
+ self,
+ messages: Union[Message, List[Message]],
+ ):
+ """
+ send one or number of messages to server.
+ it put message to internal buffer
+
+ For wait with timeout use asyncio.wait_for.
+ """
+ if self._first_write:
+ self._first_write = False
+ return await super().write_with_ack(messages)
+ return await super().write(messages)
+
+ async def _on_before_commit(self, tx: "BaseQueryTxContext"):
+ if self._is_implicit:
+ return
+ await self.close()
+
+ async def _on_before_rollback(self, tx: "BaseQueryTxContext"):
+ if self._is_implicit:
+ return
+ await self.close(flush=False)
+
+
class WriterAsyncIOReconnector:
_closed: bool
_loop: asyncio.AbstractEventLoop
@@ -178,6 +232,7 @@ class WriterAsyncIOReconnector:
_codec_selector_batch_num: int
_codec_selector_last_codec: Optional[PublicCodec]
_codec_selector_check_batches_interval: int
+ _tx: Optional["BaseQueryTxContext"]
if typing.TYPE_CHECKING:
_messages_for_encode: asyncio.Queue[List[InternalMessage]]
@@ -195,7 +250,9 @@ class WriterAsyncIOReconnector:
_stop_reason: asyncio.Future
_init_info: Optional[PublicWriterInitInfo]
- def __init__(self, driver: SupportedDriverType, settings: WriterSettings):
+ def __init__(
+ self, driver: SupportedDriverType, settings: WriterSettings, tx: Optional["BaseQueryTxContext"] = None
+ ):
self._closed = False
self._loop = asyncio.get_running_loop()
self._driver = driver
@@ -205,6 +262,7 @@ class WriterAsyncIOReconnector:
self._init_info = None
self._stream_connected = asyncio.Event()
self._settings = settings
+ self._tx = tx
self._codec_functions = {
PublicCodec.RAW: lambda data: data,
@@ -354,10 +412,12 @@ class WriterAsyncIOReconnector:
# noinspection PyBroadException
stream_writer = None
try:
+ tx_identity = None if self._tx is None else self._tx._tx_identity()
stream_writer = await WriterAsyncIOStream.create(
self._driver,
self._init_message,
self._settings.update_token_interval,
+ tx_identity=tx_identity,
)
try:
if self._init_info is None:
@@ -387,7 +447,7 @@ class WriterAsyncIOReconnector:
done.pop().result() # need for raise exception - reason of stop task
except issues.Error as err:
err_info = check_retriable_error(err, retry_settings, attempt)
- if not err_info.is_retriable:
+ if not err_info.is_retriable or self._tx is not None: # no retries in tx writer
self._stop(err)
return
@@ -533,6 +593,8 @@ class WriterAsyncIOReconnector:
result = PublicWriteResult.Skipped()
elif isinstance(status, write_ack_msg.StatusWritten):
result = PublicWriteResult.Written(offset=status.offset)
+ elif isinstance(status, write_ack_msg.StatusWrittenInTx):
+ result = PublicWriteResult.WrittenInTx()
else:
raise TopicWriterError("internal error - receive unexpected ack message.")
message_future.set_result(result)
@@ -597,10 +659,13 @@ class WriterAsyncIOStream:
_update_token_event: asyncio.Event
_get_token_function: Optional[Callable[[], str]]
+ _tx_identity: Optional[TransactionIdentity]
+
def __init__(
self,
update_token_interval: Optional[Union[int, float]] = None,
get_token_function: Optional[Callable[[], str]] = None,
+ tx_identity: Optional[TransactionIdentity] = None,
):
self._closed = False
@@ -609,6 +674,8 @@ class WriterAsyncIOStream:
self._update_token_event = asyncio.Event()
self._update_token_task = None
+ self._tx_identity = tx_identity
+
async def close(self):
if self._closed:
return
@@ -625,6 +692,7 @@ class WriterAsyncIOStream:
driver: SupportedDriverType,
init_request: StreamWriteMessage.InitRequest,
update_token_interval: Optional[Union[int, float]] = None,
+ tx_identity: Optional[TransactionIdentity] = None,
) -> "WriterAsyncIOStream":
stream = GrpcWrapperAsyncIO(StreamWriteMessage.FromServer.from_proto)
@@ -634,6 +702,7 @@ class WriterAsyncIOStream:
writer = WriterAsyncIOStream(
update_token_interval=update_token_interval,
get_token_function=creds.get_auth_token if creds else lambda: "",
+ tx_identity=tx_identity,
)
await writer._start(stream, init_request)
return writer
@@ -680,7 +749,7 @@ class WriterAsyncIOStream:
if self._closed:
raise RuntimeError("Can not write on closed stream.")
- for request in messages_to_proto_requests(messages):
+ for request in messages_to_proto_requests(messages, self._tx_identity):
self._stream.write(request)
async def _update_token_loop(self):
diff --git a/contrib/python/ydb/py3/ydb/_topic_writer/topic_writer_sync.py b/contrib/python/ydb/py3/ydb/_topic_writer/topic_writer_sync.py
index a5193caf7c5..4796d7ac2d6 100644
--- a/contrib/python/ydb/py3/ydb/_topic_writer/topic_writer_sync.py
+++ b/contrib/python/ydb/py3/ydb/_topic_writer/topic_writer_sync.py
@@ -1,6 +1,7 @@
from __future__ import annotations
import asyncio
+import logging
import typing
from concurrent.futures import Future
from typing import Union, List, Optional
@@ -14,13 +15,23 @@ from .topic_writer import (
TopicWriterClosedError,
)
-from .topic_writer_asyncio import WriterAsyncIO
+from ..query.base import TxEvent
+
+from .topic_writer_asyncio import (
+ TxWriterAsyncIO,
+ WriterAsyncIO,
+)
from .._topic_common.common import (
_get_shared_event_loop,
TimeoutType,
CallFromSyncToAsync,
)
+if typing.TYPE_CHECKING:
+ from ..query.transaction import BaseQueryTxContext
+
+logger = logging.getLogger(__name__)
+
class WriterSync:
_caller: CallFromSyncToAsync
@@ -63,7 +74,8 @@ class WriterSync:
raise
def __del__(self):
- self.close(flush=False)
+ if not self._closed:
+ logger.warning("Topic writer was not closed properly. Consider using method close().")
def close(self, *, flush: bool = True, timeout: TimeoutType = None):
if self._closed:
@@ -122,3 +134,39 @@ class WriterSync:
self._check_closed()
return self._caller.unsafe_call_with_result(self._async_writer.write_with_ack(messages), timeout=timeout)
+
+
+class TxWriterSync(WriterSync):
+ def __init__(
+ self,
+ tx: "BaseQueryTxContext",
+ driver: SupportedDriverType,
+ settings: PublicWriterSettings,
+ *,
+ eventloop: Optional[asyncio.AbstractEventLoop] = None,
+ _parent=None,
+ ):
+
+ self._closed = False
+
+ if eventloop:
+ loop = eventloop
+ else:
+ loop = _get_shared_event_loop()
+
+ self._caller = CallFromSyncToAsync(loop)
+
+ async def create_async_writer():
+ return TxWriterAsyncIO(tx, driver, settings, _is_implicit=True)
+
+ self._async_writer = self._caller.safe_call_with_result(create_async_writer(), None)
+ self._parent = _parent
+
+ tx._add_callback(TxEvent.BEFORE_COMMIT, self._on_before_commit, None)
+ tx._add_callback(TxEvent.BEFORE_ROLLBACK, self._on_before_rollback, None)
+
+ def _on_before_commit(self, tx: "BaseQueryTxContext"):
+ self.close()
+
+ def _on_before_rollback(self, tx: "BaseQueryTxContext"):
+ self.close(flush=False)
diff --git a/contrib/python/ydb/py3/ydb/aio/driver.py b/contrib/python/ydb/py3/ydb/aio/driver.py
index 9cd6fd2b74d..267997fbcc3 100644
--- a/contrib/python/ydb/py3/ydb/aio/driver.py
+++ b/contrib/python/ydb/py3/ydb/aio/driver.py
@@ -62,4 +62,5 @@ class Driver(pool.ConnectionPool):
async def stop(self, timeout=10):
await self.table_client._stop_pool_if_needed(timeout=timeout)
+ self.topic_client.close()
await super().stop(timeout=timeout)
diff --git a/contrib/python/ydb/py3/ydb/aio/query/pool.py b/contrib/python/ydb/py3/ydb/aio/query/pool.py
index 947db658726..f1ca68d1cf0 100644
--- a/contrib/python/ydb/py3/ydb/aio/query/pool.py
+++ b/contrib/python/ydb/py3/ydb/aio/query/pool.py
@@ -158,6 +158,8 @@ class QuerySessionPool:
async def wrapped_callee():
async with self.checkout() as session:
async with session.transaction(tx_mode=tx_mode) as tx:
+ if tx_mode.name in ["serializable_read_write", "snapshot_read_only"]:
+ await tx.begin()
result = await callee(tx, *args, **kwargs)
await tx.commit()
return result
@@ -213,12 +215,6 @@ class QuerySessionPool:
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.stop()
- def __del__(self):
- if self._should_stop.is_set() or self._loop.is_closed():
- return
-
- self._loop.call_soon(self.stop)
-
class SimpleQuerySessionCheckoutAsync:
def __init__(self, pool: QuerySessionPool):
diff --git a/contrib/python/ydb/py3/ydb/aio/query/transaction.py b/contrib/python/ydb/py3/ydb/aio/query/transaction.py
index 5b63a32b489..f0547e5f01f 100644
--- a/contrib/python/ydb/py3/ydb/aio/query/transaction.py
+++ b/contrib/python/ydb/py3/ydb/aio/query/transaction.py
@@ -16,6 +16,28 @@ logger = logging.getLogger(__name__)
class QueryTxContext(BaseQueryTxContext):
+ def __init__(self, driver, session_state, session, tx_mode):
+ """
+ An object that provides a simple transaction context manager that allows statements execution
+ in a transaction. You don't have to open transaction explicitly, because context manager encapsulates
+ transaction control logic, and opens new transaction if:
+
+ 1) By explicit .begin() method;
+ 2) On execution of a first statement, which is strictly recommended method, because that avoids useless round trip
+
+ This context manager is not thread-safe, so you should not manipulate on it concurrently.
+
+ :param driver: A driver instance
+ :param session_state: A state of session
+ :param tx_mode: Transaction mode, which is a one from the following choises:
+ 1) QuerySerializableReadWrite() which is default mode;
+ 2) QueryOnlineReadOnly(allow_inconsistent_reads=False);
+ 3) QuerySnapshotReadOnly();
+ 4) QueryStaleReadOnly().
+ """
+ super().__init__(driver, session_state, session, tx_mode)
+ self._init_callback_handler(base.CallbackHandlerMode.ASYNC)
+
async def __aenter__(self) -> "QueryTxContext":
"""
Enters a context manager and returns a transaction
@@ -30,7 +52,7 @@ class QueryTxContext(BaseQueryTxContext):
it is not finished explicitly
"""
await self._ensure_prev_stream_finished()
- if self._tx_state._state == QueryTxStateEnum.BEGINED:
+ if self._tx_state._state == QueryTxStateEnum.BEGINED and self._external_error is None:
# It's strictly recommended to close transactions directly
# by using commit_tx=True flag while executing statement or by
# .commit() or .rollback() methods, but here we trying to do best
@@ -65,7 +87,9 @@ class QueryTxContext(BaseQueryTxContext):
:return: A committed transaction or exception if commit is failed
"""
- if self._tx_state._already_in(QueryTxStateEnum.COMMITTED):
+ self._check_external_error_set()
+
+ if self._tx_state._should_skip(QueryTxStateEnum.COMMITTED):
return
if self._tx_state._state == QueryTxStateEnum.NOT_INITIALIZED:
@@ -74,7 +98,13 @@ class QueryTxContext(BaseQueryTxContext):
await self._ensure_prev_stream_finished()
- await self._commit_call(settings)
+ try:
+ await self._execute_callbacks_async(base.TxEvent.BEFORE_COMMIT)
+ await self._commit_call(settings)
+ await self._execute_callbacks_async(base.TxEvent.AFTER_COMMIT, exc=None)
+ except BaseException as e:
+ await self._execute_callbacks_async(base.TxEvent.AFTER_COMMIT, exc=e)
+ raise e
async def rollback(self, settings: Optional[BaseRequestSettings] = None) -> None:
"""Calls rollback on a transaction if it is open otherwise is no-op. If transaction execution
@@ -84,7 +114,9 @@ class QueryTxContext(BaseQueryTxContext):
:return: A committed transaction or exception if commit is failed
"""
- if self._tx_state._already_in(QueryTxStateEnum.ROLLBACKED):
+ self._check_external_error_set()
+
+ if self._tx_state._should_skip(QueryTxStateEnum.ROLLBACKED):
return
if self._tx_state._state == QueryTxStateEnum.NOT_INITIALIZED:
@@ -93,7 +125,13 @@ class QueryTxContext(BaseQueryTxContext):
await self._ensure_prev_stream_finished()
- await self._rollback_call(settings)
+ try:
+ await self._execute_callbacks_async(base.TxEvent.BEFORE_ROLLBACK)
+ await self._rollback_call(settings)
+ await self._execute_callbacks_async(base.TxEvent.AFTER_ROLLBACK, exc=None)
+ except BaseException as e:
+ await self._execute_callbacks_async(base.TxEvent.AFTER_ROLLBACK, exc=e)
+ raise e
async def execute(
self,
diff --git a/contrib/python/ydb/py3/ydb/driver.py b/contrib/python/ydb/py3/ydb/driver.py
index 49bd223c901..3998aeeef5f 100644
--- a/contrib/python/ydb/py3/ydb/driver.py
+++ b/contrib/python/ydb/py3/ydb/driver.py
@@ -288,4 +288,5 @@ class Driver(pool.ConnectionPool):
def stop(self, timeout=10):
self.table_client._stop_pool_if_needed(timeout=timeout)
+ self.topic_client.close()
super().stop(timeout=timeout)
diff --git a/contrib/python/ydb/py3/ydb/issues.py b/contrib/python/ydb/py3/ydb/issues.py
index f38f99f9257..4e76f5ed2b0 100644
--- a/contrib/python/ydb/py3/ydb/issues.py
+++ b/contrib/python/ydb/py3/ydb/issues.py
@@ -178,6 +178,10 @@ class SessionPoolEmpty(Error, queue.Empty):
status = StatusCode.SESSION_POOL_EMPTY
+class ClientInternalError(Error):
+ status = StatusCode.CLIENT_INTERNAL_ERROR
+
+
class UnexpectedGrpcMessage(Error):
def __init__(self, message: str):
super().__init__(message)
diff --git a/contrib/python/ydb/py3/ydb/query/base.py b/contrib/python/ydb/py3/ydb/query/base.py
index 57a769bb1a1..a5ebedd95b3 100644
--- a/contrib/python/ydb/py3/ydb/query/base.py
+++ b/contrib/python/ydb/py3/ydb/query/base.py
@@ -1,6 +1,8 @@
import abc
+import asyncio
import enum
import functools
+from collections import defaultdict
import typing
from typing import (
@@ -17,6 +19,10 @@ from .. import issues
from .. import _utilities
from .. import _apis
+from ydb._topic_common.common import CallFromSyncToAsync, _get_shared_event_loop
+from ydb._grpc.grpcwrapper.common_utils import to_thread
+
+
if typing.TYPE_CHECKING:
from .transaction import BaseQueryTxContext
@@ -196,3 +202,64 @@ def wrap_execute_query_response(
return convert.ResultSet.from_message(response_pb.result_set, settings)
return None
+
+
+class TxEvent(enum.Enum):
+ BEFORE_COMMIT = "BEFORE_COMMIT"
+ AFTER_COMMIT = "AFTER_COMMIT"
+ BEFORE_ROLLBACK = "BEFORE_ROLLBACK"
+ AFTER_ROLLBACK = "AFTER_ROLLBACK"
+
+
+class CallbackHandlerMode(enum.Enum):
+ SYNC = "SYNC"
+ ASYNC = "ASYNC"
+
+
+def _get_sync_callback(method: typing.Callable, loop: Optional[asyncio.AbstractEventLoop]):
+ if asyncio.iscoroutinefunction(method):
+ if loop is None:
+ loop = _get_shared_event_loop()
+
+ def async_to_sync_callback(*args, **kwargs):
+ caller = CallFromSyncToAsync(loop)
+ return caller.safe_call_with_result(method(*args, **kwargs), 10)
+
+ return async_to_sync_callback
+ return method
+
+
+def _get_async_callback(method: typing.Callable):
+ if asyncio.iscoroutinefunction(method):
+ return method
+
+ async def sync_to_async_callback(*args, **kwargs):
+ return await to_thread(method, *args, **kwargs, executor=None)
+
+ return sync_to_async_callback
+
+
+class CallbackHandler:
+ def _init_callback_handler(self, mode: CallbackHandlerMode) -> None:
+ self._callbacks = defaultdict(list)
+ self._callback_mode = mode
+
+ def _execute_callbacks_sync(self, event_name: str, *args, **kwargs) -> None:
+ for callback in self._callbacks[event_name]:
+ callback(self, *args, **kwargs)
+
+ async def _execute_callbacks_async(self, event_name: str, *args, **kwargs) -> None:
+ tasks = [asyncio.create_task(callback(self, *args, **kwargs)) for callback in self._callbacks[event_name]]
+ if not tasks:
+ return
+ await asyncio.gather(*tasks)
+
+ def _prepare_callback(
+ self, callback: typing.Callable, loop: Optional[asyncio.AbstractEventLoop]
+ ) -> typing.Callable:
+ if self._callback_mode == CallbackHandlerMode.SYNC:
+ return _get_sync_callback(callback, loop)
+ return _get_async_callback(callback)
+
+ def _add_callback(self, event_name: str, callback: typing.Callable, loop: Optional[asyncio.AbstractEventLoop]):
+ self._callbacks[event_name].append(self._prepare_callback(callback, loop))
diff --git a/contrib/python/ydb/py3/ydb/query/pool.py b/contrib/python/ydb/py3/ydb/query/pool.py
index e3775c4dd12..b25f7db855c 100644
--- a/contrib/python/ydb/py3/ydb/query/pool.py
+++ b/contrib/python/ydb/py3/ydb/query/pool.py
@@ -167,6 +167,8 @@ class QuerySessionPool:
def wrapped_callee():
with self.checkout(timeout=retry_settings.max_session_acquire_timeout) as session:
with session.transaction(tx_mode=tx_mode) as tx:
+ if tx_mode.name in ["serializable_read_write", "snapshot_read_only"]:
+ tx.begin()
result = callee(tx, *args, **kwargs)
tx.commit()
return result
@@ -224,9 +226,6 @@ class QuerySessionPool:
def __exit__(self, exc_type, exc_val, exc_tb):
self.stop()
- def __del__(self):
- self.stop()
-
class SimpleQuerySessionCheckout:
def __init__(self, pool: QuerySessionPool, timeout: Optional[float]):
diff --git a/contrib/python/ydb/py3/ydb/query/transaction.py b/contrib/python/ydb/py3/ydb/query/transaction.py
index 414401da4d1..ae7642dbe21 100644
--- a/contrib/python/ydb/py3/ydb/query/transaction.py
+++ b/contrib/python/ydb/py3/ydb/query/transaction.py
@@ -11,6 +11,7 @@ from .. import (
_apis,
issues,
)
+from .._grpc.grpcwrapper import ydb_topic as _ydb_topic
from .._grpc.grpcwrapper import ydb_query as _ydb_query
from ..connection import _RpcState as RpcState
@@ -42,11 +43,23 @@ class QueryTxStateHelper(abc.ABC):
QueryTxStateEnum.DEAD: [],
}
+ _SKIP_TRANSITIONS = {
+ QueryTxStateEnum.NOT_INITIALIZED: [],
+ QueryTxStateEnum.BEGINED: [],
+ QueryTxStateEnum.COMMITTED: [QueryTxStateEnum.COMMITTED, QueryTxStateEnum.ROLLBACKED],
+ QueryTxStateEnum.ROLLBACKED: [QueryTxStateEnum.COMMITTED, QueryTxStateEnum.ROLLBACKED],
+ QueryTxStateEnum.DEAD: [],
+ }
+
@classmethod
def valid_transition(cls, before: QueryTxStateEnum, after: QueryTxStateEnum) -> bool:
return after in cls._VALID_TRANSITIONS[before]
@classmethod
+ def should_skip(cls, before: QueryTxStateEnum, after: QueryTxStateEnum) -> bool:
+ return after in cls._SKIP_TRANSITIONS[before]
+
+ @classmethod
def terminal(cls, state: QueryTxStateEnum) -> bool:
return len(cls._VALID_TRANSITIONS[state]) == 0
@@ -88,8 +101,8 @@ class QueryTxState:
if QueryTxStateHelper.terminal(self._state):
raise RuntimeError(f"Transaction is in terminal state: {self._state.value}")
- def _already_in(self, target: QueryTxStateEnum) -> bool:
- return self._state == target
+ def _should_skip(self, target: QueryTxStateEnum) -> bool:
+ return QueryTxStateHelper.should_skip(self._state, target)
def _construct_tx_settings(tx_state: QueryTxState) -> _ydb_query.TransactionSettings:
@@ -170,7 +183,7 @@ def wrap_tx_rollback_response(
return tx
-class BaseQueryTxContext:
+class BaseQueryTxContext(base.CallbackHandler):
def __init__(self, driver, session_state, session, tx_mode):
"""
An object that provides a simple transaction context manager that allows statements execution
@@ -196,6 +209,7 @@ class BaseQueryTxContext:
self._session_state = session_state
self.session = session
self._prev_stream = None
+ self._external_error = None
@property
def session_id(self) -> str:
@@ -215,6 +229,19 @@ class BaseQueryTxContext:
"""
return self._tx_state.tx_id
+ def _tx_identity(self) -> _ydb_topic.TransactionIdentity:
+ if not self.tx_id:
+ raise RuntimeError("Unable to get tx identity without started tx.")
+ return _ydb_topic.TransactionIdentity(self.tx_id, self.session_id)
+
+ def _set_external_error(self, exc: BaseException) -> None:
+ self._external_error = exc
+
+ def _check_external_error_set(self):
+ if self._external_error is None:
+ return
+ raise issues.ClientInternalError("Transaction was failed by external error.") from self._external_error
+
def _begin_call(self, settings: Optional[BaseRequestSettings]) -> "BaseQueryTxContext":
self._tx_state._check_invalid_transition(QueryTxStateEnum.BEGINED)
@@ -228,6 +255,7 @@ class BaseQueryTxContext:
)
def _commit_call(self, settings: Optional[BaseRequestSettings]) -> "BaseQueryTxContext":
+ self._check_external_error_set()
self._tx_state._check_invalid_transition(QueryTxStateEnum.COMMITTED)
return self._driver(
@@ -240,6 +268,7 @@ class BaseQueryTxContext:
)
def _rollback_call(self, settings: Optional[BaseRequestSettings]) -> "BaseQueryTxContext":
+ self._check_external_error_set()
self._tx_state._check_invalid_transition(QueryTxStateEnum.ROLLBACKED)
return self._driver(
@@ -262,6 +291,7 @@ class BaseQueryTxContext:
settings: Optional[BaseRequestSettings],
) -> Iterable[_apis.ydb_query.ExecuteQueryResponsePart]:
self._tx_state._check_tx_ready_to_use()
+ self._check_external_error_set()
request = base.create_execute_query_request(
query=query,
@@ -283,18 +313,41 @@ class BaseQueryTxContext:
)
def _move_to_beginned(self, tx_id: str) -> None:
- if self._tx_state._already_in(QueryTxStateEnum.BEGINED) or not tx_id:
+ if self._tx_state._should_skip(QueryTxStateEnum.BEGINED) or not tx_id:
return
self._tx_state._change_state(QueryTxStateEnum.BEGINED)
self._tx_state.tx_id = tx_id
def _move_to_commited(self) -> None:
- if self._tx_state._already_in(QueryTxStateEnum.COMMITTED):
+ if self._tx_state._should_skip(QueryTxStateEnum.COMMITTED):
return
self._tx_state._change_state(QueryTxStateEnum.COMMITTED)
class QueryTxContext(BaseQueryTxContext):
+ def __init__(self, driver, session_state, session, tx_mode):
+ """
+ An object that provides a simple transaction context manager that allows statements execution
+ in a transaction. You don't have to open transaction explicitly, because context manager encapsulates
+ transaction control logic, and opens new transaction if:
+
+ 1) By explicit .begin() method;
+ 2) On execution of a first statement, which is strictly recommended method, because that avoids useless round trip
+
+ This context manager is not thread-safe, so you should not manipulate on it concurrently.
+
+ :param driver: A driver instance
+ :param session_state: A state of session
+ :param tx_mode: Transaction mode, which is a one from the following choises:
+ 1) QuerySerializableReadWrite() which is default mode;
+ 2) QueryOnlineReadOnly(allow_inconsistent_reads=False);
+ 3) QuerySnapshotReadOnly();
+ 4) QueryStaleReadOnly().
+ """
+
+ super().__init__(driver, session_state, session, tx_mode)
+ self._init_callback_handler(base.CallbackHandlerMode.SYNC)
+
def __enter__(self) -> "BaseQueryTxContext":
"""
Enters a context manager and returns a transaction
@@ -309,7 +362,7 @@ class QueryTxContext(BaseQueryTxContext):
it is not finished explicitly
"""
self._ensure_prev_stream_finished()
- if self._tx_state._state == QueryTxStateEnum.BEGINED:
+ if self._tx_state._state == QueryTxStateEnum.BEGINED and self._external_error is None:
# It's strictly recommended to close transactions directly
# by using commit_tx=True flag while executing statement or by
# .commit() or .rollback() methods, but here we trying to do best
@@ -345,7 +398,8 @@ class QueryTxContext(BaseQueryTxContext):
:return: A committed transaction or exception if commit is failed
"""
- if self._tx_state._already_in(QueryTxStateEnum.COMMITTED):
+ self._check_external_error_set()
+ if self._tx_state._should_skip(QueryTxStateEnum.COMMITTED):
return
if self._tx_state._state == QueryTxStateEnum.NOT_INITIALIZED:
@@ -354,7 +408,13 @@ class QueryTxContext(BaseQueryTxContext):
self._ensure_prev_stream_finished()
- self._commit_call(settings)
+ try:
+ self._execute_callbacks_sync(base.TxEvent.BEFORE_COMMIT)
+ self._commit_call(settings)
+ self._execute_callbacks_sync(base.TxEvent.AFTER_COMMIT, exc=None)
+ except BaseException as e: # TODO: probably should be less wide
+ self._execute_callbacks_sync(base.TxEvent.AFTER_COMMIT, exc=e)
+ raise e
def rollback(self, settings: Optional[BaseRequestSettings] = None) -> None:
"""Calls rollback on a transaction if it is open otherwise is no-op. If transaction execution
@@ -364,7 +424,8 @@ class QueryTxContext(BaseQueryTxContext):
:return: A committed transaction or exception if commit is failed
"""
- if self._tx_state._already_in(QueryTxStateEnum.ROLLBACKED):
+ self._check_external_error_set()
+ if self._tx_state._should_skip(QueryTxStateEnum.ROLLBACKED):
return
if self._tx_state._state == QueryTxStateEnum.NOT_INITIALIZED:
@@ -373,7 +434,13 @@ class QueryTxContext(BaseQueryTxContext):
self._ensure_prev_stream_finished()
- self._rollback_call(settings)
+ try:
+ self._execute_callbacks_sync(base.TxEvent.BEFORE_ROLLBACK)
+ self._rollback_call(settings)
+ self._execute_callbacks_sync(base.TxEvent.AFTER_ROLLBACK, exc=None)
+ except BaseException as e: # TODO: probably should be less wide
+ self._execute_callbacks_sync(base.TxEvent.AFTER_ROLLBACK, exc=e)
+ raise e
def execute(
self,
diff --git a/contrib/python/ydb/py3/ydb/topic.py b/contrib/python/ydb/py3/ydb/topic.py
index 55f4ea04c5c..52f98e61d85 100644
--- a/contrib/python/ydb/py3/ydb/topic.py
+++ b/contrib/python/ydb/py3/ydb/topic.py
@@ -25,6 +25,8 @@ __all__ = [
"TopicWriteResult",
"TopicWriter",
"TopicWriterAsyncIO",
+ "TopicTxWriter",
+ "TopicTxWriterAsyncIO",
"TopicWriterInitInfo",
"TopicWriterMessage",
"TopicWriterSettings",
@@ -33,6 +35,7 @@ __all__ = [
import concurrent.futures
import datetime
from dataclasses import dataclass
+import logging
from typing import List, Union, Mapping, Optional, Dict, Callable
from . import aio, Credentials, _apis, issues
@@ -65,8 +68,10 @@ from ._topic_writer.topic_writer import ( # noqa: F401
PublicWriteResult as TopicWriteResult,
)
+from ydb._topic_writer.topic_writer_asyncio import TxWriterAsyncIO as TopicTxWriterAsyncIO
from ydb._topic_writer.topic_writer_asyncio import WriterAsyncIO as TopicWriterAsyncIO
from ._topic_writer.topic_writer_sync import WriterSync as TopicWriter
+from ._topic_writer.topic_writer_sync import TxWriterSync as TopicTxWriter
from ._topic_common.common import (
wrap_operation as _wrap_operation,
@@ -88,6 +93,8 @@ from ._grpc.grpcwrapper.ydb_topic_public_types import ( # noqa: F401
PublicAlterAutoPartitioningSettings as TopicAlterAutoPartitioningSettings,
)
+logger = logging.getLogger(__name__)
+
class TopicClientAsyncIO:
_closed: bool
@@ -108,7 +115,8 @@ class TopicClientAsyncIO:
)
def __del__(self):
- self.close()
+ if not self._closed:
+ logger.warning("Topic client was not closed properly. Consider using method close().")
async def create_topic(
self,
@@ -276,6 +284,35 @@ class TopicClientAsyncIO:
return TopicWriterAsyncIO(self._driver, settings, _client=self)
+ def tx_writer(
+ self,
+ tx,
+ topic,
+ *,
+ producer_id: Optional[str] = None, # default - random
+ session_metadata: Mapping[str, str] = None,
+ partition_id: Union[int, None] = None,
+ auto_seqno: bool = True,
+ auto_created_at: bool = True,
+ codec: Optional[TopicCodec] = None, # default mean auto-select
+ # encoders: map[codec_code] func(encoded_bytes)->decoded_bytes
+ # the func will be called from multiply threads in parallel.
+ encoders: Optional[Mapping[_ydb_topic_public_types.PublicCodec, Callable[[bytes], bytes]]] = None,
+ # custom encoder executor for call builtin and custom decoders. If None - use shared executor pool.
+ # If max_worker in the executor is 1 - then encoders will be called from the thread without parallel.
+ encoder_executor: Optional[concurrent.futures.Executor] = None,
+ ) -> TopicTxWriterAsyncIO:
+ args = locals().copy()
+ del args["self"]
+ del args["tx"]
+
+ settings = TopicWriterSettings(**args)
+
+ if not settings.encoder_executor:
+ settings.encoder_executor = self._executor
+
+ return TopicTxWriterAsyncIO(tx=tx, driver=self._driver, settings=settings, _client=self)
+
def close(self):
if self._closed:
return
@@ -287,7 +324,7 @@ class TopicClientAsyncIO:
if not self._closed:
return
- raise RuntimeError("Topic client closed")
+ raise issues.Error("Topic client closed")
class TopicClient:
@@ -310,7 +347,8 @@ class TopicClient:
)
def __del__(self):
- self.close()
+ if not self._closed:
+ logger.warning("Topic client was not closed properly. Consider using method close().")
def create_topic(
self,
@@ -487,6 +525,36 @@ class TopicClient:
return TopicWriter(self._driver, settings, _parent=self)
+ def tx_writer(
+ self,
+ tx,
+ topic,
+ *,
+ producer_id: Optional[str] = None, # default - random
+ session_metadata: Mapping[str, str] = None,
+ partition_id: Union[int, None] = None,
+ auto_seqno: bool = True,
+ auto_created_at: bool = True,
+ codec: Optional[TopicCodec] = None, # default mean auto-select
+ # encoders: map[codec_code] func(encoded_bytes)->decoded_bytes
+ # the func will be called from multiply threads in parallel.
+ encoders: Optional[Mapping[_ydb_topic_public_types.PublicCodec, Callable[[bytes], bytes]]] = None,
+ # custom encoder executor for call builtin and custom decoders. If None - use shared executor pool.
+ # If max_worker in the executor is 1 - then encoders will be called from the thread without parallel.
+ encoder_executor: Optional[concurrent.futures.Executor] = None, # default shared client executor pool
+ ) -> TopicWriter:
+ args = locals().copy()
+ del args["self"]
+ del args["tx"]
+ self._check_closed()
+
+ settings = TopicWriterSettings(**args)
+
+ if not settings.encoder_executor:
+ settings.encoder_executor = self._executor
+
+ return TopicTxWriter(tx, self._driver, settings, _parent=self)
+
def close(self):
if self._closed:
return
@@ -498,7 +566,7 @@ class TopicClient:
if not self._closed:
return
- raise RuntimeError("Topic client closed")
+ raise issues.Error("Topic client closed")
@dataclass
diff --git a/contrib/python/ydb/py3/ydb/ydb_version.py b/contrib/python/ydb/py3/ydb/ydb_version.py
index 8bd658d49e4..070a2455ef3 100644
--- a/contrib/python/ydb/py3/ydb/ydb_version.py
+++ b/contrib/python/ydb/py3/ydb/ydb_version.py
@@ -1 +1 @@
-VERSION = "3.19.3"
+VERSION = "3.20.0"