changeset 3194:eeff2cc4ff9d

controllers: re-factored the SQL filter to a tree like structure to be of more general use.
author Sascha L. Teichmann <sascha.teichmann@intevation.de>
date Wed, 08 May 2019 13:11:30 +0200
parents 8329c6d3cf2a
children 88c31298eba8
files pkg/controllers/common.go pkg/controllers/gauges.go pkg/controllers/importqueue.go
diffstat 3 files changed, 145 insertions(+), 89 deletions(-) [+]
line wrap: on
line diff
--- a/pkg/controllers/common.go	Wed May 08 12:45:21 2019 +0200
+++ b/pkg/controllers/common.go	Wed May 08 13:11:30 2019 +0200
@@ -18,26 +18,61 @@
 	"strings"
 )
 
-type filterBuilder struct {
-	stmt    strings.Builder
-	args    []interface{}
-	hasCond bool
+type (
+	filterNode interface {
+		serialize(*strings.Builder, *[]interface{})
+	}
+
+	filterTerm struct {
+		format string
+		args   []interface{}
+	}
+
+	filterNot struct {
+		filterNode
+	}
+
+	filterAnd []filterNode
+	filterOr  []filterNode
+)
+
+func (ft *filterTerm) serialize(stmt *strings.Builder, args *[]interface{}) {
+	indices := make([]interface{}, len(ft.args))
+	for i := range indices {
+		indices[i] = len(*args) + i + 1
+	}
+	fmt.Fprintf(stmt, ft.format, indices...)
+	*args = append(*args, (*ft).args...)
 }
 
-func (fb *filterBuilder) arg(format string, v ...interface{}) {
-	indices := make([]interface{}, len(v))
-	for i := range indices {
-		indices[i] = len(fb.args) + i + 1
+func buildFilterTerm(format string, args ...interface{}) *filterTerm {
+	return &filterTerm{format: format, args: args}
+}
+
+func (fa filterAnd) serialize(stmt *strings.Builder, args *[]interface{}) {
+	for i, node := range fa {
+		if i > 0 {
+			stmt.WriteString(" AND ")
+		}
+		stmt.WriteByte('(')
+		node.serialize(stmt, args)
+		stmt.WriteByte(')')
 	}
-	fmt.Fprintf(&fb.stmt, format, indices...)
-	fb.args = append(fb.args, v...)
 }
 
-func (fb *filterBuilder) and(format string, v ...interface{}) {
-	if fb.hasCond {
-		fb.stmt.WriteString(" AND ")
-	} else {
-		fb.hasCond = true
+func (fo filterOr) serialize(stmt *strings.Builder, args *[]interface{}) {
+	for i, node := range fo {
+		if i > 0 {
+			stmt.WriteString(" OR ")
+		}
+		stmt.WriteByte('(')
+		node.serialize(stmt, args)
+		stmt.WriteByte(')')
 	}
-	fb.arg(format, v...)
 }
+
+func (fn *filterNot) serialize(stmt *strings.Builder, args *[]interface{}) {
+	stmt.WriteString("NOT (")
+	fn.filterNode.serialize(stmt, args)
+	stmt.WriteByte(')')
+}
--- a/pkg/controllers/gauges.go	Wed May 08 12:45:21 2019 +0200
+++ b/pkg/controllers/gauges.go	Wed May 08 13:11:30 2019 +0200
@@ -628,32 +628,29 @@
 		return
 	}
 
-	var fb filterBuilder
-	fb.stmt.WriteString(selectWaterlevelsSQL)
-
-	fb.and(
-		" fk_gauge_id = ($%d::char(2), $%d::char(3), $%d::char(5), $%d::char(5), $%d::int) ",
-		isrs.CountryCode,
-		isrs.LoCode,
-		isrs.FairwaySection,
-		isrs.Orc,
-		isrs.Hectometre,
-	)
-
-	fb.and(
-		`(NOT predicted
-         OR (
-           date_issue = (
-             SELECT max(date_issue) FROM waterway.gauge_measurements
-             WHERE fk_gauge_id = ($%d::char(2), $%d::char(3), $%d::char(5), $%d::char(5), $%d::int)
-           )
-         ))`,
-		isrs.CountryCode,
-		isrs.LoCode,
-		isrs.FairwaySection,
-		isrs.Orc,
-		isrs.Hectometre,
-	)
+	filters := filterAnd{
+		buildFilterTerm(
+			"fk_gauge_id = ($%d::char(2), $%d::char(3), $%d::char(5), $%d::char(5), $%d::int)",
+			isrs.CountryCode,
+			isrs.LoCode,
+			isrs.FairwaySection,
+			isrs.Orc,
+			isrs.Hectometre,
+		),
+		&filterOr{
+			&filterNot{&filterTerm{format: "predicted"}},
+			buildFilterTerm(
+				`date_issue = (
+                 SELECT max(date_issue) FROM waterway.gauge_measurements
+                 WHERE fk_gauge_id = ($%d::char(2), $%d::char(3), $%d::char(5), $%d::char(5), $%d::int)`,
+				isrs.CountryCode,
+				isrs.LoCode,
+				isrs.FairwaySection,
+				isrs.Orc,
+				isrs.Hectometre,
+			),
+		},
+	}
 
 	if from := req.FormValue("from"); from != "" {
 		fromTime, err := time.Parse(models.ImportTimeFormat, from)
@@ -663,7 +660,7 @@
 				http.StatusBadRequest)
 			return
 		}
-		fb.and("measure_date >= $%d", fromTime)
+		filters = append(filters, buildFilterTerm("measure_date >= $%d", fromTime))
 	}
 
 	if to := req.FormValue("to"); to != "" {
@@ -674,14 +671,20 @@
 				http.StatusBadRequest)
 			return
 		}
-		fb.and("measure_date <= $%d", toTime)
+		filters = append(filters, buildFilterTerm("measure_date <= $%d", toTime))
 	}
 
