view cmd/tokenserver/main.go @ 151:3349bfc2a047

Shutdown server gracefully.
author Sascha L. Teichmann <sascha.teichmann@intevation.de>
date Mon, 02 Jul 2018 13:23:31 +0200
parents 0c56c56a1c44
children fe3a88f00b0a
line wrap: on
line source

package main

import (
	"context"
	"encoding/json"
	"flag"
	"fmt"
	"log"
	"net/http"
	"os"
	"os/signal"
	"path/filepath"
	"syscall"

	"gemma.intevation.de/gemma/auth"
)

func sysAdmin(rw http.ResponseWriter, req *http.Request) {
	session, _ := auth.GetSession(req)
	rw.Header().Set("Content-Type", "text/plain")
	fmt.Fprintf(rw, "%s is a sys_admin\n", session.User)
}

func renew(rw http.ResponseWriter, req *http.Request) {
	token, _ := auth.GetToken(req)
	newToken, err := auth.ConnPool.Renew(token)
	switch {
	case err == auth.ErrNoSuchToken:
		http.NotFound(rw, req)
		return
	case err != nil:
		http.Error(rw, fmt.Sprintf("error: %v", err), http.StatusInternalServerError)
		return
	}

	session, _ := auth.GetSession(req)

	var result = struct {
		Token   string   `json:"token"`
		Expires int64    `json:"expires"`
		User    string   `json:"user"`
		Roles   []string `json:"roles"`
	}{
		Token:   newToken,
		Expires: session.ExpiresAt,
		User:    session.User,
		Roles:   session.Roles,
	}

	rw.Header().Set("Content-Type", "text/plain")
	if err := json.NewEncoder(rw).Encode(&result); err != nil {
		log.Printf("error: %v\n", err)
	}
}

func logout(rw http.ResponseWriter, req *http.Request) {
	token, _ := auth.GetToken(req)
	deleted := auth.ConnPool.Delete(token)
	if !deleted {
		http.NotFound(rw, req)
		return
	}
	rw.Header().Set("Content-Type", "text/plain")
	fmt.Fprintln(rw, "token deleted")
}

func token(rw http.ResponseWriter, req *http.Request) {
	user := req.FormValue("user")
	password := req.FormValue("password")

	token, session, err := auth.GenerateSession(user, password)

	if err != nil {
		http.Error(rw, fmt.Sprintf("error: %v", err), http.StatusInternalServerError)
		return
	}

	var result = struct {
		Token   string   `json:"token"`
		Expires int64    `json:"expires"`
		User    string   `json:"user"`
		Roles   []string `json:"roles"`
	}{
		Token:   token,
		Expires: session.ExpiresAt,
		User:    session.User,
		Roles:   session.Roles,
	}

	rw.Header().Set("Content-Type", "application/json")
	if err := json.NewEncoder(rw).Encode(&result); err != nil {
		log.Printf("error: %v\n", err)
	}
}

func main() {
	port := flag.Int("port", 8000, "port to listen at.")
	host := flag.String("host", "localhost", "host to listen at.")
	flag.Parse()
	p, _ := filepath.Abs("./web")
	mux := http.NewServeMux()
	mux.Handle("/", http.StripPrefix("/", http.FileServer(http.Dir(p))))
	mux.HandleFunc("/api/token", token)
	mux.Handle("/api/logout", auth.SessionMiddleware(http.HandlerFunc(token)))
	mux.Handle("/api/renew", auth.SessionMiddleware(http.HandlerFunc(renew)))
	mux.Handle("/api/sys_admin",
		auth.SessionMiddleware(
			auth.SessionChecker(http.HandlerFunc(sysAdmin), auth.HasRole("sys_admin"))))

	addr := fmt.Sprintf("%s:%d", *host, *port)

	server := http.Server{Addr: addr, Handler: mux}

	done := make(chan error)

	go func() {
		defer close(done)
		done <- server.ListenAndServe()
	}()

	sigChan := make(chan os.Signal)
	signal.Notify(sigChan, os.Interrupt, os.Kill, syscall.SIGTERM)

	select {
	case err := <-done:
		if err != nil && err != http.ErrServerClosed {
			log.Fatalf("error: %v\n", err)
		}
	case <-sigChan:
	}

	server.Shutdown(context.Background())

	<-done

	if err := auth.ConnPool.Shutdown(); err != nil {
		log.Fatalf("error: %v\n", err)
	}
}