package http import ( "context" "crypto/tls" "encoding/json" "errors" "io/ioutil" "net/http" "net/http/httptest" "net/textproto" "net/url" "reflect" "strings" "testing" "time" "github.com/go-test/deep" "github.com/hashicorp/go-cleanhttp" kv "github.com/hashicorp/vault-plugin-secrets-kv" "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/audit" auditFile "github.com/hashicorp/vault/builtin/audit/file" "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/vault" ) func TestHandler_parseMFAHandler(t *testing.T) { var err error var expectedMFACreds logical.MFACreds req := &logical.Request{ Headers: make(map[string][]string), } headerName := textproto.CanonicalMIMEHeaderKey(MFAHeaderName) // Set TOTP passcode in the MFA header req.Headers[headerName] = []string{ "my_totp:123456", "my_totp:111111", "my_second_mfa:hi=hello", "my_third_mfa", } err = parseMFAHeader(req) if err != nil { t.Fatal(err) } // Verify that it is being parsed properly expectedMFACreds = logical.MFACreds{ "my_totp": []string{ "123456", "111111", }, "my_second_mfa": []string{ "hi=hello", }, "my_third_mfa": []string{}, } if !reflect.DeepEqual(expectedMFACreds, req.MFACreds) { t.Fatalf("bad: parsed MFACreds; expected: %#v\n actual: %#v\n", expectedMFACreds, req.MFACreds) } // Split the creds of a method type in different headers and check if they // all get merged together req.Headers[headerName] = []string{ "my_mfa:passcode=123456", "my_mfa:month=july", "my_mfa:day=tuesday", } err = parseMFAHeader(req) if err != nil { t.Fatal(err) } expectedMFACreds = logical.MFACreds{ "my_mfa": []string{ "passcode=123456", "month=july", "day=tuesday", }, } if !reflect.DeepEqual(expectedMFACreds, req.MFACreds) { t.Fatalf("bad: parsed MFACreds; expected: %#v\n actual: %#v\n", expectedMFACreds, req.MFACreds) } // Header without method name should error out req.Headers[headerName] = []string{ ":passcode=123456", } err = parseMFAHeader(req) if err == nil { t.Fatalf("expected an error; actual: %#v\n", req.MFACreds) } // Header without method name and method value should error out req.Headers[headerName] = []string{ ":", } err = parseMFAHeader(req) if err == nil { t.Fatalf("expected an error; actual: %#v\n", req.MFACreds) } // Header without method name and method value should error out req.Headers[headerName] = []string{ "my_totp:", } err = parseMFAHeader(req) if err == nil { t.Fatalf("expected an error; actual: %#v\n", req.MFACreds) } } func TestHandler_cors(t *testing.T) { core, _, _ := vault.TestCoreUnsealed(t) ln, addr := TestServer(t, core) defer ln.Close() // Enable CORS and allow from any origin for testing. corsConfig := core.CORSConfig() err := corsConfig.Enable(context.Background(), []string{addr}, nil) if err != nil { t.Fatalf("Error enabling CORS: %s", err) } req, err := http.NewRequest(http.MethodOptions, addr+"/v1/sys/seal-status", nil) if err != nil { t.Fatalf("err: %s", err) } req.Header.Set("Origin", "BAD ORIGIN") // Requests from unacceptable origins will be rejected with a 403. client := cleanhttp.DefaultClient() resp, err := client.Do(req) if err != nil { t.Fatalf("err: %s", err) } if resp.StatusCode != http.StatusForbidden { t.Fatalf("Bad status:\nexpected: 403 Forbidden\nactual: %s", resp.Status) } // // Test preflight requests // // Set a valid origin req.Header.Set("Origin", addr) // Server should NOT accept arbitrary methods. req.Header.Set("Access-Control-Request-Method", "FOO") client = cleanhttp.DefaultClient() resp, err = client.Do(req) if err != nil { t.Fatalf("err: %s", err) } // Fail if an arbitrary method is accepted. if resp.StatusCode != http.StatusMethodNotAllowed { t.Fatalf("Bad status:\nexpected: 405 Method Not Allowed\nactual: %s", resp.Status) } // Server SHOULD accept acceptable methods. req.Header.Set("Access-Control-Request-Method", http.MethodPost) client = cleanhttp.DefaultClient() resp, err = client.Do(req) if err != nil { t.Fatalf("err: %s", err) } // // Test that the CORS headers are applied correctly. // expHeaders := map[string]string{ "Access-Control-Allow-Origin": addr, "Access-Control-Allow-Headers": strings.Join(vault.StdAllowedHeaders, ","), "Access-Control-Max-Age": "300", "Vary": "Origin", } for expHeader, expected := range expHeaders { actual := resp.Header.Get(expHeader) if actual == "" { t.Fatalf("bad:\nHeader: %#v was not on response.", expHeader) } if actual != expected { t.Fatalf("bad:\nExpected: %#v\nActual: %#v\n", expected, actual) } } } func TestHandler_HostnameHeader(t *testing.T) { t.Parallel() testCases := []struct { description string config *vault.CoreConfig headerPresent bool }{ { description: "with no header configured", config: nil, headerPresent: false, }, { description: "with header configured", config: &vault.CoreConfig{ EnableResponseHeaderHostname: true, }, headerPresent: true, }, } for _, tc := range testCases { t.Run(tc.description, func(t *testing.T) { var core *vault.Core if tc.config == nil { core, _, _ = vault.TestCoreUnsealed(t) } else { core, _, _ = vault.TestCoreUnsealedWithConfig(t, tc.config) } ln, addr := TestServer(t, core) defer ln.Close() req, err := http.NewRequest("GET", addr+"/v1/sys/seal-status", nil) if err != nil { t.Fatalf("err: %s", err) } client := cleanhttp.DefaultClient() resp, err := client.Do(req) if err != nil { t.Fatalf("err: %s", err) } if resp == nil { t.Fatal("nil response") } hnHeader := resp.Header.Get("X-Vault-Hostname") if tc.headerPresent && hnHeader == "" { t.Logf("header configured = %t", core.HostnameHeaderEnabled()) t.Fatal("missing 'X-Vault-Hostname' header entry in response") } if !tc.headerPresent && hnHeader != "" { t.Fatal("didn't expect 'X-Vault-Hostname' header but it was present anyway") } rniHeader := resp.Header.Get("X-Vault-Raft-Node-ID") if rniHeader != "" { t.Fatalf("no raft node ID header was expected, since we're not running a raft cluster. instead, got %s", rniHeader) } }) } } func TestHandler_CacheControlNoStore(t *testing.T) { core, _, token := vault.TestCoreUnsealed(t) ln, addr := TestServer(t, core) defer ln.Close() req, err := http.NewRequest("GET", addr+"/v1/sys/mounts", nil) if err != nil { t.Fatalf("err: %s", err) } req.Header.Set(consts.AuthHeaderName, token) req.Header.Set(WrapTTLHeaderName, "60s") client := cleanhttp.DefaultClient() resp, err := client.Do(req) if err != nil { t.Fatalf("err: %s", err) } if resp == nil { t.Fatalf("nil response") } actual := resp.Header.Get("Cache-Control") if actual == "" { t.Fatalf("missing 'Cache-Control' header entry in response writer") } if actual != "no-store" { t.Fatalf("bad: Cache-Control. Expected: 'no-store', Actual: %q", actual) } } func TestHandler_InFlightRequest(t *testing.T) { core, _, token := vault.TestCoreUnsealed(t) ln, addr := TestServer(t, core) defer ln.Close() TestServerAuth(t, addr, token) req, err := http.NewRequest("GET", addr+"/v1/sys/in-flight-req", nil) if err != nil { t.Fatalf("err: %s", err) } req.Header.Set(consts.AuthHeaderName, token) client := cleanhttp.DefaultClient() resp, err := client.Do(req) if err != nil { t.Fatalf("err: %s", err) } if resp == nil { t.Fatalf("nil response") } var actual map[string]interface{} testResponseStatus(t, resp, 200) testResponseBody(t, resp, &actual) if actual == nil || len(actual) == 0 { t.Fatal("expected to get at least one in-flight request, got nil or zero length map") } for _, v := range actual { reqInfo, ok := v.(map[string]interface{}) if !ok { t.Fatal("failed to read in-flight request") } if reqInfo["request_path"] != "/v1/sys/in-flight-req" { t.Fatalf("expected /v1/sys/in-flight-req in-flight request path, got %s", actual["request_path"]) } } } // TestHandler_MissingToken tests the response / error code if a request comes // in with a missing client token. See // https://github.com/hashicorp/vault/issues/8377 func TestHandler_MissingToken(t *testing.T) { // core, _, token := vault.TestCoreUnsealed(t) core, _, _ := vault.TestCoreUnsealed(t) ln, addr := TestServer(t, core) defer ln.Close() req, err := http.NewRequest("GET", addr+"/v1/sys/internal/ui/mounts/cubbyhole", nil) if err != nil { t.Fatalf("err: %s", err) } req.Header.Set(WrapTTLHeaderName, "60s") client := cleanhttp.DefaultClient() resp, err := client.Do(req) if err != nil { t.Fatal(err) } if resp.StatusCode != 403 { t.Fatalf("expected code 403, got: %d", resp.StatusCode) } } func TestHandler_Accepted(t *testing.T) { core, _, token := vault.TestCoreUnsealed(t) ln, addr := TestServer(t, core) defer ln.Close() req, err := http.NewRequest("POST", addr+"/v1/auth/token/tidy", nil) if err != nil { t.Fatalf("err: %s", err) } req.Header.Set(consts.AuthHeaderName, token) client := cleanhttp.DefaultClient() resp, err := client.Do(req) if err != nil { t.Fatalf("err: %s", err) } testResponseStatus(t, resp, 202) } // We use this test to verify header auth func TestSysMounts_headerAuth(t *testing.T) { core, _, token := vault.TestCoreUnsealed(t) ln, addr := TestServer(t, core) defer ln.Close() req, err := http.NewRequest("GET", addr+"/v1/sys/mounts", nil) if err != nil { t.Fatalf("err: %s", err) } req.Header.Set(consts.AuthHeaderName, token) client := cleanhttp.DefaultClient() resp, err := client.Do(req) if err != nil { t.Fatalf("err: %s", err) } var actual map[string]interface{} expected := map[string]interface{}{ "lease_id": "", "renewable": false, "lease_duration": json.Number("0"), "wrap_info": nil, "warnings": nil, "auth": nil, "data": map[string]interface{}{ "secret/": map[string]interface{}{ "description": "key/value secret storage", "type": "kv", "external_entropy_access": false, "config": map[string]interface{}{ "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, }, "local": false, "seal_wrap": false, "options": map[string]interface{}{"version": "1"}, }, "sys/": map[string]interface{}{ "description": "system endpoints used for control, policy and debugging", "type": "system", "external_entropy_access": false, "config": map[string]interface{}{ "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, "passthrough_request_headers": []interface{}{"Accept"}, }, "local": false, "seal_wrap": false, "options": interface{}(nil), }, "cubbyhole/": map[string]interface{}{ "description": "per-token private secret storage", "type": "cubbyhole", "external_entropy_access": false, "config": map[string]interface{}{ "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, }, "local": true, "seal_wrap": false, "options": interface{}(nil), }, "identity/": map[string]interface{}{ "description": "identity store", "type": "identity", "external_entropy_access": false, "config": map[string]interface{}{ "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, "passthrough_request_headers": []interface{}{"Authorization"}, }, "local": false, "seal_wrap": false, "options": interface{}(nil), }, }, "secret/": map[string]interface{}{ "description": "key/value secret storage", "type": "kv", "external_entropy_access": false, "config": map[string]interface{}{ "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, }, "local": false, "seal_wrap": false, "options": map[string]interface{}{"version": "1"}, }, "sys/": map[string]interface{}{ "description": "system endpoints used for control, policy and debugging", "type": "system", "external_entropy_access": false, "config": map[string]interface{}{ "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, "passthrough_request_headers": []interface{}{"Accept"}, }, "local": false, "seal_wrap": false, "options": interface{}(nil), }, "cubbyhole/": map[string]interface{}{ "description": "per-token private secret storage", "type": "cubbyhole", "external_entropy_access": false, "config": map[string]interface{}{ "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, }, "local": true, "seal_wrap": false, "options": interface{}(nil), }, "identity/": map[string]interface{}{ "description": "identity store", "type": "identity", "external_entropy_access": false, "config": map[string]interface{}{ "default_lease_ttl": json.Number("0"), "max_lease_ttl": json.Number("0"), "force_no_cache": false, "passthrough_request_headers": []interface{}{"Authorization"}, }, "local": false, "seal_wrap": false, "options": interface{}(nil), }, } testResponseStatus(t, resp, 200) testResponseBody(t, resp, &actual) expected["request_id"] = actual["request_id"] for k, v := range actual["data"].(map[string]interface{}) { if v.(map[string]interface{})["accessor"] == "" { t.Fatalf("no accessor from %s", k) } if v.(map[string]interface{})["uuid"] == "" { t.Fatalf("no uuid from %s", k) } expected[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"] expected[k].(map[string]interface{})["uuid"] = v.(map[string]interface{})["uuid"] expected["data"].(map[string]interface{})[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"] expected["data"].(map[string]interface{})[k].(map[string]interface{})["uuid"] = v.(map[string]interface{})["uuid"] } if diff := deep.Equal(actual, expected); len(diff) > 0 { t.Fatalf("bad, diff: %#v", diff) } } // We use this test to verify header auth wrapping func TestSysMounts_headerAuth_Wrapped(t *testing.T) { core, _, token := vault.TestCoreUnsealed(t) ln, addr := TestServer(t, core) defer ln.Close() req, err := http.NewRequest("GET", addr+"/v1/sys/mounts", nil) if err != nil { t.Fatalf("err: %s", err) } req.Header.Set(consts.AuthHeaderName, token) req.Header.Set(WrapTTLHeaderName, "60s") client := cleanhttp.DefaultClient() resp, err := client.Do(req) if err != nil { t.Fatalf("err: %s", err) } var actual map[string]interface{} expected := map[string]interface{}{ "request_id": "", "lease_id": "", "renewable": false, "lease_duration": json.Number("0"), "data": nil, "wrap_info": map[string]interface{}{ "ttl": json.Number("60"), }, "warnings": nil, "auth": nil, } testResponseStatus(t, resp, 200) testResponseBody(t, resp, &actual) actualToken, ok := actual["wrap_info"].(map[string]interface{})["token"] if !ok || actualToken == "" { t.Fatal("token missing in wrap info") } expected["wrap_info"].(map[string]interface{})["token"] = actualToken actualCreationTime, ok := actual["wrap_info"].(map[string]interface{})["creation_time"] if !ok || actualCreationTime == "" { t.Fatal("creation_time missing in wrap info") } expected["wrap_info"].(map[string]interface{})["creation_time"] = actualCreationTime actualCreationPath, ok := actual["wrap_info"].(map[string]interface{})["creation_path"] if !ok || actualCreationPath == "" { t.Fatal("creation_path missing in wrap info") } expected["wrap_info"].(map[string]interface{})["creation_path"] = actualCreationPath actualAccessor, ok := actual["wrap_info"].(map[string]interface{})["accessor"] if !ok || actualAccessor == "" { t.Fatal("accessor missing in wrap info") } expected["wrap_info"].(map[string]interface{})["accessor"] = actualAccessor if !reflect.DeepEqual(actual, expected) { t.Fatalf("bad:\nExpected: %#v\nActual: %#v\n%T %T", expected, actual, actual["warnings"], actual["data"]) } } func TestHandler_sealed(t *testing.T) { core, _, token := vault.TestCoreUnsealed(t) ln, addr := TestServer(t, core) defer ln.Close() core.Seal(token) resp, err := http.Get(addr + "/v1/secret/foo") if err != nil { t.Fatalf("err: %s", err) } testResponseStatus(t, resp, 503) } func TestHandler_ui_default(t *testing.T) { core := vault.TestCoreUI(t, false) ln, addr := TestServer(t, core) defer ln.Close() resp, err := http.Get(addr + "/ui/") if err != nil { t.Fatalf("err: %s", err) } testResponseStatus(t, resp, 404) } func TestHandler_ui_enabled(t *testing.T) { core := vault.TestCoreUI(t, true) ln, addr := TestServer(t, core) defer ln.Close() resp, err := http.Get(addr + "/ui/") if err != nil { t.Fatalf("err: %s", err) } testResponseStatus(t, resp, 200) } func TestHandler_error(t *testing.T) { w := httptest.NewRecorder() respondError(w, 500, errors.New("test Error")) if w.Code != 500 { t.Fatalf("expected 500, got %d", w.Code) } // The code inside of the error should override // the argument to respondError w2 := httptest.NewRecorder() e := logical.CodedError(403, "error text") respondError(w2, 500, e) if w2.Code != 403 { t.Fatalf("expected 403, got %d", w2.Code) } // vault.ErrSealed is a special case w3 := httptest.NewRecorder() respondError(w3, 400, consts.ErrSealed) if w3.Code != 503 { t.Fatalf("expected 503, got %d", w3.Code) } } func TestHandler_requestAuth(t *testing.T) { core, _, token := vault.TestCoreUnsealed(t) rootCtx := namespace.RootContext(nil) te, err := core.LookupToken(rootCtx, token) if err != nil { t.Fatalf("err: %s", err) } rWithAuthorization, err := http.NewRequest("GET", "v1/test/path", nil) if err != nil { t.Fatalf("err: %s", err) } rWithAuthorization.Header.Set("Authorization", "Bearer "+token) rWithVault, err := http.NewRequest("GET", "v1/test/path", nil) if err != nil { t.Fatalf("err: %s", err) } rWithVault.Header.Set(consts.AuthHeaderName, token) for _, r := range []*http.Request{rWithVault, rWithAuthorization} { req := logical.TestRequest(t, logical.ReadOperation, "test/path") r = r.WithContext(rootCtx) requestAuth(r, req) err = core.PopulateTokenEntry(rootCtx, req) if err != nil { t.Fatalf("err: %s", err) } if req.ClientToken != token { t.Fatalf("client token should be filled with %s, got %s", token, req.ClientToken) } if req.TokenEntry() == nil { t.Fatal("token entry should not be nil") } if !reflect.DeepEqual(req.TokenEntry(), te) { t.Fatalf("token entry should be the same as the core") } if req.ClientTokenAccessor == "" { t.Fatal("token accessor should not be empty") } } rNothing, err := http.NewRequest("GET", "v1/test/path", nil) if err != nil { t.Fatalf("err: %s", err) } req := logical.TestRequest(t, logical.ReadOperation, "test/path") requestAuth(rNothing, req) err = core.PopulateTokenEntry(rootCtx, req) if err != nil { t.Fatalf("expected no error, got %s", err) } if req.ClientToken != "" { t.Fatalf("client token should not be filled, got %s", req.ClientToken) } } func TestHandler_getTokenFromReq(t *testing.T) { r := http.Request{Header: http.Header{}} tok, _ := getTokenFromReq(&r) if tok != "" { t.Fatalf("expected '' as result, got '%s'", tok) } r.Header.Set("Authorization", "Bearer TOKEN NOT_GOOD_TOKEN") token, fromHeader := getTokenFromReq(&r) if !fromHeader { t.Fatal("expected from header") } else if token != "TOKEN NOT_GOOD_TOKEN" { t.Fatal("did not get expected token value") } else if r.Header.Get("Authorization") == "" { t.Fatal("expected value to be passed through") } r.Header.Set(consts.AuthHeaderName, "NEWTOKEN") tok, _ = getTokenFromReq(&r) if tok == "TOKEN" { t.Fatalf("%s header should be prioritized", consts.AuthHeaderName) } else if tok != "NEWTOKEN" { t.Fatalf("expected 'NEWTOKEN' as result, got '%s'", tok) } r.Header = http.Header{} r.Header.Set("Authorization", "Basic TOKEN") tok, fromHeader = getTokenFromReq(&r) if tok != "" { t.Fatalf("expected '' as result, got '%s'", tok) } else if fromHeader { t.Fatal("expected not from header") } } func TestHandler_nonPrintableChars(t *testing.T) { testNonPrintable(t, false) testNonPrintable(t, true) } func testNonPrintable(t *testing.T, disable bool) { core, _, token := vault.TestCoreUnsealedWithConfig(t, &vault.CoreConfig{ DisableKeyEncodingChecks: disable, }) ln, addr := TestListener(t) props := &vault.HandlerProperties{ Core: core, DisablePrintableCheck: disable, } TestServerWithListenerAndProperties(t, ln, addr, core, props) defer ln.Close() req, err := http.NewRequest("PUT", addr+"/v1/cubbyhole/foo\u2028bar", strings.NewReader(`{"zip": "zap"}`)) if err != nil { t.Fatalf("err: %s", err) } req.Header.Set(consts.AuthHeaderName, token) client := cleanhttp.DefaultClient() resp, err := client.Do(req) if err != nil { t.Fatalf("err: %s", err) } if disable { testResponseStatus(t, resp, 204) } else { testResponseStatus(t, resp, 400) } } func TestHandler_Parse_Form(t *testing.T) { cluster := vault.NewTestCluster(t, &vault.CoreConfig{}, &vault.TestClusterOptions{ HandlerFunc: Handler, }) cluster.Start() defer cluster.Cleanup() cores := cluster.Cores core := cores[0].Core vault.TestWaitActive(t, core) c := cleanhttp.DefaultClient() c.Transport = &http.Transport{ TLSClientConfig: &tls.Config{ RootCAs: cluster.RootCAs, }, } values := url.Values{ "zip": []string{"zap"}, "abc": []string{"xyz"}, "multi": []string{"first", "second"}, "empty": []string{}, } req, err := http.NewRequest("POST", cores[0].Client.Address()+"/v1/secret/foo", nil) if err != nil { t.Fatal(err) } req.Body = ioutil.NopCloser(strings.NewReader(values.Encode())) req.Header.Set("x-vault-token", cluster.RootToken) req.Header.Set("content-type", "application/x-www-form-urlencoded") resp, err := c.Do(req) if err != nil { t.Fatal(err) } if resp.StatusCode != 204 { t.Fatalf("bad response: %#v\nrequest was: %#v\nurl was: %#v", *resp, *req, req.URL) } client := cores[0].Client client.SetToken(cluster.RootToken) apiResp, err := client.Logical().Read("secret/foo") if err != nil { t.Fatal(err) } if apiResp == nil { t.Fatal("api resp is nil") } expected := map[string]interface{}{ "zip": "zap", "abc": "xyz", "multi": "first,second", } if diff := deep.Equal(expected, apiResp.Data); diff != nil { t.Fatal(diff) } } func TestHandler_Patch_BadContentTypeHeader(t *testing.T) { coreConfig := &vault.CoreConfig{ LogicalBackends: map[string]logical.Factory{ "kv": kv.VersionedKVFactory, }, } cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ HandlerFunc: Handler, }) cluster.Start() defer cluster.Cleanup() cores := cluster.Cores core := cores[0].Core c := cluster.Cores[0].Client vault.TestWaitActive(t, core) // Mount a KVv2 backend err := c.Sys().Mount("kv", &api.MountInput{ Type: "kv-v2", }) if err != nil { t.Fatal(err) } kvData := map[string]interface{}{ "data": map[string]interface{}{ "bar": "a", }, } resp, err := c.Logical().Write("kv/data/foo", kvData) if err != nil { t.Fatalf("write failed - err :%#v, resp: %#v\n", err, resp) } resp, err = c.Logical().Read("kv/data/foo") if err != nil { t.Fatalf("read failed - err :%#v, resp: %#v\n", err, resp) } req := c.NewRequest("PATCH", "/v1/kv/data/foo") req.Headers = http.Header{ "Content-Type": []string{"application/json"}, } if err := req.SetJSONBody(kvData); err != nil { t.Fatal(err) } apiResp, err := c.RawRequestWithContext(context.Background(), req) if err == nil || apiResp.StatusCode != http.StatusUnsupportedMediaType { t.Fatalf("expected PATCH request to fail with %d status code - err :%#v, resp: %#v\n", http.StatusUnsupportedMediaType, err, apiResp) } } func kvRequestWithRetry(t *testing.T, req func() (*api.Secret, error)) (*api.Secret, error) { t.Helper() var err error var resp *api.Secret // Loop until return message does not indicate upgrade, or timeout. timeout := time.After(20 * time.Second) ticker := time.Tick(time.Second) for { select { case <-timeout: t.Error("timeout expired waiting for upgrade") case <-ticker: resp, err = req() if err == nil { return resp, nil } responseError := err.(*api.ResponseError) if !strings.Contains(responseError.Error(), "Upgrading from non-versioned to versioned data") { return resp, err } } } } func TestHandler_Patch_Audit(t *testing.T) { coreConfig := &vault.CoreConfig{ LogicalBackends: map[string]logical.Factory{ "kv": kv.VersionedKVFactory, }, AuditBackends: map[string]audit.Factory{ "file": auditFile.Factory, }, } cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ HandlerFunc: Handler, }) cluster.Start() defer cluster.Cleanup() cores := cluster.Cores core := cores[0].Core c := cluster.Cores[0].Client vault.TestWaitActive(t, core) if err := c.Sys().Mount("kv/", &api.MountInput{ Type: "kv-v2", }); err != nil { t.Fatalf("kv-v2 mount attempt failed - err: %#v\n", err) } auditLogFile, err := ioutil.TempFile("", "httppatch") if err != nil { t.Fatal(err) } err = c.Sys().EnableAuditWithOptions("file", &api.EnableAuditOptions{ Type: "file", Options: map[string]string{ "file_path": auditLogFile.Name(), }, }) if err != nil { t.Fatal(err) } writeData := map[string]interface{}{ "data": map[string]interface{}{ "bar": "a", }, } resp, err := kvRequestWithRetry(t, func() (*api.Secret, error) { return c.Logical().Write("kv/data/foo", writeData) }) if err != nil { t.Fatalf("write request failed, err: %#v, resp: %#v\n", err, resp) } patchData := map[string]interface{}{ "data": map[string]interface{}{ "baz": "b", }, } resp, err = kvRequestWithRetry(t, func() (*api.Secret, error) { return c.Logical().JSONMergePatch(context.Background(), "kv/data/foo", patchData) }) if err != nil { t.Fatalf("patch request failed, err: %#v, resp: %#v\n", err, resp) } patchRequestLogCount := 0 patchResponseLogCount := 0 decoder := json.NewDecoder(auditLogFile) var auditRecord map[string]interface{} for decoder.Decode(&auditRecord) == nil { auditRequest := map[string]interface{}{} if req, ok := auditRecord["request"]; ok { auditRequest = req.(map[string]interface{}) } if auditRequest["operation"] == "patch" && auditRecord["type"] == "request" { patchRequestLogCount += 1 } else if auditRequest["operation"] == "patch" && auditRecord["type"] == "response" { patchResponseLogCount += 1 } } if patchRequestLogCount != 1 { t.Fatalf("expected 1 patch request audit log record, saw %d\n", patchRequestLogCount) } if patchResponseLogCount != 1 { t.Fatalf("expected 1 patch response audit log record, saw %d\n", patchResponseLogCount) } }