changeset 486:b2dc9c2f69e0 metamorph-for-all

First stab to use the metamorphic db to do all database stuff.
author Sascha L. Teichmann <sascha.teichmann@intevation.de>
date Fri, 24 Aug 2018 13:56:06 +0200
parents 7a8644e9e50e
children 8a0737aa6ab6
files pkg/auth/connection.go pkg/auth/opendb.go pkg/auth/pool.go pkg/auth/session.go pkg/controllers/json.go pkg/controllers/publish.go pkg/controllers/pwreset.go pkg/controllers/routes.go pkg/controllers/user.go
diffstat 9 files changed, 89 insertions(+), 106 deletions(-) [+]
line wrap: on
line diff
--- a/pkg/auth/connection.go	Fri Aug 24 12:14:56 2018 +0200
+++ b/pkg/auth/connection.go	Fri Aug 24 13:56:06 2018 +0200
@@ -1,10 +1,8 @@
 package auth
 
 import (
-	"database/sql"
 	"errors"
 	"io"
-	"log"
 	"sync"
 	"time"
 
@@ -21,9 +19,7 @@
 type Connection struct {
 	session *Session
 
-	access   time.Time
-	db       *sql.DB
-	refCount int
+	access time.Time
 
 	mu sync.Mutex
 }
@@ -90,10 +86,4 @@
 }
 
 func (c *Connection) close() {
-	if c.db != nil {
-		if err := c.db.Close(); err != nil {
-			log.Printf("warn: %v\n", err)
-		}
-		c.db = nil
-	}
 }
--- a/pkg/auth/opendb.go	Fri Aug 24 12:14:56 2018 +0200
+++ b/pkg/auth/opendb.go	Fri Aug 24 13:56:06 2018 +0200
@@ -1,8 +1,10 @@
 package auth
 
 import (
+	"context"
 	"database/sql"
 	"errors"
+	"sync"
 
 	"github.com/jackc/pgx"
 	"github.com/jackc/pgx/stdlib"
@@ -28,6 +30,45 @@
 	return stdlib.OpenDB(cc), nil
 }
 
