changeset 6453:ebe7d95f698b

auth: simplify the double invoked auth classes used for permission checking Also avoid storing request state in the Has Permission instances. The instances er temporary and only used once and there is thus not real problem, but it is simpler and cleaner this way.
author Mads Kiilerich <mads@kiilerich.com>
date Sun, 22 Jan 2017 01:16:52 +0100
parents 3dcf1f82311a
children 4aaeac1e7ba3
files kallithea/lib/auth.py
diffstat 1 files changed, 77 insertions(+), 145 deletions(-) [+]
line wrap: on
line diff
--- a/kallithea/lib/auth.py	Sat Dec 24 01:27:47 2016 +0100
+++ b/kallithea/lib/auth.py	Sun Jan 22 01:16:52 2017 +0100
@@ -715,6 +715,7 @@
     return HTTPFound(location=url('login_home', came_from=p))
 
 
+# Use as decorator
 class LoginRequired(object):
     """Client must be logged in as a valid User (but the "default" user,
     if enabled, is considered valid), or we'll redirect to the login page.
@@ -772,6 +773,8 @@
             log.warning('user %s NOT authenticated with regular auth @ %s', user, loc)
             raise _redirect_to_login()
 
+
+# Use as decorator
 class NotAnonymous(object):
     """Ensures that client is not logged in as the "default" user, and
     redirects to the login page otherwise. Must be used together with
@@ -782,126 +785,106 @@
 
     def __wrapper(self, func, *fargs, **fkwargs):
         cls = fargs[0]
-        self.user = request.authuser
+        user = request.authuser
 
-        log.debug('Checking if user is not anonymous @%s', cls)
+        log.debug('Checking that user %s is not anonymous @%s', user.username, cls)
 
-        if self.user.is_default_user:
+        if user.is_default_user:
             raise _redirect_to_login(_('You need to be a registered user to '
                                        'perform this action'))
         else:
             return func(*fargs, **fkwargs)
 
 
-class PermsDecorator(object):
+class _PermsDecorator(object):
     """Base class for controller decorators"""
 
     def __init__(self, *required_perms):
-        self.required_perms = set(required_perms)
-        self.user_perms = None
+        self.required_perms = required_perms # usually very short - a list is thus fine
 
     def __call__(self, func):
         return decorator(self.__wrapper, func)
 
     def __wrapper(self, func, *fargs, **fkwargs):
         cls = fargs[0]
-        self.user = request.authuser
-        self.user_perms = self.user.permissions
+        user = request.authuser
         log.debug('checking %s permissions %s for %s %s',
-          self.__class__.__name__, self.required_perms, cls, self.user)
+          self.__class__.__name__, self.required_perms, cls, user)
 
-        if self.check_permissions():
-            log.debug('Permission granted for %s %s', cls, self.user)
+        if self.check_permissions(user):
+            log.debug('Permission granted for %s %s', cls, user)
             return func(*fargs, **fkwargs)
 
         else:
-            log.debug('Permission denied for %s %s', cls, self.user)
-            if self.user.is_default_user:
+            log.debug('Permission denied for %s %s', cls, user)
+            if user.is_default_user:
                 raise _redirect_to_login(_('You need to be signed in to view this page'))
             else:
                 raise HTTPForbidden()
 
-    def check_permissions(self):
+    def check_permissions(self, user):
         raise NotImplementedError()
 
 
-class HasPermissionAnyDecorator(PermsDecorator):
+class HasPermissionAnyDecorator(_PermsDecorator):
     """
-    Checks for access permission for any of given predicates. In order to
-    fulfill the request any of predicates must be meet
+    Checks the user has any of the given global permissions.
     """
 
-    def check_permissions(self):
-        if self.required_perms.intersection(self.user_perms.get('global')):
-            return True
-        return False
+    def check_permissions(self, user):
+        global_permissions = user.permissions['global'] # usually very short
+        return any(p in global_permissions for p in self.required_perms)
 
 
-class HasRepoPermissionAnyDecorator(PermsDecorator):
+class HasRepoPermissionAnyDecorator(_PermsDecorator):
     """
-    Checks for access permission for any of given predicates for specific
-    repository. In order to fulfill the request any of predicates must be meet
+    Checks the user has any of given permissions for the requested repository.
     """
 
