diff rhodecode/model/user.py @ 2031:82a88013a3fd

merge 1.3 into stable
author Marcin Kuzminski <marcin@python-works.com>
date Sun, 26 Feb 2012 17:25:09 +0200
parents 95c3e33ef32e b63adad7c4af
children ecd59c28f432
line wrap: on
line diff
--- a/rhodecode/model/user.py	Sun Feb 19 20:21:14 2012 +0200
+++ b/rhodecode/model/user.py	Sun Feb 26 17:25:09 2012 +0200
@@ -7,7 +7,7 @@
 
     :created_on: Apr 9, 2010
     :author: marcink
-    :copyright: (C) 2009-2011 Marcin Kuzminski <marcin@python-works.com>
+    :copyright: (C) 2010-2012 Marcin Kuzminski <marcin@python-works.com>
     :license: GPLv3, see COPYING for more details.
 """
 # This program is free software: you can redistribute it and/or modify
@@ -26,13 +26,16 @@
 import logging
 import traceback
 
+from pylons import url
 from pylons.i18n.translation import _
 
 from rhodecode.lib import safe_unicode
+from rhodecode.lib.caching_query import FromCache
+
 from rhodecode.model import BaseModel
-from rhodecode.model.caching_query import FromCache
-from rhodecode.model.db import User, RepoToPerm, Repository, Permission, \
-    UserToPerm, UsersGroupRepoToPerm, UsersGroupToPerm, UsersGroupMember
+from rhodecode.model.db import User, UserRepoToPerm, Repository, Permission, \
+    UserToPerm, UsersGroupRepoToPerm, UsersGroupToPerm, UsersGroupMember, \
+    Notification, RepoGroup, UserRepoGroupToPerm, UsersGroup
 from rhodecode.lib.exceptions import DefaultUserException, \
     UserOwnsReposException
 
@@ -42,13 +45,28 @@
 
 log = logging.getLogger(__name__)
 
-PERM_WEIGHTS = {'repository.none': 0,
-                'repository.read': 1,
-                'repository.write': 3,
-                'repository.admin': 3}
+
+PERM_WEIGHTS = {
+    'repository.none': 0,
+    'repository.read': 1,
+    'repository.write': 3,
+    'repository.admin': 4,
+    'group.none': 0,
+    'group.read': 1,
+    'group.write': 3,
+    'group.admin': 4,
+}
 
 
 class UserModel(BaseModel):
+
+    def __get_user(self, user):
+        return self._get_instance(User, user, callback=User.get_by_username)
+
+    def __get_perm(self, permission):
+        return self._get_instance(Permission, permission,
+                                  callback=Permission.get_by_key)
+
     def get(self, user_id, cache=False):
         user = self.sa.query(User)
         if cache:
@@ -56,6 +74,9 @@
                                           "get_user_%s" % user_id))
         return user.get(user_id)
 
+    def get_user(self, user):
+        return self.__get_user(user)
+
     def get_by_username(self, username, cache=False, case_insensitive=False):
 
         if case_insensitive:
@@ -69,13 +90,7 @@
         return user.scalar()
 
     def get_by_api_key(self, api_key, cache=False):
-
-        user = self.sa.query(User)\
-                .filter(User.api_key == api_key)
-        if cache:
-            user = user.options(FromCache("sql_cache_short",
-                                          "get_user_%s" % api_key))
-        return user.scalar()
+        return User.get_by_api_key(api_key, cache)
 
     def create(self, form_data):
         try:
@@ -85,18 +100,91 @@
 
             new_user.api_key = generate_api_key(form_data['username'])
             self.sa.add(new_user)
-            self.sa.commit()
             return new_user
         except:
             log.error(traceback.format_exc())
-            self.sa.rollback()
             raise
 
+    def create_or_update(self, username, password, email, name, lastname,
+                         active=True, admin=False, ldap_dn=None):
+        """
+        Creates a new instance if not found, or updates current one
+
+        :param username:
+        :param password:
+        :param email:
+        :param active:
+        :param name:
+        :param lastname:
+        :param active:
+        :param admin:
+        :param ldap_dn:
+        """
+
+        from rhodecode.lib.auth import get_crypt_password
+
+        log.debug('Checking for %s account in RhodeCode database' % username)
+        user = User.get_by_username(username, case_insensitive=True)
+        if user is None:
+            log.debug('creating new user %s' % username)
+            new_user = User()
+        else:
+            log.debug('updating user %s' % username)
+            new_user = user
+
+        try:
+            new_user.username = username
+            new_user.admin = admin
+            new_user.password = get_crypt_password(password)
+            new_user.api_key = generate_api_key(username)
+            new_user.email = email
+            new_user.active = active
+            new_user.ldap_dn = safe_unicode(ldap_dn) if ldap_dn else None
+            new_user.name = name
+            new_user.lastname = lastname
+            self.sa.add(new_user)
+            return new_user
+        except (DatabaseError,):
+            log.error(traceback.format_exc())
+            raise
+
+    def create_for_container_auth(self, username, attrs):
+        """
+        Creates the given user if it's not already in the database
+
+        :param username:
+        :param attrs:
+        """
+        if self.get_by_username(username, case_insensitive=True) is None:
+
+            # autogenerate email for container account without one
+            generate_email = lambda usr: '%s@container_auth.account' % usr
+
+            try:
+                new_user = User()
+                new_user.username = username
+                new_user.password = None
+                new_user.api_key = generate_api_key(username)
+                new_user.email = attrs['email']
+                new_user.active = attrs.get('active', True)
+                new_user.name = attrs['name'] or generate_email(username)
+                new_user.lastname = attrs['lastname']
+
+                self.sa.add(new_user)
+                return new_user
+            except (DatabaseError,):
+                log.error(traceback.format_exc())
+                self.sa.rollback()
+                raise
+        log.debug('User %s already exists. Skipping creation of account'
+                  ' for container auth.', username)
+        return None
+
     def create_ldap(self, username, password, user_dn, attrs):
         """
         Checks if user is in database, if not creates this user marked
         as ldap user
