diff options
author | robot-piglet <robot-piglet@yandex-team.com> | 2024-12-09 18:25:21 +0300 |
---|---|---|
committer | robot-piglet <robot-piglet@yandex-team.com> | 2024-12-09 19:18:57 +0300 |
commit | 13374e0884578812cda7697d0c5680122db59a37 (patch) | |
tree | 30a022eb841035299deb2b8c393b2902f0c21735 /contrib/python/websocket-client/websocket | |
parent | c7ade6d3bf7cd492235a61b77153351e422a28f3 (diff) | |
download | ydb-13374e0884578812cda7697d0c5680122db59a37.tar.gz |
Intermediate changes
commit_hash:034150f557268506d7bc0cbd8b5becf65f765593
Diffstat (limited to 'contrib/python/websocket-client/websocket')
26 files changed, 5758 insertions, 0 deletions
diff --git a/contrib/python/websocket-client/websocket/__init__.py b/contrib/python/websocket-client/websocket/__init__.py new file mode 100644 index 0000000000..559b38a6b7 --- /dev/null +++ b/contrib/python/websocket-client/websocket/__init__.py @@ -0,0 +1,26 @@ +""" +__init__.py +websocket - WebSocket client library for Python + +Copyright 2024 engn33r + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +from ._abnf import * +from ._app import WebSocketApp as WebSocketApp, setReconnect as setReconnect +from ._core import * +from ._exceptions import * +from ._logging import * +from ._socket import * + +__version__ = "1.8.0" diff --git a/contrib/python/websocket-client/websocket/_abnf.py b/contrib/python/websocket-client/websocket/_abnf.py new file mode 100644 index 0000000000..d7754e0de2 --- /dev/null +++ b/contrib/python/websocket-client/websocket/_abnf.py @@ -0,0 +1,453 @@ +import array +import os +import struct +import sys +from threading import Lock +from typing import Callable, Optional, Union + +from ._exceptions import WebSocketPayloadException, WebSocketProtocolException +from ._utils import validate_utf8 + +""" +_abnf.py +websocket - WebSocket client library for Python + +Copyright 2024 engn33r + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +try: + # If wsaccel is available, use compiled routines to mask data. + # wsaccel only provides around a 10% speed boost compared + # to the websocket-client _mask() implementation. + # Note that wsaccel is unmaintained. + from wsaccel.xormask import XorMaskerSimple + + def _mask(mask_value: array.array, data_value: array.array) -> bytes: + mask_result: bytes = XorMaskerSimple(mask_value).process(data_value) + return mask_result + +except ImportError: + # wsaccel is not available, use websocket-client _mask() + native_byteorder = sys.byteorder + + def _mask(mask_value: array.array, data_value: array.array) -> bytes: + datalen = len(data_value) + int_data_value = int.from_bytes(data_value, native_byteorder) + int_mask_value = int.from_bytes( + mask_value * (datalen // 4) + mask_value[: datalen % 4], native_byteorder + ) + return (int_data_value ^ int_mask_value).to_bytes(datalen, native_byteorder) + + +__all__ = [ + "ABNF", + "continuous_frame", + "frame_buffer", + "STATUS_NORMAL", + "STATUS_GOING_AWAY", + "STATUS_PROTOCOL_ERROR", + "STATUS_UNSUPPORTED_DATA_TYPE", + "STATUS_STATUS_NOT_AVAILABLE", + "STATUS_ABNORMAL_CLOSED", + "STATUS_INVALID_PAYLOAD", + "STATUS_POLICY_VIOLATION", + "STATUS_MESSAGE_TOO_BIG", + "STATUS_INVALID_EXTENSION", + "STATUS_UNEXPECTED_CONDITION", + "STATUS_BAD_GATEWAY", + "STATUS_TLS_HANDSHAKE_ERROR", +] + +# closing frame status codes. +STATUS_NORMAL = 1000 +STATUS_GOING_AWAY = 1001 +STATUS_PROTOCOL_ERROR = 1002 +STATUS_UNSUPPORTED_DATA_TYPE = 1003 +STATUS_STATUS_NOT_AVAILABLE = 1005 +STATUS_ABNORMAL_CLOSED = 1006 +STATUS_INVALID_PAYLOAD = 1007 +STATUS_POLICY_VIOLATION = 1008 +STATUS_MESSAGE_TOO_BIG = 1009 +STATUS_INVALID_EXTENSION = 1010 +STATUS_UNEXPECTED_CONDITION = 1011 +STATUS_SERVICE_RESTART = 1012 +STATUS_TRY_AGAIN_LATER = 1013 +STATUS_BAD_GATEWAY = 1014 +STATUS_TLS_HANDSHAKE_ERROR = 1015 + +VALID_CLOSE_STATUS = ( + STATUS_NORMAL, + STATUS_GOING_AWAY, + STATUS_PROTOCOL_ERROR, + STATUS_UNSUPPORTED_DATA_TYPE, + STATUS_INVALID_PAYLOAD, + STATUS_POLICY_VIOLATION, + STATUS_MESSAGE_TOO_BIG, + STATUS_INVALID_EXTENSION, + STATUS_UNEXPECTED_CONDITION, + STATUS_SERVICE_RESTART, + STATUS_TRY_AGAIN_LATER, + STATUS_BAD_GATEWAY, +) + + +class ABNF: + """ + ABNF frame class. + See http://tools.ietf.org/html/rfc5234 + and http://tools.ietf.org/html/rfc6455#section-5.2 + """ + + # operation code values. + OPCODE_CONT = 0x0 + OPCODE_TEXT = 0x1 + OPCODE_BINARY = 0x2 + OPCODE_CLOSE = 0x8 + OPCODE_PING = 0x9 + OPCODE_PONG = 0xA + + # available operation code value tuple + OPCODES = ( + OPCODE_CONT, + OPCODE_TEXT, + OPCODE_BINARY, + OPCODE_CLOSE, + OPCODE_PING, + OPCODE_PONG, + ) + + # opcode human readable string + OPCODE_MAP = { + OPCODE_CONT: "cont", + OPCODE_TEXT: "text", + OPCODE_BINARY: "binary", + OPCODE_CLOSE: "close", + OPCODE_PING: "ping", + OPCODE_PONG: "pong", + } + + # data length threshold. + LENGTH_7 = 0x7E + LENGTH_16 = 1 << 16 + LENGTH_63 = 1 << 63 + + def __init__( + self, + fin: int = 0, + rsv1: int = 0, + rsv2: int = 0, + rsv3: int = 0, + opcode: int = OPCODE_TEXT, + mask_value: int = 1, + data: Union[str, bytes, None] = "", + ) -> None: + """ + Constructor for ABNF. Please check RFC for arguments. + """ + self.fin = fin + self.rsv1 = rsv1 + self.rsv2 = rsv2 + self.rsv3 = rsv3 + self.opcode = opcode + self.mask_value = mask_value + if data is None: + data = "" + self.data = data + self.get_mask_key = os.urandom + + def validate(self, skip_utf8_validation: bool = False) -> None: + """ + Validate the ABNF frame. + + Parameters + ---------- + skip_utf8_validation: skip utf8 validation. + """ + if self.rsv1 or self.rsv2 or self.rsv3: + raise WebSocketProtocolException("rsv is not implemented, yet") + + if self.opcode not in ABNF.OPCODES: + raise WebSocketProtocolException("Invalid opcode %r", self.opcode) + + if self.opcode == ABNF.OPCODE_PING and not self.fin: + raise WebSocketProtocolException("Invalid ping frame.") + + if self.opcode == ABNF.OPCODE_CLOSE: + l = len(self.data) + if not l: + return + if l == 1 or l >= 126: + raise WebSocketProtocolException("Invalid close frame.") + if l > 2 and not skip_utf8_validation and not validate_utf8(self.data[2:]): + raise WebSocketProtocolException("Invalid close frame.") + + code = 256 * int(self.data[0]) + int(self.data[1]) + if not self._is_valid_close_status(code): + raise WebSocketProtocolException("Invalid close opcode %r", code) + + @staticmethod + def _is_valid_close_status(code: int) -> bool: + return code in VALID_CLOSE_STATUS or (3000 <= code < 5000) + + def __str__(self) -> str: + return f"fin={self.fin} opcode={self.opcode} data={self.data}" + + @staticmethod + def create_frame(data: Union[bytes, str], opcode: int, fin: int = 1) -> "ABNF": + """ + Create frame to send text, binary and other data. + + Parameters + ---------- + data: str + data to send. This is string value(byte array). + If opcode is OPCODE_TEXT and this value is unicode, + data value is converted into unicode string, automatically. + opcode: int + operation code. please see OPCODE_MAP. + fin: int + fin flag. if set to 0, create continue fragmentation. + """ + if opcode == ABNF.OPCODE_TEXT and isinstance(data, str): + data = data.encode("utf-8") + # mask must be set if send data from client + return ABNF(fin, 0, 0, 0, opcode, 1, data) + + def format(self) -> bytes: + """ + Format this object to string(byte array) to send data to server. + """ + if any(x not in (0, 1) for x in [self.fin, self.rsv1, self.rsv2, self.rsv3]): + raise ValueError("not 0 or 1") + if self.opcode not in ABNF.OPCODES: + raise ValueError("Invalid OPCODE") + length = len(self.data) + if length >= ABNF.LENGTH_63: + raise ValueError("data is too long") + + frame_header = chr( + self.fin << 7 + | self.rsv1 << 6 + | self.rsv2 << 5 + | self.rsv3 << 4 + | self.opcode + ).encode("latin-1") + if length < ABNF.LENGTH_7: + frame_header += chr(self.mask_value << 7 | length).encode("latin-1") + elif length < ABNF.LENGTH_16: + frame_header += chr(self.mask_value << 7 | 0x7E).encode("latin-1") + frame_header += struct.pack("!H", length) + else: + frame_header += chr(self.mask_value << 7 | 0x7F).encode("latin-1") + frame_header += struct.pack("!Q", length) + + if not self.mask_value: + if isinstance(self.data, str): + self.data = self.data.encode("utf-8") + return frame_header + self.data + mask_key = self.get_mask_key(4) + return frame_header + self._get_masked(mask_key) + + def _get_masked(self, mask_key: Union[str, bytes]) -> bytes: + s = ABNF.mask(mask_key, self.data) + + if isinstance(mask_key, str): + mask_key = mask_key.encode("utf-8") + + return mask_key + s + + @staticmethod + def mask(mask_key: Union[str, bytes], data: Union[str, bytes]) -> bytes: + """ + Mask or unmask data. Just do xor for each byte + + Parameters + ---------- + mask_key: bytes or str + 4 byte mask. + data: bytes or str + data to mask/unmask. + """ + if data is None: + data = "" + + if isinstance(mask_key, str): + mask_key = mask_key.encode("latin-1") + + if isinstance(data, str): + data = data.encode("latin-1") + + return _mask(array.array("B", mask_key), array.array("B", data)) + + +class frame_buffer: + _HEADER_MASK_INDEX = 5 + _HEADER_LENGTH_INDEX = 6 + + def __init__( + self, recv_fn: Callable[[int], int], skip_utf8_validation: bool + ) -> None: + self.recv = recv_fn + self.skip_utf8_validation = skip_utf8_validation + # Buffers over the packets from the layer beneath until desired amount + # bytes of bytes are received. + self.recv_buffer: list = [] + self.clear() + self.lock = Lock() + + def clear(self) -> None: + self.header: Optional[tuple] = None + self.length: Optional[int] = None + self.mask_value: Union[bytes, str, None] = None + + def has_received_header(self) -> bool: + return self.header is None + + def recv_header(self) -> None: + header = self.recv_strict(2) + b1 = header[0] + fin = b1 >> 7 & 1 + rsv1 = b1 >> 6 & 1 + rsv2 = b1 >> 5 & 1 + rsv3 = b1 >> 4 & 1 + opcode = b1 & 0xF + b2 = header[1] + has_mask = b2 >> 7 & 1 + length_bits = b2 & 0x7F + + self.header = (fin, rsv1, rsv2, rsv3, opcode, has_mask, length_bits) + + def has_mask(self) -> Union[bool, int]: + if not self.header: + return False + header_val: int = self.header[frame_buffer._HEADER_MASK_INDEX] + return header_val + + def has_received_length(self) -> bool: + return self.length is None + + def recv_length(self) -> None: + bits = self.header[frame_buffer._HEADER_LENGTH_INDEX] + length_bits = bits & 0x7F + if length_bits == 0x7E: + v = self.recv_strict(2) + self.length = struct.unpack("!H", v)[0] + elif length_bits == 0x7F: + v = self.recv_strict(8) + self.length = struct.unpack("!Q", v)[0] + else: + self.length = length_bits + + def has_received_mask(self) -> bool: + return self.mask_value is None + + def recv_mask(self) -> None: + self.mask_value = self.recv_strict(4) if self.has_mask() else "" + + def recv_frame(self) -> ABNF: + with self.lock: + # Header + if self.has_received_header(): + self.recv_header() + (fin, rsv1, rsv2, rsv3, opcode, has_mask, _) = self.header + + # Frame length + if self.has_received_length(): + self.recv_length() + length = self.length + + # Mask + if self.has_received_mask(): + self.recv_mask() + mask_value = self.mask_value + + # Payload + payload = self.recv_strict(length) + if has_mask: + payload = ABNF.mask(mask_value, payload) + + # Reset for next frame + self.clear() + + frame = ABNF(fin, rsv1, rsv2, rsv3, opcode, has_mask, payload) + frame.validate(self.skip_utf8_validation) + + return frame + + def recv_strict(self, bufsize: int) -> bytes: + shortage = bufsize - sum(map(len, self.recv_buffer)) + while shortage > 0: + # Limit buffer size that we pass to socket.recv() to avoid + # fragmenting the heap -- the number of bytes recv() actually + # reads is limited by socket buffer and is relatively small, + # yet passing large numbers repeatedly causes lots of large + # buffers allocated and then shrunk, which results in + # fragmentation. + bytes_ = self.recv(min(16384, shortage)) + self.recv_buffer.append(bytes_) + shortage -= len(bytes_) + + unified = b"".join(self.recv_buffer) + + if shortage == 0: + self.recv_buffer = [] + return unified + else: + self.recv_buffer = [unified[bufsize:]] + return unified[:bufsize] + + +class continuous_frame: + def __init__(self, fire_cont_frame: bool, skip_utf8_validation: bool) -> None: + self.fire_cont_frame = fire_cont_frame + self.skip_utf8_validation = skip_utf8_validation + self.cont_data: Optional[list] = None + self.recving_frames: Optional[int] = None + + def validate(self, frame: ABNF) -> None: + if not self.recving_frames and frame.opcode == ABNF.OPCODE_CONT: + raise WebSocketProtocolException("Illegal frame") + if self.recving_frames and frame.opcode in ( + ABNF.OPCODE_TEXT, + ABNF.OPCODE_BINARY, + ): + raise WebSocketProtocolException("Illegal frame") + + def add(self, frame: ABNF) -> None: + if self.cont_data: + self.cont_data[1] += frame.data + else: + if frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY): + self.recving_frames = frame.opcode + self.cont_data = [frame.opcode, frame.data] + + if frame.fin: + self.recving_frames = None + + def is_fire(self, frame: ABNF) -> Union[bool, int]: + return frame.fin or self.fire_cont_frame + + def extract(self, frame: ABNF) -> tuple: + data = self.cont_data + self.cont_data = None + frame.data = data[1] + if ( + not self.fire_cont_frame + and data[0] == ABNF.OPCODE_TEXT + and not self.skip_utf8_validation + and not validate_utf8(frame.data) + ): + raise WebSocketPayloadException(f"cannot decode: {repr(frame.data)}") + return data[0], frame diff --git a/contrib/python/websocket-client/websocket/_app.py b/contrib/python/websocket-client/websocket/_app.py new file mode 100644 index 0000000000..9fee76546b --- /dev/null +++ b/contrib/python/websocket-client/websocket/_app.py @@ -0,0 +1,677 @@ +import inspect +import selectors +import socket +import threading +import time +from typing import Any, Callable, Optional, Union + +from . import _logging +from ._abnf import ABNF +from ._core import WebSocket, getdefaulttimeout +from ._exceptions import ( + WebSocketConnectionClosedException, + WebSocketException, + WebSocketTimeoutException, +) +from ._ssl_compat import SSLEOFError +from ._url import parse_url + +""" +_app.py +websocket - WebSocket client library for Python + +Copyright 2024 engn33r + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +__all__ = ["WebSocketApp"] + +RECONNECT = 0 + + +def setReconnect(reconnectInterval: int) -> None: + global RECONNECT + RECONNECT = reconnectInterval + + +class DispatcherBase: + """ + DispatcherBase + """ + + def __init__(self, app: Any, ping_timeout: Union[float, int, None]) -> None: + self.app = app + self.ping_timeout = ping_timeout + + def timeout(self, seconds: Union[float, int, None], callback: Callable) -> None: + time.sleep(seconds) + callback() + + def reconnect(self, seconds: int, reconnector: Callable) -> None: + try: + _logging.info( + f"reconnect() - retrying in {seconds} seconds [{len(inspect.stack())} frames in stack]" + ) + time.sleep(seconds) + reconnector(reconnecting=True) + except KeyboardInterrupt as e: + _logging.info(f"User exited {e}") + raise e + + +class Dispatcher(DispatcherBase): + """ + Dispatcher + """ + + def read( + self, + sock: socket.socket, + read_callback: Callable, + check_callback: Callable, + ) -> None: + sel = selectors.DefaultSelector() + sel.register(self.app.sock.sock, selectors.EVENT_READ) + try: + while self.app.keep_running: + if sel.select(self.ping_timeout): + if not read_callback(): + break + check_callback() + finally: + sel.close() + + +class SSLDispatcher(DispatcherBase): + """ + SSLDispatcher + """ + + def read( + self, + sock: socket.socket, + read_callback: Callable, + check_callback: Callable, + ) -> None: + sock = self.app.sock.sock + sel = selectors.DefaultSelector() + sel.register(sock, selectors.EVENT_READ) + try: + while self.app.keep_running: + if self.select(sock, sel): + if not read_callback(): + break + check_callback() + finally: + sel.close() + + def select(self, sock, sel: selectors.DefaultSelector): + sock = self.app.sock.sock + if sock.pending(): + return [ + sock, + ] + + r = sel.select(self.ping_timeout) + + if len(r) > 0: + return r[0][0] + + +class WrappedDispatcher: + """ + WrappedDispatcher + """ + + def __init__(self, app, ping_timeout: Union[float, int, None], dispatcher) -> None: + self.app = app + self.ping_timeout = ping_timeout + self.dispatcher = dispatcher + dispatcher.signal(2, dispatcher.abort) # keyboard interrupt + + def read( + self, + sock: socket.socket, + read_callback: Callable, + check_callback: Callable, + ) -> None: + self.dispatcher.read(sock, read_callback) + self.ping_timeout and self.timeout(self.ping_timeout, check_callback) + + def timeout(self, seconds: float, callback: Callable) -> None: + self.dispatcher.timeout(seconds, callback) + + def reconnect(self, seconds: int, reconnector: Callable) -> None: + self.timeout(seconds, reconnector) + + +class WebSocketApp: + """ + Higher level of APIs are provided. The interface is like JavaScript WebSocket object. + """ + + def __init__( + self, + url: str, + header: Union[list, dict, Callable, None] = None, + on_open: Optional[Callable[[WebSocket], None]] = None, + on_reconnect: Optional[Callable[[WebSocket], None]] = None, + on_message: Optional[Callable[[WebSocket, Any], None]] = None, + on_error: Optional[Callable[[WebSocket, Any], None]] = None, + on_close: Optional[Callable[[WebSocket, Any, Any], None]] = None, + on_ping: Optional[Callable] = None, + on_pong: Optional[Callable] = None, + on_cont_message: Optional[Callable] = None, + keep_running: bool = True, + get_mask_key: Optional[Callable] = None, + cookie: Optional[str] = None, + subprotocols: Optional[list] = None, + on_data: Optional[Callable] = None, + socket: Optional[socket.socket] = None, + ) -> None: + """ + WebSocketApp initialization + + Parameters + ---------- + url: str + Websocket url. + header: list or dict or Callable + Custom header for websocket handshake. + If the parameter is a callable object, it is called just before the connection attempt. + The returned dict or list is used as custom header value. + This could be useful in order to properly setup timestamp dependent headers. + on_open: function + Callback object which is called at opening websocket. + on_open has one argument. + The 1st argument is this class object. + on_reconnect: function + Callback object which is called at reconnecting websocket. + on_reconnect has one argument. + The 1st argument is this class object. + on_message: function + Callback object which is called when received data. + on_message has 2 arguments. + The 1st argument is this class object. + The 2nd argument is utf-8 data received from the server. + on_error: function + Callback object which is called when we get error. + on_error has 2 arguments. + The 1st argument is this class object. + The 2nd argument is exception object. + on_close: function + Callback object which is called when connection is closed. + on_close has 3 arguments. + The 1st argument is this class object. + The 2nd argument is close_status_code. + The 3rd argument is close_msg. + on_cont_message: function + Callback object which is called when a continuation + frame is received. + on_cont_message has 3 arguments. + The 1st argument is this class object. + The 2nd argument is utf-8 string which we get from the server. + The 3rd argument is continue flag. if 0, the data continue + to next frame data + on_data: function + Callback object which is called when a message received. + This is called before on_message or on_cont_message, + and then on_message or on_cont_message is called. + on_data has 4 argument. + The 1st argument is this class object. + The 2nd argument is utf-8 string which we get from the server. + The 3rd argument is data type. ABNF.OPCODE_TEXT or ABNF.OPCODE_BINARY will be came. + The 4th argument is continue flag. If 0, the data continue + keep_running: bool + This parameter is obsolete and ignored. + get_mask_key: function + A callable function to get new mask keys, see the + WebSocket.set_mask_key's docstring for more information. + cookie: str + Cookie value. + subprotocols: list + List of available sub protocols. Default is None. + socket: socket + Pre-initialized stream socket. + """ + self.url = url + self.header = header if header is not None else [] + self.cookie = cookie + + self.on_open = on_open + self.on_reconnect = on_reconnect + self.on_message = on_message + self.on_data = on_data + self.on_error = on_error + self.on_close = on_close + self.on_ping = on_ping + self.on_pong = on_pong + self.on_cont_message = on_cont_message + self.keep_running = False + self.get_mask_key = get_mask_key + self.sock: Optional[WebSocket] = None + self.last_ping_tm = float(0) + self.last_pong_tm = float(0) + self.ping_thread: Optional[threading.Thread] = None + self.stop_ping: Optional[threading.Event] = None + self.ping_interval = float(0) + self.ping_timeout: Union[float, int, None] = None + self.ping_payload = "" + self.subprotocols = subprotocols + self.prepared_socket = socket + self.has_errored = False + self.has_done_teardown = False + self.has_done_teardown_lock = threading.Lock() + + def send(self, data: Union[bytes, str], opcode: int = ABNF.OPCODE_TEXT) -> None: + """ + send message + + Parameters + ---------- + data: str + Message to send. If you set opcode to OPCODE_TEXT, + data must be utf-8 string or unicode. + opcode: int + Operation code of data. Default is OPCODE_TEXT. + """ + + if not self.sock or self.sock.send(data, opcode) == 0: + raise WebSocketConnectionClosedException("Connection is already closed.") + + def send_text(self, text_data: str) -> None: + """ + Sends UTF-8 encoded text. + """ + if not self.sock or self.sock.send(text_data, ABNF.OPCODE_TEXT) == 0: + raise WebSocketConnectionClosedException("Connection is already closed.") + + def send_bytes(self, data: Union[bytes, bytearray]) -> None: + """ + Sends a sequence of bytes. + """ + if not self.sock or self.sock.send(data, ABNF.OPCODE_BINARY) == 0: + raise WebSocketConnectionClosedException("Connection is already closed.") + + def close(self, **kwargs) -> None: + """ + Close websocket connection. + """ + self.keep_running = False + if self.sock: + self.sock.close(**kwargs) + self.sock = None + + def _start_ping_thread(self) -> None: + self.last_ping_tm = self.last_pong_tm = float(0) + self.stop_ping = threading.Event() + self.ping_thread = threading.Thread(target=self._send_ping) + self.ping_thread.daemon = True + self.ping_thread.start() + + def _stop_ping_thread(self) -> None: + if self.stop_ping: + self.stop_ping.set() + if self.ping_thread and self.ping_thread.is_alive(): + self.ping_thread.join(3) + self.last_ping_tm = self.last_pong_tm = float(0) + + def _send_ping(self) -> None: + if self.stop_ping.wait(self.ping_interval) or self.keep_running is False: + return + while not self.stop_ping.wait(self.ping_interval) and self.keep_running is True: + if self.sock: + self.last_ping_tm = time.time() + try: + _logging.debug("Sending ping") + self.sock.ping(self.ping_payload) + except Exception as e: + _logging.debug(f"Failed to send ping: {e}") + + def run_forever( + self, + sockopt: tuple = None, + sslopt: dict = None, + ping_interval: Union[float, int] = 0, + ping_timeout: Union[float, int, None] = None, + ping_payload: str = "", + http_proxy_host: str = None, + http_proxy_port: Union[int, str] = None, + http_no_proxy: list = None, + http_proxy_auth: tuple = None, + http_proxy_timeout: Optional[float] = None, + skip_utf8_validation: bool = False, + host: str = None, + origin: str = None, + dispatcher=None, + suppress_origin: bool = False, + proxy_type: str = None, + reconnect: int = None, + ) -> bool: + """ + Run event loop for WebSocket framework. + + This loop is an infinite loop and is alive while websocket is available. + + Parameters + ---------- + sockopt: tuple + Values for socket.setsockopt. + sockopt must be tuple + and each element is argument of sock.setsockopt. + sslopt: dict + Optional dict object for ssl socket option. + ping_interval: int or float + Automatically send "ping" command + every specified period (in seconds). + If set to 0, no ping is sent periodically. + ping_timeout: int or float + Timeout (in seconds) if the pong message is not received. + ping_payload: str + Payload message to send with each ping. + http_proxy_host: str + HTTP proxy host name. + http_proxy_port: int or str + HTTP proxy port. If not set, set to 80. + http_no_proxy: list + Whitelisted host names that don't use the proxy. + http_proxy_timeout: int or float + HTTP proxy timeout, default is 60 sec as per python-socks. + http_proxy_auth: tuple + HTTP proxy auth information. tuple of username and password. Default is None. + skip_utf8_validation: bool + skip utf8 validation. + host: str + update host header. + origin: str + update origin header. + dispatcher: Dispatcher object + customize reading data from socket. + suppress_origin: bool + suppress outputting origin header. + proxy_type: str + type of proxy from: http, socks4, socks4a, socks5, socks5h + reconnect: int + delay interval when reconnecting + + Returns + ------- + teardown: bool + False if the `WebSocketApp` is closed or caught KeyboardInterrupt, + True if any other exception was raised during a loop. + """ + + if reconnect is None: + reconnect = RECONNECT + + if ping_timeout is not None and ping_timeout <= 0: + raise WebSocketException("Ensure ping_timeout > 0") + if ping_interval is not None and ping_interval < 0: + raise WebSocketException("Ensure ping_interval >= 0") + if ping_timeout and ping_interval and ping_interval <= ping_timeout: + raise WebSocketException("Ensure ping_interval > ping_timeout") + if not sockopt: + sockopt = () + if not sslopt: + sslopt = {} + if self.sock: + raise WebSocketException("socket is already opened") + + self.ping_interval = ping_interval + self.ping_timeout = ping_timeout + self.ping_payload = ping_payload + self.has_done_teardown = False + self.keep_running = True + + def teardown(close_frame: ABNF = None): + """ + Tears down the connection. + + Parameters + ---------- + close_frame: ABNF frame + If close_frame is set, the on_close handler is invoked + with the statusCode and reason from the provided frame. + """ + + # teardown() is called in many code paths to ensure resources are cleaned up and on_close is fired. + # To ensure the work is only done once, we use this bool and lock. + with self.has_done_teardown_lock: + if self.has_done_teardown: + return + self.has_done_teardown = True + + self._stop_ping_thread() + self.keep_running = False + if self.sock: + self.sock.close() + close_status_code, close_reason = self._get_close_args( + close_frame if close_frame else None + ) + self.sock = None + + # Finally call the callback AFTER all teardown is complete + self._callback(self.on_close, close_status_code, close_reason) + + def setSock(reconnecting: bool = False) -> None: + if reconnecting and self.sock: + self.sock.shutdown() + + self.sock = WebSocket( + self.get_mask_key, + sockopt=sockopt, + sslopt=sslopt, + fire_cont_frame=self.on_cont_message is not None, + skip_utf8_validation=skip_utf8_validation, + enable_multithread=True, + ) + + self.sock.settimeout(getdefaulttimeout()) + try: + header = self.header() if callable(self.header) else self.header + + self.sock.connect( + self.url, + header=header, + cookie=self.cookie, + http_proxy_host=http_proxy_host, + http_proxy_port=http_proxy_port, + http_no_proxy=http_no_proxy, + http_proxy_auth=http_proxy_auth, + http_proxy_timeout=http_proxy_timeout, + subprotocols=self.subprotocols, + host=host, + origin=origin, + suppress_origin=suppress_origin, + proxy_type=proxy_type, + socket=self.prepared_socket, + ) + + _logging.info("Websocket connected") + + if self.ping_interval: + self._start_ping_thread() + + if reconnecting and self.on_reconnect: + self._callback(self.on_reconnect) + else: + self._callback(self.on_open) + + dispatcher.read(self.sock.sock, read, check) + except ( + WebSocketConnectionClosedException, + ConnectionRefusedError, + KeyboardInterrupt, + SystemExit, + Exception, + ) as e: + handleDisconnect(e, reconnecting) + + def read() -> bool: + if not self.keep_running: + return teardown() + + try: + op_code, frame = self.sock.recv_data_frame(True) + except ( + WebSocketConnectionClosedException, + KeyboardInterrupt, + SSLEOFError, + ) as e: + if custom_dispatcher: + return handleDisconnect(e, bool(reconnect)) + else: + raise e + + if op_code == ABNF.OPCODE_CLOSE: + return teardown(frame) + elif op_code == ABNF.OPCODE_PING: + self._callback(self.on_ping, frame.data) + elif op_code == ABNF.OPCODE_PONG: + self.last_pong_tm = time.time() + self._callback(self.on_pong, frame.data) + elif op_code == ABNF.OPCODE_CONT and self.on_cont_message: + self._callback(self.on_data, frame.data, frame.opcode, frame.fin) + self._callback(self.on_cont_message, frame.data, frame.fin) + else: + data = frame.data + if op_code == ABNF.OPCODE_TEXT and not skip_utf8_validation: + data = data.decode("utf-8") + self._callback(self.on_data, data, frame.opcode, True) + self._callback(self.on_message, data) + + return True + + def check() -> bool: + if self.ping_timeout: + has_timeout_expired = ( + time.time() - self.last_ping_tm > self.ping_timeout + ) + has_pong_not_arrived_after_last_ping = ( + self.last_pong_tm - self.last_ping_tm < 0 + ) + has_pong_arrived_too_late = ( + self.last_pong_tm - self.last_ping_tm > self.ping_timeout + ) + + if ( + self.last_ping_tm + and has_timeout_expired + and ( + has_pong_not_arrived_after_last_ping + or has_pong_arrived_too_late + ) + ): + raise WebSocketTimeoutException("ping/pong timed out") + return True + + def handleDisconnect( + e: Union[ + WebSocketConnectionClosedException, + ConnectionRefusedError, + KeyboardInterrupt, + SystemExit, + Exception, + ], + reconnecting: bool = False, + ) -> bool: + self.has_errored = True + self._stop_ping_thread() + if not reconnecting: + self._callback(self.on_error, e) + + if isinstance(e, (KeyboardInterrupt, SystemExit)): + teardown() + # Propagate further + raise + + if reconnect: + _logging.info(f"{e} - reconnect") + if custom_dispatcher: + _logging.debug( + f"Calling custom dispatcher reconnect [{len(inspect.stack())} frames in stack]" + ) + dispatcher.reconnect(reconnect, setSock) + else: + _logging.error(f"{e} - goodbye") + teardown() + + custom_dispatcher = bool(dispatcher) + dispatcher = self.create_dispatcher( + ping_timeout, dispatcher, parse_url(self.url)[3] + ) + + try: + setSock() + if not custom_dispatcher and reconnect: + while self.keep_running: + _logging.debug( + f"Calling dispatcher reconnect [{len(inspect.stack())} frames in stack]" + ) + dispatcher.reconnect(reconnect, setSock) + except (KeyboardInterrupt, Exception) as e: + _logging.info(f"tearing down on exception {e}") + teardown() + finally: + if not custom_dispatcher: + # Ensure teardown was called before returning from run_forever + teardown() + + return self.has_errored + + def create_dispatcher( + self, + ping_timeout: Union[float, int, None], + dispatcher: Optional[DispatcherBase] = None, + is_ssl: bool = False, + ) -> Union[Dispatcher, SSLDispatcher, WrappedDispatcher]: + if dispatcher: # If custom dispatcher is set, use WrappedDispatcher + return WrappedDispatcher(self, ping_timeout, dispatcher) + timeout = ping_timeout or 10 + if is_ssl: + return SSLDispatcher(self, timeout) + return Dispatcher(self, timeout) + + def _get_close_args(self, close_frame: ABNF) -> list: + """ + _get_close_args extracts the close code and reason from the close body + if it exists (RFC6455 says WebSocket Connection Close Code is optional) + """ + # Need to catch the case where close_frame is None + # Otherwise the following if statement causes an error + if not self.on_close or not close_frame: + return [None, None] + + # Extract close frame status code + if close_frame.data and len(close_frame.data) >= 2: + close_status_code = 256 * int(close_frame.data[0]) + int( + close_frame.data[1] + ) + reason = close_frame.data[2:] + if isinstance(reason, bytes): + reason = reason.decode("utf-8") + return [close_status_code, reason] + else: + # Most likely reached this because len(close_frame_data.data) < 2 + return [None, None] + + def _callback(self, callback, *args) -> None: + if callback: + try: + callback(self, *args) + + except Exception as e: + _logging.error(f"error from callback {callback}: {e}") + if self.on_error: + self.on_error(self, e) diff --git a/contrib/python/websocket-client/websocket/_cookiejar.py b/contrib/python/websocket-client/websocket/_cookiejar.py new file mode 100644 index 0000000000..7480e5fc21 --- /dev/null +++ b/contrib/python/websocket-client/websocket/_cookiejar.py @@ -0,0 +1,75 @@ +import http.cookies +from typing import Optional + +""" +_cookiejar.py +websocket - WebSocket client library for Python + +Copyright 2024 engn33r + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + + +class SimpleCookieJar: + def __init__(self) -> None: + self.jar: dict = {} + + def add(self, set_cookie: Optional[str]) -> None: + if set_cookie: + simple_cookie = http.cookies.SimpleCookie(set_cookie) + + for v in simple_cookie.values(): + if domain := v.get("domain"): + if not domain.startswith("."): + domain = f".{domain}" + cookie = ( + self.jar.get(domain) + if self.jar.get(domain) + else http.cookies.SimpleCookie() + ) + cookie.update(simple_cookie) + self.jar[domain.lower()] = cookie + + def set(self, set_cookie: str) -> None: + if set_cookie: + simple_cookie = http.cookies.SimpleCookie(set_cookie) + + for v in simple_cookie.values(): + if domain := v.get("domain"): + if not domain.startswith("."): + domain = f".{domain}" + self.jar[domain.lower()] = simple_cookie + + def get(self, host: str) -> str: + if not host: + return "" + + cookies = [] + for domain, _ in self.jar.items(): + host = host.lower() + if host.endswith(domain) or host == domain[1:]: + cookies.append(self.jar.get(domain)) + + return "; ".join( + filter( + None, + sorted( + [ + f"{k}={v.value}" + for cookie in filter(None, cookies) + for k, v in cookie.items() + ] + ), + ) + ) diff --git a/contrib/python/websocket-client/websocket/_core.py b/contrib/python/websocket-client/websocket/_core.py new file mode 100644 index 0000000000..f940ed0573 --- /dev/null +++ b/contrib/python/websocket-client/websocket/_core.py @@ -0,0 +1,647 @@ +import socket +import struct +import threading +import time +from typing import Optional, Union + +# websocket modules +from ._abnf import ABNF, STATUS_NORMAL, continuous_frame, frame_buffer +from ._exceptions import WebSocketProtocolException, WebSocketConnectionClosedException +from ._handshake import SUPPORTED_REDIRECT_STATUSES, handshake +from ._http import connect, proxy_info +from ._logging import debug, error, trace, isEnabledForError, isEnabledForTrace +from ._socket import getdefaulttimeout, recv, send, sock_opt +from ._ssl_compat import ssl +from ._utils import NoLock + +""" +_core.py +websocket - WebSocket client library for Python + +Copyright 2024 engn33r + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +__all__ = ["WebSocket", "create_connection"] + + +class WebSocket: + """ + Low level WebSocket interface. + + This class is based on the WebSocket protocol `draft-hixie-thewebsocketprotocol-76 <http://tools.ietf.org/html/draft-hixie-thewebsocketprotocol-76>`_ + + We can connect to the websocket server and send/receive data. + The following example is an echo client. + + >>> import websocket + >>> ws = websocket.WebSocket() + >>> ws.connect("ws://echo.websocket.events") + >>> ws.recv() + 'echo.websocket.events sponsored by Lob.com' + >>> ws.send("Hello, Server") + 19 + >>> ws.recv() + 'Hello, Server' + >>> ws.close() + + Parameters + ---------- + get_mask_key: func + A callable function to get new mask keys, see the + WebSocket.set_mask_key's docstring for more information. + sockopt: tuple + Values for socket.setsockopt. + sockopt must be tuple and each element is argument of sock.setsockopt. + sslopt: dict + Optional dict object for ssl socket options. See FAQ for details. + fire_cont_frame: bool + Fire recv event for each cont frame. Default is False. + enable_multithread: bool + If set to True, lock send method. + skip_utf8_validation: bool + Skip utf8 validation. + """ + + def __init__( + self, + get_mask_key=None, + sockopt=None, + sslopt=None, + fire_cont_frame: bool = False, + enable_multithread: bool = True, + skip_utf8_validation: bool = False, + **_, + ): + """ + Initialize WebSocket object. + + Parameters + ---------- + sslopt: dict + Optional dict object for ssl socket options. See FAQ for details. + """ + self.sock_opt = sock_opt(sockopt, sslopt) + self.handshake_response = None + self.sock: Optional[socket.socket] = None + + self.connected = False + self.get_mask_key = get_mask_key + # These buffer over the build-up of a single frame. + self.frame_buffer = frame_buffer(self._recv, skip_utf8_validation) + self.cont_frame = continuous_frame(fire_cont_frame, skip_utf8_validation) + + if enable_multithread: + self.lock = threading.Lock() + self.readlock = threading.Lock() + else: + self.lock = NoLock() + self.readlock = NoLock() + + def __iter__(self): + """ + Allow iteration over websocket, implying sequential `recv` executions. + """ + while True: + yield self.recv() + + def __next__(self): + return self.recv() + + def next(self): + return self.__next__() + + def fileno(self): + return self.sock.fileno() + + def set_mask_key(self, func): + """ + Set function to create mask key. You can customize mask key generator. + Mainly, this is for testing purpose. + + Parameters + ---------- + func: func + callable object. the func takes 1 argument as integer. + The argument means length of mask key. + This func must return string(byte array), + which length is argument specified. + """ + self.get_mask_key = func + + def gettimeout(self) -> Union[float, int, None]: + """ + Get the websocket timeout (in seconds) as an int or float + + Returns + ---------- + timeout: int or float + returns timeout value (in seconds). This value could be either float/integer. + """ + return self.sock_opt.timeout + + def settimeout(self, timeout: Union[float, int, None]): + """ + Set the timeout to the websocket. + + Parameters + ---------- + timeout: int or float + timeout time (in seconds). This value could be either float/integer. + """ + self.sock_opt.timeout = timeout + if self.sock: + self.sock.settimeout(timeout) + + timeout = property(gettimeout, settimeout) + + def getsubprotocol(self): + """ + Get subprotocol + """ + if self.handshake_response: + return self.handshake_response.subprotocol + else: + return None + + subprotocol = property(getsubprotocol) + + def getstatus(self): + """ + Get handshake status + """ + if self.handshake_response: + return self.handshake_response.status + else: + return None + + status = property(getstatus) + + def getheaders(self): + """ + Get handshake response header + """ + if self.handshake_response: + return self.handshake_response.headers + else: + return None + + def is_ssl(self): + try: + return isinstance(self.sock, ssl.SSLSocket) + except: + return False + + headers = property(getheaders) + + def connect(self, url, **options): + """ + Connect to url. url is websocket url scheme. + ie. ws://host:port/resource + You can customize using 'options'. + If you set "header" list object, you can set your own custom header. + + >>> ws = WebSocket() + >>> ws.connect("ws://echo.websocket.events", + ... header=["User-Agent: MyProgram", + ... "x-custom: header"]) + + Parameters + ---------- + header: list or dict + Custom http header list or dict. + cookie: str + Cookie value. + origin: str + Custom origin url. + connection: str + Custom connection header value. + Default value "Upgrade" set in _handshake.py + suppress_origin: bool + Suppress outputting origin header. + host: str + Custom host header string. + timeout: int or float + Socket timeout time. This value is an integer or float. + If you set None for this value, it means "use default_timeout value" + http_proxy_host: str + HTTP proxy host name. + http_proxy_port: str or int + HTTP proxy port. Default is 80. + http_no_proxy: list + Whitelisted host names that don't use the proxy. + http_proxy_auth: tuple + HTTP proxy auth information. Tuple of username and password. Default is None. + http_proxy_timeout: int or float + HTTP proxy timeout, default is 60 sec as per python-socks. + redirect_limit: int + Number of redirects to follow. + subprotocols: list + List of available subprotocols. Default is None. + socket: socket + Pre-initialized stream socket. + """ + self.sock_opt.timeout = options.get("timeout", self.sock_opt.timeout) + self.sock, addrs = connect( + url, self.sock_opt, proxy_info(**options), options.pop("socket", None) + ) + + try: + self.handshake_response = handshake(self.sock, url, *addrs, **options) + for _ in range(options.pop("redirect_limit", 3)): + if self.handshake_response.status in SUPPORTED_REDIRECT_STATUSES: + url = self.handshake_response.headers["location"] + self.sock.close() + self.sock, addrs = connect( + url, + self.sock_opt, + proxy_info(**options), + options.pop("socket", None), + ) + self.handshake_response = handshake( + self.sock, url, *addrs, **options + ) + self.connected = True + except: + if self.sock: + self.sock.close() + self.sock = None + raise + + def send(self, payload: Union[bytes, str], opcode: int = ABNF.OPCODE_TEXT) -> int: + """ + Send the data as string. + + Parameters + ---------- + payload: str + Payload must be utf-8 string or unicode, + If the opcode is OPCODE_TEXT. + Otherwise, it must be string(byte array). + opcode: int + Operation code (opcode) to send. + """ + + frame = ABNF.create_frame(payload, opcode) + return self.send_frame(frame) + + def send_text(self, text_data: str) -> int: + """ + Sends UTF-8 encoded text. + """ + return self.send(text_data, ABNF.OPCODE_TEXT) + + def send_bytes(self, data: Union[bytes, bytearray]) -> int: + """ + Sends a sequence of bytes. + """ + return self.send(data, ABNF.OPCODE_BINARY) + + def send_frame(self, frame) -> int: + """ + Send the data frame. + + >>> ws = create_connection("ws://echo.websocket.events") + >>> frame = ABNF.create_frame("Hello", ABNF.OPCODE_TEXT) + >>> ws.send_frame(frame) + >>> cont_frame = ABNF.create_frame("My name is ", ABNF.OPCODE_CONT, 0) + >>> ws.send_frame(frame) + >>> cont_frame = ABNF.create_frame("Foo Bar", ABNF.OPCODE_CONT, 1) + >>> ws.send_frame(frame) + + Parameters + ---------- + frame: ABNF frame + frame data created by ABNF.create_frame + """ + if self.get_mask_key: + frame.get_mask_key = self.get_mask_key + data = frame.format() + length = len(data) + if isEnabledForTrace(): + trace(f"++Sent raw: {repr(data)}") + trace(f"++Sent decoded: {frame.__str__()}") + with self.lock: + while data: + l = self._send(data) + data = data[l:] + + return length + + def send_binary(self, payload: bytes) -> int: + """ + Send a binary message (OPCODE_BINARY). + + Parameters + ---------- + payload: bytes + payload of message to send. + """ + return self.send(payload, ABNF.OPCODE_BINARY) + + def ping(self, payload: Union[str, bytes] = ""): + """ + Send ping data. + + Parameters + ---------- + payload: str + data payload to send server. + """ + if isinstance(payload, str): + payload = payload.encode("utf-8") + self.send(payload, ABNF.OPCODE_PING) + + def pong(self, payload: Union[str, bytes] = ""): + """ + Send pong data. + + Parameters + ---------- + payload: str + data payload to send server. + """ + if isinstance(payload, str): + payload = payload.encode("utf-8") + self.send(payload, ABNF.OPCODE_PONG) + + def recv(self) -> Union[str, bytes]: + """ + Receive string data(byte array) from the server. + + Returns + ---------- + data: string (byte array) value. + """ + with self.readlock: + opcode, data = self.recv_data() + if opcode == ABNF.OPCODE_TEXT: + data_received: Union[bytes, str] = data + if isinstance(data_received, bytes): + return data_received.decode("utf-8") + elif isinstance(data_received, str): + return data_received + elif opcode == ABNF.OPCODE_BINARY: + data_binary: bytes = data + return data_binary + else: + return "" + + def recv_data(self, control_frame: bool = False) -> tuple: + """ + Receive data with operation code. + + Parameters + ---------- + control_frame: bool + a boolean flag indicating whether to return control frame + data, defaults to False + + Returns + ------- + opcode, frame.data: tuple + tuple of operation code and string(byte array) value. + """ + opcode, frame = self.recv_data_frame(control_frame) + return opcode, frame.data + + def recv_data_frame(self, control_frame: bool = False) -> tuple: + """ + Receive data with operation code. + + If a valid ping message is received, a pong response is sent. + + Parameters + ---------- + control_frame: bool + a boolean flag indicating whether to return control frame + data, defaults to False + + Returns + ------- + frame.opcode, frame: tuple + tuple of operation code and string(byte array) value. + """ + while True: + frame = self.recv_frame() + if isEnabledForTrace(): + trace(f"++Rcv raw: {repr(frame.format())}") + trace(f"++Rcv decoded: {frame.__str__()}") + if not frame: + # handle error: + # 'NoneType' object has no attribute 'opcode' + raise WebSocketProtocolException(f"Not a valid frame {frame}") + elif frame.opcode in ( + ABNF.OPCODE_TEXT, + ABNF.OPCODE_BINARY, + ABNF.OPCODE_CONT, + ): + self.cont_frame.validate(frame) + self.cont_frame.add(frame) + + if self.cont_frame.is_fire(frame): + return self.cont_frame.extract(frame) + + elif frame.opcode == ABNF.OPCODE_CLOSE: + self.send_close() + return frame.opcode, frame + elif frame.opcode == ABNF.OPCODE_PING: + if len(frame.data) < 126: + self.pong(frame.data) + else: + raise WebSocketProtocolException("Ping message is too long") + if control_frame: + return frame.opcode, frame + elif frame.opcode == ABNF.OPCODE_PONG: + if control_frame: + return frame.opcode, frame + + def recv_frame(self): + """ + Receive data as frame from server. + + Returns + ------- + self.frame_buffer.recv_frame(): ABNF frame object + """ + return self.frame_buffer.recv_frame() + + def send_close(self, status: int = STATUS_NORMAL, reason: bytes = b""): + """ + Send close data to the server. + + Parameters + ---------- + status: int + Status code to send. See STATUS_XXX. + reason: str or bytes + The reason to close. This must be string or UTF-8 bytes. + """ + if status < 0 or status >= ABNF.LENGTH_16: + raise ValueError("code is invalid range") + self.connected = False + self.send(struct.pack("!H", status) + reason, ABNF.OPCODE_CLOSE) + + def close(self, status: int = STATUS_NORMAL, reason: bytes = b"", timeout: int = 3): + """ + Close Websocket object + + Parameters + ---------- + status: int + Status code to send. See VALID_CLOSE_STATUS in ABNF. + reason: bytes + The reason to close in UTF-8. + timeout: int or float + Timeout until receive a close frame. + If None, it will wait forever until receive a close frame. + """ + if not self.connected: + return + if status < 0 or status >= ABNF.LENGTH_16: + raise ValueError("code is invalid range") + + try: + self.connected = False + self.send(struct.pack("!H", status) + reason, ABNF.OPCODE_CLOSE) + sock_timeout = self.sock.gettimeout() + self.sock.settimeout(timeout) + start_time = time.time() + while timeout is None or time.time() - start_time < timeout: + try: + frame = self.recv_frame() + if frame.opcode != ABNF.OPCODE_CLOSE: + continue + if isEnabledForError(): + recv_status = struct.unpack("!H", frame.data[0:2])[0] + if recv_status >= 3000 and recv_status <= 4999: + debug(f"close status: {repr(recv_status)}") + elif recv_status != STATUS_NORMAL: + error(f"close status: {repr(recv_status)}") + break + except: + break + self.sock.settimeout(sock_timeout) + self.sock.shutdown(socket.SHUT_RDWR) + except: + pass + + self.shutdown() + + def abort(self): + """ + Low-level asynchronous abort, wakes up other threads that are waiting in recv_* + """ + if self.connected: + self.sock.shutdown(socket.SHUT_RDWR) + + def shutdown(self): + """ + close socket, immediately. + """ + if self.sock: + self.sock.close() + self.sock = None + self.connected = False + + def _send(self, data: Union[str, bytes]): + return send(self.sock, data) + + def _recv(self, bufsize): + try: + return recv(self.sock, bufsize) + except WebSocketConnectionClosedException: + if self.sock: + self.sock.close() + self.sock = None + self.connected = False + raise + + +def create_connection(url: str, timeout=None, class_=WebSocket, **options): + """ + Connect to url and return websocket object. + + Connect to url and return the WebSocket object. + Passing optional timeout parameter will set the timeout on the socket. + If no timeout is supplied, + the global default timeout setting returned by getdefaulttimeout() is used. + You can customize using 'options'. + If you set "header" list object, you can set your own custom header. + + >>> conn = create_connection("ws://echo.websocket.events", + ... header=["User-Agent: MyProgram", + ... "x-custom: header"]) + + Parameters + ---------- + class_: class + class to instantiate when creating the connection. It has to implement + settimeout and connect. It's __init__ should be compatible with + WebSocket.__init__, i.e. accept all of it's kwargs. + header: list or dict + custom http header list or dict. + cookie: str + Cookie value. + origin: str + custom origin url. + suppress_origin: bool + suppress outputting origin header. + host: str + custom host header string. + timeout: int or float + socket timeout time. This value could be either float/integer. + If set to None, it uses the default_timeout value. + http_proxy_host: str + HTTP proxy host name. + http_proxy_port: str or int + HTTP proxy port. If not set, set to 80. + http_no_proxy: list + Whitelisted host names that don't use the proxy. + http_proxy_auth: tuple + HTTP proxy auth information. tuple of username and password. Default is None. + http_proxy_timeout: int or float + HTTP proxy timeout, default is 60 sec as per python-socks. + enable_multithread: bool + Enable lock for multithread. + redirect_limit: int + Number of redirects to follow. + sockopt: tuple + Values for socket.setsockopt. + sockopt must be a tuple and each element is an argument of sock.setsockopt. + sslopt: dict + Optional dict object for ssl socket options. See FAQ for details. + subprotocols: list + List of available subprotocols. Default is None. + skip_utf8_validation: bool + Skip utf8 validation. + socket: socket + Pre-initialized stream socket. + """ + sockopt = options.pop("sockopt", []) + sslopt = options.pop("sslopt", {}) + fire_cont_frame = options.pop("fire_cont_frame", False) + enable_multithread = options.pop("enable_multithread", True) + skip_utf8_validation = options.pop("skip_utf8_validation", False) + websock = class_( + sockopt=sockopt, + sslopt=sslopt, + fire_cont_frame=fire_cont_frame, + enable_multithread=enable_multithread, + skip_utf8_validation=skip_utf8_validation, + **options, + ) + websock.settimeout(timeout if timeout is not None else getdefaulttimeout()) + websock.connect(url, **options) + return websock diff --git a/contrib/python/websocket-client/websocket/_exceptions.py b/contrib/python/websocket-client/websocket/_exceptions.py new file mode 100644 index 0000000000..cd196e44a3 --- /dev/null +++ b/contrib/python/websocket-client/websocket/_exceptions.py @@ -0,0 +1,94 @@ +""" +_exceptions.py +websocket - WebSocket client library for Python + +Copyright 2024 engn33r + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + + +class WebSocketException(Exception): + """ + WebSocket exception class. + """ + + pass + + +class WebSocketProtocolException(WebSocketException): + """ + If the WebSocket protocol is invalid, this exception will be raised. + """ + + pass + + +class WebSocketPayloadException(WebSocketException): + """ + If the WebSocket payload is invalid, this exception will be raised. + """ + + pass + + +class WebSocketConnectionClosedException(WebSocketException): + """ + If remote host closed the connection or some network error happened, + this exception will be raised. + """ + + pass + + +class WebSocketTimeoutException(WebSocketException): + """ + WebSocketTimeoutException will be raised at socket timeout during read/write data. + """ + + pass + + +class WebSocketProxyException(WebSocketException): + """ + WebSocketProxyException will be raised when proxy error occurred. + """ + + pass + + +class WebSocketBadStatusException(WebSocketException): + """ + WebSocketBadStatusException will be raised when we get bad handshake status code. + """ + + def __init__( + self, + message: str, + status_code: int, + status_message=None, + resp_headers=None, + resp_body=None, + ): + super().__init__(message) + self.status_code = status_code + self.resp_headers = resp_headers + self.resp_body = resp_body + + +class WebSocketAddressException(WebSocketException): + """ + If the websocket address info cannot be found, this exception will be raised. + """ + + pass diff --git a/contrib/python/websocket-client/websocket/_handshake.py b/contrib/python/websocket-client/websocket/_handshake.py new file mode 100644 index 0000000000..7bd61b82f4 --- /dev/null +++ b/contrib/python/websocket-client/websocket/_handshake.py @@ -0,0 +1,202 @@ +""" +_handshake.py +websocket - WebSocket client library for Python + +Copyright 2024 engn33r + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import hashlib +import hmac +import os +from base64 import encodebytes as base64encode +from http import HTTPStatus + +from ._cookiejar import SimpleCookieJar +from ._exceptions import WebSocketException, WebSocketBadStatusException +from ._http import read_headers +from ._logging import dump, error +from ._socket import send + +__all__ = ["handshake_response", "handshake", "SUPPORTED_REDIRECT_STATUSES"] + +# websocket supported version. +VERSION = 13 + +SUPPORTED_REDIRECT_STATUSES = ( + HTTPStatus.MOVED_PERMANENTLY, + HTTPStatus.FOUND, + HTTPStatus.SEE_OTHER, + HTTPStatus.TEMPORARY_REDIRECT, + HTTPStatus.PERMANENT_REDIRECT, +) +SUCCESS_STATUSES = SUPPORTED_REDIRECT_STATUSES + (HTTPStatus.SWITCHING_PROTOCOLS,) + +CookieJar = SimpleCookieJar() + + +class handshake_response: + def __init__(self, status: int, headers: dict, subprotocol): + self.status = status + self.headers = headers + self.subprotocol = subprotocol + CookieJar.add(headers.get("set-cookie")) + + +def handshake( + sock, url: str, hostname: str, port: int, resource: str, **options +) -> handshake_response: + headers, key = _get_handshake_headers(resource, url, hostname, port, options) + + header_str = "\r\n".join(headers) + send(sock, header_str) + dump("request header", header_str) + + status, resp = _get_resp_headers(sock) + if status in SUPPORTED_REDIRECT_STATUSES: + return handshake_response(status, resp, None) + success, subproto = _validate(resp, key, options.get("subprotocols")) + if not success: + raise WebSocketException("Invalid WebSocket Header") + + return handshake_response(status, resp, subproto) + + +def _pack_hostname(hostname: str) -> str: + # IPv6 address + if ":" in hostname: + return f"[{hostname}]" + return hostname + + +def _get_handshake_headers( + resource: str, url: str, host: str, port: int, options: dict +) -> tuple: + headers = [f"GET {resource} HTTP/1.1", "Upgrade: websocket"] + if port in [80, 443]: + hostport = _pack_hostname(host) + else: + hostport = f"{_pack_hostname(host)}:{port}" + if options.get("host"): + headers.append(f'Host: {options["host"]}') + else: + headers.append(f"Host: {hostport}") + + # scheme indicates whether http or https is used in Origin + # The same approach is used in parse_url of _url.py to set default port + scheme, url = url.split(":", 1) + if not options.get("suppress_origin"): + if "origin" in options and options["origin"] is not None: + headers.append(f'Origin: {options["origin"]}') + elif scheme == "wss": + headers.append(f"Origin: https://{hostport}") + else: + headers.append(f"Origin: http://{hostport}") + + key = _create_sec_websocket_key() + + # Append Sec-WebSocket-Key & Sec-WebSocket-Version if not manually specified + if not options.get("header") or "Sec-WebSocket-Key" not in options["header"]: + headers.append(f"Sec-WebSocket-Key: {key}") + else: + key = options["header"]["Sec-WebSocket-Key"] + + if not options.get("header") or "Sec-WebSocket-Version" not in options["header"]: + headers.append(f"Sec-WebSocket-Version: {VERSION}") + + if not options.get("connection"): + headers.append("Connection: Upgrade") + else: + headers.append(options["connection"]) + + if subprotocols := options.get("subprotocols"): + headers.append(f'Sec-WebSocket-Protocol: {",".join(subprotocols)}') + + if header := options.get("header"): + if isinstance(header, dict): + header = [": ".join([k, v]) for k, v in header.items() if v is not None] + headers.extend(header) + + server_cookie = CookieJar.get(host) + client_cookie = options.get("cookie", None) + + if cookie := "; ".join(filter(None, [server_cookie, client_cookie])): + headers.append(f"Cookie: {cookie}") + + headers.extend(("", "")) + return headers, key + + +def _get_resp_headers(sock, success_statuses: tuple = SUCCESS_STATUSES) -> tuple: + status, resp_headers, status_message = read_headers(sock) + if status not in success_statuses: + content_len = resp_headers.get("content-length") + if content_len: + response_body = sock.recv( + int(content_len) + ) # read the body of the HTTP error message response and include it in the exception + else: + response_body = None + raise WebSocketBadStatusException( + f"Handshake status {status} {status_message} -+-+- {resp_headers} -+-+- {response_body}", + status, + status_message, + resp_headers, + response_body, + ) + return status, resp_headers + + +_HEADERS_TO_CHECK = { + "upgrade": "websocket", + "connection": "upgrade", +} + + +def _validate(headers, key: str, subprotocols) -> tuple: + subproto = None + for k, v in _HEADERS_TO_CHECK.items(): + r = headers.get(k, None) + if not r: + return False, None + r = [x.strip().lower() for x in r.split(",")] + if v not in r: + return False, None + + if subprotocols: + subproto = headers.get("sec-websocket-protocol", None) + if not subproto or subproto.lower() not in [s.lower() for s in subprotocols]: + error(f"Invalid subprotocol: {subprotocols}") + return False, None + subproto = subproto.lower() + + result = headers.get("sec-websocket-accept", None) + if not result: + return False, None + result = result.lower() + + if isinstance(result, str): + result = result.encode("utf-8") + + value = f"{key}258EAFA5-E914-47DA-95CA-C5AB0DC85B11".encode("utf-8") + hashed = base64encode(hashlib.sha1(value).digest()).strip().lower() + + if hmac.compare_digest(hashed, result): + return True, subproto + else: + return False, None + + +def _create_sec_websocket_key() -> str: + randomness = os.urandom(16) + return base64encode(randomness).decode("utf-8").strip() diff --git a/contrib/python/websocket-client/websocket/_http.py b/contrib/python/websocket-client/websocket/_http.py new file mode 100644 index 0000000000..9b1bf859d9 --- /dev/null +++ b/contrib/python/websocket-client/websocket/_http.py @@ -0,0 +1,373 @@ +""" +_http.py +websocket - WebSocket client library for Python + +Copyright 2024 engn33r + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import errno +import os +import socket +from base64 import encodebytes as base64encode + +from ._exceptions import ( + WebSocketAddressException, + WebSocketException, + WebSocketProxyException, +) +from ._logging import debug, dump, trace +from ._socket import DEFAULT_SOCKET_OPTION, recv_line, send +from ._ssl_compat import HAVE_SSL, ssl +from ._url import get_proxy_info, parse_url + +__all__ = ["proxy_info", "connect", "read_headers"] + +try: + from python_socks._errors import * + from python_socks._types import ProxyType + from python_socks.sync import Proxy + + HAVE_PYTHON_SOCKS = True +except: + HAVE_PYTHON_SOCKS = False + + class ProxyError(Exception): + pass + + class ProxyTimeoutError(Exception): + pass + + class ProxyConnectionError(Exception): + pass + + +class proxy_info: + def __init__(self, **options): + self.proxy_host = options.get("http_proxy_host", None) + if self.proxy_host: + self.proxy_port = options.get("http_proxy_port", 0) + self.auth = options.get("http_proxy_auth", None) + self.no_proxy = options.get("http_no_proxy", None) + self.proxy_protocol = options.get("proxy_type", "http") + # Note: If timeout not specified, default python-socks timeout is 60 seconds + self.proxy_timeout = options.get("http_proxy_timeout", None) + if self.proxy_protocol not in [ + "http", + "socks4", + "socks4a", + "socks5", + "socks5h", + ]: + raise ProxyError( + "Only http, socks4, socks5 proxy protocols are supported" + ) + else: + self.proxy_port = 0 + self.auth = None + self.no_proxy = None + self.proxy_protocol = "http" + + +def _start_proxied_socket(url: str, options, proxy) -> tuple: + if not HAVE_PYTHON_SOCKS: + raise WebSocketException( + "Python Socks is needed for SOCKS proxying but is not available" + ) + + hostname, port, resource, is_secure = parse_url(url) + + if proxy.proxy_protocol == "socks4": + rdns = False + proxy_type = ProxyType.SOCKS4 + # socks4a sends DNS through proxy + elif proxy.proxy_protocol == "socks4a": + rdns = True + proxy_type = ProxyType.SOCKS4 + elif proxy.proxy_protocol == "socks5": + rdns = False + proxy_type = ProxyType.SOCKS5 + # socks5h sends DNS through proxy + elif proxy.proxy_protocol == "socks5h": + rdns = True + proxy_type = ProxyType.SOCKS5 + + ws_proxy = Proxy.create( + proxy_type=proxy_type, + host=proxy.proxy_host, + port=int(proxy.proxy_port), + username=proxy.auth[0] if proxy.auth else None, + password=proxy.auth[1] if proxy.auth else None, + rdns=rdns, + ) + + sock = ws_proxy.connect(hostname, port, timeout=proxy.proxy_timeout) + + if is_secure: + if HAVE_SSL: + sock = _ssl_socket(sock, options.sslopt, hostname) + else: + raise WebSocketException("SSL not available.") + + return sock, (hostname, port, resource) + + +def connect(url: str, options, proxy, socket): + # Use _start_proxied_socket() only for socks4 or socks5 proxy + # Use _tunnel() for http proxy + # TODO: Use python-socks for http protocol also, to standardize flow + if proxy.proxy_host and not socket and proxy.proxy_protocol != "http": + return _start_proxied_socket(url, options, proxy) + + hostname, port_from_url, resource, is_secure = parse_url(url) + + if socket: + return socket, (hostname, port_from_url, resource) + + addrinfo_list, need_tunnel, auth = _get_addrinfo_list( + hostname, port_from_url, is_secure, proxy + ) + if not addrinfo_list: + raise WebSocketException(f"Host not found.: {hostname}:{port_from_url}") + + sock = None + try: + sock = _open_socket(addrinfo_list, options.sockopt, options.timeout) + if need_tunnel: + sock = _tunnel(sock, hostname, port_from_url, auth) + + if is_secure: + if HAVE_SSL: + sock = _ssl_socket(sock, options.sslopt, hostname) + else: + raise WebSocketException("SSL not available.") + + return sock, (hostname, port_from_url, resource) + except: + if sock: + sock.close() + raise + + +def _get_addrinfo_list(hostname, port: int, is_secure: bool, proxy) -> tuple: + phost, pport, pauth = get_proxy_info( + hostname, + is_secure, + proxy.proxy_host, + proxy.proxy_port, + proxy.auth, + proxy.no_proxy, + ) + try: + # when running on windows 10, getaddrinfo without socktype returns a socktype 0. + # This generates an error exception: `_on_error: exception Socket type must be stream or datagram, not 0` + # or `OSError: [Errno 22] Invalid argument` when creating socket. Force the socket type to SOCK_STREAM. + if not phost: + addrinfo_list = socket.getaddrinfo( + hostname, port, 0, socket.SOCK_STREAM, socket.SOL_TCP + ) + return addrinfo_list, False, None + else: + pport = pport and pport or 80 + # when running on windows 10, the getaddrinfo used above + # returns a socktype 0. This generates an error exception: + # _on_error: exception Socket type must be stream or datagram, not 0 + # Force the socket type to SOCK_STREAM + addrinfo_list = socket.getaddrinfo( + phost, pport, 0, socket.SOCK_STREAM, socket.SOL_TCP + ) + return addrinfo_list, True, pauth + except socket.gaierror as e: + raise WebSocketAddressException(e) + + +def _open_socket(addrinfo_list, sockopt, timeout): + err = None + for addrinfo in addrinfo_list: + family, socktype, proto = addrinfo[:3] + sock = socket.socket(family, socktype, proto) + sock.settimeout(timeout) + for opts in DEFAULT_SOCKET_OPTION: + sock.setsockopt(*opts) + for opts in sockopt: + sock.setsockopt(*opts) + + address = addrinfo[4] + err = None + while not err: + try: + sock.connect(address) + except socket.error as error: + sock.close() + error.remote_ip = str(address[0]) + try: + eConnRefused = ( + errno.ECONNREFUSED, + errno.WSAECONNREFUSED, + errno.ENETUNREACH, + ) + except AttributeError: + eConnRefused = (errno.ECONNREFUSED, errno.ENETUNREACH) + if error.errno not in eConnRefused: + raise error + err = error + continue + else: + break + else: + continue + break + else: + if err: + raise err + + return sock + + +def _wrap_sni_socket(sock: socket.socket, sslopt: dict, hostname, check_hostname): + context = sslopt.get("context", None) + if not context: + context = ssl.SSLContext(sslopt.get("ssl_version", ssl.PROTOCOL_TLS_CLIENT)) + # Non default context need to manually enable SSLKEYLOGFILE support by setting the keylog_filename attribute. + # For more details see also: + # * https://docs.python.org/3.8/library/ssl.html?highlight=sslkeylogfile#context-creation + # * https://docs.python.org/3.8/library/ssl.html?highlight=sslkeylogfile#ssl.SSLContext.keylog_filename + context.keylog_filename = os.environ.get("SSLKEYLOGFILE", None) + + if sslopt.get("cert_reqs", ssl.CERT_NONE) != ssl.CERT_NONE: + cafile = sslopt.get("ca_certs", None) + capath = sslopt.get("ca_cert_path", None) + if cafile or capath: + context.load_verify_locations(cafile=cafile, capath=capath) + elif hasattr(context, "load_default_certs"): + context.load_default_certs(ssl.Purpose.SERVER_AUTH) + if sslopt.get("certfile", None): + context.load_cert_chain( + sslopt["certfile"], + sslopt.get("keyfile", None), + sslopt.get("password", None), + ) + + # Python 3.10 switch to PROTOCOL_TLS_CLIENT defaults to "cert_reqs = ssl.CERT_REQUIRED" and "check_hostname = True" + # If both disabled, set check_hostname before verify_mode + # see https://github.com/liris/websocket-client/commit/b96a2e8fa765753e82eea531adb19716b52ca3ca#commitcomment-10803153 + if sslopt.get("cert_reqs", ssl.CERT_NONE) == ssl.CERT_NONE and not sslopt.get( + "check_hostname", False + ): + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE + else: + context.check_hostname = sslopt.get("check_hostname", True) + context.verify_mode = sslopt.get("cert_reqs", ssl.CERT_REQUIRED) + + if "ciphers" in sslopt: + context.set_ciphers(sslopt["ciphers"]) + if "cert_chain" in sslopt: + certfile, keyfile, password = sslopt["cert_chain"] + context.load_cert_chain(certfile, keyfile, password) + if "ecdh_curve" in sslopt: + context.set_ecdh_curve(sslopt["ecdh_curve"]) + + return context.wrap_socket( + sock, + do_handshake_on_connect=sslopt.get("do_handshake_on_connect", True), + suppress_ragged_eofs=sslopt.get("suppress_ragged_eofs", True), + server_hostname=hostname, + ) + + +def _ssl_socket(sock: socket.socket, user_sslopt: dict, hostname): + sslopt: dict = {"cert_reqs": ssl.CERT_REQUIRED} + sslopt.update(user_sslopt) + + cert_path = os.environ.get("WEBSOCKET_CLIENT_CA_BUNDLE") + if ( + cert_path + and os.path.isfile(cert_path) + and user_sslopt.get("ca_certs", None) is None + ): + sslopt["ca_certs"] = cert_path + elif ( + cert_path + and os.path.isdir(cert_path) + and user_sslopt.get("ca_cert_path", None) is None + ): + sslopt["ca_cert_path"] = cert_path + + if sslopt.get("server_hostname", None): + hostname = sslopt["server_hostname"] + + check_hostname = sslopt.get("check_hostname", True) + sock = _wrap_sni_socket(sock, sslopt, hostname, check_hostname) + + return sock + + +def _tunnel(sock: socket.socket, host, port: int, auth) -> socket.socket: + debug("Connecting proxy...") + connect_header = f"CONNECT {host}:{port} HTTP/1.1\r\n" + connect_header += f"Host: {host}:{port}\r\n" + + # TODO: support digest auth. + if auth and auth[0]: + auth_str = auth[0] + if auth[1]: + auth_str += f":{auth[1]}" + encoded_str = base64encode(auth_str.encode()).strip().decode().replace("\n", "") + connect_header += f"Proxy-Authorization: Basic {encoded_str}\r\n" + connect_header += "\r\n" + dump("request header", connect_header) + + send(sock, connect_header) + + try: + status, _, _ = read_headers(sock) + except Exception as e: + raise WebSocketProxyException(str(e)) + + if status != 200: + raise WebSocketProxyException(f"failed CONNECT via proxy status: {status}") + + return sock + + +def read_headers(sock: socket.socket) -> tuple: + status = None + status_message = None + headers: dict = {} + trace("--- response header ---") + + while True: + line = recv_line(sock) + line = line.decode("utf-8").strip() + if not line: + break + trace(line) + if not status: + status_info = line.split(" ", 2) + status = int(status_info[1]) + if len(status_info) > 2: + status_message = status_info[2] + else: + kv = line.split(":", 1) + if len(kv) != 2: + raise WebSocketException("Invalid header") + key, value = kv + if key.lower() == "set-cookie" and headers.get("set-cookie"): + headers["set-cookie"] = headers.get("set-cookie") + "; " + value.strip() + else: + headers[key.lower()] = value.strip() + + trace("-----------------------") + + return status, headers, status_message diff --git a/contrib/python/websocket-client/websocket/_logging.py b/contrib/python/websocket-client/websocket/_logging.py new file mode 100644 index 0000000000..0f673d3aff --- /dev/null +++ b/contrib/python/websocket-client/websocket/_logging.py @@ -0,0 +1,106 @@ +import logging + +""" +_logging.py +websocket - WebSocket client library for Python + +Copyright 2024 engn33r + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +_logger = logging.getLogger("websocket") +try: + from logging import NullHandler +except ImportError: + + class NullHandler(logging.Handler): + def emit(self, record) -> None: + pass + + +_logger.addHandler(NullHandler()) + +_traceEnabled = False + +__all__ = [ + "enableTrace", + "dump", + "error", + "warning", + "debug", + "trace", + "isEnabledForError", + "isEnabledForDebug", + "isEnabledForTrace", +] + + +def enableTrace( + traceable: bool, + handler: logging.StreamHandler = logging.StreamHandler(), + level: str = "DEBUG", +) -> None: + """ + Turn on/off the traceability. + + Parameters + ---------- + traceable: bool + If set to True, traceability is enabled. + """ + global _traceEnabled + _traceEnabled = traceable + if traceable: + _logger.addHandler(handler) + _logger.setLevel(getattr(logging, level)) + + +def dump(title: str, message: str) -> None: + if _traceEnabled: + _logger.debug(f"--- {title} ---") + _logger.debug(message) + _logger.debug("-----------------------") + + +def error(msg: str) -> None: + _logger.error(msg) + + +def warning(msg: str) -> None: + _logger.warning(msg) + + +def debug(msg: str) -> None: + _logger.debug(msg) + + +def info(msg: str) -> None: + _logger.info(msg) + + +def trace(msg: str) -> None: + if _traceEnabled: + _logger.debug(msg) + + +def isEnabledForError() -> bool: + return _logger.isEnabledFor(logging.ERROR) + + +def isEnabledForDebug() -> bool: + return _logger.isEnabledFor(logging.DEBUG) + + +def isEnabledForTrace() -> bool: + return _traceEnabled diff --git a/contrib/python/websocket-client/websocket/_socket.py b/contrib/python/websocket-client/websocket/_socket.py new file mode 100644 index 0000000000..81094ffc84 --- /dev/null +++ b/contrib/python/websocket-client/websocket/_socket.py @@ -0,0 +1,188 @@ +import errno +import selectors +import socket +from typing import Union + +from ._exceptions import ( + WebSocketConnectionClosedException, + WebSocketTimeoutException, +) +from ._ssl_compat import SSLError, SSLWantReadError, SSLWantWriteError +from ._utils import extract_error_code, extract_err_message + +""" +_socket.py +websocket - WebSocket client library for Python + +Copyright 2024 engn33r + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +DEFAULT_SOCKET_OPTION = [(socket.SOL_TCP, socket.TCP_NODELAY, 1)] +if hasattr(socket, "SO_KEEPALIVE"): + DEFAULT_SOCKET_OPTION.append((socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)) +if hasattr(socket, "TCP_KEEPIDLE"): + DEFAULT_SOCKET_OPTION.append((socket.SOL_TCP, socket.TCP_KEEPIDLE, 30)) +if hasattr(socket, "TCP_KEEPINTVL"): + DEFAULT_SOCKET_OPTION.append((socket.SOL_TCP, socket.TCP_KEEPINTVL, 10)) +if hasattr(socket, "TCP_KEEPCNT"): + DEFAULT_SOCKET_OPTION.append((socket.SOL_TCP, socket.TCP_KEEPCNT, 3)) + +_default_timeout = None + +__all__ = [ + "DEFAULT_SOCKET_OPTION", + "sock_opt", + "setdefaulttimeout", + "getdefaulttimeout", + "recv", + "recv_line", + "send", +] + + +class sock_opt: + def __init__(self, sockopt: list, sslopt: dict) -> None: + if sockopt is None: + sockopt = [] + if sslopt is None: + sslopt = {} + self.sockopt = sockopt + self.sslopt = sslopt + self.timeout = None + + +def setdefaulttimeout(timeout: Union[int, float, None]) -> None: + """ + Set the global timeout setting to connect. + + Parameters + ---------- + timeout: int or float + default socket timeout time (in seconds) + """ + global _default_timeout + _default_timeout = timeout + + +def getdefaulttimeout() -> Union[int, float, None]: + """ + Get default timeout + + Returns + ---------- + _default_timeout: int or float + Return the global timeout setting (in seconds) to connect. + """ + return _default_timeout + + +def recv(sock: socket.socket, bufsize: int) -> bytes: + if not sock: + raise WebSocketConnectionClosedException("socket is already closed.") + + def _recv(): + try: + return sock.recv(bufsize) + except SSLWantReadError: + pass + except socket.error as exc: + error_code = extract_error_code(exc) + if error_code not in [errno.EAGAIN, errno.EWOULDBLOCK]: + raise + + sel = selectors.DefaultSelector() + sel.register(sock, selectors.EVENT_READ) + + r = sel.select(sock.gettimeout()) + sel.close() + + if r: + return sock.recv(bufsize) + + try: + if sock.gettimeout() == 0: + bytes_ = sock.recv(bufsize) + else: + bytes_ = _recv() + except TimeoutError: + raise WebSocketTimeoutException("Connection timed out") + except socket.timeout as e: + message = extract_err_message(e) + raise WebSocketTimeoutException(message) + except SSLError as e: + message = extract_err_message(e) + if isinstance(message, str) and "timed out" in message: + raise WebSocketTimeoutException(message) + else: + raise + + if not bytes_: + raise WebSocketConnectionClosedException("Connection to remote host was lost.") + + return bytes_ + + +def recv_line(sock: socket.socket) -> bytes: + line = [] + while True: + c = recv(sock, 1) + line.append(c) + if c == b"\n": + break + return b"".join(line) + + +def send(sock: socket.socket, data: Union[bytes, str]) -> int: + if isinstance(data, str): + data = data.encode("utf-8") + + if not sock: + raise WebSocketConnectionClosedException("socket is already closed.") + + def _send(): + try: + return sock.send(data) + except SSLWantWriteError: + pass + except socket.error as exc: + error_code = extract_error_code(exc) + if error_code is None: + raise + if error_code not in [errno.EAGAIN, errno.EWOULDBLOCK]: + raise + + sel = selectors.DefaultSelector() + sel.register(sock, selectors.EVENT_WRITE) + + w = sel.select(sock.gettimeout()) + sel.close() + + if w: + return sock.send(data) + + try: + if sock.gettimeout() == 0: + return sock.send(data) + else: + return _send() + except socket.timeout as e: + message = extract_err_message(e) + raise WebSocketTimeoutException(message) + except Exception as e: + message = extract_err_message(e) + if isinstance(message, str) and "timed out" in message: + raise WebSocketTimeoutException(message) + else: + raise diff --git a/contrib/python/websocket-client/websocket/_ssl_compat.py b/contrib/python/websocket-client/websocket/_ssl_compat.py new file mode 100644 index 0000000000..0a8a32b59b --- /dev/null +++ b/contrib/python/websocket-client/websocket/_ssl_compat.py @@ -0,0 +1,48 @@ +""" +_ssl_compat.py +websocket - WebSocket client library for Python + +Copyright 2024 engn33r + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +__all__ = [ + "HAVE_SSL", + "ssl", + "SSLError", + "SSLEOFError", + "SSLWantReadError", + "SSLWantWriteError", +] + +try: + import ssl + from ssl import SSLError, SSLEOFError, SSLWantReadError, SSLWantWriteError + + HAVE_SSL = True +except ImportError: + # dummy class of SSLError for environment without ssl support + class SSLError(Exception): + pass + + class SSLEOFError(Exception): + pass + + class SSLWantReadError(Exception): + pass + + class SSLWantWriteError(Exception): + pass + + ssl = None + HAVE_SSL = False diff --git a/contrib/python/websocket-client/websocket/_url.py b/contrib/python/websocket-client/websocket/_url.py new file mode 100644 index 0000000000..902131710b --- /dev/null +++ b/contrib/python/websocket-client/websocket/_url.py @@ -0,0 +1,190 @@ +import os +import socket +import struct +from typing import Optional +from urllib.parse import unquote, urlparse +from ._exceptions import WebSocketProxyException + +""" +_url.py +websocket - WebSocket client library for Python + +Copyright 2024 engn33r + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +__all__ = ["parse_url", "get_proxy_info"] + + +def parse_url(url: str) -> tuple: + """ + parse url and the result is tuple of + (hostname, port, resource path and the flag of secure mode) + + Parameters + ---------- + url: str + url string. + """ + if ":" not in url: + raise ValueError("url is invalid") + + scheme, url = url.split(":", 1) + + parsed = urlparse(url, scheme="http") + if parsed.hostname: + hostname = parsed.hostname + else: + raise ValueError("hostname is invalid") + port = 0 + if parsed.port: + port = parsed.port + + is_secure = False + if scheme == "ws": + if not port: + port = 80 + elif scheme == "wss": + is_secure = True + if not port: + port = 443 + else: + raise ValueError("scheme %s is invalid" % scheme) + + if parsed.path: + resource = parsed.path + else: + resource = "/" + + if parsed.query: + resource += f"?{parsed.query}" + + return hostname, port, resource, is_secure + + +DEFAULT_NO_PROXY_HOST = ["localhost", "127.0.0.1"] + + +def _is_ip_address(addr: str) -> bool: + try: + socket.inet_aton(addr) + except socket.error: + return False + else: + return True + + +def _is_subnet_address(hostname: str) -> bool: + try: + addr, netmask = hostname.split("/") + return _is_ip_address(addr) and 0 <= int(netmask) < 32 + except ValueError: + return False + + +def _is_address_in_network(ip: str, net: str) -> bool: + ipaddr: int = struct.unpack("!I", socket.inet_aton(ip))[0] + netaddr, netmask = net.split("/") + netaddr: int = struct.unpack("!I", socket.inet_aton(netaddr))[0] + + netmask = (0xFFFFFFFF << (32 - int(netmask))) & 0xFFFFFFFF + return ipaddr & netmask == netaddr + + +def _is_no_proxy_host(hostname: str, no_proxy: Optional[list]) -> bool: + if not no_proxy: + if v := os.environ.get("no_proxy", os.environ.get("NO_PROXY", "")).replace( + " ", "" + ): + no_proxy = v.split(",") + if not no_proxy: + no_proxy = DEFAULT_NO_PROXY_HOST + + if "*" in no_proxy: + return True + if hostname in no_proxy: + return True + if _is_ip_address(hostname): + return any( + [ + _is_address_in_network(hostname, subnet) + for subnet in no_proxy + if _is_subnet_address(subnet) + ] + ) + for domain in [domain for domain in no_proxy if domain.startswith(".")]: + if hostname.endswith(domain): + return True + return False + + +def get_proxy_info( + hostname: str, + is_secure: bool, + proxy_host: Optional[str] = None, + proxy_port: int = 0, + proxy_auth: Optional[tuple] = None, + no_proxy: Optional[list] = None, + proxy_type: str = "http", +) -> tuple: + """ + Try to retrieve proxy host and port from environment + if not provided in options. + Result is (proxy_host, proxy_port, proxy_auth). + proxy_auth is tuple of username and password + of proxy authentication information. + + Parameters + ---------- + hostname: str + Websocket server name. + is_secure: bool + Is the connection secure? (wss) looks for "https_proxy" in env + instead of "http_proxy" + proxy_host: str + http proxy host name. + proxy_port: str or int + http proxy port. + no_proxy: list + Whitelisted host names that don't use the proxy. + proxy_auth: tuple + HTTP proxy auth information. Tuple of username and password. Default is None. + proxy_type: str + Specify the proxy protocol (http, socks4, socks4a, socks5, socks5h). Default is "http". + Use socks4a or socks5h if you want to send DNS requests through the proxy. + """ + if _is_no_proxy_host(hostname, no_proxy): + return None, 0, None + + if proxy_host: + if not proxy_port: + raise WebSocketProxyException("Cannot use port 0 when proxy_host specified") + port = proxy_port + auth = proxy_auth + return proxy_host, port, auth + + env_key = "https_proxy" if is_secure else "http_proxy" + value = os.environ.get(env_key, os.environ.get(env_key.upper(), "")).replace( + " ", "" + ) + if value: + proxy = urlparse(value) + auth = ( + (unquote(proxy.username), unquote(proxy.password)) + if proxy.username + else None + ) + return proxy.hostname, proxy.port, auth + + return None, 0, None diff --git a/contrib/python/websocket-client/websocket/_utils.py b/contrib/python/websocket-client/websocket/_utils.py new file mode 100644 index 0000000000..65f3c0daf7 --- /dev/null +++ b/contrib/python/websocket-client/websocket/_utils.py @@ -0,0 +1,459 @@ +from typing import Union + +""" +_url.py +websocket - WebSocket client library for Python + +Copyright 2024 engn33r + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +__all__ = ["NoLock", "validate_utf8", "extract_err_message", "extract_error_code"] + + +class NoLock: + def __enter__(self) -> None: + pass + + def __exit__(self, exc_type, exc_value, traceback) -> None: + pass + + +try: + # If wsaccel is available we use compiled routines to validate UTF-8 + # strings. + from wsaccel.utf8validator import Utf8Validator + + def _validate_utf8(utfbytes: Union[str, bytes]) -> bool: + result: bool = Utf8Validator().validate(utfbytes)[0] + return result + +except ImportError: + # UTF-8 validator + # python implementation of http://bjoern.hoehrmann.de/utf-8/decoder/dfa/ + + _UTF8_ACCEPT = 0 + _UTF8_REJECT = 12 + + _UTF8D = [ + # The first part of the table maps bytes to character classes that + # to reduce the size of the transition table and create bitmasks. + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 9, + 9, + 9, + 9, + 9, + 9, + 9, + 9, + 9, + 9, + 9, + 9, + 9, + 9, + 9, + 9, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 8, + 8, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 10, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 4, + 3, + 3, + 11, + 6, + 6, + 6, + 5, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + # The second part is a transition table that maps a combination + # of a state of the automaton and a character class to a state. + 0, + 12, + 24, + 36, + 60, + 96, + 84, + 12, + 12, + 12, + 48, + 72, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 0, + 12, + 12, + 12, + 12, + 12, + 0, + 12, + 0, + 12, + 12, + 12, + 24, + 12, + 12, + 12, + 12, + 12, + 24, + 12, + 24, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 24, + 12, + 12, + 12, + 12, + 12, + 24, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 24, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 36, + 12, + 36, + 12, + 12, + 12, + 36, + 12, + 12, + 12, + 12, + 12, + 36, + 12, + 36, + 12, + 12, + 12, + 36, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + ] + + def _decode(state: int, codep: int, ch: int) -> tuple: + tp = _UTF8D[ch] + + codep = ( + (ch & 0x3F) | (codep << 6) if (state != _UTF8_ACCEPT) else (0xFF >> tp) & ch + ) + state = _UTF8D[256 + state + tp] + + return state, codep + + def _validate_utf8(utfbytes: Union[str, bytes]) -> bool: + state = _UTF8_ACCEPT + codep = 0 + for i in utfbytes: + state, codep = _decode(state, codep, int(i)) + if state == _UTF8_REJECT: + return False + + return True + + +def validate_utf8(utfbytes: Union[str, bytes]) -> bool: + """ + validate utf8 byte string. + utfbytes: utf byte string to check. + return value: if valid utf8 string, return true. Otherwise, return false. + """ + return _validate_utf8(utfbytes) + + +def extract_err_message(exception: Exception) -> Union[str, None]: + if exception.args: + exception_message: str = exception.args[0] + return exception_message + else: + return None + + +def extract_error_code(exception: Exception) -> Union[int, None]: + if exception.args and len(exception.args) > 1: + return exception.args[0] if isinstance(exception.args[0], int) else None diff --git a/contrib/python/websocket-client/websocket/_wsdump.py b/contrib/python/websocket-client/websocket/_wsdump.py new file mode 100644 index 0000000000..d4d76dc509 --- /dev/null +++ b/contrib/python/websocket-client/websocket/_wsdump.py @@ -0,0 +1,244 @@ +#!/usr/bin/env python3 + +""" +wsdump.py +websocket - WebSocket client library for Python + +Copyright 2024 engn33r + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import argparse +import code +import gzip +import ssl +import sys +import threading +import time +import zlib +from urllib.parse import urlparse + +import websocket + +try: + import readline +except ImportError: + pass + + +def get_encoding() -> str: + encoding = getattr(sys.stdin, "encoding", "") + if not encoding: + return "utf-8" + else: + return encoding.lower() + + +OPCODE_DATA = (websocket.ABNF.OPCODE_TEXT, websocket.ABNF.OPCODE_BINARY) +ENCODING = get_encoding() + + +class VAction(argparse.Action): + def __call__( + self, + parser: argparse.Namespace, + args: tuple, + values: str, + option_string: str = None, + ) -> None: + if values is None: + values = "1" + try: + values = int(values) + except ValueError: + values = values.count("v") + 1 + setattr(args, self.dest, values) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="WebSocket Simple Dump Tool") + parser.add_argument( + "url", metavar="ws_url", help="websocket url. ex. ws://echo.websocket.events/" + ) + parser.add_argument("-p", "--proxy", help="proxy url. ex. http://127.0.0.1:8080") + parser.add_argument( + "-v", + "--verbose", + default=0, + nargs="?", + action=VAction, + dest="verbose", + help="set verbose mode. If set to 1, show opcode. " + "If set to 2, enable to trace websocket module", + ) + parser.add_argument( + "-n", "--nocert", action="store_true", help="Ignore invalid SSL cert" + ) + parser.add_argument("-r", "--raw", action="store_true", help="raw output") + parser.add_argument("-s", "--subprotocols", nargs="*", help="Set subprotocols") + parser.add_argument("-o", "--origin", help="Set origin") + parser.add_argument( + "--eof-wait", + default=0, + type=int, + help="wait time(second) after 'EOF' received.", + ) + parser.add_argument("-t", "--text", help="Send initial text") + parser.add_argument( + "--timings", action="store_true", help="Print timings in seconds" + ) + parser.add_argument("--headers", help="Set custom headers. Use ',' as separator") + + return parser.parse_args() + + +class RawInput: + def raw_input(self, prompt: str = "") -> str: + line = input(prompt) + + if ENCODING and ENCODING != "utf-8" and not isinstance(line, str): + line = line.decode(ENCODING).encode("utf-8") + elif isinstance(line, str): + line = line.encode("utf-8") + + return line + + +class InteractiveConsole(RawInput, code.InteractiveConsole): + def write(self, data: str) -> None: + sys.stdout.write("\033[2K\033[E") + # sys.stdout.write("\n") + sys.stdout.write("\033[34m< " + data + "\033[39m") + sys.stdout.write("\n> ") + sys.stdout.flush() + + def read(self) -> str: + return self.raw_input("> ") + + +class NonInteractive(RawInput): + def write(self, data: str) -> None: + sys.stdout.write(data) + sys.stdout.write("\n") + sys.stdout.flush() + + def read(self) -> str: + return self.raw_input("") + + +def main() -> None: + start_time = time.time() + args = parse_args() + if args.verbose > 1: + websocket.enableTrace(True) + options = {} + if args.proxy: + p = urlparse(args.proxy) + options["http_proxy_host"] = p.hostname + options["http_proxy_port"] = p.port + if args.origin: + options["origin"] = args.origin + if args.subprotocols: + options["subprotocols"] = args.subprotocols + opts = {} + if args.nocert: + opts = {"cert_reqs": ssl.CERT_NONE, "check_hostname": False} + if args.headers: + options["header"] = list(map(str.strip, args.headers.split(","))) + ws = websocket.create_connection(args.url, sslopt=opts, **options) + if args.raw: + console = NonInteractive() + else: + console = InteractiveConsole() + print("Press Ctrl+C to quit") + + def recv() -> tuple: + try: + frame = ws.recv_frame() + except websocket.WebSocketException: + return websocket.ABNF.OPCODE_CLOSE, "" + if not frame: + raise websocket.WebSocketException(f"Not a valid frame {frame}") + elif frame.opcode in OPCODE_DATA: + return frame.opcode, frame.data + elif frame.opcode == websocket.ABNF.OPCODE_CLOSE: + ws.send_close() + return frame.opcode, "" + elif frame.opcode == websocket.ABNF.OPCODE_PING: + ws.pong(frame.data) + return frame.opcode, frame.data + + return frame.opcode, frame.data + + def recv_ws() -> None: + while True: + opcode, data = recv() + msg = None + if opcode == websocket.ABNF.OPCODE_TEXT and isinstance(data, bytes): + data = str(data, "utf-8") + if ( + isinstance(data, bytes) and len(data) > 2 and data[:2] == b"\037\213" + ): # gzip magick + try: + data = "[gzip] " + str(gzip.decompress(data), "utf-8") + except: + pass + elif isinstance(data, bytes): + try: + data = "[zlib] " + str( + zlib.decompress(data, -zlib.MAX_WBITS), "utf-8" + ) + except: + pass + + if isinstance(data, bytes): + data = repr(data) + + if args.verbose: + msg = f"{websocket.ABNF.OPCODE_MAP.get(opcode)}: {data}" + else: + msg = data + + if msg is not None: + if args.timings: + console.write(f"{time.time() - start_time}: {msg}") + else: + console.write(msg) + + if opcode == websocket.ABNF.OPCODE_CLOSE: + break + + thread = threading.Thread(target=recv_ws) + thread.daemon = True + thread.start() + + if args.text: + ws.send(args.text) + + while True: + try: + message = console.read() + ws.send(message) + except KeyboardInterrupt: + return + except EOFError: + time.sleep(args.eof_wait) + return + + +if __name__ == "__main__": + try: + main() + except Exception as e: + print(e) diff --git a/contrib/python/websocket-client/websocket/py.typed b/contrib/python/websocket-client/websocket/py.typed new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/contrib/python/websocket-client/websocket/py.typed diff --git a/contrib/python/websocket-client/websocket/tests/__init__.py b/contrib/python/websocket-client/websocket/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/contrib/python/websocket-client/websocket/tests/__init__.py diff --git a/contrib/python/websocket-client/websocket/tests/data/header01.txt b/contrib/python/websocket-client/websocket/tests/data/header01.txt new file mode 100644 index 0000000000..d44d24c205 --- /dev/null +++ b/contrib/python/websocket-client/websocket/tests/data/header01.txt @@ -0,0 +1,6 @@ +HTTP/1.1 101 WebSocket Protocol Handshake +Connection: Upgrade +Upgrade: WebSocket +Sec-WebSocket-Accept: Kxep+hNu9n51529fGidYu7a3wO0= +some_header: something + diff --git a/contrib/python/websocket-client/websocket/tests/data/header02.txt b/contrib/python/websocket-client/websocket/tests/data/header02.txt new file mode 100644 index 0000000000..f481de928a --- /dev/null +++ b/contrib/python/websocket-client/websocket/tests/data/header02.txt @@ -0,0 +1,6 @@ +HTTP/1.1 101 WebSocket Protocol Handshake +Connection: Upgrade +Upgrade WebSocket +Sec-WebSocket-Accept: Kxep+hNu9n51529fGidYu7a3wO0= +some_header: something + diff --git a/contrib/python/websocket-client/websocket/tests/data/header03.txt b/contrib/python/websocket-client/websocket/tests/data/header03.txt new file mode 100644 index 0000000000..1a81dc70ce --- /dev/null +++ b/contrib/python/websocket-client/websocket/tests/data/header03.txt @@ -0,0 +1,8 @@ +HTTP/1.1 101 WebSocket Protocol Handshake +Connection: Upgrade, Keep-Alive +Upgrade: WebSocket +Sec-WebSocket-Accept: Kxep+hNu9n51529fGidYu7a3wO0= +Set-Cookie: Token=ABCDE +Set-Cookie: Token=FGHIJ +some_header: something + diff --git a/contrib/python/websocket-client/websocket/tests/echo-server.py b/contrib/python/websocket-client/websocket/tests/echo-server.py new file mode 100644 index 0000000000..5d1e87087b --- /dev/null +++ b/contrib/python/websocket-client/websocket/tests/echo-server.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python + +# From https://github.com/aaugustin/websockets/blob/main/example/echo.py + +import asyncio +import os + +import websockets + +LOCAL_WS_SERVER_PORT = int(os.environ.get("LOCAL_WS_SERVER_PORT", "8765")) + + +async def echo(websocket): + async for message in websocket: + await websocket.send(message) + + +async def main(): + async with websockets.serve(echo, "localhost", LOCAL_WS_SERVER_PORT): + await asyncio.Future() # run forever + + +asyncio.run(main()) diff --git a/contrib/python/websocket-client/websocket/tests/test_abnf.py b/contrib/python/websocket-client/websocket/tests/test_abnf.py new file mode 100644 index 0000000000..a749f13bd5 --- /dev/null +++ b/contrib/python/websocket-client/websocket/tests/test_abnf.py @@ -0,0 +1,125 @@ +# -*- coding: utf-8 -*- +# +import unittest + +from websocket._abnf import ABNF, frame_buffer +from websocket._exceptions import WebSocketProtocolException + +""" +test_abnf.py +websocket - WebSocket client library for Python + +Copyright 2024 engn33r + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + + +class ABNFTest(unittest.TestCase): + def test_init(self): + a = ABNF(0, 0, 0, 0, opcode=ABNF.OPCODE_PING) + self.assertEqual(a.fin, 0) + self.assertEqual(a.rsv1, 0) + self.assertEqual(a.rsv2, 0) + self.assertEqual(a.rsv3, 0) + self.assertEqual(a.opcode, 9) + self.assertEqual(a.data, "") + a_bad = ABNF(0, 1, 0, 0, opcode=77) + self.assertEqual(a_bad.rsv1, 1) + self.assertEqual(a_bad.opcode, 77) + + def test_validate(self): + a_invalid_ping = ABNF(0, 0, 0, 0, opcode=ABNF.OPCODE_PING) + self.assertRaises( + WebSocketProtocolException, + a_invalid_ping.validate, + skip_utf8_validation=False, + ) + a_bad_rsv_value = ABNF(0, 1, 0, 0, opcode=ABNF.OPCODE_TEXT) + self.assertRaises( + WebSocketProtocolException, + a_bad_rsv_value.validate, + skip_utf8_validation=False, + ) + a_bad_opcode = ABNF(0, 0, 0, 0, opcode=77) + self.assertRaises( + WebSocketProtocolException, + a_bad_opcode.validate, + skip_utf8_validation=False, + ) + a_bad_close_frame = ABNF(0, 0, 0, 0, opcode=ABNF.OPCODE_CLOSE, data=b"\x01") + self.assertRaises( + WebSocketProtocolException, + a_bad_close_frame.validate, + skip_utf8_validation=False, + ) + a_bad_close_frame_2 = ABNF( + 0, 0, 0, 0, opcode=ABNF.OPCODE_CLOSE, data=b"\x01\x8a\xaa\xff\xdd" + ) + self.assertRaises( + WebSocketProtocolException, + a_bad_close_frame_2.validate, + skip_utf8_validation=False, + ) + a_bad_close_frame_3 = ABNF( + 0, 0, 0, 0, opcode=ABNF.OPCODE_CLOSE, data=b"\x03\xe7" + ) + self.assertRaises( + WebSocketProtocolException, + a_bad_close_frame_3.validate, + skip_utf8_validation=True, + ) + + def test_mask(self): + abnf_none_data = ABNF( + 0, 0, 0, 0, opcode=ABNF.OPCODE_PING, mask_value=1, data=None + ) + bytes_val = b"aaaa" + self.assertEqual(abnf_none_data._get_masked(bytes_val), bytes_val) + abnf_str_data = ABNF( + 0, 0, 0, 0, opcode=ABNF.OPCODE_PING, mask_value=1, data="a" + ) + self.assertEqual(abnf_str_data._get_masked(bytes_val), b"aaaa\x00") + + def test_format(self): + abnf_bad_rsv_bits = ABNF(2, 0, 0, 0, opcode=ABNF.OPCODE_TEXT) + self.assertRaises(ValueError, abnf_bad_rsv_bits.format) + abnf_bad_opcode = ABNF(0, 0, 0, 0, opcode=5) + self.assertRaises(ValueError, abnf_bad_opcode.format) + abnf_length_10 = ABNF(0, 0, 0, 0, opcode=ABNF.OPCODE_TEXT, data="abcdefghij") + self.assertEqual(b"\x01", abnf_length_10.format()[0].to_bytes(1, "big")) + self.assertEqual(b"\x8a", abnf_length_10.format()[1].to_bytes(1, "big")) + self.assertEqual("fin=0 opcode=1 data=abcdefghij", abnf_length_10.__str__()) + abnf_length_20 = ABNF( + 0, 0, 0, 0, opcode=ABNF.OPCODE_BINARY, data="abcdefghijabcdefghij" + ) + self.assertEqual(b"\x02", abnf_length_20.format()[0].to_bytes(1, "big")) + self.assertEqual(b"\x94", abnf_length_20.format()[1].to_bytes(1, "big")) + abnf_no_mask = ABNF( + 0, 0, 0, 0, opcode=ABNF.OPCODE_TEXT, mask_value=0, data=b"\x01\x8a\xcc" + ) + self.assertEqual(b"\x01\x03\x01\x8a\xcc", abnf_no_mask.format()) + + def test_frame_buffer(self): + fb = frame_buffer(0, True) + self.assertEqual(fb.recv, 0) + self.assertEqual(fb.skip_utf8_validation, True) + fb.clear + self.assertEqual(fb.header, None) + self.assertEqual(fb.length, None) + self.assertEqual(fb.mask_value, None) + self.assertEqual(fb.has_mask(), False) + + +if __name__ == "__main__": + unittest.main() diff --git a/contrib/python/websocket-client/websocket/tests/test_app.py b/contrib/python/websocket-client/websocket/tests/test_app.py new file mode 100644 index 0000000000..18eace5442 --- /dev/null +++ b/contrib/python/websocket-client/websocket/tests/test_app.py @@ -0,0 +1,352 @@ +# -*- coding: utf-8 -*- +# +import os +import os.path +import ssl +import threading +import unittest + +import websocket as ws + +""" +test_app.py +websocket - WebSocket client library for Python + +Copyright 2024 engn33r + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +# Skip test to access the internet unless TEST_WITH_INTERNET == 1 +TEST_WITH_INTERNET = os.environ.get("TEST_WITH_INTERNET", "0") == "1" +# Skip tests relying on local websockets server unless LOCAL_WS_SERVER_PORT != -1 +LOCAL_WS_SERVER_PORT = os.environ.get("LOCAL_WS_SERVER_PORT", "-1") +TEST_WITH_LOCAL_SERVER = LOCAL_WS_SERVER_PORT != "-1" +TRACEABLE = True + + +class WebSocketAppTest(unittest.TestCase): + class NotSetYet: + """A marker class for signalling that a value hasn't been set yet.""" + + def setUp(self): + ws.enableTrace(TRACEABLE) + + WebSocketAppTest.keep_running_open = WebSocketAppTest.NotSetYet() + WebSocketAppTest.keep_running_close = WebSocketAppTest.NotSetYet() + WebSocketAppTest.get_mask_key_id = WebSocketAppTest.NotSetYet() + WebSocketAppTest.on_error_data = WebSocketAppTest.NotSetYet() + + def tearDown(self): + WebSocketAppTest.keep_running_open = WebSocketAppTest.NotSetYet() + WebSocketAppTest.keep_running_close = WebSocketAppTest.NotSetYet() + WebSocketAppTest.get_mask_key_id = WebSocketAppTest.NotSetYet() + WebSocketAppTest.on_error_data = WebSocketAppTest.NotSetYet() + + def close(self): + pass + + @unittest.skipUnless( + TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled" + ) + def test_keep_running(self): + """A WebSocketApp should keep running as long as its self.keep_running + is not False (in the boolean context). + """ + + def on_open(self, *args, **kwargs): + """Set the keep_running flag for later inspection and immediately + close the connection. + """ + self.send("hello!") + WebSocketAppTest.keep_running_open = self.keep_running + self.keep_running = False + + def on_message(_, message): + print(message) + self.close() + + def on_close(self, *args, **kwargs): + """Set the keep_running flag for the test to use.""" + WebSocketAppTest.keep_running_close = self.keep_running + + app = ws.WebSocketApp( + f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}", + on_open=on_open, + on_close=on_close, + on_message=on_message, + ) + app.run_forever() + + # @unittest.skipUnless(TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled") + @unittest.skipUnless(False, "Test disabled for now (requires rel)") + def test_run_forever_dispatcher(self): + """A WebSocketApp should keep running as long as its self.keep_running + is not False (in the boolean context). + """ + + def on_open(self, *args, **kwargs): + """Send a message, receive, and send one more""" + self.send("hello!") + self.recv() + self.send("goodbye!") + + def on_message(_, message): + print(message) + self.close() + + app = ws.WebSocketApp( + f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}", + on_open=on_open, + on_message=on_message, + ) + app.run_forever(dispatcher="Dispatcher") # doesn't work + + # app.run_forever(dispatcher=rel) # would work + # rel.dispatch() + + @unittest.skipUnless( + TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled" + ) + def test_run_forever_teardown_clean_exit(self): + """The WebSocketApp.run_forever() method should return `False` when the application ends gracefully.""" + app = ws.WebSocketApp(f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}") + threading.Timer(interval=0.2, function=app.close).start() + teardown = app.run_forever() + self.assertEqual(teardown, False) + + @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") + def test_sock_mask_key(self): + """A WebSocketApp should forward the received mask_key function down + to the actual socket. + """ + + def my_mask_key_func(): + return "\x00\x00\x00\x00" + + app = ws.WebSocketApp( + "wss://api-pub.bitfinex.com/ws/1", get_mask_key=my_mask_key_func + ) + + # if numpy is installed, this assertion fail + # Note: We can't use 'is' for comparing the functions directly, need to use 'id'. + self.assertEqual(id(app.get_mask_key), id(my_mask_key_func)) + + @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") + def test_invalid_ping_interval_ping_timeout(self): + """Test exception handling if ping_interval < ping_timeout""" + + def on_ping(app, _): + print("Got a ping!") + app.close() + + def on_pong(app, _): + print("Got a pong! No need to respond") + app.close() + + app = ws.WebSocketApp( + "wss://api-pub.bitfinex.com/ws/1", on_ping=on_ping, on_pong=on_pong + ) + self.assertRaises( + ws.WebSocketException, + app.run_forever, + ping_interval=1, + ping_timeout=2, + sslopt={"cert_reqs": ssl.CERT_NONE}, + ) + + @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") + def test_ping_interval(self): + """Test WebSocketApp proper ping functionality""" + + def on_ping(app, _): + print("Got a ping!") + app.close() + + def on_pong(app, _): + print("Got a pong! No need to respond") + app.close() + + app = ws.WebSocketApp( + "wss://api-pub.bitfinex.com/ws/1", on_ping=on_ping, on_pong=on_pong + ) + app.run_forever( + ping_interval=2, ping_timeout=1, sslopt={"cert_reqs": ssl.CERT_NONE} + ) + + @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") + def test_opcode_close(self): + """Test WebSocketApp close opcode""" + + app = ws.WebSocketApp("wss://tsock.us1.twilio.com/v3/wsconnect") + app.run_forever(ping_interval=2, ping_timeout=1, ping_payload="Ping payload") + + # This is commented out because the URL no longer responds in the expected way + # @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") + # def testOpcodeBinary(self): + # """ Test WebSocketApp binary opcode + # """ + # app = ws.WebSocketApp('wss://streaming.vn.teslamotors.com/streaming/') + # app.run_forever(ping_interval=2, ping_timeout=1, ping_payload="Ping payload") + + @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") + def test_bad_ping_interval(self): + """A WebSocketApp handling of negative ping_interval""" + app = ws.WebSocketApp("wss://api-pub.bitfinex.com/ws/1") + self.assertRaises( + ws.WebSocketException, + app.run_forever, + ping_interval=-5, + sslopt={"cert_reqs": ssl.CERT_NONE}, + ) + + @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") + def test_bad_ping_timeout(self): + """A WebSocketApp handling of negative ping_timeout""" + app = ws.WebSocketApp("wss://api-pub.bitfinex.com/ws/1") + self.assertRaises( + ws.WebSocketException, + app.run_forever, + ping_timeout=-3, + sslopt={"cert_reqs": ssl.CERT_NONE}, + ) + + @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") + def test_close_status_code(self): + """Test extraction of close frame status code and close reason in WebSocketApp""" + + def on_close(wsapp, close_status_code, close_msg): + print("on_close reached") + + app = ws.WebSocketApp( + "wss://tsock.us1.twilio.com/v3/wsconnect", on_close=on_close + ) + closeframe = ws.ABNF( + opcode=ws.ABNF.OPCODE_CLOSE, data=b"\x03\xe8no-init-from-client" + ) + self.assertEqual([1000, "no-init-from-client"], app._get_close_args(closeframe)) + + closeframe = ws.ABNF(opcode=ws.ABNF.OPCODE_CLOSE, data=b"") + self.assertEqual([None, None], app._get_close_args(closeframe)) + + app2 = ws.WebSocketApp("wss://tsock.us1.twilio.com/v3/wsconnect") + closeframe = ws.ABNF(opcode=ws.ABNF.OPCODE_CLOSE, data=b"") + self.assertEqual([None, None], app2._get_close_args(closeframe)) + + self.assertRaises( + ws.WebSocketConnectionClosedException, + app.send, + data="test if connection is closed", + ) + + @unittest.skipUnless( + TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled" + ) + def test_callback_function_exception(self): + """Test callback function exception handling""" + + exc = None + passed_app = None + + def on_open(app): + raise RuntimeError("Callback failed") + + def on_error(app, err): + nonlocal passed_app + passed_app = app + nonlocal exc + exc = err + + def on_pong(app, _): + app.close() + + app = ws.WebSocketApp( + f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}", + on_open=on_open, + on_error=on_error, + on_pong=on_pong, + ) + app.run_forever(ping_interval=2, ping_timeout=1) + + self.assertEqual(passed_app, app) + self.assertIsInstance(exc, RuntimeError) + self.assertEqual(str(exc), "Callback failed") + + @unittest.skipUnless( + TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled" + ) + def test_callback_method_exception(self): + """Test callback method exception handling""" + + class Callbacks: + def __init__(self): + self.exc = None + self.passed_app = None + self.app = ws.WebSocketApp( + f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}", + on_open=self.on_open, + on_error=self.on_error, + on_pong=self.on_pong, + ) + self.app.run_forever(ping_interval=2, ping_timeout=1) + + def on_open(self, _): + raise RuntimeError("Callback failed") + + def on_error(self, app, err): + self.passed_app = app + self.exc = err + + def on_pong(self, app, _): + app.close() + + callbacks = Callbacks() + + self.assertEqual(callbacks.passed_app, callbacks.app) + self.assertIsInstance(callbacks.exc, RuntimeError) + self.assertEqual(str(callbacks.exc), "Callback failed") + + @unittest.skipUnless( + TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled" + ) + def test_reconnect(self): + """Test reconnect""" + pong_count = 0 + exc = None + + def on_error(_, err): + nonlocal exc + exc = err + + def on_pong(app, _): + nonlocal pong_count + pong_count += 1 + if pong_count == 1: + # First pong, shutdown socket, enforce read error + app.sock.shutdown() + if pong_count >= 2: + # Got second pong after reconnect + app.close() + + app = ws.WebSocketApp( + f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}", on_pong=on_pong, on_error=on_error + ) + app.run_forever(ping_interval=2, ping_timeout=1, reconnect=3) + + self.assertEqual(pong_count, 2) + self.assertIsInstance(exc, ws.WebSocketTimeoutException) + self.assertEqual(str(exc), "ping/pong timed out") + + +if __name__ == "__main__": + unittest.main() diff --git a/contrib/python/websocket-client/websocket/tests/test_cookiejar.py b/contrib/python/websocket-client/websocket/tests/test_cookiejar.py new file mode 100644 index 0000000000..67eddb627a --- /dev/null +++ b/contrib/python/websocket-client/websocket/tests/test_cookiejar.py @@ -0,0 +1,123 @@ +import unittest + +from websocket._cookiejar import SimpleCookieJar + +""" +test_cookiejar.py +websocket - WebSocket client library for Python + +Copyright 2024 engn33r + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + + +class CookieJarTest(unittest.TestCase): + def test_add(self): + cookie_jar = SimpleCookieJar() + cookie_jar.add("") + self.assertFalse( + cookie_jar.jar, "Cookie with no domain should not be added to the jar" + ) + + cookie_jar = SimpleCookieJar() + cookie_jar.add("a=b") + self.assertFalse( + cookie_jar.jar, "Cookie with no domain should not be added to the jar" + ) + + cookie_jar = SimpleCookieJar() + cookie_jar.add("a=b; domain=.abc") + self.assertTrue(".abc" in cookie_jar.jar) + + cookie_jar = SimpleCookieJar() + cookie_jar.add("a=b; domain=abc") + self.assertTrue(".abc" in cookie_jar.jar) + self.assertTrue("abc" not in cookie_jar.jar) + + cookie_jar = SimpleCookieJar() + cookie_jar.add("a=b; c=d; domain=abc") + self.assertEqual(cookie_jar.get("abc"), "a=b; c=d") + self.assertEqual(cookie_jar.get(None), "") + + cookie_jar = SimpleCookieJar() + cookie_jar.add("a=b; c=d; domain=abc") + cookie_jar.add("e=f; domain=abc") + self.assertEqual(cookie_jar.get("abc"), "a=b; c=d; e=f") + + cookie_jar = SimpleCookieJar() + cookie_jar.add("a=b; c=d; domain=abc") + cookie_jar.add("e=f; domain=.abc") + self.assertEqual(cookie_jar.get("abc"), "a=b; c=d; e=f") + + cookie_jar = SimpleCookieJar() + cookie_jar.add("a=b; c=d; domain=abc") + cookie_jar.add("e=f; domain=xyz") + self.assertEqual(cookie_jar.get("abc"), "a=b; c=d") + self.assertEqual(cookie_jar.get("xyz"), "e=f") + self.assertEqual(cookie_jar.get("something"), "") + + def test_set(self): + cookie_jar = SimpleCookieJar() + cookie_jar.set("a=b") + self.assertFalse( + cookie_jar.jar, "Cookie with no domain should not be added to the jar" + ) + + cookie_jar = SimpleCookieJar() + cookie_jar.set("a=b; domain=.abc") + self.assertTrue(".abc" in cookie_jar.jar) + + cookie_jar = SimpleCookieJar() + cookie_jar.set("a=b; domain=abc") + self.assertTrue(".abc" in cookie_jar.jar) + self.assertTrue("abc" not in cookie_jar.jar) + + cookie_jar = SimpleCookieJar() + cookie_jar.set("a=b; c=d; domain=abc") + self.assertEqual(cookie_jar.get("abc"), "a=b; c=d") + + cookie_jar = SimpleCookieJar() + cookie_jar.set("a=b; c=d; domain=abc") + cookie_jar.set("e=f; domain=abc") + self.assertEqual(cookie_jar.get("abc"), "e=f") + + cookie_jar = SimpleCookieJar() + cookie_jar.set("a=b; c=d; domain=abc") + cookie_jar.set("e=f; domain=.abc") + self.assertEqual(cookie_jar.get("abc"), "e=f") + + cookie_jar = SimpleCookieJar() + cookie_jar.set("a=b; c=d; domain=abc") + cookie_jar.set("e=f; domain=xyz") + self.assertEqual(cookie_jar.get("abc"), "a=b; c=d") + self.assertEqual(cookie_jar.get("xyz"), "e=f") + self.assertEqual(cookie_jar.get("something"), "") + + def test_get(self): + cookie_jar = SimpleCookieJar() + cookie_jar.set("a=b; c=d; domain=abc.com") + self.assertEqual(cookie_jar.get("abc.com"), "a=b; c=d") + self.assertEqual(cookie_jar.get("x.abc.com"), "a=b; c=d") + self.assertEqual(cookie_jar.get("abc.com.es"), "") + self.assertEqual(cookie_jar.get("xabc.com"), "") + + cookie_jar.set("a=b; c=d; domain=.abc.com") + self.assertEqual(cookie_jar.get("abc.com"), "a=b; c=d") + self.assertEqual(cookie_jar.get("x.abc.com"), "a=b; c=d") + self.assertEqual(cookie_jar.get("abc.com.es"), "") + self.assertEqual(cookie_jar.get("xabc.com"), "") + + +if __name__ == "__main__": + unittest.main() diff --git a/contrib/python/websocket-client/websocket/tests/test_http.py b/contrib/python/websocket-client/websocket/tests/test_http.py new file mode 100644 index 0000000000..72465c2205 --- /dev/null +++ b/contrib/python/websocket-client/websocket/tests/test_http.py @@ -0,0 +1,371 @@ +# -*- coding: utf-8 -*- +# +import os +import os.path +import socket +import ssl +import unittest + +import websocket +from websocket._exceptions import WebSocketProxyException, WebSocketException +from websocket._http import ( + _get_addrinfo_list, + _start_proxied_socket, + _tunnel, + connect, + proxy_info, + read_headers, + HAVE_PYTHON_SOCKS, +) + +""" +test_http.py +websocket - WebSocket client library for Python + +Copyright 2024 engn33r + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +try: + from python_socks._errors import ProxyConnectionError, ProxyError, ProxyTimeoutError +except: + from websocket._http import ProxyConnectionError, ProxyError, ProxyTimeoutError + +# Skip test to access the internet unless TEST_WITH_INTERNET == 1 +TEST_WITH_INTERNET = os.environ.get("TEST_WITH_INTERNET", "0") == "1" +TEST_WITH_PROXY = os.environ.get("TEST_WITH_PROXY", "0") == "1" +# Skip tests relying on local websockets server unless LOCAL_WS_SERVER_PORT != -1 +LOCAL_WS_SERVER_PORT = os.environ.get("LOCAL_WS_SERVER_PORT", "-1") +TEST_WITH_LOCAL_SERVER = LOCAL_WS_SERVER_PORT != "-1" + + +class SockMock: + def __init__(self): + self.data = [] + self.sent = [] + + def add_packet(self, data): + self.data.append(data) + + def gettimeout(self): + return None + + def recv(self, bufsize): + if self.data: + e = self.data.pop(0) + if isinstance(e, Exception): + raise e + if len(e) > bufsize: + self.data.insert(0, e[bufsize:]) + return e[:bufsize] + + def send(self, data): + self.sent.append(data) + return len(data) + + def close(self): + pass + + +class HeaderSockMock(SockMock): + def __init__(self, fname): + SockMock.__init__(self) + import yatest.common as yc + path = os.path.join(os.path.dirname(yc.source_path(__file__)), fname) + with open(path, "rb") as f: + self.add_packet(f.read()) + + +class OptsList: + def __init__(self): + self.timeout = 1 + self.sockopt = [] + self.sslopt = {"cert_reqs": ssl.CERT_NONE} + + +class HttpTest(unittest.TestCase): + def test_read_header(self): + status, header, _ = read_headers(HeaderSockMock("data/header01.txt")) + self.assertEqual(status, 101) + self.assertEqual(header["connection"], "Upgrade") + # header02.txt is intentionally malformed + self.assertRaises( + WebSocketException, read_headers, HeaderSockMock("data/header02.txt") + ) + + def test_tunnel(self): + self.assertRaises( + WebSocketProxyException, + _tunnel, + HeaderSockMock("data/header01.txt"), + "example.com", + 80, + ("username", "password"), + ) + self.assertRaises( + WebSocketProxyException, + _tunnel, + HeaderSockMock("data/header02.txt"), + "example.com", + 80, + ("username", "password"), + ) + + @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") + def test_connect(self): + # Not currently testing an actual proxy connection, so just check whether proxy errors are raised. This requires internet for a DNS lookup + if HAVE_PYTHON_SOCKS: + # Need this check, otherwise case where python_socks is not installed triggers + # websocket._exceptions.WebSocketException: Python Socks is needed for SOCKS proxying but is not available + self.assertRaises( + (ProxyTimeoutError, OSError), + _start_proxied_socket, + "wss://example.com", + OptsList(), + proxy_info( + http_proxy_host="example.com", + http_proxy_port="8080", + proxy_type="socks4", + http_proxy_timeout=1, + ), + ) + self.assertRaises( + (ProxyTimeoutError, OSError), + _start_proxied_socket, + "wss://example.com", + OptsList(), + proxy_info( + http_proxy_host="example.com", + http_proxy_port="8080", + proxy_type="socks4a", + http_proxy_timeout=1, + ), + ) + self.assertRaises( + (ProxyTimeoutError, OSError), + _start_proxied_socket, + "wss://example.com", + OptsList(), + proxy_info( + http_proxy_host="example.com", + http_proxy_port="8080", + proxy_type="socks5", + http_proxy_timeout=1, + ), + ) + self.assertRaises( + (ProxyTimeoutError, OSError), + _start_proxied_socket, + "wss://example.com", + OptsList(), + proxy_info( + http_proxy_host="example.com", + http_proxy_port="8080", + proxy_type="socks5h", + http_proxy_timeout=1, + ), + ) + self.assertRaises( + ProxyConnectionError, + connect, + "wss://example.com", + OptsList(), + proxy_info( + http_proxy_host="127.0.0.1", + http_proxy_port=9999, + proxy_type="socks4", + http_proxy_timeout=1, + ), + None, + ) + + self.assertRaises( + TypeError, + _get_addrinfo_list, + None, + 80, + True, + proxy_info( + http_proxy_host="127.0.0.1", http_proxy_port="9999", proxy_type="http" + ), + ) + self.assertRaises( + TypeError, + _get_addrinfo_list, + None, + 80, + True, + proxy_info( + http_proxy_host="127.0.0.1", http_proxy_port="9999", proxy_type="http" + ), + ) + self.assertRaises( + socket.timeout, + connect, + "wss://google.com", + OptsList(), + proxy_info( + http_proxy_host="8.8.8.8", + http_proxy_port=9999, + proxy_type="http", + http_proxy_timeout=1, + ), + None, + ) + self.assertEqual( + connect( + "wss://google.com", + OptsList(), + proxy_info( + http_proxy_host="8.8.8.8", http_proxy_port=8080, proxy_type="http" + ), + True, + ), + (True, ("google.com", 443, "/")), + ) + # The following test fails on Mac OS with a gaierror, not an OverflowError + # self.assertRaises(OverflowError, connect, "wss://example.com", OptsList(), proxy_info(http_proxy_host="127.0.0.1", http_proxy_port=99999, proxy_type="socks4", timeout=2), False) + + @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") + @unittest.skipUnless( + TEST_WITH_PROXY, "This test requires a HTTP proxy to be running on port 8899" + ) + @unittest.skipUnless( + TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled" + ) + def test_proxy_connect(self): + ws = websocket.WebSocket() + ws.connect( + f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}", + http_proxy_host="127.0.0.1", + http_proxy_port="8899", + proxy_type="http", + ) + ws.send("Hello, Server") + server_response = ws.recv() + self.assertEqual(server_response, "Hello, Server") + # self.assertEqual(_start_proxied_socket("wss://api.bitfinex.com/ws/2", OptsList(), proxy_info(http_proxy_host="127.0.0.1", http_proxy_port="8899", proxy_type="http"))[1], ("api.bitfinex.com", 443, '/ws/2')) + self.assertEqual( + _get_addrinfo_list( + "api.bitfinex.com", + 443, + True, + proxy_info( + http_proxy_host="127.0.0.1", + http_proxy_port="8899", + proxy_type="http", + ), + ), + ( + socket.getaddrinfo( + "127.0.0.1", 8899, 0, socket.SOCK_STREAM, socket.SOL_TCP + ), + True, + None, + ), + ) + self.assertEqual( + connect( + "wss://api.bitfinex.com/ws/2", + OptsList(), + proxy_info( + http_proxy_host="127.0.0.1", http_proxy_port=8899, proxy_type="http" + ), + None, + )[1], + ("api.bitfinex.com", 443, "/ws/2"), + ) + # TODO: Test SOCKS4 and SOCK5 proxies with unit tests + + @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") + def test_sslopt(self): + ssloptions = { + "check_hostname": False, + "server_hostname": "ServerName", + "ssl_version": ssl.PROTOCOL_TLS_CLIENT, + "ciphers": "TLS_AES_256_GCM_SHA384:TLS_CHACHA20_POLY1305_SHA256:\ + TLS_AES_128_GCM_SHA256:ECDHE-ECDSA-AES256-GCM-SHA384:\ + ECDHE-RSA-AES256-GCM-SHA384:DHE-RSA-AES256-GCM-SHA384:\ + ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305:\ + DHE-RSA-CHACHA20-POLY1305:ECDHE-ECDSA-AES128-GCM-SHA256:\ + ECDHE-RSA-AES128-GCM-SHA256:DHE-RSA-AES128-GCM-SHA256:\ + ECDHE-ECDSA-AES256-SHA384:ECDHE-RSA-AES256-SHA384:\ + DHE-RSA-AES256-SHA256:ECDHE-ECDSA-AES128-SHA256:\ + ECDHE-RSA-AES128-SHA256:DHE-RSA-AES128-SHA256:\ + ECDHE-ECDSA-AES256-SHA:ECDHE-RSA-AES256-SHA", + "ecdh_curve": "prime256v1", + } + ws_ssl1 = websocket.WebSocket(sslopt=ssloptions) + ws_ssl1.connect("wss://api.bitfinex.com/ws/2") + ws_ssl1.send("Hello") + ws_ssl1.close() + + ws_ssl2 = websocket.WebSocket(sslopt={"check_hostname": True}) + ws_ssl2.connect("wss://api.bitfinex.com/ws/2") + ws_ssl2.close + + def test_proxy_info(self): + self.assertEqual( + proxy_info( + http_proxy_host="127.0.0.1", http_proxy_port="8080", proxy_type="http" + ).proxy_protocol, + "http", + ) + self.assertRaises( + ProxyError, + proxy_info, + http_proxy_host="127.0.0.1", + http_proxy_port="8080", + proxy_type="badval", + ) + self.assertEqual( + proxy_info( + http_proxy_host="example.com", http_proxy_port="8080", proxy_type="http" + ).proxy_host, + "example.com", + ) + self.assertEqual( + proxy_info( + http_proxy_host="127.0.0.1", http_proxy_port="8080", proxy_type="http" + ).proxy_port, + "8080", + ) + self.assertEqual( + proxy_info( + http_proxy_host="127.0.0.1", http_proxy_port="8080", proxy_type="http" + ).auth, + None, + ) + self.assertEqual( + proxy_info( + http_proxy_host="127.0.0.1", + http_proxy_port="8080", + proxy_type="http", + http_proxy_auth=("my_username123", "my_pass321"), + ).auth[0], + "my_username123", + ) + self.assertEqual( + proxy_info( + http_proxy_host="127.0.0.1", + http_proxy_port="8080", + proxy_type="http", + http_proxy_auth=("my_username123", "my_pass321"), + ).auth[1], + "my_pass321", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/contrib/python/websocket-client/websocket/tests/test_url.py b/contrib/python/websocket-client/websocket/tests/test_url.py new file mode 100644 index 0000000000..110fdfad70 --- /dev/null +++ b/contrib/python/websocket-client/websocket/tests/test_url.py @@ -0,0 +1,464 @@ +# -*- coding: utf-8 -*- +# +import os +import unittest + +from websocket._url import ( + _is_address_in_network, + _is_no_proxy_host, + get_proxy_info, + parse_url, +) +from websocket._exceptions import WebSocketProxyException + +""" +test_url.py +websocket - WebSocket client library for Python + +Copyright 2024 engn33r + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + + +class UrlTest(unittest.TestCase): + def test_address_in_network(self): + self.assertTrue(_is_address_in_network("127.0.0.1", "127.0.0.0/8")) + self.assertTrue(_is_address_in_network("127.1.0.1", "127.0.0.0/8")) + self.assertFalse(_is_address_in_network("127.1.0.1", "127.0.0.0/24")) + + def test_parse_url(self): + p = parse_url("ws://www.example.com/r") + self.assertEqual(p[0], "www.example.com") + self.assertEqual(p[1], 80) + self.assertEqual(p[2], "/r") + self.assertEqual(p[3], False) + + p = parse_url("ws://www.example.com/r/") + self.assertEqual(p[0], "www.example.com") + self.assertEqual(p[1], 80) + self.assertEqual(p[2], "/r/") + self.assertEqual(p[3], False) + + p = parse_url("ws://www.example.com/") + self.assertEqual(p[0], "www.example.com") + self.assertEqual(p[1], 80) + self.assertEqual(p[2], "/") + self.assertEqual(p[3], False) + + p = parse_url("ws://www.example.com") + self.assertEqual(p[0], "www.example.com") + self.assertEqual(p[1], 80) + self.assertEqual(p[2], "/") + self.assertEqual(p[3], False) + + p = parse_url("ws://www.example.com:8080/r") + self.assertEqual(p[0], "www.example.com") + self.assertEqual(p[1], 8080) + self.assertEqual(p[2], "/r") + self.assertEqual(p[3], False) + + p = parse_url("ws://www.example.com:8080/") + self.assertEqual(p[0], "www.example.com") + self.assertEqual(p[1], 8080) + self.assertEqual(p[2], "/") + self.assertEqual(p[3], False) + + p = parse_url("ws://www.example.com:8080") + self.assertEqual(p[0], "www.example.com") + self.assertEqual(p[1], 8080) + self.assertEqual(p[2], "/") + self.assertEqual(p[3], False) + + p = parse_url("wss://www.example.com:8080/r") + self.assertEqual(p[0], "www.example.com") + self.assertEqual(p[1], 8080) + self.assertEqual(p[2], "/r") + self.assertEqual(p[3], True) + + p = parse_url("wss://www.example.com:8080/r?key=value") + self.assertEqual(p[0], "www.example.com") + self.assertEqual(p[1], 8080) + self.assertEqual(p[2], "/r?key=value") + self.assertEqual(p[3], True) + + self.assertRaises(ValueError, parse_url, "http://www.example.com/r") + + p = parse_url("ws://[2a03:4000:123:83::3]/r") + self.assertEqual(p[0], "2a03:4000:123:83::3") + self.assertEqual(p[1], 80) + self.assertEqual(p[2], "/r") + self.assertEqual(p[3], False) + + p = parse_url("ws://[2a03:4000:123:83::3]:8080/r") + self.assertEqual(p[0], "2a03:4000:123:83::3") + self.assertEqual(p[1], 8080) + self.assertEqual(p[2], "/r") + self.assertEqual(p[3], False) + + p = parse_url("wss://[2a03:4000:123:83::3]/r") + self.assertEqual(p[0], "2a03:4000:123:83::3") + self.assertEqual(p[1], 443) + self.assertEqual(p[2], "/r") + self.assertEqual(p[3], True) + + p = parse_url("wss://[2a03:4000:123:83::3]:8080/r") + self.assertEqual(p[0], "2a03:4000:123:83::3") + self.assertEqual(p[1], 8080) + self.assertEqual(p[2], "/r") + self.assertEqual(p[3], True) + + +class IsNoProxyHostTest(unittest.TestCase): + def setUp(self): + self.no_proxy = os.environ.get("no_proxy", None) + if "no_proxy" in os.environ: + del os.environ["no_proxy"] + + def tearDown(self): + if self.no_proxy: + os.environ["no_proxy"] = self.no_proxy + elif "no_proxy" in os.environ: + del os.environ["no_proxy"] + + def test_match_all(self): + self.assertTrue(_is_no_proxy_host("any.websocket.org", ["*"])) + self.assertTrue(_is_no_proxy_host("192.168.0.1", ["*"])) + self.assertFalse(_is_no_proxy_host("192.168.0.1", ["192.168.1.1"])) + self.assertFalse( + _is_no_proxy_host("any.websocket.org", ["other.websocket.org"]) + ) + self.assertTrue( + _is_no_proxy_host("any.websocket.org", ["other.websocket.org", "*"]) + ) + os.environ["no_proxy"] = "*" + self.assertTrue(_is_no_proxy_host("any.websocket.org", None)) + self.assertTrue(_is_no_proxy_host("192.168.0.1", None)) + os.environ["no_proxy"] = "other.websocket.org, *" + self.assertTrue(_is_no_proxy_host("any.websocket.org", None)) + + def test_ip_address(self): + self.assertTrue(_is_no_proxy_host("127.0.0.1", ["127.0.0.1"])) + self.assertFalse(_is_no_proxy_host("127.0.0.2", ["127.0.0.1"])) + self.assertTrue( + _is_no_proxy_host("127.0.0.1", ["other.websocket.org", "127.0.0.1"]) + ) + self.assertFalse( + _is_no_proxy_host("127.0.0.2", ["other.websocket.org", "127.0.0.1"]) + ) + os.environ["no_proxy"] = "127.0.0.1" + self.assertTrue(_is_no_proxy_host("127.0.0.1", None)) + self.assertFalse(_is_no_proxy_host("127.0.0.2", None)) + os.environ["no_proxy"] = "other.websocket.org, 127.0.0.1" + self.assertTrue(_is_no_proxy_host("127.0.0.1", None)) + self.assertFalse(_is_no_proxy_host("127.0.0.2", None)) + + def test_ip_address_in_range(self): + self.assertTrue(_is_no_proxy_host("127.0.0.1", ["127.0.0.0/8"])) + self.assertTrue(_is_no_proxy_host("127.0.0.2", ["127.0.0.0/8"])) + self.assertFalse(_is_no_proxy_host("127.1.0.1", ["127.0.0.0/24"])) + os.environ["no_proxy"] = "127.0.0.0/8" + self.assertTrue(_is_no_proxy_host("127.0.0.1", None)) + self.assertTrue(_is_no_proxy_host("127.0.0.2", None)) + os.environ["no_proxy"] = "127.0.0.0/24" + self.assertFalse(_is_no_proxy_host("127.1.0.1", None)) + + def test_hostname_match(self): + self.assertTrue(_is_no_proxy_host("my.websocket.org", ["my.websocket.org"])) + self.assertTrue( + _is_no_proxy_host( + "my.websocket.org", ["other.websocket.org", "my.websocket.org"] + ) + ) + self.assertFalse(_is_no_proxy_host("my.websocket.org", ["other.websocket.org"])) + os.environ["no_proxy"] = "my.websocket.org" + self.assertTrue(_is_no_proxy_host("my.websocket.org", None)) + self.assertFalse(_is_no_proxy_host("other.websocket.org", None)) + os.environ["no_proxy"] = "other.websocket.org, my.websocket.org" + self.assertTrue(_is_no_proxy_host("my.websocket.org", None)) + + def test_hostname_match_domain(self): + self.assertTrue(_is_no_proxy_host("any.websocket.org", [".websocket.org"])) + self.assertTrue(_is_no_proxy_host("my.other.websocket.org", [".websocket.org"])) + self.assertTrue( + _is_no_proxy_host( + "any.websocket.org", ["my.websocket.org", ".websocket.org"] + ) + ) + self.assertFalse(_is_no_proxy_host("any.websocket.com", [".websocket.org"])) + os.environ["no_proxy"] = ".websocket.org" + self.assertTrue(_is_no_proxy_host("any.websocket.org", None)) + self.assertTrue(_is_no_proxy_host("my.other.websocket.org", None)) + self.assertFalse(_is_no_proxy_host("any.websocket.com", None)) + os.environ["no_proxy"] = "my.websocket.org, .websocket.org" + self.assertTrue(_is_no_proxy_host("any.websocket.org", None)) + + +class ProxyInfoTest(unittest.TestCase): + def setUp(self): + self.http_proxy = os.environ.get("http_proxy", None) + self.https_proxy = os.environ.get("https_proxy", None) + self.no_proxy = os.environ.get("no_proxy", None) + if "http_proxy" in os.environ: + del os.environ["http_proxy"] + if "https_proxy" in os.environ: + del os.environ["https_proxy"] + if "no_proxy" in os.environ: + del os.environ["no_proxy"] + + def tearDown(self): + if self.http_proxy: + os.environ["http_proxy"] = self.http_proxy + elif "http_proxy" in os.environ: + del os.environ["http_proxy"] + + if self.https_proxy: + os.environ["https_proxy"] = self.https_proxy + elif "https_proxy" in os.environ: + del os.environ["https_proxy"] + + if self.no_proxy: + os.environ["no_proxy"] = self.no_proxy + elif "no_proxy" in os.environ: + del os.environ["no_proxy"] + + def test_proxy_from_args(self): + self.assertRaises( + WebSocketProxyException, + get_proxy_info, + "echo.websocket.events", + False, + proxy_host="localhost", + ) + self.assertEqual( + get_proxy_info( + "echo.websocket.events", False, proxy_host="localhost", proxy_port=3128 + ), + ("localhost", 3128, None), + ) + self.assertEqual( + get_proxy_info( + "echo.websocket.events", True, proxy_host="localhost", proxy_port=3128 + ), + ("localhost", 3128, None), + ) + + self.assertEqual( + get_proxy_info( + "echo.websocket.events", + False, + proxy_host="localhost", + proxy_port=9001, + proxy_auth=("a", "b"), + ), + ("localhost", 9001, ("a", "b")), + ) + self.assertEqual( + get_proxy_info( + "echo.websocket.events", + False, + proxy_host="localhost", + proxy_port=3128, + proxy_auth=("a", "b"), + ), + ("localhost", 3128, ("a", "b")), + ) + self.assertEqual( + get_proxy_info( + "echo.websocket.events", + True, + proxy_host="localhost", + proxy_port=8765, + proxy_auth=("a", "b"), + ), + ("localhost", 8765, ("a", "b")), + ) + self.assertEqual( + get_proxy_info( + "echo.websocket.events", + True, + proxy_host="localhost", + proxy_port=3128, + proxy_auth=("a", "b"), + ), + ("localhost", 3128, ("a", "b")), + ) + + self.assertEqual( + get_proxy_info( + "echo.websocket.events", + True, + proxy_host="localhost", + proxy_port=3128, + no_proxy=["example.com"], + proxy_auth=("a", "b"), + ), + ("localhost", 3128, ("a", "b")), + ) + self.assertEqual( + get_proxy_info( + "echo.websocket.events", + True, + proxy_host="localhost", + proxy_port=3128, + no_proxy=["echo.websocket.events"], + proxy_auth=("a", "b"), + ), + (None, 0, None), + ) + + self.assertEqual( + get_proxy_info( + "echo.websocket.events", + True, + proxy_host="localhost", + proxy_port=3128, + no_proxy=[".websocket.events"], + ), + (None, 0, None), + ) + + def test_proxy_from_env(self): + os.environ["http_proxy"] = "http://localhost/" + self.assertEqual( + get_proxy_info("echo.websocket.events", False), ("localhost", None, None) + ) + os.environ["http_proxy"] = "http://localhost:3128/" + self.assertEqual( + get_proxy_info("echo.websocket.events", False), ("localhost", 3128, None) + ) + + os.environ["http_proxy"] = "http://localhost/" + os.environ["https_proxy"] = "http://localhost2/" + self.assertEqual( + get_proxy_info("echo.websocket.events", False), ("localhost", None, None) + ) + os.environ["http_proxy"] = "http://localhost:3128/" + os.environ["https_proxy"] = "http://localhost2:3128/" + self.assertEqual( + get_proxy_info("echo.websocket.events", False), ("localhost", 3128, None) + ) + + os.environ["http_proxy"] = "http://localhost/" + os.environ["https_proxy"] = "http://localhost2/" + self.assertEqual( + get_proxy_info("echo.websocket.events", True), ("localhost2", None, None) + ) + os.environ["http_proxy"] = "http://localhost:3128/" + os.environ["https_proxy"] = "http://localhost2:3128/" + self.assertEqual( + get_proxy_info("echo.websocket.events", True), ("localhost2", 3128, None) + ) + + os.environ["http_proxy"] = "" + os.environ["https_proxy"] = "http://localhost2/" + self.assertEqual( + get_proxy_info("echo.websocket.events", True), ("localhost2", None, None) + ) + self.assertEqual( + get_proxy_info("echo.websocket.events", False), (None, 0, None) + ) + os.environ["http_proxy"] = "" + os.environ["https_proxy"] = "http://localhost2:3128/" + self.assertEqual( + get_proxy_info("echo.websocket.events", True), ("localhost2", 3128, None) + ) + self.assertEqual( + get_proxy_info("echo.websocket.events", False), (None, 0, None) + ) + + os.environ["http_proxy"] = "http://localhost/" + os.environ["https_proxy"] = "" + self.assertEqual(get_proxy_info("echo.websocket.events", True), (None, 0, None)) + self.assertEqual( + get_proxy_info("echo.websocket.events", False), ("localhost", None, None) + ) + os.environ["http_proxy"] = "http://localhost:3128/" + os.environ["https_proxy"] = "" + self.assertEqual(get_proxy_info("echo.websocket.events", True), (None, 0, None)) + self.assertEqual( + get_proxy_info("echo.websocket.events", False), ("localhost", 3128, None) + ) + + os.environ["http_proxy"] = "http://a:b@localhost/" + self.assertEqual( + get_proxy_info("echo.websocket.events", False), + ("localhost", None, ("a", "b")), + ) + os.environ["http_proxy"] = "http://a:b@localhost:3128/" + self.assertEqual( + get_proxy_info("echo.websocket.events", False), + ("localhost", 3128, ("a", "b")), + ) + + os.environ["http_proxy"] = "http://a:b@localhost/" + os.environ["https_proxy"] = "http://a:b@localhost2/" + self.assertEqual( + get_proxy_info("echo.websocket.events", False), + ("localhost", None, ("a", "b")), + ) + os.environ["http_proxy"] = "http://a:b@localhost:3128/" + os.environ["https_proxy"] = "http://a:b@localhost2:3128/" + self.assertEqual( + get_proxy_info("echo.websocket.events", False), + ("localhost", 3128, ("a", "b")), + ) + + os.environ["http_proxy"] = "http://a:b@localhost/" + os.environ["https_proxy"] = "http://a:b@localhost2/" + self.assertEqual( + get_proxy_info("echo.websocket.events", True), + ("localhost2", None, ("a", "b")), + ) + os.environ["http_proxy"] = "http://a:b@localhost:3128/" + os.environ["https_proxy"] = "http://a:b@localhost2:3128/" + self.assertEqual( + get_proxy_info("echo.websocket.events", True), + ("localhost2", 3128, ("a", "b")), + ) + + os.environ[ + "http_proxy" + ] = "http://john%40example.com:P%40SSWORD@localhost:3128/" + os.environ[ + "https_proxy" + ] = "http://john%40example.com:P%40SSWORD@localhost2:3128/" + self.assertEqual( + get_proxy_info("echo.websocket.events", True), + ("localhost2", 3128, ("john@example.com", "P@SSWORD")), + ) + + os.environ["http_proxy"] = "http://a:b@localhost/" + os.environ["https_proxy"] = "http://a:b@localhost2/" + os.environ["no_proxy"] = "example1.com,example2.com" + self.assertEqual( + get_proxy_info("example.1.com", True), ("localhost2", None, ("a", "b")) + ) + os.environ["http_proxy"] = "http://a:b@localhost:3128/" + os.environ["https_proxy"] = "http://a:b@localhost2:3128/" + os.environ["no_proxy"] = "example1.com,example2.com, echo.websocket.events" + self.assertEqual(get_proxy_info("echo.websocket.events", True), (None, 0, None)) + os.environ["http_proxy"] = "http://a:b@localhost:3128/" + os.environ["https_proxy"] = "http://a:b@localhost2:3128/" + os.environ["no_proxy"] = "example1.com,example2.com, .websocket.events" + self.assertEqual(get_proxy_info("echo.websocket.events", True), (None, 0, None)) + + os.environ["http_proxy"] = "http://a:b@localhost:3128/" + os.environ["https_proxy"] = "http://a:b@localhost2:3128/" + os.environ["no_proxy"] = "127.0.0.0/8, 192.168.0.0/16" + self.assertEqual(get_proxy_info("127.0.0.1", False), (None, 0, None)) + self.assertEqual(get_proxy_info("192.168.1.1", False), (None, 0, None)) + + +if __name__ == "__main__": + unittest.main() diff --git a/contrib/python/websocket-client/websocket/tests/test_websocket.py b/contrib/python/websocket-client/websocket/tests/test_websocket.py new file mode 100644 index 0000000000..892312a2db --- /dev/null +++ b/contrib/python/websocket-client/websocket/tests/test_websocket.py @@ -0,0 +1,498 @@ +# -*- coding: utf-8 -*- +# +import os +import os.path +import socket +import unittest +from base64 import decodebytes as base64decode + +import websocket as ws +from websocket._exceptions import WebSocketBadStatusException, WebSocketAddressException +from websocket._handshake import _create_sec_websocket_key +from websocket._handshake import _validate as _validate_header +from websocket._http import read_headers +from websocket._utils import validate_utf8 + +""" +test_websocket.py +websocket - WebSocket client library for Python + +Copyright 2024 engn33r + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +try: + import ssl +except ImportError: + # dummy class of SSLError for ssl none-support environment. + class SSLError(Exception): + pass + + +# Skip test to access the internet unless TEST_WITH_INTERNET == 1 +TEST_WITH_INTERNET = os.environ.get("TEST_WITH_INTERNET", "0") == "1" +# Skip tests relying on local websockets server unless LOCAL_WS_SERVER_PORT != -1 +LOCAL_WS_SERVER_PORT = os.environ.get("LOCAL_WS_SERVER_PORT", "-1") +TEST_WITH_LOCAL_SERVER = LOCAL_WS_SERVER_PORT != "-1" +TRACEABLE = True + + +def create_mask_key(_): + return "abcd" + + +class SockMock: + def __init__(self): + self.data = [] + self.sent = [] + + def add_packet(self, data): + self.data.append(data) + + def gettimeout(self): + return None + + def recv(self, bufsize): + if self.data: + e = self.data.pop(0) + if isinstance(e, Exception): + raise e + if len(e) > bufsize: + self.data.insert(0, e[bufsize:]) + return e[:bufsize] + + def send(self, data): + self.sent.append(data) + return len(data) + + def close(self): + pass + + +class HeaderSockMock(SockMock): + def __init__(self, fname): + SockMock.__init__(self) + import yatest.common as yc + path = os.path.join(os.path.dirname(yc.source_path(__file__)), fname) + with open(path, "rb") as f: + self.add_packet(f.read()) + + +class WebSocketTest(unittest.TestCase): + def setUp(self): + ws.enableTrace(TRACEABLE) + + def tearDown(self): + pass + + def test_default_timeout(self): + self.assertEqual(ws.getdefaulttimeout(), None) + ws.setdefaulttimeout(10) + self.assertEqual(ws.getdefaulttimeout(), 10) + ws.setdefaulttimeout(None) + + def test_ws_key(self): + key = _create_sec_websocket_key() + self.assertTrue(key != 24) + self.assertTrue("¥n" not in key) + + def test_nonce(self): + """WebSocket key should be a random 16-byte nonce.""" + key = _create_sec_websocket_key() + nonce = base64decode(key.encode("utf-8")) + self.assertEqual(16, len(nonce)) + + def test_ws_utils(self): + key = "c6b8hTg4EeGb2gQMztV1/g==" + required_header = { + "upgrade": "websocket", + "connection": "upgrade", + "sec-websocket-accept": "Kxep+hNu9n51529fGidYu7a3wO0=", + } + self.assertEqual(_validate_header(required_header, key, None), (True, None)) + + header = required_header.copy() + header["upgrade"] = "http" + self.assertEqual(_validate_header(header, key, None), (False, None)) + del header["upgrade"] + self.assertEqual(_validate_header(header, key, None), (False, None)) + + header = required_header.copy() + header["connection"] = "something" + self.assertEqual(_validate_header(header, key, None), (False, None)) + del header["connection"] + self.assertEqual(_validate_header(header, key, None), (False, None)) + + header = required_header.copy() + header["sec-websocket-accept"] = "something" + self.assertEqual(_validate_header(header, key, None), (False, None)) + del header["sec-websocket-accept"] + self.assertEqual(_validate_header(header, key, None), (False, None)) + + header = required_header.copy() + header["sec-websocket-protocol"] = "sub1" + self.assertEqual( + _validate_header(header, key, ["sub1", "sub2"]), (True, "sub1") + ) + # This case will print out a logging error using the error() function, but that is expected + self.assertEqual(_validate_header(header, key, ["sub2", "sub3"]), (False, None)) + + header = required_header.copy() + header["sec-websocket-protocol"] = "sUb1" + self.assertEqual( + _validate_header(header, key, ["Sub1", "suB2"]), (True, "sub1") + ) + + header = required_header.copy() + # This case will print out a logging error using the error() function, but that is expected + self.assertEqual(_validate_header(header, key, ["Sub1", "suB2"]), (False, None)) + + def test_read_header(self): + status, header, _ = read_headers(HeaderSockMock("data/header01.txt")) + self.assertEqual(status, 101) + self.assertEqual(header["connection"], "Upgrade") + + status, header, _ = read_headers(HeaderSockMock("data/header03.txt")) + self.assertEqual(status, 101) + self.assertEqual(header["connection"], "Upgrade, Keep-Alive") + + HeaderSockMock("data/header02.txt") + self.assertRaises( + ws.WebSocketException, read_headers, HeaderSockMock("data/header02.txt") + ) + + def test_send(self): + # TODO: add longer frame data + sock = ws.WebSocket() + sock.set_mask_key(create_mask_key) + s = sock.sock = HeaderSockMock("data/header01.txt") + sock.send("Hello") + self.assertEqual(s.sent[0], b"\x81\x85abcd)\x07\x0f\x08\x0e") + + sock.send("こんにちは") + self.assertEqual( + s.sent[1], + b"\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc", + ) + + # sock.send("x" * 5000) + # self.assertEqual(s.sent[1], b'\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc") + + self.assertEqual(sock.send_binary(b"1111111111101"), 19) + + def test_recv(self): + # TODO: add longer frame data + sock = ws.WebSocket() + s = sock.sock = SockMock() + something = ( + b"\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc" + ) + s.add_packet(something) + data = sock.recv() + self.assertEqual(data, "こんにちは") + + s.add_packet(b"\x81\x85abcd)\x07\x0f\x08\x0e") + data = sock.recv() + self.assertEqual(data, "Hello") + + @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") + def test_iter(self): + count = 2 + s = ws.create_connection("wss://api.bitfinex.com/ws/2") + s.send('{"event": "subscribe", "channel": "ticker"}') + for _ in s: + count -= 1 + if count == 0: + break + + @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") + def test_next(self): + sock = ws.create_connection("wss://api.bitfinex.com/ws/2") + self.assertEqual(str, type(next(sock))) + + def test_internal_recv_strict(self): + sock = ws.WebSocket() + s = sock.sock = SockMock() + s.add_packet(b"foo") + s.add_packet(socket.timeout()) + s.add_packet(b"bar") + # s.add_packet(SSLError("The read operation timed out")) + s.add_packet(b"baz") + with self.assertRaises(ws.WebSocketTimeoutException): + sock.frame_buffer.recv_strict(9) + # with self.assertRaises(SSLError): + # data = sock._recv_strict(9) + data = sock.frame_buffer.recv_strict(9) + self.assertEqual(data, b"foobarbaz") + with self.assertRaises(ws.WebSocketConnectionClosedException): + sock.frame_buffer.recv_strict(1) + + def test_recv_timeout(self): + sock = ws.WebSocket() + s = sock.sock = SockMock() + s.add_packet(b"\x81") + s.add_packet(socket.timeout()) + s.add_packet(b"\x8dabcd\x29\x07\x0f\x08\x0e") + s.add_packet(socket.timeout()) + s.add_packet(b"\x4e\x43\x33\x0e\x10\x0f\x00\x40") + with self.assertRaises(ws.WebSocketTimeoutException): + sock.recv() + with self.assertRaises(ws.WebSocketTimeoutException): + sock.recv() + data = sock.recv() + self.assertEqual(data, "Hello, World!") + with self.assertRaises(ws.WebSocketConnectionClosedException): + sock.recv() + + def test_recv_with_simple_fragmentation(self): + sock = ws.WebSocket() + s = sock.sock = SockMock() + # OPCODE=TEXT, FIN=0, MSG="Brevity is " + s.add_packet(b"\x01\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C") + # OPCODE=CONT, FIN=1, MSG="the soul of wit" + s.add_packet(b"\x80\x8fabcd\x15\n\x06D\x12\r\x16\x08A\r\x05D\x16\x0b\x17") + data = sock.recv() + self.assertEqual(data, "Brevity is the soul of wit") + with self.assertRaises(ws.WebSocketConnectionClosedException): + sock.recv() + + def test_recv_with_fire_event_of_fragmentation(self): + sock = ws.WebSocket(fire_cont_frame=True) + s = sock.sock = SockMock() + # OPCODE=TEXT, FIN=0, MSG="Brevity is " + s.add_packet(b"\x01\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C") + # OPCODE=CONT, FIN=0, MSG="Brevity is " + s.add_packet(b"\x00\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C") + # OPCODE=CONT, FIN=1, MSG="the soul of wit" + s.add_packet(b"\x80\x8fabcd\x15\n\x06D\x12\r\x16\x08A\r\x05D\x16\x0b\x17") + + _, data = sock.recv_data() + self.assertEqual(data, b"Brevity is ") + _, data = sock.recv_data() + self.assertEqual(data, b"Brevity is ") + _, data = sock.recv_data() + self.assertEqual(data, b"the soul of wit") + + # OPCODE=CONT, FIN=0, MSG="Brevity is " + s.add_packet(b"\x80\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C") + + with self.assertRaises(ws.WebSocketException): + sock.recv_data() + + with self.assertRaises(ws.WebSocketConnectionClosedException): + sock.recv() + + def test_close(self): + sock = ws.WebSocket() + sock.connected = True + sock.close + + sock = ws.WebSocket() + s = sock.sock = SockMock() + sock.connected = True + s.add_packet(b"\x88\x80\x17\x98p\x84") + sock.recv() + self.assertEqual(sock.connected, False) + + def test_recv_cont_fragmentation(self): + sock = ws.WebSocket() + s = sock.sock = SockMock() + # OPCODE=CONT, FIN=1, MSG="the soul of wit" + s.add_packet(b"\x80\x8fabcd\x15\n\x06D\x12\r\x16\x08A\r\x05D\x16\x0b\x17") + self.assertRaises(ws.WebSocketException, sock.recv) + + def test_recv_with_prolonged_fragmentation(self): + sock = ws.WebSocket() + s = sock.sock = SockMock() + # OPCODE=TEXT, FIN=0, MSG="Once more unto the breach, " + s.add_packet( + b"\x01\x9babcd.\x0c\x00\x01A\x0f\x0c\x16\x04B\x16\n\x15\rC\x10\t\x07C\x06\x13\x07\x02\x07\tNC" + ) + # OPCODE=CONT, FIN=0, MSG="dear friends, " + s.add_packet(b"\x00\x8eabcd\x05\x07\x02\x16A\x04\x11\r\x04\x0c\x07\x17MB") + # OPCODE=CONT, FIN=1, MSG="once more" + s.add_packet(b"\x80\x89abcd\x0e\x0c\x00\x01A\x0f\x0c\x16\x04") + data = sock.recv() + self.assertEqual(data, "Once more unto the breach, dear friends, once more") + with self.assertRaises(ws.WebSocketConnectionClosedException): + sock.recv() + + def test_recv_with_fragmentation_and_control_frame(self): + sock = ws.WebSocket() + sock.set_mask_key(create_mask_key) + s = sock.sock = SockMock() + # OPCODE=TEXT, FIN=0, MSG="Too much " + s.add_packet(b"\x01\x89abcd5\r\x0cD\x0c\x17\x00\x0cA") + # OPCODE=PING, FIN=1, MSG="Please PONG this" + s.add_packet(b"\x89\x90abcd1\x0e\x06\x05\x12\x07C4.,$D\x15\n\n\x17") + # OPCODE=CONT, FIN=1, MSG="of a good thing" + s.add_packet(b"\x80\x8fabcd\x0e\x04C\x05A\x05\x0c\x0b\x05B\x17\x0c\x08\x0c\x04") + data = sock.recv() + self.assertEqual(data, "Too much of a good thing") + with self.assertRaises(ws.WebSocketConnectionClosedException): + sock.recv() + self.assertEqual( + s.sent[0], b"\x8a\x90abcd1\x0e\x06\x05\x12\x07C4.,$D\x15\n\n\x17" + ) + + @unittest.skipUnless( + TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled" + ) + def test_websocket(self): + s = ws.create_connection(f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}") + self.assertNotEqual(s, None) + s.send("Hello, World") + result = s.next() + s.fileno() + self.assertEqual(result, "Hello, World") + + s.send("こにゃにゃちは、世界") + result = s.recv() + self.assertEqual(result, "こにゃにゃちは、世界") + self.assertRaises(ValueError, s.send_close, -1, "") + s.close() + + @unittest.skipUnless( + TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled" + ) + def test_ping_pong(self): + s = ws.create_connection(f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}") + self.assertNotEqual(s, None) + s.ping("Hello") + s.pong("Hi") + s.close() + + @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") + def test_support_redirect(self): + s = ws.WebSocket() + self.assertRaises(WebSocketBadStatusException, s.connect, "ws://google.com/") + # Need to find a URL that has a redirect code leading to a websocket + + @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") + def test_secure_websocket(self): + s = ws.create_connection("wss://api.bitfinex.com/ws/2") + self.assertNotEqual(s, None) + self.assertTrue(isinstance(s.sock, ssl.SSLSocket)) + self.assertEqual(s.getstatus(), 101) + self.assertNotEqual(s.getheaders(), None) + s.settimeout(10) + self.assertEqual(s.gettimeout(), 10) + self.assertEqual(s.getsubprotocol(), None) + s.abort() + + @unittest.skipUnless( + TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled" + ) + def test_websocket_with_custom_header(self): + s = ws.create_connection( + f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}", + headers={"User-Agent": "PythonWebsocketClient"}, + ) + self.assertNotEqual(s, None) + self.assertEqual(s.getsubprotocol(), None) + s.send("Hello, World") + result = s.recv() + self.assertEqual(result, "Hello, World") + self.assertRaises(ValueError, s.close, -1, "") + s.close() + + @unittest.skipUnless( + TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled" + ) + def test_after_close(self): + s = ws.create_connection(f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}") + self.assertNotEqual(s, None) + s.close() + self.assertRaises(ws.WebSocketConnectionClosedException, s.send, "Hello") + self.assertRaises(ws.WebSocketConnectionClosedException, s.recv) + + +class SockOptTest(unittest.TestCase): + @unittest.skipUnless( + TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled" + ) + def test_sockopt(self): + sockopt = ((socket.IPPROTO_TCP, socket.TCP_NODELAY, 1),) + s = ws.create_connection( + f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}", sockopt=sockopt + ) + self.assertNotEqual( + s.sock.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY), 0 + ) + s.close() + + +class UtilsTest(unittest.TestCase): + def test_utf8_validator(self): + state = validate_utf8(b"\xf0\x90\x80\x80") + self.assertEqual(state, True) + state = validate_utf8( + b"\xce\xba\xe1\xbd\xb9\xcf\x83\xce\xbc\xce\xb5\xed\xa0\x80edited" + ) + self.assertEqual(state, False) + state = validate_utf8(b"") + self.assertEqual(state, True) + + +class HandshakeTest(unittest.TestCase): + @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") + def test_http_ssl(self): + websock1 = ws.WebSocket( + sslopt={"cert_chain": ssl.get_default_verify_paths().capath}, + enable_multithread=False, + ) + self.assertRaises(ValueError, websock1.connect, "wss://api.bitfinex.com/ws/2") + websock2 = ws.WebSocket(sslopt={"certfile": "myNonexistentCertFile"}) + self.assertRaises( + FileNotFoundError, websock2.connect, "wss://api.bitfinex.com/ws/2" + ) + + @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") + def test_manual_headers(self): + websock3 = ws.WebSocket( + sslopt={ + "ca_certs": ssl.get_default_verify_paths().cafile, + "ca_cert_path": ssl.get_default_verify_paths().capath, + } + ) + self.assertRaises( + WebSocketBadStatusException, + websock3.connect, + "wss://api.bitfinex.com/ws/2", + cookie="chocolate", + origin="testing_websockets.com", + host="echo.websocket.events/websocket-client-test", + subprotocols=["testproto"], + connection="Upgrade", + header={ + "CustomHeader1": "123", + "Cookie": "TestValue", + "Sec-WebSocket-Key": "k9kFAUWNAMmf5OEMfTlOEA==", + "Sec-WebSocket-Protocol": "newprotocol", + }, + ) + + def test_ipv6(self): + websock2 = ws.WebSocket() + self.assertRaises(ValueError, websock2.connect, "2001:4860:4860::8888") + + def test_bad_urls(self): + websock3 = ws.WebSocket() + self.assertRaises(ValueError, websock3.connect, "ws//example.com") + self.assertRaises(WebSocketAddressException, websock3.connect, "ws://example") + self.assertRaises(ValueError, websock3.connect, "example.com") + + +if __name__ == "__main__": + unittest.main() |