Update websocket-client to 0.57.0

This commit is contained in:
JonnyWong16 2020-03-21 12:17:50 -07:00
parent 4ae09774f7
commit 060dff0162
16 changed files with 1153 additions and 5423 deletions

View file

@ -19,7 +19,11 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
Boston, MA 02110-1335 USA Boston, MA 02110-1335 USA
""" """
from ._core import * from ._abnf import *
from ._app import WebSocketApp from ._app import WebSocketApp
from ._core import *
from ._exceptions import *
from ._logging import *
from ._socket import *
__version__ = "0.32.0" __version__ = "0.57.0"

View file

@ -19,12 +19,59 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
Boston, MA 02110-1335 USA Boston, MA 02110-1335 USA
""" """
import six
import array import array
import struct
import os import os
import struct
import six
from ._exceptions import * from ._exceptions import *
from ._utils import validate_utf8 from ._utils import validate_utf8
from threading import Lock
try:
if six.PY3:
import numpy
else:
numpy = None
except ImportError:
numpy = None
try:
# If wsaccel is available we use compiled routines to mask data.
if not numpy:
from wsaccel.xormask import XorMaskerSimple
def _mask(_m, _d):
return XorMaskerSimple(_m).process(_d)
except ImportError:
# wsaccel is not available, we rely on python implementations.
def _mask(_m, _d):
for i in range(len(_d)):
_d[i] ^= _m[i % 4]
if six.PY3:
return _d.tobytes()
else:
return _d.tostring()
__all__ = [
'ABNF', 'continuous_frame', 'frame_buffer',
'STATUS_NORMAL',
'STATUS_GOING_AWAY',
'STATUS_PROTOCOL_ERROR',
'STATUS_UNSUPPORTED_DATA_TYPE',
'STATUS_STATUS_NOT_AVAILABLE',
'STATUS_ABNORMAL_CLOSED',
'STATUS_INVALID_PAYLOAD',
'STATUS_POLICY_VIOLATION',
'STATUS_MESSAGE_TOO_BIG',
'STATUS_INVALID_EXTENSION',
'STATUS_UNEXPECTED_CONDITION',
'STATUS_BAD_GATEWAY',
'STATUS_TLS_HANDSHAKE_ERROR',
]
# closing frame status codes. # closing frame status codes.
STATUS_NORMAL = 1000 STATUS_NORMAL = 1000
@ -38,6 +85,7 @@ STATUS_POLICY_VIOLATION = 1008
STATUS_MESSAGE_TOO_BIG = 1009 STATUS_MESSAGE_TOO_BIG = 1009
STATUS_INVALID_EXTENSION = 1010 STATUS_INVALID_EXTENSION = 1010
STATUS_UNEXPECTED_CONDITION = 1011 STATUS_UNEXPECTED_CONDITION = 1011
STATUS_BAD_GATEWAY = 1014
STATUS_TLS_HANDSHAKE_ERROR = 1015 STATUS_TLS_HANDSHAKE_ERROR = 1015
VALID_CLOSE_STATUS = ( VALID_CLOSE_STATUS = (
@ -50,8 +98,10 @@ VALID_CLOSE_STATUS = (
STATUS_MESSAGE_TOO_BIG, STATUS_MESSAGE_TOO_BIG,
STATUS_INVALID_EXTENSION, STATUS_INVALID_EXTENSION,
STATUS_UNEXPECTED_CONDITION, STATUS_UNEXPECTED_CONDITION,
STATUS_BAD_GATEWAY,
) )
class ABNF(object): class ABNF(object):
""" """
ABNF frame class. ABNF frame class.
@ -81,7 +131,7 @@ class ABNF(object):
OPCODE_PONG: "pong" OPCODE_PONG: "pong"
} }
# data length threashold. # data length threshold.
LENGTH_7 = 0x7e LENGTH_7 = 0x7e
LENGTH_16 = 1 << 16 LENGTH_16 = 1 << 16
LENGTH_63 = 1 << 63 LENGTH_63 = 1 << 63
@ -98,7 +148,7 @@ class ABNF(object):
self.rsv3 = rsv3 self.rsv3 = rsv3
self.opcode = opcode self.opcode = opcode
self.mask = mask self.mask = mask
if data == None: if data is None:
data = "" data = ""
self.data = data self.data = data
self.get_mask_key = os.urandom self.get_mask_key = os.urandom
@ -126,11 +176,13 @@ class ABNF(object):
if l > 2 and not skip_utf8_validation and not validate_utf8(self.data[2:]): if l > 2 and not skip_utf8_validation and not validate_utf8(self.data[2:]):
raise WebSocketProtocolException("Invalid close frame.") raise WebSocketProtocolException("Invalid close frame.")
code = 256*six.byte2int(self.data[0:1]) + six.byte2int(self.data[1:2]) code = 256 * \
six.byte2int(self.data[0:1]) + six.byte2int(self.data[1:2])
if not self._is_valid_close_status(code): if not self._is_valid_close_status(code):
raise WebSocketProtocolException("Invalid close opcode.") raise WebSocketProtocolException("Invalid close opcode.")
def _is_valid_close_status(self, code): @staticmethod
def _is_valid_close_status(code):
return code in VALID_CLOSE_STATUS or (3000 <= code < 5000) return code in VALID_CLOSE_STATUS or (3000 <= code < 5000)
def __str__(self): def __str__(self):
@ -144,8 +196,8 @@ class ABNF(object):
create frame to send text, binary and other data. create frame to send text, binary and other data.
data: data to send. This is string value(byte array). data: data to send. This is string value(byte array).
if opcode is OPCODE_TEXT and this value is uniocde, if opcode is OPCODE_TEXT and this value is unicode,
data value is conveted into unicode string, automatically. data value is converted into unicode string, automatically.
opcode: operation code. please see OPCODE_XXX. opcode: operation code. please see OPCODE_XXX.
@ -206,28 +258,35 @@ class ABNF(object):
data: data to mask/unmask. data: data to mask/unmask.
""" """
if data == None: if data is None:
data = "" data = ""
if isinstance(mask_key, six.text_type): if isinstance(mask_key, six.text_type):
mask_key = six.b(mask_key) mask_key = six.b(mask_key)
if isinstance(data, six.text_type): if isinstance(data, six.text_type):
data = six.b(data) data = six.b(data)
if numpy:
origlen = len(data)
_mask_key = mask_key[3] << 24 | mask_key[2] << 16 | mask_key[1] << 8 | mask_key[0]
# We need data to be a multiple of four...
data += bytes(" " * (4 - (len(data) % 4)), "us-ascii")
a = numpy.frombuffer(data, dtype="uint32")
masked = numpy.bitwise_xor(a, [_mask_key]).astype("uint32")
if len(data) > origlen:
return masked.tobytes()[:origlen]
return masked.tobytes()
else:
_m = array.array("B", mask_key) _m = array.array("B", mask_key)
_d = array.array("B", data) _d = array.array("B", data)
for i in range(len(_d)): return _mask(_m, _d)
_d[i] ^= _m[i % 4]
if six.PY3:
return _d.tobytes()
else:
return _d.tostring()
class frame_buffer(object): class frame_buffer(object):
_HEADER_MASK_INDEX = 5 _HEADER_MASK_INDEX = 5
_HEADER_LENGHT_INDEX = 6 _HEADER_LENGTH_INDEX = 6
def __init__(self, recv_fn, skip_utf8_validation): def __init__(self, recv_fn, skip_utf8_validation):
self.recv = recv_fn self.recv = recv_fn
@ -236,6 +295,7 @@ class frame_buffer(object):
# bytes of bytes are received. # bytes of bytes are received.
self.recv_buffer = [] self.recv_buffer = []
self.clear() self.clear()
self.lock = Lock()
def clear(self): def clear(self):
self.header = None self.header = None
@ -272,12 +332,11 @@ class frame_buffer(object):
return False return False
return self.header[frame_buffer._HEADER_MASK_INDEX] return self.header[frame_buffer._HEADER_MASK_INDEX]
def has_received_length(self): def has_received_length(self):
return self.length is None return self.length is None
def recv_length(self): def recv_length(self):
bits = self.header[frame_buffer._HEADER_LENGHT_INDEX] bits = self.header[frame_buffer._HEADER_LENGTH_INDEX]
length_bits = bits & 0x7f length_bits = bits & 0x7f
if length_bits == 0x7e: if length_bits == 0x7e:
v = self.recv_strict(2) v = self.recv_strict(2)
@ -295,6 +354,8 @@ class frame_buffer(object):
self.mask = self.recv_strict(4) if self.has_mask() else "" self.mask = self.recv_strict(4) if self.has_mask() else ""
def recv_frame(self): def recv_frame(self):
with self.lock:
# Header # Header
if self.has_received_header(): if self.has_received_header():
self.recv_header() self.recv_header()
@ -330,10 +391,11 @@ class frame_buffer(object):
# fragmenting the heap -- the number of bytes recv() actually # fragmenting the heap -- the number of bytes recv() actually
# reads is limited by socket buffer and is relatively small, # reads is limited by socket buffer and is relatively small,
# yet passing large numbers repeatedly causes lots of large # yet passing large numbers repeatedly causes lots of large
# buffers allocated and then shrunk, which results in fragmentation. # buffers allocated and then shrunk, which results in
bytes = self.recv(min(16384, shortage)) # fragmentation.
self.recv_buffer.append(bytes) bytes_ = self.recv(min(16384, shortage))
shortage -= len(bytes) self.recv_buffer.append(bytes_)
shortage -= len(bytes_)
unified = six.b("").join(self.recv_buffer) unified = six.b("").join(self.recv_buffer)
@ -346,6 +408,7 @@ class frame_buffer(object):
class continuous_frame(object): class continuous_frame(object):
def __init__(self, fire_cont_frame, skip_utf8_validation): def __init__(self, fire_cont_frame, skip_utf8_validation):
self.fire_cont_frame = fire_cont_frame self.fire_cont_frame = fire_cont_frame
self.skip_utf8_validation = skip_utf8_validation self.skip_utf8_validation = skip_utf8_validation
@ -355,7 +418,8 @@ class continuous_frame(object):
def validate(self, frame): def validate(self, frame):
if not self.recving_frames and frame.opcode == ABNF.OPCODE_CONT: if not self.recving_frames and frame.opcode == ABNF.OPCODE_CONT:
raise WebSocketProtocolException("Illegal frame") raise WebSocketProtocolException("Illegal frame")
if self.recving_frames and frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY): if self.recving_frames and \
frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY):
raise WebSocketProtocolException("Illegal frame") raise WebSocketProtocolException("Illegal frame")
def add(self, frame): def add(self, frame):
@ -377,6 +441,7 @@ class continuous_frame(object):
self.cont_data = None self.cont_data = None
frame.data = data[1] frame.data = data[1]
if not self.fire_cont_frame and data[0] == ABNF.OPCODE_TEXT and not self.skip_utf8_validation and not validate_utf8(frame.data): if not self.fire_cont_frame and data[0] == ABNF.OPCODE_TEXT and not self.skip_utf8_validation and not validate_utf8(frame.data):
raise WebSocketPayloadException("cannot decode: " + repr(frame.data)) raise WebSocketPayloadException(
"cannot decode: " + repr(frame.data))
return [data[0], frame] return [data[0], frame]

View file

