from sets import Set
import warnings
from twisted.internet import interfaces, defer, main
from twisted.persisted import styles
from twisted.python import log, failure
from ops import ReadFileOp, WriteFileOp
from util import StateEventMachineType
class ConnectedSocket(log.Logger, styles.Ephemeral, object):
__metaclass__ = StateEventMachineType
__implements__ = interfaces.ITransport, interfaces.IProducer, interfaces.IConsumer
events = ["write", "loseConnection", "writeDone", "writeErr", "readDone", "readErr", "shutdown"]
bufferSize = 2**2**2**2
producer = None
writing = False
reading = False
write_shutdown = False
read_shutdown = False
def __init__(self, socket, protocol, sockfactory):
self.state = "connected"
from twisted.internet import reactor
self.socket = socket
self.protocol = protocol
self.sf = sockfactory
self.writebuf = []
self.readbuf = reactor.AllocateReadBuffer(self.bufferSize)
self.reactor = reactor
self.bufferEvents = {"buffer full": Set(), "buffer empty": Set()}
self.offset = 0
self.writeBufferedSize = 0
self.read_op = ReadFileOp(self)
self.write_op = WriteFileOp(self)
def addBufferCallback(self, handler, event):
self.bufferEvents[event].add(handler)
def removeBufferCallback(self, handler, event):
self.bufferEvents[event].remove(handler)
def callBufferHandlers(self, event, *a, **kw):
for i in self.bufferEvents[event].copy():
i(*a, **kw)
def handle_connected_write(self, data):
if self.writebuf and len(self.writebuf[-1]) < self.bufferSize: self.writebuf[-1] += data
else:
self.writebuf.append(data)
self.writeBufferedSize += len(data)
if self.writeBufferedSize >= self.bufferSize:
self.callBufferHandlers(event = "buffer full")
if not self.writing:
self.startWriting()
handle_disconnecting_write = handle_connected_write
def handle_disconnected_write(self, data):
pass
def writeSequence(self, iovec):
self.write("".join(iovec))
def _cbDisconnecting(self):
if self.producer:
return
self.removeBufferCallback(self._cbDisconnecting, "buffer empty")
self.connectionLost(failure.Failure(main.CONNECTION_DONE))
def handle_connected_loseConnection(self):
self.stopReading()
if self.writing:
self.addBufferCallback(self._cbDisconnecting, "buffer empty")
self.state = "disconnecting"
else:
self.connectionLost(failure.Failure(main.CONNECTION_DONE))
def handle_disconnecting_loseConnection(self):
pass
handle_disconnected_loseConnection = handle_disconnecting_loseConnection
def _cbWriteShutdown(self):
self.removeBufferCallback(self._cbWriteShutdown, "buffer empty")
self.socket.shutdown(1)
def handle_connected_shutdown(self, write = False, read = False):
if read and not self.read_shutdown:
self.read_shutdown = True
self.socket.shutdown(0)
if write and not self.write_shutdown:
self.write_shutdown = True if self.writing:
self.addBufferCallback(self._cbWriteShutdown, "buffer empty")
else:
self.socket.shutdown(1)
def connectionLost(self, reason):
self.state = "disconnected"
protocol = self.protocol
del self.protocol
self.socket.close()
del self.socket
self.sf.connectionLost(reason)
try:
protocol.connectionLost(reason)
except TypeError, e:
if e.args and e.args[0] == "connectionLost() takes exactly 1 argument (2 given)":
warnings.warn("Protocol %s's connectionLost should accept a reason argument" % protocol,
category=DeprecationWarning, stacklevel=2)
protocol.connectionLost()
else:
raise
def startReading(self):
self.reading = True
try:
self.read_op.initiateOp(self.socket.fileno(), self.readbuf)
except WindowsError, we:
self.connectionLost(failure.Failure(main.CONNECTION_DONE))
def stopReading(self):
self.reading = False
def handle_connected_readDone(self, bytes):
self.protocol.dataReceived(self.readbuf[:bytes])
if self.reading:
self.startReading()
def handle_disconnecting_readDone(self, bytes):
pass
def handle_connected_readErr(self, ret, bytes):
self.connectionLost(failure.Failure(main.CONNECTION_DONE))
handle_disconnecting_readErr = handle_connected_readErr
def handle_disconnected_readErr(self, ret, bytes):
pass
def handle_disconnected_readDone(self, bytes):
pass
def startWriting(self):
self.writing = True
b = buffer(self.writebuf[0], self.offset)
try:
self.write_op.initiateOp(self.socket.fileno(), b)
except WindowsError, we:
self.connectionLost(failure.Failure(main.CONNECTION_DONE))
def stopWriting(self):
self.writing = False
def handle_connected_writeDone(self, bytes):
self.offset += bytes
self.writeBufferedSize -= bytes
if self.offset == len(self.writebuf[0]):
del self.writebuf[0]
self.offset = 0
if self.writebuf == []:
self.writing = False
self.callBufferHandlers(event = "buffer empty")
else:
self.startWriting()
handle_disconnecting_writeDone = handle_connected_writeDone
def handle_connected_writeErr(self, ret, bytes):
self.connectionLost(failure.Failure(main.CONNECTION_DONE))
handle_disconnecting_writeErr = handle_connected_writeErr
def handle_disconnected_writeErr(self, ret, bytes):
pass
def handle_disconnected_writeDone(self, bytes):
pass
def registerProducer(self, producer, streaming):
if self.producer is not None:
raise RuntimeError("Cannot register producer %s, because producer %s was never unregistered." % (producer, self.producer))
if self.state == "disconnected":
producer.stopProducing()
else:
self.producer = producer
self.streamingProducer = streaming
self.addBufferCallback(self.milkProducer, "buffer empty")
self.addBufferCallback(self.stfuProducer, "buffer full")
if not streaming:
producer.resumeProducing()
def milkProducer(self):
if not self.streamingProducer or self.producerPaused:
self.producer.resumeProducing()
self.producerPaused = 0
def stfuProducer(self):
self.producerPaused = 1
self.producer.pauseProducing()
def unregisterProducer(self):
self.removeBufferCallback(self.stfuProducer, "buffer full")
self.removeBufferCallback(self.milkProducer, "buffer empty")
self.producer = None
def stopConsuming(self):
self.unregisterProducer()
self.loseConnection()
def resumeProducing(self):
self.startReading()
def pauseProducing(self):
self.stopReading()
def stopProducing(self):
self.loseConnection()
def __repr__(self):
return self.repstr
def logPrefix(self):
return self.logstr
disconnecting = property(lambda self: self.state == "disconnecting")