Extract code to decode tx and validate tickets

This commit is contained in:
jholdstock 2020-06-22 15:15:06 +01:00 committed by David Hill
parent ed21f0af64
commit 80f5e6f55c
5 changed files with 60 additions and 52 deletions

View File

@ -7,7 +7,6 @@ import (
"fmt" "fmt"
"strings" "strings"
"github.com/decred/dcrd/blockchain/stake/v3"
"github.com/decred/dcrd/chaincfg/v3" "github.com/decred/dcrd/chaincfg/v3"
dcrdtypes "github.com/decred/dcrd/rpc/jsonrpc/types/v2" dcrdtypes "github.com/decred/dcrd/rpc/jsonrpc/types/v2"
"github.com/decred/dcrd/wire" "github.com/decred/dcrd/wire"
@ -140,38 +139,6 @@ func (c *DcrdRPC) SendRawTransaction(txHex string) error {
return nil return nil
} }
func (c *DcrdRPC) GetTicketCommitmentAddress(ticketHash string, netParams *chaincfg.Params) (string, error) {
// Retrieve and parse the transaction.
resp, err := c.GetRawTransaction(ticketHash)
if err != nil {
return "", err
}
msgHex, err := hex.DecodeString(resp.Hex)
if err != nil {
return "", err
}
msgTx := wire.NewMsgTx()
if err = msgTx.FromBytes(msgHex); err != nil {
return "", err
}
// Ensure transaction is a valid ticket.
if !stake.IsSStx(msgTx) {
return "", errors.New("invalid transaction - not sstx")
}
if len(msgTx.TxOut) != 3 {
return "", fmt.Errorf("invalid transaction - expected 3 outputs, got %d", len(msgTx.TxOut))
}
// Get ticket commitment address.
addr, err := stake.AddrFromSStxPkScrCommitment(msgTx.TxOut[1].PkScript, netParams)
if err != nil {
return "", err
}
return addr.Address(), nil
}
func (c *DcrdRPC) NotifyBlocks() error { func (c *DcrdRPC) NotifyBlocks() error {
return c.Call(c.ctx, "notifyblocks", nil) return c.Call(c.ctx, "notifyblocks", nil)
} }

View File

@ -18,6 +18,7 @@ const (
errBadSignature errBadSignature
errInvalidPrivKey errInvalidPrivKey
errFeeNotReceived errFeeNotReceived
errInvalidTicket
) )
// httpStatus maps application error codes to HTTP status codes. // httpStatus maps application error codes to HTTP status codes.
@ -49,6 +50,8 @@ func (e apiError) httpStatus() int {
return http.StatusBadRequest return http.StatusBadRequest
case errFeeNotReceived: case errFeeNotReceived:
return http.StatusBadRequest return http.StatusBadRequest
case errInvalidTicket:
return http.StatusBadRequest
default: default:
return http.StatusInternalServerError return http.StatusInternalServerError
} }
@ -83,6 +86,8 @@ func (e apiError) defaultMessage() string {
return "invalid private key" return "invalid private key"
case errFeeNotReceived: case errFeeNotReceived:
return "no fee tx received for ticket" return "no fee tx received for ticket"
case errInvalidTicket:
return "not a valid ticket tx"
default: default:
return "unknown error" return "unknown error"
} }

View File

@ -1,11 +1,14 @@
package webapi package webapi
import ( import (
"encoding/hex"
"errors" "errors"
"fmt" "fmt"
"github.com/decred/dcrd/blockchain/stake/v3"
"github.com/decred/dcrd/chaincfg/v3" "github.com/decred/dcrd/chaincfg/v3"
"github.com/decred/dcrd/dcrutil/v3" "github.com/decred/dcrd/dcrutil/v3"
"github.com/decred/dcrd/wire"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@ -58,3 +61,26 @@ func validateSignature(reqBytes []byte, commitmentAddress string, c *gin.Context
} }
return nil return nil
} }
func decodeTransaction(txHex string) (*wire.MsgTx, error) {
msgHex, err := hex.DecodeString(txHex)
if err != nil {
return nil, err
}
msgTx := wire.NewMsgTx()
if err = msgTx.FromBytes(msgHex); err != nil {
return nil, err
}
return msgTx, nil
}
func isValidTicket(tx *wire.MsgTx) error {
if !stake.IsSStx(tx) {
return errors.New("invalid transaction - not sstx")
}
if len(tx.TxOut) != 3 {
return fmt.Errorf("invalid transaction - expected 3 outputs, got %d", len(tx.TxOut))
}
return nil
}

View File

