diff --git a/cmd/vote-validator/dcrdata.go b/cmd/vote-validator/dcrdata.go index 60f6847..77833e4 100644 --- a/cmd/vote-validator/dcrdata.go +++ b/cmd/vote-validator/dcrdata.go @@ -6,6 +6,7 @@ package main import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -46,7 +47,7 @@ type dcrdataClient struct { URL string } -func (d *dcrdataClient) txns(txnHashes []string, spends bool) ([]tx, error) { +func (d *dcrdataClient) txns(ctx context.Context, txnHashes []string, spends bool) ([]tx, error) { jsonData, err := json.Marshal(txns{ Transactions: txnHashes, }) @@ -55,7 +56,7 @@ func (d *dcrdataClient) txns(txnHashes []string, spends bool) ([]tx, error) { } url := fmt.Sprintf("%s/api/txs?spends=%t", d.URL, spends) - request, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + request, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) if err != nil { return nil, err } diff --git a/cmd/vote-validator/main.go b/cmd/vote-validator/main.go index d1d09e9..2854fb3 100644 --- a/cmd/vote-validator/main.go +++ b/cmd/vote-validator/main.go @@ -5,6 +5,7 @@ package main import ( + "context" "encoding/binary" "encoding/hex" "errors" @@ -16,6 +17,7 @@ import ( "github.com/decred/vspd/database" "github.com/decred/vspd/internal/config" + "github.com/decred/vspd/internal/signal" ) const ( @@ -79,6 +81,8 @@ func run() int { return 1 } + ctx := signal.ShutdownListener(log) + // Get all voted tickets from database. dbTickets, err := vdb.GetVotedTickets() if err != nil { @@ -105,14 +109,22 @@ func run() int { // Use dcrdata to get spender info for voted tickets (dcrd can't do this). log.Infof("Getting vote info from %s", dcrdata.URL) for i := 0; i < numTickets; i += chunkSize { + // Stop if shutdown requested. + if ctx.Err() != nil { + return 0 + } + end := i + chunkSize if end > numTickets { end = numTickets } // Get the tx info for each ticket. - ticketTxns, err := dcrdata.txns(ticketHashes[i:end], true) + ticketTxns, err := dcrdata.txns(ctx, ticketHashes[i:end], true) if err != nil { + if errors.Is(err, context.Canceled) { + return 0 + } log.Error(err) return 1 } @@ -125,8 +137,11 @@ func run() int { mapSpenderToTicket[spenderHash] = txn.TxID } - spenderTxns, err := dcrdata.txns(spenderHashes, false) + spenderTxns, err := dcrdata.txns(ctx, spenderHashes, false) if err != nil { + if errors.Is(err, context.Canceled) { + return 0 + } log.Error(err) return 1 } @@ -193,6 +208,10 @@ func run() int { } for _, t := range sorted[0:cfg.ToCheck] { + // Stop if shutdown requested. + if ctx.Err() != nil { + return 0 + } if t.voteVersion != latestVoteVersion { results.wrongVersion = append(results.wrongVersion, t)