diff options
author | Vitalii Gridnev <gridnevvvit@gmail.com> | 2022-03-23 21:15:10 +0300 |
---|---|---|
committer | Vitalii Gridnev <gridnevvvit@gmail.com> | 2022-03-23 21:15:10 +0300 |
commit | d594ed61159ece689653e49e3de93ec994095dca (patch) | |
tree | 5a73b4210b0858b900660218ef499da4eb3202b2 | |
parent | 7afd7916c878f99b1a666ff6b3027595d740275b (diff) | |
download | ydb-d594ed61159ece689653e49e3de93ec994095dca.tar.gz |
use node id to locate endpoint to execute request on session KIKIMR-11464
ref:8fc064e45abb6c6f12e8d3bc0ea1673b98e3155f
-rw-r--r-- | ydb/public/sdk/python/ydb/_session_impl.py | 2 | ||||
-rw-r--r-- | ydb/public/sdk/python/ydb/aio/connection.py | 15 | ||||
-rw-r--r-- | ydb/public/sdk/python/ydb/aio/pool.py | 14 | ||||
-rw-r--r-- | ydb/public/sdk/python/ydb/connection.py | 27 | ||||
-rw-r--r-- | ydb/public/sdk/python/ydb/pool.py | 13 | ||||
-rw-r--r-- | ydb/public/sdk/python/ydb/resolver.py | 18 |
6 files changed, 70 insertions, 19 deletions
diff --git a/ydb/public/sdk/python/ydb/_session_impl.py b/ydb/public/sdk/python/ydb/_session_impl.py index e1906299ec..c320867671 100644 --- a/ydb/public/sdk/python/ydb/_session_impl.py +++ b/ydb/public/sdk/python/ydb/_session_impl.py @@ -237,7 +237,7 @@ def initialize_session(rpc_state, response_pb, session_state, session): issues._process_response(response_pb.operation) message = _apis.ydb_table.CreateSessionResult() response_pb.operation.result.Unpack(message) - session_state.set_id(message.session_id).attach_endpoint(rpc_state.endpoint) + session_state.set_id(message.session_id).attach_endpoint(rpc_state.endpoint_key) return session diff --git a/ydb/public/sdk/python/ydb/aio/connection.py b/ydb/public/sdk/python/ydb/aio/connection.py index f660c3eef0..88ab738c6a 100644 --- a/ydb/public/sdk/python/ydb/aio/connection.py +++ b/ydb/public/sdk/python/ydb/aio/connection.py @@ -18,6 +18,7 @@ from ydb.connection import ( YDB_DATABASE_HEADER, YDB_TRACE_ID_HEADER, YDB_REQUEST_TYPE_HEADER, + EndpointKey, ) from ydb.driver import DriverConfig from ydb.settings import BaseRequestSettings @@ -71,8 +72,8 @@ class _RpcState(RpcState): "_trailing_metadata", ) - def __init__(self, stub_instance: Any, rpc_name: str, endpoint: str): - super().__init__(stub_instance, rpc_name, endpoint) + def __init__(self, stub_instance: Any, rpc_name: str, endpoint: str, endpoint_key): + super().__init__(stub_instance, rpc_name, endpoint, endpoint_key) async def __call__(self, *args, **kwargs): resp = self.rpc(*args, **kwargs) @@ -105,6 +106,8 @@ class Connection: "lock", "calls", "closing", + "endpoint_key", + "node_id", ) def __init__( @@ -115,6 +118,10 @@ class Connection: ): global _stubs_list self.endpoint = endpoint + self.endpoint_key = EndpointKey( + self.endpoint, getattr(endpoint_options, "node_id", None) + ) + self.node_id = getattr(endpoint_options, "node_id", None) self._channel = channel_factory( self.endpoint, driver_config, grpc.aio, endpoint_options=endpoint_options ) @@ -141,7 +148,9 @@ class Connection: ) _set_server_timeouts(request, settings, timeout) self._prepare_stub_instance(stub) - rpc_state = _RpcState(self._stub_instances[stub], rpc_name, self.endpoint) + rpc_state = _RpcState( + self._stub_instances[stub], rpc_name, self.endpoint, self.endpoint_key + ) logger.debug("%s: creating call state", rpc_state) if self.closing: diff --git a/ydb/public/sdk/python/ydb/aio/pool.py b/ydb/public/sdk/python/ydb/aio/pool.py index b0593bce00..8b76d10d96 100644 --- a/ydb/public/sdk/python/ydb/aio/pool.py +++ b/ydb/public/sdk/python/ydb/aio/pool.py @@ -30,7 +30,16 @@ class ConnectionsCache(_ConnectionsCache): else: await asyncio.wait_for(self._event.wait(), timeout=wait_timeout) - if preferred_endpoint is not None and preferred_endpoint in self.connections: + if ( + preferred_endpoint is not None + and preferred_endpoint.node_id in self.connections_by_node_id + ): + return self.connections_by_node_id[preferred_endpoint.node_id] + + if ( + preferred_endpoint is not None + and preferred_endpoint.endpoint in self.connections + ): return self.connections[preferred_endpoint] for conn_lst in self.conn_lst_order: @@ -52,6 +61,8 @@ class ConnectionsCache(_ConnectionsCache): if preferred: self.preferred[connection.endpoint] = connection + + self.connections_by_node_id[connection.node_id] = connection self.connections[connection.endpoint] = connection self._event.set() @@ -66,6 +77,7 @@ class ConnectionsCache(_ConnectionsCache): self._fast_fail_event.set() def remove(self, connection): + self.connections_by_node_id.pop(connection.node_id, None) self.preferred.pop(connection.endpoint, None) self.connections.pop(connection.endpoint, None) self.outdated.pop(connection.endpoint, None) diff --git a/ydb/public/sdk/python/ydb/connection.py b/ydb/public/sdk/python/ydb/connection.py index 1500f4c8f8..a51736728b 100644 --- a/ydb/public/sdk/python/ydb/connection.py +++ b/ydb/public/sdk/python/ydb/connection.py @@ -163,10 +163,11 @@ def _get_request_timeout(settings): class EndpointOptions(object): - __slots__ = ("ssl_target_name_override",) + __slots__ = ("ssl_target_name_override", "node_id") - def __init__(self, ssl_target_name_override=None): + def __init__(self, ssl_target_name_override=None, node_id=None): self.ssl_target_name_override = ssl_target_name_override + self.node_id = node_id def _construct_channel_options(driver_config, endpoint_options=None): @@ -223,9 +224,10 @@ class _RpcState(object): "endpoint", "rendezvous", "metadata_kv", + "endpoint_key", ) - def __init__(self, stub_instance, rpc_name, endpoint): + def __init__(self, stub_instance, rpc_name, endpoint, endpoint_key): """Stores all RPC related data""" self.rpc_name = rpc_name self.rpc = getattr(stub_instance, rpc_name) @@ -233,6 +235,7 @@ class _RpcState(object): self.endpoint = endpoint self.rendezvous = None self.metadata_kv = None + self.endpoint_key = endpoint_key def __str__(self): return "RpcState(%s, %s, %s)" % (self.rpc_name, self.request_id, self.endpoint) @@ -318,6 +321,14 @@ def channel_factory( ) +class EndpointKey(object): + __slots__ = ("endpoint", "node_id") + + def __init__(self, endpoint, node_id): + self.endpoint = endpoint + self.node_id = node_id + + class Connection(object): __slots__ = ( "endpoint", @@ -330,6 +341,8 @@ class Connection(object): "lock", "calls", "closing", + "endpoint_key", + "node_id", ) def __init__(self, endpoint, driver_config=None, endpoint_options=None): @@ -341,6 +354,10 @@ class Connection(object): """ global _stubs_list self.endpoint = endpoint + self.node_id = getattr(endpoint_options, "node_id", None) + self.endpoint_key = EndpointKey( + endpoint, getattr(endpoint_options, "node_id", None) + ) self._channel = channel_factory( self.endpoint, driver_config, endpoint_options=endpoint_options ) @@ -368,7 +385,9 @@ class Connection(object): ) _set_server_timeouts(request, settings, timeout) self._prepare_stub_instance(stub) - rpc_state = _RpcState(self._stub_instances[stub], rpc_name, self.endpoint) + rpc_state = _RpcState( + self._stub_instances[stub], rpc_name, self.endpoint, self.endpoint_key + ) logger.debug("%s: creating call state", rpc_state) with self.lock: if self.closing: diff --git a/ydb/public/sdk/python/ydb/pool.py b/ydb/public/sdk/python/ydb/pool.py index 471a33c87a..dfda0adff2 100644 --- a/ydb/public/sdk/python/ydb/pool.py +++ b/ydb/public/sdk/python/ydb/pool.py @@ -19,6 +19,7 @@ class ConnectionsCache(object): self.tracer = tracer self.lock = threading.RLock() self.connections = collections.OrderedDict() + self.connections_by_node_id = collections.OrderedDict() self.outdated = collections.OrderedDict() self.subscriptions = set() self.preferred = collections.OrderedDict() @@ -39,6 +40,8 @@ class ConnectionsCache(object): with self.lock: if preferred: self.preferred[connection.endpoint] = connection + + self.connections_by_node_id[connection.node_id] = connection self.connections[connection.endpoint] = connection subscriptions = list(self.subscriptions) self.subscriptions.clear() @@ -128,9 +131,14 @@ class ConnectionsCache(object): with self.lock: if ( preferred_endpoint is not None - and preferred_endpoint in self.connections + and preferred_endpoint.node_id in self.connections_by_node_id + ): + return self.connections_by_node_id[preferred_endpoint.node_id] + + if ( + preferred_endpoint is not None + and preferred_endpoint.endpoint in self.connections ): - tracing.trace(self.tracer, {"found_preferred_endpoint": True}) return self.connections[preferred_endpoint] for conn_lst in self.conn_lst_order: @@ -146,6 +154,7 @@ class ConnectionsCache(object): def remove(self, connection): with self.lock: + self.connections_by_node_id.pop(connection.node_id, None) self.preferred.pop(connection.endpoint, None) self.connections.pop(connection.endpoint, None) self.outdated.pop(connection.endpoint, None) diff --git a/ydb/public/sdk/python/ydb/resolver.py b/ydb/public/sdk/python/ydb/resolver.py index 54712c0085..b40ae984dc 100644 --- a/ydb/public/sdk/python/ydb/resolver.py +++ b/ydb/public/sdk/python/ydb/resolver.py @@ -19,6 +19,7 @@ class EndpointInfo(object): "ipv4_addrs", "ipv6_addrs", "ssl_target_name_override", + "node_id", ) def __init__(self, endpoint_info): @@ -30,19 +31,20 @@ class EndpointInfo(object): self.ipv4_addrs = tuple(endpoint_info.ip_v4) self.ipv6_addrs = tuple(endpoint_info.ip_v6) self.ssl_target_name_override = endpoint_info.ssl_target_name_override + self.node_id = endpoint_info.node_id def endpoints_with_options(self): + ssl_target_name_override = None if self.ssl: if self.ssl_target_name_override: - endpoint_options = conn_impl.EndpointOptions( - self.ssl_target_name_override - ) + ssl_target_name_override = self.ssl_target_name_override elif self.ipv6_addrs or self.ipv4_addrs: - endpoint_options = conn_impl.EndpointOptions(self.address) - else: - endpoint_options = None - else: - endpoint_options = None + ssl_target_name_override = self.address + + endpoint_options = conn_impl.EndpointOptions( + ssl_target_name_override=ssl_target_name_override, node_id=self.node_id + ) + if self.ipv6_addrs or self.ipv4_addrs: for ipv6addr in self.ipv6_addrs: yield ("ipv6:[%s]:%s" % (ipv6addr, self.port), endpoint_options) |