Mercurial > gemma
changeset 145:b537ee0d3dcd
Merged.
author | Sascha L. Teichmann <sascha.teichmann@intevation.de> |
---|---|
date | Mon, 02 Jul 2018 10:32:40 +0200 |
parents | 836132579c6f (current diff) d93ccf1aba1b (diff) |
children | f48306e87cfb |
files | 3rdpartylibs.sh |
diffstat | 4 files changed, 164 insertions(+), 115 deletions(-) [+] |
line wrap: on
line diff
--- a/auth/connection.go Mon Jul 02 09:38:32 2018 +0200 +++ b/auth/connection.go Mon Jul 02 10:32:40 2018 +0200 @@ -4,6 +4,7 @@ "database/sql" "errors" "log" + "sync" "time" ) @@ -17,18 +18,31 @@ ) type Connection struct { - user string - password string + session *Session access time.Time db *sql.DB refCount int + + mu sync.Mutex +} + +func (c *Connection) set(session *Session) { + c.session = session + c.touch() } -func (c *Connection) set(user, password string) { - c.user = user - c.password = password +func (c *Connection) touch() { + c.mu.Lock() c.access = time.Now() + c.mu.Unlock() +} + +func (c *Connection) last() time.Time { + c.mu.Lock() + access := c.access + c.mu.Unlock() + return access } func (c *Connection) close() { @@ -70,7 +84,7 @@ func (cp *ConnectionPool) cleanDB() { valid := time.Now().Add(-maxDBIdle) for _, con := range cp.conns { - if con.refCount <= 0 && con.access.Before(valid) { + if con.refCount <= 0 && con.last().Before(valid) { con.close() } } @@ -79,14 +93,7 @@ func (cp *ConnectionPool) cleanToken() { now := time.Now() for token, con := range cp.conns { - claims, err := TokenToClaims(token) - if err != nil { // Should not happen. - log.Printf("error: %v\n", err) - con.close() - delete(cp.conns, token) - continue - } - expires := time.Unix(claims.ExpiresAt, 0) + expires := time.Unix(con.session.ExpiresAt, 0) if expires.Before(now) { // TODO: Be more graceful here? con.close() @@ -110,46 +117,16 @@ return <-res } -func (cp *ConnectionPool) Replace( - token string, - replace func(string, string) (string, error)) (string, error) { - - type res struct { - token string - err error - } - - resCh := make(chan res) - - cp.cmds <- func(cp *ConnectionPool) { - conn, found := cp.conns[token] - if !found { - resCh <- res{err: ErrNoSuchToken} - return - } - newToken, err := replace(conn.user, conn.password) - if err == nil { - delete(cp.conns, token) - cp.conns[newToken] = conn - } - resCh <- res{token: newToken, err: err} - } - - r := <-resCh - return r.token, r.err -} - -func (cp *ConnectionPool) Add(token, user, password string) *Connection { +func (cp *ConnectionPool) Add(token string, session *Session) *Connection { res := make(chan *Connection) cp.cmds <- func(cp *ConnectionPool) { - con := cp.conns[token] if con == nil { con = &Connection{} cp.conns[token] = con } - con.set(user, password) + con.set(session) res <- con } @@ -157,6 +134,33 @@ return con } +func (cp ConnectionPool) Renew(token string) (string, error) { + + type result struct { + newToken string + err error + } + + resCh := make(chan result) + + cp.cmds <- func(cp *ConnectionPool) { + con := cp.conns[token] + if con == nil { + resCh <- result{err: ErrNoSuchToken} + } else { + delete(cp.conns, token) + newToken := GenerateSessionKey() + // TODO: Ensure that this is not racy! + con.session.ExpiresAt = time.Now().Add(maxTokenValid).Unix() + cp.conns[newToken] = con + resCh <- result{newToken: newToken} + } + } + + r := <-resCh + return r.newToken, r.err +} + func (cp *ConnectionPool) trim(conn *Connection) { conn.refCount-- @@ -168,8 +172,8 @@ for _, con := range cp.conns { if con.db != nil && con.refCount <= 0 { - if con.access.Before(least) { - least = con.access + if last := con.last(); last.Before(least) { + least = last oldest = con } count++ @@ -197,14 +201,15 @@ res <- result{err: ErrNoSuchToken} return } - con.access = time.Now() + con.touch() if con.db != nil { con.refCount++ res <- result{con: con} return } - db, err := opendb(con.user, con.password) + session := con.session + db, err := opendb(session.User, session.Password) if err != nil { res <- result{err: err} return @@ -228,3 +233,17 @@ return fn(r.con.db) } + +func (cp *ConnectionPool) Session(token string) *Session { + res := make(chan *Session) + cp.cmds <- func(cp *ConnectionPool) { + con := cp.conns[token] + if con == nil { + res <- nil + } else { + con.touch() + res <- con.session + } + } + return <-res +}
--- a/auth/middleware.go Mon Jul 02 09:38:32 2018 +0200 +++ b/auth/middleware.go Mon Jul 02 10:32:40 2018 +0200 @@ -2,23 +2,20 @@ import ( "context" - "fmt" "net/http" - "regexp" + "strings" ) -var extractToken = regexp.MustCompile(`\s*Bearer\s+(\S+)`) - type contextType int const ( - claimsKey contextType = iota + sessionKey contextType = iota tokenKey ) -func GetClaims(req *http.Request) (*Claims, bool) { - claims, ok := req.Context().Value(claimsKey).(*Claims) - return claims, ok +func GetSession(req *http.Request) (*Session, bool) { + session, ok := req.Context().Value(sessionKey).(*Session) + return session, ok } func GetToken(req *http.Request) (string, bool) { @@ -26,36 +23,36 @@ return token, ok } -func JWTMiddleware(next http.Handler) http.Handler { +func SessionMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - auth := req.Header.Get("Authorization") + auth := req.Header.Get("X-Gemma-Auth") - token := extractToken.FindStringSubmatch(auth) - if len(token) != 2 { + token := strings.TrimSpace(auth) + if token == "" { http.Error(rw, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } - claims, err := TokenToClaims(token[1]) - if err != nil { - http.Error(rw, fmt.Sprintf("error: %v", err), http.StatusUnauthorized) + session := ConnPool.Session(token) + if session == nil { + http.Error(rw, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } ctx := req.Context() - ctx = context.WithValue(ctx, claimsKey, claims) - ctx = context.WithValue(ctx, tokenKey, token[1]) + ctx = context.WithValue(ctx, sessionKey, session) + ctx = context.WithValue(ctx, tokenKey, token) req = req.WithContext(ctx) next.ServeHTTP(rw, req) }) } -func ClaimsChecker(next http.Handler, check func(*Claims) bool) http.Handler { +func SessionChecker(next http.Handler, check func(*Session) bool) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - claims, ok := GetClaims(req) + claims, ok := GetSession(req) if !ok || !check(claims) { http.Error(rw, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return @@ -64,10 +61,10 @@ }) } -func HasRole(roles ...string) func(*Claims) bool { - return func(claims *Claims) bool { +func HasRole(roles ...string) func(*Session) bool { + return func(session *Session) bool { for _, r1 := range roles { - for _, r2 := range claims.Roles { + for _, r2 := range session.Roles { if r1 == r2 { return true }
--- a/auth/token.go Mon Jul 02 09:38:32 2018 +0200 +++ b/auth/token.go Mon Jul 02 10:32:40 2018 +0200 @@ -1,53 +1,54 @@ package auth import ( + "crypto/rand" + "encoding/base64" + "io" "time" - - "gemma.intevation.de/gemma/config" - - jwt "github.com/dgrijalva/jwt-go" ) -type Claims struct { - jwt.StandardClaims - - User string `json:"user"` - Roles []string `json:"roles"` +type Session struct { + ExpiresAt int64 `json:"expires"` + User string `json:"user"` + Password string `json:"password"` + Roles []string `json:"roles"` } -const maxTokenValid = time.Hour * 3 +const ( + sessionKeyLength = 20 + maxTokenValid = time.Hour * 3 +) -func NewToken(user string, roles []string) (string, error) { +func NewSession(user, password string, roles []string) *Session { // Create the Claims - claims := &Claims{ - StandardClaims: jwt.StandardClaims{ - ExpiresAt: jwt.TimeFunc().Add(maxTokenValid).Unix(), - }, - User: user, - Roles: roles, + return &Session{ + ExpiresAt: time.Now().Add(maxTokenValid).Unix(), + User: user, + Password: password, + Roles: roles, } +} - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - return token.SignedString(config.Config.JWTSignKey) +func GenerateSessionKey() string { + return base64.URLEncoding.EncodeToString(GenerateRandomKey(sessionKeyLength)) } -func TokenToClaims(token string) (*Claims, error) { - claims := &Claims{} - _, err := jwt.ParseWithClaims(token, claims, - func(*jwt.Token) (interface{}, error) { return config.Config.JWTSignKey, nil }) - return claims, err +func GenerateRandomKey(length int) []byte { + k := make([]byte, length) + if _, err := io.ReadFull(rand.Reader, k); err != nil { + return nil + } + return k } -func GenerateToken(user, password string) (string, error) { +func GenerateSession(user, password string) (string, *Session, error) { roles, err := AllOtherRoles(user, password) if err != nil { - return "", err + return "", nil, err } - token, err := NewToken(user, roles) - if err != nil { - return "", err - } - ConnPool.Add(token, user, password) - return token, nil + token := GenerateSessionKey() + session := NewSession(user, password, roles) + ConnPool.Add(token, session) + return token, session, nil }
--- a/cmd/tokenserver/main.go Mon Jul 02 09:38:32 2018 +0200 +++ b/cmd/tokenserver/main.go Mon Jul 02 10:32:40 2018 +0200 @@ -1,6 +1,7 @@ package main import ( + "encoding/json" "flag" "fmt" "log" @@ -11,14 +12,14 @@ ) func sysAdmin(rw http.ResponseWriter, req *http.Request) { - claims, _ := auth.GetClaims(req) + session, _ := auth.GetSession(req) rw.Header().Set("Content-Type", "text/plain") - fmt.Fprintf(rw, "%s is a sys_admin\n", claims.User) + fmt.Fprintf(rw, "%s is a sys_admin\n", session.User) } func renew(rw http.ResponseWriter, req *http.Request) { token, _ := auth.GetToken(req) - newToken, err := auth.ConnPool.Replace(token, auth.GenerateToken) + newToken, err := auth.ConnPool.Renew(token) switch { case err == auth.ErrNoSuchToken: http.NotFound(rw, req) @@ -27,8 +28,25 @@ http.Error(rw, fmt.Sprintf("error: %v", err), http.StatusInternalServerError) return } + + session, _ := auth.GetSession(req) + + var result = struct { + Token string `json:"token"` + Expires int64 `json:"expires"` + User string `json:"user"` + Roles []string `json:"roles"` + }{ + Token: newToken, + Expires: session.ExpiresAt, + User: session.User, + Roles: session.Roles, + } + rw.Header().Set("Content-Type", "text/plain") - fmt.Fprintf(rw, "%s\n", newToken) + if err := json.NewEncoder(rw).Encode(&result); err != nil { + log.Printf("error: %v\n", err) + } } func logout(rw http.ResponseWriter, req *http.Request) { @@ -46,15 +64,29 @@ user := req.FormValue("user") password := req.FormValue("password") - token, err := auth.GenerateToken(user, password) + token, session, err := auth.GenerateSession(user, password) if err != nil { http.Error(rw, fmt.Sprintf("error: %v", err), http.StatusInternalServerError) return } - rw.Header().Set("Content-Type", "text/plain") - fmt.Fprintf(rw, "%s\n", token) + var result = struct { + Token string `json:"token"` + Expires int64 `json:"expires"` + User string `json:"user"` + Roles []string `json:"roles"` + }{ + Token: token, + Expires: session.ExpiresAt, + User: session.User, + Roles: session.Roles, + } + + rw.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(rw).Encode(&result); err != nil { + log.Printf("error: %v\n", err) + } } func main() { @@ -65,11 +97,11 @@ mux := http.NewServeMux() mux.Handle("/", http.StripPrefix("/", http.FileServer(http.Dir(p)))) mux.HandleFunc("/api/token", token) - mux.Handle("/api/logout", auth.JWTMiddleware(http.HandlerFunc(token))) - mux.Handle("/api/renew", auth.JWTMiddleware(http.HandlerFunc(renew))) + mux.Handle("/api/logout", auth.SessionMiddleware(http.HandlerFunc(token))) + mux.Handle("/api/renew", auth.SessionMiddleware(http.HandlerFunc(renew))) mux.Handle("/api/sys_admin", - auth.JWTMiddleware( - auth.ClaimsChecker(http.HandlerFunc(sysAdmin), auth.HasRole("sys_admin")))) + auth.SessionMiddleware( + auth.SessionChecker(http.HandlerFunc(sysAdmin), auth.HasRole("sys_admin")))) addr := fmt.Sprintf("%s:%d", *host, *port) log.Fatalln(http.ListenAndServe(addr, mux))