deb.open-vault/api/renewer_test.go

286 lines
7.0 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package api
import (
"errors"
"fmt"
"math/rand"
"reflect"
"testing"
"testing/quick"
"time"
"github.com/go-test/deep"
)
func TestRenewer_NewRenewer(t *testing.T) {
t.Parallel()
client, err := NewClient(DefaultConfig())
if err != nil {
t.Fatal(err)
}
cases := []struct {
name string
i *RenewerInput
e *Renewer
err bool
}{
{
name: "nil",
i: nil,
e: nil,
err: true,
},
{
name: "missing_secret",
i: &RenewerInput{
Secret: nil,
},
e: nil,
err: true,
},
{
name: "default_grace",
i: &RenewerInput{
Secret: &Secret{},
},
e: &Renewer{
secret: &Secret{},
},
err: false,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
v, err := client.NewRenewer(tc.i)
if (err != nil) != tc.err {
t.Fatal(err)
}
if v == nil {
return
}
// Zero-out channels because reflect
v.client = nil
v.random = nil
v.doneCh = nil
v.renewCh = nil
v.stopCh = nil
if diff := deep.Equal(tc.e, v); diff != nil {
t.Error(diff)
}
})
}
}
func TestLifetimeWatcher(t *testing.T) {
t.Parallel()
client, err := NewClient(DefaultConfig())
if err != nil {
t.Fatal(err)
}
// Note that doRenewWithOptions starts its loop with an initial renewal.
// This has a big impact on the particulars of the following cases.
renewedSecret := &Secret{}
var caseOneErrorCount int
var caseManyErrorsCount int
cases := []struct {
maxTestTime time.Duration
name string
leaseDurationSeconds int
incrementSeconds int
renew renewFunc
expectError error
expectRenewal bool
}{
{
maxTestTime: time.Second,
name: "no_error",
leaseDurationSeconds: 60,
incrementSeconds: 60,
renew: func(_ string, _ int) (*Secret, error) {
return renewedSecret, nil
},
expectError: nil,
expectRenewal: true,
},
{
maxTestTime: time.Second,
name: "short_increment_duration",
leaseDurationSeconds: 60,
incrementSeconds: 10,
renew: func(_ string, _ int) (*Secret, error) {
return renewedSecret, nil
},
expectError: nil,
expectRenewal: true,
},
{
maxTestTime: 5 * time.Second,
name: "one_error",
leaseDurationSeconds: 15,
incrementSeconds: 15,
renew: func(_ string, _ int) (*Secret, error) {
if caseOneErrorCount == 0 {
caseOneErrorCount++
return nil, fmt.Errorf("renew failure")
}
return renewedSecret, nil
},
expectError: nil,
expectRenewal: true,
},
{
maxTestTime: 15 * time.Second,
name: "many_errors",
leaseDurationSeconds: 15,
incrementSeconds: 15,
renew: func(_ string, _ int) (*Secret, error) {
if caseManyErrorsCount == 3 {
return renewedSecret, nil
}
caseManyErrorsCount++
return nil, fmt.Errorf("renew failure")
},
expectError: nil,
expectRenewal: true,
},
{
maxTestTime: 15 * time.Second,
name: "only_errors",
leaseDurationSeconds: 15,
incrementSeconds: 15,
renew: func(_ string, _ int) (*Secret, error) {
return nil, fmt.Errorf("renew failure")
},
expectError: nil,
expectRenewal: false,
},
{
maxTestTime: 15 * time.Second,
name: "negative_lease_duration",
leaseDurationSeconds: -15,
incrementSeconds: 15,
renew: func(_ string, _ int) (*Secret, error) {
return renewedSecret, nil
},
expectError: nil,
expectRenewal: true,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
v, err := client.NewLifetimeWatcher(&LifetimeWatcherInput{
Secret: &Secret{
LeaseDuration: tc.leaseDurationSeconds,
},
Increment: tc.incrementSeconds,
})
if err != nil {
t.Fatal(err)
}
doneCh := make(chan error, 1)
go func() {
doneCh <- v.doRenewWithOptions(false, false,
tc.leaseDurationSeconds, "myleaseID", tc.renew, time.Second)
}()
defer v.Stop()
receivedRenewal := false
receivedDone := false
ChannelLoop:
for {
select {
case <-time.After(tc.maxTestTime):
t.Fatalf("renewal didn't happen")
case r := <-v.RenewCh():
if !tc.expectRenewal {
t.Fatal("expected no renewals")
}
if r.Secret != renewedSecret {
t.Fatalf("expected secret %v, got %v", renewedSecret, r.Secret)
}
receivedRenewal = true
if !receivedDone {
continue ChannelLoop
}
break ChannelLoop
case err := <-doneCh:
receivedDone = true
if tc.expectError != nil && !errors.Is(err, tc.expectError) {
t.Fatalf("expected error %q, got: %v", tc.expectError, err)
}
if tc.expectError == nil && err != nil {
t.Fatalf("expected no error, got: %v", err)
}
if tc.expectRenewal && !receivedRenewal {
// We might have received the stop before the renew call on the channel.
continue ChannelLoop
}
break ChannelLoop
}
}
if tc.expectRenewal && !receivedRenewal {
t.Fatalf("expected at least one renewal, got none.")
}
})
}
}
// TestCalcSleepPeriod uses property based testing to evaluate the calculateSleepDuration
// function of LifeTimeWatchers, but also incidentally tests "calculateGrace".
// This is on account of "calculateSleepDuration" performing the "calculateGrace"
// function in particular instances.
// Both of these functions support the vital functionality of the LifeTimeWatcher
// and therefore should be tested rigorously.
func TestCalcSleepPeriod(t *testing.T) {
c := quick.Config{
MaxCount: 10000,
Values: func(values []reflect.Value, r *rand.Rand) {
leaseDuration := r.Int63()
priorDuration := r.Int63n(leaseDuration)
remainingLeaseDuration := r.Int63n(priorDuration)
increment := r.Int63n(remainingLeaseDuration)
values[0] = reflect.ValueOf(r)
values[1] = reflect.ValueOf(time.Duration(leaseDuration))
values[2] = reflect.ValueOf(time.Duration(priorDuration))
values[3] = reflect.ValueOf(time.Duration(remainingLeaseDuration))
values[4] = reflect.ValueOf(time.Duration(increment))
},
}
// tests that "calculateSleepDuration" will always return a value less than
// the remaining lease duration given a random leaseDuration, priorDuration, remainingLeaseDuration, and increment.
// Inputs are generated so that:
// leaseDuration > priorDuration > remainingLeaseDuration
// and remainingLeaseDuration > increment
if err := quick.Check(func(r *rand.Rand, leaseDuration, priorDuration, remainingLeaseDuration, increment time.Duration) bool {
lw := LifetimeWatcher{
grace: 0,
increment: int(increment.Seconds()),
random: r,
}
lw.calculateGrace(remainingLeaseDuration, increment)
// ensure that we sleep for less than the remaining lease.
return lw.calculateSleepDuration(remainingLeaseDuration, priorDuration) < remainingLeaseDuration
}, &c); err != nil {
t.Error(err)
}
}