tls2.py   [plain text]



import time
import struct

import crypt

from twisted.python import context
from twisted.python import components

# Record Layer Content Types
RL_CT_CHANGE_CIPHER_SPEC = 20
RL_CT_ALERT = 21
RL_CT_HANDSHAKE = 22
RL_CT_APPLICATION_DATA = 23

def twos(s):
    return zip(*[iter(s)] * 2)

class IEncodable(components.Interface):
    def encode(self):
        """Encode this record fully for transmission.
        """

class Random:
    __implements__ = (IEncodable,)

    def __init__(self, tstamp=None, rbytes=None):
        self.time = tstamp or int(time.time())
        self.bytes = rbytes or crypt.getRandomBytes(28)

    def encode(self):
        return struct.pack('>I', self.time) + self.bytes

def NullCipherMethod(record):
    return record
def NullCompressionMethod(record):
    return record

class SecurityParameters(object):
    macAlgorithm = staticmethod(crypt.HMAC_NULL)
    bulkEncryptionAlgorithm = staticmethod(NullCipherMethod)
    compressionAlgorithm = staticmethod(NullCompressionMethod)

    masterSecret = None
    clientRandom = None
    serverRandom = None

    def getCipherSuites(self):
        return [0]

    def getCompressionMethods(self):
        return [0]

# Handshake Content Types
HS_CT_HELLO_REQUEST = 0
HS_CT_CLIENT_HELLO = 1
HS_CT_SERVER_HELLO = 2
HS_CT_CERTIFICATE = 11
HS_CT_SERVER_KEY_EXCHANGE = 12
HS_CT_CERTIFICATE_REQUEST = 13
HS_CT_SERVER_HELLO_DONE = 14
HS_CT_CERTIFICATE_VERIFY = 15
HS_CT_CLIENT_KEY_EXCHANGE = 16
HS_CT_FINISHED = 20

def logbytes(name, bytes):
    print name, len(bytes), ':', ' '.join(map(''.join, zip(*[iter(bytes.encode('hex'))] * 2)))


class RecordLayer(object):
    MAX_FRAGMENT_SIZE = 2 ** 14

    def __init__(self, record):
        self.record = record

    def encode(self):
        S = self.MAX_FRAGMENT_SIZE
        packets = []
        bytes = self.record.encode()
        hdr = struct.pack('>BBB', self.record.type, *self.record.version)
        while bytes:
            toenc, bytes = bytes[:S], bytes[S:]
            packets.append(hdr + struct.pack('>H', len(toenc)) + toenc)
        logbytes('Record Layer', ''.join(packets))
        return ''.join(packets)

class Plaintext(object):
    def __init__(self, record):
        self.type = record.type
        self.version = record.version
        self.record = record

    def encode(self):
        r = self.record.encode()
        logbytes('Plaintext', r)
        return r

class Handshake(object):

    type = RL_CT_HANDSHAKE
    version = (3, 1)

    handshakeType = None
    def encode(self):
        body = self.handshake_encode()
        assert len(body) < 2 ** 24
        high = len(body) >> 8
        low = len(body) & 0xff
        r = struct.pack('>BHB', self.handshakeType, high, low) + body
        logbytes('Handshake', r)
        return r

class ClientHello(Handshake):
    version = (3, 1)
    handshakeType = HS_CT_CLIENT_HELLO

    def __init__(self, version, random, session_id, ciphers, compressors):
        self.version = version
        self.random = random
        self.sessionID = session_id
        self.ciphers = ciphers
        self.compressors = compressors
    
    def handshake_encode(self):
        ver = ''.join(map(chr, self.version))
        rand = self.random.encode()
        sess = chr(len(self.sessionID)) + self.sessionID
        ciph = ''.join([struct.pack('>H', c) for c in self.ciphers])
        ciph = struct.pack('>H', len(ciph)) + ciph
        comp = chr(len(self.compressors)) + ''.join(map(chr, self.compressors))
        r = ver + rand + sess + ciph + comp
        logbytes('ClientHello', r)
        return r

class ServerHello(ClientHello):
    handshakeType = HS_CT_SERVER_HELLO

class Certificate(Handshake):
    handshakeType = HS_CT_CERTIFICATE

    def handshake_encode(self):
        pass

