From a95b214b3f7c65430ce3fd3c6e8f18c5e86b0fa3 Mon Sep 17 00:00:00 2001 From: jholdstock Date: Thu, 11 Jun 2020 10:48:00 +0100 Subject: [PATCH] Clean up duplicate database code and add tests. --- database/database_test.go | 4 +- database/ticket.go | 89 +++++++++++---------------------------- database/ticket_test.go | 55 ++++++++++++++++++++++++ 3 files changed, 82 insertions(+), 66 deletions(-) diff --git a/database/database_test.go b/database/database_test.go index 367d8d2..c85e322 100644 --- a/database/database_test.go +++ b/database/database_test.go @@ -26,6 +26,7 @@ func TestDatabase(t *testing.T) { "testGetTicketByHash": testGetTicketByHash, "testUpdateTicket": testUpdateTicket, "testTicketFeeExpired": testTicketFeeExpired, + "testFilterTickets": testFilterTickets, "testAddressIndex": testAddressIndex, } @@ -55,5 +56,4 @@ func TestDatabase(t *testing.T) { } } -// TODO: Add tests for CountTickets, GetUnconfirmedTickets, GetPendingFees, -// GetUnconfirmedFees, GetAllTickets. +// TODO: Add tests for CountTickets. diff --git a/database/ticket.go b/database/ticket.go index 66fb16f..8937fe7 100644 --- a/database/ticket.go +++ b/database/ticket.go @@ -170,56 +170,39 @@ func (vdb *VspDatabase) CountTickets() (int, int, error) { return total, feePaid, err } +// GetUnconfirmedTickets returns tickets which are not yet confirmed. func (vdb *VspDatabase) GetUnconfirmedTickets() ([]Ticket, error) { - var tickets []Ticket - err := vdb.db.View(func(tx *bolt.Tx) error { - ticketBkt := tx.Bucket(vspBktK).Bucket(ticketBktK) - - return ticketBkt.ForEach(func(k, v []byte) error { - var ticket Ticket - err := json.Unmarshal(v, &ticket) - if err != nil { - return fmt.Errorf("could not unmarshal ticket: %v", err) - } - - if !ticket.Confirmed { - tickets = append(tickets, ticket) - } - - return nil - }) + return vdb.filterTickets(func(t Ticket) bool { + return !t.Confirmed }) - - return tickets, err } +// GetPendingFees returns tickets which are confirmed and have a fee tx which is +// not yet broadcast. func (vdb *VspDatabase) GetPendingFees() ([]Ticket, error) { - var tickets []Ticket - err := vdb.db.View(func(tx *bolt.Tx) error { - ticketBkt := tx.Bucket(vspBktK).Bucket(ticketBktK) - - return ticketBkt.ForEach(func(k, v []byte) error { - var ticket Ticket - err := json.Unmarshal(v, &ticket) - if err != nil { - return fmt.Errorf("could not unmarshal ticket: %v", err) - } - - // Add ticket if it is confirmed, and we have a fee tx which is not - // yet broadcast. - if ticket.Confirmed && - ticket.FeeTxStatus == FeeReceieved { - tickets = append(tickets, ticket) - } - - return nil - }) + return vdb.filterTickets(func(t Ticket) bool { + return t.Confirmed && t.FeeTxStatus == FeeReceieved }) - - return tickets, err } +// GetUnconfirmedFees returns tickets with a fee tx that is broadcast but not +// confirmed yet. func (vdb *VspDatabase) GetUnconfirmedFees() ([]Ticket, error) { + return vdb.filterTickets(func(t Ticket) bool { + return t.FeeTxStatus == FeeBroadcast + }) +} + +// GetAllTickets returns all tickets in the database. +func (vdb *VspDatabase) GetAllTickets() ([]Ticket, error) { + return vdb.filterTickets(func(t Ticket) bool { + return true + }) +} + +// filterTickets accepts a filter function and returns all tickets from the +// database which match the filter. +func (vdb *VspDatabase) filterTickets(filter func(Ticket) bool) ([]Ticket, error) { var tickets []Ticket err := vdb.db.View(func(tx *bolt.Tx) error { ticketBkt := tx.Bucket(vspBktK).Bucket(ticketBktK) @@ -231,8 +214,7 @@ func (vdb *VspDatabase) GetUnconfirmedFees() ([]Ticket, error) { return fmt.Errorf("could not unmarshal ticket: %v", err) } - // Add ticket if fee tx is broadcast but not confirmed yet. - if ticket.FeeTxStatus == FeeBroadcast { + if filter(ticket) { tickets = append(tickets, ticket) } @@ -242,24 +224,3 @@ func (vdb *VspDatabase) GetUnconfirmedFees() ([]Ticket, error) { return tickets, err } - -func (vdb *VspDatabase) GetAllTickets() ([]Ticket, error) { - var tickets []Ticket - err := vdb.db.View(func(tx *bolt.Tx) error { - ticketBkt := tx.Bucket(vspBktK).Bucket(ticketBktK) - - return ticketBkt.ForEach(func(k, v []byte) error { - var ticket Ticket - err := json.Unmarshal(v, &ticket) - if err != nil { - return fmt.Errorf("could not unmarshal ticket: %v", err) - } - - tickets = append(tickets, ticket) - - return nil - }) - }) - - return tickets, err -} diff --git a/database/ticket_test.go b/database/ticket_test.go index f8c99c4..18fc543 100644 --- a/database/ticket_test.go +++ b/database/ticket_test.go @@ -153,3 +153,58 @@ func testTicketFeeExpired(t *testing.T) { t.Fatal("expected ticket to be expired") } } + +func testFilterTickets(t *testing.T) { + // Insert a ticket + ticket := exampleTicket() + err := db.InsertNewTicket(ticket) + if err != nil { + t.Fatalf("error storing ticket in database: %v", err) + } + + // Insert another ticket + ticket.Hash = ticket.Hash + "1" + ticket.FeeAddress = ticket.FeeAddress + "1" + ticket.FeeAddressIndex = ticket.FeeAddressIndex + 1 + ticket.Confirmed = !ticket.Confirmed + err = db.InsertNewTicket(ticket) + if err != nil { + t.Fatalf("error storing ticket in database: %v", err) + } + + // Expect all tickets returned. + retrieved, err := db.filterTickets(func(t Ticket) bool { + return true + }) + if err != nil { + t.Fatalf("error filtering tickets: %v", err) + } + if len(retrieved) != 2 { + t.Fatal("expected to find 2 tickets") + } + + // Only one ticket should be confirmed. + retrieved, err = db.filterTickets(func(t Ticket) bool { + return t.Confirmed + }) + if err != nil { + t.Fatalf("error filtering tickets: %v", err) + } + if len(retrieved) != 1 { + t.Fatal("expected to find 2 tickets") + } + if retrieved[0].Confirmed != true { + t.Fatal("expected retrieved ticket to be confirmed") + } + + // Expect no tickets with confirmed fee. + retrieved, err = db.filterTickets(func(t Ticket) bool { + return t.FeeTxStatus == FeeConfirmed + }) + if err != nil { + t.Fatalf("error filtering tickets: %v", err) + } + if len(retrieved) != 0 { + t.Fatal("expected to find 0 tickets") + } +}