diff auth/connection.go @ 134:0c56c56a1c44 remove-jwt

Removed the JWT layer from the session management.
author Sascha L. Teichmann <sascha.teichmann@intevation.de>
date Fri, 29 Jun 2018 17:17:20 +0200
parents f4523620ba5d
children 0116aae1071b
line wrap: on
line diff
--- a/auth/connection.go	Thu Jun 28 17:26:38 2018 +0200
+++ b/auth/connection.go	Fri Jun 29 17:17:20 2018 +0200
@@ -4,6 +4,7 @@
 	"database/sql"
 	"errors"
 	"log"
+	"sync"
 	"time"
 )
 
@@ -17,18 +18,31 @@
 )
 
 type Connection struct {
-	user     string
-	password string
+	session *Session
 
 	access   time.Time
 	db       *sql.DB
 	refCount int
+
+	mu sync.Mutex
+}
+
+func (c *Connection) set(session *Session) {
+	c.session = session
+	c.touch()
 }
 
-func (c *Connection) set(user, password string) {
-	c.user = user
-	c.password = password
+func (c *Connection) touch() {
+	c.mu.Lock()
 	c.access = time.Now()
+	c.mu.Unlock()
+}
+
+func (c *Connection) last() time.Time {
+	c.mu.Lock()
+	access := c.access
+	c.mu.Unlock()
+	return access
 }
 
 func (c *Connection) close() {
@@ -70,7 +84,7 @@
 func (cp *ConnectionPool) cleanDB() {
 	valid := time.Now().Add(-maxDBIdle)
 	for _, con := range cp.conns {
-		if con.refCount <= 0 && con.access.Before(valid) {
+		if con.refCount <= 0 && con.last().Before(valid) {
 			con.close()
 		}
 	}
@@ -79,14 +93,7 @@
 func (cp *ConnectionPool) cleanToken() {
 	now := time.Now()
 	for token, con := range cp.conns {
-		claims, err := TokenToClaims(token)
-		if err != nil { // Should not happen.
-			log.Printf("error: %v\n", err)
-			con.close()
-			delete(cp.conns, token)
-			continue
-		}
-		expires := time.Unix(claims.ExpiresAt, 0)
+		expires := time.Unix(con.session.ExpiresAt, 0)
 		if expires.Before(now) {
 			// TODO: Be more graceful here?
 			con.close()
@@ -110,46 +117,16 @@
 	return <-res
 }
 
-func (cp *ConnectionPool) Replace(
-	token string,
-	replace func(string, string) (string, error)) (string, error) {
-
-	type res struct {
-		token string
-		err   error
-	}
-
-	resCh := make(chan res)
-
-	cp.cmds <- func(cp *ConnectionPool) {
-		conn, found := cp.conns[token]
-		if !found {
-			resCh <- res{err: ErrNoSuchToken}
-			return
-		}
-		newToken, err := replace(conn.user, conn.password)
-		if err == nil {
-			delete(cp.conns, token)
-			cp.conns[newToken] = conn
-		}
-		resCh <- res{token: newToken, err: err}
-	}
-
-	r := <-resCh
-	return r.token, r.err
-}
-
-func (cp *ConnectionPool) Add(token, user, password string) *Connection {
+func (cp *ConnectionPool) Add(token string, session *Session) *Connection {
 	res := make(chan *Connection)
 
 	cp.cmds <- func(cp *ConnectionPool) {
-
 		con := cp.conns[token]
 		if con == nil {
 			con = &Connection{}
 			cp.conns[token] = con
 		}
-		con.set(user, password)
+		con.set(session)
 		res <- con
 	}
 
@@ -157,6 +134,33 @@
 	return con
 }
 
+func (cp ConnectionPool) Renew(token string) (string, error) {
+
+	type result struct {
+		newToken string
+		err      error
+	}
+
+	resCh := make(chan result)
+
+	cp.cmds <- func(cp *ConnectionPool) {
+		con := cp.conns[token]
+		if con == nil {
+			resCh <- result{err: ErrNoSuchToken}
+		} else {
+			delete(cp.conns, token)
+			newToken := GenerateSessionKey()
+			// TODO: Ensure that this is not racy!
+			con.session.ExpiresAt = time.Now().Add(maxTokenValid).Unix()
+			cp.conns[newToken] = con
+			resCh <- result{newToken: newToken}
+		}
+	}
+
+	r := <-resCh
+	return r.newToken, r.err
+}
+
 func (cp *ConnectionPool) trim(conn *Connection) {
 
 	conn.refCount--
@@ -168,8 +172,8 @@
 
 		for _, con := range cp.conns {
 			if con.db != nil && con.refCount <= 0 {
-				if con.access.Before(least) {
-					least = con.access
+				if last := con.last(); last.Before(least) {
+					least = last
 					oldest = con
 				}
 				count++
@@ -197,14 +201,15 @@
 			res <- result{err: ErrNoSuchToken}
 			return
 		}
-		con.access = time.Now()
+		con.touch()
 		if con.db != nil {
 			con.refCount++
 			res <- result{con: con}
 			return
 		}
 
-		db, err := opendb(con.user, con.password)
+		session := con.session
+		db, err := opendb(session.User, session.Password)
 		if err != nil {
 			res <- result{err: err}
 			return
@@ -228,3 +233,17 @@
 
 	return fn(r.con.db)
 }
+
+func (cp *ConnectionPool) Session(token string) *Session {
+	res := make(chan *Session)
+	cp.cmds <- func(cp *ConnectionPool) {
+		con := cp.conns[token]
+		if con == nil {
+			res <- nil
+		} else {
+			con.touch()
+			res <- con.session
+		}
+	}
+	return <-res
+}