import sys
sys.path.append('../../pahan/statefulprotocol')
from stateful import StatefulProtocol
class RecordProtocol(StatefulProtocol):
    currentReadSecurity = None
    currentWriteSecurity = None
    
    pendingReadSecurity = None
    pendingWriteSecurity = None

    def connectionMade(self):
        for cp in ('pending', 'current'):
            for rw in ('Read', 'Write'):
                setattr(self, cp + rw + 'Security', SecurityParameters())

    def _write(self, record):
        comp = self.currentWriteSecurity.compressionAlgorithm
        ciph = self.currentWriteSecurity.bulkEncryptionAlgorithm
        bytes = RecordLayer(ciph(comp(record))).encode()
        logbytes('Sending', bytes)
        self.transport.write(bytes)

    def send(self, record):
        ctx = {'SecurityParameters': self.currentWriteSecurity}
        rec = Plaintext(record)
        context.call(ctx, self._write, rec)    

class TLSClient(RecordProtocol):

    buffer = ''

    CONTENT_TYPE_MAP = {chr(20): 'ChangeCipherSpec',
                        chr(21): 'Alert',
                        chr(22): 'Handshake',
                        chr(23): 'ApplicationData'}

    def getInitialState(self):
        m = self.state_RecordType
        return m, m.byteCount

    def dataReceived(self, data):
        print 'Received', repr(data)
        StatefulProtocol.dataReceived(self, data)

    def connectionMade(self):
        RecordProtocol.connectionMade(self)
        sp = getattr(self, cp + rw + 'Security')
        cipherSuites = sp.getCipherSuites()
        compMethods = sp.getCompressionMethods()
        self.send(ClientHello((3, 1), Random(), '', cipherSuites, compMethods))

    def state_RecordType(self, data):
        method = self.CONTENT_TYPE_MAP[data]
        method = getattr(self, 'rt_' + method)
        return method, method.byteCount
    state_RecordType.byteCount = 1

    def rt_ChangeCipherSpec(self, data):
        self.changeCipherSpec()
        return self.state_RecordType, self.state_RecordType.byteCount
    rt_ChangeCipherSpec.byteCount = 1

    def rt_Alert(self, data):
        print 'Alert'

    def rt_Handshake(self, data):
        print 'Handshake'

    def rt_ApplicationData(self, data):
        print 'ApplicationData'

class TLSServerProtocol(RecordProtocol):

    CONTENT_TYPE_MAP = {chr(20): 'ChangeCipherSpec',
                        chr(21): 'Alert',
                        chr(22): 'Handshake',
                        chr(23): 'ApplicationData'}

    HANDSHAKE_TYPE_MAP = {chr(0): 'HelloRequest',
                          chr(1): 'ClientHello',
                          chr(2): 'ServerHello',
                          chr(11): 'Certificate',
                          chr(12): 'ServerKeyExchange',
                          chr(13): 'CertificateRequest',
                          chr(14): 'ServerHelloDone',
                          chr(15): 'CertificateVerify',
                          chr(16): 'ClientKeyExchange',
                          chr(20): 'HandshakeFinished'}

    def dataReceived(self, bytes):
        logbytes("Server Received", bytes)
        StatefulProtocol.dataReceived(self, bytes)

    def getInitialState(self):
        m = self.state_RecordTypeAndVersionAndLength
        return m, m.byteCount

    def state_RecordTypeAndVersionAndLength(self, data):
        logbytes("RTAV", data)
        recordType = data[0]
        self.recordVersion = map(ord, data[1:3])
        recordLength = struct.unpack('>H', data[3:])[0]
        print 'Version is', self.recordVersion
        m = getattr(self, "rt_" + self.CONTENT_TYPE_MAP[recordType])
        print 'Next state is', m, recordLength
        return m, m.byteCount
    state_RecordTypeAndVersionAndLength.byteCount = 5

    def rt_Handshake(self, data):
        logbytes("Handshake", data)
        # Determine the type of handshake record this is
        m = getattr(self, 'hs_' + self.HANDSHAKE_TYPE_MAP[data[0]])
        bytes = struct.unpack('>I', '\0' + data[1:4])[0]
        print 'Next state is', m, bytes
        return m, bytes
    rt_Handshake.byteCount = 4

    def rt_Alert(self, data):
        print 'Alert!', data

    def hs_HelloRequest(self, data):
        logbytes("HelloRequest", data)
        print 'whaaaat'
        pass

    def hs_ClientHello(self, data):
        logbytes("ClientHello", data)
        fmt = '>BBI28spH'
        L = struct.calcsize(fmt)
        front, data = data[:L], data[L:]
        cv1, cv2, time, random, sessionID, nCiphs = struct.unpack(fmt, front)
        logbytes("Ciphers", data)
        ciphs, data = data[:nCiphs], data[nCiphs:]
        ciphers = [ord(a) << 8 | ord(b) for (a, b) in twos(ciphs)]
        nComps = ord(data[0])
        compressors = map(ord, data[1:nComps+1])
        ch = ClientHello((3, 1), Random(time, random), sessionID, ciphers, compressors)
        self.handshakeMessage(ch)
        m = self.state_RecordTypeAndVersionAndLength
        return m, m.byteCount
    

