Add python-eventlet 0.16.1
[packages/trusty/python-eventlet.git] / eventlet / eventlet / support / greendns.py
1 #!/usr/bin/env python
2 '''
3     greendns - non-blocking DNS support for Eventlet
4 '''
5
6 # Portions of this code taken from the gogreen project:
7 #   http://github.com/slideinc/gogreen
8 #
9 # Copyright (c) 2005-2010 Slide, Inc.
10 # All rights reserved.
11 #
12 # Redistribution and use in source and binary forms, with or without
13 # modification, are permitted provided that the following conditions are
14 # met:
15 #
16 #     * Redistributions of source code must retain the above copyright
17 #       notice, this list of conditions and the following disclaimer.
18 #     * Redistributions in binary form must reproduce the above
19 #       copyright notice, this list of conditions and the following
20 #       disclaimer in the documentation and/or other materials provided
21 #       with the distribution.
22 #     * Neither the name of the author nor the names of other
23 #       contributors may be used to endorse or promote products derived
24 #       from this software without specific prior written permission.
25 #
26 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
27 # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
28 # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
29 # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
30 # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
31 # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
32 # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
33 # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
34 # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
35 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
36 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
37 import struct
38 import sys
39
40 from eventlet import patcher
41 from eventlet.green import _socket_nodns
42 from eventlet.green import time
43 from eventlet.green import select
44
45 dns = patcher.import_patched('dns',
46                              socket=_socket_nodns,
47                              time=time,
48                              select=select)
49 for pkg in ('dns.query', 'dns.exception', 'dns.inet', 'dns.message',
50             'dns.rdatatype', 'dns.resolver', 'dns.reversename'):
51     setattr(dns, pkg.split('.')[1], patcher.import_patched(
52         pkg,
53         socket=_socket_nodns,
54         time=time,
55         select=select))
56
57 socket = _socket_nodns
58
59 DNS_QUERY_TIMEOUT = 10.0
60
61
62 #
63 # Resolver instance used to perfrom DNS lookups.
64 #
65 class FakeAnswer(list):
66     expiration = 0
67
68
69 class FakeRecord(object):
70     pass
71
72
73 class ResolverProxy(object):
74     def __init__(self, *args, **kwargs):
75         self._resolver = None
76         self._filename = kwargs.get('filename', '/etc/resolv.conf')
77         self._hosts = {}
78         if kwargs.pop('dev', False):
79             self._load_etc_hosts()
80
81     def _load_etc_hosts(self):
82         try:
83             fd = open('/etc/hosts', 'r')
84             contents = fd.read()
85             fd.close()
86         except (IOError, OSError):
87             return
88         contents = [line for line in contents.split('\n') if line and not line[0] == '#']
89         for line in contents:
90             line = line.replace('\t', ' ')
91             parts = line.split(' ')
92             parts = [p for p in parts if p]
93             if not len(parts):
94                 continue
95             ip = parts[0]
96             for part in parts[1:]:
97                 self._hosts[part] = ip
98
99     def clear(self):
100         self._resolver = None
101
102     def query(self, *args, **kwargs):
103         if self._resolver is None:
104             self._resolver = dns.resolver.Resolver(filename=self._filename)
105             self._resolver.cache = dns.resolver.Cache()
106
107         query = args[0]
108         if query is None:
109             args = list(args)
110             query = args[0] = '0.0.0.0'
111         if self._hosts and self._hosts.get(query):
112             answer = FakeAnswer()
113             record = FakeRecord()
114             setattr(record, 'address', self._hosts[query])
115             answer.append(record)
116             return answer
117         return self._resolver.query(*args, **kwargs)
118 #
119 # cache
120 #
121 resolver = ResolverProxy(dev=True)
122
123
124 def resolve(name):
125     error = None
126     rrset = None
127
128     if rrset is None or time.time() > rrset.expiration:
129         try:
130             rrset = resolver.query(name)
131         except dns.exception.Timeout:
132             error = (socket.EAI_AGAIN, 'Lookup timed out')
133         except dns.exception.DNSException:
134             error = (socket.EAI_NODATA, 'No address associated with hostname')
135         else:
136             pass
137             # responses.insert(name, rrset)
138
139     if error:
140         if rrset is None:
141             raise socket.gaierror(error)
142         else:
143             sys.stderr.write('DNS error: %r %r\n' % (name, error))
144     return rrset
145
146
147 #
148 # methods
149 #
150 def getaliases(host):
151     """Checks for aliases of the given hostname (cname records)
152     returns a list of alias targets
153     will return an empty list if no aliases
154     """
155     cnames = []
156     error = None
157
158     try:
159         answers = dns.resolver.query(host, 'cname')
160     except dns.exception.Timeout:
161         error = (socket.EAI_AGAIN, 'Lookup timed out')
162     except dns.exception.DNSException:
163         error = (socket.EAI_NODATA, 'No address associated with hostname')
164     else:
165         for record in answers:
166             cnames.append(str(answers[0].target))
167
168     if error:
169         sys.stderr.write('DNS error: %r %r\n' % (host, error))
170
171     return cnames
172
173
174 def getaddrinfo(host, port, family=0, socktype=0, proto=0, flags=0):
175     """Replacement for Python's socket.getaddrinfo.
176
177     Currently only supports IPv4.  At present, flags are not
178     implemented.
179     """
180     socktype = socktype or socket.SOCK_STREAM
181
182     if is_ipv4_addr(host):
183         return [(socket.AF_INET, socktype, proto, '', (host, port))]
184
185     rrset = resolve(host)
186     value = []
187
188     for rr in rrset:
189         value.append((socket.AF_INET, socktype, proto, '', (rr.address, port)))
190     return value
191
192
193 def gethostbyname(hostname):
194     """Replacement for Python's socket.gethostbyname.
195
196     Currently only supports IPv4.
197     """
198     if is_ipv4_addr(hostname):
199         return hostname
200
201     rrset = resolve(hostname)
202     return rrset[0].address
203
204
205 def gethostbyname_ex(hostname):
206     """Replacement for Python's socket.gethostbyname_ex.
207
208     Currently only supports IPv4.
209     """
210     if is_ipv4_addr(hostname):
211         return (hostname, [], [hostname])
212
213     rrset = resolve(hostname)
214     addrs = []
215
216     for rr in rrset:
217         addrs.append(rr.address)
218     return (hostname, [], addrs)
219
220
221 def getnameinfo(sockaddr, flags):
222     """Replacement for Python's socket.getnameinfo.
223
224     Currently only supports IPv4.
225     """
226     try:
227         host, port = sockaddr
228     except (ValueError, TypeError):
229         if not isinstance(sockaddr, tuple):
230             del sockaddr  # to pass a stdlib test that is
231             # hyper-careful about reference counts
232             raise TypeError('getnameinfo() argument 1 must be a tuple')
233         else:
234             # must be ipv6 sockaddr, pretending we don't know how to resolve it
235             raise socket.gaierror(-2, 'name or service not known')
236
237     if (flags & socket.NI_NAMEREQD) and (flags & socket.NI_NUMERICHOST):
238         # Conflicting flags.  Punt.
239         raise socket.gaierror(
240             (socket.EAI_NONAME, 'Name or service not known'))
241
242     if is_ipv4_addr(host):
243         try:
244             rrset = resolver.query(
245                 dns.reversename.from_address(host), dns.rdatatype.PTR)
246             if len(rrset) > 1:
247                 raise socket.error('sockaddr resolved to multiple addresses')
248             host = rrset[0].target.to_text(omit_final_dot=True)
249         except dns.exception.Timeout:
250             if flags & socket.NI_NAMEREQD:
251                 raise socket.gaierror((socket.EAI_AGAIN, 'Lookup timed out'))
252         except dns.exception.DNSException:
253             if flags & socket.NI_NAMEREQD:
254                 raise socket.gaierror(
255                     (socket.EAI_NONAME, 'Name or service not known'))
256     else:
257         try:
258             rrset = resolver.query(host)
259             if len(rrset) > 1:
260                 raise socket.error('sockaddr resolved to multiple addresses')
261             if flags & socket.NI_NUMERICHOST:
262                 host = rrset[0].address
263         except dns.exception.Timeout:
264             raise socket.gaierror((socket.EAI_AGAIN, 'Lookup timed out'))
265         except dns.exception.DNSException:
266             raise socket.gaierror(
267                 (socket.EAI_NODATA, 'No address associated with hostname'))
268
269     if not (flags & socket.NI_NUMERICSERV):
270         proto = (flags & socket.NI_DGRAM) and 'udp' or 'tcp'
271         port = socket.getservbyport(port, proto)
272
273     return (host, port)
274
275
276 def is_ipv4_addr(host):
277     """is_ipv4_addr returns true if host is a valid IPv4 address in
278     dotted quad notation.
279     """
280     try:
281         d1, d2, d3, d4 = map(int, host.split('.'))
282     except (ValueError, AttributeError):
283         return False
284
285     if 0 <= d1 <= 255 and 0 <= d2 <= 255 and 0 <= d3 <= 255 and 0 <= d4 <= 255:
286         return True
287     return False
288
289
290 def _net_read(sock, count, expiration):
291     """coro friendly replacement for dns.query._net_write
292     Read the specified number of bytes from sock.  Keep trying until we
293     either get the desired amount, or we hit EOF.
294     A Timeout exception will be raised if the operation is not completed
295     by the expiration time.
296     """
297     s = ''
298     while count > 0:
299         try:
300             n = sock.recv(count)
301         except socket.timeout:
302             # Q: Do we also need to catch coro.CoroutineSocketWake and pass?
303             if expiration - time.time() <= 0.0:
304                 raise dns.exception.Timeout
305         if n == '':
306             raise EOFError
307         count = count - len(n)
308         s = s + n
309     return s
310
311
312 def _net_write(sock, data, expiration):
313     """coro friendly replacement for dns.query._net_write
314     Write the specified data to the socket.
315     A Timeout exception will be raised if the operation is not completed
316     by the expiration time.
317     """
318     current = 0
319     l = len(data)
320     while current < l:
321         try:
322             current += sock.send(data[current:])
323         except socket.timeout:
324             # Q: Do we also need to catch coro.CoroutineSocketWake and pass?
325             if expiration - time.time() <= 0.0:
326                 raise dns.exception.Timeout
327
328
329 def udp(q, where, timeout=DNS_QUERY_TIMEOUT, port=53, af=None, source=None,
330         source_port=0, ignore_unexpected=False):
331     """coro friendly replacement for dns.query.udp
332     Return the response obtained after sending a query via UDP.
333
334     @param q: the query
335     @type q: dns.message.Message
336     @param where: where to send the message
337     @type where: string containing an IPv4 or IPv6 address
338     @param timeout: The number of seconds to wait before the query times out.
339     If None, the default, wait forever.
340     @type timeout: float
341     @param port: The port to which to send the message.  The default is 53.
342     @type port: int
343     @param af: the address family to use.  The default is None, which
344     causes the address family to use to be inferred from the form of of where.
345     If the inference attempt fails, AF_INET is used.
346     @type af: int
347     @rtype: dns.message.Message object
348     @param source: source address.  The default is the IPv4 wildcard address.
349     @type source: string
350     @param source_port: The port from which to send the message.
351     The default is 0.
352     @type source_port: int
353     @param ignore_unexpected: If True, ignore responses from unexpected
354     sources.  The default is False.
355     @type ignore_unexpected: bool"""
356
357     wire = q.to_wire()
358     if af is None:
359         try:
360             af = dns.inet.af_for_address(where)
361         except:
362             af = dns.inet.AF_INET
363     if af == dns.inet.AF_INET:
364         destination = (where, port)
365         if source is not None:
366             source = (source, source_port)
367     elif af == dns.inet.AF_INET6:
368         destination = (where, port, 0, 0)
369         if source is not None:
370             source = (source, source_port, 0, 0)
371
372     s = socket.socket(af, socket.SOCK_DGRAM)
373     s.settimeout(timeout)
374     try:
375         expiration = dns.query._compute_expiration(timeout)
376         if source is not None:
377             s.bind(source)
378         try:
379             s.sendto(wire, destination)
380         except socket.timeout:
381             # Q: Do we also need to catch coro.CoroutineSocketWake and pass?
382             if expiration - time.time() <= 0.0:
383                 raise dns.exception.Timeout
384         while 1:
385             try:
386                 (wire, from_address) = s.recvfrom(65535)
387             except socket.timeout:
388                 # Q: Do we also need to catch coro.CoroutineSocketWake and pass?
389                 if expiration - time.time() <= 0.0:
390                     raise dns.exception.Timeout
391             if from_address == destination:
392                 break
393             if not ignore_unexpected:
394                 raise dns.query.UnexpectedSource(
395                     'got a response from %s instead of %s'
396                     % (from_address, destination))
397     finally:
398         s.close()
399
400     r = dns.message.from_wire(wire, keyring=q.keyring, request_mac=q.mac)
401     if not q.is_response(r):
402         raise dns.query.BadResponse()
403     return r
404
405
406 def tcp(q, where, timeout=DNS_QUERY_TIMEOUT, port=53,
407         af=None, source=None, source_port=0):
408     """coro friendly replacement for dns.query.tcp
409     Return the response obtained after sending a query via TCP.
410
411     @param q: the query
412     @type q: dns.message.Message object
413     @param where: where to send the message
414     @type where: string containing an IPv4 or IPv6 address
415     @param timeout: The number of seconds to wait before the query times out.
416     If None, the default, wait forever.
417     @type timeout: float
418     @param port: The port to which to send the message.  The default is 53.
419     @type port: int
420     @param af: the address family to use.  The default is None, which
421     causes the address family to use to be inferred from the form of of where.
422     If the inference attempt fails, AF_INET is used.
423     @type af: int
424     @rtype: dns.message.Message object
425     @param source: source address.  The default is the IPv4 wildcard address.
426     @type source: string
427     @param source_port: The port from which to send the message.
428     The default is 0.
429     @type source_port: int"""
430
431     wire = q.to_wire()
432     if af is None:
433         try:
434             af = dns.inet.af_for_address(where)
435         except:
436             af = dns.inet.AF_INET
437     if af == dns.inet.AF_INET:
438         destination = (where, port)
439         if source is not None:
440             source = (source, source_port)
441     elif af == dns.inet.AF_INET6:
442         destination = (where, port, 0, 0)
443         if source is not None:
444             source = (source, source_port, 0, 0)
445     s = socket.socket(af, socket.SOCK_STREAM)
446     s.settimeout(timeout)
447     try:
448         expiration = dns.query._compute_expiration(timeout)
449         if source is not None:
450             s.bind(source)
451         try:
452             s.connect(destination)
453         except socket.timeout:
454             # Q: Do we also need to catch coro.CoroutineSocketWake and pass?
455             if expiration - time.time() <= 0.0:
456                 raise dns.exception.Timeout
457
458         l = len(wire)
459         # copying the wire into tcpmsg is inefficient, but lets us
460         # avoid writev() or doing a short write that would get pushed
461         # onto the net
462         tcpmsg = struct.pack("!H", l) + wire
463         _net_write(s, tcpmsg, expiration)
464         ldata = _net_read(s, 2, expiration)
465         (l,) = struct.unpack("!H", ldata)
466         wire = _net_read(s, l, expiration)
467     finally:
468         s.close()
469     r = dns.message.from_wire(wire, keyring=q.keyring, request_mac=q.mac)
470     if not q.is_response(r):
471         raise dns.query.BadResponse()
472     return r
473
474
475 def reset():
476     resolver.clear()
477
478 # Install our coro-friendly replacements for the tcp and udp query methods.
479 dns.query.tcp = tcp
480 dns.query.udp = udp