Clean up duplicate database code and add tests.

This commit is contained in:
jholdstock 2020-06-11 10:48:00 +01:00 committed by David Hill
parent 78692fea88
commit a95b214b3f
3 changed files with 82 additions and 66 deletions

View File

@ -26,6 +26,7 @@ func TestDatabase(t *testing.T) {
"testGetTicketByHash": testGetTicketByHash, "testGetTicketByHash": testGetTicketByHash,
"testUpdateTicket": testUpdateTicket, "testUpdateTicket": testUpdateTicket,
"testTicketFeeExpired": testTicketFeeExpired, "testTicketFeeExpired": testTicketFeeExpired,
"testFilterTickets": testFilterTickets,
"testAddressIndex": testAddressIndex, "testAddressIndex": testAddressIndex,
} }
@ -55,5 +56,4 @@ func TestDatabase(t *testing.T) {
} }
} }
// TODO: Add tests for CountTickets, GetUnconfirmedTickets, GetPendingFees, // TODO: Add tests for CountTickets.
// GetUnconfirmedFees, GetAllTickets.

View File

@ -170,56 +170,39 @@ func (vdb *VspDatabase) CountTickets() (int, int, error) {
return total, feePaid, err return total, feePaid, err
} }
// GetUnconfirmedTickets returns tickets which are not yet confirmed.
func (vdb *VspDatabase) GetUnconfirmedTickets() ([]Ticket, error) { func (vdb *VspDatabase) GetUnconfirmedTickets() ([]Ticket, error) {
var tickets []Ticket return vdb.filterTickets(func(t Ticket) bool {
err := vdb.db.View(func(tx *bolt.Tx) error { return !t.Confirmed
ticketBkt := tx.Bucket(vspBktK).Bucket(ticketBktK)
return ticketBkt.ForEach(func(k, v []byte) error {
var ticket Ticket
err := json.Unmarshal(v, &ticket)
if err != nil {
return fmt.Errorf("could not unmarshal ticket: %v", err)
}
if !ticket.Confirmed {
tickets = append(tickets, ticket)
}
return nil
})
}) })
return tickets, err
} }
// GetPendingFees returns tickets which are confirmed and have a fee tx which is
// not yet broadcast.
func (vdb *VspDatabase) GetPendingFees() ([]Ticket, error) { func (vdb *VspDatabase) GetPendingFees() ([]Ticket, error) {
var tickets []Ticket return vdb.filterTickets(func(t Ticket) bool {
err := vdb.db.View(func(tx *bolt.Tx) error { return t.Confirmed && t.FeeTxStatus == FeeReceieved
ticketBkt := tx.Bucket(vspBktK).Bucket(ticketBktK)
return ticketBkt.ForEach(func(k, v []byte) error {
var ticket Ticket
err := json.Unmarshal(v, &ticket)
if err != nil {
return fmt.Errorf("could not unmarshal ticket: %v", err)
}
// Add ticket if it is confirmed, and we have a fee tx which is not
// yet broadcast.
if ticket.Confirmed &&
ticket.FeeTxStatus == FeeReceieved {
tickets = append(tickets, ticket)
}
return nil
})
}) })
return tickets, err
} }
// GetUnconfirmedFees returns tickets with a fee tx that is broadcast but not
// confirmed yet.
func (vdb *VspDatabase) GetUnconfirmedFees() ([]Ticket, error) { func (vdb *VspDatabase) GetUnconfirmedFees() ([]Ticket, error) {
return vdb.filterTickets(func(t Ticket) bool {
return t.FeeTxStatus == FeeBroadcast
})
}
// GetAllTickets returns all tickets in the database.
func (vdb *VspDatabase) GetAllTickets() ([]Ticket, error) {
return vdb.filterTickets(func(t Ticket) bool {
return true
})
}
// filterTickets accepts a filter function and returns all tickets from the
// database which match the filter.
func (vdb *VspDatabase) filterTickets(filter func(Ticket) bool) ([]Ticket, error) {
var tickets []Ticket var tickets []Ticket
err := vdb.db.View(func(tx *bolt.Tx) error { err := vdb.db.View(func(tx *bolt.Tx) error {
ticketBkt := tx.Bucket(vspBktK).Bucket(ticketBktK) ticketBkt := tx.Bucket(vspBktK).Bucket(ticketBktK)
@ -231,8 +214,7 @@ func (vdb *VspDatabase) GetUnconfirmedFees() ([]Ticket, error) {
return fmt.Errorf("could not unmarshal ticket: %v", err) return fmt.Errorf("could not unmarshal ticket: %v", err)
} }
// Add ticket if fee tx is broadcast but not confirmed yet. if filter(ticket) {
if ticket.FeeTxStatus == FeeBroadcast {
tickets = append(tickets, ticket) tickets = append(tickets, ticket)
} }
@ -242,24 +224,3 @@ func (vdb *VspDatabase) GetUnconfirmedFees() ([]Ticket, error) {
return tickets, err return tickets, err
} }
func (vdb *VspDatabase) GetAllTickets() ([]Ticket, error) {
var tickets []Ticket
err := vdb.db.View(func(tx *bolt.Tx) error {
ticketBkt := tx.Bucket(vspBktK).Bucket(ticketBktK)
return ticketBkt.ForEach(func(k, v []byte) error {
var ticket Ticket
err := json.Unmarshal(v, &ticket)
if err != nil {
return fmt.Errorf("could not unmarshal ticket: %v", err)
}
tickets = append(tickets, ticket)
return nil
})
})
return tickets, err
}

View File

@ -153,3 +153,58 @@ func testTicketFeeExpired(t *testing.T) {
t.Fatal("expected ticket to be expired") t.Fatal("expected ticket to be expired")
} }
} }
func testFilterTickets(t *testing.T) {
// Insert a ticket
ticket := exampleTicket()
err := db.InsertNewTicket(ticket)
if err != nil {
t.Fatalf("error storing ticket in database: %v", err)
}
// Insert another ticket
ticket.Hash = ticket.Hash + "1"
ticket.FeeAddress = ticket.FeeAddress + "1"
ticket.FeeAddressIndex = ticket.FeeAddressIndex + 1
ticket.Confirmed = !ticket.Confirmed
err = db.InsertNewTicket(ticket)
if err != nil {
t.Fatalf("error storing ticket in database: %v", err)
}
// Expect all tickets returned.
retrieved, err := db.filterTickets(func(t Ticket) bool {
return true
})
if err != nil {
t.Fatalf("error filtering tickets: %v", err)
}
if len(retrieved) != 2 {
t.Fatal("expected to find 2 tickets")
}
// Only one ticket should be confirmed.
retrieved, err = db.filterTickets(func(t Ticket) bool {
return t.Confirmed
})
if err != nil {
t.Fatalf("error filtering tickets: %v", err)
}
if len(retrieved) != 1 {
t.Fatal("expected to find 2 tickets")
}
if retrieved[0].Confirmed != true {
t.Fatal("expected retrieved ticket to be confirmed")
}
// Expect no tickets with confirmed fee.
retrieved, err = db.filterTickets(func(t Ticket) bool {
return t.FeeTxStatus == FeeConfirmed
})
if err != nil {
t.Fatalf("error filtering tickets: %v", err)
}
if len(retrieved) != 0 {
t.Fatal("expected to find 0 tickets")
}
}