changeset 128:441a8ee637c5

Added claims checker + example.
author Sascha L. Teichmann <sascha.teichmann@intevation.de>
date Thu, 28 Jun 2018 16:13:58 +0200
parents 44794c641277
children ee5a3dd8e972
files auth/middleware.go cmd/tokenserver/main.go
diffstat 2 files changed, 34 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- a/auth/middleware.go	Thu Jun 28 13:39:14 2018 +0200
+++ b/auth/middleware.go	Thu Jun 28 16:13:58 2018 +0200
@@ -52,3 +52,27 @@
 		next.ServeHTTP(rw, req)
 	})
 }
+
+func ClaimsChecker(next http.Handler, check func(*Claims) bool) http.Handler {
+	return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
+		claims, ok := GetClaims(req)
+		if !ok || !check(claims) {
+			http.Error(rw, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
+			return
+		}
+		next.ServeHTTP(rw, req)
+	})
+}
+
+func HasRole(roles ...string) func(*Claims) bool {
+	return func(claims *Claims) bool {
+		for _, r1 := range roles {
+			for _, r2 := range claims.Roles {
+				if r1 == r2 {
+					return true
+				}
+			}
+		}
+		return false
+	}
+}
--- a/cmd/tokenserver/main.go	Thu Jun 28 13:39:14 2018 +0200
+++ b/cmd/tokenserver/main.go	Thu Jun 28 16:13:58 2018 +0200
@@ -10,6 +10,12 @@
 	"gemma.intevation.de/gemma/auth"
 )
 
+func sysAdmin(rw http.ResponseWriter, req *http.Request) {
+	claims, _ := auth.GetClaims(req)
+	rw.Header().Set("Content-Type", "text/plain")
+	fmt.Fprintf(rw, "%s is a sys_admin\n", claims.User)
+}
+
 func renew(rw http.ResponseWriter, req *http.Request) {
 	token, _ := auth.GetToken(req)
 	newToken, err := auth.ConnPool.Replace(token, auth.GenerateToken)
@@ -60,6 +66,10 @@
 	mux.Handle("/", http.StripPrefix("/", http.FileServer(http.Dir(p))))
 	mux.HandleFunc("/api/token", token)
 	mux.Handle("/api/logout", auth.JWTMiddleware(http.HandlerFunc(token)))
+	mux.Handle("/api/renew", auth.JWTMiddleware(http.HandlerFunc(renew)))
+	mux.Handle("/api/sys_admin",
+		auth.JWTMiddleware(
+			auth.ClaimsChecker(http.HandlerFunc(sysAdmin), auth.HasRole("sys_admin"))))
 
 	addr := fmt.Sprintf("%s:%d", *host, *port)
 	log.Fatalln(http.ListenAndServe(addr, mux))