diff pkg/models/common.go @ 1901:71b722809b2b

Stretch import: Added stub.
author Sascha L. Teichmann <sascha.teichmann@intevation.de>
date Fri, 18 Jan 2019 15:52:51 +0100
parents 6a67cd819e93
children 0bc0312105e4
line wrap: on
line diff
--- a/pkg/models/common.go	Fri Jan 18 15:04:53 2019 +0100
+++ b/pkg/models/common.go	Fri Jan 18 15:52:51 2019 +0100
@@ -14,8 +14,11 @@
 package models
 
 import (
+	"database/sql/driver"
 	"encoding/json"
 	"errors"
+	"fmt"
+	"strings"
 	"time"
 )
 
@@ -29,7 +32,13 @@
 
 const DateFormat = "2006-01-02"
 
-type Date struct{ time.Time }
+type (
+	Date struct{ time.Time }
+	// Country is a valid country 2 letter code.
+	Country string
+	// UniqueCountries is a list of unique countries.
+	UniqueCountries []Country
+)
 
 func (srd Date) MarshalJSON() ([]byte, error) {
 	return json.Marshal(srd.Format(DateFormat))
@@ -46,3 +55,59 @@
 	}
 	return err
 }
+
+var (
+	validCountries = []string{
+		"AT", "BG", "DE", "HU", "HR",
+		"MD", "RO", "RS", "SK", "UA",
+	}
+	errNoValidCountry = errors.New("Not a valid country")
+)
+
+// UnmarshalJSON ensures that the given string forms a valid
+// two letter country code.
+func (c *Country) UnmarshalJSON(data []byte) error {
+	var s string
+	if err := json.Unmarshal(data, &s); err != nil {
+		return err
+	}
+	s = strings.ToUpper(s)
+	for _, v := range validCountries {
+		if v == s {
+			*c = Country(v)
+			return nil
+		}
+	}
+	return errNoValidCountry
+}
+
+// Value implements the driver.Valuer interface.
+func (c Country) Value() (driver.Value, error) {
+	return string(c), nil
+}
+
+// Scan implements the sql.Scanner interfaces.
+func (c *Country) Scan(src interface{}) (err error) {
+	if s, ok := src.(string); ok {
+		*c = Country(s)
+	} else {
+		err = errNoString
+	}
+	return
+}
+
+func (uc *UniqueCountries) UnmarshalJSON(data []byte) error {
+	var countries []Country
+	if err := json.Unmarshal(data, &countries); err != nil {
+		return err
+	}
+	unique := map[Country]struct{}{}
+	for _, c := range countries {
+		if _, found := unique[c]; found {
+			return fmt.Errorf("country '%s' is not unique", string(c))
+		}
+		unique[c] = struct{}{}
+	}
+	*uc = countries
+	return nil
+}