-    def check_permissions(self):
+    def check_permissions(self, user):
         repo_name = get_repo_slug(request)
         try:
-            user_perms = set([self.user_perms['repositories'][repo_name]])
+            return user.permissions['repositories'][repo_name] in self.required_perms
         except KeyError:
             return False
 
-        if self.required_perms.intersection(user_perms):
-            return True
-        return False
 
-
-class HasRepoGroupPermissionAnyDecorator(PermsDecorator):
+class HasRepoGroupPermissionAnyDecorator(_PermsDecorator):
     """
-    Checks for access permission for any of given predicates for specific
-    repository group. In order to fulfill the request any of predicates must be meet
+    Checks the user has any of given permissions for the requested repository group.
     """
 
-    def check_permissions(self):
-        group_name = get_repo_group_slug(request)
+    def check_permissions(self, user):
+        repo_group_name = get_repo_group_slug(request)
         try:
-            user_perms = set([self.user_perms['repositories_groups'][group_name]])
+            return user.permissions['repositories_groups'][repo_group_name] in self.required_perms
         except KeyError:
             return False
 
-        if self.required_perms.intersection(user_perms):
-            return True
-        return False
 
-
-class HasUserGroupPermissionAnyDecorator(PermsDecorator):
+class HasUserGroupPermissionAnyDecorator(_PermsDecorator):
     """
     Checks for access permission for any of given predicates for specific
     user group. In order to fulfill the request any of predicates must be meet
     """
 
-    def check_permissions(self):
-        group_name = get_user_group_slug(request)
+    def check_permissions(self, user):
+        user_group_name = get_user_group_slug(request)
         try:
-            user_perms = set([self.user_perms['user_groups'][group_name]])
+            return user.permissions['user_groups'][user_group_name] in self.required_perms
         except KeyError:
             return False
 
-        if self.required_perms.intersection(user_perms):
-            return True
-        return False
-
 
 #==============================================================================
 # CHECK FUNCTIONS
 #==============================================================================
-class PermsFunction(object):
+
+class _PermsFunction(object):
     """Base function for other check functions"""
 
-    def __init__(self, *perms):
-        self.required_perms = set(perms)
-        self.user_perms = None
-        self.repo_name = None
-        self.group_name = None
+    def __init__(self, *required_perms):
+        self.required_perms = required_perms # usually very short - a list is thus fine
 
     def __nonzero__(self):
         """ Defend against accidentally forgetting to call the object
@@ -910,132 +893,81 @@
         """
         raise AssertionError(self.__class__.__name__ + ' is not a bool and must be called!')
 
-    def __call__(self, check_location='unspecified location'):
-        user = request.user
-        assert user
-        assert isinstance(user, AuthUser), user
-
-        cls_name = self.__class__.__name__
-        check_scope = self._scope()
-        log.debug('checking cls:%s %s usr:%s %s @ %s', cls_name,
-                  self.required_perms, user, check_scope,
-                  check_location)
-        self.user_perms = user.permissions
-
-        result = self.check_permissions()
-        result_text = 'granted' if result else 'denied'
-        log.debug('Permission to %s %s for user: %s @ %s',
-            check_scope, result_text, user, check_location)
-        return result
-
-    def check_permissions(self):
+    def __call__(self, *a, **b):
         raise NotImplementedError()
 
-    def _scope(self):
-        return '(unknown scope)'
 
+class HasPermissionAny(_PermsFunction):
 
-class HasPermissionAny(PermsFunction):
-    def check_permissions(self):
-        if self.required_perms.intersection(self.user_perms.get('global')):
-            return True
-        return False
+    def __call__(self, purpose=None):
+        global_permissions = request.user.permissions['global'] # usually very short
+        ok = any(p in global_permissions for p in self.required_perms)
 
-    def _scope(self):
-        return 'global'
+        log.error('Check %s for global %s (%s): %s' %
+            (request.user.username, self.required_perms, purpose, ok))
+        return ok
 
 
-class HasRepoPermissionAny(PermsFunction):
-    def __call__(self, repo_name=None, check_location=''):
-        self.repo_name = repo_name
-        return super(HasRepoPermissionAny, self).__call__(check_location)
+class HasRepoPermissionAny(_PermsFunction):
 
