view auth/middleware.go @ 143:abfac07bd82a vue-gettext

closing branch vue-gettext
author Thomas Junk <thomas.junk@intevation.de>
date Mon, 02 Jul 2018 09:37:53 +0200
parents 441a8ee637c5
children 0c56c56a1c44
line wrap: on
line source

package auth

import (
	"context"
	"fmt"
	"net/http"
	"regexp"
)

var extractToken = regexp.MustCompile(`\s*Bearer\s+(\S+)`)

type contextType int

const (
	claimsKey contextType = iota
	tokenKey
)

func GetClaims(req *http.Request) (*Claims, bool) {
	claims, ok := req.Context().Value(claimsKey).(*Claims)
	return claims, ok
}

func GetToken(req *http.Request) (string, bool) {
	token, ok := req.Context().Value(tokenKey).(string)
	return token, ok
}

func JWTMiddleware(next http.Handler) http.Handler {

	return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {

		auth := req.Header.Get("Authorization")

		token := extractToken.FindStringSubmatch(auth)
		if len(token) != 2 {
			http.Error(rw, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
			return
		}

		claims, err := TokenToClaims(token[1])
		if err != nil {
			http.Error(rw, fmt.Sprintf("error: %v", err), http.StatusUnauthorized)
			return
		}

		ctx := req.Context()
		ctx = context.WithValue(ctx, claimsKey, claims)
		ctx = context.WithValue(ctx, tokenKey, token[1])
		req = req.WithContext(ctx)

		next.ServeHTTP(rw, req)
	})
}

func ClaimsChecker(next http.Handler, check func(*Claims) bool) http.Handler {
	return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
		claims, ok := GetClaims(req)
		if !ok || !check(claims) {
			http.Error(rw, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
			return
		}
		next.ServeHTTP(rw, req)
	})
}

func HasRole(roles ...string) func(*Claims) bool {
	return func(claims *Claims) bool {
		for _, r1 := range roles {
			for _, r2 := range claims.Roles {
				if r1 == r2 {
					return true
				}
			}
		}
		return false
	}
}