Mercurial > gemma
changeset 189:96bb671cdd98
Some input checking for email, roles and valid countries.
author | Sascha L. Teichmann <sascha.teichmann@intevation.de> |
---|---|
date | Fri, 20 Jul 2018 11:09:36 +0200 |
parents | ee3093966a6d |
children | 3457a60fb12d 01c5794608e0 |
files | cmd/tokenserver/user.go |
diffstat | 1 files changed, 105 insertions(+), 24 deletions(-) [+] |
line wrap: on
line diff
--- a/cmd/tokenserver/user.go Fri Jul 20 10:27:42 2018 +0200 +++ b/cmd/tokenserver/user.go Fri Jul 20 11:09:36 2018 +0200 @@ -3,29 +3,37 @@ import ( "database/sql" "encoding/json" - "fmt" + "errors" "log" "net/http" + "regexp" + "strings" "gemma.intevation.de/gemma/auth" "github.com/jackc/pgx" ) -type BoundingBox struct { - X1 float64 `json:"x1"` - Y1 float64 `json:"y1"` - X2 float64 `json:"x2"` - Y2 float64 `json:"y2"` -} +type ( + Email string + Country string + Role string -type User struct { - User string `json:"user"` - Role string `json:"role"` - Password string `json:"password"` - Email string `json:"email"` - Country string `json:"country"` - Extent *BoundingBox `json:"extent"` -} + BoundingBox struct { + X1 float64 `json:"x1"` + Y1 float64 `json:"y1"` + X2 float64 `json:"x2"` + Y2 float64 `json:"y2"` + } + + User struct { + User string `json:"user"` + Role Role `json:"role"` + Password string `json:"password"` + Email Email `json:"email"` + Country Country `json:"country"` + Extent *BoundingBox `json:"extent"` + } +) const ( createUserSQL = `SELECT create_user($1, $2, $3, $4, NULL, $5)` @@ -33,14 +41,87 @@ ST_MakeBox2D(ST_Point($5, $6), ST_Point($7, $8)), $9)` ) +var ( + // https://stackoverflow.com/questions/201323/how-to-validate-an-email-address-using-a-regular-expression + emailRe = regexp.MustCompile( + `(?:[a-z0-9!#$%&'*+/=?^_` + "`" + + `{|}~-]+(?:\.[a-z0-9!#$%&'*+/=?^_` + "`" + + `{|}~-]+)*|"(?:[\x01-\x08\x0b\x0c\x0e-\x1f\x21\x23-\x5b\x5d-\x7f]` + + `|\\[\x01-\x09\x0b\x0c\x0e-\x7f])*")` + + `@(?:(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?` + + `|\[(?:(?:(2(5[0-5]|[0-4][0-9])|1[0-9][0-9]|[1-9]?[0-9]))\.){3}` + + `(?:(2(5[0-5]|[0-4][0-9])|1[0-9][0-9]|[1-9]?[0-9])|[a-z0-9-]*[a-z0-9]` + + `:(?:[\x01-\x08\x0b\x0c\x0e-\x1f\x21-\x5a\x53-\x7f]` + + `|\\[\x01-\x09\x0b\x0c\x0e-\x7f])+)\])`) + errNoEmailAddress = errors.New("Not a valid email address") +) + +func (e *Email) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return err + } + if !emailRe.MatchString(s) { + return errNoEmailAddress + } + *e = Email(s) + return nil +} + +var ( + validCountries = []string{ + "AT", "BG", "DE", "HU", "HR", + "MD", "RO", "RS", "SK", "UA", + } + errNoValidCountry = errors.New("Not a valid country") +) + +func (c *Country) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return err + } + s = strings.ToUpper(s) + for _, v := range validCountries { + if v == s { + *c = Country(v) + return nil + } + } + return errNoValidCountry +} + +var ( + validRoles = []string{ + "waterway_user", + "waterway_admin", + "sys_admin", + } + errNoValidRole = errors.New("Not a valid role") +) + +func (r *Role) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return err + } + s = strings.ToLower(s) + for _, v := range validRoles { + if v == s { + *r = Role(v) + return nil + } + } + return errNoValidRole +} + func createUser(rw http.ResponseWriter, req *http.Request) { var user User defer req.Body.Close() if err := json.NewDecoder(req.Body).Decode(&user); err != nil { - log.Printf("err: %v\n", err) - http.Error(rw, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) + http.Error(rw, "error: "+err.Error(), http.StatusBadRequest) return } @@ -49,22 +130,22 @@ if user.Extent == nil { _, err = db.Exec( createUserSQL, - user.Role, + string(user.Role), user.User, user.Password, - user.Country, - user.Email, + string(user.Country), + string(user.Email), ) } else { _, err = db.Exec( createUserSQL, - user.Role, + string(user.Role), user.User, user.Password, - user.Country, + string(user.Country), user.Extent.X1, user.Extent.Y1, user.Extent.X2, user.Extent.Y2, - user.Email, + string(user.Email), ) } return @@ -84,7 +165,7 @@ } else { log.Printf("err: %v\n", err) http.Error(rw, - fmt.Sprintf("error: %s", err.Error()), + "error: "+err.Error(), http.StatusInternalServerError) return }