diff options
| author | robot-piglet <[email protected]> | 2025-10-22 11:36:42 +0300 |
|---|---|---|
| committer | robot-piglet <[email protected]> | 2025-10-22 12:14:27 +0300 |
| commit | 6a490d481992dac77fa8785bb4d6e6cafea36fa3 (patch) | |
| tree | 28dac4c8ea239eedadc726b73a16514223e56389 /contrib/python/websocket-client/websocket/tests | |
| parent | d924ab94175835dc15b389ee8969ff0ddfd35930 (diff) | |
Intermediate changes
commit_hash:6bfda3fd45ff19cb21e3edc6e8b7dad337978a7e
Diffstat (limited to 'contrib/python/websocket-client/websocket/tests')
14 files changed, 2271 insertions, 17 deletions
diff --git a/contrib/python/websocket-client/websocket/tests/test_abnf.py b/contrib/python/websocket-client/websocket/tests/test_abnf.py index a749f13bd54..664ea3b314c 100644 --- a/contrib/python/websocket-client/websocket/tests/test_abnf.py +++ b/contrib/python/websocket-client/websocket/tests/test_abnf.py @@ -9,7 +9,7 @@ from websocket._exceptions import WebSocketProtocolException test_abnf.py websocket - WebSocket client library for Python -Copyright 2024 engn33r +Copyright 2025 engn33r Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/contrib/python/websocket-client/websocket/tests/test_app.py b/contrib/python/websocket-client/websocket/tests/test_app.py index 18eace54427..c127e5c9e7e 100644 --- a/contrib/python/websocket-client/websocket/tests/test_app.py +++ b/contrib/python/websocket-client/websocket/tests/test_app.py @@ -12,7 +12,7 @@ import websocket as ws test_app.py websocket - WebSocket client library for Python -Copyright 2024 engn33r +Copyright 2025 engn33r Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -347,6 +347,49 @@ class WebSocketAppTest(unittest.TestCase): self.assertIsInstance(exc, ws.WebSocketTimeoutException) self.assertEqual(str(exc), "ping/pong timed out") + def test_dispatcher_selection_default(self): + """Test default dispatcher selection""" + app = ws.WebSocketApp("ws://example.com") + + # Test default dispatcher (non-SSL) + dispatcher = app.create_dispatcher(ping_timeout=10, is_ssl=False) + self.assertIsInstance(dispatcher, ws._dispatcher.Dispatcher) + + def test_dispatcher_selection_ssl(self): + """Test SSL dispatcher selection""" + app = ws.WebSocketApp("wss://example.com") + + # Test SSL dispatcher + dispatcher = app.create_dispatcher(ping_timeout=10, is_ssl=True) + self.assertIsInstance(dispatcher, ws._dispatcher.SSLDispatcher) + + def test_dispatcher_selection_custom(self): + """Test custom dispatcher selection""" + from unittest.mock import Mock + + app = ws.WebSocketApp("ws://example.com") + custom_dispatcher = Mock() + handle_disconnect = Mock() + + # Test wrapped dispatcher with custom dispatcher + dispatcher = app.create_dispatcher( + ping_timeout=10, + dispatcher=custom_dispatcher, + handleDisconnect=handle_disconnect, + ) + self.assertIsInstance(dispatcher, ws._dispatcher.WrappedDispatcher) + self.assertEqual(dispatcher.dispatcher, custom_dispatcher) + self.assertEqual(dispatcher.handleDisconnect, handle_disconnect) + + def test_dispatcher_selection_no_ping_timeout(self): + """Test dispatcher selection without ping timeout""" + app = ws.WebSocketApp("ws://example.com") + + # Test with None ping_timeout (should default to 10) + dispatcher = app.create_dispatcher(ping_timeout=None, is_ssl=False) + self.assertIsInstance(dispatcher, ws._dispatcher.Dispatcher) + self.assertEqual(dispatcher.ping_timeout, 10) + if __name__ == "__main__": unittest.main() diff --git a/contrib/python/websocket-client/websocket/tests/test_cookiejar.py b/contrib/python/websocket-client/websocket/tests/test_cookiejar.py index 67eddb627ae..7590f0caa73 100644 --- a/contrib/python/websocket-client/websocket/tests/test_cookiejar.py +++ b/contrib/python/websocket-client/websocket/tests/test_cookiejar.py @@ -6,7 +6,7 @@ from websocket._cookiejar import SimpleCookieJar test_cookiejar.py websocket - WebSocket client library for Python -Copyright 2024 engn33r +Copyright 2025 engn33r Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/contrib/python/websocket-client/websocket/tests/test_dispatcher.py b/contrib/python/websocket-client/websocket/tests/test_dispatcher.py new file mode 100644 index 00000000000..457bed6cb46 --- /dev/null +++ b/contrib/python/websocket-client/websocket/tests/test_dispatcher.py @@ -0,0 +1,385 @@ +# -*- coding: utf-8 -*- +import socket +import unittest +from unittest.mock import Mock, patch, MagicMock +import threading +import time + +import websocket +from websocket._dispatcher import ( + Dispatcher, + DispatcherBase, + SSLDispatcher, + WrappedDispatcher, +) + +""" +test_dispatcher.py +websocket - WebSocket client library for Python + +Copyright 2025 engn33r + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +class MockApp: + """Mock WebSocketApp for testing""" + + def __init__(self): + self.keep_running = True + self.sock = Mock() + self.sock.sock = Mock() + + +class MockSocket: + """Mock socket for testing""" + + def __init__(self): + self.pending_return = False + + def pending(self): + return self.pending_return + + +class MockDispatcher: + """Mock external dispatcher for WrappedDispatcher testing""" + + def __init__(self): + self.signal_calls = [] + self.abort_calls = [] + self.read_calls = [] + self.buffwrite_calls = [] + self.timeout_calls = [] + + def signal(self, sig, handler): + self.signal_calls.append((sig, handler)) + + def abort(self): + self.abort_calls.append(True) + + def read(self, sock, callback): + self.read_calls.append((sock, callback)) + + def buffwrite(self, sock, data, send_func, disconnect_handler): + self.buffwrite_calls.append((sock, data, send_func, disconnect_handler)) + + def timeout(self, seconds, callback, *args): + self.timeout_calls.append((seconds, callback, args)) + + +class DispatcherTest(unittest.TestCase): + def setUp(self): + self.app = MockApp() + + def test_dispatcher_base_init(self): + """Test DispatcherBase initialization""" + dispatcher = DispatcherBase(self.app, 30.0) + + self.assertEqual(dispatcher.app, self.app) + self.assertEqual(dispatcher.ping_timeout, 30.0) + + def test_dispatcher_base_timeout(self): + """Test DispatcherBase timeout method""" + dispatcher = DispatcherBase(self.app, 30.0) + callback = Mock() + + # Test with seconds=None (should call callback immediately) + dispatcher.timeout(None, callback) + callback.assert_called_once() + + # Test with seconds > 0 (would sleep in real implementation) + callback.reset_mock() + start_time = time.time() + dispatcher.timeout(0.1, callback) + elapsed = time.time() - start_time + + callback.assert_called_once() + self.assertGreaterEqual(elapsed, 0.05) # Allow some tolerance + + def test_dispatcher_base_reconnect(self): + """Test DispatcherBase reconnect method""" + dispatcher = DispatcherBase(self.app, 30.0) + reconnector = Mock() + + # Test normal reconnect + dispatcher.reconnect(1, reconnector) + reconnector.assert_called_once_with(reconnecting=True) + + # Test reconnect with KeyboardInterrupt + reconnector.reset_mock() + reconnector.side_effect = KeyboardInterrupt("User interrupted") + + with self.assertRaises(KeyboardInterrupt): + dispatcher.reconnect(1, reconnector) + + def test_dispatcher_base_send(self): + """Test DispatcherBase send method""" + dispatcher = DispatcherBase(self.app, 30.0) + mock_sock = Mock() + test_data = b"test data" + + with patch("websocket._dispatcher.send") as mock_send: + mock_send.return_value = len(test_data) + result = dispatcher.send(mock_sock, test_data) + + mock_send.assert_called_once_with(mock_sock, test_data) + self.assertEqual(result, len(test_data)) + + def test_dispatcher_read(self): + """Test Dispatcher read method""" + dispatcher = Dispatcher(self.app, 5.0) + read_callback = Mock(return_value=True) + check_callback = Mock() + mock_sock = Mock() + + # Mock the selector to control the loop + with patch("selectors.DefaultSelector") as mock_selector_class: + mock_selector = Mock() + mock_selector_class.return_value = mock_selector + + # Make select return immediately (timeout) + mock_selector.select.return_value = [] + + # Stop after first iteration + def side_effect(*args): + self.app.keep_running = False + return [] + + mock_selector.select.side_effect = side_effect + + dispatcher.read(mock_sock, read_callback, check_callback) + + # Verify selector was used correctly + mock_selector.register.assert_called() + mock_selector.select.assert_called_with(5.0) + mock_selector.close.assert_called() + check_callback.assert_called() + + def test_dispatcher_read_with_data(self): + """Test Dispatcher read method when data is available""" + dispatcher = Dispatcher(self.app, 5.0) + read_callback = Mock(return_value=True) + check_callback = Mock() + mock_sock = Mock() + + with patch("selectors.DefaultSelector") as mock_selector_class: + mock_selector = Mock() + mock_selector_class.return_value = mock_selector + + # First call returns data, second call stops the loop + call_count = 0 + + def select_side_effect(*args): + nonlocal call_count + call_count += 1 + if call_count == 1: + return [True] # Data available + else: + self.app.keep_running = False + return [] + + mock_selector.select.side_effect = select_side_effect + + dispatcher.read(mock_sock, read_callback, check_callback) + + read_callback.assert_called() + check_callback.assert_called() + + def test_ssl_dispatcher_read(self): + """Test SSLDispatcher read method""" + dispatcher = SSLDispatcher(self.app, 5.0) + read_callback = Mock(return_value=True) + check_callback = Mock() + + # Mock socket with pending data + mock_ssl_sock = MockSocket() + self.app.sock.sock = mock_ssl_sock + + with patch("selectors.DefaultSelector") as mock_selector_class: + mock_selector = Mock() + mock_selector_class.return_value = mock_selector + mock_selector.select.return_value = [] + + # Stop after first iteration + def side_effect(*args): + self.app.keep_running = False + return [] + + mock_selector.select.side_effect = side_effect + + dispatcher.read(None, read_callback, check_callback) + + mock_selector.register.assert_called() + check_callback.assert_called() + + def test_ssl_dispatcher_select_with_pending(self): + """Test SSLDispatcher select method with pending data""" + dispatcher = SSLDispatcher(self.app, 5.0) + mock_ssl_sock = MockSocket() + mock_ssl_sock.pending_return = True + self.app.sock.sock = mock_ssl_sock + mock_selector = Mock() + + result = dispatcher.select(None, mock_selector) + + # When pending() returns True, should return [sock] + self.assertEqual(result, [mock_ssl_sock]) + + def test_ssl_dispatcher_select_without_pending(self): + """Test SSLDispatcher select method without pending data""" + dispatcher = SSLDispatcher(self.app, 5.0) + mock_ssl_sock = MockSocket() + mock_ssl_sock.pending_return = False + self.app.sock.sock = mock_ssl_sock + mock_selector = Mock() + mock_selector.select.return_value = [(mock_ssl_sock, None)] + + result = dispatcher.select(None, mock_selector) + + # Should return the first element of first result tuple + self.assertEqual(result, mock_ssl_sock) + mock_selector.select.assert_called_with(5.0) + + def test_ssl_dispatcher_select_no_results(self): + """Test SSLDispatcher select method with no results""" + dispatcher = SSLDispatcher(self.app, 5.0) + mock_ssl_sock = MockSocket() + mock_ssl_sock.pending_return = False + self.app.sock.sock = mock_ssl_sock + mock_selector = Mock() + mock_selector.select.return_value = [] + + result = dispatcher.select(None, mock_selector) + + # Should return None when no results (function doesn't return anything when len(r) == 0) + self.assertIsNone(result) + + def test_wrapped_dispatcher_init(self): + """Test WrappedDispatcher initialization""" + mock_dispatcher = MockDispatcher() + handle_disconnect = Mock() + + wrapped = WrappedDispatcher(self.app, 10.0, mock_dispatcher, handle_disconnect) + + self.assertEqual(wrapped.app, self.app) + self.assertEqual(wrapped.ping_timeout, 10.0) + self.assertEqual(wrapped.dispatcher, mock_dispatcher) + self.assertEqual(wrapped.handleDisconnect, handle_disconnect) + + # Should have set up signal handler + self.assertEqual(len(mock_dispatcher.signal_calls), 1) + sig, handler = mock_dispatcher.signal_calls[0] + self.assertEqual(sig, 2) # SIGINT + self.assertEqual(handler, mock_dispatcher.abort) + + def test_wrapped_dispatcher_read(self): + """Test WrappedDispatcher read method""" + mock_dispatcher = MockDispatcher() + handle_disconnect = Mock() + wrapped = WrappedDispatcher(self.app, 10.0, mock_dispatcher, handle_disconnect) + + mock_sock = Mock() + read_callback = Mock() + check_callback = Mock() + + wrapped.read(mock_sock, read_callback, check_callback) + + # Should delegate to wrapped dispatcher + self.assertEqual(len(mock_dispatcher.read_calls), 1) + self.assertEqual(mock_dispatcher.read_calls[0], (mock_sock, read_callback)) + + # Should call timeout for ping_timeout + self.assertEqual(len(mock_dispatcher.timeout_calls), 1) + timeout_call = mock_dispatcher.timeout_calls[0] + self.assertEqual(timeout_call[0], 10.0) # timeout seconds + self.assertEqual(timeout_call[1], check_callback) # callback + + def test_wrapped_dispatcher_read_no_ping_timeout(self): + """Test WrappedDispatcher read method without ping timeout""" + mock_dispatcher = MockDispatcher() + handle_disconnect = Mock() + wrapped = WrappedDispatcher(self.app, None, mock_dispatcher, handle_disconnect) + + mock_sock = Mock() + read_callback = Mock() + check_callback = Mock() + + wrapped.read(mock_sock, read_callback, check_callback) + + # Should delegate to wrapped dispatcher + self.assertEqual(len(mock_dispatcher.read_calls), 1) + + # Should NOT call timeout when ping_timeout is None + self.assertEqual(len(mock_dispatcher.timeout_calls), 0) + + def test_wrapped_dispatcher_send(self): + """Test WrappedDispatcher send method""" + mock_dispatcher = MockDispatcher() + handle_disconnect = Mock() + wrapped = WrappedDispatcher(self.app, 10.0, mock_dispatcher, handle_disconnect) + + mock_sock = Mock() + test_data = b"test data" + + with patch("websocket._dispatcher.send") as mock_send: + result = wrapped.send(mock_sock, test_data) + + # Should delegate to dispatcher.buffwrite + self.assertEqual(len(mock_dispatcher.buffwrite_calls), 1) + call = mock_dispatcher.buffwrite_calls[0] + self.assertEqual(call[0], mock_sock) + self.assertEqual(call[1], test_data) + self.assertEqual(call[2], mock_send) + self.assertEqual(call[3], handle_disconnect) + + # Should return data length + self.assertEqual(result, len(test_data)) + + def test_wrapped_dispatcher_timeout(self): + """Test WrappedDispatcher timeout method""" + mock_dispatcher = MockDispatcher() + handle_disconnect = Mock() + wrapped = WrappedDispatcher(self.app, 10.0, mock_dispatcher, handle_disconnect) + + callback = Mock() + args = ("arg1", "arg2") + + wrapped.timeout(5.0, callback, *args) + + # Should delegate to wrapped dispatcher + self.assertEqual(len(mock_dispatcher.timeout_calls), 1) + call = mock_dispatcher.timeout_calls[0] + self.assertEqual(call[0], 5.0) + self.assertEqual(call[1], callback) + self.assertEqual(call[2], args) + + def test_wrapped_dispatcher_reconnect(self): + """Test WrappedDispatcher reconnect method""" + mock_dispatcher = MockDispatcher() + handle_disconnect = Mock() + wrapped = WrappedDispatcher(self.app, 10.0, mock_dispatcher, handle_disconnect) + + reconnector = Mock() + + wrapped.reconnect(3, reconnector) + + # Should delegate to timeout method with reconnect=True + self.assertEqual(len(mock_dispatcher.timeout_calls), 1) + call = mock_dispatcher.timeout_calls[0] + self.assertEqual(call[0], 3) + self.assertEqual(call[1], reconnector) + self.assertEqual(call[2], (True,)) + + +if __name__ == "__main__": + unittest.main() diff --git a/contrib/python/websocket-client/websocket/tests/test_handshake_large_response.py b/contrib/python/websocket-client/websocket/tests/test_handshake_large_response.py new file mode 100644 index 00000000000..3ca415a09bb --- /dev/null +++ b/contrib/python/websocket-client/websocket/tests/test_handshake_large_response.py @@ -0,0 +1,158 @@ +# -*- coding: utf-8 -*- +import unittest +from unittest.mock import Mock, patch + +from websocket._handshake import _get_resp_headers +from websocket._exceptions import WebSocketBadStatusException +from websocket._ssl_compat import SSLError + +""" +test_handshake_large_response.py +websocket - WebSocket client library for Python + +Copyright 2025 engn33r + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +class HandshakeLargeResponseTest(unittest.TestCase): + def test_large_error_response_chunked_reading(self): + """Test that large HTTP error responses during handshake are read in chunks""" + + # Mock socket + mock_sock = Mock() + + # Create a large error response body (> 16KB) + large_response = b"Error details: " + b"A" * 20000 # 20KB+ response + + # Track recv calls to ensure chunking + recv_calls = [] + + def mock_recv(sock, bufsize): + recv_calls.append(bufsize) + # Simulate SSL error if trying to read > 16KB at once + if bufsize > 16384: + raise SSLError("[SSL: BAD_LENGTH] unknown error") + return large_response[:bufsize] + + # Mock read_headers to return error status with large content-length + with patch("websocket._handshake.read_headers") as mock_read_headers: + mock_read_headers.return_value = ( + 400, # Bad request status + {"content-length": str(len(large_response))}, + "Bad Request", + ) + + # Mock the recv function to track calls + with patch("websocket._socket.recv", side_effect=mock_recv): + # This should not raise SSLError, but should raise WebSocketBadStatusException + with self.assertRaises(WebSocketBadStatusException) as cm: + _get_resp_headers(mock_sock) + + # Verify the response body was included in the exception + self.assertIn( + b"Error details:", + ( + cm.exception.args[0].encode() + if isinstance(cm.exception.args[0], str) + else cm.exception.args[0] + ), + ) + + # Verify chunked reading was used (multiple recv calls, none > 16KB) + self.assertGreater(len(recv_calls), 1) + self.assertTrue(all(call <= 16384 for call in recv_calls)) + + def test_handshake_ssl_large_response_protection(self): + """Test that the fix prevents SSL BAD_LENGTH errors during handshake""" + + mock_sock = Mock() + + # Large content that would trigger SSL error if read all at once + large_content = b"X" * 32768 # 32KB + + chunks_returned = 0 + + def mock_recv_chunked(sock, bufsize): + nonlocal chunks_returned + # Return data in chunks, simulating successful chunked reading + chunk_start = chunks_returned * 16384 + chunk_end = min(chunk_start + bufsize, len(large_content)) + result = large_content[chunk_start:chunk_end] + chunks_returned += 1 if result else 0 + return result + + with patch("websocket._handshake.read_headers") as mock_read_headers: + mock_read_headers.return_value = ( + 500, # Server error + {"content-length": str(len(large_content))}, + "Internal Server Error", + ) + + with patch("websocket._socket.recv", side_effect=mock_recv_chunked): + # Should handle large response without SSL errors + with self.assertRaises(WebSocketBadStatusException) as cm: + _get_resp_headers(mock_sock) + + # Verify the complete response was captured + exception_str = str(cm.exception) + # Response body should be in the exception message + self.assertIn("XXXXX", exception_str) # Part of the large content + + def test_handshake_normal_small_response(self): + """Test that normal small responses still work correctly""" + + mock_sock = Mock() + small_response = b"Small error message" + + def mock_recv(sock, bufsize): + return small_response + + with patch("websocket._handshake.read_headers") as mock_read_headers: + mock_read_headers.return_value = ( + 404, # Not found + {"content-length": str(len(small_response))}, + "Not Found", + ) + + with patch("websocket._socket.recv", side_effect=mock_recv): + with self.assertRaises(WebSocketBadStatusException) as cm: + _get_resp_headers(mock_sock) + + # Verify small response is handled correctly + self.assertIn("Small error message", str(cm.exception)) + + def test_handshake_no_content_length(self): + """Test handshake error response without content-length header""" + + mock_sock = Mock() + + with patch("websocket._handshake.read_headers") as mock_read_headers: + mock_read_headers.return_value = ( + 403, # Forbidden + {}, # No content-length header + "Forbidden", + ) + + # Should raise exception without trying to read response body + with self.assertRaises(WebSocketBadStatusException) as cm: + _get_resp_headers(mock_sock) + + # Should mention status but not have response body + exception_str = str(cm.exception) + self.assertIn("403", exception_str) + self.assertIn("Forbidden", exception_str) + + +if __name__ == "__main__": + unittest.main() diff --git a/contrib/python/websocket-client/websocket/tests/test_http.py b/contrib/python/websocket-client/websocket/tests/test_http.py index 72465c22057..55a9a9cc0d5 100644 --- a/contrib/python/websocket-client/websocket/tests/test_http.py +++ b/contrib/python/websocket-client/websocket/tests/test_http.py @@ -22,7 +22,7 @@ from websocket._http import ( test_http.py websocket - WebSocket client library for Python -Copyright 2024 engn33r +Copyright 2025 engn33r Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/contrib/python/websocket-client/websocket/tests/test_large_payloads.py b/contrib/python/websocket-client/websocket/tests/test_large_payloads.py new file mode 100644 index 00000000000..4d69c635f11 --- /dev/null +++ b/contrib/python/websocket-client/websocket/tests/test_large_payloads.py @@ -0,0 +1,273 @@ +# -*- coding: utf-8 -*- +import unittest +import struct +from unittest.mock import Mock, patch, MagicMock + +from websocket._abnf import ABNF +from websocket._core import WebSocket +from websocket._exceptions import WebSocketProtocolException, WebSocketPayloadException +from websocket._ssl_compat import SSLError + +""" +test_large_payloads.py +websocket - WebSocket client library for Python + +Copyright 2025 engn33r + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +class LargePayloadTest(unittest.TestCase): + def test_frame_length_encoding_boundaries(self): + """Test WebSocket frame length encoding at various boundaries""" + + # Test length encoding boundaries as per RFC 6455 + test_cases = [ + (125, "Single byte length"), # Max for 7-bit length + (126, "Two byte length start"), # Start of 16-bit length + (127, "Two byte length"), + (65535, "Two byte length max"), # Max for 16-bit length + (65536, "Eight byte length start"), # Start of 64-bit length + (16384, "16KB boundary"), # The problematic size + (16385, "Just over 16KB"), + (32768, "32KB"), + (131072, "128KB"), + ] + + for length, description in test_cases: + with self.subTest(length=length, description=description): + # Create payload of specified length + payload = b"A" * length + + # Create frame + frame = ABNF.create_frame(payload, ABNF.OPCODE_BINARY) + + # Verify frame can be formatted without error + formatted = frame.format() + + # Verify the frame header is correctly structured + self.assertIsInstance(formatted, bytes) + self.assertTrue(len(formatted) >= length) # Header + payload + + # Verify payload length is preserved + self.assertEqual(len(frame.data), length) + + def test_recv_large_payload_chunked(self): + """Test receiving large payloads in chunks (simulating the 16KB recv issue)""" + + # Create a large payload that would trigger chunked reading + large_payload = b"B" * 32768 # 32KB + + # Mock recv function that returns data in 16KB chunks + chunks = [] + chunk_size = 16384 + for i in range(0, len(large_payload), chunk_size): + chunks.append(large_payload[i : i + chunk_size]) + + call_count = 0 + + def mock_recv(bufsize): + nonlocal call_count + if call_count >= len(chunks): + return b"" + result = chunks[call_count] + call_count += 1 + return result + + # Test the frame buffer's recv_strict method + from websocket._abnf import frame_buffer + + fb = frame_buffer(mock_recv, skip_utf8_validation=True) + + # This should handle large payloads by chunking + result = fb.recv_strict(len(large_payload)) + + self.assertEqual(result, large_payload) + # Verify multiple recv calls were made + self.assertGreater(call_count, 1) + + def test_ssl_large_payload_simulation(self): + """Simulate SSL BAD_LENGTH error scenario""" + + # This test demonstrates that the 16KB limit in frame buffer protects against SSL issues + payload_size = 16385 + + recv_calls = [] + + def mock_recv_with_ssl_limit(bufsize): + recv_calls.append(bufsize) + # This simulates the SSL issue: BAD_LENGTH when trying to recv > 16KB + if bufsize > 16384: + raise SSLError("[SSL: BAD_LENGTH] unknown error") + return b"C" * min(bufsize, 16384) + + from websocket._abnf import frame_buffer + + fb = frame_buffer(mock_recv_with_ssl_limit, skip_utf8_validation=True) + + # The frame buffer handles this correctly by chunking recv calls + result = fb.recv_strict(payload_size) + + # Verify it worked and chunked the calls properly + self.assertEqual(len(result), payload_size) + # Verify no single recv call was > 16KB + self.assertTrue(all(call <= 16384 for call in recv_calls)) + # Verify multiple calls were made + self.assertGreater(len(recv_calls), 1) + + def test_frame_format_large_payloads(self): + """Test frame formatting with various large payload sizes""" + + # Test sizes around potential problem areas + test_sizes = [16383, 16384, 16385, 32768, 65535, 65536] + + for size in test_sizes: + with self.subTest(size=size): + payload = b"D" * size + frame = ABNF.create_frame(payload, ABNF.OPCODE_BINARY) + + # Should not raise any exceptions + formatted = frame.format() + + # Verify structure + self.assertIsInstance(formatted, bytes) + self.assertEqual(len(frame.data), size) + + # Verify length encoding is correct based on size + # Note: frames from create_frame() include masking by default (4 extra bytes) + mask_size = 4 # WebSocket frames are masked by default + if size < ABNF.LENGTH_7: # < 126 + # Length should be encoded in single byte + expected_header_size = ( + 2 + mask_size + ) # 1 byte opcode + 1 byte length + 4 byte mask + elif size < ABNF.LENGTH_16: # < 65536 + # Length should be encoded in 2 bytes + expected_header_size = ( + 4 + mask_size + ) # 1 byte opcode + 1 byte marker + 2 bytes length + 4 byte mask + else: + # Length should be encoded in 8 bytes + expected_header_size = ( + 10 + mask_size + ) # 1 byte opcode + 1 byte marker + 8 bytes length + 4 byte mask + + self.assertEqual(len(formatted), expected_header_size + size) + + def test_send_large_payload_chunking(self): + """Test that large payloads are sent in chunks to avoid SSL issues""" + + mock_sock = Mock() + + # Track how data is sent + sent_chunks = [] + + def mock_send(data): + sent_chunks.append(len(data)) + return len(data) + + mock_sock.send = mock_send + mock_sock.gettimeout.return_value = 30.0 + + # Create WebSocket with mocked socket + ws = WebSocket() + ws.sock = mock_sock + ws.connected = True + + # Create large payload + large_payload = b"E" * 32768 # 32KB + + # Send the payload + with patch("websocket._core.send") as mock_send_func: + mock_send_func.side_effect = lambda sock, data: len(data) + + # This should work without SSL errors + result = ws.send_binary(large_payload) + + # Verify payload was accepted + self.assertGreater(result, 0) + + def test_utf8_validation_large_text(self): + """Test UTF-8 validation with large text payloads""" + + # Create large valid UTF-8 text + large_text = "Hello 世界! " * 2000 # About 26KB with Unicode + + # Test frame creation + frame = ABNF.create_frame(large_text, ABNF.OPCODE_TEXT) + + # Should not raise validation errors + formatted = frame.format() + self.assertIsInstance(formatted, bytes) + + # Test with close frame that has invalid UTF-8 (this is what validate() actually checks) + invalid_utf8_close_data = struct.pack("!H", 1000) + b"\xff\xfe invalid utf8" + + # Create close frame with invalid UTF-8 data + frame = ABNF(1, 0, 0, 0, ABNF.OPCODE_CLOSE, 1, invalid_utf8_close_data) + + # Validation should catch the invalid UTF-8 in close frame reason + with self.assertRaises(WebSocketProtocolException): + frame.validate(skip_utf8_validation=False) + + def test_frame_buffer_edge_cases(self): + """Test frame buffer with edge cases that could trigger bugs""" + + # Test scenario: exactly 16KB payload split across recv calls + payload_16k = b"F" * 16384 + + # Simulate receiving in smaller chunks + chunks = [payload_16k[i : i + 4096] for i in range(0, len(payload_16k), 4096)] + + call_count = 0 + + def mock_recv(bufsize): + nonlocal call_count + if call_count >= len(chunks): + return b"" + result = chunks[call_count] + call_count += 1 + return result + + from websocket._abnf import frame_buffer + + fb = frame_buffer(mock_recv, skip_utf8_validation=True) + result = fb.recv_strict(16384) + + self.assertEqual(result, payload_16k) + # Verify multiple recv calls were made + self.assertEqual(call_count, 4) # 16KB / 4KB = 4 chunks + + def test_max_frame_size_limits(self): + """Test behavior at WebSocket maximum frame size limits""" + + # Test just under the maximum theoretical frame size + # (This is a very large test, so we'll use a smaller representative size) + + # Test with a reasonably large payload that represents the issue + large_size = 1024 * 1024 # 1MB + payload = b"G" * large_size + + # This should work without issues + frame = ABNF.create_frame(payload, ABNF.OPCODE_BINARY) + + # Verify the frame can be formatted + formatted = frame.format() + self.assertIsInstance(formatted, bytes) + + # Verify payload is preserved + self.assertEqual(len(frame.data), large_size) + + +if __name__ == "__main__": + unittest.main() diff --git a/contrib/python/websocket-client/websocket/tests/test_socket.py b/contrib/python/websocket-client/websocket/tests/test_socket.py new file mode 100644 index 00000000000..5b8b65bd6b5 --- /dev/null +++ b/contrib/python/websocket-client/websocket/tests/test_socket.py @@ -0,0 +1,357 @@ +# -*- coding: utf-8 -*- +import errno +import socket +import unittest +from unittest.mock import Mock, patch, MagicMock +import time + +from websocket._socket import recv, recv_line, send, DEFAULT_SOCKET_OPTION +from websocket._ssl_compat import ( + SSLError, + SSLEOFError, + SSLWantWriteError, + SSLWantReadError, +) +from websocket._exceptions import ( + WebSocketTimeoutException, + WebSocketConnectionClosedException, +) + +""" +test_socket.py +websocket - WebSocket client library for Python + +Copyright 2025 engn33r + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +class SocketTest(unittest.TestCase): + def test_default_socket_option(self): + """Test DEFAULT_SOCKET_OPTION contains expected options""" + self.assertIsInstance(DEFAULT_SOCKET_OPTION, list) + self.assertGreater(len(DEFAULT_SOCKET_OPTION), 0) + + # Should contain TCP_NODELAY option + tcp_nodelay_found = any( + opt[1] == socket.TCP_NODELAY for opt in DEFAULT_SOCKET_OPTION + ) + self.assertTrue(tcp_nodelay_found) + + def test_recv_normal(self): + """Test normal recv operation""" + mock_sock = Mock() + mock_sock.recv.return_value = b"test data" + + result = recv(mock_sock, 9) + + self.assertEqual(result, b"test data") + mock_sock.recv.assert_called_once_with(9) + + def test_recv_timeout_error(self): + """Test recv with TimeoutError""" + mock_sock = Mock() + mock_sock.recv.side_effect = TimeoutError("Connection timed out") + + with self.assertRaises(WebSocketTimeoutException) as cm: + recv(mock_sock, 9) + + self.assertEqual(str(cm.exception), "Connection timed out") + + def test_recv_socket_timeout(self): + """Test recv with socket.timeout""" + mock_sock = Mock() + timeout_exc = socket.timeout("Socket timed out") + timeout_exc.args = ("Socket timed out",) + mock_sock.recv.side_effect = timeout_exc + mock_sock.gettimeout.return_value = 30.0 + + with self.assertRaises(WebSocketTimeoutException) as cm: + recv(mock_sock, 9) + + # In Python 3.10+, socket.timeout is a subclass of TimeoutError + # so it's caught by the TimeoutError handler with hardcoded message + # In Python 3.9, socket.timeout is caught by socket.timeout handler + # which preserves the original message + import sys + + if sys.version_info >= (3, 10): + self.assertEqual(str(cm.exception), "Connection timed out") + else: + self.assertEqual(str(cm.exception), "Socket timed out") + + def test_recv_ssl_timeout(self): + """Test recv with SSL timeout error""" + mock_sock = Mock() + ssl_exc = SSLError("The operation timed out") + ssl_exc.args = ("The operation timed out",) + mock_sock.recv.side_effect = ssl_exc + + with self.assertRaises(WebSocketTimeoutException) as cm: + recv(mock_sock, 9) + + self.assertEqual(str(cm.exception), "The operation timed out") + + def test_recv_ssl_non_timeout_error(self): + """Test recv with SSL non-timeout error""" + mock_sock = Mock() + ssl_exc = SSLError("SSL certificate error") + ssl_exc.args = ("SSL certificate error",) + mock_sock.recv.side_effect = ssl_exc + + # Should re-raise the original SSL error + with self.assertRaises(SSLError): + recv(mock_sock, 9) + + def test_recv_empty_response(self): + """Test recv with empty response (connection closed)""" + mock_sock = Mock() + mock_sock.recv.return_value = b"" + + with self.assertRaises(WebSocketConnectionClosedException) as cm: + recv(mock_sock, 9) + + self.assertEqual(str(cm.exception), "Connection to remote host was lost.") + + def test_recv_ssl_want_read_error(self): + """Test recv with SSLWantReadError (should retry)""" + mock_sock = Mock() + + # First call raises SSLWantReadError, second call succeeds + mock_sock.recv.side_effect = [SSLWantReadError(), b"data after retry"] + + with patch("selectors.DefaultSelector") as mock_selector_class: + mock_selector = Mock() + mock_selector_class.return_value = mock_selector + mock_selector.select.return_value = [True] # Ready to read + + result = recv(mock_sock, 100) + + self.assertEqual(result, b"data after retry") + mock_selector.register.assert_called() + mock_selector.close.assert_called() + + def test_recv_ssl_want_read_timeout(self): + """Test recv with SSLWantReadError that times out""" + mock_sock = Mock() + mock_sock.recv.side_effect = SSLWantReadError() + mock_sock.gettimeout.return_value = 1.0 + + with patch("selectors.DefaultSelector") as mock_selector_class: + mock_selector = Mock() + mock_selector_class.return_value = mock_selector + mock_selector.select.return_value = [] # Timeout + + with self.assertRaises(WebSocketTimeoutException): + recv(mock_sock, 100) + + def test_recv_line(self): + """Test recv_line functionality""" + mock_sock = Mock() + + # Mock recv to return one character at a time + recv_calls = [b"H", b"e", b"l", b"l", b"o", b"\n"] + + with patch("websocket._socket.recv", side_effect=recv_calls) as mock_recv: + result = recv_line(mock_sock) + + self.assertEqual(result, b"Hello\n") + self.assertEqual(mock_recv.call_count, 6) + + def test_send_normal(self): + """Test normal send operation""" + mock_sock = Mock() + mock_sock.send.return_value = 9 + mock_sock.gettimeout.return_value = 30.0 + + result = send(mock_sock, b"test data") + + self.assertEqual(result, 9) + mock_sock.send.assert_called_with(b"test data") + + def test_send_zero_timeout(self): + """Test send with zero timeout (non-blocking)""" + mock_sock = Mock() + mock_sock.send.return_value = 9 + mock_sock.gettimeout.return_value = 0 + + result = send(mock_sock, b"test data") + + self.assertEqual(result, 9) + mock_sock.send.assert_called_once_with(b"test data") + + def test_send_ssl_eof_error(self): + """Test send with SSLEOFError""" + mock_sock = Mock() + mock_sock.gettimeout.return_value = 30.0 + mock_sock.send.side_effect = SSLEOFError("Connection closed") + + with self.assertRaises(WebSocketConnectionClosedException) as cm: + send(mock_sock, b"test data") + + self.assertEqual(str(cm.exception), "socket is already closed.") + + def test_send_ssl_want_write_error(self): + """Test send with SSLWantWriteError (should retry)""" + mock_sock = Mock() + mock_sock.gettimeout.return_value = 30.0 + + # First call raises SSLWantWriteError, second call succeeds + mock_sock.send.side_effect = [SSLWantWriteError(), 9] + + with patch("selectors.DefaultSelector") as mock_selector_class: + mock_selector = Mock() + mock_selector_class.return_value = mock_selector + mock_selector.select.return_value = [True] # Ready to write + + result = send(mock_sock, b"test data") + + self.assertEqual(result, 9) + mock_selector.register.assert_called() + mock_selector.close.assert_called() + + def test_send_socket_eagain_error(self): + """Test send with EAGAIN error (should retry)""" + mock_sock = Mock() + mock_sock.gettimeout.return_value = 30.0 + + # Create socket error with EAGAIN + eagain_error = socket.error("Resource temporarily unavailable") + eagain_error.errno = errno.EAGAIN + eagain_error.args = (errno.EAGAIN, "Resource temporarily unavailable") + + # First call raises EAGAIN, second call succeeds + mock_sock.send.side_effect = [eagain_error, 9] + + with patch("selectors.DefaultSelector") as mock_selector_class: + mock_selector = Mock() + mock_selector_class.return_value = mock_selector + mock_selector.select.return_value = [True] # Ready to write + + result = send(mock_sock, b"test data") + + self.assertEqual(result, 9) + + def test_send_socket_ewouldblock_error(self): + """Test send with EWOULDBLOCK error (should retry)""" + mock_sock = Mock() + mock_sock.gettimeout.return_value = 30.0 + + # Create socket error with EWOULDBLOCK + ewouldblock_error = socket.error("Operation would block") + ewouldblock_error.errno = errno.EWOULDBLOCK + ewouldblock_error.args = (errno.EWOULDBLOCK, "Operation would block") + + # First call raises EWOULDBLOCK, second call succeeds + mock_sock.send.side_effect = [ewouldblock_error, 9] + + with patch("selectors.DefaultSelector") as mock_selector_class: + mock_selector = Mock() + mock_selector_class.return_value = mock_selector + mock_selector.select.return_value = [True] # Ready to write + + result = send(mock_sock, b"test data") + + self.assertEqual(result, 9) + + def test_send_socket_other_error(self): + """Test send with other socket error (should raise)""" + mock_sock = Mock() + mock_sock.gettimeout.return_value = 30.0 + + # Create socket error with different errno + other_error = socket.error("Connection reset by peer") + other_error.errno = errno.ECONNRESET + other_error.args = (errno.ECONNRESET, "Connection reset by peer") + + mock_sock.send.side_effect = other_error + + with self.assertRaises(socket.error): + send(mock_sock, b"test data") + + def test_send_socket_error_no_errno(self): + """Test send with socket error that has no errno""" + mock_sock = Mock() + mock_sock.gettimeout.return_value = 30.0 + + # Create socket error without errno attribute + no_errno_error = socket.error("Generic socket error") + no_errno_error.args = ("Generic socket error",) + + mock_sock.send.side_effect = no_errno_error + + with self.assertRaises(socket.error): + send(mock_sock, b"test data") + + def test_send_write_timeout(self): + """Test send write operation timeout""" + mock_sock = Mock() + mock_sock.gettimeout.return_value = 30.0 + + # First call raises EAGAIN + eagain_error = socket.error("Resource temporarily unavailable") + eagain_error.errno = errno.EAGAIN + eagain_error.args = (errno.EAGAIN, "Resource temporarily unavailable") + + mock_sock.send.side_effect = eagain_error + + with patch("selectors.DefaultSelector") as mock_selector_class: + mock_selector = Mock() + mock_selector_class.return_value = mock_selector + mock_selector.select.return_value = [] # Timeout - nothing ready + + result = send(mock_sock, b"test data") + + # Should return 0 when write times out + self.assertEqual(result, 0) + + def test_send_string_data(self): + """Test send with string data (should be encoded)""" + mock_sock = Mock() + mock_sock.send.return_value = 9 + mock_sock.gettimeout.return_value = 30.0 + + result = send(mock_sock, "test data") + + self.assertEqual(result, 9) + mock_sock.send.assert_called_with(b"test data") + + def test_send_partial_send_retry(self): + """Test send retry mechanism""" + mock_sock = Mock() + mock_sock.gettimeout.return_value = 30.0 + + # Create a scenario where send succeeds after selector retry + eagain_error = socket.error("Resource temporarily unavailable") + eagain_error.errno = errno.EAGAIN + eagain_error.args = (errno.EAGAIN, "Resource temporarily unavailable") + + # Mock the internal _send function behavior + mock_sock.send.side_effect = [eagain_error, 9] + + with patch("selectors.DefaultSelector") as mock_selector_class: + mock_selector = Mock() + mock_selector_class.return_value = mock_selector + mock_selector.select.return_value = [True] # Socket ready for writing + + result = send(mock_sock, b"test data") + + self.assertEqual(result, 9) + # Verify selector was used for retry mechanism + mock_selector.register.assert_called() + mock_selector.select.assert_called() + mock_selector.close.assert_called() + + +if __name__ == "__main__": + unittest.main() diff --git a/contrib/python/websocket-client/websocket/tests/test_socket_bugs.py b/contrib/python/websocket-client/websocket/tests/test_socket_bugs.py new file mode 100644 index 00000000000..72f222f5c4c --- /dev/null +++ b/contrib/python/websocket-client/websocket/tests/test_socket_bugs.py @@ -0,0 +1,160 @@ +# -*- coding: utf-8 -*- +import errno +import socket +import unittest +from unittest.mock import Mock, patch + +from websocket._socket import recv +from websocket._ssl_compat import SSLWantReadError +from websocket._exceptions import ( + WebSocketTimeoutException, + WebSocketConnectionClosedException, +) + +""" +test_socket_bugs.py +websocket - WebSocket client library for Python + +Copyright 2025 engn33r + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +class SocketBugsTest(unittest.TestCase): + """Test bugs found in socket handling logic""" + + def test_bug_implicit_none_return_from_ssl_want_read_fixed(self): + """ + BUG #5 FIX VERIFICATION: Test SSLWantReadError timeout now raises correct exception + + Bug was in _socket.py:100-101 - SSLWantReadError except block returned None implicitly + Fixed: Now properly handles timeout with WebSocketTimeoutException + """ + mock_sock = Mock() + mock_sock.recv.side_effect = SSLWantReadError() + mock_sock.gettimeout.return_value = 1.0 + + with patch("selectors.DefaultSelector") as mock_selector_class: + mock_selector = Mock() + mock_selector_class.return_value = mock_selector + mock_selector.select.return_value = [] # Timeout - no data ready + + with self.assertRaises(WebSocketTimeoutException) as cm: + recv(mock_sock, 100) + + # Verify correct timeout exception and message + self.assertIn("Connection timed out waiting for data", str(cm.exception)) + + def test_bug_implicit_none_return_from_socket_error_fixed(self): + """ + BUG #5 FIX VERIFICATION: Test that socket.error with EAGAIN now handles timeout correctly + + Bug was in _socket.py:102-105 - socket.error except block returned None implicitly + Fixed: Now properly handles timeout with WebSocketTimeoutException + """ + mock_sock = Mock() + + # Create socket error with EAGAIN (should be retried) + eagain_error = OSError(errno.EAGAIN, "Resource temporarily unavailable") + + # First call raises EAGAIN, selector times out on retry + mock_sock.recv.side_effect = eagain_error + mock_sock.gettimeout.return_value = 1.0 + + with patch("selectors.DefaultSelector") as mock_selector_class: + mock_selector = Mock() + mock_selector_class.return_value = mock_selector + mock_selector.select.return_value = [] # Timeout - no data ready + + with self.assertRaises(WebSocketTimeoutException) as cm: + recv(mock_sock, 100) + + # Verify correct timeout exception and message + self.assertIn("Connection timed out waiting for data", str(cm.exception)) + + def test_bug_wrong_exception_for_selector_timeout_fixed(self): + """ + BUG #6 FIX VERIFICATION: Test that selector timeout now raises correct exception type + + Bug was in _socket.py:115 returning None for timeout, treated as connection error + Fixed: Now raises WebSocketTimeoutException directly + """ + mock_sock = Mock() + mock_sock.recv.side_effect = SSLWantReadError() # Trigger retry path + mock_sock.gettimeout.return_value = 1.0 + + with patch("selectors.DefaultSelector") as mock_selector_class: + mock_selector = Mock() + mock_selector_class.return_value = mock_selector + mock_selector.select.return_value = [] # TIMEOUT - this is key! + + with self.assertRaises(WebSocketTimeoutException) as cm: + recv(mock_sock, 100) + + # Verify it's the correct timeout exception with proper message + self.assertIn("Connection timed out waiting for data", str(cm.exception)) + + # This proves the fix works: + # 1. selector.select() returns [] (timeout) + # 2. _recv() now raises WebSocketTimeoutException directly + # 3. No more misclassification as connection closed error! + + def test_socket_timeout_exception_handling(self): + """ + Test that socket.timeout exceptions are properly handled + """ + mock_sock = Mock() + mock_sock.gettimeout.return_value = 1.0 + + # Simulate a real socket.timeout scenario + mock_sock.recv.side_effect = socket.timeout("Operation timed out") + + # This works correctly - socket.timeout raises WebSocketTimeoutException + with self.assertRaises(WebSocketTimeoutException) as cm: + recv(mock_sock, 100) + + # In Python 3.10+, socket.timeout is a subclass of TimeoutError + # so it's caught by the TimeoutError handler with hardcoded message + # In Python 3.9, socket.timeout is caught by socket.timeout handler + # which preserves the original message + import sys + + if sys.version_info >= (3, 10): + self.assertIn("Connection timed out", str(cm.exception)) + else: + self.assertIn("Operation timed out", str(cm.exception)) + + def test_correct_ssl_want_read_retry_behavior(self): + """Test the correct behavior when SSLWantReadError is properly handled""" + mock_sock = Mock() + + # First call raises SSLWantReadError, second call succeeds + mock_sock.recv.side_effect = [SSLWantReadError(), b"data after retry"] + mock_sock.gettimeout.return_value = 1.0 + + with patch("selectors.DefaultSelector") as mock_selector_class: + mock_selector = Mock() + mock_selector_class.return_value = mock_selector + mock_selector.select.return_value = [True] # Data ready after wait + + # This should work correctly + result = recv(mock_sock, 100) + self.assertEqual(result, b"data after retry") + + # Selector should be used for retry + mock_selector.register.assert_called() + mock_selector.select.assert_called() + + +if __name__ == "__main__": + unittest.main() diff --git a/contrib/python/websocket-client/websocket/tests/test_ssl_compat.py b/contrib/python/websocket-client/websocket/tests/test_ssl_compat.py new file mode 100644 index 00000000000..9dcd674b0f0 --- /dev/null +++ b/contrib/python/websocket-client/websocket/tests/test_ssl_compat.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- +import sys +import unittest +from unittest.mock import patch + +""" +test_ssl_compat.py +websocket - WebSocket client library for Python + +Copyright 2025 engn33r + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +class SSLCompatTest(unittest.TestCase): + def test_ssl_available(self): + """Test that SSL is available in normal conditions""" + import websocket._ssl_compat as ssl_compat + + # In normal conditions, SSL should be available + self.assertTrue(ssl_compat.HAVE_SSL) + self.assertIsNotNone(ssl_compat.ssl) + + # SSL exception classes should be available + self.assertTrue(hasattr(ssl_compat, "SSLError")) + self.assertTrue(hasattr(ssl_compat, "SSLEOFError")) + self.assertTrue(hasattr(ssl_compat, "SSLWantReadError")) + self.assertTrue(hasattr(ssl_compat, "SSLWantWriteError")) + + def test_ssl_not_available(self): + """Test fallback behavior when SSL is not available""" + # Remove ssl_compat from modules to force reimport + if "websocket._ssl_compat" in sys.modules: + del sys.modules["websocket._ssl_compat"] + + # Mock the ssl module to not be available + import builtins + + original_import = builtins.__import__ + + def mock_import(name, *args, **kwargs): + if name == "ssl": + raise ImportError("No module named 'ssl'") + return original_import(name, *args, **kwargs) + + with patch("builtins.__import__", side_effect=mock_import): + import websocket._ssl_compat as ssl_compat + + # SSL should not be available + self.assertFalse(ssl_compat.HAVE_SSL) + self.assertIsNone(ssl_compat.ssl) + + # Fallback exception classes should be available and functional + self.assertTrue(issubclass(ssl_compat.SSLError, Exception)) + self.assertTrue(issubclass(ssl_compat.SSLEOFError, Exception)) + self.assertTrue(issubclass(ssl_compat.SSLWantReadError, Exception)) + self.assertTrue(issubclass(ssl_compat.SSLWantWriteError, Exception)) + + # Test that exceptions can be instantiated + ssl_error = ssl_compat.SSLError("test error") + self.assertIsInstance(ssl_error, Exception) + self.assertEqual(str(ssl_error), "test error") + + ssl_eof_error = ssl_compat.SSLEOFError("test eof") + self.assertIsInstance(ssl_eof_error, Exception) + + ssl_want_read = ssl_compat.SSLWantReadError("test read") + self.assertIsInstance(ssl_want_read, Exception) + + ssl_want_write = ssl_compat.SSLWantWriteError("test write") + self.assertIsInstance(ssl_want_write, Exception) + + def tearDown(self): + """Clean up after tests""" + # Ensure ssl_compat is reimported fresh for next test + if "websocket._ssl_compat" in sys.modules: + del sys.modules["websocket._ssl_compat"] + + +if __name__ == "__main__": + unittest.main() diff --git a/contrib/python/websocket-client/websocket/tests/test_ssl_edge_cases.py b/contrib/python/websocket-client/websocket/tests/test_ssl_edge_cases.py new file mode 100644 index 00000000000..a8e14d3f4ed --- /dev/null +++ b/contrib/python/websocket-client/websocket/tests/test_ssl_edge_cases.py @@ -0,0 +1,638 @@ +# -*- coding: utf-8 -*- +import unittest +import socket +import ssl +from unittest.mock import Mock, patch, MagicMock + +from websocket._ssl_compat import ( + SSLError, + SSLEOFError, + SSLWantReadError, + SSLWantWriteError, + HAVE_SSL, +) +from websocket._http import _ssl_socket, _wrap_sni_socket +from websocket._exceptions import WebSocketException +from websocket._socket import recv, send + +""" +test_ssl_edge_cases.py +websocket - WebSocket client library for Python + +Copyright 2025 engn33r + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +class SSLEdgeCasesTest(unittest.TestCase): + + def setUp(self): + if not HAVE_SSL: + self.skipTest("SSL not available") + + def test_ssl_handshake_failure(self): + """Test SSL handshake failure scenarios""" + mock_sock = Mock() + + # Test SSL handshake timeout + with patch("ssl.SSLContext") as mock_ssl_context: + mock_context = Mock() + mock_ssl_context.return_value = mock_context + mock_context.wrap_socket.side_effect = socket.timeout( + "SSL handshake timeout" + ) + + sslopt = {"cert_reqs": ssl.CERT_REQUIRED} + + with self.assertRaises(socket.timeout): + _ssl_socket(mock_sock, sslopt, "example.com") + + def test_ssl_certificate_verification_failures(self): + """Test various SSL certificate verification failure scenarios""" + mock_sock = Mock() + + # Test certificate verification failure + with patch("ssl.SSLContext") as mock_ssl_context: + mock_context = Mock() + mock_ssl_context.return_value = mock_context + mock_context.wrap_socket.side_effect = ssl.SSLCertVerificationError( + "Certificate verification failed" + ) + + sslopt = {"cert_reqs": ssl.CERT_REQUIRED, "check_hostname": True} + + with self.assertRaises(ssl.SSLCertVerificationError): + _ssl_socket(mock_sock, sslopt, "badssl.example") + + def test_ssl_context_configuration_edge_cases(self): + """Test SSL context configuration with various edge cases""" + mock_sock = Mock() + + # Test with pre-created SSL context + with patch("ssl.SSLContext") as mock_ssl_context: + existing_context = Mock() + existing_context.wrap_socket.return_value = Mock() + mock_ssl_context.return_value = existing_context + + sslopt = {"context": existing_context} + + # Call _ssl_socket which should use the existing context + _ssl_socket(mock_sock, sslopt, "example.com") + + # Should use the provided context, not create a new one + existing_context.wrap_socket.assert_called_once() + + def test_ssl_ca_bundle_environment_edge_cases(self): + """Test CA bundle environment variable edge cases""" + mock_sock = Mock() + + # Test with non-existent CA bundle file + with patch.dict( + "os.environ", {"WEBSOCKET_CLIENT_CA_BUNDLE": "/nonexistent/ca-bundle.crt"} + ): + with patch("os.path.isfile", return_value=False): + with patch("os.path.isdir", return_value=False): + with patch("ssl.SSLContext") as mock_ssl_context: + mock_context = Mock() + mock_ssl_context.return_value = mock_context + mock_context.wrap_socket.return_value = Mock() + + sslopt = {} + _ssl_socket(mock_sock, sslopt, "example.com") + + # Should not try to load non-existent CA bundle + mock_context.load_verify_locations.assert_not_called() + + # Test with CA bundle directory + with patch.dict("os.environ", {"WEBSOCKET_CLIENT_CA_BUNDLE": "/etc/ssl/certs"}): + with patch("os.path.isfile", return_value=False): + with patch("os.path.isdir", return_value=True): + with patch("ssl.SSLContext") as mock_ssl_context: + mock_context = Mock() + mock_ssl_context.return_value = mock_context + mock_context.wrap_socket.return_value = Mock() + + sslopt = {} + _ssl_socket(mock_sock, sslopt, "example.com") + + # Should load CA directory + mock_context.load_verify_locations.assert_called_with( + cafile=None, capath="/etc/ssl/certs" + ) + + def test_ssl_cipher_configuration_edge_cases(self): + """Test SSL cipher configuration edge cases""" + mock_sock = Mock() + + # Test with invalid cipher suite + with patch("ssl.SSLContext") as mock_ssl_context: + mock_context = Mock() + mock_ssl_context.return_value = mock_context + mock_context.set_ciphers.side_effect = ssl.SSLError( + "No cipher can be selected" + ) + mock_context.wrap_socket.return_value = Mock() + + sslopt = {"ciphers": "INVALID_CIPHER"} + + with self.assertRaises(WebSocketException): + _ssl_socket(mock_sock, sslopt, "example.com") + + def test_ssl_ecdh_curve_edge_cases(self): + """Test ECDH curve configuration edge cases""" + mock_sock = Mock() + + # Test with invalid ECDH curve + with patch("ssl.SSLContext") as mock_ssl_context: + mock_context = Mock() + mock_ssl_context.return_value = mock_context + mock_context.set_ecdh_curve.side_effect = ValueError("unknown curve name") + mock_context.wrap_socket.return_value = Mock() + + sslopt = {"ecdh_curve": "invalid_curve"} + + with self.assertRaises(WebSocketException): + _ssl_socket(mock_sock, sslopt, "example.com") + + def test_ssl_client_certificate_edge_cases(self): + """Test client certificate configuration edge cases""" + mock_sock = Mock() + + # Test with non-existent client certificate + with patch("ssl.SSLContext") as mock_ssl_context: + mock_context = Mock() + mock_ssl_context.return_value = mock_context + mock_context.load_cert_chain.side_effect = FileNotFoundError("No such file") + mock_context.wrap_socket.return_value = Mock() + + sslopt = {"certfile": "/nonexistent/client.crt"} + + with self.assertRaises(WebSocketException): + _ssl_socket(mock_sock, sslopt, "example.com") + + def test_ssl_want_read_write_retry_edge_cases(self): + """Test SSL want read/write retry edge cases""" + mock_sock = Mock() + + # Test SSLWantReadError with multiple retries before success + read_attempts = [0] # Use list for mutable reference + + def mock_recv(bufsize): + read_attempts[0] += 1 + if read_attempts[0] == 1: + raise SSLWantReadError("The operation did not complete") + elif read_attempts[0] == 2: + return b"data after retries" + else: + return b"" + + mock_sock.recv.side_effect = mock_recv + mock_sock.gettimeout.return_value = 30.0 + + with patch("selectors.DefaultSelector") as mock_selector_class: + mock_selector = Mock() + mock_selector_class.return_value = mock_selector + mock_selector.select.return_value = [True] # Always ready + + result = recv(mock_sock, 100) + + self.assertEqual(result, b"data after retries") + self.assertEqual(read_attempts[0], 2) + # Should have used selector for retry + mock_selector.register.assert_called() + mock_selector.select.assert_called() + + def test_ssl_want_write_retry_edge_cases(self): + """Test SSL want write retry edge cases""" + mock_sock = Mock() + + # Test SSLWantWriteError with multiple retries before success + write_attempts = [0] # Use list for mutable reference + + def mock_send(data): + write_attempts[0] += 1 + if write_attempts[0] == 1: + raise SSLWantWriteError("The operation did not complete") + elif write_attempts[0] == 2: + return len(data) + else: + return 0 + + mock_sock.send.side_effect = mock_send + mock_sock.gettimeout.return_value = 30.0 + + with patch("selectors.DefaultSelector") as mock_selector_class: + mock_selector = Mock() + mock_selector_class.return_value = mock_selector + mock_selector.select.return_value = [True] # Always ready + + result = send(mock_sock, b"test data") + + self.assertEqual(result, 9) # len("test data") + self.assertEqual(write_attempts[0], 2) + + def test_ssl_eof_error_edge_cases(self): + """Test SSL EOF error edge cases""" + mock_sock = Mock() + + # Test SSLEOFError during send + mock_sock.send.side_effect = SSLEOFError("SSL connection has been closed") + mock_sock.gettimeout.return_value = 30.0 + + from websocket._exceptions import WebSocketConnectionClosedException + + with self.assertRaises(WebSocketConnectionClosedException): + send(mock_sock, b"test data") + + def test_ssl_pending_data_edge_cases(self): + """Test SSL pending data scenarios""" + from websocket._dispatcher import SSLDispatcher + from websocket._app import WebSocketApp + + # Mock SSL socket with pending data + mock_ssl_sock = Mock() + mock_ssl_sock.pending.return_value = 1024 # Simulates pending SSL data + + # Mock WebSocketApp + mock_app = Mock(spec=WebSocketApp) + mock_app.sock = Mock() + mock_app.sock.sock = mock_ssl_sock + + dispatcher = SSLDispatcher(mock_app, 5.0) + + # When there's pending data, should return immediately without selector + result = dispatcher.select(mock_ssl_sock, Mock()) + + # Should return the socket list when there's pending data + self.assertEqual(result, [mock_ssl_sock]) + mock_ssl_sock.pending.assert_called_once() + + def test_ssl_renegotiation_edge_cases(self): + """Test SSL renegotiation scenarios""" + mock_sock = Mock() + + # Simulate SSL renegotiation during read + call_count = 0 + + def mock_recv(bufsize): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise SSLWantReadError("SSL renegotiation required") + return b"data after renegotiation" + + mock_sock.recv.side_effect = mock_recv + mock_sock.gettimeout.return_value = 30.0 + + with patch("selectors.DefaultSelector") as mock_selector_class: + mock_selector = Mock() + mock_selector_class.return_value = mock_selector + mock_selector.select.return_value = [True] + + result = recv(mock_sock, 100) + + self.assertEqual(result, b"data after renegotiation") + self.assertEqual(call_count, 2) + + def test_ssl_server_hostname_override(self): + """Test SSL server hostname override scenarios""" + mock_sock = Mock() + + with patch("ssl.SSLContext") as mock_ssl_context: + mock_context = Mock() + mock_ssl_context.return_value = mock_context + mock_context.wrap_socket.return_value = Mock() + + # Test server_hostname override + sslopt = {"server_hostname": "override.example.com"} + _ssl_socket(mock_sock, sslopt, "original.example.com") + + # Should use override hostname in wrap_socket call + mock_context.wrap_socket.assert_called_with( + mock_sock, + do_handshake_on_connect=True, + suppress_ragged_eofs=True, + server_hostname="override.example.com", + ) + + def test_ssl_protocol_version_edge_cases(self): + """Test SSL protocol version edge cases""" + mock_sock = Mock() + + # Test with deprecated SSL version + with patch("ssl.SSLContext") as mock_ssl_context: + mock_context = Mock() + mock_ssl_context.return_value = mock_context + mock_context.wrap_socket.return_value = Mock() + + # Test that deprecated ssl_version is still handled + if hasattr(ssl, "PROTOCOL_TLS"): + sslopt = {"ssl_version": ssl.PROTOCOL_TLS} + _ssl_socket(mock_sock, sslopt, "example.com") + + mock_ssl_context.assert_called_with(ssl.PROTOCOL_TLS) + + def test_ssl_keylog_file_edge_cases(self): + """Test SSL keylog file configuration edge cases""" + mock_sock = Mock() + + # Test with SSLKEYLOGFILE environment variable + with patch.dict("os.environ", {"SSLKEYLOGFILE": "/tmp/ssl_keys.log"}): + with patch("ssl.SSLContext") as mock_ssl_context: + mock_context = Mock() + mock_ssl_context.return_value = mock_context + mock_context.wrap_socket.return_value = Mock() + + sslopt = {} + _ssl_socket(mock_sock, sslopt, "example.com") + + # Should set keylog_filename + self.assertEqual(mock_context.keylog_filename, "/tmp/ssl_keys.log") + + def test_ssl_context_verification_modes(self): + """Test different SSL verification mode combinations""" + mock_sock = Mock() + + test_cases = [ + # (cert_reqs, check_hostname, expected_verify_mode, expected_check_hostname) + (ssl.CERT_NONE, False, ssl.CERT_NONE, False), + (ssl.CERT_REQUIRED, False, ssl.CERT_REQUIRED, False), + (ssl.CERT_REQUIRED, True, ssl.CERT_REQUIRED, True), + ] + + for cert_reqs, check_hostname, expected_verify, expected_check in test_cases: + with self.subTest(cert_reqs=cert_reqs, check_hostname=check_hostname): + with patch("ssl.SSLContext") as mock_ssl_context: + mock_context = Mock() + mock_ssl_context.return_value = mock_context + mock_context.wrap_socket.return_value = Mock() + + sslopt = {"cert_reqs": cert_reqs, "check_hostname": check_hostname} + _ssl_socket(mock_sock, sslopt, "example.com") + + self.assertEqual(mock_context.verify_mode, expected_verify) + self.assertEqual(mock_context.check_hostname, expected_check) + + def test_ssl_socket_shutdown_edge_cases(self): + """Test SSL socket shutdown edge cases""" + from websocket._core import WebSocket + + mock_ssl_sock = Mock() + mock_ssl_sock.shutdown.side_effect = SSLError("SSL shutdown failed") + + ws = WebSocket() + ws.sock = mock_ssl_sock + ws.connected = True + + # Should handle SSL shutdown errors gracefully + try: + ws.close() + except SSLError: + self.fail("SSL shutdown error should be handled gracefully") + + def test_ssl_socket_close_during_operation(self): + """Test SSL socket being closed during ongoing operations""" + mock_sock = Mock() + + # Simulate SSL socket being closed during recv + mock_sock.recv.side_effect = SSLError( + "SSL connection has been closed unexpectedly" + ) + mock_sock.gettimeout.return_value = 30.0 + + from websocket._exceptions import WebSocketConnectionClosedException + + # Should handle unexpected SSL closure + with self.assertRaises((SSLError, WebSocketConnectionClosedException)): + recv(mock_sock, 100) + + def test_ssl_compression_edge_cases(self): + """Test SSL compression configuration edge cases""" + mock_sock = Mock() + + with patch("ssl.SSLContext") as mock_ssl_context: + mock_context = Mock() + mock_ssl_context.return_value = mock_context + mock_context.wrap_socket.return_value = Mock() + + # Test SSL compression options (if available) + sslopt = {"compression": False} # Some SSL contexts support this + + try: + _ssl_socket(mock_sock, sslopt, "example.com") + # Should not fail even if compression option is not supported + except AttributeError: + # Expected if SSL context doesn't support compression option + pass + + def test_ssl_session_reuse_edge_cases(self): + """Test SSL session reuse scenarios""" + mock_sock = Mock() + + with patch("ssl.SSLContext") as mock_ssl_context: + mock_context = Mock() + mock_ssl_context.return_value = mock_context + mock_ssl_sock = Mock() + mock_context.wrap_socket.return_value = mock_ssl_sock + + # Test session reuse + mock_ssl_sock.session = "mock_session" + mock_ssl_sock.session_reused = True + + result = _ssl_socket(mock_sock, {}, "example.com") + + # Should handle session reuse without issues + self.assertIsNotNone(result) + + def test_ssl_alpn_protocol_edge_cases(self): + """Test SSL ALPN (Application Layer Protocol Negotiation) edge cases""" + mock_sock = Mock() + + with patch("ssl.SSLContext") as mock_ssl_context: + mock_context = Mock() + mock_ssl_context.return_value = mock_context + mock_context.wrap_socket.return_value = Mock() + + # Test ALPN configuration + sslopt = {"alpn_protocols": ["http/1.1", "h2"]} + + # ALPN protocols are not currently supported in the SSL wrapper + # but the test should not fail + result = _ssl_socket(mock_sock, sslopt, "example.com") + self.assertIsNotNone(result) + # ALPN would need to be implemented in _wrap_sni_socket function + + def test_ssl_sni_edge_cases(self): + """Test SSL SNI (Server Name Indication) edge cases""" + mock_sock = Mock() + + # Test with IPv6 address (should not use SNI) + with patch("ssl.SSLContext") as mock_ssl_context: + mock_context = Mock() + mock_ssl_context.return_value = mock_context + mock_context.wrap_socket.return_value = Mock() + + # IPv6 addresses should not be used for SNI + ipv6_hostname = "2001:db8::1" + _ssl_socket(mock_sock, {}, ipv6_hostname) + + # Should use IPv6 address as server_hostname + mock_context.wrap_socket.assert_called_with( + mock_sock, + do_handshake_on_connect=True, + suppress_ragged_eofs=True, + server_hostname=ipv6_hostname, + ) + + def test_ssl_buffer_size_edge_cases(self): + """Test SSL buffer size related edge cases""" + mock_sock = Mock() + + def mock_recv(bufsize): + # SSL should never try to read more than 16KB at once + if bufsize > 16384: + raise SSLError("[SSL: BAD_LENGTH] buffer too large") + return b"A" * min(bufsize, 1024) # Return smaller chunks + + mock_sock.recv.side_effect = mock_recv + mock_sock.gettimeout.return_value = 30.0 + + from websocket._abnf import frame_buffer + + # Frame buffer should handle large requests by chunking + fb = frame_buffer(lambda size: recv(mock_sock, size), skip_utf8_validation=True) + + # This should work even with large size due to chunking + result = fb.recv_strict(16384) # Exactly 16KB + + self.assertGreater(len(result), 0) + + def test_ssl_protocol_downgrade_protection(self): + """Test SSL protocol downgrade protection""" + mock_sock = Mock() + + with patch("ssl.SSLContext") as mock_ssl_context: + mock_context = Mock() + mock_ssl_context.return_value = mock_context + mock_context.wrap_socket.side_effect = ssl.SSLError( + "SSLV3_ALERT_HANDSHAKE_FAILURE" + ) + + sslopt = {"ssl_version": ssl.PROTOCOL_TLS_CLIENT} + + # Should propagate SSL protocol errors + with self.assertRaises(ssl.SSLError): + _ssl_socket(mock_sock, sslopt, "example.com") + + def test_ssl_certificate_chain_validation(self): + """Test SSL certificate chain validation edge cases""" + mock_sock = Mock() + + with patch("ssl.SSLContext") as mock_ssl_context: + mock_context = Mock() + mock_ssl_context.return_value = mock_context + + # Test certificate chain validation failure + mock_context.wrap_socket.side_effect = ssl.SSLCertVerificationError( + "certificate verify failed: certificate has expired" + ) + + sslopt = {"cert_reqs": ssl.CERT_REQUIRED, "check_hostname": True} + + with self.assertRaises(ssl.SSLCertVerificationError): + _ssl_socket(mock_sock, sslopt, "expired.badssl.com") + + def test_ssl_weak_cipher_rejection(self): + """Test SSL weak cipher rejection scenarios""" + mock_sock = Mock() + + with patch("ssl.SSLContext") as mock_ssl_context: + mock_context = Mock() + mock_ssl_context.return_value = mock_context + mock_context.wrap_socket.side_effect = ssl.SSLError("no shared cipher") + + sslopt = {"ciphers": "RC4-MD5"} # Intentionally weak cipher + + # Should fail with weak ciphers (SSL error is not wrapped by our code) + with self.assertRaises(ssl.SSLError): + _ssl_socket(mock_sock, sslopt, "example.com") + + def test_ssl_hostname_verification_edge_cases(self): + """Test SSL hostname verification edge cases""" + mock_sock = Mock() + + # Test with wildcard certificate scenarios + test_cases = [ + ("*.example.com", "subdomain.example.com"), # Valid wildcard + ("*.example.com", "sub.subdomain.example.com"), # Invalid wildcard depth + ("example.com", "www.example.com"), # Hostname mismatch + ] + + for cert_hostname, connect_hostname in test_cases: + with self.subTest(cert=cert_hostname, hostname=connect_hostname): + with patch("ssl.SSLContext") as mock_ssl_context: + mock_context = Mock() + mock_ssl_context.return_value = mock_context + + if ( + cert_hostname != connect_hostname + and "sub.subdomain" in connect_hostname + ): + # Simulate hostname verification failure for invalid wildcard + mock_context.wrap_socket.side_effect = ssl.SSLCertVerificationError( + f"hostname '{connect_hostname}' doesn't match '{cert_hostname}'" + ) + + sslopt = { + "cert_reqs": ssl.CERT_REQUIRED, + "check_hostname": True, + } + + with self.assertRaises(ssl.SSLCertVerificationError): + _ssl_socket(mock_sock, sslopt, connect_hostname) + else: + mock_context.wrap_socket.return_value = Mock() + sslopt = { + "cert_reqs": ssl.CERT_REQUIRED, + "check_hostname": True, + } + + # Should succeed for valid cases + result = _ssl_socket(mock_sock, sslopt, connect_hostname) + self.assertIsNotNone(result) + + def test_ssl_memory_bio_edge_cases(self): + """Test SSL memory BIO edge cases""" + mock_sock = Mock() + + # Test SSL memory BIO scenarios (if available) + try: + import ssl + + if hasattr(ssl, "MemoryBIO"): + with patch("ssl.SSLContext") as mock_ssl_context: + mock_context = Mock() + mock_ssl_context.return_value = mock_context + mock_context.wrap_socket.return_value = Mock() + + # Memory BIO should work if available + _ssl_socket(mock_sock, {}, "example.com") + + # Standard socket wrapping should still work + mock_context.wrap_socket.assert_called_once() + except (ImportError, AttributeError): + self.skipTest("SSL MemoryBIO not available") + + +if __name__ == "__main__": + unittest.main() diff --git a/contrib/python/websocket-client/websocket/tests/test_url.py b/contrib/python/websocket-client/websocket/tests/test_url.py index 110fdfad70a..bbb39b0f3f7 100644 --- a/contrib/python/websocket-client/websocket/tests/test_url.py +++ b/contrib/python/websocket-client/websocket/tests/test_url.py @@ -15,7 +15,7 @@ from websocket._exceptions import WebSocketProxyException test_url.py websocket - WebSocket client library for Python -Copyright 2024 engn33r +Copyright 2025 engn33r Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -36,6 +36,8 @@ class UrlTest(unittest.TestCase): self.assertTrue(_is_address_in_network("127.0.0.1", "127.0.0.0/8")) self.assertTrue(_is_address_in_network("127.1.0.1", "127.0.0.0/8")) self.assertFalse(_is_address_in_network("127.1.0.1", "127.0.0.0/24")) + self.assertTrue(_is_address_in_network("2001:db8::1", "2001:db8::/64")) + self.assertFalse(_is_address_in_network("2001:db8:1::1", "2001:db8::/64")) def test_parse_url(self): p = parse_url("ws://www.example.com/r") @@ -167,11 +169,16 @@ class IsNoProxyHostTest(unittest.TestCase): self.assertTrue(_is_no_proxy_host("127.0.0.1", ["127.0.0.0/8"])) self.assertTrue(_is_no_proxy_host("127.0.0.2", ["127.0.0.0/8"])) self.assertFalse(_is_no_proxy_host("127.1.0.1", ["127.0.0.0/24"])) - os.environ["no_proxy"] = "127.0.0.0/8" + self.assertTrue(_is_no_proxy_host("2001:db8::1", ["2001:db8::/64"])) + self.assertFalse(_is_no_proxy_host("2001:db8:1::1", ["2001:db8::/64"])) + os.environ["no_proxy"] = "127.0.0.0/8,2001:db8::/64" self.assertTrue(_is_no_proxy_host("127.0.0.1", None)) self.assertTrue(_is_no_proxy_host("127.0.0.2", None)) - os.environ["no_proxy"] = "127.0.0.0/24" + self.assertTrue(_is_no_proxy_host("2001:db8::1", None)) + self.assertFalse(_is_no_proxy_host("2001:db8:1::1", None)) + os.environ["no_proxy"] = "127.0.0.0/24,2001:db8::/64" self.assertFalse(_is_no_proxy_host("127.1.0.1", None)) + self.assertFalse(_is_no_proxy_host("2001:db8:1::1", None)) def test_hostname_match(self): self.assertTrue(_is_no_proxy_host("my.websocket.org", ["my.websocket.org"])) @@ -427,12 +434,12 @@ class ProxyInfoTest(unittest.TestCase): ("localhost2", 3128, ("a", "b")), ) - os.environ[ - "http_proxy" - ] = "http://john%40example.com:P%40SSWORD@localhost:3128/" - os.environ[ - "https_proxy" - ] = "http://john%40example.com:P%40SSWORD@localhost2:3128/" + os.environ["http_proxy"] = ( + "http://john%40example.com:P%40SSWORD@localhost:3128/" + ) + os.environ["https_proxy"] = ( + "http://john%40example.com:P%40SSWORD@localhost2:3128/" + ) self.assertEqual( get_proxy_info("echo.websocket.events", True), ("localhost2", 3128, ("[email protected]", "P@SSWORD")), diff --git a/contrib/python/websocket-client/websocket/tests/test_utils.py b/contrib/python/websocket-client/websocket/tests/test_utils.py new file mode 100644 index 00000000000..deb9751bd16 --- /dev/null +++ b/contrib/python/websocket-client/websocket/tests/test_utils.py @@ -0,0 +1,138 @@ +# -*- coding: utf-8 -*- +import sys +import unittest +from unittest.mock import patch + +""" +test_utils.py +websocket - WebSocket client library for Python + +Copyright 2025 engn33r + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +class UtilsTest(unittest.TestCase): + def test_nolock(self): + """Test NoLock context manager""" + from websocket._utils import NoLock + + lock = NoLock() + + # Test that it can be used as context manager + with lock: + pass # Should not raise any exception + + # Test enter/exit methods directly + self.assertIsNone(lock.__enter__()) + self.assertIsNone(lock.__exit__(None, None, None)) + + def test_utf8_validation_with_wsaccel(self): + """Test UTF-8 validation when wsaccel is available""" + # Import normally (wsaccel should be available in test environment) + from websocket._utils import validate_utf8 + + # Test valid UTF-8 strings (convert to bytes for wsaccel) + self.assertTrue(validate_utf8("Hello, World!".encode("utf-8"))) + self.assertTrue(validate_utf8("🌟 Unicode test".encode("utf-8"))) + self.assertTrue(validate_utf8(b"Hello, bytes")) + self.assertTrue(validate_utf8("Héllo with accénts".encode("utf-8"))) + + # Test invalid UTF-8 sequences + self.assertFalse(validate_utf8(b"\xff\xfe")) # Invalid UTF-8 + self.assertFalse(validate_utf8(b"\x80\x80")) # Invalid continuation + + def test_utf8_validation_fallback(self): + """Test UTF-8 validation fallback when wsaccel is not available""" + # Remove _utils from modules to force reimport + if "websocket._utils" in sys.modules: + del sys.modules["websocket._utils"] + + # Mock wsaccel import to raise ImportError + import builtins + + original_import = builtins.__import__ + + def mock_import(name, *args, **kwargs): + if "wsaccel" in name: + raise ImportError(f"No module named '{name}'") + return original_import(name, *args, **kwargs) + + with patch("builtins.__import__", side_effect=mock_import): + import websocket._utils as utils + + # Test valid UTF-8 strings with fallback implementation (convert strings to bytes) + self.assertTrue(utils.validate_utf8("Hello, World!".encode("utf-8"))) + self.assertTrue(utils.validate_utf8(b"Hello, bytes")) + self.assertTrue(utils.validate_utf8("ASCII text".encode("utf-8"))) + + # Test Unicode strings (convert to bytes) + self.assertTrue(utils.validate_utf8("🌟 Unicode test".encode("utf-8"))) + self.assertTrue(utils.validate_utf8("Héllo with accénts".encode("utf-8"))) + + # Test empty string/bytes + self.assertTrue(utils.validate_utf8("".encode("utf-8"))) + self.assertTrue(utils.validate_utf8(b"")) + + # Test invalid UTF-8 sequences (should return False) + self.assertFalse(utils.validate_utf8(b"\xff\xfe")) + self.assertFalse(utils.validate_utf8(b"\x80\x80")) + + # Note: The fallback implementation may have different validation behavior + # than wsaccel, so we focus on clearly invalid sequences + + def test_extract_err_message(self): + """Test extract_err_message function""" + from websocket._utils import extract_err_message + + # Test with exception that has args + exc_with_args = Exception("Test error message") + self.assertEqual(extract_err_message(exc_with_args), "Test error message") + + # Test with exception that has multiple args + exc_multi_args = Exception("First arg", "Second arg") + self.assertEqual(extract_err_message(exc_multi_args), "First arg") + + # Test with exception that has no args + exc_no_args = Exception() + self.assertIsNone(extract_err_message(exc_no_args)) + + def test_extract_error_code(self): + """Test extract_error_code function""" + from websocket._utils import extract_error_code + + # Test with exception that has integer as first arg + exc_with_code = Exception(404, "Not found") + self.assertEqual(extract_error_code(exc_with_code), 404) + + # Test with exception that has string as first arg + exc_with_string = Exception("Error message", "Second arg") + self.assertIsNone(extract_error_code(exc_with_string)) + + # Test with exception that has only one arg + exc_single_arg = Exception("Single arg") + self.assertIsNone(extract_error_code(exc_single_arg)) + + # Test with exception that has no args + exc_no_args = Exception() + self.assertIsNone(extract_error_code(exc_no_args)) + + def tearDown(self): + """Clean up after tests""" + # Ensure _utils is reimported fresh for next test + if "websocket._utils" in sys.modules: + del sys.modules["websocket._utils"] + + +if __name__ == "__main__": + unittest.main() diff --git a/contrib/python/websocket-client/websocket/tests/test_websocket.py b/contrib/python/websocket-client/websocket/tests/test_websocket.py index 892312a2dbd..9e36df7c011 100644 --- a/contrib/python/websocket-client/websocket/tests/test_websocket.py +++ b/contrib/python/websocket-client/websocket/tests/test_websocket.py @@ -7,7 +7,11 @@ import unittest from base64 import decodebytes as base64decode import websocket as ws -from websocket._exceptions import WebSocketBadStatusException, WebSocketAddressException +from websocket._exceptions import ( + WebSocketBadStatusException, + WebSocketAddressException, + WebSocketException, +) from websocket._handshake import _create_sec_websocket_key from websocket._handshake import _validate as _validate_header from websocket._http import read_headers @@ -17,7 +21,7 @@ from websocket._utils import validate_utf8 test_websocket.py websocket - WebSocket client library for Python -Copyright 2024 engn33r +Copyright 2025 engn33r Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -296,7 +300,7 @@ class WebSocketTest(unittest.TestCase): def test_close(self): sock = ws.WebSocket() sock.connected = True - sock.close + sock.close() sock = ws.WebSocket() s = sock.sock = SockMock() @@ -455,7 +459,7 @@ class HandshakeTest(unittest.TestCase): self.assertRaises(ValueError, websock1.connect, "wss://api.bitfinex.com/ws/2") websock2 = ws.WebSocket(sslopt={"certfile": "myNonexistentCertFile"}) self.assertRaises( - FileNotFoundError, websock2.connect, "wss://api.bitfinex.com/ws/2" + WebSocketException, websock2.connect, "wss://api.bitfinex.com/ws/2" ) @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") |
