view contrib/gmaggregate/main.go @ 5711:2dd155cc95ec revive-cleanup

Fix all revive issue (w/o machine generated stuff).
author Sascha L. Teichmann <sascha.teichmann@intevation.de>
date Tue, 20 Feb 2024 22:22:57 +0100
parents e9ef27c75e5c
children 6270951dda28
line wrap: on
line source

// This is Free Software under GNU Affero General Public License v >= 3.0
// without warranty, see README.md and license for details.
//
// SPDX-License-Identifier: AGPL-3.0-or-later
// License-Filename: LICENSE
//
// Copyright (C) 2021 by via donau
//   - Österreichische Wasserstraßen-Gesellschaft mbH
// Software engineering by Intevation GmbH
//
// Author(s):
//  * Sascha L. Teichmann <sascha.teichmann@intevation.de>

//go:generate ragel -Z -G2 -o matcher.go matcher.rl
//go:generate go fmt matcher.go

package main

import (
	"container/heap"
	"context"
	"database/sql"
	"encoding/csv"
	"flag"
	"fmt"
	"log"
	"os"
	"runtime"
	"sort"
	"strconv"
	"strings"
	"sync"
	"time"

	_ "github.com/jackc/pgx/v4/stdlib"
)

const (
	selectOldGMLogsSQL = `
  SELECT
    lo.import_id,
    lo.time,
    lo.kind,
    lo.msg
  FROM import.imports im
  JOIN import.import_logs lo
    ON lo.import_id = im.id
  WHERE im.kind = 'gm'
  ORDER BY lo.import_id`

	createFilteredLogsSQL = `
  CREATE TABLE filtered_logs (
    import_id integer                  NOT NULL,
    time      timestamp with time zone NOT NULL,
    kind      log_type                 NOT NULL,
    msg       text                     NOT NULL
  )`

	insertFilteredLogsSQL = `
  INSERT INTO filtered_logs (import_id, time, kind, msg)
    VALUES ($1, $2, $3::log_type, $4)`

	deleteOldGMLogsSQL = `
  DELETE FROM import.import_logs WHERE import_id IN (
    SELECT import_id FROM filtered_logs)`

	copyDataSQL = `
  INSERT INTO import.import_logs (import_id, time, kind, msg)
  SELECT import_id, time, kind, msg
  FROM filtered_logs`

	dropFilteredLogsSQL = `DROP TABLE filtered_logs`
)

type phases int

const (
	nonePhase   phases = 0
	filterPhase phases = 1 << iota
	transferPhase
)

type gauge struct {
	gid           string
	unknown       bool
	assumeZPG     bool
	ignMeasCodes  []string
	rescaleErrors []string
	missingValues []string
	assumeCM      int
	badValues     int
	measurements  int
	predictions   int
}

type aggregator struct {
	current string
	hold    *line

	lastGauge *gauge
	gauges    []*gauge

	stack [4]string
}

type line struct {
	time time.Time
	kind string
	msg  string
}

type importLines struct {
	seq   int
	id    int64
	lines []line
}

type processor struct {
	cond       *sync.Cond
	aggregated []*importLines
	nextOutSeq int
	done       bool
}

type writer interface {
	prepare(context.Context, *sql.Conn) error
	write(*importLines)
	finish()
	error() error
}

type csvWriter struct {
	err  error
	file *os.File
	out  *csv.Writer
	row  [1 + 1 + 1 + 1]string
}

type sqlWriter struct {
	err  error
	ctx  context.Context
	tx   *sql.Tx
	stmt *sql.Stmt
}

func (ps phases) has(p phases) bool {
	return ps&p == p
}

func parsePhases(s string) (phases, error) {
	ps := nonePhase
	for _, x := range strings.Split(s, ",") {
		switch strings.ToLower(strings.TrimSpace(x)) {
		case "transfer":
			ps |= transferPhase
		case "filter":
			ps |= filterPhase
		default:
			return nonePhase, fmt.Errorf("invalid phase '%s'", x)
		}
	}
	return ps, nil
}

