diff pkg/controllers/json.go @ 4242:1458c9b0fdaa json-handler-middleware

Made the sql.Conn in function accessible via the context of the request.
author Sascha L. Teichmann <sascha.teichmann@intevation.de>
date Thu, 22 Aug 2019 10:18:13 +0200
parents 09f9ae3d0526
children d776110b4db0
line wrap: on
line diff
--- a/pkg/controllers/json.go	Thu Aug 22 09:20:38 2019 +0200
+++ b/pkg/controllers/json.go	Thu Aug 22 10:18:13 2019 +0200
@@ -14,6 +14,7 @@
 package controllers
 
 import (
+	"context"
 	"database/sql"
 	"encoding/json"
 	"fmt"
@@ -49,7 +50,7 @@
 	// in is the data structure returned by Input. Its nil if Input is nil.
 	// req is the incoming HTTP request.
 	// conn is the impersonated connection to the database.
-	Handle func(in interface{}, rep *http.Request, conn *sql.Conn) (JSONResult, error)
+	Handle func(in interface{}, rep *http.Request) (JSONResult, error)
 	// NoConn if set to true no database connection is established and
 	// the conn parameter of the Handle call is nil.
 	NoConn bool
@@ -74,6 +75,17 @@
 	return fmt.Sprintf("%d: %s", je.Code, je.Message)
 }
 
+type jsonHandlerType int
+
+const jsonDBKey jsonHandlerType = 0
+
+func JSONConn(req *http.Request) *sql.Conn {
+	if conn, ok := req.Context().Value(jsonDBKey).(*sql.Conn); ok {
+		return conn
+	}
+	return nil
+}
+
 // ServeHTTP makes the JSONHandler a middleware.
 func (j *JSONHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
 
@@ -101,15 +113,18 @@
 
 	if token, ok := auth.GetToken(req); ok && !j.NoConn {
 		if session := auth.Sessions.Session(token); session != nil {
-			err = auth.RunAs(req.Context(), session.User, func(conn *sql.Conn) error {
-				jr, err = j.Handle(input, req, conn)
+			parent := req.Context()
+			err = auth.RunAs(parent, session.User, func(conn *sql.Conn) error {
+				ctx := context.WithValue(parent, jsonDBKey, conn)
+				r := req.WithContext(ctx)
+				jr, err = j.Handle(input, r)
 				return err
 			})
 		} else {
 			err = auth.ErrNoSuchToken
 		}
 	} else {
-		jr, err = j.Handle(input, req, nil)
+		jr, err = j.Handle(input, req)
 	}
 
 	if err != nil {