win32_iocp.py   [plain text]


"""
XXX: I AM INCORRECT, BROKEN CODE
A IOCP-based event loop.

This requires win32all to be installed.

TODO:
1. Pass tests.
2. Switch everyone to a decent OS so we don't have to deal with insane APIs.
3. Process support, SSL, UDP.
"""

# Win32 imports
from win32file import WSAEventSelect, FD_READ, FD_WRITE, FD_CLOSE, \
                      FD_ACCEPT, FD_CONNECT
from win32event import CreateEvent, WaitForMultipleObjects, \
                       WAIT_OBJECT_0, WAIT_TIMEOUT, INFINITE
import win32api
import win32con
import win32event
import win32file
import win32pipe
import win32process
import win32security
import pywintypes
import msvcrt

# Twisted imports
from twisted.internet import default, abstract
from twisted.internet.interfaces import IReactorCore, IReactorTime, IReactorTCP
from twisted.python import log, threadable
from twisted.protocols import protocol
from twisted.persisted import styles

# System imports
import os
import threading
import Queue
import string
import time
import socket
import sys
import struct

# globals
files = {} # files handled by IOCP


class Win32Reactor(default.PosixReactorBase):
    """Reactor that uses Win32 event APIs.

    Actually, this uses Proactor pattern.
    """
    
    __implements__ = IReactorCore, IReactorTime, IReactorTCP

    def __init__(self, handleSignals=1):
        default.PosixReactorBase.__init__(self, handleSignals)
        self.iocp = win32file.CreateIoCompletionPort(win32file.INVALID_HANDLE_VALUE, None, 0, 1)

    def installWaker(self):
        self.wakeupOverlapped = pywintypes.OVERLAPPED()
    
    def wakeUp(self):
        """Wake up the event loop."""
        if not threadable.isInIOThread():
            win32file.PostQueuedCompletionStatus(self.iocp, 0, 0, self.wakeupOverlapped)

    def removeAll(self):
        return []
    
    def registerFile(self, file, wrapper):
        """Register an object that will be handled by the I/O completion port."""
        print "registering %d for %s" % (int(file), wrapper)
        files[int(file)] = wrapper
        self.iocp = win32file.CreateIoCompletionPort(file, self.iocp, int(file), 1)
    
    def doIteration(self, timeout):
        if timeout is None:
            timeout = 10000
        else:
            timeout = int(1000 * timeout)
        rc, numBytes, key, overlapped = win32file.GetQueuedCompletionStatus(self.iocp, timeout)
        print "GQCS", rc, numBytes, key, overlapped, repr(overlapped.object)
        if key == 0:
            return

        object = files[key]
        #print "about to run method %r on object %r" % (overlapped.object, object)
        action = getattr(object, overlapped.object)
        try:
            action()
        except:
            log.deferr()
            try:
                object.connectionLost()
            except:
                log.deferr()


    # IReactorTCP

    def listenTCP(self, port, factory, backlog=5, interface=''):
        """See twisted.internet.interfaces.IReactorTCP.listenTCP
        """
        p = Port(port, factory, backlog, interface)
        p.startListening()
        return p

    def clientTCP(self, host, port, protocol, timeout=30):
        return Client(host, port, protocol, timeout)


def install():
    threadable.init(1)
    r = Win32Reactor()
    import main
    main.installReactor(r)
    import threadtask
    threadtask.theDispatcher.start()