class TLSServer(TLSServerProtocol):
    sessions = {}

    def generateSessionID(self):
        return 'sessionID'

    def handshakeMessage(self, msg):
        f = self.HANDSHAKE_TYPE_MAP[chr(msg.handshakeType)]
        return getattr(self, 'handshake_' + f)(msg)

    def handshake_ClientHello(self, msg):
        if msg.sessionID in self.sessions:
            return self.resumeSession(msg)
        # XXX Check timestamp
        self.pendingWriteSecurity.clientRandom = msg.random.bytes
        self.pendingReadSecurity.clientRandom = msg.random.bytes

        random = crypt.getRandomBytes(28)
        self.pendingWriteSecurity.serverRandom = random
        self.pendingReadSecurity.serverRandom = random

        sp = self.currentWriteSecurity

        sessionID = self.generateSessionID()
        cs = sp.getCipherSuites()
        cm = sp.getCompressionMethods()
        h = ServerHello((3, 1), Random(rbytes=random), sessionID, cs, cm)
        self.send(h)
        c = Certificate()
        self.send(c)
        

if __name__ == '__main__':
    from twisted.internet import ssl
    from twisted.internet import reactor
    from twisted.internet import protocol

    from twisted.python import log
    import sys
    log.startLogging(sys.stdout)

    class ClientContextFactory(ssl.ClientContextFactory):
        method = ssl.SSL.TLSv1_METHOD

    class HexBytePrintingProtocol(protocol.Protocol):
        def dataReceived(self, bytes):
            logbytes("Received", bytes)

    OpenSSLServerFactory = protocol.ServerFactory()
    OpenSSLServerFactory.protocol = protocol.Protocol

    OpenSSLClientFactory = protocol.ClientFactory()
    OpenSSLClientFactory.protocol = protocol.Protocol

    PlainServerFactory = protocol.ServerFactory()
    PlainServerFactory.protocol = HexBytePrintingProtocol

    PlainClientFactory = protocol.ClientFactory()
    PlainClientFactory.protocol = HexBytePrintingProtocol

    PythonTLSServerFactory = protocol.ServerFactory()
    PythonTLSServerFactory.protocol = TLSServer

    PythonTLSClientFactory = protocol.ClientFactory()
    PythonTLSClientFactory.protocol = TLSClient

    pem = '/home/exarkun/projects/python/Twisted/twisted/test/server.pem'
    OpenSSLPort = reactor.listenSSL(0, OpenSSLServerFactory, ssl.DefaultOpenSSLContextFactory(pem, pem), interface='127.0.0.1')
    PlainPort = reactor.listenTCP(0, PlainServerFactory)
    PythonTLSPort = reactor.listenTCP(0, PythonTLSServerFactory)

    OpenSSLConn = reactor.connectSSL('127.0.0.1', PythonTLSPort.getHost()[2], OpenSSLClientFactory, ClientContextFactory())

    # PythonTLSConn = reactor.connectTCP('127.0.0.1', OpenSSLPort.getHost()[2], PythonTLSClientFactory)

    reactor.callLater(1, reactor.stop)
    reactor.run()