# coding=utf-8

import os
import errno
import socket
import random
import logging
import platform
import threading

import six

UI16MAXVAL = (1 << 16) - 1
logger = logging.getLogger(__name__)


class PortManagerException(Exception):
    pass


class PortManager(object):
    """
    See documentation here

    https://wiki.yandex-team.ru/yatool/test/#python-acquire-ports
    """

    def __init__(self, sync_dir=None):
        self._sync_dir = sync_dir or os.environ.get('PORT_SYNC_PATH')
        if self._sync_dir:
            _makedirs(self._sync_dir)

        self._valid_range = get_valid_port_range()
        self._valid_port_count = self._count_valid_ports()
        self._filelocks = {}
        self._lock = threading.Lock()

    def __enter__(self):
        return self

    def __exit__(self, type, value, traceback):
        self.release()

    def get_port(self, port=0):
        '''
        Gets free TCP port
        '''
        return self.get_tcp_port(port)

    def get_tcp_port(self, port=0):
        '''
        Gets free TCP port
        '''
        return self._get_port(port, socket.SOCK_STREAM)

    def get_udp_port(self, port=0):
        '''
        Gets free UDP port
        '''
        return self._get_port(port, socket.SOCK_DGRAM)

    def get_tcp_and_udp_port(self, port=0):
        '''
        Gets one free port for use in both TCP and UDP protocols
        '''
        if port and self._no_random_ports():
            return port

        retries = 20
        while retries > 0:
            retries -= 1

            result_port = self.get_tcp_port()
            if not self.is_port_free(result_port, socket.SOCK_DGRAM):
                self.release_port(result_port)
            # Don't try to _capture_port(), it's already captured in the get_tcp_port()
            return result_port
        raise Exception('Failed to find port')

    def release_port(self, port):
        with self._lock:
            self._release_port_no_lock(port)

    def _release_port_no_lock(self, port):
        filelock = self._filelocks.pop(port, None)
        if filelock:
            filelock.release()

    def release(self):
        with self._lock:
            while self._filelocks:
                _, filelock = self._filelocks.popitem()
                if filelock:
                    filelock.release()

    def get_port_range(self, start_port, count, random_start=True):
        assert count > 0
        if start_port and self._no_random_ports():
            return start_port

        candidates = []

        def drop_candidates():
            for port in candidates:
                self._release_port_no_lock(port)
            candidates[:] = []

        with self._lock:
            for attempts in six.moves.range(128):
                for left, right in self._valid_range:
                    if right - left < count:
                        continue

                    if random_start:
                        start = random.randint(left, right - ((right - left) // 2))
                    else:
                        start = left
                    for probe_port in six.moves.range(start, right):
                        if self._capture_port_no_lock(probe_port, socket.SOCK_STREAM):
                            candidates.append(probe_port)
                        else:
                            drop_candidates()

                        if len(candidates) == count:
                            return candidates[0]
                    # Can't find required number of ports without gap in the current range
                    drop_candidates()

            raise PortManagerException(
                "Failed to find valid port range (start_port: {} count: {}) (range: {} used: {})".format(
                    start_port, count, self._valid_range, self._filelocks
                )
            )

    def _count_valid_ports(self):
        res = 0
        for left, right in self._valid_range:
            res += right - left
        assert res, ('There are no available valid ports', self._valid_range)
        return res

    def _get_port(self, port, sock_type):
        if port and self._no_random_ports():
            return port

        if len(self._filelocks) >= self._valid_port_count:
            raise PortManagerException("All valid ports are taken ({}): {}".format(self._valid_range, self._filelocks))

        salt = random.randint(0, UI16MAXVAL)
        for attempt in six.moves.range(self._valid_port_count):
            probe_port = (salt + attempt) % self._valid_port_count

            for left, right in self._valid_range:
                if probe_port >= (right - left):
                    probe_port -= right - left
                else:
                    probe_port += left
                    break
            if not self._capture_port(probe_port, sock_type):
                continue
            return probe_port

        raise PortManagerException(
            "Failed to find valid port (range: {} used: {})".format(self._valid_range, self._filelocks)
        )

    def _capture_port(self, port, sock_type):
        with self._lock:
            return self._capture_port_no_lock(port, sock_type)

    def is_port_free(self, port, sock_type=socket.SOCK_STREAM):
        sock = socket.socket(socket.AF_INET6, sock_type)
        try:
            sock.bind(('::', port))
            sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        except socket.error as e:
            if e.errno == errno.EADDRINUSE:
                return False
            raise
        finally:
            sock.close()
        return True

    def _capture_port_no_lock(self, port, sock_type):
        if port in self._filelocks:
            return False

        filelock = None
        if self._sync_dir:
            # yatest.common should try to be hermetic and don't have peerdirs
            # otherwise, PYTEST_SCRIPT (aka USE_ARCADIA_PYTHON=no) won't work
            import library.python.filelock

            filelock = library.python.filelock.FileLock(os.path.join(self._sync_dir, str(port)))
            if not filelock.acquire(blocking=False):
                return False
            if self.is_port_free(port, sock_type):
                self._filelocks[port] = filelock
                return True
            else:
                filelock.release()
                return False

        if self.is_port_free(port, sock_type):
            self._filelocks[port] = filelock
            return True
        if filelock:
            filelock.release()
        return False

    def _no_random_ports(self):
        return os.environ.get("NO_RANDOM_PORTS")


def get_valid_port_range():
    first_valid = 1025
    last_valid = UI16MAXVAL

    given_range = os.environ.get('VALID_PORT_RANGE')
    if given_range and ':' in given_range:
        return [list(int(x) for x in given_range.split(':', 2))]

    first_eph, last_eph = get_ephemeral_range()
    first_invalid = max(first_eph, first_valid)
    last_invalid = min(last_eph, last_valid)

    ranges = []
    if first_invalid > first_valid:
        ranges.append((first_valid, first_invalid - 1))
    if last_invalid < last_valid:
        ranges.append((last_invalid + 1, last_valid))
    return ranges


def get_ephemeral_range():
    if platform.system() == 'Linux':
        filename = "/proc/sys/net/ipv4/ip_local_port_range"
        if os.path.exists(filename):
            with open(filename) as afile:
                data = afile.read(1024)  # fix for musl
            port_range = tuple(map(int, data.strip().split()))
            if len(port_range) == 2:
                return port_range
            else:
                logger.warning("Bad ip_local_port_range format: '%s'. Going to use IANA suggestion", data)
    elif platform.system() == 'Darwin':
        first = _sysctlbyname_uint("net.inet.ip.portrange.first")
        last = _sysctlbyname_uint("net.inet.ip.portrange.last")
        if first and last:
            return first, last
    # IANA suggestion
    return (1 << 15) + (1 << 14), UI16MAXVAL


def _sysctlbyname_uint(name):
    try:
        from ctypes import CDLL, c_uint, byref
        from ctypes.util import find_library
    except ImportError:
        return

    libc = CDLL(find_library("c"))
    size = c_uint(0)
    res = c_uint(0)
    libc.sysctlbyname(name, None, byref(size), None, 0)
    libc.sysctlbyname(name, byref(res), byref(size), None, 0)
    return res.value


def _makedirs(path):
    try:
        os.makedirs(path)
    except OSError as e:
        if e.errno == errno.EEXIST:
            return
        raise