diff rhodecode/model/scm.py @ 1512:bf263968da47

merge beta in stable branch
author Marcin Kuzminski <marcin@python-works.com>
date Fri, 07 Oct 2011 01:08:50 +0200
parents 8363b0d20c41 4aba7be311e8
children 752b0a7b7679
line wrap: on
line diff
--- a/rhodecode/model/scm.py	Thu May 12 19:50:48 2011 +0200
+++ b/rhodecode/model/scm.py	Fri Oct 07 01:08:50 2011 +0200
@@ -27,29 +27,23 @@
 import traceback
 import logging
 
+from sqlalchemy.exc import DatabaseError
+
 from vcs import get_backend
-from vcs.utils.helpers import get_scm
-from vcs.exceptions import RepositoryError, VCSError
+from vcs.exceptions import RepositoryError
 from vcs.utils.lazy import LazyProperty
-
-from mercurial import ui
-
-from beaker.cache import cache_region, region_invalidate
+from vcs.nodes import FileNode
 
 from rhodecode import BACKENDS
 from rhodecode.lib import helpers as h
+from rhodecode.lib import safe_str
 from rhodecode.lib.auth import HasRepoPermissionAny
-from rhodecode.lib.utils import get_repos, make_ui, action_logger
+from rhodecode.lib.utils import get_repos as get_filesystem_repos, make_ui, \
+    action_logger, EmptyChangeset
 from rhodecode.model import BaseModel
 from rhodecode.model.user import UserModel
-
 from rhodecode.model.db import Repository, RhodeCodeUi, CacheInvalidation, \
     UserFollowing, UserLog
-from rhodecode.model.caching_query import FromCache
-
-from sqlalchemy.orm import joinedload
-from sqlalchemy.orm.session import make_transient
-from sqlalchemy.exc import DatabaseError
 
 log = logging.getLogger(__name__)
 
@@ -69,6 +63,61 @@
     def __repr__(self):
         return "<%s('id:%s')>" % (self.__class__.__name__, self.repo_id)
 
+class CachedRepoList(object):
+
+    def __init__(self, db_repo_list, repos_path, order_by=None):
+        self.db_repo_list = db_repo_list
+        self.repos_path = repos_path
+        self.order_by = order_by
+        self.reversed = (order_by or '').startswith('-')
+
+    def __len__(self):
+        return len(self.db_repo_list)
+
+    def __repr__(self):
+        return '<%s (%s)>' % (self.__class__.__name__, self.__len__())
+
+    def __iter__(self):
+        for dbr in self.db_repo_list:
+
+            scmr = dbr.scm_instance_cached
+
+            # check permission at this level
+            if not HasRepoPermissionAny('repository.read', 'repository.write',
+                                        'repository.admin')(dbr.repo_name,
+                                                            'get repo check'):
+                continue
+
+            if scmr is None:
+                log.error('%s this repository is present in database but it '
+                          'cannot be created as an scm instance',
+                          dbr.repo_name)
+                continue
+
+            last_change = scmr.last_change
+            tip = h.get_changeset_safe(scmr, 'tip')
+
+            tmp_d = {}
+            tmp_d['name'] = dbr.repo_name
+            tmp_d['name_sort'] = tmp_d['name'].lower()
+            tmp_d['description'] = dbr.description
+            tmp_d['description_sort'] = tmp_d['description']
+            tmp_d['last_change'] = last_change
+            tmp_d['last_change_sort'] = time.mktime(last_change \
+                                                    .timetuple())
+            tmp_d['tip'] = tip.raw_id
+            tmp_d['tip_sort'] = tip.revision
+            tmp_d['rev'] = tip.revision
+            tmp_d['contact'] = dbr.user.full_contact
+            tmp_d['contact_sort'] = tmp_d['contact']
+            tmp_d['owner_sort'] = tmp_d['contact']
+            tmp_d['repo_archives'] = list(scmr._get_archives())
+            tmp_d['last_msg'] = tip.message
+            tmp_d['author'] = tip.author
+            tmp_d['dbrepo'] = dbr.get_dict()
+            tmp_d['dbrepo_fork'] = dbr.fork.get_dict() if dbr.fork \
+                                                                    else {}
+            yield tmp_d
 
 class ScmModel(BaseModel):
     """Generic Scm Model
@@ -83,21 +132,22 @@
 
         return q.ui_value
 
-    def repo_scan(self, repos_path, baseui):
+    def repo_scan(self, repos_path=None):
         """Listing of repositories in given path. This path should not be a
         repository itself. Return a dictionary of repository objects
 
         :param repos_path: path to directory containing repositories
