Mercurial > gemma
changeset 3194:eeff2cc4ff9d
controllers: re-factored the SQL filter to a tree like structure to be of more general use.
author | Sascha L. Teichmann <sascha.teichmann@intevation.de> |
---|---|
date | Wed, 08 May 2019 13:11:30 +0200 |
parents | 8329c6d3cf2a |
children | 88c31298eba8 |
files | pkg/controllers/common.go pkg/controllers/gauges.go pkg/controllers/importqueue.go |
diffstat | 3 files changed, 145 insertions(+), 89 deletions(-) [+] |
line wrap: on
line diff
--- a/pkg/controllers/common.go Wed May 08 12:45:21 2019 +0200 +++ b/pkg/controllers/common.go Wed May 08 13:11:30 2019 +0200 @@ -18,26 +18,61 @@ "strings" ) -type filterBuilder struct { - stmt strings.Builder - args []interface{} - hasCond bool +type ( + filterNode interface { + serialize(*strings.Builder, *[]interface{}) + } + + filterTerm struct { + format string + args []interface{} + } + + filterNot struct { + filterNode + } + + filterAnd []filterNode + filterOr []filterNode +) + +func (ft *filterTerm) serialize(stmt *strings.Builder, args *[]interface{}) { + indices := make([]interface{}, len(ft.args)) + for i := range indices { + indices[i] = len(*args) + i + 1 + } + fmt.Fprintf(stmt, ft.format, indices...) + *args = append(*args, (*ft).args...) } -func (fb *filterBuilder) arg(format string, v ...interface{}) { - indices := make([]interface{}, len(v)) - for i := range indices { - indices[i] = len(fb.args) + i + 1 +func buildFilterTerm(format string, args ...interface{}) *filterTerm { + return &filterTerm{format: format, args: args} +} + +func (fa filterAnd) serialize(stmt *strings.Builder, args *[]interface{}) { + for i, node := range fa { + if i > 0 { + stmt.WriteString(" AND ") + } + stmt.WriteByte('(') + node.serialize(stmt, args) + stmt.WriteByte(')') } - fmt.Fprintf(&fb.stmt, format, indices...) - fb.args = append(fb.args, v...) } -func (fb *filterBuilder) and(format string, v ...interface{}) { - if fb.hasCond { - fb.stmt.WriteString(" AND ") - } else { - fb.hasCond = true +func (fo filterOr) serialize(stmt *strings.Builder, args *[]interface{}) { + for i, node := range fo { + if i > 0 { + stmt.WriteString(" OR ") + } + stmt.WriteByte('(') + node.serialize(stmt, args) + stmt.WriteByte(')') } - fb.arg(format, v...) } + +func (fn *filterNot) serialize(stmt *strings.Builder, args *[]interface{}) { + stmt.WriteString("NOT (") + fn.filterNode.serialize(stmt, args) + stmt.WriteByte(')') +}
--- a/pkg/controllers/gauges.go Wed May 08 12:45:21 2019 +0200 +++ b/pkg/controllers/gauges.go Wed May 08 13:11:30 2019 +0200 @@ -628,32 +628,29 @@ return } - var fb filterBuilder - fb.stmt.WriteString(selectWaterlevelsSQL) - - fb.and( - " fk_gauge_id = ($%d::char(2), $%d::char(3), $%d::char(5), $%d::char(5), $%d::int) ", - isrs.CountryCode, - isrs.LoCode, - isrs.FairwaySection, - isrs.Orc, - isrs.Hectometre, - ) - - fb.and( - `(NOT predicted - OR ( - date_issue = ( - SELECT max(date_issue) FROM waterway.gauge_measurements - WHERE fk_gauge_id = ($%d::char(2), $%d::char(3), $%d::char(5), $%d::char(5), $%d::int) - ) - ))`, - isrs.CountryCode, - isrs.LoCode, - isrs.FairwaySection, - isrs.Orc, - isrs.Hectometre, - ) + filters := filterAnd{ + buildFilterTerm( + "fk_gauge_id = ($%d::char(2), $%d::char(3), $%d::char(5), $%d::char(5), $%d::int)", + isrs.CountryCode, + isrs.LoCode, + isrs.FairwaySection, + isrs.Orc, + isrs.Hectometre, + ), + &filterOr{ + &filterNot{&filterTerm{format: "predicted"}}, + buildFilterTerm( + `date_issue = ( + SELECT max(date_issue) FROM waterway.gauge_measurements + WHERE fk_gauge_id = ($%d::char(2), $%d::char(3), $%d::char(5), $%d::char(5), $%d::int)`, + isrs.CountryCode, + isrs.LoCode, + isrs.FairwaySection, + isrs.Orc, + isrs.Hectometre, + ), + }, + } if from := req.FormValue("from"); from != "" { fromTime, err := time.Parse(models.ImportTimeFormat, from) @@ -663,7 +660,7 @@ http.StatusBadRequest) return } - fb.and("measure_date >= $%d", fromTime) + filters = append(filters, buildFilterTerm("measure_date >= $%d", fromTime)) } if to := req.FormValue("to"); to != "" { @@ -674,14 +671,20 @@ http.StatusBadRequest) return } - fb.and("measure_date <= $%d", toTime) + filters = append(filters, buildFilterTerm("measure_date <= $%d", toTime)) } + var stmt strings.Builder + var args []interface{} + + stmt.WriteString(selectWaterlevelsSQL) + filters.serialize(&stmt, &args) + conn := middleware.GetDBConn(req) ctx := req.Context() - rows, err := conn.QueryContext(ctx, fb.stmt.String(), fb.args...) + rows, err := conn.QueryContext(ctx, stmt.String(), args...) if err != nil { http.Error( rw, fmt.Sprintf("error: %v", err),
--- a/pkg/controllers/importqueue.go Wed May 08 12:45:21 2019 +0200 +++ b/pkg/controllers/importqueue.go Wed May 08 13:11:30 2019 +0200 @@ -132,30 +132,22 @@ return &ta } -func buildFilters(req *http.Request) (l, b, a *filterBuilder, err error) { +type filledStmt struct { + stmt strings.Builder + args []interface{} +} - l = new(filterBuilder) - a = new(filterBuilder) - b = new(filterBuilder) +func buildFilters(req *http.Request) (*filledStmt, *filledStmt, *filledStmt, error) { + + var l, a, b filterAnd var noBefore, noAfter bool - var counting bool - - switch count := strings.ToLower(req.FormValue("count")); count { - case "1", "t", "true": - counting = true - l.stmt.WriteString(selectImportsCountSQL) - default: - l.stmt.WriteString(selectImportsSQL) - } - a.stmt.WriteString(selectAfterSQL) - b.stmt.WriteString(selectBeforeSQL) - - cond := func(format string, v ...interface{}) { - l.and(format, v...) - a.and(format, v...) - b.and(format, v...) + cond := func(format string, args ...interface{}) { + term := &filterTerm{format: format, args: args} + l = append(l, term) + a = append(l, term) + b = append(b, term) } if query := req.FormValue("query"); query != "" { @@ -181,23 +173,23 @@ } if from := req.FormValue("from"); from != "" { - var fromTime time.Time - if fromTime, err = time.Parse(models.ImportTimeFormat, from); err != nil { - return + fromTime, err := time.Parse(models.ImportTimeFormat, from) + if err != nil { + return nil, nil, nil, err } - l.and(" enqueued >= $%d ", fromTime) - b.and(" enqueued < $%d", fromTime) + l = append(l, buildFilterTerm("enqueued >= $%d", fromTime)) + b = append(b, buildFilterTerm("enqueued < $%d", fromTime)) } else { noBefore = true } if to := req.FormValue("to"); to != "" { - var toTime time.Time - if toTime, err = time.Parse(models.ImportTimeFormat, to); err != nil { - return + toTime, err := time.Parse(models.ImportTimeFormat, to) + if err != nil { + return nil, nil, nil, err } - l.and(" enqueued <= $%d ", toTime) - a.and(" enqueued > $%d", toTime) + l = append(l, buildFilterTerm("enqueued <= $%d", toTime)) + a = append(a, buildFilterTerm("enqueued > $%d", toTime)) } else { noAfter = true } @@ -207,32 +199,58 @@ cond(" id IN (SELECT id FROM warned) ") } - if !l.hasCond { - l.stmt.WriteString(" TRUE ") + fl := &filledStmt{} + fa := &filledStmt{} + fb := &filledStmt{} + + fa.stmt.WriteString(selectAfterSQL) + fb.stmt.WriteString(selectBeforeSQL) + + var counting bool + + switch count := strings.ToLower(req.FormValue("count")); count { + case "1", "t", "true": + counting = true + fl.stmt.WriteString(selectImportsCountSQL) + default: + fl.stmt.WriteString(selectImportsSQL) } - if !b.hasCond { - b.stmt.WriteString(" TRUE ") + + if len(l) == 0 { + fl.stmt.WriteString(" TRUE ") + } else { + l.serialize(&fl.stmt, &fl.args) } - if !a.hasCond { - a.stmt.WriteString(" TRUE ") + + if len(b) == 0 { + fb.stmt.WriteString(" TRUE ") + } else { + b.serialize(&fb.stmt, &fb.args) + } + + if len(a) == 0 { + fa.stmt.WriteString(" TRUE ") + } else { + a.serialize(&fa.stmt, &fa.args) } if !counting { - l.stmt.WriteString(" ORDER BY enqueued DESC ") - a.stmt.WriteString(" ORDER BY enqueued LIMIT 1") - b.stmt.WriteString(" ORDER BY enqueued DESC LIMIT 1") + fl.stmt.WriteString(" ORDER BY enqueued DESC ") + fa.stmt.WriteString(" ORDER BY enqueued LIMIT 1") + fb.stmt.WriteString(" ORDER BY enqueued DESC LIMIT 1") } if noBefore { - b = nil + fb = nil } if noAfter { - a = nil + fa = nil } - return + + return fl, fb, fa, nil } -func neighbored(ctx context.Context, conn *sql.Conn, fb *filterBuilder) *models.ImportTime { +func neighbored(ctx context.Context, conn *sql.Conn, fb *filledStmt) *models.ImportTime { var when time.Time err := conn.QueryRowContext(ctx, fb.stmt.String(), fb.args...).Scan(&when) @@ -252,7 +270,7 @@ conn *sql.Conn, ) (jr JSONResult, err error) { - var list, before, after *filterBuilder + var list, before, after *filledStmt if list, before, after, err = buildFilters(req); err != nil { return