changeset 1929:cd8a7e3698bc beta

fixes #340 session cleanup for celery tasks
author Marcin Kuzminski <marcin@python-works.com>
date Fri, 20 Jan 2012 08:11:00 +0200
parents 470dd49966f3
children a69573cfcb00
files rhodecode/lib/celerylib/__init__.py rhodecode/lib/celerylib/tasks.py
diffstat 2 files changed, 57 insertions(+), 39 deletions(-) [+]
line wrap: on
line diff
--- a/rhodecode/lib/celerylib/__init__.py	Fri Jan 20 08:08:41 2012 +0200
+++ b/rhodecode/lib/celerylib/__init__.py	Fri Jan 20 08:11:00 2012 +0200
@@ -29,6 +29,7 @@
 import traceback
 import logging
 from os.path import dirname as dn, join as jn
+from pylons import config
 
 from hashlib import md5
 from decorator import decorator
@@ -37,15 +38,17 @@
 from rhodecode import CELERY_ON
 from rhodecode.lib import str2bool, safe_str
 from rhodecode.lib.pidlock import DaemonLock, LockHeld
+from rhodecode.model import init_model
+from rhodecode.model import meta
+from rhodecode.model.db import Statistics, Repository, User
+
+from sqlalchemy import engine_from_config
 
 from celery.messaging import establish_connection
 
-
 log = logging.getLogger(__name__)
 
 
-
-
 class ResultWrapper(object):
     def __init__(self, task):
         self.task = task
@@ -103,3 +106,22 @@
             return 'Task with key %s already running' % lockkey
 
     return decorator(__wrapper, func)
+
+
+def get_session():
+    if CELERY_ON:
+        engine = engine_from_config(config, 'sqlalchemy.db1.')
+        init_model(engine)
+    sa = meta.Session
+    return sa
+
+
+def dbsession(func):
+    def __wrapper(func, *fargs, **fkwargs):
+        try:
+            ret = func(*fargs, **fkwargs)
+            return ret
+        finally:
+            meta.Session.remove()
+
+    return decorator(__wrapper, func)
--- a/rhodecode/lib/celerylib/tasks.py	Fri Jan 20 08:08:41 2012 +0200
+++ b/rhodecode/lib/celerylib/tasks.py	Fri Jan 20 08:11:00 2012 +0200
@@ -41,18 +41,15 @@
 
 from rhodecode import CELERY_ON
 from rhodecode.lib import LANGUAGES_EXTENSIONS_MAP, safe_str
-from rhodecode.lib.celerylib import run_task, locked_task, str2bool, \
-    __get_lockkey, LockHeld, DaemonLock
+from rhodecode.lib.celerylib import run_task, locked_task, dbsession, \
+    str2bool, __get_lockkey, LockHeld, DaemonLock, get_session
 from rhodecode.lib.helpers import person
 from rhodecode.lib.rcmail.smtp_mailer import SmtpMailer
 from rhodecode.lib.utils import add_cache, action_logger
 from rhodecode.lib.compat import json, OrderedDict
 
-from rhodecode.model import init_model
-from rhodecode.model import meta
 from rhodecode.model.db import Statistics, Repository, User
 
-from sqlalchemy import engine_from_config
 
 add_cache(config)
 
@@ -60,13 +57,6 @@
            'reset_user_password', 'send_email']
 
 
-def get_session():
-    if CELERY_ON:
-        engine = engine_from_config(config, 'sqlalchemy.db1.')
-        init_model(engine)
-    sa = meta.Session
-    return sa
-
 def get_logger(cls):
     if CELERY_ON:
         try:
@@ -81,21 +71,23 @@
 
 @task(ignore_result=True)
 @locked_task
+@dbsession
 def whoosh_index(repo_location, full_index):
     from rhodecode.lib.indexers.daemon import WhooshIndexingDaemon
-
-    # log = whoosh_index.get_logger(whoosh_index)
+    log = whoosh_index.get_logger(whoosh_index)
+    DBS = get_session()
 
     index_location = config['index_dir']
     WhooshIndexingDaemon(index_location=index_location,
-                         repo_location=repo_location, sa=get_session())\
+                         repo_location=repo_location, sa=DBS)\
                          .run(full_index=full_index)
 
 
 @task(ignore_result=True)
+@dbsession
 def get_commits_stats(repo_name, ts_min_y, ts_max_y):
     log = get_logger(get_commits_stats)
-
+    DBS = get_session()
     lockkey = __get_lockkey('get_commits_stats', repo_name, ts_min_y,
                             ts_max_y)
     lockkey_path = config['here']
