agent.py   [plain text]


# Twisted, the Framework of Your Internet
# Copyright (C) 2001-2003 Matthew W. Lefkowitz
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of version 2.1 of the GNU Lesser General Public
# License as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
#
"""
Implements the key agent protocol.

This module is unstable.

Maintainer: U{Paul Swartz<mailto:z3p@twistedmatrix.com>}
"""

import struct
from channel import SSHChannel
from common import NS, getNS
from twisted.conch.error import ConchError
from twisted.internet import defer, protocol

class SSHAgentClient(protocol.Protocol):
    
    def __init__(self):
        self.buf = ''
        self.deferreds = []

    def dataReceived(self, data):
        self.buf += data
        while 1:
            if len(self.buf) <= 4: return
            packLen = struct.unpack('!L', self.buf[:4])[0]
            if len(self.buf) < 4+packLen: return
            packet, self.buf = self.buf[4:4+packLen], self.buf[4+packLen:]
            reqType = ord(packet[0])
            d = self.deferreds.pop(0)
            if reqType == AGENT_FAILURE:
                d.errback(ConchError('agent failure'))
            elif reqType == AGENT_SUCCESS:
                d.callback('')
            else:
                d.callback(packet)

    def sendRequest(self, reqType, data):
        pack = struct.pack('!LB',len(data)+1, reqType)+data
        self.transport.write(pack)
        d = defer.Deferred()
        self.deferreds.append(d)
        return d

    def requestIdentities(self):
        return self.sendRequest(AGENTC_REQUEST_IDENTITIES, '').addCallback(self._cbRequestIdentities)

    def _cbRequestIdentities(self, data):
        if ord(data[0]) != AGENT_IDENTITIES_ANSWER:
            return ConchError('unexpected respone: %i' % ord(data[0]))
        numKeys = struct.unpack('!L', data[1:5])[0]
        keys = []
        data = data[5:]
        for i in range(numKeys):
            blobLen = struct.unpack('!L', data[:4])[0]
            blob, data = data[4:4+blobLen], data[4+blobLen:]
            commLen = struct.unpack('!L', data[:4])[0]
            comm, data = data[4:4+commLen], data[4+commLen:]
            keys.append((blob, comm))
        return keys

    def addIdentity(self, blob, comment = ''):
        req = blob
        req += NS(comment)
        co
        return self.sendRequest(AGENTC_ADD_IDENTITY, req)

    def signData(self,blob, data):
        req = NS(blob)
        req += NS(data)
        req += '\000\000\000\000' # flags
        return self.sendRequest(AGENTC_SIGN_REQUEST, req).addCallback(self._cbSignData)

    def _cbSignData(self, data):
        if data[0] != chr(AGENT_SIGN_RESPONSE):
            return ConchError('unexpected data: %i' % ord(data[0]))
        signature = getNS(data[1:])[0]
        return signature

    def removeIdentity(self, blob):
        req = NS(blob)
        return self.sendRequest(AGENTC_REMOVE_IDENTITY, req)

    def removeAllIdentities(self):
        return self.sendRequest(AGENTC_REMOVE_ALL_IDENTITIES, '')

class SSHAgentServer(protocol.Protocol):

    def __init__(self):
        self.buf = '' 
        self.keys = {} # public blob -> (private object, comment)

    def dataReceived(self, data):
        self.buf += data
        while 1:
            if len(self.buf) <= 4: return
            packLen = struct.unpack('!L', self.buf[:4])[0]
            if len(self.buf) < 4+packLen: return
            packet, self.buf = self.buf[4:4+packLen], self.buf[4+packLen:]
            reqType = ord(packet[0])
            reqName = messages.get(reqType, None)
            if not reqName:
                print 'bad request', reqType
            f = getattr(self, 'agentc_%s' % reqName)
            f(packet[1:])

    def sendResponse(self, reqType, data):
        pack = struct.pack('!LB', len(data)+1, reqType) + data
        self.transport.write(pack)

    def agentc_REQUEST_IDENTITIES(self, data):
        assert data == ''
        numKeys = len(self.keys)
        s = struct.pack('!L', numKeys)
        for k in self.keys:
            s += struct.pack('!L', len(k)) + k
            s += struct.pack('!L', len(self.keys[k][1])) + self.keys[k][1]
        self.sendResponse(AGENT_IDENTITIES_ANSWER, s)

    def agentc_SIGN_REQUEST(self, data):
        blob, data = common.getNS(data)
        if blob not in self.keys:
            return self.sendResponse(AGENT_FAILURE, '')
        signData, data = common.getNS(data)
        assert data == '\000\000\000\000'
        self.sendResponse(AGENT_SIGN_RESPONSE, common.NS(keys.signData(self.keys[blob][0], signData)))

    def agentc_ADD_IDENTITY(self, data): pass
    def agentc_REMOVE_IDENTITY(self, data): pass
    def agentc_REMOVE_ALL_IDENTITIES(self, data): pass

AGENT_FAILURE                   = 5
AGENT_SUCCESS                   = 6
AGENTC_REQUEST_IDENTITIES       = 11
AGENT_IDENTITIES_ANSWER         = 12
AGENTC_SIGN_REQUEST             = 13
AGENT_SIGN_RESPONSE             = 14
AGENTC_ADD_IDENTITY             = 17
AGENTC_REMOVE_IDENTITY          = 18
AGENTC_REMOVE_ALL_IDENTITIES    = 19

messages = {}
import agent
for v in dir(agent):
    if v.startswith('AGENTC_'):
        messages[getattr(agent, v)] = v[7:]