view controllers/externalwfs.go @ 354:aa24b5691838

Simplified WFS proxy code a bit.
author Sascha L. Teichmann <sascha.teichmann@intevation.de>
date Tue, 07 Aug 2018 16:40:08 +0200
parents 23d4a9104b0c
children e170075c22ac
line wrap: on
line source

package controllers

import (
	"compress/flate"
	"compress/gzip"
	"encoding/xml"
	"io"
	"io/ioutil"
	"log"
	"net/http"
	"net/url"
	"strings"

	"github.com/gorilla/mux"
	"golang.org/x/net/html/charset"

	"gemma.intevation.de/gemma/config"
)

type RoundTripFunc func(*http.Request) (*http.Response, error)

func (rtf RoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
	return rtf(req)
}

func externalWFSDirector(req *http.Request) {

	abort := func(format string, args ...interface{}) {
		log.Printf(format, args...)
		panic(http.ErrAbortHandler)
	}

	external := config.ExternalWFSs()
	if external == nil || len(external) == 0 {
		abort("No external WFS proxy config found\n")
	}
	vars := mux.Vars(req)
	wfs := vars["wfs"]
	rest := vars["rest"]

	log.Printf("rest: %s\n", rest)

	alias, found := external[wfs]
	if !found {
		abort("No config found for %s\n", wfs)
	}
	data, ok := alias.(map[string]interface{})
	if !ok {
		abort("error: badly configured external WFS %s\n", wfs)
	}

	urlS, found := data["url"]
	if !found {
		abort("error: missing url for external WFS %s\n", wfs)
	}

	prefix, ok := urlS.(string)
	if !ok {
		abort("error: badly configured url for external WFS %s\n", wfs)
	}

	log.Printf("%v\n", prefix)
	nURL := prefix + "/" + rest + "?" + req.URL.RawQuery
	log.Printf("%v\n", nURL)

	u, err := url.Parse(nURL)
	if err != nil {
		abort("Invalid url: %v\n", err)
	}
	req.URL = u
	req.Header.Set("X-Gemma-From", prefix)
	to := useHTTPS(req) + "://" + req.Host + "/api/externalwfs/" + wfs
	req.Header.Set("X-Gemma-To", to)

	req.Host = u.Host

	//log.Printf("headers: %v\n", req.Header)
}

func externalWFSTransport(req *http.Request) (*http.Response, error) {

	from := req.Header.Get("X-Gemma-From")
	to := req.Header.Get("X-Gemma-To")
	req.Header.Del("X-Gemma-From")
	req.Header.Del("X-Gemma-To")

	// To prevent some caching effects.
	req.Header.Del("If-None-Match")

	resp, err := http.DefaultTransport.RoundTrip(req)
	if err != nil {
		return nil, err
	}
	resp.Header.Set("X-Gemma-From", from)
	resp.Header.Set("X-Gemma-To", to)

	return resp, err
}

type nopCloser struct {
	io.Writer
}

func (nopCloser) Close() error { return nil }

func encoding(h http.Header) (
	func(io.Reader) (io.ReadCloser, error),
	func(io.Writer) (io.WriteCloser, error),
) {
	switch enc := h.Get("Content-Encoding"); {
	case strings.Contains(enc, "gzip"):
		log.Println("gzip compression")
		return func(r io.Reader) (io.ReadCloser, error) {
				return gzip.NewReader(r)
			},
			func(w io.Writer) (io.WriteCloser, error) {
				return gzip.NewWriter(w), nil
			}
	case strings.Contains(enc, "deflate"):
		log.Println("Deflate compression")
		return func(r io.Reader) (io.ReadCloser, error) {
				return flate.NewReader(r), nil
			},
			func(w io.Writer) (io.WriteCloser, error) {
				return flate.NewWriter(w, flate.DefaultCompression)
			}
	default:
		log.Println("No content compression")
		return func(r io.Reader) (io.ReadCloser, error) {
				if r2, ok := r.(io.ReadCloser); ok {
					return r2, nil
				}
				return ioutil.NopCloser(r), nil
			},
			func(w io.Writer) (io.WriteCloser, error) {
				if w2, ok := w.(io.WriteCloser); ok {
					return w2, nil
				}
				return nopCloser{w}, nil
			}
	}
}

