webapi: Move VSP closed check to middleware.

Checking this in middleware not only removes duplicated code from each of the handlers, but also makes things more efficient. Previously, dcrd/dcrwallet/database clients were initialized (and could possibly cause errors) **before** checking if the VSP is closed. Now the check happens before those other things.

Also bundled with some improvements to setaltsignaddr_test.go which I implemented along the way.
This commit is contained in:
Jamie Holdstock 2023-02-17 20:44:13 +00:00 committed by GitHub
parent 824c41c2a5
commit 7718476caf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 123 additions and 96 deletions

View File

@ -90,11 +90,6 @@ func (s *Server) feeAddress(c *gin.Context) {
} }
reqBytes := c.MustGet(requestBytesKey).([]byte) reqBytes := c.MustGet(requestBytesKey).([]byte)
if s.cfg.VspClosed {
s.sendError(types.ErrVspClosed, c)
return
}
var request types.FeeAddressRequest var request types.FeeAddressRequest
if err := binding.JSON.BindBody(reqBytes, &request); err != nil { if err := binding.JSON.BindBody(reqBytes, &request); err != nil {
s.log.Warnf("%s: Bad request (clientIP=%s): %v", funcName, c.ClientIP(), err) s.log.Warnf("%s: Bad request (clientIP=%s): %v", funcName, c.ClientIP(), err)

View File

@ -111,6 +111,13 @@ func drainAndReplaceBody(req *http.Request) ([]byte, error) {
return reqBytes, nil return reqBytes, nil
} }
func (s *Server) vspMustBeOpen(c *gin.Context) {
if s.cfg.VspClosed {
s.sendError(types.ErrVspClosed, c)
return
}
}
// broadcastTicket will ensure that the local dcrd instance is aware of the // broadcastTicket will ensure that the local dcrd instance is aware of the
// provided ticket. // provided ticket.
// Ticket hash, ticket hex, and parent hex are parsed from the request body and // Ticket hash, ticket hex, and parent hex are parsed from the request body and

View File

@ -36,11 +36,6 @@ func (s *Server) payFee(c *gin.Context) {
} }
reqBytes := c.MustGet(requestBytesKey).([]byte) reqBytes := c.MustGet(requestBytesKey).([]byte)
if s.cfg.VspClosed {
s.sendError(types.ErrVspClosed, c)
return
}
if !knownTicket { if !knownTicket {
s.log.Warnf("%s: Unknown ticket (clientIP=%s)", funcName, c.ClientIP()) s.log.Warnf("%s: Unknown ticket (clientIP=%s)", funcName, c.ClientIP())
s.sendError(types.ErrUnknownTicket, c) s.sendError(types.ErrUnknownTicket, c)

View File

@ -40,11 +40,6 @@ func (s *Server) setAltSignAddr(c *gin.Context) {
} }
reqBytes := c.MustGet(requestBytesKey).([]byte) reqBytes := c.MustGet(requestBytesKey).([]byte)
if s.cfg.VspClosed {
s.sendError(types.ErrVspClosed, c)
return
}
var request types.SetAltSignAddrRequest var request types.SetAltSignAddrRequest
if err := binding.JSON.BindBody(reqBytes, &request); err != nil { if err := binding.JSON.BindBody(reqBytes, &request); err != nil {
s.log.Warnf("%s: Bad request (clientIP=%s): %v", funcName, c.ClientIP(), err) s.log.Warnf("%s: Bad request (clientIP=%s): %v", funcName, c.ClientIP(), err)

View File

@ -10,6 +10,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io"
"math/rand" "math/rand"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -130,17 +131,18 @@ func (n *testNode) GetRawTransaction(txHash string) (*dcrdtypes.TxRawResult, err
func TestSetAltSignAddress(t *testing.T) { func TestSetAltSignAddress(t *testing.T) {
const testAddr = "DsVoDXNQqyF3V83PJJ5zMdnB4pQuJHBAh15" const testAddr = "DsVoDXNQqyF3V83PJJ5zMdnB4pQuJHBAh15"
tests := []struct { tests := map[string]struct {
name string
dcrdClientErr bool dcrdClientErr bool
vspClosed bool
deformReq int deformReq int
addr string addr string
node *testNode node *testNode
isExistingAltSignAddr bool isExistingAltSignAddr bool
wantCode int wantHTTPStatus int
}{{ // wantErrCode and wantErrMsg only checked if wantHTTPStatus != 200.
name: "ok", wantErrCode types.ErrorCode
wantErrMsg string
}{
"ok": {
addr: testAddr, addr: testAddr,
node: &testNode{ node: &testNode{
getRawTransaction: &dcrdtypes.TxRawResult{ getRawTransaction: &dcrdtypes.TxRawResult{
@ -149,37 +151,42 @@ func TestSetAltSignAddress(t *testing.T) {
getRawTransactionErr: nil, getRawTransactionErr: nil,
existsLiveTicket: true, existsLiveTicket: true,
}, },
wantCode: http.StatusOK, wantHTTPStatus: http.StatusOK,
}, { },
name: "vsp closed", "dcrd client error": {
vspClosed: true,
wantCode: http.StatusBadRequest,
}, {
name: "dcrd client error",
dcrdClientErr: true, dcrdClientErr: true,
wantCode: http.StatusInternalServerError, wantHTTPStatus: http.StatusInternalServerError,
}, { wantErrCode: types.ErrInternalError,
wantErrMsg: types.ErrInternalError.DefaultMessage(),
name: "bad request", },
"bad request": {
deformReq: 1, deformReq: 1,
wantCode: http.StatusBadRequest, wantHTTPStatus: http.StatusBadRequest,
}, { wantErrCode: types.ErrBadRequest,
name: "bad addr", wantErrMsg: "json: cannot unmarshal string into Go value of type types.SetAltSignAddrRequest",
},
"bad addr": {
addr: "xxx", addr: "xxx",
wantCode: http.StatusBadRequest, wantHTTPStatus: http.StatusBadRequest,
}, { wantErrCode: types.ErrBadRequest,
name: "addr wrong type", wantErrMsg: "failed to decoded address \"xxx\": invalid format: version and/or checksum bytes missing",
},
"addr wrong type": {
addr: "DkM3ZigNyiwHrsXRjkDQ8t8tW6uKGW9g61qEkG3bMqQPQWYEf5X3J", addr: "DkM3ZigNyiwHrsXRjkDQ8t8tW6uKGW9g61qEkG3bMqQPQWYEf5X3J",
wantCode: http.StatusBadRequest, wantHTTPStatus: http.StatusBadRequest,
}, { wantErrCode: types.ErrBadRequest,
name: "getRawTransaction error from dcrd client", wantErrMsg: "wrong type for alternate signing address",
},
"getRawTransaction error from dcrd client": {
addr: testAddr, addr: testAddr,
node: &testNode{ node: &testNode{
getRawTransactionErr: errors.New("getRawTransaction error"), getRawTransactionErr: errors.New("getRawTransaction error"),
}, },
wantCode: http.StatusInternalServerError, wantHTTPStatus: http.StatusInternalServerError,
}, { wantErrCode: types.ErrInternalError,
name: "existsLiveTicket error from dcrd client", wantErrMsg: types.ErrInternalError.DefaultMessage(),
},
"existsLiveTicket error from dcrd client": {
addr: testAddr, addr: testAddr,
node: &testNode{ node: &testNode{
getRawTransaction: &dcrdtypes.TxRawResult{ getRawTransaction: &dcrdtypes.TxRawResult{
@ -187,9 +194,11 @@ func TestSetAltSignAddress(t *testing.T) {
}, },
existsLiveTicketErr: errors.New("existsLiveTicket error"), existsLiveTicketErr: errors.New("existsLiveTicket error"),
}, },
wantCode: http.StatusInternalServerError, wantHTTPStatus: http.StatusInternalServerError,
}, { wantErrCode: types.ErrInternalError,
name: "ticket can't vote", wantErrMsg: types.ErrInternalError.DefaultMessage(),
},
"ticket can't vote": {
addr: testAddr, addr: testAddr,
node: &testNode{ node: &testNode{
getRawTransaction: &dcrdtypes.TxRawResult{ getRawTransaction: &dcrdtypes.TxRawResult{
@ -197,20 +206,25 @@ func TestSetAltSignAddress(t *testing.T) {
}, },
existsLiveTicket: false, existsLiveTicket: false,
}, },
wantCode: http.StatusBadRequest, wantHTTPStatus: http.StatusBadRequest,
}, { wantErrCode: types.ErrTicketCannotVote,
name: "only one alt sign addr allowed", wantErrMsg: types.ErrTicketCannotVote.DefaultMessage(),
},
"only one alt sign addr allowed": {
addr: testAddr, addr: testAddr,
node: &testNode{ node: &testNode{
getRawTransaction: &dcrdtypes.TxRawResult{}, getRawTransaction: &dcrdtypes.TxRawResult{},
existsLiveTicket: true, existsLiveTicket: true,
}, },
isExistingAltSignAddr: true, isExistingAltSignAddr: true,
wantCode: http.StatusBadRequest, wantHTTPStatus: http.StatusBadRequest,
}} wantErrCode: types.ErrBadRequest,
wantErrMsg: "alternate sign address data already exists",
},
}
for _, test := range tests { for testName, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(testName, func(t *testing.T) {
ticketHash := randString(64, hexCharset) ticketHash := randString(64, hexCharset)
req := &types.SetAltSignAddrRequest{ req := &types.SetAltSignAddrRequest{
Timestamp: time.Now().Unix(), Timestamp: time.Now().Unix(),
@ -238,8 +252,6 @@ func TestSetAltSignAddress(t *testing.T) {
} }
} }
api.cfg.VspClosed = test.vspClosed
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, r := gin.CreateTestContext(w) c, r := gin.CreateTestContext(w)
@ -266,8 +278,31 @@ func TestSetAltSignAddress(t *testing.T) {
r.ServeHTTP(w, c.Request) r.ServeHTTP(w, c.Request)
if test.wantCode != w.Code { if test.wantHTTPStatus != w.Code {
t.Fatalf("expected status %d, got %d", test.wantCode, w.Code) t.Fatalf("expected http status %d, got %d", test.wantHTTPStatus, w.Code)
}
if test.wantHTTPStatus != http.StatusOK {
respBytes, err := io.ReadAll(w.Body)
if err != nil {
t.Fatalf("failed reading response body bytes: %v", err)
}
var apiError types.ErrorResponse
err = json.Unmarshal(respBytes, &apiError)
if err != nil {
t.Fatalf("could not unmarshal error response: %v", err)
}
if int64(test.wantErrCode) != apiError.Code {
t.Fatalf("incorrect error code, expected %d, actual %d",
test.wantErrCode, apiError.Code)
}
if test.wantErrMsg != apiError.Message {
t.Fatalf("incorrect error message, expected %q, actual %q",
test.wantErrMsg, apiError.Message)
}
} }
altsig, err := api.db.AltSignAddrData(ticketHash) altsig, err := api.db.AltSignAddrData(ticketHash)
@ -275,7 +310,7 @@ func TestSetAltSignAddress(t *testing.T) {
t.Fatalf("unable to get alt sign addr data: %v", err) t.Fatalf("unable to get alt sign addr data: %v", err)
} }
if test.wantCode != http.StatusOK && !test.isExistingAltSignAddr { if test.wantHTTPStatus != http.StatusOK && !test.isExistingAltSignAddr {
if altsig != nil { if altsig != nil {
t.Fatalf("expected no alt sign addr saved for errored state") t.Fatalf("expected no alt sign addr saved for errored state")
} }

View File

@ -236,10 +236,10 @@ func (s *Server) router(cookieSecret []byte, dcrd rpc.DcrdConnect, wallets rpc.W
api := router.Group("/api/v3") api := router.Group("/api/v3")
api.GET("/vspinfo", s.vspInfo) api.GET("/vspinfo", s.vspInfo)
api.POST("/setaltsignaddr", s.withDcrdClient(dcrd), s.broadcastTicket, s.vspAuth, s.setAltSignAddr) api.POST("/setaltsignaddr", s.vspMustBeOpen, s.withDcrdClient(dcrd), s.broadcastTicket, s.vspAuth, s.setAltSignAddr)
api.POST("/feeaddress", s.withDcrdClient(dcrd), s.broadcastTicket, s.vspAuth, s.feeAddress) api.POST("/feeaddress", s.vspMustBeOpen, s.withDcrdClient(dcrd), s.broadcastTicket, s.vspAuth, s.feeAddress)
api.POST("/ticketstatus", s.withDcrdClient(dcrd), s.vspAuth, s.ticketStatus) api.POST("/ticketstatus", s.withDcrdClient(dcrd), s.vspAuth, s.ticketStatus)
api.POST("/payfee", s.withDcrdClient(dcrd), s.vspAuth, s.payFee) api.POST("/payfee", s.vspMustBeOpen, s.withDcrdClient(dcrd), s.vspAuth, s.payFee)
api.POST("/setvotechoices", s.withDcrdClient(dcrd), s.withWalletClients(wallets), s.vspAuth, s.setVoteChoices) api.POST("/setvotechoices", s.withDcrdClient(dcrd), s.withWalletClients(wallets), s.vspAuth, s.setVoteChoices)
// Website routes. // Website routes.