Don't put request bytes into context.

This commit is contained in:
jholdstock 2020-06-22 14:32:38 +01:00 committed by David Hill
parent 5d8215e6b7
commit ed21f0af64
5 changed files with 11 additions and 17 deletions

View File

@ -9,7 +9,6 @@ import (
"github.com/decred/vspd/database"
"github.com/decred/vspd/rpc"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
)
// addrMtx protects getNewFeeAddress.
@ -62,7 +61,6 @@ func getCurrentFee(dcrdClient *rpc.DcrdRPC) (dcrutil.Amount, error) {
func feeAddress(c *gin.Context) {
// Get values which have been added to context by middleware.
rawRequest := c.MustGet("RawRequest").([]byte)
ticket := c.MustGet("Ticket").(database.Ticket)
knownTicket := c.MustGet("KnownTicket").(bool)
commitmentAddress := c.MustGet("CommitmentAddress").(string)
@ -74,7 +72,7 @@ func feeAddress(c *gin.Context) {
}
var feeAddressRequest FeeAddressRequest
if err := binding.JSON.BindBody(rawRequest, &feeAddressRequest); err != nil {
if err := c.ShouldBindJSON(&feeAddressRequest); err != nil {
log.Warnf("Bad feeaddress request from %s: %v", c.ClientIP(), err)
sendErrorWithMsg(err.Error(), errBadRequest, c)
return

View File

@ -1,6 +1,8 @@
package webapi
import (
"bytes"
"io/ioutil"
"net/http"
"strings"
@ -107,16 +109,16 @@ func withWalletClients(wallets rpc.WalletConnect) gin.HandlerFunc {
// use.
func vspAuth() gin.HandlerFunc {
return func(c *gin.Context) {
// Read request bytes.
reqBytes, err := c.GetRawData()
// Read request bytes and then replace the request reader for
// downstream handlers to use.
reqBytes, err := ioutil.ReadAll(c.Request.Body)
if err != nil {
log.Warnf("Error reading request from %s: %v", c.ClientIP(), err)
sendErrorWithMsg(err.Error(), errBadRequest, c)
return
}
// Add raw request to context for downstream handlers to use.
c.Set("RawRequest", reqBytes)
c.Request.Body.Close()
c.Request.Body = ioutil.NopCloser(bytes.NewBuffer(reqBytes))
// Parse request and ensure there is a ticket hash included.
var request ticketHashRequest

View File

@ -11,14 +11,12 @@ import (
"github.com/decred/vspd/database"
"github.com/decred/vspd/rpc"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
)
// payFee is the handler for "POST /payfee".
func payFee(c *gin.Context) {
// Get values which have been added to context by middleware.
rawRequest := c.MustGet("RawRequest").([]byte)
ticket := c.MustGet("Ticket").(database.Ticket)
knownTicket := c.MustGet("KnownTicket").(bool)
dcrdClient := c.MustGet("DcrdClient").(*rpc.DcrdRPC)
@ -35,7 +33,7 @@ func payFee(c *gin.Context) {
}
var payFeeRequest PayFeeRequest
if err := binding.JSON.BindBody(rawRequest, &payFeeRequest); err != nil {
if err := c.ShouldBindJSON(&payFeeRequest); err != nil {
log.Warnf("Bad payfee request from %s: %v", c.ClientIP(), err)
sendErrorWithMsg(err.Error(), errBadRequest, c)
return

View File

@ -6,14 +6,12 @@ import (
"github.com/decred/vspd/database"
"github.com/decred/vspd/rpc"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
)
// setVoteChoices is the handler for "POST /setvotechoices".
func setVoteChoices(c *gin.Context) {
// Get values which have been added to context by middleware.
rawRequest := c.MustGet("RawRequest").([]byte)
ticket := c.MustGet("Ticket").(database.Ticket)
knownTicket := c.MustGet("KnownTicket").(bool)
walletClients := c.MustGet("WalletClients").([]*rpc.WalletRPC)
@ -31,7 +29,7 @@ func setVoteChoices(c *gin.Context) {
}
var setVoteChoicesRequest SetVoteChoicesRequest
if err := binding.JSON.BindBody(rawRequest, &setVoteChoicesRequest); err != nil {
if err := c.ShouldBindJSON(&setVoteChoicesRequest); err != nil {
log.Warnf("Bad setvotechoices request from %s: %v", c.ClientIP(), err)
sendErrorWithMsg(err.Error(), errBadRequest, c)
return

View File

@ -5,14 +5,12 @@ import (
"github.com/decred/vspd/database"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
)
// ticketStatus is the handler for "GET /ticketstatus".
func ticketStatus(c *gin.Context) {
// Get values which have been added to context by middleware.
rawRequest := c.MustGet("RawRequest").([]byte)
ticket := c.MustGet("Ticket").(database.Ticket)
knownTicket := c.MustGet("KnownTicket").(bool)
@ -23,7 +21,7 @@ func ticketStatus(c *gin.Context) {
}
var ticketStatusRequest TicketStatusRequest
if err := binding.JSON.BindBody(rawRequest, &ticketStatusRequest); err != nil {
if err := c.ShouldBindJSON(&ticketStatusRequest); err != nil {
log.Warnf("Bad ticketstatus request from %s: %v", c.ClientIP(), err)
sendErrorWithMsg(err.Error(), errBadRequest, c)
return