@ -23,74 +23,124 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
""" """
WebSocketApp provides higher level APIs. WebSocketApp provides higher level APIs.
""" """
import inspect
import select
import sys
import threading import threading
import time import time
import traceback import traceback
import sys
import select
import six import six
from ._abnf import ABNF
from ._core import WebSocket, getdefaulttimeout from ._core import WebSocket, getdefaulttimeout
from ._exceptions import * from ._exceptions import *
from ._logging import * from . import _logging
from websocket._abnf import ABNF
__all__ = ["WebSocketApp"] __all__ = ["WebSocketApp"]
class Dispatcher:
def __init__(self, app, ping_timeout):
self.app = app
self.ping_timeout = ping_timeout
def read(self, sock, read_callback, check_callback):
while self.app.keep_running:
r, w, e = select.select(
(self.app.sock.sock, ), (), (), self.ping_timeout)
if r:
if not read_callback():
break
check_callback()
class SSLDispatcher:
def __init__(self, app, ping_timeout):
self.app = app
self.ping_timeout = ping_timeout
def read(self, sock, read_callback, check_callback):
while self.app.keep_running:
r = self.select()
if r:
if not read_callback():
break
check_callback()
def select(self):
sock = self.app.sock.sock
if sock.pending():
return [sock,]
r, w, e = select.select((sock, ), (), (), self.ping_timeout)
return r
class WebSocketApp(object): class WebSocketApp(object):
""" """
Higher level of APIs are provided. Higher level of APIs are provided.
The interface is like JavaScript WebSocket object. The interface is like JavaScript WebSocket object.
""" """
def __init__(self, url, header=[],
def __init__(self, url, header=None,
on_open=None, on_message=None, on_error=None, on_open=None, on_message=None, on_error=None,
on_close=None, on_ping=None, on_pong=None, on_close=None, on_ping=None, on_pong=None,
on_cont_message=None, on_cont_message=None,
keep_running=True, get_mask_key=None, cookie=None, keep_running=True, get_mask_key=None, cookie=None,
subprotocols=None): subprotocols=None,
on_data=None):
""" """
url: websocket url. url: websocket url.
header: custom header for websocket handshake. header: custom header for websocket handshake.
on_open: callable object which is called at opening websocket. on_open: callable object which is called at opening websocket.
this function has one argument. The arugment is this class object. this function has one argument. The argument is this class object.
on_message: callbale object which is called when recieved data. on_message: callable object which is called when received data.
on_message has 2 arguments. on_message has 2 arguments.
The 1st arugment is this class object. The 1st argument is this class object.
The passing 2nd arugment is utf-8 string which we get from the server. The 2nd argument is utf-8 string which we get from the server.
on_error: callable object which is called when we get error. on_error: callable object which is called when we get error.
on_error has 2 arguments. on_error has 2 arguments.
The 1st arugment is this class object. The 1st argument is this class object.
The passing 2nd arugment is exception object. The 2nd argument is exception object.
on_close: callable object which is called when closed the connection. on_close: callable object which is called when closed the connection.
this function has one argument. The arugment is this class object. this function has one argument. The argument is this class object.
on_cont_message: callback object which is called when recieve continued on_cont_message: callback object which is called when receive continued
frame data. frame data.
on_message has 3 arguments. on_cont_message has 3 arguments.
The 1st arugment is this class object. The 1st argument is this class object.
The passing 2nd arugment is utf-8 string which we get from the server. The 2nd argument is utf-8 string which we get from the server.
The 3rd arugment is continue flag. if 0, the data continue The 3rd argument is continue flag. if 0, the data continue
to next frame data to next frame data
keep_running: a boolean flag indicating whether the app's main loop on_data: callback object which is called when a message received.
should keep running, defaults to True This is called before on_message or on_cont_message,
and then on_message or on_cont_message is called.
on_data has 4 argument.
The 1st argument is this class object.
The 2nd argument is utf-8 string which we get from the server.
The 3rd argument is data type. ABNF.OPCODE_TEXT or ABNF.OPCODE_BINARY will be came.
The 4th argument is continue flag. if 0, the data continue
keep_running: this parameter is obsolete and ignored.
get_mask_key: a callable to produce new mask keys, get_mask_key: a callable to produce new mask keys,
see the WebSocket.set_mask_key's docstring for more information see the WebSocket.set_mask_key's docstring for more information
subprotocols: array of available sub protocols. default is None. subprotocols: array of available sub protocols. default is None.
""" """
self.url = url self.url = url
self.header = header self.header = header if header is not None else []
self.cookie = cookie self.cookie = cookie
self.on_open = on_open self.on_open = on_open
self.on_message = on_message self.on_message = on_message
self.on_data = on_data
self.on_error = on_error self.on_error = on_error
self.on_close = on_close self.on_close = on_close
self.on_ping = on_ping self.on_ping = on_ping
self.on_pong = on_pong self.on_pong = on_pong
self.on_cont_message = on_cont_message self.on_cont_message = on_cont_message
self.keep_running = keep_running self.keep_running = False
self.get_mask_key = get_mask_key self.get_mask_key = get_mask_key
self.sock = None self.sock = None
self.last_ping_tm = 0 self.last_ping_tm = 0
self.last_pong_tm = 0
self.subprotocols = subprotocols self.subprotocols = subprotocols
def send(self, data, opcode=ABNF.OPCODE_TEXT): def send(self, data, opcode=ABNF.OPCODE_TEXT):
@ -102,121 +152,183 @@ class WebSocketApp(object):
""" """
if not self.sock or self.sock.send(data, opcode) == 0: if not self.sock or self.sock.send(data, opcode) == 0:
raise WebSocketConnectionClosedException("Connection is already closed.") raise WebSocketConnectionClosedException(
"Connection is already closed.")
def close(self): def close(self, **kwargs):
""" """
close websocket connection. close websocket connection.
""" """
self.keep_running = False self.keep_running = False
if self.sock: if self.sock:
self.sock.close() self.sock.close(**kwargs)
self.sock = None
def _send_ping(self, interval, event): def _send_ping(self, interval, event):
while not event.wait(interval): while not event.wait(interval):
self.last_ping_tm = time.time() self.last_ping_tm = time.time()
if self.sock: if self.sock:
try:
self.sock.ping() self.sock.ping()
except Exception as ex:
_logging.warning("send_ping routine terminated: {}".format(ex))
break
def run_forever(self, sockopt=None, sslopt=None, def run_forever(self, sockopt=None, sslopt=None,
ping_interval=0, ping_timeout=None, ping_interval=0, ping_timeout=None,
http_proxy_host=None, http_proxy_port=None, http_proxy_host=None, http_proxy_port=None,
http_no_proxy=None, http_proxy_auth=None, http_no_proxy=None, http_proxy_auth=None,
skip_utf8_validation=False, skip_utf8_validation=False,
host=None, origin=None): host=None, origin=None, dispatcher=None,
suppress_origin=False, proxy_type=None):
""" """
run event loop for WebSocket framework. run event loop for WebSocket framework.
This loop is infinite loop and is alive during websocket is available. This loop is infinite loop and is alive during websocket is available.
sockopt: values for socket.setsockopt. sockopt: values for socket.setsockopt.
sockopt must be tuple sockopt must be tuple
and each element is argument of sock.setscokopt. and each element is argument of sock.setsockopt.
sslopt: ssl socket optional dict. sslopt: ssl socket optional dict.
ping_interval: automatically send "ping" command ping_interval: automatically send "ping" command
every specified period(second) every specified period(second)
if set to 0, not send automatically. if set to 0, not send automatically.
ping_timeout: timeout(second) if the pong message is not recieved. ping_timeout: timeout(second) if the pong message is not received.
http_proxy_host: http proxy host name. http_proxy_host: http proxy host name.
http_proxy_port: http proxy port. If not set, set to 80. http_proxy_port: http proxy port. If not set, set to 80.
http_no_proxy: host names, which doesn't use proxy. http_no_proxy: host names, which doesn't use proxy.
skip_utf8_validation: skip utf8 validation. skip_utf8_validation: skip utf8 validation.
host: update host header. host: update host header.
origin: update origin header. origin: update origin header.
dispatcher: customize reading data from socket.
suppress_origin: suppress outputting origin header.
Returns
-------
False if caught KeyboardInterrupt
True if other exception was raised during a loop
""" """
if not ping_timeout or ping_timeout <= 0: if ping_timeout is not None and ping_timeout <= 0:
ping_timeout = None ping_timeout = None
if sockopt is None: if ping_timeout and ping_interval and ping_interval <= ping_timeout:
raise WebSocketException("Ensure ping_interval > ping_timeout")
if not sockopt:
sockopt = [] sockopt = []
if sslopt is None: if not sslopt:
sslopt = {} sslopt = {}
if self.sock: if self.sock:
raise WebSocketException("socket is already opened") raise WebSocketException("socket is already opened")
thread = None thread = None
close_frame = None self.keep_running = True
self.last_ping_tm = 0
self.last_pong_tm = 0
def teardown(close_frame=None):
"""
Tears down the connection.
If close_frame is set, we will invoke the on_close handler with the
statusCode and reason from there.
"""
if thread and thread.isAlive():
event.set()
thread.join()
self.keep_running = False
if self.sock:
self.sock.close()
close_args = self._get_close_args(
close_frame.data if close_frame else None)
self._callback(self.on_close, *close_args)
self.sock = None
try: try:
self.sock = WebSocket(self.get_mask_key, self.sock = WebSocket(
sockopt=sockopt, sslopt=sslopt, self.get_mask_key, sockopt=sockopt, sslopt=sslopt,
fire_cont_frame=self.on_cont_message and True or False, fire_cont_frame=self.on_cont_message is not None,
skip_utf8_validation=skip_utf8_validation) skip_utf8_validation=skip_utf8_validation,
enable_multithread=True if ping_interval else False)
self.sock.settimeout(getdefaulttimeout()) self.sock.settimeout(getdefaulttimeout())
self.sock.connect(self.url, header=self.header, cookie=self.cookie, self.sock.connect(
self.url, header=self.header, cookie=self.cookie,
http_proxy_host=http_proxy_host, http_proxy_host=http_proxy_host,
http_proxy_port=http_proxy_port, http_proxy_port=http_proxy_port, http_no_proxy=http_no_proxy,
http_no_proxy=http_no_proxy, http_proxy_auth=http_proxy_auth, http_proxy_auth=http_proxy_auth, subprotocols=self.subprotocols,
subprotocols=self.subprotocols, host=host, origin=origin, suppress_origin=suppress_origin,
host=host, origin=origin) proxy_type=proxy_type)
if not dispatcher:
dispatcher = self.create_dispatcher(ping_timeout)
self._callback(self.on_open) self._callback(self.on_open)
if ping_interval: if ping_interval:
event = threading.Event() event = threading.Event()
thread = threading.Thread(target=self._send_ping, args=(ping_interval, event)) thread = threading.Thread(
target=self._send_ping, args=(ping_interval, event))
thread.setDaemon(True) thread.setDaemon(True)
thread.start() thread.start()
while self.sock.connected: def read():
r, w, e = select.select((self.sock.sock, ), (), (), ping_timeout)
if not self.keep_running: if not self.keep_running:
break return teardown()
if ping_timeout and self.last_ping_tm and time.time() - self.last_ping_tm > ping_timeout:
self.last_ping_tm = 0
raise WebSocketTimeoutException("ping timed out")
if r:
op_code, frame = self.sock.recv_data_frame(True) op_code, frame = self.sock.recv_data_frame(True)
if op_code == ABNF.OPCODE_CLOSE: if op_code == ABNF.OPCODE_CLOSE:
close_frame = frame return teardown(frame)
break
elif op_code == ABNF.OPCODE_PING: elif op_code == ABNF.OPCODE_PING:
self._callback(self.on_ping, frame.data) self._callback(self.on_ping, frame.data)
elif op_code == ABNF.OPCODE_PONG: elif op_code == ABNF.OPCODE_PONG:
self.last_pong_tm = time.time()
self._callback(self.on_pong, frame.data) self._callback(self.on_pong, frame.data)
elif op_code == ABNF.OPCODE_CONT and self.on_cont_message: elif op_code == ABNF.OPCODE_CONT and self.on_cont_message:
self._callback(self.on_cont_message, frame.data, frame.fin) self._callback(self.on_data, frame.data,
frame.opcode, frame.fin)
self._callback(self.on_cont_message,
frame.data, frame.fin)
else: else:
data = frame.data data = frame.data
if six.PY3 and frame.opcode == ABNF.OPCODE_TEXT: if six.PY3 and op_code == ABNF.OPCODE_TEXT:
data = data.decode("utf-8") data = data.decode("utf-8")
self._callback(self.on_data, data, frame.opcode, True)
self._callback(self.on_message, data) self._callback(self.on_message, data)
except Exception as e:
return True
def check():
if (ping_timeout):
has_timeout_expired = time.time() - self.last_ping_tm > ping_timeout
has_pong_not_arrived_after_last_ping = self.last_pong_tm - self.last_ping_tm < 0
has_pong_arrived_too_late = self.last_pong_tm - self.last_ping_tm > ping_timeout
if (self.last_ping_tm
and has_timeout_expired
and (has_pong_not_arrived_after_last_ping or has_pong_arrived_too_late)):
raise WebSocketTimeoutException("ping/pong timed out")
return True
dispatcher.read(self.sock.sock, read, check)
except (Exception, KeyboardInterrupt, SystemExit) as e:
self._callback(self.on_error, e) self._callback(self.on_error, e)
finally: if isinstance(e, SystemExit):
if thread: # propagate SystemExit further
event.set() raise
thread.join() teardown()
self.keep_running = False return not isinstance(e, KeyboardInterrupt)
self.sock.close()
self._callback(self.on_close, def create_dispatcher(self, ping_timeout):
*self._get_close_args(close_frame.data if close_frame else None)) timeout = ping_timeout or 10
self.sock = None if self.sock.is_ssl():
return SSLDispatcher(self, timeout)
return Dispatcher(self, timeout)
def _get_close_args(self, data): def _get_close_args(self, data):
""" this functions extracts the code, reason from the close body """ this functions extracts the code, reason from the close body
if they exists, and if the self.on_close except three arguments """ if they exists, and if the self.on_close except three arguments """
import inspect
# if the on_close callback is "old", just return empty list # if the on_close callback is "old", just return empty list
if sys.version_info < (3, 0):
if not self.on_close or len(inspect.getargspec(self.on_close).args) != 3: if not self.on_close or len(inspect.getargspec(self.on_close).args) != 3:
return [] return []
else:
if not self.on_close or len(inspect.getfullargspec(self.on_close).args) != 3:
return []
if data and len(data) >= 2: if data and len(data) >= 2:
code = 256 * six.byte2int(data[0:1]) + six.byte2int(data[1:2]) code = 256 * six.byte2int(data[0:1]) + six.byte2int(data[1:2])
@ -228,9 +340,13 @@ class WebSocketApp(object):
def _callback(self, callback, *args): def _callback(self, callback, *args):
if callback: if callback:
try: try:
if inspect.ismethod(callback):
callback(*args)
else:
callback(self, *args) callback(self, *args)
except Exception as e: except Exception as e:
error(e) _logging.error("error from callback {}: {}".format(callback, e))
if isEnabledForDebug(): if _logging.isEnabledForDebug():
_, _, tb = sys.exc_info() _, _, tb = sys.exc_info()
traceback.print_tb(tb) traceback.print_tb(tb)