-    def check_permissions(self):
-        if not self.repo_name:
-            self.repo_name = get_repo_slug(request)
-
+    def __call__(self, repo_name, purpose=None):
         try:
-            self._user_perms = set(
-                [self.user_perms['repositories'][self.repo_name]]
-            )
+            ok = request.user.permissions['repositories'][repo_name] in self.required_perms
         except KeyError:
-            return False
-        if self.required_perms.intersection(self._user_perms):
-            return True
-        return False
+            ok = False
 
-    def _scope(self):
-        return 'repo:%s' % self.repo_name
+        log.error('Check %s for %s for repo %s (%s): %s' %
+            (request.user.username, self.required_perms, repo_name, purpose, ok))
+        return ok
 
 
-class HasRepoGroupPermissionAny(PermsFunction):
-    def __call__(self, group_name=None, check_location=''):
-        self.group_name = group_name
-        return super(HasRepoGroupPermissionAny, self).__call__(check_location)
+class HasRepoGroupPermissionAny(_PermsFunction):
 
-    def check_permissions(self):
+    def __call__(self, group_name, purpose=None):
         try:
-            self._user_perms = set(
-                [self.user_perms['repositories_groups'][self.group_name]]
-            )
+            ok = request.user.permissions['repositories_groups'][group_name] in self.required_perms
         except KeyError:
-            return False
-        if self.required_perms.intersection(self._user_perms):
-            return True
-        return False
+            ok = False
 
-    def _scope(self):
-        return 'repogroup:%s' % self.group_name
+        log.error('Check %s for %s for repo group %s (%s): %s' %
+            (request.user.username, self.required_perms, group_name, purpose, ok))
+        return ok
 
 
-class HasUserGroupPermissionAny(PermsFunction):
-    def __call__(self, user_group_name=None, check_location=''):
-        self.user_group_name = user_group_name
-        return super(HasUserGroupPermissionAny, self).__call__(check_location)
+class HasUserGroupPermissionAny(_PermsFunction):
 
-    def check_permissions(self):
+    def __call__(self, user_group_name, purpose=None):
         try:
-            self._user_perms = set(
-                [self.user_perms['user_groups'][self.user_group_name]]
-            )
+            ok = request.user.permissions['user_groups'][user_group_name] in self.required_perms
         except KeyError:
-            return False
-        if self.required_perms.intersection(self._user_perms):
-            return True
-        return False
+            ok = False
 
-    def _scope(self):
-        return 'usergroup:%s' % self.user_group_name
+        log.error('Check %s %s for user group %s (%s): %s' %
+            (request.user.username, self.required_perms, user_group_name, purpose, ok))
+        return ok
 
 
 #==============================================================================
 # SPECIAL VERSION TO HANDLE MIDDLEWARE AUTH
 #==============================================================================
+
 class HasPermissionAnyMiddleware(object):
     def __init__(self, *perms):
         self.required_perms = set(perms)
 
-    def __call__(self, user, repo_name):
-        # repo_name MUST be unicode, since we handle keys in permission
+    def __call__(self, user, repo_name, purpose=None):
+        # repo_name MUST be unicode, since we handle keys in ok
         # dict by unicode
         repo_name = safe_unicode(repo_name)
-        usr = AuthUser(user.user_id)
-        self.user_perms = set([usr.permissions['repositories'][repo_name]])
-        self.username = user.username
-        self.repo_name = repo_name
-        return self.check_permissions()
+        user = AuthUser(user.user_id)
 
-    def check_permissions(self):
-        log.debug('checking VCS protocol '
-                  'permissions %s for user:%s repository:%s', self.user_perms,
-                                                self.username, self.repo_name)
-        if self.required_perms.intersection(self.user_perms):
-            log.debug('Permission to repo: %s granted for user: %s @ %s',
-                      self.repo_name, self.username, 'PermissionMiddleware')
-            return True
-        log.debug('Permission to repo: %s denied for user: %s @ %s',
-                  self.repo_name, self.username, 'PermissionMiddleware')
-        return False
+        try:
+            ok = user.permissions['repositories'][repo_name] in self.required_perms
+        except KeyError:
+            ok = False
+
+        log.debug('Middleware check %s for %s for repo %s (%s): %s' % (user.username, self.required_perms, repo_name, purpose, ok))
+        return ok
 
 
 def check_ip_access(source_ip, allowed_ips=None):