unix.py   [plain text]


# Twisted, the Framework of Your Internet
# Copyright (C) 2001-2004 Matthew W. Lefkowitz
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of version 2.1 of the GNU Lesser General Public
# License as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
#

from twisted.conch.error import ConchError
from twisted.conch.ssh import channel, connection
from twisted.internet import defer, protocol, reactor
from twisted.python import log
from twisted.spread import banana

import os, stat, pickle
import types # this is for evil

class SSHUnixClientFactory(protocol.ClientFactory):
    noisy = 1

    def __init__(self, d, options, userAuthObject):
        self.d = d
        self.options = options
        self.userAuthObject = userAuthObject
        
#    def clientConnectionLost(self, connector, reason):
#        stopConnection()

    def clientConnectionFailed(self, connector, reason):
        try:
            os.unlink(connector.transport.addr)
        except:
            pass
        if not self.d: return
        d = self.d
        self.d = None
        d.errback(reason)
        #reactor.connectTCP(options['host'], options['port'], SSHClientFactory())
    
    def startedConnecting(self, connector):
        fd = connector.transport.fileno()
        stats = os.fstat(fd)
        try:
            filestats = os.stat(connector.transport.addr)
        except:
            connector.stopConnecting()
            return
        if stat.S_IMODE(filestats[0]) != 0600:
            log.msg("socket mode is not 0600: %s" % oct(stat.S_IMODE(stats[0])))
        elif filestats[4] != os.getuid():
            log.msg("socket not owned by us: %s" % stats[4])
        elif filestats[5] != os.getgid():
            log.msg("socket not owned by our group: %s" % stats[5])
        # XXX reenable this when i can fix it for cygwin
        #elif filestats[-3:] != stats[-3:]:
        #    log.msg("socket doesn't have same create times")
        else:
            log.msg('conecting OK')
            return
        connector.stopConnecting()

    def buildProtocol(self, addr):
        # here comes the EVIL
        obj = self.userAuthObject.instance
        bases = []
        for base in obj.__class__.__bases__:
            if base == connection.SSHConnection:
                bases.append(SSHUnixClientProtocol)
            else:
                bases.append(base)
        newClass = types.ClassType(obj.__class__.__name__, tuple(bases), obj.__class__.__dict__)
        obj.__class__ = newClass
        SSHUnixClientProtocol.__init__(obj)
        log.msg('returning %s' % obj)
        return obj

class SSHUnixServerFactory(protocol.Factory):
    def __init__(self, conn):
        self.conn = conn
        
    def buildProtocol(self, addr):
        return SSHUnixServerProtocol(self.conn)

class SSHUnixProtocol(banana.Banana):

    knownDialects = ['none']

    def __init__(self):
        banana.Banana.__init__(self)
        self.deferredQueue = []
        self.deferreds = {}
        self.deferredID = 0

    def connectionMade(self):
        log.msg('connection made %s' % self)
        banana.Banana.connectionMade(self)

    def expressionReceived(self, lst):
        vocabName = lst[0]
        fn = "msg_%s" % vocabName
        func = getattr(self, fn)
        func(lst[1:])

    def sendMessage(self, vocabName, *tup):
        self.sendEncoded([vocabName] + list(tup))

    def returnDeferredLocal(self):
        d = defer.Deferred()
        self.deferredQueue.append(d)
        return d

    def returnDeferredWire(self, d):
        di = self.deferredID
        self.deferredID += 1
        self.sendMessage('returnDeferred', di)
        d.addCallback(self._cbDeferred, di)
        d.addErrback(self._ebDeferred, di)

    def _cbDeferred(self, result, di):
        self.sendMessage('callbackDeferred', di, pickle.dumps(result))

    def _ebDeferred(self, reason, di):
        self.sendMessage('errbackDeferred', di, pickle.dumps(reason))

    def msg_returnDeferred(self, lst):
        deferredID = lst[0]
        self.deferreds[deferredID] = self.deferredQueue.pop(0)

    def msg_callbackDeferred(self, lst):
        deferredID, result = lst
        d = self.deferreds[deferredID]
        del self.deferreds[deferredID]
        d.callback(pickle.loads(result))

    def msg_errbackDeferred(self, lst):
        deferredID, result = lst
        d = self.deferreds[deferredID]
        del self.deferreds[deferredID]
        d.errback(pickle.loads(result))

