default.py   [plain text]


# Twisted, the Framework of Your Internet
# Copyright (C) 2001-2004 Matthew W. Lefkowitz
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of version 2.1 of the GNU Lesser General Public
# License as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
#
from twisted.conch.error import ConchError
from twisted.conch.ssh import common, keys, userauth, agent
from twisted.internet import defer, protocol, reactor
from twisted.python import log

import agent

import os, sys, base64, getpass

def verifyHostKey(transport, host, pubKey, fingerprint):
    goodKey = isInKnownHosts(host, pubKey)
    if goodKey == 1: # good key
        return defer.succeed(1)
    elif goodKey == 2: # AAHHHHH changed
        return defer.fail(ConchError('changed host key'))
    else:
        oldout, oldin = sys.stdout, sys.stdin
        sys.stdin = sys.stdout = open('/dev/tty','r+')
        if host == transport.transport.getPeer().host:
            khHost = host
        else:
            host = '%s (%s)' % (host,
                                transport.transport.getPeer().host)
            khHost = '%s,%s' % (host,
                                transport.transport.getPeer().host)
        keyType = common.getNS(pubKey)[0]
        print """The authenticity of host '%s' can't be extablished.
%s key fingerprint is %s.""" % (host,
                            {'ssh-dss':'DSA', 'ssh-rsa':'RSA'}[keyType],
                            fingerprint)
        try:
            ans = raw_input('Are you sure you want to continue connecting (yes/no)? ')
        except KeyboardInterrupt:
            return defer.fail(ConchError("^C"))
        while ans.lower() not in ('yes', 'no'):
            ans = raw_input("Please type 'yes' or 'no': ")
        sys.stdout,sys.stdin=oldout,oldin
        if ans == 'no':
            print 'Host key verification failed.'
            return defer.fail(ConchError('bad host key'))
        print "Warning: Permanently added '%s' (%s) to the list of known hosts." % (khHost, {'ssh-dss':'DSA', 'ssh-rsa':'RSA'}[keyType])
        known_hosts = open(os.path.expanduser('~/.ssh/known_hosts'), 'a')
        encodedKey = base64.encodestring(pubKey).replace('\n', '')
        known_hosts.write('\n%s %s %s' % (khHost, keyType, encodedKey))
        known_hosts.close()
        return defer.succeed(1)

def isInKnownHosts(host, pubKey):
    """checks to see if host is in the known_hosts file for the user.
    returns 0 if it isn't, 1 if it is and is the same, 2 if it's changed.
    """
    keyType = common.getNS(pubKey)[0]
    retVal = 0
    if not os.path.exists(os.path.expanduser('~/.ssh/')):
        print 'Creating ~/.ssh directory...'
        os.mkdir(os.path.expanduser('~/.ssh'))
    try:
        known_hosts = open(os.path.expanduser('~/.ssh/known_hosts'))
    except IOError:
        return 0
    for line in known_hosts.xreadlines():
        split = line.split()
        if len(split) != 3: # old 4-field known_hosts entry (ssh1?)
            continue
        hosts, hostKeyType, encodedKey = split
        if host not in hosts.split(','): # incorrect host
            continue
        if hostKeyType != keyType: # incorrect type of key
            continue
        try:
            decodedKey = base64.decodestring(encodedKey)
        except:
            continue
        if decodedKey == pubKey:
            return 1
        else:
            retVal = 2
    return retVal

class SSHUserAuthClient(userauth.SSHUserAuthClient):
    usedFiles = []

    def __init__(self, user, options, *args):
        userauth.SSHUserAuthClient.__init__(self, user, *args)
        self.keyAgent = None
        self.options = options

    def serviceStarted(self):
        if 'SSH_AUTH_SOCK' in os.environ:
            log.msg('using agent')
            cc = protocol.ClientCreator(reactor, agent.SSHAgentClient)
            d = cc.connectUNIX(os.environ['SSH_AUTH_SOCK'])
            d.addCallback(self._setAgent)
            d.addErrback(self._ebSetAgent)
        else:
            userauth.SSHUserAuthClient.serviceStarted(self)

    def _setAgent(self, a):
        self.keyAgent = a
        d = self.keyAgent.getPublicKeys()
        d.addBoth(self._ebSetAgent)
        return d

    def _ebSetAgent(self, f):
        userauth.SSHUserAuthClient.serviceStarted(self)

    def getPassword(self, prompt = None):
        if not prompt:
            prompt = "%s@%s's password: " % (self.user, self.transport.transport.getPeer().host)
        try:
            oldout, oldin = sys.stdout, sys.stdin
            sys.stdin = sys.stdout = open('/dev/tty','r+')
            p=getpass.getpass(prompt)
            sys.stdout,sys.stdin=oldout,oldin
            return defer.succeed(p)
        except KeyboardInterrupt, e:
            print
            return defer.fail(ConchError('PEBKAC'))

    def getPublicKey(self):
        if self.keyAgent:
            blob = self.keyAgent.getPublicKey()
            if blob:
                return blob
        files = [x for x in self.options.identitys if x not in self.usedFiles]
        log.msg(str(self.options.identitys))
        log.msg(str(files))
        if not files:
            return None
        file = files[0]
        log.msg(file)
        self.usedFiles.append(file)
        file = os.path.expanduser(file)
        file += '.pub'
        if not os.path.exists(file):
            return self.getPublicKey() # try again
        try:
            return keys.getPublicKeyString(file)
        except:
            return self.getPublicKey() # try again

    def signData(self, publicKey, signData):
        if not self.usedFiles: # agent key
            return self.keyAgent.signData(publicKey, signData)
        else:
            return userauth.SSHUserAuthClient.signData(self, publicKey, signData)

    def getPrivateKey(self):
        file = os.path.expanduser(self.usedFiles[-1])
        if not os.path.exists(file):
            return None
        try:
            return defer.succeed(keys.getPrivateKeyObject(file))
        except keys.BadKeyError, e:
            if e.args[0] == 'encrypted key with no passphrase':
                for i in range(3):
                    prompt = "Enter passphrase for key '%s': " % \
                           self.usedFiles[-1]
                    oldout, oldin = sys.stdout, sys.stdin
                    sys.stdin = sys.stdout = open('/dev/tty','r+')
                    p=getpass.getpass(prompt)
                    sys.stdout,sys.stdin=oldout,oldin
                    try:
                        return defer.succeed(keys.getPrivateKeyObject(file, passphrase = p))
                    except keys.BadKeyError:
                        pass
                return defer.fail(ConchError('bad password'))
            raise
        except KeyboardInterrupt:
            print
            reactor.stop()