class Connection(protocol.Transport, styles.Ephemeral):
    """A TCP connection for the Proactor pattern."""

    connected = 0
    producerPaused = 0
    streamingProducer = 0
    unsent = ""
    producer = None
    disconnected = 0
    disconnecting = 0
    bufferSize = 2**2**2**2
    writing = 0
    
    def __init__(self, skt, protocol):
        self.socket = skt
        self.socket.setblocking(0)
        self.protocol = protocol

        # setup win32 objects
        self.winSocket = skt.fileno()
        self.outOverlapped = pywintypes.OVERLAPPED()
        self.outOverlapped.hEvent = win32event.CreateEvent(None, 0, 0, None)
        self.inOverlapped = pywintypes.OVERLAPPED()
        self.inOverlapped.hEvent = win32event.CreateEvent(None, 0, 0, None)
        self.readData = win32file.AllocateReadBuffer(self.bufferSize)
        from twisted.internet import reactor
        reactor.registerFile(self.winSocket, self)
        self.outOverlapped.object = "finishedWriting"
        self.inOverlapped.object = "doRead"
    
    def registerProducer(self, producer, streaming):
        """Register to receive data from a producer.

        This sets this selectable to be a consumer for a producer.  When this
        selectable runs out of data on a write() call, it will ask the producer
        to resumeProducing(). A producer should implement the IProducer
        interface.

        FileDescriptor provides some infrastructure for producer methods.
        """

        self.producer = producer
        self.streamingProducer = streaming
        if not streaming:
            producer.resumeProducing()

    def unregisterProducer(self):
        """Stop consuming data from a producer, without disconnecting.
        """
        self.producer = None

    def stopConsuming(self):
        """Stop consuming data.

        This is called when a producer has lost its connection, to tell the
        consumer to go lose its connection (and break potential circular
        references).
        """
        self.unregisterProducer()
        self.loseConnection()

    def connectionLost(self):
        """The connection was lost.

        This is called when the connection on a selectable object has been
        lost.  It will be called whether the connection was closed explicitly,
        an exception occurred in an event handler, or the other end of the
        connection closed it first.

        Clean up state here, but make sure to call back up to FileDescriptor.
        """
        #print "closing connection", self
        self.disconnected = 1
        self.connected = 0
        if self.producer is not None:
            self.producer.stopProducing()
            self.producer = None
        try:
            self.socket.shutdown(2)
        except socket.error:
            pass
        protocol = self.protocol
        del self.protocol
        del self.socket
        protocol.connectionLost()

    def write(self, data):
        #print self, "is writing", repr(data)
        self.unsent = self.unsent + data
        if not self.writing:
            self.startWriting()

        if self.producer is not None:
            if len(self.unsent) > self.bufferSize:
                self.producerPaused = 1
                self.producer.pauseProducing()

    def loseConnection(self):
        if self.connected:
            if self.writing:
                self.disconnecting = 1
            else:
                self.connectionLost()

    def startWriting(self):
        #print self, "startWriting"
        self.writing = 1
        size = min(len(self.unsent), self.bufferSize)
        data, self.unsent = self.unsent[:size], self.unsent[size:]
        try:
            win32file.WriteFile(self.winSocket, data, self.outOverlapped)
        except win32api.error:
            self.connectionLost()
            return
    
    def finishedWriting(self):
        #print self, "finishedWriting"
        if self.disconnected:
            return
        if self.unsent:
            self.startWriting()
            return
        else:
            if self.producer is not None and ((not self.streamingProducer)
                                              or self.producerPaused):
                # tell them to supply some more.
                self.writing = 0
                self.producer.resumeProducing()
                self.producerPaused = 0
                return
            elif self.disconnecting:
                # But if I was previously asked to let the connection die, do
                # so.
                self.connectionLost()
                return
            self.writing = 0
    
    def startReading(self):
        #print self, "startReading"
        try:
            result, readData = win32file.ReadFile(self.winSocket, self.readData, self.inOverlapped)
            assert self.readData is readData
        except win32api.error, e:
            #print "win32api error", e
            try:
                length = win32file.GetOverlappedResult(self.winSocket, self.inOverlapped, 0)
                self.protocol.dataReceived(self.readData[:length])
            except win32api.error:
                pass
            return
    
    def doRead(self):
        #print self, "doRead"
        if self.disconnected:
            #print "disconnected, byebye:", self.disconnected
            return
        #print "not disconnected"
        length = win32file.GetOverlappedResult(self.winSocket, self.inOverlapped, 0)
        #print self, "received data of length", length
        if length:
            #print self, "received data", repr(self.readData[:length])
            self.protocol.dataReceived(self.readData[:length])
            self.startReading()
        else:
            self.connectionLost()


# this code is identical to tcp.Connection:

