From 3695587b48ab896506937ac9819d526b1659dd10 Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Mon, 26 Feb 2024 11:40:12 +0100 Subject: [PATCH] Add DB migration infrastructure --- db.go | 52 ++++++++++++++++++++++++++++++++++++++++++++++------ schema.sql | 2 -- diff --git a/db.go b/db.go index fff905e1975a1828bee34ab92f4c70f640354155..f1dbab190c309899e6d590b034659e2fda2f8369 100644 --- a/db.go +++ b/db.go @@ -13,6 +13,10 @@ //go:embed schema.sql var schema string +var migrations = []string{ + "", // migration #0 is reserved for schema initialization +} + var errNoDBRows = sql.ErrNoRows type DB struct { @@ -35,15 +39,13 @@ return db, nil } func (db *DB) init(ctx context.Context) error { - var n int - if err := db.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM sqlite_schema").Scan(&n); err != nil { + version, err := db.upgrade(ctx) + if err != nil { return err - } else if n != 0 { - return nil } - if _, err := db.db.ExecContext(ctx, schema); err != nil { - return err + if version > 0 { + return nil } // TODO: drop this @@ -52,6 +54,44 @@ if err := defaultUser.SetPassword("root"); err != nil { return err } return db.StoreUser(ctx, &defaultUser) +} + +func (db *DB) upgrade(ctx context.Context) (version int, err error) { + if err := db.db.QueryRow("PRAGMA user_version").Scan(&version); err != nil { + return 0, fmt.Errorf("failed to query schema version: %v", err) + } + + if version == len(migrations) { + return version, nil + } else if version > len(migrations) { + return version, fmt.Errorf("sinwon (version %d) older than schema (version %d)", len(migrations), version) + } + + tx, err := db.db.Begin() + if err != nil { + return version, err + } + defer tx.Rollback() + + if version == 0 { + if _, err := tx.Exec(schema); err != nil { + return version, fmt.Errorf("failed to initialize schema: %v", err) + } + } else { + for i := version; i < len(migrations); i++ { + if _, err := tx.Exec(migrations[i]); err != nil { + return version, fmt.Errorf("failed to execute migration #%v: %v", i, err) + } + } + } + + // For some reason prepared statements don't work here + _, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", len(migrations))) + if err != nil { + return version, fmt.Errorf("failed to bump schema version: %v", err) + } + + return version, tx.Commit() } func (db *DB) Close() error { diff --git a/schema.sql b/schema.sql index 7bfcd417e5110f27559ee62a04fe02f73c3c51e8..86a892f57152029daf1ab4dfde6ae56971ba85fe 100644 --- a/schema.sql +++ b/schema.sql @@ -1,5 +1,3 @@ -PRAGMA user_version = 1; - CREATE TABLE User ( id INTEGER PRIMARY KEY, username TEXT NOT NULL UNIQUE, -- 2.48.1