view pkg/controllers/proxy.go @ 5520:05db984d3db1

Improve performance of bottleneck area calculation Avoid buffer calculations by replacing them with simple distance comparisons and calculate the boundary of the result geometry only once per iteration. In some edge cases with very large numbers of iterations, this reduced the runtime of a bottleneck import by a factor of more than twenty.
author Tom Gottfried <tom@intevation.de>
date Thu, 21 Oct 2021 19:50:39 +0200
parents 5f47eeea988d
children 31973f6f5cca
line wrap: on
line source

// This is Free Software under GNU Affero General Public License v >= 3.0
// without warranty, see README.md and license for details.
//
// SPDX-License-Identifier: AGPL-3.0-or-later
// License-Filename: LICENSES/AGPL-3.0.txt
//
// Copyright (C) 2018 by via donau
//   – Österreichische Wasserstraßen-Gesellschaft mbH
// Software engineering by Intevation GmbH
//
// Author(s):
//  * Sascha L. Teichmann <sascha.teichmann@intevation.de>

package controllers

import (
	"compress/flate"
	"compress/gzip"
	"crypto/hmac"
	"crypto/sha256"
	"encoding/base64"
	"encoding/xml"
	"io"
	"io/ioutil"
	"net/http"
	"net/url"
	"regexp"
	"strings"

	"github.com/gorilla/mux"
	"golang.org/x/net/html/charset"

	"gemma.intevation.de/gemma/pkg/config"
	"gemma.intevation.de/gemma/pkg/log"
	"gemma.intevation.de/gemma/pkg/middleware"
)

// proxyBlackList is a set of URLs that should not be rewritten by the proxy.
var proxyBlackList = map[string]struct{}{
	"http://www.w3.org/2001/XMLSchema-instance": struct{}{},
	"http://www.w3.org/1999/xlink":              struct{}{},
	"http://www.w3.org/2001/XMLSchema":          struct{}{},
	"http://www.w3.org/XML/1998/namespace":      struct{}{},
	"http://www.opengis.net/wfs/2.0":            struct{}{},
	"http://www.opengis.net/ows/1.1":            struct{}{},
	"http://www.opengis.net/gml/3.2":            struct{}{},
	"http://www.opengis.net/fes/2.0":            struct{}{},
	"http://schemas.opengis.net/gml":            struct{}{},
	"http://www.opengis.net/wfs":                struct{}{},
}

func proxyDirector(lookup func(string) (string, bool)) func(*http.Request) {

	return func(req *http.Request) {

		//log.Debugf("proxyDirector: %s\n", req.RequestURI)

		abort := func(format string, args ...interface{}) {
			log.Errorf(format, args...)
			panic(http.ErrAbortHandler)
		}

		vars := mux.Vars(req)

		var s string

		if entry, found := vars["entry"]; found {
			if s, found = lookup(entry); !found {
				log.Warnf("cannot find entry '%s'\n", entry)
				panic(middleware.ErrNotFound)
			}
		} else {
			expectedMAC, err := base64.URLEncoding.DecodeString(vars["hash"])
			if err != nil {
				abort("Cannot base64 decode hash: %v\n", err)
			}
			url, err := base64.URLEncoding.DecodeString(vars["url"])
			if err != nil {
				abort("Cannot base64 decode url: %v\n", err)
			}

			mac := hmac.New(sha256.New, config.ProxyKey())
			mac.Write(url)
			messageMAC := mac.Sum(nil)

			s = string(url)

			if !hmac.Equal(messageMAC, expectedMAC) {
				abort("HMAC of URL %s failed.\n", s)
			}
		}

		nURL := s + "?" + req.URL.RawQuery
		//log.Debugf("%v\n", nURL)

		u, err := url.Parse(nURL)
		if err != nil {
			abort("Invalid url: %v\n", err)
		}
		req.URL = u

		req.Host = u.Host
		//req.Header.Del("If-None-Match")
		//log.Debugf("headers: %v\n", req.Header)
	}
}

type nopCloser struct {
	io.Writer
}

func (nopCloser) Close() error { return nil }

func encoding(h http.Header) (
	func(io.Reader) (io.ReadCloser, error),
	func(io.Writer) (io.WriteCloser, error),
) {
	switch enc := h.Get("Content-Encoding"); {
	case strings.Contains(enc, "gzip"):
		//log.Debugf("gzip compression")
		return func(r io.Reader) (io.ReadCloser, error) {
				return gzip.NewReader(r)
			},
			func(w io.Writer) (io.WriteCloser, error) {
				return gzip.NewWriter(w), nil
			}
	case strings.Contains(enc, "deflate"):
		//log.Debugf("deflate compression")
		return func(r io.Reader) (io.ReadCloser, error) {
				return flate.NewReader(r), nil
			},
			func(w io.Writer) (io.WriteCloser, error) {
				return flate.NewWriter(w, flate.DefaultCompression)
			}
	default:
		//log.Debugf("no content compression")
		return func(r io.Reader) (io.ReadCloser, error) {
				if r2, ok := r.(io.ReadCloser); ok {
					return r2, nil
				}
				return ioutil.NopCloser(r), nil
			},
			func(w io.Writer) (io.WriteCloser, error) {
				if w2, ok := w.(io.WriteCloser); ok {
					return w2, nil
				}
				return nopCloser{w}, nil
			}
	}
}

