diff --git a/lib/websocket/__init__.py b/lib/websocket/__init__.py index a9fa4634..05aae2bd 100644 --- a/lib/websocket/__init__.py +++ b/lib/websocket/__init__.py @@ -23,4 +23,4 @@ from ._exceptions import * from ._logging import * from ._socket import * -__version__ = "1.2.1" +__version__ = "1.2.3" diff --git a/lib/websocket/_abnf.py b/lib/websocket/_abnf.py index 6a4d4907..e9909ff6 100644 --- a/lib/websocket/_abnf.py +++ b/lib/websocket/_abnf.py @@ -96,7 +96,7 @@ VALID_CLOSE_STATUS = ( ) -class ABNF(object): +class ABNF: """ ABNF frame class. See http://tools.ietf.org/html/rfc5234 @@ -268,7 +268,7 @@ class ABNF(object): return _mask(array.array("B", mask_key), array.array("B", data)) -class frame_buffer(object): +class frame_buffer: _HEADER_MASK_INDEX = 5 _HEADER_LENGTH_INDEX = 6 @@ -373,7 +373,7 @@ class frame_buffer(object): self.recv_buffer.append(bytes_) shortage -= len(bytes_) - unified = bytes("", 'utf-8').join(self.recv_buffer) + unified = b"".join(self.recv_buffer) if shortage == 0: self.recv_buffer = [] @@ -383,7 +383,7 @@ class frame_buffer(object): return unified[:bufsize] -class continuous_frame(object): +class continuous_frame: def __init__(self, fire_cont_frame, skip_utf8_validation): self.fire_cont_frame = fire_cont_frame diff --git a/lib/websocket/_app.py b/lib/websocket/_app.py index 61925bad..1afd3d20 100644 --- a/lib/websocket/_app.py +++ b/lib/websocket/_app.py @@ -86,7 +86,7 @@ class SSLDispatcher: return r[0][0] -class WebSocketApp(object): +class WebSocketApp: """ Higher level of APIs are provided. The interface is like JavaScript WebSocket object. """ diff --git a/lib/websocket/_cookiejar.py b/lib/websocket/_cookiejar.py index dcf5031a..87853834 100644 --- a/lib/websocket/_cookiejar.py +++ b/lib/websocket/_cookiejar.py @@ -23,7 +23,7 @@ limitations under the License. import http.cookies -class SimpleCookieJar(object): +class SimpleCookieJar: def __init__(self): self.jar = dict() diff --git a/lib/websocket/_core.py b/lib/websocket/_core.py index f92f8a60..e26c8b11 100644 --- a/lib/websocket/_core.py +++ b/lib/websocket/_core.py @@ -40,7 +40,7 @@ from ._utils import * __all__ = ['WebSocket', 'create_connection'] -class WebSocket(object): +class WebSocket: """ Low level WebSocket interface. @@ -66,7 +66,7 @@ class WebSocket(object): Values for socket.setsockopt. sockopt must be tuple and each element is argument of sock.setsockopt. sslopt: dict - Optional dict object for ssl socket options. + Optional dict object for ssl socket options. See FAQ for details. fire_cont_frame: bool Fire recv event for each cont frame. Default is False. enable_multithread: bool @@ -84,7 +84,7 @@ class WebSocket(object): Parameters ---------- sslopt: dict - Optional dict object for ssl socket options. + Optional dict object for ssl socket options. See FAQ for details. """ self.sock_opt = sock_opt(sockopt, sslopt) self.handshake_response = None @@ -314,6 +314,14 @@ class WebSocket(object): return length def send_binary(self, payload): + """ + Send a binary message (OPCODE_BINARY). + + Parameters + ---------- + payload: bytes + payload of message to send. + """ return self.send(payload, ABNF.OPCODE_BINARY) def ping(self, payload=""): @@ -381,6 +389,8 @@ class WebSocket(object): """ Receive data with operation code. + If a valid ping message is received, a pong response is sent. + Parameters ---------- control_frame: bool @@ -434,7 +444,7 @@ class WebSocket(object): """ return self.frame_buffer.recv_frame() - def send_close(self, status=STATUS_NORMAL, reason=bytes('', encoding='utf-8')): + def send_close(self, status=STATUS_NORMAL, reason=b""): """ Send close data to the server. @@ -443,14 +453,14 @@ class WebSocket(object): status: int Status code to send. See STATUS_XXX. reason: str or bytes - The reason to close. This must be string or bytes. + The reason to close. This must be string or UTF-8 bytes. """ if status < 0 or status >= ABNF.LENGTH_16: raise ValueError("code is invalid range") self.connected = False self.send(struct.pack('!H', status) + reason, ABNF.OPCODE_CLOSE) - def close(self, status=STATUS_NORMAL, reason=bytes('', encoding='utf-8'), timeout=3): + def close(self, status=STATUS_NORMAL, reason=b"", timeout=3): """ Close Websocket object @@ -459,7 +469,7 @@ class WebSocket(object): status: int Status code to send. See STATUS_XXX. reason: bytes - The reason to close. + The reason to close in UTF-8. timeout: int or float Timeout until receive a close frame. If None, it will wait forever until receive a close frame. @@ -575,7 +585,7 @@ def create_connection(url, timeout=None, class_=WebSocket, **options): Values for socket.setsockopt. sockopt must be a tuple and each element is an argument of sock.setsockopt. sslopt: dict - Optional dict object for ssl socket options. + Optional dict object for ssl socket options. See FAQ for details. subprotocols: list List of available subprotocols. Default is None. skip_utf8_validation: bool diff --git a/lib/websocket/_exceptions.py b/lib/websocket/_exceptions.py index 2d5b0535..b92b1f40 100644 --- a/lib/websocket/_exceptions.py +++ b/lib/websocket/_exceptions.py @@ -72,7 +72,7 @@ class WebSocketBadStatusException(WebSocketException): def __init__(self, message, status_code, status_message=None, resp_headers=None): msg = message % (status_code, status_message) - super(WebSocketBadStatusException, self).__init__(msg) + super().__init__(msg) self.status_code = status_code self.resp_headers = resp_headers diff --git a/lib/websocket/_handshake.py b/lib/websocket/_handshake.py index da1a8d44..f9dabb57 100644 --- a/lib/websocket/_handshake.py +++ b/lib/websocket/_handshake.py @@ -38,7 +38,7 @@ SUCCESS_STATUSES = SUPPORTED_REDIRECT_STATUSES + (HTTPStatus.SWITCHING_PROTOCOLS CookieJar = SimpleCookieJar() -class handshake_response(object): +class handshake_response: def __init__(self, status, headers, subprotocol): self.status = status diff --git a/lib/websocket/_http.py b/lib/websocket/_http.py index 9ddf01d0..603fa00f 100644 --- a/lib/websocket/_http.py +++ b/lib/websocket/_http.py @@ -49,7 +49,7 @@ except: pass -class proxy_info(object): +class proxy_info: def __init__(self, **options): self.proxy_host = options.get("http_proxy_host", None) @@ -211,33 +211,41 @@ def _open_socket(addrinfo_list, sockopt, timeout): def _wrap_sni_socket(sock, sslopt, hostname, check_hostname): - context = ssl.SSLContext(sslopt.get('ssl_version', ssl.PROTOCOL_TLS)) + context = sslopt.get('context', None) + if not context: + context = ssl.SSLContext(sslopt.get('ssl_version', ssl.PROTOCOL_TLS_CLIENT)) - if sslopt.get('cert_reqs', ssl.CERT_NONE) != ssl.CERT_NONE: - cafile = sslopt.get('ca_certs', None) - capath = sslopt.get('ca_cert_path', None) - if cafile or capath: - context.load_verify_locations(cafile=cafile, capath=capath) - elif hasattr(context, 'load_default_certs'): - context.load_default_certs(ssl.Purpose.SERVER_AUTH) - if sslopt.get('certfile', None): - context.load_cert_chain( - sslopt['certfile'], - sslopt.get('keyfile', None), - sslopt.get('password', None), - ) - # see - # https://github.com/liris/websocket-client/commit/b96a2e8fa765753e82eea531adb19716b52ca3ca#commitcomment-10803153 - context.verify_mode = sslopt['cert_reqs'] - if HAVE_CONTEXT_CHECK_HOSTNAME: - context.check_hostname = check_hostname - if 'ciphers' in sslopt: - context.set_ciphers(sslopt['ciphers']) - if 'cert_chain' in sslopt: - certfile, keyfile, password = sslopt['cert_chain'] - context.load_cert_chain(certfile, keyfile, password) - if 'ecdh_curve' in sslopt: - context.set_ecdh_curve(sslopt['ecdh_curve']) + if sslopt.get('cert_reqs', ssl.CERT_NONE) != ssl.CERT_NONE: + cafile = sslopt.get('ca_certs', None) + capath = sslopt.get('ca_cert_path', None) + if cafile or capath: + context.load_verify_locations(cafile=cafile, capath=capath) + elif hasattr(context, 'load_default_certs'): + context.load_default_certs(ssl.Purpose.SERVER_AUTH) + if sslopt.get('certfile', None): + context.load_cert_chain( + sslopt['certfile'], + sslopt.get('keyfile', None), + sslopt.get('password', None), + ) + + # Python 3.10 switch to PROTOCOL_TLS_CLIENT defaults to "cert_reqs = ssl.CERT_REQUIRED" and "check_hostname = True" + # If both disabled, set check_hostname before verify_mode + # see https://github.com/liris/websocket-client/commit/b96a2e8fa765753e82eea531adb19716b52ca3ca#commitcomment-10803153 + if sslopt.get('cert_reqs', ssl.CERT_NONE) == ssl.CERT_NONE and not sslopt.get('check_hostname', False): + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE + else: + context.check_hostname = sslopt.get('check_hostname', True) + context.verify_mode = sslopt.get('cert_reqs', ssl.CERT_REQUIRED) + + if 'ciphers' in sslopt: + context.set_ciphers(sslopt['ciphers']) + if 'cert_chain' in sslopt: + certfile, keyfile, password = sslopt['cert_chain'] + context.load_cert_chain(certfile, keyfile, password) + if 'ecdh_curve' in sslopt: + context.set_ecdh_curve(sslopt['ecdh_curve']) return context.wrap_socket( sock, @@ -262,13 +270,9 @@ def _ssl_socket(sock, user_sslopt, hostname): if sslopt.get('server_hostname', None): hostname = sslopt['server_hostname'] - check_hostname = sslopt["cert_reqs"] != ssl.CERT_NONE and sslopt.pop( - 'check_hostname', True) + check_hostname = sslopt.get('check_hostname', True) sock = _wrap_sni_socket(sock, sslopt, hostname, check_hostname) - if not HAVE_CONTEXT_CHECK_HOSTNAME and check_hostname: - match_hostname(sock.getpeercert(), hostname) - return sock diff --git a/lib/websocket/_socket.py b/lib/websocket/_socket.py index eb573d4e..4d9cc097 100644 --- a/lib/websocket/_socket.py +++ b/lib/websocket/_socket.py @@ -44,7 +44,7 @@ __all__ = ["DEFAULT_SOCKET_OPTION", "sock_opt", "setdefaulttimeout", "getdefault "recv", "recv_line", "send"] -class sock_opt(object): +class sock_opt: def __init__(self, sockopt, sslopt): if sockopt is None: diff --git a/lib/websocket/_ssl_compat.py b/lib/websocket/_ssl_compat.py index 9e5460c2..f4af524e 100644 --- a/lib/websocket/_ssl_compat.py +++ b/lib/websocket/_ssl_compat.py @@ -23,11 +23,6 @@ try: from ssl import SSLError from ssl import SSLWantReadError from ssl import SSLWantWriteError - HAVE_CONTEXT_CHECK_HOSTNAME = False - if hasattr(ssl, 'SSLContext') and hasattr(ssl.SSLContext, 'check_hostname'): - HAVE_CONTEXT_CHECK_HOSTNAME = True - - __all__.append("HAVE_CONTEXT_CHECK_HOSTNAME") HAVE_SSL = True except ImportError: # dummy class of SSLError for environment without ssl support diff --git a/lib/websocket/_utils.py b/lib/websocket/_utils.py index feed027e..21fc437c 100644 --- a/lib/websocket/_utils.py +++ b/lib/websocket/_utils.py @@ -19,7 +19,7 @@ limitations under the License. __all__ = ["NoLock", "validate_utf8", "extract_err_message", "extract_error_code"] -class NoLock(object): +class NoLock: def __enter__(self): pass diff --git a/lib/websocket/_wsdump.py b/lib/websocket/_wsdump.py new file mode 100644 index 00000000..4d15f413 --- /dev/null +++ b/lib/websocket/_wsdump.py @@ -0,0 +1,231 @@ +#!/usr/bin/env python3 + +""" +wsdump.py +websocket - WebSocket client library for Python + +Copyright 2021 engn33r + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import argparse +import code +import sys +import threading +import time +import ssl +import gzip +import zlib +from urllib.parse import urlparse + +import websocket + +try: + import readline +except ImportError: + pass + + +def get_encoding(): + encoding = getattr(sys.stdin, "encoding", "") + if not encoding: + return "utf-8" + else: + return encoding.lower() + + +OPCODE_DATA = (websocket.ABNF.OPCODE_TEXT, websocket.ABNF.OPCODE_BINARY) +ENCODING = get_encoding() + + +class VAction(argparse.Action): + + def __call__(self, parser, args, values, option_string=None): + if values is None: + values = "1" + try: + values = int(values) + except ValueError: + values = values.count("v") + 1 + setattr(args, self.dest, values) + + +def parse_args(): + parser = argparse.ArgumentParser(description="WebSocket Simple Dump Tool") + parser.add_argument("url", metavar="ws_url", + help="websocket url. ex. ws://echo.websocket.org/") + parser.add_argument("-p", "--proxy", + help="proxy url. ex. http://127.0.0.1:8080") + parser.add_argument("-v", "--verbose", default=0, nargs='?', action=VAction, + dest="verbose", + help="set verbose mode. If set to 1, show opcode. " + "If set to 2, enable to trace websocket module") + parser.add_argument("-n", "--nocert", action='store_true', + help="Ignore invalid SSL cert") + parser.add_argument("-r", "--raw", action="store_true", + help="raw output") + parser.add_argument("-s", "--subprotocols", nargs='*', + help="Set subprotocols") + parser.add_argument("-o", "--origin", + help="Set origin") + parser.add_argument("--eof-wait", default=0, type=int, + help="wait time(second) after 'EOF' received.") + parser.add_argument("-t", "--text", + help="Send initial text") + parser.add_argument("--timings", action="store_true", + help="Print timings in seconds") + parser.add_argument("--headers", + help="Set custom headers. Use ',' as separator") + + return parser.parse_args() + + +class RawInput: + + def raw_input(self, prompt): + line = input(prompt) + + if ENCODING and ENCODING != "utf-8" and not isinstance(line, str): + line = line.decode(ENCODING).encode("utf-8") + elif isinstance(line, str): + line = line.encode("utf-8") + + return line + + +class InteractiveConsole(RawInput, code.InteractiveConsole): + + def write(self, data): + sys.stdout.write("\033[2K\033[E") + # sys.stdout.write("\n") + sys.stdout.write("\033[34m< " + data + "\033[39m") + sys.stdout.write("\n> ") + sys.stdout.flush() + + def read(self): + return self.raw_input("> ") + + +class NonInteractive(RawInput): + + def write(self, data): + sys.stdout.write(data) + sys.stdout.write("\n") + sys.stdout.flush() + + def read(self): + return self.raw_input("") + + +def main(): + start_time = time.time() + args = parse_args() + if args.verbose > 1: + websocket.enableTrace(True) + options = {} + if args.proxy: + p = urlparse(args.proxy) + options["http_proxy_host"] = p.hostname + options["http_proxy_port"] = p.port + if args.origin: + options["origin"] = args.origin + if args.subprotocols: + options["subprotocols"] = args.subprotocols + opts = {} + if args.nocert: + opts = {"cert_reqs": ssl.CERT_NONE, "check_hostname": False} + if args.headers: + options['header'] = list(map(str.strip, args.headers.split(','))) + ws = websocket.create_connection(args.url, sslopt=opts, **options) + if args.raw: + console = NonInteractive() + else: + console = InteractiveConsole() + print("Press Ctrl+C to quit") + + def recv(): + try: + frame = ws.recv_frame() + except websocket.WebSocketException: + return websocket.ABNF.OPCODE_CLOSE, None + if not frame: + raise websocket.WebSocketException("Not a valid frame %s" % frame) + elif frame.opcode in OPCODE_DATA: + return frame.opcode, frame.data + elif frame.opcode == websocket.ABNF.OPCODE_CLOSE: + ws.send_close() + return frame.opcode, None + elif frame.opcode == websocket.ABNF.OPCODE_PING: + ws.pong(frame.data) + return frame.opcode, frame.data + + return frame.opcode, frame.data + + def recv_ws(): + while True: + opcode, data = recv() + msg = None + if opcode == websocket.ABNF.OPCODE_TEXT and isinstance(data, bytes): + data = str(data, "utf-8") + if isinstance(data, bytes) and len(data) > 2 and data[:2] == b'\037\213': # gzip magick + try: + data = "[gzip] " + str(gzip.decompress(data), "utf-8") + except: + pass + elif isinstance(data, bytes): + try: + data = "[zlib] " + str(zlib.decompress(data, -zlib.MAX_WBITS), "utf-8") + except: + pass + + if isinstance(data, bytes): + data = repr(data) + + if args.verbose: + msg = "%s: %s" % (websocket.ABNF.OPCODE_MAP.get(opcode), data) + else: + msg = data + + if msg is not None: + if args.timings: + console.write(str(time.time() - start_time) + ": " + msg) + else: + console.write(msg) + + if opcode == websocket.ABNF.OPCODE_CLOSE: + break + + thread = threading.Thread(target=recv_ws) + thread.daemon = True + thread.start() + + if args.text: + ws.send(args.text) + + while True: + try: + message = console.read() + ws.send(message) + except KeyboardInterrupt: + return + except EOFError: + time.sleep(args.eof_wait) + return + + +if __name__ == "__main__": + try: + main() + except Exception as e: + print(e) diff --git a/lib/websocket/tests/data/header01.txt b/lib/websocket/tests/data/header01.txt index d44d24c2..3142b43b 100644 --- a/lib/websocket/tests/data/header01.txt +++ b/lib/websocket/tests/data/header01.txt @@ -1,6 +1,6 @@ -HTTP/1.1 101 WebSocket Protocol Handshake -Connection: Upgrade -Upgrade: WebSocket -Sec-WebSocket-Accept: Kxep+hNu9n51529fGidYu7a3wO0= -some_header: something - +HTTP/1.1 101 WebSocket Protocol Handshake +Connection: Upgrade +Upgrade: WebSocket +Sec-WebSocket-Accept: Kxep+hNu9n51529fGidYu7a3wO0= +some_header: something + diff --git a/lib/websocket/tests/data/header02.txt b/lib/websocket/tests/data/header02.txt index f481de92..a9dd2ce3 100644 --- a/lib/websocket/tests/data/header02.txt +++ b/lib/websocket/tests/data/header02.txt @@ -1,6 +1,6 @@ -HTTP/1.1 101 WebSocket Protocol Handshake -Connection: Upgrade -Upgrade WebSocket -Sec-WebSocket-Accept: Kxep+hNu9n51529fGidYu7a3wO0= -some_header: something - +HTTP/1.1 101 WebSocket Protocol Handshake +Connection: Upgrade +Upgrade WebSocket +Sec-WebSocket-Accept: Kxep+hNu9n51529fGidYu7a3wO0= +some_header: something + diff --git a/lib/websocket/tests/echo-server.py b/lib/websocket/tests/echo-server.py index 8736def4..08d108ab 100644 --- a/lib/websocket/tests/echo-server.py +++ b/lib/websocket/tests/echo-server.py @@ -4,6 +4,9 @@ import asyncio import websockets +import os + +LOCAL_WS_SERVER_PORT = os.environ.get('LOCAL_WS_SERVER_PORT', '8765') async def echo(websocket, path): @@ -12,7 +15,7 @@ async def echo(websocket, path): async def main(): - async with websockets.serve(echo, "localhost", 8765): + async with websockets.serve(echo, "localhost", LOCAL_WS_SERVER_PORT): await asyncio.Future() # run forever asyncio.run(main()) diff --git a/lib/websocket/tests/test_abnf.py b/lib/websocket/tests/test_abnf.py index 68282fef..7f156dc9 100644 --- a/lib/websocket/tests/test_abnf.py +++ b/lib/websocket/tests/test_abnf.py @@ -19,12 +19,9 @@ See the License for the specific language governing permissions and limitations under the License. """ -import os import websocket as ws from websocket._abnf import * -import sys import unittest -sys.path[0:0] = [""] class ABNFTest(unittest.TestCase): @@ -57,7 +54,7 @@ class ABNFTest(unittest.TestCase): def testMask(self): abnf_none_data = ABNF(0,0,0,0, opcode=ABNF.OPCODE_PING, mask=1, data=None) - bytes_val = bytes("aaaa", 'utf-8') + bytes_val = b"aaaa" self.assertEqual(abnf_none_data._get_masked(bytes_val), bytes_val) abnf_str_data = ABNF(0,0,0,0, opcode=ABNF.OPCODE_PING, mask=1, data="a") self.assertEqual(abnf_str_data._get_masked(bytes_val), b'aaaa\x00') diff --git a/lib/websocket/tests/test_app.py b/lib/websocket/tests/test_app.py index d81b06f5..cd1146b3 100644 --- a/lib/websocket/tests/test_app.py +++ b/lib/websocket/tests/test_app.py @@ -22,19 +22,20 @@ limitations under the License. import os import os.path import websocket as ws -import sys import ssl import unittest -sys.path[0:0] = [""] -# Skip test to access the internet. +# Skip test to access the internet unless TEST_WITH_INTERNET == 1 TEST_WITH_INTERNET = os.environ.get('TEST_WITH_INTERNET', '0') == '1' +# Skip tests relying on local websockets server unless LOCAL_WS_SERVER_PORT != -1 +LOCAL_WS_SERVER_PORT = os.environ.get('LOCAL_WS_SERVER_PORT', '-1') +TEST_WITH_LOCAL_SERVER = LOCAL_WS_SERVER_PORT != '-1' TRACEABLE = True class WebSocketAppTest(unittest.TestCase): - class NotSetYet(object): + class NotSetYet: """ A marker class for signalling that a value hasn't been set yet. """ @@ -50,7 +51,7 @@ class WebSocketAppTest(unittest.TestCase): WebSocketAppTest.keep_running_close = WebSocketAppTest.NotSetYet() WebSocketAppTest.get_mask_key_id = WebSocketAppTest.NotSetYet() - @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") + @unittest.skipUnless(TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled") def testKeepRunning(self): """ A WebSocketApp should keep running as long as its self.keep_running is not False (in the boolean context). @@ -73,7 +74,7 @@ class WebSocketAppTest(unittest.TestCase): """ WebSocketAppTest.keep_running_close = self.keep_running - app = ws.WebSocketApp('ws://127.0.0.1:8765', on_open=on_open, on_close=on_close, on_message=on_message) + app = ws.WebSocketApp('ws://127.0.0.1:' + LOCAL_WS_SERVER_PORT, on_open=on_open, on_close=on_close, on_message=on_message) app.run_forever() @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") diff --git a/lib/websocket/tests/test_cookiejar.py b/lib/websocket/tests/test_cookiejar.py index 69258ac0..5bf1fcae 100644 --- a/lib/websocket/tests/test_cookiejar.py +++ b/lib/websocket/tests/test_cookiejar.py @@ -21,7 +21,6 @@ See the License for the specific language governing permissions and limitations under the License. """ import unittest - from websocket._cookiejar import SimpleCookieJar @@ -114,3 +113,7 @@ class CookieJarTest(unittest.TestCase): self.assertEqual(cookie_jar.get("x.abc.com"), "a=b; c=d") self.assertEqual(cookie_jar.get("abc.com.es"), "") self.assertEqual(cookie_jar.get("xabc.com"), "") + + +if __name__ == "__main__": + unittest.main() diff --git a/lib/websocket/tests/test_http.py b/lib/websocket/tests/test_http.py index e978bdd8..fda467d7 100644 --- a/lib/websocket/tests/test_http.py +++ b/lib/websocket/tests/test_http.py @@ -23,26 +23,25 @@ import os import os.path import websocket as ws from websocket._http import proxy_info, read_headers, _start_proxied_socket, _tunnel, _get_addrinfo_list, connect -import sys import unittest import ssl import websocket import socket try: - from python_socks.sync import Proxy - from python_socks._errors import * + from python_socks._errors import ProxyError, ProxyTimeoutError, ProxyConnectionError except: from websocket._http import ProxyError, ProxyTimeoutError, ProxyConnectionError -sys.path[0:0] = [""] - -# Skip test to access the internet. +# Skip test to access the internet unless TEST_WITH_INTERNET == 1 TEST_WITH_INTERNET = os.environ.get('TEST_WITH_INTERNET', '0') == '1' TEST_WITH_PROXY = os.environ.get('TEST_WITH_PROXY', '0') == '1' +# Skip tests relying on local websockets server unless LOCAL_WS_SERVER_PORT != -1 +LOCAL_WS_SERVER_PORT = os.environ.get('LOCAL_WS_SERVER_PORT', '-1') +TEST_WITH_LOCAL_SERVER = LOCAL_WS_SERVER_PORT != '-1' -class SockMock(object): +class SockMock: def __init__(self): self.data = [] self.sent = [] @@ -106,10 +105,10 @@ class HttpTest(unittest.TestCase): if ws._http.HAVE_PYTHON_SOCKS: # Need this check, otherwise case where python_socks is not installed triggers # websocket._exceptions.WebSocketException: Python Socks is needed for SOCKS proxying but is not available - self.assertRaises(ProxyTimeoutError, _start_proxied_socket, "wss://example.com", OptsList(), proxy_info(http_proxy_host="example.com", http_proxy_port="8080", proxy_type="socks4", timeout=1)) - self.assertRaises(ProxyTimeoutError, _start_proxied_socket, "wss://example.com", OptsList(), proxy_info(http_proxy_host="example.com", http_proxy_port="8080", proxy_type="socks4a", timeout=1)) - self.assertRaises(ProxyTimeoutError, _start_proxied_socket, "wss://example.com", OptsList(), proxy_info(http_proxy_host="example.com", http_proxy_port="8080", proxy_type="socks5", timeout=1)) - self.assertRaises(ProxyTimeoutError, _start_proxied_socket, "wss://example.com", OptsList(), proxy_info(http_proxy_host="example.com", http_proxy_port="8080", proxy_type="socks5h", timeout=1)) + self.assertRaises((ProxyTimeoutError, OSError), _start_proxied_socket, "wss://example.com", OptsList(), proxy_info(http_proxy_host="example.com", http_proxy_port="8080", proxy_type="socks4", timeout=1)) + self.assertRaises((ProxyTimeoutError, OSError), _start_proxied_socket, "wss://example.com", OptsList(), proxy_info(http_proxy_host="example.com", http_proxy_port="8080", proxy_type="socks4a", timeout=1)) + self.assertRaises((ProxyTimeoutError, OSError), _start_proxied_socket, "wss://example.com", OptsList(), proxy_info(http_proxy_host="example.com", http_proxy_port="8080", proxy_type="socks5", timeout=1)) + self.assertRaises((ProxyTimeoutError, OSError), _start_proxied_socket, "wss://example.com", OptsList(), proxy_info(http_proxy_host="example.com", http_proxy_port="8080", proxy_type="socks5h", timeout=1)) self.assertRaises(ProxyConnectionError, connect, "wss://example.com", OptsList(), proxy_info(http_proxy_host="127.0.0.1", http_proxy_port=9999, proxy_type="socks4", timeout=1), None) self.assertRaises(TypeError, _get_addrinfo_list, None, 80, True, proxy_info(http_proxy_host="127.0.0.1", http_proxy_port="9999", proxy_type="http")) @@ -123,9 +122,10 @@ class HttpTest(unittest.TestCase): @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") @unittest.skipUnless(TEST_WITH_PROXY, "This test requires a HTTP proxy to be running on port 8899") + @unittest.skipUnless(TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled") def testProxyConnect(self): ws = websocket.WebSocket() - ws.connect("ws://127.0.0.1:8765", http_proxy_host="127.0.0.1", http_proxy_port="8899", proxy_type="http") + ws.connect("ws://127.0.0.1:" + LOCAL_WS_SERVER_PORT, http_proxy_host="127.0.0.1", http_proxy_port="8899", proxy_type="http") ws.send("Hello, Server") server_response = ws.recv() self.assertEqual(server_response, "Hello, Server") @@ -138,10 +138,9 @@ class HttpTest(unittest.TestCase): @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") def testSSLopt(self): ssloptions = { - "cert_reqs": ssl.CERT_NONE, "check_hostname": False, "server_hostname": "ServerName", - "ssl_version": ssl.PROTOCOL_TLS, + "ssl_version": ssl.PROTOCOL_TLS_CLIENT, "ciphers": "TLS_AES_256_GCM_SHA384:TLS_CHACHA20_POLY1305_SHA256:\ TLS_AES_128_GCM_SHA256:ECDHE-ECDSA-AES256-GCM-SHA384:\ ECDHE-RSA-AES256-GCM-SHA384:DHE-RSA-AES256-GCM-SHA384:\ diff --git a/lib/websocket/tests/test_url.py b/lib/websocket/tests/test_url.py index dc24bb8b..ad3a3b1b 100644 --- a/lib/websocket/tests/test_url.py +++ b/lib/websocket/tests/test_url.py @@ -19,10 +19,8 @@ See the License for the specific language governing permissions and limitations under the License. """ -import sys import os import unittest -sys.path[0:0] = [""] from websocket._url import get_proxy_info, parse_url, _is_address_in_network, _is_no_proxy_host diff --git a/lib/websocket/tests/test_websocket.py b/lib/websocket/tests/test_websocket.py index f0c38ee4..8b34aa51 100644 --- a/lib/websocket/tests/test_websocket.py +++ b/lib/websocket/tests/test_websocket.py @@ -23,8 +23,6 @@ See the License for the specific language governing permissions and limitations under the License. """ -import sys -sys.path[0:0] = [""] import os import os.path import socket @@ -45,8 +43,11 @@ except ImportError: class SSLError(Exception): pass -# Skip test to access the internet. +# Skip test to access the internet unless TEST_WITH_INTERNET == 1 TEST_WITH_INTERNET = os.environ.get('TEST_WITH_INTERNET', '0') == '1' +# Skip tests relying on local websockets server unless LOCAL_WS_SERVER_PORT != -1 +LOCAL_WS_SERVER_PORT = os.environ.get('LOCAL_WS_SERVER_PORT', '-1') +TEST_WITH_LOCAL_SERVER = LOCAL_WS_SERVER_PORT != '-1' TRACEABLE = True @@ -54,7 +55,7 @@ def create_mask_key(_): return "abcd" -class SockMock(object): +class SockMock: def __init__(self): self.data = [] self.sent = [] @@ -335,9 +336,9 @@ class WebSocketTest(unittest.TestCase): s.sent[0], b'\x8a\x90abcd1\x0e\x06\x05\x12\x07C4.,$D\x15\n\n\x17') - @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") + @unittest.skipUnless(TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled") def testWebSocket(self): - s = ws.create_connection("ws://127.0.0.1:8765") + s = ws.create_connection("ws://127.0.0.1:" + LOCAL_WS_SERVER_PORT) self.assertNotEqual(s, None) s.send("Hello, World") result = s.next() @@ -350,9 +351,9 @@ class WebSocketTest(unittest.TestCase): self.assertRaises(ValueError, s.send_close, -1, "") s.close() - @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") + @unittest.skipUnless(TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled") def testPingPong(self): - s = ws.create_connection("ws://127.0.0.1:8765") + s = ws.create_connection("ws://127.0.0.1:" + LOCAL_WS_SERVER_PORT) self.assertNotEqual(s, None) s.ping("Hello") s.pong("Hi") @@ -377,9 +378,9 @@ class WebSocketTest(unittest.TestCase): self.assertEqual(s.getsubprotocol(), None) s.abort() - @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") + @unittest.skipUnless(TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled") def testWebSocketWithCustomHeader(self): - s = ws.create_connection("ws://127.0.0.1:8765", + s = ws.create_connection("ws://127.0.0.1:" + LOCAL_WS_SERVER_PORT, headers={"User-Agent": "PythonWebsocketClient"}) self.assertNotEqual(s, None) s.send("Hello, World") @@ -388,9 +389,9 @@ class WebSocketTest(unittest.TestCase): self.assertRaises(ValueError, s.close, -1, "") s.close() - @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") + @unittest.skipUnless(TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled") def testAfterClose(self): - s = ws.create_connection("ws://127.0.0.1:8765") + s = ws.create_connection("ws://127.0.0.1:" + LOCAL_WS_SERVER_PORT) self.assertNotEqual(s, None) s.close() self.assertRaises(ws.WebSocketConnectionClosedException, s.send, "Hello") @@ -398,10 +399,10 @@ class WebSocketTest(unittest.TestCase): class SockOptTest(unittest.TestCase): - @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") + @unittest.skipUnless(TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled") def testSockOpt(self): sockopt = ((socket.IPPROTO_TCP, socket.TCP_NODELAY, 1),) - s = ws.create_connection("ws://127.0.0.1:8765", sockopt=sockopt) + s = ws.create_connection("ws://127.0.0.1:" + LOCAL_WS_SERVER_PORT, sockopt=sockopt) self.assertNotEqual(s.sock.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY), 0) s.close() @@ -428,9 +429,8 @@ class HandshakeTest(unittest.TestCase): @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") def testManualHeaders(self): - websock3 = ws.WebSocket(sslopt={"cert_reqs": ssl.CERT_NONE, - "ca_certs": ssl.get_default_verify_paths().capath, - "ca_cert_path": ssl.get_default_verify_paths().openssl_cafile}) + websock3 = ws.WebSocket(sslopt={"ca_certs": ssl.get_default_verify_paths().cafile, + "ca_cert_path": ssl.get_default_verify_paths().capath}) self.assertRaises(ws._exceptions.WebSocketBadStatusException, websock3.connect, "wss://api.bitfinex.com/ws/2", cookie="chocolate", origin="testing_websockets.com", diff --git a/requirements.txt b/requirements.txt index dc0f5c95..71486686 100644 --- a/requirements.txt +++ b/requirements.txt @@ -47,7 +47,7 @@ tzdata==2021.5 tzlocal==2.1 # apscheduler==3.8.0 requires tzlocal~=2.0 urllib3==1.26.7 webencodings==0.5.1 -websocket-client==1.2.1 +websocket-client==1.2.3 xmltodict==0.12.0 zipp==3.6.0