diff rhodecode/model/scm.py @ 3960:5293d4bbb1ea

Merged dev into stable/default/master branch
author Marcin Kuzminski <marcin@python-works.com>
date Fri, 07 Jun 2013 00:31:11 +0200
parents 51596d9ef2f8 ebde99f10d77
children 3648a2b2e17a
line wrap: on
line diff
--- a/rhodecode/model/scm.py	Mon May 20 12:26:09 2013 +0200
+++ b/rhodecode/model/scm.py	Fri Jun 07 00:31:11 2013 +0200
@@ -30,7 +30,7 @@
 import logging
 import cStringIO
 import pkg_resources
-from os.path import dirname as dn, join as jn
+from os.path import join as jn
 
 from sqlalchemy import func
 from pylons.i18n.translation import _
@@ -46,13 +46,15 @@
 from rhodecode.lib import helpers as h
 from rhodecode.lib.utils2 import safe_str, safe_unicode, get_server_url,\
     _set_extras
-from rhodecode.lib.auth import HasRepoPermissionAny, HasReposGroupPermissionAny
+from rhodecode.lib.auth import HasRepoPermissionAny, HasReposGroupPermissionAny,\
+    HasUserGroupPermissionAny
 from rhodecode.lib.utils import get_filesystem_repos, make_ui, \
-    action_logger, REMOVED_REPO_PAT
+    action_logger
 from rhodecode.model import BaseModel
 from rhodecode.model.db import Repository, RhodeCodeUi, CacheInvalidation, \
     UserFollowing, UserLog, User, RepoGroup, PullRequest
 from rhodecode.lib.hooks import log_push_action
+from rhodecode.lib.exceptions import NonRelativePathError
 
 log = logging.getLogger(__name__)
 
@@ -96,16 +98,15 @@
         return '<%s (%s)>' % (self.__class__.__name__, self.__len__())
 
     def __iter__(self):
-        # pre-propagated cache_map to save executing select statements
+        # pre-propagated valid_cache_keys to save executing select statements
         # for each repo
-        cache_map = CacheInvalidation.get_cache_map()
+        valid_cache_keys = CacheInvalidation.get_valid_cache_keys()
 
         for dbr in self.db_repo_list:
-            scmr = dbr.scm_instance_cached(cache_map)
+            scmr = dbr.scm_instance_cached(valid_cache_keys)
             # check permission at this level
             if not HasRepoPermissionAny(
-                *self.perm_set
-            )(dbr.repo_name, 'get repo check'):
+                *self.perm_set)(dbr.repo_name, 'get repo check'):
                 continue
 
             try:
@@ -150,8 +151,7 @@
         for dbr in self.db_repo_list:
             # check permission at this level
             if not HasRepoPermissionAny(
-                *self.perm_set
-            )(dbr.repo_name, 'get repo check'):
+                *self.perm_set)(dbr.repo_name, 'get repo check'):
                 continue
 
             tmp_d = {}
@@ -165,36 +165,69 @@
             yield tmp_d
 
 
-class GroupList(object):
-
-    def __init__(self, db_repo_group_list, perm_set=None):
+class _PermCheckIterator(object):
+    def __init__(self, obj_list, obj_attr, perm_set, perm_checker):
         """
-        Creates iterator from given list of group objects, additionally
+        Creates iterator from given list of objects, additionally
         checking permission for them from perm_set var
 
-        :param db_repo_group_list:
-        :param perm_set: list of permissons to check
+        :param obj_list: list of db objects
+        :param obj_attr: attribute of object to pass into perm_checker
+        :param perm_set: list of permissions to check
+        :param perm_checker: callable to check permissions against
         """
-        self.db_repo_group_list = db_repo_group_list
-        if not perm_set:
-            perm_set = ['group.read', 'group.write', 'group.admin']
+        self.obj_list = obj_list
+        self.obj_attr = obj_attr
         self.perm_set = perm_set
+        self.perm_checker = perm_checker
 
     def __len__(self):
-        return len(self.db_repo_group_list)
+        return len(self.obj_list)
 
     def __repr__(self):
         return '<%s (%s)>' % (self.__class__.__name__, self.__len__())
 
     def __iter__(self):
-        for dbgr in self.db_repo_group_list:
+        for db_obj in self.obj_list:
             # check permission at this level
