diff pkg/middleware/jsonhandler.go @ 4244:4394daeea96a json-handler-middleware

Moved JSONHandler into middleware package.
author Sascha L. Teichmann <sascha.teichmann@intevation.de>
date Thu, 22 Aug 2019 11:26:48 +0200
parents pkg/controllers/json.go@d776110b4db0
children f4ec3558460e
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/pkg/middleware/jsonhandler.go	Thu Aug 22 11:26:48 2019 +0200
@@ -0,0 +1,208 @@
+// This is Free Software under GNU Affero General Public License v >= 3.0
+// without warranty, see README.md and license for details.
+//
+// SPDX-License-Identifier: AGPL-3.0-or-later
+// License-Filename: LICENSES/AGPL-3.0.txt
+//
+// Copyright (C) 2018 by via donau
+//   – Österreichische Wasserstraßen-Gesellschaft mbH
+// Software engineering by Intevation GmbH
+//
+// Author(s):
+//  * Sascha L. Teichmann <sascha.teichmann@intevation.de>
+
+package middleware
+
+import (
+	"context"
+	"database/sql"
+	"encoding/json"
+	"fmt"
+	"io"
+	"log"
+	"net/http"
+
+	"github.com/jackc/pgx"
+
+	"gemma.intevation.de/gemma/pkg/auth"
+)
+
+// JSONResult defines the return type of JSONHandler handler function.
+type JSONResult struct {
+	// Code is the HTTP status code to be set which defaults to http.StatusOK (200).
+	Code int
+	// Result is serialized to JSON.
+	// If the type is an io.Reader its copied through.
+	Result interface{}
+}
+
+// JSONDefaultLimit is default size limit in bytes of an accepted
+// input document.
+const JSONDefaultLimit = 2048
+
+// JSONHandler implements a middleware to ease the handing JSON input
+// streams and return JSON documents as output.
+type JSONHandler struct {
+	// Input (if not nil) is called to fill a data structure
+	// returned by this function.
+	Input func(*http.Request) interface{}
+	// Handle is called to handle the incoming HTTP request.
+	// in is the data structure returned by Input. Its nil if Input is nil.
+	Handle func(rep *http.Request) (JSONResult, error)
+	// NoConn if set to true no database connection is established and
+	// the conn parameter of the Handle call is nil.
+	NoConn bool
+	// Limit overides the default size of accepted input documents.
+	// Set to a negative value to allow an arbitrary size.
+	// Handle with care!
+	Limit int64
+}
+
+// JSONError is an error if returned by the JSONHandler.Handle function
+// which ends up encoded as a JSON document.
+type JSONError struct {
+	// Code is the HTTP status code of the result defaults
+	// to http.StatusInternalServerError if not set.
+	Code int
+	// The message of the error.
+	Message string
+}
+
+// Error implements the error interface.
+func (je JSONError) Error() string {
+	return fmt.Sprintf("%d: %s", je.Code, je.Message)
+}
+
+type jsonHandlerType int
+
+const (
+	jsonHandlerConnKey jsonHandlerType = iota
+	jsonHandlerInputKey
+)
+
+// JSONConn extracts the impersonated sql.Conn from the context of the request.
+func JSONConn(req *http.Request) *sql.Conn {
+	if conn, ok := req.Context().Value(jsonHandlerConnKey).(*sql.Conn); ok {
+		return conn
+	}
+	return nil
+}
+
+// JSONInput extracts the de-serialized input from the context of the request.
+func JSONInput(req *http.Request) interface{} {
+	return req.Context().Value(jsonHandlerInputKey)
+}
+
+// ServeHTTP makes the JSONHandler a middleware.
+func (j *JSONHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
+
+	if j.Input != nil {
+		input := j.Input(req)
+		defer req.Body.Close()
+		var r io.Reader
+		switch {
+		case j.Limit == 0:
+			r = io.LimitReader(req.Body, JSONDefaultLimit)
+		case j.Limit > 0:
+			r = io.LimitReader(req.Body, j.Limit)
+		default:
+			r = req.Body
+		}
+		if err := json.NewDecoder(r).Decode(input); err != nil {
+			http.Error(rw, "error: "+err.Error(), http.StatusBadRequest)
+			return
+		}
+		parent := req.Context()
+		ctx := context.WithValue(parent, jsonHandlerInputKey, input)
+		req = req.WithContext(ctx)
+	}
+
+	var jr JSONResult
+	var err error
+
+	if token, ok := auth.GetToken(req); ok && !j.NoConn {
+		if session := auth.Sessions.Session(token); session != nil {
+			parent := req.Context()
+			err = auth.RunAs(parent, session.User, func(conn *sql.Conn) error {
+				ctx := context.WithValue(parent, jsonHandlerConnKey, conn)
+				req = req.WithContext(ctx)
+				jr, err = j.Handle(req)
+				return err
+			})
+		} else {
+			err = auth.ErrNoSuchToken
+		}
+	} else {
+		jr, err = j.Handle(req)
+	}
+
+	if err != nil {
+		log.Printf("error: %v\n", err)
+		switch e := err.(type) {
+		case pgx.PgError:
+			var res = struct {
+				Result  string `json:"result"`
+				Code    string `json:"code"`
+				Message string `json:"message"`
+			}{
+				Result:  "failure",
+				Code:    e.Code,
+				Message: e.Message,
+			}
+			rw.Header().Set("Content-Type", "application/json")
+			rw.WriteHeader(http.StatusInternalServerError)
+			if err := json.NewEncoder(rw).Encode(&res); err != nil {
+				log.Printf("error: %v\n", err)
+			}
+		case JSONError:
+			rw.Header().Set("Content-Type", "application/json")
+			if e.Code == 0 {
+				e.Code = http.StatusInternalServerError
+			}
+			rw.WriteHeader(e.Code)
+			var res = struct {
+				Message string `json:"message"`
+			}{
+				Message: e.Message,
+			}
+			if err := json.NewEncoder(rw).Encode(&res); err != nil {
+				log.Printf("error: %v\n", err)
+			}
+		default:
+			http.Error(rw,
+				"error: "+err.Error(),
+				http.StatusInternalServerError)
+		}
+		return
+	}
+
+	if jr.Code == 0 {
+		jr.Code = http.StatusOK
+	}
+
+	if jr.Code != http.StatusNoContent {
+		rw.Header().Set("Content-Type", "application/json")
+	}
+	rw.WriteHeader(jr.Code)
+	if jr.Code != http.StatusNoContent {
+		var err error
+		if r, ok := jr.Result.(io.Reader); ok {
+			_, err = io.Copy(rw, r)
+		} else {
+			err = json.NewEncoder(rw).Encode(jr.Result)
+		}
+		if err != nil {
+			log.Printf("error: %v\n", err)
+		}
+	}
+}
+
+// SendJSON sends data JSON encoded to the response writer
+// with a given HTTP status code.
+func SendJSON(rw http.ResponseWriter, code int, data interface{}) {
+	rw.Header().Set("Content-Type", "application/json")
+	rw.WriteHeader(code)
+	if err := json.NewEncoder(rw).Encode(data); err != nil {
+		log.Printf("error: %v\n", err)
+	}
+}