test_protocols.py   [plain text]


# Twisted, the Framework of Your Internet
# Copyright (C) 2001 Matthew W. Lefkowitz
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of version 2.1 of the GNU Lesser General Public
# License as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

"""
Test cases for twisted.protocols package.
"""

from twisted.trial import unittest
from twisted.protocols import basic, wire
from twisted.internet import reactor, protocol

import string, struct
import StringIO

class StringIOWithoutClosing(StringIO.StringIO):
    def close(self):
        pass

class LineTester(basic.LineReceiver):

    delimiter = '\n'
    MAX_LENGTH = 64

    def connectionMade(self):
        self.received = []

    def lineReceived(self, line):
        self.received.append(line)
        if line == '':
            self.setRawMode()
        if line[:4] == 'len ':
            self.length = int(line[4:])

    def rawDataReceived(self, data):
        data, rest = data[:self.length], data[self.length:]
        self.length = self.length - len(data)
        self.received[-1] = self.received[-1] + data
        if self.length == 0:
            self.setLineMode(rest)

    def lineLengthExceeded(self, line):
        if len(line) > self.MAX_LENGTH+1:
            self.setLineMode(line[self.MAX_LENGTH+1:])


class LineOnlyTester(basic.LineOnlyReceiver):

    delimiter = '\n'
    MAX_LENGTH = 64

    def connectionMade(self):
        self.received = []

    def lineReceived(self, line):
        self.received.append(line)

class WireTestCase(unittest.TestCase):

    def testEcho(self):
        t = StringIOWithoutClosing()
        a = wire.Echo()
        a.makeConnection(protocol.FileWrapper(t))
        a.dataReceived("hello")
        a.dataReceived("world")
        a.dataReceived("how")
        a.dataReceived("are")
        a.dataReceived("you")
        self.failUnlessEqual(t.getvalue(), "helloworldhowareyou")

    def testWho(self):
        t = StringIOWithoutClosing()
        a = wire.Who()
        a.makeConnection(protocol.FileWrapper(t))
        self.failUnlessEqual(t.getvalue(), "root\r\n")

    def testQOTD(self):
        t = StringIOWithoutClosing()
        a = wire.QOTD()
        a.makeConnection(protocol.FileWrapper(t))
        self.failUnlessEqual(t.getvalue(),
                             "An apple a day keeps the doctor away.\r\n")

    def testDiscard(self):
        t = StringIOWithoutClosing()
        a = wire.Discard()
        a.makeConnection(protocol.FileWrapper(t))
        a.dataReceived("hello")
        a.dataReceived("world")
        a.dataReceived("how")
        a.dataReceived("are")
        a.dataReceived("you")
        self.failUnlessEqual(t.getvalue(), "")

class LineReceiverTestCase(unittest.TestCase):

    buffer = '''\
len 10

0123456789len 5

1234
len 20
foo 123

0123456789
012345678len 0
foo 5

1234567890123456789012345678901234567890123456789012345678901234567890
len 1

a'''

    output = ['len 10', '0123456789', 'len 5', '1234\n',
              'len 20', 'foo 123', '0123456789\n012345678',
              'len 0', 'foo 5', '', '67890', 'len 1', 'a']

    def testBuffer(self):
        for packet_size in range(1, 10):
            t = StringIOWithoutClosing()
            a = LineTester()
            a.makeConnection(protocol.FileWrapper(t))
            for i in range(len(self.buffer)/packet_size + 1):
                s = self.buffer[i*packet_size:(i+1)*packet_size]
                a.dataReceived(s)
            self.failUnlessEqual(self.output, a.received)

class LineOnlyReceiverTestCase(unittest.TestCase):

    buffer = """foo
    bleakness
    desolation
    plastic forks
    """

    def testBuffer(self):
        t = StringIOWithoutClosing()
        a = LineOnlyTester()
        a.makeConnection(protocol.FileWrapper(t))
        for c in self.buffer:
            a.dataReceived(c)
        self.failUnlessEqual(a.received, self.buffer.split('\n')[:-1])

    def testLineTooLong(self):
        t = StringIOWithoutClosing()
        a = LineOnlyTester()
        a.makeConnection(protocol.FileWrapper(t))
        res = a.dataReceived('x'*200)
        self.failIfEqual(res, None)
            
                
class TestMixin:
    
    def connectionMade(self):
        self.received = []

    def stringReceived(self, s):
        self.received.append(s)

    MAX_LENGTH = 50
    closed = 0

    def connectionLost(self, reason):
        self.closed = 1


class TestNetstring(TestMixin, basic.NetstringReceiver):
    pass


class LPTestCaseMixin:

    illegal_strings = []
    protocol = None

    def getProtocol(self):
        t = StringIOWithoutClosing()
        a = self.protocol()
        a.makeConnection(protocol.FileWrapper(t))
        return a
    
    def testIllegal(self):
        for s in self.illegal_strings:
            r = self.getProtocol()
            for c in s:
                r.dataReceived(c)
            self.assertEquals(r.transport.closed, 1)


class NetstringReceiverTestCase(unittest.TestCase, LPTestCaseMixin):

    strings = ['hello', 'world', 'how', 'are', 'you123', ':today', "a"*515]

    illegal_strings = ['9999999999999999999999', 'abc', '4:abcde',
                       '51:aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab,',]

    protocol = TestNetstring
    
    def testBuffer(self):
        for packet_size in range(1, 10):
            t = StringIOWithoutClosing()
            a = TestNetstring()
            a.MAX_LENGTH = 699
            a.makeConnection(protocol.FileWrapper(t))
            for s in self.strings:
                a.sendString(s)
            out = t.getvalue()
            for i in range(len(out)/packet_size + 1):
                s = out[i*packet_size:(i+1)*packet_size]
                if s:
                    a.dataReceived(s)
            self.assertEquals(a.received, self.strings)


class TestInt32(TestMixin, basic.Int32StringReceiver):
    MAX_LENGTH = 50


class Int32TestCase(unittest.TestCase, LPTestCaseMixin):

    protocol = TestInt32
    strings = ["a", "b" * 16]
    illegal_strings = ["\x10\x00\x00\x00aaaaaa"]
    partial_strings = ["\x00\x00\x00", "hello there", ""]
    
    def testPartial(self):
        for s in self.partial_strings:
            r = self.getProtocol()
            r.MAX_LENGTH = 99999999
            for c in s:
                r.dataReceived(c)
            self.assertEquals(r.received, [])

    def testReceive(self):
        r = self.getProtocol()
        for s in self.strings:
            for c in struct.pack("!i",len(s))+s:
                r.dataReceived(c)
        self.assertEquals(r.received, self.strings)