-        
+
         :param username:
         :param password:
         :param user_dn:
@@ -105,31 +193,36 @@
         from rhodecode.lib.auth import get_crypt_password
         log.debug('Checking for such ldap account in RhodeCode database')
         if self.get_by_username(username, case_insensitive=True) is None:
+
+            # autogenerate email for ldap account without one
+            generate_email = lambda usr: '%s@ldap.account' % usr
+
             try:
                 new_user = User()
+                username = username.lower()
                 # add ldap account always lowercase
-                new_user.username = username.lower()
+                new_user.username = username
                 new_user.password = get_crypt_password(password)
                 new_user.api_key = generate_api_key(username)
-                new_user.email = attrs['email']
-                new_user.active = True
+                new_user.email = attrs['email'] or generate_email(username)
+                new_user.active = attrs.get('active', True)
                 new_user.ldap_dn = safe_unicode(user_dn)
                 new_user.name = attrs['name']
                 new_user.lastname = attrs['lastname']
 
                 self.sa.add(new_user)
-                self.sa.commit()
-                return True
+                return new_user
             except (DatabaseError,):
                 log.error(traceback.format_exc())
                 self.sa.rollback()
                 raise
         log.debug('this %s user exists skipping creation of ldap account',
                   username)
-        return False
+        return None
 
     def create_registration(self, form_data):
-        from rhodecode.lib.celerylib import tasks, run_task
+        from rhodecode.model.notification import NotificationModel
+
         try:
             new_user = User()
             for k, v in form_data.items():
@@ -137,18 +230,26 @@
                     setattr(new_user, k, v)
 
             self.sa.add(new_user)