func proxyModifyResponse(suffix string) func(*http.Response) error {

	return func(resp *http.Response) error {

		resp.Header.Set("X-Content-Type-Options", "nosniff")

		if !isXML(resp.Header) {
			return nil
		}

		pr, pw := io.Pipe()

		var (
			r   io.ReadCloser
			w   io.WriteCloser
			err error
		)

		reader, writer := encoding(resp.Header)

		if r, err = reader(resp.Body); err != nil {
			return err
		}

		if w, err = writer(pw); err != nil {
			return err
		}

		go func(force io.ReadCloser) {
			//start := time.Now()
			defer func() {
				//r.Close()
				w.Close()
				pw.Close()
				force.Close()
				//log.Debugf("rewrite took %s\n", time.Since(start))
			}()
			if err := rewrite(suffix, w, r); err != nil {
				log.Errorf("rewrite failed: %v\n", err)
				return
			}
		}(resp.Body)

		resp.Body = pr

		return nil
	}
}

var xmlContentTypes = []string{
	"application/xml",
	"text/xml",
	"application/gml+xml",
	"application/vnd.ogc.wms_xml",
	"application/vnd.ogc.se_xml",
}

func isXML(h http.Header) bool {
	for _, t := range h["Content-Type"] {
		t = strings.ToLower(t)
		for _, ct := range xmlContentTypes {
			if strings.Contains(t, ct) {
				return true
			}
		}
	}
	return false
}

var replaceRe = regexp.MustCompile(`\b(https?://[^\s\?'"]*)`)

func replace(suffix, s string) string {

	proxyKey := config.ProxyKey()
	proxyPrefix := config.ProxyPrefix() + suffix

	return replaceRe.ReplaceAllStringFunc(s, func(s string) string {
		if _, found := proxyBlackList[s]; found {
			return s
		}
		mac := hmac.New(sha256.New, proxyKey)
		b := []byte(s)
		mac.Write(b)
		expectedMAC := mac.Sum(nil)

		hash := base64.URLEncoding.EncodeToString(expectedMAC)
		enc := base64.URLEncoding.EncodeToString(b)
		return proxyPrefix + hash + "/" + enc
	})
}

func rewrite(suffix string, w io.Writer, r io.Reader) error {

	decoder := xml.NewDecoder(r)
	decoder.CharsetReader = charset.NewReaderLabel

	encoder := xml.NewEncoder(w)

	var n nsdef

tokens:
	for {
		tok, err := decoder.Token()
		switch {
		case tok == nil && err == io.EOF:
			break tokens
		case err != nil:
			return err
		}

		switch t := tok.(type) {
		case xml.StartElement:
			t = t.Copy()

			isDef := n.isDef(t.Name.Space)
			n = n.push()

			for i := range t.Attr {
				t.Attr[i].Value = replace(suffix, t.Attr[i].Value)
				n.checkDef(&t.Attr[i])
			}

			for i := range t.Attr {
				n.adjust(&t.Attr[i])
			}

			switch {
			case isDef:
				t.Name.Space = ""
			default:
				if s := n.lookup(t.Name.Space); s != "" {
					t.Name.Space = ""
					t.Name.Local = s + ":" + t.Name.Local
				}
			}
			tok = t

		case xml.CharData:
			tok = xml.CharData(replace(suffix, string(t)))

		case xml.Directive:
			tok = xml.Directive(replace(suffix, string(t)))

		case xml.EndElement:
			s := n.lookup(t.Name.Space)

			n = n.pop()

			if n.isDef(t.Name.Space) {
				t.Name.Space = ""
			} else if s != "" {
				t.Name.Space = ""
				t.Name.Local = s + ":" + t.Name.Local
			}
			tok = t
		}

		if err := encoder.EncodeToken(tok); err != nil {
			return err
		}
	}

	return encoder.Flush()
}

type nsframe struct {
	def string
	ns  map[string]string
}

type nsdef []nsframe

func (n nsdef) setDef(def string) {
	if l := len(n); l > 0 {
		n[l-1].def = def
	}
}

func (n nsdef) isDef(s string) bool {
	for i := len(n) - 1; i >= 0; i-- {
		if x := n[i].def; x != "" {
			return s == x
		}
	}
	return false
}

func (n nsdef) define(ns, s string) {
	if l := len(n); l > 0 {
		n[l-1].ns[ns] = s
	}
}

func (n nsdef) lookup(ns string) string {
	for i := len(n) - 1; i >= 0; i-- {
		if s := n[i].ns[ns]; s != "" {
			return s
		}
	}
	return ""
}

func (n nsdef) checkDef(at *xml.Attr) {
	if at.Name.Space == "" && at.Name.Local == "xmlns" {
		n.setDef(at.Value)
	}
}

func (n nsdef) adjust(at *xml.Attr) {
	switch {
	case at.Name.Space == "xmlns":
		n.define(at.Value, at.Name.Local)
		at.Name.Local = "xmlns:" + at.Name.Local
		at.Name.Space = ""

	case at.Name.Space != "":
		if n.isDef(at.Name.Space) {
			at.Name.Space = ""
		} else if s := n.lookup(at.Name.Space); s != "" {
			at.Name.Local = s + ":" + at.Name.Local
			at.Name.Space = ""
		}
	}
}

func (n nsdef) push() nsdef {
	return append(n, nsframe{ns: make(map[string]string)})
}

func (n nsdef) pop() nsdef {
	if l := len(n); l > 0 {
		n[l-1] = nsframe{}
		n = n[:l-1]
	}
	return n
}