View file

@ -0,0 +1,52 @@
try:
import Cookie
except:
import http.cookies as Cookie
class SimpleCookieJar(object):
def __init__(self):
self.jar = dict()
def add(self, set_cookie):
if set_cookie:
try:
simpleCookie = Cookie.SimpleCookie(set_cookie)
except:
simpleCookie = Cookie.SimpleCookie(set_cookie.encode('ascii', 'ignore'))
for k, v in simpleCookie.items():
domain = v.get("domain")
if domain:
if not domain.startswith("."):
domain = "." + domain
cookie = self.jar.get(domain) if self.jar.get(domain) else Cookie.SimpleCookie()
cookie.update(simpleCookie)
self.jar[domain.lower()] = cookie
def set(self, set_cookie):
if set_cookie:
try:
simpleCookie = Cookie.SimpleCookie(set_cookie)
except:
simpleCookie = Cookie.SimpleCookie(set_cookie.encode('ascii', 'ignore'))
for k, v in simpleCookie.items():
domain = v.get("domain")
if domain:
if not domain.startswith("."):
domain = "." + domain
self.jar[domain.lower()] = simpleCookie
def get(self, host):
if not host:
return ""
cookies = []
for domain, simpleCookie in self.jar.items():
host = host.lower()
if host.endswith(domain) or host == domain[1:]:
cookies.append(self.jar.get(domain))
return "; ".join(filter(None, ["%s=%s" % (k, v.value) for cookie in filter(None, sorted(cookies)) for k, v in
sorted(cookie.items())]))

View file

@ -21,28 +21,24 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
""" """
from __future__ import print_function from __future__ import print_function
import six
import socket import socket
if six.PY3:
from base64 import encodebytes as base64encode
else:
from base64 import encodestring as base64encode
import struct import struct
import threading import threading
import time
import six
# websocket modules # websocket modules
from ._exceptions import *
from ._abnf import * from ._abnf import *
from ._socket import * from ._exceptions import *
from ._utils import *
from ._url import *
from ._logging import *
from ._http import *
from ._handshake import * from ._handshake import *
from ._http import *
from ._logging import *
from ._socket import *
from ._ssl_compat import * from ._ssl_compat import *
from ._utils import *
__all__ = ['WebSocket', 'create_connection']
""" """
websocket python client. websocket python client.
@ -53,58 +49,6 @@ Please see http://tools.ietf.org/html/rfc6455 for protocol.
""" """
def create_connection(url, timeout=None, **options):
"""
connect to url and return websocket object.
Connect to url and return the WebSocket object.
Passing optional timeout parameter will set the timeout on the socket.
If no timeout is supplied,
the global default timeout setting returned by getdefauttimeout() is used.
You can customize using 'options'.
If you set "header" list object, you can set your own custom header.
>>> conn = create_connection("ws://echo.websocket.org/",
... header=["User-Agent: MyProgram",
... "x-custom: header"])
timeout: socket timeout time. This value is integer.
if you set None for this value,
it means "use default_timeout value"
options: "header" -> custom http header list.
"cookie" -> cookie value.
"origin" -> custom origin url.
"host" -> custom host header string.
"http_proxy_host" - http proxy host name.
"http_proxy_port" - http proxy port. If not set, set to 80.
"http_no_proxy" - host names, which doesn't use proxy.
"http_proxy_auth" - http proxy auth infomation.
tuple of username and password.
default is None
"enable_multithread" -> enable lock for multithread.
"sockopt" -> socket options
"sslopt" -> ssl option
"subprotocols" - array of available sub protocols.
default is None.
"skip_utf8_validation" - skip utf8 validation.
"""
sockopt = options.get("sockopt", [])
sslopt = options.get("sslopt", {})
fire_cont_frame = options.get("fire_cont_frame", False)
enable_multithread = options.get("enable_multithread", False)
skip_utf8_validation = options.get("skip_utf8_validation", False)
websock = WebSocket(sockopt=sockopt, sslopt=sslopt,
fire_cont_frame=fire_cont_frame,
enable_multithread=enable_multithread,
skip_utf8_validation=skip_utf8_validation)
websock.settimeout(timeout if timeout is not None else getdefaulttimeout())
websock.connect(url, **options)
return websock
class WebSocket(object): class WebSocket(object):
""" """
Low level WebSocket interface. Low level WebSocket interface.
@ -112,8 +56,8 @@ class WebSocket(object):
The WebSocket protocol draft-hixie-thewebsocketprotocol-76 The WebSocket protocol draft-hixie-thewebsocketprotocol-76
http://tools.ietf.org/html/draft-hixie-thewebsocketprotocol-76 http://tools.ietf.org/html/draft-hixie-thewebsocketprotocol-76
We can connect to the websocket server and send/recieve data. We can connect to the websocket server and send/receive data.
The following example is a echo client. The following example is an echo client.
>>> import websocket >>> import websocket
>>> ws = websocket.WebSocket() >>> ws = websocket.WebSocket()
@ -126,7 +70,7 @@ class WebSocket(object):
get_mask_key: a callable to produce new mask keys, see the set_mask_key get_mask_key: a callable to produce new mask keys, see the set_mask_key
function's docstring for more details function's docstring for more details
sockopt: values for socket.setsockopt. sockopt: values for socket.setsockopt.
sockopt must be tuple and each element is argument of sock.setscokopt. sockopt must be tuple and each element is argument of sock.setsockopt.
sslopt: dict object for ssl socket option. sslopt: dict object for ssl socket option.
fire_cont_frame: fire recv event for each cont frame. default is False fire_cont_frame: fire recv event for each cont frame. default is False
enable_multithread: if set to True, lock send method. enable_multithread: if set to True, lock send method.
@ -135,9 +79,9 @@ class WebSocket(object):
def __init__(self, get_mask_key=None, sockopt=None, sslopt=None, def __init__(self, get_mask_key=None, sockopt=None, sslopt=None,
fire_cont_frame=False, enable_multithread=False, fire_cont_frame=False, enable_multithread=False,
skip_utf8_validation=False): skip_utf8_validation=False, **_):
""" """
Initalize WebSocket object. Initialize WebSocket object.
""" """
self.sock_opt = sock_opt(sockopt, sslopt) self.sock_opt = sock_opt(sockopt, sslopt)
self.handshake_response = None self.handshake_response = None
@ -147,12 +91,15 @@ class WebSocket(object):
self.get_mask_key = get_mask_key self.get_mask_key = get_mask_key
# These buffer over the build-up of a single frame. # These buffer over the build-up of a single frame.
self.frame_buffer = frame_buffer(self._recv, skip_utf8_validation) self.frame_buffer = frame_buffer(self._recv, skip_utf8_validation)
self.cont_frame = continuous_frame(fire_cont_frame, skip_utf8_validation) self.cont_frame = continuous_frame(
fire_cont_frame, skip_utf8_validation)
if enable_multithread: if enable_multithread:
self.lock = threading.Lock() self.lock = threading.Lock()
self.readlock = threading.Lock()
else: else:
self.lock = NoLock() self.lock = NoLock()
self.readlock = NoLock()
def __iter__(self): def __iter__(self):
""" """
@ -172,12 +119,12 @@ class WebSocket(object):
def set_mask_key(self, func): def set_mask_key(self, func):
""" """
set function to create musk key. You can custumize mask key generator. set function to create musk key. You can customize mask key generator.
Mainly, this is for testing purpose. Mainly, this is for testing purpose.
func: callable object. the fuct must 1 argument as integer. func: callable object. the func takes 1 argument as integer.
The argument means length of mask key. The argument means length of mask key.
This func must be return string(byte array), This func must return string(byte array),
which length is argument specified. which length is argument specified.
""" """
self.get_mask_key = func self.get_mask_key = func
@ -231,6 +178,9 @@ class WebSocket(object):
else: else:
return None return None
def is_ssl(self):
return isinstance(self.sock, ssl.SSLSocket)
headers = property(getheaders) headers = property(getheaders)
def connect(self, url, **options): def connect(self, url, **options):
@ -249,24 +199,38 @@ class WebSocket(object):
if you set None for this value, if you set None for this value,
it means "use default_timeout value" it means "use default_timeout value"
options: "header" -> custom http header list. options: "header" -> custom http header list or dict.
"cookie" -> cookie value. "cookie" -> cookie value.
"origin" -> custom origin url. "origin" -> custom origin url.
"suppress_origin" -> suppress outputting origin header.
"host" -> custom host header string. "host" -> custom host header string.
"http_proxy_host" - http proxy host name. "http_proxy_host" - http proxy host name.
"http_proxy_port" - http proxy port. If not set, set to 80. "http_proxy_port" - http proxy port. If not set, set to 80.
"http_no_proxy" - host names, which doesn't use proxy. "http_no_proxy" - host names, which doesn't use proxy.
"http_proxy_auth" - http proxy auth infomation. "http_proxy_auth" - http proxy auth information.
tuple of username and password. tuple of username and password.
defualt is None default is None
"redirect_limit" -> number of redirects to follow.
"subprotocols" - array of available sub protocols. "subprotocols" - array of available sub protocols.
default is None. default is None.
"socket" - pre-initialized stream socket.
""" """
self.sock, addrs = connect(url, self.sock_opt, proxy_info(**options)) # FIXME: "subprotocols" are getting lost, not passed down
# FIXME: "header", "cookie", "origin" and "host" too
self.sock_opt.timeout = options.get('timeout', self.sock_opt.timeout)
self.sock, addrs = connect(url, self.sock_opt, proxy_info(**options),
options.pop('socket', None))
try: try:
self.handshake_response = handshake(self.sock, *addrs, **options) self.handshake_response = handshake(self.sock, *addrs, **options)
for attempt in range(options.pop('redirect_limit', 3)):
if self.handshake_response.status in SUPPORTED_REDIRECT_STATUSES:
url = self.handshake_response.headers['location']
self.sock.close()
self.sock, addrs = connect(url, self.sock_opt, proxy_info(**options),
options.pop('socket', None))
self.handshake_response = handshake(self.sock, *addrs, **options)
self.connected = True self.connected = True
except: except:
if self.sock: if self.sock:
@ -307,6 +271,7 @@ class WebSocket(object):
frame.get_mask_key = self.get_mask_key frame.get_mask_key = self.get_mask_key
data = frame.format() data = frame.format()
length = len(data) length = len(data)
if (isEnabledForTrace()):
trace("send: " + repr(data)) trace("send: " + repr(data))
with self.lock: with self.lock:
@ -345,6 +310,7 @@ class WebSocket(object):
return value: string(byte array) value. return value: string(byte array) value.
""" """
with self.readlock:
opcode, data = self.recv_data() opcode, data = self.recv_data()
if six.PY3 and opcode == ABNF.OPCODE_TEXT: if six.PY3 and opcode == ABNF.OPCODE_TEXT:
return data.decode("utf-8") return data.decode("utf-8")
@ -355,7 +321,7 @@ class WebSocket(object):
def recv_data(self, control_frame=False): def recv_data(self, control_frame=False):
""" """
Recieve data with operation code. Receive data with operation code.
control_frame: a boolean flag indicating whether to return control frame control_frame: a boolean flag indicating whether to return control frame
data, defaults to False data, defaults to False
@ -367,7 +333,7 @@ class WebSocket(object):
def recv_data_frame(self, control_frame=False): def recv_data_frame(self, control_frame=False):
""" """
Recieve data with operation code. Receive data with operation code.
control_frame: a boolean flag indicating whether to return control frame control_frame: a boolean flag indicating whether to return control frame
data, defaults to False data, defaults to False
@ -379,7 +345,8 @@ class WebSocket(object):
if not frame: if not frame:
# handle error: # handle error:
# 'NoneType' object has no attribute 'opcode' # 'NoneType' object has no attribute 'opcode'
raise WebSocketProtocolException("Not a valid frame %s" % frame) raise WebSocketProtocolException(
"Not a valid frame %s" % frame)
elif frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY, ABNF.OPCODE_CONT): elif frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY, ABNF.OPCODE_CONT):
self.cont_frame.validate(frame) self.cont_frame.validate(frame)
self.cont_frame.add(frame) self.cont_frame.add(frame)
@ -389,21 +356,22 @@ class WebSocket(object):
elif frame.opcode == ABNF.OPCODE_CLOSE: elif frame.opcode == ABNF.OPCODE_CLOSE:
self.send_close() self.send_close()
return (frame.opcode, frame) return frame.opcode, frame
elif frame.opcode == ABNF.OPCODE_PING: elif frame.opcode == ABNF.OPCODE_PING:
if len(frame.data) < 126: if len(frame.data) < 126:
self.pong(frame.data) self.pong(frame.data)
else: else:
raise WebSocketProtocolException("Ping message is too long") raise WebSocketProtocolException(
"Ping message is too long")
if control_frame: if control_frame:
return (frame.opcode, frame) return frame.opcode, frame
elif frame.opcode == ABNF.OPCODE_PONG: elif frame.opcode == ABNF.OPCODE_PONG:
if control_frame: if control_frame:
return (frame.opcode, frame) return frame.opcode, frame
def recv_frame(self): def recv_frame(self):
""" """
recieve data as frame from server. receive data as frame from server.
return value: ABNF frame object. return value: ABNF frame object.
""" """
@ -422,13 +390,16 @@ class WebSocket(object):
self.connected = False self.connected = False
self.send(struct.pack('!H', status) + reason, ABNF.OPCODE_CLOSE) self.send(struct.pack('!H', status) + reason, ABNF.OPCODE_CLOSE)
def close(self, status=STATUS_NORMAL, reason=six.b("")): def close(self, status=STATUS_NORMAL, reason=six.b(""), timeout=3):
""" """
Close Websocket object Close Websocket object
status: status code to send. see STATUS_XXX. status: status code to send. see STATUS_XXX.
reason: the reason to close. This must be string. reason: the reason to close. This must be string.
timeout: timeout until receive a close frame.
If None, it will wait forever until receive a close frame.
""" """
if self.connected: if self.connected:
if status < 0 or status >= ABNF.LENGTH_16: if status < 0 or status >= ABNF.LENGTH_16:
@ -436,18 +407,24 @@ class WebSocket(object):
try: try:
self.connected = False self.connected = False
self.send(struct.pack('!H', status) + reason, ABNF.OPCODE_CLOSE) self.send(struct.pack('!H', status) +
timeout = self.sock.gettimeout() reason, ABNF.OPCODE_CLOSE)
self.sock.settimeout(3) sock_timeout = self.sock.gettimeout()
self.sock.settimeout(timeout)
start_time = time.time()
while timeout is None or time.time() - start_time < timeout:
try: try:
frame = self.recv_frame() frame = self.recv_frame()
if frame.opcode != ABNF.OPCODE_CLOSE:
continue
if isEnabledForError(): if isEnabledForError():
recv_status = struct.unpack("!H", frame.data)[0] recv_status = struct.unpack("!H", frame.data[0:2])[0]
if recv_status != STATUS_NORMAL: if recv_status != STATUS_NORMAL:
error("close status: " + repr(recv_status)) error("close status: " + repr(recv_status))
break
except: except:
pass break
self.sock.settimeout(timeout) self.sock.settimeout(sock_timeout)
self.sock.shutdown(socket.SHUT_RDWR) self.sock.shutdown(socket.SHUT_RDWR)
except: except:
pass pass
@ -456,13 +433,13 @@ class WebSocket(object):
def abort(self): def abort(self):
""" """
Low-level asynchonous abort, wakes up other threads that are waiting in recv_* Low-level asynchronous abort, wakes up other threads that are waiting in recv_*
""" """
if self.connected: if self.connected:
self.sock.shutdown(socket.SHUT_RDWR) self.sock.shutdown(socket.SHUT_RDWR)
def shutdown(self): def shutdown(self):
"close socket, immediately." """close socket, immediately."""
if self.sock: if self.sock:
self.sock.close() self.sock.close()
self.sock = None self.sock = None
@ -480,3 +457,60 @@ class WebSocket(object):
self.sock = None self.sock = None
self.connected = False self.connected = False
raise raise
def create_connection(url, timeout=None, class_=WebSocket, **options):
"""
connect to url and return websocket object.
Connect to url and return the WebSocket object.
Passing optional timeout parameter will set the timeout on the socket.
If no timeout is supplied,
the global default timeout setting returned by getdefauttimeout() is used.
You can customize using 'options'.
If you set "header" list object, you can set your own custom header.
>>> conn = create_connection("ws://echo.websocket.org/",
... header=["User-Agent: MyProgram",
... "x-custom: header"])
timeout: socket timeout time. This value is integer.
if you set None for this value,
it means "use default_timeout value"
class_: class to instantiate when creating the connection. It has to implement
settimeout and connect. It's __init__ should be compatible with
WebSocket.__init__, i.e. accept all of it's kwargs.
options: "header" -> custom http header list or dict.
"cookie" -> cookie value.
"origin" -> custom origin url.
"suppress_origin" -> suppress outputting origin header.
"host" -> custom host header string.
"http_proxy_host" - http proxy host name.
"http_proxy_port" - http proxy port. If not set, set to 80.
"http_no_proxy" - host names, which doesn't use proxy.
"http_proxy_auth" - http proxy auth information.
tuple of username and password.
default is None
"enable_multithread" -> enable lock for multithread.
"redirect_limit" -> number of redirects to follow.
"sockopt" -> socket options
"sslopt" -> ssl option
"subprotocols" - array of available sub protocols.
default is None.
"skip_utf8_validation" - skip utf8 validation.
"socket" - pre-initialized stream socket.
"""
sockopt = options.pop("sockopt", [])
sslopt = options.pop("sslopt", {})
fire_cont_frame = options.pop("fire_cont_frame", False)
enable_multithread = options.pop("enable_multithread", False)
skip_utf8_validation = options.pop("skip_utf8_validation", False)
websock = class_(sockopt=sockopt, sslopt=sslopt,
fire_cont_frame=fire_cont_frame,
enable_multithread=enable_multithread,
skip_utf8_validation=skip_utf8_validation, **options)
websock.settimeout(timeout if timeout is not None else getdefaulttimeout())
websock.connect(url, **options)
return websock