func (g *gauge) getAssumeZPG() bool               { return g.assumeZPG }
func (g *gauge) getUnknown() bool                 { return g.unknown }
func (g *gauge) getIgnoredMeasureCodes() []string { return g.ignMeasCodes }
func (g *gauge) getRescaleErrors() []string       { return g.rescaleErrors }
func (g *gauge) getMissingValues() []string       { return g.missingValues }
func (g *gauge) getAssumeCM() int                 { return g.assumeCM }
func (g *gauge) getBadValues() int                { return g.badValues }
func (g *gauge) getPredictions() int              { return g.predictions }
func (g *gauge) getMeasurements() int             { return g.measurements }
func (g *gauge) nothingChanged() bool             { return g.measurements == 0 && g.predictions == 0 }

func (agg *aggregator) reset() {
	agg.current = ""
	agg.hold = nil
	agg.lastGauge = nil
	agg.gauges = nil
}

func (agg *aggregator) find(name string) *gauge {
	if agg.lastGauge != nil && name == agg.lastGauge.gid {
		return agg.lastGauge
	}
	for _, g := range agg.gauges {
		if g.gid == name {
			agg.lastGauge = g
			return g
		}
	}
	g := &gauge{gid: name}
	agg.gauges = append(agg.gauges, g)
	agg.lastGauge = g
	return g
}

func extend(haystack []string, needle string) []string {
	for _, straw := range haystack {
		if straw == needle {
			return haystack
		}
	}
	return append(haystack, needle)
}

func (agg *aggregator) logBool(
	access func(*gauge) bool,
	header string,
	log func(string),
) {
	var sb strings.Builder
	for _, g := range agg.gauges {
		if access(g) {
			if sb.Len() == 0 {
				sb.WriteString(header)
			} else {
				sb.WriteString(", ")
			}
			sb.WriteString(g.gid)
		}
	}
	if sb.Len() > 0 {
		log(sb.String())
	}
}

func (agg *aggregator) logInt(
	access func(*gauge) int,
	header string,
	log func(string),
) {
	gs := make([]*gauge, 0, len(agg.gauges))
	for _, g := range agg.gauges {
		if access(g) > 0 {
			gs = append(gs, g)
		}
	}

	if len(gs) == 0 {
		return
	}

	sort.SliceStable(gs, func(i, j int) bool {
		return access(gs[i]) < access(gs[j])
	})

	var sb strings.Builder
	var last int

	for _, g := range gs {
		if c := access(g); c != last {
			if sb.Len() == 0 {
				sb.WriteString(header)
			} else {
				sb.WriteString("); ")
			}
			sb.WriteString(strconv.Itoa(c))
			sb.WriteString(" (")
			last = c
		} else {
			sb.WriteString(", ")
		}
		sb.WriteString(g.gid)
	}

	sb.WriteByte(')')
	log(sb.String())
}

func (agg *aggregator) logString(
	access func(*gauge) []string,
	header string,
	log func(string),
) {
	var sb strings.Builder
	for _, g := range agg.gauges {
		if s := access(g); len(s) > 0 {
			if sb.Len() == 0 {
				sb.WriteString(header)
			} else {
				sb.WriteString(", ")
			}
			sb.WriteString(g.gid)
			sb.WriteString(" (")
			for i, v := range s {
				if i > 0 {
					sb.WriteString("; ")
				}
				sb.WriteString(v)
			}
			sb.WriteByte(')')
		}
	}
	if sb.Len() > 0 {
		log(sb.String())
	}
}

