from twisted.conch.error import ConchError
from twisted.conch.ssh import channel, connection
from twisted.internet import defer, protocol, reactor
from twisted.python import log
from twisted.spread import banana
import os, stat, pickle
import types
class SSHUnixClientFactory(protocol.ClientFactory):
noisy = 1
def __init__(self, d, options, userAuthObject):
self.d = d
self.options = options
self.userAuthObject = userAuthObject
def clientConnectionFailed(self, connector, reason):
try:
os.unlink(connector.transport.addr)
except:
pass
if not self.d: return
d = self.d
self.d = None
d.errback(reason)
def startedConnecting(self, connector):
fd = connector.transport.fileno()
stats = os.fstat(fd)
try:
filestats = os.stat(connector.transport.addr)
except:
connector.stopConnecting()
return
if stat.S_IMODE(filestats[0]) != 0600:
log.msg("socket mode is not 0600: %s" % oct(stat.S_IMODE(stats[0])))
elif filestats[4] != os.getuid():
log.msg("socket not owned by us: %s" % stats[4])
elif filestats[5] != os.getgid():
log.msg("socket not owned by our group: %s" % stats[5])
else:
log.msg('conecting OK')
return
connector.stopConnecting()
def buildProtocol(self, addr):
obj = self.userAuthObject.instance
bases = []
for base in obj.__class__.__bases__:
if base == connection.SSHConnection:
bases.append(SSHUnixClientProtocol)
else:
bases.append(base)
newClass = types.ClassType(obj.__class__.__name__, tuple(bases), obj.__class__.__dict__)
obj.__class__ = newClass
SSHUnixClientProtocol.__init__(obj)
log.msg('returning %s' % obj)
return obj
class SSHUnixServerFactory(protocol.Factory):
def __init__(self, conn):
self.conn = conn
def buildProtocol(self, addr):
return SSHUnixServerProtocol(self.conn)
class SSHUnixProtocol(banana.Banana):
knownDialects = ['none']
def __init__(self):
banana.Banana.__init__(self)
self.deferredQueue = []
self.deferreds = {}
self.deferredID = 0
def connectionMade(self):
log.msg('connection made %s' % self)
banana.Banana.connectionMade(self)
def expressionReceived(self, lst):
vocabName = lst[0]
fn = "msg_%s" % vocabName
func = getattr(self, fn)
func(lst[1:])
def sendMessage(self, vocabName, *tup):
self.sendEncoded([vocabName] + list(tup))
def returnDeferredLocal(self):
d = defer.Deferred()
self.deferredQueue.append(d)
return d
def returnDeferredWire(self, d):
di = self.deferredID
self.deferredID += 1
self.sendMessage('returnDeferred', di)
d.addCallback(self._cbDeferred, di)
d.addErrback(self._ebDeferred, di)
def _cbDeferred(self, result, di):
self.sendMessage('callbackDeferred', di, pickle.dumps(result))
def _ebDeferred(self, reason, di):
self.sendMessage('errbackDeferred', di, pickle.dumps(reason))
def msg_returnDeferred(self, lst):
deferredID = lst[0]
self.deferreds[deferredID] = self.deferredQueue.pop(0)
def msg_callbackDeferred(self, lst):
deferredID, result = lst
d = self.deferreds[deferredID]
del self.deferreds[deferredID]
d.callback(pickle.loads(result))
def msg_errbackDeferred(self, lst):
deferredID, result = lst
d = self.deferreds[deferredID]
del self.deferreds[deferredID]
d.errback(pickle.loads(result))
class SSHUnixClientProtocol(SSHUnixProtocol):
def __init__(self):
SSHUnixProtocol.__init__(self)
self.isClient = 1
self.channelQueue = []
self.channels = {}
def connectionReady(self):
log.msg('connection ready')
self.serviceStarted()
def connectionLost(self, reason):
self.serviceStopped()
def requestRemoteForwarding(self, remotePort, hostport):
self.sendMessage('requestRemoteForwarding', remotePort, hostport)
def cancelRemoteForwarding(self, remotePort):
self.sendMessage('cancelRemoteForwarding', remotePort)
def sendGlobalRequest(self, request, data, wantReply = 0):
self.sendMessage('sendGlobalRequest', request, data, wantReply)
if wantReply:
return self.returnDeferredLocal()
def openChannel(self, channel, extra = ''):
self.channelQueue.append(channel)
channel.conn = self
self.sendMessage('openChannel', channel.name,
channel.localWindowSize,
channel.localMaxPacket, extra)
def sendRequest(self, channel, requestType, data, wantReply = 0):
self.sendMessage('sendRequest', channel.id, requestType, data, wantReply)
if wantReply:
self.returnDeferredLocal()
def adjustWindow(self, channel, bytesToAdd):
self.sendMessage('adjustWindow', channel.id, bytesToAdd)
def sendData(self, channel, data):
self.sendMessage('sendData', channel.id, data)
def sendExtendedData(self, channel, dataType, data):
self.sendMessage('sendExtendedData', channel.id, data)
def sendEOF(self, channel):
self.sendMessage('sendEOF', channel.id)
def sendClose(self, channel):
self.sendMessage('sendClose', channel.id)
def msg_channelID(self, lst):
channelID = lst[0]
self.channels[channelID] = self.channelQueue.pop(0)
self.channels[channelID].id = channelID
def msg_channelOpen(self, lst):
channelID, remoteWindow, remoteMax, specificData = lst
channel = self.channels[channelID]
channel.remoteWindowLeft = remoteWindow
channel.remoteMaxPacket = remoteMax
channel.channelOpen(specificData)
def msg_openFailed(self, lst):
channelID, reason = lst
self.channels[channelID].openFailed(pickle.loads(reason))
del self.channels[channelID]
def msg_addWindowBytes(self, lst):
channelID, bytes = lst
self.channels[channelID].addWindowBytes(bytes)
def msg_requestReceived(self, lst):
channelID, requestType, data = lst
d = defer.maybeDeferred(self.channels[channelID].requestReceived, requestType, data)
self.returnDeferredWire(d)
def msg_dataReceived(self, lst):
channelID, data = lst
self.channels[channelID].dataReceived(data)
def msg_extReceived(self, lst):
channelID, dataType, data = lst
self.channels[channelID].extReceived(dataType, data)
def msg_eofReceived(self, lst):
channelID = lst[0]
self.channels[channelID].eofReceived()
def msg_closed(self, lst):
channelID = lst[0]
self.channels[channelID].closed()
del self.channels[channelID]
class SSHUnixServerProtocol(SSHUnixProtocol):
def __init__(self, conn):
SSHUnixProtocol.__init__(self)
self.isClient = 0
self.conn = conn
def haveChannel(self, channelID):
return self.conn.channels.has_key(channelID)
def getChannel(self, channelID):
channel = self.conn.channels[channelID]
if not isinstance(channel, SSHUnixChannel):
raise ConchError('nice try bub')
return channel
def msg_requestRemoteForwarding(self, lst):
remotePort, hostport = lst
hostport = tuple(hostport)
self.conn.requestRemoteForwarding(remotePort, hostport)
def msg_cancelRemoteForwarding(self, lst):
[remotePort] = lst
self.conn.cancelRemoteForwarding(remotePort)
def msg_sendGlobalRequest(self, lst):
requestName, data, wantReply = lst
d = self.conn.sendGlobalRequest(requestName, data, wantReply)
if wantReply:
self.returnDeferred(d)
def msg_openChannel(self, lst):
name, windowSize, maxPacket, extra = lst
channel = SSHUnixChannel(self, name, windowSize, maxPacket)
self.conn.openChannel(channel, extra)
self.sendMessage('channelID', channel.id)
def msg_sendRequest(self, lst):
cn, requestType, data, wantReply = lst
if not self.haveChannel(cn):
if wantReply:
self.returnDeferred(defer.fail(ConchError("no channel")))
channel = self.getChannel(cn)
d = self.conn.sendRequest(channel, requestType, data, wantReply)
if wantReply:
self.returnDeferredWire(d)
def msg_adjustWindow(self, lst):
cn, bytesToAdd = lst
if not self.haveChannel(cn): return
channel = self.getChannel(cn)
self.conn.adjustWindow(channel, bytesToAdd)
def msg_sendData(self, lst):
cn, data = lst
if not self.haveChannel(cn): return
channel = self.getChannel(cn)
self.conn.sendData(channel, data)
def msg_sendExtended(self, lst):
cn, dataType, data = lst
if not self.haveChannel(cn): return
channel = self.getChannel(cn)
self.conn.sendExtendedData(channel, dataType, data)
def msg_sendEOF(self, lst):
(cn, ) = lst
if not self.haveChannel(cn): return
channel = self.getChannel(cn)
self.conn.sendEOF(channel)
def msg_sendClose(self, lst):
(cn, ) = lst
if not self.haveChannel(cn): return
channel = self.getChannel(cn)
self.conn.sendClose(channel)
class SSHUnixChannel(channel.SSHChannel):
def __init__(self, unix, name, windowSize, maxPacket):
channel.SSHChannel.__init__(self, windowSize, maxPacket, conn = unix.conn)
self.unix = unix
self.name = name
def channelOpen(self, specificData):
self.unix.sendMessage('channelOpen', self.id, self.remoteWindowLeft,
self.remoteMaxPacket, specificData)
def openFailed(self, reason):
self.unix.sendMessage('openFailed', self.id, pickle.dumps(reason))
def addWindowBytes(self, bytes):
self.unix.sendMessage('addWindowBytes', self.id, bytes)
def dataReceived(self, data):
self.unix.sendMessage('dataReceived', self.id, data)
def requestReceived(self, reqType, data):
self.unix.sendMessage('requestReceived', self.id, reqType, data)
return self.unix.returnDeferredLocal()
def extReceived(self, dataType, data):
self.unix.sendMessage('extReceived', self.id, dataType, data)
def eofReceived(self):
self.unix.sendMessage('eofReceived', self.id)
def closed(self):
self.unix.sendMessage('closed', self.id)
def connect(host, port, options, verifyHostKey, userAuthObject):
if options['nocache']:
return defer.fail(ConchError('not using connection caching'))
d = defer.Deferred()
filename = os.path.expanduser("~/.conch-%s-%s-%i" % (userAuthObject.user, host, port))
factory = SSHUnixClientFactory(d, options, userAuthObject)
reactor.connectUNIX(filename, factory, timeout=2, checkPID=1)
return d