func externalWFSModifyResponse(resp *http.Response) error {

	from := resp.Header.Get("X-Gemma-From")
	to := resp.Header.Get("X-Gemma-To")
	resp.Header.Del("X-Gemma-From")
	resp.Header.Del("X-Gemma-To")

	if isXML(resp.Header) {
		log.Printf("rewrite from %s to %s\n", from, to)

		pr, pw := io.Pipe()

		var (
			r   io.ReadCloser
			w   io.WriteCloser
			err error
		)

		reader, writer := encoding(resp.Header)

		if r, err = reader(resp.Body); err != nil {
			return err
		}

		if w, err = writer(pw); err != nil {
			return err
		}

		go func(force io.ReadCloser) {
			defer func() {
				//r.Close()
				w.Close()
				pw.Close()
				force.Close()
			}()
			if err := rewrite(w, r, from, to); err != nil {
				log.Printf("rewrite failed: %v\n", err)
				return
			}
			log.Println("rewrite successful")
		}(resp.Body)

		resp.Body = pr
	}
	return nil
}

func isXML(h http.Header) bool {
	for _, t := range h["Content-Type"] {
		t = strings.ToLower(t)
		if strings.Contains(t, "text/xml") ||
			strings.Contains(t, "application/xml") {
			return true
		}
	}
	return false
}

func rewrite(w io.Writer, r io.Reader, from, to string) error {

	decoder := xml.NewDecoder(r)
	decoder.CharsetReader = charset.NewReaderLabel

	encoder := xml.NewEncoder(w)

	replace := func(s string) string {
		return strings.Replace(s, from, to, -1)
	}

	var ns nsdef

tokens:
	for {
		tok, err := decoder.Token()
		switch {
		case tok == nil && err == io.EOF:
			break tokens
		case err != nil:
			return err
		}

		switch t := tok.(type) {
		case xml.StartElement:
			ns = ns.push()
			t = t.Copy()

			attr := make([]xml.Attr, len(t.Attr))

			for i, at := range t.Attr {
				switch {
				case at.Name.Space == "xmlns":
					ns.define(at.Value, at.Name.Local)
					at.Name.Local = "xmlns:" + at.Name.Local
					at.Name.Space = ""

				case at.Name.Space != "":
					if s := ns.lookup(at.Name.Space); s != "" {
						at.Name.Local = s + ":" + at.Name.Local
						at.Name.Space = ""
					}
				}

				attr[i] = at
			}
			if s := ns.lookup(t.Name.Space); s != "" {
				t.Name.Space = ""
				t.Name.Local = s + ":" + t.Name.Local
			}
			t.Attr = attr
			tok = t

		case xml.CharData:
			tok = xml.CharData(replace(string(t)))

		case xml.EndElement:
			if s := ns.lookup(t.Name.Space); s != "" {
				t.Name.Space = ""
				t.Name.Local = s + ":" + t.Name.Local
				tok = t
			}
			ns = ns.pop()
		}
		if err := encoder.EncodeToken(tok); err != nil {
			return err
		}
	}

	return encoder.Flush()
}

type nsdef []map[string]string

func (n nsdef) lookup(ns string) string {
	for i := len(n) - 1; i >= 0; i-- {
		if s := n[i][ns]; s != "" {
			return s
		}
	}
	return ""
}

func (n nsdef) push() nsdef {
	return append(n, make(map[string]string))
}

func (n nsdef) pop() nsdef {
	if l := len(n); l > 0 {
		n[l-1] = nil
		n = n[:l-1]
	}
	return n
}

func (n nsdef) define(ns, s string) {
	if n != nil {
		n[len(n)-1][ns] = s
	}
}