class SSHUnixClientProtocol(SSHUnixProtocol):

    def __init__(self):
        SSHUnixProtocol.__init__(self)
        self.isClient = 1
        self.channelQueue = []
        self.channels = {}

    def connectionReady(self):
        log.msg('connection ready')
        self.serviceStarted()

    def connectionLost(self, reason):
        self.serviceStopped()

    def requestRemoteForwarding(self, remotePort, hostport):
        self.sendMessage('requestRemoteForwarding', remotePort, hostport)

    def cancelRemoteForwarding(self, remotePort):
        self.sendMessage('cancelRemoteForwarding', remotePort)

    def sendGlobalRequest(self, request, data, wantReply = 0):
        self.sendMessage('sendGlobalRequest', request, data, wantReply)
        if wantReply:
            return self.returnDeferredLocal()

    def openChannel(self, channel, extra = ''):
        self.channelQueue.append(channel)
        channel.conn = self
        self.sendMessage('openChannel', channel.name,
                                        channel.localWindowSize,
                                        channel.localMaxPacket, extra)

    def sendRequest(self, channel, requestType, data, wantReply = 0):
        self.sendMessage('sendRequest', channel.id, requestType, data, wantReply)
        if wantReply:
            self.returnDeferredLocal()

    def adjustWindow(self, channel, bytesToAdd):
        self.sendMessage('adjustWindow', channel.id, bytesToAdd)

    def sendData(self, channel, data):
        self.sendMessage('sendData', channel.id, data)

    def sendExtendedData(self, channel, dataType, data):
        self.sendMessage('sendExtendedData', channel.id, data)

    def sendEOF(self, channel):
        self.sendMessage('sendEOF', channel.id)

    def sendClose(self, channel):
        self.sendMessage('sendClose', channel.id)

    def msg_channelID(self, lst):
        channelID = lst[0]
        self.channels[channelID] = self.channelQueue.pop(0)
        self.channels[channelID].id = channelID

    def msg_channelOpen(self, lst):
        channelID, remoteWindow, remoteMax, specificData = lst
        channel = self.channels[channelID]
        channel.remoteWindowLeft = remoteWindow
        channel.remoteMaxPacket = remoteMax
        channel.channelOpen(specificData)

    def msg_openFailed(self, lst):
        channelID, reason = lst
        self.channels[channelID].openFailed(pickle.loads(reason))
        del self.channels[channelID]

    def msg_addWindowBytes(self, lst):
        channelID, bytes = lst
        self.channels[channelID].addWindowBytes(bytes)

    def msg_requestReceived(self, lst):
        channelID, requestType, data = lst
        d = defer.maybeDeferred(self.channels[channelID].requestReceived, requestType, data)
        self.returnDeferredWire(d)

    def msg_dataReceived(self, lst):
        channelID, data = lst
        self.channels[channelID].dataReceived(data)

    def msg_extReceived(self, lst):
        channelID, dataType, data = lst
        self.channels[channelID].extReceived(dataType, data)

    def msg_eofReceived(self, lst):
        channelID = lst[0]
        self.channels[channelID].eofReceived()

    def msg_closed(self, lst):
        channelID = lst[0]
        self.channels[channelID].closed()
        del self.channels[channelID]

