eb09f9ad6c44509c8fdc8c423d9030ffa92d10b4
[packages/trusty/python-eventlet.git] / python-eventlet / eventlet / patcher.py
1 import imp
2 import sys
3
4 from eventlet.support import six
5
6
7 __all__ = ['inject', 'import_patched', 'monkey_patch', 'is_monkey_patched']
8
9 __exclude = set(('__builtins__', '__file__', '__name__'))
10
11
12 class SysModulesSaver(object):
13     """Class that captures some subset of the current state of
14     sys.modules.  Pass in an iterator of module names to the
15     constructor."""
16
17     def __init__(self, module_names=()):
18         self._saved = {}
19         imp.acquire_lock()
20         self.save(*module_names)
21
22     def save(self, *module_names):
23         """Saves the named modules to the object."""
24         for modname in module_names:
25             self._saved[modname] = sys.modules.get(modname, None)
26
27     def restore(self):
28         """Restores the modules that the saver knows about into
29         sys.modules.
30         """
31         try:
32             for modname, mod in six.iteritems(self._saved):
33                 if mod is not None:
34                     sys.modules[modname] = mod
35                 else:
36                     try:
37                         del sys.modules[modname]
38                     except KeyError:
39                         pass
40         finally:
41             imp.release_lock()
42
43
44 def inject(module_name, new_globals, *additional_modules):
45     """Base method for "injecting" greened modules into an imported module.  It
46     imports the module specified in *module_name*, arranging things so
47     that the already-imported modules in *additional_modules* are used when
48     *module_name* makes its imports.
49
50     **Note:** This function does not create or change any sys.modules item, so
51     if your greened module use code like 'sys.modules["your_module_name"]', you
52     need to update sys.modules by yourself.
53
54     *new_globals* is either None or a globals dictionary that gets populated
55     with the contents of the *module_name* module.  This is useful when creating
56     a "green" version of some other module.
57
58     *additional_modules* should be a collection of two-element tuples, of the
59     form (<name>, <module>).  If it's not specified, a default selection of
60     name/module pairs is used, which should cover all use cases but may be
61     slower because there are inevitably redundant or unnecessary imports.
62     """
63     patched_name = '__patched_module_' + module_name
64     if patched_name in sys.modules:
65         # returning already-patched module so as not to destroy existing
66         # references to patched modules
67         return sys.modules[patched_name]
68
69     if not additional_modules:
70         # supply some defaults
71         additional_modules = (
72             _green_os_modules() +
73             _green_select_modules() +
74             _green_socket_modules() +
75             _green_thread_modules() +
76             _green_time_modules())
77         # _green_MySQLdb()) # enable this after a short baking-in period
78
79     # after this we are gonna screw with sys.modules, so capture the
80     # state of all the modules we're going to mess with, and lock
81     saver = SysModulesSaver([name for name, m in additional_modules])
82     saver.save(module_name)
83
84     # Cover the target modules so that when you import the module it
85     # sees only the patched versions
86     for name, mod in additional_modules:
87         sys.modules[name] = mod
88
89     # Remove the old module from sys.modules and reimport it while
90     # the specified modules are in place
91     sys.modules.pop(module_name, None)
92     try:
93         module = __import__(module_name, {}, {}, module_name.split('.')[:-1])
94
95         if new_globals is not None:
96             # Update the given globals dictionary with everything from this new module
97             for name in dir(module):
98                 if name not in __exclude:
99                     new_globals[name] = getattr(module, name)
100
101         # Keep a reference to the new module to prevent it from dying
102         sys.modules[patched_name] = module
103     finally:
104         saver.restore()  # Put the original modules back
105
106     return module
107
108
109 def import_patched(module_name, *additional_modules, **kw_additional_modules):
110     """Imports a module in a way that ensures that the module uses "green"
111     versions of the standard library modules, so that everything works
112     nonblockingly.
113
114     The only required argument is the name of the module to be imported.
115     """
116     return inject(
117         module_name,
118         None,
119         *additional_modules + tuple(kw_additional_modules.items()))
120
121
122 def patch_function(func, *additional_modules):
123     """Decorator that returns a version of the function that patches
124     some modules for the duration of the function call.  This is
125     deeply gross and should only be used for functions that import
126     network libraries within their function bodies that there is no
127     way of getting around."""
128     if not additional_modules:
129         # supply some defaults
130         additional_modules = (
131             _green_os_modules() +
132             _green_select_modules() +
133             _green_socket_modules() +
134             _green_thread_modules() +
135             _green_time_modules())
136
137     def patched(*args, **kw):
138         saver = SysModulesSaver()
139         for name, mod in additional_modules:
140             saver.save(name)
141             sys.modules[name] = mod
142         try:
143             return func(*args, **kw)
144         finally:
145             saver.restore()
146     return patched
147
148
149 def _original_patch_function(func, *module_names):
150     """Kind of the contrapositive of patch_function: decorates a
151     function such that when it's called, sys.modules is populated only
152     with the unpatched versions of the specified modules.  Unlike
153     patch_function, only the names of the modules need be supplied,
154     and there are no defaults.  This is a gross hack; tell your kids not
155     to import inside function bodies!"""
156     def patched(*args, **kw):
157         saver = SysModulesSaver(module_names)
158         for name in module_names:
159             sys.modules[name] = original(name)
160         try:
161             return func(*args, **kw)
162         finally:
163             saver.restore()
164     return patched
165
166
167 def original(modname):
168     """ This returns an unpatched version of a module; this is useful for
169     Eventlet itself (i.e. tpool)."""
170     # note that it's not necessary to temporarily install unpatched
171     # versions of all patchable modules during the import of the
172     # module; this is because none of them import each other, except
173     # for threading which imports thread
174     original_name = '__original_module_' + modname
175     if original_name in sys.modules:
176         return sys.modules.get(original_name)
177
178     # re-import the "pure" module and store it in the global _originals
179     # dict; be sure to restore whatever module had that name already
180     saver = SysModulesSaver((modname,))
181     sys.modules.pop(modname, None)
182     # some rudimentary dependency checking -- fortunately the modules
183     # we're working on don't have many dependencies so we can just do
184     # some special-casing here
185     if six.PY2:
186         deps = {'threading': 'thread', 'Queue': 'threading'}
187     if six.PY3:
188         deps = {'threading': '_thread', 'queue': 'threading'}
189     if modname in deps:
190         dependency = deps[modname]
191         saver.save(dependency)
192         sys.modules[dependency] = original(dependency)
193     try:
194         real_mod = __import__(modname, {}, {}, modname.split('.')[:-1])
195         if modname in ('Queue', 'queue') and not hasattr(real_mod, '_threading'):
196             # tricky hack: Queue's constructor in <2.7 imports
197             # threading on every instantiation; therefore we wrap
198             # it so that it always gets the original threading
199             real_mod.Queue.__init__ = _original_patch_function(
200                 real_mod.Queue.__init__,
201                 'threading')
202         # save a reference to the unpatched module so it doesn't get lost
203         sys.modules[original_name] = real_mod
204     finally:
205         saver.restore()
206
207     return sys.modules[original_name]
208
209 already_patched = {}
210
211
212 def monkey_patch(**on):
213     """Globally patches certain system modules to be greenthread-friendly.
214
215     The keyword arguments afford some control over which modules are patched.
216     If no keyword arguments are supplied, all possible modules are patched.
217     If keywords are set to True, only the specified modules are patched.  E.g.,
218     ``monkey_patch(socket=True, select=True)`` patches only the select and
219     socket modules.  Most arguments patch the single module of the same name
220     (os, time, select).  The exceptions are socket, which also patches the ssl
221     module if present; and thread, which patches thread, threading, and Queue.
222
223     It's safe to call monkey_patch multiple times.
224     """
225     accepted_args = set(('os', 'select', 'socket',
226                          'thread', 'time', 'psycopg', 'MySQLdb',
227                          'builtins'))
228     # To make sure only one of them is passed here
229     assert not ('__builtin__' in on and 'builtins' in on)
230     try:
231         b = on.pop('__builtin__')
232     except KeyError:
233         pass
234     else:
235         on['builtins'] = b
236
237     default_on = on.pop("all", None)
238
239     for k in six.iterkeys(on):
240         if k not in accepted_args:
241             raise TypeError("monkey_patch() got an unexpected "
242                             "keyword argument %r" % k)
243     if default_on is None:
244         default_on = not (True in on.values())
245     for modname in accepted_args:
246         if modname == 'MySQLdb':
247             # MySQLdb is only on when explicitly patched for the moment
248             on.setdefault(modname, False)
249         if modname == 'builtins':
250             on.setdefault(modname, False)
251         on.setdefault(modname, default_on)
252
253     modules_to_patch = []
254     if on['os'] and not already_patched.get('os'):
255         modules_to_patch += _green_os_modules()
256         already_patched['os'] = True
257     if on['select'] and not already_patched.get('select'):
258         modules_to_patch += _green_select_modules()
259         already_patched['select'] = True
260     if on['socket'] and not already_patched.get('socket'):
261         modules_to_patch += _green_socket_modules()
262         already_patched['socket'] = True
263     if on['thread'] and not already_patched.get('thread'):
264         modules_to_patch += _green_thread_modules()
265         already_patched['thread'] = True
266     if on['time'] and not already_patched.get('time'):
267         modules_to_patch += _green_time_modules()
268         already_patched['time'] = True
269     if on.get('MySQLdb') and not already_patched.get('MySQLdb'):
270         modules_to_patch += _green_MySQLdb()
271         already_patched['MySQLdb'] = True
272     if on.get('builtins') and not already_patched.get('builtins'):
273         modules_to_patch += _green_builtins()
274         already_patched['builtins'] = True
275     if on['psycopg'] and not already_patched.get('psycopg'):
276         try:
277             from eventlet.support import psycopg2_patcher
278             psycopg2_patcher.make_psycopg_green()
279             already_patched['psycopg'] = True
280         except ImportError:
281             # note that if we get an importerror from trying to
282             # monkeypatch psycopg, we will continually retry it
283             # whenever monkey_patch is called; this should not be a
284             # performance problem but it allows is_monkey_patched to
285             # tell us whether or not we succeeded
286             pass
287
288     imp.acquire_lock()
289     try:
290         for name, mod in modules_to_patch:
291             orig_mod = sys.modules.get(name)
292             if orig_mod is None:
293                 orig_mod = __import__(name)
294             for attr_name in mod.__patched__:
295                 patched_attr = getattr(mod, attr_name, None)
296                 if patched_attr is not None:
297                     setattr(orig_mod, attr_name, patched_attr)
298     finally:
299         imp.release_lock()
300
301     if sys.version_info >= (3, 3):
302         import importlib._bootstrap
303         thread = original('_thread')
304         # importlib must use real thread locks, not eventlet.Semaphore
305         importlib._bootstrap._thread = thread
306
307         # Issue #185: Since Python 3.3, threading.RLock is implemented in C and
308         # so call a C function to get the thread identifier, instead of calling
309         # threading.get_ident(). Force the Python implementation of RLock which
310         # calls threading.get_ident() and so is compatible with eventlet.
311         import threading
312         threading.RLock = threading._PyRLock
313
314
315 def is_monkey_patched(module):
316     """Returns True if the given module is monkeypatched currently, False if
317     not.  *module* can be either the module itself or its name.
318
319     Based entirely off the name of the module, so if you import a
320     module some other way than with the import keyword (including
321     import_patched), this might not be correct about that particular
322     module."""
323     return module in already_patched or \
324         getattr(module, '__name__', None) in already_patched
325
326
327 def _green_os_modules():
328     from eventlet.green import os
329     return [('os', os)]
330
331
332 def _green_select_modules():
333     from eventlet.green import select
334     return [('select', select)]
335
336
337 def _green_socket_modules():
338     from eventlet.green import socket
339     try:
340         from eventlet.green import ssl
341         return [('socket', socket), ('ssl', ssl)]
342     except ImportError:
343         return [('socket', socket)]
344
345
346 def _green_thread_modules():
347     from eventlet.green import Queue
348     from eventlet.green import thread
349     from eventlet.green import threading
350     if six.PY2:
351         return [('Queue', Queue), ('thread', thread), ('threading', threading)]
352     if six.PY3:
353         return [('queue', Queue), ('_thread', thread), ('threading', threading)]
354
355
356 def _green_time_modules():
357     from eventlet.green import time
358     return [('time', time)]
359
360
361 def _green_MySQLdb():
362     try:
363         from eventlet.green import MySQLdb
364         return [('MySQLdb', MySQLdb)]
365     except ImportError:
366         return []
367
368
369 def _green_builtins():
370     try:
371         from eventlet.green import builtin
372         return [('__builtin__' if six.PY2 else 'builtins', builtin)]
373     except ImportError:
374         return []
375
376
377 def slurp_properties(source, destination, ignore=[], srckeys=None):
378     """Copy properties from *source* (assumed to be a module) to
379     *destination* (assumed to be a dict).
380
381     *ignore* lists properties that should not be thusly copied.
382     *srckeys* is a list of keys to copy, if the source's __all__ is
383     untrustworthy.
384     """
385     if srckeys is None:
386         srckeys = source.__all__
387     destination.update(dict([
388         (name, getattr(source, name))
389         for name in srckeys
390         if not (name.startswith('__') or name in ignore)
391     ]))
392
393
394 if __name__ == "__main__":
395     sys.argv.pop(0)
396     monkey_patch()
397     with open(sys.argv[0]) as f:
398         code = compile(f.read(), sys.argv[0], 'exec')
399         exec(code)