comparison rhodecode/model/scm.py @ 2031:82a88013a3fd

merge 1.3 into stable
author Marcin Kuzminski <marcin@python-works.com>
date Sun, 26 Feb 2012 17:25:09 +0200
parents 1ce36a5f2305 324ac367a4da
children 6c6718c06ea2
comparison
equal deleted inserted replaced
2005:ab0e122b38a7 2031:82a88013a3fd
26 import time 26 import time
27 import traceback 27 import traceback
28 import logging 28 import logging
29 import cStringIO 29 import cStringIO
30 30
31 from sqlalchemy.exc import DatabaseError 31 from rhodecode.lib.vcs import get_backend
32 32 from rhodecode.lib.vcs.exceptions import RepositoryError
33 from vcs import get_backend 33 from rhodecode.lib.vcs.utils.lazy import LazyProperty
34 from vcs.exceptions import RepositoryError 34 from rhodecode.lib.vcs.nodes import FileNode
35 from vcs.utils.lazy import LazyProperty
36 from vcs.nodes import FileNode
37 35
38 from rhodecode import BACKENDS 36 from rhodecode import BACKENDS
39 from rhodecode.lib import helpers as h 37 from rhodecode.lib import helpers as h
40 from rhodecode.lib import safe_str 38 from rhodecode.lib import safe_str
41 from rhodecode.lib.auth import HasRepoPermissionAny 39 from rhodecode.lib.auth import HasRepoPermissionAny, HasReposGroupPermissionAny
42 from rhodecode.lib.utils import get_repos as get_filesystem_repos, make_ui, \ 40 from rhodecode.lib.utils import get_repos as get_filesystem_repos, make_ui, \
43 action_logger, EmptyChangeset 41 action_logger, EmptyChangeset
44 from rhodecode.model import BaseModel 42 from rhodecode.model import BaseModel
45 from rhodecode.model.db import Repository, RhodeCodeUi, CacheInvalidation, \ 43 from rhodecode.model.db import Repository, RhodeCodeUi, CacheInvalidation, \
46 UserFollowing, UserLog, User 44 UserFollowing, UserLog, User, RepoGroup
47 45
48 log = logging.getLogger(__name__) 46 log = logging.getLogger(__name__)
49 47
50 48
51 class UserTemp(object): 49 class UserTemp(object):
60 def __init__(self, repo_id): 58 def __init__(self, repo_id):
61 self.repo_id = repo_id 59 self.repo_id = repo_id
62 60
63 def __repr__(self): 61 def __repr__(self):
64 return "<%s('id:%s')>" % (self.__class__.__name__, self.repo_id) 62 return "<%s('id:%s')>" % (self.__class__.__name__, self.repo_id)
63
65 64
66 class CachedRepoList(object): 65 class CachedRepoList(object):
67 66
68 def __init__(self, db_repo_list, repos_path, order_by=None): 67 def __init__(self, db_repo_list, repos_path, order_by=None):
69 self.db_repo_list = db_repo_list 68 self.db_repo_list = db_repo_list
77 def __repr__(self): 76 def __repr__(self):
78 return '<%s (%s)>' % (self.__class__.__name__, self.__len__()) 77 return '<%s (%s)>' % (self.__class__.__name__, self.__len__())
79 78
80 def __iter__(self): 79 def __iter__(self):
81 for dbr in self.db_repo_list: 80 for dbr in self.db_repo_list:
82
83 scmr = dbr.scm_instance_cached 81 scmr = dbr.scm_instance_cached
84
85 # check permission at this level 82 # check permission at this level
86 if not HasRepoPermissionAny('repository.read', 'repository.write', 83 if not HasRepoPermissionAny(
87 'repository.admin')(dbr.repo_name, 84 'repository.read', 'repository.write', 'repository.admin'
88 'get repo check'): 85 )(dbr.repo_name, 'get repo check'):
89 continue 86 continue
90 87
91 if scmr is None: 88 if scmr is None:
92 log.error('%s this repository is present in database but it ' 89 log.error(
93 'cannot be created as an scm instance', 90 '%s this repository is present in database but it '
94 dbr.repo_name) 91 'cannot be created as an scm instance' % dbr.repo_name
92 )
95 continue 93 continue
96 94
97 last_change = scmr.last_change 95 last_change = scmr.last_change
98 tip = h.get_changeset_safe(scmr, 'tip') 96 tip = h.get_changeset_safe(scmr, 'tip')
99 97
101 tmp_d['name'] = dbr.repo_name 99 tmp_d['name'] = dbr.repo_name
102 tmp_d['name_sort'] = tmp_d['name'].lower() 100 tmp_d['name_sort'] = tmp_d['name'].lower()
103 tmp_d['description'] = dbr.description 101 tmp_d['description'] = dbr.description
104 tmp_d['description_sort'] = tmp_d['description'] 102 tmp_d['description_sort'] = tmp_d['description']
105 tmp_d['last_change'] = last_change 103 tmp_d['last_change'] = last_change
106 tmp_d['last_change_sort'] = time.mktime(last_change \ 104 tmp_d['last_change_sort'] = time.mktime(last_change.timetuple())
107 .timetuple())
108 tmp_d['tip'] = tip.raw_id 105 tmp_d['tip'] = tip.raw_id
109 tmp_d['tip_sort'] = tip.revision 106 tmp_d['tip_sort'] = tip.revision
110 tmp_d['rev'] = tip.revision 107 tmp_d['rev'] = tip.revision
111 tmp_d['contact'] = dbr.user.full_contact 108 tmp_d['contact'] = dbr.user.full_contact
112 tmp_d['contact_sort'] = tmp_d['contact'] 109 tmp_d['contact_sort'] = tmp_d['contact']
113 tmp_d['owner_sort'] = tmp_d['contact'] 110 tmp_d['owner_sort'] = tmp_d['contact']
114 tmp_d['repo_archives'] = list(scmr._get_archives()) 111 tmp_d['repo_archives'] = list(scmr._get_archives())
115 tmp_d['last_msg'] = tip.message 112 tmp_d['last_msg'] = tip.message
116 tmp_d['author'] = tip.author 113 tmp_d['author'] = tip.author
117 tmp_d['dbrepo'] = dbr.get_dict() 114 tmp_d['dbrepo'] = dbr.get_dict()
118 tmp_d['dbrepo_fork'] = dbr.fork.get_dict() if dbr.fork \ 115 tmp_d['dbrepo_fork'] = dbr.fork.get_dict() if dbr.fork else {}
119 else {}
120 yield tmp_d 116 yield tmp_d
121 117
118
119 class GroupList(object):
120
121 def __init__(self, db_repo_group_list):
122 self.db_repo_group_list = db_repo_group_list
123
124 def __len__(self):
125 return len(self.db_repo_group_list)
126
127 def __repr__(self):
128 return '<%s (%s)>' % (self.__class__.__name__, self.__len__())
129
130 def __iter__(self):
131 for dbgr in self.db_repo_group_list:
132 # check permission at this level
133 if not HasReposGroupPermissionAny(
134 'group.read', 'group.write', 'group.admin'
135 )(dbgr.group_name, 'get group repo check'):
136 continue
137
138 yield dbgr
139
140
122 class ScmModel(BaseModel): 141 class ScmModel(BaseModel):
123 """Generic Scm Model
124 """ 142 """
143 Generic Scm Model
144 """
145
146 def __get_repo(self, instance):
147 cls = Repository
148 if isinstance(instance, cls):
149 return instance
150 elif isinstance(instance, int) or str(instance).isdigit():
151 return cls.get(instance)
152 elif isinstance(instance, basestring):
153 return cls.get_by_repo_name(instance)
154 elif instance:
155 raise Exception('given object must be int, basestr or Instance'
156 ' of %s got %s' % (type(cls), type(instance)))
125 157
126 @LazyProperty 158 @LazyProperty
127 def repos_path(self): 159 def repos_path(self):
128 """Get's the repositories root path from database 160 """
161 Get's the repositories root path from database
129 """ 162 """
130 163
131 q = self.sa.query(RhodeCodeUi).filter(RhodeCodeUi.ui_key == '/').one() 164 q = self.sa.query(RhodeCodeUi).filter(RhodeCodeUi.ui_key == '/').one()
132 165
133 return q.ui_value 166 return q.ui_value
134 167
135 def repo_scan(self, repos_path=None): 168 def repo_scan(self, repos_path=None):
136 """Listing of repositories in given path. This path should not be a 169 """
170 Listing of repositories in given path. This path should not be a
137 repository itself. Return a dictionary of repository objects 171 repository itself. Return a dictionary of repository objects
138 172
139 :param repos_path: path to directory containing repositories 173 :param repos_path: path to directory containing repositories
140 """ 174 """
141 175
142 if repos_path is None: 176 if repos_path is None:
143 repos_path = self.repos_path 177 repos_path = self.repos_path
144 178
145 log.info('scanning for repositories in %s', repos_path) 179 log.info('scanning for repositories in %s' % repos_path)
146 180
147 baseui = make_ui('db') 181 baseui = make_ui('db')
148 repos_list = {} 182 repos = {}
149 183
150 for name, path in get_filesystem_repos(repos_path, recursive=True): 184 for name, path in get_filesystem_repos(repos_path, recursive=True):
151 185
152 # name need to be decomposed and put back together using the / 186 # name need to be decomposed and put back together using the /
153 # since this is internal storage separator for rhodecode 187 # since this is internal storage separator for rhodecode
154 name = Repository.url_sep().join(name.split(os.sep)) 188 name = Repository.url_sep().join(name.split(os.sep))
155 189
156 try: 190 try:
157 if name in repos_list: 191 if name in repos:
158 raise RepositoryError('Duplicate repository name %s ' 192 raise RepositoryError('Duplicate repository name %s '
159 'found in %s' % (name, path)) 193 'found in %s' % (name, path))
160 else: 194 else:
161 195
162 klass = get_backend(path[0]) 196 klass = get_backend(path[0])
163 197
164 if path[0] == 'hg' and path[0] in BACKENDS.keys(): 198 if path[0] == 'hg' and path[0] in BACKENDS.keys():
165 199 repos[name] = klass(safe_str(path[1]), baseui=baseui)
166 # for mercurial we need to have an str path
167 repos_list[name] = klass(safe_str(path[1]),
168 baseui=baseui)
169 200
170 if path[0] == 'git' and path[0] in BACKENDS.keys(): 201 if path[0] == 'git' and path[0] in BACKENDS.keys():
171 repos_list[name] = klass(path[1]) 202 repos[name] = klass(path[1])
172 except OSError: 203 except OSError:
173 continue 204 continue
174 205
175 return repos_list 206 return repos
176 207
177 def get_repos(self, all_repos=None, sort_key=None): 208 def get_repos(self, all_repos=None, sort_key=None):
178 """ 209 """
179 Get all repos from db and for each repo create it's 210 Get all repos from db and for each repo create it's
180 backend instance and fill that backed with information from database 211 backend instance and fill that backed with information from database
190 repo_iter = CachedRepoList(all_repos, repos_path=self.repos_path, 221 repo_iter = CachedRepoList(all_repos, repos_path=self.repos_path,
191 order_by=sort_key) 222 order_by=sort_key)
192 223
193 return repo_iter 224 return repo_iter
194 225
226 def get_repos_groups(self, all_groups=None):
227 if all_groups is None:
228 all_groups = RepoGroup.query()\
229 .filter(RepoGroup.group_parent_id == None).all()
230 group_iter = GroupList(all_groups)
231
232 return group_iter
233
195 def mark_for_invalidation(self, repo_name): 234 def mark_for_invalidation(self, repo_name):
196 """Puts cache invalidation task into db for 235 """Puts cache invalidation task into db for
197 further global cache invalidation 236 further global cache invalidation
198 237
199 :param repo_name: this repo that should invalidation take place 238 :param repo_name: this repo that should invalidation take place
200 """ 239 """
201 240 CacheInvalidation.set_invalidate(repo_name)
202 log.debug('marking %s for invalidation', repo_name) 241 CacheInvalidation.set_invalidate(repo_name + "_README")
203 cache = self.sa.query(CacheInvalidation)\
204 .filter(CacheInvalidation.cache_key == repo_name).scalar()
205
206 if cache:
207 # mark this cache as inactive
208 cache.cache_active = False
209 else:
210 log.debug('cache key not found in invalidation db -> creating one')
211 cache = CacheInvalidation(repo_name)
212
213 try:
214 self.sa.add(cache)
215 self.sa.commit()
216 except (DatabaseError,):
217 log.error(traceback.format_exc())
218 self.sa.rollback()
219 242
220 def toggle_following_repo(self, follow_repo_id, user_id): 243 def toggle_following_repo(self, follow_repo_id, user_id):
221 244
222 f = self.sa.query(UserFollowing)\ 245 f = self.sa.query(UserFollowing)\
223 .filter(UserFollowing.follows_repo_id == follow_repo_id)\ 246 .filter(UserFollowing.follows_repo_id == follow_repo_id)\
224 .filter(UserFollowing.user_id == user_id).scalar() 247 .filter(UserFollowing.user_id == user_id).scalar()
225 248
226 if f is not None: 249 if f is not None:
227
228 try: 250 try:
229 self.sa.delete(f) 251 self.sa.delete(f)
230 self.sa.commit()
231 action_logger(UserTemp(user_id), 252 action_logger(UserTemp(user_id),
232 'stopped_following_repo', 253 'stopped_following_repo',
233 RepoTemp(follow_repo_id)) 254 RepoTemp(follow_repo_id))
234 return 255 return
235 except: 256 except:
236 log.error(traceback.format_exc()) 257 log.error(traceback.format_exc())
237 self.sa.rollback()
238 raise 258 raise
239 259
240 try: 260 try:
241 f = UserFollowing() 261 f = UserFollowing()
242 f.user_id = user_id 262 f.user_id = user_id
243 f.follows_repo_id = follow_repo_id 263 f.follows_repo_id = follow_repo_id
244 self.sa.add(f) 264 self.sa.add(f)
245 self.sa.commit() 265
246 action_logger(UserTemp(user_id), 266 action_logger(UserTemp(user_id),
247 'started_following_repo', 267 'started_following_repo',
248 RepoTemp(follow_repo_id)) 268 RepoTemp(follow_repo_id))
249 except: 269 except:
250 log.error(traceback.format_exc()) 270 log.error(traceback.format_exc())
251 self.sa.rollback()
252 raise 271 raise
253 272
254 def toggle_following_user(self, follow_user_id, user_id): 273 def toggle_following_user(self, follow_user_id, user_id):
255 f = self.sa.query(UserFollowing)\ 274 f = self.sa.query(UserFollowing)\
256 .filter(UserFollowing.follows_user_id == follow_user_id)\ 275 .filter(UserFollowing.follows_user_id == follow_user_id)\
257 .filter(UserFollowing.user_id == user_id).scalar() 276 .filter(UserFollowing.user_id == user_id).scalar()
258 277
259 if f is not None: 278 if f is not None:
260 try: 279 try:
261 self.sa.delete(f) 280 self.sa.delete(f)
262 self.sa.commit()
263 return 281 return
264 except: 282 except:
265 log.error(traceback.format_exc()) 283 log.error(traceback.format_exc())
266 self.sa.rollback()
267 raise 284 raise
268 285
269 try: 286 try:
270 f = UserFollowing() 287 f = UserFollowing()
271 f.user_id = user_id 288 f.user_id = user_id
272 f.follows_user_id = follow_user_id 289 f.follows_user_id = follow_user_id
273 self.sa.add(f) 290 self.sa.add(f)
274 self.sa.commit()
275 except: 291 except:
276 log.error(traceback.format_exc()) 292 log.error(traceback.format_exc())
277 self.sa.rollback()
278 raise 293 raise
279 294
280 def is_following_repo(self, repo_name, user_id, cache=False): 295 def is_following_repo(self, repo_name, user_id, cache=False):
281 r = self.sa.query(Repository)\ 296 r = self.sa.query(Repository)\
282 .filter(Repository.repo_name == repo_name).scalar() 297 .filter(Repository.repo_name == repo_name).scalar()
307 if not isinstance(repo_id, int): 322 if not isinstance(repo_id, int):
308 repo_id = getattr(Repository.get_by_repo_name(repo_id), 'repo_id') 323 repo_id = getattr(Repository.get_by_repo_name(repo_id), 'repo_id')
309 324
310 return self.sa.query(Repository)\ 325 return self.sa.query(Repository)\
311 .filter(Repository.fork_id == repo_id).count() 326 .filter(Repository.fork_id == repo_id).count()
327
328 def mark_as_fork(self, repo, fork, user):
329 repo = self.__get_repo(repo)
330 fork = self.__get_repo(fork)
331 repo.fork = fork
332 self.sa.add(repo)
333 return repo
312 334
313 def pull_changes(self, repo_name, username): 335 def pull_changes(self, repo_name, username):
314 dbrepo = Repository.get_by_repo_name(repo_name) 336 dbrepo = Repository.get_by_repo_name(repo_name)
315 clone_uri = dbrepo.clone_uri 337 clone_uri = dbrepo.clone_uri
316 if not clone_uri: 338 if not clone_uri:
331 self.mark_for_invalidation(repo_name) 353 self.mark_for_invalidation(repo_name)
332 except: 354 except:
333 log.error(traceback.format_exc()) 355 log.error(traceback.format_exc())
334 raise 356 raise
335 357
336 def commit_change(self, repo, repo_name, cs, user, author, message, content, 358 def commit_change(self, repo, repo_name, cs, user, author, message,
337 f_path): 359 content, f_path):
338 360
339 if repo.alias == 'hg': 361 if repo.alias == 'hg':
340 from vcs.backends.hg import MercurialInMemoryChangeset as IMC 362 from rhodecode.lib.vcs.backends.hg import MercurialInMemoryChangeset as IMC
341 elif repo.alias == 'git': 363 elif repo.alias == 'git':
342 from vcs.backends.git import GitInMemoryChangeset as IMC 364 from rhodecode.lib.vcs.backends.git import GitInMemoryChangeset as IMC
343 365
344 # decoding here will force that we have proper encoded values 366 # decoding here will force that we have proper encoded values
345 # in any other case this will throw exceptions and deny commit 367 # in any other case this will throw exceptions and deny commit
346 content = safe_str(content) 368 content = safe_str(content)
347 message = safe_str(message) 369 message = safe_str(message)
361 self.mark_for_invalidation(repo_name) 383 self.mark_for_invalidation(repo_name)
362 384
363 def create_node(self, repo, repo_name, cs, user, author, message, content, 385 def create_node(self, repo, repo_name, cs, user, author, message, content,
364 f_path): 386 f_path):
365 if repo.alias == 'hg': 387 if repo.alias == 'hg':
366 from vcs.backends.hg import MercurialInMemoryChangeset as IMC 388 from rhodecode.lib.vcs.backends.hg import MercurialInMemoryChangeset as IMC
367 elif repo.alias == 'git': 389 elif repo.alias == 'git':
368 from vcs.backends.git import GitInMemoryChangeset as IMC 390 from rhodecode.lib.vcs.backends.git import GitInMemoryChangeset as IMC
369 # decoding here will force that we have proper encoded values 391 # decoding here will force that we have proper encoded values
370 # in any other case this will throw exceptions and deny commit 392 # in any other case this will throw exceptions and deny commit
371 393
372 if isinstance(content, (basestring,)): 394 if isinstance(content, (basestring,)):
373 content = safe_str(content) 395 content = safe_str(content)
398 420
399 action_logger(user, action, repo_name) 421 action_logger(user, action, repo_name)
400 422
401 self.mark_for_invalidation(repo_name) 423 self.mark_for_invalidation(repo_name)
402 424
425 def get_nodes(self, repo_name, revision, root_path='/', flat=True):
426 """
427 recursive walk in root dir and return a set of all path in that dir
428 based on repository walk function
429
430 :param repo_name: name of repository
431 :param revision: revision for which to list nodes
432 :param root_path: root path to list
433 :param flat: return as a list, if False returns a dict with decription
434
435 """
436 _files = list()
437 _dirs = list()
438 try:
439 _repo = self.__get_repo(repo_name)
440 changeset = _repo.scm_instance.get_changeset(revision)
441 root_path = root_path.lstrip('/')
442 for topnode, dirs, files in changeset.walk(root_path):
443 for f in files:
444 _files.append(f.path if flat else {"name": f.path,
445 "type": "file"})
446 for d in dirs:
447 _dirs.append(d.path if flat else {"name": d.path,
448 "type": "dir"})
449 except RepositoryError:
450 log.debug(traceback.format_exc())
451 raise
452
453 return _dirs, _files
403 454
404 def get_unread_journal(self): 455 def get_unread_journal(self):
405 return self.sa.query(UserLog).count() 456 return self.sa.query(UserLog).count()
406
407 def _should_invalidate(self, repo_name):
408 """Looks up database for invalidation signals for this repo_name
409
410 :param repo_name:
411 """
412
413 ret = self.sa.query(CacheInvalidation)\
414 .filter(CacheInvalidation.cache_key == repo_name)\
415 .filter(CacheInvalidation.cache_active == False)\
416 .scalar()
417
418 return ret
419