// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 package api import ( "errors" "fmt" "math/rand" "reflect" "testing" "testing/quick" "time" "github.com/go-test/deep" ) func TestRenewer_NewRenewer(t *testing.T) { t.Parallel() client, err := NewClient(DefaultConfig()) if err != nil { t.Fatal(err) } cases := []struct { name string i *RenewerInput e *Renewer err bool }{ { name: "nil", i: nil, e: nil, err: true, }, { name: "missing_secret", i: &RenewerInput{ Secret: nil, }, e: nil, err: true, }, { name: "default_grace", i: &RenewerInput{ Secret: &Secret{}, }, e: &Renewer{ secret: &Secret{}, }, err: false, }, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { v, err := client.NewRenewer(tc.i) if (err != nil) != tc.err { t.Fatal(err) } if v == nil { return } // Zero-out channels because reflect v.client = nil v.random = nil v.doneCh = nil v.renewCh = nil v.stopCh = nil if diff := deep.Equal(tc.e, v); diff != nil { t.Error(diff) } }) } } func TestLifetimeWatcher(t *testing.T) { t.Parallel() client, err := NewClient(DefaultConfig()) if err != nil { t.Fatal(err) } // Note that doRenewWithOptions starts its loop with an initial renewal. // This has a big impact on the particulars of the following cases. renewedSecret := &Secret{} var caseOneErrorCount int var caseManyErrorsCount int cases := []struct { maxTestTime time.Duration name string leaseDurationSeconds int incrementSeconds int renew renewFunc expectError error expectRenewal bool }{ { maxTestTime: time.Second, name: "no_error", leaseDurationSeconds: 60, incrementSeconds: 60, renew: func(_ string, _ int) (*Secret, error) { return renewedSecret, nil }, expectError: nil, expectRenewal: true, }, { maxTestTime: time.Second, name: "short_increment_duration", leaseDurationSeconds: 60, incrementSeconds: 10, renew: func(_ string, _ int) (*Secret, error) { return renewedSecret, nil }, expectError: nil, expectRenewal: true, }, { maxTestTime: 5 * time.Second, name: "one_error", leaseDurationSeconds: 15, incrementSeconds: 15, renew: func(_ string, _ int) (*Secret, error) { if caseOneErrorCount == 0 { caseOneErrorCount++ return nil, fmt.Errorf("renew failure") } return renewedSecret, nil }, expectError: nil, expectRenewal: true, }, { maxTestTime: 15 * time.Second, name: "many_errors", leaseDurationSeconds: 15, incrementSeconds: 15, renew: func(_ string, _ int) (*Secret, error) { if caseManyErrorsCount == 3 { return renewedSecret, nil } caseManyErrorsCount++ return nil, fmt.Errorf("renew failure") }, expectError: nil, expectRenewal: true, }, { maxTestTime: 15 * time.Second, name: "only_errors", leaseDurationSeconds: 15, incrementSeconds: 15, renew: func(_ string, _ int) (*Secret, error) { return nil, fmt.Errorf("renew failure") }, expectError: nil, expectRenewal: false, }, { maxTestTime: 15 * time.Second, name: "negative_lease_duration", leaseDurationSeconds: -15, incrementSeconds: 15, renew: func(_ string, _ int) (*Secret, error) { return renewedSecret, nil }, expectError: nil, expectRenewal: true, }, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { v, err := client.NewLifetimeWatcher(&LifetimeWatcherInput{ Secret: &Secret{ LeaseDuration: tc.leaseDurationSeconds, }, Increment: tc.incrementSeconds, }) if err != nil { t.Fatal(err) } doneCh := make(chan error, 1) go func() { doneCh <- v.doRenewWithOptions(false, false, tc.leaseDurationSeconds, "myleaseID", tc.renew, time.Second) }() defer v.Stop() receivedRenewal := false receivedDone := false ChannelLoop: for { select { case <-time.After(tc.maxTestTime): t.Fatalf("renewal didn't happen") case r := <-v.RenewCh(): if !tc.expectRenewal { t.Fatal("expected no renewals") } if r.Secret != renewedSecret { t.Fatalf("expected secret %v, got %v", renewedSecret, r.Secret) } receivedRenewal = true if !receivedDone { continue ChannelLoop } break ChannelLoop case err := <-doneCh: receivedDone = true if tc.expectError != nil && !errors.Is(err, tc.expectError) { t.Fatalf("expected error %q, got: %v", tc.expectError, err) } if tc.expectError == nil && err != nil { t.Fatalf("expected no error, got: %v", err) } if tc.expectRenewal && !receivedRenewal { // We might have received the stop before the renew call on the channel. continue ChannelLoop } break ChannelLoop } } if tc.expectRenewal && !receivedRenewal { t.Fatalf("expected at least one renewal, got none.") } }) } } // TestCalcSleepPeriod uses property based testing to evaluate the calculateSleepDuration // function of LifeTimeWatchers, but also incidentally tests "calculateGrace". // This is on account of "calculateSleepDuration" performing the "calculateGrace" // function in particular instances. // Both of these functions support the vital functionality of the LifeTimeWatcher // and therefore should be tested rigorously. func TestCalcSleepPeriod(t *testing.T) { c := quick.Config{ MaxCount: 10000, Values: func(values []reflect.Value, r *rand.Rand) { leaseDuration := r.Int63() priorDuration := r.Int63n(leaseDuration) remainingLeaseDuration := r.Int63n(priorDuration) increment := r.Int63n(remainingLeaseDuration) values[0] = reflect.ValueOf(r) values[1] = reflect.ValueOf(time.Duration(leaseDuration)) values[2] = reflect.ValueOf(time.Duration(priorDuration)) values[3] = reflect.ValueOf(time.Duration(remainingLeaseDuration)) values[4] = reflect.ValueOf(time.Duration(increment)) }, } // tests that "calculateSleepDuration" will always return a value less than // the remaining lease duration given a random leaseDuration, priorDuration, remainingLeaseDuration, and increment. // Inputs are generated so that: // leaseDuration > priorDuration > remainingLeaseDuration // and remainingLeaseDuration > increment if err := quick.Check(func(r *rand.Rand, leaseDuration, priorDuration, remainingLeaseDuration, increment time.Duration) bool { lw := LifetimeWatcher{ grace: 0, increment: int(increment.Seconds()), random: r, } lw.calculateGrace(remainingLeaseDuration, increment) // ensure that we sleep for less than the remaining lease. return lw.calculateSleepDuration(remainingLeaseDuration, priorDuration) < remainingLeaseDuration }, &c); err != nil { t.Error(err) } }