Mercurial > gemma
diff auth/pool.go @ 205:2a152816fc38
Renamed the file containing the connection pool to a more suited one.
author | Sascha L. Teichmann <teichmann@intevation.de> |
---|---|
date | Sun, 22 Jul 2018 10:24:28 +0200 |
parents | auth/persistent.go@3d0988d9f867 |
children | 2fad2931a5a6 |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/auth/pool.go Sun Jul 22 10:24:28 2018 +0200 @@ -0,0 +1,306 @@ +package auth + +import ( + "bytes" + "database/sql" + "log" + "time" + + bolt "github.com/coreos/bbolt" +) + +type ConnectionPool struct { + storage *bolt.DB + conns map[string]*Connection + cmds chan func(*ConnectionPool) +} + +var sessionsBucket = []byte("sessions") + +func NewConnectionPool(filename string) (*ConnectionPool, error) { + + pcp := &ConnectionPool{ + cmds: make(chan func(*ConnectionPool)), + } + if err := pcp.openStorage(filename); err != nil { + return nil, err + } + go pcp.run() + return pcp, nil +} + +// openStorage opens a storage file. +func (pcp *ConnectionPool) openStorage(filename string) error { + + // No file, nothing to restore/persist. + if filename == "" { + return nil + } + + db, err := bolt.Open(filename, 0600, nil) + if err != nil { + return err + } + + conns := make(map[string]*Connection) + err = db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucketIfNotExists(sessionsBucket) + if err != nil { + return err + } + + // pre-load sessions + c := b.Cursor() + + for k, v := c.First(); k != nil; k, v = c.Next() { + var conn Connection + if err := conn.deserialize(bytes.NewReader(v)); err != nil { + return err + } + conns[string(k)] = &conn + } + + return nil + }) + + if err != nil { + db.Close() + return err + } + + pcp.storage = db + pcp.conns = conns + + return nil +} + +func (pcp *ConnectionPool) run() { + for { + select { + case cmd := <-pcp.cmds: + cmd(pcp) + case <-time.After(time.Minute): + pcp.cleanDB() + case <-time.After(time.Minute * 5): + pcp.cleanToken() + } + } +} + +func (pcp *ConnectionPool) cleanDB() { + valid := time.Now().Add(-maxDBIdle) + for _, con := range pcp.conns { + if con.refCount <= 0 && con.last().Before(valid) { + con.close() + } + } +} + +func (pcp *ConnectionPool) cleanToken() { + now := time.Now() + for token, con := range pcp.conns { + expires := time.Unix(con.session.ExpiresAt, 0) + if expires.Before(now) { + // TODO: Be more graceful here? + con.close() + delete(pcp.conns, token) + pcp.remove(token) + } + } +} + +func (pcp *ConnectionPool) remove(token string) { + if pcp.storage == nil { + return + } + err := pcp.storage.Update(func(tx *bolt.Tx) error { + b := tx.Bucket(sessionsBucket) + return b.Delete([]byte(token)) + }) + if err != nil { + log.Printf("error: %v\n", err) + } +} + +func (pcp *ConnectionPool) Delete(token string) bool { + res := make(chan bool) + pcp.cmds <- func(pcp *ConnectionPool) { + conn, found := pcp.conns[token] + if !found { + res <- false + return + } + conn.close() + delete(pcp.conns, token) + pcp.remove(token) + res <- true + } + return <-res +} + +func (pcp *ConnectionPool) store(token string, con *Connection) { + if pcp.storage == nil { + return + } + err := pcp.storage.Update(func(tx *bolt.Tx) error { + b := tx.Bucket(sessionsBucket) + var buf bytes.Buffer + if err := con.serialize(&buf); err != nil { + return err + } + return b.Put([]byte(token), buf.Bytes()) + }) + if err != nil { + log.Printf("error: %v\n", err) + } +} + +func (pcp *ConnectionPool) Add(token string, session *Session) *Connection { + res := make(chan *Connection) + + pcp.cmds <- func(cp *ConnectionPool) { + con := pcp.conns[token] + if con == nil { + con = &Connection{} + pcp.conns[token] = con + } + con.set(session) + pcp.store(token, con) + res <- con + } + + con := <-res + return con +} + +func (pcp *ConnectionPool) Renew(token string) (string, error) { + + type result struct { + newToken string + err error + } + + resCh := make(chan result) + + pcp.cmds <- func(cp *ConnectionPool) { + con := pcp.conns[token] + if con == nil { + resCh <- result{err: ErrNoSuchToken} + } else { + delete(pcp.conns, token) + pcp.remove(token) + newToken := GenerateSessionKey() + // TODO: Ensure that this is not racy! + con.session.ExpiresAt = time.Now().Add(maxTokenValid).Unix() + pcp.conns[newToken] = con + pcp.store(newToken, con) + resCh <- result{newToken: newToken} + } + } + + r := <-resCh + return r.newToken, r.err +} + +func (pcp *ConnectionPool) trim(conn *Connection) { + + conn.refCount-- + + for { + least := time.Now() + var count int + var oldest *Connection + + for _, con := range pcp.conns { + if con.db != nil && con.refCount <= 0 { + if last := con.last(); last.Before(least) { + least = last + oldest = con + } + count++ + } + } + if count <= maxOpen { + break + } + oldest.close() + } +} + +func (pcp *ConnectionPool) Do(token string, fn func(*sql.DB) error) error { + + type result struct { + con *Connection + err error + } + + res := make(chan result) + + pcp.cmds <- func(pcp *ConnectionPool) { + con := pcp.conns[token] + if con == nil { + res <- result{err: ErrNoSuchToken} + return + } + con.touch() + // store the session here. The ref counting for + // open db connections is irrelevant for persistence + // as they all come up closed when the system reboots. + pcp.store(token, con) + + if con.db != nil { + con.refCount++ + res <- result{con: con} + return + } + + session := con.session + db, err := opendb(session.User, session.Password) + if err != nil { + res <- result{err: err} + return + } + con.db = db + con.refCount++ + res <- result{con: con} + } + + r := <-res + + if r.err != nil { + return r.err + } + + defer func() { + pcp.cmds <- func(pcp *ConnectionPool) { + pcp.trim(r.con) + } + }() + + return fn(r.con.db) +} + +func (pcp *ConnectionPool) Session(token string) *Session { + res := make(chan *Session) + pcp.cmds <- func(pcp *ConnectionPool) { + con := pcp.conns[token] + if con == nil { + res <- nil + } else { + con.touch() + pcp.store(token, con) + res <- con.session + } + } + return <-res +} + +func (pcp *ConnectionPool) Shutdown() error { + if db := pcp.storage; db != nil { + log.Println("info: shutdown persistent connection pool.") + pcp.storage = nil + return db.Close() + } + log.Println("info: shutdown in-memory connection pool.") + return nil +}