-            self.sa.commit()
+            self.sa.flush()
+
+            # notification to admins
+            subject = _('new user registration')
             body = ('New user registration\n'
-                    'username: %s\n'
-                    'email: %s\n')
-            body = body % (form_data['username'], form_data['email'])
+                    '---------------------\n'
+                    '- Username: %s\n'
+                    '- Full Name: %s\n'
+                    '- Email: %s\n')
+            body = body % (new_user.username, new_user.full_name,
+                           new_user.email)
+            edit_url = url('edit_user', id=new_user.user_id, qualified=True)
+            kw = {'registered_user_url': edit_url}
+            NotificationModel().create(created_by=new_user, subject=subject,
+                                       body=body, recipients=None,
+                                       type_=Notification.TYPE_REGISTRATION,
+                                       email_kwargs=kw)
 
-            run_task(tasks.send_email, None,
-                     _('[RhodeCode] New User registration'),
-                     body)
         except:
             log.error(traceback.format_exc())
-            self.sa.rollback()
             raise
 
     def update(self, user_id, form_data):
@@ -167,10 +268,8 @@
                     setattr(user, k, v)
 
             self.sa.add(user)
-            self.sa.commit()
         except:
             log.error(traceback.format_exc())
-            self.sa.rollback()
             raise
 
     def update_my_account(self, user_id, form_data):
@@ -189,15 +288,14 @@
                         setattr(user, k, v)
 
             self.sa.add(user)
-            self.sa.commit()
         except:
             log.error(traceback.format_exc())
-            self.sa.rollback()
             raise
 
-    def delete(self, user_id):
+    def delete(self, user):
+        user = self.__get_user(user)
+
         try:
-            user = self.get(user_id, cache=False)
             if user.username == 'default':
                 raise DefaultUserException(
                                 _("You can't remove this user since it's"
@@ -209,10 +307,8 @@
                                                'remove those repositories') \
                                                % user.repositories)
             self.sa.delete(user)
-            self.sa.commit()
         except:
             log.error(traceback.format_exc())
-            self.sa.rollback()
             raise
 
     def reset_password_link(self, data):
@@ -243,16 +339,19 @@
             else:
                 dbuser = self.get(user_id)
 
-            if dbuser is not None:
-                log.debug('filling %s data', dbuser)
+            if dbuser is not None and dbuser.active:
+                log.debug('filling %s data' % dbuser)
                 for k, v in dbuser.get_dict().items():
                     setattr(auth_user, k, v)
+            else:
+                return False
 
         except:
             log.error(traceback.format_exc())
             auth_user.is_authenticated = False
+            return False
 
-        return auth_user
+        return True
 
     def fill_perms(self, user):
         """
@@ -262,98 +361,109 @@
 
         :param user: user instance to fill his perms
         """
-
-        user.permissions['repositories'] = {}
-        user.permissions['global'] = set()
+        RK = 'repositories'
+        GK = 'repositories_groups'
+        GLOBAL = 'global'
+        user.permissions[RK] = {}
+        user.permissions[GK] = {}
+        user.permissions[GLOBAL] = set()
 
         #======================================================================
         # fetch default permissions
         #======================================================================
-        default_user = self.get_by_username('default', cache=True)
+        default_user = User.get_by_username('default', cache=True)
+        default_user_id = default_user.user_id
 
-        default_perms = self.sa.query(RepoToPerm, Repository, Permission)\
-            .join((Repository, RepoToPerm.repository_id ==
-                   Repository.repo_id))\
-            .join((Permission, RepoToPerm.permission_id ==
-                   Permission.permission_id))\
-            .filter(RepoToPerm.user == default_user).all()
+        default_repo_perms = Permission.get_default_perms(default_user_id)
+        default_repo_groups_perms = Permission.get_default_group_perms(default_user_id)
 
         if user.is_admin:
             #==================================================================
-            # #admin have all default rights set to admin
+            # admin user have all default rights for repositories
+            # and groups set to admin
             #==================================================================
-            user.permissions['global'].add('hg.admin')
+            user.permissions[GLOBAL].add('hg.admin')
 
-            for perm in default_perms:
+            # repositories
+            for perm in default_repo_perms:
+                r_k = perm.UserRepoToPerm.repository.repo_name
                 p = 'repository.admin'