func (agg *aggregator) aggregate(out []line, last time.Time) []line {

	// Guarantee that new lines has a time after already put out lines.
	if n := len(out); n > 0 && !out[n-1].time.Before(last) {
		last = out[n-1].time.Add(time.Millisecond)
	}

	log := func(kind, msg string) {
		out = append(out, line{last, kind, msg})
		last = last.Add(time.Millisecond)
	}

	infoLog := func(msg string) { log("info", msg) }
	warnLog := func(msg string) { log("warn", msg) }
	errLog := func(msg string) { log("error", msg) }

	agg.logBool(
		(*gauge).getUnknown,
		"Cannot find following gauges: ",
		warnLog)

	agg.logBool(
		(*gauge).getAssumeZPG,
		"'Reference_code' not specified. Assuming 'ZPG': ",
		warnLog)

	agg.logInt(
		(*gauge).getAssumeCM,
		"'Unit' not specified. Assuming 'cm': ",
		warnLog)

	agg.logInt(
		(*gauge).getBadValues,
		"Ignored measurements with value -99999: ",
		warnLog)

	agg.logString(
		(*gauge).getMissingValues,
		"Missing mandatory values: ",
		warnLog)

	agg.logString(
		(*gauge).getRescaleErrors,
		"Cannot convert units: ",
		errLog)

	agg.logString(
		(*gauge).getRescaleErrors,
		"Ignored measure codes: ",
		warnLog)

	agg.logInt(
		(*gauge).getPredictions,
		"New predictions: ",
		infoLog)

	agg.logInt(
		(*gauge).getMeasurements,
		"New measurements: ",
		infoLog)

	agg.logBool(
		(*gauge).nothingChanged,
		"No changes for: ",
		infoLog)

	if agg.hold != nil {
		agg.hold.time = last
		out = append(out, *agg.hold)
	}
	return out
}

func (agg *aggregator) run(
	wg *sync.WaitGroup,
	logs <-chan *importLines,
	pr *processor,
) {
	defer wg.Done()
	for l := range logs {
		// Do sorting by time in user land to take advantage
		// of concurrent workers.
		lines := l.lines
		sort.Slice(lines, func(i, j int) bool {
			return lines[i].time.Before(lines[j].time)
		})

		out := lines[:0:len(lines)]
		for i := range lines {
			line := &lines[i]
			if !agg.match(line.msg, line) {
				out = append(out, *line)
			}
		}
		l.lines = agg.aggregate(out, lines[len(lines)-1].time)
		pr.consume(l)
		agg.reset()
	}
}

const timeFormat = "2006-01-02 15:04:05.999999-07"

func newCSVWriter(filename string) (*csvWriter, error) {

	f, err := os.Create(filename)
	if err != nil {
		return nil, err
	}

	return &csvWriter{
		file: f,
		out:  csv.NewWriter(f),
	}, nil
}

func (cw *csvWriter) prepare(context.Context, *sql.Conn) error {
	return nil
}

func (cw *csvWriter) error() error { return cw.err }

func (cw *csvWriter) write(entry *importLines) {
	if cw.err != nil {
		return
	}
	row := cw.row[:]
	row[0] = strconv.FormatInt(entry.id, 10)
	for _, l := range entry.lines {
		row[1] = l.time.Format(timeFormat)
		row[2] = l.kind
		row[3] = l.msg
		if cw.err = cw.out.Write(row); cw.err != nil {
			log.Printf("error: Writing to CSV file failed: %v\n", cw.err)
			return
		}
	}
}

func (cw *csvWriter) finish() {
	cw.out.Flush()
	if err := cw.out.Error(); err != nil {
		log.Printf("error: flushing CSV file failed: %v\n", err)
	}
	if err := cw.file.Close(); err != nil {
		log.Printf("Closing CSV file failed: %v\n", err)
	}
}