class SSHUnixServerProtocol(SSHUnixProtocol):

    def __init__(self, conn):
        SSHUnixProtocol.__init__(self)
        self.isClient = 0
        self.conn = conn

    def haveChannel(self, channelID):
        return self.conn.channels.has_key(channelID)

    def getChannel(self, channelID):
        channel = self.conn.channels[channelID]
        if not isinstance(channel, SSHUnixChannel):
            raise ConchError('nice try bub')
        return channel

    def msg_requestRemoteForwarding(self, lst):
        remotePort, hostport = lst
        hostport = tuple(hostport)
        self.conn.requestRemoteForwarding(remotePort, hostport)

    def msg_cancelRemoteForwarding(self, lst):
        [remotePort] = lst
        self.conn.cancelRemoteForwarding(remotePort)

    def msg_sendGlobalRequest(self, lst):
        requestName, data, wantReply = lst
        d = self.conn.sendGlobalRequest(requestName, data, wantReply)
        if wantReply:
            self.returnDeferred(d)

    def msg_openChannel(self, lst):
        name, windowSize, maxPacket, extra = lst
        channel = SSHUnixChannel(self, name, windowSize, maxPacket)
        self.conn.openChannel(channel, extra)
        self.sendMessage('channelID', channel.id)

    def msg_sendRequest(self, lst):
        cn, requestType, data, wantReply = lst
        if not self.haveChannel(cn):
            if wantReply:
                self.returnDeferred(defer.fail(ConchError("no channel")))
        channel = self.getChannel(cn)
        d = self.conn.sendRequest(channel, requestType, data, wantReply)
        if wantReply:
            self.returnDeferredWire(d)

    def msg_adjustWindow(self, lst):
        cn, bytesToAdd = lst
        if not self.haveChannel(cn): return
        channel = self.getChannel(cn)
        self.conn.adjustWindow(channel, bytesToAdd)

    def msg_sendData(self, lst):
        cn, data = lst
        if not self.haveChannel(cn): return
        channel = self.getChannel(cn)
        self.conn.sendData(channel, data)

    def msg_sendExtended(self, lst):
        cn, dataType, data = lst
        if not self.haveChannel(cn): return
        channel = self.getChannel(cn)
        self.conn.sendExtendedData(channel, dataType, data)

    def msg_sendEOF(self, lst):
        (cn, ) = lst
        if not self.haveChannel(cn): return
        channel = self.getChannel(cn)
        self.conn.sendEOF(channel)

    def msg_sendClose(self, lst):
        (cn, ) = lst
        if not self.haveChannel(cn): return
        channel = self.getChannel(cn)
        self.conn.sendClose(channel)

class SSHUnixChannel(channel.SSHChannel):
    def __init__(self, unix, name, windowSize, maxPacket):
        channel.SSHChannel.__init__(self, windowSize, maxPacket, conn = unix.conn)
        self.unix = unix
        self.name = name

    def channelOpen(self, specificData):
        self.unix.sendMessage('channelOpen', self.id, self.remoteWindowLeft,
                                             self.remoteMaxPacket, specificData)

    def openFailed(self, reason):
        self.unix.sendMessage('openFailed', self.id, pickle.dumps(reason))

    def addWindowBytes(self, bytes):
        self.unix.sendMessage('addWindowBytes', self.id, bytes)

    def dataReceived(self, data):
        self.unix.sendMessage('dataReceived', self.id, data)

    def requestReceived(self, reqType, data):
        self.unix.sendMessage('requestReceived', self.id, reqType, data)
        return self.unix.returnDeferredLocal()

    def extReceived(self, dataType, data):
        self.unix.sendMessage('extReceived', self.id, dataType, data)

    def eofReceived(self):
        self.unix.sendMessage('eofReceived', self.id)

    def closed(self):
        self.unix.sendMessage('closed', self.id)

def connect(host, port, options, verifyHostKey, userAuthObject):
    if options['nocache']: 
        return defer.fail(ConchError('not using connection caching'))
    d = defer.Deferred()
    filename = os.path.expanduser("~/.conch-%s-%s-%i" % (userAuthObject.user, host, port))
    factory = SSHUnixClientFactory(d, options, userAuthObject)
    reactor.connectUNIX(filename, factory, timeout=2, checkPID=1)
    return d