v3tool: Use shutdown context.

Using the shutdown context provided by the internal signal package means
v3tool can exit cleanly if shutdown is requested by the user.
This commit is contained in:
jholdstock 2023-09-18 19:27:01 +01:00 committed by Jamie Holdstock
parent 9660de7d9f
commit d1eddafb52
2 changed files with 66 additions and 30 deletions

View File

@ -28,24 +28,24 @@ type dcrwallet struct {
*wsrpc.Client *wsrpc.Client
} }
func newWalletRPC(rpcURL, rpcUser, rpcPass string) (*dcrwallet, error) { func newWalletRPC(ctx context.Context, rpcURL, rpcUser, rpcPass string) (*dcrwallet, error) {
tlsOpt := wsrpc.WithTLSConfig(&tls.Config{ tlsOpt := wsrpc.WithTLSConfig(&tls.Config{
InsecureSkipVerify: true, InsecureSkipVerify: true,
}) })
authOpt := wsrpc.WithBasicAuth(rpcUser, rpcPass) authOpt := wsrpc.WithBasicAuth(rpcUser, rpcPass)
rpc, err := wsrpc.Dial(context.TODO(), rpcURL, tlsOpt, authOpt) rpc, err := wsrpc.Dial(ctx, rpcURL, tlsOpt, authOpt)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &dcrwallet{rpc}, nil return &dcrwallet{rpc}, nil
} }
func (w *dcrwallet) createFeeTx(feeAddress string, fee int64) (string, error) { func (w *dcrwallet) createFeeTx(ctx context.Context, feeAddress string, fee int64) (string, error) {
amounts := make(map[string]float64) amounts := make(map[string]float64)
amounts[feeAddress] = dcrutil.Amount(fee).ToCoin() amounts[feeAddress] = dcrutil.Amount(fee).ToCoin()
var msgtxstr string var msgtxstr string
err := w.Call(context.TODO(), "createrawtransaction", &msgtxstr, nil, amounts) err := w.Call(ctx, "createrawtransaction", &msgtxstr, nil, amounts)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -55,7 +55,7 @@ func (w *dcrwallet) createFeeTx(feeAddress string, fee int64) (string, error) {
ConfTarget: &zero, ConfTarget: &zero,
} }
var fundTx wallettypes.FundRawTransactionResult var fundTx wallettypes.FundRawTransactionResult
err = w.Call(context.TODO(), "fundrawtransaction", &fundTx, msgtxstr, "default", &opt) err = w.Call(ctx, "fundrawtransaction", &fundTx, msgtxstr, "default", &opt)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -77,7 +77,7 @@ func (w *dcrwallet) createFeeTx(feeAddress string, fee int64) (string, error) {
var locked bool var locked bool
const unlock = false const unlock = false
err = w.Call(context.TODO(), "lockunspent", &locked, unlock, transactions) err = w.Call(ctx, "lockunspent", &locked, unlock, transactions)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -87,7 +87,7 @@ func (w *dcrwallet) createFeeTx(feeAddress string, fee int64) (string, error) {
} }
var signedTx wallettypes.SignRawTransactionResult var signedTx wallettypes.SignRawTransactionResult
err = w.Call(context.TODO(), "signrawtransaction", &signedTx, fundTx.Hex) err = w.Call(ctx, "signrawtransaction", &signedTx, fundTx.Hex)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -97,9 +97,9 @@ func (w *dcrwallet) createFeeTx(feeAddress string, fee int64) (string, error) {
return signedTx.Hex, nil return signedTx.Hex, nil
} }
func (w *dcrwallet) SignMessage(_ context.Context, msg string, commitmentAddr stdaddr.Address) ([]byte, error) { func (w *dcrwallet) SignMessage(ctx context.Context, msg string, commitmentAddr stdaddr.Address) ([]byte, error) {
var signature string var signature string
err := w.Call(context.TODO(), "signmessage", &signature, commitmentAddr.String(), msg) err := w.Call(ctx, "signmessage", &signature, commitmentAddr.String(), msg)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -107,19 +107,19 @@ func (w *dcrwallet) SignMessage(_ context.Context, msg string, commitmentAddr st
return base64.StdEncoding.DecodeString(signature) return base64.StdEncoding.DecodeString(signature)
} }
func (w *dcrwallet) dumpPrivKey(addr stdaddr.Address) (string, error) { func (w *dcrwallet) dumpPrivKey(ctx context.Context, addr stdaddr.Address) (string, error) {
var privKeyStr string var privKeyStr string
err := w.Call(context.TODO(), "dumpprivkey", &privKeyStr, addr.String()) err := w.Call(ctx, "dumpprivkey", &privKeyStr, addr.String())
if err != nil { if err != nil {
return "", err return "", err
} }
return privKeyStr, nil return privKeyStr, nil
} }
func (w *dcrwallet) getTickets() (*wallettypes.GetTicketsResult, error) { func (w *dcrwallet) getTickets(ctx context.Context) (*wallettypes.GetTicketsResult, error) {
var tickets wallettypes.GetTicketsResult var tickets wallettypes.GetTicketsResult
const includeImmature = true const includeImmature = true
err := w.Call(context.TODO(), "gettickets", &tickets, includeImmature) err := w.Call(ctx, "gettickets", &tickets, includeImmature)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -128,9 +128,9 @@ func (w *dcrwallet) getTickets() (*wallettypes.GetTicketsResult, error) {
// getTicketDetails returns the ticket hex, privkey for voting, and the // getTicketDetails returns the ticket hex, privkey for voting, and the
// commitment address. // commitment address.
func (w *dcrwallet) getTicketDetails(ticketHash string) (string, string, stdaddr.Address, error) { func (w *dcrwallet) getTicketDetails(ctx context.Context, ticketHash string) (string, string, stdaddr.Address, error) {
var getTransactionResult wallettypes.GetTransactionResult var getTransactionResult wallettypes.GetTransactionResult
err := w.Call(context.TODO(), "gettransaction", &getTransactionResult, ticketHash, false) err := w.Call(ctx, "gettransaction", &getTransactionResult, ticketHash, false)
if err != nil { if err != nil {
fmt.Printf("gettransaction: %v\n", err) fmt.Printf("gettransaction: %v\n", err)
return "", "", nil, err return "", "", nil, err
@ -160,7 +160,7 @@ func (w *dcrwallet) getTicketDetails(ticketHash string) (string, string, stdaddr
return "", "", nil, err return "", "", nil, err
} }
privKeyStr, err := w.dumpPrivKey(submissionAddr[0]) privKeyStr, err := w.dumpPrivKey(ctx, submissionAddr[0])
if err != nil { if err != nil {
return "", "", nil, err return "", "", nil, err
} }

View File

@ -7,6 +7,7 @@ package main
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"io" "io"
"net/http" "net/http"
"os" "os"
@ -14,6 +15,7 @@ import (
"github.com/decred/slog" "github.com/decred/slog"
"github.com/decred/vspd/client/v2" "github.com/decred/vspd/client/v2"
"github.com/decred/vspd/internal/signal"
"github.com/decred/vspd/types/v2" "github.com/decred/vspd/types/v2"
) )
@ -55,8 +57,13 @@ func run() int {
log := slog.NewBackend(os.Stdout).Logger("") log := slog.NewBackend(os.Stdout).Logger("")
log.SetLevel(slog.LevelTrace) log.SetLevel(slog.LevelTrace)
walletRPC, err := newWalletRPC(rpcURL, rpcUser, rpcPass) ctx := signal.ShutdownListener(log)
walletRPC, err := newWalletRPC(ctx, rpcURL, rpcUser, rpcPass)
if err != nil { if err != nil {
if errors.Is(err, context.Canceled) {
return 0
}
log.Errorf("%v", err) log.Errorf("%v", err)
return 1 return 1
} }
@ -85,8 +92,11 @@ func run() int {
} }
// Get list of tickets // Get list of tickets
tickets, err := walletRPC.getTickets() tickets, err := walletRPC.getTickets(ctx)
if err != nil { if err != nil {
if errors.Is(err, context.Canceled) {
return 0
}
log.Errorf("%v", err) log.Errorf("%v", err)
return 1 return 1
} }
@ -102,9 +112,17 @@ func run() int {
} }
for i := 0; i < len(tickets.Hashes); i++ { for i := 0; i < len(tickets.Hashes); i++ {
// Stop if shutdown requested.
if ctx.Err() != nil {
return 0
}
ticketHash := tickets.Hashes[i] ticketHash := tickets.Hashes[i]
hex, privKeyStr, commitmentAddr, err := walletRPC.getTicketDetails(ticketHash) hex, privKeyStr, commitmentAddr, err := walletRPC.getTicketDetails(ctx, ticketHash)
if err != nil { if err != nil {
if errors.Is(err, context.Canceled) {
return 0
}
log.Errorf("%v", err) log.Errorf("%v", err)
return 1 return 1
} }
@ -125,19 +143,25 @@ func run() int {
Timestamp: time.Now().Unix(), Timestamp: time.Now().Unix(),
} }
feeAddrResp, err := vClient.FeeAddress(context.TODO(), feeAddrReq, commitmentAddr) feeAddrResp, err := vClient.FeeAddress(ctx, feeAddrReq, commitmentAddr)
if err != nil { if err != nil {
if errors.Is(err, context.Canceled) {
return 0
}
log.Errorf("getFeeAddress error: %v", err) log.Errorf("getFeeAddress error: %v", err)
break return 1
} }
log.Infof("feeAddress: %v", feeAddrResp.FeeAddress) log.Infof("feeAddress: %v", feeAddrResp.FeeAddress)
log.Infof("privKeyStr: %v", privKeyStr) log.Infof("privKeyStr: %v", privKeyStr)
feeTx, err := walletRPC.createFeeTx(feeAddrResp.FeeAddress, feeAddrResp.FeeAmount) feeTx, err := walletRPC.createFeeTx(ctx, feeAddrResp.FeeAddress, feeAddrResp.FeeAmount)
if err != nil { if err != nil {
if errors.Is(err, context.Canceled) {
return 0
}
log.Errorf("createFeeTx error: %v", err) log.Errorf("createFeeTx error: %v", err)
break return 1
} }
voteChoices := map[string]string{"autorevocations": "no"} voteChoices := map[string]string{"autorevocations": "no"}
@ -160,8 +184,11 @@ func run() int {
TreasuryPolicy: treasury, TreasuryPolicy: treasury,
} }
_, err = vClient.PayFee(context.TODO(), payFeeReq, commitmentAddr) _, err = vClient.PayFee(ctx, payFeeReq, commitmentAddr)
if err != nil { if err != nil {
if errors.Is(err, context.Canceled) {
return 0
}
log.Errorf("payFee error: %v", err) log.Errorf("payFee error: %v", err)
continue continue
} }
@ -170,10 +197,13 @@ func run() int {
TicketHash: ticketHash, TicketHash: ticketHash,
} }
_, err = vClient.TicketStatus(context.TODO(), ticketStatusReq, commitmentAddr) _, err = vClient.TicketStatus(ctx, ticketStatusReq, commitmentAddr)
if err != nil { if err != nil {
if errors.Is(err, context.Canceled) {
return 0
}
log.Errorf("getTicketStatus error: %v", err) log.Errorf("getTicketStatus error: %v", err)
break return 1
} }
voteChoices["autorevocations"] = "yes" voteChoices["autorevocations"] = "yes"
@ -189,16 +219,22 @@ func run() int {
TreasuryPolicy: treasury, TreasuryPolicy: treasury,
} }
_, err = vClient.SetVoteChoices(context.TODO(), voteChoiceReq, commitmentAddr) _, err = vClient.SetVoteChoices(ctx, voteChoiceReq, commitmentAddr)
if err != nil { if err != nil {
if errors.Is(err, context.Canceled) {
return 0
}
log.Errorf("setVoteChoices error: %v", err) log.Errorf("setVoteChoices error: %v", err)
break return 1
} }
_, err = vClient.TicketStatus(context.TODO(), ticketStatusReq, commitmentAddr) _, err = vClient.TicketStatus(ctx, ticketStatusReq, commitmentAddr)
if err != nil { if err != nil {
if errors.Is(err, context.Canceled) {
return 0
}
log.Errorf("getTicketStatus error: %v", err) log.Errorf("getTicketStatus error: %v", err)
break return 1
} }
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)