diff options
author | robot-piglet <robot-piglet@yandex-team.com> | 2024-03-11 17:59:28 +0300 |
---|---|---|
committer | robot-piglet <robot-piglet@yandex-team.com> | 2024-03-11 18:10:18 +0300 |
commit | a7a431a1c2a6704ba0cf2dbfe9e1c198945e8a9e (patch) | |
tree | 04c270807febe5f7da19a2510b450980f671f365 | |
parent | 7fa2009de5a7f9f102480fab66bdd624aa541755 (diff) | |
download | ydb-a7a431a1c2a6704ba0cf2dbfe9e1c198945e8a9e.tar.gz |
Intermediate changes
44 files changed, 1089 insertions, 239 deletions
diff --git a/contrib/python/hypothesis/py3/.dist-info/METADATA b/contrib/python/hypothesis/py3/.dist-info/METADATA index 95fe6d3510..717ae021d6 100644 --- a/contrib/python/hypothesis/py3/.dist-info/METADATA +++ b/contrib/python/hypothesis/py3/.dist-info/METADATA @@ -1,6 +1,6 @@ Metadata-Version: 2.1 Name: hypothesis -Version: 6.98.11 +Version: 6.98.12 Summary: A library for property-based testing Home-page: https://hypothesis.works Author: David R. MacIver and Zac Hatfield-Dodds diff --git a/contrib/python/hypothesis/py3/hypothesis/extra/ghostwriter.py b/contrib/python/hypothesis/py3/hypothesis/extra/ghostwriter.py index 8917d5bd87..2854b48c29 100644 --- a/contrib/python/hypothesis/py3/hypothesis/extra/ghostwriter.py +++ b/contrib/python/hypothesis/py3/hypothesis/extra/ghostwriter.py @@ -122,7 +122,7 @@ from hypothesis.strategies._internal.flatmapped import FlatMapStrategy from hypothesis.strategies._internal.lazy import LazyStrategy, unwrap_strategies from hypothesis.strategies._internal.strategies import ( FilteredStrategy, - MappedSearchStrategy, + MappedStrategy, OneOfStrategy, SampledFromStrategy, ) @@ -627,7 +627,7 @@ def _imports_for_strategy(strategy): strategy = unwrap_strategies(strategy) # Get imports for s.map(f), s.filter(f), s.flatmap(f), including both s and f - if isinstance(strategy, MappedSearchStrategy): + if isinstance(strategy, MappedStrategy): imports |= _imports_for_strategy(strategy.mapped_strategy) imports |= _imports_for_object(strategy.pack) if isinstance(strategy, FilteredStrategy): diff --git a/contrib/python/hypothesis/py3/hypothesis/extra/numpy.py b/contrib/python/hypothesis/py3/hypothesis/extra/numpy.py index 29d73f76be..4cfb1ca8d8 100644 --- a/contrib/python/hypothesis/py3/hypothesis/extra/numpy.py +++ b/contrib/python/hypothesis/py3/hypothesis/extra/numpy.py @@ -50,7 +50,7 @@ from hypothesis.strategies._internal.lazy import unwrap_strategies from hypothesis.strategies._internal.numbers import Real from hypothesis.strategies._internal.strategies import ( Ex, - MappedSearchStrategy, + MappedStrategy, T, check_strategy, ) @@ -516,7 +516,7 @@ def arrays( # If there's a redundant cast to the requested dtype, remove it. This unlocks # optimizations such as fast unique sampled_from, and saves some time directly too. unwrapped = unwrap_strategies(elements) - if isinstance(unwrapped, MappedSearchStrategy) and unwrapped.pack == dtype.type: + if isinstance(unwrapped, MappedStrategy) and unwrapped.pack == dtype.type: elements = unwrapped.mapped_strategy if isinstance(shape, int): shape = (shape,) diff --git a/contrib/python/hypothesis/py3/hypothesis/internal/filtering.py b/contrib/python/hypothesis/py3/hypothesis/internal/filtering.py index 4ba92b1da8..f352e9cb6d 100644 --- a/contrib/python/hypothesis/py3/hypothesis/internal/filtering.py +++ b/contrib/python/hypothesis/py3/hypothesis/internal/filtering.py @@ -33,7 +33,10 @@ from typing import Any, Callable, Collection, Dict, NamedTuple, Optional, TypeVa from hypothesis.internal.compat import ceil, floor from hypothesis.internal.floats import next_down, next_up -from hypothesis.internal.reflection import extract_lambda_source +from hypothesis.internal.reflection import ( + extract_lambda_source, + get_pretty_function_description, +) Ex = TypeVar("Ex") Predicate = Callable[[Ex], bool] @@ -64,6 +67,10 @@ class ConstructivePredicate(NamedTuple): def unchanged(cls, predicate: Predicate) -> "ConstructivePredicate": return cls({}, predicate) + def __repr__(self) -> str: + fn = get_pretty_function_description(self.predicate) + return f"{self.__class__.__name__}(kwargs={self.kwargs!r}, predicate={fn})" + ARG = object() @@ -147,8 +154,8 @@ def merge_preds(*con_predicates: ConstructivePredicate) -> ConstructivePredicate elif kw["max_value"] == base["max_value"]: base["exclude_max"] |= kw.get("exclude_max", False) - has_len = {"len" in kw for kw, _ in con_predicates} - assert len(has_len) == 1, "can't mix numeric with length constraints" + has_len = {"len" in kw for kw, _ in con_predicates if kw} + assert len(has_len) <= 1, "can't mix numeric with length constraints" if has_len == {True}: base["len"] = True diff --git a/contrib/python/hypothesis/py3/hypothesis/strategies/_internal/collections.py b/contrib/python/hypothesis/py3/hypothesis/strategies/_internal/collections.py index e8f8f21ba4..75de4a82ec 100644 --- a/contrib/python/hypothesis/py3/hypothesis/strategies/_internal/collections.py +++ b/contrib/python/hypothesis/py3/hypothesis/strategies/_internal/collections.py @@ -23,7 +23,7 @@ from hypothesis.strategies._internal.strategies import ( T4, T5, Ex, - MappedSearchStrategy, + MappedStrategy, SearchStrategy, T, check_strategy, @@ -211,6 +211,9 @@ class ListStrategy(SearchStrategy): new = copy.copy(self) new.min_size = max(self.min_size, kwargs.get("min_value", self.min_size)) new.max_size = min(self.max_size, kwargs.get("max_value", self.max_size)) + # Unsatisfiable filters are easiest to understand without rewriting. + if new.min_size > new.max_size: + return SearchStrategy.filter(self, condition) # Recompute average size; this is cheaper than making it into a property. new.average_size = min( max(new.min_size * 2, new.min_size + 5), @@ -302,7 +305,7 @@ class UniqueSampledListStrategy(UniqueListStrategy): return result -class FixedKeysDictStrategy(MappedSearchStrategy): +class FixedKeysDictStrategy(MappedStrategy): """A strategy which produces dicts with a fixed set of keys, given a strategy for each of their equivalent values. @@ -311,9 +314,12 @@ class FixedKeysDictStrategy(MappedSearchStrategy): """ def __init__(self, strategy_dict): - self.dict_type = type(strategy_dict) + dict_type = type(strategy_dict) self.keys = tuple(strategy_dict.keys()) - super().__init__(strategy=TupleStrategy(strategy_dict[k] for k in self.keys)) + super().__init__( + strategy=TupleStrategy(strategy_dict[k] for k in self.keys), + pack=lambda value: dict_type(zip(self.keys, value)), + ) def calc_is_empty(self, recur): return recur(self.mapped_strategy) @@ -321,9 +327,6 @@ class FixedKeysDictStrategy(MappedSearchStrategy): def __repr__(self): return f"FixedKeysDictStrategy({self.keys!r}, {self.mapped_strategy!r})" - def pack(self, value): - return self.dict_type(zip(self.keys, value)) - class FixedAndOptionalKeysDictStrategy(SearchStrategy): """A strategy which produces dicts with a fixed set of keys, given a diff --git a/contrib/python/hypothesis/py3/hypothesis/strategies/_internal/lazy.py b/contrib/python/hypothesis/py3/hypothesis/strategies/_internal/lazy.py index d6bb13c7c1..8f887293e6 100644 --- a/contrib/python/hypothesis/py3/hypothesis/strategies/_internal/lazy.py +++ b/contrib/python/hypothesis/py3/hypothesis/strategies/_internal/lazy.py @@ -61,10 +61,6 @@ def unwrap_strategies(s): assert unwrap_depth >= 0 -def _repr_filter(condition): - return f".filter({get_pretty_function_description(condition)})" - - class LazyStrategy(SearchStrategy): """A strategy which is defined purely by conversion to and from another strategy. @@ -72,14 +68,14 @@ class LazyStrategy(SearchStrategy): Its parameter and distribution come from that other strategy. """ - def __init__(self, function, args, kwargs, filters=(), *, force_repr=None): + def __init__(self, function, args, kwargs, *, transforms=(), force_repr=None): super().__init__() self.__wrapped_strategy = None self.__representation = force_repr self.function = function self.__args = args self.__kwargs = kwargs - self.__filters = filters + self._transformations = transforms @property def supports_find(self): @@ -115,23 +111,28 @@ class LazyStrategy(SearchStrategy): self.__wrapped_strategy = self.function( *unwrapped_args, **unwrapped_kwargs ) - for f in self.__filters: - self.__wrapped_strategy = self.__wrapped_strategy.filter(f) + for method, fn in self._transformations: + self.__wrapped_strategy = getattr(self.__wrapped_strategy, method)(fn) return self.__wrapped_strategy - def filter(self, condition): - try: - repr_ = f"{self!r}{_repr_filter(condition)}" - except Exception: - repr_ = None - return LazyStrategy( + def __with_transform(self, method, fn): + repr_ = self.__representation + if repr_: + repr_ = f"{repr_}.{method}({get_pretty_function_description(fn)})" + return type(self)( self.function, self.__args, self.__kwargs, - (*self.__filters, condition), + transforms=(*self._transformations, (method, fn)), force_repr=repr_, ) + def map(self, pack): + return self.__with_transform("map", pack) + + def filter(self, condition): + return self.__with_transform("filter", condition) + def do_validate(self): w = self.wrapped_strategy assert isinstance(w, SearchStrategy), f"{self!r} returned non-strategy {w!r}" @@ -156,7 +157,10 @@ class LazyStrategy(SearchStrategy): } self.__representation = repr_call( self.function, _args, kwargs_for_repr, reorder=False - ) + "".join(map(_repr_filter, self.__filters)) + ) + "".join( + f".{method}({get_pretty_function_description(fn)})" + for method, fn in self._transformations + ) return self.__representation def do_draw(self, data): diff --git a/contrib/python/hypothesis/py3/hypothesis/strategies/_internal/strategies.py b/contrib/python/hypothesis/py3/hypothesis/strategies/_internal/strategies.py index af2fa72937..46d4005cdb 100644 --- a/contrib/python/hypothesis/py3/hypothesis/strategies/_internal/strategies.py +++ b/contrib/python/hypothesis/py3/hypothesis/strategies/_internal/strategies.py @@ -11,6 +11,7 @@ import sys import warnings from collections import abc, defaultdict +from functools import lru_cache from random import shuffle from typing import ( Any, @@ -60,7 +61,7 @@ T5 = TypeVar("T5") calculating = UniqueIdentifier("calculating") MAPPED_SEARCH_STRATEGY_DO_DRAW_LABEL = calc_label_from_name( - "another attempted draw in MappedSearchStrategy" + "another attempted draw in MappedStrategy" ) FILTERED_SEARCH_STRATEGY_DO_DRAW_LABEL = calc_label_from_name( @@ -346,7 +347,7 @@ class SearchStrategy(Generic[Ex]): """ if is_identity_function(pack): return self # type: ignore # Mypy has no way to know that `Ex == T` - return MappedSearchStrategy(pack=pack, strategy=self) + return MappedStrategy(self, pack=pack) def flatmap( self, expand: Callable[[Ex], "SearchStrategy[T]"] @@ -468,9 +469,6 @@ class SampledFromStrategy(SearchStrategy): """A strategy which samples from a set of elements. This is essentially equivalent to using a OneOfStrategy over Just strategies but may be more efficient and convenient. - - The conditional distribution chooses uniformly at random from some - non-empty subset of the elements. """ _MAX_FILTER_CALLS = 10_000 @@ -521,7 +519,10 @@ class SampledFromStrategy(SearchStrategy): # Used in UniqueSampledListStrategy for name, f in self._transformations: if name == "map": - element = f(element) + result = f(element) + if build_context := _current_build_context.value: + build_context.record_call(result, f, [element], {}) + element = result else: assert name == "filter" if not f(element): @@ -794,18 +795,17 @@ def one_of( return OneOfStrategy(args) -class MappedSearchStrategy(SearchStrategy[Ex]): +class MappedStrategy(SearchStrategy[Ex]): """A strategy which is defined purely by conversion to and from another strategy. Its parameter and distribution come from that other strategy. """ - def __init__(self, strategy, pack=None): + def __init__(self, strategy, pack): super().__init__() self.mapped_strategy = strategy - if pack is not None: - self.pack = pack + self.pack = pack def calc_is_empty(self, recur): return recur(self.mapped_strategy) @@ -821,11 +821,6 @@ class MappedSearchStrategy(SearchStrategy[Ex]): def do_validate(self): self.mapped_strategy.validate() - def pack(self, x): - """Take a value produced by the underlying mapped_strategy and turn it - into a value suitable for outputting from this strategy.""" - raise NotImplementedError(f"{self.__class__.__name__}.pack()") - def do_draw(self, data: ConjectureData) -> Any: with warnings.catch_warnings(): if isinstance(self.pack, type) and issubclass( @@ -847,10 +842,67 @@ class MappedSearchStrategy(SearchStrategy[Ex]): @property def branches(self) -> List[SearchStrategy[Ex]]: return [ - MappedSearchStrategy(pack=self.pack, strategy=strategy) + MappedStrategy(strategy, pack=self.pack) for strategy in self.mapped_strategy.branches ] + def filter(self, condition: Callable[[Ex], Any]) -> "SearchStrategy[Ex]": + # Includes a special case so that we can rewrite filters on collection + # lengths, when most collections are `st.lists(...).map(the_type)`. + ListStrategy = _list_strategy_type() + if not isinstance(self.mapped_strategy, ListStrategy) or not ( + (isinstance(self.pack, type) and issubclass(self.pack, abc.Collection)) + or self.pack in _collection_ish_functions() + ): + return super().filter(condition) + + # Check whether our inner list strategy can rewrite this filter condition. + # If not, discard the result and _only_ apply a new outer filter. + new = ListStrategy.filter(self.mapped_strategy, condition) + if getattr(new, "filtered_strategy", None) is self.mapped_strategy: + return super().filter(condition) # didn't rewrite + + # Apply a new outer filter even though we rewrote the inner strategy, + # because some collections can change the list length (dict, set, etc). + return FilteredStrategy(type(self)(new, self.pack), conditions=(condition,)) + + +@lru_cache +def _list_strategy_type(): + from hypothesis.strategies._internal.collections import ListStrategy + + return ListStrategy + + +def _collection_ish_functions(): + funcs = [sorted] + if np := sys.modules.get("numpy"): + # c.f. https://numpy.org/doc/stable/reference/routines.array-creation.html + # Probably only `np.array` and `np.asarray` will be used in practice, + # but why should that stop us when we've already gone this far? + funcs += [ + np.empty_like, + np.eye, + np.identity, + np.ones_like, + np.zeros_like, + np.array, + np.asarray, + np.asanyarray, + np.ascontiguousarray, + np.asmatrix, + np.copy, + np.rec.array, + np.rec.fromarrays, + np.rec.fromrecords, + np.diag, + # bonus undocumented functions from tab-completion: + np.asarray_chkfinite, + np.asfarray, + np.asfortranarray, + ] + return funcs + filter_not_satisfied = UniqueIdentifier("filter not satisfied") diff --git a/contrib/python/hypothesis/py3/hypothesis/version.py b/contrib/python/hypothesis/py3/hypothesis/version.py index da7f74708c..e986241037 100644 --- a/contrib/python/hypothesis/py3/hypothesis/version.py +++ b/contrib/python/hypothesis/py3/hypothesis/version.py @@ -8,5 +8,5 @@ # v. 2.0. If a copy of the MPL was not distributed with this file, You can # obtain one at https://mozilla.org/MPL/2.0/. -__version_info__ = (6, 98, 11) +__version_info__ = (6, 98, 12) __version__ = ".".join(map(str, __version_info__)) diff --git a/contrib/python/hypothesis/py3/ya.make b/contrib/python/hypothesis/py3/ya.make index c71ce1c809..33c8057a99 100644 --- a/contrib/python/hypothesis/py3/ya.make +++ b/contrib/python/hypothesis/py3/ya.make @@ -2,7 +2,7 @@ PY3_LIBRARY() -VERSION(6.98.11) +VERSION(6.98.12) LICENSE(MPL-2.0) diff --git a/contrib/python/sniffio/.dist-info/METADATA b/contrib/python/sniffio/.dist-info/METADATA index 22520c72af..88968aed16 100644 --- a/contrib/python/sniffio/.dist-info/METADATA +++ b/contrib/python/sniffio/.dist-info/METADATA @@ -1,11 +1,12 @@ Metadata-Version: 2.1 Name: sniffio -Version: 1.3.0 +Version: 1.3.1 Summary: Sniff out which async library your code is running under -Home-page: https://github.com/python-trio/sniffio -Author: Nathaniel J. Smith -Author-email: njs@pobox.com +Author-email: "Nathaniel J. Smith" <njs@pobox.com> License: MIT OR Apache-2.0 +Project-URL: Homepage, https://github.com/python-trio/sniffio +Project-URL: Documentation, https://sniffio.readthedocs.io/ +Project-URL: Changelog, https://sniffio.readthedocs.io/en/latest/history.html Keywords: async,trio,asyncio Classifier: License :: OSI Approved :: MIT License Classifier: License :: OSI Approved :: Apache Software License @@ -20,6 +21,7 @@ Classifier: Programming Language :: Python :: Implementation :: PyPy Classifier: Intended Audience :: Developers Classifier: Development Status :: 5 - Production/Stable Requires-Python: >=3.7 +Description-Content-Type: text/x-rst License-File: LICENSE License-File: LICENSE.APACHE2 License-File: LICENSE.MIT diff --git a/contrib/python/sniffio/sniffio/__init__.py b/contrib/python/sniffio/sniffio/__init__.py index fb3364d7f1..63f2f19e40 100644 --- a/contrib/python/sniffio/sniffio/__init__.py +++ b/contrib/python/sniffio/sniffio/__init__.py @@ -1,8 +1,10 @@ """Top-level package for sniffio.""" __all__ = [ - "current_async_library", "AsyncLibraryNotFoundError", - "current_async_library_cvar" + "current_async_library", + "AsyncLibraryNotFoundError", + "current_async_library_cvar", + "thread_local", ] from ._version import __version__ diff --git a/contrib/python/sniffio/sniffio/_version.py b/contrib/python/sniffio/sniffio/_version.py index 5a5f906bbf..0495d10545 100644 --- a/contrib/python/sniffio/sniffio/_version.py +++ b/contrib/python/sniffio/sniffio/_version.py @@ -1,3 +1,3 @@ # This file is imported from __init__.py and exec'd from setup.py -__version__ = "1.3.0" +__version__ = "1.3.1" diff --git a/contrib/python/sniffio/ya.make b/contrib/python/sniffio/ya.make index d0e376d4ca..165b99c587 100644 --- a/contrib/python/sniffio/ya.make +++ b/contrib/python/sniffio/ya.make @@ -2,7 +2,7 @@ PY3_LIBRARY() -VERSION(1.3.0) +VERSION(1.3.1) LICENSE(Apache-2.0 AND MIT) diff --git a/contrib/python/typing-extensions/py3/.dist-info/METADATA b/contrib/python/typing-extensions/py3/.dist-info/METADATA index 863e977c2f..13d06e24b7 100644 --- a/contrib/python/typing-extensions/py3/.dist-info/METADATA +++ b/contrib/python/typing-extensions/py3/.dist-info/METADATA @@ -1,6 +1,6 @@ Metadata-Version: 2.1 Name: typing_extensions -Version: 4.9.0 +Version: 4.10.0 Summary: Backported and Experimental Type Hints for Python 3.8+ Keywords: annotations,backport,checker,checking,function,hinting,hints,type,typechecking,typehinting,typehints,typing Author-email: "Guido van Rossum, Jukka Lehtosalo, Ćukasz Langa, Michael Lee" <levkivskyi@gmail.com> diff --git a/contrib/python/typing-extensions/py3/typing_extensions.py b/contrib/python/typing-extensions/py3/typing_extensions.py index 1666e96b7e..f3132ea4ae 100644 --- a/contrib/python/typing-extensions/py3/typing_extensions.py +++ b/contrib/python/typing-extensions/py3/typing_extensions.py @@ -83,6 +83,7 @@ __all__ = [ 'TypeAlias', 'TypeAliasType', 'TypeGuard', + 'TypeIs', 'TYPE_CHECKING', 'Never', 'NoReturn', @@ -473,7 +474,7 @@ _EXCLUDED_ATTRS = { "_is_runtime_protocol", "__dict__", "__slots__", "__parameters__", "__orig_bases__", "__module__", "_MutableMapping__marker", "__doc__", "__subclasshook__", "__orig_class__", "__init__", "__new__", - "__protocol_attrs__", "__callable_proto_members_only__", + "__protocol_attrs__", "__non_callable_proto_members__", "__match_args__", } @@ -521,6 +522,22 @@ else: if type(self)._is_protocol: raise TypeError('Protocols cannot be instantiated') + def _type_check_issubclass_arg_1(arg): + """Raise TypeError if `arg` is not an instance of `type` + in `issubclass(arg, <protocol>)`. + + In most cases, this is verified by type.__subclasscheck__. + Checking it again unnecessarily would slow down issubclass() checks, + so, we don't perform this check unless we absolutely have to. + + For various error paths, however, + we want to ensure that *this* error message is shown to the user + where relevant, rather than a typing.py-specific error message. + """ + if not isinstance(arg, type): + # Same error message as for issubclass(1, int). + raise TypeError('issubclass() arg 1 must be a class') + # Inheriting from typing._ProtocolMeta isn't actually desirable, # but is necessary to allow typing.Protocol and typing_extensions.Protocol # to mix without getting TypeErrors about "metaclass conflict" @@ -551,11 +568,6 @@ else: abc.ABCMeta.__init__(cls, *args, **kwargs) if getattr(cls, "_is_protocol", False): cls.__protocol_attrs__ = _get_protocol_attrs(cls) - # PEP 544 prohibits using issubclass() - # with protocols that have non-method members. - cls.__callable_proto_members_only__ = all( - callable(getattr(cls, attr, None)) for attr in cls.__protocol_attrs__ - ) def __subclasscheck__(cls, other): if cls is Protocol: @@ -564,26 +576,23 @@ else: getattr(cls, '_is_protocol', False) and not _allow_reckless_class_checks() ): - if not isinstance(other, type): - # Same error message as for issubclass(1, int). - raise TypeError('issubclass() arg 1 must be a class') + if not getattr(cls, '_is_runtime_protocol', False): + _type_check_issubclass_arg_1(other) + raise TypeError( + "Instance and class checks can only be used with " + "@runtime_checkable protocols" + ) if ( - not cls.__callable_proto_members_only__ + # this attribute is set by @runtime_checkable: + cls.__non_callable_proto_members__ and cls.__dict__.get("__subclasshook__") is _proto_hook ): - non_method_attrs = sorted( - attr for attr in cls.__protocol_attrs__ - if not callable(getattr(cls, attr, None)) - ) + _type_check_issubclass_arg_1(other) + non_method_attrs = sorted(cls.__non_callable_proto_members__) raise TypeError( "Protocols with non-method members don't support issubclass()." f" Non-method members: {str(non_method_attrs)[1:-1]}." ) - if not getattr(cls, '_is_runtime_protocol', False): - raise TypeError( - "Instance and class checks can only be used with " - "@runtime_checkable protocols" - ) return abc.ABCMeta.__subclasscheck__(cls, other) def __instancecheck__(cls, instance): @@ -610,7 +619,8 @@ else: val = inspect.getattr_static(instance, attr) except AttributeError: break - if val is None and callable(getattr(cls, attr, None)): + # this attribute is set by @runtime_checkable: + if val is None and attr not in cls.__non_callable_proto_members__: break else: return True @@ -678,8 +688,58 @@ else: cls.__init__ = _no_init +if sys.version_info >= (3, 13): + runtime_checkable = typing.runtime_checkable +else: + def runtime_checkable(cls): + """Mark a protocol class as a runtime protocol. + + Such protocol can be used with isinstance() and issubclass(). + Raise TypeError if applied to a non-protocol class. + This allows a simple-minded structural check very similar to + one trick ponies in collections.abc such as Iterable. + + For example:: + + @runtime_checkable + class Closable(Protocol): + def close(self): ... + + assert isinstance(open('/some/file'), Closable) + + Warning: this will check only the presence of the required methods, + not their type signatures! + """ + if not issubclass(cls, typing.Generic) or not getattr(cls, '_is_protocol', False): + raise TypeError('@runtime_checkable can be only applied to protocol classes,' + ' got %r' % cls) + cls._is_runtime_protocol = True + + # Only execute the following block if it's a typing_extensions.Protocol class. + # typing.Protocol classes don't need it. + if isinstance(cls, _ProtocolMeta): + # PEP 544 prohibits using issubclass() + # with protocols that have non-method members. + # See gh-113320 for why we compute this attribute here, + # rather than in `_ProtocolMeta.__init__` + cls.__non_callable_proto_members__ = set() + for attr in cls.__protocol_attrs__: + try: + is_callable = callable(getattr(cls, attr, None)) + except Exception as e: + raise TypeError( + f"Failed to determine whether protocol member {attr!r} " + "is a method member" + ) from e + else: + if not is_callable: + cls.__non_callable_proto_members__.add(attr) + + return cls + + # The "runtime" alias exists for backwards compatibility. -runtime = runtime_checkable = typing.runtime_checkable +runtime = runtime_checkable # Our version of runtime-checkable protocols is faster on Python 3.8-3.11 @@ -815,7 +875,7 @@ else: break class _TypedDictMeta(type): - def __new__(cls, name, bases, ns, *, total=True): + def __new__(cls, name, bases, ns, *, total=True, closed=False): """Create new typed dict class object. This method is called when TypedDict is subclassed, @@ -860,6 +920,7 @@ else: optional_keys = set() readonly_keys = set() mutable_keys = set() + extra_items_type = None for base in bases: base_dict = base.__dict__ @@ -869,6 +930,26 @@ else: optional_keys.update(base_dict.get('__optional_keys__', ())) readonly_keys.update(base_dict.get('__readonly_keys__', ())) mutable_keys.update(base_dict.get('__mutable_keys__', ())) + base_extra_items_type = base_dict.get('__extra_items__', None) + if base_extra_items_type is not None: + extra_items_type = base_extra_items_type + + if closed and extra_items_type is None: + extra_items_type = Never + if closed and "__extra_items__" in own_annotations: + annotation_type = own_annotations.pop("__extra_items__") + qualifiers = set(_get_typeddict_qualifiers(annotation_type)) + if Required in qualifiers: + raise TypeError( + "Special key __extra_items__ does not support " + "Required" + ) + if NotRequired in qualifiers: + raise TypeError( + "Special key __extra_items__ does not support " + "NotRequired" + ) + extra_items_type = annotation_type annotations.update(own_annotations) for annotation_key, annotation_type in own_annotations.items(): @@ -883,11 +964,7 @@ else: else: optional_keys.add(annotation_key) if ReadOnly in qualifiers: - if annotation_key in mutable_keys: - raise TypeError( - f"Cannot override mutable key {annotation_key!r}" - " with read-only key" - ) + mutable_keys.discard(annotation_key) readonly_keys.add(annotation_key) else: mutable_keys.add(annotation_key) @@ -900,6 +977,8 @@ else: tp_dict.__mutable_keys__ = frozenset(mutable_keys) if not hasattr(tp_dict, '__total__'): tp_dict.__total__ = total + tp_dict.__closed__ = closed + tp_dict.__extra_items__ = extra_items_type return tp_dict __call__ = dict # static method @@ -913,7 +992,7 @@ else: _TypedDict = type.__new__(_TypedDictMeta, 'TypedDict', (), {}) @_ensure_subclassable(lambda bases: (_TypedDict,)) - def TypedDict(typename, fields=_marker, /, *, total=True, **kwargs): + def TypedDict(typename, fields=_marker, /, *, total=True, closed=False, **kwargs): """A simple typed namespace. At runtime it is equivalent to a plain dict. TypedDict creates a dictionary type such that a type checker will expect all @@ -973,6 +1052,9 @@ else: "using the functional syntax, pass an empty dictionary, e.g. " ) + example + "." warnings.warn(deprecation_msg, DeprecationWarning, stacklevel=2) + if closed is not False and closed is not True: + kwargs["closed"] = closed + closed = False fields = kwargs elif kwargs: raise TypeError("TypedDict takes either a dict or keyword arguments," @@ -994,7 +1076,7 @@ else: # Setting correct module is necessary to make typed dict classes pickleable. ns['__module__'] = module - td = _TypedDictMeta(typename, (), ns, total=total) + td = _TypedDictMeta(typename, (), ns, total=total, closed=closed) td.__orig_bases__ = (TypedDict,) return td @@ -1768,6 +1850,98 @@ else: PEP 647 (User-Defined Type Guards). """) +# 3.13+ +if hasattr(typing, 'TypeIs'): + TypeIs = typing.TypeIs +# 3.9 +elif sys.version_info[:2] >= (3, 9): + @_ExtensionsSpecialForm + def TypeIs(self, parameters): + """Special typing form used to annotate the return type of a user-defined + type narrower function. ``TypeIs`` only accepts a single type argument. + At runtime, functions marked this way should return a boolean. + + ``TypeIs`` aims to benefit *type narrowing* -- a technique used by static + type checkers to determine a more precise type of an expression within a + program's code flow. Usually type narrowing is done by analyzing + conditional code flow and applying the narrowing to a block of code. The + conditional expression here is sometimes referred to as a "type guard". + + Sometimes it would be convenient to use a user-defined boolean function + as a type guard. Such a function should use ``TypeIs[...]`` as its + return type to alert static type checkers to this intention. + + Using ``-> TypeIs`` tells the static type checker that for a given + function: + + 1. The return value is a boolean. + 2. If the return value is ``True``, the type of its argument + is the intersection of the type inside ``TypeGuard`` and the argument's + previously known type. + + For example:: + + def is_awaitable(val: object) -> TypeIs[Awaitable[Any]]: + return hasattr(val, '__await__') + + def f(val: Union[int, Awaitable[int]]) -> int: + if is_awaitable(val): + assert_type(val, Awaitable[int]) + else: + assert_type(val, int) + + ``TypeIs`` also works with type variables. For more information, see + PEP 742 (Narrowing types with TypeIs). + """ + item = typing._type_check(parameters, f'{self} accepts only a single type.') + return typing._GenericAlias(self, (item,)) +# 3.8 +else: + class _TypeIsForm(_ExtensionsSpecialForm, _root=True): + def __getitem__(self, parameters): + item = typing._type_check(parameters, + f'{self._name} accepts only a single type') + return typing._GenericAlias(self, (item,)) + + TypeIs = _TypeIsForm( + 'TypeIs', + doc="""Special typing form used to annotate the return type of a user-defined + type narrower function. ``TypeIs`` only accepts a single type argument. + At runtime, functions marked this way should return a boolean. + + ``TypeIs`` aims to benefit *type narrowing* -- a technique used by static + type checkers to determine a more precise type of an expression within a + program's code flow. Usually type narrowing is done by analyzing + conditional code flow and applying the narrowing to a block of code. The + conditional expression here is sometimes referred to as a "type guard". + + Sometimes it would be convenient to use a user-defined boolean function + as a type guard. Such a function should use ``TypeIs[...]`` as its + return type to alert static type checkers to this intention. + + Using ``-> TypeIs`` tells the static type checker that for a given + function: + + 1. The return value is a boolean. + 2. If the return value is ``True``, the type of its argument + is the intersection of the type inside ``TypeGuard`` and the argument's + previously known type. + + For example:: + + def is_awaitable(val: object) -> TypeIs[Awaitable[Any]]: + return hasattr(val, '__await__') + + def f(val: Union[int, Awaitable[int]]) -> int: + if is_awaitable(val): + assert_type(val, Awaitable[int]) + else: + assert_type(val, int) + + ``TypeIs`` also works with type variables. For more information, see + PEP 742 (Narrowing types with TypeIs). + """) + # Vendored from cpython typing._SpecialFrom class _SpecialForm(typing._Final, _root=True): diff --git a/contrib/python/typing-extensions/py3/ya.make b/contrib/python/typing-extensions/py3/ya.make index 1e65722a16..6a099000e4 100644 --- a/contrib/python/typing-extensions/py3/ya.make +++ b/contrib/python/typing-extensions/py3/ya.make @@ -2,7 +2,7 @@ PY3_LIBRARY() -VERSION(4.9.0) +VERSION(4.10.0) LICENSE(PSF-2.0) diff --git a/yt/yt/client/ypath/parser_detail.cpp b/yt/yt/client/ypath/parser_detail.cpp index 996dd52159..d1852f526f 100644 --- a/yt/yt/client/ypath/parser_detail.cpp +++ b/yt/yt/client/ypath/parser_detail.cpp @@ -163,30 +163,6 @@ TString ParseCluster(TString str, const IAttributeDictionaryPtr& attributes) attributes->Set("cluster", clusterName); return remainingString; - - - // NB. If the path had attributes, then the leading spaces must be removed in preceding ParseAttributes() - // call. We can encounter the path with leading spaces here if it didn't have attributes and passed through - // ParseAttributes() unchanged. In this case, the path is most likely incorrect, so it returns unchanged from - // this function. For example, " <> cluster://path" will become "cluster://path" after call to ParseAttributes(), - // and "//path" after this function. But " cluster://path" will pass unchanged throughout ParseAttributes(), and - // will remain " cluster://path" after this function. - size_t index = 0; - size_t clusterStart = index; - while (index < str.size() && (IsAsciiAlnum(str[index]) || str[index] == '-' || str[index] == '_')) { - ++index; - } - if (index >= str.size() || str[index] != ':') { - // Not a cluster name, so return the string as-is. - return str; - } - size_t clusterEnd = index; - if (clusterStart == clusterEnd) { - THROW_ERROR_EXCEPTION("Cluster name cannot be empty"); - } - attributes->Set("cluster", str.substr(clusterStart, clusterEnd - clusterStart)); - ++index; - return str.substr(index); } void ParseColumns(NYson::TTokenizer& tokenizer, IAttributeDictionary* attributes) diff --git a/yt/yt/core/actions/future-inl.h b/yt/yt/core/actions/future-inl.h index 51289611b0..369d6f57fe 100644 --- a/yt/yt/core/actions/future-inl.h +++ b/yt/yt/core/actions/future-inl.h @@ -1580,12 +1580,26 @@ struct TAsyncViaHelper<R(TArgs...)> TArgs... args) { auto promise = NewPromise<TUnderlying>(); + auto makeOnSuccess = [&] <size_t... Indeces> (std::index_sequence<Indeces...>) { + return + [ + promise, + this_ = std::move(this_), + tuple = std::tuple(std::forward<TArgs>(args)...) + ] { + if constexpr (sizeof...(TArgs) == 0) { + Y_UNUSED(tuple); + } + Inner(std::move(this_), promise, std::forward<TArgs>(std::get<Indeces>(tuple))...); + }; + }; + GuardedInvoke( invoker, - BIND_NO_PROPAGATE(&Inner, std::move(this_), promise, WrapToPassed(std::forward<TArgs>(args))...), - BIND_NO_PROPAGATE([promise, cancellationError = std::move(cancellationError)] { + makeOnSuccess(std::make_index_sequence<sizeof...(TArgs)>()), + [promise, cancellationError = std::move(cancellationError)] { promise.Set(std::move(cancellationError)); - })); + }); return promise; } diff --git a/yt/yt/core/actions/invoker_util-inl.h b/yt/yt/core/actions/invoker_util-inl.h new file mode 100644 index 0000000000..ab2a93834a --- /dev/null +++ b/yt/yt/core/actions/invoker_util-inl.h @@ -0,0 +1,61 @@ +#ifndef INVOKER_UTIL_INL_H_ +#error "Direct inclusion of this file is not allowed, include invoker_util.h" +// For the sake of sane code completion. +#include "invoker_util.h" +#endif +#undef INVOKER_UTIL_INL_H_ + +#include <yt/yt/core/misc/finally.h> + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +template <CInvocable<void()> TOnSuccess, CInvocable<void()> TOnCancel> +void GuardedInvoke( + const IInvokerPtr& invoker, + TOnSuccess onSuccess, + TOnCancel onCancel) +{ + YT_VERIFY(invoker); + + class TGuard + { + public: + TGuard(TOnSuccess onSuccess, TOnCancel onCancel) + : OnSuccess_(std::move(onSuccess)) + , OnCancel_(std::move(onCancel)) + { } + + TGuard(TGuard&& other) + : OnSuccess_(std::move(other.OnSuccess_)) + , OnCancel_(std::move(other.OnCancel_)) + , WasInvoked_(std::exchange(other.WasInvoked_, true)) + { } + + void operator()() + { + WasInvoked_ = true; + OnSuccess_(); + } + + ~TGuard() + { + if (!WasInvoked_) { + OnCancel_(); + } + } + + private: + TOnSuccess OnSuccess_; + TOnCancel OnCancel_; + + bool WasInvoked_ = false; + }; + + invoker->Invoke(BIND_NO_PROPAGATE(TGuard(std::move(onSuccess), std::move(onCancel)))); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/yt/yt/core/actions/invoker_util.cpp b/yt/yt/core/actions/invoker_util.cpp index 1e8cdbb36e..5428053eaa 100644 --- a/yt/yt/core/actions/invoker_util.cpp +++ b/yt/yt/core/actions/invoker_util.cpp @@ -143,51 +143,4 @@ IInvokerPtr GetFinalizerInvoker() //////////////////////////////////////////////////////////////////////////////// -void GuardedInvoke( - const IInvokerPtr& invoker, - TClosure onSuccess, - TClosure onCancel) -{ - YT_ASSERT(invoker); - YT_ASSERT(onSuccess); - YT_ASSERT(onCancel); - - class TGuard - { - public: - explicit TGuard(TClosure onCancel) - : OnCancel_(std::move(onCancel)) - { } - - TGuard(TGuard&& other) = default; - - ~TGuard() - { - if (OnCancel_) { - OnCancel_(); - } - } - - void Release() - { - OnCancel_.Reset(); - } - - private: - TClosure OnCancel_; - }; - - auto doInvoke = [] (TClosure onSuccess, TGuard guard) { - guard.Release(); - onSuccess(); - }; - - invoker->Invoke(BIND_NO_PROPAGATE( - std::move(doInvoke), - Passed(std::move(onSuccess)), - Passed(TGuard(std::move(onCancel))))); -} - -//////////////////////////////////////////////////////////////////////////////// - } // namespace NYT diff --git a/yt/yt/core/actions/invoker_util.h b/yt/yt/core/actions/invoker_util.h index 2136aa5bc8..f74f217123 100644 --- a/yt/yt/core/actions/invoker_util.h +++ b/yt/yt/core/actions/invoker_util.h @@ -35,11 +35,16 @@ IInvokerPtr GetFinalizerInvoker(); //! Tries to invoke #onSuccess via #invoker. //! If the invoker discards the callback without executing it then //! #onCancel is run. +template <CInvocable<void()> TOnSuccess, CInvocable<void()> TOnCancel> void GuardedInvoke( const IInvokerPtr& invoker, - TClosure onSuccess, - TClosure onCancel); + TOnSuccess onSuccess, + TOnCancel onCancel); //////////////////////////////////////////////////////////////////////////////// } // namespace NYT + +#define INVOKER_UTIL_INL_H_ +#include "invoker_util-inl.h" +#undef INVOKER_UTIL_INL_H_ diff --git a/yt/yt/core/bus/tcp/client.cpp b/yt/yt/core/bus/tcp/client.cpp index 6fe8a08856..770375ddfe 100644 --- a/yt/yt/core/bus/tcp/client.cpp +++ b/yt/yt/core/bus/tcp/client.cpp @@ -141,9 +141,11 @@ class TTcpBusClient public: TTcpBusClient( TBusClientConfigPtr config, - IPacketTranscoderFactory* packetTranscoderFactory) + IPacketTranscoderFactory* packetTranscoderFactory, + IMemoryUsageTrackerPtr memoryUsageTracker) : Config_(std::move(config)) , PacketTranscoderFactory_(packetTranscoderFactory) + , MemoryUsageTracker_(std::move(memoryUsageTracker)) { if (Config_->Address) { EndpointDescription_ = *Config_->Address; @@ -204,7 +206,8 @@ public: Config_->UnixDomainSocketPath, std::move(handler), std::move(poller), - PacketTranscoderFactory_); + PacketTranscoderFactory_, + MemoryUsageTracker_); connection->Start(); return New<TTcpClientBusProxy>(std::move(connection)); @@ -215,6 +218,8 @@ private: IPacketTranscoderFactory* const PacketTranscoderFactory_; + const IMemoryUsageTrackerPtr MemoryUsageTracker_; + TString EndpointDescription_; IAttributeDictionaryPtr EndpointAttributes_; }; @@ -223,9 +228,13 @@ private: IBusClientPtr CreateBusClient( TBusClientConfigPtr config, - IPacketTranscoderFactory* packetTranscoderFactory) + IPacketTranscoderFactory* packetTranscoderFactory, + IMemoryUsageTrackerPtr memoryUsageTracker) { - return New<TTcpBusClient>(std::move(config), packetTranscoderFactory); + return New<TTcpBusClient>( + std::move(config), + packetTranscoderFactory, + std::move(memoryUsageTracker)); } //////////////////////////////////////////////////////////////////////////////// diff --git a/yt/yt/core/bus/tcp/client.h b/yt/yt/core/bus/tcp/client.h index 5b38dd1d51..9cfbe84beb 100644 --- a/yt/yt/core/bus/tcp/client.h +++ b/yt/yt/core/bus/tcp/client.h @@ -11,7 +11,8 @@ namespace NYT::NBus { //! Initializes a new client for communicating with a given address. IBusClientPtr CreateBusClient( TBusClientConfigPtr config, - IPacketTranscoderFactory* packetTranscoderFactory = GetYTPacketTranscoderFactory()); + IPacketTranscoderFactory* packetTranscoderFactory = GetYTPacketTranscoderFactory(), + IMemoryUsageTrackerPtr memoryUsageTracker = GetNullMemoryUsageTracker()); //////////////////////////////////////////////////////////////////////////////// diff --git a/yt/yt/core/bus/tcp/config.cpp b/yt/yt/core/bus/tcp/config.cpp index c505874930..229507cbda 100644 --- a/yt/yt/core/bus/tcp/config.cpp +++ b/yt/yt/core/bus/tcp/config.cpp @@ -108,6 +108,10 @@ void TBusConfig::Register(TRegistrar registrar) .Default(5); registrar.Parameter("bind_retry_backoff", &TThis::BindRetryBackoff) .Default(TDuration::Seconds(3)); + registrar.Parameter("connection_start_delay", &TThis::ConnectionStartDelay) + .Default(); + registrar.Parameter("packet_decoder_delay", &TThis::PacketDecoderDelay) + .Default(); registrar.Parameter("read_stall_timeout", &TThis::ReadStallTimeout) .Default(TDuration::Minutes(1)); registrar.Parameter("write_stall_timeout", &TThis::WriteStallTimeout) diff --git a/yt/yt/core/bus/tcp/config.h b/yt/yt/core/bus/tcp/config.h index 3695a42093..4cbd422ed6 100644 --- a/yt/yt/core/bus/tcp/config.h +++ b/yt/yt/core/bus/tcp/config.h @@ -92,6 +92,9 @@ public: TDuration ReadStallTimeout; TDuration WriteStallTimeout; + std::optional<TDuration> ConnectionStartDelay; + std::optional<TDuration> PacketDecoderDelay; + bool VerifyChecksums; bool GenerateChecksums; diff --git a/yt/yt/core/bus/tcp/connection.cpp b/yt/yt/core/bus/tcp/connection.cpp index 07904d9fbe..4df9581e07 100644 --- a/yt/yt/core/bus/tcp/connection.cpp +++ b/yt/yt/core/bus/tcp/connection.cpp @@ -63,8 +63,11 @@ static constexpr i64 PendingOutBytesFlushThreshold = 1_MBs; //////////////////////////////////////////////////////////////////////////////// -struct TTcpConnectionReadBufferTag { }; -struct TTcpConnectionWriteBufferTag { }; +struct TTcpServerConnectionReadBufferTag { }; +struct TTcpServerConnectionWriteBufferTag { }; + +struct TTcpClientConnectionReadBufferTag { }; +struct TTcpClientConnectionWriteBufferTag { }; //////////////////////////////////////////////////////////////////////////////// @@ -110,7 +113,8 @@ TTcpConnection::TTcpConnection( const std::optional<TString>& unixDomainSocketPath, IMessageHandlerPtr handler, IPollerPtr poller, - IPacketTranscoderFactory* packetTranscoderFactory) + IPacketTranscoderFactory* packetTranscoderFactory, + IMemoryUsageTrackerPtr memoryUsageTracker) : Config_(std::move(config)) , ConnectionType_(connectionType) , Id_(id) @@ -137,6 +141,7 @@ TTcpConnection::TTcpConnection( , WriteStallTimeout_(NProfiling::DurationToCpuDuration(Config_->WriteStallTimeout)) , EncryptionMode_(Config_->EncryptionMode) , VerificationMode_(Config_->VerificationMode) + , MemoryUsageTracker_(std::move(memoryUsageTracker)) { } TTcpConnection::~TTcpConnection() @@ -211,7 +216,19 @@ void TTcpConnection::Start() } TTcpDispatcher::TImpl::Get()->RegisterConnection(this); - InitBuffers(); + + try { + InitBuffers(); + } catch (const std::exception& ex) { + Abort(TError(NBus::EErrorCode::TransportError, "I/O buffers allocation error") + << ex); + return; + } + + if (Config_->ConnectionStartDelay) { + YT_LOG_WARNING("Delay in opening activation of the test connection (Delay: %v)", Config_->ConnectionStartDelay); + TDelayedExecutor::WaitForDuration(Config_->ConnectionStartDelay.value()); + } switch (ConnectionType_) { case EConnectionType::Client: @@ -225,8 +242,7 @@ void TTcpConnection::Start() YT_VERIFY(Socket_ != INVALID_SOCKET); State_ = EState::Opening; SetupNetwork(EndpointNetworkAddress_); - Open(); - guard.Release(); + Open(guard); break; } @@ -395,7 +411,7 @@ TConnectionId TTcpConnection::GetId() const return Id_; } -void TTcpConnection::Open() +void TTcpConnection::Open(TGuard<NThreading::TSpinLock>& guard) { State_ = EState::Open; @@ -414,6 +430,8 @@ void TTcpConnection::Open() ArmPoller(); + guard.Release(); + // Something might be pending already, for example Terminate. if (Any(previousPendingControl & ~EPollControl::Offline)) { YT_LOG_TRACE("Retrying event processing for Open (PendingControl: %v)", previousPendingControl); @@ -550,10 +568,26 @@ bool TTcpConnection::AbortIfNetworkingDisabled() void TTcpConnection::InitBuffers() { - ReadBuffer_ = TBlob(GetRefCountedTypeCookie<TTcpConnectionReadBufferTag>(), ReadBufferSize, /*initializeStorage*/ false); - - WriteBuffers_.push_back(std::make_unique<TBlob>(GetRefCountedTypeCookie<TTcpConnectionWriteBufferTag>())); - WriteBuffers_[0]->Reserve(WriteBufferSize); + ReadBuffer_ = TMemoryTrackedBlob::Build( + MemoryUsageTracker_, + ConnectionType_ == EConnectionType::Server + ? GetRefCountedTypeCookie<TTcpServerConnectionReadBufferTag>() + : GetRefCountedTypeCookie<TTcpClientConnectionReadBufferTag>()); + ReadBuffer_ + .TryResize( + ReadBufferSize, + /*initializeStorage*/ false) + .ThrowOnError(); + + auto trackedBlob = TMemoryTrackedBlob::Build( + MemoryUsageTracker_, + ConnectionType_ == EConnectionType::Server + ? GetRefCountedTypeCookie<TTcpServerConnectionWriteBufferTag>() + : GetRefCountedTypeCookie<TTcpClientConnectionWriteBufferTag>()); + trackedBlob + .TryReserve(WriteBufferSize) + .ThrowOnError(); + WriteBuffers_.push_back(std::move(trackedBlob)); } int TTcpConnection::GetSocketPort() @@ -619,7 +653,7 @@ void TTcpConnection::OnDialerFinished(const TErrorOr<SOCKET>& socketOrError) InitSocketTosLevel(tosLevel); } - Open(); + Open(guard); } } @@ -782,6 +816,8 @@ void TTcpConnection::Terminate(const TError& error) // Arm calling OnTerminate() from OnEvent(). auto previousPendingControl = static_cast<EPollControl>(PendingControl_.fetch_or(static_cast<ui64>(EPollControl::Terminate))); + guard.Release(); + // To recover from bogus state always retry processing unless socket is offline if (None(previousPendingControl & EPollControl::Offline)) { YT_LOG_TRACE("Retrying event processing for Terminate (PendingControl: %v)", previousPendingControl); @@ -1092,6 +1128,11 @@ bool TTcpConnection::AdvanceDecoder(size_t size) return false; } + if (Config_->PacketDecoderDelay) { + YT_LOG_WARNING("Test delay in tcp connection packet decoder (Delay: %v)", Config_->PacketDecoderDelay); + TDelayedExecutor::WaitForDuration(Config_->PacketDecoderDelay.value()); + } + if (Decoder_->IsFinished()) { bool result = OnPacketReceived(); Decoder_->Restart(); @@ -1428,7 +1469,7 @@ bool TTcpConnection::MaybeEncodeFragments() // Discard all buffers except for a single one. WriteBuffers_.resize(1); - auto* buffer = WriteBuffers_.back().get(); + auto* buffer = &WriteBuffers_.back(); buffer->Clear(); size_t encodedSize = 0; @@ -1445,10 +1486,20 @@ bool TTcpConnection::MaybeEncodeFragments() if (buffer->Size() + fragment.Size() > buffer->Capacity()) { // Make sure we never reallocate. flushCoalesced(); - WriteBuffers_.push_back(std::make_unique<TBlob>(GetRefCountedTypeCookie<TTcpConnectionWriteBufferTag>())); - buffer = WriteBuffers_.back().get(); - buffer->Reserve(std::max(WriteBufferSize, fragment.Size())); + + auto size = std::max(WriteBufferSize, fragment.Size()); + + auto trackedBlob = TMemoryTrackedBlob::Build( + MemoryUsageTracker_, + ConnectionType_ == EConnectionType::Server + ? GetRefCountedTypeCookie<TTcpServerConnectionWriteBufferTag>() + : GetRefCountedTypeCookie<TTcpClientConnectionWriteBufferTag>()); + trackedBlob.Reserve(size); + + WriteBuffers_.push_back(std::move(trackedBlob)); + buffer = &WriteBuffers_.back(); } + buffer->Append(fragment); coalescedSize += fragment.Size(); }; diff --git a/yt/yt/core/bus/tcp/connection.h b/yt/yt/core/bus/tcp/connection.h index 0f5d758152..81e2c4a3a0 100644 --- a/yt/yt/core/bus/tcp/connection.h +++ b/yt/yt/core/bus/tcp/connection.h @@ -20,6 +20,7 @@ #include <yt/yt/core/misc/mpsc_stack.h> #include <yt/yt/core/misc/ring_queue.h> #include <yt/yt/core/misc/atomic_ptr.h> +#include <yt/yt/core/misc/memory_usage_tracker.h> #include <yt/yt/core/net/public.h> @@ -86,7 +87,8 @@ public: const std::optional<TString>& unixDomainSocketPath, IMessageHandlerPtr handler, NConcurrency::IPollerPtr poller, - IPacketTranscoderFactory* packetTranscoderFactory); + IPacketTranscoderFactory* packetTranscoderFactory, + IMemoryUsageTrackerPtr memoryUsageTracker); ~TTcpConnection(); @@ -239,7 +241,7 @@ private: std::unique_ptr<IPacketDecoder> Decoder_; const NProfiling::TCpuDuration ReadStallTimeout_; std::atomic<NProfiling::TCpuInstant> LastIncompleteReadTime_ = std::numeric_limits<NProfiling::TCpuInstant>::max(); - TBlob ReadBuffer_; + TMemoryTrackedBlob ReadBuffer_; TRingQueue<TPacketPtr> QueuedPackets_; TRingQueue<TPacketPtr> EncodedPackets_; @@ -248,7 +250,7 @@ private: std::unique_ptr<IPacketEncoder> Encoder_; const NProfiling::TCpuDuration WriteStallTimeout_; std::atomic<NProfiling::TCpuInstant> LastIncompleteWriteTime_ = std::numeric_limits<NProfiling::TCpuInstant>::max(); - std::vector<std::unique_ptr<TBlob>> WriteBuffers_; + std::vector<TMemoryTrackedBlob> WriteBuffers_; TRingQueue<TRef> EncodedFragments_; TRingQueue<size_t> EncodedPacketSizes_; @@ -277,11 +279,13 @@ private: const EEncryptionMode EncryptionMode_; const EVerificationMode VerificationMode_; + const IMemoryUsageTrackerPtr MemoryUsageTracker_; + NYTree::IAttributeDictionaryPtr PeerAttributes_; size_t MaxFragmentsPerWrite_ = 256; - void Open(); + void Open(TGuard<NThreading::TSpinLock>& guard); void Close(); void CloseSslSession(ESslState newSslState); diff --git a/yt/yt/core/bus/tcp/packet.cpp b/yt/yt/core/bus/tcp/packet.cpp index a6cd02de4f..aed8a6f932 100644 --- a/yt/yt/core/bus/tcp/packet.cpp +++ b/yt/yt/core/bus/tcp/packet.cpp @@ -160,14 +160,20 @@ class TPacketDecoder , public TPacketTranscoderBase<TPacketDecoder> { public: - TPacketDecoder(const NLogging::TLogger& logger, bool verifyChecksum) + TPacketDecoder( + const NLogging::TLogger& logger, + bool verifyChecksum, + IMemoryUsageTrackerPtr memoryUsageTracker) : TPacketTranscoderBase(logger) , Allocator_( PacketDecoderChunkSize, TChunkedMemoryAllocator::DefaultMaxSmallBlockSizeRatio, GetRefCountedTypeCookie<TPacketDecoderTag>()) , VerifyChecksum_(verifyChecksum) + , MemoryUsageTracker_(std::move(memoryUsageTracker)) { + YT_VERIFY(MemoryUsageTracker_); + Restart(); } @@ -201,6 +207,7 @@ public: Phase_ = EPacketPhase::FixedHeader; PacketSize_ = 0; Parts_.clear(); + MemoryGuard_ = TMemoryUsageTrackerGuard::Acquire(MemoryUsageTracker_, 0); PartIndex_ = -1; Message_.Reset(); @@ -243,11 +250,14 @@ private: TChunkedMemoryAllocator Allocator_; std::vector<TSharedRef> Parts_; + TMemoryUsageTrackerGuard MemoryGuard_; size_t PacketSize_ = 0; const bool VerifyChecksum_; + const IMemoryUsageTrackerPtr MemoryUsageTracker_; + bool EndFixedHeaderPhase() { if (FixedHeader_.Signature != PacketSignature) { @@ -351,6 +361,8 @@ private: Parts_.push_back(TSharedRef::MakeEmpty()); } else { auto part = Allocator_.AllocateAligned(partSize); + MemoryGuard_.IncrementSize(part.Size()); + BeginPhase(EPacketPhase::MessagePart, part.Begin(), part.Size()); Parts_.push_back(std::move(part)); break; @@ -494,14 +506,21 @@ private: //////////////////////////////////////////////////////////////////////////////// -struct TPacketTranscoderFactory +class TPacketTranscoderFactory : public IPacketTranscoderFactory { +public: + TPacketTranscoderFactory(IMemoryUsageTrackerPtr memoryUsageTracker) + : MemoryUsageTracker_(std::move(memoryUsageTracker)) + { + YT_VERIFY(MemoryUsageTracker_); + } + std::unique_ptr<IPacketDecoder> CreateDecoder( const NLogging::TLogger& logger, bool verifyChecksum) const override { - return std::make_unique<TPacketDecoder>(logger, verifyChecksum); + return std::make_unique<TPacketDecoder>(logger, verifyChecksum, MemoryUsageTracker_); } std::unique_ptr<IPacketEncoder> CreateEncoder( @@ -509,13 +528,16 @@ struct TPacketTranscoderFactory { return std::make_unique<TPacketEncoder>(logger); } + +private: + const IMemoryUsageTrackerPtr MemoryUsageTracker_; }; //////////////////////////////////////////////////////////////////////////////// -IPacketTranscoderFactory* GetYTPacketTranscoderFactory() +IPacketTranscoderFactory* GetYTPacketTranscoderFactory(IMemoryUsageTrackerPtr memoryUsageTracker) { - return LeakySingleton<TPacketTranscoderFactory>(); + return LeakySingleton<TPacketTranscoderFactory>(std::move(memoryUsageTracker)); } //////////////////////////////////////////////////////////////////////////////// diff --git a/yt/yt/core/bus/tcp/packet.h b/yt/yt/core/bus/tcp/packet.h index 05e4693128..2e2a729ee1 100644 --- a/yt/yt/core/bus/tcp/packet.h +++ b/yt/yt/core/bus/tcp/packet.h @@ -2,6 +2,8 @@ #include "private.h" +#include <yt/yt/core/misc/memory_usage_tracker.h> + namespace NYT::NBus { //////////////////////////////////////////////////////////////////////////////// @@ -76,7 +78,7 @@ struct IPacketTranscoderFactory //////////////////////////////////////////////////////////////////////////////// -IPacketTranscoderFactory* GetYTPacketTranscoderFactory(); +IPacketTranscoderFactory* GetYTPacketTranscoderFactory(IMemoryUsageTrackerPtr memoryUsageTracker = GetNullMemoryUsageTracker()); //////////////////////////////////////////////////////////////////////////////// diff --git a/yt/yt/core/bus/tcp/server.cpp b/yt/yt/core/bus/tcp/server.cpp index aa46ee9068..4dadac913f 100644 --- a/yt/yt/core/bus/tcp/server.cpp +++ b/yt/yt/core/bus/tcp/server.cpp @@ -43,15 +43,18 @@ public: TBusServerConfigPtr config, IPollerPtr poller, IMessageHandlerPtr handler, - IPacketTranscoderFactory* packetTranscoderFactory) + IPacketTranscoderFactory* packetTranscoderFactory, + IMemoryUsageTrackerPtr memoryUsageTracker) : Config_(std::move(config)) , Poller_(std::move(poller)) , Handler_(std::move(handler)) , PacketTranscoderFactory_(std::move(packetTranscoderFactory)) + , MemoryUsageTracker_(std::move(memoryUsageTracker)) { YT_VERIFY(Config_); YT_VERIFY(Poller_); YT_VERIFY(Handler_); + YT_VERIFY(MemoryUsageTracker_); if (Config_->Port) { Logger.AddTag("ServerPort: %v", *Config_->Port); @@ -123,6 +126,8 @@ protected: IPacketTranscoderFactory* const PacketTranscoderFactory_; + const IMemoryUsageTrackerPtr MemoryUsageTracker_; + YT_DECLARE_SPIN_LOCK(NThreading::TSpinLock, ControlSpinLock_); SOCKET ServerSocket_ = INVALID_SOCKET; @@ -253,7 +258,8 @@ protected: std::nullopt, Handler_, std::move(poller), - PacketTranscoderFactory_); + PacketTranscoderFactory_, + MemoryUsageTracker_); { auto guard = WriterGuard(ConnectionsSpinLock_); @@ -347,12 +353,14 @@ public: TBusServerConfigPtr config, IPollerPtr poller, IMessageHandlerPtr handler, - IPacketTranscoderFactory* packetTranscoderFactory) + IPacketTranscoderFactory* packetTranscoderFactory, + IMemoryUsageTrackerPtr memoryUsageTracker) : TTcpBusServerBase( std::move(config), std::move(poller), std::move(handler), - packetTranscoderFactory) + packetTranscoderFactory, + std::move(memoryUsageTracker)) { } private: @@ -390,11 +398,14 @@ class TTcpBusServerProxy public: explicit TTcpBusServerProxy( TBusServerConfigPtr config, - IPacketTranscoderFactory* packetTranscoderFactory) + IPacketTranscoderFactory* packetTranscoderFactory, + IMemoryUsageTrackerPtr memoryUsageTracker) : Config_(std::move(config)) , PacketTranscoderFactory_(packetTranscoderFactory) + , MemoryUsageTracker_(std::move(memoryUsageTracker)) { YT_VERIFY(Config_); + YT_VERIFY(MemoryUsageTracker_); } ~TTcpBusServerProxy() @@ -408,7 +419,8 @@ public: Config_, TTcpDispatcher::TImpl::Get()->GetAcceptorPoller(), std::move(handler), - PacketTranscoderFactory_); + PacketTranscoderFactory_, + MemoryUsageTracker_); Server_.Store(server); server->Start(); @@ -428,6 +440,8 @@ private: IPacketTranscoderFactory* const PacketTranscoderFactory_; + const IMemoryUsageTrackerPtr MemoryUsageTracker_; + TAtomicIntrusivePtr<TServer> Server_; }; @@ -467,7 +481,8 @@ private: IBusServerPtr CreateBusServer( TBusServerConfigPtr config, - IPacketTranscoderFactory* packetTranscoderFactory) + IPacketTranscoderFactory* packetTranscoderFactory, + IMemoryUsageTrackerPtr memoryUsageTracker) { std::vector<IBusServerPtr> servers; @@ -475,14 +490,16 @@ IBusServerPtr CreateBusServer( servers.push_back( New<TTcpBusServerProxy<TRemoteTcpBusServer>>( config, - packetTranscoderFactory)); + packetTranscoderFactory, + memoryUsageTracker)); } #ifdef _linux_ // Abstract unix sockets are supported only on Linux. servers.push_back( New<TTcpBusServerProxy<TLocalTcpBusServer>>( config, - packetTranscoderFactory)); + packetTranscoderFactory, + memoryUsageTracker)); #endif return New<TCompositeBusServer>(std::move(servers)); diff --git a/yt/yt/core/bus/tcp/server.h b/yt/yt/core/bus/tcp/server.h index 7dcb351e10..57b3a0cde9 100644 --- a/yt/yt/core/bus/tcp/server.h +++ b/yt/yt/core/bus/tcp/server.h @@ -10,7 +10,8 @@ namespace NYT::NBus { IBusServerPtr CreateBusServer( TBusServerConfigPtr config, - IPacketTranscoderFactory* packetTranscoderFactory = GetYTPacketTranscoderFactory()); + IPacketTranscoderFactory* packetTranscoderFactory = GetYTPacketTranscoderFactory(), + IMemoryUsageTrackerPtr memoryUsageTracker = GetNullMemoryUsageTracker()); //////////////////////////////////////////////////////////////////////////////// diff --git a/yt/yt/core/concurrency/periodic_executor_base-inl.h b/yt/yt/core/concurrency/periodic_executor_base-inl.h index 356a828888..33dd40db5b 100644 --- a/yt/yt/core/concurrency/periodic_executor_base-inl.h +++ b/yt/yt/core/concurrency/periodic_executor_base-inl.h @@ -151,8 +151,20 @@ void TPeriodicExecutorBase<TInvocationTimePolicy>::PostCallback() { GuardedInvoke( Invoker_, - BIND_NO_PROPAGATE(&TThis::RunCallback, MakeWeak(this)), - BIND_NO_PROPAGATE(&TThis::OnCallbackCancelled, MakeWeak(this))); + [weakThis = MakeWeak(this)] { + auto strongThis = weakThis.Lock(); + + if (strongThis) { + strongThis->RunCallback(); + } + }, + [weakThis = MakeWeak(this)] { + auto strongThis = weakThis.Lock(); + + if (strongThis) { + strongThis->OnCallbackCancelled(); + } + }); } template <CInvocationTimePolicy TInvocationTimePolicy> diff --git a/yt/yt/core/misc/memory_usage_tracker.cpp b/yt/yt/core/misc/memory_usage_tracker.cpp index d968e25fec..c5cdd80e4d 100644 --- a/yt/yt/core/misc/memory_usage_tracker.cpp +++ b/yt/yt/core/misc/memory_usage_tracker.cpp @@ -70,6 +70,21 @@ void TMemoryUsageTrackerGuard::MoveFrom(TMemoryUsageTrackerGuard&& other) other.Granularity_ = 0; } +TMemoryUsageTrackerGuard TMemoryUsageTrackerGuard::Build( + IMemoryUsageTrackerPtr tracker, + i64 granularity) +{ + if (!tracker) { + return {}; + } + + TMemoryUsageTrackerGuard guard; + guard.Tracker_ = tracker; + guard.Size_ = 0; + guard.Granularity_ = granularity; + return guard; +} + TMemoryUsageTrackerGuard TMemoryUsageTrackerGuard::Acquire( IMemoryUsageTrackerPtr tracker, i64 size, @@ -206,5 +221,104 @@ TMemoryUsageTrackerGuard TMemoryUsageTrackerGuard::TransferMemory(i64 size) //////////////////////////////////////////////////////////////////////////////// +TMemoryTrackedBlob::TMemoryTrackedBlob( + TBlob&& blob, + TMemoryUsageTrackerGuard&& guard) + : Blob_(std::move(blob)) + , Guard_(std::move(guard)) +{ } + +TMemoryTrackedBlob TMemoryTrackedBlob::Build( + IMemoryUsageTrackerPtr tracker, + TRefCountedTypeCookie tagCookie) +{ + YT_VERIFY(tracker); + + return TMemoryTrackedBlob( + TBlob(tagCookie), + TMemoryUsageTrackerGuard::Build(tracker)); +} + +void TMemoryTrackedBlob::Resize( + i64 size, + bool initializeStorage) +{ + YT_VERIFY(size >= 0); + + Blob_.Resize(size, initializeStorage); + Guard_.SetSize(Blob_.Capacity()); +} + +TError TMemoryTrackedBlob::TryResize( + i64 size, + bool initializeStorage) +{ + YT_VERIFY(size >= 0); + auto result = Guard_.TrySetSize(size); + + if (result.IsOK()) { + Blob_.Resize(size, initializeStorage); + return {}; + } else { + return result; + } +} + +void TMemoryTrackedBlob::Reserve(i64 size) +{ + YT_VERIFY(size >= 0); + + Blob_.Reserve(size); + Guard_.SetSize(Blob_.Capacity()); +} + +TError TMemoryTrackedBlob::TryReserve(i64 size) +{ + YT_VERIFY(size >= 0); + + auto result = Guard_.TrySetSize(size); + + if (result.IsOK()) { + Blob_.Reserve(size); + return {}; + } else { + return result; + } +} + +char* TMemoryTrackedBlob::Begin() +{ + return Blob_.Begin(); +} + +char* TMemoryTrackedBlob::End() +{ + return Blob_.End(); +} + +size_t TMemoryTrackedBlob::Capacity() const +{ + return Blob_.Capacity(); +} + +size_t TMemoryTrackedBlob::Size() const +{ + return Blob_.Size(); +} + +void TMemoryTrackedBlob::Append(TRef ref) +{ + Blob_.Append(ref); + Guard_.SetSize(Blob_.Capacity()); +} + +void TMemoryTrackedBlob::Clear() +{ + Blob_.Clear(); + Guard_.SetSize(Blob_.Capacity()); +} + +//////////////////////////////////////////////////////////////////////////////// + } // namespace NYT diff --git a/yt/yt/core/misc/memory_usage_tracker.h b/yt/yt/core/misc/memory_usage_tracker.h index 61248faaa8..417055528d 100644 --- a/yt/yt/core/misc/memory_usage_tracker.h +++ b/yt/yt/core/misc/memory_usage_tracker.h @@ -2,6 +2,8 @@ #include "error.h" +#include <library/cpp/yt/memory/blob.h> + namespace NYT { //////////////////////////////////////////////////////////////////////////////// @@ -35,6 +37,9 @@ public: TMemoryUsageTrackerGuard& operator=(const TMemoryUsageTrackerGuard& other) = delete; TMemoryUsageTrackerGuard& operator=(TMemoryUsageTrackerGuard&& other); + static TMemoryUsageTrackerGuard Build( + IMemoryUsageTrackerPtr tracker, + i64 granularity = 1); static TMemoryUsageTrackerGuard Acquire( IMemoryUsageTrackerPtr tracker, i64 size, @@ -70,4 +75,54 @@ private: //////////////////////////////////////////////////////////////////////////////// +class TMemoryTrackedBlob +{ +public: + static TMemoryTrackedBlob Build( + IMemoryUsageTrackerPtr tracker, + TRefCountedTypeCookie tagCookie = GetRefCountedTypeCookie<TDefaultBlobTag>()); + + TMemoryTrackedBlob() = default; + TMemoryTrackedBlob(const TMemoryTrackedBlob& other) = delete; + TMemoryTrackedBlob(TMemoryTrackedBlob&& other) = default; + ~TMemoryTrackedBlob() = default; + + TMemoryTrackedBlob& operator=(const TMemoryTrackedBlob& other) = delete; + TMemoryTrackedBlob& operator=(TMemoryTrackedBlob&& other) = default; + + void Resize( + i64 size, + bool initializeStorage = true); + + TError TryResize( + i64 size, + bool initializeStorage = true); + + void Reserve(i64 capacity); + + TError TryReserve(i64 capacity); + + char* Begin(); + + char* End(); + + size_t Capacity() const; + + size_t Size() const; + + void Append(TRef ref); + + void Clear(); + +private: + TBlob Blob_; + TMemoryUsageTrackerGuard Guard_; + + TMemoryTrackedBlob( + TBlob&& blob, + TMemoryUsageTrackerGuard&& guard); +}; + +//////////////////////////////////////////////////////////////////////////////// + } // namespace NYT diff --git a/yt/yt/core/net/dialer.cpp b/yt/yt/core/net/dialer.cpp index c27489d23b..2a1511a8b1 100644 --- a/yt/yt/core/net/dialer.cpp +++ b/yt/yt/core/net/dialer.cpp @@ -240,8 +240,10 @@ private: Poller_->Unarm(Socket_, Pollable_); YT_VERIFY(Pollable_); - YT_UNUSED_FUTURE(Poller_->Unregister(Pollable_)); - Pollable_.Reset(); + auto pollable = std::move(Pollable_); + SpinLock_.Release(); + YT_UNUSED_FUTURE(Poller_->Unregister(pollable)); + SpinLock_.Acquire(); } void Connect(TGuard<NThreading::TSpinLock>& guard) diff --git a/yt/yt/core/rpc/bus/channel.cpp b/yt/yt/core/rpc/bus/channel.cpp index e5d04ea83a..7a2369c81f 100644 --- a/yt/yt/core/rpc/bus/channel.cpp +++ b/yt/yt/core/rpc/bus/channel.cpp @@ -1248,25 +1248,36 @@ class TTcpBusChannelFactory : public IChannelFactory { public: - explicit TTcpBusChannelFactory(TBusConfigPtr config) + TTcpBusChannelFactory( + TBusConfigPtr config, + IMemoryUsageTrackerPtr memoryUsageTracker) : Config_(ConvertToNode(std::move(config))) + , MemoryUsageTracker_(std::move(memoryUsageTracker)) { } IChannelPtr CreateChannel(const TString& address) override { auto config = TBusClientConfig::CreateTcp(address); config->Load(Config_, /*postprocess*/ true, /*setDefaults*/ false); - auto client = CreateBusClient(std::move(config)); + auto client = CreateBusClient( + std::move(config), + GetYTPacketTranscoderFactory(MemoryUsageTracker_), + MemoryUsageTracker_); return CreateBusChannel(std::move(client)); } private: const INodePtr Config_; + const IMemoryUsageTrackerPtr MemoryUsageTracker_; }; -IChannelFactoryPtr CreateTcpBusChannelFactory(TBusConfigPtr config) +IChannelFactoryPtr CreateTcpBusChannelFactory( + TBusConfigPtr config, + IMemoryUsageTrackerPtr memoryUsageTracker) { - return New<TTcpBusChannelFactory>(std::move(config)); + return New<TTcpBusChannelFactory>( + std::move(config), + std::move(memoryUsageTracker)); } //////////////////////////////////////////////////////////////////////////////// @@ -1275,25 +1286,38 @@ class TUdsBusChannelFactory : public IChannelFactory { public: - explicit TUdsBusChannelFactory(TBusConfigPtr config) + TUdsBusChannelFactory( + TBusConfigPtr config, + IMemoryUsageTrackerPtr memoryUsageTracker) : Config_(ConvertToNode(std::move(config))) - { } + , MemoryUsageTracker_(std::move(memoryUsageTracker)) + { + YT_VERIFY(MemoryUsageTracker_); + } IChannelPtr CreateChannel(const TString& address) override { auto config = TBusClientConfig::CreateUds(address); config->Load(Config_, /*postprocess*/ true, /*setDefaults*/ false); - auto client = CreateBusClient(std::move(config)); + auto client = CreateBusClient( + std::move(config), + GetYTPacketTranscoderFactory(MemoryUsageTracker_), + MemoryUsageTracker_); return CreateBusChannel(std::move(client)); } private: const INodePtr Config_; + const IMemoryUsageTrackerPtr MemoryUsageTracker_; }; -IChannelFactoryPtr CreateUdsBusChannelFactory(TBusConfigPtr config) +IChannelFactoryPtr CreateUdsBusChannelFactory( + TBusConfigPtr config, + IMemoryUsageTrackerPtr memoryUsageTracker) { - return New<TUdsBusChannelFactory>(std::move(config)); + return New<TUdsBusChannelFactory>( + std::move(config), + std::move(memoryUsageTracker)); } //////////////////////////////////////////////////////////////////////////////// diff --git a/yt/yt/core/rpc/bus/channel.h b/yt/yt/core/rpc/bus/channel.h index aaf02a8241..9a69342120 100644 --- a/yt/yt/core/rpc/bus/channel.h +++ b/yt/yt/core/rpc/bus/channel.h @@ -3,6 +3,7 @@ #include "public.h" #include <yt/yt/core/bus/tcp/public.h> +#include <yt/yt/core/misc/memory_usage_tracker.h> namespace NYT::NRpc::NBus { @@ -12,10 +13,14 @@ namespace NYT::NRpc::NBus { IChannelPtr CreateBusChannel(NYT::NBus::IBusClientPtr client); //! Creates a factory for creating TCP Bus channels. -IChannelFactoryPtr CreateTcpBusChannelFactory(NYT::NBus::TBusConfigPtr config); +IChannelFactoryPtr CreateTcpBusChannelFactory( + NYT::NBus::TBusConfigPtr config, + IMemoryUsageTrackerPtr memoryUsageTracker = GetNullMemoryUsageTracker()); //! Creates a factory for creating Unix domain socket (UDS) Bus channels. -IChannelFactoryPtr CreateUdsBusChannelFactory(NYT::NBus::TBusConfigPtr config); +IChannelFactoryPtr CreateUdsBusChannelFactory( + NYT::NBus::TBusConfigPtr config, + IMemoryUsageTrackerPtr memoryUsageTracker = GetNullMemoryUsageTracker()); //////////////////////////////////////////////////////////////////////////////// diff --git a/yt/yt/core/rpc/unittests/handle_channel_failure_ut.cpp b/yt/yt/core/rpc/unittests/handle_channel_failure_ut.cpp index c8eafd445f..fd9165b570 100644 --- a/yt/yt/core/rpc/unittests/handle_channel_failure_ut.cpp +++ b/yt/yt/core/rpc/unittests/handle_channel_failure_ut.cpp @@ -16,9 +16,9 @@ class THandleChannelFailureTestBase : public ::testing::Test { public: - IServerPtr CreateServer(const TTestServerHost& serverHost) + IServerPtr CreateServer(const TTestServerHost& serverHost, IMemoryUsageTrackerPtr memoryUsageTracker) { - return TImpl::CreateServer(serverHost.GetPort()); + return TImpl::CreateServer(serverHost.GetPort(), memoryUsageTracker); } IChannelPtr CreateChannel(const TString& address) @@ -50,7 +50,7 @@ TYPED_TEST(THandleChannelFailureTest, HandleChannelFailureTest) auto workerPool = NConcurrency::CreateThreadPool(4, "Worker"); outerServer.InitializeServer( - this->CreateServer(outerServer), + this->CreateServer(outerServer, New<TTestNodeMemoryTracker>(32_MB)), workerPool->GetInvoker(), /*secure*/ false, BIND([&] (const TString& address) { @@ -58,7 +58,7 @@ TYPED_TEST(THandleChannelFailureTest, HandleChannelFailureTest) })); innerServer.InitializeServer( - this->CreateServer(innerServer), + this->CreateServer(innerServer, New<TTestNodeMemoryTracker>(32_MB)), workerPool->GetInvoker(), /*secure*/ false, /*createChannel*/ {}); diff --git a/yt/yt/core/rpc/unittests/lib/common.cpp b/yt/yt/core/rpc/unittests/lib/common.cpp index d906d82fd4..91697c38b1 100644 --- a/yt/yt/core/rpc/unittests/lib/common.cpp +++ b/yt/yt/core/rpc/unittests/lib/common.cpp @@ -8,4 +8,106 @@ TString TRpcOverUdsImpl::SocketPath_ = ""; //////////////////////////////////////////////////////////////////////////////// +TTestNodeMemoryTracker::TTestNodeMemoryTracker(size_t limit) + : Usage_(0) + , Limit_(limit) +{ } + +i64 TTestNodeMemoryTracker::GetLimit() const +{ + auto guard = Guard(Lock_); + return Limit_; +} + +i64 TTestNodeMemoryTracker::GetUsed() const +{ + auto guard = Guard(Lock_); + return Usage_; +} + +i64 TTestNodeMemoryTracker::GetFree() const +{ + auto guard = Guard(Lock_); + return Limit_ - Usage_; +} + +bool TTestNodeMemoryTracker::IsExceeded() const +{ + auto guard = Guard(Lock_); + return GetFree() > 0; +} + +TError TTestNodeMemoryTracker::TryAcquire(i64 size) +{ + auto guard = Guard(Lock_); + return DoTryAcquire(size); +} + +TError TTestNodeMemoryTracker::DoTryAcquire(i64 size) +{ + if (Usage_ + size >= Limit_) { + return TError("Memory exceeded"); + } + + Usage_ += size; + TotalUsage_ += size; + + return {}; +} + +TError TTestNodeMemoryTracker::TryChange(i64 size) +{ + auto guard = Guard(Lock_); + + if (size > Usage_) { + return DoTryAcquire(size - Usage_); + } else if (size < Usage_) { + DoRelease(Usage_ - size); + } + + return {}; +} + +bool TTestNodeMemoryTracker::Acquire(i64 size) +{ + auto guard = Guard(Lock_); + DoAcquire(size); + return Usage_ >= Limit_; +} + +void TTestNodeMemoryTracker::Release(i64 size) +{ + auto guard = Guard(Lock_); + DoRelease(size); +} + +void TTestNodeMemoryTracker::SetLimit(i64 size) +{ + auto guard = Guard(Lock_); + Limit_ = size; +} + +void TTestNodeMemoryTracker::DoAcquire(i64 size) +{ + Usage_ += size; + TotalUsage_ += size; +} + +void TTestNodeMemoryTracker::DoRelease(i64 size) +{ + Usage_ -= size; +} + +void TTestNodeMemoryTracker::ClearTotalUsage() +{ + TotalUsage_ = 0; +} + +i64 TTestNodeMemoryTracker::GetTotalUsage() const +{ + return TotalUsage_; +} + +//////////////////////////////////////////////////////////////////////////////// + } // namespace NYT::NRpc diff --git a/yt/yt/core/rpc/unittests/lib/common.h b/yt/yt/core/rpc/unittests/lib/common.h index b3981952ba..57fc7713d4 100644 --- a/yt/yt/core/rpc/unittests/lib/common.h +++ b/yt/yt/core/rpc/unittests/lib/common.h @@ -17,6 +17,7 @@ #include <yt/yt/core/bus/public.h> #include <yt/yt/core/misc/fs.h> +#include <yt/yt/core/misc/memory_usage_tracker.h> #include <yt/yt/core/rpc/bus/channel.h> #include <yt/yt/core/rpc/bus/server.h> @@ -108,6 +109,42 @@ protected: //////////////////////////////////////////////////////////////////////////////// +class TTestNodeMemoryTracker + : public IMemoryUsageTracker +{ +public: + explicit TTestNodeMemoryTracker(size_t limit); + + i64 GetLimit() const; + i64 GetUsed() const; + i64 GetFree() const; + bool IsExceeded() const; + + TError TryAcquire(i64 size) override; + TError TryChange(i64 size) override; + bool Acquire(i64 size) override; + void Release(i64 size) override; + void SetLimit(i64 size) override; + + void ClearTotalUsage(); + i64 GetTotalUsage() const; + +private: + YT_DECLARE_SPIN_LOCK(NThreading::TSpinLock, Lock_); + i64 Usage_; + i64 Limit_; + i64 TotalUsage_; + + TError DoTryAcquire(i64 size); + void DoAcquire(i64 size); + void DoRelease(i64 size); +}; + +DECLARE_REFCOUNTED_CLASS(TTestNodeMemoryTracker) +DEFINE_REFCOUNTED_TYPE(TTestNodeMemoryTracker) + +//////////////////////////////////////////////////////////////////////////////// + template <class TImpl> class TTestBase : public ::testing::Test @@ -120,9 +157,9 @@ public: WorkerPool_ = NConcurrency::CreateThreadPool(4, "Worker"); bool secure = TImpl::Secure; - + MemoryUsageTracker_ = New<TTestNodeMemoryTracker>(32_MB); TTestServerHost::InitializeServer( - TImpl::CreateServer(Port_), + TImpl::CreateServer(Port_, MemoryUsageTracker_), WorkerPool_->GetInvoker(), secure, /*createChannel*/ {}); @@ -133,6 +170,11 @@ public: TTestServerHost::TearDown(); } + TTestNodeMemoryTrackerPtr GetMemoryUsageTracker() + { + return MemoryUsageTracker_; + } + IChannelPtr CreateChannel( const std::optional<TString>& address = std::nullopt, THashMap<TString, NYTree::INodePtr> grpcArguments = {}) @@ -168,6 +210,7 @@ public: private: NConcurrency::IThreadPoolPtr WorkerPool_; + TTestNodeMemoryTrackerPtr MemoryUsageTracker_; }; //////////////////////////////////////////////////////////////////////////////// @@ -179,9 +222,9 @@ public: static constexpr bool AllowTransportErrors = false; static constexpr bool Secure = false; - static IServerPtr CreateServer(ui16 port) + static IServerPtr CreateServer(ui16 port, IMemoryUsageTrackerPtr memoryUsageTracker) { - auto busServer = MakeBusServer(port); + auto busServer = MakeBusServer(port, memoryUsageTracker); return NRpc::NBus::CreateBusServer(busServer); } @@ -193,9 +236,9 @@ public: return TImpl::CreateChannel(address, serverAddress, std::move(grpcArguments)); } - static NYT::NBus::IBusServerPtr MakeBusServer(ui16 port) + static NYT::NBus::IBusServerPtr MakeBusServer(ui16 port, IMemoryUsageTrackerPtr memoryUsageTracker) { - return TImpl::MakeBusServer(port); + return TImpl::MakeBusServer(port, memoryUsageTracker); } }; @@ -214,10 +257,13 @@ public: return NRpc::NBus::CreateBusChannel(client); } - static NYT::NBus::IBusServerPtr MakeBusServer(ui16 port) + static NYT::NBus::IBusServerPtr MakeBusServer(ui16 port, IMemoryUsageTrackerPtr memoryUsageTracker) { auto busConfig = NYT::NBus::TBusServerConfig::CreateTcp(port); - return CreateBusServer(busConfig); + return CreateBusServer( + busConfig, + NYT::NBus::GetYTPacketTranscoderFactory(memoryUsageTracker), + memoryUsageTracker); } }; @@ -397,7 +443,9 @@ public: return NGrpc::CreateGrpcChannel(channelConfig); } - static IServerPtr CreateServer(ui16 port) + static IServerPtr CreateServer( + ui16 port, + IMemoryUsageTrackerPtr /*memoryUsageTracker*/) { auto serverAddressConfig = New<NGrpc::TServerAddressConfig>(); if (EnableSsl) { @@ -429,11 +477,14 @@ public: class TRpcOverUdsImpl { public: - static NYT::NBus::IBusServerPtr MakeBusServer(ui16 port) + static NYT::NBus::IBusServerPtr MakeBusServer(ui16 port, IMemoryUsageTrackerPtr memoryUsageTracker) { SocketPath_ = GetWorkPath() + "/socket_" + ToString(port); auto busConfig = NYT::NBus::TBusServerConfig::CreateUds(SocketPath_); - return CreateBusServer(busConfig); + return CreateBusServer( + busConfig, + NYT::NBus::GetYTPacketTranscoderFactory(memoryUsageTracker), + memoryUsageTracker); } static IChannelPtr CreateChannel( diff --git a/yt/yt/core/rpc/unittests/lib/test_service.cpp b/yt/yt/core/rpc/unittests/lib/test_service.cpp index 601d95c22c..0e51ffe6e7 100644 --- a/yt/yt/core/rpc/unittests/lib/test_service.cpp +++ b/yt/yt/core/rpc/unittests/lib/test_service.cpp @@ -47,7 +47,9 @@ public: RegisterMethod(RPC_SERVICE_METHOD_DESC(SlowCall) .SetCancelable(true) .SetConcurrencyLimit(10) - .SetQueueSizeLimit(20)); + .SetQueueSizeLimit(20) + .SetConcurrencyByteLimit(10_MB) + .SetQueueByteSizeLimit(20_MB)); RegisterMethod(RPC_SERVICE_METHOD_DESC(SlowCanceledCall) .SetCancelable(true)); RegisterMethod(RPC_SERVICE_METHOD_DESC(RequestBytesThrottledCall)); diff --git a/yt/yt/core/rpc/unittests/lib/test_service.proto b/yt/yt/core/rpc/unittests/lib/test_service.proto index 79d5bff3de..8ef02ef1f5 100644 --- a/yt/yt/core/rpc/unittests/lib/test_service.proto +++ b/yt/yt/core/rpc/unittests/lib/test_service.proto @@ -110,6 +110,8 @@ message TRspNotRegistered message TReqSlowCall { + optional int32 request_codec = 1; + optional string message = 2; } message TRspSlowCall diff --git a/yt/yt/core/rpc/unittests/rpc_allocation_tags_ut.cpp b/yt/yt/core/rpc/unittests/rpc_allocation_tags_ut.cpp index d118ce0549..52466ee3a5 100644 --- a/yt/yt/core/rpc/unittests/rpc_allocation_tags_ut.cpp +++ b/yt/yt/core/rpc/unittests/rpc_allocation_tags_ut.cpp @@ -29,6 +29,9 @@ TYPED_TEST_SUITE(TRpcTest, TAllTransports); TYPED_TEST(TRpcTest, ResponseWithAllocationTags) { + auto memoryUsageTracker = this->GetMemoryUsageTracker(); + auto previousLimit = memoryUsageTracker->GetLimit(); + memoryUsageTracker->SetLimit(2_GB); static TMemoryTag testMemoryTag = 1 << 20; testMemoryTag++; @@ -88,6 +91,8 @@ TYPED_TEST(TRpcTest, ResponseWithAllocationTags) << "InitialUsage: " << initialMemoryUsage << std::endl << "MemoryUsage before waiting: " << memoryUsageBefore << std::endl << "MemoryUsage after waiting: " << memoryUsageAfter; + + memoryUsageTracker->SetLimit(previousLimit); } #endif diff --git a/yt/yt/core/rpc/unittests/rpc_ut.cpp b/yt/yt/core/rpc/unittests/rpc_ut.cpp index 3924bd6c20..af3a679024 100644 --- a/yt/yt/core/rpc/unittests/rpc_ut.cpp +++ b/yt/yt/core/rpc/unittests/rpc_ut.cpp @@ -741,17 +741,126 @@ TYPED_TEST(TRpcTest, SlowCall) TYPED_TEST(TRpcTest, RequestQueueSizeLimit) { - TTestProxy proxy(this->CreateChannel()); std::vector<TFuture<void>> futures; + std::vector<TTestProxy> proxies; + + // Concurrency byte limit + queue byte size limit = 10 + 20 = 30 + // First 30 requests must be successful, 31st request must be failed. + for (int i = 0; i < 30; ++i) { + proxies.push_back(TTestProxy(this->CreateChannel())); + proxies[i].SetDefaultTimeout(TDuration::Seconds(60.0)); + } + for (int i = 0; i < 30; ++i) { + auto req = proxies[i].SlowCall(); + futures.push_back(req->Invoke().AsVoid()); + } + + Sleep(TDuration::MilliSeconds(400)); + { + TTestProxy proxy(this->CreateChannel()); + proxy.SetDefaultTimeout(TDuration::Seconds(60.0)); auto req = proxy.SlowCall(); + EXPECT_EQ(NRpc::EErrorCode::RequestQueueSizeLimitExceeded, req->Invoke().Get().GetCode()); + } + + EXPECT_TRUE(AllSucceeded(std::move(futures)).Get().IsOK()); +} + +TYPED_TEST(TNotGrpcTest, MemoryTracking) +{ + TTestProxy proxy(this->CreateChannel()); + auto memoryUsageTracker = this->GetMemoryUsageTracker(); + memoryUsageTracker->ClearTotalUsage(); + proxy.SetDefaultTimeout(TDuration::Seconds(10.0)); + for (int i = 0; i < 300; ++i) { + auto req = proxy.SomeCall(); + req->set_a(42); + WaitFor(req->Invoke().AsVoid()).ThrowOnError(); + } + + { + auto rpcUsage = memoryUsageTracker->GetTotalUsage(); + EXPECT_EQ(rpcUsage, (static_cast<i64>(32_KB))); + } +} + +TYPED_TEST(TNotGrpcTest, MemoryTrackingMultipleConnections) +{ + auto memoryUsageTracker = this->GetMemoryUsageTracker(); + memoryUsageTracker->ClearTotalUsage(); + for (int i = 0; i < 300; ++i) { + TTestProxy proxy(this->CreateChannel()); + proxy.SetDefaultTimeout(TDuration::Seconds(10.0)); + auto req = proxy.SomeCall(); + req->set_a(42); + WaitFor(req->Invoke().AsVoid()).ThrowOnError(); + } + + { + auto rpcUsage = memoryUsageTracker->GetTotalUsage(); + EXPECT_EQ(rpcUsage, (static_cast<i64>(32_KB) * 300)); + } +} + +TYPED_TEST(TNotGrpcTest, MemoryTrackingMultipleConcurrent) +{ + auto memoryUsageTracker = this->GetMemoryUsageTracker(); + memoryUsageTracker->ClearTotalUsage(); + std::vector<TFuture<void>> futures; + std::vector<TTestProxy> proxies; + + for (int i = 0; i < 40; ++i) { + proxies.push_back(TTestProxy(this->CreateChannel())); + proxies[i].SetDefaultTimeout(TDuration::Seconds(60.0)); + } + + for (int j = 0; j < 40; ++j) { + auto req = proxies[j % 40].SlowCall(); futures.push_back(req->Invoke().AsVoid()); } - Sleep(TDuration::MilliSeconds(100)); + + Sleep(TDuration::MilliSeconds(300)); + { + auto rpcUsage = memoryUsageTracker->GetUsed(); + // 20 = concurrency (10) + queue (20) + EXPECT_EQ(rpcUsage, (static_cast<i64>(32_KB) * 40)); + } + EXPECT_TRUE(AllSet(std::move(futures)).Get().IsOK()); +} + +TYPED_TEST(TNotGrpcTest, RequestQueueByteSizeLimit) +{ + const auto requestCodecId = NCompression::ECodec::Zstd_2; + + std::vector<TFuture<void>> futures; + std::vector<TTestProxy> proxies; + + // Every request contains 2 MB, 15 requests contain 30 MB. + // Concurrency byte limit + queue byte size limit = 10 MB + 20 MB = 30 MB + // First 15 requests must be successful, 16th request must be failed. + for (int i = 0; i < 15; ++i) { + proxies.push_back(TTestProxy(this->CreateChannel())); + proxies[i].SetDefaultTimeout(TDuration::Seconds(60.0)); + } + + for (int i = 0; i < 15; ++i) { + auto req = proxies[i].SlowCall(); + req->set_request_codec(static_cast<int>(requestCodecId)); + req->set_message(TString(2_MB, 'x')); + futures.push_back(req->Invoke().AsVoid()); + } + + Sleep(TDuration::MilliSeconds(400)); { + TTestProxy proxy(this->CreateChannel()); + proxy.SetDefaultTimeout(TDuration::Seconds(60.0)); auto req = proxy.SlowCall(); + req->set_request_codec(static_cast<int>(requestCodecId)); + req->set_message(TString(1_MB, 'x')); EXPECT_EQ(NRpc::EErrorCode::RequestQueueSizeLimitExceeded, req->Invoke().Get().GetCode()); } + EXPECT_TRUE(AllSucceeded(std::move(futures)).Get().IsOK()); } @@ -774,7 +883,7 @@ TYPED_TEST(TRpcTest, ConcurrencyLimit) EXPECT_TRUE(AllSucceeded(std::move(futures)).Get().IsOK()); - Sleep(TDuration::MilliSeconds(400)); + Sleep(TDuration::MilliSeconds(200)); EXPECT_FALSE(backlogFuture.IsSet()); EXPECT_TRUE(backlogFuture.Get().IsOK()); |