Mercurial > kallithea
diff rhodecode/model/scm.py @ 3960:5293d4bbb1ea
Merged dev into stable/default/master branch
author | Marcin Kuzminski <marcin@python-works.com> |
---|---|
date | Fri, 07 Jun 2013 00:31:11 +0200 |
parents | 51596d9ef2f8 ebde99f10d77 |
children | 3648a2b2e17a |
line wrap: on
line diff
--- a/rhodecode/model/scm.py Mon May 20 12:26:09 2013 +0200 +++ b/rhodecode/model/scm.py Fri Jun 07 00:31:11 2013 +0200 @@ -30,7 +30,7 @@ import logging import cStringIO import pkg_resources -from os.path import dirname as dn, join as jn +from os.path import join as jn from sqlalchemy import func from pylons.i18n.translation import _ @@ -46,13 +46,15 @@ from rhodecode.lib import helpers as h from rhodecode.lib.utils2 import safe_str, safe_unicode, get_server_url,\ _set_extras -from rhodecode.lib.auth import HasRepoPermissionAny, HasReposGroupPermissionAny +from rhodecode.lib.auth import HasRepoPermissionAny, HasReposGroupPermissionAny,\ + HasUserGroupPermissionAny from rhodecode.lib.utils import get_filesystem_repos, make_ui, \ - action_logger, REMOVED_REPO_PAT + action_logger from rhodecode.model import BaseModel from rhodecode.model.db import Repository, RhodeCodeUi, CacheInvalidation, \ UserFollowing, UserLog, User, RepoGroup, PullRequest from rhodecode.lib.hooks import log_push_action +from rhodecode.lib.exceptions import NonRelativePathError log = logging.getLogger(__name__) @@ -96,16 +98,15 @@ return '<%s (%s)>' % (self.__class__.__name__, self.__len__()) def __iter__(self): - # pre-propagated cache_map to save executing select statements + # pre-propagated valid_cache_keys to save executing select statements # for each repo - cache_map = CacheInvalidation.get_cache_map() + valid_cache_keys = CacheInvalidation.get_valid_cache_keys() for dbr in self.db_repo_list: - scmr = dbr.scm_instance_cached(cache_map) + scmr = dbr.scm_instance_cached(valid_cache_keys) # check permission at this level if not HasRepoPermissionAny( - *self.perm_set - )(dbr.repo_name, 'get repo check'): + *self.perm_set)(dbr.repo_name, 'get repo check'): continue try: @@ -150,8 +151,7 @@ for dbr in self.db_repo_list: # check permission at this level if not HasRepoPermissionAny( - *self.perm_set - )(dbr.repo_name, 'get repo check'): + *self.perm_set)(dbr.repo_name, 'get repo check'): continue tmp_d = {} @@ -165,36 +165,69 @@ yield tmp_d -class GroupList(object): - - def __init__(self, db_repo_group_list, perm_set=None): +class _PermCheckIterator(object): + def __init__(self, obj_list, obj_attr, perm_set, perm_checker): """ - Creates iterator from given list of group objects, additionally + Creates iterator from given list of objects, additionally checking permission for them from perm_set var - :param db_repo_group_list: - :param perm_set: list of permissons to check + :param obj_list: list of db objects + :param obj_attr: attribute of object to pass into perm_checker + :param perm_set: list of permissions to check + :param perm_checker: callable to check permissions against """ - self.db_repo_group_list = db_repo_group_list - if not perm_set: - perm_set = ['group.read', 'group.write', 'group.admin'] + self.obj_list = obj_list + self.obj_attr = obj_attr self.perm_set = perm_set + self.perm_checker = perm_checker def __len__(self): - return len(self.db_repo_group_list) + return len(self.obj_list) def __repr__(self): return '<%s (%s)>' % (self.__class__.__name__, self.__len__()) def __iter__(self): - for dbgr in self.db_repo_group_list: + for db_obj in self.obj_list: # check permission at this level - if not HasReposGroupPermissionAny( - *self.perm_set - )(dbgr.group_name, 'get group repo check'): + name = getattr(db_obj, self.obj_attr, None) + if not self.perm_checker(*self.perm_set)(name, self.__class__.__name__): continue - yield dbgr + yield db_obj + + +class RepoList(_PermCheckIterator): + + def __init__(self, db_repo_list, perm_set=None): + if not perm_set: + perm_set = ['repository.read', 'repository.write', 'repository.admin'] + + super(RepoList, self).__init__(obj_list=db_repo_list, + obj_attr='repo_name', perm_set=perm_set, + perm_checker=HasRepoPermissionAny) + + +class RepoGroupList(_PermCheckIterator): + + def __init__(self, db_repo_group_list, perm_set=None): + if not perm_set: + perm_set = ['group.read', 'group.write', 'group.admin'] + + super(RepoGroupList, self).__init__(obj_list=db_repo_group_list, + obj_attr='group_name', perm_set=perm_set, + perm_checker=HasReposGroupPermissionAny) + + +class UserGroupList(_PermCheckIterator): + + def __init__(self, db_user_group_list, perm_set=None): + if not perm_set: + perm_set = ['usergroup.read', 'usergroup.write', 'usergroup.admin'] + + super(UserGroupList, self).__init__(obj_list=db_user_group_list, + obj_attr='users_group_name', perm_set=perm_set, + perm_checker=HasUserGroupPermissionAny) class ScmModel(BaseModel): @@ -293,20 +326,18 @@ if all_groups is None: all_groups = RepoGroup.query()\ .filter(RepoGroup.group_parent_id == None).all() - return [x for x in GroupList(all_groups)] + return [x for x in RepoGroupList(all_groups)] def mark_for_invalidation(self, repo_name): """ - Puts cache invalidation task into db for - further global cache invalidation + Mark caches of this repo invalid in the database. - :param repo_name: this repo that should invalidation take place + :param repo_name: the repo for which caches should be marked invalid """ - invalidated_keys = CacheInvalidation.set_invalidate(repo_name=repo_name) + CacheInvalidation.set_invalidate(repo_name) repo = Repository.get_by_repo_name(repo_name) if repo: repo.update_changeset_cache() - return invalidated_keys def toggle_following_repo(self, follow_repo_id, user_id): @@ -455,12 +486,15 @@ :param scm_type: """ if scm_type == 'hg': - from rhodecode.lib.vcs.backends.hg import \ - MercurialInMemoryChangeset as IMC - elif scm_type == 'git': - from rhodecode.lib.vcs.backends.git import \ - GitInMemoryChangeset as IMC - return IMC + from rhodecode.lib.vcs.backends.hg import MercurialInMemoryChangeset + return MercurialInMemoryChangeset + + if scm_type == 'git': + from rhodecode.lib.vcs.backends.git import GitInMemoryChangeset + return GitInMemoryChangeset + + raise Exception('Invalid scm_type, must be one of hg,git got %s' + % (scm_type,)) def pull_changes(self, repo, username): dbrepo = self.__get_repo(repo) @@ -473,6 +507,14 @@ try: if repo.alias == 'git': repo.fetch(clone_uri) + # git doesn't really have something like post-fetch action + # we fake that now. #TODO: extract fetched revisions somehow + # here + self._handle_push(repo, + username=username, + action='push_remote', + repo_name=repo_name, + revisions=[]) else: self._handle_rc_scm_extras(username, dbrepo.repo_name, repo.alias, action='push_remote') @@ -516,44 +558,76 @@ revisions=[tip.raw_id]) return tip - def create_node(self, repo, repo_name, cs, user, author, message, content, - f_path): + def create_nodes(self, user, repo, message, nodes, parent_cs=None, + author=None, trigger_push_hook=True): + """ + Commits given multiple nodes into repo + + :param user: RhodeCode User object or user_id, the commiter + :param repo: RhodeCode Repository object + :param message: commit message + :param nodes: mapping {filename:{'content':content},...} + :param parent_cs: parent changeset, can be empty than it's initial commit + :param author: author of commit, cna be different that commiter only for git + :param trigger_push_hook: trigger push hooks + + :returns: new commited changeset + """ + user = self._get_user(user) - IMC = self._get_IMC_module(repo.alias) + scm_instance = repo.scm_instance_no_cache() - # decoding here will force that we have proper encoded values - # in any other case this will throw exceptions and deny commit - if isinstance(content, (basestring,)): - content = safe_str(content) - elif isinstance(content, (file, cStringIO.OutputType,)): - content = content.read() - else: - raise Exception('Content is of unrecognized type %s' % ( - type(content) - )) + processed_nodes = [] + for f_path in nodes: + if f_path.startswith('/') or f_path.startswith('.') or '../' in f_path: + raise NonRelativePathError('%s is not an relative path' % f_path) + if f_path: + f_path = os.path.normpath(f_path) + content = nodes[f_path]['content'] + f_path = safe_str(f_path) + # decoding here will force that we have proper encoded values + # in any other case this will throw exceptions and deny commit + if isinstance(content, (basestring,)): + content = safe_str(content) + elif isinstance(content, (file, cStringIO.OutputType,)): + content = content.read() + else: + raise Exception('Content is of unrecognized type %s' % ( + type(content) + )) + processed_nodes.append((f_path, content)) message = safe_unicode(message) - author = safe_unicode(author) - path = safe_str(f_path) - m = IMC(repo) + commiter = user.full_contact + author = safe_unicode(author) if author else commiter - if isinstance(cs, EmptyChangeset): + IMC = self._get_IMC_module(scm_instance.alias) + imc = IMC(scm_instance) + + if not parent_cs: + parent_cs = EmptyChangeset(alias=scm_instance.alias) + + if isinstance(parent_cs, EmptyChangeset): # EmptyChangeset means we we're editing empty repository parents = None else: - parents = [cs] - - m.add(FileNode(path, content=content)) - tip = m.commit(message=message, - author=author, - parents=parents, branch=cs.branch) + parents = [parent_cs] + # add multiple nodes + for path, content in processed_nodes: + imc.add(FileNode(path, content=content)) - self.mark_for_invalidation(repo_name) - self._handle_push(repo, - username=user.username, - action='push_local', - repo_name=repo_name, - revisions=[tip.raw_id]) + tip = imc.commit(message=message, + author=author, + parents=parents, + branch=parent_cs.branch) + + self.mark_for_invalidation(repo.repo_name) + if trigger_push_hook: + self._handle_push(scm_instance, + username=user.username, + action='push_local', + repo_name=repo.repo_name, + revisions=[tip.raw_id]) return tip def get_nodes(self, repo_name, revision, root_path='/', flat=True): @@ -595,7 +669,6 @@ grouped by type :param repo: - :type repo: """ hist_l = [] @@ -654,7 +727,6 @@ if os.path.exists(_hook_file): # let's take a look at this hook, maybe it's rhodecode ? log.debug('hook exists, checking if it is from rhodecode') - _HOOK_VER_PAT = re.compile(r'^RC_HOOK_VER') with open(_hook_file, 'rb') as f: data = f.read() matches = re.compile(r'(?:%s)\s*=\s*(.*)' @@ -671,7 +743,7 @@ _rhodecode_hook = True if _rhodecode_hook or force_create: - log.debug('writing %s hook file !' % h_type) + log.debug('writing %s hook file !' % (h_type,)) with open(_hook_file, 'wb') as f: tmpl = tmpl.replace('_TMPL_', rhodecode.__version__) f.write(tmpl)