Mercurial > gemma
diff pkg/models/common.go @ 1901:71b722809b2b
Stretch import: Added stub.
author | Sascha L. Teichmann <sascha.teichmann@intevation.de> |
---|---|
date | Fri, 18 Jan 2019 15:52:51 +0100 |
parents | 6a67cd819e93 |
children | 0bc0312105e4 |
line wrap: on
line diff
--- a/pkg/models/common.go Fri Jan 18 15:04:53 2019 +0100 +++ b/pkg/models/common.go Fri Jan 18 15:52:51 2019 +0100 @@ -14,8 +14,11 @@ package models import ( + "database/sql/driver" "encoding/json" "errors" + "fmt" + "strings" "time" ) @@ -29,7 +32,13 @@ const DateFormat = "2006-01-02" -type Date struct{ time.Time } +type ( + Date struct{ time.Time } + // Country is a valid country 2 letter code. + Country string + // UniqueCountries is a list of unique countries. + UniqueCountries []Country +) func (srd Date) MarshalJSON() ([]byte, error) { return json.Marshal(srd.Format(DateFormat)) @@ -46,3 +55,59 @@ } return err } + +var ( + validCountries = []string{ + "AT", "BG", "DE", "HU", "HR", + "MD", "RO", "RS", "SK", "UA", + } + errNoValidCountry = errors.New("Not a valid country") +) + +// UnmarshalJSON ensures that the given string forms a valid +// two letter country code. +func (c *Country) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return err + } + s = strings.ToUpper(s) + for _, v := range validCountries { + if v == s { + *c = Country(v) + return nil + } + } + return errNoValidCountry +} + +// Value implements the driver.Valuer interface. +func (c Country) Value() (driver.Value, error) { + return string(c), nil +} + +// Scan implements the sql.Scanner interfaces. +func (c *Country) Scan(src interface{}) (err error) { + if s, ok := src.(string); ok { + *c = Country(s) + } else { + err = errNoString + } + return +} + +func (uc *UniqueCountries) UnmarshalJSON(data []byte) error { + var countries []Country + if err := json.Unmarshal(data, &countries); err != nil { + return err + } + unique := map[Country]struct{}{} + for _, c := range countries { + if _, found := unique[c]; found { + return fmt.Errorf("country '%s' is not unique", string(c)) + } + unique[c] = struct{}{} + } + *uc = countries + return nil +}