changeset 189:96bb671cdd98

Some input checking for email, roles and valid countries.
author Sascha L. Teichmann <sascha.teichmann@intevation.de>
date Fri, 20 Jul 2018 11:09:36 +0200
parents ee3093966a6d
children 3457a60fb12d 01c5794608e0
files cmd/tokenserver/user.go
diffstat 1 files changed, 105 insertions(+), 24 deletions(-) [+]
line wrap: on
line diff
--- a/cmd/tokenserver/user.go	Fri Jul 20 10:27:42 2018 +0200
+++ b/cmd/tokenserver/user.go	Fri Jul 20 11:09:36 2018 +0200
@@ -3,29 +3,37 @@
 import (
 	"database/sql"
 	"encoding/json"
-	"fmt"
+	"errors"
 	"log"
 	"net/http"
+	"regexp"
+	"strings"
 
 	"gemma.intevation.de/gemma/auth"
 	"github.com/jackc/pgx"
 )
 
-type BoundingBox struct {
-	X1 float64 `json:"x1"`
-	Y1 float64 `json:"y1"`
-	X2 float64 `json:"x2"`
-	Y2 float64 `json:"y2"`
-}
+type (
+	Email   string
+	Country string
+	Role    string
 
-type User struct {
-	User     string       `json:"user"`
-	Role     string       `json:"role"`
-	Password string       `json:"password"`
-	Email    string       `json:"email"`
-	Country  string       `json:"country"`
-	Extent   *BoundingBox `json:"extent"`
-}
+	BoundingBox struct {
+		X1 float64 `json:"x1"`
+		Y1 float64 `json:"y1"`
+		X2 float64 `json:"x2"`
+		Y2 float64 `json:"y2"`
+	}
+
+	User struct {
+		User     string       `json:"user"`
+		Role     Role         `json:"role"`
+		Password string       `json:"password"`
+		Email    Email        `json:"email"`
+		Country  Country      `json:"country"`
+		Extent   *BoundingBox `json:"extent"`
+	}
+)
 
 const (
 	createUserSQL       = `SELECT create_user($1, $2, $3, $4, NULL, $5)`
@@ -33,14 +41,87 @@
   ST_MakeBox2D(ST_Point($5, $6), ST_Point($7, $8)), $9)`
 )
 
+var (
+	// https://stackoverflow.com/questions/201323/how-to-validate-an-email-address-using-a-regular-expression
+	emailRe = regexp.MustCompile(
+		`(?:[a-z0-9!#$%&'*+/=?^_` + "`" +
+			`{|}~-]+(?:\.[a-z0-9!#$%&'*+/=?^_` + "`" +
+			`{|}~-]+)*|"(?:[\x01-\x08\x0b\x0c\x0e-\x1f\x21\x23-\x5b\x5d-\x7f]` +
+			`|\\[\x01-\x09\x0b\x0c\x0e-\x7f])*")` +
+			`@(?:(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?` +
+			`|\[(?:(?:(2(5[0-5]|[0-4][0-9])|1[0-9][0-9]|[1-9]?[0-9]))\.){3}` +
+			`(?:(2(5[0-5]|[0-4][0-9])|1[0-9][0-9]|[1-9]?[0-9])|[a-z0-9-]*[a-z0-9]` +
+			`:(?:[\x01-\x08\x0b\x0c\x0e-\x1f\x21-\x5a\x53-\x7f]` +
+			`|\\[\x01-\x09\x0b\x0c\x0e-\x7f])+)\])`)
+	errNoEmailAddress = errors.New("Not a valid email address")
+)
+
+func (e *Email) UnmarshalJSON(data []byte) error {
+	var s string
+	if err := json.Unmarshal(data, &s); err != nil {
+		return err
+	}
+	if !emailRe.MatchString(s) {
+		return errNoEmailAddress
+	}
+	*e = Email(s)
+	return nil
+}
+
+var (
+	validCountries = []string{
+		"AT", "BG", "DE", "HU", "HR",
+		"MD", "RO", "RS", "SK", "UA",
+	}
+	errNoValidCountry = errors.New("Not a valid country")
+)
+
+func (c *Country) UnmarshalJSON(data []byte) error {
+	var s string
+	if err := json.Unmarshal(data, &s); err != nil {
+		return err
+	}
+	s = strings.ToUpper(s)
+	for _, v := range validCountries {
+		if v == s {
+			*c = Country(v)
+			return nil
+		}
+	}
+	return errNoValidCountry
+}
+
+var (
+	validRoles = []string{
+		"waterway_user",
+		"waterway_admin",
+		"sys_admin",
+	}
+	errNoValidRole = errors.New("Not a valid role")
+)
+
+func (r *Role) UnmarshalJSON(data []byte) error {
+	var s string
+	if err := json.Unmarshal(data, &s); err != nil {
+		return err
+	}
+	s = strings.ToLower(s)
+	for _, v := range validRoles {
+		if v == s {
+			*r = Role(v)
+			return nil
+		}
+	}
+	return errNoValidRole
+}
+
 func createUser(rw http.ResponseWriter, req *http.Request) {
 
 	var user User
 
 	defer req.Body.Close()
 	if err := json.NewDecoder(req.Body).Decode(&user); err != nil {
-		log.Printf("err: %v\n", err)
-		http.Error(rw, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
+		http.Error(rw, "error: "+err.Error(), http.StatusBadRequest)
 		return
 	}
 
@@ -49,22 +130,22 @@
 		if user.Extent == nil {
 			_, err = db.Exec(
 				createUserSQL,
-				user.Role,
+				string(user.Role),
 				user.User,
 				user.Password,
-				user.Country,
-				user.Email,
+				string(user.Country),
+				string(user.Email),
 			)
 		} else {
 			_, err = db.Exec(
 				createUserSQL,
-				user.Role,
+				string(user.Role),
 				user.User,
 				user.Password,
-				user.Country,
+				string(user.Country),
 				user.Extent.X1, user.Extent.Y1,
 				user.Extent.X2, user.Extent.Y2,
-				user.Email,
+				string(user.Email),
 			)
 		}
 		return
@@ -84,7 +165,7 @@
 		} else {
 			log.Printf("err: %v\n", err)
 			http.Error(rw,
-				fmt.Sprintf("error: %s", err.Error()),
+				"error: "+err.Error(),
 				http.StatusInternalServerError)
 			return
 		}