test_policies.py   [plain text]


# Twisted, the Framework of Your Internet
# Copyright (C) 2001-2002 Matthew W. Lefkowitz
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of version 2.1 of the GNU Lesser General Public
# License as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
#

"""Test code for policies."""
from __future__ import nested_scopes
from StringIO import StringIO

from twisted.trial import unittest

import time

from twisted.internet import protocol, reactor
from twisted.protocols import policies


class StringIOWithoutClosing(StringIO):
    def close(self): pass

class SimpleProtocol(protocol.Protocol):

    connected = disconnected = 0
    buffer = ""

    def connectionMade(self):
        self.connected = 1

    def connectionLost(self, reason):
        self.disconnected = 1

    def dataReceived(self, data):
        self.buffer += data


class SillyFactory(protocol.ClientFactory):

    def __init__(self, p):
        self.p = p

    def buildProtocol(self, addr):
        return self.p


class EchoProtocol(protocol.Protocol):

    def pauseProducing(self):
        self.paused = time.time()

    def resumeProducing(self):
        self.resume = time.time()

    def stopProducing(self):
        pass

    def dataReceived(self, data):
        self.transport.write(data)


class Server(protocol.ServerFactory):

    protocol = EchoProtocol


class SimpleSenderProtocol(SimpleProtocol):
    finished = 0
    data = ''
    def __init__(self, testcase):
        self.testcase = testcase
    def connectionMade(self):
        SimpleProtocol.connectionMade(self)
        self.writeSomething()
    def writeSomething(self):
        if self.disconnected:
            if not self.finished:
                self.fail()
            else:
                reactor.crash()
        if not self.disconnected:
            self.transport.write('foo')
            reactor.callLater(1, self.writeSomething)
    def finish(self):
        self.finished = 1
        self.transport.loseConnection()
    def fail(self):
        self.testcase.failed = 1
    def dataReceived(self, data):
        self.data += data