-        :param baseui: baseui instance to instantiate MercurialRepostitory with
         """
 
         log.info('scanning for repositories in %s', repos_path)
 
-        if not isinstance(baseui, ui.ui):
-            baseui = make_ui('db')
+        if repos_path is None:
+            repos_path = self.repos_path
+
+        baseui = make_ui('db')
         repos_list = {}
 
-        for name, path in get_repos(repos_path):
+        for name, path in get_filesystem_repos(repos_path, recursive=True):
             try:
                 if name in repos_list:
                     raise RepositoryError('Duplicate repository name %s '
@@ -107,7 +157,10 @@
                     klass = get_backend(path[0])
 
                     if path[0] == 'hg' and path[0] in BACKENDS.keys():
-                        repos_list[name] = klass(path[1], baseui=baseui)
+
+                        # for mercurial we need to have an str path
+                        repos_list[name] = klass(safe_str(path[1]),
+                                                 baseui=baseui)
 
                     if path[0] == 'git' and path[0] in BACKENDS.keys():
                         repos_list[name] = klass(path[1])
@@ -116,121 +169,23 @@
 
         return repos_list
 
-    def get_repos(self, all_repos=None):
-        """Get all repos from db and for each repo create it's
-        backend instance.and fill that backed with information from database
+    def get_repos(self, all_repos=None, sort_key=None):
+        """
+        Get all repos from db and for each repo create it's
+        backend instance and fill that backed with information from database
 
-        :param all_repos: give specific repositories list, good for filtering
+        :param all_repos: list of repository names as strings
+            give specific repositories list, good for filtering
         """
-
         if all_repos is None:
             all_repos = self.sa.query(Repository)\
-                .order_by(Repository.repo_name).all()
-
-        #get the repositories that should be invalidated
-        invalidation_list = [str(x.cache_key) for x in \
-                             self.sa.query(CacheInvalidation.cache_key)\
-                             .filter(CacheInvalidation.cache_active == False)\
-                             .all()]
-
-        for r in all_repos:
-
-            repo = self.get(r.repo_name, invalidation_list)
-
-            if repo is not None:
-                last_change = repo.last_change
-                tip = h.get_changeset_safe(repo, 'tip')
-
-                tmp_d = {}
-                tmp_d['name'] = repo.name
-                tmp_d['name_sort'] = tmp_d['name'].lower()
-                tmp_d['description'] = repo.dbrepo.description
-                tmp_d['description_sort'] = tmp_d['description']
-                tmp_d['last_change'] = last_change
-                tmp_d['last_change_sort'] = time.mktime(last_change \
-                                                        .timetuple())
-                tmp_d['tip'] = tip.raw_id
-                tmp_d['tip_sort'] = tip.revision
-                tmp_d['rev'] = tip.revision
-                tmp_d['contact'] = repo.dbrepo.user.full_contact
-                tmp_d['contact_sort'] = tmp_d['contact']
-                tmp_d['owner_sort'] = tmp_d['contact']
-                tmp_d['repo_archives'] = list(repo._get_archives())
-                tmp_d['last_msg'] = tip.message
-                tmp_d['repo'] = repo
-                yield tmp_d
-
-    def get_repo(self, repo_name):
-        return self.get(repo_name)
-
-    def get(self, repo_name, invalidation_list=None):
-        """Get's repository from given name, creates BackendInstance and
-        propagates it's data from database with all additional information
-
-        :param repo_name:
-        :param invalidation_list: if a invalidation list is given the get
-            method should not manually check if this repository needs
-            invalidation and just invalidate the repositories in list
-
-        """
-        if not HasRepoPermissionAny('repository.read', 'repository.write',
-                            'repository.admin')(repo_name, 'get repo check'):
-            return
+                        .filter(Repository.group_id == None)\
+                        .order_by(Repository.repo_name).all()
 
