diff options
| author | alexv-smirnov <[email protected]> | 2023-12-01 12:02:50 +0300 |
|---|---|---|
| committer | alexv-smirnov <[email protected]> | 2023-12-01 13:28:10 +0300 |
| commit | 0e578a4c44d4abd539d9838347b9ebafaca41dfb (patch) | |
| tree | a0c1969c37f818c830ebeff9c077eacf30be6ef8 /contrib/python/websocket-client/py3/websocket/_handshake.py | |
| parent | 84f2d3d4cc985e63217cff149bd2e6d67ae6fe22 (diff) | |
Change "ya.make"
Diffstat (limited to 'contrib/python/websocket-client/py3/websocket/_handshake.py')
| -rw-r--r-- | contrib/python/websocket-client/py3/websocket/_handshake.py | 197 |
1 files changed, 197 insertions, 0 deletions
diff --git a/contrib/python/websocket-client/py3/websocket/_handshake.py b/contrib/python/websocket-client/py3/websocket/_handshake.py new file mode 100644 index 00000000000..a94d3030c39 --- /dev/null +++ b/contrib/python/websocket-client/py3/websocket/_handshake.py @@ -0,0 +1,197 @@ +""" +_handshake.py +websocket - WebSocket client library for Python + +Copyright 2023 engn33r + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import hashlib +import hmac +import os +from base64 import encodebytes as base64encode +from http import client as HTTPStatus +from ._cookiejar import SimpleCookieJar +from ._exceptions import * +from ._http import * +from ._logging import * +from ._socket import * + +__all__ = ["handshake_response", "handshake", "SUPPORTED_REDIRECT_STATUSES"] + +# websocket supported version. +VERSION = 13 + +SUPPORTED_REDIRECT_STATUSES = (HTTPStatus.MOVED_PERMANENTLY, HTTPStatus.FOUND, HTTPStatus.SEE_OTHER, HTTPStatus.TEMPORARY_REDIRECT, HTTPStatus.PERMANENT_REDIRECT) +SUCCESS_STATUSES = SUPPORTED_REDIRECT_STATUSES + (HTTPStatus.SWITCHING_PROTOCOLS,) + +CookieJar = SimpleCookieJar() + + +class handshake_response: + + def __init__(self, status: int, headers: dict, subprotocol): + self.status = status + self.headers = headers + self.subprotocol = subprotocol + CookieJar.add(headers.get("set-cookie")) + + +def handshake(sock, url: str, hostname: str, port: int, resource: str, **options): + headers, key = _get_handshake_headers(resource, url, hostname, port, options) + + header_str = "\r\n".join(headers) + send(sock, header_str) + dump("request header", header_str) + + status, resp = _get_resp_headers(sock) + if status in SUPPORTED_REDIRECT_STATUSES: + return handshake_response(status, resp, None) + success, subproto = _validate(resp, key, options.get("subprotocols")) + if not success: + raise WebSocketException("Invalid WebSocket Header") + + return handshake_response(status, resp, subproto) + + +def _pack_hostname(hostname: str) -> str: + # IPv6 address + if ':' in hostname: + return '[' + hostname + ']' + + return hostname + + +def _get_handshake_headers(resource: str, url: str, host: str, port: int, options: dict): + headers = [ + "GET {resource} HTTP/1.1".format(resource=resource), + "Upgrade: websocket" + ] + if port == 80 or port == 443: + hostport = _pack_hostname(host) + else: + hostport = "{h}:{p}".format(h=_pack_hostname(host), p=port) + if options.get("host"): + headers.append("Host: {h}".format(h=options["host"])) + else: + headers.append("Host: {hp}".format(hp=hostport)) + + # scheme indicates whether http or https is used in Origin + # The same approach is used in parse_url of _url.py to set default port + scheme, url = url.split(":", 1) + if not options.get("suppress_origin"): + if "origin" in options and options["origin"] is not None: + headers.append("Origin: {origin}".format(origin=options["origin"])) + elif scheme == "wss": + headers.append("Origin: https://{hp}".format(hp=hostport)) + else: + headers.append("Origin: http://{hp}".format(hp=hostport)) + + key = _create_sec_websocket_key() + + # Append Sec-WebSocket-Key & Sec-WebSocket-Version if not manually specified + if not options.get('header') or 'Sec-WebSocket-Key' not in options['header']: + headers.append("Sec-WebSocket-Key: {key}".format(key=key)) + else: + key = options['header']['Sec-WebSocket-Key'] + + if not options.get('header') or 'Sec-WebSocket-Version' not in options['header']: + headers.append("Sec-WebSocket-Version: {version}".format(version=VERSION)) + + if not options.get('connection'): + headers.append('Connection: Upgrade') + else: + headers.append(options['connection']) + + subprotocols = options.get("subprotocols") + if subprotocols: + headers.append("Sec-WebSocket-Protocol: {protocols}".format(protocols=",".join(subprotocols))) + + header = options.get("header") + if header: + if isinstance(header, dict): + header = [ + ": ".join([k, v]) + for k, v in header.items() + if v is not None + ] + headers.extend(header) + + server_cookie = CookieJar.get(host) + client_cookie = options.get("cookie", None) + + cookie = "; ".join(filter(None, [server_cookie, client_cookie])) + + if cookie: + headers.append("Cookie: {cookie}".format(cookie=cookie)) + + headers.extend(("", "")) + return headers, key + + +def _get_resp_headers(sock, success_statuses: tuple = SUCCESS_STATUSES) -> tuple: + status, resp_headers, status_message = read_headers(sock) + if status not in success_statuses: + content_len = resp_headers.get('content-length') + if content_len: + response_body = sock.recv(int(content_len)) # read the body of the HTTP error message response and include it in the exception + else: + response_body = None + raise WebSocketBadStatusException("Handshake status {status} {message} -+-+- {headers} -+-+- {body}".format(status=status, message=status_message, headers=resp_headers, body=response_body), status, status_message, resp_headers, response_body) + return status, resp_headers + + +_HEADERS_TO_CHECK = { + "upgrade": "websocket", + "connection": "upgrade", +} + + +def _validate(headers, key: str, subprotocols): + subproto = None + for k, v in _HEADERS_TO_CHECK.items(): + r = headers.get(k, None) + if not r: + return False, None + r = [x.strip().lower() for x in r.split(',')] + if v not in r: + return False, None + + if subprotocols: + subproto = headers.get("sec-websocket-protocol", None) + if not subproto or subproto.lower() not in [s.lower() for s in subprotocols]: + error("Invalid subprotocol: " + str(subprotocols)) + return False, None + subproto = subproto.lower() + + result = headers.get("sec-websocket-accept", None) + if not result: + return False, None + result = result.lower() + + if isinstance(result, str): + result = result.encode('utf-8') + + value = (key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").encode('utf-8') + hashed = base64encode(hashlib.sha1(value).digest()).strip().lower() + success = hmac.compare_digest(hashed, result) + + if success: + return True, subproto + else: + return False, None + + +def _create_sec_websocket_key() -> str: + randomness = os.urandom(16) + return base64encode(randomness).decode('utf-8').strip() |
