ded65338c8d59bad2c7d26b7be9f21b3c81d979d
[packages/trusty/python-eventlet.git] / python-eventlet / eventlet / green / ssl.py
1 __ssl = __import__('ssl')
2
3 from eventlet.patcher import slurp_properties
4 slurp_properties(__ssl, globals(), srckeys=dir(__ssl))
5
6 import functools
7 import sys
8 import errno
9 time = __import__('time')
10
11 from eventlet.support import get_errno, PY33, six
12 from eventlet.hubs import trampoline, IOClosed
13 from eventlet.greenio import (
14     set_nonblocking, GreenSocket, SOCKET_CLOSED, CONNECT_ERR, CONNECT_SUCCESS,
15 )
16 orig_socket = __import__('socket')
17 socket = orig_socket.socket
18 if sys.version_info >= (2, 7):
19     has_ciphers = True
20     timeout_exc = SSLError
21 else:
22     has_ciphers = False
23     timeout_exc = orig_socket.timeout
24
25 __patched__ = ['SSLSocket', 'wrap_socket', 'sslwrap_simple']
26
27 _original_sslsocket = __ssl.SSLSocket
28
29
30 class GreenSSLSocket(_original_sslsocket):
31     """ This is a green version of the SSLSocket class from the ssl module added
32     in 2.6.  For documentation on it, please see the Python standard
33     documentation.
34
35     Python nonblocking ssl objects don't give errors when the other end
36     of the socket is closed (they do notice when the other end is shutdown,
37     though).  Any write/read operations will simply hang if the socket is
38     closed from the other end.  There is no obvious fix for this problem;
39     it appears to be a limitation of Python's ssl object implementation.
40     A workaround is to set a reasonable timeout on the socket using
41     settimeout(), and to close/reopen the connection when a timeout
42     occurs at an unexpected juncture in the code.
43     """
44     # we are inheriting from SSLSocket because its constructor calls
45     # do_handshake whose behavior we wish to override
46
47     def __init__(self, sock, keyfile=None, certfile=None,
48                  server_side=False, cert_reqs=CERT_NONE,
49                  ssl_version=PROTOCOL_SSLv23, ca_certs=None,
50                  do_handshake_on_connect=True, *args, **kw):
51         if not isinstance(sock, GreenSocket):
52             sock = GreenSocket(sock)
53
54         self.act_non_blocking = sock.act_non_blocking
55
56         if six.PY2:
57             # On Python 2 SSLSocket constructor queries the timeout, it'd break without
58             # this assignment
59             self._timeout = sock.gettimeout()
60
61         # nonblocking socket handshaking on connect got disabled so let's pretend it's disabled
62         # even when it's on
63         super(GreenSSLSocket, self).__init__(
64             sock.fd, keyfile, certfile, server_side, cert_reqs, ssl_version,
65             ca_certs, do_handshake_on_connect and six.PY2, *args, **kw)
66
67         # the superclass initializer trashes the methods so we remove
68         # the local-object versions of them and let the actual class
69         # methods shine through
70         # Note: This for Python 2
71         try:
72             for fn in orig_socket._delegate_methods:
73                 delattr(self, fn)
74         except AttributeError:
75             pass
76
77         if six.PY3:
78             # Python 3 SSLSocket construction process overwrites the timeout so restore it
79             self._timeout = sock.gettimeout()
80
81             # it also sets timeout to None internally apparently (tested with 3.4.2)
82             _original_sslsocket.settimeout(self, 0.0)
83             assert _original_sslsocket.gettimeout(self) == 0.0
84
85             # see note above about handshaking
86             self.do_handshake_on_connect = do_handshake_on_connect
87             if do_handshake_on_connect and self._connected:
88                 self.do_handshake()
89
90     def settimeout(self, timeout):
91         self._timeout = timeout
92
93     def gettimeout(self):
94         return self._timeout
95
96     def setblocking(self, flag):
97         if flag:
98             self.act_non_blocking = False
99             self._timeout = None
100         else:
101             self.act_non_blocking = True
102             self._timeout = 0.0
103
104     def _call_trampolining(self, func, *a, **kw):
105         if self.act_non_blocking:
106             return func(*a, **kw)
107         else:
108             while True:
109                 try:
110                     return func(*a, **kw)
111                 except SSLError as exc:
112                     if get_errno(exc) == SSL_ERROR_WANT_READ:
113                         trampoline(self,
114                                    read=True,
115                                    timeout=self.gettimeout(),
116                                    timeout_exc=timeout_exc('timed out'))
117                     elif get_errno(exc) == SSL_ERROR_WANT_WRITE:
118                         trampoline(self,
119                                    write=True,
120                                    timeout=self.gettimeout(),
121                                    timeout_exc=timeout_exc('timed out'))
122                     else:
123                         raise
124
125     def write(self, data):
126         """Write DATA to the underlying SSL channel.  Returns
127         number of bytes of DATA actually transmitted."""
128         return self._call_trampolining(
129             super(GreenSSLSocket, self).write, data)
130
131     def read(self, *args, **kwargs):
132         """Read up to LEN bytes and return them.
133         Return zero-length string on EOF."""
134         try:
135             return self._call_trampolining(
136                 super(GreenSSLSocket, self).read, *args, **kwargs)
137         except IOClosed:
138             return b''
139
140     def send(self, data, flags=0):
141         if self._sslobj:
142             return self._call_trampolining(
143                 super(GreenSSLSocket, self).send, data, flags)
144         else:
145             trampoline(self, write=True, timeout_exc=timeout_exc('timed out'))
146             return socket.send(self, data, flags)
147
148     def sendto(self, data, addr, flags=0):
149         # *NOTE: gross, copied code from ssl.py becase it's not factored well enough to be used as-is
150         if self._sslobj:
151             raise ValueError("sendto not allowed on instances of %s" %
152                              self.__class__)
153         else:
154             trampoline(self, write=True, timeout_exc=timeout_exc('timed out'))
155             return socket.sendto(self, data, addr, flags)
156
157     def sendall(self, data, flags=0):
158         # *NOTE: gross, copied code from ssl.py becase it's not factored well enough to be used as-is
159         if self._sslobj:
160             if flags != 0:
161                 raise ValueError(
162                     "non-zero flags not allowed in calls to sendall() on %s" %
163                     self.__class__)
164             amount = len(data)
165             count = 0
166             data_to_send = data
167             while (count < amount):
168                 v = self.send(data_to_send)
169                 count += v
170                 if v == 0:
171                     trampoline(self, write=True, timeout_exc=timeout_exc('timed out'))
172                 else:
173                     data_to_send = data[count:]
174             return amount
175         else:
176             while True:
177                 try:
178                     return socket.sendall(self, data, flags)
179                 except orig_socket.error as e:
180                     if self.act_non_blocking:
181                         raise
182                     if get_errno(e) == errno.EWOULDBLOCK:
183                         trampoline(self, write=True,
184                                    timeout=self.gettimeout(), timeout_exc=timeout_exc('timed out'))
185                     if get_errno(e) in SOCKET_CLOSED:
186                         return ''
187                     raise
188
189     def recv(self, buflen=1024, flags=0):
190         # *NOTE: gross, copied code from ssl.py becase it's not factored well enough to be used as-is
191         if self._sslobj:
192             if flags != 0:
193                 raise ValueError(
194                     "non-zero flags not allowed in calls to recv() on %s" %
195                     self.__class__)
196             read = self.read(buflen)
197             return read
198         else:
199             while True:
200                 try:
201                     return socket.recv(self, buflen, flags)
202                 except orig_socket.error as e:
203                     if self.act_non_blocking:
204                         raise
205                     if get_errno(e) == errno.EWOULDBLOCK:
206                         try:
207                             trampoline(
208                                 self, read=True,
209                                 timeout=self.gettimeout(), timeout_exc=timeout_exc('timed out'))
210                         except IOClosed:
211                             return b''
212                     if get_errno(e) in SOCKET_CLOSED:
213                         return b''
214                     raise
215
216     def recv_into(self, buffer, nbytes=None, flags=0):
217         if not self.act_non_blocking:
218             trampoline(self, read=True, timeout=self.gettimeout(),
219                        timeout_exc=timeout_exc('timed out'))
220         return super(GreenSSLSocket, self).recv_into(buffer, nbytes, flags)
221
222     def recvfrom(self, addr, buflen=1024, flags=0):
223         if not self.act_non_blocking:
224             trampoline(self, read=True, timeout=self.gettimeout(),
225                        timeout_exc=timeout_exc('timed out'))
226         return super(GreenSSLSocket, self).recvfrom(addr, buflen, flags)
227
228     def recvfrom_into(self, buffer, nbytes=None, flags=0):
229         if not self.act_non_blocking:
230             trampoline(self, read=True, timeout=self.gettimeout(),
231                        timeout_exc=timeout_exc('timed out'))
232         return super(GreenSSLSocket, self).recvfrom_into(buffer, nbytes, flags)
233
234     def unwrap(self):
235         return GreenSocket(self._call_trampolining(
236             super(GreenSSLSocket, self).unwrap))
237
238     def do_handshake(self):
239         """Perform a TLS/SSL handshake."""
240         return self._call_trampolining(
241             super(GreenSSLSocket, self).do_handshake)
242
243     def _socket_connect(self, addr):
244         real_connect = socket.connect
245         if self.act_non_blocking:
246             return real_connect(self, addr)
247         else:
248             # *NOTE: gross, copied code from greenio because it's not factored
249             # well enough to reuse
250             if self.gettimeout() is None:
251                 while True:
252                     try:
253                         return real_connect(self, addr)
254                     except orig_socket.error as exc:
255                         if get_errno(exc) in CONNECT_ERR:
256                             trampoline(self, write=True)
257                         elif get_errno(exc) in CONNECT_SUCCESS:
258                             return
259                         else:
260                             raise
261             else:
262                 end = time.time() + self.gettimeout()
263                 while True:
264                     try:
265                         real_connect(self, addr)
266                     except orig_socket.error as exc:
267                         if get_errno(exc) in CONNECT_ERR:
268                             trampoline(
269                                 self, write=True,
270                                 timeout=end - time.time(), timeout_exc=timeout_exc('timed out'))
271                         elif get_errno(exc) in CONNECT_SUCCESS:
272                             return
273                         else:
274                             raise
275                     if time.time() >= end:
276                         raise timeout_exc('timed out')
277
278     def connect(self, addr):
279         """Connects to remote ADDR, and then wraps the connection in
280         an SSL channel."""
281         # *NOTE: grrrrr copied this code from ssl.py because of the reference
282         # to socket.connect which we don't want to call directly
283         if self._sslobj:
284             raise ValueError("attempt to connect already-connected SSLSocket!")
285         self._socket_connect(addr)
286         server_side = False
287         try:
288             sslwrap = _ssl.sslwrap
289         except AttributeError:
290             # sslwrap was removed in 3.x and later in 2.7.9
291             if six.PY2:
292                 sslobj = self._context._wrap_socket(self._sock, server_side, ssl_sock=self)
293             else:
294                 context = self.context if PY33 else self._context
295                 sslobj = context._wrap_socket(self, server_side)
296         else:
297             sslobj = sslwrap(self._sock, server_side, self.keyfile, self.certfile,
298                              self.cert_reqs, self.ssl_version,
299                              self.ca_certs, *([self.ciphers] if has_ciphers else []))
300
301         self._sslobj = sslobj
302         if self.do_handshake_on_connect:
303             self.do_handshake()
304
305     def accept(self):
306         """Accepts a new connection from a remote client, and returns
307         a tuple containing that new connection wrapped with a server-side
308         SSL channel, and the address of the remote client."""
309         # RDW grr duplication of code from greenio
310         if self.act_non_blocking:
311             newsock, addr = socket.accept(self)
312         else:
313             while True:
314                 try:
315                     newsock, addr = socket.accept(self)
316                     set_nonblocking(newsock)
317                     break
318                 except orig_socket.error as e:
319                     if get_errno(e) != errno.EWOULDBLOCK:
320                         raise
321                     trampoline(self, read=True, timeout=self.gettimeout(),
322                                timeout_exc=timeout_exc('timed out'))
323
324         new_ssl = type(self)(
325             newsock,
326             keyfile=self.keyfile,
327             certfile=self.certfile,
328             server_side=True,
329             cert_reqs=self.cert_reqs,
330             ssl_version=self.ssl_version,
331             ca_certs=self.ca_certs,
332             do_handshake_on_connect=self.do_handshake_on_connect,
333             suppress_ragged_eofs=self.suppress_ragged_eofs)
334         return (new_ssl, addr)
335
336     def dup(self):
337         raise NotImplementedError("Can't dup an ssl object")
338
339 SSLSocket = GreenSSLSocket
340
341
342 def wrap_socket(sock, *a, **kw):
343     return GreenSSLSocket(sock, *a, **kw)
344
345
346 if hasattr(__ssl, 'sslwrap_simple'):
347     def sslwrap_simple(sock, keyfile=None, certfile=None):
348         """A replacement for the old socket.ssl function.  Designed
349         for compatibility with Python 2.5 and earlier.  Will disappear in
350         Python 3.0."""
351         ssl_sock = GreenSSLSocket(sock, keyfile=keyfile, certfile=certfile,
352                                   server_side=False,
353                                   cert_reqs=CERT_NONE,
354                                   ssl_version=PROTOCOL_SSLv23,
355                                   ca_certs=None)
356         return ssl_sock
357
358
359 if hasattr(__ssl, 'SSLContext'):
360     @functools.wraps(__ssl.SSLContext.wrap_socket)
361     def _green_sslcontext_wrap_socket(self, sock, *a, **kw):
362         return GreenSSLSocket(sock, *a, _context=self, **kw)
363
364     # FIXME:
365     # * GreenSSLContext akin to GreenSSLSocket
366     # * make ssl.create_default_context() use modified SSLContext from globals as usual
367     __ssl.SSLContext.wrap_socket = _green_sslcontext_wrap_socket