changeset 321:974a5e4c0055

Persist password reset requests in database.
author Sascha L. Teichmann <sascha.teichmann@intevation.de>
date Thu, 02 Aug 2018 16:40:14 +0200
parents e4bf72cda62e
children 34ecfd8dc11e
files controllers/pwreset.go
diffstat 1 files changed, 114 insertions(+), 163 deletions(-) [+]
line wrap: on
line diff
--- a/controllers/pwreset.go	Thu Aug 02 15:14:38 2018 +0200
+++ b/controllers/pwreset.go	Thu Aug 02 16:40:14 2018 +0200
@@ -10,7 +10,6 @@
 	"net/http"
 	"os/exec"
 	"strings"
-	"sync"
 	"text/template"
 	"time"
 
@@ -22,6 +21,25 @@
 )
 
 const (
+	insertRequestSQL = `INSERT INTO pw_reset.password_reset_requests
+    (hash, username) VALUES ($1, $2)`
+
+	countRequestsSQL = `SELECT count(*) FROM pw_reset.password_reset_requests`
+
+	countRequestsUserSQL = `SELECT count(*) FROM pw_reset.password_reset_requests
+    WHERE username = $1`
+
+	deleteRequestSQL = `DELETE FROM pw_reset.password_reset_requests
+    WHERE hash = $1`
+
+	findRequestSQL = `SELECT lu.email_address, lu.username
+    FROM pw_reset.password_reset_requests prr
+    JOIN pw_reset.list_users lu on prr.username = lu.username
+    WHERE prr.hash = $1`
+
+	cleanupRequestsSQL = `DELETE FROM pw_reset.password_reset_requests
+    WHERE issued < $1`
+
 	userExistsSQL = `SELECT email_address
     FROM pw_reset.list_users WHERE username = $1`
 
@@ -30,9 +48,12 @@
 )
 
 const (
-	passwordResetValid         = time.Hour
+	hashLength                 = 16
+	passwordLength             = 20
+	passwordResetValid         = 12 * time.Hour
 	maxPasswordResets          = 1000
 	maxPasswordRequestsPerUser = 5
+	cleanupPause               = 15 * time.Minute
 )
 
 var (
@@ -45,7 +66,7 @@
 
 {{ .HTTPS }}://{{ .Server }}/api/users/passwordreset/{{ .Hash }}
 
-The link is only valid for one hour.
+The link is only valid for 12 hours.
 
 Best regards
     Your service team`))
@@ -63,83 +84,30 @@
     Your service team`))
 )
 