-                user.permissions['repositories'][perm.RepoToPerm.
-                                                 repository.repo_name] = p
+                user.permissions[RK][r_k] = p
+
+            # repositories groups
+            for perm in default_repo_groups_perms:
+                rg_k = perm.UserRepoGroupToPerm.group.group_name
+                p = 'group.admin'
+                user.permissions[GK][rg_k] = p
 
         else:
             #==================================================================
-            # set default permissions
+            # set default permissions first for repositories and groups
             #==================================================================
             uid = user.user_id
 
-            #default global
+            # default global permissions
             default_global_perms = self.sa.query(UserToPerm)\
-                .filter(UserToPerm.user == default_user)
+                .filter(UserToPerm.user_id == default_user_id)
 
             for perm in default_global_perms:
-                user.permissions['global'].add(perm.permission.permission_name)
+                user.permissions[GLOBAL].add(perm.permission.permission_name)
 
-            #default for repositories
-            for perm in default_perms:
-                if perm.Repository.private and not (perm.Repository.user_id ==
-                                                    uid):
-                    #diself.sable defaults for private repos,
+            # default for repositories
+            for perm in default_repo_perms:
+                r_k = perm.UserRepoToPerm.repository.repo_name
+                if perm.Repository.private and not (perm.Repository.user_id == uid):
+                    # disable defaults for private repos,
                     p = 'repository.none'
                 elif perm.Repository.user_id == uid:
-                    #set admin if owner
+                    # set admin if owner
                     p = 'repository.admin'
                 else:
                     p = perm.Permission.permission_name
 
-                user.permissions['repositories'][perm.RepoToPerm.
-                                                 repository.repo_name] = p
+                user.permissions[RK][r_k] = p
+
+            # default for repositories groups
+            for perm in default_repo_groups_perms:
+                rg_k = perm.UserRepoGroupToPerm.group.group_name
+                p = perm.Permission.permission_name
+                user.permissions[GK][rg_k] = p
 
             #==================================================================
             # overwrite default with user permissions if any
             #==================================================================
 
-            #user global
+            # user global
             user_perms = self.sa.query(UserToPerm)\
                     .options(joinedload(UserToPerm.permission))\
                     .filter(UserToPerm.user_id == uid).all()
 
             for perm in user_perms:
-                user.permissions['global'].add(perm.permission.
-                                               permission_name)
+                user.permissions[GLOBAL].add(perm.permission.permission_name)
 
-            #user repositories
-            user_repo_perms = self.sa.query(RepoToPerm, Permission,
-                                            Repository)\
-                .join((Repository, RepoToPerm.repository_id ==
-                       Repository.repo_id))\
-                .join((Permission, RepoToPerm.permission_id ==
-                       Permission.permission_id))\
-                .filter(RepoToPerm.user_id == uid).all()
+            # user repositories
+            user_repo_perms = \
+             self.sa.query(UserRepoToPerm, Permission, Repository)\
+             .join((Repository, UserRepoToPerm.repository_id == Repository.repo_id))\
+             .join((Permission, UserRepoToPerm.permission_id == Permission.permission_id))\
+             .filter(UserRepoToPerm.user_id == uid)\
+             .all()
 
             for perm in user_repo_perms:
                 # set admin if owner
+                r_k = perm.UserRepoToPerm.repository.repo_name
                 if perm.Repository.user_id == uid:
                     p = 'repository.admin'
                 else:
                     p = perm.Permission.permission_name
-                user.permissions['repositories'][perm.RepoToPerm.
-                                                 repository.repo_name] = p
+                user.permissions[RK][r_k] = p
 
             #==================================================================
             # check if user is part of groups for this repository and fill in
             # (or replace with higher) permissions
             #==================================================================
 
