diff --git a/.gitignore b/.gitignore index e5ec9ef..963d265 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ dcrvsp +/database/test.db diff --git a/README.md b/README.md index bfe1f1f..13c56cc 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,8 @@ - [gin-gonic](https://github.com/gin-gonic/gin) webserver - [bbolt](https://github.com/etcd-io/bbolt) database + - Tickets are stored in a single bucket, using ticket hash as the key and a + json encoded representation of the ticket as the value. ## MVP features diff --git a/database/database.go b/database/database.go index 33a5674..81680e4 100644 --- a/database/database.go +++ b/database/database.go @@ -8,24 +8,26 @@ import ( bolt "go.etcd.io/bbolt" ) -// VspDatabase wraps an instance of bbolt DB and provides VSP specific +// VspDatabase wraps an instance of bolt.DB and provides VSP specific // convenience functions. type VspDatabase struct { db *bolt.DB } +// The keys used in the database. var ( - // vspBkt is the main parent bucket of the VSP. All values and other buckets - // are nested within it. - vspBkt = []byte("vspbkt") - ticketsBkt = []byte("ticketsbkt") - versionK = []byte("version") - version = 1 + // vspbkt is the main parent bucket of the VSP database. All values and + // other buckets are nested within it. + vspBktK = []byte("vspbkt") + // ticketbkt stores all tickets known by this VSP. + ticketBktK = []byte("ticketbkt") + // version is the current database version. + versionK = []byte("version") ) -// New initialises and returns a database connection. If no database file is -// found at the provided path, a new one will be created. Returns an open -// database connection which should be closed after use. +// New initialises and returns a database. If no database file is found at the +// provided path, a new one will be created. Returns an open database which +// should always be closed after use. func New(dbFile string) (*VspDatabase, error) { db, err := bolt.Open(dbFile, 0600, &bolt.Options{Timeout: 1 * time.Second}) if err != nil { @@ -40,26 +42,35 @@ func New(dbFile string) (*VspDatabase, error) { return &VspDatabase{db: db}, nil } +// Close releases all database resources. It will block waiting for any open +// transactions to finish before closing the database and returning. +func (vdb *VspDatabase) Close() error { + return vdb.db.Close() +} + // createBuckets creates all storage buckets of the VSP if they don't already // exist. func createBuckets(db *bolt.DB) error { return db.Update(func(tx *bolt.Tx) error { - if tx.Bucket(vspBkt) == nil { - parentBkt, err := tx.CreateBucket(vspBkt) + if tx.Bucket(vspBktK) == nil { + // Create parent bucket. + vspBkt, err := tx.CreateBucket(vspBktK) if err != nil { - return fmt.Errorf("failed to create %s bucket: %v", string(vspBkt), err) + return fmt.Errorf("failed to create %s bucket: %v", string(vspBktK), err) } + // Initialise with database version 1. vbytes := make([]byte, 4) - binary.LittleEndian.PutUint32(vbytes, uint32(version)) - err = parentBkt.Put(versionK, vbytes) + binary.LittleEndian.PutUint32(vbytes, uint32(1)) + err = vspBkt.Put(versionK, vbytes) if err != nil { return err } - _, err = parentBkt.CreateBucket(ticketsBkt) + // Create ticket bucket. + _, err = vspBkt.CreateBucket(ticketBktK) if err != nil { - return fmt.Errorf("failed to create %s bucket: %v", string(ticketsBkt), err) + return fmt.Errorf("failed to create %s bucket: %v", string(ticketBktK), err) } } diff --git a/database/database_test.go b/database/database_test.go new file mode 100644 index 0000000..fed6875 --- /dev/null +++ b/database/database_test.go @@ -0,0 +1,196 @@ +package database + +import ( + "os" + "testing" +) + +var ( + testDb = "test.db" + ticket = Ticket{ + Hash: "Hash", + CommitmentSignature: "CommitmentSignature", + FeeAddress: "FeeAddress", + Address: "Address", + SDiff: 1, + BlockHeight: 2, + VoteBits: 3, + VotingKey: "VotingKey", + } + db *VspDatabase +) + +// TestDatabase runs all database tests. +func TestDatabase(t *testing.T) { + // Ensure we are starting with a clean environment. + os.Remove(testDb) + + // All sub-tests to run. + tests := map[string]func(*testing.T){ + "testInsertFeeAddress": testInsertFeeAddress, + "testGetFeeAddressByTicketHash": testGetFeeAddressByTicketHash, + "testGetFeesByFeeAddress": testGetFeesByFeeAddress, + "testInsertFeeAddressVotingKey": testInsertFeeAddressVotingKey, + "testGetInactiveFeeAddresses": testGetInactiveFeeAddresses, + } + + for testName, test := range tests { + // Create a new blank database for each sub-test. + var err error + db, err = New(testDb) + if err != nil { + t.Fatalf("error creating test database: %v", err) + } + + // Run the sub-test. + t.Run(testName, test) + + // Close and remove test database after each sub-test. + db.Close() + os.Remove(testDb) + } +} + +func testInsertFeeAddress(t *testing.T) { + // Insert a ticket into the database. + err := db.InsertFeeAddress(ticket) + if err != nil { + t.Fatalf("error storing ticket in database: %v", err) + } + + // Inserting a ticket with the same hash should fail. + err = db.InsertFeeAddress(ticket) + if err == nil { + t.Fatal("expected an error inserting ticket with duplicate hash") + } +} + +func testGetFeeAddressByTicketHash(t *testing.T) { + // Insert a ticket into the database. + err := db.InsertFeeAddress(ticket) + if err != nil { + t.Fatalf("error storing ticket in database: %v", err) + } + + // Retrieve ticket from database. + retrieved, err := db.GetFeeAddressByTicketHash(ticket.Hash) + if err != nil { + t.Fatalf("error retrieving ticket by ticket hash: %v", err) + } + + // Check ticket fields match expected. + if retrieved.Hash != ticket.Hash || + retrieved.CommitmentSignature != ticket.CommitmentSignature || + retrieved.FeeAddress != ticket.FeeAddress || + retrieved.Address != ticket.Address || + retrieved.SDiff != ticket.SDiff || + retrieved.BlockHeight != ticket.BlockHeight || + retrieved.VoteBits != ticket.VoteBits || + retrieved.VotingKey != ticket.VotingKey { + t.Fatal("retrieved ticket value didnt match expected") + } + + // Error if non-existent ticket requested. + _, err = db.GetFeeAddressByTicketHash("Not a real ticket hash") + if err == nil { + t.Fatal("expected an error while retrieving a non-existent ticket") + } +} + +func testGetFeesByFeeAddress(t *testing.T) { + // Insert a ticket into the database. + err := db.InsertFeeAddress(ticket) + if err != nil { + t.Fatalf("error storing ticket in database: %v", err) + } + + // Retrieve ticket using its fee address. + retrieved, err := db.GetFeesByFeeAddress(ticket.FeeAddress) + if err != nil { + t.Fatalf("error retrieving ticket by fee address: %v", err) + } + + // Check it is the correct ticket. + if retrieved.FeeAddress != ticket.FeeAddress { + t.Fatal("retrieved ticket FeeAddress didnt match expected") + } + + // Error if non-existent ticket requested. + _, err = db.GetFeesByFeeAddress("Not a real fee address") + if err == nil { + t.Fatal("expected an error while retrieving a non-existent ticket") + } + + // Insert another ticket into the database with the same fee address. + ticket.Hash = ticket.Hash + "2" + err = db.InsertFeeAddress(ticket) + if err != nil { + t.Fatalf("error storing ticket in database: %v", err) + } + + // Error when more than one ticket matches + _, err = db.GetFeesByFeeAddress(ticket.FeeAddress) + if err == nil { + t.Fatal("expected an error when multiple tickets are found") + } +} + +func testInsertFeeAddressVotingKey(t *testing.T) { + // Insert a ticket into the database. + err := db.InsertFeeAddress(ticket) + if err != nil { + t.Fatalf("error storing ticket in database: %v", err) + } + + // Update values. + newVotingKey := ticket.VotingKey + "2" + newVoteBits := ticket.VoteBits + 2 + err = db.InsertFeeAddressVotingKey(ticket.Address, newVotingKey, newVoteBits) + if err != nil { + t.Fatalf("error updating votingkey and votebits: %v", err) + } + + // Retrieve ticket from database. + retrieved, err := db.GetFeeAddressByTicketHash(ticket.Hash) + if err != nil { + t.Fatalf("error retrieving ticket by ticket hash: %v", err) + } + + // Check ticket fields match expected. + if newVoteBits != retrieved.VoteBits || + newVotingKey != retrieved.VotingKey { + t.Fatal("retrieved ticket value didnt match expected") + } +} + +func testGetInactiveFeeAddresses(t *testing.T) { + // Insert a ticket into the database. + err := db.InsertFeeAddress(ticket) + if err != nil { + t.Fatalf("error storing ticket in database: %v", err) + } + + // Insert a ticket with empty voting key into the database. + ticket.Hash = ticket.Hash + "2" + newFeeAddr := ticket.FeeAddress + "2" + ticket.FeeAddress = newFeeAddr + ticket.VotingKey = "" + err = db.InsertFeeAddress(ticket) + if err != nil { + t.Fatalf("error storing ticket in database: %v", err) + } + + // Retrieve unused fee address from database. + feeAddrs, err := db.GetInactiveFeeAddresses() + if err != nil { + t.Fatalf("error retrieving inactive fee addresses: %v", err) + } + + // Check we have one value, and its the expected one. + if len(feeAddrs) != 1 { + t.Fatal("expected 1 unused fee address") + } + if feeAddrs[0] != newFeeAddr { + t.Fatal("fee address didnt match expected") + } +} diff --git a/database/ticket.go b/database/ticket.go new file mode 100644 index 0000000..59c2eb7 --- /dev/null +++ b/database/ticket.go @@ -0,0 +1,140 @@ +package database + +import ( + "encoding/json" + "fmt" + + bolt "go.etcd.io/bbolt" +) + +type Ticket struct { + Hash string `json:"hash"` + CommitmentSignature string `json:"commitmentsignature"` + FeeAddress string `json:"feeaddress"` + Address string `json:"address"` + SDiff int64 `json:"sdiff"` + BlockHeight int64 `json:"blockheight"` + VoteBits uint16 `json:"votebits"` + VotingKey string `json:"votingkey"` +} + +func (vdb *VspDatabase) InsertFeeAddress(ticket Ticket) error { + return vdb.db.Update(func(tx *bolt.Tx) error { + ticketBkt := tx.Bucket(vspBktK).Bucket(ticketBktK) + + if ticketBkt.Get([]byte(ticket.Hash)) != nil { + return fmt.Errorf("ticket already exists with hash %s", ticket.Hash) + } + + ticketBytes, err := json.Marshal(ticket) + if err != nil { + return err + } + + return ticketBkt.Put([]byte(ticket.Hash), ticketBytes) + }) +} + +func (vdb *VspDatabase) InsertFeeAddressVotingKey(address, votingKey string, voteBits uint16) error { + return vdb.db.Update(func(tx *bolt.Tx) error { + ticketBkt := tx.Bucket(vspBktK).Bucket(ticketBktK) + c := ticketBkt.Cursor() + + for k, v := c.First(); k != nil; k, v = c.Next() { + var ticket Ticket + err := json.Unmarshal(v, &ticket) + if err != nil { + return fmt.Errorf("could not unmarshal ticket: %v", err) + } + + if ticket.Address == address { + ticket.VotingKey = votingKey + ticket.VoteBits = voteBits + ticketBytes, err := json.Marshal(ticket) + if err != nil { + return err + } + ticketBkt.Put(k, ticketBytes) + } + } + + return nil + }) +} + +func (vdb *VspDatabase) GetInactiveFeeAddresses() ([]string, error) { + var addrs []string + err := vdb.db.View(func(tx *bolt.Tx) error { + ticketBkt := tx.Bucket(vspBktK).Bucket(ticketBktK) + c := ticketBkt.Cursor() + + for k, v := c.First(); k != nil; k, v = c.Next() { + var ticket Ticket + err := json.Unmarshal(v, &ticket) + if err != nil { + return fmt.Errorf("could not unmarshal ticket: %v", err) + } + + if ticket.VotingKey == "" { + addrs = append(addrs, ticket.FeeAddress) + } + } + + return nil + }) + + return addrs, err +} + +func (vdb *VspDatabase) GetFeesByFeeAddress(feeAddr string) (*Ticket, error) { + var tickets []Ticket + err := vdb.db.View(func(tx *bolt.Tx) error { + ticketBkt := tx.Bucket(vspBktK).Bucket(ticketBktK) + c := ticketBkt.Cursor() + + for k, v := c.First(); k != nil; k, v = c.Next() { + var ticket Ticket + err := json.Unmarshal(v, &ticket) + if err != nil { + return fmt.Errorf("could not unmarshal ticket: %v", err) + } + + if ticket.FeeAddress == feeAddr { + tickets = append(tickets, ticket) + } + } + + return nil + }) + + if err != nil { + return nil, err + } + + if len(tickets) != 1 { + return nil, fmt.Errorf("expected 1 ticket with fee address %s, found %d", feeAddr, len(tickets)) + } + + return &tickets[0], nil +} + +func (vdb *VspDatabase) GetFeeAddressByTicketHash(ticketHash string) (Ticket, error) { + var ticket Ticket + err := vdb.db.View(func(tx *bolt.Tx) error { + ticketBkt := tx.Bucket(vspBktK).Bucket(ticketBktK) + + ticketBytes := ticketBkt.Get([]byte(ticketHash)) + if ticketBytes == nil { + return fmt.Errorf("no ticket found with hash %s", ticketHash) + } + + err := json.Unmarshal(ticketBytes, &ticket) + if err != nil { + return fmt.Errorf("could not unmarshal ticket: %v", err) + } + + return nil + }) + + return ticket, err +} diff --git a/database/tickets.go b/database/tickets.go deleted file mode 100644 index 26d0d95..0000000 --- a/database/tickets.go +++ /dev/null @@ -1,32 +0,0 @@ -package database - -type Ticket struct { - Hash string - CommitmentSignature string - FeeAddress string - Address string - SDiff int64 - BlockHeight int64 - VoteBits uint16 - VotingKey string -} - -func (db *VspDatabase) InsertFeeAddressVotingKey(address, votingKey string, voteBits uint16) error { - return nil -} - -func (db *VspDatabase) InsertFeeAddress(t Ticket) error { - return nil -} - -func (db *VspDatabase) GetInactiveFeeAddresses() ([]string, error) { - return []string{""}, nil -} - -func (db *VspDatabase) GetFeesByFeeAddress(feeAddr string) (Ticket, error) { - return Ticket{}, nil -} - -func (db *VspDatabase) GetFeeAddressByTicketHash() (Ticket, error) { - return Ticket{}, nil -} diff --git a/main.go b/main.go index 4292008..eec830a 100644 --- a/main.go +++ b/main.go @@ -70,7 +70,6 @@ func initConfig() (*Config, error) { } func main() { - cfg, err := initConfig() if err != nil { log.Fatalf("config error: %v", err) @@ -81,6 +80,8 @@ func main() { log.Fatalf("database error: %v", err) } + defer db.Close() + // Start HTTP server log.Printf("Listening on %s", listen) log.Print(newRouter().Run(listen))