diff options
author | rekby <[email protected]> | 2023-03-21 19:27:13 +0300 |
---|---|---|
committer | rekby <[email protected]> | 2023-03-21 19:27:13 +0300 |
commit | 0be64eefa9aa954902612e7c54beaf7f2bc6b89d (patch) | |
tree | 4f9af91d1aec584ce4c5a9d5c909f57945bbcbd9 | |
parent | 208093ca374000ddbce32ab91ed8b81aa03ca3e7 (diff) |
update ydb python sdk to 3.0.1b9
53 files changed, 5132 insertions, 383 deletions
diff --git a/ydb/public/sdk/python3/ya.make b/ydb/public/sdk/python3/ya.make index d878b0da834..d8342cee2b8 100644 --- a/ydb/public/sdk/python3/ya.make +++ b/ydb/public/sdk/python3/ya.make @@ -4,11 +4,32 @@ PY_SRCS( TOP_LEVEL ydb/__init__.py ydb/_apis.py + ydb/_dbapi/__init__.py + ydb/_dbapi/connection.py + ydb/_dbapi/cursor.py + ydb/_dbapi/errors.py ydb/_errors.py ydb/_grpc/__init__.py ydb/_grpc/common/__init__.py + ydb/_grpc/grpcwrapper/__init__.py + ydb/_grpc/grpcwrapper/common_utils.py + ydb/_grpc/grpcwrapper/ydb_scheme.py + ydb/_grpc/grpcwrapper/ydb_topic.py + ydb/_grpc/grpcwrapper/ydb_topic_public_types.py ydb/_session_impl.py ydb/_sp_impl.py + ydb/_topic_common/__init__.py + ydb/_topic_common/common.py + ydb/_topic_common/test_helpers.py + ydb/_topic_reader/__init__.py + ydb/_topic_reader/datatypes.py + ydb/_topic_reader/topic_reader.py + ydb/_topic_reader/topic_reader_asyncio.py + ydb/_topic_reader/topic_reader_sync.py + ydb/_topic_writer/__init__.py + ydb/_topic_writer/topic_writer.py + ydb/_topic_writer/topic_writer_asyncio.py + ydb/_topic_writer/topic_writer_sync.py ydb/_tx_ctx_impl.py ydb/_utilities.py ydb/aio/__init__.py @@ -42,13 +63,12 @@ PY_SRCS( ydb/pool.py ydb/resolver.py ydb/scheme.py - ydb/scheme_test.py ydb/scripting.py ydb/settings.py ydb/sqlalchemy/__init__.py ydb/sqlalchemy/types.py ydb/table.py - ydb/table_test.py + ydb/topic.py ydb/tracing.py ydb/types.py ydb/ydb_version.py diff --git a/ydb/public/sdk/python3/ydb/__init__.py b/ydb/public/sdk/python3/ydb/__init__.py index 7395ae36a1e..648077880e0 100644 --- a/ydb/public/sdk/python3/ydb/__init__.py +++ b/ydb/public/sdk/python3/ydb/__init__.py @@ -13,16 +13,22 @@ from .operation import * # noqa from .scripting import * # noqa from .import_client import * # noqa from .tracing import * # noqa +from .topic import * # noqa try: import ydb.aio as aio # noqa except Exception: pass + try: import kikimr.public.sdk.python.ydb_v3_new_behavior # noqa global_allow_split_transactions(False) # noqa - global_allow_split_transactions(False) # noqa + global_allow_truncated_result(False) # noqa except ModuleNotFoundError: # Old, deprecated + + import warnings + warnings.warn("Used deprecated behavior, for fix ADD PEERDIR kikimr/public/sdk/python/ydb_v3_new_behavior") + global_allow_split_transactions(True) # noqa - global_allow_split_transactions(True) # noqa + global_allow_truncated_result(True) # noqa diff --git a/ydb/public/sdk/python3/ydb/_apis.py b/ydb/public/sdk/python3/ydb/_apis.py index 6f2fc3ab6a7..27bc1bbec81 100644 --- a/ydb/public/sdk/python3/ydb/_apis.py +++ b/ydb/public/sdk/python3/ydb/_apis.py @@ -1,13 +1,15 @@ # -*- coding: utf-8 -*- +import typing + # Workaround for good IDE and universal for runtime -# noinspection PyUnreachableCode -if False: +if typing.TYPE_CHECKING: from ._grpc.v4 import ( ydb_cms_v1_pb2_grpc, ydb_discovery_v1_pb2_grpc, ydb_scheme_v1_pb2_grpc, ydb_table_v1_pb2_grpc, ydb_operation_v1_pb2_grpc, + ydb_topic_v1_pb2_grpc, ) from ._grpc.v4.protos import ( @@ -26,6 +28,7 @@ else: ydb_scheme_v1_pb2_grpc, ydb_table_v1_pb2_grpc, ydb_operation_v1_pb2_grpc, + ydb_topic_v1_pb2_grpc, ) from ._grpc.common.protos import ( @@ -38,6 +41,7 @@ else: ydb_common_pb2, ) + StatusIds = ydb_status_codes_pb2.StatusIds FeatureFlag = ydb_common_pb2.FeatureFlag primitive_types = ydb_value_pb2.Type.PrimitiveTypeId @@ -95,3 +99,13 @@ class TableService(object): KeepAlive = "KeepAlive" StreamReadTable = "StreamReadTable" BulkUpsert = "BulkUpsert" + + +class TopicService(object): + Stub = ydb_topic_v1_pb2_grpc.TopicServiceStub + + CreateTopic = "CreateTopic" + DescribeTopic = "DescribeTopic" + DropTopic = "DropTopic" + StreamRead = "StreamRead" + StreamWrite = "StreamWrite" diff --git a/ydb/public/sdk/python3/ydb/_dbapi/__init__.py b/ydb/public/sdk/python3/ydb/_dbapi/__init__.py new file mode 100644 index 00000000000..8756b0f2d4b --- /dev/null +++ b/ydb/public/sdk/python3/ydb/_dbapi/__init__.py @@ -0,0 +1,36 @@ +from .connection import Connection +from .errors import ( + Warning, + Error, + InterfaceError, + DatabaseError, + DataError, + OperationalError, + IntegrityError, + InternalError, + ProgrammingError, + NotSupportedError, +) + +apilevel = "1.0" + +threadsafety = 0 + +paramstyle = "pyformat" + +errors = ( + Warning, + Error, + InterfaceError, + DatabaseError, + DataError, + OperationalError, + IntegrityError, + InternalError, + ProgrammingError, + NotSupportedError, +) + + +def connect(*args, **kwargs): + return Connection(*args, **kwargs) diff --git a/ydb/public/sdk/python3/ydb/_dbapi/connection.py b/ydb/public/sdk/python3/ydb/_dbapi/connection.py new file mode 100644 index 00000000000..75bfeb582f2 --- /dev/null +++ b/ydb/public/sdk/python3/ydb/_dbapi/connection.py @@ -0,0 +1,73 @@ +import posixpath + +import ydb +from .cursor import Cursor +from .errors import DatabaseError + + +class Connection: + def __init__(self, endpoint, database=None, **conn_kwargs): + self.endpoint = endpoint + self.database = database + self.driver = self._create_driver(self.endpoint, self.database, **conn_kwargs) + self.pool = ydb.SessionPool(self.driver) + + def cursor(self): + return Cursor(self) + + def describe(self, table_path): + full_path = posixpath.join(self.database, table_path) + try: + res = self.pool.retry_operation_sync( + lambda cli: cli.describe_table(full_path) + ) + return res.columns + except ydb.Error as e: + raise DatabaseError(e.message, e.issues, e.status) + except Exception: + raise DatabaseError(f"Failed to describe table {table_path}") + + def check_exists(self, table_path): + try: + self.driver.scheme_client.describe_path(table_path) + return True + except ydb.SchemeError: + return False + + def commit(self): + pass + + def rollback(self): + pass + + def close(self): + if self.pool: + self.pool.stop() + if self.driver: + self.driver.stop() + + @staticmethod + def _create_driver(endpoint, database, **conn_kwargs): + # TODO: add cache for initialized drivers/pools? + driver_config = ydb.DriverConfig( + endpoint, + database=database, + table_client_settings=ydb.TableClientSettings() + .with_native_date_in_result_sets(True) + .with_native_datetime_in_result_sets(True) + .with_native_timestamp_in_result_sets(True) + .with_native_interval_in_result_sets(True) + .with_native_json_in_result_sets(True), + **conn_kwargs, + ) + driver = ydb.Driver(driver_config) + try: + driver.wait(timeout=5, fail_fast=True) + except ydb.Error as e: + raise DatabaseError(e.message, e.issues, e.status) + except Exception: + driver.stop() + raise DatabaseError( + f"Failed to connect to YDB, details {driver.discovery_debug_details()}" + ) + return driver diff --git a/ydb/public/sdk/python3/ydb/_dbapi/cursor.py b/ydb/public/sdk/python3/ydb/_dbapi/cursor.py new file mode 100644 index 00000000000..57659c7abf5 --- /dev/null +++ b/ydb/public/sdk/python3/ydb/_dbapi/cursor.py @@ -0,0 +1,172 @@ +import datetime +import itertools +import logging +import uuid +import decimal + +import ydb +from .errors import DatabaseError, ProgrammingError + + +logger = logging.getLogger(__name__) + + +def get_column_type(type_obj): + return str(ydb.convert.type_to_native(type_obj)) + + +def _generate_type_str(value): + tvalue = type(value) + + stype = { + bool: "Bool", + bytes: "String", + str: "Utf8", + int: "Int64", + float: "Double", + decimal.Decimal: "Decimal(22, 9)", + datetime.date: "Date", + datetime.datetime: "Timestamp", + datetime.timedelta: "Interval", + uuid.UUID: "Uuid", + }.get(tvalue) + + if tvalue == dict: + types_lst = ", ".join(f"{k}: {_generate_type_str(v)}" for k, v in value.items()) + stype = f"Struct<{types_lst}>" + + elif tvalue == tuple: + types_lst = ", ".join(_generate_type_str(x) for x in value) + stype = f"Tuple<{types_lst}>" + + elif tvalue == list: + nested_type = _generate_type_str(value[0]) + stype = f"List<{nested_type}>" + + elif tvalue == set: + nested_type = _generate_type_str(next(iter(value))) + stype = f"Set<{nested_type}>" + + if stype is None: + raise ProgrammingError( + "Cannot translate python type to ydb type.", tvalue, value + ) + + return stype + + +def _generate_declare_stms(params: dict) -> str: + return "".join( + f"DECLARE {k} AS {_generate_type_str(t)}; " for k, t in params.items() + ) + + +class Cursor(object): + def __init__(self, connection): + self.connection = connection + self.description = None + self.arraysize = 1 + self.rows = None + self._rows_prefetched = None + + def execute(self, sql, parameters=None, context=None): + self.description = None + sql_params = None + + if parameters: + sql = sql % {k: f"${k}" for k, v in parameters.items()} + sql_params = {f"${k}": v for k, v in parameters.items()} + declare_stms = _generate_declare_stms(sql_params) + sql = f"{declare_stms}{sql}" + + logger.info("execute sql: %s, params: %s", sql, sql_params) + + def _execute_in_pool(cli): + try: + if context and context.get("isddl"): + return cli.execute_scheme(sql) + else: + prepared_query = cli.prepare(sql) + return cli.transaction().execute( + prepared_query, sql_params, commit_tx=True + ) + except ydb.Error as e: + raise DatabaseError(e.message, e.issues, e.status) + + chunks = self.connection.pool.retry_operation_sync(_execute_in_pool) + rows = self._rows_iterable(chunks) + # Prefetch the description: + try: + first_row = next(rows) + except StopIteration: + pass + else: + rows = itertools.chain((first_row,), rows) + if self.rows is not None: + rows = itertools.chain(self.rows, rows) + + self.rows = rows + + def _rows_iterable(self, chunks_iterable): + try: + for chunk in chunks_iterable: + self.description = [ + ( + col.name, + get_column_type(col.type), + None, + None, + None, + None, + None, + ) + for col in chunk.columns + ] + for row in chunk.rows: + # returns tuple to be compatible with SqlAlchemy and because + # of this PEP to return a sequence: https://www.python.org/dev/peps/pep-0249/#fetchmany + yield row[::] + except ydb.Error as e: + raise DatabaseError(e.message, e.issues, e.status) + + def _ensure_prefetched(self): + if self.rows is not None and self._rows_prefetched is None: + self._rows_prefetched = list(self.rows) + self.rows = iter(self._rows_prefetched) + return self._rows_prefetched + + def executemany(self, sql, seq_of_parameters): + for parameters in seq_of_parameters: + self.execute(sql, parameters) + + def executescript(self, script): + return self.execute(script) + + def fetchone(self): + if self.rows is None: + return None + return next(self.rows, None) + + def fetchmany(self, size=None): + size = self.arraysize if size is None else size + return list(itertools.islice(self.rows, size)) + + def fetchall(self): + return list(self.rows) + + def nextset(self): + self.fetchall() + + def setinputsizes(self, sizes): + pass + + def setoutputsize(self, column=None): + pass + + def close(self): + self.rows = None + self._rows_prefetched = None + + @property + def rowcount(self): + return len(self._ensure_prefetched()) diff --git a/ydb/public/sdk/python3/ydb/_dbapi/errors.py b/ydb/public/sdk/python3/ydb/_dbapi/errors.py new file mode 100644 index 00000000000..ddb55b4c900 --- /dev/null +++ b/ydb/public/sdk/python3/ydb/_dbapi/errors.py @@ -0,0 +1,92 @@ +class Warning(Exception): + pass + + +class Error(Exception): + def __init__(self, message, issues=None, status=None): + super(Error, self).__init__(message) + + pretty_issues = _pretty_issues(issues) + self.issues = issues + self.message = pretty_issues or message + self.status = status + + +class InterfaceError(Error): + pass + + +class DatabaseError(Error): + pass + + +class DataError(DatabaseError): + pass + + +class OperationalError(DatabaseError): + pass + + +class IntegrityError(DatabaseError): + pass + + +class InternalError(DatabaseError): + pass + + +class ProgrammingError(DatabaseError): + pass + + +class NotSupportedError(DatabaseError): + pass + + +def _pretty_issues(issues): + if issues is None: + return None + + children_messages = [_get_messages(issue, root=True) for issue in issues] + + if None in children_messages: + return None + + return "\n" + "\n".join(children_messages) + + +def _get_messages(issue, max_depth=100, indent=2, depth=0, root=False): + if depth >= max_depth: + return None + + margin_str = " " * depth * indent + pre_message = "" + children = "" + + if issue.issues: + collapsed_messages = [] + while not root and len(issue.issues) == 1: + collapsed_messages.append(issue.message) + issue = issue.issues[0] + + if collapsed_messages: + pre_message = f"{margin_str}{', '.join(collapsed_messages)}\n" + depth += 1 + margin_str = " " * depth * indent + + children_messages = [ + _get_messages(iss, max_depth=max_depth, indent=indent, depth=depth + 1) + for iss in issue.issues + ] + + if None in children_messages: + return None + + children = "\n".join(children_messages) + + return ( + f"{pre_message}{margin_str}{issue.message}\n{margin_str}" + f"severity level: {issue.severity}\n{margin_str}" + f"issue code: {issue.issue_code}\n{children}" + ) diff --git a/ydb/public/sdk/python3/ydb/_errors.py b/ydb/public/sdk/python3/ydb/_errors.py index 8c6f072049f..ae3057b6d2f 100644 --- a/ydb/public/sdk/python3/ydb/_errors.py +++ b/ydb/public/sdk/python3/ydb/_errors.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from typing import Optional -from . import issues +from ydb import issues _errors_retriable_fast_backoff_types = [ issues.Unavailable, diff --git a/ydb/public/sdk/python3/ydb/_grpc/grpcwrapper/__init__.py b/ydb/public/sdk/python3/ydb/_grpc/grpcwrapper/__init__.py new file mode 100644 index 00000000000..e69de29bb2d --- /dev/null +++ b/ydb/public/sdk/python3/ydb/_grpc/grpcwrapper/__init__.py diff --git a/ydb/public/sdk/python3/ydb/_grpc/grpcwrapper/common_utils.py b/ydb/public/sdk/python3/ydb/_grpc/grpcwrapper/common_utils.py new file mode 100644 index 00000000000..6c624520ea8 --- /dev/null +++ b/ydb/public/sdk/python3/ydb/_grpc/grpcwrapper/common_utils.py @@ -0,0 +1,309 @@ +from __future__ import annotations + +import abc +import asyncio +import contextvars +import datetime +import functools +import typing +from typing import ( + Optional, + Any, + Iterator, + AsyncIterator, + Callable, + Iterable, + Union, + Coroutine, +) +from dataclasses import dataclass + +import grpc +from google.protobuf.message import Message +from google.protobuf.duration_pb2 import Duration as ProtoDuration +from google.protobuf.timestamp_pb2 import Timestamp as ProtoTimeStamp + +import ydb.aio + +# Workaround for good IDE and universal for runtime +if typing.TYPE_CHECKING: + from ..v4.protos import ydb_topic_pb2, ydb_issue_message_pb2 +else: + from ..common.protos import ydb_topic_pb2, ydb_issue_message_pb2 + +from ... import issues, connection + + +class IFromProto(abc.ABC): + @staticmethod + @abc.abstractmethod + def from_proto(msg: Message) -> Any: + ... + + +class IFromProtoWithProtoType(IFromProto): + @staticmethod + @abc.abstractmethod + def empty_proto_message() -> Message: + ... + + +class IToProto(abc.ABC): + @abc.abstractmethod + def to_proto(self) -> Message: + ... + + +class IFromPublic(abc.ABC): + @staticmethod + @abc.abstractmethod + def from_public(o: typing.Any) -> typing.Any: + ... + + +class IToPublic(abc.ABC): + @abc.abstractmethod + def to_public(self) -> typing.Any: + ... + + +class UnknownGrpcMessageError(issues.Error): + pass + + +_stop_grpc_connection_marker = object() + + +class QueueToIteratorAsyncIO: + __slots__ = ("_queue",) + + def __init__(self, q: asyncio.Queue): + self._queue = q + + def __aiter__(self): + return self + + async def __anext__(self): + item = await self._queue.get() + if item is _stop_grpc_connection_marker: + raise StopAsyncIteration() + return item + + +class AsyncQueueToSyncIteratorAsyncIO: + __slots__ = ( + "_loop", + "_queue", + ) + _queue: asyncio.Queue + + def __init__(self, q: asyncio.Queue): + self._loop = asyncio.get_running_loop() + self._queue = q + + def __iter__(self): + return self + + def __next__(self): + item = asyncio.run_coroutine_threadsafe(self._queue.get(), self._loop).result() + if item is _stop_grpc_connection_marker: + raise StopIteration() + return item + + +class SyncIteratorToAsyncIterator: + def __init__(self, sync_iterator: Iterator): + self._sync_iterator = sync_iterator + + def __aiter__(self): + return self + + async def __anext__(self): + try: + res = await to_thread(self._sync_iterator.__next__) + return res + except StopAsyncIteration: + raise StopIteration() + + +class IGrpcWrapperAsyncIO(abc.ABC): + @abc.abstractmethod + async def receive(self) -> Any: + ... + + @abc.abstractmethod + def write(self, wrap_message: IToProto): + ... + + @abc.abstractmethod + def close(self): + ... + + +SupportedDriverType = Union[ydb.Driver, ydb.aio.Driver] + + +class GrpcWrapperAsyncIO(IGrpcWrapperAsyncIO): + from_client_grpc: asyncio.Queue + from_server_grpc: AsyncIterator + convert_server_grpc_to_wrapper: Callable[[Any], Any] + _connection_state: str + _stream_call: Optional[ + Union[grpc.aio.StreamStreamCall, "grpc._channel._MultiThreadedRendezvous"] + ] + + def __init__(self, convert_server_grpc_to_wrapper): + self.from_client_grpc = asyncio.Queue() + self.convert_server_grpc_to_wrapper = convert_server_grpc_to_wrapper + self._connection_state = "new" + self._stream_call = None + + async def start(self, driver: SupportedDriverType, stub, method): + if asyncio.iscoroutinefunction(driver.__call__): + await self._start_asyncio_driver(driver, stub, method) + else: + await self._start_sync_driver(driver, stub, method) + self._connection_state = "started" + + def close(self): + self.from_client_grpc.put_nowait(_stop_grpc_connection_marker) + if self._stream_call: + self._stream_call.cancel() + + async def _start_asyncio_driver(self, driver: ydb.aio.Driver, stub, method): + requests_iterator = QueueToIteratorAsyncIO(self.from_client_grpc) + stream_call = await driver( + requests_iterator, + stub, + method, + ) + self._stream_call = stream_call + self.from_server_grpc = stream_call.__aiter__() + + async def _start_sync_driver(self, driver: ydb.Driver, stub, method): + requests_iterator = AsyncQueueToSyncIteratorAsyncIO(self.from_client_grpc) + stream_call = await to_thread( + driver, + requests_iterator, + stub, + method, + ) + self._stream_call = stream_call + self.from_server_grpc = SyncIteratorToAsyncIterator(stream_call.__iter__()) + + async def receive(self) -> Any: + # todo handle grpc exceptions and convert it to internal exceptions + try: + grpc_message = await self.from_server_grpc.__anext__() + except grpc.RpcError as e: + raise connection._rpc_error_handler(self._connection_state, e) + + issues._process_response(grpc_message) + + if self._connection_state != "has_received_messages": + self._connection_state = "has_received_messages" + + # print("rekby, grpc, received", grpc_message) + return self.convert_server_grpc_to_wrapper(grpc_message) + + def write(self, wrap_message: IToProto): + grpc_message = wrap_message.to_proto() + # print("rekby, grpc, send", grpc_message) + self.from_client_grpc.put_nowait(grpc_message) + + +@dataclass(init=False) +class ServerStatus(IFromProto): + __slots__ = ("_grpc_status_code", "_issues") + + def __init__( + self, + status: issues.StatusCode, + issues: Iterable[Any], + ): + self.status = status + self.issues = issues + + def __str__(self): + return self.__repr__() + + @staticmethod + def from_proto( + msg: Union[ + ydb_topic_pb2.StreamReadMessage.FromServer, + ydb_topic_pb2.StreamWriteMessage.FromServer, + ] + ) -> "ServerStatus": + return ServerStatus(msg.status, msg.issues) + + def is_success(self) -> bool: + return self.status == issues.StatusCode.SUCCESS + + @classmethod + def issue_to_str(cls, issue: ydb_issue_message_pb2.IssueMessage): + res = """code: %s message: "%s" """ % (issue.issue_code, issue.message) + if len(issue.issues) > 0: + d = ", " + res += d + d.join(str(sub_issue) for sub_issue in issue.issues) + return res + + +def callback_from_asyncio( + callback: Union[Callable, Coroutine] +) -> [asyncio.Future, asyncio.Task]: + loop = asyncio.get_running_loop() + + if asyncio.iscoroutinefunction(callback): + return loop.create_task(callback()) + else: + return loop.run_in_executor(None, callback) + + +async def to_thread(func, /, *args, **kwargs): + """Asynchronously run function *func* in a separate thread. + + Any *args and **kwargs supplied for this function are directly passed + to *func*. Also, the current :class:`contextvars.Context` is propagated, + allowing context variables from the main thread to be accessed in the + separate thread. + + Return a coroutine that can be awaited to get the eventual result of *func*. + + copy to_thread from 3.10 + """ + + loop = asyncio.get_running_loop() + ctx = contextvars.copy_context() + func_call = functools.partial(ctx.run, func, *args, **kwargs) + return await loop.run_in_executor(None, func_call) + + +def proto_duration_from_timedelta(t: Optional[datetime.timedelta]) -> ProtoDuration: + if t is None: + return None + res = ProtoDuration() + res.FromTimedelta(t) + + +def proto_timestamp_from_datetime(t: Optional[datetime.datetime]) -> ProtoTimeStamp: + if t is None: + return None + + res = ProtoTimeStamp() + res.FromDatetime(t) + + +def datetime_from_proto_timestamp( + ts: Optional[ProtoTimeStamp], +) -> Optional[datetime.datetime]: + if ts is None: + return None + return ts.ToDatetime() + + +def timedelta_from_proto_duration( + d: Optional[ProtoDuration], +) -> Optional[datetime.timedelta]: + if d is None: + return None + return d.ToTimedelta() diff --git a/ydb/public/sdk/python3/ydb/_grpc/grpcwrapper/ydb_scheme.py b/ydb/public/sdk/python3/ydb/_grpc/grpcwrapper/ydb_scheme.py new file mode 100644 index 00000000000..b9922035703 --- /dev/null +++ b/ydb/public/sdk/python3/ydb/_grpc/grpcwrapper/ydb_scheme.py @@ -0,0 +1,36 @@ +import datetime +import enum +from dataclasses import dataclass +from typing import List + + +@dataclass +class Entry: + name: str + owner: str + type: "Entry.Type" + effective_permissions: "Permissions" + permissions: "Permissions" + size_bytes: int + created_at: datetime.datetime + + class Type(enum.IntEnum): + UNSPECIFIED = 0 + DIRECTORY = 1 + TABLE = 2 + PERS_QUEUE_GROUP = 3 + DATABASE = 4 + RTMR_VOLUME = 5 + BLOCK_STORE_VOLUME = 6 + COORDINATION_NODE = 7 + COLUMN_STORE = 12 + COLUMN_TABLE = 13 + SEQUENCE = 15 + REPLICATION = 16 + TOPIC = 17 + + +@dataclass +class Permissions: + subject: str + permission_names: List[str] diff --git a/ydb/public/sdk/python3/ydb/_grpc/grpcwrapper/ydb_topic.py b/ydb/public/sdk/python3/ydb/_grpc/grpcwrapper/ydb_topic.py new file mode 100644 index 00000000000..4784d4866bf --- /dev/null +++ b/ydb/public/sdk/python3/ydb/_grpc/grpcwrapper/ydb_topic.py @@ -0,0 +1,1164 @@ +import datetime +import enum +import typing +from dataclasses import dataclass, field +from typing import List, Union, Dict, Optional + +from google.protobuf.message import Message + +from . import ydb_topic_public_types +from ... import scheme + +# Workaround for good IDE and universal for runtime +if typing.TYPE_CHECKING: + from ..v4.protos import ydb_scheme_pb2, ydb_topic_pb2 +else: + from ..common.protos import ydb_scheme_pb2, ydb_topic_pb2 + +from .common_utils import ( + IFromProto, + IFromProtoWithProtoType, + IToProto, + IToPublic, + IFromPublic, + ServerStatus, + UnknownGrpcMessageError, + proto_duration_from_timedelta, + proto_timestamp_from_datetime, + datetime_from_proto_timestamp, + timedelta_from_proto_duration, +) + + +class Codec(int, IToPublic): + CODEC_UNSPECIFIED = 0 + CODEC_RAW = 1 + CODEC_GZIP = 2 + CODEC_LZOP = 3 + CODEC_ZSTD = 4 + + @staticmethod + def from_proto_iterable(codecs: typing.Iterable[int]) -> List["Codec"]: + return [Codec(int(codec)) for codec in codecs] + + def to_public(self) -> ydb_topic_public_types.PublicCodec: + return ydb_topic_public_types.PublicCodec(int(self)) + + +@dataclass +class SupportedCodecs(IToProto, IFromProto, IToPublic): + codecs: List[Codec] + + def to_proto(self) -> ydb_topic_pb2.SupportedCodecs: + return ydb_topic_pb2.SupportedCodecs( + codecs=self.codecs, + ) + + @staticmethod + def from_proto(msg: Optional[ydb_topic_pb2.SupportedCodecs]) -> "SupportedCodecs": + if msg is None: + return SupportedCodecs(codecs=[]) + + return SupportedCodecs( + codecs=Codec.from_proto_iterable(msg.codecs), + ) + + def to_public(self) -> List[ydb_topic_public_types.PublicCodec]: + return list(map(Codec.to_public, self.codecs)) + + +@dataclass(order=True) +class OffsetsRange(IFromProto, IToProto): + """ + half-opened interval, include [start, end) offsets + """ + + __slots__ = ("start", "end") + + start: int # first offset + end: int # offset after last, included to range + + def __post_init__(self): + if self.end < self.start: + raise ValueError( + "offset end must be not less then start. Got start=%s end=%s" + % (self.start, self.end) + ) + + @staticmethod + def from_proto(msg: ydb_topic_pb2.OffsetsRange) -> "OffsetsRange": + return OffsetsRange( + start=msg.start, + end=msg.end, + ) + + def to_proto(self) -> ydb_topic_pb2.OffsetsRange: + return ydb_topic_pb2.OffsetsRange( + start=self.start, + end=self.end, + ) + + def is_intersected_with(self, other: "OffsetsRange") -> bool: + return ( + self.start <= other.start < self.end + or self.start < other.end <= self.end + or other.start <= self.start < other.end + or other.start < self.end <= other.end + ) + + +@dataclass +class UpdateTokenRequest(IToProto): + token: str + + def to_proto(self) -> Message: + res = ydb_topic_pb2.UpdateTokenRequest() + res.token = self.token + return res + + +@dataclass +class UpdateTokenResponse(IFromProto): + @staticmethod + def from_proto(msg: ydb_topic_pb2.UpdateTokenResponse) -> typing.Any: + return UpdateTokenResponse() + + +######################################################################################################################## +# StreamWrite +######################################################################################################################## + + +class StreamWriteMessage: + @dataclass() + class InitRequest(IToProto): + path: str + producer_id: str + write_session_meta: typing.Dict[str, str] + partitioning: "StreamWriteMessage.PartitioningType" + get_last_seq_no: bool + + def to_proto(self) -> ydb_topic_pb2.StreamWriteMessage.InitRequest: + proto = ydb_topic_pb2.StreamWriteMessage.InitRequest() + proto.path = self.path + proto.producer_id = self.producer_id + + if self.partitioning is None: + pass + elif isinstance( + self.partitioning, StreamWriteMessage.PartitioningMessageGroupID + ): + proto.message_group_id = self.partitioning.message_group_id + elif isinstance( + self.partitioning, StreamWriteMessage.PartitioningPartitionID + ): + proto.partition_id = self.partitioning.partition_id + else: + raise Exception( + "Bad partitioning type at StreamWriteMessage.InitRequest" + ) + + if self.write_session_meta: + for key in self.write_session_meta: + proto.write_session_meta[key] = self.write_session_meta[key] + + proto.get_last_seq_no = self.get_last_seq_no + return proto + + @dataclass + class InitResponse(IFromProto): + last_seq_no: Union[int, None] + session_id: str + partition_id: int + supported_codecs: typing.List[int] + status: ServerStatus = None + + @staticmethod + def from_proto( + msg: ydb_topic_pb2.StreamWriteMessage.InitResponse, + ) -> "StreamWriteMessage.InitResponse": + codecs = [] # type: typing.List[int] + if msg.supported_codecs: + for codec in msg.supported_codecs.codecs: + codecs.append(codec) + + return StreamWriteMessage.InitResponse( + last_seq_no=msg.last_seq_no, + session_id=msg.session_id, + partition_id=msg.partition_id, + supported_codecs=codecs, + ) + + @dataclass + class WriteRequest(IToProto): + messages: typing.List["StreamWriteMessage.WriteRequest.MessageData"] + codec: int + + @dataclass + class MessageData(IToProto): + seq_no: int + created_at: datetime.datetime + data: bytes + uncompressed_size: int + partitioning: "StreamWriteMessage.PartitioningType" + + def to_proto( + self, + ) -> ydb_topic_pb2.StreamWriteMessage.WriteRequest.MessageData: + proto = ydb_topic_pb2.StreamWriteMessage.WriteRequest.MessageData() + proto.seq_no = self.seq_no + proto.created_at.FromDatetime(self.created_at) + proto.data = self.data + proto.uncompressed_size = self.uncompressed_size + + if self.partitioning is None: + pass + elif isinstance( + self.partitioning, StreamWriteMessage.PartitioningPartitionID + ): + proto.partition_id = self.partitioning.partition_id + elif isinstance( + self.partitioning, StreamWriteMessage.PartitioningMessageGroupID + ): + proto.message_group_id = self.partitioning.message_group_id + else: + raise Exception( + "Bad partition at StreamWriteMessage.WriteRequest.MessageData" + ) + + return proto + + def to_proto(self) -> ydb_topic_pb2.StreamWriteMessage.WriteRequest: + proto = ydb_topic_pb2.StreamWriteMessage.WriteRequest() + proto.codec = self.codec + + for message in self.messages: + proto_mess = proto.messages.add() + proto_mess.CopyFrom(message.to_proto()) + + return proto + + @dataclass + class WriteResponse(IFromProto): + partition_id: int + acks: typing.List["StreamWriteMessage.WriteResponse.WriteAck"] + write_statistics: "StreamWriteMessage.WriteResponse.WriteStatistics" + status: Optional[ServerStatus] = field(default=None) + + @staticmethod + def from_proto( + msg: ydb_topic_pb2.StreamWriteMessage.WriteResponse, + ) -> "StreamWriteMessage.WriteResponse": + acks = [] + for proto_ack in msg.acks: + ack = StreamWriteMessage.WriteResponse.WriteAck.from_proto(proto_ack) + acks.append(ack) + write_statistics = StreamWriteMessage.WriteResponse.WriteStatistics( + persisting_time=msg.write_statistics.persisting_time.ToTimedelta(), + min_queue_wait_time=msg.write_statistics.min_queue_wait_time.ToTimedelta(), + max_queue_wait_time=msg.write_statistics.max_queue_wait_time.ToTimedelta(), + partition_quota_wait_time=msg.write_statistics.partition_quota_wait_time.ToTimedelta(), + topic_quota_wait_time=msg.write_statistics.topic_quota_wait_time.ToTimedelta(), + ) + return StreamWriteMessage.WriteResponse( + partition_id=msg.partition_id, + acks=acks, + write_statistics=write_statistics, + status=None, + ) + + @dataclass + class WriteAck(IFromProto): + seq_no: int + message_write_status: Union[ + "StreamWriteMessage.WriteResponse.WriteAck.StatusWritten", + "StreamWriteMessage.WriteResponse.WriteAck.StatusSkipped", + int, + ] + + @classmethod + def from_proto( + cls, proto_ack: ydb_topic_pb2.StreamWriteMessage.WriteResponse.WriteAck + ): + if proto_ack.HasField("written"): + message_write_status = ( + StreamWriteMessage.WriteResponse.WriteAck.StatusWritten( + proto_ack.written.offset + ) + ) + elif proto_ack.HasField("skipped"): + reason = proto_ack.skipped.reason + try: + message_write_status = StreamWriteMessage.WriteResponse.WriteAck.StatusSkipped( + reason=StreamWriteMessage.WriteResponse.WriteAck.StatusSkipped.Reason.from_protobuf_code( + reason + ) + ) + except ValueError: + message_write_status = reason + else: + raise NotImplementedError("unexpected ack status") + + return StreamWriteMessage.WriteResponse.WriteAck( + seq_no=proto_ack.seq_no, + message_write_status=message_write_status, + ) + + @dataclass + class StatusWritten: + offset: int + + @dataclass + class StatusSkipped: + reason: "StreamWriteMessage.WriteResponse.WriteAck.StatusSkipped.Reason" + + class Reason(enum.Enum): + UNSPECIFIED = 0 + ALREADY_WRITTEN = 1 + + @classmethod + def from_protobuf_code( + cls, code: int + ) -> Union[ + "StreamWriteMessage.WriteResponse.WriteAck.StatusSkipped.Reason", + int, + ]: + try: + return StreamWriteMessage.WriteResponse.WriteAck.StatusSkipped.Reason( + code + ) + except ValueError: + return code + + @dataclass + class WriteStatistics: + persisting_time: datetime.timedelta + min_queue_wait_time: datetime.timedelta + max_queue_wait_time: datetime.timedelta + partition_quota_wait_time: datetime.timedelta + topic_quota_wait_time: datetime.timedelta + + @dataclass + class PartitioningMessageGroupID: + message_group_id: str + + @dataclass + class PartitioningPartitionID: + partition_id: int + + PartitioningType = Union[PartitioningMessageGroupID, PartitioningPartitionID, None] + + @dataclass + class FromClient(IToProto): + value: "WriterMessagesFromClientToServer" + + def __init__(self, value: "WriterMessagesFromClientToServer"): + self.value = value + + def to_proto(self) -> Message: + res = ydb_topic_pb2.StreamWriteMessage.FromClient() + value = self.value + if isinstance(value, StreamWriteMessage.WriteRequest): + res.write_request.CopyFrom(value.to_proto()) + elif isinstance(value, StreamWriteMessage.InitRequest): + res.init_request.CopyFrom(value.to_proto()) + elif isinstance(value, UpdateTokenRequest): + res.update_token_request.CopyFrom(value.to_proto()) + else: + raise Exception("Unknown outcoming grpc message: %s" % value) + return res + + class FromServer(IFromProto): + @staticmethod + def from_proto(msg: ydb_topic_pb2.StreamWriteMessage.FromServer) -> typing.Any: + message_type = msg.WhichOneof("server_message") + if message_type == "write_response": + res = StreamWriteMessage.WriteResponse.from_proto(msg.write_response) + elif message_type == "init_response": + res = StreamWriteMessage.InitResponse.from_proto(msg.init_response) + elif message_type == "update_token_response": + res = UpdateTokenResponse.from_proto(msg.update_token_response) + else: + # todo log instead of exception - for allow add messages in the future + raise UnknownGrpcMessageError("Unexpected proto message: %s" % msg) + + res.status = ServerStatus(msg.status, msg.issues) + return res + + +WriterMessagesFromClientToServer = Union[ + StreamWriteMessage.InitRequest, StreamWriteMessage.WriteRequest, UpdateTokenRequest +] +WriterMessagesFromServerToClient = Union[ + StreamWriteMessage.InitResponse, + StreamWriteMessage.WriteResponse, + UpdateTokenResponse, +] + + +######################################################################################################################## +# StreamRead +######################################################################################################################## + + +class StreamReadMessage: + @dataclass + class PartitionSession(IFromProto): + partition_session_id: int + path: str + partition_id: int + + @staticmethod + def from_proto( + msg: ydb_topic_pb2.StreamReadMessage.PartitionSession, + ) -> "StreamReadMessage.PartitionSession": + return StreamReadMessage.PartitionSession( + partition_session_id=msg.partition_session_id, + path=msg.path, + partition_id=msg.partition_id, + ) + + @dataclass + class InitRequest(IToProto): + topics_read_settings: List["StreamReadMessage.InitRequest.TopicReadSettings"] + consumer: str + + def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.InitRequest: + res = ydb_topic_pb2.StreamReadMessage.InitRequest() + res.consumer = self.consumer + for settings in self.topics_read_settings: + res.topics_read_settings.append(settings.to_proto()) + return res + + @dataclass + class TopicReadSettings(IToProto): + path: str + partition_ids: List[int] = field(default_factory=list) + max_lag_seconds: Union[datetime.timedelta, None] = None + read_from: Union[int, float, datetime.datetime, None] = None + + def to_proto( + self, + ) -> ydb_topic_pb2.StreamReadMessage.InitRequest.TopicReadSettings: + res = ydb_topic_pb2.StreamReadMessage.InitRequest.TopicReadSettings() + res.path = self.path + res.partition_ids.extend(self.partition_ids) + if self.max_lag_seconds is not None: + res.max_lag = proto_duration_from_timedelta(self.max_lag_seconds) + return res + + @dataclass + class InitResponse(IFromProto): + session_id: str + + @staticmethod + def from_proto( + msg: ydb_topic_pb2.StreamReadMessage.InitResponse, + ) -> "StreamReadMessage.InitResponse": + return StreamReadMessage.InitResponse(session_id=msg.session_id) + + @dataclass + class ReadRequest(IToProto): + bytes_size: int + + def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.ReadRequest: + res = ydb_topic_pb2.StreamReadMessage.ReadRequest() + res.bytes_size = self.bytes_size + return res + + @dataclass + class ReadResponse(IFromProto): + partition_data: List["StreamReadMessage.ReadResponse.PartitionData"] + bytes_size: int + + @staticmethod + def from_proto( + msg: ydb_topic_pb2.StreamReadMessage.ReadResponse, + ) -> "StreamReadMessage.ReadResponse": + partition_data = [] + for proto_partition_data in msg.partition_data: + partition_data.append( + StreamReadMessage.ReadResponse.PartitionData.from_proto( + proto_partition_data + ) + ) + return StreamReadMessage.ReadResponse( + partition_data=partition_data, + bytes_size=msg.bytes_size, + ) + + @dataclass + class MessageData(IFromProto): + offset: int + seq_no: int + created_at: datetime.datetime + data: bytes + uncompresed_size: int + message_group_id: str + + @staticmethod + def from_proto( + msg: ydb_topic_pb2.StreamReadMessage.ReadResponse.MessageData, + ) -> "StreamReadMessage.ReadResponse.MessageData": + return StreamReadMessage.ReadResponse.MessageData( + offset=msg.offset, + seq_no=msg.seq_no, + created_at=msg.created_at.ToDatetime(), + data=msg.data, + uncompresed_size=msg.uncompressed_size, + message_group_id=msg.message_group_id, + ) + + @dataclass + class Batch(IFromProto): + message_data: List["StreamReadMessage.ReadResponse.MessageData"] + producer_id: str + write_session_meta: Dict[str, str] + codec: int + written_at: datetime.datetime + + @staticmethod + def from_proto( + msg: ydb_topic_pb2.StreamReadMessage.ReadResponse.Batch, + ) -> "StreamReadMessage.ReadResponse.Batch": + message_data = [] + for message in msg.message_data: + message_data.append( + StreamReadMessage.ReadResponse.MessageData.from_proto(message) + ) + return StreamReadMessage.ReadResponse.Batch( + message_data=message_data, + producer_id=msg.producer_id, + write_session_meta=dict(msg.write_session_meta), + codec=msg.codec, + written_at=msg.written_at.ToDatetime(), + ) + + @dataclass + class PartitionData(IFromProto): + partition_session_id: int + batches: List["StreamReadMessage.ReadResponse.Batch"] + + @staticmethod + def from_proto( + msg: ydb_topic_pb2.StreamReadMessage.ReadResponse.PartitionData, + ) -> "StreamReadMessage.ReadResponse.PartitionData": + batches = [] + for proto_batch in msg.batches: + batches.append( + StreamReadMessage.ReadResponse.Batch.from_proto(proto_batch) + ) + return StreamReadMessage.ReadResponse.PartitionData( + partition_session_id=msg.partition_session_id, + batches=batches, + ) + + @dataclass + class CommitOffsetRequest(IToProto): + commit_offsets: List["PartitionCommitOffset"] + + def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.CommitOffsetRequest: + res = ydb_topic_pb2.StreamReadMessage.CommitOffsetRequest( + commit_offsets=list( + map( + StreamReadMessage.CommitOffsetRequest.PartitionCommitOffset.to_proto, + self.commit_offsets, + ) + ), + ) + return res + + @dataclass + class PartitionCommitOffset(IToProto): + partition_session_id: int + offsets: List["OffsetsRange"] + + def to_proto( + self, + ) -> ydb_topic_pb2.StreamReadMessage.CommitOffsetRequest.PartitionCommitOffset: + res = ydb_topic_pb2.StreamReadMessage.CommitOffsetRequest.PartitionCommitOffset( + partition_session_id=self.partition_session_id, + offsets=list(map(OffsetsRange.to_proto, self.offsets)), + ) + return res + + @dataclass + class CommitOffsetResponse(IFromProto): + partitions_committed_offsets: List[ + "StreamReadMessage.CommitOffsetResponse.PartitionCommittedOffset" + ] + + @staticmethod + def from_proto( + msg: ydb_topic_pb2.StreamReadMessage.CommitOffsetResponse, + ) -> "StreamReadMessage.CommitOffsetResponse": + return StreamReadMessage.CommitOffsetResponse( + partitions_committed_offsets=list( + map( + StreamReadMessage.CommitOffsetResponse.PartitionCommittedOffset.from_proto, + msg.partitions_committed_offsets, + ) + ) + ) + + @dataclass + class PartitionCommittedOffset(IFromProto): + partition_session_id: int + committed_offset: int + + @staticmethod + def from_proto( + msg: ydb_topic_pb2.StreamReadMessage.CommitOffsetResponse.PartitionCommittedOffset, + ) -> "StreamReadMessage.CommitOffsetResponse.PartitionCommittedOffset": + return StreamReadMessage.CommitOffsetResponse.PartitionCommittedOffset( + partition_session_id=msg.partition_session_id, + committed_offset=msg.committed_offset, + ) + + @dataclass + class PartitionSessionStatusRequest: + partition_session_id: int + + @dataclass + class PartitionSessionStatusResponse: + partition_session_id: int + partition_offsets: "OffsetsRange" + committed_offset: int + write_time_high_watermark: float + + @dataclass + class StartPartitionSessionRequest(IFromProto): + partition_session: "StreamReadMessage.PartitionSession" + committed_offset: int + partition_offsets: "OffsetsRange" + + @staticmethod + def from_proto( + msg: ydb_topic_pb2.StreamReadMessage.StartPartitionSessionRequest, + ) -> "StreamReadMessage.StartPartitionSessionRequest": + return StreamReadMessage.StartPartitionSessionRequest( + partition_session=StreamReadMessage.PartitionSession.from_proto( + msg.partition_session + ), + committed_offset=msg.committed_offset, + partition_offsets=OffsetsRange.from_proto(msg.partition_offsets), + ) + + @dataclass + class StartPartitionSessionResponse(IToProto): + partition_session_id: int + read_offset: Optional[int] + commit_offset: Optional[int] + + def to_proto( + self, + ) -> ydb_topic_pb2.StreamReadMessage.StartPartitionSessionResponse: + res = ydb_topic_pb2.StreamReadMessage.StartPartitionSessionResponse() + res.partition_session_id = self.partition_session_id + if self.read_offset is not None: + res.read_offset = self.read_offset + if self.commit_offset is not None: + res.commit_offset = self.commit_offset + return res + + @dataclass + class StopPartitionSessionRequest: + partition_session_id: int + graceful: bool + committed_offset: int + + @dataclass + class StopPartitionSessionResponse: + partition_session_id: int + + @dataclass + class FromClient(IToProto): + client_message: "ReaderMessagesFromClientToServer" + + def __init__(self, client_message: "ReaderMessagesFromClientToServer"): + self.client_message = client_message + + def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.FromClient: + res = ydb_topic_pb2.StreamReadMessage.FromClient() + if isinstance(self.client_message, StreamReadMessage.ReadRequest): + res.read_request.CopyFrom(self.client_message.to_proto()) + elif isinstance(self.client_message, StreamReadMessage.CommitOffsetRequest): + res.commit_offset_request.CopyFrom(self.client_message.to_proto()) + elif isinstance(self.client_message, StreamReadMessage.InitRequest): + res.init_request.CopyFrom(self.client_message.to_proto()) + elif isinstance(self.client_message, UpdateTokenRequest): + res.update_token_request.CopyFrom(self.client_message.to_proto()) + elif isinstance( + self.client_message, StreamReadMessage.StartPartitionSessionResponse + ): + res.start_partition_session_response.CopyFrom( + self.client_message.to_proto() + ) + else: + raise NotImplementedError( + "Unknown message type: %s" % type(self.client_message) + ) + return res + + @dataclass + class FromServer(IFromProto): + server_message: "ReaderMessagesFromServerToClient" + server_status: ServerStatus + + @staticmethod + def from_proto( + msg: ydb_topic_pb2.StreamReadMessage.FromServer, + ) -> "StreamReadMessage.FromServer": + mess_type = msg.WhichOneof("server_message") + server_status = ServerStatus.from_proto(msg) + if mess_type == "read_response": + return StreamReadMessage.FromServer( + server_status=server_status, + server_message=StreamReadMessage.ReadResponse.from_proto( + msg.read_response + ), + ) + elif mess_type == "commit_offset_response": + return StreamReadMessage.FromServer( + server_status=server_status, + server_message=StreamReadMessage.CommitOffsetResponse.from_proto( + msg.commit_offset_response + ), + ) + elif mess_type == "init_response": + return StreamReadMessage.FromServer( + server_status=server_status, + server_message=StreamReadMessage.InitResponse.from_proto( + msg.init_response + ), + ) + elif mess_type == "start_partition_session_request": + return StreamReadMessage.FromServer( + server_status=server_status, + server_message=StreamReadMessage.StartPartitionSessionRequest.from_proto( + msg.start_partition_session_request + ), + ) + elif mess_type == "update_token_response": + return StreamReadMessage.FromServer( + server_status=server_status, + server_message=UpdateTokenResponse.from_proto( + msg.update_token_response + ), + ) + + # todo replace exception to log + raise NotImplementedError() + + +ReaderMessagesFromClientToServer = Union[ + StreamReadMessage.InitRequest, + StreamReadMessage.ReadRequest, + StreamReadMessage.CommitOffsetRequest, + StreamReadMessage.PartitionSessionStatusRequest, + UpdateTokenRequest, + StreamReadMessage.StartPartitionSessionResponse, + StreamReadMessage.StopPartitionSessionResponse, +] + +ReaderMessagesFromServerToClient = Union[ + StreamReadMessage.InitResponse, + StreamReadMessage.ReadResponse, + StreamReadMessage.CommitOffsetResponse, + StreamReadMessage.PartitionSessionStatusResponse, + UpdateTokenResponse, + StreamReadMessage.StartPartitionSessionRequest, + StreamReadMessage.StopPartitionSessionRequest, +] + + +@dataclass +class MultipleWindowsStat(IFromProto, IToPublic): + per_minute: int + per_hour: int + per_day: int + + @staticmethod + def from_proto( + msg: Optional[ydb_topic_pb2.MultipleWindowsStat], + ) -> Optional["MultipleWindowsStat"]: + if msg is None: + return None + return MultipleWindowsStat( + per_minute=msg.per_minute, + per_hour=msg.per_hour, + per_day=msg.per_day, + ) + + def to_public(self) -> ydb_topic_public_types.PublicMultipleWindowsStat: + return ydb_topic_public_types.PublicMultipleWindowsStat( + per_minute=self.per_minute, + per_hour=self.per_hour, + per_day=self.per_day, + ) + + +@dataclass +class Consumer(IToProto, IFromProto, IFromPublic, IToPublic): + name: str + important: bool + read_from: typing.Optional[datetime.datetime] + supported_codecs: SupportedCodecs + attributes: Dict[str, str] + consumer_stats: typing.Optional["Consumer.ConsumerStats"] + + def to_proto(self) -> ydb_topic_pb2.Consumer: + return ydb_topic_pb2.Consumer( + name=self.name, + important=self.important, + read_from=proto_timestamp_from_datetime(self.read_from), + supported_codecs=self.supported_codecs.to_proto(), + attributes=self.attributes, + # consumer_stats - readonly field + ) + + @staticmethod + def from_proto(msg: Optional[ydb_topic_pb2.Consumer]) -> Optional["Consumer"]: + return Consumer( + name=msg.name, + important=msg.important, + read_from=datetime_from_proto_timestamp(msg.read_from), + supported_codecs=SupportedCodecs.from_proto(msg.supported_codecs), + attributes=dict(msg.attributes), + consumer_stats=Consumer.ConsumerStats.from_proto(msg.consumer_stats), + ) + + @staticmethod + def from_public(consumer: ydb_topic_public_types.PublicConsumer): + if consumer is None: + return None + + supported_codecs = [] + if consumer.supported_codecs is not None: + supported_codecs = consumer.supported_codecs + + return Consumer( + name=consumer.name, + important=consumer.important, + read_from=consumer.read_from, + supported_codecs=SupportedCodecs(codecs=supported_codecs), + attributes=consumer.attributes, + consumer_stats=None, + ) + + def to_public(self) -> ydb_topic_public_types.PublicConsumer: + return ydb_topic_public_types.PublicConsumer( + name=self.name, + important=self.important, + read_from=self.read_from, + supported_codecs=self.supported_codecs.to_public(), + attributes=self.attributes, + ) + + @dataclass + class ConsumerStats(IFromProto): + min_partitions_last_read_time: datetime.datetime + max_read_time_lag: datetime.timedelta + max_write_time_lag: datetime.timedelta + bytes_read: MultipleWindowsStat + + @staticmethod + def from_proto( + msg: ydb_topic_pb2.Consumer.ConsumerStats, + ) -> "Consumer.ConsumerStats": + return Consumer.ConsumerStats( + min_partitions_last_read_time=datetime_from_proto_timestamp( + msg.min_partitions_last_read_time + ), + max_read_time_lag=timedelta_from_proto_duration(msg.max_read_time_lag), + max_write_time_lag=timedelta_from_proto_duration( + msg.max_write_time_lag + ), + bytes_read=MultipleWindowsStat.from_proto(msg.bytes_read), + ) + + +@dataclass +class PartitioningSettings(IToProto, IFromProto): + min_active_partitions: int + partition_count_limit: int + + @staticmethod + def from_proto(msg: ydb_topic_pb2.PartitioningSettings) -> "PartitioningSettings": + return PartitioningSettings( + min_active_partitions=msg.min_active_partitions, + partition_count_limit=msg.partition_count_limit, + ) + + def to_proto(self) -> ydb_topic_pb2.PartitioningSettings: + return ydb_topic_pb2.PartitioningSettings( + min_active_partitions=self.min_active_partitions, + partition_count_limit=self.partition_count_limit, + ) + + +class MeteringMode(int, IFromProto, IFromPublic, IToPublic): + UNSPECIFIED = 0 + RESERVED_CAPACITY = 1 + REQUEST_UNITS = 2 + + @staticmethod + def from_public( + m: Optional[ydb_topic_public_types.PublicMeteringMode], + ) -> Optional["MeteringMode"]: + if m is None: + return None + + return MeteringMode(m) + + @staticmethod + def from_proto(code: Optional[int]) -> Optional["MeteringMode"]: + if code is None: + return None + + return MeteringMode(code) + + def to_public(self) -> ydb_topic_public_types.PublicMeteringMode: + try: + ydb_topic_public_types.PublicMeteringMode(int(self)) + except KeyError: + return ydb_topic_public_types.PublicMeteringMode.UNSPECIFIED + + +@dataclass +class CreateTopicRequest(IToProto, IFromPublic): + path: str + partitioning_settings: "PartitioningSettings" + retention_period: typing.Optional[datetime.timedelta] + retention_storage_mb: typing.Optional[int] + supported_codecs: "SupportedCodecs" + partition_write_speed_bytes_per_second: typing.Optional[int] + partition_write_burst_bytes: typing.Optional[int] + attributes: Dict[str, str] + consumers: List["Consumer"] + metering_mode: "MeteringMode" + + def to_proto(self) -> ydb_topic_pb2.CreateTopicRequest: + return ydb_topic_pb2.CreateTopicRequest( + path=self.path, + partitioning_settings=self.partitioning_settings.to_proto(), + retention_period=proto_duration_from_timedelta(self.retention_period), + retention_storage_mb=self.retention_storage_mb, + supported_codecs=self.supported_codecs.to_proto(), + partition_write_speed_bytes_per_second=self.partition_write_speed_bytes_per_second, + partition_write_burst_bytes=self.partition_write_burst_bytes, + attributes=self.attributes, + consumers=[consumer.to_proto() for consumer in self.consumers], + metering_mode=self.metering_mode, + ) + + @staticmethod + def from_public(req: ydb_topic_public_types.CreateTopicRequestParams): + supported_codecs = [] + + if req.supported_codecs is not None: + supported_codecs = req.supported_codecs + + consumers = [] + if req.consumers is not None: + for consumer in req.consumers: + if isinstance(consumer, str): + consumer = ydb_topic_public_types.PublicConsumer(name=consumer) + consumers.append(Consumer.from_public(consumer)) + + return CreateTopicRequest( + path=req.path, + partitioning_settings=PartitioningSettings( + min_active_partitions=req.min_active_partitions, + partition_count_limit=req.partition_count_limit, + ), + retention_period=req.retention_period, + retention_storage_mb=req.retention_storage_mb, + supported_codecs=SupportedCodecs( + codecs=supported_codecs, + ), + partition_write_speed_bytes_per_second=req.partition_write_speed_bytes_per_second, + partition_write_burst_bytes=req.partition_write_burst_bytes, + attributes=req.attributes, + consumers=consumers, + metering_mode=MeteringMode.from_public(req.metering_mode), + ) + + +@dataclass +class CreateTopicResult: + pass + + +@dataclass +class DescribeTopicRequest: + path: str + include_stats: bool + + +@dataclass +class DescribeTopicResult(IFromProtoWithProtoType, IToPublic): + self_proto: ydb_scheme_pb2.Entry + partitioning_settings: PartitioningSettings + partitions: List["DescribeTopicResult.PartitionInfo"] + retention_period: datetime.timedelta + retention_storage_mb: int + supported_codecs: SupportedCodecs + partition_write_speed_bytes_per_second: int + partition_write_burst_bytes: int + attributes: Dict[str, str] + consumers: List["Consumer"] + metering_mode: MeteringMode + topic_stats: "DescribeTopicResult.TopicStats" + + @staticmethod + def from_proto(msg: ydb_topic_pb2.DescribeTopicResult) -> "DescribeTopicResult": + return DescribeTopicResult( + self_proto=msg.self, + partitioning_settings=PartitioningSettings.from_proto( + msg.partitioning_settings + ), + partitions=list( + map(DescribeTopicResult.PartitionInfo.from_proto, msg.partitions) + ), + retention_period=msg.retention_period, + retention_storage_mb=msg.retention_storage_mb, + supported_codecs=SupportedCodecs.from_proto(msg.supported_codecs), + partition_write_speed_bytes_per_second=msg.partition_write_speed_bytes_per_second, + partition_write_burst_bytes=msg.partition_write_burst_bytes, + attributes=dict(msg.attributes), + consumers=list(map(Consumer.from_proto, msg.consumers)), + metering_mode=MeteringMode.from_proto(msg.metering_mode), + topic_stats=DescribeTopicResult.TopicStats.from_proto(msg.topic_stats), + ) + + @staticmethod + def empty_proto_message() -> ydb_topic_pb2.DescribeTopicResult: + return ydb_topic_pb2.DescribeTopicResult() + + def to_public(self) -> ydb_topic_public_types.PublicDescribeTopicResult: + return ydb_topic_public_types.PublicDescribeTopicResult( + self=scheme._wrap_scheme_entry(self.self_proto), + min_active_partitions=self.partitioning_settings.min_active_partitions, + partition_count_limit=self.partitioning_settings.partition_count_limit, + partitions=list( + map(DescribeTopicResult.PartitionInfo.to_public, self.partitions) + ), + retention_period=self.retention_period, + retention_storage_mb=self.retention_storage_mb, + supported_codecs=self.supported_codecs.to_public(), + partition_write_speed_bytes_per_second=self.partition_write_speed_bytes_per_second, + partition_write_burst_bytes=self.partition_write_burst_bytes, + attributes=self.attributes, + consumers=list(map(Consumer.to_public, self.consumers)), + metering_mode=self.metering_mode.to_public(), + topic_stats=self.topic_stats.to_public(), + ) + + @dataclass + class PartitionInfo(IFromProto, IToPublic): + partition_id: int + active: bool + child_partition_ids: List[int] + parent_partition_ids: List[int] + partition_stats: "PartitionStats" + + @staticmethod + def from_proto( + msg: Optional[ydb_topic_pb2.DescribeTopicResult.PartitionInfo], + ) -> Optional["DescribeTopicResult.PartitionInfo"]: + if msg is None: + return None + + return DescribeTopicResult.PartitionInfo( + partition_id=msg.partition_id, + active=msg.active, + child_partition_ids=list(msg.child_partition_ids), + parent_partition_ids=list(msg.parent_partition_ids), + partition_stats=PartitionStats.from_proto(msg.partition_stats), + ) + + def to_public( + self, + ) -> ydb_topic_public_types.PublicDescribeTopicResult.PartitionInfo: + partition_stats = None + if self.partition_stats is not None: + partition_stats = self.partition_stats.to_public() + return ydb_topic_public_types.PublicDescribeTopicResult.PartitionInfo( + partition_id=self.partition_id, + active=self.active, + child_partition_ids=self.child_partition_ids, + parent_partition_ids=self.parent_partition_ids, + partition_stats=partition_stats, + ) + + @dataclass + class TopicStats(IFromProto, IToPublic): + store_size_bytes: int + min_last_write_time: datetime.datetime + max_write_time_lag: datetime.timedelta + bytes_written: "MultipleWindowsStat" + + @staticmethod + def from_proto( + msg: Optional[ydb_topic_pb2.DescribeTopicResult.TopicStats], + ) -> Optional["DescribeTopicResult.TopicStats"]: + if msg is None: + return None + + return DescribeTopicResult.TopicStats( + store_size_bytes=msg.store_size_bytes, + min_last_write_time=datetime_from_proto_timestamp( + msg.min_last_write_time + ), + max_write_time_lag=timedelta_from_proto_duration( + msg.max_write_time_lag + ), + bytes_written=MultipleWindowsStat.from_proto(msg.bytes_written), + ) + + def to_public( + self, + ) -> ydb_topic_public_types.PublicDescribeTopicResult.TopicStats: + return ydb_topic_public_types.PublicDescribeTopicResult.TopicStats( + store_size_bytes=self.store_size_bytes, + min_last_write_time=self.min_last_write_time, + max_write_time_lag=self.max_write_time_lag, + bytes_written=self.bytes_written.to_public(), + ) + + +@dataclass +class PartitionStats(IFromProto, IToPublic): + partition_offsets: OffsetsRange + store_size_bytes: int + last_write_time: datetime.datetime + max_write_time_lag: datetime.timedelta + bytes_written: "MultipleWindowsStat" + partition_node_id: int + + @staticmethod + def from_proto( + msg: Optional[ydb_topic_pb2.PartitionStats], + ) -> Optional["PartitionStats"]: + if msg is None: + return None + return PartitionStats( + partition_offsets=OffsetsRange.from_proto(msg.partition_offsets), + store_size_bytes=msg.store_size_bytes, + last_write_time=datetime_from_proto_timestamp(msg.last_write_time), + max_write_time_lag=timedelta_from_proto_duration(msg.max_write_time_lag), + bytes_written=MultipleWindowsStat.from_proto(msg.bytes_written), + partition_node_id=msg.partition_node_id, + ) + + def to_public(self) -> ydb_topic_public_types.PublicPartitionStats: + return ydb_topic_public_types.PublicPartitionStats( + partition_start=self.partition_offsets.start, + partition_end=self.partition_offsets.end, + store_size_bytes=self.store_size_bytes, + last_write_time=self.last_write_time, + max_write_time_lag=self.max_write_time_lag, + bytes_written=self.bytes_written.to_public(), + partition_node_id=self.partition_node_id, + ) diff --git a/ydb/public/sdk/python3/ydb/_grpc/grpcwrapper/ydb_topic_public_types.py b/ydb/public/sdk/python3/ydb/_grpc/grpcwrapper/ydb_topic_public_types.py new file mode 100644 index 00000000000..4582f19a024 --- /dev/null +++ b/ydb/public/sdk/python3/ydb/_grpc/grpcwrapper/ydb_topic_public_types.py @@ -0,0 +1,200 @@ +import datetime +import typing +from dataclasses import dataclass, field +from enum import IntEnum +from typing import Optional, List, Union, Dict + +# Workaround for good IDE and universal for runtime +if typing.TYPE_CHECKING: + from ..v4.protos import ydb_topic_pb2 +else: + from ..common.protos import ydb_topic_pb2 + +from .common_utils import IToProto +from ...scheme import SchemeEntry + + +@dataclass +# need similar struct to PublicDescribeTopicResult +class CreateTopicRequestParams: + path: str + min_active_partitions: Optional[int] + partition_count_limit: Optional[int] + retention_period: Optional[datetime.timedelta] + retention_storage_mb: Optional[int] + supported_codecs: Optional[List[Union["PublicCodec", int]]] + partition_write_speed_bytes_per_second: Optional[int] + partition_write_burst_bytes: Optional[int] + attributes: Optional[Dict[str, str]] + consumers: Optional[List[Union["PublicConsumer", str]]] + metering_mode: Optional["PublicMeteringMode"] + + +class PublicCodec(int): + """ + Codec value may contain any int number. + + Values below is only well-known predefined values, + but protocol support custom codecs. + """ + + UNSPECIFIED = 0 + RAW = 1 + GZIP = 2 + LZOP = 3 # Has not supported codec in standard library + ZSTD = 4 # Has not supported codec in standard library + + +class PublicMeteringMode(IntEnum): + UNSPECIFIED = 0 + RESERVED_CAPACITY = 1 + REQUEST_UNITS = 2 + + +@dataclass +class PublicConsumer: + name: str + important: bool = False + """ + Consumer may be marked as 'important'. It means messages for this consumer will never expire due to retention. + User should take care that such consumer never stalls, to prevent running out of disk space. + """ + + read_from: Optional[datetime.datetime] = None + "All messages with smaller server written_at timestamp will be skipped." + + supported_codecs: List[PublicCodec] = field(default_factory=lambda: list()) + """ + List of supported codecs by this consumer. + supported_codecs on topic must be contained inside this list. + """ + + attributes: Dict[str, str] = field(default_factory=lambda: dict()) + "Attributes of consumer" + + +@dataclass +class DropTopicRequestParams(IToProto): + path: str + + def to_proto(self) -> ydb_topic_pb2.DropTopicRequest: + return ydb_topic_pb2.DropTopicRequest(path=self.path) + + +@dataclass +class DescribeTopicRequestParams(IToProto): + path: str + include_stats: bool + + def to_proto(self) -> ydb_topic_pb2.DescribeTopicRequest: + return ydb_topic_pb2.DescribeTopicRequest( + path=self.path, include_stats=self.include_stats + ) + + +@dataclass +# Need similar struct to CreateTopicRequestParams +class PublicDescribeTopicResult: + self: SchemeEntry + "Description of scheme object" + + min_active_partitions: int + "Minimum partition count auto merge would stop working at" + + partition_count_limit: int + "Limit for total partition count, including active (open for write) and read-only partitions" + + partitions: List["PublicDescribeTopicResult.PartitionInfo"] + "Partitions description" + + retention_period: datetime.timedelta + "How long data in partition should be stored" + + retention_storage_mb: int + "How much data in partition should be stored. Zero value means infinite limit" + + supported_codecs: List[PublicCodec] + "List of allowed codecs for writers" + + partition_write_speed_bytes_per_second: int + "Partition write speed in bytes per second" + + partition_write_burst_bytes: int + "Burst size for write in partition, in bytes" + + attributes: Dict[str, str] + """User and server attributes of topic. Server attributes starts from "_" and will be validated by server.""" + + consumers: List[PublicConsumer] + """List of consumers for this topic""" + + metering_mode: PublicMeteringMode + "Metering settings" + + topic_stats: "PublicDescribeTopicResult.TopicStats" + "Statistics of topic" + + @dataclass + class PartitionInfo: + partition_id: int + "Partition identifier" + + active: bool + "Is partition open for write" + + child_partition_ids: List[int] + "Ids of partitions which was formed when this partition was split or merged" + + parent_partition_ids: List[int] + "Ids of partitions from which this partition was formed by split or merge" + + partition_stats: Optional["PublicPartitionStats"] + "Stats for partition, filled only when include_stats in request is true" + + @dataclass + class TopicStats: + store_size_bytes: int + "Approximate size of topic" + + min_last_write_time: datetime.datetime + "Minimum of timestamps of last write among all partitions." + + max_write_time_lag: datetime.timedelta + """ + Maximum of differences between write timestamp and create timestamp for all messages, + written during last minute. + """ + + bytes_written: "PublicMultipleWindowsStat" + "How much bytes were written statistics." + + +@dataclass +class PublicPartitionStats: + partition_start: int + "first message offset in the partition" + + partition_end: int + "offset after last stored message offset in the partition (last offset + 1)" + + store_size_bytes: int + "Approximate size of partition" + + last_write_time: datetime.datetime + "Timestamp of last write" + + max_write_time_lag: datetime.timedelta + "Maximum of differences between write timestamp and create timestamp for all messages, written during last minute." + + bytes_written: "PublicMultipleWindowsStat" + "How much bytes were written during several windows in this partition." + + partition_node_id: int + "Host where tablet for this partition works. Useful for debugging purposes." + + +@dataclass +class PublicMultipleWindowsStat: + per_minute: int + per_hour: int + per_day: int diff --git a/ydb/public/sdk/python3/ydb/_sp_impl.py b/ydb/public/sdk/python3/ydb/_sp_impl.py index a8529d73213..5974a3014b2 100644 --- a/ydb/public/sdk/python3/ydb/_sp_impl.py +++ b/ydb/public/sdk/python3/ydb/_sp_impl.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- import collections from concurrent import futures -from six.moves import queue +import queue import time import threading from . import settings, issues, _utilities, tracing diff --git a/ydb/public/sdk/python3/ydb/_topic_common/__init__.py b/ydb/public/sdk/python3/ydb/_topic_common/__init__.py new file mode 100644 index 00000000000..e69de29bb2d --- /dev/null +++ b/ydb/public/sdk/python3/ydb/_topic_common/__init__.py diff --git a/ydb/public/sdk/python3/ydb/_topic_common/common.py b/ydb/public/sdk/python3/ydb/_topic_common/common.py new file mode 100644 index 00000000000..9e8f1326ed3 --- /dev/null +++ b/ydb/public/sdk/python3/ydb/_topic_common/common.py @@ -0,0 +1,147 @@ +import asyncio +import concurrent.futures +import threading +import typing +from typing import Optional + +from .. import operation, issues +from .._grpc.grpcwrapper.common_utils import IFromProtoWithProtoType + +TimeoutType = typing.Union[int, float, None] + + +def wrap_operation(rpc_state, response_pb, driver=None): + return operation.Operation(rpc_state, response_pb, driver) + + +ResultType = typing.TypeVar("ResultType", bound=IFromProtoWithProtoType) + + +def create_result_wrapper( + result_type: typing.Type[ResultType], +) -> typing.Callable[[typing.Any, typing.Any, typing.Any], ResultType]: + def wrapper(rpc_state, response_pb, driver=None): + issues._process_response(response_pb.operation) + msg = result_type.empty_proto_message() + response_pb.operation.result.Unpack(msg) + return result_type.from_proto(msg) + + return wrapper + + +_shared_event_loop_lock = threading.Lock() +_shared_event_loop: Optional[asyncio.AbstractEventLoop] = None + + +def _get_shared_event_loop() -> asyncio.AbstractEventLoop: + global _shared_event_loop + + if _shared_event_loop is not None: + return _shared_event_loop + + with _shared_event_loop_lock: + if _shared_event_loop is not None: + return _shared_event_loop + + event_loop_set_done = concurrent.futures.Future() + + def start_event_loop(): + event_loop = asyncio.new_event_loop() + event_loop_set_done.set_result(event_loop) + asyncio.set_event_loop(event_loop) + event_loop.run_forever() + + t = threading.Thread( + target=start_event_loop, + name="Common ydb topic event loop", + daemon=True, + ) + t.start() + + _shared_event_loop = event_loop_set_done.result() + return _shared_event_loop + + +class CallFromSyncToAsync: + _loop: asyncio.AbstractEventLoop + + def __init__(self, loop: asyncio.AbstractEventLoop): + self._loop = loop + + def unsafe_call_with_future( + self, coro: typing.Coroutine + ) -> concurrent.futures.Future: + """ + returned result from coro may be lost + """ + return asyncio.run_coroutine_threadsafe(coro, self._loop) + + def unsafe_call_with_result(self, coro: typing.Coroutine, timeout: TimeoutType): + """ + returned result from coro may be lost by race future cancel by timeout and return value from coroutine + """ + f = self.unsafe_call_with_future(coro) + try: + return f.result(timeout) + except concurrent.futures.TimeoutError: + raise TimeoutError() + finally: + if not f.done(): + f.cancel() + + def safe_call_with_result(self, coro: typing.Coroutine, timeout: TimeoutType): + """ + no lost returned value from coro, but may be slower especially timeout latency - it wait coroutine cancelation. + """ + + if timeout is not None and timeout <= 0: + return self._safe_call_fast(coro) + + async def call_coro(): + task = self._loop.create_task(coro) + try: + res = await asyncio.wait_for(task, timeout) + return res + except asyncio.TimeoutError: + try: + res = await task + return res + except asyncio.CancelledError: + pass + + # return builtin TimeoutError instead of asyncio.TimeoutError + raise TimeoutError() + + return asyncio.run_coroutine_threadsafe(call_coro(), self._loop).result() + + def _safe_call_fast(self, coro: typing.Coroutine): + """ + no lost returned value from coro, but may be slower especially timeout latency - it wait coroutine cancelation. + Wait coroutine result only one loop. + """ + res = concurrent.futures.Future() + + async def call_coro(): + try: + res.set_result(await coro) + except asyncio.CancelledError: + res.set_exception(TimeoutError()) + + coro_future = asyncio.run_coroutine_threadsafe(call_coro(), self._loop) + asyncio.run_coroutine_threadsafe(asyncio.sleep(0), self._loop).result() + coro_future.cancel() + return res.result() + + def call_sync(self, callback: typing.Callable[[], typing.Any]) -> typing.Any: + result = concurrent.futures.Future() + + def call_callback(): + try: + res = callback() + result.set_result(res) + except BaseException as err: + result.set_exception(err) + + self._loop.call_soon_threadsafe(call_callback) + + return result.result() diff --git a/ydb/public/sdk/python3/ydb/_topic_common/test_helpers.py b/ydb/public/sdk/python3/ydb/_topic_common/test_helpers.py new file mode 100644 index 00000000000..96a812ab724 --- /dev/null +++ b/ydb/public/sdk/python3/ydb/_topic_common/test_helpers.py @@ -0,0 +1,76 @@ +import asyncio +import time +import typing + +from .._grpc.grpcwrapper.common_utils import IToProto, IGrpcWrapperAsyncIO + + +class StreamMock(IGrpcWrapperAsyncIO): + from_server: asyncio.Queue + from_client: asyncio.Queue + _closed: bool + + def __init__(self): + self.from_server = asyncio.Queue() + self.from_client = asyncio.Queue() + self._closed = False + + async def receive(self) -> typing.Any: + if self._closed: + raise Exception("read from closed StreamMock") + + item = await self.from_server.get() + if item is None: + raise StopAsyncIteration() + if isinstance(item, Exception): + raise item + return item + + def write(self, wrap_message: IToProto): + if self._closed: + raise Exception("write to closed StreamMock") + self.from_client.put_nowait(wrap_message) + + def close(self): + if self._closed: + return + + self._closed = True + self.from_server.put_nowait(None) + + +class WaitConditionError(Exception): + pass + + +async def wait_condition( + f: typing.Callable[[], bool], + timeout: typing.Optional[typing.Union[float, int]] = None, +): + """ + timeout default is 1 second + if timeout is 0 - only counter work. It userful if test need fast timeout for condition (without wait full timeout) + """ + if timeout is None: + timeout = 1 + + minimal_loop_count_for_wait = 1000 + + start = time.monotonic() + counter = 0 + while (time.monotonic() - start < timeout) or counter < minimal_loop_count_for_wait: + counter += 1 + if f(): + return + await asyncio.sleep(0) + + raise WaitConditionError("Bad condition in test") + + +async def wait_for_fast( + awaitable: typing.Awaitable, + timeout: typing.Optional[typing.Union[float, int]] = None, +): + fut = asyncio.ensure_future(awaitable) + await wait_condition(lambda: fut.done(), timeout) + return fut.result() diff --git a/ydb/public/sdk/python3/ydb/_topic_reader/__init__.py b/ydb/public/sdk/python3/ydb/_topic_reader/__init__.py new file mode 100644 index 00000000000..e69de29bb2d --- /dev/null +++ b/ydb/public/sdk/python3/ydb/_topic_reader/__init__.py diff --git a/ydb/public/sdk/python3/ydb/_topic_reader/datatypes.py b/ydb/public/sdk/python3/ydb/_topic_reader/datatypes.py new file mode 100644 index 00000000000..3845995fcfd --- /dev/null +++ b/ydb/public/sdk/python3/ydb/_topic_reader/datatypes.py @@ -0,0 +1,189 @@ +from __future__ import annotations + +import abc +import asyncio +import bisect +import enum +from collections import deque +from dataclasses import dataclass, field +import datetime +from typing import Mapping, Union, Any, List, Dict, Deque, Optional + +from ydb._grpc.grpcwrapper.ydb_topic import OffsetsRange, Codec +from ydb._topic_reader import topic_reader_asyncio + + +class ICommittable(abc.ABC): + @abc.abstractmethod + def _commit_get_partition_session(self) -> PartitionSession: + ... + + @abc.abstractmethod + def _commit_get_offsets_range(self) -> OffsetsRange: + ... + + +class ISessionAlive(abc.ABC): + @property + @abc.abstractmethod + def is_alive(self) -> bool: + pass + + +@dataclass +class PublicMessage(ICommittable, ISessionAlive): + seqno: int + created_at: datetime.datetime + message_group_id: str + session_metadata: Dict[str, str] + offset: int + written_at: datetime.datetime + producer_id: str + data: Union[ + bytes, Any + ] # set as original decompressed bytes or deserialized object if deserializer set in reader + _partition_session: PartitionSession + _commit_start_offset: int + _commit_end_offset: int + + def _commit_get_partition_session(self) -> PartitionSession: + return self._partition_session + + def _commit_get_offsets_range(self) -> OffsetsRange: + return OffsetsRange(self._commit_start_offset, self._commit_end_offset) + + # ISessionAlive implementation + @property + def is_alive(self) -> bool: + raise NotImplementedError() + + +@dataclass +class PartitionSession: + id: int + state: "PartitionSession.State" + topic_path: str + partition_id: int + committed_offset: int # last commit offset, acked from server. Processed messages up to the field-1 offset. + reader_reconnector_id: int + reader_stream_id: int + _next_message_start_commit_offset: int = field(init=False) + + # todo: check if deque is optimal + _ack_waiters: Deque["PartitionSession.CommitAckWaiter"] = field( + init=False, default_factory=lambda: deque() + ) + + _state_changed: asyncio.Event = field( + init=False, default_factory=lambda: asyncio.Event(), compare=False + ) + _loop: Optional[asyncio.AbstractEventLoop] = field( + init=False + ) # may be None in tests + + def __post_init__(self): + self._next_message_start_commit_offset = self.committed_offset + + try: + self._loop = asyncio.get_running_loop() + except RuntimeError: + self._loop = None + + def add_waiter(self, end_offset: int) -> "PartitionSession.CommitAckWaiter": + waiter = PartitionSession.CommitAckWaiter(end_offset, self._create_future()) + if end_offset <= self.committed_offset: + waiter._finish_ok() + return waiter + + # fast way + if len(self._ack_waiters) > 0 and self._ack_waiters[-1].end_offset < end_offset: + self._ack_waiters.append(waiter) + else: + bisect.insort(self._ack_waiters, waiter) + + return waiter + + def _create_future(self) -> asyncio.Future: + if self._loop: + return self._loop.create_future() + else: + return asyncio.Future() + + def ack_notify(self, offset: int): + self._ensure_not_closed() + + self.committed_offset = offset + + if len(self._ack_waiters) == 0: + # todo log warning + # must be never receive ack for not sended request + return + + while len(self._ack_waiters) > 0: + if self._ack_waiters[0].end_offset <= offset: + waiter = self._ack_waiters.popleft() + waiter._finish_ok() + else: + break + + def close(self): + try: + self._ensure_not_closed() + except topic_reader_asyncio.TopicReaderCommitToExpiredPartition: + return + + self.state = PartitionSession.State.Stopped + exception = topic_reader_asyncio.TopicReaderCommitToExpiredPartition() + for waiter in self._ack_waiters: + waiter._finish_error(exception) + + def _ensure_not_closed(self): + if self.state == PartitionSession.State.Stopped: + raise topic_reader_asyncio.TopicReaderCommitToExpiredPartition() + + class State(enum.Enum): + Active = 1 + GracefulShutdown = 2 + Stopped = 3 + + @dataclass(order=True) + class CommitAckWaiter: + end_offset: int + future: asyncio.Future = field(compare=False) + _done: bool = field(default=False, init=False) + _exception: Optional[Exception] = field(default=None, init=False) + + def _finish_ok(self): + self._done = True + self.future.set_result(None) + + def _finish_error(self, error: Exception): + self._exception = error + self.future.set_exception(error) + + +@dataclass +class PublicBatch(ICommittable, ISessionAlive): + session_metadata: Mapping[str, str] + messages: List[PublicMessage] + _partition_session: PartitionSession + _bytes_size: int + _codec: Codec + + def _commit_get_partition_session(self) -> PartitionSession: + return self.messages[0]._commit_get_partition_session() + + def _commit_get_offsets_range(self) -> OffsetsRange: + return OffsetsRange( + self.messages[0]._commit_get_offsets_range().start, + self.messages[-1]._commit_get_offsets_range().end, + ) + + # ISessionAlive implementation + @property + def is_alive(self) -> bool: + state = self._partition_session.state + return ( + state == PartitionSession.State.Active + or state == PartitionSession.State.GracefulShutdown + ) diff --git a/ydb/public/sdk/python3/ydb/_topic_reader/topic_reader.py b/ydb/public/sdk/python3/ydb/_topic_reader/topic_reader.py new file mode 100644 index 00000000000..148d63b33b5 --- /dev/null +++ b/ydb/public/sdk/python3/ydb/_topic_reader/topic_reader.py @@ -0,0 +1,116 @@ +import concurrent.futures +import enum +import datetime +from dataclasses import dataclass +from typing import ( + Union, + Optional, + List, + Mapping, + Callable, +) + +from ..table import RetrySettings +from .._grpc.grpcwrapper.ydb_topic import StreamReadMessage, OffsetsRange + + +class Selector: + path: str + partitions: Union[None, int, List[int]] + read_from_timestamp_ms: Optional[int] + max_time_lag_ms: Optional[int] + + def __init__(self, path, *, partitions: Union[None, int, List[int]] = None): + self.path = path + self.partitions = partitions + + +@dataclass +class PublicReaderSettings: + consumer: str + topic: str + buffer_size_bytes: int = 50 * 1024 * 1024 + + decoders: Union[Mapping[int, Callable[[bytes], bytes]], None] = None + """decoders: map[codec_code] func(encoded_bytes)->decoded_bytes""" + + # decoder_executor, must be set for handle non raw messages + decoder_executor: Optional[concurrent.futures.Executor] = None + + # on_commit: Callable[["Events.OnCommit"], None] = None + # on_get_partition_start_offset: Callable[ + # ["Events.OnPartitionGetStartOffsetRequest"], + # "Events.OnPartitionGetStartOffsetResponse", + # ] = None + # on_partition_session_start: Callable[["StubEvent"], None] = None + # on_partition_session_stop: Callable[["StubEvent"], None] = None + # on_partition_session_close: Callable[["StubEvent"], None] = None # todo? + # deserializer: Union[Callable[[bytes], Any], None] = None + # one_attempt_connection_timeout: Union[float, None] = 1 + # connection_timeout: Union[float, None] = None + # retry_policy: Union["RetryPolicy", None] = None + update_token_interval: Union[int, float] = 3600 + + def _init_message(self) -> StreamReadMessage.InitRequest: + return StreamReadMessage.InitRequest( + topics_read_settings=[ + StreamReadMessage.InitRequest.TopicReadSettings( + path=self.topic, + ) + ], + consumer=self.consumer, + ) + + def _retry_settings(self) -> RetrySettings: + return RetrySettings(idempotent=True) + + +class Events: + class OnCommit: + topic: str + offset: int + + class OnPartitionGetStartOffsetRequest: + topic: str + partition_id: int + + class OnPartitionGetStartOffsetResponse: + start_offset: int + + class OnInitPartition: + pass + + class OnShutdownPatition: + pass + + +class RetryPolicy: + connection_timeout_sec: float + overload_timeout_sec: float + retry_access_denied: bool = False + + +class CommitResult: + topic: str + partition: int + offset: int + state: "CommitResult.State" + details: str # for humans only, content messages may be change in any time + + class State(enum.Enum): + UNSENT = 1 # commit didn't send to the server + SENT = 2 # commit was sent to server, but ack hasn't received + ACKED = 3 # ack from server is received + + +class SessionStat: + path: str + partition_id: str + partition_offsets: OffsetsRange + committed_offset: int + write_time_high_watermark: datetime.datetime + write_time_high_watermark_timestamp_nano: int + + +class StubEvent: + pass diff --git a/ydb/public/sdk/python3/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/public/sdk/python3/ydb/_topic_reader/topic_reader_asyncio.py new file mode 100644 index 00000000000..bb87d3ccc88 --- /dev/null +++ b/ydb/public/sdk/python3/ydb/_topic_reader/topic_reader_asyncio.py @@ -0,0 +1,697 @@ +from __future__ import annotations + +import asyncio +import concurrent.futures +import gzip +import typing +from asyncio import Task +from collections import deque +from typing import Optional, Set, Dict, Union, Callable + +from .. import _apis, issues +from .._utilities import AtomicCounter +from ..aio import Driver +from ..issues import Error as YdbError, _process_response +from . import datatypes +from . import topic_reader +from .._grpc.grpcwrapper.common_utils import ( + IGrpcWrapperAsyncIO, + SupportedDriverType, + GrpcWrapperAsyncIO, +) +from .._grpc.grpcwrapper.ydb_topic import ( + StreamReadMessage, + UpdateTokenRequest, + UpdateTokenResponse, + Codec, +) +from .._errors import check_retriable_error + + +class TopicReaderError(YdbError): + pass + + +class TopicReaderUnexpectedCodec(YdbError): + pass + + +class TopicReaderCommitToExpiredPartition(TopicReaderError): + """ + Commit message when partition read session are dropped. + It is ok - the message/batch will not commit to server and will receive in other read session + (with this or other reader). + """ + + def __init__(self, message: str = "Topic reader partition session is closed"): + super().__init__(message) + + +class TopicReaderStreamClosedError(TopicReaderError): + def __init__(self): + super().__init__("Topic reader stream is closed") + + +class TopicReaderClosedError(TopicReaderError): + def __init__(self): + super().__init__("Topic reader is closed already") + + +class PublicAsyncIOReader: + _loop: asyncio.AbstractEventLoop + _closed: bool + _reconnector: ReaderReconnector + + def __init__(self, driver: Driver, settings: topic_reader.PublicReaderSettings): + self._loop = asyncio.get_running_loop() + self._closed = False + self._reconnector = ReaderReconnector(driver, settings) + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() + + def __del__(self): + if not self._closed: + self._loop.create_task(self.close(), name="close reader") + + async def sessions_stat(self) -> typing.List["topic_reader.SessionStat"]: + """ + Receive stat from the server + + use asyncio.wait_for for wait with timeout. + """ + raise NotImplementedError() + + def messages( + self, *, timeout: typing.Union[float, None] = None + ) -> typing.AsyncIterable[topic_reader.PublicMessage]: + """ + Block until receive new message + + if no new messages in timeout seconds: stop iteration by raise StopAsyncIteration + """ + raise NotImplementedError() + + async def receive_message(self) -> typing.Union[topic_reader.PublicMessage, None]: + """ + Block until receive new message + + use asyncio.wait_for for wait with timeout. + """ + raise NotImplementedError() + + def batches( + self, + *, + max_messages: typing.Union[int, None] = None, + max_bytes: typing.Union[int, None] = None, + timeout: typing.Union[float, None] = None, + ) -> typing.AsyncIterable[datatypes.PublicBatch]: + """ + Block until receive new batch. + All messages in a batch from same partition. + + if no new message in timeout seconds (default - infinite): stop iterations by raise StopIteration + """ + raise NotImplementedError() + + async def receive_batch( + self, + *, + max_messages: typing.Union[int, None] = None, + max_bytes: typing.Union[int, None] = None, + ) -> typing.Union[datatypes.PublicBatch, None]: + """ + Get one messages batch from reader. + All messages in a batch from same partition. + + use asyncio.wait_for for wait with timeout. + """ + await self._reconnector.wait_message() + return self._reconnector.receive_batch_nowait() + + async def commit_on_exit( + self, mess: datatypes.ICommittable + ) -> typing.AsyncContextManager: + """ + commit the mess match/message if exit from context manager without exceptions + + reader will close if exit from context manager with exception + """ + raise NotImplementedError() + + def commit( + self, batch: typing.Union[datatypes.PublicMessage, datatypes.PublicBatch] + ): + """ + Write commit message to a buffer. + + For the method no way check the commit result + (for example if lost connection - commits will not re-send and committed messages will receive again) + """ + self._reconnector.commit(batch) + + async def commit_with_ack( + self, batch: typing.Union[datatypes.PublicMessage, datatypes.PublicBatch] + ): + """ + write commit message to a buffer and wait ack from the server. + + use asyncio.wait_for for wait with timeout. + """ + waiter = self._reconnector.commit(batch) + await waiter.future + + async def flush(self): + """ + force send all commit messages from internal buffers to server and wait acks for all of them. + + use asyncio.wait_for for wait with timeout. + """ + raise NotImplementedError() + + async def close(self): + if self._closed: + raise TopicReaderClosedError() + + self._closed = True + await self._reconnector.close() + + +class ReaderReconnector: + _static_reader_reconnector_counter = AtomicCounter() + + _id: int + _settings: topic_reader.PublicReaderSettings + _driver: Driver + _background_tasks: Set[Task] + + _state_changed: asyncio.Event + _stream_reader: Optional["ReaderStream"] + _first_error: asyncio.Future[YdbError] + + def __init__(self, driver: Driver, settings: topic_reader.PublicReaderSettings): + self._id = self._static_reader_reconnector_counter.inc_and_get() + + self._settings = settings + self._driver = driver + self._background_tasks = set() + + self._state_changed = asyncio.Event() + self._stream_reader = None + self._background_tasks.add(asyncio.create_task(self._connection_loop())) + self._first_error = asyncio.get_running_loop().create_future() + + async def _connection_loop(self): + attempt = 0 + while True: + try: + self._stream_reader = await ReaderStream.create( + self._id, self._driver, self._settings + ) + attempt = 0 + self._state_changed.set() + await self._stream_reader.wait_error() + except issues.Error as err: + retry_info = check_retriable_error( + err, self._settings._retry_settings(), attempt + ) + if not retry_info.is_retriable: + self._set_first_error(err) + return + await asyncio.sleep(retry_info.sleep_timeout_seconds) + + attempt += 1 + + async def wait_message(self): + while True: + if self._first_error.done(): + raise self._first_error.result() + + if self._stream_reader: + try: + await self._stream_reader.wait_messages() + return + except YdbError: + pass # handle errors in reconnection loop + + await self._state_changed.wait() + self._state_changed.clear() + + def receive_batch_nowait(self): + return self._stream_reader.receive_batch_nowait() + + def commit( + self, batch: datatypes.ICommittable + ) -> datatypes.PartitionSession.CommitAckWaiter: + return self._stream_reader.commit(batch) + + async def close(self): + if self._stream_reader: + await self._stream_reader.close() + for task in self._background_tasks: + task.cancel() + + await asyncio.wait(self._background_tasks) + + def _set_first_error(self, err: issues.Error): + try: + self._first_error.set_result(err) + self._state_changed.set() + except asyncio.InvalidStateError: + # skip if already has result + pass + + +class ReaderStream: + _static_id_counter = AtomicCounter() + + _loop: asyncio.AbstractEventLoop + _id: int + _reader_reconnector_id: int + _session_id: str + _stream: Optional[IGrpcWrapperAsyncIO] + _started: bool + _background_tasks: Set[asyncio.Task] + _partition_sessions: Dict[int, datatypes.PartitionSession] + _buffer_size_bytes: int # use for init request, then for debug purposes only + _decode_executor: concurrent.futures.Executor + _decoders: Dict[ + int, typing.Callable[[bytes], bytes] + ] # dict[codec_code] func(encoded_bytes)->decoded_bytes + + if typing.TYPE_CHECKING: + _batches_to_decode: asyncio.Queue[datatypes.PublicBatch] + else: + _batches_to_decode: asyncio.Queue + + _state_changed: asyncio.Event + _closed: bool + _message_batches: typing.Deque[datatypes.PublicBatch] + _first_error: asyncio.Future[YdbError] + + _update_token_interval: Union[int, float] + _update_token_event: asyncio.Event + _get_token_function: Callable[[], str] + + def __init__( + self, + reader_reconnector_id: int, + settings: topic_reader.PublicReaderSettings, + get_token_function: Optional[Callable[[], str]] = None, + ): + self._loop = asyncio.get_running_loop() + self._id = ReaderStream._static_id_counter.inc_and_get() + self._reader_reconnector_id = reader_reconnector_id + self._session_id = "not initialized" + self._stream = None + self._started = False + self._background_tasks = set() + self._partition_sessions = dict() + self._buffer_size_bytes = settings.buffer_size_bytes + self._decode_executor = settings.decoder_executor + + self._decoders = {Codec.CODEC_GZIP: gzip.decompress} + if settings.decoders: + self._decoders.update(settings.decoders) + + self._state_changed = asyncio.Event() + self._closed = False + self._first_error = asyncio.get_running_loop().create_future() + self._batches_to_decode = asyncio.Queue() + self._message_batches = deque() + + self._update_token_interval = settings.update_token_interval + self._get_token_function = get_token_function + self._update_token_event = asyncio.Event() + + @staticmethod + async def create( + reader_reconnector_id: int, + driver: SupportedDriverType, + settings: topic_reader.PublicReaderSettings, + ) -> "ReaderStream": + stream = GrpcWrapperAsyncIO(StreamReadMessage.FromServer.from_proto) + + await stream.start( + driver, _apis.TopicService.Stub, _apis.TopicService.StreamRead + ) + + creds = driver._credentials + reader = ReaderStream( + reader_reconnector_id, + settings, + get_token_function=creds.get_auth_token if creds else None, + ) + await reader._start(stream, settings._init_message()) + return reader + + async def _start( + self, stream: IGrpcWrapperAsyncIO, init_message: StreamReadMessage.InitRequest + ): + if self._started: + raise TopicReaderError("Double start ReaderStream") + + self._started = True + self._stream = stream + + stream.write(StreamReadMessage.FromClient(client_message=init_message)) + init_response = await stream.receive() # type: StreamReadMessage.FromServer + if isinstance(init_response.server_message, StreamReadMessage.InitResponse): + self._session_id = init_response.server_message.session_id + else: + raise TopicReaderError( + "Unexpected message after InitRequest: %s", init_response + ) + + self._update_token_event.set() + + self._background_tasks.add( + asyncio.create_task(self._read_messages_loop(), name="read_messages_loop") + ) + self._background_tasks.add(asyncio.create_task(self._decode_batches_loop())) + if self._get_token_function: + self._background_tasks.add( + asyncio.create_task(self._update_token_loop(), name="update_token_loop") + ) + + async def wait_error(self): + raise await self._first_error + + async def wait_messages(self): + while True: + if self._get_first_error(): + raise self._get_first_error() + + if self._message_batches: + return + + await self._state_changed.wait() + self._state_changed.clear() + + def receive_batch_nowait(self): + if self._get_first_error(): + raise self._get_first_error() + + if not self._message_batches: + return + + batch = self._message_batches.popleft() + self._buffer_release_bytes(batch._bytes_size) + return batch + + def commit( + self, batch: datatypes.ICommittable + ) -> datatypes.PartitionSession.CommitAckWaiter: + partition_session = batch._commit_get_partition_session() + + if partition_session.reader_reconnector_id != self._reader_reconnector_id: + raise TopicReaderError("reader can commit only self-produced messages") + + if partition_session.reader_stream_id != self._id: + raise TopicReaderCommitToExpiredPartition( + "commit messages after reconnect to server" + ) + + if partition_session.id not in self._partition_sessions: + raise TopicReaderCommitToExpiredPartition( + "commit messages after server stop the partition read session" + ) + + commit_range = batch._commit_get_offsets_range() + waiter = partition_session.add_waiter(commit_range.end) + + if not waiter.future.done(): + client_message = StreamReadMessage.CommitOffsetRequest( + commit_offsets=[ + StreamReadMessage.CommitOffsetRequest.PartitionCommitOffset( + partition_session_id=partition_session.id, + offsets=[commit_range], + ) + ] + ) + self._stream.write( + StreamReadMessage.FromClient(client_message=client_message) + ) + + return waiter + + async def _read_messages_loop(self): + try: + self._stream.write( + StreamReadMessage.FromClient( + client_message=StreamReadMessage.ReadRequest( + bytes_size=self._buffer_size_bytes, + ), + ) + ) + while True: + message = ( + await self._stream.receive() + ) # type: StreamReadMessage.FromServer + _process_response(message.server_status) + + if isinstance(message.server_message, StreamReadMessage.ReadResponse): + self._on_read_response(message.server_message) + + elif isinstance( + message.server_message, StreamReadMessage.CommitOffsetResponse + ): + self._on_commit_response(message.server_message) + + elif isinstance( + message.server_message, + StreamReadMessage.StartPartitionSessionRequest, + ): + self._on_start_partition_session(message.server_message) + + elif isinstance( + message.server_message, + StreamReadMessage.StopPartitionSessionRequest, + ): + self._on_partition_session_stop(message.server_message) + + elif isinstance(message.server_message, UpdateTokenResponse): + self._update_token_event.set() + + else: + raise NotImplementedError( + "Unexpected type of StreamReadMessage.FromServer message: %s" + % message.server_message + ) + + self._state_changed.set() + except Exception as e: + self._set_first_error(e) + raise + + async def _update_token_loop(self): + while True: + await asyncio.sleep(self._update_token_interval) + await self._update_token(token=self._get_token_function()) + + async def _update_token(self, token: str): + await self._update_token_event.wait() + try: + msg = StreamReadMessage.FromClient(UpdateTokenRequest(token)) + self._stream.write(msg) + finally: + self._update_token_event.clear() + + def _on_start_partition_session( + self, message: StreamReadMessage.StartPartitionSessionRequest + ): + try: + if ( + message.partition_session.partition_session_id + in self._partition_sessions + ): + raise TopicReaderError( + "Double start partition session: %s" + % message.partition_session.partition_session_id + ) + + self._partition_sessions[ + message.partition_session.partition_session_id + ] = datatypes.PartitionSession( + id=message.partition_session.partition_session_id, + state=datatypes.PartitionSession.State.Active, + topic_path=message.partition_session.path, + partition_id=message.partition_session.partition_id, + committed_offset=message.committed_offset, + reader_reconnector_id=self._reader_reconnector_id, + reader_stream_id=self._id, + ) + self._stream.write( + StreamReadMessage.FromClient( + client_message=StreamReadMessage.StartPartitionSessionResponse( + partition_session_id=message.partition_session.partition_session_id, + read_offset=None, + commit_offset=None, + ) + ), + ) + except YdbError as err: + self._set_first_error(err) + + def _on_partition_session_stop( + self, message: StreamReadMessage.StopPartitionSessionRequest + ): + if message.partition_session_id not in self._partition_sessions: + # may if receive stop partition with graceful=false after response on stop partition + # with graceful=true and remove partition from internal dictionary + return + + partition = self._partition_sessions.pop(message.partition_session_id) + partition.close() + + if message.graceful: + self._stream.write( + StreamReadMessage.FromClient( + client_message=StreamReadMessage.StopPartitionSessionResponse( + partition_session_id=message.partition_session_id, + ) + ) + ) + + def _on_read_response(self, message: StreamReadMessage.ReadResponse): + self._buffer_consume_bytes(message.bytes_size) + + batches = self._read_response_to_batches(message) + for batch in batches: + self._batches_to_decode.put_nowait(batch) + + def _on_commit_response(self, message: StreamReadMessage.CommitOffsetResponse): + for partition_offset in message.partitions_committed_offsets: + if partition_offset.partition_session_id not in self._partition_sessions: + continue + + session = self._partition_sessions[partition_offset.partition_session_id] + session.ack_notify(partition_offset.committed_offset) + + def _buffer_consume_bytes(self, bytes_size): + self._buffer_size_bytes -= bytes_size + + def _buffer_release_bytes(self, bytes_size): + self._buffer_size_bytes += bytes_size + self._stream.write( + StreamReadMessage.FromClient( + client_message=StreamReadMessage.ReadRequest( + bytes_size=bytes_size, + ) + ) + ) + + def _read_response_to_batches( + self, message: StreamReadMessage.ReadResponse + ) -> typing.List[datatypes.PublicBatch]: + batches = [] + + batch_count = sum(len(p.batches) for p in message.partition_data) + if batch_count == 0: + return batches + + bytes_per_batch = message.bytes_size // batch_count + additional_bytes_to_last_batch = ( + message.bytes_size - bytes_per_batch * batch_count + ) + + for partition_data in message.partition_data: + partition_session = self._partition_sessions[ + partition_data.partition_session_id + ] + for server_batch in partition_data.batches: + messages = [] + for message_data in server_batch.message_data: + mess = datatypes.PublicMessage( + seqno=message_data.seq_no, + created_at=message_data.created_at, + message_group_id=message_data.message_group_id, + session_metadata=server_batch.write_session_meta, + offset=message_data.offset, + written_at=server_batch.written_at, + producer_id=server_batch.producer_id, + data=message_data.data, + _partition_session=partition_session, + _commit_start_offset=partition_session._next_message_start_commit_offset, + _commit_end_offset=message_data.offset + 1, + ) + messages.append(mess) + partition_session._next_message_start_commit_offset = ( + mess._commit_end_offset + ) + + if messages: + batch = datatypes.PublicBatch( + session_metadata=server_batch.write_session_meta, + messages=messages, + _partition_session=partition_session, + _bytes_size=bytes_per_batch, + _codec=Codec(server_batch.codec), + ) + batches.append(batch) + + batches[-1]._bytes_size += additional_bytes_to_last_batch + return batches + + async def _decode_batches_loop(self): + while True: + batch = await self._batches_to_decode.get() + await self._decode_batch_inplace(batch) + self._message_batches.append(batch) + self._state_changed.set() + + async def _decode_batch_inplace(self, batch): + if batch._codec == Codec.CODEC_RAW: + return + + try: + decode_func = self._decoders[batch._codec] + except KeyError: + raise TopicReaderUnexpectedCodec( + "Receive message with unexpected codec: %s" % batch._codec + ) + + decode_data_futures = [] + for message in batch.messages: + future = self._loop.run_in_executor( + self._decode_executor, decode_func, message.data + ) + decode_data_futures.append(future) + + decoded_data = await asyncio.gather(*decode_data_futures) + for index, message in enumerate(batch.messages): + message.data = decoded_data[index] + + batch._codec = Codec.CODEC_RAW + + def _set_first_error(self, err: YdbError): + try: + self._first_error.set_result(err) + self._state_changed.set() + except asyncio.InvalidStateError: + # skip later set errors + pass + + def _get_first_error(self) -> Optional[YdbError]: + if self._first_error.done(): + return self._first_error.result() + + async def close(self): + if self._closed: + return + self._closed = True + + self._set_first_error(TopicReaderStreamClosedError()) + self._state_changed.set() + self._stream.close() + + for session in self._partition_sessions.values(): + session.close() + + for task in self._background_tasks: + task.cancel() + await asyncio.wait(self._background_tasks) diff --git a/ydb/public/sdk/python3/ydb/_topic_reader/topic_reader_sync.py b/ydb/public/sdk/python3/ydb/_topic_reader/topic_reader_sync.py new file mode 100644 index 00000000000..30bf92a10e4 --- /dev/null +++ b/ydb/public/sdk/python3/ydb/_topic_reader/topic_reader_sync.py @@ -0,0 +1,210 @@ +import asyncio +import concurrent.futures +import typing +from typing import List, Union, Iterable, Optional + +from ydb._grpc.grpcwrapper.common_utils import SupportedDriverType +from ydb._topic_common.common import ( + _get_shared_event_loop, + CallFromSyncToAsync, + TimeoutType, +) +from ydb._topic_reader import datatypes +from ydb._topic_reader.datatypes import PublicMessage, PublicBatch, ICommittable +from ydb._topic_reader.topic_reader import ( + PublicReaderSettings, + SessionStat, + CommitResult, +) +from ydb._topic_reader.topic_reader_asyncio import ( + PublicAsyncIOReader, + TopicReaderClosedError, +) + + +class TopicReaderSync: + _caller: CallFromSyncToAsync + _async_reader: PublicAsyncIOReader + _closed: bool + + def __init__( + self, + driver: SupportedDriverType, + settings: PublicReaderSettings, + *, + eventloop: Optional[asyncio.AbstractEventLoop] = None, + ): + self._closed = False + + if eventloop: + loop = eventloop + else: + loop = _get_shared_event_loop() + + self._caller = CallFromSyncToAsync(loop) + + async def create_reader(): + return PublicAsyncIOReader(driver, settings) + + self._async_reader = asyncio.run_coroutine_threadsafe( + create_reader(), loop + ).result() + + def __del__(self): + self.close() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + def async_sessions_stat(self) -> concurrent.futures.Future: + """ + Receive stat from the server, return feature. + """ + raise NotImplementedError() + + async def sessions_stat(self) -> List[SessionStat]: + """ + Receive stat from the server + + use async_sessions_stat for set explicit wait timeout + """ + raise NotImplementedError() + + def messages( + self, *, timeout: Union[float, None] = None + ) -> Iterable[PublicMessage]: + """ + todo? + + Block until receive new message + It has no async_ version for prevent lost messages, use async_wait_message as signal for new batches available. + + if no new message in timeout seconds (default - infinite): stop iterations by raise StopIteration + if timeout <= 0 - it will fast non block method, get messages from internal buffer only. + """ + raise NotImplementedError() + + def receive_message(self, *, timeout: Union[float, None] = None) -> PublicMessage: + """ + Block until receive new message + It has no async_ version for prevent lost messages, use async_wait_message as signal for new batches available. + + if no new message in timeout seconds (default - infinite): raise TimeoutError() + if timeout <= 0 - it will fast non block method, get messages from internal buffer only. + """ + raise NotImplementedError() + + def async_wait_message(self) -> concurrent.futures.Future: + """ + Return future, which will completed when the reader has least one message in queue. + If reader already has message - future will return completed. + + Possible situation when receive signal about message available, but no messages when try to receive a message. + If message expired between send event and try to retrieve message (for example connection broken). + """ + raise NotImplementedError() + + def batches( + self, + *, + max_messages: Union[int, None] = None, + max_bytes: Union[int, None] = None, + timeout: Union[float, None] = None, + ) -> Iterable[PublicBatch]: + """ + Block until receive new batch. + It has no async_ version for prevent lost messages, use async_wait_message as signal for new batches available. + + if no new message in timeout seconds (default - infinite): stop iterations by raise StopIteration + if timeout <= 0 - it will fast non block method, get messages from internal buffer only. + """ + raise NotImplementedError() + + def receive_batch( + self, + *, + max_messages: typing.Union[int, None] = None, + max_bytes: typing.Union[int, None] = None, + timeout: Union[float, None] = None, + ) -> Union[PublicBatch, None]: + """ + Get one messages batch from reader + It has no async_ version for prevent lost messages, use async_wait_message as signal for new batches available. + + if no new message in timeout seconds (default - infinite): raise TimeoutError() + if timeout <= 0 - it will fast non block method, get messages from internal buffer only. + """ + self._check_closed() + + return self._caller.safe_call_with_result( + self._async_reader.receive_batch( + max_messages=max_messages, max_bytes=max_bytes + ), + timeout, + ) + + def commit( + self, mess: typing.Union[datatypes.PublicMessage, datatypes.PublicBatch] + ): + """ + Put commit message to internal buffer. + + For the method no way check the commit result + (for example if lost connection - commits will not re-send and committed messages will receive again) + """ + self._check_closed() + + self._caller.call_sync(self._async_reader.commit(mess)) + + def commit_with_ack( + self, mess: ICommittable, timeout: TimeoutType = None + ) -> Union[CommitResult, List[CommitResult]]: + """ + write commit message to a buffer and wait ack from the server. + + if receive in timeout seconds (default - infinite): raise TimeoutError() + """ + self._check_closed() + + return self._caller.unsafe_call_with_result( + self._async_reader.commit_with_ack(mess), timeout + ) + + def async_commit_with_ack( + self, mess: typing.Union[datatypes.PublicMessage, datatypes.PublicBatch] + ) -> concurrent.futures.Future: + """ + write commit message to a buffer and return Future for wait result. + """ + self._check_closed() + + return self._caller.unsafe_call_with_future( + self._async_reader.commit_with_ack(mess) + ) + + def async_flush(self) -> concurrent.futures.Future: + """ + force send all commit messages from internal buffers to server and return Future for wait server acks. + """ + raise NotImplementedError() + + def flush(self): + """ + force send all commit messages from internal buffers to server and wait acks for all of them. + """ + raise NotImplementedError() + + def close(self, *, timeout: TimeoutType = None): + if self._closed: + return + + self._closed = True + + self._caller.safe_call_with_result(self._async_reader.close(), timeout) + + def _check_closed(self): + if self._closed: + raise TopicReaderClosedError() diff --git a/ydb/public/sdk/python3/ydb/_topic_writer/__init__.py b/ydb/public/sdk/python3/ydb/_topic_writer/__init__.py new file mode 100644 index 00000000000..e69de29bb2d --- /dev/null +++ b/ydb/public/sdk/python3/ydb/_topic_writer/__init__.py diff --git a/ydb/public/sdk/python3/ydb/_topic_writer/topic_writer.py b/ydb/public/sdk/python3/ydb/_topic_writer/topic_writer.py new file mode 100644 index 00000000000..59ad74ff80d --- /dev/null +++ b/ydb/public/sdk/python3/ydb/_topic_writer/topic_writer.py @@ -0,0 +1,220 @@ +import concurrent.futures +import datetime +import enum +import uuid +from dataclasses import dataclass +from enum import Enum +from typing import List, Union, Optional, Any, Dict + +import typing + +import ydb.aio +from .._grpc.grpcwrapper.ydb_topic import StreamWriteMessage +from .._grpc.grpcwrapper.common_utils import IToProto +from .._grpc.grpcwrapper.ydb_topic_public_types import PublicCodec + +Message = typing.Union["PublicMessage", "PublicMessage.SimpleMessageSourceType"] + + +@dataclass +class PublicWriterSettings: + """ + Settings for topic writer. + + order of fields IS NOT stable, use keywords only + """ + + topic: str + producer_id: Optional[str] = None + session_metadata: Optional[Dict[str, str]] = None + partition_id: Optional[int] = None + auto_seqno: bool = True + auto_created_at: bool = True + codec: Optional[PublicCodec] = None # default mean auto-select + encoder_executor: Optional[ + concurrent.futures.Executor + ] = None # default shared client executor pool + encoders: Optional[ + typing.Mapping[PublicCodec, typing.Callable[[bytes], bytes]] + ] = None + # get_last_seqno: bool = False + # serializer: Union[Callable[[Any], bytes], None] = None + # send_buffer_count: Optional[int] = 10000 + # send_buffer_bytes: Optional[int] = 100 * 1024 * 1024 + # codec: Optional[int] = None + # codec_autoselect: bool = True + # retry_policy: Optional["RetryPolicy"] = None + update_token_interval: Union[int, float] = 3600 + + def __post_init__(self): + if self.producer_id is None: + self.producer_id = uuid.uuid4().hex + + +@dataclass +class PublicWriteResult: + @dataclass(eq=True) + class Written: + __slots__ = "offset" + offset: int + + @dataclass(eq=True) + class Skipped: + pass + + +PublicWriteResultTypes = Union[PublicWriteResult.Written, PublicWriteResult.Skipped] + + +class WriterSettings(PublicWriterSettings): + def __init__(self, settings: PublicWriterSettings): + self.__dict__ = settings.__dict__.copy() + + def create_init_request(self) -> StreamWriteMessage.InitRequest: + return StreamWriteMessage.InitRequest( + path=self.topic, + producer_id=self.producer_id, + write_session_meta=self.session_metadata, + partitioning=self.get_partitioning(), + get_last_seq_no=True, + ) + + def get_partitioning(self) -> StreamWriteMessage.PartitioningType: + if self.partition_id is not None: + return StreamWriteMessage.PartitioningPartitionID(self.partition_id) + return StreamWriteMessage.PartitioningMessageGroupID(self.producer_id) + + +class SendMode(Enum): + ASYNC = 1 + SYNC = 2 + + +@dataclass +class PublicWriterInitInfo: + __slots__ = ("last_seqno", "supported_codecs") + last_seqno: Optional[int] + supported_codecs: List[PublicCodec] + + +class PublicMessage: + seqno: Optional[int] + created_at: Optional[datetime.datetime] + data: "PublicMessage.SimpleMessageSourceType" + + SimpleMessageSourceType = Union[str, bytes] # Will be extend + + def __init__( + self, + data: SimpleMessageSourceType, + *, + seqno: Optional[int] = None, + created_at: Optional[datetime.datetime] = None, + ): + self.seqno = seqno + self.created_at = created_at + self.data = data + + @staticmethod + def _create_message(data: Message) -> "PublicMessage": + if isinstance(data, PublicMessage): + return data + return PublicMessage(data=data) + + +class InternalMessage(StreamWriteMessage.WriteRequest.MessageData, IToProto): + codec: PublicCodec + + def __init__(self, mess: PublicMessage): + super().__init__( + seq_no=mess.seqno, + created_at=mess.created_at, + data=mess.data, + uncompressed_size=len(mess.data), + partitioning=None, + ) + self.codec = PublicCodec.RAW + + def get_bytes(self) -> bytes: + if self.data is None: + return bytes() + if isinstance(self.data, bytes): + return self.data + if isinstance(self.data, str): + return self.data.encode("utf-8") + raise ValueError("Bad data type") + + def to_message_data(self) -> StreamWriteMessage.WriteRequest.MessageData: + data = self.get_bytes() + return StreamWriteMessage.WriteRequest.MessageData( + seq_no=self.seq_no, + created_at=self.created_at, + data=data, + uncompressed_size=len(data), + partitioning=None, # unsupported by server now + ) + + +class MessageSendResult: + offset: Optional[int] + write_status: "MessageWriteStatus" + + +class MessageWriteStatus(enum.Enum): + Written = 1 + AlreadyWritten = 2 + + +class RetryPolicy: + connection_timeout_sec: float + overload_timeout_sec: float + retry_access_denied: bool = False + + +class TopicWriterError(ydb.Error): + def __init__(self, message: str): + super(TopicWriterError, self).__init__(message) + + +class TopicWriterClosedError(ydb.Error): + def __init__(self): + super().__init__("Topic writer already closed") + + +class TopicWriterRepeatableError(TopicWriterError): + pass + + +class TopicWriterStopped(TopicWriterError): + def __init__(self): + super(TopicWriterStopped, self).__init__( + "topic writer was stopped by call close" + ) + + +def default_serializer_message_content(data: Any) -> bytes: + if data is None: + return bytes() + if isinstance(data, bytes): + return data + if isinstance(data, bytearray): + return bytes(data) + if isinstance(data, str): + return data.encode(encoding="utf-8") + raise ValueError("can't serialize type %s to bytes" % type(data)) + + +def messages_to_proto_requests( + messages: List[InternalMessage], +) -> List[StreamWriteMessage.FromClient]: + # todo split by proto message size and codec + res = [] + for msg in messages: + req = StreamWriteMessage.FromClient( + StreamWriteMessage.WriteRequest( + messages=[msg.to_message_data()], + codec=msg.codec, + ) + ) + res.append(req) + return res diff --git a/ydb/public/sdk/python3/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/public/sdk/python3/ydb/_topic_writer/topic_writer_asyncio.py new file mode 100644 index 00000000000..7cb1f1db0b9 --- /dev/null +++ b/ydb/public/sdk/python3/ydb/_topic_writer/topic_writer_asyncio.py @@ -0,0 +1,684 @@ +import asyncio +import concurrent.futures +import datetime +import gzip +import typing +from collections import deque +from typing import Deque, AsyncIterator, Union, List, Optional, Dict, Callable + +import ydb +from .topic_writer import ( + PublicWriterSettings, + WriterSettings, + PublicMessage, + PublicWriterInitInfo, + InternalMessage, + TopicWriterStopped, + TopicWriterError, + messages_to_proto_requests, + PublicWriteResultTypes, + Message, +) +from .. import ( + _apis, + issues, + check_retriable_error, + RetrySettings, +) +from .._grpc.grpcwrapper.ydb_topic_public_types import PublicCodec +from .._grpc.grpcwrapper.ydb_topic import ( + UpdateTokenRequest, + UpdateTokenResponse, + StreamWriteMessage, + WriterMessagesFromServerToClient, +) +from .._grpc.grpcwrapper.common_utils import ( + IGrpcWrapperAsyncIO, + SupportedDriverType, + GrpcWrapperAsyncIO, +) + + +class WriterAsyncIO: + _loop: asyncio.AbstractEventLoop + _reconnector: "WriterAsyncIOReconnector" + _closed: bool + + @property + def last_seqno(self) -> int: + raise NotImplementedError() + + def __init__(self, driver: SupportedDriverType, settings: PublicWriterSettings): + self._loop = asyncio.get_running_loop() + self._closed = False + self._reconnector = WriterAsyncIOReconnector( + driver=driver, settings=WriterSettings(settings) + ) + + async def __aenter__(self) -> "WriterAsyncIO": + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() + + def __del__(self): + if self._closed or self._loop.is_closed(): + return + + self._loop.call_soon(self.close) + + async def close(self, *, flush: bool = True): + if self._closed: + return + + self._closed = True + + await self._reconnector.close(flush) + + async def write_with_ack( + self, + messages: Union[Message, List[Message]], + ) -> Union[PublicWriteResultTypes, List[PublicWriteResultTypes]]: + """ + IT IS SLOWLY WAY. IT IS BAD CHOISE IN MOST CASES. + It is recommended to use write with optionally flush or write_with_ack_futures and receive acks by wait futures. + + send one or number of messages to server and wait acks. + + For wait with timeout use asyncio.wait_for. + """ + futures = await self.write_with_ack_future(messages) + if not isinstance(futures, list): + futures = [futures] + + await asyncio.wait(futures) + results = [f.result() for f in futures] + + return results if isinstance(messages, list) else results[0] + + async def write_with_ack_future( + self, + messages: Union[Message, List[Message]], + ) -> Union[asyncio.Future, List[asyncio.Future]]: + """ + send one or number of messages to server. + return feature, which can be waited for check send result. + + Usually it is fast method, but can wait if internal buffer is full. + + For wait with timeout use asyncio.wait_for. + """ + input_single_message = not isinstance(messages, list) + converted_messages = [] + if isinstance(messages, list): + for m in messages: + converted_messages.append(PublicMessage._create_message(m)) + else: + converted_messages = [PublicMessage._create_message(messages)] + + futures = await self._reconnector.write_with_ack_future(converted_messages) + if input_single_message: + return futures[0] + else: + return futures + + async def write( + self, + messages: Union[Message, List[Message]], + ): + """ + send one or number of messages to server. + it put message to internal buffer + + For wait with timeout use asyncio.wait_for. + """ + await self.write_with_ack_future(messages) + + async def flush(self): + """ + Force send all messages from internal buffer and wait acks from server for all + messages. + + For wait with timeout use asyncio.wait_for. + """ + return await self._reconnector.flush() + + async def wait_init(self) -> PublicWriterInitInfo: + """ + wait while real connection will be established to server. + + For wait with timeout use asyncio.wait_for() + """ + return await self._reconnector.wait_init() + + +class WriterAsyncIOReconnector: + _closed: bool + _loop: asyncio.AbstractEventLoop + _credentials: Union[ydb.credentials.Credentials, None] + _driver: ydb.aio.Driver + _init_message: StreamWriteMessage.InitRequest + _init_info: asyncio.Future + _stream_connected: asyncio.Event + _settings: WriterSettings + _codec: PublicCodec + _codec_functions: Dict[PublicCodec, Callable[[bytes], bytes]] + _encode_executor: Optional[concurrent.futures.Executor] + _codec_selector_batch_num: int + _codec_selector_last_codec: Optional[PublicCodec] + _codec_selector_check_batches_interval: int + + _last_known_seq_no: int + if typing.TYPE_CHECKING: + _messages_for_encode: asyncio.Queue[List[InternalMessage]] + else: + _messages_for_encode: asyncio.Queue + _messages: Deque[InternalMessage] + _messages_future: Deque[asyncio.Future] + _new_messages: asyncio.Queue + _stop_reason: asyncio.Future + _background_tasks: List[asyncio.Task] + + def __init__(self, driver: SupportedDriverType, settings: WriterSettings): + self._closed = False + self._loop = asyncio.get_running_loop() + self._driver = driver + self._credentials = driver._credentials + self._init_message = settings.create_init_request() + self._new_messages = asyncio.Queue() + self._init_info = self._loop.create_future() + self._stream_connected = asyncio.Event() + self._settings = settings + + self._codec_functions = { + PublicCodec.RAW: lambda data: data, + PublicCodec.GZIP: gzip.compress, + } + + if settings.encoders: + self._codec_functions.update(settings.encoders) + + self._encode_executor = settings.encoder_executor + + self._codec_selector_batch_num = 0 + self._codec_selector_last_codec = None + self._codec_selector_check_batches_interval = 10000 + + self._codec = self._settings.codec + if self._codec and self._codec not in self._codec_functions: + known_codecs = sorted(self._codec_functions.keys()) + raise ValueError( + "Unknown codec for writer: %s, supported codecs: %s" + % (self._codec, known_codecs) + ) + + self._last_known_seq_no = 0 + self._messages_for_encode = asyncio.Queue() + self._messages = deque() + self._messages_future = deque() + self._new_messages = asyncio.Queue() + self._stop_reason = self._loop.create_future() + self._background_tasks = [ + asyncio.create_task(self._connection_loop(), name="connection_loop"), + asyncio.create_task(self._encode_loop(), name="encode_loop"), + ] + + async def close(self, flush: bool): + if self._closed: + return + + if flush: + await self.flush() + + self._closed = True + self._stop(TopicWriterStopped()) + + for task in self._background_tasks: + task.cancel() + await asyncio.wait(self._background_tasks) + + # if work was stopped before close by error - raise the error + try: + self._check_stop() + except TopicWriterStopped: + pass + + async def wait_init(self) -> PublicWriterInitInfo: + done, _ = await asyncio.wait( + [self._init_info, self._stop_reason], return_when=asyncio.FIRST_COMPLETED + ) + res = done.pop() # type: asyncio.Future + res_val = res.result() + + if isinstance(res_val, BaseException): + raise res_val + + return res_val + + async def wait_stop(self) -> Exception: + return await self._stop_reason + + async def write_with_ack_future( + self, messages: List[PublicMessage] + ) -> List[asyncio.Future]: + # todo check internal buffer limit + self._check_stop() + + if self._settings.auto_seqno: + await self.wait_init() + + internal_messages = self._prepare_internal_messages(messages) + messages_future = [self._loop.create_future() for _ in internal_messages] + + self._messages_future.extend(messages_future) + + if self._codec == PublicCodec.RAW: + self._add_messages_to_send_queue(internal_messages) + else: + self._messages_for_encode.put_nowait(internal_messages) + + return messages_future + + def _add_messages_to_send_queue(self, internal_messages: List[InternalMessage]): + self._messages.extend(internal_messages) + for m in internal_messages: + self._new_messages.put_nowait(m) + + def _prepare_internal_messages( + self, messages: List[PublicMessage] + ) -> List[InternalMessage]: + if self._settings.auto_created_at: + now = datetime.datetime.now() + else: + now = None + + res = [] + for m in messages: + internal_message = InternalMessage(m) + if self._settings.auto_seqno: + if internal_message.seq_no is None: + self._last_known_seq_no += 1 + internal_message.seq_no = self._last_known_seq_no + else: + raise TopicWriterError( + "Explicit seqno and auto_seq setting is mutual exclusive" + ) + else: + if internal_message.seq_no is None or internal_message.seq_no == 0: + raise TopicWriterError( + "Empty seqno and auto_seq setting is disabled" + ) + elif internal_message.seq_no <= self._last_known_seq_no: + raise TopicWriterError( + "Message seqno is duplicated: %s" % internal_message.seq_no + ) + else: + self._last_known_seq_no = internal_message.seq_no + + if self._settings.auto_created_at: + if internal_message.created_at is not None: + raise TopicWriterError( + "Explicit set auto_created_at and setting auto_created_at is mutual exclusive" + ) + else: + internal_message.created_at = now + + res.append(internal_message) + + return res + + def _check_stop(self): + if self._stop_reason.done(): + raise self._stop_reason.result() + + async def _connection_loop(self): + retry_settings = RetrySettings() # todo + + while True: + attempt = 0 # todo calc and reset + tasks = [] + + # noinspection PyBroadException + stream_writer = None + try: + stream_writer = await WriterAsyncIOStream.create( + self._driver, + self._init_message, + self._settings.update_token_interval, + ) + try: + self._last_known_seq_no = stream_writer.last_seqno + self._init_info.set_result( + PublicWriterInitInfo( + last_seqno=stream_writer.last_seqno, + supported_codecs=stream_writer.supported_codecs, + ) + ) + except asyncio.InvalidStateError: + pass + + self._stream_connected.set() + + send_loop = asyncio.create_task( + self._send_loop(stream_writer), name="writer send loop" + ) + receive_loop = asyncio.create_task( + self._read_loop(stream_writer), name="writer receive loop" + ) + + tasks = [send_loop, receive_loop] + done, _ = await asyncio.wait( + [send_loop, receive_loop], return_when=asyncio.FIRST_COMPLETED + ) + await stream_writer.close() + done.pop().result() + except issues.Error as err: + # todo log error + print(err) + + err_info = check_retriable_error(err, retry_settings, attempt) + if not err_info.is_retriable: + self._stop(err) + return + + await asyncio.sleep(err_info.sleep_timeout_seconds) + + except (asyncio.CancelledError, Exception) as err: + self._stop(err) + return + finally: + if stream_writer: + await stream_writer.close() + for task in tasks: + task.cancel() + await asyncio.wait(tasks) + + async def _encode_loop(self): + while True: + messages = await self._messages_for_encode.get() + while not self._messages_for_encode.empty(): + messages.extend(self._messages_for_encode.get_nowait()) + + batch_codec = await self._codec_selector(messages) + await self._encode_data_inplace(batch_codec, messages) + self._add_messages_to_send_queue(messages) + + async def _encode_data_inplace( + self, codec: PublicCodec, messages: List[InternalMessage] + ): + if codec == PublicCodec.RAW: + return + + eventloop = asyncio.get_running_loop() + encode_waiters = [] + encoder_function = self._codec_functions[codec] + + for message in messages: + encoded_data_futures = eventloop.run_in_executor( + self._encode_executor, encoder_function, message.get_bytes() + ) + encode_waiters.append(encoded_data_futures) + + encoded_datas = await asyncio.gather(*encode_waiters) + + for index, data in enumerate(encoded_datas): + message = messages[index] + message.codec = codec + message.data = data + + async def _codec_selector(self, messages: List[InternalMessage]) -> PublicCodec: + if self._codec is not None: + return self._codec + + if self._codec_selector_last_codec is None: + available_codecs = await self._get_available_codecs() + + # use every of available encoders at start for prevent problems + # with rare used encoders (on writer or reader side) + if self._codec_selector_batch_num < len(available_codecs): + codec = available_codecs[self._codec_selector_batch_num] + else: + codec = await self._codec_selector_by_check_compress(messages) + self._codec_selector_last_codec = codec + else: + if ( + self._codec_selector_batch_num + % self._codec_selector_check_batches_interval + == 0 + ): + self._codec_selector_last_codec = ( + await self._codec_selector_by_check_compress(messages) + ) + codec = self._codec_selector_last_codec + self._codec_selector_batch_num += 1 + return codec + + async def _get_available_codecs(self) -> List[PublicCodec]: + info = await self.wait_init() + topic_supported_codecs = info.supported_codecs + if not topic_supported_codecs: + topic_supported_codecs = [PublicCodec.RAW, PublicCodec.GZIP] + + res = [] + for codec in topic_supported_codecs: + if codec in self._codec_functions: + res.append(codec) + + if not res: + raise TopicWriterError("Writer does not support topic's codecs") + + res.sort() + + return res + + async def _codec_selector_by_check_compress( + self, messages: List[InternalMessage] + ) -> PublicCodec: + """ + Try to compress messages and choose codec with the smallest result size. + """ + + test_messages = messages[:10] + + available_codecs = await self._get_available_codecs() + if len(available_codecs) == 1: + return available_codecs[0] + + def get_compressed_size(codec) -> int: + s = 0 + f = self._codec_functions[codec] + + for m in test_messages: + encoded = f(m.get_bytes()) + s += len(encoded) + + return s + + def select_codec() -> PublicCodec: + min_codec = available_codecs[0] + min_size = get_compressed_size(min_codec) + for codec in available_codecs[1:]: + size = get_compressed_size(codec) + if size < min_size: + min_codec = codec + min_size = size + return min_codec + + loop = asyncio.get_running_loop() + codec = await loop.run_in_executor(self._encode_executor, select_codec) + return codec + + async def _read_loop(self, writer: "WriterAsyncIOStream"): + while True: + resp = await writer.receive() + + for ack in resp.acks: + self._handle_receive_ack(ack) + + def _handle_receive_ack(self, ack): + current_message = self._messages.popleft() + message_future = self._messages_future.popleft() + if current_message.seq_no != ack.seq_no: + raise TopicWriterError( + "internal error - receive unexpected ack. Expected seqno: %s, received seqno: %s" + % (current_message.seq_no, ack.seq_no) + ) + message_future.set_result( + None + ) # todo - return result with offset or skip status + + async def _send_loop(self, writer: "WriterAsyncIOStream"): + try: + messages = list(self._messages) + + last_seq_no = 0 + for m in messages: + writer.write([m]) + last_seq_no = m.seq_no + + while True: + m = await self._new_messages.get() # type: InternalMessage + if m.seq_no > last_seq_no: + writer.write([m]) + except Exception as e: + self._stop(e) + finally: + pass + + def _stop(self, reason: Exception): + if reason is None: + raise Exception("writer stop reason can not be None") + + if self._stop_reason.done(): + return + + self._stop_reason.set_result(reason) + + async def flush(self): + self._check_stop() + if not self._messages_future: + return + + # wait last message + await asyncio.wait((self._messages_future[-1],)) + + +class WriterAsyncIOStream: + # todo slots + _closed: bool + + last_seqno: int + supported_codecs: Optional[List[PublicCodec]] + + _stream: IGrpcWrapperAsyncIO + _requests: asyncio.Queue + _responses: AsyncIterator + + _update_token_interval: Optional[Union[int, float]] + _update_token_task: Optional[asyncio.Task] + _update_token_event: asyncio.Event + _get_token_function: Optional[Callable[[], str]] + + def __init__( + self, + update_token_interval: Optional[Union[int, float]] = None, + get_token_function: Optional[Callable[[], str]] = None, + ): + self._closed = False + + self._update_token_interval = update_token_interval + self._get_token_function = get_token_function + self._update_token_event = asyncio.Event() + self._update_token_task = None + + async def close(self): + if self._closed: + return + self._closed = True + + if self._update_token_task: + self._update_token_task.cancel() + await asyncio.wait([self._update_token_task]) + + self._stream.close() + + @staticmethod + async def create( + driver: SupportedDriverType, + init_request: StreamWriteMessage.InitRequest, + update_token_interval: Optional[Union[int, float]] = None, + ) -> "WriterAsyncIOStream": + stream = GrpcWrapperAsyncIO(StreamWriteMessage.FromServer.from_proto) + + await stream.start( + driver, _apis.TopicService.Stub, _apis.TopicService.StreamWrite + ) + + creds = driver._credentials + writer = WriterAsyncIOStream( + update_token_interval=update_token_interval, + get_token_function=creds.get_auth_token if creds else lambda: "", + ) + await writer._start(stream, init_request) + return writer + + async def receive(self) -> StreamWriteMessage.WriteResponse: + while True: + item = await self._stream.receive() + + if isinstance(item, StreamWriteMessage.WriteResponse): + return item + if isinstance(item, UpdateTokenResponse): + self._update_token_event.set() + continue + + # todo log unknown messages instead of raise exception + raise Exception("Unknown message while read writer answers: %s" % item) + + async def _start( + self, stream: IGrpcWrapperAsyncIO, init_message: StreamWriteMessage.InitRequest + ): + stream.write(StreamWriteMessage.FromClient(init_message)) + + resp = await stream.receive() + self._ensure_ok(resp) + if not isinstance(resp, StreamWriteMessage.InitResponse): + raise TopicWriterError("Unexpected answer for init request: %s" % resp) + + self.last_seqno = resp.last_seq_no + self.supported_codecs = [PublicCodec(codec) for codec in resp.supported_codecs] + + self._stream = stream + + if self._update_token_interval is not None: + self._update_token_event.set() + self._update_token_task = asyncio.create_task( + self._update_token_loop(), name="update_token_loop" + ) + + @staticmethod + def _ensure_ok(message: WriterMessagesFromServerToClient): + if not message.status.is_success(): + raise TopicWriterError( + "status error from server in writer: %s", message.status + ) + + def write(self, messages: List[InternalMessage]): + if self._closed: + raise RuntimeError("Can not write on closed stream.") + + for request in messages_to_proto_requests(messages): + self._stream.write(request) + + async def _update_token_loop(self): + while True: + await asyncio.sleep(self._update_token_interval) + await self._update_token(token=self._get_token_function()) + + async def _update_token(self, token: str): + await self._update_token_event.wait() + try: + msg = StreamWriteMessage.FromClient(UpdateTokenRequest(token)) + self._stream.write(msg) + finally: + self._update_token_event.clear() diff --git a/ydb/public/sdk/python3/ydb/_topic_writer/topic_writer_sync.py b/ydb/public/sdk/python3/ydb/_topic_writer/topic_writer_sync.py new file mode 100644 index 00000000000..e6b512387ff --- /dev/null +++ b/ydb/public/sdk/python3/ydb/_topic_writer/topic_writer_sync.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +import asyncio +from concurrent.futures import Future +from typing import Union, List, Optional + +from .._grpc.grpcwrapper.common_utils import SupportedDriverType +from .topic_writer import ( + PublicWriterSettings, + PublicWriterInitInfo, + PublicWriteResult, + Message, + TopicWriterClosedError, +) + +from .topic_writer_asyncio import WriterAsyncIO +from .._topic_common.common import ( + _get_shared_event_loop, + TimeoutType, + CallFromSyncToAsync, +) + + +class WriterSync: + _caller: CallFromSyncToAsync + _async_writer: WriterAsyncIO + _closed: bool + + def __init__( + self, + driver: SupportedDriverType, + settings: PublicWriterSettings, + *, + eventloop: Optional[asyncio.AbstractEventLoop] = None, + ): + + self._closed = False + + if eventloop: + loop = eventloop + else: + loop = _get_shared_event_loop() + + self._caller = CallFromSyncToAsync(loop) + + async def create_async_writer(): + return WriterAsyncIO(driver, settings) + + self._async_writer = self._caller.safe_call_with_result( + create_async_writer(), None + ) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + def close(self, *, flush: bool = True, timeout: TimeoutType = None): + if self._closed: + return + + self._closed = True + + self._caller.safe_call_with_result( + self._async_writer.close(flush=flush), timeout + ) + + def _check_closed(self): + if self._closed: + raise TopicWriterClosedError() + + def async_flush(self) -> Future: + self._check_closed() + + return self._caller.unsafe_call_with_future(self._async_writer.flush()) + + def flush(self, *, timeout=None): + self._check_closed() + + return self._caller.unsafe_call_with_result(self._async_writer.flush(), timeout) + + def async_wait_init(self) -> Future[PublicWriterInitInfo]: + self._check_closed() + + return self._caller.unsafe_call_with_future(self._async_writer.wait_init()) + + def wait_init(self, *, timeout: TimeoutType = None) -> PublicWriterInitInfo: + self._check_closed() + + return self._caller.unsafe_call_with_result( + self._async_writer.wait_init(), timeout + ) + + def write( + self, + messages: Union[Message, List[Message]], + timeout: TimeoutType = None, + ): + self._check_closed() + + self._caller.safe_call_with_result(self._async_writer.write(messages), timeout) + + def async_write_with_ack( + self, + messages: Union[Message, List[Message]], + ) -> Future[Union[PublicWriteResult, List[PublicWriteResult]]]: + self._check_closed() + + return self._caller.unsafe_call_with_future( + self._async_writer.write_with_ack(messages) + ) + + def write_with_ack( + self, + messages: Union[Message, List[Message]], + timeout: Union[float, None] = None, + ) -> Union[PublicWriteResult, List[PublicWriteResult]]: + self._check_closed() + + return self._caller.unsafe_call_with_result( + self._async_writer.write_with_ack(messages), timeout=timeout + ) diff --git a/ydb/public/sdk/python3/ydb/_utilities.py b/ydb/public/sdk/python3/ydb/_utilities.py index 32419b1bf97..0b72a198979 100644 --- a/ydb/public/sdk/python3/ydb/_utilities.py +++ b/ydb/public/sdk/python3/ydb/_utilities.py @@ -1,10 +1,11 @@ # -*- coding: utf-8 -*- -import six +import threading import codecs from concurrent import futures import functools import hashlib import collections +import urllib.parse from . import ydb_version try: @@ -55,8 +56,8 @@ def parse_connection_string(connection_string): # default is grpcs cs = _grpcs_protocol + cs - p = six.moves.urllib.parse.urlparse(connection_string) - b = six.moves.urllib.parse.parse_qs(p.query) + p = urllib.parse.urlparse(connection_string) + b = urllib.parse.parse_qs(p.query) database = b.get("database", []) assert len(database) > 0 @@ -77,11 +78,9 @@ def wrap_async_call_exceptions(f): def get_query_hash(yql_text): try: - return hashlib.sha256( - six.text_type(yql_text, "utf-8").encode("utf-8") - ).hexdigest() + return hashlib.sha256(str(yql_text, "utf-8").encode("utf-8")).hexdigest() except TypeError: - return hashlib.sha256(six.text_type(yql_text).encode("utf-8")).hexdigest() + return hashlib.sha256(str(yql_text).encode("utf-8")).hexdigest() class LRUCache(object): @@ -159,3 +158,17 @@ class SyncResponseIterator(object): def __next__(self): return self._next() + + +class AtomicCounter: + _lock: threading.Lock + _value: int + + def __init__(self, initial_value: int = 0): + self._lock = threading.Lock() + self._value = initial_value + + def inc_and_get(self) -> int: + with self._lock: + self._value += 1 + return self._value diff --git a/ydb/public/sdk/python3/ydb/aio/connection.py b/ydb/public/sdk/python3/ydb/aio/connection.py index 88ab738c6ac..fbfcfaaf6c3 100644 --- a/ydb/public/sdk/python3/ydb/aio/connection.py +++ b/ydb/public/sdk/python3/ydb/aio/connection.py @@ -1,5 +1,6 @@ import logging import asyncio +import typing from typing import Any, Tuple, Callable, Iterable import collections import grpc @@ -24,11 +25,19 @@ from ydb.driver import DriverConfig from ydb.settings import BaseRequestSettings from ydb import issues +# Workaround for good IDE and universal for runtime +if typing.TYPE_CHECKING: + from ydb._grpc.v4 import ydb_topic_v1_pb2_grpc +else: + from ydb._grpc.common import ydb_topic_v1_pb2_grpc + + _stubs_list = ( _apis.TableService.Stub, _apis.SchemeService.Stub, _apis.DiscoveryService.Stub, _apis.CmsService.Stub, + ydb_topic_v1_pb2_grpc.TopicServiceStub, ) logger = logging.getLogger(__name__) diff --git a/ydb/public/sdk/python3/ydb/aio/credentials.py b/ydb/public/sdk/python3/ydb/aio/credentials.py index e98404407b7..93868b279a5 100644 --- a/ydb/public/sdk/python3/ydb/aio/credentials.py +++ b/ydb/public/sdk/python3/ydb/aio/credentials.py @@ -3,7 +3,6 @@ import time import abc import asyncio import logging -import six from ydb import issues, credentials logger = logging.getLogger(__name__) @@ -55,7 +54,6 @@ class _AtMostOneExecution(object): asyncio.ensure_future(self._wrapped_execution(callback)) [email protected]_metaclass(abc.ABCMeta) class AbstractExpiringTokenCredentials(credentials.AbstractExpiringTokenCredentials): def __init__(self): super(AbstractExpiringTokenCredentials, self).__init__() diff --git a/ydb/public/sdk/python3/ydb/aio/driver.py b/ydb/public/sdk/python3/ydb/aio/driver.py index 3bf6cca82e8..1aa3ad27576 100644 --- a/ydb/public/sdk/python3/ydb/aio/driver.py +++ b/ydb/public/sdk/python3/ydb/aio/driver.py @@ -1,42 +1,7 @@ -import os - from . import pool, scheme, table import ydb -from ydb.driver import get_config - - -def default_credentials(credentials=None): - if credentials is not None: - return credentials - - service_account_key_file = os.getenv("YDB_SERVICE_ACCOUNT_KEY_FILE_CREDENTIALS") - if service_account_key_file is not None: - from .iam import ServiceAccountCredentials - - return ServiceAccountCredentials.from_file(service_account_key_file) - - anonymous_credetials = os.getenv("YDB_ANONYMOUS_CREDENTIALS", "0") == "1" - if anonymous_credetials: - return ydb.credentials.AnonymousCredentials() - - metadata_credentials = os.getenv("YDB_METADATA_CREDENTIALS", "0") == "1" - if metadata_credentials: - from .iam import MetadataUrlCredentials - - return MetadataUrlCredentials() - - access_token = os.getenv("YDB_ACCESS_TOKEN_CREDENTIALS") - if access_token is not None: - return ydb.credentials.AccessTokenCredentials(access_token) - - # (legacy instantiation) - creds = ydb.auth_helpers.construct_credentials_from_environ() - if creds is not None: - return creds - - from .iam import MetadataUrlCredentials - - return MetadataUrlCredentials() +from .. import _utilities +from ydb.driver import get_config, default_credentials class DriverConfig(ydb.DriverConfig): @@ -56,7 +21,7 @@ class DriverConfig(ydb.DriverConfig): def default_from_connection_string( cls, connection_string, root_certificates=None, credentials=None, **kwargs ): - endpoint, database = ydb.parse_connection_string(connection_string) + endpoint, database = _utilities.parse_connection_string(connection_string) return cls( endpoint, database, @@ -67,6 +32,8 @@ class DriverConfig(ydb.DriverConfig): class Driver(pool.ConnectionPool): + _credentials: ydb.Credentials # used for topic clients + def __init__( self, driver_config=None, @@ -77,6 +44,8 @@ class Driver(pool.ConnectionPool): credentials=None, **kwargs ): + from .. import topic # local import for prevent cycle import error + config = get_config( driver_config, connection_string, @@ -89,5 +58,8 @@ class Driver(pool.ConnectionPool): super(Driver, self).__init__(config) + self._credentials = config.credentials + self.scheme_client = scheme.SchemeClient(self) self.table_client = table.TableClient(self, config.table_client_settings) + self.topic_client = topic.TopicClientAsyncIO(self, config.topic_client_settings) diff --git a/ydb/public/sdk/python3/ydb/aio/iam.py b/ydb/public/sdk/python3/ydb/aio/iam.py index 51b650f24bc..b56c066043d 100644 --- a/ydb/public/sdk/python3/ydb/aio/iam.py +++ b/ydb/public/sdk/python3/ydb/aio/iam.py @@ -3,7 +3,6 @@ import time import abc import logging -import six from ydb.iam import auth from .credentials import AbstractExpiringTokenCredentials @@ -24,7 +23,6 @@ except ImportError: aiohttp = None [email protected]_metaclass(abc.ABCMeta) class TokenServiceCredentials(AbstractExpiringTokenCredentials): def __init__(self, iam_endpoint=None, iam_channel_credentials=None): super(TokenServiceCredentials, self).__init__() diff --git a/ydb/public/sdk/python3/ydb/aio/table.py b/ydb/public/sdk/python3/ydb/aio/table.py index f937a9283ca..06f8ca7c81f 100644 --- a/ydb/public/sdk/python3/ydb/aio/table.py +++ b/ydb/public/sdk/python3/ydb/aio/table.py @@ -13,7 +13,6 @@ from ydb.table import ( _scan_query_request_factory, _wrap_scan_query_response, BaseTxContext, - _allow_split_transaction, ) from . import _utilities from ydb import _apis, _session_impl @@ -121,9 +120,7 @@ class Session(BaseSession): set_read_replicas_settings, ) - def transaction( - self, tx_mode=None, *, allow_split_transactions=_allow_split_transaction - ): + def transaction(self, tx_mode=None, *, allow_split_transactions=None): return TxContext( self._driver, self._state, @@ -194,8 +191,6 @@ class TxContext(BaseTxContext): self, query, parameters=None, commit_tx=False, settings=None ): # pylint: disable=W0236 - self._check_split() - return await super().execute(query, parameters, commit_tx, settings) async def commit(self, settings=None): # pylint: disable=W0236 diff --git a/ydb/public/sdk/python3/ydb/auth_helpers.py b/ydb/public/sdk/python3/ydb/auth_helpers.py index 5d889555dfa..6399c3cfdf0 100644 --- a/ydb/public/sdk/python3/ydb/auth_helpers.py +++ b/ydb/public/sdk/python3/ydb/auth_helpers.py @@ -1,9 +1,6 @@ # -*- coding: utf-8 -*- import os -from . import credentials, tracing -import warnings - def read_bytes(f): with open(f, "rb") as fr: @@ -15,43 +12,3 @@ def load_ydb_root_certificate(): if path is not None and os.path.exists(path): return read_bytes(path) return None - - -def construct_credentials_from_environ(tracer=None): - tracer = tracer if tracer is not None else tracing.Tracer(None) - warnings.warn( - "using construct_credentials_from_environ method for credentials instantiation is deprecated and will be " - "removed in the future major releases. Please instantialize credentials by default or provide correct credentials " - "instance to the Driver." - ) - - # dynamically import required authentication libraries - if ( - os.getenv("USE_METADATA_CREDENTIALS") is not None - and int(os.getenv("USE_METADATA_CREDENTIALS")) == 1 - ): - import ydb.iam - - tracing.trace(tracer, {"credentials.metadata": True}) - return ydb.iam.MetadataUrlCredentials() - - if os.getenv("YDB_TOKEN") is not None: - tracing.trace(tracer, {"credentials.access_token": True}) - return credentials.AuthTokenCredentials(os.getenv("YDB_TOKEN")) - - if os.getenv("SA_KEY_FILE") is not None: - - import ydb.iam - - tracing.trace(tracer, {"credentials.sa_key_file": True}) - root_certificates_file = os.getenv("SSL_ROOT_CERTIFICATES_FILE", None) - iam_channel_credentials = {} - if root_certificates_file is not None: - iam_channel_credentials = { - "root_certificates": read_bytes(root_certificates_file) - } - return ydb.iam.ServiceAccountCredentials.from_file( - os.getenv("SA_KEY_FILE"), - iam_channel_credentials=iam_channel_credentials, - iam_endpoint=os.getenv("IAM_ENDPOINT", "iam.api.cloud.yandex.net:443"), - ) diff --git a/ydb/public/sdk/python3/ydb/convert.py b/ydb/public/sdk/python3/ydb/convert.py index b231bb1091d..81348d311cb 100644 --- a/ydb/public/sdk/python3/ydb/convert.py +++ b/ydb/public/sdk/python3/ydb/convert.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- import decimal from google.protobuf import struct_pb2 -import six from . import issues, types, _apis @@ -13,7 +12,7 @@ _DecimalNanRepr = 10**35 + 1 _DecimalInfRepr = 10**35 _DecimalSignedInfRepr = -(10**35) _primitive_type_by_id = {} -_default_allow_truncated_result = True +_default_allow_truncated_result = False def _initialize(): @@ -82,9 +81,7 @@ def _pb_to_list(type_pb, value_pb, table_client_settings): def _pb_to_tuple(type_pb, value_pb, table_client_settings): return tuple( _to_native_value(item_type, item_value, table_client_settings) - for item_type, item_value in six.moves.zip( - type_pb.tuple_type.elements, value_pb.items - ) + for item_type, item_value in zip(type_pb.tuple_type.elements, value_pb.items) ) @@ -107,7 +104,7 @@ class _Struct(_DotDict): def _pb_to_struct(type_pb, value_pb, table_client_settings): result = _Struct() - for member, item in six.moves.zip(type_pb.struct_type.members, value_pb.items): + for member, item in zip(type_pb.struct_type.members, value_pb.items): result[member.name] = _to_native_value(member.type, item, table_client_settings) return result @@ -202,9 +199,7 @@ def _list_to_pb(type_pb, value): def _tuple_to_pb(type_pb, value): value_pb = _apis.ydb_value.Value() - for element_type, element_value in six.moves.zip( - type_pb.tuple_type.elements, value - ): + for element_type, element_value in zip(type_pb.tuple_type.elements, value): value_item_proto = value_pb.items.add() value_item_proto.MergeFrom(_from_native_value(element_type, element_value)) return value_pb @@ -290,7 +285,7 @@ def parameters_to_pb(parameters_types, parameters_values): return {} param_values_pb = {} - for name, type_pb in six.iteritems(parameters_types): + for name, type_pb in parameters_types.items(): result = _apis.ydb_value.TypedValue() ttype = type_pb if isinstance(type_pb, types.AbstractTypeBuilder): @@ -332,7 +327,7 @@ class _ResultSet(object): for row_proto in message.rows: row = _Row(message.columns) - for column, value, column_info in six.moves.zip( + for column, value, column_info in zip( message.columns, row_proto.items, column_parsers ): v_type = value.WhichOneof("value") @@ -400,9 +395,7 @@ class _LazyRow(_DotDict): super(_LazyRow, self).__init__() self._columns = columns self._table_client_settings = table_client_settings - for i, (column, row_item) in enumerate( - six.moves.zip(self._columns, proto_row.items) - ): + for i, (column, row_item) in enumerate(zip(self._columns, proto_row.items)): super(_LazyRow, self).__setitem__( column.name, _LazyRowItem(row_item, column.type, table_client_settings, parsers[i]), diff --git a/ydb/public/sdk/python3/ydb/credentials.py b/ydb/public/sdk/python3/ydb/credentials.py index 8e22fe2a841..2a2dea3b330 100644 --- a/ydb/public/sdk/python3/ydb/credentials.py +++ b/ydb/public/sdk/python3/ydb/credentials.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- import abc -import six +import typing + from . import tracing, issues, connection from . import settings as settings_impl import threading @@ -9,8 +10,7 @@ import logging import time # Workaround for good IDE and universal for runtime -# noinspection PyUnreachableCode -if False: +if typing.TYPE_CHECKING: from ._grpc.v4.protos import ydb_auth_pb2 from ._grpc.v4 import ydb_auth_v1_pb2_grpc else: @@ -22,15 +22,13 @@ YDB_AUTH_TICKET_HEADER = "x-ydb-auth-ticket" logger = logging.getLogger(__name__) [email protected]_metaclass(abc.ABCMeta) -class AbstractCredentials(object): +class AbstractCredentials(abc.ABC): """ An abstract class that provides auth metadata """ [email protected]_metaclass(abc.ABCMeta) -class Credentials(object): +class Credentials(abc.ABC): def __init__(self, tracer=None): self.tracer = tracer if tracer is not None else tracing.Tracer(None) @@ -41,6 +39,12 @@ class Credentials(object): """ pass + def get_auth_token(self) -> str: + for header, token in self.auth_metadata(): + if header == YDB_AUTH_TICKET_HEADER: + return token + return "" + class OneToManyValue(object): def __init__(self): @@ -87,7 +91,6 @@ class AtMostOneExecution(object): self._can_schedule = True [email protected]_metaclass(abc.ABCMeta) class AbstractExpiringTokenCredentials(Credentials): def __init__(self, tracer=None): super(AbstractExpiringTokenCredentials, self).__init__(tracer) diff --git a/ydb/public/sdk/python3/ydb/dbapi/cursor.py b/ydb/public/sdk/python3/ydb/dbapi/cursor.py index 71175abf4e8..eb26dc2b0f8 100644 --- a/ydb/public/sdk/python3/ydb/dbapi/cursor.py +++ b/ydb/public/sdk/python3/ydb/dbapi/cursor.py @@ -5,8 +5,6 @@ import datetime import itertools import logging -import six - import ydb from .errors import DatabaseError @@ -42,7 +40,7 @@ def render_datetime(value): def render(value): if value is None: return "NULL" - if isinstance(value, six.string_types): + if isinstance(value, str): return render_str(value) if isinstance(value, datetime.datetime): return render_datetime(value) diff --git a/ydb/public/sdk/python3/ydb/default_pem.py b/ydb/public/sdk/python3/ydb/default_pem.py index 92286ba237d..b8272efd292 100644 --- a/ydb/public/sdk/python3/ydb/default_pem.py +++ b/ydb/public/sdk/python3/ydb/default_pem.py @@ -1,6 +1,3 @@ -import six - - data = """ # Issuer: CN=GlobalSign Root CA O=GlobalSign nv-sa OU=Root CA # Subject: CN=GlobalSign Root CA O=GlobalSign nv-sa OU=Root CA @@ -4686,6 +4683,4 @@ LpuQKbSbIERsmR+QqQ== def load_default_pem(): global data - if six.PY3: - return data.encode("utf-8") - return data + return data.encode("utf-8") diff --git a/ydb/public/sdk/python3/ydb/driver.py b/ydb/public/sdk/python3/ydb/driver.py index 9b3fa99cfa3..e3274687eaf 100644 --- a/ydb/public/sdk/python3/ydb/driver.py +++ b/ydb/public/sdk/python3/ydb/driver.py @@ -1,15 +1,11 @@ # -*- coding: utf-8 -*- from . import credentials as credentials_impl, table, scheme, pool from . import tracing -import six import os import grpc from . import _utilities -if six.PY2: - Any = None -else: - from typing import Any # noqa +from typing import Any # noqa class RPCCompression: @@ -23,10 +19,17 @@ class RPCCompression: def default_credentials(credentials=None, tracer=None): tracer = tracer if tracer is not None else tracing.Tracer(None) with tracer.trace("Driver.default_credentials") as ctx: - if credentials is not None: + if credentials is None: + ctx.trace({"credentials.anonymous": True}) + return credentials_impl.AnonymousCredentials() + else: ctx.trace({"credentials.prepared": True}) return credentials + +def credentials_from_env_variables(tracer=None): + tracer = tracer if tracer is not None else tracing.Tracer(None) + with tracer.trace("Driver.credentials_from_env_variables") as ctx: service_account_key_file = os.getenv("YDB_SERVICE_ACCOUNT_KEY_FILE_CREDENTIALS") if service_account_key_file is not None: ctx.trace({"credentials.service_account_key_file": True}) @@ -51,9 +54,7 @@ def default_credentials(credentials=None, tracer=None): ctx.trace({"credentials.access_token": True}) return credentials_impl.AuthTokenCredentials(access_token) - import ydb.iam - - return ydb.iam.MetadataUrlCredentials(tracer=tracer) + return default_credentials(None, tracer) class DriverConfig(object): @@ -70,6 +71,7 @@ class DriverConfig(object): "grpc_keep_alive_timeout", "secure_channel", "table_client_settings", + "topic_client_settings", "endpoints", "primary_user_agent", "tracer", @@ -92,6 +94,7 @@ class DriverConfig(object): private_key=None, grpc_keep_alive_timeout=None, table_client_settings=None, + topic_client_settings=None, endpoints=None, primary_user_agent="python-library", tracer=None, @@ -138,6 +141,7 @@ class DriverConfig(object): self.private_key = private_key self.grpc_keep_alive_timeout = grpc_keep_alive_timeout self.table_client_settings = table_client_settings + self.topic_client_settings = topic_client_settings self.primary_user_agent = primary_user_agent self.tracer = tracer if tracer is not None else tracing.Tracer(None) self.grpc_lb_policy_name = grpc_lb_policy_name @@ -228,6 +232,8 @@ class Driver(pool.ConnectionPool): :param database: A database path :param credentials: A credentials. If not specifed credentials constructed by default. """ + from . import topic # local import for prevent cycle import error + driver_config = get_config( driver_config, connection_string, @@ -238,5 +244,9 @@ class Driver(pool.ConnectionPool): ) super(Driver, self).__init__(driver_config) + + self._credentials = driver_config.credentials + self.scheme_client = scheme.SchemeClient(self) self.table_client = table.TableClient(self, driver_config.table_client_settings) + self.topic_client = topic.TopicClient(self, driver_config.topic_client_settings) diff --git a/ydb/public/sdk/python3/ydb/export.py b/ydb/public/sdk/python3/ydb/export.py index 30898cbb42a..bc35bd284c5 100644 --- a/ydb/public/sdk/python3/ydb/export.py +++ b/ydb/public/sdk/python3/ydb/export.py @@ -1,12 +1,12 @@ import enum +import typing from . import _apis from . import settings_impl as s_impl # Workaround for good IDE and universal for runtime -# noinspection PyUnreachableCode -if False: +if typing.TYPE_CHECKING: from ._grpc.v4.protos import ydb_export_pb2 from ._grpc.v4 import ydb_export_v1_pb2_grpc else: diff --git a/ydb/public/sdk/python3/ydb/global_settings.py b/ydb/public/sdk/python3/ydb/global_settings.py index de8b0b1b063..8edac3f4b44 100644 --- a/ydb/public/sdk/python3/ydb/global_settings.py +++ b/ydb/public/sdk/python3/ydb/global_settings.py @@ -1,16 +1,24 @@ +import warnings + from . import convert from . import table def global_allow_truncated_result(enabled: bool = True): - """ - call global_allow_truncated_result(False) for more safe execution and compatible with future changes - """ + if convert._default_allow_truncated_result == enabled: + return + + if enabled: + warnings.warn("Global allow truncated response is deprecated behaviour.") + convert._default_allow_truncated_result = enabled def global_allow_split_transactions(enabled: bool): - """ - call global_allow_truncated_result(False) for more safe execution and compatible with future changes - """ - table._allow_split_transaction = enabled + if table._default_allow_split_transaction == enabled: + return + + if enabled: + warnings.warn("Global allow split transaction is deprecated behaviour.") + + table._default_allow_split_transaction = enabled diff --git a/ydb/public/sdk/python3/ydb/iam/auth.py b/ydb/public/sdk/python3/ydb/iam/auth.py index 06b07e917e5..50d98b4b4a9 100644 --- a/ydb/public/sdk/python3/ydb/iam/auth.py +++ b/ydb/public/sdk/python3/ydb/iam/auth.py @@ -3,7 +3,6 @@ from ydb import credentials, tracing import grpc import time import abc -import six from datetime import datetime import json import os @@ -45,7 +44,6 @@ def get_jwt(account_id, access_key_id, private_key, jwt_expiration_timeout): ) [email protected]_metaclass(abc.ABCMeta) class TokenServiceCredentials(credentials.AbstractExpiringTokenCredentials): def __init__(self, iam_endpoint=None, iam_channel_credentials=None, tracer=None): super(TokenServiceCredentials, self).__init__(tracer) @@ -84,8 +82,7 @@ class TokenServiceCredentials(credentials.AbstractExpiringTokenCredentials): return {"access_token": response.iam_token, "expires_in": expires_in} [email protected]_metaclass(abc.ABCMeta) -class BaseJWTCredentials(object): +class BaseJWTCredentials(abc.ABC): def __init__(self, account_id, access_key_id, private_key): self._account_id = account_id self._jwt_expiration_timeout = 60.0 * 60 diff --git a/ydb/public/sdk/python3/ydb/import_client.py b/ydb/public/sdk/python3/ydb/import_client.py index d1ccc99af65..d94294ca7ce 100644 --- a/ydb/public/sdk/python3/ydb/import_client.py +++ b/ydb/public/sdk/python3/ydb/import_client.py @@ -1,12 +1,12 @@ import enum +import typing from . import _apis from . import settings_impl as s_impl # Workaround for good IDE and universal for runtime -# noinspection PyUnreachableCode -if False: +if typing.TYPE_CHECKING: from ._grpc.v4.protos import ydb_import_pb2 from ._grpc.v4 import ydb_import_v1_pb2_grpc else: diff --git a/ydb/public/sdk/python3/ydb/issues.py b/ydb/public/sdk/python3/ydb/issues.py index 0a0d6a907ef..100af01d299 100644 --- a/ydb/public/sdk/python3/ydb/issues.py +++ b/ydb/public/sdk/python3/ydb/issues.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- from google.protobuf import text_format import enum -from six.moves import queue +import queue from . import _apis @@ -159,11 +159,18 @@ class SessionPoolEmpty(Error, queue.Empty): def _format_issues(issues): if not issues: return "" + return " ,".join( - [text_format.MessageToString(issue, False, True) for issue in issues] + text_format.MessageToString(issue, as_utf8=False, as_one_line=True) + for issue in issues ) +def _format_response(response): + fmt_issues = _format_issues(response.issues) + return f"{fmt_issues} (server_code: {response.status})" + + _success_status_codes = {StatusCode.STATUS_CODE_UNSPECIFIED, StatusCode.SUCCESS} _server_side_error_map = { StatusCode.BAD_REQUEST: BadRequest, @@ -190,4 +197,4 @@ _server_side_error_map = { def _process_response(response_proto): if response_proto.status not in _success_status_codes: exc_obj = _server_side_error_map.get(response_proto.status) - raise exc_obj(_format_issues(response_proto.issues), response_proto.issues) + raise exc_obj(_format_response(response_proto), response_proto.issues) diff --git a/ydb/public/sdk/python3/ydb/pool.py b/ydb/public/sdk/python3/ydb/pool.py index dfda0adff21..007aa94d33d 100644 --- a/ydb/public/sdk/python3/ydb/pool.py +++ b/ydb/public/sdk/python3/ydb/pool.py @@ -1,15 +1,15 @@ # -*- coding: utf-8 -*- +import abc import threading import logging from concurrent import futures import collections import random -import six - from . import connection as connection_impl, issues, resolver, _utilities, tracing -from abc import abstractmethod, ABCMeta +from abc import abstractmethod +from .connection import Connection logger = logging.getLogger(__name__) @@ -127,7 +127,7 @@ class ConnectionsCache(object): return subscription @tracing.with_trace() - def get(self, preferred_endpoint=None): + def get(self, preferred_endpoint=None) -> Connection: with self.lock: if ( preferred_endpoint is not None @@ -295,8 +295,7 @@ class Discovery(threading.Thread): self.logger.info("Successfully terminated discovery process") [email protected]_metaclass(ABCMeta) -class IConnectionPool: +class IConnectionPool(abc.ABC): @abstractmethod def __init__(self, driver_config): """ diff --git a/ydb/public/sdk/python3/ydb/scheme.py b/ydb/public/sdk/python3/ydb/scheme.py index 88eca78c776..96a6c25f8dc 100644 --- a/ydb/public/sdk/python3/ydb/scheme.py +++ b/ydb/public/sdk/python3/ydb/scheme.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- import abc import enum -import six from abc import abstractmethod from . import issues, operation, settings as settings_impl, _apis @@ -347,8 +346,7 @@ def _wrap_describe_path_response(rpc_state, response): return _wrap_scheme_entry(message.self) [email protected]_metaclass(abc.ABCMeta) -class ISchemeClient: +class ISchemeClient(abc.ABC): @abstractmethod def __init__(self, driver): pass diff --git a/ydb/public/sdk/python3/ydb/scheme_test.py b/ydb/public/sdk/python3/ydb/scheme_test.py deleted file mode 100644 index d63850bf930..00000000000 --- a/ydb/public/sdk/python3/ydb/scheme_test.py +++ /dev/null @@ -1,30 +0,0 @@ -from .scheme import ( - SchemeEntryType, - _wrap_scheme_entry, - _wrap_list_directory_response, -) -from ._apis import ydb_scheme - - -def test_wrap_scheme_entry(): - assert ( - _wrap_scheme_entry(ydb_scheme.Entry(type=1)).type is SchemeEntryType.DIRECTORY - ) - assert _wrap_scheme_entry(ydb_scheme.Entry(type=17)).type is SchemeEntryType.TOPIC - - assert ( - _wrap_scheme_entry(ydb_scheme.Entry()).type is SchemeEntryType.TYPE_UNSPECIFIED - ) - assert ( - _wrap_scheme_entry(ydb_scheme.Entry(type=10)).type - is SchemeEntryType.TYPE_UNSPECIFIED - ) - assert ( - _wrap_scheme_entry(ydb_scheme.Entry(type=1001)).type - is SchemeEntryType.TYPE_UNSPECIFIED - ) - - -def test_wrap_list_directory_response(): - d = _wrap_list_directory_response(None, ydb_scheme.ListDirectoryResponse()) - assert d.type is SchemeEntryType.TYPE_UNSPECIFIED diff --git a/ydb/public/sdk/python3/ydb/scripting.py b/ydb/public/sdk/python3/ydb/scripting.py index 9fed037aecf..131324301ce 100644 --- a/ydb/public/sdk/python3/ydb/scripting.py +++ b/ydb/public/sdk/python3/ydb/scripting.py @@ -1,6 +1,7 @@ +import typing + # Workaround for good IDE and universal for runtime -# noinspection PyUnreachableCode -if False: +if typing.TYPE_CHECKING: from ._grpc.v4.protos import ydb_scripting_pb2 from ._grpc.v4 import ydb_scripting_v1_pb2_grpc else: diff --git a/ydb/public/sdk/python3/ydb/sqlalchemy/__init__.py b/ydb/public/sdk/python3/ydb/sqlalchemy/__init__.py index 9e065d8f0e3..aa9b2d006ce 100644 --- a/ydb/public/sdk/python3/ydb/sqlalchemy/__init__.py +++ b/ydb/public/sdk/python3/ydb/sqlalchemy/__init__.py @@ -191,11 +191,16 @@ try: ydb.PrimitiveType.DyNumber: sa.TEXT, } - def _get_column_type(t): - if isinstance(t.item, ydb.DecimalType): - return sa.DECIMAL(precision=t.item.precision, scale=t.item.scale) + def _get_column_info(t): + nullable = False + if isinstance(t, ydb.OptionalType): + nullable = True + t = t.item - return COLUMN_TYPES[t.item] + if isinstance(t, ydb.DecimalType): + return sa.DECIMAL(precision=t.precision, scale=t.scale), nullable + + return COLUMN_TYPES[t], nullable class YqlDialect(DefaultDialect): name = "yql" @@ -250,11 +255,12 @@ try: columns = raw_conn.describe(qt) as_compatible = [] for column in columns: + col_type, nullable = _get_column_info(column.type) as_compatible.append( { "name": column.name, - "type": _get_column_type(column.type), - "nullable": True, + "type": col_type, + "nullable": nullable, } ) diff --git a/ydb/public/sdk/python3/ydb/table.py b/ydb/public/sdk/python3/ydb/table.py index 660959bfc6c..799a5426668 100644 --- a/ydb/public/sdk/python3/ydb/table.py +++ b/ydb/public/sdk/python3/ydb/table.py @@ -7,7 +7,6 @@ import time import random import enum -import six from . import ( issues, convert, @@ -28,7 +27,7 @@ try: except ImportError: interceptor = None -_allow_split_transaction = True +_default_allow_split_transaction = False logger = logging.getLogger(__name__) @@ -770,8 +769,7 @@ class TableDescription(object): return self [email protected]_metaclass(abc.ABCMeta) -class AbstractTransactionModeBuilder(object): +class AbstractTransactionModeBuilder(abc.ABC): @property @abc.abstractmethod def name(self): @@ -949,7 +947,7 @@ def retry_operation_impl(callee, retry_settings=None, *args, **kwargs): retry_settings = RetrySettings() if retry_settings is None else retry_settings status = None - for attempt in six.moves.range(retry_settings.max_retries + 1): + for attempt in range(retry_settings.max_retries + 1): try: result = YdbRetryOperationFinalResult(callee(*args, **kwargs)) yield result @@ -1103,8 +1101,7 @@ def _scan_query_request_factory(query, parameters=None, settings=None): ) [email protected]_metaclass(abc.ABCMeta) -class ISession: +class ISession(abc.ABC): @abstractmethod def __init__(self, driver, table_client_settings): pass @@ -1184,9 +1181,7 @@ class ISession: pass @abstractmethod - def transaction( - self, tx_mode=None, allow_split_transactions=_allow_split_transaction - ): + def transaction(self, tx_mode=None, allow_split_transactions=None): pass @abstractmethod @@ -1268,8 +1263,7 @@ class ISession: pass [email protected]_metaclass(abc.ABCMeta) -class ITableClient: +class ITableClient(abc.ABC): def __init__(self, driver, table_client_settings=None): pass @@ -1691,9 +1685,7 @@ class BaseSession(ISession): self._state.endpoint, ) - def transaction( - self, tx_mode=None, allow_split_transactions=_allow_split_transaction - ): + def transaction(self, tx_mode=None, allow_split_transactions=None): return TxContext( self._driver, self._state, @@ -2100,8 +2092,7 @@ class Session(BaseSession): ) [email protected]_metaclass(abc.ABCMeta) -class ITxContext: +class ITxContext(abc.ABC): @abstractmethod def __init__(self, driver, session_state, session, tx_mode=None): """ @@ -2231,7 +2222,7 @@ class BaseTxContext(ITxContext): session, tx_mode=None, *, - allow_split_transactions=_allow_split_transaction + allow_split_transactions=None ): """ An object that provides a simple transaction context manager that allows statements execution @@ -2316,6 +2307,8 @@ class BaseTxContext(ITxContext): """ self._check_split() + if commit_tx: + self._set_finish(self._COMMIT) return self._driver( _tx_ctx_impl.execute_request_factory( @@ -2416,7 +2409,13 @@ class BaseTxContext(ITxContext): Deny all operaions with transaction after commit/rollback. Exception: double commit and double rollbacks, because it is safe """ - if self._allow_split_transactions: + allow_split_transaction = ( + self._allow_split_transactions + if self._allow_split_transactions is not None + else _default_allow_split_transaction + ) + + if allow_split_transaction: return if self._finished != "" and self._finished != allow: diff --git a/ydb/public/sdk/python3/ydb/table_test.py b/ydb/public/sdk/python3/ydb/table_test.py deleted file mode 100644 index 8e93a698f4a..00000000000 --- a/ydb/public/sdk/python3/ydb/table_test.py +++ /dev/null @@ -1,140 +0,0 @@ -from unittest import mock -from . import ( - retry_operation_impl, - YdbRetryOperationFinalResult, - issues, - YdbRetryOperationSleepOpt, - RetrySettings, -) - - -def test_retry_operation_impl(monkeypatch): - monkeypatch.setattr("random.random", lambda: 0.5) - monkeypatch.setattr( - issues.Error, - "__eq__", - lambda self, other: type(self) == type(other) and self.message == other.message, - ) - - retry_once_settings = RetrySettings( - max_retries=1, - on_ydb_error_callback=mock.Mock(), - ) - retry_once_settings.unknown_error_handler = mock.Mock() - - def get_results(callee): - res_generator = retry_operation_impl(callee, retry_settings=retry_once_settings) - results = [] - exc = None - try: - for res in res_generator: - results.append(res) - if isinstance(res, YdbRetryOperationFinalResult): - break - except Exception as e: - exc = e - - return results, exc - - class TestException(Exception): - def __init__(self, message): - super(TestException, self).__init__(message) - self.message = message - - def __eq__(self, other): - return type(self) == type(other) and self.message == other.message - - def check_unretriable_error(err_type, call_ydb_handler): - retry_once_settings.on_ydb_error_callback.reset_mock() - retry_once_settings.unknown_error_handler.reset_mock() - - results = get_results( - mock.Mock(side_effect=[err_type("test1"), err_type("test2")]) - ) - yields = results[0] - exc = results[1] - - assert yields == [] - assert exc == err_type("test1") - - if call_ydb_handler: - assert retry_once_settings.on_ydb_error_callback.call_count == 1 - retry_once_settings.on_ydb_error_callback.assert_called_with( - err_type("test1") - ) - - assert retry_once_settings.unknown_error_handler.call_count == 0 - else: - assert retry_once_settings.on_ydb_error_callback.call_count == 0 - - assert retry_once_settings.unknown_error_handler.call_count == 1 - retry_once_settings.unknown_error_handler.assert_called_with( - err_type("test1") - ) - - def check_retriable_error(err_type, backoff): - retry_once_settings.on_ydb_error_callback.reset_mock() - - results = get_results( - mock.Mock(side_effect=[err_type("test1"), err_type("test2")]) - ) - yields = results[0] - exc = results[1] - - if backoff: - assert [ - YdbRetryOperationSleepOpt(backoff.calc_timeout(0)), - YdbRetryOperationSleepOpt(backoff.calc_timeout(1)), - ] == yields - else: - assert [] == yields - - assert exc == err_type("test2") - - assert retry_once_settings.on_ydb_error_callback.call_count == 2 - retry_once_settings.on_ydb_error_callback.assert_any_call(err_type("test1")) - retry_once_settings.on_ydb_error_callback.assert_called_with(err_type("test2")) - - assert retry_once_settings.unknown_error_handler.call_count == 0 - - # check ok - assert get_results(lambda: True) == ([YdbRetryOperationFinalResult(True)], None) - - # check retry error and return result - assert get_results(mock.Mock(side_effect=[issues.Overloaded("test"), True])) == ( - [ - YdbRetryOperationSleepOpt(retry_once_settings.slow_backoff.calc_timeout(0)), - YdbRetryOperationFinalResult(True), - ], - None, - ) - - # check errors - check_retriable_error(issues.Aborted, None) - check_retriable_error(issues.BadSession, None) - - check_retriable_error(issues.NotFound, None) - with mock.patch.object(retry_once_settings, "retry_not_found", False): - check_unretriable_error(issues.NotFound, True) - - check_retriable_error(issues.InternalError, None) - with mock.patch.object(retry_once_settings, "retry_internal_error", False): - check_unretriable_error(issues.InternalError, True) - - check_retriable_error(issues.Overloaded, retry_once_settings.slow_backoff) - check_retriable_error(issues.SessionPoolEmpty, retry_once_settings.slow_backoff) - check_retriable_error(issues.ConnectionError, retry_once_settings.slow_backoff) - - check_retriable_error(issues.Unavailable, retry_once_settings.fast_backoff) - - check_unretriable_error(issues.Undetermined, True) - with mock.patch.object(retry_once_settings, "idempotent", True): - check_retriable_error(issues.Unavailable, retry_once_settings.fast_backoff) - - check_unretriable_error(issues.Error, True) - with mock.patch.object(retry_once_settings, "idempotent", True): - check_unretriable_error(issues.Error, True) - - check_unretriable_error(TestException, False) - with mock.patch.object(retry_once_settings, "idempotent", True): - check_unretriable_error(TestException, False) diff --git a/ydb/public/sdk/python3/ydb/topic.py b/ydb/public/sdk/python3/ydb/topic.py new file mode 100644 index 00000000000..efe62219cb5 --- /dev/null +++ b/ydb/public/sdk/python3/ydb/topic.py @@ -0,0 +1,389 @@ +from __future__ import annotations + +import concurrent.futures +import datetime +from dataclasses import dataclass +from typing import List, Union, Mapping, Optional, Dict, Callable + +from . import aio, Credentials, _apis, issues + +from . import driver + +from ._topic_reader.topic_reader import ( + PublicReaderSettings as TopicReaderSettings, +) + +from ._topic_reader.topic_reader_sync import TopicReaderSync as TopicReader + +from ._topic_reader.topic_reader_asyncio import ( + PublicAsyncIOReader as TopicReaderAsyncIO, +) + +from ._topic_writer.topic_writer import ( # noqa: F401 + PublicWriterSettings as TopicWriterSettings, + PublicMessage as TopicWriterMessage, + RetryPolicy as TopicWriterRetryPolicy, +) + +from ._topic_writer.topic_writer_sync import WriterSync as TopicWriter + +from ._topic_common.common import ( + wrap_operation as _wrap_operation, + create_result_wrapper as _create_result_wrapper, +) + +from ydb._topic_writer.topic_writer_asyncio import WriterAsyncIO as TopicWriterAsyncIO + +from ._grpc.grpcwrapper import ydb_topic as _ydb_topic +from ._grpc.grpcwrapper import ydb_topic_public_types as _ydb_topic_public_types +from ._grpc.grpcwrapper.ydb_topic_public_types import ( # noqa: F401 + PublicDescribeTopicResult as TopicDescription, + PublicMultipleWindowsStat as TopicStatWindow, + PublicPartitionStats as TopicPartitionStats, + PublicCodec as TopicCodec, + PublicConsumer as TopicConsumer, + PublicMeteringMode as TopicMeteringMode, +) + + +class TopicClientAsyncIO: + _closed: bool + _driver: aio.Driver + _credentials: Union[Credentials, None] + _settings: TopicClientSettings + _executor: concurrent.futures.Executor + + def __init__( + self, driver: aio.Driver, settings: Optional[TopicClientSettings] = None + ): + if not settings: + settings = TopicClientSettings() + self._closed = False + self._driver = driver + self._settings = settings + self._executor = concurrent.futures.ThreadPoolExecutor( + max_workers=settings.encode_decode_threads_count, + thread_name_prefix="topic_asyncio_executor", + ) + + def __del__(self): + self.close() + + async def create_topic( + self, + path: str, + min_active_partitions: Optional[int] = None, + partition_count_limit: Optional[int] = None, + retention_period: Optional[datetime.timedelta] = None, + retention_storage_mb: Optional[int] = None, + supported_codecs: Optional[List[Union[TopicCodec, int]]] = None, + partition_write_speed_bytes_per_second: Optional[int] = None, + partition_write_burst_bytes: Optional[int] = None, + attributes: Optional[Dict[str, str]] = None, + consumers: Optional[List[Union[TopicConsumer, str]]] = None, + metering_mode: Optional[TopicMeteringMode] = None, + ): + """ + create topic command + + :param path: full path to topic + :param min_active_partitions: Minimum partition count auto merge would stop working at. + :param partition_count_limit: Limit for total partition count, including active (open for write) + and read-only partitions. + :param retention_period: How long data in partition should be stored + :param retention_storage_mb: How much data in partition should be stored + :param supported_codecs: List of allowed codecs for writers. Writes with codec not from this list are forbidden. + Empty list mean disable codec compatibility checks for the topic. + :param partition_write_speed_bytes_per_second: Partition write speed in bytes per second + :param partition_write_burst_bytes: Burst size for write in partition, in bytes + :param attributes: User and server attributes of topic. + Server attributes starts from "_" and will be validated by server. + :param consumers: List of consumers for this topic + :param metering_mode: Metering mode for the topic in a serverless database + """ + args = locals().copy() + del args["self"] + req = _ydb_topic_public_types.CreateTopicRequestParams(**args) + req = _ydb_topic.CreateTopicRequest.from_public(req) + await self._driver( + req.to_proto(), + _apis.TopicService.Stub, + _apis.TopicService.CreateTopic, + _wrap_operation, + ) + + async def describe_topic( + self, path: str, include_stats: bool = False + ) -> TopicDescription: + args = locals().copy() + del args["self"] + req = _ydb_topic_public_types.DescribeTopicRequestParams(**args) + res = await self._driver( + req.to_proto(), + _apis.TopicService.Stub, + _apis.TopicService.DescribeTopic, + _create_result_wrapper(_ydb_topic.DescribeTopicResult), + ) # type: _ydb_topic.DescribeTopicResult + return res.to_public() + + async def drop_topic(self, path: str): + req = _ydb_topic_public_types.DropTopicRequestParams(path=path) + await self._driver( + req.to_proto(), + _apis.TopicService.Stub, + _apis.TopicService.DropTopic, + _wrap_operation, + ) + + def reader( + self, + consumer: str, + topic: str, + buffer_size_bytes: int = 50 * 1024 * 1024, + # decoders: map[codec_code] func(encoded_bytes)->decoded_bytes + decoders: Union[Mapping[int, Callable[[bytes], bytes]], None] = None, + decoder_executor: Optional[ + concurrent.futures.Executor + ] = None, # default shared client executor pool + # on_commit: Callable[["Events.OnCommit"], None] = None + # on_get_partition_start_offset: Callable[ + # ["Events.OnPartitionGetStartOffsetRequest"], + # "Events.OnPartitionGetStartOffsetResponse", + # ] = None + # on_partition_session_start: Callable[["StubEvent"], None] = None + # on_partition_session_stop: Callable[["StubEvent"], None] = None + # on_partition_session_close: Callable[["StubEvent"], None] = None # todo? + # deserializer: Union[Callable[[bytes], Any], None] = None + # one_attempt_connection_timeout: Union[float, None] = 1 + # connection_timeout: Union[float, None] = None + # retry_policy: Union["RetryPolicy", None] = None + ) -> TopicReaderAsyncIO: + + if not decoder_executor: + decoder_executor = self._executor + + args = locals() + del args["self"] + + settings = TopicReaderSettings(**args) + + return TopicReaderAsyncIO(self._driver, settings) + + def writer( + self, + topic, + *, + producer_id: Optional[str] = None, # default - random + session_metadata: Mapping[str, str] = None, + partition_id: Union[int, None] = None, + auto_seqno: bool = True, + auto_created_at: bool = True, + codec: Optional[TopicCodec] = None, # default mean auto-select + encoders: Optional[ + Mapping[_ydb_topic_public_types.PublicCodec, Callable[[bytes], bytes]] + ] = None, + encoder_executor: Optional[ + concurrent.futures.Executor + ] = None, # default shared client executor pool + ) -> TopicWriterAsyncIO: + args = locals() + del args["self"] + + settings = TopicWriterSettings(**args) + + if not settings.encoder_executor: + settings.encoder_executor = self._executor + + return TopicWriterAsyncIO(self._driver, settings) + + def close(self): + if self._closed: + return + + self._closed = True + self._executor.shutdown(wait=False, cancel_futures=True) + + def _check_closed(self): + if not self._closed: + return + + raise RuntimeError("Topic client closed") + + +class TopicClient: + _closed: bool + _driver: driver.Driver + _credentials: Union[Credentials, None] + _settings: TopicClientSettings + _executor: concurrent.futures.Executor + + def __init__(self, driver: driver.Driver, settings: Optional[TopicClientSettings]): + if not settings: + settings = TopicClientSettings() + + self._closed = False + self._driver = driver + self._settings = settings + self._executor = concurrent.futures.ThreadPoolExecutor( + max_workers=settings.encode_decode_threads_count, + thread_name_prefix="topic_asyncio_executor", + ) + + def __del__(self): + self.close() + + def create_topic( + self, + path: str, + min_active_partitions: Optional[int] = None, + partition_count_limit: Optional[int] = None, + retention_period: Optional[datetime.timedelta] = None, + retention_storage_mb: Optional[int] = None, + supported_codecs: Optional[List[Union[TopicCodec, int]]] = None, + partition_write_speed_bytes_per_second: Optional[int] = None, + partition_write_burst_bytes: Optional[int] = None, + attributes: Optional[Dict[str, str]] = None, + consumers: Optional[List[Union[TopicConsumer, str]]] = None, + metering_mode: Optional[TopicMeteringMode] = None, + ): + """ + create topic command + + :param path: full path to topic + :param min_active_partitions: Minimum partition count auto merge would stop working at. + :param partition_count_limit: Limit for total partition count, including active (open for write) + and read-only partitions. + :param retention_period: How long data in partition should be stored + :param retention_storage_mb: How much data in partition should be stored + :param supported_codecs: List of allowed codecs for writers. Writes with codec not from this list are forbidden. + Empty list mean disable codec compatibility checks for the topic. + :param partition_write_speed_bytes_per_second: Partition write speed in bytes per second + :param partition_write_burst_bytes: Burst size for write in partition, in bytes + :param attributes: User and server attributes of topic. + Server attributes starts from "_" and will be validated by server. + :param consumers: List of consumers for this topic + :param metering_mode: Metering mode for the topic in a serverless database + """ + args = locals().copy() + del args["self"] + self._check_closed() + + req = _ydb_topic_public_types.CreateTopicRequestParams(**args) + req = _ydb_topic.CreateTopicRequest.from_public(req) + self._driver( + req.to_proto(), + _apis.TopicService.Stub, + _apis.TopicService.CreateTopic, + _wrap_operation, + ) + + def describe_topic( + self, path: str, include_stats: bool = False + ) -> TopicDescription: + args = locals().copy() + del args["self"] + self._check_closed() + + req = _ydb_topic_public_types.DescribeTopicRequestParams(**args) + res = self._driver( + req.to_proto(), + _apis.TopicService.Stub, + _apis.TopicService.DescribeTopic, + _create_result_wrapper(_ydb_topic.DescribeTopicResult), + ) # type: _ydb_topic.DescribeTopicResult + return res.to_public() + + def drop_topic(self, path: str): + self._check_closed() + + req = _ydb_topic_public_types.DropTopicRequestParams(path=path) + self._driver( + req.to_proto(), + _apis.TopicService.Stub, + _apis.TopicService.DropTopic, + _wrap_operation, + ) + + def reader( + self, + consumer: str, + topic: str, + buffer_size_bytes: int = 50 * 1024 * 1024, + # decoders: map[codec_code] func(encoded_bytes)->decoded_bytes + decoders: Union[Mapping[int, Callable[[bytes], bytes]], None] = None, + decoder_executor: Optional[ + concurrent.futures.Executor + ] = None, # default shared client executor pool + # on_commit: Callable[["Events.OnCommit"], None] = None + # on_get_partition_start_offset: Callable[ + # ["Events.OnPartitionGetStartOffsetRequest"], + # "Events.OnPartitionGetStartOffsetResponse", + # ] = None + # on_partition_session_start: Callable[["StubEvent"], None] = None + # on_partition_session_stop: Callable[["StubEvent"], None] = None + # on_partition_session_close: Callable[["StubEvent"], None] = None # todo? + # deserializer: Union[Callable[[bytes], Any], None] = None + # one_attempt_connection_timeout: Union[float, None] = 1 + # connection_timeout: Union[float, None] = None + # retry_policy: Union["RetryPolicy", None] = None + ) -> TopicReader: + if not decoder_executor: + decoder_executor = self._executor + + args = locals() + del args["self"] + self._check_closed() + + settings = TopicReaderSettings(**args) + + return TopicReader(self._driver, settings) + + def writer( + self, + topic, + *, + producer_id: Optional[str] = None, # default - random + session_metadata: Mapping[str, str] = None, + partition_id: Union[int, None] = None, + auto_seqno: bool = True, + auto_created_at: bool = True, + codec: Optional[TopicCodec] = None, # default mean auto-select + encoders: Optional[ + Mapping[_ydb_topic_public_types.PublicCodec, Callable[[bytes], bytes]] + ] = None, + encoder_executor: Optional[ + concurrent.futures.Executor + ] = None, # default shared client executor pool + ) -> TopicWriter: + args = locals() + del args["self"] + self._check_closed() + + settings = TopicWriterSettings(**args) + + if not settings.encoder_executor: + settings.encoder_executor = self._executor + + return TopicWriter(self._driver, settings) + + def close(self): + if self._closed: + return + + self._closed = True + self._executor.shutdown(wait=False, cancel_futures=True) + + def _check_closed(self): + if not self._closed: + return + + raise RuntimeError("Topic client closed") + + +@dataclass +class TopicClientSettings: + encode_decode_threads_count: int = 4 + + +class TopicError(issues.Error): + pass diff --git a/ydb/public/sdk/python3/ydb/types.py b/ydb/public/sdk/python3/ydb/types.py index a62c8a74a0b..5ffa16e6b6f 100644 --- a/ydb/public/sdk/python3/ydb/types.py +++ b/ydb/public/sdk/python3/ydb/types.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- import abc import enum -import six import json from . import _utilities, _apis from datetime import date, datetime, timedelta @@ -13,13 +12,6 @@ from google.protobuf import struct_pb2 _SECONDS_IN_DAY = 60 * 60 * 24 _EPOCH = datetime(1970, 1, 1) -if six.PY3: - _from_bytes = None -else: - - def _from_bytes(x, table_client_settings): - return _utilities.from_bytes(x) - def _from_date(x, table_client_settings): if ( @@ -52,8 +44,6 @@ def _from_json(x, table_client_settings): and table_client_settings._native_json_in_result_sets ): return json.loads(x) - if _from_bytes is not None: - return _from_bytes(x, table_client_settings) return x @@ -122,7 +112,7 @@ class PrimitiveType(enum.Enum): Float = _apis.primitive_types.FLOAT, "float_value" String = _apis.primitive_types.STRING, "bytes_value" - Utf8 = _apis.primitive_types.UTF8, "text_value", _from_bytes + Utf8 = _apis.primitive_types.UTF8, "text_value" Yson = _apis.primitive_types.YSON, "bytes_value" Json = _apis.primitive_types.JSON, "text_value", _from_json @@ -152,7 +142,7 @@ class PrimitiveType(enum.Enum): _to_interval, ) - DyNumber = _apis.primitive_types.DYNUMBER, "text_value", _from_bytes + DyNumber = _apis.primitive_types.DYNUMBER, "text_value" def __init__(self, idn, proto_field, to_obj=None, from_obj=None): self._idn_ = idn diff --git a/ydb/public/sdk/python3/ydb/ydb_version.py b/ydb/public/sdk/python3/ydb/ydb_version.py index dde8252c21c..ef5ee52a9bd 100644 --- a/ydb/public/sdk/python3/ydb/ydb_version.py +++ b/ydb/public/sdk/python3/ydb/ydb_version.py @@ -1 +1 @@ -VERSION = "2.13.2" +VERSION = "3.0.1b9" |