diff --git a/database/addressindex.go b/database/addressindex.go index de9ebba..8338d6d 100644 --- a/database/addressindex.go +++ b/database/addressindex.go @@ -5,8 +5,6 @@ package database import ( - "encoding/binary" - bolt "go.etcd.io/bbolt" ) @@ -22,7 +20,7 @@ func (vdb *VspDatabase) GetLastAddressIndex() (uint32, error) { return nil } - idx = binary.LittleEndian.Uint32(idxBytes) + idx = bytesToUint32(idxBytes) return nil }) @@ -35,11 +33,7 @@ func (vdb *VspDatabase) GetLastAddressIndex() (uint32, error) { func (vdb *VspDatabase) SetLastAddressIndex(idx uint32) error { err := vdb.db.Update(func(tx *bolt.Tx) error { vspBkt := tx.Bucket(vspBktK) - - idxBytes := make([]byte, 4) - binary.LittleEndian.PutUint32(idxBytes, idx) - - return vspBkt.Put(lastAddressIndexK, idxBytes) + return vspBkt.Put(lastAddressIndexK, uint32ToBytes(idx)) }) diff --git a/database/database.go b/database/database.go index 3ed2905..e2513f0 100644 --- a/database/database.go +++ b/database/database.go @@ -85,6 +85,16 @@ func writeHotBackupFile(db *bolt.DB) error { return err } +func uint32ToBytes(i uint32) []byte { + bytes := make([]byte, 4) + binary.LittleEndian.PutUint32(bytes, i) + return bytes +} + +func bytesToUint32(bytes []byte) uint32 { + return binary.LittleEndian.Uint32(bytes) +} + // CreateNew intializes a new bbolt database with all of the necessary vspd // buckets, and inserts: // - the provided extended pubkey (to be used for deriving fee addresses). @@ -108,10 +118,8 @@ func CreateNew(dbFile, feeXPub string) error { return fmt.Errorf("failed to create %s bucket: %w", string(vspBktK), err) } - // Initialize with database version 1. - vbytes := make([]byte, 4) - binary.LittleEndian.PutUint32(vbytes, uint32(1)) - err = vspBkt.Put(versionK, vbytes) + // Initialize with initial database version (1). + err = vspBkt.Put(versionK, uint32ToBytes(initialVersion)) if err != nil { return err } @@ -188,7 +196,19 @@ func Open(ctx context.Context, shutdownWg *sync.WaitGroup, dbFile string, backup return nil, fmt.Errorf("unable to open db file: %w", err) } - log.Debugf("Opened database file %s", dbFile) + vdb := &VspDatabase{db: db, maxVoteChangeRecords: maxVoteChangeRecords} + + dbVersion, err := vdb.Version() + if err != nil { + return nil, fmt.Errorf("unable to get db version: %w", err) + } + + log.Debugf("Opened database (version=%d, file=%s)", dbVersion, dbFile) + + err = vdb.Upgrade(dbVersion) + if err != nil { + return nil, fmt.Errorf("database upgrade failed: %w", err) + } // Start a ticker to update the backup file at the specified interval. shutdownWg.Add(1) @@ -209,7 +229,7 @@ func Open(ctx context.Context, shutdownWg *sync.WaitGroup, dbFile string, backup } }() - return &VspDatabase{db: db, maxVoteChangeRecords: maxVoteChangeRecords}, nil + return vdb, nil } // Close will close the database and then make a copy of the database to the @@ -305,9 +325,9 @@ func (vdb *VspDatabase) KeyPair() (ed25519.PrivateKey, ed25519.PublicKey, error) return signKey, pubKey, err } -// GetFeeXPub retrieves the extended pubkey used for generating fee addresses +// FeeXPub retrieves the extended pubkey used for generating fee addresses // from the database. -func (vdb *VspDatabase) GetFeeXPub() (string, error) { +func (vdb *VspDatabase) FeeXPub() (string, error) { var feeXPub string err := vdb.db.View(func(tx *bolt.Tx) error { vspBkt := tx.Bucket(vspBktK) @@ -325,9 +345,9 @@ func (vdb *VspDatabase) GetFeeXPub() (string, error) { return feeXPub, err } -// GetCookieSecret retrieves the generated cookie store secret key from the +// CookieSecret retrieves the generated cookie store secret key from the // database. -func (vdb *VspDatabase) GetCookieSecret() ([]byte, error) { +func (vdb *VspDatabase) CookieSecret() ([]byte, error) { var cookieSecret []byte err := vdb.db.View(func(tx *bolt.Tx) error { vspBkt := tx.Bucket(vspBktK) @@ -345,6 +365,21 @@ func (vdb *VspDatabase) GetCookieSecret() ([]byte, error) { return cookieSecret, err } +// Version returns the current database version. +func (vdb *VspDatabase) Version() (uint32, error) { + var version uint32 + err := vdb.db.View(func(tx *bolt.Tx) error { + bytes := tx.Bucket(vspBktK).Get(versionK) + version = bytesToUint32(bytes) + return nil + }) + if err != nil { + return 0, err + } + + return version, nil +} + // BackupDB streams a backup of the database over an http response writer. func (vdb *VspDatabase) BackupDB(w http.ResponseWriter) error { w.Header().Set("Content-Type", "application/octet-stream") diff --git a/database/database_test.go b/database/database_test.go index ad58e92..0ea37dc 100644 --- a/database/database_test.go +++ b/database/database_test.go @@ -90,7 +90,7 @@ func testCreateNew(t *testing.T) { } // A newly created DB should have a cookie secret. - secret, err := db.GetCookieSecret() + secret, err := db.CookieSecret() if err != nil { t.Fatalf("error getting cookie secret: %v", err) } @@ -100,7 +100,7 @@ func testCreateNew(t *testing.T) { } // A newly created DB should store the fee xpub it was initialized with. - retrievedXPub, err := db.GetFeeXPub() + retrievedXPub, err := db.FeeXPub() if err != nil { t.Fatalf("error getting fee xpub: %v", err) } diff --git a/database/database_upgrades.go b/database/database_upgrades.go new file mode 100644 index 0000000..c8e8af7 --- /dev/null +++ b/database/database_upgrades.go @@ -0,0 +1,32 @@ +package database + +import ( + "fmt" +) + +const ( + // initialVersion is the version of a freshly created database which has had + // no upgrades applied. + initialVersion = 1 + + // latestVersion is the latest version of the bolt database that is + // understood by vspd. Databases with recorded versions higher than + // this will fail to open (meaning any upgrades prevent reverting to older + // software). + latestVersion = initialVersion +) + +// Upgrade will update the database to the latest known version. +func (vdb *VspDatabase) Upgrade(currentVersion uint32) error { + if currentVersion == latestVersion { + // No upgrades required. + return nil + } + + if currentVersion > latestVersion { + // Database is too new. + return fmt.Errorf("expected database version <= %d, got %d", latestVersion, currentVersion) + } + + return nil +} diff --git a/database/votechange.go b/database/votechange.go index 981e9ff..1ffa697 100644 --- a/database/votechange.go +++ b/database/votechange.go @@ -5,7 +5,6 @@ package database import ( - "encoding/binary" "encoding/json" "fmt" "math" @@ -44,7 +43,7 @@ func (vdb *VspDatabase) SaveVoteChange(ticketHash string, record VoteChangeRecor lowest := uint32(math.MaxUint32) err = bkt.ForEach(func(k, v []byte) error { count++ - key := binary.LittleEndian.Uint32(k) + key := bytesToUint32(k) if key > highest { highest = key } @@ -60,9 +59,7 @@ func (vdb *VspDatabase) SaveVoteChange(ticketHash string, record VoteChangeRecor // If bucket is at (or over) the limit of max allowed records, remove // the oldest one. if count >= vdb.maxVoteChangeRecords { - keyBytes := make([]byte, 4) - binary.LittleEndian.PutUint32(keyBytes, lowest) - err = bkt.Delete(keyBytes) + err = bkt.Delete(uint32ToBytes(lowest)) if err != nil { return fmt.Errorf("failed to delete old vote change record: %w", err) } @@ -75,15 +72,12 @@ func (vdb *VspDatabase) SaveVoteChange(ticketHash string, record VoteChangeRecor newKey = highest + 1 } - keyBytes := make([]byte, 4) - binary.LittleEndian.PutUint32(keyBytes, newKey) - // Insert record. recordBytes, err := json.Marshal(record) if err != nil { return fmt.Errorf("could not marshal vote change record: %w", err) } - err = bkt.Put(keyBytes, recordBytes) + err = bkt.Put(uint32ToBytes(newKey), recordBytes) if err != nil { return fmt.Errorf("could not store vote change record: %w", err) } @@ -113,7 +107,7 @@ func (vdb *VspDatabase) GetVoteChanges(ticketHash string) (map[uint32]VoteChange return fmt.Errorf("could not unmarshal vote change record: %w", err) } - records[binary.LittleEndian.Uint32(k)] = record + records[bytesToUint32(k)] = record return nil }) diff --git a/webapi/webapi.go b/webapi/webapi.go index 6e1ae70..3f9e03c 100644 --- a/webapi/webapi.go +++ b/webapi/webapi.go @@ -78,7 +78,7 @@ func Start(ctx context.Context, requestShutdownChan chan struct{}, shutdownWg *s if err != nil { return fmt.Errorf("db.GetLastAddressIndex error: %w", err) } - feeXPub, err := vdb.GetFeeXPub() + feeXPub, err := vdb.FeeXPub() if err != nil { return fmt.Errorf("db.GetFeeXPub error: %w", err) } @@ -88,7 +88,7 @@ func Start(ctx context.Context, requestShutdownChan chan struct{}, shutdownWg *s } // Get the secret key used to initialize the cookie store. - cookieSecret, err := vdb.GetCookieSecret() + cookieSecret, err := vdb.CookieSecret() if err != nil { return fmt.Errorf("db.GetCookieSecret error: %w", err) }