# HG changeset patch # User Sascha L. Teichmann # Date 1535117613 -7200 # Node ID af1a198391f3b8bb53ed1aef9d95b361cd22a508 # Parent c10c76c9279712405fc5ebd2e63b966a12079e64# Parent f6d61657b48793335d5737a5f6b0fb2f5a302907 Merged default into metamorph-for-all branch. diff -r f6d61657b487 -r af1a198391f3 cmd/gemma/main.go --- a/cmd/gemma/main.go Fri Aug 24 15:17:35 2018 +0200 +++ b/cmd/gemma/main.go Fri Aug 24 15:33:33 2018 +0200 @@ -19,13 +19,13 @@ "gemma.intevation.de/gemma/pkg/controllers" ) -func prepareConnectionPool() { - // Install connection pool - cp, err := auth.NewConnectionPool(config.SessionStore()) +func prepareSessionStore() { + // Install session store + ss, err := auth.NewSessionStore(config.SessionStore()) if err != nil { log.Fatalf("Error with session store: %v\n", err) } - auth.ConnPool = cp + auth.Sessions = ss } func start(cmd *cobra.Command, args []string) { @@ -35,7 +35,7 @@ log.Fatalf("error: %v\n", err) } - prepareConnectionPool() + prepareSessionStore() // Do GeoServer setup in background. go func() { @@ -88,7 +88,7 @@ <-done - if err := auth.ConnPool.Shutdown(); err != nil { + if err := auth.Sessions.Shutdown(); err != nil { log.Fatalf("error: %v\n", err) } } diff -r f6d61657b487 -r af1a198391f3 pkg/auth/connection.go --- a/pkg/auth/connection.go Fri Aug 24 15:17:35 2018 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,99 +0,0 @@ -package auth - -import ( - "database/sql" - "errors" - "io" - "log" - "sync" - "time" - - "gemma.intevation.de/gemma/pkg/misc" -) - -var ErrNoSuchToken = errors.New("No such token") - -const ( - maxOpen = 16 - maxDBIdle = time.Minute * 5 -) - -type Connection struct { - session *Session - - access time.Time - db *sql.DB - refCount int - - mu sync.Mutex -} - -func (c *Connection) serialize(w io.Writer) error { - if err := c.session.serialize(w); err != nil { - return err - } - access, err := c.last().MarshalText() - if err != nil { - return err - } - wr := misc.BinWriter{w, nil} - wr.WriteBin(uint32(len(access))) - wr.WriteBin(access) - return wr.Err -} - -func (c *Connection) deserialize(r io.Reader) error { - session := new(Session) - if err := session.deserialize(r); err != nil { - return err - } - - rd := misc.BinReader{r, nil} - var l uint32 - rd.ReadBin(&l) - access := make([]byte, l) - rd.ReadBin(access) - - if rd.Err != nil { - return rd.Err - } - - var t time.Time - if err := t.UnmarshalText(access); err != nil { - return err - } - - *c = Connection{ - session: session, - access: t, - } - - return nil -} - -func (c *Connection) set(session *Session) { - c.session = session - c.touch() -} - -func (c *Connection) touch() { - c.mu.Lock() - c.access = time.Now() - c.mu.Unlock() -} - -func (c *Connection) last() time.Time { - c.mu.Lock() - access := c.access - c.mu.Unlock() - return access -} - -func (c *Connection) close() { - if c.db != nil { - if err := c.db.Close(); err != nil { - log.Printf("warn: %v\n", err) - } - c.db = nil - } -} diff -r f6d61657b487 -r af1a198391f3 pkg/auth/middleware.go --- a/pkg/auth/middleware.go Fri Aug 24 15:17:35 2018 +0200 +++ b/pkg/auth/middleware.go Fri Aug 24 15:33:33 2018 +0200 @@ -35,7 +35,7 @@ return } - session := ConnPool.Session(token) + session := Sessions.Session(token) if session == nil { http.Error(rw, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return diff -r f6d61657b487 -r af1a198391f3 pkg/auth/opendb.go --- a/pkg/auth/opendb.go Fri Aug 24 15:17:35 2018 +0200 +++ b/pkg/auth/opendb.go Fri Aug 24 15:33:33 2018 +0200 @@ -1,8 +1,10 @@ package auth import ( + "context" "database/sql" "errors" + "sync" "github.com/jackc/pgx" "github.com/jackc/pgx/stdlib" @@ -10,6 +12,8 @@ "gemma.intevation.de/gemma/pkg/config" ) +var ErrNoMetamorphUser = errors.New("No metamorphic user configured") + func OpenDB(user, password string) (*sql.DB, error) { // To ease SSL config ride a bit on parsing. @@ -28,6 +32,47 @@ return stdlib.OpenDB(cc), nil } +type metamorph struct { + sync.Mutex + db *sql.DB +} + +var mm metamorph + +func (m *metamorph) open() (*sql.DB, error) { + m.Lock() + defer m.Unlock() + if m.db != nil { + return m.db, nil + } + user := config.MetamorphDBUser() + if user == "" { + return nil, ErrNoMetamorphUser + } + db, err := OpenDB(user, config.MetamorhpDBPassword()) + if err != nil { + return nil, err + } + m.db = db + return db, nil +} + +func MetamorphConn(ctx context.Context, user string) (*sql.Conn, error) { + db, err := mm.open() + if err != nil { + return nil, err + } + conn, err := db.Conn(ctx) + if err != nil { + return nil, err + } + if _, err := conn.ExecContext(ctx, `SELECT public.setrole_plan($1)`, user); err != nil { + conn.Close() + return nil, err + } + return conn, nil +} + const allRoles = ` WITH RECURSIVE cte AS ( SELECT oid FROM pg_roles WHERE rolname = current_user @@ -40,8 +85,6 @@ WHERE oid IN (SELECT oid FROM cte) AND rolname <> current_user AND EXISTS (SELECT 1 FROM users.list_users WHERE username = current_user)` -var ErrNoMetamorphUser = errors.New("No metamorphic user configured") - func AllOtherRoles(user, password string) (Roles, error) { db, err := OpenDB(user, password) if err != nil { @@ -66,18 +109,11 @@ return roles, rows.Err() } -func RunAs(role string, fn func(*sql.DB) error) error { - user := config.MetamorphDBUser() - if user == "" { - return ErrNoMetamorphUser - } - db, err := OpenDB(user, config.MetamorhpDBPassword()) +func RunAs(role string, ctx context.Context, fn func(*sql.Conn) error) error { + conn, err := MetamorphConn(ctx, role) if err != nil { - return nil + return err } - defer db.Close() - if _, err = db.Exec(`SELECT public.setrole_plan($1)`, role); err == nil { - err = fn(db) - } - return err + defer conn.Close() + return fn(conn) } diff -r f6d61657b487 -r af1a198391f3 pkg/auth/pool.go --- a/pkg/auth/pool.go Fri Aug 24 15:17:35 2018 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,322 +0,0 @@ -package auth - -import ( - "bytes" - "database/sql" - "log" - "time" - - bolt "github.com/coreos/bbolt" -) - -// ConnPool is the global connection pool. -var ConnPool *ConnectionPool - -type ConnectionPool struct { - storage *bolt.DB - conns map[string]*Connection - cmds chan func(*ConnectionPool) -} - -var sessionsBucket = []byte("sessions") - -func NewConnectionPool(filename string) (*ConnectionPool, error) { - - pcp := &ConnectionPool{ - conns: make(map[string]*Connection), - cmds: make(chan func(*ConnectionPool)), - } - if err := pcp.openStorage(filename); err != nil { - return nil, err - } - go pcp.run() - return pcp, nil -} - -// openStorage opens a storage file. -func (pcp *ConnectionPool) openStorage(filename string) error { - - // No file, nothing to restore/persist. - if filename == "" { - return nil - } - - db, err := bolt.Open(filename, 0600, nil) - if err != nil { - return err - } - - err = db.Update(func(tx *bolt.Tx) error { - b, err := tx.CreateBucketIfNotExists(sessionsBucket) - if err != nil { - return err - } - - // pre-load sessions - c := b.Cursor() - - for k, v := c.First(); k != nil; k, v = c.Next() { - var conn Connection - if err := conn.deserialize(bytes.NewReader(v)); err != nil { - return err - } - pcp.conns[string(k)] = &conn - } - - return nil - }) - - if err != nil { - db.Close() - return err - } - - pcp.storage = db - return nil -} - -func (pcp *ConnectionPool) run() { - for { - select { - case cmd := <-pcp.cmds: - cmd(pcp) - case <-time.After(time.Minute): - pcp.cleanDB() - case <-time.After(time.Minute * 5): - pcp.cleanToken() - } - } -} - -func (pcp *ConnectionPool) cleanDB() { - valid := time.Now().Add(-maxDBIdle) - for _, con := range pcp.conns { - if con.refCount <= 0 && con.last().Before(valid) { - con.close() - } - } -} - -func (pcp *ConnectionPool) cleanToken() { - now := time.Now() - for token, con := range pcp.conns { - expires := time.Unix(con.session.ExpiresAt, 0) - if expires.Before(now) { - // TODO: Be more graceful here? - con.close() - delete(pcp.conns, token) - pcp.remove(token) - } - } -} - -func (pcp *ConnectionPool) remove(token string) { - if pcp.storage == nil { - return - } - err := pcp.storage.Update(func(tx *bolt.Tx) error { - b := tx.Bucket(sessionsBucket) - return b.Delete([]byte(token)) - }) - if err != nil { - log.Printf("error: %v\n", err) - } -} - -func (pcp *ConnectionPool) Delete(token string) bool { - res := make(chan bool) - pcp.cmds <- func(pcp *ConnectionPool) { - conn, found := pcp.conns[token] - if !found { - res <- false - return - } - conn.close() - delete(pcp.conns, token) - pcp.remove(token) - res <- true - } - return <-res -} - -func (pcp *ConnectionPool) store(token string, con *Connection) { - if pcp.storage == nil { - return - } - err := pcp.storage.Update(func(tx *bolt.Tx) error { - b := tx.Bucket(sessionsBucket) - var buf bytes.Buffer - if err := con.serialize(&buf); err != nil { - return err - } - return b.Put([]byte(token), buf.Bytes()) - }) - if err != nil { - log.Printf("error: %v\n", err) - } -} - -func (pcp *ConnectionPool) Add(token string, session *Session) *Connection { - res := make(chan *Connection) - - pcp.cmds <- func(cp *ConnectionPool) { - con := pcp.conns[token] - if con == nil { - con = &Connection{} - pcp.conns[token] = con - } - con.set(session) - pcp.store(token, con) - res <- con - } - - con := <-res - return con -} - -func (pcp *ConnectionPool) Renew(token string) (string, error) { - - type result struct { - newToken string - err error - } - - resCh := make(chan result) - - pcp.cmds <- func(cp *ConnectionPool) { - con := pcp.conns[token] - if con == nil { - resCh <- result{err: ErrNoSuchToken} - } else { - delete(pcp.conns, token) - pcp.remove(token) - newToken := GenerateSessionKey() - // TODO: Ensure that this is not racy! - con.session.ExpiresAt = time.Now().Add(maxTokenValid).Unix() - pcp.conns[newToken] = con - pcp.store(newToken, con) - resCh <- result{newToken: newToken} - } - } - - r := <-resCh - return r.newToken, r.err -} - -func (pcp *ConnectionPool) trim(conn *Connection) { - - conn.refCount-- - - for { - least := time.Now() - var count int - var oldest *Connection - - for _, con := range pcp.conns { - if con.db != nil && con.refCount <= 0 { - if last := con.last(); last.Before(least) { - least = last - oldest = con - } - count++ - } - } - if count <= maxOpen { - break - } - oldest.close() - } -} - -func (pcp *ConnectionPool) Do(token string, fn func(*sql.DB) error) error { - - type result struct { - con *Connection - err error - } - - res := make(chan result) - - pcp.cmds <- func(pcp *ConnectionPool) { - con := pcp.conns[token] - if con == nil { - res <- result{err: ErrNoSuchToken} - return - } - con.touch() - // store the session here. The ref counting for - // open db connections is irrelevant for persistence - // as they all come up closed when the system reboots. - pcp.store(token, con) - - if con.db != nil { - con.refCount++ - res <- result{con: con} - return - } - - session := con.session - db, err := OpenDB(session.User, session.Password) - if err != nil { - res <- result{err: err} - return - } - con.db = db - con.refCount++ - res <- result{con: con} - } - - r := <-res - - if r.err != nil { - return r.err - } - - defer func() { - pcp.cmds <- func(pcp *ConnectionPool) { - pcp.trim(r.con) - } - }() - - return fn(r.con.db) -} - -func (pcp *ConnectionPool) Session(token string) *Session { - res := make(chan *Session) - pcp.cmds <- func(pcp *ConnectionPool) { - con := pcp.conns[token] - if con == nil { - res <- nil - } else { - con.touch() - pcp.store(token, con) - res <- con.session - } - } - return <-res -} - -func (pcp *ConnectionPool) Logout(user string) { - pcp.cmds <- func(pcp *ConnectionPool) { - for token, con := range pcp.conns { - if con.session.User == user { - if db := con.db; db != nil { - con.db = nil - db.Close() - } - delete(pcp.conns, token) - pcp.remove(token) - } - } - } -} - -func (pcp *ConnectionPool) Shutdown() error { - if db := pcp.storage; db != nil { - log.Println("info: shutdown persistent connection pool.") - pcp.storage = nil - return db.Close() - } - log.Println("info: shutdown in-memory connection pool.") - return nil -} diff -r f6d61657b487 -r af1a198391f3 pkg/auth/session.go --- a/pkg/auth/session.go Fri Aug 24 15:17:35 2018 +0200 +++ b/pkg/auth/session.go Fri Aug 24 15:33:33 2018 +0200 @@ -4,6 +4,7 @@ "encoding/base64" "errors" "io" + "sync" "time" "gemma.intevation.de/gemma/pkg/common" @@ -15,8 +16,11 @@ type Session struct { ExpiresAt int64 `json:"expires"` User string `json:"user"` - Password string `json:"password"` Roles Roles `json:"roles"` + + // private fields for managing expiration. + access time.Time + mu sync.Mutex } func (r Roles) Has(role string) bool { @@ -48,16 +52,14 @@ return &Session{ ExpiresAt: time.Now().Add(maxTokenValid).Unix(), User: user, - Password: password, Roles: roles, } } -func (s *Session) serialize(w io.Writer) error { +func (s *Session) serializePublic(w io.Writer) error { wr := misc.BinWriter{w, nil} wr.WriteBin(s.ExpiresAt) wr.WriteString(s.User) - wr.WriteString(s.Password) wr.WriteBin(uint32(len(s.Roles))) for _, role := range s.Roles { wr.WriteString(role) @@ -65,22 +67,76 @@ return wr.Err } +func (s *Session) serialize(w io.Writer) error { + + access, err := s.last().MarshalText() + if err != nil { + return err + } + + wr := misc.BinWriter{w, nil} + wr.WriteBin(s.ExpiresAt) + wr.WriteString(s.User) + wr.WriteBin(uint32(len(s.Roles))) + for _, role := range s.Roles { + wr.WriteString(role) + } + wr.WriteBin(uint32(len(access))) + wr.WriteBin(access) + return wr.Err +} + func (s *Session) deserialize(r io.Reader) error { - var x Session + + var session Session + var n uint32 rd := misc.BinReader{r, nil} - rd.ReadBin(&x.ExpiresAt) - rd.ReadString(&x.User) - rd.ReadString(&x.Password) + rd.ReadBin(&session.ExpiresAt) + rd.ReadString(&session.User) rd.ReadBin(&n) - x.Roles = make(Roles, n) + session.Roles = make(Roles, n) + for i := uint32(0); n > 0 && i < n; i++ { - rd.ReadString(&x.Roles[i]) + rd.ReadString(&session.Roles[i]) + } + + if rd.Err != nil { + return rd.Err + } + + var l uint32 + rd.ReadBin(&l) + access := make([]byte, l) + rd.ReadBin(access) + + if rd.Err != nil { + return rd.Err + } + + var t time.Time + if err := t.UnmarshalText(access); err != nil { + return err } - if rd.Err == nil { - *s = x - } - return rd.Err + + session.access = t + + *s = session + + return nil +} + +func (c *Session) touch() { + c.mu.Lock() + c.access = time.Now() + c.mu.Unlock() +} + +func (c *Session) last() time.Time { + c.mu.Lock() + access := c.access + c.mu.Unlock() + return access } func GenerateSessionKey() string { @@ -100,6 +156,6 @@ } token := GenerateSessionKey() session := NewSession(user, password, roles) - ConnPool.Add(token, session) + Sessions.Add(token, session) return token, session, nil } diff -r f6d61657b487 -r af1a198391f3 pkg/auth/store.go --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pkg/auth/store.go Fri Aug 24 15:33:33 2018 +0200 @@ -0,0 +1,227 @@ +package auth + +import ( + "bytes" + "errors" + "log" + "time" + + bolt "github.com/coreos/bbolt" +) + +var ErrNoSuchToken = errors.New("No such token") + +// Sessions is the global connection pool. +var Sessions *SessionStore + +type SessionStore struct { + storage *bolt.DB + sessions map[string]*Session + cmds chan func(*SessionStore) +} + +var sessionsBucket = []byte("sessions") + +func NewSessionStore(filename string) (*SessionStore, error) { + + pcp := &SessionStore{ + sessions: make(map[string]*Session), + cmds: make(chan func(*SessionStore)), + } + if err := pcp.openStorage(filename); err != nil { + return nil, err + } + go pcp.run() + return pcp, nil +} + +// openStorage opens a storage file. +func (pcp *SessionStore) openStorage(filename string) error { + + // No file, nothing to restore/persist. + if filename == "" { + return nil + } + + db, err := bolt.Open(filename, 0600, nil) + if err != nil { + return err + } + + err = db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucketIfNotExists(sessionsBucket) + if err != nil { + return err + } + + // pre-load sessions + c := b.Cursor() + + for k, v := c.First(); k != nil; k, v = c.Next() { + var session Session + if err := session.deserialize(bytes.NewReader(v)); err != nil { + return err + } + pcp.sessions[string(k)] = &session + } + + return nil + }) + + if err != nil { + db.Close() + return err + } + + pcp.storage = db + return nil +} + +func (pcp *SessionStore) run() { + for { + select { + case cmd := <-pcp.cmds: + cmd(pcp) + case <-time.After(time.Minute * 5): + pcp.cleanToken() + } + } +} + +func (pcp *SessionStore) cleanToken() { + now := time.Now() + for token, session := range pcp.sessions { + expires := time.Unix(session.ExpiresAt, 0) + if expires.Before(now) { + delete(pcp.sessions, token) + pcp.remove(token) + } + } +} + +func (pcp *SessionStore) remove(token string) { + if pcp.storage == nil { + return + } + err := pcp.storage.Update(func(tx *bolt.Tx) error { + b := tx.Bucket(sessionsBucket) + return b.Delete([]byte(token)) + }) + if err != nil { + log.Printf("error: %v\n", err) + } +} + +func (pcp *SessionStore) Delete(token string) bool { + res := make(chan bool) + pcp.cmds <- func(pcp *SessionStore) { + if _, found := pcp.sessions[token]; !found { + res <- false + return + } + delete(pcp.sessions, token) + pcp.remove(token) + res <- true + } + return <-res +} + +func (pcp *SessionStore) store(token string, session *Session) { + if pcp.storage == nil { + return + } + err := pcp.storage.Update(func(tx *bolt.Tx) error { + b := tx.Bucket(sessionsBucket) + var buf bytes.Buffer + if err := session.serialize(&buf); err != nil { + return err + } + return b.Put([]byte(token), buf.Bytes()) + }) + if err != nil { + log.Printf("error: %v\n", err) + } +} + +func (pcp *SessionStore) Add(token string, session *Session) *Session { + res := make(chan *Session) + + pcp.cmds <- func(cp *SessionStore) { + s := pcp.sessions[token] + if s == nil { + s = session + pcp.sessions[token] = session + } + s.touch() + pcp.store(token, s) + res <- s + } + + s := <-res + return s +} + +func (pcp *SessionStore) Renew(token string) (string, error) { + + type result struct { + newToken string + err error + } + + resCh := make(chan result) + + pcp.cmds <- func(cp *SessionStore) { + session := pcp.sessions[token] + if session == nil { + resCh <- result{err: ErrNoSuchToken} + } else { + delete(pcp.sessions, token) + pcp.remove(token) + newToken := GenerateSessionKey() + // TODO: Ensure that this is not racy! + session.ExpiresAt = time.Now().Add(maxTokenValid).Unix() + pcp.sessions[newToken] = session + pcp.store(newToken, session) + resCh <- result{newToken: newToken} + } + } + + r := <-resCh + return r.newToken, r.err +} + +func (pcp *SessionStore) Session(token string) *Session { + res := make(chan *Session) + pcp.cmds <- func(pcp *SessionStore) { + session := pcp.sessions[token] + if session == nil { + res <- nil + } else { + session.touch() + pcp.store(token, session) + res <- session + } + } + return <-res +} + +func (pcp *SessionStore) Logout(user string) { + pcp.cmds <- func(pcp *SessionStore) { + for token, session := range pcp.sessions { + if session.User == user { + delete(pcp.sessions, token) + pcp.remove(token) + } + } + } +} + +func (pcp *SessionStore) Shutdown() error { + if db := pcp.storage; db != nil { + log.Println("info: shutdown persistent connection pool.") + pcp.storage = nil + return db.Close() + } + log.Println("info: shutdown in-memory connection pool.") + return nil +} diff -r f6d61657b487 -r af1a198391f3 pkg/controllers/json.go --- a/pkg/controllers/json.go Fri Aug 24 15:17:35 2018 +0200 +++ b/pkg/controllers/json.go Fri Aug 24 15:33:33 2018 +0200 @@ -19,7 +19,8 @@ type JSONHandler struct { Input func() interface{} - Handle func(interface{}, *http.Request, *sql.DB) (JSONResult, error) + Handle func(interface{}, *http.Request, *sql.Conn) (JSONResult, error) + NoConn bool } type JSONError struct { @@ -46,11 +47,16 @@ var jr JSONResult var err error - if token, ok := auth.GetToken(req); ok { - err = auth.ConnPool.Do(token, func(db *sql.DB) (err error) { - jr, err = j.Handle(input, req, db) - return err - }) + if token, ok := auth.GetToken(req); ok && !j.NoConn { + if session := auth.Sessions.Session(token); session != nil { + var conn *sql.Conn + if conn, err = auth.MetamorphConn(req.Context(), session.User); err != nil { + defer conn.Close() + jr, err = j.Handle(input, req, conn) + } + } else { + err = auth.ErrNoSuchToken + } } else { jr, err = j.Handle(input, req, nil) } diff -r f6d61657b487 -r af1a198391f3 pkg/controllers/publish.go --- a/pkg/controllers/publish.go Fri Aug 24 15:17:35 2018 +0200 +++ b/pkg/controllers/publish.go Fri Aug 24 15:33:33 2018 +0200 @@ -7,7 +7,7 @@ "gemma.intevation.de/gemma/pkg/models" ) -func published(_ interface{}, req *http.Request, _ *sql.DB) (jr JSONResult, err error) { +func published(_ interface{}, req *http.Request, _ *sql.Conn) (jr JSONResult, err error) { jr = JSONResult{ Result: struct { Internal []models.IntEntry `json:"internal"` diff -r f6d61657b487 -r af1a198391f3 pkg/controllers/pwreset.go --- a/pkg/controllers/pwreset.go Fri Aug 24 15:17:35 2018 +0200 +++ b/pkg/controllers/pwreset.go Fri Aug 24 15:33:33 2018 +0200 @@ -2,6 +2,7 @@ import ( "bytes" + "context" "database/sql" "encoding/hex" "log" @@ -92,11 +93,14 @@ func removeOutdated() { for { time.Sleep(cleanupPause) - err := auth.RunAs(pwResetRole, func(db *sql.DB) error { - good := time.Now().Add(-passwordResetValid) - _, err := db.Exec(cleanupRequestsSQL, good) - return err - }) + err := auth.RunAs( + pwResetRole, context.Background(), + func(conn *sql.Conn) error { + good := time.Now().Add(-passwordResetValid) + _, err := conn.ExecContext( + context.Background(), cleanupRequestsSQL, good) + return err + }) if err != nil { log.Printf("error: %v\n", err) } @@ -165,7 +169,7 @@ func passwordResetRequest( input interface{}, req *http.Request, - _ *sql.DB, + _ *sql.Conn, ) (jr JSONResult, err error) { user := input.(*models.PWResetUser) @@ -177,46 +181,52 @@ var hash, email string - if err = auth.RunAs(pwResetRole, func(db *sql.DB) error { + ctx := req.Context() - var count int64 - if err := db.QueryRow(countRequestsSQL).Scan(&count); err != nil { - return err - } + if err = auth.RunAs( + pwResetRole, ctx, + func(conn *sql.Conn) error { - // Limit total number of password requests. - if count >= maxPasswordResets { - return JSONError{ - Code: http.StatusServiceUnavailable, - Message: "Too much password reset request", + var count int64 + if err := conn.QueryRowContext( + ctx, countRequestsSQL).Scan(&count); err != nil { + return err } - } - err := db.QueryRow(userExistsSQL, user.User).Scan(&email) + // Limit total number of password requests. + if count >= maxPasswordResets { + return JSONError{ + Code: http.StatusServiceUnavailable, + Message: "Too much password reset request", + } + } + + err := conn.QueryRowContext(ctx, userExistsSQL, user.User).Scan(&email) - switch { - case err == sql.ErrNoRows: - return JSONError{http.StatusNotFound, "User does not exist."} - case err != nil: - return err - } + switch { + case err == sql.ErrNoRows: + return JSONError{http.StatusNotFound, "User does not exist."} + case err != nil: + return err + } - if err := db.QueryRow(countRequestsUserSQL, user.User).Scan(&count); err != nil { - return err - } + if err := conn.QueryRowContext( + ctx, countRequestsUserSQL, user.User).Scan(&count); err != nil { + return err + } - // Limit requests per user - if count >= maxPasswordRequestsPerUser { - return JSONError{ - Code: http.StatusServiceUnavailable, - Message: "Too much password reset requests for user", + // Limit requests per user + if count >= maxPasswordRequestsPerUser { + return JSONError{ + Code: http.StatusServiceUnavailable, + Message: "Too much password reset requests for user", + } } - } - hash = generateHash() - _, err = db.Exec(insertRequestSQL, hash, user.User) - return err - }); err == nil { + hash = generateHash() + _, err = conn.ExecContext(ctx, insertRequestSQL, hash, user.User) + return err + }); err == nil { body := requestMessageBody(useHTTPS(req), user.User, hash, req.Host) if err = misc.SendMail(email, "Password Reset Link", body); err == nil { @@ -231,7 +241,7 @@ func passwordReset( _ interface{}, req *http.Request, - _ *sql.DB, + _ *sql.Conn, ) (jr JSONResult, err error) { hash := mux.Vars(req)["hash"] @@ -242,25 +252,28 @@ var email, user, password string - if err = auth.RunAs(pwResetRole, func(db *sql.DB) error { - err := db.QueryRow(findRequestSQL, hash).Scan(&email, &user) - switch { - case err == sql.ErrNoRows: - return JSONError{http.StatusNotFound, "No such hash"} - case err != nil: + ctx := req.Context() + + if err = auth.RunAs( + pwResetRole, ctx, func(conn *sql.Conn) error { + err := conn.QueryRowContext(ctx, findRequestSQL, hash).Scan(&email, &user) + switch { + case err == sql.ErrNoRows: + return JSONError{http.StatusNotFound, "No such hash"} + case err != nil: + return err + } + password = generateNewPassword() + res, err := conn.ExecContext(ctx, updatePasswordSQL, password, user) + if err != nil { + return err + } + if n, err2 := res.RowsAffected(); err2 == nil && n == 0 { + return JSONError{http.StatusNotFound, "User not found"} + } + _, err = conn.ExecContext(ctx, deleteRequestSQL, hash) return err - } - password = generateNewPassword() - res, err := db.Exec(updatePasswordSQL, password, user) - if err != nil { - return err - } - if n, err2 := res.RowsAffected(); err2 == nil && n == 0 { - return JSONError{http.StatusNotFound, "User not found"} - } - _, err = db.Exec(deleteRequestSQL, hash) - return err - }); err == nil { + }); err == nil { body := changedMessageBody(useHTTPS(req), user, password, req.Host) if err = misc.SendMail(email, "Password Reset Done", body); err == nil { jr.Result = &struct { diff -r f6d61657b487 -r af1a198391f3 pkg/controllers/routes.go --- a/pkg/controllers/routes.go Fri Aug 24 15:17:35 2018 +0200 +++ b/pkg/controllers/routes.go Fri Aug 24 15:33:33 2018 +0200 @@ -92,6 +92,7 @@ api.Handle("/published", any(&JSONHandler{ Handle: published, + NoConn: true, })).Methods(http.MethodGet) // Token handling: Login/Logout. diff -r f6d61657b487 -r af1a198391f3 pkg/controllers/token.go --- a/pkg/controllers/token.go Fri Aug 24 15:17:35 2018 +0200 +++ b/pkg/controllers/token.go Fri Aug 24 15:33:33 2018 +0200 @@ -19,7 +19,7 @@ func renew(rw http.ResponseWriter, req *http.Request) { token, _ := auth.GetToken(req) - newToken, err := auth.ConnPool.Renew(token) + newToken, err := auth.Sessions.Renew(token) switch { case err == auth.ErrNoSuchToken: http.NotFound(rw, req) @@ -48,7 +48,7 @@ func logout(rw http.ResponseWriter, req *http.Request) { token, ok := auth.GetToken(req) - if !ok || !auth.ConnPool.Delete(token) { + if !ok || !auth.Sessions.Delete(token) { http.NotFound(rw, req) return } diff -r f6d61657b487 -r af1a198391f3 pkg/controllers/user.go --- a/pkg/controllers/user.go Fri Aug 24 15:17:35 2018 +0200 +++ b/pkg/controllers/user.go Fri Aug 24 15:33:33 2018 +0200 @@ -54,7 +54,7 @@ func deleteUser( _ interface{}, req *http.Request, - db *sql.DB, + db *sql.Conn, ) (jr JSONResult, err error) { user := mux.Vars(req)["user"] @@ -71,7 +71,7 @@ var res sql.Result - if res, err = db.Exec(deleteUserSQL, user); err != nil { + if res, err = db.ExecContext(req.Context(), deleteUserSQL, user); err != nil { return } @@ -84,15 +84,16 @@ } // Running in a go routine should not be necessary. - go func() { auth.ConnPool.Logout(user) }() + go func() { auth.Sessions.Logout(user) }() jr = JSONResult{Code: http.StatusNoContent} return } func updateUser( - input interface{}, req *http.Request, - db *sql.DB, + input interface{}, + req *http.Request, + db *sql.Conn, ) (jr JSONResult, err error) { user := models.UserName(mux.Vars(req)["user"]) @@ -106,7 +107,8 @@ if s, _ := auth.GetSession(req); s.Roles.Has("sys_admin") { if newUser.Extent == nil { - res, err = db.Exec( + res, err = db.ExecContext( + req.Context(), updateUserSQL, user, newUser.Role, @@ -116,7 +118,8 @@ newUser.Email, ) } else { - res, err = db.Exec( + res, err = db.ExecContext( + req.Context(), updateUserExtentSQL, user, newUser.Role, @@ -133,7 +136,8 @@ err = JSONError{http.StatusBadRequest, "extent is mandatory"} return } - res, err = db.Exec( + res, err = db.ExecContext( + req.Context(), updateUserUnprivSQL, user, newUser.Password, @@ -157,7 +161,7 @@ if user != newUser.User { // Running in a go routine should not be necessary. - go func() { auth.ConnPool.Logout(string(user)) }() + go func() { auth.Sessions.Logout(string(user)) }() } jr = JSONResult{ @@ -170,14 +174,16 @@ } func createUser( - input interface{}, req *http.Request, - db *sql.DB, + input interface{}, + req *http.Request, + db *sql.Conn, ) (jr JSONResult, err error) { user := input.(*models.User) if user.Extent == nil { - _, err = db.Exec( + _, err = db.ExecContext( + req.Context(), createUserSQL, user.Role, user.User, @@ -186,7 +192,8 @@ user.Email, ) } else { - _, err = db.Exec( + _, err = db.ExecContext( + req.Context(), createUserExtentSQL, user.Role, user.User, @@ -212,13 +219,14 @@ } func listUsers( - _ interface{}, req *http.Request, - db *sql.DB, + _ interface{}, + req *http.Request, + db *sql.Conn, ) (jr JSONResult, err error) { var rows *sql.Rows - rows, err = db.Query(listUsersSQL) + rows, err = db.QueryContext(req.Context(), listUsersSQL) if err != nil { return } @@ -250,8 +258,9 @@ } func listUser( - _ interface{}, req *http.Request, - db *sql.DB, + _ interface{}, + req *http.Request, + db *sql.Conn, ) (jr JSONResult, err error) { user := models.UserName(mux.Vars(req)["user"]) @@ -265,7 +274,7 @@ Extent: &models.BoundingBox{}, } - err = db.QueryRow(listUserSQL, user).Scan( + err = db.QueryRowContext(req.Context(), listUserSQL, user).Scan( &result.Role, &result.Country, &result.Email, diff -r f6d61657b487 -r af1a198391f3 pkg/models/extservices.go --- a/pkg/models/extservices.go Fri Aug 24 15:17:35 2018 +0200 +++ b/pkg/models/extservices.go Fri Aug 24 15:33:33 2018 +0200 @@ -1,6 +1,7 @@ package models import ( + "context" "database/sql" "log" "sort" @@ -47,25 +48,27 @@ func (es *ExtServices) load() error { // make empty slice to prevent retry if slice is empty. es.entries = []ExtEntry{} - return auth.RunAs("sys_admin", func(db *sql.DB) error { - rows, err := db.Query(selectExternalServices) - if err != nil { - return err - } - defer rows.Close() - for rows.Next() { - var entry ExtEntry - if err := rows.Scan( - &entry.Name, - &entry.URL, - &entry.WFS, - ); err != nil { + return auth.RunAs("sys_admin", context.Background(), + func(conn *sql.Conn) error { + rows, err := conn.QueryContext( + context.Background(), selectExternalServices) + if err != nil { return err } - es.entries = append(es.entries, entry) - } - return rows.Err() - }) + defer rows.Close() + for rows.Next() { + var entry ExtEntry + if err := rows.Scan( + &entry.Name, + &entry.URL, + &entry.WFS, + ); err != nil { + return err + } + es.entries = append(es.entries, entry) + } + return rows.Err() + }) } func (es *ExtServices) Invalidate() { diff -r f6d61657b487 -r af1a198391f3 pkg/models/intservices.go --- a/pkg/models/intservices.go Fri Aug 24 15:17:35 2018 +0200 +++ b/pkg/models/intservices.go Fri Aug 24 15:33:33 2018 +0200 @@ -1,6 +1,7 @@ package models import ( + "context" "database/sql" "log" "sync" @@ -64,24 +65,26 @@ func (ps *IntServices) load() error { // make empty slice to prevent retry if slice is empty. ps.entries = []IntEntry{} - return auth.RunAs("sys_admin", func(db *sql.DB) error { - rows, err := db.Query(selectPublishedServices) - if err != nil { - return err - } - defer rows.Close() - for rows.Next() { - var entry IntEntry - if err := rows.Scan( - &entry.Name, &entry.Style, - &entry.WFS, &entry.WFS, - ); err != nil { + return auth.RunAs("sys_admin", context.Background(), + func(conn *sql.Conn) error { + rows, err := conn.QueryContext( + context.Background(), selectPublishedServices) + if err != nil { return err } - ps.entries = append(ps.entries, entry) - } - return rows.Err() - }) + defer rows.Close() + for rows.Next() { + var entry IntEntry + if err := rows.Scan( + &entry.Name, &entry.Style, + &entry.WFS, &entry.WFS, + ); err != nil { + return err + } + ps.entries = append(ps.entries, entry) + } + return rows.Err() + }) return nil }