open-vault/api/renewer_test.go

236 lines
5.0 KiB
Go
Raw Normal View History

2017-06-20 04:34:11 +00:00
package api
import (
"errors"
"fmt"
2017-06-20 04:34:11 +00:00
"testing"
"time"
"github.com/go-test/deep"
2017-06-20 04:34:11 +00:00
)
func TestRenewer_NewRenewer(t *testing.T) {
t.Parallel()
client, err := NewClient(DefaultConfig())
if err != nil {
t.Fatal(err)
}
cases := []struct {
name string
i *RenewerInput
e *Renewer
err bool
}{
{
name: "nil",
i: nil,
e: nil,
err: true,
2017-06-20 04:34:11 +00:00
},
{
name: "missing_secret",
i: &RenewerInput{
2017-06-20 04:34:11 +00:00
Secret: nil,
},
e: nil,
err: true,
2017-06-20 04:34:11 +00:00
},
{
name: "default_grace",
i: &RenewerInput{
2017-06-20 04:34:11 +00:00
Secret: &Secret{},
},
e: &Renewer{
2017-06-20 04:34:11 +00:00
secret: &Secret{},
},
err: false,
2017-06-20 04:34:11 +00:00
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
2017-06-20 04:34:11 +00:00
v, err := client.NewRenewer(tc.i)
if (err != nil) != tc.err {
t.Fatal(err)
}
if v == nil {
return
}
// Zero-out channels because reflect
v.client = nil
2017-06-29 00:38:03 +00:00
v.random = nil
2017-06-20 04:34:11 +00:00
v.doneCh = nil
v.renewCh = nil
2017-06-20 04:34:11 +00:00
v.stopCh = nil
if diff := deep.Equal(tc.e, v); diff != nil {
t.Error(diff)
2017-06-20 04:34:11 +00:00
}
})
}
}
func TestLifetimeWatcher(t *testing.T) {
t.Parallel()
client, err := NewClient(DefaultConfig())
if err != nil {
t.Fatal(err)
}
// Note that doRenewWithOptions starts its loop with an initial renewal.
// This has a big impact on the particulars of the following cases.
renewedSecret := &Secret{}
var caseOneErrorCount int
var caseManyErrorsCount int
cases := []struct {
maxTestTime time.Duration
name string
leaseDurationSeconds int
incrementSeconds int
renew renewFunc
expectError error
expectRenewal bool
}{
{
maxTestTime: time.Second,
name: "no_error",
leaseDurationSeconds: 60,
incrementSeconds: 60,
renew: func(_ string, _ int) (*Secret, error) {
return renewedSecret, nil
},
expectError: nil,
expectRenewal: true,
},
{
maxTestTime: time.Second,
name: "short_increment_duration",
leaseDurationSeconds: 60,
incrementSeconds: 10,
renew: func(_ string, _ int) (*Secret, error) {
return renewedSecret, nil
},
expectError: nil,
expectRenewal: true,
},
{
maxTestTime: 5 * time.Second,
name: "one_error",
leaseDurationSeconds: 15,
incrementSeconds: 15,
renew: func(_ string, _ int) (*Secret, error) {
if caseOneErrorCount == 0 {
caseOneErrorCount++
return nil, fmt.Errorf("renew failure")
}
return renewedSecret, nil
},
expectError: nil,
expectRenewal: true,
},
{
maxTestTime: 15 * time.Second,
name: "many_errors",
leaseDurationSeconds: 15,
incrementSeconds: 15,
renew: func(_ string, _ int) (*Secret, error) {
if caseManyErrorsCount == 3 {
return renewedSecret, nil
}
caseManyErrorsCount++
return nil, fmt.Errorf("renew failure")
},
expectError: nil,
expectRenewal: true,
},
{
maxTestTime: 15 * time.Second,
name: "only_errors",
leaseDurationSeconds: 15,
incrementSeconds: 15,
renew: func(_ string, _ int) (*Secret, error) {
return nil, fmt.Errorf("renew failure")
},
expectError: nil,
expectRenewal: false,
},
{
maxTestTime: 15 * time.Second,
name: "negative_lease_duration",
leaseDurationSeconds: -15,
incrementSeconds: 15,
renew: func(_ string, _ int) (*Secret, error) {
return renewedSecret, nil
},
expectError: nil,
expectRenewal: true,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
v, err := client.NewLifetimeWatcher(&LifetimeWatcherInput{
Secret: &Secret{
LeaseDuration: tc.leaseDurationSeconds,
},
Increment: tc.incrementSeconds,
})
if err != nil {
t.Fatal(err)
}
doneCh := make(chan error, 1)
go func() {
doneCh <- v.doRenewWithOptions(false, false,
tc.leaseDurationSeconds, "myleaseID", tc.renew, time.Second)
}()
defer v.Stop()
receivedRenewal := false
receivedDone := false
ChannelLoop:
for {
select {
case <-time.After(tc.maxTestTime):
t.Fatalf("renewal didn't happen")
case r := <-v.RenewCh():
if !tc.expectRenewal {
t.Fatal("expected no renewals")
}
if r.Secret != renewedSecret {
t.Fatalf("expected secret %v, got %v", renewedSecret, r.Secret)
}
receivedRenewal = true
if !receivedDone {
continue ChannelLoop
}
break ChannelLoop
case err := <-doneCh:
receivedDone = true
if tc.expectError != nil && !errors.Is(err, tc.expectError) {
t.Fatalf("expected error %q, got: %v", tc.expectError, err)
}
if tc.expectError == nil && err != nil {
t.Fatalf("expected no error, got: %v", err)
}
if tc.expectRenewal && !receivedRenewal {
// We might have received the stop before the renew call on the channel.
continue ChannelLoop
}
break ChannelLoop
}
}
if tc.expectRenewal && !receivedRenewal {
t.Fatalf("expected at least one renewal, got none.")
}
})
}
}