class Server(Connection):
    """Serverside socket-stream connection class.

    I am a serverside network connection transport; a socket which came from an
    accept() on a server.  Programmers for the twisted.net framework should not
    have to use me directly, since I am automatically instantiated in
    TCPServer's doRead method.  For documentation on what I do, refer to the
    documentation for twisted.protocols.protocol.Transport.
    """

    def __init__(self, sock, protocol, client, server, sessionno):
        """Server(sock, protocol, client, server, sessionno)

        Initialize me with a socket, a protocol, a descriptor for my peer (a
        tuple of host, port describing the other end of the connection), an
        instance of Port, and a session number.
        """
        self.repstr = "<%s #%s on %s>" % (protocol.__class__.__name__, sessionno, server.port)
        Connection.__init__(self, sock, protocol)
        self.server = server
        self.client = client
        self.sessionno = sessionno
        self.hostname = client[0]
        self.logstr = "%s,%s,%s" % (self.protocol.__class__.__name__, sessionno, self.hostname)
        self.startReading()
        self.connected = 1

    def __repr__(self):
        """A string representation of this connection.
        """
        return self.repstr

    def getHost(self):
        """Returns a tuple of ('INET', hostname, port).

        This indicates the servers address.
        """
        return ('INET',)+self.socket.getsockname()

    def getPeer(self):
        """
        Returns a tuple of ('INET', hostname, port), indicating the connected
        client's address.
        """
        return ('INET',)+self.client


class Port:
    """I am a TCP server port, listening for connections.

    When a connection is accepted, I will call my factory's buildProtocol with
    the incoming connection as an argument, according to the specification
    described in twisted.protocols.protocol.Factory.

    If you wish to change the sort of transport that will be used, my
    `transport' attribute will be called with the signature expected for
    Server.__init__, so it can be replaced.
    """

    transport = Server
    sessionno = 0
    interface = ''
    backlog = 5

    def __init__(self, port, factory, backlog=5, interface='', reactor=None):
        """Initialize with a numeric port to listen on.
        """
        self.port = port
        self.factory = factory
        self.backlog = backlog
        self.interface = interface
        if reactor is None:
            from twisted.internet import reactor
        self.reactor = reactor
        self.overlapped = pywintypes.OVERLAPPED()
        self.overlapped.hEvent = win32event.CreateEvent(None, 0, 0, None)
        self.buffer = win32file.AllocateReadBuffer(64)

    def __repr__(self):
        return "<%s on %s>" % (self.factory.__class__, self.port)

    def createInternetSocket(self):
        """(internal) create an AF_INET socket.
        """
        s = socket.socket(socket.AF_INET,socket.SOCK_STREAM)
        return s

    def __getstate__(self):
        """(internal) get my state for persistence
        """
        dct = copy.copy(self.__dict__)
        try: del dct['socket']
        except: pass
        try: del dct['fileno']
        except: pass
        return dct

    def startListening(self):
        """Create and bind my socket, and begin listening on it.

        This is called on unserialization, and must be called after creating a
        server to begin listening on the specified port.
        """
        log.msg("%s starting on %s"%(self.factory.__class__, self.port))
        skt = self.createInternetSocket()
        skt.setblocking(0)
        skt.bind((self.interface, self.port))
        skt.listen(self.backlog)
        self.connected = 1
        self.socket = skt
        winSocket = skt.fileno()
        self.overlapped.object = "doRead"
        self.reactor.registerFile(winSocket, self)
        self.startReading()

    def startReading(self):
        self.newSocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.newSocket.setblocking(0)
        win32file.AcceptEx(self.socket, self.newSocket, self.buffer, self.overlapped)
    
    def doRead(self):
        """Called when my socket is ready for reading.

        This accepts a connection and callse self.protocol() to handle the
        wire-level protocol.
        """
        if not self.connected:
            return
        try:
            skt = self.newSocket
            del self.newSocket

            # make new socket inherit properties from the port's socket
            skt.setsockopt(socket.SOL_SOCKET, win32file.SO_UPDATE_ACCEPT_CONTEXT, struct.pack("I", self.socket.fileno()))

            # get new socket's address
            family, localaddr, addr = win32file.GetAcceptExSockaddrs(self.socket, self.buffer)

            # build the new protocol
            protocol = self.factory.buildProtocol(addr)
            if protocol is None:
                skt.close()
            else:
                s = self.sessionno
                self.sessionno = s+1
                transport = self.transport(skt, protocol, addr, self, s)
                protocol.makeConnection(transport, self)
        except:
            log.deferr()
        
        self.startReading()

    def loseConnection(self):
        """ Stop accepting connections on this port.

        This will shut down my socket and call self.connectionLost().
        """
        self.disconnecting = 1
        if self.connected:
            self.reactor.callLater(0, self.connectionLost)

    def connectionLost(self):
        """Cleans up my socket.
        """
        log.msg('(Port %s Closed)' % self.port)
        self.disconnected = 1
        self.connected = 0
        self.socket.close()
        del self.socket
        self.factory.stopFactory()

    def logPrefix(self):
        """Returns the name of my class, to prefix log entries with.
        """
        return str(self.factory.__class__)

    def getHost(self):
        """Returns a tuple of ('INET', hostname, port).

        This indicates the servers address.
        """
        return ('INET',)+self.socket.getsockname()


