aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVitalii Gridnev <gridnevvvit@gmail.com>2022-03-23 21:15:10 +0300
committerVitalii Gridnev <gridnevvvit@gmail.com>2022-03-23 21:15:10 +0300
commitd594ed61159ece689653e49e3de93ec994095dca (patch)
tree5a73b4210b0858b900660218ef499da4eb3202b2
parent7afd7916c878f99b1a666ff6b3027595d740275b (diff)
downloadydb-d594ed61159ece689653e49e3de93ec994095dca.tar.gz
use node id to locate endpoint to execute request on session KIKIMR-11464
ref:8fc064e45abb6c6f12e8d3bc0ea1673b98e3155f
-rw-r--r--ydb/public/sdk/python/ydb/_session_impl.py2
-rw-r--r--ydb/public/sdk/python/ydb/aio/connection.py15
-rw-r--r--ydb/public/sdk/python/ydb/aio/pool.py14
-rw-r--r--ydb/public/sdk/python/ydb/connection.py27
-rw-r--r--ydb/public/sdk/python/ydb/pool.py13
-rw-r--r--ydb/public/sdk/python/ydb/resolver.py18
6 files changed, 70 insertions, 19 deletions
diff --git a/ydb/public/sdk/python/ydb/_session_impl.py b/ydb/public/sdk/python/ydb/_session_impl.py
index e1906299ec..c320867671 100644
--- a/ydb/public/sdk/python/ydb/_session_impl.py
+++ b/ydb/public/sdk/python/ydb/_session_impl.py
@@ -237,7 +237,7 @@ def initialize_session(rpc_state, response_pb, session_state, session):
issues._process_response(response_pb.operation)
message = _apis.ydb_table.CreateSessionResult()
response_pb.operation.result.Unpack(message)
- session_state.set_id(message.session_id).attach_endpoint(rpc_state.endpoint)
+ session_state.set_id(message.session_id).attach_endpoint(rpc_state.endpoint_key)
return session
diff --git a/ydb/public/sdk/python/ydb/aio/connection.py b/ydb/public/sdk/python/ydb/aio/connection.py
index f660c3eef0..88ab738c6a 100644
--- a/ydb/public/sdk/python/ydb/aio/connection.py
+++ b/ydb/public/sdk/python/ydb/aio/connection.py
@@ -18,6 +18,7 @@ from ydb.connection import (
YDB_DATABASE_HEADER,
YDB_TRACE_ID_HEADER,
YDB_REQUEST_TYPE_HEADER,
+ EndpointKey,
)
from ydb.driver import DriverConfig
from ydb.settings import BaseRequestSettings
@@ -71,8 +72,8 @@ class _RpcState(RpcState):
"_trailing_metadata",
)
- def __init__(self, stub_instance: Any, rpc_name: str, endpoint: str):
- super().__init__(stub_instance, rpc_name, endpoint)
+ def __init__(self, stub_instance: Any, rpc_name: str, endpoint: str, endpoint_key):
+ super().__init__(stub_instance, rpc_name, endpoint, endpoint_key)
async def __call__(self, *args, **kwargs):
resp = self.rpc(*args, **kwargs)
@@ -105,6 +106,8 @@ class Connection:
"lock",
"calls",
"closing",
+ "endpoint_key",
+ "node_id",
)
def __init__(
@@ -115,6 +118,10 @@ class Connection:
):
global _stubs_list
self.endpoint = endpoint
+ self.endpoint_key = EndpointKey(
+ self.endpoint, getattr(endpoint_options, "node_id", None)
+ )
+ self.node_id = getattr(endpoint_options, "node_id", None)
self._channel = channel_factory(
self.endpoint, driver_config, grpc.aio, endpoint_options=endpoint_options
)
@@ -141,7 +148,9 @@ class Connection:
)
_set_server_timeouts(request, settings, timeout)
self._prepare_stub_instance(stub)
- rpc_state = _RpcState(self._stub_instances[stub], rpc_name, self.endpoint)
+ rpc_state = _RpcState(
+ self._stub_instances[stub], rpc_name, self.endpoint, self.endpoint_key
+ )
logger.debug("%s: creating call state", rpc_state)
if self.closing:
diff --git a/ydb/public/sdk/python/ydb/aio/pool.py b/ydb/public/sdk/python/ydb/aio/pool.py
index b0593bce00..8b76d10d96 100644
--- a/ydb/public/sdk/python/ydb/aio/pool.py
+++ b/ydb/public/sdk/python/ydb/aio/pool.py
@@ -30,7 +30,16 @@ class ConnectionsCache(_ConnectionsCache):
else:
await asyncio.wait_for(self._event.wait(), timeout=wait_timeout)
- if preferred_endpoint is not None and preferred_endpoint in self.connections:
+ if (
+ preferred_endpoint is not None
+ and preferred_endpoint.node_id in self.connections_by_node_id
+ ):
+ return self.connections_by_node_id[preferred_endpoint.node_id]
+
+ if (
+ preferred_endpoint is not None
+ and preferred_endpoint.endpoint in self.connections
+ ):
return self.connections[preferred_endpoint]
for conn_lst in self.conn_lst_order:
@@ -52,6 +61,8 @@ class ConnectionsCache(_ConnectionsCache):
if preferred:
self.preferred[connection.endpoint] = connection
+
+ self.connections_by_node_id[connection.node_id] = connection
self.connections[connection.endpoint] = connection
self._event.set()
@@ -66,6 +77,7 @@ class ConnectionsCache(_ConnectionsCache):
self._fast_fail_event.set()
def remove(self, connection):
+ self.connections_by_node_id.pop(connection.node_id, None)
self.preferred.pop(connection.endpoint, None)
self.connections.pop(connection.endpoint, None)
self.outdated.pop(connection.endpoint, None)
diff --git a/ydb/public/sdk/python/ydb/connection.py b/ydb/public/sdk/python/ydb/connection.py
index 1500f4c8f8..a51736728b 100644
--- a/ydb/public/sdk/python/ydb/connection.py
+++ b/ydb/public/sdk/python/ydb/connection.py
@@ -163,10 +163,11 @@ def _get_request_timeout(settings):
class EndpointOptions(object):
- __slots__ = ("ssl_target_name_override",)
+ __slots__ = ("ssl_target_name_override", "node_id")
- def __init__(self, ssl_target_name_override=None):
+ def __init__(self, ssl_target_name_override=None, node_id=None):
self.ssl_target_name_override = ssl_target_name_override
+ self.node_id = node_id
def _construct_channel_options(driver_config, endpoint_options=None):
@@ -223,9 +224,10 @@ class _RpcState(object):
"endpoint",
"rendezvous",
"metadata_kv",
+ "endpoint_key",
)
- def __init__(self, stub_instance, rpc_name, endpoint):
+ def __init__(self, stub_instance, rpc_name, endpoint, endpoint_key):
"""Stores all RPC related data"""
self.rpc_name = rpc_name
self.rpc = getattr(stub_instance, rpc_name)
@@ -233,6 +235,7 @@ class _RpcState(object):
self.endpoint = endpoint
self.rendezvous = None
self.metadata_kv = None
+ self.endpoint_key = endpoint_key
def __str__(self):
return "RpcState(%s, %s, %s)" % (self.rpc_name, self.request_id, self.endpoint)
@@ -318,6 +321,14 @@ def channel_factory(
)
+class EndpointKey(object):
+ __slots__ = ("endpoint", "node_id")
+
+ def __init__(self, endpoint, node_id):
+ self.endpoint = endpoint
+ self.node_id = node_id
+
+
class Connection(object):
__slots__ = (
"endpoint",
@@ -330,6 +341,8 @@ class Connection(object):
"lock",
"calls",
"closing",
+ "endpoint_key",
+ "node_id",
)
def __init__(self, endpoint, driver_config=None, endpoint_options=None):
@@ -341,6 +354,10 @@ class Connection(object):
"""
global _stubs_list
self.endpoint = endpoint
+ self.node_id = getattr(endpoint_options, "node_id", None)
+ self.endpoint_key = EndpointKey(
+ endpoint, getattr(endpoint_options, "node_id", None)
+ )
self._channel = channel_factory(
self.endpoint, driver_config, endpoint_options=endpoint_options
)
@@ -368,7 +385,9 @@ class Connection(object):
)
_set_server_timeouts(request, settings, timeout)
self._prepare_stub_instance(stub)
- rpc_state = _RpcState(self._stub_instances[stub], rpc_name, self.endpoint)
+ rpc_state = _RpcState(
+ self._stub_instances[stub], rpc_name, self.endpoint, self.endpoint_key
+ )
logger.debug("%s: creating call state", rpc_state)
with self.lock:
if self.closing:
diff --git a/ydb/public/sdk/python/ydb/pool.py b/ydb/public/sdk/python/ydb/pool.py
index 471a33c87a..dfda0adff2 100644
--- a/ydb/public/sdk/python/ydb/pool.py
+++ b/ydb/public/sdk/python/ydb/pool.py
@@ -19,6 +19,7 @@ class ConnectionsCache(object):
self.tracer = tracer
self.lock = threading.RLock()
self.connections = collections.OrderedDict()
+ self.connections_by_node_id = collections.OrderedDict()
self.outdated = collections.OrderedDict()
self.subscriptions = set()
self.preferred = collections.OrderedDict()
@@ -39,6 +40,8 @@ class ConnectionsCache(object):
with self.lock:
if preferred:
self.preferred[connection.endpoint] = connection
+
+ self.connections_by_node_id[connection.node_id] = connection
self.connections[connection.endpoint] = connection
subscriptions = list(self.subscriptions)
self.subscriptions.clear()
@@ -128,9 +131,14 @@ class ConnectionsCache(object):
with self.lock:
if (
preferred_endpoint is not None
- and preferred_endpoint in self.connections
+ and preferred_endpoint.node_id in self.connections_by_node_id
+ ):
+ return self.connections_by_node_id[preferred_endpoint.node_id]
+
+ if (
+ preferred_endpoint is not None
+ and preferred_endpoint.endpoint in self.connections
):
- tracing.trace(self.tracer, {"found_preferred_endpoint": True})
return self.connections[preferred_endpoint]
for conn_lst in self.conn_lst_order:
@@ -146,6 +154,7 @@ class ConnectionsCache(object):
def remove(self, connection):
with self.lock:
+ self.connections_by_node_id.pop(connection.node_id, None)
self.preferred.pop(connection.endpoint, None)
self.connections.pop(connection.endpoint, None)
self.outdated.pop(connection.endpoint, None)
diff --git a/ydb/public/sdk/python/ydb/resolver.py b/ydb/public/sdk/python/ydb/resolver.py
index 54712c0085..b40ae984dc 100644
--- a/ydb/public/sdk/python/ydb/resolver.py
+++ b/ydb/public/sdk/python/ydb/resolver.py
@@ -19,6 +19,7 @@ class EndpointInfo(object):
"ipv4_addrs",
"ipv6_addrs",
"ssl_target_name_override",
+ "node_id",
)
def __init__(self, endpoint_info):
@@ -30,19 +31,20 @@ class EndpointInfo(object):
self.ipv4_addrs = tuple(endpoint_info.ip_v4)
self.ipv6_addrs = tuple(endpoint_info.ip_v6)
self.ssl_target_name_override = endpoint_info.ssl_target_name_override
+ self.node_id = endpoint_info.node_id
def endpoints_with_options(self):
+ ssl_target_name_override = None
if self.ssl:
if self.ssl_target_name_override:
- endpoint_options = conn_impl.EndpointOptions(
- self.ssl_target_name_override
- )
+ ssl_target_name_override = self.ssl_target_name_override
elif self.ipv6_addrs or self.ipv4_addrs:
- endpoint_options = conn_impl.EndpointOptions(self.address)
- else:
- endpoint_options = None
- else:
- endpoint_options = None
+ ssl_target_name_override = self.address
+
+ endpoint_options = conn_impl.EndpointOptions(
+ ssl_target_name_override=ssl_target_name_override, node_id=self.node_id
+ )
+
if self.ipv6_addrs or self.ipv4_addrs:
for ipv6addr in self.ipv6_addrs:
yield ("ipv6:[%s]:%s" % (ipv6addr, self.port), endpoint_options)