diff options
Diffstat (limited to 'contrib/python')
33 files changed, 2790 insertions, 284 deletions
diff --git a/contrib/python/websocket-client/.dist-info/METADATA b/contrib/python/websocket-client/.dist-info/METADATA index 563e5c05268..dfb92c15dc5 100644 --- a/contrib/python/websocket-client/.dist-info/METADATA +++ b/contrib/python/websocket-client/.dist-info/METADATA @@ -1,6 +1,6 @@ -Metadata-Version: 2.1 +Metadata-Version: 2.4 Name: websocket-client -Version: 1.8.0 +Version: 1.9.0 Summary: WebSocket client for Python with low level API options Home-page: https://github.com/websocket-client/websocket-client.git Download-URL: https://github.com/websocket-client/websocket-client/releases @@ -15,29 +15,46 @@ Keywords: websockets client Classifier: Development Status :: 4 - Beta Classifier: License :: OSI Approved :: Apache Software License Classifier: Programming Language :: Python :: 3 -Classifier: Programming Language :: Python :: 3.8 Classifier: Programming Language :: Python :: 3.9 Classifier: Programming Language :: Python :: 3.10 Classifier: Programming Language :: Python :: 3.11 Classifier: Programming Language :: Python :: 3.12 +Classifier: Programming Language :: Python :: 3.13 Classifier: Operating System :: MacOS :: MacOS X Classifier: Operating System :: POSIX Classifier: Operating System :: Microsoft :: Windows Classifier: Topic :: Internet Classifier: Topic :: Software Development :: Libraries :: Python Modules Classifier: Intended Audience :: Developers -Requires-Python: >=3.8 +Requires-Python: >=3.9 Description-Content-Type: text/markdown License-File: LICENSE -Provides-Extra: docs -Requires-Dist: Sphinx >=6.0 ; extra == 'docs' -Requires-Dist: sphinx-rtd-theme >=1.1.0 ; extra == 'docs' -Requires-Dist: myst-parser >=2.0.0 ; extra == 'docs' -Provides-Extra: optional -Requires-Dist: python-socks ; extra == 'optional' -Requires-Dist: wsaccel ; extra == 'optional' Provides-Extra: test -Requires-Dist: websockets ; extra == 'test' +Requires-Dist: pytest; extra == "test" +Requires-Dist: websockets; extra == "test" +Provides-Extra: optional +Requires-Dist: python-socks; extra == "optional" +Requires-Dist: wsaccel; extra == "optional" +Provides-Extra: docs +Requires-Dist: Sphinx>=6.0; extra == "docs" +Requires-Dist: sphinx_rtd_theme>=1.1.0; extra == "docs" +Requires-Dist: myst-parser>=2.0.0; extra == "docs" +Dynamic: author +Dynamic: author-email +Dynamic: classifier +Dynamic: description +Dynamic: description-content-type +Dynamic: download-url +Dynamic: home-page +Dynamic: keywords +Dynamic: license +Dynamic: license-file +Dynamic: maintainer +Dynamic: maintainer-email +Dynamic: project-url +Dynamic: provides-extra +Dynamic: requires-python +Dynamic: summary [](https://websocket-client.readthedocs.io/) [](https://github.com/websocket-client/websocket-client/actions/workflows/build.yml) @@ -68,7 +85,7 @@ Please see the [contribution guidelines](https://github.com/websocket-client/web ## Installation You can use `pip install websocket-client` to install, or `pip install -e .` -to install from a local copy of the code. This module is tested on Python 3.8+. +to install from a local copy of the code. This module is tested on Python 3.9+. There are several optional dependencies that can be installed to enable specific websocket-client features. diff --git a/contrib/python/websocket-client/LICENSE b/contrib/python/websocket-client/LICENSE index 62a54ca4991..9ef2a65dc55 100644 --- a/contrib/python/websocket-client/LICENSE +++ b/contrib/python/websocket-client/LICENSE @@ -187,7 +187,7 @@ same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright 2024 engn33r + Copyright 2025 engn33r Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/contrib/python/websocket-client/README.md b/contrib/python/websocket-client/README.md index 14b6b506cea..860d63cb75d 100644 --- a/contrib/python/websocket-client/README.md +++ b/contrib/python/websocket-client/README.md @@ -27,7 +27,7 @@ Please see the [contribution guidelines](https://github.com/websocket-client/web ## Installation You can use `pip install websocket-client` to install, or `pip install -e .` -to install from a local copy of the code. This module is tested on Python 3.8+. +to install from a local copy of the code. This module is tested on Python 3.9+. There are several optional dependencies that can be installed to enable specific websocket-client features. diff --git a/contrib/python/websocket-client/websocket/__init__.py b/contrib/python/websocket-client/websocket/__init__.py index 559b38a6b7d..05f565d4c99 100644 --- a/contrib/python/websocket-client/websocket/__init__.py +++ b/contrib/python/websocket-client/websocket/__init__.py @@ -2,7 +2,7 @@ __init__.py websocket - WebSocket client library for Python -Copyright 2024 engn33r +Copyright 2025 engn33r Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,11 +16,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ -from ._abnf import * -from ._app import WebSocketApp as WebSocketApp, setReconnect as setReconnect -from ._core import * -from ._exceptions import * -from ._logging import * -from ._socket import * -__version__ = "1.8.0" +from ._abnf import * # noqa: F401,F403 +from ._app import ( # noqa: F401 + WebSocketApp as WebSocketApp, + set_reconnect as set_reconnect, +) +from ._core import * # noqa: F401,F403 +from ._exceptions import * # noqa: F401,F403 +from ._logging import * # noqa: F401,F403 +from ._socket import * # noqa: F401,F403 + +__version__ = "1.9.0" diff --git a/contrib/python/websocket-client/websocket/_abnf.py b/contrib/python/websocket-client/websocket/_abnf.py index d7754e0de2e..31a8d7ed8af 100644 --- a/contrib/python/websocket-client/websocket/_abnf.py +++ b/contrib/python/websocket-client/websocket/_abnf.py @@ -3,7 +3,7 @@ import os import struct import sys from threading import Lock -from typing import Callable, Optional, Union +from typing import Callable, Optional, Union, Any from ._exceptions import WebSocketPayloadException, WebSocketProtocolException from ._utils import validate_utf8 @@ -12,7 +12,7 @@ from ._utils import validate_utf8 _abnf.py websocket - WebSocket client library for Python -Copyright 2024 engn33r +Copyright 2025 engn33r Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -151,7 +151,7 @@ class ABNF: rsv3: int = 0, opcode: int = OPCODE_TEXT, mask_value: int = 1, - data: Union[str, bytes, None] = "", + data: Optional[Union[str, bytes]] = "", ) -> None: """ Constructor for ABNF. Please check RFC for arguments. @@ -185,15 +185,24 @@ class ABNF: raise WebSocketProtocolException("Invalid ping frame.") if self.opcode == ABNF.OPCODE_CLOSE: - l = len(self.data) - if not l: + data_length = len(self.data) + if not data_length: return - if l == 1 or l >= 126: + if data_length == 1 or data_length >= 126: raise WebSocketProtocolException("Invalid close frame.") - if l > 2 and not skip_utf8_validation and not validate_utf8(self.data[2:]): + if ( + data_length > 2 + and not skip_utf8_validation + and not validate_utf8(self.data[2:]) + ): raise WebSocketProtocolException("Invalid close frame.") - code = 256 * int(self.data[0]) + int(self.data[1]) + data_bytes = ( + self.data[:2] + if isinstance(self.data, bytes) + else self.data[:2].encode("utf-8") + ) + code = struct.unpack("!H", data_bytes)[0] if not self._is_valid_close_status(code): raise WebSocketProtocolException("Invalid close opcode %r", code) @@ -202,7 +211,8 @@ class ABNF: return code in VALID_CLOSE_STATUS or (3000 <= code < 5000) def __str__(self) -> str: - return f"fin={self.fin} opcode={self.opcode} data={self.data}" + data_repr = self.data if isinstance(self.data, str) else repr(self.data) + return f"fin={self.fin} opcode={self.opcode} data={data_repr}" @staticmethod def create_frame(data: Union[bytes, str], opcode: int, fin: int = 1) -> "ABNF": @@ -310,9 +320,9 @@ class frame_buffer: def clear(self) -> None: self.header: Optional[tuple] = None self.length: Optional[int] = None - self.mask_value: Union[bytes, str, None] = None + self.mask_value: Optional[Union[bytes, str]] = None - def has_received_header(self) -> bool: + def needs_header(self) -> bool: return self.header is None def recv_header(self) -> None: @@ -335,10 +345,12 @@ class frame_buffer: header_val: int = self.header[frame_buffer._HEADER_MASK_INDEX] return header_val - def has_received_length(self) -> bool: + def needs_length(self) -> bool: return self.length is None def recv_length(self) -> None: + if self.header is None: + raise WebSocketProtocolException("Header not received") bits = self.header[frame_buffer._HEADER_LENGTH_INDEX] length_bits = bits & 0x7F if length_bits == 0x7E: @@ -350,7 +362,7 @@ class frame_buffer: else: self.length = length_bits - def has_received_mask(self) -> bool: + def needs_mask(self) -> bool: return self.mask_value is None def recv_mask(self) -> None: @@ -359,23 +371,29 @@ class frame_buffer: def recv_frame(self) -> ABNF: with self.lock: # Header - if self.has_received_header(): + if self.needs_header(): self.recv_header() + if self.header is None: + raise WebSocketProtocolException("Header not received") (fin, rsv1, rsv2, rsv3, opcode, has_mask, _) = self.header # Frame length - if self.has_received_length(): + if self.needs_length(): self.recv_length() length = self.length # Mask - if self.has_received_mask(): + if self.needs_mask(): self.recv_mask() mask_value = self.mask_value # Payload + if length is None: + raise WebSocketProtocolException("Length not received") payload = self.recv_strict(length) if has_mask: + if mask_value is None: + raise WebSocketProtocolException("Mask not received") payload = ABNF.mask(mask_value, payload) # Reset for next frame @@ -387,7 +405,9 @@ class frame_buffer: return frame def recv_strict(self, bufsize: int) -> bytes: - shortage = bufsize - sum(map(len, self.recv_buffer)) + if not isinstance(bufsize, int): + raise ValueError("bufsize must be an integer") + shortage = bufsize - sum(len(buf) for buf in self.recv_buffer) while shortage > 0: # Limit buffer size that we pass to socket.recv() to avoid # fragmenting the heap -- the number of bytes recv() actually @@ -396,8 +416,12 @@ class frame_buffer: # buffers allocated and then shrunk, which results in # fragmentation. bytes_ = self.recv(min(16384, shortage)) - self.recv_buffer.append(bytes_) - shortage -= len(bytes_) + if isinstance(bytes_, bytes): + self.recv_buffer.append(bytes_) + shortage -= len(bytes_) + else: + # Handle case where recv returns int or other type + break unified = b"".join(self.recv_buffer) @@ -413,7 +437,7 @@ class continuous_frame: def __init__(self, fire_cont_frame: bool, skip_utf8_validation: bool) -> None: self.fire_cont_frame = fire_cont_frame self.skip_utf8_validation = skip_utf8_validation - self.cont_data: Optional[list] = None + self.cont_data: Optional[list[Any]] = None self.recving_frames: Optional[int] = None def validate(self, frame: ABNF) -> None: @@ -441,13 +465,18 @@ class continuous_frame: def extract(self, frame: ABNF) -> tuple: data = self.cont_data + if data is None: + raise WebSocketProtocolException("No continuation data available") self.cont_data = None frame.data = data[1] if ( not self.fire_cont_frame + and data is not None and data[0] == ABNF.OPCODE_TEXT and not self.skip_utf8_validation and not validate_utf8(frame.data) ): raise WebSocketPayloadException(f"cannot decode: {repr(frame.data)}") + if data is None: + raise WebSocketProtocolException("No continuation data available") return data[0], frame diff --git a/contrib/python/websocket-client/websocket/_app.py b/contrib/python/websocket-client/websocket/_app.py index 9fee76546b2..2e125eaeba6 100644 --- a/contrib/python/websocket-client/websocket/_app.py +++ b/contrib/python/websocket-client/websocket/_app.py @@ -1,5 +1,4 @@ import inspect -import selectors import socket import threading import time @@ -15,12 +14,13 @@ from ._exceptions import ( ) from ._ssl_compat import SSLEOFError from ._url import parse_url +from ._dispatcher import Dispatcher, DispatcherBase, SSLDispatcher, WrappedDispatcher """ _app.py websocket - WebSocket client library for Python -Copyright 2024 engn33r +Copyright 2025 engn33r Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -40,122 +40,11 @@ __all__ = ["WebSocketApp"] RECONNECT = 0 -def setReconnect(reconnectInterval: int) -> None: +def set_reconnect(reconnectInterval: int) -> None: global RECONNECT RECONNECT = reconnectInterval -class DispatcherBase: - """ - DispatcherBase - """ - - def __init__(self, app: Any, ping_timeout: Union[float, int, None]) -> None: - self.app = app - self.ping_timeout = ping_timeout - - def timeout(self, seconds: Union[float, int, None], callback: Callable) -> None: - time.sleep(seconds) - callback() - - def reconnect(self, seconds: int, reconnector: Callable) -> None: - try: - _logging.info( - f"reconnect() - retrying in {seconds} seconds [{len(inspect.stack())} frames in stack]" - ) - time.sleep(seconds) - reconnector(reconnecting=True) - except KeyboardInterrupt as e: - _logging.info(f"User exited {e}") - raise e - - -class Dispatcher(DispatcherBase): - """ - Dispatcher - """ - - def read( - self, - sock: socket.socket, - read_callback: Callable, - check_callback: Callable, - ) -> None: - sel = selectors.DefaultSelector() - sel.register(self.app.sock.sock, selectors.EVENT_READ) - try: - while self.app.keep_running: - if sel.select(self.ping_timeout): - if not read_callback(): - break - check_callback() - finally: - sel.close() - - -class SSLDispatcher(DispatcherBase): - """ - SSLDispatcher - """ - - def read( - self, - sock: socket.socket, - read_callback: Callable, - check_callback: Callable, - ) -> None: - sock = self.app.sock.sock - sel = selectors.DefaultSelector() - sel.register(sock, selectors.EVENT_READ) - try: - while self.app.keep_running: - if self.select(sock, sel): - if not read_callback(): - break - check_callback() - finally: - sel.close() - - def select(self, sock, sel: selectors.DefaultSelector): - sock = self.app.sock.sock - if sock.pending(): - return [ - sock, - ] - - r = sel.select(self.ping_timeout) - - if len(r) > 0: - return r[0][0] - - -class WrappedDispatcher: - """ - WrappedDispatcher - """ - - def __init__(self, app, ping_timeout: Union[float, int, None], dispatcher) -> None: - self.app = app - self.ping_timeout = ping_timeout - self.dispatcher = dispatcher - dispatcher.signal(2, dispatcher.abort) # keyboard interrupt - - def read( - self, - sock: socket.socket, - read_callback: Callable, - check_callback: Callable, - ) -> None: - self.dispatcher.read(sock, read_callback) - self.ping_timeout and self.timeout(self.ping_timeout, check_callback) - - def timeout(self, seconds: float, callback: Callable) -> None: - self.dispatcher.timeout(seconds, callback) - - def reconnect(self, seconds: int, reconnector: Callable) -> None: - self.timeout(seconds, reconnector) - - class WebSocketApp: """ Higher level of APIs are provided. The interface is like JavaScript WebSocket object. @@ -164,19 +53,25 @@ class WebSocketApp: def __init__( self, url: str, - header: Union[list, dict, Callable, None] = None, - on_open: Optional[Callable[[WebSocket], None]] = None, - on_reconnect: Optional[Callable[[WebSocket], None]] = None, - on_message: Optional[Callable[[WebSocket, Any], None]] = None, - on_error: Optional[Callable[[WebSocket, Any], None]] = None, - on_close: Optional[Callable[[WebSocket, Any, Any], None]] = None, + header: Optional[ + Union[ + list[str], + dict[str, str], + Callable[[], Union[list[str], dict[str, str]]], + ] + ] = None, + on_open: Optional[Callable[["WebSocketApp"], None]] = None, + on_reconnect: Optional[Callable[["WebSocketApp"], None]] = None, + on_message: Optional[Callable[["WebSocketApp", Any], None]] = None, + on_error: Optional[Callable[["WebSocketApp", Any], None]] = None, + on_close: Optional[Callable[["WebSocketApp", Any, Any], None]] = None, on_ping: Optional[Callable] = None, on_pong: Optional[Callable] = None, on_cont_message: Optional[Callable] = None, keep_running: bool = True, get_mask_key: Optional[Callable] = None, cookie: Optional[str] = None, - subprotocols: Optional[list] = None, + subprotocols: Optional[list[str]] = None, on_data: Optional[Callable] = None, socket: Optional[socket.socket] = None, ) -> None: @@ -266,7 +161,7 @@ class WebSocketApp: self.ping_thread: Optional[threading.Thread] = None self.stop_ping: Optional[threading.Event] = None self.ping_interval = float(0) - self.ping_timeout: Union[float, int, None] = None + self.ping_timeout: Optional[Union[float, int]] = None self.ping_payload = "" self.subprotocols = subprotocols self.prepared_socket = socket @@ -325,9 +220,25 @@ class WebSocketApp: self.stop_ping.set() if self.ping_thread and self.ping_thread.is_alive(): self.ping_thread.join(3) + # Handle thread leak - if thread doesn't terminate within timeout, + # force cleanup and log warning instead of abandoning the thread + if self.ping_thread.is_alive(): + _logging.warning( + "Ping thread failed to terminate within 3 seconds, " + "forcing cleanup. Thread may be blocked." + ) + # Force cleanup by clearing references even if thread is still alive + # The daemon thread will eventually be cleaned up by Python's GC + # but we prevent resource leaks by not holding references + + # Always clean up references regardless of thread state + self.ping_thread = None + self.stop_ping = None self.last_ping_tm = self.last_pong_tm = float(0) def _send_ping(self) -> None: + if self.stop_ping is None: + return if self.stop_ping.wait(self.ping_interval) or self.keep_running is False: return while not self.stop_ping.wait(self.ping_interval) and self.keep_running is True: @@ -339,12 +250,15 @@ class WebSocketApp: except Exception as e: _logging.debug(f"Failed to send ping: {e}") + def ready(self): + return self.sock and self.sock.connected + def run_forever( self, sockopt: tuple = None, sslopt: dict = None, ping_interval: Union[float, int] = 0, - ping_timeout: Union[float, int, None] = None, + ping_timeout: Optional[Union[float, int]] = None, ping_payload: str = "", http_proxy_host: str = None, http_proxy_port: Union[int, str] = None, @@ -454,17 +368,23 @@ class WebSocketApp: self._stop_ping_thread() self.keep_running = False + if self.sock: - self.sock.close() + # in cases like handleDisconnect, the "on_error" callback is called first. If the WebSocketApp + # is being used in a multithreaded application, we nee to make sure that "self.sock" is cleared + # before calling close, otherwise logic built around the sock being set can cause issues - + # specifically calling "run_forever" again, since is checks if "self.sock" is set. + current_sock = self.sock + self.sock = None + current_sock.close() + close_status_code, close_reason = self._get_close_args( close_frame if close_frame else None ) - self.sock = None - # Finally call the callback AFTER all teardown is complete self._callback(self.on_close, close_status_code, close_reason) - def setSock(reconnecting: bool = False) -> None: + def initialize_socket(reconnecting: bool = False) -> None: if reconnecting and self.sock: self.sock.shutdown() @@ -475,6 +395,7 @@ class WebSocketApp: fire_cont_frame=self.on_cont_message is not None, skip_utf8_validation=skip_utf8_validation, enable_multithread=True, + dispatcher=dispatcher, ) self.sock.settimeout(getdefaulttimeout()) @@ -520,7 +441,11 @@ class WebSocketApp: def read() -> bool: if not self.keep_running: - return teardown() + teardown() + return False + + if self.sock is None: + return False try: op_code, frame = self.sock.recv_data_frame(True) @@ -530,12 +455,12 @@ class WebSocketApp: SSLEOFError, ) as e: if custom_dispatcher: - return handleDisconnect(e, bool(reconnect)) + return closed(e) else: raise e if op_code == ABNF.OPCODE_CLOSE: - return teardown(frame) + return closed(frame) elif op_code == ABNF.OPCODE_PING: self._callback(self.on_ping, frame.data) elif op_code == ABNF.OPCODE_PONG: @@ -576,6 +501,20 @@ class WebSocketApp: raise WebSocketTimeoutException("ping/pong timed out") return True + def closed( + e: Union[ + WebSocketConnectionClosedException, + ConnectionRefusedError, + KeyboardInterrupt, + SystemExit, + Exception, + str, + ] = "closed unexpectedly", + ) -> bool: + if type(e) is str: + e = WebSocketConnectionClosedException(e) + return handleDisconnect(e, bool(reconnect)) # type: ignore[arg-type] + def handleDisconnect( e: Union[ WebSocketConnectionClosedException, @@ -602,24 +541,25 @@ class WebSocketApp: _logging.debug( f"Calling custom dispatcher reconnect [{len(inspect.stack())} frames in stack]" ) - dispatcher.reconnect(reconnect, setSock) + dispatcher.reconnect(reconnect, initialize_socket) else: _logging.error(f"{e} - goodbye") teardown() + return self.has_errored custom_dispatcher = bool(dispatcher) dispatcher = self.create_dispatcher( - ping_timeout, dispatcher, parse_url(self.url)[3] + ping_timeout, dispatcher, parse_url(self.url)[3], closed ) try: - setSock() + initialize_socket() if not custom_dispatcher and reconnect: while self.keep_running: _logging.debug( f"Calling dispatcher reconnect [{len(inspect.stack())} frames in stack]" ) - dispatcher.reconnect(reconnect, setSock) + dispatcher.reconnect(reconnect, initialize_socket) except (KeyboardInterrupt, Exception) as e: _logging.info(f"tearing down on exception {e}") teardown() @@ -632,12 +572,13 @@ class WebSocketApp: def create_dispatcher( self, - ping_timeout: Union[float, int, None], + ping_timeout: Optional[Union[float, int]], dispatcher: Optional[DispatcherBase] = None, is_ssl: bool = False, + handleDisconnect: Callable = None, ) -> Union[Dispatcher, SSLDispatcher, WrappedDispatcher]: if dispatcher: # If custom dispatcher is set, use WrappedDispatcher - return WrappedDispatcher(self, ping_timeout, dispatcher) + return WrappedDispatcher(self, ping_timeout, dispatcher, handleDisconnect) timeout = ping_timeout or 10 if is_ssl: return SSLDispatcher(self, timeout) @@ -673,5 +614,7 @@ class WebSocketApp: except Exception as e: _logging.error(f"error from callback {callback}: {e}") - if self.on_error: + # Bug fix: Prevent infinite recursion by not calling on_error + # when the failing callback IS on_error itself + if self.on_error and callback is not self.on_error: self.on_error(self, e) diff --git a/contrib/python/websocket-client/websocket/_cookiejar.py b/contrib/python/websocket-client/websocket/_cookiejar.py index 7480e5fc21c..0e5d4b169ad 100644 --- a/contrib/python/websocket-client/websocket/_cookiejar.py +++ b/contrib/python/websocket-client/websocket/_cookiejar.py @@ -5,7 +5,7 @@ from typing import Optional _cookiejar.py websocket - WebSocket client library for Python -Copyright 2024 engn33r +Copyright 2025 engn33r Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -33,11 +33,9 @@ class SimpleCookieJar: if domain := v.get("domain"): if not domain.startswith("."): domain = f".{domain}" - cookie = ( - self.jar.get(domain) - if self.jar.get(domain) - else http.cookies.SimpleCookie() - ) + cookie = self.jar.get(domain) + if cookie is None: + cookie = http.cookies.SimpleCookie() cookie.update(simple_cookie) self.jar[domain.lower()] = cookie diff --git a/contrib/python/websocket-client/websocket/_core.py b/contrib/python/websocket-client/websocket/_core.py index f940ed0573d..333015a7922 100644 --- a/contrib/python/websocket-client/websocket/_core.py +++ b/contrib/python/websocket-client/websocket/_core.py @@ -6,19 +6,24 @@ from typing import Optional, Union # websocket modules from ._abnf import ABNF, STATUS_NORMAL, continuous_frame, frame_buffer -from ._exceptions import WebSocketProtocolException, WebSocketConnectionClosedException +from ._exceptions import ( + WebSocketProtocolException, + WebSocketConnectionClosedException, + WebSocketTimeoutException, +) from ._handshake import SUPPORTED_REDIRECT_STATUSES, handshake from ._http import connect, proxy_info from ._logging import debug, error, trace, isEnabledForError, isEnabledForTrace from ._socket import getdefaulttimeout, recv, send, sock_opt from ._ssl_compat import ssl from ._utils import NoLock +from ._dispatcher import DispatcherBase, WrappedDispatcher """ _core.py websocket - WebSocket client library for Python -Copyright 2024 engn33r +Copyright 2025 engn33r Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -82,6 +87,7 @@ class WebSocket: fire_cont_frame: bool = False, enable_multithread: bool = True, skip_utf8_validation: bool = False, + dispatcher: Union[DispatcherBase, WrappedDispatcher] = None, **_, ): """ @@ -101,13 +107,14 @@ class WebSocket: # These buffer over the build-up of a single frame. self.frame_buffer = frame_buffer(self._recv, skip_utf8_validation) self.cont_frame = continuous_frame(fire_cont_frame, skip_utf8_validation) + self.dispatcher = dispatcher if enable_multithread: self.lock = threading.Lock() self.readlock = threading.Lock() else: - self.lock = NoLock() - self.readlock = NoLock() + self.lock = NoLock() # type: ignore[assignment] + self.readlock = NoLock() # type: ignore[assignment] def __iter__(self): """ @@ -140,7 +147,7 @@ class WebSocket: """ self.get_mask_key = func - def gettimeout(self) -> Union[float, int, None]: + def gettimeout(self) -> Optional[Union[float, int]]: """ Get the websocket timeout (in seconds) as an int or float @@ -151,7 +158,7 @@ class WebSocket: """ return self.sock_opt.timeout - def settimeout(self, timeout: Union[float, int, None]): + def settimeout(self, timeout: Optional[Union[float, int]]): """ Set the timeout to the websocket. @@ -200,7 +207,7 @@ class WebSocket: def is_ssl(self): try: return isinstance(self.sock, ssl.SSLSocket) - except: + except (AttributeError, NameError): return False headers = property(getheaders) @@ -334,8 +341,8 @@ class WebSocket: trace(f"++Sent decoded: {frame.__str__()}") with self.lock: while data: - l = self._send(data) - data = data[l:] + bytes_sent = self._send(data) + data = data[bytes_sent:] return length @@ -515,6 +522,8 @@ class WebSocket: try: self.connected = False self.send(struct.pack("!H", status) + reason, ABNF.OPCODE_CLOSE) + if self.sock is None: + return sock_timeout = self.sock.gettimeout() self.sock.settimeout(timeout) start_time = time.time() @@ -530,10 +539,15 @@ class WebSocket: elif recv_status != STATUS_NORMAL: error(f"close status: {repr(recv_status)}") break - except: + except ( + WebSocketConnectionClosedException, + WebSocketTimeoutException, + struct.error, + ): break - self.sock.settimeout(sock_timeout) - self.sock.shutdown(socket.SHUT_RDWR) + if self.sock is not None: + self.sock.settimeout(sock_timeout) + self.sock.shutdown(socket.SHUT_RDWR) except: pass @@ -556,6 +570,10 @@ class WebSocket: self.connected = False def _send(self, data: Union[str, bytes]): + if self.sock is None: + raise WebSocketConnectionClosedException("socket is already closed.") + if self.dispatcher: + return self.dispatcher.send(self.sock, data) return send(self.sock, data) def _recv(self, bufsize): diff --git a/contrib/python/websocket-client/websocket/_dispatcher.py b/contrib/python/websocket-client/websocket/_dispatcher.py new file mode 100644 index 00000000000..7303620ee01 --- /dev/null +++ b/contrib/python/websocket-client/websocket/_dispatcher.py @@ -0,0 +1,164 @@ +import time +import socket +import inspect +import selectors +from typing import TYPE_CHECKING, Callable, Optional, Union + +if TYPE_CHECKING: + from ._app import WebSocketApp +from . import _logging +from ._socket import send + +""" +_dispatcher.py +websocket - WebSocket client library for Python + +Copyright 2025 engn33r + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +class DispatcherBase: + """ + DispatcherBase + """ + + def __init__( + self, app: "WebSocketApp", ping_timeout: Optional[Union[float, int]] + ) -> None: + self.app = app + self.ping_timeout = ping_timeout + + def timeout(self, seconds: Optional[Union[float, int]], callback: Callable) -> None: + if seconds is not None: + time.sleep(seconds) + callback() + + def reconnect(self, seconds: int, reconnector: Callable) -> None: + try: + _logging.info( + f"reconnect() - retrying in {seconds} seconds [{len(inspect.stack())} frames in stack]" + ) + time.sleep(seconds) + reconnector(reconnecting=True) + except KeyboardInterrupt as e: + _logging.info(f"User exited {e}") + raise e + + def send(self, sock: socket.socket, data: Union[str, bytes]) -> int: + return send(sock, data) + + +class Dispatcher(DispatcherBase): + """ + Dispatcher + """ + + def read( + self, + sock: socket.socket, + read_callback: Callable, + check_callback: Callable, + ) -> None: + if self.app.sock is None or self.app.sock.sock is None: + return + sel = selectors.DefaultSelector() + sel.register(self.app.sock.sock, selectors.EVENT_READ) + try: + while self.app.keep_running: + if sel.select(self.ping_timeout): + if not read_callback(): + break + check_callback() + finally: + sel.close() + + +class SSLDispatcher(DispatcherBase): + """ + SSLDispatcher + """ + + def read( + self, + sock: socket.socket, + read_callback: Callable, + check_callback: Callable, + ) -> None: + if self.app.sock is None or self.app.sock.sock is None: + return + sock = self.app.sock.sock + sel = selectors.DefaultSelector() + sel.register(sock, selectors.EVENT_READ) + try: + while self.app.keep_running: + if self.select(sock, sel): + if not read_callback(): + break + check_callback() + finally: + sel.close() + + def select(self, sock, sel: selectors.DefaultSelector): + if self.app.sock is None: + return None + sock = self.app.sock.sock + if sock.pending(): + return [ + sock, + ] + + r = sel.select(self.ping_timeout) + + if len(r) > 0: + return r[0][0] + return None + + +class WrappedDispatcher: + """ + WrappedDispatcher + """ + + def __init__( + self, + app: "WebSocketApp", + ping_timeout: Optional[Union[float, int]], + dispatcher, + handleDisconnect, + ) -> None: + self.app = app + self.ping_timeout = ping_timeout + self.dispatcher = dispatcher + self.handleDisconnect = handleDisconnect + dispatcher.signal(2, dispatcher.abort) # keyboard interrupt + + def read( + self, + sock: socket.socket, + read_callback: Callable, + check_callback: Callable, + ) -> None: + self.dispatcher.read(sock, read_callback) + if self.ping_timeout: + self.timeout(self.ping_timeout, check_callback) + + def send(self, sock: socket.socket, data: Union[str, bytes]) -> int: + self.dispatcher.buffwrite(sock, data, send, self.handleDisconnect) + return len(data) + + def timeout(self, seconds: float, callback: Callable, *args) -> None: + self.dispatcher.timeout(seconds, callback, *args) + + def reconnect(self, seconds: int, reconnector: Callable) -> None: + self.timeout(seconds, reconnector, True) diff --git a/contrib/python/websocket-client/websocket/_exceptions.py b/contrib/python/websocket-client/websocket/_exceptions.py index cd196e44a38..71c067bc532 100644 --- a/contrib/python/websocket-client/websocket/_exceptions.py +++ b/contrib/python/websocket-client/websocket/_exceptions.py @@ -2,7 +2,7 @@ _exceptions.py websocket - WebSocket client library for Python -Copyright 2024 engn33r +Copyright 2025 engn33r Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/contrib/python/websocket-client/websocket/_handshake.py b/contrib/python/websocket-client/websocket/_handshake.py index 7bd61b82f44..ef3f7d53dd5 100644 --- a/contrib/python/websocket-client/websocket/_handshake.py +++ b/contrib/python/websocket-client/websocket/_handshake.py @@ -2,7 +2,7 @@ _handshake.py websocket - WebSocket client library for Python -Copyright 2024 engn33r +Copyright 2025 engn33r Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ + import hashlib import hmac import os @@ -142,9 +143,16 @@ def _get_resp_headers(sock, success_statuses: tuple = SUCCESS_STATUSES) -> tuple if status not in success_statuses: content_len = resp_headers.get("content-length") if content_len: - response_body = sock.recv( - int(content_len) - ) # read the body of the HTTP error message response and include it in the exception + # Use chunked reading to avoid SSL BAD_LENGTH error on large responses + from ._socket import recv + + response_body = b"" + remaining = int(content_len) + while remaining > 0: + chunk_size = min(remaining, 16384) # Read in 16KB chunks + chunk = recv(sock, chunk_size) + response_body += chunk + remaining -= len(chunk) else: response_body = None raise WebSocketBadStatusException( diff --git a/contrib/python/websocket-client/websocket/_http.py b/contrib/python/websocket-client/websocket/_http.py index 9b1bf859d91..6a23164d575 100644 --- a/contrib/python/websocket-client/websocket/_http.py +++ b/contrib/python/websocket-client/websocket/_http.py @@ -2,7 +2,7 @@ _http.py websocket - WebSocket client library for Python -Copyright 2024 engn33r +Copyright 2025 engn33r Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ + import errno import os import socket @@ -34,7 +35,7 @@ from ._url import get_proxy_info, parse_url __all__ = ["proxy_info", "connect", "read_headers"] try: - from python_socks._errors import * + from python_socks._errors import ProxyConnectionError, ProxyError, ProxyTimeoutError from python_socks._types import ProxyType from python_socks.sync import Proxy @@ -242,21 +243,36 @@ def _wrap_sni_socket(sock: socket.socket, sslopt: dict, hostname, check_hostname # For more details see also: # * https://docs.python.org/3.8/library/ssl.html?highlight=sslkeylogfile#context-creation # * https://docs.python.org/3.8/library/ssl.html?highlight=sslkeylogfile#ssl.SSLContext.keylog_filename - context.keylog_filename = os.environ.get("SSLKEYLOGFILE", None) + keylog_file = os.environ.get("SSLKEYLOGFILE") + if keylog_file is not None: + context.keylog_filename = keylog_file if sslopt.get("cert_reqs", ssl.CERT_NONE) != ssl.CERT_NONE: cafile = sslopt.get("ca_certs", None) capath = sslopt.get("ca_cert_path", None) if cafile or capath: - context.load_verify_locations(cafile=cafile, capath=capath) + try: + context.load_verify_locations(cafile=cafile, capath=capath) + except (FileNotFoundError, ssl.SSLError, ValueError) as e: + raise WebSocketException(f"SSL CA certificate loading failed: {e}") elif hasattr(context, "load_default_certs"): - context.load_default_certs(ssl.Purpose.SERVER_AUTH) + try: + context.load_default_certs(ssl.Purpose.SERVER_AUTH) + except ssl.SSLError as e: + raise WebSocketException( + f"SSL default certificate loading failed: {e}" + ) if sslopt.get("certfile", None): - context.load_cert_chain( - sslopt["certfile"], - sslopt.get("keyfile", None), - sslopt.get("password", None), - ) + try: + context.load_cert_chain( + sslopt["certfile"], + sslopt.get("keyfile", None), + sslopt.get("password", None), + ) + except (FileNotFoundError, ValueError) as e: + raise WebSocketException(f"SSL client certificate loading failed: {e}") + except ssl.SSLError as e: + raise WebSocketException(f"SSL client certificate loading failed: {e}") # Python 3.10 switch to PROTOCOL_TLS_CLIENT defaults to "cert_reqs = ssl.CERT_REQUIRED" and "check_hostname = True" # If both disabled, set check_hostname before verify_mode @@ -271,12 +287,30 @@ def _wrap_sni_socket(sock: socket.socket, sslopt: dict, hostname, check_hostname context.verify_mode = sslopt.get("cert_reqs", ssl.CERT_REQUIRED) if "ciphers" in sslopt: - context.set_ciphers(sslopt["ciphers"]) + try: + context.set_ciphers(sslopt["ciphers"]) + except ssl.SSLError as e: + raise WebSocketException(f"SSL cipher configuration failed: {e}") if "cert_chain" in sslopt: - certfile, keyfile, password = sslopt["cert_chain"] - context.load_cert_chain(certfile, keyfile, password) + try: + cert_chain = sslopt["cert_chain"] + if not isinstance(cert_chain, (tuple, list)) or len(cert_chain) != 3: + raise ValueError( + "cert_chain must be a tuple/list of (certfile, keyfile, password)" + ) + certfile, keyfile, password = cert_chain + context.load_cert_chain(certfile, keyfile, password) + except ValueError: + raise + except (FileNotFoundError, ssl.SSLError) as e: + raise WebSocketException( + f"SSL client certificate configuration failed: {e}" + ) if "ecdh_curve" in sslopt: - context.set_ecdh_curve(sslopt["ecdh_curve"]) + try: + context.set_ecdh_curve(sslopt["ecdh_curve"]) + except ValueError as e: + raise WebSocketException(f"SSL ECDH curve configuration failed: {e}") return context.wrap_socket( sock, @@ -332,7 +366,7 @@ def _tunnel(sock: socket.socket, host, port: int, auth) -> socket.socket: try: status, _, _ = read_headers(sock) - except Exception as e: + except (socket.error, WebSocketException) as e: raise WebSocketProxyException(str(e)) if status != 200: @@ -364,7 +398,11 @@ def read_headers(sock: socket.socket) -> tuple: raise WebSocketException("Invalid header") key, value = kv if key.lower() == "set-cookie" and headers.get("set-cookie"): - headers["set-cookie"] = headers.get("set-cookie") + "; " + value.strip() + existing_cookie = headers.get("set-cookie") + if existing_cookie is not None: + headers["set-cookie"] = existing_cookie + "; " + value.strip() + else: + headers["set-cookie"] = value.strip() else: headers[key.lower()] = value.strip() diff --git a/contrib/python/websocket-client/websocket/_logging.py b/contrib/python/websocket-client/websocket/_logging.py index 0f673d3aff1..65074b98761 100644 --- a/contrib/python/websocket-client/websocket/_logging.py +++ b/contrib/python/websocket-client/websocket/_logging.py @@ -4,7 +4,7 @@ import logging _logging.py websocket - WebSocket client library for Python -Copyright 2024 engn33r +Copyright 2025 engn33r Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -24,7 +24,7 @@ try: from logging import NullHandler except ImportError: - class NullHandler(logging.Handler): + class NullHandler(logging.Handler): # type: ignore[no-redef] def emit(self, record) -> None: pass diff --git a/contrib/python/websocket-client/websocket/_socket.py b/contrib/python/websocket-client/websocket/_socket.py index 81094ffc84b..7a2dff17bef 100644 --- a/contrib/python/websocket-client/websocket/_socket.py +++ b/contrib/python/websocket-client/websocket/_socket.py @@ -1,20 +1,20 @@ import errno import selectors import socket -from typing import Union +from typing import Optional, Union, Any from ._exceptions import ( WebSocketConnectionClosedException, WebSocketTimeoutException, ) -from ._ssl_compat import SSLError, SSLWantReadError, SSLWantWriteError +from ._ssl_compat import SSLError, SSLEOFError, SSLWantReadError, SSLWantWriteError from ._utils import extract_error_code, extract_err_message """ _socket.py websocket - WebSocket client library for Python -Copyright 2024 engn33r +Copyright 2025 engn33r Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -53,17 +53,19 @@ __all__ = [ class sock_opt: - def __init__(self, sockopt: list, sslopt: dict) -> None: + def __init__( + self, sockopt: Optional[list[tuple]], sslopt: Optional[dict[str, Any]] + ) -> None: if sockopt is None: sockopt = [] if sslopt is None: sslopt = {} self.sockopt = sockopt self.sslopt = sslopt - self.timeout = None + self.timeout: Optional[Union[int, float]] = None -def setdefaulttimeout(timeout: Union[int, float, None]) -> None: +def setdefaulttimeout(timeout: Optional[Union[int, float]]) -> None: """ Set the global timeout setting to connect. @@ -76,7 +78,7 @@ def setdefaulttimeout(timeout: Union[int, float, None]) -> None: _default_timeout = timeout -def getdefaulttimeout() -> Union[int, float, None]: +def getdefaulttimeout() -> Optional[Union[int, float]]: """ Get default timeout @@ -96,12 +98,15 @@ def recv(sock: socket.socket, bufsize: int) -> bytes: try: return sock.recv(bufsize) except SSLWantReadError: + # Don't return None implicitly - fall through to retry logic pass except socket.error as exc: error_code = extract_error_code(exc) if error_code not in [errno.EAGAIN, errno.EWOULDBLOCK]: raise + # Don't return None implicitly - fall through to retry logic + # Retry logic using selector for both SSLWantReadError and EAGAIN/EWOULDBLOCK sel = selectors.DefaultSelector() sel.register(sock, selectors.EVENT_READ) @@ -110,6 +115,10 @@ def recv(sock: socket.socket, bufsize: int) -> bytes: if r: return sock.recv(bufsize) + else: + # Selector timeout should raise WebSocketTimeoutException + # not return None which gets misclassified as connection closed + raise WebSocketTimeoutException("Connection timed out waiting for data") try: if sock.gettimeout() == 0: @@ -128,6 +137,8 @@ def recv(sock: socket.socket, bufsize: int) -> bytes: else: raise + if bytes_ is None: + raise WebSocketConnectionClosedException("Connection to remote host was lost.") if not bytes_: raise WebSocketConnectionClosedException("Connection to remote host was lost.") @@ -151,9 +162,11 @@ def send(sock: socket.socket, data: Union[bytes, str]) -> int: if not sock: raise WebSocketConnectionClosedException("socket is already closed.") - def _send(): + def _send() -> int: try: return sock.send(data) + except SSLEOFError: + raise WebSocketConnectionClosedException("socket is already closed.") except SSLWantWriteError: pass except socket.error as exc: @@ -171,6 +184,7 @@ def send(sock: socket.socket, data: Union[bytes, str]) -> int: if w: return sock.send(data) + return 0 try: if sock.gettimeout() == 0: @@ -180,7 +194,7 @@ def send(sock: socket.socket, data: Union[bytes, str]) -> int: except socket.timeout as e: message = extract_err_message(e) raise WebSocketTimeoutException(message) - except Exception as e: + except (OSError, SSLError) as e: message = extract_err_message(e) if isinstance(message, str) and "timed out" in message: raise WebSocketTimeoutException(message) diff --git a/contrib/python/websocket-client/websocket/_ssl_compat.py b/contrib/python/websocket-client/websocket/_ssl_compat.py index 0a8a32b59b3..02f049f66ac 100644 --- a/contrib/python/websocket-client/websocket/_ssl_compat.py +++ b/contrib/python/websocket-client/websocket/_ssl_compat.py @@ -2,7 +2,7 @@ _ssl_compat.py websocket - WebSocket client library for Python -Copyright 2024 engn33r +Copyright 2025 engn33r Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,24 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import ssl as _ssl_module + from ssl import ( + SSLError as _SSLErrorType, + SSLEOFError as _SSLEOFErrorType, + SSLWantReadError as _SSLWantReadErrorType, + SSLWantWriteError as _SSLWantWriteErrorType, + ) +else: + _ssl_module = None + _SSLErrorType = None + _SSLEOFErrorType = None + _SSLWantReadErrorType = None + _SSLWantWriteErrorType = None + __all__ = [ "HAVE_SSL", "ssl", @@ -27,22 +45,22 @@ __all__ = [ try: import ssl - from ssl import SSLError, SSLEOFError, SSLWantReadError, SSLWantWriteError + from ssl import SSLError, SSLEOFError, SSLWantReadError, SSLWantWriteError # type: ignore[attr-defined] HAVE_SSL = True except ImportError: # dummy class of SSLError for environment without ssl support - class SSLError(Exception): + class SSLError(Exception): # type: ignore[no-redef] pass - class SSLEOFError(Exception): + class SSLEOFError(Exception): # type: ignore[no-redef] pass - class SSLWantReadError(Exception): + class SSLWantReadError(Exception): # type: ignore[no-redef] pass - class SSLWantWriteError(Exception): + class SSLWantWriteError(Exception): # type: ignore[no-redef] pass - ssl = None + ssl = None # type: ignore[assignment,no-redef] HAVE_SSL = False diff --git a/contrib/python/websocket-client/websocket/_url.py b/contrib/python/websocket-client/websocket/_url.py index 902131710ba..3b0d6b072f5 100644 --- a/contrib/python/websocket-client/websocket/_url.py +++ b/contrib/python/websocket-client/websocket/_url.py @@ -1,6 +1,5 @@ +import ipaddress import os -import socket -import struct from typing import Optional from urllib.parse import unquote, urlparse from ._exceptions import WebSocketProxyException @@ -9,7 +8,7 @@ from ._exceptions import WebSocketProxyException _url.py websocket - WebSocket client library for Python -Copyright 2024 engn33r +Copyright 2025 engn33r Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -73,13 +72,12 @@ def parse_url(url: str) -> tuple: return hostname, port, resource, is_secure -DEFAULT_NO_PROXY_HOST = ["localhost", "127.0.0.1"] - - def _is_ip_address(addr: str) -> bool: + if not isinstance(addr, str): + raise TypeError("_is_ip_address() argument 1 must be str") try: - socket.inet_aton(addr) - except socket.error: + ipaddress.ip_address(addr) + except ValueError: return False else: return True @@ -87,29 +85,29 @@ def _is_ip_address(addr: str) -> bool: def _is_subnet_address(hostname: str) -> bool: try: - addr, netmask = hostname.split("/") - return _is_ip_address(addr) and 0 <= int(netmask) < 32 + ipaddress.ip_network(hostname) except ValueError: return False + else: + return True def _is_address_in_network(ip: str, net: str) -> bool: - ipaddr: int = struct.unpack("!I", socket.inet_aton(ip))[0] - netaddr, netmask = net.split("/") - netaddr: int = struct.unpack("!I", socket.inet_aton(netaddr))[0] - - netmask = (0xFFFFFFFF << (32 - int(netmask))) & 0xFFFFFFFF - return ipaddr & netmask == netaddr + try: + return ipaddress.ip_network(ip).subnet_of(ipaddress.ip_network(net)) + except TypeError: + return False -def _is_no_proxy_host(hostname: str, no_proxy: Optional[list]) -> bool: +def _is_no_proxy_host(hostname: str, no_proxy: Optional[list[str]]) -> bool: if not no_proxy: if v := os.environ.get("no_proxy", os.environ.get("NO_PROXY", "")).replace( " ", "" ): no_proxy = v.split(",") + if not no_proxy: - no_proxy = DEFAULT_NO_PROXY_HOST + no_proxy = [] if "*" in no_proxy: return True @@ -124,7 +122,8 @@ def _is_no_proxy_host(hostname: str, no_proxy: Optional[list]) -> bool: ] ) for domain in [domain for domain in no_proxy if domain.startswith(".")]: - if hostname.endswith(domain): + endDomain = domain.lstrip(".") + if hostname.endswith(endDomain): return True return False @@ -135,7 +134,7 @@ def get_proxy_info( proxy_host: Optional[str] = None, proxy_port: int = 0, proxy_auth: Optional[tuple] = None, - no_proxy: Optional[list] = None, + no_proxy: Optional[list[str]] = None, proxy_type: str = "http", ) -> tuple: """ @@ -181,7 +180,7 @@ def get_proxy_info( if value: proxy = urlparse(value) auth = ( - (unquote(proxy.username), unquote(proxy.password)) + (unquote(proxy.username or ""), unquote(proxy.password or "")) if proxy.username else None ) diff --git a/contrib/python/websocket-client/websocket/_utils.py b/contrib/python/websocket-client/websocket/_utils.py index 65f3c0daf7c..726a2f3e583 100644 --- a/contrib/python/websocket-client/websocket/_utils.py +++ b/contrib/python/websocket-client/websocket/_utils.py @@ -1,10 +1,10 @@ -from typing import Union +from typing import Union, Optional """ -_url.py +_utils.py websocket - WebSocket client library for Python -Copyright 2024 engn33r +Copyright 2025 engn33r Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -446,7 +446,7 @@ def validate_utf8(utfbytes: Union[str, bytes]) -> bool: return _validate_utf8(utfbytes) -def extract_err_message(exception: Exception) -> Union[str, None]: +def extract_err_message(exception: Exception) -> Optional[str]: if exception.args: exception_message: str = exception.args[0] return exception_message @@ -454,6 +454,7 @@ def extract_err_message(exception: Exception) -> Union[str, None]: return None -def extract_error_code(exception: Exception) -> Union[int, None]: +def extract_error_code(exception: Exception) -> Optional[int]: if exception.args and len(exception.args) > 1: return exception.args[0] if isinstance(exception.args[0], int) else None + return None diff --git a/contrib/python/websocket-client/websocket/_wsdump.py b/contrib/python/websocket-client/websocket/_wsdump.py index d4d76dc509e..81be2521561 100644 --- a/contrib/python/websocket-client/websocket/_wsdump.py +++ b/contrib/python/websocket-client/websocket/_wsdump.py @@ -1,10 +1,10 @@ #!/usr/bin/env python3 """ -wsdump.py +_wsdump.py websocket - WebSocket client library for Python -Copyright 2024 engn33r +Copyright 2025 engn33r Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -32,7 +32,7 @@ from urllib.parse import urlparse import websocket try: - import readline + import readline # noqa: F401 except ImportError: pass diff --git a/contrib/python/websocket-client/websocket/tests/test_abnf.py b/contrib/python/websocket-client/websocket/tests/test_abnf.py index a749f13bd54..664ea3b314c 100644 --- a/contrib/python/websocket-client/websocket/tests/test_abnf.py +++ b/contrib/python/websocket-client/websocket/tests/test_abnf.py @@ -9,7 +9,7 @@ from websocket._exceptions import WebSocketProtocolException test_abnf.py websocket - WebSocket client library for Python -Copyright 2024 engn33r +Copyright 2025 engn33r Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/contrib/python/websocket-client/websocket/tests/test_app.py b/contrib/python/websocket-client/websocket/tests/test_app.py index 18eace54427..c127e5c9e7e 100644 --- a/contrib/python/websocket-client/websocket/tests/test_app.py +++ b/contrib/python/websocket-client/websocket/tests/test_app.py @@ -12,7 +12,7 @@ import websocket as ws test_app.py websocket - WebSocket client library for Python -Copyright 2024 engn33r +Copyright 2025 engn33r Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -347,6 +347,49 @@ class WebSocketAppTest(unittest.TestCase): self.assertIsInstance(exc, ws.WebSocketTimeoutException) self.assertEqual(str(exc), "ping/pong timed out") + def test_dispatcher_selection_default(self): + """Test default dispatcher selection""" + app = ws.WebSocketApp("ws://example.com") + + # Test default dispatcher (non-SSL) + dispatcher = app.create_dispatcher(ping_timeout=10, is_ssl=False) + self.assertIsInstance(dispatcher, ws._dispatcher.Dispatcher) + + def test_dispatcher_selection_ssl(self): + """Test SSL dispatcher selection""" + app = ws.WebSocketApp("wss://example.com") + + # Test SSL dispatcher + dispatcher = app.create_dispatcher(ping_timeout=10, is_ssl=True) + self.assertIsInstance(dispatcher, ws._dispatcher.SSLDispatcher) + + def test_dispatcher_selection_custom(self): + """Test custom dispatcher selection""" + from unittest.mock import Mock + + app = ws.WebSocketApp("ws://example.com") + custom_dispatcher = Mock() + handle_disconnect = Mock() + + # Test wrapped dispatcher with custom dispatcher + dispatcher = app.create_dispatcher( + ping_timeout=10, + dispatcher=custom_dispatcher, + handleDisconnect=handle_disconnect, + ) + self.assertIsInstance(dispatcher, ws._dispatcher.WrappedDispatcher) + self.assertEqual(dispatcher.dispatcher, custom_dispatcher) + self.assertEqual(dispatcher.handleDisconnect, handle_disconnect) + + def test_dispatcher_selection_no_ping_timeout(self): + """Test dispatcher selection without ping timeout""" + app = ws.WebSocketApp("ws://example.com") + + # Test with None ping_timeout (should default to 10) + dispatcher = app.create_dispatcher(ping_timeout=None, is_ssl=False) + self.assertIsInstance(dispatcher, ws._dispatcher.Dispatcher) + self.assertEqual(dispatcher.ping_timeout, 10) + if __name__ == "__main__": unittest.main() diff --git a/contrib/python/websocket-client/websocket/tests/test_cookiejar.py b/contrib/python/websocket-client/websocket/tests/test_cookiejar.py index 67eddb627ae..7590f0caa73 100644 --- a/contrib/python/websocket-client/websocket/tests/test_cookiejar.py +++ b/contrib/python/websocket-client/websocket/tests/test_cookiejar.py @@ -6,7 +6,7 @@ from websocket._cookiejar import SimpleCookieJar test_cookiejar.py websocket - WebSocket client library for Python -Copyright 2024 engn33r +Copyright 2025 engn33r Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/contrib/python/websocket-client/websocket/tests/test_dispatcher.py b/contrib/python/websocket-client/websocket/tests/test_dispatcher.py new file mode 100644 index 00000000000..457bed6cb46 --- /dev/null +++ b/contrib/python/websocket-client/websocket/tests/test_dispatcher.py @@ -0,0 +1,385 @@ +# -*- coding: utf-8 -*- +import socket +import unittest +from unittest.mock import Mock, patch, MagicMock +import threading +import time + +import websocket +from websocket._dispatcher import ( + Dispatcher, + DispatcherBase, + SSLDispatcher, + WrappedDispatcher, +) + +""" +test_dispatcher.py +websocket - WebSocket client library for Python + +Copyright 2025 engn33r + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +class MockApp: + """Mock WebSocketApp for testing""" + + def __init__(self): + self.keep_running = True + self.sock = Mock() + self.sock.sock = Mock() + + +class MockSocket: + """Mock socket for testing""" + + def __init__(self): + self.pending_return = False + + def pending(self): + return self.pending_return + + +class MockDispatcher: + """Mock external dispatcher for WrappedDispatcher testing""" + + def __init__(self): + self.signal_calls = [] + self.abort_calls = [] + self.read_calls = [] + self.buffwrite_calls = [] + self.timeout_calls = [] + + def signal(self, sig, handler): + self.signal_calls.append((sig, handler)) + + def abort(self): + self.abort_calls.append(True) + + def read(self, sock, callback): + self.read_calls.append((sock, callback)) + + def buffwrite(self, sock, data, send_func, disconnect_handler): + self.buffwrite_calls.append((sock, data, send_func, disconnect_handler)) + + def timeout(self, seconds, callback, *args): + self.timeout_calls.append((seconds, callback, args)) + + +class DispatcherTest(unittest.TestCase): + def setUp(self): + self.app = MockApp() + + def test_dispatcher_base_init(self): + """Test DispatcherBase initialization""" + dispatcher = DispatcherBase(self.app, 30.0) + + self.assertEqual(dispatcher.app, self.app) + self.assertEqual(dispatcher.ping_timeout, 30.0) + + def test_dispatcher_base_timeout(self): + """Test DispatcherBase timeout method""" + dispatcher = DispatcherBase(self.app, 30.0) + callback = Mock() + + # Test with seconds=None (should call callback immediately) + dispatcher.timeout(None, callback) + callback.assert_called_once() + + # Test with seconds > 0 (would sleep in real implementation) + callback.reset_mock() + start_time = time.time() + dispatcher.timeout(0.1, callback) + elapsed = time.time() - start_time + + callback.assert_called_once() + self.assertGreaterEqual(elapsed, 0.05) # Allow some tolerance + + def test_dispatcher_base_reconnect(self): + """Test DispatcherBase reconnect method""" + dispatcher = DispatcherBase(self.app, 30.0) + reconnector = Mock() + + # Test normal reconnect + dispatcher.reconnect(1, reconnector) + reconnector.assert_called_once_with(reconnecting=True) + + # Test reconnect with KeyboardInterrupt + reconnector.reset_mock() + reconnector.side_effect = KeyboardInterrupt("User interrupted") + + with self.assertRaises(KeyboardInterrupt): + dispatcher.reconnect(1, reconnector) + + def test_dispatcher_base_send(self): + """Test DispatcherBase send method""" + dispatcher = DispatcherBase(self.app, 30.0) + mock_sock = Mock() + test_data = b"test data" + + with patch("websocket._dispatcher.send") as mock_send: + mock_send.return_value = len(test_data) + result = dispatcher.send(mock_sock, test_data) + + mock_send.assert_called_once_with(mock_sock, test_data) + self.assertEqual(result, len(test_data)) + + def test_dispatcher_read(self): + """Test Dispatcher read method""" + dispatcher = Dispatcher(self.app, 5.0) + read_callback = Mock(return_value=True) + check_callback = Mock() + mock_sock = Mock() + + # Mock the selector to control the loop + with patch("selectors.DefaultSelector") as mock_selector_class: + mock_selector = Mock() + mock_selector_class.return_value = mock_selector + + # Make select return immediately (timeout) + mock_selector.select.return_value = [] + + # Stop after first iteration + def side_effect(*args): + self.app.keep_running = False + return [] + + mock_selector.select.side_effect = side_effect + + dispatcher.read(mock_sock, read_callback, check_callback) + + # Verify selector was used correctly + mock_selector.register.assert_called() + mock_selector.select.assert_called_with(5.0) + mock_selector.close.assert_called() + check_callback.assert_called() + + def test_dispatcher_read_with_data(self): + """Test Dispatcher read method when data is available""" + dispatcher = Dispatcher(self.app, 5.0) + read_callback = Mock(return_value=True) + check_callback = Mock() + mock_sock = Mock() + + with patch("selectors.DefaultSelector") as mock_selector_class: + mock_selector = Mock() + mock_selector_class.return_value = mock_selector + + # First call returns data, second call stops the loop + call_count = 0 + + def select_side_effect(*args): + nonlocal call_count + call_count += 1 + if call_count == 1: + return [True] # Data available + else: + self.app.keep_running = False + return [] + + mock_selector.select.side_effect = select_side_effect + + dispatcher.read(mock_sock, read_callback, check_callback) + + read_callback.assert_called() + check_callback.assert_called() + + def test_ssl_dispatcher_read(self): + """Test SSLDispatcher read method""" + dispatcher = SSLDispatcher(self.app, 5.0) + read_callback = Mock(return_value=True) + check_callback = Mock() + + # Mock socket with pending data + mock_ssl_sock = MockSocket() + self.app.sock.sock = mock_ssl_sock + + with patch("selectors.DefaultSelector") as mock_selector_class: + mock_selector = Mock() + mock_selector_class.return_value = mock_selector + mock_selector.select.return_value = [] + + # Stop after first iteration + def side_effect(*args): + self.app.keep_running = False + return [] + + mock_selector.select.side_effect = side_effect + + dispatcher.read(None, read_callback, check_callback) + + mock_selector.register.assert_called() + check_callback.assert_called() + + def test_ssl_dispatcher_select_with_pending(self): + """Test SSLDispatcher select method with pending data""" + dispatcher = SSLDispatcher(self.app, 5.0) + mock_ssl_sock = MockSocket() + mock_ssl_sock.pending_return = True + self.app.sock.sock = mock_ssl_sock + mock_selector = Mock() + + result = dispatcher.select(None, mock_selector) + + # When pending() returns True, should return [sock] + self.assertEqual(result, [mock_ssl_sock]) + + def test_ssl_dispatcher_select_without_pending(self): + """Test SSLDispatcher select method without pending data""" + dispatcher = SSLDispatcher(self.app, 5.0) + mock_ssl_sock = MockSocket() + mock_ssl_sock.pending_return = False + self.app.sock.sock = mock_ssl_sock + mock_selector = Mock() + mock_selector.select.return_value = [(mock_ssl_sock, None)] + + result = dispatcher.select(None, mock_selector) + + # Should return the first element of first result tuple + self.assertEqual(result, mock_ssl_sock) + mock_selector.select.assert_called_with(5.0) + + def test_ssl_dispatcher_select_no_results(self): + """Test SSLDispatcher select method with no results""" + dispatcher = SSLDispatcher(self.app, 5.0) + mock_ssl_sock = MockSocket() + mock_ssl_sock.pending_return = False + self.app.sock.sock = mock_ssl_sock + mock_selector = Mock() + mock_selector.select.return_value = [] + + result = dispatcher.select(None, mock_selector) + + # Should return None when no results (function doesn't return anything when len(r) == 0) + self.assertIsNone(result) + + def test_wrapped_dispatcher_init(self): + """Test WrappedDispatcher initialization""" + mock_dispatcher = MockDispatcher() + handle_disconnect = Mock() + + wrapped = WrappedDispatcher(self.app, 10.0, mock_dispatcher, handle_disconnect) + + self.assertEqual(wrapped.app, self.app) + self.assertEqual(wrapped.ping_timeout, 10.0) + self.assertEqual(wrapped.dispatcher, mock_dispatcher) + self.assertEqual(wrapped.handleDisconnect, handle_disconnect) + + # Should have set up signal handler + self.assertEqual(len(mock_dispatcher.signal_calls), 1) + sig, handler = mock_dispatcher.signal_calls[0] + self.assertEqual(sig, 2) # SIGINT + self.assertEqual(handler, mock_dispatcher.abort) + + def test_wrapped_dispatcher_read(self): + """Test WrappedDispatcher read method""" + mock_dispatcher = MockDispatcher() + handle_disconnect = Mock() + wrapped = WrappedDispatcher(self.app, 10.0, mock_dispatcher, handle_disconnect) + + mock_sock = Mock() + read_callback = Mock() + check_callback = Mock() + + wrapped.read(mock_sock, read_callback, check_callback) + + # Should delegate to wrapped dispatcher + self.assertEqual(len(mock_dispatcher.read_calls), 1) + self.assertEqual(mock_dispatcher.read_calls[0], (mock_sock, read_callback)) + + # Should call timeout for ping_timeout + self.assertEqual(len(mock_dispatcher.timeout_calls), 1) + timeout_call = mock_dispatcher.timeout_calls[0] + self.assertEqual(timeout_call[0], 10.0) # timeout seconds + self.assertEqual(timeout_call[1], check_callback) # callback + + def test_wrapped_dispatcher_read_no_ping_timeout(self): + """Test WrappedDispatcher read method without ping timeout""" + mock_dispatcher = MockDispatcher() + handle_disconnect = Mock() + wrapped = WrappedDispatcher(self.app, None, mock_dispatcher, handle_disconnect) + + mock_sock = Mock() + read_callback = Mock() + check_callback = Mock() + + wrapped.read(mock_sock, read_callback, check_callback) + + # Should delegate to wrapped dispatcher + self.assertEqual(len(mock_dispatcher.read_calls), 1) + + # Should NOT call timeout when ping_timeout is None + self.assertEqual(len(mock_dispatcher.timeout_calls), 0) + + def test_wrapped_dispatcher_send(self): + """Test WrappedDispatcher send method""" + mock_dispatcher = MockDispatcher() + handle_disconnect = Mock() + wrapped = WrappedDispatcher(self.app, 10.0, mock_dispatcher, handle_disconnect) + + mock_sock = Mock() + test_data = b"test data" + + with patch("websocket._dispatcher.send") as mock_send: + result = wrapped.send(mock_sock, test_data) + + # Should delegate to dispatcher.buffwrite + self.assertEqual(len(mock_dispatcher.buffwrite_calls), 1) + call = mock_dispatcher.buffwrite_calls[0] + self.assertEqual(call[0], mock_sock) + self.assertEqual(call[1], test_data) + self.assertEqual(call[2], mock_send) + self.assertEqual(call[3], handle_disconnect) + + # Should return data length + self.assertEqual(result, len(test_data)) + + def test_wrapped_dispatcher_timeout(self): + """Test WrappedDispatcher timeout method""" + mock_dispatcher = MockDispatcher() + handle_disconnect = Mock() + wrapped = WrappedDispatcher(self.app, 10.0, mock_dispatcher, handle_disconnect) + + callback = Mock() + args = ("arg1", "arg2") + + wrapped.timeout(5.0, callback, *args) + + # Should delegate to wrapped dispatcher + self.assertEqual(len(mock_dispatcher.timeout_calls), 1) + call = mock_dispatcher.timeout_calls[0] + self.assertEqual(call[0], 5.0) + self.assertEqual(call[1], callback) + self.assertEqual(call[2], args) + + def test_wrapped_dispatcher_reconnect(self): + """Test WrappedDispatcher reconnect method""" + mock_dispatcher = MockDispatcher() + handle_disconnect = Mock() + wrapped = WrappedDispatcher(self.app, 10.0, mock_dispatcher, handle_disconnect) + + reconnector = Mock() + + wrapped.reconnect(3, reconnector) + + # Should delegate to timeout method with reconnect=True + self.assertEqual(len(mock_dispatcher.timeout_calls), 1) + call = mock_dispatcher.timeout_calls[0] + self.assertEqual(call[0], 3) + self.assertEqual(call[1], reconnector) + self.assertEqual(call[2], (True,)) + + +if __name__ == "__main__": + unittest.main() diff --git a/contrib/python/websocket-client/websocket/tests/test_handshake_large_response.py b/contrib/python/websocket-client/websocket/tests/test_handshake_large_response.py new file mode 100644 index 00000000000..3ca415a09bb --- /dev/null +++ b/contrib/python/websocket-client/websocket/tests/test_handshake_large_response.py @@ -0,0 +1,158 @@ +# -*- coding: utf-8 -*- +import unittest +from unittest.mock import Mock, patch + +from websocket._handshake import _get_resp_headers +from websocket._exceptions import WebSocketBadStatusException +from websocket._ssl_compat import SSLError + +""" +test_handshake_large_response.py +websocket - WebSocket client library for Python + +Copyright 2025 engn33r + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +class HandshakeLargeResponseTest(unittest.TestCase): + def test_large_error_response_chunked_reading(self): + """Test that large HTTP error responses during handshake are read in chunks""" + + # Mock socket + mock_sock = Mock() + + # Create a large error response body (> 16KB) + large_response = b"Error details: " + b"A" * 20000 # 20KB+ response + + # Track recv calls to ensure chunking + recv_calls = [] + + def mock_recv(sock, bufsize): + recv_calls.append(bufsize) + # Simulate SSL error if trying to read > 16KB at once + if bufsize > 16384: + raise SSLError("[SSL: BAD_LENGTH] unknown error") + return large_response[:bufsize] + + # Mock read_headers to return error status with large content-length + with patch("websocket._handshake.read_headers") as mock_read_headers: + mock_read_headers.return_value = ( + 400, # Bad request status + {"content-length": str(len(large_response))}, + "Bad Request", + ) + + # Mock the recv function to track calls + with patch("websocket._socket.recv", side_effect=mock_recv): + # This should not raise SSLError, but should raise WebSocketBadStatusException + with self.assertRaises(WebSocketBadStatusException) as cm: + _get_resp_headers(mock_sock) + + # Verify the response body was included in the exception + self.assertIn( + b"Error details:", + ( + cm.exception.args[0].encode() + if isinstance(cm.exception.args[0], str) + else cm.exception.args[0] + ), + ) + + # Verify chunked reading was used (multiple recv calls, none > 16KB) + self.assertGreater(len(recv_calls), 1) + self.assertTrue(all(call <= 16384 for call in recv_calls)) + + def test_handshake_ssl_large_response_protection(self): + """Test that the fix prevents SSL BAD_LENGTH errors during handshake""" + + mock_sock = Mock() + + # Large content that would trigger SSL error if read all at once + large_content = b"X" * 32768 # 32KB + + chunks_returned = 0 + + def mock_recv_chunked(sock, bufsize): + nonlocal chunks_returned + # Return data in chunks, simulating successful chunked reading + chunk_start = chunks_returned * 16384 + chunk_end = min(chunk_start + bufsize, len(large_content)) + result = large_content[chunk_start:chunk_end] + chunks_returned += 1 if result else 0 + return result + + with patch("websocket._handshake.read_headers") as mock_read_headers: + mock_read_headers.return_value = ( + 500, # Server error + {"content-length": str(len(large_content))}, + "Internal Server Error", + ) + + with patch("websocket._socket.recv", side_effect=mock_recv_chunked): + # Should handle large response without SSL errors + with self.assertRaises(WebSocketBadStatusException) as cm: + _get_resp_headers(mock_sock) + + # Verify the complete response was captured + exception_str = str(cm.exception) + # Response body should be in the exception message + self.assertIn("XXXXX", exception_str) # Part of the large content + + def test_handshake_normal_small_response(self): + """Test that normal small responses still work correctly""" + + mock_sock = Mock() + small_response = b"Small error message" + + def mock_recv(sock, bufsize): + return small_response + + with patch("websocket._handshake.read_headers") as mock_read_headers: + mock_read_headers.return_value = ( + 404, # Not found + {"content-length": str(len(small_response))}, + "Not Found", + ) + + with patch("websocket._socket.recv", side_effect=mock_recv): + with self.assertRaises(WebSocketBadStatusException) as cm: + _get_resp_headers(mock_sock) + + # Verify small response is handled correctly + self.assertIn("Small error message", str(cm.exception)) + + def test_handshake_no_content_length(self): + """Test handshake error response without content-length header""" + + mock_sock = Mock() + + with patch("websocket._handshake.read_headers") as mock_read_headers: + mock_read_headers.return_value = ( + 403, # Forbidden + {}, # No content-length header + "Forbidden", + ) + + # Should raise exception without trying to read response body + with self.assertRaises(WebSocketBadStatusException) as cm: + _get_resp_headers(mock_sock) + + # Should mention status but not have response body + exception_str = str(cm.exception) + self.assertIn("403", exception_str) + self.assertIn("Forbidden", exception_str) + + +if __name__ == "__main__": + unittest.main() diff --git a/contrib/python/websocket-client/websocket/tests/test_http.py b/contrib/python/websocket-client/websocket/tests/test_http.py index 72465c22057..55a9a9cc0d5 100644 --- a/contrib/python/websocket-client/websocket/tests/test_http.py +++ b/contrib/python/websocket-client/websocket/tests/test_http.py @@ -22,7 +22,7 @@ from websocket._http import ( test_http.py websocket - WebSocket client library for Python -Copyright 2024 engn33r +Copyright 2025 engn33r Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/contrib/python/websocket-client/websocket/tests/test_large_payloads.py b/contrib/python/websocket-client/websocket/tests/test_large_payloads.py new file mode 100644 index 00000000000..4d69c635f11 --- /dev/null +++ b/contrib/python/websocket-client/websocket/tests/test_large_payloads.py @@ -0,0 +1,273 @@ +# -*- coding: utf-8 -*- +import unittest +import struct +from unittest.mock import Mock, patch, MagicMock + +from websocket._abnf import ABNF +from websocket._core import WebSocket +from websocket._exceptions import WebSocketProtocolException, WebSocketPayloadException +from websocket._ssl_compat import SSLError + +""" +test_large_payloads.py +websocket - WebSocket client library for Python + +Copyright 2025 engn33r + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +class LargePayloadTest(unittest.TestCase): + def test_frame_length_encoding_boundaries(self): + """Test WebSocket frame length encoding at various boundaries""" + + # Test length encoding boundaries as per RFC 6455 + test_cases = [ + (125, "Single byte length"), # Max for 7-bit length + (126, "Two byte length start"), # Start of 16-bit length + (127, "Two byte length"), + (65535, "Two byte length max"), # Max for 16-bit length + (65536, "Eight byte length start"), # Start of 64-bit length + (16384, "16KB boundary"), # The problematic size + (16385, "Just over 16KB"), + (32768, "32KB"), + (131072, "128KB"), + ] + + for length, description in test_cases: + with self.subTest(length=length, description=description): + # Create payload of specified length + payload = b"A" * length + + # Create frame + frame = ABNF.create_frame(payload, ABNF.OPCODE_BINARY) + + # Verify frame can be formatted without error + formatted = frame.format() + + # Verify the frame header is correctly structured + self.assertIsInstance(formatted, bytes) + self.assertTrue(len(formatted) >= length) # Header + payload + + # Verify payload length is preserved + self.assertEqual(len(frame.data), length) + + def test_recv_large_payload_chunked(self): + """Test receiving large payloads in chunks (simulating the 16KB recv issue)""" + + # Create a large payload that would trigger chunked reading + large_payload = b"B" * 32768 # 32KB + + # Mock recv function that returns data in 16KB chunks + chunks = [] + chunk_size = 16384 + for i in range(0, len(large_payload), chunk_size): + chunks.append(large_payload[i : i + chunk_size]) + + call_count = 0 + + def mock_recv(bufsize): + nonlocal call_count + if call_count >= len(chunks): + return b"" + result = chunks[call_count] + call_count += 1 + return result + + # Test the frame buffer's recv_strict method + from websocket._abnf import frame_buffer + + fb = frame_buffer(mock_recv, skip_utf8_validation=True) + + # This should handle large payloads by chunking + result = fb.recv_strict(len(large_payload)) + + self.assertEqual(result, large_payload) + # Verify multiple recv calls were made + self.assertGreater(call_count, 1) + + def test_ssl_large_payload_simulation(self): + """Simulate SSL BAD_LENGTH error scenario""" + + # This test demonstrates that the 16KB limit in frame buffer protects against SSL issues + payload_size = 16385 + + recv_calls = [] + + def mock_recv_with_ssl_limit(bufsize): + recv_calls.append(bufsize) + # This simulates the SSL issue: BAD_LENGTH when trying to recv > 16KB + if bufsize > 16384: + raise SSLError("[SSL: BAD_LENGTH] unknown error") + return b"C" * min(bufsize, 16384) + + from websocket._abnf import frame_buffer + + fb = frame_buffer(mock_recv_with_ssl_limit, skip_utf8_validation=True) + + # The frame buffer handles this correctly by chunking recv calls + result = fb.recv_strict(payload_size) + + # Verify it worked and chunked the calls properly + self.assertEqual(len(result), payload_size) + # Verify no single recv call was > 16KB + self.assertTrue(all(call <= 16384 for call in recv_calls)) + # Verify multiple calls were made + self.assertGreater(len(recv_calls), 1) + + def test_frame_format_large_payloads(self): + """Test frame formatting with various large payload sizes""" + + # Test sizes around potential problem areas + test_sizes = [16383, 16384, 16385, 32768, 65535, 65536] + + for size in test_sizes: + with self.subTest(size=size): + payload = b"D" * size + frame = ABNF.create_frame(payload, ABNF.OPCODE_BINARY) + + # Should not raise any exceptions + formatted = frame.format() + + # Verify structure + self.assertIsInstance(formatted, bytes) + self.assertEqual(len(frame.data), size) + + # Verify length encoding is correct based on size + # Note: frames from create_frame() include masking by default (4 extra bytes) + mask_size = 4 # WebSocket frames are masked by default + if size < ABNF.LENGTH_7: # < 126 + # Length should be encoded in single byte + expected_header_size = ( + 2 + mask_size + ) # 1 byte opcode + 1 byte length + 4 byte mask + elif size < ABNF.LENGTH_16: # < 65536 + # Length should be encoded in 2 bytes + expected_header_size = ( + 4 + mask_size + ) # 1 byte opcode + 1 byte marker + 2 bytes length + 4 byte mask + else: + # Length should be encoded in 8 bytes + expected_header_size = ( + 10 + mask_size + ) # 1 byte opcode + 1 byte marker + 8 bytes length + 4 byte mask + + self.assertEqual(len(formatted), expected_header_size + size) + + def test_send_large_payload_chunking(self): + """Test that large payloads are sent in chunks to avoid SSL issues""" + + mock_sock = Mock() + + # Track how data is sent + sent_chunks = [] + + def mock_send(data): + sent_chunks.append(len(data)) + return len(data) + + mock_sock.send = mock_send + mock_sock.gettimeout.return_value = 30.0 + + # Create WebSocket with mocked socket + ws = WebSocket() + ws.sock = mock_sock + ws.connected = True + + # Create large payload + large_payload = b"E" * 32768 # 32KB + + # Send the payload + with patch("websocket._core.send") as mock_send_func: + mock_send_func.side_effect = lambda sock, data: len(data) + + # This should work without SSL errors + result = ws.send_binary(large_payload) + + # Verify payload was accepted + self.assertGreater(result, 0) + + def test_utf8_validation_large_text(self): + """Test UTF-8 validation with large text payloads""" + + # Create large valid UTF-8 text + large_text = "Hello 世界! " * 2000 # About 26KB with Unicode + + # Test frame creation + frame = ABNF.create_frame(large_text, ABNF.OPCODE_TEXT) + + # Should not raise validation errors + formatted = frame.format() + self.assertIsInstance(formatted, bytes) + + # Test with close frame that has invalid UTF-8 (this is what validate() actually checks) + invalid_utf8_close_data = struct.pack("!H", 1000) + b"\xff\xfe invalid utf8" + + # Create close frame with invalid UTF-8 data + frame = ABNF(1, 0, 0, 0, ABNF.OPCODE_CLOSE, 1, invalid_utf8_close_data) + + # Validation should catch the invalid UTF-8 in close frame reason + with self.assertRaises(WebSocketProtocolException): + frame.validate(skip_utf8_validation=False) + + def test_frame_buffer_edge_cases(self): + """Test frame buffer with edge cases that could trigger bugs""" + + # Test scenario: exactly 16KB payload split across recv calls + payload_16k = b"F" * 16384 + + # Simulate receiving in smaller chunks + chunks = [payload_16k[i : i + 4096] for i in range(0, len(payload_16k), 4096)] + + call_count = 0 + + def mock_recv(bufsize): + nonlocal call_count + if call_count >= len(chunks): + return b"" + result = chunks[call_count] + call_count += 1 + return result + + from websocket._abnf import frame_buffer + + fb = frame_buffer(mock_recv, skip_utf8_validation=True) + result = fb.recv_strict(16384) + + self.assertEqual(result, payload_16k) + # Verify multiple recv calls were made + self.assertEqual(call_count, 4) # 16KB / 4KB = 4 chunks + + def test_max_frame_size_limits(self): + """Test behavior at WebSocket maximum frame size limits""" + + # Test just under the maximum theoretical frame size + # (This is a very large test, so we'll use a smaller representative size) + + # Test with a reasonably large payload that represents the issue + large_size = 1024 * 1024 # 1MB + payload = b"G" * large_size + + # This should work without issues + frame = ABNF.create_frame(payload, ABNF.OPCODE_BINARY) + + # Verify the frame can be formatted + formatted = frame.format() + self.assertIsInstance(formatted, bytes) + + # Verify payload is preserved + self.assertEqual(len(frame.data), large_size) + + +if __name__ == "__main__": + unittest.main() diff --git a/contrib/python/websocket-client/websocket/tests/test_socket.py b/contrib/python/websocket-client/websocket/tests/test_socket.py new file mode 100644 index 00000000000..5b8b65bd6b5 --- /dev/null +++ b/contrib/python/websocket-client/websocket/tests/test_socket.py @@ -0,0 +1,357 @@ +# -*- coding: utf-8 -*- +import errno +import socket +import unittest +from unittest.mock import Mock, patch, MagicMock +import time + +from websocket._socket import recv, recv_line, send, DEFAULT_SOCKET_OPTION +from websocket._ssl_compat import ( + SSLError, + SSLEOFError, + SSLWantWriteError, + SSLWantReadError, +) +from websocket._exceptions import ( + WebSocketTimeoutException, + WebSocketConnectionClosedException, +) + +""" +test_socket.py +websocket - WebSocket client library for Python + +Copyright 2025 engn33r + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +class SocketTest(unittest.TestCase): + def test_default_socket_option(self): + """Test DEFAULT_SOCKET_OPTION contains expected options""" + self.assertIsInstance(DEFAULT_SOCKET_OPTION, list) + self.assertGreater(len(DEFAULT_SOCKET_OPTION), 0) + + # Should contain TCP_NODELAY option + tcp_nodelay_found = any( + opt[1] == socket.TCP_NODELAY for opt in DEFAULT_SOCKET_OPTION + ) + self.assertTrue(tcp_nodelay_found) + + def test_recv_normal(self): + """Test normal recv operation""" + mock_sock = Mock() + mock_sock.recv.return_value = b"test data" + + result = recv(mock_sock, 9) + + self.assertEqual(result, b"test data") + mock_sock.recv.assert_called_once_with(9) + + def test_recv_timeout_error(self): + """Test recv with TimeoutError""" + mock_sock = Mock() + mock_sock.recv.side_effect = TimeoutError("Connection timed out") + + with self.assertRaises(WebSocketTimeoutException) as cm: + recv(mock_sock, 9) + + self.assertEqual(str(cm.exception), "Connection timed out") + + def test_recv_socket_timeout(self): + """Test recv with socket.timeout""" + mock_sock = Mock() + timeout_exc = socket.timeout("Socket timed out") + timeout_exc.args = ("Socket timed out",) + mock_sock.recv.side_effect = timeout_exc + mock_sock.gettimeout.return_value = 30.0 + + with self.assertRaises(WebSocketTimeoutException) as cm: + recv(mock_sock, 9) + + # In Python 3.10+, socket.timeout is a subclass of TimeoutError + # so it's caught by the TimeoutError handler with hardcoded message + # In Python 3.9, socket.timeout is caught by socket.timeout handler + # which preserves the original message + import sys + + if sys.version_info >= (3, 10): + self.assertEqual(str(cm.exception), "Connection timed out") + else: + self.assertEqual(str(cm.exception), "Socket timed out") + + def test_recv_ssl_timeout(self): + """Test recv with SSL timeout error""" + mock_sock = Mock() + ssl_exc = SSLError("The operation timed out") + ssl_exc.args = ("The operation timed out",) + mock_sock.recv.side_effect = ssl_exc + + with self.assertRaises(WebSocketTimeoutException) as cm: + recv(mock_sock, 9) + + self.assertEqual(str(cm.exception), "The operation timed out") + + def test_recv_ssl_non_timeout_error(self): + """Test recv with SSL non-timeout error""" + mock_sock = Mock() + ssl_exc = SSLError("SSL certificate error") + ssl_exc.args = ("SSL certificate error",) + mock_sock.recv.side_effect = ssl_exc + + # Should re-raise the original SSL error + with self.assertRaises(SSLError): + recv(mock_sock, 9) + + def test_recv_empty_response(self): + """Test recv with empty response (connection closed)""" + mock_sock = Mock() + mock_sock.recv.return_value = b"" + + with self.assertRaises(WebSocketConnectionClosedException) as cm: + recv(mock_sock, 9) + + self.assertEqual(str(cm.exception), "Connection to remote host was lost.") + + def test_recv_ssl_want_read_error(self): + """Test recv with SSLWantReadError (should retry)""" + mock_sock = Mock() + + # First call raises SSLWantReadError, second call succeeds + mock_sock.recv.side_effect = [SSLWantReadError(), b"data after retry"] + + with patch("selectors.DefaultSelector") as mock_selector_class: + mock_selector = Mock() + mock_selector_class.return_value = mock_selector + mock_selector.select.return_value = [True] # Ready to read + + result = recv(mock_sock, 100) + + self.assertEqual(result, b"data after retry") + mock_selector.register.assert_called() + mock_selector.close.assert_called() + + def test_recv_ssl_want_read_timeout(self): + """Test recv with SSLWantReadError that times out""" + mock_sock = Mock() + mock_sock.recv.side_effect = SSLWantReadError() + mock_sock.gettimeout.return_value = 1.0 + + with patch("selectors.DefaultSelector") as mock_selector_class: + mock_selector = Mock() + mock_selector_class.return_value = mock_selector + mock_selector.select.return_value = [] # Timeout + + with self.assertRaises(WebSocketTimeoutException): + recv(mock_sock, 100) + + def test_recv_line(self): + """Test recv_line functionality""" + mock_sock = Mock() + + # Mock recv to return one character at a time + recv_calls = [b"H", b"e", b"l", b"l", b"o", b"\n"] + + with patch("websocket._socket.recv", side_effect=recv_calls) as mock_recv: + result = recv_line(mock_sock) + + self.assertEqual(result, b"Hello\n") + self.assertEqual(mock_recv.call_count, 6) + + def test_send_normal(self): + """Test normal send operation""" + mock_sock = Mock() + mock_sock.send.return_value = 9 + mock_sock.gettimeout.return_value = 30.0 + + result = send(mock_sock, b"test data") + + self.assertEqual(result, 9) + mock_sock.send.assert_called_with(b"test data") + + def test_send_zero_timeout(self): + """Test send with zero timeout (non-blocking)""" + mock_sock = Mock() + mock_sock.send.return_value = 9 + mock_sock.gettimeout.return_value = 0 + + result = send(mock_sock, b"test data") + + self.assertEqual(result, 9) + mock_sock.send.assert_called_once_with(b"test data") + + def test_send_ssl_eof_error(self): + """Test send with SSLEOFError""" + mock_sock = Mock() + mock_sock.gettimeout.return_value = 30.0 + mock_sock.send.side_effect = SSLEOFError("Connection closed") + + with self.assertRaises(WebSocketConnectionClosedException) as cm: + send(mock_sock, b"test data") + + self.assertEqual(str(cm.exception), "socket is already closed.") + + def test_send_ssl_want_write_error(self): + """Test send with SSLWantWriteError (should retry)""" + mock_sock = Mock() + mock_sock.gettimeout.return_value = 30.0 + + # First call raises SSLWantWriteError, second call succeeds + mock_sock.send.side_effect = [SSLWantWriteError(), 9] + + with patch("selectors.DefaultSelector") as mock_selector_class: + mock_selector = Mock() + mock_selector_class.return_value = mock_selector + mock_selector.select.return_value = [True] # Ready to write + + result = send(mock_sock, b"test data") + + self.assertEqual(result, 9) + mock_selector.register.assert_called() + mock_selector.close.assert_called() + + def test_send_socket_eagain_error(self): + """Test send with EAGAIN error (should retry)""" + mock_sock = Mock() + mock_sock.gettimeout.return_value = 30.0 + + # Create socket error with EAGAIN + eagain_error = socket.error("Resource temporarily unavailable") + eagain_error.errno = errno.EAGAIN + eagain_error.args = (errno.EAGAIN, "Resource temporarily unavailable") + + # First call raises EAGAIN, second call succeeds + mock_sock.send.side_effect = [eagain_error, 9] + + with patch("selectors.DefaultSelector") as mock_selector_class: + mock_selector = Mock() + mock_selector_class.return_value = mock_selector + mock_selector.select.return_value = [True] # Ready to write + + result = send(mock_sock, b"test data") + + self.assertEqual(result, 9) + + def test_send_socket_ewouldblock_error(self): + """Test send with EWOULDBLOCK error (should retry)""" + mock_sock = Mock() + mock_sock.gettimeout.return_value = 30.0 + + # Create socket error with EWOULDBLOCK + ewouldblock_error = socket.error("Operation would block") + ewouldblock_error.errno = errno.EWOULDBLOCK + ewouldblock_error.args = (errno.EWOULDBLOCK, "Operation would block") + + # First call raises EWOULDBLOCK, second call succeeds + mock_sock.send.side_effect = [ewouldblock_error, 9] + + with patch("selectors.DefaultSelector") as mock_selector_class: + mock_selector = Mock() + mock_selector_class.return_value = mock_selector + mock_selector.select.return_value = [True] # Ready to write + + result = send(mock_sock, b"test data") + + self.assertEqual(result, 9) + + def test_send_socket_other_error(self): + """Test send with other socket error (should raise)""" + mock_sock = Mock() + mock_sock.gettimeout.return_value = 30.0 + + # Create socket error with different errno + other_error = socket.error("Connection reset by peer") + other_error.errno = errno.ECONNRESET + other_error.args = (errno.ECONNRESET, "Connection reset by peer") + + mock_sock.send.side_effect = other_error + + with self.assertRaises(socket.error): + send(mock_sock, b"test data") + + def test_send_socket_error_no_errno(self): + """Test send with socket error that has no errno""" + mock_sock = Mock() + mock_sock.gettimeout.return_value = 30.0 + + # Create socket error without errno attribute + no_errno_error = socket.error("Generic socket error") + no_errno_error.args = ("Generic socket error",) + + mock_sock.send.side_effect = no_errno_error + + with self.assertRaises(socket.error): + send(mock_sock, b"test data") + + def test_send_write_timeout(self): + """Test send write operation timeout""" + mock_sock = Mock() + mock_sock.gettimeout.return_value = 30.0 + + # First call raises EAGAIN + eagain_error = socket.error("Resource temporarily unavailable") + eagain_error.errno = errno.EAGAIN + eagain_error.args = (errno.EAGAIN, "Resource temporarily unavailable") + + mock_sock.send.side_effect = eagain_error + + with patch("selectors.DefaultSelector") as mock_selector_class: + mock_selector = Mock() + mock_selector_class.return_value = mock_selector + mock_selector.select.return_value = [] # Timeout - nothing ready + + result = send(mock_sock, b"test data") + + # Should return 0 when write times out + self.assertEqual(result, 0) + + def test_send_string_data(self): + """Test send with string data (should be encoded)""" + mock_sock = Mock() + mock_sock.send.return_value = 9 + mock_sock.gettimeout.return_value = 30.0 + + result = send(mock_sock, "test data") + + self.assertEqual(result, 9) + mock_sock.send.assert_called_with(b"test data") + + def test_send_partial_send_retry(self): + """Test send retry mechanism""" + mock_sock = Mock() + mock_sock.gettimeout.return_value = 30.0 + + # Create a scenario where send succeeds after selector retry + eagain_error = socket.error("Resource temporarily unavailable") + eagain_error.errno = errno.EAGAIN + eagain_error.args = (errno.EAGAIN, "Resource temporarily unavailable") + + # Mock the internal _send function behavior + mock_sock.send.side_effect = [eagain_error, 9] + + with patch("selectors.DefaultSelector") as mock_selector_class: + mock_selector = Mock() + mock_selector_class.return_value = mock_selector + mock_selector.select.return_value = [True] # Socket ready for writing + + result = send(mock_sock, b"test data") + + self.assertEqual(result, 9) + # Verify selector was used for retry mechanism + mock_selector.register.assert_called() + mock_selector.select.assert_called() + mock_selector.close.assert_called() + + +if __name__ == "__main__": + unittest.main() diff --git a/contrib/python/websocket-client/websocket/tests/test_socket_bugs.py b/contrib/python/websocket-client/websocket/tests/test_socket_bugs.py new file mode 100644 index 00000000000..72f222f5c4c --- /dev/null +++ b/contrib/python/websocket-client/websocket/tests/test_socket_bugs.py @@ -0,0 +1,160 @@ +# -*- coding: utf-8 -*- +import errno +import socket +import unittest +from unittest.mock import Mock, patch + +from websocket._socket import recv +from websocket._ssl_compat import SSLWantReadError +from websocket._exceptions import ( + WebSocketTimeoutException, + WebSocketConnectionClosedException, +) + +""" +test_socket_bugs.py +websocket - WebSocket client library for Python + +Copyright 2025 engn33r + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +class SocketBugsTest(unittest.TestCase): + """Test bugs found in socket handling logic""" + + def test_bug_implicit_none_return_from_ssl_want_read_fixed(self): + """ + BUG #5 FIX VERIFICATION: Test SSLWantReadError timeout now raises correct exception + + Bug was in _socket.py:100-101 - SSLWantReadError except block returned None implicitly + Fixed: Now properly handles timeout with WebSocketTimeoutException + """ + mock_sock = Mock() + mock_sock.recv.side_effect = SSLWantReadError() + mock_sock.gettimeout.return_value = 1.0 + + with patch("selectors.DefaultSelector") as mock_selector_class: + mock_selector = Mock() + mock_selector_class.return_value = mock_selector + mock_selector.select.return_value = [] # Timeout - no data ready + + with self.assertRaises(WebSocketTimeoutException) as cm: + recv(mock_sock, 100) + + # Verify correct timeout exception and message + self.assertIn("Connection timed out waiting for data", str(cm.exception)) + + def test_bug_implicit_none_return_from_socket_error_fixed(self): + """ + BUG #5 FIX VERIFICATION: Test that socket.error with EAGAIN now handles timeout correctly + + Bug was in _socket.py:102-105 - socket.error except block returned None implicitly + Fixed: Now properly handles timeout with WebSocketTimeoutException + """ + mock_sock = Mock() + + # Create socket error with EAGAIN (should be retried) + eagain_error = OSError(errno.EAGAIN, "Resource temporarily unavailable") + + # First call raises EAGAIN, selector times out on retry + mock_sock.recv.side_effect = eagain_error + mock_sock.gettimeout.return_value = 1.0 + + with patch("selectors.DefaultSelector") as mock_selector_class: + mock_selector = Mock() + mock_selector_class.return_value = mock_selector + mock_selector.select.return_value = [] # Timeout - no data ready + + with self.assertRaises(WebSocketTimeoutException) as cm: + recv(mock_sock, 100) + + # Verify correct timeout exception and message + self.assertIn("Connection timed out waiting for data", str(cm.exception)) + + def test_bug_wrong_exception_for_selector_timeout_fixed(self): + """ + BUG #6 FIX VERIFICATION: Test that selector timeout now raises correct exception type + + Bug was in _socket.py:115 returning None for timeout, treated as connection error + Fixed: Now raises WebSocketTimeoutException directly + """ + mock_sock = Mock() + mock_sock.recv.side_effect = SSLWantReadError() # Trigger retry path + mock_sock.gettimeout.return_value = 1.0 + + with patch("selectors.DefaultSelector") as mock_selector_class: + mock_selector = Mock() + mock_selector_class.return_value = mock_selector + mock_selector.select.return_value = [] # TIMEOUT - this is key! + + with self.assertRaises(WebSocketTimeoutException) as cm: + recv(mock_sock, 100) + + # Verify it's the correct timeout exception with proper message + self.assertIn("Connection timed out waiting for data", str(cm.exception)) + + # This proves the fix works: + # 1. selector.select() returns [] (timeout) + # 2. _recv() now raises WebSocketTimeoutException directly + # 3. No more misclassification as connection closed error! + + def test_socket_timeout_exception_handling(self): + """ + Test that socket.timeout exceptions are properly handled + """ + mock_sock = Mock() + mock_sock.gettimeout.return_value = 1.0 + + # Simulate a real socket.timeout scenario + mock_sock.recv.side_effect = socket.timeout("Operation timed out") + + # This works correctly - socket.timeout raises WebSocketTimeoutException + with self.assertRaises(WebSocketTimeoutException) as cm: + recv(mock_sock, 100) + + # In Python 3.10+, socket.timeout is a subclass of TimeoutError + # so it's caught by the TimeoutError handler with hardcoded message + # In Python 3.9, socket.timeout is caught by socket.timeout handler + # which preserves the original message + import sys + + if sys.version_info >= (3, 10): + self.assertIn("Connection timed out", str(cm.exception)) + else: + self.assertIn("Operation timed out", str(cm.exception)) + + def test_correct_ssl_want_read_retry_behavior(self): + """Test the correct behavior when SSLWantReadError is properly handled""" + mock_sock = Mock() + + # First call raises SSLWantReadError, second call succeeds + mock_sock.recv.side_effect = [SSLWantReadError(), b"data after retry"] + mock_sock.gettimeout.return_value = 1.0 + + with patch("selectors.DefaultSelector") as mock_selector_class: + mock_selector = Mock() + mock_selector_class.return_value = mock_selector + mock_selector.select.return_value = [True] # Data ready after wait + + # This should work correctly + result = recv(mock_sock, 100) + self.assertEqual(result, b"data after retry") + + # Selector should be used for retry + mock_selector.register.assert_called() + mock_selector.select.assert_called() + + +if __name__ == "__main__": + unittest.main() diff --git a/contrib/python/websocket-client/websocket/tests/test_ssl_compat.py b/contrib/python/websocket-client/websocket/tests/test_ssl_compat.py new file mode 100644 index 00000000000..9dcd674b0f0 --- /dev/null +++ b/contrib/python/websocket-client/websocket/tests/test_ssl_compat.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- +import sys +import unittest +from unittest.mock import patch + +""" +test_ssl_compat.py +websocket - WebSocket client library for Python + +Copyright 2025 engn33r + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +class SSLCompatTest(unittest.TestCase): + def test_ssl_available(self): + """Test that SSL is available in normal conditions""" + import websocket._ssl_compat as ssl_compat + + # In normal conditions, SSL should be available + self.assertTrue(ssl_compat.HAVE_SSL) + self.assertIsNotNone(ssl_compat.ssl) + + # SSL exception classes should be available + self.assertTrue(hasattr(ssl_compat, "SSLError")) + self.assertTrue(hasattr(ssl_compat, "SSLEOFError")) + self.assertTrue(hasattr(ssl_compat, "SSLWantReadError")) + self.assertTrue(hasattr(ssl_compat, "SSLWantWriteError")) + + def test_ssl_not_available(self): + """Test fallback behavior when SSL is not available""" + # Remove ssl_compat from modules to force reimport + if "websocket._ssl_compat" in sys.modules: + del sys.modules["websocket._ssl_compat"] + + # Mock the ssl module to not be available + import builtins + + original_import = builtins.__import__ + + def mock_import(name, *args, **kwargs): + if name == "ssl": + raise ImportError("No module named 'ssl'") + return original_import(name, *args, **kwargs) + + with patch("builtins.__import__", side_effect=mock_import): + import websocket._ssl_compat as ssl_compat + + # SSL should not be available + self.assertFalse(ssl_compat.HAVE_SSL) + self.assertIsNone(ssl_compat.ssl) + + # Fallback exception classes should be available and functional + self.assertTrue(issubclass(ssl_compat.SSLError, Exception)) + self.assertTrue(issubclass(ssl_compat.SSLEOFError, Exception)) + self.assertTrue(issubclass(ssl_compat.SSLWantReadError, Exception)) + self.assertTrue(issubclass(ssl_compat.SSLWantWriteError, Exception)) + + # Test that exceptions can be instantiated + ssl_error = ssl_compat.SSLError("test error") + self.assertIsInstance(ssl_error, Exception) + self.assertEqual(str(ssl_error), "test error") + + ssl_eof_error = ssl_compat.SSLEOFError("test eof") + self.assertIsInstance(ssl_eof_error, Exception) + + ssl_want_read = ssl_compat.SSLWantReadError("test read") + self.assertIsInstance(ssl_want_read, Exception) + + ssl_want_write = ssl_compat.SSLWantWriteError("test write") + self.assertIsInstance(ssl_want_write, Exception) + + def tearDown(self): + """Clean up after tests""" + # Ensure ssl_compat is reimported fresh for next test + if "websocket._ssl_compat" in sys.modules: + del sys.modules["websocket._ssl_compat"] + + +if __name__ == "__main__": + unittest.main() diff --git a/contrib/python/websocket-client/websocket/tests/test_ssl_edge_cases.py b/contrib/python/websocket-client/websocket/tests/test_ssl_edge_cases.py new file mode 100644 index 00000000000..a8e14d3f4ed --- /dev/null +++ b/contrib/python/websocket-client/websocket/tests/test_ssl_edge_cases.py @@ -0,0 +1,638 @@ +# -*- coding: utf-8 -*- +import unittest +import socket +import ssl +from unittest.mock import Mock, patch, MagicMock + +from websocket._ssl_compat import ( + SSLError, + SSLEOFError, + SSLWantReadError, + SSLWantWriteError, + HAVE_SSL, +) +from websocket._http import _ssl_socket, _wrap_sni_socket +from websocket._exceptions import WebSocketException +from websocket._socket import recv, send + +""" +test_ssl_edge_cases.py +websocket - WebSocket client library for Python + +Copyright 2025 engn33r + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +class SSLEdgeCasesTest(unittest.TestCase): + + def setUp(self): + if not HAVE_SSL: + self.skipTest("SSL not available") + + def test_ssl_handshake_failure(self): + """Test SSL handshake failure scenarios""" + mock_sock = Mock() + + # Test SSL handshake timeout + with patch("ssl.SSLContext") as mock_ssl_context: + mock_context = Mock() + mock_ssl_context.return_value = mock_context + mock_context.wrap_socket.side_effect = socket.timeout( + "SSL handshake timeout" + ) + + sslopt = {"cert_reqs": ssl.CERT_REQUIRED} + + with self.assertRaises(socket.timeout): + _ssl_socket(mock_sock, sslopt, "example.com") + + def test_ssl_certificate_verification_failures(self): + """Test various SSL certificate verification failure scenarios""" + mock_sock = Mock() + + # Test certificate verification failure + with patch("ssl.SSLContext") as mock_ssl_context: + mock_context = Mock() + mock_ssl_context.return_value = mock_context + mock_context.wrap_socket.side_effect = ssl.SSLCertVerificationError( + "Certificate verification failed" + ) + + sslopt = {"cert_reqs": ssl.CERT_REQUIRED, "check_hostname": True} + + with self.assertRaises(ssl.SSLCertVerificationError): + _ssl_socket(mock_sock, sslopt, "badssl.example") + + def test_ssl_context_configuration_edge_cases(self): + """Test SSL context configuration with various edge cases""" + mock_sock = Mock() + + # Test with pre-created SSL context + with patch("ssl.SSLContext") as mock_ssl_context: + existing_context = Mock() + existing_context.wrap_socket.return_value = Mock() + mock_ssl_context.return_value = existing_context + + sslopt = {"context": existing_context} + + # Call _ssl_socket which should use the existing context + _ssl_socket(mock_sock, sslopt, "example.com") + + # Should use the provided context, not create a new one + existing_context.wrap_socket.assert_called_once() + + def test_ssl_ca_bundle_environment_edge_cases(self): + """Test CA bundle environment variable edge cases""" + mock_sock = Mock() + + # Test with non-existent CA bundle file + with patch.dict( + "os.environ", {"WEBSOCKET_CLIENT_CA_BUNDLE": "/nonexistent/ca-bundle.crt"} + ): + with patch("os.path.isfile", return_value=False): + with patch("os.path.isdir", return_value=False): + with patch("ssl.SSLContext") as mock_ssl_context: + mock_context = Mock() + mock_ssl_context.return_value = mock_context + mock_context.wrap_socket.return_value = Mock() + + sslopt = {} + _ssl_socket(mock_sock, sslopt, "example.com") + + # Should not try to load non-existent CA bundle + mock_context.load_verify_locations.assert_not_called() + + # Test with CA bundle directory + with patch.dict("os.environ", {"WEBSOCKET_CLIENT_CA_BUNDLE": "/etc/ssl/certs"}): + with patch("os.path.isfile", return_value=False): + with patch("os.path.isdir", return_value=True): + with patch("ssl.SSLContext") as mock_ssl_context: + mock_context = Mock() + mock_ssl_context.return_value = mock_context + mock_context.wrap_socket.return_value = Mock() + + sslopt = {} + _ssl_socket(mock_sock, sslopt, "example.com") + + # Should load CA directory + mock_context.load_verify_locations.assert_called_with( + cafile=None, capath="/etc/ssl/certs" + ) + + def test_ssl_cipher_configuration_edge_cases(self): + """Test SSL cipher configuration edge cases""" + mock_sock = Mock() + + # Test with invalid cipher suite + with patch("ssl.SSLContext") as mock_ssl_context: + mock_context = Mock() + mock_ssl_context.return_value = mock_context + mock_context.set_ciphers.side_effect = ssl.SSLError( + "No cipher can be selected" + ) + mock_context.wrap_socket.return_value = Mock() + + sslopt = {"ciphers": "INVALID_CIPHER"} + + with self.assertRaises(WebSocketException): + _ssl_socket(mock_sock, sslopt, "example.com") + + def test_ssl_ecdh_curve_edge_cases(self): + """Test ECDH curve configuration edge cases""" + mock_sock = Mock() + + # Test with invalid ECDH curve + with patch("ssl.SSLContext") as mock_ssl_context: + mock_context = Mock() + mock_ssl_context.return_value = mock_context + mock_context.set_ecdh_curve.side_effect = ValueError("unknown curve name") + mock_context.wrap_socket.return_value = Mock() + + sslopt = {"ecdh_curve": "invalid_curve"} + + with self.assertRaises(WebSocketException): + _ssl_socket(mock_sock, sslopt, "example.com") + + def test_ssl_client_certificate_edge_cases(self): + """Test client certificate configuration edge cases""" + mock_sock = Mock() + + # Test with non-existent client certificate + with patch("ssl.SSLContext") as mock_ssl_context: + mock_context = Mock() + mock_ssl_context.return_value = mock_context + mock_context.load_cert_chain.side_effect = FileNotFoundError("No such file") + mock_context.wrap_socket.return_value = Mock() + + sslopt = {"certfile": "/nonexistent/client.crt"} + + with self.assertRaises(WebSocketException): + _ssl_socket(mock_sock, sslopt, "example.com") + + def test_ssl_want_read_write_retry_edge_cases(self): + """Test SSL want read/write retry edge cases""" + mock_sock = Mock() + + # Test SSLWantReadError with multiple retries before success + read_attempts = [0] # Use list for mutable reference + + def mock_recv(bufsize): + read_attempts[0] += 1 + if read_attempts[0] == 1: + raise SSLWantReadError("The operation did not complete") + elif read_attempts[0] == 2: + return b"data after retries" + else: + return b"" + + mock_sock.recv.side_effect = mock_recv + mock_sock.gettimeout.return_value = 30.0 + + with patch("selectors.DefaultSelector") as mock_selector_class: + mock_selector = Mock() + mock_selector_class.return_value = mock_selector + mock_selector.select.return_value = [True] # Always ready + + result = recv(mock_sock, 100) + + self.assertEqual(result, b"data after retries") + self.assertEqual(read_attempts[0], 2) + # Should have used selector for retry + mock_selector.register.assert_called() + mock_selector.select.assert_called() + + def test_ssl_want_write_retry_edge_cases(self): + """Test SSL want write retry edge cases""" + mock_sock = Mock() + + # Test SSLWantWriteError with multiple retries before success + write_attempts = [0] # Use list for mutable reference + + def mock_send(data): + write_attempts[0] += 1 + if write_attempts[0] == 1: + raise SSLWantWriteError("The operation did not complete") + elif write_attempts[0] == 2: + return len(data) + else: + return 0 + + mock_sock.send.side_effect = mock_send + mock_sock.gettimeout.return_value = 30.0 + + with patch("selectors.DefaultSelector") as mock_selector_class: + mock_selector = Mock() + mock_selector_class.return_value = mock_selector + mock_selector.select.return_value = [True] # Always ready + + result = send(mock_sock, b"test data") + + self.assertEqual(result, 9) # len("test data") + self.assertEqual(write_attempts[0], 2) + + def test_ssl_eof_error_edge_cases(self): + """Test SSL EOF error edge cases""" + mock_sock = Mock() + + # Test SSLEOFError during send + mock_sock.send.side_effect = SSLEOFError("SSL connection has been closed") + mock_sock.gettimeout.return_value = 30.0 + + from websocket._exceptions import WebSocketConnectionClosedException + + with self.assertRaises(WebSocketConnectionClosedException): + send(mock_sock, b"test data") + + def test_ssl_pending_data_edge_cases(self): + """Test SSL pending data scenarios""" + from websocket._dispatcher import SSLDispatcher + from websocket._app import WebSocketApp + + # Mock SSL socket with pending data + mock_ssl_sock = Mock() + mock_ssl_sock.pending.return_value = 1024 # Simulates pending SSL data + + # Mock WebSocketApp + mock_app = Mock(spec=WebSocketApp) + mock_app.sock = Mock() + mock_app.sock.sock = mock_ssl_sock + + dispatcher = SSLDispatcher(mock_app, 5.0) + + # When there's pending data, should return immediately without selector + result = dispatcher.select(mock_ssl_sock, Mock()) + + # Should return the socket list when there's pending data + self.assertEqual(result, [mock_ssl_sock]) + mock_ssl_sock.pending.assert_called_once() + + def test_ssl_renegotiation_edge_cases(self): + """Test SSL renegotiation scenarios""" + mock_sock = Mock() + + # Simulate SSL renegotiation during read + call_count = 0 + + def mock_recv(bufsize): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise SSLWantReadError("SSL renegotiation required") + return b"data after renegotiation" + + mock_sock.recv.side_effect = mock_recv + mock_sock.gettimeout.return_value = 30.0 + + with patch("selectors.DefaultSelector") as mock_selector_class: + mock_selector = Mock() + mock_selector_class.return_value = mock_selector + mock_selector.select.return_value = [True] + + result = recv(mock_sock, 100) + + self.assertEqual(result, b"data after renegotiation") + self.assertEqual(call_count, 2) + + def test_ssl_server_hostname_override(self): + """Test SSL server hostname override scenarios""" + mock_sock = Mock() + + with patch("ssl.SSLContext") as mock_ssl_context: + mock_context = Mock() + mock_ssl_context.return_value = mock_context + mock_context.wrap_socket.return_value = Mock() + + # Test server_hostname override + sslopt = {"server_hostname": "override.example.com"} + _ssl_socket(mock_sock, sslopt, "original.example.com") + + # Should use override hostname in wrap_socket call + mock_context.wrap_socket.assert_called_with( + mock_sock, + do_handshake_on_connect=True, + suppress_ragged_eofs=True, + server_hostname="override.example.com", + ) + + def test_ssl_protocol_version_edge_cases(self): + """Test SSL protocol version edge cases""" + mock_sock = Mock() + + # Test with deprecated SSL version + with patch("ssl.SSLContext") as mock_ssl_context: + mock_context = Mock() + mock_ssl_context.return_value = mock_context + mock_context.wrap_socket.return_value = Mock() + + # Test that deprecated ssl_version is still handled + if hasattr(ssl, "PROTOCOL_TLS"): + sslopt = {"ssl_version": ssl.PROTOCOL_TLS} + _ssl_socket(mock_sock, sslopt, "example.com") + + mock_ssl_context.assert_called_with(ssl.PROTOCOL_TLS) + + def test_ssl_keylog_file_edge_cases(self): + """Test SSL keylog file configuration edge cases""" + mock_sock = Mock() + + # Test with SSLKEYLOGFILE environment variable + with patch.dict("os.environ", {"SSLKEYLOGFILE": "/tmp/ssl_keys.log"}): + with patch("ssl.SSLContext") as mock_ssl_context: + mock_context = Mock() + mock_ssl_context.return_value = mock_context + mock_context.wrap_socket.return_value = Mock() + + sslopt = {} + _ssl_socket(mock_sock, sslopt, "example.com") + + # Should set keylog_filename + self.assertEqual(mock_context.keylog_filename, "/tmp/ssl_keys.log") + + def test_ssl_context_verification_modes(self): + """Test different SSL verification mode combinations""" + mock_sock = Mock() + + test_cases = [ + # (cert_reqs, check_hostname, expected_verify_mode, expected_check_hostname) + (ssl.CERT_NONE, False, ssl.CERT_NONE, False), + (ssl.CERT_REQUIRED, False, ssl.CERT_REQUIRED, False), + (ssl.CERT_REQUIRED, True, ssl.CERT_REQUIRED, True), + ] + + for cert_reqs, check_hostname, expected_verify, expected_check in test_cases: + with self.subTest(cert_reqs=cert_reqs, check_hostname=check_hostname): + with patch("ssl.SSLContext") as mock_ssl_context: + mock_context = Mock() + mock_ssl_context.return_value = mock_context + mock_context.wrap_socket.return_value = Mock() + + sslopt = {"cert_reqs": cert_reqs, "check_hostname": check_hostname} + _ssl_socket(mock_sock, sslopt, "example.com") + + self.assertEqual(mock_context.verify_mode, expected_verify) + self.assertEqual(mock_context.check_hostname, expected_check) + + def test_ssl_socket_shutdown_edge_cases(self): + """Test SSL socket shutdown edge cases""" + from websocket._core import WebSocket + + mock_ssl_sock = Mock() + mock_ssl_sock.shutdown.side_effect = SSLError("SSL shutdown failed") + + ws = WebSocket() + ws.sock = mock_ssl_sock + ws.connected = True + + # Should handle SSL shutdown errors gracefully + try: + ws.close() + except SSLError: + self.fail("SSL shutdown error should be handled gracefully") + + def test_ssl_socket_close_during_operation(self): + """Test SSL socket being closed during ongoing operations""" + mock_sock = Mock() + + # Simulate SSL socket being closed during recv + mock_sock.recv.side_effect = SSLError( + "SSL connection has been closed unexpectedly" + ) + mock_sock.gettimeout.return_value = 30.0 + + from websocket._exceptions import WebSocketConnectionClosedException + + # Should handle unexpected SSL closure + with self.assertRaises((SSLError, WebSocketConnectionClosedException)): + recv(mock_sock, 100) + + def test_ssl_compression_edge_cases(self): + """Test SSL compression configuration edge cases""" + mock_sock = Mock() + + with patch("ssl.SSLContext") as mock_ssl_context: + mock_context = Mock() + mock_ssl_context.return_value = mock_context + mock_context.wrap_socket.return_value = Mock() + + # Test SSL compression options (if available) + sslopt = {"compression": False} # Some SSL contexts support this + + try: + _ssl_socket(mock_sock, sslopt, "example.com") + # Should not fail even if compression option is not supported + except AttributeError: + # Expected if SSL context doesn't support compression option + pass + + def test_ssl_session_reuse_edge_cases(self): + """Test SSL session reuse scenarios""" + mock_sock = Mock() + + with patch("ssl.SSLContext") as mock_ssl_context: + mock_context = Mock() + mock_ssl_context.return_value = mock_context + mock_ssl_sock = Mock() + mock_context.wrap_socket.return_value = mock_ssl_sock + + # Test session reuse + mock_ssl_sock.session = "mock_session" + mock_ssl_sock.session_reused = True + + result = _ssl_socket(mock_sock, {}, "example.com") + + # Should handle session reuse without issues + self.assertIsNotNone(result) + + def test_ssl_alpn_protocol_edge_cases(self): + """Test SSL ALPN (Application Layer Protocol Negotiation) edge cases""" + mock_sock = Mock() + + with patch("ssl.SSLContext") as mock_ssl_context: + mock_context = Mock() + mock_ssl_context.return_value = mock_context + mock_context.wrap_socket.return_value = Mock() + + # Test ALPN configuration + sslopt = {"alpn_protocols": ["http/1.1", "h2"]} + + # ALPN protocols are not currently supported in the SSL wrapper + # but the test should not fail + result = _ssl_socket(mock_sock, sslopt, "example.com") + self.assertIsNotNone(result) + # ALPN would need to be implemented in _wrap_sni_socket function + + def test_ssl_sni_edge_cases(self): + """Test SSL SNI (Server Name Indication) edge cases""" + mock_sock = Mock() + + # Test with IPv6 address (should not use SNI) + with patch("ssl.SSLContext") as mock_ssl_context: + mock_context = Mock() + mock_ssl_context.return_value = mock_context + mock_context.wrap_socket.return_value = Mock() + + # IPv6 addresses should not be used for SNI + ipv6_hostname = "2001:db8::1" + _ssl_socket(mock_sock, {}, ipv6_hostname) + + # Should use IPv6 address as server_hostname + mock_context.wrap_socket.assert_called_with( + mock_sock, + do_handshake_on_connect=True, + suppress_ragged_eofs=True, + server_hostname=ipv6_hostname, + ) + + def test_ssl_buffer_size_edge_cases(self): + """Test SSL buffer size related edge cases""" + mock_sock = Mock() + + def mock_recv(bufsize): + # SSL should never try to read more than 16KB at once + if bufsize > 16384: + raise SSLError("[SSL: BAD_LENGTH] buffer too large") + return b"A" * min(bufsize, 1024) # Return smaller chunks + + mock_sock.recv.side_effect = mock_recv + mock_sock.gettimeout.return_value = 30.0 + + from websocket._abnf import frame_buffer + + # Frame buffer should handle large requests by chunking + fb = frame_buffer(lambda size: recv(mock_sock, size), skip_utf8_validation=True) + + # This should work even with large size due to chunking + result = fb.recv_strict(16384) # Exactly 16KB + + self.assertGreater(len(result), 0) + + def test_ssl_protocol_downgrade_protection(self): + """Test SSL protocol downgrade protection""" + mock_sock = Mock() + + with patch("ssl.SSLContext") as mock_ssl_context: + mock_context = Mock() + mock_ssl_context.return_value = mock_context + mock_context.wrap_socket.side_effect = ssl.SSLError( + "SSLV3_ALERT_HANDSHAKE_FAILURE" + ) + + sslopt = {"ssl_version": ssl.PROTOCOL_TLS_CLIENT} + + # Should propagate SSL protocol errors + with self.assertRaises(ssl.SSLError): + _ssl_socket(mock_sock, sslopt, "example.com") + + def test_ssl_certificate_chain_validation(self): + """Test SSL certificate chain validation edge cases""" + mock_sock = Mock() + + with patch("ssl.SSLContext") as mock_ssl_context: + mock_context = Mock() + mock_ssl_context.return_value = mock_context + + # Test certificate chain validation failure + mock_context.wrap_socket.side_effect = ssl.SSLCertVerificationError( + "certificate verify failed: certificate has expired" + ) + + sslopt = {"cert_reqs": ssl.CERT_REQUIRED, "check_hostname": True} + + with self.assertRaises(ssl.SSLCertVerificationError): + _ssl_socket(mock_sock, sslopt, "expired.badssl.com") + + def test_ssl_weak_cipher_rejection(self): + """Test SSL weak cipher rejection scenarios""" + mock_sock = Mock() + + with patch("ssl.SSLContext") as mock_ssl_context: + mock_context = Mock() + mock_ssl_context.return_value = mock_context + mock_context.wrap_socket.side_effect = ssl.SSLError("no shared cipher") + + sslopt = {"ciphers": "RC4-MD5"} # Intentionally weak cipher + + # Should fail with weak ciphers (SSL error is not wrapped by our code) + with self.assertRaises(ssl.SSLError): + _ssl_socket(mock_sock, sslopt, "example.com") + + def test_ssl_hostname_verification_edge_cases(self): + """Test SSL hostname verification edge cases""" + mock_sock = Mock() + + # Test with wildcard certificate scenarios + test_cases = [ + ("*.example.com", "subdomain.example.com"), # Valid wildcard + ("*.example.com", "sub.subdomain.example.com"), # Invalid wildcard depth + ("example.com", "www.example.com"), # Hostname mismatch + ] + + for cert_hostname, connect_hostname in test_cases: + with self.subTest(cert=cert_hostname, hostname=connect_hostname): + with patch("ssl.SSLContext") as mock_ssl_context: + mock_context = Mock() + mock_ssl_context.return_value = mock_context + + if ( + cert_hostname != connect_hostname + and "sub.subdomain" in connect_hostname + ): + # Simulate hostname verification failure for invalid wildcard + mock_context.wrap_socket.side_effect = ssl.SSLCertVerificationError( + f"hostname '{connect_hostname}' doesn't match '{cert_hostname}'" + ) + + sslopt = { + "cert_reqs": ssl.CERT_REQUIRED, + "check_hostname": True, + } + + with self.assertRaises(ssl.SSLCertVerificationError): + _ssl_socket(mock_sock, sslopt, connect_hostname) + else: + mock_context.wrap_socket.return_value = Mock() + sslopt = { + "cert_reqs": ssl.CERT_REQUIRED, + "check_hostname": True, + } + + # Should succeed for valid cases + result = _ssl_socket(mock_sock, sslopt, connect_hostname) + self.assertIsNotNone(result) + + def test_ssl_memory_bio_edge_cases(self): + """Test SSL memory BIO edge cases""" + mock_sock = Mock() + + # Test SSL memory BIO scenarios (if available) + try: + import ssl + + if hasattr(ssl, "MemoryBIO"): + with patch("ssl.SSLContext") as mock_ssl_context: + mock_context = Mock() + mock_ssl_context.return_value = mock_context + mock_context.wrap_socket.return_value = Mock() + + # Memory BIO should work if available + _ssl_socket(mock_sock, {}, "example.com") + + # Standard socket wrapping should still work + mock_context.wrap_socket.assert_called_once() + except (ImportError, AttributeError): + self.skipTest("SSL MemoryBIO not available") + + +if __name__ == "__main__": + unittest.main() diff --git a/contrib/python/websocket-client/websocket/tests/test_url.py b/contrib/python/websocket-client/websocket/tests/test_url.py index 110fdfad70a..bbb39b0f3f7 100644 --- a/contrib/python/websocket-client/websocket/tests/test_url.py +++ b/contrib/python/websocket-client/websocket/tests/test_url.py @@ -15,7 +15,7 @@ from websocket._exceptions import WebSocketProxyException test_url.py websocket - WebSocket client library for Python -Copyright 2024 engn33r +Copyright 2025 engn33r Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -36,6 +36,8 @@ class UrlTest(unittest.TestCase): self.assertTrue(_is_address_in_network("127.0.0.1", "127.0.0.0/8")) self.assertTrue(_is_address_in_network("127.1.0.1", "127.0.0.0/8")) self.assertFalse(_is_address_in_network("127.1.0.1", "127.0.0.0/24")) + self.assertTrue(_is_address_in_network("2001:db8::1", "2001:db8::/64")) + self.assertFalse(_is_address_in_network("2001:db8:1::1", "2001:db8::/64")) def test_parse_url(self): p = parse_url("ws://www.example.com/r") @@ -167,11 +169,16 @@ class IsNoProxyHostTest(unittest.TestCase): self.assertTrue(_is_no_proxy_host("127.0.0.1", ["127.0.0.0/8"])) self.assertTrue(_is_no_proxy_host("127.0.0.2", ["127.0.0.0/8"])) self.assertFalse(_is_no_proxy_host("127.1.0.1", ["127.0.0.0/24"])) - os.environ["no_proxy"] = "127.0.0.0/8" + self.assertTrue(_is_no_proxy_host("2001:db8::1", ["2001:db8::/64"])) + self.assertFalse(_is_no_proxy_host("2001:db8:1::1", ["2001:db8::/64"])) + os.environ["no_proxy"] = "127.0.0.0/8,2001:db8::/64" self.assertTrue(_is_no_proxy_host("127.0.0.1", None)) self.assertTrue(_is_no_proxy_host("127.0.0.2", None)) - os.environ["no_proxy"] = "127.0.0.0/24" + self.assertTrue(_is_no_proxy_host("2001:db8::1", None)) + self.assertFalse(_is_no_proxy_host("2001:db8:1::1", None)) + os.environ["no_proxy"] = "127.0.0.0/24,2001:db8::/64" self.assertFalse(_is_no_proxy_host("127.1.0.1", None)) + self.assertFalse(_is_no_proxy_host("2001:db8:1::1", None)) def test_hostname_match(self): self.assertTrue(_is_no_proxy_host("my.websocket.org", ["my.websocket.org"])) @@ -427,12 +434,12 @@ class ProxyInfoTest(unittest.TestCase): ("localhost2", 3128, ("a", "b")), ) - os.environ[ - "http_proxy" - ] = "http://john%40example.com:P%40SSWORD@localhost:3128/" - os.environ[ - "https_proxy" - ] = "http://john%40example.com:P%40SSWORD@localhost2:3128/" + os.environ["http_proxy"] = ( + "http://john%40example.com:P%40SSWORD@localhost:3128/" + ) + os.environ["https_proxy"] = ( + "http://john%40example.com:P%40SSWORD@localhost2:3128/" + ) self.assertEqual( get_proxy_info("echo.websocket.events", True), ("localhost2", 3128, ("[email protected]", "P@SSWORD")), diff --git a/contrib/python/websocket-client/websocket/tests/test_utils.py b/contrib/python/websocket-client/websocket/tests/test_utils.py new file mode 100644 index 00000000000..deb9751bd16 --- /dev/null +++ b/contrib/python/websocket-client/websocket/tests/test_utils.py @@ -0,0 +1,138 @@ +# -*- coding: utf-8 -*- +import sys +import unittest +from unittest.mock import patch + +""" +test_utils.py +websocket - WebSocket client library for Python + +Copyright 2025 engn33r + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +class UtilsTest(unittest.TestCase): + def test_nolock(self): + """Test NoLock context manager""" + from websocket._utils import NoLock + + lock = NoLock() + + # Test that it can be used as context manager + with lock: + pass # Should not raise any exception + + # Test enter/exit methods directly + self.assertIsNone(lock.__enter__()) + self.assertIsNone(lock.__exit__(None, None, None)) + + def test_utf8_validation_with_wsaccel(self): + """Test UTF-8 validation when wsaccel is available""" + # Import normally (wsaccel should be available in test environment) + from websocket._utils import validate_utf8 + + # Test valid UTF-8 strings (convert to bytes for wsaccel) + self.assertTrue(validate_utf8("Hello, World!".encode("utf-8"))) + self.assertTrue(validate_utf8("🌟 Unicode test".encode("utf-8"))) + self.assertTrue(validate_utf8(b"Hello, bytes")) + self.assertTrue(validate_utf8("Héllo with accénts".encode("utf-8"))) + + # Test invalid UTF-8 sequences + self.assertFalse(validate_utf8(b"\xff\xfe")) # Invalid UTF-8 + self.assertFalse(validate_utf8(b"\x80\x80")) # Invalid continuation + + def test_utf8_validation_fallback(self): + """Test UTF-8 validation fallback when wsaccel is not available""" + # Remove _utils from modules to force reimport + if "websocket._utils" in sys.modules: + del sys.modules["websocket._utils"] + + # Mock wsaccel import to raise ImportError + import builtins + + original_import = builtins.__import__ + + def mock_import(name, *args, **kwargs): + if "wsaccel" in name: + raise ImportError(f"No module named '{name}'") + return original_import(name, *args, **kwargs) + + with patch("builtins.__import__", side_effect=mock_import): + import websocket._utils as utils + + # Test valid UTF-8 strings with fallback implementation (convert strings to bytes) + self.assertTrue(utils.validate_utf8("Hello, World!".encode("utf-8"))) + self.assertTrue(utils.validate_utf8(b"Hello, bytes")) + self.assertTrue(utils.validate_utf8("ASCII text".encode("utf-8"))) + + # Test Unicode strings (convert to bytes) + self.assertTrue(utils.validate_utf8("🌟 Unicode test".encode("utf-8"))) + self.assertTrue(utils.validate_utf8("Héllo with accénts".encode("utf-8"))) + + # Test empty string/bytes + self.assertTrue(utils.validate_utf8("".encode("utf-8"))) + self.assertTrue(utils.validate_utf8(b"")) + + # Test invalid UTF-8 sequences (should return False) + self.assertFalse(utils.validate_utf8(b"\xff\xfe")) + self.assertFalse(utils.validate_utf8(b"\x80\x80")) + + # Note: The fallback implementation may have different validation behavior + # than wsaccel, so we focus on clearly invalid sequences + + def test_extract_err_message(self): + """Test extract_err_message function""" + from websocket._utils import extract_err_message + + # Test with exception that has args + exc_with_args = Exception("Test error message") + self.assertEqual(extract_err_message(exc_with_args), "Test error message") + + # Test with exception that has multiple args + exc_multi_args = Exception("First arg", "Second arg") + self.assertEqual(extract_err_message(exc_multi_args), "First arg") + + # Test with exception that has no args + exc_no_args = Exception() + self.assertIsNone(extract_err_message(exc_no_args)) + + def test_extract_error_code(self): + """Test extract_error_code function""" + from websocket._utils import extract_error_code + + # Test with exception that has integer as first arg + exc_with_code = Exception(404, "Not found") + self.assertEqual(extract_error_code(exc_with_code), 404) + + # Test with exception that has string as first arg + exc_with_string = Exception("Error message", "Second arg") + self.assertIsNone(extract_error_code(exc_with_string)) + + # Test with exception that has only one arg + exc_single_arg = Exception("Single arg") + self.assertIsNone(extract_error_code(exc_single_arg)) + + # Test with exception that has no args + exc_no_args = Exception() + self.assertIsNone(extract_error_code(exc_no_args)) + + def tearDown(self): + """Clean up after tests""" + # Ensure _utils is reimported fresh for next test + if "websocket._utils" in sys.modules: + del sys.modules["websocket._utils"] + + +if __name__ == "__main__": + unittest.main() diff --git a/contrib/python/websocket-client/websocket/tests/test_websocket.py b/contrib/python/websocket-client/websocket/tests/test_websocket.py index 892312a2dbd..9e36df7c011 100644 --- a/contrib/python/websocket-client/websocket/tests/test_websocket.py +++ b/contrib/python/websocket-client/websocket/tests/test_websocket.py @@ -7,7 +7,11 @@ import unittest from base64 import decodebytes as base64decode import websocket as ws -from websocket._exceptions import WebSocketBadStatusException, WebSocketAddressException +from websocket._exceptions import ( + WebSocketBadStatusException, + WebSocketAddressException, + WebSocketException, +) from websocket._handshake import _create_sec_websocket_key from websocket._handshake import _validate as _validate_header from websocket._http import read_headers @@ -17,7 +21,7 @@ from websocket._utils import validate_utf8 test_websocket.py websocket - WebSocket client library for Python -Copyright 2024 engn33r +Copyright 2025 engn33r Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -296,7 +300,7 @@ class WebSocketTest(unittest.TestCase): def test_close(self): sock = ws.WebSocket() sock.connected = True - sock.close + sock.close() sock = ws.WebSocket() s = sock.sock = SockMock() @@ -455,7 +459,7 @@ class HandshakeTest(unittest.TestCase): self.assertRaises(ValueError, websock1.connect, "wss://api.bitfinex.com/ws/2") websock2 = ws.WebSocket(sslopt={"certfile": "myNonexistentCertFile"}) self.assertRaises( - FileNotFoundError, websock2.connect, "wss://api.bitfinex.com/ws/2" + WebSocketException, websock2.connect, "wss://api.bitfinex.com/ws/2" ) @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") diff --git a/contrib/python/websocket-client/ya.make b/contrib/python/websocket-client/ya.make index 3c7e7f401d0..effb2ac24bf 100644 --- a/contrib/python/websocket-client/ya.make +++ b/contrib/python/websocket-client/ya.make @@ -2,7 +2,7 @@ PY3_LIBRARY() -VERSION(1.8.0) +VERSION(1.9.0) LICENSE(Apache-2.0) @@ -15,6 +15,7 @@ PY_SRCS( websocket/_app.py websocket/_cookiejar.py websocket/_core.py + websocket/_dispatcher.py websocket/_exceptions.py websocket/_handshake.py websocket/_http.py |
