view auth/inmemory.go @ 151:3349bfc2a047

Shutdown server gracefully.
author Sascha L. Teichmann <sascha.teichmann@intevation.de>
date Mon, 02 Jul 2018 13:23:31 +0200
parents 0116aae1071b
children
line wrap: on
line source

package auth

import (
	"database/sql"
	"log"
	"time"
)

type InMemoryConnectionPool struct {
	conns map[string]*Connection
	cmds  chan func(*InMemoryConnectionPool)
}

func NewInMemoryConnectionPool() *InMemoryConnectionPool {
	cp := &InMemoryConnectionPool{
		conns: make(map[string]*Connection),
		cmds:  make(chan func(*InMemoryConnectionPool)),
	}
	go cp.run()
	return cp
}

func (cp *InMemoryConnectionPool) run() {
	for {
		select {
		case cmd := <-cp.cmds:
			cmd(cp)
		case <-time.After(time.Minute):
			cp.cleanDB()
		case <-time.After(time.Minute * 5):
			cp.cleanToken()
		}
	}
}

func (cp *InMemoryConnectionPool) cleanDB() {
	valid := time.Now().Add(-maxDBIdle)
	for _, con := range cp.conns {
		if con.refCount <= 0 && con.last().Before(valid) {
			con.close()
		}
	}
}

func (cp *InMemoryConnectionPool) cleanToken() {
	now := time.Now()
	for token, con := range cp.conns {
		expires := time.Unix(con.session.ExpiresAt, 0)
		if expires.Before(now) {
			// TODO: Be more graceful here?
			con.close()
			delete(cp.conns, token)
		}
	}
}

func (cp *InMemoryConnectionPool) Delete(token string) bool {
	res := make(chan bool)
	cp.cmds <- func(cp *InMemoryConnectionPool) {
		conn, found := cp.conns[token]
		if !found {
			res <- false
			return
		}
		conn.close()
		delete(cp.conns, token)
		res <- true
	}
	return <-res
}

func (cp *InMemoryConnectionPool) Add(token string, session *Session) *Connection {
	res := make(chan *Connection)

	cp.cmds <- func(cp *InMemoryConnectionPool) {
		con := cp.conns[token]
		if con == nil {
			con = &Connection{}
			cp.conns[token] = con
		}
		con.set(session)
		res <- con
	}

	con := <-res
	return con
}

func (cp *InMemoryConnectionPool) Renew(token string) (string, error) {

	type result struct {
		newToken string
		err      error
	}

	resCh := make(chan result)

	cp.cmds <- func(cp *InMemoryConnectionPool) {
		con := cp.conns[token]
		if con == nil {
			resCh <- result{err: ErrNoSuchToken}
		} else {
			delete(cp.conns, token)
			newToken := GenerateSessionKey()
			// TODO: Ensure that this is not racy!
			con.session.ExpiresAt = time.Now().Add(maxTokenValid).Unix()
			cp.conns[newToken] = con
			resCh <- result{newToken: newToken}
		}
	}

	r := <-resCh
	return r.newToken, r.err
}

func (cp *InMemoryConnectionPool) trim(conn *Connection) {

	conn.refCount--

	for {
		least := time.Now()
		var count int
		var oldest *Connection

		for _, con := range cp.conns {
			if con.db != nil && con.refCount <= 0 {
				if last := con.last(); last.Before(least) {
					least = last
					oldest = con
				}
				count++
			}
		}
		if count <= maxOpen {
			break
		}
		oldest.close()
	}
}

func (cp *InMemoryConnectionPool) Do(token string, fn func(*sql.DB) error) error {

	type result struct {
		con *Connection
		err error
	}

	res := make(chan result)

	cp.cmds <- func(cp *InMemoryConnectionPool) {
		con := cp.conns[token]
		if con == nil {
			res <- result{err: ErrNoSuchToken}
			return
		}
		con.touch()
		if con.db != nil {
			con.refCount++
			res <- result{con: con}
			return
		}

		session := con.session
		db, err := opendb(session.User, session.Password)
		if err != nil {
			res <- result{err: err}
			return
		}
		con.db = db
		con.refCount++
		res <- result{con: con}
	}

	r := <-res

	if r.err != nil {
		return r.err
	}

	defer func() {
		cp.cmds <- func(cp *InMemoryConnectionPool) {
			cp.trim(r.con)
		}
	}()

	return fn(r.con.db)
}

func (cp *InMemoryConnectionPool) Session(token string) *Session {
	res := make(chan *Session)
	cp.cmds <- func(cp *InMemoryConnectionPool) {
		con := cp.conns[token]
		if con == nil {
			res <- nil
		} else {
			con.touch()
			res <- con.session
		}
	}
	return <-res
}

func (cp *InMemoryConnectionPool) Shutdown() error {
	log.Println("info: shutdown in-memory connection pool.")
	return nil
}