client: Add tests for server signature validation.

New tests to ensure that the correct error is returned when the server signature is invalid.
This commit is contained in:
jholdstock 2023-03-01 09:24:53 +00:00 committed by Jamie Holdstock
parent 79fd2f7bbe
commit 8303aa8536

View File

@ -2,6 +2,8 @@ package client
import (
"context"
"crypto/ed25519"
"encoding/base64"
"errors"
"net/http"
"net/http/httptest"
@ -103,3 +105,77 @@ func TestErrorDetails(t *testing.T) {
})
}
}
// TestSignatureValidation ensures that responses with invalid signatures are
// flagged.
func TestSignatureValidation(t *testing.T) {
// Generate some test data for the valid signature case.
privKey := ed25519.NewKeyFromSeed([]byte("00000000000000000000000000000000"))
pubKey, _ := privKey.Public().(ed25519.PublicKey)
emptyJSON := []byte("{}")
validSig := base64.StdEncoding.EncodeToString(ed25519.Sign(privKey, emptyJSON))
tests := map[string]struct {
responseSig string
expectErr bool
expectErrStr string
}{
"valid signature": {
responseSig: validSig,
expectErr: false,
},
"invalid signature": {
responseSig: "1234",
expectErr: true,
expectErrStr: "authenticate server response: invalid signature",
},
"no signature": {
responseSig: "",
expectErr: true,
expectErrStr: "authenticate server response: no signature provided",
},
"failed to decode signature": {
responseSig: "0xp",
expectErr: true,
expectErrStr: "authenticate server response: failed to decode signature: illegal base64 data at input byte 0",
},
}
for testName, testData := range tests {
t.Run(testName, func(t *testing.T) {
testServer := httptest.NewServer(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
res.Header().Add("VSP-Server-Signature", testData.responseSig)
res.WriteHeader(http.StatusOK)
_, err := res.Write(emptyJSON)
if err != nil {
t.Fatalf("writing response body failed: %v", err)
}
}))
client := Client{
URL: testServer.URL,
PubKey: pubKey,
Log: slog.Disabled,
}
var resp interface{}
err := client.do(context.TODO(), http.MethodGet, "", nil, &resp, nil)
testServer.Close()
if testData.expectErr {
if err.Error() != testData.expectErrStr {
t.Fatalf("client.do returned incorrect error, expected %q, got %q",
testData.expectErrStr, err.Error())
}
} else {
if err != nil {
t.Fatalf("client.do returned unexpected error: %v", err)
}
}
})
}
}