view auth/connection.go @ 126:89cf2e7672ff

Implemented an explicit token deletion under endpoint /api/logout.
author Sascha L. Teichmann <sascha.teichmann@intevation.de>
date Thu, 28 Jun 2018 13:09:38 +0200
parents c3e2cd7fa46f
children 44794c641277
line wrap: on
line source

package auth

import (
	"database/sql"
	"errors"
	"log"
	"time"
)

var ErrNoSuchToken = errors.New("No such token")

var ConnPool = NewConnectionPool()

const (
	maxOpen      = 16
	maxDBIdle    = time.Minute * 5
	maxTokenIdle = time.Minute * 20
)

type Connection struct {
	user     string
	password string

	access time.Time
	db     *sql.DB
}

func (c *Connection) set(user, password string) {
	c.user = user
	c.password = password
	c.access = time.Now()
}

type ConnectionPool struct {
	conns map[string]*Connection
	cmds  chan func(*ConnectionPool)
}

func NewConnectionPool() *ConnectionPool {
	cp := &ConnectionPool{
		conns: make(map[string]*Connection),
		cmds:  make(chan func(*ConnectionPool)),
	}
	go cp.run()
	return cp
}

func (cp *ConnectionPool) run() {
	for {
		select {
		case cmd := <-cp.cmds:
			cmd(cp)
		case <-time.After(time.Minute):
			cp.cleanDB()
		case <-time.After(time.Minute * 5):
			cp.cleanToken()
		}
	}
}

func (cp *ConnectionPool) cleanDB() {
	valid := time.Now().Add(-maxDBIdle)
	for _, con := range cp.conns {
		if con.access.Before(valid) {
			con.db.Close()
			con.db = nil
		}
	}
}

func (cp *ConnectionPool) cleanToken() {
	valid := time.Now().Add(-maxTokenIdle)
	for token, con := range cp.conns {
		if con.access.Before(valid) {
			if con.db != nil {
				// TODO: Be more graceful here?
				con.db.Close()
				con.db = nil
			}
			delete(cp.conns, token)
		}
	}
}

func (cp *ConnectionPool) Delete(token string) bool {
	res := make(chan bool)
	cp.cmds <- func(cp *ConnectionPool) {
		conn, found := cp.conns[token]
		if !found {
			res <- false
			return
		}
		delete(cp.conns, token)
		if conn.db != nil {
			if err := conn.db.Close(); err != nil {
				log.Printf("warn: %v\n", err)
			}
			conn.db = nil
		}
		res <- true
	}
	return <-res
}

func (cp *ConnectionPool) Add(token, user, password string) *Connection {
	res := make(chan *Connection)

	cp.cmds <- func(cp *ConnectionPool) {

		con := cp.conns[token]
		if con == nil {
			con = &Connection{}
			cp.conns[token] = con
		}
		con.set(user, password)
		res <- con
	}

	con := <-res
	return con
}

func trim(cp *ConnectionPool) {

	least := time.Now()
	var count int
	var oldest *Connection

	for _, con := range cp.conns {
		if con.db != nil {
			if con.access.Before(least) {
				least = con.access
				oldest = con
			}
			count++
		}
	}
	if count > maxOpen {
		oldest.db.Close()
		oldest.db = nil
	}
	return
}

func (cp *ConnectionPool) Do(token string, fn func(*sql.DB) error) error {

	type result struct {
		con *Connection
		err error
	}

	res := make(chan result)

	cp.cmds <- func(cp *ConnectionPool) {
		con := cp.conns[token]
		if con == nil {
			res <- result{err: ErrNoSuchToken}
			return
		}
		con.access = time.Now()
		if con.db != nil {
			res <- result{con: con}
			return
		}

		db, err := opendb(con.user, con.password)
		if err != nil {
			res <- result{err: err}
			return
		}
		con.db = db
	}

	r := <-res

	if r.err != nil {
		return r.err
	}

	defer func() { cp.cmds <- trim }()

	return fn(r.con.db)
}