"""Test methods in twisted.internet.threads and reactor thread APIs."""
from twisted.trial import unittest
from twisted.internet import threads, reactor
from twisted.python import threadable, failure
import time
import atexit
class Counter:
index = 0
problem = 0
def sync_add(self):
"""A thread-safe method."""
self.add()
def add(self):
"""A non thread-safe method."""
next = self.index + 1
if next != self.index + 1:
self.problem = 1
raise ValueError
self.index = next
synchronized = ["sync_add"]
threadable.synchronize(Counter)
class ThreadsTestCase(unittest.TestCase):
"""Test twisted.internet.threads."""
def testCallInThread(self):
c = Counter()
for i in range(1000):
reactor.callInThread(c.sync_add)
when = time.time()
oldIndex = 0
while c.index < 1000:
assert oldIndex <= c.index
self.failIf(c.problem, "threads reported overlap")
if c.index > oldIndex:
when = time.time() else:
if time.time() > when + 5:
if c.index > 0:
self.fail("threads lost a count, index is %d "
" time is %s, when is %s" %
(c.index, time.time(), when))
else:
self.fail("threads never started")
oldIndex = c.index
self.assertEquals(c.index, 1000, "threads lost a count")
def testCallMultiple(self):
c = Counter()
commands = [(c.add, (), {})] * 1000
threads.callMultipleInThread(commands)
when = time.time()
oldIndex = 0
while c.index < 1000:
assert oldIndex <= c.index
self.failIf(c.problem, "threads reported overlap")
if c.index > oldIndex:
when = time.time() else:
if time.time() > when + 5:
if c.index > 0:
self.fail("threads lost a count")
else:
self.fail("threads never started")
oldIndex = c.index
self.assertEquals(c.index, 1000)
def testSuggestThreadPoolSize(self):
reactor.suggestThreadPoolSize(34)
reactor.suggestThreadPoolSize(4)
class DeferredResultTestCase(unittest.TestCase):
"""Test threads.deferToThread"""
def setUp(self):
self.done = 0
self.gotResult = 0
def _timeout(self):
self.done = 1
def _resultCallback(self, result):
self.assertEquals(result, 7)
self.gotResult = 1
def _resultErrback(self, error):
self.done = 1
self.assert_( isinstance(error, failure.Failure) )
self.assertEquals(error.type, TypeError)
self.gotResult = 1
def testDeferredResult(self):
d = threads.deferToThread(lambda x, y=5: x + y, 3, y=4)
d.addCallback(self._resultCallback)
t = reactor.callLater(1, self._timeout)
while not self.done:
reactor.iterate()
self.failUnless(self.gotResult, "timeout")
if t.active(): t.cancel()
def testDeferredFailure(self):
def raiseError(): raise TypeError
d = threads.deferToThread(raiseError)
d.addErrback(self._resultErrback)
t = reactor.callLater(1, self._timeout)
while not self.done:
reactor.iterate()
self.failUnless(self.gotResult, "timeout")
if t.active(): t.cancel()
def OFFtestDeferredFailure2(self):
def nothing(): pass
reactor.callLater(1, reactor.crash)
reactor.callInThread(nothing)
reactor.run()
def raiseError(): raise TypeError
d = threads.deferToThread(raiseError)
d.addErrback(self._resultErrback)
t = reactor.callLater(1, self._timeout)
while not self.done:
reactor.iterate()
self.failUnless(self.gotResult, "timeout")
if t.active(): t.cancel()