diff --git a/lib/websocket/__init__.py b/lib/websocket/__init__.py index 588a8f21..a579342d 100644 --- a/lib/websocket/__init__.py +++ b/lib/websocket/__init__.py @@ -23,4 +23,4 @@ from ._exceptions import * from ._logging import * from ._socket import * -__version__ = "1.4.2" +__version__ = "1.5.1" diff --git a/lib/websocket/_app.py b/lib/websocket/_app.py index e0a88e2f..0a16ddb6 100644 --- a/lib/websocket/_app.py +++ b/lib/websocket/_app.py @@ -4,11 +4,12 @@ import sys import threading import time import traceback + +from . import _logging from ._abnf import ABNF from ._url import parse_url from ._core import WebSocket, getdefaulttimeout from ._exceptions import * -from . import _logging """ _app.py @@ -53,10 +54,9 @@ class DispatcherBase: def reconnect(self, seconds, reconnector): try: - while True: - _logging.info("reconnect() - retrying in %s seconds [%s frames in stack]" % (seconds, len(inspect.stack()))) - time.sleep(seconds) - reconnector(reconnecting=True) + _logging.info("reconnect() - retrying in %s seconds [%s frames in stack]" % (seconds, len(inspect.stack()))) + time.sleep(seconds) + reconnector(reconnecting=True) except KeyboardInterrupt as e: _logging.info("User exited %s" % (e,)) @@ -214,6 +214,11 @@ class WebSocketApp: self.sock = None self.last_ping_tm = 0 self.last_pong_tm = 0 + self.ping_thread = None + self.stop_ping = None + self.ping_interval = 0 + self.ping_timeout = None + self.ping_payload = "" self.subprotocols = subprotocols self.prepared_socket = socket self.has_errored = False @@ -244,15 +249,31 @@ class WebSocketApp: self.sock.close(**kwargs) self.sock = None - def _send_ping(self, interval, event, payload): - while not event.wait(interval): - self.last_ping_tm = time.time() + def _start_ping_thread(self): + self.last_ping_tm = self.last_pong_tm = 0 + self.stop_ping = threading.Event() + self.ping_thread = threading.Thread(target=self._send_ping) + self.ping_thread.daemon = True + self.ping_thread.start() + + def _stop_ping_thread(self): + if self.stop_ping: + self.stop_ping.set() + if self.ping_thread and self.ping_thread.is_alive(): + self.ping_thread.join(3) + self.last_ping_tm = self.last_pong_tm = 0 + + def _send_ping(self): + if self.stop_ping.wait(self.ping_interval): + return + while not self.stop_ping.wait(self.ping_interval): if self.sock: + self.last_ping_tm = time.time() try: - self.sock.ping(payload) + _logging.debug("Sending ping") + self.sock.ping(self.ping_payload) except Exception as ex: - _logging.warning("send_ping routine terminated: {}".format(ex)) - break + _logging.debug("Failed to send ping: %s", ex) def run_forever(self, sockopt=None, sslopt=None, ping_interval=0, ping_timeout=None, @@ -331,10 +352,11 @@ class WebSocketApp: sslopt = {} if self.sock: raise WebSocketException("socket is already opened") - thread = None + + self.ping_interval = ping_interval + self.ping_timeout = ping_timeout + self.ping_payload = ping_payload self.keep_running = True - self.last_ping_tm = 0 - self.last_pong_tm = 0 def teardown(close_frame=None): """ @@ -347,9 +369,7 @@ class WebSocketApp: with the statusCode and reason from the provided frame. """ - if thread and thread.is_alive(): - event.set() - thread.join() + self._stop_ping_thread() self.keep_running = False if self.sock: self.sock.close() @@ -361,11 +381,15 @@ class WebSocketApp: self._callback(self.on_close, close_status_code, close_reason) def setSock(reconnecting=False): + if reconnecting and self.sock: + self.sock.shutdown() + self.sock = WebSocket( self.get_mask_key, sockopt=sockopt, sslopt=sslopt, fire_cont_frame=self.on_cont_message is not None, skip_utf8_validation=skip_utf8_validation, enable_multithread=True) + self.sock.settimeout(getdefaulttimeout()) try: self.sock.connect( @@ -377,13 +401,16 @@ class WebSocketApp: host=host, origin=origin, suppress_origin=suppress_origin, proxy_type=proxy_type, socket=self.prepared_socket) + _logging.info("Websocket connected") + + if self.ping_interval: + self._start_ping_thread() + self._callback(self.on_open) - _logging.warning("websocket connected") dispatcher.read(self.sock.sock, read, check) except (WebSocketConnectionClosedException, ConnectionRefusedError, KeyboardInterrupt, SystemExit, Exception) as e: - _logging.error("%s - %s" % (e, reconnect and "reconnecting" or "goodbye")) - reconnecting or handleDisconnect(e) + handleDisconnect(e, reconnecting) def read(): if not self.keep_running: @@ -396,6 +423,7 @@ class WebSocketApp: return handleDisconnect(e) else: raise e + if op_code == ABNF.OPCODE_CLOSE: return teardown(frame) elif op_code == ABNF.OPCODE_PING: @@ -410,7 +438,7 @@ class WebSocketApp: frame.data, frame.fin) else: data = frame.data - if op_code == ABNF.OPCODE_TEXT: + if op_code == ABNF.OPCODE_TEXT and not skip_utf8_validation: data = data.decode("utf-8") self._callback(self.on_data, data, frame.opcode, True) self._callback(self.on_message, data) @@ -418,10 +446,10 @@ class WebSocketApp: return True def check(): - if (ping_timeout): - has_timeout_expired = time.time() - self.last_ping_tm > ping_timeout + if (self.ping_timeout): + has_timeout_expired = time.time() - self.last_ping_tm > self.ping_timeout has_pong_not_arrived_after_last_ping = self.last_pong_tm - self.last_ping_tm < 0 - has_pong_arrived_too_late = self.last_pong_tm - self.last_ping_tm > ping_timeout + has_pong_arrived_too_late = self.last_pong_tm - self.last_ping_tm > self.ping_timeout if (self.last_ping_tm and has_timeout_expired and @@ -429,29 +457,35 @@ class WebSocketApp: raise WebSocketTimeoutException("ping/pong timed out") return True - def handleDisconnect(e): + def handleDisconnect(e, reconnecting=False): self.has_errored = True - self._callback(self.on_error, e) - if isinstance(e, SystemExit): - # propagate SystemExit further + self._stop_ping_thread() + if not reconnecting: + self._callback(self.on_error, e) + + if isinstance(e, (KeyboardInterrupt, SystemExit)): + teardown() + # Propagate further raise - if reconnect and not isinstance(e, KeyboardInterrupt): - _logging.info("websocket disconnected (retrying in %s seconds) [%s frames in stack]" % (reconnect, len(inspect.stack()))) - dispatcher.reconnect(reconnect, setSock) + + if reconnect: + _logging.info("%s - reconnect" % e) + if custom_dispatcher: + _logging.debug("Calling custom dispatcher reconnect [%s frames in stack]" % len(inspect.stack())) + dispatcher.reconnect(reconnect, setSock) else: + _logging.error("%s - goodbye" % e) teardown() custom_dispatcher = bool(dispatcher) dispatcher = self.create_dispatcher(ping_timeout, dispatcher, parse_url(self.url)[3]) - if ping_interval: - event = threading.Event() - thread = threading.Thread( - target=self._send_ping, args=(ping_interval, event, ping_payload)) - thread.daemon = True - thread.start() - setSock() + if not custom_dispatcher and reconnect: + while self.keep_running: + _logging.debug("Calling dispatcher reconnect [%s frames in stack]" % len(inspect.stack())) + dispatcher.reconnect(reconnect, setSock) + return self.has_errored def create_dispatcher(self, ping_timeout, dispatcher=None, is_ssl=False): diff --git a/lib/websocket/_socket.py b/lib/websocket/_socket.py index 54e63997..7cc02164 100644 --- a/lib/websocket/_socket.py +++ b/lib/websocket/_socket.py @@ -151,7 +151,7 @@ def send(sock, data): error_code = extract_error_code(exc) if error_code is None: raise - if error_code != errno.EAGAIN or error_code != errno.EWOULDBLOCK: + if error_code != errno.EAGAIN and error_code != errno.EWOULDBLOCK: raise sel = selectors.DefaultSelector() diff --git a/lib/websocket/tests/test_app.py b/lib/websocket/tests/test_app.py index 5526d3ec..ac563c6e 100644 --- a/lib/websocket/tests/test_app.py +++ b/lib/websocket/tests/test_app.py @@ -186,13 +186,13 @@ class WebSocketAppTest(unittest.TestCase): app = ws.WebSocketApp('wss://tsock.us1.twilio.com/v3/wsconnect') app.run_forever(ping_interval=2, ping_timeout=1, ping_payload="Ping payload") - @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") - def testOpcodeBinary(self): - """ Test WebSocketApp binary opcode - """ - # The lack of wss:// in the URL below is on purpose - app = ws.WebSocketApp('wss://streaming.vn.teslamotors.com/streaming/') - app.run_forever(ping_interval=2, ping_timeout=1, ping_payload="Ping payload") + # This is commented out because the URL no longer responds in the expected way + # @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") + # def testOpcodeBinary(self): + # """ Test WebSocketApp binary opcode + # """ + # app = ws.WebSocketApp('wss://streaming.vn.teslamotors.com/streaming/') + # app.run_forever(ping_interval=2, ping_timeout=1, ping_payload="Ping payload") @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") def testBadPingInterval(self): @@ -228,6 +228,91 @@ class WebSocketAppTest(unittest.TestCase): self.assertRaises(ws.WebSocketConnectionClosedException, app.send, data="test if connection is closed") + @unittest.skipUnless(TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled") + def testCallbackFunctionException(self): + """ Test callback function exception handling """ + + exc = None + passed_app = None + + def on_open(app): + raise RuntimeError("Callback failed") + + def on_error(app, err): + nonlocal passed_app + passed_app = app + nonlocal exc + exc = err + + def on_pong(app, msg): + app.close() + + app = ws.WebSocketApp('ws://127.0.0.1:' + LOCAL_WS_SERVER_PORT, on_open=on_open, on_error=on_error, on_pong=on_pong) + app.run_forever(ping_interval=2, ping_timeout=1) + + self.assertEqual(passed_app, app) + self.assertIsInstance(exc, RuntimeError) + self.assertEqual(str(exc), "Callback failed") + + @unittest.skipUnless(TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled") + def testCallbackMethodException(self): + """ Test callback method exception handling """ + + class Callbacks: + def __init__(self): + self.exc = None + self.passed_app = None + self.app = ws.WebSocketApp( + 'ws://127.0.0.1:' + LOCAL_WS_SERVER_PORT, + on_open=self.on_open, + on_error=self.on_error, + on_pong=self.on_pong + ) + self.app.run_forever(ping_interval=2, ping_timeout=1) + + def on_open(self, app): + raise RuntimeError("Callback failed") + + def on_error(self, app, err): + self.passed_app = app + self.exc = err + + def on_pong(self, app, msg): + app.close() + + callbacks = Callbacks() + + self.assertEqual(callbacks.passed_app, callbacks.app) + self.assertIsInstance(callbacks.exc, RuntimeError) + self.assertEqual(str(callbacks.exc), "Callback failed") + + @unittest.skipUnless(TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled") + def testReconnect(self): + """ Test reconnect """ + pong_count = 0 + exc = None + + def on_error(app, err): + nonlocal exc + exc = err + + def on_pong(app, msg): + nonlocal pong_count + pong_count += 1 + if pong_count == 1: + # First pong, shutdown socket, enforce read error + app.sock.shutdown() + if pong_count >= 2: + # Got second pong after reconnect + app.close() + + app = ws.WebSocketApp('ws://127.0.0.1:' + LOCAL_WS_SERVER_PORT, on_pong=on_pong, on_error=on_error) + app.run_forever(ping_interval=2, ping_timeout=1, reconnect=3) + + self.assertEqual(pong_count, 2) + self.assertIsInstance(exc, ValueError) + self.assertEqual(str(exc), "Invalid file object: None") + if __name__ == "__main__": unittest.main() diff --git a/requirements.txt b/requirements.txt index 754a9e9c..2e253ab7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -48,7 +48,7 @@ tzdata==2022.7 tzlocal==4.2 urllib3==1.26.13 webencodings==0.5.1 -websocket-client==1.4.2 +websocket-client==1.5.1 xmltodict==0.13.0 zipp==3.15.0