diff pkg/imports/misc.go @ 2758:a996f2ca9fa5

Simplified savepoint handling.
author Sascha L. Teichmann <sascha.teichmann@intevation.de>
date Thu, 21 Mar 2019 15:43:02 +0100
parents 542d3441c2d8
children f464cbcdf2f2
line wrap: on
line diff
--- a/pkg/imports/misc.go	Thu Mar 21 14:29:47 2019 +0100
+++ b/pkg/imports/misc.go	Thu Mar 21 15:43:02 2019 +0100
@@ -34,25 +34,33 @@
 	ctx context.Context,
 	tx *sql.Tx,
 	name string,
-) error {
-	_, err := tx.ExecContext(ctx, "SAVEPOINT "+name)
-	return err
-}
+) func(func() error) error {
+
+	var (
+		savepoint = "SAVEPOINT " + name
+		rollback  = "ROLLBACK TO SAVEPOINT " + name
+		release   = "RELEASE SAVEPOINT " + name
+	)
 
-func RollbackToSavepoint(
-	ctx context.Context,
-	tx *sql.Tx,
-	name string,
-) error {
-	_, err := tx.ExecContext(ctx, "ROLLBACK TO SAVEPOINT "+name)
-	return err
-}
+	return func(fn func() error) (err error) {
+		if _, err = tx.ExecContext(ctx, savepoint); err != nil {
+			return
+		}
+		var done bool
+		defer func() {
+			if !done {
+				_, err2 := tx.ExecContext(ctx, rollback)
+				if err == nil {
+					err = err2
+				}
+			}
+		}()
+		err = fn()
 
-func ReleaseSavepoint(
-	ctx context.Context,
-	tx *sql.Tx,
-	name string,
-) error {
-	_, err := tx.ExecContext(ctx, "RELEASE SAVEPOINT "+name)
-	return err
+		if err == nil {
+			done = true
+			_, err = tx.ExecContext(ctx, release)
+		}
+		return
+	}
 }