import types, struct
from twisted.internet import protocol, error
from twisted.python.failure import Failure
from slicer import RootSlicer, RootUnslicer, \
UnbananaFailure, VocabSlicer, SimpleTokens
from tokens import Violation, SIZE_LIMIT, STRING, LIST, INT, NEG, \
LONGINT, LONGNEG, VOCAB, FLOAT, OPEN, CLOSE, ABORT, tokenNames, \
BananaError, BananaError2
def int2b128(integer, stream):
if integer == 0:
stream(chr(0))
return
assert integer > 0, "can only encode positive integers"
while integer:
stream(chr(integer & 0x7f))
integer = integer >> 7
def b1282int(st):
oneHundredAndTwentyEight = 128
i = 0
place = 0
for char in st:
num = ord(char)
i = i + (num * (oneHundredAndTwentyEight ** place))
place = place + 1
return i
def long_to_bytes(n, blocksize=0):
"""long_to_bytes(n:long, blocksize:int) : string
Convert a long integer to a byte string.
If optional blocksize is given and greater than zero, pad the front of
the byte string with binary zeros so that the length is a multiple of
blocksize.
"""
s = ''
n = long(n)
pack = struct.pack
while n > 0:
s = pack('>I', n & 0xffffffffL) + s
n = n >> 32
for i in range(len(s)):
if s[i] != '\000':
break
else:
s = '\000'
i = 0
s = s[i:]
if blocksize > 0 and len(s) % blocksize:
s = (blocksize - len(s) % blocksize) * '\000' + s
return s
def bytes_to_long(s):
"""bytes_to_long(string) : long
Convert a byte string to a long integer.
This is (essentially) the inverse of long_to_bytes().
"""
acc = 0L
unpack = struct.unpack
length = len(s)
if length % 4:
extra = (4 - length % 4)
s = '\000' * extra + s
length = length + extra
for i in range(0, length, 4):
acc = (acc << 32) + unpack('>I', s[i:i+4])[0]
return acc
HIGH_BIT_SET = chr(0x80)
class Banana(protocol.Protocol):
slicerClass = RootSlicer
unslicerClass = RootUnslicer
hangupOnLengthViolation = False
debug = False
def __init__(self):
self.initSend()
self.initReceive()
def initSend(self):
self.rootSlicer = self.slicerClass()
self.rootSlicer.protocol = self
self.slicerStack = [self.rootSlicer]
self.openCount = 0
self.outgoingVocabulary = {}
def send(self, obj):
assert(len(self.slicerStack) == 1)
assert(isinstance(self.slicerStack[0], self.slicerClass))
if type(obj) in SimpleTokens:
self.sendToken(obj)
return
self.doSlice(obj)
def setOutgoingVocabulary(self, vocabDict):
for key,value in vocabDict.items():
assert(isinstance(key, types.IntType))
assert(isinstance(value, types.StringType))
s = VocabSlicer()
s.protocol = self
self.slicerStack.append(s)
self.doSlice(vocabDict)
self.slicerStack.pop(-1)
self.outgoingVocabulary = dict(zip(vocabDict.values(),
vocabDict.keys()))
def doSlice(self, obj):
slicer = self.slicerStack[-1]
slicer.start(obj)
slicer.slice(obj)
slicer.finish(obj)
def slice(self, obj):
child = None
for i in range(len(self.slicerStack)-1, -1, -1):
child = self.slicerStack[i].newSlicer(obj)
if child:
break
if child == None:
raise "nothing to send for obj '%s' (type '%s')" % (obj, type(obj))
self.slice2(child, obj)
def slice2(self, child, obj):
child.protocol = self
self.slicerStack.append(child)
self.doSlice(obj)
self.slicerStack.pop(-1)
def setRefID(self, obj, refid):
for i in range(len(self.slicerStack)-1, -1, -1):
self.slicerStack[i].setRefID(obj, refid)
def getRefID(self, refid):
for i in range(len(self.slicerStack)-1, -1, -1):
obj = self.slicerStack[i].getRefID(refid)
if obj is not None:
return obj
return None
def sendOpen(self, opentype):
openID = self.openCount
self.openCount += 1
int2b128(openID, self.transport.write)
self.transport.write(OPEN)
self.sendToken(opentype)
return openID
def sendToken(self, obj):
write = self.transport.write
if isinstance(obj, types.IntType) or isinstance(obj, types.LongType):
if obj >= 2**31:
s = long_to_bytes(obj)
int2b128(len(s), write)
write(LONGINT)
write(s)
elif obj >= 0:
int2b128(obj, write)
write(INT)
elif -obj > 2**31: s = long_to_bytes(-obj)
int2b128(len(s), write)
write(LONGNEG)
write(s)
else:
int2b128(-obj, write)
write(NEG)
elif isinstance(obj, types.FloatType):
write(FLOAT)
write(struct.pack("!d", obj))
elif isinstance(obj, types.StringType):
if self.outgoingVocabulary.has_key(obj):
symbolID = self.outgoingVocabulary[obj]
int2b128(symbolID, write)
write(VOCAB)
else:
if len(obj) > SIZE_LIMIT:
raise BananaError, \
"string is too long to send (%d)" % len(obj)
int2b128(len(obj), write)
write(STRING)
write(obj)
else:
raise BananaError, "could not send object: %s" % repr(obj)
def sendClose(self, openID):
int2b128(openID, self.transport.write)
self.transport.write(CLOSE)
def sendAbort(self, count=0):
int2b128(count, self.transport.write)
self.transport.write(ABORT)
def initReceive(self):
self.rootUnslicer = self.unslicerClass()
self.rootUnslicer.protocol = self
self.receiveStack = [self.rootUnslicer]
self.objectCounter = 0
self.objects = {}
self.inOpen = False self.opentype = []
self.incomingVocabulary = {}
self.buffer = ''
self.skipBytes = 0 self.discardCount = 0 self.exploded = None
def printStack(self, verbose=0):
print "STACK:"
for s in self.receiveStack:
if verbose:
d = s.__dict__.copy()
del d['protocol']
print " %s: %s" % (s, d)
else:
print " %s" % s
def setObject(self, count, obj):
for i in range(len(self.receiveStack)-1, -1, -1):
self.receiveStack[i].setObject(count, obj)
def getObject(self, count):
for i in range(len(self.receiveStack)-1, -1, -1):
obj = self.receiveStack[i].getObject(count)
if obj is not None:
return obj
raise ValueError, "dangling reference '%d'" % count
def setIncomingVocabulary(self, vocabDict):
self.incomingVocabulary = vocabDict
def getLimit(self, typebyte):
top = self.receiveStack[-1]
if self.inOpen:
limit = top.openerCheckToken(typebyte, self.opentype)
else:
limit = top.checkToken(typebyte) if self.debug: print "getLimit(0x%x)=%s" % (ord(typebyte), limit)
return limit
def dataReceived(self, chunk):
if self.skipBytes:
if len(chunk) < self.skipBytes:
self.skipBytes -= len(chunk)
return
chunk = chunk[self.skipBytes:]
self.skipBytes = 0
buffer = self.buffer + chunk
gotItem = self.handleToken
while buffer:
assert self.buffer != buffer, "This ain't right: %s %s" % (repr(self.buffer), repr(buffer))
self.buffer = buffer
pos = 0
for ch in buffer:
if ch >= HIGH_BIT_SET:
break
pos = pos + 1
else:
if pos > 64:
raise BananaError("token prefix is limited to 64 bytes")
return
typebyte = buffer[pos]
sizelimit = SIZE_LIMIT
rejected = False
if self.discardCount:
rejected = True
self.inOpen = False
if not rejected:
if typebyte not in (ABORT, CLOSE):
try:
sizelimit = self.getLimit(typebyte)
except Violation:
where = self.describe()
e = BananaError("schema rejected %s token" % \
tokenNames[typebyte],
where + "<checkToken>")
rejected = True
gotItem(UnbananaFailure(self.describe(), e))
except BananaError, e:
where = self.describe()
e.where = where
raise e
except:
e = BananaError2(Failure(),
self.describe() + "<checkToken>")
raise e
header = buffer[:pos]
rest = buffer[pos+1:]
if len(header) > 64:
raise BananaError("token prefix is limited to 64 bytes")
if typebyte == LIST:
raise BananaError("oldbanana peer detected, " +
"compatibility code not yet written")
elif typebyte == STRING:
strlen = b1282int(header)
if not rejected and sizelimit != None and strlen > sizelimit:
if self.hangupOnLengthViolation:
raise BananaError("String too long.")
else:
rejected = True
e = BananaError("String too long.")
gotItem(UnbananaFailure(self.describe(), e))
if len(rest) >= strlen:
buffer = rest[strlen:]
obj = rest[:strlen]
else:
if rejected:
self.skipBytes = strlen - len(rest)
self.buffer = ""
return
elif typebyte == INT:
buffer = rest
header = b1282int(header)
obj = int(header)
elif typebyte == NEG:
buffer = rest
header = b1282int(header)
obj = -int(header)
elif typebyte == LONGINT or typebyte == LONGNEG:
strlen = b1282int(header)
if not rejected and sizelimit != None and strlen > sizelimit:
if self.hangupOnLengthViolation:
raise BananaError("Longint too long.")
else:
rejected = True
e = BananaError("Longint too long.")
gotItem(UnbananaFailure(self.describe(), e))
if len(rest) >= strlen:
buffer = rest[strlen:]
obj = bytes_to_long(rest[:strlen])
if typebyte == LONGNEG:
obj = -obj
else:
if rejected:
self.skipBytes = strlen - len(rest)
self.buffer = ""
return
elif typebyte == VOCAB:
buffer = rest
header = b1282int(header)
obj = self.incomingVocabulary[header]
elif typebyte == FLOAT:
if len(rest) >= 8:
buffer = rest[8:]
obj = struct.unpack("!d", rest[:8])[0]
else:
return
elif typebyte == OPEN:
buffer = rest
self.openCount = b1282int(header)
if rejected:
self.discardCount += 1
else:
if self.inOpen:
raise BananaError("OPEN token followed by OPEN")
self.inOpen = True
self.opentype = []
continue
elif typebyte == CLOSE:
buffer = rest
count = b1282int(header)
if self.discardCount:
self.discardCount -= 1
else:
self.handleClose(count)
continue
elif typebyte == ABORT:
buffer = rest
count = b1282int(header)
self.discardCount += 1
e = BananaError("ABORT received")
gotItem(UnbananaFailure(self.describe(), e))
continue
else:
raise BananaError(("Invalid Type Byte 0x%x" % ord(typebyte)))
if not rejected:
if self.inOpen:
self.handleOpen(self.openCount, obj)
else:
gotItem(obj)
else:
pass
self.buffer = ''
def handleOpen(self, openCount, indexToken):
self.opentype.append(indexToken)
opentype = tuple(self.opentype)
if self.debug:
print "handleOpen(%d,%s)" % (openCount, indexToken)
objectCount = self.objectCounter
top = self.receiveStack[-1]
try:
child = top.doOpen(opentype)
if not child:
if self.debug:
print " doOpen wants more index tokens"
return if self.debug:
print " opened[%d] with %s" % (openCount, child)
except Violation:
self.discardCount += 1
self.inOpen = False
where = self.describe() + ".<OPEN(%s)>" % (opentype,)
failure = UnbananaFailure(where, Failure())
top.receiveChild(failure)
return
except BananaError:
raise
except:
where = self.describe() + ".<OPEN(%s)>" % (opentype,)
raise BananaError2(Failure(), where)
self.objectCounter += 1
self.inOpen = False
child.protocol = self
child.openCount = openCount
self.receiveStack.append(child)
try:
child.start(objectCount)
except Violation:
where = self.describe() + ".<START>"
f = UnbananaFailure(where, Failure())
self.abandonUnslicer(f, child)
except BananaError:
raise
except:
where = self.describe() + ".<START>"
raise BananaError2(Failure(), where)
def handleToken(self, token):
top = self.receiveStack[-1]
if self.debug: print "handleToken(%s)" % token
try:
top.receiveChild(token)
except Violation:
f = UnbananaFailure(self.describe(), Failure())
self.abandonUnslicer(f, top)
except BananaError:
raise
except:
where = self.describe() + ".<receiveChild(%s)>" % (token,)
raise BananaError2(Failure(), where)
def handleClose(self, closeCount):
if self.debug:
print "handleClose(%d)" % closeCount
if self.receiveStack[-1].openCount != closeCount:
print "LOST SYNC"
self.printStack()
assert(0)
child = self.receiveStack[-1]
try:
obj = child.receiveClose()
except Violation:
where = self.describe() + ".<CLOSE>"
obj = UnbananaFailure(where, Failure())
except BananaError:
raise
except:
where = self.describe() + ".<CLOSE>"
raise BananaError2(Failure(), where)
if self.debug: print "receiveClose returned", obj
try:
child.finish()
except Violation:
where = self.describe() + ".<FINISH>"
obj = UnbananaFailure(where, Failure())
except BananaError:
raise
except:
where = self.describe() + ".<FINISH>"
raise BananaError2(Failure(), where)
self.receiveStack.pop()
parent = self.receiveStack[-1]
try:
if self.debug: print "receiveChild()"
if isinstance(obj, UnbananaFailure):
if self.debug: print "%s .childFinished for UF" % parent
self.startDiscarding(obj, parent)
parent.receiveChild(obj)
except Violation:
f = UnbananaFailure(self.describe(), Failure())
self.abandonUnslicer(f, parent)
except BananaError:
raise
except:
where = self.describe() + ".<receiveChild(%s)>" % (obj,)
raise BananaError2(Failure(), where)
def abandonUnslicer(self, failure, leaf=None):
"""The top-most Unslicer has decided to give up. We must discard all
tokens until the matching CLOSE is received. The UnbananaFailure
must be delivered to the late unslicer's parent.
leaf is a paranoia debug check, used to make sure abandonUnslicer is
called by the slicer that is currently in control.
"""
if self.debug:
print "## abandonUnslicer called"
if isinstance(failure, UnbananaFailure):
print "## while decoding '%s'" % failure.where
print "## current stack leading up to abandonUnslicer:"
import traceback
traceback.print_stack()
if not isinstance(failure, UnbananaFailure) and failure.failure:
print "## exception that triggered abandonUnslicer:"
print failure.failure.getBriefTraceback()
old = self.receiveStack.pop()
try:
old.finish() except Violation:
pass
assert leaf == old
if not self.receiveStack:
print "RootUnslicer broken! hang up or else"
raise RuntimeError, "RootUnslicer broken: hang up or else"
self.discardCount += 1 top = self.receiveStack[-1]
try:
top.receiveChild(failure)
except Violation:
self.abandonUnslicer(failure, top)
except BananaError:
raise
except:
where = self.describe() + \
".<abandonUnslicer-receiveChild(%s)>" % failure
raise BananaError2(Failure(), where)
def describe(self):
where = []
for i in self.receiveStack:
try:
piece = i.describeSelf()
except:
piece = "???"
where.append(piece)
return ".".join(where)
def receivedObject(self, obj):
"""Decoded objects are delivered here, unless you use a RootUnslicer
variant which does something else in its .childFinished method.
"""
raise NotImplementedError