Mercurial > gemma
diff pkg/controllers/proxy.go @ 414:c1047fd04a3a
Moved project specific Go packages to new pkg folder.
author | Sascha L. Teichmann <sascha.teichmann@intevation.de> |
---|---|
date | Wed, 15 Aug 2018 17:30:50 +0200 |
parents | controllers/proxy.go@cdd63547930a |
children | 6627c48363a0 |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pkg/controllers/proxy.go Wed Aug 15 17:30:50 2018 +0200 @@ -0,0 +1,381 @@ +package controllers + +import ( + "compress/flate" + "compress/gzip" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/xml" + "io" + "io/ioutil" + "log" + "net/http" + "net/url" + "regexp" + "strings" + "time" + + "github.com/gorilla/mux" + "golang.org/x/net/html/charset" + + "gemma.intevation.de/gemma/pkg/config" +) + +// proxyBlackList is a set of URLs that should not be rewritten by the proxy. +var proxyBlackList = map[string]struct{}{ + "http://www.w3.org/2001/XMLSchema-instance": struct{}{}, + "http://www.w3.org/1999/xlink": struct{}{}, + "http://www.w3.org/2001/XMLSchema": struct{}{}, + "http://www.w3.org/XML/1998/namespace": struct{}{}, + "http://www.opengis.net/wfs/2.0": struct{}{}, + "http://www.opengis.net/ows/1.1": struct{}{}, + "http://www.opengis.net/gml/3.2": struct{}{}, + "http://www.opengis.net/fes/2.0": struct{}{}, + "http://schemas.opengis.net/gml": struct{}{}, +} + +func findEntry(entry string) (string, bool) { + external := config.ExternalWFSs() + if external == nil || len(external) == 0 { + return "", false + } + alias, found := external[entry] + if !found { + return "", false + } + data, ok := alias.(map[string]interface{}) + if !ok { + return "", false + } + urlS, found := data["url"] + if !found { + return "", false + } + url, ok := urlS.(string) + return url, ok +} + +func proxyDirector(req *http.Request) { + + log.Printf("proxyDirector: %s\n", req.RequestURI) + + abort := func(format string, args ...interface{}) { + log.Printf(format, args...) + panic(http.ErrAbortHandler) + } + + vars := mux.Vars(req) + + var s string + + if entry, found := vars["entry"]; found { + if s, found = findEntry(entry); !found { + abort("Cannot find entry '%s'\n", entry) + } + } else { + expectedMAC, err := base64.URLEncoding.DecodeString(vars["hash"]) + if err != nil { + abort("Cannot base64 decode hash: %v\n", err) + } + url, err := base64.URLEncoding.DecodeString(vars["url"]) + if err != nil { + abort("Cannot base64 decode url: %v\n", err) + } + + mac := hmac.New(sha256.New, config.ProxyKey()) + mac.Write(url) + messageMAC := mac.Sum(nil) + + s = string(url) + + if !hmac.Equal(messageMAC, expectedMAC) { + abort("HMAC of URL %s failed.\n", s) + } + } + + nURL := s + "?" + req.URL.RawQuery + //log.Printf("%v\n", nURL) + + u, err := url.Parse(nURL) + if err != nil { + abort("Invalid url: %v\n", err) + } + req.URL = u + + req.Host = u.Host + //req.Header.Del("If-None-Match") + //log.Printf("headers: %v\n", req.Header) +} + +type nopCloser struct { + io.Writer +} + +func (nopCloser) Close() error { return nil } + +func encoding(h http.Header) ( + func(io.Reader) (io.ReadCloser, error), + func(io.Writer) (io.WriteCloser, error), +) { + switch enc := h.Get("Content-Encoding"); { + case strings.Contains(enc, "gzip"): + log.Println("gzip compression") + return func(r io.Reader) (io.ReadCloser, error) { + return gzip.NewReader(r) + }, + func(w io.Writer) (io.WriteCloser, error) { + return gzip.NewWriter(w), nil + } + case strings.Contains(enc, "deflate"): + log.Println("Deflate compression") + return func(r io.Reader) (io.ReadCloser, error) { + return flate.NewReader(r), nil + }, + func(w io.Writer) (io.WriteCloser, error) { + return flate.NewWriter(w, flate.DefaultCompression) + } + default: + log.Println("No content compression") + return func(r io.Reader) (io.ReadCloser, error) { + if r2, ok := r.(io.ReadCloser); ok { + return r2, nil + } + return ioutil.NopCloser(r), nil + }, + func(w io.Writer) (io.WriteCloser, error) { + if w2, ok := w.(io.WriteCloser); ok { + return w2, nil + } + return nopCloser{w}, nil + } + } +} + +func proxyModifyResponse(resp *http.Response) error { + + if !isXML(resp.Header) { + return nil + } + + pr, pw := io.Pipe() + + var ( + r io.ReadCloser + w io.WriteCloser + err error + ) + + reader, writer := encoding(resp.Header) + + if r, err = reader(resp.Body); err != nil { + return err + } + + if w, err = writer(pw); err != nil { + return err + } + + go func(force io.ReadCloser) { + start := time.Now() + defer func() { + //r.Close() + w.Close() + pw.Close() + force.Close() + log.Printf("rewrite took %s\n", time.Since(start)) + }() + if err := rewrite(w, r); err != nil { + log.Printf("rewrite failed: %v\n", err) + return + } + log.Println("rewrite successful") + }(resp.Body) + + resp.Body = pr + + return nil +} + +var xmlContentTypes = []string{ + "application/xml", + "text/xml", + "application/gml+xml", +} + +func isXML(h http.Header) bool { + for _, t := range h["Content-Type"] { + t = strings.ToLower(t) + for _, ct := range xmlContentTypes { + if strings.Contains(t, ct) { + return true + } + } + } + return false +} + +var replaceRe = regexp.MustCompile(`\b(https?://[^\s\?]*)`) + +func replace(s string) string { + + proxyKey := config.ProxyKey() + proxyPrefix := config.ProxyPrefix() + "/api/proxy/" + + return replaceRe.ReplaceAllStringFunc(s, func(s string) string { + if _, found := proxyBlackList[s]; found { + return s + } + mac := hmac.New(sha256.New, proxyKey) + b := []byte(s) + mac.Write(b) + expectedMAC := mac.Sum(nil) + + hash := base64.URLEncoding.EncodeToString(expectedMAC) + enc := base64.URLEncoding.EncodeToString(b) + return proxyPrefix + hash + "/" + enc + }) +} + +func rewrite(w io.Writer, r io.Reader) error { + + decoder := xml.NewDecoder(r) + decoder.CharsetReader = charset.NewReaderLabel + + encoder := xml.NewEncoder(w) + + var n nsdef + +tokens: + for { + tok, err := decoder.Token() + switch { + case tok == nil && err == io.EOF: + break tokens + case err != nil: + return err + } + + switch t := tok.(type) { + case xml.StartElement: + t = t.Copy() + + isDef := n.isDef(t.Name.Space) + n = n.push() + + for i := range t.Attr { + t.Attr[i].Value = replace(t.Attr[i].Value) + n.checkDef(&t.Attr[i]) + } + + for i := range t.Attr { + n.adjust(&t.Attr[i]) + } + + switch { + case isDef: + t.Name.Space = "" + default: + if s := n.lookup(t.Name.Space); s != "" { + t.Name.Space = "" + t.Name.Local = s + ":" + t.Name.Local + } + } + tok = t + + case xml.CharData: + tok = xml.CharData(replace(string(t))) + + case xml.EndElement: + s := n.lookup(t.Name.Space) + + n = n.pop() + + if n.isDef(t.Name.Space) { + t.Name.Space = "" + } else if s != "" { + t.Name.Space = "" + t.Name.Local = s + ":" + t.Name.Local + } + tok = t + } + + if err := encoder.EncodeToken(tok); err != nil { + return err + } + } + + return encoder.Flush() +} + +type nsframe struct { + def string + ns map[string]string +} + +type nsdef []nsframe + +func (n nsdef) setDef(def string) { + if l := len(n); l > 0 { + n[l-1].def = def + } +} + +func (n nsdef) isDef(s string) bool { + for i := len(n) - 1; i >= 0; i-- { + if x := n[i].def; x != "" { + return s == x + } + } + return false +} + +func (n nsdef) define(ns, s string) { + if l := len(n); l > 0 { + n[l-1].ns[ns] = s + } +} + +func (n nsdef) lookup(ns string) string { + for i := len(n) - 1; i >= 0; i-- { + if s := n[i].ns[ns]; s != "" { + return s + } + } + return "" +} + +func (n nsdef) checkDef(at *xml.Attr) { + if at.Name.Space == "" && at.Name.Local == "xmlns" { + n.setDef(at.Value) + } +} + +func (n nsdef) adjust(at *xml.Attr) { + switch { + case at.Name.Space == "xmlns": + n.define(at.Value, at.Name.Local) + at.Name.Local = "xmlns:" + at.Name.Local + at.Name.Space = "" + + case at.Name.Space != "": + if n.isDef(at.Name.Space) { + at.Name.Space = "" + } else if s := n.lookup(at.Name.Space); s != "" { + at.Name.Local = s + ":" + at.Name.Local + at.Name.Space = "" + } + } +} + +func (n nsdef) push() nsdef { + return append(n, nsframe{ns: make(map[string]string)}) +} + +func (n nsdef) pop() nsdef { + if l := len(n); l > 0 { + n[l-1] = nsframe{} + n = n[:l-1] + } + return n +}