diff --git a/changelog/11324.txt b/changelog/11324.txt new file mode 100644 index 000000000..e638d26f0 --- /dev/null +++ b/changelog/11324.txt @@ -0,0 +1,3 @@ +```release-note:enhancement +core: Add per-listener config to allow unauthenticated pprof requests, and collect a few more pprof targets. +``` diff --git a/command/debug.go b/command/debug.go index 03c3a7e2e..14b8cbd00 100644 --- a/command/debug.go +++ b/command/debug.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "io/ioutil" + "net/url" "os" "path/filepath" "strconv" @@ -581,7 +582,10 @@ func (c *DebugCommand) capturePollingTargets() error { if strutil.StrListContains(c.flagTargets, "log") { g.Add(func() error { - _ = c.writeLogs(ctx) + c.writeLogs(ctx) + // If writeLogs returned earlier due to an error, wait for context + // to terminate so we don't abort everything. + <-ctx.Done() return nil }, func(error) { cancelFunc() @@ -670,16 +674,12 @@ func (c *DebugCommand) collectMetrics(ctx context.Context) { } // Check replication status. We skip on processing metrics if we're one - // of the following (since the request will be forwarded): - // 1. Any type of DR Node - // 2. Non-DR, non-performance standby nodes + // a DR node, though non-perf standbys will fail if they aren't using + // unauthenticated_metrics_access. switch { case healthStatus.ReplicationDRMode == "secondary": c.logger.Info("skipping metrics capture on DR secondary node") continue - case healthStatus.Standby && !healthStatus.PerformanceStandby: - c.logger.Info("skipping metrics on standby node") - continue } // Perform metrics request @@ -731,35 +731,37 @@ func (c *DebugCommand) collectPprof(ctx context.Context) { var wg sync.WaitGroup - // Capture goroutines + for _, target := range []string{"threadcreate", "allocs", "block", "mutex", "goroutine", "heap"} { + wg.Add(1) + go func(target string) { + defer wg.Done() + data, err := pprofTarget(ctx, c.cachedClient, target, nil) + if err != nil { + c.captureError("pprof."+target, err) + return + } + + err = ioutil.WriteFile(filepath.Join(dirName, target+".prof"), data, 0o644) + if err != nil { + c.captureError("pprof."+target, err) + } + }(target) + } + + // As a convenience, we'll also fetch the goroutine target using debug=2, which yields a text + // version of the stack traces that don't require using `go tool pprof` to view. wg.Add(1) go func() { defer wg.Done() - data, err := pprofGoroutine(ctx, c.cachedClient) + data, err := pprofTarget(ctx, c.cachedClient, "goroutine", url.Values{"debug": []string{"2"}}) if err != nil { - c.captureError("pprof.goroutine", err) + c.captureError("pprof.goroutines-text", err) return } - err = ioutil.WriteFile(filepath.Join(dirName, "goroutine.prof"), data, 0o644) + err = ioutil.WriteFile(filepath.Join(dirName, "goroutines.txt"), data, 0o644) if err != nil { - c.captureError("pprof.goroutine", err) - } - }() - - // Capture heap - wg.Add(1) - go func() { - defer wg.Done() - data, err := pprofHeap(ctx, c.cachedClient) - if err != nil { - c.captureError("pprof.heap", err) - return - } - - err = ioutil.WriteFile(filepath.Join(dirName, "heap.prof"), data, 0o644) - if err != nil { - c.captureError("pprof.heap", err) + c.captureError("pprof.goroutines-text", err) } }() @@ -911,24 +913,11 @@ func (c *DebugCommand) compress(dst string) error { return nil } -func pprofGoroutine(ctx context.Context, client *api.Client) ([]byte, error) { - req := client.NewRequest("GET", "/v1/sys/pprof/goroutine") - resp, err := client.RawRequestWithContext(ctx, req) - if err != nil { - return nil, err +func pprofTarget(ctx context.Context, client *api.Client, target string, params url.Values) ([]byte, error) { + req := client.NewRequest("GET", "/v1/sys/pprof/"+target) + if params != nil { + req.Params = params } - defer resp.Body.Close() - - data, err := ioutil.ReadAll(resp.Body) - if err != nil { - return nil, err - } - - return data, nil -} - -func pprofHeap(ctx context.Context, client *api.Client) ([]byte, error) { - req := client.NewRequest("GET", "/v1/sys/pprof/heap") resp, err := client.RawRequestWithContext(ctx, req) if err != nil { return nil, err @@ -994,16 +983,18 @@ func (c *DebugCommand) captureError(target string, err error) { c.errLock.Unlock() } -func (c *DebugCommand) writeLogs(ctx context.Context) error { +func (c *DebugCommand) writeLogs(ctx context.Context) { out, err := os.Create(filepath.Join(c.flagOutput, "vault.log")) if err != nil { - return err + c.captureError("log", err) + return } defer out.Close() logCh, err := c.cachedClient.Sys().Monitor(ctx, "trace") if err != nil { - return err + c.captureError("log", err) + return } for { @@ -1011,10 +1002,11 @@ func (c *DebugCommand) writeLogs(ctx context.Context) error { case log := <-logCh: _, err = out.WriteString(log) if err != nil { - return err + c.captureError("log", err) + return } case <-ctx.Done(): - return nil + return } } } diff --git a/command/server/config_test_helpers.go b/command/server/config_test_helpers.go index e51409022..ff2326c59 100644 --- a/command/server/config_test_helpers.go +++ b/command/server/config_test_helpers.go @@ -711,6 +711,12 @@ listener "tcp" { tls_max_version = "tls13" tls_require_and_verify_client_cert = true tls_disable_client_certs = true + telemetry { + unauthenticated_metrics_access = true + } + profiling { + unauthenticated_pprof_access = true + } }`)) config := Config{ @@ -742,6 +748,12 @@ listener "tcp" { TLSMaxVersion: "tls13", TLSRequireAndVerifyClientCert: true, TLSDisableClientCerts: true, + Telemetry: configutil.ListenerTelemetry{ + UnauthenticatedMetricsAccess: true, + }, + Profiling: configutil.ListenerProfiling{ + UnauthenticatedPProfAccess: true, + }, }, }, }, diff --git a/http/handler.go b/http/handler.go index dc52598ef..3002f6398 100644 --- a/http/handler.go +++ b/http/handler.go @@ -11,6 +11,7 @@ import ( "mime" "net" "net/http" + "net/http/pprof" "net/textproto" "net/url" "os" @@ -130,7 +131,6 @@ func Handler(props *vault.HandlerProperties) http.Handler { // Handle non-forwarded paths mux.Handle("/v1/sys/config/state/", handleLogicalNoForward(core)) mux.Handle("/v1/sys/host-info", handleLogicalNoForward(core)) - mux.Handle("/v1/sys/pprof/", handleLogicalNoForward(core)) mux.Handle("/v1/sys/init", handleSysInit(core)) mux.Handle("/v1/sys/seal-status", handleSysSealStatus(core)) @@ -177,6 +177,19 @@ func Handler(props *vault.HandlerProperties) http.Handler { mux.Handle("/v1/sys/metrics", handleLogicalNoForward(core)) } + if props.ListenerConfig != nil && props.ListenerConfig.Profiling.UnauthenticatedPProfAccess { + for _, name := range []string{"goroutine", "threadcreate", "heap", "allocs", "block", "mutex"} { + mux.Handle("/v1/sys/pprof/"+name, pprof.Handler(name)) + } + mux.Handle("/v1/sys/pprof/", http.HandlerFunc(pprof.Index)) + mux.Handle("/v1/sys/pprof/cmdline", http.HandlerFunc(pprof.Cmdline)) + mux.Handle("/v1/sys/pprof/profile", http.HandlerFunc(pprof.Profile)) + mux.Handle("/v1/sys/pprof/symbol", http.HandlerFunc(pprof.Symbol)) + mux.Handle("/v1/sys/pprof/trace", http.HandlerFunc(pprof.Trace)) + } else { + mux.Handle("/v1/sys/pprof/", handleLogicalNoForward(core)) + } + additionalRoutes(mux, core) } diff --git a/http/sys_metrics_test.go b/http/sys_metrics_test.go index 492e6514b..6f1a7a62e 100644 --- a/http/sys_metrics_test.go +++ b/http/sys_metrics_test.go @@ -56,3 +56,41 @@ func TestSysMetricsUnauthenticated(t *testing.T) { resp = testHttpGet(t, "", addr+"/v1/sys/metrics?format=prometheus") testResponseStatus(t, resp, 200) } + +func TestSysPProfUnauthenticated(t *testing.T) { + conf := &vault.CoreConfig{} + core, _, token := vault.TestCoreUnsealedWithConfig(t, conf) + ln, addr := TestServer(t, core) + TestServerAuth(t, addr, token) + + // Default: Only authenticated access + resp := testHttpGet(t, "", addr+"/v1/sys/pprof/cmdline") + testResponseStatus(t, resp, 400) + resp = testHttpGet(t, token, addr+"/v1/sys/pprof/cmdline") + testResponseStatus(t, resp, 200) + + // Close listener + ln.Close() + + // Setup new custom listener with unauthenticated metrics access + ln, addr = TestListener(t) + props := &vault.HandlerProperties{ + Core: core, + ListenerConfig: &configutil.Listener{ + Profiling: configutil.ListenerProfiling{ + UnauthenticatedPProfAccess: true, + }, + }, + } + TestServerWithListenerAndProperties(t, ln, addr, core, props) + defer ln.Close() + TestServerAuth(t, addr, token) + + // Test without token + resp = testHttpGet(t, "", addr+"/v1/sys/pprof/cmdline") + testResponseStatus(t, resp, 200) + + // Should also work with token + resp = testHttpGet(t, token, addr+"/v1/sys/pprof/cmdline") + testResponseStatus(t, resp, 200) +} diff --git a/internalshared/configutil/listener.go b/internalshared/configutil/listener.go index bd7c905a2..ed9c90e32 100644 --- a/internalshared/configutil/listener.go +++ b/internalshared/configutil/listener.go @@ -21,6 +21,11 @@ type ListenerTelemetry struct { UnauthenticatedMetricsAccessRaw interface{} `hcl:"unauthenticated_metrics_access"` } +type ListenerProfiling struct { + UnauthenticatedPProfAccess bool `hcl:"-"` + UnauthenticatedPProfAccessRaw interface{} `hcl:"unauthenticated_pprof_access"` +} + // Listener is the listener configuration for the server. type Listener struct { RawConfig map[string]interface{} @@ -81,6 +86,7 @@ type Listener struct { SocketGroup string `hcl:"socket_group"` Telemetry ListenerTelemetry `hcl:"telemetry"` + Profiling ListenerProfiling `hcl:"profiling"` // RandomPort is used only for some testing purposes RandomPort bool `hcl:"-"` @@ -315,6 +321,17 @@ func ParseListeners(result *SharedConfig, list *ast.ObjectList) error { } } + // Profiling + { + if l.Profiling.UnauthenticatedPProfAccessRaw != nil { + if l.Profiling.UnauthenticatedPProfAccess, err = parseutil.ParseBool(l.Profiling.UnauthenticatedPProfAccessRaw); err != nil { + return multierror.Prefix(fmt.Errorf("invalid value for profiling.unauthenticated_pprof_access: %w", err), fmt.Sprintf("listeners.%d", i)) + } + + l.Profiling.UnauthenticatedPProfAccessRaw = nil + } + } + // CORS { if l.CorsEnabledRaw != nil { diff --git a/vault/external_tests/pprof/pprof_test.go b/vault/external_tests/pprof/pprof_test.go index 5da8c2bb3..32df691cc 100644 --- a/vault/external_tests/pprof/pprof_test.go +++ b/vault/external_tests/pprof/pprof_test.go @@ -1,6 +1,7 @@ package pprof import ( + "context" "encoding/json" "io/ioutil" "net/http" @@ -9,8 +10,11 @@ import ( "testing" "github.com/hashicorp/go-cleanhttp" + "github.com/hashicorp/vault/api" vaulthttp "github.com/hashicorp/vault/http" + "github.com/hashicorp/vault/internalshared/configutil" "github.com/hashicorp/vault/vault" + "github.com/stretchr/testify/require" "golang.org/x/net/http2" ) @@ -168,6 +172,45 @@ func TestSysPprof_MaxRequestDuration(t *testing.T) { t.Fatalf("expected error response, got: %v", httpResp) } if len(errs) == 0 || !strings.Contains(errs[0].(string), "exceeds max request duration") { - t.Fatalf("unexptected error returned: %v", errs) + t.Fatalf("unexpected error returned: %v", errs) } } + +func TestSysPprof_Standby(t *testing.T) { + cluster := vault.NewTestCluster(t, &vault.CoreConfig{ + DisablePerformanceStandby: true, + }, &vault.TestClusterOptions{ + HandlerFunc: vaulthttp.Handler, + DefaultHandlerProperties: vault.HandlerProperties{ + ListenerConfig: &configutil.Listener{ + Profiling: configutil.ListenerProfiling{ + UnauthenticatedPProfAccess: true, + }, + }, + }, + }) + cluster.Start() + defer cluster.Cleanup() + + pprof := func(client *api.Client) (string, error) { + req := client.NewRequest("GET", "/v1/sys/pprof/cmdline") + resp, err := client.RawRequestWithContext(context.Background(), req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + data, err := ioutil.ReadAll(resp.Body) + return string(data), err + } + + cmdline, err := pprof(cluster.Cores[0].Client) + require.Nil(t, err) + require.NotEmpty(t, cmdline) + t.Log(cmdline) + + cmdline, err = pprof(cluster.Cores[1].Client) + require.Nil(t, err) + require.NotEmpty(t, cmdline) + t.Log(cmdline) +} diff --git a/vault/logical_system_pprof.go b/vault/logical_system_pprof.go index 03db5495b..ce7fc4b27 100644 --- a/vault/logical_system_pprof.go +++ b/vault/logical_system_pprof.go @@ -59,6 +59,50 @@ render pages.`, }, }, }, + { + Pattern: "pprof/allocs", + + Operations: map[logical.Operation]framework.OperationHandler{ + logical.ReadOperation: &framework.PathOperation{ + Callback: b.handlePprofAllocs, + Summary: "Returns a sampling of all past memory allocations.", + Description: "Returns a sampling of all past memory allocations.", + }, + }, + }, + { + Pattern: "pprof/threadcreate", + + Operations: map[logical.Operation]framework.OperationHandler{ + logical.ReadOperation: &framework.PathOperation{ + Callback: b.handlePprofThreadcreate, + Summary: "Returns stack traces that led to the creation of new OS threads", + Description: "Returns stack traces that led to the creation of new OS threads", + }, + }, + }, + { + Pattern: "pprof/block", + + Operations: map[logical.Operation]framework.OperationHandler{ + logical.ReadOperation: &framework.PathOperation{ + Callback: b.handlePprofBlock, + Summary: "Returns stack traces that led to blocking on synchronization primitives", + Description: "Returns stack traces that led to blocking on synchronization primitives", + }, + }, + }, + { + Pattern: "pprof/mutex", + + Operations: map[logical.Operation]framework.OperationHandler{ + logical.ReadOperation: &framework.PathOperation{ + Callback: b.handlePprofMutex, + Summary: "Returns stack traces of holders of contended mutexes", + Description: "Returns stack traces of holders of contended mutexes", + }, + }, + }, { Pattern: "pprof/profile", @@ -146,6 +190,42 @@ func (b *SystemBackend) handlePprofHeap(ctx context.Context, req *logical.Reques return nil, nil } +func (b *SystemBackend) handlePprofAllocs(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + if err := checkRequestHandlerParams(req); err != nil { + return nil, err + } + + pprof.Handler("allocs").ServeHTTP(req.ResponseWriter, req.HTTPRequest) + return nil, nil +} + +func (b *SystemBackend) handlePprofThreadcreate(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + if err := checkRequestHandlerParams(req); err != nil { + return nil, err + } + + pprof.Handler("threadcreate").ServeHTTP(req.ResponseWriter, req.HTTPRequest) + return nil, nil +} + +func (b *SystemBackend) handlePprofBlock(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + if err := checkRequestHandlerParams(req); err != nil { + return nil, err + } + + pprof.Handler("block").ServeHTTP(req.ResponseWriter, req.HTTPRequest) + return nil, nil +} + +func (b *SystemBackend) handlePprofMutex(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + if err := checkRequestHandlerParams(req); err != nil { + return nil, err + } + + pprof.Handler("mutex").ServeHTTP(req.ResponseWriter, req.HTTPRequest) + return nil, nil +} + func (b *SystemBackend) handlePprofProfile(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { if err := checkRequestHandlerParams(req); err != nil { return nil, err