@@ -103,7 +95,6 @@
     log.info('running task with lockkey %s', lockkey)
 
     try:
-        sa = get_session()
         lock = l = DaemonLock(file_=jn(lockkey_path, lockkey))
 
         # for js data compatibilty cleans the key for person from '
@@ -128,9 +119,9 @@
         last_cs = None
         timegetter = itemgetter('time')
 
-        dbrepo = sa.query(Repository)\
+        dbrepo = DBS.query(Repository)\
             .filter(Repository.repo_name == repo_name).scalar()
-        cur_stats = sa.query(Statistics)\
+        cur_stats = DBS.query(Statistics)\
             .filter(Statistics.repository == dbrepo).scalar()
 
         if cur_stats is not None:
@@ -234,11 +225,11 @@
         try:
             stats.repository = dbrepo
             stats.stat_on_revision = last_cs.revision if last_cs else 0
-            sa.add(stats)
-            sa.commit()
+            DBS.add(stats)
+            DBS.commit()
         except:
             log.error(traceback.format_exc())
-            sa.rollback()
+            DBS.rollback()
             lock.release()
             return False
 
@@ -254,13 +245,14 @@
         return 'Task with key %s already running' % lockkey
 
 @task(ignore_result=True)
+@dbsession
 def send_password_link(user_email):
     from rhodecode.model.notification import EmailNotificationModel
 
     log = get_logger(send_password_link)
-
+    DBS = get_session()
+    
     try:
-        sa = get_session()
         user = User.get_by_email(user_email)
         if user:
             log.debug('password reset user found %s' % user)
@@ -283,28 +275,29 @@
     return True
 
 @task(ignore_result=True)
+@dbsession
 def reset_user_password(user_email):
     from rhodecode.lib import auth
 
     log = get_logger(reset_user_password)
-
+    DBS = get_session()
+    
     try:
         try:
-            sa = get_session()
             user = User.get_by_email(user_email)
             new_passwd = auth.PasswordGenerator().gen_password(8,
                              auth.PasswordGenerator.ALPHABETS_BIG_SMALL)
             if user:
                 user.password = auth.get_crypt_password(new_passwd)
                 user.api_key = auth.generate_api_key(user.username)
-                sa.add(user)
-                sa.commit()
+                DBS.add(user)
+                DBS.commit()
                 log.info('change password for %s', user_email)
             if new_passwd is None:
                 raise Exception('unable to generate new password')
         except:
             log.error(traceback.format_exc())
-            sa.rollback()
+            DBS.rollback()
 
         run_task(send_email, user_email,
                  'Your new password',
@@ -319,6 +312,7 @@
 
 
 @task(ignore_result=True)
+@dbsession
 def send_email(recipients, subject, body, html_body=''):
     """
     Sends an email with defined parameters from the .ini files.
@@ -330,7 +324,8 @@
     :param html_body: html version of body
     """
     log = get_logger(send_email)
-    sa = get_session()
+    DBS = get_session()
+    
     email_config = config
     subject = "%s %s" % (email_config.get('email_prefix'), subject)
     if not recipients:
@@ -361,6 +356,7 @@
 
 
 @task(ignore_result=True)
+@dbsession
 def create_repo_fork(form_data, cur_user):
     """
     Creates a fork of repository using interval VCS methods
@@ -371,11 +367,11 @@
     from rhodecode.model.repo import RepoModel
 
     log = get_logger(create_repo_fork)
-
-    Session = get_session()
+    DBS = create_repo_fork.DBS
+    
     base_path = Repository.base_path()
 
-    RepoModel(Session).create(form_data, cur_user, just_db=True, fork=True)
+    RepoModel(DBS).create(form_data, cur_user, just_db=True, fork=True)
 
     alias = form_data['repo_type']
     org_repo_name = form_data['org_path']
@@ -391,12 +387,12 @@
             src_url=safe_str(source_repo_path),
             update_after_clone=update_after_clone)
     action_logger(cur_user, 'user_forked_repo:%s' % fork_name,
-                   org_repo_name, '', Session)
+                   org_repo_name, '', DBS)
 
     action_logger(cur_user, 'user_created_fork:%s' % fork_name,
-                   fork_name, '', Session)
+                   fork_name, '', DBS)
     # finally commit at latest possible stage
-    Session.commit()
+    DBS.commit()
 
 def __get_codes_stats(repo_name):
     repo = Repository.get_by_repo_name(repo_name).scm_instance