"""Integration testing utilities."""
from __future__ import absolute_import, print_function, unicode_literals

import socket
import sys
from collections import defaultdict
from functools import partial
from itertools import count

from kombu.utils.functional import retry_over_time

from celery import states
from celery.exceptions import TimeoutError
from celery.five import items
from celery.result import ResultSet
from celery.utils.text import truncate
from celery.utils.time import humanize_seconds as _humanize_seconds

E_STILL_WAITING = 'Still waiting for {0}.  Trying again {when}: {exc!r}'

humanize_seconds = partial(_humanize_seconds, microseconds=True)

[docs]class Sentinel(Exception): """Signifies the end of something."""
[docs]class ManagerMixin(object): """Mixin that adds :class:`Manager` capabilities.""" def _init_manager(self, block_timeout=30 * 60.0, no_join=False, stdout=None, stderr=None): # type: (float, bool, TextIO, TextIO) -> None self.stdout = sys.stdout if stdout is None else stdout self.stderr = sys.stderr if stderr is None else stderr self.connerrors = self.block_timeout = block_timeout self.no_join = no_join
[docs] def remark(self, s, sep='-'): # type: (str, str) -> None print('{0}{1}'.format(sep, s), file=self.stdout)
[docs] def missing_results(self, r): # type: (Sequence[AsyncResult]) -> Sequence[str] return [ for res in r if not in res.backend._cache]
[docs] def wait_for( self, fun, # type: Callable catch, # type: Sequence[Any] desc="thing", # type: str args=(), # type: Tuple kwargs=None, # type: Dict errback=None, # type: Callable max_retries=10, # type: int interval_start=0.1, # type: float interval_step=0.5, # type: float interval_max=5.0, # type: float emit_warning=False, # type: bool **options # type: Any ): # type: (...) -> Any """Wait for event to happen. The `catch` argument specifies the exception that means the event has not happened yet. """ kwargs = {} if not kwargs else kwargs def on_error(exc, intervals, retries): interval = next(intervals) if emit_warning: self.warn(E_STILL_WAITING.format( desc, when=humanize_seconds(interval, 'in', ' '), exc=exc, )) if errback: errback(exc, interval, retries) return interval return self.retry_over_time( fun, catch, args=args, kwargs=kwargs, errback=on_error, max_retries=max_retries, interval_start=interval_start, interval_step=interval_step, **options )
[docs] def ensure_not_for_a_while(self, fun, catch, desc='thing', max_retries=20, interval_start=0.1, interval_step=0.02, interval_max=1.0, emit_warning=False, **options): """Make sure something does not happen (at least for a while).""" try: return self.wait_for( fun, catch, desc=desc, max_retries=max_retries, interval_start=interval_start, interval_step=interval_step, interval_max=interval_max, emit_warning=emit_warning, ) except catch: pass else: raise AssertionError('Should not have happened: {0}'.format(desc))
[docs] def retry_over_time(self, *args, **kwargs): return retry_over_time(*args, **kwargs)
[docs] def join(self, r, propagate=False, max_retries=10, **kwargs): if self.no_join: return if not isinstance(r, ResultSet): r =[r]) received = [] def on_result(task_id, value): received.append(task_id) for i in range(max_retries) if max_retries else count(0): received[:] = [] try: return r.get(callback=on_result, propagate=propagate, **kwargs) except (socket.timeout, TimeoutError) as exc: waiting_for = self.missing_results(r) self.remark( 'Still waiting for {0}/{1}: [{2}]: {3!r}'.format( len(r) - len(received), len(r), truncate(', '.join(waiting_for)), exc), '!', ) except self.connerrors as exc: self.remark('join: connection lost: {0!r}'.format(exc), '!') raise AssertionError('Test failed: Missing task results')
[docs] def inspect(self, timeout=3.0): return
[docs] def query_tasks(self, ids, timeout=0.5): for reply in items(self.inspect(timeout).query_task(*ids) or {}): yield reply
[docs] def query_task_states(self, ids, timeout=0.5): states = defaultdict(set) for hostname, reply in self.query_tasks(ids, timeout=timeout): for task_id, (state, _) in items(reply): states[state].add(task_id) return states
[docs] def assert_accepted(self, ids, interval=0.5, desc='waiting for tasks to be accepted', **policy): return self.assert_task_worker_state( self.is_accepted, ids, interval=interval, desc=desc, **policy )
[docs] def assert_received(self, ids, interval=0.5, desc='waiting for tasks to be received', **policy): return self.assert_task_worker_state( self.is_accepted, ids, interval=interval, desc=desc, **policy )
[docs] def assert_result_tasks_in_progress_or_completed( self, async_results, interval=0.5, desc='waiting for tasks to be started or completed', **policy ): return self.assert_task_state_from_result( self.is_result_task_in_progress, async_results, interval=interval, desc=desc, **policy )
[docs] def assert_task_state_from_result(self, fun, results, interval=0.5, **policy): return self.wait_for( partial(self.true_or_raise, fun, results, timeout=interval), (Sentinel,), **policy )
[docs] @staticmethod def is_result_task_in_progress(results, **kwargs): possible_states = (states.STARTED, states.SUCCESS) return all(result.state in possible_states for result in results)
[docs] def assert_task_worker_state(self, fun, ids, interval=0.5, **policy): return self.wait_for( partial(self.true_or_raise, fun, ids, timeout=interval), (Sentinel,), **policy )
[docs] def is_received(self, ids, **kwargs): return self._ids_matches_state( ['reserved', 'active', 'ready'], ids, **kwargs)
[docs] def is_accepted(self, ids, **kwargs): return self._ids_matches_state(['active', 'ready'], ids, **kwargs)
def _ids_matches_state(self, expected_states, ids, timeout=0.5): states = self.query_task_states(ids, timeout=timeout) return all( any(t in s for s in [states[k] for k in expected_states]) for t in ids )
[docs] def true_or_raise(self, fun, *args, **kwargs): res = fun(*args, **kwargs) if not res: raise Sentinel() return res
[docs]class Manager(ManagerMixin): """Test helpers for task integration tests.""" def __init__(self, app, **kwargs): = app self._init_manager(**kwargs)