-            if not HasReposGroupPermissionAny(
-                *self.perm_set
-            )(dbgr.group_name, 'get group repo check'):
+            name = getattr(db_obj, self.obj_attr, None)
+            if not self.perm_checker(*self.perm_set)(name, self.__class__.__name__):
                 continue
 
-            yield dbgr
+            yield db_obj
+
+
+class RepoList(_PermCheckIterator):
+
+    def __init__(self, db_repo_list, perm_set=None):
+        if not perm_set:
+            perm_set = ['repository.read', 'repository.write', 'repository.admin']
+
+        super(RepoList, self).__init__(obj_list=db_repo_list,
+                    obj_attr='repo_name', perm_set=perm_set,
+                    perm_checker=HasRepoPermissionAny)
+
+
+class RepoGroupList(_PermCheckIterator):
+
+    def __init__(self, db_repo_group_list, perm_set=None):
+        if not perm_set:
+            perm_set = ['group.read', 'group.write', 'group.admin']
+
+        super(RepoGroupList, self).__init__(obj_list=db_repo_group_list,
+                    obj_attr='group_name', perm_set=perm_set,
+                    perm_checker=HasReposGroupPermissionAny)
+
+
+class UserGroupList(_PermCheckIterator):
+
+    def __init__(self, db_user_group_list, perm_set=None):
+        if not perm_set:
+            perm_set = ['usergroup.read', 'usergroup.write', 'usergroup.admin']
+
+        super(UserGroupList, self).__init__(obj_list=db_user_group_list,
+                    obj_attr='users_group_name', perm_set=perm_set,
+                    perm_checker=HasUserGroupPermissionAny)
 
 
 class ScmModel(BaseModel):
@@ -293,20 +326,18 @@
         if all_groups is None:
             all_groups = RepoGroup.query()\
                 .filter(RepoGroup.group_parent_id == None).all()
-        return [x for x in GroupList(all_groups)]
+        return [x for x in RepoGroupList(all_groups)]
 
     def mark_for_invalidation(self, repo_name):
         """
-        Puts cache invalidation task into db for
-        further global cache invalidation
+        Mark caches of this repo invalid in the database.
 
-        :param repo_name: this repo that should invalidation take place
+        :param repo_name: the repo for which caches should be marked invalid
         """
-        invalidated_keys = CacheInvalidation.set_invalidate(repo_name=repo_name)
+        CacheInvalidation.set_invalidate(repo_name)
         repo = Repository.get_by_repo_name(repo_name)
         if repo:
             repo.update_changeset_cache()
-        return invalidated_keys
 
     def toggle_following_repo(self, follow_repo_id, user_id):
 
