diff options
| author | Alexander Smirnov <[email protected]> | 2025-04-01 00:52:07 +0000 |
|---|---|---|
| committer | Alexander Smirnov <[email protected]> | 2025-04-01 00:52:07 +0000 |
| commit | 7b1dd0572a50d5e1d1dc0d973dd7b686ec5d1cfa (patch) | |
| tree | 77b99773cdbad1af81f9e310eebc0d9f361dafba | |
| parent | 4ae39ac0db24e7564f13a3d0545a540bd647283c (diff) | |
| parent | 0afdf94708e28c1ce387764b1a93c32da0cd2350 (diff) | |
Merge branch 'rightlib' into merge-libs-250401-0050
90 files changed, 1988 insertions, 368 deletions
diff --git a/build/ymake.core.conf b/build/ymake.core.conf index b5173c0ebb4..a71f3268083 100644 --- a/build/ymake.core.conf +++ b/build/ymake.core.conf @@ -727,10 +727,12 @@ module _BASE_UNIT: _BARE_UNIT { when ($CLANG) { when ($PGO_ADD == "yes") { CFLAGS+=-fprofile-instr-generate + NO_PGO_CFLAGS=-fno-profile-instr-generate LDFLAGS+=-fprofile-instr-generate } when ($PGO_PATH) { CFLAGS+=-fprofile-instr-use=$PGO_PATH -Wno-profile-instr-unprofiled -Wno-profile-instr-out-of-date + NO_PGO_CFLAGS=-fno-profile-instr-use LDFLAGS+=-fprofile-instr-use=$PGO_PATH } } @@ -2828,6 +2830,7 @@ SSE4_CFLAGS= XOP_CFLAGS= NO_LTO_CFLAGS= +NO_PGO_CFLAGS= # tag:cpu when (($ARCH_X86_64 || $ARCH_I386) && $DISABLE_INSTRUCTION_SETS != "yes") { @@ -4722,8 +4725,8 @@ macro CLANG_EMIT_AST_CXX(Input, Output, Opts...) { ### Emit LLVM bytecode from .cpp file. BC_CXXFLAGS, LLVM_OPTS and C_FLAGS_PLATFORM are passed in, while CFLAGS are not. ### Note: Output name is used as is, no extension added. macro LLVM_COMPILE_CXX(Input, Output, Opts...) { - .CMD=$YMAKE_PYTHON ${input:"build/scripts/clang_wrapper.py"} $WINDOWS ${CLANG_BC_ROOT}/bin/clang++ ${pre=-I:_C__INCLUDE} $BC_CXXFLAGS $C_FLAGS_PLATFORM -Wno-unknown-warning-option $LLVM_OPTS ${NO_LTO_CFLAGS} -emit-llvm -c ${input:Input} -o ${noauto;output:Output} $Opts ${hide;kv:"p BC"} ${hide;kv:"pc light-green"} - .SEM=target_macroses-ITEM && target_macroses-macro llvm_compile_cxx && target_macroses-args ${input:Input} ${noauto;output:Output} ${"${CLANGPLUSPLUS}"} -Wno-unknown-warning-option $LLVM_OPTS ${NO_LTO_CFLAGS} -emit-llvm ${Opts} + .CMD=$YMAKE_PYTHON ${input:"build/scripts/clang_wrapper.py"} $WINDOWS ${CLANG_BC_ROOT}/bin/clang++ ${pre=-I:_C__INCLUDE} $BC_CXXFLAGS $C_FLAGS_PLATFORM -Wno-unknown-warning-option $LLVM_OPTS ${NO_LTO_CFLAGS} $NO_PGO_CFLAGS -emit-llvm -c ${input:Input} -o ${noauto;output:Output} $Opts ${hide;kv:"p BC"} ${hide;kv:"pc light-green"} + .SEM=target_macroses-ITEM && target_macroses-macro llvm_compile_cxx && target_macroses-args ${input:Input} ${noauto;output:Output} ${"${CLANGPLUSPLUS}"} -Wno-unknown-warning-option $LLVM_OPTS ${NO_LTO_CFLAGS} $NO_PGO_CFLAGS -emit-llvm ${Opts} .STRUCT_CMD=yes when ($CLANG_BC_ROOT == "") { _OK = no @@ -4737,8 +4740,8 @@ macro LLVM_COMPILE_CXX(Input, Output, Opts...) { ### Emit LLVM bytecode from .c file. BC_CFLAGS, LLVM_OPTS and C_FLAGS_PLATFORM are passed in, while CFLAGS are not. ### Note: Output name is used as is, no extension added. macro LLVM_COMPILE_C(Input, Output, Opts...) { - .CMD=$YMAKE_PYTHON ${input:"build/scripts/clang_wrapper.py"} $WINDOWS ${CLANG_BC_ROOT}/bin/clang ${pre=-I:_C__INCLUDE} $BC_CFLAGS $C_FLAGS_PLATFORM $LLVM_OPTS ${NO_LTO_CFLAGS} -emit-llvm -c ${input:Input} -o ${noauto;output:Output} $Opts ${hide;kv:"p BC"} ${hide;kv:"pc light-green"} - .SEM=target_macroses-ITEM && target_macroses-macro llvm_compile_c && target_macroses-args ${input:Input} ${noauto;output:Output} ${"${CLANGC}"} -Wno-unknown-warning-option $LLVM_OPTS ${NO_LTO_CFLAGS} -emit-llvm ${Opts} + .CMD=$YMAKE_PYTHON ${input:"build/scripts/clang_wrapper.py"} $WINDOWS ${CLANG_BC_ROOT}/bin/clang ${pre=-I:_C__INCLUDE} $BC_CFLAGS $C_FLAGS_PLATFORM $LLVM_OPTS ${NO_LTO_CFLAGS} $NO_PGO_CFLAGS -emit-llvm -c ${input:Input} -o ${noauto;output:Output} $Opts ${hide;kv:"p BC"} ${hide;kv:"pc light-green"} + .SEM=target_macroses-ITEM && target_macroses-macro llvm_compile_c && target_macroses-args ${input:Input} ${noauto;output:Output} ${"${CLANGC}"} -Wno-unknown-warning-option $LLVM_OPTS ${NO_LTO_CFLAGS} $NO_PGO_CFLAGS -emit-llvm ${Opts} .STRUCT_CMD=yes when ($CLANG_BC_ROOT == "") { _OK = no diff --git a/contrib/python/ydb/py3/.dist-info/METADATA b/contrib/python/ydb/py3/.dist-info/METADATA index b6911ce75e8..1f52419b882 100644 --- a/contrib/python/ydb/py3/.dist-info/METADATA +++ b/contrib/python/ydb/py3/.dist-info/METADATA @@ -1,6 +1,6 @@ Metadata-Version: 2.1 Name: ydb -Version: 3.19.3 +Version: 3.20.0 Summary: YDB Python SDK Home-page: http://github.com/ydb-platform/ydb-python-sdk Author: Yandex LLC diff --git a/contrib/python/ydb/py3/ya.make b/contrib/python/ydb/py3/ya.make index 71cfb8fa720..847d0def900 100644 --- a/contrib/python/ydb/py3/ya.make +++ b/contrib/python/ydb/py3/ya.make @@ -2,7 +2,7 @@ PY3_LIBRARY() -VERSION(3.19.3) +VERSION(3.20.0) LICENSE(Apache-2.0) diff --git a/contrib/python/ydb/py3/ydb/_apis.py b/contrib/python/ydb/py3/ydb/_apis.py index fc28d0ceb29..fc6f16e287c 100644 --- a/contrib/python/ydb/py3/ydb/_apis.py +++ b/contrib/python/ydb/py3/ydb/_apis.py @@ -115,6 +115,7 @@ class TopicService(object): DropTopic = "DropTopic" StreamRead = "StreamRead" StreamWrite = "StreamWrite" + UpdateOffsetsInTransaction = "UpdateOffsetsInTransaction" class QueryService(object): diff --git a/contrib/python/ydb/py3/ydb/_errors.py b/contrib/python/ydb/py3/ydb/_errors.py index 17002d25749..1e2308ef394 100644 --- a/contrib/python/ydb/py3/ydb/_errors.py +++ b/contrib/python/ydb/py3/ydb/_errors.py @@ -5,6 +5,7 @@ from . import issues _errors_retriable_fast_backoff_types = [ issues.Unavailable, + issues.ClientInternalError, ] _errors_retriable_slow_backoff_types = [ issues.Aborted, diff --git a/contrib/python/ydb/py3/ydb/_grpc/grpcwrapper/common_utils.py b/contrib/python/ydb/py3/ydb/_grpc/grpcwrapper/common_utils.py index 95a5744313e..10d98918c30 100644 --- a/contrib/python/ydb/py3/ydb/_grpc/grpcwrapper/common_utils.py +++ b/contrib/python/ydb/py3/ydb/_grpc/grpcwrapper/common_utils.py @@ -160,9 +160,6 @@ class GrpcWrapperAsyncIO(IGrpcWrapperAsyncIO): self._stream_call = None self._wait_executor = None - def __del__(self): - self._clean_executor(wait=False) - async def start(self, driver: SupportedDriverType, stub, method): if asyncio.iscoroutinefunction(driver.__call__): await self._start_asyncio_driver(driver, stub, method) diff --git a/contrib/python/ydb/py3/ydb/_grpc/grpcwrapper/ydb_topic.py b/contrib/python/ydb/py3/ydb/_grpc/grpcwrapper/ydb_topic.py index 5b22c7cf862..0f8a0f03a7a 100644 --- a/contrib/python/ydb/py3/ydb/_grpc/grpcwrapper/ydb_topic.py +++ b/contrib/python/ydb/py3/ydb/_grpc/grpcwrapper/ydb_topic.py @@ -141,6 +141,18 @@ class UpdateTokenResponse(IFromProto): ######################################################################################################################## +@dataclass +class TransactionIdentity(IToProto): + tx_id: str + session_id: str + + def to_proto(self) -> ydb_topic_pb2.TransactionIdentity: + return ydb_topic_pb2.TransactionIdentity( + id=self.tx_id, + session=self.session_id, + ) + + class StreamWriteMessage: @dataclass() class InitRequest(IToProto): @@ -199,6 +211,7 @@ class StreamWriteMessage: class WriteRequest(IToProto): messages: typing.List["StreamWriteMessage.WriteRequest.MessageData"] codec: int + tx_identity: Optional[TransactionIdentity] @dataclass class MessageData(IToProto): @@ -237,6 +250,9 @@ class StreamWriteMessage: proto = ydb_topic_pb2.StreamWriteMessage.WriteRequest() proto.codec = self.codec + if self.tx_identity is not None: + proto.tx.CopyFrom(self.tx_identity.to_proto()) + for message in self.messages: proto_mess = proto.messages.add() proto_mess.CopyFrom(message.to_proto()) @@ -297,6 +313,8 @@ class StreamWriteMessage: ) except ValueError: message_write_status = reason + elif proto_ack.HasField("written_in_tx"): + message_write_status = StreamWriteMessage.WriteResponse.WriteAck.StatusWrittenInTx() else: raise NotImplementedError("unexpected ack status") @@ -309,6 +327,9 @@ class StreamWriteMessage: class StatusWritten: offset: int + class StatusWrittenInTx: + pass + @dataclass class StatusSkipped: reason: "StreamWriteMessage.WriteResponse.WriteAck.StatusSkipped.Reason" @@ -1197,6 +1218,52 @@ class MeteringMode(int, IFromProto, IFromPublic, IToPublic): @dataclass +class UpdateOffsetsInTransactionRequest(IToProto): + tx: TransactionIdentity + topics: List[UpdateOffsetsInTransactionRequest.TopicOffsets] + consumer: str + + def to_proto(self): + return ydb_topic_pb2.UpdateOffsetsInTransactionRequest( + tx=self.tx.to_proto(), + consumer=self.consumer, + topics=list( + map( + UpdateOffsetsInTransactionRequest.TopicOffsets.to_proto, + self.topics, + ) + ), + ) + + @dataclass + class TopicOffsets(IToProto): + path: str + partitions: List[UpdateOffsetsInTransactionRequest.TopicOffsets.PartitionOffsets] + + def to_proto(self): + return ydb_topic_pb2.UpdateOffsetsInTransactionRequest.TopicOffsets( + path=self.path, + partitions=list( + map( + UpdateOffsetsInTransactionRequest.TopicOffsets.PartitionOffsets.to_proto, + self.partitions, + ) + ), + ) + + @dataclass + class PartitionOffsets(IToProto): + partition_id: int + partition_offsets: List[OffsetsRange] + + def to_proto(self) -> ydb_topic_pb2.UpdateOffsetsInTransactionRequest.TopicOffsets.PartitionOffsets: + return ydb_topic_pb2.UpdateOffsetsInTransactionRequest.TopicOffsets.PartitionOffsets( + partition_id=self.partition_id, + partition_offsets=list(map(OffsetsRange.to_proto, self.partition_offsets)), + ) + + +@dataclass class CreateTopicRequest(IToProto, IFromPublic): path: str partitioning_settings: "PartitioningSettings" diff --git a/contrib/python/ydb/py3/ydb/_topic_reader/datatypes.py b/contrib/python/ydb/py3/ydb/_topic_reader/datatypes.py index b48501aff2f..74f06a086fc 100644 --- a/contrib/python/ydb/py3/ydb/_topic_reader/datatypes.py +++ b/contrib/python/ydb/py3/ydb/_topic_reader/datatypes.py @@ -108,6 +108,9 @@ class PartitionSession: waiter = self._ack_waiters.popleft() waiter._finish_ok() + def _update_last_commited_offset_if_needed(self, offset: int): + self.committed_offset = max(self.committed_offset, offset) + def close(self): if self.closed: return @@ -211,3 +214,9 @@ class PublicBatch(ICommittable, ISessionAlive): self._bytes_size = self._bytes_size - new_batch._bytes_size return new_batch + + def _update_partition_offsets(self, tx, exc=None): + if exc is not None: + return + offsets = self._commit_get_offsets_range() + self._partition_session._update_last_commited_offset_if_needed(offsets.end) diff --git a/contrib/python/ydb/py3/ydb/_topic_reader/topic_reader_asyncio.py b/contrib/python/ydb/py3/ydb/_topic_reader/topic_reader_asyncio.py index 7061b4e449c..c9704d5542a 100644 --- a/contrib/python/ydb/py3/ydb/_topic_reader/topic_reader_asyncio.py +++ b/contrib/python/ydb/py3/ydb/_topic_reader/topic_reader_asyncio.py @@ -5,7 +5,7 @@ import concurrent.futures import gzip import typing from asyncio import Task -from collections import OrderedDict +from collections import defaultdict, OrderedDict from typing import Optional, Set, Dict, Union, Callable import ydb @@ -19,17 +19,24 @@ from . import topic_reader from .._grpc.grpcwrapper.common_utils import ( IGrpcWrapperAsyncIO, SupportedDriverType, + to_thread, GrpcWrapperAsyncIO, ) from .._grpc.grpcwrapper.ydb_topic import ( StreamReadMessage, UpdateTokenRequest, UpdateTokenResponse, + UpdateOffsetsInTransactionRequest, Codec, ) from .._errors import check_retriable_error import logging +from ..query.base import TxEvent + +if typing.TYPE_CHECKING: + from ..query.transaction import BaseQueryTxContext + logger = logging.getLogger(__name__) @@ -77,7 +84,7 @@ class PublicAsyncIOReader: ): self._loop = asyncio.get_running_loop() self._closed = False - self._reconnector = ReaderReconnector(driver, settings) + self._reconnector = ReaderReconnector(driver, settings, self._loop) self._parent = _parent async def __aenter__(self): @@ -88,8 +95,7 @@ class PublicAsyncIOReader: def __del__(self): if not self._closed: - task = self._loop.create_task(self.close(flush=False)) - topic_common.wrap_set_name_for_asyncio_task(task, task_name="close reader") + logger.warning("Topic reader was not closed properly. Consider using method close().") async def wait_message(self): """ @@ -112,6 +118,23 @@ class PublicAsyncIOReader: max_messages=max_messages, ) + async def receive_batch_with_tx( + self, + tx: "BaseQueryTxContext", + max_messages: typing.Union[int, None] = None, + ) -> typing.Union[datatypes.PublicBatch, None]: + """ + Get one messages batch with tx from reader. + All messages in a batch from same partition. + + use asyncio.wait_for for wait with timeout. + """ + await self._reconnector.wait_message() + return self._reconnector.receive_batch_with_tx_nowait( + tx=tx, + max_messages=max_messages, + ) + async def receive_message(self) -> typing.Optional[datatypes.PublicMessage]: """ Block until receive new message @@ -165,11 +188,18 @@ class ReaderReconnector: _state_changed: asyncio.Event _stream_reader: Optional["ReaderStream"] _first_error: asyncio.Future[YdbError] + _tx_to_batches_map: Dict[str, typing.List[datatypes.PublicBatch]] - def __init__(self, driver: Driver, settings: topic_reader.PublicReaderSettings): + def __init__( + self, + driver: Driver, + settings: topic_reader.PublicReaderSettings, + loop: Optional[asyncio.AbstractEventLoop] = None, + ): self._id = self._static_reader_reconnector_counter.inc_and_get() self._settings = settings self._driver = driver + self._loop = loop if loop is not None else asyncio.get_running_loop() self._background_tasks = set() self._state_changed = asyncio.Event() @@ -177,6 +207,8 @@ class ReaderReconnector: self._background_tasks.add(asyncio.create_task(self._connection_loop())) self._first_error = asyncio.get_running_loop().create_future() + self._tx_to_batches_map = dict() + async def _connection_loop(self): attempt = 0 while True: @@ -190,6 +222,7 @@ class ReaderReconnector: if not retry_info.is_retriable: self._set_first_error(err) return + await asyncio.sleep(retry_info.sleep_timeout_seconds) attempt += 1 @@ -222,9 +255,87 @@ class ReaderReconnector: max_messages=max_messages, ) + def receive_batch_with_tx_nowait(self, tx: "BaseQueryTxContext", max_messages: Optional[int] = None): + batch = self._stream_reader.receive_batch_nowait( + max_messages=max_messages, + ) + + self._init_tx(tx) + + self._tx_to_batches_map[tx.tx_id].append(batch) + + tx._add_callback(TxEvent.AFTER_COMMIT, batch._update_partition_offsets, self._loop) + + return batch + def receive_message_nowait(self): return self._stream_reader.receive_message_nowait() + def _init_tx(self, tx: "BaseQueryTxContext"): + if tx.tx_id not in self._tx_to_batches_map: # Init tx callbacks + self._tx_to_batches_map[tx.tx_id] = [] + tx._add_callback(TxEvent.BEFORE_COMMIT, self._commit_batches_with_tx, self._loop) + tx._add_callback(TxEvent.AFTER_COMMIT, self._handle_after_tx_commit, self._loop) + tx._add_callback(TxEvent.AFTER_ROLLBACK, self._handle_after_tx_rollback, self._loop) + + async def _commit_batches_with_tx(self, tx: "BaseQueryTxContext"): + grouped_batches = defaultdict(lambda: defaultdict(list)) + for batch in self._tx_to_batches_map[tx.tx_id]: + grouped_batches[batch._partition_session.topic_path][batch._partition_session.partition_id].append(batch) + + request = UpdateOffsetsInTransactionRequest(tx=tx._tx_identity(), consumer=self._settings.consumer, topics=[]) + + for topic_path in grouped_batches: + topic_offsets = UpdateOffsetsInTransactionRequest.TopicOffsets(path=topic_path, partitions=[]) + for partition_id in grouped_batches[topic_path]: + partition_offsets = UpdateOffsetsInTransactionRequest.TopicOffsets.PartitionOffsets( + partition_id=partition_id, + partition_offsets=[ + batch._commit_get_offsets_range() for batch in grouped_batches[topic_path][partition_id] + ], + ) + topic_offsets.partitions.append(partition_offsets) + request.topics.append(topic_offsets) + + try: + return await self._do_commit_batches_with_tx_call(request) + except BaseException: + exc = issues.ClientInternalError("Failed to update offsets in tx.") + tx._set_external_error(exc) + self._stream_reader._set_first_error(exc) + finally: + del self._tx_to_batches_map[tx.tx_id] + + async def _do_commit_batches_with_tx_call(self, request: UpdateOffsetsInTransactionRequest): + args = [ + request.to_proto(), + _apis.TopicService.Stub, + _apis.TopicService.UpdateOffsetsInTransaction, + topic_common.wrap_operation, + ] + + if asyncio.iscoroutinefunction(self._driver.__call__): + res = await self._driver(*args) + else: + res = await to_thread(self._driver, *args, executor=None) + + return res + + async def _handle_after_tx_rollback(self, tx: "BaseQueryTxContext", exc: Optional[BaseException]) -> None: + if tx.tx_id in self._tx_to_batches_map: + del self._tx_to_batches_map[tx.tx_id] + exc = issues.ClientInternalError("Reconnect due to transaction rollback") + self._stream_reader._set_first_error(exc) + + async def _handle_after_tx_commit(self, tx: "BaseQueryTxContext", exc: Optional[BaseException]) -> None: + if tx.tx_id in self._tx_to_batches_map: + del self._tx_to_batches_map[tx.tx_id] + + if exc is not None: + self._stream_reader._set_first_error( + issues.ClientInternalError("Reconnect due to transaction commit failed") + ) + def commit(self, batch: datatypes.ICommittable) -> datatypes.PartitionSession.CommitAckWaiter: return self._stream_reader.commit(batch) diff --git a/contrib/python/ydb/py3/ydb/_topic_reader/topic_reader_sync.py b/contrib/python/ydb/py3/ydb/_topic_reader/topic_reader_sync.py index eda1d374fc3..3e6806d06fe 100644 --- a/contrib/python/ydb/py3/ydb/_topic_reader/topic_reader_sync.py +++ b/contrib/python/ydb/py3/ydb/_topic_reader/topic_reader_sync.py @@ -1,5 +1,6 @@ import asyncio import concurrent.futures +import logging import typing from typing import List, Union, Optional @@ -20,6 +21,11 @@ from ydb._topic_reader.topic_reader_asyncio import ( TopicReaderClosedError, ) +if typing.TYPE_CHECKING: + from ..query.transaction import BaseQueryTxContext + +logger = logging.getLogger(__name__) + class TopicReaderSync: _caller: CallFromSyncToAsync @@ -52,7 +58,8 @@ class TopicReaderSync: self._parent = _parent def __del__(self): - self.close(flush=False) + if not self._closed: + logger.warning("Topic reader was not closed properly. Consider using method close().") def __enter__(self): return self @@ -109,6 +116,31 @@ class TopicReaderSync: timeout, ) + def receive_batch_with_tx( + self, + tx: "BaseQueryTxContext", + *, + max_messages: typing.Union[int, None] = None, + max_bytes: typing.Union[int, None] = None, + timeout: Union[float, None] = None, + ) -> Union[PublicBatch, None]: + """ + Get one messages batch with tx from reader + It has no async_ version for prevent lost messages, use async_wait_message as signal for new batches available. + + if no new message in timeout seconds (default - infinite): raise TimeoutError() + if timeout <= 0 - it will fast wait only one event loop cycle - without wait any i/o operations or pauses, get messages from internal buffer only. + """ + self._check_closed() + + return self._caller.safe_call_with_result( + self._async_reader.receive_batch_with_tx( + tx=tx, + max_messages=max_messages, + ), + timeout, + ) + def commit(self, mess: typing.Union[datatypes.PublicMessage, datatypes.PublicBatch]): """ Put commit message to internal buffer. diff --git a/contrib/python/ydb/py3/ydb/_topic_writer/topic_writer.py b/contrib/python/ydb/py3/ydb/_topic_writer/topic_writer.py index aa5fe9749a7..a3e407ed86d 100644 --- a/contrib/python/ydb/py3/ydb/_topic_writer/topic_writer.py +++ b/contrib/python/ydb/py3/ydb/_topic_writer/topic_writer.py @@ -11,6 +11,7 @@ import typing import ydb.aio from .._grpc.grpcwrapper.ydb_topic import StreamWriteMessage +from .._grpc.grpcwrapper.ydb_topic import TransactionIdentity from .._grpc.grpcwrapper.common_utils import IToProto from .._grpc.grpcwrapper.ydb_topic_public_types import PublicCodec from .. import connection @@ -53,8 +54,12 @@ class PublicWriteResult: class Skipped: pass + @dataclass(eq=True) + class WrittenInTx: + pass + -PublicWriteResultTypes = Union[PublicWriteResult.Written, PublicWriteResult.Skipped] +PublicWriteResultTypes = Union[PublicWriteResult.Written, PublicWriteResult.Skipped, PublicWriteResult.WrittenInTx] class WriterSettings(PublicWriterSettings): @@ -205,6 +210,7 @@ def default_serializer_message_content(data: Any) -> bytes: def messages_to_proto_requests( messages: List[InternalMessage], + tx_identity: Optional[TransactionIdentity], ) -> List[StreamWriteMessage.FromClient]: gropus = _slit_messages_for_send(messages) @@ -215,6 +221,7 @@ def messages_to_proto_requests( StreamWriteMessage.WriteRequest( messages=list(map(InternalMessage.to_message_data, group)), codec=group[0].codec, + tx_identity=tx_identity, ) ) res.append(req) @@ -239,6 +246,7 @@ _message_data_overhead = ( ), ], codec=20000, + tx_identity=None, ) ) .to_proto() diff --git a/contrib/python/ydb/py3/ydb/_topic_writer/topic_writer_asyncio.py b/contrib/python/ydb/py3/ydb/_topic_writer/topic_writer_asyncio.py index 32d8fefe51c..1ea6c25028b 100644 --- a/contrib/python/ydb/py3/ydb/_topic_writer/topic_writer_asyncio.py +++ b/contrib/python/ydb/py3/ydb/_topic_writer/topic_writer_asyncio.py @@ -1,7 +1,6 @@ import asyncio import concurrent.futures import datetime -import functools import gzip import typing from collections import deque @@ -35,6 +34,7 @@ from .._grpc.grpcwrapper.ydb_topic import ( UpdateTokenRequest, UpdateTokenResponse, StreamWriteMessage, + TransactionIdentity, WriterMessagesFromServerToClient, ) from .._grpc.grpcwrapper.common_utils import ( @@ -43,6 +43,11 @@ from .._grpc.grpcwrapper.common_utils import ( GrpcWrapperAsyncIO, ) +from ..query.base import TxEvent + +if typing.TYPE_CHECKING: + from ..query.transaction import BaseQueryTxContext + logger = logging.getLogger(__name__) @@ -74,10 +79,8 @@ class WriterAsyncIO: raise def __del__(self): - if self._closed or self._loop.is_closed(): - return - - self._loop.call_soon(functools.partial(self.close, flush=False)) + if not self._closed: + logger.warning("Topic writer was not closed properly. Consider using method close().") async def close(self, *, flush: bool = True): if self._closed: @@ -164,6 +167,57 @@ class WriterAsyncIO: return await self._reconnector.wait_init() +class TxWriterAsyncIO(WriterAsyncIO): + _tx: "BaseQueryTxContext" + + def __init__( + self, + tx: "BaseQueryTxContext", + driver: SupportedDriverType, + settings: PublicWriterSettings, + _client=None, + _is_implicit=False, + ): + self._tx = tx + self._loop = asyncio.get_running_loop() + self._closed = False + self._reconnector = WriterAsyncIOReconnector(driver=driver, settings=WriterSettings(settings), tx=self._tx) + self._parent = _client + self._is_implicit = _is_implicit + + # For some reason, creating partition could conflict with other session operations. + # Could be removed later. + self._first_write = True + + tx._add_callback(TxEvent.BEFORE_COMMIT, self._on_before_commit, self._loop) + tx._add_callback(TxEvent.BEFORE_ROLLBACK, self._on_before_rollback, self._loop) + + async def write( + self, + messages: Union[Message, List[Message]], + ): + """ + send one or number of messages to server. + it put message to internal buffer + + For wait with timeout use asyncio.wait_for. + """ + if self._first_write: + self._first_write = False + return await super().write_with_ack(messages) + return await super().write(messages) + + async def _on_before_commit(self, tx: "BaseQueryTxContext"): + if self._is_implicit: + return + await self.close() + + async def _on_before_rollback(self, tx: "BaseQueryTxContext"): + if self._is_implicit: + return + await self.close(flush=False) + + class WriterAsyncIOReconnector: _closed: bool _loop: asyncio.AbstractEventLoop @@ -178,6 +232,7 @@ class WriterAsyncIOReconnector: _codec_selector_batch_num: int _codec_selector_last_codec: Optional[PublicCodec] _codec_selector_check_batches_interval: int + _tx: Optional["BaseQueryTxContext"] if typing.TYPE_CHECKING: _messages_for_encode: asyncio.Queue[List[InternalMessage]] @@ -195,7 +250,9 @@ class WriterAsyncIOReconnector: _stop_reason: asyncio.Future _init_info: Optional[PublicWriterInitInfo] - def __init__(self, driver: SupportedDriverType, settings: WriterSettings): + def __init__( + self, driver: SupportedDriverType, settings: WriterSettings, tx: Optional["BaseQueryTxContext"] = None + ): self._closed = False self._loop = asyncio.get_running_loop() self._driver = driver @@ -205,6 +262,7 @@ class WriterAsyncIOReconnector: self._init_info = None self._stream_connected = asyncio.Event() self._settings = settings + self._tx = tx self._codec_functions = { PublicCodec.RAW: lambda data: data, @@ -354,10 +412,12 @@ class WriterAsyncIOReconnector: # noinspection PyBroadException stream_writer = None try: + tx_identity = None if self._tx is None else self._tx._tx_identity() stream_writer = await WriterAsyncIOStream.create( self._driver, self._init_message, self._settings.update_token_interval, + tx_identity=tx_identity, ) try: if self._init_info is None: @@ -387,7 +447,7 @@ class WriterAsyncIOReconnector: done.pop().result() # need for raise exception - reason of stop task except issues.Error as err: err_info = check_retriable_error(err, retry_settings, attempt) - if not err_info.is_retriable: + if not err_info.is_retriable or self._tx is not None: # no retries in tx writer self._stop(err) return @@ -533,6 +593,8 @@ class WriterAsyncIOReconnector: result = PublicWriteResult.Skipped() elif isinstance(status, write_ack_msg.StatusWritten): result = PublicWriteResult.Written(offset=status.offset) + elif isinstance(status, write_ack_msg.StatusWrittenInTx): + result = PublicWriteResult.WrittenInTx() else: raise TopicWriterError("internal error - receive unexpected ack message.") message_future.set_result(result) @@ -597,10 +659,13 @@ class WriterAsyncIOStream: _update_token_event: asyncio.Event _get_token_function: Optional[Callable[[], str]] + _tx_identity: Optional[TransactionIdentity] + def __init__( self, update_token_interval: Optional[Union[int, float]] = None, get_token_function: Optional[Callable[[], str]] = None, + tx_identity: Optional[TransactionIdentity] = None, ): self._closed = False @@ -609,6 +674,8 @@ class WriterAsyncIOStream: self._update_token_event = asyncio.Event() self._update_token_task = None + self._tx_identity = tx_identity + async def close(self): if self._closed: return @@ -625,6 +692,7 @@ class WriterAsyncIOStream: driver: SupportedDriverType, init_request: StreamWriteMessage.InitRequest, update_token_interval: Optional[Union[int, float]] = None, + tx_identity: Optional[TransactionIdentity] = None, ) -> "WriterAsyncIOStream": stream = GrpcWrapperAsyncIO(StreamWriteMessage.FromServer.from_proto) @@ -634,6 +702,7 @@ class WriterAsyncIOStream: writer = WriterAsyncIOStream( update_token_interval=update_token_interval, get_token_function=creds.get_auth_token if creds else lambda: "", + tx_identity=tx_identity, ) await writer._start(stream, init_request) return writer @@ -680,7 +749,7 @@ class WriterAsyncIOStream: if self._closed: raise RuntimeError("Can not write on closed stream.") - for request in messages_to_proto_requests(messages): + for request in messages_to_proto_requests(messages, self._tx_identity): self._stream.write(request) async def _update_token_loop(self): diff --git a/contrib/python/ydb/py3/ydb/_topic_writer/topic_writer_sync.py b/contrib/python/ydb/py3/ydb/_topic_writer/topic_writer_sync.py index a5193caf7c5..4796d7ac2d6 100644 --- a/contrib/python/ydb/py3/ydb/_topic_writer/topic_writer_sync.py +++ b/contrib/python/ydb/py3/ydb/_topic_writer/topic_writer_sync.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import logging import typing from concurrent.futures import Future from typing import Union, List, Optional @@ -14,13 +15,23 @@ from .topic_writer import ( TopicWriterClosedError, ) -from .topic_writer_asyncio import WriterAsyncIO +from ..query.base import TxEvent + +from .topic_writer_asyncio import ( + TxWriterAsyncIO, + WriterAsyncIO, +) from .._topic_common.common import ( _get_shared_event_loop, TimeoutType, CallFromSyncToAsync, ) +if typing.TYPE_CHECKING: + from ..query.transaction import BaseQueryTxContext + +logger = logging.getLogger(__name__) + class WriterSync: _caller: CallFromSyncToAsync @@ -63,7 +74,8 @@ class WriterSync: raise def __del__(self): - self.close(flush=False) + if not self._closed: + logger.warning("Topic writer was not closed properly. Consider using method close().") def close(self, *, flush: bool = True, timeout: TimeoutType = None): if self._closed: @@ -122,3 +134,39 @@ class WriterSync: self._check_closed() return self._caller.unsafe_call_with_result(self._async_writer.write_with_ack(messages), timeout=timeout) + + +class TxWriterSync(WriterSync): + def __init__( + self, + tx: "BaseQueryTxContext", + driver: SupportedDriverType, + settings: PublicWriterSettings, + *, + eventloop: Optional[asyncio.AbstractEventLoop] = None, + _parent=None, + ): + + self._closed = False + + if eventloop: + loop = eventloop + else: + loop = _get_shared_event_loop() + + self._caller = CallFromSyncToAsync(loop) + + async def create_async_writer(): + return TxWriterAsyncIO(tx, driver, settings, _is_implicit=True) + + self._async_writer = self._caller.safe_call_with_result(create_async_writer(), None) + self._parent = _parent + + tx._add_callback(TxEvent.BEFORE_COMMIT, self._on_before_commit, None) + tx._add_callback(TxEvent.BEFORE_ROLLBACK, self._on_before_rollback, None) + + def _on_before_commit(self, tx: "BaseQueryTxContext"): + self.close() + + def _on_before_rollback(self, tx: "BaseQueryTxContext"): + self.close(flush=False) diff --git a/contrib/python/ydb/py3/ydb/aio/driver.py b/contrib/python/ydb/py3/ydb/aio/driver.py index 9cd6fd2b74d..267997fbcc3 100644 --- a/contrib/python/ydb/py3/ydb/aio/driver.py +++ b/contrib/python/ydb/py3/ydb/aio/driver.py @@ -62,4 +62,5 @@ class Driver(pool.ConnectionPool): async def stop(self, timeout=10): await self.table_client._stop_pool_if_needed(timeout=timeout) + self.topic_client.close() await super().stop(timeout=timeout) diff --git a/contrib/python/ydb/py3/ydb/aio/query/pool.py b/contrib/python/ydb/py3/ydb/aio/query/pool.py index 947db658726..f1ca68d1cf0 100644 --- a/contrib/python/ydb/py3/ydb/aio/query/pool.py +++ b/contrib/python/ydb/py3/ydb/aio/query/pool.py @@ -158,6 +158,8 @@ class QuerySessionPool: async def wrapped_callee(): async with self.checkout() as session: async with session.transaction(tx_mode=tx_mode) as tx: + if tx_mode.name in ["serializable_read_write", "snapshot_read_only"]: + await tx.begin() result = await callee(tx, *args, **kwargs) await tx.commit() return result @@ -213,12 +215,6 @@ class QuerySessionPool: async def __aexit__(self, exc_type, exc_val, exc_tb): await self.stop() - def __del__(self): - if self._should_stop.is_set() or self._loop.is_closed(): - return - - self._loop.call_soon(self.stop) - class SimpleQuerySessionCheckoutAsync: def __init__(self, pool: QuerySessionPool): diff --git a/contrib/python/ydb/py3/ydb/aio/query/transaction.py b/contrib/python/ydb/py3/ydb/aio/query/transaction.py index 5b63a32b489..f0547e5f01f 100644 --- a/contrib/python/ydb/py3/ydb/aio/query/transaction.py +++ b/contrib/python/ydb/py3/ydb/aio/query/transaction.py @@ -16,6 +16,28 @@ logger = logging.getLogger(__name__) class QueryTxContext(BaseQueryTxContext): + def __init__(self, driver, session_state, session, tx_mode): + """ + An object that provides a simple transaction context manager that allows statements execution + in a transaction. You don't have to open transaction explicitly, because context manager encapsulates + transaction control logic, and opens new transaction if: + + 1) By explicit .begin() method; + 2) On execution of a first statement, which is strictly recommended method, because that avoids useless round trip + + This context manager is not thread-safe, so you should not manipulate on it concurrently. + + :param driver: A driver instance + :param session_state: A state of session + :param tx_mode: Transaction mode, which is a one from the following choises: + 1) QuerySerializableReadWrite() which is default mode; + 2) QueryOnlineReadOnly(allow_inconsistent_reads=False); + 3) QuerySnapshotReadOnly(); + 4) QueryStaleReadOnly(). + """ + super().__init__(driver, session_state, session, tx_mode) + self._init_callback_handler(base.CallbackHandlerMode.ASYNC) + async def __aenter__(self) -> "QueryTxContext": """ Enters a context manager and returns a transaction @@ -30,7 +52,7 @@ class QueryTxContext(BaseQueryTxContext): it is not finished explicitly """ await self._ensure_prev_stream_finished() - if self._tx_state._state == QueryTxStateEnum.BEGINED: + if self._tx_state._state == QueryTxStateEnum.BEGINED and self._external_error is None: # It's strictly recommended to close transactions directly # by using commit_tx=True flag while executing statement or by # .commit() or .rollback() methods, but here we trying to do best @@ -65,7 +87,9 @@ class QueryTxContext(BaseQueryTxContext): :return: A committed transaction or exception if commit is failed """ - if self._tx_state._already_in(QueryTxStateEnum.COMMITTED): + self._check_external_error_set() + + if self._tx_state._should_skip(QueryTxStateEnum.COMMITTED): return if self._tx_state._state == QueryTxStateEnum.NOT_INITIALIZED: @@ -74,7 +98,13 @@ class QueryTxContext(BaseQueryTxContext): await self._ensure_prev_stream_finished() - await self._commit_call(settings) + try: + await self._execute_callbacks_async(base.TxEvent.BEFORE_COMMIT) + await self._commit_call(settings) + await self._execute_callbacks_async(base.TxEvent.AFTER_COMMIT, exc=None) + except BaseException as e: + await self._execute_callbacks_async(base.TxEvent.AFTER_COMMIT, exc=e) + raise e async def rollback(self, settings: Optional[BaseRequestSettings] = None) -> None: """Calls rollback on a transaction if it is open otherwise is no-op. If transaction execution @@ -84,7 +114,9 @@ class QueryTxContext(BaseQueryTxContext): :return: A committed transaction or exception if commit is failed """ - if self._tx_state._already_in(QueryTxStateEnum.ROLLBACKED): + self._check_external_error_set() + + if self._tx_state._should_skip(QueryTxStateEnum.ROLLBACKED): return if self._tx_state._state == QueryTxStateEnum.NOT_INITIALIZED: @@ -93,7 +125,13 @@ class QueryTxContext(BaseQueryTxContext): await self._ensure_prev_stream_finished() - await self._rollback_call(settings) + try: + await self._execute_callbacks_async(base.TxEvent.BEFORE_ROLLBACK) + await self._rollback_call(settings) + await self._execute_callbacks_async(base.TxEvent.AFTER_ROLLBACK, exc=None) + except BaseException as e: + await self._execute_callbacks_async(base.TxEvent.AFTER_ROLLBACK, exc=e) + raise e async def execute( self, diff --git a/contrib/python/ydb/py3/ydb/driver.py b/contrib/python/ydb/py3/ydb/driver.py index 49bd223c901..3998aeeef5f 100644 --- a/contrib/python/ydb/py3/ydb/driver.py +++ b/contrib/python/ydb/py3/ydb/driver.py @@ -288,4 +288,5 @@ class Driver(pool.ConnectionPool): def stop(self, timeout=10): self.table_client._stop_pool_if_needed(timeout=timeout) + self.topic_client.close() super().stop(timeout=timeout) diff --git a/contrib/python/ydb/py3/ydb/issues.py b/contrib/python/ydb/py3/ydb/issues.py index f38f99f9257..4e76f5ed2b0 100644 --- a/contrib/python/ydb/py3/ydb/issues.py +++ b/contrib/python/ydb/py3/ydb/issues.py @@ -178,6 +178,10 @@ class SessionPoolEmpty(Error, queue.Empty): status = StatusCode.SESSION_POOL_EMPTY +class ClientInternalError(Error): + status = StatusCode.CLIENT_INTERNAL_ERROR + + class UnexpectedGrpcMessage(Error): def __init__(self, message: str): super().__init__(message) diff --git a/contrib/python/ydb/py3/ydb/query/base.py b/contrib/python/ydb/py3/ydb/query/base.py index 57a769bb1a1..a5ebedd95b3 100644 --- a/contrib/python/ydb/py3/ydb/query/base.py +++ b/contrib/python/ydb/py3/ydb/query/base.py @@ -1,6 +1,8 @@ import abc +import asyncio import enum import functools +from collections import defaultdict import typing from typing import ( @@ -17,6 +19,10 @@ from .. import issues from .. import _utilities from .. import _apis +from ydb._topic_common.common import CallFromSyncToAsync, _get_shared_event_loop +from ydb._grpc.grpcwrapper.common_utils import to_thread + + if typing.TYPE_CHECKING: from .transaction import BaseQueryTxContext @@ -196,3 +202,64 @@ def wrap_execute_query_response( return convert.ResultSet.from_message(response_pb.result_set, settings) return None + + +class TxEvent(enum.Enum): + BEFORE_COMMIT = "BEFORE_COMMIT" + AFTER_COMMIT = "AFTER_COMMIT" + BEFORE_ROLLBACK = "BEFORE_ROLLBACK" + AFTER_ROLLBACK = "AFTER_ROLLBACK" + + +class CallbackHandlerMode(enum.Enum): + SYNC = "SYNC" + ASYNC = "ASYNC" + + +def _get_sync_callback(method: typing.Callable, loop: Optional[asyncio.AbstractEventLoop]): + if asyncio.iscoroutinefunction(method): + if loop is None: + loop = _get_shared_event_loop() + + def async_to_sync_callback(*args, **kwargs): + caller = CallFromSyncToAsync(loop) + return caller.safe_call_with_result(method(*args, **kwargs), 10) + + return async_to_sync_callback + return method + + +def _get_async_callback(method: typing.Callable): + if asyncio.iscoroutinefunction(method): + return method + + async def sync_to_async_callback(*args, **kwargs): + return await to_thread(method, *args, **kwargs, executor=None) + + return sync_to_async_callback + + +class CallbackHandler: + def _init_callback_handler(self, mode: CallbackHandlerMode) -> None: + self._callbacks = defaultdict(list) + self._callback_mode = mode + + def _execute_callbacks_sync(self, event_name: str, *args, **kwargs) -> None: + for callback in self._callbacks[event_name]: + callback(self, *args, **kwargs) + + async def _execute_callbacks_async(self, event_name: str, *args, **kwargs) -> None: + tasks = [asyncio.create_task(callback(self, *args, **kwargs)) for callback in self._callbacks[event_name]] + if not tasks: + return + await asyncio.gather(*tasks) + + def _prepare_callback( + self, callback: typing.Callable, loop: Optional[asyncio.AbstractEventLoop] + ) -> typing.Callable: + if self._callback_mode == CallbackHandlerMode.SYNC: + return _get_sync_callback(callback, loop) + return _get_async_callback(callback) + + def _add_callback(self, event_name: str, callback: typing.Callable, loop: Optional[asyncio.AbstractEventLoop]): + self._callbacks[event_name].append(self._prepare_callback(callback, loop)) diff --git a/contrib/python/ydb/py3/ydb/query/pool.py b/contrib/python/ydb/py3/ydb/query/pool.py index e3775c4dd12..b25f7db855c 100644 --- a/contrib/python/ydb/py3/ydb/query/pool.py +++ b/contrib/python/ydb/py3/ydb/query/pool.py @@ -167,6 +167,8 @@ class QuerySessionPool: def wrapped_callee(): with self.checkout(timeout=retry_settings.max_session_acquire_timeout) as session: with session.transaction(tx_mode=tx_mode) as tx: + if tx_mode.name in ["serializable_read_write", "snapshot_read_only"]: + tx.begin() result = callee(tx, *args, **kwargs) tx.commit() return result @@ -224,9 +226,6 @@ class QuerySessionPool: def __exit__(self, exc_type, exc_val, exc_tb): self.stop() - def __del__(self): - self.stop() - class SimpleQuerySessionCheckout: def __init__(self, pool: QuerySessionPool, timeout: Optional[float]): diff --git a/contrib/python/ydb/py3/ydb/query/transaction.py b/contrib/python/ydb/py3/ydb/query/transaction.py index 414401da4d1..ae7642dbe21 100644 --- a/contrib/python/ydb/py3/ydb/query/transaction.py +++ b/contrib/python/ydb/py3/ydb/query/transaction.py @@ -11,6 +11,7 @@ from .. import ( _apis, issues, ) +from .._grpc.grpcwrapper import ydb_topic as _ydb_topic from .._grpc.grpcwrapper import ydb_query as _ydb_query from ..connection import _RpcState as RpcState @@ -42,11 +43,23 @@ class QueryTxStateHelper(abc.ABC): QueryTxStateEnum.DEAD: [], } + _SKIP_TRANSITIONS = { + QueryTxStateEnum.NOT_INITIALIZED: [], + QueryTxStateEnum.BEGINED: [], + QueryTxStateEnum.COMMITTED: [QueryTxStateEnum.COMMITTED, QueryTxStateEnum.ROLLBACKED], + QueryTxStateEnum.ROLLBACKED: [QueryTxStateEnum.COMMITTED, QueryTxStateEnum.ROLLBACKED], + QueryTxStateEnum.DEAD: [], + } + @classmethod def valid_transition(cls, before: QueryTxStateEnum, after: QueryTxStateEnum) -> bool: return after in cls._VALID_TRANSITIONS[before] @classmethod + def should_skip(cls, before: QueryTxStateEnum, after: QueryTxStateEnum) -> bool: + return after in cls._SKIP_TRANSITIONS[before] + + @classmethod def terminal(cls, state: QueryTxStateEnum) -> bool: return len(cls._VALID_TRANSITIONS[state]) == 0 @@ -88,8 +101,8 @@ class QueryTxState: if QueryTxStateHelper.terminal(self._state): raise RuntimeError(f"Transaction is in terminal state: {self._state.value}") - def _already_in(self, target: QueryTxStateEnum) -> bool: - return self._state == target + def _should_skip(self, target: QueryTxStateEnum) -> bool: + return QueryTxStateHelper.should_skip(self._state, target) def _construct_tx_settings(tx_state: QueryTxState) -> _ydb_query.TransactionSettings: @@ -170,7 +183,7 @@ def wrap_tx_rollback_response( return tx -class BaseQueryTxContext: +class BaseQueryTxContext(base.CallbackHandler): def __init__(self, driver, session_state, session, tx_mode): """ An object that provides a simple transaction context manager that allows statements execution @@ -196,6 +209,7 @@ class BaseQueryTxContext: self._session_state = session_state self.session = session self._prev_stream = None + self._external_error = None @property def session_id(self) -> str: @@ -215,6 +229,19 @@ class BaseQueryTxContext: """ return self._tx_state.tx_id + def _tx_identity(self) -> _ydb_topic.TransactionIdentity: + if not self.tx_id: + raise RuntimeError("Unable to get tx identity without started tx.") + return _ydb_topic.TransactionIdentity(self.tx_id, self.session_id) + + def _set_external_error(self, exc: BaseException) -> None: + self._external_error = exc + + def _check_external_error_set(self): + if self._external_error is None: + return + raise issues.ClientInternalError("Transaction was failed by external error.") from self._external_error + def _begin_call(self, settings: Optional[BaseRequestSettings]) -> "BaseQueryTxContext": self._tx_state._check_invalid_transition(QueryTxStateEnum.BEGINED) @@ -228,6 +255,7 @@ class BaseQueryTxContext: ) def _commit_call(self, settings: Optional[BaseRequestSettings]) -> "BaseQueryTxContext": + self._check_external_error_set() self._tx_state._check_invalid_transition(QueryTxStateEnum.COMMITTED) return self._driver( @@ -240,6 +268,7 @@ class BaseQueryTxContext: ) def _rollback_call(self, settings: Optional[BaseRequestSettings]) -> "BaseQueryTxContext": + self._check_external_error_set() self._tx_state._check_invalid_transition(QueryTxStateEnum.ROLLBACKED) return self._driver( @@ -262,6 +291,7 @@ class BaseQueryTxContext: settings: Optional[BaseRequestSettings], ) -> Iterable[_apis.ydb_query.ExecuteQueryResponsePart]: self._tx_state._check_tx_ready_to_use() + self._check_external_error_set() request = base.create_execute_query_request( query=query, @@ -283,18 +313,41 @@ class BaseQueryTxContext: ) def _move_to_beginned(self, tx_id: str) -> None: - if self._tx_state._already_in(QueryTxStateEnum.BEGINED) or not tx_id: + if self._tx_state._should_skip(QueryTxStateEnum.BEGINED) or not tx_id: return self._tx_state._change_state(QueryTxStateEnum.BEGINED) self._tx_state.tx_id = tx_id def _move_to_commited(self) -> None: - if self._tx_state._already_in(QueryTxStateEnum.COMMITTED): + if self._tx_state._should_skip(QueryTxStateEnum.COMMITTED): return self._tx_state._change_state(QueryTxStateEnum.COMMITTED) class QueryTxContext(BaseQueryTxContext): + def __init__(self, driver, session_state, session, tx_mode): + """ + An object that provides a simple transaction context manager that allows statements execution + in a transaction. You don't have to open transaction explicitly, because context manager encapsulates + transaction control logic, and opens new transaction if: + + 1) By explicit .begin() method; + 2) On execution of a first statement, which is strictly recommended method, because that avoids useless round trip + + This context manager is not thread-safe, so you should not manipulate on it concurrently. + + :param driver: A driver instance + :param session_state: A state of session + :param tx_mode: Transaction mode, which is a one from the following choises: + 1) QuerySerializableReadWrite() which is default mode; + 2) QueryOnlineReadOnly(allow_inconsistent_reads=False); + 3) QuerySnapshotReadOnly(); + 4) QueryStaleReadOnly(). + """ + + super().__init__(driver, session_state, session, tx_mode) + self._init_callback_handler(base.CallbackHandlerMode.SYNC) + def __enter__(self) -> "BaseQueryTxContext": """ Enters a context manager and returns a transaction @@ -309,7 +362,7 @@ class QueryTxContext(BaseQueryTxContext): it is not finished explicitly """ self._ensure_prev_stream_finished() - if self._tx_state._state == QueryTxStateEnum.BEGINED: + if self._tx_state._state == QueryTxStateEnum.BEGINED and self._external_error is None: # It's strictly recommended to close transactions directly # by using commit_tx=True flag while executing statement or by # .commit() or .rollback() methods, but here we trying to do best @@ -345,7 +398,8 @@ class QueryTxContext(BaseQueryTxContext): :return: A committed transaction or exception if commit is failed """ - if self._tx_state._already_in(QueryTxStateEnum.COMMITTED): + self._check_external_error_set() + if self._tx_state._should_skip(QueryTxStateEnum.COMMITTED): return if self._tx_state._state == QueryTxStateEnum.NOT_INITIALIZED: @@ -354,7 +408,13 @@ class QueryTxContext(BaseQueryTxContext): self._ensure_prev_stream_finished() - self._commit_call(settings) + try: + self._execute_callbacks_sync(base.TxEvent.BEFORE_COMMIT) + self._commit_call(settings) + self._execute_callbacks_sync(base.TxEvent.AFTER_COMMIT, exc=None) + except BaseException as e: # TODO: probably should be less wide + self._execute_callbacks_sync(base.TxEvent.AFTER_COMMIT, exc=e) + raise e def rollback(self, settings: Optional[BaseRequestSettings] = None) -> None: """Calls rollback on a transaction if it is open otherwise is no-op. If transaction execution @@ -364,7 +424,8 @@ class QueryTxContext(BaseQueryTxContext): :return: A committed transaction or exception if commit is failed """ - if self._tx_state._already_in(QueryTxStateEnum.ROLLBACKED): + self._check_external_error_set() + if self._tx_state._should_skip(QueryTxStateEnum.ROLLBACKED): return if self._tx_state._state == QueryTxStateEnum.NOT_INITIALIZED: @@ -373,7 +434,13 @@ class QueryTxContext(BaseQueryTxContext): self._ensure_prev_stream_finished() - self._rollback_call(settings) + try: + self._execute_callbacks_sync(base.TxEvent.BEFORE_ROLLBACK) + self._rollback_call(settings) + self._execute_callbacks_sync(base.TxEvent.AFTER_ROLLBACK, exc=None) + except BaseException as e: # TODO: probably should be less wide + self._execute_callbacks_sync(base.TxEvent.AFTER_ROLLBACK, exc=e) + raise e def execute( self, diff --git a/contrib/python/ydb/py3/ydb/topic.py b/contrib/python/ydb/py3/ydb/topic.py index 55f4ea04c5c..52f98e61d85 100644 --- a/contrib/python/ydb/py3/ydb/topic.py +++ b/contrib/python/ydb/py3/ydb/topic.py @@ -25,6 +25,8 @@ __all__ = [ "TopicWriteResult", "TopicWriter", "TopicWriterAsyncIO", + "TopicTxWriter", + "TopicTxWriterAsyncIO", "TopicWriterInitInfo", "TopicWriterMessage", "TopicWriterSettings", @@ -33,6 +35,7 @@ __all__ = [ import concurrent.futures import datetime from dataclasses import dataclass +import logging from typing import List, Union, Mapping, Optional, Dict, Callable from . import aio, Credentials, _apis, issues @@ -65,8 +68,10 @@ from ._topic_writer.topic_writer import ( # noqa: F401 PublicWriteResult as TopicWriteResult, ) +from ydb._topic_writer.topic_writer_asyncio import TxWriterAsyncIO as TopicTxWriterAsyncIO from ydb._topic_writer.topic_writer_asyncio import WriterAsyncIO as TopicWriterAsyncIO from ._topic_writer.topic_writer_sync import WriterSync as TopicWriter +from ._topic_writer.topic_writer_sync import TxWriterSync as TopicTxWriter from ._topic_common.common import ( wrap_operation as _wrap_operation, @@ -88,6 +93,8 @@ from ._grpc.grpcwrapper.ydb_topic_public_types import ( # noqa: F401 PublicAlterAutoPartitioningSettings as TopicAlterAutoPartitioningSettings, ) +logger = logging.getLogger(__name__) + class TopicClientAsyncIO: _closed: bool @@ -108,7 +115,8 @@ class TopicClientAsyncIO: ) def __del__(self): - self.close() + if not self._closed: + logger.warning("Topic client was not closed properly. Consider using method close().") async def create_topic( self, @@ -276,6 +284,35 @@ class TopicClientAsyncIO: return TopicWriterAsyncIO(self._driver, settings, _client=self) + def tx_writer( + self, + tx, + topic, + *, + producer_id: Optional[str] = None, # default - random + session_metadata: Mapping[str, str] = None, + partition_id: Union[int, None] = None, + auto_seqno: bool = True, + auto_created_at: bool = True, + codec: Optional[TopicCodec] = None, # default mean auto-select + # encoders: map[codec_code] func(encoded_bytes)->decoded_bytes + # the func will be called from multiply threads in parallel. + encoders: Optional[Mapping[_ydb_topic_public_types.PublicCodec, Callable[[bytes], bytes]]] = None, + # custom encoder executor for call builtin and custom decoders. If None - use shared executor pool. + # If max_worker in the executor is 1 - then encoders will be called from the thread without parallel. + encoder_executor: Optional[concurrent.futures.Executor] = None, + ) -> TopicTxWriterAsyncIO: + args = locals().copy() + del args["self"] + del args["tx"] + + settings = TopicWriterSettings(**args) + + if not settings.encoder_executor: + settings.encoder_executor = self._executor + + return TopicTxWriterAsyncIO(tx=tx, driver=self._driver, settings=settings, _client=self) + def close(self): if self._closed: return @@ -287,7 +324,7 @@ class TopicClientAsyncIO: if not self._closed: return - raise RuntimeError("Topic client closed") + raise issues.Error("Topic client closed") class TopicClient: @@ -310,7 +347,8 @@ class TopicClient: ) def __del__(self): - self.close() + if not self._closed: + logger.warning("Topic client was not closed properly. Consider using method close().") def create_topic( self, @@ -487,6 +525,36 @@ class TopicClient: return TopicWriter(self._driver, settings, _parent=self) + def tx_writer( + self, + tx, + topic, + *, + producer_id: Optional[str] = None, # default - random + session_metadata: Mapping[str, str] = None, + partition_id: Union[int, None] = None, + auto_seqno: bool = True, + auto_created_at: bool = True, + codec: Optional[TopicCodec] = None, # default mean auto-select + # encoders: map[codec_code] func(encoded_bytes)->decoded_bytes + # the func will be called from multiply threads in parallel. + encoders: Optional[Mapping[_ydb_topic_public_types.PublicCodec, Callable[[bytes], bytes]]] = None, + # custom encoder executor for call builtin and custom decoders. If None - use shared executor pool. + # If max_worker in the executor is 1 - then encoders will be called from the thread without parallel. + encoder_executor: Optional[concurrent.futures.Executor] = None, # default shared client executor pool + ) -> TopicWriter: + args = locals().copy() + del args["self"] + del args["tx"] + self._check_closed() + + settings = TopicWriterSettings(**args) + + if not settings.encoder_executor: + settings.encoder_executor = self._executor + + return TopicTxWriter(tx, self._driver, settings, _parent=self) + def close(self): if self._closed: return @@ -498,7 +566,7 @@ class TopicClient: if not self._closed: return - raise RuntimeError("Topic client closed") + raise issues.Error("Topic client closed") @dataclass diff --git a/contrib/python/ydb/py3/ydb/ydb_version.py b/contrib/python/ydb/py3/ydb/ydb_version.py index 8bd658d49e4..070a2455ef3 100644 --- a/contrib/python/ydb/py3/ydb/ydb_version.py +++ b/contrib/python/ydb/py3/ydb/ydb_version.py @@ -1 +1 @@ -VERSION = "3.19.3" +VERSION = "3.20.0" diff --git a/library/cpp/http/misc/httpreqdata.cpp b/library/cpp/http/misc/httpreqdata.cpp index ed5e8872c92..028418bc4a0 100644 --- a/library/cpp/http/misc/httpreqdata.cpp +++ b/library/cpp/http/misc/httpreqdata.cpp @@ -81,6 +81,8 @@ TStringBuf TBaseServerRequestData::Environment(TStringBuf key) const { return ip ? *ip : RemoteAddr(); } else if (ciKey == "QUERY_STRING") { return Query(); + } else if (ciKey == "BODY") { + return Body(); } else if (ciKey == "SERVER_NAME") { return ServerName(); } else if (ciKey == "SERVER_PORT") { diff --git a/library/cpp/http/misc/httpreqdata.h b/library/cpp/http/misc/httpreqdata.h index b5c9e446a99..48854923d12 100644 --- a/library/cpp/http/misc/httpreqdata.h +++ b/library/cpp/http/misc/httpreqdata.h @@ -7,6 +7,7 @@ #include <util/system/defaults.h> #include <util/string/cast.h> #include <library/cpp/cgiparam/cgiparam.h> +#include <util/memory/blob.h> #include <util/network/address.h> #include <util/network/socket.h> #include <util/generic/hash.h> @@ -52,6 +53,10 @@ public: return OrigQuery_; } + TStringBuf Body() const { + return Body_.AsStringBuf(); + } + void AppendQueryString(TStringBuf str); TStringBuf RemoteAddr() const; void SetRemoteAddr(TStringBuf addr); @@ -77,6 +82,10 @@ public: Socket_ = s; } + void SetBody(const TBlob& body) noexcept { + Body_ = body; + } + ui64 RequestBeginTime() const noexcept { return BeginTime_; } @@ -93,6 +102,7 @@ private: TString Path_; TStringBuf Query_; TStringBuf OrigQuery_; + TBlob Body_; THttpHeadersContainer HeadersIn_; SOCKET Socket_; ui64 BeginTime_; diff --git a/library/cpp/http/server/http_ex.cpp b/library/cpp/http/server/http_ex.cpp index 0681da10ff9..ead37da56d5 100644 --- a/library/cpp/http/server/http_ex.cpp +++ b/library/cpp/http/server/http_ex.cpp @@ -93,6 +93,7 @@ bool THttpClientRequestExtension::ProcessHeaders(TBaseServerRequestData& rd, TBl } else { postData = TBlob::FromStream(Input()); } + rd.SetBody(postData); } catch (...) { Output() << "HTTP/1.1 400 Bad request\r\n\r\n"; return false; diff --git a/library/cpp/tld/tlds-alpha-by-domain.txt b/library/cpp/tld/tlds-alpha-by-domain.txt index edda68377ff..d55b476c701 100644 --- a/library/cpp/tld/tlds-alpha-by-domain.txt +++ b/library/cpp/tld/tlds-alpha-by-domain.txt @@ -1,4 +1,4 @@ -# Version 2025032800, Last Updated Fri Mar 28 07:07:01 2025 UTC +# Version 2025033100, Last Updated Mon Mar 31 07:07:02 2025 UTC AAA AARP ABB diff --git a/yql/essentials/core/expr_nodes/yql_expr_nodes.json b/yql/essentials/core/expr_nodes/yql_expr_nodes.json index 7eb01b6114c..c051517eee9 100644 --- a/yql/essentials/core/expr_nodes/yql_expr_nodes.json +++ b/yql/essentials/core/expr_nodes/yql_expr_nodes.json @@ -2557,6 +2557,11 @@ "Match": {"Type": "Callable", "Name": "WideFromBlocks"} }, { + "Name": "TCoListFromBlocks", + "Base": "TCoInputBase", + "Match": {"Type": "Callable", "Name": "ListFromBlocks"} + }, + { "Name": "TCoReplicateScalars", "Base": "TCoInputBase", "Match": {"Type": "Callable", "Name": "ReplicateScalars"}, @@ -2570,6 +2575,11 @@ "Match": {"Type": "Callable", "Name": "WideToBlocks"} }, { + "Name": "TCoListToBlocks", + "Base": "TCoInputBase", + "Match": {"Type": "Callable", "Name": "ListToBlocks"} + }, + { "Name": "TCoPgSelect", "Base": "TCallable", "Match": {"Type": "Callable", "Name": "PgSelect"}, diff --git a/yql/essentials/core/histogram/eq_width_histogram.cpp b/yql/essentials/core/histogram/eq_width_histogram.cpp new file mode 100644 index 00000000000..3c5a452fdbd --- /dev/null +++ b/yql/essentials/core/histogram/eq_width_histogram.cpp @@ -0,0 +1,73 @@ +#include "eq_width_histogram.h" + +namespace NKikimr { + +TEqWidthHistogram::TEqWidthHistogram(ui32 numBuckets, EHistogramValueType valueType) + : ValueType(valueType), Buckets(numBuckets) { + // Exptected at least one bucket for histogram. + Y_ASSERT(numBuckets >= 1); +} + +TEqWidthHistogram::TEqWidthHistogram(const char *str, ui64 size) { + Y_ASSERT(str && size); + const ui32 numBuckets = *reinterpret_cast<const ui32 *>(str); + Y_ABORT_UNLESS(GetBinarySize(numBuckets) == size); + ui32 offset = sizeof(ui32); + ValueType = *reinterpret_cast<const EHistogramValueType *>(str + offset); + offset += sizeof(EHistogramValueType); + Buckets = TVector<TBucket>(numBuckets); + for (ui32 i = 0; i < numBuckets; ++i) { + std::memcpy(&Buckets[i], reinterpret_cast<const char *>(str + offset), sizeof(TBucket)); + offset += sizeof(TBucket); + } +} + +ui64 TEqWidthHistogram::GetBinarySize(ui32 nBuckets) const { + return sizeof(ui32) + sizeof(EHistogramValueType) + sizeof(TBucket) * nBuckets; +} + +// Binary layout: +// [4 byte: number of buckets][1 byte: value type] +// [sizeof(Bucket)[0]... sizeof(Bucket)[n]]. +std::unique_ptr<char> TEqWidthHistogram::Serialize(ui64 &binarySize) const { + binarySize = GetBinarySize(GetNumBuckets()); + std::unique_ptr<char> binaryData(new char[binarySize]); + ui32 offset = 0; + const ui32 numBuckets = GetNumBuckets(); + // 4 byte - number of buckets. + std::memcpy(binaryData.get(), &numBuckets, sizeof(ui32)); + offset += sizeof(ui32); + // 1 byte - values type. + std::memcpy(binaryData.get() + offset, &ValueType, sizeof(EHistogramValueType)); + offset += sizeof(EHistogramValueType); + // Buckets. + for (ui32 i = 0; i < numBuckets; ++i) { + std::memcpy(binaryData.get() + offset, &Buckets[i], sizeof(TBucket)); + offset += sizeof(TBucket); + } + return binaryData; +} + +TEqWidthHistogramEstimator::TEqWidthHistogramEstimator(std::shared_ptr<TEqWidthHistogram> histogram) + : Histogram(histogram) { + const auto numBuckets = Histogram->GetNumBuckets(); + PrefixSum = TVector<ui64>(numBuckets); + SuffixSum = TVector<ui64>(numBuckets); + CreatePrefixSum(numBuckets); + CreateSuffixSum(numBuckets); +} + +void TEqWidthHistogramEstimator::CreatePrefixSum(ui32 numBuckets) { + PrefixSum[0] = Histogram->GetNumElementsInBucket(0); + for (ui32 i = 1; i < numBuckets; ++i) { + PrefixSum[i] = PrefixSum[i - 1] + Histogram->GetNumElementsInBucket(i); + } +} + +void TEqWidthHistogramEstimator::CreateSuffixSum(ui32 numBuckets) { + SuffixSum[numBuckets - 1] = Histogram->GetNumElementsInBucket(numBuckets - 1); + for (i32 i = static_cast<i32>(numBuckets) - 2; i >= 0; --i) { + SuffixSum[i] = SuffixSum[i + 1] + Histogram->GetNumElementsInBucket(i); + } +} +} // namespace NKikimr diff --git a/yql/essentials/core/histogram/eq_width_histogram.h b/yql/essentials/core/histogram/eq_width_histogram.h new file mode 100644 index 00000000000..97c660af76b --- /dev/null +++ b/yql/essentials/core/histogram/eq_width_histogram.h @@ -0,0 +1,228 @@ +#pragma once + +#include <util/generic/strbuf.h> +#include <util/generic/vector.h> +#include <util/stream/output.h> +#include <util/system/types.h> +#include <cmath> + +namespace NKikimr { + + // Helper functions to work with histogram values. +template <typename T> +inline T LoadFrom(const ui8 *storage) { + T val; + std::memcpy(&val, storage, sizeof(T)); + return val; +} +template <typename T> +inline void StoreTo(ui8 *storage, T value) { + std::memcpy(storage, &value, sizeof(T)); +} +template <typename T> +inline bool CmpEqual(T left, T right) { + return left == right; +} +template <> +inline bool CmpEqual(double left, double right) { + return std::fabs(left - right) < std::numeric_limits<double>::epsilon(); +} +template <typename T> +inline bool CmpLess(T left, T right) { + return left < right; +} + +// Represents value types supported by histogram. +enum class EHistogramValueType : ui8 { Int16, Int32, Int64, Uint16, Uint32, Uint64, Double, NotSupported }; + +// Bucket storage size for Equal width histogram. +constexpr const ui32 EqWidthHistogramBucketStorageSize = 8; + +// This class represents an `Equal-width` histogram. +// Each bucket represents a range of contiguous values of equal width, and the +// aggregate summary stored in the bucket is the number of rows whose value lies +// within that range. +class TEqWidthHistogram { + public: +#pragma pack(push, 1) + struct TBucket { + // The number of values in a bucket. + ui64 Count{0}; + // The `start` value of a bucket, the `end` of the bucket is a next start. + // [start = start[i], end = start[i + 1]) + ui8 Start[EqWidthHistogramBucketStorageSize]; + }; + struct TBucketRange { + ui8 Start[EqWidthHistogramBucketStorageSize]; + ui8 End[EqWidthHistogramBucketStorageSize]; + }; +#pragma pack(pop) + + // Have to specify the number of buckets and type of the values. + TEqWidthHistogram(ui32 numBuckets = 1, EHistogramValueType type = EHistogramValueType::Int32); + // From serialized data. + TEqWidthHistogram(const char *str, ui64 size); + + // Adds the given `val` to a histogram. + template <typename T> + void AddElement(T val) { + const auto index = FindBucketIndex(val); + // The given `index` in range [0, numBuckets - 1]. + const T bucketValue = LoadFrom<T>(Buckets[index].Start); + if (!index || ((CmpEqual<T>(bucketValue, val) || CmpLess<T>(bucketValue, val)))) { + Buckets[index].Count++; + } else { + Buckets[index - 1].Count++; + } + } + + // Returns an index of the bucket which stores the given `val`. + // Returned index in range [0, numBuckets - 1]. + // Not using `std::lower_bound()` here because need an index to map to `suffix` and `prefix` sum. + template <typename T> + ui32 FindBucketIndex(T val) const { + ui32 start = 0; + ui32 end = GetNumBuckets() - 1; + while (start < end) { + auto it = start + (end - start) / 2; + if (CmpLess<T>(LoadFrom<T>(Buckets[it].Start), val)) { + start = it + 1; + } else { + end = it; + } + } + return start; + } + + // Returns a number of buckets in a histogram. + ui32 GetNumBuckets() const { return Buckets.size(); } + + template <typename T> + ui32 GetBucketWidth() const { + Y_ASSERT(GetNumBuckets()); + if (GetNumBuckets() == 1) { + return std::max(static_cast<ui32>(LoadFrom<T>(Buckets.front().Start)), 1U); + } else { + return std::max(static_cast<ui32>(LoadFrom<T>(Buckets[1].Start) - LoadFrom<T>(Buckets[0].Start)), 1U); + } + } + + template <> + ui32 GetBucketWidth<double>() const { + return 1; + } + + // Returns histogram type. + EHistogramValueType GetType() const { return ValueType; } + // Returns a number of elements in a bucket by the given `index`. + ui64 GetNumElementsInBucket(ui32 index) const { return Buckets[index].Count; } + + // Initializes buckets with a given `range`. + template <typename T> + void InitializeBuckets(const TBucketRange &range) { + Y_ASSERT(CmpLess<T>(LoadFrom<T>(range.Start), LoadFrom<T>(range.End))); + T rangeLen = LoadFrom<T>(range.End) - LoadFrom<T>(range.Start); + std::memcpy(Buckets[0].Start, range.Start, sizeof(range.Start)); + for (ui32 i = 1; i < GetNumBuckets(); ++i) { + const T prevStart = LoadFrom<T>(Buckets[i - 1].Start); + StoreTo<T>(Buckets[i].Start, prevStart + rangeLen); + } + } + + // Seriailizes to a binary representation + std::unique_ptr<char> Serialize(ui64 &binSize) const; + // Returns buckets. + const TVector<TBucket> &GetBuckets() const { return Buckets; } + + template <typename T> + void Aggregate(const TEqWidthHistogram &other) { + if ((this->ValueType != other.GetType()) || (!BucketsEqual<T>(other))) { + // Should we fail? + return; + } + for (ui32 i = 0; i < Buckets.size(); ++i) { + Buckets[i].Count += other.GetBuckets()[i].Count; + } + } + + private: + template <typename T> + bool BucketsEqual(const TEqWidthHistogram &other) { + if (Buckets.size() != other.GetNumBuckets()) { + return false; + } + for (ui32 i = 0; i < Buckets.size(); ++i) { + if (!CmpEqual<T>(LoadFrom<T>(Buckets[i].Start), LoadFrom<T>(GetBuckets()[i].Start))) { + return false; + } + } + return true; + } + + // Returns binary size of the histogram. + ui64 GetBinarySize(ui32 nBuckets) const; + EHistogramValueType ValueType; + TVector<TBucket> Buckets; +}; + +// This class represents a machinery to estimate a value in a histogram. +class TEqWidthHistogramEstimator { + public: + TEqWidthHistogramEstimator(std::shared_ptr<TEqWidthHistogram> histogram); + + // Methods to estimate values. + template <typename T> + ui64 EstimateLessOrEqual(T val) const { + return EstimateOrEqual<T>(val, PrefixSum); + } + + template <typename T> + ui64 EstimateGreaterOrEqual(T val) const { + return EstimateOrEqual<T>(val, SuffixSum); + } + + template <typename T> + ui64 EstimateLess(T val) const { + return EstimateNotEqual<T>(val, PrefixSum); + } + + template <typename T> + ui64 EstimateGreater(T val) const { + return EstimateNotEqual<T>(val, SuffixSum); + } + + template <typename T> + ui64 EstimateEqual(T val) const { + const auto index = Histogram->FindBucketIndex(val); + // Assuming uniform distribution. + return std::max(1U, static_cast<ui32>(Histogram->GetNumElementsInBucket(index) / Histogram->template GetBucketWidth<T>())); + } + + // Returns the total number elements in histogram. + // Could be used to adjust scale. + ui64 GetNumElements() const { return PrefixSum.back(); } + + private: + template <typename T> + ui64 EstimateOrEqual(T val, const TVector<ui64> &sumArray) const { + const auto index = Histogram->FindBucketIndex(val); + return sumArray[index]; + } + + template <typename T> + ui64 EstimateNotEqual(T val, const TVector<ui64> &sumArray) const { + const auto index = Histogram->FindBucketIndex(val); + // Take the previous backet if it's not the first one. + if (!index) { + return sumArray[index]; + } + return sumArray[index - 1]; + } + + void CreatePrefixSum(ui32 numBuckets); + void CreateSuffixSum(ui32 numBuckets); + std::shared_ptr<TEqWidthHistogram> Histogram; + TVector<ui64> PrefixSum; + TVector<ui64> SuffixSum; +}; +} // namespace NKikimr diff --git a/yql/essentials/core/histogram/ut/eq_width_histogram_ut.cpp b/yql/essentials/core/histogram/ut/eq_width_histogram_ut.cpp new file mode 100644 index 00000000000..9c1b1d969fe --- /dev/null +++ b/yql/essentials/core/histogram/ut/eq_width_histogram_ut.cpp @@ -0,0 +1,127 @@ +#include <library/cpp/testing/unittest/registar.h> + +#include "eq_width_histogram.h" + +namespace NKikimr { + +template <typename T> +bool EqualHistograms(const std::shared_ptr<TEqWidthHistogram> &left, const std::shared_ptr<TEqWidthHistogram> &right) { + // Not expecting any nullptr. + if (!left || !right) return false; + + if (left->GetNumBuckets() != right->GetNumBuckets()) { + return false; + } + if (left->GetType() != right->GetType()) { + return false; + } + + for (ui32 i = 0; i < left->GetNumBuckets(); ++i) { + const auto &leftBucket = left->GetBuckets()[i]; + const auto &rightBucket = right->GetBuckets()[i]; + if (leftBucket.Count != rightBucket.Count) { + return false; + } + if (!CmpEqual<T>(LoadFrom<T>(leftBucket.Start), LoadFrom<T>(rightBucket.Start))) { + return false; + } + } + + return true; +} + +template <typename T> +std::shared_ptr<TEqWidthHistogram> CreateHistogram(ui32 numBuckets, T start, T range, EHistogramValueType valueType) { + std::shared_ptr<TEqWidthHistogram> histogram(std::make_shared<TEqWidthHistogram>(numBuckets, valueType)); + TEqWidthHistogram::TBucketRange bucketRange; + StoreTo<T>(bucketRange.Start, start); + StoreTo<T>(bucketRange.End, range); + histogram->InitializeBuckets<T>(bucketRange); + return histogram; +} + +template <typename T> +void PopulateHistogram(std::shared_ptr<TEqWidthHistogram> histogram, const std::pair<ui32, ui32> &range) { + for (ui32 i = range.first; i < range.second; ++i) { + histogram->AddElement<T>(i); + } +} + +template <typename T> +void TestHistogramBasic(ui32 numBuckets, std::pair<ui32, ui32> range, std::pair<T, T> bucketRange, + EHistogramValueType valueType, std::pair<T, ui64> less, std::pair<T, ui64> greater) { + auto histogram = CreateHistogram<T>(numBuckets, bucketRange.first, bucketRange.second, valueType); + UNIT_ASSERT_VALUES_EQUAL(histogram->GetNumBuckets(), numBuckets); + PopulateHistogram<T>(histogram, range); + TEqWidthHistogramEstimator estimator(histogram); + UNIT_ASSERT_VALUES_EQUAL(estimator.EstimateLessOrEqual<T>(less.first), less.second); + UNIT_ASSERT_VALUES_EQUAL(estimator.EstimateGreaterOrEqual<T>(greater.first), greater.second); +} + +template <typename T> +void TestHistogramSerialization(ui32 numBuckets, std::pair<ui32, ui32> range, std::pair<T, T> bucketRange, + EHistogramValueType valueType) { + auto histogram = CreateHistogram<T>(numBuckets, bucketRange.first, bucketRange.second, valueType); + UNIT_ASSERT(histogram); + PopulateHistogram<T>(histogram, range); + ui64 binarySize = 0; + auto binaryData = histogram->Serialize(binarySize); + UNIT_ASSERT(binaryData && binarySize); + TString hString(binaryData.get(), binarySize); + auto histogramFromString = std::make_shared<TEqWidthHistogram>(hString.data(), hString.size()); + UNIT_ASSERT(histogramFromString); + UNIT_ASSERT(EqualHistograms<T>(histogram, histogramFromString)); +} + +template <typename T> +void TestHistogramAggregate(ui32 numBuckets, std::pair<ui32, ui32> range, std::pair<T, T> bucketRange, + EHistogramValueType valueType, ui32 numCombine, const TVector<ui64> &resultCount) { + auto histogram = CreateHistogram<T>(numBuckets, bucketRange.first, bucketRange.second, valueType); + UNIT_ASSERT(histogram); + PopulateHistogram<T>(histogram, range); + auto histogramToAdd = CreateHistogram<T>(numBuckets, bucketRange.first, bucketRange.second, valueType); + PopulateHistogram<T>(histogramToAdd, range); + UNIT_ASSERT(histogram); + for (ui32 i = 0; i < numCombine; ++i) histogram->template Aggregate<T>(*histogramToAdd); + for (ui32 i = 0; i < histogram->GetNumBuckets(); ++i) { + UNIT_ASSERT(histogram->GetBuckets()[i].Count == resultCount[i]); + } +} + +Y_UNIT_TEST_SUITE(EqWidthHistogram) { + Y_UNIT_TEST(Basic) { + TestHistogramBasic<ui32>(10, /*values range=*/{0, 10}, /*bucket range=*/{0, 2}, EHistogramValueType::Uint32, + /*{value, result}=*/{9, 10}, + /*{value, result}=*/{10, 0}); + TestHistogramBasic<ui64>(10, /*values range=*/{0, 10}, /*bucket range=*/{0, 2}, EHistogramValueType::Uint64, + /*{value, result}=*/{9, 10}, + /*{value, result}=*/{10, 0}); + TestHistogramBasic<i32>(10, /*values range=*/{0, 10}, /*bucket range=*/{0, 2}, EHistogramValueType::Int32, + /*{value, result}=*/{9, 10}, + /*{value, result}=*/{10, 0}); + TestHistogramBasic<i64>(10, /*values range=*/{0, 10}, /*bucket range=*/{0, 2}, EHistogramValueType::Int64, + /*{value, result}=*/{9, 10}, + /*{value, result}=*/{10, 0}); + TestHistogramBasic<double>(10, /*values range=*/{0.0, 10.0}, /*bucket range=*/{0.0, 2.0}, + EHistogramValueType::Double, + /*{value, result}=*/{9.0, 10}, + /*{value, result}=*/{10.0, 0}); + } + + Y_UNIT_TEST(Serialization) { + TestHistogramSerialization<ui32>(10, /*values range=*/{0, 10}, /*bucket range=*/{0, 2}, + EHistogramValueType::Uint32); + TestHistogramSerialization<ui64>(10, /*values range=*/{0, 10}, /*bucket range=*/{0, 2}, + EHistogramValueType::Uint64); + TestHistogramSerialization<i32>(10, /*values range=*/{0, 10}, /*bucket range=*/{0, 2}, EHistogramValueType::Int32); + TestHistogramSerialization<i64>(10, /*values range=*/{0, 10}, /*bucket range=*/{0, 2}, EHistogramValueType::Int64); + TestHistogramSerialization<double>(10, /*values range=*/{0.0, 10.0}, /*bucket range=*/{0.0, 2.0}, + EHistogramValueType::Double); + } + Y_UNIT_TEST(AggregateHistogram) { + TVector<ui64> resultCount{20, 20, 20, 20, 20, 0, 0, 0, 0, 0}; + TestHistogramAggregate<ui32>(10, /*values range=*/{0, 10}, /*bucket range=*/{0, 2}, EHistogramValueType::Uint32, 9, + resultCount); + } +} +} // namespace NKikimr diff --git a/yql/essentials/core/histogram/ut/ya.make b/yql/essentials/core/histogram/ut/ya.make new file mode 100644 index 00000000000..17e420ad072 --- /dev/null +++ b/yql/essentials/core/histogram/ut/ya.make @@ -0,0 +1,8 @@ +UNITTEST_FOR(yql/essentials/core/histogram) + +SIZE(MEDIUM) +SRCS( + eq_width_histogram_ut.cpp +) + +END() diff --git a/yql/essentials/core/histogram/ya.make b/yql/essentials/core/histogram/ya.make new file mode 100644 index 00000000000..bcc309c3798 --- /dev/null +++ b/yql/essentials/core/histogram/ya.make @@ -0,0 +1,12 @@ +LIBRARY() + +SRCS( + eq_width_histogram.h + eq_width_histogram.cpp +) + +END() + +RECURSE_FOR_TESTS( + ut +) diff --git a/yql/essentials/core/minsketch/ut/ya.make b/yql/essentials/core/minsketch/ut/ya.make index 30f60c7229a..d951550facc 100644 --- a/yql/essentials/core/minsketch/ut/ya.make +++ b/yql/essentials/core/minsketch/ut/ya.make @@ -1,15 +1,6 @@ UNITTEST_FOR(yql/essentials/core/minsketch) -FORK_SUBTESTS() -IF (WITH_VALGRIND) - SPLIT_FACTOR(30) - TIMEOUT(1200) - SIZE(LARGE) - TAG(ya:fat) -ELSE() - TIMEOUT(600) - SIZE(MEDIUM) -ENDIF() +SIZE(MEDIUM) SRCS( count_min_sketch_ut.cpp diff --git a/yql/essentials/core/type_ann/type_ann_blocks.cpp b/yql/essentials/core/type_ann/type_ann_blocks.cpp index c674d61ed5c..20f09c46961 100644 --- a/yql/essentials/core/type_ann/type_ann_blocks.cpp +++ b/yql/essentials/core/type_ann/type_ann_blocks.cpp @@ -906,6 +906,43 @@ IGraphTransformer::TStatus WideToBlocksWrapper(const TExprNode::TPtr& input, TEx return IGraphTransformer::TStatus::Ok; } +IGraphTransformer::TStatus ListToBlocksWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExtContext& ctx) { + Y_UNUSED(output); + if (!EnsureArgsCount(*input, 1U, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + if (!EnsureListType(input->Head(), ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + const auto listItemType = input->Head().GetTypeAnn()->Cast<TListExprType>()->GetItemType(); + if (!EnsureStructType(input->Head().Pos(), *listItemType, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + const auto structType = listItemType->Cast<TStructExprType>(); + + TVector<const TItemExprType*> outputStructItems; + for (auto item : structType->GetItems()) { + auto itemType = item->GetItemType(); + if (itemType->IsBlockOrScalar()) { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), "Input type should not be a block or scalar")); + return IGraphTransformer::TStatus::Error; + } + + if (!EnsureSupportedAsBlockType(input->Pos(), *itemType, ctx.Expr, ctx.Types)) { + return IGraphTransformer::TStatus::Error; + } + + outputStructItems.push_back(ctx.Expr.MakeType<TItemExprType>(item->GetName(), ctx.Expr.MakeType<TBlockExprType>(itemType))); + } + outputStructItems.push_back(ctx.Expr.MakeType<TItemExprType>(BlockLengthColumnName, ctx.Expr.MakeType<TScalarExprType>(ctx.Expr.MakeType<TDataExprType>(EDataSlot::Uint64)))); + + auto outputStructType = ctx.Expr.MakeType<TStructExprType>(outputStructItems); + input->SetTypeAnn(ctx.Expr.MakeType<TListExprType>(outputStructType)); + return IGraphTransformer::TStatus::Ok; +} + IGraphTransformer::TStatus WideFromBlocksWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) { Y_UNUSED(output); if (!EnsureArgsCount(*input, 1U, ctx.Expr)) { @@ -924,6 +961,22 @@ IGraphTransformer::TStatus WideFromBlocksWrapper(const TExprNode::TPtr& input, T return IGraphTransformer::TStatus::Ok; } +IGraphTransformer::TStatus ListFromBlocksWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) { + Y_UNUSED(output); + if (!EnsureArgsCount(*input, 1U, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + TVector<const TItemExprType*> outputStructItems; + if (!EnsureBlockListType(input->Head(), outputStructItems, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + auto outputStructType = ctx.Expr.MakeType<TStructExprType>(outputStructItems); + input->SetTypeAnn(ctx.Expr.MakeType<TListExprType>(outputStructType)); + return IGraphTransformer::TStatus::Ok; +} + IGraphTransformer::TStatus WideSkipTakeBlocksWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) { if (!EnsureArgsCount(*input, 2U, ctx.Expr)) { return IGraphTransformer::TStatus::Error; diff --git a/yql/essentials/core/type_ann/type_ann_blocks.h b/yql/essentials/core/type_ann/type_ann_blocks.h index 8a16376f9b7..223758ef7fc 100644 --- a/yql/essentials/core/type_ann/type_ann_blocks.h +++ b/yql/essentials/core/type_ann/type_ann_blocks.h @@ -28,7 +28,9 @@ namespace NTypeAnnImpl { IGraphTransformer::TStatus BlockCombineHashedWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExtContext& ctx); IGraphTransformer::TStatus BlockMergeFinalizeHashedWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExtContext& ctx); IGraphTransformer::TStatus WideToBlocksWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExtContext& ctx); + IGraphTransformer::TStatus ListToBlocksWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExtContext& ctx); IGraphTransformer::TStatus WideFromBlocksWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); + IGraphTransformer::TStatus ListFromBlocksWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); IGraphTransformer::TStatus WideSkipTakeBlocksWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); IGraphTransformer::TStatus WideTopBlocksWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); IGraphTransformer::TStatus WideSortBlocksWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); diff --git a/yql/essentials/core/type_ann/type_ann_core.cpp b/yql/essentials/core/type_ann/type_ann_core.cpp index 75e07ab3a72..cd6bd71608b 100644 --- a/yql/essentials/core/type_ann/type_ann_core.cpp +++ b/yql/essentials/core/type_ann/type_ann_core.cpp @@ -12983,6 +12983,7 @@ template <NKikimr::NUdf::EDataSlot DataSlot> Functions["NarrowMultiMap"] = &NarrowMultiMapWrapper; Functions["WideFromBlocks"] = &WideFromBlocksWrapper; + Functions["ListFromBlocks"] = &ListFromBlocksWrapper; Functions["WideSkipBlocks"] = &WideSkipTakeBlocksWrapper; Functions["WideTakeBlocks"] = &WideSkipTakeBlocksWrapper; Functions["BlockCompress"] = &BlockCompressWrapper; @@ -13018,6 +13019,7 @@ template <NKikimr::NUdf::EDataSlot DataSlot> ExtFunctions["AsScalar"] = &AsScalarWrapper; ExtFunctions["WideToBlocks"] = &WideToBlocksWrapper; + ExtFunctions["ListToBlocks"] = &ListToBlocksWrapper; ExtFunctions["BlockCombineAll"] = &BlockCombineAllWrapper; ExtFunctions["BlockCombineHashed"] = &BlockCombineHashedWrapper; ExtFunctions["BlockMergeFinalizeHashed"] = &BlockMergeFinalizeHashedWrapper; diff --git a/yql/essentials/core/ya.make b/yql/essentials/core/ya.make index d1142fe083b..8626132b6f4 100644 --- a/yql/essentials/core/ya.make +++ b/yql/essentials/core/ya.make @@ -71,6 +71,7 @@ PEERDIR( yql/essentials/minikql yql/essentials/minikql/jsonpath/parser yql/essentials/core/minsketch + yql/essentials/core/histogram yql/essentials/protos yql/essentials/public/udf yql/essentials/public/udf/tz diff --git a/yql/essentials/core/yql_expr_constraint.cpp b/yql/essentials/core/yql_expr_constraint.cpp index 32cba8f9d28..fdfc21946be 100644 --- a/yql/essentials/core/yql_expr_constraint.cpp +++ b/yql/essentials/core/yql_expr_constraint.cpp @@ -245,7 +245,9 @@ public: Functions["WideTopSortBlocks"] = &TCallableConstraintTransformer::WideTopWrap<true>; Functions["WideSortBlocks"] = &TCallableConstraintTransformer::WideTopWrap<true>; Functions["WideToBlocks"] = &TCallableConstraintTransformer::CopyAllFrom<0>; + Functions["ListToBlocks"] = &TCallableConstraintTransformer::CopyAllFrom<0>; Functions["WideFromBlocks"] = &TCallableConstraintTransformer::CopyAllFrom<0>; + Functions["ListFromBlocks"] = &TCallableConstraintTransformer::CopyAllFrom<0>; Functions["ReplicateScalars"] = &TCallableConstraintTransformer::CopyAllFrom<0>; Functions["BlockMergeFinalizeHashed"] = &TCallableConstraintTransformer::AggregateWrap<true>; Functions["BlockMergeManyFinalizeHashed"] = &TCallableConstraintTransformer::AggregateWrap<true>; diff --git a/yql/essentials/core/yql_expr_type_annotation.cpp b/yql/essentials/core/yql_expr_type_annotation.cpp index 80207d35bc2..16e29a60c8e 100644 --- a/yql/essentials/core/yql_expr_type_annotation.cpp +++ b/yql/essentials/core/yql_expr_type_annotation.cpp @@ -10,6 +10,7 @@ #include <yql/essentials/minikql/dom/json.h> #include <yql/essentials/minikql/dom/yson.h> #include <yql/essentials/minikql/jsonpath/parser/parser.h> +#include <yql/essentials/core/sql_types/block.h> #include <yql/essentials/core/sql_types/simple_types.h> #include "yql/essentials/parser/pg_catalog/catalog.h" #include <yql/essentials/parser/pg_wrapper/interface/utils.h> @@ -3269,6 +3270,52 @@ bool EnsureWideBlockType(TPositionHandle position, const TTypeAnnotationNode& ty return true; } +bool EnsureBlockStructType(TPositionHandle position, const TTypeAnnotationNode& type, TVector<const TItemExprType*>& structItems, TExprContext& ctx) { + if (HasError(&type, ctx)) { + return false; + } + + if (type.GetKind() != ETypeAnnotationKind::Struct) { + ctx.AddError(TIssue(ctx.GetPosition(position), TStringBuilder() << "Expected struct, but got: " << type)); + return false; + } + + auto& items = type.Cast<TStructExprType>()->GetItems(); + if (items.empty()) { + ctx.AddError(TIssue(ctx.GetPosition(position), "Expected at least one column")); + return false; + } + + bool hasBlockLengthColumn = false; + for (auto item : items) { + auto blockType = item->GetItemType(); + if (!EnsureBlockOrScalarType(position, *blockType, ctx)) { + return false; + } + + bool isScalar = false; + auto itemType = GetBlockItemType(*blockType, isScalar); + + if (item->GetName() == BlockLengthColumnName) { + if (!isScalar) { + ctx.AddError(TIssue(ctx.GetPosition(position), "Block length column should be a scalar")); + return false; + } + if (!EnsureSpecificDataType(position, *itemType, EDataSlot::Uint64, ctx)) { + return false; + } + hasBlockLengthColumn = true; + } else { + structItems.push_back(ctx.MakeType<TItemExprType>(item->GetName(), itemType)); + } + } + if (!hasBlockLengthColumn) { + ctx.AddError(TIssue(ctx.GetPosition(position), "Block struct must contain block length column")); + return false; + } + return true; +} + bool EnsureWideFlowBlockType(const TExprNode& node, TTypeAnnotationNode::TListType& blockItemTypes, TExprContext& ctx, bool allowScalar) { if (!EnsureWideFlowType(node, ctx)) { return false; @@ -3285,6 +3332,14 @@ bool EnsureWideStreamBlockType(const TExprNode& node, TTypeAnnotationNode::TList return EnsureWideBlockType(node.Pos(), *node.GetTypeAnn()->Cast<TStreamExprType>()->GetItemType(), blockItemTypes, ctx, allowScalar); } +bool EnsureBlockListType(const TExprNode& node, TVector<const TItemExprType*>& structItems, TExprContext& ctx) { + if (!EnsureListType(node, ctx)) { + return false; + } + + return EnsureBlockStructType(node.Pos(), *node.GetTypeAnn()->Cast<TListExprType>()->GetItemType(), structItems, ctx); +} + bool EnsureOptionalType(const TExprNode& node, TExprContext& ctx) { if (!node.GetTypeAnn()) { YQL_ENSURE(node.Type() == TExprNode::Lambda); diff --git a/yql/essentials/core/yql_expr_type_annotation.h b/yql/essentials/core/yql_expr_type_annotation.h index b429f46f394..8538c45ab18 100644 --- a/yql/essentials/core/yql_expr_type_annotation.h +++ b/yql/essentials/core/yql_expr_type_annotation.h @@ -134,8 +134,10 @@ bool IsWideSequenceBlockType(const TTypeAnnotationNode& type); bool IsSupportedAsBlockType(TPositionHandle pos, const TTypeAnnotationNode& type, TExprContext& ctx, TTypeAnnotationContext& types, bool reportUnspported = false); bool EnsureSupportedAsBlockType(TPositionHandle pos, const TTypeAnnotationNode& type, TExprContext& ctx, TTypeAnnotationContext& types); bool EnsureWideBlockType(TPositionHandle position, const TTypeAnnotationNode& type, TTypeAnnotationNode::TListType& blockItemTypes, TExprContext& ctx, bool allowScalar = true); +bool EnsureBlockStructType(TPositionHandle position, const TTypeAnnotationNode& type, TVector<const TItemExprType*>& structItems, TExprContext& ctx); bool EnsureWideFlowBlockType(const TExprNode& node, TTypeAnnotationNode::TListType& blockItemTypes, TExprContext& ctx, bool allowScalar = true); bool EnsureWideStreamBlockType(const TExprNode& node, TTypeAnnotationNode::TListType& blockItemTypes, TExprContext& ctx, bool allowScalar = true); +bool EnsureBlockListType(const TExprNode& node, TVector<const TItemExprType*>& structItems, TExprContext& ctx); bool EnsureOptionalType(const TExprNode& node, TExprContext& ctx); bool EnsureOptionalType(TPositionHandle position, const TTypeAnnotationNode& type, TExprContext& ctx); bool EnsureType(const TExprNode& node, TExprContext& ctx); diff --git a/yql/essentials/core/yql_statistics.cpp b/yql/essentials/core/yql_statistics.cpp index f0958b8976e..cf70ad2cf55 100644 --- a/yql/essentials/core/yql_statistics.cpp +++ b/yql/essentials/core/yql_statistics.cpp @@ -189,6 +189,16 @@ std::shared_ptr<TOptimizerStatistics> NYql::OverrideStatistics(const NYql::TOpti Base64StrictDecode(countMinBase64, countMinRaw); cStat.CountMinSketch.reset(NKikimr::TCountMinSketch::FromString(countMinRaw.data(), countMinRaw.size())); } + if (auto eqWidthHistogram = colMap.find("histogram"); eqWidthHistogram != colMap.end()) { + TString histogramBase64 = eqWidthHistogram->second.GetStringSafe(); + + TString histogramBinary{}; + Base64StrictDecode(histogramBase64, histogramBinary); + auto histogram = std::make_shared<NKikimr::TEqWidthHistogram>( + histogramBinary.data(), histogramBinary.size()); + cStat.EqWidthHistogramEstimator = + std::make_shared<NKikimr::TEqWidthHistogramEstimator>(histogram); + } res->ColumnStatistics->Data[columnName] = cStat; } diff --git a/yql/essentials/core/yql_statistics.h b/yql/essentials/core/yql_statistics.h index f3875138f30..e96e8b9f50a 100644 --- a/yql/essentials/core/yql_statistics.h +++ b/yql/essentials/core/yql_statistics.h @@ -2,6 +2,7 @@ #include "yql_cost_function.h" #include <yql/essentials/core/minsketch/count_min_sketch.h> +#include <yql/essentials/core/histogram/eq_width_histogram.h> #include <library/cpp/json/json_reader.h> @@ -36,6 +37,7 @@ struct TColumnStatistics { std::optional<double> NumUniqueVals; std::optional<double> HyperLogLog; std::shared_ptr<NKikimr::TCountMinSketch> CountMinSketch; + std::shared_ptr<NKikimr::TEqWidthHistogramEstimator> EqWidthHistogramEstimator; TString Type; TColumnStatistics() {} diff --git a/yql/essentials/core/yql_type_annotation.cpp b/yql/essentials/core/yql_type_annotation.cpp index 8934949d3fe..6de6ecf8bdf 100644 --- a/yql/essentials/core/yql_type_annotation.cpp +++ b/yql/essentials/core/yql_type_annotation.cpp @@ -303,7 +303,7 @@ IGraphTransformer::TStatus TTypeAnnotationContext::SetColumnOrder(const TExprNod allColumns.erase(it); } - if (!allColumns.empty()) { + if (!allColumns.empty() && !(allColumns.size() == 1 && *allColumns.begin() == BlockLengthColumnName)) { ctx.AddError(TIssue(ctx.GetPosition(node.Pos()), TStringBuilder() << "Some columns are left unordered with column order " << FormatColumnOrder(columnOrder) << " for node " << node.Content() << " with type: " << *node.GetTypeAnn())); diff --git a/yql/essentials/minikql/comp_nodes/mkql_unwrap.cpp b/yql/essentials/minikql/comp_nodes/mkql_unwrap.cpp index 307609aa082..75d0f4ba581 100644 --- a/yql/essentials/minikql/comp_nodes/mkql_unwrap.cpp +++ b/yql/essentials/minikql/comp_nodes/mkql_unwrap.cpp @@ -20,6 +20,10 @@ public: } NUdf::TUnboxedValuePod DoCalculate(TComputationContext& compCtx) const { + return DoCalculateImpl(compCtx).Release(); + } + + NUdf::TUnboxedValue DoCalculateImpl(TComputationContext& compCtx) const { auto value = Optional()->GetValue(compCtx); if (value) { return value.GetOptionalValue(); diff --git a/yql/essentials/minikql/comp_nodes/mkql_weakmember.cpp b/yql/essentials/minikql/comp_nodes/mkql_weakmember.cpp index 2d535459104..1c058b21a36 100644 --- a/yql/essentials/minikql/comp_nodes/mkql_weakmember.cpp +++ b/yql/essentials/minikql/comp_nodes/mkql_weakmember.cpp @@ -26,6 +26,11 @@ public: } NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const { + auto result = DoCalculateImpl(ctx); + return result.Release(); + } + + NUdf::TUnboxedValue DoCalculateImpl(TComputationContext& ctx) const { if (const auto& restDict = RestDict->GetValue(ctx)) { if (const auto& tryMember = restDict.Lookup(MemberName)) { return SimpleValueFromYson(SchemeType, tryMember.AsStringRef()); @@ -46,7 +51,7 @@ public: stringStream.DoWrite(ref.Data(), size); return stringStream.Value(); } else if (SchemeType == NUdf::EDataSlot::String) { - return tryMember.Release(); + return tryMember; } else { return {}; } diff --git a/yql/essentials/minikql/computation/mkql_method_address_helper.h b/yql/essentials/minikql/computation/mkql_method_address_helper.h index 058e4098976..8935a74d529 100644 --- a/yql/essentials/minikql/computation/mkql_method_address_helper.h +++ b/yql/essentials/minikql/computation/mkql_method_address_helper.h @@ -1,3 +1,5 @@ +#pragma once + #include <yql/essentials/public/udf/udf_value.h> #if defined(_msan_enabled_) && defined(__linux__) diff --git a/yql/essentials/minikql/mkql_alloc.cpp b/yql/essentials/minikql/mkql_alloc.cpp index 299e691f168..53e73d696ab 100644 --- a/yql/essentials/minikql/mkql_alloc.cpp +++ b/yql/essentials/minikql/mkql_alloc.cpp @@ -27,9 +27,15 @@ void TAllocState::TListEntry::Unlink() noexcept { TAllocState::TAllocState(const TSourceLocation& location, const NKikimr::TAlignedPagePoolCounters &counters, bool supportsSizedAllocators) : TAlignedPagePool(location, counters) +#ifndef NDEBUG + , DefaultMemInfo(MakeIntrusive<TMemoryUsageInfo>("default")) +#endif , SupportsSizedAllocators(supportsSizedAllocators) , CurrentPAllocList(&GlobalPAllocList) { +#ifndef NDEBUG + ActiveMemInfo.emplace(DefaultMemInfo.Get(), DefaultMemInfo); +#endif GetRoot()->InitLinks(); OffloadedBlocksRoot.InitLinks(); GlobalPAllocList.InitLinks(); diff --git a/yql/essentials/minikql/mkql_alloc.h b/yql/essentials/minikql/mkql_alloc.h index 24bbbb8e9ed..4bf567ca8f0 100644 --- a/yql/essentials/minikql/mkql_alloc.h +++ b/yql/essentials/minikql/mkql_alloc.h @@ -13,6 +13,7 @@ #include <unordered_map> #include <atomic> #include <memory> +#include <source_location> namespace NKikimr { @@ -43,6 +44,25 @@ static_assert(sizeof(TAllocPageHeader) % MKQL_ALIGNMENT == 0, "Incorrect size of struct TMkqlArrowHeader; +#ifndef NDEBUG +using TAllocLocation = std::source_location; +#else +struct TAllocLocation +{ + const char* file_name() const { + return ""; + } + + std::uint_least32_t line() const { + return 0; + } + + static TAllocLocation current() { + return {}; + } +}; +#endif + struct TAllocState : public TAlignedPagePool { struct TListEntry { @@ -57,31 +77,11 @@ struct TAllocState : public TAlignedPagePool }; #ifndef NDEBUG + TIntrusivePtr<TMemoryUsageInfo> DefaultMemInfo; std::unordered_map<TMemoryUsageInfo*, TIntrusivePtr<TMemoryUsageInfo>> ActiveMemInfo; #endif bool SupportsSizedAllocators = false; - void* LargeAlloc(size_t size) { -#if defined(ALLOW_DEFAULT_ALLOCATOR) - if (Y_UNLIKELY(IsDefaultAllocatorUsed())) { - return malloc(size); - } -#endif - - return Alloc(size); - } - - void LargeFree(void* ptr, size_t size) noexcept { -#if defined(ALLOW_DEFAULT_ALLOCATOR) - if (Y_UNLIKELY(IsDefaultAllocatorUsed())) { - free(ptr); - return; - } -#endif - - Free(ptr, size); - } - using TCurrentPages = std::array<TAllocPageHeader*, (TMemorySubPoolIdx)EMemorySubPool::Count>; static TAllocPageHeader EmptyPageHeader; @@ -307,7 +307,10 @@ private: void* MKQLAllocSlow(size_t sz, TAllocState* state, const EMemorySubPool mPool); -inline void* MKQLAllocFastDeprecated(size_t sz, TAllocState* state, const EMemorySubPool mPool) { +inline void* MKQLAllocFastDeprecated(size_t sz, TAllocState* state, const EMemorySubPool mPool, const TAllocLocation& location = TAllocLocation::current()) { +#ifdef NDEBUG + Y_UNUSED(location); +#endif Y_DEBUG_ABORT_UNLESS(state); #if defined(ALLOW_DEFAULT_ALLOCATOR) @@ -318,6 +321,9 @@ inline void* MKQLAllocFastDeprecated(size_t sz, TAllocState* state, const EMemor } ret->Link(&state->OffloadedBlocksRoot); +#ifndef NDEBUG + state->DefaultMemInfo->Take(ret + 1, sz, { location.file_name(), (int)location.line() }); +#endif return ret + 1; } #endif @@ -327,13 +333,23 @@ inline void* MKQLAllocFastDeprecated(size_t sz, TAllocState* state, const EMemor void* ret = (char*)currPage + currPage->Offset; currPage->Offset = AlignUp(currPage->Offset + sz, MKQL_ALIGNMENT); ++currPage->UseCount; +#ifndef NDEBUG + state->DefaultMemInfo->Take(ret, sz, { location.file_name(), (int)location.line() }); +#endif return ret; } - return MKQLAllocSlow(sz, state, mPool); + auto ret = MKQLAllocSlow(sz, state, mPool); +#ifndef NDEBUG + state->DefaultMemInfo->Take(ret, sz, { location.file_name(), (int)location.line() }); +#endif + return ret; } -inline void* MKQLAllocFastWithSize(size_t sz, TAllocState* state, const EMemorySubPool mPool) { +inline void* MKQLAllocFastWithSize(size_t sz, TAllocState* state, const EMemorySubPool mPool, const TAllocLocation& location = TAllocLocation::current()) { +#ifdef NDEBUG + Y_UNUSED(location); +#endif Y_DEBUG_ABORT_UNLESS(state); bool useMalloc = state->SupportsSizedAllocators && sz > MaxPageUserData; @@ -349,6 +365,9 @@ inline void* MKQLAllocFastWithSize(size_t sz, TAllocState* state, const EMemoryS } ret->Link(&state->OffloadedBlocksRoot); +#ifndef NDEBUG + state->DefaultMemInfo->Take(ret + 1, sz, { location.file_name(), (int)location.line() }); +#endif return ret + 1; } @@ -357,10 +376,17 @@ inline void* MKQLAllocFastWithSize(size_t sz, TAllocState* state, const EMemoryS void* ret = (char*)currPage + currPage->Offset; currPage->Offset = AlignUp(currPage->Offset + sz, MKQL_ALIGNMENT); ++currPage->UseCount; +#ifndef NDEBUG + state->DefaultMemInfo->Take(ret, sz, { location.file_name(), (int)location.line() }); +#endif return ret; } - return MKQLAllocSlow(sz, state, mPool); + auto ret = MKQLAllocSlow(sz, state, mPool); +#ifndef NDEBUG + state->DefaultMemInfo->Take(ret, sz, { location.file_name(), (int)location.line() }); +#endif + return ret; } void MKQLFreeSlow(TAllocPageHeader* header, TAllocState *state, const EMemorySubPool mPool) noexcept; @@ -370,6 +396,10 @@ inline void MKQLFreeDeprecated(const void* mem, const EMemorySubPool mPool) noex return; } +#ifndef NDEBUG + TlsAllocState->DefaultMemInfo->Return(mem); +#endif + #if defined(ALLOW_DEFAULT_ALLOCATOR) if (Y_UNLIKELY(TAllocState::IsDefaultAllocatorUsed())) { TAllocState *state = TlsAllocState; @@ -398,6 +428,9 @@ inline void MKQLFreeFastWithSize(const void* mem, size_t sz, TAllocState* state, } Y_DEBUG_ABORT_UNLESS(state); +#ifndef NDEBUG + state->DefaultMemInfo->Return(mem, sz); +#endif bool useFree = state->SupportsSizedAllocators && sz > MaxPageUserData; #if defined(ALLOW_DEFAULT_ALLOCATOR) @@ -423,12 +456,12 @@ inline void MKQLFreeFastWithSize(const void* mem, size_t sz, TAllocState* state, MKQLFreeSlow(header, state, mPool); } -inline void* MKQLAllocDeprecated(size_t sz, const EMemorySubPool mPool) { - return MKQLAllocFastDeprecated(sz, TlsAllocState, mPool); +inline void* MKQLAllocDeprecated(size_t sz, const EMemorySubPool mPool, const TAllocLocation& location = TAllocLocation::current()) { + return MKQLAllocFastDeprecated(sz, TlsAllocState, mPool, location); } -inline void* MKQLAllocWithSize(size_t sz, const EMemorySubPool mPool) { - return MKQLAllocFastWithSize(sz, TlsAllocState, mPool); +inline void* MKQLAllocWithSize(size_t sz, const EMemorySubPool mPool, const TAllocLocation& location = TAllocLocation::current()) { + return MKQLAllocFastWithSize(sz, TlsAllocState, mPool, location); } inline void MKQLFreeWithSize(const void* mem, size_t sz, const EMemorySubPool mPool) noexcept { @@ -478,6 +511,14 @@ struct TWithMiniKQLAlloc { }; template <typename T, typename... Args> +T* AllocateOn(const TAllocLocation& location, TAllocState* state, Args&&... args) +{ + void* addr = MKQLAllocFastWithSize(sizeof(T), state, T::MemoryPool, location); + return ::new(addr) T(std::forward<Args>(args)...); + static_assert(std::is_base_of<TWithMiniKQLAlloc<T::MemoryPool>, T>::value, "Class must inherit TWithMiniKQLAlloc."); +} + +template <typename T, typename... Args> T* AllocateOn(TAllocState* state, Args&&... args) { void* addr = MKQLAllocFastWithSize(sizeof(T), state, T::MemoryPool); diff --git a/yql/essentials/minikql/mkql_string_util_ut.cpp b/yql/essentials/minikql/mkql_string_util_ut.cpp index 9826ee0ee10..f0d5545ab73 100644 --- a/yql/essentials/minikql/mkql_string_util_ut.cpp +++ b/yql/essentials/minikql/mkql_string_util_ut.cpp @@ -9,10 +9,10 @@ using namespace NKikimr::NMiniKQL; Y_UNIT_TEST_SUITE(TMiniKQLStringUtils) { Y_UNIT_TEST(SubstringWithLargeOffset) { TScopedAlloc alloc(__LOCATION__); - const auto big = MakeStringNotFilled(NUdf::TUnboxedValuePod::OffsetLimit << 1U); - const auto sub0 = SubString(big, 1U, 42U); - const auto sub1 = SubString(big, NUdf::TUnboxedValuePod::OffsetLimit - 1U, 42U); - const auto sub2 = SubString(big, NUdf::TUnboxedValuePod::OffsetLimit, 42U); + const auto big = MakeStringNotFilled(/*size=*/NUdf::TUnboxedValuePod::OffsetLimit << 1U); + const auto sub0 = NUdf::TUnboxedValue(SubString(big, 1U, 42U)); + const auto sub1 = NUdf::TUnboxedValue(SubString(big, NUdf::TUnboxedValuePod::OffsetLimit - 1U, 42U)); + const auto sub2 = NUdf::TUnboxedValue(SubString(big, NUdf::TUnboxedValuePod::OffsetLimit, 42U)); UNIT_ASSERT(sub0.AsStringValue().Data() == sub1.AsStringValue().Data()); UNIT_ASSERT(sub1.AsStringValue().Data() != sub2.AsStringValue().Data()); diff --git a/yql/essentials/minikql/mkql_type_builder.cpp b/yql/essentials/minikql/mkql_type_builder.cpp index c9d6e363b4f..41502b144ae 100644 --- a/yql/essentials/minikql/mkql_type_builder.cpp +++ b/yql/essentials/minikql/mkql_type_builder.cpp @@ -2820,7 +2820,6 @@ TType* TTypeBuilder::ValidateBlockStructType(const TStructType* structType) cons MKQL_ENSURE(isScalar, "Block length column should be scalar"); MKQL_ENSURE(AS_TYPE(TDataType, itemType)->GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected Uint64"); - MKQL_ENSURE(!hasBlockLengthColumn, "Block struct must contain only one block length column"); hasBlockLengthColumn = true; } else { outStructItems.emplace_back(structType->GetMemberName(i), itemType); diff --git a/yql/essentials/providers/common/mkql/yql_provider_mkql.cpp b/yql/essentials/providers/common/mkql/yql_provider_mkql.cpp index 983bf855424..abe2f9c1e9a 100644 --- a/yql/essentials/providers/common/mkql/yql_provider_mkql.cpp +++ b/yql/essentials/providers/common/mkql/yql_provider_mkql.cpp @@ -448,7 +448,9 @@ TMkqlCommonCallableCompiler::TShared::TShared() { {"FromFlow", &TProgramBuilder::FromFlow}, {"WideToBlocks", &TProgramBuilder::WideToBlocks}, + {"ListToBlocks", &TProgramBuilder::ListToBlocks}, {"WideFromBlocks", &TProgramBuilder::WideFromBlocks}, + {"ListFromBlocks", &TProgramBuilder::ListFromBlocks}, {"AsScalar", &TProgramBuilder::AsScalar}, {"Just", &TProgramBuilder::NewOptional}, diff --git a/yql/essentials/sql/v1/SQLv1.g.in b/yql/essentials/sql/v1/SQLv1.g.in index dbec6828030..6ba91358316 100644 --- a/yql/essentials/sql/v1/SQLv1.g.in +++ b/yql/essentials/sql/v1/SQLv1.g.in @@ -1100,7 +1100,7 @@ alter_sequence_action: | INCREMENT BY? integer ; -show_create_table_stmt: SHOW CREATE TABLE simple_table_ref; +show_create_table_stmt: SHOW CREATE (TABLE | VIEW) simple_table_ref; // Special rules that allow to use certain keywords as identifiers. identifier: ID_PLAIN | ID_QUOTED; diff --git a/yql/essentials/sql/v1/SQLv1Antlr4.g.in b/yql/essentials/sql/v1/SQLv1Antlr4.g.in index 5c59ab61ea4..66be65e2aee 100644 --- a/yql/essentials/sql/v1/SQLv1Antlr4.g.in +++ b/yql/essentials/sql/v1/SQLv1Antlr4.g.in @@ -1100,7 +1100,7 @@ alter_sequence_action: | INCREMENT BY? integer ; -show_create_table_stmt: SHOW CREATE TABLE simple_table_ref; +show_create_table_stmt: SHOW CREATE (TABLE | VIEW) simple_table_ref; // Special rules that allow to use certain keywords as identifiers. identifier: ID_PLAIN | ID_QUOTED; diff --git a/yql/essentials/sql/v1/format/sql_format_ut.h b/yql/essentials/sql/v1/format/sql_format_ut.h index 2747ce28714..8bb2af27939 100644 --- a/yql/essentials/sql/v1/format/sql_format_ut.h +++ b/yql/essentials/sql/v1/format/sql_format_ut.h @@ -153,6 +153,15 @@ Y_UNIT_TEST(ShowCreateTable) { setup.Run(cases); } +Y_UNIT_TEST(ShowCreateView) { + TCases cases = { + {"use plato;show create view user;","USE plato;\n\nSHOW CREATE VIEW user;\n"}, + }; + + TSetup setup; + setup.Run(cases); +} + Y_UNIT_TEST(Use) { TCases cases = { {"use user;","USE user;\n"}, diff --git a/yql/essentials/sql/v1/lexer/lexer_ut.cpp b/yql/essentials/sql/v1/lexer/lexer_ut.cpp index 53cff6ffdc7..1ddfd04b507 100644 --- a/yql/essentials/sql/v1/lexer/lexer_ut.cpp +++ b/yql/essentials/sql/v1/lexer/lexer_ut.cpp @@ -1,36 +1,60 @@ #include "lexer.h" +#include "lexer_ut.h" #include <yql/essentials/core/issue/yql_issue.h> #include <yql/essentials/sql/settings/translation_settings.h> #include <yql/essentials/sql/v1/lexer/antlr3/lexer.h> +#include <yql/essentials/sql/v1/lexer/antlr3_ansi/lexer.h> #include <yql/essentials/sql/v1/lexer/antlr4/lexer.h> +#include <yql/essentials/sql/v1/lexer/antlr4_ansi/lexer.h> #include <yql/essentials/sql/v1/lexer/antlr4_pure/lexer.h> +#include <yql/essentials/sql/v1/lexer/antlr4_pure_ansi/lexer.h> #include <yql/essentials/sql/v1/lexer/regex/lexer.h> #include <library/cpp/testing/unittest/registar.h> +#include <util/string/ascii.h> +#include <util/random/random.h> + +#define UNIT_ASSERT_TOKENIZED(LEXER, QUERY, TOKENS) \ + do { \ + auto tokens = Tokenized((LEXER), (QUERY)); \ + UNIT_ASSERT_VALUES_EQUAL(tokens, (TOKENS)); \ + } while (false) + using namespace NSQLTranslation; using namespace NSQLTranslationV1; -std::pair<TParsedTokenList, NYql::TIssues> Tokenize(ILexer::TPtr& lexer, TString queryUtf8) { +TLexers Lexers = { + .Antlr3 = MakeAntlr3LexerFactory(), + .Antlr3Ansi = MakeAntlr4AnsiLexerFactory(), + .Antlr4 = MakeAntlr4LexerFactory(), + .Antlr4Ansi = MakeAntlr4AnsiLexerFactory(), + .Antlr4Pure = MakeAntlr4PureLexerFactory(), + .Antlr4PureAnsi = MakeAntlr4PureAnsiLexerFactory(), + .Regex = MakeRegexLexerFactory(/* ansi = */ false), + .RegexAnsi = MakeRegexLexerFactory(/* ansi = */ true), +}; + +std::pair<TParsedTokenList, NYql::TIssues> Tokenize(ILexer::TPtr& lexer, const TString& query) { TParsedTokenList tokens; NYql::TIssues issues; - Tokenize(*lexer, queryUtf8, "", tokens, issues, SQL_MAX_PARSER_ERRORS); + Tokenize(*lexer, query, "", tokens, issues, SQL_MAX_PARSER_ERRORS); return {tokens, issues}; } -TVector<TString> GetIssueMessages(ILexer::TPtr& lexer, TString queryUtf8) { +TVector<TString> GetIssueMessages(ILexer::TPtr& lexer, const TString& query) { TVector<TString> messages; - for (const auto& issue : Tokenize(lexer, queryUtf8).second) { + for (const auto& issue : Tokenize(lexer, query).second) { messages.emplace_back(issue.ToString(/* oneLine = */ true)); } return messages; } -TVector<TString> GetTokenViews(ILexer::TPtr& lexer, TString queryUtf8) { +TVector<TString> GetTokenViews(ILexer::TPtr& lexer, const TString& query) { TVector<TString> names; - for (auto& token : Tokenize(lexer, queryUtf8).first) { + for (auto& token : Tokenize(lexer, query).first) { TString view = std::move(token.Name); if (view == "ID_PLAIN" || view == "STRING_VALUE") { view.append(" ("); @@ -42,28 +66,58 @@ TVector<TString> GetTokenViews(ILexer::TPtr& lexer, TString queryUtf8) { return names; } -void AssertEquivialent(const TParsedToken& lhs, const TParsedToken& rhs) { - if (lhs.Name == "EOF" && rhs.Name == "EOF") { - return; +TString ToString(TParsedToken token) { + TString& string = token.Name; + if (!AsciiEqualsIgnoreCase(token.Name, token.Content) && token.Name != "EOF") { + string += "("; + string += token.Content; + string += ")"; + } + return string; +} + +TString Tokenized(ILexer::TPtr& lexer, const TString& query) { + TParsedTokenList tokens; + NYql::TIssues issues; + bool ok = Tokenize(*lexer, query, "Test", tokens, issues, SQL_MAX_PARSER_ERRORS); + + TString out; + if (!ok) { + out = "[INVALID] "; } - UNIT_ASSERT_VALUES_EQUAL(lhs.Name, rhs.Name); - UNIT_ASSERT_VALUES_EQUAL(lhs.Content, rhs.Content); - UNIT_ASSERT_VALUES_EQUAL(lhs.Line, rhs.Line); + for (auto& token : tokens) { + out += ToString(std::move(token)); + out += " "; + } + if (!out.empty()) { + out.pop_back(); + } + return out; } -void AssertEquivialent(const TParsedTokenList& lhs, const TParsedTokenList& rhs) { - UNIT_ASSERT_VALUES_EQUAL(lhs.size(), rhs.size()); - for (size_t i = 0; i < lhs.size(); ++i) { - AssertEquivialent(lhs.at(i), rhs.at(i)); +TString RandomMultilineCommentLikeText(size_t maxSize) { + auto size = RandomNumber<size_t>(maxSize); + TString comment; + for (size_t i = 0; i < size; ++i) { + if (auto /* isOpen */ _ = RandomNumber<bool>()) { + comment += "/*"; + } else { + comment += "*/"; + } + + for (int gap = RandomNumber<size_t>(2); gap > 0; --gap) { + comment += " "; + } } + return comment; } Y_UNIT_TEST_SUITE(SQLv1Lexer) { Y_UNIT_TEST(UnsupportedIssues) { NSQLTranslationV1::TLexers factories; - TVector<ILexer::TPtr> lexers; + TVector<ILexer::TPtr> lexers; for (auto ansi : {false, true}) { for (auto antlr4 : {false, true}) { for (auto flavor : {ELexerFlavor::Default, ELexerFlavor::Pure, ELexerFlavor::Regex}) { @@ -96,8 +150,8 @@ Y_UNIT_TEST_SUITE(SQLv1Lexer) { UNIT_ASSERT_VALUES_EQUAL(actual, expected); } - Y_UNIT_TEST(AntlrVersionIndependent) { - const TVector<TString> queriesUtf8 = { + Y_UNIT_TEST_ON_EACH_LEXER(AntlrAndFlavorIndependent) { + static const TVector<TString> queries = { "", " ", "SELECT", @@ -115,35 +169,31 @@ Y_UNIT_TEST_SUITE(SQLv1Lexer) { "\"select\"select", }; - NSQLTranslationV1::TLexers lexers; - lexers.Antlr3 = NSQLTranslationV1::MakeAntlr3LexerFactory(); - lexers.Antlr4 = NSQLTranslationV1::MakeAntlr4LexerFactory(); - lexers.Antlr4Pure = NSQLTranslationV1::MakeAntlr4PureLexerFactory(); + static TVector<TString> expectations(queries.size()); + + if (ANSI) { + return; + } + + auto lexer = MakeLexer(Lexers, ANSI, ANTLR4, FLAVOR); + + for (size_t i = 0; i < queries.size(); ++i) { + const auto& query = queries[i]; + auto& expected = expectations[i]; - auto lexer3 = MakeLexer(lexers, /* ansi = */ false, /* antlr4 = */ false); - auto lexer4 = MakeLexer(lexers, /* ansi = */ false, /* antlr4 = */ true); - auto lexer4p = MakeLexer(lexers, /* ansi = */ false, /* antlr4 = */ true, ELexerFlavor::Pure); + if (expected.empty()) { + expected = Tokenized(lexer, query); + return; + } - for (const auto& query : queriesUtf8) { - auto [tokens3, issues3] = Tokenize(lexer3, query); - auto [tokens4, issues4] = Tokenize(lexer4, query); - auto [tokens4p, issues4p] = Tokenize(lexer4p, query); - AssertEquivialent(tokens3, tokens4); - AssertEquivialent(tokens3, tokens4p); - UNIT_ASSERT(issues3.Empty()); - UNIT_ASSERT(issues4.Empty()); - UNIT_ASSERT(issues4p.Empty()); + UNIT_ASSERT_TOKENIZED(lexer, query, expected); } } TVector<TString> InvalidQueries(); void TestInvalidTokensSkipped(bool antlr4, const TVector<TVector<TString>>& expected) { - NSQLTranslationV1::TLexers lexers; - lexers.Antlr3 = NSQLTranslationV1::MakeAntlr3LexerFactory(); - lexers.Antlr4 = NSQLTranslationV1::MakeAntlr4LexerFactory(); - - auto lexer = MakeLexer(lexers, /* ansi = */ false, antlr4); + auto lexer = MakeLexer(Lexers, /* ansi = */ false, antlr4); auto input = InvalidQueries(); UNIT_ASSERT_VALUES_EQUAL(input.size(), expected.size()); @@ -198,16 +248,10 @@ Y_UNIT_TEST_SUITE(SQLv1Lexer) { } Y_UNIT_TEST(IssuesCollected) { - NSQLTranslationV1::TLexers lexers; - lexers.Antlr3 = NSQLTranslationV1::MakeAntlr3LexerFactory(); - lexers.Antlr4 = NSQLTranslationV1::MakeAntlr4LexerFactory(); - lexers.Antlr4Pure = NSQLTranslationV1::MakeAntlr4PureLexerFactory(); - lexers.Regex = NSQLTranslationV1::MakeRegexLexerFactory(/* ansi = */ false); - - auto lexer3 = MakeLexer(lexers, /* ansi = */ false, /* antlr4 = */ false); - auto lexer4 = MakeLexer(lexers, /* ansi = */ false, /* antlr4 = */ true); - auto lexer4p = MakeLexer(lexers, /* ansi = */ false, /* antlr4 = */ true, ELexerFlavor::Pure); - auto lexerR = MakeLexer(lexers, /* ansi = */ false, /* antlr4 = */ false, ELexerFlavor::Regex); + auto lexer3 = MakeLexer(Lexers, /* ansi = */ false, /* antlr4 = */ false); + auto lexer4 = MakeLexer(Lexers, /* ansi = */ false, /* antlr4 = */ true); + auto lexer4p = MakeLexer(Lexers, /* ansi = */ false, /* antlr4 = */ true, ELexerFlavor::Pure); + auto lexerR = MakeLexer(Lexers, /* ansi = */ false, /* antlr4 = */ false, ELexerFlavor::Regex); for (const auto& query : InvalidQueries()) { auto issues3 = GetIssueMessages(lexer3, query); @@ -223,9 +267,7 @@ Y_UNIT_TEST_SUITE(SQLv1Lexer) { } Y_UNIT_TEST(IssueMessagesAntlr3) { - NSQLTranslationV1::TLexers lexers; - lexers.Antlr3 = NSQLTranslationV1::MakeAntlr3LexerFactory(); - auto lexer3 = MakeLexer(lexers, /* ansi = */ false, /* antlr4 = */ false); + auto lexer3 = MakeLexer(Lexers, /* ansi = */ false, /* antlr4 = */ false); auto actual = GetIssueMessages(lexer3, "\xF0\x9F\x98\x8A SELECT * FR"); @@ -240,10 +282,7 @@ Y_UNIT_TEST_SUITE(SQLv1Lexer) { } Y_UNIT_TEST(IssueMessagesAntlr4) { - NSQLTranslationV1::TLexers lexers; - lexers.Antlr4 = NSQLTranslationV1::MakeAntlr4LexerFactory(); - - auto lexer4 = MakeLexer(lexers, /* ansi = */ false, /* antlr4 = */ true); + auto lexer4 = MakeLexer(Lexers, /* ansi = */ false, /* antlr4 = */ true); auto actual = GetIssueMessages(lexer4, "\xF0\x9F\x98\x8A SELECT * FR"); @@ -253,4 +292,164 @@ Y_UNIT_TEST_SUITE(SQLv1Lexer) { UNIT_ASSERT_VALUES_EQUAL(actual, expected); } -} + + Y_UNIT_TEST_ON_EACH_LEXER(Whitespace) { + auto lexer = MakeLexer(Lexers, ANSI, ANTLR4, FLAVOR); + UNIT_ASSERT_TOKENIZED(lexer, "", "EOF"); + UNIT_ASSERT_TOKENIZED(lexer, " ", "WS( ) EOF"); + UNIT_ASSERT_TOKENIZED(lexer, " ", "WS( ) WS( ) EOF"); + UNIT_ASSERT_TOKENIZED(lexer, "\n", "WS(\n) EOF"); + } + + Y_UNIT_TEST_ON_EACH_LEXER(Keyword) { + auto lexer = MakeLexer(Lexers, ANSI, ANTLR4, FLAVOR); + UNIT_ASSERT_TOKENIZED(lexer, "SELECT", "SELECT EOF"); + UNIT_ASSERT_TOKENIZED(lexer, "INSERT", "INSERT EOF"); + UNIT_ASSERT_TOKENIZED(lexer, "FROM", "FROM EOF"); + } + + Y_UNIT_TEST_ON_EACH_LEXER(Punctuation) { + auto lexer = MakeLexer(Lexers, ANSI, ANTLR4, FLAVOR); + UNIT_ASSERT_TOKENIZED( + lexer, + "* / + - <|", + "ASTERISK(*) WS( ) SLASH(/) WS( ) " + "PLUS(+) WS( ) MINUS(-) WS( ) STRUCT_OPEN(<|) EOF"); + UNIT_ASSERT_TOKENIZED(lexer, "SELECT*FROM", "SELECT ASTERISK(*) FROM EOF"); + } + + Y_UNIT_TEST_ON_EACH_LEXER(IdPlain) { + auto lexer = MakeLexer(Lexers, ANSI, ANTLR4, FLAVOR); + UNIT_ASSERT_TOKENIZED(lexer, "variable my_table", "ID_PLAIN(variable) WS( ) ID_PLAIN(my_table) EOF"); + } + + Y_UNIT_TEST_ON_EACH_LEXER(IdQuoted) { + auto lexer = MakeLexer(Lexers, ANSI, ANTLR4, FLAVOR); + UNIT_ASSERT_TOKENIZED(lexer, "``", "ID_QUOTED(``) EOF"); + UNIT_ASSERT_TOKENIZED(lexer, "` `", "ID_QUOTED(` `) EOF"); + UNIT_ASSERT_TOKENIZED(lexer, "`local/table`", "ID_QUOTED(`local/table`) EOF"); + } + + Y_UNIT_TEST_ON_EACH_LEXER(Number) { + auto lexer = MakeLexer(Lexers, ANSI, ANTLR4, FLAVOR); + UNIT_ASSERT_TOKENIZED(lexer, "1", "DIGITS(1) EOF"); + UNIT_ASSERT_TOKENIZED(lexer, "123", "DIGITS(123) EOF"); + UNIT_ASSERT_TOKENIZED(lexer, "123u", "INTEGER_VALUE(123u) EOF"); + UNIT_ASSERT_TOKENIZED(lexer, "123ui", "INTEGER_VALUE(123ui) EOF"); + UNIT_ASSERT_TOKENIZED(lexer, "123.45", "REAL(123.45) EOF"); + UNIT_ASSERT_TOKENIZED(lexer, "123.45E10", "REAL(123.45E10) EOF"); + UNIT_ASSERT_TOKENIZED(lexer, "123.45E+10", "REAL(123.45E+10) EOF"); + UNIT_ASSERT_TOKENIZED(lexer, "1E+10", "REAL(1E+10) EOF"); + } + + Y_UNIT_TEST_ON_EACH_LEXER(SingleLineString) { + auto lexer = MakeLexer(Lexers, ANSI, ANTLR4, FLAVOR); + UNIT_ASSERT_TOKENIZED(lexer, "\"\"", "STRING_VALUE(\"\") EOF"); + UNIT_ASSERT_TOKENIZED(lexer, "\' \'", "STRING_VALUE(\' \') EOF"); + UNIT_ASSERT_TOKENIZED(lexer, "\" \"", "STRING_VALUE(\" \") EOF"); + UNIT_ASSERT_TOKENIZED(lexer, "\"test\"", "STRING_VALUE(\"test\") EOF"); + + if (!ANSI) { + UNIT_ASSERT_TOKENIZED(lexer, "\"\\\"\"", "STRING_VALUE(\"\\\"\") EOF"); + UNIT_ASSERT_TOKENIZED(lexer, "\"\"\"\"", "STRING_VALUE(\"\") STRING_VALUE(\"\") EOF"); + } else { + UNIT_ASSERT_TOKENIZED(lexer, "\"\\\"\"", "[INVALID] STRING_VALUE(\"\\\") EOF"); + UNIT_ASSERT_TOKENIZED(lexer, "\"\"\"\"", "STRING_VALUE(\"\"\"\") EOF"); + } + } + + Y_UNIT_TEST_ON_EACH_LEXER(MultiLineString) { + auto lexer = MakeLexer(Lexers, ANSI, ANTLR4, FLAVOR); + UNIT_ASSERT_TOKENIZED(lexer, "@@@@", "STRING_VALUE(@@@@) EOF"); + UNIT_ASSERT_TOKENIZED(lexer, "@@ @@@", "STRING_VALUE(@@ @@@) EOF"); + UNIT_ASSERT_TOKENIZED(lexer, "@@test@@", "STRING_VALUE(@@test@@) EOF"); + UNIT_ASSERT_TOKENIZED(lexer, "@@line1\nline2@@", "STRING_VALUE(@@line1\nline2@@) EOF"); + } + + Y_UNIT_TEST_ON_EACH_LEXER(SingleLineComment) { + auto lexer = MakeLexer(Lexers, ANSI, ANTLR4, FLAVOR); + UNIT_ASSERT_TOKENIZED(lexer, "--yql", "COMMENT(--yql) EOF"); + UNIT_ASSERT_TOKENIZED(lexer, "-- yql ", "COMMENT(-- yql ) EOF"); + UNIT_ASSERT_TOKENIZED(lexer, "-- yql\nSELECT", "COMMENT(-- yql\n) SELECT EOF"); + UNIT_ASSERT_TOKENIZED(lexer, "-- yql --", "COMMENT(-- yql --) EOF"); + } + + Y_UNIT_TEST_ON_EACH_LEXER(MultiLineComment) { + auto lexer = MakeLexer(Lexers, ANSI, ANTLR4, FLAVOR); + UNIT_ASSERT_TOKENIZED(lexer, "/* yql */", "COMMENT(/* yql */) EOF"); + UNIT_ASSERT_TOKENIZED(lexer, "/* yql */ */", "COMMENT(/* yql */) WS( ) ASTERISK(*) SLASH(/) EOF"); + UNIT_ASSERT_TOKENIZED(lexer, "/* yql\n * yql\n */", "COMMENT(/* yql\n * yql\n */) EOF"); + } + + Y_UNIT_TEST_ON_EACH_LEXER(RecursiveMultiLineComment) { + auto lexer = MakeLexer(Lexers, ANSI, ANTLR4, FLAVOR); + if (!ANSI) { + UNIT_ASSERT_TOKENIZED(lexer, "/* /* yql */", "COMMENT(/* /* yql */) EOF"); + UNIT_ASSERT_TOKENIZED(lexer, "/* /* yql */ */", "COMMENT(/* /* yql */) WS( ) ASTERISK(*) SLASH(/) EOF"); + } else { + UNIT_ASSERT_TOKENIZED(lexer, "/* /* yql */", "COMMENT(/* /* yql */) EOF"); + UNIT_ASSERT_TOKENIZED(lexer, "/* yql */ */", "COMMENT(/* yql */) WS( ) ASTERISK(*) SLASH(/) EOF"); + UNIT_ASSERT_TOKENIZED(lexer, "/* /* /* yql */ */", "COMMENT(/* /* /* yql */ */) EOF"); + UNIT_ASSERT_TOKENIZED(lexer, "/* /* yql */ */ */", "COMMENT(/* /* yql */ */) WS( ) ASTERISK(*) SLASH(/) EOF"); + UNIT_ASSERT_TOKENIZED(lexer, "/* /* yql */ */", "COMMENT(/* /* yql */ */) EOF"); + UNIT_ASSERT_TOKENIZED(lexer, "/*/*/*/", "COMMENT(/*/*/) ASTERISK(*) SLASH(/) EOF"); + UNIT_ASSERT_TOKENIZED(lexer, "/*/**/*/*/*/", "COMMENT(/*/**/*/) ASTERISK(*) SLASH(/) ASTERISK(*) SLASH(/) EOF"); + UNIT_ASSERT_TOKENIZED(lexer, "/* /* */ a /* /* */", "COMMENT(/* /* */ a /* /* */) EOF"); + } + } + + Y_UNIT_TEST_ON_EACH_LEXER(RandomRecursiveMultiLineComment) { + if (!ANTLR4 && FLAVOR != ELexerFlavor::Regex || FLAVOR != ELexerFlavor::Pure) { + return; + } + + auto lexer = MakeLexer(Lexers, ANSI, ANTLR4, FLAVOR); + auto reference = MakeLexer(Lexers, ANSI, /* antlr4 = */ true, ELexerFlavor::Pure); + + SetRandomSeed(100); + for (size_t i = 0; i < 512; ++i) { + auto input = RandomMultilineCommentLikeText(/* maxSize = */ 32); + TString actual = Tokenized(lexer, input); + TString expected = Tokenized(reference, input); + + UNIT_ASSERT_VALUES_EQUAL_C(actual, expected, "Input: " << input); + } + } + + Y_UNIT_TEST_ON_EACH_LEXER(SimpleQuery) { + auto lexer = MakeLexer(Lexers, ANSI, ANTLR4, FLAVOR); + UNIT_ASSERT_TOKENIZED(lexer, "select 1", "SELECT WS( ) DIGITS(1) EOF"); + UNIT_ASSERT_TOKENIZED(lexer, "SELect 1", "SELECT WS( ) DIGITS(1) EOF"); + } + + Y_UNIT_TEST_ON_EACH_LEXER(ComplexQuery) { + auto lexer = MakeLexer(Lexers, ANSI, ANTLR4, FLAVOR); + + TString query = + "SELECT\n" + " 123467,\n" + " \"Hello, {name}!\",\n" + " (1 + (5U * 1 / 0)),\n" + " MIN(identifier),\n" + " Bool(field),\n" + " Math::Sin(var)\n" + "FROM `local/test/space/table`\n" + "JOIN test;"; + + TString expected = + "SELECT WS(\n) " + "WS( ) WS( ) DIGITS(123467) COMMA(,) WS(\n) " + "WS( ) WS( ) STRING_VALUE(\"Hello, {name}!\") COMMA(,) WS(\n) " + "WS( ) WS( ) LPAREN(() DIGITS(1) WS( ) PLUS(+) WS( ) LPAREN(() INTEGER_VALUE(5U) WS( ) " + "ASTERISK(*) WS( ) DIGITS(1) WS( ) SLASH(/) WS( ) DIGITS(0) RPAREN()) " + "RPAREN()) COMMA(,) WS(\n) " + "WS( ) WS( ) ID_PLAIN(MIN) LPAREN(() ID_PLAIN(identifier) RPAREN()) COMMA(,) WS(\n) " + "WS( ) WS( ) ID_PLAIN(Bool) LPAREN(() ID_PLAIN(field) RPAREN()) COMMA(,) WS(\n) " + "WS( ) WS( ) ID_PLAIN(Math) NAMESPACE(::) ID_PLAIN(Sin) LPAREN(() ID_PLAIN(var) RPAREN()) WS(\n) " + "FROM WS( ) ID_QUOTED(`local/test/space/table`) WS(\n) " + "JOIN WS( ) ID_PLAIN(test) SEMICOLON(;) EOF"; + + UNIT_ASSERT_TOKENIZED(lexer, query, expected); + } + +} // Y_UNIT_TEST_SUITE(SQLv1Lexer) diff --git a/yql/essentials/sql/v1/lexer/lexer_ut.h b/yql/essentials/sql/v1/lexer/lexer_ut.h new file mode 100644 index 00000000000..b4304eb7070 --- /dev/null +++ b/yql/essentials/sql/v1/lexer/lexer_ut.h @@ -0,0 +1,37 @@ +#pragma once + +#include "lexer.h" + +#define LEXER_NAME_ANSI_false_ANTLR4_false_FLAVOR_Default "antlr3" +#define LEXER_NAME_ANSI_false_ANTLR4_true_FLAVOR_Default "antlr4" +#define LEXER_NAME_ANSI_true_ANTLR4_false_FLAVOR_Default "antlr3_ansi" +#define LEXER_NAME_ANSI_true_ANTLR4_true_FLAVOR_Default "antlr4_ansi" +#define LEXER_NAME_ANSI_false_ANTLR4_true_FLAVOR_Pure "antlr4_pure" +#define LEXER_NAME_ANSI_true_ANTLR4_true_FLAVOR_Pure "antlr4_pure_ansi" +#define LEXER_NAME_ANSI_false_ANTLR4_false_FLAVOR_Regex "regex" +#define LEXER_NAME_ANSI_true_ANTLR4_false_FLAVOR_Regex "regex_ansi" + +#define Y_UNIT_TEST_ON_EACH_LEXER_ADD_TEST(N, ANSI, ANTLR4, FLAVOR) \ + TCurrentTest::AddTest( \ + #N "::" LEXER_NAME_ANSI_##ANSI##_ANTLR4_##ANTLR4##_FLAVOR_##FLAVOR, \ + static_cast<void (*)(NUnitTest::TTestContext&)>(&N<ANSI, ANTLR4, ELexerFlavor::FLAVOR>), \ + /* forceFork = */ false) + +#define Y_UNIT_TEST_ON_EACH_LEXER(N) \ + template <bool ANSI, bool ANTLR4, ELexerFlavor FLAVOR> \ + void N(NUnitTest::TTestContext&); \ + struct TTestRegistration##N { \ + TTestRegistration##N() { \ + Y_UNIT_TEST_ON_EACH_LEXER_ADD_TEST(N, false, false, Default); \ + Y_UNIT_TEST_ON_EACH_LEXER_ADD_TEST(N, false, true, Default); \ + Y_UNIT_TEST_ON_EACH_LEXER_ADD_TEST(N, true, false, Default); \ + Y_UNIT_TEST_ON_EACH_LEXER_ADD_TEST(N, true, true, Default); \ + Y_UNIT_TEST_ON_EACH_LEXER_ADD_TEST(N, false, true, Pure); \ + Y_UNIT_TEST_ON_EACH_LEXER_ADD_TEST(N, true, true, Pure); \ + Y_UNIT_TEST_ON_EACH_LEXER_ADD_TEST(N, false, false, Regex); \ + Y_UNIT_TEST_ON_EACH_LEXER_ADD_TEST(N, true, false, Regex); \ + } \ + }; \ + static TTestRegistration##N testRegistration##N; \ + template <bool ANSI, bool ANTLR4, ELexerFlavor FLAVOR> \ + void N(NUnitTest::TTestContext&) diff --git a/yql/essentials/sql/v1/lexer/regex/lexer.cpp b/yql/essentials/sql/v1/lexer/regex/lexer.cpp index 1c8f2104a48..b0b5c2dad44 100644 --- a/yql/essentials/sql/v1/lexer/regex/lexer.cpp +++ b/yql/essentials/sql/v1/lexer/regex/lexer.cpp @@ -10,6 +10,7 @@ #include <util/generic/algorithm.h> #include <util/generic/string.h> #include <util/string/subst.h> +#include <util/string/ascii.h> namespace NSQLTranslationV1 { @@ -23,15 +24,15 @@ namespace NSQLTranslationV1 { TRegexLexer( bool ansi, NSQLReflect::TLexerGrammar grammar, - const THashMap<TString, TString>& RegexByOtherNameMap) + const TVector<std::tuple<TString, TString>>& RegexByOtherName) : Grammar_(std::move(grammar)) , Ansi_(ansi) { - for (auto& [token, regex] : RegexByOtherNameMap) { + for (const auto& [token, regex] : RegexByOtherName) { if (token == CommentTokenName) { CommentRegex_.Reset(new RE2(regex)); } else { - OtherRegexes_.emplace(std::move(token), std::move(regex)); + OtherRegexes_.emplace_back(token, new RE2(regex)); } } } @@ -71,27 +72,27 @@ namespace NSQLTranslationV1 { size_t keywordCount = MatchKeyword(prefix, matches); MatchPunctuation(prefix, matches); - size_t otherCount = MatchRegex(prefix, matches); + MatchRegex(prefix, matches); MatchComment(prefix, matches); - auto max = MaxElementBy(matches, [](const TParsedToken& m) { - return m.Content.length(); - }); - - if (max == std::end(matches)) { + if (matches.empty()) { return {}; } + auto maxLength = MaxElementBy(matches, [](const TParsedToken& m) { + return m.Content.length(); + })->Content.length(); + + auto max = FindIf(matches, [&](const TParsedToken& m) { + return m.Content.length() == maxLength; + }); + auto isMatched = [&](const TStringBuf name) { return std::end(matches) != FindIf(matches, [&](const auto& m) { return m.Name == name; }); }; - Y_ENSURE( - otherCount <= 1 || - (otherCount == 2 && isMatched("DIGITS") && isMatched("INTEGER_VALUE"))); - size_t conflicts = CountIf(matches, [&](const TParsedToken& m) { return m.Content.length() == max->Content.length(); }); @@ -108,7 +109,7 @@ namespace NSQLTranslationV1 { bool MatchKeyword(const TStringBuf prefix, TParsedTokenList& matches) { size_t count = 0; for (const auto& keyword : Grammar_.KeywordNames) { - if (prefix.substr(0, keyword.length()) == keyword) { + if (AsciiEqualsIgnoreCase(prefix.substr(0, keyword.length()), keyword)) { matches.emplace_back(keyword, keyword); count += 1; } @@ -131,7 +132,7 @@ namespace NSQLTranslationV1 { size_t MatchRegex(const TStringBuf prefix, TParsedTokenList& matches) { size_t count = 0; for (const auto& [token, regex] : OtherRegexes_) { - if (const TStringBuf match = TryMatchRegex(prefix, regex); !match.empty()) { + if (const TStringBuf match = TryMatchRegex(prefix, *regex); !match.empty()) { matches.emplace_back(token, TString(match)); count += 1; } @@ -216,7 +217,7 @@ namespace NSQLTranslationV1 { } NSQLReflect::TLexerGrammar Grammar_; - THashMap<TString, RE2> OtherRegexes_; + TVector<std::tuple<TString, THolder<RE2>>> OtherRegexes_; THolder<RE2> CommentRegex_; bool Ansi_; }; @@ -228,19 +229,19 @@ namespace NSQLTranslationV1 { explicit TFactory(bool ansi) : Ansi_(ansi) , Grammar_(NSQLReflect::LoadLexerGrammar()) - , RegexByOtherNameMap_(MakeRegexByOtherNameMap(Grammar_, Ansi_)) + , RegexByOtherName_(MakeRegexByOtherName(Grammar_, Ansi_)) { } NSQLTranslation::ILexer::TPtr MakeLexer() const override { return NSQLTranslation::ILexer::TPtr( - new TRegexLexer(Ansi_, Grammar_, RegexByOtherNameMap_)); + new TRegexLexer(Ansi_, Grammar_, RegexByOtherName_)); } private: bool Ansi_; NSQLReflect::TLexerGrammar Grammar_; - THashMap<TString, TString> RegexByOtherNameMap_; + TVector<std::tuple<TString, TString>> RegexByOtherName_; }; } // namespace diff --git a/yql/essentials/sql/v1/lexer/regex/regex.cpp b/yql/essentials/sql/v1/lexer/regex/regex.cpp index a8aca8a1318..937d21572fc 100644 --- a/yql/essentials/sql/v1/lexer/regex/regex.cpp +++ b/yql/essentials/sql/v1/lexer/regex/regex.cpp @@ -227,12 +227,12 @@ namespace NSQLTranslationV1 { TRewriteRule UnwrapQuotedSpace_; }; - THashMap<TString, TString> MakeRegexByOtherNameMap(const NSQLReflect::TLexerGrammar& grammar, bool ansi) { + TVector<std::tuple<TString, TString>> MakeRegexByOtherName(const NSQLReflect::TLexerGrammar& grammar, bool ansi) { TLexerGrammarToRegexTranslator translator(grammar, ansi); - THashMap<TString, TString> regexes; + TVector<std::tuple<TString, TString>> regexes; for (const auto& token : grammar.OtherNames) { - regexes.emplace(token, translator.ToRegex(token)); + regexes.emplace_back(token, translator.ToRegex(token)); } return regexes; } diff --git a/yql/essentials/sql/v1/lexer/regex/regex.h b/yql/essentials/sql/v1/lexer/regex/regex.h index 9e29c3df25b..1e9d92b6535 100644 --- a/yql/essentials/sql/v1/lexer/regex/regex.h +++ b/yql/essentials/sql/v1/lexer/regex/regex.h @@ -8,7 +8,7 @@ namespace NSQLTranslationV1 { // Makes regexes only for tokens from OtherNames, // as keywords and punctuation are trivially matched. - THashMap<TString, TString> MakeRegexByOtherNameMap( + TVector<std::tuple<TString, TString>> MakeRegexByOtherName( const NSQLReflect::TLexerGrammar& grammar, bool ansi); } // namespace NSQLTranslationV1 diff --git a/yql/essentials/sql/v1/lexer/ut/ya.make b/yql/essentials/sql/v1/lexer/ut/ya.make index 7e62fb50c85..87cb156cd93 100644 --- a/yql/essentials/sql/v1/lexer/ut/ya.make +++ b/yql/essentials/sql/v1/lexer/ut/ya.make @@ -4,8 +4,11 @@ PEERDIR( yql/essentials/core/issue yql/essentials/parser/lexer_common yql/essentials/sql/v1/lexer/antlr3 + yql/essentials/sql/v1/lexer/antlr3_ansi yql/essentials/sql/v1/lexer/antlr4 + yql/essentials/sql/v1/lexer/antlr4_ansi yql/essentials/sql/v1/lexer/antlr4_pure + yql/essentials/sql/v1/lexer/antlr4_pure_ansi yql/essentials/sql/v1/lexer/regex ) diff --git a/yql/essentials/sql/v1/query.cpp b/yql/essentials/sql/v1/query.cpp index 8e886064784..fc18200b513 100644 --- a/yql/essentials/sql/v1/query.cpp +++ b/yql/essentials/sql/v1/query.cpp @@ -3677,9 +3677,10 @@ TNodePtr BuildAnalyze(TPosition pos, const TString& service, const TDeferredAtom class TShowCreateNode final : public TAstListNode { public: - TShowCreateNode(TPosition pos, const TTableRef& tr, TScopedStatePtr scoped) + TShowCreateNode(TPosition pos, const TTableRef& tr, const TString& type, TScopedStatePtr scoped) : TAstListNode(pos) , Table(tr) + , Type(type) , Scoped(scoped) , FakeSource(BuildFakeSource(pos)) { @@ -3691,9 +3692,9 @@ public: if (!Table.Options->Init(ctx, src)) { return false; } - Table.Options = L(Table.Options, Q(Y(Q("showCreateTable")))); + Table.Options = L(Table.Options, Q(Y(Q(Type)))); } else { - Table.Options = Y(Q(Y(Q("showCreateTable")))); + Table.Options = Y(Q(Y(Q(Type)))); } bool asRef = ctx.PragmaRefSelect; @@ -3741,12 +3742,14 @@ public: } private: TTableRef Table; + // showCreateTable, showCreateView, ... + TString Type; TScopedStatePtr Scoped; TSourcePtr FakeSource; }; -TNodePtr BuildShowCreate(TPosition pos, const TTableRef& tr, TScopedStatePtr scoped) { - return new TShowCreateNode(pos, tr, scoped); +TNodePtr BuildShowCreate(TPosition pos, const TTableRef& tr, const TString& type, TScopedStatePtr scoped) { + return new TShowCreateNode(pos, tr, type, scoped); } class TBaseBackupCollectionNode diff --git a/yql/essentials/sql/v1/reflect/sql_reflect.cpp b/yql/essentials/sql/v1/reflect/sql_reflect.cpp index f47f35cb9de..c0af06e0b46 100644 --- a/yql/essentials/sql/v1/reflect/sql_reflect.cpp +++ b/yql/essentials/sql/v1/reflect/sql_reflect.cpp @@ -134,7 +134,7 @@ namespace NSQLReflect { auto [name, block] = ParseLexerRule(std::move(line)); if (!name.StartsWith(FragmentPrefix)) { - grammar.OtherNames.emplace(name); + grammar.OtherNames.emplace_back(name); } SubstGlobal(name, FragmentPrefix, ""); diff --git a/yql/essentials/sql/v1/reflect/sql_reflect.h b/yql/essentials/sql/v1/reflect/sql_reflect.h index 5225a3c996b..ca398706873 100644 --- a/yql/essentials/sql/v1/reflect/sql_reflect.h +++ b/yql/essentials/sql/v1/reflect/sql_reflect.h @@ -1,15 +1,16 @@ #pragma once #include <util/generic/string.h> -#include <util/generic/hash_set.h> #include <util/generic/hash.h> +#include <util/generic/hash_set.h> +#include <util/generic/vector.h> namespace NSQLReflect { struct TLexerGrammar { THashSet<TString> KeywordNames; THashSet<TString> PunctuationNames; - THashSet<TString> OtherNames; + TVector<TString> OtherNames; THashMap<TString, TString> BlockByName; }; diff --git a/yql/essentials/sql/v1/source.h b/yql/essentials/sql/v1/source.h index 6eb040f2e42..3048b2d5847 100644 --- a/yql/essentials/sql/v1/source.h +++ b/yql/essentials/sql/v1/source.h @@ -318,7 +318,7 @@ namespace NSQLTranslationV1 { TNodePtr BuildWriteTable(TPosition pos, const TString& label, const TTableRef& table, EWriteColumnMode mode, TNodePtr options, TScopedStatePtr scoped); TNodePtr BuildAnalyze(TPosition pos, const TString& service, const TDeferredAtom& cluster, const TAnalyzeParams& params, TScopedStatePtr scoped); - TNodePtr BuildShowCreate(TPosition pos, const TTableRef& table, TScopedStatePtr scoped); + TNodePtr BuildShowCreate(TPosition pos, const TTableRef& table, const TString& type, TScopedStatePtr scoped); TNodePtr BuildAlterSequence(TPosition pos, const TString& service, const TDeferredAtom& cluster, const TString& id, const TSequenceParameters& params, TScopedStatePtr scoped); TSourcePtr TryMakeSourceFromExpression(TPosition pos, TContext& ctx, const TString& currService, const TDeferredAtom& currCluster, TNodePtr node, const TString& view = {}); diff --git a/yql/essentials/sql/v1/sql_query.cpp b/yql/essentials/sql/v1/sql_query.cpp index b59ae88c4b6..627f1e0ae33 100644 --- a/yql/essentials/sql/v1/sql_query.cpp +++ b/yql/essentials/sql/v1/sql_query.cpp @@ -191,7 +191,7 @@ static bool TransferSettingsEntry(std::map<TString, TNodePtr>& out, ctx.Context().Error() << key.Name << " is not supported in ALTER"; return false; } - + if (!out.emplace(keyName, value).second) { ctx.Context().Error() << "Duplicate transfer setting: " << key.Name; } @@ -1961,7 +1961,7 @@ bool TSqlQuery::Statement(TVector<TNodePtr>& blocks, const TRule_sql_stmt_core& break; } case TRule_sql_stmt_core::kAltSqlStmtCore62: { - // show_create_table_stmt: SHOW CREATE TABLE table_ref + // show_create_table_stmt: SHOW CREATE (TABLE | VIEW) table_ref Ctx.BodyPart(); const auto& rule = core.GetAlt_sql_stmt_core62().GetRule_show_create_table_stmt1(); @@ -1969,8 +1969,16 @@ bool TSqlQuery::Statement(TVector<TNodePtr>& blocks, const TRule_sql_stmt_core& if (!SimpleTableRefImpl(rule.GetRule_simple_table_ref4(), tr)) { return false; } + TString type; + if (auto typeToken = to_lower(rule.GetToken3().GetValue()); typeToken == "table") { + type = "showCreateTable"; + } else if (typeToken == "view") { + type = "showCreateView"; + } else { + YQL_ENSURE(false, "Unsupported SHOW CREATE statement type: " << typeToken); + } - AddStatementToBlocks(blocks, BuildShowCreate(Ctx.Pos(), tr, Ctx.Scoped)); + AddStatementToBlocks(blocks, BuildShowCreate(Ctx.Pos(), tr, type, Ctx.Scoped)); break; } case TRule_sql_stmt_core::ALT_NOT_SET: @@ -3521,14 +3529,20 @@ TNodePtr TSqlQuery::Build(const TRule_delete_stmt& stmt) { TSourcePtr source = BuildTableSource(Ctx.Pos(), table); + const bool isBatch = stmt.HasBlock1(); TNodePtr options = nullptr; + if (stmt.HasBlock6()) { + if (isBatch) { + Ctx.Error(GetPos(stmt.GetToken2())) + << "BATCH DELETE is unsupported with RETURNING"; + return nullptr; + } + options = ReturningList(stmt.GetBlock6().GetRule_returning_columns_list1()); options = options->Y(options); } - const bool isBatch = stmt.HasBlock1(); - if (stmt.HasBlock5()) { switch (stmt.GetBlock5().Alt_case()) { case TRule_delete_stmt_TBlock5::kAlt1: { @@ -3585,14 +3599,20 @@ TNodePtr TSqlQuery::Build(const TRule_update_stmt& stmt) { return nullptr; } + const bool isBatch = stmt.HasBlock1(); TNodePtr options = nullptr; + if (stmt.HasBlock5()) { + if (isBatch) { + Ctx.Error(GetPos(stmt.GetToken2())) + << "BATCH UPDATE is unsupported with RETURNING"; + return nullptr; + } + options = ReturningList(stmt.GetBlock5().GetRule_returning_columns_list1()); options = options->Y(options); } - const bool isBatch = stmt.HasBlock1(); - switch (stmt.GetBlock4().Alt_case()) { case TRule_update_stmt_TBlock4::kAlt1: { const auto& alt = stmt.GetBlock4().GetAlt1(); diff --git a/yql/essentials/sql/v1/sql_ut_common.h b/yql/essentials/sql/v1/sql_ut_common.h index 36fe641ba63..564885f8c55 100644 --- a/yql/essentials/sql/v1/sql_ut_common.h +++ b/yql/essentials/sql/v1/sql_ut_common.h @@ -1475,6 +1475,12 @@ Y_UNIT_TEST_SUITE(SqlParsingOnly) { UNIT_ASSERT_VALUES_EQUAL(1, elementStat["Write"]); } + Y_UNIT_TEST(DeleteFromTableBatchReturning) { + NYql::TAstParseResult res = SqlToYql("batch delete from plato.Input returning *;", 10, "kikimr"); + UNIT_ASSERT(!res.Root); + UNIT_ASSERT_NO_DIFF(Err2Str(res), "<main>:1:6: Error: BATCH DELETE is unsupported with RETURNING\n"); + } + Y_UNIT_TEST(DeleteFromTableOnValues) { NYql::TAstParseResult res = SqlToYql("delete from plato.Input on (key) values (1);", 10, "kikimr"); @@ -1559,6 +1565,12 @@ Y_UNIT_TEST_SUITE(SqlParsingOnly) { UNIT_ASSERT_VALUES_EQUAL(1, elementStat["Write"]); } + Y_UNIT_TEST(UpdateByValuesBatchReturning) { + NYql::TAstParseResult res = SqlToYql("batch update plato.Input set value = 'cool' where key = 200 returning key;", 10, "kikimr"); + UNIT_ASSERT(!res.Root); + UNIT_ASSERT_NO_DIFF(Err2Str(res), "<main>:1:6: Error: BATCH UPDATE is unsupported with RETURNING\n"); + } + Y_UNIT_TEST(UpdateByMultiValues) { NYql::TAstParseResult res = SqlToYql("update plato.Input set (key, value, subkey) = ('2','ddd',':') where key = 200;", 10, "kikimr"); UNIT_ASSERT(res.Root); @@ -3197,6 +3209,26 @@ Y_UNIT_TEST_SUITE(SqlParsingOnly) { UNIT_ASSERT_VALUES_EQUAL(1, elementStat["showCreateTable"]); } + Y_UNIT_TEST(ShowCreateView) { + NYql::TAstParseResult res = SqlToYql(R"( + USE plato; + SHOW CREATE VIEW user; + )"); + UNIT_ASSERT(res.Root); + + TVerifyLineFunc verifyLine = [](const TString& word, const TString& line) { + if (word == "Read") { + UNIT_ASSERT_STRING_CONTAINS(line, "showCreateView"); + } + }; + + TWordCountHive elementStat = {{"Read"}, {"showCreateView"}}; + VerifyProgram(res, elementStat, verifyLine); + + UNIT_ASSERT_VALUES_EQUAL(elementStat["Read"], 1); + UNIT_ASSERT_VALUES_EQUAL(elementStat["showCreateView"], 1); + } + Y_UNIT_TEST(OptionalAliases) { UNIT_ASSERT(SqlToYql("USE plato; SELECT foo FROM (SELECT key foo FROM Input);").IsOk()); UNIT_ASSERT(SqlToYql("USE plato; SELECT a.x FROM Input1 a JOIN Input2 b ON a.key = b.key;").IsOk()); diff --git a/yt/cpp/mapreduce/common/retry_lib.cpp b/yt/cpp/mapreduce/common/retry_lib.cpp index 53216bd3f86..e898bfc3814 100644 --- a/yt/cpp/mapreduce/common/retry_lib.cpp +++ b/yt/cpp/mapreduce/common/retry_lib.cpp @@ -227,6 +227,7 @@ static TMaybe<TDuration> TryGetBackoffDuration(const TErrorResponse& errorRespon NSequoiaClient::SequoiaRetriableError, NRpc::TransientFailure, Canceled, + Timeout, }) { if (allCodes.contains(code)) { return config->RetryInterval; diff --git a/yt/yql/providers/yt/codec/yt_codec.cpp b/yt/yql/providers/yt/codec/yt_codec.cpp index 6e64136937f..14e317df97f 100644 --- a/yt/yql/providers/yt/codec/yt_codec.cpp +++ b/yt/yql/providers/yt/codec/yt_codec.cpp @@ -305,6 +305,13 @@ void TMkqlIOSpecs::InitDecoder(NCommon::TCodecContext& codecCtx, } } + if (InputBlockRepresentation_ == EBlockRepresentation::BlockStruct) { + if (auto pos = rowType->FindMemberIndex(BlockLengthColumnName)) { + virtualColumns.insert(*pos); + decoder.FillBlockStructSize = pos; + } + } + THashSet<ui32> usedPos; for (ui32 index = 0; index < rowType->GetMembersCount(); ++index) { auto name = rowType->GetMemberNameStr(index); @@ -444,6 +451,7 @@ void TMkqlIOSpecs::InitInput(NCommon::TCodecContext& codecCtx, TSpecInfo localSpecInfo; TSpecInfo* specInfo = &localSpecInfo; TString decoderRefName = TStringBuilder() << "_internal" << inputIndex; + bool newSpec = false; if (inputSpecs[inputIndex].IsString()) { auto refName = inputSpecs[inputIndex].AsString(); decoderRefName = refName; @@ -453,9 +461,14 @@ void TMkqlIOSpecs::InitInput(NCommon::TCodecContext& codecCtx, Y_ENSURE(inAttrs.HasKey(YqlIOSpecRegistry) && inAttrs[YqlIOSpecRegistry].HasKey(refName), "Bad input registry reference: " << refName); specInfo = &specInfoRegistry[refName]; LoadSpecInfo(true, inAttrs[YqlIOSpecRegistry][refName], codecCtx, *specInfo); + newSpec = true; } } else { LoadSpecInfo(true, inputSpecs[inputIndex], codecCtx, localSpecInfo); + newSpec = true; + } + if (InputBlockRepresentation_ == EBlockRepresentation::BlockStruct && newSpec) { + specInfo->Type = codecCtx.Builder.NewStructType(specInfo->Type, BlockLengthColumnName, TDataType::Create(NUdf::TDataType<ui64>::Id, codecCtx.Env)); } TStructType* inStruct = AS_TYPE(TStructType, specInfo->Type); diff --git a/yt/yql/providers/yt/codec/yt_codec.h b/yt/yql/providers/yt/codec/yt_codec.h index 4e6ef543a5f..ed309564799 100644 --- a/yt/yql/providers/yt/codec/yt_codec.h +++ b/yt/yql/providers/yt/codec/yt_codec.h @@ -31,6 +31,12 @@ public: Y_DECLARE_FLAGS(TSystemFields, ESystemField); + enum class EBlockRepresentation { + None, + WideBlock, + BlockStruct, + }; + struct TSpecInfo { NKikimr::NMiniKQL::TType* Type = nullptr; bool StrictSchema = true; @@ -65,6 +71,7 @@ public: TMaybe<ui32> FillSysColumnIndex; TMaybe<ui32> FillSysColumnNum; TMaybe<ui32> FillSysColumnKeySwitch; + TMaybe<ui32> FillBlockStructSize; }; struct TEncoderSpec { @@ -137,6 +144,10 @@ public: IsTableContent_ = true; } + void SetInputBlockRepresentation(EBlockRepresentation type) { + InputBlockRepresentation_ = type; + } + void SetTableOffsets(const TVector<ui64>& offsets); void Clear(); @@ -156,6 +167,8 @@ public: TString OptLLVM_; TSystemFields SystemFields_; + EBlockRepresentation InputBlockRepresentation_ = EBlockRepresentation::None; + NKikimr::NMiniKQL::IStatsRegistry* JobStats_ = nullptr; THashMap<TString, TDecoderSpec> Decoders; TVector<const TDecoderSpec*> Inputs; diff --git a/yt/yql/providers/yt/codec/yt_codec_io.cpp b/yt/yql/providers/yt/codec/yt_codec_io.cpp index 0a6b31f4e77..a46ecc0f680 100644 --- a/yt/yql/providers/yt/codec/yt_codec_io.cpp +++ b/yt/yql/providers/yt/codec/yt_codec_io.cpp @@ -651,7 +651,7 @@ struct TMkqlReaderImpl::TDecoder { KeySwitch_ = false; } - void Reset(bool hasRangeIndices, ui32 tableIndex, bool ignoreStreamTableIndex) { + virtual void Reset(bool hasRangeIndices, ui32 tableIndex, bool ignoreStreamTableIndex) { HasRangeIndices_ = hasRangeIndices; TableIndex_ = tableIndex; AtStart_ = true; @@ -1463,7 +1463,7 @@ public: , Pool_(pool) { InputStream_ = std::make_unique<TInputBufArrowInputStream>(buf, pool); - ResetColumnConverters(); + HandleTableSwitch(); HandlesSysColumns_ = true; } @@ -1482,14 +1482,19 @@ public: YQL_ENSURE(!Chunks_.empty()); } + bool isWideBlock = (Specs_.InputBlockRepresentation_ == TMkqlIOSpecs::EBlockRepresentation::WideBlock); + auto& decoder = *Specs_.Inputs[TableIndex_]; - Row_ = SpecsCache_.NewRow(TableIndex_, items, true); + Row_ = SpecsCache_.NewRow(TableIndex_, items, isWideBlock); auto& [chunkRowIndex, chunkLen, chunk] = Chunks_.front(); for (size_t i = 0; i < decoder.StructSize; i++) { + if (i == decoder.FillBlockStructSize) { + continue; + } items[i] = SpecsCache_.GetHolderFactory().CreateArrowBlock(std::move(chunk[i])); } - items[decoder.StructSize] = SpecsCache_.GetHolderFactory().CreateArrowBlock(arrow::Datum(static_cast<uint64_t>(chunkLen))); + items[BlockSizeStructIndex_] = SpecsCache_.GetHolderFactory().CreateArrowBlock(arrow::Datum(static_cast<uint64_t>(chunkLen))); RowIndex_ = chunkRowIndex; Chunks_.pop_front(); @@ -1505,17 +1510,17 @@ public: } StreamReader_ = ARROW_RESULT(streamReaderResult); - auto oldTableIndex = TableIndex_; if (!IgnoreStreamTableIndex) { + auto oldTableIndex = TableIndex_; auto tableIdKey = StreamReader_->schema()->metadata()->Get("TableId"); if (tableIdKey.ok()) { TableIndex_ = std::stoi(tableIdKey.ValueOrDie()); YQL_ENSURE(TableIndex_ < Specs_.Inputs.size()); } - } - if (TableIndex_ != oldTableIndex) { - ResetColumnConverters(); + if (TableIndex_ != oldTableIndex) { + HandleTableSwitch(); + } } } @@ -1523,6 +1528,8 @@ public: ARROW_OK(StreamReader_->ReadNext(&batch)); if (!batch) { if (InputStream_->EOSReached()) { + // Prepare for possible table switch + StreamReader_.reset(); return false; } @@ -1565,6 +1572,9 @@ public: } } else if (decoder.FillSysColumnIndex == inputFields[i].StructIndex) { convertedColumn = ARROW_RESULT(arrow::MakeArrayFromScalar(arrow::UInt32Scalar(TableIndex_), batch->num_rows())); + } else if (decoder.FillBlockStructSize == inputFields[i].StructIndex) { + // Actual value will be specified later + convertedColumn = arrow::Datum(static_cast<uint64_t>(0)); } else if (inputFields[i].StructIndex == Max<ui32>()) { // Input field won't appear in the result continue; @@ -1593,14 +1603,22 @@ public: return true; } - void ResetColumnConverters() { - auto& fields = Specs_.Inputs[TableIndex_]->FieldsVec; + void HandleTableSwitch() { + auto& decoder = Specs_.Inputs[TableIndex_]; + ColumnConverters_.clear(); - ColumnConverters_.reserve(fields.size()); - for (auto& field: fields) { + ColumnConverters_.reserve(decoder->FieldsVec.size()); + for (auto& field: decoder->FieldsVec) { YQL_ENSURE(!field.Type->IsPg()); ColumnConverters_.emplace_back(MakeYtColumnConverter(field.Type, nullptr, *Pool_, Specs_.Inputs[TableIndex_]->NativeYtTypeFlags)); } + + BlockSizeStructIndex_ = GetBlockSizeStructIndex(Specs_, TableIndex_); + } + + void Reset(bool hasRangeIndices, ui32 tableIndex, bool ignoreStreamTableIndex) override { + TDecoder::Reset(hasRangeIndices, tableIndex, ignoreStreamTableIndex); + HandleTableSwitch(); } private: @@ -1610,6 +1628,8 @@ private: TDeque<std::tuple<ui64, ui64, std::vector<arrow::Datum>>> Chunks_; + size_t BlockSizeStructIndex_ = 0; + const TMkqlIOSpecs& Specs_; arrow::MemoryPool* Pool_; }; @@ -2517,6 +2537,27 @@ void DecodeToYson(TMkqlIOCache& specsCache, size_t tableIndex, const NUdf::TUnbo WriteRowItems(specsCache, tableIndex, items, {}, ysonOut); } +ui32 GetBlockSizeStructIndex(const TMkqlIOSpecs& specs, size_t tableIndex) { + auto& decoder = specs.Inputs[tableIndex]; + + ui32 blockSizeStructIndex = 0; + switch (specs.InputBlockRepresentation_) { + case TMkqlIOSpecs::EBlockRepresentation::WideBlock: + blockSizeStructIndex = decoder->StructSize; + break; + + case TMkqlIOSpecs::EBlockRepresentation::BlockStruct: + YQL_ENSURE(decoder->FillBlockStructSize.Defined()); + blockSizeStructIndex = *decoder->FillBlockStructSize; + break; + + default: + YQL_ENSURE(false, "unknown block representation"); + } + + return blockSizeStructIndex; +} + ////////////////////////////////////////////////////////////////////////////////////////////////////////// } // NYql diff --git a/yt/yql/providers/yt/codec/yt_codec_io.h b/yt/yql/providers/yt/codec/yt_codec_io.h index 47f8d098635..3a8c4212954 100644 --- a/yt/yql/providers/yt/codec/yt_codec_io.h +++ b/yt/yql/providers/yt/codec/yt_codec_io.h @@ -164,4 +164,6 @@ void DecodeToYson(TMkqlIOCache& specsCache, size_t tableIndex, const NKikimr::NU THolder<NCommon::IBlockReader> MakeBlockReader(NYT::TRawTableReader& source, size_t blockCount, size_t blockSize); +ui32 GetBlockSizeStructIndex(const TMkqlIOSpecs& specs, size_t tableIndex); + } // NYql diff --git a/yt/yql/providers/yt/comp_nodes/yql_mkql_block_table_content.cpp b/yt/yql/providers/yt/comp_nodes/yql_mkql_block_table_content.cpp index d935da20041..d4b65187b4f 100644 --- a/yt/yql/providers/yt/comp_nodes/yql_mkql_block_table_content.cpp +++ b/yt/yql/providers/yt/comp_nodes/yql_mkql_block_table_content.cpp @@ -1,5 +1,5 @@ #include "yql_mkql_block_table_content.h" -#include "yql_mkql_file_block_stream.h" +#include "yql_mkql_file_list.h" #include <yql/essentials/minikql/computation/mkql_computation_node_impl.h> #include <yql/essentials/minikql/mkql_node_cast.h> @@ -20,19 +20,20 @@ class TYtBlockTableContentWrapper : public TMutableComputationNode<TYtBlockTable typedef TMutableComputationNode<TYtBlockTableContentWrapper> TBaseComputation; public: TYtBlockTableContentWrapper(TComputationMutables& mutables, NCommon::TCodecContext& codecCtx, - TVector<TString>&& files, const TString& inputSpec, TStructType* origStructType, bool decompress, std::optional<ui64> expectedRowCount) + TVector<TString>&& files, const TString& inputSpec, TType* listType, bool decompress, std::optional<ui64> expectedRowCount) : TBaseComputation(mutables) , Files_(std::move(files)) , Decompress_(decompress) , ExpectedRowCount_(std::move(expectedRowCount)) { Spec_.SetUseBlockInput(); + Spec_.SetInputBlockRepresentation(TMkqlIOSpecs::EBlockRepresentation::BlockStruct); Spec_.SetIsTableContent(); - Spec_.Init(codecCtx, inputSpec, {}, {}, origStructType, {}, TString()); + Spec_.Init(codecCtx, inputSpec, {}, {}, AS_TYPE(TListType, listType)->GetItemType(), {}, TString()); } NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const { - return ctx.HolderFactory.Create<TFileWideBlockStreamValue>(Spec_, ctx.HolderFactory, Files_, Decompress_, 4, 1_MB, ExpectedRowCount_); + return ctx.HolderFactory.Create<TFileListValue>(Spec_, ctx.HolderFactory, Files_, Decompress_, 4, 1_MB, ExpectedRowCount_); } private: @@ -47,15 +48,14 @@ private: IComputationNode* WrapYtBlockTableContent(NCommon::TCodecContext& codecCtx, TComputationMutables& mutables, TCallable& callable, TStringBuf pathPrefix) { - MKQL_ENSURE(callable.GetInputsCount() == 6, "Expected 6 arguments"); + MKQL_ENSURE(callable.GetInputsCount() == 5, "Expected 5 arguments"); TString uniqueId(AS_VALUE(TDataLiteral, callable.GetInput(0))->AsValue().AsStringRef()); - auto origStructType = AS_TYPE(TStructType, AS_VALUE(TTypeType, callable.GetInput(1))); - const ui32 tablesCount = AS_VALUE(TDataLiteral, callable.GetInput(2))->AsValue().Get<ui32>(); - TString inputSpec(AS_VALUE(TDataLiteral, callable.GetInput(3))->AsValue().AsStringRef()); - const bool decompress = AS_VALUE(TDataLiteral, callable.GetInput(4))->AsValue().Get<bool>(); + const ui32 tablesCount = AS_VALUE(TDataLiteral, callable.GetInput(1))->AsValue().Get<ui32>(); + TString inputSpec(AS_VALUE(TDataLiteral, callable.GetInput(2))->AsValue().AsStringRef()); + const bool decompress = AS_VALUE(TDataLiteral, callable.GetInput(3))->AsValue().Get<bool>(); std::optional<ui64> length; - TTupleLiteral* lengthTuple = AS_VALUE(TTupleLiteral, callable.GetInput(5)); + TTupleLiteral* lengthTuple = AS_VALUE(TTupleLiteral, callable.GetInput(4)); if (lengthTuple->GetValuesCount() > 0) { MKQL_ENSURE(lengthTuple->GetValuesCount() == 1, "Expect 1 element in the length tuple"); length = AS_VALUE(TDataLiteral, lengthTuple->GetValue(0))->AsValue().Get<ui64>(); @@ -67,7 +67,7 @@ IComputationNode* WrapYtBlockTableContent(NCommon::TCodecContext& codecCtx, } return new TYtBlockTableContentWrapper(mutables, codecCtx, std::move(files), inputSpec, - origStructType, decompress, length); + callable.GetType()->GetReturnType(), decompress, length); } } // NYql diff --git a/yt/yql/providers/yt/comp_nodes/yql_mkql_file_input_state.cpp b/yt/yql/providers/yt/comp_nodes/yql_mkql_file_input_state.cpp index d814246f76e..ae32ee6bb54 100644 --- a/yt/yql/providers/yt/comp_nodes/yql_mkql_file_input_state.cpp +++ b/yt/yql/providers/yt/comp_nodes/yql_mkql_file_input_state.cpp @@ -57,7 +57,8 @@ bool TFileInputState::NextValue() { MkqlReader_.Next(); if (Spec_->UseBlockInput_) { - auto blockCountValue = CurrentValue_.GetElement(Spec_->Inputs[CurrentInput_]->StructSize); + auto blockSizeStructIndex = GetBlockSizeStructIndex(*Spec_, CurrentInput_); + auto blockCountValue = CurrentValue_.GetElement(blockSizeStructIndex); CurrentRecord_ += GetBlockCount(blockCountValue); } else { ++CurrentRecord_; diff --git a/yt/yql/providers/yt/comp_nodes/yql_mkql_file_list.cpp b/yt/yql/providers/yt/comp_nodes/yql_mkql_file_list.cpp index 7d720cbbd5d..410abc6ca93 100644 --- a/yt/yql/providers/yt/comp_nodes/yql_mkql_file_list.cpp +++ b/yt/yql/providers/yt/comp_nodes/yql_mkql_file_list.cpp @@ -1,14 +1,16 @@ #include "yql_mkql_file_list.h" -#include "yql_mkql_file_input_state.h" + +#include <yql/essentials/minikql/computation/mkql_block_impl.h> namespace NYql { using namespace NKikimr::NMiniKQL; -TFileListValueBase::TIterator::TIterator(TMemoryUsageInfo* memInfo, THolder<IInputState>&& state, std::optional<ui64> length) +TFileListValueBase::TIterator::TIterator(TMemoryUsageInfo* memInfo, const TMkqlIOSpecs& spec, THolder<TFileInputState>&& state, std::optional<ui64> length) : TComputationValue(memInfo) , State_(std::move(state)) , ExpectedLength_(std::move(length)) + , Spec_(spec) { } @@ -22,19 +24,25 @@ bool TFileListValueBase::TIterator::Next(NUdf::TUnboxedValue& value) { return false; } + value = State_->GetCurrent(); if (ExpectedLength_) { MKQL_ENSURE(*ExpectedLength_ > 0, "Invalid file length. State: " << State_->DebugInfo()); - --(*ExpectedLength_); + if (Spec_.UseBlockInput_) { + auto blockSizeStructIndex = GetBlockSizeStructIndex(Spec_, State_->GetTableIndex()); + auto blockCountValue = value.GetElement(blockSizeStructIndex); + (*ExpectedLength_) -= GetBlockCount(blockCountValue); + } else { + --(*ExpectedLength_); + } } - value = State_->GetCurrent(); return true; } NUdf::TUnboxedValue TFileListValueBase::GetListIterator() const { - return NUdf::TUnboxedValuePod(new TIterator(GetMemInfo(), MakeState(), Length)); + return NUdf::TUnboxedValuePod(new TIterator(GetMemInfo(), Spec, MakeState(), Length)); } -THolder<IInputState> TFileListValue::MakeState() const { +THolder<TFileInputState> TFileListValue::MakeState() const { return MakeHolder<TFileInputState>(Spec, HolderFactory, MakeMkqlFileInputs(FilePaths, Decompress), BlockCount, BlockSize); } diff --git a/yt/yql/providers/yt/comp_nodes/yql_mkql_file_list.h b/yt/yql/providers/yt/comp_nodes/yql_mkql_file_list.h index 912a083efe5..aa0b8d184c9 100644 --- a/yt/yql/providers/yt/comp_nodes/yql_mkql_file_list.h +++ b/yt/yql/providers/yt/comp_nodes/yql_mkql_file_list.h @@ -3,6 +3,7 @@ #include "yql_mkql_input_stream.h" #include <yt/yql/providers/yt/codec/yt_codec.h> +#include <yt/yql/providers/yt/comp_nodes/yql_mkql_file_input_state.h> #include <yql/essentials/minikql/computation/mkql_computation_node.h> #include <yql/essentials/minikql/computation/mkql_custom_list.h> @@ -28,19 +29,21 @@ public: protected: class TIterator : public NKikimr::NMiniKQL::TComputationValue<TIterator> { public: - TIterator(NKikimr::NMiniKQL::TMemoryUsageInfo* memInfo, THolder<IInputState>&& state, std::optional<ui64> length); + TIterator(NKikimr::NMiniKQL::TMemoryUsageInfo* memInfo, const TMkqlIOSpecs& spec, THolder<TFileInputState>&& state, std::optional<ui64> length); private: bool Next(NUdf::TUnboxedValue& value) override; bool AtStart_ = true; - THolder<IInputState> State_; + THolder<TFileInputState> State_; std::optional<ui64> ExpectedLength_; + + const TMkqlIOSpecs& Spec_; }; NUdf::TUnboxedValue GetListIterator() const override; - virtual THolder<IInputState> MakeState() const = 0; + virtual THolder<TFileInputState> MakeState() const = 0; protected: const TMkqlIOSpecs& Spec; @@ -66,7 +69,7 @@ public: } protected: - THolder<IInputState> MakeState() const override; + THolder<TFileInputState> MakeState() const override; private: const TVector<TString> FilePaths; diff --git a/yt/yql/providers/yt/gateway/file/yql_yt_file_comp_nodes.cpp b/yt/yql/providers/yt/gateway/file/yql_yt_file_comp_nodes.cpp index 65453c8137e..889c06436f7 100644 --- a/yt/yql/providers/yt/gateway/file/yql_yt_file_comp_nodes.cpp +++ b/yt/yql/providers/yt/gateway/file/yql_yt_file_comp_nodes.cpp @@ -102,7 +102,7 @@ public: } protected: - THolder<IInputState> MakeState() const override { + THolder<TFileInputState> MakeState() const override { return MakeHolder<TFileInputStateWithTableState>(Spec, HolderFactory, MakeTextYsonInputs(TablePaths_), 0u, 1_MB, TTableState(TableState_)); } diff --git a/yt/yql/providers/yt/gateway/file/yql_yt_file_mkql_compiler.cpp b/yt/yql/providers/yt/gateway/file/yql_yt_file_mkql_compiler.cpp index f1ffba0d200..2296595dc7e 100644 --- a/yt/yql/providers/yt/gateway/file/yql_yt_file_mkql_compiler.cpp +++ b/yt/yql/providers/yt/gateway/file/yql_yt_file_mkql_compiler.cpp @@ -639,13 +639,7 @@ void RegisterYtFileMkqlCompilers(NCommon::TMkqlCallableCompilerBase& compiler) { output.Ref(), itemsCount, ctx, true); } - return ctx.ProgramBuilder.WideToBlocks(ctx.ProgramBuilder.FromFlow(ctx.ProgramBuilder.ExpandMap(ctx.ProgramBuilder.ToFlow(values), [&](TRuntimeNode item) -> TRuntimeNode::TList { - TRuntimeNode::TList result; - for (auto& origItem : origItemStructType->GetItems()) { - result.push_back(ctx.ProgramBuilder.Member(item, origItem->GetName())); - } - return result; - }))); + return ctx.ProgramBuilder.ListToBlocks(values); }); compiler.AddCallable({TYtSort::CallableName(), TYtCopy::CallableName(), TYtMerge::CallableName()}, diff --git a/yt/yql/providers/yt/gateway/lib/yt_helpers.cpp b/yt/yql/providers/yt/gateway/lib/yt_helpers.cpp index 6ad53ec6d99..dd4ba61d7c5 100644 --- a/yt/yql/providers/yt/gateway/lib/yt_helpers.cpp +++ b/yt/yql/providers/yt/gateway/lib/yt_helpers.cpp @@ -369,7 +369,7 @@ static bool IterateRows(NYT::ITransactionPtr tx, } else { auto format = specsCache.GetSpecs().MakeInputFormat(tableIndex); auto rawReader = tx->CreateRawReader(path, format, readerOptions); - TMkqlReaderImpl reader(*rawReader, 0, 4 << 10, tableIndex); + TMkqlReaderImpl reader(*rawReader, 0, 4 << 10, tableIndex, true); reader.SetSpecs(specsCache.GetSpecs(), specsCache.GetHolderFactory()); for (reader.Next(); reader.IsValid(); reader.Next()) { diff --git a/yt/yql/providers/yt/gateway/native/yql_yt_transform.cpp b/yt/yql/providers/yt/gateway/native/yql_yt_transform.cpp index 26af2460bf8..18e6204a58d 100644 --- a/yt/yql/providers/yt/gateway/native/yql_yt_transform.cpp +++ b/yt/yql/providers/yt/gateway/native/yql_yt_transform.cpp @@ -82,11 +82,7 @@ TCallableVisitFunc TGatewayTransformer::operator()(TInternName internName) { if (EPhase::Content == Phase_ || EPhase::All == Phase_) { return [&, name, useBlocks](NMiniKQL::TCallable& callable, const TTypeEnvironment& env) { - if (useBlocks) { - YQL_ENSURE(callable.GetInputsCount() == 4, "Expected 4 args"); - } else { - YQL_ENSURE(callable.GetInputsCount() == 3, "Expected 3 args"); - } + YQL_ENSURE(callable.GetInputsCount() == 3, "Expected 3 args"); const TString cluster = ExecCtx_.Cluster_; const TString tmpFolder = GetTablesTmpFolder(*Settings_); @@ -331,7 +327,6 @@ TCallableVisitFunc TGatewayTransformer::operator()(TInternName internName) { callable.GetType()->GetReturnType()); if (useBlocks) { call.Add(PgmBuilder_.NewDataLiteral<NUdf::EDataSlot::String>(uniqueId)); - call.Add(callable.GetInput(3)); // orig struct type call.Add(PgmBuilder_.NewDataLiteral(tableList->GetItemsCount())); call.Add(PgmBuilder_.NewDataLiteral<NUdf::EDataSlot::String>(NYT::NodeToYsonString(specNode))); call.Add(PgmBuilder_.NewDataLiteral(ETableContentDeliveryMode::File == deliveryMode)); // use compression diff --git a/yt/yql/providers/yt/job/yql_job_user.cpp b/yt/yql/providers/yt/job/yql_job_user.cpp index f9e923ccffc..0b355e6609d 100644 --- a/yt/yql/providers/yt/job/yql_job_user.cpp +++ b/yt/yql/providers/yt/job/yql_job_user.cpp @@ -171,6 +171,7 @@ void TYqlUserJob::DoImpl(const TFile& inHandle, const TVector<TFile>& outHandles } if (UseBlockInput) { MkqlIOSpecs->SetUseBlockInput(); + MkqlIOSpecs->SetInputBlockRepresentation(TMkqlIOSpecs::EBlockRepresentation::WideBlock); } if (UseBlockOutput) { MkqlIOSpecs->SetUseBlockOutput(); diff --git a/yt/yql/providers/yt/lib/expr_traits/yql_expr_traits.cpp b/yt/yql/providers/yt/lib/expr_traits/yql_expr_traits.cpp index a55d9baf659..8d6b31a1571 100644 --- a/yt/yql/providers/yt/lib/expr_traits/yql_expr_traits.cpp +++ b/yt/yql/providers/yt/lib/expr_traits/yql_expr_traits.cpp @@ -372,6 +372,7 @@ namespace NYql { TStringBuf("Last"), TStringBuf("ToDict"), TStringBuf("SqueezeToDict"), + TStringBuf("BlockStorage"), TStringBuf("Iterator"), // Why? TStringBuf("Collect"), TStringBuf("Length"), diff --git a/yt/yql/providers/yt/provider/yql_yt_block_input.cpp b/yt/yql/providers/yt/provider/yql_yt_block_input.cpp index 5f0f3e07396..37313115425 100644 --- a/yt/yql/providers/yt/provider/yql_yt_block_input.cpp +++ b/yt/yql/providers/yt/provider/yql_yt_block_input.cpp @@ -68,46 +68,19 @@ private: return EnsureWideFlowType(mapLambda.Cast().Args().Arg(0).Ref(), ctx); } - TMaybeNode<TExprBase> TryTransformTableContent(TExprBase node, TExprContext& ctx, const TGetParents& getParents) const { + TMaybeNode<TExprBase> TryTransformTableContent(TExprBase node, TExprContext& ctx) const { auto tableContent = node.Cast<TYtTableContent>(); if (!NYql::HasSetting(tableContent.Settings().Ref(), EYtSettingType::BlockInputReady)) { return tableContent; } - const TParentsMap* parentsMap = getParents(); - if (auto it = parentsMap->find(tableContent.Raw()); it != parentsMap->end() && it->second.size() > 1) { - return tableContent; - } - YQL_CLOG(INFO, ProviderYt) << "Rewrite YtTableContent with block input"; - auto inputStructType = GetSeqItemType(tableContent.Ref().GetTypeAnn())->Cast<TStructExprType>(); - auto asStructBuilder = Build<TCoAsStruct>(ctx, tableContent.Pos()); - TExprNode::TListType narrowMapArgs; - for (auto& item : inputStructType->GetItems()) { - auto arg = ctx.NewArgument(tableContent.Pos(), item->GetName()); - asStructBuilder.Add<TCoNameValueTuple>() - .Name().Build(item->GetName()) - .Value(arg) - .Build(); - narrowMapArgs.push_back(std::move(arg)); - } - auto settings = RemoveSetting(tableContent.Settings().Ref(), EYtSettingType::BlockInputReady, ctx); - return Build<TCoForwardList>(ctx, tableContent.Pos()) - .Stream<TCoNarrowMap>() - .Input<TCoToFlow>() - .Input<TCoWideFromBlocks>() - .Input<TYtBlockTableContent>() - .Input(tableContent.Input()) - .Settings(settings) - .Build() - .Build() - .Build() - .Lambda() - .Args(narrowMapArgs) - .Body(asStructBuilder.Done()) - .Build() + return Build<TCoListFromBlocks>(ctx, tableContent.Pos()) + .Input<TYtBlockTableContent>() + .Input(tableContent.Input()) + .Settings(settings) .Build() .Done(); } diff --git a/yt/yql/providers/yt/provider/yql_yt_datasource_constraints.cpp b/yt/yql/providers/yt/provider/yql_yt_datasource_constraints.cpp index fcad1669af7..9e3bd1fbc2c 100644 --- a/yt/yql/providers/yt/provider/yql_yt_datasource_constraints.cpp +++ b/yt/yql/providers/yt/provider/yql_yt_datasource_constraints.cpp @@ -189,42 +189,9 @@ public: return TStatus::Ok; } - TStatus HandleBlockTableContent(TExprBase input, TExprContext& ctx) { + TStatus HandleBlockTableContent(TExprBase input, TExprContext& /*ctx*/) { TYtBlockTableContent tableContent = input.Cast<TYtBlockTableContent>(); - - auto listType = tableContent.Input().Maybe<TYtOutput>() - ? tableContent.Input().Ref().GetTypeAnn() - : tableContent.Input().Ref().GetTypeAnn()->Cast<TTupleExprType>()->GetItems().back(); - auto itemStructType = listType->Cast<TListExprType>()->GetItemType()->Cast<TStructExprType>(); - - auto pathRename = [&](TPartOfConstraintBase::TPathType path) -> std::vector<TPartOfConstraintBase::TPathType> { - YQL_ENSURE(!path.empty()); - - auto fieldIndex = itemStructType->FindItem(path[0]); - YQL_ENSURE(fieldIndex.Defined()); - - path[0] = ctx.GetIndexAsString(*fieldIndex); - return { path }; - }; - - TConstraintSet wideConstraints; - for (auto constraint : tableContent.Input().Ref().GetAllConstraints()) { - if (auto empty = dynamic_cast<const TEmptyConstraintNode*>(constraint)) { - wideConstraints.AddConstraint(ctx.MakeConstraint<TEmptyConstraintNode>()); - } else if (auto sorted = dynamic_cast<const TSortedConstraintNode*>(constraint)) { - wideConstraints.AddConstraint(sorted->RenameFields(ctx, pathRename)); - } else if (auto chopped = dynamic_cast<const TChoppedConstraintNode*>(constraint)) { - wideConstraints.AddConstraint(chopped->RenameFields(ctx, pathRename)); - } else if (auto unique = dynamic_cast<const TUniqueConstraintNode*>(constraint)) { - wideConstraints.AddConstraint(unique->RenameFields(ctx, pathRename)); - } else if (auto distinct = dynamic_cast<const TDistinctConstraintNode*>(constraint)) { - wideConstraints.AddConstraint(distinct->RenameFields(ctx, pathRename)); - } else { - YQL_ENSURE(false, "unexpected constraint"); - } - } - - input.Ptr()->SetConstraints(wideConstraints); + input.Ptr()->CopyConstraints(tableContent.Input().Ref()); return TStatus::Ok; } diff --git a/yt/yql/providers/yt/provider/yql_yt_datasource_type_ann.cpp b/yt/yql/providers/yt/provider/yql_yt_datasource_type_ann.cpp index 69dec1f558d..84c811cc816 100644 --- a/yt/yql/providers/yt/provider/yql_yt_datasource_type_ann.cpp +++ b/yt/yql/providers/yt/provider/yql_yt_datasource_type_ann.cpp @@ -898,14 +898,26 @@ public: auto listType = tableContent.Input().Maybe<TYtOutput>() ? tableContent.Input().Ref().GetTypeAnn() : tableContent.Input().Ref().GetTypeAnn()->Cast<TTupleExprType>()->GetItems().back(); - auto itemStructType = listType->Cast<TListExprType>()->GetItemType()->Cast<TStructExprType>(); + auto tableStructType = listType->Cast<TListExprType>()->GetItemType()->Cast<TStructExprType>(); - TTypeAnnotationNode::TListType multiTypeItems; - for (auto& item: itemStructType->GetItems()) { - multiTypeItems.emplace_back(ctx.MakeType<TBlockExprType>(item->GetItemType())); + TVector<const TItemExprType*> outputStructItems; + for (auto item : tableStructType->GetItems()) { + auto itemType = item->GetItemType(); + if (itemType->IsBlockOrScalar()) { + ctx.AddError(TIssue(ctx.GetPosition(input.Pos()), "Input type should not be a block or scalar")); + return IGraphTransformer::TStatus::Error; + } + + if (!EnsureSupportedAsBlockType(input.Pos(), *itemType, ctx, *State_->Types)) { + return IGraphTransformer::TStatus::Error; + } + + outputStructItems.push_back(ctx.MakeType<TItemExprType>(item->GetName(), ctx.MakeType<TBlockExprType>(itemType))); } - multiTypeItems.push_back(ctx.MakeType<TScalarExprType>(ctx.MakeType<TDataExprType>(EDataSlot::Uint64))); - input.Ptr()->SetTypeAnn(ctx.MakeType<TStreamExprType>(ctx.MakeType<TMultiExprType>(multiTypeItems))); + outputStructItems.push_back(ctx.MakeType<TItemExprType>(BlockLengthColumnName, ctx.MakeType<TScalarExprType>(ctx.MakeType<TDataExprType>(EDataSlot::Uint64)))); + + auto outputStructType = ctx.MakeType<TStructExprType>(outputStructItems); + input.Ptr()->SetTypeAnn(ctx.MakeType<TListExprType>(outputStructType)); if (auto columnOrder = State_->Types->LookupColumnOrder(tableContent.Input().Ref())) { return State_->Types->SetColumnOrder(input.Ref(), *columnOrder, ctx); diff --git a/yt/yql/providers/yt/provider/yql_yt_mkql_compiler.cpp b/yt/yql/providers/yt/provider/yql_yt_mkql_compiler.cpp index 0c3a1a208ba..a6c34dc8155 100644 --- a/yt/yql/providers/yt/provider/yql_yt_mkql_compiler.cpp +++ b/yt/yql/providers/yt/provider/yql_yt_mkql_compiler.cpp @@ -232,23 +232,13 @@ TRuntimeNode BuildTableContentCall(TStringBuf callName, samplingTupleItems.push_back(ctx.ProgramBuilder.NewDataLiteral(isSystemSampling)); } - TType* outType = nullptr; if (useBlocks) { - auto structType = AS_TYPE(TStructType, outItemType); - - std::vector<TType*> outputItems; - outputItems.reserve(structType->GetMembersCount()); - for (size_t i = 0; i < structType->GetMembersCount(); i++) { - outputItems.push_back(ctx.ProgramBuilder.NewBlockType(structType->GetMemberType(i), TBlockType::EShape::Many)); - } - outputItems.push_back(ctx.ProgramBuilder.NewBlockType(ctx.ProgramBuilder.NewDataType(NUdf::TDataType<ui64>::Id), TBlockType::EShape::Scalar)); - outType = ctx.ProgramBuilder.NewStreamType(ctx.ProgramBuilder.NewMultiType(outputItems)); - - } else { - outType = ctx.ProgramBuilder.NewListType(outItemType); + outItemType = ctx.ProgramBuilder.BuildBlockStructType(AS_TYPE(TStructType, outItemType)); } - TCallableBuilder call(ctx.ProgramBuilder.GetTypeEnvironment(), callName, outType); + auto outListType = ctx.ProgramBuilder.NewListType(outItemType); + + TCallableBuilder call(ctx.ProgramBuilder.GetTypeEnvironment(), callName, outListType); call.Add(ctx.ProgramBuilder.NewList(listTypeGroup, groups)); call.Add(ctx.ProgramBuilder.NewTuple(samplingTupleItems)); @@ -259,10 +249,6 @@ TRuntimeNode BuildTableContentCall(TStringBuf callName, call.Add(ctx.ProgramBuilder.NewEmptyTuple()); } - if (useBlocks) { - call.Add(TRuntimeNode(outItemType, true)); - } - auto res = TRuntimeNode(call.Build(), false); if (settings) { @@ -505,8 +491,8 @@ void RegisterYtMkqlCompilers(NCommon::TMkqlCallableCompilerBase& compiler) { [](const TExprNode& node, NCommon::TMkqlBuildContext& ctx) { TYtBlockTableContent tableContent(&node); if (node.GetConstraint<TEmptyConstraintNode>()) { - const auto streamType = ctx.BuildType(node, *node.GetTypeAnn()); - return ctx.ProgramBuilder.EmptyIterator(streamType); + const auto itemType = ctx.BuildType(node, GetSeqItemType(*node.GetTypeAnn())); + return ctx.ProgramBuilder.NewEmptyList(itemType); } auto origItemStructType = ( diff --git a/yt/yt/client/chaos_client/replication_card.cpp b/yt/yt/client/chaos_client/replication_card.cpp index 78b2be376ca..9aa1561afe9 100644 --- a/yt/yt/client/chaos_client/replication_card.cpp +++ b/yt/yt/client/chaos_client/replication_card.cpp @@ -743,19 +743,31 @@ TReplicationProgress BuildMaxProgress( if (otherIt == otherEnd) { cmpResult = -1; - if (!upperKeySelected && CompareRows(progressIt->LowerKey, other.UpperKey) >= 0) { - upperKeySelected = true; - otherTimestamp = NullTimestamp; - tryAppendSegment(other.UpperKey, progressTimestamp); - continue; + if (!upperKeySelected) { + int upperKeyCmpResult = CompareRows(progressIt->LowerKey, other.UpperKey); + if (upperKeyCmpResult >= 0) { + upperKeySelected = true; + otherTimestamp = NullTimestamp; + if (upperKeyCmpResult > 0) { + // UpperKey is smaller than progressIt->LowerKey so there's a gap to fill with progressTimestamp. + tryAppendSegment(other.UpperKey, progressTimestamp); + continue; + } + } } } else if (progressIt == progressEnd) { cmpResult = 1; - if (!upperKeySelected && CompareRows(otherIt->LowerKey, progress.UpperKey) >= 0) { - upperKeySelected = true; - progressTimestamp = NullTimestamp; - tryAppendSegment(progress.UpperKey, otherTimestamp); - continue; + if (!upperKeySelected) { + int upperKeyCmpResult = CompareRows(otherIt->LowerKey, progress.UpperKey); + if (upperKeyCmpResult >= 0) { + upperKeySelected = true; + progressTimestamp = NullTimestamp; + if (upperKeyCmpResult > 0) { + // UpperKey is smaller than otherIt->LowerKey so there's a gap to fill with otherTimestamp. + tryAppendSegment(progress.UpperKey, otherTimestamp); + continue; + } + } } } else { cmpResult = CompareRows(progressIt->LowerKey, otherIt->LowerKey); @@ -902,4 +914,3 @@ THashMap<TReplicaId, TDuration> ComputeReplicasLag(const THashMap<TReplicaId, TR //////////////////////////////////////////////////////////////////////////////// } // namespace NYT::NChaosClient - diff --git a/yt/yt/client/tablet_client/config.cpp b/yt/yt/client/tablet_client/config.cpp index 6e30181969f..dc6a909d223 100644 --- a/yt/yt/client/tablet_client/config.cpp +++ b/yt/yt/client/tablet_client/config.cpp @@ -82,6 +82,12 @@ void TReplicatedTableOptions::Register(TRegistrar registrar) .Optional(); registrar.Parameter("min_sync_replica_count", &TThis::MinSyncReplicaCount) .Optional(); + registrar.Parameter("max_sync_queue_replica_count", &TThis::MaxSyncQueueReplicaCount) + .Optional() + .DontSerializeDefault(); + registrar.Parameter("min_sync_queue_replica_count", &TThis::MinSyncQueueReplicaCount) + .Optional() + .DontSerializeDefault(); registrar.Parameter("enable_replicated_table_tracker", &TThis::EnableReplicatedTableTracker) .Default(false); @@ -100,26 +106,60 @@ void TReplicatedTableOptions::Register(TRegistrar registrar) .Default(TDuration::Minutes(5)); registrar.Postprocessor([] (TThis* config) { - if (config->MaxSyncReplicaCount && config->MinSyncReplicaCount && *config->MinSyncReplicaCount > *config->MaxSyncReplicaCount) { + if (config->MaxSyncReplicaCount && + config->MinSyncReplicaCount && + config->MinSyncReplicaCount > config->MaxSyncReplicaCount) + { THROW_ERROR_EXCEPTION("\"min_sync_replica_count\" must be less or equal to \"max_sync_replica_count\""); } + + if (config->MaxSyncQueueReplicaCount && config->MaxSyncQueueReplicaCount < 2) { + THROW_ERROR_EXCEPTION("\"max_sync_queue_replica_count\" canot be less than 2, actual: %v", + config->MaxSyncQueueReplicaCount); + } + + if (config->MaxSyncQueueReplicaCount && + config->MinSyncQueueReplicaCount && + config->MinSyncQueueReplicaCount > config->MaxSyncQueueReplicaCount) + { + THROW_ERROR_EXCEPTION("\"min_sync_queue_replica_count\" must be less or equal to \"max_sync_queue_replica_count\""); + } }); } -std::tuple<int, int> TReplicatedTableOptions::GetEffectiveMinMaxReplicaCount(int replicaCount) const +std::tuple<int, int> TReplicatedTableOptions::GetEffectiveMinMaxReplicaCount( + ETableReplicaContentType contentType, + int replicaCount) const { - int maxSyncReplicas = 0; - int minSyncReplicas = 0; + auto getResult = [&] (auto minSyncReplicaCount, auto maxSyncReplicaCount) { + int maxSyncReplicas = 0; + int minSyncReplicas = 0; - if (!MaxSyncReplicaCount && !MinSyncReplicaCount) { - maxSyncReplicas = 1; - } else { - maxSyncReplicas = MaxSyncReplicaCount.value_or(replicaCount); - } + if (!maxSyncReplicaCount && !minSyncReplicaCount) { + maxSyncReplicas = 1; + } else { + maxSyncReplicas = maxSyncReplicaCount.value_or(replicaCount); + } + + minSyncReplicas = minSyncReplicaCount.value_or(maxSyncReplicas); - minSyncReplicas = MinSyncReplicaCount.value_or(maxSyncReplicas); + return std::tuple(minSyncReplicas, maxSyncReplicas); + }; - return std::tuple(minSyncReplicas, maxSyncReplicas); + if (contentType == ETableReplicaContentType::Queue) { + int minSyncReplicas; + int maxSyncReplicas; + if (MinSyncQueueReplicaCount || MaxSyncQueueReplicaCount) { + std::tie(minSyncReplicas, maxSyncReplicas) = getResult(MinSyncQueueReplicaCount, MaxSyncQueueReplicaCount); + } else { + std::tie(minSyncReplicas, maxSyncReplicas) = getResult(MinSyncReplicaCount, MaxSyncReplicaCount); + } + return std::tuple( + std::max(minSyncReplicas, 1), + std::max(maxSyncReplicas, 2)); + } else { + return getResult(MinSyncReplicaCount, MaxSyncReplicaCount); + } } //////////////////////////////////////////////////////////////////////////////// diff --git a/yt/yt/client/tablet_client/config.h b/yt/yt/client/tablet_client/config.h index 1e5a0531f4f..eaf16aa37c6 100644 --- a/yt/yt/client/tablet_client/config.h +++ b/yt/yt/client/tablet_client/config.h @@ -104,6 +104,9 @@ public: std::optional<int> MaxSyncReplicaCount; std::optional<int> MinSyncReplicaCount; + std::optional<int> MaxSyncQueueReplicaCount; + std::optional<int> MinSyncQueueReplicaCount; + TDuration SyncReplicaLagThreshold; // TODO(akozhikhov): We probably do not need these in this per-table config. @@ -113,7 +116,7 @@ public: bool EnablePreloadStateCheck; TDuration IncompletePreloadGracePeriod; - std::tuple<int, int> GetEffectiveMinMaxReplicaCount(int replicaCount) const; + std::tuple<int, int> GetEffectiveMinMaxReplicaCount(ETableReplicaContentType contentType, int replicaCount) const; REGISTER_YSON_STRUCT(TReplicatedTableOptions); diff --git a/yt/yt/client/unittests/replication_progress_ut.cpp b/yt/yt/client/unittests/replication_progress_ut.cpp index 5ef0cf52652..41746f544ce 100644 --- a/yt/yt/client/unittests/replication_progress_ut.cpp +++ b/yt/yt/client/unittests/replication_progress_ut.cpp @@ -735,8 +735,21 @@ INSTANTIATE_TEST_SUITE_P( "{segments=[{lower_key=[1];timestamp=1073741824};{lower_key=[2];timestamp=3221225472}];" "upper_key=[<type=max>#]}", "{segments=[{lower_key=[];timestamp=1073741824};{lower_key=[2];timestamp=3221225472}];" - "upper_key=[<type=max>#]}") - + "upper_key=[<type=max>#]}"), + std::tuple( + "{segments=[{lower_key=[2];timestamp=1}];upper_key=[<type=max>#]}", + "{segments=[{lower_key=[1];timestamp=1}];upper_key=[2]}", + "{segments=[{lower_key=[1];timestamp=1}];upper_key=[<type=max>#]}"), + std::tuple( + "{segments=[{lower_key=[3];timestamp=1}];upper_key=[<type=max>#]}", + "{segments=[{lower_key=[1];timestamp=1}];upper_key=[2]}", + "{segments=[{lower_key=[1];timestamp=1};{lower_key=[2];timestamp=0};{lower_key=[3];timestamp=1}];" + "upper_key=[<type=max>#]}"), + std::tuple( + "{segments=[{lower_key=[1];timestamp=1};{lower_key=[2];timestamp=0};{lower_key=[3];timestamp=1}];" + "upper_key=[<type=max>#]}", + "{segments=[{lower_key=[2];timestamp=1}];upper_key=[3]}", + "{segments=[{lower_key=[1];timestamp=1}];upper_key=[<type=max>#]}") )); //////////////////////////////////////////////////////////////////////////////// |
