diff --git a/command/operator_diagnose.go b/command/operator_diagnose.go index 8b8511c4d..a395c7706 100644 --- a/command/operator_diagnose.go +++ b/command/operator_diagnose.go @@ -133,6 +133,7 @@ func (c *OperatorDiagnoseCommand) RunWithParsedFlags() int { c.UI.Output(version.GetVersion().FullVersionNumber(true)) ctx := diagnose.Context(context.Background(), c.diagnose) err := c.offlineDiagnostics(ctx) + c.diagnose.SetSkipList(c.flagSkips) if err != nil { return 1 diff --git a/vault/diagnose/helpers.go b/vault/diagnose/helpers.go index 599bda309..ad41b421c 100644 --- a/vault/diagnose/helpers.go +++ b/vault/diagnose/helpers.go @@ -13,6 +13,7 @@ import ( const ( warningEventName = "warning" + skippedEventName = "skipped" actionKey = "actionKey" spotCheckOkEventName = "spot-check-ok" spotCheckWarnEventName = "spot-check-warn" @@ -25,10 +26,13 @@ const ( var diagnoseSession = struct{}{} var noopTracer = trace.NewNoopTracerProvider().Tracer("vault-diagnose") +type testFunction func(context.Context) error + type Session struct { tc *TelemetryCollector tracer trace.Tracer tp *sdktrace.TracerProvider + skip map[string]bool } // New initializes a Diagnose tracing session. In particular this wires a TelemetryCollector, which @@ -47,15 +51,39 @@ func New() *Session { tp: tp, tc: tc, tracer: tracer, + skip: make(map[string]bool), } return sess } +func (s *Session) SetSkipList(ls []string) { + for _, e := range ls { + s.skip[e] = true + } +} + +// IsSkipped returns true if skipName is present in the skip list. Can be used in combination with Skip to mark a +// span skipped and conditionally skip some logic. +func (s *Session) IsSkipped(skipName string) bool { + return s.skip[skipName] +} + // Context returns a new context with a defined diagnose session func Context(ctx context.Context, sess *Session) context.Context { return context.WithValue(ctx, diagnoseSession, sess) } +// CurrentSession retrieves the active diagnose session from the context, or nil if none. +func CurrentSession(ctx context.Context) *Session { + sessionCtxVal := ctx.Value(diagnoseSession) + if sessionCtxVal != nil { + + return sessionCtxVal.(*Session) + + } + return nil +} + // Finalize ends the Diagnose session, returning the root of the result tree. This will be empty until // the outermost span ends. func (s *Session) Finalize(ctx context.Context) *Result { @@ -65,10 +93,8 @@ func (s *Session) Finalize(ctx context.Context) *Result { // StartSpan starts a "diagnose" span, which is really just an OpenTelemetry Tracing span. func StartSpan(ctx context.Context, spanName string, options ...trace.SpanOption) (context.Context, trace.Span) { - sessionCtxVal := ctx.Value(diagnoseSession) - if sessionCtxVal != nil { - - session := sessionCtxVal.(*Session) + session := CurrentSession(ctx) + if session != nil { return session.tracer.Start(ctx, spanName, options...) } else { return noopTracer.Start(ctx, spanName, options...) @@ -88,6 +114,12 @@ func Error(ctx context.Context, err error, options ...trace.EventOption) error { return err } +// Skipped marks the current span skipped +func Skipped(ctx context.Context) { + span := trace.SpanFromContext(ctx) + span.AddEvent(skippedEventName) +} + // Warn records a warning on the current span func Warn(ctx context.Context, msg string) { span := trace.SpanFromContext(ctx) @@ -139,7 +171,7 @@ func SpotCheck(ctx context.Context, checkName string, f func() error) error { // Test creates a new named span, and executes the provided function within it. If the function returns an error, // the span is considered to have failed. -func Test(ctx context.Context, spanName string, function func(context.Context) error, options ...trace.SpanOption) error { +func Test(ctx context.Context, spanName string, function testFunction, options ...trace.SpanOption) error { ctx, span := StartSpan(ctx, spanName, options...) defer span.End() @@ -154,7 +186,7 @@ func Test(ctx context.Context, spanName string, function func(context.Context) e // complete within the timeout, e.g. // // diagnose.Test(ctx, "my-span", diagnose.WithTimeout(5 * time.Second, myTestFunc)) -func WithTimeout(d time.Duration, f func(context.Context) error) func(ctx context.Context) error { +func WithTimeout(d time.Duration, f testFunction) testFunction { return func(ctx context.Context) error { rch := make(chan error) t := time.NewTimer(d) @@ -168,3 +200,19 @@ func WithTimeout(d time.Duration, f func(context.Context) error) func(ctx contex } } } + +// Skippable wraps a Test function with logic that will not run the test if the skipName +// was in the session's skip list +func Skippable(skipName string, f testFunction) testFunction { + return func(ctx context.Context) error { + session := CurrentSession(ctx) + if session != nil { + if !session.IsSkipped(skipName) { + return f(ctx) + } else { + Skipped(ctx) + } + } + return nil + } +} diff --git a/vault/diagnose/helpers_test.go b/vault/diagnose/helpers_test.go index ffa45e6dd..ebe92f597 100644 --- a/vault/diagnose/helpers_test.go +++ b/vault/diagnose/helpers_test.go @@ -31,9 +31,14 @@ func TestDiagnoseOtelResults(t *testing.T) { Status: ErrorStatus, Message: "no scones", }, + { + Name: "dispose-grounds", + Status: SkippedStatus, + }, }, } sess := New() + sess.SetSkipList([]string{"dispose-grounds"}) ctx := Context(context.Background(), sess) func() { @@ -70,6 +75,7 @@ func makeCoffee(ctx context.Context) error { SpotCheck(ctx, "pick-scone", pickScone) + Test(ctx, "dispose-grounds", Skippable("dispose-grounds", disposeGrounds)) return nil } @@ -89,3 +95,8 @@ func brewCoffee(ctx context.Context) error { func pickScone() error { return errors.New("no scones") } + +func disposeGrounds(_ context.Context) error { + //Done! + return nil +} diff --git a/vault/diagnose/output.go b/vault/diagnose/output.go index 3b281d4e8..2fdad10ff 100644 --- a/vault/diagnose/output.go +++ b/vault/diagnose/output.go @@ -17,12 +17,14 @@ import ( const ( status_unknown = "[ ] " status_ok = "\u001b[32m[ ok ]\u001b[0m " - status_failed = "\u001b[31m[failed]\u001b[0m " + status_failed = "\u001b[31m[ fail ]\u001b[0m " status_warn = "\u001b[33m[ warn ]\u001b[0m " + status_skipped = "\u001b[90m[ skip ]\u001b[0m " same_line = "\u001b[F" ErrorStatus = "error" WarningStatus = "warn" OkStatus = "ok" + SkippedStatus = "skipped" ) var errUnimplemented = errors.New("unimplemented") @@ -133,6 +135,8 @@ func (t *TelemetryCollector) getOrBuildResult(id trace.SpanID) *Result { r.Warnings = append(r.Warnings, a.Value.AsString()) } } + case skippedEventName: + r.Status = SkippedStatus case ErrorStatus: var message string var action string @@ -218,11 +222,13 @@ func (t *TelemetryCollector) getOrBuildResult(id trace.SpanID) *Result { case codes.Unset: if len(r.Warnings) > 0 { r.Status = WarningStatus - } else { + } else if r.Status != SkippedStatus { r.Status = OkStatus } case codes.Ok: - r.Status = OkStatus + if r.Status != SkippedStatus { + r.Status = OkStatus + } case codes.Error: r.Status = ErrorStatus } @@ -251,6 +257,8 @@ func (r *Result) write(sb *strings.Builder, depth int) { sb.WriteString(status_warn) case ErrorStatus: sb.WriteString(status_failed) + case SkippedStatus: + sb.WriteString(status_skipped) } sb.WriteString(r.Name)