-        #======================================================================
-        # CACHE FUNCTION
-        #======================================================================
-        @cache_region('long_term')
-        def _get_repo(repo_name):
-
-            repo_path = os.path.join(self.repos_path, repo_name)
-
-            try:
-                alias = get_scm(repo_path)[0]
-
-                log.debug('Creating instance of %s repository', alias)
-                backend = get_backend(alias)
-            except VCSError:
-                log.error(traceback.format_exc())
-                return
-
-            if alias == 'hg':
-                from pylons import app_globals as g
-                repo = backend(repo_path, create=False, baseui=g.baseui)
-                #skip hidden web repository
-                if repo._get_hidden():
-                    return
-            else:
-                repo = backend(repo_path, create=False)
+        repo_iter = CachedRepoList(all_repos, repos_path=self.repos_path,
+                                   order_by=sort_key)
 
-            dbrepo = self.sa.query(Repository)\
-                .options(joinedload(Repository.fork))\
-                .options(joinedload(Repository.user))\
-                .filter(Repository.repo_name == repo_name)\
-                .scalar()
-
-            make_transient(dbrepo)
-            if dbrepo.user:
-                make_transient(dbrepo.user)
-            if dbrepo.fork:
-                make_transient(dbrepo.fork)
-
-            repo.dbrepo = dbrepo
-            return repo
-
-        pre_invalidate = True
-        if invalidation_list is not None:
-            pre_invalidate = repo_name in invalidation_list
-
-        if pre_invalidate:
-            invalidate = self._should_invalidate(repo_name)
-
-            if invalidate:
-                log.info('invalidating cache for repository %s', repo_name)
-                region_invalidate(_get_repo, None, repo_name)
-                self._mark_invalidated(invalidate)
-
-        return _get_repo(repo_name)
+        return repo_iter
 
     def mark_for_invalidation(self, repo_name):
         """Puts cache invalidation task into db for
@@ -244,7 +199,7 @@
             .filter(CacheInvalidation.cache_key == repo_name).scalar()
 
         if cache:
-            #mark this cache as inactive
+            # mark this cache as inactive
             cache.cache_active = False
         else:
             log.debug('cache key not found in invalidation db -> creating one')
@@ -317,7 +272,7 @@
             self.sa.rollback()
             raise
 
-    def is_following_repo(self, repo_name, user_id):
+    def is_following_repo(self, repo_name, user_id, cache=False):
         r = self.sa.query(Repository)\
             .filter(Repository.repo_name == repo_name).scalar()
 
@@ -327,7 +282,7 @@
 
         return f is not None
 
-    def is_following_user(self, username, user_id):
+    def is_following_user(self, username, user_id, cache=False):
         u = UserModel(self.sa).get_by_username(username)
 
         f = self.sa.query(UserFollowing)\
@@ -337,13 +292,106 @@
         return f is not None
 
     def get_followers(self, repo_id):
+        if not isinstance(repo_id, int):
+            repo_id = getattr(Repository.by_repo_name(repo_id), 'repo_id')
+
         return self.sa.query(UserFollowing)\
                 .filter(UserFollowing.follows_repo_id == repo_id).count()
 
     def get_forks(self, repo_id):
+        if not isinstance(repo_id, int):
+            repo_id = getattr(Repository.by_repo_name(repo_id), 'repo_id')
+
         return self.sa.query(Repository)\
                 .filter(Repository.fork_id == repo_id).count()
 
+    def pull_changes(self, repo_name, username):
+        dbrepo = Repository.by_repo_name(repo_name)
+        clone_uri = dbrepo.clone_uri
+        if not clone_uri:
+            raise Exception("This repository doesn't have a clone uri")
+        
+        repo = dbrepo.scm_instance
+        try:
+            extras = {'ip': '',
+                      'username': username,
+                      'action': 'push_remote',
+                      'repository': repo_name}
+
+            #inject ui extra param to log this action via push logger
+            for k, v in extras.items():
+                repo._repo.ui.setconfig('rhodecode_extras', k, v)
+
+            repo.pull(clone_uri)
+            self.mark_for_invalidation(repo_name)
+        except:
+            log.error(traceback.format_exc())
+            raise
+
+    def commit_change(self, repo, repo_name, cs, user, author, message, content,
+                      f_path):
+
+        if repo.alias == 'hg':
+            from vcs.backends.hg import MercurialInMemoryChangeset as IMC
+        elif repo.alias == 'git':
+            from vcs.backends.git import GitInMemoryChangeset as IMC
+
+        # decoding here will force that we have proper encoded values
+        # in any other case this will throw exceptions and deny commit
+        content = safe_str(content)
+        message = safe_str(message)
+        path = safe_str(f_path)
+        author = safe_str(author)
+        m = IMC(repo)
+        m.change(FileNode(path, content))
+        tip = m.commit(message=message,
+                 author=author,
+                 parents=[cs], branch=cs.branch)
+
+        new_cs = tip.short_id
+        action = 'push_local:%s' % new_cs
+
+        action_logger(user, action, repo_name)
+
+        self.mark_for_invalidation(repo_name)
+
+    def create_node(self, repo, repo_name, cs, user, author, message, content,
+                      f_path):
+        if repo.alias == 'hg':
+            from vcs.backends.hg import MercurialInMemoryChangeset as IMC
+        elif repo.alias == 'git':
+            from vcs.backends.git import GitInMemoryChangeset as IMC
+        # decoding here will force that we have proper encoded values
+        # in any other case this will throw exceptions and deny commit
+        
+        if isinstance(content,(basestring,)):
+            content = safe_str(content)
+        elif isinstance(content,file):
+            content = content.read()
+            
+        message = safe_str(message)
+        path = safe_str(f_path)
+        author = safe_str(author)
+        m = IMC(repo)
+
+        if isinstance(cs, EmptyChangeset):
+            # Emptychangeset means we we're editing empty repository
+            parents = None
+        else:
+            parents = [cs]
+
+        m.add(FileNode(path, content=content))
+        tip = m.commit(message=message,
+                 author=author,
+                 parents=parents, branch=cs.branch)
+        new_cs = tip.short_id
+        action = 'push_local:%s' % new_cs
+
+        action_logger(user, action, repo_name)
+
+        self.mark_for_invalidation(repo_name)
+
+
     def get_unread_journal(self):
         return self.sa.query(UserLog).count()
 
@@ -354,27 +402,9 @@
         """
 
         ret = self.sa.query(CacheInvalidation)\
-            .options(FromCache('sql_cache_short',
-                           'get_invalidation_%s' % repo_name))\
             .filter(CacheInvalidation.cache_key == repo_name)\
             .filter(CacheInvalidation.cache_active == False)\
             .scalar()
 
         return ret
 
-    def _mark_invalidated(self, cache_key):
-        """ Marks all occurences of cache to invaldation as
-        already invalidated
-
-        :param cache_key:
-        """
-
-        if cache_key:
-            log.debug('marking %s as already invalidated', cache_key)
-        try:
-            cache_key.cache_active = True
-            self.sa.add(cache_key)
-            self.sa.commit()
-        except (DatabaseError,):
-            log.error(traceback.format_exc())
-            self.sa.rollback()