Add infrastructure and helpers for skipping diagnose checks (#11593)

* Add infrastructure for skipping tests

* Add infrastructure for skipping tests

* Set it

* Update vault/diagnose/helpers.go

Co-authored-by: swayne275 <swayne275@gmail.com>

* Implement type alias for test functions

Co-authored-by: swayne275 <swayne275@gmail.com>
This commit is contained in:
Scott Miller 2021-05-12 12:54:40 -05:00 committed by GitHub
parent 98f239498a
commit 9dbf1a7dba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 77 additions and 9 deletions

View File

@ -133,6 +133,7 @@ func (c *OperatorDiagnoseCommand) RunWithParsedFlags() int {
c.UI.Output(version.GetVersion().FullVersionNumber(true))
ctx := diagnose.Context(context.Background(), c.diagnose)
err := c.offlineDiagnostics(ctx)
c.diagnose.SetSkipList(c.flagSkips)
if err != nil {
return 1

View File

@ -13,6 +13,7 @@ import (
const (
warningEventName = "warning"
skippedEventName = "skipped"
actionKey = "actionKey"
spotCheckOkEventName = "spot-check-ok"
spotCheckWarnEventName = "spot-check-warn"
@ -25,10 +26,13 @@ const (
var diagnoseSession = struct{}{}
var noopTracer = trace.NewNoopTracerProvider().Tracer("vault-diagnose")
type testFunction func(context.Context) error
type Session struct {
tc *TelemetryCollector
tracer trace.Tracer
tp *sdktrace.TracerProvider
skip map[string]bool
}
// New initializes a Diagnose tracing session. In particular this wires a TelemetryCollector, which
@ -47,15 +51,39 @@ func New() *Session {
tp: tp,
tc: tc,
tracer: tracer,
skip: make(map[string]bool),
}
return sess
}
func (s *Session) SetSkipList(ls []string) {
for _, e := range ls {
s.skip[e] = true
}
}
// IsSkipped returns true if skipName is present in the skip list. Can be used in combination with Skip to mark a
// span skipped and conditionally skip some logic.
func (s *Session) IsSkipped(skipName string) bool {
return s.skip[skipName]
}
// Context returns a new context with a defined diagnose session
func Context(ctx context.Context, sess *Session) context.Context {
return context.WithValue(ctx, diagnoseSession, sess)
}
// CurrentSession retrieves the active diagnose session from the context, or nil if none.
func CurrentSession(ctx context.Context) *Session {
sessionCtxVal := ctx.Value(diagnoseSession)
if sessionCtxVal != nil {
return sessionCtxVal.(*Session)
}
return nil
}
// Finalize ends the Diagnose session, returning the root of the result tree. This will be empty until
// the outermost span ends.
func (s *Session) Finalize(ctx context.Context) *Result {
@ -65,10 +93,8 @@ func (s *Session) Finalize(ctx context.Context) *Result {
// StartSpan starts a "diagnose" span, which is really just an OpenTelemetry Tracing span.
func StartSpan(ctx context.Context, spanName string, options ...trace.SpanOption) (context.Context, trace.Span) {
sessionCtxVal := ctx.Value(diagnoseSession)
if sessionCtxVal != nil {
session := sessionCtxVal.(*Session)
session := CurrentSession(ctx)
if session != nil {
return session.tracer.Start(ctx, spanName, options...)
} else {
return noopTracer.Start(ctx, spanName, options...)
@ -88,6 +114,12 @@ func Error(ctx context.Context, err error, options ...trace.EventOption) error {
return err
}
// Skipped marks the current span skipped
func Skipped(ctx context.Context) {
span := trace.SpanFromContext(ctx)
span.AddEvent(skippedEventName)
}
// Warn records a warning on the current span
func Warn(ctx context.Context, msg string) {
span := trace.SpanFromContext(ctx)
@ -139,7 +171,7 @@ func SpotCheck(ctx context.Context, checkName string, f func() error) error {
// Test creates a new named span, and executes the provided function within it. If the function returns an error,
// the span is considered to have failed.
func Test(ctx context.Context, spanName string, function func(context.Context) error, options ...trace.SpanOption) error {
func Test(ctx context.Context, spanName string, function testFunction, options ...trace.SpanOption) error {
ctx, span := StartSpan(ctx, spanName, options...)
defer span.End()
@ -154,7 +186,7 @@ func Test(ctx context.Context, spanName string, function func(context.Context) e
// complete within the timeout, e.g.
//
// diagnose.Test(ctx, "my-span", diagnose.WithTimeout(5 * time.Second, myTestFunc))
func WithTimeout(d time.Duration, f func(context.Context) error) func(ctx context.Context) error {
func WithTimeout(d time.Duration, f testFunction) testFunction {
return func(ctx context.Context) error {
rch := make(chan error)
t := time.NewTimer(d)
@ -168,3 +200,19 @@ func WithTimeout(d time.Duration, f func(context.Context) error) func(ctx contex
}
}
}
// Skippable wraps a Test function with logic that will not run the test if the skipName
// was in the session's skip list
func Skippable(skipName string, f testFunction) testFunction {
return func(ctx context.Context) error {
session := CurrentSession(ctx)
if session != nil {
if !session.IsSkipped(skipName) {
return f(ctx)
} else {
Skipped(ctx)
}
}
return nil
}
}

View File

@ -31,9 +31,14 @@ func TestDiagnoseOtelResults(t *testing.T) {
Status: ErrorStatus,
Message: "no scones",
},
{
Name: "dispose-grounds",
Status: SkippedStatus,
},
},
}
sess := New()
sess.SetSkipList([]string{"dispose-grounds"})
ctx := Context(context.Background(), sess)
func() {
@ -70,6 +75,7 @@ func makeCoffee(ctx context.Context) error {
SpotCheck(ctx, "pick-scone", pickScone)
Test(ctx, "dispose-grounds", Skippable("dispose-grounds", disposeGrounds))
return nil
}
@ -89,3 +95,8 @@ func brewCoffee(ctx context.Context) error {
func pickScone() error {
return errors.New("no scones")
}
func disposeGrounds(_ context.Context) error {
//Done!
return nil
}

View File

@ -17,12 +17,14 @@ import (
const (
status_unknown = "[ ] "
status_ok = "\u001b[32m[ ok ]\u001b[0m "
status_failed = "\u001b[31m[failed]\u001b[0m "
status_failed = "\u001b[31m[ fail ]\u001b[0m "
status_warn = "\u001b[33m[ warn ]\u001b[0m "
status_skipped = "\u001b[90m[ skip ]\u001b[0m "
same_line = "\u001b[F"
ErrorStatus = "error"
WarningStatus = "warn"
OkStatus = "ok"
SkippedStatus = "skipped"
)
var errUnimplemented = errors.New("unimplemented")
@ -133,6 +135,8 @@ func (t *TelemetryCollector) getOrBuildResult(id trace.SpanID) *Result {
r.Warnings = append(r.Warnings, a.Value.AsString())
}
}
case skippedEventName:
r.Status = SkippedStatus
case ErrorStatus:
var message string
var action string
@ -218,11 +222,13 @@ func (t *TelemetryCollector) getOrBuildResult(id trace.SpanID) *Result {
case codes.Unset:
if len(r.Warnings) > 0 {
r.Status = WarningStatus
} else {
} else if r.Status != SkippedStatus {
r.Status = OkStatus
}
case codes.Ok:
r.Status = OkStatus
if r.Status != SkippedStatus {
r.Status = OkStatus
}
case codes.Error:
r.Status = ErrorStatus
}
@ -251,6 +257,8 @@ func (r *Result) write(sb *strings.Builder, depth int) {
sb.WriteString(status_warn)
case ErrorStatus:
sb.WriteString(status_failed)
case SkippedStatus:
sb.WriteString(status_skipped)
}
sb.WriteString(r.Name)