View file

@ -25,24 +25,28 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
define websocket exceptions define websocket exceptions
""" """
class WebSocketException(Exception): class WebSocketException(Exception):
""" """
websocket exeception class. websocket exception class.
""" """
pass pass
class WebSocketProtocolException(WebSocketException): class WebSocketProtocolException(WebSocketException):
""" """
If the webscoket protocol is invalid, this exception will be raised. If the websocket protocol is invalid, this exception will be raised.
""" """
pass pass
class WebSocketPayloadException(WebSocketException): class WebSocketPayloadException(WebSocketException):
""" """
If the webscoket payload is invalid, this exception will be raised. If the websocket payload is invalid, this exception will be raised.
""" """
pass pass
class WebSocketConnectionClosedException(WebSocketException): class WebSocketConnectionClosedException(WebSocketException):
""" """
If remote host closed the connection or some network error happened, If remote host closed the connection or some network error happened,
@ -50,16 +54,35 @@ class WebSocketConnectionClosedException(WebSocketException):
""" """
pass pass
class WebSocketTimeoutException(WebSocketException): class WebSocketTimeoutException(WebSocketException):
""" """
WebSocketTimeoutException will be raised at socket timeout during read/write data. WebSocketTimeoutException will be raised at socket timeout during read/write data.
""" """
pass pass
class WebSocketProxyException(WebSocketException): class WebSocketProxyException(WebSocketException):
""" """
WebSocketProxyException will be raised when proxy error occured. WebSocketProxyException will be raised when proxy error occurred.
""" """
pass pass
class WebSocketBadStatusException(WebSocketException):
"""
WebSocketBadStatusException will be raised when we get bad handshake status code.
"""
def __init__(self, message, status_code, status_message=None, resp_headers=None):
msg = message % (status_code, status_message)
super(WebSocketBadStatusException, self).__init__(msg)
self.status_code = status_code
self.resp_headers = resp_headers
class WebSocketAddressException(WebSocketException):
"""
If the websocket address info cannot be found, this exception will be raised.
"""
pass

View file

