Add python-eventlet 0.16.1
[packages/trusty/python-eventlet.git] / eventlet / eventlet / websocket.py
1 import base64
2 import codecs
3 import collections
4 import errno
5 from random import Random
6 from socket import error as SocketError
7 import string
8 import struct
9 import sys
10 import time
11
12 try:
13     from hashlib import md5, sha1
14 except ImportError:  # pragma NO COVER
15     from md5 import md5
16     from sha import sha as sha1
17
18 from eventlet import semaphore
19 from eventlet import wsgi
20 from eventlet.green import socket
21 from eventlet.support import get_errno, six
22
23 # Python 2's utf8 decoding is more lenient than we'd like
24 # In order to pass autobahn's testsuite we need stricter validation
25 # if available...
26 for _mod in ('wsaccel.utf8validator', 'autobahn.utf8validator'):
27     # autobahn has it's own python-based validator. in newest versions
28     # this prefers to use wsaccel, a cython based implementation, if available.
29     # wsaccel may also be installed w/out autobahn, or with a earlier version.
30     try:
31         utf8validator = __import__(_mod, {}, {}, [''])
32     except ImportError:
33         utf8validator = None
34     else:
35         break
36
37 ACCEPTABLE_CLIENT_ERRORS = set((errno.ECONNRESET, errno.EPIPE))
38
39 __all__ = ["WebSocketWSGI", "WebSocket"]
40 PROTOCOL_GUID = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
41 VALID_CLOSE_STATUS = set(
42     list(range(1000, 1004)) +
43     list(range(1007, 1012)) +
44     # 3000-3999: reserved for use by libraries, frameworks,
45     # and applications
46     list(range(3000, 4000)) +
47     # 4000-4999: reserved for private use and thus can't
48     # be registered
49     list(range(4000, 5000))
50 )
51
52
53 class BadRequest(Exception):
54     def __init__(self, status='400 Bad Request', body=None, headers=None):
55         super(Exception, self).__init__()
56         self.status = status
57         self.body = body
58         self.headers = headers
59
60
61 class WebSocketWSGI(object):
62     """Wraps a websocket handler function in a WSGI application.
63
64     Use it like this::
65
66       @websocket.WebSocketWSGI
67       def my_handler(ws):
68           from_browser = ws.wait()
69           ws.send("from server")
70
71     The single argument to the function will be an instance of
72     :class:`WebSocket`.  To close the socket, simply return from the
73     function.  Note that the server will log the websocket request at
74     the time of closure.
75     """
76
77     def __init__(self, handler):
78         self.handler = handler
79         self.protocol_version = None
80         self.support_legacy_versions = True
81         self.supported_protocols = []
82         self.origin_checker = None
83
84     @classmethod
85     def configured(cls,
86                    handler=None,
87                    supported_protocols=None,
88                    origin_checker=None,
89                    support_legacy_versions=False):
90         def decorator(handler):
91             inst = cls(handler)
92             inst.support_legacy_versions = support_legacy_versions
93             inst.origin_checker = origin_checker
94             if supported_protocols:
95                 inst.supported_protocols = supported_protocols
96             return inst
97         if handler is None:
98             return decorator
99         return decorator(handler)
100
101     def __call__(self, environ, start_response):
102         http_connection_parts = [
103             part.strip()
104             for part in environ.get('HTTP_CONNECTION', '').lower().split(',')]
105         if not ('upgrade' in http_connection_parts and
106                 environ.get('HTTP_UPGRADE', '').lower() == 'websocket'):
107             # need to check a few more things here for true compliance
108             start_response('400 Bad Request', [('Connection', 'close')])
109             return []
110
111         try:
112             if 'HTTP_SEC_WEBSOCKET_VERSION' in environ:
113                 ws = self._handle_hybi_request(environ)
114             elif self.support_legacy_versions:
115                 ws = self._handle_legacy_request(environ)
116             else:
117                 raise BadRequest()
118         except BadRequest as e:
119             status = e.status
120             body = e.body or b''
121             headers = e.headers or []
122             start_response(status,
123                            [('Connection', 'close'), ] + headers)
124             return [body]
125
126         try:
127             self.handler(ws)
128         except socket.error as e:
129             if get_errno(e) not in ACCEPTABLE_CLIENT_ERRORS:
130                 raise
131         # Make sure we send the closing frame
132         ws._send_closing_frame(True)
133         # use this undocumented feature of eventlet.wsgi to ensure that it
134         # doesn't barf on the fact that we didn't call start_response
135         return wsgi.ALREADY_HANDLED
136
137     def _handle_legacy_request(self, environ):
138         sock = environ['eventlet.input'].get_socket()
139
140         if 'HTTP_SEC_WEBSOCKET_KEY1' in environ:
141             self.protocol_version = 76
142             if 'HTTP_SEC_WEBSOCKET_KEY2' not in environ:
143                 raise BadRequest()
144         else:
145             self.protocol_version = 75
146
147         if self.protocol_version == 76:
148             key1 = self._extract_number(environ['HTTP_SEC_WEBSOCKET_KEY1'])
149             key2 = self._extract_number(environ['HTTP_SEC_WEBSOCKET_KEY2'])
150             # There's no content-length header in the request, but it has 8
151             # bytes of data.
152             environ['wsgi.input'].content_length = 8
153             key3 = environ['wsgi.input'].read(8)
154             key = struct.pack(">II", key1, key2) + key3
155             response = md5(key).digest()
156
157         # Start building the response
158         scheme = 'ws'
159         if environ.get('wsgi.url_scheme') == 'https':
160             scheme = 'wss'
161         location = '%s://%s%s%s' % (
162             scheme,
163             environ.get('HTTP_HOST'),
164             environ.get('SCRIPT_NAME'),
165             environ.get('PATH_INFO')
166         )
167         qs = environ.get('QUERY_STRING')
168         if qs is not None:
169             location += '?' + qs
170         if self.protocol_version == 75:
171             handshake_reply = (
172                 b"HTTP/1.1 101 Web Socket Protocol Handshake\r\n"
173                 b"Upgrade: WebSocket\r\n"
174                 b"Connection: Upgrade\r\n"
175                 b"WebSocket-Origin: " + six.b(environ.get('HTTP_ORIGIN')) + b"\r\n"
176                 b"WebSocket-Location: " + six.b(location) + b"\r\n\r\n"
177             )
178         elif self.protocol_version == 76:
179             handshake_reply = (
180                 b"HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
181                 b"Upgrade: WebSocket\r\n"
182                 b"Connection: Upgrade\r\n"
183                 b"Sec-WebSocket-Origin: " + six.b(environ.get('HTTP_ORIGIN')) + b"\r\n"
184                 b"Sec-WebSocket-Protocol: " +
185                 six.b(environ.get('HTTP_SEC_WEBSOCKET_PROTOCOL', 'default')) + b"\r\n"
186                 b"Sec-WebSocket-Location: " + six.b(location) + b"\r\n"
187                 b"\r\n" + response
188             )
189         else:  # pragma NO COVER
190             raise ValueError("Unknown WebSocket protocol version.")
191         sock.sendall(handshake_reply)
192         return WebSocket(sock, environ, self.protocol_version)
193
194     def _handle_hybi_request(self, environ):
195         sock = environ['eventlet.input'].get_socket()
196         hybi_version = environ['HTTP_SEC_WEBSOCKET_VERSION']
197         if hybi_version not in ('8', '13', ):
198             raise BadRequest(status='426 Upgrade Required',
199                              headers=[('Sec-WebSocket-Version', '8, 13')])
200         self.protocol_version = int(hybi_version)
201         if 'HTTP_SEC_WEBSOCKET_KEY' not in environ:
202             # That's bad.
203             raise BadRequest()
204         origin = environ.get(
205             'HTTP_ORIGIN',
206             (environ.get('HTTP_SEC_WEBSOCKET_ORIGIN', '')
207              if self.protocol_version <= 8 else ''))
208         if self.origin_checker is not None:
209             if not self.origin_checker(environ.get('HTTP_HOST'), origin):
210                 raise BadRequest(status='403 Forbidden')
211         protocols = environ.get('HTTP_SEC_WEBSOCKET_PROTOCOL', None)
212         negotiated_protocol = None
213         if protocols:
214             for p in (i.strip() for i in protocols.split(',')):
215                 if p in self.supported_protocols:
216                     negotiated_protocol = p
217                     break
218         # extensions = environ.get('HTTP_SEC_WEBSOCKET_EXTENSIONS', None)
219         # if extensions:
220         #    extensions = [i.strip() for i in extensions.split(',')]
221
222         key = environ['HTTP_SEC_WEBSOCKET_KEY']
223         response = base64.b64encode(sha1(six.b(key) + PROTOCOL_GUID).digest())
224         handshake_reply = [b"HTTP/1.1 101 Switching Protocols",
225                            b"Upgrade: websocket",
226                            b"Connection: Upgrade",
227                            b"Sec-WebSocket-Accept: " + response]
228         if negotiated_protocol:
229             handshake_reply.append(b"Sec-WebSocket-Protocol: " + six.b(negotiated_protocol))
230         sock.sendall(b'\r\n'.join(handshake_reply) + b'\r\n\r\n')
231         return RFC6455WebSocket(sock, environ, self.protocol_version,
232                                 protocol=negotiated_protocol)
233
234     def _extract_number(self, value):
235         """
236         Utility function which, given a string like 'g98sd  5[]221@1', will
237         return 9852211. Used to parse the Sec-WebSocket-Key headers.
238         """
239         out = ""
240         spaces = 0
241         for char in value:
242             if char in string.digits:
243                 out += char
244             elif char == " ":
245                 spaces += 1
246         return int(out) // spaces
247
248
249 class WebSocket(object):
250     """A websocket object that handles the details of
251     serialization/deserialization to the socket.
252
253     The primary way to interact with a :class:`WebSocket` object is to
254     call :meth:`send` and :meth:`wait` in order to pass messages back
255     and forth with the browser.  Also available are the following
256     properties:
257
258     path
259         The path value of the request.  This is the same as the WSGI PATH_INFO variable,
260         but more convenient.
261     protocol
262         The value of the Websocket-Protocol header.
263     origin
264         The value of the 'Origin' header.
265     environ
266         The full WSGI environment for this request.
267
268     """
269
270     def __init__(self, sock, environ, version=76):
271         """
272         :param socket: The eventlet socket
273         :type socket: :class:`eventlet.greenio.GreenSocket`
274         :param environ: The wsgi environment
275         :param version: The WebSocket spec version to follow (default is 76)
276         """
277         self.socket = sock
278         self.origin = environ.get('HTTP_ORIGIN')
279         self.protocol = environ.get('HTTP_WEBSOCKET_PROTOCOL')
280         self.path = environ.get('PATH_INFO')
281         self.environ = environ
282         self.version = version
283         self.websocket_closed = False
284         self._buf = b""
285         self._msgs = collections.deque()
286         self._sendlock = semaphore.Semaphore()
287
288     @staticmethod
289     def _pack_message(message):
290         """Pack the message inside ``00`` and ``FF``
291
292         As per the dataframing section (5.3) for the websocket spec
293         """
294         if isinstance(message, six.text_type):
295             message = message.encode('utf-8')
296         elif not isinstance(message, six.binary_type):
297             message = six.b(str(message))
298         packed = b"\x00" + message + b"\xFF"
299         return packed
300
301     def _parse_messages(self):
302         """ Parses for messages in the buffer *buf*.  It is assumed that
303         the buffer contains the start character for a message, but that it
304         may contain only part of the rest of the message.
305
306         Returns an array of messages, and the buffer remainder that
307         didn't contain any full messages."""
308         msgs = []
309         end_idx = 0
310         buf = self._buf
311         while buf:
312             frame_type = six.indexbytes(buf, 0)
313             if frame_type == 0:
314                 # Normal message.
315                 end_idx = buf.find(b"\xFF")
316                 if end_idx == -1:  # pragma NO COVER
317                     break
318                 msgs.append(buf[1:end_idx].decode('utf-8', 'replace'))
319                 buf = buf[end_idx + 1:]
320             elif frame_type == 255:
321                 # Closing handshake.
322                 assert six.indexbytes(buf, 1) == 0, "Unexpected closing handshake: %r" % buf
323                 self.websocket_closed = True
324                 break
325             else:
326                 raise ValueError("Don't understand how to parse this type of message: %r" % buf)
327         self._buf = buf
328         return msgs
329
330     def send(self, message):
331         """Send a message to the browser.
332
333         *message* should be convertable to a string; unicode objects should be
334         encodable as utf-8.  Raises socket.error with errno of 32
335         (broken pipe) if the socket has already been closed by the client."""
336         packed = self._pack_message(message)
337         # if two greenthreads are trying to send at the same time
338         # on the same socket, sendlock prevents interleaving and corruption
339         self._sendlock.acquire()
340         try:
341             self.socket.sendall(packed)
342         finally:
343             self._sendlock.release()
344
345     def wait(self):
346         """Waits for and deserializes messages.
347
348         Returns a single message; the oldest not yet processed. If the client
349         has already closed the connection, returns None.  This is different
350         from normal socket behavior because the empty string is a valid
351         websocket message."""
352         while not self._msgs:
353             # Websocket might be closed already.
354             if self.websocket_closed:
355                 return None
356             # no parsed messages, must mean buf needs more data
357             delta = self.socket.recv(8096)
358             if delta == b'':
359                 return None
360             self._buf += delta
361             msgs = self._parse_messages()
362             self._msgs.extend(msgs)
363         return self._msgs.popleft()
364
365     def _send_closing_frame(self, ignore_send_errors=False):
366         """Sends the closing frame to the client, if required."""
367         if self.version == 76 and not self.websocket_closed:
368             try:
369                 self.socket.sendall(b"\xff\x00")
370             except SocketError:
371                 # Sometimes, like when the remote side cuts off the connection,
372                 # we don't care about this.
373                 if not ignore_send_errors:  # pragma NO COVER
374                     raise
375             self.websocket_closed = True
376
377     def close(self):
378         """Forcibly close the websocket; generally it is preferable to
379         return from the handler method."""
380         self._send_closing_frame()
381         self.socket.shutdown(True)
382         self.socket.close()
383
384
385 class ConnectionClosedError(Exception):
386     pass
387
388
389 class FailedConnectionError(Exception):
390     def __init__(self, status, message):
391         super(FailedConnectionError, self).__init__(status, message)
392         self.message = message
393         self.status = status
394
395
396 class ProtocolError(ValueError):
397     pass
398
399
400 class RFC6455WebSocket(WebSocket):
401     def __init__(self, sock, environ, version=13, protocol=None, client=False):
402         super(RFC6455WebSocket, self).__init__(sock, environ, version)
403         self.iterator = self._iter_frames()
404         self.client = client
405         self.protocol = protocol
406
407     class UTF8Decoder(object):
408         def __init__(self):
409             if utf8validator:
410                 self.validator = utf8validator.Utf8Validator()
411             else:
412                 self.validator = None
413             decoderclass = codecs.getincrementaldecoder('utf8')
414             self.decoder = decoderclass()
415
416         def reset(self):
417             if self.validator:
418                 self.validator.reset()
419             self.decoder.reset()
420
421         def decode(self, data, final=False):
422             if self.validator:
423                 valid, eocp, c_i, t_i = self.validator.validate(data)
424                 if not valid:
425                     raise ValueError('Data is not valid unicode')
426             return self.decoder.decode(data, final)
427
428     def _get_bytes(self, numbytes):
429         data = b''
430         while len(data) < numbytes:
431             d = self.socket.recv(numbytes - len(data))
432             if not d:
433                 raise ConnectionClosedError()
434             data = data + d
435         return data
436
437     class Message(object):
438         def __init__(self, opcode, decoder=None):
439             self.decoder = decoder
440             self.data = []
441             self.finished = False
442             self.opcode = opcode
443
444         def push(self, data, final=False):
445             if self.decoder:
446                 data = self.decoder.decode(data, final=final)
447             self.finished = final
448             self.data.append(data)
449
450         def getvalue(self):
451             return ('' if self.decoder else b'').join(self.data)
452
453     @staticmethod
454     def _apply_mask(data, mask, length=None, offset=0):
455         if length is None:
456             length = len(data)
457         cnt = range(length)
458         return b''.join(six.int2byte(six.indexbytes(data, i) ^ mask[(offset + i) % 4]) for i in cnt)
459
460     def _handle_control_frame(self, opcode, data):
461         if opcode == 8:  # connection close
462             if not data:
463                 status = 1000
464             elif len(data) > 1:
465                 status = struct.unpack_from('!H', data)[0]
466                 if not status or status not in VALID_CLOSE_STATUS:
467                     raise FailedConnectionError(
468                         1002,
469                         "Unexpected close status code.")
470                 try:
471                     data = self.UTF8Decoder().decode(data[2:], True)
472                 except (UnicodeDecodeError, ValueError):
473                     raise FailedConnectionError(
474                         1002,
475                         "Close message data should be valid UTF-8.")
476             else:
477                 status = 1002
478             self.close(close_data=(status, ''))
479             raise ConnectionClosedError()
480         elif opcode == 9:  # ping
481             self.send(data, control_code=0xA)
482         elif opcode == 0xA:  # pong
483             pass
484         else:
485             raise FailedConnectionError(
486                 1002, "Unknown control frame received.")
487
488     def _iter_frames(self):
489         fragmented_message = None
490         try:
491             while True:
492                 message = self._recv_frame(message=fragmented_message)
493                 if message.opcode & 8:
494                     self._handle_control_frame(
495                         message.opcode, message.getvalue())
496                     continue
497                 if fragmented_message and message is not fragmented_message:
498                     raise RuntimeError('Unexpected message change.')
499                 fragmented_message = message
500                 if message.finished:
501                     data = fragmented_message.getvalue()
502                     fragmented_message = None
503                     yield data
504         except FailedConnectionError:
505             exc_typ, exc_val, exc_tb = sys.exc_info()
506             self.close(close_data=(exc_val.status, exc_val.message))
507         except ConnectionClosedError:
508             return
509         except Exception:
510             self.close(close_data=(1011, 'Internal Server Error'))
511             raise
512
513     def _recv_frame(self, message=None):
514         recv = self._get_bytes
515         header = recv(2)
516         a, b = struct.unpack('!BB', header)
517         finished = a >> 7 == 1
518         rsv123 = a >> 4 & 7
519         if rsv123:
520             # must be zero
521             raise FailedConnectionError(
522                 1002,
523                 "RSV1, RSV2, RSV3: MUST be 0 unless an extension is"
524                 " negotiated that defines meanings for non-zero values.")
525         opcode = a & 15
526         if opcode not in (0, 1, 2, 8, 9, 0xA):
527             raise FailedConnectionError(1002, "Unknown opcode received.")
528         masked = b & 128 == 128
529         if not masked and not self.client:
530             raise FailedConnectionError(1002, "A client MUST mask all frames"
531                                         " that it sends to the server")
532         length = b & 127
533         if opcode & 8:
534             if not finished:
535                 raise FailedConnectionError(1002, "Control frames must not"
536                                             " be fragmented.")
537             if length > 125:
538                 raise FailedConnectionError(
539                     1002,
540                     "All control frames MUST have a payload length of 125"
541                     " bytes or less")
542         elif opcode and message:
543             raise FailedConnectionError(
544                 1002,
545                 "Received a non-continuation opcode within"
546                 " fragmented message.")
547         elif not opcode and not message:
548             raise FailedConnectionError(
549                 1002,
550                 "Received continuation opcode with no previous"
551                 " fragments received.")
552         if length == 126:
553             length = struct.unpack('!H', recv(2))[0]
554         elif length == 127:
555             length = struct.unpack('!Q', recv(8))[0]
556         if masked:
557             mask = struct.unpack('!BBBB', recv(4))
558         received = 0
559         if not message or opcode & 8:
560             decoder = self.UTF8Decoder() if opcode == 1 else None
561             message = self.Message(opcode, decoder=decoder)
562         if not length:
563             message.push('', final=finished)
564         else:
565             while received < length:
566                 d = self.socket.recv(length - received)
567                 if not d:
568                     raise ConnectionClosedError()
569                 dlen = len(d)
570                 if masked:
571                     d = self._apply_mask(d, mask, length=dlen, offset=received)
572                 received = received + dlen
573                 try:
574                     message.push(d, final=finished)
575                 except (UnicodeDecodeError, ValueError):
576                     raise FailedConnectionError(
577                         1007, "Text data must be valid utf-8")
578         return message
579
580     @staticmethod
581     def _pack_message(message, masked=False,
582                       continuation=False, final=True, control_code=None):
583         is_text = False
584         if isinstance(message, six.text_type):
585             message = message.encode('utf-8')
586             is_text = True
587         length = len(message)
588         if not length:
589             # no point masking empty data
590             masked = False
591         if control_code:
592             if control_code not in (8, 9, 0xA):
593                 raise ProtocolError('Unknown control opcode.')
594             if continuation or not final:
595                 raise ProtocolError('Control frame cannot be a fragment.')
596             if length > 125:
597                 raise ProtocolError('Control frame data too large (>125).')
598             header = struct.pack('!B', control_code | 1 << 7)
599         else:
600             opcode = 0 if continuation else (1 if is_text else 2)
601             header = struct.pack('!B', opcode | (1 << 7 if final else 0))
602         lengthdata = 1 << 7 if masked else 0
603         if length > 65535:
604             lengthdata = struct.pack('!BQ', lengthdata | 127, length)
605         elif length > 125:
606             lengthdata = struct.pack('!BH', lengthdata | 126, length)
607         else:
608             lengthdata = struct.pack('!B', lengthdata | length)
609         if masked:
610             # NOTE: RFC6455 states:
611             # A server MUST NOT mask any frames that it sends to the client
612             rand = Random(time.time())
613             mask = [rand.getrandbits(8) for _ in six.moves.xrange(4)]
614             message = RFC6455WebSocket._apply_mask(message, mask, length)
615             maskdata = struct.pack('!BBBB', *mask)
616         else:
617             maskdata = b''
618
619         return b''.join((header, lengthdata, maskdata, message))
620
621     def wait(self):
622         for i in self.iterator:
623             return i
624
625     def _send(self, frame):
626         self._sendlock.acquire()
627         try:
628             self.socket.sendall(frame)
629         finally:
630             self._sendlock.release()
631
632     def send(self, message, **kw):
633         kw['masked'] = self.client
634         payload = self._pack_message(message, **kw)
635         self._send(payload)
636
637     def _send_closing_frame(self, ignore_send_errors=False, close_data=None):
638         if self.version in (8, 13) and not self.websocket_closed:
639             if close_data is not None:
640                 status, msg = close_data
641                 if isinstance(msg, six.text_type):
642                     msg = msg.encode('utf-8')
643                 data = struct.pack('!H', status) + msg
644             else:
645                 data = ''
646             try:
647                 self.send(data, control_code=8)
648             except SocketError:
649                 # Sometimes, like when the remote side cuts off the connection,
650                 # we don't care about this.
651                 if not ignore_send_errors:  # pragma NO COVER
652                     raise
653             self.websocket_closed = True
654
655     def close(self, close_data=None):
656         """Forcibly close the websocket; generally it is preferable to
657         return from the handler method."""
658         self._send_closing_frame(close_data=close_data)
659         self.socket.shutdown(socket.SHUT_WR)
660         self.socket.close()