func (sw *sqlWriter) prepare(ctx context.Context, conn *sql.Conn) error {

	tx, err := conn.BeginTx(ctx, nil)
	if err != nil {
		return err
	}

	if _, err := tx.ExecContext(ctx, createFilteredLogsSQL); err != nil {
		tx.Rollback()
		return fmt.Errorf("cannot create new log table: %v", err)
	}

	stmt, err := tx.PrepareContext(ctx, insertFilteredLogsSQL)
	if err != nil {
		tx.Rollback()
		return err
	}

	sw.ctx = ctx
	sw.tx = tx
	sw.stmt = stmt
	return nil
}

func (sw *sqlWriter) error() error { return sw.err }

func (sw *sqlWriter) write(entry *importLines) {
	if sw.err != nil {
		return
	}
	for _, l := range entry.lines {
		if _, sw.err = sw.stmt.ExecContext(
			sw.ctx,
			entry.id,
			l.time,
			l.kind,
			l.msg,
		); sw.err != nil {
			log.Printf("error: writing log line to db failed: %v\n", sw.err)
			return
		}
	}
}

func (sw *sqlWriter) finish() {
	if err := sw.stmt.Close(); err != nil {
		log.Printf("error: close stmt failed: %v\n", err)
	}
	if sw.err == nil {
		if err := sw.tx.Commit(); err != nil {
			log.Printf("error: Commiting transaction failed: %v\n", err)
		}
	} else if err := sw.tx.Rollback(); err != nil {
		log.Printf("error: Rollback transaction failed: %v\n", err)
	}
}

func (pr *processor) Push(x interface{}) {
	pr.aggregated = append(pr.aggregated, x.(*importLines))
}

func (pr *processor) Pop() interface{} {
	n := len(pr.aggregated)
	x := pr.aggregated[n-1]
	pr.aggregated[n-1] = nil
	pr.aggregated = pr.aggregated[:n-1]
	return x
}

func (pr *processor) Len() int { return len(pr.aggregated) }

func (pr *processor) Less(i, j int) bool {
	return pr.aggregated[i].seq < pr.aggregated[j].seq
}

func (pr *processor) Swap(i, j int) {
	pr.aggregated[i], pr.aggregated[j] = pr.aggregated[j], pr.aggregated[i]
}

func (pr *processor) consume(l *importLines) {
	pr.cond.L.Lock()
	heap.Push(pr, l)
	pr.cond.L.Unlock()
	pr.cond.Signal()
}

func (pr *processor) quit() {
	pr.cond.L.Lock()
	pr.done = true
	pr.cond.L.Unlock()
	pr.cond.Signal()
}

func (pr *processor) drain(write func(*importLines)) {

	for {
		pr.cond.L.Lock()
		for !pr.done &&
			(len(pr.aggregated) == 0 || pr.aggregated[0].seq != pr.nextOutSeq) {
			pr.cond.Wait()
		}
		if pr.done {
			for len(pr.aggregated) > 0 {
				write(heap.Pop(pr).(*importLines))
			}
			pr.cond.L.Unlock()
			return
		}
		l := heap.Pop(pr).(*importLines)
		//log.Printf("%d %p\n", c.nextOutSeq, l)
		pr.nextOutSeq++
		pr.cond.L.Unlock()
		write(l)
	}
}

