1131 lines
38 KiB
Python
1131 lines
38 KiB
Python
from __future__ import absolute_import, print_function
|
|
|
|
import os
|
|
import socket
|
|
|
|
from collections import deque
|
|
from datetime import datetime, timedelta
|
|
from threading import Event
|
|
|
|
from amqp import ChannelError
|
|
from kombu import Connection
|
|
from kombu.common import QoS, ignore_errors
|
|
from kombu.transport.base import Message
|
|
|
|
from celery.app.defaults import DEFAULTS
|
|
from celery.bootsteps import RUN, CLOSE, StartStopStep
|
|
from celery.concurrency.base import BasePool
|
|
from celery.datastructures import AttributeDict
|
|
from celery.exceptions import (
|
|
WorkerShutdown, WorkerTerminate, TaskRevokedError,
|
|
)
|
|
from celery.five import Empty, range, Queue as FastQueue
|
|
from celery.utils import uuid
|
|
from celery.worker import components
|
|
from celery.worker import consumer
|
|
from celery.worker.consumer import Consumer as __Consumer
|
|
from celery.worker.job import Request
|
|
from celery.utils import worker_direct
|
|
from celery.utils.serialization import pickle
|
|
from celery.utils.timer2 import Timer
|
|
|
|
from celery.tests.case import AppCase, Mock, SkipTest, patch, restore_logging
|
|
|
|
|
|
def MockStep(step=None):
|
|
step = Mock() if step is None else step
|
|
step.blueprint = Mock()
|
|
step.blueprint.name = 'MockNS'
|
|
step.name = 'MockStep(%s)' % (id(step), )
|
|
return step
|
|
|
|
|
|
def mock_event_dispatcher():
|
|
evd = Mock(name='event_dispatcher')
|
|
evd.groups = ['worker']
|
|
evd._outbound_buffer = deque()
|
|
return evd
|
|
|
|
|
|
class PlaceHolder(object):
|
|
pass
|
|
|
|
|
|
def find_step(obj, typ):
|
|
return obj.blueprint.steps[typ.name]
|
|
|
|
|
|
class Consumer(__Consumer):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
kwargs.setdefault('without_mingle', True) # disable Mingle step
|
|
kwargs.setdefault('without_gossip', True) # disable Gossip step
|
|
kwargs.setdefault('without_heartbeat', True) # disable Heart step
|
|
super(Consumer, self).__init__(*args, **kwargs)
|
|
|
|
|
|
class _MyKombuConsumer(Consumer):
|
|
broadcast_consumer = Mock()
|
|
task_consumer = Mock()
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
kwargs.setdefault('pool', BasePool(2))
|
|
super(_MyKombuConsumer, self).__init__(*args, **kwargs)
|
|
|
|
def restart_heartbeat(self):
|
|
self.heart = None
|
|
|
|
|
|
class MyKombuConsumer(Consumer):
|
|
|
|
def loop(self, *args, **kwargs):
|
|
pass
|
|
|
|
|
|
class MockNode(object):
|
|
commands = []
|
|
|
|
def handle_message(self, body, message):
|
|
self.commands.append(body.pop('command', None))
|
|
|
|
|
|
class MockEventDispatcher(object):
|
|
sent = []
|
|
closed = False
|
|
flushed = False
|
|
_outbound_buffer = []
|
|
|
|
def send(self, event, *args, **kwargs):
|
|
self.sent.append(event)
|
|
|
|
def close(self):
|
|
self.closed = True
|
|
|
|
def flush(self):
|
|
self.flushed = True
|
|
|
|
|
|
class MockHeart(object):
|
|
closed = False
|
|
|
|
def stop(self):
|
|
self.closed = True
|
|
|
|
|
|
def create_message(channel, **data):
|
|
data.setdefault('id', uuid())
|
|
channel.no_ack_consumers = set()
|
|
m = Message(channel, body=pickle.dumps(dict(**data)),
|
|
content_type='application/x-python-serialize',
|
|
content_encoding='binary',
|
|
delivery_info={'consumer_tag': 'mock'})
|
|
m.accept = ['application/x-python-serialize']
|
|
return m
|
|
|
|
|
|
class test_Consumer(AppCase):
|
|
|
|
def setup(self):
|
|
self.buffer = FastQueue()
|
|
self.timer = Timer()
|
|
|
|
@self.app.task(shared=False)
|
|
def foo_task(x, y, z):
|
|
return x * y * z
|
|
self.foo_task = foo_task
|
|
|
|
def teardown(self):
|
|
self.timer.stop()
|
|
|
|
def test_info(self):
|
|
l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
|
|
l.task_consumer = Mock()
|
|
l.qos = QoS(l.task_consumer.qos, 10)
|
|
l.connection = Mock()
|
|
l.connection.info.return_value = {'foo': 'bar'}
|
|
l.controller = l.app.WorkController()
|
|
l.controller.pool = Mock()
|
|
l.controller.pool.info.return_value = [Mock(), Mock()]
|
|
l.controller.consumer = l
|
|
info = l.controller.stats()
|
|
self.assertEqual(info['prefetch_count'], 10)
|
|
self.assertTrue(info['broker'])
|
|
|
|
def test_start_when_closed(self):
|
|
l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
|
|
l.blueprint.state = CLOSE
|
|
l.start()
|
|
|
|
def test_connection(self):
|
|
l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
|
|
|
|
l.blueprint.start(l)
|
|
self.assertIsInstance(l.connection, Connection)
|
|
|
|
l.blueprint.state = RUN
|
|
l.event_dispatcher = None
|
|
l.blueprint.restart(l)
|
|
self.assertTrue(l.connection)
|
|
|
|
l.blueprint.state = RUN
|
|
l.shutdown()
|
|
self.assertIsNone(l.connection)
|
|
self.assertIsNone(l.task_consumer)
|
|
|
|
l.blueprint.start(l)
|
|
self.assertIsInstance(l.connection, Connection)
|
|
l.blueprint.restart(l)
|
|
|
|
l.stop()
|
|
l.shutdown()
|
|
self.assertIsNone(l.connection)
|
|
self.assertIsNone(l.task_consumer)
|
|
|
|
def test_close_connection(self):
|
|
l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
|
|
l.blueprint.state = RUN
|
|
step = find_step(l, consumer.Connection)
|
|
conn = l.connection = Mock()
|
|
step.shutdown(l)
|
|
self.assertTrue(conn.close.called)
|
|
self.assertIsNone(l.connection)
|
|
|
|
l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
|
|
eventer = l.event_dispatcher = mock_event_dispatcher()
|
|
eventer.enabled = True
|
|
heart = l.heart = MockHeart()
|
|
l.blueprint.state = RUN
|
|
Events = find_step(l, consumer.Events)
|
|
Events.shutdown(l)
|
|
Heart = find_step(l, consumer.Heart)
|
|
Heart.shutdown(l)
|
|
self.assertTrue(eventer.close.call_count)
|
|
self.assertTrue(heart.closed)
|
|
|
|
@patch('celery.worker.consumer.warn')
|
|
def test_receive_message_unknown(self, warn):
|
|
l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
|
|
l.blueprint.state = RUN
|
|
l.steps.pop()
|
|
backend = Mock()
|
|
m = create_message(backend, unknown={'baz': '!!!'})
|
|
l.event_dispatcher = mock_event_dispatcher()
|
|
l.node = MockNode()
|
|
|
|
callback = self._get_on_message(l)
|
|
callback(m.decode(), m)
|
|
self.assertTrue(warn.call_count)
|
|
|
|
@patch('celery.worker.strategy.to_timestamp')
|
|
def test_receive_message_eta_OverflowError(self, to_timestamp):
|
|
to_timestamp.side_effect = OverflowError()
|
|
l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
|
|
l.blueprint.state = RUN
|
|
l.steps.pop()
|
|
m = create_message(Mock(), task=self.foo_task.name,
|
|
args=('2, 2'),
|
|
kwargs={},
|
|
eta=datetime.now().isoformat())
|
|
l.event_dispatcher = mock_event_dispatcher()
|
|
l.node = MockNode()
|
|
l.update_strategies()
|
|
l.qos = Mock()
|
|
|
|
callback = self._get_on_message(l)
|
|
callback(m.decode(), m)
|
|
self.assertTrue(m.acknowledged)
|
|
|
|
@patch('celery.worker.consumer.error')
|
|
def test_receive_message_InvalidTaskError(self, error):
|
|
l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
|
|
l.blueprint.state = RUN
|
|
l.event_dispatcher = mock_event_dispatcher()
|
|
l.steps.pop()
|
|
m = create_message(Mock(), task=self.foo_task.name,
|
|
args=(1, 2), kwargs='foobarbaz', id=1)
|
|
l.update_strategies()
|
|
l.event_dispatcher = mock_event_dispatcher()
|
|
|
|
callback = self._get_on_message(l)
|
|
callback(m.decode(), m)
|
|
self.assertIn('Received invalid task message', error.call_args[0][0])
|
|
|
|
@patch('celery.worker.consumer.crit')
|
|
def test_on_decode_error(self, crit):
|
|
l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
|
|
|
|
class MockMessage(Mock):
|
|
content_type = 'application/x-msgpack'
|
|
content_encoding = 'binary'
|
|
body = 'foobarbaz'
|
|
|
|
message = MockMessage()
|
|
l.on_decode_error(message, KeyError('foo'))
|
|
self.assertTrue(message.ack.call_count)
|
|
self.assertIn("Can't decode message body", crit.call_args[0][0])
|
|
|
|
def _get_on_message(self, l):
|
|
if l.qos is None:
|
|
l.qos = Mock()
|
|
l.event_dispatcher = mock_event_dispatcher()
|
|
l.task_consumer = Mock()
|
|
l.connection = Mock()
|
|
l.connection.drain_events.side_effect = WorkerShutdown()
|
|
|
|
with self.assertRaises(WorkerShutdown):
|
|
l.loop(*l.loop_args())
|
|
self.assertTrue(l.task_consumer.register_callback.called)
|
|
return l.task_consumer.register_callback.call_args[0][0]
|
|
|
|
def test_receieve_message(self):
|
|
l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
|
|
l.blueprint.state = RUN
|
|
l.event_dispatcher = mock_event_dispatcher()
|
|
m = create_message(Mock(), task=self.foo_task.name,
|
|
args=[2, 4, 8], kwargs={})
|
|
l.update_strategies()
|
|
callback = self._get_on_message(l)
|
|
callback(m.decode(), m)
|
|
|
|
in_bucket = self.buffer.get_nowait()
|
|
self.assertIsInstance(in_bucket, Request)
|
|
self.assertEqual(in_bucket.name, self.foo_task.name)
|
|
self.assertEqual(in_bucket.execute(), 2 * 4 * 8)
|
|
self.assertTrue(self.timer.empty())
|
|
|
|
def test_start_channel_error(self):
|
|
|
|
class MockConsumer(Consumer):
|
|
iterations = 0
|
|
|
|
def loop(self, *args, **kwargs):
|
|
if not self.iterations:
|
|
self.iterations = 1
|
|
raise KeyError('foo')
|
|
raise SyntaxError('bar')
|
|
|
|
l = MockConsumer(self.buffer.put, timer=self.timer,
|
|
send_events=False, pool=BasePool(), app=self.app)
|
|
l.channel_errors = (KeyError, )
|
|
with self.assertRaises(KeyError):
|
|
l.start()
|
|
l.timer.stop()
|
|
|
|
def test_start_connection_error(self):
|
|
|
|
class MockConsumer(Consumer):
|
|
iterations = 0
|
|
|
|
def loop(self, *args, **kwargs):
|
|
if not self.iterations:
|
|
self.iterations = 1
|
|
raise KeyError('foo')
|
|
raise SyntaxError('bar')
|
|
|
|
l = MockConsumer(self.buffer.put, timer=self.timer,
|
|
send_events=False, pool=BasePool(), app=self.app)
|
|
|
|
l.connection_errors = (KeyError, )
|
|
self.assertRaises(SyntaxError, l.start)
|
|
l.timer.stop()
|
|
|
|
def test_loop_ignores_socket_timeout(self):
|
|
|
|
class Connection(self.app.connection().__class__):
|
|
obj = None
|
|
|
|
def drain_events(self, **kwargs):
|
|
self.obj.connection = None
|
|
raise socket.timeout(10)
|
|
|
|
l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
|
|
l.connection = Connection()
|
|
l.task_consumer = Mock()
|
|
l.connection.obj = l
|
|
l.qos = QoS(l.task_consumer.qos, 10)
|
|
l.loop(*l.loop_args())
|
|
|
|
def test_loop_when_socket_error(self):
|
|
|
|
class Connection(self.app.connection().__class__):
|
|
obj = None
|
|
|
|
def drain_events(self, **kwargs):
|
|
self.obj.connection = None
|
|
raise socket.error('foo')
|
|
|
|
l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
|
|
l.blueprint.state = RUN
|
|
c = l.connection = Connection()
|
|
l.connection.obj = l
|
|
l.task_consumer = Mock()
|
|
l.qos = QoS(l.task_consumer.qos, 10)
|
|
with self.assertRaises(socket.error):
|
|
l.loop(*l.loop_args())
|
|
|
|
l.blueprint.state = CLOSE
|
|
l.connection = c
|
|
l.loop(*l.loop_args())
|
|
|
|
def test_loop(self):
|
|
|
|
class Connection(self.app.connection().__class__):
|
|
obj = None
|
|
|
|
def drain_events(self, **kwargs):
|
|
self.obj.connection = None
|
|
|
|
l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
|
|
l.blueprint.state = RUN
|
|
l.connection = Connection()
|
|
l.connection.obj = l
|
|
l.task_consumer = Mock()
|
|
l.qos = QoS(l.task_consumer.qos, 10)
|
|
|
|
l.loop(*l.loop_args())
|
|
l.loop(*l.loop_args())
|
|
self.assertTrue(l.task_consumer.consume.call_count)
|
|
l.task_consumer.qos.assert_called_with(prefetch_count=10)
|
|
self.assertEqual(l.qos.value, 10)
|
|
l.qos.decrement_eventually()
|
|
self.assertEqual(l.qos.value, 9)
|
|
l.qos.update()
|
|
self.assertEqual(l.qos.value, 9)
|
|
l.task_consumer.qos.assert_called_with(prefetch_count=9)
|
|
|
|
def test_ignore_errors(self):
|
|
l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
|
|
l.connection_errors = (AttributeError, KeyError, )
|
|
l.channel_errors = (SyntaxError, )
|
|
ignore_errors(l, Mock(side_effect=AttributeError('foo')))
|
|
ignore_errors(l, Mock(side_effect=KeyError('foo')))
|
|
ignore_errors(l, Mock(side_effect=SyntaxError('foo')))
|
|
with self.assertRaises(IndexError):
|
|
ignore_errors(l, Mock(side_effect=IndexError('foo')))
|
|
|
|
def test_apply_eta_task(self):
|
|
from celery.worker import state
|
|
l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
|
|
l.qos = QoS(None, 10)
|
|
|
|
task = object()
|
|
qos = l.qos.value
|
|
l.apply_eta_task(task)
|
|
self.assertIn(task, state.reserved_requests)
|
|
self.assertEqual(l.qos.value, qos - 1)
|
|
self.assertIs(self.buffer.get_nowait(), task)
|
|
|
|
def test_receieve_message_eta_isoformat(self):
|
|
raise SkipTest('broken test, may fail at random')
|
|
l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
|
|
l.blueprint.state = RUN
|
|
l.steps.pop()
|
|
m = create_message(
|
|
Mock(), task=self.foo_task.name,
|
|
eta=(datetime.now() + timedelta(days=1)).isoformat(),
|
|
args=[2, 4, 8], kwargs={},
|
|
)
|
|
|
|
l.task_consumer = Mock()
|
|
l.qos = QoS(l.task_consumer.qos, 1)
|
|
current_pcount = l.qos.value
|
|
l.event_dispatcher = mock_event_dispatcher()
|
|
l.enabled = False
|
|
l.update_strategies()
|
|
callback = self._get_on_message(l)
|
|
callback(m.decode(), m)
|
|
l.timer.stop()
|
|
l.timer.join(1)
|
|
|
|
items = [entry[2] for entry in self.timer.queue]
|
|
found = 0
|
|
for item in items:
|
|
if item.args[0].name == self.foo_task.name:
|
|
found = True
|
|
self.assertTrue(found)
|
|
self.assertGreater(l.qos.value, current_pcount)
|
|
l.timer.stop()
|
|
|
|
def test_pidbox_callback(self):
|
|
l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
|
|
con = find_step(l, consumer.Control).box
|
|
con.node = Mock()
|
|
con.reset = Mock()
|
|
|
|
con.on_message('foo', 'bar')
|
|
con.node.handle_message.assert_called_with('foo', 'bar')
|
|
|
|
con.node = Mock()
|
|
con.node.handle_message.side_effect = KeyError('foo')
|
|
con.on_message('foo', 'bar')
|
|
con.node.handle_message.assert_called_with('foo', 'bar')
|
|
|
|
con.node = Mock()
|
|
con.node.handle_message.side_effect = ValueError('foo')
|
|
con.on_message('foo', 'bar')
|
|
con.node.handle_message.assert_called_with('foo', 'bar')
|
|
self.assertTrue(con.reset.called)
|
|
|
|
def test_revoke(self):
|
|
l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
|
|
l.blueprint.state = RUN
|
|
l.steps.pop()
|
|
backend = Mock()
|
|
id = uuid()
|
|
t = create_message(backend, task=self.foo_task.name, args=[2, 4, 8],
|
|
kwargs={}, id=id)
|
|
from celery.worker.state import revoked
|
|
revoked.add(id)
|
|
|
|
callback = self._get_on_message(l)
|
|
callback(t.decode(), t)
|
|
self.assertTrue(self.buffer.empty())
|
|
|
|
def test_receieve_message_not_registered(self):
|
|
l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
|
|
l.blueprint.state = RUN
|
|
l.steps.pop()
|
|
backend = Mock()
|
|
m = create_message(backend, task='x.X.31x', args=[2, 4, 8], kwargs={})
|
|
|
|
l.event_dispatcher = mock_event_dispatcher()
|
|
callback = self._get_on_message(l)
|
|
self.assertFalse(callback(m.decode(), m))
|
|
with self.assertRaises(Empty):
|
|
self.buffer.get_nowait()
|
|
self.assertTrue(self.timer.empty())
|
|
|
|
@patch('celery.worker.consumer.warn')
|
|
@patch('celery.worker.consumer.logger')
|
|
def test_receieve_message_ack_raises(self, logger, warn):
|
|
l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
|
|
l.blueprint.state = RUN
|
|
backend = Mock()
|
|
m = create_message(backend, args=[2, 4, 8], kwargs={})
|
|
|
|
l.event_dispatcher = mock_event_dispatcher()
|
|
l.connection_errors = (socket.error, )
|
|
m.reject = Mock()
|
|
m.reject.side_effect = socket.error('foo')
|
|
callback = self._get_on_message(l)
|
|
self.assertFalse(callback(m.decode(), m))
|
|
self.assertTrue(warn.call_count)
|
|
with self.assertRaises(Empty):
|
|
self.buffer.get_nowait()
|
|
self.assertTrue(self.timer.empty())
|
|
m.reject.assert_called_with(requeue=False)
|
|
self.assertTrue(logger.critical.call_count)
|
|
|
|
def test_receive_message_eta(self):
|
|
raise SkipTest('broken test, may fail at random')
|
|
import sys
|
|
from functools import partial
|
|
if os.environ.get('C_DEBUG_TEST'):
|
|
pp = partial(print, file=sys.__stderr__)
|
|
else:
|
|
def pp(*args, **kwargs):
|
|
pass
|
|
pp('TEST RECEIVE MESSAGE ETA')
|
|
pp('+CREATE MYKOMBUCONSUMER')
|
|
l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
|
|
pp('-CREATE MYKOMBUCONSUMER')
|
|
l.steps.pop()
|
|
l.event_dispatcher = mock_event_dispatcher()
|
|
backend = Mock()
|
|
pp('+ CREATE MESSAGE')
|
|
m = create_message(
|
|
backend, task=self.foo_task.name,
|
|
args=[2, 4, 8], kwargs={},
|
|
eta=(datetime.now() + timedelta(days=1)).isoformat(),
|
|
)
|
|
pp('- CREATE MESSAGE')
|
|
|
|
try:
|
|
pp('+ BLUEPRINT START 1')
|
|
l.blueprint.start(l)
|
|
pp('- BLUEPRINT START 1')
|
|
p = l.app.conf.BROKER_CONNECTION_RETRY
|
|
l.app.conf.BROKER_CONNECTION_RETRY = False
|
|
pp('+ BLUEPRINT START 2')
|
|
l.blueprint.start(l)
|
|
pp('- BLUEPRINT START 2')
|
|
l.app.conf.BROKER_CONNECTION_RETRY = p
|
|
pp('+ BLUEPRINT RESTART')
|
|
l.blueprint.restart(l)
|
|
pp('- BLUEPRINT RESTART')
|
|
l.event_dispatcher = mock_event_dispatcher()
|
|
pp('+ GET ON MESSAGE')
|
|
callback = self._get_on_message(l)
|
|
pp('- GET ON MESSAGE')
|
|
pp('+ CALLBACK')
|
|
callback(m.decode(), m)
|
|
pp('- CALLBACK')
|
|
finally:
|
|
pp('+ STOP TIMER')
|
|
l.timer.stop()
|
|
pp('- STOP TIMER')
|
|
try:
|
|
pp('+ JOIN TIMER')
|
|
l.timer.join()
|
|
pp('- JOIN TIMER')
|
|
except RuntimeError:
|
|
pass
|
|
|
|
in_hold = l.timer.queue[0]
|
|
self.assertEqual(len(in_hold), 3)
|
|
eta, priority, entry = in_hold
|
|
task = entry.args[0]
|
|
self.assertIsInstance(task, Request)
|
|
self.assertEqual(task.name, self.foo_task.name)
|
|
self.assertEqual(task.execute(), 2 * 4 * 8)
|
|
with self.assertRaises(Empty):
|
|
self.buffer.get_nowait()
|
|
|
|
def test_reset_pidbox_node(self):
|
|
l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
|
|
con = find_step(l, consumer.Control).box
|
|
con.node = Mock()
|
|
chan = con.node.channel = Mock()
|
|
l.connection = Mock()
|
|
chan.close.side_effect = socket.error('foo')
|
|
l.connection_errors = (socket.error, )
|
|
con.reset()
|
|
chan.close.assert_called_with()
|
|
|
|
def test_reset_pidbox_node_green(self):
|
|
from celery.worker.pidbox import gPidbox
|
|
pool = Mock()
|
|
pool.is_green = True
|
|
l = MyKombuConsumer(self.buffer.put, timer=self.timer, pool=pool,
|
|
app=self.app)
|
|
con = find_step(l, consumer.Control)
|
|
self.assertIsInstance(con.box, gPidbox)
|
|
con.start(l)
|
|
l.pool.spawn_n.assert_called_with(
|
|
con.box.loop, l,
|
|
)
|
|
|
|
def test__green_pidbox_node(self):
|
|
pool = Mock()
|
|
pool.is_green = True
|
|
l = MyKombuConsumer(self.buffer.put, timer=self.timer, pool=pool,
|
|
app=self.app)
|
|
l.node = Mock()
|
|
controller = find_step(l, consumer.Control)
|
|
|
|
class BConsumer(Mock):
|
|
|
|
def __enter__(self):
|
|
self.consume()
|
|
return self
|
|
|
|
def __exit__(self, *exc_info):
|
|
self.cancel()
|
|
|
|
controller.box.node.listen = BConsumer()
|
|
connections = []
|
|
|
|
class Connection(object):
|
|
calls = 0
|
|
|
|
def __init__(self, obj):
|
|
connections.append(self)
|
|
self.obj = obj
|
|
self.default_channel = self.channel()
|
|
self.closed = False
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, *exc_info):
|
|
self.close()
|
|
|
|
def channel(self):
|
|
return Mock()
|
|
|
|
def as_uri(self):
|
|
return 'dummy://'
|
|
|
|
def drain_events(self, **kwargs):
|
|
if not self.calls:
|
|
self.calls += 1
|
|
raise socket.timeout()
|
|
self.obj.connection = None
|
|
controller.box._node_shutdown.set()
|
|
|
|
def close(self):
|
|
self.closed = True
|
|
|
|
l.connection = Mock()
|
|
l.connect = lambda: Connection(obj=l)
|
|
controller = find_step(l, consumer.Control)
|
|
controller.box.loop(l)
|
|
|
|
self.assertTrue(controller.box.node.listen.called)
|
|
self.assertTrue(controller.box.consumer)
|
|
controller.box.consumer.consume.assert_called_with()
|
|
|
|
self.assertIsNone(l.connection)
|
|
self.assertTrue(connections[0].closed)
|
|
|
|
@patch('kombu.connection.Connection._establish_connection')
|
|
@patch('kombu.utils.sleep')
|
|
def test_connect_errback(self, sleep, connect):
|
|
l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
|
|
from kombu.transport.memory import Transport
|
|
Transport.connection_errors = (ChannelError, )
|
|
|
|
def effect():
|
|
if connect.call_count > 1:
|
|
return
|
|
raise ChannelError('error')
|
|
connect.side_effect = effect
|
|
l.connect()
|
|
connect.assert_called_with()
|
|
|
|
def test_stop_pidbox_node(self):
|
|
l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
|
|
cont = find_step(l, consumer.Control)
|
|
cont._node_stopped = Event()
|
|
cont._node_shutdown = Event()
|
|
cont._node_stopped.set()
|
|
cont.stop(l)
|
|
|
|
def test_start__loop(self):
|
|
|
|
class _QoS(object):
|
|
prev = 3
|
|
value = 4
|
|
|
|
def update(self):
|
|
self.prev = self.value
|
|
|
|
class _Consumer(MyKombuConsumer):
|
|
iterations = 0
|
|
|
|
def reset_connection(self):
|
|
if self.iterations >= 1:
|
|
raise KeyError('foo')
|
|
|
|
init_callback = Mock()
|
|
l = _Consumer(self.buffer.put, timer=self.timer,
|
|
init_callback=init_callback, app=self.app)
|
|
l.task_consumer = Mock()
|
|
l.broadcast_consumer = Mock()
|
|
l.qos = _QoS()
|
|
l.connection = Connection()
|
|
l.iterations = 0
|
|
|
|
def raises_KeyError(*args, **kwargs):
|
|
l.iterations += 1
|
|
if l.qos.prev != l.qos.value:
|
|
l.qos.update()
|
|
if l.iterations >= 2:
|
|
raise KeyError('foo')
|
|
|
|
l.loop = raises_KeyError
|
|
with self.assertRaises(KeyError):
|
|
l.start()
|
|
self.assertEqual(l.iterations, 2)
|
|
self.assertEqual(l.qos.prev, l.qos.value)
|
|
|
|
init_callback.reset_mock()
|
|
l = _Consumer(self.buffer.put, timer=self.timer, app=self.app,
|
|
send_events=False, init_callback=init_callback)
|
|
l.qos = _QoS()
|
|
l.task_consumer = Mock()
|
|
l.broadcast_consumer = Mock()
|
|
l.connection = Connection()
|
|
l.loop = Mock(side_effect=socket.error('foo'))
|
|
with self.assertRaises(socket.error):
|
|
l.start()
|
|
self.assertTrue(l.loop.call_count)
|
|
|
|
def test_reset_connection_with_no_node(self):
|
|
l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
|
|
l.steps.pop()
|
|
self.assertEqual(None, l.pool)
|
|
l.blueprint.start(l)
|
|
|
|
|
|
class test_WorkController(AppCase):
|
|
|
|
def setup(self):
|
|
self.worker = self.create_worker()
|
|
from celery import worker
|
|
self._logger = worker.logger
|
|
self._comp_logger = components.logger
|
|
self.logger = worker.logger = Mock()
|
|
self.comp_logger = components.logger = Mock()
|
|
|
|
@self.app.task(shared=False)
|
|
def foo_task(x, y, z):
|
|
return x * y * z
|
|
self.foo_task = foo_task
|
|
|
|
def teardown(self):
|
|
from celery import worker
|
|
worker.logger = self._logger
|
|
components.logger = self._comp_logger
|
|
|
|
def create_worker(self, **kw):
|
|
worker = self.app.WorkController(concurrency=1, loglevel=0, **kw)
|
|
worker.blueprint.shutdown_complete.set()
|
|
return worker
|
|
|
|
def test_on_consumer_ready(self):
|
|
self.worker.on_consumer_ready(Mock())
|
|
|
|
def test_setup_queues_worker_direct(self):
|
|
self.app.conf.CELERY_WORKER_DIRECT = True
|
|
self.app.amqp.__dict__['queues'] = Mock()
|
|
self.worker.setup_queues({})
|
|
self.app.amqp.queues.select_add.assert_called_with(
|
|
worker_direct(self.worker.hostname),
|
|
)
|
|
|
|
def test_send_worker_shutdown(self):
|
|
with patch('celery.signals.worker_shutdown') as ws:
|
|
self.worker._send_worker_shutdown()
|
|
ws.send.assert_called_with(sender=self.worker)
|
|
|
|
def test_process_shutdown_on_worker_shutdown(self):
|
|
raise SkipTest('unstable test')
|
|
from celery.concurrency.prefork import process_destructor
|
|
from celery.concurrency.asynpool import Worker
|
|
with patch('celery.signals.worker_process_shutdown') as ws:
|
|
Worker._make_shortcuts = Mock()
|
|
with patch('os._exit') as _exit:
|
|
worker = Worker(None, None, on_exit=process_destructor)
|
|
worker._do_exit(22, 3.1415926)
|
|
ws.send.assert_called_with(
|
|
sender=None, pid=22, exitcode=3.1415926,
|
|
)
|
|
_exit.assert_called_with(3.1415926)
|
|
|
|
def test_process_task_revoked_release_semaphore(self):
|
|
self.worker._quick_release = Mock()
|
|
req = Mock()
|
|
req.execute_using_pool.side_effect = TaskRevokedError
|
|
self.worker._process_task(req)
|
|
self.worker._quick_release.assert_called_with()
|
|
|
|
delattr(self.worker, '_quick_release')
|
|
self.worker._process_task(req)
|
|
|
|
def test_shutdown_no_blueprint(self):
|
|
self.worker.blueprint = None
|
|
self.worker._shutdown()
|
|
|
|
@patch('celery.platforms.create_pidlock')
|
|
def test_use_pidfile(self, create_pidlock):
|
|
create_pidlock.return_value = Mock()
|
|
worker = self.create_worker(pidfile='pidfilelockfilepid')
|
|
worker.steps = []
|
|
worker.start()
|
|
self.assertTrue(create_pidlock.called)
|
|
worker.stop()
|
|
self.assertTrue(worker.pidlock.release.called)
|
|
|
|
@patch('celery.platforms.signals')
|
|
@patch('celery.platforms.set_mp_process_title')
|
|
def test_process_initializer(self, set_mp_process_title, _signals):
|
|
with restore_logging():
|
|
from celery import signals
|
|
from celery._state import _tls
|
|
from celery.concurrency.prefork import (
|
|
process_initializer, WORKER_SIGRESET, WORKER_SIGIGNORE,
|
|
)
|
|
|
|
def on_worker_process_init(**kwargs):
|
|
on_worker_process_init.called = True
|
|
on_worker_process_init.called = False
|
|
signals.worker_process_init.connect(on_worker_process_init)
|
|
|
|
def Loader(*args, **kwargs):
|
|
loader = Mock(*args, **kwargs)
|
|
loader.conf = {}
|
|
loader.override_backends = {}
|
|
return loader
|
|
|
|
with self.Celery(loader=Loader) as app:
|
|
app.conf = AttributeDict(DEFAULTS)
|
|
process_initializer(app, 'awesome.worker.com')
|
|
_signals.ignore.assert_any_call(*WORKER_SIGIGNORE)
|
|
_signals.reset.assert_any_call(*WORKER_SIGRESET)
|
|
self.assertTrue(app.loader.init_worker.call_count)
|
|
self.assertTrue(on_worker_process_init.called)
|
|
self.assertIs(_tls.current_app, app)
|
|
set_mp_process_title.assert_called_with(
|
|
'celeryd', hostname='awesome.worker.com',
|
|
)
|
|
|
|
with patch('celery.app.trace.setup_worker_optimizations') as S:
|
|
os.environ['FORKED_BY_MULTIPROCESSING'] = "1"
|
|
try:
|
|
process_initializer(app, 'luke.worker.com')
|
|
S.assert_called_with(app)
|
|
finally:
|
|
os.environ.pop('FORKED_BY_MULTIPROCESSING', None)
|
|
|
|
def test_attrs(self):
|
|
worker = self.worker
|
|
self.assertIsNotNone(worker.timer)
|
|
self.assertIsInstance(worker.timer, Timer)
|
|
self.assertIsNotNone(worker.pool)
|
|
self.assertIsNotNone(worker.consumer)
|
|
self.assertTrue(worker.steps)
|
|
|
|
def test_with_embedded_beat(self):
|
|
worker = self.app.WorkController(concurrency=1, loglevel=0, beat=True)
|
|
self.assertTrue(worker.beat)
|
|
self.assertIn(worker.beat, [w.obj for w in worker.steps])
|
|
|
|
def test_with_autoscaler(self):
|
|
worker = self.create_worker(
|
|
autoscale=[10, 3], send_events=False,
|
|
timer_cls='celery.utils.timer2.Timer',
|
|
)
|
|
self.assertTrue(worker.autoscaler)
|
|
|
|
def test_dont_stop_or_terminate(self):
|
|
worker = self.app.WorkController(concurrency=1, loglevel=0)
|
|
worker.stop()
|
|
self.assertNotEqual(worker.blueprint.state, CLOSE)
|
|
worker.terminate()
|
|
self.assertNotEqual(worker.blueprint.state, CLOSE)
|
|
|
|
sigsafe, worker.pool.signal_safe = worker.pool.signal_safe, False
|
|
try:
|
|
worker.blueprint.state = RUN
|
|
worker.stop(in_sighandler=True)
|
|
self.assertNotEqual(worker.blueprint.state, CLOSE)
|
|
worker.terminate(in_sighandler=True)
|
|
self.assertNotEqual(worker.blueprint.state, CLOSE)
|
|
finally:
|
|
worker.pool.signal_safe = sigsafe
|
|
|
|
def test_on_timer_error(self):
|
|
worker = self.app.WorkController(concurrency=1, loglevel=0)
|
|
|
|
try:
|
|
raise KeyError('foo')
|
|
except KeyError as exc:
|
|
components.Timer(worker).on_timer_error(exc)
|
|
msg, args = self.comp_logger.error.call_args[0]
|
|
self.assertIn('KeyError', msg % args)
|
|
|
|
def test_on_timer_tick(self):
|
|
worker = self.app.WorkController(concurrency=1, loglevel=10)
|
|
|
|
components.Timer(worker).on_timer_tick(30.0)
|
|
xargs = self.comp_logger.debug.call_args[0]
|
|
fmt, arg = xargs[0], xargs[1]
|
|
self.assertEqual(30.0, arg)
|
|
self.assertIn('Next eta %s secs', fmt)
|
|
|
|
def test_process_task(self):
|
|
worker = self.worker
|
|
worker.pool = Mock()
|
|
backend = Mock()
|
|
m = create_message(backend, task=self.foo_task.name, args=[4, 8, 10],
|
|
kwargs={})
|
|
task = Request(m.decode(), message=m, app=self.app)
|
|
worker._process_task(task)
|
|
self.assertEqual(worker.pool.apply_async.call_count, 1)
|
|
worker.pool.stop()
|
|
|
|
def test_process_task_raise_base(self):
|
|
worker = self.worker
|
|
worker.pool = Mock()
|
|
worker.pool.apply_async.side_effect = KeyboardInterrupt('Ctrl+C')
|
|
backend = Mock()
|
|
m = create_message(backend, task=self.foo_task.name, args=[4, 8, 10],
|
|
kwargs={})
|
|
task = Request(m.decode(), message=m, app=self.app)
|
|
worker.steps = []
|
|
worker.blueprint.state = RUN
|
|
with self.assertRaises(KeyboardInterrupt):
|
|
worker._process_task(task)
|
|
|
|
def test_process_task_raise_WorkerTerminate(self):
|
|
worker = self.worker
|
|
worker.pool = Mock()
|
|
worker.pool.apply_async.side_effect = WorkerTerminate()
|
|
backend = Mock()
|
|
m = create_message(backend, task=self.foo_task.name, args=[4, 8, 10],
|
|
kwargs={})
|
|
task = Request(m.decode(), message=m, app=self.app)
|
|
worker.steps = []
|
|
worker.blueprint.state = RUN
|
|
with self.assertRaises(SystemExit):
|
|
worker._process_task(task)
|
|
|
|
def test_process_task_raise_regular(self):
|
|
worker = self.worker
|
|
worker.pool = Mock()
|
|
worker.pool.apply_async.side_effect = KeyError('some exception')
|
|
backend = Mock()
|
|
m = create_message(backend, task=self.foo_task.name, args=[4, 8, 10],
|
|
kwargs={})
|
|
task = Request(m.decode(), message=m, app=self.app)
|
|
worker._process_task(task)
|
|
worker.pool.stop()
|
|
|
|
def test_start_catches_base_exceptions(self):
|
|
worker1 = self.create_worker()
|
|
worker1.blueprint.state = RUN
|
|
stc = MockStep()
|
|
stc.start.side_effect = WorkerTerminate()
|
|
worker1.steps = [stc]
|
|
worker1.start()
|
|
stc.start.assert_called_with(worker1)
|
|
self.assertTrue(stc.terminate.call_count)
|
|
|
|
worker2 = self.create_worker()
|
|
worker2.blueprint.state = RUN
|
|
sec = MockStep()
|
|
sec.start.side_effect = WorkerShutdown()
|
|
sec.terminate = None
|
|
worker2.steps = [sec]
|
|
worker2.start()
|
|
self.assertTrue(sec.stop.call_count)
|
|
|
|
def test_state_db(self):
|
|
from celery.worker import state
|
|
Persistent = state.Persistent
|
|
|
|
state.Persistent = Mock()
|
|
try:
|
|
worker = self.create_worker(state_db='statefilename')
|
|
self.assertTrue(worker._persistence)
|
|
finally:
|
|
state.Persistent = Persistent
|
|
|
|
def test_process_task_sem(self):
|
|
worker = self.worker
|
|
worker._quick_acquire = Mock()
|
|
|
|
req = Mock()
|
|
worker._process_task_sem(req)
|
|
worker._quick_acquire.assert_called_with(worker._process_task, req)
|
|
|
|
def test_signal_consumer_close(self):
|
|
worker = self.worker
|
|
worker.consumer = Mock()
|
|
|
|
worker.signal_consumer_close()
|
|
worker.consumer.close.assert_called_with()
|
|
|
|
worker.consumer.close.side_effect = AttributeError()
|
|
worker.signal_consumer_close()
|
|
|
|
def test_start__stop(self):
|
|
worker = self.worker
|
|
worker.blueprint.shutdown_complete.set()
|
|
worker.steps = [MockStep(StartStopStep(self)) for _ in range(4)]
|
|
worker.blueprint.state = RUN
|
|
worker.blueprint.started = 4
|
|
for w in worker.steps:
|
|
w.start = Mock()
|
|
w.close = Mock()
|
|
w.stop = Mock()
|
|
|
|
worker.start()
|
|
for w in worker.steps:
|
|
self.assertTrue(w.start.call_count)
|
|
worker.consumer = Mock()
|
|
worker.stop()
|
|
for stopstep in worker.steps:
|
|
self.assertTrue(stopstep.close.call_count)
|
|
self.assertTrue(stopstep.stop.call_count)
|
|
|
|
# Doesn't close pool if no pool.
|
|
worker.start()
|
|
worker.pool = None
|
|
worker.stop()
|
|
|
|
# test that stop of None is not attempted
|
|
worker.steps[-1] = None
|
|
worker.start()
|
|
worker.stop()
|
|
|
|
def test_step_raises(self):
|
|
worker = self.worker
|
|
step = Mock()
|
|
worker.steps = [step]
|
|
step.start.side_effect = TypeError()
|
|
worker.stop = Mock()
|
|
worker.start()
|
|
worker.stop.assert_called_with()
|
|
|
|
def test_state(self):
|
|
self.assertTrue(self.worker.state)
|
|
|
|
def test_start__terminate(self):
|
|
worker = self.worker
|
|
worker.blueprint.shutdown_complete.set()
|
|
worker.blueprint.started = 5
|
|
worker.blueprint.state = RUN
|
|
worker.steps = [MockStep() for _ in range(5)]
|
|
worker.start()
|
|
for w in worker.steps[:3]:
|
|
self.assertTrue(w.start.call_count)
|
|
self.assertTrue(worker.blueprint.started, len(worker.steps))
|
|
self.assertEqual(worker.blueprint.state, RUN)
|
|
worker.terminate()
|
|
for step in worker.steps:
|
|
self.assertTrue(step.terminate.call_count)
|
|
|
|
def test_Queues_pool_no_sem(self):
|
|
w = Mock()
|
|
w.pool_cls.uses_semaphore = False
|
|
components.Queues(w).create(w)
|
|
self.assertIs(w.process_task, w._process_task)
|
|
|
|
def test_Hub_crate(self):
|
|
w = Mock()
|
|
x = components.Hub(w)
|
|
x.create(w)
|
|
self.assertTrue(w.timer.max_interval)
|
|
|
|
def test_Pool_crate_threaded(self):
|
|
w = Mock()
|
|
w._conninfo.connection_errors = w._conninfo.channel_errors = ()
|
|
w.pool_cls = Mock()
|
|
w.use_eventloop = False
|
|
pool = components.Pool(w)
|
|
pool.create(w)
|
|
|
|
def test_Pool_create(self):
|
|
from kombu.async.semaphore import LaxBoundedSemaphore
|
|
w = Mock()
|
|
w._conninfo.connection_errors = w._conninfo.channel_errors = ()
|
|
w.hub = Mock()
|
|
|
|
PoolImp = Mock()
|
|
poolimp = PoolImp.return_value = Mock()
|
|
poolimp._pool = [Mock(), Mock()]
|
|
poolimp._cache = {}
|
|
poolimp._fileno_to_inq = {}
|
|
poolimp._fileno_to_outq = {}
|
|
|
|
from celery.concurrency.prefork import TaskPool as _TaskPool
|
|
|
|
class MockTaskPool(_TaskPool):
|
|
Pool = PoolImp
|
|
|
|
@property
|
|
def timers(self):
|
|
return {Mock(): 30}
|
|
|
|
w.pool_cls = MockTaskPool
|
|
w.use_eventloop = True
|
|
w.consumer.restart_count = -1
|
|
pool = components.Pool(w)
|
|
pool.create(w)
|
|
pool.register_with_event_loop(w, w.hub)
|
|
self.assertIsInstance(w.semaphore, LaxBoundedSemaphore)
|
|
P = w.pool
|
|
P.start()
|