Use constants for web context keys.

This commit is contained in:
jholdstock 2021-11-15 09:10:09 +00:00 committed by Jamie Holdstock
parent 6acc5be10c
commit b1a68a94fe
9 changed files with 59 additions and 45 deletions

View File

@ -47,11 +47,11 @@ type searchResult struct {
} }
func dcrdStatus(c *gin.Context) DcrdStatus { func dcrdStatus(c *gin.Context) DcrdStatus {
hostname := c.MustGet("DcrdHostname").(string) hostname := c.MustGet(dcrdHostKey).(string)
status := DcrdStatus{Host: hostname} status := DcrdStatus{Host: hostname}
dcrdClient := c.MustGet("DcrdClient").(*rpc.DcrdRPC) dcrdClient := c.MustGet(dcrdKey).(*rpc.DcrdRPC)
dcrdErr := c.MustGet("DcrdClientErr") dcrdErr := c.MustGet(dcrdErrorKey)
if dcrdErr != nil { if dcrdErr != nil {
log.Errorf("could not get dcrd client: %v", dcrdErr.(error)) log.Errorf("could not get dcrd client: %v", dcrdErr.(error))
return status return status
@ -72,8 +72,8 @@ func dcrdStatus(c *gin.Context) DcrdStatus {
} }
func walletStatus(c *gin.Context) map[string]WalletStatus { func walletStatus(c *gin.Context) map[string]WalletStatus {
walletClients := c.MustGet("WalletClients").([]*rpc.WalletRPC) walletClients := c.MustGet(walletsKey).([]*rpc.WalletRPC)
failedWalletClients := c.MustGet("FailedWalletClients").([]string) failedWalletClients := c.MustGet(failedWalletsKey).([]string)
walletStatus := make(map[string]WalletStatus) walletStatus := make(map[string]WalletStatus)
for _, v := range walletClients { for _, v := range walletClients {
@ -236,7 +236,7 @@ func downloadDatabaseBackup(c *gin.Context) {
// setAdminStatus stores the authentication status of the current session and // setAdminStatus stores the authentication status of the current session and
// redirects the client to GET /admin. // redirects the client to GET /admin.
func setAdminStatus(admin interface{}, c *gin.Context) { func setAdminStatus(admin interface{}, c *gin.Context) {
session := c.MustGet("session").(*sessions.Session) session := c.MustGet(sessionKey).(*sessions.Session)
session.Values["admin"] = admin session.Values["admin"] = admin
err := session.Save(c.Request, c.Writer) err := session.Save(c.Request, c.Writer)
if err != nil { if err != nil {

View File

@ -72,17 +72,17 @@ func feeAddress(c *gin.Context) {
const funcName = "feeAddress" const funcName = "feeAddress"
// Get values which have been added to context by middleware. // Get values which have been added to context by middleware.
ticket := c.MustGet("Ticket").(database.Ticket) ticket := c.MustGet(ticketKey).(database.Ticket)
knownTicket := c.MustGet("KnownTicket").(bool) knownTicket := c.MustGet(knownTicketKey).(bool)
commitmentAddress := c.MustGet("CommitmentAddress").(string) commitmentAddress := c.MustGet(commitmentAddressKey).(string)
dcrdClient := c.MustGet("DcrdClient").(*rpc.DcrdRPC) dcrdClient := c.MustGet(dcrdKey).(*rpc.DcrdRPC)
dcrdErr := c.MustGet("DcrdClientErr") dcrdErr := c.MustGet(dcrdErrorKey)
if dcrdErr != nil { if dcrdErr != nil {
log.Errorf("%s: could not get dcrd client: %v", funcName, dcrdErr.(error)) log.Errorf("%s: could not get dcrd client: %v", funcName, dcrdErr.(error))
sendError(errInternalError, c) sendError(errInternalError, c)
return return
} }
reqBytes := c.MustGet("RequestBytes").([]byte) reqBytes := c.MustGet(requestBytesKey).([]byte)
if cfg.VspClosed { if cfg.VspClosed {
sendError(errVspClosed, c) sendError(errVspClosed, c)

View File

@ -49,7 +49,7 @@ func withSession(store *sessions.CookieStore) gin.HandlerFunc {
} }
} }
c.Set("session", session) c.Set(sessionKey, session)
} }
} }
@ -57,7 +57,7 @@ func withSession(store *sessions.CookieStore) gin.HandlerFunc {
// authenticated as an admin, otherwise it will render the login template. // authenticated as an admin, otherwise it will render the login template.
func requireAdmin() gin.HandlerFunc { func requireAdmin() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
session := c.MustGet("session").(*sessions.Session) session := c.MustGet(sessionKey).(*sessions.Session)
admin := session.Values["admin"] admin := session.Values["admin"]
if admin == nil { if admin == nil {
@ -78,9 +78,9 @@ func withDcrdClient(dcrd rpc.DcrdConnect) gin.HandlerFunc {
client, hostname, err := dcrd.Client(c, cfg.NetParams) client, hostname, err := dcrd.Client(c, cfg.NetParams)
// Don't handle the error here, add it to the context and let downstream // Don't handle the error here, add it to the context and let downstream
// handlers decide what to do with it. // handlers decide what to do with it.
c.Set("DcrdClient", client) c.Set(dcrdKey, client)
c.Set("DcrdHostname", hostname) c.Set(dcrdHostKey, hostname)
c.Set("DcrdClientErr", err) c.Set(dcrdErrorKey, err)
} }
} }
@ -96,8 +96,8 @@ func withWalletClients(wallets rpc.WalletConnect) gin.HandlerFunc {
log.Errorf("Failed to connect to %d wallet(s), proceeding with only %d", log.Errorf("Failed to connect to %d wallet(s), proceeding with only %d",
len(failedConnections), len(clients)) len(failedConnections), len(clients))
} }
c.Set("WalletClients", clients) c.Set(walletsKey, clients)
c.Set("FailedWalletClients", failedConnections) c.Set(failedWalletsKey, failedConnections)
} }
} }
@ -176,8 +176,8 @@ func broadcastTicket() gin.HandlerFunc {
parentHash := parentTx.TxHash() parentHash := parentTx.TxHash()
// Check if local dcrd already knows the parent tx. // Check if local dcrd already knows the parent tx.
dcrdClient := c.MustGet("DcrdClient").(*rpc.DcrdRPC) dcrdClient := c.MustGet(dcrdKey).(*rpc.DcrdRPC)
dcrdErr := c.MustGet("DcrdClientErr") dcrdErr := c.MustGet(dcrdErrorKey)
if dcrdErr != nil { if dcrdErr != nil {
log.Errorf("%s: could not get dcrd client: %v", funcName, dcrdErr.(error)) log.Errorf("%s: could not get dcrd client: %v", funcName, dcrdErr.(error))
sendError(errInternalError, c) sendError(errInternalError, c)
@ -273,7 +273,7 @@ func vspAuth() gin.HandlerFunc {
// Add request bytes to request context for downstream handlers to reuse. // Add request bytes to request context for downstream handlers to reuse.
// Necessary because the request body reader can only be used once. // Necessary because the request body reader can only be used once.
c.Set("RequestBytes", reqBytes) c.Set(requestBytesKey, reqBytes)
// Parse request and ensure there is a ticket hash included. // Parse request and ensure there is a ticket hash included.
var request struct { var request struct {
@ -316,8 +316,8 @@ func vspAuth() gin.HandlerFunc {
if ticketFound { if ticketFound {
commitmentAddress = ticket.CommitmentAddress commitmentAddress = ticket.CommitmentAddress
} else { } else {
dcrdClient := c.MustGet("DcrdClient").(*rpc.DcrdRPC) dcrdClient := c.MustGet(dcrdKey).(*rpc.DcrdRPC)
dcrdErr := c.MustGet("DcrdClientErr") dcrdErr := c.MustGet(dcrdErrorKey)
if dcrdErr != nil { if dcrdErr != nil {
log.Errorf("%s: could not get dcrd client: %v", funcName, dcrdErr.(error)) log.Errorf("%s: could not get dcrd client: %v", funcName, dcrdErr.(error))
sendError(errInternalError, c) sendError(errInternalError, c)
@ -378,9 +378,9 @@ func vspAuth() gin.HandlerFunc {
// Add ticket information to context so downstream handlers don't need // Add ticket information to context so downstream handlers don't need
// to access the db for it. // to access the db for it.
c.Set("Ticket", ticket) c.Set(ticketKey, ticket)
c.Set("KnownTicket", ticketFound) c.Set(knownTicketKey, ticketFound)
c.Set("CommitmentAddress", commitmentAddress) c.Set(commitmentAddressKey, commitmentAddress)
} }
} }

View File

@ -24,16 +24,16 @@ func payFee(c *gin.Context) {
const funcName = "payFee" const funcName = "payFee"
// Get values which have been added to context by middleware. // Get values which have been added to context by middleware.
ticket := c.MustGet("Ticket").(database.Ticket) ticket := c.MustGet(ticketKey).(database.Ticket)
knownTicket := c.MustGet("KnownTicket").(bool) knownTicket := c.MustGet(knownTicketKey).(bool)
dcrdClient := c.MustGet("DcrdClient").(*rpc.DcrdRPC) dcrdClient := c.MustGet(dcrdKey).(*rpc.DcrdRPC)
dcrdErr := c.MustGet("DcrdClientErr") dcrdErr := c.MustGet(dcrdErrorKey)
if dcrdErr != nil { if dcrdErr != nil {
log.Errorf("%s: could not get dcrd client: %v", funcName, dcrdErr.(error)) log.Errorf("%s: could not get dcrd client: %v", funcName, dcrdErr.(error))
sendError(errInternalError, c) sendError(errInternalError, c)
return return
} }
reqBytes := c.MustGet("RequestBytes").([]byte) reqBytes := c.MustGet(requestBytesKey).([]byte)
if cfg.VspClosed { if cfg.VspClosed {
sendError(errVspClosed, c) sendError(errVspClosed, c)

View File

@ -31,14 +31,14 @@ func setAltSig(c *gin.Context) {
const funcName = "setAltSig" const funcName = "setAltSig"
// Get values which have been added to context by middleware. // Get values which have been added to context by middleware.
dcrdClient := c.MustGet("DcrdClient").(Node) dcrdClient := c.MustGet(dcrdKey).(Node)
dcrdErr := c.MustGet("DcrdClientErr") dcrdErr := c.MustGet(dcrdErrorKey)
if dcrdErr != nil { if dcrdErr != nil {
log.Errorf("%s: could not get dcrd client: %v", funcName, dcrdErr.(error)) log.Errorf("%s: could not get dcrd client: %v", funcName, dcrdErr.(error))
sendError(errInternalError, c) sendError(errInternalError, c)
return return
} }
reqBytes := c.MustGet("RequestBytes").([]byte) reqBytes := c.MustGet(requestBytesKey).([]byte)
if cfg.VspClosed { if cfg.VspClosed {
sendError(errVspClosed, c) sendError(errVspClosed, c)

View File

@ -220,9 +220,9 @@ func TestSetAltSig(t *testing.T) {
} }
handle := func(c *gin.Context) { handle := func(c *gin.Context) {
c.Set("DcrdClient", tNode) c.Set(dcrdKey, tNode)
c.Set("DcrdClientErr", dcrdErr) c.Set(dcrdErrorKey, dcrdErr)
c.Set("RequestBytes", b[test.deformReq:]) c.Set(requestBytesKey, b[test.deformReq:])
setAltSig(c) setAltSig(c)
} }

View File

@ -20,10 +20,10 @@ func setVoteChoices(c *gin.Context) {
const funcName = "setVoteChoices" const funcName = "setVoteChoices"
// Get values which have been added to context by middleware. // Get values which have been added to context by middleware.
ticket := c.MustGet("Ticket").(database.Ticket) ticket := c.MustGet(ticketKey).(database.Ticket)
knownTicket := c.MustGet("KnownTicket").(bool) knownTicket := c.MustGet(knownTicketKey).(bool)
walletClients := c.MustGet("WalletClients").([]*rpc.WalletRPC) walletClients := c.MustGet(walletsKey).([]*rpc.WalletRPC)
reqBytes := c.MustGet("RequestBytes").([]byte) reqBytes := c.MustGet(requestBytesKey).([]byte)
// If we cannot set the vote choices on at least one voting wallet right // If we cannot set the vote choices on at least one voting wallet right
// now, don't update the database, just return an error. // now, don't update the database, just return an error.

View File

@ -17,9 +17,9 @@ func ticketStatus(c *gin.Context) {
const funcName = "ticketStatus" const funcName = "ticketStatus"
// Get values which have been added to context by middleware. // Get values which have been added to context by middleware.
ticket := c.MustGet("Ticket").(database.Ticket) ticket := c.MustGet(ticketKey).(database.Ticket)
knownTicket := c.MustGet("KnownTicket").(bool) knownTicket := c.MustGet(knownTicketKey).(bool)
reqBytes := c.MustGet("RequestBytes").([]byte) reqBytes := c.MustGet(requestBytesKey).([]byte)
if !knownTicket { if !knownTicket {
log.Warnf("%s: Unknown ticket (clientIP=%s)", funcName, c.ClientIP()) log.Warnf("%s: Unknown ticket (clientIP=%s)", funcName, c.ClientIP())

View File

@ -49,6 +49,20 @@ const (
feeAddressExpiration = 1 * time.Hour feeAddressExpiration = 1 * time.Hour
) )
// Hard-coded keys used for storing values in the web context.
const (
sessionKey = "Session"
dcrdKey = "DcrdClient"
dcrdHostKey = "DcrdHostname"
dcrdErrorKey = "DcrdClientErr"
walletsKey = "WalletClients"
failedWalletsKey = "FailedWalletClients"
requestBytesKey = "RequestBytes"
ticketKey = "Ticket"
knownTicketKey = "KnownTicket"
commitmentAddressKey = "CommitmentAddress"
)
var cfg Config var cfg Config
var db *database.VspDatabase var db *database.VspDatabase
var addrGen *addressGenerator var addrGen *addressGenerator