func (pr *processor) filterPhase(db *sql.DB, worker int, wr writer) error {

	log.Println("filter phase started")

	ctx := context.Background()

	con1, err := db.Conn(ctx)
	if err != nil {
		return err
	}
	defer con1.Close()

	con2, err := db.Conn(ctx)
	if err != nil {
		return err
	}
	defer con2.Close()

	tx, err := con1.BeginTx(ctx, &sql.TxOptions{ReadOnly: true})
	if err != nil {
		return err
	}
	defer tx.Rollback()

	if err := wr.prepare(ctx, con2); err != nil {
		return err
	}
	defer wr.finish()

	logs := make(chan *importLines)
	var wg sync.WaitGroup

	for i := 0; i < worker; i++ {
		wg.Add(1)
		go new(aggregator).run(&wg, logs, pr)
	}

	writeDone := make(chan struct{})

	go func() {
		defer close(writeDone)
		pr.drain(wr.write)
	}()

	log.Println("Querying for old logs started. (Can take a while.)")
	rows, err := tx.QueryContext(ctx, selectOldGMLogsSQL)
	if err != nil {
		return err
	}
	defer rows.Close()

	log.Println("Querying done. (Maybe restart the gemma server, now?)")

	var (
		count    int64
		current  *importLines
		seq      int
		l        line
		importID int64
		start    = time.Now()
		last     = start
	)

	log.Println("Filtering started.")
	for rows.Next() {
		if err := rows.Scan(&importID, &l.time, &l.kind, &l.msg); err != nil {
			return err
		}

		if current == nil || importID != current.id {
			if current != nil {
				logs <- current
			}
			current = &importLines{
				seq: seq,
				id:  importID,
			}
			seq++
		}
		current.lines = append(current.lines, l)

		if count++; count%1_000_000 == 0 {
			now := time.Now()
			diff := now.Sub(last)
			log.Printf("lines: %d rate: %.2f lines/s\n",
				count,
				1_000_000/diff.Seconds())
			last = now
		}
	}
	if current != nil && len(current.lines) > 0 {
		logs <- current
	}
	close(logs)
	wg.Wait()

	pr.quit()

	<-writeDone

	rate := float64(count) / time.Since(start).Seconds()
	log.Printf("lines: %d rate: %.2f lines/s imports: %d\n",
		count, rate, seq)
	return nil
}

func (pr *processor) transferPhase(db *sql.DB) error {
	log.Println("Transfer phase started.")
	ctx := context.Background()
	conn, err := db.Conn(ctx)
	if err != nil {
		return err
	}
	defer conn.Close()
	tx, err := conn.BeginTx(ctx, nil)
	if err != nil {
		return err
	}
	defer tx.Rollback()

	for _, sql := range []string{
		deleteOldGMLogsSQL,
		copyDataSQL,
		dropFilteredLogsSQL,
	} {
		if _, err := tx.ExecContext(ctx, sql); err != nil {
			return err
		}
	}
	return tx.Commit()
}

func newProcessor() *processor {
	return &processor{
		cond: sync.NewCond(new(sync.Mutex)),
	}
}

func process(
	host, dbname string, port int,
	worker int,
	csvFile string,
	ps phases,
) error {

	p := newProcessor()
	var wr writer

	if csvFile != "" {
		var err error
		if wr, err = newCSVWriter(csvFile); err != nil {
			return fmt.Errorf("error: Cannot create CSV file: %v", err)
		}
	} else {
		wr = new(sqlWriter)
	}

	dsn := fmt.Sprintf("host=%s dbname=%s port=%d", host, dbname, port)
	db, err := sql.Open("pgx", dsn)
	if err != nil {
		return err
	}
	defer db.Close()

	if ps.has(filterPhase) {
		if err := p.filterPhase(db, worker, wr); err != nil {
			return err
		}
	}
	if ps.has(transferPhase) {
		if err := p.transferPhase(db); err != nil {
			return err
		}
	}

	return nil
}

func main() {
	var (
		host   = flag.String("h", "/var/run/postgresql", "database host")
		dbname = flag.String("d", "gemma", "database")
		port   = flag.Int("p", 5432, "database port")
		worker = flag.Int("w", runtime.NumCPU(), "workers to aggregate")
		csv    = flag.String("c", "", "CSV file to be written")
		phases = flag.String("phases", "filter,transfer", "Phases filter and/or transfer")
	)

	flag.Parse()

	ps, err := parsePhases(*phases)
	if err != nil {
		log.Fatalf("error: %v\n", err)
	}

	start := time.Now()
	if err := process(*host, *dbname, *port, *worker, *csv, ps); err != nil {
		log.Fatalf("error: %v\n", err)
	}
	log.Printf("time took: %s\n", time.Since(start))
}