class Client(Connection):
    """A client for TCP (and similiar) sockets.
    """
    def __init__(self, host, port, protocol, timeout=None, connector=None, reactor=None):
        """Initialize the client, setting up its socket, and request to connect.
        """
        if reactor is None:
            from twisted.internet import reactor
        self.reactor = reactor
        self.socket = self.createInternetSocket()
        self.addr = (host, port)
        self.protocol = protocol
        self.host = host
        self.port = port
        self.connector = connector
        self.logstr = self.protocol.__class__.__name__+",client"
        if timeout is not None:
            self.reactor.callLater(timeout, self.failIfNotConnected)
        self.reactor.callLater(0, self.resolveAddress)

    def failIfNotConnected(self, *ignored):
        # print 'failing if not connected'
        if (not self.connected) and (not self.disconnected):
            if self.connector:
                self.connector.connectionFailed()
            self.protocol.connectionFailed()

    def createInternetSocket(self):
        """(internal) Create an AF_INET socket.
        """
        # factored out so as to minimise the code necessary for SecureClient
        return socket.socket(socket.AF_INET,socket.SOCK_STREAM)

    def resolveAddress(self):
        if abstract.isIPAddress(self.addr[0]):
            self._setRealAddress(self.addr[0])
        else:
            self.reactor.resolve(self.addr[0]
                            ).addCallbacks(
                self._setRealAddress, self.failIfNotConnected
                ).arm()

    def _setRealAddress(self, address):
        # print 'real address:',repr(address),repr(self.addr)
        self.realAddress = (address, self.addr[1])
        import threadtask
        threadtask.theDispatcher.dispatch(log.logOwner.owner(), self._connect)

    def _connect(self):
        """Runs in thread"""
        try:
            self.socket.connect(self.realAddress)
        except socket.error:
            self.reactor.callFromThread(self.failIfNotConnected)
        else:
            self.reactor.callFromThread(self._connected)
    
    def _connected(self):
        """Called when connection succeeded."""
        self.connected = 1
        Connection.__init__(self, self.socket, self.protocol)
        self.protocol.makeConnection(self)
        self.startReading()
    
    def connectionLost(self):
        if not self.connected:
            self.failIfNotConnected()
        else:
            Connection.connectionLost(self)
            if self.connector:
                self.connector.connectionLost()

    def getHost(self):
        """Returns a tuple of ('INET', hostname, port).

        This indicates the address from which I am connecting.
        """
        return ('INET',)+self.socket.getsockname()

    def getPeer(self):
        """Returns a tuple of ('INET', hostname, port).

        This indicates the address that I am connected to.  I implement
        twisted.protocols.protocol.Transport.
        """
        return ('INET',)+self.addr

    def __repr__(self):
        s = '<%s to %s at %x>' % (self.__class__, self.addr, id(self))
        return s



__all__ = ["Win32Reactor", "install"]