class ThrottlingTestCase(unittest.TestCase):

    def doIterations(self, count=5):
        for i in range(count):
            reactor.iterate()
            
    def testLimit(self):
        server = Server()
        c1, c2, c3, c4 = [SimpleProtocol() for i in range(4)]
        tServer = policies.ThrottlingFactory(server, 2)
        p = reactor.listenTCP(0, tServer, interface="127.0.0.1")
        n = p.getHost()[2]
        self.doIterations()

        for c in c1, c2, c3:
            reactor.connectTCP("127.0.0.1", n, SillyFactory(c))
            self.doIterations()

        self.assertEquals([c.connected for c in c1, c2, c3], [1, 1, 1])
        self.assertEquals([c.disconnected for c in c1, c2, c3], [0, 0, 1])
        self.assertEquals(len(tServer.protocols.keys()), 2)

        # disconnect one protocol and now another should be able to connect
        c1.transport.loseConnection()
        self.doIterations()
        reactor.connectTCP("127.0.0.1", n, SillyFactory(c4))
        self.doIterations()

        self.assertEquals(c4.connected, 1)
        self.assertEquals(c4.disconnected, 0)

        for c in c2, c4: c.transport.loseConnection()
        p.stopListening()
        self.doIterations()

    def testWriteLimit(self):
        server = Server()
        c1, c2 = SimpleProtocol(), SimpleProtocol()

        # The throttling factory starts checking bandwidth immediately
        now = time.time()

        tServer = policies.ThrottlingFactory(server, writeLimit=10)
        port = reactor.listenTCP(0, tServer, interface="127.0.0.1")
        n = port.getHost()[2]
        reactor.iterate(); reactor.iterate()
        for c in c1, c2:
            reactor.connectTCP("127.0.0.1", n, SillyFactory(c))
            self.doIterations()

        for p in tServer.protocols.keys():
            p = p.wrappedProtocol
            self.assert_(isinstance(p, EchoProtocol))
            p.transport.registerProducer(p, 1)

        c1.transport.write("0123456789")
        c2.transport.write("abcdefghij")
        self.doIterations()

        self.assertEquals(c1.buffer, "0123456789")
        self.assertEquals(c2.buffer, "abcdefghij")
        self.assertEquals(tServer.writtenThisSecond, 20)

        # at this point server should've written 20 bytes, 10 bytes
        # above the limit so writing should be paused around 1 second
        # from 'now', and resumed a second after that

        for p in tServer.protocols.keys():
            self.assert_(not hasattr(p.wrappedProtocol, "paused"))
            self.assert_(not hasattr(p.wrappedProtocol, "resume"))

        while not hasattr(p.wrappedProtocol, "paused"):
            reactor.iterate()

        self.assertEquals(tServer.writtenThisSecond, 0)

        for p in tServer.protocols.keys():
            self.assert_(hasattr(p.wrappedProtocol, "paused"))
            self.assert_(not hasattr(p.wrappedProtocol, "resume"))
            self.assert_(abs(p.wrappedProtocol.paused - now - 1.0) < 0.1)

        while not hasattr(p.wrappedProtocol, "resume"):
            reactor.iterate()

        for p in tServer.protocols.keys():
            self.assert_(hasattr(p.wrappedProtocol, "resume"))
            self.assert_(abs(p.wrappedProtocol.resume -
                             p.wrappedProtocol.paused - 1.0) < 0.1)

        c1.transport.loseConnection()
        c2.transport.loseConnection()
        port.stopListening()
        for p in tServer.protocols.keys():
            p.loseConnection()
        self.doIterations()

    def testReadLimit(self):
        server = Server()
        c1, c2 = SimpleProtocol(), SimpleProtocol()
        now = time.time()
        tServer = policies.ThrottlingFactory(server, readLimit=10)
        port = reactor.listenTCP(0, tServer, interface="127.0.0.1")
        n = port.getHost()[2]
        self.doIterations()
        for c in c1, c2:
            reactor.connectTCP("127.0.0.1", n, SillyFactory(c))
            self.doIterations()

        c1.transport.write("0123456789")
        c2.transport.write("abcdefghij")
        self.doIterations()
        self.assertEquals(c1.buffer, "0123456789")
        self.assertEquals(c2.buffer, "abcdefghij")
        self.assertEquals(tServer.readThisSecond, 20)

        # we wrote 20 bytes, so after one second it should stop reading
        # and then a second later start reading again
        while time.time() - now < 1.05:
            reactor.iterate()
        self.assertEquals(tServer.readThisSecond, 0)

        # write some more - data should *not* get written for another second
        c1.transport.write("0123456789")
        c2.transport.write("abcdefghij")
        self.doIterations()
        self.assertEquals(c1.buffer, "0123456789")
        self.assertEquals(c2.buffer, "abcdefghij")
        self.assertEquals(tServer.readThisSecond, 0)

        while time.time() - now < 2.05:
            reactor.iterate()
        self.assertEquals(c1.buffer, "01234567890123456789")
        self.assertEquals(c2.buffer, "abcdefghijabcdefghij")
        c1.transport.loseConnection()
        c2.transport.loseConnection()
        port.stopListening()
        for p in tServer.protocols.keys():
            p.loseConnection()
        self.doIterations()

    # These fail intermittently.
    testReadLimit.skip = "Inaccurate tests are worse than no tests."
    testWriteLimit.skip = "Inaccurate tests are worse than no tests."

