"""HyperText Transfer Protocol implementation.
The second coming.
Future Plans:
- HTTP client support will at some point be refactored to support HTTP/1.1.
- Accept chunked data from clients in server.
- Other missing HTTP features from the RFC.
Maintainer: U{James Y Knight <mailto:foom@fuhm.net>}
"""
from cStringIO import StringIO
import tempfile
import base64, binascii
import cgi
import socket
import math
import time
import calendar
import warnings
import os
from twisted.internet import interfaces, reactor, protocol, address
from twisted.protocols import policies, basic
from twisted.python import log, components
try: from twisted.protocols._c_urlarg import unquote
except ImportError:
from urllib import unquote
import responsecode
protocol_version = "HTTP/1.1"
def parse_qs(qs, keep_blank_values=0, strict_parsing=0, unquote=unquote):
"""like cgi.parse_qs, only with custom unquote function"""
d = {}
items = [s2 for s1 in qs.split("&") for s2 in s1.split(";")]
for item in items:
try:
k, v = item.split("=", 1)
except ValueError:
if strict_parsing:
raise
continue
if v or keep_blank_values:
k = unquote(k.replace("+", " "))
v = unquote(v.replace("+", " "))
if k in d:
d[k].append(v)
else:
d[k] = [v]
return d
def toChunk(data):
"""Convert string to a chunk.
@returns: a tuple of strings representing the chunked encoding of data"""
return ("%x\r\n" % len(data), data, "\r\n")
def fromChunk(data):
"""Convert chunk to string.
@returns: tuple (result, remaining), may raise ValueError.
"""
prefix, rest = data.split('\r\n', 1)
length = int(prefix, 16)
if not rest[length:length+2] == '\r\n':
raise ValueError, "chunk must end with CRLF"
return rest[:length], rest[length+2:]
class StringTransport:
"""
I am a StringIO wrapper that conforms for the transport API. I support
the 'writeSequence' method.
"""
def __init__(self):
self.s = StringIO()
def writeSequence(self, seq):
self.s.write(''.join(seq))
def __getattr__(self, attr):
return getattr(self.__dict__['s'], attr)
NO_BODY_CODES = (204, 304)
class Request:
"""A HTTP request.
Subclasses should override the process() method to determine how
the request will be processed.
@ivar method: The HTTP method that was used.
@ivar uri: The full URI that was requested (includes arguments).
@ivar path: The path only (arguments not included).
@ivar args: All of the arguments, including URL and POST arguments.
@type args: A mapping of strings (the argument names) to lists of values.
i.e., ?foo=bar&foo=baz&quux=spam results in
{'foo': ['bar', 'baz'], 'quux': ['spam']}.
@ivar received_headers: All received headers
"""
__implements__ = interfaces.IConsumer,
producer = None
finished = 0
code = responsecode.OK
code_message = responsecode.RESPONSES[responsecode.OK]
startedWriting = 0
chunked = 0
sentLength = 0
_foreceSSL = False
def __init__(self, channel, queued, command, path, version, raw_in_headers):
"""
@param channel: the channel we're connected to.
@param queued: are we in the request queue, or can we start writing to
the transport?
"""
self.channel = channel
self.queued = queued
self.method = command
self.uri = path
self.clientproto = version
self.out_headers = http_headers.ResponseHeaders()
self.in_headers = http_headers.RequestHeaders(raw_in_headers)
if queued:
self.transport = StringTransport()
else:
self.transport = self.channel.transport
def _cleanup(self):
"""Called when have finished responding and are no longer queued."""
if self.producer:
log.err(RuntimeError("Producer was not unregistered for %s" % self.uri))
self.unregisterProducer()
self.channel.requestDone(self)
del self.channel
def noLongerQueued(self):
"""Notify the object that it is no longer queued.
We start writing whatever data we have to the transport, etc.
This method is not intended for users.
"""
if not self.queued:
raise RuntimeError, "noLongerQueued() got called unnecessarily."
self.queued = 0
data = self.transport.getvalue()
self.transport = self.channel.transport
if data:
self.transport.write(data)
if (self.producer is not None) and not self.finished:
self.transport.registerProducer(self.producer, True)
if self.finished:
self._cleanup()
def handleContentChunk(self, data):
"""Called by channel when a piece of data has been received.
Should be overridden by a subclass to do something appropriate."""
pass
def handleContentComplete(self):
"""Called by channel when all data has been received.
This method is not intended for users.
"""
self.args = {}
self.stack = []
x = self.uri.split('?')
if len(x) == 1:
self.path = self.uri
else:
if len(x) != 2:
log.msg("May ignore parts of this invalid URI: %s"
% repr(self.uri))
self.path, argstring = x[0], x[1]
self.args = parse_qs(argstring, 1)
self.client = self.channel.transport.getPeer()
self.host = self.channel.transport.getHost()
def __repr__(self):
return '<%s %s %s>'% (self.method, self.uri, self.clientproto)
def registerProducer(self, producer, streaming=True):
"""Register a producer.
Only streaming (push) producers supported."""
if not streaming:
raise ValueError, "non-streaming (pull) producers not supported."
if self.producer:
raise ValueError, "registering producer %s before previous one (%s) was unregistered" % (producer, self.producer)
self.producer = producer
if self.queued:
producer.pauseProducing()
else:
self.transport.registerProducer(producer, True)
def unregisterProducer(self):
"""Unregister the producer."""
if not self.queued:
self.transport.unregisterProducer()
self.producer = None
def _sendError(self, code, resp=''):
self.transport.write('%s %s %s\r\n\r\n' % (self.clientproto, code, resp))
def finish(self):
"""We are finished writing data."""
if self.finished:
warnings.warn("Warning! request.finish called twice.", stacklevel=2)
return
if not self.startedWriting:
self.write('')
if self.chunked and self.code in NO_BODY_CODES:
raise RuntimeError, "we screwed up"
if self.chunked:
self.transport.write("0\r\n\r\n")
if hasattr(self.channel, "factory"):
self.channel.factory.log(self)
self.finished = 1
if not self.queued:
self._cleanup()
def _initialWrite(self):
self.startedWriting = 1
version = self.clientproto
l = []
l.append('%s %s %s\r\n' % (version, self.code,
self.code_message))
if ((version == "HTTP/1.1") and
(self.headers.get('content-length', None) is None) and
(self.code not in NO_BODY_CODES)):
l.append("%s: %s\r\n" % ('Transfer-encoding', 'chunked'))
self.chunked = 1
for name, value in self.headers.items():
l.append("%s: %s\r\n" % (name.capitalize(), value))
l.append("\r\n")
self.transport.writeSequence(l)
if self.method == "HEAD":
self.write = lambda data: None
return False
if self.code in NO_BODY_CODES:
self.write = lambda data: None
return False
return True
def write(self, data):
"""
Write some data as a result of an HTTP request. The first
time this is called, it writes out response data.
"""
if not self.startedWriting:
if not self._initialWrite():
return
self.sentLength = self.sentLength + len(data)
if data:
if self.chunked:
self.transport.writeSequence(toChunk(data))
else:
self.transport.write(data)
def writeFile(self, file):
"""
Write data from a file, possibly more efficiently than write(data)
would do. Otherwise identical to write(file.read()).
v """
self.write(file.read())
def setResponseCode(self, code, message=None):
"""Set the HTTP response code.
"""
self.code = code
if message:
self.code_message = message
else:
self.code_message = RESPONSES.get(code, "Unknown Status")
def setHeader(self, k, v):
"""Set an outgoing HTTP header.
"""
self.headers[k.lower()] = v
def redirect(self, url):
"""Utility function that does a redirect.
The request should have finish() called after this.
"""
self.setResponseCode(FOUND)
self.setHeader("location", url)
def setLastModified(self, when):
"""Set the X{Last-Modified} time for the response to this request.
If I am called more than once, I ignore attempts to set
Last-Modified earlier, only replacing the Last-Modified time
if it is to a later value.
@param when: The last time the resource being returned was
modified, in seconds since the epoch.
@type when: number
"""
when = long(math.ceil(when))
lastModified = self.getRespHeader('Last-Modified')
if not lastModified or (lastModified < when):
self.setRespHeader('Last-Modified', when)
def checkBody(self):
"""Check to see if this request should have a body. As a side-effect
may modify my response code to L{NOT_MODIFIED} or L{PRECONDITION_FAILED},
if appropriate.
Call this function after setting the ETag and Last-Modified
output headers, but before actually proceeding with request
processing.
This examines the appropriate request headers for conditionals,
the existing response headers and sets the response code as necessary.
@return: True if you should write a body, False if you should
not.
"""
tags = self.getReqHeader("if-none-match")
etag = self.getRespHeader("etag")
if tags:
if (etag in tags) or ('*' in tags):
self.setResponseCode(((self.method in ("HEAD", "GET"))
and NOT_MODIFIED)
or PRECONDITION_FAILED)
return False
modified_since = self.getReqHeader('if-modified-since')
if modified_since:
if modified_since >= self.lastModified:
self.setResponseCode(NOT_MODIFIED)
return False
if self.method == "HEAD":
return False
return True
def getRequestHostname(self):
"""Get the hostname that the user passed in to the request.
This will either use the Host: header (if it is available) or the
host we are listening on if the header is unavailable.
"""
return (self.getReqHeader('host') or
socket.gethostbyaddr(self.getHost()[1])[0]
).split(':')[0]
def getHost(self):
"""Get my originally requesting transport's host.
Don't rely on the 'transport' attribute, since Request objects may be
copied remotely. For information on this method's return value, see
twisted.internet.tcp.Port.
"""
return self.host
def setHost(self, host, port, ssl=0):
"""Change the host and port the request thinks it's using.
This method is useful for working with reverse HTTP proxies (e.g.
both Squid and Apache's mod_proxy can do this), when the address
the HTTP client is using is different than the one we're listening on.
For example, Apache may be listening on https://www.example.com, and then
forwarding requests to http://localhost:8080, but we don't want HTML produced
by Twisted to say 'http://localhost:8080', they should say 'https://www.example.com',
so we do::
request.setHost('www.example.com', 443, ssl=1)
This method is experimental.
"""
self._forceSSL = ssl
self.received_headers["host"] = host
self.host = address.IPv4Address("TCP", host, port)
def getClientIP(self):
if isinstance(self.client, address.IPv4Address):
return self.client.host
else:
return None
def isSecure(self):
return self._forceSSL or components.implements(self.channel.transport, interfaces.ISSLTransport)
def _authorize(self):
try:
authh = self.getReqHeaderRaw("Authorization")
if not authh:
self.user = self.password = ''
return
bas, upw = authh.split()
if bas.lower() != "basic":
raise ValueError
upw = base64.decodestring(upw)
self.user, self.password = upw.split(':', 1)
except (binascii.Error, ValueError):
self.user = self.password = ""
except:
log.err()
self.user = self.password = ""
def getUser(self):
try:
return self.user
except:
pass
self._authorize()
return self.user
def getPassword(self):
try:
return self.password
except:
pass
self._authorize()
return self.password
def connectionLost(self, reason):
"""connection was lost"""
pass
class HTTPChannel(basic.LineReceiver, policies.TimeoutMixin):
"""A receiver for HTTP requests. Handles the hop-by-hop behavior."""
maxHeaderLength = 10240 requestFactory = Request
_partialheader = ''
_first_line = 1
_headerlen = 0
_savedTimeOut = None
def __init__(self):
self.reqHeaders = {}
self.requests = []
def connectionMade(self):
self.setTimeout(self.timeOut)
def lineReceived(self, line):
self.resetTimeout()
if self._first_line:
if not self.persistent:
self.dataReceived = self.lineReceived = lambda *args: None
return
if not line and self._first_line == 1:
self._first_line = 2
return
self._first_line = 0
parts = line.split()
if len(parts) != 3:
self.transport.write("HTTP/1.1 400 Bad Request\r\n\r\n")
self.transport.loseConnection()
return
self._command, self._path, self._version = parts
elif line == '':
if self._partialheader:
self.headerReceived(self._partialheader)
self._partialheader = ''
self.allHeadersReceived()
if self.length == 0:
self.allContentReceived()
else:
self.setRawMode()
elif line[0] in ' \t':
self._partialheader = self._partialheader+line
else:
if self._partialheader:
self.headerReceived(self._partialheader)
self._partialheader = line
def headerReceived(self, line):
"""Store this header away. Check for too much header data
(> maxHeaderLength) and abort the connection if so.
"""
name,val = line.split(':', 1)
val.lstrip(' \t')
old = self._reqHeaders.get(name, None)
if old is None:
old = []
self._reqHeaders[name]=old
old.append(val)
self._headerlen = self._headerlen+ len(line)
if self._headerlen > self.maxHeaderLength:
self.transport.write("HTTP/1.1 400 Bad Request\r\n\r\n")
self.transport.loseConnection()
def allHeadersReceived(self):
self.length = 0
self.persistent = False
connHeaders = self.splitConnectionHeaders()
request = self.requestFactory(self, len(self.requests), self._command, self._path, self._version, self._reqHeaders)
self.requests.append(request)
del self._command, self._path, self._version
del self._headerlen
self._reqHeaders = {}
self.persistent = self.checkPersistence(request)
self.length=req.getReqHeader('Content-Length')
def allContentReceived(self):
self.length = 0
self._first_line = 1
if self.timeOut:
self._savedTimeOut = self.setTimeout(None)
req = self.requests[-1]
req.handleContentComplete()
def rawDataReceived(self, data):
if len(data) < self.length:
self.requests[-1].handleContentChunk(data)
self.length = self.length - len(data)
else:
self.requests[-1].handleContentChunk(data[:self.length])
extraneous = data[self.length:]
self.allContentReceived()
self.setLineMode(extraneous)
def requestDone(self, request):
"""Called by first request in queue when it is done."""
if request != self.requests[0]: raise TypeError
del self.requests[0]
if self.persistent:
if self.requests:
self.requests[0].noLongerQueued()
else:
if self._savedTimeOut:
self.setTimeout(self._savedTimeOut)
else:
self.transport.loseConnection()
def timeoutConnection(self):
log.msg("Timing out client: %s" % str(self.transport.getPeer()))
policies.TimeoutMixin.timeoutConnection(self)
def connectionLost(self, reason):
self.setTimeout(None)
for request in self.requests:
request.connectionLost(reason)
def splitConnectionHeaders(self):
def move(name):
h = reqHeaders.getRawHeader(name, None)
if h is not None:
reqHeaders.removeHeader(name)
connHeaders.setRawHeader(name, h)
connHeaderNames = ['Connection', 'Keep-Alive', 'Proxy-Authenticate', 'Proxy-Authorization', 'TE', 'Trailers', 'Transfer-Encoding', 'Upgrade']
reqHeaders = self._reqHeaders
connHeaders = http_headers.Headers()
move('Connection')
if connHeaders.hasHeader('Connection'):
if self._version != "1.1":
for name in connHeaders.getHeader('Connection'):
if reqHeaders.hasHeader(name):
reqHeaders.removeHeader(name)
else:
connHeaderNames.extend(connHeaders.getHeader('Connection'))
for headername in connHeaders:
move(headername)
h = reqHeaders.getRawHeader('Content-Length', None)
if h is not None:
connHeaders.setRawHeader('Content-Length', h)
return connHeaders
def checkPersistence(self, request):
"""Check if the channel should close or not."""
if self.version == "HTTP/1.1":
if 'close' in self.getReqHeader('connection'):
self.addRespHeader('connection', 'close')
return 0
else:
return 1
else:
return 0
class HTTPFactory(protocol.ServerFactory):
"""Factory for HTTP server."""
protocol = HTTPChannel
logPath = None
timeOut = 60 * 60 * 12
def __init__(self, logPath=None, timeout=60*60*12):
if logPath is not None:
logPath = os.path.abspath(logPath)
self.logPath = logPath
self.timeOut = timeout
def buildProtocol(self, addr):
p = protocol.ServerFactory.buildProtocol(self, addr)
p.timeOut = self.timeOut
return p
def startFactory(self):
_logDateTimeStart()
if self.logPath:
self.logFile = self._openLogFile(self.logPath)
else:
self.logFile = log.logfile
def stopFactory(self):
if hasattr(self, "logFile"):
if self.logFile != log.logfile:
self.logFile.close()
del self.logFile
_logDateTimeStop()
def _openLogFile(self, path):
"""Override in subclasses, e.g. to use twisted.python.logfile."""
f = open(path, "a", 1)
f.seek(2, 0)
return f
def log(self, request):
"""Log a request's result to the logfile, by default in combined log format."""
line = '%s - - %s "%s" %d %s "%s" "%s"\n' % (
request.getClientIP(),
_logDateTime,
'%s %s %s' % (request.method, request.uri, request.clientproto),
request.code,
request.sentLength or "-",
request.getReqHeader("referer") or "-",
request.getReqHeader("user-agent") or "-")
self.logFile.write(line)