# HG changeset patch # User Sascha L. Teichmann # Date 1535117431 -7200 # Node ID c10c76c9279712405fc5ebd2e63b966a12079e64 # Parent 22e1bf563a0420c6423f628ff24f5c4880042345 Use metamorphic database connections for auth.RunAs(). diff -r 22e1bf563a04 -r c10c76c92797 pkg/auth/opendb.go --- a/pkg/auth/opendb.go Fri Aug 24 15:12:22 2018 +0200 +++ b/pkg/auth/opendb.go Fri Aug 24 15:30:31 2018 +0200 @@ -12,6 +12,8 @@ "gemma.intevation.de/gemma/pkg/config" ) +var ErrNoMetamorphUser = errors.New("No metamorphic user configured") + func OpenDB(user, password string) (*sql.DB, error) { // To ease SSL config ride a bit on parsing. @@ -43,9 +45,11 @@ if m.db != nil { return m.db, nil } - db, err := OpenDB( - config.MetamorphDBUser(), - config.MetamorhpDBPassword()) + user := config.MetamorphDBUser() + if user == "" { + return nil, ErrNoMetamorphUser + } + db, err := OpenDB(user, config.MetamorhpDBPassword()) if err != nil { return nil, err } @@ -81,8 +85,6 @@ WHERE oid IN (SELECT oid FROM cte) AND rolname <> current_user AND EXISTS (SELECT 1 FROM users.list_users WHERE username = current_user)` -var ErrNoMetamorphUser = errors.New("No metamorphic user configured") - func AllOtherRoles(user, password string) (Roles, error) { db, err := OpenDB(user, password) if err != nil { @@ -107,18 +109,11 @@ return roles, rows.Err() } -func RunAs(role string, fn func(*sql.DB) error) error { - user := config.MetamorphDBUser() - if user == "" { - return ErrNoMetamorphUser - } - db, err := OpenDB(user, config.MetamorhpDBPassword()) +func RunAs(role string, ctx context.Context, fn func(*sql.Conn) error) error { + conn, err := MetamorphConn(ctx, role) if err != nil { - return nil + return err } - defer db.Close() - if _, err = db.Exec(`SELECT public.setrole_plan($1)`, role); err == nil { - err = fn(db) - } - return err + defer conn.Close() + return fn(conn) } diff -r 22e1bf563a04 -r c10c76c92797 pkg/controllers/pwreset.go --- a/pkg/controllers/pwreset.go Fri Aug 24 15:12:22 2018 +0200 +++ b/pkg/controllers/pwreset.go Fri Aug 24 15:30:31 2018 +0200 @@ -2,6 +2,7 @@ import ( "bytes" + "context" "database/sql" "encoding/hex" "log" @@ -92,11 +93,14 @@ func removeOutdated() { for { time.Sleep(cleanupPause) - err := auth.RunAs(pwResetRole, func(db *sql.DB) error { - good := time.Now().Add(-passwordResetValid) - _, err := db.Exec(cleanupRequestsSQL, good) - return err - }) + err := auth.RunAs( + pwResetRole, context.Background(), + func(conn *sql.Conn) error { + good := time.Now().Add(-passwordResetValid) + _, err := conn.ExecContext( + context.Background(), cleanupRequestsSQL, good) + return err + }) if err != nil { log.Printf("error: %v\n", err) } @@ -177,46 +181,52 @@ var hash, email string - if err = auth.RunAs(pwResetRole, func(db *sql.DB) error { + ctx := req.Context() - var count int64 - if err := db.QueryRow(countRequestsSQL).Scan(&count); err != nil { - return err - } + if err = auth.RunAs( + pwResetRole, ctx, + func(conn *sql.Conn) error { - // Limit total number of password requests. - if count >= maxPasswordResets { - return JSONError{ - Code: http.StatusServiceUnavailable, - Message: "Too much password reset request", + var count int64 + if err := conn.QueryRowContext( + ctx, countRequestsSQL).Scan(&count); err != nil { + return err } - } - err := db.QueryRow(userExistsSQL, user.User).Scan(&email) + // Limit total number of password requests. + if count >= maxPasswordResets { + return JSONError{ + Code: http.StatusServiceUnavailable, + Message: "Too much password reset request", + } + } + + err := conn.QueryRowContext(ctx, userExistsSQL, user.User).Scan(&email) - switch { - case err == sql.ErrNoRows: - return JSONError{http.StatusNotFound, "User does not exist."} - case err != nil: - return err - } + switch { + case err == sql.ErrNoRows: + return JSONError{http.StatusNotFound, "User does not exist."} + case err != nil: + return err + } - if err := db.QueryRow(countRequestsUserSQL, user.User).Scan(&count); err != nil { - return err - } + if err := conn.QueryRowContext( + ctx, countRequestsUserSQL, user.User).Scan(&count); err != nil { + return err + } - // Limit requests per user - if count >= maxPasswordRequestsPerUser { - return JSONError{ - Code: http.StatusServiceUnavailable, - Message: "Too much password reset requests for user", + // Limit requests per user + if count >= maxPasswordRequestsPerUser { + return JSONError{ + Code: http.StatusServiceUnavailable, + Message: "Too much password reset requests for user", + } } - } - hash = generateHash() - _, err = db.Exec(insertRequestSQL, hash, user.User) - return err - }); err == nil { + hash = generateHash() + _, err = conn.ExecContext(ctx, insertRequestSQL, hash, user.User) + return err + }); err == nil { body := requestMessageBody(useHTTPS(req), user.User, hash, req.Host) if err = misc.SendMail(email, "Password Reset Link", body); err == nil { @@ -242,25 +252,28 @@ var email, user, password string - if err = auth.RunAs(pwResetRole, func(db *sql.DB) error { - err := db.QueryRow(findRequestSQL, hash).Scan(&email, &user) - switch { - case err == sql.ErrNoRows: - return JSONError{http.StatusNotFound, "No such hash"} - case err != nil: + ctx := req.Context() + + if err = auth.RunAs( + pwResetRole, ctx, func(conn *sql.Conn) error { + err := conn.QueryRowContext(ctx, findRequestSQL, hash).Scan(&email, &user) + switch { + case err == sql.ErrNoRows: + return JSONError{http.StatusNotFound, "No such hash"} + case err != nil: + return err + } + password = generateNewPassword() + res, err := conn.ExecContext(ctx, updatePasswordSQL, password, user) + if err != nil { + return err + } + if n, err2 := res.RowsAffected(); err2 == nil && n == 0 { + return JSONError{http.StatusNotFound, "User not found"} + } + _, err = conn.ExecContext(ctx, deleteRequestSQL, hash) return err - } - password = generateNewPassword() - res, err := db.Exec(updatePasswordSQL, password, user) - if err != nil { - return err - } - if n, err2 := res.RowsAffected(); err2 == nil && n == 0 { - return JSONError{http.StatusNotFound, "User not found"} - } - _, err = db.Exec(deleteRequestSQL, hash) - return err - }); err == nil { + }); err == nil { body := changedMessageBody(useHTTPS(req), user, password, req.Host) if err = misc.SendMail(email, "Password Reset Done", body); err == nil { jr.Result = &struct { diff -r 22e1bf563a04 -r c10c76c92797 pkg/models/extservices.go --- a/pkg/models/extservices.go Fri Aug 24 15:12:22 2018 +0200 +++ b/pkg/models/extservices.go Fri Aug 24 15:30:31 2018 +0200 @@ -1,6 +1,7 @@ package models import ( + "context" "database/sql" "log" "sort" @@ -47,25 +48,27 @@ func (es *ExtServices) load() error { // make empty slice to prevent retry if slice is empty. es.entries = []ExtEntry{} - return auth.RunAs("sys_admin", func(db *sql.DB) error { - rows, err := db.Query(selectExternalServices) - if err != nil { - return err - } - defer rows.Close() - for rows.Next() { - var entry ExtEntry - if err := rows.Scan( - &entry.Name, - &entry.URL, - &entry.WFS, - ); err != nil { + return auth.RunAs("sys_admin", context.Background(), + func(conn *sql.Conn) error { + rows, err := conn.QueryContext( + context.Background(), selectExternalServices) + if err != nil { return err } - es.entries = append(es.entries, entry) - } - return rows.Err() - }) + defer rows.Close() + for rows.Next() { + var entry ExtEntry + if err := rows.Scan( + &entry.Name, + &entry.URL, + &entry.WFS, + ); err != nil { + return err + } + es.entries = append(es.entries, entry) + } + return rows.Err() + }) } func (es *ExtServices) Invalidate() { diff -r 22e1bf563a04 -r c10c76c92797 pkg/models/intservices.go --- a/pkg/models/intservices.go Fri Aug 24 15:12:22 2018 +0200 +++ b/pkg/models/intservices.go Fri Aug 24 15:30:31 2018 +0200 @@ -1,6 +1,7 @@ package models import ( + "context" "database/sql" "log" "sync" @@ -64,24 +65,26 @@ func (ps *IntServices) load() error { // make empty slice to prevent retry if slice is empty. ps.entries = []IntEntry{} - return auth.RunAs("sys_admin", func(db *sql.DB) error { - rows, err := db.Query(selectPublishedServices) - if err != nil { - return err - } - defer rows.Close() - for rows.Next() { - var entry IntEntry - if err := rows.Scan( - &entry.Name, &entry.Style, - &entry.WFS, &entry.WFS, - ); err != nil { + return auth.RunAs("sys_admin", context.Background(), + func(conn *sql.Conn) error { + rows, err := conn.QueryContext( + context.Background(), selectPublishedServices) + if err != nil { return err } - ps.entries = append(ps.entries, entry) - } - return rows.Err() - }) + defer rows.Close() + for rows.Next() { + var entry IntEntry + if err := rows.Scan( + &entry.Name, &entry.Style, + &entry.WFS, &entry.WFS, + ); err != nil { + return err + } + ps.entries = append(ps.entries, entry) + } + return rows.Err() + }) return nil }