+	var stmt strings.Builder
+	var args []interface{}
+
+	stmt.WriteString(selectWaterlevelsSQL)
+	filters.serialize(&stmt, &args)
+
 	conn := middleware.GetDBConn(req)
 
 	ctx := req.Context()
 
-	rows, err := conn.QueryContext(ctx, fb.stmt.String(), fb.args...)
+	rows, err := conn.QueryContext(ctx, stmt.String(), args...)
 	if err != nil {
 		http.Error(
 			rw, fmt.Sprintf("error: %v", err),
--- a/pkg/controllers/importqueue.go	Wed May 08 12:45:21 2019 +0200
+++ b/pkg/controllers/importqueue.go	Wed May 08 13:11:30 2019 +0200
@@ -132,30 +132,22 @@
 	return &ta
 }
 
-func buildFilters(req *http.Request) (l, b, a *filterBuilder, err error) {
+type filledStmt struct {
+	stmt strings.Builder
+	args []interface{}
+}
 
-	l = new(filterBuilder)
-	a = new(filterBuilder)
-	b = new(filterBuilder)
+func buildFilters(req *http.Request) (*filledStmt, *filledStmt, *filledStmt, error) {
+
+	var l, a, b filterAnd
 
 	var noBefore, noAfter bool
 
-	var counting bool
-
-	switch count := strings.ToLower(req.FormValue("count")); count {
-	case "1", "t", "true":
-		counting = true
-		l.stmt.WriteString(selectImportsCountSQL)
-	default:
-		l.stmt.WriteString(selectImportsSQL)
-	}
-	a.stmt.WriteString(selectAfterSQL)
-	b.stmt.WriteString(selectBeforeSQL)
-
-	cond := func(format string, v ...interface{}) {
-		l.and(format, v...)
-		a.and(format, v...)
-		b.and(format, v...)
+	cond := func(format string, args ...interface{}) {
+		term := &filterTerm{format: format, args: args}
+		l = append(l, term)
+		a = append(l, term)
+		b = append(b, term)
 	}
 
 	if query := req.FormValue("query"); query != "" {
@@ -181,23 +173,23 @@
 	}
 
 	if from := req.FormValue("from"); from != "" {
-		var fromTime time.Time
-		if fromTime, err = time.Parse(models.ImportTimeFormat, from); err != nil {
-			return
+		fromTime, err := time.Parse(models.ImportTimeFormat, from)
+		if err != nil {
+			return nil, nil, nil, err
 		}
-		l.and(" enqueued >= $%d ", fromTime)
-		b.and(" enqueued < $%d", fromTime)
+		l = append(l, buildFilterTerm("enqueued >= $%d", fromTime))
+		b = append(b, buildFilterTerm("enqueued < $%d", fromTime))
 	} else {
 		noBefore = true
 	}
 
 	if to := req.FormValue("to"); to != "" {
-		var toTime time.Time
-		if toTime, err = time.Parse(models.ImportTimeFormat, to); err != nil {
-			return
+		toTime, err := time.Parse(models.ImportTimeFormat, to)
+		if err != nil {
+			return nil, nil, nil, err
 		}
-		l.and(" enqueued <= $%d ", toTime)
-		a.and(" enqueued > $%d", toTime)
+		l = append(l, buildFilterTerm("enqueued <= $%d", toTime))
+		a = append(a, buildFilterTerm("enqueued > $%d", toTime))
 	} else {
 		noAfter = true
 	}
@@ -207,32 +199,58 @@
 		cond(" id IN (SELECT id FROM warned) ")
 	}
 
-	if !l.hasCond {
-		l.stmt.WriteString(" TRUE ")
+	fl := &filledStmt{}
+	fa := &filledStmt{}
+	fb := &filledStmt{}
+
+	fa.stmt.WriteString(selectAfterSQL)
+	fb.stmt.WriteString(selectBeforeSQL)
+
+	var counting bool
+
+	switch count := strings.ToLower(req.FormValue("count")); count {
+	case "1", "t", "true":
+		counting = true
+		fl.stmt.WriteString(selectImportsCountSQL)
+	default:
+		fl.stmt.WriteString(selectImportsSQL)
 	}
-	if !b.hasCond {
-		b.stmt.WriteString(" TRUE ")
+
+	if len(l) == 0 {
+		fl.stmt.WriteString(" TRUE ")
+	} else {
+		l.serialize(&fl.stmt, &fl.args)
 	}
-	if !a.hasCond {
-		a.stmt.WriteString(" TRUE ")
+
+	if len(b) == 0 {
+		fb.stmt.WriteString(" TRUE ")
+	} else {
+		b.serialize(&fb.stmt, &fb.args)
+	}
+
+	if len(a) == 0 {
+		fa.stmt.WriteString(" TRUE ")
+	} else {
+		a.serialize(&fa.stmt, &fa.args)
 	}
 
 	if !counting {
-		l.stmt.WriteString(" ORDER BY enqueued DESC ")
-		a.stmt.WriteString(" ORDER BY enqueued LIMIT 1")
-		b.stmt.WriteString(" ORDER BY enqueued DESC LIMIT 1")
+		fl.stmt.WriteString(" ORDER BY enqueued DESC ")
+		fa.stmt.WriteString(" ORDER BY enqueued LIMIT 1")
+		fb.stmt.WriteString(" ORDER BY enqueued DESC LIMIT 1")
 	}
 
 	if noBefore {
-		b = nil
+		fb = nil
 	}
 	if noAfter {
-		a = nil
+		fa = nil
 	}
-	return
+
+	return fl, fb, fa, nil
 }
 
-func neighbored(ctx context.Context, conn *sql.Conn, fb *filterBuilder) *models.ImportTime {
+func neighbored(ctx context.Context, conn *sql.Conn, fb *filledStmt) *models.ImportTime {
 
 	var when time.Time
 	err := conn.QueryRowContext(ctx, fb.stmt.String(), fb.args...).Scan(&when)
@@ -252,7 +270,7 @@
 	conn *sql.Conn,
 ) (jr JSONResult, err error) {
 
-	var list, before, after *filterBuilder
+	var list, before, after *filledStmt
 
 	if list, before, after, err = buildFilters(req); err != nil {
 		return