Mercurial > gemma
changeset 412:21fb992b1f5a
merge
author | Thomas Junk <thomas.junk@intevation.de> |
---|---|
date | Wed, 15 Aug 2018 16:54:49 +0200 |
parents | d428db60fad1 (current diff) 3f803d64a6ee (diff) |
children | a9440a4826aa |
files | controllers/externalwfs.go misc/random.go |
diffstat | 11 files changed, 545 insertions(+), 444 deletions(-) [+] |
line wrap: on
line diff
--- a/auth/session.go Wed Aug 15 16:54:23 2018 +0200 +++ b/auth/session.go Wed Aug 15 16:54:49 2018 +0200 @@ -5,6 +5,7 @@ "io" "time" + "gemma.intevation.de/gemma/common" "gemma.intevation.de/gemma/misc" ) @@ -74,7 +75,7 @@ func GenerateSessionKey() string { return base64.URLEncoding.EncodeToString( - misc.GenerateRandomKey(sessionKeyLength)) + common.GenerateRandomKey(sessionKeyLength)) } func GenerateSession(user, password string) (string, *Session, error) {
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/common/random.go Wed Aug 15 16:54:49 2018 +0200 @@ -0,0 +1,48 @@ +package common + +import ( + "bytes" + "crypto/rand" + "io" + "log" + "math/big" +) + +func GenerateRandomKey(length int) []byte { + k := make([]byte, length) + if _, err := io.ReadFull(rand.Reader, k); err != nil { + return nil + } + return k +} + +func RandomString(n int) string { + + const ( + special = "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~" + alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + + "abcdefghijklmnopqrstuvwxyz" + + "0123456789" + + special + ) + + max := big.NewInt(int64(len(alphabet))) + out := make([]byte, n) + + for i := 0; i < 1000; i++ { + for i := range out { + v, err := rand.Int(rand.Reader, max) + if err != nil { + log.Panicf("error: %v\n", err) + } + out[i] = alphabet[v.Int64()] + } + // Ensure at least one special char. + if bytes.IndexAny(out, special) >= 0 { + return string(out) + } + } + log.Println("warn: Your random generator may be broken.") + out[0] = special[0] + return string(out) +}
--- a/config/config.go Wed Aug 15 16:54:23 2018 +0200 +++ b/config/config.go Wed Aug 15 16:54:49 2018 +0200 @@ -1,11 +1,16 @@ package config import ( + "crypto/sha256" + "fmt" "log" + "sync" homedir "github.com/mitchellh/go-homedir" "github.com/spf13/cobra" "github.com/spf13/viper" + + "gemma.intevation.de/gemma/common" ) // This is not part of the persistent config. @@ -44,6 +49,40 @@ func GeoServerPassword() string { return viper.GetString("geoserver-password") } func GeoServerTables() []string { return viper.GetStringSlice("geoserver-tables") } +var ( + proxyKeyOnce sync.Once + proxyKey []byte + + proxyPrefixOnce sync.Once + proxyPrefix string +) + +func ProxyKey() []byte { + fetchKey := func() { + if proxyKey == nil { + key := []byte(viper.GetString("proxy-key")) + if len(key) == 0 { + key = common.GenerateRandomKey(64) + } + hash := sha256.New() + hash.Write(key) + proxyKey = hash.Sum(nil) + } + } + proxyKeyOnce.Do(fetchKey) + return proxyKey +} + +func ProxyPrefix() string { + fetchPrefix := func() { + if proxyPrefix == "" { + proxyPrefix = fmt.Sprintf("http://%s:%d", WebHost(), WebPort()) + } + } + proxyPrefixOnce.Do(fetchPrefix) + return proxyPrefix +} + var RootCmd = &cobra.Command{ Use: "gemma", Short: "gemma is a server for waterway monitoring and management", @@ -115,6 +154,10 @@ str("geoserver-user", "admin", "GeoServer user") str("geoserver-password", "geoserver", "GeoServer password") strSl("geoserver-tables", geoTables, "tables to publish with GeoServer") + + str("proxy-key", "", `signing key for proxy URLs. Defaults to random key.`) + str("proxy-prefix", "", `URL prefix of proxy. Defaults to "http://${web-host}:${web-port}"`) + } func initConfig() {
--- a/controllers/externalwfs.go Wed Aug 15 16:54:23 2018 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,363 +0,0 @@ -package controllers - -import ( - "compress/flate" - "compress/gzip" - "encoding/xml" - "io" - "io/ioutil" - "log" - "net/http" - "net/url" - "strings" - "time" - - "github.com/gorilla/mux" - "golang.org/x/net/html/charset" - - "gemma.intevation.de/gemma/config" -) - -// roundTripFunc is a helper type to make externalWFSDirector a http.RoundTripper. -type roundTripFunc func(*http.Request) (*http.Response, error) - -func (rtf roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { - return rtf(req) -} - -func externalWFSDirector(req *http.Request) { - - abort := func(format string, args ...interface{}) { - log.Printf(format, args...) - panic(http.ErrAbortHandler) - } - - external := config.ExternalWFSs() - if external == nil || len(external) == 0 { - abort("No external WFS proxy config found\n") - } - vars := mux.Vars(req) - wfs := vars["wfs"] - rest := vars["rest"] - - log.Printf("rest: %s\n", rest) - - alias, found := external[wfs] - if !found { - abort("No config found for %s\n", wfs) - } - data, ok := alias.(map[string]interface{}) - if !ok { - abort("error: badly configured external WFS %s\n", wfs) - } - - urlS, found := data["url"] - if !found { - abort("error: missing url for external WFS %s\n", wfs) - } - - prefix, ok := urlS.(string) - if !ok { - abort("error: badly configured url for external WFS %s\n", wfs) - } - - https := useHTTPS(req) - - log.Printf("%v\n", prefix) - nURL := prefix + "/" + rest + "?" + req.URL.RawQuery - log.Printf("%v\n", nURL) - - u, err := url.Parse(nURL) - if err != nil { - abort("Invalid url: %v\n", err) - } - req.URL = u - req.Header.Set("X-Gemma-From", prefix) - to := https + "://" + req.Host + "/api/externalwfs/" + wfs - req.Header.Set("X-Gemma-To", to) - - req.Host = u.Host - - //log.Printf("headers: %v\n", req.Header) -} - -func externalWFSTransport(req *http.Request) (*http.Response, error) { - - from := req.Header.Get("X-Gemma-From") - to := req.Header.Get("X-Gemma-To") - req.Header.Del("X-Gemma-From") - req.Header.Del("X-Gemma-To") - - // To prevent some caching effects. - req.Header.Del("If-None-Match") - - resp, err := http.DefaultTransport.RoundTrip(req) - if err != nil { - return nil, err - } - resp.Header.Set("X-Gemma-From", from) - resp.Header.Set("X-Gemma-To", to) - - return resp, err -} - -type nopCloser struct { - io.Writer -} - -func (nopCloser) Close() error { return nil } - -func encoding(h http.Header) ( - func(io.Reader) (io.ReadCloser, error), - func(io.Writer) (io.WriteCloser, error), -) { - switch enc := h.Get("Content-Encoding"); { - case strings.Contains(enc, "gzip"): - log.Println("gzip compression") - return func(r io.Reader) (io.ReadCloser, error) { - return gzip.NewReader(r) - }, - func(w io.Writer) (io.WriteCloser, error) { - return gzip.NewWriter(w), nil - } - case strings.Contains(enc, "deflate"): - log.Println("Deflate compression") - return func(r io.Reader) (io.ReadCloser, error) { - return flate.NewReader(r), nil - }, - func(w io.Writer) (io.WriteCloser, error) { - return flate.NewWriter(w, flate.DefaultCompression) - } - default: - log.Println("No content compression") - return func(r io.Reader) (io.ReadCloser, error) { - if r2, ok := r.(io.ReadCloser); ok { - return r2, nil - } - return ioutil.NopCloser(r), nil - }, - func(w io.Writer) (io.WriteCloser, error) { - if w2, ok := w.(io.WriteCloser); ok { - return w2, nil - } - return nopCloser{w}, nil - } - } -} - -func externalWFSModifyResponse(resp *http.Response) error { - - from := resp.Header.Get("X-Gemma-From") - to := resp.Header.Get("X-Gemma-To") - resp.Header.Del("X-Gemma-From") - resp.Header.Del("X-Gemma-To") - - if !isXML(resp.Header) { - return nil - } - - log.Printf("rewrite from %s to %s\n", from, to) - - pr, pw := io.Pipe() - - var ( - r io.ReadCloser - w io.WriteCloser - err error - ) - - reader, writer := encoding(resp.Header) - - if r, err = reader(resp.Body); err != nil { - return err - } - - if w, err = writer(pw); err != nil { - return err - } - - go func(force io.ReadCloser) { - start := time.Now() - defer func() { - //r.Close() - w.Close() - pw.Close() - force.Close() - log.Printf("rewrite took %s\n", time.Since(start)) - }() - if err := rewrite(w, r, from, to); err != nil { - log.Printf("rewrite failed: %v\n", err) - return - } - log.Println("rewrite successful") - }(resp.Body) - - resp.Body = pr - - return nil -} - -var xmlContentTypes = []string{ - "application/xml", - "text/xml", - "application/gml+xml", -} - -func isXML(h http.Header) bool { - for _, t := range h["Content-Type"] { - t = strings.ToLower(t) - for _, ct := range xmlContentTypes { - if strings.Contains(t, ct) { - return true - } - } - } - return false -} - -func rewrite(w io.Writer, r io.Reader, from, to string) error { - - decoder := xml.NewDecoder(r) - decoder.CharsetReader = charset.NewReaderLabel - - encoder := xml.NewEncoder(w) - - replace := func(s string) string { - return strings.Replace(s, from, to, -1) - } - - var n nsdef - -tokens: - for { - tok, err := decoder.Token() - switch { - case tok == nil && err == io.EOF: - break tokens - case err != nil: - return err - } - - switch t := tok.(type) { - case xml.StartElement: - t = t.Copy() - - isDef := n.isDef(t.Name.Space) - n = n.push() - - for i := range t.Attr { - t.Attr[i].Value = replace(t.Attr[i].Value) - n.checkDef(&t.Attr[i]) - } - - for i := range t.Attr { - n.adjust(&t.Attr[i]) - } - - switch { - case isDef: - t.Name.Space = "" - default: - if s := n.lookup(t.Name.Space); s != "" { - t.Name.Space = "" - t.Name.Local = s + ":" + t.Name.Local - } - } - tok = t - - case xml.CharData: - tok = xml.CharData(replace(string(t))) - - case xml.EndElement: - s := n.lookup(t.Name.Space) - - n = n.pop() - - if n.isDef(t.Name.Space) { - t.Name.Space = "" - } else if s != "" { - t.Name.Space = "" - t.Name.Local = s + ":" + t.Name.Local - } - tok = t - } - - if err := encoder.EncodeToken(tok); err != nil { - return err - } - } - - return encoder.Flush() -} - -type nsframe struct { - def string - ns map[string]string -} - -type nsdef []nsframe - -func (n nsdef) setDef(def string) { - if l := len(n); l > 0 { - n[l-1].def = def - } -} - -func (n nsdef) isDef(s string) bool { - for i := len(n) - 1; i >= 0; i-- { - if x := n[i].def; x != "" { - return s == x - } - } - return false -} - -func (n nsdef) define(ns, s string) { - if l := len(n); l > 0 { - n[l-1].ns[ns] = s - } -} - -func (n nsdef) lookup(ns string) string { - for i := len(n) - 1; i >= 0; i-- { - if s := n[i].ns[ns]; s != "" { - return s - } - } - return "" -} - -func (n nsdef) checkDef(at *xml.Attr) { - if at.Name.Space == "" && at.Name.Local == "xmlns" { - n.setDef(at.Value) - } -} - -func (n nsdef) adjust(at *xml.Attr) { - switch { - case at.Name.Space == "xmlns": - n.define(at.Value, at.Name.Local) - at.Name.Local = "xmlns:" + at.Name.Local - at.Name.Space = "" - - case at.Name.Space != "": - if n.isDef(at.Name.Space) { - at.Name.Space = "" - } else if s := n.lookup(at.Name.Space); s != "" { - at.Name.Local = s + ":" + at.Name.Local - at.Name.Space = "" - } - } -} - -func (n nsdef) push() nsdef { - return append(n, nsframe{ns: make(map[string]string)}) -} - -func (n nsdef) pop() nsdef { - if l := len(n); l > 0 { - n[l-1] = nsframe{} - n = n[:l-1] - } - return n -}
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/controllers/proxy.go Wed Aug 15 16:54:49 2018 +0200 @@ -0,0 +1,380 @@ +package controllers + +import ( + "compress/flate" + "compress/gzip" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/xml" + "io" + "io/ioutil" + "log" + "net/http" + "net/url" + "regexp" + "strings" + "time" + + "gemma.intevation.de/gemma/config" + "github.com/gorilla/mux" + "golang.org/x/net/html/charset" +) + +// proxyBlackList is a set of URLs that should not be rewritten by the proxy. +var proxyBlackList = map[string]struct{}{ + "http://www.w3.org/2001/XMLSchema-instance": struct{}{}, + "http://www.w3.org/1999/xlink": struct{}{}, + "http://www.w3.org/2001/XMLSchema": struct{}{}, + "http://www.w3.org/XML/1998/namespace": struct{}{}, + "http://www.opengis.net/wfs/2.0": struct{}{}, + "http://www.opengis.net/ows/1.1": struct{}{}, + "http://www.opengis.net/gml/3.2": struct{}{}, + "http://www.opengis.net/fes/2.0": struct{}{}, + "http://schemas.opengis.net/gml": struct{}{}, +} + +func findEntry(entry string) (string, bool) { + external := config.ExternalWFSs() + if external == nil || len(external) == 0 { + return "", false + } + alias, found := external[entry] + if !found { + return "", false + } + data, ok := alias.(map[string]interface{}) + if !ok { + return "", false + } + urlS, found := data["url"] + if !found { + return "", false + } + url, ok := urlS.(string) + return url, ok +} + +func proxyDirector(req *http.Request) { + + log.Printf("proxyDirector: %s\n", req.RequestURI) + + abort := func(format string, args ...interface{}) { + log.Printf(format, args...) + panic(http.ErrAbortHandler) + } + + vars := mux.Vars(req) + + var s string + + if entry, found := vars["entry"]; found { + if s, found = findEntry(entry); !found { + abort("Cannot find entry '%s'\n", entry) + } + } else { + expectedMAC, err := base64.URLEncoding.DecodeString(vars["hash"]) + if err != nil { + abort("Cannot base64 decode hash: %v\n", err) + } + url, err := base64.URLEncoding.DecodeString(vars["url"]) + if err != nil { + abort("Cannot base64 decode url: %v\n", err) + } + + mac := hmac.New(sha256.New, config.ProxyKey()) + mac.Write(url) + messageMAC := mac.Sum(nil) + + s = string(url) + + if !hmac.Equal(messageMAC, expectedMAC) { + abort("HMAC of URL %s failed.\n", s) + } + } + + nURL := s + "?" + req.URL.RawQuery + //log.Printf("%v\n", nURL) + + u, err := url.Parse(nURL) + if err != nil { + abort("Invalid url: %v\n", err) + } + req.URL = u + + req.Host = u.Host + //req.Header.Del("If-None-Match") + //log.Printf("headers: %v\n", req.Header) +} + +type nopCloser struct { + io.Writer +} + +func (nopCloser) Close() error { return nil } + +func encoding(h http.Header) ( + func(io.Reader) (io.ReadCloser, error), + func(io.Writer) (io.WriteCloser, error), +) { + switch enc := h.Get("Content-Encoding"); { + case strings.Contains(enc, "gzip"): + log.Println("gzip compression") + return func(r io.Reader) (io.ReadCloser, error) { + return gzip.NewReader(r) + }, + func(w io.Writer) (io.WriteCloser, error) { + return gzip.NewWriter(w), nil + } + case strings.Contains(enc, "deflate"): + log.Println("Deflate compression") + return func(r io.Reader) (io.ReadCloser, error) { + return flate.NewReader(r), nil + }, + func(w io.Writer) (io.WriteCloser, error) { + return flate.NewWriter(w, flate.DefaultCompression) + } + default: + log.Println("No content compression") + return func(r io.Reader) (io.ReadCloser, error) { + if r2, ok := r.(io.ReadCloser); ok { + return r2, nil + } + return ioutil.NopCloser(r), nil + }, + func(w io.Writer) (io.WriteCloser, error) { + if w2, ok := w.(io.WriteCloser); ok { + return w2, nil + } + return nopCloser{w}, nil + } + } +} + +func proxyModifyResponse(resp *http.Response) error { + + if !isXML(resp.Header) { + return nil + } + + pr, pw := io.Pipe() + + var ( + r io.ReadCloser + w io.WriteCloser + err error + ) + + reader, writer := encoding(resp.Header) + + if r, err = reader(resp.Body); err != nil { + return err + } + + if w, err = writer(pw); err != nil { + return err + } + + go func(force io.ReadCloser) { + start := time.Now() + defer func() { + //r.Close() + w.Close() + pw.Close() + force.Close() + log.Printf("rewrite took %s\n", time.Since(start)) + }() + if err := rewrite(w, r); err != nil { + log.Printf("rewrite failed: %v\n", err) + return + } + log.Println("rewrite successful") + }(resp.Body) + + resp.Body = pr + + return nil +} + +var xmlContentTypes = []string{ + "application/xml", + "text/xml", + "application/gml+xml", +} + +func isXML(h http.Header) bool { + for _, t := range h["Content-Type"] { + t = strings.ToLower(t) + for _, ct := range xmlContentTypes { + if strings.Contains(t, ct) { + return true + } + } + } + return false +} + +var replaceRe = regexp.MustCompile(`\b(https?://[^\s\?]*)`) + +func replace(s string) string { + + proxyKey := config.ProxyKey() + proxyPrefix := config.ProxyPrefix() + "/api/proxy/" + + return replaceRe.ReplaceAllStringFunc(s, func(s string) string { + if _, found := proxyBlackList[s]; found { + return s + } + mac := hmac.New(sha256.New, proxyKey) + b := []byte(s) + mac.Write(b) + expectedMAC := mac.Sum(nil) + + hash := base64.URLEncoding.EncodeToString(expectedMAC) + enc := base64.URLEncoding.EncodeToString(b) + return proxyPrefix + hash + "/" + enc + }) +} + +func rewrite(w io.Writer, r io.Reader) error { + + decoder := xml.NewDecoder(r) + decoder.CharsetReader = charset.NewReaderLabel + + encoder := xml.NewEncoder(w) + + var n nsdef + +tokens: + for { + tok, err := decoder.Token() + switch { + case tok == nil && err == io.EOF: + break tokens + case err != nil: + return err + } + + switch t := tok.(type) { + case xml.StartElement: + t = t.Copy() + + isDef := n.isDef(t.Name.Space) + n = n.push() + + for i := range t.Attr { + t.Attr[i].Value = replace(t.Attr[i].Value) + n.checkDef(&t.Attr[i]) + } + + for i := range t.Attr { + n.adjust(&t.Attr[i]) + } + + switch { + case isDef: + t.Name.Space = "" + default: + if s := n.lookup(t.Name.Space); s != "" { + t.Name.Space = "" + t.Name.Local = s + ":" + t.Name.Local + } + } + tok = t + + case xml.CharData: + tok = xml.CharData(replace(string(t))) + + case xml.EndElement: + s := n.lookup(t.Name.Space) + + n = n.pop() + + if n.isDef(t.Name.Space) { + t.Name.Space = "" + } else if s != "" { + t.Name.Space = "" + t.Name.Local = s + ":" + t.Name.Local + } + tok = t + } + + if err := encoder.EncodeToken(tok); err != nil { + return err + } + } + + return encoder.Flush() +} + +type nsframe struct { + def string + ns map[string]string +} + +type nsdef []nsframe + +func (n nsdef) setDef(def string) { + if l := len(n); l > 0 { + n[l-1].def = def + } +} + +func (n nsdef) isDef(s string) bool { + for i := len(n) - 1; i >= 0; i-- { + if x := n[i].def; x != "" { + return s == x + } + } + return false +} + +func (n nsdef) define(ns, s string) { + if l := len(n); l > 0 { + n[l-1].ns[ns] = s + } +} + +func (n nsdef) lookup(ns string) string { + for i := len(n) - 1; i >= 0; i-- { + if s := n[i].ns[ns]; s != "" { + return s + } + } + return "" +} + +func (n nsdef) checkDef(at *xml.Attr) { + if at.Name.Space == "" && at.Name.Local == "xmlns" { + n.setDef(at.Value) + } +} + +func (n nsdef) adjust(at *xml.Attr) { + switch { + case at.Name.Space == "xmlns": + n.define(at.Value, at.Name.Local) + at.Name.Local = "xmlns:" + at.Name.Local + at.Name.Space = "" + + case at.Name.Space != "": + if n.isDef(at.Name.Space) { + at.Name.Space = "" + } else if s := n.lookup(at.Name.Space); s != "" { + at.Name.Local = s + ":" + at.Name.Local + at.Name.Space = "" + } + } +} + +func (n nsdef) push() nsdef { + return append(n, nsframe{ns: make(map[string]string)}) +} + +func (n nsdef) pop() nsdef { + if l := len(n); l > 0 { + n[l-1] = nsframe{} + n = n[:l-1] + } + return n +}
--- a/controllers/pwreset.go Wed Aug 15 16:54:23 2018 +0200 +++ b/controllers/pwreset.go Wed Aug 15 16:54:49 2018 +0200 @@ -14,6 +14,7 @@ "github.com/gorilla/mux" "gemma.intevation.de/gemma/auth" + "gemma.intevation.de/gemma/common" "gemma.intevation.de/gemma/config" "gemma.intevation.de/gemma/misc" ) @@ -155,7 +156,7 @@ } func generateHash() string { - return hex.EncodeToString(misc.GenerateRandomKey(hashLength)) + return hex.EncodeToString(common.GenerateRandomKey(hashLength)) } func generateNewPassword() string { @@ -165,7 +166,7 @@ return strings.TrimSpace(string(out)) } // Use internal generator. - return misc.RandomString(20) + return common.RandomString(20) } func passwordResetRequest(
--- a/controllers/routes.go Wed Aug 15 16:54:23 2018 +0200 +++ b/controllers/routes.go Wed Aug 15 16:54:49 2018 +0200 @@ -52,13 +52,17 @@ }).Methods(http.MethodGet) // Proxy for external WFSs. - externalWFSProxy := &httputil.ReverseProxy{ - Director: externalWFSDirector, - Transport: roundTripFunc(externalWFSTransport), - ModifyResponse: externalWFSModifyResponse, + proxy := &httputil.ReverseProxy{ + Director: proxyDirector, + ModifyResponse: proxyModifyResponse, } - api.Handle("/externalwfs/{wfs}/{rest:.*}", externalWFSProxy). + api.Handle(`/proxy/{hash}/{url}`, proxy). + Methods( + http.MethodGet, http.MethodPost, + http.MethodPut, http.MethodDelete) + + api.Handle("/proxy/{entry}", proxy). Methods( http.MethodGet, http.MethodPost, http.MethodPut, http.MethodDelete)
--- a/misc/random.go Wed Aug 15 16:54:23 2018 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,48 +0,0 @@ -package misc - -import ( - "bytes" - "crypto/rand" - "io" - "log" - "math/big" -) - -func GenerateRandomKey(length int) []byte { - k := make([]byte, length) - if _, err := io.ReadFull(rand.Reader, k); err != nil { - return nil - } - return k -} - -func RandomString(n int) string { - - const ( - special = "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~" - alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + - "abcdefghijklmnopqrstuvwxyz" + - "0123456789" + - special - ) - - max := big.NewInt(int64(len(alphabet))) - out := make([]byte, n) - - for i := 0; i < 1000; i++ { - for i := range out { - v, err := rand.Int(rand.Reader, max) - if err != nil { - log.Panicf("error: %v\n", err) - } - out[i] = alphabet[v.Int64()] - } - // Ensure at least one special char. - if bytes.IndexAny(out, special) >= 0 { - return string(out) - } - } - log.Println("warn: Your random generator may be broken.") - out[0] = special[0] - return string(out) -}
--- a/schema/manage_users.sql Wed Aug 15 16:54:23 2018 +0200 +++ b/schema/manage_users.sql Wed Aug 15 16:54:49 2018 +0200 @@ -24,19 +24,6 @@ LANGUAGE plpgsql; --- Security-definer function to get current users country, which allows to --- restrict the view on user_profiles by country without infinite recursion -CREATE FUNCTION users.current_user_country() - RETURNS internal.user_profiles.country%TYPE - AS $$ - SELECT country FROM internal.user_profiles - WHERE username = session_user - $$ - LANGUAGE SQL - SECURITY DEFINER - STABLE PARALLEL SAFE; - - CREATE OR REPLACE VIEW users.list_users WITH (security_barrier) AS SELECT r.rolname, @@ -51,11 +38,23 @@ JOIN pg_roles r ON a.roleid = r.oid WHERE p.username = current_user OR pg_has_role('waterway_admin', 'MEMBER') - AND p.country = users.current_user_country() + AND p.country = ( + SELECT country FROM internal.user_profiles + WHERE username = current_user) OR pg_has_role('pw_reset', 'MEMBER') OR pg_has_role('sys_admin', 'MEMBER'); +CREATE OR REPLACE FUNCTION users.current_user_country() + RETURNS internal.user_profiles.country%TYPE + AS $$ + SELECT country FROM users.list_users + WHERE username = current_user + $$ + LANGUAGE SQL + STABLE PARALLEL SAFE; + + CREATE OR REPLACE FUNCTION internal.create_user() RETURNS trigger AS $$ BEGIN @@ -86,6 +85,31 @@ EXECUTE PROCEDURE internal.create_user(); +-- Prevent roles other than sys_admin and pw_reset to update any user but +-- themselves (affects waterway_admin) +CREATE OR REPLACE FUNCTION internal.authorize_update_user() RETURNS trigger +AS $$ +BEGIN + IF OLD.username <> current_user + AND NOT (pg_has_role('sys_admin', 'MEMBER') + OR pg_has_role('pw_reset', 'MEMBER')) + THEN + RETURN NULL; + ELSE + RETURN NEW; + END IF; +END; +$$ + LANGUAGE plpgsql; + +-- Note that PostgreSQL fires triggers for the same event in alphabetical +-- order! Make sure that authorization takes place before any other trigger +-- is fired that might execute otherwise unauthorized statements! +CREATE TRIGGER authorize_update_user INSTEAD OF UPDATE ON users.list_users + FOR EACH ROW + EXECUTE PROCEDURE internal.authorize_update_user(); + + CREATE OR REPLACE FUNCTION internal.update_user() RETURNS trigger AS $$ DECLARE @@ -93,14 +117,6 @@ BEGIN cur_username = OLD.username; - IF cur_username <> session_user - AND NOT (pg_has_role(session_user, 'sys_admin', 'MEMBER') - OR pg_has_role(session_user, 'pw_reset', 'MEMBER')) - THEN - -- Discard row. This is what WITH CHECK in an RLS policy would do. - RETURN NULL; - END IF; - UPDATE internal.user_profiles p SET (username, country, map_extent, email_address) = (NEW.username, NEW.country, NEW.map_extent, NEW.email_address)
--- a/schema/manage_users_tests.sql Wed Aug 15 16:54:23 2018 +0200 +++ b/schema/manage_users_tests.sql Wed Aug 15 16:54:49 2018 +0200 @@ -138,6 +138,25 @@ $$, 'Waterway admin cannot update attributes of other users in country'); +-- The above test will pass even if the password is actually updated in case +-- a trigger returns NULL after ALTER ROLE ... PASSWORD ... has been executed. +RESET SESSION AUTHORIZATION; +CREATE TEMP TABLE old_pw_hash AS + SELECT rolpassword FROM pg_authid WHERE rolname = 'test_user_at'; +SET SESSION AUTHORIZATION test_admin_at; +UPDATE users.list_users + SET pw = 'test_user_at2!' + WHERE username = 'test_user_at'; +RESET SESSION AUTHORIZATION; +SELECT set_eq($$ + SELECT rolpassword FROM old_pw_hash + $$, + $$ + SELECT rolpassword FROM pg_authid WHERE rolname = 'test_user_at' + $$, + 'Waterway admin cannot update password of other users in country'); + + SET SESSION AUTHORIZATION test_sys_admin1; SELECT lives_ok($$ @@ -223,8 +242,8 @@ -- To compare passwords, we need to run the following tests as superuser RESET SESSION AUTHORIZATION; -CREATE TEMP TABLE old_pw_hash AS - SELECT rolpassword FROM pg_authid WHERE rolname = 'test_user_at'; +UPDATE old_pw_hash SET rolpassword = ( + SELECT rolpassword FROM pg_authid WHERE rolname = 'test_user_at'); UPDATE users.list_users SET (rolname, username, pw, country, map_extent, email_address)
--- a/schema/run_tests.sh Wed Aug 15 16:54:23 2018 +0200 +++ b/schema/run_tests.sh Wed Aug 15 16:54:49 2018 +0200 @@ -16,7 +16,7 @@ -c 'SET client_min_messages TO WARNING' \ -c "DROP ROLE IF EXISTS $TEST_ROLES" \ -f tap_tests_data.sql \ - -c 'SELECT plan(44)' \ + -c 'SELECT plan(45)' \ -f auth_tests.sql \ -f manage_users_tests.sql \ -c 'SELECT * FROM finish()'