"""
Asynchronous client DNS
API Stability: Unstable
Future plans: Proper nameserver acquisition on Windows/MacOS,
better caching, respect timeouts
@author: U{Jp Calderone<mailto:exarkun@twistedmatrix.com>}
"""
from __future__ import nested_scopes
import socket
import os
import errno
import time
from twisted.python.runtime import platform
from twisted.internet import defer, protocol, interfaces, threads
from twisted.python import log, failure
from twisted.protocols import dns
import common
class Resolver(common.ResolverBase):
__implements__ = (interfaces.IResolver,)
index = 0
timeout = None
factory = None
servers = None
dynServers = ()
pending = None
protocol = None
connections = None
resolv = None
_lastResolvTime = None
_resolvReadInterval = 60
def __init__(self, resolv = None, servers = None, timeout = (1, 3, 11, 45)):
"""
@type servers: C{list} of C{(str, int)} or C{None}
@param servers: If not None, interpreted as a list of addresses of
domain name servers to attempt to use for this lookup. Addresses
should be in dotted-quad form. If specified, overrides C{resolv}.
@type resolv: C{str}
@param resolv: Filename to read and parse as a resolver(5)
configuration file.
@type timeout: Sequence of C{int}
@param timeout: Default number of seconds after which to reissue the query.
When the last timeout expires, the query is considered failed.
@raise ValueError: Raised if no nameserver addresses can be found.
"""
common.ResolverBase.__init__(self)
self.timeout = timeout
if servers is None:
self.servers = []
else:
self.servers = servers
self.resolv = resolv
if not len(self.servers) and not resolv:
raise ValueError, "No nameservers specified"
self.factory = DNSClientFactory(self, timeout)
self.factory.noisy = 0
self.protocol = dns.DNSDatagramProtocol(self)
self.protocol.noisy = 0
self.connections = []
self.pending = []
self.maybeParseConfig()
def __getstate__(self):
d = self.__dict__.copy()
d['connections'] = []
d['_parseCall'] = None
return d
def __setstate__(self, state):
self.__dict__.update(state)
self.maybeParseConfig()
def maybeParseConfig(self):
if self.resolv is None:
return
try:
resolvConf = file(self.resolv)
except IOError, e:
if e.errno == errno.ENOENT:
pass
else:
raise
else:
mtime = os.fstat(resolvConf.fileno()).st_mtime
if mtime != self._lastResolvTime:
log.msg('%s changed, reparsing' % (self.resolv,))
self._lastResolvTime = mtime
self.parseConfig(resolvConf)
from twisted.internet import reactor
self._parseCall = reactor.callLater(self._resolvReadInterval, self.maybeParseConfig)
def parseConfig(self, resolvConf):
servers = []
for L in resolvConf:
L = L.strip()
if L.startswith('nameserver'):
resolver = (L.split()[1], dns.PORT)
servers.append(resolver)
log.msg("Resolver added %r to server list" % (resolver,))
elif L.startswith('domain'):
try:
self.domain = L.split()[1]
except IndexError:
self.domain = ''
self.search = None
elif L.startswith('search'):
try:
self.search = L.split()[1:]
except IndexError:
self.search = ''
self.domain = None
self.dynServers = servers
def pickServer(self):
"""
Return the address of a nameserver.
TODO: Weight servers for response time so faster ones can be
preferred.
"""
if not self.servers and not self.dynServers:
return None
serverL = len(self.servers)
dynL = len(self.dynServers)
self.index += 1
self.index %= (serverL + dynL)
if self.index < serverL:
return self.servers[self.index]
else:
return self.dynServers[self.index - serverL]
def connectionMade(self, protocol):
self.connections.append(protocol)
for (d, q, t) in self.pending:
self.queryTCP(q, t).chainDeferred(d)
del self.pending[:]
def messageReceived(self, message, protocol, address = None):
log.msg("Unexpected message (%d) received from %r" % (message.id, address))
def queryUDP(self, queries, timeout = None):
"""
Make a number of DNS queries via UDP.
@type queries: A C{list} of C{dns.Query} instances
@param queries: The queries to make.
@type timeout: Sequence of C{int}
@param timeout: Number of seconds after which to reissue the query.
When the last timeout expires, the query is considered failed.
@rtype: C{Deferred}
@raise C{twisted.internet.defer.TimeoutError}: When the query times
out.
"""
if timeout is None:
timeout = self.timeout
address = self.pickServer()
if address is None:
return defer.fail(IOError("No domain name servers available"))
return self.protocol.query(address, queries, timeout[0]
).addErrback(self._reissue, address, queries, timeout[1:]
)
def _reissue(self, reason, address, query, timeout):
reason.trap(defer.TimeoutError)
if timeout and self.protocol.transport:
d = self.protocol.query(address, query, timeout[0], reason.value.id)
d.addErrback(self._reissue, address, query, timeout[1:])
return d
try:
del self.protocol.resends[reason.value.id]
except:
pass
return failure.Failure(defer.TimeoutError(query))
def queryTCP(self, queries, timeout = 10):
"""
Make a number of DNS queries via TCP.
@type queries: Any non-zero number of C{dns.Query} instances
@param queries: The queries to make.
@type timeout: C{int}
@param timeout: The number of seconds after which to fail.
@rtype: C{Deferred}
"""
if not len(self.connections):
address = self.pickServer()
if address is None:
return defer.fail(IOError("No domain name servers available"))
host, port = address
from twisted.internet import reactor
reactor.connectTCP(host, port, self.factory)
self.pending.append((defer.Deferred(), queries, timeout))
return self.pending[-1][0]
else:
return self.connections[0].query(queries, timeout)
def filterAnswers(self, message):
if message.trunc:
return self.queryTCP(message.queries).addCallback(self.filterAnswers)
else:
return (message.answers, message.authority, message.additional)
def _lookup(self, name, cls, type, timeout):
return self.queryUDP(
[dns.Query(name, type, cls)], timeout
).addCallback(self.filterAnswers)
def lookupZone(self, name, timeout = 10):
"""
Perform an AXFR request. This is quite different from usual
DNS requests. See http://cr.yp.to/djbdns/axfr-notes.html for
more information.
"""
address = self.pickServer()
if address is None:
return defer.fail(IOError('No domain name servers available'))
host,port = address
from twisted.internet import reactor
d = defer.Deferred()
d.setTimeout(timeout or 10)
controller = AXFRController(name, d)
factory = DNSClientFactory(controller, timeout)
factory.noisy = False reactor.connectTCP(host, port, factory)
return d.addCallback(lambda x: (x, [], []))
class AXFRController:
def __init__(self, name, deferred):
self.name = name
self.deferred = deferred
self.soa = None
self.records = []
def connectionMade(self, protocol):
message = dns.Message(protocol.pickID(), recDes=0)
message.queries = [dns.Query(self.name, dns.AXFR, dns.IN)]
protocol.writeMessage(message)
def messageReceived(self, message, protocol):
self.records.extend(message.answers)
if not self.records:
return
if not self.soa:
if self.records[0].type == dns.SOA:
self.soa = self.records[0]
if len(self.records) > 1 and self.records[-1].type == dns.SOA:
self.deferred.callback(self.records)
class ThreadedResolver:
__implements__ = (interfaces.IResolverSimple,)
def __init__(self):
self.cache = {}
def getHostByName(self, name, timeout = 10):
d = threads.deferToThread(socket.gethostbyname, name)
d.setTimeout(timeout)
return d
class DNSClientFactory(protocol.ClientFactory):
def __init__(self, controller, timeout = 10):
self.controller = controller
self.timeout = timeout
def clientConnectionLost(self, connector, reason):
pass
def buildProtocol(self, addr):
p = dns.DNSProtocol(self.controller)
p.factory = self
return p
def createResolver(servers = None, resolvconf = None, hosts = None):
from twisted.names import resolve, cache, root, hosts as hostsModule
if platform.getType() == 'posix':
if resolvconf is None:
resolvconf = '/etc/resolv.conf'
if hosts is None:
hosts = '/etc/hosts'
theResolver = Resolver(resolvconf, servers)
hostResolver = hostsModule.Resolver(hosts)
else:
if hosts is None:
hosts = r'c:\windows\hosts'
bootstrap = ThreadedResolver()
hostResolver = hostsModule.Resolver(hosts)
theResolver = root.bootstrap(bootstrap)
L = [hostResolver, cache.CacheResolver(), theResolver]
return resolve.ResolverChain(L)
theResolver = None
def _makeLookup(method):
def lookup(*a, **kw):
global theResolver
if theResolver is None:
try:
theResolver = createResolver()
except ValueError:
theResolver = createResolver(servers=[('127.0.0.1', 53)])
return getattr(theResolver, method)(*a, **kw)
return lookup
for method in common.typeToMethod.values():
globals()[method] = _makeLookup(method)
del method
getHostByName = _makeLookup('getHostByName')