diff --git a/cmd/v3tool/dcrwallet.go b/cmd/v3tool/dcrwallet.go index 7fd295f..0edf637 100644 --- a/cmd/v3tool/dcrwallet.go +++ b/cmd/v3tool/dcrwallet.go @@ -28,24 +28,24 @@ type dcrwallet struct { *wsrpc.Client } -func newWalletRPC(rpcURL, rpcUser, rpcPass string) (*dcrwallet, error) { +func newWalletRPC(ctx context.Context, rpcURL, rpcUser, rpcPass string) (*dcrwallet, error) { tlsOpt := wsrpc.WithTLSConfig(&tls.Config{ InsecureSkipVerify: true, }) authOpt := wsrpc.WithBasicAuth(rpcUser, rpcPass) - rpc, err := wsrpc.Dial(context.TODO(), rpcURL, tlsOpt, authOpt) + rpc, err := wsrpc.Dial(ctx, rpcURL, tlsOpt, authOpt) if err != nil { return nil, err } return &dcrwallet{rpc}, nil } -func (w *dcrwallet) createFeeTx(feeAddress string, fee int64) (string, error) { +func (w *dcrwallet) createFeeTx(ctx context.Context, feeAddress string, fee int64) (string, error) { amounts := make(map[string]float64) amounts[feeAddress] = dcrutil.Amount(fee).ToCoin() var msgtxstr string - err := w.Call(context.TODO(), "createrawtransaction", &msgtxstr, nil, amounts) + err := w.Call(ctx, "createrawtransaction", &msgtxstr, nil, amounts) if err != nil { return "", err } @@ -55,7 +55,7 @@ func (w *dcrwallet) createFeeTx(feeAddress string, fee int64) (string, error) { ConfTarget: &zero, } var fundTx wallettypes.FundRawTransactionResult - err = w.Call(context.TODO(), "fundrawtransaction", &fundTx, msgtxstr, "default", &opt) + err = w.Call(ctx, "fundrawtransaction", &fundTx, msgtxstr, "default", &opt) if err != nil { return "", err } @@ -77,7 +77,7 @@ func (w *dcrwallet) createFeeTx(feeAddress string, fee int64) (string, error) { var locked bool const unlock = false - err = w.Call(context.TODO(), "lockunspent", &locked, unlock, transactions) + err = w.Call(ctx, "lockunspent", &locked, unlock, transactions) if err != nil { return "", err } @@ -87,7 +87,7 @@ func (w *dcrwallet) createFeeTx(feeAddress string, fee int64) (string, error) { } var signedTx wallettypes.SignRawTransactionResult - err = w.Call(context.TODO(), "signrawtransaction", &signedTx, fundTx.Hex) + err = w.Call(ctx, "signrawtransaction", &signedTx, fundTx.Hex) if err != nil { return "", err } @@ -97,9 +97,9 @@ func (w *dcrwallet) createFeeTx(feeAddress string, fee int64) (string, error) { return signedTx.Hex, nil } -func (w *dcrwallet) SignMessage(_ context.Context, msg string, commitmentAddr stdaddr.Address) ([]byte, error) { +func (w *dcrwallet) SignMessage(ctx context.Context, msg string, commitmentAddr stdaddr.Address) ([]byte, error) { var signature string - err := w.Call(context.TODO(), "signmessage", &signature, commitmentAddr.String(), msg) + err := w.Call(ctx, "signmessage", &signature, commitmentAddr.String(), msg) if err != nil { return nil, err } @@ -107,19 +107,19 @@ func (w *dcrwallet) SignMessage(_ context.Context, msg string, commitmentAddr st return base64.StdEncoding.DecodeString(signature) } -func (w *dcrwallet) dumpPrivKey(addr stdaddr.Address) (string, error) { +func (w *dcrwallet) dumpPrivKey(ctx context.Context, addr stdaddr.Address) (string, error) { var privKeyStr string - err := w.Call(context.TODO(), "dumpprivkey", &privKeyStr, addr.String()) + err := w.Call(ctx, "dumpprivkey", &privKeyStr, addr.String()) if err != nil { return "", err } return privKeyStr, nil } -func (w *dcrwallet) getTickets() (*wallettypes.GetTicketsResult, error) { +func (w *dcrwallet) getTickets(ctx context.Context) (*wallettypes.GetTicketsResult, error) { var tickets wallettypes.GetTicketsResult const includeImmature = true - err := w.Call(context.TODO(), "gettickets", &tickets, includeImmature) + err := w.Call(ctx, "gettickets", &tickets, includeImmature) if err != nil { return nil, err } @@ -128,9 +128,9 @@ func (w *dcrwallet) getTickets() (*wallettypes.GetTicketsResult, error) { // getTicketDetails returns the ticket hex, privkey for voting, and the // commitment address. -func (w *dcrwallet) getTicketDetails(ticketHash string) (string, string, stdaddr.Address, error) { +func (w *dcrwallet) getTicketDetails(ctx context.Context, ticketHash string) (string, string, stdaddr.Address, error) { var getTransactionResult wallettypes.GetTransactionResult - err := w.Call(context.TODO(), "gettransaction", &getTransactionResult, ticketHash, false) + err := w.Call(ctx, "gettransaction", &getTransactionResult, ticketHash, false) if err != nil { fmt.Printf("gettransaction: %v\n", err) return "", "", nil, err @@ -160,7 +160,7 @@ func (w *dcrwallet) getTicketDetails(ticketHash string) (string, string, stdaddr return "", "", nil, err } - privKeyStr, err := w.dumpPrivKey(submissionAddr[0]) + privKeyStr, err := w.dumpPrivKey(ctx, submissionAddr[0]) if err != nil { return "", "", nil, err } diff --git a/cmd/v3tool/main.go b/cmd/v3tool/main.go index 7d63b71..b1863bd 100644 --- a/cmd/v3tool/main.go +++ b/cmd/v3tool/main.go @@ -7,6 +7,7 @@ package main import ( "context" "encoding/json" + "errors" "io" "net/http" "os" @@ -14,6 +15,7 @@ import ( "github.com/decred/slog" "github.com/decred/vspd/client/v2" + "github.com/decred/vspd/internal/signal" "github.com/decred/vspd/types/v2" ) @@ -55,8 +57,13 @@ func run() int { log := slog.NewBackend(os.Stdout).Logger("") log.SetLevel(slog.LevelTrace) - walletRPC, err := newWalletRPC(rpcURL, rpcUser, rpcPass) + ctx := signal.ShutdownListener(log) + + walletRPC, err := newWalletRPC(ctx, rpcURL, rpcUser, rpcPass) if err != nil { + if errors.Is(err, context.Canceled) { + return 0 + } log.Errorf("%v", err) return 1 } @@ -85,8 +92,11 @@ func run() int { } // Get list of tickets - tickets, err := walletRPC.getTickets() + tickets, err := walletRPC.getTickets(ctx) if err != nil { + if errors.Is(err, context.Canceled) { + return 0 + } log.Errorf("%v", err) return 1 } @@ -102,9 +112,17 @@ func run() int { } for i := 0; i < len(tickets.Hashes); i++ { + // Stop if shutdown requested. + if ctx.Err() != nil { + return 0 + } + ticketHash := tickets.Hashes[i] - hex, privKeyStr, commitmentAddr, err := walletRPC.getTicketDetails(ticketHash) + hex, privKeyStr, commitmentAddr, err := walletRPC.getTicketDetails(ctx, ticketHash) if err != nil { + if errors.Is(err, context.Canceled) { + return 0 + } log.Errorf("%v", err) return 1 } @@ -125,19 +143,25 @@ func run() int { Timestamp: time.Now().Unix(), } - feeAddrResp, err := vClient.FeeAddress(context.TODO(), feeAddrReq, commitmentAddr) + feeAddrResp, err := vClient.FeeAddress(ctx, feeAddrReq, commitmentAddr) if err != nil { + if errors.Is(err, context.Canceled) { + return 0 + } log.Errorf("getFeeAddress error: %v", err) - break + return 1 } log.Infof("feeAddress: %v", feeAddrResp.FeeAddress) log.Infof("privKeyStr: %v", privKeyStr) - feeTx, err := walletRPC.createFeeTx(feeAddrResp.FeeAddress, feeAddrResp.FeeAmount) + feeTx, err := walletRPC.createFeeTx(ctx, feeAddrResp.FeeAddress, feeAddrResp.FeeAmount) if err != nil { + if errors.Is(err, context.Canceled) { + return 0 + } log.Errorf("createFeeTx error: %v", err) - break + return 1 } voteChoices := map[string]string{"autorevocations": "no"} @@ -160,8 +184,11 @@ func run() int { TreasuryPolicy: treasury, } - _, err = vClient.PayFee(context.TODO(), payFeeReq, commitmentAddr) + _, err = vClient.PayFee(ctx, payFeeReq, commitmentAddr) if err != nil { + if errors.Is(err, context.Canceled) { + return 0 + } log.Errorf("payFee error: %v", err) continue } @@ -170,10 +197,13 @@ func run() int { TicketHash: ticketHash, } - _, err = vClient.TicketStatus(context.TODO(), ticketStatusReq, commitmentAddr) + _, err = vClient.TicketStatus(ctx, ticketStatusReq, commitmentAddr) if err != nil { + if errors.Is(err, context.Canceled) { + return 0 + } log.Errorf("getTicketStatus error: %v", err) - break + return 1 } voteChoices["autorevocations"] = "yes" @@ -189,16 +219,22 @@ func run() int { TreasuryPolicy: treasury, } - _, err = vClient.SetVoteChoices(context.TODO(), voteChoiceReq, commitmentAddr) + _, err = vClient.SetVoteChoices(ctx, voteChoiceReq, commitmentAddr) if err != nil { + if errors.Is(err, context.Canceled) { + return 0 + } log.Errorf("setVoteChoices error: %v", err) - break + return 1 } - _, err = vClient.TicketStatus(context.TODO(), ticketStatusReq, commitmentAddr) + _, err = vClient.TicketStatus(ctx, ticketStatusReq, commitmentAddr) if err != nil { + if errors.Is(err, context.Canceled) { + return 0 + } log.Errorf("getTicketStatus error: %v", err) - break + return 1 } time.Sleep(1 * time.Second)