changeset 134:0c56c56a1c44 remove-jwt

Removed the JWT layer from the session management.
author Sascha L. Teichmann <sascha.teichmann@intevation.de>
date Fri, 29 Jun 2018 17:17:20 +0200
parents f4523620ba5d
children d93ccf1aba1b
files 3rdpartylibs.sh auth/connection.go auth/middleware.go auth/token.go cmd/tokenserver/main.go
diffstat 5 files changed, 164 insertions(+), 116 deletions(-) [+]
line wrap: on
line diff
--- a/3rdpartylibs.sh	Thu Jun 28 17:26:38 2018 +0200
+++ b/3rdpartylibs.sh	Fri Jun 29 17:17:20 2018 +0200
@@ -1,3 +1,2 @@
 #!/bin/env sh
-go get -u -v github.com/dgrijalva/jwt-go
 go get -u -v github.com/jackc/pgx
--- a/auth/connection.go	Thu Jun 28 17:26:38 2018 +0200
+++ b/auth/connection.go	Fri Jun 29 17:17:20 2018 +0200
@@ -4,6 +4,7 @@
 	"database/sql"
 	"errors"
 	"log"
+	"sync"
 	"time"
 )
 
@@ -17,18 +18,31 @@
 )
 
 type Connection struct {
-	user     string
-	password string
+	session *Session
 
 	access   time.Time
 	db       *sql.DB
 	refCount int
+
+	mu sync.Mutex
+}
+
+func (c *Connection) set(session *Session) {
+	c.session = session
+	c.touch()
 }
 
