diff --git a/client/client_test.go b/client/client_test.go index 29e88e9..f7c7d27 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -2,6 +2,8 @@ package client import ( "context" + "crypto/ed25519" + "encoding/base64" "errors" "net/http" "net/http/httptest" @@ -103,3 +105,77 @@ func TestErrorDetails(t *testing.T) { }) } } + +// TestSignatureValidation ensures that responses with invalid signatures are +// flagged. +func TestSignatureValidation(t *testing.T) { + + // Generate some test data for the valid signature case. + privKey := ed25519.NewKeyFromSeed([]byte("00000000000000000000000000000000")) + pubKey, _ := privKey.Public().(ed25519.PublicKey) + emptyJSON := []byte("{}") + validSig := base64.StdEncoding.EncodeToString(ed25519.Sign(privKey, emptyJSON)) + + tests := map[string]struct { + responseSig string + expectErr bool + expectErrStr string + }{ + "valid signature": { + responseSig: validSig, + expectErr: false, + }, + "invalid signature": { + responseSig: "1234", + expectErr: true, + expectErrStr: "authenticate server response: invalid signature", + }, + "no signature": { + responseSig: "", + expectErr: true, + expectErrStr: "authenticate server response: no signature provided", + }, + "failed to decode signature": { + responseSig: "0xp", + expectErr: true, + expectErrStr: "authenticate server response: failed to decode signature: illegal base64 data at input byte 0", + }, + } + + for testName, testData := range tests { + t.Run(testName, func(t *testing.T) { + + testServer := httptest.NewServer(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { + res.Header().Add("VSP-Server-Signature", testData.responseSig) + res.WriteHeader(http.StatusOK) + _, err := res.Write(emptyJSON) + if err != nil { + t.Fatalf("writing response body failed: %v", err) + } + })) + + client := Client{ + URL: testServer.URL, + PubKey: pubKey, + Log: slog.Disabled, + } + + var resp interface{} + err := client.do(context.TODO(), http.MethodGet, "", nil, &resp, nil) + + testServer.Close() + + if testData.expectErr { + if err.Error() != testData.expectErrStr { + t.Fatalf("client.do returned incorrect error, expected %q, got %q", + testData.expectErrStr, err.Error()) + } + } else { + if err != nil { + t.Fatalf("client.do returned unexpected error: %v", err) + } + } + + }) + } +}