changeset 501:c10c76c92797 metamorph-for-all

Use metamorphic database connections for auth.RunAs().
author Sascha L. Teichmann <sascha.teichmann@intevation.de>
date Fri, 24 Aug 2018 15:30:31 +0200
parents 22e1bf563a04
children af1a198391f3
files pkg/auth/opendb.go pkg/controllers/pwreset.go pkg/models/extservices.go pkg/models/intservices.go
diffstat 4 files changed, 118 insertions(+), 104 deletions(-) [+]
line wrap: on
line diff
--- a/pkg/auth/opendb.go	Fri Aug 24 15:12:22 2018 +0200
+++ b/pkg/auth/opendb.go	Fri Aug 24 15:30:31 2018 +0200
@@ -12,6 +12,8 @@
 	"gemma.intevation.de/gemma/pkg/config"
 )
 
+var ErrNoMetamorphUser = errors.New("No metamorphic user configured")
+
 func OpenDB(user, password string) (*sql.DB, error) {
 
 	// To ease SSL config ride a bit on parsing.
@@ -43,9 +45,11 @@
 	if m.db != nil {
 		return m.db, nil
 	}
-	db, err := OpenDB(
-		config.MetamorphDBUser(),
-		config.MetamorhpDBPassword())
+	user := config.MetamorphDBUser()
+	if user == "" {
+		return nil, ErrNoMetamorphUser
+	}
+	db, err := OpenDB(user, config.MetamorhpDBPassword())
 	if err != nil {
 		return nil, err
 	}
@@ -81,8 +85,6 @@
 WHERE oid IN (SELECT oid FROM cte) AND rolname <> current_user
 AND EXISTS (SELECT 1 FROM users.list_users WHERE username = current_user)`
 
-var ErrNoMetamorphUser = errors.New("No metamorphic user configured")
-
 func AllOtherRoles(user, password string) (Roles, error) {
 	db, err := OpenDB(user, password)
 	if err != nil {
@@ -107,18 +109,11 @@
 	return roles, rows.Err()
 }
 
-func RunAs(role string, fn func(*sql.DB) error) error {
-	user := config.MetamorphDBUser()
-	if user == "" {
-		return ErrNoMetamorphUser
-	}
-	db, err := OpenDB(user, config.MetamorhpDBPassword())
+func RunAs(role string, ctx context.Context, fn func(*sql.Conn) error) error {
+	conn, err := MetamorphConn(ctx, role)
 	if err != nil {
-		return nil
+		return err
 	}
-	defer db.Close()
-	if _, err = db.Exec(`SELECT public.setrole_plan($1)`, role); err == nil {
-		err = fn(db)
-	}
-	return err
+	defer conn.Close()
+	return fn(conn)
 }
--- a/pkg/controllers/pwreset.go	Fri Aug 24 15:12:22 2018 +0200
+++ b/pkg/controllers/pwreset.go	Fri Aug 24 15:30:31 2018 +0200
@@ -2,6 +2,7 @@
 
 import (
 	"bytes"
+	"context"
 	"database/sql"
 	"encoding/hex"
 	"log"
@@ -92,11 +93,14 @@
 func removeOutdated() {
 	for {
 		time.Sleep(cleanupPause)
-		err := auth.RunAs(pwResetRole, func(db *sql.DB) error {
-			good := time.Now().Add(-passwordResetValid)
-			_, err := db.Exec(cleanupRequestsSQL, good)
-			return err
-		})
+		err := auth.RunAs(
+			pwResetRole, context.Background(),
+			func(conn *sql.Conn) error {
+				good := time.Now().Add(-passwordResetValid)
+				_, err := conn.ExecContext(
+					context.Background(), cleanupRequestsSQL, good)
+				return err
+			})
 		if err != nil {
 			log.Printf("error: %v\n", err)
 		}
@@ -177,46 +181,52 @@
 
 	var hash, email string
 
-	if err = auth.RunAs(pwResetRole, func(db *sql.DB) error {
+	ctx := req.Context()
 
-		var count int64
-		if err := db.QueryRow(countRequestsSQL).Scan(&count); err != nil {
-			return err
-		}
+	if err = auth.RunAs(
+		pwResetRole, ctx,
+		func(conn *sql.Conn) error {
 
-		// Limit total number of password requests.
-		if count >= maxPasswordResets {
-			return JSONError{
-				Code:    http.StatusServiceUnavailable,
-				Message: "Too much password reset request",
+			var count int64
+			if err := conn.QueryRowContext(
+				ctx, countRequestsSQL).Scan(&count); err != nil {
+				return err
 			}
-		}
 
-		err := db.QueryRow(userExistsSQL, user.User).Scan(&email)
+			// Limit total number of password requests.
+			if count >= maxPasswordResets {
+				return JSONError{
+					Code:    http.StatusServiceUnavailable,
+					Message: "Too much password reset request",
+				}
+			}
+
+			err := conn.QueryRowContext(ctx, userExistsSQL, user.User).Scan(&email)
 
-		switch {
-		case err == sql.ErrNoRows:
-			return JSONError{http.StatusNotFound, "User does not exist."}
-		case err != nil:
-			return err
-		}
+			switch {
+			case err == sql.ErrNoRows:
+				return JSONError{http.StatusNotFound, "User does not exist."}
+			case err != nil:
+				return err
+			}
 
-		if err := db.QueryRow(countRequestsUserSQL, user.User).Scan(&count); err != nil {
-			return err
-		}
+			if err := conn.QueryRowContext(
+				ctx, countRequestsUserSQL, user.User).Scan(&count); err != nil {
+				return err
+			}
 
-		// Limit requests per user
-		if count >= maxPasswordRequestsPerUser {
-			return JSONError{
-				Code:    http.StatusServiceUnavailable,
-				Message: "Too much password reset requests for user",
+			// Limit requests per user
+			if count >= maxPasswordRequestsPerUser {
+				return JSONError{
+					Code:    http.StatusServiceUnavailable,
+					Message: "Too much password reset requests for user",
+				}
 			}
-		}
 
-		hash = generateHash()
-		_, err = db.Exec(insertRequestSQL, hash, user.User)
-		return err
-	}); err == nil {
+			hash = generateHash()
+			_, err = conn.ExecContext(ctx, insertRequestSQL, hash, user.User)
+			return err
+		}); err == nil {
 		body := requestMessageBody(useHTTPS(req), user.User, hash, req.Host)
 
 		if err = misc.SendMail(email, "Password Reset Link", body); err == nil {
@@ -242,25 +252,28 @@
 
 	var email, user, password string
 
-	if err = auth.RunAs(pwResetRole, func(db *sql.DB) error {
-		err := db.QueryRow(findRequestSQL, hash).Scan(&email, &user)
-		switch {
-		case err == sql.ErrNoRows:
-			return JSONError{http.StatusNotFound, "No such hash"}
-		case err != nil:
+	ctx := req.Context()
+
+	if err = auth.RunAs(
+		pwResetRole, ctx, func(conn *sql.Conn) error {
+			err := conn.QueryRowContext(ctx, findRequestSQL, hash).Scan(&email, &user)
+			switch {
+			case err == sql.ErrNoRows:
+				return JSONError{http.StatusNotFound, "No such hash"}
+			case err != nil:
+				return err
+			}
+			password = generateNewPassword()
+			res, err := conn.ExecContext(ctx, updatePasswordSQL, password, user)
+			if err != nil {
+				return err
+			}
+			if n, err2 := res.RowsAffected(); err2 == nil && n == 0 {
+				return JSONError{http.StatusNotFound, "User not found"}
+			}
+			_, err = conn.ExecContext(ctx, deleteRequestSQL, hash)
 			return err
-		}
-		password = generateNewPassword()
-		res, err := db.Exec(updatePasswordSQL, password, user)
-		if err != nil {
-			return err
-		}
-		if n, err2 := res.RowsAffected(); err2 == nil && n == 0 {
-			return JSONError{http.StatusNotFound, "User not found"}
-		}
-		_, err = db.Exec(deleteRequestSQL, hash)
-		return err
-	}); err == nil {
+		}); err == nil {
 		body := changedMessageBody(useHTTPS(req), user, password, req.Host)
 		if err = misc.SendMail(email, "Password Reset Done", body); err == nil {
 			jr.Result = &struct {
--- a/pkg/models/extservices.go	Fri Aug 24 15:12:22 2018 +0200
+++ b/pkg/models/extservices.go	Fri Aug 24 15:30:31 2018 +0200
@@ -1,6 +1,7 @@
 package models
 
 import (
+	"context"
 	"database/sql"
 	"log"
 	"sort"
@@ -47,25 +48,27 @@
 func (es *ExtServices) load() error {
 	// make empty slice to prevent retry if slice is empty.
 	es.entries = []ExtEntry{}
-	return auth.RunAs("sys_admin", func(db *sql.DB) error {
-		rows, err := db.Query(selectExternalServices)
-		if err != nil {
-			return err
-		}
-		defer rows.Close()
-		for rows.Next() {
-			var entry ExtEntry
-			if err := rows.Scan(
-				&entry.Name,
-				&entry.URL,
-				&entry.WFS,
-			); err != nil {
+	return auth.RunAs("sys_admin", context.Background(),
+		func(conn *sql.Conn) error {
+			rows, err := conn.QueryContext(
+				context.Background(), selectExternalServices)
+			if err != nil {
 				return err
 			}
-			es.entries = append(es.entries, entry)
-		}
-		return rows.Err()
-	})
+			defer rows.Close()
+			for rows.Next() {
+				var entry ExtEntry
+				if err := rows.Scan(
+					&entry.Name,
+					&entry.URL,
+					&entry.WFS,
+				); err != nil {
+					return err
+				}
+				es.entries = append(es.entries, entry)
+			}
+			return rows.Err()
+		})
 }
 
 func (es *ExtServices) Invalidate() {
--- a/pkg/models/intservices.go	Fri Aug 24 15:12:22 2018 +0200
+++ b/pkg/models/intservices.go	Fri Aug 24 15:30:31 2018 +0200
@@ -1,6 +1,7 @@
 package models
 
 import (
+	"context"
 	"database/sql"
 	"log"
 	"sync"
@@ -64,24 +65,26 @@
 func (ps *IntServices) load() error {
 	// make empty slice to prevent retry if slice is empty.
 	ps.entries = []IntEntry{}
-	return auth.RunAs("sys_admin", func(db *sql.DB) error {
-		rows, err := db.Query(selectPublishedServices)
-		if err != nil {
-			return err
-		}
-		defer rows.Close()
-		for rows.Next() {
-			var entry IntEntry
-			if err := rows.Scan(
-				&entry.Name, &entry.Style,
-				&entry.WFS, &entry.WFS,
-			); err != nil {
+	return auth.RunAs("sys_admin", context.Background(),
+		func(conn *sql.Conn) error {
+			rows, err := conn.QueryContext(
+				context.Background(), selectPublishedServices)
+			if err != nil {
 				return err
 			}
-			ps.entries = append(ps.entries, entry)
-		}
-		return rows.Err()
-	})
+			defer rows.Close()
+			for rows.Next() {
+				var entry IntEntry
+				if err := rows.Scan(
+					&entry.Name, &entry.Style,
+					&entry.WFS, &entry.WFS,
+				); err != nil {
+					return err
+				}
+				ps.entries = append(ps.entries, entry)
+			}
+			return rows.Err()
+		})
 	return nil
 }