Mercurial > gemma
changeset 128:441a8ee637c5
Added claims checker + example.
author | Sascha L. Teichmann <sascha.teichmann@intevation.de> |
---|---|
date | Thu, 28 Jun 2018 16:13:58 +0200 |
parents | 44794c641277 |
children | ee5a3dd8e972 |
files | auth/middleware.go cmd/tokenserver/main.go |
diffstat | 2 files changed, 34 insertions(+), 0 deletions(-) [+] |
line wrap: on
line diff
--- a/auth/middleware.go Thu Jun 28 13:39:14 2018 +0200 +++ b/auth/middleware.go Thu Jun 28 16:13:58 2018 +0200 @@ -52,3 +52,27 @@ next.ServeHTTP(rw, req) }) } + +func ClaimsChecker(next http.Handler, check func(*Claims) bool) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + claims, ok := GetClaims(req) + if !ok || !check(claims) { + http.Error(rw, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + return + } + next.ServeHTTP(rw, req) + }) +} + +func HasRole(roles ...string) func(*Claims) bool { + return func(claims *Claims) bool { + for _, r1 := range roles { + for _, r2 := range claims.Roles { + if r1 == r2 { + return true + } + } + } + return false + } +}
--- a/cmd/tokenserver/main.go Thu Jun 28 13:39:14 2018 +0200 +++ b/cmd/tokenserver/main.go Thu Jun 28 16:13:58 2018 +0200 @@ -10,6 +10,12 @@ "gemma.intevation.de/gemma/auth" ) +func sysAdmin(rw http.ResponseWriter, req *http.Request) { + claims, _ := auth.GetClaims(req) + rw.Header().Set("Content-Type", "text/plain") + fmt.Fprintf(rw, "%s is a sys_admin\n", claims.User) +} + func renew(rw http.ResponseWriter, req *http.Request) { token, _ := auth.GetToken(req) newToken, err := auth.ConnPool.Replace(token, auth.GenerateToken) @@ -60,6 +66,10 @@ mux.Handle("/", http.StripPrefix("/", http.FileServer(http.Dir(p)))) mux.HandleFunc("/api/token", token) mux.Handle("/api/logout", auth.JWTMiddleware(http.HandlerFunc(token))) + mux.Handle("/api/renew", auth.JWTMiddleware(http.HandlerFunc(renew))) + mux.Handle("/api/sys_admin", + auth.JWTMiddleware( + auth.ClaimsChecker(http.HandlerFunc(sysAdmin), auth.HasRole("sys_admin")))) addr := fmt.Sprintf("%s:%d", *host, *port) log.Fatalln(http.ListenAndServe(addr, mux))