# HG changeset patch # User Sascha L. Teichmann # Date 1534347050 -7200 # Node ID c1047fd04a3ab51315452ebb17032f53bb83c6b8 # Parent a9440a4826aa0bb00e77aab93cfb4f7c29d8e133 Moved project specific Go packages to new pkg folder. diff -r a9440a4826aa -r c1047fd04a3a Dockerfile --- a/Dockerfile Wed Aug 15 17:13:28 2018 +0200 +++ b/Dockerfile Wed Aug 15 17:30:50 2018 +0200 @@ -11,11 +11,8 @@ # Copy only backend stuff COPY 3rdpartylibs.sh ./ -COPY auth ./auth/ +COPY pkg ./pkg/ COPY cmd ./cmd/ -COPY config ./config/ -COPY controllers ./controllers/ -COPY misc ./misc/ COPY Makefile ./ COPY example_conf.toml ./ diff -r a9440a4826aa -r c1047fd04a3a auth/connection.go --- a/auth/connection.go Wed Aug 15 17:13:28 2018 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,99 +0,0 @@ -package auth - -import ( - "database/sql" - "errors" - "io" - "log" - "sync" - "time" - - "gemma.intevation.de/gemma/misc" -) - -var ErrNoSuchToken = errors.New("No such token") - -const ( - maxOpen = 16 - maxDBIdle = time.Minute * 5 -) - -type Connection struct { - session *Session - - access time.Time - db *sql.DB - refCount int - - mu sync.Mutex -} - -func (c *Connection) serialize(w io.Writer) error { - if err := c.session.serialize(w); err != nil { - return err - } - access, err := c.last().MarshalText() - if err != nil { - return err - } - wr := misc.BinWriter{w, nil} - wr.WriteBin(uint32(len(access))) - wr.WriteBin(access) - return wr.Err -} - -func (c *Connection) deserialize(r io.Reader) error { - session := new(Session) - if err := session.deserialize(r); err != nil { - return err - } - - rd := misc.BinReader{r, nil} - var l uint32 - rd.ReadBin(&l) - access := make([]byte, l) - rd.ReadBin(access) - - if rd.Err != nil { - return rd.Err - } - - var t time.Time - if err := t.UnmarshalText(access); err != nil { - return err - } - - *c = Connection{ - session: session, - access: t, - } - - return nil -} - -func (c *Connection) set(session *Session) { - c.session = session - c.touch() -} - -func (c *Connection) touch() { - c.mu.Lock() - c.access = time.Now() - c.mu.Unlock() -} - -func (c *Connection) last() time.Time { - c.mu.Lock() - access := c.access - c.mu.Unlock() - return access -} - -func (c *Connection) close() { - if c.db != nil { - if err := c.db.Close(); err != nil { - log.Printf("warn: %v\n", err) - } - c.db = nil - } -} diff -r a9440a4826aa -r c1047fd04a3a auth/middleware.go --- a/auth/middleware.go Wed Aug 15 17:13:28 2018 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,79 +0,0 @@ -package auth - -import ( - "context" - "net/http" - "strings" -) - -type contextType int - -const ( - sessionKey contextType = iota - tokenKey -) - -func GetSession(req *http.Request) (*Session, bool) { - session, ok := req.Context().Value(sessionKey).(*Session) - return session, ok -} - -func GetToken(req *http.Request) (string, bool) { - token, ok := req.Context().Value(tokenKey).(string) - return token, ok -} - -func SessionMiddleware(next http.Handler) http.Handler { - - return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - - auth := req.Header.Get("X-Gemma-Auth") - - token := strings.TrimSpace(auth) - if token == "" { - http.Error(rw, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) - return - } - - session := ConnPool.Session(token) - if session == nil { - http.Error(rw, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) - return - } - - ctx := req.Context() - ctx = context.WithValue(ctx, sessionKey, session) - ctx = context.WithValue(ctx, tokenKey, token) - req = req.WithContext(ctx) - - next.ServeHTTP(rw, req) - }) -} - -func SessionChecker(next http.Handler, check func(*Session) bool) http.Handler { - return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - claims, ok := GetSession(req) - if !ok || !check(claims) { - http.Error(rw, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) - return - } - next.ServeHTTP(rw, req) - }) -} - -func HasRole(roles ...string) func(*Session) bool { - return func(session *Session) bool { - for _, r1 := range roles { - if session.Roles.Has(r1) { - return true - } - } - return false - } -} - -func EnsureRole(roles ...string) func(http.Handler) http.Handler { - return func(handler http.Handler) http.Handler { - return SessionMiddleware(SessionChecker(handler, HasRole(roles...))) - } -} diff -r a9440a4826aa -r c1047fd04a3a auth/opendb.go --- a/auth/opendb.go Wed Aug 15 17:13:28 2018 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,69 +0,0 @@ -package auth - -import ( - "database/sql" - "fmt" - "strings" - - "gemma.intevation.de/gemma/config" - - _ "github.com/jackc/pgx/stdlib" -) - -const driver = "pgx" - -// dbQuote quotes strings to be able to contain whitespace -// and backslashes in database DSN strings. -var dbQuote = strings.NewReplacer(`\`, `\\`, `'`, `\'`).Replace - -// dbDSN creates a data source name suitable for sql.Open on -// PostgreSQL databases. -func dbDSN(host string, port uint, dbname, user, password string, sslmode string) string { - return fmt.Sprintf("host=%s port=%d dbname=%s user=%s password=%s sslmode=%s", - dbQuote(host), port, dbQuote(dbname), - dbQuote(user), dbQuote(password), sslmode) -} - -func OpenDB(user, password string) (*sql.DB, error) { - dsn := dbDSN( - config.DBHost(), config.DBPort(), - config.DBName(), - user, password, - config.DBSSLMode()) - return sql.Open(driver, dsn) -} - -const allRoles = ` -WITH RECURSIVE cte AS ( - SELECT oid FROM pg_roles WHERE rolname = current_user - UNION ALL - SELECT m.roleid - FROM cte - JOIN pg_auth_members m ON m.member = cte.oid -) -SELECT rolname FROM pg_roles -WHERE oid IN (SELECT oid FROM cte) AND rolname <> current_user` - -func AllOtherRoles(user, password string) ([]string, error) { - db, err := OpenDB(user, password) - if err != nil { - return nil, err - } - defer db.Close() - rows, err := db.Query(allRoles) - if err != nil { - return nil, err - } - defer rows.Close() - - roles := []string{} // explicit empty by intention. - - for rows.Next() { - var role string - if err := rows.Scan(&role); err != nil { - return nil, err - } - roles = append(roles, role) - } - return roles, rows.Err() -} diff -r a9440a4826aa -r c1047fd04a3a auth/pool.go --- a/auth/pool.go Wed Aug 15 17:13:28 2018 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,322 +0,0 @@ -package auth - -import ( - "bytes" - "database/sql" - "log" - "time" - - bolt "github.com/coreos/bbolt" -) - -// ConnPool is the global connection pool. -var ConnPool *ConnectionPool - -type ConnectionPool struct { - storage *bolt.DB - conns map[string]*Connection - cmds chan func(*ConnectionPool) -} - -var sessionsBucket = []byte("sessions") - -func NewConnectionPool(filename string) (*ConnectionPool, error) { - - pcp := &ConnectionPool{ - conns: make(map[string]*Connection), - cmds: make(chan func(*ConnectionPool)), - } - if err := pcp.openStorage(filename); err != nil { - return nil, err - } - go pcp.run() - return pcp, nil -} - -// openStorage opens a storage file. -func (pcp *ConnectionPool) openStorage(filename string) error { - - // No file, nothing to restore/persist. - if filename == "" { - return nil - } - - db, err := bolt.Open(filename, 0600, nil) - if err != nil { - return err - } - - err = db.Update(func(tx *bolt.Tx) error { - b, err := tx.CreateBucketIfNotExists(sessionsBucket) - if err != nil { - return err - } - - // pre-load sessions - c := b.Cursor() - - for k, v := c.First(); k != nil; k, v = c.Next() { - var conn Connection - if err := conn.deserialize(bytes.NewReader(v)); err != nil { - return err - } - pcp.conns[string(k)] = &conn - } - - return nil - }) - - if err != nil { - db.Close() - return err - } - - pcp.storage = db - return nil -} - -func (pcp *ConnectionPool) run() { - for { - select { - case cmd := <-pcp.cmds: - cmd(pcp) - case <-time.After(time.Minute): - pcp.cleanDB() - case <-time.After(time.Minute * 5): - pcp.cleanToken() - } - } -} - -func (pcp *ConnectionPool) cleanDB() { - valid := time.Now().Add(-maxDBIdle) - for _, con := range pcp.conns { - if con.refCount <= 0 && con.last().Before(valid) { - con.close() - } - } -} - -func (pcp *ConnectionPool) cleanToken() { - now := time.Now() - for token, con := range pcp.conns { - expires := time.Unix(con.session.ExpiresAt, 0) - if expires.Before(now) { - // TODO: Be more graceful here? - con.close() - delete(pcp.conns, token) - pcp.remove(token) - } - } -} - -func (pcp *ConnectionPool) remove(token string) { - if pcp.storage == nil { - return - } - err := pcp.storage.Update(func(tx *bolt.Tx) error { - b := tx.Bucket(sessionsBucket) - return b.Delete([]byte(token)) - }) - if err != nil { - log.Printf("error: %v\n", err) - } -} - -func (pcp *ConnectionPool) Delete(token string) bool { - res := make(chan bool) - pcp.cmds <- func(pcp *ConnectionPool) { - conn, found := pcp.conns[token] - if !found { - res <- false - return - } - conn.close() - delete(pcp.conns, token) - pcp.remove(token) - res <- true - } - return <-res -} - -func (pcp *ConnectionPool) store(token string, con *Connection) { - if pcp.storage == nil { - return - } - err := pcp.storage.Update(func(tx *bolt.Tx) error { - b := tx.Bucket(sessionsBucket) - var buf bytes.Buffer - if err := con.serialize(&buf); err != nil { - return err - } - return b.Put([]byte(token), buf.Bytes()) - }) - if err != nil { - log.Printf("error: %v\n", err) - } -} - -func (pcp *ConnectionPool) Add(token string, session *Session) *Connection { - res := make(chan *Connection) - - pcp.cmds <- func(cp *ConnectionPool) { - con := pcp.conns[token] - if con == nil { - con = &Connection{} - pcp.conns[token] = con - } - con.set(session) - pcp.store(token, con) - res <- con - } - - con := <-res - return con -} - -func (pcp *ConnectionPool) Renew(token string) (string, error) { - - type result struct { - newToken string - err error - } - - resCh := make(chan result) - - pcp.cmds <- func(cp *ConnectionPool) { - con := pcp.conns[token] - if con == nil { - resCh <- result{err: ErrNoSuchToken} - } else { - delete(pcp.conns, token) - pcp.remove(token) - newToken := GenerateSessionKey() - // TODO: Ensure that this is not racy! - con.session.ExpiresAt = time.Now().Add(maxTokenValid).Unix() - pcp.conns[newToken] = con - pcp.store(newToken, con) - resCh <- result{newToken: newToken} - } - } - - r := <-resCh - return r.newToken, r.err -} - -func (pcp *ConnectionPool) trim(conn *Connection) { - - conn.refCount-- - - for { - least := time.Now() - var count int - var oldest *Connection - - for _, con := range pcp.conns { - if con.db != nil && con.refCount <= 0 { - if last := con.last(); last.Before(least) { - least = last - oldest = con - } - count++ - } - } - if count <= maxOpen { - break - } - oldest.close() - } -} - -func (pcp *ConnectionPool) Do(token string, fn func(*sql.DB) error) error { - - type result struct { - con *Connection - err error - } - - res := make(chan result) - - pcp.cmds <- func(pcp *ConnectionPool) { - con := pcp.conns[token] - if con == nil { - res <- result{err: ErrNoSuchToken} - return - } - con.touch() - // store the session here. The ref counting for - // open db connections is irrelevant for persistence - // as they all come up closed when the system reboots. - pcp.store(token, con) - - if con.db != nil { - con.refCount++ - res <- result{con: con} - return - } - - session := con.session - db, err := OpenDB(session.User, session.Password) - if err != nil { - res <- result{err: err} - return - } - con.db = db - con.refCount++ - res <- result{con: con} - } - - r := <-res - - if r.err != nil { - return r.err - } - - defer func() { - pcp.cmds <- func(pcp *ConnectionPool) { - pcp.trim(r.con) - } - }() - - return fn(r.con.db) -} - -func (pcp *ConnectionPool) Session(token string) *Session { - res := make(chan *Session) - pcp.cmds <- func(pcp *ConnectionPool) { - con := pcp.conns[token] - if con == nil { - res <- nil - } else { - con.touch() - pcp.store(token, con) - res <- con.session - } - } - return <-res -} - -func (pcp *ConnectionPool) Logout(user string) { - pcp.cmds <- func(pcp *ConnectionPool) { - for token, con := range pcp.conns { - if con.session.User == user { - if db := con.db; db != nil { - con.db = nil - db.Close() - } - delete(pcp.conns, token) - pcp.remove(token) - } - } - } -} - -func (pcp *ConnectionPool) Shutdown() error { - if db := pcp.storage; db != nil { - log.Println("info: shutdown persistent connection pool.") - pcp.storage = nil - return db.Close() - } - log.Println("info: shutdown in-memory connection pool.") - return nil -} diff -r a9440a4826aa -r c1047fd04a3a auth/session.go --- a/auth/session.go Wed Aug 15 17:13:28 2018 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,90 +0,0 @@ -package auth - -import ( - "encoding/base64" - "io" - "time" - - "gemma.intevation.de/gemma/common" - "gemma.intevation.de/gemma/misc" -) - -type Roles []string - -type Session struct { - ExpiresAt int64 `json:"expires"` - User string `json:"user"` - Password string `json:"password"` - Roles Roles `json:"roles"` -} - -func (r Roles) Has(role string) bool { - for _, x := range r { - if x == role { - return true - } - } - return false -} - -const ( - sessionKeyLength = 20 - maxTokenValid = time.Hour * 3 -) - -func NewSession(user, password string, roles []string) *Session { - - // Create the Claims - return &Session{ - ExpiresAt: time.Now().Add(maxTokenValid).Unix(), - User: user, - Password: password, - Roles: roles, - } -} - -func (s *Session) serialize(w io.Writer) error { - wr := misc.BinWriter{w, nil} - wr.WriteBin(s.ExpiresAt) - wr.WriteString(s.User) - wr.WriteString(s.Password) - wr.WriteBin(uint32(len(s.Roles))) - for _, role := range s.Roles { - wr.WriteString(role) - } - return wr.Err -} - -func (s *Session) deserialize(r io.Reader) error { - var x Session - var n uint32 - rd := misc.BinReader{r, nil} - rd.ReadBin(&x.ExpiresAt) - rd.ReadString(&x.User) - rd.ReadString(&x.Password) - rd.ReadBin(&n) - x.Roles = make(Roles, n) - for i := uint32(0); n > 0 && i < n; i++ { - rd.ReadString(&x.Roles[i]) - } - if rd.Err == nil { - *s = x - } - return rd.Err -} - -func GenerateSessionKey() string { - return base64.URLEncoding.EncodeToString( - common.GenerateRandomKey(sessionKeyLength)) -} - -func GenerateSession(user, password string) (string, *Session, error) { - roles, err := AllOtherRoles(user, password) - if err != nil { - return "", nil, err - } - token := GenerateSessionKey() - session := NewSession(user, password, roles) - ConnPool.Add(token, session) - return token, session, nil -} diff -r a9440a4826aa -r c1047fd04a3a cmd/gemma/geoserver.go --- a/cmd/gemma/geoserver.go Wed Aug 15 17:13:28 2018 +0200 +++ b/cmd/gemma/geoserver.go Wed Aug 15 17:30:50 2018 +0200 @@ -7,8 +7,8 @@ "log" "net/http" - "gemma.intevation.de/gemma/config" - "gemma.intevation.de/gemma/misc" + "gemma.intevation.de/gemma/pkg/config" + "gemma.intevation.de/gemma/pkg/misc" ) const ( diff -r a9440a4826aa -r c1047fd04a3a cmd/gemma/main.go --- a/cmd/gemma/main.go Wed Aug 15 17:13:28 2018 +0200 +++ b/cmd/gemma/main.go Wed Aug 15 17:30:50 2018 +0200 @@ -14,9 +14,9 @@ "github.com/rs/cors" "github.com/spf13/cobra" - "gemma.intevation.de/gemma/auth" - "gemma.intevation.de/gemma/config" - "gemma.intevation.de/gemma/controllers" + "gemma.intevation.de/gemma/pkg/auth" + "gemma.intevation.de/gemma/pkg/config" + "gemma.intevation.de/gemma/pkg/controllers" ) func prepareConnectionPool() { diff -r a9440a4826aa -r c1047fd04a3a common/random.go --- a/common/random.go Wed Aug 15 17:13:28 2018 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,48 +0,0 @@ -package common - -import ( - "bytes" - "crypto/rand" - "io" - "log" - "math/big" -) - -func GenerateRandomKey(length int) []byte { - k := make([]byte, length) - if _, err := io.ReadFull(rand.Reader, k); err != nil { - return nil - } - return k -} - -func RandomString(n int) string { - - const ( - special = "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~" - alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + - "abcdefghijklmnopqrstuvwxyz" + - "0123456789" + - special - ) - - max := big.NewInt(int64(len(alphabet))) - out := make([]byte, n) - - for i := 0; i < 1000; i++ { - for i := range out { - v, err := rand.Int(rand.Reader, max) - if err != nil { - log.Panicf("error: %v\n", err) - } - out[i] = alphabet[v.Int64()] - } - // Ensure at least one special char. - if bytes.IndexAny(out, special) >= 0 { - return string(out) - } - } - log.Println("warn: Your random generator may be broken.") - out[0] = special[0] - return string(out) -} diff -r a9440a4826aa -r c1047fd04a3a config/config.go --- a/config/config.go Wed Aug 15 17:13:28 2018 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,186 +0,0 @@ -package config - -import ( - "crypto/sha256" - "fmt" - "log" - "sync" - - homedir "github.com/mitchellh/go-homedir" - "github.com/spf13/cobra" - "github.com/spf13/viper" - - "gemma.intevation.de/gemma/common" -) - -// This is not part of the persistent config. -var configFile string - -func ConfigFile() string { return configFile } - -func DBHost() string { return viper.GetString("dbhost") } -func DBPort() uint { return uint(viper.GetInt32("dbport")) } -func DBName() string { return viper.GetString("dbname") } -func DBSSLMode() string { return viper.GetString("dbssl") } -func SessionStore() string { return viper.GetString("sessions") } -func Web() string { return viper.GetString("web") } -func WebHost() string { return viper.GetString("host") } -func WebPort() uint { return uint(viper.GetInt32("port")) } - -func ServiceUser() string { return viper.GetString("service-user") } -func ServicePassword() string { return viper.GetString("service-password") } - -func SysAdmin() string { return viper.GetString("sys-admin") } -func SysAdminPassword() string { return viper.GetString("sys-admin-password") } - -func MailHost() string { return viper.GetString("mail-host") } -func MailPort() uint { return uint(viper.GetInt32("mail-port")) } -func MailUser() string { return viper.GetString("mail-user") } -func MailPassword() string { return viper.GetString("mail-password") } -func MailFrom() string { return viper.GetString("mail-from") } -func MailHelo() string { return viper.GetString("mail-helo") } - -func AllowedOrigins() []string { return viper.GetStringSlice("allowed-origins") } - -func ExternalWFSs() map[string]interface{} { return viper.GetStringMap("external-wfs") } - -func GeoServerURL() string { return viper.GetString("geoserver-url") } -func GeoServerUser() string { return viper.GetString("geoserver-user") } -func GeoServerPassword() string { return viper.GetString("geoserver-password") } -func GeoServerTables() []string { return viper.GetStringSlice("geoserver-tables") } - -var ( - proxyKeyOnce sync.Once - proxyKey []byte - - proxyPrefixOnce sync.Once - proxyPrefix string -) - -func ProxyKey() []byte { - fetchKey := func() { - if proxyKey == nil { - key := []byte(viper.GetString("proxy-key")) - if len(key) == 0 { - key = common.GenerateRandomKey(64) - } - hash := sha256.New() - hash.Write(key) - proxyKey = hash.Sum(nil) - } - } - proxyKeyOnce.Do(fetchKey) - return proxyKey -} - -func ProxyPrefix() string { - fetchPrefix := func() { - if proxyPrefix == "" { - proxyPrefix = fmt.Sprintf("http://%s:%d", WebHost(), WebPort()) - } - } - proxyPrefixOnce.Do(fetchPrefix) - return proxyPrefix -} - -var RootCmd = &cobra.Command{ - Use: "gemma", - Short: "gemma is a server for waterway monitoring and management", -} - -var allowedOrigins = []string{ - // TODO: Fill me! -} - -var geoTables = []string{ - "fairway_dimensions", -} - -func init() { - cobra.OnInitialize(initConfig) - fl := RootCmd.PersistentFlags() - fl.StringVarP(&configFile, "config", "c", "", "config file (default is $HOME/.gemma.toml)") - - vbind := func(name string) { viper.BindPFlag(name, fl.Lookup(name)) } - - str := func(name, value, usage string) { - fl.String(name, value, usage) - vbind(name) - } - strP := func(name, shorthand, value, usage string) { - fl.StringP(name, shorthand, value, usage) - vbind(name) - } - ui := func(name string, value uint, usage string) { - fl.Uint(name, value, usage) - vbind(name) - } - uiP := func(name, shorthand string, value uint, usage string) { - fl.UintP(name, shorthand, value, usage) - vbind(name) - } - strSl := func(name string, value []string, usage string) { - fl.StringSlice(name, value, usage) - vbind(name) - } - - strP("dbhost", "H", "localhost", "host of the database") - uiP("dbport", "P", 5432, "port of the database") - strP("dbname", "d", "gemma", "name of the database") - strP("dbssl", "S", "prefer", "SSL mode of the database") - - strP("sessions", "s", "", "path to the sessions file") - - strP("web", "w", "./web", "path to the web files") - strP("host", "o", "localhost", "host of the web app") - uiP("port", "p", 8000, "port of the web app") - - str("service-user", "postgres", "user to do service tasks") - str("service-password", "", "password of user to do service tasks") - - str("sys-admin", "postgres", "user to do admin tasks") - str("sys-admin-password", "", "password of user to do admin tasks") - - str("mail-host", "localhost", "server to send mail with") - ui("mail-port", 465, "port of server to send mail with") - str("mail-user", "gemma", "user to send mail with") - str("mail-password", "", "password of user to send mail with") - str("mail-from", "noreplay@localhost", "from line of mails") - str("mail-helo", "localhost", "name of server to send mail from.") - - strSl("allowed-origins", allowedOrigins, "allow access for remote origins") - - str("geoserver-url", "http://localhost:8080/geoserver", "URL to GeoServer") - str("geoserver-user", "admin", "GeoServer user") - str("geoserver-password", "geoserver", "GeoServer password") - strSl("geoserver-tables", geoTables, "tables to publish with GeoServer") - - str("proxy-key", "", `signing key for proxy URLs. Defaults to random key.`) - str("proxy-prefix", "", `URL prefix of proxy. Defaults to "http://${web-host}:${web-port}"`) - -} - -func initConfig() { - // Don't forget to read config either from cfgFile or from home directory! - if configFile != "" { - // Use config file from the flag. - viper.SetConfigFile(configFile) - } else { - // Find home directory. - home, err := homedir.Dir() - if err != nil { - log.Fatalf("error: %v\n", err) - } - - // Search config in home directory with name ".cobra" (without extension). - viper.AddConfigPath(home) - viper.SetConfigName(".gemma") - } - if err := viper.ReadInConfig(); err != nil { - if _, ok := err.(viper.ConfigFileNotFoundError); ok && configFile == "" { - // Don't bother if not found. - return - } - log.Fatalf("Can't read config: %v\n", err) - } -} diff -r a9440a4826aa -r c1047fd04a3a controllers/json.go --- a/controllers/json.go Wed Aug 15 17:13:28 2018 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,111 +0,0 @@ -package controllers - -import ( - "database/sql" - "encoding/json" - "fmt" - "log" - "net/http" - - "github.com/jackc/pgx" - - "gemma.intevation.de/gemma/auth" -) - -type JSONResult struct { - Code int - Result interface{} -} - -type JSONHandler struct { - Input func() interface{} - Handle func(interface{}, *http.Request, *sql.DB) (JSONResult, error) -} - -type JSONError struct { - Code int - Message string -} - -func (je JSONError) Error() string { - return fmt.Sprintf("%d: %s", je.Code, je.Message) -} - -func (j *JSONHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { - - var input interface{} - if j.Input != nil { - input = j.Input() - defer req.Body.Close() - if err := json.NewDecoder(req.Body).Decode(input); err != nil { - http.Error(rw, "error: "+err.Error(), http.StatusBadRequest) - return - } - } - - var jr JSONResult - var err error - - if token, ok := auth.GetToken(req); ok { - err = auth.ConnPool.Do(token, func(db *sql.DB) (err error) { - jr, err = j.Handle(input, req, db) - return err - }) - } else { - jr, err = j.Handle(input, req, nil) - } - - if err != nil { - switch e := err.(type) { - case pgx.PgError: - var res = struct { - Result string `json:"result"` - Code string `json:"code"` - Message string `json:"message"` - }{ - Result: "failure", - Code: e.Code, - Message: e.Message, - } - rw.Header().Set("Content-Type", "application/json") - rw.WriteHeader(http.StatusInternalServerError) - if err := json.NewEncoder(rw).Encode(&res); err != nil { - log.Printf("error: %v\n", err) - } - case JSONError: - rw.Header().Set("Content-Type", "application/json") - if e.Code == 0 { - e.Code = http.StatusInternalServerError - } - rw.WriteHeader(e.Code) - var res = struct { - Message string `json:"message"` - }{ - Message: e.Message, - } - if err := json.NewEncoder(rw).Encode(&res); err != nil { - log.Printf("error: %v\n", err) - } - default: - log.Printf("err: %v\n", err) - http.Error(rw, - "error: "+err.Error(), - http.StatusInternalServerError) - } - return - } - - if jr.Code == 0 { - jr.Code = http.StatusOK - } - - if jr.Code != http.StatusNoContent { - rw.Header().Set("Content-Type", "application/json") - } - rw.WriteHeader(jr.Code) - if jr.Code != http.StatusNoContent { - if err := json.NewEncoder(rw).Encode(jr.Result); err != nil { - log.Printf("error: %v\n", err) - } - } -} diff -r a9440a4826aa -r c1047fd04a3a controllers/proxy.go --- a/controllers/proxy.go Wed Aug 15 17:13:28 2018 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,380 +0,0 @@ -package controllers - -import ( - "compress/flate" - "compress/gzip" - "crypto/hmac" - "crypto/sha256" - "encoding/base64" - "encoding/xml" - "io" - "io/ioutil" - "log" - "net/http" - "net/url" - "regexp" - "strings" - "time" - - "gemma.intevation.de/gemma/config" - "github.com/gorilla/mux" - "golang.org/x/net/html/charset" -) - -// proxyBlackList is a set of URLs that should not be rewritten by the proxy. -var proxyBlackList = map[string]struct{}{ - "http://www.w3.org/2001/XMLSchema-instance": struct{}{}, - "http://www.w3.org/1999/xlink": struct{}{}, - "http://www.w3.org/2001/XMLSchema": struct{}{}, - "http://www.w3.org/XML/1998/namespace": struct{}{}, - "http://www.opengis.net/wfs/2.0": struct{}{}, - "http://www.opengis.net/ows/1.1": struct{}{}, - "http://www.opengis.net/gml/3.2": struct{}{}, - "http://www.opengis.net/fes/2.0": struct{}{}, - "http://schemas.opengis.net/gml": struct{}{}, -} - -func findEntry(entry string) (string, bool) { - external := config.ExternalWFSs() - if external == nil || len(external) == 0 { - return "", false - } - alias, found := external[entry] - if !found { - return "", false - } - data, ok := alias.(map[string]interface{}) - if !ok { - return "", false - } - urlS, found := data["url"] - if !found { - return "", false - } - url, ok := urlS.(string) - return url, ok -} - -func proxyDirector(req *http.Request) { - - log.Printf("proxyDirector: %s\n", req.RequestURI) - - abort := func(format string, args ...interface{}) { - log.Printf(format, args...) - panic(http.ErrAbortHandler) - } - - vars := mux.Vars(req) - - var s string - - if entry, found := vars["entry"]; found { - if s, found = findEntry(entry); !found { - abort("Cannot find entry '%s'\n", entry) - } - } else { - expectedMAC, err := base64.URLEncoding.DecodeString(vars["hash"]) - if err != nil { - abort("Cannot base64 decode hash: %v\n", err) - } - url, err := base64.URLEncoding.DecodeString(vars["url"]) - if err != nil { - abort("Cannot base64 decode url: %v\n", err) - } - - mac := hmac.New(sha256.New, config.ProxyKey()) - mac.Write(url) - messageMAC := mac.Sum(nil) - - s = string(url) - - if !hmac.Equal(messageMAC, expectedMAC) { - abort("HMAC of URL %s failed.\n", s) - } - } - - nURL := s + "?" + req.URL.RawQuery - //log.Printf("%v\n", nURL) - - u, err := url.Parse(nURL) - if err != nil { - abort("Invalid url: %v\n", err) - } - req.URL = u - - req.Host = u.Host - //req.Header.Del("If-None-Match") - //log.Printf("headers: %v\n", req.Header) -} - -type nopCloser struct { - io.Writer -} - -func (nopCloser) Close() error { return nil } - -func encoding(h http.Header) ( - func(io.Reader) (io.ReadCloser, error), - func(io.Writer) (io.WriteCloser, error), -) { - switch enc := h.Get("Content-Encoding"); { - case strings.Contains(enc, "gzip"): - log.Println("gzip compression") - return func(r io.Reader) (io.ReadCloser, error) { - return gzip.NewReader(r) - }, - func(w io.Writer) (io.WriteCloser, error) { - return gzip.NewWriter(w), nil - } - case strings.Contains(enc, "deflate"): - log.Println("Deflate compression") - return func(r io.Reader) (io.ReadCloser, error) { - return flate.NewReader(r), nil - }, - func(w io.Writer) (io.WriteCloser, error) { - return flate.NewWriter(w, flate.DefaultCompression) - } - default: - log.Println("No content compression") - return func(r io.Reader) (io.ReadCloser, error) { - if r2, ok := r.(io.ReadCloser); ok { - return r2, nil - } - return ioutil.NopCloser(r), nil - }, - func(w io.Writer) (io.WriteCloser, error) { - if w2, ok := w.(io.WriteCloser); ok { - return w2, nil - } - return nopCloser{w}, nil - } - } -} - -func proxyModifyResponse(resp *http.Response) error { - - if !isXML(resp.Header) { - return nil - } - - pr, pw := io.Pipe() - - var ( - r io.ReadCloser - w io.WriteCloser - err error - ) - - reader, writer := encoding(resp.Header) - - if r, err = reader(resp.Body); err != nil { - return err - } - - if w, err = writer(pw); err != nil { - return err - } - - go func(force io.ReadCloser) { - start := time.Now() - defer func() { - //r.Close() - w.Close() - pw.Close() - force.Close() - log.Printf("rewrite took %s\n", time.Since(start)) - }() - if err := rewrite(w, r); err != nil { - log.Printf("rewrite failed: %v\n", err) - return - } - log.Println("rewrite successful") - }(resp.Body) - - resp.Body = pr - - return nil -} - -var xmlContentTypes = []string{ - "application/xml", - "text/xml", - "application/gml+xml", -} - -func isXML(h http.Header) bool { - for _, t := range h["Content-Type"] { - t = strings.ToLower(t) - for _, ct := range xmlContentTypes { - if strings.Contains(t, ct) { - return true - } - } - } - return false -} - -var replaceRe = regexp.MustCompile(`\b(https?://[^\s\?]*)`) - -func replace(s string) string { - - proxyKey := config.ProxyKey() - proxyPrefix := config.ProxyPrefix() + "/api/proxy/" - - return replaceRe.ReplaceAllStringFunc(s, func(s string) string { - if _, found := proxyBlackList[s]; found { - return s - } - mac := hmac.New(sha256.New, proxyKey) - b := []byte(s) - mac.Write(b) - expectedMAC := mac.Sum(nil) - - hash := base64.URLEncoding.EncodeToString(expectedMAC) - enc := base64.URLEncoding.EncodeToString(b) - return proxyPrefix + hash + "/" + enc - }) -} - -func rewrite(w io.Writer, r io.Reader) error { - - decoder := xml.NewDecoder(r) - decoder.CharsetReader = charset.NewReaderLabel - - encoder := xml.NewEncoder(w) - - var n nsdef - -tokens: - for { - tok, err := decoder.Token() - switch { - case tok == nil && err == io.EOF: - break tokens - case err != nil: - return err - } - - switch t := tok.(type) { - case xml.StartElement: - t = t.Copy() - - isDef := n.isDef(t.Name.Space) - n = n.push() - - for i := range t.Attr { - t.Attr[i].Value = replace(t.Attr[i].Value) - n.checkDef(&t.Attr[i]) - } - - for i := range t.Attr { - n.adjust(&t.Attr[i]) - } - - switch { - case isDef: - t.Name.Space = "" - default: - if s := n.lookup(t.Name.Space); s != "" { - t.Name.Space = "" - t.Name.Local = s + ":" + t.Name.Local - } - } - tok = t - - case xml.CharData: - tok = xml.CharData(replace(string(t))) - - case xml.EndElement: - s := n.lookup(t.Name.Space) - - n = n.pop() - - if n.isDef(t.Name.Space) { - t.Name.Space = "" - } else if s != "" { - t.Name.Space = "" - t.Name.Local = s + ":" + t.Name.Local - } - tok = t - } - - if err := encoder.EncodeToken(tok); err != nil { - return err - } - } - - return encoder.Flush() -} - -type nsframe struct { - def string - ns map[string]string -} - -type nsdef []nsframe - -func (n nsdef) setDef(def string) { - if l := len(n); l > 0 { - n[l-1].def = def - } -} - -func (n nsdef) isDef(s string) bool { - for i := len(n) - 1; i >= 0; i-- { - if x := n[i].def; x != "" { - return s == x - } - } - return false -} - -func (n nsdef) define(ns, s string) { - if l := len(n); l > 0 { - n[l-1].ns[ns] = s - } -} - -func (n nsdef) lookup(ns string) string { - for i := len(n) - 1; i >= 0; i-- { - if s := n[i].ns[ns]; s != "" { - return s - } - } - return "" -} - -func (n nsdef) checkDef(at *xml.Attr) { - if at.Name.Space == "" && at.Name.Local == "xmlns" { - n.setDef(at.Value) - } -} - -func (n nsdef) adjust(at *xml.Attr) { - switch { - case at.Name.Space == "xmlns": - n.define(at.Value, at.Name.Local) - at.Name.Local = "xmlns:" + at.Name.Local - at.Name.Space = "" - - case at.Name.Space != "": - if n.isDef(at.Name.Space) { - at.Name.Space = "" - } else if s := n.lookup(at.Name.Space); s != "" { - at.Name.Local = s + ":" + at.Name.Local - at.Name.Space = "" - } - } -} - -func (n nsdef) push() nsdef { - return append(n, nsframe{ns: make(map[string]string)}) -} - -func (n nsdef) pop() nsdef { - if l := len(n); l > 0 { - n[l-1] = nsframe{} - n = n[:l-1] - } - return n -} diff -r a9440a4826aa -r c1047fd04a3a controllers/pwreset.go --- a/controllers/pwreset.go Wed Aug 15 17:13:28 2018 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,279 +0,0 @@ -package controllers - -import ( - "bytes" - "database/sql" - "encoding/hex" - "log" - "net/http" - "os/exec" - "strings" - "text/template" - "time" - - "github.com/gorilla/mux" - - "gemma.intevation.de/gemma/auth" - "gemma.intevation.de/gemma/common" - "gemma.intevation.de/gemma/config" - "gemma.intevation.de/gemma/misc" -) - -const ( - insertRequestSQL = `INSERT INTO pw_reset.password_reset_requests - (hash, username) VALUES ($1, $2)` - - countRequestsSQL = `SELECT count(*) FROM pw_reset.password_reset_requests` - - countRequestsUserSQL = `SELECT count(*) FROM pw_reset.password_reset_requests - WHERE username = $1` - - deleteRequestSQL = `DELETE FROM pw_reset.password_reset_requests - WHERE hash = $1` - - findRequestSQL = `SELECT lu.email_address, lu.username - FROM pw_reset.password_reset_requests prr - JOIN pw_reset.list_users lu on prr.username = lu.username - WHERE prr.hash = $1` - - cleanupRequestsSQL = `DELETE FROM pw_reset.password_reset_requests - WHERE issued < $1` - - userExistsSQL = `SELECT email_address - FROM pw_reset.list_users WHERE username = $1` - - updatePasswordSQL = `UPDATE pw_reset.list_users - SET pw = $1 WHERE username = $2` -) - -const ( - hashLength = 16 - passwordLength = 20 - passwordResetValid = 12 * time.Hour - maxPasswordResets = 1000 - maxPasswordRequestsPerUser = 5 - cleanupPause = 15 * time.Minute -) - -var ( - passwordResetRequestMailTmpl = template.Must( - template.New("request").Parse(`You have requested a password change -for your account {{ .User }} on -{{ .HTTPS }}://{{ .Server }} - -Please follow this link to get to the page where you can change your password. - -{{ .HTTPS }}://{{ .Server }}/api/users/passwordreset/{{ .Hash }} - -The link is only valid for 12 hours. - -Best regards - Your service team`)) - - passwordResetMailTmpl = template.Must( - template.New("reset").Parse(`Your password for your account {{ .User }} on -{{ .HTTPS }}://{{ .Server }} - -has been changed to - {{ .Password }} - -Change it as soon as possible. - -Best regards - Your service team`)) -) - -func asServiceUser(fn func(*sql.DB) error) error { - db, err := auth.OpenDB(config.ServiceUser(), config.ServicePassword()) - if err == nil { - defer db.Close() - err = fn(db) - } - return err -} - -func init() { - go removeOutdated() -} - -func removeOutdated() { - for { - time.Sleep(cleanupPause) - err := asServiceUser(func(db *sql.DB) error { - good := time.Now().Add(-passwordResetValid) - _, err := db.Exec(cleanupRequestsSQL, good) - return err - }) - if err != nil { - log.Printf("error: %v\n", err) - } - } -} - -func requestMessageBody(https, user, hash, server string) string { - var content = struct { - User string - HTTPS string - Server string - Hash string - }{ - User: user, - HTTPS: https, - Server: server, - Hash: hash, - } - var buf bytes.Buffer - if err := passwordResetRequestMailTmpl.Execute(&buf, &content); err != nil { - log.Printf("error: %v\n", err) - } - return buf.String() -} - -func changedMessageBody(https, user, password, server string) string { - var content = struct { - User string - HTTPS string - Server string - Password string - }{ - User: user, - HTTPS: https, - Server: server, - Password: password, - } - var buf bytes.Buffer - if err := passwordResetMailTmpl.Execute(&buf, &content); err != nil { - log.Printf("error: %v\n", err) - } - return buf.String() -} - -func useHTTPS(req *http.Request) string { - if strings.ToLower(req.URL.Scheme) == "https" { - return "https" - } - return "http" -} - -func generateHash() string { - return hex.EncodeToString(common.GenerateRandomKey(hashLength)) -} - -func generateNewPassword() string { - // First try pwgen - out, err := exec.Command("pwgen", "-y", "20", "1").Output() - if err == nil { - return strings.TrimSpace(string(out)) - } - // Use internal generator. - return common.RandomString(20) -} - -func passwordResetRequest( - input interface{}, - req *http.Request, - _ *sql.DB, -) (jr JSONResult, err error) { - - user := input.(*PWResetUser) - - if user.User == "" { - err = JSONError{http.StatusBadRequest, "Invalid user name"} - return - } - - var hash, email string - - if err = asServiceUser(func(db *sql.DB) error { - - var count int64 - if err := db.QueryRow(countRequestsSQL).Scan(&count); err != nil { - return err - } - - // Limit total number of password requests. - if count >= maxPasswordResets { - return JSONError{ - Code: http.StatusServiceUnavailable, - Message: "Too much password reset request", - } - } - - err := db.QueryRow(userExistsSQL, user.User).Scan(&email) - - switch { - case err == sql.ErrNoRows: - return JSONError{http.StatusNotFound, "User does not exist."} - case err != nil: - return err - } - - if err := db.QueryRow(countRequestsUserSQL, user.User).Scan(&count); err != nil { - return err - } - - // Limit requests per user - if count >= maxPasswordRequestsPerUser { - return JSONError{ - Code: http.StatusServiceUnavailable, - Message: "Too much password reset requests for user", - } - } - - hash = generateHash() - _, err = db.Exec(insertRequestSQL, hash, user.User) - return err - }); err == nil { - body := requestMessageBody(useHTTPS(req), user.User, hash, req.Host) - - if err = misc.SendMail(email, "Password Reset Link", body); err == nil { - jr.Result = &struct { - SendTo string `json:"send-to"` - }{email} - } - } - return -} - -func passwordReset( - _ interface{}, - req *http.Request, - _ *sql.DB, -) (jr JSONResult, err error) { - - hash := mux.Vars(req)["hash"] - if _, err = hex.DecodeString(hash); err != nil { - err = JSONError{http.StatusBadRequest, "Invalid hash"} - return - } - - var email, user, password string - - if err = asServiceUser(func(db *sql.DB) error { - err := db.QueryRow(findRequestSQL, hash).Scan(&email, &user) - switch { - case err == sql.ErrNoRows: - return JSONError{http.StatusNotFound, "No such hash"} - case err != nil: - return err - } - password = generateNewPassword() - res, err := db.Exec(updatePasswordSQL, password, user) - if err != nil { - return err - } - if n, err2 := res.RowsAffected(); err2 == nil && n == 0 { - return JSONError{http.StatusNotFound, "User not found"} - } - _, err = db.Exec(deleteRequestSQL, hash) - return err - }); err == nil { - body := changedMessageBody(useHTTPS(req), user, password, req.Host) - if err = misc.SendMail(email, "Password Reset Done", body); err == nil { - jr.Result = &struct { - SendTo string `json:"send-to"` - }{email} - } - } - return -} diff -r a9440a4826aa -r c1047fd04a3a controllers/routes.go --- a/controllers/routes.go Wed Aug 15 17:13:28 2018 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,77 +0,0 @@ -package controllers - -import ( - "net/http" - "net/http/httputil" - - "gemma.intevation.de/gemma/auth" - - "github.com/gorilla/mux" -) - -func BindRoutes(m *mux.Router) { - - api := m.PathPrefix("/api").Subrouter() - - var ( - sysAdmin = auth.EnsureRole("sys_admin") - all = auth.EnsureRole("sys_admin", "waterway_admin", "waterway_user") - ) - - // User management. - api.Handle("/users", all(&JSONHandler{ - Handle: listUsers, - })).Methods(http.MethodGet) - - api.Handle("/users", sysAdmin(&JSONHandler{ - Input: func() interface{} { return new(User) }, - Handle: createUser, - })).Methods(http.MethodPost) - - api.Handle("/users/{user}", all(&JSONHandler{ - Handle: listUser, - })).Methods(http.MethodGet) - - api.Handle("/users/{user}", all(&JSONHandler{ - Input: func() interface{} { return new(User) }, - Handle: updateUser, - })).Methods(http.MethodPut) - - api.Handle("/users/{user}", sysAdmin(&JSONHandler{ - Handle: deleteUser, - })).Methods(http.MethodDelete) - - // Password resets. - api.Handle("/users/passwordreset", &JSONHandler{ - Input: func() interface{} { return new(PWResetUser) }, - Handle: passwordResetRequest, - }).Methods(http.MethodPost) - - api.Handle("/users/passwordreset/{hash}", &JSONHandler{ - Handle: passwordReset, - }).Methods(http.MethodGet) - - // Proxy for external WFSs. - proxy := &httputil.ReverseProxy{ - Director: proxyDirector, - ModifyResponse: proxyModifyResponse, - } - - api.Handle(`/proxy/{hash}/{url}`, proxy). - Methods( - http.MethodGet, http.MethodPost, - http.MethodPut, http.MethodDelete) - - api.Handle("/proxy/{entry}", proxy). - Methods( - http.MethodGet, http.MethodPost, - http.MethodPut, http.MethodDelete) - - // Token handling: Login/Logout. - api.HandleFunc("/login", login). - Methods(http.MethodGet, http.MethodPost) - api.Handle("/logout", auth.SessionMiddleware(http.HandlerFunc(logout))). - Methods(http.MethodGet, http.MethodPost) - api.Handle("/renew", auth.SessionMiddleware(http.HandlerFunc(renew))). - Methods(http.MethodGet, http.MethodPost) -} diff -r a9440a4826aa -r c1047fd04a3a controllers/token.go --- a/controllers/token.go Wed Aug 15 17:13:28 2018 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,90 +0,0 @@ -package controllers - -import ( - "encoding/json" - "fmt" - "log" - "net/http" - - "gemma.intevation.de/gemma/auth" -) - -func sendJSON(rw http.ResponseWriter, data interface{}) { - rw.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(rw).Encode(data); err != nil { - log.Printf("error: %v\n", err) - } -} - -func renew(rw http.ResponseWriter, req *http.Request) { - token, _ := auth.GetToken(req) - newToken, err := auth.ConnPool.Renew(token) - switch { - case err == auth.ErrNoSuchToken: - http.NotFound(rw, req) - return - case err != nil: - http.Error(rw, fmt.Sprintf("error: %v", err), http.StatusInternalServerError) - return - } - - session, _ := auth.GetSession(req) - - var result = struct { - Token string `json:"token"` - Expires int64 `json:"expires"` - User string `json:"user"` - Roles []string `json:"roles"` - }{ - Token: newToken, - Expires: session.ExpiresAt, - User: session.User, - Roles: session.Roles, - } - - sendJSON(rw, &result) -} - -func logout(rw http.ResponseWriter, req *http.Request) { - token, _ := auth.GetToken(req) - deleted := auth.ConnPool.Delete(token) - if !deleted { - http.NotFound(rw, req) - return - } - rw.Header().Set("Content-Type", "text/plain") - fmt.Fprintln(rw, "token deleted") -} - -func login(rw http.ResponseWriter, req *http.Request) { - - var ( - user = req.FormValue("user") - password = req.FormValue("password") - ) - - if user == "" || password == "" { - http.Error(rw, "Invalid credentials", http.StatusBadRequest) - return - } - - token, session, err := auth.GenerateSession(user, password) - if err != nil { - http.Error(rw, fmt.Sprintf("error: %v", err), http.StatusUnauthorized) - return - } - - var result = struct { - Token string `json:"token"` - Expires int64 `json:"expires"` - User string `json:"user"` - Roles []string `json:"roles"` - }{ - Token: token, - Expires: session.ExpiresAt, - User: session.User, - Roles: session.Roles, - } - - sendJSON(rw, &result) -} diff -r a9440a4826aa -r c1047fd04a3a controllers/types.go --- a/controllers/types.go Wed Aug 15 17:13:28 2018 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,150 +0,0 @@ -package controllers - -import ( - "database/sql/driver" - "encoding/json" - "errors" - "regexp" - "strings" -) - -type ( - Email string - Country string - Role string - - BoundingBox struct { - X1 float64 `json:"x1"` - Y1 float64 `json:"y1"` - X2 float64 `json:"x2"` - Y2 float64 `json:"y2"` - } - - User struct { - User string `json:"user"` - Role Role `json:"role"` - Password string `json:"password,omitempty"` - Email Email `json:"email"` - Country Country `json:"country"` - Extent *BoundingBox `json:"extent"` - } - - PWResetUser struct { - User string `json:"user"` - } -) - -var ( - // https://stackoverflow.com/questions/201323/how-to-validate-an-email-address-using-a-regular-expression - emailRe = regexp.MustCompile( - `(?:[a-z0-9!#$%&'*+/=?^_` + "`" + - `{|}~-]+(?:\.[a-z0-9!#$%&'*+/=?^_` + "`" + - `{|}~-]+)*|"(?:[\x01-\x08\x0b\x0c\x0e-\x1f\x21\x23-\x5b\x5d-\x7f]` + - `|\\[\x01-\x09\x0b\x0c\x0e-\x7f])*")` + - `@(?:(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?` + - `|\[(?:(?:(2(5[0-5]|[0-4][0-9])|1[0-9][0-9]|[1-9]?[0-9]))\.){3}` + - `(?:(2(5[0-5]|[0-4][0-9])|1[0-9][0-9]|[1-9]?[0-9])|[a-z0-9-]*[a-z0-9]` + - `:(?:[\x01-\x08\x0b\x0c\x0e-\x1f\x21-\x5a\x53-\x7f]` + - `|\\[\x01-\x09\x0b\x0c\x0e-\x7f])+)\])`) - - errNoEmailAddress = errors.New("Not a valid email address") - errNoString = errors.New("Not a string") -) - -func (e *Email) UnmarshalJSON(data []byte) error { - var s string - if err := json.Unmarshal(data, &s); err != nil { - return err - } - if !emailRe.MatchString(s) { - return errNoEmailAddress - } - *e = Email(s) - return nil -} - -func (e Email) Value() (driver.Value, error) { - return string(e), nil -} - -func (e *Email) Scan(src interface{}) (err error) { - if s, ok := src.(string); ok { - *e = Email(s) - } else { - err = errNoString - } - return -} - -var ( - validCountries = []string{ - "AT", "BG", "DE", "HU", "HR", - "MD", "RO", "RS", "SK", "UA", - } - errNoValidCountry = errors.New("Not a valid country") -) - -func (c *Country) UnmarshalJSON(data []byte) error { - var s string - if err := json.Unmarshal(data, &s); err != nil { - return err - } - s = strings.ToUpper(s) - for _, v := range validCountries { - if v == s { - *c = Country(v) - return nil - } - } - return errNoValidCountry -} - -func (c Country) Value() (driver.Value, error) { - return string(c), nil -} - -func (c *Country) Scan(src interface{}) (err error) { - if s, ok := src.(string); ok { - *c = Country(s) - } else { - err = errNoString - } - return -} - -var ( - validRoles = []string{ - "waterway_user", - "waterway_admin", - "sys_admin", - } - errNoValidRole = errors.New("Not a valid role") -) - -func (r Role) Value() (driver.Value, error) { - return string(r), nil -} - -func (r *Role) Scan(src interface{}) (err error) { - if s, ok := src.(string); ok { - *r = Role(s) - } else { - err = errNoString - } - return -} - -func (r *Role) UnmarshalJSON(data []byte) error { - var s string - if err := json.Unmarshal(data, &s); err != nil { - return err - } - s = strings.ToLower(s) - for _, v := range validRoles { - if v == s { - *r = Role(v) - return nil - } - } - return errNoValidRole -} diff -r a9440a4826aa -r c1047fd04a3a controllers/user.go --- a/controllers/user.go Wed Aug 15 17:13:28 2018 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,288 +0,0 @@ -package controllers - -import ( - "database/sql" - "fmt" - "net/http" - - "github.com/gorilla/mux" - - "gemma.intevation.de/gemma/auth" -) - -const ( - createUserSQL = `INSERT INTO users.list_users - VALUES ($1, $2, $3, $4, NULL, $5)` - createUserExtentSQL = `INSERT INTO users.list_users - VALUES ($1, $2, $3, $4, - ST_MakeBox2D(ST_Point($5, $6), ST_Point($7, $8)), $9)` - - updateUserUnprivSQL = `UPDATE users.list_users - SET (pw, map_extent, email_address) - = ($2, ST_MakeBox2D(ST_Point($3, $4), ST_Point($5, $6)), $7) - WHERE username = $1` - updateUserSQL = `UPDATE users.list_users - SET (rolname, username, pw, country, map_extent, email_address) - = ($2, $3, $4, $5, NULL, $6) - WHERE username = $1` - updateUserExtentSQL = `UPDATE users.list_users - SET (rolname, username, pw, country, map_extent, email_address) - = ($2, $3, $4, $5, ST_MakeBox2D(ST_Point($6, $7), ST_Point($8, $9)), $10) - WHERE username = $1` - - deleteUserSQL = `DELETE FROM users.list_users WHERE username = $1` - - listUsersSQL = `SELECT - rolname, - username, - country, - email_address, - ST_XMin(map_extent), ST_YMin(map_extent), - ST_XMax(map_extent), ST_YMax(map_extent) -FROM users.list_users` - - listUserSQL = `SELECT - rolname, - country, - email_address, - ST_XMin(map_extent), ST_YMin(map_extent), - ST_XMax(map_extent), ST_YMax(map_extent) -FROM users.list_users -WHERE username = $1` -) - -func deleteUser( - _ interface{}, req *http.Request, - db *sql.DB, -) (jr JSONResult, err error) { - - user := mux.Vars(req)["user"] - if user == "" { - err = JSONError{http.StatusBadRequest, "error: user empty"} - return - } - - session, _ := auth.GetSession(req) - if session.User == user { - err = JSONError{http.StatusBadRequest, "error: cannot delete yourself"} - return - } - - var res sql.Result - - if res, err = db.Exec(deleteUserSQL, user); err != nil { - return - } - - if n, err2 := res.RowsAffected(); err2 == nil && n == 0 { - err = JSONError{ - Code: http.StatusNotFound, - Message: fmt.Sprintf("Cannot find user %s.", user), - } - return - } - - // Running in a go routine should not be necessary. - go func() { auth.ConnPool.Logout(user) }() - - jr = JSONResult{Code: http.StatusNoContent} - return -} - -func updateUser( - input interface{}, req *http.Request, - db *sql.DB, -) (jr JSONResult, err error) { - - user := mux.Vars(req)["user"] - if user == "" { - err = JSONError{http.StatusBadRequest, "error: user empty"} - return - } - - newUser := input.(*User) - var res sql.Result - - if s, _ := auth.GetSession(req); s.Roles.Has("sys_admin") { - if newUser.Extent == nil { - res, err = db.Exec( - updateUserSQL, - user, - newUser.Role, - newUser.User, - newUser.Password, - newUser.Country, - newUser.Email, - ) - } else { - res, err = db.Exec( - updateUserExtentSQL, - user, - newUser.Role, - newUser.User, - newUser.Password, - newUser.Country, - newUser.Extent.X1, newUser.Extent.Y1, - newUser.Extent.X2, newUser.Extent.Y2, - newUser.Email, - ) - } - } else { - if newUser.Extent == nil { - err = JSONError{http.StatusBadRequest, "extent is mandatory"} - return - } - res, err = db.Exec( - updateUserUnprivSQL, - user, - newUser.Password, - newUser.Extent.X1, newUser.Extent.Y1, - newUser.Extent.X2, newUser.Extent.Y2, - newUser.Email, - ) - } - - if err != nil { - return - } - - if n, err2 := res.RowsAffected(); err2 == nil && n == 0 { - err = JSONError{ - Code: http.StatusNotFound, - Message: fmt.Sprintf("Cannot find user %s.", user), - } - return - } - - if user != newUser.User { - // Running in a go routine should not be necessary. - go func() { auth.ConnPool.Logout(user) }() - } - - jr = JSONResult{ - Code: http.StatusCreated, - Result: struct { - Result string `json:"result"` - }{"success"}, - } - return -} - -func createUser( - input interface{}, req *http.Request, - db *sql.DB, -) (jr JSONResult, err error) { - - user := input.(*User) - - if user.Extent == nil { - _, err = db.Exec( - createUserSQL, - user.Role, - user.User, - user.Password, - user.Country, - user.Email, - ) - } else { - _, err = db.Exec( - createUserExtentSQL, - user.Role, - user.User, - user.Password, - user.Country, - user.Extent.X1, user.Extent.Y1, - user.Extent.X2, user.Extent.Y2, - user.Email, - ) - } - - if err != nil { - return - } - - jr = JSONResult{ - Code: http.StatusCreated, - Result: struct { - Result string `json:"result"` - }{"success"}, - } - return -} - -func listUsers( - _ interface{}, req *http.Request, - db *sql.DB, -) (jr JSONResult, err error) { - - var rows *sql.Rows - - rows, err = db.Query(listUsersSQL) - if err != nil { - return - } - defer rows.Close() - - var users []*User - - for rows.Next() { - user := &User{Extent: &BoundingBox{}} - if err = rows.Scan( - &user.Role, - &user.User, - &user.Country, - &user.Email, - &user.Extent.X1, &user.Extent.Y1, - &user.Extent.X2, &user.Extent.Y2, - ); err != nil { - return - } - users = append(users, user) - } - - jr = JSONResult{ - Result: struct { - Users []*User `json:"users"` - }{users}, - } - return -} - -func listUser( - _ interface{}, req *http.Request, - db *sql.DB, -) (jr JSONResult, err error) { - - user := mux.Vars(req)["user"] - if user == "" { - err = JSONError{http.StatusBadRequest, "error: user empty"} - return - } - - result := &User{ - User: user, - Extent: &BoundingBox{}, - } - - err = db.QueryRow(listUserSQL, user).Scan( - &result.Role, - &result.Country, - &result.Email, - &result.Extent.X1, &result.Extent.Y1, - &result.Extent.X2, &result.Extent.Y2, - ) - - switch { - case err == sql.ErrNoRows: - err = JSONError{ - Code: http.StatusNotFound, - Message: fmt.Sprintf("Cannot find user %s.", user), - } - return - case err != nil: - return - } - - jr.Result = result - return -} diff -r a9440a4826aa -r c1047fd04a3a misc/encode.go --- a/misc/encode.go Wed Aug 15 17:13:28 2018 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,76 +0,0 @@ -package misc - -import ( - "encoding/base64" - "encoding/binary" - "io" -) - -type BinReader struct { - io.Reader - Err error -} - -func (r *BinReader) Read(buf []byte) (int, error) { - if r.Err != nil { - return 0, r.Err - } - var n int - n, r.Err = r.Read(buf) - return n, r.Err -} - -func (r *BinReader) ReadBin(x interface{}) { - if r.Err == nil { - r.Err = binary.Read(r.Reader, binary.BigEndian, x) - } -} - -func (r *BinReader) ReadString(s *string) { - if r.Err != nil { - return - } - var l uint32 - if r.Err = binary.Read(r.Reader, binary.BigEndian, &l); r.Err != nil { - return - } - b := make([]byte, l) - if r.Err = binary.Read(r.Reader, binary.BigEndian, b); r.Err != nil { - return - } - *s = string(b) -} - -type BinWriter struct { - io.Writer - Err error -} - -func (w *BinWriter) Write(buf []byte) (int, error) { - if w.Err != nil { - return 0, w.Err - } - var n int - n, w.Err = w.Writer.Write(buf) - return n, w.Err -} - -func (w *BinWriter) WriteBin(x interface{}) { - if w.Err == nil { - w.Err = binary.Write(w.Writer, binary.BigEndian, x) - } -} - -func (w *BinWriter) WriteString(s string) { - if w.Err == nil { - w.Err = binary.Write(w.Writer, binary.BigEndian, uint32(len(s))) - } - if w.Err == nil { - w.Err = binary.Write(w.Writer, binary.BigEndian, []byte(s)) - } -} - -func BasicAuth(user, password string) string { - auth := user + ":" + password - return base64.StdEncoding.EncodeToString([]byte(auth)) -} diff -r a9440a4826aa -r c1047fd04a3a misc/mail.go --- a/misc/mail.go Wed Aug 15 17:13:28 2018 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,26 +0,0 @@ -package misc - -import ( - gomail "gopkg.in/gomail.v2" - - "gemma.intevation.de/gemma/config" -) - -func SendMail(email, subject, body string) error { - m := gomail.NewMessage() - m.SetHeader("From", config.MailFrom()) - m.SetHeader("To", email) - m.SetHeader("Subject", subject) - m.SetBody("text/plain", body) - - d := gomail.Dialer{ - Host: config.MailHost(), - Port: int(config.MailPort()), - Username: config.MailUser(), - Password: config.MailPassword(), - LocalName: config.MailHelo(), - SSL: config.MailPort() == 465, - } - - return d.DialAndSend(m) -} diff -r a9440a4826aa -r c1047fd04a3a pkg/auth/connection.go --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pkg/auth/connection.go Wed Aug 15 17:30:50 2018 +0200 @@ -0,0 +1,99 @@ +package auth + +import ( + "database/sql" + "errors" + "io" + "log" + "sync" + "time" + + "gemma.intevation.de/gemma/pkg/misc" +) + +var ErrNoSuchToken = errors.New("No such token") + +const ( + maxOpen = 16 + maxDBIdle = time.Minute * 5 +) + +type Connection struct { + session *Session + + access time.Time + db *sql.DB + refCount int + + mu sync.Mutex +} + +func (c *Connection) serialize(w io.Writer) error { + if err := c.session.serialize(w); err != nil { + return err + } + access, err := c.last().MarshalText() + if err != nil { + return err + } + wr := misc.BinWriter{w, nil} + wr.WriteBin(uint32(len(access))) + wr.WriteBin(access) + return wr.Err +} + +func (c *Connection) deserialize(r io.Reader) error { + session := new(Session) + if err := session.deserialize(r); err != nil { + return err + } + + rd := misc.BinReader{r, nil} + var l uint32 + rd.ReadBin(&l) + access := make([]byte, l) + rd.ReadBin(access) + + if rd.Err != nil { + return rd.Err + } + + var t time.Time + if err := t.UnmarshalText(access); err != nil { + return err + } + + *c = Connection{ + session: session, + access: t, + } + + return nil +} + +func (c *Connection) set(session *Session) { + c.session = session + c.touch() +} + +func (c *Connection) touch() { + c.mu.Lock() + c.access = time.Now() + c.mu.Unlock() +} + +func (c *Connection) last() time.Time { + c.mu.Lock() + access := c.access + c.mu.Unlock() + return access +} + +func (c *Connection) close() { + if c.db != nil { + if err := c.db.Close(); err != nil { + log.Printf("warn: %v\n", err) + } + c.db = nil + } +} diff -r a9440a4826aa -r c1047fd04a3a pkg/auth/middleware.go --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pkg/auth/middleware.go Wed Aug 15 17:30:50 2018 +0200 @@ -0,0 +1,79 @@ +package auth + +import ( + "context" + "net/http" + "strings" +) + +type contextType int + +const ( + sessionKey contextType = iota + tokenKey +) + +func GetSession(req *http.Request) (*Session, bool) { + session, ok := req.Context().Value(sessionKey).(*Session) + return session, ok +} + +func GetToken(req *http.Request) (string, bool) { + token, ok := req.Context().Value(tokenKey).(string) + return token, ok +} + +func SessionMiddleware(next http.Handler) http.Handler { + + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + + auth := req.Header.Get("X-Gemma-Auth") + + token := strings.TrimSpace(auth) + if token == "" { + http.Error(rw, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + return + } + + session := ConnPool.Session(token) + if session == nil { + http.Error(rw, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + return + } + + ctx := req.Context() + ctx = context.WithValue(ctx, sessionKey, session) + ctx = context.WithValue(ctx, tokenKey, token) + req = req.WithContext(ctx) + + next.ServeHTTP(rw, req) + }) +} + +func SessionChecker(next http.Handler, check func(*Session) bool) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + claims, ok := GetSession(req) + if !ok || !check(claims) { + http.Error(rw, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + return + } + next.ServeHTTP(rw, req) + }) +} + +func HasRole(roles ...string) func(*Session) bool { + return func(session *Session) bool { + for _, r1 := range roles { + if session.Roles.Has(r1) { + return true + } + } + return false + } +} + +func EnsureRole(roles ...string) func(http.Handler) http.Handler { + return func(handler http.Handler) http.Handler { + return SessionMiddleware(SessionChecker(handler, HasRole(roles...))) + } +} diff -r a9440a4826aa -r c1047fd04a3a pkg/auth/opendb.go --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pkg/auth/opendb.go Wed Aug 15 17:30:50 2018 +0200 @@ -0,0 +1,69 @@ +package auth + +import ( + "database/sql" + "fmt" + "strings" + + "gemma.intevation.de/gemma/pkg/config" + + _ "github.com/jackc/pgx/stdlib" +) + +const driver = "pgx" + +// dbQuote quotes strings to be able to contain whitespace +// and backslashes in database DSN strings. +var dbQuote = strings.NewReplacer(`\`, `\\`, `'`, `\'`).Replace + +// dbDSN creates a data source name suitable for sql.Open on +// PostgreSQL databases. +func dbDSN(host string, port uint, dbname, user, password string, sslmode string) string { + return fmt.Sprintf("host=%s port=%d dbname=%s user=%s password=%s sslmode=%s", + dbQuote(host), port, dbQuote(dbname), + dbQuote(user), dbQuote(password), sslmode) +} + +func OpenDB(user, password string) (*sql.DB, error) { + dsn := dbDSN( + config.DBHost(), config.DBPort(), + config.DBName(), + user, password, + config.DBSSLMode()) + return sql.Open(driver, dsn) +} + +const allRoles = ` +WITH RECURSIVE cte AS ( + SELECT oid FROM pg_roles WHERE rolname = current_user + UNION ALL + SELECT m.roleid + FROM cte + JOIN pg_auth_members m ON m.member = cte.oid +) +SELECT rolname FROM pg_roles +WHERE oid IN (SELECT oid FROM cte) AND rolname <> current_user` + +func AllOtherRoles(user, password string) ([]string, error) { + db, err := OpenDB(user, password) + if err != nil { + return nil, err + } + defer db.Close() + rows, err := db.Query(allRoles) + if err != nil { + return nil, err + } + defer rows.Close() + + roles := []string{} // explicit empty by intention. + + for rows.Next() { + var role string + if err := rows.Scan(&role); err != nil { + return nil, err + } + roles = append(roles, role) + } + return roles, rows.Err() +} diff -r a9440a4826aa -r c1047fd04a3a pkg/auth/pool.go --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pkg/auth/pool.go Wed Aug 15 17:30:50 2018 +0200 @@ -0,0 +1,322 @@ +package auth + +import ( + "bytes" + "database/sql" + "log" + "time" + + bolt "github.com/coreos/bbolt" +) + +// ConnPool is the global connection pool. +var ConnPool *ConnectionPool + +type ConnectionPool struct { + storage *bolt.DB + conns map[string]*Connection + cmds chan func(*ConnectionPool) +} + +var sessionsBucket = []byte("sessions") + +func NewConnectionPool(filename string) (*ConnectionPool, error) { + + pcp := &ConnectionPool{ + conns: make(map[string]*Connection), + cmds: make(chan func(*ConnectionPool)), + } + if err := pcp.openStorage(filename); err != nil { + return nil, err + } + go pcp.run() + return pcp, nil +} + +// openStorage opens a storage file. +func (pcp *ConnectionPool) openStorage(filename string) error { + + // No file, nothing to restore/persist. + if filename == "" { + return nil + } + + db, err := bolt.Open(filename, 0600, nil) + if err != nil { + return err + } + + err = db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucketIfNotExists(sessionsBucket) + if err != nil { + return err + } + + // pre-load sessions + c := b.Cursor() + + for k, v := c.First(); k != nil; k, v = c.Next() { + var conn Connection + if err := conn.deserialize(bytes.NewReader(v)); err != nil { + return err + } + pcp.conns[string(k)] = &conn + } + + return nil + }) + + if err != nil { + db.Close() + return err + } + + pcp.storage = db + return nil +} + +func (pcp *ConnectionPool) run() { + for { + select { + case cmd := <-pcp.cmds: + cmd(pcp) + case <-time.After(time.Minute): + pcp.cleanDB() + case <-time.After(time.Minute * 5): + pcp.cleanToken() + } + } +} + +func (pcp *ConnectionPool) cleanDB() { + valid := time.Now().Add(-maxDBIdle) + for _, con := range pcp.conns { + if con.refCount <= 0 && con.last().Before(valid) { + con.close() + } + } +} + +func (pcp *ConnectionPool) cleanToken() { + now := time.Now() + for token, con := range pcp.conns { + expires := time.Unix(con.session.ExpiresAt, 0) + if expires.Before(now) { + // TODO: Be more graceful here? + con.close() + delete(pcp.conns, token) + pcp.remove(token) + } + } +} + +func (pcp *ConnectionPool) remove(token string) { + if pcp.storage == nil { + return + } + err := pcp.storage.Update(func(tx *bolt.Tx) error { + b := tx.Bucket(sessionsBucket) + return b.Delete([]byte(token)) + }) + if err != nil { + log.Printf("error: %v\n", err) + } +} + +func (pcp *ConnectionPool) Delete(token string) bool { + res := make(chan bool) + pcp.cmds <- func(pcp *ConnectionPool) { + conn, found := pcp.conns[token] + if !found { + res <- false + return + } + conn.close() + delete(pcp.conns, token) + pcp.remove(token) + res <- true + } + return <-res +} + +func (pcp *ConnectionPool) store(token string, con *Connection) { + if pcp.storage == nil { + return + } + err := pcp.storage.Update(func(tx *bolt.Tx) error { + b := tx.Bucket(sessionsBucket) + var buf bytes.Buffer + if err := con.serialize(&buf); err != nil { + return err + } + return b.Put([]byte(token), buf.Bytes()) + }) + if err != nil { + log.Printf("error: %v\n", err) + } +} + +func (pcp *ConnectionPool) Add(token string, session *Session) *Connection { + res := make(chan *Connection) + + pcp.cmds <- func(cp *ConnectionPool) { + con := pcp.conns[token] + if con == nil { + con = &Connection{} + pcp.conns[token] = con + } + con.set(session) + pcp.store(token, con) + res <- con + } + + con := <-res + return con +} + +func (pcp *ConnectionPool) Renew(token string) (string, error) { + + type result struct { + newToken string + err error + } + + resCh := make(chan result) + + pcp.cmds <- func(cp *ConnectionPool) { + con := pcp.conns[token] + if con == nil { + resCh <- result{err: ErrNoSuchToken} + } else { + delete(pcp.conns, token) + pcp.remove(token) + newToken := GenerateSessionKey() + // TODO: Ensure that this is not racy! + con.session.ExpiresAt = time.Now().Add(maxTokenValid).Unix() + pcp.conns[newToken] = con + pcp.store(newToken, con) + resCh <- result{newToken: newToken} + } + } + + r := <-resCh + return r.newToken, r.err +} + +func (pcp *ConnectionPool) trim(conn *Connection) { + + conn.refCount-- + + for { + least := time.Now() + var count int + var oldest *Connection + + for _, con := range pcp.conns { + if con.db != nil && con.refCount <= 0 { + if last := con.last(); last.Before(least) { + least = last + oldest = con + } + count++ + } + } + if count <= maxOpen { + break + } + oldest.close() + } +} + +func (pcp *ConnectionPool) Do(token string, fn func(*sql.DB) error) error { + + type result struct { + con *Connection + err error + } + + res := make(chan result) + + pcp.cmds <- func(pcp *ConnectionPool) { + con := pcp.conns[token] + if con == nil { + res <- result{err: ErrNoSuchToken} + return + } + con.touch() + // store the session here. The ref counting for + // open db connections is irrelevant for persistence + // as they all come up closed when the system reboots. + pcp.store(token, con) + + if con.db != nil { + con.refCount++ + res <- result{con: con} + return + } + + session := con.session + db, err := OpenDB(session.User, session.Password) + if err != nil { + res <- result{err: err} + return + } + con.db = db + con.refCount++ + res <- result{con: con} + } + + r := <-res + + if r.err != nil { + return r.err + } + + defer func() { + pcp.cmds <- func(pcp *ConnectionPool) { + pcp.trim(r.con) + } + }() + + return fn(r.con.db) +} + +func (pcp *ConnectionPool) Session(token string) *Session { + res := make(chan *Session) + pcp.cmds <- func(pcp *ConnectionPool) { + con := pcp.conns[token] + if con == nil { + res <- nil + } else { + con.touch() + pcp.store(token, con) + res <- con.session + } + } + return <-res +} + +func (pcp *ConnectionPool) Logout(user string) { + pcp.cmds <- func(pcp *ConnectionPool) { + for token, con := range pcp.conns { + if con.session.User == user { + if db := con.db; db != nil { + con.db = nil + db.Close() + } + delete(pcp.conns, token) + pcp.remove(token) + } + } + } +} + +func (pcp *ConnectionPool) Shutdown() error { + if db := pcp.storage; db != nil { + log.Println("info: shutdown persistent connection pool.") + pcp.storage = nil + return db.Close() + } + log.Println("info: shutdown in-memory connection pool.") + return nil +} diff -r a9440a4826aa -r c1047fd04a3a pkg/auth/session.go --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pkg/auth/session.go Wed Aug 15 17:30:50 2018 +0200 @@ -0,0 +1,90 @@ +package auth + +import ( + "encoding/base64" + "io" + "time" + + "gemma.intevation.de/gemma/pkg/common" + "gemma.intevation.de/gemma/pkg/misc" +) + +type Roles []string + +type Session struct { + ExpiresAt int64 `json:"expires"` + User string `json:"user"` + Password string `json:"password"` + Roles Roles `json:"roles"` +} + +func (r Roles) Has(role string) bool { + for _, x := range r { + if x == role { + return true + } + } + return false +} + +const ( + sessionKeyLength = 20 + maxTokenValid = time.Hour * 3 +) + +func NewSession(user, password string, roles []string) *Session { + + // Create the Claims + return &Session{ + ExpiresAt: time.Now().Add(maxTokenValid).Unix(), + User: user, + Password: password, + Roles: roles, + } +} + +func (s *Session) serialize(w io.Writer) error { + wr := misc.BinWriter{w, nil} + wr.WriteBin(s.ExpiresAt) + wr.WriteString(s.User) + wr.WriteString(s.Password) + wr.WriteBin(uint32(len(s.Roles))) + for _, role := range s.Roles { + wr.WriteString(role) + } + return wr.Err +} + +func (s *Session) deserialize(r io.Reader) error { + var x Session + var n uint32 + rd := misc.BinReader{r, nil} + rd.ReadBin(&x.ExpiresAt) + rd.ReadString(&x.User) + rd.ReadString(&x.Password) + rd.ReadBin(&n) + x.Roles = make(Roles, n) + for i := uint32(0); n > 0 && i < n; i++ { + rd.ReadString(&x.Roles[i]) + } + if rd.Err == nil { + *s = x + } + return rd.Err +} + +func GenerateSessionKey() string { + return base64.URLEncoding.EncodeToString( + common.GenerateRandomKey(sessionKeyLength)) +} + +func GenerateSession(user, password string) (string, *Session, error) { + roles, err := AllOtherRoles(user, password) + if err != nil { + return "", nil, err + } + token := GenerateSessionKey() + session := NewSession(user, password, roles) + ConnPool.Add(token, session) + return token, session, nil +} diff -r a9440a4826aa -r c1047fd04a3a pkg/common/random.go --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pkg/common/random.go Wed Aug 15 17:30:50 2018 +0200 @@ -0,0 +1,48 @@ +package common + +import ( + "bytes" + "crypto/rand" + "io" + "log" + "math/big" +) + +func GenerateRandomKey(length int) []byte { + k := make([]byte, length) + if _, err := io.ReadFull(rand.Reader, k); err != nil { + return nil + } + return k +} + +func RandomString(n int) string { + + const ( + special = "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~" + alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + + "abcdefghijklmnopqrstuvwxyz" + + "0123456789" + + special + ) + + max := big.NewInt(int64(len(alphabet))) + out := make([]byte, n) + + for i := 0; i < 1000; i++ { + for i := range out { + v, err := rand.Int(rand.Reader, max) + if err != nil { + log.Panicf("error: %v\n", err) + } + out[i] = alphabet[v.Int64()] + } + // Ensure at least one special char. + if bytes.IndexAny(out, special) >= 0 { + return string(out) + } + } + log.Println("warn: Your random generator may be broken.") + out[0] = special[0] + return string(out) +} diff -r a9440a4826aa -r c1047fd04a3a pkg/config/config.go --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pkg/config/config.go Wed Aug 15 17:30:50 2018 +0200 @@ -0,0 +1,186 @@ +package config + +import ( + "crypto/sha256" + "fmt" + "log" + "sync" + + homedir "github.com/mitchellh/go-homedir" + "github.com/spf13/cobra" + "github.com/spf13/viper" + + "gemma.intevation.de/gemma/pkg/common" +) + +// This is not part of the persistent config. +var configFile string + +func ConfigFile() string { return configFile } + +func DBHost() string { return viper.GetString("dbhost") } +func DBPort() uint { return uint(viper.GetInt32("dbport")) } +func DBName() string { return viper.GetString("dbname") } +func DBSSLMode() string { return viper.GetString("dbssl") } +func SessionStore() string { return viper.GetString("sessions") } +func Web() string { return viper.GetString("web") } +func WebHost() string { return viper.GetString("host") } +func WebPort() uint { return uint(viper.GetInt32("port")) } + +func ServiceUser() string { return viper.GetString("service-user") } +func ServicePassword() string { return viper.GetString("service-password") } + +func SysAdmin() string { return viper.GetString("sys-admin") } +func SysAdminPassword() string { return viper.GetString("sys-admin-password") } + +func MailHost() string { return viper.GetString("mail-host") } +func MailPort() uint { return uint(viper.GetInt32("mail-port")) } +func MailUser() string { return viper.GetString("mail-user") } +func MailPassword() string { return viper.GetString("mail-password") } +func MailFrom() string { return viper.GetString("mail-from") } +func MailHelo() string { return viper.GetString("mail-helo") } + +func AllowedOrigins() []string { return viper.GetStringSlice("allowed-origins") } + +func ExternalWFSs() map[string]interface{} { return viper.GetStringMap("external-wfs") } + +func GeoServerURL() string { return viper.GetString("geoserver-url") } +func GeoServerUser() string { return viper.GetString("geoserver-user") } +func GeoServerPassword() string { return viper.GetString("geoserver-password") } +func GeoServerTables() []string { return viper.GetStringSlice("geoserver-tables") } + +var ( + proxyKeyOnce sync.Once + proxyKey []byte + + proxyPrefixOnce sync.Once + proxyPrefix string +) + +func ProxyKey() []byte { + fetchKey := func() { + if proxyKey == nil { + key := []byte(viper.GetString("proxy-key")) + if len(key) == 0 { + key = common.GenerateRandomKey(64) + } + hash := sha256.New() + hash.Write(key) + proxyKey = hash.Sum(nil) + } + } + proxyKeyOnce.Do(fetchKey) + return proxyKey +} + +func ProxyPrefix() string { + fetchPrefix := func() { + if proxyPrefix == "" { + proxyPrefix = fmt.Sprintf("http://%s:%d", WebHost(), WebPort()) + } + } + proxyPrefixOnce.Do(fetchPrefix) + return proxyPrefix +} + +var RootCmd = &cobra.Command{ + Use: "gemma", + Short: "gemma is a server for waterway monitoring and management", +} + +var allowedOrigins = []string{ + // TODO: Fill me! +} + +var geoTables = []string{ + "fairway_dimensions", +} + +func init() { + cobra.OnInitialize(initConfig) + fl := RootCmd.PersistentFlags() + fl.StringVarP(&configFile, "config", "c", "", "config file (default is $HOME/.gemma.toml)") + + vbind := func(name string) { viper.BindPFlag(name, fl.Lookup(name)) } + + str := func(name, value, usage string) { + fl.String(name, value, usage) + vbind(name) + } + strP := func(name, shorthand, value, usage string) { + fl.StringP(name, shorthand, value, usage) + vbind(name) + } + ui := func(name string, value uint, usage string) { + fl.Uint(name, value, usage) + vbind(name) + } + uiP := func(name, shorthand string, value uint, usage string) { + fl.UintP(name, shorthand, value, usage) + vbind(name) + } + strSl := func(name string, value []string, usage string) { + fl.StringSlice(name, value, usage) + vbind(name) + } + + strP("dbhost", "H", "localhost", "host of the database") + uiP("dbport", "P", 5432, "port of the database") + strP("dbname", "d", "gemma", "name of the database") + strP("dbssl", "S", "prefer", "SSL mode of the database") + + strP("sessions", "s", "", "path to the sessions file") + + strP("web", "w", "./web", "path to the web files") + strP("host", "o", "localhost", "host of the web app") + uiP("port", "p", 8000, "port of the web app") + + str("service-user", "postgres", "user to do service tasks") + str("service-password", "", "password of user to do service tasks") + + str("sys-admin", "postgres", "user to do admin tasks") + str("sys-admin-password", "", "password of user to do admin tasks") + + str("mail-host", "localhost", "server to send mail with") + ui("mail-port", 465, "port of server to send mail with") + str("mail-user", "gemma", "user to send mail with") + str("mail-password", "", "password of user to send mail with") + str("mail-from", "noreplay@localhost", "from line of mails") + str("mail-helo", "localhost", "name of server to send mail from.") + + strSl("allowed-origins", allowedOrigins, "allow access for remote origins") + + str("geoserver-url", "http://localhost:8080/geoserver", "URL to GeoServer") + str("geoserver-user", "admin", "GeoServer user") + str("geoserver-password", "geoserver", "GeoServer password") + strSl("geoserver-tables", geoTables, "tables to publish with GeoServer") + + str("proxy-key", "", `signing key for proxy URLs. Defaults to random key.`) + str("proxy-prefix", "", `URL prefix of proxy. Defaults to "http://${web-host}:${web-port}"`) + +} + +func initConfig() { + // Don't forget to read config either from cfgFile or from home directory! + if configFile != "" { + // Use config file from the flag. + viper.SetConfigFile(configFile) + } else { + // Find home directory. + home, err := homedir.Dir() + if err != nil { + log.Fatalf("error: %v\n", err) + } + + // Search config in home directory with name ".cobra" (without extension). + viper.AddConfigPath(home) + viper.SetConfigName(".gemma") + } + if err := viper.ReadInConfig(); err != nil { + if _, ok := err.(viper.ConfigFileNotFoundError); ok && configFile == "" { + // Don't bother if not found. + return + } + log.Fatalf("Can't read config: %v\n", err) + } +} diff -r a9440a4826aa -r c1047fd04a3a pkg/controllers/json.go --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pkg/controllers/json.go Wed Aug 15 17:30:50 2018 +0200 @@ -0,0 +1,111 @@ +package controllers + +import ( + "database/sql" + "encoding/json" + "fmt" + "log" + "net/http" + + "github.com/jackc/pgx" + + "gemma.intevation.de/gemma/pkg/auth" +) + +type JSONResult struct { + Code int + Result interface{} +} + +type JSONHandler struct { + Input func() interface{} + Handle func(interface{}, *http.Request, *sql.DB) (JSONResult, error) +} + +type JSONError struct { + Code int + Message string +} + +func (je JSONError) Error() string { + return fmt.Sprintf("%d: %s", je.Code, je.Message) +} + +func (j *JSONHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + + var input interface{} + if j.Input != nil { + input = j.Input() + defer req.Body.Close() + if err := json.NewDecoder(req.Body).Decode(input); err != nil { + http.Error(rw, "error: "+err.Error(), http.StatusBadRequest) + return + } + } + + var jr JSONResult + var err error + + if token, ok := auth.GetToken(req); ok { + err = auth.ConnPool.Do(token, func(db *sql.DB) (err error) { + jr, err = j.Handle(input, req, db) + return err + }) + } else { + jr, err = j.Handle(input, req, nil) + } + + if err != nil { + switch e := err.(type) { + case pgx.PgError: + var res = struct { + Result string `json:"result"` + Code string `json:"code"` + Message string `json:"message"` + }{ + Result: "failure", + Code: e.Code, + Message: e.Message, + } + rw.Header().Set("Content-Type", "application/json") + rw.WriteHeader(http.StatusInternalServerError) + if err := json.NewEncoder(rw).Encode(&res); err != nil { + log.Printf("error: %v\n", err) + } + case JSONError: + rw.Header().Set("Content-Type", "application/json") + if e.Code == 0 { + e.Code = http.StatusInternalServerError + } + rw.WriteHeader(e.Code) + var res = struct { + Message string `json:"message"` + }{ + Message: e.Message, + } + if err := json.NewEncoder(rw).Encode(&res); err != nil { + log.Printf("error: %v\n", err) + } + default: + log.Printf("err: %v\n", err) + http.Error(rw, + "error: "+err.Error(), + http.StatusInternalServerError) + } + return + } + + if jr.Code == 0 { + jr.Code = http.StatusOK + } + + if jr.Code != http.StatusNoContent { + rw.Header().Set("Content-Type", "application/json") + } + rw.WriteHeader(jr.Code) + if jr.Code != http.StatusNoContent { + if err := json.NewEncoder(rw).Encode(jr.Result); err != nil { + log.Printf("error: %v\n", err) + } + } +} diff -r a9440a4826aa -r c1047fd04a3a pkg/controllers/proxy.go --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pkg/controllers/proxy.go Wed Aug 15 17:30:50 2018 +0200 @@ -0,0 +1,381 @@ +package controllers + +import ( + "compress/flate" + "compress/gzip" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/xml" + "io" + "io/ioutil" + "log" + "net/http" + "net/url" + "regexp" + "strings" + "time" + + "github.com/gorilla/mux" + "golang.org/x/net/html/charset" + + "gemma.intevation.de/gemma/pkg/config" +) + +// proxyBlackList is a set of URLs that should not be rewritten by the proxy. +var proxyBlackList = map[string]struct{}{ + "http://www.w3.org/2001/XMLSchema-instance": struct{}{}, + "http://www.w3.org/1999/xlink": struct{}{}, + "http://www.w3.org/2001/XMLSchema": struct{}{}, + "http://www.w3.org/XML/1998/namespace": struct{}{}, + "http://www.opengis.net/wfs/2.0": struct{}{}, + "http://www.opengis.net/ows/1.1": struct{}{}, + "http://www.opengis.net/gml/3.2": struct{}{}, + "http://www.opengis.net/fes/2.0": struct{}{}, + "http://schemas.opengis.net/gml": struct{}{}, +} + +func findEntry(entry string) (string, bool) { + external := config.ExternalWFSs() + if external == nil || len(external) == 0 { + return "", false + } + alias, found := external[entry] + if !found { + return "", false + } + data, ok := alias.(map[string]interface{}) + if !ok { + return "", false + } + urlS, found := data["url"] + if !found { + return "", false + } + url, ok := urlS.(string) + return url, ok +} + +func proxyDirector(req *http.Request) { + + log.Printf("proxyDirector: %s\n", req.RequestURI) + + abort := func(format string, args ...interface{}) { + log.Printf(format, args...) + panic(http.ErrAbortHandler) + } + + vars := mux.Vars(req) + + var s string + + if entry, found := vars["entry"]; found { + if s, found = findEntry(entry); !found { + abort("Cannot find entry '%s'\n", entry) + } + } else { + expectedMAC, err := base64.URLEncoding.DecodeString(vars["hash"]) + if err != nil { + abort("Cannot base64 decode hash: %v\n", err) + } + url, err := base64.URLEncoding.DecodeString(vars["url"]) + if err != nil { + abort("Cannot base64 decode url: %v\n", err) + } + + mac := hmac.New(sha256.New, config.ProxyKey()) + mac.Write(url) + messageMAC := mac.Sum(nil) + + s = string(url) + + if !hmac.Equal(messageMAC, expectedMAC) { + abort("HMAC of URL %s failed.\n", s) + } + } + + nURL := s + "?" + req.URL.RawQuery + //log.Printf("%v\n", nURL) + + u, err := url.Parse(nURL) + if err != nil { + abort("Invalid url: %v\n", err) + } + req.URL = u + + req.Host = u.Host + //req.Header.Del("If-None-Match") + //log.Printf("headers: %v\n", req.Header) +} + +type nopCloser struct { + io.Writer +} + +func (nopCloser) Close() error { return nil } + +func encoding(h http.Header) ( + func(io.Reader) (io.ReadCloser, error), + func(io.Writer) (io.WriteCloser, error), +) { + switch enc := h.Get("Content-Encoding"); { + case strings.Contains(enc, "gzip"): + log.Println("gzip compression") + return func(r io.Reader) (io.ReadCloser, error) { + return gzip.NewReader(r) + }, + func(w io.Writer) (io.WriteCloser, error) { + return gzip.NewWriter(w), nil + } + case strings.Contains(enc, "deflate"): + log.Println("Deflate compression") + return func(r io.Reader) (io.ReadCloser, error) { + return flate.NewReader(r), nil + }, + func(w io.Writer) (io.WriteCloser, error) { + return flate.NewWriter(w, flate.DefaultCompression) + } + default: + log.Println("No content compression") + return func(r io.Reader) (io.ReadCloser, error) { + if r2, ok := r.(io.ReadCloser); ok { + return r2, nil + } + return ioutil.NopCloser(r), nil + }, + func(w io.Writer) (io.WriteCloser, error) { + if w2, ok := w.(io.WriteCloser); ok { + return w2, nil + } + return nopCloser{w}, nil + } + } +} + +func proxyModifyResponse(resp *http.Response) error { + + if !isXML(resp.Header) { + return nil + } + + pr, pw := io.Pipe() + + var ( + r io.ReadCloser + w io.WriteCloser + err error + ) + + reader, writer := encoding(resp.Header) + + if r, err = reader(resp.Body); err != nil { + return err + } + + if w, err = writer(pw); err != nil { + return err + } + + go func(force io.ReadCloser) { + start := time.Now() + defer func() { + //r.Close() + w.Close() + pw.Close() + force.Close() + log.Printf("rewrite took %s\n", time.Since(start)) + }() + if err := rewrite(w, r); err != nil { + log.Printf("rewrite failed: %v\n", err) + return + } + log.Println("rewrite successful") + }(resp.Body) + + resp.Body = pr + + return nil +} + +var xmlContentTypes = []string{ + "application/xml", + "text/xml", + "application/gml+xml", +} + +func isXML(h http.Header) bool { + for _, t := range h["Content-Type"] { + t = strings.ToLower(t) + for _, ct := range xmlContentTypes { + if strings.Contains(t, ct) { + return true + } + } + } + return false +} + +var replaceRe = regexp.MustCompile(`\b(https?://[^\s\?]*)`) + +func replace(s string) string { + + proxyKey := config.ProxyKey() + proxyPrefix := config.ProxyPrefix() + "/api/proxy/" + + return replaceRe.ReplaceAllStringFunc(s, func(s string) string { + if _, found := proxyBlackList[s]; found { + return s + } + mac := hmac.New(sha256.New, proxyKey) + b := []byte(s) + mac.Write(b) + expectedMAC := mac.Sum(nil) + + hash := base64.URLEncoding.EncodeToString(expectedMAC) + enc := base64.URLEncoding.EncodeToString(b) + return proxyPrefix + hash + "/" + enc + }) +} + +func rewrite(w io.Writer, r io.Reader) error { + + decoder := xml.NewDecoder(r) + decoder.CharsetReader = charset.NewReaderLabel + + encoder := xml.NewEncoder(w) + + var n nsdef + +tokens: + for { + tok, err := decoder.Token() + switch { + case tok == nil && err == io.EOF: + break tokens + case err != nil: + return err + } + + switch t := tok.(type) { + case xml.StartElement: + t = t.Copy() + + isDef := n.isDef(t.Name.Space) + n = n.push() + + for i := range t.Attr { + t.Attr[i].Value = replace(t.Attr[i].Value) + n.checkDef(&t.Attr[i]) + } + + for i := range t.Attr { + n.adjust(&t.Attr[i]) + } + + switch { + case isDef: + t.Name.Space = "" + default: + if s := n.lookup(t.Name.Space); s != "" { + t.Name.Space = "" + t.Name.Local = s + ":" + t.Name.Local + } + } + tok = t + + case xml.CharData: + tok = xml.CharData(replace(string(t))) + + case xml.EndElement: + s := n.lookup(t.Name.Space) + + n = n.pop() + + if n.isDef(t.Name.Space) { + t.Name.Space = "" + } else if s != "" { + t.Name.Space = "" + t.Name.Local = s + ":" + t.Name.Local + } + tok = t + } + + if err := encoder.EncodeToken(tok); err != nil { + return err + } + } + + return encoder.Flush() +} + +type nsframe struct { + def string + ns map[string]string +} + +type nsdef []nsframe + +func (n nsdef) setDef(def string) { + if l := len(n); l > 0 { + n[l-1].def = def + } +} + +func (n nsdef) isDef(s string) bool { + for i := len(n) - 1; i >= 0; i-- { + if x := n[i].def; x != "" { + return s == x + } + } + return false +} + +func (n nsdef) define(ns, s string) { + if l := len(n); l > 0 { + n[l-1].ns[ns] = s + } +} + +func (n nsdef) lookup(ns string) string { + for i := len(n) - 1; i >= 0; i-- { + if s := n[i].ns[ns]; s != "" { + return s + } + } + return "" +} + +func (n nsdef) checkDef(at *xml.Attr) { + if at.Name.Space == "" && at.Name.Local == "xmlns" { + n.setDef(at.Value) + } +} + +func (n nsdef) adjust(at *xml.Attr) { + switch { + case at.Name.Space == "xmlns": + n.define(at.Value, at.Name.Local) + at.Name.Local = "xmlns:" + at.Name.Local + at.Name.Space = "" + + case at.Name.Space != "": + if n.isDef(at.Name.Space) { + at.Name.Space = "" + } else if s := n.lookup(at.Name.Space); s != "" { + at.Name.Local = s + ":" + at.Name.Local + at.Name.Space = "" + } + } +} + +func (n nsdef) push() nsdef { + return append(n, nsframe{ns: make(map[string]string)}) +} + +func (n nsdef) pop() nsdef { + if l := len(n); l > 0 { + n[l-1] = nsframe{} + n = n[:l-1] + } + return n +} diff -r a9440a4826aa -r c1047fd04a3a pkg/controllers/pwreset.go --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pkg/controllers/pwreset.go Wed Aug 15 17:30:50 2018 +0200 @@ -0,0 +1,279 @@ +package controllers + +import ( + "bytes" + "database/sql" + "encoding/hex" + "log" + "net/http" + "os/exec" + "strings" + "text/template" + "time" + + "github.com/gorilla/mux" + + "gemma.intevation.de/gemma/pkg/auth" + "gemma.intevation.de/gemma/pkg/common" + "gemma.intevation.de/gemma/pkg/config" + "gemma.intevation.de/gemma/pkg/misc" +) + +const ( + insertRequestSQL = `INSERT INTO pw_reset.password_reset_requests + (hash, username) VALUES ($1, $2)` + + countRequestsSQL = `SELECT count(*) FROM pw_reset.password_reset_requests` + + countRequestsUserSQL = `SELECT count(*) FROM pw_reset.password_reset_requests + WHERE username = $1` + + deleteRequestSQL = `DELETE FROM pw_reset.password_reset_requests + WHERE hash = $1` + + findRequestSQL = `SELECT lu.email_address, lu.username + FROM pw_reset.password_reset_requests prr + JOIN pw_reset.list_users lu on prr.username = lu.username + WHERE prr.hash = $1` + + cleanupRequestsSQL = `DELETE FROM pw_reset.password_reset_requests + WHERE issued < $1` + + userExistsSQL = `SELECT email_address + FROM pw_reset.list_users WHERE username = $1` + + updatePasswordSQL = `UPDATE pw_reset.list_users + SET pw = $1 WHERE username = $2` +) + +const ( + hashLength = 16 + passwordLength = 20 + passwordResetValid = 12 * time.Hour + maxPasswordResets = 1000 + maxPasswordRequestsPerUser = 5 + cleanupPause = 15 * time.Minute +) + +var ( + passwordResetRequestMailTmpl = template.Must( + template.New("request").Parse(`You have requested a password change +for your account {{ .User }} on +{{ .HTTPS }}://{{ .Server }} + +Please follow this link to get to the page where you can change your password. + +{{ .HTTPS }}://{{ .Server }}/api/users/passwordreset/{{ .Hash }} + +The link is only valid for 12 hours. + +Best regards + Your service team`)) + + passwordResetMailTmpl = template.Must( + template.New("reset").Parse(`Your password for your account {{ .User }} on +{{ .HTTPS }}://{{ .Server }} + +has been changed to + {{ .Password }} + +Change it as soon as possible. + +Best regards + Your service team`)) +) + +func asServiceUser(fn func(*sql.DB) error) error { + db, err := auth.OpenDB(config.ServiceUser(), config.ServicePassword()) + if err == nil { + defer db.Close() + err = fn(db) + } + return err +} + +func init() { + go removeOutdated() +} + +func removeOutdated() { + for { + time.Sleep(cleanupPause) + err := asServiceUser(func(db *sql.DB) error { + good := time.Now().Add(-passwordResetValid) + _, err := db.Exec(cleanupRequestsSQL, good) + return err + }) + if err != nil { + log.Printf("error: %v\n", err) + } + } +} + +func requestMessageBody(https, user, hash, server string) string { + var content = struct { + User string + HTTPS string + Server string + Hash string + }{ + User: user, + HTTPS: https, + Server: server, + Hash: hash, + } + var buf bytes.Buffer + if err := passwordResetRequestMailTmpl.Execute(&buf, &content); err != nil { + log.Printf("error: %v\n", err) + } + return buf.String() +} + +func changedMessageBody(https, user, password, server string) string { + var content = struct { + User string + HTTPS string + Server string + Password string + }{ + User: user, + HTTPS: https, + Server: server, + Password: password, + } + var buf bytes.Buffer + if err := passwordResetMailTmpl.Execute(&buf, &content); err != nil { + log.Printf("error: %v\n", err) + } + return buf.String() +} + +func useHTTPS(req *http.Request) string { + if strings.ToLower(req.URL.Scheme) == "https" { + return "https" + } + return "http" +} + +func generateHash() string { + return hex.EncodeToString(common.GenerateRandomKey(hashLength)) +} + +func generateNewPassword() string { + // First try pwgen + out, err := exec.Command("pwgen", "-y", "20", "1").Output() + if err == nil { + return strings.TrimSpace(string(out)) + } + // Use internal generator. + return common.RandomString(20) +} + +func passwordResetRequest( + input interface{}, + req *http.Request, + _ *sql.DB, +) (jr JSONResult, err error) { + + user := input.(*PWResetUser) + + if user.User == "" { + err = JSONError{http.StatusBadRequest, "Invalid user name"} + return + } + + var hash, email string + + if err = asServiceUser(func(db *sql.DB) error { + + var count int64 + if err := db.QueryRow(countRequestsSQL).Scan(&count); err != nil { + return err + } + + // Limit total number of password requests. + if count >= maxPasswordResets { + return JSONError{ + Code: http.StatusServiceUnavailable, + Message: "Too much password reset request", + } + } + + err := db.QueryRow(userExistsSQL, user.User).Scan(&email) + + switch { + case err == sql.ErrNoRows: + return JSONError{http.StatusNotFound, "User does not exist."} + case err != nil: + return err + } + + if err := db.QueryRow(countRequestsUserSQL, user.User).Scan(&count); err != nil { + return err + } + + // Limit requests per user + if count >= maxPasswordRequestsPerUser { + return JSONError{ + Code: http.StatusServiceUnavailable, + Message: "Too much password reset requests for user", + } + } + + hash = generateHash() + _, err = db.Exec(insertRequestSQL, hash, user.User) + return err + }); err == nil { + body := requestMessageBody(useHTTPS(req), user.User, hash, req.Host) + + if err = misc.SendMail(email, "Password Reset Link", body); err == nil { + jr.Result = &struct { + SendTo string `json:"send-to"` + }{email} + } + } + return +} + +func passwordReset( + _ interface{}, + req *http.Request, + _ *sql.DB, +) (jr JSONResult, err error) { + + hash := mux.Vars(req)["hash"] + if _, err = hex.DecodeString(hash); err != nil { + err = JSONError{http.StatusBadRequest, "Invalid hash"} + return + } + + var email, user, password string + + if err = asServiceUser(func(db *sql.DB) error { + err := db.QueryRow(findRequestSQL, hash).Scan(&email, &user) + switch { + case err == sql.ErrNoRows: + return JSONError{http.StatusNotFound, "No such hash"} + case err != nil: + return err + } + password = generateNewPassword() + res, err := db.Exec(updatePasswordSQL, password, user) + if err != nil { + return err + } + if n, err2 := res.RowsAffected(); err2 == nil && n == 0 { + return JSONError{http.StatusNotFound, "User not found"} + } + _, err = db.Exec(deleteRequestSQL, hash) + return err + }); err == nil { + body := changedMessageBody(useHTTPS(req), user, password, req.Host) + if err = misc.SendMail(email, "Password Reset Done", body); err == nil { + jr.Result = &struct { + SendTo string `json:"send-to"` + }{email} + } + } + return +} diff -r a9440a4826aa -r c1047fd04a3a pkg/controllers/routes.go --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pkg/controllers/routes.go Wed Aug 15 17:30:50 2018 +0200 @@ -0,0 +1,77 @@ +package controllers + +import ( + "net/http" + "net/http/httputil" + + "github.com/gorilla/mux" + + "gemma.intevation.de/gemma/pkg/auth" +) + +func BindRoutes(m *mux.Router) { + + api := m.PathPrefix("/api").Subrouter() + + var ( + sysAdmin = auth.EnsureRole("sys_admin") + all = auth.EnsureRole("sys_admin", "waterway_admin", "waterway_user") + ) + + // User management. + api.Handle("/users", all(&JSONHandler{ + Handle: listUsers, + })).Methods(http.MethodGet) + + api.Handle("/users", sysAdmin(&JSONHandler{ + Input: func() interface{} { return new(User) }, + Handle: createUser, + })).Methods(http.MethodPost) + + api.Handle("/users/{user}", all(&JSONHandler{ + Handle: listUser, + })).Methods(http.MethodGet) + + api.Handle("/users/{user}", all(&JSONHandler{ + Input: func() interface{} { return new(User) }, + Handle: updateUser, + })).Methods(http.MethodPut) + + api.Handle("/users/{user}", sysAdmin(&JSONHandler{ + Handle: deleteUser, + })).Methods(http.MethodDelete) + + // Password resets. + api.Handle("/users/passwordreset", &JSONHandler{ + Input: func() interface{} { return new(PWResetUser) }, + Handle: passwordResetRequest, + }).Methods(http.MethodPost) + + api.Handle("/users/passwordreset/{hash}", &JSONHandler{ + Handle: passwordReset, + }).Methods(http.MethodGet) + + // Proxy for external WFSs. + proxy := &httputil.ReverseProxy{ + Director: proxyDirector, + ModifyResponse: proxyModifyResponse, + } + + api.Handle(`/proxy/{hash}/{url}`, proxy). + Methods( + http.MethodGet, http.MethodPost, + http.MethodPut, http.MethodDelete) + + api.Handle("/proxy/{entry}", proxy). + Methods( + http.MethodGet, http.MethodPost, + http.MethodPut, http.MethodDelete) + + // Token handling: Login/Logout. + api.HandleFunc("/login", login). + Methods(http.MethodGet, http.MethodPost) + api.Handle("/logout", auth.SessionMiddleware(http.HandlerFunc(logout))). + Methods(http.MethodGet, http.MethodPost) + api.Handle("/renew", auth.SessionMiddleware(http.HandlerFunc(renew))). + Methods(http.MethodGet, http.MethodPost) +} diff -r a9440a4826aa -r c1047fd04a3a pkg/controllers/token.go --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pkg/controllers/token.go Wed Aug 15 17:30:50 2018 +0200 @@ -0,0 +1,90 @@ +package controllers + +import ( + "encoding/json" + "fmt" + "log" + "net/http" + + "gemma.intevation.de/gemma/pkg/auth" +) + +func sendJSON(rw http.ResponseWriter, data interface{}) { + rw.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(rw).Encode(data); err != nil { + log.Printf("error: %v\n", err) + } +} + +func renew(rw http.ResponseWriter, req *http.Request) { + token, _ := auth.GetToken(req) + newToken, err := auth.ConnPool.Renew(token) + switch { + case err == auth.ErrNoSuchToken: + http.NotFound(rw, req) + return + case err != nil: + http.Error(rw, fmt.Sprintf("error: %v", err), http.StatusInternalServerError) + return + } + + session, _ := auth.GetSession(req) + + var result = struct { + Token string `json:"token"` + Expires int64 `json:"expires"` + User string `json:"user"` + Roles []string `json:"roles"` + }{ + Token: newToken, + Expires: session.ExpiresAt, + User: session.User, + Roles: session.Roles, + } + + sendJSON(rw, &result) +} + +func logout(rw http.ResponseWriter, req *http.Request) { + token, _ := auth.GetToken(req) + deleted := auth.ConnPool.Delete(token) + if !deleted { + http.NotFound(rw, req) + return + } + rw.Header().Set("Content-Type", "text/plain") + fmt.Fprintln(rw, "token deleted") +} + +func login(rw http.ResponseWriter, req *http.Request) { + + var ( + user = req.FormValue("user") + password = req.FormValue("password") + ) + + if user == "" || password == "" { + http.Error(rw, "Invalid credentials", http.StatusBadRequest) + return + } + + token, session, err := auth.GenerateSession(user, password) + if err != nil { + http.Error(rw, fmt.Sprintf("error: %v", err), http.StatusUnauthorized) + return + } + + var result = struct { + Token string `json:"token"` + Expires int64 `json:"expires"` + User string `json:"user"` + Roles []string `json:"roles"` + }{ + Token: token, + Expires: session.ExpiresAt, + User: session.User, + Roles: session.Roles, + } + + sendJSON(rw, &result) +} diff -r a9440a4826aa -r c1047fd04a3a pkg/controllers/types.go --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pkg/controllers/types.go Wed Aug 15 17:30:50 2018 +0200 @@ -0,0 +1,150 @@ +package controllers + +import ( + "database/sql/driver" + "encoding/json" + "errors" + "regexp" + "strings" +) + +type ( + Email string + Country string + Role string + + BoundingBox struct { + X1 float64 `json:"x1"` + Y1 float64 `json:"y1"` + X2 float64 `json:"x2"` + Y2 float64 `json:"y2"` + } + + User struct { + User string `json:"user"` + Role Role `json:"role"` + Password string `json:"password,omitempty"` + Email Email `json:"email"` + Country Country `json:"country"` + Extent *BoundingBox `json:"extent"` + } + + PWResetUser struct { + User string `json:"user"` + } +) + +var ( + // https://stackoverflow.com/questions/201323/how-to-validate-an-email-address-using-a-regular-expression + emailRe = regexp.MustCompile( + `(?:[a-z0-9!#$%&'*+/=?^_` + "`" + + `{|}~-]+(?:\.[a-z0-9!#$%&'*+/=?^_` + "`" + + `{|}~-]+)*|"(?:[\x01-\x08\x0b\x0c\x0e-\x1f\x21\x23-\x5b\x5d-\x7f]` + + `|\\[\x01-\x09\x0b\x0c\x0e-\x7f])*")` + + `@(?:(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?` + + `|\[(?:(?:(2(5[0-5]|[0-4][0-9])|1[0-9][0-9]|[1-9]?[0-9]))\.){3}` + + `(?:(2(5[0-5]|[0-4][0-9])|1[0-9][0-9]|[1-9]?[0-9])|[a-z0-9-]*[a-z0-9]` + + `:(?:[\x01-\x08\x0b\x0c\x0e-\x1f\x21-\x5a\x53-\x7f]` + + `|\\[\x01-\x09\x0b\x0c\x0e-\x7f])+)\])`) + + errNoEmailAddress = errors.New("Not a valid email address") + errNoString = errors.New("Not a string") +) + +func (e *Email) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return err + } + if !emailRe.MatchString(s) { + return errNoEmailAddress + } + *e = Email(s) + return nil +} + +func (e Email) Value() (driver.Value, error) { + return string(e), nil +} + +func (e *Email) Scan(src interface{}) (err error) { + if s, ok := src.(string); ok { + *e = Email(s) + } else { + err = errNoString + } + return +} + +var ( + validCountries = []string{ + "AT", "BG", "DE", "HU", "HR", + "MD", "RO", "RS", "SK", "UA", + } + errNoValidCountry = errors.New("Not a valid country") +) + +func (c *Country) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return err + } + s = strings.ToUpper(s) + for _, v := range validCountries { + if v == s { + *c = Country(v) + return nil + } + } + return errNoValidCountry +} + +func (c Country) Value() (driver.Value, error) { + return string(c), nil +} + +func (c *Country) Scan(src interface{}) (err error) { + if s, ok := src.(string); ok { + *c = Country(s) + } else { + err = errNoString + } + return +} + +var ( + validRoles = []string{ + "waterway_user", + "waterway_admin", + "sys_admin", + } + errNoValidRole = errors.New("Not a valid role") +) + +func (r Role) Value() (driver.Value, error) { + return string(r), nil +} + +func (r *Role) Scan(src interface{}) (err error) { + if s, ok := src.(string); ok { + *r = Role(s) + } else { + err = errNoString + } + return +} + +func (r *Role) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return err + } + s = strings.ToLower(s) + for _, v := range validRoles { + if v == s { + *r = Role(v) + return nil + } + } + return errNoValidRole +} diff -r a9440a4826aa -r c1047fd04a3a pkg/controllers/user.go --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pkg/controllers/user.go Wed Aug 15 17:30:50 2018 +0200 @@ -0,0 +1,288 @@ +package controllers + +import ( + "database/sql" + "fmt" + "net/http" + + "github.com/gorilla/mux" + + "gemma.intevation.de/gemma/pkg/auth" +) + +const ( + createUserSQL = `INSERT INTO users.list_users + VALUES ($1, $2, $3, $4, NULL, $5)` + createUserExtentSQL = `INSERT INTO users.list_users + VALUES ($1, $2, $3, $4, + ST_MakeBox2D(ST_Point($5, $6), ST_Point($7, $8)), $9)` + + updateUserUnprivSQL = `UPDATE users.list_users + SET (pw, map_extent, email_address) + = ($2, ST_MakeBox2D(ST_Point($3, $4), ST_Point($5, $6)), $7) + WHERE username = $1` + updateUserSQL = `UPDATE users.list_users + SET (rolname, username, pw, country, map_extent, email_address) + = ($2, $3, $4, $5, NULL, $6) + WHERE username = $1` + updateUserExtentSQL = `UPDATE users.list_users + SET (rolname, username, pw, country, map_extent, email_address) + = ($2, $3, $4, $5, ST_MakeBox2D(ST_Point($6, $7), ST_Point($8, $9)), $10) + WHERE username = $1` + + deleteUserSQL = `DELETE FROM users.list_users WHERE username = $1` + + listUsersSQL = `SELECT + rolname, + username, + country, + email_address, + ST_XMin(map_extent), ST_YMin(map_extent), + ST_XMax(map_extent), ST_YMax(map_extent) +FROM users.list_users` + + listUserSQL = `SELECT + rolname, + country, + email_address, + ST_XMin(map_extent), ST_YMin(map_extent), + ST_XMax(map_extent), ST_YMax(map_extent) +FROM users.list_users +WHERE username = $1` +) + +func deleteUser( + _ interface{}, req *http.Request, + db *sql.DB, +) (jr JSONResult, err error) { + + user := mux.Vars(req)["user"] + if user == "" { + err = JSONError{http.StatusBadRequest, "error: user empty"} + return + } + + session, _ := auth.GetSession(req) + if session.User == user { + err = JSONError{http.StatusBadRequest, "error: cannot delete yourself"} + return + } + + var res sql.Result + + if res, err = db.Exec(deleteUserSQL, user); err != nil { + return + } + + if n, err2 := res.RowsAffected(); err2 == nil && n == 0 { + err = JSONError{ + Code: http.StatusNotFound, + Message: fmt.Sprintf("Cannot find user %s.", user), + } + return + } + + // Running in a go routine should not be necessary. + go func() { auth.ConnPool.Logout(user) }() + + jr = JSONResult{Code: http.StatusNoContent} + return +} + +func updateUser( + input interface{}, req *http.Request, + db *sql.DB, +) (jr JSONResult, err error) { + + user := mux.Vars(req)["user"] + if user == "" { + err = JSONError{http.StatusBadRequest, "error: user empty"} + return + } + + newUser := input.(*User) + var res sql.Result + + if s, _ := auth.GetSession(req); s.Roles.Has("sys_admin") { + if newUser.Extent == nil { + res, err = db.Exec( + updateUserSQL, + user, + newUser.Role, + newUser.User, + newUser.Password, + newUser.Country, + newUser.Email, + ) + } else { + res, err = db.Exec( + updateUserExtentSQL, + user, + newUser.Role, + newUser.User, + newUser.Password, + newUser.Country, + newUser.Extent.X1, newUser.Extent.Y1, + newUser.Extent.X2, newUser.Extent.Y2, + newUser.Email, + ) + } + } else { + if newUser.Extent == nil { + err = JSONError{http.StatusBadRequest, "extent is mandatory"} + return + } + res, err = db.Exec( + updateUserUnprivSQL, + user, + newUser.Password, + newUser.Extent.X1, newUser.Extent.Y1, + newUser.Extent.X2, newUser.Extent.Y2, + newUser.Email, + ) + } + + if err != nil { + return + } + + if n, err2 := res.RowsAffected(); err2 == nil && n == 0 { + err = JSONError{ + Code: http.StatusNotFound, + Message: fmt.Sprintf("Cannot find user %s.", user), + } + return + } + + if user != newUser.User { + // Running in a go routine should not be necessary. + go func() { auth.ConnPool.Logout(user) }() + } + + jr = JSONResult{ + Code: http.StatusCreated, + Result: struct { + Result string `json:"result"` + }{"success"}, + } + return +} + +func createUser( + input interface{}, req *http.Request, + db *sql.DB, +) (jr JSONResult, err error) { + + user := input.(*User) + + if user.Extent == nil { + _, err = db.Exec( + createUserSQL, + user.Role, + user.User, + user.Password, + user.Country, + user.Email, + ) + } else { + _, err = db.Exec( + createUserExtentSQL, + user.Role, + user.User, + user.Password, + user.Country, + user.Extent.X1, user.Extent.Y1, + user.Extent.X2, user.Extent.Y2, + user.Email, + ) + } + + if err != nil { + return + } + + jr = JSONResult{ + Code: http.StatusCreated, + Result: struct { + Result string `json:"result"` + }{"success"}, + } + return +} + +func listUsers( + _ interface{}, req *http.Request, + db *sql.DB, +) (jr JSONResult, err error) { + + var rows *sql.Rows + + rows, err = db.Query(listUsersSQL) + if err != nil { + return + } + defer rows.Close() + + var users []*User + + for rows.Next() { + user := &User{Extent: &BoundingBox{}} + if err = rows.Scan( + &user.Role, + &user.User, + &user.Country, + &user.Email, + &user.Extent.X1, &user.Extent.Y1, + &user.Extent.X2, &user.Extent.Y2, + ); err != nil { + return + } + users = append(users, user) + } + + jr = JSONResult{ + Result: struct { + Users []*User `json:"users"` + }{users}, + } + return +} + +func listUser( + _ interface{}, req *http.Request, + db *sql.DB, +) (jr JSONResult, err error) { + + user := mux.Vars(req)["user"] + if user == "" { + err = JSONError{http.StatusBadRequest, "error: user empty"} + return + } + + result := &User{ + User: user, + Extent: &BoundingBox{}, + } + + err = db.QueryRow(listUserSQL, user).Scan( + &result.Role, + &result.Country, + &result.Email, + &result.Extent.X1, &result.Extent.Y1, + &result.Extent.X2, &result.Extent.Y2, + ) + + switch { + case err == sql.ErrNoRows: + err = JSONError{ + Code: http.StatusNotFound, + Message: fmt.Sprintf("Cannot find user %s.", user), + } + return + case err != nil: + return + } + + jr.Result = result + return +} diff -r a9440a4826aa -r c1047fd04a3a pkg/misc/encode.go --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pkg/misc/encode.go Wed Aug 15 17:30:50 2018 +0200 @@ -0,0 +1,76 @@ +package misc + +import ( + "encoding/base64" + "encoding/binary" + "io" +) + +type BinReader struct { + io.Reader + Err error +} + +func (r *BinReader) Read(buf []byte) (int, error) { + if r.Err != nil { + return 0, r.Err + } + var n int + n, r.Err = r.Read(buf) + return n, r.Err +} + +func (r *BinReader) ReadBin(x interface{}) { + if r.Err == nil { + r.Err = binary.Read(r.Reader, binary.BigEndian, x) + } +} + +func (r *BinReader) ReadString(s *string) { + if r.Err != nil { + return + } + var l uint32 + if r.Err = binary.Read(r.Reader, binary.BigEndian, &l); r.Err != nil { + return + } + b := make([]byte, l) + if r.Err = binary.Read(r.Reader, binary.BigEndian, b); r.Err != nil { + return + } + *s = string(b) +} + +type BinWriter struct { + io.Writer + Err error +} + +func (w *BinWriter) Write(buf []byte) (int, error) { + if w.Err != nil { + return 0, w.Err + } + var n int + n, w.Err = w.Writer.Write(buf) + return n, w.Err +} + +func (w *BinWriter) WriteBin(x interface{}) { + if w.Err == nil { + w.Err = binary.Write(w.Writer, binary.BigEndian, x) + } +} + +func (w *BinWriter) WriteString(s string) { + if w.Err == nil { + w.Err = binary.Write(w.Writer, binary.BigEndian, uint32(len(s))) + } + if w.Err == nil { + w.Err = binary.Write(w.Writer, binary.BigEndian, []byte(s)) + } +} + +func BasicAuth(user, password string) string { + auth := user + ":" + password + return base64.StdEncoding.EncodeToString([]byte(auth)) +} diff -r a9440a4826aa -r c1047fd04a3a pkg/misc/mail.go --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pkg/misc/mail.go Wed Aug 15 17:30:50 2018 +0200 @@ -0,0 +1,26 @@ +package misc + +import ( + gomail "gopkg.in/gomail.v2" + + "gemma.intevation.de/gemma/pkg/config" +) + +func SendMail(email, subject, body string) error { + m := gomail.NewMessage() + m.SetHeader("From", config.MailFrom()) + m.SetHeader("To", email) + m.SetHeader("Subject", subject) + m.SetBody("text/plain", body) + + d := gomail.Dialer{ + Host: config.MailHost(), + Port: int(config.MailPort()), + Username: config.MailUser(), + Password: config.MailPassword(), + LocalName: config.MailHelo(), + SSL: config.MailPort() == 465, + } + + return d.DialAndSend(m) +}