@@ -455,12 +486,15 @@
         :param scm_type:
         """
         if scm_type == 'hg':
-            from rhodecode.lib.vcs.backends.hg import \
-                MercurialInMemoryChangeset as IMC
-        elif scm_type == 'git':
-            from rhodecode.lib.vcs.backends.git import \
-                GitInMemoryChangeset as IMC
-        return IMC
+            from rhodecode.lib.vcs.backends.hg import MercurialInMemoryChangeset
+            return MercurialInMemoryChangeset
+
+        if scm_type == 'git':
+            from rhodecode.lib.vcs.backends.git import GitInMemoryChangeset
+            return GitInMemoryChangeset
+
+        raise Exception('Invalid scm_type, must be one of hg,git got %s'
+                        % (scm_type,))
 
     def pull_changes(self, repo, username):
         dbrepo = self.__get_repo(repo)
@@ -473,6 +507,14 @@
         try:
             if repo.alias == 'git':
                 repo.fetch(clone_uri)
+                # git doesn't really have something like post-fetch action
+                # we fake that now. #TODO: extract fetched revisions somehow
+                # here
+                self._handle_push(repo,
+                                  username=username,
+                                  action='push_remote',
+                                  repo_name=repo_name,
+                                  revisions=[])
             else:
                 self._handle_rc_scm_extras(username, dbrepo.repo_name,
                                            repo.alias, action='push_remote')
@@ -516,44 +558,76 @@
                           revisions=[tip.raw_id])
         return tip
 
-    def create_node(self, repo, repo_name, cs, user, author, message, content,
-                      f_path):
+    def create_nodes(self, user, repo, message, nodes, parent_cs=None,
+                     author=None, trigger_push_hook=True):
+        """
+        Commits given multiple nodes into repo
+
+        :param user: RhodeCode User object or user_id, the commiter
+        :param repo: RhodeCode Repository object
+        :param message: commit message
+        :param nodes: mapping {filename:{'content':content},...}
+        :param parent_cs: parent changeset, can be empty than it's initial commit
+        :param author: author of commit, cna be different that commiter only for git
+        :param trigger_push_hook: trigger push hooks
+
+        :returns: new commited changeset
+        """
+
         user = self._get_user(user)
-        IMC = self._get_IMC_module(repo.alias)
+        scm_instance = repo.scm_instance_no_cache()
 
-        # decoding here will force that we have proper encoded values
-        # in any other case this will throw exceptions and deny commit
-        if isinstance(content, (basestring,)):
-            content = safe_str(content)
-        elif isinstance(content, (file, cStringIO.OutputType,)):
-            content = content.read()
-        else:
-            raise Exception('Content is of unrecognized type %s' % (
-                type(content)
-            ))
+        processed_nodes = []
+        for f_path in nodes:
+            if f_path.startswith('/') or f_path.startswith('.') or '../' in f_path:
+                raise NonRelativePathError('%s is not an relative path' % f_path)
+            if f_path:
+                f_path = os.path.normpath(f_path)
+            content = nodes[f_path]['content']
+            f_path = safe_str(f_path)
+            # decoding here will force that we have proper encoded values
+            # in any other case this will throw exceptions and deny commit
+            if isinstance(content, (basestring,)):
+                content = safe_str(content)
+            elif isinstance(content, (file, cStringIO.OutputType,)):
+                content = content.read()
+            else:
+                raise Exception('Content is of unrecognized type %s' % (
+                    type(content)
+                ))
+            processed_nodes.append((f_path, content))
 
         message = safe_unicode(message)
-        author = safe_unicode(author)
-        path = safe_str(f_path)
-        m = IMC(repo)
+        commiter = user.full_contact
+        author = safe_unicode(author) if author else commiter
 
-        if isinstance(cs, EmptyChangeset):
+        IMC = self._get_IMC_module(scm_instance.alias)
+        imc = IMC(scm_instance)
+
+        if not parent_cs:
+            parent_cs = EmptyChangeset(alias=scm_instance.alias)
+
+        if isinstance(parent_cs, EmptyChangeset):
             # EmptyChangeset means we we're editing empty repository
             parents = None
         else:
-            parents = [cs]
-
-        m.add(FileNode(path, content=content))
-        tip = m.commit(message=message,
-                       author=author,
-                       parents=parents, branch=cs.branch)
+            parents = [parent_cs]
+        # add multiple nodes
+        for path, content in processed_nodes:
+            imc.add(FileNode(path, content=content))
 
-        self.mark_for_invalidation(repo_name)
-        self._handle_push(repo,
-                          username=user.username,
-                          action='push_local',
-                          repo_name=repo_name,
-                          revisions=[tip.raw_id])
+        tip = imc.commit(message=message,
+                         author=author,
+                         parents=parents,
+                         branch=parent_cs.branch)
+
+        self.mark_for_invalidation(repo.repo_name)
+        if trigger_push_hook:
+            self._handle_push(scm_instance,
+                              username=user.username,
+                              action='push_local',
+                              repo_name=repo.repo_name,
+                              revisions=[tip.raw_id])
         return tip
 
     def get_nodes(self, repo_name, revision, root_path='/', flat=True):
@@ -595,7 +669,6 @@
         grouped by type
 
         :param repo:
-        :type repo:
         """
 
         hist_l = []
@@ -654,7 +727,6 @@
             if os.path.exists(_hook_file):
                 # let's take a look at this hook, maybe it's rhodecode ?
                 log.debug('hook exists, checking if it is from rhodecode')
-                _HOOK_VER_PAT = re.compile(r'^RC_HOOK_VER')
                 with open(_hook_file, 'rb') as f:
                     data = f.read()
                     matches = re.compile(r'(?:%s)\s*=\s*(.*)'
@@ -671,7 +743,7 @@
                 _rhodecode_hook = True
 
             if _rhodecode_hook or force_create:
-                log.debug('writing %s hook file !' % h_type)
+                log.debug('writing %s hook file !' % (h_type,))
                 with open(_hook_file, 'wb') as f:
                     tmpl = tmpl.replace('_TMPL_', rhodecode.__version__)
                     f.write(tmpl)