@ -6,6 +6,7 @@ import (
"net/http" "net/http"
"strings" "strings"
"github.com/decred/dcrd/blockchain/stake/v3"
"github.com/decred/dcrd/chaincfg/chainhash" "github.com/decred/dcrd/chaincfg/chainhash"
"github.com/decred/vspd/rpc" "github.com/decred/vspd/rpc"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -159,12 +160,36 @@ func vspAuth() gin.HandlerFunc {
commitmentAddress = ticket.CommitmentAddress commitmentAddress = ticket.CommitmentAddress
} else { } else {
dcrdClient := c.MustGet("DcrdClient").(*rpc.DcrdRPC) dcrdClient := c.MustGet("DcrdClient").(*rpc.DcrdRPC)
commitmentAddress, err = dcrdClient.GetTicketCommitmentAddress(hash, cfg.NetParams)
resp, err := dcrdClient.GetRawTransaction(hash)
if err != nil { if err != nil {
log.Errorf("GetTicketCommitmentAddress error: %v", err) log.Errorf("GetRawTransaction error: %v", err)
sendError(errInternalError, c) sendError(errInternalError, c)
return return
} }
msgTx, err := decodeTransaction(resp.Hex)
if err != nil {
log.Errorf("decodeTransaction error: %v", err)
sendError(errInternalError, c)
return
}
err = isValidTicket(msgTx)
if err != nil {
log.Warnf("Invalid ticket from %s: %v", c.ClientIP(), err)
sendError(errInvalidTicket, c)
return
}
addr, err := stake.AddrFromSStxPkScrCommitment(msgTx.TxOut[1].PkScript, cfg.NetParams)
if err != nil {
log.Errorf("AddrFromSStxPkScrCommitment error: %v", err)
sendError(errInternalError, c)
return
}
commitmentAddress = addr.Address()
} }
// Validate request signature to ensure ticket ownership. // Validate request signature to ensure ticket ownership.

View File

@ -1,13 +1,11 @@
package webapi package webapi
import ( import (
"encoding/hex"
"time" "time"
"github.com/decred/dcrd/dcrec" "github.com/decred/dcrd/dcrec"
"github.com/decred/dcrd/dcrutil/v3" "github.com/decred/dcrd/dcrutil/v3"
"github.com/decred/dcrd/txscript/v3" "github.com/decred/dcrd/txscript/v3"
"github.com/decred/dcrd/wire"
"github.com/decred/vspd/database" "github.com/decred/vspd/database"
"github.com/decred/vspd/rpc" "github.com/decred/vspd/rpc"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -95,20 +93,13 @@ func payFee(c *gin.Context) {
} }
// Validate FeeTx. // Validate FeeTx.
feeTxBytes, err := hex.DecodeString(payFeeRequest.FeeTx) feeTx, err := decodeTransaction(payFeeRequest.FeeTx)
if err != nil { if err != nil {
log.Warnf("Failed to decode tx: %v", err) log.Warnf("Failed to decode tx: %v", err)
sendError(errInvalidFeeTx, c) sendError(errInvalidFeeTx, c)
return return
} }
feeTx := wire.NewMsgTx()
if err = feeTx.FromBytes(feeTxBytes); err != nil {
log.Warnf("Failed to deserialize tx: %v", err)
sendError(errInvalidFeeTx, c)
return
}
// Loop through transaction outputs until we find one which pays to the // Loop through transaction outputs until we find one which pays to the
// expected fee address. Record how much is being paid to the fee address. // expected fee address. Record how much is being paid to the fee address.
var feePaid dcrutil.Amount var feePaid dcrutil.Amount
@ -150,18 +141,12 @@ findAddress:
} }
// Decode ticket transaction to get its voting address. // Decode ticket transaction to get its voting address.
ticketBytes, err := hex.DecodeString(rawTicket.Hex) ticketTx, err := decodeTransaction(rawTicket.Hex)
if err != nil { if err != nil {
log.Warnf("Failed to decode tx: %v", err) log.Warnf("Failed to decode tx: %v", err)
sendError(errInternalError, c) sendError(errInternalError, c)
return return
} }
ticketTx := wire.NewMsgTx()
if err = ticketTx.FromBytes(ticketBytes); err != nil {
log.Errorf("Failed to deserialize tx: %v", err)
sendError(errInternalError, c)
return
}
// Get ticket voting address. // Get ticket voting address.
_, votingAddr, _, err := txscript.ExtractPkScriptAddrs(scriptVersion, ticketTx.TxOut[0].PkScript, cfg.NetParams) _, votingAddr, _, err := txscript.ExtractPkScriptAddrs(scriptVersion, ticketTx.TxOut[0].PkScript, cfg.NetParams)