diff --git a/database/database_test.go b/database/database_test.go index 556009e..ad58e92 100644 --- a/database/database_test.go +++ b/database/database_test.go @@ -6,19 +6,28 @@ package database import ( "context" + "crypto/ed25519" + "io/ioutil" + "net/http" + "net/http/httptest" "os" + "strconv" "sync" "testing" "time" ) -var ( +const ( testDb = "test.db" backupDb = "test.db-backup" - db *VspDatabase + feeXPub = "feexpub" maxVoteChangeRecords = 3 ) +var ( + db *VspDatabase +) + // TestDatabase runs all database tests. func TestDatabase(t *testing.T) { // Ensure we are starting with a clean environment. @@ -27,14 +36,17 @@ func TestDatabase(t *testing.T) { // All sub-tests to run. tests := map[string]func(*testing.T){ + "testCreateNew": testCreateNew, "testInsertNewTicket": testInsertNewTicket, "testGetTicketByHash": testGetTicketByHash, "testUpdateTicket": testUpdateTicket, "testTicketFeeExpired": testTicketFeeExpired, "testFilterTickets": testFilterTickets, + "testCountTickets": testCountTickets, "testAddressIndex": testAddressIndex, "testDeleteTicket": testDeleteTicket, "testVoteChangeRecords": testVoteChangeRecords, + "testHTTPBackup": testHTTPBackup, } for testName, test := range tests { @@ -42,7 +54,7 @@ func TestDatabase(t *testing.T) { var err error var wg sync.WaitGroup ctx, cancel := context.WithCancel(context.TODO()) - err = CreateNew(testDb, "feexpub") + err = CreateNew(testDb, feeXPub) if err != nil { t.Fatalf("error creating test database: %v", err) } @@ -63,4 +75,88 @@ func TestDatabase(t *testing.T) { } } -// TODO: Add tests for CountTickets. +func testCreateNew(t *testing.T) { + // A newly created DB should contain a signing keypair. + priv, pub, err := db.KeyPair() + if err != nil { + t.Fatalf("error getting keypair: %v", err) + } + + // Ensure keypair can be used for signing/verifying messages. + msg := []byte("msg") + sig := ed25519.Sign(priv, msg) + if !ed25519.Verify(pub, msg, sig) { + t.Fatalf("keypair from database could not be used to sign/verify a message") + } + + // A newly created DB should have a cookie secret. + secret, err := db.GetCookieSecret() + if err != nil { + t.Fatalf("error getting cookie secret: %v", err) + } + + if len(secret) != 32 { + t.Fatalf("expected a 32 byte cookie secret, got %d bytes", len(secret)) + } + + // A newly created DB should store the fee xpub it was initialized with. + retrievedXPub, err := db.GetFeeXPub() + if err != nil { + t.Fatalf("error getting fee xpub: %v", err) + } + + if retrievedXPub != feeXPub { + t.Fatalf("expected fee xpub %v, got %v", feeXPub, retrievedXPub) + } +} + +func testHTTPBackup(t *testing.T) { + // Capture the HTTP response written by the backup func. + rr := httptest.NewRecorder() + err := db.BackupDB(rr) + if err != nil { + t.Fatal(err) + } + + // Check the HTTP status is OK. + if status := rr.Code; status != http.StatusOK { + t.Fatalf("wrong HTTP status code: expected %v, got %v", + http.StatusOK, status) + } + + // Check HTTP headers. + header := "Content-Type" + expected := "application/octet-stream" + if actual := rr.Header().Get(header); actual != expected { + t.Errorf("wrong %s header: expected %s, got %s", + header, expected, actual) + } + + header = "Content-Disposition" + expected = `attachment; filename="vspd.db"` + if actual := rr.Header().Get(header); actual != expected { + t.Errorf("wrong %s header: expected %s, got %s", + header, expected, actual) + } + + header = "Content-Length" + cLength, err := strconv.Atoi(rr.Header().Get(header)) + if err != nil { + t.Fatalf("could not convert %s to integer: %v", header, err) + } + + if cLength <= 0 { + t.Fatalf("expected a %s greater than zero, got %d", header, cLength) + } + + // Check reported length matches actual. + body, err := ioutil.ReadAll(rr.Result().Body) + if err != nil { + t.Fatalf("could not read http response body: %v", err) + } + + if len(body) != cLength { + t.Fatalf("expected reported content-length to match actual body length. %v != %v", + cLength, len(body)) + } +} diff --git a/database/ticket_test.go b/database/ticket_test.go index 4517872..8f98ede 100644 --- a/database/ticket_test.go +++ b/database/ticket_test.go @@ -236,3 +236,79 @@ func testFilterTickets(t *testing.T) { t.Fatal("expected to find 0 tickets") } } + +func testCountTickets(t *testing.T) { + count := func(test string, expectedVoting, expectedVoted, expectedRevoked int64) { + voting, voted, revoked, err := db.CountTickets() + if err != nil { + t.Fatalf("error counting tickets: %v", err) + } + + if voting != expectedVoting { + t.Fatalf("test %s: expected %d voting tickets, got %d", + test, expectedVoting, voting) + } + if voted != expectedVoted { + t.Fatalf("test %s: expected %d voted tickets, got %d", + test, expectedVoted, voted) + } + if revoked != expectedRevoked { + t.Fatalf("test %s: expected %d revoked tickets, got %d", + test, expectedRevoked, revoked) + } + } + + // Initial counts should all be zero. + count("empty db", 0, 0, 0) + + // Insert a ticket with non-confirmed fee into the database. + // This should not be counted. + ticket := exampleTicket() + ticket.FeeTxStatus = FeeReceieved + err := db.InsertNewTicket(ticket) + if err != nil { + t.Fatalf("error storing ticket in database: %v", err) + } + + count("unconfirmed fee", 0, 0, 0) + + // Insert a ticket with confirmed fee into the database. + // This should be counted. + ticket.Hash = ticket.Hash + "1" + ticket.FeeAddress = ticket.FeeAddress + "1" + ticket.FeeAddressIndex = ticket.FeeAddressIndex + 1 + ticket.FeeTxStatus = FeeConfirmed + err = db.InsertNewTicket(ticket) + if err != nil { + t.Fatalf("error storing ticket in database: %v", err) + } + + count("confirmed fee", 1, 0, 0) + + // Insert a voted ticket into the database. + // This should be counted. + ticket.Hash = ticket.Hash + "1" + ticket.FeeAddress = ticket.FeeAddress + "1" + ticket.FeeAddressIndex = ticket.FeeAddressIndex + 1 + ticket.FeeTxStatus = FeeConfirmed + ticket.Outcome = Voted + err = db.InsertNewTicket(ticket) + if err != nil { + t.Fatalf("error storing ticket in database: %v", err) + } + + count("voted", 1, 1, 0) + + // Insert a revoked ticket into the database. + ticket.Hash = ticket.Hash + "1" + ticket.FeeAddress = ticket.FeeAddress + "1" + ticket.FeeAddressIndex = ticket.FeeAddressIndex + 1 + ticket.FeeTxStatus = FeeConfirmed + ticket.Outcome = Revoked + err = db.InsertNewTicket(ticket) + if err != nil { + t.Fatalf("error storing ticket in database: %v", err) + } + + count("revoked", 1, 1, 1) +}