aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorrekby <rekby@ydb.tech>2023-03-21 19:27:13 +0300
committerrekby <rekby@ydb.tech>2023-03-21 19:27:13 +0300
commit0be64eefa9aa954902612e7c54beaf7f2bc6b89d (patch)
tree4f9af91d1aec584ce4c5a9d5c909f57945bbcbd9
parent208093ca374000ddbce32ab91ed8b81aa03ca3e7 (diff)
downloadydb-0be64eefa9aa954902612e7c54beaf7f2bc6b89d.tar.gz
update ydb python sdk to 3.0.1b9
-rw-r--r--ydb/public/sdk/python3/ya.make24
-rw-r--r--ydb/public/sdk/python3/ydb/__init__.py10
-rw-r--r--ydb/public/sdk/python3/ydb/_apis.py18
-rw-r--r--ydb/public/sdk/python3/ydb/_dbapi/__init__.py36
-rw-r--r--ydb/public/sdk/python3/ydb/_dbapi/connection.py73
-rw-r--r--ydb/public/sdk/python3/ydb/_dbapi/cursor.py172
-rw-r--r--ydb/public/sdk/python3/ydb/_dbapi/errors.py92
-rw-r--r--ydb/public/sdk/python3/ydb/_errors.py2
-rw-r--r--ydb/public/sdk/python3/ydb/_grpc/grpcwrapper/__init__.py0
-rw-r--r--ydb/public/sdk/python3/ydb/_grpc/grpcwrapper/common_utils.py309
-rw-r--r--ydb/public/sdk/python3/ydb/_grpc/grpcwrapper/ydb_scheme.py36
-rw-r--r--ydb/public/sdk/python3/ydb/_grpc/grpcwrapper/ydb_topic.py1164
-rw-r--r--ydb/public/sdk/python3/ydb/_grpc/grpcwrapper/ydb_topic_public_types.py200
-rw-r--r--ydb/public/sdk/python3/ydb/_sp_impl.py2
-rw-r--r--ydb/public/sdk/python3/ydb/_topic_common/__init__.py0
-rw-r--r--ydb/public/sdk/python3/ydb/_topic_common/common.py147
-rw-r--r--ydb/public/sdk/python3/ydb/_topic_common/test_helpers.py76
-rw-r--r--ydb/public/sdk/python3/ydb/_topic_reader/__init__.py0
-rw-r--r--ydb/public/sdk/python3/ydb/_topic_reader/datatypes.py189
-rw-r--r--ydb/public/sdk/python3/ydb/_topic_reader/topic_reader.py116
-rw-r--r--ydb/public/sdk/python3/ydb/_topic_reader/topic_reader_asyncio.py697
-rw-r--r--ydb/public/sdk/python3/ydb/_topic_reader/topic_reader_sync.py210
-rw-r--r--ydb/public/sdk/python3/ydb/_topic_writer/__init__.py0
-rw-r--r--ydb/public/sdk/python3/ydb/_topic_writer/topic_writer.py220
-rw-r--r--ydb/public/sdk/python3/ydb/_topic_writer/topic_writer_asyncio.py684
-rw-r--r--ydb/public/sdk/python3/ydb/_topic_writer/topic_writer_sync.py123
-rw-r--r--ydb/public/sdk/python3/ydb/_utilities.py27
-rw-r--r--ydb/public/sdk/python3/ydb/aio/connection.py9
-rw-r--r--ydb/public/sdk/python3/ydb/aio/credentials.py2
-rw-r--r--ydb/public/sdk/python3/ydb/aio/driver.py48
-rw-r--r--ydb/public/sdk/python3/ydb/aio/iam.py2
-rw-r--r--ydb/public/sdk/python3/ydb/aio/table.py7
-rw-r--r--ydb/public/sdk/python3/ydb/auth_helpers.py43
-rw-r--r--ydb/public/sdk/python3/ydb/convert.py21
-rw-r--r--ydb/public/sdk/python3/ydb/credentials.py19
-rw-r--r--ydb/public/sdk/python3/ydb/dbapi/cursor.py4
-rw-r--r--ydb/public/sdk/python3/ydb/default_pem.py7
-rw-r--r--ydb/public/sdk/python3/ydb/driver.py28
-rw-r--r--ydb/public/sdk/python3/ydb/export.py4
-rw-r--r--ydb/public/sdk/python3/ydb/global_settings.py22
-rw-r--r--ydb/public/sdk/python3/ydb/iam/auth.py5
-rw-r--r--ydb/public/sdk/python3/ydb/import_client.py4
-rw-r--r--ydb/public/sdk/python3/ydb/issues.py13
-rw-r--r--ydb/public/sdk/python3/ydb/pool.py11
-rw-r--r--ydb/public/sdk/python3/ydb/scheme.py4
-rw-r--r--ydb/public/sdk/python3/ydb/scheme_test.py30
-rw-r--r--ydb/public/sdk/python3/ydb/scripting.py5
-rw-r--r--ydb/public/sdk/python3/ydb/sqlalchemy/__init__.py18
-rw-r--r--ydb/public/sdk/python3/ydb/table.py37
-rw-r--r--ydb/public/sdk/python3/ydb/table_test.py140
-rw-r--r--ydb/public/sdk/python3/ydb/topic.py389
-rw-r--r--ydb/public/sdk/python3/ydb/types.py14
-rw-r--r--ydb/public/sdk/python3/ydb/ydb_version.py2
53 files changed, 5132 insertions, 383 deletions
diff --git a/ydb/public/sdk/python3/ya.make b/ydb/public/sdk/python3/ya.make
index d878b0da83..d8342cee2b 100644
--- a/ydb/public/sdk/python3/ya.make
+++ b/ydb/public/sdk/python3/ya.make
@@ -4,11 +4,32 @@ PY_SRCS(
TOP_LEVEL
ydb/__init__.py
ydb/_apis.py
+ ydb/_dbapi/__init__.py
+ ydb/_dbapi/connection.py
+ ydb/_dbapi/cursor.py
+ ydb/_dbapi/errors.py
ydb/_errors.py
ydb/_grpc/__init__.py
ydb/_grpc/common/__init__.py
+ ydb/_grpc/grpcwrapper/__init__.py
+ ydb/_grpc/grpcwrapper/common_utils.py
+ ydb/_grpc/grpcwrapper/ydb_scheme.py
+ ydb/_grpc/grpcwrapper/ydb_topic.py
+ ydb/_grpc/grpcwrapper/ydb_topic_public_types.py
ydb/_session_impl.py
ydb/_sp_impl.py
+ ydb/_topic_common/__init__.py
+ ydb/_topic_common/common.py
+ ydb/_topic_common/test_helpers.py
+ ydb/_topic_reader/__init__.py
+ ydb/_topic_reader/datatypes.py
+ ydb/_topic_reader/topic_reader.py
+ ydb/_topic_reader/topic_reader_asyncio.py
+ ydb/_topic_reader/topic_reader_sync.py
+ ydb/_topic_writer/__init__.py
+ ydb/_topic_writer/topic_writer.py
+ ydb/_topic_writer/topic_writer_asyncio.py
+ ydb/_topic_writer/topic_writer_sync.py
ydb/_tx_ctx_impl.py
ydb/_utilities.py
ydb/aio/__init__.py
@@ -42,13 +63,12 @@ PY_SRCS(
ydb/pool.py
ydb/resolver.py
ydb/scheme.py
- ydb/scheme_test.py
ydb/scripting.py
ydb/settings.py
ydb/sqlalchemy/__init__.py
ydb/sqlalchemy/types.py
ydb/table.py
- ydb/table_test.py
+ ydb/topic.py
ydb/tracing.py
ydb/types.py
ydb/ydb_version.py
diff --git a/ydb/public/sdk/python3/ydb/__init__.py b/ydb/public/sdk/python3/ydb/__init__.py
index 7395ae36a1..648077880e 100644
--- a/ydb/public/sdk/python3/ydb/__init__.py
+++ b/ydb/public/sdk/python3/ydb/__init__.py
@@ -13,16 +13,22 @@ from .operation import * # noqa
from .scripting import * # noqa
from .import_client import * # noqa
from .tracing import * # noqa
+from .topic import * # noqa
try:
import ydb.aio as aio # noqa
except Exception:
pass
+
try:
import kikimr.public.sdk.python.ydb_v3_new_behavior # noqa
global_allow_split_transactions(False) # noqa
- global_allow_split_transactions(False) # noqa
+ global_allow_truncated_result(False) # noqa
except ModuleNotFoundError:
# Old, deprecated
+
+ import warnings
+ warnings.warn("Used deprecated behavior, for fix ADD PEERDIR kikimr/public/sdk/python/ydb_v3_new_behavior")
+
global_allow_split_transactions(True) # noqa
- global_allow_split_transactions(True) # noqa
+ global_allow_truncated_result(True) # noqa
diff --git a/ydb/public/sdk/python3/ydb/_apis.py b/ydb/public/sdk/python3/ydb/_apis.py
index 6f2fc3ab6a..27bc1bbec8 100644
--- a/ydb/public/sdk/python3/ydb/_apis.py
+++ b/ydb/public/sdk/python3/ydb/_apis.py
@@ -1,13 +1,15 @@
# -*- coding: utf-8 -*-
+import typing
+
# Workaround for good IDE and universal for runtime
-# noinspection PyUnreachableCode
-if False:
+if typing.TYPE_CHECKING:
from ._grpc.v4 import (
ydb_cms_v1_pb2_grpc,
ydb_discovery_v1_pb2_grpc,
ydb_scheme_v1_pb2_grpc,
ydb_table_v1_pb2_grpc,
ydb_operation_v1_pb2_grpc,
+ ydb_topic_v1_pb2_grpc,
)
from ._grpc.v4.protos import (
@@ -26,6 +28,7 @@ else:
ydb_scheme_v1_pb2_grpc,
ydb_table_v1_pb2_grpc,
ydb_operation_v1_pb2_grpc,
+ ydb_topic_v1_pb2_grpc,
)
from ._grpc.common.protos import (
@@ -38,6 +41,7 @@ else:
ydb_common_pb2,
)
+
StatusIds = ydb_status_codes_pb2.StatusIds
FeatureFlag = ydb_common_pb2.FeatureFlag
primitive_types = ydb_value_pb2.Type.PrimitiveTypeId
@@ -95,3 +99,13 @@ class TableService(object):
KeepAlive = "KeepAlive"
StreamReadTable = "StreamReadTable"
BulkUpsert = "BulkUpsert"
+
+
+class TopicService(object):
+ Stub = ydb_topic_v1_pb2_grpc.TopicServiceStub
+
+ CreateTopic = "CreateTopic"
+ DescribeTopic = "DescribeTopic"
+ DropTopic = "DropTopic"
+ StreamRead = "StreamRead"
+ StreamWrite = "StreamWrite"
diff --git a/ydb/public/sdk/python3/ydb/_dbapi/__init__.py b/ydb/public/sdk/python3/ydb/_dbapi/__init__.py
new file mode 100644
index 0000000000..8756b0f2d4
--- /dev/null
+++ b/ydb/public/sdk/python3/ydb/_dbapi/__init__.py
@@ -0,0 +1,36 @@
+from .connection import Connection
+from .errors import (
+ Warning,
+ Error,
+ InterfaceError,
+ DatabaseError,
+ DataError,
+ OperationalError,
+ IntegrityError,
+ InternalError,
+ ProgrammingError,
+ NotSupportedError,
+)
+
+apilevel = "1.0"
+
+threadsafety = 0
+
+paramstyle = "pyformat"
+
+errors = (
+ Warning,
+ Error,
+ InterfaceError,
+ DatabaseError,
+ DataError,
+ OperationalError,
+ IntegrityError,
+ InternalError,
+ ProgrammingError,
+ NotSupportedError,
+)
+
+
+def connect(*args, **kwargs):
+ return Connection(*args, **kwargs)
diff --git a/ydb/public/sdk/python3/ydb/_dbapi/connection.py b/ydb/public/sdk/python3/ydb/_dbapi/connection.py
new file mode 100644
index 0000000000..75bfeb582f
--- /dev/null
+++ b/ydb/public/sdk/python3/ydb/_dbapi/connection.py
@@ -0,0 +1,73 @@
+import posixpath
+
+import ydb
+from .cursor import Cursor
+from .errors import DatabaseError
+
+
+class Connection:
+ def __init__(self, endpoint, database=None, **conn_kwargs):
+ self.endpoint = endpoint
+ self.database = database
+ self.driver = self._create_driver(self.endpoint, self.database, **conn_kwargs)
+ self.pool = ydb.SessionPool(self.driver)
+
+ def cursor(self):
+ return Cursor(self)
+
+ def describe(self, table_path):
+ full_path = posixpath.join(self.database, table_path)
+ try:
+ res = self.pool.retry_operation_sync(
+ lambda cli: cli.describe_table(full_path)
+ )
+ return res.columns
+ except ydb.Error as e:
+ raise DatabaseError(e.message, e.issues, e.status)
+ except Exception:
+ raise DatabaseError(f"Failed to describe table {table_path}")
+
+ def check_exists(self, table_path):
+ try:
+ self.driver.scheme_client.describe_path(table_path)
+ return True
+ except ydb.SchemeError:
+ return False
+
+ def commit(self):
+ pass
+
+ def rollback(self):
+ pass
+
+ def close(self):
+ if self.pool:
+ self.pool.stop()
+ if self.driver:
+ self.driver.stop()
+
+ @staticmethod
+ def _create_driver(endpoint, database, **conn_kwargs):
+ # TODO: add cache for initialized drivers/pools?
+ driver_config = ydb.DriverConfig(
+ endpoint,
+ database=database,
+ table_client_settings=ydb.TableClientSettings()
+ .with_native_date_in_result_sets(True)
+ .with_native_datetime_in_result_sets(True)
+ .with_native_timestamp_in_result_sets(True)
+ .with_native_interval_in_result_sets(True)
+ .with_native_json_in_result_sets(True),
+ **conn_kwargs,
+ )
+ driver = ydb.Driver(driver_config)
+ try:
+ driver.wait(timeout=5, fail_fast=True)
+ except ydb.Error as e:
+ raise DatabaseError(e.message, e.issues, e.status)
+ except Exception:
+ driver.stop()
+ raise DatabaseError(
+ f"Failed to connect to YDB, details {driver.discovery_debug_details()}"
+ )
+ return driver
diff --git a/ydb/public/sdk/python3/ydb/_dbapi/cursor.py b/ydb/public/sdk/python3/ydb/_dbapi/cursor.py
new file mode 100644
index 0000000000..57659c7abf
--- /dev/null
+++ b/ydb/public/sdk/python3/ydb/_dbapi/cursor.py
@@ -0,0 +1,172 @@
+import datetime
+import itertools
+import logging
+import uuid
+import decimal
+
+import ydb
+from .errors import DatabaseError, ProgrammingError
+
+
+logger = logging.getLogger(__name__)
+
+
+def get_column_type(type_obj):
+ return str(ydb.convert.type_to_native(type_obj))
+
+
+def _generate_type_str(value):
+ tvalue = type(value)
+
+ stype = {
+ bool: "Bool",
+ bytes: "String",
+ str: "Utf8",
+ int: "Int64",
+ float: "Double",
+ decimal.Decimal: "Decimal(22, 9)",
+ datetime.date: "Date",
+ datetime.datetime: "Timestamp",
+ datetime.timedelta: "Interval",
+ uuid.UUID: "Uuid",
+ }.get(tvalue)
+
+ if tvalue == dict:
+ types_lst = ", ".join(f"{k}: {_generate_type_str(v)}" for k, v in value.items())
+ stype = f"Struct<{types_lst}>"
+
+ elif tvalue == tuple:
+ types_lst = ", ".join(_generate_type_str(x) for x in value)
+ stype = f"Tuple<{types_lst}>"
+
+ elif tvalue == list:
+ nested_type = _generate_type_str(value[0])
+ stype = f"List<{nested_type}>"
+
+ elif tvalue == set:
+ nested_type = _generate_type_str(next(iter(value)))
+ stype = f"Set<{nested_type}>"
+
+ if stype is None:
+ raise ProgrammingError(
+ "Cannot translate python type to ydb type.", tvalue, value
+ )
+
+ return stype
+
+
+def _generate_declare_stms(params: dict) -> str:
+ return "".join(
+ f"DECLARE {k} AS {_generate_type_str(t)}; " for k, t in params.items()
+ )
+
+
+class Cursor(object):
+ def __init__(self, connection):
+ self.connection = connection
+ self.description = None
+ self.arraysize = 1
+ self.rows = None
+ self._rows_prefetched = None
+
+ def execute(self, sql, parameters=None, context=None):
+ self.description = None
+ sql_params = None
+
+ if parameters:
+ sql = sql % {k: f"${k}" for k, v in parameters.items()}
+ sql_params = {f"${k}": v for k, v in parameters.items()}
+ declare_stms = _generate_declare_stms(sql_params)
+ sql = f"{declare_stms}{sql}"
+
+ logger.info("execute sql: %s, params: %s", sql, sql_params)
+
+ def _execute_in_pool(cli):
+ try:
+ if context and context.get("isddl"):
+ return cli.execute_scheme(sql)
+ else:
+ prepared_query = cli.prepare(sql)
+ return cli.transaction().execute(
+ prepared_query, sql_params, commit_tx=True
+ )
+ except ydb.Error as e:
+ raise DatabaseError(e.message, e.issues, e.status)
+
+ chunks = self.connection.pool.retry_operation_sync(_execute_in_pool)
+ rows = self._rows_iterable(chunks)
+ # Prefetch the description:
+ try:
+ first_row = next(rows)
+ except StopIteration:
+ pass
+ else:
+ rows = itertools.chain((first_row,), rows)
+ if self.rows is not None:
+ rows = itertools.chain(self.rows, rows)
+
+ self.rows = rows
+
+ def _rows_iterable(self, chunks_iterable):
+ try:
+ for chunk in chunks_iterable:
+ self.description = [
+ (
+ col.name,
+ get_column_type(col.type),
+ None,
+ None,
+ None,
+ None,
+ None,
+ )
+ for col in chunk.columns
+ ]
+ for row in chunk.rows:
+ # returns tuple to be compatible with SqlAlchemy and because
+ # of this PEP to return a sequence: https://www.python.org/dev/peps/pep-0249/#fetchmany
+ yield row[::]
+ except ydb.Error as e:
+ raise DatabaseError(e.message, e.issues, e.status)
+
+ def _ensure_prefetched(self):
+ if self.rows is not None and self._rows_prefetched is None:
+ self._rows_prefetched = list(self.rows)
+ self.rows = iter(self._rows_prefetched)
+ return self._rows_prefetched
+
+ def executemany(self, sql, seq_of_parameters):
+ for parameters in seq_of_parameters:
+ self.execute(sql, parameters)
+
+ def executescript(self, script):
+ return self.execute(script)
+
+ def fetchone(self):
+ if self.rows is None:
+ return None
+ return next(self.rows, None)
+
+ def fetchmany(self, size=None):
+ size = self.arraysize if size is None else size
+ return list(itertools.islice(self.rows, size))
+
+ def fetchall(self):
+ return list(self.rows)
+
+ def nextset(self):
+ self.fetchall()
+
+ def setinputsizes(self, sizes):
+ pass
+
+ def setoutputsize(self, column=None):
+ pass
+
+ def close(self):
+ self.rows = None
+ self._rows_prefetched = None
+
+ @property
+ def rowcount(self):
+ return len(self._ensure_prefetched())
diff --git a/ydb/public/sdk/python3/ydb/_dbapi/errors.py b/ydb/public/sdk/python3/ydb/_dbapi/errors.py
new file mode 100644
index 0000000000..ddb55b4c90
--- /dev/null
+++ b/ydb/public/sdk/python3/ydb/_dbapi/errors.py
@@ -0,0 +1,92 @@
+class Warning(Exception):
+ pass
+
+
+class Error(Exception):
+ def __init__(self, message, issues=None, status=None):
+ super(Error, self).__init__(message)
+
+ pretty_issues = _pretty_issues(issues)
+ self.issues = issues
+ self.message = pretty_issues or message
+ self.status = status
+
+
+class InterfaceError(Error):
+ pass
+
+
+class DatabaseError(Error):
+ pass
+
+
+class DataError(DatabaseError):
+ pass
+
+
+class OperationalError(DatabaseError):
+ pass
+
+
+class IntegrityError(DatabaseError):
+ pass
+
+
+class InternalError(DatabaseError):
+ pass
+
+
+class ProgrammingError(DatabaseError):
+ pass
+
+
+class NotSupportedError(DatabaseError):
+ pass
+
+
+def _pretty_issues(issues):
+ if issues is None:
+ return None
+
+ children_messages = [_get_messages(issue, root=True) for issue in issues]
+
+ if None in children_messages:
+ return None
+
+ return "\n" + "\n".join(children_messages)
+
+
+def _get_messages(issue, max_depth=100, indent=2, depth=0, root=False):
+ if depth >= max_depth:
+ return None
+
+ margin_str = " " * depth * indent
+ pre_message = ""
+ children = ""
+
+ if issue.issues:
+ collapsed_messages = []
+ while not root and len(issue.issues) == 1:
+ collapsed_messages.append(issue.message)
+ issue = issue.issues[0]
+
+ if collapsed_messages:
+ pre_message = f"{margin_str}{', '.join(collapsed_messages)}\n"
+ depth += 1
+ margin_str = " " * depth * indent
+
+ children_messages = [
+ _get_messages(iss, max_depth=max_depth, indent=indent, depth=depth + 1)
+ for iss in issue.issues
+ ]
+
+ if None in children_messages:
+ return None
+
+ children = "\n".join(children_messages)
+
+ return (
+ f"{pre_message}{margin_str}{issue.message}\n{margin_str}"
+ f"severity level: {issue.severity}\n{margin_str}"
+ f"issue code: {issue.issue_code}\n{children}"
+ )
diff --git a/ydb/public/sdk/python3/ydb/_errors.py b/ydb/public/sdk/python3/ydb/_errors.py
index 8c6f072049..ae3057b6d2 100644
--- a/ydb/public/sdk/python3/ydb/_errors.py
+++ b/ydb/public/sdk/python3/ydb/_errors.py
@@ -1,7 +1,7 @@
from dataclasses import dataclass
from typing import Optional
-from . import issues
+from ydb import issues
_errors_retriable_fast_backoff_types = [
issues.Unavailable,
diff --git a/ydb/public/sdk/python3/ydb/_grpc/grpcwrapper/__init__.py b/ydb/public/sdk/python3/ydb/_grpc/grpcwrapper/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/ydb/public/sdk/python3/ydb/_grpc/grpcwrapper/__init__.py
diff --git a/ydb/public/sdk/python3/ydb/_grpc/grpcwrapper/common_utils.py b/ydb/public/sdk/python3/ydb/_grpc/grpcwrapper/common_utils.py
new file mode 100644
index 0000000000..6c624520ea
--- /dev/null
+++ b/ydb/public/sdk/python3/ydb/_grpc/grpcwrapper/common_utils.py
@@ -0,0 +1,309 @@
+from __future__ import annotations
+
+import abc
+import asyncio
+import contextvars
+import datetime
+import functools
+import typing
+from typing import (
+ Optional,
+ Any,
+ Iterator,
+ AsyncIterator,
+ Callable,
+ Iterable,
+ Union,
+ Coroutine,
+)
+from dataclasses import dataclass
+
+import grpc
+from google.protobuf.message import Message
+from google.protobuf.duration_pb2 import Duration as ProtoDuration
+from google.protobuf.timestamp_pb2 import Timestamp as ProtoTimeStamp
+
+import ydb.aio
+
+# Workaround for good IDE and universal for runtime
+if typing.TYPE_CHECKING:
+ from ..v4.protos import ydb_topic_pb2, ydb_issue_message_pb2
+else:
+ from ..common.protos import ydb_topic_pb2, ydb_issue_message_pb2
+
+from ... import issues, connection
+
+
+class IFromProto(abc.ABC):
+ @staticmethod
+ @abc.abstractmethod
+ def from_proto(msg: Message) -> Any:
+ ...
+
+
+class IFromProtoWithProtoType(IFromProto):
+ @staticmethod
+ @abc.abstractmethod
+ def empty_proto_message() -> Message:
+ ...
+
+
+class IToProto(abc.ABC):
+ @abc.abstractmethod
+ def to_proto(self) -> Message:
+ ...
+
+
+class IFromPublic(abc.ABC):
+ @staticmethod
+ @abc.abstractmethod
+ def from_public(o: typing.Any) -> typing.Any:
+ ...
+
+
+class IToPublic(abc.ABC):
+ @abc.abstractmethod
+ def to_public(self) -> typing.Any:
+ ...
+
+
+class UnknownGrpcMessageError(issues.Error):
+ pass
+
+
+_stop_grpc_connection_marker = object()
+
+
+class QueueToIteratorAsyncIO:
+ __slots__ = ("_queue",)
+
+ def __init__(self, q: asyncio.Queue):
+ self._queue = q
+
+ def __aiter__(self):
+ return self
+
+ async def __anext__(self):
+ item = await self._queue.get()
+ if item is _stop_grpc_connection_marker:
+ raise StopAsyncIteration()
+ return item
+
+
+class AsyncQueueToSyncIteratorAsyncIO:
+ __slots__ = (
+ "_loop",
+ "_queue",
+ )
+ _queue: asyncio.Queue
+
+ def __init__(self, q: asyncio.Queue):
+ self._loop = asyncio.get_running_loop()
+ self._queue = q
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ item = asyncio.run_coroutine_threadsafe(self._queue.get(), self._loop).result()
+ if item is _stop_grpc_connection_marker:
+ raise StopIteration()
+ return item
+
+
+class SyncIteratorToAsyncIterator:
+ def __init__(self, sync_iterator: Iterator):
+ self._sync_iterator = sync_iterator
+
+ def __aiter__(self):
+ return self
+
+ async def __anext__(self):
+ try:
+ res = await to_thread(self._sync_iterator.__next__)
+ return res
+ except StopAsyncIteration:
+ raise StopIteration()
+
+
+class IGrpcWrapperAsyncIO(abc.ABC):
+ @abc.abstractmethod
+ async def receive(self) -> Any:
+ ...
+
+ @abc.abstractmethod
+ def write(self, wrap_message: IToProto):
+ ...
+
+ @abc.abstractmethod
+ def close(self):
+ ...
+
+
+SupportedDriverType = Union[ydb.Driver, ydb.aio.Driver]
+
+
+class GrpcWrapperAsyncIO(IGrpcWrapperAsyncIO):
+ from_client_grpc: asyncio.Queue
+ from_server_grpc: AsyncIterator
+ convert_server_grpc_to_wrapper: Callable[[Any], Any]
+ _connection_state: str
+ _stream_call: Optional[
+ Union[grpc.aio.StreamStreamCall, "grpc._channel._MultiThreadedRendezvous"]
+ ]
+
+ def __init__(self, convert_server_grpc_to_wrapper):
+ self.from_client_grpc = asyncio.Queue()
+ self.convert_server_grpc_to_wrapper = convert_server_grpc_to_wrapper
+ self._connection_state = "new"
+ self._stream_call = None
+
+ async def start(self, driver: SupportedDriverType, stub, method):
+ if asyncio.iscoroutinefunction(driver.__call__):
+ await self._start_asyncio_driver(driver, stub, method)
+ else:
+ await self._start_sync_driver(driver, stub, method)
+ self._connection_state = "started"
+
+ def close(self):
+ self.from_client_grpc.put_nowait(_stop_grpc_connection_marker)
+ if self._stream_call:
+ self._stream_call.cancel()
+
+ async def _start_asyncio_driver(self, driver: ydb.aio.Driver, stub, method):
+ requests_iterator = QueueToIteratorAsyncIO(self.from_client_grpc)
+ stream_call = await driver(
+ requests_iterator,
+ stub,
+ method,
+ )
+ self._stream_call = stream_call
+ self.from_server_grpc = stream_call.__aiter__()
+
+ async def _start_sync_driver(self, driver: ydb.Driver, stub, method):
+ requests_iterator = AsyncQueueToSyncIteratorAsyncIO(self.from_client_grpc)
+ stream_call = await to_thread(
+ driver,
+ requests_iterator,
+ stub,
+ method,
+ )
+ self._stream_call = stream_call
+ self.from_server_grpc = SyncIteratorToAsyncIterator(stream_call.__iter__())
+
+ async def receive(self) -> Any:
+ # todo handle grpc exceptions and convert it to internal exceptions
+ try:
+ grpc_message = await self.from_server_grpc.__anext__()
+ except grpc.RpcError as e:
+ raise connection._rpc_error_handler(self._connection_state, e)
+
+ issues._process_response(grpc_message)
+
+ if self._connection_state != "has_received_messages":
+ self._connection_state = "has_received_messages"
+
+ # print("rekby, grpc, received", grpc_message)
+ return self.convert_server_grpc_to_wrapper(grpc_message)
+
+ def write(self, wrap_message: IToProto):
+ grpc_message = wrap_message.to_proto()
+ # print("rekby, grpc, send", grpc_message)
+ self.from_client_grpc.put_nowait(grpc_message)
+
+
+@dataclass(init=False)
+class ServerStatus(IFromProto):
+ __slots__ = ("_grpc_status_code", "_issues")
+
+ def __init__(
+ self,
+ status: issues.StatusCode,
+ issues: Iterable[Any],
+ ):
+ self.status = status
+ self.issues = issues
+
+ def __str__(self):
+ return self.__repr__()
+
+ @staticmethod
+ def from_proto(
+ msg: Union[
+ ydb_topic_pb2.StreamReadMessage.FromServer,
+ ydb_topic_pb2.StreamWriteMessage.FromServer,
+ ]
+ ) -> "ServerStatus":
+ return ServerStatus(msg.status, msg.issues)
+
+ def is_success(self) -> bool:
+ return self.status == issues.StatusCode.SUCCESS
+
+ @classmethod
+ def issue_to_str(cls, issue: ydb_issue_message_pb2.IssueMessage):
+ res = """code: %s message: "%s" """ % (issue.issue_code, issue.message)
+ if len(issue.issues) > 0:
+ d = ", "
+ res += d + d.join(str(sub_issue) for sub_issue in issue.issues)
+ return res
+
+
+def callback_from_asyncio(
+ callback: Union[Callable, Coroutine]
+) -> [asyncio.Future, asyncio.Task]:
+ loop = asyncio.get_running_loop()
+
+ if asyncio.iscoroutinefunction(callback):
+ return loop.create_task(callback())
+ else:
+ return loop.run_in_executor(None, callback)
+
+
+async def to_thread(func, /, *args, **kwargs):
+ """Asynchronously run function *func* in a separate thread.
+
+ Any *args and **kwargs supplied for this function are directly passed
+ to *func*. Also, the current :class:`contextvars.Context` is propagated,
+ allowing context variables from the main thread to be accessed in the
+ separate thread.
+
+ Return a coroutine that can be awaited to get the eventual result of *func*.
+
+ copy to_thread from 3.10
+ """
+
+ loop = asyncio.get_running_loop()
+ ctx = contextvars.copy_context()
+ func_call = functools.partial(ctx.run, func, *args, **kwargs)
+ return await loop.run_in_executor(None, func_call)
+
+
+def proto_duration_from_timedelta(t: Optional[datetime.timedelta]) -> ProtoDuration:
+ if t is None:
+ return None
+ res = ProtoDuration()
+ res.FromTimedelta(t)
+
+
+def proto_timestamp_from_datetime(t: Optional[datetime.datetime]) -> ProtoTimeStamp:
+ if t is None:
+ return None
+
+ res = ProtoTimeStamp()
+ res.FromDatetime(t)
+
+
+def datetime_from_proto_timestamp(
+ ts: Optional[ProtoTimeStamp],
+) -> Optional[datetime.datetime]:
+ if ts is None:
+ return None
+ return ts.ToDatetime()
+
+
+def timedelta_from_proto_duration(
+ d: Optional[ProtoDuration],
+) -> Optional[datetime.timedelta]:
+ if d is None:
+ return None
+ return d.ToTimedelta()
diff --git a/ydb/public/sdk/python3/ydb/_grpc/grpcwrapper/ydb_scheme.py b/ydb/public/sdk/python3/ydb/_grpc/grpcwrapper/ydb_scheme.py
new file mode 100644
index 0000000000..b992203570
--- /dev/null
+++ b/ydb/public/sdk/python3/ydb/_grpc/grpcwrapper/ydb_scheme.py
@@ -0,0 +1,36 @@
+import datetime
+import enum
+from dataclasses import dataclass
+from typing import List
+
+
+@dataclass
+class Entry:
+ name: str
+ owner: str
+ type: "Entry.Type"
+ effective_permissions: "Permissions"
+ permissions: "Permissions"
+ size_bytes: int
+ created_at: datetime.datetime
+
+ class Type(enum.IntEnum):
+ UNSPECIFIED = 0
+ DIRECTORY = 1
+ TABLE = 2
+ PERS_QUEUE_GROUP = 3
+ DATABASE = 4
+ RTMR_VOLUME = 5
+ BLOCK_STORE_VOLUME = 6
+ COORDINATION_NODE = 7
+ COLUMN_STORE = 12
+ COLUMN_TABLE = 13
+ SEQUENCE = 15
+ REPLICATION = 16
+ TOPIC = 17
+
+
+@dataclass
+class Permissions:
+ subject: str
+ permission_names: List[str]
diff --git a/ydb/public/sdk/python3/ydb/_grpc/grpcwrapper/ydb_topic.py b/ydb/public/sdk/python3/ydb/_grpc/grpcwrapper/ydb_topic.py
new file mode 100644
index 0000000000..4784d4866b
--- /dev/null
+++ b/ydb/public/sdk/python3/ydb/_grpc/grpcwrapper/ydb_topic.py
@@ -0,0 +1,1164 @@
+import datetime
+import enum
+import typing
+from dataclasses import dataclass, field
+from typing import List, Union, Dict, Optional
+
+from google.protobuf.message import Message
+
+from . import ydb_topic_public_types
+from ... import scheme
+
+# Workaround for good IDE and universal for runtime
+if typing.TYPE_CHECKING:
+ from ..v4.protos import ydb_scheme_pb2, ydb_topic_pb2
+else:
+ from ..common.protos import ydb_scheme_pb2, ydb_topic_pb2
+
+from .common_utils import (
+ IFromProto,
+ IFromProtoWithProtoType,
+ IToProto,
+ IToPublic,
+ IFromPublic,
+ ServerStatus,
+ UnknownGrpcMessageError,
+ proto_duration_from_timedelta,
+ proto_timestamp_from_datetime,
+ datetime_from_proto_timestamp,
+ timedelta_from_proto_duration,
+)
+
+
+class Codec(int, IToPublic):
+ CODEC_UNSPECIFIED = 0
+ CODEC_RAW = 1
+ CODEC_GZIP = 2
+ CODEC_LZOP = 3
+ CODEC_ZSTD = 4
+
+ @staticmethod
+ def from_proto_iterable(codecs: typing.Iterable[int]) -> List["Codec"]:
+ return [Codec(int(codec)) for codec in codecs]
+
+ def to_public(self) -> ydb_topic_public_types.PublicCodec:
+ return ydb_topic_public_types.PublicCodec(int(self))
+
+
+@dataclass
+class SupportedCodecs(IToProto, IFromProto, IToPublic):
+ codecs: List[Codec]
+
+ def to_proto(self) -> ydb_topic_pb2.SupportedCodecs:
+ return ydb_topic_pb2.SupportedCodecs(
+ codecs=self.codecs,
+ )
+
+ @staticmethod
+ def from_proto(msg: Optional[ydb_topic_pb2.SupportedCodecs]) -> "SupportedCodecs":
+ if msg is None:
+ return SupportedCodecs(codecs=[])
+
+ return SupportedCodecs(
+ codecs=Codec.from_proto_iterable(msg.codecs),
+ )
+
+ def to_public(self) -> List[ydb_topic_public_types.PublicCodec]:
+ return list(map(Codec.to_public, self.codecs))
+
+
+@dataclass(order=True)
+class OffsetsRange(IFromProto, IToProto):
+ """
+ half-opened interval, include [start, end) offsets
+ """
+
+ __slots__ = ("start", "end")
+
+ start: int # first offset
+ end: int # offset after last, included to range
+
+ def __post_init__(self):
+ if self.end < self.start:
+ raise ValueError(
+ "offset end must be not less then start. Got start=%s end=%s"
+ % (self.start, self.end)
+ )
+
+ @staticmethod
+ def from_proto(msg: ydb_topic_pb2.OffsetsRange) -> "OffsetsRange":
+ return OffsetsRange(
+ start=msg.start,
+ end=msg.end,
+ )
+
+ def to_proto(self) -> ydb_topic_pb2.OffsetsRange:
+ return ydb_topic_pb2.OffsetsRange(
+ start=self.start,
+ end=self.end,
+ )
+
+ def is_intersected_with(self, other: "OffsetsRange") -> bool:
+ return (
+ self.start <= other.start < self.end
+ or self.start < other.end <= self.end
+ or other.start <= self.start < other.end
+ or other.start < self.end <= other.end
+ )
+
+
+@dataclass
+class UpdateTokenRequest(IToProto):
+ token: str
+
+ def to_proto(self) -> Message:
+ res = ydb_topic_pb2.UpdateTokenRequest()
+ res.token = self.token
+ return res
+
+
+@dataclass
+class UpdateTokenResponse(IFromProto):
+ @staticmethod
+ def from_proto(msg: ydb_topic_pb2.UpdateTokenResponse) -> typing.Any:
+ return UpdateTokenResponse()
+
+
+########################################################################################################################
+# StreamWrite
+########################################################################################################################
+
+
+class StreamWriteMessage:
+ @dataclass()
+ class InitRequest(IToProto):
+ path: str
+ producer_id: str
+ write_session_meta: typing.Dict[str, str]
+ partitioning: "StreamWriteMessage.PartitioningType"
+ get_last_seq_no: bool
+
+ def to_proto(self) -> ydb_topic_pb2.StreamWriteMessage.InitRequest:
+ proto = ydb_topic_pb2.StreamWriteMessage.InitRequest()
+ proto.path = self.path
+ proto.producer_id = self.producer_id
+
+ if self.partitioning is None:
+ pass
+ elif isinstance(
+ self.partitioning, StreamWriteMessage.PartitioningMessageGroupID
+ ):
+ proto.message_group_id = self.partitioning.message_group_id
+ elif isinstance(
+ self.partitioning, StreamWriteMessage.PartitioningPartitionID
+ ):
+ proto.partition_id = self.partitioning.partition_id
+ else:
+ raise Exception(
+ "Bad partitioning type at StreamWriteMessage.InitRequest"
+ )
+
+ if self.write_session_meta:
+ for key in self.write_session_meta:
+ proto.write_session_meta[key] = self.write_session_meta[key]
+
+ proto.get_last_seq_no = self.get_last_seq_no
+ return proto
+
+ @dataclass
+ class InitResponse(IFromProto):
+ last_seq_no: Union[int, None]
+ session_id: str
+ partition_id: int
+ supported_codecs: typing.List[int]
+ status: ServerStatus = None
+
+ @staticmethod
+ def from_proto(
+ msg: ydb_topic_pb2.StreamWriteMessage.InitResponse,
+ ) -> "StreamWriteMessage.InitResponse":
+ codecs = [] # type: typing.List[int]
+ if msg.supported_codecs:
+ for codec in msg.supported_codecs.codecs:
+ codecs.append(codec)
+
+ return StreamWriteMessage.InitResponse(
+ last_seq_no=msg.last_seq_no,
+ session_id=msg.session_id,
+ partition_id=msg.partition_id,
+ supported_codecs=codecs,
+ )
+
+ @dataclass
+ class WriteRequest(IToProto):
+ messages: typing.List["StreamWriteMessage.WriteRequest.MessageData"]
+ codec: int
+
+ @dataclass
+ class MessageData(IToProto):
+ seq_no: int
+ created_at: datetime.datetime
+ data: bytes
+ uncompressed_size: int
+ partitioning: "StreamWriteMessage.PartitioningType"
+
+ def to_proto(
+ self,
+ ) -> ydb_topic_pb2.StreamWriteMessage.WriteRequest.MessageData:
+ proto = ydb_topic_pb2.StreamWriteMessage.WriteRequest.MessageData()
+ proto.seq_no = self.seq_no
+ proto.created_at.FromDatetime(self.created_at)
+ proto.data = self.data
+ proto.uncompressed_size = self.uncompressed_size
+
+ if self.partitioning is None:
+ pass
+ elif isinstance(
+ self.partitioning, StreamWriteMessage.PartitioningPartitionID
+ ):
+ proto.partition_id = self.partitioning.partition_id
+ elif isinstance(
+ self.partitioning, StreamWriteMessage.PartitioningMessageGroupID
+ ):
+ proto.message_group_id = self.partitioning.message_group_id
+ else:
+ raise Exception(
+ "Bad partition at StreamWriteMessage.WriteRequest.MessageData"
+ )
+
+ return proto
+
+ def to_proto(self) -> ydb_topic_pb2.StreamWriteMessage.WriteRequest:
+ proto = ydb_topic_pb2.StreamWriteMessage.WriteRequest()
+ proto.codec = self.codec
+
+ for message in self.messages:
+ proto_mess = proto.messages.add()
+ proto_mess.CopyFrom(message.to_proto())
+
+ return proto
+
+ @dataclass
+ class WriteResponse(IFromProto):
+ partition_id: int
+ acks: typing.List["StreamWriteMessage.WriteResponse.WriteAck"]
+ write_statistics: "StreamWriteMessage.WriteResponse.WriteStatistics"
+ status: Optional[ServerStatus] = field(default=None)
+
+ @staticmethod
+ def from_proto(
+ msg: ydb_topic_pb2.StreamWriteMessage.WriteResponse,
+ ) -> "StreamWriteMessage.WriteResponse":
+ acks = []
+ for proto_ack in msg.acks:
+ ack = StreamWriteMessage.WriteResponse.WriteAck.from_proto(proto_ack)
+ acks.append(ack)
+ write_statistics = StreamWriteMessage.WriteResponse.WriteStatistics(
+ persisting_time=msg.write_statistics.persisting_time.ToTimedelta(),
+ min_queue_wait_time=msg.write_statistics.min_queue_wait_time.ToTimedelta(),
+ max_queue_wait_time=msg.write_statistics.max_queue_wait_time.ToTimedelta(),
+ partition_quota_wait_time=msg.write_statistics.partition_quota_wait_time.ToTimedelta(),
+ topic_quota_wait_time=msg.write_statistics.topic_quota_wait_time.ToTimedelta(),
+ )
+ return StreamWriteMessage.WriteResponse(
+ partition_id=msg.partition_id,
+ acks=acks,
+ write_statistics=write_statistics,
+ status=None,
+ )
+
+ @dataclass
+ class WriteAck(IFromProto):
+ seq_no: int
+ message_write_status: Union[
+ "StreamWriteMessage.WriteResponse.WriteAck.StatusWritten",
+ "StreamWriteMessage.WriteResponse.WriteAck.StatusSkipped",
+ int,
+ ]
+
+ @classmethod
+ def from_proto(
+ cls, proto_ack: ydb_topic_pb2.StreamWriteMessage.WriteResponse.WriteAck
+ ):
+ if proto_ack.HasField("written"):
+ message_write_status = (
+ StreamWriteMessage.WriteResponse.WriteAck.StatusWritten(
+ proto_ack.written.offset
+ )
+ )
+ elif proto_ack.HasField("skipped"):
+ reason = proto_ack.skipped.reason
+ try:
+ message_write_status = StreamWriteMessage.WriteResponse.WriteAck.StatusSkipped(
+ reason=StreamWriteMessage.WriteResponse.WriteAck.StatusSkipped.Reason.from_protobuf_code(
+ reason
+ )
+ )
+ except ValueError:
+ message_write_status = reason
+ else:
+ raise NotImplementedError("unexpected ack status")
+
+ return StreamWriteMessage.WriteResponse.WriteAck(
+ seq_no=proto_ack.seq_no,
+ message_write_status=message_write_status,
+ )
+
+ @dataclass
+ class StatusWritten:
+ offset: int
+
+ @dataclass
+ class StatusSkipped:
+ reason: "StreamWriteMessage.WriteResponse.WriteAck.StatusSkipped.Reason"
+
+ class Reason(enum.Enum):
+ UNSPECIFIED = 0
+ ALREADY_WRITTEN = 1
+
+ @classmethod
+ def from_protobuf_code(
+ cls, code: int
+ ) -> Union[
+ "StreamWriteMessage.WriteResponse.WriteAck.StatusSkipped.Reason",
+ int,
+ ]:
+ try:
+ return StreamWriteMessage.WriteResponse.WriteAck.StatusSkipped.Reason(
+ code
+ )
+ except ValueError:
+ return code
+
+ @dataclass
+ class WriteStatistics:
+ persisting_time: datetime.timedelta
+ min_queue_wait_time: datetime.timedelta
+ max_queue_wait_time: datetime.timedelta
+ partition_quota_wait_time: datetime.timedelta
+ topic_quota_wait_time: datetime.timedelta
+
+ @dataclass
+ class PartitioningMessageGroupID:
+ message_group_id: str
+
+ @dataclass
+ class PartitioningPartitionID:
+ partition_id: int
+
+ PartitioningType = Union[PartitioningMessageGroupID, PartitioningPartitionID, None]
+
+ @dataclass
+ class FromClient(IToProto):
+ value: "WriterMessagesFromClientToServer"
+
+ def __init__(self, value: "WriterMessagesFromClientToServer"):
+ self.value = value
+
+ def to_proto(self) -> Message:
+ res = ydb_topic_pb2.StreamWriteMessage.FromClient()
+ value = self.value
+ if isinstance(value, StreamWriteMessage.WriteRequest):
+ res.write_request.CopyFrom(value.to_proto())
+ elif isinstance(value, StreamWriteMessage.InitRequest):
+ res.init_request.CopyFrom(value.to_proto())
+ elif isinstance(value, UpdateTokenRequest):
+ res.update_token_request.CopyFrom(value.to_proto())
+ else:
+ raise Exception("Unknown outcoming grpc message: %s" % value)
+ return res
+
+ class FromServer(IFromProto):
+ @staticmethod
+ def from_proto(msg: ydb_topic_pb2.StreamWriteMessage.FromServer) -> typing.Any:
+ message_type = msg.WhichOneof("server_message")
+ if message_type == "write_response":
+ res = StreamWriteMessage.WriteResponse.from_proto(msg.write_response)
+ elif message_type == "init_response":
+ res = StreamWriteMessage.InitResponse.from_proto(msg.init_response)
+ elif message_type == "update_token_response":
+ res = UpdateTokenResponse.from_proto(msg.update_token_response)
+ else:
+ # todo log instead of exception - for allow add messages in the future
+ raise UnknownGrpcMessageError("Unexpected proto message: %s" % msg)
+
+ res.status = ServerStatus(msg.status, msg.issues)
+ return res
+
+
+WriterMessagesFromClientToServer = Union[
+ StreamWriteMessage.InitRequest, StreamWriteMessage.WriteRequest, UpdateTokenRequest
+]
+WriterMessagesFromServerToClient = Union[
+ StreamWriteMessage.InitResponse,
+ StreamWriteMessage.WriteResponse,
+ UpdateTokenResponse,
+]
+
+
+########################################################################################################################
+# StreamRead
+########################################################################################################################
+
+
+class StreamReadMessage:
+ @dataclass
+ class PartitionSession(IFromProto):
+ partition_session_id: int
+ path: str
+ partition_id: int
+
+ @staticmethod
+ def from_proto(
+ msg: ydb_topic_pb2.StreamReadMessage.PartitionSession,
+ ) -> "StreamReadMessage.PartitionSession":
+ return StreamReadMessage.PartitionSession(
+ partition_session_id=msg.partition_session_id,
+ path=msg.path,
+ partition_id=msg.partition_id,
+ )
+
+ @dataclass
+ class InitRequest(IToProto):
+ topics_read_settings: List["StreamReadMessage.InitRequest.TopicReadSettings"]
+ consumer: str
+
+ def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.InitRequest:
+ res = ydb_topic_pb2.StreamReadMessage.InitRequest()
+ res.consumer = self.consumer
+ for settings in self.topics_read_settings:
+ res.topics_read_settings.append(settings.to_proto())
+ return res
+
+ @dataclass
+ class TopicReadSettings(IToProto):
+ path: str
+ partition_ids: List[int] = field(default_factory=list)
+ max_lag_seconds: Union[datetime.timedelta, None] = None
+ read_from: Union[int, float, datetime.datetime, None] = None
+
+ def to_proto(
+ self,
+ ) -> ydb_topic_pb2.StreamReadMessage.InitRequest.TopicReadSettings:
+ res = ydb_topic_pb2.StreamReadMessage.InitRequest.TopicReadSettings()
+ res.path = self.path
+ res.partition_ids.extend(self.partition_ids)
+ if self.max_lag_seconds is not None:
+ res.max_lag = proto_duration_from_timedelta(self.max_lag_seconds)
+ return res
+
+ @dataclass
+ class InitResponse(IFromProto):
+ session_id: str
+
+ @staticmethod
+ def from_proto(
+ msg: ydb_topic_pb2.StreamReadMessage.InitResponse,
+ ) -> "StreamReadMessage.InitResponse":
+ return StreamReadMessage.InitResponse(session_id=msg.session_id)
+
+ @dataclass
+ class ReadRequest(IToProto):
+ bytes_size: int
+
+ def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.ReadRequest:
+ res = ydb_topic_pb2.StreamReadMessage.ReadRequest()
+ res.bytes_size = self.bytes_size
+ return res
+
+ @dataclass
+ class ReadResponse(IFromProto):
+ partition_data: List["StreamReadMessage.ReadResponse.PartitionData"]
+ bytes_size: int
+
+ @staticmethod
+ def from_proto(
+ msg: ydb_topic_pb2.StreamReadMessage.ReadResponse,
+ ) -> "StreamReadMessage.ReadResponse":
+ partition_data = []
+ for proto_partition_data in msg.partition_data:
+ partition_data.append(
+ StreamReadMessage.ReadResponse.PartitionData.from_proto(
+ proto_partition_data
+ )
+ )
+ return StreamReadMessage.ReadResponse(
+ partition_data=partition_data,
+ bytes_size=msg.bytes_size,
+ )
+
+ @dataclass
+ class MessageData(IFromProto):
+ offset: int
+ seq_no: int
+ created_at: datetime.datetime
+ data: bytes
+ uncompresed_size: int
+ message_group_id: str
+
+ @staticmethod
+ def from_proto(
+ msg: ydb_topic_pb2.StreamReadMessage.ReadResponse.MessageData,
+ ) -> "StreamReadMessage.ReadResponse.MessageData":
+ return StreamReadMessage.ReadResponse.MessageData(
+ offset=msg.offset,
+ seq_no=msg.seq_no,
+ created_at=msg.created_at.ToDatetime(),
+ data=msg.data,
+ uncompresed_size=msg.uncompressed_size,
+ message_group_id=msg.message_group_id,
+ )
+
+ @dataclass
+ class Batch(IFromProto):
+ message_data: List["StreamReadMessage.ReadResponse.MessageData"]
+ producer_id: str
+ write_session_meta: Dict[str, str]
+ codec: int
+ written_at: datetime.datetime
+
+ @staticmethod
+ def from_proto(
+ msg: ydb_topic_pb2.StreamReadMessage.ReadResponse.Batch,
+ ) -> "StreamReadMessage.ReadResponse.Batch":
+ message_data = []
+ for message in msg.message_data:
+ message_data.append(
+ StreamReadMessage.ReadResponse.MessageData.from_proto(message)
+ )
+ return StreamReadMessage.ReadResponse.Batch(
+ message_data=message_data,
+ producer_id=msg.producer_id,
+ write_session_meta=dict(msg.write_session_meta),
+ codec=msg.codec,
+ written_at=msg.written_at.ToDatetime(),
+ )
+
+ @dataclass
+ class PartitionData(IFromProto):
+ partition_session_id: int
+ batches: List["StreamReadMessage.ReadResponse.Batch"]
+
+ @staticmethod
+ def from_proto(
+ msg: ydb_topic_pb2.StreamReadMessage.ReadResponse.PartitionData,
+ ) -> "StreamReadMessage.ReadResponse.PartitionData":
+ batches = []
+ for proto_batch in msg.batches:
+ batches.append(
+ StreamReadMessage.ReadResponse.Batch.from_proto(proto_batch)
+ )
+ return StreamReadMessage.ReadResponse.PartitionData(
+ partition_session_id=msg.partition_session_id,
+ batches=batches,
+ )
+
+ @dataclass
+ class CommitOffsetRequest(IToProto):
+ commit_offsets: List["PartitionCommitOffset"]
+
+ def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.CommitOffsetRequest:
+ res = ydb_topic_pb2.StreamReadMessage.CommitOffsetRequest(
+ commit_offsets=list(
+ map(
+ StreamReadMessage.CommitOffsetRequest.PartitionCommitOffset.to_proto,
+ self.commit_offsets,
+ )
+ ),
+ )
+ return res
+
+ @dataclass
+ class PartitionCommitOffset(IToProto):
+ partition_session_id: int
+ offsets: List["OffsetsRange"]
+
+ def to_proto(
+ self,
+ ) -> ydb_topic_pb2.StreamReadMessage.CommitOffsetRequest.PartitionCommitOffset:
+ res = ydb_topic_pb2.StreamReadMessage.CommitOffsetRequest.PartitionCommitOffset(
+ partition_session_id=self.partition_session_id,
+ offsets=list(map(OffsetsRange.to_proto, self.offsets)),
+ )
+ return res
+
+ @dataclass
+ class CommitOffsetResponse(IFromProto):
+ partitions_committed_offsets: List[
+ "StreamReadMessage.CommitOffsetResponse.PartitionCommittedOffset"
+ ]
+
+ @staticmethod
+ def from_proto(
+ msg: ydb_topic_pb2.StreamReadMessage.CommitOffsetResponse,
+ ) -> "StreamReadMessage.CommitOffsetResponse":
+ return StreamReadMessage.CommitOffsetResponse(
+ partitions_committed_offsets=list(
+ map(
+ StreamReadMessage.CommitOffsetResponse.PartitionCommittedOffset.from_proto,
+ msg.partitions_committed_offsets,
+ )
+ )
+ )
+
+ @dataclass
+ class PartitionCommittedOffset(IFromProto):
+ partition_session_id: int
+ committed_offset: int
+
+ @staticmethod
+ def from_proto(
+ msg: ydb_topic_pb2.StreamReadMessage.CommitOffsetResponse.PartitionCommittedOffset,
+ ) -> "StreamReadMessage.CommitOffsetResponse.PartitionCommittedOffset":
+ return StreamReadMessage.CommitOffsetResponse.PartitionCommittedOffset(
+ partition_session_id=msg.partition_session_id,
+ committed_offset=msg.committed_offset,
+ )
+
+ @dataclass
+ class PartitionSessionStatusRequest:
+ partition_session_id: int
+
+ @dataclass
+ class PartitionSessionStatusResponse:
+ partition_session_id: int
+ partition_offsets: "OffsetsRange"
+ committed_offset: int
+ write_time_high_watermark: float
+
+ @dataclass
+ class StartPartitionSessionRequest(IFromProto):
+ partition_session: "StreamReadMessage.PartitionSession"
+ committed_offset: int
+ partition_offsets: "OffsetsRange"
+
+ @staticmethod
+ def from_proto(
+ msg: ydb_topic_pb2.StreamReadMessage.StartPartitionSessionRequest,
+ ) -> "StreamReadMessage.StartPartitionSessionRequest":
+ return StreamReadMessage.StartPartitionSessionRequest(
+ partition_session=StreamReadMessage.PartitionSession.from_proto(
+ msg.partition_session
+ ),
+ committed_offset=msg.committed_offset,
+ partition_offsets=OffsetsRange.from_proto(msg.partition_offsets),
+ )
+
+ @dataclass
+ class StartPartitionSessionResponse(IToProto):
+ partition_session_id: int
+ read_offset: Optional[int]
+ commit_offset: Optional[int]
+
+ def to_proto(
+ self,
+ ) -> ydb_topic_pb2.StreamReadMessage.StartPartitionSessionResponse:
+ res = ydb_topic_pb2.StreamReadMessage.StartPartitionSessionResponse()
+ res.partition_session_id = self.partition_session_id
+ if self.read_offset is not None:
+ res.read_offset = self.read_offset
+ if self.commit_offset is not None:
+ res.commit_offset = self.commit_offset
+ return res
+
+ @dataclass
+ class StopPartitionSessionRequest:
+ partition_session_id: int
+ graceful: bool
+ committed_offset: int
+
+ @dataclass
+ class StopPartitionSessionResponse:
+ partition_session_id: int
+
+ @dataclass
+ class FromClient(IToProto):
+ client_message: "ReaderMessagesFromClientToServer"
+
+ def __init__(self, client_message: "ReaderMessagesFromClientToServer"):
+ self.client_message = client_message
+
+ def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.FromClient:
+ res = ydb_topic_pb2.StreamReadMessage.FromClient()
+ if isinstance(self.client_message, StreamReadMessage.ReadRequest):
+ res.read_request.CopyFrom(self.client_message.to_proto())
+ elif isinstance(self.client_message, StreamReadMessage.CommitOffsetRequest):
+ res.commit_offset_request.CopyFrom(self.client_message.to_proto())
+ elif isinstance(self.client_message, StreamReadMessage.InitRequest):
+ res.init_request.CopyFrom(self.client_message.to_proto())
+ elif isinstance(self.client_message, UpdateTokenRequest):
+ res.update_token_request.CopyFrom(self.client_message.to_proto())
+ elif isinstance(
+ self.client_message, StreamReadMessage.StartPartitionSessionResponse
+ ):
+ res.start_partition_session_response.CopyFrom(
+ self.client_message.to_proto()
+ )
+ else:
+ raise NotImplementedError(
+ "Unknown message type: %s" % type(self.client_message)
+ )
+ return res
+
+ @dataclass
+ class FromServer(IFromProto):
+ server_message: "ReaderMessagesFromServerToClient"
+ server_status: ServerStatus
+
+ @staticmethod
+ def from_proto(
+ msg: ydb_topic_pb2.StreamReadMessage.FromServer,
+ ) -> "StreamReadMessage.FromServer":
+ mess_type = msg.WhichOneof("server_message")
+ server_status = ServerStatus.from_proto(msg)
+ if mess_type == "read_response":
+ return StreamReadMessage.FromServer(
+ server_status=server_status,
+ server_message=StreamReadMessage.ReadResponse.from_proto(
+ msg.read_response
+ ),
+ )
+ elif mess_type == "commit_offset_response":
+ return StreamReadMessage.FromServer(
+ server_status=server_status,
+ server_message=StreamReadMessage.CommitOffsetResponse.from_proto(
+ msg.commit_offset_response
+ ),
+ )
+ elif mess_type == "init_response":
+ return StreamReadMessage.FromServer(
+ server_status=server_status,
+ server_message=StreamReadMessage.InitResponse.from_proto(
+ msg.init_response
+ ),
+ )
+ elif mess_type == "start_partition_session_request":
+ return StreamReadMessage.FromServer(
+ server_status=server_status,
+ server_message=StreamReadMessage.StartPartitionSessionRequest.from_proto(
+ msg.start_partition_session_request
+ ),
+ )
+ elif mess_type == "update_token_response":
+ return StreamReadMessage.FromServer(
+ server_status=server_status,
+ server_message=UpdateTokenResponse.from_proto(
+ msg.update_token_response
+ ),
+ )
+
+ # todo replace exception to log
+ raise NotImplementedError()
+
+
+ReaderMessagesFromClientToServer = Union[
+ StreamReadMessage.InitRequest,
+ StreamReadMessage.ReadRequest,
+ StreamReadMessage.CommitOffsetRequest,
+ StreamReadMessage.PartitionSessionStatusRequest,
+ UpdateTokenRequest,
+ StreamReadMessage.StartPartitionSessionResponse,
+ StreamReadMessage.StopPartitionSessionResponse,
+]
+
+ReaderMessagesFromServerToClient = Union[
+ StreamReadMessage.InitResponse,
+ StreamReadMessage.ReadResponse,
+ StreamReadMessage.CommitOffsetResponse,
+ StreamReadMessage.PartitionSessionStatusResponse,
+ UpdateTokenResponse,
+ StreamReadMessage.StartPartitionSessionRequest,
+ StreamReadMessage.StopPartitionSessionRequest,
+]
+
+
+@dataclass
+class MultipleWindowsStat(IFromProto, IToPublic):
+ per_minute: int
+ per_hour: int
+ per_day: int
+
+ @staticmethod
+ def from_proto(
+ msg: Optional[ydb_topic_pb2.MultipleWindowsStat],
+ ) -> Optional["MultipleWindowsStat"]:
+ if msg is None:
+ return None
+ return MultipleWindowsStat(
+ per_minute=msg.per_minute,
+ per_hour=msg.per_hour,
+ per_day=msg.per_day,
+ )
+
+ def to_public(self) -> ydb_topic_public_types.PublicMultipleWindowsStat:
+ return ydb_topic_public_types.PublicMultipleWindowsStat(
+ per_minute=self.per_minute,
+ per_hour=self.per_hour,
+ per_day=self.per_day,
+ )
+
+
+@dataclass
+class Consumer(IToProto, IFromProto, IFromPublic, IToPublic):
+ name: str
+ important: bool
+ read_from: typing.Optional[datetime.datetime]
+ supported_codecs: SupportedCodecs
+ attributes: Dict[str, str]
+ consumer_stats: typing.Optional["Consumer.ConsumerStats"]
+
+ def to_proto(self) -> ydb_topic_pb2.Consumer:
+ return ydb_topic_pb2.Consumer(
+ name=self.name,
+ important=self.important,
+ read_from=proto_timestamp_from_datetime(self.read_from),
+ supported_codecs=self.supported_codecs.to_proto(),
+ attributes=self.attributes,
+ # consumer_stats - readonly field
+ )
+
+ @staticmethod
+ def from_proto(msg: Optional[ydb_topic_pb2.Consumer]) -> Optional["Consumer"]:
+ return Consumer(
+ name=msg.name,
+ important=msg.important,
+ read_from=datetime_from_proto_timestamp(msg.read_from),
+ supported_codecs=SupportedCodecs.from_proto(msg.supported_codecs),
+ attributes=dict(msg.attributes),
+ consumer_stats=Consumer.ConsumerStats.from_proto(msg.consumer_stats),
+ )
+
+ @staticmethod
+ def from_public(consumer: ydb_topic_public_types.PublicConsumer):
+ if consumer is None:
+ return None
+
+ supported_codecs = []
+ if consumer.supported_codecs is not None:
+ supported_codecs = consumer.supported_codecs
+
+ return Consumer(
+ name=consumer.name,
+ important=consumer.important,
+ read_from=consumer.read_from,
+ supported_codecs=SupportedCodecs(codecs=supported_codecs),
+ attributes=consumer.attributes,
+ consumer_stats=None,
+ )
+
+ def to_public(self) -> ydb_topic_public_types.PublicConsumer:
+ return ydb_topic_public_types.PublicConsumer(
+ name=self.name,
+ important=self.important,
+ read_from=self.read_from,
+ supported_codecs=self.supported_codecs.to_public(),
+ attributes=self.attributes,
+ )
+
+ @dataclass
+ class ConsumerStats(IFromProto):
+ min_partitions_last_read_time: datetime.datetime
+ max_read_time_lag: datetime.timedelta
+ max_write_time_lag: datetime.timedelta
+ bytes_read: MultipleWindowsStat
+
+ @staticmethod
+ def from_proto(
+ msg: ydb_topic_pb2.Consumer.ConsumerStats,
+ ) -> "Consumer.ConsumerStats":
+ return Consumer.ConsumerStats(
+ min_partitions_last_read_time=datetime_from_proto_timestamp(
+ msg.min_partitions_last_read_time
+ ),
+ max_read_time_lag=timedelta_from_proto_duration(msg.max_read_time_lag),
+ max_write_time_lag=timedelta_from_proto_duration(
+ msg.max_write_time_lag
+ ),
+ bytes_read=MultipleWindowsStat.from_proto(msg.bytes_read),
+ )
+
+
+@dataclass
+class PartitioningSettings(IToProto, IFromProto):
+ min_active_partitions: int
+ partition_count_limit: int
+
+ @staticmethod
+ def from_proto(msg: ydb_topic_pb2.PartitioningSettings) -> "PartitioningSettings":
+ return PartitioningSettings(
+ min_active_partitions=msg.min_active_partitions,
+ partition_count_limit=msg.partition_count_limit,
+ )
+
+ def to_proto(self) -> ydb_topic_pb2.PartitioningSettings:
+ return ydb_topic_pb2.PartitioningSettings(
+ min_active_partitions=self.min_active_partitions,
+ partition_count_limit=self.partition_count_limit,
+ )
+
+
+class MeteringMode(int, IFromProto, IFromPublic, IToPublic):
+ UNSPECIFIED = 0
+ RESERVED_CAPACITY = 1
+ REQUEST_UNITS = 2
+
+ @staticmethod
+ def from_public(
+ m: Optional[ydb_topic_public_types.PublicMeteringMode],
+ ) -> Optional["MeteringMode"]:
+ if m is None:
+ return None
+
+ return MeteringMode(m)
+
+ @staticmethod
+ def from_proto(code: Optional[int]) -> Optional["MeteringMode"]:
+ if code is None:
+ return None
+
+ return MeteringMode(code)
+
+ def to_public(self) -> ydb_topic_public_types.PublicMeteringMode:
+ try:
+ ydb_topic_public_types.PublicMeteringMode(int(self))
+ except KeyError:
+ return ydb_topic_public_types.PublicMeteringMode.UNSPECIFIED
+
+
+@dataclass
+class CreateTopicRequest(IToProto, IFromPublic):
+ path: str
+ partitioning_settings: "PartitioningSettings"
+ retention_period: typing.Optional[datetime.timedelta]
+ retention_storage_mb: typing.Optional[int]
+ supported_codecs: "SupportedCodecs"
+ partition_write_speed_bytes_per_second: typing.Optional[int]
+ partition_write_burst_bytes: typing.Optional[int]
+ attributes: Dict[str, str]
+ consumers: List["Consumer"]
+ metering_mode: "MeteringMode"
+
+ def to_proto(self) -> ydb_topic_pb2.CreateTopicRequest:
+ return ydb_topic_pb2.CreateTopicRequest(
+ path=self.path,
+ partitioning_settings=self.partitioning_settings.to_proto(),
+ retention_period=proto_duration_from_timedelta(self.retention_period),
+ retention_storage_mb=self.retention_storage_mb,
+ supported_codecs=self.supported_codecs.to_proto(),
+ partition_write_speed_bytes_per_second=self.partition_write_speed_bytes_per_second,
+ partition_write_burst_bytes=self.partition_write_burst_bytes,
+ attributes=self.attributes,
+ consumers=[consumer.to_proto() for consumer in self.consumers],
+ metering_mode=self.metering_mode,
+ )
+
+ @staticmethod
+ def from_public(req: ydb_topic_public_types.CreateTopicRequestParams):
+ supported_codecs = []
+
+ if req.supported_codecs is not None:
+ supported_codecs = req.supported_codecs
+
+ consumers = []
+ if req.consumers is not None:
+ for consumer in req.consumers:
+ if isinstance(consumer, str):
+ consumer = ydb_topic_public_types.PublicConsumer(name=consumer)
+ consumers.append(Consumer.from_public(consumer))
+
+ return CreateTopicRequest(
+ path=req.path,
+ partitioning_settings=PartitioningSettings(
+ min_active_partitions=req.min_active_partitions,
+ partition_count_limit=req.partition_count_limit,
+ ),
+ retention_period=req.retention_period,
+ retention_storage_mb=req.retention_storage_mb,
+ supported_codecs=SupportedCodecs(
+ codecs=supported_codecs,
+ ),
+ partition_write_speed_bytes_per_second=req.partition_write_speed_bytes_per_second,
+ partition_write_burst_bytes=req.partition_write_burst_bytes,
+ attributes=req.attributes,
+ consumers=consumers,
+ metering_mode=MeteringMode.from_public(req.metering_mode),
+ )
+
+
+@dataclass
+class CreateTopicResult:
+ pass
+
+
+@dataclass
+class DescribeTopicRequest:
+ path: str
+ include_stats: bool
+
+
+@dataclass
+class DescribeTopicResult(IFromProtoWithProtoType, IToPublic):
+ self_proto: ydb_scheme_pb2.Entry
+ partitioning_settings: PartitioningSettings
+ partitions: List["DescribeTopicResult.PartitionInfo"]
+ retention_period: datetime.timedelta
+ retention_storage_mb: int
+ supported_codecs: SupportedCodecs
+ partition_write_speed_bytes_per_second: int
+ partition_write_burst_bytes: int
+ attributes: Dict[str, str]
+ consumers: List["Consumer"]
+ metering_mode: MeteringMode
+ topic_stats: "DescribeTopicResult.TopicStats"
+
+ @staticmethod
+ def from_proto(msg: ydb_topic_pb2.DescribeTopicResult) -> "DescribeTopicResult":
+ return DescribeTopicResult(
+ self_proto=msg.self,
+ partitioning_settings=PartitioningSettings.from_proto(
+ msg.partitioning_settings
+ ),
+ partitions=list(
+ map(DescribeTopicResult.PartitionInfo.from_proto, msg.partitions)
+ ),
+ retention_period=msg.retention_period,
+ retention_storage_mb=msg.retention_storage_mb,
+ supported_codecs=SupportedCodecs.from_proto(msg.supported_codecs),
+ partition_write_speed_bytes_per_second=msg.partition_write_speed_bytes_per_second,
+ partition_write_burst_bytes=msg.partition_write_burst_bytes,
+ attributes=dict(msg.attributes),
+ consumers=list(map(Consumer.from_proto, msg.consumers)),
+ metering_mode=MeteringMode.from_proto(msg.metering_mode),
+ topic_stats=DescribeTopicResult.TopicStats.from_proto(msg.topic_stats),
+ )
+
+ @staticmethod
+ def empty_proto_message() -> ydb_topic_pb2.DescribeTopicResult:
+ return ydb_topic_pb2.DescribeTopicResult()
+
+ def to_public(self) -> ydb_topic_public_types.PublicDescribeTopicResult:
+ return ydb_topic_public_types.PublicDescribeTopicResult(
+ self=scheme._wrap_scheme_entry(self.self_proto),
+ min_active_partitions=self.partitioning_settings.min_active_partitions,
+ partition_count_limit=self.partitioning_settings.partition_count_limit,
+ partitions=list(
+ map(DescribeTopicResult.PartitionInfo.to_public, self.partitions)
+ ),
+ retention_period=self.retention_period,
+ retention_storage_mb=self.retention_storage_mb,
+ supported_codecs=self.supported_codecs.to_public(),
+ partition_write_speed_bytes_per_second=self.partition_write_speed_bytes_per_second,
+ partition_write_burst_bytes=self.partition_write_burst_bytes,
+ attributes=self.attributes,
+ consumers=list(map(Consumer.to_public, self.consumers)),
+ metering_mode=self.metering_mode.to_public(),
+ topic_stats=self.topic_stats.to_public(),
+ )
+
+ @dataclass
+ class PartitionInfo(IFromProto, IToPublic):
+ partition_id: int
+ active: bool
+ child_partition_ids: List[int]
+ parent_partition_ids: List[int]
+ partition_stats: "PartitionStats"
+
+ @staticmethod
+ def from_proto(
+ msg: Optional[ydb_topic_pb2.DescribeTopicResult.PartitionInfo],
+ ) -> Optional["DescribeTopicResult.PartitionInfo"]:
+ if msg is None:
+ return None
+
+ return DescribeTopicResult.PartitionInfo(
+ partition_id=msg.partition_id,
+ active=msg.active,
+ child_partition_ids=list(msg.child_partition_ids),
+ parent_partition_ids=list(msg.parent_partition_ids),
+ partition_stats=PartitionStats.from_proto(msg.partition_stats),
+ )
+
+ def to_public(
+ self,
+ ) -> ydb_topic_public_types.PublicDescribeTopicResult.PartitionInfo:
+ partition_stats = None
+ if self.partition_stats is not None:
+ partition_stats = self.partition_stats.to_public()
+ return ydb_topic_public_types.PublicDescribeTopicResult.PartitionInfo(
+ partition_id=self.partition_id,
+ active=self.active,
+ child_partition_ids=self.child_partition_ids,
+ parent_partition_ids=self.parent_partition_ids,
+ partition_stats=partition_stats,
+ )
+
+ @dataclass
+ class TopicStats(IFromProto, IToPublic):
+ store_size_bytes: int
+ min_last_write_time: datetime.datetime
+ max_write_time_lag: datetime.timedelta
+ bytes_written: "MultipleWindowsStat"
+
+ @staticmethod
+ def from_proto(
+ msg: Optional[ydb_topic_pb2.DescribeTopicResult.TopicStats],
+ ) -> Optional["DescribeTopicResult.TopicStats"]:
+ if msg is None:
+ return None
+
+ return DescribeTopicResult.TopicStats(
+ store_size_bytes=msg.store_size_bytes,
+ min_last_write_time=datetime_from_proto_timestamp(
+ msg.min_last_write_time
+ ),
+ max_write_time_lag=timedelta_from_proto_duration(
+ msg.max_write_time_lag
+ ),
+ bytes_written=MultipleWindowsStat.from_proto(msg.bytes_written),
+ )
+
+ def to_public(
+ self,
+ ) -> ydb_topic_public_types.PublicDescribeTopicResult.TopicStats:
+ return ydb_topic_public_types.PublicDescribeTopicResult.TopicStats(
+ store_size_bytes=self.store_size_bytes,
+ min_last_write_time=self.min_last_write_time,
+ max_write_time_lag=self.max_write_time_lag,
+ bytes_written=self.bytes_written.to_public(),
+ )
+
+
+@dataclass
+class PartitionStats(IFromProto, IToPublic):
+ partition_offsets: OffsetsRange
+ store_size_bytes: int
+ last_write_time: datetime.datetime
+ max_write_time_lag: datetime.timedelta
+ bytes_written: "MultipleWindowsStat"
+ partition_node_id: int
+
+ @staticmethod
+ def from_proto(
+ msg: Optional[ydb_topic_pb2.PartitionStats],
+ ) -> Optional["PartitionStats"]:
+ if msg is None:
+ return None
+ return PartitionStats(
+ partition_offsets=OffsetsRange.from_proto(msg.partition_offsets),
+ store_size_bytes=msg.store_size_bytes,
+ last_write_time=datetime_from_proto_timestamp(msg.last_write_time),
+ max_write_time_lag=timedelta_from_proto_duration(msg.max_write_time_lag),
+ bytes_written=MultipleWindowsStat.from_proto(msg.bytes_written),
+ partition_node_id=msg.partition_node_id,
+ )
+
+ def to_public(self) -> ydb_topic_public_types.PublicPartitionStats:
+ return ydb_topic_public_types.PublicPartitionStats(
+ partition_start=self.partition_offsets.start,
+ partition_end=self.partition_offsets.end,
+ store_size_bytes=self.store_size_bytes,
+ last_write_time=self.last_write_time,
+ max_write_time_lag=self.max_write_time_lag,
+ bytes_written=self.bytes_written.to_public(),
+ partition_node_id=self.partition_node_id,
+ )
diff --git a/ydb/public/sdk/python3/ydb/_grpc/grpcwrapper/ydb_topic_public_types.py b/ydb/public/sdk/python3/ydb/_grpc/grpcwrapper/ydb_topic_public_types.py
new file mode 100644
index 0000000000..4582f19a02
--- /dev/null
+++ b/ydb/public/sdk/python3/ydb/_grpc/grpcwrapper/ydb_topic_public_types.py
@@ -0,0 +1,200 @@
+import datetime
+import typing
+from dataclasses import dataclass, field
+from enum import IntEnum
+from typing import Optional, List, Union, Dict
+
+# Workaround for good IDE and universal for runtime
+if typing.TYPE_CHECKING:
+ from ..v4.protos import ydb_topic_pb2
+else:
+ from ..common.protos import ydb_topic_pb2
+
+from .common_utils import IToProto
+from ...scheme import SchemeEntry
+
+
+@dataclass
+# need similar struct to PublicDescribeTopicResult
+class CreateTopicRequestParams:
+ path: str
+ min_active_partitions: Optional[int]
+ partition_count_limit: Optional[int]
+ retention_period: Optional[datetime.timedelta]
+ retention_storage_mb: Optional[int]
+ supported_codecs: Optional[List[Union["PublicCodec", int]]]
+ partition_write_speed_bytes_per_second: Optional[int]
+ partition_write_burst_bytes: Optional[int]
+ attributes: Optional[Dict[str, str]]
+ consumers: Optional[List[Union["PublicConsumer", str]]]
+ metering_mode: Optional["PublicMeteringMode"]
+
+
+class PublicCodec(int):
+ """
+ Codec value may contain any int number.
+
+ Values below is only well-known predefined values,
+ but protocol support custom codecs.
+ """
+
+ UNSPECIFIED = 0
+ RAW = 1
+ GZIP = 2
+ LZOP = 3 # Has not supported codec in standard library
+ ZSTD = 4 # Has not supported codec in standard library
+
+
+class PublicMeteringMode(IntEnum):
+ UNSPECIFIED = 0
+ RESERVED_CAPACITY = 1
+ REQUEST_UNITS = 2
+
+
+@dataclass
+class PublicConsumer:
+ name: str
+ important: bool = False
+ """
+ Consumer may be marked as 'important'. It means messages for this consumer will never expire due to retention.
+ User should take care that such consumer never stalls, to prevent running out of disk space.
+ """
+
+ read_from: Optional[datetime.datetime] = None
+ "All messages with smaller server written_at timestamp will be skipped."
+
+ supported_codecs: List[PublicCodec] = field(default_factory=lambda: list())
+ """
+ List of supported codecs by this consumer.
+ supported_codecs on topic must be contained inside this list.
+ """
+
+ attributes: Dict[str, str] = field(default_factory=lambda: dict())
+ "Attributes of consumer"
+
+
+@dataclass
+class DropTopicRequestParams(IToProto):
+ path: str
+
+ def to_proto(self) -> ydb_topic_pb2.DropTopicRequest:
+ return ydb_topic_pb2.DropTopicRequest(path=self.path)
+
+
+@dataclass
+class DescribeTopicRequestParams(IToProto):
+ path: str
+ include_stats: bool
+
+ def to_proto(self) -> ydb_topic_pb2.DescribeTopicRequest:
+ return ydb_topic_pb2.DescribeTopicRequest(
+ path=self.path, include_stats=self.include_stats
+ )
+
+
+@dataclass
+# Need similar struct to CreateTopicRequestParams
+class PublicDescribeTopicResult:
+ self: SchemeEntry
+ "Description of scheme object"
+
+ min_active_partitions: int
+ "Minimum partition count auto merge would stop working at"
+
+ partition_count_limit: int
+ "Limit for total partition count, including active (open for write) and read-only partitions"
+
+ partitions: List["PublicDescribeTopicResult.PartitionInfo"]
+ "Partitions description"
+
+ retention_period: datetime.timedelta
+ "How long data in partition should be stored"
+
+ retention_storage_mb: int
+ "How much data in partition should be stored. Zero value means infinite limit"
+
+ supported_codecs: List[PublicCodec]
+ "List of allowed codecs for writers"
+
+ partition_write_speed_bytes_per_second: int
+ "Partition write speed in bytes per second"
+
+ partition_write_burst_bytes: int
+ "Burst size for write in partition, in bytes"
+
+ attributes: Dict[str, str]
+ """User and server attributes of topic. Server attributes starts from "_" and will be validated by server."""
+
+ consumers: List[PublicConsumer]
+ """List of consumers for this topic"""
+
+ metering_mode: PublicMeteringMode
+ "Metering settings"
+
+ topic_stats: "PublicDescribeTopicResult.TopicStats"
+ "Statistics of topic"
+
+ @dataclass
+ class PartitionInfo:
+ partition_id: int
+ "Partition identifier"
+
+ active: bool
+ "Is partition open for write"
+
+ child_partition_ids: List[int]
+ "Ids of partitions which was formed when this partition was split or merged"
+
+ parent_partition_ids: List[int]
+ "Ids of partitions from which this partition was formed by split or merge"
+
+ partition_stats: Optional["PublicPartitionStats"]
+ "Stats for partition, filled only when include_stats in request is true"
+
+ @dataclass
+ class TopicStats:
+ store_size_bytes: int
+ "Approximate size of topic"
+
+ min_last_write_time: datetime.datetime
+ "Minimum of timestamps of last write among all partitions."
+
+ max_write_time_lag: datetime.timedelta
+ """
+ Maximum of differences between write timestamp and create timestamp for all messages,
+ written during last minute.
+ """
+
+ bytes_written: "PublicMultipleWindowsStat"
+ "How much bytes were written statistics."
+
+
+@dataclass
+class PublicPartitionStats:
+ partition_start: int
+ "first message offset in the partition"
+
+ partition_end: int
+ "offset after last stored message offset in the partition (last offset + 1)"
+
+ store_size_bytes: int
+ "Approximate size of partition"
+
+ last_write_time: datetime.datetime
+ "Timestamp of last write"
+
+ max_write_time_lag: datetime.timedelta
+ "Maximum of differences between write timestamp and create timestamp for all messages, written during last minute."
+
+ bytes_written: "PublicMultipleWindowsStat"
+ "How much bytes were written during several windows in this partition."
+
+ partition_node_id: int
+ "Host where tablet for this partition works. Useful for debugging purposes."
+
+
+@dataclass
+class PublicMultipleWindowsStat:
+ per_minute: int
+ per_hour: int
+ per_day: int
diff --git a/ydb/public/sdk/python3/ydb/_sp_impl.py b/ydb/public/sdk/python3/ydb/_sp_impl.py
index a8529d7321..5974a3014b 100644
--- a/ydb/public/sdk/python3/ydb/_sp_impl.py
+++ b/ydb/public/sdk/python3/ydb/_sp_impl.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
import collections
from concurrent import futures
-from six.moves import queue
+import queue
import time
import threading
from . import settings, issues, _utilities, tracing
diff --git a/ydb/public/sdk/python3/ydb/_topic_common/__init__.py b/ydb/public/sdk/python3/ydb/_topic_common/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/ydb/public/sdk/python3/ydb/_topic_common/__init__.py
diff --git a/ydb/public/sdk/python3/ydb/_topic_common/common.py b/ydb/public/sdk/python3/ydb/_topic_common/common.py
new file mode 100644
index 0000000000..9e8f1326ed
--- /dev/null
+++ b/ydb/public/sdk/python3/ydb/_topic_common/common.py
@@ -0,0 +1,147 @@
+import asyncio
+import concurrent.futures
+import threading
+import typing
+from typing import Optional
+
+from .. import operation, issues
+from .._grpc.grpcwrapper.common_utils import IFromProtoWithProtoType
+
+TimeoutType = typing.Union[int, float, None]
+
+
+def wrap_operation(rpc_state, response_pb, driver=None):
+ return operation.Operation(rpc_state, response_pb, driver)
+
+
+ResultType = typing.TypeVar("ResultType", bound=IFromProtoWithProtoType)
+
+
+def create_result_wrapper(
+ result_type: typing.Type[ResultType],
+) -> typing.Callable[[typing.Any, typing.Any, typing.Any], ResultType]:
+ def wrapper(rpc_state, response_pb, driver=None):
+ issues._process_response(response_pb.operation)
+ msg = result_type.empty_proto_message()
+ response_pb.operation.result.Unpack(msg)
+ return result_type.from_proto(msg)
+
+ return wrapper
+
+
+_shared_event_loop_lock = threading.Lock()
+_shared_event_loop: Optional[asyncio.AbstractEventLoop] = None
+
+
+def _get_shared_event_loop() -> asyncio.AbstractEventLoop:
+ global _shared_event_loop
+
+ if _shared_event_loop is not None:
+ return _shared_event_loop
+
+ with _shared_event_loop_lock:
+ if _shared_event_loop is not None:
+ return _shared_event_loop
+
+ event_loop_set_done = concurrent.futures.Future()
+
+ def start_event_loop():
+ event_loop = asyncio.new_event_loop()
+ event_loop_set_done.set_result(event_loop)
+ asyncio.set_event_loop(event_loop)
+ event_loop.run_forever()
+
+ t = threading.Thread(
+ target=start_event_loop,
+ name="Common ydb topic event loop",
+ daemon=True,
+ )
+ t.start()
+
+ _shared_event_loop = event_loop_set_done.result()
+ return _shared_event_loop
+
+
+class CallFromSyncToAsync:
+ _loop: asyncio.AbstractEventLoop
+
+ def __init__(self, loop: asyncio.AbstractEventLoop):
+ self._loop = loop
+
+ def unsafe_call_with_future(
+ self, coro: typing.Coroutine
+ ) -> concurrent.futures.Future:
+ """
+ returned result from coro may be lost
+ """
+ return asyncio.run_coroutine_threadsafe(coro, self._loop)
+
+ def unsafe_call_with_result(self, coro: typing.Coroutine, timeout: TimeoutType):
+ """
+ returned result from coro may be lost by race future cancel by timeout and return value from coroutine
+ """
+ f = self.unsafe_call_with_future(coro)
+ try:
+ return f.result(timeout)
+ except concurrent.futures.TimeoutError:
+ raise TimeoutError()
+ finally:
+ if not f.done():
+ f.cancel()
+
+ def safe_call_with_result(self, coro: typing.Coroutine, timeout: TimeoutType):
+ """
+ no lost returned value from coro, but may be slower especially timeout latency - it wait coroutine cancelation.
+ """
+
+ if timeout is not None and timeout <= 0:
+ return self._safe_call_fast(coro)
+
+ async def call_coro():
+ task = self._loop.create_task(coro)
+ try:
+ res = await asyncio.wait_for(task, timeout)
+ return res
+ except asyncio.TimeoutError:
+ try:
+ res = await task
+ return res
+ except asyncio.CancelledError:
+ pass
+
+ # return builtin TimeoutError instead of asyncio.TimeoutError
+ raise TimeoutError()
+
+ return asyncio.run_coroutine_threadsafe(call_coro(), self._loop).result()
+
+ def _safe_call_fast(self, coro: typing.Coroutine):
+ """
+ no lost returned value from coro, but may be slower especially timeout latency - it wait coroutine cancelation.
+ Wait coroutine result only one loop.
+ """
+ res = concurrent.futures.Future()
+
+ async def call_coro():
+ try:
+ res.set_result(await coro)
+ except asyncio.CancelledError:
+ res.set_exception(TimeoutError())
+
+ coro_future = asyncio.run_coroutine_threadsafe(call_coro(), self._loop)
+ asyncio.run_coroutine_threadsafe(asyncio.sleep(0), self._loop).result()
+ coro_future.cancel()
+ return res.result()
+
+ def call_sync(self, callback: typing.Callable[[], typing.Any]) -> typing.Any:
+ result = concurrent.futures.Future()
+
+ def call_callback():
+ try:
+ res = callback()
+ result.set_result(res)
+ except BaseException as err:
+ result.set_exception(err)
+
+ self._loop.call_soon_threadsafe(call_callback)
+
+ return result.result()
diff --git a/ydb/public/sdk/python3/ydb/_topic_common/test_helpers.py b/ydb/public/sdk/python3/ydb/_topic_common/test_helpers.py
new file mode 100644
index 0000000000..96a812ab72
--- /dev/null
+++ b/ydb/public/sdk/python3/ydb/_topic_common/test_helpers.py
@@ -0,0 +1,76 @@
+import asyncio
+import time
+import typing
+
+from .._grpc.grpcwrapper.common_utils import IToProto, IGrpcWrapperAsyncIO
+
+
+class StreamMock(IGrpcWrapperAsyncIO):
+ from_server: asyncio.Queue
+ from_client: asyncio.Queue
+ _closed: bool
+
+ def __init__(self):
+ self.from_server = asyncio.Queue()
+ self.from_client = asyncio.Queue()
+ self._closed = False
+
+ async def receive(self) -> typing.Any:
+ if self._closed:
+ raise Exception("read from closed StreamMock")
+
+ item = await self.from_server.get()
+ if item is None:
+ raise StopAsyncIteration()
+ if isinstance(item, Exception):
+ raise item
+ return item
+
+ def write(self, wrap_message: IToProto):
+ if self._closed:
+ raise Exception("write to closed StreamMock")
+ self.from_client.put_nowait(wrap_message)
+
+ def close(self):
+ if self._closed:
+ return
+
+ self._closed = True
+ self.from_server.put_nowait(None)
+
+
+class WaitConditionError(Exception):
+ pass
+
+
+async def wait_condition(
+ f: typing.Callable[[], bool],
+ timeout: typing.Optional[typing.Union[float, int]] = None,
+):
+ """
+ timeout default is 1 second
+ if timeout is 0 - only counter work. It userful if test need fast timeout for condition (without wait full timeout)
+ """
+ if timeout is None:
+ timeout = 1
+
+ minimal_loop_count_for_wait = 1000
+
+ start = time.monotonic()
+ counter = 0
+ while (time.monotonic() - start < timeout) or counter < minimal_loop_count_for_wait:
+ counter += 1
+ if f():
+ return
+ await asyncio.sleep(0)
+
+ raise WaitConditionError("Bad condition in test")
+
+
+async def wait_for_fast(
+ awaitable: typing.Awaitable,
+ timeout: typing.Optional[typing.Union[float, int]] = None,
+):
+ fut = asyncio.ensure_future(awaitable)
+ await wait_condition(lambda: fut.done(), timeout)
+ return fut.result()
diff --git a/ydb/public/sdk/python3/ydb/_topic_reader/__init__.py b/ydb/public/sdk/python3/ydb/_topic_reader/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/ydb/public/sdk/python3/ydb/_topic_reader/__init__.py
diff --git a/ydb/public/sdk/python3/ydb/_topic_reader/datatypes.py b/ydb/public/sdk/python3/ydb/_topic_reader/datatypes.py
new file mode 100644
index 0000000000..3845995fcf
--- /dev/null
+++ b/ydb/public/sdk/python3/ydb/_topic_reader/datatypes.py
@@ -0,0 +1,189 @@
+from __future__ import annotations
+
+import abc
+import asyncio
+import bisect
+import enum
+from collections import deque
+from dataclasses import dataclass, field
+import datetime
+from typing import Mapping, Union, Any, List, Dict, Deque, Optional
+
+from ydb._grpc.grpcwrapper.ydb_topic import OffsetsRange, Codec
+from ydb._topic_reader import topic_reader_asyncio
+
+
+class ICommittable(abc.ABC):
+ @abc.abstractmethod
+ def _commit_get_partition_session(self) -> PartitionSession:
+ ...
+
+ @abc.abstractmethod
+ def _commit_get_offsets_range(self) -> OffsetsRange:
+ ...
+
+
+class ISessionAlive(abc.ABC):
+ @property
+ @abc.abstractmethod
+ def is_alive(self) -> bool:
+ pass
+
+
+@dataclass
+class PublicMessage(ICommittable, ISessionAlive):
+ seqno: int
+ created_at: datetime.datetime
+ message_group_id: str
+ session_metadata: Dict[str, str]
+ offset: int
+ written_at: datetime.datetime
+ producer_id: str
+ data: Union[
+ bytes, Any
+ ] # set as original decompressed bytes or deserialized object if deserializer set in reader
+ _partition_session: PartitionSession
+ _commit_start_offset: int
+ _commit_end_offset: int
+
+ def _commit_get_partition_session(self) -> PartitionSession:
+ return self._partition_session
+
+ def _commit_get_offsets_range(self) -> OffsetsRange:
+ return OffsetsRange(self._commit_start_offset, self._commit_end_offset)
+
+ # ISessionAlive implementation
+ @property
+ def is_alive(self) -> bool:
+ raise NotImplementedError()
+
+
+@dataclass
+class PartitionSession:
+ id: int
+ state: "PartitionSession.State"
+ topic_path: str
+ partition_id: int
+ committed_offset: int # last commit offset, acked from server. Processed messages up to the field-1 offset.
+ reader_reconnector_id: int
+ reader_stream_id: int
+ _next_message_start_commit_offset: int = field(init=False)
+
+ # todo: check if deque is optimal
+ _ack_waiters: Deque["PartitionSession.CommitAckWaiter"] = field(
+ init=False, default_factory=lambda: deque()
+ )
+
+ _state_changed: asyncio.Event = field(
+ init=False, default_factory=lambda: asyncio.Event(), compare=False
+ )
+ _loop: Optional[asyncio.AbstractEventLoop] = field(
+ init=False
+ ) # may be None in tests
+
+ def __post_init__(self):
+ self._next_message_start_commit_offset = self.committed_offset
+
+ try:
+ self._loop = asyncio.get_running_loop()
+ except RuntimeError:
+ self._loop = None
+
+ def add_waiter(self, end_offset: int) -> "PartitionSession.CommitAckWaiter":
+ waiter = PartitionSession.CommitAckWaiter(end_offset, self._create_future())
+ if end_offset <= self.committed_offset:
+ waiter._finish_ok()
+ return waiter
+
+ # fast way
+ if len(self._ack_waiters) > 0 and self._ack_waiters[-1].end_offset < end_offset:
+ self._ack_waiters.append(waiter)
+ else:
+ bisect.insort(self._ack_waiters, waiter)
+
+ return waiter
+
+ def _create_future(self) -> asyncio.Future:
+ if self._loop:
+ return self._loop.create_future()
+ else:
+ return asyncio.Future()
+
+ def ack_notify(self, offset: int):
+ self._ensure_not_closed()
+
+ self.committed_offset = offset
+
+ if len(self._ack_waiters) == 0:
+ # todo log warning
+ # must be never receive ack for not sended request
+ return
+
+ while len(self._ack_waiters) > 0:
+ if self._ack_waiters[0].end_offset <= offset:
+ waiter = self._ack_waiters.popleft()
+ waiter._finish_ok()
+ else:
+ break
+
+ def close(self):
+ try:
+ self._ensure_not_closed()
+ except topic_reader_asyncio.TopicReaderCommitToExpiredPartition:
+ return
+
+ self.state = PartitionSession.State.Stopped
+ exception = topic_reader_asyncio.TopicReaderCommitToExpiredPartition()
+ for waiter in self._ack_waiters:
+ waiter._finish_error(exception)
+
+ def _ensure_not_closed(self):
+ if self.state == PartitionSession.State.Stopped:
+ raise topic_reader_asyncio.TopicReaderCommitToExpiredPartition()
+
+ class State(enum.Enum):
+ Active = 1
+ GracefulShutdown = 2
+ Stopped = 3
+
+ @dataclass(order=True)
+ class CommitAckWaiter:
+ end_offset: int
+ future: asyncio.Future = field(compare=False)
+ _done: bool = field(default=False, init=False)
+ _exception: Optional[Exception] = field(default=None, init=False)
+
+ def _finish_ok(self):
+ self._done = True
+ self.future.set_result(None)
+
+ def _finish_error(self, error: Exception):
+ self._exception = error
+ self.future.set_exception(error)
+
+
+@dataclass
+class PublicBatch(ICommittable, ISessionAlive):
+ session_metadata: Mapping[str, str]
+ messages: List[PublicMessage]
+ _partition_session: PartitionSession
+ _bytes_size: int
+ _codec: Codec
+
+ def _commit_get_partition_session(self) -> PartitionSession:
+ return self.messages[0]._commit_get_partition_session()
+
+ def _commit_get_offsets_range(self) -> OffsetsRange:
+ return OffsetsRange(
+ self.messages[0]._commit_get_offsets_range().start,
+ self.messages[-1]._commit_get_offsets_range().end,
+ )
+
+ # ISessionAlive implementation
+ @property
+ def is_alive(self) -> bool:
+ state = self._partition_session.state
+ return (
+ state == PartitionSession.State.Active
+ or state == PartitionSession.State.GracefulShutdown
+ )
diff --git a/ydb/public/sdk/python3/ydb/_topic_reader/topic_reader.py b/ydb/public/sdk/python3/ydb/_topic_reader/topic_reader.py
new file mode 100644
index 0000000000..148d63b33b
--- /dev/null
+++ b/ydb/public/sdk/python3/ydb/_topic_reader/topic_reader.py
@@ -0,0 +1,116 @@
+import concurrent.futures
+import enum
+import datetime
+from dataclasses import dataclass
+from typing import (
+ Union,
+ Optional,
+ List,
+ Mapping,
+ Callable,
+)
+
+from ..table import RetrySettings
+from .._grpc.grpcwrapper.ydb_topic import StreamReadMessage, OffsetsRange
+
+
+class Selector:
+ path: str
+ partitions: Union[None, int, List[int]]
+ read_from_timestamp_ms: Optional[int]
+ max_time_lag_ms: Optional[int]
+
+ def __init__(self, path, *, partitions: Union[None, int, List[int]] = None):
+ self.path = path
+ self.partitions = partitions
+
+
+@dataclass
+class PublicReaderSettings:
+ consumer: str
+ topic: str
+ buffer_size_bytes: int = 50 * 1024 * 1024
+
+ decoders: Union[Mapping[int, Callable[[bytes], bytes]], None] = None
+ """decoders: map[codec_code] func(encoded_bytes)->decoded_bytes"""
+
+ # decoder_executor, must be set for handle non raw messages
+ decoder_executor: Optional[concurrent.futures.Executor] = None
+
+ # on_commit: Callable[["Events.OnCommit"], None] = None
+ # on_get_partition_start_offset: Callable[
+ # ["Events.OnPartitionGetStartOffsetRequest"],
+ # "Events.OnPartitionGetStartOffsetResponse",
+ # ] = None
+ # on_partition_session_start: Callable[["StubEvent"], None] = None
+ # on_partition_session_stop: Callable[["StubEvent"], None] = None
+ # on_partition_session_close: Callable[["StubEvent"], None] = None # todo?
+ # deserializer: Union[Callable[[bytes], Any], None] = None
+ # one_attempt_connection_timeout: Union[float, None] = 1
+ # connection_timeout: Union[float, None] = None
+ # retry_policy: Union["RetryPolicy", None] = None
+ update_token_interval: Union[int, float] = 3600
+
+ def _init_message(self) -> StreamReadMessage.InitRequest:
+ return StreamReadMessage.InitRequest(
+ topics_read_settings=[
+ StreamReadMessage.InitRequest.TopicReadSettings(
+ path=self.topic,
+ )
+ ],
+ consumer=self.consumer,
+ )
+
+ def _retry_settings(self) -> RetrySettings:
+ return RetrySettings(idempotent=True)
+
+
+class Events:
+ class OnCommit:
+ topic: str
+ offset: int
+
+ class OnPartitionGetStartOffsetRequest:
+ topic: str
+ partition_id: int
+
+ class OnPartitionGetStartOffsetResponse:
+ start_offset: int
+
+ class OnInitPartition:
+ pass
+
+ class OnShutdownPatition:
+ pass
+
+
+class RetryPolicy:
+ connection_timeout_sec: float
+ overload_timeout_sec: float
+ retry_access_denied: bool = False
+
+
+class CommitResult:
+ topic: str
+ partition: int
+ offset: int
+ state: "CommitResult.State"
+ details: str # for humans only, content messages may be change in any time
+
+ class State(enum.Enum):
+ UNSENT = 1 # commit didn't send to the server
+ SENT = 2 # commit was sent to server, but ack hasn't received
+ ACKED = 3 # ack from server is received
+
+
+class SessionStat:
+ path: str
+ partition_id: str
+ partition_offsets: OffsetsRange
+ committed_offset: int
+ write_time_high_watermark: datetime.datetime
+ write_time_high_watermark_timestamp_nano: int
+
+
+class StubEvent:
+ pass
diff --git a/ydb/public/sdk/python3/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/public/sdk/python3/ydb/_topic_reader/topic_reader_asyncio.py
new file mode 100644
index 0000000000..bb87d3ccc8
--- /dev/null
+++ b/ydb/public/sdk/python3/ydb/_topic_reader/topic_reader_asyncio.py
@@ -0,0 +1,697 @@
+from __future__ import annotations
+
+import asyncio
+import concurrent.futures
+import gzip
+import typing
+from asyncio import Task
+from collections import deque
+from typing import Optional, Set, Dict, Union, Callable
+
+from .. import _apis, issues
+from .._utilities import AtomicCounter
+from ..aio import Driver
+from ..issues import Error as YdbError, _process_response
+from . import datatypes
+from . import topic_reader
+from .._grpc.grpcwrapper.common_utils import (
+ IGrpcWrapperAsyncIO,
+ SupportedDriverType,
+ GrpcWrapperAsyncIO,
+)
+from .._grpc.grpcwrapper.ydb_topic import (
+ StreamReadMessage,
+ UpdateTokenRequest,
+ UpdateTokenResponse,
+ Codec,
+)
+from .._errors import check_retriable_error
+
+
+class TopicReaderError(YdbError):
+ pass
+
+
+class TopicReaderUnexpectedCodec(YdbError):
+ pass
+
+
+class TopicReaderCommitToExpiredPartition(TopicReaderError):
+ """
+ Commit message when partition read session are dropped.
+ It is ok - the message/batch will not commit to server and will receive in other read session
+ (with this or other reader).
+ """
+
+ def __init__(self, message: str = "Topic reader partition session is closed"):
+ super().__init__(message)
+
+
+class TopicReaderStreamClosedError(TopicReaderError):
+ def __init__(self):
+ super().__init__("Topic reader stream is closed")
+
+
+class TopicReaderClosedError(TopicReaderError):
+ def __init__(self):
+ super().__init__("Topic reader is closed already")
+
+
+class PublicAsyncIOReader:
+ _loop: asyncio.AbstractEventLoop
+ _closed: bool
+ _reconnector: ReaderReconnector
+
+ def __init__(self, driver: Driver, settings: topic_reader.PublicReaderSettings):
+ self._loop = asyncio.get_running_loop()
+ self._closed = False
+ self._reconnector = ReaderReconnector(driver, settings)
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ await self.close()
+
+ def __del__(self):
+ if not self._closed:
+ self._loop.create_task(self.close(), name="close reader")
+
+ async def sessions_stat(self) -> typing.List["topic_reader.SessionStat"]:
+ """
+ Receive stat from the server
+
+ use asyncio.wait_for for wait with timeout.
+ """
+ raise NotImplementedError()
+
+ def messages(
+ self, *, timeout: typing.Union[float, None] = None
+ ) -> typing.AsyncIterable[topic_reader.PublicMessage]:
+ """
+ Block until receive new message
+
+ if no new messages in timeout seconds: stop iteration by raise StopAsyncIteration
+ """
+ raise NotImplementedError()
+
+ async def receive_message(self) -> typing.Union[topic_reader.PublicMessage, None]:
+ """
+ Block until receive new message
+
+ use asyncio.wait_for for wait with timeout.
+ """
+ raise NotImplementedError()
+
+ def batches(
+ self,
+ *,
+ max_messages: typing.Union[int, None] = None,
+ max_bytes: typing.Union[int, None] = None,
+ timeout: typing.Union[float, None] = None,
+ ) -> typing.AsyncIterable[datatypes.PublicBatch]:
+ """
+ Block until receive new batch.
+ All messages in a batch from same partition.
+
+ if no new message in timeout seconds (default - infinite): stop iterations by raise StopIteration
+ """
+ raise NotImplementedError()
+
+ async def receive_batch(
+ self,
+ *,
+ max_messages: typing.Union[int, None] = None,
+ max_bytes: typing.Union[int, None] = None,
+ ) -> typing.Union[datatypes.PublicBatch, None]:
+ """
+ Get one messages batch from reader.
+ All messages in a batch from same partition.
+
+ use asyncio.wait_for for wait with timeout.
+ """
+ await self._reconnector.wait_message()
+ return self._reconnector.receive_batch_nowait()
+
+ async def commit_on_exit(
+ self, mess: datatypes.ICommittable
+ ) -> typing.AsyncContextManager:
+ """
+ commit the mess match/message if exit from context manager without exceptions
+
+ reader will close if exit from context manager with exception
+ """
+ raise NotImplementedError()
+
+ def commit(
+ self, batch: typing.Union[datatypes.PublicMessage, datatypes.PublicBatch]
+ ):
+ """
+ Write commit message to a buffer.
+
+ For the method no way check the commit result
+ (for example if lost connection - commits will not re-send and committed messages will receive again)
+ """
+ self._reconnector.commit(batch)
+
+ async def commit_with_ack(
+ self, batch: typing.Union[datatypes.PublicMessage, datatypes.PublicBatch]
+ ):
+ """
+ write commit message to a buffer and wait ack from the server.
+
+ use asyncio.wait_for for wait with timeout.
+ """
+ waiter = self._reconnector.commit(batch)
+ await waiter.future
+
+ async def flush(self):
+ """
+ force send all commit messages from internal buffers to server and wait acks for all of them.
+
+ use asyncio.wait_for for wait with timeout.
+ """
+ raise NotImplementedError()
+
+ async def close(self):
+ if self._closed:
+ raise TopicReaderClosedError()
+
+ self._closed = True
+ await self._reconnector.close()
+
+
+class ReaderReconnector:
+ _static_reader_reconnector_counter = AtomicCounter()
+
+ _id: int
+ _settings: topic_reader.PublicReaderSettings
+ _driver: Driver
+ _background_tasks: Set[Task]
+
+ _state_changed: asyncio.Event
+ _stream_reader: Optional["ReaderStream"]
+ _first_error: asyncio.Future[YdbError]
+
+ def __init__(self, driver: Driver, settings: topic_reader.PublicReaderSettings):
+ self._id = self._static_reader_reconnector_counter.inc_and_get()
+
+ self._settings = settings
+ self._driver = driver
+ self._background_tasks = set()
+
+ self._state_changed = asyncio.Event()
+ self._stream_reader = None
+ self._background_tasks.add(asyncio.create_task(self._connection_loop()))
+ self._first_error = asyncio.get_running_loop().create_future()
+
+ async def _connection_loop(self):
+ attempt = 0
+ while True:
+ try:
+ self._stream_reader = await ReaderStream.create(
+ self._id, self._driver, self._settings
+ )
+ attempt = 0
+ self._state_changed.set()
+ await self._stream_reader.wait_error()
+ except issues.Error as err:
+ retry_info = check_retriable_error(
+ err, self._settings._retry_settings(), attempt
+ )
+ if not retry_info.is_retriable:
+ self._set_first_error(err)
+ return
+ await asyncio.sleep(retry_info.sleep_timeout_seconds)
+
+ attempt += 1
+
+ async def wait_message(self):
+ while True:
+ if self._first_error.done():
+ raise self._first_error.result()
+
+ if self._stream_reader:
+ try:
+ await self._stream_reader.wait_messages()
+ return
+ except YdbError:
+ pass # handle errors in reconnection loop
+
+ await self._state_changed.wait()
+ self._state_changed.clear()
+
+ def receive_batch_nowait(self):
+ return self._stream_reader.receive_batch_nowait()
+
+ def commit(
+ self, batch: datatypes.ICommittable
+ ) -> datatypes.PartitionSession.CommitAckWaiter:
+ return self._stream_reader.commit(batch)
+
+ async def close(self):
+ if self._stream_reader:
+ await self._stream_reader.close()
+ for task in self._background_tasks:
+ task.cancel()
+
+ await asyncio.wait(self._background_tasks)
+
+ def _set_first_error(self, err: issues.Error):
+ try:
+ self._first_error.set_result(err)
+ self._state_changed.set()
+ except asyncio.InvalidStateError:
+ # skip if already has result
+ pass
+
+
+class ReaderStream:
+ _static_id_counter = AtomicCounter()
+
+ _loop: asyncio.AbstractEventLoop
+ _id: int
+ _reader_reconnector_id: int
+ _session_id: str
+ _stream: Optional[IGrpcWrapperAsyncIO]
+ _started: bool
+ _background_tasks: Set[asyncio.Task]
+ _partition_sessions: Dict[int, datatypes.PartitionSession]
+ _buffer_size_bytes: int # use for init request, then for debug purposes only
+ _decode_executor: concurrent.futures.Executor
+ _decoders: Dict[
+ int, typing.Callable[[bytes], bytes]
+ ] # dict[codec_code] func(encoded_bytes)->decoded_bytes
+
+ if typing.TYPE_CHECKING:
+ _batches_to_decode: asyncio.Queue[datatypes.PublicBatch]
+ else:
+ _batches_to_decode: asyncio.Queue
+
+ _state_changed: asyncio.Event
+ _closed: bool
+ _message_batches: typing.Deque[datatypes.PublicBatch]
+ _first_error: asyncio.Future[YdbError]
+
+ _update_token_interval: Union[int, float]
+ _update_token_event: asyncio.Event
+ _get_token_function: Callable[[], str]
+
+ def __init__(
+ self,
+ reader_reconnector_id: int,
+ settings: topic_reader.PublicReaderSettings,
+ get_token_function: Optional[Callable[[], str]] = None,
+ ):
+ self._loop = asyncio.get_running_loop()
+ self._id = ReaderStream._static_id_counter.inc_and_get()
+ self._reader_reconnector_id = reader_reconnector_id
+ self._session_id = "not initialized"
+ self._stream = None
+ self._started = False
+ self._background_tasks = set()
+ self._partition_sessions = dict()
+ self._buffer_size_bytes = settings.buffer_size_bytes
+ self._decode_executor = settings.decoder_executor
+
+ self._decoders = {Codec.CODEC_GZIP: gzip.decompress}
+ if settings.decoders:
+ self._decoders.update(settings.decoders)
+
+ self._state_changed = asyncio.Event()
+ self._closed = False
+ self._first_error = asyncio.get_running_loop().create_future()
+ self._batches_to_decode = asyncio.Queue()
+ self._message_batches = deque()
+
+ self._update_token_interval = settings.update_token_interval
+ self._get_token_function = get_token_function
+ self._update_token_event = asyncio.Event()
+
+ @staticmethod
+ async def create(
+ reader_reconnector_id: int,
+ driver: SupportedDriverType,
+ settings: topic_reader.PublicReaderSettings,
+ ) -> "ReaderStream":
+ stream = GrpcWrapperAsyncIO(StreamReadMessage.FromServer.from_proto)
+
+ await stream.start(
+ driver, _apis.TopicService.Stub, _apis.TopicService.StreamRead
+ )
+
+ creds = driver._credentials
+ reader = ReaderStream(
+ reader_reconnector_id,
+ settings,
+ get_token_function=creds.get_auth_token if creds else None,
+ )
+ await reader._start(stream, settings._init_message())
+ return reader
+
+ async def _start(
+ self, stream: IGrpcWrapperAsyncIO, init_message: StreamReadMessage.InitRequest
+ ):
+ if self._started:
+ raise TopicReaderError("Double start ReaderStream")
+
+ self._started = True
+ self._stream = stream
+
+ stream.write(StreamReadMessage.FromClient(client_message=init_message))
+ init_response = await stream.receive() # type: StreamReadMessage.FromServer
+ if isinstance(init_response.server_message, StreamReadMessage.InitResponse):
+ self._session_id = init_response.server_message.session_id
+ else:
+ raise TopicReaderError(
+ "Unexpected message after InitRequest: %s", init_response
+ )
+
+ self._update_token_event.set()
+
+ self._background_tasks.add(
+ asyncio.create_task(self._read_messages_loop(), name="read_messages_loop")
+ )
+ self._background_tasks.add(asyncio.create_task(self._decode_batches_loop()))
+ if self._get_token_function:
+ self._background_tasks.add(
+ asyncio.create_task(self._update_token_loop(), name="update_token_loop")
+ )
+
+ async def wait_error(self):
+ raise await self._first_error
+
+ async def wait_messages(self):
+ while True:
+ if self._get_first_error():
+ raise self._get_first_error()
+
+ if self._message_batches:
+ return
+
+ await self._state_changed.wait()
+ self._state_changed.clear()
+
+ def receive_batch_nowait(self):
+ if self._get_first_error():
+ raise self._get_first_error()
+
+ if not self._message_batches:
+ return
+
+ batch = self._message_batches.popleft()
+ self._buffer_release_bytes(batch._bytes_size)
+ return batch
+
+ def commit(
+ self, batch: datatypes.ICommittable
+ ) -> datatypes.PartitionSession.CommitAckWaiter:
+ partition_session = batch._commit_get_partition_session()
+
+ if partition_session.reader_reconnector_id != self._reader_reconnector_id:
+ raise TopicReaderError("reader can commit only self-produced messages")
+
+ if partition_session.reader_stream_id != self._id:
+ raise TopicReaderCommitToExpiredPartition(
+ "commit messages after reconnect to server"
+ )
+
+ if partition_session.id not in self._partition_sessions:
+ raise TopicReaderCommitToExpiredPartition(
+ "commit messages after server stop the partition read session"
+ )
+
+ commit_range = batch._commit_get_offsets_range()
+ waiter = partition_session.add_waiter(commit_range.end)
+
+ if not waiter.future.done():
+ client_message = StreamReadMessage.CommitOffsetRequest(
+ commit_offsets=[
+ StreamReadMessage.CommitOffsetRequest.PartitionCommitOffset(
+ partition_session_id=partition_session.id,
+ offsets=[commit_range],
+ )
+ ]
+ )
+ self._stream.write(
+ StreamReadMessage.FromClient(client_message=client_message)
+ )
+
+ return waiter
+
+ async def _read_messages_loop(self):
+ try:
+ self._stream.write(
+ StreamReadMessage.FromClient(
+ client_message=StreamReadMessage.ReadRequest(
+ bytes_size=self._buffer_size_bytes,
+ ),
+ )
+ )
+ while True:
+ message = (
+ await self._stream.receive()
+ ) # type: StreamReadMessage.FromServer
+ _process_response(message.server_status)
+
+ if isinstance(message.server_message, StreamReadMessage.ReadResponse):
+ self._on_read_response(message.server_message)
+
+ elif isinstance(
+ message.server_message, StreamReadMessage.CommitOffsetResponse
+ ):
+ self._on_commit_response(message.server_message)
+
+ elif isinstance(
+ message.server_message,
+ StreamReadMessage.StartPartitionSessionRequest,
+ ):
+ self._on_start_partition_session(message.server_message)
+
+ elif isinstance(
+ message.server_message,
+ StreamReadMessage.StopPartitionSessionRequest,
+ ):
+ self._on_partition_session_stop(message.server_message)
+
+ elif isinstance(message.server_message, UpdateTokenResponse):
+ self._update_token_event.set()
+
+ else:
+ raise NotImplementedError(
+ "Unexpected type of StreamReadMessage.FromServer message: %s"
+ % message.server_message
+ )
+
+ self._state_changed.set()
+ except Exception as e:
+ self._set_first_error(e)
+ raise
+
+ async def _update_token_loop(self):
+ while True:
+ await asyncio.sleep(self._update_token_interval)
+ await self._update_token(token=self._get_token_function())
+
+ async def _update_token(self, token: str):
+ await self._update_token_event.wait()
+ try:
+ msg = StreamReadMessage.FromClient(UpdateTokenRequest(token))
+ self._stream.write(msg)
+ finally:
+ self._update_token_event.clear()
+
+ def _on_start_partition_session(
+ self, message: StreamReadMessage.StartPartitionSessionRequest
+ ):
+ try:
+ if (
+ message.partition_session.partition_session_id
+ in self._partition_sessions
+ ):
+ raise TopicReaderError(
+ "Double start partition session: %s"
+ % message.partition_session.partition_session_id
+ )
+
+ self._partition_sessions[
+ message.partition_session.partition_session_id
+ ] = datatypes.PartitionSession(
+ id=message.partition_session.partition_session_id,
+ state=datatypes.PartitionSession.State.Active,
+ topic_path=message.partition_session.path,
+ partition_id=message.partition_session.partition_id,
+ committed_offset=message.committed_offset,
+ reader_reconnector_id=self._reader_reconnector_id,
+ reader_stream_id=self._id,
+ )
+ self._stream.write(
+ StreamReadMessage.FromClient(
+ client_message=StreamReadMessage.StartPartitionSessionResponse(
+ partition_session_id=message.partition_session.partition_session_id,
+ read_offset=None,
+ commit_offset=None,
+ )
+ ),
+ )
+ except YdbError as err:
+ self._set_first_error(err)
+
+ def _on_partition_session_stop(
+ self, message: StreamReadMessage.StopPartitionSessionRequest
+ ):
+ if message.partition_session_id not in self._partition_sessions:
+ # may if receive stop partition with graceful=false after response on stop partition
+ # with graceful=true and remove partition from internal dictionary
+ return
+
+ partition = self._partition_sessions.pop(message.partition_session_id)
+ partition.close()
+
+ if message.graceful:
+ self._stream.write(
+ StreamReadMessage.FromClient(
+ client_message=StreamReadMessage.StopPartitionSessionResponse(
+ partition_session_id=message.partition_session_id,
+ )
+ )
+ )
+
+ def _on_read_response(self, message: StreamReadMessage.ReadResponse):
+ self._buffer_consume_bytes(message.bytes_size)
+
+ batches = self._read_response_to_batches(message)
+ for batch in batches:
+ self._batches_to_decode.put_nowait(batch)
+
+ def _on_commit_response(self, message: StreamReadMessage.CommitOffsetResponse):
+ for partition_offset in message.partitions_committed_offsets:
+ if partition_offset.partition_session_id not in self._partition_sessions:
+ continue
+
+ session = self._partition_sessions[partition_offset.partition_session_id]
+ session.ack_notify(partition_offset.committed_offset)
+
+ def _buffer_consume_bytes(self, bytes_size):
+ self._buffer_size_bytes -= bytes_size
+
+ def _buffer_release_bytes(self, bytes_size):
+ self._buffer_size_bytes += bytes_size
+ self._stream.write(
+ StreamReadMessage.FromClient(
+ client_message=StreamReadMessage.ReadRequest(
+ bytes_size=bytes_size,
+ )
+ )
+ )
+
+ def _read_response_to_batches(
+ self, message: StreamReadMessage.ReadResponse
+ ) -> typing.List[datatypes.PublicBatch]:
+ batches = []
+
+ batch_count = sum(len(p.batches) for p in message.partition_data)
+ if batch_count == 0:
+ return batches
+
+ bytes_per_batch = message.bytes_size // batch_count
+ additional_bytes_to_last_batch = (
+ message.bytes_size - bytes_per_batch * batch_count
+ )
+
+ for partition_data in message.partition_data:
+ partition_session = self._partition_sessions[
+ partition_data.partition_session_id
+ ]
+ for server_batch in partition_data.batches:
+ messages = []
+ for message_data in server_batch.message_data:
+ mess = datatypes.PublicMessage(
+ seqno=message_data.seq_no,
+ created_at=message_data.created_at,
+ message_group_id=message_data.message_group_id,
+ session_metadata=server_batch.write_session_meta,
+ offset=message_data.offset,
+ written_at=server_batch.written_at,
+ producer_id=server_batch.producer_id,
+ data=message_data.data,
+ _partition_session=partition_session,
+ _commit_start_offset=partition_session._next_message_start_commit_offset,
+ _commit_end_offset=message_data.offset + 1,
+ )
+ messages.append(mess)
+ partition_session._next_message_start_commit_offset = (
+ mess._commit_end_offset
+ )
+
+ if messages:
+ batch = datatypes.PublicBatch(
+ session_metadata=server_batch.write_session_meta,
+ messages=messages,
+ _partition_session=partition_session,
+ _bytes_size=bytes_per_batch,
+ _codec=Codec(server_batch.codec),
+ )
+ batches.append(batch)
+
+ batches[-1]._bytes_size += additional_bytes_to_last_batch
+ return batches
+
+ async def _decode_batches_loop(self):
+ while True:
+ batch = await self._batches_to_decode.get()
+ await self._decode_batch_inplace(batch)
+ self._message_batches.append(batch)
+ self._state_changed.set()
+
+ async def _decode_batch_inplace(self, batch):
+ if batch._codec == Codec.CODEC_RAW:
+ return
+
+ try:
+ decode_func = self._decoders[batch._codec]
+ except KeyError:
+ raise TopicReaderUnexpectedCodec(
+ "Receive message with unexpected codec: %s" % batch._codec
+ )
+
+ decode_data_futures = []
+ for message in batch.messages:
+ future = self._loop.run_in_executor(
+ self._decode_executor, decode_func, message.data
+ )
+ decode_data_futures.append(future)
+
+ decoded_data = await asyncio.gather(*decode_data_futures)
+ for index, message in enumerate(batch.messages):
+ message.data = decoded_data[index]
+
+ batch._codec = Codec.CODEC_RAW
+
+ def _set_first_error(self, err: YdbError):
+ try:
+ self._first_error.set_result(err)
+ self._state_changed.set()
+ except asyncio.InvalidStateError:
+ # skip later set errors
+ pass
+
+ def _get_first_error(self) -> Optional[YdbError]:
+ if self._first_error.done():
+ return self._first_error.result()
+
+ async def close(self):
+ if self._closed:
+ return
+ self._closed = True
+
+ self._set_first_error(TopicReaderStreamClosedError())
+ self._state_changed.set()
+ self._stream.close()
+
+ for session in self._partition_sessions.values():
+ session.close()
+
+ for task in self._background_tasks:
+ task.cancel()
+ await asyncio.wait(self._background_tasks)
diff --git a/ydb/public/sdk/python3/ydb/_topic_reader/topic_reader_sync.py b/ydb/public/sdk/python3/ydb/_topic_reader/topic_reader_sync.py
new file mode 100644
index 0000000000..30bf92a10e
--- /dev/null
+++ b/ydb/public/sdk/python3/ydb/_topic_reader/topic_reader_sync.py
@@ -0,0 +1,210 @@
+import asyncio
+import concurrent.futures
+import typing
+from typing import List, Union, Iterable, Optional
+
+from ydb._grpc.grpcwrapper.common_utils import SupportedDriverType
+from ydb._topic_common.common import (
+ _get_shared_event_loop,
+ CallFromSyncToAsync,
+ TimeoutType,
+)
+from ydb._topic_reader import datatypes
+from ydb._topic_reader.datatypes import PublicMessage, PublicBatch, ICommittable
+from ydb._topic_reader.topic_reader import (
+ PublicReaderSettings,
+ SessionStat,
+ CommitResult,
+)
+from ydb._topic_reader.topic_reader_asyncio import (
+ PublicAsyncIOReader,
+ TopicReaderClosedError,
+)
+
+
+class TopicReaderSync:
+ _caller: CallFromSyncToAsync
+ _async_reader: PublicAsyncIOReader
+ _closed: bool
+
+ def __init__(
+ self,
+ driver: SupportedDriverType,
+ settings: PublicReaderSettings,
+ *,
+ eventloop: Optional[asyncio.AbstractEventLoop] = None,
+ ):
+ self._closed = False
+
+ if eventloop:
+ loop = eventloop
+ else:
+ loop = _get_shared_event_loop()
+
+ self._caller = CallFromSyncToAsync(loop)
+
+ async def create_reader():
+ return PublicAsyncIOReader(driver, settings)
+
+ self._async_reader = asyncio.run_coroutine_threadsafe(
+ create_reader(), loop
+ ).result()
+
+ def __del__(self):
+ self.close()
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.close()
+
+ def async_sessions_stat(self) -> concurrent.futures.Future:
+ """
+ Receive stat from the server, return feature.
+ """
+ raise NotImplementedError()
+
+ async def sessions_stat(self) -> List[SessionStat]:
+ """
+ Receive stat from the server
+
+ use async_sessions_stat for set explicit wait timeout
+ """
+ raise NotImplementedError()
+
+ def messages(
+ self, *, timeout: Union[float, None] = None
+ ) -> Iterable[PublicMessage]:
+ """
+ todo?
+
+ Block until receive new message
+ It has no async_ version for prevent lost messages, use async_wait_message as signal for new batches available.
+
+ if no new message in timeout seconds (default - infinite): stop iterations by raise StopIteration
+ if timeout <= 0 - it will fast non block method, get messages from internal buffer only.
+ """
+ raise NotImplementedError()
+
+ def receive_message(self, *, timeout: Union[float, None] = None) -> PublicMessage:
+ """
+ Block until receive new message
+ It has no async_ version for prevent lost messages, use async_wait_message as signal for new batches available.
+
+ if no new message in timeout seconds (default - infinite): raise TimeoutError()
+ if timeout <= 0 - it will fast non block method, get messages from internal buffer only.
+ """
+ raise NotImplementedError()
+
+ def async_wait_message(self) -> concurrent.futures.Future:
+ """
+ Return future, which will completed when the reader has least one message in queue.
+ If reader already has message - future will return completed.
+
+ Possible situation when receive signal about message available, but no messages when try to receive a message.
+ If message expired between send event and try to retrieve message (for example connection broken).
+ """
+ raise NotImplementedError()
+
+ def batches(
+ self,
+ *,
+ max_messages: Union[int, None] = None,
+ max_bytes: Union[int, None] = None,
+ timeout: Union[float, None] = None,
+ ) -> Iterable[PublicBatch]:
+ """
+ Block until receive new batch.
+ It has no async_ version for prevent lost messages, use async_wait_message as signal for new batches available.
+
+ if no new message in timeout seconds (default - infinite): stop iterations by raise StopIteration
+ if timeout <= 0 - it will fast non block method, get messages from internal buffer only.
+ """
+ raise NotImplementedError()
+
+ def receive_batch(
+ self,
+ *,
+ max_messages: typing.Union[int, None] = None,
+ max_bytes: typing.Union[int, None] = None,
+ timeout: Union[float, None] = None,
+ ) -> Union[PublicBatch, None]:
+ """
+ Get one messages batch from reader
+ It has no async_ version for prevent lost messages, use async_wait_message as signal for new batches available.
+
+ if no new message in timeout seconds (default - infinite): raise TimeoutError()
+ if timeout <= 0 - it will fast non block method, get messages from internal buffer only.
+ """
+ self._check_closed()
+
+ return self._caller.safe_call_with_result(
+ self._async_reader.receive_batch(
+ max_messages=max_messages, max_bytes=max_bytes
+ ),
+ timeout,
+ )
+
+ def commit(
+ self, mess: typing.Union[datatypes.PublicMessage, datatypes.PublicBatch]
+ ):
+ """
+ Put commit message to internal buffer.
+
+ For the method no way check the commit result
+ (for example if lost connection - commits will not re-send and committed messages will receive again)
+ """
+ self._check_closed()
+
+ self._caller.call_sync(self._async_reader.commit(mess))
+
+ def commit_with_ack(
+ self, mess: ICommittable, timeout: TimeoutType = None
+ ) -> Union[CommitResult, List[CommitResult]]:
+ """
+ write commit message to a buffer and wait ack from the server.
+
+ if receive in timeout seconds (default - infinite): raise TimeoutError()
+ """
+ self._check_closed()
+
+ return self._caller.unsafe_call_with_result(
+ self._async_reader.commit_with_ack(mess), timeout
+ )
+
+ def async_commit_with_ack(
+ self, mess: typing.Union[datatypes.PublicMessage, datatypes.PublicBatch]
+ ) -> concurrent.futures.Future:
+ """
+ write commit message to a buffer and return Future for wait result.
+ """
+ self._check_closed()
+
+ return self._caller.unsafe_call_with_future(
+ self._async_reader.commit_with_ack(mess)
+ )
+
+ def async_flush(self) -> concurrent.futures.Future:
+ """
+ force send all commit messages from internal buffers to server and return Future for wait server acks.
+ """
+ raise NotImplementedError()
+
+ def flush(self):
+ """
+ force send all commit messages from internal buffers to server and wait acks for all of them.
+ """
+ raise NotImplementedError()
+
+ def close(self, *, timeout: TimeoutType = None):
+ if self._closed:
+ return
+
+ self._closed = True
+
+ self._caller.safe_call_with_result(self._async_reader.close(), timeout)
+
+ def _check_closed(self):
+ if self._closed:
+ raise TopicReaderClosedError()
diff --git a/ydb/public/sdk/python3/ydb/_topic_writer/__init__.py b/ydb/public/sdk/python3/ydb/_topic_writer/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/ydb/public/sdk/python3/ydb/_topic_writer/__init__.py
diff --git a/ydb/public/sdk/python3/ydb/_topic_writer/topic_writer.py b/ydb/public/sdk/python3/ydb/_topic_writer/topic_writer.py
new file mode 100644
index 0000000000..59ad74ff80
--- /dev/null
+++ b/ydb/public/sdk/python3/ydb/_topic_writer/topic_writer.py
@@ -0,0 +1,220 @@
+import concurrent.futures
+import datetime
+import enum
+import uuid
+from dataclasses import dataclass
+from enum import Enum
+from typing import List, Union, Optional, Any, Dict
+
+import typing
+
+import ydb.aio
+from .._grpc.grpcwrapper.ydb_topic import StreamWriteMessage
+from .._grpc.grpcwrapper.common_utils import IToProto
+from .._grpc.grpcwrapper.ydb_topic_public_types import PublicCodec
+
+Message = typing.Union["PublicMessage", "PublicMessage.SimpleMessageSourceType"]
+
+
+@dataclass
+class PublicWriterSettings:
+ """
+ Settings for topic writer.
+
+ order of fields IS NOT stable, use keywords only
+ """
+
+ topic: str
+ producer_id: Optional[str] = None
+ session_metadata: Optional[Dict[str, str]] = None
+ partition_id: Optional[int] = None
+ auto_seqno: bool = True
+ auto_created_at: bool = True
+ codec: Optional[PublicCodec] = None # default mean auto-select
+ encoder_executor: Optional[
+ concurrent.futures.Executor
+ ] = None # default shared client executor pool
+ encoders: Optional[
+ typing.Mapping[PublicCodec, typing.Callable[[bytes], bytes]]
+ ] = None
+ # get_last_seqno: bool = False
+ # serializer: Union[Callable[[Any], bytes], None] = None
+ # send_buffer_count: Optional[int] = 10000
+ # send_buffer_bytes: Optional[int] = 100 * 1024 * 1024
+ # codec: Optional[int] = None
+ # codec_autoselect: bool = True
+ # retry_policy: Optional["RetryPolicy"] = None
+ update_token_interval: Union[int, float] = 3600
+
+ def __post_init__(self):
+ if self.producer_id is None:
+ self.producer_id = uuid.uuid4().hex
+
+
+@dataclass
+class PublicWriteResult:
+ @dataclass(eq=True)
+ class Written:
+ __slots__ = "offset"
+ offset: int
+
+ @dataclass(eq=True)
+ class Skipped:
+ pass
+
+
+PublicWriteResultTypes = Union[PublicWriteResult.Written, PublicWriteResult.Skipped]
+
+
+class WriterSettings(PublicWriterSettings):
+ def __init__(self, settings: PublicWriterSettings):
+ self.__dict__ = settings.__dict__.copy()
+
+ def create_init_request(self) -> StreamWriteMessage.InitRequest:
+ return StreamWriteMessage.InitRequest(
+ path=self.topic,
+ producer_id=self.producer_id,
+ write_session_meta=self.session_metadata,
+ partitioning=self.get_partitioning(),
+ get_last_seq_no=True,
+ )
+
+ def get_partitioning(self) -> StreamWriteMessage.PartitioningType:
+ if self.partition_id is not None:
+ return StreamWriteMessage.PartitioningPartitionID(self.partition_id)
+ return StreamWriteMessage.PartitioningMessageGroupID(self.producer_id)
+
+
+class SendMode(Enum):
+ ASYNC = 1
+ SYNC = 2
+
+
+@dataclass
+class PublicWriterInitInfo:
+ __slots__ = ("last_seqno", "supported_codecs")
+ last_seqno: Optional[int]
+ supported_codecs: List[PublicCodec]
+
+
+class PublicMessage:
+ seqno: Optional[int]
+ created_at: Optional[datetime.datetime]
+ data: "PublicMessage.SimpleMessageSourceType"
+
+ SimpleMessageSourceType = Union[str, bytes] # Will be extend
+
+ def __init__(
+ self,
+ data: SimpleMessageSourceType,
+ *,
+ seqno: Optional[int] = None,
+ created_at: Optional[datetime.datetime] = None,
+ ):
+ self.seqno = seqno
+ self.created_at = created_at
+ self.data = data
+
+ @staticmethod
+ def _create_message(data: Message) -> "PublicMessage":
+ if isinstance(data, PublicMessage):
+ return data
+ return PublicMessage(data=data)
+
+
+class InternalMessage(StreamWriteMessage.WriteRequest.MessageData, IToProto):
+ codec: PublicCodec
+
+ def __init__(self, mess: PublicMessage):
+ super().__init__(
+ seq_no=mess.seqno,
+ created_at=mess.created_at,
+ data=mess.data,
+ uncompressed_size=len(mess.data),
+ partitioning=None,
+ )
+ self.codec = PublicCodec.RAW
+
+ def get_bytes(self) -> bytes:
+ if self.data is None:
+ return bytes()
+ if isinstance(self.data, bytes):
+ return self.data
+ if isinstance(self.data, str):
+ return self.data.encode("utf-8")
+ raise ValueError("Bad data type")
+
+ def to_message_data(self) -> StreamWriteMessage.WriteRequest.MessageData:
+ data = self.get_bytes()
+ return StreamWriteMessage.WriteRequest.MessageData(
+ seq_no=self.seq_no,
+ created_at=self.created_at,
+ data=data,
+ uncompressed_size=len(data),
+ partitioning=None, # unsupported by server now
+ )
+
+
+class MessageSendResult:
+ offset: Optional[int]
+ write_status: "MessageWriteStatus"
+
+
+class MessageWriteStatus(enum.Enum):
+ Written = 1
+ AlreadyWritten = 2
+
+
+class RetryPolicy:
+ connection_timeout_sec: float
+ overload_timeout_sec: float
+ retry_access_denied: bool = False
+
+
+class TopicWriterError(ydb.Error):
+ def __init__(self, message: str):
+ super(TopicWriterError, self).__init__(message)
+
+
+class TopicWriterClosedError(ydb.Error):
+ def __init__(self):
+ super().__init__("Topic writer already closed")
+
+
+class TopicWriterRepeatableError(TopicWriterError):
+ pass
+
+
+class TopicWriterStopped(TopicWriterError):
+ def __init__(self):
+ super(TopicWriterStopped, self).__init__(
+ "topic writer was stopped by call close"
+ )
+
+
+def default_serializer_message_content(data: Any) -> bytes:
+ if data is None:
+ return bytes()
+ if isinstance(data, bytes):
+ return data
+ if isinstance(data, bytearray):
+ return bytes(data)
+ if isinstance(data, str):
+ return data.encode(encoding="utf-8")
+ raise ValueError("can't serialize type %s to bytes" % type(data))
+
+
+def messages_to_proto_requests(
+ messages: List[InternalMessage],
+) -> List[StreamWriteMessage.FromClient]:
+ # todo split by proto message size and codec
+ res = []
+ for msg in messages:
+ req = StreamWriteMessage.FromClient(
+ StreamWriteMessage.WriteRequest(
+ messages=[msg.to_message_data()],
+ codec=msg.codec,
+ )
+ )
+ res.append(req)
+ return res
diff --git a/ydb/public/sdk/python3/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/public/sdk/python3/ydb/_topic_writer/topic_writer_asyncio.py
new file mode 100644
index 0000000000..7cb1f1db0b
--- /dev/null
+++ b/ydb/public/sdk/python3/ydb/_topic_writer/topic_writer_asyncio.py
@@ -0,0 +1,684 @@
+import asyncio
+import concurrent.futures
+import datetime
+import gzip
+import typing
+from collections import deque
+from typing import Deque, AsyncIterator, Union, List, Optional, Dict, Callable
+
+import ydb
+from .topic_writer import (
+ PublicWriterSettings,
+ WriterSettings,
+ PublicMessage,
+ PublicWriterInitInfo,
+ InternalMessage,
+ TopicWriterStopped,
+ TopicWriterError,
+ messages_to_proto_requests,
+ PublicWriteResultTypes,
+ Message,
+)
+from .. import (
+ _apis,
+ issues,
+ check_retriable_error,
+ RetrySettings,
+)
+from .._grpc.grpcwrapper.ydb_topic_public_types import PublicCodec
+from .._grpc.grpcwrapper.ydb_topic import (
+ UpdateTokenRequest,
+ UpdateTokenResponse,
+ StreamWriteMessage,
+ WriterMessagesFromServerToClient,
+)
+from .._grpc.grpcwrapper.common_utils import (
+ IGrpcWrapperAsyncIO,
+ SupportedDriverType,
+ GrpcWrapperAsyncIO,
+)
+
+
+class WriterAsyncIO:
+ _loop: asyncio.AbstractEventLoop
+ _reconnector: "WriterAsyncIOReconnector"
+ _closed: bool
+
+ @property
+ def last_seqno(self) -> int:
+ raise NotImplementedError()
+
+ def __init__(self, driver: SupportedDriverType, settings: PublicWriterSettings):
+ self._loop = asyncio.get_running_loop()
+ self._closed = False
+ self._reconnector = WriterAsyncIOReconnector(
+ driver=driver, settings=WriterSettings(settings)
+ )
+
+ async def __aenter__(self) -> "WriterAsyncIO":
+ return self
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ await self.close()
+
+ def __del__(self):
+ if self._closed or self._loop.is_closed():
+ return
+
+ self._loop.call_soon(self.close)
+
+ async def close(self, *, flush: bool = True):
+ if self._closed:
+ return
+
+ self._closed = True
+
+ await self._reconnector.close(flush)
+
+ async def write_with_ack(
+ self,
+ messages: Union[Message, List[Message]],
+ ) -> Union[PublicWriteResultTypes, List[PublicWriteResultTypes]]:
+ """
+ IT IS SLOWLY WAY. IT IS BAD CHOISE IN MOST CASES.
+ It is recommended to use write with optionally flush or write_with_ack_futures and receive acks by wait futures.
+
+ send one or number of messages to server and wait acks.
+
+ For wait with timeout use asyncio.wait_for.
+ """
+ futures = await self.write_with_ack_future(messages)
+ if not isinstance(futures, list):
+ futures = [futures]
+
+ await asyncio.wait(futures)
+ results = [f.result() for f in futures]
+
+ return results if isinstance(messages, list) else results[0]
+
+ async def write_with_ack_future(
+ self,
+ messages: Union[Message, List[Message]],
+ ) -> Union[asyncio.Future, List[asyncio.Future]]:
+ """
+ send one or number of messages to server.
+ return feature, which can be waited for check send result.
+
+ Usually it is fast method, but can wait if internal buffer is full.
+
+ For wait with timeout use asyncio.wait_for.
+ """
+ input_single_message = not isinstance(messages, list)
+ converted_messages = []
+ if isinstance(messages, list):
+ for m in messages:
+ converted_messages.append(PublicMessage._create_message(m))
+ else:
+ converted_messages = [PublicMessage._create_message(messages)]
+
+ futures = await self._reconnector.write_with_ack_future(converted_messages)
+ if input_single_message:
+ return futures[0]
+ else:
+ return futures
+
+ async def write(
+ self,
+ messages: Union[Message, List[Message]],
+ ):
+ """
+ send one or number of messages to server.
+ it put message to internal buffer
+
+ For wait with timeout use asyncio.wait_for.
+ """
+ await self.write_with_ack_future(messages)
+
+ async def flush(self):
+ """
+ Force send all messages from internal buffer and wait acks from server for all
+ messages.
+
+ For wait with timeout use asyncio.wait_for.
+ """
+ return await self._reconnector.flush()
+
+ async def wait_init(self) -> PublicWriterInitInfo:
+ """
+ wait while real connection will be established to server.
+
+ For wait with timeout use asyncio.wait_for()
+ """
+ return await self._reconnector.wait_init()
+
+
+class WriterAsyncIOReconnector:
+ _closed: bool
+ _loop: asyncio.AbstractEventLoop
+ _credentials: Union[ydb.credentials.Credentials, None]
+ _driver: ydb.aio.Driver
+ _init_message: StreamWriteMessage.InitRequest
+ _init_info: asyncio.Future
+ _stream_connected: asyncio.Event
+ _settings: WriterSettings
+ _codec: PublicCodec
+ _codec_functions: Dict[PublicCodec, Callable[[bytes], bytes]]
+ _encode_executor: Optional[concurrent.futures.Executor]
+ _codec_selector_batch_num: int
+ _codec_selector_last_codec: Optional[PublicCodec]
+ _codec_selector_check_batches_interval: int
+
+ _last_known_seq_no: int
+ if typing.TYPE_CHECKING:
+ _messages_for_encode: asyncio.Queue[List[InternalMessage]]
+ else:
+ _messages_for_encode: asyncio.Queue
+ _messages: Deque[InternalMessage]
+ _messages_future: Deque[asyncio.Future]
+ _new_messages: asyncio.Queue
+ _stop_reason: asyncio.Future
+ _background_tasks: List[asyncio.Task]
+
+ def __init__(self, driver: SupportedDriverType, settings: WriterSettings):
+ self._closed = False
+ self._loop = asyncio.get_running_loop()
+ self._driver = driver
+ self._credentials = driver._credentials
+ self._init_message = settings.create_init_request()
+ self._new_messages = asyncio.Queue()
+ self._init_info = self._loop.create_future()
+ self._stream_connected = asyncio.Event()
+ self._settings = settings
+
+ self._codec_functions = {
+ PublicCodec.RAW: lambda data: data,
+ PublicCodec.GZIP: gzip.compress,
+ }
+
+ if settings.encoders:
+ self._codec_functions.update(settings.encoders)
+
+ self._encode_executor = settings.encoder_executor
+
+ self._codec_selector_batch_num = 0
+ self._codec_selector_last_codec = None
+ self._codec_selector_check_batches_interval = 10000
+
+ self._codec = self._settings.codec
+ if self._codec and self._codec not in self._codec_functions:
+ known_codecs = sorted(self._codec_functions.keys())
+ raise ValueError(
+ "Unknown codec for writer: %s, supported codecs: %s"
+ % (self._codec, known_codecs)
+ )
+
+ self._last_known_seq_no = 0
+ self._messages_for_encode = asyncio.Queue()
+ self._messages = deque()
+ self._messages_future = deque()
+ self._new_messages = asyncio.Queue()
+ self._stop_reason = self._loop.create_future()
+ self._background_tasks = [
+ asyncio.create_task(self._connection_loop(), name="connection_loop"),
+ asyncio.create_task(self._encode_loop(), name="encode_loop"),
+ ]
+
+ async def close(self, flush: bool):
+ if self._closed:
+ return
+
+ if flush:
+ await self.flush()
+
+ self._closed = True
+ self._stop(TopicWriterStopped())
+
+ for task in self._background_tasks:
+ task.cancel()
+ await asyncio.wait(self._background_tasks)
+
+ # if work was stopped before close by error - raise the error
+ try:
+ self._check_stop()
+ except TopicWriterStopped:
+ pass
+
+ async def wait_init(self) -> PublicWriterInitInfo:
+ done, _ = await asyncio.wait(
+ [self._init_info, self._stop_reason], return_when=asyncio.FIRST_COMPLETED
+ )
+ res = done.pop() # type: asyncio.Future
+ res_val = res.result()
+
+ if isinstance(res_val, BaseException):
+ raise res_val
+
+ return res_val
+
+ async def wait_stop(self) -> Exception:
+ return await self._stop_reason
+
+ async def write_with_ack_future(
+ self, messages: List[PublicMessage]
+ ) -> List[asyncio.Future]:
+ # todo check internal buffer limit
+ self._check_stop()
+
+ if self._settings.auto_seqno:
+ await self.wait_init()
+
+ internal_messages = self._prepare_internal_messages(messages)
+ messages_future = [self._loop.create_future() for _ in internal_messages]
+
+ self._messages_future.extend(messages_future)
+
+ if self._codec == PublicCodec.RAW:
+ self._add_messages_to_send_queue(internal_messages)
+ else:
+ self._messages_for_encode.put_nowait(internal_messages)
+
+ return messages_future
+
+ def _add_messages_to_send_queue(self, internal_messages: List[InternalMessage]):
+ self._messages.extend(internal_messages)
+ for m in internal_messages:
+ self._new_messages.put_nowait(m)
+
+ def _prepare_internal_messages(
+ self, messages: List[PublicMessage]
+ ) -> List[InternalMessage]:
+ if self._settings.auto_created_at:
+ now = datetime.datetime.now()
+ else:
+ now = None
+
+ res = []
+ for m in messages:
+ internal_message = InternalMessage(m)
+ if self._settings.auto_seqno:
+ if internal_message.seq_no is None:
+ self._last_known_seq_no += 1
+ internal_message.seq_no = self._last_known_seq_no
+ else:
+ raise TopicWriterError(
+ "Explicit seqno and auto_seq setting is mutual exclusive"
+ )
+ else:
+ if internal_message.seq_no is None or internal_message.seq_no == 0:
+ raise TopicWriterError(
+ "Empty seqno and auto_seq setting is disabled"
+ )
+ elif internal_message.seq_no <= self._last_known_seq_no:
+ raise TopicWriterError(
+ "Message seqno is duplicated: %s" % internal_message.seq_no
+ )
+ else:
+ self._last_known_seq_no = internal_message.seq_no
+
+ if self._settings.auto_created_at:
+ if internal_message.created_at is not None:
+ raise TopicWriterError(
+ "Explicit set auto_created_at and setting auto_created_at is mutual exclusive"
+ )
+ else:
+ internal_message.created_at = now
+
+ res.append(internal_message)
+
+ return res
+
+ def _check_stop(self):
+ if self._stop_reason.done():
+ raise self._stop_reason.result()
+
+ async def _connection_loop(self):
+ retry_settings = RetrySettings() # todo
+
+ while True:
+ attempt = 0 # todo calc and reset
+ tasks = []
+
+ # noinspection PyBroadException
+ stream_writer = None
+ try:
+ stream_writer = await WriterAsyncIOStream.create(
+ self._driver,
+ self._init_message,
+ self._settings.update_token_interval,
+ )
+ try:
+ self._last_known_seq_no = stream_writer.last_seqno
+ self._init_info.set_result(
+ PublicWriterInitInfo(
+ last_seqno=stream_writer.last_seqno,
+ supported_codecs=stream_writer.supported_codecs,
+ )
+ )
+ except asyncio.InvalidStateError:
+ pass
+
+ self._stream_connected.set()
+
+ send_loop = asyncio.create_task(
+ self._send_loop(stream_writer), name="writer send loop"
+ )
+ receive_loop = asyncio.create_task(
+ self._read_loop(stream_writer), name="writer receive loop"
+ )
+
+ tasks = [send_loop, receive_loop]
+ done, _ = await asyncio.wait(
+ [send_loop, receive_loop], return_when=asyncio.FIRST_COMPLETED
+ )
+ await stream_writer.close()
+ done.pop().result()
+ except issues.Error as err:
+ # todo log error
+ print(err)
+
+ err_info = check_retriable_error(err, retry_settings, attempt)
+ if not err_info.is_retriable:
+ self._stop(err)
+ return
+
+ await asyncio.sleep(err_info.sleep_timeout_seconds)
+
+ except (asyncio.CancelledError, Exception) as err:
+ self._stop(err)
+ return
+ finally:
+ if stream_writer:
+ await stream_writer.close()
+ for task in tasks:
+ task.cancel()
+ await asyncio.wait(tasks)
+
+ async def _encode_loop(self):
+ while True:
+ messages = await self._messages_for_encode.get()
+ while not self._messages_for_encode.empty():
+ messages.extend(self._messages_for_encode.get_nowait())
+
+ batch_codec = await self._codec_selector(messages)
+ await self._encode_data_inplace(batch_codec, messages)
+ self._add_messages_to_send_queue(messages)
+
+ async def _encode_data_inplace(
+ self, codec: PublicCodec, messages: List[InternalMessage]
+ ):
+ if codec == PublicCodec.RAW:
+ return
+
+ eventloop = asyncio.get_running_loop()
+ encode_waiters = []
+ encoder_function = self._codec_functions[codec]
+
+ for message in messages:
+ encoded_data_futures = eventloop.run_in_executor(
+ self._encode_executor, encoder_function, message.get_bytes()
+ )
+ encode_waiters.append(encoded_data_futures)
+
+ encoded_datas = await asyncio.gather(*encode_waiters)
+
+ for index, data in enumerate(encoded_datas):
+ message = messages[index]
+ message.codec = codec
+ message.data = data
+
+ async def _codec_selector(self, messages: List[InternalMessage]) -> PublicCodec:
+ if self._codec is not None:
+ return self._codec
+
+ if self._codec_selector_last_codec is None:
+ available_codecs = await self._get_available_codecs()
+
+ # use every of available encoders at start for prevent problems
+ # with rare used encoders (on writer or reader side)
+ if self._codec_selector_batch_num < len(available_codecs):
+ codec = available_codecs[self._codec_selector_batch_num]
+ else:
+ codec = await self._codec_selector_by_check_compress(messages)
+ self._codec_selector_last_codec = codec
+ else:
+ if (
+ self._codec_selector_batch_num
+ % self._codec_selector_check_batches_interval
+ == 0
+ ):
+ self._codec_selector_last_codec = (
+ await self._codec_selector_by_check_compress(messages)
+ )
+ codec = self._codec_selector_last_codec
+ self._codec_selector_batch_num += 1
+ return codec
+
+ async def _get_available_codecs(self) -> List[PublicCodec]:
+ info = await self.wait_init()
+ topic_supported_codecs = info.supported_codecs
+ if not topic_supported_codecs:
+ topic_supported_codecs = [PublicCodec.RAW, PublicCodec.GZIP]
+
+ res = []
+ for codec in topic_supported_codecs:
+ if codec in self._codec_functions:
+ res.append(codec)
+
+ if not res:
+ raise TopicWriterError("Writer does not support topic's codecs")
+
+ res.sort()
+
+ return res
+
+ async def _codec_selector_by_check_compress(
+ self, messages: List[InternalMessage]
+ ) -> PublicCodec:
+ """
+ Try to compress messages and choose codec with the smallest result size.
+ """
+
+ test_messages = messages[:10]
+
+ available_codecs = await self._get_available_codecs()
+ if len(available_codecs) == 1:
+ return available_codecs[0]
+
+ def get_compressed_size(codec) -> int:
+ s = 0
+ f = self._codec_functions[codec]
+
+ for m in test_messages:
+ encoded = f(m.get_bytes())
+ s += len(encoded)
+
+ return s
+
+ def select_codec() -> PublicCodec:
+ min_codec = available_codecs[0]
+ min_size = get_compressed_size(min_codec)
+ for codec in available_codecs[1:]:
+ size = get_compressed_size(codec)
+ if size < min_size:
+ min_codec = codec
+ min_size = size
+ return min_codec
+
+ loop = asyncio.get_running_loop()
+ codec = await loop.run_in_executor(self._encode_executor, select_codec)
+ return codec
+
+ async def _read_loop(self, writer: "WriterAsyncIOStream"):
+ while True:
+ resp = await writer.receive()
+
+ for ack in resp.acks:
+ self._handle_receive_ack(ack)
+
+ def _handle_receive_ack(self, ack):
+ current_message = self._messages.popleft()
+ message_future = self._messages_future.popleft()
+ if current_message.seq_no != ack.seq_no:
+ raise TopicWriterError(
+ "internal error - receive unexpected ack. Expected seqno: %s, received seqno: %s"
+ % (current_message.seq_no, ack.seq_no)
+ )
+ message_future.set_result(
+ None
+ ) # todo - return result with offset or skip status
+
+ async def _send_loop(self, writer: "WriterAsyncIOStream"):
+ try:
+ messages = list(self._messages)
+
+ last_seq_no = 0
+ for m in messages:
+ writer.write([m])
+ last_seq_no = m.seq_no
+
+ while True:
+ m = await self._new_messages.get() # type: InternalMessage
+ if m.seq_no > last_seq_no:
+ writer.write([m])
+ except Exception as e:
+ self._stop(e)
+ finally:
+ pass
+
+ def _stop(self, reason: Exception):
+ if reason is None:
+ raise Exception("writer stop reason can not be None")
+
+ if self._stop_reason.done():
+ return
+
+ self._stop_reason.set_result(reason)
+
+ async def flush(self):
+ self._check_stop()
+ if not self._messages_future:
+ return
+
+ # wait last message
+ await asyncio.wait((self._messages_future[-1],))
+
+
+class WriterAsyncIOStream:
+ # todo slots
+ _closed: bool
+
+ last_seqno: int
+ supported_codecs: Optional[List[PublicCodec]]
+
+ _stream: IGrpcWrapperAsyncIO
+ _requests: asyncio.Queue
+ _responses: AsyncIterator
+
+ _update_token_interval: Optional[Union[int, float]]
+ _update_token_task: Optional[asyncio.Task]
+ _update_token_event: asyncio.Event
+ _get_token_function: Optional[Callable[[], str]]
+
+ def __init__(
+ self,
+ update_token_interval: Optional[Union[int, float]] = None,
+ get_token_function: Optional[Callable[[], str]] = None,
+ ):
+ self._closed = False
+
+ self._update_token_interval = update_token_interval
+ self._get_token_function = get_token_function
+ self._update_token_event = asyncio.Event()
+ self._update_token_task = None
+
+ async def close(self):
+ if self._closed:
+ return
+ self._closed = True
+
+ if self._update_token_task:
+ self._update_token_task.cancel()
+ await asyncio.wait([self._update_token_task])
+
+ self._stream.close()
+
+ @staticmethod
+ async def create(
+ driver: SupportedDriverType,
+ init_request: StreamWriteMessage.InitRequest,
+ update_token_interval: Optional[Union[int, float]] = None,
+ ) -> "WriterAsyncIOStream":
+ stream = GrpcWrapperAsyncIO(StreamWriteMessage.FromServer.from_proto)
+
+ await stream.start(
+ driver, _apis.TopicService.Stub, _apis.TopicService.StreamWrite
+ )
+
+ creds = driver._credentials
+ writer = WriterAsyncIOStream(
+ update_token_interval=update_token_interval,
+ get_token_function=creds.get_auth_token if creds else lambda: "",
+ )
+ await writer._start(stream, init_request)
+ return writer
+
+ async def receive(self) -> StreamWriteMessage.WriteResponse:
+ while True:
+ item = await self._stream.receive()
+
+ if isinstance(item, StreamWriteMessage.WriteResponse):
+ return item
+ if isinstance(item, UpdateTokenResponse):
+ self._update_token_event.set()
+ continue
+
+ # todo log unknown messages instead of raise exception
+ raise Exception("Unknown message while read writer answers: %s" % item)
+
+ async def _start(
+ self, stream: IGrpcWrapperAsyncIO, init_message: StreamWriteMessage.InitRequest
+ ):
+ stream.write(StreamWriteMessage.FromClient(init_message))
+
+ resp = await stream.receive()
+ self._ensure_ok(resp)
+ if not isinstance(resp, StreamWriteMessage.InitResponse):
+ raise TopicWriterError("Unexpected answer for init request: %s" % resp)
+
+ self.last_seqno = resp.last_seq_no
+ self.supported_codecs = [PublicCodec(codec) for codec in resp.supported_codecs]
+
+ self._stream = stream
+
+ if self._update_token_interval is not None:
+ self._update_token_event.set()
+ self._update_token_task = asyncio.create_task(
+ self._update_token_loop(), name="update_token_loop"
+ )
+
+ @staticmethod
+ def _ensure_ok(message: WriterMessagesFromServerToClient):
+ if not message.status.is_success():
+ raise TopicWriterError(
+ "status error from server in writer: %s", message.status
+ )
+
+ def write(self, messages: List[InternalMessage]):
+ if self._closed:
+ raise RuntimeError("Can not write on closed stream.")
+
+ for request in messages_to_proto_requests(messages):
+ self._stream.write(request)
+
+ async def _update_token_loop(self):
+ while True:
+ await asyncio.sleep(self._update_token_interval)
+ await self._update_token(token=self._get_token_function())
+
+ async def _update_token(self, token: str):
+ await self._update_token_event.wait()
+ try:
+ msg = StreamWriteMessage.FromClient(UpdateTokenRequest(token))
+ self._stream.write(msg)
+ finally:
+ self._update_token_event.clear()
diff --git a/ydb/public/sdk/python3/ydb/_topic_writer/topic_writer_sync.py b/ydb/public/sdk/python3/ydb/_topic_writer/topic_writer_sync.py
new file mode 100644
index 0000000000..e6b512387f
--- /dev/null
+++ b/ydb/public/sdk/python3/ydb/_topic_writer/topic_writer_sync.py
@@ -0,0 +1,123 @@
+from __future__ import annotations
+
+import asyncio
+from concurrent.futures import Future
+from typing import Union, List, Optional
+
+from .._grpc.grpcwrapper.common_utils import SupportedDriverType
+from .topic_writer import (
+ PublicWriterSettings,
+ PublicWriterInitInfo,
+ PublicWriteResult,
+ Message,
+ TopicWriterClosedError,
+)
+
+from .topic_writer_asyncio import WriterAsyncIO
+from .._topic_common.common import (
+ _get_shared_event_loop,
+ TimeoutType,
+ CallFromSyncToAsync,
+)
+
+
+class WriterSync:
+ _caller: CallFromSyncToAsync
+ _async_writer: WriterAsyncIO
+ _closed: bool
+
+ def __init__(
+ self,
+ driver: SupportedDriverType,
+ settings: PublicWriterSettings,
+ *,
+ eventloop: Optional[asyncio.AbstractEventLoop] = None,
+ ):
+
+ self._closed = False
+
+ if eventloop:
+ loop = eventloop
+ else:
+ loop = _get_shared_event_loop()
+
+ self._caller = CallFromSyncToAsync(loop)
+
+ async def create_async_writer():
+ return WriterAsyncIO(driver, settings)
+
+ self._async_writer = self._caller.safe_call_with_result(
+ create_async_writer(), None
+ )
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.close()
+
+ def close(self, *, flush: bool = True, timeout: TimeoutType = None):
+ if self._closed:
+ return
+
+ self._closed = True
+
+ self._caller.safe_call_with_result(
+ self._async_writer.close(flush=flush), timeout
+ )
+
+ def _check_closed(self):
+ if self._closed:
+ raise TopicWriterClosedError()
+
+ def async_flush(self) -> Future:
+ self._check_closed()
+
+ return self._caller.unsafe_call_with_future(self._async_writer.flush())
+
+ def flush(self, *, timeout=None):
+ self._check_closed()
+
+ return self._caller.unsafe_call_with_result(self._async_writer.flush(), timeout)
+
+ def async_wait_init(self) -> Future[PublicWriterInitInfo]:
+ self._check_closed()
+
+ return self._caller.unsafe_call_with_future(self._async_writer.wait_init())
+
+ def wait_init(self, *, timeout: TimeoutType = None) -> PublicWriterInitInfo:
+ self._check_closed()
+
+ return self._caller.unsafe_call_with_result(
+ self._async_writer.wait_init(), timeout
+ )
+
+ def write(
+ self,
+ messages: Union[Message, List[Message]],
+ timeout: TimeoutType = None,
+ ):
+ self._check_closed()
+
+ self._caller.safe_call_with_result(self._async_writer.write(messages), timeout)
+
+ def async_write_with_ack(
+ self,
+ messages: Union[Message, List[Message]],
+ ) -> Future[Union[PublicWriteResult, List[PublicWriteResult]]]:
+ self._check_closed()
+
+ return self._caller.unsafe_call_with_future(
+ self._async_writer.write_with_ack(messages)
+ )
+
+ def write_with_ack(
+ self,
+ messages: Union[Message, List[Message]],
+ timeout: Union[float, None] = None,
+ ) -> Union[PublicWriteResult, List[PublicWriteResult]]:
+ self._check_closed()
+
+ return self._caller.unsafe_call_with_result(
+ self._async_writer.write_with_ack(messages), timeout=timeout
+ )
diff --git a/ydb/public/sdk/python3/ydb/_utilities.py b/ydb/public/sdk/python3/ydb/_utilities.py
index 32419b1bf9..0b72a19897 100644
--- a/ydb/public/sdk/python3/ydb/_utilities.py
+++ b/ydb/public/sdk/python3/ydb/_utilities.py
@@ -1,10 +1,11 @@
# -*- coding: utf-8 -*-
-import six
+import threading
import codecs
from concurrent import futures
import functools
import hashlib
import collections
+import urllib.parse
from . import ydb_version
try:
@@ -55,8 +56,8 @@ def parse_connection_string(connection_string):
# default is grpcs
cs = _grpcs_protocol + cs
- p = six.moves.urllib.parse.urlparse(connection_string)
- b = six.moves.urllib.parse.parse_qs(p.query)
+ p = urllib.parse.urlparse(connection_string)
+ b = urllib.parse.parse_qs(p.query)
database = b.get("database", [])
assert len(database) > 0
@@ -77,11 +78,9 @@ def wrap_async_call_exceptions(f):
def get_query_hash(yql_text):
try:
- return hashlib.sha256(
- six.text_type(yql_text, "utf-8").encode("utf-8")
- ).hexdigest()
+ return hashlib.sha256(str(yql_text, "utf-8").encode("utf-8")).hexdigest()
except TypeError:
- return hashlib.sha256(six.text_type(yql_text).encode("utf-8")).hexdigest()
+ return hashlib.sha256(str(yql_text).encode("utf-8")).hexdigest()
class LRUCache(object):
@@ -159,3 +158,17 @@ class SyncResponseIterator(object):
def __next__(self):
return self._next()
+
+
+class AtomicCounter:
+ _lock: threading.Lock
+ _value: int
+
+ def __init__(self, initial_value: int = 0):
+ self._lock = threading.Lock()
+ self._value = initial_value
+
+ def inc_and_get(self) -> int:
+ with self._lock:
+ self._value += 1
+ return self._value
diff --git a/ydb/public/sdk/python3/ydb/aio/connection.py b/ydb/public/sdk/python3/ydb/aio/connection.py
index 88ab738c6a..fbfcfaaf6c 100644
--- a/ydb/public/sdk/python3/ydb/aio/connection.py
+++ b/ydb/public/sdk/python3/ydb/aio/connection.py
@@ -1,5 +1,6 @@
import logging
import asyncio
+import typing
from typing import Any, Tuple, Callable, Iterable
import collections
import grpc
@@ -24,11 +25,19 @@ from ydb.driver import DriverConfig
from ydb.settings import BaseRequestSettings
from ydb import issues
+# Workaround for good IDE and universal for runtime
+if typing.TYPE_CHECKING:
+ from ydb._grpc.v4 import ydb_topic_v1_pb2_grpc
+else:
+ from ydb._grpc.common import ydb_topic_v1_pb2_grpc
+
+
_stubs_list = (
_apis.TableService.Stub,
_apis.SchemeService.Stub,
_apis.DiscoveryService.Stub,
_apis.CmsService.Stub,
+ ydb_topic_v1_pb2_grpc.TopicServiceStub,
)
logger = logging.getLogger(__name__)
diff --git a/ydb/public/sdk/python3/ydb/aio/credentials.py b/ydb/public/sdk/python3/ydb/aio/credentials.py
index e98404407b..93868b279a 100644
--- a/ydb/public/sdk/python3/ydb/aio/credentials.py
+++ b/ydb/public/sdk/python3/ydb/aio/credentials.py
@@ -3,7 +3,6 @@ import time
import abc
import asyncio
import logging
-import six
from ydb import issues, credentials
logger = logging.getLogger(__name__)
@@ -55,7 +54,6 @@ class _AtMostOneExecution(object):
asyncio.ensure_future(self._wrapped_execution(callback))
-@six.add_metaclass(abc.ABCMeta)
class AbstractExpiringTokenCredentials(credentials.AbstractExpiringTokenCredentials):
def __init__(self):
super(AbstractExpiringTokenCredentials, self).__init__()
diff --git a/ydb/public/sdk/python3/ydb/aio/driver.py b/ydb/public/sdk/python3/ydb/aio/driver.py
index 3bf6cca82e..1aa3ad2757 100644
--- a/ydb/public/sdk/python3/ydb/aio/driver.py
+++ b/ydb/public/sdk/python3/ydb/aio/driver.py
@@ -1,42 +1,7 @@
-import os
-
from . import pool, scheme, table
import ydb
-from ydb.driver import get_config
-
-
-def default_credentials(credentials=None):
- if credentials is not None:
- return credentials
-
- service_account_key_file = os.getenv("YDB_SERVICE_ACCOUNT_KEY_FILE_CREDENTIALS")
- if service_account_key_file is not None:
- from .iam import ServiceAccountCredentials
-
- return ServiceAccountCredentials.from_file(service_account_key_file)
-
- anonymous_credetials = os.getenv("YDB_ANONYMOUS_CREDENTIALS", "0") == "1"
- if anonymous_credetials:
- return ydb.credentials.AnonymousCredentials()
-
- metadata_credentials = os.getenv("YDB_METADATA_CREDENTIALS", "0") == "1"
- if metadata_credentials:
- from .iam import MetadataUrlCredentials
-
- return MetadataUrlCredentials()
-
- access_token = os.getenv("YDB_ACCESS_TOKEN_CREDENTIALS")
- if access_token is not None:
- return ydb.credentials.AccessTokenCredentials(access_token)
-
- # (legacy instantiation)
- creds = ydb.auth_helpers.construct_credentials_from_environ()
- if creds is not None:
- return creds
-
- from .iam import MetadataUrlCredentials
-
- return MetadataUrlCredentials()
+from .. import _utilities
+from ydb.driver import get_config, default_credentials
class DriverConfig(ydb.DriverConfig):
@@ -56,7 +21,7 @@ class DriverConfig(ydb.DriverConfig):
def default_from_connection_string(
cls, connection_string, root_certificates=None, credentials=None, **kwargs
):
- endpoint, database = ydb.parse_connection_string(connection_string)
+ endpoint, database = _utilities.parse_connection_string(connection_string)
return cls(
endpoint,
database,
@@ -67,6 +32,8 @@ class DriverConfig(ydb.DriverConfig):
class Driver(pool.ConnectionPool):
+ _credentials: ydb.Credentials # used for topic clients
+
def __init__(
self,
driver_config=None,
@@ -77,6 +44,8 @@ class Driver(pool.ConnectionPool):
credentials=None,
**kwargs
):
+ from .. import topic # local import for prevent cycle import error
+
config = get_config(
driver_config,
connection_string,
@@ -89,5 +58,8 @@ class Driver(pool.ConnectionPool):
super(Driver, self).__init__(config)
+ self._credentials = config.credentials
+
self.scheme_client = scheme.SchemeClient(self)
self.table_client = table.TableClient(self, config.table_client_settings)
+ self.topic_client = topic.TopicClientAsyncIO(self, config.topic_client_settings)
diff --git a/ydb/public/sdk/python3/ydb/aio/iam.py b/ydb/public/sdk/python3/ydb/aio/iam.py
index 51b650f24b..b56c066043 100644
--- a/ydb/public/sdk/python3/ydb/aio/iam.py
+++ b/ydb/public/sdk/python3/ydb/aio/iam.py
@@ -3,7 +3,6 @@ import time
import abc
import logging
-import six
from ydb.iam import auth
from .credentials import AbstractExpiringTokenCredentials
@@ -24,7 +23,6 @@ except ImportError:
aiohttp = None
-@six.add_metaclass(abc.ABCMeta)
class TokenServiceCredentials(AbstractExpiringTokenCredentials):
def __init__(self, iam_endpoint=None, iam_channel_credentials=None):
super(TokenServiceCredentials, self).__init__()
diff --git a/ydb/public/sdk/python3/ydb/aio/table.py b/ydb/public/sdk/python3/ydb/aio/table.py
index f937a9283c..06f8ca7c81 100644
--- a/ydb/public/sdk/python3/ydb/aio/table.py
+++ b/ydb/public/sdk/python3/ydb/aio/table.py
@@ -13,7 +13,6 @@ from ydb.table import (
_scan_query_request_factory,
_wrap_scan_query_response,
BaseTxContext,
- _allow_split_transaction,
)
from . import _utilities
from ydb import _apis, _session_impl
@@ -121,9 +120,7 @@ class Session(BaseSession):
set_read_replicas_settings,
)
- def transaction(
- self, tx_mode=None, *, allow_split_transactions=_allow_split_transaction
- ):
+ def transaction(self, tx_mode=None, *, allow_split_transactions=None):
return TxContext(
self._driver,
self._state,
@@ -194,8 +191,6 @@ class TxContext(BaseTxContext):
self, query, parameters=None, commit_tx=False, settings=None
): # pylint: disable=W0236
- self._check_split()
-
return await super().execute(query, parameters, commit_tx, settings)
async def commit(self, settings=None): # pylint: disable=W0236
diff --git a/ydb/public/sdk/python3/ydb/auth_helpers.py b/ydb/public/sdk/python3/ydb/auth_helpers.py
index 5d889555df..6399c3cfdf 100644
--- a/ydb/public/sdk/python3/ydb/auth_helpers.py
+++ b/ydb/public/sdk/python3/ydb/auth_helpers.py
@@ -1,9 +1,6 @@
# -*- coding: utf-8 -*-
import os
-from . import credentials, tracing
-import warnings
-
def read_bytes(f):
with open(f, "rb") as fr:
@@ -15,43 +12,3 @@ def load_ydb_root_certificate():
if path is not None and os.path.exists(path):
return read_bytes(path)
return None
-
-
-def construct_credentials_from_environ(tracer=None):
- tracer = tracer if tracer is not None else tracing.Tracer(None)
- warnings.warn(
- "using construct_credentials_from_environ method for credentials instantiation is deprecated and will be "
- "removed in the future major releases. Please instantialize credentials by default or provide correct credentials "
- "instance to the Driver."
- )
-
- # dynamically import required authentication libraries
- if (
- os.getenv("USE_METADATA_CREDENTIALS") is not None
- and int(os.getenv("USE_METADATA_CREDENTIALS")) == 1
- ):
- import ydb.iam
-
- tracing.trace(tracer, {"credentials.metadata": True})
- return ydb.iam.MetadataUrlCredentials()
-
- if os.getenv("YDB_TOKEN") is not None:
- tracing.trace(tracer, {"credentials.access_token": True})
- return credentials.AuthTokenCredentials(os.getenv("YDB_TOKEN"))
-
- if os.getenv("SA_KEY_FILE") is not None:
-
- import ydb.iam
-
- tracing.trace(tracer, {"credentials.sa_key_file": True})
- root_certificates_file = os.getenv("SSL_ROOT_CERTIFICATES_FILE", None)
- iam_channel_credentials = {}
- if root_certificates_file is not None:
- iam_channel_credentials = {
- "root_certificates": read_bytes(root_certificates_file)
- }
- return ydb.iam.ServiceAccountCredentials.from_file(
- os.getenv("SA_KEY_FILE"),
- iam_channel_credentials=iam_channel_credentials,
- iam_endpoint=os.getenv("IAM_ENDPOINT", "iam.api.cloud.yandex.net:443"),
- )
diff --git a/ydb/public/sdk/python3/ydb/convert.py b/ydb/public/sdk/python3/ydb/convert.py
index b231bb1091..81348d311c 100644
--- a/ydb/public/sdk/python3/ydb/convert.py
+++ b/ydb/public/sdk/python3/ydb/convert.py
@@ -1,7 +1,6 @@
# -*- coding: utf-8 -*-
import decimal
from google.protobuf import struct_pb2
-import six
from . import issues, types, _apis
@@ -13,7 +12,7 @@ _DecimalNanRepr = 10**35 + 1
_DecimalInfRepr = 10**35
_DecimalSignedInfRepr = -(10**35)
_primitive_type_by_id = {}
-_default_allow_truncated_result = True
+_default_allow_truncated_result = False
def _initialize():
@@ -82,9 +81,7 @@ def _pb_to_list(type_pb, value_pb, table_client_settings):
def _pb_to_tuple(type_pb, value_pb, table_client_settings):
return tuple(
_to_native_value(item_type, item_value, table_client_settings)
- for item_type, item_value in six.moves.zip(
- type_pb.tuple_type.elements, value_pb.items
- )
+ for item_type, item_value in zip(type_pb.tuple_type.elements, value_pb.items)
)
@@ -107,7 +104,7 @@ class _Struct(_DotDict):
def _pb_to_struct(type_pb, value_pb, table_client_settings):
result = _Struct()
- for member, item in six.moves.zip(type_pb.struct_type.members, value_pb.items):
+ for member, item in zip(type_pb.struct_type.members, value_pb.items):
result[member.name] = _to_native_value(member.type, item, table_client_settings)
return result
@@ -202,9 +199,7 @@ def _list_to_pb(type_pb, value):
def _tuple_to_pb(type_pb, value):
value_pb = _apis.ydb_value.Value()
- for element_type, element_value in six.moves.zip(
- type_pb.tuple_type.elements, value
- ):
+ for element_type, element_value in zip(type_pb.tuple_type.elements, value):
value_item_proto = value_pb.items.add()
value_item_proto.MergeFrom(_from_native_value(element_type, element_value))
return value_pb
@@ -290,7 +285,7 @@ def parameters_to_pb(parameters_types, parameters_values):
return {}
param_values_pb = {}
- for name, type_pb in six.iteritems(parameters_types):
+ for name, type_pb in parameters_types.items():
result = _apis.ydb_value.TypedValue()
ttype = type_pb
if isinstance(type_pb, types.AbstractTypeBuilder):
@@ -332,7 +327,7 @@ class _ResultSet(object):
for row_proto in message.rows:
row = _Row(message.columns)
- for column, value, column_info in six.moves.zip(
+ for column, value, column_info in zip(
message.columns, row_proto.items, column_parsers
):
v_type = value.WhichOneof("value")
@@ -400,9 +395,7 @@ class _LazyRow(_DotDict):
super(_LazyRow, self).__init__()
self._columns = columns
self._table_client_settings = table_client_settings
- for i, (column, row_item) in enumerate(
- six.moves.zip(self._columns, proto_row.items)
- ):
+ for i, (column, row_item) in enumerate(zip(self._columns, proto_row.items)):
super(_LazyRow, self).__setitem__(
column.name,
_LazyRowItem(row_item, column.type, table_client_settings, parsers[i]),
diff --git a/ydb/public/sdk/python3/ydb/credentials.py b/ydb/public/sdk/python3/ydb/credentials.py
index 8e22fe2a84..2a2dea3b33 100644
--- a/ydb/public/sdk/python3/ydb/credentials.py
+++ b/ydb/public/sdk/python3/ydb/credentials.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
import abc
-import six
+import typing
+
from . import tracing, issues, connection
from . import settings as settings_impl
import threading
@@ -9,8 +10,7 @@ import logging
import time
# Workaround for good IDE and universal for runtime
-# noinspection PyUnreachableCode
-if False:
+if typing.TYPE_CHECKING:
from ._grpc.v4.protos import ydb_auth_pb2
from ._grpc.v4 import ydb_auth_v1_pb2_grpc
else:
@@ -22,15 +22,13 @@ YDB_AUTH_TICKET_HEADER = "x-ydb-auth-ticket"
logger = logging.getLogger(__name__)
-@six.add_metaclass(abc.ABCMeta)
-class AbstractCredentials(object):
+class AbstractCredentials(abc.ABC):
"""
An abstract class that provides auth metadata
"""
-@six.add_metaclass(abc.ABCMeta)
-class Credentials(object):
+class Credentials(abc.ABC):
def __init__(self, tracer=None):
self.tracer = tracer if tracer is not None else tracing.Tracer(None)
@@ -41,6 +39,12 @@ class Credentials(object):
"""
pass
+ def get_auth_token(self) -> str:
+ for header, token in self.auth_metadata():
+ if header == YDB_AUTH_TICKET_HEADER:
+ return token
+ return ""
+
class OneToManyValue(object):
def __init__(self):
@@ -87,7 +91,6 @@ class AtMostOneExecution(object):
self._can_schedule = True
-@six.add_metaclass(abc.ABCMeta)
class AbstractExpiringTokenCredentials(Credentials):
def __init__(self, tracer=None):
super(AbstractExpiringTokenCredentials, self).__init__(tracer)
diff --git a/ydb/public/sdk/python3/ydb/dbapi/cursor.py b/ydb/public/sdk/python3/ydb/dbapi/cursor.py
index 71175abf4e..eb26dc2b0f 100644
--- a/ydb/public/sdk/python3/ydb/dbapi/cursor.py
+++ b/ydb/public/sdk/python3/ydb/dbapi/cursor.py
@@ -5,8 +5,6 @@ import datetime
import itertools
import logging
-import six
-
import ydb
from .errors import DatabaseError
@@ -42,7 +40,7 @@ def render_datetime(value):
def render(value):
if value is None:
return "NULL"
- if isinstance(value, six.string_types):
+ if isinstance(value, str):
return render_str(value)
if isinstance(value, datetime.datetime):
return render_datetime(value)
diff --git a/ydb/public/sdk/python3/ydb/default_pem.py b/ydb/public/sdk/python3/ydb/default_pem.py
index 92286ba237..b8272efd29 100644
--- a/ydb/public/sdk/python3/ydb/default_pem.py
+++ b/ydb/public/sdk/python3/ydb/default_pem.py
@@ -1,6 +1,3 @@
-import six
-
-
data = """
# Issuer: CN=GlobalSign Root CA O=GlobalSign nv-sa OU=Root CA
# Subject: CN=GlobalSign Root CA O=GlobalSign nv-sa OU=Root CA
@@ -4686,6 +4683,4 @@ LpuQKbSbIERsmR+QqQ==
def load_default_pem():
global data
- if six.PY3:
- return data.encode("utf-8")
- return data
+ return data.encode("utf-8")
diff --git a/ydb/public/sdk/python3/ydb/driver.py b/ydb/public/sdk/python3/ydb/driver.py
index 9b3fa99cfa..e3274687ea 100644
--- a/ydb/public/sdk/python3/ydb/driver.py
+++ b/ydb/public/sdk/python3/ydb/driver.py
@@ -1,15 +1,11 @@
# -*- coding: utf-8 -*-
from . import credentials as credentials_impl, table, scheme, pool
from . import tracing
-import six
import os
import grpc
from . import _utilities
-if six.PY2:
- Any = None
-else:
- from typing import Any # noqa
+from typing import Any # noqa
class RPCCompression:
@@ -23,10 +19,17 @@ class RPCCompression:
def default_credentials(credentials=None, tracer=None):
tracer = tracer if tracer is not None else tracing.Tracer(None)
with tracer.trace("Driver.default_credentials") as ctx:
- if credentials is not None:
+ if credentials is None:
+ ctx.trace({"credentials.anonymous": True})
+ return credentials_impl.AnonymousCredentials()
+ else:
ctx.trace({"credentials.prepared": True})
return credentials
+
+def credentials_from_env_variables(tracer=None):
+ tracer = tracer if tracer is not None else tracing.Tracer(None)
+ with tracer.trace("Driver.credentials_from_env_variables") as ctx:
service_account_key_file = os.getenv("YDB_SERVICE_ACCOUNT_KEY_FILE_CREDENTIALS")
if service_account_key_file is not None:
ctx.trace({"credentials.service_account_key_file": True})
@@ -51,9 +54,7 @@ def default_credentials(credentials=None, tracer=None):
ctx.trace({"credentials.access_token": True})
return credentials_impl.AuthTokenCredentials(access_token)
- import ydb.iam
-
- return ydb.iam.MetadataUrlCredentials(tracer=tracer)
+ return default_credentials(None, tracer)
class DriverConfig(object):
@@ -70,6 +71,7 @@ class DriverConfig(object):
"grpc_keep_alive_timeout",
"secure_channel",
"table_client_settings",
+ "topic_client_settings",
"endpoints",
"primary_user_agent",
"tracer",
@@ -92,6 +94,7 @@ class DriverConfig(object):
private_key=None,
grpc_keep_alive_timeout=None,
table_client_settings=None,
+ topic_client_settings=None,
endpoints=None,
primary_user_agent="python-library",
tracer=None,
@@ -138,6 +141,7 @@ class DriverConfig(object):
self.private_key = private_key
self.grpc_keep_alive_timeout = grpc_keep_alive_timeout
self.table_client_settings = table_client_settings
+ self.topic_client_settings = topic_client_settings
self.primary_user_agent = primary_user_agent
self.tracer = tracer if tracer is not None else tracing.Tracer(None)
self.grpc_lb_policy_name = grpc_lb_policy_name
@@ -228,6 +232,8 @@ class Driver(pool.ConnectionPool):
:param database: A database path
:param credentials: A credentials. If not specifed credentials constructed by default.
"""
+ from . import topic # local import for prevent cycle import error
+
driver_config = get_config(
driver_config,
connection_string,
@@ -238,5 +244,9 @@ class Driver(pool.ConnectionPool):
)
super(Driver, self).__init__(driver_config)
+
+ self._credentials = driver_config.credentials
+
self.scheme_client = scheme.SchemeClient(self)
self.table_client = table.TableClient(self, driver_config.table_client_settings)
+ self.topic_client = topic.TopicClient(self, driver_config.topic_client_settings)
diff --git a/ydb/public/sdk/python3/ydb/export.py b/ydb/public/sdk/python3/ydb/export.py
index 30898cbb42..bc35bd284c 100644
--- a/ydb/public/sdk/python3/ydb/export.py
+++ b/ydb/public/sdk/python3/ydb/export.py
@@ -1,12 +1,12 @@
import enum
+import typing
from . import _apis
from . import settings_impl as s_impl
# Workaround for good IDE and universal for runtime
-# noinspection PyUnreachableCode
-if False:
+if typing.TYPE_CHECKING:
from ._grpc.v4.protos import ydb_export_pb2
from ._grpc.v4 import ydb_export_v1_pb2_grpc
else:
diff --git a/ydb/public/sdk/python3/ydb/global_settings.py b/ydb/public/sdk/python3/ydb/global_settings.py
index de8b0b1b06..8edac3f4b4 100644
--- a/ydb/public/sdk/python3/ydb/global_settings.py
+++ b/ydb/public/sdk/python3/ydb/global_settings.py
@@ -1,16 +1,24 @@
+import warnings
+
from . import convert
from . import table
def global_allow_truncated_result(enabled: bool = True):
- """
- call global_allow_truncated_result(False) for more safe execution and compatible with future changes
- """
+ if convert._default_allow_truncated_result == enabled:
+ return
+
+ if enabled:
+ warnings.warn("Global allow truncated response is deprecated behaviour.")
+
convert._default_allow_truncated_result = enabled
def global_allow_split_transactions(enabled: bool):
- """
- call global_allow_truncated_result(False) for more safe execution and compatible with future changes
- """
- table._allow_split_transaction = enabled
+ if table._default_allow_split_transaction == enabled:
+ return
+
+ if enabled:
+ warnings.warn("Global allow split transaction is deprecated behaviour.")
+
+ table._default_allow_split_transaction = enabled
diff --git a/ydb/public/sdk/python3/ydb/iam/auth.py b/ydb/public/sdk/python3/ydb/iam/auth.py
index 06b07e917e..50d98b4b4a 100644
--- a/ydb/public/sdk/python3/ydb/iam/auth.py
+++ b/ydb/public/sdk/python3/ydb/iam/auth.py
@@ -3,7 +3,6 @@ from ydb import credentials, tracing
import grpc
import time
import abc
-import six
from datetime import datetime
import json
import os
@@ -45,7 +44,6 @@ def get_jwt(account_id, access_key_id, private_key, jwt_expiration_timeout):
)
-@six.add_metaclass(abc.ABCMeta)
class TokenServiceCredentials(credentials.AbstractExpiringTokenCredentials):
def __init__(self, iam_endpoint=None, iam_channel_credentials=None, tracer=None):
super(TokenServiceCredentials, self).__init__(tracer)
@@ -84,8 +82,7 @@ class TokenServiceCredentials(credentials.AbstractExpiringTokenCredentials):
return {"access_token": response.iam_token, "expires_in": expires_in}
-@six.add_metaclass(abc.ABCMeta)
-class BaseJWTCredentials(object):
+class BaseJWTCredentials(abc.ABC):
def __init__(self, account_id, access_key_id, private_key):
self._account_id = account_id
self._jwt_expiration_timeout = 60.0 * 60
diff --git a/ydb/public/sdk/python3/ydb/import_client.py b/ydb/public/sdk/python3/ydb/import_client.py
index d1ccc99af6..d94294ca7c 100644
--- a/ydb/public/sdk/python3/ydb/import_client.py
+++ b/ydb/public/sdk/python3/ydb/import_client.py
@@ -1,12 +1,12 @@
import enum
+import typing
from . import _apis
from . import settings_impl as s_impl
# Workaround for good IDE and universal for runtime
-# noinspection PyUnreachableCode
-if False:
+if typing.TYPE_CHECKING:
from ._grpc.v4.protos import ydb_import_pb2
from ._grpc.v4 import ydb_import_v1_pb2_grpc
else:
diff --git a/ydb/public/sdk/python3/ydb/issues.py b/ydb/public/sdk/python3/ydb/issues.py
index 0a0d6a907e..100af01d29 100644
--- a/ydb/public/sdk/python3/ydb/issues.py
+++ b/ydb/public/sdk/python3/ydb/issues.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
from google.protobuf import text_format
import enum
-from six.moves import queue
+import queue
from . import _apis
@@ -159,11 +159,18 @@ class SessionPoolEmpty(Error, queue.Empty):
def _format_issues(issues):
if not issues:
return ""
+
return " ,".join(
- [text_format.MessageToString(issue, False, True) for issue in issues]
+ text_format.MessageToString(issue, as_utf8=False, as_one_line=True)
+ for issue in issues
)
+def _format_response(response):
+ fmt_issues = _format_issues(response.issues)
+ return f"{fmt_issues} (server_code: {response.status})"
+
+
_success_status_codes = {StatusCode.STATUS_CODE_UNSPECIFIED, StatusCode.SUCCESS}
_server_side_error_map = {
StatusCode.BAD_REQUEST: BadRequest,
@@ -190,4 +197,4 @@ _server_side_error_map = {
def _process_response(response_proto):
if response_proto.status not in _success_status_codes:
exc_obj = _server_side_error_map.get(response_proto.status)
- raise exc_obj(_format_issues(response_proto.issues), response_proto.issues)
+ raise exc_obj(_format_response(response_proto), response_proto.issues)
diff --git a/ydb/public/sdk/python3/ydb/pool.py b/ydb/public/sdk/python3/ydb/pool.py
index dfda0adff2..007aa94d33 100644
--- a/ydb/public/sdk/python3/ydb/pool.py
+++ b/ydb/public/sdk/python3/ydb/pool.py
@@ -1,15 +1,15 @@
# -*- coding: utf-8 -*-
+import abc
import threading
import logging
from concurrent import futures
import collections
import random
-import six
-
from . import connection as connection_impl, issues, resolver, _utilities, tracing
-from abc import abstractmethod, ABCMeta
+from abc import abstractmethod
+from .connection import Connection
logger = logging.getLogger(__name__)
@@ -127,7 +127,7 @@ class ConnectionsCache(object):
return subscription
@tracing.with_trace()
- def get(self, preferred_endpoint=None):
+ def get(self, preferred_endpoint=None) -> Connection:
with self.lock:
if (
preferred_endpoint is not None
@@ -295,8 +295,7 @@ class Discovery(threading.Thread):
self.logger.info("Successfully terminated discovery process")
-@six.add_metaclass(ABCMeta)
-class IConnectionPool:
+class IConnectionPool(abc.ABC):
@abstractmethod
def __init__(self, driver_config):
"""
diff --git a/ydb/public/sdk/python3/ydb/scheme.py b/ydb/public/sdk/python3/ydb/scheme.py
index 88eca78c77..96a6c25f8d 100644
--- a/ydb/public/sdk/python3/ydb/scheme.py
+++ b/ydb/public/sdk/python3/ydb/scheme.py
@@ -1,7 +1,6 @@
# -*- coding: utf-8 -*-
import abc
import enum
-import six
from abc import abstractmethod
from . import issues, operation, settings as settings_impl, _apis
@@ -347,8 +346,7 @@ def _wrap_describe_path_response(rpc_state, response):
return _wrap_scheme_entry(message.self)
-@six.add_metaclass(abc.ABCMeta)
-class ISchemeClient:
+class ISchemeClient(abc.ABC):
@abstractmethod
def __init__(self, driver):
pass
diff --git a/ydb/public/sdk/python3/ydb/scheme_test.py b/ydb/public/sdk/python3/ydb/scheme_test.py
deleted file mode 100644
index d63850bf93..0000000000
--- a/ydb/public/sdk/python3/ydb/scheme_test.py
+++ /dev/null
@@ -1,30 +0,0 @@
-from .scheme import (
- SchemeEntryType,
- _wrap_scheme_entry,
- _wrap_list_directory_response,
-)
-from ._apis import ydb_scheme
-
-
-def test_wrap_scheme_entry():
- assert (
- _wrap_scheme_entry(ydb_scheme.Entry(type=1)).type is SchemeEntryType.DIRECTORY
- )
- assert _wrap_scheme_entry(ydb_scheme.Entry(type=17)).type is SchemeEntryType.TOPIC
-
- assert (
- _wrap_scheme_entry(ydb_scheme.Entry()).type is SchemeEntryType.TYPE_UNSPECIFIED
- )
- assert (
- _wrap_scheme_entry(ydb_scheme.Entry(type=10)).type
- is SchemeEntryType.TYPE_UNSPECIFIED
- )
- assert (
- _wrap_scheme_entry(ydb_scheme.Entry(type=1001)).type
- is SchemeEntryType.TYPE_UNSPECIFIED
- )
-
-
-def test_wrap_list_directory_response():
- d = _wrap_list_directory_response(None, ydb_scheme.ListDirectoryResponse())
- assert d.type is SchemeEntryType.TYPE_UNSPECIFIED
diff --git a/ydb/public/sdk/python3/ydb/scripting.py b/ydb/public/sdk/python3/ydb/scripting.py
index 9fed037aec..131324301c 100644
--- a/ydb/public/sdk/python3/ydb/scripting.py
+++ b/ydb/public/sdk/python3/ydb/scripting.py
@@ -1,6 +1,7 @@
+import typing
+
# Workaround for good IDE and universal for runtime
-# noinspection PyUnreachableCode
-if False:
+if typing.TYPE_CHECKING:
from ._grpc.v4.protos import ydb_scripting_pb2
from ._grpc.v4 import ydb_scripting_v1_pb2_grpc
else:
diff --git a/ydb/public/sdk/python3/ydb/sqlalchemy/__init__.py b/ydb/public/sdk/python3/ydb/sqlalchemy/__init__.py
index 9e065d8f0e..aa9b2d006c 100644
--- a/ydb/public/sdk/python3/ydb/sqlalchemy/__init__.py
+++ b/ydb/public/sdk/python3/ydb/sqlalchemy/__init__.py
@@ -191,11 +191,16 @@ try:
ydb.PrimitiveType.DyNumber: sa.TEXT,
}
- def _get_column_type(t):
- if isinstance(t.item, ydb.DecimalType):
- return sa.DECIMAL(precision=t.item.precision, scale=t.item.scale)
+ def _get_column_info(t):
+ nullable = False
+ if isinstance(t, ydb.OptionalType):
+ nullable = True
+ t = t.item
- return COLUMN_TYPES[t.item]
+ if isinstance(t, ydb.DecimalType):
+ return sa.DECIMAL(precision=t.precision, scale=t.scale), nullable
+
+ return COLUMN_TYPES[t], nullable
class YqlDialect(DefaultDialect):
name = "yql"
@@ -250,11 +255,12 @@ try:
columns = raw_conn.describe(qt)
as_compatible = []
for column in columns:
+ col_type, nullable = _get_column_info(column.type)
as_compatible.append(
{
"name": column.name,
- "type": _get_column_type(column.type),
- "nullable": True,
+ "type": col_type,
+ "nullable": nullable,
}
)
diff --git a/ydb/public/sdk/python3/ydb/table.py b/ydb/public/sdk/python3/ydb/table.py
index 660959bfc6..799a542666 100644
--- a/ydb/public/sdk/python3/ydb/table.py
+++ b/ydb/public/sdk/python3/ydb/table.py
@@ -7,7 +7,6 @@ import time
import random
import enum
-import six
from . import (
issues,
convert,
@@ -28,7 +27,7 @@ try:
except ImportError:
interceptor = None
-_allow_split_transaction = True
+_default_allow_split_transaction = False
logger = logging.getLogger(__name__)
@@ -770,8 +769,7 @@ class TableDescription(object):
return self
-@six.add_metaclass(abc.ABCMeta)
-class AbstractTransactionModeBuilder(object):
+class AbstractTransactionModeBuilder(abc.ABC):
@property
@abc.abstractmethod
def name(self):
@@ -949,7 +947,7 @@ def retry_operation_impl(callee, retry_settings=None, *args, **kwargs):
retry_settings = RetrySettings() if retry_settings is None else retry_settings
status = None
- for attempt in six.moves.range(retry_settings.max_retries + 1):
+ for attempt in range(retry_settings.max_retries + 1):
try:
result = YdbRetryOperationFinalResult(callee(*args, **kwargs))
yield result
@@ -1103,8 +1101,7 @@ def _scan_query_request_factory(query, parameters=None, settings=None):
)
-@six.add_metaclass(abc.ABCMeta)
-class ISession:
+class ISession(abc.ABC):
@abstractmethod
def __init__(self, driver, table_client_settings):
pass
@@ -1184,9 +1181,7 @@ class ISession:
pass
@abstractmethod
- def transaction(
- self, tx_mode=None, allow_split_transactions=_allow_split_transaction
- ):
+ def transaction(self, tx_mode=None, allow_split_transactions=None):
pass
@abstractmethod
@@ -1268,8 +1263,7 @@ class ISession:
pass
-@six.add_metaclass(abc.ABCMeta)
-class ITableClient:
+class ITableClient(abc.ABC):
def __init__(self, driver, table_client_settings=None):
pass
@@ -1691,9 +1685,7 @@ class BaseSession(ISession):
self._state.endpoint,
)
- def transaction(
- self, tx_mode=None, allow_split_transactions=_allow_split_transaction
- ):
+ def transaction(self, tx_mode=None, allow_split_transactions=None):
return TxContext(
self._driver,
self._state,
@@ -2100,8 +2092,7 @@ class Session(BaseSession):
)
-@six.add_metaclass(abc.ABCMeta)
-class ITxContext:
+class ITxContext(abc.ABC):
@abstractmethod
def __init__(self, driver, session_state, session, tx_mode=None):
"""
@@ -2231,7 +2222,7 @@ class BaseTxContext(ITxContext):
session,
tx_mode=None,
*,
- allow_split_transactions=_allow_split_transaction
+ allow_split_transactions=None
):
"""
An object that provides a simple transaction context manager that allows statements execution
@@ -2316,6 +2307,8 @@ class BaseTxContext(ITxContext):
"""
self._check_split()
+ if commit_tx:
+ self._set_finish(self._COMMIT)
return self._driver(
_tx_ctx_impl.execute_request_factory(
@@ -2416,7 +2409,13 @@ class BaseTxContext(ITxContext):
Deny all operaions with transaction after commit/rollback.
Exception: double commit and double rollbacks, because it is safe
"""
- if self._allow_split_transactions:
+ allow_split_transaction = (
+ self._allow_split_transactions
+ if self._allow_split_transactions is not None
+ else _default_allow_split_transaction
+ )
+
+ if allow_split_transaction:
return
if self._finished != "" and self._finished != allow:
diff --git a/ydb/public/sdk/python3/ydb/table_test.py b/ydb/public/sdk/python3/ydb/table_test.py
deleted file mode 100644
index 8e93a698f4..0000000000
--- a/ydb/public/sdk/python3/ydb/table_test.py
+++ /dev/null
@@ -1,140 +0,0 @@
-from unittest import mock
-from . import (
- retry_operation_impl,
- YdbRetryOperationFinalResult,
- issues,
- YdbRetryOperationSleepOpt,
- RetrySettings,
-)
-
-
-def test_retry_operation_impl(monkeypatch):
- monkeypatch.setattr("random.random", lambda: 0.5)
- monkeypatch.setattr(
- issues.Error,
- "__eq__",
- lambda self, other: type(self) == type(other) and self.message == other.message,
- )
-
- retry_once_settings = RetrySettings(
- max_retries=1,
- on_ydb_error_callback=mock.Mock(),
- )
- retry_once_settings.unknown_error_handler = mock.Mock()
-
- def get_results(callee):
- res_generator = retry_operation_impl(callee, retry_settings=retry_once_settings)
- results = []
- exc = None
- try:
- for res in res_generator:
- results.append(res)
- if isinstance(res, YdbRetryOperationFinalResult):
- break
- except Exception as e:
- exc = e
-
- return results, exc
-
- class TestException(Exception):
- def __init__(self, message):
- super(TestException, self).__init__(message)
- self.message = message
-
- def __eq__(self, other):
- return type(self) == type(other) and self.message == other.message
-
- def check_unretriable_error(err_type, call_ydb_handler):
- retry_once_settings.on_ydb_error_callback.reset_mock()
- retry_once_settings.unknown_error_handler.reset_mock()
-
- results = get_results(
- mock.Mock(side_effect=[err_type("test1"), err_type("test2")])
- )
- yields = results[0]
- exc = results[1]
-
- assert yields == []
- assert exc == err_type("test1")
-
- if call_ydb_handler:
- assert retry_once_settings.on_ydb_error_callback.call_count == 1
- retry_once_settings.on_ydb_error_callback.assert_called_with(
- err_type("test1")
- )
-
- assert retry_once_settings.unknown_error_handler.call_count == 0
- else:
- assert retry_once_settings.on_ydb_error_callback.call_count == 0
-
- assert retry_once_settings.unknown_error_handler.call_count == 1
- retry_once_settings.unknown_error_handler.assert_called_with(
- err_type("test1")
- )
-
- def check_retriable_error(err_type, backoff):
- retry_once_settings.on_ydb_error_callback.reset_mock()
-
- results = get_results(
- mock.Mock(side_effect=[err_type("test1"), err_type("test2")])
- )
- yields = results[0]
- exc = results[1]
-
- if backoff:
- assert [
- YdbRetryOperationSleepOpt(backoff.calc_timeout(0)),
- YdbRetryOperationSleepOpt(backoff.calc_timeout(1)),
- ] == yields
- else:
- assert [] == yields
-
- assert exc == err_type("test2")
-
- assert retry_once_settings.on_ydb_error_callback.call_count == 2
- retry_once_settings.on_ydb_error_callback.assert_any_call(err_type("test1"))
- retry_once_settings.on_ydb_error_callback.assert_called_with(err_type("test2"))
-
- assert retry_once_settings.unknown_error_handler.call_count == 0
-
- # check ok
- assert get_results(lambda: True) == ([YdbRetryOperationFinalResult(True)], None)
-
- # check retry error and return result
- assert get_results(mock.Mock(side_effect=[issues.Overloaded("test"), True])) == (
- [
- YdbRetryOperationSleepOpt(retry_once_settings.slow_backoff.calc_timeout(0)),
- YdbRetryOperationFinalResult(True),
- ],
- None,
- )
-
- # check errors
- check_retriable_error(issues.Aborted, None)
- check_retriable_error(issues.BadSession, None)
-
- check_retriable_error(issues.NotFound, None)
- with mock.patch.object(retry_once_settings, "retry_not_found", False):
- check_unretriable_error(issues.NotFound, True)
-
- check_retriable_error(issues.InternalError, None)
- with mock.patch.object(retry_once_settings, "retry_internal_error", False):
- check_unretriable_error(issues.InternalError, True)
-
- check_retriable_error(issues.Overloaded, retry_once_settings.slow_backoff)
- check_retriable_error(issues.SessionPoolEmpty, retry_once_settings.slow_backoff)
- check_retriable_error(issues.ConnectionError, retry_once_settings.slow_backoff)
-
- check_retriable_error(issues.Unavailable, retry_once_settings.fast_backoff)
-
- check_unretriable_error(issues.Undetermined, True)
- with mock.patch.object(retry_once_settings, "idempotent", True):
- check_retriable_error(issues.Unavailable, retry_once_settings.fast_backoff)
-
- check_unretriable_error(issues.Error, True)
- with mock.patch.object(retry_once_settings, "idempotent", True):
- check_unretriable_error(issues.Error, True)
-
- check_unretriable_error(TestException, False)
- with mock.patch.object(retry_once_settings, "idempotent", True):
- check_unretriable_error(TestException, False)
diff --git a/ydb/public/sdk/python3/ydb/topic.py b/ydb/public/sdk/python3/ydb/topic.py
new file mode 100644
index 0000000000..efe62219cb
--- /dev/null
+++ b/ydb/public/sdk/python3/ydb/topic.py
@@ -0,0 +1,389 @@
+from __future__ import annotations
+
+import concurrent.futures
+import datetime
+from dataclasses import dataclass
+from typing import List, Union, Mapping, Optional, Dict, Callable
+
+from . import aio, Credentials, _apis, issues
+
+from . import driver
+
+from ._topic_reader.topic_reader import (
+ PublicReaderSettings as TopicReaderSettings,
+)
+
+from ._topic_reader.topic_reader_sync import TopicReaderSync as TopicReader
+
+from ._topic_reader.topic_reader_asyncio import (
+ PublicAsyncIOReader as TopicReaderAsyncIO,
+)
+
+from ._topic_writer.topic_writer import ( # noqa: F401
+ PublicWriterSettings as TopicWriterSettings,
+ PublicMessage as TopicWriterMessage,
+ RetryPolicy as TopicWriterRetryPolicy,
+)
+
+from ._topic_writer.topic_writer_sync import WriterSync as TopicWriter
+
+from ._topic_common.common import (
+ wrap_operation as _wrap_operation,
+ create_result_wrapper as _create_result_wrapper,
+)
+
+from ydb._topic_writer.topic_writer_asyncio import WriterAsyncIO as TopicWriterAsyncIO
+
+from ._grpc.grpcwrapper import ydb_topic as _ydb_topic
+from ._grpc.grpcwrapper import ydb_topic_public_types as _ydb_topic_public_types
+from ._grpc.grpcwrapper.ydb_topic_public_types import ( # noqa: F401
+ PublicDescribeTopicResult as TopicDescription,
+ PublicMultipleWindowsStat as TopicStatWindow,
+ PublicPartitionStats as TopicPartitionStats,
+ PublicCodec as TopicCodec,
+ PublicConsumer as TopicConsumer,
+ PublicMeteringMode as TopicMeteringMode,
+)
+
+
+class TopicClientAsyncIO:
+ _closed: bool
+ _driver: aio.Driver
+ _credentials: Union[Credentials, None]
+ _settings: TopicClientSettings
+ _executor: concurrent.futures.Executor
+
+ def __init__(
+ self, driver: aio.Driver, settings: Optional[TopicClientSettings] = None
+ ):
+ if not settings:
+ settings = TopicClientSettings()
+ self._closed = False
+ self._driver = driver
+ self._settings = settings
+ self._executor = concurrent.futures.ThreadPoolExecutor(
+ max_workers=settings.encode_decode_threads_count,
+ thread_name_prefix="topic_asyncio_executor",
+ )
+
+ def __del__(self):
+ self.close()
+
+ async def create_topic(
+ self,
+ path: str,
+ min_active_partitions: Optional[int] = None,
+ partition_count_limit: Optional[int] = None,
+ retention_period: Optional[datetime.timedelta] = None,
+ retention_storage_mb: Optional[int] = None,
+ supported_codecs: Optional[List[Union[TopicCodec, int]]] = None,
+ partition_write_speed_bytes_per_second: Optional[int] = None,
+ partition_write_burst_bytes: Optional[int] = None,
+ attributes: Optional[Dict[str, str]] = None,
+ consumers: Optional[List[Union[TopicConsumer, str]]] = None,
+ metering_mode: Optional[TopicMeteringMode] = None,
+ ):
+ """
+ create topic command
+
+ :param path: full path to topic
+ :param min_active_partitions: Minimum partition count auto merge would stop working at.
+ :param partition_count_limit: Limit for total partition count, including active (open for write)
+ and read-only partitions.
+ :param retention_period: How long data in partition should be stored
+ :param retention_storage_mb: How much data in partition should be stored
+ :param supported_codecs: List of allowed codecs for writers. Writes with codec not from this list are forbidden.
+ Empty list mean disable codec compatibility checks for the topic.
+ :param partition_write_speed_bytes_per_second: Partition write speed in bytes per second
+ :param partition_write_burst_bytes: Burst size for write in partition, in bytes
+ :param attributes: User and server attributes of topic.
+ Server attributes starts from "_" and will be validated by server.
+ :param consumers: List of consumers for this topic
+ :param metering_mode: Metering mode for the topic in a serverless database
+ """
+ args = locals().copy()
+ del args["self"]
+ req = _ydb_topic_public_types.CreateTopicRequestParams(**args)
+ req = _ydb_topic.CreateTopicRequest.from_public(req)
+ await self._driver(
+ req.to_proto(),
+ _apis.TopicService.Stub,
+ _apis.TopicService.CreateTopic,
+ _wrap_operation,
+ )
+
+ async def describe_topic(
+ self, path: str, include_stats: bool = False
+ ) -> TopicDescription:
+ args = locals().copy()
+ del args["self"]
+ req = _ydb_topic_public_types.DescribeTopicRequestParams(**args)
+ res = await self._driver(
+ req.to_proto(),
+ _apis.TopicService.Stub,
+ _apis.TopicService.DescribeTopic,
+ _create_result_wrapper(_ydb_topic.DescribeTopicResult),
+ ) # type: _ydb_topic.DescribeTopicResult
+ return res.to_public()
+
+ async def drop_topic(self, path: str):
+ req = _ydb_topic_public_types.DropTopicRequestParams(path=path)
+ await self._driver(
+ req.to_proto(),
+ _apis.TopicService.Stub,
+ _apis.TopicService.DropTopic,
+ _wrap_operation,
+ )
+
+ def reader(
+ self,
+ consumer: str,
+ topic: str,
+ buffer_size_bytes: int = 50 * 1024 * 1024,
+ # decoders: map[codec_code] func(encoded_bytes)->decoded_bytes
+ decoders: Union[Mapping[int, Callable[[bytes], bytes]], None] = None,
+ decoder_executor: Optional[
+ concurrent.futures.Executor
+ ] = None, # default shared client executor pool
+ # on_commit: Callable[["Events.OnCommit"], None] = None
+ # on_get_partition_start_offset: Callable[
+ # ["Events.OnPartitionGetStartOffsetRequest"],
+ # "Events.OnPartitionGetStartOffsetResponse",
+ # ] = None
+ # on_partition_session_start: Callable[["StubEvent"], None] = None
+ # on_partition_session_stop: Callable[["StubEvent"], None] = None
+ # on_partition_session_close: Callable[["StubEvent"], None] = None # todo?
+ # deserializer: Union[Callable[[bytes], Any], None] = None
+ # one_attempt_connection_timeout: Union[float, None] = 1
+ # connection_timeout: Union[float, None] = None
+ # retry_policy: Union["RetryPolicy", None] = None
+ ) -> TopicReaderAsyncIO:
+
+ if not decoder_executor:
+ decoder_executor = self._executor
+
+ args = locals()
+ del args["self"]
+
+ settings = TopicReaderSettings(**args)
+
+ return TopicReaderAsyncIO(self._driver, settings)
+
+ def writer(
+ self,
+ topic,
+ *,
+ producer_id: Optional[str] = None, # default - random
+ session_metadata: Mapping[str, str] = None,
+ partition_id: Union[int, None] = None,
+ auto_seqno: bool = True,
+ auto_created_at: bool = True,
+ codec: Optional[TopicCodec] = None, # default mean auto-select
+ encoders: Optional[
+ Mapping[_ydb_topic_public_types.PublicCodec, Callable[[bytes], bytes]]
+ ] = None,
+ encoder_executor: Optional[
+ concurrent.futures.Executor
+ ] = None, # default shared client executor pool
+ ) -> TopicWriterAsyncIO:
+ args = locals()
+ del args["self"]
+
+ settings = TopicWriterSettings(**args)
+
+ if not settings.encoder_executor:
+ settings.encoder_executor = self._executor
+
+ return TopicWriterAsyncIO(self._driver, settings)
+
+ def close(self):
+ if self._closed:
+ return
+
+ self._closed = True
+ self._executor.shutdown(wait=False, cancel_futures=True)
+
+ def _check_closed(self):
+ if not self._closed:
+ return
+
+ raise RuntimeError("Topic client closed")
+
+
+class TopicClient:
+ _closed: bool
+ _driver: driver.Driver
+ _credentials: Union[Credentials, None]
+ _settings: TopicClientSettings
+ _executor: concurrent.futures.Executor
+
+ def __init__(self, driver: driver.Driver, settings: Optional[TopicClientSettings]):
+ if not settings:
+ settings = TopicClientSettings()
+
+ self._closed = False
+ self._driver = driver
+ self._settings = settings
+ self._executor = concurrent.futures.ThreadPoolExecutor(
+ max_workers=settings.encode_decode_threads_count,
+ thread_name_prefix="topic_asyncio_executor",
+ )
+
+ def __del__(self):
+ self.close()
+
+ def create_topic(
+ self,
+ path: str,
+ min_active_partitions: Optional[int] = None,
+ partition_count_limit: Optional[int] = None,
+ retention_period: Optional[datetime.timedelta] = None,
+ retention_storage_mb: Optional[int] = None,
+ supported_codecs: Optional[List[Union[TopicCodec, int]]] = None,
+ partition_write_speed_bytes_per_second: Optional[int] = None,
+ partition_write_burst_bytes: Optional[int] = None,
+ attributes: Optional[Dict[str, str]] = None,
+ consumers: Optional[List[Union[TopicConsumer, str]]] = None,
+ metering_mode: Optional[TopicMeteringMode] = None,
+ ):
+ """
+ create topic command
+
+ :param path: full path to topic
+ :param min_active_partitions: Minimum partition count auto merge would stop working at.
+ :param partition_count_limit: Limit for total partition count, including active (open for write)
+ and read-only partitions.
+ :param retention_period: How long data in partition should be stored
+ :param retention_storage_mb: How much data in partition should be stored
+ :param supported_codecs: List of allowed codecs for writers. Writes with codec not from this list are forbidden.
+ Empty list mean disable codec compatibility checks for the topic.
+ :param partition_write_speed_bytes_per_second: Partition write speed in bytes per second
+ :param partition_write_burst_bytes: Burst size for write in partition, in bytes
+ :param attributes: User and server attributes of topic.
+ Server attributes starts from "_" and will be validated by server.
+ :param consumers: List of consumers for this topic
+ :param metering_mode: Metering mode for the topic in a serverless database
+ """
+ args = locals().copy()
+ del args["self"]
+ self._check_closed()
+
+ req = _ydb_topic_public_types.CreateTopicRequestParams(**args)
+ req = _ydb_topic.CreateTopicRequest.from_public(req)
+ self._driver(
+ req.to_proto(),
+ _apis.TopicService.Stub,
+ _apis.TopicService.CreateTopic,
+ _wrap_operation,
+ )
+
+ def describe_topic(
+ self, path: str, include_stats: bool = False
+ ) -> TopicDescription:
+ args = locals().copy()
+ del args["self"]
+ self._check_closed()
+
+ req = _ydb_topic_public_types.DescribeTopicRequestParams(**args)
+ res = self._driver(
+ req.to_proto(),
+ _apis.TopicService.Stub,
+ _apis.TopicService.DescribeTopic,
+ _create_result_wrapper(_ydb_topic.DescribeTopicResult),
+ ) # type: _ydb_topic.DescribeTopicResult
+ return res.to_public()
+
+ def drop_topic(self, path: str):
+ self._check_closed()
+
+ req = _ydb_topic_public_types.DropTopicRequestParams(path=path)
+ self._driver(
+ req.to_proto(),
+ _apis.TopicService.Stub,
+ _apis.TopicService.DropTopic,
+ _wrap_operation,
+ )
+
+ def reader(
+ self,
+ consumer: str,
+ topic: str,
+ buffer_size_bytes: int = 50 * 1024 * 1024,
+ # decoders: map[codec_code] func(encoded_bytes)->decoded_bytes
+ decoders: Union[Mapping[int, Callable[[bytes], bytes]], None] = None,
+ decoder_executor: Optional[
+ concurrent.futures.Executor
+ ] = None, # default shared client executor pool
+ # on_commit: Callable[["Events.OnCommit"], None] = None
+ # on_get_partition_start_offset: Callable[
+ # ["Events.OnPartitionGetStartOffsetRequest"],
+ # "Events.OnPartitionGetStartOffsetResponse",
+ # ] = None
+ # on_partition_session_start: Callable[["StubEvent"], None] = None
+ # on_partition_session_stop: Callable[["StubEvent"], None] = None
+ # on_partition_session_close: Callable[["StubEvent"], None] = None # todo?
+ # deserializer: Union[Callable[[bytes], Any], None] = None
+ # one_attempt_connection_timeout: Union[float, None] = 1
+ # connection_timeout: Union[float, None] = None
+ # retry_policy: Union["RetryPolicy", None] = None
+ ) -> TopicReader:
+ if not decoder_executor:
+ decoder_executor = self._executor
+
+ args = locals()
+ del args["self"]
+ self._check_closed()
+
+ settings = TopicReaderSettings(**args)
+
+ return TopicReader(self._driver, settings)
+
+ def writer(
+ self,
+ topic,
+ *,
+ producer_id: Optional[str] = None, # default - random
+ session_metadata: Mapping[str, str] = None,
+ partition_id: Union[int, None] = None,
+ auto_seqno: bool = True,
+ auto_created_at: bool = True,
+ codec: Optional[TopicCodec] = None, # default mean auto-select
+ encoders: Optional[
+ Mapping[_ydb_topic_public_types.PublicCodec, Callable[[bytes], bytes]]
+ ] = None,
+ encoder_executor: Optional[
+ concurrent.futures.Executor
+ ] = None, # default shared client executor pool
+ ) -> TopicWriter:
+ args = locals()
+ del args["self"]
+ self._check_closed()
+
+ settings = TopicWriterSettings(**args)
+
+ if not settings.encoder_executor:
+ settings.encoder_executor = self._executor
+
+ return TopicWriter(self._driver, settings)
+
+ def close(self):
+ if self._closed:
+ return
+
+ self._closed = True
+ self._executor.shutdown(wait=False, cancel_futures=True)
+
+ def _check_closed(self):
+ if not self._closed:
+ return
+
+ raise RuntimeError("Topic client closed")
+
+
+@dataclass
+class TopicClientSettings:
+ encode_decode_threads_count: int = 4
+
+
+class TopicError(issues.Error):
+ pass
diff --git a/ydb/public/sdk/python3/ydb/types.py b/ydb/public/sdk/python3/ydb/types.py
index a62c8a74a0..5ffa16e6b6 100644
--- a/ydb/public/sdk/python3/ydb/types.py
+++ b/ydb/public/sdk/python3/ydb/types.py
@@ -1,7 +1,6 @@
# -*- coding: utf-8 -*-
import abc
import enum
-import six
import json
from . import _utilities, _apis
from datetime import date, datetime, timedelta
@@ -13,13 +12,6 @@ from google.protobuf import struct_pb2
_SECONDS_IN_DAY = 60 * 60 * 24
_EPOCH = datetime(1970, 1, 1)
-if six.PY3:
- _from_bytes = None
-else:
-
- def _from_bytes(x, table_client_settings):
- return _utilities.from_bytes(x)
-
def _from_date(x, table_client_settings):
if (
@@ -52,8 +44,6 @@ def _from_json(x, table_client_settings):
and table_client_settings._native_json_in_result_sets
):
return json.loads(x)
- if _from_bytes is not None:
- return _from_bytes(x, table_client_settings)
return x
@@ -122,7 +112,7 @@ class PrimitiveType(enum.Enum):
Float = _apis.primitive_types.FLOAT, "float_value"
String = _apis.primitive_types.STRING, "bytes_value"
- Utf8 = _apis.primitive_types.UTF8, "text_value", _from_bytes
+ Utf8 = _apis.primitive_types.UTF8, "text_value"
Yson = _apis.primitive_types.YSON, "bytes_value"
Json = _apis.primitive_types.JSON, "text_value", _from_json
@@ -152,7 +142,7 @@ class PrimitiveType(enum.Enum):
_to_interval,
)
- DyNumber = _apis.primitive_types.DYNUMBER, "text_value", _from_bytes
+ DyNumber = _apis.primitive_types.DYNUMBER, "text_value"
def __init__(self, idn, proto_field, to_obj=None, from_obj=None):
self._idn_ = idn
diff --git a/ydb/public/sdk/python3/ydb/ydb_version.py b/ydb/public/sdk/python3/ydb/ydb_version.py
index dde8252c21..ef5ee52a9b 100644
--- a/ydb/public/sdk/python3/ydb/ydb_version.py
+++ b/ydb/public/sdk/python3/ydb/ydb_version.py
@@ -1 +1 @@
-VERSION = "2.13.2"
+VERSION = "3.0.1b9"