Don't put request bytes into context.

This commit is contained in:
jholdstock 2020-06-22 14:32:38 +01:00 committed by David Hill
parent 5d8215e6b7
commit ed21f0af64
5 changed files with 11 additions and 17 deletions

View File

@ -9,7 +9,6 @@ import (
"github.com/decred/vspd/database" "github.com/decred/vspd/database"
"github.com/decred/vspd/rpc" "github.com/decred/vspd/rpc"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
) )
// addrMtx protects getNewFeeAddress. // addrMtx protects getNewFeeAddress.
@ -62,7 +61,6 @@ func getCurrentFee(dcrdClient *rpc.DcrdRPC) (dcrutil.Amount, error) {
func feeAddress(c *gin.Context) { func feeAddress(c *gin.Context) {
// Get values which have been added to context by middleware. // Get values which have been added to context by middleware.
rawRequest := c.MustGet("RawRequest").([]byte)
ticket := c.MustGet("Ticket").(database.Ticket) ticket := c.MustGet("Ticket").(database.Ticket)
knownTicket := c.MustGet("KnownTicket").(bool) knownTicket := c.MustGet("KnownTicket").(bool)
commitmentAddress := c.MustGet("CommitmentAddress").(string) commitmentAddress := c.MustGet("CommitmentAddress").(string)
@ -74,7 +72,7 @@ func feeAddress(c *gin.Context) {
} }
var feeAddressRequest FeeAddressRequest var feeAddressRequest FeeAddressRequest
if err := binding.JSON.BindBody(rawRequest, &feeAddressRequest); err != nil { if err := c.ShouldBindJSON(&feeAddressRequest); err != nil {
log.Warnf("Bad feeaddress request from %s: %v", c.ClientIP(), err) log.Warnf("Bad feeaddress request from %s: %v", c.ClientIP(), err)
sendErrorWithMsg(err.Error(), errBadRequest, c) sendErrorWithMsg(err.Error(), errBadRequest, c)
return return

View File

@ -1,6 +1,8 @@
package webapi package webapi
import ( import (
"bytes"
"io/ioutil"
"net/http" "net/http"
"strings" "strings"
@ -107,16 +109,16 @@ func withWalletClients(wallets rpc.WalletConnect) gin.HandlerFunc {
// use. // use.
func vspAuth() gin.HandlerFunc { func vspAuth() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
// Read request bytes. // Read request bytes and then replace the request reader for
reqBytes, err := c.GetRawData() // downstream handlers to use.
reqBytes, err := ioutil.ReadAll(c.Request.Body)
if err != nil { if err != nil {
log.Warnf("Error reading request from %s: %v", c.ClientIP(), err) log.Warnf("Error reading request from %s: %v", c.ClientIP(), err)
sendErrorWithMsg(err.Error(), errBadRequest, c) sendErrorWithMsg(err.Error(), errBadRequest, c)
return return
} }
c.Request.Body.Close()
// Add raw request to context for downstream handlers to use. c.Request.Body = ioutil.NopCloser(bytes.NewBuffer(reqBytes))
c.Set("RawRequest", reqBytes)
// Parse request and ensure there is a ticket hash included. // Parse request and ensure there is a ticket hash included.
var request ticketHashRequest var request ticketHashRequest

View File

@ -11,14 +11,12 @@ import (
"github.com/decred/vspd/database" "github.com/decred/vspd/database"
"github.com/decred/vspd/rpc" "github.com/decred/vspd/rpc"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
) )
// payFee is the handler for "POST /payfee". // payFee is the handler for "POST /payfee".
func payFee(c *gin.Context) { func payFee(c *gin.Context) {
// Get values which have been added to context by middleware. // Get values which have been added to context by middleware.
rawRequest := c.MustGet("RawRequest").([]byte)
ticket := c.MustGet("Ticket").(database.Ticket) ticket := c.MustGet("Ticket").(database.Ticket)
knownTicket := c.MustGet("KnownTicket").(bool) knownTicket := c.MustGet("KnownTicket").(bool)
dcrdClient := c.MustGet("DcrdClient").(*rpc.DcrdRPC) dcrdClient := c.MustGet("DcrdClient").(*rpc.DcrdRPC)
@ -35,7 +33,7 @@ func payFee(c *gin.Context) {
} }
var payFeeRequest PayFeeRequest var payFeeRequest PayFeeRequest
if err := binding.JSON.BindBody(rawRequest, &payFeeRequest); err != nil { if err := c.ShouldBindJSON(&payFeeRequest); err != nil {
log.Warnf("Bad payfee request from %s: %v", c.ClientIP(), err) log.Warnf("Bad payfee request from %s: %v", c.ClientIP(), err)
sendErrorWithMsg(err.Error(), errBadRequest, c) sendErrorWithMsg(err.Error(), errBadRequest, c)
return return

View File

@ -6,14 +6,12 @@ import (
"github.com/decred/vspd/database" "github.com/decred/vspd/database"
"github.com/decred/vspd/rpc" "github.com/decred/vspd/rpc"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
) )
// setVoteChoices is the handler for "POST /setvotechoices". // setVoteChoices is the handler for "POST /setvotechoices".
func setVoteChoices(c *gin.Context) { func setVoteChoices(c *gin.Context) {
// Get values which have been added to context by middleware. // Get values which have been added to context by middleware.
rawRequest := c.MustGet("RawRequest").([]byte)
ticket := c.MustGet("Ticket").(database.Ticket) ticket := c.MustGet("Ticket").(database.Ticket)
knownTicket := c.MustGet("KnownTicket").(bool) knownTicket := c.MustGet("KnownTicket").(bool)
walletClients := c.MustGet("WalletClients").([]*rpc.WalletRPC) walletClients := c.MustGet("WalletClients").([]*rpc.WalletRPC)
@ -31,7 +29,7 @@ func setVoteChoices(c *gin.Context) {
} }
var setVoteChoicesRequest SetVoteChoicesRequest var setVoteChoicesRequest SetVoteChoicesRequest
if err := binding.JSON.BindBody(rawRequest, &setVoteChoicesRequest); err != nil { if err := c.ShouldBindJSON(&setVoteChoicesRequest); err != nil {
log.Warnf("Bad setvotechoices request from %s: %v", c.ClientIP(), err) log.Warnf("Bad setvotechoices request from %s: %v", c.ClientIP(), err)
sendErrorWithMsg(err.Error(), errBadRequest, c) sendErrorWithMsg(err.Error(), errBadRequest, c)
return return

View File

@ -5,14 +5,12 @@ import (
"github.com/decred/vspd/database" "github.com/decred/vspd/database"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
) )
// ticketStatus is the handler for "GET /ticketstatus". // ticketStatus is the handler for "GET /ticketstatus".
func ticketStatus(c *gin.Context) { func ticketStatus(c *gin.Context) {
// Get values which have been added to context by middleware. // Get values which have been added to context by middleware.
rawRequest := c.MustGet("RawRequest").([]byte)
ticket := c.MustGet("Ticket").(database.Ticket) ticket := c.MustGet("Ticket").(database.Ticket)
knownTicket := c.MustGet("KnownTicket").(bool) knownTicket := c.MustGet("KnownTicket").(bool)
@ -23,7 +21,7 @@ func ticketStatus(c *gin.Context) {
} }
var ticketStatusRequest TicketStatusRequest var ticketStatusRequest TicketStatusRequest
if err := binding.JSON.BindBody(rawRequest, &ticketStatusRequest); err != nil { if err := c.ShouldBindJSON(&ticketStatusRequest); err != nil {
log.Warnf("Bad ticketstatus request from %s: %v", c.ClientIP(), err) log.Warnf("Bad ticketstatus request from %s: %v", c.ClientIP(), err)
sendErrorWithMsg(err.Error(), errBadRequest, c) sendErrorWithMsg(err.Error(), errBadRequest, c)
return return