202 lines
6.4 KiB
Python
202 lines
6.4 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
celery.backends.database
|
|
~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
SQLAlchemy result store backend.
|
|
|
|
"""
|
|
from __future__ import absolute_import
|
|
|
|
import logging
|
|
from contextlib import contextmanager
|
|
from functools import wraps
|
|
|
|
from celery import states
|
|
from celery.backends.base import BaseBackend
|
|
from celery.exceptions import ImproperlyConfigured
|
|
from celery.five import range
|
|
from celery.utils.timeutils import maybe_timedelta
|
|
|
|
from .models import Task
|
|
from .models import TaskSet
|
|
from .session import SessionManager
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
__all__ = ['DatabaseBackend']
|
|
|
|
|
|
def _sqlalchemy_installed():
|
|
try:
|
|
import sqlalchemy
|
|
except ImportError:
|
|
raise ImproperlyConfigured(
|
|
'The database result backend requires SQLAlchemy to be installed.'
|
|
'See http://pypi.python.org/pypi/SQLAlchemy')
|
|
return sqlalchemy
|
|
_sqlalchemy_installed()
|
|
|
|
from sqlalchemy.exc import DatabaseError, InvalidRequestError
|
|
from sqlalchemy.orm.exc import StaleDataError
|
|
|
|
|
|
@contextmanager
|
|
def session_cleanup(session):
|
|
try:
|
|
yield
|
|
except Exception:
|
|
session.rollback()
|
|
raise
|
|
finally:
|
|
session.close()
|
|
|
|
|
|
def retry(fun):
|
|
|
|
@wraps(fun)
|
|
def _inner(*args, **kwargs):
|
|
max_retries = kwargs.pop('max_retries', 3)
|
|
|
|
for retries in range(max_retries):
|
|
try:
|
|
return fun(*args, **kwargs)
|
|
except (DatabaseError, InvalidRequestError, StaleDataError):
|
|
logger.warning(
|
|
"Failed operation %s. Retrying %s more times.",
|
|
fun.__name__, max_retries - retries - 1,
|
|
exc_info=True,
|
|
)
|
|
if retries + 1 >= max_retries:
|
|
raise
|
|
|
|
return _inner
|
|
|
|
|
|
class DatabaseBackend(BaseBackend):
|
|
"""The database result backend."""
|
|
# ResultSet.iterate should sleep this much between each pool,
|
|
# to not bombard the database with queries.
|
|
subpolling_interval = 0.5
|
|
|
|
def __init__(self, dburi=None, expires=None,
|
|
engine_options=None, url=None, **kwargs):
|
|
# The `url` argument was added later and is used by
|
|
# the app to set backend by url (celery.backends.get_backend_by_url)
|
|
super(DatabaseBackend, self).__init__(**kwargs)
|
|
conf = self.app.conf
|
|
self.expires = maybe_timedelta(self.prepare_expires(expires))
|
|
self.dburi = url or dburi or conf.CELERY_RESULT_DBURI
|
|
self.engine_options = dict(
|
|
engine_options or {},
|
|
**conf.CELERY_RESULT_ENGINE_OPTIONS or {})
|
|
self.short_lived_sessions = kwargs.get(
|
|
'short_lived_sessions',
|
|
conf.CELERY_RESULT_DB_SHORT_LIVED_SESSIONS,
|
|
)
|
|
|
|
tablenames = conf.CELERY_RESULT_DB_TABLENAMES or {}
|
|
Task.__table__.name = tablenames.get('task', 'celery_taskmeta')
|
|
TaskSet.__table__.name = tablenames.get('group', 'celery_tasksetmeta')
|
|
|
|
if not self.dburi:
|
|
raise ImproperlyConfigured(
|
|
'Missing connection string! Do you have '
|
|
'CELERY_RESULT_DBURI set to a real value?')
|
|
|
|
def ResultSession(self, session_manager=SessionManager()):
|
|
return session_manager.session_factory(
|
|
dburi=self.dburi,
|
|
short_lived_sessions=self.short_lived_sessions,
|
|
**self.engine_options
|
|
)
|
|
|
|
@retry
|
|
def _store_result(self, task_id, result, status,
|
|
traceback=None, max_retries=3, **kwargs):
|
|
"""Store return value and status of an executed task."""
|
|
session = self.ResultSession()
|
|
with session_cleanup(session):
|
|
task = list(session.query(Task).filter(Task.task_id == task_id))
|
|
task = task and task[0]
|
|
if not task:
|
|
task = Task(task_id)
|
|
session.add(task)
|
|
session.flush()
|
|
task.result = result
|
|
task.status = status
|
|
task.traceback = traceback
|
|
session.commit()
|
|
return result
|
|
|
|
@retry
|
|
def _get_task_meta_for(self, task_id):
|
|
"""Get task metadata for a task by id."""
|
|
session = self.ResultSession()
|
|
with session_cleanup(session):
|
|
task = list(session.query(Task).filter(Task.task_id == task_id))
|
|
task = task and task[0]
|
|
if not task:
|
|
task = Task(task_id)
|
|
task.status = states.PENDING
|
|
task.result = None
|
|
return task.to_dict()
|
|
|
|
@retry
|
|
def _save_group(self, group_id, result):
|
|
"""Store the result of an executed group."""
|
|
session = self.ResultSession()
|
|
with session_cleanup(session):
|
|
group = TaskSet(group_id, result)
|
|
session.add(group)
|
|
session.flush()
|
|
session.commit()
|
|
return result
|
|
|
|
@retry
|
|
def _restore_group(self, group_id):
|
|
"""Get metadata for group by id."""
|
|
session = self.ResultSession()
|
|
with session_cleanup(session):
|
|
group = session.query(TaskSet).filter(
|
|
TaskSet.taskset_id == group_id).first()
|
|
if group:
|
|
return group.to_dict()
|
|
|
|
@retry
|
|
def _delete_group(self, group_id):
|
|
"""Delete metadata for group by id."""
|
|
session = self.ResultSession()
|
|
with session_cleanup(session):
|
|
session.query(TaskSet).filter(
|
|
TaskSet.taskset_id == group_id).delete()
|
|
session.flush()
|
|
session.commit()
|
|
|
|
@retry
|
|
def _forget(self, task_id):
|
|
"""Forget about result."""
|
|
session = self.ResultSession()
|
|
with session_cleanup(session):
|
|
session.query(Task).filter(Task.task_id == task_id).delete()
|
|
session.commit()
|
|
|
|
def cleanup(self):
|
|
"""Delete expired metadata."""
|
|
session = self.ResultSession()
|
|
expires = self.expires
|
|
now = self.app.now()
|
|
with session_cleanup(session):
|
|
session.query(Task).filter(
|
|
Task.date_done < (now - expires)).delete()
|
|
session.query(TaskSet).filter(
|
|
TaskSet.date_done < (now - expires)).delete()
|
|
session.commit()
|
|
|
|
def __reduce__(self, args=(), kwargs={}):
|
|
kwargs.update(
|
|
dict(dburi=self.dburi,
|
|
expires=self.expires,
|
|
engine_options=self.engine_options))
|
|
return super(DatabaseBackend, self).__reduce__(args, kwargs)
|