-type timedUser struct {
-	user  string
-	email string
-	time  time.Time
-}
-
-type resetRequests struct {
-	sync.Mutex
-	reqs map[string]*timedUser
-}
-
-var passwordResetRequests = func() *resetRequests {
-	rr := &resetRequests{reqs: map[string]*timedUser{}}
-	go func() {
-		for {
-			time.Sleep(2 * time.Minute)
-			rr.removeOutdated()
-		}
-	}()
-	return rr
-}()
-
-func (rr *resetRequests) len() int {
-	rr.Lock()
-	l := len(rr.reqs)
-	rr.Unlock()
-	return l
+func asServiceUser(fn func(*sql.DB) error) error {
+	cfg := &config.Config
+	db, err := auth.OpenDB(cfg.ServiceUser, cfg.ServicePassword)
+	if err == nil {
+		defer db.Close()
+		err = fn(db)
+	}
+	return err
 }
 
-func (rr *resetRequests) userAllowed(user string) bool {
-	rr.Lock()
-	defer rr.Unlock()
-	var count int
-	for _, v := range rr.reqs {
-		if v.user == user {
-			if count++; count >= maxPasswordRequestsPerUser {
-				return false
-			}
-		}
-	}
-	return true
-}
-
-func (rr *resetRequests) store(hash, user, email string) {
-	now := time.Now()
-	rr.Lock()
-	rr.reqs[hash] = &timedUser{user, email, now}
-	rr.Unlock()
+func init() {
+	go removeOutdated()
 }
 
-func (rr *resetRequests) fetch(hash string) *timedUser {
-	rr.Lock()
-	defer rr.Unlock()
-	tu := rr.reqs[hash]
-	if tu == nil {
-		return nil
-	}
-	if tu.time.Before(time.Now().Add(-passwordResetValid)) {
-		delete(rr.reqs, hash)
-		return nil
-	}
-	return tu
-}
-
-func (rr *resetRequests) delete(hash string) {
-	rr.Lock()
-	delete(rr.reqs, hash)
-	rr.Unlock()
-}
-
-func (rr *resetRequests) removeOutdated() {
-	good := time.Now().Add(-passwordResetValid)
-	rr.Lock()
-	defer rr.Unlock()
-	for k, v := range rr.reqs {
-		if v.time.Before(good) {
-			delete(rr.reqs, k)
+func removeOutdated() {
+	for {
+		time.Sleep(cleanupPause)
+		err := asServiceUser(func(db *sql.DB) error {
+			good := time.Now().Add(-passwordResetValid)
+			_, err := db.Exec(cleanupRequestsSQL, good)
+			return err
+		})
+		if err != nil {
+			log.Printf("error: %v\n", err)
 		}
 	}
 }
@@ -189,11 +157,6 @@
 	return "http"
 }
 
-const (
-	hashLength     = 32
-	passwordLength = 20
-)
-
 func generateHash() string {
 	return hex.EncodeToString(auth.GenerateRandomKey(hashLength))
 }
@@ -225,10 +188,8 @@
 	if err == nil {
 		return strings.TrimSpace(string(out))
 	}
-
 	// Use internal generator.
 	return randomString(20)
-
 }
 
 func sendMail(email, subject, body string) error {
@@ -254,18 +215,9 @@
 func passwordResetRequest(
 	input interface{},
 	req *http.Request,
-	db *sql.DB,
+	_ *sql.DB,
 ) (jr JSONResult, err error) {
 
-	// Limit total number of password requests.
-	if passwordResetRequests.len() >= maxPasswordResets {
-		err = JSONError{
-			Code:    http.StatusServiceUnavailable,
-			Message: "Too much password reset request",
-		}
-		return
-	}
-
 	user := input.(*PWResetUser)
 
 	if user.User == "" {
@@ -273,53 +225,63 @@
 		return
 	}
 
-	cfg := &config.Config
-	if db, err = auth.OpenDB(cfg.ServiceUser, cfg.ServicePassword); err != nil {
-		return
-	}
-	defer db.Close()
+	var hash, email string
+
+	if err = asServiceUser(func(db *sql.DB) error {
+
+		var count int64
+		if err := db.QueryRow(countRequestsSQL).Scan(&count); err != nil {
+			return err
+		}
 
-	var email string
-	err = db.QueryRow(userExistsSQL, user.User).Scan(&email)
+		// Limit total number of password requests.
+		if count >= maxPasswordResets {
+			return JSONError{
+				Code:    http.StatusServiceUnavailable,
+				Message: "Too much password reset request",
+			}
+		}
 
-	switch {
-	case err == sql.ErrNoRows:
-		err = JSONError{http.StatusNotFound, "User does not exist."}
-		return
-	case err != nil:
-		return
-	}
+		err := db.QueryRow(userExistsSQL, user.User).Scan(&email)
 
-	// Limit requests per user
-	if !passwordResetRequests.userAllowed(user.User) {
-		err = JSONError{
-			Code:    http.StatusServiceUnavailable,
-			Message: "Too much password reset requests for user",
+		switch {
+		case err == sql.ErrNoRows:
+			return JSONError{http.StatusNotFound, "User does not exist."}
+		case err != nil:
+			return err
 		}
-		return
-	}
 
-	hash := generateHash()
+		if err := db.QueryRow(countRequestsUserSQL, user.User).Scan(&count); err != nil {
+			return err
+		}
 
-	passwordResetRequests.store(hash, user.User, email)
-
-	body := requestMessageBody(useHTTPS(req), user.User, hash, req.Host)
+		// Limit requests per user
+		if count >= maxPasswordRequestsPerUser {
+			return JSONError{
+				Code:    http.StatusServiceUnavailable,
+				Message: "Too much password reset requests for user",
+			}
+		}
 
-	if err = sendMail(email, "Password Reset Link", body); err != nil {
-		return
-	}
+		hash = generateHash()
+		_, err = db.Exec(insertRequestSQL, hash, user.User)
+		return err
+	}); err == nil {
+		body := requestMessageBody(useHTTPS(req), user.User, hash, req.Host)
 
-	jr.Result = &struct {
-		SendTo string `json:"send-to"`
-	}{email}
-
+		if err = sendMail(email, "Password Reset Link", body); err == nil {
+			jr.Result = &struct {
+				SendTo string `json:"send-to"`
+			}{email}
+		}
+	}
 	return
 }
 
 func passwordReset(
 	_ interface{},
 	req *http.Request,
-	db *sql.DB,
+	_ *sql.DB,
 ) (jr JSONResult, err error) {
 
 	hash := mux.Vars(req)["hash"]
@@ -328,44 +290,33 @@
 		return
 	}
 
-	tu := passwordResetRequests.fetch(hash)
-	if tu == nil {
-		err = JSONError{http.StatusNotFound, "No such hash"}
-		return
-	}
-
-	password := generateNewPassword()
-
-	cfg := &config.Config
-	if db, err = auth.OpenDB(cfg.ServiceUser, cfg.ServicePassword); err != nil {
-		return
-	}
-	defer db.Close()
+	var email, user, password string
 
-	var res sql.Result
-	if res, err = db.Exec(
-		updatePasswordSQL,
-		password,
-		tu.user,
-	); err != nil {
-		return
+	if err = asServiceUser(func(db *sql.DB) error {
+		err := db.QueryRow(findRequestSQL, hash).Scan(&email, &user)
+		switch {
+		case err == sql.ErrNoRows:
+			return JSONError{http.StatusNotFound, "No such hash"}
+		case err != nil:
+			return err
+		}
+		password = generateNewPassword()
+		res, err := db.Exec(updatePasswordSQL, password, user)
+		if err != nil {
+			return err
+		}
+		if n, err2 := res.RowsAffected(); err2 == nil && n == 0 {
+			return JSONError{http.StatusNotFound, "User not found"}
+		}
+		_, err = db.Exec(deleteRequestSQL, hash)
+		return err
+	}); err == nil {
+		body := changedMessageBody(useHTTPS(req), user, password, req.Host)
+		if err = sendMail(email, "Password Reset Done", body); err == nil {
+			jr.Result = &struct {
+				SendTo string `json:"send-to"`
+			}{email}
+		}
 	}
-
-	passwordResetRequests.delete(hash)
-
-	if n, err2 := res.RowsAffected(); err2 == nil && n == 0 {
-		err = JSONError{http.StatusNotFound, "User not found"}
-		return
-	}
-
-	body := changedMessageBody(useHTTPS(req), tu.user, password, req.Host)
-	if err = sendMail(tu.email, "Password Reset Done", body); err != nil {
-		return
-	}
-
-	jr.Result = &struct {
-		Message string `json:"message"`
-	}{"User password changed"}
-
 	return
 }