view auth/inmemory.go @ 196:b67208d82543

Make test output more comprehensive Running all tests in one transaction ensures the final output tells about any failing test, not just in the last transaction (i.e. test script). The price is that no traces of the tests are left in the database because we have to rollback in order to have no left-over test roles in the cluster.
author Tom Gottfried <tom@intevation.de>
date Fri, 20 Jul 2018 18:31:45 +0200
parents 3349bfc2a047
children
line wrap: on
line source

package auth

import (
	"database/sql"
	"log"
	"time"
)

type InMemoryConnectionPool struct {
	conns map[string]*Connection
	cmds  chan func(*InMemoryConnectionPool)
}

func NewInMemoryConnectionPool() *InMemoryConnectionPool {
	cp := &InMemoryConnectionPool{
		conns: make(map[string]*Connection),
		cmds:  make(chan func(*InMemoryConnectionPool)),
	}
	go cp.run()
	return cp
}

func (cp *InMemoryConnectionPool) run() {
	for {
		select {
		case cmd := <-cp.cmds:
			cmd(cp)
		case <-time.After(time.Minute):
			cp.cleanDB()
		case <-time.After(time.Minute * 5):
			cp.cleanToken()
		}
	}
}

func (cp *InMemoryConnectionPool) cleanDB() {
	valid := time.Now().Add(-maxDBIdle)
	for _, con := range cp.conns {
		if con.refCount <= 0 && con.last().Before(valid) {
			con.close()
		}
	}
}

func (cp *InMemoryConnectionPool) cleanToken() {
	now := time.Now()
	for token, con := range cp.conns {
		expires := time.Unix(con.session.ExpiresAt, 0)
		if expires.Before(now) {
			// TODO: Be more graceful here?
			con.close()
			delete(cp.conns, token)
		}
	}
}

func (cp *InMemoryConnectionPool) Delete(token string) bool {
	res := make(chan bool)
	cp.cmds <- func(cp *InMemoryConnectionPool) {
		conn, found := cp.conns[token]
		if !found {
			res <- false
			return
		}
		conn.close()
		delete(cp.conns, token)
		res <- true
	}
	return <-res
}

func (cp *InMemoryConnectionPool) Add(token string, session *Session) *Connection {
	res := make(chan *Connection)

	cp.cmds <- func(cp *InMemoryConnectionPool) {
		con := cp.conns[token]
		if con == nil {
			con = &Connection{}
			cp.conns[token] = con
		}
		con.set(session)
		res <- con
	}

	con := <-res
	return con
}

func (cp *InMemoryConnectionPool) Renew(token string) (string, error) {

	type result struct {
		newToken string
		err      error
	}

	resCh := make(chan result)

	cp.cmds <- func(cp *InMemoryConnectionPool) {
		con := cp.conns[token]
		if con == nil {
			resCh <- result{err: ErrNoSuchToken}
		} else {
			delete(cp.conns, token)
			newToken := GenerateSessionKey()
			// TODO: Ensure that this is not racy!
			con.session.ExpiresAt = time.Now().Add(maxTokenValid).Unix()
			cp.conns[newToken] = con
			resCh <- result{newToken: newToken}
		}
	}

	r := <-resCh
	return r.newToken, r.err
}

func (cp *InMemoryConnectionPool) trim(conn *Connection) {

	conn.refCount--

	for {
		least := time.Now()
		var count int
		var oldest *Connection

		for _, con := range cp.conns {
			if con.db != nil && con.refCount <= 0 {
				if last := con.last(); last.Before(least) {
					least = last
					oldest = con
				}
				count++
			}
		}
		if count <= maxOpen {
			break
		}
		oldest.close()
	}
}

func (cp *InMemoryConnectionPool) Do(token string, fn func(*sql.DB) error) error {

	type result struct {
		con *Connection
		err error
	}

	res := make(chan result)

	cp.cmds <- func(cp *InMemoryConnectionPool) {
		con := cp.conns[token]
		if con == nil {
			res <- result{err: ErrNoSuchToken}
			return
		}
		con.touch()
		if con.db != nil {
			con.refCount++
			res <- result{con: con}
			return
		}

		session := con.session
		db, err := opendb(session.User, session.Password)
		if err != nil {
			res <- result{err: err}
			return
		}
		con.db = db
		con.refCount++
		res <- result{con: con}
	}

	r := <-res

	if r.err != nil {
		return r.err
	}

	defer func() {
		cp.cmds <- func(cp *InMemoryConnectionPool) {
			cp.trim(r.con)
		}
	}()

	return fn(r.con.db)
}

func (cp *InMemoryConnectionPool) Session(token string) *Session {
	res := make(chan *Session)
	cp.cmds <- func(cp *InMemoryConnectionPool) {
		con := cp.conns[token]
		if con == nil {
			res <- nil
		} else {
			con.touch()
			res <- con.session
		}
	}
	return <-res
}

func (cp *InMemoryConnectionPool) Shutdown() error {
	log.Println("info: shutdown in-memory connection pool.")
	return nil
}