+type metamorph struct {
+	sync.Mutex
+	db *sql.DB
+}
+
+var mm metamorph
+
+func (m *metamorph) open() (*sql.DB, error) {
+	m.Lock()
+	defer m.Unlock()
+	if m.db != nil {
+		return m.db, nil
+	}
+	db, err := OpenDB(
+		config.MetamorphDBUser(),
+		config.MetamorhpDBPassword())
+	if err != nil {
+		return nil, err
+	}
+	m.db = db
+	return db, nil
+}
+
+func MetamorphConn(ctx context.Context, user string) (*sql.Conn, error) {
+	db, err := mm.open()
+	if err != nil {
+		return nil, err
+	}
+	conn, err := db.Conn(ctx)
+	if err != nil {
+		return nil, err
+	}
+	if _, err := conn.ExecContext(ctx, `SELECT public.setrole_plan($1)`, user); err != nil {
+		conn.Close()
+		return nil, err
+	}
+	return conn, nil
+}
+
 const allRoles = `
 WITH RECURSIVE cte AS (
    SELECT oid FROM pg_roles WHERE rolname = current_user
--- a/pkg/auth/pool.go	Fri Aug 24 12:14:56 2018 +0200
+++ b/pkg/auth/pool.go	Fri Aug 24 13:56:06 2018 +0200
@@ -2,7 +2,6 @@
 
 import (
 	"bytes"
-	"database/sql"
 	"log"
 	"time"
 
@@ -89,12 +88,6 @@
 }
 
 func (pcp *ConnectionPool) cleanDB() {
-	valid := time.Now().Add(-maxDBIdle)
-	for _, con := range pcp.conns {
-		if con.refCount <= 0 && con.last().Before(valid) {
-			con.close()
-		}
-	}
 }
 
 func (pcp *ConnectionPool) cleanToken() {
@@ -203,36 +196,11 @@
 	return r.newToken, r.err
 }
 
-func (pcp *ConnectionPool) trim(conn *Connection) {
-
-	conn.refCount--
-
-	for {
-		least := time.Now()
-		var count int
-		var oldest *Connection
-
-		for _, con := range pcp.conns {
-			if con.db != nil && con.refCount <= 0 {
-				if last := con.last(); last.Before(least) {
-					least = last
-					oldest = con
-				}
-				count++
-			}
-		}
-		if count <= maxOpen {
-			break
-		}
-		oldest.close()
-	}
-}
-
-func (pcp *ConnectionPool) Do(token string, fn func(*sql.DB) error) error {
+func (pcp *ConnectionPool) Do(token string) (*Session, error) {
 
 	type result struct {
-		con *Connection
-		err error
+		session *Session
+		err     error
 	}
 
 	res := make(chan result)
@@ -244,41 +212,18 @@
 			return
 		}
 		con.touch()
-		// store the session here. The ref counting for
-		// open db connections is irrelevant for persistence
-		// as they all come up closed when the system reboots.
 		pcp.store(token, con)
 
-		if con.db != nil {
-			con.refCount++
-			res <- result{con: con}
-			return
-		}
-
-		session := con.session
-		db, err := OpenDB(session.User, session.Password)
-		if err != nil {
-			res <- result{err: err}
-			return
-		}
-		con.db = db
-		con.refCount++
-		res <- result{con: con}
+		res <- result{session: con.session}
 	}
 
 	r := <-res
 
 	if r.err != nil {
-		return r.err
+		return nil, r.err
 	}
 
-	defer func() {
-		pcp.cmds <- func(pcp *ConnectionPool) {
-			pcp.trim(r.con)
-		}
-	}()
-
-	return fn(r.con.db)
+	return r.session, nil
 }
 
 func (pcp *ConnectionPool) Session(token string) *Session {
@@ -300,10 +245,6 @@
 	pcp.cmds <- func(pcp *ConnectionPool) {
 		for token, con := range pcp.conns {
 			if con.session.User == user {
-				if db := con.db; db != nil {
-					con.db = nil
-					db.Close()
-				}
 				delete(pcp.conns, token)
 				pcp.remove(token)
 			}
--- a/pkg/auth/session.go	Fri Aug 24 12:14:56 2018 +0200
+++ b/pkg/auth/session.go	Fri Aug 24 13:56:06 2018 +0200
@@ -15,7 +15,6 @@
 type Session struct {
 	ExpiresAt int64  `json:"expires"`
 	User      string `json:"user"`
-	Password  string `json:"password"`
 	Roles     Roles  `json:"roles"`
 }
 
@@ -48,7 +47,6 @@
 	return &Session{
 		ExpiresAt: time.Now().Add(maxTokenValid).Unix(),
 		User:      user,
-		Password:  password,
 		Roles:     roles,
 	}
 }
@@ -57,7 +55,6 @@
 	wr := misc.BinWriter{w, nil}
 	wr.WriteBin(s.ExpiresAt)
 	wr.WriteString(s.User)
-	wr.WriteString(s.Password)
 	wr.WriteBin(uint32(len(s.Roles)))
 	for _, role := range s.Roles {
 		wr.WriteString(role)
@@ -71,7 +68,6 @@
 	rd := misc.BinReader{r, nil}
 	rd.ReadBin(&x.ExpiresAt)
 	rd.ReadString(&x.User)
-	rd.ReadString(&x.Password)
 	rd.ReadBin(&n)
 	x.Roles = make(Roles, n)
 	for i := uint32(0); n > 0 && i < n; i++ {
--- a/pkg/controllers/json.go	Fri Aug 24 12:14:56 2018 +0200
+++ b/pkg/controllers/json.go	Fri Aug 24 13:56:06 2018 +0200
@@ -19,7 +19,8 @@
 
 type JSONHandler struct {
 	Input  func() interface{}
-	Handle func(interface{}, *http.Request, *sql.DB) (JSONResult, error)
+	Handle func(interface{}, *http.Request, *sql.Conn) (JSONResult, error)
+	NoConn bool
 }
 
 type JSONError struct {
@@ -46,11 +47,15 @@
 	var jr JSONResult
 	var err error
 
-	if token, ok := auth.GetToken(req); ok {
-		err = auth.ConnPool.Do(token, func(db *sql.DB) (err error) {
-			jr, err = j.Handle(input, req, db)
-			return err
-		})
+	if token, ok := auth.GetToken(req); ok && !j.NoConn {
+		var session *auth.Session
+		if session, err = auth.ConnPool.Do(token); err != nil {
+			var conn *sql.Conn
+			if conn, err = auth.MetamorphConn(req.Context(), session.User); err != nil {
+				defer conn.Close()
+				jr, err = j.Handle(input, req, conn)
+			}
+		}
 	} else {
 		jr, err = j.Handle(input, req, nil)
 	}
--- a/pkg/controllers/publish.go	Fri Aug 24 12:14:56 2018 +0200
+++ b/pkg/controllers/publish.go	Fri Aug 24 13:56:06 2018 +0200
@@ -7,7 +7,7 @@
 	"gemma.intevation.de/gemma/pkg/models"
 )
 
-func published(_ interface{}, req *http.Request, _ *sql.DB) (jr JSONResult, err error) {
+func published(_ interface{}, req *http.Request, _ *sql.Conn) (jr JSONResult, err error) {
 	jr = JSONResult{
 		Result: struct {
 			Internal []models.IntEntry `json:"internal"`
--- a/pkg/controllers/pwreset.go	Fri Aug 24 12:14:56 2018 +0200
+++ b/pkg/controllers/pwreset.go	Fri Aug 24 13:56:06 2018 +0200
@@ -165,7 +165,7 @@
 func passwordResetRequest(
 	input interface{},
 	req *http.Request,
-	_ *sql.DB,
+	_ *sql.Conn,
 ) (jr JSONResult, err error) {
 
 	user := input.(*models.PWResetUser)
@@ -231,7 +231,7 @@
 func passwordReset(
 	_ interface{},
 	req *http.Request,
-	_ *sql.DB,
+	_ *sql.Conn,
 ) (jr JSONResult, err error) {
 
 	hash := mux.Vars(req)["hash"]
--- a/pkg/controllers/routes.go	Fri Aug 24 12:14:56 2018 +0200
+++ b/pkg/controllers/routes.go	Fri Aug 24 13:56:06 2018 +0200
@@ -92,6 +92,7 @@
 
 	api.Handle("/published", any(&JSONHandler{
 		Handle: published,
+		NoConn: true,
 	})).Methods(http.MethodGet)
 
 	// Token handling: Login/Logout.
--- a/pkg/controllers/user.go	Fri Aug 24 12:14:56 2018 +0200
+++ b/pkg/controllers/user.go	Fri Aug 24 13:56:06 2018 +0200
@@ -54,7 +54,7 @@
 
 func deleteUser(
 	_ interface{}, req *http.Request,
-	db *sql.DB,
+	db *sql.Conn,
 ) (jr JSONResult, err error) {
 
 	user := mux.Vars(req)["user"]
@@ -71,7 +71,7 @@
 
 	var res sql.Result
 
-	if res, err = db.Exec(deleteUserSQL, user); err != nil {
+	if res, err = db.ExecContext(req.Context(), deleteUserSQL, user); err != nil {
 		return
 	}
 
@@ -91,8 +91,9 @@
 }
 
 func updateUser(
-	input interface{}, req *http.Request,
-	db *sql.DB,
+	input interface{},
+	req *http.Request,
+	db *sql.Conn,
 ) (jr JSONResult, err error) {
 
 	user := models.UserName(mux.Vars(req)["user"])
@@ -106,7 +107,8 @@
 
 	if s, _ := auth.GetSession(req); s.Roles.Has("sys_admin") {
 		if newUser.Extent == nil {
-			res, err = db.Exec(
+			res, err = db.ExecContext(
+				req.Context(),
 				updateUserSQL,
 				user,
 				newUser.Role,
@@ -116,7 +118,8 @@
 				newUser.Email,
 			)
 		} else {
-			res, err = db.Exec(
+			res, err = db.ExecContext(
+				req.Context(),
 				updateUserExtentSQL,
 				user,
 				newUser.Role,
@@ -133,7 +136,8 @@
 			err = JSONError{http.StatusBadRequest, "extent is mandatory"}
 			return
 		}
-		res, err = db.Exec(
+		res, err = db.ExecContext(
+			req.Context(),
 			updateUserUnprivSQL,
 			user,
 			newUser.Password,
@@ -170,14 +174,16 @@
 }
 
 func createUser(
-	input interface{}, req *http.Request,
-	db *sql.DB,
+	input interface{},
+	req *http.Request,
+	db *sql.Conn,
 ) (jr JSONResult, err error) {
 
 	user := input.(*models.User)
 
 	if user.Extent == nil {
-		_, err = db.Exec(
+		_, err = db.ExecContext(
+			req.Context(),
 			createUserSQL,
 			user.Role,
 			user.User,
@@ -186,7 +192,8 @@
 			user.Email,
 		)
 	} else {
-		_, err = db.Exec(
+		_, err = db.ExecContext(
+			req.Context(),
 			createUserExtentSQL,
 			user.Role,
 			user.User,
@@ -212,13 +219,14 @@
 }
 
 func listUsers(
-	_ interface{}, req *http.Request,
-	db *sql.DB,
+	_ interface{},
+	req *http.Request,
+	db *sql.Conn,
 ) (jr JSONResult, err error) {
 
 	var rows *sql.Rows
 
-	rows, err = db.Query(listUsersSQL)
+	rows, err = db.QueryContext(req.Context(), listUsersSQL)
 	if err != nil {
 		return
 	}
@@ -250,8 +258,9 @@
 }
 
 func listUser(
-	_ interface{}, req *http.Request,
-	db *sql.DB,
+	_ interface{},
+	req *http.Request,
+	db *sql.Conn,
 ) (jr JSONResult, err error) {
 
 	user := models.UserName(mux.Vars(req)["user"])
@@ -265,7 +274,7 @@
 		Extent: &models.BoundingBox{},
 	}
 
-	err = db.QueryRow(listUserSQL, user).Scan(
+	err = db.QueryRowContext(req.Context(), listUserSQL, user).Scan(
 		&result.Role,
 		&result.Country,
 		&result.Email,