@ -19,33 +19,55 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
Boston, MA 02110-1335 USA Boston, MA 02110-1335 USA
""" """
import hashlib
import hmac
import os
import six import six
if six.PY3:
from ._cookiejar import SimpleCookieJar
from ._exceptions import *
from ._http import *
from ._logging import *
from ._socket import *
if hasattr(six, 'PY3') and six.PY3:
from base64 import encodebytes as base64encode from base64 import encodebytes as base64encode
else: else:
from base64 import encodestring as base64encode from base64 import encodestring as base64encode
import uuid if hasattr(six, 'PY3') and six.PY3:
import hashlib if hasattr(six, 'PY34') and six.PY34:
from http import client as HTTPStatus
else:
from http import HTTPStatus
else:
import httplib as HTTPStatus
from ._logging import * __all__ = ["handshake_response", "handshake", "SUPPORTED_REDIRECT_STATUSES"]
from ._url import *
from ._socket import*
from ._http import *
from ._exceptions import *
__all__ = ["handshake_response", "handshake"] if hasattr(hmac, "compare_digest"):
compare_digest = hmac.compare_digest
else:
def compare_digest(s1, s2):
return s1 == s2
# websocket supported version. # websocket supported version.
VERSION = 13 VERSION = 13
SUPPORTED_REDIRECT_STATUSES = (HTTPStatus.MOVED_PERMANENTLY, HTTPStatus.FOUND, HTTPStatus.SEE_OTHER,)
SUCCESS_STATUSES = SUPPORTED_REDIRECT_STATUSES + (HTTPStatus.SWITCHING_PROTOCOLS,)
CookieJar = SimpleCookieJar()
class handshake_response(object): class handshake_response(object):
def __init__(self, status, headers, subprotocol): def __init__(self, status, headers, subprotocol):
self.status = status self.status = status
self.headers = headers self.headers = headers
self.subprotocol = subprotocol self.subprotocol = subprotocol
CookieJar.add(headers.get("set-cookie"))
def handshake(sock, hostname, port, resource, **options): def handshake(sock, hostname, port, resource, **options):
@ -56,6 +78,8 @@ def handshake(sock, hostname, port, resource, **options):
dump("request header", header_str) dump("request header", header_str)
status, resp = _get_resp_headers(sock) status, resp = _get_resp_headers(sock)
if status in SUPPORTED_REDIRECT_STATUSES:
return handshake_response(status, resp, None)
success, subproto = _validate(resp, key, options.get("subprotocols")) success, subproto = _validate(resp, key, options.get("subprotocols"))
if not success: if not success:
raise WebSocketException("Invalid WebSocket Header") raise WebSocketException("Invalid WebSocket Header")
@ -63,38 +87,68 @@ def handshake(sock, hostname, port, resource, **options):
return handshake_response(status, resp, subproto) return handshake_response(status, resp, subproto)
def _get_handshake_headers(resource, host, port, options): def _pack_hostname(hostname):
headers = [] # IPv6 address
headers.append("GET %s HTTP/1.1" % resource) if ':' in hostname:
headers.append("Upgrade: websocket") return '[' + hostname + ']'
headers.append("Connection: Upgrade")
if port == 80:
hostport = host
else:
hostport = "%s:%d" % (host, port)
if "host" in options and options["host"]: return hostname
def _get_handshake_headers(resource, host, port, options):
headers = [
"GET %s HTTP/1.1" % resource,
"Upgrade: websocket"
]
if port == 80 or port == 443:
hostport = _pack_hostname(host)
else:
hostport = "%s:%d" % (_pack_hostname(host), port)
if "host" in options and options["host"] is not None:
headers.append("Host: %s" % options["host"]) headers.append("Host: %s" % options["host"])
else: else:
headers.append("Host: %s" % hostport) headers.append("Host: %s" % hostport)
if "origin" in options and options["origin"]: if "suppress_origin" not in options or not options["suppress_origin"]:
if "origin" in options and options["origin"] is not None:
headers.append("Origin: %s" % options["origin"]) headers.append("Origin: %s" % options["origin"])
else: else:
headers.append("Origin: http://%s" % hostport) headers.append("Origin: http://%s" % hostport)
key = _create_sec_websocket_key() key = _create_sec_websocket_key()
# Append Sec-WebSocket-Key & Sec-WebSocket-Version if not manually specified
if not 'header' in options or 'Sec-WebSocket-Key' not in options['header']:
key = _create_sec_websocket_key()
headers.append("Sec-WebSocket-Key: %s" % key) headers.append("Sec-WebSocket-Key: %s" % key)
else:
key = options['header']['Sec-WebSocket-Key']
if not 'header' in options or 'Sec-WebSocket-Version' not in options['header']:
headers.append("Sec-WebSocket-Version: %s" % VERSION) headers.append("Sec-WebSocket-Version: %s" % VERSION)
if not 'connection' in options or options['connection'] is None:
headers.append('Connection: upgrade')
else:
headers.append(options['connection'])
subprotocols = options.get("subprotocols") subprotocols = options.get("subprotocols")
if subprotocols: if subprotocols:
headers.append("Sec-WebSocket-Protocol: %s" % ",".join(subprotocols)) headers.append("Sec-WebSocket-Protocol: %s" % ",".join(subprotocols))
if "header" in options: if "header" in options:
headers.extend(options["header"]) header = options["header"]
if isinstance(header, dict):
header = [
": ".join([k, v])
for k, v in header.items()
if v is not None
]
headers.extend(header)
cookie = options.get("cookie", None) server_cookie = CookieJar.get(host)
client_cookie = options.get("cookie", None)
cookie = "; ".join(filter(None, [server_cookie, client_cookie]))
if cookie: if cookie:
headers.append("Cookie: %s" % cookie) headers.append("Cookie: %s" % cookie)
@ -105,12 +159,13 @@ def _get_handshake_headers(resource, host, port, options):
return headers, key return headers, key
def _get_resp_headers(sock, success_status=101): def _get_resp_headers(sock, success_statuses=SUCCESS_STATUSES):
status, resp_headers = read_headers(sock) status, resp_headers, status_message = read_headers(sock)
if status != success_status: if status not in success_statuses:
raise WebSocketException("Handshake status %d" % status) raise WebSocketBadStatusException("Handshake status %d %s", status, status_message, resp_headers)
return status, resp_headers return status, resp_headers
_HEADERS_TO_CHECK = { _HEADERS_TO_CHECK = {
"upgrade": "websocket", "upgrade": "websocket",
"connection": "upgrade", "connection": "upgrade",
@ -143,7 +198,8 @@ def _validate(headers, key, subprotocols):
value = (key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").encode('utf-8') value = (key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").encode('utf-8')
hashed = base64encode(hashlib.sha1(value).digest()).strip().lower() hashed = base64encode(hashlib.sha1(value).digest()).strip().lower()
success = (hashed == result) success = compare_digest(hashed, result)
if success: if success:
return True, subproto return True, subproto
else: else:
@ -151,5 +207,5 @@ def _validate(headers, key, subprotocols):
def _create_sec_websocket_key(): def _create_sec_websocket_key():
uid = uuid.uuid4() randomness = os.urandom(16)
return base64encode(uid.bytes).decode('utf-8').strip() return base64encode(randomness).decode('utf-8').strip()

View file

@ -19,28 +19,41 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
Boston, MA 02110-1335 USA Boston, MA 02110-1335 USA
""" """
import six
import socket
import errno import errno
import os import os
import socket
import sys import sys
import six
from ._exceptions import *
from ._logging import *
from ._socket import*
from ._ssl_compat import *
from ._url import *
if six.PY3: if six.PY3:
from base64 import encodebytes as base64encode from base64 import encodebytes as base64encode
else: else:
from base64 import encodestring as base64encode from base64 import encodestring as base64encode
from ._logging import *
from ._url import *
from ._socket import*
from ._exceptions import *
from ._ssl_compat import *
__all__ = ["proxy_info", "connect", "read_headers"] __all__ = ["proxy_info", "connect", "read_headers"]
try:
import socks
ProxyConnectionError = socks.ProxyConnectionError
HAS_PYSOCKS = True
except:
class ProxyConnectionError(BaseException):
pass
HAS_PYSOCKS = False
class proxy_info(object): class proxy_info(object):
def __init__(self, **options): def __init__(self, **options):
self.type = options.get("proxy_type") or "http"
if not(self.type in ['http', 'socks4', 'socks5', 'socks5h']):
raise ValueError("proxy_type must be 'http', 'socks4', 'socks5' or 'socks5h'")
self.host = options.get("http_proxy_host", None) self.host = options.get("http_proxy_host", None)
if self.host: if self.host:
self.port = options.get("http_proxy_port", 0) self.port = options.get("http_proxy_port", 0)
@ -51,9 +64,54 @@ class proxy_info(object):
self.auth = None self.auth = None
self.no_proxy = None self.no_proxy = None
def connect(url, options, proxy):
def _open_proxied_socket(url, options, proxy):
hostname, port, resource, is_secure = parse_url(url) hostname, port, resource, is_secure = parse_url(url)
addrinfo_list, need_tunnel, auth = _get_addrinfo_list(hostname, port, is_secure, proxy)
if not HAS_PYSOCKS:
raise WebSocketException("PySocks module not found.")
ptype = socks.SOCKS5
rdns = False
if proxy.type == "socks4":
ptype = socks.SOCKS4
if proxy.type == "http":
ptype = socks.HTTP
if proxy.type[-1] == "h":
rdns = True
sock = socks.create_connection(
(hostname, port),
proxy_type = ptype,
proxy_addr = proxy.host,
proxy_port = proxy.port,
proxy_rdns = rdns,
proxy_username = proxy.auth[0] if proxy.auth else None,
proxy_password = proxy.auth[1] if proxy.auth else None,
timeout = options.timeout,
socket_options = DEFAULT_SOCKET_OPTION + options.sockopt
)
if is_secure:
if HAVE_SSL:
sock = _ssl_socket(sock, options.sslopt, hostname)
else:
raise WebSocketException("SSL not available.")
return sock, (hostname, port, resource)
def connect(url, options, proxy, socket):
if proxy.host and not socket and not (proxy.type == 'http'):
return _open_proxied_socket(url, options, proxy)
hostname, port, resource, is_secure = parse_url(url)
if socket:
return socket, (hostname, port, resource)
addrinfo_list, need_tunnel, auth = _get_addrinfo_list(
hostname, port, is_secure, proxy)
if not addrinfo_list: if not addrinfo_list:
raise WebSocketException( raise WebSocketException(
"Host not found.: " + hostname + ":" + str(port)) "Host not found.: " + hostname + ":" + str(port))
@ -78,22 +136,33 @@ def connect(url, options, proxy):
def _get_addrinfo_list(hostname, port, is_secure, proxy): def _get_addrinfo_list(hostname, port, is_secure, proxy):
phost, pport, pauth = get_proxy_info(hostname, is_secure, phost, pport, pauth = get_proxy_info(
proxy.host, proxy.port, proxy.auth, proxy.no_proxy) hostname, is_secure, proxy.host, proxy.port, proxy.auth, proxy.no_proxy)
try:
# when running on windows 10, getaddrinfo without socktype returns a socktype 0.
# This generates an error exception: `_on_error: exception Socket type must be stream or datagram, not 0`
# or `OSError: [Errno 22] Invalid argument` when creating socket. Force the socket type to SOCK_STREAM.
if not phost: if not phost:
addrinfo_list = socket.getaddrinfo(hostname, port, 0, 0, socket.SOL_TCP) addrinfo_list = socket.getaddrinfo(
hostname, port, 0, socket.SOCK_STREAM, socket.SOL_TCP)
return addrinfo_list, False, None return addrinfo_list, False, None
else: else:
pport = pport and pport or 80 pport = pport and pport or 80
addrinfo_list = socket.getaddrinfo(phost, pport, 0, 0, socket.SOL_TCP) # when running on windows 10, the getaddrinfo used above
# returns a socktype 0. This generates an error exception:
# _on_error: exception Socket type must be stream or datagram, not 0
# Force the socket type to SOCK_STREAM
addrinfo_list = socket.getaddrinfo(phost, pport, 0, socket.SOCK_STREAM, socket.SOL_TCP)
return addrinfo_list, True, pauth return addrinfo_list, True, pauth
except socket.gaierror as e:
raise WebSocketAddressException(e)
def _open_socket(addrinfo_list, sockopt, timeout): def _open_socket(addrinfo_list, sockopt, timeout):
err = None err = None
for addrinfo in addrinfo_list: for addrinfo in addrinfo_list:
family = addrinfo[0] family, socktype, proto = addrinfo[:3]
sock = socket.socket(family) sock = socket.socket(family, socktype, proto)
sock.settimeout(timeout) sock.settimeout(timeout)
for opts in DEFAULT_SOCKET_OPTION: for opts in DEFAULT_SOCKET_OPTION:
sock.setsockopt(*opts) sock.setsockopt(*opts)
@ -101,37 +170,71 @@ def _open_socket(addrinfo_list, sockopt, timeout):
sock.setsockopt(*opts) sock.setsockopt(*opts)
address = addrinfo[4] address = addrinfo[4]
err = None
while not err:
try: try:
sock.connect(address) sock.connect(address)
except ProxyConnectionError as error:
err = WebSocketProxyException(str(error))
err.remote_ip = str(address[0])
continue
except socket.error as error: except socket.error as error:
error.remote_ip = str(address[0]) error.remote_ip = str(address[0])
if error.errno in (errno.ECONNREFUSED, ): try:
eConnRefused = (errno.ECONNREFUSED, errno.WSAECONNREFUSED)
except:
eConnRefused = (errno.ECONNREFUSED, )
if error.errno == errno.EINTR:
continue
elif error.errno in eConnRefused:
err = error err = error
continue continue
else: else:
raise raise error
else: else:
break break
else: else:
continue
break
else:
if err:
raise err raise err
return sock return sock
def _can_use_sni(): def _can_use_sni():
return (six.PY2 and sys.version_info[1] >= 7 and sys.version_info[2] >= 9) or (six.PY3 and sys.version_info[2] >= 2) return six.PY2 and sys.version_info >= (2, 7, 9) or sys.version_info >= (3, 2)
def _wrap_sni_socket(sock, sslopt, hostname, check_hostname): def _wrap_sni_socket(sock, sslopt, hostname, check_hostname):
context = ssl.SSLContext(sslopt.get('ssl_version', ssl.PROTOCOL_SSLv23)) context = ssl.SSLContext(sslopt.get('ssl_version', ssl.PROTOCOL_SSLv23))
context.load_verify_locations(cafile=sslopt.get('ca_certs', None)) if sslopt.get('cert_reqs', ssl.CERT_NONE) != ssl.CERT_NONE:
# see https://github.com/liris/websocket-client/commit/b96a2e8fa765753e82eea531adb19716b52ca3ca#commitcomment-10803153 cafile = sslopt.get('ca_certs', None)
capath = sslopt.get('ca_cert_path', None)
if cafile or capath:
context.load_verify_locations(cafile=cafile, capath=capath)
elif hasattr(context, 'load_default_certs'):
context.load_default_certs(ssl.Purpose.SERVER_AUTH)
if sslopt.get('certfile', None):
context.load_cert_chain(
sslopt['certfile'],
sslopt.get('keyfile', None),
sslopt.get('password', None),
)
# see
# https://github.com/liris/websocket-client/commit/b96a2e8fa765753e82eea531adb19716b52ca3ca#commitcomment-10803153
context.verify_mode = sslopt['cert_reqs'] context.verify_mode = sslopt['cert_reqs']
if HAVE_CONTEXT_CHECK_HOSTNAME: if HAVE_CONTEXT_CHECK_HOSTNAME:
context.check_hostname = check_hostname context.check_hostname = check_hostname
if 'ciphers' in sslopt: if 'ciphers' in sslopt:
context.set_ciphers(sslopt['ciphers']) context.set_ciphers(sslopt['ciphers'])
if 'cert_chain' in sslopt:
certfile, keyfile, password = sslopt['cert_chain']
context.load_cert_chain(certfile, keyfile, password)
if 'ecdh_curve' in sslopt:
context.set_ecdh_curve(sslopt['ecdh_curve'])
return context.wrap_socket( return context.wrap_socket(
sock, sock,
@ -143,12 +246,19 @@ def _wrap_sni_socket(sock, sslopt, hostname, check_hostname):
def _ssl_socket(sock, user_sslopt, hostname): def _ssl_socket(sock, user_sslopt, hostname):
sslopt = dict(cert_reqs=ssl.CERT_REQUIRED) sslopt = dict(cert_reqs=ssl.CERT_REQUIRED)
certPath = os.path.join(
os.path.dirname(__file__), "cacert.pem")
if os.path.isfile(certPath):
sslopt['ca_certs'] = certPath
sslopt.update(user_sslopt) sslopt.update(user_sslopt)
check_hostname = sslopt["cert_reqs"] != ssl.CERT_NONE and sslopt.pop('check_hostname', True)
certPath = os.environ.get('WEBSOCKET_CLIENT_CA_BUNDLE')
if certPath and os.path.isfile(certPath) \
and user_sslopt.get('ca_certs', None) is None \
and user_sslopt.get('ca_cert', None) is None:
sslopt['ca_certs'] = certPath
elif certPath and os.path.isdir(certPath) \
and user_sslopt.get('ca_cert_path', None) is None:
sslopt['ca_cert_path'] = certPath
check_hostname = sslopt["cert_reqs"] != ssl.CERT_NONE and sslopt.pop(
'check_hostname', True)
if _can_use_sni(): if _can_use_sni():
sock = _wrap_sni_socket(sock, sslopt, hostname, check_hostname) sock = _wrap_sni_socket(sock, sslopt, hostname, check_hostname)
@ -161,6 +271,7 @@ def _ssl_socket(sock, user_sslopt, hostname):
return sock return sock
def _tunnel(sock, host, port, auth): def _tunnel(sock, host, port, auth):
debug("Connecting proxy...") debug("Connecting proxy...")
connect_header = "CONNECT %s:%d HTTP/1.0\r\n" % (host, port) connect_header = "CONNECT %s:%d HTTP/1.0\r\n" % (host, port)
@ -169,7 +280,7 @@ def _tunnel(sock, host, port, auth):
auth_str = auth[0] auth_str = auth[0]
if auth[1]: if auth[1]:
auth_str += ":" + auth[1] auth_str += ":" + auth[1]
encoded_str = base64encode(auth_str.encode()).strip().decode() encoded_str = base64encode(auth_str.encode()).strip().decode().replace('\n', '')
connect_header += "Proxy-Authorization: Basic %s\r\n" % encoded_str connect_header += "Proxy-Authorization: Basic %s\r\n" % encoded_str
connect_header += "\r\n" connect_header += "\r\n"
dump("request header", connect_header) dump("request header", connect_header)
@ -177,18 +288,20 @@ def _tunnel(sock, host, port, auth):
send(sock, connect_header) send(sock, connect_header)
try: try:
status, resp_headers = read_headers(sock) status, resp_headers, status_message = read_headers(sock)
except Exception as e: except Exception as e:
raise WebSocketProxyException(str(e)) raise WebSocketProxyException(str(e))
if status != 200: if status != 200:
raise WebSocketProxyException( raise WebSocketProxyException(
"failed CONNECT via proxy status: %r" + status) "failed CONNECT via proxy status: %r" % status)
return sock return sock
def read_headers(sock): def read_headers(sock):
status = None status = None
status_message = None
headers = {} headers = {}
trace("--- response header ---") trace("--- response header ---")
@ -202,14 +315,16 @@ def read_headers(sock):
status_info = line.split(" ", 2) status_info = line.split(" ", 2)
status = int(status_info[1]) status = int(status_info[1])
if len(status_info) > 2:
status_message = status_info[2]
else: else:
kv = line.split(":", 1) kv = line.split(":", 1)
if len(kv) == 2: if len(kv) == 2:
key, value = kv key, value = kv
headers[key.lower()] = value.strip().lower() headers[key.lower()] = value.strip()
else: else:
raise WebSocketException("Invalid header") raise WebSocketException("Invalid header")
trace("-----------------------") trace("-----------------------")
return status, headers return status, headers, status_message

View file

@ -19,30 +19,36 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
Boston, MA 02110-1335 USA Boston, MA 02110-1335 USA
""" """
import logging import logging
_logger = logging.getLogger() _logger = logging.getLogger('websocket')
try:
from logging import NullHandler
except ImportError:
class NullHandler(logging.Handler):
def emit(self, record):
pass
_logger.addHandler(NullHandler())
_traceEnabled = False _traceEnabled = False
__all__ = ["enableTrace", "dump", "error", "debug", "trace", __all__ = ["enableTrace", "dump", "error", "warning", "debug", "trace",
"isEnabledForError", "isEnabledForDebug"] "isEnabledForError", "isEnabledForDebug", "isEnabledForTrace"]
def enableTrace(tracable): def enableTrace(traceable, handler = logging.StreamHandler()):
""" """
turn on/off the tracability. turn on/off the traceability.
tracable: boolean value. if set True, tracability is enabled. traceable: boolean value. if set True, traceability is enabled.
""" """
global _traceEnabled global _traceEnabled
_traceEnabled = tracable _traceEnabled = traceable
if tracable: if traceable:
if not _logger.handlers: _logger.addHandler(handler)
_logger.addHandler(logging.StreamHandler())
_logger.setLevel(logging.DEBUG) _logger.setLevel(logging.DEBUG)
def dump(title, message): def dump(title, message):
if _traceEnabled: if _traceEnabled:
_logger.debug("--- " + title + " ---") _logger.debug("--- " + title + " ---")
@ -54,6 +60,10 @@ def error(msg):
_logger.error(msg) _logger.error(msg)
def warning(msg):
_logger.warning(msg)
def debug(msg): def debug(msg):
_logger.debug(msg) _logger.debug(msg)
@ -69,3 +79,6 @@ def isEnabledForError():
def isEnabledForDebug(): def isEnabledForDebug():
return _logger.isEnabledFor(logging.DEBUG) return _logger.isEnabledFor(logging.DEBUG)
def isEnabledForTrace():
return _traceEnabled

View file

@ -19,13 +19,16 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
Boston, MA 02110-1335 USA Boston, MA 02110-1335 USA
""" """
import errno
import select
import socket import socket
import six import six
import sys
from ._exceptions import * from ._exceptions import *
from ._utils import *
from ._ssl_compat import * from ._ssl_compat import *
from ._utils import *
DEFAULT_SOCKET_OPTION = [(socket.SOL_TCP, socket.TCP_NODELAY, 1)] DEFAULT_SOCKET_OPTION = [(socket.SOL_TCP, socket.TCP_NODELAY, 1)]
if hasattr(socket, "SO_KEEPALIVE"): if hasattr(socket, "SO_KEEPALIVE"):
@ -42,7 +45,9 @@ _default_timeout = None
__all__ = ["DEFAULT_SOCKET_OPTION", "sock_opt", "setdefaulttimeout", "getdefaulttimeout", __all__ = ["DEFAULT_SOCKET_OPTION", "sock_opt", "setdefaulttimeout", "getdefaulttimeout",
"recv", "recv_line", "send"] "recv", "recv_line", "send"]
class sock_opt(object): class sock_opt(object):
def __init__(self, sockopt, sslopt): def __init__(self, sockopt, sslopt):
if sockopt is None: if sockopt is None:
sockopt = [] sockopt = []
@ -52,6 +57,7 @@ class sock_opt(object):
self.sslopt = sslopt self.sslopt = sslopt
self.timeout = None self.timeout = None
def setdefaulttimeout(timeout): def setdefaulttimeout(timeout):
""" """
Set the global timeout setting to connect. Set the global timeout setting to connect.
@ -73,22 +79,42 @@ def recv(sock, bufsize):
if not sock: if not sock:
raise WebSocketConnectionClosedException("socket is already closed.") raise WebSocketConnectionClosedException("socket is already closed.")
def _recv():
try: try:
bytes = sock.recv(bufsize) return sock.recv(bufsize)
except SSLWantReadError:
pass
except socket.error as exc:
error_code = extract_error_code(exc)
if error_code is None:
raise
if error_code != errno.EAGAIN or error_code != errno.EWOULDBLOCK:
raise
r, w, e = select.select((sock, ), (), (), sock.gettimeout())
if r:
return sock.recv(bufsize)
try:
if sock.gettimeout() == 0:
bytes_ = sock.recv(bufsize)
else:
bytes_ = _recv()
except socket.timeout as e: except socket.timeout as e:
message = extract_err_message(e) message = extract_err_message(e)
raise WebSocketTimeoutException(message) raise WebSocketTimeoutException(message)
except SSLError as e: except SSLError as e:
message = extract_err_message(e) message = extract_err_message(e)
if message == "The read operation timed out": if isinstance(message, str) and 'timed out' in message:
raise WebSocketTimeoutException(message) raise WebSocketTimeoutException(message)
else: else:
raise raise
if not bytes: if not bytes_:
raise WebSocketConnectionClosedException("Connection is already closed.") raise WebSocketConnectionClosedException(
"Connection is already closed.")
return bytes return bytes_
def recv_line(sock): def recv_line(sock):
@ -108,14 +134,33 @@ def send(sock, data):
if not sock: if not sock:
raise WebSocketConnectionClosedException("socket is already closed.") raise WebSocketConnectionClosedException("socket is already closed.")
def _send():
try: try:
return sock.send(data) return sock.send(data)
except SSLWantWriteError:
pass
except socket.error as exc:
error_code = extract_error_code(exc)
if error_code is None:
raise
if error_code != errno.EAGAIN or error_code != errno.EWOULDBLOCK:
raise
r, w, e = select.select((), (sock, ), (), sock.gettimeout())
if w:
return sock.send(data)
try:
if sock.gettimeout() == 0:
return sock.send(data)
else:
return _send()
except socket.timeout as e: except socket.timeout as e:
message = extract_err_message(e) message = extract_err_message(e)
raise WebSocketTimeoutException(message) raise WebSocketTimeoutException(message)
except Exception as e: except Exception as e:
message = extract_err_message(e) message = extract_err_message(e)
if message and "timed out" in message: if isinstance(message, str) and "timed out" in message:
raise WebSocketTimeoutException(message) raise WebSocketTimeoutException(message)
else: else:
raise raise

View file

@ -19,12 +19,13 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
Boston, MA 02110-1335 USA Boston, MA 02110-1335 USA
""" """
__all__ = ["HAVE_SSL", "ssl", "SSLError", "SSLWantReadError", "SSLWantWriteError"]
__all__ = ["HAVE_SSL", "ssl", "SSLError"]
try: try:
import ssl import ssl
from ssl import SSLError from ssl import SSLError
from ssl import SSLWantReadError
from ssl import SSLWantWriteError
if hasattr(ssl, 'SSLContext') and hasattr(ssl.SSLContext, 'check_hostname'): if hasattr(ssl, 'SSLContext') and hasattr(ssl.SSLContext, 'check_hostname'):
HAVE_CONTEXT_CHECK_HOSTNAME = True HAVE_CONTEXT_CHECK_HOSTNAME = True
else: else:
@ -42,4 +43,12 @@ except ImportError:
class SSLError(Exception): class SSLError(Exception):
pass pass
class SSLWantReadError(Exception):
pass
class SSLWantWriteError(Exception):
pass
ssl = lambda: None
HAVE_SSL = False HAVE_SSL = False

View file

@ -20,8 +20,12 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
""" """
from six.moves.urllib.parse import urlparse
import os import os
import socket
import struct
from six.moves.urllib.parse import urlparse
__all__ = ["parse_url", "get_proxy_info"] __all__ = ["parse_url", "get_proxy_info"]
@ -66,24 +70,55 @@ def parse_url(url):
if parsed.query: if parsed.query:
resource += "?" + parsed.query resource += "?" + parsed.query
return (hostname, port, resource, is_secure) return hostname, port, resource, is_secure
DEFAULT_NO_PROXY_HOST = ["localhost", "127.0.0.1"] DEFAULT_NO_PROXY_HOST = ["localhost", "127.0.0.1"]
def _is_ip_address(addr):
try:
socket.inet_aton(addr)
except socket.error:
return False
else:
return True
def _is_subnet_address(hostname):
try:
addr, netmask = hostname.split("/")
return _is_ip_address(addr) and 0 <= int(netmask) < 32
except ValueError:
return False
def _is_address_in_network(ip, net):
ipaddr = struct.unpack('I', socket.inet_aton(ip))[0]
netaddr, bits = net.split('/')
netmask = struct.unpack('I', socket.inet_aton(netaddr))[0] & ((2 << int(bits) - 1) - 1)
return ipaddr & netmask == netmask
def _is_no_proxy_host(hostname, no_proxy): def _is_no_proxy_host(hostname, no_proxy):
if not no_proxy: if not no_proxy:
v = os.environ.get("no_proxy", "").replace(" ", "") v = os.environ.get("no_proxy", "").replace(" ", "")
if v:
no_proxy = v.split(",") no_proxy = v.split(",")
if not no_proxy: if not no_proxy:
no_proxy = DEFAULT_NO_PROXY_HOST no_proxy = DEFAULT_NO_PROXY_HOST
return hostname in no_proxy if hostname in no_proxy:
return True
elif _is_ip_address(hostname):
return any([_is_address_in_network(hostname, subnet) for subnet in no_proxy if _is_subnet_address(subnet)])
return False
def get_proxy_info(hostname, is_secure, def get_proxy_info(
proxy_host=None, proxy_port=0, proxy_auth=None, no_proxy=None): hostname, is_secure, proxy_host=None, proxy_port=0, proxy_auth=None,
no_proxy=None, proxy_type='http'):
""" """
try to retrieve proxy host and port from environment try to retrieve proxy host and port from environment
if not provided in options. if not provided in options.
@ -100,9 +135,12 @@ def get_proxy_info(hostname, is_secure,
options: "http_proxy_host" - http proxy host name. options: "http_proxy_host" - http proxy host name.
"http_proxy_port" - http proxy port. "http_proxy_port" - http proxy port.
"http_no_proxy" - host names, which doesn't use proxy. "http_no_proxy" - host names, which doesn't use proxy.
"http_proxy_auth" - http proxy auth infomation. "http_proxy_auth" - http proxy auth information.
tuple of username and password. tuple of username and password.
defualt is None default is None
"proxy_type" - if set to "socks5" PySocks wrapper
will be used in place of a http proxy.
default is "http"
""" """
if _is_no_proxy_host(hostname, no_proxy): if _is_no_proxy_host(hostname, no_proxy):
return None, 0, None return None, 0, None

View file

@ -19,24 +19,34 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
Boston, MA 02110-1335 USA Boston, MA 02110-1335 USA
""" """
import six import six
__all__ = ["NoLock", "validate_utf8", "extract_err_message"] __all__ = ["NoLock", "validate_utf8", "extract_err_message", "extract_error_code"]
class NoLock(object): class NoLock(object):
def __enter__(self): def __enter__(self):
pass pass
def __exit__(self,type, value, traceback): def __exit__(self, exc_type, exc_value, traceback):
pass pass
try:
# If wsaccel is available we use compiled routines to validate UTF-8
# strings.
from wsaccel.utf8validator import Utf8Validator
def _validate_utf8(utfbytes):
return Utf8Validator().validate(utfbytes)[0]
except ImportError:
# UTF-8 validator # UTF-8 validator
# python implementation of http://bjoern.hoehrmann.de/utf-8/decoder/dfa/ # python implementation of http://bjoern.hoehrmann.de/utf-8/decoder/dfa/
UTF8_ACCEPT = 0 _UTF8_ACCEPT = 0
UTF8_REJECT=12 _UTF8_REJECT = 12
_UTF8D = [ _UTF8D = [
# The first part of the table maps bytes to character classes that # The first part of the table maps bytes to character classes that
@ -61,10 +71,24 @@ _UTF8D = [
def _decode(state, codep, ch): def _decode(state, codep, ch):
tp = _UTF8D[ch] tp = _UTF8D[ch]
codep = (ch & 0x3f ) | (codep << 6) if (state != UTF8_ACCEPT) else (0xff >> tp) & (ch) codep = (ch & 0x3f) | (codep << 6) if (
state != _UTF8_ACCEPT) else (0xff >> tp) & ch
state = _UTF8D[256 + state + tp] state = _UTF8D[256 + state + tp]
return state, codep; return state, codep
def _validate_utf8(utfbytes):
state = _UTF8_ACCEPT
codep = 0
for i in utfbytes:
if six.PY2:
i = ord(i)
state, codep = _decode(state, codep, i)
if state == _UTF8_REJECT:
return False
return True
def validate_utf8(utfbytes): def validate_utf8(utfbytes):
""" """
@ -72,17 +96,16 @@ def validate_utf8(utfbytes):
utfbytes: utf byte string to check. utfbytes: utf byte string to check.
return value: if valid utf8 string, return true. Otherwise, return false. return value: if valid utf8 string, return true. Otherwise, return false.
""" """
state = UTF8_ACCEPT return _validate_utf8(utfbytes)
codep = 0
for i in utfbytes:
if six.PY2:
i = ord(i)
state, codep = _decode(state, codep, i)
if state == UTF8_REJECT:
return False
return True
def extract_err_message(exception): def extract_err_message(exception):
return getattr(exception, 'strerror', str(exception)) if exception.args:
return exception.args[0]
else:
return None
def extract_error_code(exception):
if exception.args and len(exception.args) > 1:
return exception.args[0] if isinstance(exception.args[0], int) else None

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,98 @@
import unittest
from websocket._cookiejar import SimpleCookieJar
try:
import Cookie
except:
import http.cookies as Cookie
class CookieJarTest(unittest.TestCase):
def testAdd(self):
cookie_jar = SimpleCookieJar()
cookie_jar.add("")
self.assertFalse(cookie_jar.jar, "Cookie with no domain should not be added to the jar")
cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b")
self.assertFalse(cookie_jar.jar, "Cookie with no domain should not be added to the jar")
cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b; domain=.abc")
self.assertTrue(".abc" in cookie_jar.jar)
cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b; domain=abc")
self.assertTrue(".abc" in cookie_jar.jar)
self.assertTrue("abc" not in cookie_jar.jar)
cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b; c=d; domain=abc")
self.assertEquals(cookie_jar.get("abc"), "a=b; c=d")
cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b; c=d; domain=abc")
cookie_jar.add("e=f; domain=abc")
self.assertEquals(cookie_jar.get("abc"), "a=b; c=d; e=f")
cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b; c=d; domain=abc")
cookie_jar.add("e=f; domain=.abc")
self.assertEquals(cookie_jar.get("abc"), "a=b; c=d; e=f")
cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b; c=d; domain=abc")
cookie_jar.add("e=f; domain=xyz")
self.assertEquals(cookie_jar.get("abc"), "a=b; c=d")
self.assertEquals(cookie_jar.get("xyz"), "e=f")
self.assertEquals(cookie_jar.get("something"), "")
def testSet(self):
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b")
self.assertFalse(cookie_jar.jar, "Cookie with no domain should not be added to the jar")
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; domain=.abc")
self.assertTrue(".abc" in cookie_jar.jar)
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; domain=abc")
self.assertTrue(".abc" in cookie_jar.jar)
self.assertTrue("abc" not in cookie_jar.jar)
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; c=d; domain=abc")
self.assertEquals(cookie_jar.get("abc"), "a=b; c=d")
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; c=d; domain=abc")
cookie_jar.set("e=f; domain=abc")
self.assertEquals(cookie_jar.get("abc"), "e=f")
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; c=d; domain=abc")
cookie_jar.set("e=f; domain=.abc")
self.assertEquals(cookie_jar.get("abc"), "e=f")
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; c=d; domain=abc")
cookie_jar.set("e=f; domain=xyz")
self.assertEquals(cookie_jar.get("abc"), "a=b; c=d")
self.assertEquals(cookie_jar.get("xyz"), "e=f")
self.assertEquals(cookie_jar.get("something"), "")
def testGet(self):
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; c=d; domain=abc.com")
self.assertEquals(cookie_jar.get("abc.com"), "a=b; c=d")
self.assertEquals(cookie_jar.get("x.abc.com"), "a=b; c=d")
self.assertEquals(cookie_jar.get("abc.com.es"), "")
self.assertEquals(cookie_jar.get("xabc.com"), "")
cookie_jar.set("a=b; c=d; domain=.abc.com")
self.assertEquals(cookie_jar.get("abc.com"), "a=b; c=d")
self.assertEquals(cookie_jar.get("x.abc.com"), "a=b; c=d")
self.assertEquals(cookie_jar.get("abc.com.es"), "")
self.assertEquals(cookie_jar.get("xabc.com"), "")

View file

@ -1,14 +1,33 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
import six
import sys import sys
sys.path[0:0] = [""] sys.path[0:0] = [""]
import os import os
import os.path import os.path
import base64
import socket import socket
import six
# websocket-client
import websocket as ws
from websocket._handshake import _create_sec_websocket_key, \
_validate as _validate_header
from websocket._http import read_headers
from websocket._url import get_proxy_info, parse_url
from websocket._utils import validate_utf8
if six.PY3:
from base64 import decodebytes as base64decode
else:
from base64 import decodestring as base64decode
if sys.version_info[0] == 2 and sys.version_info[1] < 7:
import unittest2 as unittest
else:
import unittest
try: try:
from ssl import SSLError from ssl import SSLError
except ImportError: except ImportError:
@ -16,37 +35,15 @@ except ImportError:
class SSLError(Exception): class SSLError(Exception):
pass pass
if sys.version_info[0] == 2 and sys.version_info[1] < 7:
import unittest2 as unittest
else:
import unittest
import uuid
if six.PY3:
from base64 import decodebytes as base64decode
else:
from base64 import decodestring as base64decode
# websocket-client
import websocket as ws
from websocket._handshake import _create_sec_websocket_key
from websocket._url import parse_url, get_proxy_info
from websocket._utils import validate_utf8
from websocket._handshake import _validate as _validate_header
from websocket._http import read_headers
# Skip test to access the internet. # Skip test to access the internet.
TEST_WITH_INTERNET = os.environ.get('TEST_WITH_INTERNET', '0') == '1' TEST_WITH_INTERNET = os.environ.get('TEST_WITH_INTERNET', '0') == '1'
# Skip Secure WebSocket test. # Skip Secure WebSocket test.
TEST_SECURE_WS = True TEST_SECURE_WS = True
TRACABLE = False TRACEABLE = True
def create_mask_key(n): def create_mask_key(_):
return "abcd" return "abcd"
@ -58,6 +55,9 @@ class SockMock(object):
def add_packet(self, data): def add_packet(self, data):
self.data.append(data) self.data.append(data)
def gettimeout(self):
return None
def recv(self, bufsize): def recv(self, bufsize):
if self.data: if self.data:
e = self.data.pop(0) e = self.data.pop(0)
@ -86,7 +86,7 @@ class HeaderSockMock(SockMock):
class WebSocketTest(unittest.TestCase): class WebSocketTest(unittest.TestCase):
def setUp(self): def setUp(self):
ws.enableTrace(TRACABLE) ws.enableTrace(TRACEABLE)
def tearDown(self): def tearDown(self):
pass pass
@ -224,9 +224,9 @@ class WebSocketTest(unittest.TestCase):
def testReadHeader(self): def testReadHeader(self):
status, header = read_headers(HeaderSockMock("data/header01.txt")) status, header, status_message = read_headers(HeaderSockMock("data/header01.txt"))
self.assertEqual(status, 101) self.assertEqual(status, 101)
self.assertEqual(header["connection"], "upgrade") self.assertEqual(header["connection"], "Upgrade")
HeaderSockMock("data/header02.txt") HeaderSockMock("data/header02.txt")
self.assertRaises(ws.WebSocketException, read_headers, HeaderSockMock("data/header02.txt")) self.assertRaises(ws.WebSocketException, read_headers, HeaderSockMock("data/header02.txt"))
@ -263,7 +263,7 @@ class WebSocketTest(unittest.TestCase):
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testIter(self): def testIter(self):
count = 2 count = 2
for rsvp in ws.create_connection('ws://stream.meetup.com/2/rsvps'): for _ in ws.create_connection('ws://stream.meetup.com/2/rsvps'):
count -= 1 count -= 1
if count == 0: if count == 0:
break break
@ -282,7 +282,7 @@ class WebSocketTest(unittest.TestCase):
# s.add_packet(SSLError("The read operation timed out")) # s.add_packet(SSLError("The read operation timed out"))
s.add_packet(six.b("baz")) s.add_packet(six.b("baz"))
with self.assertRaises(ws.WebSocketTimeoutException): with self.assertRaises(ws.WebSocketTimeoutException):
data = sock.frame_buffer.recv_strict(9) sock.frame_buffer.recv_strict(9)
# if six.PY2: # if six.PY2:
# with self.assertRaises(ws.WebSocketTimeoutException): # with self.assertRaises(ws.WebSocketTimeoutException):
# data = sock._recv_strict(9) # data = sock._recv_strict(9)
@ -292,7 +292,7 @@ class WebSocketTest(unittest.TestCase):
data = sock.frame_buffer.recv_strict(9) data = sock.frame_buffer.recv_strict(9)
self.assertEqual(data, six.b("foobarbaz")) self.assertEqual(data, six.b("foobarbaz"))
with self.assertRaises(ws.WebSocketConnectionClosedException): with self.assertRaises(ws.WebSocketConnectionClosedException):
data = sock.frame_buffer.recv_strict(1) sock.frame_buffer.recv_strict(1)
def testRecvTimeout(self): def testRecvTimeout(self):
sock = ws.WebSocket() sock = ws.WebSocket()
@ -303,13 +303,13 @@ class WebSocketTest(unittest.TestCase):
s.add_packet(socket.timeout()) s.add_packet(socket.timeout())
s.add_packet(six.b("\x4e\x43\x33\x0e\x10\x0f\x00\x40")) s.add_packet(six.b("\x4e\x43\x33\x0e\x10\x0f\x00\x40"))
with self.assertRaises(ws.WebSocketTimeoutException): with self.assertRaises(ws.WebSocketTimeoutException):
data = sock.recv() sock.recv()
with self.assertRaises(ws.WebSocketTimeoutException): with self.assertRaises(ws.WebSocketTimeoutException):
data = sock.recv() sock.recv()
data = sock.recv() data = sock.recv()
self.assertEqual(data, "Hello, World!") self.assertEqual(data, "Hello, World!")
with self.assertRaises(ws.WebSocketConnectionClosedException): with self.assertRaises(ws.WebSocketConnectionClosedException):
data = sock.recv() sock.recv()
def testRecvWithSimpleFragmentation(self): def testRecvWithSimpleFragmentation(self):
sock = ws.WebSocket() sock = ws.WebSocket()
@ -374,10 +374,10 @@ class WebSocketTest(unittest.TestCase):
sock = ws.WebSocket() sock = ws.WebSocket()
s = sock.sock = SockMock() s = sock.sock = SockMock()
# OPCODE=TEXT, FIN=0, MSG="Once more unto the breach, " # OPCODE=TEXT, FIN=0, MSG="Once more unto the breach, "
s.add_packet(six.b("\x01\x9babcd.\x0c\x00\x01A\x0f\x0c\x16\x04B\x16\n\x15" \ s.add_packet(six.b("\x01\x9babcd.\x0c\x00\x01A\x0f\x0c\x16\x04B\x16\n\x15"
"\rC\x10\t\x07C\x06\x13\x07\x02\x07\tNC")) "\rC\x10\t\x07C\x06\x13\x07\x02\x07\tNC"))
# OPCODE=CONT, FIN=0, MSG="dear friends, " # OPCODE=CONT, FIN=0, MSG="dear friends, "
s.add_packet(six.b("\x00\x8eabcd\x05\x07\x02\x16A\x04\x11\r\x04\x0c\x07" \ s.add_packet(six.b("\x00\x8eabcd\x05\x07\x02\x16A\x04\x11\r\x04\x0c\x07"
"\x17MB")) "\x17MB"))
# OPCODE=CONT, FIN=1, MSG="once more" # OPCODE=CONT, FIN=1, MSG="once more"
s.add_packet(six.b("\x80\x89abcd\x0e\x0c\x00\x01A\x0f\x0c\x16\x04")) s.add_packet(six.b("\x80\x89abcd\x0e\x0c\x00\x01A\x0f\x0c\x16\x04"))
@ -397,7 +397,7 @@ class WebSocketTest(unittest.TestCase):
# OPCODE=PING, FIN=1, MSG="Please PONG this" # OPCODE=PING, FIN=1, MSG="Please PONG this"
s.add_packet(six.b("\x89\x90abcd1\x0e\x06\x05\x12\x07C4.,$D\x15\n\n\x17")) s.add_packet(six.b("\x89\x90abcd1\x0e\x06\x05\x12\x07C4.,$D\x15\n\n\x17"))
# OPCODE=CONT, FIN=1, MSG="of a good thing" # OPCODE=CONT, FIN=1, MSG="of a good thing"
s.add_packet(six.b("\x80\x8fabcd\x0e\x04C\x05A\x05\x0c\x0b\x05B\x17\x0c" \ s.add_packet(six.b("\x80\x8fabcd\x0e\x04C\x05A\x05\x0c\x0b\x05B\x17\x0c"
"\x08\x0c\x04")) "\x08\x0c\x04"))
data = sock.recv() data = sock.recv()
self.assertEqual(data, "Too much of a good thing") self.assertEqual(data, "Too much of a good thing")
@ -464,12 +464,12 @@ class WebSocketTest(unittest.TestCase):
self.assertRaises(ws.WebSocketConnectionClosedException, s.send, "Hello") self.assertRaises(ws.WebSocketConnectionClosedException, s.send, "Hello")
self.assertRaises(ws.WebSocketConnectionClosedException, s.recv) self.assertRaises(ws.WebSocketConnectionClosedException, s.recv)
def testUUID4(self): def testNonce(self):
""" WebSocket key should be a UUID4. """ WebSocket key should be a random 16-byte nonce.
""" """
key = _create_sec_websocket_key() key = _create_sec_websocket_key()
u = uuid.UUID(bytes=base64decode(key.encode("utf-8"))) nonce = base64decode(key.encode("utf-8"))
self.assertEqual(4, u.version) self.assertEqual(16, len(nonce))
class WebSocketAppTest(unittest.TestCase): class WebSocketAppTest(unittest.TestCase):
@ -479,7 +479,7 @@ class WebSocketAppTest(unittest.TestCase):
""" """
def setUp(self): def setUp(self):
ws.enableTrace(TRACABLE) ws.enableTrace(TRACEABLE)
WebSocketAppTest.keep_running_open = WebSocketAppTest.NotSetYet() WebSocketAppTest.keep_running_open = WebSocketAppTest.NotSetYet()
WebSocketAppTest.keep_running_close = WebSocketAppTest.NotSetYet() WebSocketAppTest.keep_running_close = WebSocketAppTest.NotSetYet()
@ -512,14 +512,15 @@ class WebSocketAppTest(unittest.TestCase):
app = ws.WebSocketApp('ws://echo.websocket.org/', on_open=on_open, on_close=on_close) app = ws.WebSocketApp('ws://echo.websocket.org/', on_open=on_open, on_close=on_close)
app.run_forever() app.run_forever()
self.assertFalse(isinstance(WebSocketAppTest.keep_running_open, # if numpy is installed, this assertion fail
WebSocketAppTest.NotSetYet)) # self.assertFalse(isinstance(WebSocketAppTest.keep_running_open,
# WebSocketAppTest.NotSetYet))
self.assertFalse(isinstance(WebSocketAppTest.keep_running_close, # self.assertFalse(isinstance(WebSocketAppTest.keep_running_close,
WebSocketAppTest.NotSetYet)) # WebSocketAppTest.NotSetYet))
self.assertEqual(True, WebSocketAppTest.keep_running_open) # self.assertEqual(True, WebSocketAppTest.keep_running_open)
self.assertEqual(False, WebSocketAppTest.keep_running_close) # self.assertEqual(False, WebSocketAppTest.keep_running_close)
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testSockMaskKey(self): def testSockMaskKey(self):
@ -540,8 +541,9 @@ class WebSocketAppTest(unittest.TestCase):
app = ws.WebSocketApp('ws://echo.websocket.org/', on_open=on_open, get_mask_key=my_mask_key_func) app = ws.WebSocketApp('ws://echo.websocket.org/', on_open=on_open, get_mask_key=my_mask_key_func)
app.run_forever() app.run_forever()
# if numpu is installed, this assertion fail
# Note: We can't use 'is' for comparing the functions directly, need to use 'id'. # Note: We can't use 'is' for comparing the functions directly, need to use 'id'.
self.assertEqual(WebSocketAppTest.get_mask_key_id, id(my_mask_key_func)) # self.assertEqual(WebSocketAppTest.get_mask_key_id, id(my_mask_key_func))
class SockOptTest(unittest.TestCase): class SockOptTest(unittest.TestCase):
@ -552,6 +554,7 @@ class SockOptTest(unittest.TestCase):
self.assertNotEqual(s.sock.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY), 0) self.assertNotEqual(s.sock.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY), 0)
s.close() s.close()
class UtilsTest(unittest.TestCase): class UtilsTest(unittest.TestCase):
def testUtf8Validator(self): def testUtf8Validator(self):
state = validate_utf8(six.b('\xf0\x90\x80\x80')) state = validate_utf8(six.b('\xf0\x90\x80\x80'))
@ -561,6 +564,7 @@ class UtilsTest(unittest.TestCase):
state = validate_utf8(six.b('')) state = validate_utf8(six.b(''))
self.assertEqual(state, True) self.assertEqual(state, True)
class ProxyInfoTest(unittest.TestCase): class ProxyInfoTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.http_proxy = os.environ.get("http_proxy", None) self.http_proxy = os.environ.get("http_proxy", None)
@ -581,7 +585,6 @@ class ProxyInfoTest(unittest.TestCase):
elif "https_proxy" in os.environ: elif "https_proxy" in os.environ:
del os.environ["https_proxy"] del os.environ["https_proxy"]
def testProxyFromArgs(self): def testProxyFromArgs(self):
self.assertEqual(get_proxy_info("echo.websocket.org", False, proxy_host="localhost"), ("localhost", 0, None)) self.assertEqual(get_proxy_info("echo.websocket.org", False, proxy_host="localhost"), ("localhost", 0, None))
self.assertEqual(get_proxy_info("echo.websocket.org", False, proxy_host="localhost", proxy_port=3128), ("localhost", 3128, None)) self.assertEqual(get_proxy_info("echo.websocket.org", False, proxy_host="localhost", proxy_port=3128), ("localhost", 3128, None))
@ -602,7 +605,6 @@ class ProxyInfoTest(unittest.TestCase):
self.assertEqual(get_proxy_info("echo.websocket.org", True, proxy_host="localhost", proxy_port=3128, no_proxy=["echo.websocket.org"], proxy_auth=("a", "b")), self.assertEqual(get_proxy_info("echo.websocket.org", True, proxy_host="localhost", proxy_port=3128, no_proxy=["echo.websocket.org"], proxy_auth=("a", "b")),
(None, 0, None)) (None, 0, None))
def testProxyFromEnv(self): def testProxyFromEnv(self):
os.environ["http_proxy"] = "http://localhost/" os.environ["http_proxy"] = "http://localhost/"
self.assertEqual(get_proxy_info("echo.websocket.org", False), ("localhost", None, None)) self.assertEqual(get_proxy_info("echo.websocket.org", False), ("localhost", None, None))
@ -652,8 +654,11 @@ class ProxyInfoTest(unittest.TestCase):
os.environ["no_proxy"] = "example1.com,example2.com, echo.websocket.org" os.environ["no_proxy"] = "example1.com,example2.com, echo.websocket.org"
self.assertEqual(get_proxy_info("echo.websocket.org", True), (None, 0, None)) self.assertEqual(get_proxy_info("echo.websocket.org", True), (None, 0, None))
os.environ["http_proxy"] = "http://a:b@localhost:3128/"
os.environ["https_proxy"] = "http://a:b@localhost2:3128/"
os.environ["no_proxy"] = "127.0.0.0/8, 192.168.0.0/16"
self.assertEqual(get_proxy_info("127.0.0.1", False), (None, 0, None))
self.assertEqual(get_proxy_info("192.168.1.1", False), (None, 0, None))
if __name__ == "__main__": if __name__ == "__main__":