class TimeoutTestCase(unittest.TestCase):
    def setUp(self):
        self.failed = 0

    def testTimeout(self):
        # Create a server which times out inactive connections
        server = policies.TimeoutFactory(Server(), 3)
        port = reactor.listenTCP(0, server, interface="127.0.0.1")

        # Create a client tha sends and receive nothing
        client = SimpleProtocol()
        f = SillyFactory(client)
        reactor.connectTCP("127.0.0.1", port.getHost()[2], f)

        for i in range(10):
            reactor.iterate()
            self.assert_(client.connected)

        time.sleep(3.5)
        for i in range(3):
            reactor.iterate()
        self.assert_(client.disconnected)

        # Clean up
        port.loseConnection()
        for i in range(10):
            reactor.iterate()

    def testThatSendingDataAvoidsTimeout(self):
        # Create a server which times out inactive connections
        server = policies.TimeoutFactory(Server(), 2)
        port = reactor.listenTCP(0, server, interface="127.0.0.1")

        # Create a client that sends and receive nothing
        client = SimpleSenderProtocol(self)
        f = SillyFactory(client)
        f.protocol = client
        reactor.connectTCP("127.0.0.1", port.getHost()[2], f)
        reactor.callLater(3.5, client.finish)
        reactor.run()

        self.failUnlessEqual(self.failed, 0)
        self.failUnlessEqual(client.data, 'foo'*4)

    def testThatReadingDataAvoidsTimeout(self):
        # Create a server that sends occasionally
        server = SillyFactory(SimpleSenderProtocol(self))
        port = reactor.listenTCP(0, server, interface='127.0.0.1')

        clientFactory = policies.WrappingFactory(SillyFactory(SimpleProtocol()))
        port = reactor.connectTCP('127.0.0.1', port.getHost()[2], clientFactory)

        reactor.iterate()
        reactor.iterate()
        reactor.callLater(5, server.p.finish)
        reactor.run()

        self.failUnlessEqual(self.failed, 0)

class TimeoutTester(protocol.Protocol, policies.TimeoutMixin):
    timeOut  = 3
    timedOut = 0

    def connectionMade(self):
        self.setTimeout(self.timeOut)

    def dataReceived(self, data):
        self.resetTimeout()
        protocol.Protocol.dataReceived(self, data)

    def connectionLost(self, reason=None):
        self.setTimeout(None)

    def timeoutConnection(self):
        self.timedOut = 1


class TestTimeout(unittest.TestCase):

    def testTimeout(self):
        p = TimeoutTester()
        s = StringIOWithoutClosing()
        p.makeConnection(protocol.FileWrapper(s))

        for i in range(10):
            reactor.iterate()
        self.failIf(p.timedOut)

        time.sleep(3.5)
        reactor.iterate()
        self.failUnless(p.timedOut)

    def testNoTimeout(self):
        p = TimeoutTester()
        s = StringIOWithoutClosing()
        p.makeConnection(protocol.FileWrapper(s))

        for i in range(10):
            reactor.iterate()
        self.failIf(p.timedOut)

        time.sleep(2)
        p.dataReceived('hello there')
        time.sleep(1.5)

        for i in range(10):
            reactor.iterate()
        self.failIf(p.timedOut)

        time.sleep(2)
        for i in range(10):
            reactor.iterate()
        self.failUnless(p.timedOut)

    def testResetTimeout(self):
        p = TimeoutTester()
        p.timeOut = None
        s = StringIOWithoutClosing()
        p.makeConnection(protocol.FileWrapper(s))
        
        p.setTimeout(1)
        self.assertEquals(p.timeOut, 1)
        
        for i in range(10):
            reactor.iterate()
        self.failIf(p.timedOut)

        time.sleep(1.1)
        reactor.iterate()
        self.failUnless(p.timedOut)
        p.connectionLost()
    
    def testCancelTimeout(self):
        p = TimeoutTester()
        p.timeOut = 5
        s = StringIOWithoutClosing()
        p.makeConnection(protocol.FileWrapper(s))
        
        p.setTimeout(None)
        self.assertEquals(p.timeOut, None)
        
        for i in range(10):
            reactor.iterate()
        self.failIf(p.timedOut)
        p.connectionLost()

    def testReturn(self):
        p = TimeoutTester()
        p.timeOut = 5
        
        self.assertEquals(p.setTimeout(10), 5)
        self.assertEquals(p.setTimeout(None), 10)
        self.assertEquals(p.setTimeout(1), None)
        self.assertEquals(p.timeOut, 1)
        
        p.connectionLost()