From 80f5e6f55cc853d534ada2b8abdb0f1c0b99325e Mon Sep 17 00:00:00 2001 From: jholdstock Date: Mon, 22 Jun 2020 15:15:06 +0100 Subject: [PATCH] Extract code to decode tx and validate tickets --- rpc/dcrd.go | 33 --------------------------------- webapi/errors.go | 5 +++++ webapi/helpers.go | 26 ++++++++++++++++++++++++++ webapi/middleware.go | 29 +++++++++++++++++++++++++++-- webapi/payfee.go | 19 ++----------------- 5 files changed, 60 insertions(+), 52 deletions(-) diff --git a/rpc/dcrd.go b/rpc/dcrd.go index 57da623..eb62f02 100644 --- a/rpc/dcrd.go +++ b/rpc/dcrd.go @@ -7,7 +7,6 @@ import ( "fmt" "strings" - "github.com/decred/dcrd/blockchain/stake/v3" "github.com/decred/dcrd/chaincfg/v3" dcrdtypes "github.com/decred/dcrd/rpc/jsonrpc/types/v2" "github.com/decred/dcrd/wire" @@ -140,38 +139,6 @@ func (c *DcrdRPC) SendRawTransaction(txHex string) error { return nil } -func (c *DcrdRPC) GetTicketCommitmentAddress(ticketHash string, netParams *chaincfg.Params) (string, error) { - // Retrieve and parse the transaction. - resp, err := c.GetRawTransaction(ticketHash) - if err != nil { - return "", err - } - msgHex, err := hex.DecodeString(resp.Hex) - if err != nil { - return "", err - } - msgTx := wire.NewMsgTx() - if err = msgTx.FromBytes(msgHex); err != nil { - return "", err - } - - // Ensure transaction is a valid ticket. - if !stake.IsSStx(msgTx) { - return "", errors.New("invalid transaction - not sstx") - } - if len(msgTx.TxOut) != 3 { - return "", fmt.Errorf("invalid transaction - expected 3 outputs, got %d", len(msgTx.TxOut)) - } - - // Get ticket commitment address. - addr, err := stake.AddrFromSStxPkScrCommitment(msgTx.TxOut[1].PkScript, netParams) - if err != nil { - return "", err - } - - return addr.Address(), nil -} - func (c *DcrdRPC) NotifyBlocks() error { return c.Call(c.ctx, "notifyblocks", nil) } diff --git a/webapi/errors.go b/webapi/errors.go index b1a89f7..667754c 100644 --- a/webapi/errors.go +++ b/webapi/errors.go @@ -18,6 +18,7 @@ const ( errBadSignature errInvalidPrivKey errFeeNotReceived + errInvalidTicket ) // httpStatus maps application error codes to HTTP status codes. @@ -49,6 +50,8 @@ func (e apiError) httpStatus() int { return http.StatusBadRequest case errFeeNotReceived: return http.StatusBadRequest + case errInvalidTicket: + return http.StatusBadRequest default: return http.StatusInternalServerError } @@ -83,6 +86,8 @@ func (e apiError) defaultMessage() string { return "invalid private key" case errFeeNotReceived: return "no fee tx received for ticket" + case errInvalidTicket: + return "not a valid ticket tx" default: return "unknown error" } diff --git a/webapi/helpers.go b/webapi/helpers.go index f223534..db2ba9a 100644 --- a/webapi/helpers.go +++ b/webapi/helpers.go @@ -1,11 +1,14 @@ package webapi import ( + "encoding/hex" "errors" "fmt" + "github.com/decred/dcrd/blockchain/stake/v3" "github.com/decred/dcrd/chaincfg/v3" "github.com/decred/dcrd/dcrutil/v3" + "github.com/decred/dcrd/wire" "github.com/gin-gonic/gin" ) @@ -58,3 +61,26 @@ func validateSignature(reqBytes []byte, commitmentAddress string, c *gin.Context } return nil } + +func decodeTransaction(txHex string) (*wire.MsgTx, error) { + msgHex, err := hex.DecodeString(txHex) + if err != nil { + return nil, err + } + msgTx := wire.NewMsgTx() + if err = msgTx.FromBytes(msgHex); err != nil { + return nil, err + } + return msgTx, nil +} + +func isValidTicket(tx *wire.MsgTx) error { + if !stake.IsSStx(tx) { + return errors.New("invalid transaction - not sstx") + } + if len(tx.TxOut) != 3 { + return fmt.Errorf("invalid transaction - expected 3 outputs, got %d", len(tx.TxOut)) + } + + return nil +} diff --git a/webapi/middleware.go b/webapi/middleware.go index 55c0859..e2916ab 100644 --- a/webapi/middleware.go +++ b/webapi/middleware.go @@ -6,6 +6,7 @@ import ( "net/http" "strings" + "github.com/decred/dcrd/blockchain/stake/v3" "github.com/decred/dcrd/chaincfg/chainhash" "github.com/decred/vspd/rpc" "github.com/gin-gonic/gin" @@ -159,12 +160,36 @@ func vspAuth() gin.HandlerFunc { commitmentAddress = ticket.CommitmentAddress } else { dcrdClient := c.MustGet("DcrdClient").(*rpc.DcrdRPC) - commitmentAddress, err = dcrdClient.GetTicketCommitmentAddress(hash, cfg.NetParams) + + resp, err := dcrdClient.GetRawTransaction(hash) if err != nil { - log.Errorf("GetTicketCommitmentAddress error: %v", err) + log.Errorf("GetRawTransaction error: %v", err) sendError(errInternalError, c) return } + + msgTx, err := decodeTransaction(resp.Hex) + if err != nil { + log.Errorf("decodeTransaction error: %v", err) + sendError(errInternalError, c) + return + } + + err = isValidTicket(msgTx) + if err != nil { + log.Warnf("Invalid ticket from %s: %v", c.ClientIP(), err) + sendError(errInvalidTicket, c) + return + } + + addr, err := stake.AddrFromSStxPkScrCommitment(msgTx.TxOut[1].PkScript, cfg.NetParams) + if err != nil { + log.Errorf("AddrFromSStxPkScrCommitment error: %v", err) + sendError(errInternalError, c) + return + } + + commitmentAddress = addr.Address() } // Validate request signature to ensure ticket ownership. diff --git a/webapi/payfee.go b/webapi/payfee.go index e84e3ab..136a39a 100644 --- a/webapi/payfee.go +++ b/webapi/payfee.go @@ -1,13 +1,11 @@ package webapi import ( - "encoding/hex" "time" "github.com/decred/dcrd/dcrec" "github.com/decred/dcrd/dcrutil/v3" "github.com/decred/dcrd/txscript/v3" - "github.com/decred/dcrd/wire" "github.com/decred/vspd/database" "github.com/decred/vspd/rpc" "github.com/gin-gonic/gin" @@ -95,20 +93,13 @@ func payFee(c *gin.Context) { } // Validate FeeTx. - feeTxBytes, err := hex.DecodeString(payFeeRequest.FeeTx) + feeTx, err := decodeTransaction(payFeeRequest.FeeTx) if err != nil { log.Warnf("Failed to decode tx: %v", err) sendError(errInvalidFeeTx, c) return } - feeTx := wire.NewMsgTx() - if err = feeTx.FromBytes(feeTxBytes); err != nil { - log.Warnf("Failed to deserialize tx: %v", err) - sendError(errInvalidFeeTx, c) - return - } - // Loop through transaction outputs until we find one which pays to the // expected fee address. Record how much is being paid to the fee address. var feePaid dcrutil.Amount @@ -150,18 +141,12 @@ findAddress: } // Decode ticket transaction to get its voting address. - ticketBytes, err := hex.DecodeString(rawTicket.Hex) + ticketTx, err := decodeTransaction(rawTicket.Hex) if err != nil { log.Warnf("Failed to decode tx: %v", err) sendError(errInternalError, c) return } - ticketTx := wire.NewMsgTx() - if err = ticketTx.FromBytes(ticketBytes); err != nil { - log.Errorf("Failed to deserialize tx: %v", err) - sendError(errInternalError, c) - return - } // Get ticket voting address. _, votingAddr, _, err := txscript.ExtractPkScriptAddrs(scriptVersion, ticketTx.TxOut[0].PkScript, cfg.NetParams)