changeset 4242:1458c9b0fdaa json-handler-middleware

Made the sql.Conn in function accessible via the context of the request.
author Sascha L. Teichmann <sascha.teichmann@intevation.de>
date Thu, 22 Aug 2019 10:18:13 +0200
parents a4f76e170290
children d776110b4db0
files pkg/controllers/cross.go pkg/controllers/diff.go pkg/controllers/gauges.go pkg/controllers/importconfig.go pkg/controllers/importqueue.go pkg/controllers/json.go pkg/controllers/manualimports.go pkg/controllers/printtemplates.go pkg/controllers/publish.go pkg/controllers/pwreset.go pkg/controllers/search.go pkg/controllers/srimports.go pkg/controllers/surveys.go pkg/controllers/system.go pkg/controllers/user.go
diffstat 15 files changed, 61 insertions(+), 66 deletions(-) [+]
line wrap: on
line diff
--- a/pkg/controllers/cross.go	Thu Aug 22 09:20:38 2019 +0200
+++ b/pkg/controllers/cross.go	Thu Aug 22 10:18:13 2019 +0200
@@ -67,13 +67,13 @@
 func crossSection(
 	input interface{},
 	req *http.Request,
-	conn *sql.Conn,
 ) (jr JSONResult, err error) {
 
 	csi := input.(*models.CrossSectionInput)
 
 	start := time.Now()
 	ctx := req.Context()
+	conn := JSONConn(req)
 
 	tree, err := octree.FromCache(
 		ctx, conn,
--- a/pkg/controllers/diff.go	Thu Aug 22 09:20:38 2019 +0200
+++ b/pkg/controllers/diff.go	Thu Aug 22 10:18:13 2019 +0200
@@ -87,7 +87,6 @@
 func diffCalculation(
 	input interface{},
 	req *http.Request,
-	conn *sql.Conn,
 ) (jr JSONResult, err error) {
 
 	begin := time.Now()
@@ -97,6 +96,8 @@
 
 	ctx := req.Context()
 
+	conn := JSONConn(req)
+
 	var id int64
 	err = conn.QueryRowContext(
 		ctx,
--- a/pkg/controllers/gauges.go	Thu Aug 22 09:20:38 2019 +0200
+++ b/pkg/controllers/gauges.go	Thu Aug 22 10:18:13 2019 +0200
@@ -577,7 +577,6 @@
 func nashSutcliffe(
 	_ interface{},
 	req *http.Request,
-	conn *sql.Conn,
 ) (jr JSONResult, err error) {
 	gauge := mux.Vars(req)["gauge"]
 
@@ -604,7 +603,7 @@
 
 	var values []observedPredictedValues
 
-	if values, err = loadNashSutcliffeData(ctx, conn, isrs, when); err != nil {
+	if values, err = loadNashSutcliffeData(ctx, JSONConn(req), isrs, when); err != nil {
 		return
 	}
 
--- a/pkg/controllers/importconfig.go	Thu Aug 22 09:20:38 2019 +0200
+++ b/pkg/controllers/importconfig.go	Thu Aug 22 10:18:13 2019 +0200
@@ -31,7 +31,6 @@
 func runImportConfig(
 	_ interface{},
 	req *http.Request,
-	conn *sql.Conn,
 ) (jr JSONResult, err error) {
 
 	id, _ := strconv.ParseInt(mux.Vars(req)["id"], 10, 64)
@@ -39,7 +38,7 @@
 	ctx := req.Context()
 
 	var jobID int64
-	if jobID, err = imports.RunConfiguredImportContext(ctx, conn, id); err != nil {
+	if jobID, err = imports.RunConfiguredImportContext(ctx, JSONConn(req), id); err != nil {
 		return
 	}
 
@@ -59,7 +58,6 @@
 func modifyImportConfig(
 	input interface{},
 	req *http.Request,
-	conn *sql.Conn,
 ) (jr JSONResult, err error) {
 
 	ctx := req.Context()
@@ -68,6 +66,8 @@
 
 	id, _ := strconv.ParseInt(mux.Vars(req)["id"], 10, 64)
 
+	conn := JSONConn(req)
+
 	var pc *imports.PersistentConfig
 	pc, err = imports.LoadPersistentConfigContext(ctx, conn, id)
 	switch {
@@ -152,7 +152,6 @@
 func infoImportConfig(
 	_ interface{},
 	req *http.Request,
-	conn *sql.Conn,
 ) (jr JSONResult, err error) {
 
 	ctx := req.Context()
@@ -161,7 +160,7 @@
 
 	var cfg *imports.PersistentConfig
 
-	cfg, err = imports.LoadPersistentConfigContext(ctx, conn, id)
+	cfg, err = imports.LoadPersistentConfigContext(ctx, JSONConn(req), id)
 	switch {
 	case err != nil:
 		return
@@ -211,7 +210,6 @@
 func deleteImportConfig(
 	_ interface{},
 	req *http.Request,
-	conn *sql.Conn,
 ) (jr JSONResult, err error) {
 
 	ctx := req.Context()
@@ -219,7 +217,7 @@
 	id, _ := strconv.ParseInt(mux.Vars(req)["id"], 10, 64)
 
 	var tx *sql.Tx
-	if tx, err = conn.BeginTx(ctx, nil); err != nil {
+	if tx, err = JSONConn(req).BeginTx(ctx, nil); err != nil {
 		return
 	}
 	defer tx.Rollback()
@@ -262,7 +260,6 @@
 func addImportConfig(
 	input interface{},
 	req *http.Request,
-	conn *sql.Conn,
 ) (jr JSONResult, err error) {
 
 	cfg := input.(*imports.ImportConfigIn)
@@ -294,7 +291,7 @@
 	ctx := req.Context()
 
 	var tx *sql.Tx
-	if tx, err = conn.BeginTx(ctx, nil); err != nil {
+	if tx, err = JSONConn(req).BeginTx(ctx, nil); err != nil {
 		return
 	}
 	defer tx.Rollback()
@@ -332,14 +329,13 @@
 func listImportConfigs(
 	_ interface{},
 	req *http.Request,
-	conn *sql.Conn,
 ) (jr JSONResult, err error) {
 
 	ctx := req.Context()
 	configs := []*imports.ImportConfigOut{}
 
 	if err = imports.ListAllPersistentConfigurationsContext(
-		ctx, conn,
+		ctx, JSONConn(req),
 		func(config *imports.ImportConfigOut) error {
 			configs = append(configs, config)
 			return nil
--- a/pkg/controllers/importqueue.go	Thu Aug 22 09:20:38 2019 +0200
+++ b/pkg/controllers/importqueue.go	Thu Aug 22 10:18:13 2019 +0200
@@ -231,7 +231,6 @@
 func listImports(
 	_ interface{},
 	req *http.Request,
-	conn *sql.Conn,
 ) (jr JSONResult, err error) {
 
 	var list, before, after *filledStmt
@@ -242,6 +241,8 @@
 
 	ctx := req.Context()
 
+	conn := JSONConn(req)
+
 	// Fast path for counting
 
 	switch count := strings.ToLower(req.FormValue("count")); count {
@@ -324,13 +325,14 @@
 func importLogs(
 	_ interface{},
 	req *http.Request,
-	conn *sql.Conn,
 ) (jr JSONResult, err error) {
 
 	ctx := req.Context()
 
 	id, _ := strconv.ParseInt(mux.Vars(req)["id"], 10, 64)
 
+	conn := JSONConn(req)
+
 	// Check if he have such a import job first.
 	var summary sql.NullString
 	var enqueued time.Time
@@ -399,14 +401,13 @@
 func deleteImport(
 	_ interface{},
 	req *http.Request,
-	conn *sql.Conn,
 ) (jr JSONResult, err error) {
 
 	ctx := req.Context()
 	id, _ := strconv.ParseInt(mux.Vars(req)["id"], 10, 64)
 
 	var tx *sql.Tx
-	tx, err = conn.BeginTx(ctx, nil)
+	tx, err = JSONConn(req).BeginTx(ctx, nil)
 	if err != nil {
 		return
 	}
@@ -467,7 +468,6 @@
 func reviewImports(
 	reviews interface{},
 	req *http.Request,
-	conn *sql.Conn,
 ) (JSONResult, error) {
 
 	rs := *reviews.(*[]models.Review)
@@ -480,6 +480,8 @@
 
 	results := make([]reviewResult, len(rs))
 
+	conn := JSONConn(req)
+
 	for i := range rs {
 		rev := &rs[i]
 		msg, err := decideImport(req, conn, rev.ID, string(rev.State))
@@ -500,7 +502,6 @@
 func reviewImport(
 	_ interface{},
 	req *http.Request,
-	conn *sql.Conn,
 ) (jr JSONResult, err error) {
 
 	vars := mux.Vars(req)
@@ -508,7 +509,7 @@
 	state := vars["state"]
 
 	var msg string
-	if msg, err = decideImport(req, conn, id, state); err != nil {
+	if msg, err = decideImport(req, JSONConn(req), id, state); err != nil {
 		return
 	}
 
--- a/pkg/controllers/json.go	Thu Aug 22 09:20:38 2019 +0200
+++ b/pkg/controllers/json.go	Thu Aug 22 10:18:13 2019 +0200
@@ -14,6 +14,7 @@
 package controllers
 
 import (
+	"context"
 	"database/sql"
 	"encoding/json"
 	"fmt"
@@ -49,7 +50,7 @@
 	// in is the data structure returned by Input. Its nil if Input is nil.
 	// req is the incoming HTTP request.
 	// conn is the impersonated connection to the database.
-	Handle func(in interface{}, rep *http.Request, conn *sql.Conn) (JSONResult, error)
+	Handle func(in interface{}, rep *http.Request) (JSONResult, error)
 	// NoConn if set to true no database connection is established and
 	// the conn parameter of the Handle call is nil.
 	NoConn bool
@@ -74,6 +75,17 @@
 	return fmt.Sprintf("%d: %s", je.Code, je.Message)
 }
 
+type jsonHandlerType int
+
+const jsonDBKey jsonHandlerType = 0
+
+func JSONConn(req *http.Request) *sql.Conn {
+	if conn, ok := req.Context().Value(jsonDBKey).(*sql.Conn); ok {
+		return conn
+	}
+	return nil
+}
+
 // ServeHTTP makes the JSONHandler a middleware.
 func (j *JSONHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
 
@@ -101,15 +113,18 @@
 
 	if token, ok := auth.GetToken(req); ok && !j.NoConn {
 		if session := auth.Sessions.Session(token); session != nil {
-			err = auth.RunAs(req.Context(), session.User, func(conn *sql.Conn) error {
-				jr, err = j.Handle(input, req, conn)
+			parent := req.Context()
+			err = auth.RunAs(parent, session.User, func(conn *sql.Conn) error {
+				ctx := context.WithValue(parent, jsonDBKey, conn)
+				r := req.WithContext(ctx)
+				jr, err = j.Handle(input, r)
 				return err
 			})
 		} else {
 			err = auth.ErrNoSuchToken
 		}
 	} else {
-		jr, err = j.Handle(input, req, nil)
+		jr, err = j.Handle(input, req)
 	}
 
 	if err != nil {
--- a/pkg/controllers/manualimports.go	Thu Aug 22 09:20:38 2019 +0200
+++ b/pkg/controllers/manualimports.go	Thu Aug 22 10:18:13 2019 +0200
@@ -15,16 +15,16 @@
 package controllers
 
 import (
-	"database/sql"
 	"log"
 	"net/http"
 	"time"
 
+	"github.com/gorilla/mux"
+
 	"gemma.intevation.de/gemma/pkg/auth"
 	"gemma.intevation.de/gemma/pkg/common"
 	"gemma.intevation.de/gemma/pkg/imports"
 	"gemma.intevation.de/gemma/pkg/models"
-	"github.com/gorilla/mux"
 )
 
 func importModel(req *http.Request) interface{} {
@@ -40,7 +40,6 @@
 func manualImport(
 	input interface{},
 	req *http.Request,
-	_ *sql.Conn,
 ) (jr JSONResult, err error) {
 
 	kind := imports.JobKind(mux.Vars(req)["kind"])
--- a/pkg/controllers/printtemplates.go	Thu Aug 22 09:20:38 2019 +0200
+++ b/pkg/controllers/printtemplates.go	Thu Aug 22 10:18:13 2019 +0200
@@ -73,7 +73,6 @@
 func listPrintTemplates(
 	_ interface{},
 	req *http.Request,
-	conn *sql.Conn,
 ) (jr JSONResult, err error) {
 
 	ts := mux.Vars(req)["type"]
@@ -94,7 +93,7 @@
 	stmt.WriteString(" ORDER BY date_info DESC")
 
 	var rows *sql.Rows
-	if rows, err = conn.QueryContext(req.Context(), stmt.String(), args...); err != nil {
+	if rows, err = JSONConn(req).QueryContext(req.Context(), stmt.String(), args...); err != nil {
 		return
 	}
 	defer rows.Close()
@@ -134,7 +133,6 @@
 func fetchPrintTemplate(
 	_ interface{},
 	req *http.Request,
-	conn *sql.Conn,
 ) (jr JSONResult, err error) {
 
 	vars := mux.Vars(req)
@@ -142,7 +140,7 @@
 
 	ctx := req.Context()
 	var data pgtype.Bytea
-	err = conn.QueryRowContext(ctx, selectPrintTemplateSQL, name, typ).Scan(&data)
+	err = JSONConn(req).QueryRowContext(ctx, selectPrintTemplateSQL, name, typ).Scan(&data)
 
 	switch {
 	case err == sql.ErrNoRows:
@@ -167,7 +165,6 @@
 func createPrintTemplate(
 	input interface{},
 	req *http.Request,
-	conn *sql.Conn,
 ) (jr JSONResult, err error) {
 
 	vars := mux.Vars(req)
@@ -192,7 +189,7 @@
 
 	ctx := req.Context()
 	var tx *sql.Tx
-	if tx, err = conn.BeginTx(ctx, nil); err != nil {
+	if tx, err = JSONConn(req).BeginTx(ctx, nil); err != nil {
 		return
 	}
 	defer tx.Rollback()
@@ -233,7 +230,6 @@
 func deletePrintTemplate(
 	_ interface{},
 	req *http.Request,
-	conn *sql.Conn,
 ) (jr JSONResult, err error) {
 
 	vars := mux.Vars(req)
@@ -241,7 +237,7 @@
 
 	ctx := req.Context()
 	var tx *sql.Tx
-	if tx, err = conn.BeginTx(ctx, nil); err != nil {
+	if tx, err = JSONConn(req).BeginTx(ctx, nil); err != nil {
 		return
 	}
 	defer tx.Rollback()
@@ -286,7 +282,6 @@
 func updatePrintTemplate(
 	input interface{},
 	req *http.Request,
-	conn *sql.Conn,
 ) (jr JSONResult, err error) {
 
 	vars := mux.Vars(req)
@@ -311,7 +306,7 @@
 
 	ctx := req.Context()
 	var tx *sql.Tx
-	if tx, err = conn.BeginTx(ctx, nil); err != nil {
+	if tx, err = JSONConn(req).BeginTx(ctx, nil); err != nil {
 		return
 	}
 	defer tx.Rollback()
--- a/pkg/controllers/publish.go	Thu Aug 22 09:20:38 2019 +0200
+++ b/pkg/controllers/publish.go	Thu Aug 22 10:18:13 2019 +0200
@@ -14,13 +14,12 @@
 package controllers
 
 import (
-	"database/sql"
 	"net/http"
 
 	"gemma.intevation.de/gemma/pkg/models"
 )
 
-func published(_ interface{}, req *http.Request, _ *sql.Conn) (jr JSONResult, err error) {
+func published(_ interface{}, req *http.Request) (jr JSONResult, err error) {
 	jr = JSONResult{
 		Result: struct {
 			Internal []models.IntEntry `json:"internal"`
--- a/pkg/controllers/pwreset.go	Thu Aug 22 09:20:38 2019 +0200
+++ b/pkg/controllers/pwreset.go	Thu Aug 22 10:18:13 2019 +0200
@@ -238,7 +238,6 @@
 func passwordResetRequest(
 	input interface{},
 	req *http.Request,
-	_ *sql.Conn,
 ) (jr JSONResult, err error) {
 
 	// We do the checks and the emailing in background
--- a/pkg/controllers/search.go	Thu Aug 22 09:20:38 2019 +0200
+++ b/pkg/controllers/search.go	Thu Aug 22 10:18:13 2019 +0200
@@ -41,7 +41,6 @@
 func searchFeature(
 	input interface{},
 	req *http.Request,
-	db *sql.Conn,
 ) (jr JSONResult, err error) {
 
 	s := input.(*models.SearchRequest)
@@ -53,7 +52,7 @@
 	}
 
 	var result string
-	err = db.QueryRowContext(
+	err = JSONConn(req).QueryRowContext(
 		req.Context(),
 		searchMostSQL,
 		s.SearchString,
@@ -70,11 +69,10 @@
 func listBottlenecks(
 	_ interface{},
 	req *http.Request,
-	conn *sql.Conn,
 ) (jr JSONResult, err error) {
 
 	var result string
-	err = conn.QueryRowContext(
+	err = JSONConn(req).QueryRowContext(
 		req.Context(), listBottlenecksSQL).Scan(&result)
 
 	switch {
--- a/pkg/controllers/srimports.go	Thu Aug 22 09:20:38 2019 +0200
+++ b/pkg/controllers/srimports.go	Thu Aug 22 10:18:13 2019 +0200
@@ -15,7 +15,6 @@
 
 import (
 	"archive/zip"
-	"database/sql"
 	"encoding/hex"
 	"fmt"
 	"log"
@@ -193,7 +192,6 @@
 func uploadSoundingResult(
 	_ interface{},
 	req *http.Request,
-	conn *sql.Conn,
 ) (jr JSONResult, err error) {
 
 	var dir string
@@ -248,7 +246,7 @@
 				messages = append(messages,
 					fmt.Sprintf("'meta.json' found but invalid: %v", err))
 			} else {
-				errs := meta.Validate(req.Context(), conn)
+				errs := meta.Validate(req.Context(), JSONConn(req))
 				for _, err := range errs {
 					messages = append(messages,
 						fmt.Sprintf("invalid 'meta.json': %v", err))
--- a/pkg/controllers/surveys.go	Thu Aug 22 09:20:38 2019 +0200
+++ b/pkg/controllers/surveys.go	Thu Aug 22 10:18:13 2019 +0200
@@ -44,14 +44,13 @@
 func listSurveys(
 	_ interface{},
 	req *http.Request,
-	db *sql.Conn,
 ) (jr JSONResult, err error) {
 
 	bottleneckName := mux.Vars(req)["bottleneck"]
 
 	var rows *sql.Rows
 
-	rows, err = db.QueryContext(req.Context(), listSurveysSQL, bottleneckName)
+	rows, err = JSONConn(req).QueryContext(req.Context(), listSurveysSQL, bottleneckName)
 	if err != nil {
 		return
 	}
--- a/pkg/controllers/system.go	Thu Aug 22 09:20:38 2019 +0200
+++ b/pkg/controllers/system.go	Thu Aug 22 10:18:13 2019 +0200
@@ -59,7 +59,6 @@
 
 func showSystemLog(
 	_ interface{}, req *http.Request,
-	_ *sql.Conn,
 ) (jr JSONResult, err error) {
 
 	serviceName := mux.Vars(req)["service"]
@@ -102,7 +101,6 @@
 
 func getSystemConfig(
 	_ interface{}, req *http.Request,
-	_ *sql.Conn,
 ) (jr JSONResult, err error) {
 
 	cfg := config.PublishedConfig()
@@ -123,11 +121,10 @@
 func getSystemSettings(
 	_ interface{},
 	req *http.Request,
-	conn *sql.Conn,
 ) (jr JSONResult, err error) {
 
 	var rows *sql.Rows
-	if rows, err = conn.QueryContext(req.Context(), getSettingsSQL); err != nil {
+	if rows, err = JSONConn(req).QueryContext(req.Context(), getSettingsSQL); err != nil {
 		return
 	}
 	defer rows.Close()
@@ -311,14 +308,13 @@
 func setSystemSettings(
 	input interface{},
 	req *http.Request,
-	conn *sql.Conn,
 ) (jr JSONResult, err error) {
 
 	settings := input.(*map[string]string)
 
 	ctx := req.Context()
 	var tx *sql.Tx
-	if tx, err = conn.BeginTx(ctx, nil); err != nil {
+	if tx, err = JSONConn(req).BeginTx(ctx, nil); err != nil {
 		return
 	}
 	defer tx.Rollback()
--- a/pkg/controllers/user.go	Thu Aug 22 09:20:38 2019 +0200
+++ b/pkg/controllers/user.go	Thu Aug 22 10:18:13 2019 +0200
@@ -99,7 +99,6 @@
 
 func deleteUser(
 	_ interface{}, req *http.Request,
-	db *sql.Conn,
 ) (jr JSONResult, err error) {
 
 	user := mux.Vars(req)["user"]
@@ -116,6 +115,8 @@
 
 	ctx := req.Context()
 
+	db := JSONConn(req)
+
 	// Remove scheduled tasks.
 	ids, err2 := scheduler.ScheduledUserIDs(ctx, db, user)
 	if err2 == nil {
@@ -150,7 +151,6 @@
 func updateUser(
 	input interface{},
 	req *http.Request,
-	db *sql.Conn,
 ) (jr JSONResult, err error) {
 
 	user := models.UserName(mux.Vars(req)["user"])
@@ -162,6 +162,8 @@
 	newUser := input.(*models.User)
 	var res sql.Result
 
+	db := JSONConn(req)
+
 	if s, _ := auth.GetSession(req); s.Roles.Has("sys_admin") {
 		if newUser.Extent == nil {
 			res, err = db.ExecContext(
@@ -233,11 +235,12 @@
 func createUser(
 	input interface{},
 	req *http.Request,
-	db *sql.Conn,
 ) (jr JSONResult, err error) {
 
 	user := input.(*models.User)
 
+	db := JSONConn(req)
+
 	if user.Extent == nil {
 		_, err = db.ExecContext(
 			req.Context(),
@@ -280,12 +283,11 @@
 func listUsers(
 	_ interface{},
 	req *http.Request,
-	db *sql.Conn,
 ) (jr JSONResult, err error) {
 
 	var rows *sql.Rows
 
-	rows, err = db.QueryContext(req.Context(), listUsersSQL)
+	rows, err = JSONConn(req).QueryContext(req.Context(), listUsersSQL)
 	if err != nil {
 		return
 	}
@@ -319,7 +321,6 @@
 func listUser(
 	_ interface{},
 	req *http.Request,
-	db *sql.Conn,
 ) (jr JSONResult, err error) {
 
 	user := models.UserName(mux.Vars(req)["user"])
@@ -333,7 +334,7 @@
 		Extent: &models.BoundingBox{},
 	}
 
-	err = db.QueryRowContext(req.Context(), listUserSQL, user).Scan(
+	err = JSONConn(req).QueryRowContext(req.Context(), listUserSQL, user).Scan(
 		&result.Role,
 		&result.Country,
 		&result.Email,
@@ -359,7 +360,6 @@
 func sendTestMail(
 	_ interface{},
 	req *http.Request,
-	db *sql.Conn,
 ) (jr JSONResult, err error) {
 
 	user := models.UserName(mux.Vars(req)["user"])
@@ -373,7 +373,7 @@
 		Extent: &models.BoundingBox{},
 	}
 
-	err = db.QueryRowContext(req.Context(), listUserSQL, user).Scan(
+	err = JSONConn(req).QueryRowContext(req.Context(), listUserSQL, user).Scan(
 		&userData.Role,
 		&userData.Country,
 		&userData.Email,