-func (c *Connection) set(user, password string) {
-	c.user = user
-	c.password = password
+func (c *Connection) touch() {
+	c.mu.Lock()
 	c.access = time.Now()
+	c.mu.Unlock()
+}
+
+func (c *Connection) last() time.Time {
+	c.mu.Lock()
+	access := c.access
+	c.mu.Unlock()
+	return access
 }
 
 func (c *Connection) close() {
@@ -70,7 +84,7 @@
 func (cp *ConnectionPool) cleanDB() {
 	valid := time.Now().Add(-maxDBIdle)
 	for _, con := range cp.conns {
-		if con.refCount <= 0 && con.access.Before(valid) {
+		if con.refCount <= 0 && con.last().Before(valid) {
 			con.close()
 		}
 	}
@@ -79,14 +93,7 @@
 func (cp *ConnectionPool) cleanToken() {
 	now := time.Now()
 	for token, con := range cp.conns {
-		claims, err := TokenToClaims(token)
-		if err != nil { // Should not happen.
-			log.Printf("error: %v\n", err)
-			con.close()
-			delete(cp.conns, token)
-			continue
-		}
-		expires := time.Unix(claims.ExpiresAt, 0)
+		expires := time.Unix(con.session.ExpiresAt, 0)
 		if expires.Before(now) {
 			// TODO: Be more graceful here?
 			con.close()
@@ -110,46 +117,16 @@
 	return <-res
 }
 
-func (cp *ConnectionPool) Replace(
-	token string,
-	replace func(string, string) (string, error)) (string, error) {
-
-	type res struct {
-		token string
-		err   error
-	}
-
-	resCh := make(chan res)
-
-	cp.cmds <- func(cp *ConnectionPool) {
-		conn, found := cp.conns[token]
-		if !found {
-			resCh <- res{err: ErrNoSuchToken}
-			return
-		}
-		newToken, err := replace(conn.user, conn.password)
-		if err == nil {
-			delete(cp.conns, token)
-			cp.conns[newToken] = conn
-		}
-		resCh <- res{token: newToken, err: err}
-	}
-
-	r := <-resCh
-	return r.token, r.err
-}
-
-func (cp *ConnectionPool) Add(token, user, password string) *Connection {
+func (cp *ConnectionPool) Add(token string, session *Session) *Connection {
 	res := make(chan *Connection)
 
 	cp.cmds <- func(cp *ConnectionPool) {
-
 		con := cp.conns[token]
 		if con == nil {
 			con = &Connection{}
 			cp.conns[token] = con
 		}
-		con.set(user, password)
+		con.set(session)
 		res <- con
 	}
 
@@ -157,6 +134,33 @@
 	return con
 }
 
+func (cp ConnectionPool) Renew(token string) (string, error) {
+
+	type result struct {
+		newToken string
+		err      error
+	}
+
+	resCh := make(chan result)
+
+	cp.cmds <- func(cp *ConnectionPool) {
+		con := cp.conns[token]
+		if con == nil {
+			resCh <- result{err: ErrNoSuchToken}
+		} else {
+			delete(cp.conns, token)
+			newToken := GenerateSessionKey()
+			// TODO: Ensure that this is not racy!
+			con.session.ExpiresAt = time.Now().Add(maxTokenValid).Unix()
+			cp.conns[newToken] = con
+			resCh <- result{newToken: newToken}
+		}
+	}
+
+	r := <-resCh
+	return r.newToken, r.err
+}
+
 func (cp *ConnectionPool) trim(conn *Connection) {
 
 	conn.refCount--
@@ -168,8 +172,8 @@
 
 		for _, con := range cp.conns {
 			if con.db != nil && con.refCount <= 0 {
-				if con.access.Before(least) {
-					least = con.access
+				if last := con.last(); last.Before(least) {
+					least = last
 					oldest = con
 				}
 				count++
@@ -197,14 +201,15 @@
 			res <- result{err: ErrNoSuchToken}
 			return
 		}
-		con.access = time.Now()
+		con.touch()
 		if con.db != nil {
 			con.refCount++
 			res <- result{con: con}
 			return
 		}
 
-		db, err := opendb(con.user, con.password)
+		session := con.session
+		db, err := opendb(session.User, session.Password)
 		if err != nil {
 			res <- result{err: err}
 			return
@@ -228,3 +233,17 @@
 
 	return fn(r.con.db)
 }
+
+func (cp *ConnectionPool) Session(token string) *Session {
+	res := make(chan *Session)
+	cp.cmds <- func(cp *ConnectionPool) {
+		con := cp.conns[token]
+		if con == nil {
+			res <- nil
+		} else {
+			con.touch()
+			res <- con.session
+		}
+	}
+	return <-res
+}
--- a/auth/middleware.go	Thu Jun 28 17:26:38 2018 +0200
+++ b/auth/middleware.go	Fri Jun 29 17:17:20 2018 +0200
@@ -2,23 +2,20 @@
 
 import (
 	"context"
-	"fmt"
 	"net/http"
-	"regexp"
+	"strings"
 )
 
-var extractToken = regexp.MustCompile(`\s*Bearer\s+(\S+)`)
-
 type contextType int
 
 const (
-	claimsKey contextType = iota
+	sessionKey contextType = iota
 	tokenKey
 )
 
-func GetClaims(req *http.Request) (*Claims, bool) {
-	claims, ok := req.Context().Value(claimsKey).(*Claims)
-	return claims, ok
+func GetSession(req *http.Request) (*Session, bool) {
+	session, ok := req.Context().Value(sessionKey).(*Session)
+	return session, ok
 }
 
 func GetToken(req *http.Request) (string, bool) {
@@ -26,36 +23,36 @@
 	return token, ok
 }
 
-func JWTMiddleware(next http.Handler) http.Handler {
+func SessionMiddleware(next http.Handler) http.Handler {
 
 	return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
 
-		auth := req.Header.Get("Authorization")
+		auth := req.Header.Get("X-Gemma-Auth")
 
-		token := extractToken.FindStringSubmatch(auth)
-		if len(token) != 2 {
+		token := strings.TrimSpace(auth)
+		if token == "" {
 			http.Error(rw, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
 			return
 		}
 
-		claims, err := TokenToClaims(token[1])
-		if err != nil {
-			http.Error(rw, fmt.Sprintf("error: %v", err), http.StatusUnauthorized)
+		session := ConnPool.Session(token)
+		if session == nil {
+			http.Error(rw, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
 			return
 		}
 
 		ctx := req.Context()
-		ctx = context.WithValue(ctx, claimsKey, claims)
-		ctx = context.WithValue(ctx, tokenKey, token[1])
+		ctx = context.WithValue(ctx, sessionKey, session)
+		ctx = context.WithValue(ctx, tokenKey, token)
 		req = req.WithContext(ctx)
 
 		next.ServeHTTP(rw, req)
 	})
 }
 
-func ClaimsChecker(next http.Handler, check func(*Claims) bool) http.Handler {
+func SessionChecker(next http.Handler, check func(*Session) bool) http.Handler {
 	return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
-		claims, ok := GetClaims(req)
+		claims, ok := GetSession(req)
 		if !ok || !check(claims) {
 			http.Error(rw, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
 			return
@@ -64,10 +61,10 @@
 	})
 }
 
-func HasRole(roles ...string) func(*Claims) bool {
-	return func(claims *Claims) bool {
+func HasRole(roles ...string) func(*Session) bool {
+	return func(session *Session) bool {
 		for _, r1 := range roles {
-			for _, r2 := range claims.Roles {
+			for _, r2 := range session.Roles {
 				if r1 == r2 {
 					return true
 				}
--- a/auth/token.go	Thu Jun 28 17:26:38 2018 +0200
+++ b/auth/token.go	Fri Jun 29 17:17:20 2018 +0200
@@ -1,53 +1,54 @@
 package auth
 
 import (
+	"crypto/rand"
+	"encoding/base64"
+	"io"
 	"time"
-
-	"gemma.intevation.de/gemma/config"
-
-	jwt "github.com/dgrijalva/jwt-go"
 )
 
-type Claims struct {
-	jwt.StandardClaims
-
-	User  string   `json:"user"`
-	Roles []string `json:"roles"`
+type Session struct {
+	ExpiresAt int64    `json:"expires"`
+	User      string   `json:"user"`
+	Password  string   `json:"password"`
+	Roles     []string `json:"roles"`
 }
 
-const maxTokenValid = time.Hour * 3
+const (
+	sessionKeyLength = 20
+	maxTokenValid    = time.Hour * 3
+)
 
-func NewToken(user string, roles []string) (string, error) {
+func NewSession(user, password string, roles []string) *Session {
 
 	// Create the Claims
-	claims := &Claims{
-		StandardClaims: jwt.StandardClaims{
-			ExpiresAt: jwt.TimeFunc().Add(maxTokenValid).Unix(),
-		},
-		User:  user,
-		Roles: roles,
+	return &Session{
+		ExpiresAt: time.Now().Add(maxTokenValid).Unix(),
+		User:      user,
+		Password:  password,
+		Roles:     roles,
 	}
+}
 
-	token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
-	return token.SignedString(config.Config.JWTSignKey)
+func GenerateSessionKey() string {
+	return base64.URLEncoding.EncodeToString(GenerateRandomKey(sessionKeyLength))
 }
 
-func TokenToClaims(token string) (*Claims, error) {
-	claims := &Claims{}
-	_, err := jwt.ParseWithClaims(token, claims,
-		func(*jwt.Token) (interface{}, error) { return config.Config.JWTSignKey, nil })
-	return claims, err
+func GenerateRandomKey(length int) []byte {
+	k := make([]byte, length)
+	if _, err := io.ReadFull(rand.Reader, k); err != nil {
+		return nil
+	}
+	return k
 }
 
-func GenerateToken(user, password string) (string, error) {
+func GenerateSession(user, password string) (string, *Session, error) {
 	roles, err := AllOtherRoles(user, password)
 	if err != nil {
-		return "", err
+		return "", nil, err
 	}
-	token, err := NewToken(user, roles)
-	if err != nil {
-		return "", err
-	}
-	ConnPool.Add(token, user, password)
-	return token, nil
+	token := GenerateSessionKey()
+	session := NewSession(user, password, roles)
+	ConnPool.Add(token, session)
+	return token, session, nil
 }
--- a/cmd/tokenserver/main.go	Thu Jun 28 17:26:38 2018 +0200
+++ b/cmd/tokenserver/main.go	Fri Jun 29 17:17:20 2018 +0200
@@ -1,6 +1,7 @@
 package main
 
 import (
+	"encoding/json"
 	"flag"
 	"fmt"
 	"log"
@@ -11,14 +12,14 @@
 )
 
 func sysAdmin(rw http.ResponseWriter, req *http.Request) {
-	claims, _ := auth.GetClaims(req)
+	session, _ := auth.GetSession(req)
 	rw.Header().Set("Content-Type", "text/plain")
-	fmt.Fprintf(rw, "%s is a sys_admin\n", claims.User)
+	fmt.Fprintf(rw, "%s is a sys_admin\n", session.User)
 }
 
 func renew(rw http.ResponseWriter, req *http.Request) {
 	token, _ := auth.GetToken(req)
-	newToken, err := auth.ConnPool.Replace(token, auth.GenerateToken)
+	newToken, err := auth.ConnPool.Renew(token)
 	switch {
 	case err == auth.ErrNoSuchToken:
 		http.NotFound(rw, req)
@@ -27,8 +28,25 @@
 		http.Error(rw, fmt.Sprintf("error: %v", err), http.StatusInternalServerError)
 		return
 	}
+
+	session, _ := auth.GetSession(req)
+
+	var result = struct {
+		Token   string   `json:"token"`
+		Expires int64    `json:"expires"`
+		User    string   `json:"user"`
+		Roles   []string `json:"roles"`
+	}{
+		Token:   newToken,
+		Expires: session.ExpiresAt,
+		User:    session.User,
+		Roles:   session.Roles,
+	}
+
 	rw.Header().Set("Content-Type", "text/plain")
-	fmt.Fprintf(rw, "%s\n", newToken)
+	if err := json.NewEncoder(rw).Encode(&result); err != nil {
+		log.Printf("error: %v\n", err)
+	}
 }
 
 func logout(rw http.ResponseWriter, req *http.Request) {
@@ -46,15 +64,29 @@
 	user := req.FormValue("user")
 	password := req.FormValue("password")
 
-	token, err := auth.GenerateToken(user, password)
+	token, session, err := auth.GenerateSession(user, password)
 
 	if err != nil {
 		http.Error(rw, fmt.Sprintf("error: %v", err), http.StatusInternalServerError)
 		return
 	}
 
-	rw.Header().Set("Content-Type", "text/plain")
-	fmt.Fprintf(rw, "%s\n", token)
+	var result = struct {
+		Token   string   `json:"token"`
+		Expires int64    `json:"expires"`
+		User    string   `json:"user"`
+		Roles   []string `json:"roles"`
+	}{
+		Token:   token,
+		Expires: session.ExpiresAt,
+		User:    session.User,
+		Roles:   session.Roles,
+	}
+
+	rw.Header().Set("Content-Type", "application/json")
+	if err := json.NewEncoder(rw).Encode(&result); err != nil {
+		log.Printf("error: %v\n", err)
+	}
 }
 
 func main() {
@@ -65,11 +97,11 @@
 	mux := http.NewServeMux()
 	mux.Handle("/", http.StripPrefix("/", http.FileServer(http.Dir(p))))
 	mux.HandleFunc("/api/token", token)
-	mux.Handle("/api/logout", auth.JWTMiddleware(http.HandlerFunc(token)))
-	mux.Handle("/api/renew", auth.JWTMiddleware(http.HandlerFunc(renew)))
+	mux.Handle("/api/logout", auth.SessionMiddleware(http.HandlerFunc(token)))
+	mux.Handle("/api/renew", auth.SessionMiddleware(http.HandlerFunc(renew)))
 	mux.Handle("/api/sys_admin",
-		auth.JWTMiddleware(
-			auth.ClaimsChecker(http.HandlerFunc(sysAdmin), auth.HasRole("sys_admin"))))
+		auth.SessionMiddleware(
+			auth.SessionChecker(http.HandlerFunc(sysAdmin), auth.HasRole("sys_admin"))))
 
 	addr := fmt.Sprintf("%s:%d", *host, *port)
 	log.Fatalln(http.ListenAndServe(addr, mux))