v3tool: Use shutdown context.

Using the shutdown context provided by the internal signal package means
v3tool can exit cleanly if shutdown is requested by the user.
This commit is contained in:
jholdstock 2023-09-18 19:27:01 +01:00 committed by Jamie Holdstock
parent 9660de7d9f
commit d1eddafb52
2 changed files with 66 additions and 30 deletions

View File

@ -28,24 +28,24 @@ type dcrwallet struct {
*wsrpc.Client
}
func newWalletRPC(rpcURL, rpcUser, rpcPass string) (*dcrwallet, error) {
func newWalletRPC(ctx context.Context, rpcURL, rpcUser, rpcPass string) (*dcrwallet, error) {
tlsOpt := wsrpc.WithTLSConfig(&tls.Config{
InsecureSkipVerify: true,
})
authOpt := wsrpc.WithBasicAuth(rpcUser, rpcPass)
rpc, err := wsrpc.Dial(context.TODO(), rpcURL, tlsOpt, authOpt)
rpc, err := wsrpc.Dial(ctx, rpcURL, tlsOpt, authOpt)
if err != nil {
return nil, err
}
return &dcrwallet{rpc}, nil
}
func (w *dcrwallet) createFeeTx(feeAddress string, fee int64) (string, error) {
func (w *dcrwallet) createFeeTx(ctx context.Context, feeAddress string, fee int64) (string, error) {
amounts := make(map[string]float64)
amounts[feeAddress] = dcrutil.Amount(fee).ToCoin()
var msgtxstr string
err := w.Call(context.TODO(), "createrawtransaction", &msgtxstr, nil, amounts)
err := w.Call(ctx, "createrawtransaction", &msgtxstr, nil, amounts)
if err != nil {
return "", err
}
@ -55,7 +55,7 @@ func (w *dcrwallet) createFeeTx(feeAddress string, fee int64) (string, error) {
ConfTarget: &zero,
}
var fundTx wallettypes.FundRawTransactionResult
err = w.Call(context.TODO(), "fundrawtransaction", &fundTx, msgtxstr, "default", &opt)
err = w.Call(ctx, "fundrawtransaction", &fundTx, msgtxstr, "default", &opt)
if err != nil {
return "", err
}
@ -77,7 +77,7 @@ func (w *dcrwallet) createFeeTx(feeAddress string, fee int64) (string, error) {
var locked bool
const unlock = false
err = w.Call(context.TODO(), "lockunspent", &locked, unlock, transactions)
err = w.Call(ctx, "lockunspent", &locked, unlock, transactions)
if err != nil {
return "", err
}
@ -87,7 +87,7 @@ func (w *dcrwallet) createFeeTx(feeAddress string, fee int64) (string, error) {
}
var signedTx wallettypes.SignRawTransactionResult
err = w.Call(context.TODO(), "signrawtransaction", &signedTx, fundTx.Hex)
err = w.Call(ctx, "signrawtransaction", &signedTx, fundTx.Hex)
if err != nil {
return "", err
}
@ -97,9 +97,9 @@ func (w *dcrwallet) createFeeTx(feeAddress string, fee int64) (string, error) {
return signedTx.Hex, nil
}
func (w *dcrwallet) SignMessage(_ context.Context, msg string, commitmentAddr stdaddr.Address) ([]byte, error) {
func (w *dcrwallet) SignMessage(ctx context.Context, msg string, commitmentAddr stdaddr.Address) ([]byte, error) {
var signature string
err := w.Call(context.TODO(), "signmessage", &signature, commitmentAddr.String(), msg)
err := w.Call(ctx, "signmessage", &signature, commitmentAddr.String(), msg)
if err != nil {
return nil, err
}
@ -107,19 +107,19 @@ func (w *dcrwallet) SignMessage(_ context.Context, msg string, commitmentAddr st
return base64.StdEncoding.DecodeString(signature)
}
func (w *dcrwallet) dumpPrivKey(addr stdaddr.Address) (string, error) {
func (w *dcrwallet) dumpPrivKey(ctx context.Context, addr stdaddr.Address) (string, error) {
var privKeyStr string
err := w.Call(context.TODO(), "dumpprivkey", &privKeyStr, addr.String())
err := w.Call(ctx, "dumpprivkey", &privKeyStr, addr.String())
if err != nil {
return "", err
}
return privKeyStr, nil
}
func (w *dcrwallet) getTickets() (*wallettypes.GetTicketsResult, error) {
func (w *dcrwallet) getTickets(ctx context.Context) (*wallettypes.GetTicketsResult, error) {
var tickets wallettypes.GetTicketsResult
const includeImmature = true
err := w.Call(context.TODO(), "gettickets", &tickets, includeImmature)
err := w.Call(ctx, "gettickets", &tickets, includeImmature)
if err != nil {
return nil, err
}
@ -128,9 +128,9 @@ func (w *dcrwallet) getTickets() (*wallettypes.GetTicketsResult, error) {
// getTicketDetails returns the ticket hex, privkey for voting, and the
// commitment address.
func (w *dcrwallet) getTicketDetails(ticketHash string) (string, string, stdaddr.Address, error) {
func (w *dcrwallet) getTicketDetails(ctx context.Context, ticketHash string) (string, string, stdaddr.Address, error) {
var getTransactionResult wallettypes.GetTransactionResult
err := w.Call(context.TODO(), "gettransaction", &getTransactionResult, ticketHash, false)
err := w.Call(ctx, "gettransaction", &getTransactionResult, ticketHash, false)
if err != nil {
fmt.Printf("gettransaction: %v\n", err)
return "", "", nil, err
@ -160,7 +160,7 @@ func (w *dcrwallet) getTicketDetails(ticketHash string) (string, string, stdaddr
return "", "", nil, err
}
privKeyStr, err := w.dumpPrivKey(submissionAddr[0])
privKeyStr, err := w.dumpPrivKey(ctx, submissionAddr[0])
if err != nil {
return "", "", nil, err
}

View File

@ -7,6 +7,7 @@ package main
import (
"context"
"encoding/json"
"errors"
"io"
"net/http"
"os"
@ -14,6 +15,7 @@ import (
"github.com/decred/slog"
"github.com/decred/vspd/client/v2"
"github.com/decred/vspd/internal/signal"
"github.com/decred/vspd/types/v2"
)
@ -55,8 +57,13 @@ func run() int {
log := slog.NewBackend(os.Stdout).Logger("")
log.SetLevel(slog.LevelTrace)
walletRPC, err := newWalletRPC(rpcURL, rpcUser, rpcPass)
ctx := signal.ShutdownListener(log)
walletRPC, err := newWalletRPC(ctx, rpcURL, rpcUser, rpcPass)
if err != nil {
if errors.Is(err, context.Canceled) {
return 0
}
log.Errorf("%v", err)
return 1
}
@ -85,8 +92,11 @@ func run() int {
}
// Get list of tickets
tickets, err := walletRPC.getTickets()
tickets, err := walletRPC.getTickets(ctx)
if err != nil {
if errors.Is(err, context.Canceled) {
return 0
}
log.Errorf("%v", err)
return 1
}
@ -102,9 +112,17 @@ func run() int {
}
for i := 0; i < len(tickets.Hashes); i++ {
// Stop if shutdown requested.
if ctx.Err() != nil {
return 0
}
ticketHash := tickets.Hashes[i]
hex, privKeyStr, commitmentAddr, err := walletRPC.getTicketDetails(ticketHash)
hex, privKeyStr, commitmentAddr, err := walletRPC.getTicketDetails(ctx, ticketHash)
if err != nil {
if errors.Is(err, context.Canceled) {
return 0
}
log.Errorf("%v", err)
return 1
}
@ -125,19 +143,25 @@ func run() int {
Timestamp: time.Now().Unix(),
}
feeAddrResp, err := vClient.FeeAddress(context.TODO(), feeAddrReq, commitmentAddr)
feeAddrResp, err := vClient.FeeAddress(ctx, feeAddrReq, commitmentAddr)
if err != nil {
if errors.Is(err, context.Canceled) {
return 0
}
log.Errorf("getFeeAddress error: %v", err)
break
return 1
}
log.Infof("feeAddress: %v", feeAddrResp.FeeAddress)
log.Infof("privKeyStr: %v", privKeyStr)
feeTx, err := walletRPC.createFeeTx(feeAddrResp.FeeAddress, feeAddrResp.FeeAmount)
feeTx, err := walletRPC.createFeeTx(ctx, feeAddrResp.FeeAddress, feeAddrResp.FeeAmount)
if err != nil {
if errors.Is(err, context.Canceled) {
return 0
}
log.Errorf("createFeeTx error: %v", err)
break
return 1
}
voteChoices := map[string]string{"autorevocations": "no"}
@ -160,8 +184,11 @@ func run() int {
TreasuryPolicy: treasury,
}
_, err = vClient.PayFee(context.TODO(), payFeeReq, commitmentAddr)
_, err = vClient.PayFee(ctx, payFeeReq, commitmentAddr)
if err != nil {
if errors.Is(err, context.Canceled) {
return 0
}
log.Errorf("payFee error: %v", err)
continue
}
@ -170,10 +197,13 @@ func run() int {
TicketHash: ticketHash,
}
_, err = vClient.TicketStatus(context.TODO(), ticketStatusReq, commitmentAddr)
_, err = vClient.TicketStatus(ctx, ticketStatusReq, commitmentAddr)
if err != nil {
if errors.Is(err, context.Canceled) {
return 0
}
log.Errorf("getTicketStatus error: %v", err)
break
return 1
}
voteChoices["autorevocations"] = "yes"
@ -189,16 +219,22 @@ func run() int {
TreasuryPolicy: treasury,
}
_, err = vClient.SetVoteChoices(context.TODO(), voteChoiceReq, commitmentAddr)
_, err = vClient.SetVoteChoices(ctx, voteChoiceReq, commitmentAddr)
if err != nil {
if errors.Is(err, context.Canceled) {
return 0
}
log.Errorf("setVoteChoices error: %v", err)
break
return 1
}
_, err = vClient.TicketStatus(context.TODO(), ticketStatusReq, commitmentAddr)
_, err = vClient.TicketStatus(ctx, ticketStatusReq, commitmentAddr)
if err != nil {
if errors.Is(err, context.Canceled) {
return 0
}
log.Errorf("getTicketStatus error: %v", err)
break
return 1
}
time.Sleep(1 * time.Second)