view pkg/auth/pool.go @ 493:8a0737aa6ab6 metamorph-for-all

The connection pool is now only a session store.
author Sascha L. Teichmann <sascha.teichmann@intevation.de>
date Fri, 24 Aug 2018 14:25:05 +0200
parents b2dc9c2f69e0
children
line wrap: on
line source

package auth

import (
	"bytes"
	"log"
	"time"

	bolt "github.com/coreos/bbolt"
)

// Sessions is the global connection pool.
var Sessions *SessionStore

type SessionStore struct {
	storage *bolt.DB
	conns   map[string]*Connection
	cmds    chan func(*SessionStore)
}

var sessionsBucket = []byte("sessions")

func NewSessionStore(filename string) (*SessionStore, error) {

	pcp := &SessionStore{
		conns: make(map[string]*Connection),
		cmds:  make(chan func(*SessionStore)),
	}
	if err := pcp.openStorage(filename); err != nil {
		return nil, err
	}
	go pcp.run()
	return pcp, nil
}

// openStorage opens a storage file.
func (pcp *SessionStore) openStorage(filename string) error {

	// No file, nothing to restore/persist.
	if filename == "" {
		return nil
	}

	db, err := bolt.Open(filename, 0600, nil)
	if err != nil {
		return err
	}

	err = db.Update(func(tx *bolt.Tx) error {
		b, err := tx.CreateBucketIfNotExists(sessionsBucket)
		if err != nil {
			return err
		}

		// pre-load sessions
		c := b.Cursor()

		for k, v := c.First(); k != nil; k, v = c.Next() {
			var conn Connection
			if err := conn.deserialize(bytes.NewReader(v)); err != nil {
				return err
			}
			pcp.conns[string(k)] = &conn
		}

		return nil
	})

	if err != nil {
		db.Close()
		return err
	}

	pcp.storage = db
	return nil
}

func (pcp *SessionStore) run() {
	for {
		select {
		case cmd := <-pcp.cmds:
			cmd(pcp)
		case <-time.After(time.Minute * 5):
			pcp.cleanToken()
		}
	}
}

func (pcp *SessionStore) cleanToken() {
	now := time.Now()
	for token, con := range pcp.conns {
		expires := time.Unix(con.session.ExpiresAt, 0)
		if expires.Before(now) {
			delete(pcp.conns, token)
			pcp.remove(token)
		}
	}
}

func (pcp *SessionStore) remove(token string) {
	if pcp.storage == nil {
		return
	}
	err := pcp.storage.Update(func(tx *bolt.Tx) error {
		b := tx.Bucket(sessionsBucket)
		return b.Delete([]byte(token))
	})
	if err != nil {
		log.Printf("error: %v\n", err)
	}
}

func (pcp *SessionStore) Delete(token string) bool {
	res := make(chan bool)
	pcp.cmds <- func(pcp *SessionStore) {
		if _, found := pcp.conns[token]; !found {
			res <- false
			return
		}
		delete(pcp.conns, token)
		pcp.remove(token)
		res <- true
	}
	return <-res
}

func (pcp *SessionStore) store(token string, con *Connection) {
	if pcp.storage == nil {
		return
	}
	err := pcp.storage.Update(func(tx *bolt.Tx) error {
		b := tx.Bucket(sessionsBucket)
		var buf bytes.Buffer
		if err := con.serialize(&buf); err != nil {
			return err
		}
		return b.Put([]byte(token), buf.Bytes())
	})
	if err != nil {
		log.Printf("error: %v\n", err)
	}
}

func (pcp *SessionStore) Add(token string, session *Session) *Connection {
	res := make(chan *Connection)

	pcp.cmds <- func(cp *SessionStore) {
		con := pcp.conns[token]
		if con == nil {
			con = &Connection{}
			pcp.conns[token] = con
		}
		con.set(session)
		pcp.store(token, con)
		res <- con
	}

	con := <-res
	return con
}

func (pcp *SessionStore) Renew(token string) (string, error) {

	type result struct {
		newToken string
		err      error
	}

	resCh := make(chan result)

	pcp.cmds <- func(cp *SessionStore) {
		con := pcp.conns[token]
		if con == nil {
			resCh <- result{err: ErrNoSuchToken}
		} else {
			delete(pcp.conns, token)
			pcp.remove(token)
			newToken := GenerateSessionKey()
			// TODO: Ensure that this is not racy!
			con.session.ExpiresAt = time.Now().Add(maxTokenValid).Unix()
			pcp.conns[newToken] = con
			pcp.store(newToken, con)
			resCh <- result{newToken: newToken}
		}
	}

	r := <-resCh
	return r.newToken, r.err
}

func (pcp *SessionStore) Do(token string) (*Session, error) {

	type result struct {
		session *Session
		err     error
	}

	res := make(chan result)

	pcp.cmds <- func(pcp *SessionStore) {
		con := pcp.conns[token]
		if con == nil {
			res <- result{err: ErrNoSuchToken}
			return
		}
		con.touch()
		pcp.store(token, con)

		res <- result{session: con.session}
	}

	r := <-res

	if r.err != nil {
		return nil, r.err
	}

	return r.session, nil
}

func (pcp *SessionStore) Session(token string) *Session {
	res := make(chan *Session)
	pcp.cmds <- func(pcp *SessionStore) {
		con := pcp.conns[token]
		if con == nil {
			res <- nil
		} else {
			con.touch()
			pcp.store(token, con)
			res <- con.session
		}
	}
	return <-res
}

func (pcp *SessionStore) Logout(user string) {
	pcp.cmds <- func(pcp *SessionStore) {
		for token, con := range pcp.conns {
			if con.session.User == user {
				delete(pcp.conns, token)
				pcp.remove(token)
			}
		}
	}
}

func (pcp *SessionStore) Shutdown() error {
	if db := pcp.storage; db != nil {
		log.Println("info: shutdown persistent connection pool.")
		pcp.storage = nil
		return db.Close()
	}
	log.Println("info: shutdown in-memory connection pool.")
	return nil
}