-            #users group global
+            # users group global
             user_perms_from_users_groups = self.sa.query(UsersGroupToPerm)\
                 .options(joinedload(UsersGroupToPerm.permission))\
                 .join((UsersGroupMember, UsersGroupToPerm.users_group_id ==
@@ -361,30 +471,82 @@
                 .filter(UsersGroupMember.user_id == uid).all()
 
             for perm in user_perms_from_users_groups:
-                user.permissions['global'].add(perm.permission.permission_name)
+                user.permissions[GLOBAL].add(perm.permission.permission_name)
 
-            #users group repositories
-            user_repo_perms_from_users_groups = self.sa.query(
-                                                UsersGroupRepoToPerm,
-                                                Permission, Repository,)\
-                .join((Repository, UsersGroupRepoToPerm.repository_id ==
-                       Repository.repo_id))\
-                .join((Permission, UsersGroupRepoToPerm.permission_id ==
-                       Permission.permission_id))\
-                .join((UsersGroupMember, UsersGroupRepoToPerm.users_group_id ==
-                       UsersGroupMember.users_group_id))\
-                .filter(UsersGroupMember.user_id == uid).all()
+            # users group repositories
+            user_repo_perms_from_users_groups = \
+             self.sa.query(UsersGroupRepoToPerm, Permission, Repository,)\
+             .join((Repository, UsersGroupRepoToPerm.repository_id == Repository.repo_id))\
+             .join((Permission, UsersGroupRepoToPerm.permission_id == Permission.permission_id))\
+             .join((UsersGroupMember, UsersGroupRepoToPerm.users_group_id == UsersGroupMember.users_group_id))\
+             .filter(UsersGroupMember.user_id == uid)\
+             .all()
 
             for perm in user_repo_perms_from_users_groups:
+                r_k = perm.UsersGroupRepoToPerm.repository.repo_name
                 p = perm.Permission.permission_name
-                cur_perm = user.permissions['repositories'][perm.
-                                                    UsersGroupRepoToPerm.
-                                                    repository.repo_name]
-                #overwrite permission only if it's greater than permission
+                cur_perm = user.permissions[RK][r_k]
+                # overwrite permission only if it's greater than permission
                 # given from other sources
                 if PERM_WEIGHTS[p] > PERM_WEIGHTS[cur_perm]:
-                    user.permissions['repositories'][perm.UsersGroupRepoToPerm.
-                                                     repository.repo_name] = p
+                    user.permissions[RK][r_k] = p
+
+            #==================================================================
+            # get access for this user for repos group and override defaults
+            #==================================================================
+
+            # user repositories groups
+            user_repo_groups_perms = \
+             self.sa.query(UserRepoGroupToPerm, Permission, RepoGroup)\
+             .join((RepoGroup, UserRepoGroupToPerm.group_id == RepoGroup.group_id))\
+             .join((Permission, UserRepoGroupToPerm.permission_id == Permission.permission_id))\
+             .filter(UserRepoToPerm.user_id == uid)\
+             .all()
+
+            for perm in user_repo_groups_perms:
+                rg_k = perm.UserRepoGroupToPerm.group.group_name
+                p = perm.Permission.permission_name
+                cur_perm = user.permissions[GK][rg_k]
+                if PERM_WEIGHTS[p] > PERM_WEIGHTS[cur_perm]:
+                    user.permissions[GK][rg_k] = p
 
         return user
 
+    def has_perm(self, user, perm):
+        if not isinstance(perm, Permission):
+            raise Exception('perm needs to be an instance of Permission class '
+                            'got %s instead' % type(perm))
+
+        user = self.__get_user(user)
+
+        return UserToPerm.query().filter(UserToPerm.user == user)\
+            .filter(UserToPerm.permission == perm).scalar() is not None
+
+    def grant_perm(self, user, perm):
+        """
+        Grant user global permissions
+
+        :param user:
+        :param perm:
+        """
+        user = self.__get_user(user)
+        perm = self.__get_perm(perm)
+        new = UserToPerm()
+        new.user = user
+        new.permission = perm
+        self.sa.add(new)
+
+    def revoke_perm(self, user, perm):
+        """
+        Revoke users global permissions
+
+        :param user:
+        :param perm:
+        """
+        user = self.__get_user(user)
+        perm = self.__get_perm(perm)
+
+        obj = UserToPerm.query().filter(UserToPerm.user == user)\
+                .filter(UserToPerm.permission == perm).scalar()
+        if obj:
+            self.sa.delete(obj)