capserver.py   [plain text]



import socket

from twisted.internet import defer, protocol
from twisted.protocols import basic
from twisted.python import log
basic.DEBUG = True

FAILURE = 'Failure'
SUCCESS = 'Success'

BAD_ARGUMENT = 'Bad argument'
NO_SUCH_METHOD = 'No such method'
DONT_PIPELINE = 'No pipelining allowed'

##
## Basic StupidRPC
##

class StupidRPC(basic.NetstringReceiver):
    waitingForArgs = None
    dispatching = False

    def connectionMade(self):
        self.gotArgs = []

    def clean(self):
        self.dispatching = False
        self.gotArgs = []
        self.currMeth = None

    def sendSuccess(self, result):
        if not self.dispatching:
            log.msg("warning: sending success when not dispatching")
        self.clean()
        self.sendString('Success')
        self.sendString(result)

    def sendFailure(self, type, result=''):
        if not self.dispatching:
            log.msg("warning: sending failure when not dispatching")
        self.clean()
        self.sendString('Failure')
        self.sendString(type)
        self.sendString(result)

    def stringReceived(self, line):
        if self.dispatching:
            self.sendFailure(DONT_PIPELINE)
            self.loseConnection()

        if self.waitingForArgs:
            # An argument
            coercer = self.waitingForArgs.pop(0)
            try:
                v = coercer(line)
            except:
                f = failure.Failure()
                log.err(f)
                self.sendFailure(BAD_ARGUMENT, f.getErrorTraceback())
            else:
                self.gotArgs.append(v)
                if not self.waitingForArgs:
                    # that was all the args we needed
                    args = self.gotArgs
                    self.dispatching = True
                    self.currMeth(*args)

        else:
            # Beginning of a method call
            self.currMeth = self.findMethod(line)
            if not self.currMeth:
                self.sendFailure(NO_SUCH_METHOD)
            else:
                self.waitingForArgs = list(self.currMeth.sig)

    def findMethod(self, name):
        return getattr(self, 'remote_' + name, None)


##
## Cred integration
##

AUTH_SUCCESS = 'Authentication Successful'
AUTH_FAILURE = 'Authentication Failure'

from twisted.cred import credentials
from twisted.python import components

class IStupid(components.Interface):
    """remote_ methods"""

class AuthServer(StupidRPC):

    root = None
    avatar = None
    logout = None

    def connectionLost(self, reason):
        if self.logout is not None:
            self.logout()
            self.logout = None
        self.avatar = None

    def findMethod(self, name):
        if self.avatar is not None:
            return getattr(self.avatar, 'remote_'+name, None)
        else:
            return StupidRPC.findMethod(self, name)


    def remote_auth(self, name, password):
        d = self.factory.portal.login(
            credentials.UsernamePassword(name, password),
            None,
            IStupid)
        d.addCallback(self._cbGotRoot)
        d.addErrback(self._ebNoRoot)
    remote_auth.sig = (str, str)

    def _cbGotRoot(self, (i, a, l)):
        self.avatar = a
        self.logout = l
        self.avatar.proto = self
        self.sendSuccess(AUTH_SUCCESS)

    def _ebNoRoot(self, f):
        log.err(f)
        self.sendFailure(AUTH_FAILURE)

##
## Capability Server
##

def ipify(ipstr):
    assert len(ipstr) < 16
    assert ipstr.count('.') == 3
    return ipstr

SOCKET_BOUND = "Socket bound"
PERMISSION_DENIED = "Permission denied"

def makeCoercerWithDefault(coercer, default):
    def coerce(s):
        if not s:
            return default
        return coercer(s)
    return coerce

class CapServer:
    def __init__(self, allowed=None):
        if allowed is None:
            allowed = {}
        self.allowed = allowed

    def remote_bindPort(self, interface, portno, addressFamily, socketType):
        if (interface, portno, addressFamily, socketType) in self.allowed:
            s = self.socket(addressFamily, socketType)
            s.bind((interface, portno))
            self.sendSocket(s)
            self.proto.sendSuccess("<3")
        else:
            self.proto.sendFailure(PERMISSION_DENIED)
    remote_bindPort.sig = (ipify, int,
                           makeCoercerWithDefault(int, socket.AF_INET),
                           makeCoercerWithDefault(int, socket.SOCK_STREAM))

    socket = staticmethod(socket.socket)


def makeAuthFactory(rootFactory, checker):
    """
    Convenience.
    """
    from twisted.cred import portal
    f = protocol.Factory()
    f.protocol = AuthServer
    class Realm:
        def requestAvatar(self, name, mind, iface):
            assert iface is IStupid, iface
            return iface, rootFactory(), lambda: None
    f.portal = portal.Portal(Realm())
    f.portal.registerChecker(checker)
    return f

def main():
    import sys
    from twisted.internet import reactor
    from twisted.cred import checkers
    reactor.listenTCP(1025,makeAuthFactory(CapServer,
                           checkers.InMemoryUsernamePasswordDatabaseDontUse(radix='secret')))

    log.startLogging(sys.stdout)
    reactor.run()

if __name__ == '__main__':
    main()