# coding=utf-8
import os
import errno
import socket
import random
import logging
import platform
import threading
import six
UI16MAXVAL = (1 << 16) - 1
logger = logging.getLogger(__name__)
class PortManagerException(Exception):
pass
class PortManager(object):
"""
See documentation here
https://wiki.yandex-team.ru/yatool/test/#python-acquire-ports
"""
def __init__(self, sync_dir=None):
self._sync_dir = sync_dir or os.environ.get('PORT_SYNC_PATH')
if self._sync_dir:
_makedirs(self._sync_dir)
self._valid_range = get_valid_port_range()
self._valid_port_count = self._count_valid_ports()
self._filelocks = {}
self._lock = threading.Lock()
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
self.release()
def get_port(self, port=0):
'''
Gets free TCP port
'''
return self.get_tcp_port(port)
def get_tcp_port(self, port=0):
'''
Gets free TCP port
'''
return self._get_port(port, socket.SOCK_STREAM)
def get_udp_port(self, port=0):
'''
Gets free UDP port
'''
return self._get_port(port, socket.SOCK_DGRAM)
def get_tcp_and_udp_port(self, port=0):
'''
Gets one free port for use in both TCP and UDP protocols
'''
if port and self._no_random_ports():
return port
retries = 20
while retries > 0:
retries -= 1
result_port = self.get_tcp_port()
if not self.is_port_free(result_port, socket.SOCK_DGRAM):
self.release_port(result_port)
# Don't try to _capture_port(), it's already captured in the get_tcp_port()
return result_port
raise Exception('Failed to find port')
def release_port(self, port):
with self._lock:
self._release_port_no_lock(port)
def _release_port_no_lock(self, port):
filelock = self._filelocks.pop(port, None)
if filelock:
filelock.release()
def release(self):
with self._lock:
while self._filelocks:
_, filelock = self._filelocks.popitem()
if filelock:
filelock.release()
def get_port_range(self, start_port, count, random_start=True):
assert count > 0
if start_port and self._no_random_ports():
return start_port
candidates = []
def drop_candidates():
for port in candidates:
self._release_port_no_lock(port)
candidates[:] = []
with self._lock:
for attempts in six.moves.range(128):
for left, right in self._valid_range:
if right - left < count:
continue
if random_start:
start = random.randint(left, right - ((right - left) // 2))
else:
start = left
for probe_port in six.moves.range(start, right):
if self._capture_port_no_lock(probe_port, socket.SOCK_STREAM):
candidates.append(probe_port)
else:
drop_candidates()
if len(candidates) == count:
return candidates[0]
# Can't find required number of ports without gap in the current range
drop_candidates()
raise PortManagerException(
"Failed to find valid port range (start_port: {} count: {}) (range: {} used: {})".format(
start_port, count, self._valid_range, self._filelocks
)
)
def _count_valid_ports(self):
res = 0
for left, right in self._valid_range:
res += right - left
assert res, ('There are no available valid ports', self._valid_range)
return res
def _get_port(self, port, sock_type):
if port and self._no_random_ports():
return port
if len(self._filelocks) >= self._valid_port_count:
raise PortManagerException("All valid ports are taken ({}): {}".format(self._valid_range, self._filelocks))
salt = random.randint(0, UI16MAXVAL)
for attempt in six.moves.range(self._valid_port_count):
probe_port = (salt + attempt) % self._valid_port_count
for left, right in self._valid_range:
if probe_port >= (right - left):
probe_port -= right - left
else:
probe_port += left
break
if not self._capture_port(probe_port, sock_type):
continue
return probe_port
raise PortManagerException(
"Failed to find valid port (range: {} used: {})".format(self._valid_range, self._filelocks)
)
def _capture_port(self, port, sock_type):
with self._lock:
return self._capture_port_no_lock(port, sock_type)
def is_port_free(self, port, sock_type=socket.SOCK_STREAM):
sock = socket.socket(socket.AF_INET6, sock_type)
try:
sock.bind(('::', port))
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
except socket.error as e:
if e.errno == errno.EADDRINUSE:
return False
raise
finally:
sock.close()
return True
def _capture_port_no_lock(self, port, sock_type):
if port in self._filelocks:
return False
filelock = None
if self._sync_dir:
# yatest.common should try to be hermetic and don't have peerdirs
# otherwise, PYTEST_SCRIPT (aka USE_ARCADIA_PYTHON=no) won't work
import library.python.filelock
filelock = library.python.filelock.FileLock(os.path.join(self._sync_dir, str(port)))
if not filelock.acquire(blocking=False):
return False
if self.is_port_free(port, sock_type):
self._filelocks[port] = filelock
return True
else:
filelock.release()
return False
if self.is_port_free(port, sock_type):
self._filelocks[port] = filelock
return True
if filelock:
filelock.release()
return False
def _no_random_ports(self):
return os.environ.get("NO_RANDOM_PORTS")
def get_valid_port_range():
first_valid = 1025
last_valid = UI16MAXVAL
given_range = os.environ.get('VALID_PORT_RANGE')
if given_range and ':' in given_range:
return [list(int(x) for x in given_range.split(':', 2))]
first_eph, last_eph = get_ephemeral_range()
first_invalid = max(first_eph, first_valid)
last_invalid = min(last_eph, last_valid)
ranges = []
if first_invalid > first_valid:
ranges.append((first_valid, first_invalid - 1))
if last_invalid < last_valid:
ranges.append((last_invalid + 1, last_valid))
return ranges
def get_ephemeral_range():
if platform.system() == 'Linux':
filename = "/proc/sys/net/ipv4/ip_local_port_range"
if os.path.exists(filename):
with open(filename) as afile:
data = afile.read(1024) # fix for musl
port_range = tuple(map(int, data.strip().split()))
if len(port_range) == 2:
return port_range
else:
logger.warning("Bad ip_local_port_range format: '%s'. Going to use IANA suggestion", data)
elif platform.system() == 'Darwin':
first = _sysctlbyname_uint("net.inet.ip.portrange.first")
last = _sysctlbyname_uint("net.inet.ip.portrange.last")
if first and last:
return first, last
# IANA suggestion
return (1 << 15) + (1 << 14), UI16MAXVAL
def _sysctlbyname_uint(name):
try:
from ctypes import CDLL, c_uint, byref
from ctypes.util import find_library
except ImportError:
return
libc = CDLL(find_library("c"))
size = c_uint(0)
res = c_uint(0)
libc.sysctlbyname(name, None, byref(size), None, 0)
libc.sysctlbyname(name, byref(res), byref(size), None, 0)
return res.value
def _makedirs(path):
try:
os.makedirs(path)
except OSError as e:
if e.errno == errno.EEXIST:
return
raise