open-consul/agent/consul/watch/server_local_test.go

458 lines
12 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package watch
import (
"context"
"fmt"
"testing"
"time"
"github.com/hashicorp/consul/lib/retry"
"github.com/hashicorp/go-memdb"
mock "github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
type mockStoreProvider struct {
mock.Mock
}
func newMockStoreProvider(t *testing.T) *mockStoreProvider {
t.Helper()
provider := &mockStoreProvider{}
t.Cleanup(func() {
provider.AssertExpectations(t)
})
return provider
}
func (m *mockStoreProvider) getStore() *MockStateStore {
return m.Called().Get(0).(*MockStateStore)
}
type testResult struct {
value string
}
func (m *mockStoreProvider) query(ws memdb.WatchSet, store *MockStateStore) (uint64, *testResult, error) {
ret := m.Called(ws, store)
index := ret.Get(0).(uint64)
result := ret.Get(1).(*testResult)
err := ret.Error(2)
return index, result, err
}
func (m *mockStoreProvider) notify(ctx context.Context, correlationID string, result *testResult, err error) {
m.Called(ctx, correlationID, result, err)
}
func TestServerLocalBlockingQuery_getStoreNotProvided(t *testing.T) {
_, _, err := ServerLocalBlockingQuery(
context.Background(),
nil,
0,
true,
func(memdb.WatchSet, *MockStateStore) (uint64, struct{}, error) {
return 0, struct{}{}, nil
},
)
require.Error(t, err)
require.Contains(t, err.Error(), "no getStore function was provided")
}
func TestServerLocalBlockingQuery_queryNotProvided(t *testing.T) {
var query func(memdb.WatchSet, *MockStateStore) (uint64, struct{}, error)
_, _, err := ServerLocalBlockingQuery(
context.Background(),
func() *MockStateStore { return nil },
0,
true,
query,
)
require.Error(t, err)
require.Contains(t, err.Error(), "no query function was provided")
}
func TestServerLocalBlockingQuery_NonBlocking(t *testing.T) {
abandonCh := make(chan struct{})
t.Cleanup(func() { close(abandonCh) })
store := NewMockStateStore(t)
store.On("AbandonCh").
Return(closeChan(abandonCh)).
Once()
provider := newMockStoreProvider(t)
provider.On("getStore").Return(store).Once()
provider.On("query", mock.Anything, store).
Return(uint64(1), &testResult{value: "foo"}, nil).
Once()
idx, result, err := ServerLocalBlockingQuery(
context.Background(),
provider.getStore,
0,
true,
provider.query,
)
require.NoError(t, err)
require.EqualValues(t, 1, idx)
require.Equal(t, &testResult{value: "foo"}, result)
}
func TestServerLocalBlockingQuery_Index0(t *testing.T) {
abandonCh := make(chan struct{})
t.Cleanup(func() { close(abandonCh) })
store := NewMockStateStore(t)
store.On("AbandonCh").
Return(closeChan(abandonCh)).
Once()
provider := newMockStoreProvider(t)
provider.On("getStore").Return(store).Once()
provider.On("query", mock.Anything, store).
// the index 0 returned here should get translated to 1 by ServerLocalBlockingQuery
Return(uint64(0), &testResult{value: "foo"}, nil).
Once()
idx, result, err := ServerLocalBlockingQuery(
context.Background(),
provider.getStore,
0,
true,
provider.query,
)
require.NoError(t, err)
require.EqualValues(t, 1, idx)
require.Equal(t, &testResult{value: "foo"}, result)
}
func TestServerLocalBlockingQuery_NotFound(t *testing.T) {
abandonCh := make(chan struct{})
t.Cleanup(func() { close(abandonCh) })
store := NewMockStateStore(t)
store.On("AbandonCh").
Return(closeChan(abandonCh)).
Once()
provider := newMockStoreProvider(t)
provider.On("getStore").
Return(store).
Once()
var nilResult *testResult
provider.On("query", mock.Anything, store).
Return(uint64(1), nilResult, ErrorNotFound).
Once()
idx, result, err := ServerLocalBlockingQuery(
context.Background(),
provider.getStore,
0,
true,
provider.query,
)
require.NoError(t, err)
require.EqualValues(t, 1, idx)
require.Nil(t, result)
}
func TestServerLocalBlockingQuery_NotFoundBlocks(t *testing.T) {
abandonCh := make(chan struct{})
t.Cleanup(func() { close(abandonCh) })
store := NewMockStateStore(t)
store.On("AbandonCh").
Return(closeChan(abandonCh)).
Times(5)
provider := newMockStoreProvider(t)
provider.On("getStore").
Return(store).
Times(3)
var nilResult *testResult
// Initial data returned is not found and has an index less than the original
// blocking index. This should not return data to the caller.
provider.On("query", mock.Anything, store).
Return(uint64(4), nilResult, ErrorNotFound).
Run(addReadyWatchSet).
Once()
// There is an update to the data but the value still doesn't exist. Therefore
// we should not return data to the caller.
provider.On("query", mock.Anything, store).
Return(uint64(6), nilResult, ErrorNotFound).
Run(addReadyWatchSet).
Once()
// Finally we have some real data and can return it to the caller.
provider.On("query", mock.Anything, store).
Return(uint64(7), &testResult{value: "foo"}, nil).
Once()
idx, result, err := ServerLocalBlockingQuery(
context.Background(),
provider.getStore,
5,
true,
provider.query,
)
require.NoError(t, err)
require.EqualValues(t, 7, idx)
require.Equal(t, &testResult{value: "foo"}, result)
}
func TestServerLocalBlockingQuery_Error(t *testing.T) {
abandonCh := make(chan struct{})
t.Cleanup(func() { close(abandonCh) })
store := NewMockStateStore(t)
store.On("AbandonCh").
Return(closeChan(abandonCh)).
Once()
provider := newMockStoreProvider(t)
provider.On("getStore").
Return(store).
Once()
var nilResult *testResult
provider.On("query", mock.Anything, store).
Return(uint64(10), nilResult, fmt.Errorf("synthetic error")).
Once()
idx, result, err := ServerLocalBlockingQuery(
context.Background(),
provider.getStore,
4,
true,
provider.query,
)
require.Error(t, err)
require.Contains(t, err.Error(), "synthetic error")
require.EqualValues(t, 10, idx)
require.Nil(t, result)
}
func TestServerLocalBlockingQuery_ContextCancellation(t *testing.T) {
abandonCh := make(chan struct{})
t.Cleanup(func() { close(abandonCh) })
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
store := NewMockStateStore(t)
store.On("AbandonCh").
Return(closeChan(abandonCh)).
Once()
provider := newMockStoreProvider(t)
provider.On("getStore").
Return(store).
Once()
provider.On("query", mock.Anything, store).
// Return an index that should not cause the blocking query to return.
Return(uint64(4), &testResult{value: "foo"}, nil).
Once().
Run(func(_ mock.Arguments) {
// Cancel the context so that the memdb WatchCtx call will error.
cancel()
})
idx, result, err := ServerLocalBlockingQuery(
ctx,
provider.getStore,
8,
true,
provider.query,
)
// The internal cancellation error should not be propagated.
require.NoError(t, err)
require.EqualValues(t, 4, idx)
require.Equal(t, &testResult{value: "foo"}, result)
}
func TestServerLocalBlockingQuery_StateAbandoned(t *testing.T) {
abandonCh := make(chan struct{})
store := NewMockStateStore(t)
store.On("AbandonCh").
Return(closeChan(abandonCh)).
Twice()
provider := newMockStoreProvider(t)
provider.On("getStore").
Return(store).
Once()
provider.On("query", mock.Anything, store).
// Return an index that should not cause the blocking query to return.
Return(uint64(4), &testResult{value: "foo"}, nil).
Once().
Run(func(_ mock.Arguments) {
// Cancel the context so that the memdb WatchCtx call will error.
close(abandonCh)
})
idx, result, err := ServerLocalBlockingQuery(
context.Background(),
provider.getStore,
8,
true,
provider.query,
)
// The internal cancellation error should not be propagated.
require.NoError(t, err)
require.EqualValues(t, 4, idx)
require.Equal(t, &testResult{value: "foo"}, result)
}
func TestServerLocalNotify_Validations(t *testing.T) {
provider := newMockStoreProvider(t)
type testCase struct {
ctx context.Context
getStore func() *MockStateStore
query func(memdb.WatchSet, *MockStateStore) (uint64, *testResult, error)
notify func(context.Context, string, *testResult, error)
err error
}
cases := map[string]testCase{
"nil-context": {
getStore: provider.getStore,
query: provider.query,
notify: provider.notify,
err: errNilContext,
},
"nil-getStore": {
ctx: context.Background(),
query: provider.query,
notify: provider.notify,
err: errNilGetStore,
},
"nil-query": {
ctx: context.Background(),
getStore: provider.getStore,
notify: provider.notify,
err: errNilQuery,
},
"nil-notify": {
ctx: context.Background(),
getStore: provider.getStore,
query: provider.query,
err: errNilNotify,
},
}
for name, tcase := range cases {
t.Run(name, func(t *testing.T) {
err := ServerLocalNotify(tcase.ctx, "test", tcase.getStore, tcase.query, tcase.notify)
require.ErrorIs(t, err, tcase.err)
})
}
}
func TestServerLocalNotify(t *testing.T) {
notifyCtx, notifyCancel := context.WithCancel(context.Background())
t.Cleanup(notifyCancel)
abandonCh := make(chan struct{})
store := NewMockStateStore(t)
store.On("AbandonCh").
Return(closeChan(abandonCh)).
Times(3)
provider := newMockStoreProvider(t)
provider.On("getStore").
Return(store).
Times(3)
provider.On("query", mock.Anything, store).
Return(uint64(4), &testResult{value: "foo"}, nil).
Once()
provider.On("notify", notifyCtx, t.Name(), &testResult{value: "foo"}, nil).Once()
provider.On("query", mock.Anything, store).
Return(uint64(6), &testResult{value: "bar"}, nil).
Once()
provider.On("notify", notifyCtx, t.Name(), &testResult{value: "bar"}, nil).Once()
provider.On("query", mock.Anything, store).
Return(uint64(7), &testResult{value: "baz"}, context.Canceled).
Run(func(mock.Arguments) {
notifyCancel()
})
doneCtx, routineDone := context.WithCancel(context.Background())
err := serverLocalNotify(notifyCtx, t.Name(), provider.getStore, provider.query, provider.notify, routineDone, defaultWaiter())
require.NoError(t, err)
// Wait for the context cancellation which will happen when the "query" func is run the third time. The doneCtx gets "cancelled"
// by the backgrounded go routine when it is actually finished. We need to wait for this to ensure that all mocked calls have been
// made and that no extra calls get made.
<-doneCtx.Done()
}
func TestServerLocalNotify_internal(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
abandonCh := make(chan struct{})
store := NewMockStateStore(t)
store.On("AbandonCh").
Return(closeChan(abandonCh)).
Times(4)
var nilResult *testResult
provider := newMockStoreProvider(t)
provider.On("getStore").
Return(store).
Times(4)
provider.On("query", mock.Anything, store).
Return(uint64(0), nilResult, fmt.Errorf("injected error")).
Times(3)
// we should only notify the first time as the index of 1 wont exceed the min index
// after the second two queries.
provider.On("notify", ctx, "test", nilResult, fmt.Errorf("injected error")).
Once()
provider.On("query", mock.Anything, store).
Return(uint64(7), &testResult{value: "foo"}, nil).
Once()
provider.On("notify", ctx, "test", &testResult{value: "foo"}, nil).
Once().
Run(func(mock.Arguments) {
cancel()
})
waiter := retry.Waiter{
MinFailures: 1,
MinWait: time.Millisecond,
MaxWait: 50 * time.Millisecond,
Jitter: retry.NewJitter(100),
Factor: 2 * time.Millisecond,
}
// all the mock expectations should ensure things are working properly
serverLocalNotifyRoutine(ctx, "test", provider.getStore, provider.query, provider.notify, noopDone, &waiter)
}
func addReadyWatchSet(args mock.Arguments) {
ws := args.Get(0).(memdb.WatchSet)
ch := make(chan struct{})
ws.Add(ch)
close(ch)
}
// small convenience to make this more readable. The alternative in a few
// cases would be to do something like (<-chan struct{})(ch). I find that
// syntax very difficult to read.
func closeChan(ch chan struct{}) <-chan struct{} {
return ch
}